├── python ├── vllm_rs │ └── __init__.py ├── __init__.py └── ReadMe.md ├── .gitignore ├── example ├── rust-demo │ ├── .gitignore │ ├── Cargo.toml │ └── src │ │ └── main.rs ├── rust-demo-tokenize │ ├── .gitignore │ ├── Cargo.toml │ └── src │ │ └── main.rs ├── rust-demo-tools │ ├── .gitignore │ ├── Cargo.toml │ └── src │ │ └── main.rs ├── mcp.json ├── server.py ├── tokenize.py ├── completion.py └── tool_calling.py ├── src ├── models │ ├── mod.rs │ ├── mistral3_vl │ │ ├── config.rs │ │ └── mod.rs │ ├── layers │ │ ├── mask.rs │ │ ├── mod.rs │ │ ├── deepstack.rs │ │ ├── mlp.rs │ │ └── others.rs │ ├── qwen3_vl │ │ ├── config.rs │ │ ├── input.rs │ │ └── mod.rs │ └── gemma3 │ │ └── config.rs ├── mcp │ ├── mod.rs │ ├── transport.rs │ ├── client.rs │ └── types.rs ├── lib.rs ├── utils │ ├── guidance.rs │ ├── heartbeat.rs │ ├── gguf_varbuilder.rs │ ├── chat_template.rs │ ├── command.rs │ └── progress.rs ├── server │ └── streaming.rs ├── transfer │ └── cuda_remote.rs ├── core │ ├── sequence.rs │ └── mod.rs ├── tools │ └── parser.rs └── api.rs ├── myproject.toml ├── run.sh ├── docs ├── context-cache.md ├── embeddings.md ├── multimodal.md ├── performance.md ├── tokenizer_api.md ├── rust_crate.md └── get_started.md ├── Cargo.toml └── vllm_rs.pyi /python/vllm_rs/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | .vscode/ 4 | .idea/ 5 | AGENTS.md -------------------------------------------------------------------------------- /example/rust-demo/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | .vscode/ 4 | .idea/ 5 | AGENTS.md -------------------------------------------------------------------------------- /example/rust-demo-tokenize/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | .vscode/ 4 | .idea/ 5 | AGENTS.md -------------------------------------------------------------------------------- /example/rust-demo-tools/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | .vscode/ 4 | .idea/ 5 | AGENTS.md -------------------------------------------------------------------------------- /python/__init__.py: -------------------------------------------------------------------------------- 1 | from .vllm_rs import * 2 | 3 | __doc__ = vllm_rs.__doc__ 4 | if hasattr(vllm_rs, "__all__"): 5 | __all__ = vllm_rs.__all__ -------------------------------------------------------------------------------- /src/models/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod gemma3; 2 | pub mod glm4; 3 | pub mod layers; 4 | pub mod llama; 5 | pub mod mistral3_vl; 6 | pub mod qwen3; 7 | pub mod qwen3_moe; 8 | pub mod qwen3_vl; 9 | -------------------------------------------------------------------------------- /example/rust-demo-tools/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rust-demo-tools" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | vllm-rs = { path = "../.." } 8 | anyhow = "1.0" 9 | serde_json = "1.0" 10 | -------------------------------------------------------------------------------- /example/mcp.json: -------------------------------------------------------------------------------- 1 | { 2 | "mcpServers": { 3 | "filesystem": { 4 | "command": "npx", 5 | "args": [ 6 | "-y", 7 | "@modelcontextprotocol/server-filesystem", 8 | "~/" 9 | ] 10 | } 11 | } 12 | } -------------------------------------------------------------------------------- /example/rust-demo-tokenize/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rust-demo-tokenize" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | reqwest = { version = "0.12", features = ["json", "blocking"] } 8 | serde = { version = "1.0", features = ["derive"] } 9 | serde_json = "1.0" 10 | -------------------------------------------------------------------------------- /example/rust-demo/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rust-demo" 3 | version = "0.0.1" 4 | edition = "2021" 5 | default-run = "rust-demo" 6 | 7 | [dependencies] 8 | anyhow = "1.0.100" 9 | vllm_rs = { path = "../../", package = "vllm-rs" } 10 | 11 | [features] 12 | cuda = ["vllm_rs/cuda"] 13 | flash-attn = ["vllm_rs/flash-attn"] 14 | flash-context = ["vllm_rs/flash-context"] 15 | graph = ["vllm_rs/graph"] 16 | nccl = ["vllm_rs/nccl"] 17 | metal = ["vllm_rs/metal"] 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /src/mcp/mod.rs: -------------------------------------------------------------------------------- 1 | // src/mcp/mod.rs 2 | //! Model Context Protocol (MCP) support for vLLM.rs 3 | //! 4 | //! MCP enables AI applications to connect with external data sources and tools 5 | //! through a standardized protocol. 6 | 7 | pub mod client; 8 | pub mod manager; 9 | pub mod server; 10 | pub mod transport; 11 | pub mod types; 12 | 13 | pub use client::McpClient; 14 | pub use manager::{McpClientManager, McpManagerConfig, McpServerDefinition, McpToolConfig}; 15 | pub use server::McpServer; 16 | pub use types::*; 17 | -------------------------------------------------------------------------------- /myproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["maturin>=0.14,<0.15"] 3 | build-backend = "maturin" 4 | 5 | [project] 6 | name = "vllm-rs" 7 | description = "A blazing-fast ⚡, lightweight Rust 🦀 implementation of vLLM." 8 | requires-python = ">=3.8" 9 | classifiers = [ 10 | "Programming Language :: Rust", 11 | "Programming Language :: Python :: Implementation :: CPython", 12 | "Programming Language :: Python :: Implementation :: PyPy", 13 | ] 14 | 15 | [tool.maturin] 16 | python-source = "python" 17 | include = ["vllm_rs/runner", "vllm_rs/__init__.py", "vllm_rs/__init__.pyi", "py.typed"] 18 | sdist-include = ["ReadMe.md"] -------------------------------------------------------------------------------- /example/rust-demo/src/main.rs: -------------------------------------------------------------------------------- 1 | use vllm_rs::api::{EngineBuilder, ModelRepo}; 2 | use vllm_rs::server::{ChatMessage, MessageContentType}; 3 | use vllm_rs::utils::{config::SamplingParams, log_throughput}; 4 | 5 | fn main() -> anyhow::Result<()> { 6 | let mut engine = EngineBuilder::new(ModelRepo::ModelID(("Qwen/Qwen3-0.6B", None))).build()?; 7 | 8 | let messages = vec![ChatMessage { 9 | role: "user".to_string(), 10 | content: MessageContentType::PureText("Say hello from the Rust API.".to_string()), 11 | }]; 12 | 13 | let params = SamplingParams::default(); 14 | let output = engine.generate(params, messages)?; 15 | println!("\n\n{}", output.decode_output); 16 | 17 | log_throughput(&vec![output]); 18 | } 19 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "python")] 2 | use pyo3::prelude::*; 3 | pub mod api; 4 | pub mod core; 5 | pub mod mcp; 6 | pub mod models; 7 | #[cfg(feature = "python")] 8 | pub mod py; 9 | pub mod runner; 10 | pub mod server; 11 | pub mod tools; 12 | pub mod transfer; 13 | pub mod utils; 14 | #[cfg(feature = "python")] 15 | use crate::core::GenerationOutput; 16 | #[cfg(feature = "python")] 17 | use crate::py::Engine; 18 | #[cfg(feature = "python")] 19 | use crate::transfer::{PdConfig, PdMethod, PdRole}; 20 | #[cfg(feature = "python")] 21 | use crate::utils::chat_template::Message; 22 | #[cfg(feature = "python")] 23 | use crate::utils::config::{EngineConfig, GenerationConfig, SamplingParams}; 24 | /// A Python module implemented in Rust. 25 | #[cfg(feature = "python")] 26 | #[pymodule] 27 | fn vllm_rs(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { 28 | m.add_class::()?; 29 | m.add_class::()?; 30 | m.add_class::()?; 31 | m.add_class::()?; 32 | m.add_class::()?; 33 | m.add_class::()?; 34 | m.add_class::()?; 35 | m.add_class::()?; 36 | m.add_class::()?; 37 | Ok(()) 38 | } 39 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | # Default build options 5 | RELEASE="" 6 | PROFILE="release" 7 | FEATURES="" 8 | 9 | # Arrays to hold build and run args separately 10 | RUN_ARGS=() 11 | 12 | # Parse build arguments 13 | while [[ "$#" -gt 0 ]]; do 14 | case $1 in 15 | --release) 16 | RELEASE="--release" 17 | PROFILE="release" 18 | shift 19 | ;; 20 | --features) 21 | FEATURES="$2" 22 | shift 2 23 | ;; 24 | --) # Separator: remaining args go to runtime 25 | shift 26 | RUN_ARGS+=("$@") 27 | break 28 | ;; 29 | *) # Anything unknown is forwarded as a runtime arg 30 | RUN_ARGS+=("$1") 31 | shift 32 | ;; 33 | esac 34 | done 35 | 36 | # Echo config 37 | echo "Building with profile: $PROFILE" 38 | echo "Features: $FEATURES" 39 | echo "Runtime arguments: ${RUN_ARGS[*]}" 40 | 41 | # Step 1: Build runner binary 42 | FEATURES_RUNNER=$(echo "$FEATURES" | sed -E 's/\bpython\b//g' | xargs) 43 | echo "Building runner binary..." 44 | cargo build $RELEASE --bin runner --features "$FEATURES_RUNNER" 45 | 46 | #FEATURES=$(echo "$FEATURES" | sed -E 's/\bflash-attn\b//g' | xargs) 47 | # Step 2: Run the program with runtime args 48 | cargo run $RELEASE --features "$FEATURES" -- "${RUN_ARGS[@]}" 49 | -------------------------------------------------------------------------------- /src/models/mistral3_vl/config.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::config::Config; 2 | use serde::Deserialize; 3 | 4 | pub fn default_num_channels() -> usize { 5 | 3 6 | } 7 | 8 | pub fn default_activation() -> candle_nn::Activation { 9 | candle_nn::Activation::Silu 10 | } 11 | 12 | #[derive(Deserialize, Debug, Clone)] 13 | pub struct VisionConfig { 14 | pub hidden_size: usize, 15 | #[serde(default = "default_num_channels")] 16 | pub num_channels: usize, 17 | pub image_size: usize, 18 | pub patch_size: usize, 19 | pub rope_theta: f64, 20 | pub intermediate_size: usize, 21 | pub num_hidden_layers: usize, 22 | pub head_dim: Option, 23 | pub num_attention_heads: usize, 24 | #[serde(default = "default_activation")] 25 | pub hidden_act: candle_nn::Activation, 26 | } 27 | 28 | impl VisionConfig { 29 | pub fn head_dim(&self) -> usize { 30 | self.head_dim 31 | .unwrap_or(self.hidden_size / self.num_attention_heads) 32 | } 33 | } 34 | 35 | #[derive(Deserialize, Debug, Clone)] 36 | pub struct Mistral3Config { 37 | pub image_token_index: usize, 38 | pub multimodal_projector_bias: bool, 39 | pub projector_hidden_act: candle_nn::Activation, 40 | pub spatial_merge_size: usize, 41 | pub vision_feature_layer: isize, 42 | pub text_config: Config, 43 | pub vision_config: VisionConfig, 44 | } 45 | -------------------------------------------------------------------------------- /docs/context-cache.md: -------------------------------------------------------------------------------- 1 | # Context Cache Guide 2 | 3 | Context-cache lets the server reuse KV cache across turns via `session_id` when `--context-cache` is enabled (CUDA/Metal). This reduces prefill latency for long conversations. 4 | 5 | ## Enabling 6 | - Start server with cache on: 7 | ```bash 8 | ./run.sh --release --features cuda -- --server --m Qwen/Qwen3-30B-A3B-Instruct-2507 --context-cache 9 | ``` 10 | - Metal example: 11 | ```bash 12 | ./run.sh --release --features metal -- --server --m Qwen/Qwen3-4B-GGUF --f Qwen3-4B-Q4_K_M.gguf --context-cache --max-model-len 32768 13 | ``` 14 | 15 | ## Using `session_id` 16 | - First turn (creates cache): 17 | ```json 18 | {"model":"default","messages":[{"role":"user","content":"Explain KV cache"}],"session_id":"chat-123"} 19 | ``` 20 | - Follow-up reuses cache; only the new message is sent: 21 | ```json 22 | {"model":"default","messages":[{"role":"user","content":"continue"}],"session_id":"chat-123"} 23 | ``` 24 | - Cache limits follow `max_model_len` and block allocation; server logs warn when swapping or evicting. 25 | 26 | ## Notes 27 | - Set `--max-model-len` and `--kv-fraction` to balance cache size vs decode headroom; on Metal prefer smaller lengths. 28 | - `--fp8-kvcache` and `--flash-context` are optional CUDA optimizations (Ampere+); enable when building with the relevant features. 29 | - Avoid mixing context-cache with streaming `session_id` on heavy loads unless you need persistence; throughput may drop if cache swaps. 30 | -------------------------------------------------------------------------------- /docs/embeddings.md: -------------------------------------------------------------------------------- 1 | # Embedding Usage 2 | 3 | This repository now exposes OpenAI-style embeddings for text-only models (Qwen3, Qwen3-MoE, LLaMa, GLM4, Gemma3). Use the standard server run path and hit `/v1/embeddings`. 4 | 5 | ## Start the server (embeddings enabled) 6 | - CUDA example (Qwen3 text): 7 | ```bash 8 | ./run.sh --release --features cuda -- --server --m Qwen/Qwen2.5-7B-Instruct --context-cache 9 | ``` 10 | - Metal example (LLaMa3 text): 11 | ```bash 12 | ./run.sh --release --features metal -- --server --m meta-llama/Llama-3-8b --max-model-len 32768 13 | ``` 14 | 15 | ## Request examples 16 | - Float embeddings (default) with mean pooling: 17 | ```bash 18 | curl -X POST http://localhost:8000/v1/embeddings \ 19 | -H "Content-Type: application/json" \ 20 | -d '{"input":"hello world","model":"default","embedding_type":"mean"}' 21 | ``` 22 | - Base64-encoded embeddings with last-token pooling: 23 | ```bash 24 | curl -X POST http://localhost:8000/v1/embeddings \ 25 | -H "Content-Type: application/json" \ 26 | -d '{"input":["hello","hola"],"embedding_type":"last","encoding_format":"base64"}' 27 | ``` 28 | 29 | ## Notes 30 | - `model` defaults to the loaded model id; multiple models per request are not supported. 31 | - Uses existing tokenizer; long prompts must fit `max_model_len` (same as chat). 32 | - `embedding_type`: `mean` (default) averages tokens; `last` returns the final token hidden state. 33 | - Responses mirror OpenAI schema: `data[].embedding`, `usage.prompt_tokens`. 34 | -------------------------------------------------------------------------------- /docs/multimodal.md: -------------------------------------------------------------------------------- 1 | # Multimodal Model Usage 2 | 3 | This project supports vision-language models (Qwen3-VL dense/MoE, Gemma3, Mistral3-VL). The server exposes `/v1/chat/completions` with mixed text+image content and optional web UI. 4 | 5 | ## Starting servers 6 | - Qwen3-VL (CUDA): 7 | ```bash 8 | ./run.sh --release --features cuda -- --server \ 9 | --m Qwen/Qwen3-VL-8B-Instruct --ui-server --context-cache 10 | ``` 11 | - Qwen3-VL (Metal/Mac): 12 | ```bash 13 | ./run.sh --release --features metal -- --server \ 14 | --m Qwen/Qwen3-VL-8B-Instruct --max-model-len 32768 --ui-server 15 | ``` 16 | - Gemma3 (vision): 17 | ```bash 18 | ./run.sh --release --features cuda -- --server \ 19 | --m google/gemma-3-4b-it --ui-server --context-cache 20 | ``` 21 | - Mistral3-VL (vision): 22 | ```bash 23 | ./run.sh --release --features cuda -- --server \ 24 | --m mistralai/Ministral-3-8B-Reasoning --ui-server --context-cache 25 | ``` 26 | 27 | ## Request payloads (OpenAI-compatible) 28 | - Text + image URL: 29 | ```json 30 | { 31 | "model": "default", 32 | "messages": [ 33 | {"role":"user","content":[ 34 | {"type":"text","text":"Describe this image"}, 35 | {"type":"image_url","image_url":"https://example.com/cat.png"} 36 | ]} 37 | ] 38 | } 39 | ``` 40 | - Text + base64 image: `{"type":"image_base64","image_base64":"data:image/png;base64,..."}` 41 | 42 | ## Tips 43 | - Use smaller `--max-model-len` on Metal if VRAM is tight; consider `--kv-fraction` on CUDA to reserve cache. 44 | - For batch image inputs, keep concurrent images modest; too many images will increase prefill time. 45 | - `--ui-server` opens the built-in chat UI for uploads; without it, send HTTP requests directly.*** 46 | -------------------------------------------------------------------------------- /src/models/layers/mask.rs: -------------------------------------------------------------------------------- 1 | use candle_core::{DType, Device, Tensor}; 2 | 3 | #[cfg(any(feature = "flash-attn", feature = "flash-context"))] // If flash-attn or metal is enabled, we don't implement this function. 4 | // The actual implementation would be embedded in the flash or metal attention kernel. 5 | pub fn get_attention_causal_mask( 6 | _: &Device, 7 | _: DType, 8 | _: &Tensor, 9 | _: Vec, 10 | _: Option, 11 | _: bool, 12 | ) -> Option> { 13 | None 14 | } 15 | 16 | #[allow(unreachable_code)] 17 | #[cfg(not(any(feature = "flash-attn", feature = "flash-context")))] 18 | fn get_causal_mask_internal( 19 | device: &Device, 20 | dtype: DType, 21 | tgt_len: usize, 22 | sliding_window: Option, 23 | ) -> candle_core::Result { 24 | use attention_rs::mask::causal_mask; 25 | let mask = Tensor::zeros((tgt_len, tgt_len), dtype, device)?; 26 | let _ = causal_mask(&mask, sliding_window)?; 27 | mask.unsqueeze(0)?.unsqueeze(0) 28 | } 29 | 30 | #[cfg(not(any(feature = "flash-attn", feature = "flash-context")))] 31 | pub fn get_attention_causal_mask( 32 | device: &Device, 33 | dtype: DType, 34 | _: &Tensor, 35 | seqlens: Vec, 36 | sliding_window: Option, 37 | is_prefill: bool, 38 | ) -> Option> { 39 | if !is_prefill { 40 | return None; 41 | } 42 | let mut offsets = vec![0u32]; 43 | offsets.extend(seqlens.clone()); 44 | let mut vec_mask = Vec::new(); 45 | let mut start = 0; 46 | for (_, seq_offset) in seqlens.iter().enumerate() { 47 | let seq_len = seq_offset - start; 48 | let mask = 49 | get_causal_mask_internal(device, dtype, seq_len as usize, sliding_window).unwrap(); 50 | vec_mask.push(mask); 51 | start = *seq_offset; 52 | } 53 | Some(vec_mask) 54 | } 55 | -------------------------------------------------------------------------------- /src/utils/guidance.rs: -------------------------------------------------------------------------------- 1 | // src/utils/guidance.rs 2 | //! Guided decoding support via llguidance. 3 | //! 4 | //! NOTE: This module is currently stubbed out due to API changes in llguidance >= 0.6. 5 | //! The TopLevelGrammar::from_json_schema method is no longer available. 6 | //! Guided decoding features are temporarily disabled. 7 | 8 | use serde_json::Value; 9 | use std::path::Path; 10 | use std::sync::Arc; 11 | 12 | // Import toktrie from the crate root (it's re-exported by llguidance) 13 | pub use toktrie::TokTrie; 14 | 15 | pub struct GuidanceState { 16 | // Placeholder for future implementation 17 | _phantom: std::marker::PhantomData<()>, 18 | } 19 | 20 | impl GuidanceState { 21 | pub fn new(_toktrie: Arc, _schema: Value) -> anyhow::Result { 22 | // Stubbed out - guided decoding temporarily disabled 23 | anyhow::bail!("Guided decoding is temporarily disabled due to llguidance API changes. \ 24 | The TopLevelGrammar::from_json_schema method is no longer available in llguidance >= 0.6") 25 | } 26 | 27 | pub fn compute_allowed_tokens(&mut self) -> anyhow::Result { 28 | anyhow::bail!("Guided decoding is temporarily disabled") 29 | } 30 | 31 | pub fn commit_token(&mut self, _token: u32) -> anyhow::Result<()> { 32 | anyhow::bail!("Guided decoding is temporarily disabled") 33 | } 34 | } 35 | 36 | pub struct AllowedTokens { 37 | pub tokens: Vec, 38 | pub is_stopped: bool, 39 | } 40 | 41 | pub fn build_toktrie_from_tokenizer_bytes(bytes: &[u8]) -> anyhow::Result { 42 | // Try to build TokTrie from bytes 43 | // The new API uses TokTrie::from() with TokRxInfo and words 44 | // For now, return an error as the exact migration path needs investigation 45 | anyhow::bail!("TokTrie construction from tokenizer bytes is temporarily disabled. \ 46 | The TokTrie::from_huggingface_bytes method is no longer available in toktrie >= 1.0. \ 47 | Input bytes length: {}", bytes.len()) 48 | } 49 | 50 | pub fn load_toktrie_from_path(path: &Path) -> Option { 51 | // Temporarily disabled - returns None 52 | crate::log_warn!("load_toktrie_from_path is disabled: {:?}", path); 53 | None 54 | } 55 | -------------------------------------------------------------------------------- /src/models/layers/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod attention; 2 | pub mod deepstack; 3 | pub mod distributed; 4 | pub mod linear; 5 | pub mod mask; 6 | pub mod mlp; 7 | pub mod moe; 8 | pub mod others; 9 | pub mod rotary_emb; 10 | pub mod wna16; 11 | use crate::utils::downloader::ModelPaths; 12 | use crate::utils::gguf_varbuilder::VarBuilder as QVarBuilder; 13 | use candle_core::DType; 14 | use candle_core::{Device, Result}; 15 | use candle_nn::var_builder::ShardedVarBuilder as VarBuilder; 16 | use either::Either; 17 | 18 | #[derive(Clone)] 19 | pub struct VarBuilderX<'a>(pub Either, QVarBuilder>); 20 | 21 | impl VarBuilderX<'_> { 22 | pub fn new( 23 | model_pathes: &ModelPaths, 24 | is_gguf: bool, 25 | dtype: DType, 26 | device: &Device, 27 | ) -> Result { 28 | assert!( 29 | !model_pathes.get_weight_filenames().is_empty(), 30 | "No weight files found!" 31 | ); 32 | let weight_files = model_pathes.get_weight_filenames(); 33 | if is_gguf { 34 | let vb = crate::utils::gguf_varbuilder::VarBuilder::from_gguf( 35 | weight_files[0].clone(), 36 | device, 37 | )?; 38 | Ok(Self(Either::Right(vb))) 39 | } else { 40 | let vb = unsafe { 41 | candle_nn::var_builder::ShardedSafeTensors::var_builder( 42 | &weight_files, 43 | dtype, 44 | device, 45 | )? 46 | }; 47 | Ok(Self(Either::Left(vb))) 48 | } 49 | } 50 | 51 | pub fn is_var_builder(&self) -> bool { 52 | matches!(self.0, Either::Left(_)) 53 | } 54 | 55 | pub fn is_qvar_builder(&self) -> bool { 56 | matches!(self.0, Either::Right(_)) 57 | } 58 | 59 | pub fn device(&self) -> Device { 60 | match &self.0 { 61 | Either::Left(vb) => vb.device().clone(), 62 | Either::Right(vb) => vb.device().clone(), 63 | } 64 | } 65 | 66 | pub fn pp(&self, name: &str) -> VarBuilderX { 67 | match &self.0 { 68 | Either::Left(vb) => VarBuilderX(Either::Left(vb.pp(name))), 69 | Either::Right(vb) => VarBuilderX(Either::Right(vb.pp(name))), 70 | } 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /src/server/streaming.rs: -------------------------------------------------------------------------------- 1 | use super::ChatCompletionChunk; 2 | use axum::response::sse::Event; 3 | use flume::Receiver; 4 | use futures::Stream; 5 | use std::{ 6 | pin::Pin, 7 | task::{Context, Poll}, 8 | }; 9 | 10 | #[derive(PartialEq)] 11 | pub enum StreamingStatus { 12 | Uninitialized, 13 | Started, 14 | Interrupted, 15 | Stopped, 16 | } 17 | pub enum ChatResponse { 18 | InternalError(String), 19 | ValidationError(String), 20 | ModelError(String), 21 | Chunk(ChatCompletionChunk), 22 | Done, //finish flag 23 | } 24 | 25 | pub struct Streamer { 26 | pub rx: Receiver, 27 | pub status: StreamingStatus, 28 | } 29 | 30 | impl Stream for Streamer { 31 | type Item = Result; 32 | 33 | fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { 34 | if self.status == StreamingStatus::Stopped { 35 | return Poll::Ready(None); 36 | } 37 | match self.rx.try_recv() { 38 | Ok(resp) => match resp { 39 | ChatResponse::InternalError(e) => Poll::Ready(Some(Ok(Event::default().data(e)))), 40 | ChatResponse::ValidationError(e) => Poll::Ready(Some(Ok(Event::default().data(e)))), 41 | ChatResponse::ModelError(e) => Poll::Ready(Some(Ok(Event::default().data(e)))), 42 | ChatResponse::Chunk(response) => { 43 | if self.status != StreamingStatus::Started { 44 | self.status = StreamingStatus::Started; 45 | } 46 | Poll::Ready(Some(Event::default().json_data(response))) 47 | } 48 | ChatResponse::Done => { 49 | self.status = StreamingStatus::Stopped; 50 | Poll::Ready(Some(Ok(Event::default().data("[DONE]")))) 51 | } 52 | }, 53 | Err(e) => { 54 | if self.status == StreamingStatus::Started && e == flume::TryRecvError::Disconnected 55 | { 56 | self.status = StreamingStatus::Interrupted; 57 | Poll::Ready(None) 58 | } else { 59 | Poll::Pending 60 | } 61 | } 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /src/models/qwen3_vl/config.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | serde_default, 3 | utils::config::{Config, QuantConfig}, 4 | }; 5 | use candle_nn::Activation; 6 | 7 | serde_default!(Activation, default_vision_hidden_act, Activation::Gelu); 8 | serde_default!(usize, default_in_channels, 3); 9 | serde_default!(usize, default_depth, 32); 10 | serde_default!(usize, default_hidden_size, 3584); 11 | serde_default!(usize, default_out_hidden_size, 3584); 12 | serde_default!(usize, default_intermediate_size, 3420); 13 | serde_default!(usize, default_num_heads, 16); 14 | serde_default!(usize, default_patch_size, 14); 15 | serde_default!(usize, default_spatial_merge_size, 2); 16 | serde_default!(usize, default_temporal_patch_size, 2); 17 | serde_default!(usize, default_num_position_embeddings, 576); 18 | serde_default!(Vec, default_deepstack_visual_indexes, Vec::new()); 19 | 20 | #[derive(Debug, Clone, serde::Deserialize)] 21 | pub struct VisionConfig { 22 | #[serde(default = "default_depth")] 23 | pub depth: usize, 24 | #[serde(default = "default_hidden_size")] 25 | pub hidden_size: usize, 26 | #[serde(default = "default_out_hidden_size")] 27 | pub out_hidden_size: usize, 28 | #[serde(default = "default_vision_hidden_act")] 29 | pub hidden_act: Activation, 30 | #[serde(default = "default_intermediate_size")] 31 | pub intermediate_size: usize, 32 | #[serde(default = "default_num_heads")] 33 | pub num_heads: usize, 34 | #[serde(default = "default_in_channels")] 35 | pub in_chans: usize, 36 | #[serde(default = "default_patch_size")] 37 | pub patch_size: usize, 38 | #[serde(default = "default_spatial_merge_size")] 39 | pub spatial_merge_size: usize, 40 | #[serde(default = "default_temporal_patch_size")] 41 | pub temporal_patch_size: usize, 42 | #[serde(default = "default_num_position_embeddings")] 43 | pub num_position_embeddings: usize, 44 | #[serde(default = "default_deepstack_visual_indexes")] 45 | pub deepstack_visual_indexes: Vec, 46 | } 47 | 48 | #[derive(Debug, Clone, serde::Deserialize)] 49 | pub struct Qwen3VLConfig { 50 | pub architectures: Option>, 51 | pub text_config: Config, 52 | pub vision_config: VisionConfig, 53 | pub image_token_id: u32, 54 | pub video_token_id: u32, 55 | pub vision_start_token_id: u32, 56 | pub vision_end_token_id: u32, 57 | pub tie_word_embeddings: bool, 58 | pub quantization_config: Option, 59 | } 60 | -------------------------------------------------------------------------------- /src/models/layers/deepstack.rs: -------------------------------------------------------------------------------- 1 | use candle_core::{DType, Result, Tensor}; 2 | pub trait ApplyDeepStack { 3 | fn apply_deep_stack(&self, visual_pos_masks: &Tensor, visual_embeds: &Tensor) 4 | -> Result; 5 | } 6 | 7 | impl ApplyDeepStack for Tensor { 8 | fn apply_deep_stack( 9 | &self, 10 | visual_pos_masks: &Tensor, 11 | visual_embeds: &Tensor, 12 | ) -> Result { 13 | deepstack_process(&self, visual_pos_masks, visual_embeds) 14 | } 15 | } 16 | 17 | // Reference: https://github.com/EricLBuehler/mistral.rs/blob/master/mistralrs-core/src/vision_models/qwen3_vl_moe/text.rs#698 18 | fn deepstack_process( 19 | hidden_states: &Tensor, 20 | visual_pos_masks: &Tensor, 21 | visual_embeds: &Tensor, 22 | ) -> Result { 23 | let device = hidden_states.device(); 24 | let dtype = hidden_states.dtype(); 25 | 26 | let mask = visual_pos_masks.to_device(device)?.to_dtype(DType::F32)?; 27 | let mask_flat = mask.flatten_all()?; 28 | 29 | let masked_count = mask_flat.sum_all()?.to_scalar::()? as usize; 30 | let visual_embeds = visual_embeds.to_device(device)?.to_dtype(dtype)?; 31 | 32 | if masked_count == 0 { 33 | if visual_embeds.dim(0)? != 0 { 34 | candle_core::bail!( 35 | "DeepStack visual embeds ({}) provided but mask is empty", 36 | visual_embeds.dim(0)? 37 | ); 38 | } 39 | return Ok(hidden_states.clone()); 40 | } 41 | 42 | if visual_embeds.dim(0)? != masked_count { 43 | candle_core::bail!( 44 | "Mismatch between DeepStack visual embeds ({}) and mask positions ({})", 45 | visual_embeds.dim(0)?, 46 | masked_count 47 | ); 48 | } 49 | 50 | let (total_positions, hidden) = hidden_states.dims2()?; 51 | 52 | let prefix = mask_flat.cumsum(0)?; 53 | let rank = (prefix - &mask_flat)?.mul(&mask_flat)?; 54 | let rank_u32 = rank.to_dtype(DType::U32)?; 55 | 56 | let positions = Tensor::arange(0u32, total_positions as u32, device)?; 57 | let positions_f32 = positions.to_dtype(DType::F32)?; 58 | let masked_positions = positions_f32.mul(&mask_flat)?; 59 | 60 | let mut position_per_rank = Tensor::zeros((masked_count,), DType::F32, device)?; 61 | position_per_rank = position_per_rank.scatter_add(&rank_u32, &masked_positions, 0)?; 62 | let position_per_rank = position_per_rank.to_dtype(DType::U32)?; 63 | 64 | let linear_index = position_per_rank.unsqueeze(1)?.repeat((1, hidden))?; 65 | 66 | hidden_states.scatter_add(&linear_index, &visual_embeds, 0) 67 | } 68 | -------------------------------------------------------------------------------- /src/utils/heartbeat.rs: -------------------------------------------------------------------------------- 1 | use super::command::CommandManager; 2 | use std::sync::{ 3 | atomic::{AtomicBool, Ordering}, 4 | Arc, 5 | }; 6 | use std::{process, thread, time}; 7 | 8 | pub fn heartbeat_worker( 9 | num_subprocess: Option, 10 | is_daemon: bool, 11 | stop_flag: Arc, 12 | uuid: &str, 13 | ) -> std::thread::JoinHandle<()> { 14 | let uuid_str = uuid.to_string(); 15 | let handle = thread::spawn(move || { 16 | let flag_clone = Arc::clone(&stop_flag); 17 | let sock_name = format!("{}@vllm-rs-runner-heartbeat", uuid_str); 18 | let mut connect_retry_count = 0; 19 | let mut command_manager = if is_daemon { 20 | let mut manager = CommandManager::new_command(&sock_name, None, is_daemon); 21 | while !flag_clone.load(Ordering::Relaxed) { 22 | if manager.is_ok() { 23 | break; 24 | } else if connect_retry_count < 120 { 25 | connect_retry_count += 1; 26 | crate::log_info!( 27 | "Retry connect to main process' command channel ({:?})!", 28 | manager 29 | ); 30 | let _ = thread::sleep(time::Duration::from_millis(1000 as u64)); 31 | manager = CommandManager::new_command(&sock_name, None, is_daemon); 32 | continue; 33 | } else { 34 | crate::log_warn!("{:?}", manager); 35 | break; 36 | } 37 | } 38 | manager 39 | } else { 40 | CommandManager::new_command(&sock_name, num_subprocess, is_daemon) 41 | }; 42 | 43 | let mut heartbeat_error_count = 0; 44 | crate::log_info!("enter heartbeat processing loop ({:?})", command_manager); 45 | while !flag_clone.load(Ordering::Relaxed) { 46 | let alive_result = command_manager.as_mut().unwrap().heartbeat(is_daemon); 47 | if alive_result.is_err() { 48 | if !flag_clone.load(Ordering::Relaxed) { 49 | crate::log_warn!("{:?}", alive_result); 50 | } 51 | if heartbeat_error_count > 5 { 52 | crate::log_error!( 53 | "heartbeat detection failed, exit the current process because of {:?}", 54 | alive_result 55 | ); 56 | process::abort(); 57 | } 58 | heartbeat_error_count += 1; 59 | } 60 | 61 | let _ = thread::sleep(time::Duration::from_millis(1000 as u64)); 62 | } 63 | }); 64 | handle 65 | } 66 | -------------------------------------------------------------------------------- /docs/performance.md: -------------------------------------------------------------------------------- 1 | # Performance Benchmarks 2 | 3 | This document contains detailed performance benchmarks for vLLM.rs across different hardware platforms. 4 | 5 | ## 🚀 CUDA Performance (A100 40GB) 6 | 7 | ### Single Request Decoding Speed 8 | 9 | | Model | Format | Size | Decoding Speed | 10 | |-------|--------|------|----------------| 11 | | Llama-3.1-8B | ISQ (BF16→Q4K) | 8B | **90.19** tokens/s | 12 | | DeepSeek-R1-Distill-Llama-8B | Q2_K | 8B | **94.47** tokens/s | 13 | | DeepSeek-R1-0528-Qwen3-8B | Q4_K_M | 8B | **95** tokens/s | 14 | | GLM-4-9B-0414 | Q4_K_M | 9B | **70.38** tokens/s | 15 | | QwQ-32B | Q4_K_M | 32B | **35.69** tokens/s | 16 | | **Qwen3-30B-A3B** | Q4_K_M | **30B (MoE)** | **75.91** tokens/s | 17 | 18 | ## 🍎 Metal Performance (Apple Silicon M4) 19 | 20 | > Test Configuration: 21 | > - Models: Qwen3-0.6B (BF16), Qwen3-4B (Q4_K_M), Qwen3-8B (Q2_K) 22 | > - Concurrent Requests: 1 - 128 23 | > - Max Model Length: 512 - 2048 24 | > - Max Output Tokens/Request: 512 - 2048 25 | 26 | | Model | Batch Size | Output Tokens | Time (s) | Throughput (tokens/s) | 27 | |-------|------------|---------------|----------|----------------------| 28 | | Qwen3-0.6B (BF16) | 128 | 63,488 | 83.13s | **763.73** | 29 | | Qwen3-0.6B (BF16) | 32 | 15,872 | 23.53s | **674.43** | 30 | | Qwen3-0.6B (BF16) | 1 | 456 | 9.23s | 49.42 | 31 | | Qwen3-4B (Q4_K_M) | 1 | 1,683 | 52.62s | 31.98 | 32 | | Qwen3-8B (Q2_K) | 1 | 1,300 | 80.88s | 16.07 | 33 | 34 | ## 📊 Performance Comparison 35 | 36 | > Test Configuration: 37 | > - Model: Qwen3-0.6B (BF16) 38 | > - Concurrent Requests: 256 39 | > - Max Model Length: 1024 40 | > - Max Output Tokens/Request: 1024 41 | 42 | | Inference Engine | Hardware | Tokens | Time (s) | Throughput (tokens/s) | 43 | |------------------|----------|--------|----------|----------------------| 44 | | vLLM (Reference) | RTX 4070 | 133,966 | 98.37 | 1,361.84 | 45 | | Nano-vLLM (Reference) | RTX 4070 | 133,966 | 93.41 | 1,434.13 | 46 | | **vLLM.rs** | **A100** | 262,144 | 23.88 | **10,977.55** | 47 | | Nano-vLLM | A100 | 262,144 | 34.22 | 7,660.26 | 48 | 49 | ### Key Insights 50 | 51 | - **40%+ faster** than Nano-vLLM on A100 52 | - **7x faster** than reference implementations on consumer hardware 53 | - Efficient memory management with quantized models 54 | 55 | ## 🔧 Reproduce Benchmarks 56 | 57 | See [python/ReadMe.md](../python/ReadMe.md) for reproducible benchmark steps. 58 | 59 | ## Optimization Tips 60 | 61 | 1. **Use FP8 KV Cache** (`--fp8-kvcache`) for memory efficiency with slight accuracy tradeoff 62 | 2. **Enable Flash Attention** (`flash-attn` feature) for maximum CUDA performance 63 | 3. **Use Context Cache** (`--context-cache`) for multi-turn conversations 64 | 4. **Tune `--kv-fraction`** to balance memory usage and batch size 65 | 5. **Use PD Disaggregation** for long-context workloads to prevent decoding stalls 66 | -------------------------------------------------------------------------------- /src/transfer/cuda_remote.rs: -------------------------------------------------------------------------------- 1 | // src/core/transfer/cuda_remote.rs 2 | use super::CudaIpcMemHandle; 3 | use candle_core::cuda_backend::cudarc::driver::sys::{lib, CUdeviceptr, CUipcMemHandle}; 4 | use candle_core::cuda_backend::cudarc::driver::DevicePtr; 5 | use candle_core::{Device, Result, Tensor, WithDType}; 6 | use std::mem::{ManuallyDrop, MaybeUninit}; 7 | /// (Server) Gets a serializable IPC handle for a GPU tensor's memory. 8 | pub(super) fn get_ipc_handle( 9 | tensor: &Tensor, 10 | ) -> Result { 11 | use candle_core::Storage; 12 | let (storage, _) = tensor.storage_and_layout(); 13 | let Storage::Cuda(src_storage) = &*storage else { 14 | candle_core::bail!("Invalid source kvcache storage!") 15 | }; 16 | let ptr = src_storage.as_cuda_slice::()?.device_ptr(); 17 | 18 | let mut handle = MaybeUninit::::uninit(); 19 | let handle = unsafe { 20 | lib() 21 | .cuIpcGetMemHandle(handle.as_mut_ptr(), *ptr) 22 | .result() 23 | .map_err(|e| candle_core::Error::Msg(format!("cuIpcGetMemHandle failed: {e:?}")))?; 24 | handle.assume_init() 25 | }; 26 | 27 | Ok(CudaIpcMemHandle( 28 | handle.reserved.to_vec(), 29 | tensor.shape().dims().to_vec(), 30 | tensor.dtype().into(), 31 | )) 32 | } 33 | 34 | /// (Client) Opens an IPC handle to get a local tensor pointing to remote GPU memory. 35 | pub(super) fn open_ipc_handle( 36 | handle: &CudaIpcMemHandle, 37 | device: &Device, 38 | ) -> Result> { 39 | use candle_core::cuda_backend::cudarc::driver::CudaSlice; 40 | 41 | let mut ptr: CUdeviceptr = 0; 42 | use core::ffi::c_char; 43 | if handle.0.len() != 64 { 44 | candle_core::bail!("Invalid CUipcMemHandle handle!"); 45 | } 46 | let raw_array: [i8; 64] = handle.0.clone().try_into().unwrap(); 47 | let handle_raw = CUipcMemHandle { 48 | reserved: raw_array.map(|b| b as c_char), 49 | }; 50 | unsafe { 51 | lib() 52 | .cuIpcOpenMemHandle_v2(&mut ptr, handle_raw, 1) 53 | .result() 54 | .map_err(|e| candle_core::Error::Msg(format!("cuIpcOpenMemHandle_v2 failed: {e:?}")))?; 55 | } 56 | let dev = device.as_cuda_device()?; 57 | let src_slice = unsafe { 58 | let slice: CudaSlice = dev.upgrade_device_ptr(ptr, handle.1.iter().sum()); 59 | // std::mem::ManuallyDrop::new(slice) 60 | slice 61 | }; 62 | 63 | let slice = candle_core::CudaStorage::wrap_cuda_slice(src_slice, dev.clone()); 64 | // We created a virtual Tensor, it stored the remote mem handle, so we should not release it 65 | Ok(ManuallyDrop::new(Tensor::from_storage( 66 | candle_core::Storage::Cuda(slice), 67 | handle.1.clone(), 68 | )?)) 69 | } 70 | -------------------------------------------------------------------------------- /docs/tokenizer_api.md: -------------------------------------------------------------------------------- 1 | # Tokenizer API 2 | 3 | The Tokenizer API provides direct access to the model's tokenizer for encoding/decoding text without running inference. This is useful for: 4 | 5 | - Pre-computing token counts for cost estimation 6 | - Validating inputs before sending to the model 7 | - Debugging tokenization issues 8 | - Building custom tooling around the tokenizer 9 | 10 | ## Endpoints 11 | 12 | ### POST /tokenize 13 | 14 | Convert text or chat messages to token IDs. 15 | 16 | **Request (plain text):** 17 | ```json 18 | { 19 | "prompt": "Hello, world!" 20 | } 21 | ``` 22 | 23 | **Request (chat messages - applies chat template):** 24 | ```json 25 | { 26 | "messages": [ 27 | {"role": "system", "content": "You are helpful."}, 28 | {"role": "user", "content": "Hello!"} 29 | ] 30 | } 31 | ``` 32 | 33 | **Optional parameters:** 34 | - `model`: Model name (optional, uses loaded model) 35 | - `add_special_tokens`: Whether to add special tokens (default: `true`) 36 | 37 | **Response:** 38 | ```json 39 | { 40 | "tokens": [1, 2, 3, 4], 41 | "count": 4, 42 | "max_model_len": 4096 43 | } 44 | ``` 45 | 46 | ### POST /detokenize 47 | 48 | Convert token IDs back to text. 49 | 50 | **Request:** 51 | ```json 52 | { 53 | "tokens": [1, 2, 3, 4] 54 | } 55 | ``` 56 | 57 | **Optional parameters:** 58 | - `model`: Model name (optional, uses loaded model) 59 | - `skip_special_tokens`: Whether to skip special tokens in output (default: `true`) 60 | 61 | **Response:** 62 | ```json 63 | { 64 | "prompt": "Hello, world!" 65 | } 66 | ``` 67 | 68 | ## Examples 69 | 70 | ### Python 71 | 72 | ```python 73 | import requests 74 | 75 | # Tokenize text 76 | response = requests.post("http://localhost:8000/tokenize", json={ 77 | "prompt": "Hello, world!" 78 | }) 79 | print(response.json()) 80 | # {"tokens": [9906, 11, 1917, 0], "count": 4, "max_model_len": 4096} 81 | 82 | # Detokenize 83 | response = requests.post("http://localhost:8000/detokenize", json={ 84 | "tokens": [9906, 11, 1917, 0] 85 | }) 86 | print(response.json()) 87 | # {"prompt": "Hello, world!"} 88 | ``` 89 | 90 | ### cURL 91 | 92 | ```bash 93 | # Tokenize 94 | curl -X POST http://localhost:8000/tokenize \ 95 | -H "Content-Type: application/json" \ 96 | -d '{"prompt": "Hello, world!"}' 97 | 98 | # Detokenize 99 | curl -X POST http://localhost:8000/detokenize \ 100 | -H "Content-Type: application/json" \ 101 | -d '{"tokens": [9906, 11, 1917, 0]}' 102 | ``` 103 | 104 | ## Demo Scripts 105 | 106 | - **Python**: `example/tokenize.py` - Interactive demo with multiple examples 107 | - **Rust**: `example/rust-demo-tokenize/` - Rust client demo 108 | 109 | Run the Python demo: 110 | ```bash 111 | # Start the server first 112 | cargo run --release -- --m Qwen/Qwen2.5-0.5B-Instruct --server 113 | 114 | # In another terminal 115 | python example/tokenize.py 116 | ``` 117 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "vllm-rs" 3 | version = "0.6.2" 4 | edition = "2021" 5 | default-run = "vllm-rs" 6 | 7 | [dependencies] 8 | candle-core = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "8e43e56" } 9 | candle-nn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "8e43e56" } 10 | serde = { version = "1.0.190", features = ["serde_derive"] } 11 | tokenizers = {version = "0.21.2", features = ["http"] } 12 | hf-hub = "0.4.1" 13 | anyhow = "1.0.75" 14 | itertools = "0.13.0" 15 | akin = "0.4.0" 16 | indicatif = "0.17.11" 17 | serde_json = "1.0.108" 18 | llguidance = "0.6" 19 | toktrie = "1.4" 20 | half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] } 21 | tokio = { version = "1.38.0", features = ["sync"] } 22 | tracing = "0.1.40" 23 | tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } 24 | either = { version = "1.13.0", features = ["serde"] } 25 | minijinja = { version = "2.10.2", features = ["builtins", "json"] } 26 | minijinja-contrib = { version = "2.10.2", features = ["pycompat"] } 27 | lazy_static = {version = "1.4.0"} 28 | interprocess = "2.2.2" 29 | serde-big-array = "0.5.1" 30 | bincode = { version = "1.3.1" } 31 | twox-hash = "2.1.1" 32 | rand = "0.9.0" 33 | rayon="1.10.0" 34 | clap = { version = "4.4.7", features = ["derive"] } 35 | thiserror = "1.0.58" 36 | schemars = "0.8" 37 | ahash = "0.8.11" 38 | reedline = "0.40.0" 39 | pyo3 = { version = "0.25.1", features = ["extension-module", "abi3-py38"], optional = true } 40 | parking_lot = "0.12.4" 41 | attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.1.7", rev = "b34e402" } 42 | once_cell = "1.21.3" 43 | tqdm = "0.8.0" 44 | futures = "0.3.31" 45 | crossbeam = "0.8.4" 46 | ctrlc = "3.4.7" 47 | dirs = "5.0.1" 48 | base64 = "0.22.1" 49 | metal = { version = "0.27.0", features = ["mps"], optional = true } 50 | uuid = { version = "1.5.0", features = ["v4"] } 51 | axum = { version = "0.7.4", features = ["tokio"] } 52 | flume = "0.10.14" 53 | utoipa = { version = "4.2", features = ["axum_extras"] } 54 | colored = { version = "3.0.0" } 55 | tower-http = { version = "0.6.6", features = ["cors"] } 56 | rustchatui = { git = "https://github.com/guoqingbao/rustchatui.git", rev = "e5feeb0" } 57 | sysinfo = "0.37.2" 58 | image = { version = "0.25.6", default-features = false, features = ['bmp', 'gif', 'jpeg', 'png', 'tiff', 'webp'] } 59 | reqwest = { version = "0.12.24", features = ["blocking", "json", "rustls-tls"]} 60 | bytemuck = "1.24.0" 61 | regex = "1.12.2" 62 | 63 | [lib] 64 | name = "vllm_rs" 65 | path = "src/lib.rs" 66 | crate-type = ["rlib", "cdylib"] # cdylib needed for python extension 67 | 68 | [features] 69 | cuda = ["candle-core/cuda", "candle-nn/cuda", "attention-rs/cuda"] 70 | nccl = ["candle-core/nccl"] 71 | graph = ["cuda", "attention-rs/graph", "candle-core/graph"] 72 | flash-attn = ["attention-rs/flash-attn"] 73 | flash-context = ["attention-rs/flash-attn", "attention-rs/flash-decoding", "attention-rs/flash-context", "attention-rs/no-fp8-kvcache"] 74 | metal = ["candle-core/metal", "candle-nn/metal", "attention-rs/metal", "dep:metal"] 75 | python = ["pyo3"] 76 | 77 | [[bin]] 78 | name = "runner" 79 | path = "src/runner/runner.rs" 80 | -------------------------------------------------------------------------------- /python/ReadMe.md: -------------------------------------------------------------------------------- 1 | #### How to reproduce? 2 | **vLLM.rs** 3 | ```shell 4 | pip install vllm_rs 5 | python -m vllm_rs.completion --w /home/Qwen3-0.6B/ --batch 256 --max-tokens 1024 --max-model-len 1024 6 | 7 | # Log 8 | Allocating 8192 KV blocks (28672 MB) for [256 seqs x 1024 tokens] 9 | Maximum batched tokens 262144 (8192 blocks x Block_Size 32 for KV cache). 10 | Start inference with 256 prompts 11 | --- Performance Metrics --- 12 | ⏱️ Prompt tokens: 4096 in 0.28s (14894.55 tokens/s) 13 | ⏱️ Decoded tokens: 258048 in 23.60s (10944.62 tokens/s) 14 | ``` 15 | 16 | 17 | **Nano-vLLM** 18 | 19 | 💡 To ensure a fair comparison, revise each request to have a maximum of 1024 output tokens, instead of a random number between 100 and 1024. 20 | ```shell 21 | pip install git+https://github.com/GeeeekExplorer/nano-vllm.git 22 | # with cuda graph, flash attention and model warmup 23 | python3 bench.py 24 | # log 25 | Generating: 100%|██████████████████| 1/1 [00:02<00:00, 2.65s/it, Prefill=1tok/s, Decode=369tok/s] 26 | Total: 262144tok, Time: 34.22s, Throughput: 7660.26tok/s 27 | ``` 28 | --- 29 | 30 | ### 🐍 Python API 31 | 32 | ```python 33 | from vllm_rs import Engine, EngineConfig, SamplingParams, Message 34 | cfg = EngineConfig(weight_path="/path/Qwen3-8B-Q2_K.gguf", max_model_len=4096) 35 | engine = Engine(cfg, "bf16") 36 | params = SamplingParams(temperature=0.6, max_tokens=256) 37 | message = Message("user", "How are you?")] 38 | 39 | # Synchronous batch generation 40 | outputs = engine.generate_sync([params, params], [[message], [message]]) 41 | print(outputs) 42 | 43 | params.session_id = xxx # Pass session_id to enable context cache 44 | 45 | # Single-request streaming generation 46 | (seq_id, prompt_length, stream) = engine.generate_stream(params, [message]) 47 | for item in stream: 48 | # item.datatype == "TOKEN" 49 | print(item.data) 50 | ``` 51 | 52 | ### 🤖 Client Usage of Context Cache 53 | 54 | **Key changes for the client:** 55 | 56 | ```python 57 | import uuid 58 | import openai 59 | use_context_cache = True #flag to use context_cache 60 | # create session_id for each new chat session and use it throughout that session (session cache will be cleared if the client aborted the connection) 61 | session_id = str(uuid.uuid4()) 62 | extra_body = {"session_id": session_id if use_context_cache else None } 63 | 64 | # vllm.rs service url 65 | openai.api_key = "EMPTY" 66 | openai.base_url = "http://localhost:8000/v1/" 67 | 68 | response = openai.chat.completions.create( 69 | model="", 70 | messages=messages + [user_msg], 71 | stream=True, 72 | max_tokens = max_tokens, 73 | temperature = temperature, 74 | top_p = top_p, 75 | extra_body = extra_body, #pass session_id through extra_body 76 | ) 77 | 78 | ``` 79 | 80 | ### 🧰 MCP Multi-Server Demo (Python Client) 81 | 82 | Start vLLM.rs with an MCP config file: 83 | 84 | ```shell 85 | vllm-rs --server --mcp-config ./mcp.json 86 | ``` 87 | 88 | Then call a prefixed MCP tool from Python: 89 | 90 | ```python 91 | import openai 92 | 93 | openai.api_key = "EMPTY" 94 | openai.base_url = "http://localhost:8000/v1/" 95 | 96 | response = openai.chat.completions.create( 97 | model="", 98 | messages=[{"role": "user", "content": "Use filesystem_read_file to read README.md"}], 99 | ) 100 | print(response.choices[0].message) 101 | ``` 102 | -------------------------------------------------------------------------------- /docs/rust_crate.md: -------------------------------------------------------------------------------- 1 | # Rust crate usage 2 | 3 | This crate exposes a Rust-facing API for loading models, running generation, and optionally running 4 | an OpenAI-compatible service without changing the existing project structure. 5 | 6 | ## Add dependency 7 | 8 | ```toml 9 | [dependencies] 10 | vllm-rs = { path = "/path/to/vllm.rs" } 11 | 12 | [features] 13 | cuda = ["vllm_rs/cuda"] 14 | ``` 15 | 16 | Use the same Cargo features you would use for the CLI (`cuda`, `metal`, `nccl`, etc.). 17 | 18 | ## Direct generation (text) 19 | 20 | ```rust 21 | use vllm_rs::api::{EngineBuilder, ModelRepo}; 22 | use vllm_rs::server::{ChatMessage, MessageContentType}; 23 | use vllm_rs::utils::{config::SamplingParams, log_throughput}; 24 | 25 | fn main() -> anyhow::Result<()> { 26 | let mut engine = 27 | EngineBuilder::new(ModelRepo::ModelID(("google/gemma-3-4b-it", None))).build()?; 28 | 29 | let messages = vec![ChatMessage { 30 | role: "user".to_string(), 31 | content: MessageContentType::PureText("Say hello from the Rust API.".to_string()), 32 | }]; 33 | 34 | let params = SamplingParams::default(); 35 | let output = engine.generate(params, messages)?; 36 | println!("\n\n{}", output.decode_output); 37 | 38 | log_throughput(&vec![output]); 39 | } 40 | ``` 41 | 42 | ## Multimodal request (URL or base64) 43 | 44 | ```rust 45 | use vllm_rs::api::{EngineBuilder, ModelRepo}; 46 | use vllm_rs::server::{ChatMessage, MessageContent, MessageContentType}; 47 | use vllm_rs::utils::config::SamplingParams; 48 | 49 | fn main() -> candle_core::Result<()> { 50 | let mut engine = EngineBuilder::new(ModelRepo::ModelID(( 51 | "Qwen/Qwen3-VL-8B-Instruct".to_string(), 52 | None, 53 | ))) 54 | .build()?; 55 | 56 | let messages = vec![ChatMessage { 57 | role: "user".to_string(), 58 | content: MessageContentType::Multi(vec![ 59 | MessageContent::Text { 60 | text: "Describe this image:".to_string(), 61 | }, 62 | MessageContent::ImageUrl { 63 | image_url: "https://example.com/cat.png".to_string(), 64 | }, 65 | ]), 66 | }]; 67 | 68 | let params = SamplingParams::default(); 69 | let output = engine.generate(params, messages)?; 70 | println!("{}", output.decode_output); 71 | 72 | Ok(()) 73 | } 74 | ``` 75 | 76 | ## Serve API 77 | 78 | ```rust 79 | use vllm_rs::api::{EngineBuilder, ModelRepo}; 80 | 81 | fn main() -> candle_core::Result<()> { 82 | let mut engine = EngineBuilder::new(ModelRepo::ModelID(( 83 | "Qwen/Qwen3-0.6B".to_string(), 84 | None, 85 | ))) 86 | .build()?; 87 | 88 | engine.start_server(8000, true, false)?; 89 | Ok(()) 90 | } 91 | ``` 92 | 93 | ## Multi-rank / multi-GPU 94 | 95 | Provide `device_ids` with `with_multirank` (e.g. `"0,1"`) along with the same CUDA/NCCL features 96 | you use for the CLI. The Rust API reuses the same engine and scheduler path. 97 | 98 | ```rust 99 | use vllm_rs::api::{EngineBuilder, ModelRepo}; 100 | 101 | fn main() -> candle_core::Result<()> { 102 | let mut engine = EngineBuilder::new(ModelRepo::ModelFile(vec![ 103 | "/path/Qwen3-VL-8B-Instruct-GGUF-Q4_KM.gguf".to_string(), 104 | ])) 105 | .with_multirank("0,1")? 106 | .build()?; 107 | 108 | engine.start_server(8000, true, true)?; 109 | Ok(()) 110 | } 111 | ``` 112 | ## Command to run 113 | 114 | [Reference Rust demo](/example/rust-demo/) 115 | 116 | ```shell 117 | # add `nccl` feature for multirank inference (and copy `runner` (which is built with build.sh) to your target path) 118 | cargo run --release --features cuda,graph 119 | ``` -------------------------------------------------------------------------------- /src/utils/gguf_varbuilder.rs: -------------------------------------------------------------------------------- 1 | use candle::quantized::QTensor; 2 | use candle::{Device, Result, Shape}; 3 | use candle_core as candle; 4 | use std::fs::File; 5 | use std::sync::Arc; 6 | use std::sync::Mutex; 7 | // light-cached qvarbuilder 8 | 9 | #[derive(Clone)] 10 | pub struct VarBuilder { 11 | content: Arc, 12 | file: Arc>, // Keep file open for lazy loading 13 | cache: Arc)>>>, // last cached tensor 14 | path: Vec, 15 | device: Device, 16 | } 17 | 18 | impl VarBuilder { 19 | pub fn from_gguf>(p: P, device: &Device) -> Result { 20 | let mut file = File::open(&p)?; 21 | let content = candle_core::quantized::gguf_file::Content::read(&mut file)?; 22 | Ok(Self { 23 | content: Arc::new(content), 24 | file: Arc::new(std::sync::Mutex::new(file)), 25 | cache: Arc::new(Mutex::new(None)), 26 | path: Vec::new(), 27 | device: device.clone(), 28 | }) 29 | } 30 | 31 | pub fn pp(&self, s: S) -> Self { 32 | let mut path = self.path.clone(); 33 | path.push(s.to_string()); 34 | Self { 35 | content: self.content.clone(), 36 | file: self.file.clone(), 37 | cache: self.cache.clone(), 38 | path, 39 | device: self.device.clone(), 40 | } 41 | } 42 | 43 | pub fn path(&self, tensor_name: &str) -> String { 44 | if self.path.is_empty() { 45 | tensor_name.to_string() 46 | } else { 47 | [&self.path.join("."), tensor_name].join(".") 48 | } 49 | } 50 | 51 | pub fn get>(&self, s: S, name: &str) -> Result> { 52 | let path = self.path(name); 53 | 54 | // Check cache 55 | { 56 | let cache_guard = self.cache.lock().unwrap(); 57 | if let Some((ref cached_name, ref cached_tensor)) = *cache_guard { 58 | if cached_name == &path { 59 | // Return cached tensor 60 | let shape = s.into(); 61 | if cached_tensor.shape() != &shape { 62 | candle::bail!( 63 | "shape mismatch for {name}, got {:?}, expected {shape:?}", 64 | cached_tensor.shape() 65 | ); 66 | } 67 | return Ok(cached_tensor.clone()); 68 | } 69 | } 70 | } 71 | 72 | let mut file = self.file.lock().unwrap(); 73 | let tensor = self.content.tensor(&mut *file, &path, &self.device)?; 74 | let tensor = Arc::new(tensor); 75 | // Update cache 76 | *self.cache.lock().unwrap() = Some((path.clone(), tensor.clone())); 77 | 78 | let shape = s.into(); 79 | if tensor.shape() != &shape { 80 | candle::bail!( 81 | "shape mismatch for {name}, got {:?}, expected {shape:?}", 82 | tensor.shape() 83 | ); 84 | } 85 | Ok(tensor) 86 | } 87 | 88 | pub fn get_no_shape(&self, name: &str) -> Result> { 89 | let mut file = self.file.lock().unwrap(); 90 | let tensor = self.content.tensor(&mut *file, name, &self.device)?; 91 | Ok(Arc::new(tensor)) 92 | } 93 | 94 | pub fn device(&self) -> &Device { 95 | &self.device 96 | } 97 | 98 | pub fn contains_key(&self, key: &str) -> bool { 99 | self.content.tensor_infos.contains_key(key) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /src/models/gemma3/config.rs: -------------------------------------------------------------------------------- 1 | use crate::serde_default; 2 | use crate::utils::config::EosTokenId; 3 | use crate::utils::config::QuantConfig; 4 | use crate::utils::config::RopeScalingValue; 5 | use candle_nn::Activation; 6 | use std::collections::HashMap; 7 | 8 | serde_default!(usize, hidden_size, 768); 9 | serde_default!(usize, intermediate_size, 3072); 10 | serde_default!(usize, num_hidden_layers, 12); 11 | serde_default!(usize, num_attention_heads_vision, 12); 12 | serde_default!(usize, num_channels, 3); 13 | serde_default!(usize, image_size, 224); 14 | serde_default!(usize, patch_size, 16); 15 | serde_default!(Activation, hidden_act, Activation::GeluPytorchTanh); 16 | serde_default!(f64, layer_norm_eps, 1e-6); 17 | 18 | #[derive(Debug, Clone, serde::Deserialize)] 19 | pub struct VisionConfig { 20 | #[serde(default = "hidden_size")] 21 | pub hidden_size: usize, 22 | #[serde(default = "intermediate_size")] 23 | pub intermediate_size: usize, 24 | #[serde(default = "num_hidden_layers")] 25 | pub num_hidden_layers: usize, 26 | #[serde(default = "num_attention_heads_vision")] 27 | pub num_attention_heads: usize, 28 | #[serde(default = "num_channels")] 29 | pub num_channels: usize, 30 | #[serde(default = "image_size")] 31 | pub image_size: usize, 32 | #[serde(default = "patch_size")] 33 | pub patch_size: usize, 34 | #[serde(default = "hidden_act")] 35 | pub hidden_act: Activation, 36 | #[serde(default = "layer_norm_eps")] 37 | pub layer_norm_eps: f64, 38 | } 39 | 40 | serde_default!(bool, attention_bias, false); 41 | serde_default!(usize, head_dim, 256); 42 | serde_default!(Activation, hidden_activation, Activation::GeluPytorchTanh); 43 | serde_default!(f64, rms_norm_eps, 1e-6); 44 | serde_default!(f64, rope_theta, 1000000.); 45 | serde_default!(usize, vocab_size, 262208); 46 | serde_default!(bool, tie_word_embeddings, true); 47 | serde_default!(usize, query_pre_attn_scalar, 256); 48 | serde_default!(usize, max_position_embeddings, 131072); 49 | serde_default!(f64, rope_local_base_freq, 10000.); 50 | serde_default!(usize, sliding_window_pattern, 6); 51 | serde_default!(usize, num_attention_heads_text, 8); 52 | serde_default!(usize, num_key_value_heads, 4); 53 | 54 | #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] 55 | pub struct TextConfig { 56 | #[serde(default = "attention_bias")] 57 | pub attention_bias: bool, 58 | #[serde(default = "head_dim")] 59 | pub head_dim: usize, 60 | #[serde(default = "hidden_activation")] 61 | pub hidden_activation: Activation, 62 | pub hidden_size: usize, 63 | pub intermediate_size: usize, 64 | #[serde(default = "num_attention_heads_text")] 65 | pub num_attention_heads: usize, 66 | pub num_hidden_layers: usize, 67 | #[serde(default = "num_key_value_heads")] 68 | pub num_key_value_heads: usize, 69 | #[serde(default = "rms_norm_eps")] 70 | pub rms_norm_eps: f64, 71 | #[serde(default = "rope_theta")] 72 | pub rope_theta: f64, 73 | #[serde(default = "vocab_size")] 74 | pub vocab_size: usize, 75 | pub sliding_window: usize, 76 | pub attn_logit_softcapping: Option, 77 | pub final_logit_softcapping: Option, 78 | #[serde(default = "query_pre_attn_scalar")] 79 | pub query_pre_attn_scalar: usize, 80 | #[serde(default = "max_position_embeddings")] 81 | pub max_position_embeddings: usize, 82 | pub quantization_config: Option, 83 | #[serde(default = "tie_word_embeddings")] 84 | pub tie_word_embeddings: bool, 85 | #[serde(default = "rope_local_base_freq")] 86 | pub rope_local_base_freq: f64, 87 | #[serde(default = "sliding_window_pattern")] 88 | pub sliding_window_pattern: usize, 89 | pub rope_scaling: Option>, 90 | pub quant: Option, 91 | } 92 | 93 | fn has_vision() -> bool { 94 | true 95 | } 96 | 97 | #[derive(Debug, Clone, serde::Deserialize)] 98 | pub struct Gemma3Config { 99 | pub text_config: TextConfig, 100 | pub vision_config: VisionConfig, 101 | pub image_token_index: usize, 102 | pub mm_tokens_per_image: usize, 103 | pub eos_token_id: Option, 104 | #[serde(default = "has_vision")] 105 | pub has_vision: bool, 106 | } 107 | -------------------------------------------------------------------------------- /docs/get_started.md: -------------------------------------------------------------------------------- 1 | # Get Started 2 | 3 | This guide walks through building and running vLLM.rs across CUDA/Metal, different model formats, multi-rank, PD Disaggregation, and OpenAI-compatible APIs. Commands assume repo root and `./run.sh` (wrapper around `cargo build/run`). 4 | 5 | ## 1) Build & features 6 | - **Backends**: `--features cuda[,nccl,graph,flash-attn,flash-context]` or `--features metal`. CPU-only is supported but slow. 7 | - **Quant/accel toggles**: `--fp8-kvcache` (KV in FP8, CUDA), `--flash-context` (Ampere+, long compile), `--context-cache` (session reuse). 8 | - **Python bindings**: add feature `python` when building wheels (`./build.sh --features python`). 9 | 10 | ## 2) Model formats 11 | - **Safetensors (HF layout)**: `--m ` for cached download, or `--w ` for offline weights + configs. 12 | - **GGUF**: `--f `; no configs needed. For safetensors, you may in-situ quantize with `--isq `. 13 | - **Vision-Language** (Qwen3-VL, Gemma3, Mistral3-VL): require image tokens; use `--ui-server` for uploads or send image_url/base64 in the request. 14 | 15 | ## 3) Run patterns (single host) 16 | - **CUDA text model (chat/server)** 17 | ```bash 18 | ./run.sh --release --features cuda -- --server \ 19 | --m Qwen/Qwen2.5-7B-Instruct --max-model-len 131072 \ 20 | --kv-fraction 0.6 --context-cache --ui-server 21 | ``` 22 | - **Metal (Mac) text model** 23 | ```bash 24 | ./run.sh --release --features metal -- --server \ 25 | --m meta-llama/Llama-3-8b --max-model-len 32768 --ui-server 26 | ``` 27 | - **GGUF quantized** 28 | ```bash 29 | ./run.sh --release --features cuda -- --server \ 30 | --f /path/model-Q4_K_M.gguf --max-model-len 65536 --context-cache 31 | ``` 32 | - **Embeddings** (same server; OpenAI `/v1/embeddings`) 33 | ```bash 34 | ./run.sh --release --features cuda -- --server \ 35 | --m Qwen/Qwen2.5-7B-Instruct --context-cache 36 | # curl -d '{"input":"hello","embedding_type":"mean"}' http://localhost:8000/v1/embeddings 37 | ``` 38 | - **Multimodal** 39 | ```bash 40 | ./run.sh --release --features cuda -- --server \ 41 | --m Qwen/Qwen3-VL-8B-Instruct --ui-server --context-cache 42 | ``` 43 | 44 | Common runtime knobs: `--max-model-len`, `--max-num-seqs`, `--kv-fraction` (CUDA KV share), `--cpu-mem-fold` (CPU swap ratio), `--port`, `--fp8-kvcache`, `--context-cache`, `--ui-server`, `--batch` (perf test). 45 | 46 | ## 4) Multi-rank (single node) 47 | - **NCCL multi-GPU** 48 | ```bash 49 | ./run.sh --release --features cuda,nccl -- --server \ 50 | --m Qwen/Qwen3-30B-A3B-Instruct-2507 --d 0,1 --max-num-seqs 2 --kv-fraction 0.5 51 | ``` 52 | - **Graph capture (Ampere+)**: add `--features graph,flash-context` for fastest long-context prefill/decoding (compilation time increases). 53 | 54 | ## 5) PD Disaggregation (prefill/decoding split) 55 | - **PD server (prefill host, usually memory-rich)** 56 | ```bash 57 | ./run.sh --release --features cuda -- --server --pd-server --port 8000 \ 58 | --m Qwen/Qwen3-30B-A3B-Instruct-2507 --context-cache 59 | ``` 60 | - **PD client (decode host)** 61 | ```bash 62 | ./run.sh --release --features cuda -- --server --pd-client --pd-url 0.0.0.0:8000 \ 63 | --m Qwen/Qwen3-30B-A3B-Instruct-2507 --context-cache 64 | ``` 65 | - Same weights/config on both ends; Local IPC used automatically on same node CUDA, TCP when `--pd-url` is set. Monitor logs for transfer and swap events. 66 | 67 | ## 6) Context cache 68 | - Enable with `--context-cache` (CUDA/Metal). Reuse a `session_id` across turns to skip re-prefill. 69 | First turn: `{"messages":[...],"session_id":"chat-123"}`; follow-up: send only new message with same `session_id`. 70 | - Tune `--max-model-len`, `--kv-fraction`, `--cpu-mem-fold`; avoid overcommitting KV or cache will swap/evict. 71 | 72 | ## 7) APIs (OpenAI-style) 73 | - Chat: `POST /v1/chat/completions` (supports `stream=true`, images for VL models). 74 | - Embeddings: `POST /v1/embeddings` (`embedding_type=mean|last`, `encoding_format=float|base64`). 75 | - Models: `GET /v1/models`; Usage: `GET /v1/usage?session_id=...`. 76 | - UI: add `--ui-server` to expose the built-in web UI on port 8001. 77 | 78 | ## 8) Troubleshooting & tuning 79 | - Use `--log` to view loading/progress; watch for “swap” messages (KV pressure). 80 | - If OOM on Metal, lower `--max-model-len` and batch; on CUDA, reduce `--kv-fraction` or `--max-num-seqs`. 81 | - For GGUF/ISQ, keep `--max-num-seqs` moderate to avoid bandwidth bottlenecks; consider `--fp8-kvcache` only on Ampere+. 82 | -------------------------------------------------------------------------------- /example/rust-demo-tools/src/main.rs: -------------------------------------------------------------------------------- 1 | //! Rust Tool Calling Example for vLLM.rs 2 | //! 3 | //! This example demonstrates: 4 | //! 1. Defining tools using the builder pattern 5 | //! 2. Parsing tool calls from model output 6 | //! 3. Handling tool results 7 | 8 | use vllm_rs::tools::parser::ToolParser; 9 | use vllm_rs::tools::schema::SchemaBuilder; 10 | use vllm_rs::tools::{Tool, ToolCall, ToolFormat, ToolResult}; 11 | 12 | fn main() -> anyhow::Result<()> { 13 | println!("🛠️ vLLM.rs Tool Calling Demo (Rust API)\n"); 14 | 15 | // === Part 1: Define Tools === 16 | println!("=== Part 1: Defining Tools ===\n"); 17 | 18 | // Using builder pattern 19 | let weather_tool = Tool::function("get_weather", "Get the current weather for a location") 20 | .param("location", "string", "The city name", true) 21 | .param( 22 | "unit", 23 | "string", 24 | "Temperature unit (celsius/fahrenheit)", 25 | false, 26 | ) 27 | .build(); 28 | 29 | // Using schema builder for more complex schemas 30 | let search_schema = SchemaBuilder::object() 31 | .description("Web search parameters") 32 | .string_prop("query", "Search query", true) 33 | .integer_prop("max_results", "Maximum results", false) 34 | .build(); 35 | 36 | let search_tool = Tool::function("search_web", "Search the web for information") 37 | .parameters_schema(search_schema) 38 | .build(); 39 | 40 | println!( 41 | "Weather Tool: {}", 42 | serde_json::to_string_pretty(&weather_tool)? 43 | ); 44 | println!( 45 | "\nSearch Tool: {}", 46 | serde_json::to_string_pretty(&search_tool)? 47 | ); 48 | 49 | // === Part 2: Format Tools for Prompts === 50 | println!("\n=== Part 2: Tool Prompt Formatting ===\n"); 51 | 52 | let tools = vec![weather_tool.clone(), search_tool.clone()]; 53 | 54 | // Qwen format 55 | let qwen_prompt = ToolFormat::Qwen.format_tools(&tools); 56 | println!("Qwen Format:\n{}\n", qwen_prompt); 57 | 58 | // === Part 3: Parse Tool Calls from Model Output === 59 | println!("=== Part 3: Parsing Tool Calls ===\n"); 60 | 61 | let parser = ToolParser::new(); 62 | 63 | // Simulate model output with tool call 64 | let model_outputs = vec![ 65 | // Qwen format 66 | r#"I'll check the weather for you. 67 | 68 | {"name": "get_weather", "arguments": {"location": "Tokyo", "unit": "celsius"}} 69 | "#, 70 | // JSON format 71 | r#"Let me search for that. {"name": "search_web", "arguments": {"query": "Rust programming"}}"#, 72 | // Code block format 73 | r#"Here's my search: 74 | ```json 75 | {"name": "get_weather", "arguments": {"location": "London"}} 76 | ```"#, 77 | ]; 78 | 79 | for (i, output) in model_outputs.iter().enumerate() { 80 | println!("Model Output {}:", i + 1); 81 | println!(" Input: {:?}", output.lines().next().unwrap_or("")); 82 | 83 | let calls = parser.parse(output); 84 | if calls.is_empty() { 85 | println!(" No tool calls detected\n"); 86 | } else { 87 | for call in &calls { 88 | println!( 89 | " Parsed: {} with {}", 90 | call.function.name, call.function.arguments 91 | ); 92 | } 93 | println!(); 94 | } 95 | } 96 | 97 | // === Part 4: Handle Tool Results === 98 | println!("=== Part 4: Tool Results ===\n"); 99 | 100 | // Simulate executing a tool and creating result 101 | let call = ToolCall::new("call_001", "get_weather", r#"{"location": "Tokyo"}"#); 102 | 103 | // Execute tool (simulated) 104 | let weather_data = serde_json::json!({ 105 | "location": "Tokyo", 106 | "temperature": 22, 107 | "unit": "celsius", 108 | "condition": "sunny" 109 | }); 110 | 111 | // Create result 112 | let result = ToolResult::success(&call.id, weather_data.to_string()); 113 | println!("Tool Call ID: {}", call.id); 114 | println!("Tool Result: {}", result.content); 115 | 116 | // Error example 117 | let error_result = ToolResult::error("call_002", "API rate limit exceeded"); 118 | println!( 119 | "\nError Result: {} (is_error: {:?})", 120 | error_result.content, error_result.is_error 121 | ); 122 | 123 | println!("\n✅ Demo complete!"); 124 | 125 | Ok(()) 126 | } 127 | -------------------------------------------------------------------------------- /example/server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import warnings 4 | from vllm_rs import Engine, EngineConfig, GenerationConfig, PdConfig, PdMethod, PdRole 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser(description="Run Chat Server") 8 | parser.add_argument("--host", type=str, default="0.0.0.0") 9 | parser.add_argument("--port", type=int, default=8000) 10 | parser.add_argument("--m", help="huggingface model id", type=str, default=None) 11 | parser.add_argument("--w", help="safetensor weight path", type=str, default=None) 12 | parser.add_argument("--f", help="gguf file path or gguf file name when model_id is given", type=str, default=None) 13 | parser.add_argument("--dtype", choices=["f16", "bf16", "f32"], default="bf16") 14 | parser.add_argument("--max-num-seqs", type=int, default=2) 15 | parser.add_argument("--max-model-len", type=int, default=None) 16 | parser.add_argument("--max-tokens", type=int, default=32768) 17 | parser.add_argument("--d", type=str, default="0") 18 | parser.add_argument("--isq", type=str, default=None) 19 | parser.add_argument("--temperature", type=float, default=None) 20 | parser.add_argument("--top-p", type=float, default=None) 21 | parser.add_argument("--top-k", type=int, default=None) 22 | parser.add_argument("--frequency-penalty", type=float, default=None) 23 | parser.add_argument("--presence-penalty", type=float, default=None) 24 | parser.add_argument("--context-cache", action="store_true") 25 | parser.add_argument("--fp8-kvcache", action="store_true") 26 | parser.add_argument("--cpu-mem-fold", type=float, default=None) 27 | parser.add_argument("--kv-fraction", type=float, default=None) 28 | parser.add_argument("--pd-server", action="store_true") 29 | parser.add_argument("--pd-client", action="store_true") 30 | parser.add_argument("--pd-url", help="Url like `192.168.1.100:8888` \ 31 | used for TCP/IP communication between PD server and client", type=str, default=None) 32 | parser.add_argument("--ui-server", action="store_true") 33 | 34 | return parser.parse_args() 35 | 36 | def main(): 37 | args = parse_args() 38 | 39 | # limit default max_num_seqs to 1 on MacOs (due to limited gpu memory) 40 | max_num_seqs = 1 if sys.platform == "darwin" else args.max_num_seqs 41 | # max_model_len = 32768 if sys.platform == "darwin" else args.max_model_len 42 | # if args.max_model_len is None: 43 | # if max_num_seqs > 0: 44 | # max_model_len = max_model_len // max_num_seqs 45 | # else: 46 | # max_model_len = args.max_model_len 47 | 48 | generation_cfg = None 49 | if (args.temperature != None and (args.top_p != None or args.top_k != None)) or args.frequency_penalty != None or args.presence_penalty != None: 50 | generation_cfg = GenerationConfig(args.temperature, args.top_p, args.top_k, args.frequency_penalty, args.presence_penalty) 51 | 52 | assert args.m or args.w or args.f, "Must provide model_id or weight_path or weight_file!" 53 | if args.max_model_len != None: 54 | args.max_tokens = args.max_model_len if args.max_tokens > args.max_model_len else args.max_tokens 55 | 56 | assert args.max_model_len == None or args.kv_fraction == None, "You provided both max_model_len and kv_fraction!" 57 | 58 | pd_config = None 59 | if args.pd_server or args.pd_client: 60 | pd_role = PdRole.Server if args.pd_server else PdRole.Client 61 | pd_method = PdMethod.RemoteTcp if args.pd_url != None else PdMethod.LocalIpc 62 | pd_config = PdConfig(role=pd_role, method=pd_method, url=args.pd_url) 63 | 64 | cfg = EngineConfig( 65 | model_id=args.m, 66 | weight_path=args.w, 67 | weight_file=args.f, 68 | max_num_seqs=max_num_seqs, 69 | max_model_len=args.max_model_len, 70 | max_tokens=args.max_tokens, 71 | isq=args.isq, 72 | device_ids=[int(d) for d in args.d.split(",")], 73 | generation_cfg=generation_cfg, 74 | flash_context=args.context_cache, 75 | fp8_kvcache=args.fp8_kvcache, 76 | server_mode=True, 77 | cpu_mem_fold=args.cpu_mem_fold, 78 | kv_fraction=args.kv_fraction, 79 | pd_config=pd_config, 80 | ) 81 | 82 | engine = Engine(cfg, args.dtype) 83 | 84 | # max_kvcache_tokens = max_model_len * max_num_seqs 85 | # if args.max_model_len is None: 86 | # warnings.warn(f"Warning: max_model_len is not given, default to {max_model_len}, max kvcache tokens {max_kvcache_tokens}.") 87 | engine.start_server(args.port, args.ui_server) # this will block 88 | 89 | 90 | if __name__ == "__main__": 91 | main() 92 | -------------------------------------------------------------------------------- /src/utils/chat_template.rs: -------------------------------------------------------------------------------- 1 | use minijinja::{context, Environment}; 2 | #[cfg(feature = "python")] 3 | use pyo3::pyclass; 4 | 5 | #[cfg(feature = "python")] 6 | #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] 7 | #[pyclass] 8 | pub struct Message { 9 | #[pyo3(get)] 10 | pub role: String, 11 | #[pyo3(get)] 12 | pub content: String, 13 | pub num_images: usize, 14 | } 15 | 16 | #[cfg(not(feature = "python"))] 17 | #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] 18 | pub struct Message { 19 | pub role: String, 20 | pub content: String, 21 | pub num_images: usize, 22 | } 23 | 24 | #[cfg(not(feature = "python"))] 25 | impl Message { 26 | pub fn new(role: String, content: String, num_images: usize) -> Self { 27 | Message { 28 | role, 29 | content, 30 | num_images, 31 | } 32 | } 33 | } 34 | 35 | #[derive(thiserror::Error, Debug)] 36 | pub enum ApplyChatTemplateError { 37 | #[error("failed to add chat template")] 38 | AddTemplateError(#[source] minijinja::Error), 39 | #[error("failed to get chat template")] 40 | GetTemplateError(#[source] minijinja::Error), 41 | #[error("failed to render chat template")] 42 | RenderTemplateError(#[source] minijinja::Error), 43 | } 44 | 45 | #[derive(Clone, Debug)] 46 | pub struct ChatTemplate { 47 | system_message: Option, 48 | chat_template: Option, 49 | bos_token: Option, 50 | eos_token: Option, 51 | messages: Vec, 52 | add_generation_prompt: bool, 53 | enable_thinking: bool, 54 | } 55 | 56 | impl ChatTemplate { 57 | pub fn new( 58 | system_message: Option, 59 | chat_template: Option, 60 | bos_token: Option, 61 | eos_token: Option, 62 | prompt: Option, 63 | add_generation_prompt: bool, 64 | enable_thinking: bool, 65 | ) -> Self { 66 | let mut template = ChatTemplate { 67 | system_message: system_message.clone(), 68 | chat_template, 69 | bos_token, 70 | eos_token, 71 | messages: Vec::new(), 72 | add_generation_prompt, 73 | enable_thinking, 74 | }; 75 | if system_message.is_some() { 76 | template.append_message( 77 | "system".to_string(), 78 | template.system_message.clone().unwrap_or_default(), 79 | 0, 80 | ); 81 | } 82 | if let Some(prompt) = prompt { 83 | template.append_message("user".to_string(), prompt, 0); 84 | } 85 | template 86 | } 87 | 88 | pub fn append_message(&mut self, role: String, content: String, num_images: usize) { 89 | self.messages.push(Message { 90 | role, 91 | content, 92 | num_images, 93 | }); 94 | } 95 | 96 | pub fn set_messages(&mut self, messages: &Vec) { 97 | self.messages.clear(); 98 | self.messages.extend(messages.clone()); 99 | } 100 | 101 | #[allow(dead_code)] 102 | fn clear_message(&mut self) { 103 | self.messages.clear() 104 | } 105 | 106 | pub fn apply_chat_template(&self, log: bool) -> Result { 107 | if self.chat_template.is_none() { 108 | return Err(ApplyChatTemplateError::GetTemplateError( 109 | minijinja::Error::new(minijinja::ErrorKind::CannotDeserialize, "Not found!"), 110 | )); 111 | } 112 | let mut env = Environment::new(); 113 | env.set_lstrip_blocks(true); 114 | env.set_trim_blocks(true); 115 | env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback); 116 | let template = self.chat_template.as_ref().unwrap(); 117 | let mut template = template.replace("[::-1]", "|reverse"); 118 | if template.find("{{ meta }}").is_some() { 119 | template = template.replace("{%- set meta = message.get(\"metadata\", \"\") %}", ""); 120 | template = template.replace("{{ meta }}", ""); 121 | } 122 | env.add_template("vllm.rs", template.as_str()) 123 | .map_err(ApplyChatTemplateError::AddTemplateError)?; 124 | let template = env 125 | .get_template("vllm.rs") 126 | .map_err(ApplyChatTemplateError::GetTemplateError)?; 127 | 128 | if log { 129 | tracing::info!("messages {:?}", self.messages); 130 | } 131 | template 132 | .render(context! { 133 | messages => self.messages, 134 | add_generation_prompt => self.add_generation_prompt, 135 | bos_token => self.bos_token, 136 | eos_token => self.eos_token, 137 | enable_thinking => self.enable_thinking, 138 | }) 139 | .map_err(ApplyChatTemplateError::RenderTemplateError) 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /vllm_rs.pyi: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Iterator, Tuple, Union, List, Literal, Mapping, Optional, Callable 3 | from enum import Enum 4 | from collections.abc import AsyncGenerator 5 | from typing import Any 6 | 7 | @dataclass 8 | class DType(Enum): 9 | F16 = "f16" 10 | BF16 = "bf16" 11 | F32 = "f32" 12 | 13 | @dataclass 14 | class PdRole(Enum): 15 | Client = 1 16 | PDServer = 2 17 | 18 | @dataclass 19 | class PdMethod(Enum): 20 | LocalIpc = 1 21 | RemoteTcp = 2 22 | 23 | @dataclass 24 | class PdConfig: 25 | role: PdRole 26 | method = PdMethod 27 | url: Optional[str] 28 | 29 | 30 | @dataclass 31 | class GenerationOutput: 32 | seq_id: int 33 | prompt_length: int 34 | prompt_start_time: int 35 | decode_start_time: int 36 | decode_finish_time: int 37 | decoded_length: int 38 | decode_output: str 39 | 40 | @dataclass 41 | class GenerationConfig: 42 | temperature: Optional[float] 43 | top_p: Optional[float] 44 | top_k: Optional[int] 45 | frequency_penalty: Optional[float] 46 | presence_penalty: Optional[float] 47 | 48 | @dataclass 49 | class EngineConfig: 50 | model_id: Optional[str] 51 | weight_path: Optional[str] 52 | weight_file: Optional[str] 53 | hf_token: Optional[str] 54 | hf_token_path: Optional[str] 55 | tokenizer: Optional[str] 56 | tokenizer_config: Optional[str] 57 | num_blocks: Optional[int] 58 | kv_fraction: Optional[float] 59 | cpu_mem_fold: Optional[float] 60 | max_num_seqs: Optional[int] 61 | max_model_len: Optional[int] 62 | max_tokens: Optional[int] 63 | max_num_batched_tokens: Optional[int] 64 | isq: Optional[str] 65 | num_shards: Optional[int] 66 | device_id: Optional[int] 67 | generation_cfg: Optional[GenerationConfig] 68 | seed: Optional[int] 69 | flash_context: Optional[bool] 70 | fp8_kvcache: Optional[bool] 71 | server_mode: Optional[bool] 72 | pd_config: Optional[PdConfig] 73 | 74 | @dataclass 75 | class SamplingParams: 76 | temperature: Optional[float] 77 | max_tokens: Optional[int] 78 | ignore_eos: Optional[bool] 79 | top_k: Optional[int] 80 | top_p: Optional[float] 81 | session_id: Optional[str] 82 | frequency_penalty: Optional[float] 83 | presence_penalty: Optional[float] 84 | 85 | @dataclass 86 | class Message: 87 | role: str 88 | content: str 89 | 90 | @dataclass 91 | class StepOutput(Enum): 92 | Token: int 93 | Tokens: List[int] 94 | 95 | @dataclass 96 | class StreamItem: 97 | """ 98 | An item returned by the EngineStream iterator. 99 | Check the `type` attribute to determine how to interpret the `data`. 100 | """ 101 | @property 102 | def datatype(self) -> Literal["TOKEN", "TOKEN_ID", "COMPLETION", "DONE", "ERROR"]: 103 | """The type of the stream item.""" 104 | ... 105 | 106 | @property 107 | def data(self) -> Union[ 108 | str, # For TOKEN or ERROR 109 | int, # For TOKEN_ID 110 | Tuple[int, int, int, int], # For DONE 111 | Tuple[int, int, int, List[int]] # For COMPLETION 112 | ]: 113 | """The data payload of the stream item.""" 114 | ... 115 | 116 | class EngineStream: 117 | finished: bool 118 | seq_id: int 119 | prompt_length: int 120 | cancelled: bool 121 | def cancel(self): ... 122 | def __iter__(self) -> Iterator[StreamItem]: ... 123 | def __next__(self) -> StreamItem: ... 124 | 125 | class Engine: 126 | def __init__(econfig: EngineConfig, dtype: DType) -> Engine: 127 | """ 128 | Create a vllm.rs engine with given engine config and dtype ("f16", "bf16", and "f32") 129 | """ 130 | 131 | def generate_sync(self, 132 | params: List[SamplingParams], 133 | message_list: List[List[Message]], 134 | ) -> List[GenerationOutput]: 135 | """ 136 | Chat completion using given prompts and sampling parameters 137 | """ 138 | def generate_stream( 139 | self, 140 | params: SamplingParams, 141 | messages: List[Message], 142 | ) -> Tuple[int, int, EngineStream]: 143 | """ 144 | Chat streaming using given prompts and sampling parameters. 145 | 146 | Return: (seq_id, prompt_length, stream) tuples 147 | """ 148 | 149 | def get_num_cached_tokens( 150 | self, 151 | ) -> int: 152 | """ 153 | Call this function when context-cache feature enabled 154 | 155 | Return: total number of context cached for all requests 156 | """ 157 | 158 | def get_available_kv_tokens( 159 | self, 160 | ) -> int: 161 | """ 162 | Return: total number of available kvcache tokens 163 | """ 164 | 165 | def start_server( 166 | self, 167 | port: int, 168 | with_ui_server: bool, 169 | ): 170 | """ 171 | Start the API server with optional start of Chat UI server 172 | """ -------------------------------------------------------------------------------- /src/models/layers/mlp.rs: -------------------------------------------------------------------------------- 1 | use crate::models::layers::distributed::{ 2 | shard, Comm, ReplicatedLinear, TensorParallelColumnLinear, TensorParallelRowLinear, 3 | }; 4 | use crate::models::layers::VarBuilderX; 5 | use crate::utils::config::QuantConfig; 6 | use candle_core::{DType, Result, Tensor}; 7 | use candle_nn::{Activation, Module}; 8 | use std::collections::HashMap; 9 | use std::rc::Rc; 10 | 11 | pub struct MLP { 12 | gate_proj: TensorParallelColumnLinear, 13 | up_proj: TensorParallelColumnLinear, 14 | down_proj: TensorParallelRowLinear, 15 | activation: Activation, 16 | } 17 | 18 | impl MLP { 19 | pub fn new( 20 | vb: VarBuilderX, 21 | comm: Rc, 22 | hidden_size: usize, 23 | intermediate_size: usize, 24 | activation: &Activation, 25 | quant_cfg: &Option, 26 | quant: &Option, 27 | gate_up_merged: bool, 28 | dtype: DType, 29 | suffix: &str, 30 | ) -> Result { 31 | let key_map: HashMap<&str, &str> = [ 32 | ("gate_proj", "ffn_gate"), 33 | ("up_proj", "ffn_up"), 34 | ("gate_up_proj", "ffn_up"), 35 | ("down_proj", "ffn_down"), 36 | ] 37 | .iter() 38 | .cloned() 39 | .collect(); 40 | let is_qvar_builder = vb.is_qvar_builder(); 41 | 42 | let gate_proj = TensorParallelColumnLinear::load_with_shard( 43 | hidden_size, 44 | if gate_up_merged { 45 | intermediate_size * 2 46 | } else { 47 | intermediate_size 48 | }, 49 | false, 50 | if is_qvar_builder { 51 | vb.pp((key_map[if gate_up_merged { 52 | "up_proj" 53 | } else { 54 | "gate_proj" 55 | }] 56 | .to_string() 57 | + suffix) 58 | .as_str()) 59 | } else { 60 | vb.pp(if gate_up_merged { 61 | "gate_up_proj" 62 | } else { 63 | "gate_proj" 64 | }) 65 | }, 66 | if gate_up_merged { 67 | shard(0, comm.rank(), comm.world_size() * 2) 68 | } else { 69 | shard(0, comm.rank(), comm.world_size()) 70 | }, 71 | quant_cfg, 72 | quant, 73 | dtype, 74 | )?; 75 | 76 | let up_proj = TensorParallelColumnLinear::load_with_shard( 77 | hidden_size, 78 | if gate_up_merged { 79 | intermediate_size * 2 80 | } else { 81 | intermediate_size 82 | }, 83 | false, 84 | if is_qvar_builder { 85 | vb.pp((key_map["up_proj"].to_string() + suffix).as_str()) 86 | } else { 87 | vb.pp(if gate_up_merged { 88 | "gate_up_proj" 89 | } else { 90 | "up_proj" 91 | }) 92 | }, 93 | if gate_up_merged { 94 | shard(0, comm.world_size() + comm.rank(), comm.world_size() * 2) 95 | } else { 96 | shard(0, comm.rank(), comm.world_size()) 97 | }, 98 | quant_cfg, 99 | quant, 100 | dtype, 101 | )?; 102 | 103 | let down_proj = TensorParallelRowLinear::load_with_hints( 104 | intermediate_size, 105 | hidden_size, 106 | if is_qvar_builder { 107 | vb.pp((key_map["down_proj"].to_string() + suffix).as_str()) 108 | } else { 109 | vb.pp("down_proj") 110 | }, 111 | comm.clone(), 112 | quant_cfg, 113 | quant, 114 | dtype, 115 | )?; 116 | 117 | Ok(Self { 118 | gate_proj, 119 | up_proj, 120 | down_proj, 121 | activation: activation.clone(), 122 | }) 123 | } 124 | 125 | pub fn forward(&self, xs: &Tensor) -> Result { 126 | let gate = self.gate_proj.forward(xs)?; 127 | let up = self.up_proj.forward(xs)?; 128 | self.down_proj 129 | .forward(&(self.activation.forward(&gate)? * up)?) 130 | } 131 | } 132 | 133 | pub struct NaiveMLP { 134 | fc1: ReplicatedLinear, 135 | fc2: ReplicatedLinear, 136 | act: Activation, 137 | } 138 | 139 | impl NaiveMLP { 140 | pub fn new( 141 | vb: VarBuilderX, 142 | hidden_size: usize, 143 | intermediate_size: usize, 144 | bias: bool, 145 | names: &[&str], 146 | hidden_act: Activation, 147 | dtype: DType, 148 | ) -> Result { 149 | let fc1 = ReplicatedLinear::load_b( 150 | hidden_size, 151 | intermediate_size, 152 | bias, 153 | vb.pp(names[0]), 154 | &None, 155 | &None, 156 | dtype, 157 | )?; 158 | 159 | let fc2 = ReplicatedLinear::load_b( 160 | intermediate_size, 161 | hidden_size, 162 | bias, 163 | vb.pp(names[1]), 164 | &None, 165 | &None, 166 | dtype, 167 | )?; 168 | 169 | Ok(Self { 170 | fc1, 171 | fc2, 172 | act: hidden_act, 173 | }) 174 | } 175 | 176 | pub fn forward(&self, xs: &Tensor) -> Result { 177 | let gate_up = self.fc1.forward(xs)?; 178 | let down = self.act.forward(&gate_up)?; 179 | self.fc2.forward(&down) 180 | } 181 | } 182 | -------------------------------------------------------------------------------- /example/tokenize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Tokenize/Detokenize API Demo for vLLM.rs 4 | 5 | This script demonstrates how to use the /tokenize and /detokenize API endpoints. 6 | Make sure the vllm.rs server is running before executing this script. 7 | 8 | Usage: 9 | python tokenize.py [--url URL] 10 | """ 11 | 12 | import argparse 13 | import requests 14 | import json 15 | 16 | 17 | def tokenize_text(base_url: str, text: str, add_special_tokens: bool = True) -> dict: 18 | """ 19 | Tokenize plain text using the /tokenize endpoint. 20 | 21 | Args: 22 | base_url: Base URL of the vllm.rs server 23 | text: Text to tokenize 24 | add_special_tokens: Whether to add special tokens (default: True) 25 | 26 | Returns: 27 | API response containing tokens, count, and max_model_len 28 | """ 29 | response = requests.post( 30 | f"{base_url}/tokenize", 31 | json={ 32 | "prompt": text, 33 | "add_special_tokens": add_special_tokens 34 | } 35 | ) 36 | response.raise_for_status() 37 | return response.json() 38 | 39 | 40 | def tokenize_messages(base_url: str, messages: list, add_special_tokens: bool = True) -> dict: 41 | """ 42 | Tokenize chat messages using the /tokenize endpoint. 43 | This applies the model's chat template before tokenization. 44 | 45 | Args: 46 | base_url: Base URL of the vllm.rs server 47 | messages: List of message dicts with 'role' and 'content' keys 48 | add_special_tokens: Whether to add special tokens (default: True) 49 | 50 | Returns: 51 | API response containing tokens, count, and max_model_len 52 | """ 53 | response = requests.post( 54 | f"{base_url}/tokenize", 55 | json={ 56 | "messages": messages, 57 | "add_special_tokens": add_special_tokens 58 | } 59 | ) 60 | response.raise_for_status() 61 | return response.json() 62 | 63 | 64 | def detokenize(base_url: str, tokens: list, skip_special_tokens: bool = True) -> dict: 65 | """ 66 | Convert token IDs back to text using the /detokenize endpoint. 67 | 68 | Args: 69 | base_url: Base URL of the vllm.rs server 70 | tokens: List of token IDs to decode 71 | skip_special_tokens: Whether to skip special tokens in output (default: True) 72 | 73 | Returns: 74 | API response containing the decoded prompt 75 | """ 76 | response = requests.post( 77 | f"{base_url}/detokenize", 78 | json={ 79 | "tokens": tokens, 80 | "skip_special_tokens": skip_special_tokens 81 | } 82 | ) 83 | response.raise_for_status() 84 | return response.json() 85 | 86 | 87 | def main(): 88 | parser = argparse.ArgumentParser(description="Tokenize/Detokenize API Demo") 89 | parser.add_argument("--url", default="http://localhost:8000", help="vllm.rs server URL") 90 | args = parser.parse_args() 91 | 92 | base_url = args.url 93 | print(f"Using server at: {base_url}\n") 94 | 95 | # Example 1: Tokenize plain text 96 | print("=" * 50) 97 | print("Example 1: Tokenize plain text") 98 | print("=" * 50) 99 | text = "Hello, world! How are you today?" 100 | try: 101 | result = tokenize_text(base_url, text) 102 | print(f"Input: {text}") 103 | print(f"Tokens: {result['tokens']}") 104 | print(f"Token count: {result['count']}") 105 | if result.get('max_model_len'): 106 | print(f"Max model length: {result['max_model_len']}") 107 | except requests.exceptions.RequestException as e: 108 | print(f"Error: {e}") 109 | print("Make sure the vllm.rs server is running!") 110 | return 111 | 112 | print() 113 | 114 | # Example 2: Tokenize chat messages (applies chat template) 115 | print("=" * 50) 116 | print("Example 2: Tokenize chat messages") 117 | print("=" * 50) 118 | messages = [ 119 | {"role": "system", "content": "You are a helpful assistant."}, 120 | {"role": "user", "content": "What is 2 + 2?"}, 121 | ] 122 | try: 123 | result = tokenize_messages(base_url, messages) 124 | print(f"Messages: {json.dumps(messages, indent=2)}") 125 | print(f"Token count (with chat template): {result['count']}") 126 | print(f"First 10 tokens: {result['tokens'][:10]}...") 127 | except requests.exceptions.RequestException as e: 128 | print(f"Error: {e}") 129 | 130 | print() 131 | 132 | # Example 3: Detokenize 133 | print("=" * 50) 134 | print("Example 3: Detokenize tokens") 135 | print("=" * 50) 136 | # Use the tokens from Example 1 137 | try: 138 | tokens = tokenize_text(base_url, "Hello!")['tokens'] 139 | print(f"Input tokens: {tokens}") 140 | result = detokenize(base_url, tokens) 141 | print(f"Decoded text: {result['prompt']}") 142 | except requests.exceptions.RequestException as e: 143 | print(f"Error: {e}") 144 | 145 | print() 146 | 147 | # Example 4: Round-trip test 148 | print("=" * 50) 149 | print("Example 4: Round-trip test") 150 | print("=" * 50) 151 | original = "The quick brown fox jumps over the lazy dog." 152 | try: 153 | tokenized = tokenize_text(base_url, original) 154 | detokenized = detokenize(base_url, tokenized['tokens']) 155 | print(f"Original: {original}") 156 | print(f"After round-trip: {detokenized['prompt']}") 157 | print(f"Match: {original == detokenized['prompt']}") 158 | except requests.exceptions.RequestException as e: 159 | print(f"Error: {e}") 160 | 161 | 162 | if __name__ == "__main__": 163 | main() 164 | -------------------------------------------------------------------------------- /example/completion.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import sys 4 | import argparse 5 | import warnings 6 | from vllm_rs import EngineConfig, SamplingParams, Message, GenerationOutput, GenerationConfig, Engine 7 | # Before running this code, first perform maturin build and then install the package in target/wheels 8 | 9 | 10 | def current_millis(): 11 | return int(time.time() * 1000) 12 | 13 | 14 | def run(args): 15 | prompts = args.prompts 16 | if prompts == None: 17 | if args.batch > 1: 18 | prompts = ["Please talk about China in more details."] * args.batch 19 | else: 20 | prompts = ["How are you?", "How to make money?"] 21 | print("⛔️ No prompts found, use default ", prompts) 22 | else: 23 | prompts = prompts.split("|") 24 | if args.batch > 1: 25 | prompts = prompts[0] * args.batch 26 | 27 | if args.batch > 1: 28 | max_num_seqs = args.batch 29 | elif len(prompts) > 0: 30 | max_num_seqs = len(prompts) 31 | else: 32 | # limit default max_num_seqs to 8 on MacOs (due to limited gpu memory) 33 | max_num_seqs = 8 if sys.platform == "darwin" else args.max_num_seqs 34 | 35 | if args.max_model_len is None: 36 | max_model_len = 32768 // max_num_seqs 37 | warnings.warn(f"max_model_len is not given, default to {max_model_len}.") 38 | else: 39 | max_model_len = args.max_model_len 40 | 41 | generation_cfg = None 42 | if (args.temperature != None and (args.top_p != None or args.top_k != None)) or args.frequency_penalty != None or args.presence_penalty != None: 43 | generation_cfg = GenerationConfig(args.temperature, args.top_p, args.top_k, args.frequency_penalty, args.presence_penalty) 44 | 45 | assert args.m or args.w or args.f, "Must provide model_id or weight_path or weight_file!" 46 | cfg = EngineConfig( 47 | model_id=args.m, 48 | weight_path=args.w, 49 | weight_file=args.f, 50 | max_num_seqs=max_num_seqs, 51 | max_model_len=max_model_len, 52 | max_tokens=max_model_len if args.max_tokens > max_model_len else args.max_tokens, 53 | isq=args.isq, 54 | device_ids=[int(d) for d in args.d.split(",")], 55 | generation_cfg=generation_cfg, 56 | flash_context=args.context_cache, 57 | fp8_kvcache=args.fp8_kvcache, 58 | server_mode=False, 59 | cpu_mem_fold=args.cpu_mem_fold, 60 | kv_fraction=args.kv_fraction, 61 | ) 62 | 63 | 64 | engine = Engine(cfg, "bf16") 65 | 66 | sampling_params = [] 67 | params = SamplingParams() 68 | message_list = [] 69 | for i in range(len(prompts)): 70 | message_list.append([Message("user", prompts[i])]) 71 | sampling_params.append(params) 72 | 73 | print("Start inference with", len(prompts), "prompts") 74 | outputs: GenerationOutput = engine.generate_sync(sampling_params, message_list) 75 | outputs.sort(key=lambda o: o.seq_id) 76 | 77 | decode_time_taken = 0.0 78 | prompt_time_taken = 0.0 79 | total_decoded_tokens = 0 80 | total_prompt_tokens = 0 81 | 82 | for i, output in enumerate(outputs): 83 | if args.batch == 1: 84 | print(f"\n[Prompt {i + 1}]") 85 | print(f"Prompt: {prompts[i]}") 86 | print(f"Response: {output.decode_output}") 87 | 88 | total_prompt_tokens += output.prompt_length 89 | total_decoded_tokens += output.decoded_length 90 | 91 | prompt_latency = (output.decode_start_time - output.prompt_start_time) / 1000.0 92 | prompt_time_taken = max(prompt_time_taken, prompt_latency) 93 | 94 | decode_latency = (output.decode_finish_time - output.decode_start_time) / 1000.0 95 | decode_time_taken = max(decode_time_taken, decode_latency) 96 | 97 | print("\n--- Performance Metrics ---") 98 | print( 99 | f"⏱️ Prompt tokens: {total_prompt_tokens} in {prompt_time_taken:.2f}s " 100 | f"({total_prompt_tokens / max(prompt_time_taken, 0.001):.2f} tokens/s)" 101 | ) 102 | print( 103 | f"⏱️ Decoded tokens: {total_decoded_tokens} in {decode_time_taken:.2f}s " 104 | f"({total_decoded_tokens / max(decode_time_taken, 0.001):.2f} tokens/s)" 105 | ) 106 | 107 | if __name__ == "__main__": 108 | parser = argparse.ArgumentParser(description="vllm.rs Python CLI") 109 | parser.add_argument("--m", help="huggingface model id", type=str, default=None) 110 | parser.add_argument("--w", help="safetensor weight path", type=str, default=None) 111 | parser.add_argument("--f", help="gguf file path or gguf file name when model_id is given", type=str, default=None) 112 | parser.add_argument("--prompts", type=str, 113 | help="Use '|' to separate multiple prompts") 114 | parser.add_argument("--d", type=str, default="0") 115 | parser.add_argument("--max-num-seqs", type=int, default=1) 116 | parser.add_argument("--max-model-len", type=int, default=None) 117 | parser.add_argument("--max-tokens", type=int, default=4096) 118 | parser.add_argument("--batch", type=int, default=1) 119 | parser.add_argument("--isq", type=str, default=None) 120 | parser.add_argument("--temperature", type=float, default=None) 121 | parser.add_argument("--top-p", type=float, default=None) 122 | parser.add_argument("--top-k", type=int, default=None) 123 | parser.add_argument("--frequency-penalty", type=float, default=None) 124 | parser.add_argument("--presence-penalty", type=float, default=None) 125 | parser.add_argument("--context-cache", action="store_true") 126 | parser.add_argument("--fp8-kvcache", action="store_true") 127 | parser.add_argument("--cpu-mem-fold", type=float, default=None) 128 | parser.add_argument("--kv-fraction", type=float, default=None) 129 | 130 | args = parser.parse_args() 131 | if not os.path.exists(args.w): 132 | print("⛔️ Model path is not provided (--w)!") 133 | else: 134 | run(args) 135 | -------------------------------------------------------------------------------- /example/rust-demo-tokenize/src/main.rs: -------------------------------------------------------------------------------- 1 | //! Tokenize/Detokenize API Demo for vLLM.rs 2 | //! 3 | //! This demonstrates how to use the /tokenize and /detokenize endpoints. 4 | //! Make sure the vllm.rs server is running before executing this. 5 | //! 6 | //! Usage: 7 | //! cargo run [-- --url http://localhost:8000] 8 | 9 | use reqwest::blocking::Client; 10 | use serde::{Deserialize, Serialize}; 11 | use std::env; 12 | 13 | // === Request/Response Types === 14 | 15 | #[derive(Serialize)] 16 | struct TokenizeTextRequest { 17 | prompt: String, 18 | #[serde(skip_serializing_if = "Option::is_none")] 19 | add_special_tokens: Option, 20 | } 21 | 22 | #[derive(Serialize)] 23 | struct TokenizeMessagesRequest { 24 | messages: Vec, 25 | #[serde(skip_serializing_if = "Option::is_none")] 26 | add_special_tokens: Option, 27 | } 28 | 29 | #[derive(Serialize)] 30 | struct Message { 31 | role: String, 32 | content: String, 33 | } 34 | 35 | #[derive(Deserialize, Debug)] 36 | struct TokenizeResponse { 37 | tokens: Vec, 38 | count: usize, 39 | #[serde(default)] 40 | max_model_len: Option, 41 | } 42 | 43 | #[derive(Serialize)] 44 | struct DetokenizeRequest { 45 | tokens: Vec, 46 | #[serde(skip_serializing_if = "Option::is_none")] 47 | skip_special_tokens: Option, 48 | } 49 | 50 | #[derive(Deserialize, Debug)] 51 | struct DetokenizeResponse { 52 | prompt: String, 53 | } 54 | 55 | // === API Functions === 56 | 57 | fn tokenize_text(client: &Client, base_url: &str, text: &str) -> Result> { 58 | let request = TokenizeTextRequest { 59 | prompt: text.to_string(), 60 | add_special_tokens: Some(true), 61 | }; 62 | 63 | let response = client 64 | .post(format!("{}/tokenize", base_url)) 65 | .json(&request) 66 | .send()? 67 | .json::()?; 68 | 69 | Ok(response) 70 | } 71 | 72 | fn tokenize_messages(client: &Client, base_url: &str, messages: Vec) -> Result> { 73 | let request = TokenizeMessagesRequest { 74 | messages, 75 | add_special_tokens: Some(true), 76 | }; 77 | 78 | let response = client 79 | .post(format!("{}/tokenize", base_url)) 80 | .json(&request) 81 | .send()? 82 | .json::()?; 83 | 84 | Ok(response) 85 | } 86 | 87 | fn detokenize(client: &Client, base_url: &str, tokens: Vec) -> Result> { 88 | let request = DetokenizeRequest { 89 | tokens, 90 | skip_special_tokens: Some(true), 91 | }; 92 | 93 | let response = client 94 | .post(format!("{}/detokenize", base_url)) 95 | .json(&request) 96 | .send()? 97 | .json::()?; 98 | 99 | Ok(response) 100 | } 101 | 102 | fn main() { 103 | // Parse command line arguments 104 | let args: Vec = env::args().collect(); 105 | let base_url = if args.len() > 2 && args[1] == "--url" { 106 | args[2].clone() 107 | } else { 108 | "http://localhost:8000".to_string() 109 | }; 110 | 111 | println!("Using server at: {}\n", base_url); 112 | 113 | let client = Client::new(); 114 | 115 | // Example 1: Tokenize plain text 116 | println!("{}", "=".repeat(50)); 117 | println!("Example 1: Tokenize plain text"); 118 | println!("{}", "=".repeat(50)); 119 | 120 | let text = "Hello, world! How are you today?"; 121 | match tokenize_text(&client, &base_url, text) { 122 | Ok(result) => { 123 | println!("Input: {}", text); 124 | println!("Tokens: {:?}", result.tokens); 125 | println!("Token count: {}", result.count); 126 | if let Some(max_len) = result.max_model_len { 127 | println!("Max model length: {}", max_len); 128 | } 129 | } 130 | Err(e) => { 131 | eprintln!("Error: {}", e); 132 | eprintln!("Make sure the vllm.rs server is running!"); 133 | return; 134 | } 135 | } 136 | 137 | println!(); 138 | 139 | // Example 2: Tokenize chat messages 140 | println!("{}", "=".repeat(50)); 141 | println!("Example 2: Tokenize chat messages"); 142 | println!("{}", "=".repeat(50)); 143 | 144 | let messages = vec![ 145 | Message { 146 | role: "system".to_string(), 147 | content: "You are a helpful assistant.".to_string(), 148 | }, 149 | Message { 150 | role: "user".to_string(), 151 | content: "What is 2 + 2?".to_string(), 152 | }, 153 | ]; 154 | 155 | match tokenize_messages(&client, &base_url, messages) { 156 | Ok(result) => { 157 | println!("Token count (with chat template): {}", result.count); 158 | println!("First 10 tokens: {:?}...", &result.tokens[..result.tokens.len().min(10)]); 159 | } 160 | Err(e) => { 161 | eprintln!("Error: {}", e); 162 | } 163 | } 164 | 165 | println!(); 166 | 167 | // Example 3: Detokenize 168 | println!("{}", "=".repeat(50)); 169 | println!("Example 3: Detokenize tokens"); 170 | println!("{}", "=".repeat(50)); 171 | 172 | if let Ok(tokenized) = tokenize_text(&client, &base_url, "Hello!") { 173 | println!("Input tokens: {:?}", tokenized.tokens); 174 | match detokenize(&client, &base_url, tokenized.tokens) { 175 | Ok(result) => { 176 | println!("Decoded text: {}", result.prompt); 177 | } 178 | Err(e) => { 179 | eprintln!("Error: {}", e); 180 | } 181 | } 182 | } 183 | 184 | println!(); 185 | 186 | // Example 4: Round-trip test 187 | println!("{}", "=".repeat(50)); 188 | println!("Example 4: Round-trip test"); 189 | println!("{}", "=".repeat(50)); 190 | 191 | let original = "The quick brown fox jumps over the lazy dog."; 192 | if let Ok(tokenized) = tokenize_text(&client, &base_url, original) { 193 | if let Ok(detokenized) = detokenize(&client, &base_url, tokenized.tokens) { 194 | println!("Original: {}", original); 195 | println!("After round-trip: {}", detokenized.prompt); 196 | println!("Match: {}", original == detokenized.prompt); 197 | } 198 | } 199 | } 200 | -------------------------------------------------------------------------------- /src/models/qwen3_vl/input.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::image::{to_tensor, ImageProcessConfig, ImageProcessTrait, ToFilter}; 2 | use crate::utils::image::{IMAGE_PLACEHOLDER, PLACEHOLDER}; 3 | use candle_core::{Result, Tensor}; 4 | use image::{DynamicImage, GenericImageView}; 5 | /// Qwen3-VL Image + Prompt Processor 6 | #[derive(Clone)] 7 | pub struct Qwen3VLImageProcessor { 8 | pub cfg: ImageProcessConfig, 9 | pub patch_size: usize, 10 | pub merge_size: usize, 11 | pub temporal_patch_size: usize, 12 | pub min_pixels: usize, 13 | pub max_pixels: usize, 14 | pub fixed_width: Option, 15 | pub fixed_height: Option, 16 | } 17 | 18 | impl Qwen3VLImageProcessor { 19 | #[allow(dead_code)] 20 | pub fn default(cfg: &ImageProcessConfig) -> Self { 21 | let max_row = std::cmp::max(cfg.max_height, cfg.max_width); 22 | Self { 23 | cfg: cfg.clone(), 24 | patch_size: cfg.patch_size, 25 | merge_size: cfg.spatial_merge_size, 26 | temporal_patch_size: cfg.temporal_patch_size.unwrap_or(2), 27 | min_pixels: 256 * 256, 28 | max_pixels: max_row * max_row, 29 | fixed_width: None, 30 | fixed_height: None, 31 | } 32 | } 33 | } 34 | 35 | impl Qwen3VLImageProcessor { 36 | pub const DEFAULT_MEAN: [f32; 3] = [0.48145466, 0.4578275, 0.40821073]; 37 | pub const DEFAULT_STD: [f32; 3] = [0.26862954, 0.26130258, 0.27577711]; 38 | 39 | pub const VISION_START: &str = "<|vision_start|>"; 40 | pub const VISION_END: &str = "<|vision_end|>"; 41 | pub const IMAGE_PAD: &str = "<|image_pad|>"; 42 | 43 | /// Resize respecting patch constraints 44 | fn smart_resize(&self, h: usize, w: usize) -> Result<(usize, usize)> { 45 | let factor = self.patch_size * self.merge_size; 46 | 47 | let mut nh = (h as f64 / factor as f64).round() as usize * factor; 48 | let mut nw = (w as f64 / factor as f64).round() as usize * factor; 49 | 50 | let pixels = nh * nw; 51 | 52 | if pixels > self.max_pixels { 53 | let beta = (pixels as f64 / self.max_pixels as f64).sqrt(); 54 | nh = ((nh as f64 / beta) as usize / factor) * factor; 55 | nw = ((nw as f64 / beta) as usize / factor) * factor; 56 | } else if pixels < self.min_pixels { 57 | let beta = (self.min_pixels as f64 / pixels as f64).sqrt(); 58 | nh = ((nh as f64 * beta) as usize / factor) * factor; 59 | nw = ((nw as f64 * beta) as usize / factor) * factor; 60 | } 61 | 62 | Ok((nh, nw)) 63 | } 64 | 65 | fn prepreprocess( 66 | &mut self, 67 | image: &DynamicImage, 68 | target_hw: (u32, u32), 69 | ) -> Result<(Tensor, (usize, usize))> { 70 | let (th, tw) = target_hw; 71 | 72 | let (mut nh, mut nw) = self.smart_resize(th as usize, tw as usize)?; 73 | if let (Some(h), Some(w)) = (self.fixed_height, self.fixed_width) { 74 | nh = h; 75 | nw = w; 76 | } else { 77 | self.fixed_height = Some(nh); 78 | self.fixed_width = Some(nw); 79 | }; 80 | let image = image 81 | .resize_exact(nw as u32, nh as u32, self.cfg.resampling.to_filter()?) 82 | .to_rgb8(); 83 | 84 | let image_mean = Some(self.cfg.image_mean.unwrap_or(Self::DEFAULT_MEAN)); 85 | let image_std = Some(self.cfg.image_std.unwrap_or(Self::DEFAULT_STD)); 86 | 87 | let (mut patches, _) = 88 | to_tensor(&vec![DynamicImage::ImageRgb8(image)], image_mean, image_std)?; 89 | 90 | if patches.dim(0)? == 1 { 91 | patches = patches.repeat((self.temporal_patch_size, 1, 1, 1))?; 92 | } 93 | 94 | let c = patches.dim(1)?; 95 | let grid_t = patches.dim(0)? / self.temporal_patch_size; 96 | let grid_h = nh / self.patch_size; 97 | let grid_w = nw / self.patch_size; 98 | 99 | patches = patches.reshape(&[ 100 | grid_t, 101 | self.temporal_patch_size, 102 | c, 103 | grid_h / self.merge_size, 104 | self.merge_size, 105 | self.patch_size, 106 | grid_w / self.merge_size, 107 | self.merge_size, 108 | self.patch_size, 109 | ])?; 110 | 111 | patches = patches.permute([0, 3, 6, 4, 7, 2, 1, 5, 8])?; 112 | 113 | let patches = patches.reshape(( 114 | grid_t * grid_h * grid_w, 115 | c * self.temporal_patch_size * self.patch_size * self.patch_size, 116 | ))?; 117 | 118 | Ok((patches, (grid_h as usize, grid_w as usize))) 119 | } 120 | } 121 | 122 | impl ImageProcessTrait for Qwen3VLImageProcessor { 123 | /// 🔹 Main entry: processes prompt + images together 124 | fn process_inputs( 125 | &mut self, 126 | prompt: &mut String, 127 | images: &Vec, 128 | ) -> Result<(Tensor, Vec<(usize, usize)>)> { 129 | let (max_w, max_h) = images 130 | .iter() 131 | .map(|i| i.dimensions()) 132 | .fold((0, 0), |(mw, mh), (w, h)| (mw.max(w), mh.max(h))); 133 | 134 | let mut pixel_values = Vec::new(); 135 | let mut grid_thw = Vec::new(); 136 | 137 | for image in images { 138 | let (patches, (h, w)) = self.prepreprocess(image, (max_h, max_w))?; 139 | 140 | pixel_values.push(patches); 141 | grid_thw.push((h, w)); 142 | } 143 | 144 | let pixel_values = Tensor::stack(&pixel_values, 0)?; 145 | 146 | // ===== Prompt expansion logic (preserved & fixed) ===== 147 | let merge_len = self.merge_size * self.merge_size; 148 | let mut image_idx = 0; 149 | let mut replace_strings = Vec::new(); 150 | while prompt.contains(IMAGE_PLACEHOLDER) { 151 | let grid = grid_thw[image_idx]; 152 | let num_patches: usize = (grid.0 * grid.1) as usize / merge_len; 153 | let mut replace_tokens = vec![Self::VISION_START]; 154 | replace_tokens.extend(vec![Self::IMAGE_PAD; num_patches]); 155 | replace_tokens.push(Self::VISION_END); 156 | 157 | replace_strings.push(replace_tokens.join("")); 158 | *prompt = prompt.replace(IMAGE_PLACEHOLDER, PLACEHOLDER); 159 | image_idx += 1; 160 | } 161 | 162 | while prompt.contains(PLACEHOLDER) { 163 | let replace_str = replace_strings.pop().unwrap(); 164 | *prompt = prompt.replace(PLACEHOLDER, &replace_str); 165 | } 166 | 167 | Ok((pixel_values, grid_thw)) 168 | } 169 | } 170 | -------------------------------------------------------------------------------- /src/utils/command.rs: -------------------------------------------------------------------------------- 1 | use crate::runner::MessageType; 2 | use bincode; 3 | use interprocess::local_socket::traits::{Listener, Stream}; 4 | use interprocess::local_socket::{GenericNamespaced, Name, ToNsName}; 5 | use interprocess::local_socket::{ListenerOptions, Stream as LocalStream}; 6 | use std::fmt; 7 | use std::io::Read; 8 | use std::io::{BufRead, BufReader, Write}; 9 | 10 | pub struct CommandManager { 11 | daemon_streams: Option>, 12 | main_stream: Option, 13 | } 14 | 15 | impl fmt::Debug for CommandManager { 16 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 17 | f.debug_struct("CommandManager") 18 | .field("daemon_streams", &self.daemon_streams) 19 | .field("main_stream", &self.main_stream) 20 | .finish() 21 | } 22 | } 23 | 24 | impl CommandManager { 25 | pub fn ipc_default_name() -> anyhow::Result<&'static str> { 26 | Ok("vllm_rs_daemon") 27 | } 28 | 29 | pub fn ipc_command_name(command_name: &str) -> anyhow::Result { 30 | let printname = format!("command_{}", command_name); 31 | Ok(printname) 32 | } 33 | 34 | pub fn to_channel_name(name: &str) -> anyhow::Result> { 35 | let printname = format!("{}.sock", name); 36 | Ok(printname.to_ns_name::()?) 37 | } 38 | 39 | //inter-node communication 40 | pub fn send_local( 41 | streams: &mut Vec, 42 | message: &MessageType, 43 | ) -> std::io::Result<()> { 44 | let serialized = bincode::serialize(message).expect("Serialization failed"); 45 | for stream in streams.iter_mut() { 46 | stream.write_all(&(serialized.len() as u32).to_le_bytes())?; 47 | stream.write_all(&serialized)?; 48 | stream.flush()?; // Ensure data is sent immediately 49 | // Wait for acknowledgment 50 | let mut ack_buf = [0u8; 1]; 51 | if let Err(e) = stream.read_exact(&mut ack_buf) { 52 | crate::log_info!( 53 | "Timeout waiting for acknowledgment from subprocess: {:?}", 54 | e 55 | ); 56 | } else if ack_buf[0] != 1 { 57 | crate::log_info!("Unexpected acknowledgment value from subprocess"); 58 | } 59 | } 60 | Ok(()) 61 | } 62 | 63 | pub fn send_message(&mut self, message: &MessageType) -> std::io::Result<()> { 64 | assert!(self.daemon_streams.is_some(), "No daomon process found!"); 65 | let streams = self.daemon_streams.as_mut().unwrap(); 66 | CommandManager::send_local(streams, message) 67 | } 68 | 69 | pub fn receive_local(stream: &mut LocalStream) -> std::io::Result { 70 | let mut length_buf = [0u8; 4]; 71 | stream.read_exact(&mut length_buf)?; 72 | let length = u32::from_le_bytes(length_buf) as usize; 73 | 74 | let mut serialized = vec![0u8; length]; 75 | stream.read_exact(&mut serialized)?; 76 | let message: MessageType = 77 | bincode::deserialize(&serialized).expect("Deserialization failed"); 78 | // Send acknowledgment 79 | stream.write_all(&[1])?; 80 | stream.flush()?; 81 | Ok(message) 82 | } 83 | 84 | pub fn receive_message(&mut self) -> std::io::Result { 85 | assert!( 86 | self.main_stream.is_some(), 87 | "not connected to the main process!" 88 | ); 89 | CommandManager::receive_local(self.main_stream.as_mut().unwrap()) 90 | } 91 | 92 | pub fn new_command( 93 | command_name: &str, 94 | num_subprocess: Option, 95 | is_daemon: bool, 96 | ) -> std::io::Result { 97 | let name = CommandManager::ipc_command_name(command_name).unwrap(); 98 | CommandManager::new_channel(&name.as_str(), true, num_subprocess, is_daemon) 99 | } 100 | 101 | pub fn new_channel( 102 | channel_name: &str, 103 | is_command: bool, 104 | num_subprocess: Option, 105 | is_daemon: bool, 106 | ) -> std::io::Result { 107 | let sock_name = Self::to_channel_name(channel_name).unwrap(); 108 | if is_daemon { 109 | crate::log_info!( 110 | "connect to main process' {} channel!", 111 | if is_command { "command" } else { "data" } 112 | ); 113 | let mut stream = LocalStream::connect(sock_name)?; 114 | stream.write_all(b"ready\n")?; 115 | crate::log_warn!( 116 | "connected to the main process' {} channel!", 117 | if is_command { "command" } else { "data" } 118 | ); 119 | Ok(Self { 120 | daemon_streams: None, 121 | main_stream: Some(stream), 122 | }) 123 | } else { 124 | crate::log_info!( 125 | "build {} channel for the main process!", 126 | if is_command { "command" } else { "data" } 127 | ); 128 | let num_subprocess = num_subprocess.unwrap(); 129 | let listener = ListenerOptions::new().name(sock_name).create_sync()?; 130 | let mut streams = Vec::with_capacity(num_subprocess); 131 | for _ in 0..num_subprocess { 132 | let stream = listener.accept()?; 133 | crate::log_info!( 134 | "accept one daemon process in {} channel!", 135 | if is_command { "command" } else { "data" } 136 | ); 137 | streams.push(stream); 138 | } 139 | 140 | for stream in streams.iter_mut() { 141 | let mut reader = BufReader::new(stream); 142 | let mut message = String::new(); 143 | reader.read_line(&mut message)?; 144 | if message.trim() == "ready" { 145 | crate::log_info!( 146 | "one daemon process connected to the {} channel!", 147 | if is_command { "command" } else { "data" } 148 | ); 149 | } 150 | } 151 | crate::log_warn!( 152 | "{} channel is built!", 153 | if is_command { "command" } else { "data" } 154 | ); 155 | Ok(Self { 156 | daemon_streams: Some(streams), 157 | main_stream: None, 158 | }) 159 | } 160 | } 161 | 162 | pub fn heartbeat(&mut self, is_daemon: bool) -> std::io::Result<()> { 163 | if is_daemon { 164 | match CommandManager::receive_local(self.main_stream.as_mut().unwrap()) { 165 | Ok(MessageType::Heartbeat) => Ok(()), 166 | Err(e) => Err(e), 167 | _ => Ok(()), 168 | } 169 | } else { 170 | let streams = self.daemon_streams.as_mut().unwrap(); 171 | CommandManager::send_local(streams, &MessageType::Heartbeat) 172 | } 173 | } 174 | } 175 | -------------------------------------------------------------------------------- /src/core/sequence.rs: -------------------------------------------------------------------------------- 1 | // src/core/sequence.rs 2 | use crate::utils::config::SamplingParams; 3 | use crate::utils::image::ImageData; 4 | use serde::{Deserialize, Serialize}; 5 | use std::fmt; 6 | use std::time::{SystemTime, UNIX_EPOCH}; 7 | #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Copy)] 8 | pub enum SequenceStatus { 9 | Waiting, 10 | Running, 11 | Finished, 12 | Cached, //Finished but resources not freed 13 | Swapped, //kv cache swapped to CPU memory (may not finished) 14 | FinishSwapped, //Finished and kv cache swapped to CPU memory 15 | } 16 | 17 | impl fmt::Display for SequenceStatus { 18 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 19 | let s = match self { 20 | SequenceStatus::Waiting => "Waiting", 21 | SequenceStatus::Running => "Running", 22 | SequenceStatus::Finished => "Finished", 23 | SequenceStatus::Cached => "Cached", 24 | SequenceStatus::Swapped => "Swapped", 25 | SequenceStatus::FinishSwapped => "FinishSwapped", 26 | }; 27 | write!(f, "{}", s) 28 | } 29 | } 30 | 31 | #[derive(Serialize, Deserialize, Debug, Clone)] 32 | pub struct Sequence { 33 | pub id: usize, 34 | pub created_time: usize, 35 | pub swapped_time: Option, 36 | pub status: SequenceStatus, 37 | pub token_ids: Vec, 38 | pub output_ids: Vec, 39 | pub block_table: Vec, 40 | pub num_cached_tokens: usize, 41 | pub last_token: u32, 42 | pub block_size: usize, 43 | pub sampling_params: SamplingParams, 44 | pub pd_first_token: Option, // the first token generated by PD server 45 | pub images: Option, // Multimodel image data 46 | /// Session ID generated for tool call continuation (set by scheduler when tool call detected) 47 | pub tool_call_session_id: Option, 48 | } 49 | 50 | #[derive(Serialize, Deserialize, Debug, Clone)] 51 | pub struct DecodeSequence { 52 | pub id: usize, 53 | pub last_token: u32, 54 | pub len: usize, 55 | pub last_block_tokens: usize, 56 | pub block_table_last: u32, 57 | pub block_tables: Vec, 58 | } 59 | 60 | impl DecodeSequence { 61 | pub fn new(sequence: &Sequence) -> Self { 62 | let last_token = sequence.last_token; 63 | let len = sequence.len(); 64 | let last_block_tokens = sequence.last_block_num_tokens(); 65 | let block_table_last = *sequence.block_table.last().unwrap(); 66 | DecodeSequence { 67 | id: sequence.id, 68 | last_token, 69 | len, 70 | last_block_tokens, 71 | block_table_last, 72 | block_tables: sequence.block_table.clone(), 73 | } 74 | } 75 | 76 | pub fn tokens_len(&self) -> usize { 77 | self.len 78 | } 79 | } 80 | 81 | pub trait ToDecodeInput { 82 | fn last_token(&self) -> u32; 83 | fn len(&self) -> usize; 84 | fn last_block_tokens(&self) -> usize; 85 | fn block_table_last(&self) -> u32; 86 | fn block_table(&self) -> &Vec; 87 | fn id(&self) -> usize; 88 | } 89 | 90 | impl ToDecodeInput for DecodeSequence { 91 | fn last_token(&self) -> u32 { 92 | self.last_token 93 | } 94 | 95 | fn len(&self) -> usize { 96 | self.len 97 | } 98 | 99 | fn last_block_tokens(&self) -> usize { 100 | self.last_block_tokens 101 | } 102 | 103 | fn block_table_last(&self) -> u32 { 104 | self.block_table_last 105 | } 106 | 107 | fn block_table(&self) -> &Vec { 108 | &self.block_tables 109 | } 110 | 111 | fn id(&self) -> usize { 112 | self.id 113 | } 114 | } 115 | 116 | impl ToDecodeInput for &Sequence { 117 | fn last_token(&self) -> u32 { 118 | self.last_token 119 | } 120 | 121 | fn len(&self) -> usize { 122 | self.tokens_len() 123 | } 124 | 125 | fn last_block_tokens(&self) -> usize { 126 | self.last_block_num_tokens() 127 | } 128 | 129 | fn block_table_last(&self) -> u32 { 130 | *self.block_table.last().unwrap() 131 | } 132 | 133 | fn block_table(&self) -> &Vec { 134 | &self.block_table 135 | } 136 | 137 | fn id(&self) -> usize { 138 | self.id 139 | } 140 | } 141 | 142 | impl Sequence { 143 | pub fn new( 144 | token_ids: Vec, 145 | block_size: usize, 146 | sampling_params: SamplingParams, 147 | images: &Option, 148 | image_idx: i32, 149 | ) -> Self { 150 | let images = if let Some(img) = &images { 151 | let mut img = img.clone(); // update the images 152 | img.image_idx = image_idx; 153 | Some(img) 154 | } else { 155 | None 156 | }; 157 | Self { 158 | id: 0, // Will be set by scheduler 159 | created_time: SystemTime::now() 160 | .duration_since(UNIX_EPOCH) 161 | .expect("Time went backwards") 162 | .as_millis() as usize, 163 | swapped_time: None, 164 | status: SequenceStatus::Waiting, 165 | token_ids: token_ids.clone(), 166 | output_ids: Vec::new(), 167 | block_table: Vec::new(), 168 | num_cached_tokens: 0, 169 | sampling_params, 170 | block_size, 171 | last_token: *token_ids.last().unwrap_or(&0), 172 | pd_first_token: None, 173 | images, 174 | tool_call_session_id: None, 175 | } 176 | } 177 | 178 | pub fn tokens_len(&self) -> usize { 179 | self.token_ids.len() 180 | } 181 | 182 | pub fn len(&self) -> usize { 183 | self.token_ids.len() 184 | } 185 | 186 | pub fn output_len(&self) -> usize { 187 | self.output_ids.len() 188 | } 189 | 190 | pub fn is_finished(&self) -> bool { 191 | self.status == SequenceStatus::Finished 192 | || self.status == SequenceStatus::Cached 193 | || self.status == SequenceStatus::FinishSwapped 194 | } 195 | 196 | pub fn num_blocks(&self) -> usize { 197 | self.len().div_ceil(self.block_size) 198 | } 199 | 200 | pub fn last_block_num_tokens(&self) -> usize { 201 | self.len() - (self.num_blocks() - 1) * self.block_size 202 | } 203 | 204 | pub fn num_cached_blocks(&self) -> usize { 205 | self.num_cached_tokens / self.block_size 206 | } 207 | 208 | pub fn append_token(&mut self, token: u32) { 209 | self.token_ids.push(token); 210 | self.output_ids.push(token); 211 | self.last_token = token; 212 | } 213 | 214 | pub fn block(&self, index: usize) -> Vec { 215 | let start = index * self.block_size; 216 | let end = (index + 1) * self.block_size; 217 | self.token_ids[start..end.min(self.token_ids.len())].to_vec() 218 | } 219 | 220 | pub fn created_time(&self) -> usize { 221 | self.created_time 222 | } 223 | 224 | pub fn swapped_time(&self) -> Option { 225 | self.swapped_time 226 | } 227 | 228 | pub fn clear_block_table(&mut self) { 229 | self.block_table.clear(); 230 | } 231 | } 232 | -------------------------------------------------------------------------------- /src/core/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod block_manager; 2 | pub mod engine; 3 | pub mod runner; 4 | pub mod scheduler; 5 | pub mod sequence; 6 | #[cfg(feature = "python")] 7 | use pyo3::pyclass; 8 | 9 | #[cfg(feature = "python")] 10 | #[pyclass] 11 | #[derive(Debug, Clone)] 12 | pub struct GenerationOutput { 13 | #[pyo3(get)] 14 | pub seq_id: usize, 15 | #[pyo3(get)] 16 | pub prompt_length: usize, 17 | #[pyo3(get)] 18 | pub prompt_start_time: usize, 19 | #[pyo3(get)] 20 | pub decode_start_time: usize, 21 | #[pyo3(get)] 22 | pub decode_finish_time: usize, 23 | #[pyo3(get)] 24 | pub decoded_length: usize, 25 | #[pyo3(get)] 26 | pub decode_output: String, 27 | } 28 | 29 | #[cfg(not(feature = "python"))] 30 | #[derive(Debug, Clone)] 31 | pub struct GenerationOutput { 32 | pub seq_id: usize, 33 | pub prompt_length: usize, 34 | pub prompt_start_time: usize, 35 | pub decode_start_time: usize, 36 | pub decode_finish_time: usize, 37 | pub decoded_length: usize, 38 | pub decode_output: String, 39 | } 40 | 41 | /// Result from sync collection - either completed normally or paused for tool call 42 | #[derive(Debug, Clone)] 43 | pub enum SyncCollectionResult { 44 | /// Generation completed normally 45 | Completed(GenerationOutput), 46 | /// Generation paused for tool call, needs resume with session_id 47 | ToolCallPause { 48 | session_id: String, 49 | seq_id: usize, 50 | prompt_length: usize, 51 | output_ids: Vec, 52 | prompt_start_time: usize, 53 | decode_start_time: usize, 54 | decoded_length: usize, 55 | }, 56 | } 57 | 58 | #[macro_export] 59 | macro_rules! log_info { 60 | ($($arg:tt)*) => { 61 | { 62 | #[cfg(feature = "python")] 63 | { 64 | use colored::Colorize; 65 | let s = format!($($arg)*); 66 | println!("{}", String::from(s).truecolor(211, 211, 211)); 67 | } 68 | #[cfg(not(feature = "python"))] 69 | { 70 | tracing::info!($($arg)*); 71 | } 72 | } 73 | }; 74 | } 75 | 76 | #[macro_export] 77 | macro_rules! log_warn { 78 | ($($arg:tt)*) => { 79 | { 80 | #[cfg(feature = "python")] 81 | { 82 | use colored::Colorize; 83 | let s = format!($($arg)*); 84 | eprintln!("{}", String::from(s).yellow()); 85 | } 86 | #[cfg(not(feature = "python"))] 87 | { 88 | tracing::warn!($($arg)*); 89 | } 90 | } 91 | }; 92 | } 93 | 94 | #[macro_export] 95 | macro_rules! log_error { 96 | ($($arg:tt)*) => { 97 | { 98 | #[cfg(feature = "python")] 99 | { 100 | use colored::Colorize; 101 | let s = format!($($arg)*); 102 | eprintln!("{}", String::from(s).red()); 103 | } 104 | #[cfg(not(feature = "python"))] 105 | { 106 | tracing::error!($($arg)*); 107 | } 108 | } 109 | }; 110 | } 111 | 112 | pub trait DecodeStreamTrait: Send + Sync { 113 | fn step(&mut self, id: u32) -> Option; 114 | } 115 | 116 | struct StreamWithTokenizer 117 | where 118 | M: tokenizers::Model + Send + Sync + 'static, 119 | N: tokenizers::Normalizer + Send + Sync + 'static, 120 | PT: tokenizers::PreTokenizer + Send + Sync + 'static, 121 | PP: tokenizers::PostProcessor + Send + Sync + 'static, 122 | D: tokenizers::Decoder + Send + Sync + 'static, 123 | { 124 | _tokenizer: Box>, 125 | stream: tokenizers::DecodeStream<'static, M, N, PT, PP, D>, 126 | } 127 | 128 | impl DecodeStreamTrait for StreamWithTokenizer 129 | where 130 | M: tokenizers::Model + Send + Sync + 'static, 131 | N: tokenizers::Normalizer + Send + Sync + 'static, 132 | PT: tokenizers::PreTokenizer + Send + Sync + 'static, 133 | PP: tokenizers::PostProcessor + Send + Sync + 'static, 134 | D: tokenizers::Decoder + Send + Sync + 'static, 135 | { 136 | fn step(&mut self, id: u32) -> Option { 137 | self.stream.step(id).ok().flatten() 138 | } 139 | } 140 | 141 | type DecodeStreamType = Box; 142 | 143 | #[macro_export] 144 | macro_rules! build_model { 145 | ($model_type:expr, $vb:expr, $comm:expr, $config:expr, $dtype:expr, $is_rope_i:expr, $device:expr, $reporter:expr, 146 | { $( $variant:ident => $ctor:ident ),+ $(,)? } 147 | ) => {{ 148 | match $model_type { 149 | $( ModelType::$variant => Ok::(Model::$variant(Arc::new($ctor::new( 150 | $vb, 151 | $comm.clone(), 152 | $config, 153 | $dtype, 154 | $is_rope_i, 155 | $device, 156 | Arc::clone(&$reporter), 157 | )?))), )+ 158 | _ => { 159 | candle_core::bail!("Unsupported model type: {:?}", $model_type); 160 | } 161 | } 162 | }}; 163 | } 164 | 165 | #[macro_export] 166 | macro_rules! model_call { 167 | ($model:expr, $method:ident, 168 | ($input_ids:expr, $positions:expr, $kv:expr, $input_metadata:expr), 169 | { $( $variant:ident => $extra:expr ),+ $(,)? } 170 | $(, $fallback:expr )? 171 | ) => {{ 172 | match $model { 173 | $( Model::$variant(model) => model.$method($input_ids, $positions, $kv, $input_metadata, $extra), )+ 174 | $( _ => $fallback, )? 175 | } 176 | }}; 177 | } 178 | 179 | #[cfg(all(feature = "cuda", feature = "graph"))] 180 | #[macro_export] 181 | macro_rules! graph_extra_arg { 182 | (EmbedInputs, $embeded_inputs:ident) => { 183 | $embeded_inputs 184 | }; 185 | (NoneArg, $embeded_inputs:ident) => { 186 | None 187 | }; 188 | } 189 | 190 | #[cfg(all(feature = "cuda", feature = "graph"))] 191 | #[macro_export] 192 | macro_rules! graph_wrapper { 193 | ($model:expr, $device:expr, 194 | { $( $variant:ident => $arg:tt ),+ $(,)? } 195 | ) => {{ 196 | match $model { 197 | $( Model::$variant(m) => { 198 | let model_arc = Arc::clone(m); 199 | let closure = move |input_ids: &Tensor, 200 | positions: &Tensor, 201 | kv_caches: Option<&Vec<(Tensor, Tensor)>>, 202 | input_metadata: &InputMetadata, 203 | embeded_inputs: bool| { 204 | model_arc.forward( 205 | input_ids, 206 | positions, 207 | kv_caches, 208 | input_metadata, 209 | crate::graph_extra_arg!($arg, embeded_inputs), 210 | ) 211 | }; 212 | let boxed_closure: Box = Box::new(closure); 213 | CudaGraphWrapper::new(boxed_closure, $device.as_cuda_device()?.clone().into()) 214 | }, )+ 215 | } 216 | }}; 217 | } 218 | -------------------------------------------------------------------------------- /src/models/qwen3_vl/mod.rs: -------------------------------------------------------------------------------- 1 | use parking_lot::RwLock; 2 | use std::rc::Rc; 3 | use std::sync::Arc; 4 | pub mod config; 5 | pub mod input; 6 | pub mod vision; 7 | 8 | use crate::models::layers::VarBuilderX; 9 | use crate::models::qwen3::Qwen3ForCausalLM; 10 | use crate::models::qwen3_moe::Qwen3MoEForCausalLM; 11 | use crate::utils::config::Config; 12 | use crate::utils::progress::ProgressLike; 13 | use crate::{models::layers::distributed::Comm, utils::image::ImageData}; 14 | use attention_rs::InputMetadata; 15 | use candle_core::{DType, Device, Result, Tensor}; 16 | use config::Qwen3VLConfig; 17 | use vision::Qwen3VLVisionModel; 18 | 19 | pub enum Qwen3TextModel { 20 | Dense(Qwen3ForCausalLM), 21 | MoE(Qwen3MoEForCausalLM), 22 | } 23 | 24 | #[allow(dead_code)] 25 | pub struct Qwen3VLForConditionalGeneration { 26 | text_model: Qwen3TextModel, 27 | vision_model: Qwen3VLVisionModel, 28 | spatial_merge_size: usize, 29 | image_token_id: u32, 30 | vision_start_token_id: u32, 31 | vision_end_token_id: u32, 32 | } 33 | 34 | impl Qwen3VLForConditionalGeneration { 35 | pub fn new( 36 | vb: &VarBuilderX, 37 | comm: Rc, 38 | config: &Config, 39 | dtype: DType, 40 | is_rope_i: bool, 41 | device: &Device, 42 | progress_reporter: Arc>>, 43 | ) -> Result { 44 | assert!( 45 | config.extra_config_json.is_some(), 46 | "Invalid multimodel config file!" 47 | ); 48 | let mut cfg: Qwen3VLConfig = 49 | serde_json::from_str(config.extra_config_json.as_ref().unwrap()) 50 | .map_err(candle_core::Error::wrap)?; 51 | cfg.text_config = config.clone(); 52 | 53 | let vision_model = 54 | Qwen3VLVisionModel::new(&cfg.vision_config, &vb.pp("model.visual"), dtype, device)?; 55 | 56 | if cfg.quantization_config.is_some() { 57 | cfg.text_config.quantization_config = cfg.quantization_config.clone(); 58 | } 59 | 60 | let arch = cfg 61 | .architectures 62 | .unwrap_or(vec!["Qwen3VLForConditionalGeneration".to_string()]); 63 | let arch = arch[0].as_str(); 64 | let text_model = if matches!(arch, "Qwen3VLMoeForConditionalGeneration") { 65 | Qwen3TextModel::MoE(Qwen3MoEForCausalLM::new_with_prefix( 66 | &vb, 67 | comm.clone(), 68 | config, 69 | dtype, 70 | is_rope_i, 71 | device, 72 | progress_reporter, 73 | Some("model.language_model".to_string()), 74 | )?) 75 | } else { 76 | Qwen3TextModel::Dense(Qwen3ForCausalLM::new_with_prefix( 77 | &vb, 78 | comm.clone(), 79 | config, 80 | dtype, 81 | is_rope_i, 82 | device, 83 | progress_reporter, 84 | Some("model.language_model".to_string()), 85 | )?) 86 | }; 87 | 88 | Ok(Self { 89 | text_model, 90 | vision_model, 91 | spatial_merge_size: cfg.vision_config.spatial_merge_size, 92 | image_token_id: cfg.image_token_id, 93 | vision_start_token_id: cfg.vision_start_token_id, 94 | vision_end_token_id: cfg.vision_end_token_id, 95 | }) 96 | } 97 | 98 | #[allow(clippy::too_many_arguments)] 99 | pub fn forward( 100 | &self, 101 | input_ids: &Tensor, 102 | positions: &Tensor, 103 | kv_caches: Option<&Vec<(Tensor, Tensor)>>, 104 | input_metadata: &InputMetadata, 105 | images: Option<&ImageData>, 106 | ) -> Result { 107 | let (mut input_embeds, dtype) = match &self.text_model { 108 | Qwen3TextModel::Dense(m) => (m.embed_forward(input_ids)?, m.dtype()), 109 | Qwen3TextModel::MoE(m) => (m.embed_forward(input_ids)?, m.dtype()), 110 | }; 111 | let device = input_embeds.device().clone(); 112 | let mut visual_pos_masks: Option = None; 113 | let mut deepstack_visual_embeds: Option> = None; 114 | 115 | if let Some(images) = &images { 116 | let mut pixel_values = images.to_tensor_f32(&device)?.to_dtype(dtype)?; 117 | let mut patches = Vec::new(); 118 | for (h, w) in &images.patches { 119 | patches.extend(vec![1, *h as u32, *w as u32]); 120 | } 121 | let mut image_grid_thw = Tensor::from_vec(patches, (images.patches.len(), 3), &device)?; 122 | let num_images = pixel_values.dim(0)?; 123 | assert!( 124 | num_images == image_grid_thw.dim(0)?, 125 | "Input image and patch dim mismatch!" 126 | ); 127 | if images.image_idx > 0 && (images.image_idx as usize) < num_images { 128 | pixel_values = pixel_values.narrow( 129 | 0, 130 | images.image_idx as usize, 131 | num_images - images.image_idx as usize, 132 | )?; 133 | image_grid_thw = image_grid_thw.narrow( 134 | 0, 135 | images.image_idx as usize, 136 | num_images - images.image_idx as usize, 137 | )?; 138 | crate::log_warn!( 139 | "Slicing images: start idx {} -> {:?}", 140 | images.image_idx, 141 | pixel_values.shape() 142 | ); 143 | } 144 | 145 | let dims = pixel_values.dims(); 146 | if dims.len() == 3 { 147 | pixel_values = pixel_values.reshape((dims[0] * dims[1], dims[2]))?; 148 | } 149 | let (image_embeds, deepstack_image_embeds) = 150 | self.vision_model.forward(&pixel_values, &image_grid_thw)?; 151 | 152 | let image_embeds = image_embeds 153 | .to_device(&device)? 154 | .to_dtype(input_embeds.dtype())?; 155 | let deepstack_image_embeds = deepstack_image_embeds 156 | .into_iter() 157 | .map(|t| t.to_device(&device)?.to_dtype(input_embeds.dtype())) 158 | .collect::>>()?; 159 | 160 | let image_mask = input_ids.eq(self.image_token_id as u32)?; 161 | visual_pos_masks = Some(image_mask.to_dtype(DType::U8)?); 162 | 163 | let image_mask = image_mask 164 | .unsqueeze(candle_core::D::Minus1)? 165 | .broadcast_as(input_embeds.shape().clone())? 166 | .to_dtype(DType::U32)?; 167 | use attention_rs::ops::NonZeroOp; 168 | let indices = image_mask.flatten_all()?.nonzero()?.squeeze(1)?; 169 | 170 | let mut x_flat = input_embeds.flatten_all()?; 171 | let image_flat = image_embeds.flatten_all()?; 172 | 173 | x_flat = 174 | x_flat.scatter_add(&indices, &(image_flat - x_flat.gather(&indices, 0)?)?, 0)?; 175 | input_embeds = x_flat.reshape(input_embeds.shape())?; 176 | deepstack_visual_embeds = Some(deepstack_image_embeds); 177 | } 178 | 179 | match &self.text_model { 180 | Qwen3TextModel::Dense(m) => m.forward_with_deepstack( 181 | &input_embeds, 182 | &positions, 183 | kv_caches, 184 | input_metadata, 185 | true, 186 | &visual_pos_masks, 187 | &deepstack_visual_embeds, 188 | ), 189 | Qwen3TextModel::MoE(m) => m.forward_with_deepstack( 190 | &input_embeds, 191 | &positions, 192 | kv_caches, 193 | input_metadata, 194 | true, 195 | &visual_pos_masks, 196 | &deepstack_visual_embeds, 197 | ), 198 | } 199 | } 200 | 201 | pub fn get_vocab_size(&self) -> usize { 202 | todo!() 203 | } 204 | } 205 | -------------------------------------------------------------------------------- /src/models/layers/others.rs: -------------------------------------------------------------------------------- 1 | use crate::models::layers::VarBuilderX; 2 | use candle_core::{DType, IndexOp, Result, Tensor, WithDType}; 3 | use candle_nn::{var_builder::Shard, Module}; 4 | use candle_nn::{Embedding, LayerNorm, RmsNorm}; 5 | use either::Either; 6 | 7 | pub struct NormX { 8 | norm: Either, 9 | dtype: DType, 10 | } 11 | impl NormX { 12 | pub fn forward(&self, xs: &Tensor) -> Result { 13 | let in_dtype = xs.dtype(); 14 | let xs = if xs.dtype() != self.dtype { 15 | xs.to_dtype(self.dtype)? 16 | } else { 17 | xs.to_owned() 18 | }; 19 | let xs = match &self.norm { 20 | Either::Left(norm) => norm.forward(&xs)?, 21 | Either::Right(norm) => norm.forward(&xs)?, 22 | }; 23 | if xs.dtype() != in_dtype { 24 | xs.to_dtype(in_dtype) 25 | } else { 26 | Ok(xs) 27 | } 28 | } 29 | } 30 | 31 | pub fn rms_norm( 32 | size: usize, 33 | eps: f64, 34 | vb: VarBuilderX, 35 | dtype: DType, 36 | is_gemma: bool, 37 | ) -> Result { 38 | let (weight, dtype) = match &vb.0 { 39 | Either::Left(vb) => { 40 | let ws = vb.get_with_hints(size, "weight", Shard::default())?; 41 | if ws.dtype() != dtype { 42 | (ws.to_dtype(dtype)?, dtype) 43 | } else { 44 | (ws, dtype) 45 | } 46 | } 47 | Either::Right(vb) => (vb.get(size, "weight")?.dequantize(vb.device())?, DType::F32), 48 | }; 49 | 50 | let weight = if is_gemma { (weight + 1.0)? } else { weight }; 51 | Ok(NormX { 52 | norm: Either::Left(RmsNorm::new(weight, eps)), 53 | dtype, 54 | }) 55 | } 56 | 57 | pub fn layer_norm( 58 | size: usize, 59 | eps: f64, 60 | affine: bool, 61 | vb: VarBuilderX, 62 | dtype: DType, 63 | ) -> Result { 64 | let (weight, dtype) = match &vb.0 { 65 | Either::Left(vb) => ( 66 | vb.get_with_hints(size, "weight", Shard::default())? 67 | .to_dtype(dtype)?, 68 | dtype, 69 | ), 70 | Either::Right(vb) => (vb.get(size, "weight")?.dequantize(vb.device())?, DType::F32), 71 | }; 72 | if affine { 73 | let bias = match &vb.0 { 74 | Either::Left(vb) => vb.get(size, "bias")?.to_dtype(dtype)?, 75 | Either::Right(vb) => vb.get(size, "bias")?.dequantize(vb.device())?, 76 | }; 77 | Ok(NormX { 78 | norm: Either::Right(LayerNorm::new(weight, bias, eps)), 79 | dtype, 80 | }) 81 | } else { 82 | Ok(NormX { 83 | norm: Either::Right(LayerNorm::new_no_bias(weight, eps)), 84 | dtype, 85 | }) 86 | } 87 | } 88 | 89 | pub fn embedding( 90 | vocab_size: Option, 91 | hidden_size: usize, 92 | vb: VarBuilderX, 93 | dtype: DType, 94 | ) -> Result<(Embedding, usize)> { 95 | let (embeddings, vocab_size) = match &vb.0 { 96 | Either::Left(vb) => { 97 | assert!( 98 | vocab_size.is_some(), 99 | "vocab_size must be specified for safetensor models" 100 | ); 101 | ( 102 | vb.get((vocab_size.unwrap(), hidden_size), "weight")? 103 | .to_dtype(dtype)?, 104 | vocab_size.unwrap(), 105 | ) 106 | } 107 | Either::Right(vb) => { 108 | let weight = if vocab_size.is_some() { 109 | vb.get((vocab_size.unwrap(), hidden_size), "weight")? 110 | } else { 111 | vb.get_no_shape("weight")? 112 | } 113 | .dequantize(vb.device())?; 114 | let vocab_size = vocab_size.unwrap_or(weight.dim(0)?); 115 | (weight, vocab_size) 116 | } 117 | }; 118 | Ok((Embedding::new(embeddings, hidden_size), vocab_size)) 119 | } 120 | 121 | pub fn conv2d( 122 | in_channels: usize, 123 | out_channels: usize, 124 | kernel_size: usize, 125 | cfg: candle_nn::Conv2dConfig, 126 | vb: VarBuilderX, 127 | bias: bool, 128 | ) -> Result { 129 | let (ws, bs) = match vb.0 { 130 | Either::Left(v) => { 131 | let ws = v.get( 132 | ( 133 | out_channels, 134 | in_channels / cfg.groups, 135 | kernel_size, 136 | kernel_size, 137 | ), 138 | "weight", 139 | )?; 140 | let bs = if bias { 141 | Some(v.get(out_channels, "bias")?) 142 | } else { 143 | None 144 | }; 145 | (ws, bs) 146 | } 147 | _ => { 148 | todo!() 149 | } 150 | }; 151 | 152 | Ok(candle_nn::Conv2d::new(ws, bs, cfg)) 153 | } 154 | 155 | pub struct AvgPool2d { 156 | kernel_size: usize, 157 | stride: usize, 158 | } 159 | 160 | impl AvgPool2d { 161 | pub fn new(kernel_size: usize, stride: usize) -> Self { 162 | Self { 163 | kernel_size, 164 | stride, 165 | } 166 | } 167 | 168 | pub fn forward(&self, xs: &Tensor) -> Result { 169 | xs.avg_pool2d_with_stride(self.kernel_size, self.stride) 170 | } 171 | } 172 | 173 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 174 | pub struct Conv3dConfig { 175 | pub padding: usize, 176 | pub stride: usize, 177 | pub dilation: usize, 178 | pub groups: usize, 179 | } 180 | 181 | impl Default for Conv3dConfig { 182 | fn default() -> Self { 183 | Self { 184 | padding: 0, 185 | stride: 1, 186 | dilation: 1, 187 | groups: 1, 188 | } 189 | } 190 | } 191 | 192 | pub struct Conv3dNoBias { 193 | conv2d_1: candle_nn::Conv2d, 194 | conv2d_2: candle_nn::Conv2d, 195 | } 196 | 197 | impl Conv3dNoBias { 198 | pub fn new( 199 | in_channels: usize, 200 | out_channels: usize, 201 | kernel_sizes: [usize; 3], 202 | cfg: Conv3dConfig, 203 | vb: VarBuilderX, 204 | ) -> Result { 205 | use candle_nn::Conv2dConfig; 206 | let expected_shape = ( 207 | out_channels, 208 | in_channels / cfg.groups, 209 | kernel_sizes[0], 210 | kernel_sizes[1], 211 | kernel_sizes[2], 212 | ); 213 | let ws = match vb.0 { 214 | Either::Left(v) => v.get(expected_shape, "weight")?, 215 | _ => { 216 | panic!("Unsupported quantized format for conv3d") 217 | } 218 | }; 219 | 220 | let w1 = ws.i((.., .., 0, .., ..))?; 221 | let w2 = ws.i((.., .., 1, .., ..))?; 222 | 223 | let cfg = Conv2dConfig { 224 | padding: cfg.padding, 225 | stride: cfg.stride, 226 | dilation: cfg.dilation, 227 | groups: cfg.groups, 228 | }; 229 | 230 | Ok(Self { 231 | conv2d_1: candle_nn::Conv2d::new(w1.contiguous()?, None, cfg), 232 | conv2d_2: candle_nn::Conv2d::new(w2.contiguous()?, None, cfg), 233 | }) 234 | } 235 | 236 | pub fn weight(&self) -> Result { 237 | let w1 = self.conv2d_1.weight().clone().unsqueeze(2)?; 238 | let w2 = self.conv2d_2.weight().clone().unsqueeze(2)?; 239 | Tensor::cat(&[w1, w2], 2) 240 | } 241 | } 242 | 243 | impl Module for Conv3dNoBias { 244 | fn forward(&self, xs: &Tensor) -> Result { 245 | let xs1 = xs.i((.., .., 0, .., ..))?; 246 | let xs2 = xs.i((.., .., 1, .., ..))?; 247 | 248 | (self.conv2d_1.forward(&xs1)? + self.conv2d_2.forward(&xs2)?)?.unsqueeze(2) 249 | } 250 | } 251 | 252 | pub fn masked_fill(xs: &Tensor, mask: &Tensor, value: D) -> Result { 253 | let on_true = Tensor::full(value, xs.shape(), xs.device())?.to_dtype(xs.dtype())?; 254 | let on_false = xs; 255 | let res = mask 256 | .broadcast_as(xs.shape())? 257 | .where_cond(&on_true, on_false)?; 258 | Ok(res) 259 | } 260 | -------------------------------------------------------------------------------- /src/utils/progress.rs: -------------------------------------------------------------------------------- 1 | use crate::runner::send_local; 2 | use crate::runner::{receive_local, MessageType}; 3 | use candle_core::Result; 4 | use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; 5 | use interprocess::local_socket::traits::{Listener, Stream}; 6 | use interprocess::local_socket::{GenericNamespaced, ToNsName}; 7 | use interprocess::local_socket::{ListenerOptions, Stream as LocalStream}; 8 | use parking_lot::RwLock; 9 | use std::collections::HashMap; 10 | use std::sync::Arc; 11 | use std::thread::JoinHandle; 12 | use std::{thread, time}; 13 | pub trait ProgressLike: Send + Sync { 14 | fn get_progress(&mut self) -> Vec<(usize, usize)>; 15 | fn set_progress(&mut self, p: usize); 16 | } 17 | 18 | pub struct ProgressReporter { 19 | pub rank: usize, 20 | pub progress: usize, 21 | } 22 | 23 | impl ProgressLike for ProgressReporter { 24 | fn get_progress(&mut self) -> Vec<(usize, usize)> { 25 | vec![(self.rank, self.progress)] 26 | } 27 | 28 | fn set_progress(&mut self, p: usize) { 29 | self.progress = p; 30 | } 31 | } 32 | 33 | impl ProgressReporter { 34 | pub fn new(rank: usize) -> Self { 35 | Self { rank, progress: 0 } 36 | } 37 | } 38 | 39 | unsafe impl Send for ProgressReporter {} 40 | unsafe impl Sync for ProgressReporter {} 41 | 42 | pub struct RemoteProgressReporter { 43 | pub rank: usize, 44 | pub progress: usize, 45 | pub streams: Vec, 46 | } 47 | 48 | impl RemoteProgressReporter { 49 | pub fn new(rank: usize, shards: usize, sock_name: String, client: bool) -> Result { 50 | let mut streams = Vec::::with_capacity(shards); 51 | if client { 52 | crate::log_info!("Remote progress reporter initialized for rank {}", rank); 53 | let name = sock_name.clone().to_ns_name::()?; 54 | let mut stream = LocalStream::connect(name.clone()); 55 | 56 | loop { 57 | if stream.is_ok() { 58 | break; 59 | } 60 | crate::log_info!("Runner retry connecting to socket: {}", sock_name); 61 | stream = LocalStream::connect(name.clone()); 62 | std::thread::sleep(std::time::Duration::from_millis(100)); 63 | } 64 | streams.push(stream.unwrap()); 65 | } else { 66 | let listener = ListenerOptions::new() 67 | .name( 68 | sock_name 69 | .clone() 70 | .to_ns_name::() 71 | .expect("Failed to to_ns_name"), 72 | ) 73 | .create_sync()?; 74 | 75 | crate::log_info!("listener starting accepting runner {}", rank); 76 | for _ in 0..shards { 77 | match listener.accept() { 78 | Ok(stream) => streams.push(stream), 79 | Err(e) => { 80 | crate::log_error!("Failed to accept connection: {}", e); 81 | } 82 | } 83 | } 84 | } 85 | Ok(Self { 86 | rank, 87 | progress: 0, 88 | streams, 89 | }) 90 | } 91 | } 92 | 93 | impl ProgressLike for RemoteProgressReporter { 94 | fn get_progress(&mut self) -> Vec<(usize, usize)> { 95 | let mut progress_values = Vec::with_capacity(self.streams.len()); 96 | for mut stream in &mut self.streams { 97 | if let Ok(msg) = receive_local(&mut stream, false) { 98 | if let MessageType::LoadingProgress((rank, progress)) = msg { 99 | progress_values.push((rank, progress)); 100 | } 101 | } else { 102 | panic!("Error when loading model!"); 103 | } 104 | } 105 | progress_values 106 | } 107 | 108 | fn set_progress(&mut self, p: usize) { 109 | let _ = send_local( 110 | &mut self.streams, 111 | &MessageType::LoadingProgress((self.rank, p)), 112 | false, 113 | ); 114 | } 115 | } 116 | 117 | unsafe impl Send for RemoteProgressReporter {} 118 | unsafe impl Sync for RemoteProgressReporter {} 119 | 120 | pub struct Progress { 121 | m: MultiProgress, 122 | bars: Vec, 123 | size: usize, 124 | } 125 | 126 | impl Progress { 127 | pub fn new(n: usize, size: usize) -> Progress { 128 | let m = MultiProgress::new(); 129 | let sty = ProgressStyle::with_template( 130 | "[{elapsed_precise}] {bar:60.cyan/blue} {pos:>4}/{len:4} {msg}", 131 | ) 132 | .unwrap() 133 | .progress_chars("##-"); 134 | 135 | let mut bars = Vec::::new(); 136 | for i in 0..n { 137 | let pb = m.add(ProgressBar::new(size as u64)); 138 | pb.set_style(sty.clone()); 139 | if n > 1 { 140 | pb.set_message(format!("On Rank {} Device", i)); 141 | } 142 | bars.push(pb); 143 | } 144 | 145 | // if n > 1 { 146 | // m.println(format!("Loading model in {} ranks!", n)).unwrap(); 147 | // } 148 | Self { m, bars, size } 149 | } 150 | 151 | pub fn update(&self, idx: usize, progress: usize) { 152 | if idx < self.bars.len() && progress > 0 { 153 | let pos = self.bars[idx].position(); 154 | self.bars[idx].inc(progress as u64 - pos); 155 | if self.bars.len() > 1 { 156 | if progress >= self.size { 157 | self.bars[idx].set_message(format!("On Rank {} Device Finished", idx)); 158 | } else { 159 | self.bars[idx].set_message(format!("On Rank {} Device", idx)); 160 | } 161 | } 162 | } 163 | } 164 | 165 | pub fn finish(&self) { 166 | for idx in 0..self.bars.len() { 167 | let pos = self.bars[idx].position(); 168 | self.bars[idx].inc(self.size as u64 - pos); 169 | if self.bars.len() > 1 { 170 | self.bars[idx].set_message(format!("On Rank {} Device Finished", idx)); 171 | } 172 | } 173 | self.m.clear().unwrap(); 174 | } 175 | } 176 | 177 | #[allow(unused_variables)] 178 | pub fn progress_worker( 179 | num_shards: usize, 180 | length: usize, 181 | progress_reporter: &Arc>>, 182 | ) -> std::thread::JoinHandle<()> { 183 | let mut finished_map = HashMap::::new(); 184 | let reporter = progress_reporter.clone(); 185 | let progress_bar = Some(Progress::new(num_shards, length)); 186 | let handle = thread::spawn(move || loop { 187 | { 188 | let _ = thread::sleep(time::Duration::from_millis(10 as u64)); 189 | let progress = reporter.write().get_progress(); 190 | for (rank, progress) in progress { 191 | finished_map.insert(rank, progress); 192 | progress_bar.as_ref().unwrap().update(rank, progress); 193 | } 194 | 195 | if finished_map.values().all(|v| v >= &length) { 196 | progress_bar.as_ref().unwrap().finish(); 197 | break; 198 | } 199 | } 200 | }); 201 | 202 | handle 203 | } 204 | 205 | pub fn spawn_progress_thread( 206 | num_shards: usize, 207 | length: usize, 208 | progress_sock_name: String, 209 | ) -> JoinHandle>> { 210 | thread::spawn(move || { 211 | match RemoteProgressReporter::new(0, num_shards, progress_sock_name, false) { 212 | Ok(reporter) => { 213 | let reporter: Arc>> = 214 | Arc::new(RwLock::new(Box::new(reporter))); 215 | 216 | // Call the real worker — assumed to return a JoinHandle 217 | let handle = progress_worker(num_shards, length, &reporter); 218 | Some(handle) 219 | } 220 | Err(e) => { 221 | eprintln!("Unable to create progress monitor: {e}"); 222 | None 223 | } 224 | } 225 | }) 226 | } 227 | -------------------------------------------------------------------------------- /src/mcp/transport.rs: -------------------------------------------------------------------------------- 1 | // src/mcp/transport.rs 2 | //! MCP transport layer implementations 3 | //! 4 | //! Supports stdio (for local processes) and HTTP/SSE (for remote servers) 5 | 6 | use super::types::*; 7 | use std::io::{BufRead, BufReader, Write}; 8 | use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio}; 9 | 10 | /// Transport trait for sending and receiving MCP messages 11 | pub trait Transport: Send + Sync { 12 | /// Send a message 13 | fn send(&mut self, message: &str) -> Result<(), TransportError>; 14 | 15 | /// Receive a message (blocking) 16 | fn receive(&mut self) -> Result; 17 | 18 | /// Close the transport 19 | fn close(&mut self) -> Result<(), TransportError>; 20 | } 21 | 22 | /// Transport errors 23 | #[derive(Debug, thiserror::Error)] 24 | pub enum TransportError { 25 | #[error("IO error: {0}")] 26 | Io(#[from] std::io::Error), 27 | 28 | #[error("Process error: {0}")] 29 | Process(String), 30 | 31 | #[error("Connection closed")] 32 | Closed, 33 | 34 | #[error("Timeout")] 35 | Timeout, 36 | 37 | #[error("Parse error: {0}")] 38 | Parse(String), 39 | } 40 | 41 | /// Stdio transport for communicating with local MCP server processes 42 | pub struct StdioTransport { 43 | child: Child, 44 | stdin: Option, 45 | stdout_reader: Option>, 46 | } 47 | 48 | impl StdioTransport { 49 | /// Create a new stdio transport by spawning a process 50 | pub fn spawn(command: &str, args: &[&str]) -> Result { 51 | Self::spawn_with_env(command, args, &std::collections::HashMap::new()) 52 | } 53 | 54 | /// Create a new stdio transport with additional environment variables 55 | pub fn spawn_with_env( 56 | command: &str, 57 | args: &[&str], 58 | env: &std::collections::HashMap, 59 | ) -> Result { 60 | let mut child = Command::new(command) 61 | .args(args) 62 | .envs(env) 63 | .stdin(Stdio::piped()) 64 | .stdout(Stdio::piped()) 65 | .stderr(Stdio::inherit()) 66 | .spawn()?; 67 | 68 | let stdin = child.stdin.take(); 69 | let stdout = child.stdout.take(); 70 | let stdout_reader = stdout.map(BufReader::new); 71 | 72 | Ok(Self { 73 | child, 74 | stdin, 75 | stdout_reader, 76 | }) 77 | } 78 | } 79 | 80 | impl Transport for StdioTransport { 81 | fn send(&mut self, message: &str) -> Result<(), TransportError> { 82 | if let Some(ref mut stdin) = self.stdin { 83 | writeln!(stdin, "{}", message)?; 84 | stdin.flush()?; 85 | Ok(()) 86 | } else { 87 | Err(TransportError::Closed) 88 | } 89 | } 90 | 91 | fn receive(&mut self) -> Result { 92 | if let Some(ref mut reader) = self.stdout_reader { 93 | let mut line = String::new(); 94 | let bytes_read = reader.read_line(&mut line)?; 95 | if bytes_read == 0 { 96 | return Err(TransportError::Closed); 97 | } 98 | Ok(line.trim().to_string()) 99 | } else { 100 | Err(TransportError::Closed) 101 | } 102 | } 103 | 104 | fn close(&mut self) -> Result<(), TransportError> { 105 | self.stdin = None; 106 | self.stdout_reader = None; 107 | let _ = self.child.kill(); 108 | Ok(()) 109 | } 110 | } 111 | 112 | impl Drop for StdioTransport { 113 | fn drop(&mut self) { 114 | let _ = self.close(); 115 | } 116 | } 117 | 118 | /// In-memory transport for testing (not thread-safe, for single-threaded tests only) 119 | pub struct MemoryTransport { 120 | tx: std::sync::mpsc::Sender, 121 | rx: std::sync::mpsc::Receiver, 122 | } 123 | 124 | impl MemoryTransport { 125 | /// Create a pair of connected transports for testing 126 | pub fn pair() -> (Self, Self) { 127 | let (tx1, rx1) = std::sync::mpsc::channel(); 128 | let (tx2, rx2) = std::sync::mpsc::channel(); 129 | 130 | (Self { tx: tx1, rx: rx2 }, Self { tx: tx2, rx: rx1 }) 131 | } 132 | 133 | pub fn send(&mut self, message: &str) -> Result<(), TransportError> { 134 | self.tx 135 | .send(message.to_string()) 136 | .map_err(|_| TransportError::Closed) 137 | } 138 | 139 | pub fn receive(&mut self) -> Result { 140 | self.rx.recv().map_err(|_| TransportError::Closed) 141 | } 142 | 143 | #[allow(dead_code)] 144 | pub fn close(&mut self) -> Result<(), TransportError> { 145 | Ok(()) 146 | } 147 | } 148 | 149 | /// Message framing utilities for MCP over different transports 150 | pub mod framing { 151 | use super::*; 152 | 153 | /// Encode a JSON-RPC message for line-based transport 154 | pub fn encode_line(message: &impl serde::Serialize) -> Result { 155 | serde_json::to_string(message).map_err(|e| TransportError::Parse(e.to_string())) 156 | } 157 | 158 | /// Decode a JSON-RPC message from line-based transport 159 | pub fn decode_line(line: &str) -> Result { 160 | serde_json::from_str(line).map_err(|e| TransportError::Parse(e.to_string())) 161 | } 162 | 163 | /// Parse any JSON-RPC message (request, response, or notification) 164 | pub fn parse_message(line: &str) -> Result { 165 | let value: serde_json::Value = 166 | serde_json::from_str(line).map_err(|e| TransportError::Parse(e.to_string()))?; 167 | 168 | // Check if it's a response (has result or error) 169 | if value.get("result").is_some() || value.get("error").is_some() { 170 | let response: JsonRpcResponse = 171 | serde_json::from_value(value).map_err(|e| TransportError::Parse(e.to_string()))?; 172 | return Ok(McpMessage::Response(response)); 173 | } 174 | 175 | // Check if it's a notification (no id) 176 | if value.get("id").is_none() { 177 | let notification: JsonRpcNotification = 178 | serde_json::from_value(value).map_err(|e| TransportError::Parse(e.to_string()))?; 179 | return Ok(McpMessage::Notification(notification)); 180 | } 181 | 182 | // It's a request 183 | let request: JsonRpcRequest = 184 | serde_json::from_value(value).map_err(|e| TransportError::Parse(e.to_string()))?; 185 | Ok(McpMessage::Request(request)) 186 | } 187 | } 188 | 189 | /// Parsed MCP message types 190 | #[derive(Debug)] 191 | pub enum McpMessage { 192 | Request(JsonRpcRequest), 193 | Response(JsonRpcResponse), 194 | Notification(JsonRpcNotification), 195 | } 196 | 197 | #[cfg(test)] 198 | mod tests { 199 | use super::*; 200 | 201 | #[test] 202 | fn test_memory_transport() { 203 | let (mut t1, mut t2) = MemoryTransport::pair(); 204 | 205 | t1.send("hello").unwrap(); 206 | assert_eq!(t2.receive().unwrap(), "hello"); 207 | 208 | t2.send("world").unwrap(); 209 | assert_eq!(t1.receive().unwrap(), "world"); 210 | } 211 | 212 | #[test] 213 | fn test_framing() { 214 | let req = JsonRpcRequest::new(1i64, "test", None); 215 | let encoded = framing::encode_line(&req).unwrap(); 216 | let decoded: JsonRpcRequest = framing::decode_line(&encoded).unwrap(); 217 | 218 | assert_eq!(decoded.method, "test"); 219 | } 220 | 221 | #[test] 222 | fn test_parse_message() { 223 | let request = r#"{"jsonrpc":"2.0","id":1,"method":"test"}"#; 224 | let response = r#"{"jsonrpc":"2.0","id":1,"result":{}}"#; 225 | let notification = r#"{"jsonrpc":"2.0","method":"notify"}"#; 226 | 227 | assert!(matches!( 228 | framing::parse_message(request).unwrap(), 229 | McpMessage::Request(_) 230 | )); 231 | assert!(matches!( 232 | framing::parse_message(response).unwrap(), 233 | McpMessage::Response(_) 234 | )); 235 | assert!(matches!( 236 | framing::parse_message(notification).unwrap(), 237 | McpMessage::Notification(_) 238 | )); 239 | } 240 | } 241 | -------------------------------------------------------------------------------- /src/mcp/client.rs: -------------------------------------------------------------------------------- 1 | // src/mcp/client.rs 2 | //! MCP Client implementation 3 | //! 4 | //! Connect to external MCP servers to discover and call tools. 5 | 6 | use super::transport::{framing, Transport, TransportError}; 7 | use super::types::*; 8 | use serde_json::Value; 9 | use std::collections::HashMap; 10 | use std::sync::atomic::{AtomicI64, Ordering}; 11 | 12 | /// MCP Client for connecting to MCP servers 13 | pub struct McpClient { 14 | /// Transport layer 15 | transport: T, 16 | /// Client info 17 | client_info: Implementation, 18 | /// Server info (after initialization) 19 | server_info: Option, 20 | /// Server capabilities (after initialization) 21 | server_capabilities: Option, 22 | /// Cached tools list 23 | tools_cache: Vec, 24 | /// Request ID counter 25 | request_counter: AtomicI64, 26 | /// Whether initialized 27 | initialized: bool, 28 | } 29 | 30 | impl McpClient { 31 | /// Create a new MCP client 32 | pub fn new(transport: T, name: impl Into, version: impl Into) -> Self { 33 | Self { 34 | transport, 35 | client_info: Implementation { 36 | name: name.into(), 37 | version: version.into(), 38 | }, 39 | server_info: None, 40 | server_capabilities: None, 41 | tools_cache: Vec::new(), 42 | request_counter: AtomicI64::new(1), 43 | initialized: false, 44 | } 45 | } 46 | 47 | /// Get the next request ID 48 | fn next_id(&self) -> RequestId { 49 | RequestId::Number(self.request_counter.fetch_add(1, Ordering::SeqCst)) 50 | } 51 | 52 | /// Send a request and wait for response 53 | fn send_request( 54 | &mut self, 55 | method: &str, 56 | params: Option, 57 | ) -> Result { 58 | let id = self.next_id(); 59 | let request = JsonRpcRequest::new(id.clone(), method, params); 60 | 61 | let request_str = framing::encode_line(&request).map_err(McpClientError::Transport)?; 62 | 63 | self.transport 64 | .send(&request_str) 65 | .map_err(McpClientError::Transport)?; 66 | 67 | // Wait for response 68 | loop { 69 | let line = self 70 | .transport 71 | .receive() 72 | .map_err(McpClientError::Transport)?; 73 | 74 | if line.is_empty() { 75 | continue; 76 | } 77 | 78 | let message = framing::parse_message(&line).map_err(McpClientError::Transport)?; 79 | 80 | match message { 81 | super::transport::McpMessage::Response(response) => { 82 | if response.id == id { 83 | if let Some(error) = response.error { 84 | return Err(McpClientError::ServerError(error)); 85 | } 86 | return response.result.ok_or(McpClientError::EmptyResponse); 87 | } 88 | } 89 | super::transport::McpMessage::Notification(_) => { 90 | // Handle notifications - continue waiting for response 91 | continue; 92 | } 93 | super::transport::McpMessage::Request(_) => { 94 | // Server sending request to client - handle if needed 95 | continue; 96 | } 97 | } 98 | } 99 | } 100 | 101 | /// Send a notification (no response expected) 102 | fn send_notification( 103 | &mut self, 104 | method: &str, 105 | params: Option, 106 | ) -> Result<(), McpClientError> { 107 | let notification = JsonRpcNotification { 108 | jsonrpc: "2.0".to_string(), 109 | method: method.to_string(), 110 | params, 111 | }; 112 | 113 | let notification_str = 114 | framing::encode_line(¬ification).map_err(McpClientError::Transport)?; 115 | 116 | self.transport 117 | .send(¬ification_str) 118 | .map_err(McpClientError::Transport)?; 119 | 120 | Ok(()) 121 | } 122 | 123 | /// Initialize the connection with the server 124 | pub fn initialize(&mut self) -> Result { 125 | let params = InitializeParams { 126 | protocol_version: MCP_VERSION.to_string(), 127 | capabilities: ClientCapabilities::default(), 128 | client_info: self.client_info.clone(), 129 | }; 130 | 131 | let result = self.send_request("initialize", Some(serde_json::to_value(¶ms)?))?; 132 | let init_result: InitializeResult = serde_json::from_value(result)?; 133 | 134 | self.server_info = Some(init_result.server_info.clone()); 135 | self.server_capabilities = Some(init_result.capabilities.clone()); 136 | 137 | // Send initialized notification 138 | self.send_notification("notifications/initialized", None)?; 139 | self.initialized = true; 140 | 141 | Ok(init_result) 142 | } 143 | 144 | /// List available tools from the server 145 | pub fn list_tools(&mut self) -> Result, McpClientError> { 146 | if !self.initialized { 147 | return Err(McpClientError::NotInitialized); 148 | } 149 | 150 | let result = self.send_request("tools/list", None)?; 151 | let list_result: ListToolsResult = serde_json::from_value(result)?; 152 | 153 | self.tools_cache = list_result.tools.clone(); 154 | Ok(list_result.tools) 155 | } 156 | 157 | /// Call a tool on the server 158 | pub fn call_tool( 159 | &mut self, 160 | name: impl Into, 161 | arguments: HashMap, 162 | ) -> Result { 163 | if !self.initialized { 164 | return Err(McpClientError::NotInitialized); 165 | } 166 | 167 | let params = CallToolParams { 168 | name: name.into(), 169 | arguments, 170 | }; 171 | 172 | let result = self.send_request("tools/call", Some(serde_json::to_value(¶ms)?))?; 173 | let call_result: CallToolResult = serde_json::from_value(result)?; 174 | 175 | Ok(call_result) 176 | } 177 | 178 | /// Get cached tools (from last list_tools call) 179 | pub fn cached_tools(&self) -> &[McpTool] { 180 | &self.tools_cache 181 | } 182 | 183 | /// Check if a tool exists (from cache) 184 | pub fn has_tool(&self, name: &str) -> bool { 185 | self.tools_cache.iter().any(|t| t.name == name) 186 | } 187 | 188 | /// Get server info (after initialization) 189 | pub fn server_info(&self) -> Option<&Implementation> { 190 | self.server_info.as_ref() 191 | } 192 | 193 | /// Get server capabilities (after initialization) 194 | pub fn capabilities(&self) -> Option<&ServerCapabilities> { 195 | self.server_capabilities.as_ref() 196 | } 197 | 198 | /// Close the connection 199 | pub fn close(mut self) -> Result<(), McpClientError> { 200 | self.transport.close().map_err(McpClientError::Transport) 201 | } 202 | } 203 | 204 | /// Errors that can occur during MCP client operations 205 | #[derive(Debug, thiserror::Error)] 206 | pub enum McpClientError { 207 | #[error("Transport error: {0}")] 208 | Transport(#[from] TransportError), 209 | 210 | #[error("Server error: {0}")] 211 | ServerError(JsonRpcError), 212 | 213 | #[error("Configuration error: {0}")] 214 | Config(String), 215 | 216 | #[error("Serialization error: {0}")] 217 | Serialization(#[from] serde_json::Error), 218 | 219 | #[error("Client not initialized")] 220 | NotInitialized, 221 | 222 | #[error("Empty response from server")] 223 | EmptyResponse, 224 | 225 | #[error("Tool not found: {0}")] 226 | ToolNotFound(String), 227 | } 228 | 229 | impl std::fmt::Display for JsonRpcError { 230 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 231 | write!(f, "{} (code: {})", self.message, self.code) 232 | } 233 | } 234 | 235 | #[cfg(test)] 236 | mod tests { 237 | use super::*; 238 | 239 | #[test] 240 | fn test_mcp_client_error_display() { 241 | let err = McpClientError::NotInitialized; 242 | assert!(format!("{}", err).contains("not initialized")); 243 | 244 | let err = McpClientError::ToolNotFound("test".to_string()); 245 | assert!(format!("{}", err).contains("test")); 246 | } 247 | 248 | #[test] 249 | fn test_json_rpc_error_display() { 250 | let err = JsonRpcError { 251 | code: -32600, 252 | message: "Invalid Request".to_string(), 253 | data: None, 254 | }; 255 | assert!(format!("{}", err).contains("Invalid Request")); 256 | assert!(format!("{}", err).contains("-32600")); 257 | } 258 | } 259 | -------------------------------------------------------------------------------- /src/tools/parser.rs: -------------------------------------------------------------------------------- 1 | // src/tools/parser.rs 2 | //! Tool call parsing from model output 3 | //! 4 | //! Supports multiple formats used by different models. 5 | 6 | use super::ToolCall; 7 | use regex::Regex; 8 | use serde_json::Value; 9 | 10 | /// Parser for extracting tool calls from model output text 11 | #[allow(dead_code)] 12 | #[derive(Debug, Clone)] 13 | pub struct ToolParser { 14 | /// Regex patterns for different formats 15 | patterns: Vec<(String, Regex)>, 16 | } 17 | 18 | impl Default for ToolParser { 19 | fn default() -> Self { 20 | Self::new() 21 | } 22 | } 23 | 24 | impl ToolParser { 25 | /// Create a new parser with default patterns 26 | pub fn new() -> Self { 27 | let patterns = vec![ 28 | // Qwen format: {"name": "...", "arguments": {...}} 29 | ( 30 | "qwen".to_string(), 31 | Regex::new(r#"\s*(\{[^}]+\})\s*"#).unwrap() 32 | ), 33 | // Generic JSON object with name and arguments 34 | ( 35 | "json".to_string(), 36 | Regex::new(r#"\{\s*"name"\s*:\s*"([^"]+)"\s*,\s*"arguments"\s*:\s*(\{[^}]*\}|\[[^\]]*\]|"[^"]*"|\d+|true|false|null)\s*\}"#).unwrap() 37 | ), 38 | // Function call format in code blocks 39 | ( 40 | "func".to_string(), 41 | Regex::new(r#"```(?:json)?\s*\{[^}]*"name"[^}]*\}\s*```"#).unwrap() 42 | ), 43 | ]; 44 | Self { patterns } 45 | } 46 | 47 | /// Parse tool calls from model output 48 | pub fn parse(&self, text: &str) -> Vec { 49 | let mut calls = Vec::new(); 50 | let mut call_id = 0; 51 | 52 | // Try Qwen format first 53 | if let Some(qwen_calls) = self.parse_qwen_format(text, &mut call_id) { 54 | calls.extend(qwen_calls); 55 | } 56 | 57 | // Try generic JSON format 58 | if calls.is_empty() { 59 | if let Some(json_calls) = self.parse_json_format(text, &mut call_id) { 60 | calls.extend(json_calls); 61 | } 62 | } 63 | 64 | // Try code block format 65 | if calls.is_empty() { 66 | if let Some(block_calls) = self.parse_code_block_format(text, &mut call_id) { 67 | calls.extend(block_calls); 68 | } 69 | } 70 | 71 | calls 72 | } 73 | 74 | /// Parse Qwen's format 75 | fn parse_qwen_format(&self, text: &str, call_id: &mut usize) -> Option> { 76 | let re = Regex::new(r"(?s)\s*(.*?)\s*").ok()?; 77 | let mut calls = Vec::new(); 78 | 79 | for cap in re.captures_iter(text) { 80 | if let Some(json_str) = cap.get(1) { 81 | if let Ok(parsed) = serde_json::from_str::(json_str.as_str()) { 82 | if let Some(call) = self.value_to_tool_call(&parsed, call_id) { 83 | calls.push(call); 84 | } 85 | } 86 | } 87 | } 88 | 89 | if calls.is_empty() { 90 | None 91 | } else { 92 | Some(calls) 93 | } 94 | } 95 | 96 | /// Parse generic JSON format with name and arguments 97 | fn parse_json_format(&self, text: &str, call_id: &mut usize) -> Option> { 98 | // Try to find JSON objects that look like tool calls 99 | let mut calls = Vec::new(); 100 | 101 | // Simple approach: try to parse the entire text as JSON first 102 | if let Ok(parsed) = serde_json::from_str::(text.trim()) { 103 | if let Some(call) = self.value_to_tool_call(&parsed, call_id) { 104 | return Some(vec![call]); 105 | } 106 | } 107 | 108 | // Look for JSON blocks in the text 109 | let mut depth = 0; 110 | let mut start = None; 111 | 112 | for (i, c) in text.char_indices() { 113 | match c { 114 | '{' => { 115 | if depth == 0 { 116 | start = Some(i); 117 | } 118 | depth += 1; 119 | } 120 | '}' => { 121 | depth -= 1; 122 | if depth == 0 { 123 | if let Some(s) = start { 124 | let json_str = &text[s..=i]; 125 | if let Ok(parsed) = serde_json::from_str::(json_str) { 126 | if let Some(call) = self.value_to_tool_call(&parsed, call_id) { 127 | calls.push(call); 128 | } 129 | } 130 | } 131 | start = None; 132 | } 133 | } 134 | _ => {} 135 | } 136 | } 137 | 138 | if calls.is_empty() { 139 | None 140 | } else { 141 | Some(calls) 142 | } 143 | } 144 | 145 | /// Parse tool calls from markdown code blocks 146 | fn parse_code_block_format(&self, text: &str, call_id: &mut usize) -> Option> { 147 | let re = Regex::new(r"```(?:json)?\s*([\s\S]*?)\s*```").ok()?; 148 | let mut calls = Vec::new(); 149 | 150 | for cap in re.captures_iter(text) { 151 | if let Some(content) = cap.get(1) { 152 | if let Ok(parsed) = serde_json::from_str::(content.as_str().trim()) { 153 | if let Some(call) = self.value_to_tool_call(&parsed, call_id) { 154 | calls.push(call); 155 | } 156 | } 157 | } 158 | } 159 | 160 | if calls.is_empty() { 161 | None 162 | } else { 163 | Some(calls) 164 | } 165 | } 166 | 167 | /// Convert a JSON Value to a ToolCall if it has the right structure 168 | fn value_to_tool_call(&self, value: &Value, call_id: &mut usize) -> Option { 169 | let name = value.get("name")?.as_str()?; 170 | let arguments = value.get("arguments")?; 171 | 172 | let args_str = if arguments.is_string() { 173 | arguments.as_str().unwrap().to_string() 174 | } else { 175 | serde_json::to_string(arguments).ok()? 176 | }; 177 | 178 | *call_id += 1; 179 | Some(ToolCall::new( 180 | format!("call_{}", call_id), 181 | name.to_string(), 182 | args_str, 183 | )) 184 | } 185 | 186 | /// Check if text contains any tool calls 187 | pub fn has_tool_calls(&self, text: &str) -> bool { 188 | // Quick checks for common patterns 189 | text.contains("") 190 | || (text.contains("\"name\"") && text.contains("\"arguments\"")) 191 | } 192 | } 193 | 194 | #[cfg(test)] 195 | mod tests { 196 | use super::*; 197 | 198 | #[test] 199 | fn test_parse_qwen_format() { 200 | let parser = ToolParser::new(); 201 | let text = r#"I'll help you with the weather. 202 | 203 | {"name": "get_weather", "arguments": {"location": "Tokyo", "unit": "celsius"}} 204 | "#; 205 | 206 | let calls = parser.parse(text); 207 | assert_eq!(calls.len(), 1); 208 | assert_eq!(calls[0].function.name, "get_weather"); 209 | assert!(calls[0].function.arguments.contains("Tokyo")); 210 | } 211 | 212 | #[test] 213 | fn test_parse_json_format() { 214 | let parser = ToolParser::new(); 215 | let text = r#"{"name": "calculate", "arguments": {"expression": "2+2"}}"#; 216 | 217 | let calls = parser.parse(text); 218 | assert_eq!(calls.len(), 1); 219 | assert_eq!(calls[0].function.name, "calculate"); 220 | } 221 | 222 | #[test] 223 | fn test_parse_code_block() { 224 | let parser = ToolParser::new(); 225 | let text = r#"Let me search for that: 226 | 227 | ```json 228 | {"name": "search", "arguments": {"query": "rust programming"}} 229 | ```"#; 230 | 231 | let calls = parser.parse(text); 232 | assert_eq!(calls.len(), 1); 233 | assert_eq!(calls[0].function.name, "search"); 234 | } 235 | 236 | #[test] 237 | fn test_multiple_tool_calls() { 238 | let parser = ToolParser::new(); 239 | let text = r#" 240 | {"name": "get_weather", "arguments": {"location": "Tokyo"}} 241 | 242 | 243 | {"name": "get_weather", "arguments": {"location": "London"}} 244 | "#; 245 | 246 | let calls = parser.parse(text); 247 | assert_eq!(calls.len(), 2); 248 | } 249 | 250 | #[test] 251 | fn test_has_tool_calls() { 252 | let parser = ToolParser::new(); 253 | 254 | assert!(parser.has_tool_calls("{}")); 255 | assert!(parser.has_tool_calls(r#"{"name": "foo", "arguments": {}}"#)); 256 | assert!(!parser.has_tool_calls("Just a normal response")); 257 | } 258 | } 259 | -------------------------------------------------------------------------------- /example/tool_calling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Tool Calling Example for vLLM.rs 4 | 5 | This example demonstrates how to use tool calling with the vLLM.rs server. 6 | It shows: 7 | 1. Defining tools with JSON Schema 8 | 2. Sending chat completion requests with tools 9 | 3. Handling tool calls from the model 10 | 4. Sending tool results back to continue the conversation 11 | """ 12 | 13 | import requests 14 | import json 15 | import argparse 16 | 17 | # Server URL 18 | BASE_URL = "http://localhost:8000/v1" 19 | 20 | 21 | def define_tools(): 22 | """Define available tools with JSON Schema.""" 23 | return [ 24 | { 25 | "type": "function", 26 | "function": { 27 | "name": "get_weather", 28 | "description": "Get the current weather for a location", 29 | "parameters": { 30 | "type": "object", 31 | "properties": { 32 | "location": { 33 | "type": "string", 34 | "description": "The city name, e.g., 'Tokyo', 'New York'" 35 | }, 36 | "unit": { 37 | "type": "string", 38 | "enum": ["celsius", "fahrenheit"], 39 | "description": "Temperature unit" 40 | } 41 | }, 42 | "required": ["location"] 43 | } 44 | } 45 | }, 46 | { 47 | "type": "function", 48 | "function": { 49 | "name": "calculator", 50 | "description": "Evaluate a mathematical expression", 51 | "parameters": { 52 | "type": "object", 53 | "properties": { 54 | "expression": { 55 | "type": "string", 56 | "description": "The mathematical expression to evaluate, e.g., '2 + 2'" 57 | } 58 | }, 59 | "required": ["expression"] 60 | } 61 | } 62 | }, 63 | { 64 | "type": "function", 65 | "function": { 66 | "name": "search_web", 67 | "description": "Search the web for information", 68 | "parameters": { 69 | "type": "object", 70 | "properties": { 71 | "query": { 72 | "type": "string", 73 | "description": "The search query" 74 | } 75 | }, 76 | "required": ["query"] 77 | } 78 | } 79 | } 80 | ] 81 | 82 | 83 | def execute_tool(tool_name: str, arguments: dict) -> str: 84 | """Execute a tool and return the result.""" 85 | if tool_name == "get_weather": 86 | location = arguments.get("location", "Unknown") 87 | unit = arguments.get("unit", "celsius") 88 | # Simulated weather data 89 | temp = 22 if unit == "celsius" else 72 90 | return json.dumps({ 91 | "location": location, 92 | "temperature": temp, 93 | "unit": unit, 94 | "condition": "sunny", 95 | "humidity": 65 96 | }) 97 | 98 | elif tool_name == "calculator": 99 | expression = arguments.get("expression", "0") 100 | try: 101 | # WARNING: In production, use a safe expression parser! 102 | result = eval(expression) 103 | return json.dumps({"result": result}) 104 | except Exception as e: 105 | return json.dumps({"error": str(e)}) 106 | 107 | elif tool_name == "search_web": 108 | query = arguments.get("query", "") 109 | # Simulated search results 110 | return json.dumps({ 111 | "query": query, 112 | "results": [ 113 | {"title": f"Result 1 for '{query}'", "url": "https://example.com/1"}, 114 | {"title": f"Result 2 for '{query}'", "url": "https://example.com/2"} 115 | ] 116 | }) 117 | 118 | else: 119 | return json.dumps({"error": f"Unknown tool: {tool_name}"}) 120 | 121 | 122 | def chat_with_tools(user_message: str, tools: list, model: str = "default"): 123 | """Send a chat request with tools and handle tool calls.""" 124 | 125 | messages = [{"role": "user", "content": user_message}] 126 | 127 | print(f"📝 User: {user_message}") 128 | print("-" * 50) 129 | 130 | while True: 131 | # Make the API request 132 | response = requests.post( 133 | f"{BASE_URL}/chat/completions", 134 | json={ 135 | "model": model, 136 | "messages": messages, 137 | "tools": tools, 138 | "max_tokens": 2048 139 | } 140 | ) 141 | 142 | if response.status_code != 200: 143 | print(f"❌ Error: {response.status_code} - {response.text}") 144 | return 145 | 146 | result = response.json() 147 | choice = result["choices"][0] 148 | message = choice["message"] 149 | finish_reason = choice.get("finish_reason", "unknown") 150 | 151 | # Check if the model made tool calls 152 | if message.get("tool_calls"): 153 | print("🔧 Model is calling tools:") 154 | tool_results = [] 155 | 156 | for tool_call in message["tool_calls"]: 157 | tool_name = tool_call["function"]["name"] 158 | arguments = json.loads(tool_call["function"]["arguments"]) 159 | tool_call_id = tool_call["id"] 160 | 161 | print(f" - {tool_name}({json.dumps(arguments)})") 162 | 163 | # Execute the tool 164 | result = execute_tool(tool_name, arguments) 165 | print(f" → Result: {result}") 166 | 167 | tool_results.append({ 168 | "role": "tool", 169 | "tool_call_id": tool_call_id, 170 | "content": result 171 | }) 172 | 173 | # Add the assistant's tool call message and tool results 174 | messages.append({ 175 | "role": "assistant", 176 | "tool_calls": message["tool_calls"] 177 | }) 178 | messages.extend(tool_results) 179 | 180 | print("-" * 50) 181 | continue 182 | 183 | # Model gave a text response 184 | content = message.get("content", "") 185 | print(f"🤖 Assistant: {content}") 186 | 187 | if finish_reason != "tool_calls": 188 | break 189 | 190 | return content 191 | 192 | 193 | def main(): 194 | parser = argparse.ArgumentParser(description="Tool calling example for vLLM.rs") 195 | parser.add_argument("--url", default="http://localhost:8000/v1", 196 | help="Base URL for the vLLM.rs server") 197 | parser.add_argument("--model", default="default", 198 | help="Model name to use") 199 | args = parser.parse_args() 200 | 201 | global BASE_URL 202 | BASE_URL = args.url 203 | 204 | tools = define_tools() 205 | 206 | print("=" * 60) 207 | print("🛠️ vLLM.rs Tool Calling Demo") 208 | print("=" * 60) 209 | print() 210 | 211 | # Example 1: Weather query 212 | print("Example 1: Weather Query") 213 | print("=" * 60) 214 | chat_with_tools( 215 | "What's the weather like in Tokyo and New York?", 216 | tools, 217 | args.model 218 | ) 219 | print() 220 | 221 | # Example 2: Calculator 222 | print("Example 2: Calculator") 223 | print("=" * 60) 224 | chat_with_tools( 225 | "What is 25 * 17 + 43?", 226 | tools, 227 | args.model 228 | ) 229 | print() 230 | 231 | # Example 3: Web search 232 | print("Example 3: Web Search") 233 | print("=" * 60) 234 | chat_with_tools( 235 | "Search for information about Rust programming language", 236 | tools, 237 | args.model 238 | ) 239 | print() 240 | 241 | # Interactive mode 242 | print("=" * 60) 243 | print("💬 Interactive Mode (type 'quit' to exit)") 244 | print("=" * 60) 245 | 246 | while True: 247 | try: 248 | user_input = input("\n🤖✨ Enter your prompt: ").strip() 249 | if user_input.lower() in ['quit', 'exit', 'q']: 250 | print("👋 Goodbye!") 251 | break 252 | if not user_input: 253 | continue 254 | 255 | print() 256 | chat_with_tools(user_input, tools, args.model) 257 | 258 | except KeyboardInterrupt: 259 | print("\n👋 Goodbye!") 260 | break 261 | except EOFError: 262 | print("\n👋 Goodbye!") 263 | break 264 | 265 | 266 | if __name__ == "__main__": 267 | main() 268 | -------------------------------------------------------------------------------- /src/api.rs: -------------------------------------------------------------------------------- 1 | use crate::core::engine::{LLMEngine, StreamItem, GLOBAL_RT}; 2 | use crate::core::{GenerationOutput, SyncCollectionResult}; 3 | use crate::server::{build_messages_and_images, run_server, ChatMessage}; 4 | use crate::utils::chat_template::Message; 5 | use crate::utils::config::{EngineConfig, SamplingParams}; 6 | use crate::utils::get_dtype; 7 | use candle_core::{DType, Result}; 8 | use parking_lot::RwLock; 9 | use std::str::FromStr; 10 | use std::sync::Arc; 11 | use tokio::sync::mpsc; 12 | 13 | #[derive(Clone, Debug)] 14 | pub enum ModelRepo { 15 | /// (model_id, filename) -- when filename is None, treat as safetensor model id. 16 | /// When filename is Some, treat as GGUF model id + GGUF filename. 17 | ModelID((&'static str, Option<&'static str>)), 18 | /// Safetensor local path. 19 | ModelPath(&'static str), 20 | /// GGUF file(s). Only the first file is used today. 21 | ModelFile(Vec<&'static str>), 22 | } 23 | 24 | #[derive(Clone, Debug)] 25 | pub struct EngineBuilder { 26 | repo: ModelRepo, 27 | isq: Option, 28 | dtype: Option, 29 | flash_attn: Option, 30 | fp8_kvcache: Option, 31 | context_cache: Option, 32 | device_ids: Option>, 33 | } 34 | 35 | impl EngineBuilder { 36 | pub fn new(repo: ModelRepo) -> Self { 37 | Self { 38 | repo, 39 | isq: None, 40 | dtype: None, 41 | flash_attn: None, 42 | fp8_kvcache: None, 43 | context_cache: None, 44 | device_ids: None, 45 | } 46 | } 47 | 48 | pub fn with_isq(mut self, isq: impl Into) -> Self { 49 | self.isq = Some(isq.into()); 50 | self 51 | } 52 | 53 | pub fn with_dtype(mut self, dtype: DType) -> Self { 54 | self.dtype = Some(dtype); 55 | self 56 | } 57 | 58 | pub fn without_flash_attn(mut self) -> Self { 59 | self.flash_attn = Some(false); 60 | self 61 | } 62 | 63 | pub fn with_fp8_kvcache(mut self) -> Self { 64 | self.fp8_kvcache = Some(true); 65 | self 66 | } 67 | 68 | pub fn with_context_cache(mut self, enabled: bool) -> Self { 69 | self.context_cache = Some(enabled); 70 | self 71 | } 72 | 73 | pub fn with_multirank(mut self, device_ids: &str) -> Result { 74 | self.device_ids = Some(parse_device_ids(device_ids)?); 75 | Ok(self) 76 | } 77 | 78 | pub fn build(self) -> Result { 79 | let disable_flash_attn = if let Some(enable_flash_attn) = self.flash_attn { 80 | Some(!enable_flash_attn) 81 | } else { 82 | None 83 | }; 84 | 85 | let (model_id, weight_path, weight_file) = match self.repo { 86 | ModelRepo::ModelID((model_id, filename)) => ( 87 | Some(model_id.to_owned()), 88 | None, 89 | filename.map(|f| f.to_owned()), 90 | ), 91 | ModelRepo::ModelPath(path) => (None, Some(path.to_owned()), None), 92 | ModelRepo::ModelFile(files) => { 93 | if files.len() > 1 { 94 | crate::log_warn!("Multiple GGUF files provided, using the first one."); 95 | } 96 | let weight_file = files.into_iter().next().map(|f| f.to_owned()); 97 | (None, None, weight_file) 98 | } 99 | }; 100 | 101 | let econfig = EngineConfig::new( 102 | model_id, 103 | weight_path, 104 | weight_file, 105 | None, 106 | None, 107 | None, 108 | None, 109 | None, 110 | None, 111 | self.isq, 112 | Some(self.device_ids.clone().unwrap_or(vec![0]).len()), 113 | self.device_ids.clone(), 114 | None, 115 | None, 116 | self.context_cache, 117 | self.fp8_kvcache, 118 | None, 119 | None, 120 | None, 121 | None, 122 | None, 123 | None, 124 | None, 125 | disable_flash_attn, 126 | ); 127 | 128 | let dtype = self.dtype.clone().map(dtype_to_str); 129 | let dtype = get_dtype(dtype); 130 | 131 | let engine = LLMEngine::new(&econfig, dtype)?; 132 | Ok(Engine { engine, econfig }) 133 | } 134 | } 135 | 136 | pub struct Engine { 137 | engine: Arc>, 138 | econfig: EngineConfig, 139 | } 140 | 141 | impl Engine { 142 | pub fn start_server(&mut self, port: usize, with_ui_server: bool) -> Result<()> { 143 | GLOBAL_RT.block_on(async { 144 | run_server( 145 | self.engine.clone(), 146 | self.econfig.clone(), 147 | port, 148 | with_ui_server, 149 | ) 150 | .await 151 | }) 152 | } 153 | 154 | pub fn generate( 155 | &mut self, 156 | params: SamplingParams, 157 | messages: Vec, 158 | ) -> Result { 159 | let img_cfg = { self.engine.read().img_cfg.clone() }; 160 | let (messages, image_data) = build_messages_and_images(&messages, img_cfg.as_ref())?; 161 | self.generate_messages(params, messages, image_data) 162 | } 163 | 164 | pub fn generate_messages( 165 | &mut self, 166 | params: SamplingParams, 167 | messages: Vec, 168 | images: Option, 169 | ) -> Result { 170 | let (receivers, tokenizer) = { 171 | let mut engine = self.engine.write(); 172 | ( 173 | engine.generate_sync(&vec![params], &vec![messages], images)?, 174 | Arc::new(engine.tokenizer.clone()), 175 | ) 176 | }; 177 | 178 | let results = GLOBAL_RT 179 | .block_on(async { LLMEngine::collect_sync_results(receivers, tokenizer).await })?; 180 | 181 | // Extract GenerationOutput from SyncCollectionResult 182 | for result in results { 183 | match result { 184 | SyncCollectionResult::Completed(output) => return Ok(output), 185 | SyncCollectionResult::ToolCallPause { .. } => { 186 | // In simple API, tool call pause is not supported - return error 187 | candle_core::bail!("Tool call detected but simple API does not support MCP tool calling. Use the server API instead."); 188 | } 189 | } 190 | } 191 | 192 | candle_core::bail!("No generation output returned") 193 | } 194 | 195 | pub fn generate_stream( 196 | &mut self, 197 | params: SamplingParams, 198 | messages: Vec, 199 | ) -> Result { 200 | let img_cfg = { self.engine.read().img_cfg.clone() }; 201 | let (messages, image_data) = build_messages_and_images(&messages, img_cfg.as_ref())?; 202 | 203 | let (seq_id, prompt_length, stream) = { 204 | let mut engine = self.engine.write(); 205 | engine.generate_stream(¶ms, &messages, image_data)? 206 | }; 207 | 208 | Ok(EngineStream { 209 | engine: self.engine.clone(), 210 | rx: stream, 211 | finished: false, 212 | seq_id, 213 | prompt_length, 214 | cancelled: false, 215 | }) 216 | } 217 | 218 | pub fn get_num_cached_tokens(&self) -> usize { 219 | let engine = self.engine.read(); 220 | engine.get_num_cached_tokens() 221 | } 222 | 223 | pub fn get_available_kv_tokens(&self) -> usize { 224 | let engine = self.engine.read(); 225 | engine.get_available_kv_tokens() 226 | } 227 | } 228 | 229 | pub struct EngineStream { 230 | engine: Arc>, 231 | rx: mpsc::Receiver, 232 | finished: bool, 233 | pub seq_id: usize, 234 | pub prompt_length: usize, 235 | cancelled: bool, 236 | } 237 | 238 | impl EngineStream { 239 | pub fn cancel(&mut self) { 240 | self.cancelled = true; 241 | let mut engine_guard = self.engine.write(); 242 | engine_guard.cancel(self.seq_id); 243 | } 244 | 245 | pub async fn recv(&mut self) -> Option { 246 | if self.finished { 247 | return None; 248 | } 249 | let item = self.rx.recv().await; 250 | if matches!(item, Some(StreamItem::Done(_) | StreamItem::Error(_))) { 251 | self.finished = true; 252 | } 253 | item 254 | } 255 | 256 | pub fn recv_blocking(&mut self) -> Option { 257 | if self.finished { 258 | return None; 259 | } 260 | let item = GLOBAL_RT.block_on(self.rx.recv()); 261 | if matches!(item, Some(StreamItem::Done(_) | StreamItem::Error(_))) { 262 | self.finished = true; 263 | } 264 | item 265 | } 266 | 267 | pub fn is_finished(&self) -> bool { 268 | self.finished 269 | } 270 | 271 | pub fn is_cancelled(&self) -> bool { 272 | self.cancelled 273 | } 274 | } 275 | 276 | fn parse_device_ids(device_ids: &str) -> Result> { 277 | device_ids 278 | .split(',') 279 | .filter(|s| !s.trim().is_empty()) 280 | .map(|s| { 281 | usize::from_str(s.trim()) 282 | .map_err(|e| candle_core::Error::msg(format!("Invalid device id '{s}': {e}"))) 283 | }) 284 | .collect() 285 | } 286 | 287 | fn dtype_to_str(dtype: DType) -> String { 288 | match dtype { 289 | DType::F16 => "f16".to_string(), 290 | DType::BF16 => "bf16".to_string(), 291 | DType::F32 => "f32".to_string(), 292 | _ => "bf16".to_string(), 293 | } 294 | } 295 | -------------------------------------------------------------------------------- /src/models/mistral3_vl/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::models::layers::distributed::{Comm, ReplicatedLinear}; 2 | use crate::models::layers::others::{rms_norm, NormX}; 3 | use crate::models::layers::VarBuilderX; 4 | use crate::models::llama::LLaMaForCausalLM; 5 | use crate::utils::config::Config; 6 | use crate::utils::image::ImageData; 7 | use crate::utils::progress::ProgressLike; 8 | use attention_rs::InputMetadata; 9 | use candle_core::{DType, Device, Result, Tensor, D}; 10 | use parking_lot::RwLock; 11 | use std::rc::Rc; 12 | use std::sync::Arc; 13 | mod config; 14 | mod vision; 15 | use attention_rs::ops::{NonZeroOp, SplitOp}; 16 | pub use config::Mistral3Config; 17 | use vision::VisionModel; 18 | 19 | struct PatchMerger { 20 | merge: ReplicatedLinear, 21 | spatial_merge_size: usize, 22 | patch_size: usize, 23 | } 24 | 25 | impl PatchMerger { 26 | fn new(cfg: &Mistral3Config, vb: VarBuilderX, dtype: DType) -> Result { 27 | Ok(Self { 28 | merge: ReplicatedLinear::load_no_bias( 29 | cfg.vision_config.hidden_size * cfg.spatial_merge_size.pow(2), 30 | cfg.vision_config.hidden_size, 31 | vb.pp("merging_layer"), 32 | &None, 33 | &None, 34 | dtype, 35 | )?, 36 | spatial_merge_size: cfg.spatial_merge_size, 37 | patch_size: cfg.vision_config.patch_size, 38 | }) 39 | } 40 | 41 | fn forward(&self, image_features: &Tensor, image_sizes: Vec<(usize, usize)>) -> Result { 42 | let image_sizes = image_sizes 43 | .iter() 44 | .map(|&(h, w)| (h / self.patch_size, w / self.patch_size)) 45 | .collect::>(); 46 | 47 | let tokens_per_image = image_sizes.iter().map(|&(h, w)| h * w).collect::>(); 48 | let d = image_features.dim(D::Minus1)?; 49 | 50 | let mut permuted_tensor = Vec::new(); 51 | 52 | for (image_index, image_tokens) in image_features 53 | .split(&tokens_per_image, 0)? 54 | .iter() 55 | .enumerate() 56 | { 57 | let (h, w) = image_sizes[image_index]; 58 | let image_grid = image_tokens 59 | .reshape((h, w, d))? 60 | .permute((2, 0, 1))? 61 | .unsqueeze(0)?; 62 | let grid = { 63 | let patches = image_grid 64 | .unfold(2, self.spatial_merge_size, self.spatial_merge_size)? 65 | .unfold(3, self.spatial_merge_size, self.spatial_merge_size)?; 66 | 67 | let patches = patches.permute((0, 1, 4, 5, 2, 3))?; 68 | patches.contiguous()?.reshape(( 69 | 1, 70 | d * self.spatial_merge_size * self.spatial_merge_size, 71 | (), 72 | ))? 73 | }; 74 | let grid = grid 75 | .reshape((d * self.spatial_merge_size.pow(2), ()))? 76 | .t()?; 77 | permuted_tensor.push(grid); 78 | } 79 | 80 | let image_features = Tensor::cat(&permuted_tensor, 0)?; 81 | 82 | self.merge.forward(&image_features) 83 | } 84 | } 85 | 86 | struct MultiModalProjector { 87 | norm: NormX, 88 | linear_1: ReplicatedLinear, 89 | linear_2: ReplicatedLinear, 90 | act: candle_nn::Activation, 91 | patch_merger: PatchMerger, 92 | } 93 | 94 | impl MultiModalProjector { 95 | fn new(cfg: &Mistral3Config, vb: VarBuilderX, dtype: DType) -> Result { 96 | let is_qvar_builder = vb.is_qvar_builder(); 97 | let norm = rms_norm( 98 | cfg.vision_config.hidden_size, 99 | cfg.text_config.rms_norm_eps, 100 | vb.pp("norm"), 101 | if is_qvar_builder { DType::F32 } else { dtype }, 102 | false, 103 | )?; 104 | let num_feature_layers = 1; 105 | let linear_1 = ReplicatedLinear::load_b( 106 | cfg.vision_config.hidden_size * num_feature_layers, 107 | cfg.text_config.hidden_size, 108 | cfg.multimodal_projector_bias, 109 | if is_qvar_builder { 110 | vb.pp("ln1") 111 | } else { 112 | vb.pp("linear_1") 113 | }, 114 | &cfg.text_config.quantization_config, 115 | &cfg.text_config.quant, 116 | dtype, 117 | )?; 118 | 119 | let linear_2 = ReplicatedLinear::load_b( 120 | cfg.text_config.hidden_size, 121 | cfg.text_config.hidden_size, 122 | cfg.multimodal_projector_bias, 123 | if is_qvar_builder { 124 | vb.pp("ln2") 125 | } else { 126 | vb.pp("linear_2") 127 | }, 128 | &cfg.text_config.quantization_config, 129 | &cfg.text_config.quant, 130 | dtype, 131 | )?; 132 | 133 | let patch_merger = PatchMerger::new(cfg, vb.pp("patch_merger"), dtype)?; 134 | Ok(Self { 135 | norm, 136 | linear_1, 137 | linear_2, 138 | act: cfg.projector_hidden_act, 139 | patch_merger, 140 | }) 141 | } 142 | 143 | fn forward(&self, image_features: &Tensor, image_sizes: Vec<(usize, usize)>) -> Result { 144 | let mut hidden_states = self.norm.forward(image_features)?; 145 | hidden_states = self.patch_merger.forward(&hidden_states, image_sizes)?; 146 | hidden_states = self.linear_1.forward(&hidden_states)?.apply(&self.act)?; 147 | self.linear_2.forward(&hidden_states) 148 | } 149 | } 150 | 151 | pub struct Mistral3ForConditionalGeneration { 152 | text_model: LLaMaForCausalLM, 153 | vision_model: VisionModel, 154 | mmproj: MultiModalProjector, 155 | cfg: Mistral3Config, 156 | } 157 | 158 | impl Mistral3ForConditionalGeneration { 159 | pub fn new( 160 | vb: &VarBuilderX, 161 | comm: Rc, 162 | config: &Config, 163 | dtype: DType, 164 | is_rope_i: bool, 165 | device: &Device, 166 | progress_reporter: Arc>>, 167 | ) -> Result { 168 | assert!( 169 | config.extra_config_json.is_some(), 170 | "Invalid multimodel config file!" 171 | ); 172 | let mut cfg: Mistral3Config = 173 | serde_json::from_str(config.extra_config_json.as_ref().unwrap()) 174 | .map_err(candle_core::Error::wrap)?; 175 | cfg.text_config = config.clone(); 176 | let vision_model = VisionModel::new( 177 | &cfg.vision_config, 178 | vb.pp("vision_tower"), 179 | comm.clone(), 180 | dtype, 181 | )?; 182 | let mmproj = MultiModalProjector::new(&cfg, vb.pp("multi_modal_projector"), dtype)?; 183 | 184 | let text_model = LLaMaForCausalLM::new( 185 | &vb.pp("language_model"), 186 | comm.clone(), 187 | &cfg.text_config, 188 | dtype, 189 | is_rope_i, 190 | device, 191 | progress_reporter, 192 | )?; 193 | 194 | Ok(Self { 195 | vision_model, 196 | text_model, 197 | mmproj, 198 | cfg: cfg.clone(), 199 | }) 200 | } 201 | 202 | fn vision_tower( 203 | &self, 204 | image_features: &Tensor, 205 | image_sizes: Vec<(usize, usize)>, 206 | ) -> Result { 207 | let image_outputs = self 208 | .vision_model 209 | .forward(image_features, image_sizes.clone())?; 210 | let selected_image_feature = image_outputs; 211 | self.mmproj 212 | .forward(&selected_image_feature.squeeze(0)?, image_sizes) 213 | } 214 | 215 | #[allow(clippy::too_many_arguments)] 216 | pub fn forward( 217 | &self, 218 | input_ids: &Tensor, 219 | positions: &Tensor, 220 | kv_caches: Option<&Vec<(Tensor, Tensor)>>, 221 | input_metadata: &InputMetadata, 222 | images: Option<&ImageData>, 223 | ) -> Result { 224 | let (mut input_embeds, dtype) = ( 225 | self.text_model.embed_forward(input_ids)?, 226 | self.text_model.dtype(), 227 | ); 228 | 229 | if let Some(images) = &images { 230 | let mut image_tensor = images.to_tensor_f32(&input_ids.device())?.to_dtype(dtype)?; 231 | let image_mask = input_ids.eq(self.cfg.image_token_index as u32)?; 232 | let image_mask = image_mask 233 | .unsqueeze(D::Minus1)? 234 | .broadcast_as(input_embeds.shape().clone())? 235 | .to_dtype(DType::U32)?; 236 | 237 | let indices = image_mask.flatten_all()?.nonzero()?.squeeze(1)?; 238 | let mut image_sizes = images.patches.clone(); 239 | let num_images = image_tensor.dim(0)?; 240 | assert!( 241 | num_images == image_sizes.len(), 242 | "Input image and patch dim mismatch!" 243 | ); 244 | if images.image_idx > 0 && (images.image_idx as usize) < num_images { 245 | image_tensor = image_tensor.narrow( 246 | 0, 247 | images.image_idx as usize, 248 | num_images - images.image_idx as usize, 249 | )?; 250 | image_sizes = image_sizes[images.image_idx as usize..].to_vec(); 251 | crate::log_warn!( 252 | "Slicing images: start idx {} -> {:?}", 253 | images.image_idx, 254 | image_tensor.shape() 255 | ); 256 | } 257 | 258 | let image_features = self 259 | .vision_tower(&image_tensor, image_sizes)? 260 | .to_dtype(input_embeds.dtype())?; 261 | 262 | let mut x_flat = input_embeds.flatten_all()?; 263 | let image_flat = image_features.flatten_all()?; 264 | 265 | x_flat = 266 | x_flat.scatter_add(&indices, &(image_flat - x_flat.gather(&indices, 0)?)?, 0)?; 267 | input_embeds = x_flat.reshape(input_embeds.shape())?; 268 | } 269 | 270 | self.text_model 271 | .forward(&input_embeds, positions, kv_caches, input_metadata, true) 272 | } 273 | 274 | pub fn get_vocab_size(&self) -> usize { 275 | panic!("not impl") 276 | } 277 | } 278 | -------------------------------------------------------------------------------- /src/mcp/types.rs: -------------------------------------------------------------------------------- 1 | // src/mcp/types.rs 2 | //! MCP protocol types and message definitions 3 | 4 | use serde::{Deserialize, Serialize}; 5 | use serde_json::Value; 6 | use std::collections::HashMap; 7 | 8 | /// MCP protocol version 9 | pub const MCP_VERSION: &str = "2024-11-05"; 10 | 11 | /// JSON-RPC request structure 12 | #[derive(Debug, Clone, Serialize, Deserialize)] 13 | pub struct JsonRpcRequest { 14 | pub jsonrpc: String, 15 | pub id: RequestId, 16 | pub method: String, 17 | #[serde(default, skip_serializing_if = "Option::is_none")] 18 | pub params: Option, 19 | } 20 | 21 | impl JsonRpcRequest { 22 | pub fn new(id: impl Into, method: impl Into, params: Option) -> Self { 23 | Self { 24 | jsonrpc: "2.0".to_string(), 25 | id: id.into(), 26 | method: method.into(), 27 | params, 28 | } 29 | } 30 | } 31 | 32 | /// JSON-RPC response structure 33 | #[derive(Debug, Clone, Serialize, Deserialize)] 34 | pub struct JsonRpcResponse { 35 | pub jsonrpc: String, 36 | pub id: RequestId, 37 | #[serde(skip_serializing_if = "Option::is_none")] 38 | pub result: Option, 39 | #[serde(skip_serializing_if = "Option::is_none")] 40 | pub error: Option, 41 | } 42 | 43 | impl JsonRpcResponse { 44 | pub fn success(id: RequestId, result: Value) -> Self { 45 | Self { 46 | jsonrpc: "2.0".to_string(), 47 | id, 48 | result: Some(result), 49 | error: None, 50 | } 51 | } 52 | 53 | pub fn error(id: RequestId, error: JsonRpcError) -> Self { 54 | Self { 55 | jsonrpc: "2.0".to_string(), 56 | id, 57 | result: None, 58 | error: Some(error), 59 | } 60 | } 61 | } 62 | 63 | /// JSON-RPC notification (no response expected) 64 | #[derive(Debug, Clone, Serialize, Deserialize)] 65 | pub struct JsonRpcNotification { 66 | pub jsonrpc: String, 67 | pub method: String, 68 | #[serde(default, skip_serializing_if = "Option::is_none")] 69 | pub params: Option, 70 | } 71 | 72 | /// JSON-RPC error 73 | #[derive(Debug, Clone, Serialize, Deserialize)] 74 | pub struct JsonRpcError { 75 | pub code: i32, 76 | pub message: String, 77 | #[serde(skip_serializing_if = "Option::is_none")] 78 | pub data: Option, 79 | } 80 | 81 | impl JsonRpcError { 82 | pub fn parse_error() -> Self { 83 | Self { 84 | code: -32700, 85 | message: "Parse error".to_string(), 86 | data: None, 87 | } 88 | } 89 | 90 | pub fn invalid_request() -> Self { 91 | Self { 92 | code: -32600, 93 | message: "Invalid Request".to_string(), 94 | data: None, 95 | } 96 | } 97 | 98 | pub fn method_not_found() -> Self { 99 | Self { 100 | code: -32601, 101 | message: "Method not found".to_string(), 102 | data: None, 103 | } 104 | } 105 | 106 | pub fn invalid_params(msg: impl Into) -> Self { 107 | Self { 108 | code: -32602, 109 | message: msg.into(), 110 | data: None, 111 | } 112 | } 113 | 114 | pub fn internal_error(msg: impl Into) -> Self { 115 | Self { 116 | code: -32603, 117 | message: msg.into(), 118 | data: None, 119 | } 120 | } 121 | } 122 | 123 | /// Request ID (can be string or number) 124 | #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] 125 | #[serde(untagged)] 126 | pub enum RequestId { 127 | String(String), 128 | Number(i64), 129 | } 130 | 131 | impl From for RequestId { 132 | fn from(s: String) -> Self { 133 | RequestId::String(s) 134 | } 135 | } 136 | 137 | impl From<&str> for RequestId { 138 | fn from(s: &str) -> Self { 139 | RequestId::String(s.to_string()) 140 | } 141 | } 142 | 143 | impl From for RequestId { 144 | fn from(n: i64) -> Self { 145 | RequestId::Number(n) 146 | } 147 | } 148 | 149 | /// MCP server capabilities 150 | #[derive(Debug, Clone, Default, Serialize, Deserialize)] 151 | pub struct ServerCapabilities { 152 | #[serde(skip_serializing_if = "Option::is_none")] 153 | pub tools: Option, 154 | #[serde(skip_serializing_if = "Option::is_none")] 155 | pub resources: Option, 156 | #[serde(skip_serializing_if = "Option::is_none")] 157 | pub prompts: Option, 158 | #[serde(skip_serializing_if = "Option::is_none")] 159 | pub logging: Option, 160 | } 161 | 162 | #[derive(Debug, Clone, Default, Serialize, Deserialize)] 163 | pub struct ToolsCapability { 164 | #[serde(default)] 165 | pub list_changed: bool, 166 | } 167 | 168 | #[derive(Debug, Clone, Default, Serialize, Deserialize)] 169 | pub struct ResourcesCapability { 170 | #[serde(default)] 171 | pub subscribe: bool, 172 | #[serde(default)] 173 | pub list_changed: bool, 174 | } 175 | 176 | #[derive(Debug, Clone, Default, Serialize, Deserialize)] 177 | pub struct PromptsCapability { 178 | #[serde(default)] 179 | pub list_changed: bool, 180 | } 181 | 182 | #[derive(Debug, Clone, Default, Serialize, Deserialize)] 183 | pub struct LoggingCapability {} 184 | 185 | /// MCP client capabilities 186 | #[derive(Debug, Clone, Default, Serialize, Deserialize)] 187 | pub struct ClientCapabilities { 188 | #[serde(skip_serializing_if = "Option::is_none")] 189 | pub roots: Option, 190 | #[serde(skip_serializing_if = "Option::is_none")] 191 | pub sampling: Option, 192 | } 193 | 194 | #[derive(Debug, Clone, Default, Serialize, Deserialize)] 195 | pub struct RootsCapability { 196 | #[serde(default)] 197 | pub list_changed: bool, 198 | } 199 | 200 | #[derive(Debug, Clone, Default, Serialize, Deserialize)] 201 | pub struct SamplingCapability {} 202 | 203 | /// Initialize request parameters 204 | #[derive(Debug, Clone, Serialize, Deserialize)] 205 | #[serde(rename_all = "camelCase")] 206 | pub struct InitializeParams { 207 | pub protocol_version: String, 208 | pub capabilities: ClientCapabilities, 209 | pub client_info: Implementation, 210 | } 211 | 212 | /// Initialize response result 213 | #[derive(Debug, Clone, Serialize, Deserialize)] 214 | #[serde(rename_all = "camelCase")] 215 | pub struct InitializeResult { 216 | pub protocol_version: String, 217 | pub capabilities: ServerCapabilities, 218 | pub server_info: Implementation, 219 | #[serde(skip_serializing_if = "Option::is_none")] 220 | pub instructions: Option, 221 | } 222 | 223 | /// Implementation info (client or server) 224 | #[derive(Debug, Clone, Serialize, Deserialize)] 225 | pub struct Implementation { 226 | pub name: String, 227 | pub version: String, 228 | } 229 | 230 | /// MCP Tool definition (different from our internal Tool) 231 | #[derive(Debug, Clone, Serialize, Deserialize)] 232 | #[serde(rename_all = "camelCase")] 233 | pub struct McpTool { 234 | pub name: String, 235 | #[serde(skip_serializing_if = "Option::is_none")] 236 | pub description: Option, 237 | pub input_schema: Value, 238 | #[serde(skip_serializing_if = "Option::is_none")] 239 | pub output_schema: Option, 240 | } 241 | 242 | /// List tools result 243 | #[derive(Debug, Clone, Serialize, Deserialize)] 244 | pub struct ListToolsResult { 245 | pub tools: Vec, 246 | #[serde(skip_serializing_if = "Option::is_none")] 247 | pub next_cursor: Option, 248 | } 249 | 250 | /// Call tool parameters 251 | #[derive(Debug, Clone, Serialize, Deserialize)] 252 | pub struct CallToolParams { 253 | pub name: String, 254 | #[serde(default)] 255 | pub arguments: HashMap, 256 | } 257 | 258 | /// Call tool result 259 | #[derive(Debug, Clone, Serialize, Deserialize)] 260 | #[serde(rename_all = "camelCase")] 261 | pub struct CallToolResult { 262 | pub content: Vec, 263 | #[serde(default)] 264 | pub is_error: bool, 265 | } 266 | 267 | /// Tool content types 268 | #[derive(Debug, Clone, Serialize, Deserialize)] 269 | #[serde(tag = "type", rename_all = "lowercase")] 270 | pub enum ToolContent { 271 | Text { 272 | text: String, 273 | }, 274 | Image { 275 | data: String, 276 | mime_type: String, 277 | }, 278 | Resource { 279 | uri: String, 280 | mime_type: Option, 281 | text: Option, 282 | }, 283 | } 284 | 285 | impl ToolContent { 286 | pub fn text(text: impl Into) -> Self { 287 | ToolContent::Text { text: text.into() } 288 | } 289 | } 290 | 291 | /// MCP Resource 292 | #[derive(Debug, Clone, Serialize, Deserialize)] 293 | #[serde(rename_all = "camelCase")] 294 | pub struct Resource { 295 | pub uri: String, 296 | pub name: String, 297 | #[serde(skip_serializing_if = "Option::is_none")] 298 | pub description: Option, 299 | #[serde(skip_serializing_if = "Option::is_none")] 300 | pub mime_type: Option, 301 | } 302 | 303 | /// MCP Prompt 304 | #[derive(Debug, Clone, Serialize, Deserialize)] 305 | pub struct Prompt { 306 | pub name: String, 307 | #[serde(skip_serializing_if = "Option::is_none")] 308 | pub description: Option, 309 | #[serde(default, skip_serializing_if = "Vec::is_empty")] 310 | pub arguments: Vec, 311 | } 312 | 313 | #[derive(Debug, Clone, Serialize, Deserialize)] 314 | pub struct PromptArgument { 315 | pub name: String, 316 | #[serde(skip_serializing_if = "Option::is_none")] 317 | pub description: Option, 318 | #[serde(default)] 319 | pub required: bool, 320 | } 321 | 322 | #[cfg(test)] 323 | mod tests { 324 | use super::*; 325 | 326 | #[test] 327 | fn test_json_rpc_request() { 328 | let req = JsonRpcRequest::new(1i64, "tools/list", None); 329 | let json = serde_json::to_string(&req).unwrap(); 330 | assert!(json.contains("\"jsonrpc\":\"2.0\"")); 331 | assert!(json.contains("\"method\":\"tools/list\"")); 332 | } 333 | 334 | #[test] 335 | fn test_json_rpc_response() { 336 | let resp = JsonRpcResponse::success(RequestId::Number(1), serde_json::json!({"tools": []})); 337 | assert!(resp.result.is_some()); 338 | assert!(resp.error.is_none()); 339 | } 340 | 341 | #[test] 342 | fn test_mcp_tool() { 343 | let tool = McpTool { 344 | name: "get_weather".to_string(), 345 | description: Some("Get weather info".to_string()), 346 | input_schema: serde_json::json!({ 347 | "type": "object", 348 | "properties": { 349 | "location": {"type": "string"} 350 | } 351 | }), 352 | output_schema: None, 353 | }; 354 | 355 | let json = serde_json::to_string(&tool).unwrap(); 356 | let parsed: McpTool = serde_json::from_str(&json).unwrap(); 357 | assert_eq!(parsed.name, "get_weather"); 358 | } 359 | } 360 | --------------------------------------------------------------------------------