├── .isort.cfg ├── .gitmodules ├── python ├── sglang │ ├── launch_server.py │ ├── srt │ │ ├── managers │ │ │ ├── openai_protocol.py │ │ │ ├── router │ │ │ │ ├── manager.py │ │ │ │ ├── scheduler.py │ │ │ │ ├── radix_cache.py │ │ │ │ └── infer_batch.py │ │ │ ├── io_struct.py │ │ │ ├── detokenizer_manager.py │ │ │ └── tokenizer_manager.py │ │ ├── layers │ │ │ ├── activation.py │ │ │ ├── layernorm.py │ │ │ ├── quantization │ │ │ │ ├── base_config.py │ │ │ │ ├── awq.py │ │ │ │ └── awq_triton.py │ │ │ ├── get_selected_logprob.py │ │ │ ├── logits_processor.py │ │ │ ├── vocab_parallel_embedding.py │ │ │ ├── context_flashattention_nopad.py │ │ │ ├── radix_attention.py │ │ │ ├── token_attention.py │ │ │ └── extend_attention.py │ │ ├── model_config.py │ │ ├── constrained │ │ │ ├── fsm_cache.py │ │ │ ├── tokenizer.py │ │ │ └── fsm.py │ │ ├── parallel_utils │ │ │ ├── utils.py │ │ │ └── parallel_state.py │ │ ├── sampling_params.py │ │ ├── server.py │ │ ├── memory_pool.py │ │ ├── server_args.py │ │ ├── hf_transformers_utils.py │ │ ├── models │ │ │ ├── llava.py │ │ │ └── llama2.py │ │ └── utils.py │ └── utils.py └── pyproject.toml ├── docs ├── flashinfer.md └── test_process.md ├── run.sh ├── .pre-commit-config.yaml ├── LICENSE ├── .gitignore └── README.md /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile=black 3 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rdparty/flashinfer"] 2 | path = 3rdparty/flashinfer 3 | url = git@github.com:flashinfer-ai/flashinfer.git 4 | -------------------------------------------------------------------------------- /python/sglang/launch_server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from sglang.srt.server import ServerArgs, launch_server 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | ServerArgs.add_cli_args(parser) 8 | args = parser.parse_args() 9 | server_args = ServerArgs.from_cli_args(args) 10 | 11 | launch_server(server_args) 12 | -------------------------------------------------------------------------------- /python/sglang/srt/managers/openai_protocol.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, List, Optional, Union 3 | 4 | 5 | @dataclass 6 | class CompletionRequest: 7 | prompt: Union[str, List[Any]] 8 | model: str = "default" 9 | temperature: Optional[float] = 0.7 10 | max_tokens: Optional[int] = 16 11 | n: Optional[int] = 1 12 | stop: Optional[Union[str, List[str]]] = None 13 | -------------------------------------------------------------------------------- /docs/flashinfer.md: -------------------------------------------------------------------------------- 1 | ## Flashinfer Mode 2 | 3 | [flashinfer](https://github.com/flashinfer-ai/flashinfer) is a kernel library for LLM serving. 4 | It can be used in SGLang runtime to accelerate attention computation. 5 | 6 | ### Install flashinfer 7 | 8 | Note: The compilation can take a very long time. 9 | 10 | ```bash 11 | git submodule update --init --recursive 12 | pip install 3rdparty/flashinfer/python 13 | ``` 14 | 15 | ### Run a Server With Flashinfer Mode 16 | 17 | Add `--model-mode flashinfer` argument to enable flashinfer when launching a server. 18 | 19 | Example: 20 | 21 | ```bash 22 | python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --model-mode flashinfer 23 | ``` 24 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | export CUDA_LAUNCH_BLOCKING=1 5 | export TORCH_USE_CUDA_DSA=1 # Enable device-side assertions for debugging 6 | 7 | # python3 -m sglang.launch_server \ 8 | # --model-path /tmp-data/models/llama-2-7b \ 9 | # --port 30000 \ 10 | # --mem-fraction-static 0.8 --tp 2 11 | 12 | # python3 -m sglang.launch_server \ 13 | # --model-path /tmp-data/models/Qwen3-8B \ 14 | # --port 30000 \ 15 | # --mem-fraction-static 0.8 --tp 1 16 | 17 | # --tp 1 \ 18 | # --trust-remote-code --host 0.0.0.0 19 | 20 | python3 -m sglang.launch_server \ 21 | --model-path /tmp-data/models/Llama-2-7B-AWQ \ 22 | --port 30000 \ 23 | --mem-fraction-static 0.8 --tp 2 24 | 25 | # python3 -m sglang.launch_server \ 26 | # --model-path /tmp-data/models/llama-2-7b \ 27 | # --port 30000 \ 28 | # --mem-fraction-static 0.8 --tp 2 \ 29 | # --model-mode flashinfer 30 | -------------------------------------------------------------------------------- /python/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "nano-sglang" 7 | version = "0.0.1" 8 | description = "nano SGLang: LLM inference framework" 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | license = { file = "LICENSE" } 12 | classifiers = ["Programming Language :: Python :: 3"] 13 | dependencies = ["requests"] 14 | 15 | [project.optional-dependencies] 16 | all = [ 17 | "fastapi", 18 | "psutil", 19 | "rpyc", 20 | "torch>=2.8.0", 21 | "transformers", 22 | "triton", 23 | "uvloop", 24 | "uvicorn", 25 | "zmq", 26 | "interegular", 27 | "lark", 28 | "numba", 29 | "tqdm", 30 | "numpy", 31 | "requests", 32 | "huggingface_hub", 33 | "safetensors", 34 | "Pillow", 35 | "filelock", 36 | ] 37 | 38 | [project.urls] 39 | "Homepage" = "https://github.com/gogongxt/nano-sglang" 40 | "Bug Tracker" = "https://github.com/gogongxt/nano-sglang/issues" 41 | -------------------------------------------------------------------------------- /python/sglang/srt/layers/activation.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class SiluAndMul(nn.Module): 10 | """An activation function for SwiGLU. 11 | 12 | The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2. 13 | 14 | Shapes: 15 | x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) 16 | return: (batch_size, seq_len, d) or (num_tokens, d) 17 | """ 18 | 19 | def forward(self, x: torch.Tensor) -> torch.Tensor: 20 | """PyTorch-native implementation equivalent to forward().""" 21 | d = x.shape[-1] // 2 22 | return F.silu(x[..., :d]) * x[..., d:] 23 | 24 | # def forward(self, x: torch.Tensor) -> torch.Tensor: 25 | # d = x.shape[-1] // 2 26 | # output_shape = x.shape[:-1] + (d,) 27 | # out = torch.empty(output_shape, dtype=x.dtype, device=x.device) 28 | # ops.silu_and_mul(out, x) 29 | # return out 30 | -------------------------------------------------------------------------------- /python/sglang/utils.py: -------------------------------------------------------------------------------- 1 | """Common utilities.""" 2 | 3 | 4 | def get_available_gpu_memory(gpu_id, distributed=True): 5 | """ 6 | Get available memory for cuda:gpu_id device. 7 | When distributed is True, the available memory is the minimum available memory of all GPUs. 8 | """ 9 | import torch 10 | 11 | num_gpus = torch.cuda.device_count() 12 | assert gpu_id < num_gpus 13 | 14 | if torch.cuda.current_device() != gpu_id: 15 | print( 16 | f"WARN: current device is not {gpu_id}, but {torch.cuda.current_device()}, ", 17 | "which may cause useless memory allocation for torch CUDA context.", 18 | ) 19 | 20 | free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id) 21 | 22 | if distributed: 23 | tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to( 24 | torch.device("cuda", gpu_id) 25 | ) 26 | torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN) 27 | free_gpu_memory = tensor.item() 28 | 29 | return free_gpu_memory / (1 << 30) 30 | -------------------------------------------------------------------------------- /python/sglang/srt/model_config.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from sglang.srt.hf_transformers_utils import get_config, get_context_length 4 | 5 | 6 | class ModelConfig: 7 | def __init__( 8 | self, 9 | path: str, 10 | trust_remote_code: bool = True, 11 | revision: Optional[str] = None, 12 | ) -> None: 13 | self.path = path 14 | self.trust_remote_code = trust_remote_code 15 | self.revision = revision 16 | self.hf_config = get_config(self.path, trust_remote_code, revision) 17 | 18 | # Unify the config keys for hf_config 19 | self.context_len = get_context_length(self.hf_config) 20 | self.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads 21 | self.num_key_value_heads = self.hf_config.num_key_value_heads 22 | self.num_attention_heads = self.hf_config.num_attention_heads 23 | self.hidden_size = self.hf_config.hidden_size 24 | self.num_hidden_layers = self.hf_config.num_hidden_layers 25 | self.vocab_size = self.hf_config.vocab_size 26 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_stages: [pre-commit, pre-push, manual] 2 | 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v5.0.0 6 | hooks: 7 | - id: check-symlinks 8 | - id: destroyed-symlinks 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-yaml 12 | args: [--allow-multiple-documents] 13 | - id: check-toml 14 | - id: check-ast 15 | - id: check-added-large-files 16 | - id: check-merge-conflict 17 | - id: check-shebang-scripts-are-executable 18 | - id: detect-private-key 19 | - id: debug-statements 20 | # - id: no-commit-to-branch 21 | - repo: https://github.com/PyCQA/isort 22 | rev: 6.0.1 23 | hooks: 24 | - id: isort 25 | - repo: https://github.com/psf/black 26 | rev: 25.1.0 27 | hooks: 28 | - id: black-jupyter 29 | - repo: https://github.com/pre-commit/mirrors-clang-format 30 | rev: v20.1.8 31 | hooks: 32 | - id: clang-format 33 | types_or: [c++, cuda] 34 | args: [--style=file, --verbose] 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 XiaoTian Gong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/test_process.md: -------------------------------------------------------------------------------- 1 | ## SRT Unit Tests 2 | 3 | ### Low-level API 4 | ``` 5 | cd sglang/test/srt/model 6 | 7 | python3 test_llama_low_api.py 8 | python3 test_llama_extend.py 9 | python3 test_llava_low_api.py 10 | python3 bench_llama_low_api.py 11 | ``` 12 | 13 | ### High-level API 14 | 15 | ``` 16 | python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 17 | ``` 18 | 19 | ``` 20 | cd test/lang 21 | python3 test_srt_backend.py 22 | ``` 23 | 24 | ### Performance 25 | 26 | #### MMLU 27 | ``` 28 | cd benchmark/mmlu 29 | ``` 30 | Follow README.md to download the data. 31 | 32 | ``` 33 | python3 bench_sglang.py --nsub 3 34 | 35 | # Expected performance on A10G 36 | # Total latency: 8.200 37 | # Average accuracy: 0.413 38 | ``` 39 | 40 | ### More Models 41 | 42 | #### LLaVA 43 | 44 | ``` 45 | python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000 46 | ``` 47 | 48 | ``` 49 | cd benchmark/llava_bench 50 | python3 bench_sglang.py 51 | ``` 52 | 53 | ## SGLang Unit Tests 54 | ``` 55 | export ANTHROPIC_API_KEY= 56 | export OPENAI_API_KEY= 57 | python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 58 | ``` 59 | 60 | ``` 61 | cd test/lang 62 | python3 run_all.py 63 | ``` 64 | -------------------------------------------------------------------------------- /python/sglang/srt/constrained/fsm_cache.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | from sglang.srt.constrained.fsm import RegexFSM 4 | from sglang.srt.constrained.tokenizer import TransformerTokenizer 5 | 6 | 7 | def get_fsm(regex, tokenizer, fsm_cache_entry): 8 | outlines_tokenizer = TransformerTokenizer(tokenizer) 9 | fsm = RegexFSM(regex, outlines_tokenizer) 10 | fsm_cache_entry.fsm = fsm 11 | fsm_cache_entry.event.set() 12 | 13 | 14 | class FSMCacheEntry: 15 | def __init__(self): 16 | self.fsm = None 17 | self.event = threading.Event() 18 | 19 | 20 | class FSMCache: 21 | def __init__(self, tokenizer): 22 | self.cache = {} 23 | self.tokenizer = tokenizer 24 | 25 | def init_fsm_in_background(self, regex): 26 | if regex not in self.cache: 27 | self.cache[regex] = FSMCacheEntry() 28 | threading.Thread( 29 | target=get_fsm, 30 | args=( 31 | regex, 32 | self.tokenizer, 33 | self.cache[regex], 34 | ), 35 | ).start() 36 | 37 | def get_fsm(self, regex): 38 | self.init_fsm_in_background(regex) 39 | entry = self.cache[regex] 40 | entry.event.wait() 41 | return entry.fsm 42 | -------------------------------------------------------------------------------- /python/sglang/srt/parallel_utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The vLLM team. 2 | # Adapted from 3 | # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py 4 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 5 | from typing import Sequence 6 | 7 | import torch 8 | 9 | 10 | def ensure_divisibility(numerator, denominator): 11 | """Ensure that numerator is divisible by the denominator.""" 12 | assert numerator % denominator == 0, "{} is not divisible by {}".format( 13 | numerator, denominator 14 | ) 15 | 16 | 17 | def divide(numerator, denominator): 18 | """Ensure that numerator is divisible by the denominator and return 19 | the division value.""" 20 | ensure_divisibility(numerator, denominator) 21 | return numerator // denominator 22 | 23 | 24 | def split_tensor_along_last_dim( 25 | tensor: torch.Tensor, 26 | num_partitions: int, 27 | contiguous_split_chunks: bool = False, 28 | ) -> Sequence[torch.Tensor]: 29 | """Split a tensor along its last dimension. 30 | 31 | Arguments: 32 | tensor: input tensor. 33 | num_partitions: number of partitions to split the tensor 34 | contiguous_split_chunks: If True, make each chunk contiguous 35 | in memory. 36 | 37 | Returns: 38 | A list of Tensors 39 | """ 40 | # Get the size and dimension. 41 | last_dim = tensor.dim() - 1 42 | last_dim_size = divide(tensor.size()[last_dim], num_partitions) 43 | # Split. 44 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) 45 | # NOTE: torch.split does not create contiguous tensors by default. 46 | if contiguous_split_chunks: 47 | return tuple(chunk.contiguous() for chunk in tensor_list) 48 | 49 | return tensor_list 50 | -------------------------------------------------------------------------------- /python/sglang/srt/layers/layernorm.py: -------------------------------------------------------------------------------- 1 | """Custom normalization layers.""" 2 | 3 | from typing import Optional, Tuple, Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class RMSNorm(nn.Module): 10 | """Root mean square normalization. 11 | 12 | Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. 13 | Refer to https://arxiv.org/abs/1910.07467 14 | """ 15 | 16 | def __init__( 17 | self, 18 | hidden_size: int, 19 | eps: float = 1e-6, 20 | ) -> None: 21 | super().__init__() 22 | self.weight = nn.Parameter(torch.ones(hidden_size)) 23 | self.variance_epsilon = eps 24 | 25 | def forward( 26 | self, 27 | x: torch.Tensor, 28 | residual: Optional[torch.Tensor] = None, 29 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 30 | """PyTorch-native implementation equivalent to forward().""" 31 | orig_dtype = x.dtype 32 | x = x.to(torch.float32) 33 | if residual is not None: 34 | x = x + residual.to(torch.float32) 35 | residual = x.to(orig_dtype) 36 | 37 | variance = x.pow(2).mean(dim=-1, keepdim=True) 38 | x = x * torch.rsqrt(variance + self.variance_epsilon) 39 | x = x.to(orig_dtype) * self.weight 40 | if residual is None: 41 | return x 42 | else: 43 | return x, residual 44 | 45 | # def forward( 46 | # self, 47 | # x: torch.Tensor, 48 | # residual: Optional[torch.Tensor] = None, 49 | # ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 50 | # if residual is not None: 51 | # ops.fused_add_rms_norm( 52 | # x, 53 | # residual, 54 | # self.weight.data, 55 | # self.variance_epsilon, 56 | # ) 57 | # return x, residual 58 | # out = torch.empty_like(x) 59 | # ops.rms_norm( 60 | # out, 61 | # x, 62 | # self.weight.data, 63 | # self.variance_epsilon, 64 | # ) 65 | # return out 66 | -------------------------------------------------------------------------------- /python/sglang/srt/layers/quantization/base_config.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict, List 3 | 4 | import torch 5 | from sglang.srt.layers.linear import LinearMethodBase 6 | 7 | 8 | class QuantizationConfig(ABC): 9 | """Base class for quantization configs.""" 10 | 11 | @abstractmethod 12 | def get_name(self) -> str: 13 | """Name of the quantization method.""" 14 | raise NotImplementedError 15 | 16 | @abstractmethod 17 | def get_supported_act_dtypes(self) -> List[torch.dtype]: 18 | """List of supported activation dtypes.""" 19 | raise NotImplementedError 20 | 21 | @abstractmethod 22 | def get_min_capability(self) -> int: 23 | """Minimum GPU capability to support the quantization method. 24 | 25 | E.g., 70 for Volta, 75 for Turing, 80 for Ampere. 26 | This requirement is due to the custom CUDA kernels used by the 27 | quantization method. 28 | """ 29 | raise NotImplementedError 30 | 31 | @staticmethod 32 | @abstractmethod 33 | def get_config_filenames() -> List[str]: 34 | """List of filenames to search for in the model directory.""" 35 | raise NotImplementedError 36 | 37 | @classmethod 38 | @abstractmethod 39 | def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": 40 | """Create a config class from the model's quantization config.""" 41 | raise NotImplementedError 42 | 43 | @staticmethod 44 | def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: 45 | """Get a value from the model's quantization config.""" 46 | for key in keys: 47 | if key in config: 48 | return config[key] 49 | raise ValueError( 50 | f"Cannot find any of {keys} in the model's " "quantization config." 51 | ) 52 | 53 | @abstractmethod 54 | def get_linear_method(self) -> LinearMethodBase: 55 | """Get the linear method to use for the quantized linear layer.""" 56 | raise NotImplementedError 57 | 58 | @abstractmethod 59 | def get_scaled_act_names(self) -> List[str]: 60 | """Returns the activation function names that should be post-scaled. 61 | 62 | For now, this is only used by AWQ. 63 | """ 64 | raise NotImplementedError 65 | -------------------------------------------------------------------------------- /python/sglang/srt/layers/get_selected_logprob.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | from sglang.srt.utils import wrap_kernel_launcher 5 | 6 | 7 | @triton.jit 8 | def _fwd_segmented_gather( 9 | all_logits, 10 | len_add_1, 11 | cum_len, 12 | input_ids, 13 | logprobs, 14 | max_seq_len, 15 | voc_size: tl.constexpr, 16 | BLOCK_SIZE: tl.constexpr, 17 | ): 18 | cur_req = tl.program_id(0) 19 | cur_l = tl.load(len_add_1 + cur_req) 20 | cum_l = tl.load(cum_len + cur_req) 21 | 22 | for i in range(0, (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE): 23 | off = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 24 | mask = off < cur_l - 1 25 | 26 | idx = tl.load(input_ids + cum_l - cur_l + off + 1, mask=mask) 27 | data = tl.load(all_logits + (cum_l - cur_l + off) * voc_size + idx, mask=mask) 28 | tl.store(logprobs + cum_l - cur_l - cur_req + off, data, mask=mask) 29 | 30 | 31 | cached_kernel = None 32 | 33 | 34 | def get_selected_logprob(all_logits, len_add_1, input_ids, logprobs): 35 | cum_len = torch.cumsum(len_add_1, dtype=torch.int32, dim=0) 36 | voc_size = all_logits.shape[1] 37 | grid = (len_add_1.shape[0], 1, 1) 38 | max_seq_len = len_add_1.max().item() 39 | 40 | global cached_kernel 41 | if cached_kernel: 42 | cached_kernel( 43 | grid, 44 | 4, 45 | all_logits, 46 | len_add_1, 47 | cum_len, 48 | input_ids, 49 | logprobs, 50 | max_seq_len, 51 | voc_size, 52 | BLOCK_SIZE=128, 53 | ) 54 | return 55 | 56 | # Launch kernel using modern Triton API 57 | kernel_launcher = wrap_kernel_launcher(_fwd_segmented_gather) 58 | kernel_launcher( 59 | grid, 60 | 4, 61 | all_logits, 62 | len_add_1, 63 | cum_len, 64 | input_ids, 65 | logprobs, 66 | max_seq_len, 67 | voc_size, 68 | BLOCK_SIZE=128, 69 | ) 70 | cached_kernel = kernel_launcher 71 | 72 | 73 | if __name__ == "__main__": 74 | all_logits = torch.tensor( 75 | # s s s 76 | [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]], 77 | dtype=torch.float32, 78 | device="cuda", 79 | ) 80 | len_add_1 = torch.tensor([2, 3], dtype=torch.int32, device="cuda") 81 | input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda") 82 | logprobs = torch.empty((3), dtype=torch.float32, device="cuda") 83 | get_selected_logprobs(all_logits, len_add_1, input_ids, logprobs) 84 | print(logprobs) 85 | # assert logprobs == [2, 2, 4] 86 | -------------------------------------------------------------------------------- /python/sglang/srt/managers/router/manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | import uvloop 5 | import zmq 6 | import zmq.asyncio 7 | from sglang.srt.managers.router.model_rpc import ModelRpcClient 8 | from sglang.srt.server_args import PortArgs, ServerArgs 9 | from sglang.srt.utils import get_exception_traceback 10 | 11 | asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) 12 | 13 | 14 | class RouterManager: 15 | def __init__(self, model_client: ModelRpcClient, port_args: PortArgs): 16 | # Init communication 17 | context = zmq.asyncio.Context(2) 18 | self.recv_from_tokenizer = context.socket(zmq.PULL) 19 | self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}") 20 | 21 | self.send_to_detokenizer = context.socket(zmq.PUSH) 22 | self.send_to_detokenizer.connect( 23 | f"tcp://127.0.0.1:{port_args.detokenizer_port}" 24 | ) 25 | 26 | # Init status 27 | self.model_client = model_client 28 | self.recv_reqs = [] 29 | 30 | # Init Some Configs 31 | self.extend_dependency_time = 0.03 32 | 33 | async def loop_for_forward(self): 34 | while True: 35 | next_step_input = list(self.recv_reqs) 36 | self.recv_reqs = [] 37 | out_pyobjs = await self.model_client.step(next_step_input) 38 | 39 | for obj in out_pyobjs: 40 | self.send_to_detokenizer.send_pyobj(obj) 41 | 42 | # async sleep for recving the subsequent request, and avoiding cache miss 43 | if len(out_pyobjs) != 0: 44 | has_finished = any([obj.finished for obj in out_pyobjs]) 45 | if has_finished: 46 | await asyncio.sleep(self.extend_dependency_time) 47 | 48 | await asyncio.sleep(0.001) 49 | 50 | async def loop_for_recv_requests(self): 51 | while True: 52 | recv_req = await self.recv_from_tokenizer.recv_pyobj() 53 | self.recv_reqs.append(recv_req) 54 | 55 | 56 | def start_router_process( 57 | server_args: ServerArgs, 58 | port_args: PortArgs, 59 | pipe_writer, 60 | ): 61 | logging.basicConfig( 62 | level=getattr(logging, server_args.log_level.upper()), 63 | format="%(message)s", 64 | ) 65 | 66 | try: 67 | model_client = ModelRpcClient(server_args, port_args) 68 | router = RouterManager(model_client, port_args) 69 | except Exception: 70 | pipe_writer.send(get_exception_traceback()) 71 | raise 72 | 73 | pipe_writer.send("init ok") 74 | 75 | loop = asyncio.new_event_loop() 76 | asyncio.set_event_loop(loop) 77 | loop.create_task(router.loop_for_recv_requests()) 78 | loop.run_until_complete(router.loop_for_forward()) 79 | -------------------------------------------------------------------------------- /python/sglang/srt/managers/router/scheduler.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | 4 | 5 | class Scheduler: 6 | def __init__( 7 | self, 8 | schedule_heuristic, 9 | max_running_seq, 10 | max_prefill_num_token, 11 | max_total_num_token, 12 | tree_cache, 13 | ): 14 | self.schedule_heuristic = schedule_heuristic 15 | self.max_running_seq = max_running_seq 16 | self.max_prefill_num_token = max_prefill_num_token 17 | self.max_total_num_token = max_total_num_token 18 | self.tree_cache = tree_cache 19 | 20 | def new_token_estimation_ratio(self): 21 | return 0.5 if self.schedule_heuristic != "fcfs" else 0.6 22 | 23 | def get_priority_queue(self, forward_queue): 24 | if self.schedule_heuristic == "lpm": 25 | # longest prefix match 26 | forward_queue.sort(key=lambda x: -len(x.prefix_indices)) 27 | return forward_queue 28 | elif self.schedule_heuristic == "random": 29 | random.shuffle(forward_queue) 30 | return forward_queue 31 | elif self.schedule_heuristic == "fcfs": 32 | return forward_queue 33 | elif self.schedule_heuristic == "weight": 34 | last_node_to_reqs = defaultdict(list) 35 | for req in forward_queue: 36 | last_node_to_reqs[req.last_node].append(req) 37 | for node in last_node_to_reqs: 38 | last_node_to_reqs[node].sort(key=lambda x: -len(x.prefix_indices)) 39 | 40 | node_to_weight = defaultdict(int) 41 | self._calc_weight_recursive( 42 | self.tree_cache.root_node, last_node_to_reqs, node_to_weight 43 | ) 44 | 45 | tmp_queue = [] 46 | self._get_weight_priority_recursive( 47 | self.tree_cache.root_node, node_to_weight, last_node_to_reqs, tmp_queue 48 | ) 49 | assert len(tmp_queue) == len(forward_queue) 50 | return tmp_queue 51 | else: 52 | raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}") 53 | 54 | def _calc_weight_recursive(self, cur_node, last_node_to_reqs, node_to_weight): 55 | node_to_weight[cur_node] = 1 56 | if cur_node in last_node_to_reqs: 57 | node_to_weight[cur_node] += len(last_node_to_reqs[cur_node]) 58 | for child in cur_node.children.values(): 59 | self._calc_weight_recursive(child, last_node_to_reqs, node_to_weight) 60 | node_to_weight[cur_node] += node_to_weight[child] 61 | 62 | def _get_weight_priority_recursive( 63 | self, cur_node, node_to_wight, last_node_to_reqs, tmp_queue 64 | ): 65 | visit_list = [child for child in cur_node.children.values()] 66 | visit_list.sort(key=lambda x: -node_to_wight[x]) 67 | # for node in visit_list: 68 | # print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}") 69 | for child in visit_list: 70 | self._get_weight_priority_recursive( 71 | child, node_to_wight, last_node_to_reqs, tmp_queue 72 | ) 73 | tmp_queue.extend(last_node_to_reqs[cur_node]) 74 | -------------------------------------------------------------------------------- /python/sglang/srt/sampling_params.py: -------------------------------------------------------------------------------- 1 | """Sampling parameters for text generation.""" 2 | 3 | from typing import List, Optional, Union 4 | 5 | _SAMPLING_EPS = 1e-6 6 | 7 | 8 | class SamplingParams: 9 | def __init__( 10 | self, 11 | max_new_tokens: int = 16, 12 | stop: Optional[Union[str, List[str]]] = None, 13 | temperature: float = 1.0, 14 | top_p: float = 1.0, 15 | top_k: int = -1, 16 | frequency_penalty: float = 0.0, 17 | presence_penalty: float = 0.0, 18 | ignore_eos: bool = False, 19 | skip_special_tokens: bool = True, 20 | dtype: Optional[str] = None, 21 | regex: Optional[str] = None, 22 | ) -> None: 23 | self.temperature = temperature 24 | self.top_p = top_p 25 | self.top_k = top_k 26 | self.frequency_penalty = frequency_penalty 27 | self.presence_penalty = presence_penalty 28 | self.stop_strs = stop 29 | self.max_new_tokens = max_new_tokens 30 | self.ignore_eos = ignore_eos 31 | self.skip_special_tokens = skip_special_tokens 32 | self.dtype = dtype 33 | self.regex = regex 34 | 35 | # Process some special cases 36 | if self.temperature < _SAMPLING_EPS: 37 | self.temperature = 1.0 38 | self.top_k = 1 39 | if self.top_k == -1: 40 | self.top_k = 1 << 30 # whole vocabulary 41 | if self.dtype == "int": 42 | self.stop_strs = [" ", "\n"] 43 | 44 | def verify(self): 45 | if self.temperature < 0.0: 46 | raise ValueError( 47 | f"temperature must be non-negative, got {self.temperature}." 48 | ) 49 | if not 0.0 < self.top_p <= 1.0: 50 | raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") 51 | if self.top_k < -1 or self.top_k == 0: 52 | raise ValueError( 53 | f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}." 54 | ) 55 | if not -2.0 <= self.frequency_penalty <= 2.0: 56 | raise ValueError( 57 | "frequency_penalty must be in [-2, 2], got " 58 | f"{self.frequency_penalty}." 59 | ) 60 | if not -2.0 <= self.presence_penalty <= 2.0: 61 | raise ValueError( 62 | "presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}." 63 | ) 64 | if self.max_new_tokens < 0: 65 | raise ValueError( 66 | f"max_new_tokens must be at least 0, got {self.max_new_tokens}." 67 | ) 68 | 69 | def normalize(self, tokenizer): 70 | # Process stop strings 71 | if self.stop_strs is None: 72 | self.stop_strs = [] 73 | self.stop_str_max_len = 0 74 | else: 75 | if isinstance(self.stop_strs, str): 76 | self.stop_strs = [self.stop_strs] 77 | 78 | stop_str_max_len = 0 79 | for stop_str in self.stop_strs: 80 | stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False) 81 | stop_str_max_len = max(stop_str_max_len, len(stop_str_ids)) 82 | self.stop_str_max_len = stop_str_max_len 83 | -------------------------------------------------------------------------------- /python/sglang/srt/managers/io_struct.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from dataclasses import dataclass 3 | from typing import Dict, List, Optional, Union 4 | 5 | from sglang.srt.sampling_params import SamplingParams 6 | 7 | 8 | @dataclass 9 | class GenerateReqInput: 10 | text: Union[List[str], str] 11 | image_data: Optional[Union[List[str], str]] = None 12 | sampling_params: Union[List[Dict], Dict] = None 13 | rid: Optional[Union[List[str], str]] = None 14 | return_normalized_logprob: Optional[Union[List[bool], bool]] = None 15 | normalized_logprob_start_len: Optional[Union[List[int], int]] = None 16 | stream: bool = False 17 | 18 | def post_init(self): 19 | is_single = isinstance(self.text, str) 20 | 21 | if is_single: 22 | if self.sampling_params is None: 23 | self.sampling_params = {} 24 | if self.rid is None: 25 | self.rid = uuid.uuid4().hex 26 | if self.return_normalized_logprob is None: 27 | self.return_normalized_logprob = False 28 | if self.normalized_logprob_start_len is None: 29 | self.normalized_logprob_start_len = 0 30 | else: 31 | num = len(self.text) 32 | 33 | if self.image_data is None: 34 | self.image_data = [None] * num 35 | elif not isinstance(self.image_data, list): 36 | self.image_data = [self.image_data] * num 37 | 38 | if self.sampling_params is None: 39 | self.sampling_params = [{}] * num 40 | elif not isinstance(self.sampling_params, list): 41 | self.sampling_params = [self.sampling_params] * num 42 | 43 | if self.rid is None: 44 | self.rid = [uuid.uuid4().hex for _ in range(num)] 45 | else: 46 | assert isinstance(self.rid, list) 47 | 48 | if self.return_normalized_logprob is None: 49 | self.return_normalized_logprob = [False] * num 50 | elif not isinstance(self.return_normalized_logprob, list): 51 | self.return_normalized_logprob = [self.return_normalized_logprob] * num 52 | 53 | if self.normalized_logprob_start_len is None: 54 | self.normalized_logprob_start_len = [0] * num 55 | elif not isinstance(self.normalized_logprob_start_len, list): 56 | self.normalized_logprob_start_len = [ 57 | self.normalized_logprob_start_len 58 | ] * num 59 | 60 | 61 | @dataclass 62 | class TokenizedGenerateReqInput: 63 | rid: str 64 | input_ids: List[int] 65 | pixel_values: List[float] 66 | image_hash: int 67 | sampling_params: SamplingParams 68 | return_normalized_logprob: bool 69 | normalized_logprob_start_len: int 70 | stream: bool 71 | 72 | 73 | @dataclass 74 | class BatchTokenIDOut: 75 | rids: List[str] 76 | output_tokens: List[List[int]] 77 | hit_stop_str: List[Optional[str]] 78 | skip_special_tokens: List[bool] 79 | meta_info: List[Dict] 80 | finished: List[bool] 81 | 82 | 83 | @dataclass 84 | class BatchStrOut: 85 | rids: List[str] 86 | output_str: List[str] 87 | meta_info: List[Dict] 88 | finished: List[bool] 89 | -------------------------------------------------------------------------------- /python/sglang/srt/managers/detokenizer_manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import uvloop 4 | import zmq 5 | import zmq.asyncio 6 | from sglang.srt.hf_transformers_utils import get_tokenizer 7 | from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut 8 | from sglang.srt.server_args import PortArgs, ServerArgs 9 | from sglang.srt.utils import get_exception_traceback 10 | 11 | asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) 12 | 13 | 14 | class DetokenizerManager: 15 | def __init__( 16 | self, 17 | server_args: ServerArgs, 18 | port_args: PortArgs, 19 | ): 20 | context = zmq.asyncio.Context(2) 21 | self.recv_from_router = context.socket(zmq.PULL) 22 | self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}") 23 | 24 | self.send_to_tokenizer = context.socket(zmq.PUSH) 25 | self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}") 26 | 27 | self.tokenizer = get_tokenizer( 28 | server_args.tokenizer_path, 29 | tokenizer_mode=server_args.tokenizer_mode, 30 | trust_remote_code=server_args.trust_remote_code, 31 | ) 32 | 33 | async def handle_loop(self): 34 | while True: 35 | recv_obj = await self.recv_from_router.recv_pyobj() 36 | 37 | if isinstance(recv_obj, BatchTokenIDOut): 38 | output_tokens = recv_obj.output_tokens 39 | 40 | # TODO(lmzheng): handle skip_special_tokens per request 41 | output_strs = self.tokenizer.batch_decode( 42 | output_tokens, 43 | skip_special_tokens=recv_obj.skip_special_tokens[0], 44 | ) 45 | 46 | # Trim stop str 47 | # TODO(lmzheng): handle the case where multiple stop strs are hit 48 | for i in range(len(output_strs)): 49 | if recv_obj.hit_stop_str[i] is not None: 50 | pos = output_strs[i].find(recv_obj.hit_stop_str[i]) 51 | if pos != -1: 52 | output_strs[i] = output_strs[i][:pos] 53 | 54 | if len(output_tokens[i]) > 0: 55 | first_token = self.tokenizer.convert_ids_to_tokens( 56 | int(output_tokens[i][0]) 57 | ) 58 | if first_token.startswith("▁"): 59 | output_strs[i] = " " + output_strs[i] 60 | 61 | self.send_to_tokenizer.send_pyobj( 62 | BatchStrOut( 63 | recv_obj.rids, 64 | output_strs, 65 | recv_obj.meta_info, 66 | recv_obj.finished, 67 | ) 68 | ) 69 | else: 70 | raise ValueError(f"Invalid object: {recv_obj}") 71 | 72 | 73 | def start_detokenizer_process( 74 | server_args: ServerArgs, 75 | port_args: PortArgs, 76 | pipe_writer, 77 | ): 78 | try: 79 | manager = DetokenizerManager(server_args, port_args) 80 | except Exception as e: 81 | pipe_writer.send(get_exception_traceback()) 82 | raise 83 | pipe_writer.send("init ok") 84 | loop = uvloop.new_event_loop() 85 | asyncio.set_event_loop(loop) 86 | loop.run_until_complete(manager.handle_loop()) 87 | -------------------------------------------------------------------------------- /python/sglang/srt/layers/logits_processor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sglang.srt.layers.get_selected_logprob import get_selected_logprob 3 | from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata 4 | from sglang.srt.parallel_utils.parallel_state import ( 5 | get_tensor_model_parallel_world_size, 6 | tensor_model_parallel_all_gather, 7 | ) 8 | from torch import nn 9 | 10 | 11 | class LogitsProcessor(nn.Module): 12 | def __init__(self, config): 13 | super().__init__() 14 | self.config = config 15 | self.tp_size = get_tensor_model_parallel_world_size() 16 | 17 | def forward(self, input_ids, hidden_states, weight, input_metadata): 18 | if not input_metadata.return_normalized_logprob: 19 | if input_metadata.forward_mode == ForwardMode.DECODE: 20 | # For decode mode, hidden_states should be [batch, hidden_dim] 21 | # But if it has an extra sequence dimension, extract the last token 22 | if hidden_states.dim() == 3: 23 | last_hidden = hidden_states[ 24 | :, -1, : 25 | ] # [batch, seq, hidden] -> [batch, hidden] 26 | else: 27 | last_hidden = hidden_states 28 | else: 29 | last_index = ( 30 | torch.cumsum( 31 | input_metadata.seq_lens - input_metadata.prefix_lens, 32 | dim=0, 33 | dtype=torch.long, 34 | ) 35 | - 1 36 | ) 37 | # Clamp last_index to prevent out of bounds 38 | last_index = torch.clamp(last_index, 0, hidden_states.shape[0] - 1) 39 | last_hidden = hidden_states[last_index] 40 | hidden_states = None 41 | 42 | last_logits = torch.matmul(last_hidden, weight.T) 43 | if self.tp_size > 1: 44 | last_logits = tensor_model_parallel_all_gather(last_logits) 45 | last_logits = last_logits[:, : self.config.vocab_size] 46 | return last_logits, None 47 | else: 48 | assert input_metadata.forward_mode != ForwardMode.DECODE 49 | last_index = ( 50 | torch.cumsum( 51 | input_metadata.seq_lens - input_metadata.prefix_lens, 52 | dim=0, 53 | dtype=torch.long, 54 | ) 55 | - 1 56 | ) 57 | 58 | logits = torch.matmul(hidden_states, weight.T) 59 | if self.tp_size > 1: 60 | logits = tensor_model_parallel_all_gather(logits) 61 | logits = logits[:, : self.config.vocab_size] 62 | all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6) 63 | 64 | normalized_logprobs = compute_normalized_logprobs( 65 | all_logprobs, 66 | input_metadata.seq_lens - input_metadata.prefix_lens, 67 | input_ids, 68 | ) 69 | 70 | # Clamp last_index to prevent out of bounds 71 | last_index = torch.clamp(last_index, 0, logits.shape[0] - 1) 72 | last_logits = logits[last_index] 73 | return last_logits, normalized_logprobs 74 | 75 | 76 | def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids): 77 | # assert all_logprobs.shape[0] == torch.sum(len_add_1) == input_ids.shape[0] 78 | logprobs = torch.zeros( 79 | (all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda" 80 | ) 81 | get_selected_logprob(all_logprobs, len_add_1, input_ids, logprobs) 82 | cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32) 83 | end = torch.cumsum(len_add_1.sub_(1), dim=0) 84 | start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0) 85 | end.sub_(1) 86 | sum_logp = cumsum[end] - cumsum[start] + logprobs[start] 87 | res = sum_logp / len_add_1 88 | return res 89 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # MacOS 163 | .DS_Store 164 | *.json 165 | 166 | # Vim 167 | *.swp 168 | 169 | # SGL 170 | benchmark/mmlu/data 171 | benchmark/mmlu/data.tar 172 | benchmark/llava_bench/images 173 | benchmark/llava_bench/mme_pack 174 | *.jsonl 175 | tmp*.txt 176 | 177 | # Plots 178 | *.png 179 | *.pdf 180 | -------------------------------------------------------------------------------- /python/sglang/srt/server.py: -------------------------------------------------------------------------------- 1 | """SRT: SGLang Runtime""" 2 | 3 | import asyncio 4 | import json 5 | import multiprocessing as mp 6 | import sys 7 | import threading 8 | 9 | # Fix a Python bug 10 | setattr(threading, "_register_atexit", lambda *args, **kwargs: None) 11 | 12 | import uvicorn 13 | import uvloop 14 | from fastapi import FastAPI 15 | from fastapi.responses import StreamingResponse 16 | from sglang.srt.managers.detokenizer_manager import start_detokenizer_process 17 | from sglang.srt.managers.io_struct import GenerateReqInput 18 | from sglang.srt.managers.openai_protocol import CompletionRequest 19 | from sglang.srt.managers.router.manager import start_router_process 20 | from sglang.srt.managers.tokenizer_manager import TokenizerManager 21 | from sglang.srt.server_args import PortArgs, ServerArgs 22 | from sglang.srt.utils import alloc_usable_network_port 23 | 24 | asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) 25 | 26 | 27 | app = FastAPI() 28 | tokenizer_manager = None 29 | 30 | 31 | @app.get("/get_model_info") 32 | async def get_model_info(): 33 | result = { 34 | "model_path": tokenizer_manager.model_path, 35 | } 36 | return result 37 | 38 | 39 | @app.post("/generate") 40 | async def generate_request(obj: GenerateReqInput): 41 | obj.post_init() 42 | result_generator = tokenizer_manager.generate_request(obj) 43 | 44 | if obj.stream: 45 | 46 | async def stream_results(): 47 | async for out in result_generator: 48 | yield (json.dumps(out) + "\0").encode("utf-8") 49 | 50 | return StreamingResponse(stream_results(), media_type="text/event-stream") 51 | else: 52 | ret = await result_generator.__anext__() 53 | return ret 54 | 55 | 56 | @app.post("/v1/completions") 57 | async def v1_completions(obj: CompletionRequest): 58 | assert obj.n == 1 59 | obj = GenerateReqInput( 60 | text=obj.prompt, 61 | sampling_params={ 62 | "temperature": obj.temperature, 63 | "max_new_tokens": obj.max_tokens, 64 | "stop": obj.stop, 65 | }, 66 | ) 67 | ret = await generate_request(obj) 68 | return { 69 | "choices": [{"text": ret["text"]}], 70 | } 71 | 72 | 73 | def launch_server(server_args): 74 | global tokenizer_manager 75 | 76 | # Allocate ports 77 | can_use_ports = alloc_usable_network_port( 78 | num=4 + server_args.tp_size, used_list=(server_args.port,) 79 | ) 80 | port_args = PortArgs( 81 | tokenizer_port=can_use_ports[0], 82 | router_port=can_use_ports[1], 83 | detokenizer_port=can_use_ports[2], 84 | nccl_port=can_use_ports[3], 85 | model_rpc_ports=can_use_ports[4:], 86 | ) 87 | 88 | # Launch processes 89 | tokenizer_manager = TokenizerManager(server_args, port_args) 90 | pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False) 91 | pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) 92 | 93 | proc_router = mp.Process( 94 | target=start_router_process, 95 | args=( 96 | server_args, 97 | port_args, 98 | pipe_router_writer, 99 | ), 100 | ) 101 | proc_router.start() 102 | proc_detoken = mp.Process( 103 | target=start_detokenizer_process, 104 | args=( 105 | server_args, 106 | port_args, 107 | pipe_detoken_writer, 108 | ), 109 | ) 110 | proc_detoken.start() 111 | 112 | # Wait for the model to finish loading 113 | router_init_state = pipe_router_reader.recv() 114 | detoken_init_state = pipe_detoken_reader.recv() 115 | 116 | if router_init_state != "init ok" or detoken_init_state != "init ok": 117 | proc_router.kill() 118 | proc_detoken.kill() 119 | print("router init state:", router_init_state) 120 | print("detoken init state:", detoken_init_state) 121 | sys.exit(1) 122 | 123 | assert proc_router.is_alive() and proc_detoken.is_alive() 124 | 125 | def launch_server(): 126 | # Launch api server 127 | uvicorn.run( 128 | app, 129 | host=server_args.host, 130 | port=server_args.port, 131 | log_level=server_args.log_level, 132 | timeout_keep_alive=5, 133 | loop="uvloop", 134 | ) 135 | 136 | t = threading.Thread(target=launch_server) 137 | t.start() 138 | -------------------------------------------------------------------------------- /python/sglang/srt/memory_pool.py: -------------------------------------------------------------------------------- 1 | """Memory pool.""" 2 | 3 | import logging 4 | 5 | import torch 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class ReqToTokenPool: 11 | def __init__(self, size, max_context_len): 12 | self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda") 13 | self.can_use_mem_size = size 14 | self.req_to_token = torch.empty( 15 | (size, max_context_len), dtype=torch.int32, device="cuda" 16 | ) 17 | 18 | def alloc(self, need_size): 19 | if need_size > self.can_use_mem_size: 20 | return None 21 | 22 | available_indices = torch.nonzero(self.mem_state) 23 | if available_indices.numel() == 0: 24 | return None 25 | select_index = available_indices.squeeze(1)[:need_size] 26 | if select_index.numel() == 0: 27 | return None 28 | self.mem_state[select_index] = 0 29 | self.can_use_mem_size -= need_size 30 | return select_index.to(torch.int32) 31 | 32 | def free(self, free_index): 33 | if isinstance(free_index, (int,)): 34 | # Clamp individual integer indices 35 | free_index = max(0, min(free_index, len(self.mem_state) - 1)) 36 | self.can_use_mem_size += 1 37 | else: 38 | if len(free_index) == 0: 39 | return 40 | # Clamp tensor indices to prevent out of bounds access 41 | free_index = torch.clamp(free_index, 0, len(self.mem_state) - 1) 42 | self.can_use_mem_size += free_index.shape[0] 43 | self.mem_state[free_index] = 1 44 | 45 | # if self.can_use_mem_size == len(self.mem_state): 46 | # print(f"ReqToTokenPool: freed all. size = {self.can_use_mem_size}.") 47 | 48 | 49 | class TokenToKVPool: 50 | def __init__(self, size, dtype, head_num, head_dim, layer_num): 51 | self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda") 52 | self.alloc_ct = 0 53 | 54 | # [size, key/value, head_num, head_dim] for each layer 55 | self.kv_data = [ 56 | torch.empty((size, 2, head_num, head_dim), dtype=dtype, device="cuda") 57 | for _ in range(layer_num) 58 | ] 59 | 60 | def get_key_buffer(self, layer_id): 61 | return self.kv_data[layer_id][:, 0] 62 | 63 | def get_value_buffer(self, layer_id): 64 | return self.kv_data[layer_id][:, 1] 65 | 66 | def alloc(self, need_size): 67 | available_indices = torch.nonzero(self.mem_state == 0) 68 | if available_indices.numel() == 0: 69 | return None 70 | select_index = available_indices.squeeze(1)[:need_size] 71 | if select_index.numel() < need_size: 72 | return None 73 | 74 | self.add_refs(select_index) 75 | return select_index.to(torch.int32) 76 | 77 | def alloc_contiguous(self, need_size): 78 | available_indices = torch.nonzero(self.mem_state == 0) 79 | if available_indices.numel() == 0: 80 | return None 81 | empty_index = available_indices.squeeze(1)[:need_size] 82 | if empty_index.numel() < need_size: 83 | return None 84 | empty_size = len(empty_index) 85 | loc_sum = ( 86 | empty_index[need_size - 1 :] - empty_index[: empty_size - (need_size - 1)] 87 | ) 88 | can_used_loc = empty_index[: empty_size - (need_size - 1)][ 89 | loc_sum == need_size - 1 90 | ] 91 | if can_used_loc.shape[0] == 0: 92 | return None 93 | 94 | start_loc = can_used_loc[0].item() 95 | select_index = torch.arange(start_loc, start_loc + need_size, device="cuda") 96 | self.add_refs(select_index) 97 | return select_index.to(torch.int32), start_loc, start_loc + need_size 98 | 99 | def free(self, free_index): 100 | return self.decrease_refs(free_index) 101 | 102 | def available_size(self): 103 | return torch.sum(self.mem_state == 0).item() 104 | 105 | def add_refs(self, token_index: torch.Tensor): 106 | if len(token_index) == 0: 107 | return 108 | # Clamp indices to prevent out of bounds access 109 | token_index = torch.clamp(token_index, 0, len(self.mem_state) - 1) 110 | self.alloc_ct += len(token_index) 111 | self.mem_state[token_index] += 1 112 | 113 | def decrease_refs(self, token_index: torch.Tensor): 114 | if len(token_index) == 0: 115 | return 0 116 | # Clamp indices to prevent out of bounds access 117 | token_index = torch.clamp(token_index, 0, len(self.mem_state) - 1) 118 | self.alloc_ct -= len(token_index) 119 | self.mem_state[token_index] -= 1 120 | 121 | num_freed = torch.sum(self.mem_state[token_index] == 0) 122 | 123 | # if self.alloc_ct == 0: 124 | # print(f"TokenToKVPool: freed all. size = {len(self.mem_state)}.") 125 | 126 | return num_freed 127 | -------------------------------------------------------------------------------- /python/sglang/srt/server_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dataclasses 3 | from typing import List, Optional 4 | 5 | 6 | @dataclasses.dataclass 7 | class ServerArgs: 8 | model_path: str 9 | tokenizer_path: Optional[str] = None 10 | host: str = "127.0.0.1" 11 | port: int = 30000 12 | load_format: str = "auto" 13 | tokenizer_mode: str = "auto" 14 | trust_remote_code: bool = True 15 | mem_fraction_static: Optional[float] = None 16 | tp_size: int = 1 17 | model_mode: List[str] = () 18 | schedule_heuristic: str = "lpm" 19 | random_seed: int = 42 20 | log_level: str = "info" 21 | 22 | def __post_init__(self): 23 | if self.tokenizer_path is None: 24 | self.tokenizer_path = self.model_path 25 | if self.mem_fraction_static is None: 26 | if self.tp_size > 1: 27 | self.mem_fraction_static = 0.8 28 | else: 29 | self.mem_fraction_static = 0.9 30 | 31 | @staticmethod 32 | def add_cli_args(parser: argparse.ArgumentParser): 33 | parser.add_argument( 34 | "--model-path", 35 | type=str, 36 | help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.", 37 | required=True, 38 | ) 39 | parser.add_argument( 40 | "--tokenizer-path", 41 | type=str, 42 | default=ServerArgs.tokenizer_path, 43 | help="The path of the tokenizer.", 44 | ) 45 | parser.add_argument("--host", type=str, default=ServerArgs.host) 46 | parser.add_argument("--port", type=int, default=ServerArgs.port) 47 | parser.add_argument( 48 | "--load-format", 49 | type=str, 50 | default=ServerArgs.load_format, 51 | choices=["auto", "pt", "safetensors", "npcache", "dummy"], 52 | help="The format of the model weights to load. " 53 | '"auto" will try to load the weights in the safetensors format ' 54 | "and fall back to the pytorch bin format if safetensors format " 55 | "is not available. " 56 | '"pt" will load the weights in the pytorch bin format. ' 57 | '"safetensors" will load the weights in the safetensors format. ' 58 | '"npcache" will load the weights in pytorch format and store ' 59 | "a numpy cache to speed up the loading. " 60 | '"dummy" will initialize the weights with random values, ' 61 | "which is mainly for profiling.", 62 | ) 63 | parser.add_argument( 64 | "--tokenizer-mode", 65 | type=str, 66 | default=ServerArgs.tokenizer_mode, 67 | choices=["auto", "slow"], 68 | help="Tokenizer mode. 'auto' will use the fast " 69 | "tokenizer if available, and 'slow' will " 70 | "always use the slow tokenizer.", 71 | ) 72 | parser.add_argument( 73 | "--trust-remote-code", 74 | action="store_true", 75 | help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", 76 | ) 77 | parser.add_argument( 78 | "--mem-fraction-static", 79 | type=float, 80 | default=ServerArgs.mem_fraction_static, 81 | help="The fraction of the memory used for static allocation (model weights and KV cache memory pool)", 82 | ) 83 | parser.add_argument( 84 | "--tp-size", 85 | type=int, 86 | default=ServerArgs.tp_size, 87 | help="Tensor parallelism degree.", 88 | ) 89 | parser.add_argument( 90 | "--model-mode", 91 | type=str, 92 | default=[], 93 | nargs="+", 94 | choices=["flashinfer", "no-cache"], 95 | help="Model mode: [flashinfer, no-cache]", 96 | ) 97 | parser.add_argument( 98 | "--schedule-heuristic", 99 | type=str, 100 | default=ServerArgs.schedule_heuristic, 101 | help="Schudule mode: [lpm, weight, random, fcfs]", 102 | ) 103 | parser.add_argument( 104 | "--random-seed", 105 | type=int, 106 | default=ServerArgs.random_seed, 107 | help="Random seed.", 108 | ) 109 | parser.add_argument( 110 | "--log-level", 111 | type=str, 112 | default=ServerArgs.log_level, 113 | help="Log level", 114 | ) 115 | 116 | @classmethod 117 | def from_cli_args(cls, args: argparse.Namespace): 118 | attrs = [attr.name for attr in dataclasses.fields(cls)] 119 | return cls(**{attr: getattr(args, attr) for attr in attrs}) 120 | 121 | 122 | @dataclasses.dataclass 123 | class PortArgs: 124 | tokenizer_port: int # Port for tokenizer manager communication 125 | router_port: int # Port for router process communication and load balancing 126 | detokenizer_port: int # Port for detokenizer process communication 127 | nccl_port: int # Port for NCCL multi-GPU communication 128 | model_rpc_ports: List[int] # Ports for model RPC calls (tensor parallelism) 129 | -------------------------------------------------------------------------------- /python/sglang/srt/hf_transformers_utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for Huggingface Transformers.""" 2 | 3 | import warnings 4 | from typing import List, Optional, Tuple, Union 5 | 6 | from huggingface_hub import snapshot_download 7 | from sglang.srt.utils import is_multimodal_model 8 | from transformers import ( 9 | AutoConfig, 10 | AutoProcessor, 11 | AutoTokenizer, 12 | PreTrainedTokenizer, 13 | PreTrainedTokenizerFast, 14 | ) 15 | 16 | 17 | def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None): 18 | config = AutoConfig.from_pretrained( 19 | model, trust_remote_code=trust_remote_code, revision=revision 20 | ) 21 | return config 22 | 23 | 24 | # Models don't use the same configuration key for determining the maximum 25 | # context length. Store them here so we can sanely check them. 26 | # NOTE: The ordering here is important. Some models have two of these and we 27 | # have a preference for which value gets used. 28 | CONTEXT_LENGTH_KEYS = [ 29 | "max_sequence_length", 30 | "seq_length", 31 | "max_position_embeddings", 32 | "max_seq_len", 33 | "model_max_length", 34 | ] 35 | 36 | 37 | def get_context_length(config): 38 | """Get the context length of a model from a huggingface model config.""" 39 | rope_scaling = getattr(config, "rope_scaling", None) 40 | if rope_scaling: 41 | rope_scaling_factor = config.rope_scaling["factor"] 42 | else: 43 | rope_scaling_factor = 1 44 | 45 | for key in CONTEXT_LENGTH_KEYS: 46 | val = getattr(config, key, None) 47 | if val is not None: 48 | return int(rope_scaling_factor * val) 49 | return 2048 50 | 51 | 52 | # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. 53 | _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" 54 | 55 | 56 | def get_tokenizer( 57 | tokenizer_name: str, 58 | *args, 59 | tokenizer_mode: str = "auto", 60 | trust_remote_code: bool = False, 61 | tokenizer_revision: Optional[str] = None, 62 | **kwargs, 63 | ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: 64 | """Gets a tokenizer for the given model name via Huggingface.""" 65 | if is_multimodal_model(tokenizer_name): 66 | processor = get_processor( 67 | tokenizer_name, 68 | *args, 69 | trust_remote_code=trust_remote_code, 70 | tokenizer_revision=tokenizer_revision, 71 | **kwargs, 72 | ) 73 | tokenizer = processor.tokenizer 74 | return tokenizer 75 | 76 | if tokenizer_mode == "slow": 77 | if kwargs.get("use_fast", False): 78 | raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") 79 | kwargs["use_fast"] = False 80 | 81 | if ( 82 | "llama" in tokenizer_name.lower() 83 | and kwargs.get("use_fast", True) 84 | and tokenizer_name != _FAST_LLAMA_TOKENIZER 85 | ): 86 | pass 87 | # warnings.warn( 88 | # "For some LLaMA V1 models, initializing the fast tokenizer may " 89 | # "take a long time. To reduce the initialization time, consider " 90 | # f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " 91 | # "tokenizer." 92 | # ) 93 | try: 94 | tokenizer = AutoTokenizer.from_pretrained( 95 | tokenizer_name, 96 | *args, 97 | trust_remote_code=trust_remote_code, 98 | tokenizer_revision=tokenizer_revision, 99 | **kwargs, 100 | ) 101 | except TypeError as e: 102 | # The LLaMA tokenizer causes a protobuf error in some environments. 103 | err_msg = ( 104 | "Failed to load the tokenizer. If you are using a LLaMA V1 model " 105 | f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the " 106 | "original tokenizer." 107 | ) 108 | raise RuntimeError(err_msg) from e 109 | except ValueError as e: 110 | # If the error pertains to the tokenizer class not existing or not 111 | # currently being imported, suggest using the --trust-remote-code flag. 112 | if not trust_remote_code and ( 113 | "does not exist or is not currently imported." in str(e) 114 | or "requires you to execute the tokenizer file" in str(e) 115 | ): 116 | err_msg = ( 117 | "Failed to load the tokenizer. If the tokenizer is a custom " 118 | "tokenizer not yet available in the HuggingFace transformers " 119 | "library, consider setting `trust_remote_code=True` in LLM " 120 | "or using the `--trust-remote-code` flag in the CLI." 121 | ) 122 | raise RuntimeError(err_msg) from e 123 | else: 124 | raise e 125 | 126 | if not isinstance(tokenizer, PreTrainedTokenizerFast): 127 | warnings.warn( 128 | "Using a slow tokenizer. This might cause a significant " 129 | "slowdown. Consider using a fast tokenizer instead." 130 | ) 131 | return tokenizer 132 | 133 | 134 | def get_processor( 135 | tokenizer_name: str, 136 | *args, 137 | tokenizer_mode: str = "auto", 138 | trust_remote_code: bool = False, 139 | tokenizer_revision: Optional[str] = None, 140 | **kwargs, 141 | ): 142 | processor = AutoProcessor.from_pretrained( 143 | tokenizer_name, 144 | *args, 145 | trust_remote_code=trust_remote_code, 146 | tokenizer_revision=tokenizer_revision, 147 | **kwargs, 148 | ) 149 | return processor 150 | -------------------------------------------------------------------------------- /python/sglang/srt/layers/vocab_parallel_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from sglang.srt.parallel_utils.parallel_state import ( 6 | get_tensor_model_parallel_rank, 7 | get_tensor_model_parallel_world_size, 8 | tensor_model_parallel_all_reduce, 9 | ) 10 | from sglang.srt.parallel_utils.utils import divide 11 | from sglang.srt.utils import set_weight_attrs 12 | from torch.nn.parameter import Parameter 13 | 14 | 15 | def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int: 16 | """Pad the vocab size to the given value.""" 17 | return ((vocab_size + pad_to - 1) // pad_to) * pad_to 18 | 19 | 20 | def vocab_range_from_per_partition_vocab_size( 21 | per_partition_vocab_size: int, rank: int 22 | ) -> Sequence[int]: 23 | index_f = rank * per_partition_vocab_size 24 | index_l = index_f + per_partition_vocab_size 25 | return index_f, index_l 26 | 27 | 28 | def vocab_range_from_global_vocab_size( 29 | global_vocab_size: int, rank: int, world_size: int 30 | ) -> Sequence[int]: 31 | per_partition_vocab_size = divide(global_vocab_size, world_size) 32 | return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank) 33 | 34 | 35 | class VocabParallelEmbedding(torch.nn.Module): 36 | """Embedding parallelized in the vocabulary dimension. 37 | 38 | Adapted from torch.nn.Embedding, note that we pad the vocabulary size to 39 | make sure it is divisible by the number of model parallel GPUs. 40 | 41 | Args: 42 | num_embeddings: vocabulary size. 43 | embedding_dim: size of hidden state. 44 | params_dtype: type of the parameters. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | num_embeddings: int, 50 | embedding_dim: int, 51 | params_dtype: Optional[torch.dtype] = None, 52 | ): 53 | super().__init__() 54 | 55 | # Keep the input dimensions. 56 | self.num_embeddings = num_embeddings 57 | self.num_embeddings_padded = pad_vocab_size(num_embeddings) 58 | self.embedding_dim = embedding_dim 59 | if params_dtype is None: 60 | params_dtype = torch.get_default_dtype() 61 | self.tp_size = get_tensor_model_parallel_world_size() 62 | # Divide the weight matrix along the vocaburaly dimension. 63 | self.vocab_start_index, self.vocab_end_index = ( 64 | vocab_range_from_global_vocab_size( 65 | self.num_embeddings_padded, 66 | get_tensor_model_parallel_rank(), 67 | self.tp_size, 68 | ) 69 | ) 70 | self.num_embeddings_per_partition = ( 71 | self.vocab_end_index - self.vocab_start_index 72 | ) 73 | self.weight = Parameter( 74 | torch.empty( 75 | self.num_embeddings_per_partition, 76 | self.embedding_dim, 77 | device=torch.cuda.current_device(), 78 | dtype=params_dtype, 79 | ) 80 | ) 81 | set_weight_attrs( 82 | self.weight, {"parallel_dim": 0, "weight_loader": self.weight_loader} 83 | ) 84 | 85 | def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): 86 | parallel_dim = param.parallel_dim 87 | assert loaded_weight.shape[parallel_dim] == self.num_embeddings 88 | loaded_weight = loaded_weight[self.vocab_start_index : self.vocab_end_index] 89 | param[: loaded_weight.shape[0]].data.copy_(loaded_weight) 90 | 91 | def forward(self, input_): 92 | if self.tp_size > 1: 93 | # Build the mask. 94 | input_mask = (input_ < self.vocab_start_index) | ( 95 | input_ >= self.vocab_end_index 96 | ) 97 | # Mask the input. 98 | masked_input = input_.clone() - self.vocab_start_index 99 | masked_input[input_mask] = 0 100 | else: 101 | masked_input = input_ 102 | # Get the embeddings. 103 | output_parallel = F.embedding(masked_input, self.weight) 104 | # Mask the output embedding. 105 | if self.tp_size > 1: 106 | output_parallel[input_mask, :] = 0.0 107 | # Reduce across all the model parallel GPUs. 108 | output = tensor_model_parallel_all_reduce(output_parallel) 109 | return output 110 | 111 | 112 | class ParallelLMHead(VocabParallelEmbedding): 113 | """Parallelized LM head. 114 | 115 | Output logits weight matrices used in the Sampler. The weight and bias 116 | tensors are padded to make sure they are divisible by the number of 117 | model parallel GPUs. 118 | 119 | Args: 120 | num_embeddings: vocabulary size. 121 | embedding_dim: size of hidden state. 122 | bias: whether to use bias. 123 | params_dtype: type of the parameters. 124 | """ 125 | 126 | def __init__( 127 | self, 128 | num_embeddings: int, 129 | embedding_dim: int, 130 | bias: bool = False, 131 | params_dtype: Optional[torch.dtype] = None, 132 | ): 133 | super().__init__(num_embeddings, embedding_dim, params_dtype) 134 | if bias: 135 | self.bias = Parameter( 136 | torch.empty( 137 | self.num_embeddings_per_partition, 138 | device=torch.cuda.current_device(), 139 | dtype=params_dtype, 140 | ) 141 | ) 142 | set_weight_attrs( 143 | self.bias, {"parallel_dim": 0, "weight_loader": self.weight_loader} 144 | ) 145 | else: 146 | self.register_parameter("bias", None) 147 | 148 | def forward(self, input_): 149 | del input_ 150 | raise RuntimeError("LMHead's weights should be used in the sampler.") 151 | -------------------------------------------------------------------------------- /python/sglang/srt/layers/quantization/awq.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | import torch 4 | from sglang.srt.layers.linear import LinearMethodBase, set_weight_attrs 5 | from sglang.srt.layers.quantization.awq_triton import awq_gemm_triton 6 | from sglang.srt.layers.quantization.base_config import QuantizationConfig 7 | from torch.nn.parameter import Parameter 8 | 9 | 10 | class AWQConfig(QuantizationConfig): 11 | """Config class for AWQ. 12 | 13 | Reference: https://arxiv.org/abs/2306.00978 14 | """ 15 | 16 | def __init__( 17 | self, 18 | weight_bits: int, 19 | group_size: int, 20 | zero_point: bool, 21 | ) -> None: 22 | self.weight_bits = weight_bits 23 | self.group_size = group_size 24 | self.zero_point = zero_point 25 | 26 | if self.weight_bits != 4: 27 | raise ValueError( 28 | "Currently, only 4-bit weight quantization is supported for " 29 | f"AWQ, but got {self.weight_bits} bits." 30 | ) 31 | self.pack_factor = 32 // self.weight_bits 32 | 33 | def __repr__(self) -> str: 34 | return ( 35 | f"AWQConfig(weight_bits={self.weight_bits}, " 36 | f"group_size={self.group_size}, " 37 | f"zero_point={self.zero_point})" 38 | ) 39 | 40 | def get_name(self) -> str: 41 | return "awq" 42 | 43 | def get_supported_act_dtypes(self) -> List[torch.dtype]: 44 | return [torch.half] 45 | 46 | def get_min_capability(self) -> int: 47 | # The AWQ kernel only supports Turing or newer GPUs. 48 | return 75 49 | 50 | @staticmethod 51 | def get_config_filenames() -> List[str]: 52 | return [ 53 | "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq 54 | "quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq 55 | ] 56 | 57 | @classmethod 58 | def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": 59 | weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) 60 | group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) 61 | zero_point = cls.get_from_keys(config, ["zero_point"]) 62 | return cls(weight_bits, group_size, zero_point) 63 | 64 | def get_linear_method(self) -> "AWQLinearMethod": 65 | return AWQLinearMethod(self) 66 | 67 | def get_scaled_act_names(self) -> List[str]: 68 | return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] 69 | 70 | 71 | class AWQLinearMethod(LinearMethodBase): 72 | """Linear method for AWQ. 73 | 74 | Args: 75 | quant_config: The AWQ quantization config. 76 | """ 77 | 78 | def __init__(self, quant_config: AWQConfig): 79 | self.quant_config = quant_config 80 | 81 | def create_weights( 82 | self, input_size: int, output_size: int, params_dtype: torch.dtype 83 | ) -> Dict[str, torch.Tensor]: 84 | if input_size % self.quant_config.group_size != 0: 85 | raise ValueError( 86 | "The input size is not aligned with the quantized " 87 | "weight shape. This can be caused by too large " 88 | "tensor parallel size." 89 | ) 90 | if output_size % self.quant_config.pack_factor != 0: 91 | raise ValueError( 92 | "The output size is not aligned with the quantized " 93 | "weight shape. This can be caused by too large " 94 | "tensor parallel size." 95 | ) 96 | 97 | qweight = Parameter( 98 | torch.empty( 99 | input_size, 100 | output_size // self.quant_config.pack_factor, 101 | device="cuda", 102 | dtype=torch.int32, 103 | ), 104 | requires_grad=False, 105 | ) 106 | set_weight_attrs( 107 | qweight, 108 | { 109 | "input_dim": 0, 110 | "output_dim": 1, 111 | "packed_dim": 1, 112 | "pack_factor": self.quant_config.pack_factor, 113 | }, 114 | ) 115 | qzeros = Parameter( 116 | torch.empty( 117 | input_size // self.quant_config.group_size, 118 | output_size // self.quant_config.pack_factor, 119 | device="cuda", 120 | dtype=torch.int32, 121 | ), 122 | requires_grad=False, 123 | ) 124 | set_weight_attrs( 125 | qzeros, 126 | { 127 | "input_dim": 0, 128 | "output_dim": 1, 129 | "packed_dim": 1, 130 | "pack_factor": self.quant_config.pack_factor, 131 | }, 132 | ) 133 | scales = Parameter( 134 | torch.empty( 135 | input_size // self.quant_config.group_size, 136 | output_size, 137 | device="cuda", 138 | dtype=params_dtype, 139 | ), 140 | requires_grad=False, 141 | ) 142 | set_weight_attrs( 143 | scales, 144 | { 145 | "input_dim": 0, 146 | "output_dim": 1, 147 | }, 148 | ) 149 | return { 150 | "qweight": qweight, 151 | "qzeros": qzeros, 152 | "scales": scales, 153 | } 154 | 155 | def apply_weights( 156 | self, 157 | weights: Dict[str, torch.Tensor], 158 | x: torch.Tensor, 159 | bias: Optional[torch.Tensor] = None, 160 | ) -> torch.Tensor: 161 | qweight = weights["qweight"] 162 | qzeros = weights["qzeros"] 163 | scales = weights["scales"] 164 | pack_factor = self.quant_config.pack_factor 165 | out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,) 166 | reshaped_x = x.reshape(-1, x.shape[-1]) 167 | out = awq_gemm_triton(reshaped_x, qweight, scales, qzeros, pack_factor) 168 | if bias is not None: 169 | out = out + bias 170 | return out.reshape(out_shape) 171 | -------------------------------------------------------------------------------- /python/sglang/srt/layers/context_flashattention_nopad.py: -------------------------------------------------------------------------------- 1 | # Adapted from 2 | # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1 3 | import torch 4 | import triton 5 | import triton.language as tl 6 | from sglang.srt.utils import wrap_kernel_launcher 7 | 8 | 9 | @triton.jit 10 | def _fwd_kernel( 11 | Q, 12 | K, 13 | V, 14 | sm_scale, 15 | B_Start_Loc, 16 | B_Seqlen, 17 | Out, 18 | stride_qbs, 19 | stride_qh, 20 | stride_kbs, 21 | stride_kh, 22 | stride_vbs, 23 | stride_vh, 24 | stride_obs, 25 | stride_oh, 26 | kv_group_num: tl.constexpr, 27 | BLOCK_M: tl.constexpr, 28 | BLOCK_DMODEL: tl.constexpr, 29 | BLOCK_N: tl.constexpr, 30 | ): 31 | cur_batch = tl.program_id(0) 32 | cur_head = tl.program_id(1) 33 | start_m = tl.program_id(2) 34 | 35 | cur_kv_head = cur_head // kv_group_num 36 | 37 | cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) 38 | cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) 39 | 40 | block_start_loc = BLOCK_M * start_m 41 | 42 | # initialize offsets 43 | offs_n = tl.arange(0, BLOCK_N) 44 | offs_d = tl.arange(0, BLOCK_DMODEL) 45 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 46 | off_q = ( 47 | (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs 48 | + cur_head * stride_qh 49 | + offs_d[None, :] 50 | ) 51 | off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] 52 | off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] 53 | 54 | q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) 55 | 56 | k_ptrs = K + off_k 57 | v_ptrs = V + off_v 58 | 59 | # initialize pointer to m and l 60 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 61 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32) 62 | acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) 63 | 64 | block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) 65 | 66 | for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): 67 | start_n = tl.multiple_of(start_n, BLOCK_N) 68 | # -- compute qk ---- 69 | k = tl.load( 70 | k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, 71 | mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, 72 | other=0.0, 73 | ) 74 | # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) 75 | 76 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 77 | qk += tl.dot(q, k) 78 | qk *= sm_scale 79 | qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) 80 | 81 | # -- compute m_ij, p, l_ij 82 | m_ij = tl.max(qk, 1) 83 | p = tl.exp(qk - m_ij[:, None]) 84 | l_ij = tl.sum(p, 1) 85 | # -- update m_i and l_i 86 | m_i_new = tl.maximum(m_i, m_ij) 87 | alpha = tl.exp(m_i - m_i_new) 88 | beta = tl.exp(m_ij - m_i_new) 89 | l_i_new = alpha * l_i + beta * l_ij 90 | # -- update output accumulator -- 91 | # scale p 92 | p_scale = beta / l_i_new 93 | p = p * p_scale[:, None] 94 | # scale acc 95 | acc_scale = l_i / l_i_new * alpha 96 | acc = acc * acc_scale[:, None] 97 | # update acc 98 | v = tl.load( 99 | v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, 100 | mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, 101 | other=0.0, 102 | ) 103 | 104 | p = p.to(v.dtype) 105 | acc += tl.dot(p, v) 106 | # update m_i and l_i 107 | l_i = l_i_new 108 | m_i = m_i_new 109 | # initialize pointers to output 110 | off_o = ( 111 | (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs 112 | + cur_head * stride_oh 113 | + offs_d[None, :] 114 | ) 115 | out_ptrs = Out + off_o 116 | tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) 117 | 118 | 119 | cached_kernel = None 120 | 121 | 122 | def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): 123 | BLOCK = 128 124 | Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] 125 | assert Lq == Lk and Lk == Lv 126 | assert Lk in {16, 32, 64, 128} 127 | 128 | sm_scale = 1.0 / (Lq**0.5) 129 | batch, head = b_seq_len.shape[0], q.shape[1] 130 | kv_group_num = q.shape[1] // k.shape[1] 131 | 132 | grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) 133 | num_warps = 4 if Lk <= 64 else 8 134 | 135 | global cached_kernel 136 | if cached_kernel: 137 | cached_kernel( 138 | grid, 139 | num_warps, 140 | q, 141 | k, 142 | v, 143 | sm_scale, 144 | b_start_loc, 145 | b_seq_len, 146 | o, 147 | q.stride(0), 148 | q.stride(1), 149 | k.stride(0), 150 | k.stride(1), 151 | v.stride(0), 152 | v.stride(1), 153 | o.stride(0), 154 | o.stride(1), 155 | kv_group_num=kv_group_num, 156 | BLOCK_M=BLOCK, 157 | BLOCK_DMODEL=Lk, 158 | BLOCK_N=BLOCK, 159 | num_stages=1, 160 | ) 161 | return 162 | 163 | # Launch kernel using modern Triton API 164 | kernel_launcher = wrap_kernel_launcher(_fwd_kernel) 165 | kernel_launcher( 166 | grid, 167 | num_warps, 168 | q, 169 | k, 170 | v, 171 | sm_scale, 172 | b_start_loc, 173 | b_seq_len, 174 | o, 175 | q.stride(0), 176 | q.stride(1), 177 | k.stride(0), 178 | k.stride(1), 179 | v.stride(0), 180 | v.stride(1), 181 | o.stride(0), 182 | o.stride(1), 183 | kv_group_num=kv_group_num, 184 | BLOCK_M=BLOCK, 185 | BLOCK_DMODEL=Lk, 186 | BLOCK_N=BLOCK, 187 | num_stages=1, 188 | ) 189 | cached_kernel = kernel_launcher 190 | -------------------------------------------------------------------------------- /python/sglang/srt/layers/radix_attention.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd 5 | from sglang.srt.layers.extend_attention import extend_attention_fwd 6 | from sglang.srt.layers.token_attention import token_attention_fwd 7 | from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata 8 | 9 | # from sglang.srt.parallel_utils.parallel_state import ( 10 | # get_tensor_model_parallel_rank, 11 | # get_tensor_model_parallel_world_size, 12 | # ) 13 | from torch import nn 14 | 15 | 16 | class RadixAttention(nn.Module): 17 | def __init__( 18 | self, 19 | num_heads, 20 | head_dim, 21 | scaling, 22 | num_kv_heads, 23 | layer_id, 24 | ): 25 | super().__init__() 26 | 27 | self.tp_q_head_num = num_heads 28 | self.tp_k_head_num = num_kv_heads 29 | self.tp_v_head_num = num_kv_heads 30 | self.head_dim = head_dim 31 | self.layer_id = layer_id 32 | 33 | from sglang.srt.managers.router.model_runner import global_model_mode 34 | 35 | self.use_flashinfer = "flashinfer" in global_model_mode 36 | 37 | if self.use_flashinfer: 38 | self.prefill_forward = self.prefill_forward_flashinfer 39 | self.extend_forward = self.prefill_forward_flashinfer 40 | self.decode_forward = self.decode_forward_flashinfer 41 | else: 42 | self.prefill_forward = self.prefill_forward_triton 43 | self.extend_forward = self.extend_forward_triton 44 | self.decode_forward = self.decode_forward_triton 45 | 46 | def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata): 47 | o = torch.empty_like(q) 48 | 49 | context_attention_fwd( 50 | q.view(-1, self.tp_q_head_num, self.head_dim), 51 | k, 52 | v, 53 | o.view(-1, self.tp_q_head_num, self.head_dim), 54 | input_metadata.start_loc, 55 | input_metadata.seq_lens, 56 | input_metadata.max_seq_len, 57 | ) 58 | self.store_kv_cache(k, v, input_metadata) 59 | 60 | return o 61 | 62 | def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): 63 | o = torch.empty_like(q) 64 | self.store_kv_cache(k, v, input_metadata) 65 | 66 | extend_attention_fwd( 67 | q.view(-1, self.tp_q_head_num, self.head_dim), 68 | k.contiguous(), 69 | v.contiguous(), 70 | o.view(-1, self.tp_q_head_num, self.head_dim), 71 | input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id), 72 | input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id), 73 | input_metadata.req_to_token_pool.req_to_token, 74 | input_metadata.req_pool_indices, 75 | input_metadata.start_loc, 76 | input_metadata.seq_lens, 77 | input_metadata.prefix_lens, 78 | input_metadata.extend_start_loc, 79 | input_metadata.extend_seq_lens, 80 | input_metadata.max_seq_len, 81 | input_metadata.max_extend_len, 82 | ) 83 | 84 | return o 85 | 86 | def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata): 87 | o = torch.empty_like(q) 88 | self.store_kv_cache(k, v, input_metadata) 89 | 90 | token_attention_fwd( 91 | q.view(-1, self.tp_q_head_num, self.head_dim), 92 | input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id), 93 | input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id), 94 | o.view(-1, self.tp_q_head_num, self.head_dim), 95 | input_metadata.req_to_token_pool.req_to_token, 96 | input_metadata.req_pool_indices, 97 | input_metadata.start_loc, 98 | input_metadata.seq_lens, 99 | input_metadata.max_seq_len, 100 | input_metadata.other_kv_index, 101 | input_metadata.total_num_tokens, 102 | ) 103 | 104 | return o 105 | 106 | def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): 107 | self.store_kv_cache(k, v, input_metadata) 108 | 109 | o = input_metadata.prefill_wrapper.forward( 110 | q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), 111 | input_metadata.qo_indptr, 112 | input_metadata.token_to_kv_pool.kv_data[self.layer_id], 113 | input_metadata.kv_indptr, 114 | input_metadata.kv_indices, 115 | input_metadata.kv_last_page_len, 116 | allow_fp16_qk_reduction=True, 117 | ) 118 | 119 | return o.view(-1, self.tp_q_head_num * self.head_dim) 120 | 121 | def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): 122 | self.store_kv_cache(k, v, input_metadata) 123 | 124 | o = input_metadata.decode_wrapper.forward( 125 | q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), 126 | input_metadata.token_to_kv_pool.kv_data[self.layer_id], 127 | input_metadata.kv_indptr, 128 | input_metadata.kv_indices, 129 | input_metadata.kv_last_page_len, 130 | ) 131 | 132 | return o.view(-1, self.tp_q_head_num * self.head_dim) 133 | 134 | def forward(self, q, k, v, input_metadata: InputMetadata): 135 | k = k.view(-1, self.tp_k_head_num, self.head_dim) 136 | v = v.view(-1, self.tp_v_head_num, self.head_dim) 137 | 138 | if input_metadata.forward_mode == ForwardMode.PREFILL: 139 | return self.prefill_forward(q, k, v, input_metadata) 140 | elif input_metadata.forward_mode == ForwardMode.EXTEND: 141 | return self.extend_forward(q, k, v, input_metadata) 142 | elif input_metadata.forward_mode == ForwardMode.DECODE: 143 | return self.decode_forward(q, k, v, input_metadata) 144 | 145 | def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): 146 | key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id) 147 | value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id) 148 | if input_metadata.out_cache_loc is not None: 149 | key_buffer[input_metadata.out_cache_loc] = cache_k 150 | value_buffer[input_metadata.out_cache_loc] = cache_v 151 | elif input_metadata.out_cache_cont_start is not None: 152 | key_buffer[ 153 | input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end 154 | ] = cache_k 155 | value_buffer[ 156 | input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end 157 | ] = cache_v 158 | else: 159 | raise RuntimeError() 160 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nano-sglang 2 | 3 | A lightweight LLM inference framework inspired by SGLang, designed for learning and educational purposes. This project focuses on simplicity and clarity over raw performance, implementing core inference concepts while maintaining a clean, minimal codebase. 4 | 5 | ## 🎯 Project Goals 6 | 7 | nano-sglang is created to help developers understand the internal workings of modern LLM inference frameworks. We've stripped away complexity while preserving the essential components that make SGLang powerful: 8 | 9 | - **Educational Focus**: Clean, readable code that demonstrates core inference concepts 10 | - **Core Implementation**: RadixTree, scheduling, and other fundamental mechanisms 11 | - **Minimal Dependencies**: Reduced complexity without sacrificing functionality 12 | - **Modern Stack**: Built with torch and triton operators, updated for latest libraries 13 | 14 | ## ✨ Features 15 | 16 | ### Core Capabilities 17 | 18 | - **RadixTree**: Efficient attention key-value caching and management 19 | - **Advanced Scheduling**: Multiple scheduling heuristics (LPM, weight, random, FCFS) 20 | - **Tensor Parallelism**: Multi-GPU inference support 21 | - **AWQ Quantization**: Memory-efficient model quantization 22 | 23 | ### Model Support 24 | 25 | - ✅ **Llama2** models (base and chat variants) 26 | - ✅ **Llama2 AWQ** quantized models 27 | - 🚧 **More models coming soon** (see roadmap below) 28 | 29 | ### Technical Implementation 30 | 31 | - **Pure Torch/Triton**: All operators implemented using PyTorch and Triton 32 | - **VLLM-Free**: Completely removed VLLM dependencies for cleaner codebase 33 | - **Modern Dependencies**: Updated to work with latest library versions 34 | - **Bug Fixes**: Resolved issues from early SGLang versions 35 | 36 | ## 🚀 Quick Start 37 | 38 | ### Installation 39 | 40 | ```bash 41 | # Clone the repository 42 | git clone https://github.com/your-username/nano-sglang.git 43 | cd nano-sglang/python 44 | 45 | # Install with all dependencies 46 | pip install -e ".[all]" 47 | 48 | # Optional: Install flashinfer for acceleration 49 | git submodule update --init --recursive 50 | pip install 3rdparty/flashinfer/python 51 | ``` 52 | 53 | ### Basic Usage 54 | 55 | ```bash 56 | # Basic server launch 57 | python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 58 | 59 | # With tensor parallelism 60 | python3 -m sglang.launch_server --model-path /path/to/llama2-model --port 30000 --tp 2 61 | 62 | # With AWQ quantization 63 | python3 -m sglang.launch_server --model-path /path/to/Llama-2-7B-AWQ --port 30000 --mem-fraction-static 0.8 64 | 65 | # With flashinfer acceleration 66 | python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --model-mode flashinfer 67 | ``` 68 | 69 | ### API Usage 70 | 71 | The server provides OpenAI-compatible endpoints: 72 | 73 | #### curl Examples 74 | 75 | ```bash 76 | # Completions 77 | curl -X POST "http://localhost:30000/v1/completions" \ 78 | -H "Content-Type: application/json" \ 79 | -d '{ 80 | "model": "meta-llama/Llama-2-7b-chat-hf", 81 | "prompt": "What is the capital of France?", 82 | "max_tokens": 40, 83 | "temperature": 0 84 | }' 85 | # {"choices":[{"text":"\nFrance is a country in Western Europe. It is the largest country in the European Union. The capital of France is Paris.\nWhat is the capital of France?\nWhat is the capital of"}]} 86 | ``` 87 | 88 | #### Python Examples 89 | 90 | ```python 91 | import requests 92 | 93 | # Completions 94 | response = requests.post("http://localhost:30000/v1/completions", json={ 95 | "model": "meta-llama/Llama-2-7b-chat-hf", 96 | "prompt": "What is the capital of France?", 97 | "max_tokens": 40, 98 | "temperature": 0 99 | }) 100 | 101 | print(response.json()) 102 | # {"choices":[{"text":"\nFrance is a country in Western Europe. It is the largest country in the European Union. The capital of France is Paris.\nWhat is the capital of France?\nWhat is the capital of"}]} 103 | ``` 104 | 105 | ## 🏗️ Architecture 106 | 107 | ### Core Components 108 | 109 | - **Multi-process Architecture**: Separate processes for tokenizer, router, detokenizer, and model workers 110 | - **Memory Management**: Efficient GPU memory pool management 111 | - **OpenAI-Compatible API**: FastAPI server with standard endpoints 112 | 113 | ### Process Structure 114 | 115 | ``` 116 | ┌──────────────────────────────────────────────────────────┐ 117 | │ │ 118 | ▼ │ 119 | ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │ 120 | │ Tokenizer │───▶│ Router │───▶│ Detokenizer │───┘ 121 | │ Manager │ │ Process │ │ Process │ 122 | └─────────────────┘ └─────────────────┘ └─────────────────┘ 123 | │ ▲ 124 | │ │ 125 | ▼ │ 126 | ┌───────────────────────────┐ 127 | │ Model Workers │ 128 | │ (Tensor Parallel) │ 129 | └───────────────────────────┘ 130 | ``` 131 | 132 | ## 🗺️ Roadmap 133 | 134 | We welcome contributions to implement these features: 135 | 136 | ### Model Support 137 | 138 | - [ ] **Qwen3 / Qwen3-MoE**: models from Alibaba 139 | - [ ] **More Models**: Additional model architectures 140 | 141 | ### Advanced Features 142 | 143 | - [ ] **DP-Attention**: parallel dp attention mechanisms 144 | - [ ] **Chunked Prefill**: Efficient long context processing 145 | - [ ] **More quantization methods**: Additional quantization methods (GPTQ, SmoothQuant, etc.) 146 | - [ ] **Speculative Decoding**: Faster inference techniques 147 | 148 | ## 🙏 Acknowledgments 149 | 150 | - **[SGLang](https://github.com/sgl-project/sglang)**: For the original inspiration and architectural design 151 | - **[VLLM](https://github.com/vllm-project/vllm)**: For pioneering many optimization techniques in LLM inference 152 | - **[nano-vllm](https://github.com/GeeeekExplorer/nano-vllm)**: For the lightweight VLLM implementation reference 153 | 154 | ## 🤝 Contributing 155 | 156 | We strongly encourage contributions! This project is designed to be a collaborative learning resource. 157 | 158 | Feel free to open issues, submit pull requests, or start discussions. We're here to learn together! 159 | 160 | Before submitting code, please set up pre-commit hooks to ensure code quality: 161 | 162 | - Follow the existing code style and structure 163 | - Ensure all pre-commit checks pass before submitting PRs 164 | 165 | ```bash 166 | # Install pre-commit hooks 167 | pre-commit install 168 | 169 | # Run pre-commit on all files 170 | pre-commit run --all-files 171 | ``` 172 | 173 | ## 🌟Star History 174 | 175 | [![Star History Chart](https://api.star-history.com/svg?repos=gogongxt/nano-sglang&type=date&legend=top-left)](https://www.star-history.com/#gogongxt/nano-sglang&type=date&legend=top-left) 176 | 177 | --- 178 | 179 | **Note**: This project prioritizes educational clarity over raw performance. For production workloads, consider using the original [SGLang](https://github.com/sgl-project/sglang) framework. 180 | -------------------------------------------------------------------------------- /python/sglang/srt/managers/router/radix_cache.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | import time 3 | from collections import defaultdict 4 | from dataclasses import dataclass 5 | from typing import Tuple 6 | 7 | import torch 8 | 9 | 10 | class TreeNode: 11 | def __init__(self): 12 | self.children = defaultdict(TreeNode) 13 | self.parent = None 14 | self.value = None 15 | self.ref_counter = 0 16 | self.last_access_time = time.time() 17 | 18 | def __lt__(self, other): 19 | return self.last_access_time < other.last_access_time 20 | 21 | 22 | def match(key, seq): 23 | i = 0 24 | for k, w in zip(key, seq): 25 | if k != w: 26 | break 27 | i += 1 28 | return i 29 | 30 | 31 | class RadixCache: 32 | def __init__(self, disable=False): 33 | self.root_node = TreeNode() 34 | self.root_node.value = [] 35 | self.root_node.ref_counter = 1 36 | self.evictable_size_ = 0 37 | 38 | self.disable = disable 39 | 40 | ##### Public API ##### 41 | def match_prefix(self, key): 42 | if self.disable: 43 | return [], self.root_node 44 | 45 | value = [] 46 | last_node = [self.root_node] 47 | self._match_prefix_helper(self.root_node, key, value, last_node) 48 | if value: 49 | value = torch.concat(value) 50 | return value, last_node[0] 51 | 52 | def insert(self, key, value=None): 53 | if self.disable: 54 | return len(key) 55 | 56 | if value is None: 57 | value = [x for x in key] 58 | return self._insert_helper(self.root_node, key, value) 59 | 60 | def pretty_print(self): 61 | self._print_helper(self.root_node, 0) 62 | print(f"#tokens: {self.total_size()}") 63 | 64 | def total_size(self): 65 | return self._total_size_helper(self.root_node) 66 | 67 | def evict(self, num_tokens, evict_callback): 68 | if self.disable: 69 | raise RuntimeError() 70 | 71 | leaves = self._collect_leaves() 72 | heapq.heapify(leaves) 73 | 74 | num_evicted = 0 75 | while num_evicted < num_tokens and len(leaves): 76 | x = heapq.heappop(leaves) 77 | 78 | if x == self.root_node: 79 | break 80 | if x.ref_counter > 0: 81 | continue 82 | 83 | num_evicted += evict_callback(x.value) 84 | self._delete_leaf(x) 85 | 86 | if len(x.parent.children) == 0: 87 | heapq.heappush(leaves, x.parent) 88 | 89 | def inc_ref_counter(self, node): 90 | delta = 0 91 | while node != self.root_node: 92 | if node.ref_counter == 0: 93 | self.evictable_size_ -= len(node.value) 94 | delta -= len(node.value) 95 | node.ref_counter += 1 96 | node = node.parent 97 | return delta 98 | 99 | def dec_ref_counter(self, node): 100 | delta = 0 101 | while node != self.root_node: 102 | if node.ref_counter == 1: 103 | self.evictable_size_ += len(node.value) 104 | delta += len(node.value) 105 | node.ref_counter -= 1 106 | node = node.parent 107 | return delta 108 | 109 | def evictable_size(self): 110 | return self.evictable_size_ 111 | 112 | ##### Internal Helper Functions ##### 113 | def _match_prefix_helper(self, node, key, value, last_node): 114 | node.last_access_time = time.time() 115 | 116 | for c_key, child in node.children.items(): 117 | prefix_len = match(c_key, key) 118 | if prefix_len != 0: 119 | if prefix_len < len(c_key): 120 | new_node = self._split_node(c_key, child, prefix_len) 121 | value.append(new_node.value) 122 | last_node[0] = new_node 123 | else: 124 | value.append(child.value) 125 | last_node[0] = child 126 | self._match_prefix_helper(child, key[prefix_len:], value, last_node) 127 | break 128 | 129 | def _split_node(self, key, child, split_len): 130 | # new_node -> child 131 | new_node = TreeNode() 132 | new_node.children = {key[split_len:]: child} 133 | new_node.parent = child.parent 134 | new_node.ref_counter = child.ref_counter 135 | new_node.value = child.value[:split_len] 136 | child.parent = new_node 137 | child.value = child.value[split_len:] 138 | new_node.parent.children[key[:split_len]] = new_node 139 | del new_node.parent.children[key] 140 | return new_node 141 | 142 | def _insert_helper(self, node, key, value): 143 | node.last_access_time = time.time() 144 | 145 | for c_key, child in node.children.items(): 146 | prefix_len = match(c_key, key) 147 | 148 | if prefix_len == len(c_key): 149 | if prefix_len == len(key): 150 | return prefix_len 151 | else: 152 | key = key[prefix_len:] 153 | value = value[prefix_len:] 154 | return prefix_len + self._insert_helper(child, key, value) 155 | 156 | if prefix_len: 157 | new_node = self._split_node(c_key, child, prefix_len) 158 | return prefix_len + self._insert_helper( 159 | new_node, key[prefix_len:], value[prefix_len:] 160 | ) 161 | 162 | if len(key): 163 | new_node = TreeNode() 164 | new_node.parent = node 165 | new_node.value = value 166 | node.children[key] = new_node 167 | self.evictable_size_ += len(value) 168 | return 0 169 | 170 | def _print_helper(self, node, indent): 171 | for key, child in node.children.items(): 172 | print(" " * indent, len(key), key[:10], f"r={child.ref_counter}") 173 | self._print_helper(child, indent=indent + 2) 174 | 175 | def _delete_leaf(self, node): 176 | for k, v in node.parent.children.items(): 177 | if v == node: 178 | break 179 | del node.parent.children[k] 180 | self.evictable_size_ -= len(k) 181 | 182 | def _total_size_helper(self, node): 183 | x = len(node.value) 184 | for child in node.children.values(): 185 | x += self._total_size_helper(child) 186 | return x 187 | 188 | def _collect_leaves(self): 189 | ret_list = [] 190 | 191 | def dfs_(cur_node): 192 | if len(cur_node.children) == 0: 193 | ret_list.append(cur_node) 194 | 195 | for x in cur_node.children.values(): 196 | dfs_(x) 197 | 198 | dfs_(self.root_node) 199 | return ret_list 200 | 201 | 202 | if __name__ == "__main__": 203 | tree = RadixCache(disable=False) 204 | 205 | tree.insert("Hello") 206 | tree.insert("Hello") 207 | tree.insert("Hello_L.A.!") 208 | # tree.insert("Hello_world! Happy") 209 | # tree.insert("I love you!") 210 | tree.pretty_print() 211 | 212 | # print(tree.match_prefix("I love you! aha")) 213 | 214 | # def evict_callback(x): 215 | # print("evict", x) 216 | # return len(x) 217 | 218 | # tree.evict(5, evict_callback) 219 | # tree.evict(10, evict_callback) 220 | # tree.pretty_print() 221 | -------------------------------------------------------------------------------- /python/sglang/srt/managers/tokenizer_manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import concurrent.futures 3 | import dataclasses 4 | import os 5 | from typing import List 6 | 7 | import numpy as np 8 | import transformers 9 | import uvloop 10 | import zmq 11 | import zmq.asyncio 12 | from sglang.srt.hf_transformers_utils import ( 13 | get_config, 14 | get_context_length, 15 | get_processor, 16 | get_tokenizer, 17 | ) 18 | from sglang.srt.managers.io_struct import ( 19 | BatchStrOut, 20 | GenerateReqInput, 21 | TokenizedGenerateReqInput, 22 | ) 23 | from sglang.srt.sampling_params import SamplingParams 24 | from sglang.srt.server_args import PortArgs, ServerArgs 25 | from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image 26 | 27 | asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) 28 | 29 | 30 | @dataclasses.dataclass 31 | class ReqState: 32 | out_list: List 33 | finished: bool 34 | event: asyncio.Event 35 | lock: asyncio.Lock 36 | 37 | 38 | global global_processor 39 | 40 | 41 | def init_global_processor(server_args: ServerArgs): 42 | global global_processor 43 | transformers.logging.set_verbosity_error() 44 | global_processor = get_processor( 45 | server_args.tokenizer_path, 46 | tokenizer_mode=server_args.tokenizer_mode, 47 | trust_remote_code=server_args.trust_remote_code, 48 | ) 49 | 50 | 51 | def get_pixel_values(image_data, processor=None): 52 | try: 53 | processor = processor or global_processor 54 | image = load_image(image_data) 55 | image_hash = hash(image_data) 56 | pixel_values = processor.image_processor(image)["pixel_values"][0] 57 | pixel_values = pixel_values.astype(np.float16) 58 | return pixel_values, image_hash 59 | except Exception: 60 | print("Exception in TokenizerManager:\n" + get_exception_traceback()) 61 | 62 | 63 | class TokenizerManager: 64 | def __init__( 65 | self, 66 | server_args: ServerArgs, 67 | port_args: PortArgs, 68 | ): 69 | context = zmq.asyncio.Context(2) 70 | self.recv_from_detokenizer = context.socket(zmq.PULL) 71 | self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}") 72 | 73 | self.send_to_router = context.socket(zmq.PUSH) 74 | self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.router_port}") 75 | 76 | self.model_path = server_args.model_path 77 | self.hf_config = get_config( 78 | self.model_path, trust_remote_code=server_args.trust_remote_code 79 | ) 80 | self.context_len = get_context_length(self.hf_config) # 4096 llama2-7b 81 | 82 | if is_multimodal_model(self.model_path): 83 | self.processor = get_processor( 84 | server_args.tokenizer_path, 85 | tokenizer_mode=server_args.tokenizer_mode, 86 | trust_remote_code=server_args.trust_remote_code, 87 | ) 88 | self.tokenizer = self.processor.tokenizer 89 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 90 | self.executor = concurrent.futures.ProcessPoolExecutor( 91 | initializer=init_global_processor, initargs=(server_args,) 92 | ) 93 | else: 94 | self.tokenizer = get_tokenizer( 95 | server_args.tokenizer_path, 96 | tokenizer_mode=server_args.tokenizer_mode, 97 | trust_remote_code=server_args.trust_remote_code, 98 | ) 99 | 100 | self.to_create_loop = True 101 | self.rid_to_state = {} # Dict[str -> ReqState] 102 | 103 | async def get_pixel_values(self, image_data): 104 | if self.executor is not None: 105 | loop = asyncio.get_event_loop() 106 | return await loop.run_in_executor( 107 | self.executor, get_pixel_values, image_data 108 | ) 109 | else: 110 | return get_pixel_values(image_data, self.processor) 111 | 112 | async def generate_request(self, obj: GenerateReqInput): 113 | if self.to_create_loop: 114 | await self.create_handle_loop() 115 | 116 | is_single = isinstance(obj.text, str) 117 | 118 | if is_single: 119 | rid = obj.rid 120 | input_ids = self.tokenizer.encode(obj.text) 121 | sampling_params = SamplingParams(**obj.sampling_params) 122 | if sampling_params.max_new_tokens != 0: 123 | sampling_params.normalize(self.tokenizer) 124 | sampling_params.verify() 125 | if obj.image_data is None: 126 | pixel_values, image_hash = None, None 127 | else: 128 | pixel_values, image_hash = await self.get_pixel_values(obj.image_data) 129 | tokenized_obj = TokenizedGenerateReqInput( 130 | rid=rid, 131 | input_ids=input_ids, 132 | pixel_values=pixel_values, 133 | image_hash=image_hash, 134 | sampling_params=sampling_params, 135 | return_normalized_logprob=obj.return_normalized_logprob, 136 | normalized_logprob_start_len=obj.normalized_logprob_start_len, 137 | stream=obj.stream, 138 | ) 139 | self.send_to_router.send_pyobj(tokenized_obj) 140 | 141 | lock = asyncio.Lock() 142 | event = asyncio.Event() 143 | state = ReqState([], False, event, lock) 144 | self.rid_to_state[rid] = state 145 | 146 | while True: 147 | await event.wait() 148 | yield state.out_list[-1] 149 | state.out_list = [] 150 | if state.finished: 151 | del self.rid_to_state[rid] 152 | break 153 | event.clear() 154 | else: 155 | assert obj.stream is False 156 | bs = len(obj.text) 157 | for i in range(bs): 158 | rid = obj.rid[i] 159 | input_ids = self.tokenizer.encode(obj.text[i]) 160 | sampling_params = SamplingParams(**obj.sampling_params[i]) 161 | if sampling_params.max_new_tokens != 0: 162 | sampling_params.normalize(self.tokenizer) 163 | sampling_params.verify() 164 | if obj.image_data[i] is None: 165 | pixel_values, image_hash = None, None 166 | else: 167 | pixel_values, image_hash = await self.get_pixel_values( 168 | obj.image_data[i] 169 | ) 170 | tokenized_obj = TokenizedGenerateReqInput( 171 | rid=rid, 172 | input_ids=input_ids, 173 | pixel_values=pixel_values, 174 | image_hash=image_hash, 175 | sampling_params=sampling_params, 176 | return_normalized_logprob=obj.return_normalized_logprob[i], 177 | normalized_logprob_start_len=obj.normalized_logprob_start_len[i], 178 | stream=obj.stream, 179 | ) 180 | self.send_to_router.send_pyobj(tokenized_obj) 181 | 182 | lock = asyncio.Lock() 183 | event = asyncio.Event() 184 | state = ReqState([], False, event, lock) 185 | self.rid_to_state[rid] = state 186 | 187 | output_list = [] 188 | for i in range(bs): 189 | rid = obj.rid[i] 190 | state = self.rid_to_state[rid] 191 | await state.event.wait() 192 | output_list.append(state.out_list[-1]) 193 | assert state.finished 194 | del self.rid_to_state[rid] 195 | 196 | yield output_list 197 | 198 | async def create_handle_loop(self): 199 | self.to_create_loop = False 200 | loop = asyncio.get_event_loop() 201 | loop.create_task(self.handle_loop()) 202 | 203 | async def handle_loop(self): 204 | while True: 205 | recv_obj = await self.recv_from_detokenizer.recv_pyobj() 206 | 207 | if isinstance(recv_obj, BatchStrOut): 208 | for i, rid in enumerate(recv_obj.rids): 209 | recv_obj.meta_info[i]["id"] = rid 210 | out_dict = { 211 | "text": recv_obj.output_str[i], 212 | "meta_info": recv_obj.meta_info[i], 213 | } 214 | state = self.rid_to_state[rid] 215 | state.out_list.append(out_dict) 216 | state.finished = recv_obj.finished[i] 217 | state.event.set() 218 | else: 219 | raise ValueError(f"Invalid object: {recv_obj}") 220 | -------------------------------------------------------------------------------- /python/sglang/srt/constrained/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Adapted from: 2 | # https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py 3 | # https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py 4 | from abc import abstractmethod 5 | from typing import ( 6 | TYPE_CHECKING, 7 | Dict, 8 | Hashable, 9 | List, 10 | Optional, 11 | Protocol, 12 | Set, 13 | Tuple, 14 | Union, 15 | ) 16 | 17 | import numpy as np 18 | import torch 19 | from numpy.typing import NDArray 20 | 21 | 22 | class Tokenizer(Protocol, Hashable): 23 | eos_token: str 24 | eos_token_id: int 25 | pad_token_id: int 26 | vocabulary: Dict[str, int] 27 | special_tokens: Set[int] 28 | 29 | @abstractmethod 30 | def encode( 31 | self, prompt: Union[str, List[str]] 32 | ) -> Tuple[NDArray[np.int64], NDArray[np.int64]]: 33 | """Translate the input prompts into NumPy arrays of token ids and attention mask.""" 34 | ... 35 | 36 | @abstractmethod 37 | def decode(self, token_ids: NDArray[np.int64]) -> List[str]: 38 | """Translate an array of token ids to a string or list of strings.""" 39 | ... 40 | 41 | @abstractmethod 42 | def convert_token_to_string(self, token: str) -> str: 43 | """Convert a token to its equivalent string. 44 | 45 | This is for instance useful for BPE tokenizers where whitespaces are 46 | represented by the special characted `Ġ`. This prevents matching a raw 47 | token that includes `Ġ` with a string. 48 | 49 | """ 50 | ... 51 | 52 | 53 | if TYPE_CHECKING: 54 | from transformers import PreTrainedModel, PreTrainedTokenizer 55 | 56 | __all__ = ["transformers"] 57 | 58 | 59 | KVCacheType = Tuple[Tuple[torch.DoubleTensor, torch.DoubleTensor], ...] 60 | 61 | 62 | def get_llama_tokenizer_types(): 63 | """Get all the Llama tokenizer types/classes that need work-arounds. 64 | 65 | When they can't be imported, a dummy class is created. 66 | 67 | """ 68 | try: 69 | from transformers.models.llama import LlamaTokenizer 70 | except ImportError: 71 | 72 | class LlamaTokenizer: # type: ignore 73 | pass 74 | 75 | try: 76 | from transformers.models.llama import LlamaTokenizerFast 77 | except ImportError: 78 | 79 | class LlamaTokenizerFast: # type: ignore 80 | pass 81 | 82 | try: 83 | from transformers.models.code_llama import CodeLlamaTokenizer 84 | except ImportError: 85 | 86 | class CodeLlamaTokenizer: # type: ignore 87 | pass 88 | 89 | try: 90 | from transformers.models.code_llama import CodeLlamaTokenizerFast 91 | except ImportError: 92 | 93 | class CodeLlamaTokenizerFast: # type: ignore 94 | pass 95 | 96 | return ( 97 | LlamaTokenizer, 98 | LlamaTokenizerFast, 99 | CodeLlamaTokenizer, 100 | CodeLlamaTokenizerFast, 101 | ) 102 | 103 | 104 | class Transformer: 105 | """Represents a `transformers` model.""" 106 | 107 | def __init__( 108 | self, 109 | model: "PreTrainedModel", 110 | tokenizer: "PreTrainedTokenizer", 111 | ): 112 | self.device = model.device 113 | self.model = model 114 | self.tokenizer = tokenizer 115 | 116 | @torch.inference_mode 117 | def forward( 118 | self, 119 | input_ids: torch.LongTensor, 120 | attention_mask: torch.LongTensor, 121 | past_key_values: Optional[Tuple] = None, 122 | ) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]: 123 | """Compute a forward pass through the transformer model. 124 | 125 | Parameters 126 | ---------- 127 | input_ids 128 | The input token ids. Must be one or two dimensional. 129 | attention_mask 130 | The attention mask. Must be one or two dimensional. 131 | past_key_values 132 | A tuple of tuples containing the cached key and value tensors for each 133 | attention head. 134 | 135 | Returns 136 | ------- 137 | The computed logits and the new cached key and value tensors. 138 | 139 | """ 140 | assert 0 < input_ids.ndim < 3 141 | 142 | if past_key_values: 143 | input_ids = input_ids[..., -1].unsqueeze(-1) 144 | 145 | output = self.model( 146 | input_ids, 147 | attention_mask=attention_mask, 148 | return_dict=True, 149 | output_attentions=False, 150 | output_hidden_states=False, 151 | past_key_values=past_key_values, 152 | ) 153 | 154 | return output.logits, output.past_key_values 155 | 156 | def __call__( 157 | self, 158 | input_ids: torch.LongTensor, 159 | attention_mask: torch.LongTensor, 160 | past_key_values: Optional[Tuple] = None, 161 | ) -> torch.FloatTensor: 162 | logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values) 163 | next_token_logits = logits[..., -1, :] 164 | 165 | return next_token_logits, kv_cache 166 | 167 | 168 | class TransformerTokenizer(Tokenizer): 169 | """Represents a tokenizer for models in the `transformers` library.""" 170 | 171 | def __init__(self, tokenizer): 172 | # TODO: Do something to make this hashable? 173 | self.tokenizer = tokenizer 174 | self.eos_token_id = self.tokenizer.eos_token_id 175 | self.eos_token = self.tokenizer.eos_token 176 | 177 | if not self.tokenizer.pad_token_id: 178 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 179 | self.pad_token_id = self.eos_token_id 180 | else: 181 | self.pad_token_id = self.tokenizer.pad_token_id 182 | self.pad_token = self.tokenizer.pad_token 183 | 184 | self.special_tokens = set(self.tokenizer.all_special_tokens) 185 | 186 | self.vocabulary = self.tokenizer.get_vocab() 187 | self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types()) 188 | 189 | def encode( 190 | self, prompt: Union[str, List[str]], **kwargs 191 | ) -> Tuple[torch.LongTensor, torch.LongTensor]: 192 | kwargs["padding"] = True 193 | kwargs["return_tensors"] = "pt" 194 | output = self.tokenizer(prompt, **kwargs) 195 | return output["input_ids"], output["attention_mask"] 196 | 197 | def decode(self, token_ids: torch.LongTensor) -> List[str]: 198 | text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) 199 | return text 200 | 201 | def convert_token_to_string(self, token: str) -> str: 202 | from transformers.file_utils import SPIECE_UNDERLINE 203 | 204 | string = self.tokenizer.convert_tokens_to_string([token]) 205 | 206 | if self.is_llama: 207 | # A hack to handle missing spaces to HF's Llama tokenizers 208 | if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": 209 | return " " + string 210 | 211 | return string 212 | 213 | def __eq__(self, other): 214 | if isinstance(other, type(self)): 215 | return False 216 | # TODO(lsyin): the lru_cache for the TransoformerTokenizer is useless ? 217 | # return other.model_name == self.model_name and other.kwargs == self.kwargs 218 | return NotImplemented 219 | 220 | def __hash__(self): 221 | from datasets.fingerprint import Hasher 222 | 223 | return hash(Hasher.hash(self.tokenizer)) 224 | 225 | 226 | def transformers( 227 | model_name: str, 228 | device: Optional[str] = None, 229 | model_kwargs: dict = {}, 230 | tokenizer_kwargs: dict = {}, 231 | ): 232 | """Instantiate a model from the `transformers` library and its tokenizer. 233 | 234 | Parameters 235 | ---------- 236 | model_name 237 | The name of the model as listed on Hugging Face's model page. 238 | device 239 | The device(s) on which the model should be loaded. This overrides 240 | the `device_map` entry in `model_kwargs` when provided. 241 | model_kwargs 242 | A dictionary that contains the keyword arguments to pass to the 243 | `from_pretrained` method when loading the model. 244 | tokenizer_kwargs 245 | A dictionary that contains the keyword arguments to pass to the 246 | `from_pretrained` method when loading the tokenizer. 247 | 248 | Returns 249 | ------- 250 | A `TransformersModel` model instance. 251 | 252 | """ 253 | try: 254 | from transformers import AutoModelForCausalLM 255 | except ImportError: 256 | raise ImportError( 257 | "The `transformers` library needs to be installed in order to use `transformers` models." 258 | ) 259 | 260 | if device is not None: 261 | model_kwargs["device_map"] = device 262 | 263 | model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) 264 | tokenizer = TransformerTokenizer(model_name, **tokenizer_kwargs) 265 | 266 | return Transformer(model, tokenizer) 267 | -------------------------------------------------------------------------------- /python/sglang/srt/models/llava.py: -------------------------------------------------------------------------------- 1 | """Inference-only LLaVa model compatible with HuggingFace weights.""" 2 | 3 | import json 4 | import os 5 | from typing import Any, Dict, List, Optional, Tuple 6 | 7 | import numpy as np 8 | import torch 9 | from sglang.srt.managers.router.infer_batch import ForwardMode 10 | from sglang.srt.managers.router.model_runner import InputMetadata 11 | from sglang.srt.models.llama2 import LlamaForCausalLM 12 | from torch import nn 13 | from transformers import CLIPImageProcessor, CLIPVisionModel, LlavaConfig 14 | from transformers.models.llava.modeling_llava import LlavaMultiModalProjector 15 | from vllm.model_executor.layers.linear import LinearMethodBase 16 | from vllm.model_executor.weight_utils import ( 17 | default_weight_loader, 18 | hf_model_weights_iterator, 19 | ) 20 | 21 | 22 | class LlavaLlamaForCausalLM(nn.Module): 23 | def __init__( 24 | self, 25 | config: LlavaConfig, 26 | linear_method: Optional[LinearMethodBase] = None, 27 | ) -> None: 28 | super().__init__() 29 | self.config = config 30 | self.vision_tower = None 31 | self.config.vision_config.hidden_size = config.mm_hidden_size 32 | self.config.text_config.hidden_size = config.hidden_size 33 | self.multi_modal_projector = LlavaMultiModalProjector(config) 34 | self.language_model = LlamaForCausalLM(config, linear_method) 35 | 36 | def pad_input_ids(self, input_ids, pad_value): 37 | pad_ids = pad_value * ( 38 | (self.image_feature_len + len(pad_value)) // len(pad_value) 39 | ) 40 | offset = input_ids.index(self.config.image_token_index) 41 | # old_len + pad_len - 1, because we need to remove image_token_id 42 | new_input_ids = ( 43 | input_ids[:offset] 44 | + pad_ids[: self.image_feature_len] 45 | + input_ids[offset + 1 :] 46 | ) 47 | return new_input_ids, offset 48 | 49 | def forward( 50 | self, 51 | input_ids: torch.LongTensor, 52 | positions: torch.Tensor, 53 | input_metadata: InputMetadata, 54 | pixel_values: Optional[List[Optional[np.array]]] = None, 55 | image_offsets: Optional[List[int]] = None, 56 | ) -> torch.Tensor: 57 | if input_metadata.forward_mode == ForwardMode.EXTEND: 58 | bs = input_metadata.batch_size 59 | 60 | # Embed text input 61 | input_embeds = self.language_model.model.embed_tokens(input_ids) 62 | 63 | # Embed vision input 64 | need_vision = ( 65 | (positions[input_metadata.extend_start_loc] < self.image_feature_len) 66 | .cpu() 67 | .numpy() 68 | ) 69 | # FIXME: We need to substract the length of the system prompt 70 | has_pixel = np.array([pixel_values[i] is not None for i in range(bs)]) 71 | need_vision = need_vision & has_pixel 72 | 73 | if need_vision.any(): 74 | pixel_values = torch.tensor( 75 | np.array([pixel_values[i] for i in range(bs) if need_vision[i]]), 76 | device=self.vision_tower.device, 77 | ) 78 | 79 | image_outputs = self.vision_tower( 80 | pixel_values, output_hidden_states=True 81 | ) 82 | # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated. 83 | 84 | selected_image_feature = image_outputs.hidden_states[ 85 | self.vision_feature_layer 86 | ] 87 | if self.vision_feature_select_strategy in ["default", "patch"]: 88 | selected_image_feature = selected_image_feature[:, 1:] 89 | elif self.vision_feature_select_strategy == "full": 90 | selected_image_feature = selected_image_feature 91 | else: 92 | raise ValueError( 93 | f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" 94 | ) 95 | image_features = self.multi_modal_projector(selected_image_feature) 96 | 97 | extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() 98 | pt = 0 99 | for i in range(bs): 100 | if not need_vision[i]: 101 | continue 102 | 103 | start_idx = extend_start_loc_cpu[i] 104 | pad_len, pad_dim = image_features[pt].shape 105 | dim = input_embeds.shape[1] 106 | assert ( 107 | pad_dim == dim 108 | ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) 109 | # Fill in the placeholder for the image 110 | try: 111 | input_embeds[ 112 | start_idx 113 | + image_offsets[i] : start_idx 114 | + image_offsets[i] 115 | + pad_len 116 | ] = image_features[pt] 117 | except RuntimeError as e: 118 | print(f"RuntimeError in llava image encoding: {e}") 119 | print(input_embeds.shape) 120 | print(start_idx, image_offsets[i]) 121 | pt += 1 122 | 123 | return self.language_model( 124 | input_embeds, positions, input_metadata, skip_embed=True 125 | ) 126 | elif input_metadata.forward_mode == ForwardMode.DECODE: 127 | return self.language_model( 128 | input_ids, positions, input_metadata, skip_embed=False 129 | ) 130 | 131 | def load_weights( 132 | self, 133 | model_name_or_path: str, 134 | cache_dir: Optional[str] = None, 135 | load_format: str = "auto", 136 | revision: Optional[str] = None, 137 | ): 138 | # load clip vision model by cfg['mm_vision_tower']: 139 | # huggingface_name or path_of_clip_relative_to_llava_model_dir 140 | vision_path = self.config.mm_vision_tower 141 | self.vision_tower = CLIPVisionModel.from_pretrained( 142 | vision_path, torch_dtype=torch.float16 143 | ).cuda() 144 | self.vision_tower.eval() 145 | 146 | self.vision_feature_layer = self.config.mm_vision_select_layer 147 | self.vision_feature_select_strategy = self.config.mm_vision_select_feature 148 | self.image_size = self.vision_tower.config.image_size 149 | self.patch_size = self.vision_tower.config.patch_size 150 | self.image_feature_len = int((self.image_size / self.patch_size) ** 2) 151 | if self.vision_feature_select_strategy == "patch": 152 | pass 153 | elif self.vision_feature_select_strategy == "cls_patch": 154 | self.image_feature_len += 1 155 | else: 156 | raise ValueError(f"Unexpected select feature: {self.select_feature}") 157 | 158 | # load mm_projector 159 | # TODO: support TP? 160 | projector_weights = { 161 | "model.mm_projector.0": "multi_modal_projector.linear_1", 162 | "model.mm_projector.2": "multi_modal_projector.linear_2", 163 | } 164 | params_dict = dict(self.named_parameters()) 165 | for name, loaded_weight in hf_model_weights_iterator( 166 | model_name_or_path, cache_dir, load_format, revision 167 | ): 168 | # FIXME: why projector weights read two times? 169 | if "projector" in name: 170 | for weight_name, param_name in projector_weights.items(): 171 | if weight_name in name: 172 | name = name.replace(weight_name, param_name) 173 | param = params_dict[name] 174 | weight_loader = getattr(param, "weight_loader", default_weight_loader) 175 | weight_loader(param, loaded_weight) 176 | 177 | # load language model 178 | self.language_model.load_weights( 179 | model_name_or_path, cache_dir, load_format, revision 180 | ) 181 | 182 | monkey_path_clip_vision_embed_forward() 183 | 184 | 185 | first_call = True 186 | 187 | 188 | def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: 189 | batch_size = pixel_values.shape[0] 190 | 191 | # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G. 192 | global first_call 193 | if first_call: 194 | self.patch_embedding.cpu().float() 195 | first_call = False 196 | pixel_values = pixel_values.to(dtype=torch.float32, device="cpu") 197 | patch_embeds = self.patch_embedding(pixel_values).cuda().half() 198 | 199 | patch_embeds = patch_embeds.flatten(2).transpose(1, 2) 200 | 201 | class_embeds = self.class_embedding.expand(batch_size, 1, -1) 202 | embeddings = torch.cat([class_embeds, patch_embeds], dim=1) 203 | embeddings = embeddings + self.position_embedding(self.position_ids) 204 | return embeddings 205 | 206 | 207 | def monkey_path_clip_vision_embed_forward(): 208 | import transformers 209 | 210 | setattr( 211 | transformers.models.clip.modeling_clip.CLIPVisionEmbeddings, 212 | "forward", 213 | clip_vision_embed_forward, 214 | ) 215 | -------------------------------------------------------------------------------- /python/sglang/srt/parallel_utils/parallel_state.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The vLLM team. 2 | # Adapted from 3 | # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py 4 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 5 | """Tensor and pipeline parallel groups.""" 6 | 7 | import torch 8 | 9 | # Tensor model parallel group that the current rank belongs to. 10 | _TENSOR_MODEL_PARALLEL_GROUP = None 11 | # Pipeline model parallel group that the current rank belongs to. 12 | _PIPELINE_MODEL_PARALLEL_GROUP = None 13 | 14 | # A list of global ranks for each pipeline group to ease calculation of the 15 | # source rank when broadcasting from the first or last pipeline stage. 16 | _PIPELINE_GLOBAL_RANKS = None 17 | 18 | 19 | def initialize_model_parallel( 20 | tensor_model_parallel_size: int = 1, 21 | pipeline_model_parallel_size: int = 1, 22 | ) -> None: 23 | """ 24 | Initialize model parallel groups. 25 | 26 | Arguments: 27 | tensor_model_parallel_size: number of GPUs used for tensor model 28 | parallelism. 29 | pipeline_model_parallel_size: number of GPUs used for pipeline model 30 | parallelism. 31 | 32 | Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we 33 | use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize 34 | the model pipeline. The present function will 35 | create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: 36 | 4 tensor model-parallel groups: 37 | [g0, g1], [g2, g3], [g4, g5], [g6, g7] 38 | 2 pipeline model-parallel groups: 39 | [g0, g2, g4, g6], [g1, g3, g5, g7] 40 | Note that for efficiency, the caller should make sure adjacent ranks 41 | are on the same DGX box. For example if we are using 2 DGX-1 boxes 42 | with a total of 16 GPUs, rank 0 to 7 belong to the first box and 43 | ranks 8 to 15 belong to the second box. 44 | """ 45 | # Get world size and rank. Ensure some consistencies. 46 | assert torch.distributed.is_initialized() 47 | world_size: int = torch.distributed.get_world_size() 48 | 49 | if world_size != tensor_model_parallel_size * pipeline_model_parallel_size: 50 | raise RuntimeError( 51 | f"world_size ({world_size}) is not equal to " 52 | f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " 53 | f"pipeline_model_parallel_size ({pipeline_model_parallel_size})" 54 | ) 55 | 56 | num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size 57 | num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size 58 | rank = torch.distributed.get_rank() 59 | 60 | # Build the tensor model-parallel groups. 61 | global _TENSOR_MODEL_PARALLEL_GROUP 62 | assert ( 63 | _TENSOR_MODEL_PARALLEL_GROUP is None 64 | ), "tensor model parallel group is already initialized" 65 | for i in range(num_tensor_model_parallel_groups): 66 | ranks = range( 67 | i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size 68 | ) 69 | group = torch.distributed.new_group(ranks) 70 | if rank in ranks: 71 | _TENSOR_MODEL_PARALLEL_GROUP = group 72 | 73 | # Build the pipeline model-parallel groups. 74 | global _PIPELINE_MODEL_PARALLEL_GROUP 75 | global _PIPELINE_GLOBAL_RANKS 76 | assert ( 77 | _PIPELINE_MODEL_PARALLEL_GROUP is None 78 | ), "pipeline model parallel group is already initialized" 79 | for i in range(num_pipeline_model_parallel_groups): 80 | ranks = range(i, world_size, num_pipeline_model_parallel_groups) 81 | group = torch.distributed.new_group(ranks) 82 | if rank in ranks: 83 | _PIPELINE_MODEL_PARALLEL_GROUP = group 84 | _PIPELINE_GLOBAL_RANKS = ranks 85 | 86 | 87 | def model_parallel_is_initialized(): 88 | """Check if tensor and pipeline parallel groups are initialized.""" 89 | return ( 90 | _TENSOR_MODEL_PARALLEL_GROUP is not None 91 | and _PIPELINE_MODEL_PARALLEL_GROUP is not None 92 | ) 93 | 94 | 95 | def get_tensor_model_parallel_group(): 96 | """Get the tensor model parallel group the caller rank belongs to.""" 97 | assert ( 98 | _TENSOR_MODEL_PARALLEL_GROUP is not None 99 | ), "tenosr model parallel group is not initialized" 100 | return _TENSOR_MODEL_PARALLEL_GROUP 101 | 102 | 103 | def get_pipeline_model_parallel_group(): 104 | """Get the pipeline model parallel group the caller rank belongs to.""" 105 | assert ( 106 | _PIPELINE_MODEL_PARALLEL_GROUP is not None 107 | ), "pipeline model parallel group is not initialized" 108 | return _PIPELINE_MODEL_PARALLEL_GROUP 109 | 110 | 111 | def get_tensor_model_parallel_world_size(): 112 | """Return world size for the tensor model parallel group.""" 113 | return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) 114 | 115 | 116 | def get_pipeline_model_parallel_world_size(): 117 | """Return world size for the pipeline model parallel group.""" 118 | return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group()) 119 | 120 | 121 | def get_tensor_model_parallel_rank(): 122 | """Return my rank for the tensor model parallel group.""" 123 | return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) 124 | 125 | 126 | def get_pipeline_model_parallel_rank(): 127 | """Return my rank for the pipeline model parallel group.""" 128 | return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) 129 | 130 | 131 | def get_tensor_model_parallel_src_rank(): 132 | """Calculate the global rank corresponding to the first local rank 133 | in the tensor model parallel group.""" 134 | global_rank = torch.distributed.get_rank() 135 | local_world_size = get_tensor_model_parallel_world_size() 136 | return (global_rank // local_world_size) * local_world_size 137 | 138 | 139 | def get_pipeline_model_parallel_first_rank(): 140 | """Return the global rank of the first process in the pipeline for the 141 | current tensor parallel group""" 142 | assert ( 143 | _PIPELINE_GLOBAL_RANKS is not None 144 | ), "Pipeline parallel group is not initialized" 145 | return _PIPELINE_GLOBAL_RANKS[0] 146 | 147 | 148 | def get_pipeline_model_parallel_last_rank(): 149 | """Return the global rank of the last process in the pipeline for the 150 | current tensor parallel group""" 151 | assert ( 152 | _PIPELINE_GLOBAL_RANKS is not None 153 | ), "Pipeline parallel group is not initialized" 154 | last_rank_local = get_pipeline_model_parallel_world_size() - 1 155 | return _PIPELINE_GLOBAL_RANKS[last_rank_local] 156 | 157 | 158 | def get_pipeline_model_parallel_next_rank(): 159 | """Return the global rank that follows the caller in the pipeline""" 160 | assert ( 161 | _PIPELINE_GLOBAL_RANKS is not None 162 | ), "Pipeline parallel group is not initialized" 163 | rank_in_pipeline = get_pipeline_model_parallel_rank() 164 | world_size = get_pipeline_model_parallel_world_size() 165 | return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] 166 | 167 | 168 | def get_pipeline_model_parallel_prev_rank(): 169 | """Return the global rank that preceeds the caller in the pipeline""" 170 | assert ( 171 | _PIPELINE_GLOBAL_RANKS is not None 172 | ), "Pipeline parallel group is not initialized" 173 | rank_in_pipeline = get_pipeline_model_parallel_rank() 174 | world_size = get_pipeline_model_parallel_world_size() 175 | return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] 176 | 177 | 178 | def destroy_model_parallel(): 179 | """Set the groups to none.""" 180 | global _TENSOR_MODEL_PARALLEL_GROUP 181 | _TENSOR_MODEL_PARALLEL_GROUP = None 182 | global _PIPELINE_MODEL_PARALLEL_GROUP 183 | _PIPELINE_MODEL_PARALLEL_GROUP = None 184 | global _PIPELINE_GLOBAL_RANKS 185 | _PIPELINE_GLOBAL_RANKS = None 186 | 187 | 188 | def tensor_model_parallel_all_reduce(input_): 189 | """All-reduce the input tensor across model parallel group. 190 | 191 | NOTE: This operation is applied in-place on the input tensor. 192 | """ 193 | # Bypass the function if we are using only 1 GPU. 194 | if get_tensor_model_parallel_world_size() == 1: 195 | return input_ 196 | # All-reduce. 197 | torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) 198 | return input_ 199 | 200 | 201 | def tensor_model_parallel_all_gather(input_, dim=-1): 202 | """All-gather the input tensor across model parallel group.""" 203 | world_size = get_tensor_model_parallel_world_size() 204 | # Bypass the function if we are using only 1 GPU. 205 | if world_size == 1: 206 | return input_ 207 | assert ( 208 | -input_.dim() <= dim < input_.dim() 209 | ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" 210 | if dim < 0: 211 | # Convert negative dim to positive. 212 | dim += input_.dim() 213 | input_size = input_.size() 214 | # Allocate output tensor. 215 | output_tensor = torch.empty( 216 | (world_size,) + input_size, dtype=input_.dtype, device=input_.device 217 | ) 218 | # All-gather. 219 | torch.distributed.all_gather_into_tensor( 220 | output_tensor, input_, group=get_tensor_model_parallel_group() 221 | ) 222 | # Reshape 223 | output_tensor = output_tensor.movedim(0, dim) 224 | output_tensor = output_tensor.reshape( 225 | input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :] 226 | ) 227 | return output_tensor 228 | -------------------------------------------------------------------------------- /python/sglang/srt/layers/token_attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from 2 | # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py 3 | # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py 4 | import torch 5 | import triton 6 | import triton.language as tl 7 | from sglang.srt.utils import wrap_kernel_launcher 8 | 9 | 10 | @triton.jit 11 | def _fwd_kernel_stage1( 12 | Q, 13 | K_Buffer, 14 | sm_scale, 15 | Req_to_tokens, 16 | B_req_idx, 17 | B_Start_Loc, 18 | B_Seqlen, 19 | Att_Out, 20 | stride_req_to_tokens_b, 21 | stride_qbs, 22 | stride_qh, 23 | stride_buf_kbs, 24 | stride_buf_kh, 25 | att_stride_h, 26 | kv_group_num: tl.constexpr, 27 | BLOCK_DMODEL: tl.constexpr, 28 | BLOCK_N: tl.constexpr, 29 | ): 30 | cur_batch = tl.program_id(0) 31 | cur_head = tl.program_id(1) 32 | start_n = tl.program_id(2) 33 | 34 | cur_kv_head = cur_head // kv_group_num 35 | 36 | offs_d = tl.arange(0, BLOCK_DMODEL) 37 | cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) 38 | cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) 39 | cur_batch_req_idx = tl.load(B_req_idx + cur_batch) 40 | 41 | cur_batch_start_index = 0 42 | cur_batch_end_index = cur_batch_seq_len 43 | 44 | off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d 45 | 46 | offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) 47 | 48 | block_stard_index = start_n * BLOCK_N 49 | block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) 50 | 51 | for start_mark in range(0, block_mask, 1): 52 | q = tl.load(Q + off_q + start_mark) 53 | offs_n_new = cur_batch_start_index + offs_n 54 | k_loc = tl.load( 55 | Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, 56 | mask=offs_n_new < cur_batch_end_index, 57 | other=0, 58 | ) 59 | offs_buf_k = ( 60 | k_loc[:, None] * stride_buf_kbs 61 | + cur_kv_head * stride_buf_kh 62 | + offs_d[None, :] 63 | ) 64 | k = tl.load( 65 | K_Buffer + offs_buf_k, 66 | mask=offs_n_new[:, None] < cur_batch_end_index, 67 | other=0.0, 68 | ) 69 | att_value = tl.sum(q[None, :] * k, 1) 70 | att_value *= sm_scale 71 | off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) 72 | tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) 73 | 74 | 75 | @triton.jit 76 | def _fwd_kernel_stage2( 77 | Logics, 78 | V_Buffer, 79 | Out, 80 | Req_to_tokens, 81 | B_req_idx, 82 | B_Start_Loc, 83 | B_Seqlen, 84 | stride_logic_h, 85 | stride_buf_vbs, 86 | stride_buf_vh, 87 | stride_obs, 88 | stride_oh, 89 | stride_req_to_token_b, 90 | other_kv_index, # To fix a NAN issue 91 | kv_group_num: tl.constexpr, 92 | BLOCK_DMODEL: tl.constexpr, 93 | BLOCK_N: tl.constexpr, 94 | ): 95 | cur_batch = tl.program_id(0) 96 | cur_head = tl.program_id(1) 97 | 98 | cur_kv_head = cur_head // kv_group_num 99 | 100 | cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) 101 | cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) 102 | cur_batch_req_idx = tl.load(B_req_idx + cur_batch) 103 | 104 | offs_n = tl.arange(0, BLOCK_N) 105 | offs_d = tl.arange(0, BLOCK_DMODEL) 106 | 107 | offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :] 108 | v_ptrs = V_Buffer + offs_buf_v 109 | 110 | e_max = float("-inf") 111 | e_sum = 0.0 112 | acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) 113 | 114 | for start_n in range(0, cur_batch_seq_len, BLOCK_N): 115 | start_n = tl.multiple_of(start_n, BLOCK_N) 116 | v_index = tl.load( 117 | Req_to_tokens 118 | + cur_batch_req_idx * stride_req_to_token_b 119 | + (start_n + offs_n), 120 | mask=(start_n + offs_n) < cur_batch_seq_len, 121 | other=other_kv_index, 122 | ) 123 | 124 | qk = tl.load( 125 | Logics 126 | + cur_head * stride_logic_h 127 | + (cur_batch_start_loc + start_n + offs_n), 128 | mask=start_n + offs_n < cur_batch_seq_len, 129 | other=float("-inf"), 130 | ) 131 | 132 | n_e_max = tl.maximum(tl.max(qk, 0), e_max) 133 | old_scale = tl.exp(e_max - n_e_max) 134 | p = tl.exp(qk - n_e_max) 135 | e_sum = e_sum * old_scale + tl.sum(p, 0) 136 | v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs) 137 | acc = acc * old_scale + tl.sum(p[:, None] * v, 0) 138 | e_max = n_e_max 139 | 140 | acc = acc / e_sum 141 | off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d 142 | out_ptrs = Out + off_o 143 | tl.store(out_ptrs, acc) 144 | 145 | 146 | cached_kernel_stage1 = None 147 | cached_kernel_stage2 = None 148 | 149 | 150 | def _token_att_m_fwd( 151 | q, 152 | k_buffer, 153 | att_out, 154 | Req_to_tokens, 155 | B_req_idx, 156 | B_Start_Loc, 157 | B_Seqlen, 158 | max_len_in_batch, 159 | ): 160 | BLOCK = 32 161 | # shape constraints 162 | Lq, Lk = q.shape[-1], k_buffer.shape[-1] 163 | assert Lq == Lk 164 | assert Lk in {16, 32, 64, 128} 165 | sm_scale = 1.0 / (Lk**0.5) 166 | 167 | batch, head_num = B_req_idx.shape[0], q.shape[1] 168 | 169 | grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK)) 170 | kv_group_num = q.shape[1] // k_buffer.shape[1] 171 | 172 | if kv_group_num == 1: 173 | num_warps = 4 174 | else: 175 | num_warps = 2 176 | 177 | global cached_kernel_stage1 178 | if cached_kernel_stage1: 179 | cached_kernel_stage1( 180 | grid, 181 | num_warps, 182 | q, 183 | k_buffer, 184 | sm_scale, 185 | Req_to_tokens, 186 | B_req_idx, 187 | B_Start_Loc, 188 | B_Seqlen, 189 | att_out, 190 | Req_to_tokens.stride(0), 191 | q.stride(0), 192 | q.stride(1), 193 | k_buffer.stride(0), 194 | k_buffer.stride(1), 195 | att_out.stride(0), 196 | kv_group_num=kv_group_num, 197 | BLOCK_DMODEL=Lk, 198 | BLOCK_N=BLOCK, 199 | num_stages=1, 200 | ) 201 | return 202 | 203 | # Launch kernel using modern Triton API 204 | kernel_launcher = wrap_kernel_launcher(_fwd_kernel_stage1) 205 | kernel_launcher( 206 | grid, 207 | num_warps, 208 | q, 209 | k_buffer, 210 | sm_scale, 211 | Req_to_tokens, 212 | B_req_idx, 213 | B_Start_Loc, 214 | B_Seqlen, 215 | att_out, 216 | Req_to_tokens.stride(0), 217 | q.stride(0), 218 | q.stride(1), 219 | k_buffer.stride(0), 220 | k_buffer.stride(1), 221 | att_out.stride(0), 222 | kv_group_num=kv_group_num, 223 | BLOCK_DMODEL=Lk, 224 | BLOCK_N=BLOCK, 225 | num_stages=1, 226 | ) 227 | cached_kernel_stage1 = kernel_launcher 228 | 229 | 230 | def _token_softmax_reducev_fwd( 231 | logics, 232 | v_buffer, 233 | o, 234 | req_to_tokens, 235 | b_req_idx, 236 | b_start_loc, 237 | b_seq_len, 238 | other_kv_index, 239 | ): 240 | BLOCK = 64 241 | batch, head = b_seq_len.shape[0], logics.shape[0] 242 | grid = (batch, head, 1) 243 | kv_group_num = logics.shape[0] // v_buffer.shape[1] 244 | 245 | num_warps = 1 246 | 247 | global cached_kernel_stage2 248 | if cached_kernel_stage2: 249 | cached_kernel_stage2( 250 | grid, 251 | num_warps, 252 | logics, 253 | v_buffer, 254 | o, 255 | req_to_tokens, 256 | b_req_idx, 257 | b_start_loc, 258 | b_seq_len, 259 | logics.stride(0), 260 | v_buffer.stride(0), 261 | v_buffer.stride(1), 262 | o.stride(0), 263 | o.stride(1), 264 | req_to_tokens.stride(0), 265 | other_kv_index, 266 | kv_group_num=kv_group_num, 267 | BLOCK_DMODEL=v_buffer.shape[-1], 268 | BLOCK_N=BLOCK, 269 | num_stages=3, 270 | ) 271 | return 272 | 273 | # Launch kernel using modern Triton API 274 | kernel_launcher = wrap_kernel_launcher(_fwd_kernel_stage2) 275 | kernel_launcher( 276 | grid, 277 | num_warps, 278 | logics, 279 | v_buffer, 280 | o, 281 | req_to_tokens, 282 | b_req_idx, 283 | b_start_loc, 284 | b_seq_len, 285 | logics.stride(0), 286 | v_buffer.stride(0), 287 | v_buffer.stride(1), 288 | o.stride(0), 289 | o.stride(1), 290 | req_to_tokens.stride(0), 291 | other_kv_index, 292 | kv_group_num=kv_group_num, 293 | BLOCK_DMODEL=v_buffer.shape[-1], 294 | BLOCK_N=BLOCK, 295 | num_stages=3, 296 | ) 297 | cached_kernel_stage2 = kernel_launcher 298 | 299 | 300 | def token_attention_fwd( 301 | q, 302 | k_buffer, 303 | v_buffer, 304 | o, 305 | req_to_token, 306 | b_req_idx, 307 | b_start_loc, 308 | b_seq_len, 309 | max_len_in_batch, 310 | other_kv_index, 311 | total_num_tokens, 312 | att_m=None, 313 | ): 314 | if att_m is None: 315 | att_m = torch.empty( 316 | (q.shape[-2], total_num_tokens), dtype=q.dtype, device="cuda" 317 | ) 318 | 319 | _token_att_m_fwd( 320 | q, 321 | k_buffer, 322 | att_m, 323 | req_to_token, 324 | b_req_idx, 325 | b_start_loc, 326 | b_seq_len, 327 | max_len_in_batch, 328 | ) 329 | _token_softmax_reducev_fwd( 330 | att_m, 331 | v_buffer, 332 | o, 333 | req_to_token, 334 | b_req_idx, 335 | b_start_loc, 336 | b_seq_len, 337 | other_kv_index, 338 | ) 339 | -------------------------------------------------------------------------------- /python/sglang/srt/models/llama2.py: -------------------------------------------------------------------------------- 1 | # Adapted from 2 | # https://github.com/vllm-project/vllm/blob/671af2b1c0b3ed6d856d37c21a561cc429a10701/vllm/model_executor/models/llama.py#L1 3 | """Inference-only LLaMA model compatible with HuggingFace weights.""" 4 | from typing import Any, Dict, List, Optional, Tuple 5 | 6 | import torch 7 | from sglang.srt.layers.activation import SiluAndMul 8 | from sglang.srt.layers.layernorm import RMSNorm 9 | from sglang.srt.layers.linear import ( 10 | LinearMethodBase, 11 | MergedColumnParallelLinear, 12 | QKVParallelLinear, 13 | RowParallelLinear, 14 | ) 15 | from sglang.srt.layers.logits_processor import LogitsProcessor 16 | from sglang.srt.layers.radix_attention import RadixAttention 17 | from sglang.srt.layers.rotary_embedding import get_rope 18 | from sglang.srt.layers.vocab_parallel_embedding import ( 19 | ParallelLMHead, 20 | VocabParallelEmbedding, 21 | ) 22 | from sglang.srt.managers.router.model_runner import InputMetadata 23 | from sglang.srt.parallel_utils.parallel_state import ( 24 | get_tensor_model_parallel_world_size, 25 | ) 26 | from sglang.srt.utils import ( 27 | default_weight_loader, 28 | hf_model_weights_iterator, 29 | ) 30 | from torch import nn 31 | from transformers import LlamaConfig 32 | 33 | 34 | class LlamaMLP(nn.Module): 35 | def __init__( 36 | self, 37 | hidden_size: int, 38 | intermediate_size: int, 39 | hidden_act: str, 40 | linear_method: Optional[LinearMethodBase] = None, 41 | ) -> None: 42 | super().__init__() 43 | self.gate_up_proj = MergedColumnParallelLinear( 44 | hidden_size, 45 | [intermediate_size] * 2, 46 | bias=False, 47 | linear_method=linear_method, 48 | ) 49 | self.down_proj = RowParallelLinear( 50 | intermediate_size, hidden_size, bias=False, linear_method=linear_method 51 | ) 52 | if hidden_act != "silu": 53 | raise ValueError( 54 | f"Unsupported activation: {hidden_act}. " 55 | "Only silu is supported for now." 56 | ) 57 | self.act_fn = SiluAndMul() 58 | 59 | def forward(self, x): 60 | gate_up, _ = self.gate_up_proj(x) 61 | x = self.act_fn(gate_up) 62 | x, _ = self.down_proj(x) 63 | return x 64 | 65 | 66 | class LlamaAttention(nn.Module): 67 | def __init__( 68 | self, 69 | hidden_size: int, 70 | num_heads: int, 71 | num_kv_heads: int, 72 | layer_id: int = 0, 73 | rope_theta: float = 10000, 74 | rope_scaling: Optional[Dict[str, Any]] = None, 75 | max_position_embeddings: int = 8192, 76 | linear_method: Optional[LinearMethodBase] = None, 77 | ) -> None: 78 | super().__init__() 79 | self.hidden_size = hidden_size 80 | tp_size = get_tensor_model_parallel_world_size() 81 | self.total_num_heads = num_heads 82 | assert self.total_num_heads % tp_size == 0 83 | self.num_heads = self.total_num_heads // tp_size 84 | self.total_num_kv_heads = num_kv_heads 85 | if self.total_num_kv_heads >= tp_size: 86 | # Number of KV heads is greater than TP size, so we partition 87 | # the KV heads across multiple tensor parallel GPUs. 88 | assert self.total_num_kv_heads % tp_size == 0 89 | else: 90 | # Number of KV heads is less than TP size, so we replicate 91 | # the KV heads across multiple tensor parallel GPUs. 92 | assert tp_size % self.total_num_kv_heads == 0 93 | self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) 94 | self.head_dim = hidden_size // self.total_num_heads 95 | self.q_size = self.num_heads * self.head_dim 96 | self.kv_size = self.num_kv_heads * self.head_dim 97 | self.scaling = self.head_dim**-0.5 98 | self.rope_theta = rope_theta 99 | self.max_position_embeddings = max_position_embeddings 100 | 101 | self.qkv_proj = QKVParallelLinear( 102 | hidden_size, 103 | self.head_dim, 104 | self.total_num_heads, 105 | self.total_num_kv_heads, 106 | bias=False, 107 | linear_method=linear_method, 108 | ) 109 | self.o_proj = RowParallelLinear( 110 | self.total_num_heads * self.head_dim, 111 | hidden_size, 112 | bias=False, 113 | linear_method=linear_method, 114 | ) 115 | 116 | self.rotary_emb = get_rope( 117 | self.head_dim, 118 | rotary_dim=self.head_dim, 119 | max_position=max_position_embeddings, 120 | base=rope_theta, 121 | rope_scaling=rope_scaling, 122 | ) 123 | self.attn = RadixAttention( 124 | self.num_heads, 125 | self.head_dim, 126 | self.scaling, 127 | num_kv_heads=self.num_kv_heads, 128 | layer_id=layer_id, 129 | ) 130 | 131 | def forward( 132 | self, 133 | positions: torch.Tensor, 134 | hidden_states: torch.Tensor, 135 | input_metadata: InputMetadata, 136 | ) -> torch.Tensor: 137 | qkv, _ = self.qkv_proj(hidden_states) 138 | q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) 139 | q, k = self.rotary_emb(positions, q, k) 140 | attn_output = self.attn(q, k, v, input_metadata) 141 | output, _ = self.o_proj(attn_output) 142 | return output 143 | 144 | 145 | class LlamaDecoderLayer(nn.Module): 146 | def __init__( 147 | self, 148 | config: LlamaConfig, 149 | layer_id: int = 0, 150 | linear_method: Optional[LinearMethodBase] = None, 151 | ) -> None: 152 | super().__init__() 153 | self.hidden_size = config.hidden_size 154 | rope_theta = getattr(config, "rope_theta", 10000) 155 | rope_scaling = getattr(config, "rope_scaling", None) 156 | max_position_embeddings = getattr(config, "max_position_embeddings", 8192) 157 | self.self_attn = LlamaAttention( 158 | hidden_size=self.hidden_size, 159 | num_heads=config.num_attention_heads, 160 | num_kv_heads=config.num_key_value_heads, 161 | layer_id=layer_id, 162 | rope_theta=rope_theta, 163 | rope_scaling=rope_scaling, 164 | max_position_embeddings=max_position_embeddings, 165 | linear_method=linear_method, 166 | ) 167 | self.mlp = LlamaMLP( 168 | hidden_size=self.hidden_size, 169 | intermediate_size=config.intermediate_size, 170 | hidden_act=config.hidden_act, 171 | linear_method=linear_method, 172 | ) 173 | self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 174 | self.post_attention_layernorm = RMSNorm( 175 | config.hidden_size, eps=config.rms_norm_eps 176 | ) 177 | 178 | def forward( 179 | self, 180 | positions: torch.Tensor, 181 | hidden_states: torch.Tensor, 182 | input_metadata: InputMetadata, 183 | residual: Optional[torch.Tensor], 184 | ) -> Tuple[torch.Tensor, torch.Tensor]: 185 | # Self Attention 186 | if residual is None: 187 | residual = hidden_states 188 | hidden_states = self.input_layernorm(hidden_states) 189 | else: 190 | hidden_states, residual = self.input_layernorm(hidden_states, residual) 191 | hidden_states = self.self_attn( 192 | positions=positions, 193 | hidden_states=hidden_states, 194 | input_metadata=input_metadata, 195 | ) 196 | 197 | # Fully Connected 198 | hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) 199 | hidden_states = self.mlp(hidden_states) 200 | return hidden_states, residual 201 | 202 | 203 | class LlamaModel(nn.Module): 204 | def __init__( 205 | self, 206 | config: LlamaConfig, 207 | linear_method: Optional[LinearMethodBase] = None, 208 | ) -> None: 209 | super().__init__() 210 | self.config = config 211 | self.padding_idx = config.pad_token_id 212 | self.vocab_size = config.vocab_size 213 | self.embed_tokens = VocabParallelEmbedding( 214 | config.vocab_size, 215 | config.hidden_size, 216 | ) 217 | self.layers = nn.ModuleList( 218 | [ 219 | LlamaDecoderLayer(config, i, linear_method) 220 | for i in range(config.num_hidden_layers) 221 | ] 222 | ) 223 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 224 | 225 | def forward( 226 | self, 227 | input_ids: torch.Tensor, 228 | positions: torch.Tensor, 229 | input_metadata: InputMetadata, 230 | skip_embed: bool = False, 231 | ) -> torch.Tensor: 232 | if not skip_embed: 233 | hidden_states = self.embed_tokens(input_ids) 234 | else: 235 | hidden_states = input_ids 236 | residual = None 237 | for i in range(len(self.layers)): 238 | layer = self.layers[i] 239 | hidden_states, residual = layer( 240 | positions, 241 | hidden_states, 242 | input_metadata, 243 | residual, 244 | ) 245 | hidden_states, _ = self.norm(hidden_states, residual) 246 | return hidden_states 247 | 248 | 249 | class LlamaForCausalLM(nn.Module): 250 | def __init__( 251 | self, 252 | config: LlamaConfig, 253 | linear_method: Optional[LinearMethodBase] = None, 254 | ) -> None: 255 | super().__init__() 256 | self.config = config 257 | self.linear_method = linear_method 258 | self.model = LlamaModel(config, linear_method) 259 | self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) 260 | self.logits_processor = LogitsProcessor(config) 261 | 262 | def forward( 263 | self, 264 | input_ids: torch.Tensor, 265 | positions: torch.Tensor, 266 | input_metadata: InputMetadata, 267 | skip_embed: bool = False, 268 | ) -> torch.Tensor: 269 | hidden_states = self.model(input_ids, positions, input_metadata, skip_embed) 270 | return self.logits_processor( 271 | input_ids, hidden_states, self.lm_head.weight, input_metadata 272 | ) 273 | 274 | def load_weights( 275 | self, 276 | model_name_or_path: str, 277 | cache_dir: Optional[str] = None, 278 | load_format: str = "auto", 279 | revision: Optional[str] = None, 280 | ): 281 | stacked_params_mapping = [ 282 | # (param_name, shard_name, shard_id) 283 | ("qkv_proj", "q_proj", "q"), 284 | ("qkv_proj", "k_proj", "k"), 285 | ("qkv_proj", "v_proj", "v"), 286 | ("gate_up_proj", "gate_proj", 0), 287 | ("gate_up_proj", "up_proj", 1), 288 | ] 289 | params_dict = dict(self.named_parameters()) 290 | for name, loaded_weight in hf_model_weights_iterator( 291 | model_name_or_path, cache_dir, load_format, revision 292 | ): 293 | if "rotary_emb.inv_freq" in name or "projector" in name: 294 | continue 295 | if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: 296 | # Models trained using ColossalAI may include these tensors in 297 | # the checkpoint. Skip them. 298 | continue 299 | for param_name, weight_name, shard_id in stacked_params_mapping: 300 | if weight_name not in name: 301 | continue 302 | name = name.replace(weight_name, param_name) 303 | # Skip loading extra bias for GPTQ models. 304 | if name.endswith(".bias") and name not in params_dict: 305 | continue 306 | param = params_dict[name] 307 | weight_loader = param.weight_loader 308 | weight_loader(param, loaded_weight, shard_id) 309 | break 310 | else: 311 | # Skip loading extra bias for GPTQ models. 312 | if name.endswith(".bias") and name not in params_dict: 313 | continue 314 | param = params_dict[name] 315 | weight_loader = getattr(param, "weight_loader", default_weight_loader) 316 | weight_loader(param, loaded_weight) 317 | -------------------------------------------------------------------------------- /python/sglang/srt/utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import glob 3 | import json 4 | import os 5 | import random 6 | import socket 7 | import sys 8 | import traceback 9 | from collections import defaultdict 10 | from io import BytesIO 11 | from typing import Any, Iterator, List, Optional, Tuple 12 | 13 | import numpy as np 14 | import requests 15 | import torch 16 | import torch.distributed as dist 17 | from huggingface_hub import snapshot_download 18 | from safetensors.torch import load_file, safe_open, save_file 19 | from tqdm.auto import tqdm 20 | from transformers import PretrainedConfig 21 | 22 | 23 | def alloc_usable_network_port(num, used_list=()): 24 | port_list = [] 25 | for port in range(10000, 65536): 26 | if port in used_list: 27 | continue 28 | 29 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 30 | try: 31 | s.bind(("", port)) 32 | port_list.append(port) 33 | except socket.error: 34 | pass 35 | 36 | if len(port_list) == num: 37 | return port_list 38 | return None 39 | 40 | 41 | def get_exception_traceback(): 42 | etype, value, tb = sys.exc_info() 43 | err_str = "".join(traceback.format_exception(etype, value, tb)) 44 | return err_str 45 | 46 | 47 | def get_int_token_logit_bias(tokenizer, vocab_size): 48 | from transformers import LlamaTokenizer, LlamaTokenizerFast 49 | 50 | logit_bias = np.zeros(vocab_size, dtype=np.float32) 51 | for t_id in range(vocab_size): 52 | ss = tokenizer.decode(t_id).strip() 53 | if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id): 54 | logit_bias[t_id] = -1e5 55 | # else: 56 | # print(ss, t_id) 57 | 58 | return logit_bias 59 | 60 | 61 | def wrap_kernel_launcher(kernel): 62 | """A faster launcher for triton kernels compatible with modern Triton versions.""" 63 | import torch.distributed as dist 64 | 65 | if dist.is_initialized(): 66 | rank = dist.get_rank() 67 | else: 68 | rank = 0 69 | 70 | # Handle different kernel cache structures across Triton versions 71 | try: 72 | # Try to get the actual kernel instance 73 | if hasattr(kernel, "cache") and rank in kernel.cache: 74 | kernels = list(kernel.cache[rank].values()) 75 | if kernels: 76 | kernel_instance = kernels[0] 77 | else: 78 | kernel_instance = kernel 79 | else: 80 | kernel_instance = kernel 81 | except (AttributeError, KeyError): 82 | kernel_instance = kernel 83 | 84 | # Modern Triton 3.4.0 compatible wrapper 85 | def modern_launcher(grid, num_warps, *args, **kwargs): 86 | # For Triton 3.4.0, we need to use the [grid] indexing syntax 87 | try: 88 | # Extract kwargs that are meant for kernel launch parameters 89 | launch_kwargs = {} 90 | kernel_kwargs = {} 91 | 92 | # Separate launch parameters from kernel parameters 93 | for key, value in kwargs.items(): 94 | if key in ["num_warps", "num_stages"]: 95 | launch_kwargs[key] = value 96 | else: 97 | kernel_kwargs[key] = value 98 | 99 | # Add num_warps if not in kwargs 100 | if "num_warps" not in launch_kwargs: 101 | launch_kwargs["num_warps"] = num_warps 102 | 103 | # Call the kernel with modern syntax 104 | return kernel_instance[grid](*args, **launch_kwargs, **kernel_kwargs) 105 | except Exception as e: 106 | # Fallback for different calling conventions 107 | try: 108 | # Try without kwargs 109 | return kernel_instance[grid](*args, num_warps=num_warps) 110 | except Exception: 111 | try: 112 | # Try direct call (for cached kernels) 113 | if callable(kernel_instance): 114 | return kernel_instance(grid, num_warps, *args) 115 | else: 116 | raise RuntimeError(f"Cannot launch kernel: {kernel_instance}") 117 | except Exception: 118 | raise RuntimeError(f"Failed to launch kernel: {e}") 119 | 120 | return modern_launcher 121 | 122 | 123 | def is_multimodal_model(model): 124 | if isinstance(model, str): 125 | return "llava" in model 126 | from sglang.srt.model_config import ModelConfig 127 | 128 | if isinstance(model, ModelConfig): 129 | return "llava" in model.path.lower() 130 | raise Exception("unrecognized type") 131 | 132 | 133 | def load_image(image_file): 134 | from PIL import Image 135 | 136 | image = None 137 | 138 | if image_file.startswith("http://") or image_file.startswith("https://"): 139 | timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) 140 | response = requests.get(image_file, timeout=timeout) 141 | image = Image.open(BytesIO(response.content)) 142 | elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")): 143 | image = Image.open(image_file) 144 | elif image_file.startswith("data:"): 145 | image_file = image_url.split(",")[1] 146 | image = Image.open(BytesIO(base64.b64decode(image_file))) 147 | else: 148 | image = Image.open(BytesIO(base64.b64decode(image_file))) 149 | 150 | return image 151 | 152 | 153 | """Utils for model executor.""" 154 | import random 155 | from typing import Any, Dict, Optional 156 | 157 | import numpy as np 158 | import torch 159 | 160 | 161 | def set_random_seed(seed: int) -> None: 162 | random.seed(seed) 163 | np.random.seed(seed) 164 | torch.manual_seed(seed) 165 | if torch.cuda.is_available(): 166 | torch.cuda.manual_seed_all(seed) 167 | 168 | 169 | def set_weight_attrs( 170 | weight: torch.Tensor, 171 | weight_attrs: Optional[Dict[str, Any]], 172 | ): 173 | """Set attributes on a weight tensor. 174 | 175 | This method is used to set attributes on a weight tensor. This method 176 | will not overwrite existing attributes. 177 | 178 | Args: 179 | weight: The weight tensor. 180 | weight_attrs: A dictionary of attributes to set on the weight tensor. 181 | """ 182 | if weight_attrs is None: 183 | return 184 | for key, value in weight_attrs.items(): 185 | assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}" 186 | setattr(weight, key, value) 187 | 188 | 189 | def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: 190 | """Default weight loader.""" 191 | assert param.size() == loaded_weight.size() 192 | param.data.copy_(loaded_weight) 193 | 194 | 195 | def prepare_hf_model_weights( 196 | model_name_or_path: str, 197 | cache_dir: Optional[str] = None, 198 | use_safetensors: bool = False, 199 | fall_back_to_pt: bool = True, 200 | revision: Optional[str] = None, 201 | ) -> Tuple[str, List[str], bool]: 202 | # Download model weights from huggingface. 203 | is_local = os.path.isdir(model_name_or_path) 204 | # Some quantized models use .pt files for storing the weights. 205 | allow_patterns = ["*.safetensors"] if use_safetensors else ["*.bin", "*.pt"] 206 | if not is_local: 207 | # Use file lock to prevent multiple processes from 208 | # downloading the same model weights at the same time. 209 | with get_lock(model_name_or_path, cache_dir): 210 | hf_folder = snapshot_download( 211 | model_name_or_path, 212 | allow_patterns=allow_patterns, 213 | cache_dir=cache_dir, 214 | tqdm_class=Disabledtqdm, 215 | revision=revision, 216 | ) 217 | else: 218 | hf_folder = model_name_or_path 219 | hf_weights_files: List[str] = [] 220 | for pattern in allow_patterns: 221 | hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) 222 | if not use_safetensors: 223 | # Exclude files that are not needed for inference. 224 | # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 225 | blacklist = [ 226 | "training_args.bin", 227 | "optimizer.bin", 228 | "optimizer.pt", 229 | "scheduler.pt", 230 | "scaler.pt", 231 | ] 232 | hf_weights_files = [ 233 | f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist) 234 | ] 235 | 236 | if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt: 237 | return prepare_hf_model_weights( 238 | model_name_or_path, 239 | cache_dir=cache_dir, 240 | use_safetensors=False, 241 | fall_back_to_pt=False, 242 | revision=revision, 243 | ) 244 | 245 | if len(hf_weights_files) == 0: 246 | raise RuntimeError(f"Cannot find any model weights with `{model_name_or_path}`") 247 | 248 | return hf_folder, hf_weights_files, use_safetensors 249 | 250 | 251 | def hf_model_weights_iterator( 252 | model_name_or_path: str, 253 | cache_dir: Optional[str] = None, 254 | load_format: str = "auto", 255 | revision: Optional[str] = None, 256 | ) -> Iterator[Tuple[str, torch.Tensor]]: 257 | use_safetensors = False 258 | use_np_cache = False 259 | fall_back_to_pt = False 260 | if load_format == "auto": 261 | use_safetensors = True 262 | fall_back_to_pt = True 263 | elif load_format == "safetensors": 264 | use_safetensors = True 265 | elif load_format == "pt": 266 | pass 267 | elif load_format == "npcache": 268 | use_np_cache = True 269 | else: 270 | raise ValueError(f"Unknown load_format: {load_format}") 271 | 272 | hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights( 273 | model_name_or_path, 274 | cache_dir=cache_dir, 275 | use_safetensors=use_safetensors, 276 | fall_back_to_pt=fall_back_to_pt, 277 | revision=revision, 278 | ) 279 | 280 | if use_np_cache: 281 | # Currently np_cache only support *.bin checkpoints 282 | assert use_safetensors is False 283 | 284 | # Convert the model weights from torch tensors to numpy arrays for 285 | # faster loading. 286 | np_folder = os.path.join(hf_folder, "np") 287 | os.makedirs(np_folder, exist_ok=True) 288 | weight_names_file = os.path.join(np_folder, "weight_names.json") 289 | # Use file lock to prevent multiple processes from 290 | # dumping the same model weights to numpy at the same time. 291 | with get_lock(model_name_or_path, cache_dir): 292 | if not os.path.exists(weight_names_file): 293 | weight_names = [] 294 | for bin_file in hf_weights_files: 295 | state = torch.load(bin_file, map_location="cpu") 296 | for name, param in state.items(): 297 | param_path = os.path.join(np_folder, name) 298 | with open(param_path, "wb") as f: 299 | np.save(f, param.cpu().detach().numpy()) 300 | weight_names.append(name) 301 | with open(weight_names_file, "w") as f: 302 | json.dump(weight_names, f) 303 | 304 | with open(weight_names_file, "r") as f: 305 | weight_names = json.load(f) 306 | 307 | for name in weight_names: 308 | param_path = os.path.join(np_folder, name) 309 | with open(param_path, "rb") as f: 310 | param = np.load(f) 311 | yield name, torch.from_numpy(param) 312 | elif use_safetensors: 313 | for st_file in hf_weights_files: 314 | with safe_open(st_file, framework="pt") as f: 315 | for name in f.keys(): # noqa: SIM118 316 | param = f.get_tensor(name) 317 | yield name, param 318 | else: 319 | for bin_file in hf_weights_files: 320 | state = torch.load(bin_file, map_location="cpu") 321 | for name, param in state.items(): 322 | yield name, param 323 | del state 324 | torch.cuda.empty_cache() 325 | 326 | 327 | import contextlib 328 | 329 | 330 | @contextlib.contextmanager 331 | def _set_default_torch_dtype(dtype: torch.dtype): 332 | """Sets the default torch dtype to the given dtype.""" 333 | old_dtype = torch.get_default_dtype() 334 | torch.set_default_dtype(dtype) 335 | yield 336 | torch.set_default_dtype(old_dtype) 337 | -------------------------------------------------------------------------------- /python/sglang/srt/managers/router/infer_batch.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | from typing import List 3 | 4 | import numpy as np 5 | import torch 6 | from sglang.srt.managers.router.radix_cache import RadixCache 7 | from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool 8 | 9 | 10 | class ForwardMode(Enum): 11 | PREFILL = auto() 12 | EXTEND = auto() 13 | DECODE = auto() 14 | 15 | 16 | class FinishReason(Enum): 17 | LENGTH = auto() 18 | EOS_TOKEN = auto() 19 | STOP_STR = auto() 20 | 21 | 22 | class Req: 23 | def __init__(self, rid): 24 | self.rid = rid 25 | self.input_ids = [] 26 | self.output_ids = [] 27 | self.pixel_values = None 28 | self.image_offset = 0 29 | self.sampling_params = None 30 | self.return_normalized_logprob = False 31 | self.normalized_logprob_start_len = 0 32 | self.stream = False 33 | 34 | self.tokenizer = None 35 | self.finished = False 36 | self.finish_reason = None 37 | self.hit_stop_str = None 38 | 39 | self.adjust_input_len = 0 40 | self.prefix_indices = [] 41 | 42 | self.normalized_logprob = None 43 | 44 | # for constrained decoding 45 | self.regex_fsm = None 46 | self.regex_fsm_state = None 47 | 48 | def max_new_tokens(self): 49 | return self.sampling_params.max_new_tokens 50 | 51 | def check_finished(self): 52 | if self.finished: 53 | return 54 | 55 | if len(self.output_ids) >= self.sampling_params.max_new_tokens: 56 | self.finished = True 57 | self.finish_reason = FinishReason.LENGTH 58 | return 59 | 60 | if ( 61 | self.output_ids[-1] == self.tokenizer.eos_token_id 62 | and self.sampling_params.ignore_eos == False 63 | ): 64 | self.finished = True 65 | self.finish_reason = FinishReason.EOS_TOKEN 66 | return 67 | 68 | if len(self.sampling_params.stop_strs) > 0: 69 | tail_str = self.tokenizer.decode( 70 | self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] 71 | ) 72 | 73 | for stop_str in self.sampling_params.stop_strs: 74 | if stop_str in tail_str: 75 | self.finished = True 76 | self.finish_reason = FinishReason.STOP_STR 77 | self.hit_stop_str = stop_str 78 | return 79 | 80 | def __repr__(self): 81 | return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, " 82 | 83 | 84 | class Batch: 85 | def __init__( 86 | self, 87 | reqs: List[Req], 88 | req_to_token_pool: ReqToTokenPool, 89 | token_to_kv_pool: TokenToKVPool, 90 | tree_cache: RadixCache, 91 | ): 92 | self.reqs = reqs 93 | self.req_to_token_pool = req_to_token_pool 94 | self.token_to_kv_pool = token_to_kv_pool 95 | self.tree_cache = tree_cache 96 | 97 | self.return_normalized_logprob = any( 98 | req.return_normalized_logprob for req in reqs 99 | ) 100 | 101 | def is_empty(self): 102 | return len(self.reqs) == 0 103 | 104 | def init_extend_batch(self, vocab_size: int, int_token_logit_bias: torch.Tensor): 105 | device = "cuda" 106 | bs = len(self.reqs) 107 | reqs = self.reqs 108 | input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs] 109 | prefix_indices = [r.prefix_indices for r in reqs] 110 | 111 | # Handle prefix 112 | flatten_input_ids = [] 113 | extend_lens = [] 114 | prefix_lens = [] 115 | seq_lens = [] 116 | 117 | req_pool_indices = self.req_to_token_pool.alloc(bs) 118 | req_pool_indices_cpu = req_pool_indices.cpu().numpy() 119 | for i in range(bs): 120 | flatten_input_ids.extend(input_ids[i]) 121 | extend_lens.append(len(input_ids[i])) 122 | 123 | if len(prefix_indices[i]) == 0: 124 | prefix_lens.append(0) 125 | else: 126 | prefix_lens.append(len(prefix_indices[i])) 127 | self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][ 128 | : len(prefix_indices[i]) 129 | ] = prefix_indices[i] 130 | 131 | seq_lens.append(prefix_lens[-1] + extend_lens[-1]) 132 | 133 | position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device) 134 | 135 | # Alloc mem 136 | seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens) 137 | extend_num_tokens = seq_lens.sum() - prefix_lens.sum() 138 | out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) 139 | if out_cache_loc is None: 140 | self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free) 141 | out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) 142 | 143 | if out_cache_loc is None: 144 | print("Prefill out of memory.") 145 | self.tree_cache.pretty_print() 146 | exit() 147 | 148 | pt = 0 149 | for i in range(bs): 150 | self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][ 151 | prefix_lens[i] : prefix_lens[i] + extend_lens[i] 152 | ] = out_cache_loc[pt : pt + extend_lens[i]] 153 | pt += extend_lens[i] 154 | 155 | # Handle logit bias 156 | logit_bias = torch.zeros((bs, vocab_size), dtype=torch.float32, device=device) 157 | for i in range(bs): 158 | if reqs[i].sampling_params.dtype == "int": 159 | logit_bias[i] = int_token_logit_bias 160 | 161 | # Set fields 162 | self.input_ids = torch.tensor( 163 | flatten_input_ids, dtype=torch.int32, device=device 164 | ) 165 | self.pixel_values = [r.pixel_values for r in reqs] 166 | self.image_offsets = [ 167 | r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens) 168 | ] 169 | self.req_pool_indices = req_pool_indices 170 | self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device) 171 | self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device) 172 | self.position_ids_offsets = position_ids_offsets 173 | self.extend_num_tokens = extend_num_tokens 174 | self.out_cache_loc = out_cache_loc 175 | 176 | self.temperatures = torch.tensor( 177 | [r.sampling_params.temperature for r in reqs], 178 | dtype=torch.float, 179 | device=device, 180 | ).view(-1, 1) 181 | self.top_ps = torch.tensor( 182 | [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device 183 | ).view(-1, 1) 184 | self.top_ks = torch.tensor( 185 | [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device 186 | ).view(-1, 1) 187 | self.frequency_penalties = torch.tensor( 188 | [r.sampling_params.frequency_penalty for r in reqs], 189 | dtype=torch.float, 190 | device=device, 191 | ) 192 | self.presence_penalties = torch.tensor( 193 | [r.sampling_params.presence_penalty for r in reqs], 194 | dtype=torch.float, 195 | device=device, 196 | ) 197 | self.logit_bias = logit_bias 198 | 199 | def update_for_decode(self, input_ids=None): 200 | if input_ids is None: 201 | input_ids = [ 202 | r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs 203 | ] 204 | self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda") 205 | self.seq_lens.add_(1) 206 | self.prefix_lens = None 207 | 208 | # Alloc mem 209 | bs = len(self.reqs) 210 | alloc_res = self.token_to_kv_pool.alloc_contiguous(bs) 211 | if alloc_res is None: 212 | self.out_cache_loc = self.token_to_kv_pool.alloc(bs) 213 | 214 | if self.out_cache_loc is None: 215 | self.tree_cache.evict(bs, self.token_to_kv_pool.free) 216 | self.out_cache_loc = self.token_to_kv_pool.alloc(bs) 217 | 218 | if self.out_cache_loc is None: 219 | print("Decode out of memory.") 220 | self.tree_cache.pretty_print() 221 | exit() 222 | 223 | self.out_cache_cont_start = None 224 | self.out_cache_cont_end = None 225 | else: 226 | self.out_cache_loc = alloc_res[0] 227 | self.out_cache_cont_start = alloc_res[1] 228 | self.out_cache_cont_end = alloc_res[2] 229 | 230 | self.req_to_token_pool.req_to_token[ 231 | self.req_pool_indices, self.seq_lens - 1 232 | ] = self.out_cache_loc 233 | 234 | def filter_batch(self, unfinished_indices: List[int]): 235 | self.reqs = [self.reqs[i] for i in unfinished_indices] 236 | new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda") 237 | self.seq_lens = self.seq_lens[new_indices] 238 | self.input_ids = None 239 | self.req_pool_indices = self.req_pool_indices[new_indices] 240 | self.prefix_lens = None 241 | self.position_ids_offsets = self.position_ids_offsets[new_indices] 242 | self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None 243 | 244 | for item in [ 245 | "temperatures", 246 | "top_ps", 247 | "top_ks", 248 | "frequency_penalties", 249 | "presence_penalties", 250 | "logit_bias", 251 | ]: 252 | setattr(self, item, getattr(self, item)[new_indices]) 253 | 254 | def merge(self, other): 255 | self.reqs.extend(other.reqs) 256 | 257 | self.req_pool_indices = torch.concat( 258 | [self.req_pool_indices, other.req_pool_indices] 259 | ) 260 | self.seq_lens = torch.concat([self.seq_lens, other.seq_lens]) 261 | self.prefix_lens = None 262 | self.position_ids_offsets = torch.concat( 263 | [self.position_ids_offsets, other.position_ids_offsets] 264 | ) 265 | self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None 266 | 267 | for item in [ 268 | "temperatures", 269 | "top_ps", 270 | "top_ks", 271 | "frequency_penalties", 272 | "presence_penalties", 273 | "logit_bias", 274 | ]: 275 | setattr( 276 | self, item, torch.concat([getattr(self, item), getattr(other, item)]) 277 | ) 278 | 279 | def sample(self, logits: torch.Tensor): 280 | # Post process logits 281 | logits = logits.contiguous() 282 | logits = logits / self.temperatures 283 | logits = logits + self.logit_bias 284 | 285 | has_regex = any(req.regex_fsm is not None for req in self.reqs) 286 | if has_regex: 287 | allowed_mask = torch.empty_like(logits[0], dtype=torch.bool) 288 | for i, req in enumerate(self.reqs): 289 | if req.regex_fsm is not None: 290 | allowed_mask.zero_() 291 | allowed_mask[ 292 | req.regex_fsm.allowed_token_ids(req.regex_fsm_state) 293 | ] = 1 294 | logits[i].masked_fill_(~allowed_mask, float("-inf")) 295 | 296 | # TODO(lmzheng): apply penalty 297 | probs = torch.softmax(logits, dim=-1) 298 | probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks) 299 | sampled_index = torch.multinomial(probs_sort, num_samples=1) 300 | batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view( 301 | -1 302 | ) 303 | batch_next_token_probs = torch.gather( 304 | probs_sort, dim=1, index=sampled_index 305 | ).view(-1) 306 | 307 | if has_regex: 308 | batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy() 309 | for i, req in enumerate(self.reqs): 310 | if req.regex_fsm is not None: 311 | req.regex_fsm_state = req.regex_fsm.next_state( 312 | req.regex_fsm_state, batch_next_token_ids_cpu[i] 313 | ) 314 | 315 | return batch_next_token_ids, batch_next_token_probs 316 | 317 | 318 | def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor): 319 | # Handle both 2D [batch, vocab] and 3D [batch, seq, vocab] tensors 320 | if probs.dim() == 3: 321 | # For 3D tensors, we only want to sample from the last token 322 | probs = probs[:, -1, :] # Take the last token in sequence 323 | 324 | probs_sort, probs_idx = probs.sort(dim=-1, descending=True) 325 | probs_sum = torch.cumsum(probs_sort, dim=-1) 326 | probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0 327 | probs_sort[ 328 | torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks 329 | ] = 0.0 330 | probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) 331 | return probs_sort, probs_idx 332 | -------------------------------------------------------------------------------- /python/sglang/srt/layers/quantization/awq_triton.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/awq_triton.py 2 | 3 | # SPDX-License-Identifier: Apache-2.0 4 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project 5 | 6 | import torch 7 | import triton 8 | import triton.language as tl 9 | 10 | AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] 11 | 12 | 13 | @triton.jit 14 | def awq_dequantize_kernel( 15 | qweight_ptr, # quantized matrix 16 | scales_ptr, # scales, per group 17 | zeros_ptr, # zeros, per group 18 | group_size, # Should always be one of the supported group sizes 19 | result_ptr, # Output matrix 20 | num_cols, # input num cols in qweight 21 | num_rows, # input num rows in qweight 22 | BLOCK_SIZE_X: tl.constexpr, 23 | BLOCK_SIZE_Y: tl.constexpr, 24 | ): 25 | # Setup the pids. 26 | pid_x = tl.program_id(axis=0) 27 | pid_y = tl.program_id(axis=1) 28 | 29 | # Compute offsets and masks for qweight_ptr. 30 | offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) 31 | offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) 32 | offsets = num_cols * offsets_y[:, None] + offsets_x[None, :] 33 | 34 | masks_y = offsets_y < num_rows 35 | masks_x = offsets_x < num_cols 36 | 37 | masks = masks_y[:, None] & masks_x[None, :] 38 | 39 | # Compute offsets and masks for result output ptr. 40 | result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) 41 | result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8) 42 | result_offsets = ( 43 | 8 * num_cols * result_offsets_y[:, None] + result_offsets_x[None, :] 44 | ) 45 | 46 | result_masks_y = result_offsets_y < num_rows 47 | result_masks_x = result_offsets_x < num_cols * 8 48 | result_masks = result_masks_y[:, None] & result_masks_x[None, :] 49 | 50 | # Load the weights. 51 | iweights = tl.load(qweight_ptr + offsets, masks, 0.0) 52 | iweights = tl.interleave(iweights, iweights) 53 | iweights = tl.interleave(iweights, iweights) 54 | iweights = tl.interleave(iweights, iweights) 55 | 56 | # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] 57 | # that will map given indices to the correct order. 58 | reverse_awq_order_tensor = ( 59 | (tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None] 60 | ).reshape(8) 61 | 62 | # Use this to compute a set of shifts that can be used to unpack and 63 | # reorder the values in iweights and zeros. 64 | shifts = reverse_awq_order_tensor * 4 65 | shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8)) 66 | shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) 67 | 68 | # Unpack and reorder: shift out the correct 4-bit value and mask. 69 | iweights = (iweights >> shifts) & 0xF 70 | 71 | # Compute zero offsets and masks. 72 | zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) 73 | zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) 74 | zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :] 75 | 76 | zero_masks_y = zero_offsets_y < num_rows // group_size 77 | zero_masks_x = zero_offsets_x < num_cols 78 | zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :] 79 | 80 | # Load the zeros. 81 | zeros = tl.load(zeros_ptr + zero_offsets, zero_masks, 0.0) 82 | zeros = tl.interleave(zeros, zeros) 83 | zeros = tl.interleave(zeros, zeros) 84 | zeros = tl.interleave(zeros, zeros) 85 | zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) 86 | 87 | # Unpack and reorder: shift out the correct 4-bit value and mask. 88 | zeros = (zeros >> shifts) & 0xF 89 | 90 | # Compute scale offsets and masks. 91 | scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) 92 | scale_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8) 93 | scale_offsets = num_cols * 8 * scale_offsets_y[:, None] + scale_offsets_x[None, :] 94 | scale_masks_y = scale_offsets_y < num_rows // group_size 95 | scale_masks_x = scale_offsets_x < num_cols * 8 96 | scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :] 97 | 98 | # Load the scales. 99 | scales = tl.load(scales_ptr + scale_offsets, scale_masks, 0.0) 100 | scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) 101 | 102 | # Dequantize. 103 | iweights = (iweights - zeros) * scales 104 | iweights = iweights.to(result_ptr.type.element_ty) 105 | 106 | # Finally, store. 107 | tl.store(result_ptr + result_offsets, iweights, result_masks) 108 | 109 | 110 | @triton.jit 111 | def awq_gemm_kernel( 112 | a_ptr, 113 | b_ptr, 114 | c_ptr, 115 | zeros_ptr, 116 | scales_ptr, 117 | M, 118 | N, 119 | K, 120 | group_size, 121 | BLOCK_SIZE_M: tl.constexpr, 122 | BLOCK_SIZE_N: tl.constexpr, 123 | BLOCK_SIZE_K: tl.constexpr, 124 | SPLIT_K: tl.constexpr, 125 | ): 126 | pid = tl.program_id(axis=0) 127 | pid_z = tl.program_id(1) 128 | 129 | # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. 130 | # num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N 131 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 132 | 133 | pid_m = pid // num_pid_n 134 | pid_n = pid % num_pid_n 135 | 136 | accumulator_dtype = c_ptr.type.element_ty 137 | 138 | # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. 139 | # accumulator = tl.arange(0, BLOCK_SIZE_N) 140 | # accumulator = tl.broadcast_to(accumulator[None, :], 141 | # (BLOCK_SIZE_M, BLOCK_SIZE_N)) 142 | # accumulator = accumulator & 0x0 143 | # accumulator = accumulator.to(accumulator_dtype) 144 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype) 145 | 146 | # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] 147 | # that will map given indices to the correct order. 148 | reverse_awq_order_tensor = ( 149 | (tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None] 150 | ).reshape(8) 151 | 152 | # Create the necessary shifts to use to unpack. 153 | shifts = reverse_awq_order_tensor * 4 154 | shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8)) 155 | shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N)) 156 | 157 | # Offsets and masks. 158 | offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 159 | masks_am = offsets_am < M 160 | 161 | offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) 162 | masks_bn = offsets_bn < N // 8 163 | 164 | offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) 165 | masks_zn = offsets_zn < N // 8 166 | 167 | offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 168 | masks_sn = offsets_sn < N 169 | 170 | offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) 171 | offsets_a = K * offsets_am[:, None] + offsets_k[None, :] 172 | offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :] 173 | 174 | a_ptrs = a_ptr + offsets_a 175 | b_ptrs = b_ptr + offsets_b 176 | 177 | # NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv 178 | # block_offset = BLOCK_SIZE_K * SPLIT_K 179 | # for k in range(0, (K + block_offset - 1) // (block_offset)): 180 | for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): 181 | masks_k = offsets_k < K 182 | masks_a = masks_am[:, None] & masks_k[None, :] 183 | a = tl.load(a_ptrs, mask=masks_a, other=0.0) 184 | 185 | masks_b = masks_k[:, None] & masks_bn[None, :] 186 | b = tl.load(b_ptrs, mask=masks_b, other=0.0) 187 | b = tl.interleave(b, b) 188 | b = tl.interleave(b, b) 189 | b = tl.interleave(b, b) 190 | 191 | # Dequantize b. 192 | offsets_szk = ( 193 | BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K 194 | ) // group_size + tl.arange(0, 1) 195 | offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :] 196 | masks_zk = offsets_szk < K // group_size 197 | masks_z = masks_zk[:, None] & masks_zn[None, :] 198 | zeros_ptrs = zeros_ptr + offsets_z 199 | zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0) 200 | zeros = tl.interleave(zeros, zeros) 201 | zeros = tl.interleave(zeros, zeros) 202 | zeros = tl.interleave(zeros, zeros) 203 | zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N)) 204 | 205 | offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :] 206 | masks_sk = offsets_szk < K // group_size 207 | masks_s = masks_sk[:, None] & masks_sn[None, :] 208 | scales_ptrs = scales_ptr + offsets_s 209 | scales = tl.load(scales_ptrs, mask=masks_s, other=0.0) 210 | scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N)) 211 | 212 | b = (b >> shifts) & 0xF 213 | zeros = (zeros >> shifts) & 0xF 214 | b = (b - zeros) * scales 215 | b = b.to(c_ptr.type.element_ty) 216 | 217 | # Accumulate results. 218 | accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) 219 | 220 | offsets_k += BLOCK_SIZE_K * SPLIT_K 221 | a_ptrs += BLOCK_SIZE_K * SPLIT_K 222 | b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8) 223 | 224 | c = accumulator.to(c_ptr.type.element_ty) 225 | offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 226 | offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 227 | c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :] 228 | c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 229 | tl.store(c_ptrs, c, mask=c_mask) 230 | 231 | 232 | # qweights - [K , M // 8], int32 233 | # scales - [K // G, M ], float16 234 | # zeros - [K // G, M // 8], int32 235 | def awq_dequantize_triton( 236 | qweight: torch.Tensor, 237 | scales: torch.Tensor, 238 | zeros: torch.Tensor, 239 | block_size_x: int = 32, 240 | block_size_y: int = 32, 241 | ) -> torch.Tensor: 242 | K = qweight.shape[0] 243 | M = scales.shape[1] 244 | group_size = qweight.shape[0] // scales.shape[0] 245 | 246 | assert K > 0 and M > 0 247 | assert scales.shape[0] == K // group_size and scales.shape[1] == M 248 | assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8 249 | assert group_size <= K 250 | assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K 251 | 252 | # Result tensor: 253 | # number of rows = same as input tensor 254 | # number of cols = 8 x input tensor num cols 255 | result = torch.empty( 256 | qweight.shape[0], 257 | qweight.shape[1] * 8, 258 | device=qweight.device, 259 | dtype=scales.dtype, 260 | ) 261 | 262 | Y = qweight.shape[0] # num rows 263 | X = qweight.shape[1] # num cols 264 | 265 | grid = lambda META: ( 266 | triton.cdiv(X, META["BLOCK_SIZE_X"]), 267 | triton.cdiv(Y, META["BLOCK_SIZE_Y"]), 268 | ) 269 | awq_dequantize_kernel[grid]( 270 | qweight, 271 | scales, 272 | zeros, 273 | group_size, 274 | result, 275 | X, 276 | Y, 277 | BLOCK_SIZE_X=block_size_x, 278 | BLOCK_SIZE_Y=block_size_y, 279 | ) 280 | 281 | return result 282 | 283 | 284 | # input - [M, K] 285 | # qweight - [K, N // 8] 286 | # qzeros - [K // G, N // 8] 287 | # scales - [K // G, N] 288 | # split_k_iters - parallelism along K-dimension, int, power of 2. 289 | def awq_gemm_triton( 290 | input: torch.Tensor, 291 | qweight: torch.Tensor, 292 | scales: torch.Tensor, 293 | qzeros: torch.Tensor, 294 | split_k_iters: int, 295 | block_size_m: int = 32, 296 | block_size_n: int = 32, 297 | block_size_k: int = 32, 298 | ) -> torch.Tensor: 299 | M, K = input.shape 300 | N = qweight.shape[1] * 8 301 | group_size = qweight.shape[0] // qzeros.shape[0] 302 | 303 | assert N > 0 and K > 0 and M > 0 304 | assert qweight.shape[0] == K and qweight.shape[1] == N // 8 305 | assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8 306 | assert scales.shape[0] == K // group_size and scales.shape[1] == N 307 | assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0 308 | assert split_k_iters <= 32 309 | assert group_size <= K 310 | assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K 311 | 312 | grid = lambda META: ( 313 | triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), 314 | split_k_iters, 315 | ) 316 | 317 | result = torch.zeros((split_k_iters, M, N), dtype=scales.dtype, device=input.device) 318 | 319 | # A = input, B = qweight, C = result 320 | # A = M x K, B = K x N, C = M x N 321 | awq_gemm_kernel[grid]( 322 | input, 323 | qweight, 324 | result, 325 | qzeros, 326 | scales, 327 | M, 328 | N, 329 | K, 330 | group_size, 331 | BLOCK_SIZE_M=block_size_m, 332 | BLOCK_SIZE_N=block_size_n, 333 | BLOCK_SIZE_K=block_size_k, 334 | SPLIT_K=split_k_iters, 335 | ) 336 | 337 | result = result.sum(0) 338 | 339 | return result 340 | -------------------------------------------------------------------------------- /python/sglang/srt/layers/extend_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd 5 | from sglang.srt.utils import wrap_kernel_launcher 6 | 7 | 8 | @triton.jit 9 | def _fwd_kernel( 10 | Q_Extend, 11 | K_Extend, 12 | V_Extend, 13 | O_Extend, 14 | K_Buffer, 15 | V_Buffer, 16 | Req_to_tokens, 17 | B_req_idx, 18 | B_Seq_Len, 19 | B_Start_Loc_Extend, 20 | B_Seq_Len_Extend, 21 | sm_scale, 22 | kv_group_num, 23 | stride_qbs, 24 | stride_qh, 25 | stride_kbs, 26 | stride_kh, 27 | stride_vbs, 28 | stride_vh, 29 | stride_obs, 30 | stride_oh, 31 | stride_buf_kbs, 32 | stride_buf_kh, 33 | stride_buf_vbs, 34 | stride_buf_vh, 35 | stride_req_to_tokens_b, 36 | BLOCK_DMODEL: tl.constexpr, 37 | BLOCK_M: tl.constexpr, 38 | BLOCK_N: tl.constexpr, 39 | ): 40 | cur_seq = tl.program_id(0) 41 | cur_head = tl.program_id(1) 42 | cur_block_m = tl.program_id(2) 43 | cur_kv_head = cur_head // kv_group_num 44 | 45 | cur_seq_len = tl.load(B_Seq_Len + cur_seq) 46 | cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq) 47 | cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend 48 | 49 | cur_seq_prefix_start_in_loc = 0 50 | cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq) 51 | cur_batch_req_idx = tl.load(B_req_idx + cur_seq) 52 | 53 | offs_d = tl.arange(0, BLOCK_DMODEL) 54 | offs_m = tl.arange(0, BLOCK_M) 55 | mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend 56 | offs_q = ( 57 | (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) 58 | * stride_qbs 59 | + cur_head * stride_qh 60 | + offs_d[None, :] 61 | ) 62 | q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0) 63 | 64 | # stage1: compute scores with prefix 65 | offs_n = tl.arange(0, BLOCK_N) 66 | 67 | acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) 68 | deno = tl.zeros([BLOCK_M], dtype=tl.float32) 69 | e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 70 | 71 | for start_n in range(0, cur_seq_len_prefix, BLOCK_N): 72 | start_n = tl.multiple_of(start_n, BLOCK_N) 73 | mask_n = (start_n + offs_n) < cur_seq_len_prefix 74 | offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + ( 75 | cur_seq_prefix_start_in_loc + start_n + offs_n 76 | ) 77 | offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0) 78 | 79 | # load k in transposed way 80 | offs_buf_k = ( 81 | offs_kv_loc[None, :] * stride_buf_kbs 82 | + cur_kv_head * stride_buf_kh 83 | + offs_d[:, None] 84 | ) 85 | k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0) 86 | 87 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 88 | qk += tl.dot(q, k) 89 | qk *= sm_scale 90 | qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf")) 91 | 92 | n_e_max = tl.maximum(tl.max(qk, 1), e_max) 93 | re_scale = tl.exp(e_max - n_e_max) 94 | p = tl.exp(qk - n_e_max[:, None]) 95 | deno = deno * re_scale + tl.sum(p, 1) 96 | 97 | offs_buf_v = ( 98 | offs_kv_loc[:, None] * stride_buf_vbs 99 | + cur_kv_head * stride_buf_vh 100 | + offs_d[None, :] 101 | ) 102 | v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0) 103 | p = p.to(v.dtype) 104 | acc = acc * re_scale[:, None] + tl.dot(p, v) 105 | 106 | e_max = n_e_max 107 | 108 | # stage2: compute the trianlge part 109 | 110 | cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M) 111 | for start_n in range(0, cur_block_m_end, BLOCK_N): 112 | start_n = tl.multiple_of(start_n, BLOCK_N) 113 | mask_n = (start_n + offs_n) < cur_block_m_end 114 | 115 | # load k in transposed way 116 | offs_k = ( 117 | (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs 118 | + cur_kv_head * stride_kh 119 | + offs_d[:, None] 120 | ) 121 | k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0) 122 | 123 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 124 | qk += tl.dot(q, k) 125 | qk *= sm_scale 126 | mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= ( 127 | start_n + offs_n[None, :] 128 | ) 129 | mask_causual &= mask_m[:, None] & mask_n[None, :] 130 | qk = tl.where(mask_causual, qk, float("-inf")) 131 | 132 | n_e_max = tl.maximum(tl.max(qk, 1), e_max) 133 | re_scale = tl.exp(e_max - n_e_max) 134 | p = tl.exp(qk - n_e_max[:, None]) 135 | deno = deno * re_scale + tl.sum(p, 1) 136 | 137 | offs_v = ( 138 | (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs 139 | + cur_kv_head * stride_vh 140 | + offs_d[None, :] 141 | ) 142 | v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0) 143 | p = p.to(v.dtype) 144 | acc = acc * re_scale[:, None] + tl.dot(p, v) 145 | 146 | e_max = n_e_max 147 | 148 | offs_o = ( 149 | (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) 150 | * stride_obs 151 | + cur_head * stride_oh 152 | + offs_d[None, :] 153 | ) 154 | tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None]) 155 | 156 | 157 | def extend_attention_fwd( 158 | q_extend, 159 | k_extend, 160 | v_extend, 161 | o_extend, 162 | k_buffer, 163 | v_buffer, 164 | req_to_tokens, 165 | b_req_idx, 166 | b_start_loc, 167 | b_seq_len, 168 | b_seq_len_prefix, 169 | b_start_loc_extend, 170 | b_seq_len_extend, 171 | max_len_in_batch, 172 | max_len_extend, 173 | ): 174 | """ 175 | q_extend, k_extend, v_extend, o_extend: contiguous tensors 176 | 177 | k_buffer, v_buffer: (prefix + extend) tensors in mem_manager 178 | """ 179 | BLOCK_M, BLOCK_N = 128, 128 180 | Lq, Lk, Lv, Lo = ( 181 | q_extend.shape[-1], 182 | k_extend.shape[-1], 183 | v_extend.shape[-1], 184 | o_extend.shape[-1], 185 | ) 186 | assert Lq == Lk and Lk == Lv and Lv == Lo 187 | assert Lq in {16, 32, 64, 128} 188 | 189 | sm_scale = 1.0 / (Lq**0.5) 190 | batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1] 191 | kv_group_num = q_extend.shape[1] // k_extend.shape[1] 192 | 193 | grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) 194 | num_warps = 4 if Lk <= 64 else 8 195 | num_stages = 1 196 | 197 | # Launch kernel using modern Triton API 198 | kernel_launcher = wrap_kernel_launcher(_fwd_kernel) 199 | kernel_launcher( 200 | grid, 201 | num_warps, 202 | q_extend, 203 | k_extend, 204 | v_extend, 205 | o_extend, 206 | k_buffer, 207 | v_buffer, 208 | req_to_tokens, 209 | b_req_idx, 210 | b_seq_len, 211 | b_start_loc_extend, 212 | b_seq_len_extend, 213 | sm_scale, 214 | kv_group_num, 215 | q_extend.stride(0), 216 | q_extend.stride(1), 217 | k_extend.stride(0), 218 | k_extend.stride(1), 219 | v_extend.stride(0), 220 | v_extend.stride(1), 221 | o_extend.stride(0), 222 | o_extend.stride(1), 223 | k_buffer.stride(0), 224 | k_buffer.stride(1), 225 | v_buffer.stride(0), 226 | v_buffer.stride(1), 227 | req_to_tokens.stride(0), 228 | BLOCK_DMODEL=Lq, 229 | BLOCK_M=BLOCK_M, 230 | BLOCK_N=BLOCK_N, 231 | num_stages=num_stages, 232 | ) 233 | 234 | 235 | def redundant_attention( 236 | q_extend, 237 | k_extend, 238 | v_extend, 239 | o_extend, 240 | k_buffer, 241 | v_buffer, 242 | req_to_tokens, 243 | b_req_idx, 244 | b_start_loc, 245 | b_seq_len, 246 | b_seq_len_prefix, 247 | max_len_in_batch, 248 | ): 249 | total_token_num = k_buffer.shape[0] 250 | B, H_Q, D = b_req_idx.shape[0], q_extend.shape[-2], q_extend.shape[-1] 251 | q_buffer = torch.empty( 252 | (total_token_num, H_Q, D), dtype=q_extend.dtype, device=q_extend.device 253 | ) 254 | 255 | pt = 0 256 | for i in range(B): 257 | cur_seq_len_extend = b_seq_len[i] - b_seq_len_prefix[i] 258 | pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i] 259 | q_buffer[pl:pr] = q_extend[pt : pt + cur_seq_len_extend] 260 | pt += cur_seq_len_extend 261 | 262 | o_buffer = torch.empty_like(q_buffer) 263 | context_attention_fwd( 264 | q_buffer, k_buffer, v_buffer, o_buffer, b_start_loc, b_seq_len, max_len_in_batch 265 | ) 266 | 267 | pt = 0 268 | for i in range(B): 269 | cur_seq_len_extend = b_seq_len[i] - b_seq_len_prefix[i] 270 | pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i] 271 | o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr] 272 | pt += cur_seq_len_extend 273 | 274 | 275 | def test(): 276 | torch.manual_seed(0) 277 | 278 | B, N_CTX, H_Q, H_KV, D = 19, 12331, 12, 4, 128 279 | dtype = torch.float16 280 | 281 | b_seq_len_prefix = torch.randint( 282 | 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" 283 | ) 284 | b_seq_len_extend = torch.randint( 285 | 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" 286 | ) 287 | b_seq_len = b_seq_len_prefix + b_seq_len_extend 288 | max_len_in_batch = torch.max(b_seq_len, 0)[0].item() 289 | 290 | b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda") 291 | req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32, device="cuda") 292 | b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") 293 | b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) 294 | b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda") 295 | b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) 296 | for i in range(B): 297 | req_to_tokens[i, : b_seq_len[i]] = torch.arange( 298 | b_start_loc[i], b_start_loc[i] + b_seq_len[i] 299 | ) 300 | 301 | total_token_num = torch.sum(b_seq_len).item() 302 | extend_token_num = torch.sum(b_seq_len_extend).item() 303 | k_buffer = torch.empty( 304 | (total_token_num, H_KV, D), dtype=dtype, device="cuda" 305 | ).normal_(mean=0.1, std=0.2) 306 | v_buffer = torch.empty( 307 | (total_token_num, H_KV, D), dtype=dtype, device="cuda" 308 | ).normal_(mean=0.1, std=0.2) 309 | 310 | k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") 311 | v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") 312 | q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") 313 | for i in range(B): 314 | extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] 315 | extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] 316 | extend_start = b_start_loc_extend[i] 317 | extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] 318 | k_extend[extend_start:extend_end] = k_buffer[ 319 | extend_start_in_buffer:extend_end_in_buffer 320 | ] 321 | v_extend[extend_start:extend_end] = v_buffer[ 322 | extend_start_in_buffer:extend_end_in_buffer 323 | ] 324 | q_extend[extend_start:extend_end] = torch.empty( 325 | (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda" 326 | ).normal_(mean=0.1, std=0.2) 327 | 328 | o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") 329 | o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") 330 | 331 | b_seq_len_extend = b_seq_len - b_seq_len_prefix 332 | b_start_loc_extend = torch.zeros_like(b_seq_len) 333 | b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) 334 | max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() 335 | extend_attention_fwd( 336 | q_extend, 337 | k_extend, 338 | v_extend, 339 | o_extend, 340 | k_buffer, 341 | v_buffer, 342 | req_to_tokens, 343 | b_req_idx, 344 | b_start_loc, 345 | b_seq_len, 346 | b_seq_len_prefix, 347 | b_start_loc_extend, 348 | b_seq_len_extend, 349 | max_len_in_batch, 350 | max_len_extend, 351 | ) 352 | 353 | redundant_attention( 354 | q_extend, 355 | k_extend, 356 | v_extend, 357 | o_redundant, 358 | k_buffer, 359 | v_buffer, 360 | req_to_tokens, 361 | b_req_idx, 362 | b_start_loc, 363 | b_seq_len, 364 | b_seq_len_prefix, 365 | max_len_in_batch, 366 | ) 367 | 368 | print("Mean: ", torch.mean(torch.abs(o_extend - o_redundant))) 369 | print("Max: ", torch.max(torch.abs(o_extend - o_redundant))) 370 | 371 | assert torch.allclose(o_extend, o_redundant, rtol=1e-2) 372 | 373 | 374 | if __name__ == "__main__": 375 | test() 376 | -------------------------------------------------------------------------------- /python/sglang/srt/constrained/fsm.py: -------------------------------------------------------------------------------- 1 | # Adapted from: 2 | # https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/fsm.py 3 | from typing import List, NewType, Protocol 4 | 5 | import interegular 6 | from lark import Lark 7 | 8 | # from outlines.fsm.parsing import PartialLark 9 | from sglang.srt.constrained.regex import ( 10 | create_fsm_index_tokenizer, 11 | make_deterministic_fsm, 12 | ) 13 | from sglang.srt.constrained.tokenizer import Tokenizer 14 | 15 | FSMState = NewType("FSMState", int) 16 | 17 | 18 | class FSM(Protocol): 19 | def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: ... 20 | 21 | def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: ... 22 | 23 | def is_final_state(self, state: FSMState, idx: int = 0) -> bool: ... 24 | 25 | def reset(self) -> None: ... 26 | 27 | 28 | class StopAtTokenFSM(FSM): 29 | """FSM to generate text until a specified token id is generated or 30 | a specified number of tokens has been generated. 31 | 32 | Text is usually produced until the EOS token is generated by the 33 | model. 34 | 35 | """ 36 | 37 | def __init__( 38 | self, 39 | tokenizer: "Tokenizer", 40 | stop_token_id: int, 41 | ): 42 | self.stop_token_id = stop_token_id 43 | self.num_tokens_generated = 0 44 | self.vocabulary = tokenizer.vocabulary.values() 45 | self.final_states = {1} 46 | 47 | def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: 48 | """Generate a list of allowed tokens for the next step. 49 | 50 | When in the initial state we allow every token to be generated. 51 | In the final state the only allowed token is `stop_token_id`. 52 | 53 | Parameters 54 | ---------- 55 | state 56 | The current state of the FSM. 57 | idx 58 | The index of the current input in the batch. 59 | 60 | Returns 61 | ------- 62 | A list that contains the tokens to mask. 63 | 64 | """ 65 | if state == 0: 66 | return list(self.vocabulary) 67 | else: 68 | return [self.stop_token_id] 69 | 70 | def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: 71 | """Update the state of the FSM. 72 | 73 | The FSM stays in the initial state `0` unless the specified stop token 74 | has been generated or the maximum number of tokens has been reached. In 75 | which case the FSM moves to the final state `1`. 76 | 77 | Parameters 78 | ---------- 79 | state 80 | The current state of the FSM. 81 | token_id 82 | The id of the token that was just generated. 83 | idx 84 | The index of the current input in the batch. 85 | 86 | Returns 87 | ------- 88 | The new state of the FSM. 89 | 90 | """ 91 | if idx == 0: 92 | self.num_tokens_generated += 1 93 | 94 | if token_id == self.stop_token_id: 95 | return FSMState(1) 96 | 97 | return FSMState(0) 98 | 99 | def is_final_state(self, state: FSMState, idx: int = 0) -> bool: 100 | """Determine whether the current state of the FSM is a final state.""" 101 | return state in self.final_states 102 | 103 | def reset(self) -> None: 104 | """Reset the FSM to its initial state. Here this only resets the token counter.""" 105 | self.num_tokens_generated = 0 106 | 107 | 108 | class RegexFSM(FSM): 109 | """FSM to generate text that is in the language of a regular expression.""" 110 | 111 | def __init__( 112 | self, 113 | regex_string: str, 114 | tokenizer: "Tokenizer", 115 | ): 116 | regex_pattern = interegular.parse_pattern(regex_string) 117 | regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) 118 | ( 119 | self.states_to_token_maps, 120 | self.empty_token_ids, 121 | ) = create_fsm_index_tokenizer(regex_fsm, tokenizer) 122 | 123 | # We make sure that it is possible to generate strings in the language 124 | # of the regular expression with the tokens present in the model's 125 | # vocabulary. 126 | if not any( 127 | regex_fsm.finals.intersection(v.values()) 128 | for v in self.states_to_token_maps.values() 129 | ): 130 | raise ValueError( 131 | "The vocabulary does not allow us to build a sequence that matches the input regex" 132 | ) 133 | 134 | self.final_states = regex_fsm.finals | { 135 | -1 136 | } # Include the EOS token in final states 137 | self.num_tokens_generated = 0 138 | self.vocabulary = tokenizer.vocabulary.values() 139 | self.end_token_id = tokenizer.eos_token_id 140 | 141 | def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: 142 | """Generate a list of allowed tokens for the next step. 143 | 144 | The initialization of the FSM builds an index which maps FSM states to a 145 | map from authorized tokens to the state in which the FSM needs to move 146 | if said token is generated. Therefore the authorized tokens at the 147 | current state are the keys of the map returned by the value of the index 148 | for current state. 149 | 150 | If the current state is not contained in the end this means that we are 151 | in a final state of the FSM. We only authorize EOS tokens in the final 152 | state. 153 | 154 | Parameters 155 | ---------- 156 | state 157 | The current state of the FSM. 158 | idx 159 | The index of the current input in the batch. 160 | 161 | Returns 162 | ------- 163 | A list that contains the tokens to mask. 164 | 165 | """ 166 | next_tokens_to_end_states = self.states_to_token_maps.get(state) 167 | 168 | if next_tokens_to_end_states is None: 169 | return [self.end_token_id] 170 | else: 171 | return list(next_tokens_to_end_states.keys()) 172 | 173 | def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: 174 | """Update the state of the FSM. 175 | 176 | We use the index to determine to which state the FSM should transition 177 | given the token that was just generated. 178 | 179 | Parameters 180 | ---------- 181 | state 182 | The current state of the FSM. 183 | token_id 184 | The id of the token that was just generated. 185 | idx 186 | The index of the current input in the batch. 187 | 188 | Returns 189 | ------- 190 | The new state of the FSM. 191 | 192 | """ 193 | if idx == 0: 194 | self.num_tokens_generated += 1 195 | 196 | if token_id == self.end_token_id: 197 | return FSMState(-1) 198 | 199 | last_token_to_end_state = self.states_to_token_maps[state] 200 | next_state = last_token_to_end_state.get(token_id) 201 | if next_state is None: 202 | next_state = -1 203 | 204 | return FSMState(next_state) 205 | 206 | def is_final_state(self, state: FSMState, idx: int = 0) -> bool: 207 | """Determine whether the current state of the FSM is a final state.""" 208 | return state in self.final_states 209 | 210 | def reset(self) -> None: 211 | """Reset the FSM to its initial state. Here this only resets the token counter.""" 212 | self.num_tokens_generated = 0 213 | 214 | 215 | class CFGFSM(FSM): 216 | """FSM to generate text that is in the language of a context-free grammar.""" 217 | 218 | def __init__( 219 | self, 220 | cfg_string: str, 221 | tokenizer: "Tokenizer", 222 | ): 223 | # self.parser = PartialLark(cfg_string, parser="lalr") 224 | self.parser = Lark( 225 | cfg_string, 226 | parser="lalr", 227 | lexer="contextual", 228 | propagate_positions=False, 229 | maybe_placeholders=False, 230 | regex=True, 231 | ) 232 | self.terminal_regexps = dict() 233 | for terminal in self.parser.terminals: 234 | if terminal.pattern is not None: 235 | self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp() 236 | self.terminal_regexps["$END"] = tokenizer.eos_token 237 | 238 | self.tokenizer = tokenizer 239 | self.num_tokens_generated = 0 240 | self.generations: List[str] = [] 241 | self.regex_fsms: List[RegexFSM] = [] 242 | self.reset_state: List[bool] = [] 243 | self.allow_eos: List[bool] = [] 244 | self.done: List[bool] = [] 245 | 246 | def _set_next_regex_fsm(self, idx: int = 0) -> None: 247 | """Use the CFG incremental parser to set the next regex FSM. 248 | 249 | Check what the CFG incremental parser proposes next. 250 | If the only proposal is the EOS token, 251 | we set the state to done and return. 252 | If there are other proposals, 253 | we set a new regex FSM and return. 254 | 255 | """ 256 | interactive = self.parser.parse_interactive(self.generations[idx]) 257 | interactive.exhaust_lexer() 258 | options = {self.terminal_regexps[x] for x in interactive.accepts()} 259 | 260 | if self.terminal_regexps["$END"] in options: 261 | options.remove(self.terminal_regexps["$END"]) 262 | if len(options) == 0: 263 | self.done[idx] = True 264 | return 265 | self.allow_eos[idx] = True 266 | options.add("") 267 | assert len(options) > 1 268 | 269 | regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")" 270 | args = ( 271 | regex_string, 272 | self.tokenizer, 273 | ) 274 | if len(self.regex_fsms) <= idx: 275 | self.regex_fsms.append(RegexFSM(*args)) 276 | else: 277 | self.regex_fsms[idx] = RegexFSM(*args) 278 | self.reset_state[idx] = True 279 | 280 | def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: 281 | """Generate a list of allowed tokens for the next step. 282 | 283 | Upon initialization, the CFG incremental parser is used to determine the first regex. 284 | 285 | This regex is used for proposals until either: 286 | - the regex is exhausted, and its only remaining option is the EOS token, 287 | in which case we always transition to the next regex 288 | - the regex can be exhausted, but the EOS token is not the only remaining option, 289 | in which case we transition to the next regex with probability P (TODO) 290 | or remove the possibility of generating the EOS token and continue with the current regex 291 | 292 | The CFG incremental parser is allowed to propose the EOS token from any final state, 293 | and once it is generated, the FSM will continue to always generate the EOS token. 294 | 295 | Parameters 296 | ---------- 297 | state 298 | The current state of the FSM. 299 | idx 300 | The index of the current input in the batch. 301 | 302 | Returns 303 | ------- 304 | A list that contains the tokens to mask. 305 | 306 | """ 307 | if len(self.generations) <= idx: 308 | self.generations.append("") 309 | self.reset_state.append(False) 310 | self.allow_eos.append(False) 311 | self.done.append(False) 312 | 313 | if len(self.regex_fsms) > idx: 314 | proposal = self.regex_fsms[idx].allowed_token_ids(state) 315 | if self.tokenizer.eos_token_id not in proposal: 316 | return proposal 317 | if set(proposal) != {self.tokenizer.eos_token_id}: 318 | if False: # TODO: THIS NEEDS TO BE SAMPLED 319 | proposal = [x for x in proposal if x != self.tokenizer.eos_token_id] 320 | return proposal 321 | 322 | self._set_next_regex_fsm(idx) 323 | 324 | if self.done[idx]: 325 | return [self.tokenizer.eos_token_id] 326 | 327 | if self.reset_state[idx]: 328 | state = FSMState(0) 329 | 330 | proposal = self.regex_fsms[idx].allowed_token_ids(state) 331 | if self.allow_eos[idx]: 332 | self.allow_eos[idx] = False 333 | else: 334 | proposal = [x for x in proposal if x != self.tokenizer.eos_token_id] 335 | assert len(proposal) > 0 336 | return proposal 337 | 338 | def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: 339 | """Update the state of the FSM. 340 | 341 | Transitions the underlying regex FSM to its next state. 342 | If at max tokens or EOS token, transition permanently to the final state. 343 | Update stored partial generations for subsequent incremental parsing. 344 | 345 | Parameters 346 | ---------- 347 | state 348 | The current state of the FSM. 349 | token_id 350 | The id of the token that was just generated. 351 | idx 352 | The index of the current input in the batch. 353 | 354 | Returns 355 | ------- 356 | The new state of the FSM. 357 | """ 358 | if idx == 0: 359 | self.num_tokens_generated += 1 360 | if token_id == self.tokenizer.eos_token_id: 361 | self.done[idx] = True 362 | return FSMState(-1) 363 | if self.reset_state[idx]: 364 | self.reset_state[idx] = False 365 | state = FSMState(0) 366 | 367 | self.generations[idx] += self.tokenizer.decode([token_id])[0] 368 | 369 | return self.regex_fsms[idx].next_state(state, token_id, idx) 370 | 371 | def is_final_state(self, state: FSMState, idx: int = 0) -> bool: 372 | """Return whether the current state of the FSM is a final state.""" 373 | return self.done[idx] 374 | 375 | def reset(self) -> None: 376 | """Reset the FSM to its initial state, so it can be called on a fresh batch on inputs.""" 377 | self.num_tokens_generated = 0 378 | self.generations = [] 379 | self.regex_fsms = [] 380 | self.reset_state = [] 381 | self.done = [] 382 | --------------------------------------------------------------------------------