├── docs ├── load-latency.png └── cpu-sensitivity.png ├── requirements.txt ├── swiftllm ├── utils.py ├── __init__.py ├── worker │ ├── kernels │ │ ├── linear.py │ │ ├── silu_and_mul.py │ │ ├── rotary_emb.py │ │ ├── rmsnorm.py │ │ ├── kvcache_mgmt.py │ │ ├── prefill_attn.py │ │ ├── block_mgmt.py │ │ └── paged_attn.py │ ├── infer_state.py │ ├── buffer.py │ ├── layers │ │ ├── pre_layer.py │ │ └── post_layer.py │ ├── block_swapper.py │ └── weight.py ├── server │ ├── tokenization_engine.py │ ├── api_server.py │ ├── executor.py │ ├── engine.py │ └── block_manager.py ├── model_config.py ├── engine_config.py ├── perfpredictor.py └── structs.py ├── csrc ├── src │ ├── linear.h │ ├── block_swapping.h │ ├── attention.h │ ├── entrypoints.cpp │ ├── small_kernels.h │ ├── linear.cu │ ├── block_swapping.cpp │ └── attention.cu └── setup.py ├── setup.py ├── pacpu ├── build.sh ├── dtype.h ├── CMakeLists.txt ├── pacpu.cpp ├── pacpu.ispc └── core.h ├── evaluation ├── configs │ ├── config-t4-7b.json │ └── config-a10-8b.json ├── api_client.py ├── reproduce-fig6c.py ├── reproduce-fig10a.py ├── benchmark.py ├── server.py └── illustrator.py ├── examples ├── example.txt └── example.py ├── .gitignore ├── README.md └── LICENSE /docs/load-latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NEO-MLSys25/NEO/HEAD/docs/load-latency.png -------------------------------------------------------------------------------- /docs/cpu-sensitivity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NEO-MLSys25/NEO/HEAD/docs/cpu-sensitivity.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi>=0.111 2 | ray[default]>=2.21 3 | safetensors>=0.4.3 4 | transformers>=4.40 5 | uvicorn>=0.29 6 | vllm_flash_attn>=2.6.1 7 | matplotlib 8 | -------------------------------------------------------------------------------- /swiftllm/utils.py: -------------------------------------------------------------------------------- 1 | def cdiv(a: int, b: int): 2 | return (a + b - 1) // b 3 | 4 | KB = 1024 5 | MB = 1024*1024 6 | GB = 1024*1024*1024 7 | TB = 1024*1024*1024*1024 8 | -------------------------------------------------------------------------------- /csrc/src/linear.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | torch::Tensor linear( 6 | torch::Tensor a, 7 | torch::Tensor w 8 | ); 9 | 10 | void linear_inplace( 11 | torch::Tensor a, 12 | torch::Tensor w, 13 | torch::Tensor r 14 | ); -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | 2 | from setuptools import setup 3 | 4 | setup( 5 | name="SwiftLLM", 6 | version="0.0.1", 7 | author="Shengyu Liu", 8 | description="A tiny yet powerful LLM inference system tailored for researching purpose", 9 | packages=["swiftllm"], 10 | zip_safe=False, 11 | ) 12 | -------------------------------------------------------------------------------- /pacpu/build.sh: -------------------------------------------------------------------------------- 1 | Torch_DIR=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')/Torch 2 | CUDA_HOST_COMPILER_PATH=$(which g++-11) 3 | CXX_COMPILER_PATH=$(which g++-13) 4 | 5 | mkdir -p build 6 | cmake -B build -S . -DTorch_DIR=$Torch_DIR -DModel=$1 -DTP=$2 -DCMAKE_CUDA_HOST_COMPILER=${CUDA_HOST_COMPILER_PATH} -DCMAKE_CXX_COMPILER=${CXX_COMPILER_PATH} 7 | cmake --build build 8 | -------------------------------------------------------------------------------- /csrc/src/block_swapping.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | void swap_blocks( 9 | const std::vector &source_block_ids, 10 | const std::vector &target_block_ids, 11 | const bool is_swap_out, 12 | const int gpu_layer, 13 | const int cpu_layer, 14 | 15 | torch::Tensor k_cache, 16 | torch::Tensor v_cache, 17 | torch::Tensor k_swap, 18 | torch::Tensor v_swap 19 | ); 20 | -------------------------------------------------------------------------------- /evaluation/configs/config-t4-7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "Llama-2-7b-hf", 3 | "model_path": "/home/ubuntu/weights/Llama-2-7b-hf", 4 | "num_layers": 32, 5 | "block_size": 16, 6 | "max_model_len": 832, 7 | "max_num_seqs": 512, 8 | "max_num_batched_tokens": 832, 9 | "tensor_parallel_size": 1, 10 | "gpu_memory_utilization": 0.99, 11 | "num_gpu_blocks_override": 54, 12 | "swap_space": 20, 13 | "library": "libpacpu-llama2_7b-tp1.so" 14 | } -------------------------------------------------------------------------------- /evaluation/configs/config-a10-8b.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "Llama-3-8B", 3 | "model_path": "/home/ubuntu/weights/Llama-3-8B", 4 | "num_layers": 32, 5 | "block_size": 16, 6 | "max_model_len": 20000, 7 | "max_num_seqs": 1024, 8 | "max_num_batched_tokens": 20480, 9 | "tensor_parallel_size": 1, 10 | "gpu_memory_utilization": 0.99, 11 | "num_gpu_blocks_override": 1650, 12 | "swap_space": 120, 13 | "library": "libpacpu-llama3_8b-tp1.so" 14 | } -------------------------------------------------------------------------------- /swiftllm/__init__.py: -------------------------------------------------------------------------------- 1 | # Config class for the engine 2 | from swiftllm.engine_config import EngineConfig 3 | 4 | # The Engine & RawRequest for online serving 5 | from swiftllm.server.engine import Engine, AsyncEngine 6 | from swiftllm.structs import RawRequest 7 | 8 | # The Model for offline inference 9 | from swiftllm.worker.model import LlamaModel, ModelPerfResult 10 | from swiftllm.structs import create_request, SubBatch 11 | 12 | # The Profiler 13 | from swiftllm.server.profiler import ModelProfiler 14 | -------------------------------------------------------------------------------- /swiftllm/worker/kernels/linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def linear( 4 | a: torch.Tensor, # [a, b] 5 | w: torch.Tensor # [c, b] 6 | ) -> torch.Tensor: # [a, c] 7 | # pylint: disable=not-callable 8 | # NOTE. It seems that torch.nn.functional.linear automatically select 9 | # the best implementation for the given input shapes (GEMM or GEMV) while 10 | # torch.matmul always uses GEMM. So, we use torch.nn.functional.linear here 11 | # to get the best performance. 12 | return torch.nn.functional.linear(a, w) 13 | -------------------------------------------------------------------------------- /csrc/src/attention.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | void paged_attention( 9 | torch::Tensor q, 10 | torch::Tensor k, 11 | torch::Tensor v, 12 | torch::Tensor o, 13 | torch::Tensor kcache, 14 | torch::Tensor vcache, 15 | float softmax_scale, 16 | torch::Tensor block_table, 17 | torch::Tensor seq_ids, 18 | torch::Tensor seq_lens, 19 | const int64_t cur_layer, 20 | const int64_t seq_block_size, 21 | const int64_t num_seq_blocks 22 | ); -------------------------------------------------------------------------------- /swiftllm/server/tokenization_engine.py: -------------------------------------------------------------------------------- 1 | import ray 2 | from transformers import AutoTokenizer 3 | 4 | from swiftllm.engine_config import EngineConfig 5 | 6 | @ray.remote 7 | class TokenizationEngine: 8 | def __init__(self, engine_config: EngineConfig): 9 | self.tokenizer = AutoTokenizer.from_pretrained(engine_config.model_path) 10 | 11 | def batched_tokenize(self, prompts: list[str]) -> list[list[int]]: 12 | prompt_token_ids = self.tokenizer(prompts, return_attention_mask=False)['input_ids'] 13 | return prompt_token_ids 14 | -------------------------------------------------------------------------------- /csrc/src/entrypoints.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "block_swapping.h" 4 | #include "small_kernels.h" 5 | #include "attention.h" 6 | #include "linear.h" 7 | 8 | PYBIND11_MODULE(swiftllm_c, m) { 9 | m.def("swap_blocks", &swap_blocks); 10 | 11 | m.def("fused_add_rmsnorm_inplace", &fused_add_rmsnorm_inplace); 12 | m.def("silu_and_mul_inplace", &silu_and_mul_inplace); 13 | m.def("rotary_embedding_inplace", &rotary_embedding_inplace); 14 | m.def("store_kvcache", &store_kvcache); 15 | m.def("embedding", &embedding); 16 | 17 | // m.def("paged_attention", &paged_attention); 18 | 19 | m.def("linear", &linear); 20 | } 21 | -------------------------------------------------------------------------------- /examples/example.txt: -------------------------------------------------------------------------------- 1 | Rome had begun expanding shortly after the founding of the Roman Republic in the 6th century BC, though not outside the Italian Peninsula until the 3rd century BC. Thus, it was an "empire" (a great power) long before it had an emperor.[21] The Republic was not a nation-state in the modern sense, but a network of self-ruled towns (with varying degrees of independence from the Senate) and provinces administered by military commanders. It was governed by annually elected magistrates (Roman consuls above all) in conjunction with the Senate.[22] The 1st century BC was a time of political and military upheaval, which ultimately led to rule by emperors.[23][24][25] The consuls' military power rested in the Roman legal concept of imperium, meaning "command" (typically in a military sense).[26] Occasionally, successful consuls or generals were given the honorary title imperator (commander); this is the origin of the word emperor, since this title was always bestowed upon the emperor. 2 | The Roman -------------------------------------------------------------------------------- /csrc/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils import cpp_extension 4 | 5 | __version__ = "0.0.1" 6 | 7 | ext_modules = [ 8 | cpp_extension.CUDAExtension( 9 | "swiftllm_c", 10 | [ 11 | "src/entrypoints.cpp", 12 | "src/block_swapping.cpp", 13 | "src/small_kernels.cu", 14 | # "src/attention.cu", 15 | "src/linear.cu", 16 | ], 17 | extra_compile_args={ 18 | 'cxx': ['-O3'], 19 | 'nvcc': ['-O3', '--use_fast_math'] 20 | } 21 | ), 22 | ] 23 | 24 | setup( 25 | name="swiftllm_c", 26 | version=__version__, 27 | author="Shengyu Liu", 28 | author_email="shengyu.liu@stu.pku.edu.cn", 29 | url="", 30 | description="Some C++/CUDA sources for SwiftLLM.", 31 | long_description="", 32 | ext_modules=ext_modules, 33 | cmdclass={ 34 | 'build_ext': cpp_extension.BuildExtension 35 | }, 36 | zip_safe=False, 37 | python_requires=">=3.9", 38 | ) 39 | -------------------------------------------------------------------------------- /evaluation/api_client.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import asyncio 4 | from fastapi import HTTPException 5 | import aiohttp 6 | from transformers import AutoTokenizer 7 | 8 | AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) 9 | 10 | async def request_completions( 11 | api_url: str, 12 | prompt: str | list[int], 13 | output_len: int, 14 | model_path: str 15 | ): 16 | async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: 17 | payload = { 18 | "model": model_path, 19 | "prompt": prompt, 20 | "max_tokens": output_len, 21 | "temperature": 0.0, 22 | "ignore_eos": True 23 | } 24 | 25 | async with session.post(url=api_url, json=payload) as response: 26 | if response.status != 200: 27 | raise HTTPException(status_code=response.status, detail=await response.text()) 28 | data = json.loads(await response.text()) 29 | 30 | return data['choices'][0]['text'] 31 | 32 | -------------------------------------------------------------------------------- /csrc/src/small_kernels.h: -------------------------------------------------------------------------------- 1 | # pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | void fused_add_rmsnorm_inplace( 8 | torch::Tensor buffer, 9 | torch::Tensor residual, 10 | torch::Tensor weight, 11 | const float epsilon 12 | ); 13 | 14 | void silu_and_mul_inplace( 15 | torch::Tensor buffer 16 | ); 17 | 18 | void rotary_embedding_inplace( 19 | torch::Tensor q, 20 | torch::Tensor k, 21 | torch::Tensor sin_table, 22 | torch::Tensor cos_table 23 | ); 24 | 25 | void store_kvcache( 26 | torch::Tensor k, 27 | torch::Tensor v, 28 | torch::Tensor k_cache, 29 | torch::Tensor v_cache, 30 | torch::Tensor block_table, 31 | torch::Tensor seq_ids, 32 | torch::Tensor seq_start_locs, 33 | torch::Tensor seq_lens, 34 | const int64_t itm_layer, 35 | const int64_t gpu_layer, 36 | const int64_t num_cprfs, 37 | const int64_t max_pref_len 38 | ); 39 | 40 | void embedding( 41 | torch::Tensor input_tokens, 42 | torch::Tensor weights, 43 | torch::Tensor output, 44 | const int64_t token_offset 45 | ); -------------------------------------------------------------------------------- /swiftllm/worker/infer_state.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import torch 3 | 4 | @dataclasses.dataclass 5 | class LlamaInferState: 6 | batch_size: int 7 | num_tokens: int 8 | 9 | gpu_seq_ids: torch.Tensor # [batch_size] 10 | softmax_scale: float # Equal to 1/sqrt(head_dim) 11 | 12 | num_prefill_seqs: int 13 | num_prefill_tokens: int 14 | prefill_seq_lens: torch.Tensor # [batch_size] 15 | prefill_seq_start_locs: torch.Tensor # [batch_size] 16 | prefill_seq_start_locs_with_end: torch.Tensor # [batch_size+1], = prefill_seq_start_locs + [num_prefill_tokens] 17 | max_prefill_len: int 18 | 19 | gpu_num_decoding_seqs: int 20 | gpu_decoding_seq_lens: torch.Tensor # [batch_size] 21 | 22 | cpu_num_decoding_seqs: int 23 | cpu_seq_ids: torch.Tensor 24 | cpu_decoding_seq_lens: torch.Tensor 25 | 26 | @property 27 | def gpu_token_end(self) -> int: 28 | return self.num_tokens - self.cpu_num_decoding_seqs 29 | 30 | seq_block_size: int 31 | num_seq_blocks: int 32 | 33 | position_cos: torch.Tensor # [num_tokens, hidden_size] 34 | position_sin: torch.Tensor # [num_tokens, hidden_size] 35 | 36 | src_block_ids: list[int] 37 | dst_block_ids: list[int] 38 | -------------------------------------------------------------------------------- /swiftllm/worker/buffer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Useful buffers for model forward pass. 3 | """ 4 | 5 | import torch 6 | from swiftllm.model_config import LlamaModelConfig 7 | from swiftllm.engine_config import EngineConfig 8 | from swiftllm.structs import SubBatch 9 | 10 | class ModelForwardBuffers: 11 | """ 12 | Useful buffers for model forward pass. 13 | """ 14 | def __init__( 15 | self, 16 | engine_config: EngineConfig, 17 | model_config: LlamaModelConfig 18 | ): 19 | iter_width = engine_config.max_tokens_in_batch 20 | hidden_size = model_config.hidden_size 21 | ws = model_config.world_size 22 | self.attn_out_buf = torch.zeros((iter_width, hidden_size // ws), dtype=torch.float16, device='cuda') 23 | self.residual_buf = torch.zeros((iter_width, hidden_size), dtype=torch.float16, device='cuda') 24 | self.cur_residual_buf = None 25 | 26 | def alloc_for_batches(self, batches: list[SubBatch]): 27 | """ 28 | Allocate buffers for batches. 29 | """ 30 | offs = 0 31 | for batch in batches: 32 | batch.attn_out_buf = self.attn_out_buf[offs: offs + batch.iter_width] 33 | batch.residual_buf = self.residual_buf[offs: offs + batch.iter_width] 34 | offs += batch.iter_width 35 | 36 | self.cur_residual_buf = self.residual_buf[:offs] 37 | self.cur_residual_buf.fill_(0.0) 38 | -------------------------------------------------------------------------------- /pacpu/dtype.h: -------------------------------------------------------------------------------- 1 | #if defined(__x86_64__) && !defined(ISPC) 2 | typedef _Float16 __fp16; 3 | #endif 4 | #ifdef ISPC 5 | typedef float16 data_t; 6 | #else 7 | typedef __fp16 data_t; 8 | #endif 9 | // #if defined(ISPC_TARGET_AVX2) || defined(__AVX2__) 10 | // typedef float itmd_t; 11 | // #else 12 | // typedef data_t itmd_t; 13 | // #endif 14 | typedef float itmd_t; 15 | typedef float otpt_t; 16 | 17 | #define GEMM cblas_hgemm 18 | #define HEAD_DIM 128 // Constant for all models 19 | #define BLOCK_SIZE 16 20 | 21 | #if defined(LLAMA3_8B) 22 | #define NUM_LAYERS 32 23 | #define NUM_Q_HEADS (32 / TP_DEGREE) 24 | #define NUM_KV_HEADS (8 / TP_DEGREE) 25 | #elif defined(LLAMA2_7B) 26 | #define NUM_LAYERS 32 27 | #define NUM_Q_HEADS (32 / TP_DEGREE) 28 | #define NUM_KV_HEADS (32 / TP_DEGREE) 29 | #elif defined(LLAMA2_13B) 30 | #define NUM_LAYERS 40 31 | #define NUM_Q_HEADS (40 / TP_DEGREE) 32 | #define NUM_KV_HEADS (40 / TP_DEGREE) 33 | #elif defined(LLAMA2_70B) || defined(LLAMA3_70B) 34 | #define NUM_LAYERS 80 35 | #define NUM_Q_HEADS (64 / TP_DEGREE) 36 | #define NUM_KV_HEADS (8 / TP_DEGREE) 37 | #else 38 | #error "Please define the model" 39 | #endif 40 | 41 | #define QH_PER_KVH (NUM_Q_HEADS / NUM_KV_HEADS) 42 | #define BLOCK_NELEM (NUM_KV_HEADS * BLOCK_SIZE * HEAD_DIM) 43 | 44 | #define MAX_BATCH_SIZE 4096 45 | #define MAX_WS 256 46 | #define MAX_TOK_NUM 1048576 // Maxinum number of token's KV to be scanned in one iteration 47 | #define MAX_TASK_NUM (MAX_BATCH_SIZE + MAX_WS) -------------------------------------------------------------------------------- /swiftllm/worker/layers/pre_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pre-layer of the model. 3 | """ 4 | # pylint: disable=no-name-in-module 5 | 6 | import torch 7 | 8 | from swiftllm_c import embedding 9 | 10 | from swiftllm.model_config import LlamaModelConfig 11 | from swiftllm.worker.weight import LlamaWeight 12 | 13 | 14 | class LlamaPreLayer: 15 | """ 16 | Pre-layer of the model. 17 | """ 18 | def __init__( 19 | self, 20 | model_config: LlamaModelConfig, 21 | weights: LlamaWeight, 22 | ): 23 | self.model_config = model_config 24 | self.weights = weights 25 | seg_len = model_config.vocab_size // model_config.world_size 26 | self.token_offs = seg_len * model_config.rank 27 | 28 | def forward( 29 | self, 30 | input_ids: list[int] 31 | ) -> torch.Tensor: 32 | """ 33 | Forward pass of the pre-layer. 34 | 35 | Each shard of the model is responsible for a segment of the vocabulary, and only sets the embeddings 36 | for the tokens in its segment. For tokens outside of its segment, the embeddings are set to zeros. 37 | Then all the embeddings would be reduced across all the shards by the first transformer layer. 38 | """ 39 | input_gpu = torch.tensor(input_ids, dtype=torch.int32, device='cuda') 40 | embeddings = torch.zeros((len(input_ids), self.model_config.hidden_size), dtype=torch.float16, device='cuda') 41 | embedding(input_gpu, self.weights.wte, embeddings, self.token_offs) 42 | return embeddings 43 | -------------------------------------------------------------------------------- /swiftllm/worker/kernels/silu_and_mul.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | import swiftllm_c 5 | 6 | @triton.jit 7 | def _fwd_silu_and_mul( 8 | x: torch.Tensor, # [num_tokens, 2*ffn_inter_dim]. Result will be stored at input[:, :ffn_inter_dim] 9 | ffn_inter_dim: tl.constexpr, 10 | block_size: tl.constexpr 11 | ): 12 | # grid shape: [num_tokens, ffn_inter_dim / block_size] 13 | # require ffn_inter_dim % block_size == 0 14 | my_token_id = tl.program_id(0).to(tl.int64) 15 | my_block_id = tl.program_id(1) 16 | 17 | offs = my_token_id*(2*ffn_inter_dim) + my_block_id*block_size + tl.arange(0, block_size) 18 | gate = tl.load(x + (offs+ffn_inter_dim)) 19 | gate = gate.to(tl.float32) 20 | gate = gate / (1 + tl.exp(-gate)) 21 | gate = gate.to(tl.float16) 22 | up = tl.load(x + offs) 23 | down = up * gate 24 | tl.store(x + offs, down) 25 | 26 | def silu_and_mul_inplace( 27 | x: torch.Tensor # [num_tokens, 2*ffn_inter_dim] 28 | ): 29 | assert x.is_contiguous() 30 | num_tokens = x.shape[0] 31 | ffn_inter_dim = x.shape[1] // 2 32 | 33 | block_size = 256 34 | assert ffn_inter_dim % block_size == 0 35 | _fwd_silu_and_mul[(num_tokens, ffn_inter_dim//block_size)](x, ffn_inter_dim, block_size) 36 | 37 | if __name__ == "__main__": 38 | x = torch.randn(10, 512, dtype=torch.float16, device="cuda") 39 | y = x.clone() 40 | silu_and_mul_inplace(x) 41 | swiftllm_c.silu_and_mul_inplace(y) 42 | 43 | print(x) 44 | print(y) 45 | 46 | assert torch.allclose(x, y, atol=1e-5) 47 | -------------------------------------------------------------------------------- /evaluation/reproduce-fig6c.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reproduction script for Figure 6c in the paper. 3 | 4 | Please note that the input data for this script is only a small subset (100 requests) of the original data (2000 requests) used 5 | in the paper. This is for the purpose of demonstration and quick verification of the results. As a result, the latency would be 6 | lower than the original figure due to less average queuing latency. 7 | 8 | The original data can be generated by modifying the parameters in the script. 9 | """ 10 | 11 | import asyncio 12 | import json 13 | import os 14 | 15 | from server import start_server, stop_server 16 | from benchmark import run_test, prepare_real_test 17 | from illustrator import draw_one_rl_diagram 18 | 19 | # Tweak hyperparameters here: 20 | 21 | vllm_rates = [0.2, 0.4, 0.5, 0.6] 22 | ours_rates = [0.5, 1.5, 2.5, 3.1, 3.5, 3.7, 3.9] 23 | # Rates of requests per second, reduce the number of elements in the list to speed up the evaluation process. 24 | 25 | 26 | cur_dir = os.path.dirname(os.path.realpath(__file__)) 27 | with open(f"{cur_dir}/configs/config-t4-7b.json", "r") as f: 28 | config = json.load(f) 29 | 30 | 31 | async def one_round(server_name: str): 32 | start_server(server_name, config) 33 | try: 34 | # Change the rate argument (in reqs/s) to other values to see more results 35 | # Feel free to comment out some of the following lines to reduce running time 36 | if server_name == "ours": 37 | for rate in ours_rates: 38 | await run_test(*prepare_real_test("osc", config, server_name), rate=rate) 39 | if server_name == "vllm": 40 | for rate in vllm_rates: 41 | await run_test(*prepare_real_test("osc", config, server_name), rate=rate) 42 | finally: 43 | stop_server() 44 | await asyncio.sleep(5) 45 | 46 | 47 | async def main(): 48 | await one_round("vllm") 49 | await one_round("ours") 50 | 51 | 52 | if __name__ == "__main__": 53 | asyncio.run(main()) 54 | draw_one_rl_diagram( 55 | title="fig6c", 56 | data_name="osc", 57 | sys_file_names=["vllm", "ours"], 58 | sys_legend_names=["VLLM", "Ours"], 59 | rate_lists=[vllm_rates, ours_rates], 60 | ylim=2, 61 | markers=["o", "x"], 62 | set_ylabel=True 63 | ) -------------------------------------------------------------------------------- /evaluation/reproduce-fig10a.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reproduction script for Figure 10a in the paper. 3 | 4 | Please note that this script should be run on g5.x16large instances, and only the corresponding line 5 | is drew in the figure. 6 | 7 | The full test takes a long time (~5h) to run. You can reduce the number of requests or number of data 8 | points to speed up the evaluation process. 9 | """ 10 | 11 | import asyncio 12 | import json 13 | import os 14 | 15 | from server import start_server, stop_server 16 | from benchmark import run_test, prepare_mock_test 17 | from illustrator import draw_one_ps_diagram 18 | 19 | 20 | # Tweak hyperparameters here: 21 | 22 | num_data = 2000 23 | # Number of total request send to the serving engine, reduce this number to speed up the evaluation process. 24 | # However, the result may not be as accurate as the original one due to warm-up and cool-down effects. It is 25 | # not recommended to set this number below 800. 26 | 27 | input_len = 1000 28 | # Length of input sequence, please keep it as 1000 to reproduce the original result. 29 | 30 | output_lens = [50, 100, 200, 300, 400] 31 | # Length of output sequence, reduce the number of elements in the list to speed up the evaluation process. 32 | 33 | 34 | cur_dir = os.path.dirname(os.path.realpath(__file__)) 35 | with open(f"{cur_dir}/configs/config-a10-8b.json", "r") as f: 36 | config = json.load(f) 37 | 38 | 39 | async def one_round(server_name: str): 40 | start_server(server_name, config) 41 | try: 42 | for output_len in output_lens: 43 | await run_test(*prepare_mock_test(num_data, input_len, output_len, server_name, config)) 44 | finally: 45 | stop_server() 46 | await asyncio.sleep(5) 47 | 48 | 49 | async def main(): 50 | await one_round("base") 51 | await one_round("ours") 52 | 53 | 54 | if __name__ == "__main__": 55 | asyncio.run(main()) 56 | draw_one_ps_diagram( 57 | title="fig10a", 58 | base_sys_name="base", 59 | interv=[0.3, 0.7], # The interval for calculating throughput, we ignore the first 30% and last 30% of the data in order to avoid warm-up and cool-down effects. 60 | num_datas=[num_data], 61 | sys_file_names=["ours"], 62 | legend_names=["x16large"], 63 | input_lens=[input_len], 64 | output_lens=output_lens, 65 | markers=["x"], 66 | show_ylabels=True, 67 | show_legend=True 68 | ) 69 | -------------------------------------------------------------------------------- /pacpu/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.18 FATAL_ERROR) 2 | project(PagedAttentionCPU C CXX) 3 | 4 | # Meta information 5 | set(TARGET_NAME_PREFIX pacpu) 6 | set(ISPC_SRC_NAME pacpu) 7 | # set(ISPC_TARGETS "avx2") 8 | # set(ISPC_TARGETS "avx512spr-x16") 9 | set(ISPC_FLAGS "-O3" "--opt=fast-math") 10 | set(TARGET_SOURCES "pacpu.cpp" "pacpu.ispc") 11 | 12 | set(ISPC_ARCH "x86-64") 13 | set(ISPC_ARCH_BIT "64") 14 | 15 | # Set a small number here, we don't actually compile any CUDA kernels 16 | set(CMAKE_CUDA_ARCHITECTURES 70) 17 | 18 | enable_language(ISPC) 19 | set(GEN_TORCH_LIBRARY TRUE) 20 | 21 | function(target_add_common_options target model_name tp_degree) 22 | set_property(TARGET ${target} PROPERTY CXX_STANDARD 17) 23 | set_property(TARGET ${target} PROPERTY POSITION_INDEPENDENT_CODE ON) 24 | set_property(TARGET ${target} PROPERTY ISPC_INSTRUCTION_SETS "${ISPC_TARGETS}") 25 | target_compile_options(${target} PRIVATE $<$:${ISPC_FLAGS}>) 26 | target_compile_options(${target} PRIVATE $<$:--arch=${ISPC_ARCH}>) 27 | 28 | set(arch_flag "-m${ISPC_ARCH_BIT}") 29 | target_compile_options(${target} PRIVATE $<$:-Ofast -march=native>) 30 | target_compile_options(${target} PRIVATE $<$:-fopenmp>) 31 | target_compile_options(${target} PRIVATE $<$:${arch_flag}>) 32 | string(TOUPPER ${model_name} MODEL_NAME) 33 | string(TOUPPER ${tp_degree} TP_DEGREE) 34 | target_compile_definitions(${target} PRIVATE "${MODEL_NAME}" TP_DEGREE=${TP_DEGREE}) 35 | 36 | find_package(OpenMP REQUIRED) 37 | target_link_libraries(${target} PRIVATE OpenMP::OpenMP_CXX) 38 | endfunction() 39 | 40 | function(gen_torch_lib model_name tp_degree) 41 | set(TARGET_NAME "${TARGET_NAME_PREFIX}-${model_name}-tp${tp_degree}") 42 | 43 | add_library(${TARGET_NAME} SHARED) 44 | target_sources(${TARGET_NAME} PRIVATE ${TARGET_SOURCES}) 45 | target_add_common_options(${TARGET_NAME} ${model_name} ${tp_degree}) 46 | target_compile_options(${TARGET_NAME} PRIVATE $<$:${TORCH_CXX_FLAGS}>) 47 | target_link_libraries(${TARGET_NAME} PRIVATE "${TORCH_LIBRARIES}") 48 | endfunction() 49 | 50 | if (GEN_TORCH_LIBRARY) 51 | find_package(Torch REQUIRED) 52 | # set(CMAKE_CXX_FLAGS "-Ofast -march=native ${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") 53 | 54 | # Create library, Model and TP should be pass as arguments 55 | gen_torch_lib(${Model} ${TP}) 56 | endif() 57 | -------------------------------------------------------------------------------- /csrc/src/linear.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "linear.h" 4 | 5 | bool has_handle = false; 6 | cublasHandle_t handle; 7 | 8 | void cublas_init_handle(){ 9 | cublasStatus_t status = cublasCreate(&handle); 10 | if (status != CUBLAS_STATUS_SUCCESS) { 11 | std::cerr << "cublas_init_handle failed: " << status << std::endl; 12 | throw std::runtime_error("cublas_init_handle failed"); 13 | } 14 | } 15 | 16 | /* Row major 17 | * A (m x k) einsum(ik, jk -> ij) B (n x k) = C (m x n) 18 | * Equivalent to column major 19 | * B (n x k) @ A^T (k x m) = C^T (m x n) 20 | */ 21 | void array_linear( 22 | int m, 23 | int n, 24 | int k, 25 | const half* Aarray, 26 | const half* Barray, 27 | half* Carray 28 | ) { 29 | const float alpha = 1.0; 30 | const float beta = 0.0; 31 | cublasStatus_t status = cublasSgemmEx( 32 | handle, 33 | CUBLAS_OP_T, 34 | CUBLAS_OP_N, 35 | n, 36 | m, 37 | k, 38 | &alpha, 39 | Barray, 40 | CUDA_R_16F, 41 | k, 42 | Aarray, 43 | CUDA_R_16F, 44 | k, 45 | &beta, 46 | Carray, 47 | CUDA_R_16F, 48 | n 49 | ); 50 | if (status != CUBLAS_STATUS_SUCCESS) { 51 | std::cerr << "cublasGemmEx failed: " << status << std::endl; 52 | throw std::runtime_error("cublasGemmEx failed"); 53 | } 54 | } 55 | 56 | torch::Tensor linear( 57 | torch::Tensor a, 58 | torch::Tensor w 59 | ) { 60 | int m = a.size(0); 61 | int k = a.size(1); 62 | int n = w.size(0); 63 | auto r = torch::empty({m, n}, a.options()); 64 | 65 | const half* a_data = (half*)a.data_ptr(); 66 | const half* w_data = (half*)w.data_ptr(); 67 | half* r_data = (half*)r.data_ptr(); 68 | 69 | if (!has_handle) { 70 | cublas_init_handle(); 71 | has_handle = true; 72 | } 73 | 74 | array_linear( 75 | m, 76 | n, 77 | k, 78 | a_data, 79 | w_data, 80 | r_data 81 | ); 82 | 83 | return r; 84 | } 85 | 86 | void linear_inplace( 87 | torch::Tensor a, 88 | torch::Tensor w, 89 | torch::Tensor r 90 | ) { 91 | int m = a.size(0); 92 | int k = a.size(1); 93 | int n = w.size(0); 94 | 95 | const half* a_data = (half*)a.data_ptr(); 96 | const half* w_data = (half*)w.data_ptr(); 97 | half* r_data = (half*)r.data_ptr(); 98 | 99 | if (!has_handle) { 100 | cublas_init_handle(); 101 | has_handle = true; 102 | } 103 | 104 | array_linear( 105 | m, 106 | n, 107 | k, 108 | a_data, 109 | w_data, 110 | r_data 111 | ); 112 | } -------------------------------------------------------------------------------- /swiftllm/worker/layers/post_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Post layer of the model. 3 | """ 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | # pylint: disable=no-name-in-module 9 | from swiftllm_c import fused_add_rmsnorm_inplace 10 | 11 | from swiftllm.model_config import LlamaModelConfig 12 | from swiftllm.worker.weight import LlamaWeight 13 | # from swiftllm.worker.kernels.rmsnorm import rmsnorm_inplace 14 | from swiftllm.worker.kernels.linear import linear 15 | from swiftllm.structs import SubBatch 16 | 17 | class LlamaPostLayer: 18 | """ 19 | Post layer of the model. 20 | """ 21 | def __init__( 22 | self, 23 | model_config: LlamaModelConfig, 24 | weights: LlamaWeight, 25 | ): 26 | self.model_config = model_config 27 | self.weights = weights 28 | 29 | 30 | def forward( 31 | self, 32 | batches: list[SubBatch], 33 | input_embeds: torch.Tensor, # [num_total_tokens, hidden_size] 34 | residual_buf: torch.Tensor # [num_total_tokens, hidden_size] 35 | ) -> list[int]: 36 | """ 37 | Forward pass of the post layer. 38 | """ 39 | offs = 0 40 | for batch in batches: 41 | last_token_indices = torch.cat( 42 | ( 43 | last_token_indices, 44 | batch.last_token_indices + offs 45 | ), dim=0 46 | ) if offs else batch.last_token_indices 47 | offs += batch.iter_width 48 | 49 | input_embeds = input_embeds[last_token_indices, :] 50 | residual_buf = residual_buf[last_token_indices, :] 51 | 52 | if self.model_config.world_size > 1: 53 | dist.all_reduce(input_embeds) 54 | 55 | fused_add_rmsnorm_inplace( 56 | input_embeds, 57 | residual_buf, 58 | self.weights.final_norm, 59 | self.model_config.rms_norm_eps 60 | ) 61 | 62 | logits = linear(input_embeds, self.weights.lm_head) # [batch_size, vocab_size] 63 | if self.model_config.world_size > 1: 64 | gather_list = [torch.zeros_like(logits) for _ in range(self.model_config.world_size)] \ 65 | if self.model_config.rank == 0 else None 66 | dist.gather(logits, gather_list) # only rank 0 will have the final logits 67 | if self.model_config.rank == 0: 68 | logits = torch.cat(gather_list, dim=1) 69 | return torch.argmax(logits, dim=1).tolist() if self.model_config.rank == 0 else [] 70 | -------------------------------------------------------------------------------- /swiftllm/model_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | 5 | class LlamaModelConfig: 6 | """ 7 | The configuration of a LLaMA model (including LLaMA 1/2/3). 8 | """ 9 | 10 | def __init__( 11 | self, 12 | model_config: dict 13 | ): 14 | """ 15 | Initialize a LLaMA model configuration from a dict, which should be generated 16 | from a huggingface transformers config.json file. 17 | """ 18 | 19 | assert model_config["model_type"] == "llama" 20 | self.num_layers = model_config["num_hidden_layers"] 21 | self.num_q_heads = model_config["num_attention_heads"] 22 | self.num_kv_heads = model_config.get("num_key_value_heads", self.num_q_heads) 23 | self.hidden_size = model_config["hidden_size"] 24 | self.head_dim = self.hidden_size // self.num_q_heads 25 | self.vocab_size = model_config["vocab_size"] 26 | self.max_position_embeddings = model_config["max_position_embeddings"] 27 | self.ffn_inter_dim = model_config["intermediate_size"] 28 | self.rotary_base = model_config.get("rope_theta", model_config.get("rotary_base", 10000)) 29 | self.rms_norm_eps = model_config["rms_norm_eps"] 30 | self.rope_theta = model_config.get("rope_theta", 10000) 31 | rope_scaling = model_config.get("rope_scaling", None) 32 | if rope_scaling is None: 33 | self.rope_scaling_factor = 1.0 34 | else: 35 | self.rope_scaling_factor = 1.0 36 | # Here we use 1.0 for simplicity, but in practice it should be determined by the model_config["rope_scaling"] dictionary 37 | assert model_config["hidden_act"] == "silu" 38 | 39 | self.rank = None 40 | self.world_size = None 41 | 42 | def get_kvslot_size(self, extra_layer: bool = False, dtype: torch.dtype = torch.float16) -> int: 43 | """ 44 | Get the size of one kv slot (the kv cache of one token) (in bytes) 45 | """ 46 | return (2 * (self.num_layers + extra_layer) * self.num_kv_heads * self.head_dim) * dtype.itemsize 47 | 48 | @property 49 | def softmax_scale(self) -> float: 50 | """ 51 | Get the scale of the softmax function 52 | """ 53 | return self.head_dim ** -0.5 54 | 55 | @staticmethod 56 | def load_from_model_path(model_path: str) -> "LlamaModelConfig": 57 | with open(os.path.join(model_path, "config.json"), "r", encoding="utf-8") as f: 58 | model_config_dict = json.loads(f.read()) 59 | return LlamaModelConfig(model_config_dict) 60 | -------------------------------------------------------------------------------- /swiftllm/server/api_server.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | import os 3 | 4 | import argparse 5 | import asyncio 6 | import fastapi 7 | import uvicorn 8 | 9 | import swiftllm 10 | 11 | TIMEOUT_KEEP_ALIVE = 5 # in seconds 12 | 13 | app = fastapi.FastAPI() 14 | engine = None 15 | 16 | @app.post("/v1/completions") 17 | async def generate(req: fastapi.Request) -> fastapi.Response: 18 | """ 19 | Generate completion for the request. 20 | 21 | The request should be a JSON object with fields that match the `RawRequest` 22 | class plus the following fields: 23 | - `stream`: boolean, whether to stream the output or not 24 | """ 25 | req_dict = await req.json() 26 | raw_request = swiftllm.RawRequest( 27 | prompt = req_dict["prompt"], 28 | max_output_len = req_dict["max_tokens"] 29 | ) 30 | 31 | if req_dict.get("stream", False): 32 | generator = engine.add_request_and_stream(raw_request) 33 | async def wrapper(): 34 | async for step_output in generator: 35 | yield f"{step_output.token_id}\n" 36 | return fastapi.responses.StreamingResponse( 37 | wrapper(), 38 | media_type="text/plain" 39 | ) 40 | else: 41 | # TODO Abort the request when the client disconnects 42 | (_, output_token_ids) = await engine.add_request_and_wait(raw_request) 43 | return fastapi.responses.JSONResponse( 44 | content={"choices": [{"text": output_token_ids}]} 45 | ) 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("--host", type=str, default="localhost") 50 | parser.add_argument("--port", type=int, default=8000) 51 | swiftllm.EngineConfig.add_cli_args(parser) 52 | 53 | args = parser.parse_args() 54 | args = vars(args) 55 | 56 | host = args.pop("host") 57 | port = args.pop("port") 58 | engine = swiftllm.AsyncEngine(swiftllm.EngineConfig(**args)) 59 | 60 | uvicorn_config = uvicorn.Config( 61 | app, 62 | host=host, 63 | port=port, 64 | log_level="info", 65 | timeout_keep_alive=TIMEOUT_KEEP_ALIVE 66 | ) 67 | uvicorn_server = uvicorn.Server(uvicorn_config) 68 | 69 | async def main_coroutine(): 70 | await engine.initialize_async() 71 | 72 | uvicorn_task = asyncio.create_task(uvicorn_server.serve()) 73 | engine_task = asyncio.create_task(engine.start_all_event_loops()) 74 | 75 | try: 76 | await engine_task 77 | except: # pylint: disable=broad-except 78 | traceback.print_exc() 79 | uvicorn_task.cancel() 80 | os._exit(1) # Kill myself, or it will print tons of errors. Don't know why. 81 | 82 | asyncio.run(main_coroutine()) 83 | -------------------------------------------------------------------------------- /swiftllm/worker/kernels/rotary_emb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | import swiftllm_c 5 | 6 | from swiftllm.worker.infer_state import LlamaInferState 7 | 8 | @triton.jit 9 | def _fwd_rotary_embedding( 10 | q: torch.Tensor, # [num_tokens, num_q_heads, head_dim] 11 | k: torch.Tensor, # [num_tokens, num_k_heads, head_dim] 12 | cos_table: torch.Tensor, # [num_tokens, head_dim//2] 13 | sin_table: torch.Tensor, # [num_tokens, head_dim//2] 14 | 15 | num_q_heads: tl.constexpr, 16 | num_kv_heads: tl.constexpr, 17 | gqa_group_size: tl.constexpr, # = num_q_heads / num_kv_heads 18 | head_dim: tl.constexpr 19 | ): 20 | # grid: [num_tokens, num_kv_heads] 21 | my_token_id = tl.program_id(0) 22 | my_kv_head = tl.program_id(1) 23 | 24 | q += my_token_id*num_q_heads*head_dim + my_kv_head*gqa_group_size*head_dim # [gqa_group_size, head_dim] 25 | k += my_token_id*num_kv_heads*head_dim + my_kv_head*head_dim # [head_dim] 26 | 27 | offs0 = tl.arange(0, head_dim//2) 28 | offs1 = tl.arange(head_dim//2, head_dim) 29 | 30 | cos = tl.load(cos_table + my_token_id*(head_dim//2) + offs0) 31 | sin = tl.load(sin_table + my_token_id*(head_dim//2) + offs0) 32 | 33 | offs_q0 = (tl.arange(0, gqa_group_size)*head_dim)[:, None] + offs0[None, :] 34 | offs_q1 = (tl.arange(0, gqa_group_size)*head_dim)[:, None] + offs1[None, :] 35 | q0 = tl.load(q + offs_q0) 36 | q1 = tl.load(q + offs_q1) 37 | tl.store(q + offs_q0, q0*cos - q1*sin) 38 | tl.store(q + offs_q1, q0*sin + q1*cos) 39 | 40 | k0 = tl.load(k + offs0) 41 | k1 = tl.load(k + offs1) 42 | tl.store(k + offs0, k0*cos - k1*sin) 43 | tl.store(k + offs1, k0*sin + k1*cos) 44 | 45 | def rotary_embedding_inplace( 46 | q: torch.Tensor, # [num_tokens, num_q_heads, head_dim] 47 | k: torch.Tensor, # [num_tokens, num_k_heads, head_dim] 48 | sin_table: torch.Tensor, # [num_tokens, head_dim//2] 49 | cos_table: torch.Tensor # [num_tokens, head_dim//2] 50 | ): 51 | num_tokens = q.shape[0] 52 | num_q_heads = q.shape[1] 53 | num_kv_heads = k.shape[1] 54 | head_dim = k.shape[2] 55 | grid = (num_tokens, num_kv_heads) 56 | _fwd_rotary_embedding[grid]( 57 | q, k, 58 | cos_table, sin_table, 59 | num_q_heads, num_kv_heads, num_q_heads//num_kv_heads, head_dim 60 | ) 61 | 62 | if __name__ == '__main__': 63 | q0 = torch.randn(4, 32, 128, dtype=torch.float16, device='cuda') 64 | q1 = q0.clone() 65 | k0 = torch.randn(4, 32, 128, dtype=torch.float16, device='cuda') 66 | k1 = k0.clone() 67 | sin_table = torch.randn(4, 64, dtype=torch.float16, device='cuda') 68 | cos_table = torch.randn(4, 64, dtype=torch.float16, device='cuda') 69 | 70 | rotary_embedding_inplace(q0, k0, sin_table, cos_table) 71 | swiftllm_c.rotary_embedding_inplace(q1, k1, sin_table, cos_table) 72 | print(q0[0,0]) 73 | print(q1[0,0]) 74 | 75 | assert torch.allclose(q0, q1, atol=1e-6), "q0 != q1" 76 | assert torch.allclose(k0, k1, atol=1e-6), "k0 != k1" 77 | -------------------------------------------------------------------------------- /swiftllm/worker/kernels/rmsnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | import swiftllm_c 5 | 6 | @triton.jit 7 | def _fwd_rmsnorm( 8 | input_and_output: torch.Tensor, # [num_tokens, hidden_size], contiguous 9 | weight: torch.Tensor, # [hidden_size] 10 | eps: float, 11 | 12 | hidden_size: tl.constexpr 13 | ): 14 | # grid shape: [num_tokens] 15 | my_token_id = tl.program_id(0) 16 | input_and_output += my_token_id * hidden_size # [hidden_size] 17 | 18 | offs = tl.arange(0, hidden_size) 19 | x = tl.load(input_and_output+offs).to(tl.float32) 20 | variance = tl.sum(x*x, axis=0) / hidden_size 21 | rstd = 1 / tl.sqrt(variance + eps) 22 | 23 | w = tl.load(weight+offs).to(tl.float32) 24 | x = x*rstd*w 25 | tl.store(input_and_output+offs, x.to(tl.float16)) 26 | 27 | def rmsnorm_inplace( 28 | input_and_output: torch.Tensor, # [num_tokens, hidden_size] 29 | weight: torch.Tensor, 30 | eps: float 31 | ): 32 | grid = (input_and_output.shape[0], ) 33 | _fwd_rmsnorm[grid]( 34 | input_and_output, 35 | weight, 36 | eps, 37 | input_and_output.shape[1] 38 | ) 39 | 40 | @triton.jit 41 | def _fwd_fused_add_rmsnorm( 42 | input_and_output: torch.Tensor, # [num_tokens, hidden_size], contiguous 43 | residual_io: torch.Tensor, # [num_tokens, hidden_size], contiguous 44 | weight: torch.Tensor, # [hidden_size] 45 | eps: float, 46 | 47 | hidden_size: tl.constexpr 48 | ): 49 | # grid shape: [num_tokens] 50 | my_token_id = tl.program_id(0) 51 | input_and_output += my_token_id * hidden_size # [hidden_size] 52 | residual_io += my_token_id * hidden_size 53 | 54 | offs = tl.arange(0, hidden_size) 55 | x = tl.load(input_and_output+offs) 56 | r = tl.load(residual_io+offs) 57 | x += r 58 | tl.store(residual_io+offs, x) 59 | 60 | x = x.to(tl.float32) 61 | variance = tl.sum(x*x, axis=0) / hidden_size 62 | rstd = 1 / tl.sqrt(variance + eps) 63 | 64 | w = tl.load(weight+offs).to(tl.float32) 65 | x = x*rstd*w 66 | tl.store(input_and_output+offs, x.to(tl.float16)) 67 | 68 | def fused_add_rmsnorm_inplace( 69 | input_and_output: torch.Tensor, # [num_tokens, hidden_size] 70 | residual_io: torch.Tensor, 71 | weight: torch.Tensor, 72 | eps: float 73 | ): 74 | """ 75 | Perform fused add & rmsnorm 76 | 77 | This function accepts input_and_output (x), residual_io (r), and weight(w) 78 | as inputs, set r = x+r, and x = rms_norm(x+r, w) 79 | """ 80 | assert input_and_output.is_contiguous() 81 | assert residual_io.is_contiguous() 82 | assert weight.is_contiguous() 83 | grid = (input_and_output.shape[0], ) 84 | _fwd_fused_add_rmsnorm[grid]( 85 | input_and_output, 86 | residual_io, 87 | weight, 88 | eps, 89 | input_and_output.shape[1] 90 | ) 91 | 92 | if __name__ == "__main__": 93 | x = torch.randn(10, 128, dtype=torch.float16, device="cuda") 94 | y = x.clone() 95 | r = torch.randn(10, 128, dtype=torch.float16, device="cuda") 96 | s = r.clone() 97 | w = torch.randn(1024, dtype=torch.float16, device="cuda") 98 | eps = 1e-6 99 | 100 | fused_add_rmsnorm_inplace(x, r, w, eps) 101 | swiftllm_c.fused_add_rmsnorm_inplace(y, s, w, eps) 102 | 103 | assert torch.allclose(x, y, atol=1e-5), "Mismatch x, y" 104 | assert torch.allclose(r, s, atol=1e-5), "Mismatch r, s" 105 | -------------------------------------------------------------------------------- /swiftllm/worker/kernels/kvcache_mgmt.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import triton 4 | import triton.language as tl 5 | 6 | from swiftllm.model_config import LlamaModelConfig 7 | from swiftllm.engine_config import EngineConfig 8 | from swiftllm.worker.infer_state import LlamaInferState 9 | from swiftllm.utils import cdiv 10 | 11 | @triton.jit 12 | def _fwd_kvcache_mgmt_prefill_kernel( 13 | k_cache: torch.Tensor, # [num_blocks, num_layers, num_kv_heads, block_size, head_dim], contiguous 14 | v_cache: torch.Tensor, # [num_blocks, num_layers, num_kv_heads, block_size, head_dim], contiguous 15 | k: torch.Tensor, # [num_prefill_tokens, num_kv_heads, head_dim], contiguous 16 | v: torch.Tensor, # [num_prefill_tokens, num_kv_heads, head_dim], contiguous 17 | block_table: torch.Tensor, # [*, max_blocks_per_seq], contiguous 18 | seq_ids: torch.Tensor, # [num_prefill_seqs], contiguous 19 | prefill_seq_start_locs: torch.Tensor, # [num_prefill_seqs], contiguous 20 | prefill_seq_lens: torch.Tensor, # [num_prefill_seqs], contiguous 21 | cur_layer: int, 22 | 23 | num_layers: tl.constexpr, 24 | num_kv_heads: tl.constexpr, 25 | block_size: tl.constexpr, 26 | head_dim: tl.constexpr, 27 | max_blocks_per_seq: tl.constexpr, 28 | ): 29 | # grid shape: [num_prefill_seqs, cdiv(max_prefill_len, block_size)] 30 | my_batch_id = tl.program_id(0) 31 | my_block_id = tl.program_id(1) 32 | my_seq_len = tl.load(prefill_seq_lens + my_batch_id) 33 | my_seq_start_loc = tl.load(prefill_seq_start_locs + my_batch_id) 34 | if my_block_id*block_size >= my_seq_len: 35 | return 36 | 37 | my_token_range = tl.arange(0, block_size).to(tl.int64) + my_block_id*block_size + my_seq_start_loc 38 | my_seq_id = tl.load(seq_ids + my_batch_id) 39 | my_block_index = tl.load(block_table + my_seq_id*max_blocks_per_seq + my_block_id).to(tl.int64) 40 | 41 | offs_kv = (my_token_range*num_kv_heads*head_dim).to(tl.int64)[:, None, None] + (tl.arange(0, num_kv_heads)*head_dim)[None, :, None] + tl.arange(0, head_dim)[None, None, :] 42 | offs_kvcache = (my_block_index*num_layers+cur_layer)*num_kv_heads*block_size*head_dim + \ 43 | (tl.arange(0, num_kv_heads)*block_size*head_dim)[None, :, None] + \ 44 | (tl.arange(0, block_size)*head_dim)[:, None, None] + \ 45 | tl.arange(0, head_dim)[None, None, :] 46 | 47 | mask = (my_token_range < my_seq_len + my_seq_start_loc)[:, None, None] 48 | tl.store(k_cache + offs_kvcache, tl.load(k + offs_kv, mask=mask), mask=mask) 49 | tl.store(v_cache + offs_kvcache, tl.load(v + offs_kv, mask=mask), mask=mask) 50 | 51 | def store_kvcache( 52 | k: torch.Tensor, 53 | v: torch.Tensor, 54 | k_cache: torch.Tensor, 55 | v_cache: torch.Tensor, 56 | block_table: torch.Tensor, 57 | seq_ids: torch.Tensor, 58 | prefill_seq_start_locs: torch.Tensor, 59 | prefill_seq_lens: torch.Tensor, 60 | cur_layer: int, 61 | max_prefill_len: int 62 | ): 63 | assert k.is_contiguous() 64 | assert v.is_contiguous() 65 | assert k_cache.is_contiguous() 66 | assert v_cache.is_contiguous() 67 | assert block_table.is_contiguous() 68 | assert seq_ids.is_contiguous() 69 | 70 | num_layers = k_cache.shape[1] 71 | num_kv_heads = k_cache.shape[2] 72 | block_size = k_cache.shape[3] 73 | head_dim = k_cache.shape[4] 74 | block_table_width = block_table.shape[1] 75 | num_prefill_seqs = seq_ids.shape[0] 76 | 77 | grid = (num_prefill_seqs, cdiv(max_prefill_len, block_size)) 78 | _fwd_kvcache_mgmt_prefill_kernel[grid]( 79 | k_cache, v_cache, 80 | k, v, 81 | block_table, 82 | seq_ids, 83 | prefill_seq_start_locs, 84 | prefill_seq_lens, 85 | cur_layer, 86 | num_layers, num_kv_heads, block_size, head_dim, block_table_width 87 | ) 88 | -------------------------------------------------------------------------------- /swiftllm/server/executor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model executor classes. 3 | 4 | Provides control plane APIs for the engine. Calls the data plane APIs under the hood. 5 | """ 6 | 7 | import os 8 | from abc import ABC, abstractmethod 9 | 10 | import ray 11 | 12 | from swiftllm.worker.model import ModelPerfResult, LlamaModel, RemoteLlamaModel 13 | from swiftllm.engine_config import EngineConfig 14 | from swiftllm.model_config import LlamaModelConfig 15 | 16 | class Executor(ABC): 17 | """ 18 | Base class for executors. 19 | """ 20 | def __init__( 21 | self, 22 | engine_config: EngineConfig, 23 | model_config: LlamaModelConfig 24 | ): 25 | raise NotImplementedError 26 | 27 | 28 | @abstractmethod 29 | def init_kvcache_and_swap(self): 30 | """ 31 | Initialize the key-value cache and swap. 32 | """ 33 | raise NotImplementedError 34 | 35 | 36 | @abstractmethod 37 | def do_one_iteration(self, *args) -> list[int]: 38 | """ 39 | Do one iteration of the model. 40 | """ 41 | raise NotImplementedError 42 | 43 | 44 | @abstractmethod 45 | def turn_on_perf_monitor(self): 46 | """ 47 | Turn on performance monitoring. 48 | """ 49 | raise NotImplementedError 50 | 51 | 52 | @abstractmethod 53 | def turn_off_perf_monitor_and_flush_results(self) -> list[ModelPerfResult]: 54 | """ 55 | Turn off performance monitoring and flush results. 56 | """ 57 | raise NotImplementedError 58 | 59 | 60 | 61 | class SingleProcExecutor(Executor): 62 | """ 63 | Single process executor. 64 | """ 65 | def __init__( 66 | self, 67 | engine_config: EngineConfig, 68 | model_config: LlamaModelConfig 69 | ): 70 | self.engine_config = engine_config 71 | self.model_config = model_config 72 | tpd = engine_config.tensor_parallel_degree 73 | assert tpd == 1, f"SingleProcExecutor does not support tensor parallelism degree({tpd}) == 1" 74 | self.model = LlamaModel(engine_config, model_config, rank=0) 75 | 76 | 77 | def init_kvcache_and_swap(self): 78 | self.model.init_kvcache_and_swap(self.engine_config) 79 | 80 | 81 | def do_one_iteration(self, *args) -> list[int]: 82 | return self.model.do_one_iteration(*args) 83 | 84 | 85 | def turn_on_perf_monitor(self): 86 | self.model.turn_on_perf_monitor() 87 | 88 | 89 | def turn_off_perf_monitor_and_flush_results(self) -> list[ModelPerfResult]: 90 | return self.model.turn_off_perf_monitor_and_flush_results() 91 | 92 | 93 | class RayExecutor(Executor): 94 | """ 95 | Ray executor. Inits ray framework when instantiated. 96 | """ 97 | # pylint: disable=no-member 98 | def __init__( 99 | self, 100 | engine_config: EngineConfig, 101 | model_config: LlamaModelConfig 102 | ): 103 | os.environ["MASTER_ADDR"] = "localhost" 104 | os.environ["MASTER_PORT"] = "29500" 105 | self.engine_config = engine_config 106 | self.model_config = model_config 107 | 108 | num_workers = engine_config.tensor_parallel_degree 109 | self.models = [RemoteLlamaModel.remote(engine_config, model_config, rank=i) for i in range(num_workers)] 110 | 111 | 112 | def init_kvcache_and_swap(self): 113 | ray.get([model.init_kvcache_and_swap.remote(self.engine_config) for model in self.models]) 114 | 115 | 116 | def do_one_iteration(self, *args) -> list[int]: 117 | return ray.get([model.do_one_iteration.remote(*args) for model in self.models])[0] 118 | 119 | 120 | def turn_on_perf_monitor(self): 121 | ray.get(self.models[0].turn_on_perf_monitor.remote()) 122 | 123 | 124 | def turn_off_perf_monitor_and_flush_results(self) -> list[ModelPerfResult]: 125 | return ray.get(self.models[0].turn_off_perf_monitor_and_flush_results.remote()) 126 | -------------------------------------------------------------------------------- /csrc/src/block_swapping.cpp: -------------------------------------------------------------------------------- 1 | #include "block_swapping.h" 2 | 3 | #include // for at::cuda::getCurrentCUDAStream() 4 | 5 | inline size_t getTensorSizeInBytes(const torch::Tensor &tensor) { 6 | return tensor.numel() * torch::elementSize(torch::typeMetaToScalarType(tensor.dtype())); 7 | } 8 | 9 | // swap_blocks - Perform swapping between GPU blocks and CPU blocks 10 | // The source_block_ids and target_block_ids are the block ids of the blocks to be swapped. 11 | // source_block_ids[0] will be copied to target_block_ids[0] and so on 12 | // `is_swap_out` defines whether the swap is a swap-in or swap-out (swap-in means 13 | // to swap from CPU to GPU, swap-out means to swap from GPU to CPU) 14 | // 15 | // Here we do not pass a cudaStream to the function. Instead we use the current 16 | // stream indicated by at::cuda::getCurrentCUDAStream(). So it is python's 17 | // responsibility to set the current stream before calling this function. 18 | // 19 | // Future work: Now the number of cudaMemcpyAsync calls is equal to 2x the number 20 | // of blocks to swap. We can reduce the number of cudaMemcpyAsync calls by 21 | // grouping nearby blocks together and perform a single invocation 22 | void swap_blocks( 23 | const std::vector &source_block_ids, 24 | const std::vector &target_block_ids, 25 | const bool is_swap_out, 26 | const int gpu_layer, 27 | const int cpu_layer, 28 | 29 | torch::Tensor k_cache, 30 | torch::Tensor v_cache, 31 | torch::Tensor k_swap, 32 | torch::Tensor v_swap 33 | ) { 34 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 35 | size_t gpu_layer_size_in_bytes = getTensorSizeInBytes(k_cache) / k_cache.size(0); 36 | size_t cpu_layer_size_in_bytes = getTensorSizeInBytes(k_swap) / k_swap.size(0); 37 | size_t block_layer_size_in_bytes = gpu_layer_size_in_bytes / k_cache.size(1); // same for gpu and cpu 38 | 39 | char* k_cache_ptr = (char*)k_cache.data_ptr() + gpu_layer * gpu_layer_size_in_bytes; 40 | char* v_cache_ptr = (char*)v_cache.data_ptr() + gpu_layer * gpu_layer_size_in_bytes; 41 | char* k_swap_ptr = (char*)k_swap.data_ptr() + cpu_layer * cpu_layer_size_in_bytes; 42 | char* v_swap_ptr = (char*)v_swap.data_ptr() + cpu_layer * cpu_layer_size_in_bytes; 43 | int num_blocks_to_swap = source_block_ids.size(); 44 | int next_index = 0; 45 | while (next_index < num_blocks_to_swap) { 46 | int start_index = next_index; 47 | int end_index = start_index+1; 48 | while (end_index < num_blocks_to_swap && 49 | source_block_ids[end_index] == source_block_ids[end_index-1]+1 && 50 | target_block_ids[end_index] == target_block_ids[end_index-1]+1) { 51 | end_index++; 52 | } 53 | int cur_segment_len = end_index - start_index; 54 | size_t cur_segment_size_in_bytes = cur_segment_len * block_layer_size_in_bytes; 55 | int64_t start_source_block_id = source_block_ids[start_index]; 56 | int64_t start_target_block_id = target_block_ids[start_index]; 57 | 58 | if (!is_swap_out) { 59 | // Copy from CPU to GPU 60 | cudaMemcpyAsync( 61 | k_cache_ptr + start_target_block_id * block_layer_size_in_bytes, 62 | k_swap_ptr + start_source_block_id * block_layer_size_in_bytes, 63 | cur_segment_size_in_bytes, 64 | cudaMemcpyHostToDevice, 65 | stream 66 | ); 67 | cudaMemcpyAsync( 68 | v_cache_ptr + start_target_block_id * block_layer_size_in_bytes, 69 | v_swap_ptr + start_source_block_id * block_layer_size_in_bytes, 70 | cur_segment_size_in_bytes, 71 | cudaMemcpyHostToDevice, 72 | stream 73 | ); 74 | } else { 75 | // Copy from GPU to CPU 76 | cudaMemcpyAsync( 77 | k_swap_ptr + start_target_block_id * block_layer_size_in_bytes, 78 | k_cache_ptr + start_source_block_id * block_layer_size_in_bytes, 79 | cur_segment_size_in_bytes, 80 | cudaMemcpyDeviceToHost, 81 | stream 82 | ); 83 | cudaMemcpyAsync( 84 | v_swap_ptr + start_target_block_id * block_layer_size_in_bytes, 85 | v_cache_ptr + start_source_block_id * block_layer_size_in_bytes, 86 | cur_segment_size_in_bytes, 87 | cudaMemcpyDeviceToHost, 88 | stream 89 | ); 90 | } 91 | 92 | next_index = end_index; 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /temp 2 | /.vscode 3 | /.history 4 | 5 | # NVIDIA Nsight Systems profiler files 6 | *.nsys-rep 7 | *.sqlite 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | cover/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | .pybuilder/ 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # pdm 113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 114 | #pdm.lock 115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 116 | # in version control. 117 | # https://pdm.fming.dev/#use-with-ide 118 | .pdm.toml 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | 171 | # NEO profile results 172 | profile_results/ 173 | 174 | # NEO evaluation results 175 | evaluation/results/ 176 | evaluation/*.pdf 177 | -------------------------------------------------------------------------------- /swiftllm/worker/block_swapper.py: -------------------------------------------------------------------------------- 1 | """ 2 | BlockManager and Swapper 3 | 4 | Contains initalization and transition logics for the KV cache. 5 | """ 6 | 7 | import torch 8 | import swiftllm_c 9 | from swiftllm.engine_config import EngineConfig 10 | from swiftllm.model_config import LlamaModelConfig 11 | 12 | class Swapper: 13 | """ 14 | Swapper - Manage the swapping of sequences in and out of the model 15 | 16 | This manager is responsible for swapping sequences in and out of the model. 17 | It maintains the block manager, and provides methods to swap sequences in 18 | and out. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | engine_config: EngineConfig, 24 | model_config: LlamaModelConfig 25 | ): 26 | self.engine_config = engine_config 27 | self.model_config = model_config 28 | 29 | num_q_heads = model_config.num_q_heads // model_config.world_size 30 | num_kv_heads = model_config.num_kv_heads // model_config.world_size 31 | 32 | # Initialize KV cache 33 | kvcache_shape = ( 34 | model_config.num_layers + engine_config.extra_layer_for_cprf, 35 | engine_config.num_gpu_blocks, 36 | num_kv_heads, 37 | engine_config.block_size, 38 | model_config.head_dim 39 | ) 40 | # Here we use torch.zeros instead of torch.empty, since that torch.empty 41 | # has the possibility to contain NaNs, which will cause the model to output NaNs. 42 | self.k_cache = torch.zeros(kvcache_shape, dtype=torch.float16, device="cuda") 43 | self.v_cache = torch.zeros(kvcache_shape, dtype=torch.float16, device="cuda") 44 | 45 | # Initialize KV swap space 46 | kvswap_shape = ( 47 | model_config.num_layers, 48 | engine_config.num_cpu_blocks, 49 | num_kv_heads, 50 | engine_config.block_size, 51 | model_config.head_dim 52 | ) 53 | self.k_swap = torch.zeros(kvswap_shape, dtype=torch.float16, device="cpu", pin_memory=True) 54 | self.v_swap = torch.zeros(kvswap_shape, dtype=torch.float16, device="cpu", pin_memory=True) 55 | 56 | # Initialize CPU QKV buffer 57 | qo_cpu_shape = (engine_config.max_batch_size, num_q_heads, model_config.head_dim) 58 | kv_cpu_shape = (engine_config.max_batch_size, num_kv_heads, model_config.head_dim) 59 | self.q_cpu = torch.zeros(qo_cpu_shape, dtype=torch.float16, device="cpu", pin_memory=True) 60 | self.k_cpu = torch.zeros(kv_cpu_shape, dtype=torch.float16, device="cpu", pin_memory=True) 61 | self.v_cpu = torch.zeros(kv_cpu_shape, dtype=torch.float16, device="cpu", pin_memory=True) 62 | # We store float32 tensors for the output, but convert them to float16 after copying back to GPU 63 | self.o_cpu = torch.zeros(qo_cpu_shape, dtype=torch.float32, device="cpu", pin_memory=True) 64 | 65 | self.gpu_block_table = torch.zeros( 66 | (engine_config.max_seqs_in_block_table, engine_config.max_blocks_per_seq), 67 | dtype=torch.int32, 68 | device="cuda" 69 | ) 70 | self.cpu_block_table = torch.zeros( 71 | (engine_config.max_seqs_in_block_table, engine_config.max_blocks_per_seq), 72 | dtype=torch.int32, 73 | device="cpu" 74 | ) 75 | 76 | 77 | def swap_blocks( 78 | self, 79 | src_block_ids: list[int], 80 | dst_block_ids: list[int], 81 | is_swap_out: bool, 82 | gpu_layer: int, 83 | cpu_layer: int 84 | ): 85 | """ 86 | Swap blocks between the GPU and CPU, the physical indexes of the blocks are given. 87 | """ 88 | # pylint: disable=too-many-arguments, c-extension-no-member 89 | assert len(src_block_ids) == len(dst_block_ids), "Length mismatch between src_block_ids and dst_block_ids" 90 | if not src_block_ids: 91 | return 92 | swiftllm_c.swap_blocks( 93 | src_block_ids, 94 | dst_block_ids, 95 | is_swap_out, 96 | gpu_layer, 97 | cpu_layer, 98 | 99 | self.k_cache, self.v_cache, 100 | self.k_swap, self.v_swap 101 | ) 102 | 103 | @torch.inference_mode() 104 | def set_block_tables( 105 | self, 106 | mappings: tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]] 107 | ): 108 | """ 109 | Establish new mappings in the block tables 110 | """ 111 | (gpu_vids, gpu_pids), (cpu_vids, cpu_pids) = mappings 112 | if gpu_vids: 113 | self.gpu_block_table.view(-1)[gpu_vids] = torch.tensor(gpu_pids, dtype=torch.int32, device="cuda") 114 | if cpu_vids: 115 | self.cpu_block_table.view(-1)[cpu_vids] = torch.tensor(cpu_pids, dtype=torch.int32, device="cpu") 116 | -------------------------------------------------------------------------------- /evaluation/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | import time 4 | import logging 5 | import json 6 | import random 7 | 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | # pylint: disable=import-error 12 | from api_client import request_completions 13 | 14 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 15 | res_dir = f"{cur_dir}/results" 16 | os.makedirs(res_dir, exist_ok=True) 17 | 18 | logger = logging.getLogger(__name__) 19 | logging.basicConfig(filename=f"{cur_dir}/evaluation.log", level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S') 20 | 21 | api_url = "http://localhost:8000/v1/completions" 22 | 23 | async def request_completions_task(prompt: list[int], output_len: int, model_path: str): 24 | start = time.perf_counter() 25 | await request_completions(api_url, prompt, output_len, model_path) 26 | end = time.perf_counter() 27 | return start, end 28 | 29 | 30 | async def run_test( 31 | prompts: list[list[int]], 32 | output_lens: list[int], 33 | res_prefix: str, 34 | model_path: str, 35 | rate: float = -1 # -1 means throughput test 36 | ): 37 | if rate > 0: 38 | res_file = f"{res_prefix}-lat-{str(rate).replace('.', '_')}.json" 39 | else: 40 | res_file = f"{res_prefix}-tp.json" 41 | 42 | if os.path.exists(res_file): 43 | logger.info("Test result file already exists: %s", res_file) 44 | with open(res_file, "r") as f: 45 | data = json.load(f) 46 | times = [(d["start"], d["end"]) for d in data] 47 | else: 48 | logger.info("Running test, saving results to %s", res_file) 49 | 50 | tasks = [] 51 | np.random.seed(0) 52 | gaps = np.random.exponential(1 / rate, len(prompts)).tolist() if rate > 0 else [0] * len(prompts) 53 | for prompt, output_len in tqdm(zip(prompts, output_lens)): 54 | task = asyncio.create_task(request_completions_task(prompt, output_len, model_path)) 55 | tasks.append(task) 56 | if rate > 0: 57 | await asyncio.sleep(gaps.pop(0)) 58 | times = await asyncio.gather(*tasks) 59 | with open(res_file, "w") as f: 60 | json.dump([{ 61 | "input_len": len(prompt), 62 | "output_len": output_len, 63 | "start": start, 64 | "end": end 65 | } for (start, end), prompt, output_len in zip(times, prompts, output_lens)], f, indent=4) 66 | 67 | if rate > 0: 68 | comp_times = [end - start for start, end in times] 69 | pertok_times = [comp_time / (len(prompt) + output_len) for comp_time, prompt, output_len in zip(comp_times, prompts, output_lens)] 70 | average_completion_time = sum(comp_times) / len(comp_times) 71 | average_pertok_time = sum(pertok_times) / len(pertok_times) 72 | logger.info("Average completion time: %.3f s", average_completion_time) 73 | logger.info("Average per-token completion time: %.3f s", average_pertok_time) 74 | else: 75 | n = len(prompts) 76 | req_end_times = sorted([end for _, end in times]) 77 | req_end_times = req_end_times[n // 10: n - n // 10 * 3 + 1] 78 | throughput = (len(req_end_times) - 1) / (req_end_times[-1] - req_end_times[0]) 79 | logger.info("Throughput: %.3f req/s", throughput) 80 | 81 | 82 | def _get_rand_array(n: int, avg_val: int, ratio: float): 83 | """ 84 | Get a random array with average value `avg_val`, 85 | 86 | all values are uniformly distributed in the range of [avg_val * (1 - ratio), avg_val * (1 + ratio)] 87 | """ 88 | delta = int(avg_val * ratio) 89 | return [avg_val + random.randint(-delta, delta) for _ in range(n)] 90 | 91 | 92 | def prepare_mock_test( 93 | nreqs: int, 94 | input_len: int, 95 | output_len: int, 96 | server_name: str, 97 | config: dict 98 | ) -> tuple[list[list[int]], list[int], str]: 99 | input_lens = _get_rand_array(nreqs, input_len, 0.1) 100 | output_lens = _get_rand_array(nreqs, output_len, 0.1) 101 | prompts = [[10] * input_len for input_len in input_lens] 102 | res_file = f"{res_dir}/{server_name}-{nreqs}-{input_len}-{output_len}" 103 | return prompts, output_lens, res_file, config['model_path'] 104 | 105 | 106 | def prepare_real_test( 107 | dataset_name: str, 108 | config: dict, 109 | server_name: str 110 | ) -> tuple[list[list[int]], list[int], str]: 111 | input_file = f"{cur_dir}/data/{dataset_name}-{config['model']}.json" 112 | with open(input_file, "r") as f: 113 | # Remove the [:100] to use the full dataset. However, it may take a long time (~10h) to run the full test of fig6c. 114 | datas = json.load(f)[:100] 115 | prompts = [[10] * data["prompt"] for data in datas] 116 | output_lens = [data["max_tokens"] for data in datas] 117 | 118 | res_file = f"{res_dir}/{server_name}-{dataset_name}" 119 | return prompts, output_lens, res_file, config['model_path'] 120 | -------------------------------------------------------------------------------- /evaluation/server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | import time 5 | import logging 6 | 7 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 8 | neo_dir = os.path.dirname(cur_dir) 9 | logger = logging.getLogger(__name__) 10 | logging.basicConfig(filename=f"{cur_dir}/evaluation.log", level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S') 11 | 12 | server_proc = None 13 | 14 | def start_server(name: str, config: dict): 15 | """ 16 | Start the server 17 | """ 18 | # pylint: disable=global-statement 19 | global server_proc 20 | 21 | numacmd = ["numactl", "-N", "0", "-m", "0"] 22 | with open(f"{cur_dir}/{name}-server.log", "w") as f: 23 | if name[:4] == "vllm": 24 | chunk_size_str = name[4:] if name != "vllm" else str(config["num_gpu_blocks_override"] * config["block_size"]) 25 | max_num_seqs = min(int(chunk_size_str), config["max_num_seqs"]) 26 | server_proc = subprocess.Popen( 27 | numacmd + [ 28 | "vllm", "serve", config["model_path"], "--port", "8000", 29 | "--block-size", str(config["block_size"]), 30 | "--max-model-len", str(config["max_model_len"]), 31 | "--max-num-seqs", str(max_num_seqs), 32 | "--max-num-batched-tokens", chunk_size_str, 33 | "--tensor-parallel-size", str(config["tensor_parallel_size"]), 34 | # "--gpu-memory-utilization", str(config["gpu_memory_utilization"]), 35 | "--num-gpu-blocks-override", str(config["num_gpu_blocks_override"]), 36 | "--swap-space", str(config["swap_space"] / config["tensor_parallel_size"]), 37 | "--enforce-eager", 38 | "--disable-sliding-window", 39 | "--disable-async-output-proc", 40 | "--disable-custom-all-reduce", 41 | "--disable-frontend-multiprocessing", 42 | "--tokenizer-pool-size", "1", 43 | "--enable-chunked-prefill", 44 | "--preemption-mode", "recompute", 45 | "--dtype", "float16" 46 | ], 47 | env=os.environ | {"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"}, 48 | stdout=f, 49 | stderr=f 50 | ) 51 | 52 | elif name in ["ours", "base", "fsdc"]: 53 | nl = config['num_layers'] 54 | if name == "base": 55 | cmd=["--always-use-gpu"] 56 | num_gpu_blocks_override = config["num_gpu_blocks_override"] 57 | swap_space = config["swap_space"] // 8 58 | elif name == "ours": 59 | cmd=["--extra-layer-for-cprf"] 60 | num_gpu_blocks_override = config["num_gpu_blocks_override"] * nl // (nl + 1) 61 | swap_space = config["swap_space"] 62 | else: 63 | cmd=["--disable-partial-offl", "--extra-layer-for-cprf"] 64 | num_gpu_blocks_override = config["num_gpu_blocks_override"] * nl // (nl + 1) 65 | swap_space = config["swap_space"] 66 | 67 | cmd = numacmd + [ 68 | sys.executable, "-m", "swiftllm.server.api_server", 69 | "--port", "8000", 70 | "--model-path", config["model_path"], 71 | "--block-size", str(config["block_size"]), 72 | "--max-blocks-per-seq", str((config["max_num_batched_tokens"] - 1) // config["block_size"] + 1), 73 | "--max-seqs-in-block-table", str(config["max_num_seqs"]), 74 | "--max-batch-size", str(config["max_num_seqs"]), 75 | "--max-tokens-in-batch", str(config["max_num_batched_tokens"]), 76 | "--tensor-parallel-degree", str(config["tensor_parallel_size"]), 77 | # "--gpu-mem-utilization", str(config["gpu_memory_utilization"]), 78 | "--num-gpu-blocks-override", str(num_gpu_blocks_override), 79 | "--swap-space", str(swap_space), 80 | "--library-path", f"{neo_dir}/pacpu/build/{config['library']}", 81 | "--profile-result-path", f"{neo_dir}/profile_results/", 82 | ] + cmd 83 | 84 | server_proc = subprocess.Popen( 85 | cmd, 86 | stdout=f, 87 | stderr=f 88 | ) 89 | 90 | else: 91 | raise ValueError(f"Unknown server name: {name}") 92 | 93 | # Check the server log every 5s, until the starting keyword is found 94 | time_counter = 0 95 | while True: 96 | time.sleep(5) 97 | time_counter += 5 98 | logger.info(f"{time_counter}s elapsed, checking server log ...") 99 | with open(f"{cur_dir}/{name}-server.log", "r") as f: 100 | if "Started server process" in f.read(): 101 | break 102 | time.sleep(0.5) 103 | 104 | logger.info("Server started") 105 | 106 | 107 | def stop_server(): 108 | """ 109 | Stop the server 110 | """ 111 | assert server_proc is not None, "Server not started" 112 | server_proc.terminate() 113 | logger.info("Server stopped") 114 | -------------------------------------------------------------------------------- /swiftllm/engine_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration for the SwiftLLM engine. 3 | """ 4 | 5 | import dataclasses 6 | import argparse 7 | 8 | @dataclasses.dataclass 9 | class EngineConfig: 10 | """ 11 | Configuration for the SwiftLLM engine. 12 | """ 13 | 14 | # Model loading parameters 15 | model_path: str 16 | use_dummy: bool 17 | 18 | # PagedAttention-related parameters 19 | block_size: int 20 | gpu_mem_utilization: float 21 | num_gpu_blocks_override: int # -1 for not overriding 22 | swap_space: int 23 | max_seqs_in_block_table: int 24 | max_blocks_per_seq: int 25 | 26 | # Scheduling-related parameters 27 | max_batch_size: int 28 | max_tokens_in_batch: int 29 | 30 | # External paths 31 | library_path: str 32 | profile_result_path: str 33 | 34 | # Switches 35 | extra_layer_for_cprf: bool = False # Fixed after initialization 36 | disable_partial_offl: bool = False # Fixed after initialization 37 | monitor_performance: bool = False # Can be altered while running 38 | always_use_gpu: bool = False # Can be altered while running 39 | 40 | # Parallel parameter 41 | tensor_parallel_degree: int = 1 42 | 43 | # Derived parameters 44 | num_cpu_blocks: int = -1 45 | num_gpu_blocks: int = -1 46 | 47 | @property 48 | def max_seq_len(self) -> int: 49 | """ 50 | Maximum sequence length in tokens 51 | """ 52 | return self.block_size * self.max_blocks_per_seq 53 | 54 | @property 55 | def max_gpu_tokens(self) -> int: 56 | """ 57 | Maximum number of tokens that can be stored in the GPU 58 | """ 59 | return self.block_size * self.num_gpu_blocks 60 | 61 | @property 62 | def max_cpu_tokens(self) -> int: 63 | """ 64 | Maximum number of tokens that can be stored in the CPU 65 | """ 66 | return self.block_size * self.num_cpu_blocks 67 | 68 | @staticmethod 69 | def add_cli_args(parser: argparse.ArgumentParser): 70 | """ 71 | Add CLI arguments for the engine configuration 72 | """ 73 | parser.add_argument( 74 | "--model-path", 75 | type=str, 76 | required=True, 77 | help="Path to the model directory (currently SwiftLLM does not support downloading from HuggingFace, so please download in advance)", 78 | ) 79 | parser.add_argument( 80 | "--use-dummy", 81 | action="store_true", 82 | help="Use dummy weights (mainly for profiling)", 83 | ) 84 | 85 | parser.add_argument( 86 | "--block-size", 87 | type=int, 88 | default=16, 89 | help="Block size for PagedAttention", 90 | ) 91 | parser.add_argument( 92 | "--gpu-mem-utilization", 93 | type=float, 94 | default=0.99, 95 | help="Fraction of GPU memory to be used", 96 | ) 97 | parser.add_argument( 98 | "--num-gpu-blocks-override", 99 | type=int, 100 | default=-1, 101 | help="Override the number of GPU blocks", 102 | ) 103 | parser.add_argument( 104 | "--swap-space", 105 | type=int, 106 | default=20, 107 | help="Swap space in GB", 108 | ) 109 | parser.add_argument( 110 | "--max-seqs-in-block-table", 111 | type=int, 112 | default=768, 113 | help="Maximum number of sequences in the block table", 114 | ) 115 | parser.add_argument( 116 | "--max-blocks-per-seq", 117 | type=int, 118 | default=512, 119 | help="Maximum number of blocks per sequence", 120 | ) 121 | 122 | parser.add_argument( 123 | "--max-batch-size", 124 | type=int, 125 | default=512, 126 | help="Maximum batch size", 127 | ) 128 | parser.add_argument( 129 | "--max-tokens-in-batch", 130 | type=int, 131 | default=3072, 132 | help="Maximum number of tokens in a batch", 133 | ) 134 | 135 | parser.add_argument( 136 | "--library-path", 137 | type=str, 138 | help="Path to the external library", 139 | ) 140 | parser.add_argument( 141 | "--profile-result-path", 142 | type=str, 143 | help="Path to the profiling results", 144 | ) 145 | parser.add_argument( 146 | "--tensor-parallel-degree", 147 | type=int, 148 | default=1, 149 | help="Degree of tensor parallelism", 150 | ) 151 | parser.add_argument( 152 | "--disable-partial-offl", 153 | action="store_true", 154 | help="Disable partial offloading", 155 | ) 156 | parser.add_argument( 157 | "--always-use-gpu", 158 | action="store_true", 159 | help="Always use GPU", 160 | ) 161 | parser.add_argument( 162 | "--extra-layer-for-cprf", 163 | action="store_true", 164 | help="Use an extra layer for CPRF", 165 | ) 166 | -------------------------------------------------------------------------------- /pacpu/pacpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "dtype.h" 5 | #include "core.h" 6 | 7 | typedef at::Half at_data_t; 8 | 9 | // #define USE_ATEN_OPER 10 | #ifdef USE_ATEN_OPER 11 | #include 12 | #include 13 | 14 | void paged_attention_cpu_torch( 15 | int64_t cur_layer, 16 | double softmax_scale, 17 | const std::vector &decoding_seq_ids, 18 | const std::vector &decoding_seq_lengths, 19 | 20 | at::Tensor q, 21 | at::Tensor k, 22 | at::Tensor v, 23 | at::Tensor k_cache, 24 | at::Tensor v_cache, 25 | at::Tensor block_table, 26 | at::Tensor o 27 | ) { 28 | for (auto i = 0; i < decoding_seq_ids.size(); i++) { 29 | auto seq_id = decoding_seq_ids[i]; 30 | auto seq_len = decoding_seq_lengths[i]; 31 | at::Tensor qi = q.index({i}); 32 | at::Tensor ki = k.index({i}); 33 | at::Tensor vi = v.index({i}); 34 | auto blkid = (seq_len - 1) / BLOCK_SIZE; 35 | auto blkoff = (seq_len - 1) % BLOCK_SIZE; 36 | std::cout << k_cache.sizes() << std::endl; 37 | k_cache.index_put_({cur_layer, blkid, at::indexing::Slice(), blkoff, at::indexing::Slice()}, ki); 38 | v_cache.index_put_({cur_layer, blkid, at::indexing::Slice(), blkoff, at::indexing::Slice()}, vi); 39 | at::Tensor nk = k_cache.index({cur_layer, block_ids}).permute({1, 2, 0}).to(at::kFloat); 40 | at::Tensor nv = v_cache.index({cur_layer, block_ids}).permute({1, 0, 2}).to(at::kFloat); 41 | at::Tensor attn_score = (at::bmm(qi, nk) * softmax_scale).softmax(-1); 42 | o.index_put_({i}, at::bmm(attn_score, nv).view(-1)); 43 | } 44 | } 45 | #endif 46 | 47 | void assert_hyper_params_expected(int num_q_heads, int num_kv_heads, int num_layers, int head_dim, int block_size) { 48 | if (num_q_heads != NUM_Q_HEADS) { 49 | throw std::invalid_argument("expected num_q_heads to be " + std::to_string(NUM_Q_HEADS) + ", but got " + std::to_string(num_q_heads)); 50 | } 51 | if (num_kv_heads != NUM_KV_HEADS) { 52 | throw std::invalid_argument("expected num_kv_heads to be " + std::to_string(NUM_KV_HEADS) + ", but got " + std::to_string(num_kv_heads)); 53 | } 54 | if (num_layers != NUM_LAYERS) { 55 | throw std::invalid_argument("expected num_layers to be " + std::to_string(NUM_LAYERS) + ", but got " + std::to_string(num_layers)); 56 | } 57 | if (head_dim != HEAD_DIM) { 58 | throw std::invalid_argument("expected head_dim to be " + std::to_string(HEAD_DIM) + ", but got " + std::to_string(head_dim)); 59 | } 60 | if (block_size != BLOCK_SIZE) { 61 | throw std::invalid_argument("expected block_size to be " + std::to_string(BLOCK_SIZE) + ", but got " + std::to_string(block_size)); 62 | } 63 | } 64 | 65 | /* 66 | * Paged attention, contains 3 implementations: 67 | */ 68 | 69 | #define USE_ISPC_TASKS_OPER 70 | // #define USE_ISPC_OPER 71 | 72 | void paged_attention_cpu( 73 | int64_t cur_layer, 74 | double softmax_scale, 75 | const std::vector &seq_ids, 76 | const std::vector &seq_lengths, 77 | 78 | at::Tensor q, // [batch_size, num_q_heads, head_dim] 79 | at::Tensor k, // [batch_size, num_kv_heads, head_dim] 80 | at::Tensor v, // [batch_size, num_kv_heads, head_dim] 81 | at::Tensor k_cache, // [..., num_layers, num_kv_heads, block_size, head_dim] 82 | at::Tensor v_cache, // [..., num_layers, num_kv_heads, block_size, head_dim] 83 | at::Tensor block_table, // [..., max_seq_len] 84 | at::Tensor o // [batch_size, num_kv_heads * qh_per_kvh * head_dim] 85 | ) { 86 | int batch_size = q.size(0); 87 | int num_q_heads = q.size(1); 88 | int num_layers = k_cache.size(0); 89 | int num_blocks = k_cache.size(1); 90 | int num_kv_heads = k_cache.size(2); 91 | int block_size = k_cache.size(3); 92 | int head_dim = k_cache.size(4); 93 | int block_table_width = block_table.size(1); 94 | 95 | assert_hyper_params_expected(num_q_heads, num_kv_heads, num_layers, head_dim, block_size); 96 | 97 | auto qbatch_p = (data_t*) q.data_ptr(); 98 | auto kbatch_p = (data_t*) k.data_ptr(); 99 | auto vbatch_p = (data_t*) v.data_ptr(); 100 | auto obatch_p = o.data_ptr(); 101 | auto kcache_p = (data_t*) k_cache.data_ptr(); 102 | auto vcache_p = (data_t*) v_cache.data_ptr(); 103 | auto block_table_p = block_table.data_ptr(); // [batch_size, max_seq_len] 104 | 105 | #ifdef USE_BRUTE_OPER 106 | brute_attention( 107 | cur_layer, num_blocks, batch_size, block_table_width, softmax_scale, 108 | seq_ids, seq_lengths, 109 | qbatch_p, kbatch_p, vbatch_p, obatch_p, kcache_p, vcache_p, block_table_p 110 | ); 111 | #elifdef USE_ISPC_OPER 112 | ispc_attention( 113 | cur_layer, num_blocks, batch_size, block_table_width, softmax_scale, 114 | seq_ids, seq_lengths, 115 | qbatch_p, kbatch_p, vbatch_p, obatch_p, kcache_p, vcache_p, block_table_p 116 | ); 117 | #elifdef USE_ISPC_TASKS_OPER 118 | ispc_attention_tasks( 119 | cur_layer, num_blocks, batch_size, block_table_width, softmax_scale, 120 | seq_ids, seq_lengths, 121 | qbatch_p, kbatch_p, vbatch_p, obatch_p, kcache_p, vcache_p, block_table_p 122 | ); 123 | #endif 124 | } 125 | 126 | TORCH_LIBRARY(pacpu, m) { 127 | #ifdef USE_ATEN_OPER 128 | m.def("paged_attention_cpu_torch", &paged_attention_cpu_torch); 129 | #endif 130 | m.def("paged_attention_cpu", &paged_attention_cpu); 131 | } -------------------------------------------------------------------------------- /evaluation/illustrator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | cur_dir = os.path.dirname(os.path.realpath(__file__)) 7 | 8 | def get_lat_avg(file): 9 | with open(file) as f: 10 | data = json.load(f) 11 | # only take latter half 12 | data = data[len(data) // 4:] 13 | return sum([(x['end'] - x['start']) / (x['output_len']) for x in data]) / len(data) 14 | 15 | 16 | def draw_one_rl_diagram( 17 | title: str, 18 | data_name: str, 19 | sys_file_names: list[str], 20 | sys_legend_names: list[str], 21 | rate_lists: list[list[float]], 22 | ylim: float, 23 | markers: list[str], 24 | set_ylabel: bool = False, 25 | ): 26 | lats = [] 27 | max_rate = max([max(rate_list) for rate_list in rate_lists]) 28 | for sys_file_name, rate_list in zip(sys_file_names, rate_lists): 29 | lats.append([]) 30 | for rate in rate_list: 31 | rate_str = str(rate).replace(".", "_") 32 | lats[-1].append(get_lat_avg(f"{cur_dir}/results/{sys_file_name}-{data_name}-lat-{rate_str}.json")) 33 | 34 | # ax.set_title(title, y=-0.3, fontsize="x-large") 35 | 36 | fig, ax = plt.subplots(1, 1, figsize=(4, 3)) 37 | for i, sys_legend_name in enumerate(sys_legend_names): 38 | ax.plot(rate_lists[i], lats[i], label=sys_legend_name, marker=markers[i]) 39 | 40 | ax.set_xlabel("Ruquest rate (req/s)", fontsize="large") 41 | if set_ylabel: 42 | ax.set_ylabel("Average per token latency (s)", fontsize="large") 43 | ax.set_xlim(0, max_rate) 44 | ax.set_xticks([0.5 * x for x in range(round(max_rate * 2 + 1))]) 45 | ax.set_ylim(-ylim / 50, ylim) 46 | ax.set_yticks([ylim / 5 * x for x in range(6)]) 47 | ax.set_xticklabels([f"{x:.1f}" for x in ax.get_xticks()], fontsize="large") 48 | ax.set_yticklabels([f"{y:.2f}" for y in ax.get_yticks()], fontsize="large") 49 | ax.grid(True) 50 | handles, labels = ax.get_legend_handles_labels() 51 | ax.legend() 52 | plt.savefig(f"{cur_dir}/{title}.pdf", bbox_inches='tight') 53 | return handles, labels 54 | 55 | 56 | def get_tp(filenames: list[str], interv: tuple[float, float]): 57 | tps = [] 58 | for i, filename in enumerate(filenames): 59 | with open(filename) as f: 60 | data = json.load(f) 61 | 62 | times = sorted([d['end'] for d in data]) 63 | data = [times[j] - times[j-1] for j in range(1, len(times))] 64 | 65 | ndata = len(data) 66 | nwarmup = round((ndata + 1) * interv[0]) 67 | ncooldown = round((ndata + 1) * interv[1]) 68 | tps.append(1 / np.mean(data[nwarmup: ncooldown])) 69 | return tps 70 | 71 | def get_tp_token(filenames: list[str]): 72 | tps = [] 73 | for i, filename in enumerate(filenames): 74 | with open(filename) as f: 75 | data = json.load(f) 76 | 77 | first_start = min([d['start'] for d in data]) 78 | last_end = max([d['end'] for d in data]) 79 | total_time = last_end - first_start 80 | total_tokens = sum([d['output_len'] + d['input_len'] for d in data]) 81 | tps.append(total_tokens / total_time) 82 | 83 | return tps 84 | 85 | 86 | def draw_one_ps_diagram( 87 | title: str, 88 | base_sys_name: str, 89 | interv: list[float], 90 | num_datas: list[int], 91 | sys_file_names: list[str], 92 | legend_names: list[str | None], 93 | input_lens: list[int], 94 | output_lens: list[int], 95 | markers: list[str], 96 | show_ylabels: bool = False, 97 | show_legend: bool = True, 98 | ): 99 | fig, ax = plt.subplots(1, 1, figsize=(4, 3)) 100 | for i in range(len(num_datas)): 101 | tps = [] 102 | for out_len in output_lens: 103 | file_names = [f'{cur_dir}/results/{sys_name}-{num_datas[i]}-{input_lens[i]}-{out_len}-tp.json' for sys_name in [base_sys_name, sys_file_names[i]]] 104 | tp_pair = get_tp(file_names, interv) 105 | tps.append(tp_pair) 106 | 107 | ratios = [tp1 / tp0 for tp0, tp1 in tps] 108 | ax.plot(output_lens, ratios, label=f'{legend_names[i]}', marker=markers[i]) 109 | 110 | # draw y = 1 line 111 | ax.plot([output_lens[0], output_lens[-1]], [1, 1], 'r--', label='baseline') 112 | 113 | ax.set_xlabel('Avg. output length', fontsize='large') 114 | if show_ylabels: 115 | ax.set_ylabel('Relative throughput', fontsize='large') 116 | ax.set_xticklabels([f'{x:.0f}' for x in ax.get_xticks()], fontsize='large') 117 | ax.set_yticklabels([f'{x:.2f}' for x in ax.get_yticks()], fontsize='large') 118 | handles, labels = ax.get_legend_handles_labels() 119 | if show_legend: 120 | ax.legend() 121 | fig.savefig(f'{cur_dir}/{title}.pdf', bbox_inches='tight') 122 | return handles, labels 123 | 124 | 125 | def parse_ours_server_log(file): 126 | # Get sizes list from lines like below 127 | # INFO:swiftllm.server.engine:Forwarding batches with sizes [(0, 1, 14, 0)], swap out: 0, swap in: 4 128 | with open(file) as f: 129 | lines = f.readlines() 130 | 131 | sizes = [] 132 | for line in lines: 133 | if 'Forwarding batches with sizes' in line: 134 | sizes.append(eval(line.split('[')[1].split(']')[0])) 135 | 136 | return sizes 137 | 138 | 139 | def parse_vllm_server_log(file): 140 | # Get Rumnings list from lines like below 141 | # INFO 10-28 08:28:24 metrics.py:351] Avg prompt throughput: 5806.9 tokens/s, Avg generation throughput: 53.0 tokens/s, Running: 12 reqs, Swapped: 0 reqs, Pending: 1466 reqs, GPU KV cache usage: 93.6%, CPU KV cache usage: 0.0%. 142 | with open(file) as f: 143 | lines = f.readlines() 144 | 145 | runnings = [] 146 | for line in lines: 147 | if 'Running' in line and 'Avg prompt throughput' in line: 148 | runnings.append(int(line.split('Running: ')[1].split(' reqs')[0])) 149 | 150 | return runnings 151 | -------------------------------------------------------------------------------- /swiftllm/worker/kernels/prefill_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | from swiftllm.model_config import LlamaModelConfig 6 | from swiftllm.engine_config import EngineConfig 7 | from swiftllm.structs import SubBatch 8 | 9 | @triton.jit 10 | def _fwd_prefill_attention( 11 | o: torch.Tensor, # [num_prefill_tokens, num_q_heads, head_dim] 12 | q: torch.Tensor, # [num_prefill_tokens, num_q_heads, head_dim] 13 | k: torch.Tensor, # [num_prefill_tokens, num_kv_heads, head_dim] 14 | v: torch.Tensor, # [num_prefill_tokens, num_kv_heads, head_dim] 15 | softmax_scale: float, 16 | 17 | prefill_seq_start_locs: torch.Tensor, # [num_prefill_seqs+1] 18 | prefill_seq_lens: torch.Tensor, # [num_prefill_seqs] 19 | 20 | num_q_heads: tl.constexpr, 21 | num_kv_heads: tl.constexpr, 22 | gpa_group_size: tl.constexpr, # = num_q_heads // num_kv_heads 23 | head_dim: tl.constexpr, 24 | 25 | BLOCK_Q: tl.constexpr, 26 | BLOCK_K: tl.constexpr 27 | ): 28 | # grid shape: [num_prefill_seqs, num_q_heads, cdiv(max_prefill_len, BLOCK_Q)] 29 | # Require: BLOCK_Q % BLOCK_K == 0 30 | my_batch_id = tl.program_id(0) 31 | my_q_head = tl.program_id(1) 32 | my_q_block = tl.program_id(2) 33 | my_kv_head = my_q_head // gpa_group_size 34 | 35 | my_seq_len = tl.load(prefill_seq_lens + my_batch_id) 36 | if my_q_block * BLOCK_Q >= my_seq_len: 37 | return 38 | my_q_start_loc = tl.load(prefill_seq_start_locs + my_batch_id) 39 | 40 | q += (my_q_start_loc*num_q_heads+my_q_head)*head_dim 41 | k += (my_q_start_loc*num_kv_heads+my_kv_head)*head_dim 42 | v += (my_q_start_loc*num_kv_heads+my_kv_head)*head_dim 43 | o += (my_q_start_loc*num_q_heads+my_q_head)*head_dim 44 | 45 | range_my_q = my_q_block*BLOCK_Q + tl.arange(0, BLOCK_Q) 46 | offs_my_q = range_my_q[:, None]*(num_q_heads*head_dim) + tl.arange(0, head_dim)[None, :] 47 | my_q = tl.load(q + offs_my_q, mask = range_my_q[:, None] < my_seq_len, cache_modifier=".cg") # [BLOCK_Q, head_dim] 48 | 49 | k_ptrs = k + (tl.arange(0, BLOCK_K))[None, :]*(num_kv_heads*head_dim) + tl.arange(0, head_dim)[:, None] 50 | v_ptrs = v + (tl.arange(0, BLOCK_K))[:, None]*(num_kv_heads*head_dim) + tl.arange(0, head_dim)[None, :] 51 | 52 | m_i = tl.full([BLOCK_Q], value=float("-1e20"), dtype=tl.float32) 53 | l_i = tl.zeros([BLOCK_Q], dtype=tl.float32) 54 | acc = tl.zeros([BLOCK_Q, head_dim], dtype=tl.float32) 55 | 56 | # Calculate non-diagonal attention 57 | for k_block_start in range(0, my_q_block*BLOCK_Q, BLOCK_K): 58 | k_block_start = tl.multiple_of(k_block_start, BLOCK_K) 59 | # Here masking is unnecessary 60 | cur_k = tl.load(k_ptrs + k_block_start*(num_kv_heads*head_dim), cache_modifier=".cg") # [head_dim, BLOCK_K] 61 | qk = tl.dot(my_q, cur_k, out_dtype=tl.float32) * softmax_scale # [BLOCK_Q, BLOCK_K] 62 | cur_k = None 63 | cur_v = tl.load(v_ptrs + k_block_start*(num_kv_heads*head_dim), cache_modifier=".cg") # [BLOCK_K, head_dim] 64 | 65 | m_i_new = tl.maximum(m_i, tl.max(qk, 1)) 66 | alpha = tl.math.exp2(m_i - m_i_new) 67 | exp_qk = tl.math.exp2(qk - m_i_new[:, None]) 68 | 69 | m_i = m_i_new 70 | l_i = l_i*alpha + tl.sum(exp_qk, 1) 71 | acc = acc*alpha[:, None] + tl.dot(exp_qk.to(tl.float16), cur_v) 72 | 73 | # Calculate the diagonal attention 74 | for k_block_start in range(my_q_block*BLOCK_Q, (my_q_block+1)*BLOCK_Q, BLOCK_K): 75 | k_block_start = tl.multiple_of(k_block_start, BLOCK_K) 76 | cur_k = tl.load(k_ptrs + k_block_start*(num_kv_heads*head_dim), 77 | mask = (k_block_start + tl.arange(0, BLOCK_K))[None, :] < my_seq_len, 78 | cache_modifier=".cg") # [head_dim, BLOCK_K] 79 | qk = tl.dot(my_q, cur_k, out_dtype=tl.float32) * softmax_scale # [BLOCK_Q, BLOCK_K] 80 | cur_k = None 81 | cur_v = tl.load(v_ptrs + k_block_start*(num_kv_heads*head_dim), 82 | mask = (k_block_start + tl.arange(0, BLOCK_K))[:, None] < my_seq_len, 83 | cache_modifier=".cg") # [BLOCK_K, head_dim] 84 | 85 | qk = tl.where( 86 | ((k_block_start + tl.arange(0, BLOCK_K)) < my_seq_len) & 87 | (range_my_q[:, None] >= (k_block_start + tl.arange(0, BLOCK_K))[None, :]), 88 | qk, 89 | float("-1e20") 90 | ) 91 | 92 | m_i_new = tl.maximum(m_i, tl.max(qk, 1)) 93 | alpha = tl.math.exp2(m_i - m_i_new) 94 | exp_qk = tl.math.exp2(qk - m_i_new[:, None]) 95 | 96 | m_i = m_i_new 97 | l_i = l_i*alpha + tl.sum(exp_qk, 1) 98 | acc = acc*alpha[:, None] + tl.dot(exp_qk.to(tl.float16), cur_v) 99 | 100 | tl.store(o + offs_my_q, acc / l_i[:, None], mask=range_my_q[:, None] < my_seq_len, cache_modifier=".cg") 101 | 102 | def prefill_attention( 103 | q: torch.Tensor, # [num_prefill_tokens, num_q_heads, head_dim] 104 | k: torch.Tensor, # [num_prefill_tokens, num_kv_heads, head_dim] 105 | v: torch.Tensor, # [num_prefill_tokens, num_kv_heads, head_dim] 106 | o: torch.Tensor, # [num_prefill_tokens, num_q_heads, head_dim] 107 | model_config: LlamaModelConfig, 108 | engine_config: EngineConfig, 109 | batch: SubBatch, 110 | ): 111 | is_rtx4090 = '4090' in torch.cuda.get_device_name(0) 112 | BLOCK_Q = 128 if not is_rtx4090 else 128 113 | BLOCK_K = 64 114 | 115 | # Here we reduce BLOCK_Q and BLOCK_K, since that when max_prefill_len is 116 | # small, large block size introduces unnecessary computation when computing 117 | # the attention score. 118 | # note: We restrict BLOCK_Q and BLOCK_K >= 16 due to a limitation proposed by tl.dot 119 | BLOCK_Q = min(BLOCK_Q, triton.next_power_of_2(max(batch.max_pref_toks, 16))) 120 | BLOCK_K = min(BLOCK_K, triton.next_power_of_2(max(batch.max_pref_toks, 16))) 121 | 122 | # Please refer to `paged_attn.py` for the reason of multiplying softmax_scale 123 | # by log2(e) 124 | softmax_scale2 = model_config.softmax_scale * 1.442695040888963 125 | 126 | assert BLOCK_Q % BLOCK_K == 0 127 | grid = (batch.num_prefs, model_config.num_q_heads, triton.cdiv(batch.max_pref_toks, BLOCK_Q)) 128 | num_warps = 8 129 | _fwd_prefill_attention[grid]( 130 | o, q, k, v, 131 | softmax_scale2, 132 | batch.pref_st_locs_we[:-1], 133 | batch.prgd_seq_lens[:batch.num_prefs], 134 | model_config.num_q_heads, model_config.num_kv_heads, 135 | model_config.num_q_heads // model_config.num_kv_heads, 136 | model_config.head_dim, 137 | BLOCK_Q, BLOCK_K, 138 | num_warps=num_warps, 139 | num_stages=3 140 | ) 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NEO: Saving GPU Memory Crisis with CPU Offloading for Online LLM Inference 2 | 3 | Online LLM inference powers many exciting applications such as intelligent chatbots and autonomous agents. Modern LLM inference engines widely rely on request batching to improve inference throughput, aiming to make it cost-efficient when running on expensive GPU accelerators. However, the limited GPU memory has largely limited the batch size achieved in practice, leaving significant GPU compute resources wasted. 4 | 5 | NEO is an online LLM inference system that offloads part of attention compute and KV cache states from the GPU to the local host CPU, effectively increasing the GPU batch size and thus inference throughput. To this end, NEO proposes asymmetric GPU-CPU pipelining and load-aware scheduling to balance GPU and CPU loads and fully utilize their compute and memory resources. Our MLSys'25 paper is [here](https://yangzhou1997.github.io/paper/neo_mlsys25.pdf). 6 | 7 | ## Requirements 8 | 9 | Python >= 3.10 10 | PyTorch >= 2.4 11 | 12 | 2 versions of g++ (see `pacpu/build.sh` for more details): 13 | 14 | - one >= 13 (for compiling CPU kernel) 15 | - the other < 13 (for passing the NVCC version check) 16 | 17 | Intel ISPC compiler == 1.23, which can be installed by `sudo snap install ispc --channel latest/edge` 18 | 19 | ## Installation 20 | 21 | 1. Clone the NEO repository and `cd` into the repo. 22 | 23 | 2. Install dependencies by `pip install -r requirements.txt.` 24 | 25 | 3. Install the swiftLLM library to your local environment by `pip install -e .` 26 | 27 | 4. Build and install auxiliary GPU operators library by `pip install -e csrc` 28 | 29 | 5. Build the CPU operator library by 30 | 31 | ```bash 32 | cd pacpu 33 | bash build.sh 34 | # e.g bash build.sh llama2_7b 1 35 | cd .. 36 | ``` 37 | 38 | ## Offline Example 39 | 40 | ```bash 41 | cd NEO 42 | python examples/example.py --model-path ... --model-name ... 43 | # e.g. python examples/example.py --model-path /home/ubuntu/weights/Llama-2-7b-hf/ --model-name llama2_7b 44 | ``` 45 | 46 | Run `python examples/example.py --help` to see more options. 47 | 48 | ## Performance Results 49 | 50 | ### Load-latency Curves 51 | 52 | The figure below (Figure 6c in the paper) shows online latencies of NEO and other baselines under different request rates. 53 | 54 | vLLM-256 and vLLM-512 designate vLLM with chunked-prefilling at the chunk size of 256 and 512 tokens, respectively. 55 | 56 | ![image-20250221101244560](docs/load-latency.png) 57 | 58 | - Hardware: AWS g4.4xlarge instance, with Tesla T4 GPU, 8 cores of Xeon P-8259CL CPU, and 64 GB main memory. 59 | - Model: LLaMa-2-7B 60 | - Workload: OpenAI summarization comparison ([CarperAI](https://huggingface.co/datasets/CarperAI/openai_summarize_comparisons.)) 61 | 62 | ### Generation Throughput 63 | 64 | The figure below (Figure 10a in the paper) shows NEO's throughput gains over the non-CPU-offloading baseline under different workloads. NEO achieves up to 12.2%, 13.3%, 29.7%, and 79.3% higher throughput over the baseline under different CPU capacities. 65 | 66 | ![image-20250221101309717](docs/cpu-sensitivity.png) 67 | 68 | - Hardware: AWS g5.nxlarge instances (n=2,4,8,16), with A10 GPU, 2n cores of EPYC 7R32 CPU, and 16n GB main memory. 69 | - Model: LLaMa-3-8B 70 | - Workload: Synthetic workloads with various input and output lengths. For a pair of input length $l_i$ and output length $l_o$, we synthesize requests with input and output lengths sampled independently and uniformly from $[0.9l_i, 1.1l_i]$ and $[0.9l_o, 1.1l_o]$, respectively. Here we fix $l_i=1000$ and pick $l_o$ from $\{50, 100, 200, 300, 400\}$. 71 | 72 | ## Reproduction 73 | 74 | Below are instructions for reproducing Figure 6c in the paper. Instructions for Figure 10a are the same except for specific details noted in parentheses. 75 | 76 | ### With an AWS Account 77 | 78 | 1. Launch a g4dn.4xlarge (g5.16xlarge) instance in us-east-1 region with community AMI neo-ae-g4-image (neo-ae-g5-image). 79 | 2. SSH to the instance and run `mamba activate neo` in the shell. 80 | 3. run `cd NEO` 81 | 4. run `python evaluation/reproduce-fig6c.py`(`python evaluation/reproduce-fig10a.py`) 82 | 83 | > NOTE: Although the model weights are pre-packaged in the images, the first time loading them would take about 1 hour. Therefore, it is recommended to download the weights from the internet and replace those embedded in the image, which usually takes less than 10 min. The following script can be used to retrieve the weights from Huggingface: 84 | > 85 | > ```bash 86 | > cd ~ 87 | > rm -r weights/* 88 | > ip install 'huggingface_hub[cli]' 89 | > huggingface-cli login --token 90 | > # For g5 instance: 91 | > huggingface-cli download meta-llama/Llama-3.1-8B --local-dir weights/Llama-3-8B --exclude "*.pth" 92 | > # For g4 instance: 93 | > huggingface-cli download meta-llama/Llama-2-7b-hf --local-dir weights/Llama-2-7b-hf --exclude "*.pth" 94 | > ``` 95 | > 96 | > Alternatively, you may use the pre-packaged weights within the image. It is possible to encounter timeout issues during the initial execution of the evaluation script due to prolonged loading times. If this occurs, simply rerunning the script should resolve the issue. 97 | 98 | ### Without an AWS Account 99 | 100 | 1. Prepare a machine with 101 | - Nvidia Tesla T4 (A10G) GPU; 102 | - CPU with AVX2 support; 103 | - At least 30GB (120GB) main memory for CPU KV Cache. 104 | - Ubuntu >= 22.04 105 | 2. Follow the steps in the Installation section to install dependencies. 106 | 3. Download LLaMa-2-7B (LLaMa-3-8B) model weights. You can refer to the NOTE above for weight retrieving scripts. 107 | 4. Modify `model_path` entry in `evaluation/configs/config-t4-7b.json` ( `evaluation/configs/config-a10-8b.json`) to the actual path to the model weights. 108 | 5. run `python evaluation/reproduce-fig6c.py`(`python evaluation/reproduce-fig10a.py`) in top level directory of the NEO repository. 109 | 110 | ### Expected Results 111 | 112 | - The reproduced figure fig6c.pdf (fig10a.pdf) will be produced in `evaluation` directory. 113 | - For Figure 6c, there will be only 2 lines (Neo and vLLM). By default the script only uses a small subset (100 requests) of the original input data (2000 requests) used in the original experiment. This is for the purpose of demonstration and quick verification of the results for faster evaluation. As a result, the latency would be lower than the original figure due to less average queuing latency. 114 | - For Figure 10a, only 2 lines (x16large and baseline) in the original figure will be drawn. 115 | 116 | > NOTE: You can change the hyperparameters of the experiments by modifying the corresponding scripts. Please refer to comments in the code for detailed instructions. 117 | -------------------------------------------------------------------------------- /examples/example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Offline example of using the NEO engine to run inference on a model. 3 | 4 | Performance details are printed on the screen. 5 | 6 | Note that this script is for demonstration purposes only and uses symmetric pipelining. In evaluation, we use asymmetric pipelining instead. 7 | """ 8 | import os 9 | import time 10 | import argparse 11 | from transformers import AutoTokenizer 12 | 13 | import swiftllm 14 | 15 | 16 | if __name__ == '__main__': 17 | script_dir = os.path.dirname(os.path.realpath(__file__)) 18 | repo_dir = os.path.dirname(script_dir) 19 | parser = argparse.ArgumentParser() 20 | parser.description = """ 21 | An example script to demonstrate how to use the NEO offline inference engine. 22 | """ 23 | parser.add_argument( 24 | "--model-path", 25 | help="Path to the model. Note: please download the model weights in advance and specify the path here.", 26 | type=str 27 | ) 28 | parser.add_argument( 29 | "--model-name", 30 | help="Name of the model in lowercase. Helps in loading CPU kernel library", 31 | type=str 32 | ) 33 | parser.add_argument( 34 | "--tp-degree", 35 | help="Tensor parallel degree", 36 | type=int, 37 | default=1 38 | ) 39 | parser.add_argument( 40 | "--profile-result-path", 41 | help="Path to folder of profiling results", 42 | type=str, 43 | default=f"{repo_dir}/profile_results/" 44 | ) 45 | parser.add_argument( 46 | "--num-gpu-blocks", 47 | help="Number of GPU blocks to use", 48 | type=int, 49 | default=50 50 | ) 51 | parser.add_argument( 52 | "--swap-space", 53 | help="CPU swap space in GB", 54 | type=int, 55 | default=2 56 | ) 57 | parser.add_argument( 58 | "--prompt-path", 59 | help="Path to the prompt file", 60 | type=str, 61 | default=f"{script_dir}/example.txt" 62 | ) 63 | parser.add_argument( 64 | "--num-gpu-requests", 65 | help="Number of GPU requests", 66 | type=int, 67 | default=2 68 | ) 69 | parser.add_argument( 70 | "--num-cpu-requests", 71 | help="Number of CPU requests", 72 | type=int, 73 | default=2 74 | ) 75 | parser.add_argument( 76 | "--monitor-performace", 77 | help="Performance monitoring switch", 78 | action="store_true", 79 | default=False 80 | ) 81 | args = parser.parse_args() 82 | 83 | 84 | # 1. Create the engine 85 | engine_config = swiftllm.EngineConfig( 86 | model_path = args.model_path, 87 | use_dummy = False, 88 | 89 | block_size = 16, 90 | gpu_mem_utilization = 0.99, 91 | num_gpu_blocks_override = args.num_gpu_blocks, 92 | swap_space = args.swap_space, 93 | max_seqs_in_block_table = 10, 94 | max_blocks_per_seq = 100, 95 | 96 | max_batch_size = 10, 97 | max_tokens_in_batch = 600, 98 | 99 | library_path=f"{repo_dir}/pacpu/build/libpacpu-{args.model_name}-tp{args.tp_degree}.so", 100 | profile_result_path=args.profile_result_path, 101 | 102 | extra_layer_for_cprf=True, 103 | tensor_parallel_degree=1 104 | ) 105 | 106 | start_time = time.perf_counter() 107 | engine = swiftllm.Engine(engine_config) 108 | engine.initialize() 109 | print(f"Engine creation time: {time.perf_counter() - start_time:.2f} seconds") 110 | 111 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 112 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 113 | 114 | 115 | # 2. Load the prompt and tokenize 116 | ngpu_prompts = args.num_gpu_requests 117 | ncpu_prompts = args.num_cpu_requests 118 | nprompts = ncpu_prompts + ngpu_prompts 119 | with open(args.prompt_path, "r") as f: 120 | prompt = ''.join(f.readlines()) 121 | 122 | input_ids = tokenizer(prompt)['input_ids'] 123 | print("Prompt token length: ", len(input_ids)) 124 | 125 | 126 | # 3. Prefill the prompts 127 | reqs = [None] * nprompts 128 | gpu_req_ids = list(range(ngpu_prompts // 2)) + list(range(nprompts // 2, nprompts // 2 + ngpu_prompts // 2)) 129 | gpu_reqs = [] 130 | if ngpu_prompts: 131 | batch = swiftllm.SubBatch() 132 | for i in gpu_req_ids: 133 | reqs[i] = swiftllm.create_request(input_ids, i) 134 | batch.add_pref(reqs[i], is_gpu=True) 135 | gpu_reqs = [reqs[i] for i in gpu_req_ids] 136 | engine.step([batch]) 137 | 138 | if ncpu_prompts: 139 | batch = swiftllm.SubBatch() 140 | for i in range(ngpu_prompts // 2, nprompts // 2): 141 | reqs[i] = swiftllm.create_request(input_ids, i) 142 | batch.add_pref(reqs[i], is_gpu=False) 143 | engine.step([batch]) 144 | 145 | batch = swiftllm.SubBatch() 146 | for i in range(nprompts // 2 + ngpu_prompts // 2, nprompts): 147 | reqs[i] = swiftllm.create_request(input_ids, i) 148 | batch.add_pref(reqs[i], is_gpu=False) 149 | engine.step([batch]) 150 | 151 | print("Prefilling phase done") 152 | 153 | 154 | # 4. Run the inference 155 | if args.monitor_performace: 156 | engine.executor.turn_on_perf_monitor() 157 | 158 | for iteration in range(16): 159 | batches = [swiftllm.SubBatch() for _ in range(2)] 160 | for i in range(ngpu_prompts // 2): 161 | batches[0].add_gdec(reqs[i]) 162 | for i in range(ngpu_prompts // 2, nprompts // 2): 163 | batches[1].add_cdec(reqs[i]) 164 | for i in range(nprompts // 2, nprompts // 2 + ngpu_prompts // 2): 165 | batches[1].add_gdec(reqs[i]) 166 | for i in range(nprompts // 2 + ngpu_prompts // 2, nprompts): 167 | batches[0].add_cdec(reqs[i]) 168 | 169 | # Un-comment the following 4 lines to run mixed batches 170 | # reqs.append(swiftllm.create_request(input_ids, len(reqs))) 171 | # reqs.append(swiftllm.create_request(input_ids, len(reqs))) 172 | # batches[0].add_pref(reqs[-2], is_gpu=False) 173 | # batches[1].add_pref(reqs[-1], is_gpu=False) 174 | 175 | start = time.perf_counter() 176 | engine.step(batches) 177 | end = time.perf_counter() 178 | print(f"Iteration {iteration:3} E2E time: {(end - start) * 1000:.4f} ms") 179 | 180 | for i in range(nprompts): 181 | if i in (0, nprompts // 2 - 1, nprompts - 1): 182 | output_text = tokenizer.decode(reqs[i].output_token_ids, skip_special_tokens=True) 183 | print(f"{prompt}|{output_text}") 184 | 185 | if args.monitor_performace: 186 | res = engine.executor.turn_off_perf_monitor_and_flush_results() 187 | print(res) 188 | -------------------------------------------------------------------------------- /swiftllm/worker/kernels/block_mgmt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | @triton.jit 6 | def _fwd_set_block_table_and_num_seq_alloc_blocks_kernel( 7 | block_table: torch.Tensor, # [max_seqs_in_block_table, max_blocks_per_seq] 8 | candidate_blocks: torch.Tensor, # [sum(block_needed)] 9 | seq_ids: torch.Tensor, # [batch_size] 10 | num_seq_allocated_blocks: torch.Tensor, # [max_seqs_in_block_table] 11 | block_needed: torch.Tensor, # [batch_size] 12 | block_needed_cumsum: torch.Tensor, # [batch_size] 13 | max_blocks_per_seq: tl.constexpr 14 | ): 15 | # grid shape: [batch_size] 16 | my_batch_id = tl.program_id(0) 17 | my_seq_id = tl.load(seq_ids + my_batch_id).to(tl.int64) 18 | my_block_needed = tl.load(block_needed + my_batch_id) 19 | my_candidate_block_start_index = tl.load(block_needed_cumsum + my_batch_id) - my_block_needed 20 | my_num_allocated_blocks = tl.load(num_seq_allocated_blocks + my_seq_id) 21 | for i in range(my_block_needed): 22 | my_block_id = tl.load(candidate_blocks + my_candidate_block_start_index + i) 23 | tl.store(block_table + my_seq_id * max_blocks_per_seq + my_num_allocated_blocks + i, my_block_id) 24 | tl.store(num_seq_allocated_blocks + my_seq_id, my_num_allocated_blocks + my_block_needed) 25 | 26 | def set_block_table_and_num_seq_alloc_blocks( 27 | num_seq_allocated_blocks: torch.Tensor, # [max_seqs_in_block_table] 28 | block_table: torch.Tensor, # [max_seqs_in_block_table, max_blocks_per_seq] 29 | candidate_blocks: torch.Tensor, # [sum(block_needed)] 30 | seq_ids: torch.Tensor, # [batch_size] 31 | block_needed: torch.Tensor, # [batch_size] 32 | ): 33 | """ 34 | Set block_table and num_seq_allocated_blocks 35 | 36 | For the ith sequence in the batch which has seq_id s: 37 | - Set block_table[s][num_seq_allocated_block[s]: num_seq_allocated_block[s] + block_needed[i]] to 38 | candidate_blocks[block_needed_cumsum[i-1]: block_needed_cumsum[i]] 39 | - Set num_seq_allocated_blocks[s] to num_seq_allocated_blocks[s] + block_needed[i] 40 | """ 41 | block_needed_cumsum = torch.cumsum(block_needed, 0) 42 | max_blocks_per_seq = block_table.shape[1] 43 | if block_table.is_cuda: 44 | grid = (block_needed.shape[0], ) 45 | _fwd_set_block_table_and_num_seq_alloc_blocks_kernel[grid]( 46 | block_table, candidate_blocks, seq_ids, num_seq_allocated_blocks, block_needed, block_needed_cumsum, max_blocks_per_seq 47 | ) 48 | else: 49 | for i, n in enumerate(block_needed): 50 | seq_id = seq_ids[i] 51 | bt_off = num_seq_allocated_blocks[seq_id] 52 | cb_off = block_needed_cumsum[i] - n 53 | block_table[seq_id, bt_off: bt_off + n] = candidate_blocks[cb_off: cb_off + n] 54 | num_seq_allocated_blocks[seq_id] += n 55 | 56 | 57 | @triton.jit 58 | def _fwd_unset_block_table_and_num_seq_alloc_blocks_kernel( 59 | num_seq_allocated_blocks: torch.Tensor, # [max_seqs_in_block_table] 60 | block_table: torch.Tensor, # [max_seqs_in_block_table, max_blocks_per_seq] 61 | seq_ids: torch.Tensor, # [batch_size] 62 | is_block_free: torch.Tensor, # [num_blocks], bool 63 | max_blocks_per_seq: tl.constexpr 64 | ): 65 | # grid shape: [batch_size] 66 | my_batch_id = tl.program_id(0) 67 | my_seq_id = tl.load(seq_ids + my_batch_id) 68 | my_num_blocks = tl.load(num_seq_allocated_blocks + my_seq_id) 69 | for i in range(my_num_blocks): 70 | my_block_id = tl.load(block_table + my_seq_id * max_blocks_per_seq + i) 71 | tl.store(is_block_free + my_block_id, True) 72 | tl.store(num_seq_allocated_blocks + my_seq_id, 0) 73 | 74 | def unset_block_table_and_num_seq_alloc_blocks( 75 | num_seq_allocated_blocks: torch.Tensor, # [max_seqs_in_block_table] 76 | block_table: torch.Tensor, # [max_seqs_in_block_table, max_blocks_per_seq] 77 | seq_ids: torch.Tensor, # [batch_size] 78 | is_block_free: torch.Tensor, # [num_blocks], bool 79 | ): 80 | """ 81 | Mark the blocks allocated for the specified sequences in the `is_block_free` 82 | as free, and set corresponding num_seq_allocated_blocks to 0 83 | """ 84 | max_blocks_per_seq = block_table.shape[1] 85 | grid = (seq_ids.shape[0], ) 86 | if block_table.is_cuda: 87 | _fwd_unset_block_table_and_num_seq_alloc_blocks_kernel[grid]( 88 | num_seq_allocated_blocks, block_table, seq_ids, is_block_free, max_blocks_per_seq 89 | ) 90 | else: 91 | for seq_id in seq_ids: 92 | n = num_seq_allocated_blocks[seq_id] 93 | is_block_free[block_table[seq_id, :n]] = True 94 | num_seq_allocated_blocks[seq_id] = 0 95 | 96 | 97 | @triton.jit 98 | def _fwd_gather_allocated_blocks_and_unset_kernel( 99 | num_seq_allocated_blocks: torch.Tensor, # [max_seqs_in_block_table] 100 | block_table: torch.Tensor, # [max_seqs_in_block_table, max_blocks_per_seq] 101 | seq_ids: torch.Tensor, # [batch_size] 102 | is_block_free: torch.Tensor, # [num_blocks], bool 103 | 104 | num_allocated_blocks_cumsum: torch.Tensor, # [batch_size] 105 | gathered_block_ids: torch.Tensor, # [sum(num_seq_allocated_blocks[seq_ids])] 106 | 107 | max_blocks_per_seq: tl.constexpr 108 | ): 109 | # grid shape: [batch_size] 110 | my_batch_id = tl.program_id(0) 111 | my_seq_id = tl.load(seq_ids + my_batch_id) 112 | my_num_blocks = tl.load(num_seq_allocated_blocks + my_seq_id) 113 | my_num_allocated_blocks_cumsum = tl.load(num_allocated_blocks_cumsum+my_batch_id-1, mask=my_batch_id>0, other=0) 114 | for i in range(my_num_blocks): 115 | my_block_id = tl.load(block_table + my_seq_id * max_blocks_per_seq + i) 116 | tl.store(gathered_block_ids + my_num_allocated_blocks_cumsum + i, my_block_id) 117 | tl.store(is_block_free + my_block_id, True) 118 | tl.store(num_seq_allocated_blocks + my_seq_id, 0) 119 | 120 | def gather_allocated_blocks_and_unset( 121 | num_seq_allocated_blocks: torch.Tensor, # [max_seqs_in_block_table] 122 | block_table: torch.Tensor, # [max_seqs_in_block_table, max_blocks_per_seq] 123 | seq_ids: torch.Tensor, # [batch_size] 124 | is_block_free: torch.Tensor, # [num_blocks], bool 125 | ) -> torch.Tensor: 126 | """ 127 | Gather the block IDs allocated for the specified sequences and mark them as free 128 | """ 129 | num_allocated_blocks_cumsum = torch.cumsum(num_seq_allocated_blocks[seq_ids], 0) 130 | gathered_block_ids = torch.empty((num_allocated_blocks_cumsum[-1].item(),), dtype=torch.int32, device=block_table.device) 131 | 132 | max_blocks_per_seq = block_table.shape[1] 133 | grid = (seq_ids.shape[0], ) 134 | if block_table.is_cuda: 135 | _fwd_gather_allocated_blocks_and_unset_kernel[grid]( 136 | num_seq_allocated_blocks, block_table, seq_ids, is_block_free, 137 | num_allocated_blocks_cumsum, gathered_block_ids, max_blocks_per_seq 138 | ) 139 | else: 140 | for i, seq_id in enumerate(seq_ids): 141 | n = num_seq_allocated_blocks[seq_id] 142 | gb_off = num_allocated_blocks_cumsum[i] - n 143 | gathered_block_ids[gb_off: gb_off + n] = block_table[seq_id, :n] 144 | is_block_free[block_table[seq_id, :n]] = True 145 | num_seq_allocated_blocks[seq_id] = 0 146 | 147 | return gathered_block_ids 148 | -------------------------------------------------------------------------------- /pacpu/pacpu.ispc: -------------------------------------------------------------------------------- 1 | /* ISPC kernels for paged attention algorithm on CPU */ 2 | #include "dtype.h" 3 | 4 | #define K_TILE_WIDTH 2 5 | export void qk_product( 6 | const uniform int cur_layer, 7 | const uniform int num_blocks, 8 | const uniform int seq_len, 9 | 10 | const uniform data_t q[], // [NUM_Q_HEADS, HEAD_DIM] 11 | const uniform data_t k_cache[], // [..., NUM_LAYERS, NUM_KV_HEADS, BLOCK_SIZE, HEAD_DIM] 12 | const uniform int block_table[], // [seq_len] 13 | 14 | uniform itmd_t a[] // [seq_len, NUM_KV_HEADS, HEAD_DIM] 15 | ) { 16 | uniform int imax = seq_len / BLOCK_SIZE + 1; 17 | for (uniform int i = 0; i < imax; i++) { 18 | const uniform data_t * k = k_cache + 19 | (1ll * cur_layer * num_blocks + block_table[i]) * BLOCK_NELEM; 20 | for (uniform int j = 0; j < NUM_KV_HEADS; j++) { 21 | uniform int q_off = j * QH_PER_KVH * HEAD_DIM; 22 | uniform int k_off = j * BLOCK_SIZE * HEAD_DIM; 23 | uniform int a_off = i * BLOCK_SIZE * NUM_Q_HEADS + j * QH_PER_KVH; 24 | uniform int tmax = min(BLOCK_SIZE, seq_len - i * BLOCK_SIZE); 25 | uniform int t; 26 | for (t = 0; t < tmax - K_TILE_WIDTH + 1; t += K_TILE_WIDTH) { 27 | itmd_t sum[QH_PER_KVH][K_TILE_WIDTH]; 28 | for (uniform int h = 0; h < QH_PER_KVH; h++) { 29 | for (uniform int g = 0; g < K_TILE_WIDTH; g++) { 30 | sum[h][g] = 0; 31 | } 32 | } 33 | foreach (l = 0 ... HEAD_DIM) { 34 | for (uniform int g = 0; g < K_TILE_WIDTH; g++) { 35 | itmd_t k_val = k[k_off + g * HEAD_DIM + l]; 36 | for (uniform int h = 0; h < QH_PER_KVH; h++) { 37 | sum[h][g] += q[q_off + h * HEAD_DIM + l] * k_val; 38 | } 39 | } 40 | } 41 | for (uniform int h = 0; h < QH_PER_KVH; h++) { 42 | for (uniform int g = 0; g < K_TILE_WIDTH; g++) { 43 | a[a_off + (t + g) * NUM_Q_HEADS + h] = reduce_add(sum[h][g]); 44 | } 45 | } 46 | k_off += HEAD_DIM * K_TILE_WIDTH; 47 | } 48 | for (; t < tmax; t++) { 49 | itmd_t sum[QH_PER_KVH]; 50 | for (uniform int h = 0; h < QH_PER_KVH; h++) { 51 | sum[h] = 0; 52 | } 53 | foreach (l = 0 ... HEAD_DIM) { 54 | itmd_t k_val = k[k_off + l]; 55 | for (uniform int h = 0; h < QH_PER_KVH; h++) { 56 | sum[h] += q[q_off + h * HEAD_DIM + l] * k_val; 57 | } 58 | } 59 | for (uniform int h = 0; h < QH_PER_KVH; h++) { 60 | a[a_off + t * NUM_Q_HEADS + h] = reduce_add(sum[h]); 61 | } 62 | k_off += HEAD_DIM; 63 | } 64 | } 65 | } 66 | } 67 | 68 | export void av_product( 69 | const uniform int cur_layer, 70 | const uniform int num_blocks, 71 | const uniform int seq_len, 72 | 73 | const uniform itmd_t a[], // [seq_len, NUM_KV_HEADS, HEAD_DIM] 74 | const uniform data_t v_cache[], // [..., NUM_LAYERS, NUM_KV_HEADS, BLOCK_SIZE, HEAD_DIM] 75 | const uniform int block_table[], // [seq_len] 76 | 77 | uniform otpt_t o[] // [NUM_Q_HEADS, HEAD_DIM] 78 | ) { 79 | uniform int imax = seq_len / BLOCK_SIZE + 1; 80 | 81 | memset(o, 0, NUM_Q_HEADS * HEAD_DIM * sizeof(uniform otpt_t)); 82 | for (uniform int i = 0; i < imax; i++) { 83 | const uniform data_t * uniform v = v_cache + 84 | (1ll * cur_layer * num_blocks + block_table[i]) * BLOCK_NELEM; 85 | for (uniform int j = 0; j < NUM_KV_HEADS; j++) { 86 | uniform int o_off = j * QH_PER_KVH * HEAD_DIM; 87 | uniform int v_off = j * BLOCK_SIZE * HEAD_DIM; 88 | uniform int tmax = min(BLOCK_SIZE, seq_len - i * BLOCK_SIZE); 89 | for (uniform int t = 0; t < tmax; t++) { 90 | uniform int a_off = (i * BLOCK_SIZE + t) * NUM_Q_HEADS + j * QH_PER_KVH; 91 | foreach (l = 0 ... HEAD_DIM) { 92 | otpt_t v_val = v[v_off + l]; 93 | for (uniform int h = 0; h < QH_PER_KVH; h++) { 94 | o[o_off + h * HEAD_DIM + l] += v_val * a[a_off + h]; 95 | } 96 | } 97 | v_off += HEAD_DIM; 98 | } 99 | } 100 | } 101 | } 102 | 103 | void softmax( 104 | const uniform int seq_len, 105 | const uniform itmd_t softmax_scale, 106 | 107 | uniform itmd_t a[], // [seq_len, NUM_Q_HEADS] 108 | uniform itmd_t asb[] // [NUM_Q_HEADS] 109 | ) { 110 | uniform itmd_t amb[NUM_Q_HEADS]; 111 | 112 | foreach (h = 0 ... NUM_Q_HEADS) { 113 | amb[h] = -1e20; 114 | for (uniform int i = 0; i < seq_len; i++) { 115 | a[i * NUM_Q_HEADS + h] *= softmax_scale; 116 | amb[h] = max(amb[h], a[i * NUM_Q_HEADS + h]); 117 | } 118 | } 119 | 120 | foreach (h = 0 ... NUM_Q_HEADS) { 121 | asb[h] = 0; 122 | for (uniform int i = 0; i < seq_len; i++) { 123 | a[i * NUM_Q_HEADS + h] = exp(a[i * NUM_Q_HEADS + h] - amb[h]); 124 | asb[h] += a[i * NUM_Q_HEADS + h]; 125 | } 126 | } 127 | 128 | foreach (h = 0 ... NUM_Q_HEADS) { 129 | for (uniform int i = 0; i < seq_len; i++) { 130 | a[i * NUM_Q_HEADS + h] /= asb[h]; 131 | } 132 | asb[h] = log(asb[h]) + amb[h]; 133 | } 134 | } 135 | 136 | export void attn_one_seq( 137 | const uniform int cur_layer, 138 | const uniform int num_blocks, 139 | const uniform int seq_len, 140 | const uniform itmd_t softmax_scale, 141 | 142 | const uniform data_t q[], // [NUM_Q_HEADS, HEAD_DIM] 143 | const uniform data_t k_cache[], // [..., num_blocks, NUM_KV_HEADS, BLOCK_SIZE, HEAD_DIM] 144 | const uniform data_t v_cache[], // [..., num_blocks, NUM_KV_HEADS, BLOCK_SIZE, HEAD_DIM] 145 | const uniform int block_table[], // [seq_len] 146 | 147 | uniform itmd_t a[], // [seq_len, NUM_KV_HEADS, HEAD_DIM] 148 | uniform otpt_t o[], // [NUM_Q_HEADS, HEAD_DIM] 149 | uniform itmd_t asb[] // [NUM_Q_HEADS] 150 | ) { 151 | qk_product(cur_layer, num_blocks, seq_len, q, k_cache, block_table, a); 152 | softmax(seq_len, softmax_scale, a, asb); 153 | av_product(cur_layer, num_blocks, seq_len, a, v_cache, block_table, o); 154 | } 155 | 156 | export void gather_output_one_seq( 157 | const uniform int num_blocks, 158 | const uniform otpt_t o_buf[], // [num_blocks, NUM_Q_HEADS, HEAD_DIM] 159 | 160 | uniform itmd_t as_buf[], // [num_blocks, NUM_Q_HEADS] 161 | uniform otpt_t o[] // [NUM_Q_HEADS, HEAD_DIM] 162 | ) { 163 | uniform itmd_t as_all[NUM_Q_HEADS]; 164 | uniform itmd_t am_all[NUM_Q_HEADS]; 165 | 166 | foreach(h = 0 ... NUM_Q_HEADS) { 167 | am_all[h] = -1e20; 168 | for (uniform int i = 0; i < num_blocks; i++) { 169 | am_all[h] = max(am_all[h], as_buf[i * NUM_Q_HEADS + h]); 170 | } 171 | } 172 | 173 | foreach(h = 0 ... NUM_Q_HEADS) { 174 | as_all[h] = 0; 175 | for (uniform int i = 0; i < num_blocks; i++) { 176 | as_buf[i * NUM_Q_HEADS + h] = exp(as_buf[i * NUM_Q_HEADS + h] - am_all[h]); 177 | as_all[h] += as_buf[i * NUM_Q_HEADS + h]; 178 | } 179 | } 180 | 181 | foreach(h = 0 ... NUM_Q_HEADS) { 182 | for (uniform int i = 0; i < num_blocks; i++) { 183 | as_buf[i * NUM_Q_HEADS + h] /= as_all[h]; 184 | } 185 | } 186 | 187 | memset(o, 0, NUM_Q_HEADS * HEAD_DIM * sizeof(uniform otpt_t)); 188 | for (uniform int i = 0; i < num_blocks; i++) { 189 | uniform int o_off = 0; 190 | uniform int o_buf_off = i * NUM_Q_HEADS * HEAD_DIM; 191 | for (uniform int j = 0; j < NUM_Q_HEADS; j++) { 192 | uniform itmd_t scale = as_buf[i * NUM_Q_HEADS + j]; 193 | foreach (l = 0 ... HEAD_DIM) { 194 | o[o_off + l] += o_buf[o_buf_off + l] * scale; 195 | } 196 | o_off += HEAD_DIM; 197 | o_buf_off += HEAD_DIM; 198 | } 199 | } 200 | } -------------------------------------------------------------------------------- /swiftllm/perfpredictor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Performance predictor for the SwiftLLM engine. 3 | """ 4 | 5 | from swiftllm.engine_config import EngineConfig 6 | 7 | class PerfPredictor: 8 | """ 9 | Base class for performance predictors. 10 | """ 11 | def __init__( 12 | self, *args 13 | ): 14 | raise NotImplementedError 15 | 16 | def get_linr_T(self, S: int) -> float: 17 | """ 18 | Get the linear time for iteration width S 19 | """ 20 | raise NotImplementedError 21 | 22 | def get_pref_T(self, S: int) -> float: 23 | """ 24 | Get the GPU prefilling time for iteration width S 25 | """ 26 | raise NotImplementedError 27 | 28 | def get_gdec_T(self, N: int) -> float: 29 | """ 30 | Get the GPU decoding time for number of tokens N 31 | """ 32 | raise NotImplementedError 33 | 34 | def get_cdec_T(self, S: int, N: int) -> float: 35 | """ 36 | Get the CPU decoding time for iteration width S and number of tokens N 37 | """ 38 | raise NotImplementedError 39 | 40 | def get_lnch_T(self) -> float: 41 | """ 42 | Get the kernel launch time 43 | """ 44 | raise NotImplementedError 45 | 46 | class ZeroPerfPredictor(PerfPredictor): 47 | """ 48 | A dummy performance predictor that always returns 0. 49 | """ 50 | def __init__( 51 | self, *args 52 | ): 53 | pass 54 | 55 | def get_linr_T(self, S: int) -> float: 56 | return 0.0 57 | 58 | def get_pref_T(self, S: int) -> float: 59 | return 0.0 60 | 61 | def get_gdec_T(self, N: int) -> float: 62 | return 0.0 63 | 64 | def get_cdec_T(self, S: int, N: int) -> float: 65 | return 0.0 66 | 67 | def get_lnch_T(self) -> float: 68 | return 0.0 69 | 70 | class TablePerfPredictor(PerfPredictor): 71 | """ 72 | A perfomance predictor that uses a table to store the performance data. 73 | 74 | It uses linear interpolation to estimate the performance for unseen data. 75 | """ 76 | def __init__( 77 | self, 78 | engine_config: EngineConfig 79 | ): 80 | # Linr 81 | self.linr_S_list = list(range(1, 512)) + [ 82 | 2 ** i for i in range( 83 | 9, 84 | (engine_config.max_tokens_in_batch - 1).bit_length() 85 | ) 86 | ] + [engine_config.max_tokens_in_batch] 87 | self.linr_T_list = None 88 | self.linr_S_lb_idx = self._get_lb_idx_list(self.linr_S_list) 89 | self.linr_S_threshold = 128 # NOTE: This is a heuristic value 90 | 91 | # Pref 92 | self.pref_S_list = sum([[2 ** (i-2) * 3, 2 ** i] for i in range( 93 | (engine_config.block_size - 1).bit_length(), 94 | (engine_config.max_tokens_in_batch - 1).bit_length() 95 | )], []) + [engine_config.max_tokens_in_batch] 96 | self.pref_T_list = None 97 | self.pref_S_lb_idx = self._get_lb_idx_list(self.pref_S_list) 98 | 99 | # Gdec 100 | self.gdec_N_list = sum([[2 ** (i-2) * 3, 2 ** i] for i in range( 101 | (engine_config.block_size - 1).bit_length(), 102 | (engine_config.max_gpu_tokens - 1).bit_length() 103 | )], []) + [engine_config.max_gpu_tokens] 104 | self.gdec_T_list = None 105 | self.gdec_N_lb_idx = self._get_lb_idx_list(self.gdec_N_list) 106 | 107 | # Cdec 108 | self.cdec_S_list = [2 ** i for i in range( 109 | 0, 110 | (engine_config.max_batch_size - 1).bit_length() 111 | )] + [engine_config.max_batch_size] 112 | self.cdec_N_lists = [ 113 | [S * engine_config.block_size] + 114 | [2 ** i for i in range( 115 | (S * engine_config.block_size).bit_length(), 116 | (min(S * engine_config.max_seq_len, engine_config.max_cpu_tokens) - 1).bit_length() 117 | )] + 118 | [min(S * engine_config.max_seq_len, engine_config.max_cpu_tokens)] 119 | for S in self.cdec_S_list 120 | ] 121 | self.cdec_N_list_agg = sorted(list(set(sum(self.cdec_N_lists, [])))) 122 | 123 | self.cdec_T_lists = [None] 124 | self.cdec_S_lb_idx = self._get_lb_idx_list(self.cdec_S_list) 125 | self.cdec_N_lb_idx = self._get_lb_idx_list(self.cdec_N_list_agg) 126 | 127 | # Lnch 128 | self.lnch_T = 0.8 129 | # self.lnch_T = self._profile_lnch(lnch_S_list) 130 | 131 | def _get_lb_idx_list(self, input_list: list[int]) -> list[int]: 132 | """ 133 | Get the lower bound index list of x in the input list. 134 | 135 | Given i, find the smallest j s.t. input_list[j] >= i. 136 | """ 137 | return sum( 138 | [[i+1] * (input_list[i+1] - input_list[i]) for i in range(len(input_list) - 1)], 139 | [0] * (input_list[0] + 1) 140 | ) 141 | 142 | def _interp(self, x: int, x0: int, x1: int, y0: float, y1: float) -> float: 143 | """ 144 | Linear interpolation of 2 points (x0, y0) and (x1, y1) at x. 145 | """ 146 | return y0 + (y1 - y0) * (x - x0) / (x1 - x0) 147 | 148 | def _interp_1d(self, x, xs: list[int], ys: list[float], x_lb_idx: list[int]) -> float: 149 | """ 150 | Linear interpolation of 1D points (x_list, y_list) at x. Assume x <= x_list[-1]. 151 | """ 152 | assert x <= xs[-1], f"x={x} exceeds the maximum {xs[-1]}" 153 | if x == 0: 154 | return 0.0 155 | idx = x_lb_idx[x] 156 | if idx == 0 or x == xs[idx]: 157 | return ys[idx] 158 | return self._interp(x, xs[idx-1], xs[idx], ys[idx-1], ys[idx]) 159 | 160 | def get_linr_T(self, S: int) -> float: 161 | """ 162 | Get the linear time for iteration width S, using linear interpolation 163 | """ 164 | return self._interp_1d(S, self.linr_S_list, self.linr_T_list, self.linr_S_lb_idx) 165 | 166 | def get_pref_T(self, S: int) -> float: 167 | """ 168 | Get the GPU prefilling time for iteration width S, using linear interpolation 169 | """ 170 | return self._interp_1d(S, self.pref_S_list, self.pref_T_list, self.pref_S_lb_idx) 171 | 172 | def get_gdec_T(self, N: int) -> float: 173 | """ 174 | Get the GPU decoding time for number of tokens N, using linear interpolation 175 | """ 176 | return self._interp_1d(N, self.gdec_N_list, self.gdec_T_list, self.gdec_N_lb_idx) 177 | 178 | def get_cdec_T(self, S: int, N: int) -> float: 179 | """ 180 | Get the CPU decoding time for iteration width S and number of tokens N, 181 | using bilinear interpolation 182 | """ 183 | assert S < len(self.cdec_S_lb_idx), f"CPU batch size {S} exceeds the maximum {len(self.cdec_S_lb_idx)}" 184 | if S == 0: 185 | return 0.0 186 | s_idx = self.cdec_S_lb_idx[S] 187 | if s_idx == 0 or S == self.cdec_S_list[s_idx]: 188 | return self._interp_1d(N, self.cdec_N_list_agg, self.cdec_T_lists[s_idx], self.cdec_N_lb_idx) 189 | s1 = self.cdec_S_list[s_idx] 190 | s0 = self.cdec_S_list[s_idx - 1] 191 | ts1 = self._interp_1d(N, self.cdec_N_list_agg, self.cdec_T_lists[s_idx], self.cdec_N_lb_idx) 192 | ts0 = self._interp_1d(N, self.cdec_N_list_agg, self.cdec_T_lists[s_idx - 1], self.cdec_N_lb_idx) 193 | return self._interp(S, s0, s1, ts0, ts1) 194 | 195 | def get_lnch_T(self) -> float: 196 | return self.lnch_T 197 | -------------------------------------------------------------------------------- /swiftllm/server/engine.py: -------------------------------------------------------------------------------- 1 | """ 2 | The main engine of the server 3 | """ 4 | 5 | import time 6 | import sys 7 | import asyncio 8 | import functools 9 | import logging 10 | from typing import AsyncGenerator 11 | 12 | from swiftllm.engine_config import EngineConfig 13 | from swiftllm.model_config import LlamaModelConfig 14 | from swiftllm.server.executor import SingleProcExecutor, RayExecutor 15 | from swiftllm.server.profiler import ModelProfiler 16 | from swiftllm.structs import Request, RawRequest, StepOutput, SubBatch 17 | 18 | from swiftllm.server.tokenization_engine import TokenizationEngine 19 | from swiftllm.server.scheduler import Scheduler 20 | from swiftllm.server.block_manager import BlockManager 21 | 22 | logger = logging.getLogger(__name__) 23 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S') 24 | 25 | class Engine: 26 | """ 27 | Offline version of the engine, need to tokenize manually 28 | """ 29 | 30 | def __init__(self, engine_config: EngineConfig): 31 | self.engine_config = engine_config 32 | self.model_config = LlamaModelConfig.load_from_model_path(engine_config.model_path) 33 | self.initialized = False 34 | 35 | assert engine_config.max_batch_size <= engine_config.max_tokens_in_batch, \ 36 | f"max_batch_size {engine_config.max_batch_size} exceeds max_tokens_in_batch {engine_config.max_tokens_in_batch}" 37 | assert engine_config.max_batch_size <= engine_config.max_seqs_in_block_table, \ 38 | f"max_batch_size {engine_config.max_batch_size} exceeds max_seqs_in_block_table {engine_config.max_seqs_in_block_table}" 39 | assert engine_config.tensor_parallel_degree >= 1, "Tensor parallel degree should be positive" 40 | 41 | # The following fields will be created on `initialize()` 42 | self.executor = None 43 | self.event_loop = None 44 | self.profiler = None 45 | self.block_manager = None 46 | self.executor_class = SingleProcExecutor if engine_config.tensor_parallel_degree == 1 else RayExecutor 47 | 48 | 49 | def initialize(self): 50 | """ 51 | Initialize the engine 52 | """ 53 | logger.info("Initializing model...") 54 | self.executor = self.executor_class(self.engine_config, self.model_config) 55 | 56 | logger.info("Profiling model...") 57 | self.profiler = ModelProfiler(self.executor) 58 | self.profiler.profile_num_blocks() 59 | 60 | logger.info("Initializing block manager...") 61 | self.block_manager = BlockManager(self.engine_config, self.model_config) 62 | 63 | logger.info("Initializing KV cache and swap...") 64 | self.executor.init_kvcache_and_swap() 65 | 66 | logger.info("Model initialized") 67 | self.initialized = True 68 | 69 | 70 | def step(self, batches: list[SubBatch], cur_swap_out: list[Request]=None, cur_swap_in: list[Request]=None): 71 | """ 72 | Perform a step of the engine 73 | """ 74 | forward_args = self.block_manager.prepare(batches, cur_swap_out or [], cur_swap_in or []) 75 | output_token_ids = self.executor.do_one_iteration(batches, *forward_args) 76 | self.block_manager.update_and_free(batches, output_token_ids) 77 | 78 | 79 | 80 | class AsyncEngine(Engine): 81 | """ 82 | The main engine of the server 83 | """ 84 | 85 | def __init__(self, engine_config: EngineConfig): 86 | super().__init__(engine_config) 87 | 88 | # The following fields will be created on `init_model()` 89 | self.scheduler = None 90 | self.tokenization_engine = None 91 | 92 | self.untokenized_raw_requests: list[tuple[Request, str]] = [] 93 | 94 | 95 | async def _run_on_model_executor_async(self, func, *args, **kwargs): 96 | """ 97 | Run a function on the model asynchronously, and return the result 98 | """ 99 | func_partial = functools.partial(func, *args, **kwargs) 100 | return await self.event_loop.run_in_executor(None, func_partial) 101 | 102 | 103 | async def initialize_async(self): 104 | """ 105 | Initialize the engine 106 | """ 107 | self.event_loop = asyncio.get_event_loop() 108 | 109 | super().initialize() 110 | 111 | logger.info("Initializing performance table...") 112 | self.profiler.init_profile_tables(self.block_manager) 113 | 114 | logger.info("Initializing scheduler...") 115 | self.scheduler = Scheduler(self.engine_config, self.model_config, self.profiler.pp) 116 | 117 | logger.info("Initializing tokenization engine...") 118 | # pylint: disable=no-member 119 | self.tokenization_engine = TokenizationEngine.remote(self.engine_config) 120 | 121 | logger.info("Engine initialized") 122 | self.initialized = True 123 | 124 | 125 | async def add_request_and_stream(self, raw_request: RawRequest) -> AsyncGenerator[StepOutput, None]: 126 | """ 127 | Add a raw request to the engine and stream the output of the request (streaming mode) 128 | """ 129 | request = Request(raw_request) 130 | self.untokenized_raw_requests.append((request, raw_request.prompt)) 131 | while True: 132 | step_output = await request.output_q.get() 133 | yield step_output 134 | request.output_q.task_done() 135 | if step_output.request.is_finished(): 136 | break 137 | 138 | 139 | async def add_request_and_wait(self, raw_request: RawRequest) -> tuple[Request, list[int]]: 140 | """ 141 | Add a raw request to the engine and wait for the completion (non-streaming mode) 142 | 143 | Return the output token ids 144 | """ 145 | request = Request(raw_request) 146 | if isinstance(raw_request.prompt, str): 147 | self.untokenized_raw_requests.append((request, raw_request.prompt)) 148 | else: 149 | # Already tokenized, directly add to the scheduler 150 | request.prompt_token_ids = raw_request.prompt 151 | request.prompt_len = len(raw_request.prompt) 152 | assert request.prompt_len + request.max_output_len <= self.engine_config.max_seq_len, \ 153 | f"Request length {request.prompt_len + request.output_len} exceeds max_seq_len {self.engine_config.max_seq_len}" 154 | self.scheduler.on_requests_arrival([request]) 155 | 156 | await request.finished_event.wait() 157 | return (request, request.output_token_ids) 158 | 159 | 160 | async def _tokenize_raw_request_event_loop(self): 161 | """ 162 | Event loop for tokenizing raw requests 163 | """ 164 | while True: 165 | if not self.untokenized_raw_requests: 166 | # No new raw requests, sleep for a bit 167 | await asyncio.sleep(0.002) 168 | continue 169 | 170 | # Tokenize the raw request in batch 171 | cur_untokenized_raw_requests = self.untokenized_raw_requests 172 | self.untokenized_raw_requests = [] 173 | 174 | prompts = [prompt for _, prompt in cur_untokenized_raw_requests] 175 | prompt_token_ids = await self.tokenization_engine.batched_tokenize.remote(prompts) 176 | 177 | new_requests = [] 178 | for (request, _), prompt_token_id in zip(cur_untokenized_raw_requests, prompt_token_ids): 179 | request.prompt_token_ids = prompt_token_id 180 | request.prompt_len = len(prompt_token_id) 181 | assert request.prompt_len + request.max_output_len <= self.engine_config.max_seq_len, \ 182 | f"Request length {request.prompt_len + request.output_len} exceeds max_seq_len {self.engine_config.max_seq_len}" 183 | new_requests.append(request) 184 | 185 | self.scheduler.on_requests_arrival(new_requests) 186 | await asyncio.sleep(0.001) # yield the event loop 187 | 188 | 189 | async def _main_event_loop(self): 190 | """ 191 | Event loop for forwarding the model 192 | """ 193 | while True: 194 | # Get the next batch from the scheduler 195 | batches, cur_swap_out, cur_swap_in = self.scheduler.get_next_batch() 196 | if not (len(batches) or len(cur_swap_in) or len(cur_swap_out)): 197 | # Nothing to do, sleep for a bit 198 | await asyncio.sleep(0.001) 199 | continue 200 | 201 | # Prepare model forward arguments 202 | forward_args = self.block_manager.prepare(batches, cur_swap_out, cur_swap_in) 203 | 204 | # Forward the model 205 | if any(b.num_prefs for b in batches): 206 | logger.info(f"Forwarding batches with sizes {[(b.num_cprfs, b.num_gprfs, b.num_gdecs, b.num_cdecs) for b in batches]}, " 207 | f"swap out: {len(cur_swap_out)}, swap in: {len(cur_swap_in)}") 208 | output_token_ids = await self._run_on_model_executor_async(self.executor.do_one_iteration, batches, *forward_args) 209 | 210 | # Deal with output tokens 211 | finished_reqs = self.block_manager.update_and_free(batches, output_token_ids) 212 | self.scheduler.remove_finished_requests(finished_reqs) 213 | 214 | 215 | async def start_all_event_loops(self): 216 | """ 217 | Start all event loops 218 | """ 219 | assert self.initialized, "Engine not initialized. Please call `initialize()` before starting the event loop." 220 | await asyncio.gather( 221 | self._tokenize_raw_request_event_loop(), 222 | self._main_event_loop() 223 | ) 224 | -------------------------------------------------------------------------------- /swiftllm/structs.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import dataclasses 3 | from swiftllm.perfpredictor import PerfPredictor, ZeroPerfPredictor 4 | from swiftllm.model_config import LlamaModelConfig 5 | 6 | @dataclasses.dataclass 7 | class StepOutput: 8 | """ 9 | The output of one decoding step 10 | """ 11 | token_id: int 12 | request: "Request" 13 | 14 | 15 | class RawRequest: 16 | """ 17 | A request issued by user 18 | """ 19 | prompt: str | list[int] 20 | max_output_len: int 21 | 22 | def __init__(self, prompt: str | list[int], max_output_len: int): 23 | self.prompt = prompt 24 | self.max_output_len = max_output_len 25 | 26 | 27 | class Request: 28 | """ 29 | A (queuing, processing, or finished) request in the system 30 | """ 31 | 32 | prompt_token_ids: list[int] # Prompt token ids, generated by the tokenizer upon request arrival 33 | prompt_len: int # len(prompt_token_ids) 34 | output_len: int # Current output length 35 | max_output_len: int # Final output length 36 | 37 | output_q: asyncio.Queue[StepOutput] # Queue initialized when the raw request enters the 38 | # engine, and to be set upon a new token being generated 39 | # Mainly for streaming the output back to the user 40 | finished_event: asyncio.Event # Event to be set when the request is finished 41 | # Mainly for the non-streaming case 42 | 43 | request_id: int # Request ID, within range [0, max_seqs_in_block_table). 44 | # Generated before being prefilled, and used as the index 45 | # into the block table 46 | output_token_ids: list[int] # Output token ids' 47 | 48 | @property 49 | def seq_len(self) -> int: 50 | return self.prompt_len + self.output_len 51 | 52 | def __init__(self, raw_request: RawRequest): 53 | # A request is __init__-ed when entering `untokenized_raw_requests`, and 54 | # its `prompt_token_ids` and `prompt_len` will be set upon tokenization 55 | self.prompt_token_ids = [] 56 | self.prompt_len = 0 57 | self.max_output_len = raw_request.max_output_len 58 | self.output_len = 0 59 | self.output_q = asyncio.Queue() 60 | self.finished_event = asyncio.Event() 61 | self.request_id = -1 62 | self.output_token_ids = [] 63 | 64 | def is_finished(self) -> bool: 65 | return self.output_len == self.max_output_len 66 | 67 | @staticmethod 68 | def get_ids(reqs: list["Request"]) -> list[int]: 69 | """ 70 | Get the request IDs of a list of requests 71 | """ 72 | return [req.request_id for req in reqs] 73 | 74 | @staticmethod 75 | def get_lens(reqs: list["Request"]) -> list[int]: 76 | """ 77 | Get the sequence lengths of a list of requests 78 | """ 79 | return [req.seq_len for req in reqs] 80 | 81 | 82 | @staticmethod 83 | def get_input_tokens(reqs: list["Request"]) -> list[list[int]]: 84 | """ 85 | Get the concatenated input tokens for model forward pass 86 | """ 87 | return sum([req.prompt_token_ids if req.output_len == 0 else req.output_token_ids[-1:] for req in reqs], []) 88 | 89 | 90 | @staticmethod 91 | def update_output(reqs: list["Request"], output_toks: list[int]) -> list["Request"]: 92 | """ 93 | Update the output of a list of requests, requires the reqs are in the order that the output tokens are generated 94 | 95 | Returns the list of requests that are finished 96 | """ 97 | assert len(reqs) == len(output_toks), f"Number of requests {len(reqs)} and output tokens {len(output_toks)} do not match" 98 | finished_reqs = [] 99 | for req, tok in zip(reqs, output_toks): 100 | req.output_len += 1 101 | req.output_token_ids.append(tok) 102 | req.output_q.put_nowait(StepOutput(tok, req)) 103 | if req.is_finished(): 104 | req.finished_event.set() 105 | finished_reqs.append(req) 106 | return finished_reqs 107 | 108 | 109 | def __getstate__(self): 110 | """ 111 | Get the state of the request for serialization, we only pass useful information 112 | """ 113 | return { 114 | "prompt_token_ids": self.prompt_token_ids if self.output_len == 0 else [], 115 | "output_token_ids": self.output_token_ids[-1:] if self.output_len > 0 else [], 116 | "prompt_len": self.prompt_len, 117 | "output_len": self.output_len, 118 | "request_id": self.request_id 119 | } 120 | 121 | 122 | def __setstate__(self, state): 123 | """ 124 | Set the state of the request from the serialized state 125 | """ 126 | self.prompt_token_ids = state["prompt_token_ids"] 127 | self.output_token_ids = state["output_token_ids"] 128 | self.prompt_len = state["prompt_len"] 129 | self.output_len = state["output_len"] 130 | self.request_id = state["request_id"] 131 | 132 | 133 | def create_request( 134 | prompt_token_ids: list[int], 135 | req_id: int, 136 | output_token_ids: list[int] | None = None, 137 | quick_stop: bool = False 138 | ) -> Request: 139 | ret = Request(RawRequest("", 0)) 140 | ret.prompt_token_ids = prompt_token_ids 141 | ret.output_token_ids = output_token_ids or [] 142 | ret.prompt_len = len(ret.prompt_token_ids) 143 | ret.output_len = len(ret.output_token_ids) 144 | ret.max_output_len = ret.output_len + 1 if quick_stop else 10 ** 9 145 | ret.request_id = req_id 146 | return ret 147 | 148 | class BatchPerfData: 149 | """ 150 | Performance data for a batch 151 | """ 152 | # pylint: disable=too-many-instance-attributes, missing-function-docstring 153 | def __init__(self, predictor: PerfPredictor): 154 | self.x = 0 155 | self.s = 0 156 | self.n_g = 0 157 | self.x_c = 0 158 | self.n_c = 0 159 | 160 | self.predictor = predictor 161 | self.pref_T = 0 162 | self.gdec_T = 0 163 | self.lnch_T = predictor.get_lnch_T() 164 | 165 | def add_pref(self, prompt_len): 166 | self.x += 1 167 | self.s += prompt_len 168 | self.pref_T += self.predictor.get_pref_T(prompt_len) 169 | 170 | def pop_pref(self, prompt_len): 171 | self.x -= 1 172 | self.s -= prompt_len 173 | self.pref_T -= self.predictor.get_pref_T(prompt_len) 174 | 175 | def add_gdec(self, seq_len): 176 | self.x += 1 177 | self.s += 1 178 | self.n_g += seq_len 179 | self.gdec_T = self.predictor.get_gdec_T(self.n_g) 180 | 181 | def add_cdec(self, seq_len): 182 | self.x += 1 183 | self.s += 1 184 | self.x_c += 1 185 | self.n_c += seq_len 186 | 187 | def pop_cdec(self, seq_len): 188 | self.x -= 1 189 | self.s -= 1 190 | self.x_c -= 1 191 | self.n_c -= seq_len 192 | 193 | @property 194 | def linr_T(self) -> float: 195 | return self.predictor.get_linr_T(self.s) 196 | 197 | @property 198 | def cdec_T(self) -> float: 199 | return self.predictor.get_cdec_T(self.x_c, self.n_c) 200 | 201 | @property 202 | def gpu_time(self) -> float: 203 | return self.linr_T + self.pref_T + self.gdec_T 204 | 205 | @property 206 | def cpu_time(self) -> float: 207 | return self.cdec_T + self.lnch_T 208 | 209 | 210 | 211 | class SubBatch: 212 | """ 213 | A sub-batch of requests 214 | """ 215 | # pylint: disable=too-many-instance-attributes, missing-function-docstring 216 | def __init__(self, predictor: PerfPredictor=ZeroPerfPredictor()): 217 | self.gprf_reqs = [] 218 | self.cprf_reqs = [] 219 | self.gdec_reqs = [] 220 | self.cdec_reqs = [] 221 | self.perfdata = BatchPerfData(predictor) 222 | 223 | def __len__(self): 224 | return self.perfdata.x 225 | 226 | def add_pref(self, req: Request, is_gpu: bool): 227 | if is_gpu: 228 | self.gprf_reqs.append(req) 229 | else: 230 | self.cprf_reqs.append(req) 231 | self.perfdata.add_pref(req.prompt_len) 232 | 233 | def pop_pref(self) -> Request: 234 | is_gpu = not self.cprf_reqs 235 | req = self.gprf_reqs.pop() if is_gpu else self.cprf_reqs.pop() 236 | self.perfdata.pop_pref(req.prompt_len) 237 | return req, is_gpu 238 | 239 | def add_gdec(self, req: Request): 240 | self.gdec_reqs.append(req) 241 | self.perfdata.add_gdec(req.seq_len) 242 | 243 | def add_cdec(self, req: Request): 244 | self.cdec_reqs.append(req) 245 | self.perfdata.add_cdec(req.seq_len) 246 | 247 | def pop_cdec(self): 248 | req = self.cdec_reqs.pop() 249 | self.perfdata.pop_cdec(req.seq_len) 250 | 251 | def get_num_prefs(self) -> int: 252 | return len(self.gprf_reqs) + len(self.cprf_reqs) 253 | 254 | def set_model_forward_args(self, model_config: LlamaModelConfig): 255 | """ 256 | Set useful attributes for the model forward pass 257 | 258 | The comments indicate each attribute's usage in the model forward pass 259 | """ 260 | # pylint: disable=attribute-defined-outside-init 261 | self.batch_size = self.perfdata.x # post-layer 262 | self.iter_width = self.perfdata.s # post-layer 263 | del self.perfdata 264 | 265 | self.num_cprfs = len(self.cprf_reqs) 266 | self.num_gprfs = len(self.gprf_reqs) 267 | self.num_gdecs = len(self.gdec_reqs) 268 | self.num_cdecs = len(self.cdec_reqs) 269 | self.num_prefs = self.num_cprfs + self.num_gprfs 270 | self.num_prgds = self.num_prefs + self.num_gdecs 271 | 272 | self.all_reqs = self.cprf_reqs + self.gprf_reqs + self.gdec_reqs + self.cdec_reqs 273 | assert all(req.request_id >= 0 for req in self.all_reqs), "Request ID not set" 274 | del self.cprf_reqs, self.gprf_reqs, self.gdec_reqs, self.cdec_reqs 275 | 276 | self.seq_ids_list = Request.get_ids(self.all_reqs) 277 | self.seq_lens_list = Request.get_lens(self.all_reqs) 278 | 279 | # Useful for attn kernels 280 | self.sum_pref_toks = sum(self.seq_lens_list[:self.num_prefs]) # store-pref-KV, pref, gdec 281 | self.sum_prgd_toks = self.sum_pref_toks + self.num_gdecs # gdec 282 | self.max_pref_toks = max(self.seq_lens_list[:self.num_prefs], default=0) # store-pref-KV, pref 283 | 284 | # Useful for paged attention 285 | sum_gdec_toks = sum(self.seq_lens_list[self.num_prefs:self.num_prgds]) 286 | max_gdec_toks = max(self.seq_lens_list[self.num_prefs:self.num_prgds], default=0) 287 | seq_block_size = 2048 288 | num_kv_heads = model_config.num_kv_heads 289 | while num_kv_heads*(sum_gdec_toks/seq_block_size) < 1024 and seq_block_size//2 >= 64 and \ 290 | max_gdec_toks / (seq_block_size//2) <= 128: 291 | seq_block_size //= 2 292 | self.seq_block_size = seq_block_size 293 | self.num_seq_blocks = (max_gdec_toks + seq_block_size - 1) // seq_block_size 294 | 295 | 296 | def print_profile(self): 297 | print(f"cprf lens: {[req.prompt_len for req in self.cprf_reqs]}, gprf lens: {[req.prompt_len for req in self.gprf_reqs]}, " 298 | f"gdec lens: {[req.seq_len for req in self.gdec_reqs]}, cdec lens: {[req.seq_len for req in self.cdec_reqs]}") 299 | -------------------------------------------------------------------------------- /pacpu/core.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "dtype.h" 7 | #include "pacpu_ispc.h" 8 | 9 | namespace brute { 10 | 11 | void store_kv( 12 | int cur_layer, 13 | int num_blocks, 14 | int seq_len, 15 | data_t k[], 16 | data_t v[], 17 | data_t k_cache[], 18 | data_t v_cache[], 19 | int block_table[] 20 | ) { 21 | int block_pos = (seq_len - 1) / BLOCK_SIZE; 22 | int block_id = block_table[block_pos]; 23 | int block_off = (seq_len - 1) % BLOCK_SIZE; 24 | int64_t cache_off = (1ll * cur_layer * num_blocks + block_id) * BLOCK_NELEM + block_off * HEAD_DIM; 25 | data_t* kp = k_cache + cache_off; 26 | data_t* vp = v_cache + cache_off; 27 | for (int i = 0; i < NUM_KV_HEADS; i++) { 28 | memcpy(kp + i * BLOCK_SIZE * HEAD_DIM, k + i * HEAD_DIM, HEAD_DIM * sizeof(data_t)); 29 | memcpy(vp + i * BLOCK_SIZE * HEAD_DIM, v + i * HEAD_DIM, HEAD_DIM * sizeof(data_t)); 30 | } 31 | } 32 | 33 | void qk_product( 34 | int cur_layer, 35 | int num_blocks, 36 | int seq_len, 37 | 38 | data_t q[], 39 | data_t k_cache[], 40 | int block_table[], 41 | 42 | itmd_t a[] 43 | ) { 44 | for (auto j = 0; j < seq_len; j += BLOCK_SIZE) { 45 | auto kp = k_cache + (1ll * cur_layer * num_blocks + block_table[j / BLOCK_SIZE]) * BLOCK_NELEM; 46 | auto tlim = std::min(BLOCK_SIZE, seq_len - j); 47 | for (auto h = 0; h < NUM_KV_HEADS; h++) { 48 | for (auto t = 0; t < tlim; t++) { 49 | for (auto d = 0; d < HEAD_DIM; d++) { 50 | for (auto l = 0; l < QH_PER_KVH; l++) { 51 | a[(j + t) * NUM_Q_HEADS + h * QH_PER_KVH + l] += 52 | q[(h * QH_PER_KVH + l) * HEAD_DIM + d] * kp[(h * BLOCK_SIZE + t) * HEAD_DIM + d]; 53 | } 54 | } 55 | } 56 | } 57 | } 58 | } 59 | 60 | void av_product( 61 | int cur_layer, 62 | int num_blocks, 63 | int seq_len, 64 | 65 | itmd_t a[], 66 | data_t v_cache[], 67 | int block_table[], 68 | 69 | otpt_t o[] 70 | ) { 71 | memset(o, 0, NUM_Q_HEADS * HEAD_DIM * sizeof(otpt_t)); 72 | for (auto j = 0; j < seq_len; j += BLOCK_SIZE) { 73 | auto vjp = v_cache + (1ll * cur_layer * num_blocks + block_table[j / BLOCK_SIZE]) * BLOCK_NELEM; 74 | auto vp = vjp; 75 | auto tlim = std::min(BLOCK_SIZE, seq_len - j); 76 | for (auto h = 0; h < NUM_KV_HEADS; h++) { 77 | for (auto t = 0; t < tlim; t++) { 78 | for (auto d = 0; d < HEAD_DIM; d++) { 79 | for (auto l = 0; l < QH_PER_KVH; l++) { 80 | o[(h * QH_PER_KVH + l) * HEAD_DIM + d] += 81 | a[(j + t) * NUM_Q_HEADS + h * QH_PER_KVH + l] * vp[(h * BLOCK_SIZE + t) * HEAD_DIM + d]; 82 | } 83 | } 84 | } 85 | } 86 | } 87 | } 88 | 89 | void softmax( 90 | int seq_len, 91 | itmd_t softmax_scale, 92 | itmd_t a[], 93 | itmd_t s[], 94 | itmd_t m[] 95 | ) { 96 | for (auto h = 0; h < NUM_Q_HEADS; h++) { 97 | s[h] = 0; 98 | m[h] = -1e20; 99 | } 100 | 101 | auto ap = a; 102 | for (auto j = 0; j < seq_len; j++) { 103 | for (auto h = 0; h < NUM_Q_HEADS; h++) { 104 | ap[h] *= softmax_scale; 105 | m[h] = std::max(m[h], ap[h]); 106 | } 107 | ap += NUM_Q_HEADS; 108 | } 109 | 110 | ap = a; 111 | for (auto j = 0; j < seq_len; j++) { 112 | for (auto h = 0; h < NUM_Q_HEADS; h++) { 113 | ap[h] = std::exp(ap[h] - m[h]); 114 | s[h] += ap[h]; 115 | } 116 | ap += NUM_Q_HEADS; 117 | } 118 | 119 | ap = a; 120 | for (auto j = 0; j < seq_len; j++) { 121 | for (auto h = 0; h < NUM_Q_HEADS; h++) { 122 | ap[h] /= s[h]; 123 | } 124 | ap += NUM_Q_HEADS; 125 | } 126 | } 127 | 128 | } 129 | 130 | void brute_attention( 131 | int cur_layer, 132 | int num_blocks, 133 | int batch_size, 134 | int block_table_width, 135 | double softmax_scale, 136 | const std::vector &seq_ids, 137 | const std::vector &seq_lengths, 138 | 139 | data_t qbatch_p[], 140 | data_t kbatch_p[], 141 | data_t vbatch_p[], 142 | otpt_t obatch_p[], 143 | data_t kcache_p[], 144 | data_t vcache_p[], 145 | int block_table_p[] 146 | ){ 147 | auto max_seq_len = *std::max_element(seq_lengths.begin(), seq_lengths.end()); 148 | itmd_t* attn_score_buf = new itmd_t[max_seq_len * NUM_Q_HEADS]; 149 | itmd_t* attn_sum_buf = new itmd_t[NUM_Q_HEADS]; 150 | itmd_t* attn_max_buf = new itmd_t[NUM_Q_HEADS]; 151 | 152 | for (auto i = 0; i < batch_size; i++) { 153 | int seq_id = seq_ids[i]; 154 | int seq_len = seq_lengths[i]; 155 | auto qip = qbatch_p + i * NUM_Q_HEADS * HEAD_DIM; 156 | auto kip = kbatch_p + i * NUM_KV_HEADS * HEAD_DIM; 157 | auto vip = vbatch_p + i * NUM_KV_HEADS * HEAD_DIM; 158 | auto oip = obatch_p + i * NUM_Q_HEADS * HEAD_DIM; 159 | auto btp = block_table_p + seq_id * block_table_width; 160 | memset(attn_score_buf, 0, seq_len * NUM_Q_HEADS * sizeof(itmd_t)); 161 | 162 | brute::store_kv(cur_layer, num_blocks, seq_len, kip, vip, kcache_p, vcache_p, btp); 163 | brute::qk_product(cur_layer, num_blocks, seq_len, qip, kcache_p, btp, attn_score_buf); 164 | brute::softmax(seq_len, softmax_scale, attn_score_buf, attn_sum_buf, attn_max_buf); 165 | brute::av_product(cur_layer, num_blocks, seq_len, attn_score_buf, vcache_p, btp, oip); 166 | } 167 | delete [] attn_score_buf; 168 | delete [] attn_sum_buf; 169 | delete [] attn_max_buf; 170 | } 171 | 172 | void ispc_attention( 173 | int cur_layer, 174 | int num_blocks, 175 | int batch_size, 176 | int block_table_width, 177 | double softmax_scale, 178 | const std::vector &seq_ids, 179 | const std::vector &seq_lengths, 180 | 181 | data_t qbatch_p[], 182 | data_t kbatch_p[], 183 | data_t vbatch_p[], 184 | otpt_t obatch_p[], 185 | data_t kcache_p[], 186 | data_t vcache_p[], 187 | int block_table_p[] 188 | ) { 189 | int max_seq_len = *std::max_element(seq_lengths.begin(), seq_lengths.end()); 190 | itmd_t* attn_score_buf = new itmd_t[max_seq_len * NUM_Q_HEADS]; 191 | itmd_t attn_sum_buf[NUM_Q_HEADS]; 192 | 193 | for (auto i = 0; i < batch_size; i++) { 194 | int seq_id = seq_ids[i]; 195 | int seq_len = seq_lengths[i]; 196 | auto qip = qbatch_p + i * NUM_Q_HEADS * HEAD_DIM; 197 | auto kip = kbatch_p + i * NUM_KV_HEADS * HEAD_DIM; 198 | auto vip = vbatch_p + i * NUM_KV_HEADS * HEAD_DIM; 199 | auto oip = obatch_p + i * NUM_Q_HEADS * HEAD_DIM; 200 | auto btp = block_table_p + seq_id * block_table_width; 201 | 202 | brute::store_kv( 203 | cur_layer, num_blocks, seq_len, 204 | kip, vip, kcache_p, vcache_p, btp 205 | ); 206 | 207 | ispc::attn_one_seq( 208 | cur_layer, num_blocks, seq_len, softmax_scale, 209 | qip, kcache_p, vcache_p, btp, 210 | attn_score_buf, oip, attn_sum_buf 211 | ); 212 | } 213 | 214 | delete [] attn_score_buf; 215 | } 216 | 217 | // Here we use global buffers to store intermediate results 218 | itmd_t attn_score_buf[MAX_TOK_NUM * NUM_Q_HEADS]; 219 | otpt_t o_buf [MAX_TASK_NUM * NUM_Q_HEADS * HEAD_DIM]; 220 | itmd_t attn_sum_buf[MAX_TASK_NUM * NUM_Q_HEADS]; 221 | 222 | void ispc_attention_tasks( 223 | int cur_layer, 224 | int num_blocks, 225 | int batch_size, 226 | int block_table_width, 227 | double softmax_scale, 228 | const std::vector &seq_ids, 229 | const std::vector &seq_lengths, 230 | 231 | data_t qbatch_p[], 232 | data_t kbatch_p[], 233 | data_t vbatch_p[], 234 | otpt_t obatch_p[], 235 | data_t kcache_p[], 236 | data_t vcache_p[], 237 | int block_table_p[] 238 | ) { 239 | int ws = omp_get_max_threads(); 240 | int bch_blk_size = (batch_size - 1) / ws + 1; 241 | int tot_blks = 0; 242 | for (auto i = 0; i < batch_size; i++) { 243 | tot_blks += (seq_lengths[i] - 1) / BLOCK_SIZE + 1; 244 | } 245 | 246 | int thrd_rst_blks[MAX_WS]; 247 | for (auto i = 0; i < ws; i++) { 248 | thrd_rst_blks[i] = tot_blks / ws + (i < tot_blks % ws); 249 | } 250 | 251 | // Distribute tasks to threads, each thread processes no more than thrd_max_blks blocks 252 | std::vector > tasks; // specs of each task (batch_id, seq_offs, seg_len, cum_seg_len) 253 | int* thrd_start_task = new int[ws + 1]; // starting task id for each thread 254 | int* seq_start_task = new int[batch_size + 1]; // starting task id for each sequence 255 | int cur_thrd = 0; 256 | int cum_seg_len = 0; 257 | thrd_start_task[0] = 0; 258 | for (int i = 0; i < batch_size; i++) { 259 | int seq_offs = 0; 260 | int seq_len = seq_lengths[i]; 261 | seq_start_task[i] = tasks.size(); 262 | while(seq_offs < seq_len) { 263 | if (thrd_rst_blks[cur_thrd] == 0) { 264 | thrd_start_task[++cur_thrd] = tasks.size(); 265 | } 266 | int seg_len = std::min(seq_len - seq_offs, thrd_rst_blks[cur_thrd] * BLOCK_SIZE); 267 | tasks.emplace_back(i, seq_offs, seg_len, cum_seg_len); 268 | seq_offs += seg_len; 269 | cum_seg_len += seg_len; 270 | thrd_rst_blks[cur_thrd] -= (seg_len - 1) / BLOCK_SIZE + 1; 271 | } 272 | } 273 | seq_start_task[batch_size] = tasks.size(); 274 | for (;cur_thrd < ws; cur_thrd++) { 275 | thrd_start_task[cur_thrd + 1] = tasks.size(); 276 | } 277 | 278 | // for (int i = 0; i <= batch_size; i++) { 279 | // printf("seq_start_task[%d] = %d\n", i, seq_start_task[i]); 280 | // } 281 | // for (int i = 0; i <= ws; i++) { 282 | // printf("thrd_start_task[%d] = %d\n", i, thrd_start_task[i]); 283 | // } 284 | 285 | # pragma omp parallel 286 | { 287 | // Step 0: 288 | // store the kv_cache 289 | int tid = omp_get_thread_num(); 290 | int l = tid * bch_blk_size, r = std::min((tid + 1) * bch_blk_size, batch_size); 291 | // NOTE: l >= r when batch_size < omp_get_max_threads() 292 | for (auto i = l; i < r; i++) { 293 | int seq_id = seq_ids[i]; 294 | int seq_len = seq_lengths[i]; 295 | auto kip = kbatch_p + i * NUM_KV_HEADS * HEAD_DIM; 296 | auto vip = vbatch_p + i * NUM_KV_HEADS * HEAD_DIM; 297 | auto btp = block_table_p + seq_id * block_table_width; 298 | 299 | brute::store_kv( 300 | cur_layer, num_blocks, seq_len, kip, vip, kcache_p, vcache_p, btp 301 | ); 302 | } 303 | # pragma omp barrier 304 | 305 | // Step 1: 306 | // compute intermediate output for each sequence block 307 | // output is stored in o_buf and attn_sum_buf 308 | for (auto t = thrd_start_task[tid]; t < thrd_start_task[tid + 1]; t++) { 309 | int i, seq_offs, seg_len, cum_seg_len; 310 | std::tie(i, seq_offs, seg_len, cum_seg_len) = tasks[t]; 311 | auto qip = qbatch_p + i * NUM_Q_HEADS * HEAD_DIM; 312 | auto oip = obatch_p + i * NUM_Q_HEADS * HEAD_DIM; 313 | auto btp = block_table_p + seq_ids[i] * block_table_width; 314 | ispc::attn_one_seq( 315 | cur_layer, num_blocks, seg_len, softmax_scale, 316 | qip, kcache_p, vcache_p, btp + seq_offs / BLOCK_SIZE, 317 | attn_score_buf + cum_seg_len * NUM_Q_HEADS, 318 | seg_len == seq_lengths[i] ? oip : o_buf + t * NUM_Q_HEADS * HEAD_DIM, 319 | attn_sum_buf + t * NUM_Q_HEADS 320 | ); 321 | } 322 | # pragma omp barrier 323 | 324 | // Step 2: 325 | // Gather intermediate output to final output 326 | for (auto i = l; i < r; i++) { 327 | int num_tasks = seq_start_task[i + 1] - seq_start_task[i]; 328 | if (num_tasks > 1) { 329 | int o_off = seq_start_task[i] * NUM_Q_HEADS * HEAD_DIM; 330 | int as_off = seq_start_task[i] * NUM_Q_HEADS; 331 | auto oip = obatch_p + i * NUM_Q_HEADS * HEAD_DIM; 332 | ispc::gather_output_one_seq( 333 | num_tasks, 334 | o_buf + o_off, 335 | attn_sum_buf + as_off, 336 | oip 337 | ); 338 | } 339 | } 340 | } 341 | } 342 | -------------------------------------------------------------------------------- /csrc/src/attention.cu: -------------------------------------------------------------------------------- 1 | #include "attention.h" 2 | 3 | template 4 | __global__ void paged_attention_phase1( 5 | float* __restrict__ mid_o, 6 | float* __restrict__ mid_o_logexpsum, 7 | const half* __restrict__ q, 8 | const half* __restrict__ k, 9 | const half* __restrict__ v, 10 | half* __restrict__ kcache, 11 | half* __restrict__ vcache, 12 | const float softmax_scale, 13 | const int* __restrict__ block_table, 14 | const int* __restrict__ seq_ids, 15 | const int* __restrict__ seq_lens, 16 | const int cur_layer, 17 | const int num_seq_blocks, 18 | const int seq_block_size, 19 | const int block_table_width 20 | ) { 21 | // grid shape: [num_decoding_seqs, NUM_Q_HEADS, num_seq_blocks] 22 | // block shape: [HEAD_DIM] 23 | const int QH_PER_KVH = NUM_Q_HEADS / NUM_KV_HEADS; 24 | const int batch_id = blockIdx.x; 25 | const int qhead_id = blockIdx.y; 26 | const int seq_block_id = blockIdx.z; 27 | const int head_offs = threadIdx.x; 28 | const int kvhead_id = qhead_id / QH_PER_KVH; 29 | 30 | const int seq_id = seq_ids[batch_id]; 31 | const int seq_len = seq_lens[batch_id]; 32 | 33 | const int start_token_id = seq_block_id * seq_block_size; 34 | 35 | if (start_token_id >= seq_len) { 36 | return; 37 | } 38 | 39 | if (start_token_id + seq_block_size >= seq_len) { 40 | // The last sequence block, need to store new KV 41 | // Note that the same value may be stored by different thread blocks, but it's safe since the value is the same 42 | const int last_block_pos = (seq_len - 1) / BLOCK_SIZE; 43 | const int last_tok_offs = (seq_len - 1) % BLOCK_SIZE; 44 | const int last_block_id = block_table[seq_id * block_table_width + last_block_pos]; 45 | const int kv_offs = (batch_id * 46 | NUM_KV_HEADS + kvhead_id) * 47 | HEAD_DIM + head_offs; 48 | const int64_t kvcache_offs = ((((int64_t)last_block_id * 49 | NUM_LAYERS + cur_layer) * 50 | NUM_KV_HEADS + kvhead_id) * 51 | BLOCK_SIZE + last_tok_offs) * 52 | HEAD_DIM + head_offs; 53 | kcache[kvcache_offs] = k[kv_offs]; 54 | vcache[kvcache_offs] = v[kv_offs]; 55 | } 56 | 57 | // We group blocks into sections so each thread can handle a "token" in the section 58 | // to compute Q*K^T 59 | const int SECTION_SIZE = HEAD_DIM; 60 | const int sec_offs = head_offs; 61 | 62 | __shared__ float acc[HEAD_DIM]; 63 | __shared__ float cur[HEAD_DIM]; 64 | __shared__ float a[HEAD_DIM]; // A = Q * K^T 65 | float max_score = -1e20f; 66 | float sum_exp = 0.0f; 67 | acc[head_offs] = 0.0f; 68 | 69 | const int maxi = min(seq_len, start_token_id + seq_block_size); 70 | const int q_offs = (batch_id * 71 | NUM_Q_HEADS + qhead_id) * 72 | HEAD_DIM; 73 | for (int i = start_token_id; i < maxi; i += SECTION_SIZE) { 74 | // Step1: Compute Q * K^T 75 | if (i + sec_offs < seq_len) { 76 | const int block_id = block_table[ 77 | seq_id * block_table_width + (i + sec_offs) / BLOCK_SIZE]; 78 | const int block_offs = sec_offs % BLOCK_SIZE; 79 | const int kcache_offs = ((((int64_t)block_id * 80 | NUM_LAYERS + cur_layer) * 81 | NUM_KV_HEADS + kvhead_id) * 82 | BLOCK_SIZE + block_offs) * 83 | HEAD_DIM; 84 | a[sec_offs] = 0.0f; 85 | for (int j = 0; j < HEAD_DIM; j++) { 86 | a[sec_offs] += __half2float(kcache[kcache_offs + j]) * __half2float(q[q_offs + j]); 87 | } 88 | a[sec_offs] *= softmax_scale; 89 | } 90 | else { 91 | a[sec_offs] = -1e20f; 92 | } 93 | __syncthreads(); 94 | // Step2: Compute softmax (TODO: optimize the reduction) 95 | __shared__ float cur_max_score; 96 | __shared__ float new_max_score; 97 | __shared__ float old_acc_scale; 98 | __shared__ float cur_sum_exp; 99 | if (sec_offs == 0) { 100 | cur_max_score = -1e20f; 101 | cur_sum_exp = 0.0f; 102 | for (int j = 0; j < SECTION_SIZE; j++) { 103 | cur_max_score = fmaxf(cur_max_score, a[j]); 104 | } 105 | new_max_score = fmaxf(max_score, cur_max_score); 106 | old_acc_scale = exp2f(max_score - new_max_score); 107 | for (int j = 0; j < SECTION_SIZE; j++) { 108 | a[j] = exp2f(a[j] - new_max_score); 109 | cur_sum_exp += a[j]; 110 | } 111 | } 112 | __syncthreads(); 113 | sum_exp = sum_exp * old_acc_scale + cur_sum_exp; 114 | max_score = new_max_score; 115 | // Step3: Compute exp(a - max_score) and o 116 | cur[head_offs] = 0.0f; 117 | const int maxj = min(maxi, i + SECTION_SIZE); 118 | for (int j = i; j < maxj; j += BLOCK_SIZE) { 119 | const int block_id = block_table[ 120 | seq_id * block_table_width + j / BLOCK_SIZE]; 121 | const int vcache_offs = ((((int64_t)block_id * 122 | NUM_LAYERS + cur_layer) * 123 | NUM_KV_HEADS + kvhead_id) * 124 | BLOCK_SIZE) * 125 | HEAD_DIM + head_offs; 126 | const int maxl = min(maxj, j + BLOCK_SIZE); 127 | for (int l = j; l < maxl; l++) { 128 | cur[head_offs] += a[l - i] * __half2float(vcache[vcache_offs + (l - j) * HEAD_DIM]); 129 | } 130 | } 131 | acc[head_offs] = acc[head_offs] * old_acc_scale + cur[head_offs]; 132 | } 133 | const int mid_o_logexpsum_offs = (batch_id * 134 | NUM_Q_HEADS + qhead_id) * 135 | num_seq_blocks + seq_block_id; 136 | const int mid_o_offs = mid_o_logexpsum_offs * 137 | HEAD_DIM + head_offs; 138 | mid_o[mid_o_offs] = acc[head_offs] / sum_exp; 139 | if (head_offs == 0) { 140 | mid_o_logexpsum[mid_o_logexpsum_offs] = log2f(sum_exp) + max_score; 141 | } 142 | } 143 | 144 | #define LAUNCH_PHASE1_KERNEL(num_layers, num_q_heads, num_kv_heads, head_dim, block_size) \ 145 | paged_attention_phase1 \ 146 | <<>>( \ 147 | (float*)mid_o.data_ptr(), \ 148 | (float*)mid_o_logexpsum.data_ptr(), \ 149 | (half*)q.data_ptr(), \ 150 | (half*)k.data_ptr(), \ 151 | (half*)v.data_ptr(), \ 152 | (half*)kcache.data_ptr(), \ 153 | (half*)vcache.data_ptr(), \ 154 | softmax_scale, \ 155 | (int*)block_table.data_ptr(), \ 156 | (int*)seq_ids.data_ptr(), \ 157 | (int*)seq_lens.data_ptr(), \ 158 | cur_layer, \ 159 | num_seq_blocks, \ 160 | seq_block_size, \ 161 | block_table_width \ 162 | ) 163 | 164 | template 165 | __global__ void paged_attention_phase2( 166 | const float* __restrict__ mid_o, 167 | const float* __restrict__ mid_o_logexpsum, 168 | half* __restrict__ o, 169 | 170 | const int* __restrict__ seq_ids, 171 | const int* __restrict__ seq_lens, 172 | 173 | const int seq_block_size 174 | ) { 175 | // grid shape: [num_decoding_seqs, NUM_Q_HEADS] 176 | // block shape: [HEAD_DIM] 177 | const int batch_id = blockIdx.x; 178 | const int qhead_id = blockIdx.y; 179 | const int head_offs = threadIdx.x; 180 | 181 | const int seq_len = seq_lens[batch_id]; 182 | const int num_seq_blocks = (seq_len - 1) / seq_block_size + 1; 183 | float sum_exp = 0.0f; 184 | float max_score = -1e20f; 185 | __shared__ float acc[HEAD_DIM]; 186 | acc[head_offs] = 0.0f; 187 | 188 | for (int i = 0; i < num_seq_blocks; i++) { 189 | const int mid_o_logexpsum_offs = (batch_id * 190 | NUM_Q_HEADS + qhead_id) * 191 | num_seq_blocks + i; 192 | const int mid_o_offs = mid_o_logexpsum_offs * 193 | HEAD_DIM + head_offs; 194 | const float cur_mid_o = mid_o[mid_o_offs]; 195 | const float cur_mid_o_logexpsum = mid_o_logexpsum[mid_o_logexpsum_offs]; 196 | const float new_max_score = fmaxf(max_score, cur_mid_o_logexpsum); 197 | const float old_acc_scale = exp2f(max_score - new_max_score); 198 | const float exp_score = exp2f(cur_mid_o_logexpsum - new_max_score); 199 | acc[head_offs] = acc[head_offs] * old_acc_scale + cur_mid_o * exp_score; 200 | sum_exp = sum_exp * old_acc_scale + exp_score; 201 | max_score = new_max_score; 202 | } 203 | 204 | const int o_offs = (batch_id * 205 | NUM_Q_HEADS + qhead_id) * 206 | HEAD_DIM + head_offs; 207 | 208 | o[o_offs] = __float2half(acc[head_offs] / sum_exp); 209 | } 210 | 211 | #define LAUNCH_PHASE2_KERNEL(num_layers, num_q_heads, num_kv_heads, head_dim, block_size) \ 212 | paged_attention_phase2 \ 213 | <<>>( \ 214 | (float*)mid_o.data_ptr(), \ 215 | (float*)mid_o_logexpsum.data_ptr(), \ 216 | (half*)o.data_ptr(), \ 217 | (int*)seq_ids.data_ptr(), \ 218 | (int*)seq_lens.data_ptr(), \ 219 | seq_block_size \ 220 | ) 221 | 222 | #define SELECT_KERNEL(block_size) \ 223 | if (num_layers == 32 && num_q_heads == 32 && num_kv_heads == 8 && head_dim == 128) { \ 224 | LAUNCH_PHASE1_KERNEL(32, 32, 8, 128, block_size); \ 225 | LAUNCH_PHASE2_KERNEL(32, 32, 8, 128, block_size); \ 226 | } \ 227 | else if (num_layers == 32 && num_q_heads == 32 && num_kv_heads == 32 && head_dim == 128) { \ 228 | LAUNCH_PHASE1_KERNEL(32, 32, 32, 128, block_size); \ 229 | LAUNCH_PHASE2_KERNEL(32, 32, 32, 128, block_size); \ 230 | } \ 231 | else if (num_layers == 40 && num_q_heads == 40 && num_kv_heads == 40 && head_dim == 128) { \ 232 | LAUNCH_PHASE1_KERNEL(40, 40, 40, 128, block_size); \ 233 | LAUNCH_PHASE2_KERNEL(40, 40, 40, 128, block_size); \ 234 | } \ 235 | else if (num_layers == 80 && num_q_heads == 64 && num_kv_heads == 8 && head_dim == 128) { \ 236 | LAUNCH_PHASE1_KERNEL(80, 64, 8, 128, block_size); \ 237 | LAUNCH_PHASE2_KERNEL(80, 64, 8, 128, block_size); \ 238 | } \ 239 | else { \ 240 | throw std::runtime_error("Unsupported configuration"); \ 241 | } 242 | 243 | void paged_attention( 244 | torch::Tensor q, // [num_decoding_seqs, num_q_heads, head_dim] 245 | torch::Tensor k, // [num_decoding_seqs, num_kv_heads, head_dim] 246 | torch::Tensor v, // [num_decoding_seqs, num_kv_heads, head_dim] 247 | torch::Tensor o, // [num_decoding_seqs, num_q_heads, head_dim] 248 | torch::Tensor kcache, // [..., num_layers, num_kv_heads, block_size, head_dim] 249 | torch::Tensor vcache, // [..., num_layers, num_kv_heads, block_size, head_dim] 250 | float softmax_scale, 251 | torch::Tensor block_table, // [..., block_table_width] 252 | torch::Tensor seq_ids, // [num_decoding_seqs] 253 | torch::Tensor seq_lens, // [num_decoding_seqs] 254 | const int64_t cur_layer, 255 | const int64_t seq_block_size, 256 | const int64_t num_seq_blocks 257 | ) { 258 | const int num_decoding_seqs = q.size(0); 259 | const int num_q_heads = q.size(1); 260 | const int head_dim = q.size(2); 261 | const int num_kv_heads = k.size(1); 262 | const int num_layers = kcache.size(1); 263 | const int block_size = kcache.size(-2); 264 | const int block_table_width = block_table.size(-1); 265 | 266 | assert (seq_block_size % block_size == 0); 267 | assert (seq_block_size % head_dim == 0); 268 | 269 | auto options = torch::TensorOptions() 270 | .dtype(torch::kFloat32) 271 | .device(torch::kCUDA); 272 | 273 | torch::Tensor mid_o = torch::empty( 274 | {num_decoding_seqs, num_q_heads, num_seq_blocks, head_dim}, 275 | options 276 | ); 277 | torch::Tensor mid_o_logexpsum = torch::zeros( 278 | {num_decoding_seqs, num_q_heads, num_seq_blocks}, 279 | options 280 | ); 281 | 282 | // softmax is defined to use exp, here we use 2^x so we need to scale the exponent by log2(e) 283 | softmax_scale *= 1.442695040888963; 284 | 285 | dim3 grid1(num_decoding_seqs, num_q_heads, num_seq_blocks); 286 | dim3 grid2(num_decoding_seqs, num_q_heads); 287 | 288 | switch(block_size){ 289 | case 16: 290 | SELECT_KERNEL(16); 291 | break; 292 | case 32: 293 | SELECT_KERNEL(32); 294 | break; 295 | case 64: 296 | SELECT_KERNEL(64); 297 | break; 298 | default: 299 | throw std::runtime_error("Unsupported block size"); 300 | } 301 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 DistServe Authors 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /swiftllm/worker/weight.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import dataclasses 4 | import torch 5 | import safetensors 6 | 7 | from swiftllm.model_config import LlamaModelConfig 8 | 9 | @dataclasses.dataclass 10 | class RegisteredWeightItem: 11 | attr_name: str 12 | key: str 13 | shape: tuple 14 | split: tuple # same shape as above, each element is a boolean indicating whether to split the tensor on that dimension 15 | dtype: torch.dtype 16 | 17 | @property 18 | def shape_with_split(self): 19 | return zip(self.shape, self.split) 20 | 21 | def get_real_shape(self, ws: int): 22 | return tuple(size // ws if split else size for size, split in self.shape_with_split) 23 | 24 | 25 | 26 | class WeightBase: 27 | """ 28 | The base class of all weight classes (i.e. LlamaTransformerLayerWeight or LlamaWeight) 29 | 30 | During weight initialization, each concrete weight class should first register 31 | all weight items. Each weight item has its own attribute name, key, shape, and dtype. 32 | 33 | During weight loading, RegisterWeightItem will be passed to the weight getter 34 | function, which should return the corresponding weight value (real/dummy). 35 | """ 36 | 37 | def __init__(self, model_config: LlamaModelConfig, dtype: torch.dtype): 38 | self.model_config = model_config 39 | self.dtype = dtype 40 | self.registered_weights: list[RegisteredWeightItem] = [] 41 | 42 | def register_weight(self, item: RegisteredWeightItem): 43 | self.registered_weights.append(item) 44 | 45 | def _post_process_after_load(self, getter: callable): 46 | """ 47 | This function is called after loading weights (real/dummy). 48 | Defined in each concrete weight class, called by load_weights(). 49 | """ 50 | raise NotImplementedError() 51 | 52 | def load_weights(self, getter: callable): 53 | """ 54 | Load weights 55 | """ 56 | ws = self.model_config.world_size 57 | for item in self.registered_weights: 58 | assert len(item.split) == len(item.shape), f"Length mismatch between split and shape for {item.attr_name}" 59 | assert all( 60 | not split or size % ws == 0 61 | for size, split in item.shape_with_split 62 | ), f"Cannot split tensor {item.attr_name} with shape {item.shape} into {ws} parts" 63 | 64 | weight_value = getter(item) 65 | 66 | assert weight_value is not None, f"getter() returned None for {item.key} ({item})" 67 | assert isinstance(weight_value, torch.Tensor), f"Weight {item.key} is not a tensor" 68 | assert weight_value.shape == item.get_real_shape(ws), \ 69 | f"Shape of weight {item.key} does not match {weight_value.shape} != {item.get_real_shape(ws)}" 70 | assert weight_value.device.type == "cuda", f"Weight {item.key} is not on GPU" 71 | setattr(self, item.attr_name, weight_value.to(item.dtype)) 72 | self._post_process_after_load(getter) 73 | 74 | 75 | 76 | class LlamaTransformerLayerWeight(WeightBase): 77 | """ 78 | Class stores the weights of one transformer layer (transformer block) in Llama model. 79 | """ 80 | 81 | def __init__( 82 | self, 83 | layer_id: int, 84 | model_config: LlamaModelConfig, 85 | dtype: torch.dtype 86 | ): 87 | super().__init__(model_config, dtype) 88 | 89 | self.layer_id = layer_id 90 | 91 | self.register_weight(RegisteredWeightItem( 92 | "attn_norm", 93 | f"model.layers.{self.layer_id}.input_layernorm.weight", 94 | (self.model_config.hidden_size,), 95 | (False,), 96 | self.dtype 97 | )) 98 | self.register_weight(RegisteredWeightItem( 99 | "q_proj", 100 | f"model.layers.{self.layer_id}.self_attn.q_proj.weight", 101 | (self.model_config.hidden_size, self.model_config.hidden_size), 102 | (True, False), 103 | self.dtype 104 | )) 105 | self.register_weight(RegisteredWeightItem( 106 | "k_proj", 107 | f"model.layers.{self.layer_id}.self_attn.k_proj.weight", 108 | (self.model_config.num_kv_heads*self.model_config.head_dim, self.model_config.hidden_size), 109 | (True, False), 110 | self.dtype 111 | )) 112 | self.register_weight(RegisteredWeightItem( 113 | "v_proj", 114 | f"model.layers.{self.layer_id}.self_attn.v_proj.weight", 115 | (self.model_config.num_kv_heads*self.model_config.head_dim, self.model_config.hidden_size), 116 | (True, False), 117 | self.dtype 118 | )) 119 | self.register_weight(RegisteredWeightItem( 120 | "o_proj", 121 | f"model.layers.{self.layer_id}.self_attn.o_proj.weight", 122 | (self.model_config.hidden_size, self.model_config.hidden_size), 123 | (False, True), 124 | self.dtype 125 | )) 126 | 127 | self.register_weight(RegisteredWeightItem( 128 | "ffn_norm", 129 | f"model.layers.{self.layer_id}.post_attention_layernorm.weight", 130 | (self.model_config.hidden_size,), 131 | (False,), 132 | self.dtype 133 | )) 134 | self.register_weight(RegisteredWeightItem( 135 | "up_proj", 136 | f"model.layers.{self.layer_id}.mlp.up_proj.weight", 137 | (self.model_config.ffn_inter_dim, self.model_config.hidden_size), 138 | (True, False), 139 | self.dtype 140 | )) 141 | self.register_weight(RegisteredWeightItem( 142 | "gate_proj", 143 | f"model.layers.{self.layer_id}.mlp.gate_proj.weight", 144 | (self.model_config.ffn_inter_dim, self.model_config.hidden_size), 145 | (True, False), 146 | self.dtype 147 | )) 148 | self.register_weight(RegisteredWeightItem( 149 | "down_proj", 150 | f"model.layers.{self.layer_id}.mlp.down_proj.weight", 151 | (self.model_config.hidden_size, self.model_config.ffn_inter_dim), 152 | (False, True), 153 | self.dtype 154 | )) 155 | 156 | def _post_process_after_load(self, getter: callable): 157 | # pylint: disable=no-member, attribute-defined-outside-init 158 | # self.qkv_proj = torch.cat((self.q_proj, self.k_proj, self.v_proj), dim=0).contiguous() 159 | # del self.q_proj, self.k_proj, self.v_proj 160 | self.up_gate_proj = torch.cat((self.up_proj, self.gate_proj), dim=0).contiguous() 161 | del self.up_proj, self.gate_proj 162 | 163 | 164 | 165 | class LlamaWeight(WeightBase): 166 | def __init__( 167 | self, 168 | model_config: LlamaModelConfig, 169 | dtype: torch.dtype 170 | ): 171 | super().__init__(model_config, dtype) 172 | 173 | self.register_weight(RegisteredWeightItem( 174 | "wte", 175 | "model.embed_tokens.weight", 176 | (self.model_config.vocab_size, self.model_config.hidden_size), 177 | (True, False), 178 | self.dtype 179 | )) 180 | self.register_weight(RegisteredWeightItem( 181 | "lm_head", 182 | "lm_head.weight", 183 | (self.model_config.vocab_size, self.model_config.hidden_size), 184 | (True, False), 185 | self.dtype 186 | )) 187 | self.register_weight(RegisteredWeightItem( 188 | "final_norm", 189 | "model.norm.weight", 190 | (self.model_config.hidden_size,), 191 | (False,), 192 | self.dtype 193 | )) 194 | 195 | self.layers: list[LlamaTransformerLayerWeight] = [] 196 | for i in range(self.model_config.num_layers): 197 | layer = LlamaTransformerLayerWeight(i, self.model_config, self.dtype) 198 | self.layers.append(layer) 199 | 200 | def _post_process_after_load(self, getter: callable): 201 | for layer in self.layers: 202 | layer.load_weights(getter) 203 | 204 | 205 | 206 | def load_weights( 207 | model_config: LlamaModelConfig, 208 | dtype: torch.dtype, 209 | model_path: str, 210 | use_dummy: bool = False 211 | ) -> LlamaWeight: 212 | """ 213 | Load weights from a given path 214 | """ 215 | rk = model_config.rank 216 | ws = model_config.world_size 217 | if use_dummy: 218 | assert rk == 0 and ws == 1, "Model sharding is not supported for dummy weights" 219 | def weight_getter_dummy(item: RegisteredWeightItem): 220 | return torch.empty(item.shape, dtype=item.dtype, device="cuda").uniform_(-0.001, 0.001) 221 | getter = weight_getter_dummy 222 | else: 223 | safetensor_files = [name for name in os.listdir(model_path) if name.endswith(".safetensors")] 224 | if len(safetensor_files) > 0: 225 | # Use Safetensors 226 | safetensor_index_path = os.path.join(model_path, "model.safetensors.index.json") 227 | if os.path.exists(safetensor_index_path): 228 | # The weight is stored in multiple files 229 | f = open(safetensor_index_path, "r", encoding="utf-8") 230 | safetensor_index = json.load(f)["weight_map"] 231 | safetensor_filename = None 232 | else: 233 | # The weight is stored in a single file 234 | assert len(safetensor_files) == 1, "model.safetensors.index.json not found, but there are multiple .safetensors files" 235 | safetensor_index = None 236 | safetensor_filename = safetensor_files[0] 237 | 238 | def weight_getter_real(item: RegisteredWeightItem): 239 | file_name = safetensor_index[item.key] if safetensor_index is not None else safetensor_filename 240 | file_path = os.path.join(model_path, file_name) 241 | # For safetensor files, since "opening" it is cheap, we open it every time 242 | with safetensors.safe_open(file_path, framework="pt", device="cuda") as f: 243 | whole = f.get_slice(item.key) 244 | slices = [ 245 | slice(rk * size // ws, (rk + 1) * size // ws) 246 | if split else slice(0, size) 247 | for size, split in item.shape_with_split 248 | ] 249 | tensor = whole[slices] 250 | return tensor.to(item.dtype) 251 | getter = weight_getter_real 252 | 253 | else: 254 | # Use PyTorch 255 | # TODO: support Model sharding for PyTorch files 256 | assert rk == 0 and ws == 1, "Model sharding is not supported for PyTorch files" 257 | pytorch_index_path = os.path.join(model_path, "pytorch_model.bin.index.json") 258 | if os.path.exists(pytorch_index_path): 259 | # The weight is stored in multiple files 260 | f = open(pytorch_index_path, "r", encoding="utf-8") 261 | pytorch_index = json.load(f)["weight_map"] 262 | pytorch_filename = None 263 | else: 264 | # The weight is stored in a single file 265 | pytorch_index = None 266 | pytorch_filename = "pytorch_model.bin" 267 | 268 | # For PyTorch files, since "opening" it is slow (due to deserialization), 269 | # we open it only once and then store the opened files in a dictionary. 270 | # We add `mmap=True` to avoid loading the entire file into memory. 271 | opened_files = {} 272 | def weight_getter_real(item: RegisteredWeightItem): 273 | file_name = pytorch_index[item.key] if pytorch_index is not None else pytorch_filename 274 | file_path = os.path.join(model_path, file_name) 275 | if file_path not in opened_files: 276 | opened_files[file_path] = torch.load(file_path, map_location="cuda", mmap=True) 277 | file = opened_files[file_path] 278 | return file[item.key].to(item.dtype) 279 | getter = weight_getter_real 280 | 281 | weight = LlamaWeight(model_config, dtype) 282 | weight.load_weights(getter) 283 | return weight 284 | -------------------------------------------------------------------------------- /swiftllm/server/block_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | Block manager classes on the control plane. 3 | 4 | They are used to manage the allocated and free blocks on both CPU and GPU, but actual 5 | model computations doesn't involve these classes. 6 | """ 7 | 8 | import torch 9 | from swiftllm.engine_config import EngineConfig 10 | from swiftllm.model_config import LlamaModelConfig 11 | from swiftllm.structs import Request, SubBatch 12 | 13 | 14 | class DeviceBlockManager: 15 | """ 16 | DeviceBlockManager - Manage the allocated & free blocks on one device (CPU / GPU) 17 | 18 | We may split KV cache along layer dimension, so there may be multiple free blocks tables. 19 | 20 | However, the sequences share same namespace of sequence IDs, so we only need one set of 21 | block table and `seq_num_blks`. 22 | """ 23 | 24 | @torch.inference_mode() 25 | def __init__( 26 | self, 27 | device_name: str, 28 | engine_config: EngineConfig 29 | ): 30 | self.device_name = device_name 31 | self.num_blocks = engine_config.num_gpu_blocks if device_name == 'cuda' else engine_config.num_cpu_blocks 32 | self.block_size = engine_config.block_size 33 | self.block_table_width = engine_config.max_blocks_per_seq 34 | nsplits = 1 + engine_config.extra_layer_for_cprf 35 | 36 | # seq_id |-> number of blocks allocated for this sequence 37 | self.seq_num_blks = torch.zeros( 38 | (engine_config.max_seqs_in_block_table,), 39 | dtype=torch.int32, 40 | device='cpu' 41 | ) 42 | # (seq_id, block_index) |-> block_id 43 | self.block_table = torch.empty( 44 | (engine_config.max_seqs_in_block_table, engine_config.max_blocks_per_seq), 45 | dtype=torch.int32, 46 | device='cpu' 47 | ) 48 | # block_id |-> whether this block is free or not 49 | self.num_free_blocks = [self.num_blocks] * nsplits 50 | self.is_block_free = [torch.ones( 51 | (self.num_blocks,), 52 | dtype=torch.bool, 53 | device='cpu' 54 | ) for _ in range(nsplits)] 55 | 56 | 57 | @torch.inference_mode() 58 | def _get_new_blk_ids(self, num_blocks: int, split_id: int=0) -> torch.Tensor: 59 | """ 60 | Check the free block table and return the block IDs of the newly allocated blocks 61 | """ 62 | if num_blocks == 0: 63 | return torch.empty(0, dtype=torch.int32) 64 | 65 | is_block_free = self.is_block_free[split_id] 66 | if num_blocks > self.num_free_blocks[split_id]: 67 | raise RuntimeError( 68 | f"No enough free blocks available on {self.device_name} split {split_id} ({self.num_blocks} in total, " 69 | f"{self.num_free_blocks[split_id]} free, {num_blocks} requested)" 70 | ) 71 | 72 | selected_blocks = torch.nonzero(is_block_free)[:num_blocks].view(-1).to(dtype=torch.int32) 73 | self.num_free_blocks[split_id] -= num_blocks 74 | is_block_free[selected_blocks] = False 75 | return selected_blocks 76 | 77 | 78 | @torch.inference_mode() 79 | def alloc(self, reqs: list[Request], split_point: int=0, omit_last=False) -> tuple[list[int], list[int]]: 80 | """ 81 | Allocate blocks for sequences, making sure that every request have enough blocks allocated for all its tokens. 82 | 83 | Those after split_point will be allocated in the first split, and the rest will be allocated in the second split. 84 | 85 | If omit_last is set to True, we don't need to allocate block for the last token. 86 | 87 | Return new mapping from block virtual IDs to block physical IDs. 88 | """ 89 | if not reqs: 90 | return [], [] 91 | 92 | seq_ids = Request.get_ids(reqs) 93 | seq_lens = torch.tensor(Request.get_lens(reqs), dtype=torch.int32) - int(omit_last) 94 | tgt_num_blks = (seq_lens - 1) // self.block_size + 1 95 | seq_num_blks = self.seq_num_blks[seq_ids] 96 | 97 | assert all(seq_num_blks <= tgt_num_blks), \ 98 | f"""(On {self.device_name}) Logic error: Some sequences have more blocks already allocated than needed. 99 | seq_ids: {seq_ids}, target_lens: {seq_lens}, target_num_blocks: {tgt_num_blks}, 100 | seq_num_blks: {seq_num_blks}""" 101 | 102 | new_num_blks = tgt_num_blks - seq_num_blks 103 | new_blk_ids0 = self._get_new_blk_ids(torch.sum(new_num_blks[split_point:]), 0) 104 | new_blk_ids1 = self._get_new_blk_ids(torch.sum(new_num_blks[:split_point]), 1) 105 | new_blk_pids = torch.cat([new_blk_ids1, new_blk_ids0]) 106 | 107 | seq_num_blks_list = seq_num_blks.tolist() 108 | new_num_blks_list = new_num_blks.tolist() 109 | new_blk_vids = [seq_ids[i] * self.block_table_width + j + seq_num_blks_list[i] for i, n in enumerate(new_num_blks_list) for j in range(n)] 110 | self.block_table.view(-1)[new_blk_vids] = new_blk_pids 111 | self.seq_num_blks[seq_ids] = tgt_num_blks 112 | return new_blk_vids, new_blk_pids.tolist() 113 | 114 | 115 | @torch.inference_mode() 116 | def free(self, reqs: list[Request], split_id: int=0) -> list[int]: 117 | """ 118 | Free the blocks allocated for the specified sequences. 119 | 120 | Return the block physical IDs that are freed. 121 | """ 122 | if not reqs: 123 | return [] 124 | 125 | seq_ids = Request.get_ids(reqs) 126 | seq_num_blks_list = self.seq_num_blks[seq_ids].tolist() 127 | blk_vids = [seq_ids[i] * self.block_table_width + j for i, n in enumerate(seq_num_blks_list) for j in range(n)] 128 | blk_pids = self.block_table.view(-1)[blk_vids] # possibly on GPU 129 | self.num_free_blocks[split_id] += len(blk_pids) 130 | self.is_block_free[split_id][blk_pids] = True 131 | self.seq_num_blks[seq_ids] = 0 132 | return blk_pids.tolist() 133 | 134 | 135 | 136 | class BlockManager: 137 | """ 138 | BlockManager - Manage the allocated & free blocks on both CPU and GPU 139 | """ 140 | 141 | def __init__( 142 | self, 143 | engine_config: EngineConfig, 144 | model_config: LlamaModelConfig 145 | ): 146 | self.engine_config = engine_config 147 | self.model_config = model_config 148 | self.extra_layer_for_cprf = engine_config.extra_layer_for_cprf 149 | self.gpu_block_manager = DeviceBlockManager("cuda", engine_config) 150 | self.cpu_block_manager = DeviceBlockManager("cpu", engine_config) 151 | 152 | 153 | def _alloc_blocks_for_batch(self, batch: SubBatch) -> tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]]: 154 | """ 155 | Allocate blocks for a batch of sequences. 156 | 157 | Return new block VIDs to block PIDs mappings on both CPU and GPU. 158 | """ 159 | return ( 160 | self.gpu_block_manager.alloc(batch.all_reqs[:batch.num_prgds], split_point=batch.num_cprfs * self.extra_layer_for_cprf, omit_last=False), 161 | self.cpu_block_manager.alloc(batch.all_reqs[batch.num_prgds:], omit_last=False) 162 | ) 163 | 164 | 165 | def _free_blocks_of_requests(self, reqs: list[Request]) -> tuple[list[int], list[int]]: 166 | """ 167 | Free the blocks allocated for the specified requests. 168 | """ 169 | return self.gpu_block_manager.free(reqs), self.cpu_block_manager.free(reqs) 170 | 171 | 172 | def _initiate_swap( 173 | self, 174 | reqs: list[Request], 175 | is_swap_out: bool, 176 | use_itm: bool = False, # Only true when swapping out from intermediate cache to CPU 177 | omit_last: bool = True # Normally we don't need to allocate block new token(s), except for CPU prefills 178 | ) -> tuple[list[int], list[int], list[int]]: 179 | """ 180 | Do all the set-up work for swapping in/out sequences. 181 | Returns a triple of src block PIDs, dst block VIDs and dst block PIDs. 182 | """ 183 | assert is_swap_out or not use_itm, "Cannot swap in to intermediate space" 184 | 185 | if not reqs: 186 | return [], [], [] 187 | 188 | src_block_manager = self.gpu_block_manager if is_swap_out else self.cpu_block_manager 189 | dst_block_manager = self.cpu_block_manager if is_swap_out else self.gpu_block_manager 190 | src_blk_pids = src_block_manager.free(reqs, int(use_itm)) 191 | dst_blk_vids, dst_blk_pids = dst_block_manager.alloc(reqs, omit_last=omit_last) 192 | return src_blk_pids, dst_blk_vids, dst_blk_pids 193 | 194 | 195 | def prepare( 196 | self, 197 | batches: list[SubBatch], 198 | cur_swap_out: list[Request], 199 | cur_swap_in: list[Request] 200 | ) -> tuple[tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]], tuple[list[int], list[int]], bool]: 201 | """ 202 | Prepare KV cache and swapping related arguments for the model forward pass 203 | 204 | Requires either cur_swap_out or cur_swap_in to be empty 205 | 206 | Return a triple, the first element is yet another tuple of mappings: 207 | 208 | (GPU block VIDs, GPU block PIDs), (CPU block VIDs, CPU block PIDs) 209 | 210 | The second element is a tuple of lists of block IDs: 211 | 212 | (source block PIDs, destination block PIDs) 213 | 214 | The third element is a boolean indicating whether it's a swap out operation 215 | """ 216 | assert not (cur_swap_out and cur_swap_in), "Swap out and swap in should be mutually exclusive" 217 | assert len(batches) in (1, 2), "The number of batches should be at most 2" 218 | 219 | mappings = (([], []), ([], [])) # (GPU, CPU) 220 | swappings = ([], []) # (swap in: CPU -> GPU / swap out: GPU -> CPU) 221 | 222 | # print(f"Preparing model forward args with {len(batches)} batches, swap out: {len(cur_swap_out)}, swap in: {len(cur_swap_in)}") 223 | 224 | # 1. Do conventional swaps 225 | is_swap_out = bool(cur_swap_out) 226 | sp, dv, dp = self._initiate_swap(cur_swap_out or cur_swap_in, is_swap_out) 227 | mappings[is_swap_out][0].extend(dv) 228 | mappings[is_swap_out][1].extend(dp) 229 | swappings[0].extend(sp) 230 | swappings[1].extend(dp) 231 | 232 | # 2. Allocate blocks for the batch, also prepare forward args 233 | sum_batch_size = 0 234 | sum_iter_width = 0 235 | for batch in batches: 236 | batch.set_model_forward_args(self.model_config) 237 | assert batch.batch_size > 0, "Batch size should be greater than 0" 238 | sum_batch_size += batch.batch_size 239 | sum_iter_width += batch.iter_width 240 | (gv, gp), (cv, cp) = self._alloc_blocks_for_batch(batch) 241 | mappings[0][0].extend(gv) 242 | mappings[0][1].extend(gp) 243 | mappings[1][0].extend(cv) 244 | mappings[1][1].extend(cp) 245 | assert sum_batch_size <= self.engine_config.max_batch_size, \ 246 | f"Batch size {sum_batch_size} exceeds max_batch_size {self.engine_config.max_batch_size}" 247 | assert sum_iter_width <= self.engine_config.max_tokens_in_batch, \ 248 | f"Iteration width {sum_iter_width} exceeds max_tokens_in_batch {self.engine_config.max_tokens_in_batch}" 249 | 250 | # 3. Do cprf swaps, this should happen after the batch allocation 251 | for batch in batches: 252 | sp, dv, dp = self._initiate_swap( 253 | batch.all_reqs[:batch.num_cprfs], is_swap_out=True, 254 | use_itm=self.engine_config.extra_layer_for_cprf, omit_last=False 255 | ) 256 | batch.src_blk_ids = sp 257 | batch.dst_blk_ids = dp 258 | mappings[1][0].extend(dv) 259 | mappings[1][1].extend(dp) 260 | 261 | return mappings, swappings, is_swap_out 262 | 263 | 264 | def update_and_free(self, batches: list[SubBatch], output_token_ids: list[int]) -> list[Request]: 265 | """ 266 | Called at the end of each iteration, 267 | 268 | Update the output token IDs of the requests and free the blocks allocated for the finished requests. 269 | 270 | Return the finished requests. 271 | """ 272 | all_reqs = sum([b.all_reqs for b in batches], []) 273 | finished_reqs = Request.update_output(all_reqs, output_token_ids) 274 | self._free_blocks_of_requests(finished_reqs) 275 | return finished_reqs 276 | 277 | -------------------------------------------------------------------------------- /swiftllm/worker/kernels/paged_attn.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import triton 4 | import triton.language as tl 5 | from triton.compiler.compiler import CompiledKernel 6 | 7 | from swiftllm.model_config import LlamaModelConfig 8 | from swiftllm.engine_config import EngineConfig 9 | from swiftllm.worker.infer_state import LlamaInferState 10 | 11 | # Map from seq_block_size to the compiled kernel, the former must be power of 2 12 | cached_phase1_bin = {} 13 | 14 | @triton.jit 15 | def _fwd_paged_attention_phase1( 16 | mid_o: torch.Tensor, # [num_decoding_seqs, num_q_heads, num_seq_blocks, head_dim], contiguous. num_seq_blocks = ceil(max_seq_len / seq_block_size) 17 | mid_o_logexpsum: torch.Tensor, # [num_decoding_seqs, num_q_heads, num_seq_blocks], contiguous 18 | q: torch.Tensor, # [num_decoding_seqs, num_q_heads, head_dim], contiguous 19 | k: torch.Tensor, # [num_decoding_seqs, num_kv_heads, head_dim], contiguous 20 | v: torch.Tensor, # [num_decoding_seqs, num_kv_heads, head_dim], contiguous 21 | k_cache: torch.Tensor, # [num_layers, num_blocks, num_kv_heads, block_size, head_dim], contiguous 22 | v_cache: torch.Tensor, # [num_layers, num_blocks, num_kv_heads, block_size, head_dim], contiguous 23 | block_table: torch.Tensor, # [*, max_blocks_per_seq], contiguous 24 | softmax_scale: tl.float16, 25 | decoding_seq_lens: torch.Tensor, # [num_decoding_seqs], contiguous 26 | seq_ids: torch.Tensor, # [num_decoding_seqs], contiguous 27 | num_seq_blocks: int, 28 | cur_layer: int, 29 | 30 | num_blocks: tl.constexpr, 31 | num_q_heads: tl.constexpr, 32 | num_kv_heads: tl.constexpr, 33 | num_my_heads: tl.constexpr, 34 | block_size: tl.constexpr, 35 | head_dim: tl.constexpr, 36 | seq_block_size: tl.constexpr, 37 | max_blocks_per_seq: tl.constexpr, 38 | ): 39 | # grid shape: [num_decoding_seqs, num_q_heads, num_seq_blocks] 40 | my_batch_id = tl.program_id(0).to(tl.int64) 41 | my_q_head_id = tl.program_id(1).to(tl.int64) 42 | my_seq_block_id = tl.program_id(2) 43 | # num_my_heads = num_q_heads // num_kv_heads 44 | my_kv_head_id = my_q_head_id // num_my_heads 45 | 46 | my_seq_id = tl.load(seq_ids + my_batch_id) 47 | my_seq_len = tl.load(decoding_seq_lens + my_batch_id) 48 | my_start_token_idx = my_seq_block_id * seq_block_size 49 | 50 | if my_start_token_idx >= my_seq_len: 51 | return 52 | 53 | offs_q = my_batch_id*num_q_heads*head_dim + my_q_head_id*head_dim + tl.arange(0, head_dim) 54 | my_q = tl.load(q + offs_q) # [head_dim] 55 | 56 | start_block_idx = my_seq_block_id*(seq_block_size//block_size) 57 | k_ptrs = k_cache + (cur_layer*num_blocks*num_kv_heads+my_kv_head_id)*block_size*head_dim + tl.arange(0, block_size)[:, None]*head_dim + tl.arange(0, head_dim)[None, :] 58 | v_ptrs = v_cache + (cur_layer*num_blocks*num_kv_heads+my_kv_head_id)*block_size*head_dim + tl.arange(0, block_size)[:, None]*head_dim + tl.arange(0, head_dim)[None, :] 59 | 60 | max_score = float("-1e20") 61 | sum_exp = 0.0 62 | acc = tl.zeros([head_dim], dtype=tl.float32) 63 | 64 | # In the following code we deal with the case where the sequence block is 65 | # the last one in the sequence separately, because: 66 | # - The last sequence block may not be a full block, therefore maskings 67 | # are needed. 68 | # - We can use tl.arange() when the sequence block is not the last one, 69 | # leading to better performance. 70 | if my_start_token_idx + seq_block_size >= my_seq_len: 71 | # First store the new KV cache 72 | my_block_pos = (my_seq_len-1) // block_size 73 | my_block_offset = (my_seq_len-1) % block_size 74 | my_block_index = tl.load(block_table + my_seq_id*max_blocks_per_seq + my_block_pos).to(tl.int64) 75 | offs_kv = (my_batch_id * num_kv_heads + my_kv_head_id) * head_dim + tl.arange(0, head_dim) 76 | offs_kvcache = (((cur_layer*num_blocks + my_block_index) * num_kv_heads + my_kv_head_id) * block_size + my_block_offset) * head_dim + tl.arange(0, head_dim) 77 | tl.store(k_cache + offs_kvcache, tl.load(k + offs_kv)) 78 | tl.store(v_cache + offs_kvcache, tl.load(v + offs_kv)) 79 | 80 | # The seq block I am processing is the last one in the sequence 81 | my_num_blocks = tl.cdiv( 82 | my_seq_len - my_start_token_idx, 83 | block_size 84 | ) 85 | 86 | for block_i in range(0, my_num_blocks): 87 | block_idx = start_block_idx + block_i 88 | block_index = tl.load(block_table + my_seq_id*max_blocks_per_seq + block_idx).to(tl.int64) 89 | k_block = tl.load(k_ptrs + block_index*num_kv_heads*block_size*head_dim) # [block_size, head_dim] 90 | attn_score = tl.sum(my_q[None, :] * k_block, axis=1) # [block_size] 91 | attn_score = attn_score * softmax_scale 92 | offs_token = block_i*block_size + my_start_token_idx + tl.arange(0, block_size) 93 | attn_score = tl.where(offs_token < my_seq_len, attn_score, float('-1e20')) 94 | v_block = tl.load(v_ptrs + block_index*num_kv_heads*block_size*head_dim) # [block_size, head_dim] 95 | 96 | cur_max_score = tl.max(attn_score, axis=0) 97 | new_max_score = tl.maximum(max_score, cur_max_score) 98 | exp_attn_score = tl.math.exp2(attn_score - new_max_score) 99 | old_acc_scale = tl.math.exp2(max_score - new_max_score) 100 | 101 | acc = acc*old_acc_scale + tl.sum(exp_attn_score[:, None]*v_block, axis=0) 102 | sum_exp = sum_exp*old_acc_scale + tl.sum(exp_attn_score, axis=0) 103 | max_score = new_max_score 104 | else: 105 | # The seq block I am processing is NOT the last one in the sequence 106 | for block_i in tl.static_range(0, seq_block_size // block_size): 107 | block_idx = start_block_idx + block_i 108 | block_index = tl.load(block_table + my_seq_id*max_blocks_per_seq + block_idx).to(tl.int64) 109 | k_block = tl.load(k_ptrs + block_index*num_kv_heads*block_size*head_dim) # [block_size, head_dim] 110 | attn_score = tl.sum(my_q[None, :] * k_block, axis=1) # [block_size] 111 | attn_score = attn_score * softmax_scale 112 | v_block = tl.load(v_ptrs + block_index*num_kv_heads*block_size*head_dim) # [block_size, head_dim] 113 | 114 | cur_max_score = tl.max(attn_score, axis=0) 115 | new_max_score = tl.maximum(max_score, cur_max_score) 116 | exp_attn_score = tl.math.exp2(attn_score - new_max_score) 117 | old_acc_scale = tl.math.exp2(max_score - new_max_score) 118 | 119 | acc = acc*old_acc_scale + tl.sum(exp_attn_score[:, None]*v_block, axis=0) 120 | sum_exp = sum_exp*old_acc_scale + tl.sum(exp_attn_score, axis=0) 121 | max_score = new_max_score 122 | 123 | offs_mid_o = my_batch_id*num_q_heads*num_seq_blocks*head_dim + my_seq_block_id*head_dim + (my_q_head_id*num_seq_blocks*head_dim) + tl.arange(0, head_dim) 124 | tl.store(mid_o + offs_mid_o, acc / sum_exp) 125 | offs_mid_o_logexpsum = my_batch_id*num_q_heads*num_seq_blocks + my_seq_block_id + my_q_head_id*num_seq_blocks 126 | tl.store(mid_o_logexpsum + offs_mid_o_logexpsum, tl.math.log2(sum_exp) + max_score) # Here tl.log(sum_exp) + max_score = log(sum(e^{a_i})) 127 | 128 | 129 | cached_phase2_bin = {} 130 | 131 | @triton.jit 132 | def _fwd_paged_attention_phase2( 133 | mid_o: torch.Tensor, # [num_decoding_seqs, num_q_heads, num_seq_blocks, head_dim], contiguous 134 | mid_o_logexpsum: torch.Tensor, # [num_decoding_seqs, num_q_heads, num_seq_blocks], contiguous 135 | o: torch.Tensor, # [num_decoding_seqs, num_q_heads, head_dim], contiguous 136 | 137 | decoding_seq_lens: torch.Tensor, # [num_decoding_seqs], contiguous 138 | 139 | num_q_heads: tl.constexpr, 140 | head_dim: tl.constexpr, 141 | num_seq_blocks: tl.constexpr, 142 | seq_block_size: tl.constexpr, 143 | ): 144 | # grid shape: [num_decoding_seqs, num_q_heads] 145 | my_batch_id = tl.program_id(0) 146 | my_q_head_id = tl.program_id(1) 147 | 148 | my_seq_len = tl.load(decoding_seq_lens + my_batch_id) 149 | my_num_seq_blocks = tl.cdiv(my_seq_len, seq_block_size) 150 | 151 | sum_exp = 0.0 152 | max_score = float("-1e20") 153 | acc = tl.zeros([head_dim], dtype=tl.float32) 154 | 155 | for seq_block_id in range(my_num_seq_blocks): 156 | offs_mid_o = ((my_batch_id*num_q_heads+my_q_head_id)*num_seq_blocks+seq_block_id)*head_dim + tl.arange(0, head_dim) 157 | offs_mid_o_logexpsum = (my_batch_id*num_q_heads+my_q_head_id)*num_seq_blocks+seq_block_id 158 | cur_mid_o = tl.load(mid_o + offs_mid_o) # [head_dim] 159 | cur_mid_o_logexpsum = tl.load(mid_o_logexpsum + offs_mid_o_logexpsum) 160 | 161 | new_max_score = tl.maximum(max_score, cur_mid_o_logexpsum) 162 | old_scale = tl.math.exp2(max_score - new_max_score) 163 | exp_score = tl.math.exp2(cur_mid_o_logexpsum - new_max_score) 164 | acc = acc * old_scale + exp_score * cur_mid_o 165 | sum_exp = sum_exp * old_scale + exp_score 166 | max_score = new_max_score 167 | 168 | offs_o = (my_batch_id*num_q_heads+my_q_head_id)*head_dim + tl.arange(0, head_dim) 169 | tl.store(o + offs_o, (acc / sum_exp).to(tl.float16)) 170 | 171 | 172 | def paged_attention( 173 | q: torch.Tensor, # [num_decoding_seqs, num_q_heads, head_dim] 174 | k: torch.Tensor, # [num_decoding_seqs, num_kv_heads, head_dim] 175 | v: torch.Tensor, # [num_decoding_seqs, num_kv_heads, head_dim] 176 | o: torch.Tensor, # [num_decoding_seqs, num_q_heads, head_dim] 177 | k_cache: torch.Tensor, 178 | v_cache: torch.Tensor, 179 | softmax_scale: float, 180 | block_table: torch.Tensor, 181 | seq_ids: torch.Tensor, 182 | seq_lens: torch.Tensor, 183 | cur_layer: int, 184 | seq_block_size: int, 185 | num_seq_blocks: int 186 | ): 187 | start = time.perf_counter() 188 | assert q.is_contiguous() 189 | assert k_cache.is_contiguous() 190 | assert v_cache.is_contiguous() 191 | assert block_table.is_contiguous() 192 | assert o.is_contiguous() 193 | 194 | num_q_heads = q.shape[1] 195 | head_dim = q.shape[2] 196 | num_blocks = k_cache.shape[1] 197 | num_kv_heads = k_cache.shape[2] 198 | block_size = k_cache.shape[3] 199 | block_table_width = block_table.shape[1] 200 | assert seq_block_size % block_size == 0 201 | 202 | num_decoding_seqs = seq_ids.shape[0] 203 | 204 | 205 | mid_o = torch.empty(( 206 | num_decoding_seqs, 207 | num_q_heads, 208 | num_seq_blocks, 209 | head_dim 210 | ), device=q.device, dtype=torch.float32) 211 | mid_o_logexpsum = torch.empty(( 212 | num_decoding_seqs, 213 | num_q_heads, 214 | num_seq_blocks 215 | ), device=q.device, dtype=torch.float32) 216 | 217 | grid = (num_decoding_seqs, num_q_heads, num_seq_blocks) 218 | 219 | global cached_phase1_bin 220 | if True: #seq_block_size not in cached_phase1_bin: 221 | cached_phase1_bin[seq_block_size] = _fwd_paged_attention_phase1[grid]( 222 | mid_o, mid_o_logexpsum, 223 | q, k, v, k_cache, v_cache, 224 | block_table, 225 | 226 | # Here we multiply softmax_scale by log2(e) and use `exp2` instead of 227 | # `exp` because of two reasons: 228 | # 1. Up to 12 Jun 2024, all NVIDIA GPUs does not have a `exp` instruction 229 | # in PTX. When calculating `exp`, they multiply the input by log2(e) 230 | # and use `exp2` instead. 231 | # 2. Some optimizations are disabled while using `exp` in a loop, see 232 | # https://github.com/triton-lang/triton/issues/2961 233 | softmax_scale * 1.442695040888963, 234 | seq_lens, 235 | seq_ids, 236 | num_seq_blocks, 237 | cur_layer, 238 | 239 | num_blocks, 240 | num_q_heads, 241 | num_kv_heads, 242 | num_q_heads // num_kv_heads, 243 | block_size, 244 | head_dim, 245 | seq_block_size, 246 | block_table_width, 247 | num_warps = 1, 248 | num_stages = 4 249 | ) 250 | else: 251 | cached_phase1_bin[seq_block_size][grid]( 252 | mid_o, mid_o_logexpsum, 253 | q, k, v, k_cache, v_cache, 254 | block_table, 255 | softmax_scale * 1.442695040888963, 256 | seq_lens, 257 | seq_ids, 258 | num_seq_blocks, 259 | cur_layer 260 | ) 261 | 262 | grid = (num_decoding_seqs, num_q_heads, 1) 263 | 264 | global cached_phase2_bin 265 | if True: #seq_block_size not in cached_phase2_bin: 266 | cached_phase2_bin[seq_block_size] = _fwd_paged_attention_phase2[grid]( 267 | mid_o, mid_o_logexpsum, 268 | o, 269 | seq_lens, 270 | num_q_heads, 271 | head_dim, 272 | num_seq_blocks, 273 | seq_block_size 274 | ) 275 | else: 276 | cached_phase2_bin[seq_block_size][grid]( 277 | mid_o, mid_o_logexpsum, 278 | o, 279 | seq_lens 280 | ) 281 | 282 | end = time.perf_counter() 283 | return end - start --------------------------------------------------------------------------------