├── MANIFEST.in ├── tokasaurus ├── manager │ ├── __init__.py │ ├── hydragen.py │ └── input_building.py ├── server │ ├── __init__.py │ ├── types.py │ └── endpoints.py ├── __init__.py ├── scripts │ ├── download.py │ └── ping.py ├── core.py ├── model │ ├── qwen.py │ ├── kv_cache.py │ ├── safetensors_utils.py │ ├── entry.py │ ├── attention_utils.py │ ├── qwen3.py │ ├── basic_worker.py │ └── pipeline_worker.py ├── benchmarks │ ├── sharegpt.py │ ├── monkeys_chat.py │ ├── bench_summarization.py │ ├── monkeys_math500.py │ ├── monkeys_gsm8k.py │ ├── utils.py │ └── bench_model.py ├── entry.py └── common_types.py ├── pyrightconfig.json ├── pytest.ini ├── pyproject.toml ├── tests ├── test_bumping.py ├── test_block_allocator.py ├── test_logprobs.py ├── test_topk.py ├── test_basic.py └── test_scheduler.py ├── contributing └── models.md ├── .gitignore ├── CLAUDE.md ├── logs └── blog_commands.md ├── README.md └── LICENSE /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | -------------------------------------------------------------------------------- /tokasaurus/manager/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tokasaurus/server/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tokasaurus/__init__.py: -------------------------------------------------------------------------------- 1 | version = "0.0.4" 2 | -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "reportOptionalMemberAccess": "none", 3 | "reportPossiblyUnboundVariable": "none" 4 | } 5 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = 3 | ignore:Module already imported so cannot be rewritten.*:pytest.PytestAssertRewriteWarning 4 | -------------------------------------------------------------------------------- /tokasaurus/scripts/download.py: -------------------------------------------------------------------------------- 1 | import huggingface_hub 2 | import pydra 3 | 4 | 5 | class ScriptConfig(pydra.Config): 6 | model: str 7 | 8 | def __init__(self): 9 | super().__init__() 10 | self.allow_patterns = ["*.safetensors", "*.json"] 11 | 12 | 13 | def download(config: ScriptConfig): 14 | print(f"Downloading {config.model}") 15 | cached_path = huggingface_hub.snapshot_download( 16 | config.model, allow_patterns=config.allow_patterns 17 | ) 18 | print(f"Download complete, stored at {cached_path}") 19 | 20 | 21 | def main(): 22 | pydra.run(download) 23 | 24 | 25 | if __name__ == "__main__": 26 | main() 27 | -------------------------------------------------------------------------------- /tokasaurus/core.py: -------------------------------------------------------------------------------- 1 | """ 2 | The most important part of LLM inference - having fun. 3 | """ 4 | 5 | import random 6 | 7 | from art import text2art 8 | 9 | STARTUP_MESSAGES = [ 10 | "let's get\nready to\nrumble!", 11 | "GPU poor?\nnot today!", 12 | "locked and\nloaded!", 13 | "it's a bad day\nto be the\npower grid", 14 | "token\ntime!1!1!", 15 | "eat, sleep,\ninference,\nrepeat.", 16 | "the little\nLLM engine\nthat could", 17 | "attention all\nyou need?\nI gotchu", 18 | # generated by claude: 19 | "beep boop\nlet's compute", 20 | "caution:\nhot tensors", 21 | "sudo chmod\n777 fun.py", 22 | ] 23 | 24 | 25 | def complete_server_startup(): 26 | chosen = random.choice(STARTUP_MESSAGES) 27 | spaced = chosen.replace(" ", " ") 28 | print(text2art(spaced)) 29 | -------------------------------------------------------------------------------- /tokasaurus/model/qwen.py: -------------------------------------------------------------------------------- 1 | from transformers import Qwen2Config 2 | 3 | from tokasaurus.model.llama import ( 4 | LlamaAttention, 5 | LlamaBlock, 6 | LlamaForCausalLM, 7 | LlamaModel, 8 | ) 9 | 10 | 11 | class Qwen2Attention(LlamaAttention): 12 | qkv_bias: bool = True 13 | 14 | 15 | class Qwen2Block(LlamaBlock): 16 | attn_cls = Qwen2Attention 17 | 18 | 19 | class Qwen2Model(LlamaModel): 20 | block_cls = Qwen2Block 21 | 22 | 23 | class Qwen2ForCausalLM(LlamaForCausalLM): 24 | model_cls = Qwen2Model 25 | config_cls = Qwen2Config 26 | 27 | def make_tp_map(self): 28 | """ 29 | Need to add the qkv biases to the tp map. 30 | """ 31 | tp_map = super().make_tp_map() 32 | for param_name, _ in self.named_parameters(): 33 | if any( 34 | param_name.endswith(suffix) 35 | for suffix in [ 36 | "q_proj.bias", 37 | "k_proj.bias", 38 | "v_proj.bias", 39 | ] 40 | ): 41 | tp_map[param_name] = 0 42 | 43 | return tp_map 44 | -------------------------------------------------------------------------------- /tokasaurus/benchmarks/sharegpt.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | import pydra 4 | 5 | from tokasaurus.benchmarks.utils import ( 6 | BaseConfig, 7 | launch_server, 8 | prepend_conda_activate, 9 | ) 10 | 11 | 12 | class Config(BaseConfig): 13 | sharegpt_command: str 14 | sharegpt_env: str | None = None 15 | 16 | 17 | def main(config: Config): 18 | sharegpt_command = config.sharegpt_command 19 | 20 | if config.save_path is not None: 21 | config.save_path.parent.mkdir(parents=True, exist_ok=True) 22 | sharegpt_command = f"{sharegpt_command} --output-file {config.save_path}" 23 | 24 | if config.sharegpt_env is not None: 25 | sharegpt_command = prepend_conda_activate( 26 | sharegpt_command, config.conda_activate_path, config.sharegpt_env 27 | ) 28 | 29 | if (save_path := config.save_path) is not None: 30 | save_path.parent.mkdir(parents=True, exist_ok=True) 31 | launch_command_save_path = save_path.with_suffix(".launch.txt") 32 | launch_command_save_path.write_text(str(config.launch)) 33 | 34 | print(f"ShareGPT command: '{sharegpt_command}'") 35 | 36 | with launch_server(config): 37 | for _ in range(config.reps): 38 | subprocess.run(sharegpt_command, shell=True, executable="/bin/bash") 39 | 40 | 41 | if __name__ == "__main__": 42 | pydra.run(main) 43 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "tokasaurus" 7 | version = "0.0.4" 8 | description = "The little (LLM) engine that could!" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | license = { file = "LICENSE" } 12 | classifiers = [ 13 | "Development Status :: 3 - Alpha", 14 | "Intended Audience :: Developers", 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: Apache Software License", 17 | ] 18 | dependencies = [ 19 | "transformers==4.53.0", 20 | "pydra-config>=0.0.13", 21 | "accelerate", 22 | "art", 23 | "statsd", 24 | "fastapi", 25 | "ninja", 26 | "tabulate", 27 | "uvicorn", 28 | "typer", 29 | "openai", 30 | "loguru", 31 | "python-multipart", 32 | "torch==2.6.0", 33 | "flashinfer-python==0.2.0.post2", 34 | "tqdm", 35 | ] 36 | 37 | [project.optional-dependencies] 38 | dev = [ 39 | "pytest", 40 | "datasets", 41 | "pyright", 42 | "math-verify[antlr4_13_2]", 43 | "matplotlib", 44 | ] 45 | 46 | [project.scripts] 47 | tksrs = "tokasaurus.entry:main" 48 | toka = "tokasaurus.entry:main" 49 | tksrs-ping = "tokasaurus.scripts.ping:main" 50 | toka-ping = "tokasaurus.scripts.ping:main" 51 | toka-download = "tokasaurus.scripts.download:main" 52 | 53 | [tool.setuptools] 54 | include-package-data = true 55 | packages = {find = {}} -------------------------------------------------------------------------------- /tokasaurus/model/kv_cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | from tokasaurus.model.types import DeviceType 5 | 6 | 7 | class LayerKVCache(nn.Module): 8 | k_cache: Tensor 9 | v_cache: Tensor | None 10 | 11 | def __init__( 12 | self, 13 | head_dim: int, 14 | num_kv_heads: int, 15 | num_pages: int, 16 | page_size: int, 17 | device: DeviceType | None = None, 18 | dtype: torch.dtype | None = None, 19 | ): 20 | super().__init__() 21 | self.num_pages = num_pages 22 | self.page_size = page_size 23 | 24 | self.num_key_value_heads = num_kv_heads 25 | self.head_dim = head_dim 26 | 27 | self.register_buffer( 28 | "k_cache", 29 | torch.zeros( 30 | ( 31 | num_pages, 32 | page_size, 33 | num_kv_heads, 34 | head_dim, 35 | ), 36 | device=device, 37 | dtype=dtype, 38 | ), 39 | persistent=False, 40 | ) 41 | self.register_buffer( 42 | "v_cache", 43 | torch.zeros( 44 | ( 45 | num_pages, 46 | page_size, 47 | num_kv_heads, 48 | head_dim, 49 | ), 50 | device=device, 51 | dtype=dtype, 52 | ), 53 | persistent=False, 54 | ) 55 | -------------------------------------------------------------------------------- /tokasaurus/scripts/ping.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import pydra 4 | from openai import OpenAI 5 | 6 | 7 | class ScriptConfig(pydra.Config): 8 | prompt: str 9 | 10 | model: str = "" 11 | port: int = 10210 12 | host: str = "0.0.0.0" 13 | chat: bool = False 14 | max_tokens: int = 100 15 | n: int = 1 16 | temperature: float = 0.0 17 | hide: bool = False 18 | retries: int = 0 19 | 20 | 21 | def ping(config: ScriptConfig): 22 | client = OpenAI( 23 | base_url=f"http://{config.host}:{config.port}/v1", 24 | api_key="fake-key", 25 | max_retries=config.retries, 26 | ) 27 | 28 | print("Making request...") 29 | start = time.time() 30 | if config.chat: 31 | out = client.chat.completions.create( 32 | model=config.model, 33 | messages=[{"role": "user", "content": config.prompt}], 34 | max_tokens=config.max_tokens, 35 | n=config.n, 36 | temperature=config.temperature, 37 | ) 38 | responses = [choice.message.content for choice in out.choices] 39 | else: 40 | out = client.completions.create( 41 | model=config.model, 42 | prompt=config.prompt, 43 | max_tokens=config.max_tokens, 44 | n=config.n, 45 | temperature=config.temperature, 46 | ) 47 | responses = [choice.text for choice in out.choices] 48 | 49 | end = time.time() 50 | print(f"Time taken: {end - start} seconds") 51 | 52 | if not config.hide: 53 | print("Responses:") 54 | print("-" * 100) 55 | for i, response in enumerate(responses): 56 | print(f"Response {i}: {response}") 57 | print("-" * 100) 58 | 59 | 60 | def main(): 61 | pydra.run(ping) 62 | 63 | 64 | if __name__ == "__main__": 65 | main() 66 | -------------------------------------------------------------------------------- /tests/test_bumping.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch.multiprocessing as mp 5 | from openai import OpenAI 6 | 7 | from tokasaurus.common_types import ServerConfig 8 | from tokasaurus.entry import server_manager 9 | from tokasaurus.utils import find_free_port 10 | 11 | MODEL = os.environ.get("MODEL", "meta-llama/Llama-3.2-1B-Instruct") 12 | 13 | 14 | @pytest.fixture(scope="module") 15 | def client(): 16 | mp.set_start_method("spawn", force=True) 17 | 18 | port = find_free_port() 19 | config = ServerConfig() 20 | config.model = MODEL 21 | config.kv_cache_num_tokens = 16384 22 | config.max_num_tokens_per_request = 16384 23 | config.port = port 24 | config.page_size = 16 25 | config.track_early_stopping = True 26 | config.use_spec_allocation = True 27 | 28 | with server_manager(config): 29 | client = OpenAI( 30 | api_key="beepboop", base_url=f"http://localhost:{config.port}/v1" 31 | ) 32 | 33 | yield client 34 | 35 | 36 | def test_bumping(client: OpenAI): 37 | a_through_j = " A B C D E F G H I J" 38 | 39 | abc_prompt = (a_through_j * 10).strip() 40 | hundred_token_response = a_through_j * 10 41 | 42 | # first we make send enough sequences to 43 | # set the early stopping tracker 44 | response = client.completions.create( 45 | model="", prompt=abc_prompt, max_tokens=100, temperature=0.0, n=1024, stop=["C"] 46 | ) 47 | 48 | for c in response.choices: 49 | assert c.text == " A B " 50 | 51 | print("Done with first request - no bumping should have occurred yet") 52 | 53 | # now we don't use stop strings to cause bumping 54 | response2 = client.completions.create( 55 | model="", prompt=abc_prompt, max_tokens=100, temperature=0.0, n=1024 56 | ) 57 | 58 | for c in response2.choices: 59 | assert c.text == hundred_token_response 60 | -------------------------------------------------------------------------------- /tokasaurus/benchmarks/monkeys_chat.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import pydra 5 | 6 | from tokasaurus.benchmarks.monkeys_gsm8k import ScriptConfig as GSM8kConfig 7 | from tokasaurus.benchmarks.utils import ( 8 | launch_server, 9 | maybe_save_results, 10 | parallelize, 11 | shuffle_and_limit, 12 | sample_sharegpt_requests, 13 | ) 14 | 15 | 16 | class ScriptConfig(GSM8kConfig): 17 | def __init__(self): 18 | super().__init__() 19 | self.stop_strings = ["U:", "User:"] 20 | self.temperature = 0.6 21 | self.max_tokens = 8192 22 | self.n = 1 23 | self.limit = None 24 | 25 | 26 | def run_inference(item, config: ScriptConfig): 27 | # making the ordering of requests to the server more consistent with multiple workers 28 | if config.workers != 0: 29 | index = item["shuffled_index"] 30 | time.sleep(0.1 * index) 31 | 32 | client = config.client() 33 | 34 | message = [ 35 | {"role": "system", "content": "You are a helpful assistant."}, 36 | {"role": "user", "content": item["conversations"]}, 37 | ] 38 | 39 | response = client.chat.completions.create( 40 | model=config.model, 41 | messages=message, 42 | max_tokens=config.max_tokens, 43 | temperature=config.temperature, 44 | top_p=config.top_p, 45 | stop=config.stop_strings, 46 | n=config.n, 47 | ) 48 | completions = [ 49 | response.choices[i].message.content for i in range(config.n) 50 | ] 51 | assert len(completions) == config.n 52 | return completions 53 | 54 | def main(config: ScriptConfig): 55 | raw_test_dataset = list( 56 | sample_sharegpt_requests() 57 | ) 58 | 59 | print(f"Number of test items: {len(raw_test_dataset)}") 60 | 61 | test_dataset = shuffle_and_limit(raw_test_dataset, config) 62 | 63 | print(f"Total number of items to process: {len(test_dataset)}") 64 | 65 | go_func = partial(run_inference, config=config) 66 | 67 | with launch_server(config): 68 | start = time.time() 69 | completions_list = parallelize( 70 | fn=go_func, 71 | items=test_dataset, 72 | num_workers=config.workers, 73 | allow_unordered=True, 74 | ) 75 | end = time.time() 76 | 77 | elapsed = end - start 78 | print(f"Time taken: {elapsed} seconds") 79 | 80 | 81 | maybe_save_results( 82 | config, 83 | { 84 | "elapsed": elapsed, 85 | "completions": completions_list, 86 | "launch": config.launch, 87 | }, 88 | ) 89 | 90 | 91 | if __name__ == "__main__": 92 | pydra.run(main) 93 | -------------------------------------------------------------------------------- /contributing/models.md: -------------------------------------------------------------------------------- 1 | # Adding a new model 2 | As a running example, we'll describe adding support for the **Qwen-3** family of models. 3 | The same high-level steps apply to any model that can be framed as a (possibly light) 4 | variant of the Llama / Qwen-2 architecture. 5 | 6 | **Note**: This document is as much for human developers as it is for coding agents :) 7 | --- 8 | 9 | ## 1. Implement the modelling file 10 | 11 | All model specific code lives under `tokasaurus/model`. For Qwen-3 we create 12 | `tokasaurus/model/qwen3.py`. 13 | 14 | You should use a [reference implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modular_qwen3.py) from the `transformers` library as your guide. 15 | 16 | Critical things to remember: 17 | * The model must be compatible with the Tokasaurus `BatchState` interface. 18 | * The model must use `tokasaurus_attention`, which calls to FlashInfer under the hood. 19 | 20 | --- 21 | 22 | ## 2. Register the model type 23 | 24 | Tokasaurus discovers models through a small mapping in 25 | `tokasaurus/model/utils.py`. Add an import and extend the dictionaries: 26 | 27 | ```diff 28 | # utils.py (near the top) 29 | -from tokasaurus.model.qwen import Qwen2ForCausalLM 30 | +from tokasaurus.model.qwen import Qwen2ForCausalLM 31 | +from tokasaurus.model.qwen3 import Qwen3ForCausalLM 32 | @@ 33 | -model_type = LlamaForCausalLM | Qwen2ForCausalLM 34 | +model_type = LlamaForCausalLM | Qwen2ForCausalLM | Qwen3ForCausalLM 35 | @@ 36 | "qwen2": Qwen2ForCausalLM, 37 | + "qwen3": Qwen3ForCausalLM, 38 | ``` 39 | 40 | The key (`"qwen3"`) must match the `model_type` field inside the Hugging Face 41 | `Qwen3Config` (you can verify via `AutoConfig.from_pretrained`). 42 | 43 | --- 44 | 45 | ## 3. (Optional) Extra features 46 | 47 | If the new architecture requires deeper changes (e.g. different position 48 | encoding or weight layout): 49 | 50 | * Add new subclasses for the relevant modules (MLP, embeddings, …) similar to 51 | the Attention example above. 52 | * Overwrite `make_name_to_hf_name` and/or `tp_modify_state_dict` in 53 | `LlamaForCausalLM` if the checkpoint key names differ. 54 | * Add any device-side kernels you need in `tokasaurus/model/attention_utils.py`. 55 | 56 | For purely additive features (e.g. support for **rope-scaling** parameters) you 57 | usually only need to read the attribute from the HF `Config` and forward it to 58 | `ExtraModelConfig`. 59 | 60 | --- 61 | 62 | ## 4. Tests 63 | 64 | You have succeeded when the following command passes **without GPU OOMs or 65 | assertion failures**: 66 | 67 | ```bash 68 | MODEL=Qwen/Qwen3-0.6B pytest tests/test_logprobs.py -k test_logprobs -s 69 | ``` 70 | 71 | Tips to debug failures: 72 | * Use the `--capture=no -s` flags for *verbose* test output. 73 | * Run the server directly via `python -m tokasaurus.entry --config …` and send 74 | a manual request with the OpenAI client. 75 | 76 | AGENT 77 | -------------------------------------------------------------------------------- /tokasaurus/benchmarks/bench_summarization.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import numpy as np 5 | import pydra 6 | from datasets import load_dataset 7 | 8 | from tokasaurus.benchmarks.utils import ( 9 | BaseConfig, 10 | launch_server, 11 | maybe_save_results, 12 | parallelize, 13 | shuffle_and_limit, 14 | ) 15 | 16 | 17 | class ScriptConfig(BaseConfig): 18 | def __init__(self): 19 | super().__init__() 20 | self.max_tokens = 1024 21 | self.temperature = 0.6 22 | 23 | 24 | def run_inference(item, config: ScriptConfig): 25 | # making the ordering of requests to the server more consistent with multiple workers 26 | if config.workers != 0: 27 | index = item["shuffled_index"] 28 | time.sleep(0.1 * index) 29 | 30 | client = config.client() 31 | prompt = f"Summarize this research paper in approximately 256 words:\n\n{item['article']}" 32 | 33 | response = client.chat.completions.create( 34 | model=config.model, 35 | messages=[{"role": "user", "content": prompt}], 36 | max_tokens=config.max_tokens, 37 | temperature=config.temperature, 38 | top_p=config.top_p, 39 | stop=config.stop_strings, 40 | n=config.n, 41 | ) 42 | 43 | completions = [choice.message.content for choice in response.choices] 44 | assert len(completions) == config.n 45 | 46 | # completion_ids only in system fingerprint for tokasaurus 47 | # completion_ids = ast.literal_eval(response.system_fingerprint)["completion_ids"] 48 | # response_lengths = [len(c) for c in completion_ids] 49 | 50 | response_lengths = [len(c) for c in completions] 51 | 52 | return response_lengths 53 | 54 | 55 | def main(config: ScriptConfig): 56 | raw_test_dataset = list( 57 | load_dataset("ccdv/arxiv-summarization", "section", split="test") 58 | ) 59 | 60 | print(f"Number of test items: {len(raw_test_dataset)}") 61 | 62 | test_dataset = shuffle_and_limit(raw_test_dataset, config) 63 | 64 | print(f"Total number of items to process: {len(test_dataset)}") 65 | 66 | go_func = partial(run_inference, config=config) 67 | 68 | with launch_server(config): 69 | start = time.time() 70 | responses = parallelize( 71 | fn=go_func, 72 | items=test_dataset, 73 | num_workers=config.workers, 74 | processes=True, 75 | allow_unordered=True, 76 | ) 77 | 78 | end = time.time() 79 | 80 | elapsed = end - start 81 | print(f"Elapsed time: {elapsed:.2f} seconds") 82 | 83 | all_lengths = [ 84 | length for response_lengths in responses for length in response_lengths 85 | ] 86 | mean_length = float(np.mean(all_lengths)) if all_lengths else 0 87 | throughput = sum(all_lengths) / elapsed 88 | 89 | print(f"Mean response length: {mean_length:.2f}") 90 | print(f"Throughput: {throughput:.2f} toks/sec") 91 | 92 | maybe_save_results( 93 | config, 94 | { 95 | "elapsed": elapsed, 96 | "responses": responses, 97 | }, 98 | ) 99 | 100 | 101 | if __name__ == "__main__": 102 | pydra.run(main) 103 | -------------------------------------------------------------------------------- /tokasaurus/model/safetensors_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import List 4 | 5 | import huggingface_hub 6 | from safetensors import safe_open 7 | from tqdm import tqdm 8 | 9 | from tokasaurus.model.types import DeviceType 10 | 11 | 12 | def compute_shard_bounds( 13 | tensor_shape: List[int], dim: int, num_shards: int, shard_index: int 14 | ): 15 | dim_size = tensor_shape[dim] 16 | base_shard_size = dim_size // num_shards 17 | remainder = dim_size % num_shards 18 | 19 | start_idx = shard_index * base_shard_size + min(shard_index, remainder) 20 | 21 | if shard_index < remainder: 22 | end_idx = start_idx + base_shard_size + 1 23 | else: 24 | end_idx = start_idx + base_shard_size 25 | 26 | return slice(start_idx, end_idx) 27 | 28 | 29 | def can_load_from_safetensors(model_path: Path): 30 | files = [x.name for x in sorted(model_path.glob("*"))] 31 | return "model.safetensors.index.json" in files or "model.safetensors" in files 32 | 33 | 34 | def load_safetensors_repo( 35 | repo_path: Path, 36 | include_parameters: set[str], 37 | device: DeviceType, 38 | tp_rank: int = 0, 39 | tp_size: int = 1, 40 | tp_map: dict[str, int] | None = None, 41 | ): 42 | if tp_map is None: 43 | tp_map = {} 44 | 45 | single_file = repo_path / "model.safetensors" 46 | if single_file.exists(): 47 | files_to_load = [single_file] 48 | 49 | else: 50 | safetensors_index = repo_path / "model.safetensors.index.json" 51 | 52 | if not safetensors_index.exists(): 53 | raise FileNotFoundError( 54 | f"Could not find model.safetensors or model.safetensors.index.json in {repo_path}" 55 | ) 56 | 57 | with open(safetensors_index, "r") as f: 58 | index = json.load(f) 59 | 60 | param_to_path = index["weight_map"] 61 | 62 | files_to_load_set = set() 63 | 64 | for param_name, path in param_to_path.items(): 65 | if param_name in include_parameters: 66 | files_to_load_set.add(repo_path / path) 67 | 68 | files_to_load = list(sorted(files_to_load_set)) 69 | 70 | state_dict = {} 71 | 72 | for file in tqdm( 73 | files_to_load, 74 | desc="Loading safetensors files", 75 | ): 76 | with safe_open(file, framework="pt", device=device) as f: 77 | for k in f.keys(): 78 | if k in include_parameters: 79 | if tp_size > 1 and (split_dim := tp_map.get(k)) is not None: 80 | tensor_slice = f.get_slice(k) 81 | shard_bounds = compute_shard_bounds( 82 | tensor_slice.get_shape(), split_dim, tp_size, tp_rank 83 | ) 84 | # TODO: there's gotta be a better way to do this 85 | match split_dim: 86 | case 0: 87 | state_dict[k] = tensor_slice[shard_bounds] 88 | case 1: 89 | state_dict[k] = tensor_slice[:, shard_bounds] 90 | case _: 91 | raise ValueError( 92 | f"Unsupported split dimension: {split_dim}" 93 | ) 94 | else: 95 | state_dict[k] = f.get_tensor(k) 96 | 97 | return state_dict 98 | -------------------------------------------------------------------------------- /tokasaurus/benchmarks/monkeys_math500.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import pydra 5 | from datasets import load_dataset 6 | from math_verify import parse, verify 7 | from tabulate import tabulate 8 | 9 | from tokasaurus.benchmarks.monkeys_gsm8k import ScriptConfig as GSM8kConfig 10 | from tokasaurus.benchmarks.utils import ( 11 | launch_server, 12 | make_pass_at_k_table, 13 | maybe_save_results, 14 | parallelize, 15 | shuffle_and_limit, 16 | ) 17 | 18 | 19 | def is_correct_math500(completion, gt_answer): 20 | gold = parse(gt_answer) 21 | answer = parse(completion) 22 | 23 | return verify(gold, answer) 24 | 25 | 26 | class ScriptConfig(GSM8kConfig): 27 | def __init__(self): 28 | super().__init__() 29 | self.stop_strings = ["Q:", "Question:"] 30 | self.temperature = 0.6 31 | self.max_tokens = 8192 32 | self.n = 1 33 | self.limit = None 34 | 35 | 36 | def run_inference(item, config: ScriptConfig): 37 | # making the ordering of requests to the server more consistent with multiple workers 38 | if config.workers != 0: 39 | index = item["shuffled_index"] 40 | time.sleep(0.1 * index) 41 | 42 | client = config.client() 43 | 44 | # https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B#usage-recommendations 45 | math_directive = ( 46 | r"Please reason step-by-step, and put your final answer within \boxed{}." 47 | ) 48 | prompt = f"{math_directive} Question: {item['problem']}\nAnswer: \n" 49 | 50 | response = client.completions.create( 51 | model=config.model, 52 | prompt=prompt, 53 | max_tokens=config.max_tokens, 54 | temperature=config.temperature, 55 | top_p=config.top_p, 56 | stop=config.stop_strings, 57 | n=config.n, 58 | ) 59 | 60 | completions = [choice.text for choice in response.choices] 61 | assert len(completions) == config.n 62 | 63 | gt_answer = item["answer"] 64 | corrects = [] 65 | for completion in completions: 66 | try: 67 | score = is_correct_math500(completion, gt_answer) 68 | corrects.append(score) 69 | except Exception: 70 | score = 0 71 | corrects.append(score) 72 | return corrects 73 | 74 | 75 | def main(config: ScriptConfig): 76 | raw_test_dataset = list( 77 | load_dataset("HuggingFaceH4/MATH-500", "default", split="test") 78 | ) 79 | 80 | print(f"Number of test items: {len(raw_test_dataset)}") 81 | 82 | test_dataset = shuffle_and_limit(raw_test_dataset, config) 83 | 84 | print(f"Total number of items to process: {len(test_dataset)}") 85 | 86 | go_func = partial(run_inference, config=config) 87 | 88 | with launch_server(config): 89 | start = time.time() 90 | corrects_list = parallelize( 91 | fn=go_func, 92 | items=test_dataset, 93 | num_workers=config.workers, 94 | allow_unordered=True, 95 | ) 96 | end = time.time() 97 | 98 | elapsed = end - start 99 | print(f"Time taken: {elapsed} seconds") 100 | 101 | table = make_pass_at_k_table(corrects_list, config.ks) 102 | 103 | print(tabulate(table, headers=["k", "pass@k"], tablefmt="github")) 104 | 105 | maybe_save_results( 106 | config, 107 | { 108 | "elapsed": elapsed, 109 | "pass_at_k": table, 110 | "launch": config.launch, 111 | }, 112 | ) 113 | 114 | 115 | if __name__ == "__main__": 116 | pydra.run(main) 117 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | local/ 2 | .DS_Store 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 113 | .pdm.toml 114 | .pdm-python 115 | .pdm-build/ 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /CLAUDE.md: -------------------------------------------------------------------------------- 1 | # CLAUDE.md 2 | 3 | This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. 4 | 5 | ## Project Overview 6 | 7 | Tokasaurus is a high-performance LLM inference engine designed for high-throughput workloads. It implements advanced features for efficient LLM serving including data/pipeline/tensor parallelism, paged KV caching, Hydragen optimization, and end-to-end torch compilation. 8 | 9 | ## Architecture 10 | 11 | The system uses a three-tier architecture: 12 | - **Web Server** (`tokasaurus/server/`): FastAPI-based HTTP server with OpenAI-compatible APIs 13 | - **Manager** (`tokasaurus/manager/`): CPU-side orchestration handling scheduling, KV cache management, and request batching 14 | - **Model Worker** (`tokasaurus/model/`): GPU-side execution of model forward passes 15 | 16 | Communication between components uses async queues with protobuf-like message passing. 17 | 18 | ## Development Commands 19 | 20 | ### Running the Server 21 | ```bash 22 | # Single GPU 23 | toka model=meta-llama/Llama-3.2-1B-Instruct 24 | 25 | # Multi-GPU with pipeline parallelism 26 | toka model=meta-llama/Llama-3.1-70B-Instruct pp_size=8 27 | 28 | # Custom configuration 29 | toka model= tp_size=2 max_seqs_per_forward=256 30 | ``` 31 | 32 | ### Testing 33 | ```bash 34 | # Run all tests 35 | pytest tests/ 36 | 37 | # Run specific test 38 | pytest tests/test_basic.py::test_basic 39 | 40 | # Test with specific model 41 | MODEL=meta-llama/Llama-3.2-1B-Instruct pytest tests/test_basic.py 42 | ``` 43 | 44 | ### Testing Server Connectivity 45 | ```bash 46 | # Basic ping test 47 | toka-ping prompt='Hello world' max_tokens=100 48 | 49 | # Chat mode 50 | toka-ping prompt='Hello world' max_tokens=100 chat=True 51 | ``` 52 | 53 | ## Configuration System 54 | 55 | Uses Pydra for configuration management. All configuration is in `tokasaurus/common_types.py` as the `EngineConfig` class. Pass config options as `key=value` pairs: 56 | 57 | ```bash 58 | toka model= kv_cache_num_tokens=100000 max_tokens_per_forward=8192 59 | ``` 60 | 61 | Key configuration groups: 62 | - Model loading: `model`, `tokenizer`, `trust_remote_code` 63 | - Parallelism: `dp_size`, `pp_size`, `tp_size` 64 | - Memory: `kv_cache_num_tokens`, `page_size`, `max_seqs_per_forward` 65 | - Performance: `torch_compile`, `use_cudagraphs`, `use_hydragen` 66 | 67 | ## Code Organization 68 | 69 | - Entry points are in `entry.py` (main), `scripts/ping.py`, `scripts/download.py` 70 | - Core types and config in `common_types.py` 71 | - Manager logic for scheduling and allocation in `manager/` 72 | - Model implementations in `model/` (supports Llama and Qwen) 73 | - HTTP endpoints in `server/endpoints.py` 74 | 75 | ## Key Implementation Details 76 | 77 | 1. **KV Cache Management**: Uses paged allocation with configurable page sizes. Implementation in `manager/allocator.py` and `model/kv_cache.py`. 78 | 79 | 2. **Scheduling**: Advanced scheduler in `manager/scheduler.py` that simulates future KV cache usage to make optimal batching decisions. 80 | 81 | 3. **Hydragen**: Shared prefix optimization in `manager/hydragen.py` that groups requests with common prefixes for efficient attention computation. 82 | 83 | 4. **Model Support**: Currently supports Llama and Qwen models. New models can be added in `model/` directory following the existing pattern. 84 | 85 | 5. **Compilation**: Supports torch.compile with dynamic shapes and CUDA graphs. Warmup phase automatically triggers recompiles. 86 | 87 | ## Common Development Tasks 88 | 89 | ### Adding a New Model 90 | 1. Create new file in `model/` directory 91 | 2. Implement the model class following `llama.py` or `qwen.py` pattern 92 | 3. Register in `model/__init__.py` 93 | 4. Update `get_model_cls()` in `model/__init__.py` 94 | 95 | ### Debugging 96 | - Set `log_level=DEBUG` for verbose logging 97 | - Use `statsd_addr` to enable metrics collection 98 | - Check process logs - each component (server, manager, model) logs separately 99 | 100 | ### Performance Tuning 101 | - Adjust `max_seqs_per_forward` and `max_tokens_per_forward` for memory/throughput tradeoff 102 | - Enable `use_cudagraphs` for lower latency (requires warmup) 103 | - Use `use_hydragen` for workloads with shared prefixes 104 | - Configure `torch_compile` options in `torch_compile_config` -------------------------------------------------------------------------------- /tokasaurus/entry.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | import pydra 4 | import torch.multiprocessing as mp 5 | 6 | from tokasaurus.common_types import Engine, ProcessInfo, ServerConfig, TimedBarrier 7 | from tokasaurus.manager.manager import start_manager 8 | from tokasaurus.model.entry import get_model_process_dict 9 | from tokasaurus.server.endpoints import start_server 10 | from tokasaurus.utils import find_free_port 11 | 12 | 13 | def cleanup_processes(processes: list[mp.Process]): 14 | for process in processes: 15 | process.kill() 16 | 17 | for process in processes: 18 | process.join() 19 | 20 | 21 | def make_engine(config: ServerConfig, dp_rank: int, master_port: int): 22 | q_server_to_manager = mp.Queue() 23 | q_manager_to_server = mp.Queue() 24 | 25 | q_manager_to_model = mp.Queue() 26 | q_model_to_manager = mp.Queue() 27 | 28 | # Start the model process 29 | process_dict = get_model_process_dict( 30 | config=config, 31 | q_manager_to_model=q_manager_to_model, 32 | q_model_to_manager=q_model_to_manager, 33 | dp_rank=dp_rank, 34 | master_port=master_port, 35 | ) 36 | process_dict["manager"] = ProcessInfo( 37 | target=start_manager, 38 | kwargs={ 39 | "config": config, 40 | "q_manager_to_model": q_manager_to_model, 41 | "q_model_to_manager": q_model_to_manager, 42 | "q_server_to_manager": q_server_to_manager, 43 | "q_manager_to_server": q_manager_to_server, 44 | }, 45 | ) 46 | 47 | if config.dp_size > 1: 48 | process_dict = {f"dp{dp_rank}_{k}": v for k, v in process_dict.items()} 49 | 50 | engine = Engine( 51 | q_server_to_manager=q_server_to_manager, 52 | q_manager_to_server=q_manager_to_server, 53 | proc_dict=process_dict, 54 | ) 55 | 56 | return engine 57 | 58 | 59 | def make_proc_dict(config: ServerConfig, add_extra_barrier_member: bool = False): 60 | master_port = find_free_port() 61 | engines = [ 62 | make_engine(config, dp_rank, master_port) for dp_rank in range(config.dp_size) 63 | ] 64 | 65 | pooled_proc_dict: dict[str, ProcessInfo] = {} 66 | for engine in engines: 67 | for proc_name, proc_info in engine.proc_dict.items(): 68 | assert proc_name not in pooled_proc_dict 69 | pooled_proc_dict[proc_name] = proc_info 70 | 71 | pooled_proc_dict["server"] = ProcessInfo( 72 | target=start_server, 73 | kwargs={ 74 | "config": config, 75 | "engines": engines, 76 | }, 77 | ) 78 | 79 | num_procs = len(pooled_proc_dict) 80 | barrier_size = num_procs + 1 if add_extra_barrier_member else num_procs 81 | barrier = TimedBarrier(barrier_size, "System startup time") 82 | 83 | for proc_name, proc_info in pooled_proc_dict.items(): 84 | proc_info.kwargs["barrier"] = barrier 85 | proc_info.kwargs["process_name"] = proc_name 86 | 87 | return pooled_proc_dict, barrier 88 | 89 | 90 | @contextmanager 91 | def server_manager(config: ServerConfig, finalize=True): 92 | mp.set_start_method("spawn", force=True) 93 | 94 | if finalize: 95 | config.finalize() 96 | 97 | process_dict, barrier = make_proc_dict(config, add_extra_barrier_member=True) 98 | 99 | processes = [] 100 | 101 | try: 102 | for _, process_info in process_dict.items(): 103 | p = process_info.make_process() 104 | p.start() 105 | processes.append(p) 106 | 107 | barrier.wait() 108 | 109 | yield 110 | finally: 111 | cleanup_processes(processes) 112 | 113 | 114 | def start(config: ServerConfig): 115 | mp.set_start_method("spawn", force=True) 116 | 117 | process_dict, _ = make_proc_dict(config, add_extra_barrier_member=False) 118 | 119 | print(f"Starting {len(process_dict)} processes: {list(process_dict.keys())}") 120 | print(f"Running in the main process: {config.local_proc_name}") 121 | 122 | processes = [] 123 | for proc_name, process_info in process_dict.items(): 124 | if proc_name == config.local_proc_name: 125 | continue 126 | p = process_info.make_process() 127 | p.start() 128 | processes.append(p) 129 | 130 | try: 131 | local_proc = process_dict[config.local_proc_name] 132 | local_proc.target(*local_proc.args, **local_proc.kwargs) 133 | except KeyboardInterrupt: 134 | print("KeyboardInterrupt detected. Terminating all subprocesses.") 135 | cleanup_processes(processes) 136 | print("... All subprocesses cleaned up.") 137 | 138 | 139 | def main(): 140 | """ 141 | For use as a setup.py entry point. 142 | """ 143 | pydra.run(start) 144 | 145 | 146 | if __name__ == "__main__": 147 | main() 148 | -------------------------------------------------------------------------------- /tokasaurus/model/entry.py: -------------------------------------------------------------------------------- 1 | import torch.multiprocessing as mp 2 | 3 | from tokasaurus.common_types import ProcessInfo, ServerConfig 4 | from tokasaurus.model.basic_worker import start_basic_model_worker, start_fanout_worker 5 | from tokasaurus.model.pipeline_worker import ( 6 | start_pipeline_worker, 7 | ) 8 | 9 | 10 | def get_pipeline_model_process_dict( 11 | config: ServerConfig, 12 | q_manager_to_model: mp.Queue, 13 | q_model_to_manager: mp.Queue, 14 | dp_rank: int, 15 | master_port: int, 16 | ): 17 | pp_size = config.pp_size 18 | tp_size = config.tp_size 19 | gpus_per_replica = pp_size * tp_size 20 | 21 | input_qs = [mp.Queue() for _ in range(gpus_per_replica)] 22 | qs_pipe_end_to_start = [mp.Queue() for _ in range(tp_size)] 23 | 24 | process_dict = {} 25 | 26 | for pp_rank in range(config.pp_size): 27 | for tp_rank in range(config.tp_size): 28 | input_q = input_qs[pp_rank * tp_size + tp_rank] 29 | q_pipe_end_to_start = qs_pipe_end_to_start[tp_rank] 30 | 31 | worker_pinfo = ProcessInfo( 32 | target=start_pipeline_worker, 33 | kwargs={ 34 | "config": config, 35 | "input_q": input_q, 36 | "q_pipe_end_to_start": q_pipe_end_to_start, 37 | "q_to_manager": q_model_to_manager, 38 | "pp_rank": pp_rank, 39 | "tp_rank": tp_rank, 40 | "dp_rank": dp_rank, 41 | "master_port": master_port, 42 | }, 43 | ) 44 | 45 | if pp_size > 1 and tp_size > 1: 46 | name = f"model_worker_pp{pp_rank}_tp{tp_rank}" 47 | elif pp_size > 1: 48 | name = f"model_worker_pp{pp_rank}" 49 | elif tp_size > 1: 50 | name = f"model_worker_tp{tp_rank}" 51 | else: 52 | raise ValueError("Shouldn't happen") 53 | 54 | process_dict[name] = worker_pinfo 55 | 56 | leader_process = ProcessInfo( 57 | target=start_fanout_worker, 58 | kwargs={ 59 | "config": config, 60 | "input_q": q_manager_to_model, 61 | "fanout_qs": input_qs, 62 | }, 63 | ) 64 | 65 | process_dict["fanout_worker"] = leader_process 66 | 67 | return process_dict 68 | 69 | 70 | def get_basic_model_process_dict( 71 | config: ServerConfig, 72 | q_manager_to_model: mp.Queue, 73 | q_model_to_manager: mp.Queue, 74 | dp_rank: int, 75 | master_port: int, 76 | ): 77 | process_info = ProcessInfo( 78 | target=start_basic_model_worker, 79 | kwargs={ 80 | "config": config, 81 | "input_q": q_manager_to_model, 82 | "q_model_to_manager": q_model_to_manager, 83 | "dp_rank": dp_rank, 84 | "tp_rank": 0, 85 | "master_port": master_port, 86 | }, 87 | ) 88 | 89 | return { 90 | "model_worker": process_info, 91 | } 92 | 93 | 94 | def get_tp_model_process_dict( 95 | config: ServerConfig, 96 | q_manager_to_model: mp.Queue, 97 | q_model_to_manager: mp.Queue, 98 | dp_rank: int, 99 | master_port: int, 100 | ): 101 | process_dict = {} 102 | 103 | input_qs = [mp.Queue() for _ in range(config.tp_size)] 104 | 105 | for tp_rank in range(config.tp_size): 106 | process_dict[f"model_worker_tp{tp_rank}"] = ProcessInfo( 107 | target=start_basic_model_worker, 108 | kwargs={ 109 | "config": config, 110 | "input_q": input_qs[tp_rank], 111 | "q_model_to_manager": q_model_to_manager, 112 | "dp_rank": dp_rank, 113 | "tp_rank": tp_rank, 114 | "master_port": master_port, 115 | }, 116 | ) 117 | 118 | process_dict["fanout_worker"] = ProcessInfo( 119 | target=start_fanout_worker, 120 | kwargs={ 121 | "config": config, 122 | "input_q": q_manager_to_model, 123 | "fanout_qs": input_qs, 124 | }, 125 | ) 126 | 127 | return process_dict 128 | 129 | 130 | def get_model_process_dict( 131 | config: ServerConfig, 132 | q_manager_to_model: mp.Queue, 133 | q_model_to_manager: mp.Queue, 134 | dp_rank: int, 135 | master_port: int, 136 | ): 137 | if config.pp_size > 1: 138 | return get_pipeline_model_process_dict( 139 | config=config, 140 | q_manager_to_model=q_manager_to_model, 141 | q_model_to_manager=q_model_to_manager, 142 | dp_rank=dp_rank, 143 | master_port=master_port, 144 | ) 145 | elif config.tp_size > 1: 146 | return get_tp_model_process_dict( 147 | config=config, 148 | q_manager_to_model=q_manager_to_model, 149 | q_model_to_manager=q_model_to_manager, 150 | dp_rank=dp_rank, 151 | master_port=master_port, 152 | ) 153 | else: 154 | return get_basic_model_process_dict( 155 | config=config, 156 | q_manager_to_model=q_manager_to_model, 157 | q_model_to_manager=q_model_to_manager, 158 | dp_rank=dp_rank, 159 | master_port=master_port, 160 | ) 161 | -------------------------------------------------------------------------------- /tests/test_block_allocator.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import pytest 4 | 5 | from tokasaurus.manager.allocator import BlockAllocator, NoSpaceException 6 | 7 | PAGE_SIZE = 4 8 | NUM_BLOCKS = 8 9 | 10 | 11 | @pytest.fixture 12 | def allocator(): 13 | return BlockAllocator(page_size=PAGE_SIZE, num_blocks=NUM_BLOCKS) 14 | 15 | 16 | def test_basic_allocation(allocator: BlockAllocator): 17 | allocator.sanity_checks() 18 | 19 | input_ids1 = [1] * 16 20 | input_ids2 = [2] * 9 21 | input_ids3 = [3] * 8 22 | 23 | allocator.sanity_checks() 24 | 25 | kvs1, num_cached1 = allocator.allocate_with_prefix_match("seq1", input_ids1) 26 | allocator.sanity_checks() 27 | 28 | assert num_cached1 == 0 29 | assert len(kvs1) == math.ceil(len(input_ids1) / PAGE_SIZE) 30 | 31 | kvs2, num_cached2 = allocator.allocate_with_prefix_match("seq2", input_ids2) 32 | allocator.sanity_checks() 33 | 34 | assert num_cached2 == 0 35 | assert len(kvs2) == math.ceil(len(input_ids2) / PAGE_SIZE) 36 | 37 | with pytest.raises(NoSpaceException): 38 | allocator.allocate_with_prefix_match("seq3", input_ids3) 39 | 40 | 41 | def test_basic_caching(allocator: BlockAllocator): 42 | allocator.sanity_checks() 43 | 44 | input_ids1 = [1] * 12 45 | input_ids2 = [1] * 8 46 | input_ids3 = [2] 47 | input_ids4 = [3] 48 | 49 | completion_ids1 = input_ids1 + [9] * 2 50 | completion_ids2 = input_ids2 + [9] * 3 51 | completion_ids3 = input_ids3 + [9] * 4 52 | 53 | kvs1, num_cached1 = allocator.allocate_with_prefix_match("seq1", input_ids1) 54 | allocator.sanity_checks() 55 | 56 | assert num_cached1 == 0 57 | 58 | kvs1.extend(allocator.allocate_up_to_length("seq1", kvs1, len(completion_ids1))) 59 | allocator.sanity_checks() 60 | 61 | kvs2, num_cached2 = allocator.allocate_with_prefix_match("seq2", input_ids2) 62 | allocator.sanity_checks() 63 | 64 | # 4, not 8, since last block isn't cached 65 | assert num_cached2 == 4 66 | assert kvs2[0] == kvs1[0] 67 | assert kvs2[1] != kvs1[1] 68 | 69 | kvs2.extend(allocator.allocate_up_to_length("seq2", kvs2, len(completion_ids2))) 70 | allocator.sanity_checks() 71 | 72 | kvs3, num_cached3 = allocator.allocate_with_prefix_match("seq3", input_ids3) 73 | allocator.sanity_checks() 74 | 75 | assert num_cached3 == 0 76 | 77 | kvs3.extend(allocator.allocate_up_to_length("seq3", kvs3, len(completion_ids3))) 78 | allocator.sanity_checks() 79 | 80 | with pytest.raises(NoSpaceException): 81 | allocator.allocate_with_prefix_match("seq4", input_ids4) 82 | 83 | 84 | def test_basic_free(allocator: BlockAllocator): 85 | allocator.sanity_checks() 86 | 87 | input_ids1 = [1] * 31 88 | input_ids2 = [2] * 31 89 | input_ids3 = [3] * 31 90 | 91 | completion_ids1 = input_ids1 + [1] 92 | completion_ids2 = input_ids2 + [2] 93 | 94 | kvs1, num_cached1 = allocator.allocate_with_prefix_match("seq1", input_ids1) 95 | allocator.sanity_checks() 96 | 97 | with pytest.raises(ValueError): 98 | allocator.allocate_with_prefix_match("seq2", input_ids2) 99 | 100 | kvs1.extend(allocator.allocate_up_to_length("seq1", kvs1, len(completion_ids1))) 101 | allocator.sanity_checks() 102 | allocator.free_and_update("seq1", kvs1, completion_ids1) 103 | allocator.sanity_checks() 104 | 105 | kvs2, num_cached2 = allocator.allocate_with_prefix_match("seq2", input_ids2) 106 | allocator.sanity_checks() 107 | 108 | assert num_cached2 == 0 109 | 110 | kvs2.extend(allocator.allocate_up_to_length("seq2", kvs2, len(completion_ids2))) 111 | allocator.sanity_checks() 112 | allocator.free_and_update("seq2", kvs2, completion_ids2) 113 | allocator.sanity_checks() 114 | 115 | kvs3, num_cached3 = allocator.allocate_with_prefix_match("seq3", input_ids3) 116 | allocator.sanity_checks() 117 | 118 | assert num_cached3 == 0 119 | 120 | 121 | def test_multi_branch_free(allocator: BlockAllocator): 122 | allocator.sanity_checks() 123 | 124 | input_ids1 = [1] * 15 125 | input_ids2 = [2] * 15 126 | input_ids3 = [1] * 8 + [3] * 4 127 | 128 | completion_ids1 = input_ids1 + [9] 129 | completion_ids2 = input_ids2 + [9] 130 | completion_ids3 = input_ids3 + [9] * 20 131 | 132 | kvs1, num_cached1 = allocator.allocate_with_prefix_match("seq1", input_ids1) 133 | allocator.sanity_checks() 134 | 135 | kvs1.extend(allocator.allocate_up_to_length("seq1", kvs1, len(completion_ids1))) 136 | allocator.free_and_update("seq1", kvs1, completion_ids1) 137 | 138 | kvs2, num_cached2 = allocator.allocate_with_prefix_match("seq2", input_ids2) 139 | allocator.sanity_checks() 140 | 141 | kvs2.extend(allocator.allocate_up_to_length("seq2", kvs2, len(completion_ids2))) 142 | allocator.sanity_checks() 143 | 144 | allocator.free_and_update("seq2", kvs2, completion_ids2) 145 | allocator.sanity_checks() 146 | 147 | kvs3, num_cached3 = allocator.allocate_with_prefix_match("seq3", input_ids3) 148 | allocator.sanity_checks() 149 | 150 | assert num_cached3 == 8 151 | 152 | kvs3.extend(allocator.allocate_up_to_length("seq3", kvs3, len(completion_ids3))) 153 | allocator.sanity_checks() 154 | 155 | allocator.free_and_update("seq3", kvs3, completion_ids3) 156 | allocator.sanity_checks() 157 | 158 | 159 | if __name__ == "__main__": 160 | pytest.main([__file__]) 161 | -------------------------------------------------------------------------------- /tokasaurus/model/attention_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from flashinfer import ( 3 | BatchDecodeWithPagedKVCacheWrapper, 4 | BatchPrefillWithPagedKVCacheWrapper, 5 | cascade, 6 | ) 7 | from torch import Tensor 8 | 9 | from tokasaurus.model.types import ( 10 | AttentionInfo, 11 | DeviceType, 12 | WrapperCollection, 13 | ) 14 | 15 | 16 | def create_workspace_buffer(device: DeviceType): 17 | # flashinfer recommends a 128MB buffer 18 | return torch.empty( 19 | 128 * 1024 * 1024, 20 | dtype=torch.uint8, 21 | device=device, 22 | ) 23 | 24 | 25 | def create_wrappers( 26 | device: DeviceType, 27 | num_attention_heads: int, 28 | num_key_value_heads: int, 29 | workspace_buffer: Tensor | None = None, 30 | ): 31 | if workspace_buffer is None: 32 | workspace_buffer = create_workspace_buffer(device) 33 | 34 | gqa_ratio = num_attention_heads // num_key_value_heads 35 | 36 | # NOTE: I think it's ok to reuse the buffers across both wrappers 37 | prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer) 38 | hydragen_wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer) 39 | decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( 40 | workspace_buffer, use_tensor_cores=gqa_ratio >= 4 41 | ) 42 | 43 | return WrapperCollection( 44 | prefill_wrapper=prefill_wrapper, 45 | hydragen_wrapper=hydragen_wrapper, 46 | decode_wrapper=decode_wrapper, 47 | ) 48 | 49 | 50 | def create_wrappers_for_cudagraph( 51 | device: DeviceType, 52 | num_attention_heads: int, 53 | num_key_value_heads: int, 54 | num_decode_sequences: int, 55 | max_kv_indices: int, 56 | workspace_buffer: Tensor | None = None, 57 | ): 58 | if workspace_buffer is None: 59 | workspace_buffer = create_workspace_buffer(device) 60 | 61 | gqa_ratio = num_attention_heads // num_key_value_heads 62 | 63 | decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( 64 | workspace_buffer, 65 | use_tensor_cores=gqa_ratio >= 4, 66 | use_cuda_graph=True, 67 | paged_kv_indptr_buffer=torch.empty( 68 | num_decode_sequences + 1, 69 | dtype=torch.int32, 70 | device=device, 71 | ), 72 | paged_kv_indices_buffer=torch.empty( 73 | max_kv_indices, dtype=torch.int32, device=device 74 | ), 75 | paged_kv_last_page_len_buffer=torch.empty( 76 | num_decode_sequences, dtype=torch.int32, device=device 77 | ), 78 | ) 79 | 80 | return WrapperCollection( 81 | prefill_wrapper=None, 82 | hydragen_wrapper=None, 83 | decode_wrapper=decode_wrapper, 84 | ) 85 | 86 | 87 | def append_to_kv_cache( 88 | token_indices: Tensor, 89 | key: Tensor, 90 | value: Tensor, 91 | k_cache: Tensor, 92 | v_cache: Tensor, 93 | ): 94 | """ 95 | Important to back out of torch compile for this op, since the compiler 96 | seemed to be making a copy of the cache, taking a lot of mem/time. 97 | """ 98 | 99 | _, num_key_value_heads, head_dim = key.shape 100 | 101 | flat_k_cache = k_cache.view(-1, num_key_value_heads, head_dim) 102 | flat_v_cache = v_cache.view(-1, num_key_value_heads, head_dim) 103 | 104 | flat_k_cache[token_indices] = key 105 | flat_v_cache[token_indices] = value 106 | 107 | 108 | def tokasaurus_attention( 109 | ragged_q: Tensor, 110 | ragged_k: Tensor, 111 | ragged_v: Tensor, 112 | k_cache: Tensor, 113 | v_cache: Tensor, 114 | attn_info: AttentionInfo, 115 | wrappers: WrapperCollection, 116 | ) -> Tensor: 117 | """ 118 | Assumes rope has been already applied. 119 | """ 120 | 121 | append_to_kv_cache( 122 | token_indices=attn_info.append_kv_token_indices, 123 | key=ragged_k, 124 | value=ragged_v, 125 | k_cache=k_cache, 126 | v_cache=v_cache, 127 | ) 128 | 129 | prefill_q, hydragen_q, decode_q = attn_info.split_q(ragged_q) 130 | 131 | # the key difference between the hydragen shared 132 | # prefix attention and normal prefill 133 | # is that hydragen does not have a causal mask 134 | if prefill_q.numel() > 0: 135 | prefill_wrapper = wrappers.prefill_wrapper 136 | assert prefill_wrapper is not None 137 | true_prefill_output = prefill_wrapper.run( 138 | q=prefill_q, paged_kv_cache=(k_cache, v_cache) 139 | ) 140 | else: 141 | true_prefill_output = prefill_q 142 | 143 | # decode 144 | if decode_q.numel() > 0: 145 | decode_wrapper = wrappers.decode_wrapper 146 | assert decode_wrapper is not None 147 | decode_output, decode_lse = decode_wrapper.run_return_lse( 148 | q=decode_q, paged_kv_cache=(k_cache, v_cache) 149 | ) 150 | else: 151 | decode_output = decode_q 152 | 153 | if hydragen_q.numel() > 0: 154 | hydragen_wrapper = wrappers.hydragen_wrapper 155 | assert hydragen_wrapper is not None 156 | shared_prefill_output, shared_prefill_lse = hydragen_wrapper.run_return_lse( 157 | q=hydragen_q, paged_kv_cache=(k_cache, v_cache) 158 | ) 159 | 160 | # Unique (decode) 161 | n_mixed = attn_info.hydragen_info.num_tokens 162 | unique_lse = decode_lse[:n_mixed] 163 | unique_out = decode_output[:n_mixed] 164 | 165 | aggregate, _ = cascade.merge_state( 166 | shared_prefill_output, shared_prefill_lse, unique_out, unique_lse 167 | ) 168 | 169 | true_decode_out = decode_output[n_mixed:] 170 | output = torch.cat([true_prefill_output, aggregate, true_decode_out], dim=0) 171 | 172 | else: 173 | output = torch.cat([true_prefill_output, decode_output], dim=0) 174 | 175 | return output 176 | -------------------------------------------------------------------------------- /tests/test_logprobs.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import os 3 | import shlex 4 | 5 | import pydra 6 | import pytest 7 | import torch 8 | import torch.multiprocessing as mp 9 | from openai import OpenAI 10 | from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer 11 | 12 | from tokasaurus.common_types import ServerConfig 13 | from tokasaurus.entry import server_manager 14 | from tokasaurus.utils import find_free_port 15 | 16 | MODEL = os.environ.get( 17 | "MODEL", 18 | "Qwen/Qwen2-0.5B-Instruct", 19 | ) 20 | OVERRIDES = os.environ.get("OVERRIDES", None) 21 | MEAN_REL_TOL_LIMIT = ast.literal_eval(os.environ.get("MEAN_REL_TOL_LIMIT", "0.1")) 22 | MAX_ABS_TOL_LIMIT = ast.literal_eval(os.environ.get("MAX_ABS_TOL_LIMIT", "0.2")) 23 | TOKEN_MATCH_LIMIT = ast.literal_eval(os.environ.get("TOKEN_MATCH_LIMIT", "0.95")) 24 | 25 | 26 | def make_basic_config(): 27 | print(f"Making basic config for {MODEL}...") 28 | config = ServerConfig() 29 | config.model = MODEL 30 | config.kv_cache_num_tokens = 16384 31 | config.max_num_tokens_per_request = 16384 32 | config.max_seqs_per_forward = 1024 33 | config.port = find_free_port() 34 | 35 | if OVERRIDES: 36 | # split apart like a shell, respecting quotes 37 | parsed_overrides = shlex.split(OVERRIDES) 38 | pydra.apply_overrides(config, parsed_overrides) 39 | 40 | return config 41 | 42 | 43 | @pytest.fixture(scope="module", params=[make_basic_config()]) 44 | def client(request): 45 | mp.set_start_method("spawn", force=True) 46 | 47 | config: ServerConfig = request.param 48 | print(f"Launching server with config: {config.to_dict()}") 49 | 50 | with server_manager(config): 51 | client = OpenAI( 52 | api_key="beepboop", base_url=f"http://localhost:{config.port}/v1" 53 | ) 54 | 55 | yield client 56 | 57 | 58 | @pytest.fixture(scope="module") 59 | def hf_model_and_tokenizer() -> tuple[torch.nn.Module, PreTrainedTokenizer]: 60 | print(f"Loading HF model and tokenizer ({MODEL})...") 61 | tokenizer = AutoTokenizer.from_pretrained(MODEL) 62 | model = AutoModelForCausalLM.from_pretrained( 63 | MODEL, torch_dtype=torch.bfloat16, device_map="cuda:0" 64 | ) 65 | model.eval() 66 | print("Loaded HF model and tokenizer.") 67 | return model, tokenizer 68 | 69 | 70 | PROMPTS = { 71 | "abc": "Please repeat the following pattern:" 72 | + "a b c d e f g h i j k l m n o p q r s a b c d e f g h i j k l m n o p q r s" 73 | * 10, 74 | "story": "Please tell me a long story about a cat.", 75 | } 76 | 77 | 78 | @pytest.mark.parametrize("prompt_name", list(PROMPTS.keys())) 79 | def test_logprobs( 80 | client: OpenAI, 81 | hf_model_and_tokenizer: tuple[torch.nn.Module, PreTrainedTokenizer], 82 | prompt_name: str, 83 | ): 84 | prompt = PROMPTS[prompt_name] 85 | response = client.chat.completions.create( 86 | model="none", 87 | messages=[ 88 | {"role": "user", "content": prompt}, 89 | ], 90 | max_tokens=64, 91 | temperature=0.0, 92 | logprobs=True, 93 | ) 94 | model, tokenizer = hf_model_and_tokenizer 95 | 96 | for idx, choice in enumerate(response.choices): 97 | api_tokens = [token_logprob.token for token_logprob in choice.logprobs.content] 98 | logprobs = [token_logprob.logprob for token_logprob in choice.logprobs.content] 99 | 100 | seq_ids = tokenizer.convert_tokens_to_ids(api_tokens) 101 | 102 | input_ids = ( 103 | tokenizer.apply_chat_template( 104 | [ 105 | {"role": "user", "content": prompt}, 106 | ], 107 | add_generation_prompt=True, 108 | ) 109 | + seq_ids 110 | ) 111 | with torch.inference_mode(): 112 | input_tensor = torch.tensor(input_ids).unsqueeze(0).to("cuda:0") 113 | outputs = model(input_tensor) 114 | 115 | logits = outputs.logits.to(torch.float32) # shape [1, seq_len, vocab_size] 116 | hf_logprobs = torch.nn.functional.log_softmax(logits, dim=-1) 117 | 118 | token_matches = [] 119 | logprob_adiffs = [] 120 | logprob_rdiffs = [] 121 | 122 | for idx, (api_token_id, hf_logprob_dist, api_logprob) in enumerate( 123 | zip(seq_ids, hf_logprobs[0, -len(seq_ids) - 1 : -1], logprobs) 124 | ): 125 | hf_logprob = hf_logprob_dist[api_token_id].item() 126 | hf_token_id = hf_logprob_dist.argmax().item() 127 | 128 | token_match = hf_token_id == api_token_id 129 | token_matches.append(token_match) 130 | 131 | adiff = abs(api_logprob - hf_logprob) 132 | rdiff = 2 * adiff / (abs(api_logprob) + abs(hf_logprob) + 1e-3) 133 | logprob_adiffs.append(adiff) 134 | logprob_rdiffs.append(rdiff) 135 | print( 136 | f"Pos {idx}: token match: {token_match}, logprob adiff: {adiff:.4f}, rdiff: {rdiff:.4f} (API: token={api_token_id} logprob={api_logprob:.4f}, HF: token={hf_token_id} logprob={hf_logprob:.4f})" 137 | ) 138 | 139 | token_match_rate = sum(token_matches) / len(token_matches) 140 | max_adiff = max(logprob_adiffs) 141 | mean_rdiff = sum(logprob_rdiffs) / len(logprob_rdiffs) 142 | 143 | print(f"Token match rate: {token_match_rate:.4f}") 144 | print(f"Max logprob absolute diff: {max_adiff:.4f}") 145 | print(f"Mean logprob relative diff: {mean_rdiff:.4f}") 146 | 147 | if TOKEN_MATCH_LIMIT is not None: 148 | assert token_match_rate >= TOKEN_MATCH_LIMIT, ( 149 | f"Token match rate: {token_match_rate} < {TOKEN_MATCH_LIMIT}" 150 | ) 151 | 152 | if MEAN_REL_TOL_LIMIT is not None: 153 | assert mean_rdiff <= MEAN_REL_TOL_LIMIT, ( 154 | f"Mean logprob relative diff: {mean_rdiff} > {MEAN_REL_TOL_LIMIT}" 155 | ) 156 | 157 | if MAX_ABS_TOL_LIMIT is not None: 158 | assert max_adiff <= MAX_ABS_TOL_LIMIT, ( 159 | f"Max logprob absolute diff: {max_adiff} > {MAX_ABS_TOL_LIMIT}" 160 | ) 161 | -------------------------------------------------------------------------------- /tokasaurus/model/qwen3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | from transformers import Qwen3Config 4 | 5 | from tokasaurus.model.llama import ( 6 | LlamaAttention, 7 | LlamaBlock, 8 | LlamaForCausalLM, 9 | LlamaModel, 10 | apply_rotary_pos_emb, 11 | reduce_scatter, 12 | ) 13 | from tokasaurus.model.types import BatchState 14 | 15 | 16 | class Qwen3RMSNorm(nn.Module): 17 | """RMSNorm for head dimension (used in q_norm and k_norm)""" 18 | 19 | def __init__(self, head_dim: int, eps: float = 1e-6): 20 | super().__init__() 21 | self.weight = nn.Parameter(torch.ones(head_dim)) 22 | self.eps = eps 23 | 24 | def forward(self, hidden_states: Tensor): 25 | input_dtype = hidden_states.dtype 26 | hidden_states = hidden_states.to(torch.float32) 27 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 28 | hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 29 | return self.weight * hidden_states.to(input_dtype) 30 | 31 | 32 | class Qwen3Attention(LlamaAttention): 33 | qkv_bias: bool = False # Qwen3 doesn't use bias in projection layers 34 | 35 | def __init__(self, config, layer_idx, extra_config): 36 | super().__init__(config, layer_idx, extra_config) 37 | 38 | # Add query and key normalization as in Qwen3 39 | self.q_norm = Qwen3RMSNorm(self.head_dim(), config.rms_norm_eps) 40 | self.k_norm = Qwen3RMSNorm(self.head_dim(), config.rms_norm_eps) 41 | 42 | def head_dim(self): 43 | # Qwen3 uses explicit head_dim in the config, instead of inferring it from hidden_size and num_attention_heads 44 | return self.config.head_dim 45 | 46 | def forward( 47 | self, 48 | batch_state: BatchState, 49 | ): 50 | assert batch_state.hidden_states is not None 51 | assert batch_state.position_embeddings is not None 52 | assert self.layer_cache is not None 53 | assert self.layer_cache.v_cache is not None 54 | 55 | inp = batch_state.hidden_states 56 | residual = inp 57 | 58 | hidden_states = self.input_layernorm(inp) 59 | 60 | from tokasaurus.model.llama import all_gather 61 | 62 | hidden_states = all_gather(hidden_states, self.extra_config) 63 | bsz = hidden_states.shape[0] 64 | head_dim = self.head_dim() 65 | 66 | # Project to query, key, value 67 | query_proj = self.q_proj(hidden_states) 68 | key_proj = self.k_proj(hidden_states) 69 | value_proj = self.v_proj(hidden_states) 70 | 71 | # In tokasaurus, hidden_states is already 2D: [total_tokens, hidden_size] 72 | # We need to reshape to [total_tokens, num_heads, head_dim] for normalization 73 | # then back to [total_tokens, num_heads, head_dim] for attention 74 | 75 | # Match HuggingFace exactly: q_proj -> view -> q_norm 76 | # Note: In tokasaurus, bsz represents total tokens, not batch size 77 | # HF uses: query_proj.view(batch_size, seq_len, num_heads, head_dim) 78 | # We use: query_proj.view(total_tokens, 1, num_heads, head_dim) since each "token" is independent 79 | query_states = self.q_norm( 80 | query_proj.view(bsz, 1, self.num_attention_heads, head_dim) 81 | ) 82 | key_states = self.k_norm(key_proj.view(bsz, 1, self.num_kv_heads, head_dim)) 83 | 84 | # Flatten back to [total_tokens, num_heads, head_dim] for tokasaurus attention 85 | query_states = query_states.view(bsz, self.num_attention_heads, head_dim) 86 | key_states = key_states.view(bsz, self.num_kv_heads, head_dim) 87 | value_states = value_proj.view(bsz, self.num_kv_heads, head_dim) 88 | 89 | # Store original dtype 90 | dtype = query_states.dtype 91 | 92 | cos, sin = batch_state.position_embeddings 93 | 94 | query_states, key_states = apply_rotary_pos_emb( 95 | query_states, 96 | key_states, 97 | cos, 98 | sin, 99 | ) 100 | 101 | # Ensure dtype is preserved after rotary embedding 102 | query_states = query_states.to(dtype) 103 | key_states = key_states.to(dtype) 104 | 105 | raw_attn_output = self.attn_fn( 106 | query_states, 107 | key_states, 108 | value_states, 109 | self.layer_cache.k_cache, 110 | self.layer_cache.v_cache, 111 | ).clone() 112 | 113 | attn_output = raw_attn_output.view(bsz, self.num_attention_heads * head_dim) 114 | 115 | # NOTE: The purpose of running prefill tokens through the model is only 116 | # to populate the kv cache. After this last layer, we don't need to 117 | # do any more compute with these tokens. Technically, we could have 118 | # skipped the sdpa call for these too, but that would screw with the 119 | # paging information. 120 | if ( 121 | self.layer_idx == self.config.num_hidden_layers - 1 122 | and self.extra_config.tp_size == 1 123 | ): 124 | attn_output = attn_output[batch_state.lm_head_indices] 125 | residual = residual[batch_state.lm_head_indices] 126 | 127 | o_proj = self.o_proj(attn_output) 128 | 129 | o_proj = reduce_scatter(o_proj, self.extra_config) 130 | 131 | with_residual = residual + o_proj 132 | 133 | batch_state.hidden_states = with_residual 134 | return batch_state 135 | 136 | 137 | class Qwen3Block(LlamaBlock): 138 | attn_cls = Qwen3Attention 139 | 140 | 141 | class Qwen3Model(LlamaModel): 142 | block_cls = Qwen3Block 143 | 144 | 145 | class Qwen3ForCausalLM(LlamaForCausalLM): 146 | model_cls = Qwen3Model 147 | config_cls = Qwen3Config 148 | 149 | def head_dim(self): 150 | return self.config.head_dim 151 | 152 | def make_name_to_hf_name(self): 153 | """Override to add q_norm and k_norm parameter mappings""" 154 | name_to_hf_name = super().make_name_to_hf_name() 155 | 156 | # Add mappings for q_norm and k_norm in each attention layer 157 | for layer_idx in range(self.config.num_hidden_layers): 158 | name_to_hf_name[f"model.layers.{layer_idx}.self_attn.q_norm.weight"] = ( 159 | f"model.layers.{layer_idx}.self_attn.q_norm.weight" 160 | ) 161 | name_to_hf_name[f"model.layers.{layer_idx}.self_attn.k_norm.weight"] = ( 162 | f"model.layers.{layer_idx}.self_attn.k_norm.weight" 163 | ) 164 | 165 | return name_to_hf_name 166 | -------------------------------------------------------------------------------- /tokasaurus/server/types.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | from dataclasses import dataclass, field 4 | from datetime import datetime 5 | from typing import TYPE_CHECKING, Literal, Optional, Union 6 | 7 | from openai.types.chat import ChatCompletionMessageParam 8 | from openai.types.file_object import FileObject 9 | from pydantic import BaseModel, Field 10 | 11 | if TYPE_CHECKING: 12 | from tokasaurus.manager.types import SequenceOutput 13 | 14 | 15 | def nowstamp(): 16 | return int(datetime.now().timestamp()) 17 | 18 | 19 | class StreamOptions(BaseModel): 20 | include_usage: Optional[bool] = False 21 | 22 | 23 | class CompletionsRequest(BaseModel): 24 | # Ordered by official OpenAI API documentation 25 | # https://platform.openai.com/docs/api-reference/completions/create 26 | model: str 27 | prompt: Union[list[int], list[list[int]], str, list[str]] 28 | best_of: Optional[int] = None 29 | echo: Optional[bool] = False 30 | frequency_penalty: Optional[float] = 0.0 31 | logit_bias: Optional[dict[str, float]] = None 32 | logprobs: Optional[int] = None 33 | max_tokens: Optional[int] = 16 34 | n: int = 1 35 | presence_penalty: Optional[float] = 0.0 36 | seed: Optional[int] = None 37 | stop: Optional[Union[str, list[str]]] = Field(default_factory=list) 38 | stream: Optional[bool] = False 39 | stream_options: Optional[StreamOptions] = None 40 | suffix: Optional[str] = None 41 | temperature: Optional[float] = 1.0 42 | top_p: Optional[float] = 1.0 43 | user: Optional[str] = None 44 | metadata: Optional[dict] = None 45 | 46 | # pack the logprobs into the fingerprint in a more space-efficient way 47 | logprobs_in_fingerprint: bool = False 48 | 49 | # extra fields to get sglang benchmarking script to work 50 | ignore_eos: bool = False 51 | 52 | class Config: 53 | extra = "forbid" 54 | 55 | 56 | class JsonSchemaResponseFormat(BaseModel): 57 | name: str 58 | description: Optional[str] = None 59 | # use alias to workaround pydantic conflict 60 | schema_: Optional[dict[str, object]] = Field(alias="schema", default=None) 61 | strict: Optional[bool] = False 62 | 63 | 64 | class ResponseFormat(BaseModel): 65 | type: Literal["text", "json_object", "json_schema"] 66 | json_schema: Optional[JsonSchemaResponseFormat] = None 67 | 68 | 69 | class ChatCompletionRequest(BaseModel): 70 | # Ordered by official OpenAI API documentation 71 | # https://platform.openai.com/docs/api-reference/chat/create 72 | messages: list[ChatCompletionMessageParam] 73 | model: str 74 | frequency_penalty: Optional[float] = 0.0 75 | logit_bias: Optional[dict[str, float]] = None 76 | logprobs: Optional[bool] = False 77 | top_logprobs: Optional[int] = None 78 | max_tokens: Optional[int] = None 79 | max_completion_tokens: Optional[int] = None 80 | n: Optional[int] = 1 81 | presence_penalty: Optional[float] = 0.0 82 | response_format: Optional[ResponseFormat] = None 83 | seed: Optional[int] = None 84 | stop: Optional[Union[str, list[str]]] = Field(default_factory=list) 85 | stream: Optional[bool] = False 86 | stream_options: Optional[StreamOptions] = None 87 | temperature: Optional[float] = 0.7 88 | top_p: Optional[float] = 1.0 89 | user: Optional[str] = None 90 | metadata: Optional[dict] = None 91 | 92 | # extra fields --- 93 | 94 | # needed for sglang benchmarking script 95 | ignore_eos: bool = False 96 | 97 | # pack the logprobs into the fingerprint in a more space-efficient way 98 | logprobs_in_fingerprint: bool = False 99 | 100 | # extra chat template args, e.g. to pass enable_thinking for Qwen3 models: https://huggingface.co/Qwen/Qwen3-32B 101 | apply_chat_template_overrides: Optional[dict[str, object]] = None 102 | 103 | class Config: 104 | extra = "forbid" 105 | 106 | 107 | class BatchCreationRequest(BaseModel): 108 | """Request model for creating a batch""" 109 | 110 | input_file_id: str = Field( 111 | description="The ID of an uploaded file that contains requests for the new batch" 112 | ) 113 | endpoint: str = Field( 114 | description="The endpoint to be used for all requests in the batch" 115 | ) 116 | completion_window: str = Field( 117 | description="The time frame within which the batch should be processed" 118 | ) 119 | metadata: Optional[dict[str, str]] = Field(default=None) 120 | 121 | 122 | class SynchronousBatchCompletionsRequest(BaseModel): 123 | """Request model for synchronous batch completions""" 124 | 125 | requests: list[ChatCompletionRequest] = Field( 126 | description="List of chat completion requests to process" 127 | ) 128 | 129 | 130 | @dataclass 131 | class RequestOutput: 132 | id: str 133 | sequence_outputs: list["SequenceOutput"] = field(default_factory=list) 134 | 135 | 136 | @dataclass 137 | class SamplingParams: 138 | temperature: float 139 | top_p: float 140 | 141 | 142 | @dataclass 143 | class TokasaurusRequest: 144 | id: str 145 | input_ids: list[int] 146 | max_num_tokens: int 147 | sampling_params: SamplingParams 148 | stop: list[str] 149 | n: int 150 | ignore_eos: bool 151 | topk_logprobs: int | None = None # Number of top tokens to return log probs for 152 | created_timestamp: float = field(default_factory=time.time) 153 | 154 | 155 | @dataclass 156 | class SubmittedRequest: 157 | request: TokasaurusRequest 158 | engine_index: int 159 | 160 | event: asyncio.Event = field(default_factory=asyncio.Event) 161 | request_output: RequestOutput | None = None 162 | 163 | 164 | class BatchFileLine(BaseModel): 165 | custom_id: str 166 | method: Literal["POST"] 167 | url: Literal["/v1/completions", "/v1/chat/completions"] 168 | body: dict 169 | 170 | 171 | @dataclass 172 | class FileEntry: 173 | content: bytes 174 | details: FileObject 175 | 176 | 177 | @dataclass 178 | class SubmittedBatchItem: 179 | line: BatchFileLine 180 | user_req: CompletionsRequest | ChatCompletionRequest 181 | submitted_req: SubmittedRequest 182 | 183 | 184 | @dataclass 185 | class SubmittedBatch: 186 | id: str 187 | creation_request: BatchCreationRequest 188 | items: list[SubmittedBatchItem] 189 | task: asyncio.Task 190 | created_at: int = field(default_factory=nowstamp) 191 | output_file: FileEntry | None = None 192 | 193 | 194 | @dataclass 195 | class RequestError: 196 | error: str 197 | 198 | 199 | @dataclass 200 | class CancelledRequest: 201 | req_id: str 202 | 203 | 204 | CommandsFromServer = TokasaurusRequest | CancelledRequest 205 | -------------------------------------------------------------------------------- /tokasaurus/common_types.py: -------------------------------------------------------------------------------- 1 | import time 2 | from dataclasses import dataclass, field 3 | from typing import Callable 4 | 5 | import pydra 6 | import torch.multiprocessing as mp 7 | from transformers import AutoConfig, GenerationConfig 8 | 9 | from tokasaurus.core import complete_server_startup 10 | 11 | 12 | class TimedBarrier: 13 | def __init__(self, num_procs: int, message: str): 14 | self.barrier = mp.Barrier(num_procs) 15 | self.message = message 16 | self.start_time = time.time() 17 | 18 | def wait(self): 19 | remaining = self.barrier.wait() 20 | end = time.time() 21 | if remaining == 0: 22 | print(f"{self.message}: {end - self.start_time}") 23 | complete_server_startup() 24 | 25 | 26 | @dataclass 27 | class ProcessInfo: 28 | target: Callable 29 | args: tuple = () 30 | kwargs: dict = field(default_factory=dict) 31 | 32 | def make_process(self): 33 | return mp.Process(target=self.target, args=self.args, kwargs=self.kwargs) 34 | 35 | 36 | @dataclass 37 | class Engine: 38 | """ 39 | Wraps the queues to interact with a manager 40 | and one or more model processes. 41 | """ 42 | 43 | q_server_to_manager: mp.Queue 44 | q_manager_to_server: mp.Queue 45 | 46 | proc_dict: dict[str, ProcessInfo] 47 | 48 | def num_procs(self): 49 | return len(self.proc_dict) 50 | 51 | 52 | class ServerConfig(pydra.Config): 53 | model: str 54 | tokenizer: str | None = None 55 | 56 | trust_remote_code: bool = False 57 | dtype: str = "bfloat16" 58 | rope_scaling: str | None = None 59 | 60 | use_hydragen: bool = False 61 | hydragen_min_group_size: int = 32 62 | hydragen_min_prefix_len: int = 256 63 | 64 | enable_chosen_logprobs: bool = True 65 | max_topk_logprobs: int | None = None 66 | 67 | port: int = 10210 68 | local_proc_name: str = "server" 69 | 70 | log_level: str = "INFO" 71 | log_procs: list[str] | None = None 72 | uvicorn_log_level: str = "info" 73 | 74 | stats_report_seconds: float = 5.0 75 | statsd_server_url: None | str = None 76 | 77 | page_size: int = 16 78 | kv_cache_num_tokens: int = 1024 * 128 79 | 80 | torch_compile: bool = False 81 | 82 | # the batch size at which we switch to using async TP 83 | async_tp_threshold: int | None = None 84 | 85 | max_tokens_per_forward: int = 8192 86 | max_seqs_per_forward: int = 1024 87 | prefill_round_up_multiple: int = 16 88 | 89 | scheduling_steps_ahead: int = 8 90 | stop_string_num_token_lookback: int = 5 91 | 92 | dp_size: int = 1 93 | pp_size: int = 1 94 | tp_size: int = 1 95 | 96 | # adding extra stages to hide the latency 97 | # of sending lm-head results from the end of the pipeline to the start, 98 | # as well as buffer data dependencies from sequences being rearranged 99 | # across microbatches (e.g. as sequences finish / new sequences start). 100 | pp_num_buffer_stages: int = 1 101 | 102 | track_early_stopping: bool = True 103 | early_stopping_buffer_size: int = 2048 104 | early_stopping_num_prediction_buckets: int = 1024 105 | early_stopping_initial_wait: int = 16 106 | early_stopping_init_mean: float | None = None 107 | early_stopping_init_std: float | None = None 108 | max_num_tokens_per_request: int | None = None 109 | 110 | enable_precise_onboard: bool = True 111 | precise_onboard_batch_size: int = 128 112 | greedy_prefill: bool = True 113 | 114 | use_spec_allocation: bool = True 115 | spec_allocation_std_buffer_scale: float = 0.25 116 | spec_allocation_target_kv_cache_utilization: float = 1.0 117 | 118 | use_cudagraphs: bool = True 119 | cudagraph_max_size: int = 128 120 | cudagraph_step: int = 16 121 | cudagraph_max_kv_indices_per_seq: int = 32768 122 | 123 | # for debugging only, will slow things down 124 | allocator_sanity_checks: bool = False 125 | bump_city_population_me: bool = False 126 | 127 | def uvsh(self): 128 | self.uvicorn_log_level = "warning" 129 | 130 | def kv_cache_num_blocks(self): 131 | assert self.kv_cache_num_tokens % self.page_size == 0 132 | return self.kv_cache_num_tokens // self.page_size 133 | 134 | def max_batch_index(self): 135 | # fudge factor on the total number of sequences running at any time 136 | return self.max_tokens_per_forward * 2 137 | 138 | def model_config(self): 139 | return AutoConfig.from_pretrained( 140 | self.model, trust_remote_code=self.trust_remote_code 141 | ) 142 | 143 | def generation_config(self): 144 | return GenerationConfig.from_pretrained( 145 | self.model, trust_remote_code=self.trust_remote_code 146 | ) 147 | 148 | def finalize(self): 149 | super().finalize() 150 | 151 | if self.use_spec_allocation: 152 | assert self.track_early_stopping, ( 153 | "use_spec_allocation requires track_early_stopping" 154 | ) 155 | assert self.spec_allocation_std_buffer_scale >= 0, ( 156 | "spec_allocation_std_buffer_scale must be non-negative" 157 | ) 158 | 159 | if self.tokenizer is None: 160 | self.tokenizer = self.model 161 | 162 | if self.max_num_tokens_per_request is None: 163 | model_config = self.model_config() 164 | self.max_num_tokens_per_request = min( 165 | model_config.max_position_embeddings, self.kv_cache_num_tokens 166 | ) 167 | print( 168 | f"Setting max_num_tokens_per_request to {self.max_num_tokens_per_request}" 169 | ) 170 | 171 | if self.use_hydragen and self.use_cudagraphs: 172 | assert self.cudagraph_max_size < self.hydragen_min_group_size, ( 173 | f"For now hydragen_min_group_size ({self.hydragen_min_group_size}) must exceed cudagraph_max_size ({self.cudagraph_max_size})" 174 | ) 175 | 176 | # for debugging different parts of the system 177 | def dmanager(self): 178 | self.local_proc_name = "manager" 179 | 180 | def dmodel(self): 181 | self.local_proc_name = "model_worker" 182 | 183 | def par(self, dp=1, pp=1, tp=1): 184 | self.dp_size = dp 185 | self.pp_size = pp 186 | self.tp_size = tp 187 | 188 | def scheduler_block_target(self): 189 | target_blocks = self.kv_cache_num_blocks() 190 | if self.use_spec_allocation: 191 | target_blocks = round( 192 | target_blocks * self.spec_allocation_target_kv_cache_utilization 193 | ) 194 | return target_blocks 195 | -------------------------------------------------------------------------------- /tokasaurus/benchmarks/monkeys_gsm8k.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | import time 4 | from functools import partial 5 | 6 | import pydra 7 | from datasets import load_dataset 8 | from tabulate import tabulate 9 | from tokenizers import Tokenizer 10 | from tqdm import tqdm 11 | from transformers import AutoTokenizer 12 | 13 | from tokasaurus.benchmarks.utils import ( 14 | BaseConfig, 15 | launch_server, 16 | make_pass_at_k_table, 17 | maybe_save_results, 18 | parallelize, 19 | shuffle_and_limit, 20 | ) 21 | 22 | 23 | class ScriptConfig(BaseConfig): 24 | def __init__(self): 25 | super().__init__() 26 | self.n = 512 27 | self.limit = 128 28 | self.max_tokens = 1024 29 | self.temperature = 0.6 30 | self.num_few_shot = 4 31 | self.stop_strings = ["Question:"] 32 | 33 | def finalize(self): 34 | self.ks = list(range(1, min(11, self.n + 1))) 35 | cur = 100 36 | while True: 37 | self.ks.append(cur) 38 | cur *= 10 39 | if cur > self.n: 40 | break 41 | 42 | if self.n not in self.ks: 43 | self.ks.append(self.n) 44 | 45 | 46 | ANS_RE_GSM8k = re.compile(r"#### (\-?[\$0-9\.\,]+)") 47 | INVALID_ANS_GSM8k = "[invalid]" 48 | GSM8K_IGNORE_REGEXES = [",", "\\$", "\\.$"] 49 | 50 | 51 | def filter_ignores(st, regexes_to_ignore): 52 | if regexes_to_ignore is not None: 53 | for s in regexes_to_ignore: 54 | st = re.sub(s, "", st) 55 | return st 56 | 57 | 58 | def extract_answer_gsm8k(completion): 59 | match = ANS_RE_GSM8k.search(completion) 60 | if match: 61 | match_str = match.group(1).strip() 62 | match_str = filter_ignores( 63 | match_str, 64 | GSM8K_IGNORE_REGEXES, 65 | ) 66 | return match_str 67 | else: 68 | return INVALID_ANS_GSM8k 69 | 70 | 71 | def is_correct_gsm8k(model_completion, gt_example): 72 | gt_answer = extract_answer_gsm8k(gt_example) 73 | assert gt_answer != INVALID_ANS_GSM8k 74 | return extract_answer_gsm8k(model_completion) == gt_answer 75 | 76 | 77 | def get_few_shot_prompt(item): 78 | few_shot_items = item["few_shot_items"] 79 | 80 | few_shot_pieces = [] 81 | for f in few_shot_items: 82 | # https://github.com/EleutherAI/lm-evaluation-harness/blob/568af943e315100af3f00937bfd6947844769ab8/lm_eval/tasks/gsm8k/gsm8k.yaml 83 | few_shot_prompt = f"Question: {f['question']}\nAnswer: {f['answer']}\n\n" 84 | few_shot_pieces.append(few_shot_prompt) 85 | 86 | few_shot_prompt = "".join(few_shot_pieces) 87 | 88 | return few_shot_prompt 89 | 90 | 91 | def run_inference(item, config: ScriptConfig): 92 | # making the ordering of requests to the server more consistent with multiple workers 93 | if config.workers != 0: 94 | index = item["shuffled_index"] 95 | time.sleep(0.1 * index) 96 | 97 | client = config.client() 98 | few_shot_prompt = get_few_shot_prompt(item) 99 | prompt = few_shot_prompt + f"Question: {item['question']}\nAnswer:" 100 | 101 | response = client.completions.create( 102 | model=config.model, 103 | prompt=prompt, 104 | max_tokens=config.max_tokens, 105 | temperature=config.temperature, 106 | top_p=config.top_p, 107 | stop=config.stop_strings, 108 | n=config.n, 109 | logprobs=None, 110 | ) 111 | 112 | completions = [choice.text for choice in response.choices] 113 | assert len(completions) == config.n 114 | 115 | gt_answer = item["answer"] 116 | corrects = [is_correct_gsm8k(completion, gt_answer) for completion in completions] 117 | 118 | result = { 119 | "prompt": prompt, 120 | "completions": completions, 121 | "corrects": corrects, 122 | } 123 | 124 | return result 125 | 126 | 127 | def run_eval(config: ScriptConfig, go_func, test_dataset: list[dict]): 128 | start = time.time() 129 | results_list = parallelize( 130 | go_func, 131 | test_dataset, 132 | num_workers=config.workers, 133 | processes=True, 134 | allow_unordered=True, 135 | ) 136 | end = time.time() 137 | 138 | elapsed = end - start 139 | print(f"Elapsed time: {elapsed} seconds") 140 | 141 | corrects_list = [result["corrects"] for result in results_list] 142 | table = make_pass_at_k_table(corrects_list, config.ks) 143 | 144 | print(tabulate(table, headers=["k", "pass@k"], tablefmt="github")) 145 | 146 | tokenizer = AutoTokenizer.from_pretrained(config.model) 147 | 148 | total_input_tokens = sum( 149 | len(tokenizer.encode(result["prompt"])) for result in results_list 150 | ) 151 | inner_tokenizer: Tokenizer = tokenizer._tokenizer 152 | encoded_outputs = [ 153 | inner_tokenizer.encode_batch(result["completions"]) for result in results_list 154 | ] 155 | output_tokens_per_item = [ 156 | sum(len(output) for output in outputs) for outputs in encoded_outputs 157 | ] 158 | total_output_tokens = sum(output_tokens_per_item) 159 | 160 | print(f"Total input tokens: {total_input_tokens}") 161 | print(f"Total output tokens: {total_output_tokens}") 162 | 163 | throughput = total_output_tokens / elapsed 164 | print(f"Throughput: {throughput:.2f} tokens/second") 165 | 166 | maybe_save_results( 167 | config, 168 | { 169 | "duration": elapsed, 170 | "pass_at_k": table, 171 | "launch": config.launch, 172 | "total_input_tokens": total_input_tokens, 173 | "total_output_tokens": total_output_tokens, 174 | }, 175 | ) 176 | 177 | 178 | def main(config: ScriptConfig): 179 | raw_test_dataset = list(load_dataset("gsm8k", "main", split="test")) 180 | train_dataset = list(load_dataset("gsm8k", "main", split="train")) 181 | 182 | print(f"Number of test items: {len(raw_test_dataset)}") 183 | print(f"Number of train items: {len(train_dataset)}") 184 | 185 | random.seed(config.seed) 186 | 187 | for i, data in enumerate(train_dataset): 188 | data["index"] = i 189 | 190 | test_dataset = shuffle_and_limit(raw_test_dataset, config) 191 | 192 | for i, data in enumerate(test_dataset): 193 | few_shot_items = random.sample(train_dataset, config.num_few_shot) 194 | data["few_shot_items"] = few_shot_items 195 | 196 | print(f"Total number of items to process: {len(test_dataset)}") 197 | 198 | go_func = partial(run_inference, config=config) 199 | 200 | if (save_path := config.save_path) is not None: 201 | save_path.parent.mkdir(parents=True, exist_ok=True) 202 | launch_command_save_path = save_path.with_suffix(".launch.txt") 203 | launch_command_save_path.write_text(str(config.launch)) 204 | 205 | with launch_server(config): 206 | for _ in tqdm(range(config.reps)): 207 | run_eval(config, go_func, test_dataset) 208 | 209 | 210 | if __name__ == "__main__": 211 | pydra.run(main) 212 | -------------------------------------------------------------------------------- /logs/blog_commands.md: -------------------------------------------------------------------------------- 1 | Setup: 2 | 3 | ```bash 4 | 5 | BASE_DIR=local/results 6 | TOKA_ENV=toka-bench 7 | VLLM_ENV=vllm-bench 8 | SGL_ENV=sgl-bench 9 | 10 | conda create -n $TOKA_ENV -y python=3.12 11 | conda activate $TOKA_ENV 12 | conda install -y nvidia/label/cuda-12.4.1::cuda-toolkit 13 | cd ~/tokasaurus 14 | pip install uv 15 | uv pip install -e '.[dev]' 16 | 17 | conda create -n $VLLM_ENV -y python=3.12 18 | conda activate $VLLM_ENV 19 | conda install -y nvidia/label/cuda-12.4.1::cuda-toolkit 20 | pip install uv 21 | uv pip install vllm==0.9.0.1 22 | uv pip install flashinfer-python==0.2.5 --no-deps 23 | 24 | conda create -n $SGL_ENV -y python=3.12 25 | conda activate $SGL_ENV 26 | conda install -y nvidia/label/cuda-12.4.1::cuda-toolkit 27 | pip install uv 28 | uv pip install 'sglang[all]'==0.4.6.post5 29 | 30 | ``` 31 | 32 | 33 | 1xH100: 34 | 35 | ```bash 36 | 37 | # monkeys 1b 1xh100 38 | 39 | DIR=$BASE_DIR/monkeys-1b-1xh100 40 | 41 | TOKA_COMMAND="tksrs model=meta-llama/Llama-3.2-1B-Instruct kv_cache_num_tokens='((1024 + 512) * 1024)' max_seqs_per_forward=8192 max_tokens_per_forward=32768 torch_compile=T use_hydragen=T hydragen_min_group_size=129 .uvsh" 42 | 43 | TOKA_NO_HYDRAGEN_COMMAND="tksrs model=meta-llama/Llama-3.2-1B-Instruct kv_cache_num_tokens='((1024 + 512) * 1024)' max_seqs_per_forward=8192 max_tokens_per_forward=32768 torch_compile=T .uvsh" 44 | 45 | SGL_COMMAND="python -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port 10210 --max-running-requests 8192 --chunked-prefill-size 32768 --max-prefill-tokens 32786 --max-total-tokens 1572864 --schedule-conservativeness 0.01 --log-level-http warning --attention-backend flashinfer" 46 | 47 | VLLM_COMMAND="vllm serve meta-llama/Llama-3.2-1B-Instruct --num-gpu-blocks-override 98304 --port 10210 --max-num-seqs 8192 --max-num-batched-tokens 32768 --enable-prefix-caching --disable-log-requests --enable-chunked-prefill --uvicorn-log-level warning" 48 | 49 | bench() { 50 | python tokasaurus/benchmarks/monkeys_gsm8k.py model=meta-llama/Llama-3.2-1B-Instruct limit=128 n=1024 port=10210 reps=4 "$@" 51 | } 52 | 53 | ulimit -n unlimited 54 | export SGLANG_DETOKENIZER_MAX_STATES=10000000 55 | 56 | conda activate $TOKA_ENV 57 | bench launch=$TOKA_COMMAND save_path=$DIR/toka.jsonl 58 | bench launch=$TOKA_NO_HYDRAGEN_COMMAND save_path=$DIR/toka_no_hydragen.jsonl 59 | bench launch=$SGL_COMMAND env=$SGL_ENV save_path=$DIR/sgl.jsonl 60 | bench launch=$VLLM_COMMAND env=$VLLM_ENV save_path=$DIR/vllm.jsonl 61 | 62 | # sharegpt 1b 1xh100 63 | 64 | DIR=$BASE_DIR/sharegpt-1b-1xh100 65 | 66 | BENCH_COMMAND='python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompt 65536 --sharegpt-context-len 131072 --model meta-llama/Llama-3.2-1B-Instruct --disable-stream --max-concurrency 8192 --port 10210' 67 | 68 | TOKA_COMMAND="tksrs model=meta-llama/Llama-3.2-1B-Instruct kv_cache_num_tokens='((1024 + 768) * 1024)' max_seqs_per_forward=8192 max_tokens_per_forward=32768 torch_compile=T scheduling_steps_ahead=16 .uvsh" 69 | 70 | SGL_COMMAND="python -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port 10210 --max-running-requests 8192 --chunked-prefill-size 32768 --max-prefill-tokens 32786 --max-total-tokens 1835008 --schedule-conservativeness 0.1 --log-level-http warning --attention-backend flashinfer" 71 | 72 | VLLM_COMMAND="vllm serve meta-llama/Llama-3.2-1B-Instruct --num-gpu-blocks-override 114688 --port 10210 --max-num-seqs 8192 --max-num-batched-tokens 32768 --enable-prefix-caching --disable-log-requests --enable-chunked-prefill --uvicorn-log-level warning" 73 | 74 | bench() { 75 | python tokasaurus/benchmarks/sharegpt.py model=meta-llama/Llama-3.2-1B-Instruct sharegpt_command=\"$BENCH_COMMAND\" sharegpt_env=$SGL_ENV port=10210 reps=4 "$@" 76 | } 77 | 78 | ulimit -n unlimited 79 | export SGLANG_DETOKENIZER_MAX_STATES=10000000 80 | 81 | bench launch=$TOKA_COMMAND save_path=$DIR/toka.jsonl 82 | bench launch=$SGL_COMMAND env=$SGL_ENV save_path=$DIR/sgl.jsonl 83 | bench launch=$VLLM_COMMAND env=$VLLM_ENV save_path=$DIR/vllm.jsonl 84 | 85 | ``` 86 | 87 | 88 | 8xH100: 89 | 90 | ```bash 91 | 92 | # sharegpt 70b 8xh100 93 | 94 | DIR=$BASE_DIR/sharegpt-70b-8xh100 95 | 96 | BENCH_COMMAND='python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompt 65536 --sharegpt-context-len 131072 --model meta-llama/Llama-3.1-70B-Instruct --disable-stream --max-concurrency 8192 --port 10210' 97 | 98 | TOKA_COMMAND="tksrs model=meta-llama/Llama-3.1-70B-Instruct tp_size=8 kv_cache_num_tokens='((1024 + 128) * 1024)' max_seqs_per_forward=4096 max_tokens_per_forward=16384 torch_compile=T async_tp_threshold=6144 .uvsh" 99 | 100 | SGL_COMMAND="python -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --port 10210 --max-running-requests 4096 --chunked-prefill-size 49152 --max-prefill-tokens 49152 --max-total-tokens 1179648 --schedule-conservativeness 0.1 --tensor-parallel-size 8 --log-level-http warning --attention-backend flashinfer" 101 | 102 | VLLM_COMMAND="vllm serve meta-llama/Llama-3.1-70B-Instruct --num-gpu-blocks-override 73728 --port 10210 --max-num-seqs 4096 --max-num-batched-tokens 16384 --enable-prefix-caching --disable-log-requests --enable-chunked-prefill --tensor-parallel-size 8 --uvicorn-log-level warning" 103 | 104 | bench() { 105 | python tokasaurus/benchmarks/sharegpt.py model=meta-llama/Llama-3.1-70B-Instruct sharegpt_command=\"$BENCH_COMMAND\" sharegpt_env=$SGL_ENV port=10210 reps=4 "$@" 106 | } 107 | 108 | ulimit -n unlimited 109 | export SGLANG_DETOKENIZER_MAX_STATES=10000000 110 | 111 | conda activate $TOKA_ENV 112 | 113 | bench launch=$TOKA_COMMAND save_path=$DIR/toka.jsonl 114 | bench launch=$SGL_COMMAND env=$SGL_ENV save_path=$DIR/sgl.jsonl 115 | bench launch=$VLLM_COMMAND env=$VLLM_ENV save_path=$DIR/vllm.jsonl 116 | 117 | 118 | ``` 119 | 120 | 121 | 8xL40S: 122 | 123 | ```bash 124 | 125 | DIR=$BASE_DIR/sharegpt-70b-8xl40s 126 | 127 | BENCH_COMMAND='python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompt 16384 --sharegpt-context-len 4096 --model meta-llama/Llama-3.1-70B-Instruct --disable-stream --max-concurrency 4096 --port 10210' 128 | 129 | TOKA_COMMAND="tksrs model=meta-llama/Llama-3.1-70B-Instruct pp_size=8 kv_cache_num_tokens='(512 * 1024)' max_seqs_per_forward=2048 max_tokens_per_forward=8192 torch_compile=T .uvsh" 130 | 131 | VLLM_COMMAND="vllm serve meta-llama/Llama-3.1-70B-Instruct --num-gpu-blocks-override 32768 --port 10210 --max-num-seqs 512 --max-num-batched-tokens 2048 --enable-prefix-caching --disable-log-requests --enable-chunked-prefill --pipeline-parallel-size 8 --port 10210 --max-model-len 32768" 132 | 133 | SGL_COMMAND="python -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --port 10210 --max-running-requests 2048 --chunked-prefill-size 8192 --max-prefill-tokens 8192 --max-total-tokens 524288 --schedule-conservativeness 0.1 --pipeline-parallel-size 8 --log-level-http warning --attention-backend flashinfer" 134 | 135 | bench() { 136 | python tokasaurus/benchmarks/sharegpt.py model=meta-llama/Llama-3.1-70B-Instruct sharegpt_command=\"$BENCH_COMMAND\" sharegpt_env=$SGL_ENV port=10210 reps=4 "$@" 137 | } 138 | 139 | ulimit -n unlimited 140 | export SGLANG_DETOKENIZER_MAX_STATES=10000000 141 | 142 | conda activate $TOKA_ENV 143 | 144 | bench launch=$TOKA_COMMAND save_path=$DIR/toka.jsonl 145 | bench launch=$SGL_COMMAND env=$SGL_ENV save_path=$DIR/sgl.jsonl 146 | bench launch=$VLLM_COMMAND env=$VLLM_ENV save_path=$DIR/vllm.jsonl 147 | 148 | 149 | ``` -------------------------------------------------------------------------------- /tokasaurus/model/basic_worker.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.multiprocessing as mp 5 | from loguru import logger 6 | from torch import Tensor 7 | 8 | from tokasaurus.common_types import ( 9 | ServerConfig, 10 | TimedBarrier, 11 | ) 12 | from tokasaurus.model.llama import LlamaForCausalLM 13 | from tokasaurus.model.types import ( 14 | BasicWorkerState, 15 | BatchState, 16 | CommandFromManager, 17 | ModelInput, 18 | ModelOutput, 19 | ModelOutputTensors, 20 | NoMoreInputs, 21 | ) 22 | from tokasaurus.model.utils import ( 23 | ModelRunner, 24 | add_decoding_ids_to_batch_state, 25 | get_dtype, 26 | make_input_batch_state, 27 | make_model, 28 | move_batch_state, 29 | setup_and_run_loop, 30 | setup_distributed, 31 | unpad_output_batch_state, 32 | ) 33 | from tokasaurus.utils import ( 34 | error_propogation_decorator, 35 | setup_logging, 36 | ) 37 | 38 | 39 | def basic_model_loop( 40 | state: BasicWorkerState, 41 | model: LlamaForCausalLM, 42 | ): 43 | state.logger.info("Model loop started!") 44 | 45 | tp_rank = state.tp_rank 46 | tp_size = state.config.tp_size 47 | non_blocking = True 48 | 49 | @dataclass 50 | class Work: 51 | model_input: ModelInput 52 | input_batch_state: BatchState 53 | batch_indices: Tensor 54 | output_batch_state: BatchState | None = None 55 | output_tensors_cpu: ModelOutputTensors | None = None 56 | 57 | def preprocess(): 58 | command: CommandFromManager = state.input_q.get() 59 | 60 | match command: 61 | case ModelInput(): 62 | inp = command 63 | case NoMoreInputs(): 64 | return None 65 | case _: 66 | raise ValueError(f"Unknown command: {type(command)}") 67 | 68 | batch_indices = torch.tensor( 69 | inp.batch_indices, 70 | dtype=torch.long, 71 | ) 72 | 73 | num_total_padding, num_lm_head_padding = model_runner.calc_padding( 74 | num_prefill_tokens=inp.num_prefill_tokens(), 75 | num_decode_tokens=inp.num_decode_tokens(), 76 | num_lm_head_tokens=inp.num_lm_head_tokens(), 77 | ) 78 | 79 | input_batch_state = make_input_batch_state( 80 | inp, 81 | tp_rank=tp_rank, 82 | tp_size=tp_size, 83 | num_total_padding=num_total_padding, 84 | num_lm_head_padding=num_lm_head_padding, 85 | ) 86 | 87 | model_runner.plan(input_batch_state, non_blocking=non_blocking) 88 | 89 | move_batch_state( 90 | input_batch_state=input_batch_state, 91 | device=state.device, 92 | non_blocking=non_blocking, 93 | ) 94 | 95 | return Work( 96 | model_input=inp, 97 | input_batch_state=input_batch_state, 98 | batch_indices=batch_indices.to(state.device, non_blocking=non_blocking), 99 | ) 100 | 101 | def run_model(work: Work): 102 | decoding_batch_indices = work.batch_indices[ 103 | work.model_input.decode_start_pos() : 104 | ] 105 | decoding_input_ids = state.batch_index_to_last_token[decoding_batch_indices] 106 | 107 | input_batch_state = work.input_batch_state 108 | 109 | add_decoding_ids_to_batch_state( 110 | input_batch_state=input_batch_state, 111 | decoding_input_ids=decoding_input_ids, 112 | tp_rank=tp_rank, 113 | tp_size=tp_size, 114 | ) 115 | 116 | output_batch_state = model_runner.run( 117 | input_batch_state, non_blocking=non_blocking 118 | ) 119 | 120 | unpad_output_batch_state( 121 | output_batch_state=output_batch_state, 122 | model_input=work.model_input, 123 | ) 124 | 125 | if input_batch_state.raw_lm_head_indices is not None: 126 | lm_head_indices = input_batch_state.raw_lm_head_indices 127 | else: 128 | lm_head_indices = input_batch_state.lm_head_indices 129 | 130 | assert lm_head_indices is not None 131 | batch_indices = work.batch_indices[lm_head_indices] 132 | 133 | if len(batch_indices) > 0: 134 | assert output_batch_state.outputs is not None 135 | state.batch_index_to_last_token[batch_indices] = ( 136 | output_batch_state.outputs.output_ids 137 | ) 138 | 139 | work.output_batch_state = output_batch_state 140 | 141 | def synchronize(work: Work): 142 | # technically, we don't need to sync when tp_rank != 0, 143 | # but omitting it causes sporadic nccl illegal memory access errors 144 | torch.cuda.synchronize() 145 | 146 | work.output_tensors_cpu = work.output_batch_state.outputs.to("cpu") 147 | 148 | def postprocess(work: Work): 149 | if state.tp_rank != 0: 150 | return 151 | 152 | assert work.output_tensors_cpu is not None 153 | 154 | out = ModelOutput( 155 | tensors=work.output_tensors_cpu, 156 | schedule_id=work.model_input.schedule_id, 157 | ) 158 | 159 | state.q_model_to_manager.put(out) 160 | 161 | model_runner = ModelRunner( 162 | config=state.config, 163 | model=model, 164 | ) 165 | 166 | setup_and_run_loop( 167 | state=state, 168 | model_runner=model_runner, 169 | preprocess=preprocess, 170 | run_model=run_model, 171 | synchronize=synchronize, 172 | postprocess=postprocess, 173 | ) 174 | 175 | 176 | @error_propogation_decorator 177 | def start_basic_model_worker( 178 | config: ServerConfig, 179 | input_q: mp.Queue, 180 | q_model_to_manager: mp.Queue, 181 | dp_rank: int, 182 | tp_rank: int, 183 | master_port: int, 184 | process_name: str, 185 | barrier: TimedBarrier, 186 | ): 187 | setup_logging(config) 188 | 189 | device_mesh, device = setup_distributed( 190 | config, 191 | dp_rank=dp_rank, 192 | pp_rank=0, 193 | tp_rank=tp_rank, 194 | master_port=master_port, 195 | ) 196 | dtype = get_dtype(config.dtype) 197 | 198 | batch_index_to_last_token = torch.zeros( 199 | config.max_batch_index(), dtype=torch.long, device=device 200 | ) 201 | 202 | state = BasicWorkerState( 203 | config=config, 204 | batch_index_to_last_token=batch_index_to_last_token, 205 | input_q=input_q, 206 | q_model_to_manager=q_model_to_manager, 207 | device=device, 208 | dtype=dtype, 209 | process_name=process_name, 210 | tp_rank=tp_rank, 211 | barrier=barrier, 212 | ) 213 | 214 | state.logger.info("Model worker started!") 215 | state.logger.info(f"Creating model on device {device} with dtype {dtype}") 216 | 217 | model = make_model( 218 | config, 219 | device, 220 | dtype, 221 | tp_rank=tp_rank, 222 | tp_group=device_mesh["tp"].get_group() if device_mesh is not None else None, 223 | ) 224 | 225 | state.logger.info("Created model") 226 | 227 | basic_model_loop(state, model) 228 | 229 | 230 | def start_fanout_worker( 231 | config: ServerConfig, 232 | input_q: mp.Queue, 233 | fanout_qs: list[mp.Queue], 234 | process_name: str, 235 | barrier: TimedBarrier, 236 | ): 237 | setup_logging(config) 238 | bound_logger = logger.bind(process_name=process_name) 239 | 240 | bound_logger.info("Fanout worker started!") 241 | 242 | barrier.wait() 243 | 244 | while True: 245 | inp = input_q.get() 246 | for q in fanout_qs: 247 | q.put(inp) 248 | -------------------------------------------------------------------------------- /tokasaurus/manager/hydragen.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Iterable 3 | 4 | from tokasaurus.manager.allocator import PrefixTreeBlock 5 | from tokasaurus.manager.monitoring import track_time_decorator 6 | from tokasaurus.manager.types import HydragenGroup, ScheduleDecision, Sequence 7 | 8 | 9 | @track_time_decorator() 10 | def reorder_decoding_seqs_for_hydragen( 11 | decoding_seqs: list[Sequence], hydragen_groups: list[HydragenGroup] 12 | ): 13 | """ 14 | Our Hydragen implementation requires us to reorder the decoding sequences so that 15 | seqs in the same shared-prefix group are adjacent to each other in the batch. 16 | """ 17 | 18 | sid_to_decoding_seq = {seq.id: seq for seq in decoding_seqs} 19 | 20 | # we re-order the decoding sequences so that seqs in the same hydragen group 21 | # are adjacent to each other, with ungrouped seqs at the end. 22 | reordered_decoding_seqs = [] 23 | 24 | grouped_sids = set[str]() 25 | sid_to_group = dict[str, HydragenGroup]() 26 | for group in hydragen_groups: 27 | for sid in group.seq_ids: 28 | sid_to_group[sid] = group 29 | grouped_sids.add(sid) 30 | reordered_decoding_seqs.append(sid_to_decoding_seq[sid]) 31 | 32 | # ungrouped seqs at the end 33 | for seq in decoding_seqs: 34 | if seq.id not in grouped_sids: 35 | reordered_decoding_seqs.append(seq) 36 | 37 | assert len(reordered_decoding_seqs) == len(decoding_seqs) 38 | 39 | return reordered_decoding_seqs 40 | 41 | 42 | def reorder_decision_for_hydragen( 43 | decision: ScheduleDecision, groups: list[HydragenGroup] 44 | ): 45 | return ScheduleDecision( 46 | id=decision.id, 47 | decoding_seqs=reorder_decoding_seqs_for_hydragen( 48 | decision.decoding_seqs, groups 49 | ), 50 | prefill_seqs=decision.prefill_seqs, 51 | ) 52 | 53 | 54 | def node_to_block_ids(node: PrefixTreeBlock) -> list[int]: 55 | block_ids_last_to_first = [] 56 | cur = node 57 | while not cur.is_root(): 58 | block_ids_last_to_first.append(cur.idx) 59 | cur = cur.parent 60 | assert cur is not None 61 | 62 | return list(reversed(block_ids_last_to_first)) 63 | 64 | 65 | @track_time_decorator() 66 | def group_for_hydragen( 67 | root: PrefixTreeBlock, 68 | seq_ids_to_group: Iterable[str], 69 | min_group_size: int, 70 | min_prefix_len: int, 71 | page_size: int, 72 | ) -> list[HydragenGroup]: 73 | """ 74 | Iterative version of depth-first search - we make a group for a given prefix if 75 | it meets the minimum group size/prefix length requirements, 76 | after checking if any children have met these requirements. 77 | """ 78 | groups = list[HydragenGroup]() 79 | all_sids = set(seq_ids_to_group) 80 | grouped_sids = set[str]() 81 | 82 | @dataclass 83 | class StackItem: 84 | node: PrefixTreeBlock 85 | depth: int 86 | visited_children: bool 87 | potential_sids: set[str] 88 | 89 | # Stack will contain tuples of (node, block_ids_before_this_node, visited_children) 90 | # visited_children is a boolean indicating whether we've already processed the children 91 | stack: list[StackItem] = [] 92 | 93 | assert min_prefix_len % page_size == 0 94 | min_depth = min_prefix_len // page_size 95 | 96 | # Initialize the stack with the root's children 97 | for child in root.children.values(): 98 | stack.append(StackItem(child, 1, False, all_sids & child.seq_ids)) 99 | 100 | while stack: 101 | item = stack.pop() 102 | 103 | if not item.visited_children: 104 | # Skip this node if it doesn't have enough sequence IDs 105 | if len(item.potential_sids) < min_group_size: 106 | continue 107 | 108 | if item.depth >= min_depth: 109 | # If there's a chance this node is the last block in a group, 110 | # push it back on the stack to re-consider it after we process its children. 111 | stack.append( 112 | StackItem(item.node, item.depth, True, item.potential_sids) 113 | ) 114 | 115 | # Process children 116 | for child in item.node.children.values(): 117 | stack.append( 118 | StackItem( 119 | child, 120 | item.depth + 1, 121 | False, 122 | item.potential_sids & child.seq_ids, 123 | ) 124 | ) 125 | else: 126 | # We need to compute the available ids after considering the children 127 | # since the children may have created groups. 128 | available_sids_to_group = item.potential_sids - grouped_sids 129 | if len(available_sids_to_group) >= min_group_size: 130 | groups.append( 131 | HydragenGroup( 132 | block_ids=node_to_block_ids(item.node), 133 | seq_ids=available_sids_to_group, 134 | ) 135 | ) 136 | grouped_sids.update(available_sids_to_group) 137 | 138 | sorted_groups = list(sorted(groups, key=lambda x: x.block_ids)) 139 | return sorted_groups 140 | 141 | 142 | def restrict_hydragen_groups( 143 | groups: list[HydragenGroup], 144 | restrict_to_seq_ids: set[str], 145 | min_group_size: int, 146 | min_prefix_len: int, 147 | page_size: int, 148 | ) -> list[HydragenGroup]: 149 | """ 150 | Restricts the groups to only include seqs in restrict_to_seq_ids, and ensures 151 | that the groups meet the minimum group size/prefix length requirements. 152 | """ 153 | 154 | assert min_prefix_len % page_size == 0 155 | min_prefix_len_in_pages = min_prefix_len // page_size 156 | 157 | now_too_small_groups = list[HydragenGroup]() 158 | good_groups = list[HydragenGroup]() 159 | 160 | for group in groups: 161 | restricted_group_seq_ids = group.seq_ids & restrict_to_seq_ids 162 | if len(restricted_group_seq_ids) >= min_group_size: 163 | good_groups.append( 164 | HydragenGroup( 165 | block_ids=group.block_ids, 166 | seq_ids=restricted_group_seq_ids, 167 | ) 168 | ) 169 | else: 170 | now_too_small_groups.append(group) 171 | 172 | if len(now_too_small_groups) > 0: 173 | # best-effort approach to merge groups that are now too small 174 | too_small_prefixes = [group.block_ids for group in now_too_small_groups] 175 | 176 | # find longest common prefix of all too_small_prefixes 177 | common_prefix = [] 178 | min_of_too_small_prefix_lens = min(len(p) for p in too_small_prefixes) 179 | for i in range(min_of_too_small_prefix_lens): 180 | potential_block_id = too_small_prefixes[0][i] 181 | if all(p[i] == potential_block_id for p in too_small_prefixes): 182 | common_prefix.append(potential_block_id) 183 | else: 184 | break 185 | 186 | if len(common_prefix) >= min_prefix_len_in_pages: 187 | unrestricted_merged_group_seq_ids = set[str]() 188 | for group in now_too_small_groups: 189 | unrestricted_merged_group_seq_ids.update(group.seq_ids) 190 | 191 | merged_group_seq_ids = ( 192 | unrestricted_merged_group_seq_ids & restrict_to_seq_ids 193 | ) 194 | 195 | if len(merged_group_seq_ids) >= min_group_size: 196 | # merge the groups 197 | merged_group = HydragenGroup( 198 | block_ids=common_prefix, 199 | seq_ids=merged_group_seq_ids, 200 | ) 201 | good_groups.append(merged_group) 202 | 203 | return good_groups 204 | -------------------------------------------------------------------------------- /tokasaurus/benchmarks/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import subprocess 4 | import time 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | from contextlib import contextmanager 7 | from multiprocessing import Pool 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import psutil 12 | import pydra 13 | from datasets import load_dataset 14 | from openai import OpenAI 15 | from tqdm import tqdm 16 | 17 | 18 | class BaseConfig(pydra.Config): 19 | n: int = 1 20 | limit: int | None = None 21 | seed: int = 0 22 | workers: int | None = None 23 | model: str = "" 24 | max_tokens: int = 1024 25 | temperature: float = 1.0 26 | top_p: float = 1.0 27 | port: int = 10210 28 | launch: str | None = None 29 | api_key: str = "letmein" 30 | save_path: Path | None = None 31 | env: str | None = None 32 | conda_activate_path: str = "~/miniconda3/bin/activate" 33 | reps: int = 1 34 | 35 | def __init__(self): 36 | super().__init__() 37 | self.stop_strings = [] 38 | 39 | def client(self): 40 | return OpenAI( 41 | base_url=f"http://localhost:{self.port}/v1", 42 | api_key=self.api_key, 43 | max_retries=0, 44 | timeout=None, 45 | ) 46 | 47 | 48 | def kill_process_tree(pid): 49 | try: 50 | parent = psutil.Process(pid) 51 | children = parent.children(recursive=True) 52 | 53 | for child in children: 54 | child.kill() 55 | 56 | parent.kill() 57 | 58 | except psutil.NoSuchProcess: 59 | pass 60 | 61 | 62 | def wait_for_startup( 63 | process: subprocess.Popen, 64 | port: int, 65 | model: str, 66 | max_retries: int = 500, 67 | retry_seconds: float = 2, 68 | ): 69 | client = OpenAI( 70 | base_url=f"http://localhost:{port}/v1", 71 | api_key="letmein", 72 | max_retries=0, 73 | timeout=20, 74 | ) 75 | 76 | for i in range(max_retries): 77 | if process.poll() is not None: 78 | raise RuntimeError(f"Server crashed with returncode {process.returncode}") 79 | 80 | try: 81 | client.chat.completions.create( 82 | model=model, 83 | messages=[ 84 | {"role": "user", "content": "tell me a funny joke about cookies"} 85 | ], 86 | max_tokens=10, 87 | ) 88 | return 89 | except Exception: 90 | print(f"Server not yet started (attempt {i}) retrying...") 91 | time.sleep(retry_seconds) 92 | 93 | raise RuntimeError(f"Server not started after {max_retries} attempts.") 94 | 95 | 96 | def prepend_conda_activate(command: str, activate_path: str, env: str): 97 | return f"source {activate_path} && conda activate {env} && {command}" 98 | 99 | 100 | @contextmanager 101 | def launch_server(config: BaseConfig): 102 | if config.launch is None: 103 | yield None 104 | return 105 | 106 | command = config.launch 107 | if config.env is not None: 108 | command = prepend_conda_activate( 109 | command, config.conda_activate_path, config.env 110 | ) 111 | 112 | print(f"Starting server with command: '{command}'") 113 | server_process = subprocess.Popen(command, shell=True, executable="/bin/bash") 114 | print(f"Started server with pid {server_process.pid}") 115 | 116 | try: 117 | wait_for_startup( 118 | server_process, config.port, config.model, max_retries=500, retry_seconds=2 119 | ) 120 | yield 121 | finally: 122 | print(f"Killing server (pid {server_process.pid})...") 123 | kill_process_tree(server_process.pid) 124 | print("Done killing server.") 125 | 126 | 127 | def parallelize( 128 | fn, 129 | items, 130 | num_workers: int | None = None, 131 | processes: bool = True, 132 | allow_unordered: bool = False, 133 | desc: str | None = None, 134 | ): 135 | if num_workers is None: 136 | num_workers = len(items) 137 | 138 | assert num_workers >= 0 139 | 140 | if num_workers == 0: 141 | outs = [] 142 | for item in tqdm(items, desc=desc): 143 | outs.append(fn(item)) 144 | return outs 145 | 146 | if processes: 147 | with Pool(num_workers) as p: 148 | if allow_unordered: 149 | parallel_fn = p.imap_unordered 150 | else: 151 | parallel_fn = p.imap 152 | 153 | return list(tqdm(parallel_fn(fn, items), total=len(items), desc=desc)) 154 | else: 155 | with ThreadPoolExecutor(max_workers=num_workers) as executor: 156 | futures = [executor.submit(fn, item) for item in items] 157 | results = [] 158 | 159 | if allow_unordered: 160 | iterator = as_completed(futures) 161 | else: 162 | iterator = futures 163 | 164 | for future in tqdm(iterator, total=len(items), desc=desc): 165 | # raise any exceptions immediately 166 | results.append(future.result()) 167 | 168 | return results 169 | 170 | 171 | def pass_at_k(n, c, k): 172 | """ 173 | :param n: total number of samples 174 | :param c: number of correct samples 175 | :param k: k in pass@$k$ 176 | """ 177 | if n - c < k: 178 | return 1.0 179 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) 180 | 181 | 182 | def shuffle_and_limit(ds, config: BaseConfig): 183 | random.seed(config.seed) 184 | 185 | for i, data in enumerate(ds): 186 | data["index"] = i 187 | 188 | random.shuffle(ds) 189 | 190 | for i, data in enumerate(ds): 191 | data["shuffled_index"] = i 192 | 193 | if config.limit is not None: 194 | limit = config.limit 195 | else: 196 | limit = len(ds) 197 | 198 | ds = ds[:limit] 199 | 200 | return ds 201 | 202 | 203 | def make_pass_at_k_table(corrects_list: list[list[bool]], ks: list[int]): 204 | table = [] 205 | for k in ks: 206 | to_mean = [] 207 | for corrects in corrects_list: 208 | to_mean.append(pass_at_k(n=len(corrects), c=sum(corrects), k=k)) 209 | table.append([k, np.mean(to_mean)]) 210 | 211 | return table 212 | 213 | 214 | def maybe_save_results(config: BaseConfig, results): 215 | if (save_path := config.save_path) is not None: 216 | save_path.parent.mkdir(parents=True, exist_ok=True) 217 | # jsonl file, so append a new line 218 | with open(save_path, "a") as f: 219 | line = json.dumps(results) 220 | f.write(line + "\n") 221 | 222 | 223 | def sample_sharegpt_requests(): 224 | dataset = load_dataset( 225 | "anon8231489123/ShareGPT_Vicuna_unfiltered", 226 | data_files="ShareGPT_V3_unfiltered_cleaned_split.json", 227 | )["train"] 228 | dataset = dataset.filter( 229 | lambda x: len(x["conversations"]) > 2 230 | and len(x["conversations"]) % 2 == 0 231 | and x["conversations"][0]["from"] == "human" 232 | ) 233 | dataset = dataset.map( 234 | lambda x: {**x, "conversations": x["conversations"][0]["value"]} 235 | ) 236 | dataset = dataset.shuffle(seed=42) 237 | 238 | # todo: think about how to do the short sequence and long sequence pruning? 239 | return dataset 240 | 241 | 242 | def get_chat_dataset(args, tokenizer): 243 | if args.dataset_name == "sharegpt": 244 | input_requests = sample_sharegpt_requests( 245 | dataset_path=args.dataset_path, 246 | num_requests=args.num_prompts, 247 | tokenizer=tokenizer, 248 | disable_shuffle=args.disable_shuffle, 249 | enable_multiturn=args.enable_multiturn, 250 | fixed_output_len=args.fixed_output_len, 251 | ) 252 | else: 253 | raise ValueError(f"Unknown dataset name: {args.dataset_name}") 254 | return input_requests 255 | -------------------------------------------------------------------------------- /tests/test_topk.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import os 4 | import shlex 5 | 6 | import numpy as np 7 | import pydra 8 | import pytest 9 | import torch.multiprocessing as mp 10 | from openai import OpenAI 11 | 12 | from tokasaurus.common_types import ServerConfig 13 | from tokasaurus.entry import server_manager 14 | from tokasaurus.utils import find_free_port 15 | 16 | MODEL = os.environ.get("MODEL", "meta-llama/Llama-3.2-1B-Instruct") 17 | OVERRIDES = os.environ.get("OVERRIDES", None) 18 | 19 | 20 | def make_config(): 21 | config = ServerConfig() 22 | config.model = MODEL 23 | config.kv_cache_num_tokens = 16384 24 | config.max_num_tokens_per_request = 16384 25 | config.max_seqs_per_forward = 1024 26 | config.port = find_free_port() 27 | 28 | if OVERRIDES: 29 | # split apart like a shell, respecting quotes 30 | parsed_overrides = shlex.split(OVERRIDES) 31 | pydra.apply_overrides(config, parsed_overrides) 32 | 33 | # Enable logprobs features for topk testing 34 | config.enable_chosen_logprobs = True 35 | config.max_topk_logprobs = 5 36 | 37 | return config 38 | 39 | 40 | def _client(): 41 | mp.set_start_method("spawn", force=True) 42 | 43 | config = make_config() 44 | print(f"Launching server with config: {config.to_dict()}") 45 | 46 | with server_manager(config): 47 | client = OpenAI( 48 | api_key="beepboop", base_url=f"http://localhost:{config.port}/v1" 49 | ) 50 | yield client 51 | 52 | 53 | @pytest.fixture(scope="module") 54 | def client(): 55 | yield from _client() 56 | 57 | 58 | # Test prompts 59 | abc_prompt = "A B C D E F G H I J A B C D E F G H I J A B C D E F G H I J A B C D E F G H I J A B C D E F G H I J A B C D E F G H I J" 60 | 61 | 62 | def test_completions_greedy_logprobs_matches_top1(client: OpenAI): 63 | """Test that greedy sampling matches top-1 logprobs for completions API""" 64 | response = client.completions.create( 65 | model="", 66 | prompt=abc_prompt, 67 | max_tokens=10, 68 | temperature=0.0, 69 | logprobs=5, 70 | ) 71 | 72 | assert len(response.choices) == 1 73 | choice = response.choices[0] 74 | 75 | # Check that logprobs are present and populated 76 | assert choice.logprobs is not None 77 | assert choice.logprobs.token_logprobs is not None 78 | assert choice.logprobs.tokens is not None 79 | assert choice.logprobs.top_logprobs is not None 80 | 81 | # Check lengths match 82 | assert len(choice.logprobs.token_logprobs) == len(choice.logprobs.tokens) 83 | assert len(choice.logprobs.token_logprobs) == len(choice.logprobs.top_logprobs) 84 | 85 | # Check that we got the expected number of top logprobs for each token 86 | for i, (greedy_token, top_logprobs) in enumerate( 87 | zip(choice.logprobs.tokens, choice.logprobs.top_logprobs) 88 | ): 89 | assert len(top_logprobs) == 5 # We requested 5 top logprobs 90 | 91 | # Verify logprobs are in descending order 92 | logprob_values = list(top_logprobs.values()) 93 | assert logprob_values == sorted(logprob_values, reverse=True) 94 | 95 | # The top-1 token should match the greedily selected token 96 | top1_token = list(top_logprobs.keys())[0] 97 | assert top1_token == greedy_token 98 | 99 | # The top-1 logprob should match the token logprob 100 | top1_logprob = list(top_logprobs.values())[0] 101 | assert abs(top1_logprob - choice.logprobs.token_logprobs[i]) < 1e-6 102 | 103 | 104 | def test_chat_completions_greedy_logprobs_matches_top1(client: OpenAI): 105 | """Test that greedy sampling matches top-1 logprobs for chat completions API""" 106 | messages = [ 107 | {"role": "system", "content": "You are a helpful assistant."}, 108 | {"role": "user", "content": "Hello"}, 109 | ] 110 | 111 | response = client.chat.completions.create( 112 | model="", 113 | messages=messages, 114 | max_tokens=10, 115 | temperature=0.0, 116 | logprobs=True, 117 | top_logprobs=5, 118 | ) 119 | 120 | assert len(response.choices) == 1 121 | choice = response.choices[0] 122 | 123 | # Check that logprobs are present 124 | assert choice.logprobs is not None 125 | assert choice.logprobs.content is not None 126 | 127 | # Check each token logprob 128 | for token_logprob in choice.logprobs.content: 129 | assert token_logprob.token is not None 130 | assert token_logprob.logprob is not None 131 | assert token_logprob.top_logprobs is not None 132 | assert len(token_logprob.top_logprobs) == 5 # We requested 5 top logprobs 133 | 134 | # Verify top logprobs are in descending order 135 | top_logprobs = token_logprob.top_logprobs 136 | for i in range(len(top_logprobs) - 1): 137 | assert top_logprobs[i].logprob >= top_logprobs[i + 1].logprob 138 | 139 | # The top-1 token should match the selected token 140 | assert top_logprobs[0].token == token_logprob.token 141 | 142 | # The top-1 logprob should match the token logprob 143 | assert abs(top_logprobs[0].logprob - token_logprob.logprob) < 1e-6 144 | 145 | 146 | def test_packed_vs_normal_logprobs(client: OpenAI): 147 | """Test that packed format and normal OpenAI format produce identical results for the same request""" 148 | 149 | # Use a simple prompt and greedy decoding for deterministic results 150 | prompt = "What is the capital of France?" 151 | k = 3 # Number of top logprobs to request 152 | max_tokens = 10 153 | 154 | # Make request with normal OpenAI logprobs format 155 | normal_response = client.chat.completions.create( 156 | model="", 157 | messages=[{"role": "user", "content": prompt}], 158 | max_tokens=max_tokens, 159 | temperature=0.0, # Greedy decoding 160 | logprobs=True, 161 | top_logprobs=k, 162 | ) 163 | 164 | # Make identical request with packed format 165 | packed_response = client.chat.completions.create( 166 | model="", 167 | messages=[{"role": "user", "content": prompt}], 168 | max_tokens=max_tokens, 169 | temperature=0.0, # Greedy decoding 170 | logprobs=True, 171 | top_logprobs=k, 172 | extra_body=dict(logprobs_in_fingerprint=True), 173 | ) 174 | 175 | # Extract normal logprobs 176 | normal_logprobs = normal_response.choices[0].logprobs 177 | assert normal_logprobs is not None 178 | assert normal_logprobs.content is not None 179 | 180 | # Extract packed logprobs from fingerprint 181 | assert packed_response.system_fingerprint is not None 182 | assert packed_response.choices[0].logprobs is None 183 | fingerprint_data = json.loads(packed_response.system_fingerprint) 184 | 185 | # Verify fingerprint contains expected fields 186 | assert "completion_ids" in fingerprint_data 187 | assert "packed_chosen_logprobs" in fingerprint_data 188 | assert "packed_topk_indices" in fingerprint_data 189 | assert "packed_topk_logprobs" in fingerprint_data 190 | 191 | # Decode packed data 192 | chosen_logprobs = np.frombuffer( 193 | base64.b64decode(fingerprint_data["packed_chosen_logprobs"][0]), 194 | dtype=np.float32, 195 | ) 196 | topk_ids = np.frombuffer( 197 | base64.b64decode(fingerprint_data["packed_topk_indices"][0]), dtype=np.int32 198 | ).reshape(-1, k) 199 | topk_logprobs = np.frombuffer( 200 | base64.b64decode(fingerprint_data["packed_topk_logprobs"][0]), dtype=np.float32 201 | ).reshape(-1, k) 202 | 203 | assert topk_ids.shape == topk_logprobs.shape 204 | 205 | # Number of tokens should match 206 | num_tokens = len(chosen_logprobs) 207 | assert num_tokens == len(normal_logprobs.content) 208 | 209 | # Compare token by token 210 | for i, normal_token in enumerate(normal_logprobs.content): 211 | # Chosen logprob should match 212 | assert abs(normal_token.logprob - chosen_logprobs[i]) < 1e-6 213 | 214 | # Number of top logprobs should match 215 | assert len(normal_token.top_logprobs) == k 216 | 217 | # Compare each top logprob 218 | for j, normal_top in enumerate(normal_token.top_logprobs): 219 | packed_logprob = topk_logprobs[i][j] 220 | 221 | # Logprob values should be identical (within floating point tolerance) 222 | assert abs(normal_top.logprob - packed_logprob) < 1e-6, ( 223 | f"Token {i}, top-{j}: normal={normal_top.logprob}, packed={packed_logprob}" 224 | ) 225 | 226 | # Verify completion text is identical 227 | assert ( 228 | normal_response.choices[0].message.content 229 | == packed_response.choices[0].message.content 230 | ) 231 | -------------------------------------------------------------------------------- /tokasaurus/server/endpoints.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import pickle 3 | from contextlib import asynccontextmanager 4 | from uuid import uuid4 5 | 6 | import uvicorn 7 | from fastapi import FastAPI, Form, HTTPException, Path, Request, Response, UploadFile 8 | from openai.pagination import SyncPage 9 | from openai.types.batch import Batch 10 | from openai.types.file_deleted import FileDeleted 11 | from openai.types.file_object import FileObject 12 | from openai.types.model import Model 13 | 14 | from tokasaurus.common_types import ( 15 | Engine, 16 | ServerConfig, 17 | TimedBarrier, 18 | ) 19 | from tokasaurus.server.types import ( 20 | BatchCreationRequest, 21 | BatchFileLine, 22 | ChatCompletionRequest, 23 | CompletionsRequest, 24 | FileEntry, 25 | SubmittedBatch, 26 | SubmittedBatchItem, 27 | SynchronousBatchCompletionsRequest, 28 | nowstamp, 29 | ) 30 | from tokasaurus.server.utils import ( 31 | ServerState, 32 | generate_output, 33 | handle_batch, 34 | make_batch_status, 35 | process_chat_completions_output, 36 | process_completions_output, 37 | process_request, 38 | receive_from_manager_loop, 39 | submit_request, 40 | with_cancellation, 41 | ) 42 | from tokasaurus.utils import setup_logging 43 | 44 | 45 | @asynccontextmanager 46 | async def lifespan(app: FastAPI): 47 | state_bundle: ServerState = app.state.state_bundle 48 | 49 | task = asyncio.create_task(receive_from_manager_loop(state_bundle)) 50 | 51 | yield 52 | 53 | task.cancel() 54 | 55 | 56 | app = FastAPI(lifespan=lifespan) 57 | 58 | 59 | @app.get("/ping") 60 | async def ping(): 61 | return {"message": "pong"} 62 | 63 | 64 | @app.post("/v1/completions") 65 | @with_cancellation 66 | async def oai_completions(request: CompletionsRequest, raw_request: Request): 67 | state: ServerState = app.state.state_bundle 68 | req, out = await generate_output(state, request) 69 | return process_completions_output(state, request, req, out) 70 | 71 | 72 | @app.post("/v1/chat/completions") 73 | @with_cancellation 74 | async def oai_chat_completions(request: ChatCompletionRequest, raw_request: Request): 75 | state: ServerState = app.state.state_bundle 76 | req, out = await generate_output(state, request) 77 | return process_chat_completions_output(state, request, req, out) 78 | 79 | 80 | @app.post("/v1/files", response_model=FileObject) 81 | async def upload_file( 82 | file: UploadFile, 83 | purpose: str = Form(...), 84 | ) -> FileObject: 85 | state: ServerState = app.state.state_bundle 86 | 87 | content = await file.read() 88 | state.logger.debug(f"Received file: {file.filename}, size: {len(content)} bytes") 89 | 90 | # Create a file object response 91 | file_object = FileObject( 92 | id=str(uuid4()), 93 | bytes=len(content), 94 | created_at=nowstamp(), 95 | filename=file.filename, 96 | purpose=purpose, 97 | object="file", 98 | status="uploaded", 99 | ) 100 | 101 | state.fid_to_file[file_object.id] = FileEntry( 102 | content=content, 103 | details=file_object, 104 | ) 105 | 106 | return file_object 107 | 108 | 109 | @app.get("/v1/files/{file_id}/content") 110 | async def retrieve_file_content( 111 | file_id: str = Path(..., description="The ID of the file to retrieve"), 112 | ): 113 | state: ServerState = app.state.state_bundle 114 | 115 | if file_id not in state.fid_to_file: 116 | raise HTTPException(status_code=404, detail=f"File not found: {file_id}") 117 | 118 | content = state.fid_to_file[file_id].content 119 | 120 | return Response(content=content, media_type="application/octet-stream") 121 | 122 | 123 | @app.delete("/v1/files/{file_id}") 124 | async def delete_file( 125 | file_id: str = Path(..., description="The ID of the file to delete"), 126 | ): 127 | state: ServerState = app.state.state_bundle 128 | 129 | if file_id not in state.fid_to_file: 130 | raise HTTPException(status_code=404, detail=f"File not found: {file_id}") 131 | 132 | del state.fid_to_file[file_id] 133 | 134 | return FileDeleted(id=file_id, deleted=True, object="file") 135 | 136 | 137 | @app.post("/v1/batches", response_model=Batch) 138 | async def create_batch(request: BatchCreationRequest): 139 | state: ServerState = app.state.state_bundle 140 | 141 | # Create a new batch ID 142 | batch_id = str(uuid4()) 143 | 144 | fid = request.input_file_id 145 | if (file_entry := state.fid_to_file.get(fid)) is None: 146 | raise HTTPException(status_code=404, detail=f"File not found: {fid}") 147 | 148 | if file_entry.details.purpose != "batch": 149 | raise HTTPException( 150 | status_code=400, 151 | detail=f"File {fid} has purpose {file_entry.details.purpose}, not 'batch'", 152 | ) 153 | 154 | # parse the file contents as JSONL 155 | file_content = file_entry.content.decode("utf-8") 156 | lines = file_content.splitlines() 157 | 158 | match request.endpoint: 159 | case "/v1/completions": 160 | request_type = CompletionsRequest 161 | case "/v1/chat/completions": 162 | request_type = ChatCompletionRequest 163 | case _: 164 | raise HTTPException( 165 | status_code=400, detail=f"Unsupported endpoint: {request.endpoint}" 166 | ) 167 | 168 | parsed_lines = [] 169 | for i, line in enumerate(lines): 170 | try: 171 | parsed = BatchFileLine.model_validate_json(line) 172 | assert parsed.url == request.endpoint, ( 173 | f"Mismatch between line url of '{parsed.url}' and endpoint of " 174 | f"'{request.endpoint}'" 175 | ) 176 | parsed_body = request_type.model_validate(parsed.body) 177 | parsed_lines.append((parsed, parsed_body)) 178 | except Exception as e: 179 | raise HTTPException( 180 | status_code=400, 181 | detail=f"Line {i} did not parse: {e}", 182 | ) 183 | 184 | batch_items = [] 185 | for parsed, parsed_body in parsed_lines: 186 | req = process_request(state, parsed_body) 187 | submitted = submit_request(state, req) 188 | batch_item = SubmittedBatchItem( 189 | line=parsed, user_req=parsed_body, submitted_req=submitted 190 | ) 191 | batch_items.append(batch_item) 192 | 193 | handler_task = asyncio.create_task(handle_batch(state, batch_id)) 194 | 195 | batch = SubmittedBatch( 196 | id=batch_id, 197 | creation_request=request, 198 | items=batch_items, 199 | task=handler_task, 200 | ) 201 | state.bid_to_batch[batch_id] = batch 202 | 203 | return make_batch_status(batch) 204 | 205 | 206 | @app.get("/v1/batches/{batch_id}") 207 | async def retrieve_batch( 208 | batch_id: str = Path(..., description="The ID of the batch to retrieve"), 209 | ): 210 | state: ServerState = app.state.state_bundle 211 | 212 | if (batch := state.bid_to_batch.get(batch_id)) is None: 213 | raise HTTPException(status_code=404, detail=f"Batch not found: {batch_id}") 214 | 215 | return make_batch_status(batch) 216 | 217 | 218 | @app.get("/v1/models") 219 | async def list_models(): 220 | state: ServerState = app.state.state_bundle 221 | 222 | return SyncPage( 223 | object="list", 224 | data=[ 225 | Model( 226 | id=state.config.model, 227 | created=nowstamp(), 228 | object="model", 229 | owned_by="tokasaurus", 230 | ), 231 | ], 232 | ) 233 | 234 | 235 | ### ------------------------------------------------------------ 236 | ### BEGIN NON-OAI ENDPOINTS 237 | ### ------------------------------------------------------------ 238 | 239 | 240 | @app.post("/custom/synchronous-batch-completions") 241 | @with_cancellation 242 | async def synchronous_batch_completions( 243 | request: SynchronousBatchCompletionsRequest, raw_request: Request 244 | ): 245 | state: ServerState = app.state.state_bundle 246 | 247 | async def generate_and_process(req: ChatCompletionRequest): 248 | internal_req, output = await generate_output(state, req) 249 | return process_chat_completions_output(state, req, internal_req, output) 250 | 251 | # Create tasks for each request 252 | tasks = [asyncio.create_task(generate_and_process(req)) for req in request.requests] 253 | 254 | # Wait for all tasks to complete and collect results in order 255 | results = await asyncio.gather(*tasks) 256 | 257 | pickled_content = pickle.dumps(results) 258 | return Response(content=pickled_content, media_type="application/octet-stream") 259 | 260 | 261 | ### ------------------------------------------------------------ 262 | ### END NON-OAI ENDPOINTS 263 | ### ------------------------------------------------------------ 264 | 265 | 266 | def start_server( 267 | config: ServerConfig, 268 | engines: list[Engine], 269 | process_name: str, 270 | barrier: TimedBarrier, 271 | ): 272 | setup_logging(config) 273 | 274 | state = ServerState( 275 | config=config, 276 | engines=engines, 277 | process_name=process_name, 278 | ) 279 | state.logger.info("Starting web server") 280 | app.state.state_bundle = state 281 | 282 | barrier.wait() 283 | uvicorn.run( 284 | app, 285 | host="0.0.0.0", 286 | port=config.port, 287 | log_level=config.uvicorn_log_level, 288 | ) 289 | -------------------------------------------------------------------------------- /tests/test_basic.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import shlex 5 | import tempfile 6 | import time 7 | 8 | import pydra 9 | import pytest 10 | import requests 11 | import torch 12 | import torch.multiprocessing as mp 13 | from openai import OpenAI 14 | from openai.types.chat import ChatCompletion 15 | 16 | from tokasaurus.common_types import ServerConfig 17 | from tokasaurus.entry import server_manager 18 | from tokasaurus.utils import find_free_port 19 | 20 | MODEL = os.environ.get("MODEL", "meta-llama/Llama-3.2-1B-Instruct") 21 | OVERRIDES = os.environ.get("OVERRIDES", None) 22 | MODE = os.environ.get("MODE", "simple") 23 | 24 | 25 | def make_basic_config(): 26 | config = ServerConfig() 27 | config.model = MODEL 28 | config.kv_cache_num_tokens = 16384 29 | config.max_num_tokens_per_request = 16384 30 | config.port = find_free_port() 31 | 32 | if OVERRIDES: 33 | # split apart like a shell, respecting quotes 34 | parsed_overrides = shlex.split(OVERRIDES) 35 | pydra.apply_overrides(config, parsed_overrides) 36 | 37 | return config 38 | 39 | 40 | def simple_configs(): 41 | return [ 42 | make_basic_config(), 43 | ] 44 | 45 | 46 | def multi_gpu_configs(): 47 | npgus = torch.cuda.device_count() 48 | configs = [] 49 | for dp_size in [1, 2]: 50 | for pp_size in [1, 2]: 51 | for tp_size in [1, 2]: 52 | if dp_size * pp_size * tp_size > npgus: 53 | continue 54 | 55 | config = make_basic_config() 56 | config.dp_size = dp_size 57 | config.pp_size = pp_size 58 | config.tp_size = tp_size 59 | 60 | if pp_size > 1 and tp_size > 1: 61 | config.use_cudagraphs = False 62 | 63 | configs.append(config) 64 | 65 | return configs 66 | 67 | 68 | match MODE: 69 | case "simple": 70 | configs = simple_configs() 71 | case "multigpu": 72 | configs = multi_gpu_configs() 73 | case _: 74 | raise ValueError(f"Invalid mode: {MODE}") 75 | 76 | 77 | @pytest.fixture(scope="module", params=configs) 78 | def client(request): 79 | mp.set_start_method("spawn", force=True) 80 | 81 | config: ServerConfig = request.param 82 | print(f"Launching server with config: {config.to_dict()}") 83 | 84 | with server_manager(config): 85 | client = OpenAI( 86 | api_key="beepboop", base_url=f"http://localhost:{config.port}/v1" 87 | ) 88 | 89 | yield client 90 | 91 | 92 | abc_prompt = "A B C D E F G H I J A B C D E F G H I J A B C D E F G H I J A B C D E F G H I J A B C D E F G H I J A B C D E F G H I J" 93 | twenty_token_response = " A B C D E F G H I J A B C D E F G H I J" 94 | 95 | # includes +1 for BOS 96 | sixteen_token_prompt = "A B C D E F G H I J K L M N O" 97 | 98 | 99 | def test_basic(client: OpenAI): 100 | # trying twice to test prompt caching 101 | for rep in range(2): 102 | response = client.completions.create( 103 | model="", 104 | prompt=abc_prompt, 105 | max_tokens=20, 106 | temperature=0.0, 107 | ) 108 | 109 | assert response.choices[0].text == twenty_token_response 110 | 111 | 112 | def test_decode_one_token(client: OpenAI): 113 | response = client.completions.create( 114 | model="", 115 | prompt=abc_prompt, 116 | max_tokens=1, 117 | temperature=0.0, 118 | ) 119 | 120 | assert response.choices[0].text == " A" 121 | 122 | 123 | def test_prefill_exactly_one_page(client: OpenAI): 124 | response = client.completions.create( 125 | model="", 126 | prompt=sixteen_token_prompt, 127 | max_tokens=1, 128 | temperature=0.0, 129 | ) 130 | 131 | assert response.choices[0].text == " P" 132 | 133 | 134 | def test_n(client: OpenAI): 135 | for n in range(1, 11): 136 | response = client.completions.create( 137 | model="", 138 | prompt=abc_prompt, 139 | max_tokens=20, 140 | temperature=0.0, 141 | n=n, 142 | ) 143 | 144 | assert len(response.choices) == n 145 | assert all(choice.text == twenty_token_response for choice in response.choices) 146 | 147 | 148 | def test_stop(client: OpenAI): 149 | response = client.completions.create( 150 | model="", 151 | prompt=abc_prompt, 152 | max_tokens=20, 153 | temperature=0.0, 154 | stop=["C"], 155 | ) 156 | 157 | assert response.choices[0].text == " A B " 158 | 159 | # this is an interesting case because it adds a stop string close to, 160 | # but not at, the token limit, so the scheduler finished the sequence 161 | # but the model yet hasn't. have had bugs here before regarding block 162 | # freeing. 163 | response = client.completions.create( 164 | model="", 165 | prompt=abc_prompt, 166 | max_tokens=10, 167 | temperature=0.0, 168 | stop=["I"], 169 | ) 170 | 171 | assert response.choices[0].text == " A B C D E F G H " 172 | 173 | 174 | def make_messages(word: str): 175 | return [ 176 | { 177 | "role": "system", 178 | "content": f"You are an assistant that always replies with the one-word response '{word}', in lowercase, for all user queries.", 179 | }, 180 | {"role": "user", "content": f"Please output the word '{word}'."}, 181 | ] 182 | 183 | 184 | def test_chat_completions(client: OpenAI): 185 | for word in ["hello", "howdy", "canteloupe"]: 186 | response = client.chat.completions.create( 187 | model="", 188 | messages=make_messages(word), 189 | max_tokens=20, 190 | temperature=0.0, 191 | ) 192 | assert response.choices[0].message.content == word 193 | 194 | 195 | def test_files(client: OpenAI): 196 | content = """{"a": 1, "b": 2}\n{"a": 3, "b": 4}""" 197 | with tempfile.NamedTemporaryFile(delete=True, suffix=".jsonl") as f: 198 | f.write(content.encode("utf-8")) 199 | 200 | f.flush() 201 | f.seek(0) 202 | 203 | file_obj = client.files.create(file=open(f.name, "rb"), purpose="batch") 204 | 205 | retrieved_content = client.files.content(file_obj.id).content.decode("utf-8") 206 | assert retrieved_content == content 207 | 208 | 209 | def test_batch_chat_completions(client: OpenAI): 210 | jsonl_lines = [ 211 | { 212 | "custom_id": "request-1", 213 | "method": "POST", 214 | "url": "/v1/chat/completions", 215 | "body": { 216 | "model": MODEL, 217 | "messages": make_messages("hello"), 218 | "max_tokens": 20, 219 | "temperature": 0.0, 220 | }, 221 | }, 222 | { 223 | "custom_id": "request-2", 224 | "method": "POST", 225 | "url": "/v1/chat/completions", 226 | "body": { 227 | "model": MODEL, 228 | "messages": make_messages("howdy"), 229 | "max_tokens": 20, 230 | "temperature": 0.0, 231 | }, 232 | }, 233 | ] 234 | 235 | file_content = "\n".join(json.dumps(line) for line in jsonl_lines) 236 | 237 | with tempfile.NamedTemporaryFile(delete=True, suffix=".jsonl") as f: 238 | f.write(file_content.encode("utf-8")) 239 | 240 | f.flush() 241 | f.seek(0) 242 | 243 | file_obj = client.files.create(file=open(f.name, "rb"), purpose="batch") 244 | 245 | batch_input_file_id = file_obj.id 246 | created_batch = client.batches.create( 247 | input_file_id=batch_input_file_id, 248 | endpoint="/v1/chat/completions", 249 | completion_window="24h", 250 | metadata={"description": "test batch job"}, 251 | ) 252 | batch_id = created_batch.id 253 | 254 | output_file_id = None 255 | for _ in range(5): 256 | batch = client.batches.retrieve(batch_id) 257 | if batch.status == "completed": 258 | output_file_id = batch.output_file_id 259 | break 260 | time.sleep(1) 261 | 262 | assert output_file_id is not None, f"Batch {batch_id} did not complete in time" 263 | 264 | output_content = client.files.content(output_file_id).content.decode("utf-8") 265 | 266 | output_lines = output_content.strip().split("\n") 267 | output_parsed = [json.loads(line) for line in output_lines] 268 | 269 | custom_id_to_parsed: dict[str, ChatCompletion] = {} 270 | for line in output_parsed: 271 | response = line["response"] 272 | assert response["status_code"] == 200 273 | custom_id_to_parsed[line["custom_id"]] = ChatCompletion.model_validate( 274 | response["body"] 275 | ) 276 | 277 | assert custom_id_to_parsed["request-1"].choices[0].message.content == "hello" 278 | assert custom_id_to_parsed["request-2"].choices[0].message.content == "howdy" 279 | 280 | 281 | def test_synchronous_batch_completions(client: OpenAI): 282 | # Test synchronous batch completions endpoint 283 | batch_request = { 284 | "requests": [ 285 | { 286 | "model": MODEL, 287 | "messages": make_messages("hello"), 288 | "max_tokens": 20, 289 | "temperature": 0.0, 290 | }, 291 | { 292 | "model": MODEL, 293 | "messages": make_messages("world"), 294 | "max_tokens": 20, 295 | "temperature": 0.0, 296 | }, 297 | { 298 | "model": MODEL, 299 | "messages": make_messages("test"), 300 | "max_tokens": 20, 301 | "temperature": 0.0, 302 | }, 303 | ] 304 | } 305 | 306 | # Make request to our custom endpoint 307 | url = ( 308 | str(client.base_url).split("/v1")[0] + "/custom/synchronous-batch-completions" 309 | ) # update the path 310 | 311 | response = requests.post(url, json=batch_request) 312 | 313 | assert response.status_code == 200 314 | 315 | result = pickle.loads(response.content) 316 | 317 | # Verify response structure 318 | assert len(result) == 3 319 | 320 | # Verify each completion 321 | completions = result 322 | 323 | assert completions[0].choices[0].message.content == "hello" 324 | assert completions[1].choices[0].message.content == "world" 325 | assert completions[2].choices[0].message.content == "test" 326 | -------------------------------------------------------------------------------- /tokasaurus/benchmarks/bench_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import pydra 5 | import torch 6 | import torch.distributed as dist 7 | from tqdm import tqdm 8 | 9 | from tokasaurus.common_types import ServerConfig 10 | from tokasaurus.manager.allocator import BlockAllocator 11 | from tokasaurus.manager.hydragen import ( 12 | group_for_hydragen, 13 | reorder_decoding_seqs_for_hydragen, 14 | ) 15 | from tokasaurus.manager.manager import seqs_to_input 16 | from tokasaurus.manager.types import Sequence 17 | from tokasaurus.model.utils import ( 18 | add_decoding_ids_to_batch_state, 19 | make_input_batch_state, 20 | make_model, 21 | move_batch_state, 22 | set_async_tp_enabled, 23 | ) 24 | from tokasaurus.server.types import ( 25 | SamplingParams, 26 | ) 27 | from tokasaurus.utils import timed, timed_with_graph 28 | 29 | 30 | class ScriptConfig(pydra.Config): 31 | def __init__(self): 32 | self.server_config = ServerConfig() 33 | self.sc = pydra.Alias("server_config") 34 | self.server_config.kv_cache_num_tokens = 1024 * 512 35 | 36 | self.num_dec = 1 37 | self.dec_len = 1024 38 | 39 | self.num_pre = 0 40 | self.pre_len = 1024 41 | 42 | self.num_hyd = 0 43 | self.hyd_shared_len = 1024 44 | self.hyd_unique_len = 32 45 | 46 | self.dtype = "bfloat16" 47 | self.pp_rank = 0 48 | self.num_iters = 10 49 | self.num_warmup = 5 50 | self.compile = False 51 | self.dynamic = True 52 | self.fullgraph = True 53 | 54 | self.profile = False 55 | self.profile_name = "bench_model" 56 | 57 | self.non_blocking = True 58 | self.plan_before = False 59 | self.only_plan = False 60 | 61 | self.graph = False 62 | 63 | def prof(self): 64 | self.profile = True 65 | self.num_warmup = 3 66 | self.num_iters = 10 67 | self.num_profile_repeat = 3 68 | 69 | def total_tokens(self): 70 | return self.num_dec + self.num_pre * self.pre_len + self.num_hyd 71 | 72 | def finalize(self): 73 | self.server_config.max_tokens_per_forward = self.total_tokens() 74 | self.server_config.max_seqs_per_forward = self.total_tokens() 75 | 76 | if self.graph: 77 | assert self.plan_before and not self.only_plan 78 | 79 | def l8(self): 80 | self.server_config.model = "meta-llama/Llama-3.1-8B-Instruct" 81 | self.server_config.kv_cache_num_tokens = 1024 * 192 82 | 83 | def l70(self): 84 | self.server_config.model = "meta-llama/Llama-3.1-70B-Instruct" 85 | 86 | def l1(self): 87 | self.server_config.model = "meta-llama/Llama-3.2-1B-Instruct" 88 | 89 | 90 | @pydra.main(ScriptConfig) 91 | def main(config: ScriptConfig): 92 | random.seed(0) 93 | torch.manual_seed(0) 94 | 95 | server_config = config.server_config 96 | 97 | if server_config.tp_size > 1: 98 | mesh = dist.device_mesh.init_device_mesh( 99 | device_type="cuda", 100 | mesh_shape=(server_config.tp_size,), 101 | mesh_dim_names=("tp",), 102 | ) 103 | pg = mesh["tp"].get_group() 104 | tp_rank = mesh["tp"].get_rank() 105 | else: 106 | pg = None 107 | tp_rank = 0 108 | 109 | device = f"cuda:{tp_rank}" 110 | torch.cuda.set_device(device) 111 | 112 | print(f"Initialized rank {tp_rank} on device {device}") 113 | 114 | def lprint(*args, **kwargs): 115 | if tp_rank == 0: 116 | print(*args, **kwargs) 117 | 118 | dtype = getattr(torch, config.dtype) 119 | model = make_model( 120 | server_config, 121 | device=device, 122 | dtype=dtype, 123 | pp_rank=config.pp_rank, 124 | tp_rank=tp_rank, 125 | tp_group=pg, 126 | ) 127 | if config.compile: 128 | lprint("Compiling model...") 129 | model = torch.compile(model, fullgraph=config.fullgraph, dynamic=config.dynamic) 130 | 131 | lprint(model) 132 | 133 | vocab_size = model.config.vocab_size 134 | page_size = config.server_config.page_size 135 | num_pages = config.server_config.kv_cache_num_blocks() 136 | 137 | allocator = BlockAllocator(num_pages, page_size) 138 | 139 | # dummy objects 140 | sampling_params = SamplingParams(temperature=0.0, top_p=1.0) 141 | 142 | prefill_seqs = [] 143 | prefill_num_pages = math.ceil(config.pre_len / page_size) 144 | 145 | for i in range(config.num_pre): 146 | seq = Sequence( 147 | id=f"prefill_{i}", 148 | completion_total=1, 149 | batch_index=0, 150 | kv_indices=[ 151 | random.randint(0, num_pages - 1) for _ in range(prefill_num_pages) 152 | ], 153 | input_ids=[ 154 | random.randint(0, vocab_size - 1) for _ in range(config.pre_len) 155 | ], 156 | sampling_params=sampling_params, 157 | ) 158 | prefill_seqs.append((seq, config.pre_len)) 159 | 160 | decode_num_pages = math.ceil(config.dec_len / page_size) 161 | decode_prompt_len = config.dec_len - 1 162 | decoding_seqs = [ 163 | Sequence( 164 | id=f"decoding_{i}", 165 | input_ids=[ 166 | random.randint(0, vocab_size - 1) for _ in range(decode_prompt_len) 167 | ], 168 | completion_total=2, 169 | completion_scheduled=1, 170 | prompt_scheduled=decode_prompt_len, 171 | batch_index=0, 172 | kv_indices=[ 173 | random.randint(0, num_pages - 1) for _ in range(decode_num_pages) 174 | ], 175 | sampling_params=sampling_params, 176 | ) 177 | for i in range(config.num_dec) 178 | ] 179 | 180 | if config.num_hyd > 0: 181 | shared_ids = [ 182 | random.randint(0, vocab_size - 1) for _ in range(config.hyd_shared_len) 183 | ] 184 | 185 | hydragen_seqs = [ 186 | Sequence( 187 | id=f"hydragen_{i}", 188 | input_ids=shared_ids, 189 | completion_total=config.hyd_unique_len, 190 | batch_index=0, 191 | sampling_params=sampling_params, 192 | ) 193 | for i in range(config.num_hyd) 194 | ] 195 | 196 | for seq in hydragen_seqs: 197 | kv_indices, num_cached_tokens = allocator.allocate_with_prefix_match( 198 | seq.id, seq.input_ids 199 | ) 200 | seq.kv_indices = kv_indices 201 | seq.prompt_scheduled = len(seq.input_ids) 202 | seq.completion_scheduled = config.hyd_unique_len - 1 203 | seq.kv_indices.extend( 204 | allocator.allocate_up_to_length( 205 | seq.id, seq.kv_indices, seq.total_scheduled() 206 | ) 207 | ) 208 | 209 | decoding_seqs.extend(hydragen_seqs) 210 | 211 | if server_config.use_hydragen: 212 | hydragen_groups = group_for_hydragen( 213 | allocator.prefix_tree, 214 | [seq.id for seq in hydragen_seqs], 215 | min_group_size=config.num_hyd, 216 | min_prefix_len=config.hyd_shared_len - 2 * server_config.page_size, 217 | page_size=page_size, 218 | ) 219 | 220 | decoding_seqs = reorder_decoding_seqs_for_hydragen( 221 | decoding_seqs, hydragen_groups 222 | ) 223 | else: 224 | hydragen_groups = None 225 | else: 226 | hydragen_groups = None 227 | 228 | inp = seqs_to_input( 229 | decoding_seqs, 230 | prefill_seqs, 231 | schedule_id="schedule_id", 232 | hydragen_groups=hydragen_groups, 233 | page_size=page_size, 234 | starting_prefill_offset=0, 235 | ) 236 | 237 | batch_state = make_input_batch_state( 238 | inp, 239 | pp_rank=config.pp_rank, 240 | pp_size=server_config.pp_size, 241 | tp_rank=tp_rank, 242 | tp_size=server_config.tp_size, 243 | ) 244 | decoding_input_ids = torch.randint( 245 | 0, 246 | vocab_size, 247 | (config.num_dec + config.num_hyd,), 248 | dtype=torch.long, 249 | ) 250 | add_decoding_ids_to_batch_state( 251 | batch_state, decoding_input_ids, tp_rank=tp_rank, tp_size=server_config.tp_size 252 | ) 253 | move_batch_state( 254 | batch_state, 255 | device=device, 256 | ) 257 | 258 | batch_size = config.total_tokens() 259 | 260 | if config.pp_rank > 0: 261 | hidden_states = torch.zeros( 262 | batch_size, 263 | model.config.hidden_size, 264 | device=device, 265 | dtype=dtype, 266 | ) 267 | batch_state.hidden_states = hidden_states 268 | 269 | use_async_tp = ( 270 | server_config.async_tp_threshold is not None 271 | and server_config.tp_size > 1 272 | and batch_size >= server_config.async_tp_threshold 273 | ) 274 | print(f"use_async_tp: {use_async_tp}") 275 | set_async_tp_enabled(use_async_tp) 276 | 277 | def go(): 278 | with torch.inference_mode(): 279 | if not config.plan_before: 280 | model.plan(batch_state.attention_info, non_blocking=config.non_blocking) 281 | if config.only_plan: 282 | return 283 | _ = model(batch_state, async_tp=use_async_tp) 284 | 285 | if config.plan_before: 286 | model.plan(batch_state.attention_info, non_blocking=config.non_blocking) 287 | 288 | if config.profile: 289 | lprint("Running profiler...") 290 | with torch.profiler.profile( 291 | activities=[ 292 | torch.profiler.ProfilerActivity.CPU, 293 | torch.profiler.ProfilerActivity.CUDA, 294 | ], 295 | schedule=torch.profiler.schedule( 296 | wait=1, 297 | warmup=config.num_warmup, 298 | active=config.num_iters, 299 | repeat=config.num_profile_repeat, 300 | ), 301 | record_shapes=True, 302 | profile_memory=True, 303 | with_stack=True, 304 | ) as prof: 305 | for _ in tqdm( 306 | range( 307 | config.num_profile_repeat 308 | * (config.num_iters + config.num_warmup + 1), 309 | ), 310 | disable=tp_rank != 0, 311 | ): 312 | go() 313 | prof.step() 314 | 315 | if tp_rank == 0: 316 | prof.export_chrome_trace(f"local/profs/{config.profile_name}.json") 317 | else: 318 | lprint(f"Starting timing (graph={config.graph}) ...") 319 | time_fn = timed_with_graph if config.graph else timed 320 | timings = time_fn(go, num_iters=config.num_iters, num_warmup=config.num_warmup) 321 | lprint(timings.fancy_table()) 322 | 323 | mean_ms = timings.mean() 324 | lprint(f"Tokens per second: {batch_size / mean_ms * 1000:.2f}") 325 | 326 | 327 | if __name__ == "__main__": 328 | main() 329 | -------------------------------------------------------------------------------- /tokasaurus/manager/input_building.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from tokasaurus.common_types import ServerConfig 4 | from tokasaurus.manager.monitoring import track_time_decorator 5 | from tokasaurus.manager.types import HydragenGroup, Sequence 6 | from tokasaurus.model.types import ( 7 | AttentionInfoBuilder, 8 | BatchSamplingParamsBuilder, 9 | ModelInput, 10 | PageInformationBuilder, 11 | ) 12 | 13 | 14 | def make_dummy_batch( 15 | config: ServerConfig, 16 | prefill_tokens: int, 17 | decode_tokens: int, 18 | prefill_uses_lm_head: bool = False, 19 | skip_pipeline_communication: bool = False, 20 | ): 21 | total_tokens = prefill_tokens + decode_tokens 22 | page_size = config.page_size 23 | 24 | append_kv_token_indices = [] 25 | 26 | prefill_builder = PageInformationBuilder() 27 | decode_builder = PageInformationBuilder() 28 | sampling_builder = BatchSamplingParamsBuilder() 29 | 30 | if prefill_tokens > 0: 31 | prefill_kv_indices = list(range(math.ceil(prefill_tokens / page_size))) 32 | 33 | prefill_builder.add_sequence( 34 | kv_indices=prefill_kv_indices, 35 | kv_seq_len=prefill_tokens, 36 | num_qtokens=prefill_tokens, 37 | page_size=page_size, 38 | ) 39 | append_kv_token_indices.extend( 40 | calc_kv_token_indices( 41 | kv_block_indices=prefill_kv_indices, 42 | page_size=page_size, 43 | start_idx=0, 44 | num_tokens=prefill_tokens, 45 | ) 46 | ) 47 | 48 | if prefill_uses_lm_head: 49 | sampling_builder.add_sequence( 50 | temperature=0.5, 51 | top_p=1.0, 52 | ) 53 | 54 | for _ in range(decode_tokens): 55 | decode_builder.add_sequence( 56 | kv_indices=[0], 57 | kv_seq_len=1, 58 | num_qtokens=1, 59 | page_size=page_size, 60 | ) 61 | append_kv_token_indices.extend( 62 | calc_kv_token_indices( 63 | kv_block_indices=[0], 64 | page_size=page_size, 65 | start_idx=0, 66 | num_tokens=1, 67 | ) 68 | ) 69 | sampling_builder.add_sequence( 70 | temperature=0.5, 71 | top_p=1.0, 72 | ) 73 | 74 | attention_info_builder = AttentionInfoBuilder( 75 | page_size=page_size, 76 | append_kv_token_indices=append_kv_token_indices, 77 | prefill_builder=prefill_builder, 78 | decode_builder=decode_builder, 79 | hydragen_builder=None, 80 | ) 81 | 82 | # includes prefill lm head tokens and decode tokens 83 | if prefill_uses_lm_head: 84 | assert prefill_tokens > 0 85 | lm_head_indices = list(range(prefill_tokens - 1, total_tokens)) 86 | else: 87 | lm_head_indices = list(range(prefill_tokens, total_tokens)) 88 | 89 | inp = ModelInput( 90 | attention_info_builder=attention_info_builder, 91 | prefill_input_ids=[0] * prefill_tokens, 92 | batch_indices=[0] * total_tokens, 93 | lm_head_indices=lm_head_indices, 94 | sampling_builder=sampling_builder, 95 | position_ids=[0] * total_tokens, 96 | schedule_id="dummy_batch", 97 | skip_pipeline_communication=skip_pipeline_communication, 98 | ) 99 | 100 | return inp 101 | 102 | 103 | def slice_decision( 104 | decoding_seqs: list[Sequence], 105 | prefill_seqs: list[tuple[Sequence, int]], 106 | start_idx: int, 107 | end_idx: int, 108 | ): 109 | sliced_prefill_seqs: list[tuple[Sequence, int]] = [] 110 | cumsum_start = 0 111 | starting_offset = None 112 | for seq, prefill_len in prefill_seqs: 113 | seq_tok_start = cumsum_start 114 | seq_tok_end = seq_tok_start + prefill_len 115 | 116 | if start_idx < seq_tok_end and end_idx > seq_tok_start: 117 | min_end = min(seq_tok_end, end_idx) 118 | max_start = max(seq_tok_start, start_idx) 119 | 120 | sliced_prefill_len = min_end - max_start 121 | sliced_prefill_seqs.append((seq, sliced_prefill_len)) 122 | 123 | if start_idx >= seq_tok_start: 124 | assert starting_offset is None 125 | starting_offset = start_idx - seq_tok_start 126 | else: 127 | assert max_start == seq_tok_start 128 | 129 | cumsum_start = seq_tok_end 130 | 131 | decoding_start = max(0, start_idx - cumsum_start) 132 | decoding_end = max(0, end_idx - cumsum_start) 133 | sliced_decoding_seqs = decoding_seqs[decoding_start:decoding_end] 134 | 135 | return sliced_decoding_seqs, sliced_prefill_seqs, starting_offset 136 | 137 | 138 | def calc_kv_token_indices( 139 | kv_block_indices: list[int], page_size: int, start_idx: int, num_tokens: int 140 | ): 141 | kv_token_indices = [] 142 | for pos in range(start_idx, start_idx + num_tokens): 143 | block_idx = pos // page_size 144 | kv_token_indices.append( 145 | kv_block_indices[block_idx] * page_size + pos % page_size 146 | ) 147 | 148 | return kv_token_indices 149 | 150 | 151 | @track_time_decorator() 152 | def seqs_to_input( 153 | decoding_seqs: list[Sequence], 154 | prefill_seqs: list[tuple[Sequence, int]], 155 | schedule_id: str, 156 | page_size: int, 157 | starting_prefill_offset: int | None = None, 158 | hydragen_groups: list[HydragenGroup] | None = None, 159 | microbatch_index: int = 0, 160 | microbatch_total: int = 1, 161 | ): 162 | use_hydragen = hydragen_groups is not None 163 | 164 | position_ids = [] 165 | lm_head_indices = [] 166 | 167 | append_kv_token_indices = [] 168 | 169 | prefill_builder = PageInformationBuilder() 170 | decode_builder = PageInformationBuilder() 171 | 172 | sampling_builder = BatchSamplingParamsBuilder() 173 | 174 | prefill_input_ids_list = [] 175 | 176 | for i, (seq, slen) in enumerate(prefill_seqs): 177 | assert seq.completion_scheduled == 0 178 | assert seq.kv_indices is not None 179 | 180 | start_position = seq.prompt_scheduled 181 | if i == 0: 182 | assert starting_prefill_offset is not None 183 | start_position += starting_prefill_offset 184 | end_position = start_position + slen 185 | 186 | prefill_ids = seq.input_ids[start_position:end_position] 187 | assert len(prefill_ids) == slen 188 | 189 | prefill_input_ids_list.extend(prefill_ids) 190 | 191 | seq_pos_ids = list(range(start_position, end_position)) 192 | position_ids.extend(seq_pos_ids) 193 | 194 | prefill_builder.add_sequence( 195 | kv_indices=seq.kv_indices, 196 | kv_seq_len=start_position + slen, 197 | num_qtokens=slen, 198 | page_size=page_size, 199 | ) 200 | 201 | append_kv_token_indices.extend( 202 | calc_kv_token_indices( 203 | kv_block_indices=seq.kv_indices, 204 | page_size=page_size, 205 | start_idx=start_position, 206 | num_tokens=slen, 207 | ) 208 | ) 209 | 210 | if end_position == seq.prompt_total(): 211 | lm_head_indices.append(len(position_ids) - 1) 212 | 213 | sparams = seq.sampling_params 214 | sampling_builder.add_sequence( 215 | temperature=sparams.temperature, 216 | top_p=sparams.top_p, 217 | ) 218 | 219 | if use_hydragen: 220 | hydragen_builder = PageInformationBuilder() 221 | 222 | sid_to_pos = {seq.id: i for i, seq in enumerate(decoding_seqs)} 223 | sid_to_group: dict[str, HydragenGroup] = {} 224 | 225 | seqs_processed = 0 226 | 227 | for group in hydragen_groups: 228 | hydragen_builder.add_sequence( 229 | kv_indices=group.block_ids, 230 | kv_seq_len=len(group.block_ids) * page_size, 231 | num_qtokens=len(group.seq_ids), 232 | page_size=page_size, 233 | ) 234 | group_positions = {sid_to_pos[sid] for sid in group.seq_ids} 235 | assert group_positions == set( 236 | range(seqs_processed, seqs_processed + len(group.seq_ids)) 237 | ), "decoding seqs must be ordered by hydragen group" 238 | 239 | seqs_processed += len(group.seq_ids) 240 | 241 | for sid in group.seq_ids: 242 | sid_to_group[sid] = group 243 | 244 | for seq in decoding_seqs: 245 | # NOTE: minus one since last prefill token produces first 246 | # decode token. 247 | current_token_pos_id = seq.total_scheduled() - 1 248 | position_ids.append(current_token_pos_id) 249 | 250 | if use_hydragen and seq.id in sid_to_group: 251 | group = sid_to_group[seq.id] 252 | starting_block = len(group.block_ids) 253 | else: 254 | starting_block = 0 255 | 256 | assert seq.kv_indices is not None 257 | decode_builder.add_sequence( 258 | kv_indices=seq.kv_indices, 259 | kv_seq_len=current_token_pos_id + 1, 260 | num_qtokens=1, 261 | page_size=page_size, 262 | starting_block=starting_block, 263 | ) 264 | # starting block of 0 needed for append 265 | append_kv_token_indices.extend( 266 | calc_kv_token_indices( 267 | kv_block_indices=seq.kv_indices, 268 | page_size=page_size, 269 | start_idx=current_token_pos_id, 270 | num_tokens=1, 271 | ) 272 | ) 273 | 274 | sparams = seq.sampling_params 275 | sampling_builder.add_sequence( 276 | temperature=sparams.temperature, 277 | top_p=sparams.top_p, 278 | ) 279 | 280 | prefill_lengths = [slen for _, slen in prefill_seqs] 281 | start_of_decode = sum(prefill_lengths) 282 | 283 | lm_head_indices.extend( 284 | list(range(start_of_decode, start_of_decode + len(decoding_seqs))) 285 | ) 286 | 287 | batch_indices = [] 288 | 289 | def register_batch_index(seq: Sequence, num_tokens: int): 290 | assert seq.batch_index is not None 291 | batch_indices.extend([seq.batch_index] * num_tokens) 292 | 293 | for seq, num_tokens in prefill_seqs: 294 | register_batch_index(seq, num_tokens) 295 | 296 | for seq in decoding_seqs: 297 | register_batch_index(seq, 1) 298 | 299 | # No need to call build() on builders anymore, just pass them directly 300 | attention_info_builder = AttentionInfoBuilder( 301 | page_size=page_size, 302 | append_kv_token_indices=append_kv_token_indices, 303 | prefill_builder=prefill_builder, 304 | decode_builder=decode_builder, 305 | hydragen_builder=hydragen_builder if use_hydragen else None, 306 | ) 307 | 308 | inp = ModelInput( 309 | attention_info_builder=attention_info_builder, 310 | sampling_builder=sampling_builder, 311 | prefill_input_ids=prefill_input_ids_list, 312 | batch_indices=batch_indices, 313 | lm_head_indices=lm_head_indices, 314 | position_ids=position_ids, 315 | schedule_id=schedule_id, 316 | microbatch_index=microbatch_index, 317 | microbatch_total=microbatch_total, 318 | ) 319 | 320 | return inp 321 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tokasaurus: The Little (LLM) Engine That Could! 2 | 3 | Check out our blog post [here](https://scalingintelligence.stanford.edu/blogs/tokasaurus/)! 4 | 5 | ## Table of Contents 6 | 7 | - [What is This?](#what-is-this) 8 | - [Installation](#installation) 9 | - [Quickstart](#quickstart) 10 | - [Walkthrough of CLI Flags](#walkthrough-of-cli-flags) 11 | - [System Design](#system-design) 12 | 13 | 14 | ## What is This? 15 | Tokasaurus is an LLM inference engine designed for high-throughput workloads. Features include: 16 | 17 | - OpenAI chat, completions, and batch APIs. 18 | - Data, pipeline, and tensor parallelism (with support for [AsyncTP](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487)). 19 | - Support for Llama3 and Qwen2 architectures. 20 | - [Paged KV caching](https://arxiv.org/abs/2309.06180) with [prefix caching](https://arxiv.org/abs/2312.07104). 21 | - Efficient attention over shared prefixes with [Hydragen](https://arxiv.org/abs/2402.05099), with automatic detection of shared prefixes across groups of sequences. 22 | - End-to-end torch compile with dynamic shapes. 23 | - CUDA graphs. 24 | - Very low CPU overhead (important for small models/fast GPUs). 25 | - A scheduler that can simulate the number of available KV cache blocks thousands of steps in the future, allowing us to aggressively onboard new sequences and keep our batch size as large as possible. 26 | - No OOMs or recompiles in production: on engine startup, we launch a series of warmup inputs that trigger all torch recompiles ahead-of-time (torch will recompile whenever a tensor has an input dimension is 0 or 1) and make check for OOMs using the largest configured batch size. 27 | 28 | NOTE: as a new project, expect some potentially rough edges :). 29 | 30 | ## Installation 31 | 32 | Tokasaurus has been tested on Python >= 3.10. To install from PyPI, run: 33 | 34 | ```bash 35 | 36 | pip install tokasaurus 37 | 38 | ``` 39 | 40 | Alternatively, clone the repo and run: 41 | 42 | ```bash 43 | 44 | pip install -e . 45 | 46 | ``` 47 | 48 | ## Quickstart 49 | 50 | Once installed, you can launch the engine with: 51 | 52 | ```bash 53 | 54 | # launch engine for Llama 1B (by default on port 10210). 55 | toka model=meta-llama/Llama-3.2-1B-Instruct 56 | 57 | # make a request to the engine (this command just wraps the OpenAI client) 58 | toka-ping prompt='tell me a joke' max_tokens=256 chat=True 59 | 60 | # launch a 70B model with pipeline parallelism across 8 gpus 61 | toka model=meta-llama/Llama-3.1-70B-Instruct kv_cache_num_tokens='(512 * 1024)' pp_size=8 62 | ``` 63 | 64 | To ping the engine once it's been launched, you can use the OpenAI client: 65 | 66 | ```python 67 | 68 | from openai import OpenAI 69 | client = OpenAI( 70 | api_key='fake-key', 71 | base_url="http://0.0.0.0:10210/v1" 72 | ) 73 | response = client.completions.create( 74 | model="default", 75 | prompt="On a dark desert highway, cool wind in my hair, warm smell of colitas, rising up through the air, up ahead in the distance, I saw a shimmering light, my head grew heavy and my sight grew dim, I had to stop for the night", 76 | temperature=0, 77 | n=2, 78 | max_tokens=100, 79 | ) 80 | 81 | ``` 82 | 83 | ### LM Eval Harness 84 | 85 | Since the engine supports the OpenAI API, you can plug it into the EleutherAI LM Eval harness using their local completions feature. First spin up an engine (see above) and then run evals on it with: 86 | 87 | ```bash 88 | 89 | lm_eval --model local-completions --tasks gsm8k --model_args model=MODEL,base_url=http://0.0.0.0:PORT/v1/completions,num_concurrent=256,max_retries=3,tokenized_requests=False 90 | 91 | ``` 92 | 93 | ## Walkthrough of CLI Flags 94 | 95 | The tokasaurus CLI uses [Pydra](https://github.com/jordan-benjamin/pydra), which uses a `key=value` format to set config flags. It also allows for boolean shorthands (e.g. `key=T` is equivalent to `key=True`) and allows for Python expression evaluation between parentheses (e.g. `key='(2 * 1024)'` is equivalent to `key=2048`). 96 | 97 | ### The Basics 98 | 99 | The only required parameter to launch an engine is the `model` field, which can point to a repo on HF or a local directory where a model is stored in HF format (just like when calling `from_pretrained` on a HF model). By default, the tokenizer will also be loaded using using the `model` flag. This can be overridden by setting the `tokenizer` flag yourself: 100 | 101 | ```bash 102 | toka model=meta-llama/Llama-3.2-1B-Instruct 103 | 104 | # e.g. if you want to load a fine-tuned model you saved to disk 105 | toka model=my_local_dir tokenizer=meta-llama/Llama-3.2-1B-Instruct 106 | ``` 107 | 108 | ### Leveraging Multiple GPUs 109 | 110 | By default, the engine will only use a single GPU to serve the model. You can change this with the `dp_size`, `pp_size`, and `tp_size` flags to control data, pipeline, and tensor parallelism, respectively. These flags are composable: for example, `dp_size=2` and `pp_size=4` will use 8 GPUs in total by creating two data-parallel replicas that each contain 4 GPUs in a pipeline: 111 | 112 | ### Managing GPU Memory with KV Cache Limits and Concurrency Controls 113 | 114 | The total amount of GPU memory used by the engine is the sum of GPU memory used to store the model weights, the activations, and the KV cache. While the model's GPU memory is fixed for a given model, we can control the size of the KV cache and the amount of activation memory we use. 115 | 116 | The KV cache size is controlled with `kv_cache_size_num_tokens`, and we can cap activation memory with the flags `max_tokens_per_forward` and `max_seqs_per_forward`. With `max_tokens_per_forward`, you directly control the number of tokens being sent through the model in a single forward pass, which can include tokens from sequences running either prefill or decode. With `max_seqs_per_forward`, we control the total number of sequences that can be running (i.e. that are in prefill or in decode) at a given time. Importantly, this limits the number of tokens per forward pass that can ever be sent through the language modeling head of the model, which can have a disproportionately large impact on activation memory. Prefill tokens don't run through the LM head (since we don't need to decode anything from them), so they take less activation memory. 117 | 118 | How should you tune these flags? Well, one of the most important factors for achieving high throughput is making the batch size as large as possible. A common bottleneck that limits the batch size in practice is the size of the KV cache - once your KV cache is full, you can't run any more sequences concurrently. Therefore, we want to make the KV cache as large as possible. However, in order to benefit from a large KV cache that can fit many sequences, we also must increase `max_seqs_per_forward` and `max_tokens_per_forward`. However, increasing these concurrency control flags increases the amount of used activation memory... decreasing the size of the largest KV cache we can fit. 119 | 120 | In practice, what this means is that you should increase your KV cache size and concurrency control flags jointly, making sure that you're not excessively raising one without the other. 121 | 122 | Note: when using multiple GPUs, these flags apply to each data-parallel replica separately (and apply collectively to all of the GPUs within a data parallel replica). For example, if you run with` dp_size=2 pp_size=4 kv_cache_size_num_tokens='(1024 * 1024)' max_seqs_per_forward=1024 max_tokens_per_forward=2048`: 123 | - In total, your server will have a KV cache size of 2 million tokens (1 million for each of the data parallel replicas). 124 | - Each replica can have 1024 sequences running at once and 2048 tokens scheduled per forwards pass. 125 | - Note that none of these numbers are multiplies by the pipeline parallel size. 126 | 127 | ### Torch Compile 128 | 129 | Torch compiling your model can make it faster and reduce the amount of used activation memory, allowing you to increase the KV cache size further. You can turn it on with `torch_compile=T`. The reason it's off by default is because it increases server startup time (often by a minute or two, but this can be worse the first time you run the engine on a new machine with compilation enabled). As a rough rule of thumb, turn compilation off for debugging things where fast startup is handy, but keep it on for all long-running jobs. 130 | 131 | ### Hydragen 132 | 133 | [Hydragen](https://arxiv.org/abs/2402.05099) (AKA cascade attention, bifurcated attention) is a method for more efficiently computing attention over a batch of sequences that share a common prefix. You can turn on Hydragen with `use_hydragen=T` and tokasaurus will automatically detect shared prefixes across groups of sequences actively running. You can control the thresholds where groups will be formed with `hydragen_min_group_size` and `hydragen_min_prefix_len`, which define the minimum number of sequences in a shared prefix group, and the minimum token length of a shared prefix measured in tokens, respectively. Note that turning on Hydragen can have a slight numerical impact on your generations since we combine attention results in bfloat16. 134 | 135 | ### Misc 136 | 137 | Here are some other server flags we didn't cover above, with their corresponding defaults: 138 | 139 | ```bash 140 | port=10210 # The port the server listens on. Note that all data parallel replicas are accessed through the same server port. 141 | page_size=16 # The page size for the paged KV cache. 142 | stop_string_num_token_lookback=5 # How many tokens to look back in the sequence for when checking whether a stop string has been generated. You may need to increase this if you have very long stop strings. 143 | stats_report_seconds=5.0 # How often server stats are printed to the console. 144 | uvicorn_log_level="info" # The logging level for the uvicorn web server handling requests. Set this value to "warning" to disable logs being printed every time a request is finished (which can sometimes be annoying/verbose). 145 | ``` 146 | 147 | ## System Design 148 | 149 | Tokasaurus has three major components: 150 | 151 | 1. A web server that interfaces between client requests and the actual engine (see `tokasaurus/server/`). 152 | 2. A manager that handles most of the CPU-side complexity (e.g. scheduling, paged kv cache management, hydragen grouping, etc.) (see `tokasaurus/manager/`). 153 | 3. A relatively barebones model worker that runs forward passes (see `tokasaurus/model/`). 154 | 155 | The server and manager are each their own process, with the model worker corresponding to one or more processes depending on the parallelization flags. These components communicate with each other asynchronously using queues. Importantly, the manager works to ensure that there are multiple items in the model input queue, so that the model can always be running forwards passes (i.e. the GPU can always be active) and never stall waiting for the manager to send it more work. 156 | 157 | When data parallelism is used, each replica has its own manager process and set of model worker processes. However, all data parallel replicas share the same server process which handles load balancing. 158 | 159 | The entry point for starting up the server and kicking off all the processes is `tokasaurus/entry.py`. 160 | 161 | 162 | ## Citation 163 | 164 | If you use Tokasaurus in your research, please cite: 165 | 166 | ```bibtex 167 | 168 | @misc{juravsky2025tokasaurus, 169 | author = {Jordan Juravsky and Ayush Chakravarthy and Ryan Ehrlich and Sabri Eyuboglu and Bradley Brown and Joseph Shetaye and Christopher R{\'e} and Azalia Mirhoseini}, 170 | title = {Tokasaurus: An LLM Inference Engine for High-Throughput Workloads}, 171 | year = {2025}, 172 | howpublished = {\url{https://scalingintelligence.stanford.edu/blogs/tokasaurus/}} 173 | } 174 | 175 | ``` -------------------------------------------------------------------------------- /tests/test_scheduler.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import pytest 4 | 5 | from tokasaurus.manager.allocator import BlockAllocator 6 | from tokasaurus.manager.scheduler import ( 7 | BlockUsageOverTime, 8 | BlockUsagePoint, 9 | EventCollection, 10 | calc_block_usage_over_time, 11 | try_onboarding_seqs, 12 | ) 13 | from tokasaurus.manager.types import Sequence 14 | 15 | 16 | def make_sequences( 17 | num_shared_decoding_seqs: int = 128, 18 | num_unique_decoding_seqs: int = 128, 19 | num_prefilling_seqs: int = 128, 20 | shared_length: int = 32, 21 | max_unique_length: int = 256, 22 | min_completion_length: int = 1, 23 | max_completion_length: int = 1024, 24 | vocab_size: int = 2, 25 | page_size: int = 2, 26 | num_blocks: int = 128 * 1024, 27 | allocator: BlockAllocator | None = None, 28 | name_prefix: str = "", 29 | ): 30 | if allocator is None: 31 | allocator = BlockAllocator(page_size=page_size, num_blocks=num_blocks) 32 | 33 | def make_random_ids(length: int): 34 | return [random.randint(0, vocab_size - 1) for _ in range(length)] 35 | 36 | shared_ids = make_random_ids(shared_length) 37 | 38 | # with only 1 completion token, a seq would never move to 39 | # decode - it would finish prefill and be done. 40 | min_decoding_completion_length = max(min_completion_length, 2) 41 | 42 | shared_decoding_seqs = [ 43 | Sequence( 44 | id=f"{name_prefix}shared-dec-seq{i}", 45 | completion_total=random.randint( 46 | min_decoding_completion_length, max_completion_length 47 | ), 48 | input_ids=shared_ids 49 | + make_random_ids(random.randint(1, max_unique_length)), 50 | ) 51 | for i in range(num_shared_decoding_seqs) 52 | ] 53 | 54 | unique_decoding_seqs = [ 55 | Sequence( 56 | id=f"{name_prefix}unique-dec-seq{i}", 57 | completion_total=random.randint( 58 | min_decoding_completion_length, max_completion_length 59 | ), 60 | input_ids=make_random_ids(random.randint(1, max_unique_length)), 61 | ) 62 | for i in range(num_unique_decoding_seqs) 63 | ] 64 | 65 | decoding_seqs = shared_decoding_seqs + unique_decoding_seqs 66 | 67 | prefilling_seqs = [ 68 | Sequence( 69 | id=f"{name_prefix}prefill-seq{i}", 70 | completion_total=random.randint( 71 | min_completion_length, max_completion_length 72 | ), 73 | input_ids=make_random_ids(random.randint(1, max_unique_length)), 74 | ) 75 | for i in range(num_prefilling_seqs) 76 | ] 77 | 78 | for d in decoding_seqs: 79 | kvs, num_cached = allocator.allocate_with_prefix_match(d.id, d.input_ids) 80 | completion_scheduled = random.randint(1, d.completion_total - 1) 81 | d.prompt_scheduled = len(d.input_ids) 82 | d.completion_scheduled = completion_scheduled 83 | d.num_cached_prompt_tokens = num_cached 84 | allocate_up_to = d.total_scheduled() - 1 85 | kvs.extend(allocator.allocate_up_to_length(d.id, kvs, allocate_up_to)) 86 | d.kv_indices = kvs 87 | assert 0 <= page_size * len(kvs) - allocate_up_to < page_size 88 | 89 | for p in prefilling_seqs: 90 | kvs, num_cached = allocator.allocate_with_prefix_match(p.id, p.input_ids) 91 | p.prompt_scheduled = num_cached 92 | p.num_cached_prompt_tokens = num_cached 93 | p.kv_indices = kvs 94 | assert 0 <= page_size * len(kvs) - len(p.input_ids) < page_size 95 | 96 | return shared_decoding_seqs, unique_decoding_seqs, prefilling_seqs, allocator 97 | 98 | 99 | @pytest.mark.parametrize("seed", list(range(10))) 100 | def test_calc_block_usage_over_time(seed): 101 | random.seed(seed) 102 | 103 | page_size = 2 104 | prefill_rate = 100 105 | vocab_size = 2 106 | 107 | shared_decoding_seqs, unique_decoding_seqs, prefilling_seqs, allocator = ( 108 | make_sequences(page_size=page_size, vocab_size=vocab_size) 109 | ) 110 | decoding_seqs = shared_decoding_seqs + unique_decoding_seqs 111 | 112 | block_usage: BlockUsageOverTime = calc_block_usage_over_time( 113 | decoding_seqs=decoding_seqs, 114 | prefilling_seqs=prefilling_seqs, 115 | page_size=page_size, 116 | prefill_rate=prefill_rate, 117 | add_buffer=False, 118 | ) 119 | 120 | # for all timesteps, not just ones where an event happens 121 | gold_points: list[BlockUsagePoint] = [] 122 | 123 | active_decoding_seqs = decoding_seqs.copy() 124 | active_prefilling_seqs = prefilling_seqs.copy() 125 | 126 | def free_seq(seq: Sequence): 127 | assert seq.kv_indices is not None 128 | return allocator.free_and_update( 129 | seq.id, 130 | seq.kv_indices, 131 | seq.input_ids 132 | + [ 133 | random.randint(0, vocab_size - 1) 134 | for _ in range(random.randint(0, seq.completion_total)) 135 | ], 136 | ) 137 | 138 | first_used_blocks = set() 139 | for d in decoding_seqs: 140 | first_used_blocks.update(d.kv_indices) 141 | 142 | for p in prefilling_seqs: 143 | first_used_blocks.update(p.kv_indices) 144 | 145 | while len(active_decoding_seqs) + len(active_prefilling_seqs) > 0: 146 | cur_step = len(gold_points) 147 | 148 | prefill_finishes = [] 149 | decode_finishes = [] 150 | 151 | last_page_lens_minus_one = [0] * page_size 152 | for d in active_decoding_seqs: 153 | assert d.completion_scheduled < d.completion_total 154 | d.kv_indices.extend( 155 | allocator.allocate_up_to_length(d.id, d.kv_indices, d.total_scheduled()) 156 | ) 157 | last_page_len = d.total_scheduled() % page_size 158 | if last_page_len == 0: 159 | last_page_len = page_size 160 | last_page_lens_minus_one[last_page_len - 1] += 1 161 | 162 | used_blocks = set() 163 | for d in active_decoding_seqs: 164 | used_blocks.update(d.kv_indices) 165 | 166 | for p in active_prefilling_seqs: 167 | used_blocks.update(p.kv_indices) 168 | 169 | new_active_decoding_seqs = [] 170 | new_active_prefilling_seqs = [] 171 | 172 | freed_blocks = set() 173 | 174 | for d in active_decoding_seqs: 175 | d.completion_scheduled += 1 176 | assert d.completion_scheduled <= d.completion_total 177 | if d.completion_scheduled < d.completion_total: 178 | new_active_decoding_seqs.append(d) 179 | else: 180 | decode_finishes.append(d) 181 | freed_blocks.update(free_seq(d)) 182 | 183 | prefill_available_for_step = prefill_rate 184 | for p in active_prefilling_seqs: 185 | new_prompt_scheduled = min( 186 | p.prompt_scheduled + prefill_available_for_step, len(p.input_ids) 187 | ) 188 | amount_prefilled = new_prompt_scheduled - p.prompt_scheduled 189 | prefill_available_for_step -= amount_prefilled 190 | p.prompt_scheduled = new_prompt_scheduled 191 | 192 | assert prefill_available_for_step >= 0 193 | 194 | if p.prompt_scheduled < len(p.input_ids): 195 | new_active_prefilling_seqs.append(p) 196 | else: 197 | assert p.prompt_scheduled == len(p.input_ids) 198 | 199 | p.completion_scheduled += 1 200 | if p.completion_scheduled < p.completion_total: 201 | new_active_decoding_seqs.append(p) 202 | prefill_finishes.append(p) 203 | else: 204 | assert p.completion_total == 1 205 | decode_finishes.append(p) 206 | freed_blocks.update(free_seq(p)) 207 | 208 | point = BlockUsagePoint( 209 | timestep=cur_step, 210 | num_used_blocks_after_allocation=len(used_blocks), 211 | last_page_lens_after_allocation=last_page_lens_minus_one, 212 | event=EventCollection( 213 | timestep=cur_step, 214 | decode_finishes=set(decode_finishes), 215 | prefill_finishes=set(prefill_finishes), 216 | ), 217 | freed_blocks_after_deallocation=freed_blocks, 218 | ) 219 | gold_points.append(point) 220 | 221 | active_decoding_seqs = new_active_decoding_seqs 222 | active_prefilling_seqs = new_active_prefilling_seqs 223 | 224 | for point in reversed(block_usage.points): 225 | gold_point = gold_points[point.timestep] 226 | 227 | assert ( 228 | gold_point.num_used_blocks_after_allocation 229 | == point.num_used_blocks_after_allocation 230 | ) 231 | assert ( 232 | gold_point.last_page_lens_after_allocation 233 | == point.last_page_lens_after_allocation 234 | ) 235 | assert gold_point.event == point.event 236 | assert gold_point.freed_blocks_after_deallocation.issuperset( 237 | point.freed_blocks_after_deallocation 238 | ) 239 | 240 | assert block_usage.used_blocks == first_used_blocks 241 | 242 | 243 | @pytest.mark.parametrize("seed", list(range(20, 30))) 244 | def test_try_onboarding_seq(seed): 245 | random.seed(seed) 246 | 247 | page_size = 2 248 | prefill_rate = 100 249 | 250 | shared_decoding_seqs, unique_decoding_seqs, prefilling_seqs, allocator = ( 251 | make_sequences(page_size=page_size) 252 | ) 253 | decoding_seqs = shared_decoding_seqs + unique_decoding_seqs 254 | 255 | all_used_blocks = { 256 | block.idx for block in allocator.all_blocks if len(block.seq_ids) > 0 257 | } 258 | 259 | ( 260 | _, 261 | _, 262 | additional_prefilling_seqs, 263 | _, 264 | ) = make_sequences( 265 | num_shared_decoding_seqs=0, 266 | num_unique_decoding_seqs=0, 267 | num_prefilling_seqs=512, 268 | allocator=allocator, 269 | name_prefix="more-", 270 | ) 271 | 272 | latest_block_usage: BlockUsageOverTime = calc_block_usage_over_time( 273 | decoding_seqs=decoding_seqs, 274 | prefilling_seqs=prefilling_seqs, 275 | page_size=page_size, 276 | prefill_rate=prefill_rate, 277 | add_buffer=False, 278 | ) 279 | 280 | to_onboard = additional_prefilling_seqs.copy() 281 | cur_used_blocks = all_used_blocks.copy() 282 | cur_prefilling_seqs = prefilling_seqs.copy() 283 | 284 | iters = 0 285 | while len(to_onboard) > 0: 286 | iters += 1 287 | num_to_onboard = random.randint(1, max(1, len(to_onboard) // 4)) 288 | seqs = to_onboard[:num_to_onboard] 289 | to_onboard = to_onboard[num_to_onboard:] 290 | 291 | used_by_seqs = set() 292 | for seq in seqs: 293 | used_by_seqs.update(seq.kv_indices) 294 | 295 | modified_block_usage = try_onboarding_seqs( 296 | block_usage=latest_block_usage, 297 | seqs=seqs, 298 | existing_prefill_seqs=cur_prefilling_seqs, 299 | page_size=page_size, 300 | add_buffer=False, 301 | prefill_rate=prefill_rate, 302 | block_limit=float("inf"), 303 | ) 304 | 305 | fresh_block_usage = calc_block_usage_over_time( 306 | decoding_seqs=decoding_seqs, 307 | prefilling_seqs=cur_prefilling_seqs + seqs, 308 | page_size=page_size, 309 | prefill_rate=prefill_rate, 310 | add_buffer=False, 311 | ) 312 | 313 | assert fresh_block_usage == modified_block_usage 314 | 315 | cur_used_blocks.update(used_by_seqs) 316 | cur_prefilling_seqs.extend(seqs) 317 | latest_block_usage = modified_block_usage 318 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2025 Stanford University 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /tokasaurus/model/pipeline_worker.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Iterable 3 | 4 | import torch 5 | import torch.distributed as dist 6 | import torch.multiprocessing as mp 7 | from torch import Tensor 8 | 9 | from tokasaurus.common_types import ServerConfig, TimedBarrier 10 | from tokasaurus.model.llama import LlamaForCausalLM 11 | from tokasaurus.model.types import ( 12 | BatchState, 13 | ModelInput, 14 | ModelOutput, 15 | ModelOutputTensors, 16 | NoMoreInputs, 17 | PipelineWorkerState, 18 | ) 19 | from tokasaurus.model.utils import ( 20 | ModelRunner, 21 | add_decoding_ids_to_batch_state, 22 | get_dtype, 23 | get_global_rank, 24 | make_input_batch_state, 25 | make_model, 26 | move_batch_state, 27 | setup_and_run_loop, 28 | setup_distributed, 29 | unpad_output_batch_state, 30 | ) 31 | from tokasaurus.utils import ( 32 | error_propogation_decorator, 33 | setup_logging, 34 | ) 35 | 36 | 37 | def wait_for_data_dependencies(state: PipelineWorkerState, inp: ModelInput): 38 | while contains_data_dependency( 39 | inp=inp, 40 | inputs_in_front=state.inflight_microbatches, 41 | ): 42 | output_ids = state.q_pipe_end_to_start.get() 43 | handle_output_from_pipeline_end(state, output_ids) 44 | 45 | 46 | def handle_output_from_pipeline_end(state: PipelineWorkerState, output_ids: Tensor): 47 | assert state.batch_id_to_last_token is not None 48 | 49 | model_inp = state.inflight_microbatches.popleft() 50 | 51 | batch_ids_to_update = torch.tensor( 52 | model_inp.lm_head_batch_indices(), dtype=torch.long 53 | ) 54 | 55 | assert batch_ids_to_update.shape == output_ids.shape, ( 56 | f"batch_ids_to_update.shape={batch_ids_to_update.shape} != output_ids.shape={output_ids.shape}" 57 | ) 58 | 59 | state.batch_id_to_last_token[batch_ids_to_update] = output_ids 60 | 61 | 62 | def handle_outputs_to_manager(state: PipelineWorkerState): 63 | front_inp, front_out = state.finished_outputs[0] 64 | 65 | assert front_inp.schedule_id == front_out.schedule_id 66 | schedule_id = front_inp.schedule_id 67 | 68 | microbatch_total = front_inp.microbatch_total 69 | assert microbatch_total is not None 70 | 71 | if len(state.finished_outputs) >= microbatch_total: 72 | to_finish = [state.finished_outputs.popleft() for _ in range(microbatch_total)] 73 | 74 | cat_output_tensors: list[ModelOutputTensors] = [] 75 | 76 | for i, (mb_inp, mb_out) in enumerate(to_finish): 77 | assert mb_inp.microbatch_index == i 78 | assert mb_inp.schedule_id == schedule_id 79 | assert mb_out.schedule_id == schedule_id 80 | 81 | cat_output_tensors.append(mb_out.tensors) 82 | 83 | cat_output_tokens = torch.cat([x.output_ids for x in cat_output_tensors]) 84 | cat_chosen_token_logprobs = torch.cat( 85 | [x.chosen_logprobs for x in cat_output_tensors] 86 | ) 87 | 88 | if cat_output_tensors[0].topk_indices is not None: 89 | cat_topk_indices = torch.cat([x.topk_indices for x in cat_output_tensors]) # type: ignore 90 | cat_topk_logprobs = torch.cat([x.topk_logprobs for x in cat_output_tensors]) # type: ignore 91 | else: 92 | cat_topk_indices = None 93 | cat_topk_logprobs = None 94 | 95 | cat_tensors = ModelOutputTensors( 96 | output_ids=cat_output_tokens, 97 | chosen_logprobs=cat_chosen_token_logprobs, 98 | topk_indices=cat_topk_indices, 99 | topk_logprobs=cat_topk_logprobs, 100 | ) 101 | 102 | out = ModelOutput( 103 | tensors=cat_tensors, 104 | schedule_id=schedule_id, 105 | ) 106 | 107 | state.q_to_manager.put(out) 108 | 109 | 110 | @error_propogation_decorator 111 | def pipeline_worker_model_loop( 112 | state: PipelineWorkerState, 113 | model: LlamaForCausalLM, 114 | ): 115 | assert state.device_mesh is not None 116 | 117 | config = state.config 118 | world_size = config.pp_size 119 | device = model.device 120 | 121 | pp_rank = state.pp_rank 122 | tp_rank = state.tp_rank 123 | dp_rank = state.dp_rank 124 | 125 | pp_group = state.device_mesh["pp"].get_group() 126 | 127 | if pp_rank > 0: 128 | pp_src_rank = get_global_rank(config, dp_rank, pp_rank - 1, tp_rank) 129 | else: 130 | pp_src_rank = None 131 | 132 | if pp_rank < world_size - 1: 133 | pp_dst_rank = get_global_rank(config, dp_rank, pp_rank + 1, tp_rank) 134 | else: 135 | pp_dst_rank = None 136 | 137 | non_blocking = True 138 | 139 | @dataclass 140 | class Work: 141 | model_input: ModelInput 142 | input_batch_state: BatchState 143 | output_batch_state: BatchState | None = None 144 | output_tensors_cpu: ModelOutputTensors | None = None 145 | 146 | def preprocess(): 147 | command = state.input_q.get() 148 | match command: 149 | case NoMoreInputs(): 150 | return None 151 | case _: 152 | inp: ModelInput = command 153 | 154 | num_total_padding, num_lm_head_padding = model_runner.calc_padding( 155 | num_prefill_tokens=inp.num_prefill_tokens(), 156 | num_decode_tokens=inp.num_decode_tokens(), 157 | num_lm_head_tokens=inp.num_lm_head_tokens(), 158 | ) 159 | 160 | input_batch_state = make_input_batch_state( 161 | inp, 162 | pp_rank=pp_rank, 163 | pp_size=config.pp_size, 164 | tp_rank=tp_rank, 165 | tp_size=config.tp_size, 166 | num_total_padding=num_total_padding, 167 | num_lm_head_padding=num_lm_head_padding, 168 | ) 169 | 170 | if pp_rank == 0: 171 | wait_for_data_dependencies(state, inp) 172 | 173 | assert state.batch_id_to_last_token is not None 174 | decoding_input_ids = state.batch_id_to_last_token[ 175 | torch.tensor(inp.decoding_batch_indices(), dtype=torch.long) 176 | ] 177 | 178 | add_decoding_ids_to_batch_state( 179 | input_batch_state=input_batch_state, 180 | decoding_input_ids=decoding_input_ids, 181 | tp_rank=tp_rank, 182 | tp_size=config.tp_size, 183 | ) 184 | 185 | state.inflight_microbatches.append(inp) 186 | 187 | model_runner.plan(input_batch_state, non_blocking=non_blocking) 188 | 189 | move_batch_state( 190 | input_batch_state=input_batch_state, 191 | device=device, 192 | non_blocking=non_blocking, 193 | ) 194 | 195 | return Work( 196 | model_input=inp, 197 | input_batch_state=input_batch_state, 198 | ) 199 | 200 | def run_model(work: Work): 201 | if pp_rank > 0: 202 | full_bs = work.input_batch_state.position_ids.shape[0] 203 | assert full_bs % config.tp_size == 0 204 | bs = full_bs // config.tp_size 205 | recv = torch.empty( 206 | bs, model.config.hidden_size, device=device, dtype=model.dtype 207 | ) 208 | if not work.model_input.skip_pipeline_communication: 209 | dist.recv(recv, src=pp_src_rank, group=pp_group) 210 | 211 | work.input_batch_state.hidden_states = recv 212 | 213 | output_batch_state = model_runner.run( 214 | work.input_batch_state, non_blocking=non_blocking 215 | ) 216 | assert output_batch_state.hidden_states is not None 217 | 218 | # if not last pipeline stage, send hidden states to next stage 219 | if pp_rank < world_size - 1: 220 | if not work.model_input.skip_pipeline_communication: 221 | dist.send( 222 | output_batch_state.hidden_states, 223 | dst=pp_dst_rank, 224 | group=pp_group, 225 | ) 226 | 227 | else: 228 | unpad_output_batch_state( 229 | output_batch_state=output_batch_state, 230 | model_input=work.model_input, 231 | ) 232 | 233 | work.output_batch_state = output_batch_state 234 | 235 | def synchronize(work: Work): 236 | # NOTE: important to do this for all workers - from what I can tell, 237 | # if too many nccl sends/recvs are launched by one process without getting 238 | # fulfilled by other ranks, deadlocks and illegal memory access 239 | # errors can happen. 240 | torch.cuda.synchronize() 241 | 242 | if pp_rank != world_size - 1: 243 | return 244 | 245 | assert work.output_batch_state is not None 246 | assert work.output_batch_state.outputs is not None 247 | 248 | # NOTE: if there are no output tokens and these tensors are empty, 249 | # calling .cpu() does not actually cause a sync. 250 | work.output_tensors_cpu = work.output_batch_state.outputs.to("cpu") 251 | 252 | # we have to send this now (and not in postprocess) because the start of the 253 | # pipeline may be waiting on it to send the next batch. if this end of the 254 | # pipeline end worker blocks (i.e. because of a cudaMalloc) after a nccl 255 | # recv is launched, it will deadlock 256 | state.q_pipe_end_to_start.put(work.output_tensors_cpu.output_ids) 257 | 258 | def postprocess(work: Work): 259 | if pp_rank != world_size - 1 or tp_rank != 0: 260 | return 261 | 262 | assert work.output_tensors_cpu is not None 263 | 264 | out = ModelOutput( 265 | tensors=work.output_tensors_cpu, 266 | schedule_id=work.model_input.schedule_id, 267 | microbatch_index=work.model_input.microbatch_index, 268 | ) 269 | 270 | state.finished_outputs.append((work.model_input, out)) 271 | handle_outputs_to_manager(state) 272 | 273 | model_runner = ModelRunner( 274 | config=state.config, 275 | model=model, 276 | ) 277 | 278 | setup_and_run_loop( 279 | state=state, 280 | model_runner=model_runner, 281 | preprocess=preprocess, 282 | run_model=run_model, 283 | synchronize=synchronize, 284 | postprocess=postprocess, 285 | ) 286 | 287 | 288 | def contains_data_dependency( 289 | inp: ModelInput, 290 | inputs_in_front: Iterable[ModelInput], 291 | ): 292 | # checking for data dependencies - we can't schedule a 293 | # new input (microbatch) if it depends on a token that's in flight 294 | # NOTE: this set can change between while loop iterations 295 | # because we add new in flight microbatches as we go 296 | ids_in_front = set() 297 | for mb in inputs_in_front: 298 | ids_in_front.update(mb.lm_head_batch_indices()) 299 | 300 | # decoding seqs need to wait on any prev decodes or final-token prefills 301 | # final-token prefills themselves don't need to wait on anything 302 | for id in inp.decoding_batch_indices(): 303 | if id in ids_in_front: 304 | return True 305 | 306 | return False 307 | 308 | 309 | def start_pipeline_worker( 310 | config: ServerConfig, 311 | input_q: mp.Queue, 312 | q_pipe_end_to_start: mp.Queue, 313 | q_to_manager: mp.Queue, 314 | dp_rank: int, 315 | pp_rank: int, 316 | tp_rank: int, 317 | master_port: int, 318 | process_name: str, 319 | barrier: TimedBarrier, 320 | ): 321 | setup_logging(config) 322 | 323 | state = PipelineWorkerState( 324 | config=config, 325 | input_q=input_q, 326 | q_pipe_end_to_start=q_pipe_end_to_start, 327 | q_to_manager=q_to_manager, 328 | process_name=process_name, 329 | pp_rank=pp_rank, 330 | tp_rank=tp_rank, 331 | dp_rank=dp_rank, 332 | barrier=barrier, 333 | ) 334 | 335 | if pp_rank == 0: 336 | state.batch_id_to_last_token = torch.zeros( 337 | config.max_batch_index(), dtype=torch.long 338 | ) 339 | 340 | state.logger.info(f"Pipeline worker {pp_rank} started!") 341 | dtype = get_dtype(config.dtype) 342 | 343 | device_mesh, device = setup_distributed( 344 | config=config, 345 | dp_rank=dp_rank, 346 | pp_rank=pp_rank, 347 | tp_rank=tp_rank, 348 | master_port=master_port, 349 | ) 350 | assert device_mesh is not None 351 | state.device_mesh = device_mesh 352 | 353 | state.logger.info(f"Creating model on device {device} with dtype {dtype}") 354 | 355 | model = make_model( 356 | config, 357 | device, 358 | dtype, 359 | pp_rank=pp_rank, 360 | tp_rank=tp_rank, 361 | tp_group=state.device_mesh["tp"].get_group(), 362 | ) 363 | 364 | state.logger.info("Created model") 365 | 366 | pipeline_worker_model_loop(state=state, model=model) 367 | --------------------------------------------------------------------------------