├── .gitignore ├── .typos.toml ├── res └── candle_vllm_logo.png ├── src ├── openai │ ├── utils.rs │ ├── models │ │ └── mod.rs │ ├── conversation │ │ ├── mod.rs │ │ └── default_conversation.rs │ ├── streaming.rs │ ├── mod.rs │ ├── requests.rs │ ├── responses.rs │ ├── pipelines │ │ ├── mod.rs │ │ └── llama.rs │ ├── openai_server.rs │ └── sampling_params.rs ├── paged_attention │ ├── input_metadata.rs │ ├── utils.rs │ ├── attn_bias.rs │ ├── mod.rs │ └── memory_efficient_attention.rs ├── lib.rs ├── main.rs ├── backend │ ├── mod.rs │ ├── layers.rs │ ├── paged_attention.rs │ └── cache.rs └── scheduler │ ├── cache_engine.rs │ ├── sequence.rs │ ├── block_engine.rs │ └── mod.rs ├── examples └── llama.py ├── LICENSE ├── .devcontainer └── devcontainer.json ├── .github └── workflows │ ├── install.py │ └── ci.yml ├── .vscode └── settings.json ├── Cargo.toml ├── tests ├── tests.rs └── test_flan-t5-quantized.rs ├── README.md ├── kernels ├── copy_blocks_kernel.cu └── reshape_and_cache_kernel.cu └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | .hf_token 3 | __pycache__ 4 | *.ptx -------------------------------------------------------------------------------- /.typos.toml: -------------------------------------------------------------------------------- 1 | [default] 2 | extend-ignore-identifiers-re = [ 3 | "mmaped", 4 | ] -------------------------------------------------------------------------------- /res/candle_vllm_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LLukas22/candle-vllm/master/res/candle_vllm_logo.png -------------------------------------------------------------------------------- /src/openai/utils.rs: -------------------------------------------------------------------------------- 1 | use std::time::{SystemTime, UNIX_EPOCH}; 2 | 3 | pub(crate) fn get_created_time_secs() -> u64 { 4 | SystemTime::now() 5 | .duration_since(UNIX_EPOCH) 6 | .expect("Time travel has occurred...") 7 | .as_secs() 8 | } 9 | -------------------------------------------------------------------------------- /src/openai/models/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod llama; 2 | 3 | pub trait ConfigLike { 4 | fn get_num_kv_heads(&self) -> usize; 5 | fn get_hidden_size(&self) -> usize; 6 | fn get_num_hidden_layers(&self) -> usize; 7 | fn get_num_attention_heads(&self) -> usize; 8 | fn get_vocab_size(&self) -> usize; 9 | fn get_sliding_window(&self) -> Option; 10 | fn get_head_size(&self) -> usize { 11 | self.get_hidden_size() / self.get_num_attention_heads() 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/openai/conversation/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod default_conversation; 2 | 3 | /// A trait for using conversation managers with a `ModulePipeline`. 4 | pub trait Conversation { 5 | fn set_system_message(&mut self, system_message: String); 6 | 7 | fn append_message(&mut self, role: String, message: String); 8 | 9 | fn append_none_message(&mut self, role: String); 10 | 11 | fn update_last_message(&mut self); 12 | 13 | fn get_roles(&self) -> &(String, String); 14 | 15 | fn get_prompt(&mut self) -> String; 16 | } 17 | -------------------------------------------------------------------------------- /examples/llama.py: -------------------------------------------------------------------------------- 1 | import openai 2 | 3 | # Run: HF_TOKEN=... cargo run --release -- --hf-token HF_TOKEN --port 2000 llama7b --repeat-last-n 64 4 | 5 | openai.api_key = "EMPTY" 6 | 7 | openai.base_url = "http://localhost:2000/v1/" 8 | 9 | completion = openai.chat.completions.create( 10 | model="llama7b", 11 | messages=[ 12 | { 13 | "role": "user", 14 | "content": "Explain how to best learn Rust.", 15 | }, 16 | ], 17 | max_tokens = 32, 18 | ) 19 | print(completion.choices[0].message.content) -------------------------------------------------------------------------------- /src/openai/streaming.rs: -------------------------------------------------------------------------------- 1 | use std::{error::Error, sync::Arc}; 2 | 3 | use actix_web::web::Bytes; 4 | use futures::Stream; 5 | use tokio::sync::mpsc::{channel, Receiver, Sender}; 6 | 7 | pub(crate) type SenderError = Arc; 8 | 9 | pub(crate) fn new_streaming_conn() -> (Sender>, Client) { 10 | let (tx, rx) = channel(128); 11 | (tx, Client(rx)) 12 | } 13 | 14 | pub(crate) struct Client(Receiver>); 15 | 16 | impl Stream for Client { 17 | type Item = Result; 18 | 19 | fn poll_next( 20 | mut self: std::pin::Pin<&mut Self>, 21 | cx: &mut std::task::Context<'_>, 22 | ) -> std::task::Poll> { 23 | self.0.poll_recv(cx) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Eric Buehler 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the 2 | // README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-dockerfile 3 | { 4 | "name": "Existing Dockerfile", 5 | "build": { 6 | // Sets the run context to one level up instead of the .devcontainer folder. 7 | "context": "..", 8 | // Update the 'dockerFile' property if you aren't using the standard 'Dockerfile' filename. 9 | "dockerfile": "../Dockerfile" 10 | } 11 | // Features to add to the dev container. More info: https://containers.dev/features. 12 | // "features": {}, 13 | // Use 'forwardPorts' to make a list of ports inside the container available locally. 14 | // "forwardPorts": [], 15 | // Uncomment the next line to run commands after the container is created. 16 | // "postCreateCommand": "cat /etc/os-release", 17 | // Configure tool-specific properties. 18 | // "customizations": {}, 19 | // Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root. 20 | // "remoteUser": "devcontainer" 21 | } 22 | -------------------------------------------------------------------------------- /src/openai/mod.rs: -------------------------------------------------------------------------------- 1 | use std::sync::{Arc, Mutex}; 2 | 3 | use candle_core::Device; 4 | use tokenizers::{EncodeInput, Encoding, Tokenizer}; 5 | 6 | use self::{pipelines::llm_engine::LLMEngine, responses::APIError}; 7 | 8 | pub mod requests; 9 | pub mod responses; 10 | pub mod sampling_params; 11 | mod streaming; 12 | 13 | pub trait TokenizerWrapper<'s, E> 14 | where 15 | E: Into>, 16 | { 17 | fn tokenize(&self, input: E) -> Result; 18 | fn detokenize(&self, input: &[u32]) -> Result; 19 | } 20 | 21 | impl<'s, E> TokenizerWrapper<'s, E> for Tokenizer 22 | where 23 | E: Into>, 24 | { 25 | fn tokenize(&self, input: E) -> Result { 26 | self.encode(input, false).map_err(APIError::from) 27 | } 28 | 29 | fn detokenize(&self, input: &[u32]) -> Result { 30 | self.decode(input, false).map_err(APIError::from) 31 | } 32 | } 33 | 34 | #[derive(Clone)] 35 | pub struct PipelineConfig { 36 | pub max_model_len: usize, 37 | } 38 | 39 | #[derive(Clone)] 40 | pub struct OpenAIServerData<'s> { 41 | pub model: Arc>>, 42 | pub pipeline_config: PipelineConfig, 43 | pub device: Device, 44 | } 45 | 46 | pub mod conversation; 47 | pub mod models; 48 | pub mod openai_server; 49 | pub mod pipelines; 50 | pub mod utils; 51 | -------------------------------------------------------------------------------- /src/paged_attention/input_metadata.rs: -------------------------------------------------------------------------------- 1 | use candle_core::Tensor; 2 | 3 | use super::attn_bias::AttentionBiasBlockDiagonal; 4 | 5 | pub struct InputMetadata { 6 | pub prompt_lens: Vec, 7 | pub max_context_len: Option, 8 | pub block_tables: Option, 9 | pub context_lens: Option, 10 | pub slot_mapping: Tensor, 11 | pub attn_bias: Option>, 12 | pub is_prompt: bool, 13 | pub kv_cache_dtype: String, 14 | } 15 | 16 | impl InputMetadata { 17 | /// prompt_lens: Lengths of prompts. 18 | /// slot_mapping: The address to write the new KV to of each token. 19 | /// context_lens: the length of attention context for each generation token. 20 | /// max_context_len: The maximum context length. 21 | /// block_tables: The block tables. (Seq id -> list of physical block) 22 | /// kv_cache_dtype: KV cache datatype (auto or fp8_e5m2) 23 | pub fn new( 24 | prompt_lens: Vec, 25 | max_context_len: Option, 26 | block_tables: Option, 27 | context_lens: Option, 28 | slot_mapping: Tensor, 29 | kv_cache_dtype: String, 30 | ) -> Self { 31 | let is_prompt = !prompt_lens.is_empty(); 32 | Self { 33 | prompt_lens, 34 | max_context_len, 35 | block_tables, 36 | context_lens, 37 | slot_mapping, 38 | attn_bias: None, 39 | is_prompt, 40 | kv_cache_dtype, 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /.github/workflows/install.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | subprocess.run(["sudo", "apt", "update", "-y"]) 4 | subprocess.run(["sudo", "apt", "install", "libssl-dev", "-y"]) 5 | subprocess.run(["sudo", "apt", "install", "pkg-config", "-y"]) 6 | 7 | try: 8 | import torch 9 | works = True 10 | except: 11 | works = False 12 | 13 | if works: 14 | first = subprocess.run(["sudo", "find", "/", "-name", "libtorch_cpu.so"]).stdout.split("\n") 15 | else: 16 | first = [] 17 | 18 | nvcc_release = subprocess.run(["nvcc", "--version"]) 19 | assert nvcc_release.returncode == 0 20 | 21 | nvcc_release = nvcc_release.stdout.split("\n")[3] #Cuda compilation tools, release 11.5, V11.5.119 22 | nvcc_release = nvcc_release.split("release ")[1] #['Cuda compilation tools, ', '11.5, V11.5.119'] 23 | nvcc_release = float(nvcc_release[1].split(",")) 24 | 25 | print(f"Got nvcc version {nvcc_release}") 26 | if nvcc_release<=11.8: 27 | subprocess.run(["pip", "install", "torch==2.1.0", "torchvision==0.16.0", "torchaudio==2.1.0", "--index-url", "https://download.pytorch.org/whl/cu118"]) 28 | else: 29 | subprocess.run(["pip", "install", "torch==2.1.0", "torchvision==0.16.0", "torchaudio==2.1.0", "--index-url", "https://download.pytorch.org/whl/cu121"]) 30 | 31 | after = subprocess.run(["sudo", "find", "/", "-name", "libtorch_cpu.so"]).stdout.split("\n") 32 | different = list(filter(lambda x: x not in first, after))[0] 33 | 34 | with open("~/.bashrc", "a") as f: 35 | f.write("# candle-vllm") 36 | f.write(f"export LD_LIBRARY_PATH={different}:$LD_LIBRARY_PATH") 37 | f.write("export LIBTORCH_USE_PYTORCH=1") 38 | 39 | subprocess.run(["source", "~/.bashrc"]) -------------------------------------------------------------------------------- /src/openai/requests.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | #[derive(Debug, Clone, Serialize, Deserialize)] 6 | #[serde(untagged)] 7 | pub enum Messages { 8 | Map(Vec>), 9 | Literal(String), 10 | } 11 | 12 | #[derive(Debug, Clone, Serialize, Deserialize)] 13 | pub enum StopTokens { 14 | Multi(Vec), 15 | Single(String), 16 | } 17 | 18 | #[derive(Debug, Clone, Serialize, Deserialize)] 19 | pub struct ChatCompletionRequest { 20 | pub model: String, 21 | pub messages: Messages, 22 | #[serde(default)] 23 | pub temperature: Option, //0.7 24 | #[serde(default)] 25 | pub top_p: Option, //1.0 26 | #[serde(default)] 27 | pub n: Option, //1 28 | #[serde(default)] 29 | pub max_tokens: Option, //None 30 | #[serde(default)] 31 | pub stop: Option, 32 | #[serde(default)] 33 | pub stream: Option, //false 34 | #[serde(default)] 35 | pub presence_penalty: Option, //0.0 36 | #[serde(default)] 37 | pub frequency_penalty: Option, //0.0 38 | #[serde(default)] 39 | pub logit_bias: Option>, //None 40 | #[serde(default)] 41 | pub user: Option, //None 42 | #[serde(default)] 43 | //Additional candle-vllm params 44 | pub top_k: Option, //-1 45 | #[serde(default)] 46 | pub best_of: Option, //None 47 | #[serde(default)] 48 | pub use_beam_search: Option, //false 49 | #[serde(default)] 50 | pub ignore_eos: Option, //false 51 | #[serde(default)] 52 | pub skip_special_tokens: Option, //false 53 | #[serde(default)] 54 | pub stop_token_ids: Option>, //[] 55 | } 56 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.associations": { 3 | "*.yml": "yaml", 4 | "*.ke": "Kestrel", 5 | "array": "cpp", 6 | "atomic": "cpp", 7 | "bit": "cpp", 8 | "*.tcc": "cpp", 9 | "bitset": "cpp", 10 | "cctype": "cpp", 11 | "chrono": "cpp", 12 | "clocale": "cpp", 13 | "cmath": "cpp", 14 | "compare": "cpp", 15 | "concepts": "cpp", 16 | "cstdarg": "cpp", 17 | "cstddef": "cpp", 18 | "cstdint": "cpp", 19 | "cstdio": "cpp", 20 | "cstdlib": "cpp", 21 | "cstring": "cpp", 22 | "ctime": "cpp", 23 | "cwchar": "cpp", 24 | "cwctype": "cpp", 25 | "deque": "cpp", 26 | "unordered_map": "cpp", 27 | "vector": "cpp", 28 | "exception": "cpp", 29 | "algorithm": "cpp", 30 | "functional": "cpp", 31 | "iterator": "cpp", 32 | "memory": "cpp", 33 | "memory_resource": "cpp", 34 | "numeric": "cpp", 35 | "optional": "cpp", 36 | "random": "cpp", 37 | "ratio": "cpp", 38 | "string": "cpp", 39 | "string_view": "cpp", 40 | "system_error": "cpp", 41 | "tuple": "cpp", 42 | "type_traits": "cpp", 43 | "utility": "cpp", 44 | "initializer_list": "cpp", 45 | "iosfwd": "cpp", 46 | "istream": "cpp", 47 | "limits": "cpp", 48 | "mutex": "cpp", 49 | "new": "cpp", 50 | "numbers": "cpp", 51 | "ostream": "cpp", 52 | "ranges": "cpp", 53 | "stdexcept": "cpp", 54 | "stop_token": "cpp", 55 | "streambuf": "cpp", 56 | "thread": "cpp", 57 | "typeinfo": "cpp", 58 | "__nullptr": "cpp", 59 | "__bit_reference": "cpp", 60 | "__functional_base": "cpp", 61 | "__memory": "cpp" 62 | } 63 | } -------------------------------------------------------------------------------- /src/paged_attention/utils.rs: -------------------------------------------------------------------------------- 1 | use candle_core::{DType, Device, Shape, Tensor}; 2 | 3 | use crate::{openai::responses::APIError, try_api}; 4 | 5 | // https://github.com/mokeyish/candle-ext/blob/main/src/triangular.rs 6 | pub(crate) fn apply_triangular( 7 | xs: &Tensor, 8 | diagonal: isize, 9 | upper: bool, 10 | ) -> Result { 11 | let device = xs.device(); 12 | let (l, s) = try_api!(xs.dims2()); 13 | let mut xs_tri = vec![]; 14 | for i in 0..l.try_into().unwrap() { 15 | for j in 0..s.try_into().unwrap() { 16 | let cond = if upper { 17 | i + diagonal > j 18 | } else { 19 | i + diagonal < j 20 | }; 21 | xs_tri.push(if cond { 0u8 } else { 1u8 }); 22 | } 23 | } 24 | (xs * try_api!(try_api!(Tensor::from_vec(xs_tri, (l * s,), device)).to_dtype(xs.dtype()))) 25 | .map_err(APIError::from) 26 | } 27 | 28 | pub(crate) fn materialize_causal_mask( 29 | shape: &Shape, 30 | dtype: DType, 31 | device: &Device, 32 | window_size: Option, 33 | from_bottomright: bool, 34 | ) -> Result { 35 | let create_as = if dtype != DType::BF16 { 36 | dtype 37 | } else { 38 | DType::F32 39 | }; 40 | let tensor = try_api!(Tensor::ones(shape, create_as, device)); 41 | 42 | let mut shift = 0usize; 43 | if from_bottomright { 44 | let num_queries = shape.dims()[shape.dims().len() - 2]; 45 | let num_keys = shape.dims()[shape.dims().len() - 1]; 46 | shift = num_keys - num_queries; 47 | } 48 | 49 | let mut mask = try_api!(apply_triangular(&tensor, shift.try_into().unwrap(), false)); 50 | if let Some(window_size) = window_size { 51 | mask = try_api!(apply_triangular( 52 | &mask, 53 | (shift - window_size + 1).try_into().unwrap(), 54 | false 55 | )); 56 | } 57 | try_api!(mask.log()).to_dtype(dtype).map_err(APIError::from) 58 | } 59 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "candle-vllm" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | actix-web = "4.4.0" 10 | anyhow = "1.0.75" 11 | candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.4.0" } 12 | candle-examples = { git = "https://github.com/huggingface/candle.git", version = "0.4.0" } 13 | candle-lora = { git = "https://github.com/EricLBuehler/candle-lora.git", version = "0.2.0" } 14 | candle-lora-macro = { git = "https://github.com/EricLBuehler/candle-lora.git", version = "0.2.0" } 15 | candle-lora-transformers = { git = "https://github.com/EricLBuehler/candle-lora.git", version = "0.2.0" } 16 | candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.4.0" } 17 | dyn-fmt = "0.4.0" 18 | serde = { version = "1.0.190", features = ["serde_derive"] } 19 | tokenizers = "0.15.0" 20 | uuid = { version = "1.5.0", features = ["v4"] } 21 | candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.4.0" } 22 | hf-hub = "0.3.2" 23 | serde_json = "1.0.108" 24 | derive_more = "0.99.17" 25 | accelerate-src = { version = "0.3.2", optional = true } 26 | intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"], optional = true } 27 | cudarc = { version = "0.9.14", features = ["f16"], optional = true } 28 | half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } 29 | candle-flash-attn = { git = "https://github.com/huggingface/candle.git", version = "0.4.0", optional = true } 30 | clap = { version = "4.4.7", features = ["derive"] } 31 | candle-sampling = { git = "https://github.com/EricLBuehler/candle-sampling.git", version = "0.2.0" } 32 | futures = "0.3.29" 33 | tokio = { version = "1.33.0", features = ["sync"] } 34 | env_logger = "0.10.1" 35 | tracing = "0.1.40" 36 | range-checked = { git = "https://github.com/EricLBuehler/range-checked.git", version = "0.1.0" } 37 | chrono = { version = "0.4.31", features = ["clock"] } 38 | either = "1.9.0" 39 | dirs = "5.0.1" 40 | 41 | [features] 42 | default = ["cuda"] 43 | accelerate = ["dep:accelerate-src", "candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] 44 | cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:candle-flash-attn"] 45 | cudnn = ["candle-core/cudnn"] 46 | flash-attn = ["cuda", "candle-transformers/flash-attn"] 47 | mkl = ["dep:intel-mkl-src", "candle-core/mkl", "candle-nn/mkl", "candle-transformers/mkl"] 48 | nccl = ["cuda", "cudarc/nccl"] 49 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![warn(clippy::cast_lossless)] 2 | 3 | use clap::Subcommand; 4 | use openai::pipelines::{ 5 | llama::{LlamaLoader, LlamaSpecificConfig}, 6 | ModelLoader, 7 | }; 8 | 9 | #[derive(Debug, Subcommand)] 10 | pub enum ModelSelected { 11 | /// Select the llama7b model. 12 | Llama7b { 13 | /// Control the application of repeat penalty for the last n tokens 14 | #[arg(long)] 15 | repeat_last_n: usize, 16 | }, 17 | 18 | /// Select the llama13b model. 19 | Llama13b { 20 | /// Control the application of repeat penalty for the last n tokens 21 | #[arg(long)] 22 | repeat_last_n: usize, 23 | }, 24 | 25 | /// Select the llama70b model. 26 | Llama70b { 27 | /// Control the application of repeat penalty for the last n tokens 28 | #[arg(long)] 29 | repeat_last_n: usize, 30 | }, 31 | } 32 | 33 | impl ToString for ModelSelected { 34 | fn to_string(&self) -> String { 35 | match self { 36 | ModelSelected::Llama7b { repeat_last_n: _ } => "llama7b".to_string(), 37 | ModelSelected::Llama13b { repeat_last_n: _ } => "llama13b".to_string(), 38 | ModelSelected::Llama70b { repeat_last_n: _ } => "llama70b".to_string(), 39 | } 40 | } 41 | } 42 | 43 | pub fn get_model_loader<'a>(selected_model: ModelSelected) -> (Box>, String) { 44 | match selected_model { 45 | ModelSelected::Llama7b { repeat_last_n } => ( 46 | Box::new(LlamaLoader::new( 47 | LlamaSpecificConfig::new(repeat_last_n), 48 | "llama7b".to_string(), 49 | )), 50 | "meta-llama/Llama-2-7b-chat-hf".to_string(), 51 | ), 52 | ModelSelected::Llama13b { repeat_last_n } => ( 53 | Box::new(LlamaLoader::new( 54 | LlamaSpecificConfig::new(repeat_last_n), 55 | "llama13b".to_string(), 56 | )), 57 | "meta-llama/Llama-2-13b-chat-hf".to_string(), 58 | ), 59 | ModelSelected::Llama70b { repeat_last_n } => ( 60 | Box::new(LlamaLoader::new( 61 | LlamaSpecificConfig::new(repeat_last_n), 62 | "llama70b".to_string(), 63 | )), 64 | "meta-llama/Llama-2-70b-chat-hf".to_string(), 65 | ), 66 | } 67 | } 68 | 69 | pub fn log_warning(message: &str) { 70 | eprintln!("Warning at {:?}: '{}'", chrono::offset::Utc::now(), message); 71 | } 72 | 73 | pub mod backend; 74 | pub mod openai; 75 | pub mod paged_attention; 76 | pub mod scheduler; 77 | -------------------------------------------------------------------------------- /src/openai/responses.rs: -------------------------------------------------------------------------------- 1 | use actix_web::error; 2 | use candle_sampling::logits_processor::Logprobs; 3 | use derive_more::{Display, Error}; 4 | 5 | use serde::{Deserialize, Serialize}; 6 | 7 | #[derive(Debug, Display, Error, Serialize)] 8 | #[display(fmt = "Error: {}", data)] 9 | pub struct APIError { 10 | data: String, 11 | } 12 | 13 | impl error::ResponseError for APIError {} 14 | 15 | impl APIError { 16 | pub fn new(data: String) -> Self { 17 | Self { data } 18 | } 19 | 20 | pub fn new_str(data: &str) -> Self { 21 | Self { 22 | data: data.to_string(), 23 | } 24 | } 25 | 26 | pub fn from(value: T) -> Self { 27 | //panic!("{}", value.to_string()); 28 | Self::new(value.to_string()) 29 | } 30 | } 31 | 32 | #[macro_export] 33 | macro_rules! try_api { 34 | ($candle_result:expr) => { 35 | match $candle_result { 36 | Ok(v) => v, 37 | Err(e) => { 38 | return Err(APIError::from(e)); 39 | } 40 | } 41 | }; 42 | } 43 | 44 | #[derive(Debug, Clone, Serialize, Deserialize)] 45 | pub struct ChatCompletionUsageResponse { 46 | pub completion_tokens: usize, 47 | pub prompt_tokens: usize, 48 | pub total_tokens: usize, 49 | } 50 | 51 | // tool_calls, function_call not supported! 52 | #[derive(Debug, Clone, Serialize, Deserialize)] 53 | pub struct ChatChoiceData { 54 | pub content: Option, 55 | pub role: String, 56 | } 57 | 58 | #[derive(Debug, Clone, Serialize, Deserialize)] 59 | pub struct WrapperLogprobs { 60 | pub content: Vec, 61 | } 62 | 63 | #[derive(Debug, Clone, Serialize, Deserialize)] 64 | pub struct ChatChoice { 65 | pub message: ChatChoiceData, 66 | pub finish_reason: Option, 67 | pub index: usize, 68 | pub logprobs: Option, 69 | } 70 | 71 | #[derive(Debug, Clone, Serialize, Deserialize)] 72 | pub struct ChatCompletionResponse { 73 | pub id: String, 74 | pub choices: Vec, 75 | pub created: u64, 76 | pub model: String, 77 | pub object: &'static str, 78 | pub usage: ChatCompletionUsageResponse, 79 | } 80 | 81 | // tool_calls, function_call not supported! 82 | #[derive(Debug, Clone, Serialize, Deserialize)] 83 | pub struct StreamingChoiceData { 84 | pub content: Option, 85 | pub role: String, 86 | } 87 | 88 | #[derive(Debug, Clone, Serialize, Deserialize)] 89 | pub struct StreamingChoice { 90 | pub delta: StreamingChoiceData, 91 | pub finish_reason: Option, 92 | pub index: usize, 93 | } 94 | 95 | #[derive(Debug, Clone, Serialize, Deserialize)] 96 | pub struct StreamingChatCompletionResponse { 97 | pub id: String, 98 | pub choices: Vec, 99 | pub created: u64, 100 | pub model: String, 101 | pub object: &'static str, 102 | } 103 | -------------------------------------------------------------------------------- /tests/tests.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::HashMap, 3 | sync::{Arc, Mutex}, 4 | }; 5 | 6 | use actix_web::{http::header::ContentType, test, web::Data, App}; 7 | use candle_core::{DType, Device}; 8 | use candle_vllm::{ 9 | get_model_loader, 10 | openai::{ 11 | self, openai_server::chat_completions, pipelines::llm_engine::LLMEngine, 12 | requests::Messages, responses::APIError, OpenAIServerData, 13 | }, 14 | scheduler::{cache_engine::CacheConfig, SchedulerConfig}, 15 | ModelSelected, 16 | }; 17 | 18 | #[actix_web::test] 19 | async fn test_llama() -> Result<(), APIError> { 20 | let (loader, model_id) = get_model_loader(ModelSelected::Llama7b { repeat_last_n: 64 }); 21 | let paths = loader.download_model( 22 | model_id, 23 | None, 24 | Some(std::env::var("TESTS_HF_TOKEN").unwrap()), 25 | None, 26 | )?; 27 | let model = loader.load_model(paths, DType::F16, Device::Cpu)?; 28 | let llm_engine = LLMEngine::new( 29 | model.0, 30 | SchedulerConfig { max_num_seqs: 256 }, 31 | CacheConfig { 32 | block_size: 16, 33 | num_gpu_blocks: None, 34 | num_cpu_blocks: None, 35 | fully_init: false, 36 | }, 37 | )?; 38 | 39 | let server_data = OpenAIServerData { 40 | pipeline_config: model.1, 41 | model: Arc::new(Mutex::new(llm_engine)), 42 | device: Device::Cpu, 43 | }; 44 | 45 | let app = test::init_service( 46 | App::new() 47 | .service(chat_completions) 48 | .app_data(Data::new(server_data)), 49 | ) 50 | .await; 51 | 52 | let mut system = HashMap::new(); 53 | system.insert("role".to_string(), "system".to_string()); 54 | system.insert( 55 | "content".to_string(), 56 | "You are a talented author who specializes in writing poems.".to_string(), 57 | ); 58 | 59 | let mut user = HashMap::new(); 60 | user.insert("role".to_string(), "user".to_string()); 61 | user.insert( 62 | "content".to_string(), 63 | "Please write me a poem about why Rust is a great programming language:".to_string(), 64 | ); 65 | 66 | let req = test::TestRequest::with_uri("/v1/chat/completions") 67 | .insert_header(ContentType::json()) 68 | .set_json(openai::requests::ChatCompletionRequest { 69 | model: "llama".to_string(), 70 | messages: Messages::Map(vec![system, user]), 71 | temperature: None, 72 | top_p: None, 73 | n: None, 74 | max_tokens: None, 75 | stop: None, 76 | stream: None, 77 | presence_penalty: None, 78 | frequency_penalty: None, 79 | logit_bias: None, 80 | user: None, 81 | top_k: None, 82 | best_of: None, 83 | use_beam_search: None, 84 | skip_special_tokens: None, 85 | ignore_eos: None, 86 | stop_token_ids: None, 87 | }) 88 | .to_request(); 89 | 90 | let resp = test::call_service(&app, req).await; 91 | println!("{:?}", resp.status()); 92 | println!("{:?}", resp.into_body()); 93 | Ok(()) 94 | } 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | candle vLLM 3 |

4 | 5 | [![Continuous integration](https://github.com/EricLBuehler/candle-vllm/actions/workflows/ci.yml/badge.svg)](https://github.com/EricLBuehler/candle-vllm/actions/workflows/ci.yml) 6 | [![Discord server](https://dcbadge.vercel.app/api/server/FAeJRRJ8)](https://discord.gg/FAeJRRJ8) 7 | 8 | Efficient, easy-to-use platform for inference and serving local LLMs including an OpenAI compatible API server. 9 | 10 | PPlease see [mistral.rs](https://github.com/EricLBuehler/mistral.rs), efficient inference platform for many models, including quantized support. Additionally, it implements X-LoRA, recently released method [here](https://github.com/EricLBuehler/xlora). X-LoRA introduces a MoE inspired method to densely gate LoRA adapters powered by a model self-reflection forward pass. 11 | 12 | **candle-vllm is flux, in breaking development and as such is currently unstable.** 13 | 14 | ## Features 15 | - OpenAI compatible API server provided for serving LLMs. 16 | - Highly extensible trait-based system to allow rapid implementation of new module pipelines, 17 | - Streaming support in generation. 18 | - Efficient management of key-value cache with PagedAttention. 19 | - Continuous batching. 20 | 21 | ### Pipelines 22 | - Llama 23 | - 7b 24 | - 13b 25 | - 70b 26 | - Mistral 27 | - 7b 28 | 29 | ## Examples 30 | See [this folder](examples/) for some examples. 31 | 32 | ### Example with Llama 7b 33 | In your terminal, install the `openai` Python package by running `pip install openai`. I use version `1.3.5`. 34 | 35 | Then, create a new Python file and write the following code: 36 | ```python 37 | import openai 38 | 39 | openai.api_key = "EMPTY" 40 | 41 | openai.base_url = "http://localhost:2000/v1/" 42 | 43 | completion = openai.chat.completions.create( 44 | model="llama7b", 45 | messages=[ 46 | { 47 | "role": "user", 48 | "content": "Explain how to best learn Rust.", 49 | }, 50 | ], 51 | max_tokens = 64, 52 | ) 53 | print(completion.choices[0].message.content) 54 | ``` 55 | Next, launch a `candle-vllm` instance by running `cargo run --release -- --port 2000 llama7b --repeat-last-n 64`. 56 | 57 | After the `candle-vllm` instance is running, run the Python script and enjoy efficient inference with an OpenAI compatible API server! 58 | 59 | ## Usage Help 60 | For general configuration help, run `cargo run -- --help`. 61 | 62 | For model-specific help, run `cargo run -- --port 1234 --help` 63 | 64 | ## Installation 65 | Installing `candle-vllm` is as simple as the following steps. If you have any problems, please create an 66 | [issue](https://github.com/EricLBuehler/candle-lora/issues). 67 | 68 | 0) Be sure to install Rust here: https://www.rust-lang.org/tools/install 69 | 1) Run `sudo apt install libssl-dev` or equivalent install command 70 | 2) Run `sudo apt install pkg-config` or equivalent install command 71 | 72 | ## Contributing 73 | The following features are planned to be implemented, but contributions are especially welcome: 74 | - Sampling methods: 75 | - Beam search ([huggingface/candle#1319](https://github.com/huggingface/candle/issues/1319)) 76 | - More pipelines (from `candle-transformers`) 77 | 78 | ## Resources 79 | - Python implementation: [`vllm-project`](https://github.com/vllm-project/vllm) 80 | - [`vllm` paper](https://arxiv.org/abs/2309.06180) 81 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | use std::sync::{Arc, Mutex}; 2 | 3 | use actix_web::middleware::Logger; 4 | use actix_web::web::Data; 5 | use actix_web::{App, HttpServer}; 6 | use candle_core::{DType, Device}; 7 | use candle_vllm::openai::openai_server::chat_completions; 8 | use candle_vllm::openai::pipelines::llm_engine::LLMEngine; 9 | use candle_vllm::openai::responses::APIError; 10 | use candle_vllm::openai::OpenAIServerData; 11 | use candle_vllm::scheduler::cache_engine::CacheConfig; 12 | use candle_vllm::scheduler::SchedulerConfig; 13 | use candle_vllm::{get_model_loader, ModelSelected}; 14 | use clap::Parser; 15 | 16 | #[derive(Parser, Debug)] 17 | #[command(author, version, about, long_about = None)] 18 | struct Args { 19 | /// Huggingface token environment variable (optional). If not specified, load using hf_token_path. 20 | #[arg(long)] 21 | hf_token: Option, 22 | 23 | /// Huggingface token file (optional). If neither `hf_token` or `hf_token_path` are specified this is used with the value 24 | /// of `~/.cache/huggingface/token` 25 | #[arg(long)] 26 | hf_token_path: Option, 27 | 28 | /// Port to serve on (localhost:port) 29 | #[arg(long)] 30 | port: u16, 31 | 32 | /// Set verbose mode (print all requests) 33 | #[arg(long)] 34 | verbose: bool, 35 | 36 | #[clap(subcommand)] 37 | command: ModelSelected, 38 | 39 | /// Maximum number of sequences to allow 40 | #[arg(long, default_value_t = 256)] 41 | max_num_seqs: usize, 42 | 43 | /// Size of a block 44 | #[arg(long, default_value_t = 16)] 45 | block_size: usize, 46 | } 47 | 48 | #[actix_web::main] 49 | async fn main() -> Result<(), APIError> { 50 | let args = Args::parse(); 51 | 52 | let (loader, model_id) = get_model_loader(args.command); 53 | let paths = loader.download_model(model_id, None, args.hf_token, args.hf_token_path)?; 54 | let model = loader.load_model(paths, DType::F16, Device::Cpu)?; 55 | let llm_engine = LLMEngine::new( 56 | model.0, 57 | SchedulerConfig { 58 | max_num_seqs: args.max_num_seqs, 59 | }, 60 | CacheConfig { 61 | block_size: args.block_size, 62 | num_gpu_blocks: None, 63 | num_cpu_blocks: None, 64 | fully_init: false, 65 | }, 66 | )?; 67 | 68 | let server_data = OpenAIServerData { 69 | pipeline_config: model.1, 70 | model: Arc::new(Mutex::new(llm_engine)), 71 | device: Device::Cpu, 72 | }; 73 | 74 | println!("Server started at http://127.0.0.1:{}.", args.port); 75 | if args.verbose { 76 | env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); 77 | 78 | HttpServer::new(move || { 79 | App::new() 80 | .wrap(Logger::default()) 81 | .service(chat_completions) 82 | .app_data(Data::new(server_data.clone())) 83 | }) 84 | .bind(("127.0.0.1", args.port)) 85 | .map_err(|e| APIError::new(e.to_string()))? 86 | .run() 87 | .await 88 | .map_err(|e| APIError::new(e.to_string()))?; 89 | } else { 90 | HttpServer::new(move || { 91 | App::new() 92 | .service(chat_completions) 93 | .app_data(Data::new(server_data.clone())) 94 | }) 95 | .bind(("127.0.0.1", args.port)) 96 | .map_err(|e| APIError::new(e.to_string()))? 97 | .run() 98 | .await 99 | .map_err(|e| APIError::new(e.to_string()))?; 100 | } 101 | 102 | Ok(()) 103 | } 104 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | on: 2 | schedule: 3 | - cron: '0 0 * * 1' 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | 9 | name: Continuous integration 10 | 11 | jobs: 12 | check: 13 | name: Check 14 | runs-on: ${{ matrix.os }} 15 | strategy: 16 | matrix: 17 | os: [ubuntu-latest, windows-latest, macOS-latest] 18 | rust: [stable] 19 | steps: 20 | - uses: actions/checkout@v2 21 | - uses: actions-rs/toolchain@v1 22 | with: 23 | profile: minimal 24 | toolchain: ${{ matrix.rust }} 25 | override: true 26 | - uses: actions-rs/cargo@v1 27 | with: 28 | command: check 29 | args: --examples 30 | 31 | fmt: 32 | name: Rustfmt 33 | runs-on: ubuntu-latest 34 | steps: 35 | - uses: actions/checkout@v2 36 | - uses: actions-rs/toolchain@v1 37 | with: 38 | profile: minimal 39 | toolchain: stable 40 | override: true 41 | - run: rustup component add rustfmt 42 | - uses: actions-rs/cargo@v1 43 | with: 44 | command: fmt 45 | args: --all -- --check 46 | 47 | #build: 48 | # name: Build 49 | # runs-on: ${{ matrix.os }} 50 | # strategy: 51 | # matrix: 52 | # os: [ubuntu-latest] 53 | # rust: [stable] 54 | # steps: 55 | # - uses: actions/checkout@v2 56 | # - uses: actions-rs/toolchain@v1 57 | # with: 58 | # profile: minimal 59 | # toolchain: ${{ matrix.rust }} 60 | # override: true 61 | # - uses: Swatinem/rust-cache@v2 62 | # - run: sudo apt-get update -y && sudo apt-get install -y libssl-dev pkg-config 63 | # - uses: Jimver/cuda-toolkit@v0.2.14 64 | # id: cuda-toolkit 65 | # with: 66 | # cuda: '12.2.2' 67 | # # - run: echo "Installed cuda version is: ${{steps.cuda-toolkit.outputs.cuda}}" 68 | # # - run: echo "Cuda install location: ${{steps.cuda-toolkit.outputs.CUDA_PATH}}" 69 | # - run: CUDA_COMPUTE_CAP=86 cargo build --features cuda,cudnn,flash-attn,nccl 70 | 71 | #clippy: 72 | # name: Clippy 73 | # runs-on: ${{ matrix.os }} 74 | # strategy: 75 | # matrix: 76 | # os: [ubuntu-latest, windows-latest, macOS-latest] 77 | # rust: [stable] 78 | # steps: 79 | # - uses: actions/checkout@v2 80 | # - uses: actions-rs/toolchain@v1 81 | # with: 82 | # profile: minimal 83 | # toolchain: ${{ matrix.rust }} 84 | # override: true 85 | # - run: rustup component add clippy 86 | # - uses: actions-rs/cargo@v1 87 | # with: 88 | # command: clippy 89 | # args: --workspace --tests --examples -- -D warnings 90 | 91 | #docs: 92 | # name: Docs 93 | # runs-on: ubuntu-latest 94 | # steps: 95 | # - uses: actions/checkout@v2 96 | # - uses: actions-rs/toolchain@v1 97 | # with: 98 | # profile: minimal 99 | # toolchain: stable 100 | # override: true 101 | # - uses: actions-rs/cargo@v1 102 | # with: 103 | # command: doc 104 | # args: --workspace 105 | 106 | typos: 107 | name: Typos 108 | runs-on: ubuntu-latest 109 | steps: 110 | - uses: actions/checkout@v2 111 | - uses: actions-rs/toolchain@v1 112 | with: 113 | profile: minimal 114 | toolchain: stable 115 | override: true 116 | - name: Typos check with custom config file 117 | uses: crate-ci/typos@master 118 | with: 119 | config: .typos.toml -------------------------------------------------------------------------------- /kernels/copy_blocks_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // Grid: (num_layers, num_pairs) 4 | template 5 | __device__ void copy_blocks_internal_kernel( 6 | int64_t* key_cache_ptrs, 7 | int64_t* value_cache_ptrs, 8 | const int64_t* __restrict__ block_mapping, 9 | const int numel_per_block) { 10 | const int layer_idx = blockIdx.x; 11 | const int pair_idx = blockIdx.y; 12 | 13 | scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); 14 | scalar_t* value_cache = reinterpret_cast(value_cache_ptrs[layer_idx]); 15 | int64_t src_block_number = block_mapping[2 * pair_idx]; 16 | int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; 17 | 18 | const int64_t src_block_offset = src_block_number * numel_per_block; 19 | const int64_t dst_block_offset = dst_block_number * numel_per_block; 20 | for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { 21 | int64_t src_offset = src_block_offset + i; 22 | int64_t dst_offset = dst_block_offset + i; 23 | key_cache[dst_offset] = key_cache[src_offset]; 24 | } 25 | for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { 26 | int64_t src_offset = src_block_offset + i; 27 | int64_t dst_offset = dst_block_offset + i; 28 | value_cache[dst_offset] = value_cache[src_offset]; 29 | } 30 | } 31 | 32 | // Monomorphize the generics ourselves 33 | extern "C" __global__ void copy_blocks_kernel_u8(int64_t* key_cache_ptrs, 34 | int64_t* value_cache_ptrs, 35 | const int64_t* __restrict__ block_mapping, 36 | const int numel_per_block) { 37 | copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, block_mapping, numel_per_block); 38 | } 39 | 40 | extern "C" __global__ void copy_blocks_kernel_u32(int64_t* key_cache_ptrs, 41 | int64_t* value_cache_ptrs, 42 | const int64_t* __restrict__ block_mapping, 43 | const int numel_per_block) { 44 | copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, block_mapping, numel_per_block); 45 | } 46 | 47 | extern "C" __global__ void copy_blocks_kernel_i64(int64_t* key_cache_ptrs, 48 | int64_t* value_cache_ptrs, 49 | const int64_t* __restrict__ block_mapping, 50 | const int numel_per_block) { 51 | copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, block_mapping, numel_per_block); 52 | } 53 | 54 | extern "C" __global__ void copy_blocks_kernel_f32(int64_t* key_cache_ptrs, 55 | int64_t* value_cache_ptrs, 56 | const int64_t* __restrict__ block_mapping, 57 | const int numel_per_block) { 58 | copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, block_mapping, numel_per_block); 59 | } 60 | 61 | extern "C" __global__ void copy_blocks_kernel_f64(int64_t* key_cache_ptrs, 62 | int64_t* value_cache_ptrs, 63 | const int64_t* __restrict__ block_mapping, 64 | const int numel_per_block) { 65 | copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, block_mapping, numel_per_block); 66 | } 67 | 68 | // f16, bf16 are special cases: We use a 16-bit integer to simulate the bit width. 69 | // SAFETY: This is technically UB due to aliasing, but it is OK because the width is compatible. 70 | extern "C" __global__ void copy_blocks_kernel_f16(int64_t* key_cache_ptrs, 71 | int64_t* value_cache_ptrs, 72 | const int64_t* __restrict__ block_mapping, 73 | const int numel_per_block) { 74 | copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, block_mapping, numel_per_block); 75 | } 76 | 77 | extern "C" __global__ void copy_blocks_kernel_bf16(int64_t* key_cache_ptrs, 78 | int64_t* value_cache_ptrs, 79 | const int64_t* __restrict__ block_mapping, 80 | const int numel_per_block) { 81 | copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, block_mapping, numel_per_block); 82 | } -------------------------------------------------------------------------------- /src/backend/mod.rs: -------------------------------------------------------------------------------- 1 | mod cache; 2 | mod layers; 3 | mod paged_attention; 4 | 5 | const COPY_BLOCKS_PTX: &str = "kernels/copy_blocks_kernel.ptx"; 6 | 7 | const COPY_BLOCKS_KERNEL: &str = "copy_blocks_kernel"; 8 | 9 | const RESHAPE_AND_CACHE_PTX: &str = "kernels/reshape_and_cache_kernel.ptx"; 10 | 11 | const RESHAPE_AND_CACHE_KERNEL: &str = "reshape_and_cache_kernel"; 12 | 13 | const ROTARY_EMBDEDDING_PTX: &str = "kernels/rotary_embedding_kernel.ptx"; 14 | 15 | const ROTARY_EMBDEDDING_KERNEL: &str = "rotary_embedding_kernel"; 16 | 17 | pub fn get_or_load_func( 18 | ptx_file: &'static str, 19 | kernel_base: &str, 20 | dtype: DType, 21 | suffix: Option<&str>, 22 | device: &CudaDevice, 23 | ) -> Result { 24 | let spec = match dtype { 25 | DType::U8 => "_u8", 26 | DType::U32 => "_u32", 27 | DType::I64 => "_i64", 28 | DType::BF16 => "_bf16", 29 | DType::F16 => "_f16", 30 | DType::F32 => "_f32", 31 | DType::F64 => "_f64", 32 | }; 33 | let spec = if let Some(suffix) = suffix { 34 | spec.to_owned() + suffix 35 | } else { 36 | spec.to_owned() 37 | }; 38 | let kernel = kernel_base.to_owned() + &spec; 39 | device 40 | .get_or_load_func(&kernel, ptx_file) 41 | .map_err(APIError::from) 42 | } 43 | 44 | #[repr(transparent)] 45 | struct Conjoined<'a, T, R> { 46 | raw: *mut T, 47 | _ref: PhantomData<&'a mut R>, 48 | } 49 | 50 | impl<'a, T, R> Conjoined<'a, T, R> { 51 | fn new(raw: NonNull, _ref: &'a mut R) -> Self { 52 | Self { 53 | raw: raw.as_ptr(), 54 | _ref: PhantomData, 55 | } 56 | } 57 | } 58 | 59 | /// According to the docs: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb8f3dc3031b40da29d5f9a7139e52e15 60 | /// Each of the kernel params (*mut c_void) "must point to a region of memory from which the actual kernel parameter will be copied". 61 | /// This means that we must return a pointer to our pointer. 62 | /// 63 | /// ## Safety 64 | /// - The returned pointer **must not** outlive the &self reference. Otherwise, a dangling pointer is created. 65 | unsafe impl<'a, T, R> DeviceRepr for Conjoined<'a, T, R> { 66 | fn as_kernel_param(&self) -> *mut std::ffi::c_void { 67 | addr_of!(self.raw) as *mut _ 68 | } 69 | } 70 | 71 | fn dispatch_get_cuda_pointer(tensor: Tensor) -> u64 { 72 | match tensor.dtype() { 73 | DType::BF16 => get_cuda_pointer::(tensor), 74 | DType::F16 => get_cuda_pointer::(tensor), 75 | DType::U8 => get_cuda_pointer::(tensor), 76 | DType::U32 => get_cuda_pointer::(tensor), 77 | DType::I64 => get_cuda_pointer::(tensor), 78 | DType::F32 => get_cuda_pointer::(tensor), 79 | DType::F64 => get_cuda_pointer::(tensor), 80 | } 81 | } 82 | 83 | fn get_cuda_pointer(tensor: Tensor) -> u64 { 84 | match &*tensor.storage_and_layout().0 { 85 | Storage::Cuda(cuda_storage) => *cuda_storage.as_cuda_slice::().unwrap().device_ptr(), 86 | other => panic!("Unsupported storage `{:?}`", other), 87 | } 88 | } 89 | 90 | pub use cache::*; 91 | use candle_core::{ 92 | cuda_backend::{ 93 | cudarc::driver::{CudaFunction, DevicePtr, DeviceRepr}, 94 | CudaDType, 95 | }, 96 | CudaDevice, DType, Storage, Tensor, 97 | }; 98 | use half::{bf16, f16}; 99 | pub use layers::*; 100 | pub use paged_attention::*; 101 | pub use std::ops::Deref; 102 | use std::{ 103 | marker::PhantomData, 104 | ptr::{addr_of, NonNull}, 105 | }; 106 | 107 | use crate::openai::responses::APIError; 108 | -------------------------------------------------------------------------------- /src/openai/pipelines/mod.rs: -------------------------------------------------------------------------------- 1 | use candle_core::{DType, Device, Tensor, WithDType}; 2 | use candle_sampling::logits_processor::Logprobs; 3 | use dirs; 4 | use either::Either; 5 | use std::{env, fs, path::PathBuf, sync::Arc}; 6 | 7 | use crate::{ 8 | paged_attention::input_metadata::InputMetadata, scheduler::sequence::Sequence, try_api, 9 | }; 10 | 11 | use super::{ 12 | conversation::Conversation, models::ConfigLike, responses::APIError, 13 | sampling_params::SamplingParams, PipelineConfig, TokenizerWrapper, 14 | }; 15 | 16 | pub mod llama; 17 | /// The LLMEngine is effectively a wrapper around a ModulePipeline. It contains a Scheduler and a CacheEngine 18 | /// which are used to scheduler and manage the cache during generation requests, respectively. 19 | pub mod llm_engine; 20 | 21 | type TokenOrFinishReason = Either; 22 | 23 | pub trait ModulePipeline<'s>: Send + Sync { 24 | fn forward( 25 | &mut self, 26 | input_tokens: Tensor, 27 | input_positions: Tensor, 28 | kv_cache: Option<&Vec<(Tensor, Tensor)>>, 29 | input_metadata: InputMetadata, 30 | ) -> Result; 31 | 32 | fn sample( 33 | &mut self, 34 | logits: Tensor, 35 | sampling_params: &SamplingParams, 36 | seqs: &[(&usize, &Arc)], 37 | ) -> Result, APIError>; 38 | 39 | fn name(&self) -> &str; 40 | 41 | fn tokenizer(&self) -> &dyn TokenizerWrapper<'s, String>; 42 | 43 | fn get_conversation(&mut self) -> &mut dyn Conversation; 44 | 45 | fn get_model_config(&self) -> Box; 46 | 47 | fn get_dtype(&self) -> DType; 48 | } 49 | 50 | // TODO(EricLBuehler): Ensure the padding token matches tokenizer 51 | fn _make_tensor_with_pad( 52 | x: Vec>, 53 | max_len: usize, 54 | pad: D, 55 | ) -> Result { 56 | let mut padded_x = Vec::new(); 57 | for mut x_i in x { 58 | assert!(x_i.len() <= max_len); 59 | x_i.extend([pad].repeat(max_len - x_i.len())); 60 | let shape = (x_i.len(),); 61 | padded_x.push(try_api!(Tensor::from_vec( 62 | x_i, 63 | shape, 64 | &try_api!(Device::new_cuda(0)) 65 | ))); 66 | } 67 | Tensor::cat(&padded_x[..], 0).map_err(APIError::from) 68 | } 69 | 70 | pub(crate) fn get_token( 71 | hf_token: Option, 72 | hf_token_path: Option, 73 | ) -> Result { 74 | Ok(match (hf_token, hf_token_path) { 75 | (Some(envvar), None) => try_api!(env::var(envvar)), 76 | (None, Some(path)) => try_api!(fs::read_to_string(path)), 77 | (None, None) => try_api!(fs::read_to_string(format!( 78 | "{}/.cache/huggingface/token", 79 | dirs::home_dir() 80 | .ok_or(APIError::new_str("No home directory"))? 81 | .display() 82 | ))), 83 | _ => { 84 | return Err(APIError::new_str( 85 | "Do not specify `hf_token` and `hf_token_path` at the same time.", 86 | )) 87 | } 88 | }) 89 | } 90 | 91 | pub trait ModelPaths { 92 | fn get_weight_filenames(&self) -> &Vec; 93 | fn get_config_filename(&self) -> &PathBuf; 94 | fn get_tokenizer_filename(&self) -> &PathBuf; 95 | } 96 | 97 | pub trait ModelLoader<'a> { 98 | fn download_model( 99 | &self, 100 | model_id: String, 101 | revision: Option, 102 | hf_token: Option, 103 | hf_token_path: Option, 104 | ) -> Result, APIError>; 105 | 106 | fn load_model( 107 | &self, 108 | paths: Box, 109 | dtype: DType, 110 | device: Device, 111 | ) -> Result<(Box>, PipelineConfig), APIError>; 112 | } 113 | -------------------------------------------------------------------------------- /src/backend/layers.rs: -------------------------------------------------------------------------------- 1 | use candle_core::{ 2 | cuda_backend::cudarc::driver::{LaunchAsync, LaunchConfig}, 3 | DType, Device, Tensor, 4 | }; 5 | 6 | use crate::{ 7 | backend::{get_or_load_func, ROTARY_EMBDEDDING_KERNEL, ROTARY_EMBDEDDING_PTX}, 8 | openai::responses::APIError, 9 | try_api, 10 | }; 11 | 12 | use super::dispatch_get_cuda_pointer; 13 | 14 | /// # Safety 15 | /// Unsafe due to passing pointers 16 | pub unsafe fn rotary_embedding( 17 | positions: Tensor, 18 | query: &mut Tensor, 19 | key: &mut Tensor, 20 | head_size: usize, 21 | cos_sin_cache: Tensor, 22 | is_neox: bool, 23 | ) -> Result<(), APIError> { 24 | let positions_dev = positions.device().clone(); 25 | let Device::Cuda(dev) = positions_dev else { 26 | panic!("Expected the positions to be on a CUDA device.") 27 | }; 28 | 29 | if positions.dtype() != DType::I64 { 30 | return Err(APIError::new(format!( 31 | "`positions` has {:?} type, expected I64 type.", 32 | positions.dtype() 33 | ))); 34 | } 35 | 36 | if !query.device().same_device(positions.device()) { 37 | return Err(APIError::new(format!( 38 | "`query` and `positions` have different devices, got {:?} and {:?} respectively.", 39 | query.device(), 40 | positions.device() 41 | ))); 42 | } 43 | 44 | if !key.device().same_device(positions.device()) { 45 | return Err(APIError::new(format!( 46 | "`key` and `positions` have different devices, got {:?} and {:?} respectively.", 47 | key.device(), 48 | positions.device() 49 | ))); 50 | } 51 | 52 | if !cos_sin_cache.device().same_device(positions.device()) { 53 | return Err(APIError::new(format!( 54 | "`cos_sin_cache` and `positions` have different devices, got {:?} and {:?} respectively.", 55 | cos_sin_cache.device(), 56 | positions.device() 57 | ))); 58 | } 59 | 60 | let num_tokens = query.shape().elem_count() / query.shape().dims().last().unwrap(); 61 | let cache_shape = cos_sin_cache.shape().clone(); 62 | let rot_dim = cache_shape.dims().get(1).unwrap(); 63 | let num_heads = query.shape().dims().last().unwrap() / head_size; 64 | let num_kv_heads = key.shape().dims().last().unwrap() / head_size; 65 | let query_stride = query.stride().get(key.stride().len() - 2).unwrap(); 66 | let key_stride = key.stride().get(key.stride().len() - 2).unwrap(); 67 | 68 | let launch_conf = LaunchConfig { 69 | grid_dim: (num_tokens.try_into().unwrap(), 1u32, 1u32), 70 | block_dim: ( 71 | 512.min((num_heads * rot_dim / 2).try_into().unwrap()), 72 | 1u32, 73 | 1u32, 74 | ), 75 | shared_mem_bytes: 0, 76 | }; 77 | 78 | let positions_ptr = dispatch_get_cuda_pointer(positions); 79 | let key_ptr = dispatch_get_cuda_pointer(key.clone()); 80 | let query_ptr = dispatch_get_cuda_pointer(query.clone()); 81 | let cos_sin_cache_ptr = dispatch_get_cuda_pointer(cos_sin_cache); 82 | 83 | let stream = try_api!(dev.fork_default_stream()); 84 | 85 | let kernel = if is_neox { 86 | try_api!(get_or_load_func( 87 | ROTARY_EMBDEDDING_PTX, 88 | ROTARY_EMBDEDDING_KERNEL, 89 | query.dtype(), 90 | Some("_neox"), 91 | &dev 92 | )) 93 | } else { 94 | try_api!(get_or_load_func( 95 | ROTARY_EMBDEDDING_PTX, 96 | ROTARY_EMBDEDDING_KERNEL, 97 | query.dtype(), 98 | None, 99 | &dev 100 | )) 101 | }; 102 | 103 | try_api!(unsafe { 104 | kernel.launch_on_stream( 105 | &stream, 106 | launch_conf, 107 | ( 108 | positions_ptr, 109 | query_ptr, 110 | key_ptr, 111 | cos_sin_cache_ptr, 112 | *rot_dim, 113 | *query_stride, 114 | *key_stride, 115 | num_heads, 116 | num_kv_heads, 117 | head_size, 118 | ), 119 | ) 120 | }); 121 | 122 | Ok(()) 123 | } 124 | -------------------------------------------------------------------------------- /src/paged_attention/attn_bias.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | 3 | use std::iter::zip; 4 | 5 | use candle_core::{DType, Device, Shape, Tensor}; 6 | 7 | use crate::openai::responses::APIError; 8 | 9 | use crate::paged_attention::utils; 10 | use crate::try_api; 11 | 12 | pub trait AttentionBiasBlockDiagonal { 13 | /// Queries and Keys are each divided into the same number of blocks. 14 | /// A query Q in block i cannot attend to a key which is not in block i, 15 | /// nor one which is farther from the initial key in block i than Q 16 | /// is from the initial query in block i. 17 | fn materialize( 18 | &self, 19 | shape: &Shape, 20 | dtype: DType, 21 | device: &Device, 22 | ) -> Result { 23 | let mut mask = 24 | try_api!( 25 | try_api!(Tensor::zeros(&shape.dims().to_vec()[2..], dtype, device,)) 26 | .to_dtype(dtype) 27 | ); 28 | 29 | for ((q_start, q_end), (k_start, k_end)) in zip( 30 | self.get_q_seqinfo().intervals(), 31 | self.get_k_seqinfo().intervals(), 32 | ) { 33 | try_api!(mask.slice_assign( 34 | &[ 35 | q_start as usize..q_end as usize, 36 | k_start as usize..k_end as usize, 37 | ], 38 | &self._create_block_mask( 39 | &Shape::from_dims(&[(q_end - q_start) as usize, (k_end - k_start) as usize]), 40 | dtype, 41 | device, 42 | )?, 43 | )); 44 | } 45 | 46 | for _ in 0..shape.dims().len() - 2 { 47 | mask = try_api!(mask.unsqueeze(0)); 48 | } 49 | mask.expand(shape).map_err(APIError::from) 50 | } 51 | 52 | fn get_q_seqinfo(&self) -> &SeqLenInfo; 53 | 54 | fn get_k_seqinfo(&self) -> &SeqLenInfo; 55 | 56 | fn _create_block_mask( 57 | &self, 58 | shape: &Shape, 59 | dtype: DType, 60 | device: &Device, 61 | ) -> Result; 62 | 63 | fn make_local_attention( 64 | &self, 65 | _window_size: usize, 66 | ) -> Result, APIError> { 67 | unimplemented!(); 68 | } 69 | } 70 | 71 | #[derive(Clone)] 72 | pub struct SeqLenInfo { 73 | seqstart_py: Vec, 74 | } 75 | 76 | impl SeqLenInfo { 77 | fn new(seqstart_py: Vec) -> Self { 78 | Self { seqstart_py } 79 | } 80 | 81 | fn from_seqlens<'a>(seqlens: impl Iterator) -> Result { 82 | let mut seqstart_py = vec![0]; 83 | for seqlen in seqlens.into_iter() { 84 | seqstart_py.push(seqstart_py[seqstart_py.len() - 1] + seqlen); 85 | } 86 | Ok(Self::new(seqstart_py)) 87 | } 88 | 89 | fn intervals(&self) -> Box> { 90 | Box::new(zip( 91 | self.seqstart_py.clone(), 92 | self.seqstart_py[1..].to_vec(), 93 | )) 94 | } 95 | } 96 | 97 | pub struct BlockDiagonalCausalMask { 98 | q_seqinfo: SeqLenInfo, 99 | k_seqinfo: SeqLenInfo, 100 | _batch_sizes: Option>, 101 | } 102 | 103 | impl BlockDiagonalCausalMask { 104 | fn new(q_seqinfo: SeqLenInfo, k_seqinfo: SeqLenInfo, _batch_sizes: Option>) -> Self { 105 | Self { 106 | q_seqinfo, 107 | k_seqinfo, 108 | _batch_sizes, 109 | } 110 | } 111 | 112 | pub fn from_seqlens( 113 | q_seqlen: Vec, 114 | kv_seqlen: Option>, 115 | ) -> Result, APIError> { 116 | assert!(kv_seqlen.is_none() || q_seqlen.len() == kv_seqlen.as_ref().unwrap().len()); 117 | let q_seqinfo = try_api!(SeqLenInfo::from_seqlens(q_seqlen.iter())); 118 | let k_seqinfo = if kv_seqlen.is_none() || &q_seqlen == kv_seqlen.as_ref().unwrap() { 119 | q_seqinfo.clone() 120 | } else { 121 | try_api!(SeqLenInfo::from_seqlens(kv_seqlen.unwrap().iter())) 122 | }; 123 | Ok(Box::new(Self::new(q_seqinfo, k_seqinfo, None))) 124 | } 125 | } 126 | 127 | impl AttentionBiasBlockDiagonal for BlockDiagonalCausalMask { 128 | fn _create_block_mask( 129 | &self, 130 | shape: &Shape, 131 | dtype: DType, 132 | device: &Device, 133 | ) -> Result { 134 | Tensor::zeros(shape, dtype, device).map_err(APIError::from) 135 | } 136 | 137 | fn get_k_seqinfo(&self) -> &SeqLenInfo { 138 | &self.k_seqinfo 139 | } 140 | 141 | fn get_q_seqinfo(&self) -> &SeqLenInfo { 142 | &self.q_seqinfo 143 | } 144 | 145 | fn make_local_attention( 146 | &self, 147 | window_size: usize, 148 | ) -> Result, APIError> { 149 | Ok(Box::new(BlockDiagonalCausalLocalAttentionMask::new( 150 | self.q_seqinfo.clone(), 151 | self.k_seqinfo.clone(), 152 | self._batch_sizes.clone(), 153 | window_size, 154 | ))) 155 | } 156 | } 157 | 158 | pub struct BlockDiagonalCausalLocalAttentionMask { 159 | q_seqinfo: SeqLenInfo, 160 | k_seqinfo: SeqLenInfo, 161 | _batch_sizes: Option>, 162 | _window_size: usize, 163 | } 164 | 165 | impl BlockDiagonalCausalLocalAttentionMask { 166 | fn new( 167 | q_seqinfo: SeqLenInfo, 168 | k_seqinfo: SeqLenInfo, 169 | _batch_sizes: Option>, 170 | window_size: usize, 171 | ) -> Self { 172 | Self { 173 | q_seqinfo, 174 | k_seqinfo, 175 | _batch_sizes, 176 | _window_size: window_size, 177 | } 178 | } 179 | } 180 | 181 | impl AttentionBiasBlockDiagonal for BlockDiagonalCausalLocalAttentionMask { 182 | fn _create_block_mask( 183 | &self, 184 | shape: &Shape, 185 | dtype: DType, 186 | device: &Device, 187 | ) -> Result { 188 | utils::materialize_causal_mask(shape, dtype, device, Some(self._window_size), false) 189 | } 190 | 191 | fn get_k_seqinfo(&self) -> &SeqLenInfo { 192 | &self.k_seqinfo 193 | } 194 | 195 | fn get_q_seqinfo(&self) -> &SeqLenInfo { 196 | &self.q_seqinfo 197 | } 198 | } 199 | 200 | pub struct LowerTriangularMaskWithTensorBias { 201 | bias: Tensor, 202 | } 203 | 204 | impl LowerTriangularMaskWithTensorBias { 205 | pub fn new(bias: Tensor) -> Self { 206 | Self { bias } 207 | } 208 | } 209 | 210 | impl AttentionBiasBlockDiagonal for LowerTriangularMaskWithTensorBias { 211 | fn materialize( 212 | &self, 213 | shape: &Shape, 214 | dtype: DType, 215 | device: &Device, 216 | ) -> Result { 217 | (try_api!(utils::materialize_causal_mask( 218 | shape, dtype, device, None, false 219 | )) + &self.bias) 220 | .map_err(APIError::from) 221 | } 222 | fn _create_block_mask( 223 | &self, 224 | _shape: &Shape, 225 | _dtype: DType, 226 | _device: &Device, 227 | ) -> Result { 228 | unimplemented!("should not be called"); 229 | } 230 | fn get_k_seqinfo(&self) -> &SeqLenInfo { 231 | unimplemented!("should not be called"); 232 | } 233 | fn get_q_seqinfo(&self) -> &SeqLenInfo { 234 | unimplemented!("should not be called"); 235 | } 236 | fn make_local_attention( 237 | &self, 238 | _window_size: usize, 239 | ) -> Result, APIError> { 240 | unimplemented!("should not be called"); 241 | } 242 | } 243 | -------------------------------------------------------------------------------- /kernels/reshape_and_cache_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | template 4 | __device__ void reshape_and_cache_internal_kernel( 5 | const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] 6 | const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] 7 | scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] 8 | scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, blosudo ck_size] 9 | const int64_t* __restrict__ slot_mapping, // [num_tokens] 10 | const int key_stride, 11 | const int value_stride, 12 | const int num_heads, 13 | const int head_size, 14 | const int block_size, 15 | const int x) { 16 | const int64_t token_idx = blockIdx.x; 17 | const int64_t slot_idx = slot_mapping[token_idx]; 18 | if (slot_idx < 0) { 19 | // Padding token that should be ignored. 20 | return; 21 | } 22 | 23 | const int64_t block_idx = slot_idx / block_size; 24 | const int64_t block_offset = slot_idx % block_size; 25 | 26 | const int n = num_heads * head_size; 27 | for (int i = threadIdx.x; i < n; i += blockDim.x) { 28 | const int64_t src_key_idx = token_idx * key_stride + i; 29 | const int64_t src_value_idx = token_idx * value_stride + i; 30 | 31 | const int head_idx = i / head_size; 32 | const int head_offset = i % head_size; 33 | const int x_idx = head_offset / x; 34 | const int x_offset = head_offset % x; 35 | 36 | const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x 37 | + head_idx * (head_size / x) * block_size * x 38 | + x_idx * block_size * x 39 | + block_offset * x 40 | + x_offset; 41 | const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size 42 | + head_idx * head_size * block_size 43 | + head_offset * block_size 44 | + block_offset; 45 | key_cache[tgt_key_idx] = key[src_key_idx]; 46 | value_cache[tgt_value_idx] = value[src_value_idx]; 47 | } 48 | } 49 | 50 | // Monomorphize the generics ourselves 51 | extern "C" __global__ void reshape_and_cache_kernel_u8( 52 | const uint8_t* __restrict__ key, // [num_tokens, num_heads, head_size] 53 | const uint8_t* __restrict__ value, // [num_tokens, num_heads, head_size] 54 | uint8_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] 55 | uint8_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] 56 | const int64_t* __restrict__ slot_mapping, // [num_tokens] 57 | const int key_stride, 58 | const int value_stride, 59 | const int num_heads, 60 | const int head_size, 61 | const int block_size, 62 | const int x) { 63 | reshape_and_cache_internal_kernel(key, value, key_cache, value_cache, slot_mapping, key_stride, value_stride, num_heads, head_size, block_size, x); 64 | } 65 | 66 | extern "C" __global__ void reshape_and_cache_kernel_u32( 67 | const uint32_t* __restrict__ key, // [num_tokens, num_heads, head_size] 68 | const uint32_t* __restrict__ value, // [num_tokens, num_heads, head_size] 69 | uint32_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] 70 | uint32_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] 71 | const int64_t* __restrict__ slot_mapping, // [num_tokens] 72 | const int key_stride, 73 | const int value_stride, 74 | const int num_heads, 75 | const int head_size, 76 | const int block_size, 77 | const int x) { 78 | reshape_and_cache_internal_kernel(key, value, key_cache, value_cache, slot_mapping, key_stride, value_stride, num_heads, head_size, block_size, x); 79 | } 80 | 81 | extern "C" __global__ void reshape_and_cache_kernel_i64( 82 | const int64_t* __restrict__ key, // [num_tokens, num_heads, head_size] 83 | const int64_t* __restrict__ value, // [num_tokens, num_heads, head_size] 84 | int64_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] 85 | int64_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] 86 | const int64_t* __restrict__ slot_mapping, // [num_tokens] 87 | const int key_stride, 88 | const int value_stride, 89 | const int num_heads, 90 | const int head_size, 91 | const int block_size, 92 | const int x) { 93 | reshape_and_cache_internal_kernel(key, value, key_cache, value_cache, slot_mapping, key_stride, value_stride, num_heads, head_size, block_size, x); 94 | } 95 | 96 | extern "C" __global__ void reshape_and_cache_kernel_f32( 97 | const float* __restrict__ key, // [num_tokens, num_heads, head_size] 98 | const float* __restrict__ value, // [num_tokens, num_heads, head_size] 99 | float* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] 100 | float* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] 101 | const int64_t* __restrict__ slot_mapping, // [num_tokens] 102 | const int key_stride, 103 | const int value_stride, 104 | const int num_heads, 105 | const int head_size, 106 | const int block_size, 107 | const int x) { 108 | reshape_and_cache_internal_kernel(key, value, key_cache, value_cache, slot_mapping, key_stride, value_stride, num_heads, head_size, block_size, x); 109 | } 110 | 111 | extern "C" __global__ void reshape_and_cache_kernel_f64( 112 | const double* __restrict__ key, // [num_tokens, num_heads, head_size] 113 | const double* __restrict__ value, // [num_tokens, num_heads, head_size] 114 | double* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] 115 | double* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] 116 | const int64_t* __restrict__ slot_mapping, // [num_tokens] 117 | const int key_stride, 118 | const int value_stride, 119 | const int num_heads, 120 | const int head_size, 121 | const int block_size, 122 | const int x) { 123 | reshape_and_cache_internal_kernel(key, value, key_cache, value_cache, slot_mapping, key_stride, value_stride, num_heads, head_size, block_size, x); 124 | } 125 | 126 | extern "C" __global__ void reshape_and_cache_kernel_f16( 127 | const int16_t* __restrict__ key, // [num_tokens, num_heads, head_size] 128 | const int16_t* __restrict__ value, // [num_tokens, num_heads, head_size] 129 | int16_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] 130 | int16_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] 131 | const int64_t* __restrict__ slot_mapping, // [num_tokens] 132 | const int key_stride, 133 | const int value_stride, 134 | const int num_heads, 135 | const int head_size, 136 | const int block_size, 137 | const int x) { 138 | reshape_and_cache_internal_kernel(key, value, key_cache, value_cache, slot_mapping, key_stride, value_stride, num_heads, head_size, block_size, x); 139 | } 140 | 141 | extern "C" __global__ void reshape_and_cache_kernel_bf16( 142 | const int16_t* __restrict__ key, // [num_tokens, num_heads, head_size] 143 | const int16_t* __restrict__ value, // [num_tokens, num_heads, head_size] 144 | int16_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] 145 | int16_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] 146 | const int64_t* __restrict__ slot_mapping, // [num_tokens] 147 | const int key_stride, 148 | const int value_stride, 149 | const int num_heads, 150 | const int head_size, 151 | const int block_size, 152 | const int x) { 153 | reshape_and_cache_internal_kernel(key, value, key_cache, value_cache, slot_mapping, key_stride, value_stride, num_heads, head_size, block_size, x); 154 | } 155 | -------------------------------------------------------------------------------- /src/paged_attention/mod.rs: -------------------------------------------------------------------------------- 1 | use candle_core::{DType, Device, Tensor}; 2 | 3 | use crate::{ 4 | backend::{paged_attention_v1, paged_attention_v2, reshape_and_cache}, 5 | openai::responses::APIError, 6 | try_api, 7 | }; 8 | 9 | use self::input_metadata::InputMetadata; 10 | mod attn_bias; 11 | pub(crate) mod input_metadata; 12 | mod memory_efficient_attention; 13 | use memory_efficient_attention::_memory_efficient_attention; 14 | pub(crate) mod utils; 15 | 16 | const _PARTITION_SIZE: usize = 512; 17 | 18 | #[allow(dead_code)] 19 | pub struct PagedAttention { 20 | num_attention_heads: usize, 21 | head_dim: usize, 22 | num_key_value_heads: usize, 23 | scale: f32, 24 | sliding_window: Option, 25 | num_queries_per_kv: usize, 26 | alibi_slopes: Option, 27 | } 28 | 29 | impl PagedAttention { 30 | pub fn new( 31 | num_attention_heads: usize, 32 | head_dim: usize, 33 | scale: f32, 34 | num_key_value_heads: Option, 35 | sliding_window: Option, 36 | device: Device, 37 | alibi_slopes: Option>, 38 | ) -> Result { 39 | let num_key_value_heads = num_key_value_heads.unwrap_or(num_attention_heads); 40 | let num_queries_per_kv = num_attention_heads / num_key_value_heads; 41 | let alibi_slopes = if let Some(alibi_slopes) = alibi_slopes { 42 | Some(try_api!(Tensor::new(alibi_slopes, &device))) 43 | } else { 44 | None 45 | }; 46 | Ok(Self { 47 | num_attention_heads, 48 | head_dim, 49 | num_key_value_heads, 50 | scale, 51 | sliding_window, 52 | num_queries_per_kv, 53 | alibi_slopes, 54 | }) 55 | } 56 | 57 | /// Args: 58 | /// output: shape = [num_generation_tokens, num_heads, head_size] 59 | /// 60 | /// query: shape = [num_generation_tokens, num_heads, head_size] 61 | /// 62 | /// key_cache: shape = [num_blocks, num_kv_heads, head_size/x, 63 | /// block_size, x] 64 | /// 65 | /// value_cache: shape = [num_blocks, num_kv_heads, head_size, 66 | /// block_size] 67 | /// 68 | /// input_metadata: metadata for paged attention. 69 | /// 70 | /// alibi_slopes: shape = [num_heads] 71 | pub fn _paged_attention( 72 | &mut self, 73 | query: Tensor, 74 | key_cache: Tensor, 75 | value_cache: Tensor, 76 | input_metadata: &mut InputMetadata, 77 | alibi_slopes: Option, 78 | ) -> Result { 79 | let block_size = *value_cache.shape().dims().get(3).unwrap(); 80 | let (num_seqs, num_heads, _head_size) = try_api!(query.shape().dims3()); 81 | let max_num_partitions = 82 | (input_metadata.max_context_len.unwrap() + _PARTITION_SIZE - 1) / _PARTITION_SIZE; 83 | 84 | let use_v1 = input_metadata.max_context_len.unwrap() <= 8192 85 | && (max_num_partitions == 1 || num_seqs * num_heads > 512); 86 | let output = if use_v1 { 87 | //Run PagedAttention V1 88 | paged_attention_v1( 89 | query, 90 | key_cache, 91 | value_cache, 92 | self.num_key_value_heads.try_into().unwrap(), 93 | self.scale, 94 | input_metadata.block_tables.as_ref().unwrap().clone(), 95 | input_metadata.context_lens.as_ref().unwrap().clone(), 96 | block_size, 97 | input_metadata.max_context_len.unwrap(), 98 | alibi_slopes, 99 | &input_metadata.kv_cache_dtype, 100 | )? 101 | } else { 102 | //Run PagedAttention V2 103 | assert_eq!(_PARTITION_SIZE % block_size, 0); 104 | 105 | let exp_sums = try_api!(Tensor::zeros( 106 | (num_seqs, num_heads, max_num_partitions), 107 | DType::F32, 108 | query.device(), 109 | )); 110 | let max_logits = try_api!(exp_sums.zeros_like()); 111 | 112 | paged_attention_v2( 113 | exp_sums, 114 | max_logits, 115 | query, 116 | key_cache, 117 | value_cache, 118 | self.num_key_value_heads.try_into().unwrap(), 119 | self.scale, 120 | input_metadata.block_tables.as_ref().unwrap().clone(), 121 | input_metadata.context_lens.as_ref().unwrap().clone(), 122 | block_size, 123 | input_metadata.max_context_len.unwrap(), 124 | alibi_slopes, 125 | ) 126 | }; 127 | Ok(output) 128 | } 129 | 130 | #[allow(clippy::too_many_arguments)] 131 | fn _normal_attention( 132 | &self, 133 | query: Tensor, 134 | key: Tensor, 135 | value: Tensor, 136 | input_metadata: &mut InputMetadata, 137 | seq_len: usize, 138 | batch_size: usize, 139 | device: &Device, 140 | dtype: DType, 141 | ) -> Result { 142 | _memory_efficient_attention( 143 | self, 144 | query, 145 | key, 146 | value, 147 | input_metadata, 148 | seq_len, 149 | batch_size, 150 | device, 151 | dtype, 152 | ) 153 | } 154 | 155 | #[allow(clippy::too_many_arguments)] 156 | #[allow(unused_variables)] 157 | /// query: shape = [batch_size, seq_len, num_heads * head_size] 158 | /// key: shape = [batch_size, seq_len, num_kv_heads * head_size] 159 | /// value: shape = [batch_size, num_kv_heads * head_size] 160 | /// key_cache: shape = [num_blocks, num_kv_heads, head_size/x, 161 | /// block_size, x] 162 | /// value_cache: shape = [num_blocks, num_kv_heads, head_size, 163 | /// block_size] 164 | /// input_metadata: metadata for paged attention. 165 | pub fn forward( 166 | &mut self, 167 | query: Tensor, 168 | key: Tensor, 169 | value: Tensor, 170 | mut key_cache: Option, 171 | mut value_cache: Option, 172 | input_metadata: &mut InputMetadata, 173 | dtype: DType, 174 | device: Device, 175 | ) -> Result { 176 | let (batch_size, seq_len, hidden_size) = try_api!(query.shape().dims3()); 177 | let query = try_api!(query.reshape(((), self.num_attention_heads, self.head_dim))); 178 | let key = try_api!(key.reshape(((), self.num_key_value_heads, self.head_dim))); 179 | let value = try_api!(value.reshape(((), self.num_key_value_heads, self.head_dim))); 180 | let slot_mapping = try_api!(input_metadata 181 | .slot_mapping 182 | .flatten(0, input_metadata.slot_mapping.dims().len())); 183 | 184 | if key_cache.as_ref().is_some_and(|_| value_cache.is_some()) { 185 | try_api!(unsafe { 186 | reshape_and_cache( 187 | key.clone(), 188 | value.clone(), 189 | key_cache.as_mut().unwrap(), 190 | value_cache.as_mut().unwrap(), 191 | slot_mapping, 192 | ) 193 | }); 194 | } 195 | 196 | let output = if input_metadata.is_prompt { 197 | self._normal_attention( 198 | query, 199 | key, 200 | value, 201 | input_metadata, 202 | seq_len, 203 | batch_size, 204 | &device, 205 | dtype, 206 | )? 207 | } else { 208 | self._paged_attention( 209 | query, 210 | key_cache.as_ref().unwrap().clone(), 211 | value_cache.as_ref().unwrap().clone(), 212 | input_metadata, 213 | None, 214 | )? 215 | }; 216 | 217 | output 218 | .reshape((batch_size, seq_len, hidden_size)) 219 | .map_err(APIError::from) 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /src/openai/openai_server.rs: -------------------------------------------------------------------------------- 1 | use std::thread; 2 | 3 | use super::requests::ChatCompletionRequest; 4 | use super::requests::Messages; 5 | use super::responses::{APIError, ChatCompletionResponse, ChatCompletionUsageResponse}; 6 | use super::sampling_params::{EarlyStoppingCondition, SamplingParams}; 7 | use super::streaming::new_streaming_conn; 8 | use super::utils::get_created_time_secs; 9 | use super::OpenAIServerData; 10 | use actix_web::web::Bytes; 11 | use actix_web::{post, web, Either, HttpResponse}; 12 | use tokenizers::Encoding; 13 | use uuid::Uuid; 14 | 15 | fn verify_model(data: &OpenAIServerData<'_>, model_name: &String) -> Result<(), APIError> { 16 | let current_name = { 17 | let model = data.model.lock().unwrap(); 18 | model.get_pipeline().name().to_string() 19 | }; 20 | if ¤t_name != model_name { 21 | Err(APIError::new(format!( 22 | "Model name `{model_name}` is invalid." 23 | ))) 24 | } else { 25 | Ok(()) 26 | } 27 | } 28 | 29 | // Get prompt, roles 30 | async fn get_gen_prompt( 31 | data: &OpenAIServerData<'_>, 32 | request: &web::Json, 33 | ) -> Result { 34 | let mut model = data.model.lock().unwrap(); 35 | let conversation = model.get_mut_pipeline().get_conversation(); 36 | 37 | match &request.messages { 38 | Messages::Literal(msg) => { 39 | return Ok(msg.clone()); 40 | } 41 | Messages::Map(messages) => { 42 | for message in messages { 43 | let role = message 44 | .get("role") 45 | .ok_or(APIError::new("Message key `role` not found.".to_string()))?; 46 | let content = message 47 | .get("content") 48 | .ok_or(APIError::new( 49 | "Message key `content` not found.".to_string(), 50 | ))? 51 | .clone(); 52 | 53 | if role == "system" { 54 | conversation.set_system_message(content); 55 | } else if role == "user" { 56 | conversation.append_message(conversation.get_roles().0.clone(), content) 57 | } else if role == "assistant" { 58 | conversation.append_message(conversation.get_roles().1.clone(), content) 59 | } else { 60 | return Err(APIError::new(format!("Unknown role: {role}"))); 61 | } 62 | } 63 | } 64 | } 65 | 66 | conversation.append_none_message(conversation.get_roles().1.clone()); 67 | 68 | Ok(conversation.get_prompt()) 69 | } 70 | 71 | fn check_length( 72 | request: &web::Json, 73 | prompt: String, 74 | data: &OpenAIServerData<'_>, 75 | ) -> Result { 76 | let token_ids = { 77 | let model = data.model.lock().unwrap(); 78 | model.get_pipeline().tokenizer().tokenize(prompt)? 79 | }; 80 | 81 | let max_tokens = if let Some(max_toks) = request.max_tokens { 82 | max_toks 83 | } else { 84 | data.pipeline_config.max_model_len - token_ids.len() 85 | }; 86 | 87 | if token_ids.len() + max_tokens > data.pipeline_config.max_model_len { 88 | Err(APIError::new(format!( 89 | "This model's maximum context length is {} tokens. \ 90 | However, you requested {} tokens ({} in the messages, \ 91 | {} in the completion). Please reduce the length of the \ 92 | messages or completion.", 93 | data.pipeline_config.max_model_len, 94 | max_tokens + token_ids.len(), 95 | token_ids.len(), 96 | max_tokens 97 | ))) 98 | } else { 99 | Ok(token_ids) 100 | } 101 | } 102 | 103 | #[post("/v1/chat/completions")] 104 | async fn chat_completions( 105 | data: web::Data>, 106 | request: web::Json, 107 | ) -> Either, APIError>, HttpResponse> { 108 | let model_name = &request.model; 109 | let res = verify_model(&data, model_name); 110 | if res.is_err() { 111 | return Either::Left(Err(res.err().unwrap())); 112 | } 113 | 114 | if request.logit_bias.as_ref().is_some() 115 | && request.logit_bias.as_ref().is_some_and(|x| !x.is_empty()) 116 | { 117 | return Either::Left(Err(APIError::new_str( 118 | "`logit_bias` is not currently supported.", 119 | ))); 120 | } 121 | 122 | let prompt = get_gen_prompt(&data, &request).await; 123 | if prompt.is_err() { 124 | return Either::Left(Err(prompt.err().unwrap())); 125 | } 126 | let prompt = prompt.unwrap(); 127 | 128 | let token_ids = check_length(&request, prompt, &data); 129 | if token_ids.is_err() { 130 | return Either::Left(Err(token_ids.err().unwrap())); 131 | } 132 | let token_ids = token_ids.unwrap(); 133 | 134 | let request_id = format!("cmpl-{}", Uuid::new_v4()); 135 | 136 | let sampling_params = SamplingParams::new( 137 | request.n.unwrap_or(1), 138 | request.best_of, 139 | request.presence_penalty.unwrap_or(0.0), 140 | request.frequency_penalty.unwrap_or(0.0), 141 | 1.0, 142 | request.temperature.unwrap_or(0.7), 143 | request.top_p.unwrap_or(1.0), 144 | request.top_k.unwrap_or(-1), 145 | request.use_beam_search.unwrap_or(false), 146 | 1.0, 147 | EarlyStoppingCondition::UnlikelyBetterCandidates, 148 | request.stop.clone(), 149 | request.stop_token_ids.clone().unwrap_or_default(), 150 | request.ignore_eos.unwrap_or(false), 151 | request.max_tokens.unwrap_or(16), 152 | None, 153 | None, 154 | request.skip_special_tokens.unwrap_or(true), 155 | ); 156 | if sampling_params.is_err() { 157 | return Either::Left(Err(sampling_params.err().unwrap())); 158 | } 159 | let sampling_params = sampling_params.unwrap(); 160 | 161 | let created = get_created_time_secs(); 162 | 163 | if request.stream.is_some_and(|x| x) { 164 | let (sender, receiver) = new_streaming_conn(); 165 | let _ = thread::spawn(move || { 166 | let mut model = data.model.lock().unwrap(); 167 | let model_res = model.generate(token_ids, request_id, created, sampling_params); 168 | if model_res.is_err() { 169 | let runtime = tokio::runtime::Builder::new_current_thread() 170 | .enable_all() 171 | .build() 172 | .unwrap(); 173 | 174 | // Ignore sending errors 175 | let _ = runtime.block_on(sender.send(Ok(Bytes::from( 176 | serde_json::to_vec(&model_res.err().unwrap()).unwrap(), 177 | )))); 178 | } 179 | }); 180 | 181 | return Either::Right( 182 | HttpResponse::Ok() 183 | .append_header(("content-type", "text/event-stream")) 184 | //.no_chunking(asdf) 185 | .streaming(receiver), 186 | ); 187 | } 188 | 189 | let result = { 190 | let mut model = data.model.lock().unwrap(); 191 | let model_res = model.generate(token_ids, request_id.clone(), created, sampling_params); 192 | if model_res.is_err() { 193 | return Either::Left(Err(model_res.err().unwrap())); 194 | } 195 | model_res.unwrap() 196 | }; 197 | 198 | let choices = result 199 | .iter() 200 | .flat_map(|(choices, _)| choices.clone()) 201 | .collect::>(); 202 | let usage = ChatCompletionUsageResponse { 203 | completion_tokens: result 204 | .iter() 205 | .map(|(_, usage)| usage.completion_tokens) 206 | .sum(), 207 | prompt_tokens: result.iter().map(|(_, usage)| usage.prompt_tokens).sum(), 208 | total_tokens: result.iter().map(|(_, usage)| usage.total_tokens).sum(), 209 | }; 210 | 211 | Either::Left(Ok(web::Json(ChatCompletionResponse { 212 | id: request_id, 213 | choices, 214 | created, 215 | model: request.model.clone(), 216 | object: "chat.completion", 217 | usage, 218 | }))) 219 | } 220 | -------------------------------------------------------------------------------- /src/scheduler/cache_engine.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::HashMap, 3 | sync::{Arc, Mutex, MutexGuard}, 4 | }; 5 | 6 | use candle_core::{DType, Device, Tensor}; 7 | 8 | use crate::{ 9 | backend::{copy_blocks, swap_blocks}, 10 | openai::{models::ConfigLike, responses::APIError}, 11 | try_api, 12 | }; 13 | 14 | #[derive(Clone)] 15 | pub struct CacheConfig { 16 | pub block_size: usize, 17 | pub num_gpu_blocks: Option, // Set after profiling init 18 | pub num_cpu_blocks: Option, // Set after profiling init 19 | pub fully_init: bool, 20 | } 21 | 22 | impl CacheConfig { 23 | pub fn set_num_gpu_blocks(&mut self, num_gpu_blocks: usize) { 24 | if self.num_cpu_blocks.is_some() { 25 | self.fully_init = true; 26 | } 27 | self.num_gpu_blocks = Some(num_gpu_blocks); 28 | } 29 | pub fn set_num_cpu_blocks(&mut self, num_cpu_blocks: usize) { 30 | if self.num_gpu_blocks.is_some() { 31 | self.fully_init = true; 32 | } 33 | self.num_cpu_blocks = Some(num_cpu_blocks); 34 | } 35 | } 36 | 37 | pub type KVCache = (Tensor, Tensor); 38 | 39 | pub struct CacheEngine { 40 | gpu_cache: Arc>>, 41 | cpu_cache: Vec, 42 | num_layers: usize, 43 | } 44 | 45 | impl CacheEngine { 46 | pub fn new( 47 | model_config: Box, 48 | cache_config: CacheConfig, 49 | dtype: DType, 50 | ) -> Result { 51 | Ok(Self { 52 | gpu_cache: Arc::new(Mutex::new(Self::allocate_gpu_cache( 53 | &*model_config, 54 | &cache_config, 55 | dtype, 56 | )?)), 57 | cpu_cache: Self::allocate_cpu_cache(&*model_config, &cache_config, dtype)?, 58 | num_layers: model_config.get_num_hidden_layers(), 59 | }) 60 | } 61 | 62 | pub fn get_kv_cache(&self) -> MutexGuard<'_, Vec> { 63 | loop { 64 | if let Ok(v) = self.gpu_cache.try_lock() { 65 | return v; 66 | } 67 | } 68 | } 69 | 70 | fn allocate_gpu_cache( 71 | model_config: &dyn ConfigLike, 72 | cache_config: &CacheConfig, 73 | dtype: DType, 74 | ) -> Result, APIError> { 75 | assert!(cache_config.fully_init); 76 | 77 | let key_block_shape = 78 | Self::calculate_key_block_shape(model_config, dtype, cache_config.block_size); 79 | let value_block_shape = 80 | Self::calculate_value_block_shape(model_config, cache_config.block_size); 81 | let mut gpu_cache = Vec::new(); 82 | for _ in 0..model_config.get_num_hidden_layers() { 83 | let cuda_device = try_api!(Device::new_cuda(0)); 84 | let key_blocks = try_api!(Tensor::zeros( 85 | ( 86 | cache_config.num_gpu_blocks.unwrap(), 87 | key_block_shape.0, 88 | key_block_shape.1, 89 | key_block_shape.2, 90 | key_block_shape.3, 91 | ), 92 | dtype, 93 | &cuda_device, 94 | )); 95 | let value_blocks = try_api!(Tensor::zeros( 96 | ( 97 | cache_config.num_gpu_blocks.unwrap(), 98 | value_block_shape.0, 99 | value_block_shape.1, 100 | value_block_shape.2, 101 | ), 102 | dtype, 103 | &cuda_device, 104 | )); 105 | gpu_cache.push((key_blocks, value_blocks)); 106 | } 107 | Ok(gpu_cache) 108 | } 109 | 110 | fn allocate_cpu_cache( 111 | model_config: &dyn ConfigLike, 112 | cache_config: &CacheConfig, 113 | dtype: DType, 114 | ) -> Result, APIError> { 115 | assert!(cache_config.fully_init); 116 | 117 | let key_block_shape = 118 | Self::calculate_key_block_shape(model_config, dtype, cache_config.block_size); 119 | let value_block_shape = 120 | Self::calculate_value_block_shape(model_config, cache_config.block_size); 121 | let mut cpu_cache = Vec::new(); 122 | for _ in 0..model_config.get_num_hidden_layers() { 123 | let cuda_device = try_api!(Device::new_cuda(0)); 124 | let key_blocks = try_api!(Tensor::zeros( 125 | ( 126 | cache_config.num_cpu_blocks.unwrap(), 127 | key_block_shape.0, 128 | key_block_shape.1, 129 | key_block_shape.2, 130 | key_block_shape.3, 131 | ), 132 | dtype, 133 | &cuda_device, 134 | )); 135 | let value_blocks = try_api!(Tensor::zeros( 136 | ( 137 | cache_config.num_cpu_blocks.unwrap(), 138 | value_block_shape.0, 139 | value_block_shape.1, 140 | value_block_shape.2, 141 | ), 142 | dtype, 143 | &cuda_device, 144 | )); 145 | cpu_cache.push((key_blocks, value_blocks)); 146 | } 147 | Ok(cpu_cache) 148 | } 149 | } 150 | 151 | impl CacheEngine { 152 | fn calculate_key_block_shape( 153 | model_config: &dyn ConfigLike, 154 | dtype: DType, 155 | block_size: usize, 156 | ) -> (usize, usize, usize, usize) { 157 | let element_size = dtype.size_in_bytes(); 158 | let x = 16 / element_size; 159 | ( 160 | model_config.get_num_kv_heads(), 161 | model_config.get_head_size() / x, 162 | block_size, 163 | x, 164 | ) 165 | } 166 | 167 | fn calculate_value_block_shape( 168 | model_config: &dyn ConfigLike, 169 | block_size: usize, 170 | ) -> (usize, usize, usize) { 171 | ( 172 | model_config.get_num_kv_heads(), 173 | model_config.get_head_size(), 174 | block_size, 175 | ) 176 | } 177 | } 178 | 179 | impl CacheEngine { 180 | pub fn swap_in(&self, src_to_dst: HashMap) -> Result<(), APIError> { 181 | for i in 0..self.num_layers { 182 | let (src_key_cache, src_value_cache) = self.cpu_cache.get(i).unwrap(); 183 | let mut gpu_cache = self.get_kv_cache(); 184 | let (dst_key_cache, dst_value_cache) = gpu_cache.get_mut(i).unwrap(); 185 | // Swap (copy) key blocks 186 | try_api!(swap_blocks( 187 | src_key_cache.clone(), 188 | dst_key_cache, 189 | src_to_dst.clone() 190 | )); 191 | // Swap (copy) key blocks 192 | try_api!(swap_blocks( 193 | src_value_cache.clone(), 194 | dst_value_cache, 195 | src_to_dst.clone() 196 | )); 197 | } 198 | Ok(()) 199 | } 200 | 201 | pub fn swap_out(&mut self, src_to_dst: HashMap) -> Result<(), APIError> { 202 | for i in 0..self.num_layers { 203 | let gpu_cache = self.get_kv_cache(); 204 | let (src_key_cache, src_value_cache) = gpu_cache.get(i).unwrap().clone(); 205 | drop(gpu_cache); 206 | 207 | let (dst_key_cache, dst_value_cache) = self.cpu_cache.get_mut(i).unwrap(); 208 | // Swap (copy) key blocks 209 | try_api!(swap_blocks( 210 | src_key_cache.clone(), 211 | dst_key_cache, 212 | src_to_dst.clone() 213 | )); 214 | // Swap (copy) key blocks 215 | try_api!(swap_blocks( 216 | src_value_cache.clone(), 217 | dst_value_cache, 218 | src_to_dst.clone() 219 | )); 220 | } 221 | Ok(()) 222 | } 223 | 224 | pub fn copy(&mut self, src_to_dst: HashMap>) -> Result<(), APIError> { 225 | let mut gpu_cache = self.get_kv_cache(); 226 | #[allow(clippy::map_identity)] 227 | let caches: (Vec<&mut Tensor>, Vec<&mut Tensor>) = 228 | gpu_cache.iter_mut().map(|(a, b)| (a, b)).unzip(); 229 | let (key_caches, value_caches) = caches; 230 | 231 | // NOTE(EricLBuehler): This may synchronize the CPU and GPU 232 | try_api!(unsafe { copy_blocks(key_caches, value_caches, src_to_dst) }); 233 | 234 | Ok(()) 235 | } 236 | } 237 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import re 4 | import subprocess 5 | from typing import List, Set 6 | import warnings 7 | 8 | from packaging.version import parse, Version 9 | import setuptools 10 | import torch 11 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME 12 | 13 | ROOT_DIR = os.path.dirname(__file__) 14 | 15 | MAIN_CUDA_VERSION = "12.1" 16 | 17 | # Supported NVIDIA GPU architectures. 18 | SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} 19 | 20 | # Compiler flags. 21 | CXX_FLAGS = ["-g", "-O2", "-std=c++17"] 22 | # TODO(EricLBuehler): Should we use -O3? 23 | NVCC_FLAGS = ["-O2", "-std=c++17"] 24 | 25 | ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 26 | CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] 27 | NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] 28 | 29 | if CUDA_HOME is None: 30 | raise RuntimeError( 31 | "Cannot find CUDA_HOME. CUDA must be available to build the package.") 32 | 33 | 34 | def get_nvcc_cuda_version(cuda_dir: str) -> Version: 35 | """Get the CUDA version from nvcc. 36 | 37 | Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py 38 | """ 39 | nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], 40 | universal_newlines=True) 41 | output = nvcc_output.split() 42 | release_idx = output.index("release") + 1 43 | nvcc_cuda_version = parse(output[release_idx].split(",")[0]) 44 | return nvcc_cuda_version 45 | 46 | 47 | def get_torch_arch_list() -> Set[str]: 48 | # TORCH_CUDA_ARCH_LIST can have one or more architectures, 49 | # e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the 50 | # compiler to additionally include PTX code that can be runtime-compiled 51 | # and executed on the 8.6 or newer architectures. While the PTX code will 52 | # not give the best performance on the newer architectures, it provides 53 | # forward compatibility. 54 | env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) 55 | if env_arch_list is None: 56 | return set() 57 | 58 | # List are separated by ; or space. 59 | torch_arch_list = set(env_arch_list.replace(" ", ";").split(";")) 60 | if not torch_arch_list: 61 | return set() 62 | 63 | # Filter out the invalid architectures and print a warning. 64 | valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS}) 65 | arch_list = torch_arch_list.intersection(valid_archs) 66 | # If none of the specified architectures are valid, raise an error. 67 | if not arch_list: 68 | raise RuntimeError( 69 | "None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env " 70 | f"variable ({env_arch_list}) is supported. " 71 | f"Supported CUDA architectures are: {valid_archs}.") 72 | invalid_arch_list = torch_arch_list - valid_archs 73 | if invalid_arch_list: 74 | warnings.warn( 75 | f"Unsupported CUDA architectures ({invalid_arch_list}) are " 76 | "excluded from the `TORCH_CUDA_ARCH_LIST` env variable " 77 | f"({env_arch_list}). Supported CUDA architectures are: " 78 | f"{valid_archs}.", 79 | stacklevel=2) 80 | return arch_list 81 | 82 | 83 | # First, check the TORCH_CUDA_ARCH_LIST environment variable. 84 | compute_capabilities = get_torch_arch_list() 85 | if not compute_capabilities: 86 | # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available 87 | # GPUs on the current machine. 88 | device_count = torch.cuda.device_count() 89 | for i in range(device_count): 90 | major, minor = torch.cuda.get_device_capability(i) 91 | if major < 7: 92 | raise RuntimeError( 93 | "GPUs with compute capability below 7.0 are not supported.") 94 | compute_capabilities.add(f"{major}.{minor}") 95 | 96 | nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) 97 | if not compute_capabilities: 98 | # If no GPU is specified nor available, add all supported architectures 99 | # based on the NVCC CUDA version. 100 | compute_capabilities = SUPPORTED_ARCHS.copy() 101 | if nvcc_cuda_version < Version("11.1"): 102 | compute_capabilities.remove("8.6") 103 | if nvcc_cuda_version < Version("11.8"): 104 | compute_capabilities.remove("8.9") 105 | compute_capabilities.remove("9.0") 106 | 107 | # Validate the NVCC CUDA version. 108 | if nvcc_cuda_version < Version("11.0"): 109 | raise RuntimeError("CUDA 11.0 or higher is required to build the package.") 110 | if (nvcc_cuda_version < Version("11.1") 111 | and any(cc.startswith("8.6") for cc in compute_capabilities)): 112 | raise RuntimeError( 113 | "CUDA 11.1 or higher is required for compute capability 8.6.") 114 | if nvcc_cuda_version < Version("11.8"): 115 | if any(cc.startswith("8.9") for cc in compute_capabilities): 116 | # CUDA 11.8 is required to generate the code targeting compute capability 8.9. 117 | # However, GPUs with compute capability 8.9 can also run the code generated by 118 | # the previous versions of CUDA 11 and targeting compute capability 8.0. 119 | # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 120 | # instead of 8.9. 121 | warnings.warn( 122 | "CUDA 11.8 or higher is required for compute capability 8.9. " 123 | "Targeting compute capability 8.0 instead.", 124 | stacklevel=2) 125 | compute_capabilities = set(cc for cc in compute_capabilities 126 | if not cc.startswith("8.9")) 127 | compute_capabilities.add("8.0+PTX") 128 | if any(cc.startswith("9.0") for cc in compute_capabilities): 129 | raise RuntimeError( 130 | "CUDA 11.8 or higher is required for compute capability 9.0.") 131 | 132 | # Add target compute capabilities to NVCC flags. 133 | for capability in compute_capabilities: 134 | num = capability[0] + capability[2] 135 | NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] 136 | if capability.endswith("+PTX"): 137 | NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] 138 | 139 | # Use NVCC threads to parallelize the build. 140 | if nvcc_cuda_version >= Version("11.2"): 141 | num_threads = min(os.cpu_count(), 8) 142 | NVCC_FLAGS += ["--threads", str(num_threads)] 143 | 144 | ext_modules = [] 145 | vllm_extension = CUDAExtension( 146 | name="candle_vllm._C", 147 | sources=[ 148 | "csrc/cache_kernels.cu", 149 | "csrc/attention/attention_kernels.cu", 150 | "csrc/pos_encoding_kernels.cu", 151 | #"csrc/activation_kernels.cu", 152 | #"csrc/layernorm_kernels.cu", 153 | #"csrc/quantization/awq/gemm_kernels.cu", 154 | #"csrc/quantization/squeezellm/quant_cuda_kernel.cu", 155 | #"csrc/cuda_utils_kernels.cu", 156 | 157 | #"csrc/ops.h", 158 | #"csrc/cache.h", 159 | 160 | "csrc/rustbind.cpp" 161 | ], 162 | extra_compile_args={ 163 | "cxx": CXX_FLAGS, 164 | "nvcc": NVCC_FLAGS, 165 | }, 166 | ) 167 | ext_modules.append(vllm_extension) 168 | 169 | 170 | def get_path(*filepath) -> str: 171 | return os.path.join(ROOT_DIR, *filepath) 172 | 173 | 174 | def find_version(filepath: str) -> str: 175 | """Extract version information from the given filepath. 176 | 177 | Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py 178 | """ 179 | with open(filepath) as fp: 180 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", 181 | fp.read(), re.M) 182 | if version_match: 183 | return version_match.group(1) 184 | raise RuntimeError("Unable to find version string.") 185 | 186 | 187 | def get_candle_vllm_version() -> str: 188 | return "0.1.0" 189 | 190 | 191 | def read_readme() -> str: 192 | return "" 193 | 194 | 195 | def get_requirements() -> List[str]: 196 | return "" 197 | 198 | 199 | setuptools.setup( 200 | name="candle-vllm", 201 | version=get_candle_vllm_version(), 202 | author="Eric Buehler", 203 | license="MIT LICENSE", 204 | description=None, 205 | long_description=read_readme(), 206 | long_description_content_type="text/markdown", 207 | url="https://github.com/EricLBuehler/candle-vllm/", 208 | project_urls={ 209 | "Homepage": "https://github.com/EricLBuehler/candle-vllm/", 210 | "Documentation": None, 211 | }, 212 | classifiers=[], 213 | packages=setuptools.find_packages(exclude=("benchmarks", "csrc", "docs", 214 | "examples", "tests")), 215 | python_requires=">=3.8", 216 | install_requires=get_requirements(), 217 | ext_modules=ext_modules, 218 | cmdclass={"build_ext": BuildExtension}, 219 | ) 220 | -------------------------------------------------------------------------------- /src/scheduler/sequence.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::HashMap, 3 | sync::{Arc, Mutex, MutexGuard}, 4 | }; 5 | 6 | use candle_sampling::logits_processor::Logprobs; 7 | 8 | use super::block_engine::LogicalTokenBlock; 9 | 10 | #[derive(Clone)] 11 | pub enum SequenceStatus { 12 | FinishedIgnored, 13 | Waiting, 14 | Running, 15 | Swapped, 16 | FinishedAborted, 17 | Finished(String), 18 | } 19 | 20 | pub struct SequenceData { 21 | prompt_token_ids: Vec, 22 | output_token_ids: Vec, 23 | cumulative_logprob: f32, 24 | status: SequenceStatus, 25 | } 26 | 27 | impl SequenceData { 28 | pub fn new(prompt_token_ids: Vec) -> Self { 29 | Self { 30 | prompt_token_ids, 31 | output_token_ids: Vec::new(), 32 | cumulative_logprob: 0., 33 | status: SequenceStatus::Waiting, 34 | } 35 | } 36 | 37 | pub fn append_token_id(&mut self, logprobs: Logprobs) { 38 | self.cumulative_logprob += logprobs.logprob; 39 | self.output_token_ids.push(logprobs); 40 | } 41 | 42 | pub fn set_status(&mut self, status: SequenceStatus) { 43 | self.status = status; 44 | } 45 | 46 | fn get_cumulative_logprob(&self) -> f32 { 47 | self.cumulative_logprob 48 | } 49 | } 50 | 51 | /// A Sequence holds information about the data it contains (the tokens), and the logical token blocks 52 | /// to which it is mapped. 53 | pub struct _Sequence { 54 | data: Mutex, 55 | seq_id: usize, 56 | logical_token_blocks: Vec, 57 | block_size: usize, 58 | } 59 | 60 | impl _Sequence { 61 | pub fn new(prompt_token_ids: Vec, seq_id: usize, block_size: usize) -> Self { 62 | let mut this = Self { 63 | data: Mutex::new(SequenceData::new(prompt_token_ids.clone())), 64 | seq_id, 65 | logical_token_blocks: Vec::new(), 66 | block_size, 67 | }; 68 | this.append_tokens_to_blocks(prompt_token_ids); 69 | this 70 | } 71 | 72 | pub fn add_token(&mut self, logprobs: Logprobs) { 73 | self.append_token_to_blocks(logprobs.token); 74 | self.deref_mut().append_token_id(logprobs); 75 | } 76 | 77 | pub fn blocks_to_add_new_tok(&mut self) -> usize { 78 | let last = self.logical_token_blocks.last_mut(); 79 | if !last.is_some_and(|last| last.is_full()) { 80 | // If we have space 81 | 0 82 | } else { 83 | 1 84 | } 85 | } 86 | 87 | pub fn get_logical_token_blocks(&self) -> usize { 88 | self.logical_token_blocks.len() 89 | } 90 | 91 | pub fn get_id(&self) -> usize { 92 | self.seq_id 93 | } 94 | 95 | pub fn is_prompt(&self) -> bool { 96 | self.deref().output_token_ids.is_empty() 97 | } 98 | 99 | pub fn get_prompt_len(&self) -> usize { 100 | self.deref().prompt_token_ids.len() 101 | } 102 | 103 | pub fn get_len(&self) -> usize { 104 | self.deref().prompt_token_ids.len() + self.deref().output_token_ids.len() 105 | } 106 | 107 | pub fn get_token_ids(&self) -> Vec { 108 | let mut res = self.deref().prompt_token_ids.clone(); 109 | res.extend( 110 | self.deref() 111 | .output_token_ids 112 | .iter() 113 | .map(|logprobs| logprobs.token) 114 | .clone(), 115 | ); 116 | res 117 | } 118 | 119 | pub fn get_last_token_id(&self) -> usize { 120 | if self.deref().output_token_ids.is_empty() { 121 | *self.deref().prompt_token_ids.last().unwrap() 122 | } else { 123 | self.deref().output_token_ids.last().unwrap().token 124 | } 125 | } 126 | 127 | pub fn is_finished(&self) -> bool { 128 | matches!( 129 | self.deref().status, 130 | SequenceStatus::FinishedAborted 131 | | SequenceStatus::FinishedIgnored 132 | | SequenceStatus::Finished(_) 133 | ) 134 | } 135 | 136 | pub fn get_cumulative_logprob(&self) -> f32 { 137 | self.deref().get_cumulative_logprob() 138 | } 139 | 140 | pub fn set_finish_reason(&mut self, finish_reason: String) { 141 | self.deref() 142 | .set_status(SequenceStatus::Finished(finish_reason.clone())); 143 | } 144 | 145 | pub fn get_finish_reason(&self) -> String { 146 | match &self.deref().status { 147 | SequenceStatus::Finished(state) => state.clone(), 148 | SequenceStatus::FinishedAborted => "abort".to_string(), 149 | SequenceStatus::FinishedIgnored => "length".to_string(), 150 | _ => { 151 | unreachable!("No finish reason.") 152 | } 153 | } 154 | } 155 | 156 | #[must_use] 157 | /// Clones the internal logprobs. 158 | pub fn get_output_tokens(&self) -> Vec { 159 | self.deref().output_token_ids.clone() // TODO(EricLBuehler): Better way to do this? 160 | } 161 | 162 | fn append_tokens_to_blocks(&mut self, tokens: Vec) { 163 | for tok in tokens { 164 | self.append_token_to_blocks(tok); 165 | } 166 | } 167 | 168 | fn append_token_to_blocks(&mut self, token: usize) { 169 | let last = self.logical_token_blocks.last_mut(); 170 | if !last.as_ref().is_some_and(|last| last.is_full()) { 171 | // If we have space 172 | let last = last.unwrap(); 173 | last.append_token_id(token); 174 | } else { 175 | self.logical_token_blocks 176 | .push(LogicalTokenBlock::new(self.block_size)); 177 | self.logical_token_blocks 178 | .last_mut() 179 | .unwrap() 180 | .append_token_id(token); 181 | } 182 | } 183 | } 184 | 185 | impl _Sequence { 186 | pub fn deref(&self) -> MutexGuard<'_, SequenceData> { 187 | loop { 188 | if let Ok(res) = self.data.try_lock() { 189 | return res; 190 | } 191 | } 192 | } 193 | 194 | pub fn deref_mut(&self) -> MutexGuard<'_, SequenceData> { 195 | loop { 196 | if let Ok(res) = self.data.try_lock() { 197 | return res; 198 | } 199 | } 200 | } 201 | } 202 | 203 | pub struct Sequence(pub Mutex<_Sequence>); 204 | 205 | impl Sequence { 206 | pub fn deref_mut(&self) -> MutexGuard<'_, _Sequence> { 207 | loop { 208 | if let Ok(v) = self.0.try_lock() { 209 | return v; 210 | } 211 | } 212 | } 213 | } 214 | 215 | type SeqID = usize; 216 | 217 | /// A SequenceGroup holds the `n` (see SamplingParams) sequences generated from a single prompt. 218 | /// A SequenceGroup contains only sequences with the same prompt. They will always be scheduled together. 219 | pub struct SequenceGroup { 220 | seqs: HashMap>, 221 | arrival_time: u64, 222 | group_id: usize, 223 | request_id: String, 224 | created: u64, 225 | } 226 | 227 | impl SequenceGroup { 228 | pub fn new( 229 | seqs: &[Arc], 230 | arrival_time: u64, 231 | group_id: usize, 232 | request_id: String, 233 | created: u64, 234 | ) -> Self { 235 | let mut seq_map = HashMap::new(); 236 | for seq in seqs { 237 | seq_map.insert(seq.deref_mut().get_id(), seq.clone()); 238 | } 239 | Self { 240 | seqs: seq_map, 241 | arrival_time, 242 | group_id, 243 | request_id, 244 | created, 245 | } 246 | } 247 | 248 | pub fn set_status(&self, status: SequenceStatus) { 249 | for seq in self.seqs.values() { 250 | seq.deref_mut().deref().set_status(status.clone()); 251 | } 252 | } 253 | 254 | /// Blocks to add one new token to each sequence 255 | pub fn total_blocks_to_add_new_tok(&self) -> usize { 256 | self.seqs 257 | .values() 258 | .map(|seq| seq.deref_mut().blocks_to_add_new_tok()) 259 | .sum() 260 | } 261 | 262 | pub fn get_prompt_len(&self) -> usize { 263 | self.seqs.len() 264 | } 265 | 266 | pub fn get_total_logical_token_blocks(&self) -> usize { 267 | self.seqs 268 | .values() 269 | .map(|seq| seq.deref_mut().get_logical_token_blocks()) 270 | .sum() 271 | } 272 | 273 | pub fn get_seqs(&self) -> &HashMap> { 274 | &self.seqs 275 | } 276 | 277 | pub fn arrival_time(&self) -> u64 { 278 | self.arrival_time 279 | } 280 | 281 | pub fn get_id(&self) -> &usize { 282 | &self.group_id 283 | } 284 | 285 | pub fn is_finished(&self) -> bool { 286 | self.seqs.iter().all(|(_, x)| x.deref_mut().is_finished()) 287 | } 288 | 289 | pub fn get_request_id(&self) -> &String { 290 | &self.request_id 291 | } 292 | 293 | pub fn get_created_time(&self) -> u64 { 294 | self.created 295 | } 296 | } 297 | -------------------------------------------------------------------------------- /tests/test_flan-t5-quantized.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "mkl")] 2 | extern crate intel_mkl_src; 3 | 4 | #[cfg(feature = "accelerate")] 5 | extern crate accelerate_src; 6 | use std::io::Write; 7 | use std::path::PathBuf; 8 | 9 | use actix_web::{post, web, App, HttpResponse, HttpServer, Responder}; 10 | use serde::{Deserialize, Serialize}; 11 | 12 | use candle_transformers::models::quantized_t5 as t5; 13 | 14 | use anyhow::{Error as E, Result}; 15 | use candle_core::{Device, Tensor}; 16 | use candle_transformers::generation::LogitsProcessor; 17 | use clap::{Parser, ValueEnum}; 18 | use hf_hub::{api::sync::Api, api::sync::ApiRepo, Repo, RepoType}; 19 | use tokenizers::Tokenizer; 20 | 21 | #[derive(Clone, Debug, Copy, ValueEnum)] 22 | enum Which { 23 | T5Small, 24 | FlanT5Small, 25 | FlanT5Base, 26 | FlanT5Large, 27 | FlanT5Xl, 28 | FlanT5Xxl, 29 | } 30 | 31 | #[derive(Parser, Debug, Clone)] 32 | #[command(author, version, about, long_about = None)] 33 | 34 | struct Args { 35 | /// Enable tracing (generates a trace-timestamp.json file). 36 | #[arg(long)] 37 | tracing: bool, 38 | 39 | /// The model repository to use on the HuggingFace hub. 40 | #[arg(long)] 41 | model_id: Option, 42 | 43 | #[arg(long)] 44 | revision: Option, 45 | 46 | #[arg(long)] 47 | weight_file: Option, 48 | 49 | #[arg(long)] 50 | config_file: Option, 51 | 52 | // Enable/disable decoding. 53 | #[arg(long, default_value = "false")] 54 | disable_cache: bool, 55 | 56 | /// Use this prompt, otherwise compute sentence similarities. 57 | // #[arg(long)] 58 | // prompt: Option, 59 | 60 | /// The temperature used to generate samples. 61 | #[arg(long, default_value_t = 0.8)] 62 | temperature: f64, 63 | 64 | /// Nucleus sampling probability cutoff. 65 | #[arg(long)] 66 | top_p: Option, 67 | 68 | /// Penalty to be applied for repeating tokens, 1. means no penalty. 69 | #[arg(long, default_value_t = 1.1)] 70 | repeat_penalty: f32, 71 | 72 | /// The context size to consider for the repeat penalty. 73 | #[arg(long, default_value_t = 64)] 74 | repeat_last_n: usize, 75 | 76 | /// The model size to use. 77 | #[arg(long, default_value = "flan-t5-large")] 78 | which: Which, 79 | } 80 | 81 | struct T5ModelBuilder { 82 | device: Device, 83 | config: t5::Config, 84 | weights_filename: PathBuf, 85 | } 86 | 87 | impl T5ModelBuilder { 88 | pub fn load(args: &Args) -> Result<(Self, Tokenizer)> { 89 | let device = Device::Cpu; 90 | let default_model = "lmz/candle-quantized-t5".to_string(); 91 | let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) { 92 | (Some(model_id), Some(revision)) => (model_id, revision), 93 | (Some(model_id), None) => (model_id, "main".to_string()), 94 | (None, Some(revision)) => (default_model, revision), 95 | (None, None) => (default_model, "main".to_string()), 96 | }; 97 | 98 | let repo = Repo::with_revision(model_id, RepoType::Model, revision); 99 | let api = Api::new()?; 100 | let api = api.repo(repo); 101 | let config_filename = match &args.config_file { 102 | Some(filename) => Self::get_local_or_remote_file(filename, &api)?, 103 | None => match args.which { 104 | Which::T5Small => api.get("config.json")?, 105 | Which::FlanT5Small => api.get("config-flan-t5-small.json")?, 106 | Which::FlanT5Base => api.get("config-flan-t5-base.json")?, 107 | Which::FlanT5Large => api.get("config-flan-t5-large.json")?, 108 | Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?, 109 | Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?, 110 | }, 111 | }; 112 | let tokenizer_filename = api.get("tokenizer.json")?; 113 | let weights_filename = match &args.weight_file { 114 | Some(filename) => Self::get_local_or_remote_file(filename, &api)?, 115 | None => match args.which { 116 | Which::T5Small => api.get("model.gguf")?, 117 | Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?, 118 | Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?, 119 | Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?, 120 | Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?, 121 | Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?, 122 | }, 123 | }; 124 | 125 | let config = std::fs::read_to_string(config_filename)?; 126 | let mut config: t5::Config = serde_json::from_str(&config)?; 127 | config.use_cache = !args.disable_cache; 128 | let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; 129 | Ok(( 130 | Self { 131 | device, 132 | config, 133 | weights_filename, 134 | }, 135 | tokenizer, 136 | )) 137 | } 138 | 139 | pub fn build_model(&self) -> Result { 140 | let vb = t5::VarBuilder::from_gguf(&self.weights_filename, &self.device)?; 141 | Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) 142 | } 143 | 144 | fn get_local_or_remote_file(filename: &str, api: &ApiRepo) -> Result { 145 | let local_filename = std::path::PathBuf::from(filename); 146 | if local_filename.exists() { 147 | Ok(local_filename) 148 | } else { 149 | Ok(api.get(filename)?) 150 | } 151 | } 152 | } 153 | fn generate_answer(_prompt: String, args: &Args) -> Result { 154 | let mut generated_text = String::new(); 155 | 156 | let (_builder, mut _tokenizer) = T5ModelBuilder::load(args)?; 157 | let device = &_builder.device; 158 | let _tokenizer = _tokenizer 159 | .with_padding(None) 160 | .with_truncation(None) 161 | .map_err(E::msg)?; 162 | let _tokens = _tokenizer 163 | .encode(_prompt, true) 164 | .map_err(E::msg)? 165 | .get_ids() 166 | .to_vec(); 167 | let input_token_ids = Tensor::new(&_tokens[..], device)?.unsqueeze(0)?; 168 | let mut model = _builder.build_model()?; 169 | let mut output_token_ids = [_builder.config.pad_token_id as u32].to_vec(); 170 | let temperature = 0.8f64; 171 | 172 | let mut logits_processor = LogitsProcessor::new(299792458, Some(temperature), None); 173 | let encoder_output = model.encode(&input_token_ids)?; 174 | 175 | let start = std::time::Instant::now(); 176 | 177 | for index in 0.. { 178 | if output_token_ids.len() > 512 { 179 | break; 180 | } 181 | let decoder_token_ids = if index == 0 || !_builder.config.use_cache { 182 | Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)? 183 | } else { 184 | let last_token = *output_token_ids.last().unwrap(); 185 | Tensor::new(&[last_token], device)?.unsqueeze(0)? 186 | }; 187 | let logits = model 188 | .decode(&decoder_token_ids, &encoder_output)? 189 | .squeeze(0)?; 190 | let logits = if args.repeat_penalty == 1. { 191 | logits 192 | } else { 193 | let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n); 194 | candle_transformers::utils::apply_repeat_penalty( 195 | &logits, 196 | args.repeat_penalty, 197 | &output_token_ids[start_at..], 198 | )? 199 | }; 200 | 201 | let next_token_id = logits_processor.sample(&logits)?; 202 | if next_token_id as usize == _builder.config.eos_token_id { 203 | break; 204 | } 205 | output_token_ids.push(next_token_id); 206 | if let Some(text) = _tokenizer.id_to_token(next_token_id) { 207 | let text = text.replace('▁', " ").replace("<0x0A>", "\n"); 208 | generated_text.push_str(&text); 209 | print!("{}", text); 210 | std::io::stdout().flush()?; 211 | } 212 | } 213 | let dt = start.elapsed(); 214 | println!( 215 | "\n{} tokens generated ({:.2} token/s)\n", 216 | output_token_ids.len(), 217 | output_token_ids.len() as f64 / dt.as_secs_f64(), 218 | ); 219 | 220 | Ok(generated_text) 221 | } 222 | 223 | // request struct 224 | #[derive(Deserialize)] 225 | struct Request { 226 | prompt: String, 227 | } 228 | 229 | #[derive(Serialize)] 230 | struct Response { 231 | answer: String, 232 | } 233 | 234 | #[post("/generate")] 235 | async fn generate(req_body: web::Json) -> impl Responder { 236 | let args = Args::parse(); 237 | let generated_answer = generate_answer(req_body.prompt.clone(), &args); 238 | HttpResponse::Ok().json(Response { 239 | answer: generated_answer.unwrap(), 240 | }) 241 | } 242 | 243 | #[actix_web::main] 244 | async fn main() -> std::io::Result<()> { 245 | HttpServer::new(|| App::new().service(generate)) 246 | .bind("localhost:1111")? 247 | .run() 248 | .await 249 | } 250 | -------------------------------------------------------------------------------- /src/paged_attention/memory_efficient_attention.rs: -------------------------------------------------------------------------------- 1 | use candle_core::{DType, Device, IndexOp, Shape, Tensor, D}; 2 | 3 | use crate::{ 4 | openai::responses::APIError, 5 | paged_attention::attn_bias::{BlockDiagonalCausalMask, LowerTriangularMaskWithTensorBias}, 6 | try_api, 7 | }; 8 | 9 | use super::{input_metadata::InputMetadata, PagedAttention}; 10 | 11 | #[allow(clippy::too_many_arguments)] 12 | pub fn _memory_efficient_attention( 13 | this: &PagedAttention, 14 | query: Tensor, 15 | key: Tensor, 16 | value: Tensor, 17 | input_metadata: &mut InputMetadata, 18 | seq_len: usize, 19 | batch_size: usize, 20 | device: &Device, 21 | dtype: DType, 22 | ) -> Result { 23 | let (query, key, value) = if this.num_key_value_heads != this.num_attention_heads { 24 | let query = try_api!(query.reshape(( 25 | *query.shape().dims().first().unwrap(), 26 | this.num_key_value_heads, 27 | this.num_queries_per_kv, 28 | *query.shape().dims().last().unwrap(), 29 | ))); 30 | 31 | let key = try_api!( 32 | try_api!(try_api!(key.i((.., .., .., ..))).unsqueeze(2)).expand(( 33 | *key.shape().dims().first().unwrap(), 34 | this.num_key_value_heads, 35 | this.num_queries_per_kv, 36 | *key.shape().dims().last().unwrap(), 37 | )) 38 | ); 39 | 40 | let value = try_api!( 41 | try_api!(try_api!(value.i((.., .., .., ..))).unsqueeze(2)).expand(( 42 | *value.shape().dims().first().unwrap(), 43 | this.num_key_value_heads, 44 | this.num_queries_per_kv, 45 | *value.shape().dims().last().unwrap(), 46 | )) 47 | ); 48 | 49 | (query, key, value) 50 | } else { 51 | (query, key, value) 52 | }; 53 | 54 | if input_metadata.attn_bias.is_none() { 55 | if let Some(alibi_slopes) = &this.alibi_slopes { 56 | //make alibi bias 57 | let bias = try_api!(try_api!(Tensor::arange( 58 | 0f64, 59 | TryInto::::try_into(seq_len).unwrap().into(), 60 | device, 61 | )) 62 | .to_dtype(dtype)); 63 | let bias = try_api!(try_api!( 64 | (try_api!(bias.unsqueeze(0)) - try_api!(bias.unsqueeze(1))) 65 | ) 66 | .to_device(alibi_slopes.device())); 67 | 68 | let padded_len = ((seq_len + 7) / 8) * 8; 69 | let mut bias_new = try_api!(try_api!(Tensor::zeros( 70 | ( 71 | batch_size, 72 | alibi_slopes.shape().dims()[0], 73 | seq_len, 74 | padded_len, 75 | ), 76 | dtype, 77 | device, 78 | )) 79 | .i((.., .., .., ..seq_len))); 80 | 81 | bias_new = try_api!(bias_new.slice_assign(&[.., .., .., ..], &bias)); 82 | 83 | bias_new = try_api!(bias_new.mul(&try_api!(try_api!( 84 | try_api!(alibi_slopes.i(..)).unsqueeze(1) 85 | ) 86 | .unsqueeze(2)),)); 87 | let attn_bias = LowerTriangularMaskWithTensorBias::new(bias_new); 88 | input_metadata.attn_bias = Some(Box::new(attn_bias)); 89 | } else { 90 | let mut attn_bias = try_api!(BlockDiagonalCausalMask::from_seqlens( 91 | [seq_len.try_into().unwrap()].repeat(batch_size), 92 | None, 93 | )); 94 | if let Some(sliding_window) = this.sliding_window { 95 | attn_bias = try_api!(attn_bias.make_local_attention(sliding_window)); 96 | } 97 | input_metadata.attn_bias = Some(attn_bias); 98 | } 99 | } 100 | 101 | let (query, key, value) = if this.alibi_slopes.is_none() { 102 | ( 103 | try_api!(query.unsqueeze(0)), 104 | try_api!(key.unsqueeze(0)), 105 | try_api!(value.unsqueeze(0)), 106 | ) 107 | } else { 108 | assert_eq!(query.shape().dims().len(), key.shape().dims().len()); 109 | assert_eq!(value.shape().dims().len(), key.shape().dims().len()); 110 | assert!(query.shape().dims().len() == 3 || query.shape().dims().len() == 4); 111 | if query.shape().dims().len() == 3 { 112 | ( 113 | try_api!(query.reshape(( 114 | batch_size, 115 | seq_len, 116 | query.shape().dims()[1], 117 | query.shape().dims()[2], 118 | ))), 119 | try_api!(key.reshape(( 120 | batch_size, 121 | seq_len, 122 | key.shape().dims()[1], 123 | key.shape().dims()[2], 124 | ))), 125 | try_api!(value.reshape(( 126 | batch_size, 127 | seq_len, 128 | value.shape().dims()[1], 129 | value.shape().dims()[2], 130 | ))), 131 | ) 132 | } else { 133 | ( 134 | try_api!(query.reshape(( 135 | batch_size, 136 | seq_len, 137 | query.shape().dims()[1], 138 | query.shape().dims()[2], 139 | query.shape().dims()[3], 140 | ))), 141 | try_api!(key.reshape(( 142 | batch_size, 143 | seq_len, 144 | key.shape().dims()[1], 145 | key.shape().dims()[2], 146 | key.shape().dims()[3], 147 | ))), 148 | try_api!(value.reshape(( 149 | batch_size, 150 | seq_len, 151 | value.shape().dims()[1], 152 | value.shape().dims()[2], 153 | value.shape().dims()[3], 154 | ))), 155 | ) 156 | } 157 | }; 158 | 159 | let l = try_api!(query.dim(D::Minus2)); 160 | let s = try_api!(key.dim(D::Minus2)); 161 | 162 | scaled_dot_product_attention( 163 | &query, 164 | &key, 165 | &value, 166 | &try_api!(input_metadata.attn_bias.as_ref().unwrap().materialize( 167 | &Shape::from_dims(&[l, s]), 168 | query.dtype(), 169 | device 170 | )), 171 | None, 172 | this.scale, 173 | ) 174 | } 175 | 176 | #[cfg(feature = "cuda")] 177 | /// Flash-attention v2 layer. 178 | /// 179 | /// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. 180 | /// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads 181 | /// than q, the number of heads in k and v has to be divisible by the number of heads in q. 182 | /// 183 | /// # Arguments 184 | /// 185 | /// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. 186 | /// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. 187 | /// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. 188 | /// 189 | /// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. 190 | pub fn scaled_dot_product_attention( 191 | query: &Tensor, 192 | key: &Tensor, 193 | value: &Tensor, 194 | _attn_bias: &Tensor, 195 | _dropout_p: Option, 196 | scale_factor: f32, 197 | ) -> Result { 198 | candle_flash_attn::flash_attn(query, key, value, scale_factor, false).map_err(APIError::from) 199 | } 200 | 201 | #[cfg(not(feature = "cuda"))] 202 | // https://github.com/mokeyish/candle-ext/blob/main/src/scaled_dot_product_attention.rs 203 | 204 | /// Computes scaled dot product attention on query, key and value tensors, 205 | /// using an optional attention mask if passed, and applying dropout 206 | /// if a probability greater than 0.0 is specified. 207 | /// 208 | /// # Arguments 209 | /// - query - Query tensor; shape (N, ..., L, E) 210 | /// - key - Key tensor; shape (N, ..., S, E) 211 | /// - value - Value tensor; shape (N, ..., S, E) 212 | /// 213 | /// https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html 214 | /// # Errors 215 | /// 216 | /// This function will return an error if . 217 | pub fn scaled_dot_product_attention( 218 | query: &Tensor, 219 | key: &Tensor, 220 | value: &Tensor, 221 | attn_bias: &Tensor, 222 | dropout_p: Option, 223 | scale_factor: f32, 224 | ) -> Result { 225 | let mut attn_weights = try_api!( 226 | try_api!(query.matmul(&try_api!( 227 | try_api!(key.transpose(D::Minus2, D::Minus1)).contiguous() 228 | ),)) * f64::from(scale_factor) 229 | ); 230 | 231 | attn_weights = try_api!(&attn_weights + try_api!(attn_bias.broadcast_as(attn_weights.shape()))); 232 | attn_weights = try_api!(candle_nn::ops::softmax_last_dim(&attn_weights)); 233 | 234 | if let Some(drop_p) = dropout_p { 235 | attn_weights = try_api!(candle_nn::ops::dropout(&attn_weights, drop_p)); 236 | } 237 | attn_weights.matmul(value).map_err(APIError::from) 238 | } 239 | -------------------------------------------------------------------------------- /src/openai/pipelines/llama.rs: -------------------------------------------------------------------------------- 1 | use std::{iter::zip, path::PathBuf, sync::Arc}; 2 | 3 | use crate::{ 4 | openai::{ 5 | conversation::{ 6 | default_conversation::{ 7 | DefaultConversation, DefaultConversationSeparators, SeparatorStyle, 8 | }, 9 | Conversation, 10 | }, 11 | models::{ 12 | llama::{Llama, LlamaConfig}, 13 | ConfigLike, 14 | }, 15 | requests::StopTokens, 16 | responses::APIError, 17 | sampling_params::SamplingParams, 18 | PipelineConfig, TokenizerWrapper, 19 | }, 20 | paged_attention::input_metadata::InputMetadata, 21 | scheduler::sequence::Sequence, 22 | try_api, 23 | }; 24 | use candle_core::{DType, Device, IndexOp, Tensor}; 25 | use candle_lora_transformers::varbuilder_utils::from_mmaped_safetensors; 26 | use either::Either::{Left, Right}; 27 | use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; 28 | use tokenizers::Tokenizer; 29 | 30 | use super::{get_token, ModelLoader, ModelPaths, ModulePipeline, TokenOrFinishReason}; 31 | 32 | const EOS_TOKEN: &str = ""; 33 | const SAMPLING_SEED: u64 = 299792458; 34 | 35 | #[derive(Debug, Clone)] 36 | pub struct LlamaSpecificConfig { 37 | repeat_last_n: usize, 38 | } 39 | 40 | impl LlamaSpecificConfig { 41 | pub fn new(repeat_last_n: usize) -> Self { 42 | Self { repeat_last_n } 43 | } 44 | } 45 | 46 | /// top-p, multinomial, and argmax sampling are implemented. Beam search is not implemented. 47 | pub struct LlamaPipeline { 48 | llama: Llama, 49 | args: LlamaSpecificConfig, 50 | tokenizer: Tokenizer, 51 | conversation: DefaultConversation, 52 | name: String, 53 | } 54 | 55 | pub struct LlamaLoader { 56 | config: LlamaSpecificConfig, 57 | name: String, 58 | } 59 | 60 | pub struct LlamaModelPaths

{ 61 | tokenizer_filename: P, 62 | config_filename: P, 63 | filenames: Vec

, 64 | } 65 | 66 | impl ModelPaths for LlamaModelPaths { 67 | fn get_config_filename(&self) -> &PathBuf { 68 | &self.config_filename 69 | } 70 | fn get_tokenizer_filename(&self) -> &PathBuf { 71 | &self.tokenizer_filename 72 | } 73 | fn get_weight_filenames(&self) -> &Vec { 74 | &self.filenames 75 | } 76 | } 77 | 78 | impl LlamaLoader { 79 | pub fn new(config: LlamaSpecificConfig, name: String) -> Self { 80 | Self { config, name } 81 | } 82 | } 83 | 84 | impl<'a> ModelLoader<'a> for LlamaLoader { 85 | fn download_model( 86 | &self, 87 | model_id: String, 88 | revision: Option, 89 | hf_token: Option, 90 | hf_token_path: Option, 91 | ) -> Result, APIError> { 92 | let api = try_api!(ApiBuilder::new() 93 | .with_progress(true) 94 | .with_token(Some(get_token(hf_token, hf_token_path)?)) 95 | .build()); 96 | let revision = revision.unwrap_or("main".to_string()); 97 | let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); 98 | 99 | let tokenizer_filename = try_api!(api.get("tokenizer.json")); 100 | 101 | let config_filename = try_api!(api.get("config.json")); 102 | 103 | let mut filenames = vec![]; 104 | for rfilename in try_api!(api.info()) 105 | .siblings 106 | .iter() 107 | .map(|x| x.rfilename.clone()) 108 | .filter(|x| x.ends_with(".safetensors")) 109 | { 110 | let filename = try_api!(api.get(&rfilename)); 111 | filenames.push(filename); 112 | } 113 | 114 | Ok(Box::new(LlamaModelPaths { 115 | tokenizer_filename, 116 | config_filename, 117 | filenames, 118 | })) 119 | } 120 | 121 | fn load_model( 122 | &self, 123 | paths: Box, 124 | dtype: DType, 125 | device: Device, 126 | ) -> Result<(Box>, PipelineConfig), APIError> { 127 | let args = self.config.clone(); 128 | 129 | let config: LlamaConfig = try_api!(serde_json::from_slice(&try_api!(std::fs::read( 130 | paths.get_config_filename() 131 | )),)); 132 | let config = config.into_config(); 133 | 134 | println!("Loading {} model.", self.name); 135 | 136 | let vb = try_api!(from_mmaped_safetensors( 137 | paths.get_weight_filenames(), 138 | dtype, 139 | &device, 140 | false 141 | )); 142 | 143 | let llama = try_api!(Llama::load(vb, &config, dtype, &device)); 144 | 145 | let tokenizer = Tokenizer::from_file(paths.get_tokenizer_filename()) 146 | .map_err(|x| APIError::new(x.to_string()))?; 147 | 148 | println!("Done loading."); 149 | 150 | //max is https://huggingface.co/docs/transformers/model_doc/llama2#transformers.LlamaConfig.max_position_embeddings 151 | let pipeline_config = PipelineConfig { 152 | max_model_len: 4096, 153 | }; 154 | 155 | //reference: https://huggingface.co/blog/codellama#conversational-instructions, 156 | //reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212 157 | Ok(( 158 | Box::new(LlamaPipeline { 159 | llama, 160 | args, 161 | tokenizer, 162 | conversation: DefaultConversation::new( 163 | "llama-2".to_string(), 164 | "[INST] <>\n{}\n<>\n\n".to_string(), 165 | Vec::default(), 166 | 0, 167 | SeparatorStyle::Llama2, 168 | "".to_string(), 169 | Vec::default(), 170 | ("[INST]".to_string(), "[/INST]".to_string()), 171 | DefaultConversationSeparators { 172 | sep: " ".to_string(), 173 | sep2: Some(" ".to_string()), 174 | }, 175 | ), 176 | name: self.name.clone(), 177 | }), 178 | pipeline_config, 179 | )) 180 | } 181 | } 182 | 183 | impl<'s> ModulePipeline<'s> for LlamaPipeline { 184 | fn forward( 185 | &mut self, 186 | input_tokens: Tensor, 187 | input_positions: Tensor, 188 | kv_cache: Option<&Vec<(Tensor, Tensor)>>, 189 | mut input_metadata: InputMetadata, 190 | ) -> Result { 191 | self.llama.forward( 192 | &input_tokens, 193 | &input_positions, 194 | kv_cache, 195 | &mut input_metadata, 196 | ) 197 | } 198 | 199 | fn sample( 200 | &mut self, 201 | logits: Tensor, 202 | sampling_params: &SamplingParams, 203 | seqs: &[(&usize, &Arc)], 204 | ) -> Result, APIError> { 205 | let eos_token_id = self.tokenizer.token_to_id(EOS_TOKEN); 206 | 207 | let mut logits_processor = sampling_params.get_logits_processor( 208 | SAMPLING_SEED, 209 | &self.tokenizer, 210 | sampling_params.logprobs.unwrap_or(1), 211 | ); 212 | let stop_tokens = match sampling_params.stop.clone() { 213 | Some(stop) => match stop { 214 | StopTokens::Multi(multi) => multi, 215 | StopTokens::Single(single) => vec![single], 216 | }, 217 | 218 | None => vec![], 219 | }; 220 | 221 | let n_seqs = logits.dims()[0]; 222 | 223 | let mut result = Vec::new(); 224 | for (seq_n, (_, seq)) in zip(0..n_seqs, seqs) { 225 | let logits = try_api!(logits.i((seq_n, try_api!(logits.dim(1)) - 1))); 226 | 227 | let tokens = seq 228 | .deref_mut() 229 | .get_token_ids() 230 | .iter() 231 | .map(|x| *x as u32) 232 | .collect::>(); 233 | let tokens_generated = seq.deref_mut().get_len() - seq.deref_mut().get_prompt_len(); 234 | 235 | let logits = if sampling_params.repetition_penalty == 1. { 236 | logits 237 | } else { 238 | let start_at = tokens.len().saturating_sub(self.args.repeat_last_n); 239 | try_api!(candle_transformers::utils::apply_repeat_penalty( 240 | &logits, 241 | sampling_params.repetition_penalty, 242 | &tokens[start_at..], 243 | )) 244 | }; 245 | 246 | let next_token = try_api!(logits_processor.sample(&logits)); 247 | if let Some(text) = self.tokenizer.id_to_token(next_token.token as u32) { 248 | let text = text.replace('▁', " ").replace("<0x0A>", "\n"); 249 | if stop_tokens.contains(&text) { 250 | result.push(Right("stop".to_string())); 251 | continue; 252 | } 253 | } 254 | 255 | if Some(next_token.token) == eos_token_id.map(|x| x as usize) { 256 | result.push(Right("stop".to_string())); 257 | continue; 258 | } 259 | if tokens_generated >= sampling_params.max_tokens { 260 | result.push(Right("length".to_string())); 261 | continue; 262 | } 263 | result.push(Left(next_token)); 264 | } 265 | 266 | Ok(result) 267 | } 268 | 269 | fn name(&self) -> &str { 270 | &self.name 271 | } 272 | 273 | fn tokenizer(&self) -> &dyn TokenizerWrapper<'s, String> { 274 | &self.tokenizer 275 | } 276 | 277 | fn get_conversation(&mut self) -> &mut dyn Conversation { 278 | &mut self.conversation 279 | } 280 | 281 | fn get_model_config(&self) -> Box { 282 | Box::new(self.llama.get_config().clone()) 283 | } 284 | 285 | fn get_dtype(&self) -> DType { 286 | todo!() 287 | } 288 | } 289 | 290 | unsafe impl Send for LlamaPipeline {} 291 | unsafe impl Sync for LlamaPipeline {} 292 | -------------------------------------------------------------------------------- /src/openai/sampling_params.rs: -------------------------------------------------------------------------------- 1 | use std::ops::Range; 2 | 3 | use candle_sampling::logits_processor::{LogitsProcessor, SamplingMethod}; 4 | use tokenizers::Tokenizer; 5 | 6 | use super::{requests::StopTokens, responses::APIError}; 7 | 8 | const SAMPLING_EPS: f32 = 1e-5; 9 | 10 | #[derive(Clone, Debug, PartialEq)] 11 | pub enum EarlyStoppingCondition { 12 | ///True 13 | BestOfCompleteCandidates, 14 | ///False 15 | UnlikelyBetterCandidates, 16 | ///"never" 17 | CanonicalNoBetterCandidates, 18 | } 19 | 20 | #[derive(Clone, Debug, PartialEq, Eq, Hash)] 21 | pub enum SamplingType { 22 | BEAM, 23 | GREEDY, 24 | RANDOM, 25 | } 26 | 27 | #[derive(Clone, Debug)] 28 | pub struct SamplingParams { 29 | /// Number of output seqs to return for a prompt. 30 | pub n: usize, 31 | /// Number of output seqs that are generated from the prompt, from these `best_of` seqs, the top `n` sequences are returned. `best_of` must be `>=n`. Default = `n`. 32 | /// Beam width when `use_beam_search` is true. 33 | pub best_of: usize, 34 | /// Penalize new tokens based upon whether they appear in the generated text so far, >0 encourage new, <0 encourage repeat. 35 | /// rec. default = 0 36 | pub presence_penalty: f32, 37 | /// Penalize new tokens based upon whether their frequency in the generated text so far, >0 encourage new, <0 encourage repeat. 38 | /// rec. default = 0 39 | pub frequency_penalty: f32, 40 | /// Penalize new tokens based upon whether their frequency in the generated text so far, >1 encourage new, <1 encourage repeat 41 | /// rec. default = 1 42 | pub repetition_penalty: f32, 43 | /// Randomness of sampling. 44 | /// rec. default = 1 45 | pub temperature: f32, 46 | /// Cumulative prob of the top tokens to consider, must be in (0, 1]. Set 1 to consider all toks. 47 | /// rec. default = 1 48 | pub top_p: f32, 49 | /// Control the number of top tokens to consider, set -1 to consider all. 50 | /// rec. default = -1 51 | pub top_k: isize, 52 | /// Use beam search instead of sampling. 53 | /// rec. default = false 54 | pub use_beam_search: bool, 55 | /// Penalize based on length. 56 | /// rec. default = 1 57 | pub length_penalty: f32, 58 | /// Control stopping for beam search. 59 | /// rec. default = EarlyStoppingCondition::UnlikelyBetterCandidates 60 | pub early_stopping: EarlyStoppingCondition, 61 | /// Strings that stop generation when generated. 62 | pub stop: Option, 63 | /// Tokens to stop on. 64 | pub stop_token_ids: Vec, 65 | /// Whether to ignore EOS token. 66 | pub ignore_eos: bool, 67 | /// Max number of toks to gen per output seq. 68 | /// rec. default = 16 69 | pub max_tokens: usize, 70 | /// Num of log probs to return per output token. Follows OpenAI API, return result include the log probabilities on the `logprobs` most likely tokens. 71 | /// will always return the log prob of the sampled token, so there may be up to `logprobs+1` elements in the response. 72 | /// Default = 1 73 | pub logprobs: Option, 74 | /// Num of log probs to return per prompt token. 75 | pub prompt_logprobs: Option, 76 | /// Skip special toks in output. 77 | /// rec. default = true 78 | pub skip_special_tokens: bool, 79 | } 80 | 81 | impl SamplingParams { 82 | #[allow(clippy::too_many_arguments)] 83 | pub fn new( 84 | n: usize, 85 | best_of: Option, 86 | presence_penalty: f32, 87 | frequency_penalty: f32, 88 | repetition_penalty: f32, 89 | temperature: f32, 90 | top_p: f32, 91 | top_k: isize, 92 | use_beam_search: bool, 93 | length_penalty: f32, 94 | early_stopping: EarlyStoppingCondition, 95 | stop: Option, 96 | stop_token_ids: Vec, 97 | ignore_eos: bool, 98 | max_tokens: usize, 99 | logprobs: Option, 100 | prompt_logprobs: Option, 101 | skip_special_tokens: bool, 102 | ) -> Result { 103 | let this = Self { 104 | n, 105 | best_of: best_of.unwrap_or(n), 106 | presence_penalty, 107 | frequency_penalty, 108 | repetition_penalty, 109 | temperature, 110 | top_p, 111 | top_k, 112 | use_beam_search, 113 | length_penalty, 114 | early_stopping, 115 | stop, 116 | stop_token_ids, 117 | ignore_eos, 118 | max_tokens, 119 | logprobs, 120 | prompt_logprobs, 121 | skip_special_tokens, 122 | }; 123 | 124 | this.verify_args()?; 125 | if this.use_beam_search { 126 | this.verify_beam_search()?; 127 | } else { 128 | this.verify_non_beam_search()?; 129 | if this.temperature < SAMPLING_EPS { 130 | this.verify_greedy_sampling()?; 131 | } 132 | } 133 | 134 | Ok(this) 135 | } 136 | 137 | pub fn get_logits_processor<'a>( 138 | &self, 139 | seed: u64, 140 | tokenizer: &'a Tokenizer, 141 | top_n_logprobs: usize, 142 | ) -> LogitsProcessor<'a> { 143 | if self.top_k == -1 && self.top_p == 1. { 144 | // Greedy 145 | LogitsProcessor::new( 146 | seed, 147 | Some(self.temperature.into()), 148 | SamplingMethod::Multinomial, 149 | top_n_logprobs, 150 | tokenizer, 151 | ) 152 | } else if self.top_k > 0 && self.top_p == 1. { 153 | // Top-k 154 | LogitsProcessor::new( 155 | seed, 156 | Some(self.temperature.into()), 157 | SamplingMethod::TopK(self.top_k.try_into().unwrap()), 158 | top_n_logprobs, 159 | tokenizer, 160 | ) 161 | } else if self.top_k == -1 && self.top_p != 1. { 162 | // Top-p 163 | LogitsProcessor::new( 164 | seed, 165 | Some(self.temperature.into()), 166 | SamplingMethod::TopP(self.top_p.into()), 167 | top_n_logprobs, 168 | tokenizer, 169 | ) 170 | } else { 171 | unreachable!() 172 | } 173 | } 174 | 175 | fn verify_args(&self) -> Result<(), APIError> { 176 | if self.n < 1 { 177 | return Err(APIError::new(format!( 178 | "n must be at leas 1, got {}.", 179 | self.n 180 | ))); 181 | } 182 | if self.best_of < self.n { 183 | return Err(APIError::new(format!( 184 | "best_of must be greater than or equal to n, got n={} and best_of={}", 185 | self.n, self.best_of 186 | ))); 187 | } 188 | if !(-2.0..=2.0).contains(&self.presence_penalty) { 189 | return Err(APIError::new(format!( 190 | "presence_penalty must be in [-2, 2], got {}", 191 | self.presence_penalty 192 | ))); 193 | } 194 | if !(-2.0..=2.0).contains(&self.frequency_penalty) { 195 | return Err(APIError::new(format!( 196 | "frequency_penalty must be in [-2, 2], got {}", 197 | self.frequency_penalty 198 | ))); 199 | } 200 | if !(Range { 201 | start: 0.0, 202 | end: 2.0, 203 | }) 204 | .contains(&self.repetition_penalty) 205 | || self.repetition_penalty == 0.0 206 | { 207 | return Err(APIError::new(format!( 208 | "repetition_penalty must be in (0, 2], got {}", 209 | self.repetition_penalty 210 | ))); 211 | } 212 | if self.temperature < 0.0 { 213 | return Err(APIError::new(format!( 214 | "temperature must be non-negative, got {}", 215 | self.temperature 216 | ))); 217 | } 218 | if self.max_tokens < 1 { 219 | return Err(APIError::new(format!( 220 | "max_tokens must be at least 1, got {}", 221 | self.max_tokens 222 | ))); 223 | } 224 | Ok(()) 225 | } 226 | 227 | fn verify_beam_search(&self) -> Result<(), APIError> { 228 | if self.best_of <= 1 { 229 | return Err(APIError::new(format!( 230 | "best_of must be greater than 1 when using beam search. Got {}", 231 | self.best_of 232 | ))); 233 | } 234 | if self.temperature > SAMPLING_EPS { 235 | return Err(APIError::new_str( 236 | "temperature must be 0 when using beam search", 237 | )); 238 | } 239 | if self.top_p < 1.0 - SAMPLING_EPS { 240 | return Err(APIError::new_str("top_p must be 1 when using beam search")); 241 | } 242 | if self.top_k != -1 { 243 | return Err(APIError::new_str("top_k must be -1 when using beam search")); 244 | } 245 | Ok(()) 246 | } 247 | 248 | fn verify_non_beam_search(&self) -> Result<(), APIError> { 249 | if self.early_stopping != EarlyStoppingCondition::UnlikelyBetterCandidates { 250 | return Err(APIError::new_str("early_stopping is not effective and must be UnlikelyBetterCandidates when not using beam search.")); 251 | } 252 | if self.length_penalty < 1.0 - SAMPLING_EPS || self.length_penalty > 1.0 + SAMPLING_EPS { 253 | return Err(APIError::new_str("length_penalty is not effective and must be the default value of 1.0 when not using beam search.")); 254 | } 255 | Ok(()) 256 | } 257 | 258 | fn verify_greedy_sampling(&self) -> Result<(), APIError> { 259 | if self.best_of > 1 { 260 | return Err(APIError::new(format!( 261 | "best_of must be 1 when using greedy sampling. Got {}.", 262 | self.best_of 263 | ))); 264 | } 265 | if self.top_p < 1.0 - SAMPLING_EPS { 266 | return Err(APIError::new_str( 267 | "top_p must be 1 when using greedy sampling.", 268 | )); 269 | } 270 | if self.top_p != -1.0 { 271 | return Err(APIError::new_str( 272 | "top_k must be -1 when using greedy sampling.", 273 | )); 274 | } 275 | Ok(()) 276 | } 277 | } 278 | -------------------------------------------------------------------------------- /src/backend/paged_attention.rs: -------------------------------------------------------------------------------- 1 | use candle_core::{cuda_backend::cudarc::driver::CudaFunction, DType, Tensor}; 2 | 3 | use candle_core::cuda_backend::cudarc::driver::sys as cudarc_sys; 4 | 5 | use crate::openai::responses::APIError; 6 | 7 | fn set_max_dynamic_shared_memory_size(func: CudaFunction, size: usize) { 8 | // let attr = cudarc_sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES; 9 | // func.set_attribute(attr, size.try_into().unwrap()); 10 | } 11 | 12 | const WARP_SIZE: usize = 32; 13 | 14 | #[allow(clippy::too_many_arguments)] 15 | fn paged_attention_v1_launcher( 16 | query: Tensor, // [num_seqs, num_heads, head_size] 17 | key_cache: Tensor, // [num_blocks, num_heads, head_size/x, block_size, x] 18 | value_cache: Tensor, // [num_blocks, num_heads, head_size, block_size] 19 | num_key_value_heads: i32, // [num_heads] 20 | scale: f32, 21 | block_tables: Tensor, // [num_seqs, max_num_blocks_per_seq] 22 | context_lens: Tensor, // [num_seqs] 23 | block_size: usize, 24 | max_context_len: usize, 25 | alibi_slopes: Option, 26 | dtype: DType, 27 | is_fp8_e5m2_kv_cache: bool, 28 | ) -> Tensor { 29 | let num_seqs = query.shape().dims()[0]; 30 | let num_heads = query.shape().dims()[1]; 31 | let head_size = query.shape().dims()[2]; 32 | let max_num_blocks_per_seq = block_tables.shape().dims()[1]; 33 | let q_stride = query.stride()[0]; 34 | let kv_block_stride = key_cache.stride()[0]; 35 | let kv_head_stride = key_cache.stride()[1]; 36 | 37 | let thread_group_size = 1.max(WARP_SIZE / block_size); 38 | debug_assert_eq!(head_size % thread_group_size, 0); 39 | todo!(); 40 | } 41 | 42 | #[allow(clippy::too_many_arguments)] 43 | pub fn paged_attention_v1( 44 | query: Tensor, // [num_seqs, num_heads, head_size] 45 | key_cache: Tensor, // [num_blocks, num_heads, head_size/x, block_size, x] 46 | value_cache: Tensor, // [num_blocks, num_heads, head_size, block_size] 47 | num_key_value_heads: i32, // [num_heads] 48 | scale: f32, 49 | block_tables: Tensor, // [num_seqs, max_num_blocks_per_seq] 50 | context_lens: Tensor, // [num_seqs] 51 | block_size: usize, 52 | max_context_len: usize, 53 | alibi_slopes: Option, 54 | kv_cache_dtype: &str, 55 | ) -> Result { 56 | let query_dtype = query.dtype(); 57 | if kv_cache_dtype == "auto" { 58 | match query_dtype { 59 | DType::F32 | DType::F16 | DType::BF16 => Ok(paged_attention_v1_launcher( 60 | query, 61 | key_cache, 62 | value_cache, 63 | num_key_value_heads, 64 | scale, 65 | block_tables, 66 | context_lens, 67 | block_size, 68 | max_context_len, 69 | alibi_slopes, 70 | query_dtype, 71 | false, 72 | )), 73 | _ => Err(APIError::new(format!( 74 | "Unsupported data type {:?}", 75 | query_dtype 76 | ))), 77 | } 78 | } else if kv_cache_dtype == "fp8_e5m2" { 79 | match query_dtype { 80 | DType::F32 | DType::F16 | DType::BF16 => Ok(paged_attention_v1_launcher( 81 | query, 82 | key_cache, 83 | value_cache, 84 | num_key_value_heads, 85 | scale, 86 | block_tables, 87 | context_lens, 88 | block_size, 89 | max_context_len, 90 | alibi_slopes, 91 | query_dtype, 92 | true, 93 | )), 94 | _ => Err(APIError::new(format!( 95 | "Unsupported data type {:?}", 96 | query_dtype 97 | ))), 98 | } 99 | } else { 100 | Err(APIError::new(format!( 101 | "Unsupported data type {:?}", 102 | query_dtype 103 | ))) 104 | } 105 | } 106 | 107 | #[allow(clippy::too_many_arguments)] 108 | pub fn paged_attention_v2( 109 | _exp_sums: Tensor, 110 | _max_logits: Tensor, 111 | _query: Tensor, 112 | _key_cache: Tensor, 113 | _value_cache: Tensor, 114 | _num_key_value_heads: i32, 115 | _scale: f32, 116 | _block_tables: Tensor, 117 | _context_lens: Tensor, 118 | _block_size: usize, 119 | _max_context_len: usize, 120 | _alibi_slopes: Option, 121 | ) -> Tensor { 122 | todo!(); 123 | } 124 | 125 | /* 126 | #ifndef USE_ROCM 127 | #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ 128 | cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) 129 | #else 130 | #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ 131 | hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) 132 | #endif 133 | 134 | #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ 135 | VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ 136 | ((void*)vllm::paged_attention_v1_kernel), shared_mem_size); \ 138 | vllm::paged_attention_v1_kernel<<>>( \ 140 | out_ptr, \ 141 | query_ptr, \ 142 | key_cache_ptr, \ 143 | value_cache_ptr, \ 144 | num_kv_heads, \ 145 | scale, \ 146 | block_tables_ptr, \ 147 | context_lens_ptr, \ 148 | max_num_blocks_per_seq, \ 149 | alibi_slopes_ptr, \ 150 | q_stride, \ 151 | kv_block_stride, \ 152 | kv_head_stride); 153 | 154 | // TODO(woosuk): Tune NUM_THREADS. 155 | template< 156 | typename T, 157 | typename CACHE_T, 158 | int BLOCK_SIZE, 159 | bool IS_FP8_E5M2_KV_CACHE, 160 | int NUM_THREADS = 128> 161 | void paged_attention_v1_launcher( 162 | torch::Tensor& out, 163 | torch::Tensor& query, 164 | torch::Tensor& key_cache, 165 | torch::Tensor& value_cache, 166 | int num_kv_heads, 167 | float scale, 168 | torch::Tensor& block_tables, 169 | torch::Tensor& context_lens, 170 | int max_context_len, 171 | const c10::optional& alibi_slopes) { 172 | int num_seqs = query.size(0); 173 | int num_heads = query.size(1); 174 | int head_size = query.size(2); 175 | int max_num_blocks_per_seq = block_tables.size(1); 176 | int q_stride = query.stride(0); 177 | int kv_block_stride = key_cache.stride(0); 178 | int kv_head_stride = key_cache.stride(1); 179 | 180 | int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); 181 | assert(head_size % thread_group_size == 0); 182 | 183 | // NOTE: alibi_slopes is optional. 184 | const float* alibi_slopes_ptr = alibi_slopes ? 185 | reinterpret_cast(alibi_slopes.value().data_ptr()) 186 | : nullptr; 187 | 188 | T* out_ptr = reinterpret_cast(out.data_ptr()); 189 | T* query_ptr = reinterpret_cast(query.data_ptr()); 190 | CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); 191 | CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); 192 | int* block_tables_ptr = block_tables.data_ptr(); 193 | int* context_lens_ptr = context_lens.data_ptr(); 194 | 195 | constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; 196 | int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; 197 | int logits_size = padded_max_context_len * sizeof(float); 198 | int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); 199 | // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len 200 | // Keep that in sync with the logic here! 201 | int shared_mem_size = std::max(logits_size, outputs_size); 202 | 203 | dim3 grid(num_heads, num_seqs, 1); 204 | dim3 block(NUM_THREADS); 205 | const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); 206 | const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 207 | switch (head_size) { 208 | // NOTE(woosuk): To reduce the compilation time, we only compile for the 209 | // head sizes that we use in the model. However, we can easily extend this 210 | // to support any head size which is a multiple of 16. 211 | case 64: 212 | LAUNCH_PAGED_ATTENTION_V1(64); 213 | break; 214 | case 80: 215 | LAUNCH_PAGED_ATTENTION_V1(80); 216 | break; 217 | case 96: 218 | LAUNCH_PAGED_ATTENTION_V1(96); 219 | break; 220 | case 112: 221 | LAUNCH_PAGED_ATTENTION_V1(112); 222 | break; 223 | case 128: 224 | LAUNCH_PAGED_ATTENTION_V1(128); 225 | break; 226 | case 256: 227 | LAUNCH_PAGED_ATTENTION_V1(256); 228 | break; 229 | default: 230 | TORCH_CHECK(false, "Unsupported head size: ", head_size); 231 | break; 232 | } 233 | } 234 | 235 | #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \ 236 | paged_attention_v1_launcher( \ 237 | out, \ 238 | query, \ 239 | key_cache, \ 240 | value_cache, \ 241 | num_kv_heads, \ 242 | scale, \ 243 | block_tables, \ 244 | context_lens, \ 245 | max_context_len, \ 246 | alibi_slopes); 247 | 248 | // NOTE(woosuk): To reduce the compilation time, we omitted block sizes 249 | // 1, 2, 4, 64, 128, 256. 250 | #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \ 251 | switch (block_size) { \ 252 | case 8: \ 253 | CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \ 254 | break; \ 255 | case 16: \ 256 | CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \ 257 | break; \ 258 | case 32: \ 259 | CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \ 260 | break; \ 261 | default: \ 262 | TORCH_CHECK(false, "Unsupported block size: ", block_size); \ 263 | break; \ 264 | } 265 | 266 | */ 267 | -------------------------------------------------------------------------------- /src/scheduler/block_engine.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::{hash_map::Entry, HashMap}, 3 | hash::Hash, 4 | marker::PhantomData, 5 | ops::Deref, 6 | sync::{Arc, Mutex, MutexGuard}, 7 | }; 8 | 9 | use super::sequence::{Sequence, SequenceGroup}; 10 | 11 | pub struct LogicalTokenBlock { 12 | tokens: Vec, 13 | block_size: usize, 14 | num_tokens: usize, 15 | } 16 | 17 | impl LogicalTokenBlock { 18 | pub fn new(block_size: usize) -> Self { 19 | Self { 20 | tokens: [0].repeat(block_size), 21 | block_size, 22 | num_tokens: 0, 23 | } 24 | } 25 | 26 | pub fn is_full(&self) -> bool { 27 | self.num_tokens == self.block_size 28 | } 29 | 30 | pub fn append_token_id(&mut self, token: usize) { 31 | assert!(!self.is_full()); 32 | self.tokens[self.num_tokens] = token; 33 | self.num_tokens += 1; 34 | } 35 | 36 | pub fn append_tokens(&mut self, tokens: &[usize]) { 37 | for token in tokens { 38 | self.append_token_id(*token); 39 | } 40 | } 41 | } 42 | 43 | #[derive(Hash, PartialEq, Eq)] 44 | pub struct _PhysicalTokenBlock { 45 | pub block_id: usize, 46 | block_size: usize, 47 | refcount: usize, 48 | is_gpu: bool, 49 | } 50 | 51 | pub struct PhysicalTokenBlock(pub Mutex<_PhysicalTokenBlock>); 52 | 53 | impl PhysicalTokenBlock { 54 | pub fn deref_mut(&self) -> MutexGuard<'_, _PhysicalTokenBlock> { 55 | loop { 56 | if let Ok(v) = self.0.try_lock() { 57 | return v; 58 | } 59 | } 60 | } 61 | } 62 | 63 | impl PartialEq for PhysicalTokenBlock { 64 | fn eq(&self, other: &Self) -> bool { 65 | *self.deref_mut() == *other.deref_mut() 66 | } 67 | } 68 | 69 | impl Hash for PhysicalTokenBlock { 70 | fn hash(&self, state: &mut H) { 71 | self.deref_mut().hash(state) 72 | } 73 | } 74 | 75 | impl Eq for PhysicalTokenBlock {} 76 | 77 | type BlockTable = Vec>; 78 | struct GPUAllocator; 79 | struct CPUAllocator; 80 | 81 | struct GPUAllocatorWrapper(usize); 82 | struct CPUAllocatorWrapper(usize); 83 | 84 | impl Deref for GPUAllocatorWrapper { 85 | type Target = usize; 86 | 87 | fn deref(&self) -> &Self::Target { 88 | &self.0 89 | } 90 | } 91 | 92 | impl Deref for CPUAllocatorWrapper { 93 | type Target = usize; 94 | 95 | fn deref(&self) -> &Self::Target { 96 | &self.0 97 | } 98 | } 99 | 100 | struct Allocator { 101 | free_blocks: BlockTable, 102 | _ghost: PhantomData, 103 | } 104 | 105 | impl Allocator { 106 | fn allocate(&mut self) -> Arc { 107 | let block = self.free_blocks.pop().unwrap(); 108 | block.deref_mut().refcount = 1; 109 | block 110 | } 111 | 112 | fn free_block(&mut self, block: Arc) { 113 | if block.deref_mut().refcount == 0 { 114 | panic!( 115 | "PhysicalTokenBlock with id {} experienced a double free!", 116 | block.deref_mut().block_id 117 | ); 118 | } 119 | block.deref_mut().refcount -= 1; 120 | if block.deref_mut().refcount == 0 { 121 | self.free_blocks.push(block); 122 | } 123 | } 124 | } 125 | 126 | impl Allocator { 127 | fn new(block_size: usize, num_blocks: usize) -> Self { 128 | let mut free_blocks = Vec::new(); 129 | for id in 0..num_blocks { 130 | free_blocks.push(Arc::new(PhysicalTokenBlock(Mutex::new( 131 | _PhysicalTokenBlock { 132 | block_id: id, 133 | block_size, 134 | refcount: 0, 135 | is_gpu: true, 136 | }, 137 | )))) 138 | } 139 | Allocator { 140 | free_blocks, 141 | _ghost: PhantomData, 142 | } 143 | } 144 | 145 | fn get_num_free_blocks(&self) -> GPUAllocatorWrapper { 146 | GPUAllocatorWrapper(self.free_blocks.len()) 147 | } 148 | } 149 | 150 | impl Allocator { 151 | fn new(block_size: usize, num_blocks: usize) -> Self { 152 | let mut free_blocks = Vec::new(); 153 | for id in 0..num_blocks { 154 | free_blocks.push(Arc::new(PhysicalTokenBlock(Mutex::new( 155 | _PhysicalTokenBlock { 156 | block_id: id, 157 | block_size, 158 | refcount: 0, 159 | is_gpu: true, 160 | }, 161 | )))) 162 | } 163 | Allocator { 164 | free_blocks, 165 | _ghost: PhantomData, 166 | } 167 | } 168 | } 169 | 170 | pub enum AllocStatus { 171 | Ok, 172 | Later, 173 | Impossible, 174 | } 175 | 176 | type SeqID = usize; 177 | 178 | /// A BlockEngine maps each Sequence (identified by its SeqID), to physical token blocks. 179 | /// The physical token blocks may not match the logical token blocks because during 180 | /// scheduling, physical blocks are allocated to accommodate the new tokens generated. 181 | /// These new tokens will be added to the logical token block for each sequence. 182 | pub struct BlockEngine { 183 | num_gpu_blocks: usize, 184 | gpu_allocator: Allocator, 185 | cpu_allocator: Allocator, 186 | pub block_tables: HashMap, 187 | } 188 | 189 | impl BlockEngine { 190 | #[must_use] 191 | pub fn new(block_size: usize, num_gpu_blocks: usize, num_cpu_blocks: usize) -> Self { 192 | Self { 193 | num_gpu_blocks, 194 | gpu_allocator: Allocator::::new(block_size, num_gpu_blocks), 195 | cpu_allocator: Allocator::::new(block_size, num_cpu_blocks), 196 | block_tables: HashMap::new(), 197 | } 198 | } 199 | 200 | pub fn can_allocate(&self, seq_group: &SequenceGroup) -> AllocStatus { 201 | let num_required_blocks = seq_group.get_total_logical_token_blocks(); 202 | let num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks(); 203 | 204 | if self.num_gpu_blocks > *num_free_gpu_blocks + num_required_blocks { 205 | AllocStatus::Later 206 | } else if self.num_gpu_blocks < num_required_blocks { 207 | AllocStatus::Impossible 208 | } else { 209 | AllocStatus::Ok 210 | } 211 | } 212 | 213 | pub fn allocate(&mut self, seq_group: &SequenceGroup) { 214 | let mut block_table = Vec::new(); 215 | for _logcical_idx in 0..seq_group.get_total_logical_token_blocks() { 216 | block_table.push(self.gpu_allocator.allocate()); 217 | } 218 | for seq_id in seq_group.get_seqs().keys() { 219 | self.block_tables.insert(*seq_id, block_table.clone()); 220 | } 221 | } 222 | 223 | pub fn can_append_token_to_seq(&self, seq_group: &SequenceGroup) -> bool { 224 | let free_blocks = self.gpu_allocator.get_num_free_blocks(); 225 | // Physical blocks = logical blocks 226 | seq_group.total_blocks_to_add_new_tok() <= *free_blocks 227 | } 228 | 229 | pub fn free_sequence(&mut self, sequence: &Sequence) { 230 | let block_table = self 231 | .block_tables 232 | .get(&sequence.deref_mut().get_id()) 233 | .unwrap(); 234 | 235 | // Free from block table 236 | for block in block_table { 237 | if block.deref_mut().is_gpu { 238 | self.gpu_allocator.free_block(block.clone()) 239 | } else { 240 | self.cpu_allocator.free_block(block.clone()) 241 | } 242 | } 243 | 244 | self.block_tables.remove(&sequence.deref_mut().get_id()); 245 | } 246 | 247 | pub fn can_swap_out_seq_group(&self, seq_group: &SequenceGroup) -> bool { 248 | let blocks_required: usize = self 249 | .block_tables 250 | .iter() 251 | .filter(|(id, _)| seq_group.get_seqs().contains_key(id)) 252 | .map(|(_, table)| table.len()) 253 | .sum(); 254 | blocks_required <= self.cpu_allocator.free_blocks.len() 255 | } 256 | 257 | /// Update the block table so that the sequence does no longer reserve any GPU 258 | /// physical blocks, and only has CPU physical blocks. 259 | pub fn swap_out(&mut self, seq_group: &SequenceGroup) -> HashMap { 260 | // GPU block to a CPU block 261 | let mut new_mapping = HashMap::new(); 262 | for seq_id in seq_group.get_seqs().keys() { 263 | let mut new_block_table = Vec::new(); 264 | let block_table = self.block_tables.get(seq_id).unwrap(); 265 | 266 | for gpu_block in block_table { 267 | let cpu_block = 268 | if let Entry::Vacant(e) = new_mapping.entry(gpu_block.deref_mut().block_id) { 269 | // Create a new block 270 | let cpu_block = self.cpu_allocator.allocate(); 271 | e.insert(cpu_block.clone()); 272 | cpu_block 273 | } else { 274 | // Reuse a block 275 | let cpu_block = new_mapping 276 | .get(&gpu_block.deref_mut().block_id) 277 | .unwrap() 278 | .clone(); 279 | cpu_block.deref_mut().refcount += 1; 280 | cpu_block 281 | }; 282 | new_block_table.push(cpu_block); 283 | self.gpu_allocator.free_block(gpu_block.clone()); 284 | } 285 | self.block_tables.insert(*seq_id, new_block_table); 286 | } 287 | 288 | new_mapping 289 | .iter() 290 | .map(|(k, v)| (*k, v.deref_mut().block_id)) 291 | .collect::>() 292 | } 293 | 294 | // Returns the COW mapping (src, dst). 295 | // COW is performed if there are multiple references to the last physical block. 296 | pub fn append_token_slot_to_seq(&mut self, sequence: &Sequence) -> Option<(usize, usize)> { 297 | let table = self 298 | .block_tables 299 | .get_mut(&sequence.deref_mut().get_id()) 300 | .unwrap(); 301 | 302 | match sequence.deref_mut().blocks_to_add_new_tok() { 303 | 1 => { 304 | table.push(self.gpu_allocator.allocate()); 305 | None 306 | } 307 | 0 => { 308 | let last_block = table.last_mut().unwrap(); 309 | assert!(last_block.deref_mut().is_gpu); 310 | if last_block.deref_mut().refcount == 1 { 311 | None 312 | } else { 313 | // We would be writing into shared, so COW. 314 | let new_block = self.gpu_allocator.allocate(); 315 | self.gpu_allocator.free_block(last_block.clone()); 316 | let old_number = last_block.deref_mut().block_id; 317 | let new_number = new_block.deref_mut().block_id; 318 | *last_block = new_block; 319 | Some((old_number, new_number)) 320 | } 321 | } 322 | _ => { 323 | unreachable!() 324 | } 325 | } 326 | } 327 | 328 | pub fn can_swap_in_seq_group(&self, seq_group: &SequenceGroup) -> bool { 329 | let blocks_required: usize = self 330 | .block_tables 331 | .iter() 332 | .filter(|(id, _)| seq_group.get_seqs().contains_key(id)) 333 | .map(|(_, table)| table.len()) 334 | .sum(); 335 | blocks_required <= self.gpu_allocator.free_blocks.len() 336 | } 337 | 338 | /// Update the block table so that the sequence does no longer reserve any CPU 339 | /// physical blocks, and only has GPU physical blocks. 340 | pub fn swap_in(&mut self, seq_group: &SequenceGroup) -> HashMap { 341 | // CPU block to a GPU block 342 | let mut new_mapping = HashMap::new(); 343 | for seq_id in seq_group.get_seqs().keys() { 344 | let mut new_block_table = Vec::new(); 345 | let block_table = self.block_tables.get(seq_id).unwrap(); 346 | 347 | for cpu_block in block_table { 348 | let gpu_block = 349 | if let Entry::Vacant(e) = new_mapping.entry(cpu_block.deref_mut().block_id) { 350 | // Create a new block 351 | let gpu_block = self.cpu_allocator.allocate(); 352 | e.insert(gpu_block.clone()); 353 | gpu_block 354 | } else { 355 | // Reuse a block 356 | let gpu_block = new_mapping 357 | .get(&cpu_block.deref_mut().block_id) 358 | .unwrap() 359 | .clone(); 360 | gpu_block.deref_mut().refcount += 1; 361 | gpu_block 362 | }; 363 | new_block_table.push(gpu_block); 364 | self.gpu_allocator.free_block(cpu_block.clone()); 365 | } 366 | self.block_tables.insert(*seq_id, new_block_table); 367 | } 368 | 369 | new_mapping 370 | .iter() 371 | .map(|(k, v)| (*k, v.deref_mut().block_id)) 372 | .collect::>() 373 | } 374 | } 375 | -------------------------------------------------------------------------------- /src/openai/conversation/default_conversation.rs: -------------------------------------------------------------------------------- 1 | use dyn_fmt::AsStrFormatExt; 2 | 3 | use super::Conversation; 4 | 5 | pub const ROLES: (&str, &str) = ("USER", "ASSISTANT"); 6 | pub const SYSTEM_TEMPLATE: &str = "{}"; 7 | pub const DEFAULT_SEP: &str = "\n"; 8 | 9 | /// Separator style for default conversation. 10 | #[derive(Default)] 11 | pub enum SeparatorStyle { 12 | #[default] 13 | AddColonSingle, 14 | AddColonTwo, 15 | AddColonSpaceSingle, 16 | NoColonSingle, 17 | NoColonTwo, 18 | AddNewLineSingle, 19 | Llama2, 20 | ChatGLM, 21 | ChatML, 22 | ChatIntern, 23 | Dolly, 24 | RWKV, 25 | Phoenix, 26 | Robin, 27 | FalconChat, 28 | } 29 | 30 | /// A struct for managing prompt templates and conversation history. 31 | #[allow(dead_code)] 32 | pub struct DefaultConversation { 33 | name: String, 34 | system_message: String, 35 | system_template: String, 36 | messages: Vec, 37 | offset: usize, 38 | sep_style: SeparatorStyle, 39 | stop_criteria: String, 40 | stop_token_ids: Vec, 41 | roles: (String, String), 42 | sep: String, 43 | sep2: Option, 44 | } 45 | 46 | /// Default conversion separators 47 | pub struct DefaultConversationSeparators { 48 | pub sep: String, 49 | pub sep2: Option, 50 | } 51 | 52 | /// A message in a conversation 53 | pub struct Message((String, Option)); 54 | 55 | impl Message { 56 | pub fn new(message: (String, String)) -> Message { 57 | Message((message.0, Some(message.1))) 58 | } 59 | } 60 | 61 | impl DefaultConversation { 62 | #[allow(clippy::too_many_arguments)] 63 | pub fn new( 64 | name: String, 65 | system_template: String, 66 | messages: Vec, 67 | offset: usize, 68 | sep_style: SeparatorStyle, 69 | stop_criteria: String, 70 | stop_token_ids: Vec, 71 | roles: (String, String), 72 | seps: DefaultConversationSeparators, 73 | ) -> Self { 74 | Self { 75 | name, 76 | system_message: "".to_string(), 77 | system_template, 78 | messages, 79 | offset, 80 | sep_style, 81 | stop_criteria, 82 | stop_token_ids, 83 | roles, 84 | sep: seps.sep, 85 | sep2: seps.sep2, 86 | } 87 | } 88 | } 89 | 90 | impl Conversation for DefaultConversation { 91 | /// Set the system message. 92 | fn set_system_message(&mut self, system_message: String) { 93 | self.system_message = system_message; 94 | } 95 | 96 | /// Append a new message. 97 | fn append_message(&mut self, role: String, message: String) { 98 | self.messages.push(Message((role, Some(message)))); 99 | } 100 | 101 | /// Append a new `None` message. 102 | fn append_none_message(&mut self, role: String) { 103 | self.messages.push(Message((role, None))); 104 | } 105 | 106 | /// Set the last message to `None`. 107 | fn update_last_message(&mut self) { 108 | self.messages.last_mut().unwrap().0 .1 = None; 109 | } 110 | 111 | fn get_roles(&self) -> &(String, String) { 112 | &self.roles 113 | } 114 | 115 | /// Convert this conversation to a String prompt 116 | fn get_prompt(&mut self) -> String { 117 | let system_prompt = self.system_template.format(&[self.system_message.clone()]); 118 | match self.sep_style { 119 | SeparatorStyle::AddColonSingle => { 120 | let mut accum = system_prompt + &self.sep; 121 | for message in &self.messages { 122 | let Message((role, message)) = message; 123 | if let Some(message) = message { 124 | accum += &format!("{role}: {message}{}", self.sep); 125 | } else { 126 | accum += &format!("{role}:"); 127 | } 128 | } 129 | accum 130 | } 131 | 132 | SeparatorStyle::AddColonTwo => { 133 | let seps = [&self.sep, &self.sep2.clone().unwrap_or("".to_string())]; 134 | let mut accum = system_prompt + &self.sep; 135 | for (i, message) in self.messages.iter().enumerate() { 136 | let Message((role, message)) = message; 137 | if let Some(message) = message { 138 | accum += &format!("{role}: {message}{}", seps[i % 2]); 139 | } else { 140 | accum += &format!("{role}:"); 141 | } 142 | } 143 | accum 144 | } 145 | 146 | SeparatorStyle::AddColonSpaceSingle => { 147 | let mut accum = system_prompt + &self.sep; 148 | for message in &self.messages { 149 | let Message((role, message)) = message; 150 | if let Some(message) = message { 151 | accum += &format!("{role}: {message}{}", self.sep); 152 | } else { 153 | accum += &format!("{role}: "); //must end with space 154 | } 155 | } 156 | accum 157 | } 158 | 159 | SeparatorStyle::AddNewLineSingle => { 160 | let mut accum = if system_prompt.is_empty() { 161 | "".to_string() 162 | } else { 163 | system_prompt.clone() + &self.sep 164 | }; 165 | for message in &self.messages { 166 | let Message((role, message)) = message; 167 | if let Some(message) = message { 168 | accum += &format!("{role}\n{message}{}", self.sep); 169 | } else { 170 | accum += &format!("{role}\n"); 171 | } 172 | } 173 | accum 174 | } 175 | 176 | SeparatorStyle::NoColonSingle => { 177 | let mut accum = system_prompt.clone(); 178 | for message in &self.messages { 179 | let Message((role, message)) = message; 180 | if let Some(message) = message { 181 | accum += &format!("{role}{message}{}", self.sep); 182 | } else { 183 | accum += role; 184 | } 185 | } 186 | accum 187 | } 188 | 189 | SeparatorStyle::NoColonTwo => { 190 | let seps = [&self.sep, &self.sep2.clone().unwrap_or("".to_string())]; 191 | let mut accum = system_prompt.clone(); 192 | for (i, message) in self.messages.iter().enumerate() { 193 | let Message((role, message)) = message; 194 | if let Some(message) = message { 195 | accum += &format!("{role}{message}{}", seps[i % 2]); 196 | } else { 197 | accum += role; 198 | } 199 | } 200 | accum 201 | } 202 | 203 | SeparatorStyle::RWKV => { 204 | let mut accum = system_prompt.clone() + &self.sep; 205 | for message in &self.messages { 206 | let Message((role, message)) = message; 207 | if let Some(message) = message { 208 | accum += &format!( 209 | "{role}: {}\n\n", 210 | message.replace("\r\n", "\n").replace("\n\n", "\n") 211 | ); 212 | } else { 213 | accum += &format!("{role}:"); 214 | } 215 | } 216 | accum 217 | } 218 | 219 | SeparatorStyle::Llama2 => { 220 | let seps = [&self.sep, &self.sep2.clone().unwrap_or("".to_string())]; 221 | let mut accum = if !system_prompt.is_empty() { 222 | system_prompt.clone() 223 | } else { 224 | "[INST] ".to_string() 225 | }; 226 | for (i, message) in self.messages.iter().enumerate() { 227 | let Message((_role, message)) = message; 228 | 229 | let tag = &[self.roles.0.clone(), self.roles.1.clone()][i % 2]; 230 | 231 | if let Some(message) = message { 232 | if i == 0 { 233 | accum += &format!("{message} "); 234 | } else { 235 | accum += &format!("{tag} {message}{}", seps[i % 2]); 236 | } 237 | } else { 238 | accum += tag; 239 | } 240 | } 241 | accum 242 | } 243 | 244 | SeparatorStyle::ChatGLM => { 245 | let round_add_n = if self.name == "chatglm2" { 1 } else { 0 }; 246 | 247 | let mut accum = if !system_prompt.is_empty() { 248 | system_prompt.clone() 249 | } else { 250 | "".to_string() 251 | }; 252 | 253 | for (i, message) in self.messages.iter().enumerate() { 254 | if i % 2 == 0 { 255 | accum += &format!("[Round {}]{}", i / 2 + round_add_n, self.sep); 256 | } 257 | let Message((role, message)) = message; 258 | if let Some(message) = message { 259 | accum += &format!("{role}: {message}{}", self.sep); 260 | } else { 261 | accum += &format!("{role}: "); 262 | } 263 | } 264 | accum 265 | } 266 | 267 | SeparatorStyle::ChatML => { 268 | let mut accum = if !system_prompt.is_empty() { 269 | format!("{}{}\n", system_prompt, self.sep) 270 | } else { 271 | "".to_string() 272 | }; 273 | for message in &self.messages { 274 | let Message((role, message)) = message; 275 | if let Some(message) = message { 276 | accum += &format!("{role}\n{message}{}\n", self.sep); 277 | } else { 278 | accum += &format!("{role}\n"); 279 | } 280 | } 281 | accum 282 | } 283 | 284 | SeparatorStyle::ChatIntern => { 285 | let seps = [&self.sep, &self.sep2.clone().unwrap_or("".to_string())]; 286 | let mut accum = system_prompt.clone(); 287 | for (i, message) in self.messages.iter().enumerate() { 288 | if i % 2 == 0 { 289 | accum += ""; 290 | } 291 | 292 | let Message((role, message)) = message; 293 | 294 | if let Some(message) = message { 295 | accum += &format!("{role}:{message}{}\n", seps[i % 2]); 296 | } else { 297 | accum += &format!("{role}:"); 298 | } 299 | } 300 | accum 301 | } 302 | 303 | SeparatorStyle::Dolly => { 304 | let seps = [&self.sep, &self.sep2.clone().unwrap_or("".to_string())]; 305 | let mut accum = system_prompt.clone(); 306 | for (i, message) in self.messages.iter().enumerate() { 307 | let Message((role, message)) = message; 308 | 309 | if let Some(message) = message { 310 | accum += &format!("{role}:\n{message}{}", seps[i % 2]); 311 | if i % 2 == 1 { 312 | accum += "\n\n"; 313 | } 314 | } else { 315 | accum += &format!("{role}:\n"); 316 | } 317 | } 318 | accum 319 | } 320 | 321 | SeparatorStyle::Phoenix => { 322 | let mut accum = system_prompt.clone() + &self.sep; 323 | for message in &self.messages { 324 | let Message((role, message)) = message; 325 | if let Some(message) = message { 326 | accum += &format!("{role}: {message}"); 327 | } else { 328 | accum += &format!("{role}: "); 329 | } 330 | } 331 | accum 332 | } 333 | 334 | SeparatorStyle::Robin => { 335 | let mut accum = system_prompt.clone() + &self.sep; 336 | for message in &self.messages { 337 | let Message((role, message)) = message; 338 | if let Some(message) = message { 339 | accum += &format!("{role}:\n{message}{}", self.sep); 340 | } else { 341 | accum += &format!("{role}:\n"); 342 | } 343 | } 344 | accum 345 | } 346 | 347 | SeparatorStyle::FalconChat => { 348 | let mut accum = "".to_string(); 349 | if !system_prompt.is_empty() { 350 | accum += &format!("{}{}", system_prompt, self.sep) 351 | } 352 | for message in &self.messages { 353 | let Message((role, message)) = message; 354 | if let Some(message) = message { 355 | accum += &format!("{role}: {message}{}", self.sep); 356 | } else { 357 | accum += &format!("{role}:"); 358 | } 359 | } 360 | accum 361 | } 362 | } 363 | } 364 | } 365 | -------------------------------------------------------------------------------- /src/scheduler/mod.rs: -------------------------------------------------------------------------------- 1 | //! The Scheduler uses a BlockEngine to schedule and automatically batch sequences. The 2 | //! primary method `schedule` returns the batched sequences as inputs, as well as the 3 | //! operations to be executed on the cache by the CacheEngine. 4 | 5 | /// The higher-level manager of the blocks allocated. Operations performed by the block engine do 6 | /// not directly change memory. 7 | pub mod block_engine; 8 | /// This is the lower-level manager of the cache. It manages swapping and copying the blocks and 9 | /// actually allocates the KV cache for the CPU and GPU. It is used by the LLMEngine to execute 10 | /// operations issued by the scheduler. 11 | pub mod cache_engine; 12 | pub mod sequence; 13 | 14 | type CPUBlockFrom = usize; 15 | type GPUBlockFrom = usize; 16 | type CPUBlockTo = usize; 17 | type GPUBlockTo = usize; 18 | type SrcBlockFrom = usize; 19 | type DstBlocksTo = Vec; 20 | 21 | use std::{ 22 | collections::{HashMap, VecDeque}, 23 | sync::Arc, 24 | }; 25 | 26 | use crate::{ 27 | log_warning, 28 | scheduler::{block_engine::AllocStatus, sequence::SequenceStatus}, 29 | }; 30 | 31 | use self::{block_engine::BlockEngine, cache_engine::CacheConfig, sequence::SequenceGroup}; 32 | 33 | pub struct SchedulerOutput { 34 | pub scheduled: Arc>>, 35 | pub blocks_to_swap_in: HashMap, 36 | pub blocks_to_swap_out: HashMap, 37 | pub blocks_to_copy: HashMap, 38 | pub ignored_seq_groups: Arc>>, 39 | } 40 | 41 | pub struct SchedulerConfig { 42 | pub max_num_seqs: usize, 43 | } 44 | 45 | pub struct Scheduler { 46 | waiting: VecDeque>, 47 | running: VecDeque>, 48 | swapped_out: VecDeque>, 49 | config: SchedulerConfig, 50 | pub block_engine: BlockEngine, 51 | } 52 | 53 | impl Scheduler { 54 | pub fn new(config: SchedulerConfig, cache_config: &CacheConfig) -> Self { 55 | assert!(cache_config.fully_init); 56 | Self { 57 | waiting: VecDeque::new(), 58 | running: VecDeque::new(), 59 | swapped_out: VecDeque::new(), 60 | config, 61 | block_engine: BlockEngine::new( 62 | cache_config.block_size, 63 | cache_config.num_gpu_blocks.unwrap(), 64 | cache_config.num_cpu_blocks.unwrap(), 65 | ), 66 | } 67 | } 68 | 69 | pub fn add_sequence(&mut self, seq_group: SequenceGroup) { 70 | self.waiting.push_back(Arc::new(seq_group)); 71 | } 72 | 73 | pub fn schedule(&mut self) -> SchedulerOutput { 74 | // If there are no swapped seqs (they have higher priority), add seqs that are in the 75 | // waiting queue to the running queue. 76 | if self.swapped_out.is_empty() { 77 | let mut scheduled = VecDeque::new(); 78 | let mut ignored_seq_groups = VecDeque::new(); 79 | while !self.waiting.is_empty() { 80 | let seq_group = self.waiting.front().unwrap().clone(); 81 | 82 | // If adding this seq means we will have too many, stop as no more could be added. 83 | if self.config.max_num_seqs 84 | == self 85 | .running 86 | .iter() 87 | .map(|group| group.get_seqs().len()) 88 | .sum::() 89 | + 1 90 | { 91 | break; 92 | } 93 | 94 | // If we cannot allocate either now or in the future, either do not continue or remove the sequence. 95 | let can_allocate = self.block_engine.can_allocate(&seq_group); 96 | match can_allocate { 97 | AllocStatus::Later => break, //If we can only allocate later, do not bother iterating over the rest. 98 | AllocStatus::Impossible => { 99 | log_warning( 100 | &format!("Input prompt with length of {} tokens is too long and exceeds capacity of block engine.", 101 | seq_group.get_prompt_len()) 102 | ); 103 | seq_group.set_status(SequenceStatus::FinishedIgnored); 104 | ignored_seq_groups.push_back(self.waiting.pop_front().unwrap()); 105 | } 106 | _ => {} 107 | } 108 | 109 | seq_group.set_status(SequenceStatus::Running); 110 | self._allocate(&seq_group); 111 | 112 | let seq_group = self.waiting.pop_front().unwrap(); 113 | self.running.push_back(seq_group.clone()); 114 | scheduled.push_back(seq_group); 115 | } 116 | 117 | // If we did schedule, or we ignored sequences. 118 | if !scheduled.is_empty() || !ignored_seq_groups.is_empty() { 119 | return SchedulerOutput { 120 | scheduled: Arc::new(scheduled), 121 | blocks_to_swap_in: HashMap::new(), 122 | blocks_to_copy: HashMap::new(), 123 | blocks_to_swap_out: HashMap::new(), 124 | ignored_seq_groups: Arc::new(ignored_seq_groups), 125 | }; 126 | } 127 | } 128 | 129 | let mut blocks_to_swap_out = HashMap::new(); 130 | let mut blocks_to_swap_in = HashMap::new(); 131 | let mut blocks_to_copy = HashMap::new(); 132 | 133 | // Reserve token slots for the running sequence groups, preempting the lowest (earliest) first. 134 | // Preempt lowest priority sequences that are in the running queue, forming a 135 | // new running queue that has the actually running sequences. Remember the preempted 136 | // sequences, which will be put into the waiting or swapped out state depending on 137 | // the preemption method (recompute or swap, respectively). 138 | 139 | // Sorts by creation time, in descending order so that earliest are latest (first come first serve). 140 | self.sort_running_by_priority_fcfs(); 141 | 142 | let mut running = VecDeque::new(); 143 | let mut preempted = VecDeque::new(); 144 | while !self.running.is_empty() { 145 | let seq_group = self.running.pop_front().unwrap(); 146 | let mut finished_with_break = false; 147 | while !self.block_engine.can_append_token_to_seq(&seq_group) { 148 | // If we cannot, now we need to preempt some seqs 149 | if !self.running.is_empty() { 150 | // There is something to preempt. 151 | let seq_to_preempt = self.running.pop_back().unwrap(); 152 | self._preempt(seq_to_preempt.clone(), &mut blocks_to_swap_out); 153 | preempted.push_back(seq_to_preempt); 154 | } else { 155 | // Nothing to preempt, preempt ourselves. Also, do not bother looking at anything else. 156 | self._preempt(seq_group.clone(), &mut blocks_to_swap_out); 157 | preempted.push_back(seq_group.clone()); 158 | finished_with_break = true; 159 | break; 160 | } 161 | } 162 | if !finished_with_break { 163 | // If we need to, append physical blocks for a new token. We do not need to if there is enough space. 164 | // If we just got preempted, there is no reason to allocate 165 | self._append_token_slot_to_seq_group(&seq_group, &mut blocks_to_copy); 166 | running.push_back(seq_group); 167 | } 168 | } 169 | self.running = running; 170 | 171 | // Try to swap in the swapped out sequences and add these to the 172 | // running state if possible. 173 | 174 | // Sorts by creation time, in descending order so that earliest are latest (first come first serve). 175 | self.sort_swapped_out_by_priority_fcfs(); 176 | 177 | if preempted.is_empty() { 178 | while !self.swapped_out.is_empty() { 179 | let seq_group = self.swapped_out.front().unwrap(); 180 | 181 | // If the GPU cannot handle the group being swapped in, stop 182 | if !self.block_engine.can_swap_in_seq_group(seq_group) { 183 | break; 184 | } 185 | 186 | let seq_group = self.swapped_out.pop_front().unwrap(); 187 | // Swap in the blocks 188 | let to_swap_in = self.block_engine.swap_in(&seq_group); 189 | blocks_to_swap_in.extend(to_swap_in); 190 | // Reserve a new slot 191 | self._append_token_slot_to_seq_group(&seq_group, &mut blocks_to_copy); 192 | self.running.push_back(seq_group); 193 | } 194 | } 195 | 196 | SchedulerOutput { 197 | scheduled: self.running.clone().into(), 198 | blocks_to_swap_in, 199 | blocks_to_copy, 200 | blocks_to_swap_out, 201 | ignored_seq_groups: Arc::new(VecDeque::new()), 202 | } 203 | } 204 | 205 | pub fn has_unfinished_sequences(&self) -> bool { 206 | !self.running.is_empty() 207 | } 208 | 209 | pub fn free_finished_sequence_groups(&mut self) { 210 | let mut to_free = Vec::new(); 211 | let clone = self.running.clone(); 212 | self.running = clone 213 | .iter() 214 | .filter(|group| { 215 | if group.is_finished() { 216 | to_free.push((*group).clone()); 217 | false 218 | } else { 219 | true 220 | } 221 | }) 222 | .cloned() 223 | .collect::>(); 224 | for group in to_free { 225 | self._free(&group); 226 | } 227 | } 228 | } 229 | 230 | impl Scheduler { 231 | fn remove_seq_group(&mut self, seq_group: &SequenceGroup) { 232 | // Remove it if it is in waiting 233 | if let Some(idx) = self 234 | .waiting 235 | .iter() 236 | .position(|grp| grp.get_id() == seq_group.get_id()) 237 | { 238 | self.waiting.remove(idx); 239 | }; 240 | // Remove it if it is in running 241 | if let Some(idx) = self 242 | .running 243 | .iter() 244 | .position(|grp| grp.get_id() == seq_group.get_id()) 245 | { 246 | self.running.remove(idx); 247 | }; 248 | // Remove it if it is in swapped out 249 | if let Some(idx) = self 250 | .swapped_out 251 | .iter() 252 | .position(|grp| grp.get_id() == seq_group.get_id()) 253 | { 254 | self.swapped_out.remove(idx); 255 | }; 256 | } 257 | fn _append_token_slot_to_seq_group( 258 | &mut self, 259 | seq_group: &SequenceGroup, 260 | blocks_to_copy: &mut HashMap>, 261 | ) { 262 | for seq in seq_group.get_seqs().values() { 263 | let op = self.block_engine.append_token_slot_to_seq(seq); 264 | if let Some((src_block, dst_block)) = op { 265 | if let std::collections::hash_map::Entry::Vacant(e) = 266 | blocks_to_copy.entry(src_block) 267 | { 268 | e.insert(vec![dst_block]); 269 | } else { 270 | blocks_to_copy.get_mut(&src_block).unwrap().push(dst_block); 271 | } 272 | } 273 | } 274 | } 275 | 276 | fn _abort_seq_group(&mut self, seq_group: &SequenceGroup) { 277 | self.remove_seq_group(seq_group); 278 | seq_group.set_status(SequenceStatus::FinishedAborted); 279 | self._free(seq_group); 280 | } 281 | 282 | /// Preempt either by recomputation (for single sequence), or by swapping (for multiple). 283 | fn _preempt( 284 | &mut self, 285 | seq_group: Arc, 286 | blocks_to_swap_out: &mut HashMap, 287 | ) { 288 | match seq_group.get_seqs().len() { 289 | 1 => self._preempt_by_recompute(seq_group), 290 | _ => self._preempt_by_swap(seq_group, blocks_to_swap_out), 291 | } 292 | } 293 | 294 | fn _preempt_by_recompute(&mut self, seq_group: Arc) { 295 | seq_group.set_status(SequenceStatus::Waiting); 296 | self._free(&seq_group); 297 | self.waiting.push_front(seq_group); 298 | } 299 | 300 | fn _preempt_by_swap( 301 | &mut self, 302 | seq_group: Arc, 303 | blocks_to_swap_out: &mut HashMap, 304 | ) { 305 | if !self.block_engine.can_swap_out_seq_group(&seq_group) { 306 | // If we cannot swap it out, abort the sequence group. 307 | self._abort_seq_group(&seq_group); 308 | return; 309 | } 310 | let new_to_swap = self.block_engine.swap_out(&seq_group); 311 | blocks_to_swap_out.extend(new_to_swap); 312 | seq_group.set_status(SequenceStatus::Swapped); 313 | 314 | self.swapped_out.push_back(seq_group); 315 | } 316 | 317 | fn _allocate(&mut self, seq_group: &SequenceGroup) { 318 | self.block_engine.allocate(seq_group) 319 | } 320 | 321 | fn _free(&mut self, seq_group: &SequenceGroup) { 322 | for seq in seq_group.get_seqs().values() { 323 | self.block_engine.free_sequence(seq); 324 | } 325 | } 326 | 327 | fn sort_running_by_priority_fcfs(&mut self) { 328 | self.running 329 | .make_contiguous() 330 | .sort_by_key(|seq_group| seq_group.arrival_time()); 331 | self.running.make_contiguous().reverse(); 332 | } 333 | 334 | fn sort_swapped_out_by_priority_fcfs(&mut self) { 335 | self.swapped_out 336 | .make_contiguous() 337 | .sort_by_key(|seq_group| seq_group.arrival_time()); 338 | self.swapped_out.make_contiguous().reverse(); 339 | } 340 | } 341 | -------------------------------------------------------------------------------- /src/backend/cache.rs: -------------------------------------------------------------------------------- 1 | use std::{collections::HashMap, iter::zip, ptr::NonNull}; 2 | 3 | use candle_core::{ 4 | cuda_backend::cudarc::driver::{CudaSlice, DevicePtr, LaunchAsync, LaunchConfig}, 5 | DType, Device, IndexOp, Storage, Tensor, 6 | }; 7 | 8 | use crate::{ 9 | backend::{ 10 | dispatch_get_cuda_pointer, get_or_load_func, Conjoined, COPY_BLOCKS_KERNEL, COPY_BLOCKS_PTX, 11 | }, 12 | openai::responses::APIError, 13 | try_api, 14 | }; 15 | 16 | use super::{RESHAPE_AND_CACHE_KERNEL, RESHAPE_AND_CACHE_PTX}; 17 | 18 | /// # Safety 19 | /// Unsafe due to passing pointers 20 | pub unsafe fn reshape_and_cache( 21 | key: Tensor, // [num_tokens, num_heads, head_size] 22 | value: Tensor, // [num_tokens, num_heads, head_size] 23 | key_cache: &mut Tensor, // [num_blocks, num_heads, head_size/x, block_size, x] 24 | value_cache: &mut Tensor, // [num_blocks, num_heads, head_size, block_size] 25 | slot_mapping: Tensor, // [num_tokens] 26 | ) -> Result<(), APIError> { 27 | let cache_dev = key.device(); 28 | let Device::Cuda(dev) = cache_dev else { 29 | panic!("Expected the key to be on a CUDA device.") 30 | }; 31 | 32 | if slot_mapping.dtype() != DType::I64 { 33 | return Err(APIError::new(format!( 34 | "`slot_mapping` has {:?} type, expected I64 type.", 35 | slot_mapping.dtype() 36 | ))); 37 | } 38 | 39 | if key.dtype() != value.dtype() { 40 | return Err(APIError::new(format!( 41 | "`key` and `value` have different data types, got {:?} and {:?} respectively.", 42 | key.dtype(), 43 | value.dtype() 44 | ))); 45 | } 46 | 47 | if key.dtype() != key_cache.dtype() { 48 | return Err(APIError::new(format!( 49 | "`key` and `key_cache` have different data types, got {:?} and {:?} respectively.", 50 | key.dtype(), 51 | key_cache.dtype() 52 | ))); 53 | } 54 | 55 | if key.dtype() != value_cache.dtype() { 56 | return Err(APIError::new(format!( 57 | "`key` and `value_cache` have different data types, got {:?} and {:?} respectively.", 58 | key.dtype(), 59 | value_cache.dtype() 60 | ))); 61 | } 62 | 63 | if !key.device().is_cuda() { 64 | return Err(APIError::new(format!( 65 | "`key` must be on a CUDA device, got {:?}.", 66 | key.device() 67 | ))); 68 | } 69 | 70 | if !key.device().same_device(value.device()) { 71 | return Err(APIError::new(format!( 72 | "`key` and `value` have different devices, got {:?} and {:?} respectively.", 73 | key.device(), 74 | value.device() 75 | ))); 76 | } 77 | 78 | if !key.device().same_device(key_cache.device()) { 79 | return Err(APIError::new(format!( 80 | "`key` and `key_cache` have different devices, got {:?} and {:?} respectively.", 81 | key.device(), 82 | key_cache.device() 83 | ))); 84 | } 85 | 86 | if !key.device().same_device(value_cache.device()) { 87 | return Err(APIError::new(format!( 88 | "`key` and `value_cache` have different devices, got {:?} and {:?} respectively.", 89 | key.device(), 90 | value_cache.device() 91 | ))); 92 | } 93 | 94 | let num_tokens = key.dims()[0]; 95 | let num_heads = key.dims()[1]; 96 | let head_size = key.dims()[2]; 97 | let block_size = key_cache.dims()[3]; 98 | let x = key_cache.dims()[4]; 99 | 100 | let key_stride = key.stride()[0]; 101 | let value_stride = value.stride()[0]; 102 | 103 | let stream = try_api!(dev.fork_default_stream()); 104 | 105 | let launch_conf = LaunchConfig { 106 | grid_dim: (num_tokens.try_into().unwrap(), 1u32, 1u32), 107 | block_dim: ( 108 | 512.min((num_heads * head_size).try_into().unwrap()), 109 | 1u32, 110 | 1u32, 111 | ), 112 | shared_mem_bytes: 0, 113 | }; 114 | 115 | let kernel = try_api!(get_or_load_func( 116 | RESHAPE_AND_CACHE_PTX, 117 | RESHAPE_AND_CACHE_KERNEL, 118 | key.dtype(), 119 | None, 120 | dev 121 | )); 122 | 123 | let key_ptr = dispatch_get_cuda_pointer(key); 124 | let value_ptr = dispatch_get_cuda_pointer(value); 125 | let key_cache_ptr = dispatch_get_cuda_pointer(key_cache.clone()); 126 | let value_cache_ptr = dispatch_get_cuda_pointer(value_cache.clone()); 127 | 128 | try_api!(unsafe { 129 | kernel.launch_on_stream( 130 | &stream, 131 | launch_conf, 132 | ( 133 | key_ptr, 134 | value_ptr, 135 | key_cache_ptr, 136 | value_cache_ptr, 137 | key_stride, 138 | value_stride, 139 | num_heads, 140 | head_size, 141 | block_size, 142 | x, 143 | ), 144 | ) 145 | }); 146 | 147 | Ok(()) 148 | } 149 | 150 | /// # Safety 151 | /// Unsafe due to passing pointers 152 | pub unsafe fn copy_blocks( 153 | key_caches: Vec<&mut Tensor>, 154 | value_caches: Vec<&mut Tensor>, 155 | block_mapping: HashMap>, 156 | ) -> Result<(), APIError> { 157 | let cache_dev = key_caches.first().unwrap().device(); 158 | let Device::Cuda(dev) = cache_dev else { 159 | panic!("Expected the key caches to be on a CUDA device.") 160 | }; 161 | if !cache_dev.same_device(value_caches.first().unwrap().device()) { 162 | return Err(APIError::new(format!( 163 | "`key` and `value` caches have different devices, got {:?} and {:?} respectively.", 164 | cache_dev, 165 | value_caches.first().unwrap().device() 166 | ))); 167 | } 168 | if key_caches.first().unwrap().dtype() != value_caches.first().unwrap().dtype() { 169 | return Err(APIError::new(format!( 170 | "Key and value caches have different types, got {:?} and {:?}.", 171 | key_caches.first().unwrap().dtype(), 172 | value_caches.first().unwrap().dtype() 173 | ))); 174 | } 175 | let num_layers: u32 = key_caches.len().try_into().unwrap(); 176 | if num_layers == 0 { 177 | return Ok(()); 178 | } 179 | 180 | let mut key_cache_ptrs = Vec::new(); 181 | key_cache_ptrs.reserve_exact(num_layers as usize); 182 | let mut value_cache_ptrs = Vec::new(); 183 | value_cache_ptrs.reserve_exact(num_layers as usize); 184 | for (key_cache, value_cache) in zip(&key_caches, &value_caches) { 185 | try_api!(key_cache.to_device(cache_dev)); 186 | try_api!(value_cache.to_device(cache_dev)); 187 | 188 | let key_offset: u64 = key_cache 189 | .storage_and_layout() 190 | .1 191 | .start_offset() 192 | .try_into() 193 | .unwrap(); 194 | let Storage::Cuda(key_storage) = &*key_cache.storage_and_layout().0 else { 195 | unreachable!() 196 | }; 197 | let key_ptr = *try_api!(key_storage.as_cuda_slice::()).device_ptr(); 198 | key_cache_ptrs.push(key_ptr + key_offset); 199 | 200 | let value_offset: u64 = value_cache 201 | .storage_and_layout() 202 | .1 203 | .start_offset() 204 | .try_into() 205 | .unwrap(); 206 | let Storage::Cuda(value_storage) = &*value_cache.storage_and_layout().0 else { 207 | unreachable!() 208 | }; 209 | let value_ptr = *try_api!(value_storage.as_cuda_slice::()).device_ptr(); 210 | value_cache_ptrs.push(value_ptr + value_offset); 211 | } 212 | 213 | let mut block_mapping_vec: Vec = Vec::new(); 214 | for (src_block_number, dst_blocks) in block_mapping { 215 | for dst_block_number in dst_blocks { 216 | block_mapping_vec.push(src_block_number.try_into().unwrap()); 217 | block_mapping_vec.push(dst_block_number.try_into().unwrap()); 218 | } 219 | } 220 | let num_pairs: u32 = (block_mapping_vec.len() / 2).try_into().unwrap(); 221 | let block_mapping_ptr = Conjoined::new( 222 | NonNull::new(block_mapping_vec.as_mut_ptr()).unwrap(), 223 | &mut block_mapping_vec, 224 | ); 225 | 226 | let key_cache_ptr = Conjoined::new( 227 | NonNull::new(key_cache_ptrs.as_mut_ptr()).unwrap(), 228 | &mut key_cache_ptrs, 229 | ); 230 | let value_cache_ptr = Conjoined::new( 231 | NonNull::new(value_cache_ptrs.as_mut_ptr()).unwrap(), 232 | &mut value_cache_ptrs, 233 | ); 234 | 235 | let numel_per_block: u32 = try_api!(key_caches.first().unwrap().i(0)) 236 | .shape() 237 | .dims() 238 | .iter() 239 | .product::() 240 | .try_into() 241 | .unwrap(); 242 | let launch_conf = LaunchConfig { 243 | grid_dim: (num_layers, num_pairs, 1u32), 244 | block_dim: (numel_per_block.min(1024), 1u32, 1u32), 245 | shared_mem_bytes: 0, 246 | }; 247 | let stream = try_api!(dev.fork_default_stream()); 248 | 249 | let kernel = try_api!(get_or_load_func( 250 | COPY_BLOCKS_PTX, 251 | COPY_BLOCKS_KERNEL, 252 | key_caches.first().unwrap().dtype(), 253 | None, 254 | dev, 255 | )); 256 | 257 | try_api!(unsafe { 258 | kernel.launch_on_stream( 259 | &stream, 260 | launch_conf, 261 | (key_cache_ptr, value_cache_ptr, block_mapping_ptr), 262 | ) 263 | }); 264 | 265 | Ok(()) 266 | } 267 | 268 | pub fn swap_blocks( 269 | src: Tensor, 270 | dst: &mut Tensor, 271 | block_mapping: HashMap, 272 | ) -> Result<(), APIError> { 273 | let block_size_in_bytes = src.dtype().size_in_bytes() * src.dims()[0]; 274 | match (src.device(), dst.device()) { 275 | (Device::Cuda(src_dev), Device::Cuda(dst_dev)) => { 276 | if src_dev.ordinal() != dst_dev.ordinal() { 277 | return Err(APIError::new(format!("Tensors must be on the same device to copy, got ordinals {} (src) and {} (dst).", src_dev.ordinal(), dst_dev.ordinal()))) 278 | } 279 | let (src_storage, src_layout) = src.storage_and_layout(); 280 | let (dst_storage, dst_layout) = dst.storage_and_layout(); 281 | assert!(matches!(&*src_storage, Storage::Cuda(_))); 282 | assert!(matches!(&*dst_storage, Storage::Cuda(_))); 283 | let Storage::Cuda(src_storage) = &*src_storage else { unreachable!() }; 284 | let Storage::Cuda(dst_storage) = &*dst_storage else { unreachable!() }; 285 | let src_ptr = src_storage.as_cuda_slice::().map_err(APIError::from)?.device_ptr() + TryInto::::try_into(src_layout.start_offset()).unwrap(); 286 | let dst_ptr = dst_storage.as_cuda_slice::().map_err(APIError::from)?.device_ptr() + TryInto::::try_into(dst_layout.start_offset()).unwrap(); 287 | 288 | for (src_block_number, dst_block_number) in block_mapping { 289 | let src_offset: u64 = (src_block_number * block_size_in_bytes).try_into().unwrap(); 290 | let dst_offset: u64 = (dst_block_number * block_size_in_bytes).try_into().unwrap(); 291 | // u8s because we copy by bytes 292 | let src_slice: CudaSlice = unsafe { src_dev.upgrade_device_ptr(src_ptr+src_offset, block_size_in_bytes) }; 293 | let mut dst_slice = unsafe { dst_dev.upgrade_device_ptr(dst_ptr+dst_offset, block_size_in_bytes) }; 294 | 295 | try_api!(src_dev.dtod_copy(&src_slice, &mut dst_slice)); 296 | } 297 | } 298 | (Device::Cpu, Device::Cuda(dst_dev)) => { 299 | let (src_storage, _src_layout) = src.storage_and_layout(); 300 | let (dst_storage, dst_layout) = dst.storage_and_layout(); 301 | assert!(matches!(&*src_storage, Storage::Cpu(_))); 302 | assert!(matches!(&*dst_storage, Storage::Cuda(_))); 303 | let Storage::Cpu(src_storage) = &*src_storage else { unreachable!() }; 304 | let Storage::Cuda(dst_storage) = &*dst_storage else { unreachable!() }; 305 | let dst_ptr = dst_storage.as_cuda_slice::().map_err(APIError::from)?.device_ptr() + TryInto::::try_into(dst_layout.start_offset()).unwrap(); 306 | let src_slice = try_api!(src_storage.as_slice()); 307 | 308 | for (src_block_number, dst_block_number) in block_mapping { 309 | let src_offset = src_block_number * block_size_in_bytes; 310 | let dst_offset: u64 = (dst_block_number * block_size_in_bytes).try_into().unwrap(); 311 | // u8s because we copy by bytes 312 | let mut dst_slice: CudaSlice = unsafe { dst_dev.upgrade_device_ptr(dst_ptr+dst_offset, block_size_in_bytes) }; 313 | 314 | try_api!(dst_dev.htod_sync_copy_into(&src_slice[src_offset..src_offset+block_size_in_bytes], &mut dst_slice)); 315 | } 316 | } 317 | (Device::Cuda(src_dev), Device::Cpu) => { 318 | // Pending on huggingface/candle#1467 319 | todo!(); 320 | /*let (src_storage, src_layout) = src.storage_and_layout(); 321 | let (dst_storage, dst_layout) = dst.storage_mut_and_layout(); 322 | assert!(matches!(&*src_storage, Storage::Cuda(_))); 323 | assert!(matches!(&*dst_storage, Storage::Cpu(_))); 324 | let Storage::Cuda(src_storage) = &*src_storage else { unreachable!() }; 325 | let Storage::Cpu(dst_storage) = &*dst_storage else { unreachable!() }; 326 | let src_ptr = src_storage.as_cuda_slice::().map_err(APIError::from)?.device_ptr() + TryInto::::try_into(src_layout.start_offset()).unwrap(); 327 | let dst_slice: &[u8] = try_api!(dst_storage.as_slice()); 328 | let ptr = dst_slice.as_ptr() as *mut u8; 329 | // Safety: 330 | let dst_slice = unsafe { slice::from_raw_parts_mut(ptr, dst_slice.len()) }; 331 | 332 | for (src_block_number, dst_block_number) in block_mapping { 333 | let src_offset: u64 = (src_block_number * block_size_in_bytes).try_into().unwrap(); 334 | let dst_offset: u64 = (dst_block_number * block_size_in_bytes).try_into().unwrap(); 335 | // u8s because we copy by bytes 336 | let src_slice: CudaSlice = unsafe { src_dev.upgrade_device_ptr(src_ptr+src_offset, block_size_in_bytes) }; 337 | 338 | try_api!(src_dev.dtoh_sync_copy_into(&src_slice, dst_slice)); 339 | }*/ 340 | } 341 | (src, dst) => { 342 | return Err(APIError::new(format!("Tensors must be on either the GPU or CPU to swap,, got {src:?} (src) and {dst:?} (dst)."))) 343 | } 344 | } 345 | 346 | Ok(()) 347 | } 348 | --------------------------------------------------------------------------------