├── tests ├── __init__.py └── e2e │ ├── __init__.py │ ├── test_synchronous.py │ └── test_structured_outputs.py ├── tools ├── __init__.py └── pd_disaggregation_proxy_server.py ├── src └── stopwatch │ ├── __init__.py │ ├── __main__.py │ ├── llm_servers │ ├── __init__.py │ ├── sglang.py │ ├── tokasaurus.py │ ├── vllm.py │ ├── vllm_pd_disaggregation.py │ ├── tensorrt_llm.py │ └── dynamic.py │ ├── utils │ ├── __init__.py │ ├── openai.py │ └── loader.py │ ├── cli │ ├── utils.py │ ├── __init__.py │ ├── benchmark.py │ ├── provision.py │ ├── provision_and_benchmark.py │ └── profile.py │ ├── resources.py │ ├── constants.py │ ├── profile.py │ └── benchmark.py ├── pytest.ini ├── .gitignore ├── .github └── workflows │ └── main.yml ├── LICENSE.md ├── pyproject.toml ├── README.md └── CONTRIBUTING.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/e2e/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/stopwatch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = -ra -q -n 32 --reruns 3 -------------------------------------------------------------------------------- /src/stopwatch/__main__.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | from src.stopwatch.cli import app 3 | 4 | app() 5 | -------------------------------------------------------------------------------- /src/stopwatch/llm_servers/__init__.py: -------------------------------------------------------------------------------- 1 | from .dynamic import create_dynamic_llm_server_class 2 | 3 | __all__ = ["create_dynamic_llm_server_class"] 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | env/ 2 | .venv/ 3 | venv/ 4 | __pycache__/ 5 | stopwatch.egg-info/ 6 | *.json 7 | *.ipynb 8 | .DS_Store 9 | traces/ 10 | *.db 11 | results/ 12 | *.json.gz 13 | \.*.yaml 14 | -------------------------------------------------------------------------------- /src/stopwatch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .loader import CustomGenerativeRequestLoader 2 | from .openai import CustomOpenAIHTTPBackend 3 | 4 | __all__ = ["CustomGenerativeRequestLoader", "CustomOpenAIHTTPBackend"] 5 | -------------------------------------------------------------------------------- /src/stopwatch/cli/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import typer 4 | 5 | 6 | def config_callback(config: str | None) -> dict: 7 | """Parse JSON config strings into dicts.""" 8 | 9 | if isinstance(config, dict): 10 | return config 11 | 12 | if config is None: 13 | return {} 14 | 15 | try: 16 | return json.loads(config) 17 | except json.JSONDecodeError as err: 18 | msg = "Must be a valid JSON string" 19 | raise typer.BadParameter(msg) from err 20 | -------------------------------------------------------------------------------- /src/stopwatch/cli/__init__.py: -------------------------------------------------------------------------------- 1 | import typer 2 | 3 | from .benchmark import benchmark_cli 4 | from .profile import profile_cli 5 | from .provision import provision_cli 6 | from .provision_and_benchmark import provision_and_benchmark_cli 7 | 8 | app = typer.Typer() 9 | app.command(name="benchmark")(benchmark_cli) 10 | app.command(name="profile")(profile_cli) 11 | app.command(name="provision")(provision_cli) 12 | app.command(name="provision-and-benchmark")(provision_and_benchmark_cli) 13 | 14 | 15 | if __name__ == "__main__": 16 | app() 17 | -------------------------------------------------------------------------------- /src/stopwatch/utils/openai.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | from guidellm.backend.openai import OpenAIHTTPBackend 3 | 4 | 5 | class CustomOpenAIHTTPBackend(OpenAIHTTPBackend): 6 | """A custom OpenAI HTTP backend that increases the number of maximum redirects.""" 7 | 8 | def _get_async_client(self) -> httpx.AsyncClient: 9 | if self._async_client is None or self._async_client.is_closed: 10 | client = super()._get_async_client() 11 | client.max_redirects = 1000 12 | self._async_client = client 13 | 14 | return self._async_client 15 | -------------------------------------------------------------------------------- /src/stopwatch/resources.py: -------------------------------------------------------------------------------- 1 | import modal 2 | 3 | app = modal.App() 4 | 5 | # Dicts 6 | startup_metrics_dict = modal.Dict.from_name( 7 | "stopwatch-startup-metrics", 8 | create_if_missing=True, 9 | ) 10 | 11 | # Secrets 12 | hf_secret = modal.Secret.from_name("huggingface-secret") 13 | 14 | # Volumes 15 | db_volume = modal.Volume.from_name("stopwatch-db", create_if_missing=True) 16 | hf_cache_volume = modal.Volume.from_name("stopwatch-hf-cache", create_if_missing=True) 17 | results_volume = modal.Volume.from_name("stopwatch-results", create_if_missing=True) 18 | traces_volume = modal.Volume.from_name("stopwatch-traces", create_if_missing=True) 19 | vllm_cache_volume = modal.Volume.from_name( 20 | "stopwatch-vllm-cache", 21 | create_if_missing=True, 22 | ) 23 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Main 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | lint: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python: ["3.13"] 14 | steps: 15 | - uses: actions/checkout@v4 16 | - uses: astral-sh/ruff-action@v3 17 | 18 | test: 19 | runs-on: ubuntu-latest 20 | strategy: 21 | matrix: 22 | python: ["3.13"] 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Set up Python ${{ matrix.python }} 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: ${{ matrix.python }} 29 | - name: Install dependencies 30 | run: | 31 | pip install --upgrade pip 32 | pip install . 33 | - name: Run tests 34 | env: 35 | MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} 36 | MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} 37 | run: pytest 38 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2025 Modal Labs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/stopwatch/constants.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | SECONDS = 1 4 | MINUTES = 60 * SECONDS 5 | HOURS = 60 * MINUTES 6 | 7 | # Volume mount paths 8 | DB_PATH = "/db" 9 | HF_CACHE_PATH = "/cache" 10 | TRACES_PATH = "/traces" 11 | VLLM_CACHE_PATH = "/root/.cache/vllm" 12 | 13 | GUIDELLM_VERSION = "af89787" 14 | TENSORRT_LLM_CUDA_VERSION = "12.9.1" 15 | TOKASAURUS_CUDA_VERSION = "12.4.1" 16 | 17 | 18 | class RateType(str, Enum): 19 | """Types of rates for running benchmarks.""" 20 | 21 | constant = "constant" 22 | sweep = "sweep" 23 | synchronous = "synchronous" 24 | throughput = "throughput" 25 | 26 | 27 | class LLMServerType(str, Enum): 28 | """Types of LLM servers.""" 29 | 30 | sglang = "sglang" 31 | tensorrt_llm = "tensorrt-llm" 32 | tokasaurus = "tokasaurus" 33 | vllm = "vllm" 34 | vllm_pd_disaggregation = "vllm-pd-disaggregation" 35 | 36 | def get_version(self) -> str: 37 | """Get the latest version of the LLM server.""" 38 | 39 | versions = { 40 | LLMServerType.sglang: "v0.5.1.post3-cu126", 41 | LLMServerType.tensorrt_llm: "1.1.0rc1", 42 | LLMServerType.tokasaurus: "0.0.4", 43 | LLMServerType.vllm: "v0.10.1.1", 44 | LLMServerType.vllm_pd_disaggregation: "v0.10.1.1", 45 | } 46 | 47 | return versions[self] 48 | -------------------------------------------------------------------------------- /src/stopwatch/cli/benchmark.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import modal 5 | 6 | from stopwatch.benchmark import GuideLLM 7 | from stopwatch.constants import RateType 8 | from stopwatch.resources import app 9 | 10 | 11 | def benchmark_cli( 12 | *, 13 | endpoint: str, 14 | model: str, 15 | output_path: str | None = "results.json", 16 | detach: bool = False, 17 | rate_type: RateType = RateType.synchronous, 18 | data: str = "prompt_tokens=128,output_tokens=128", 19 | duration: float | None = 120, 20 | client_config: str | None = None, 21 | rate: float | None = None, 22 | region: str | None = None, 23 | ) -> list[dict]: 24 | """Benchmark an OpenAI-compatible LLM server using GuideLLM.""" 25 | 26 | with modal.enable_output(), app.run(detach=detach): 27 | print(f"Running benchmark on {endpoint}...") 28 | 29 | results = GuideLLM.with_options(region=region)().run_benchmark.remote( 30 | endpoint, 31 | model, 32 | rate_type, 33 | data, 34 | duration, 35 | client_config, 36 | rate, 37 | ) 38 | 39 | if output_path is not None: 40 | with Path(output_path).open("w") as f: 41 | json.dump(results, f, indent=2) 42 | 43 | print(f"Results saved to {output_path}") 44 | 45 | return results 46 | -------------------------------------------------------------------------------- /tests/e2e/test_synchronous.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from stopwatch.cli import provision_and_benchmark_cli 4 | from stopwatch.constants import LLMServerType 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "llm_server_type", 9 | [ 10 | LLMServerType.vllm, 11 | LLMServerType.sglang, 12 | LLMServerType.tensorrt_llm, 13 | LLMServerType.tokasaurus, 14 | ], 15 | ) 16 | @pytest.mark.timeout(300) 17 | def test_llama(llm_server_type: LLMServerType) -> None: 18 | """Test that a quick synchronous benchmark runs successfully.""" 19 | 20 | results = provision_and_benchmark_cli( 21 | "meta-llama/Llama-3.1-8B-Instruct", 22 | llm_server_type, 23 | gpu="H100!", 24 | duration=10, 25 | server_cloud="oci", 26 | client_config=( 27 | { 28 | "remove_from_body": ["max_completion_tokens", "stream"], 29 | } 30 | if llm_server_type == LLMServerType.tokasaurus 31 | else None 32 | ), 33 | ) 34 | 35 | # Only one benchmark should have been run 36 | assert len(results) == 1 37 | 38 | for result in results: 39 | # Check that the rate type and rate are in the result 40 | assert result["rate_type"] == "synchronous" 41 | assert result["rate"] is None 42 | 43 | # At least one successful request should have been made 44 | assert result["run_stats"]["requests_made"]["successful"] > 0 45 | 46 | # Cold start and queue durations should have been saved 47 | assert isinstance(result["cold_start_duration"], float) 48 | assert isinstance(result["queue_duration"], float) 49 | -------------------------------------------------------------------------------- /src/stopwatch/cli/provision.py: -------------------------------------------------------------------------------- 1 | import json 2 | import uuid 3 | 4 | import modal 5 | 6 | from stopwatch.constants import LLMServerType 7 | from stopwatch.llm_servers import create_dynamic_llm_server_class 8 | from stopwatch.resources import app 9 | 10 | 11 | def provision_cli( 12 | model: str, 13 | *, 14 | endpoint_label: str | None = None, 15 | gpu: str = "H100", 16 | llm_server_type: LLMServerType = LLMServerType.vllm, 17 | cpu: int | None = None, 18 | memory: int | None = None, 19 | min_containers: int = 0, 20 | max_containers: int = 1, 21 | cloud: str | None = None, 22 | region: str | None = None, 23 | llm_server_config: str | None = None, 24 | max_concurrent_inputs: int = 1000, 25 | ) -> None: 26 | """Deploy an LLM server on Modal.""" 27 | 28 | with modal.enable_output(): 29 | cls, _ = create_dynamic_llm_server_class( 30 | # Pick a random name for the server if not provided 31 | endpoint_label or uuid.uuid4().hex[:4], 32 | model, 33 | gpu=gpu, 34 | llm_server_type=llm_server_type, 35 | cpu=cpu, 36 | memory=memory, 37 | min_containers=min_containers, 38 | max_containers=max_containers, 39 | cloud=cloud, 40 | region=region, 41 | llm_server_config=( 42 | json.loads(llm_server_config) if llm_server_config else None 43 | ), 44 | max_concurrent_inputs=max_concurrent_inputs, 45 | ) 46 | 47 | app.deploy(name="deployment") 48 | 49 | print("Your OpenAI-compatible endpoint is ready at:") 50 | print(cls().start.get_web_url()) 51 | -------------------------------------------------------------------------------- /src/stopwatch/utils/loader.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterator 2 | from typing import Any 3 | 4 | from guidellm.request import GenerationRequest, GenerativeRequestLoader 5 | 6 | 7 | class CustomGenerativeRequestLoader(GenerativeRequestLoader): 8 | """ 9 | A wrapper around GenerativeRequestLoader that allows for modifications to 10 | be made to GuideLLM requests. 11 | 12 | These are both useful when testing structured outputs, e.g. 13 | https://docs.vllm.ai/en/latest/features/structured_outputs.html 14 | """ 15 | 16 | def __init__( 17 | self, 18 | extra_body: dict[str, Any] | None = None, 19 | *, 20 | use_chat_completions: bool = False, 21 | **kwargs: dict[str, Any], 22 | ) -> None: 23 | """ 24 | Create a custom generative request loader. 25 | 26 | :param: extra_body: Extra parameters to add to the body of each request. 27 | :param: use_chat_completions: Whether to use the chat completions endpoint, 28 | as opposed to the text completions endpoint. 29 | :param: kwargs: Additional keyword arguments to pass to the 30 | GenerativeRequestLoader constructor. 31 | """ 32 | 33 | super().__init__(**kwargs) 34 | self.extra_body = extra_body or {} 35 | self.use_chat_completions = use_chat_completions 36 | 37 | def __iter__(self) -> Iterator[GenerationRequest]: 38 | """Iterate over the requests in the loader.""" 39 | 40 | for item in super().__iter__(): 41 | for k, v in self.extra_body.items(): 42 | item.params[k] = v 43 | 44 | if self.use_chat_completions: 45 | item.request_type = "chat_completions" 46 | 47 | yield item 48 | 49 | def __len__(self) -> int: 50 | """Return the number of unique requests in the loader.""" 51 | return super().__len__() 52 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 77.0.3"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "stopwatch" 7 | version = "0.1.0" 8 | requires-python = ">=3.10,<4.0" 9 | dependencies = [ 10 | "guidellm @ git+https://github.com/vllm-project/guidellm.git#af89787", 11 | "modal~=1.1.3", 12 | "pytest~=8.4.1", 13 | "pytest-rerunfailures~=15.1", 14 | "pytest-timeout~=2.4.0", 15 | "pytest-xdist~=3.8.0", 16 | "requests~=2.32.5", 17 | ] 18 | authors = [ 19 | {name = "Jack Cook", email = "hello@jackcook.com"}, 20 | ] 21 | description = "A tool for benchmarking LLMs on Modal" 22 | readme = "README.md" 23 | license = "MIT" 24 | license-files = ["LICENSE.md"] 25 | keywords = ["machine-learning", "llms", "vllm", "sglang", "tensorrt-llm"] 26 | 27 | [project.scripts] 28 | stopwatch = "stopwatch.cli:app" 29 | 30 | [project.urls] 31 | Homepage = "https://github.com/modal-labs/stopwatch" 32 | Documentation = "https://github.com/modal-labs/stopwatch/blob/main/README.md" 33 | Repository = "https://github.com/modal-labs/stopwatch.git" 34 | 35 | [tool.ruff] 36 | line-length = 88 37 | indent-width = 4 38 | exclude = ["build", "dist", "env", ".venv"] 39 | 40 | [tool.ruff.format] 41 | quote-style = "double" 42 | indent-style = "space" 43 | 44 | [tool.ruff.lint] 45 | ignore = [ 46 | # Various dumb rules related to docstrings 47 | "D100", 48 | "D104", 49 | "D202", 50 | "D203", 51 | "D205", 52 | "D212", 53 | 54 | "PLC0415", # Allow imports that aren't at the top of the file 55 | "PLR0913", # Allow functions with more than 5 arguments 56 | "S113", # Allow requests without timeouts 57 | "UP017", # Allow timezone.utc on the SGLang image, which uses Python 3.10 58 | ] 59 | select = [ 60 | "ALL", 61 | ] 62 | 63 | [tool.ruff.lint.extend-per-file-ignores] 64 | "src/stopwatch/cli/*.py" = [ 65 | "T201", # Allow print statements in CLI 66 | ] 67 | "src/stopwatch/llm_servers/*.py" = [ 68 | "N801", # Allow class names to be lowercase 69 | "S104", # Allow binding to all interfaces 70 | "S602", # Allow subprocess.Popen with shell=True 71 | ] 72 | "tests/**/test_*.py" = [ 73 | "S101", # Allow assert statements in tests 74 | "PLR2004", # Allow constant values in tests 75 | ] -------------------------------------------------------------------------------- /tests/e2e/test_structured_outputs.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from stopwatch.cli import provision_and_benchmark_cli 4 | from stopwatch.constants import LLMServerType 5 | 6 | CLIENT_CONFIGS = { 7 | LLMServerType.sglang: { 8 | "extra_body": { 9 | "response_format": { 10 | "type": "json_schema", 11 | "json_schema": { 12 | "name": "Summary", 13 | "schema": { 14 | "properties": { 15 | "summary": {"title": "Summary", "type": "string"}, 16 | }, 17 | "required": ["summary"], 18 | "title": "Summary", 19 | "type": "object", 20 | }, 21 | }, 22 | }, 23 | }, 24 | }, 25 | LLMServerType.vllm: { 26 | "extra_body": { 27 | "guided_json": { 28 | "properties": { 29 | "summary": { 30 | "title": "Summary", 31 | "type": "string", 32 | }, 33 | }, 34 | "required": ["summary"], 35 | "title": "Summary", 36 | "type": "object", 37 | }, 38 | }, 39 | }, 40 | } 41 | 42 | SERVER_CONFIGS = { 43 | LLMServerType.sglang: { 44 | "extra_args": ["--grammar-backend", "outlines"], 45 | "image_kwargs": { 46 | "extra_python_packages": ["outlines", "transformers==4.53.3"], 47 | }, 48 | }, 49 | } 50 | 51 | 52 | @pytest.mark.parametrize("llm_server_type", [LLMServerType.vllm, LLMServerType.sglang]) 53 | @pytest.mark.timeout(300) 54 | def test_structured_outputs(llm_server_type: LLMServerType) -> None: 55 | """Test that a quick synchronous benchmark runs successfully.""" 56 | 57 | results = provision_and_benchmark_cli( 58 | "meta-llama/Llama-3.1-8B-Instruct", 59 | llm_server_type, 60 | gpu="H100!", 61 | duration=10, 62 | llm_server_config=SERVER_CONFIGS.get(llm_server_type), 63 | client_config=CLIENT_CONFIGS.get(llm_server_type), 64 | ) 65 | 66 | # Only one benchmark should have been run 67 | assert len(results) == 1 68 | 69 | # At least one successful request should have been made 70 | assert results[0]["run_stats"]["requests_made"]["successful"] > 0 71 | -------------------------------------------------------------------------------- /src/stopwatch/cli/provision_and_benchmark.py: -------------------------------------------------------------------------------- 1 | import json 2 | import uuid 3 | from pathlib import Path 4 | from typing import Annotated 5 | 6 | import modal 7 | import typer 8 | 9 | from stopwatch.benchmark import GuideLLM 10 | from stopwatch.constants import LLMServerType, RateType 11 | from stopwatch.llm_servers import create_dynamic_llm_server_class 12 | from stopwatch.resources import app 13 | 14 | from .utils import config_callback 15 | 16 | 17 | def rate_type_callback(ctx: typer.Context, rate_type: RateType) -> RateType: 18 | """Require rate to be provided when rate_type is constant.""" 19 | 20 | if rate_type == RateType.constant and ctx.params.get("rate") is None: 21 | msg = "Rate must be provided when rate_type is constant" 22 | raise typer.BadParameter(msg) 23 | 24 | return rate_type.value 25 | 26 | 27 | def provision_and_benchmark_cli( 28 | model: str, 29 | llm_server_type: LLMServerType, 30 | *, 31 | output_path: str | None = "results.json", 32 | detach: bool = False, 33 | data: str = "prompt_tokens=512,output_tokens=128", 34 | gpu: str = "H100", 35 | server_region: str | None = None, 36 | client_region: str | None = None, 37 | server_cloud: str | None = None, 38 | duration: float | None = 120, 39 | llm_server_config: Annotated[ 40 | str | None, 41 | typer.Option(callback=config_callback), 42 | ] = None, 43 | client_config: Annotated[str | None, typer.Option(callback=config_callback)] = None, 44 | rate_type: Annotated[ 45 | RateType, 46 | typer.Option(callback=rate_type_callback), 47 | ] = RateType.synchronous, 48 | rate: float | None = None, 49 | ) -> list[dict]: 50 | """Run a benchmark.""" 51 | 52 | server_class, server_id = create_dynamic_llm_server_class( 53 | uuid.uuid4().hex[:4], 54 | model, 55 | gpu=gpu, 56 | llm_server_type=llm_server_type, 57 | cloud=server_cloud, 58 | region=server_region, 59 | llm_server_config=llm_server_config, 60 | ) 61 | 62 | with modal.enable_output(), app.run(detach=detach): 63 | results = GuideLLM.with_options(region=client_region)().run_benchmark.remote( 64 | endpoint=f"{server_class().start.get_web_url()}/v1", 65 | model=model, 66 | rate_type=rate_type, 67 | data=data, 68 | duration=duration, 69 | client_config=client_config, 70 | rate=rate, 71 | server_id=server_id, 72 | ) 73 | 74 | if output_path is not None: 75 | with Path(output_path).open("w") as f: 76 | json.dump(results, f, indent=2) 77 | 78 | return results 79 | -------------------------------------------------------------------------------- /src/stopwatch/llm_servers/sglang.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import subprocess 4 | from datetime import datetime, timezone 5 | 6 | import modal 7 | 8 | from stopwatch.constants import HF_CACHE_PATH, MINUTES, LLMServerType 9 | from stopwatch.resources import hf_cache_volume, startup_metrics_dict 10 | 11 | PORT = 30000 12 | 13 | 14 | def sglang_image_factory( 15 | docker_tag: str = LLMServerType.sglang.get_version(), 16 | extra_python_packages: list[str] | None = None, 17 | ) -> modal.Image: 18 | """ 19 | Create a Modal image for running an SGLang server. 20 | 21 | :param: docker_tag: The tag of the SGLang Docker image to use. 22 | :return: A Modal image for running a SGLang server. 23 | """ 24 | 25 | return ( 26 | modal.Image.from_registry(f"lmsysorg/sglang:{docker_tag}") 27 | .uv_pip_install( 28 | "hf-transfer", 29 | "grpclib", 30 | "requests", 31 | *(extra_python_packages or []), 32 | ) 33 | .env({"HF_HUB_CACHE": HF_CACHE_PATH, "HF_HUB_ENABLE_HF_TRANSFER": "1"}) 34 | .dockerfile_commands("ENTRYPOINT []") 35 | ) 36 | 37 | 38 | class SGLangBase: 39 | """A Modal class that runs an SGLang server.""" 40 | 41 | @modal.web_server(port=PORT, startup_timeout=30 * MINUTES) 42 | def start(self) -> None: 43 | """Start an SGLang server.""" 44 | 45 | # Save the startup time to a dictionary so we can measure cold start duration 46 | startup_metrics_dict[self.server_id] = datetime.now(timezone.utc).timestamp() 47 | 48 | hf_cache_volume.reload() 49 | 50 | if not self.model: 51 | msg = "model must be set, e.g. 'meta-llama/Llama-3.1-8B-Instruct'" 52 | raise ValueError(msg) 53 | 54 | server_config = json.loads(self.server_config) 55 | 56 | # Start SGLang server 57 | subprocess.Popen( 58 | " ".join( 59 | [ 60 | "python", 61 | "-m", 62 | "sglang.launch_server", 63 | "--model-path", 64 | self.model, 65 | "--host", 66 | "0.0.0.0", 67 | *( 68 | ["--tokenizer-path", server_config["tokenizer"]] 69 | if "tokenizer" in server_config 70 | else [] 71 | ), 72 | *server_config.get("extra_args", []), 73 | ], 74 | ) 75 | + f" || python -m http.server {PORT}", 76 | env={ 77 | **os.environ, 78 | **server_config.get("env_vars", {}), 79 | }, 80 | shell=True, 81 | ) 82 | -------------------------------------------------------------------------------- /src/stopwatch/cli/profile.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from pathlib import Path 3 | from typing import Annotated 4 | 5 | import modal 6 | import typer 7 | 8 | from stopwatch.constants import LLMServerType 9 | from stopwatch.llm_servers import create_dynamic_llm_server_class 10 | from stopwatch.profile import profile 11 | from stopwatch.resources import app, traces_volume 12 | 13 | from .utils import config_callback 14 | 15 | TRACES_PATH = "/traces" 16 | 17 | 18 | def llm_server_type_callback(llm_server_type: LLMServerType) -> LLMServerType: 19 | """Require that llm_server_type is supported for profiling.""" 20 | 21 | if llm_server_type in (LLMServerType.vllm, LLMServerType.sglang): 22 | return llm_server_type.value 23 | 24 | msg = "Profiling is only supported with vLLM or SGLang" 25 | raise typer.BadParameter(msg) 26 | 27 | 28 | def profile_cli( 29 | model: str, 30 | llm_server_type: Annotated[ 31 | LLMServerType, 32 | typer.Argument(callback=llm_server_type_callback), 33 | ], 34 | *, 35 | output_path: str = "trace.json.gz", 36 | gpu: str = "H100", 37 | server_region: str = "us-chicago-1", 38 | num_requests: int = 10, 39 | prompt_tokens: int = 512, 40 | output_tokens: int = 8, 41 | llm_server_config: Annotated[ 42 | str | None, 43 | typer.Option(callback=config_callback), 44 | ] = None, 45 | ) -> None: 46 | """Run an LLM server alongside the PyTorch profiler.""" 47 | 48 | if "env_vars" not in llm_server_config: 49 | if llm_server_type == LLMServerType.vllm: 50 | llm_server_config["env_vars"] = { 51 | "VLLM_TORCH_PROFILER_DIR": TRACES_PATH, 52 | "VLLM_RPC_TIMEOUT": "1800000", 53 | } 54 | elif llm_server_type == LLMServerType.sglang: 55 | llm_server_config["env_vars"] = { 56 | "SGLANG_TORCH_PROFILER_DIR": TRACES_PATH, 57 | } 58 | 59 | name = uuid.uuid4().hex[:4] 60 | server_cls, _ = create_dynamic_llm_server_class( 61 | name, 62 | model, 63 | gpu=gpu, 64 | llm_server_type=llm_server_type, 65 | region=server_region, 66 | llm_server_config=llm_server_config, 67 | ) 68 | 69 | with modal.enable_output(), app.run(): 70 | fc = profile.spawn( 71 | endpoint=server_cls().start.get_web_url(), 72 | model=model, 73 | num_requests=num_requests, 74 | prompt_tokens=prompt_tokens, 75 | output_tokens=output_tokens, 76 | ) 77 | 78 | print(f"Profiler running at {fc.object_id}...") 79 | trace_path = fc.get() 80 | 81 | with Path(output_path).open("wb") as f: 82 | for chunk in traces_volume.read_file(trace_path): # noqa: FURB122 83 | f.write(chunk) 84 | -------------------------------------------------------------------------------- /src/stopwatch/llm_servers/tokasaurus.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import subprocess 4 | from datetime import datetime, timezone 5 | 6 | import modal 7 | 8 | from stopwatch.constants import ( 9 | HF_CACHE_PATH, 10 | MINUTES, 11 | TOKASAURUS_CUDA_VERSION, 12 | LLMServerType, 13 | ) 14 | from stopwatch.resources import hf_cache_volume, startup_metrics_dict 15 | 16 | PORT = 10210 17 | 18 | 19 | def tokasaurus_image_factory( 20 | version: str = LLMServerType.tokasaurus.get_version(), 21 | cuda_version: str = TOKASAURUS_CUDA_VERSION, 22 | ) -> modal.Image: 23 | """ 24 | Create a Modal image for running a Tokasaurus server. 25 | 26 | :param: version: The version of Tokasaurus to install. 27 | :param: cuda_version: The version of CUDA to start the image from. 28 | :return: A Modal image for running a Tokasaurus server. 29 | """ 30 | 31 | return ( 32 | modal.Image.from_registry( 33 | f"nvidia/cuda:{cuda_version}-devel-ubuntu22.04", 34 | add_python="3.12", 35 | ) 36 | .entrypoint([]) # Remove verbose logging by base image on entry 37 | .apt_install("git") 38 | .uv_pip_install(f"tokasaurus=={version}") 39 | .env( 40 | { 41 | "HF_HUB_CACHE": HF_CACHE_PATH, 42 | "HF_HUB_ENABLE_HF_TRANSFER": "1", 43 | }, 44 | ) 45 | ) 46 | 47 | 48 | class TokasaurusBase: 49 | """A Modal class that runs a Tokasaurus server.""" 50 | 51 | @modal.web_server(port=PORT, startup_timeout=30 * MINUTES) 52 | def start(self) -> None: 53 | """Start a Tokasaurus server.""" 54 | 55 | # Save the startup time to a dictionary so we can measure cold start duration 56 | startup_metrics_dict[self.server_id] = datetime.now(timezone.utc).timestamp() 57 | 58 | hf_cache_volume.reload() 59 | 60 | if not self.model: 61 | msg = "model must be set, e.g. 'meta-llama/Llama-3.1-8B-Instruct'" 62 | raise ValueError(msg) 63 | 64 | server_config = json.loads(self.server_config) 65 | 66 | # Start Tokasaurus server 67 | subprocess.Popen( 68 | " ".join( 69 | [ 70 | "toka", 71 | f"model={self.model}", 72 | *( 73 | [f"tokenizer={server_config['tokenizer']}"] 74 | if "tokenizer" in server_config 75 | else [] 76 | ), 77 | *server_config.get("extra_args", []), 78 | ], 79 | ) 80 | + f" || python -m http.server {PORT}", 81 | env={ 82 | **os.environ, 83 | **server_config.get("env_vars", {}), 84 | }, 85 | shell=True, 86 | ) 87 | -------------------------------------------------------------------------------- /src/stopwatch/llm_servers/vllm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import subprocess 4 | from datetime import datetime, timezone 5 | 6 | import modal 7 | 8 | from stopwatch.constants import HF_CACHE_PATH, HOURS, LLMServerType 9 | from stopwatch.resources import hf_cache_volume, startup_metrics_dict, vllm_cache_volume 10 | 11 | PORT = 8000 12 | VLLM_PYTHON_BINARY = "/usr/bin/python3" 13 | 14 | 15 | def vllm_image_factory( 16 | docker_tag: str = LLMServerType.vllm.get_version(), 17 | extra_dockerfile_commands: list[str] | None = None, 18 | ) -> modal.Image: 19 | """ 20 | Create a Modal image for running a vLLM server. 21 | 22 | :param: docker_tag: The tag of the vLLM Docker image to use. 23 | :param: extra_dockerfile_commands: Extra Dockerfile commands to add to the image. 24 | :return: A Modal image for running a vLLM server. 25 | """ 26 | 27 | return ( 28 | modal.Image.from_registry( 29 | f"vllm/vllm-openai:{docker_tag}", 30 | add_python="3.13", 31 | ) 32 | .uv_pip_install("hf-transfer", "grpclib", "requests", "typer") 33 | .env( 34 | { 35 | "HF_HUB_CACHE": HF_CACHE_PATH, 36 | "HF_HUB_ENABLE_HF_TRANSFER": "1", 37 | "VLLM_SKIP_P2P_CHECK": "1", 38 | }, 39 | ) 40 | .dockerfile_commands(*(extra_dockerfile_commands or []), "ENTRYPOINT []") 41 | ) 42 | 43 | 44 | class vLLMBase: 45 | """A Modal class that runs a vLLM server.""" 46 | 47 | @modal.web_server(port=PORT, startup_timeout=1 * HOURS) 48 | def start(self) -> None: 49 | """Start a vLLM server.""" 50 | 51 | # Save the startup time to a dictionary so we can measure cold start duration 52 | startup_metrics_dict[self.server_id] = datetime.now(timezone.utc).timestamp() 53 | 54 | hf_cache_volume.reload() 55 | vllm_cache_volume.reload() 56 | 57 | if not self.model: 58 | msg = "model must be set, e.g. 'meta-llama/Llama-3.1-8B-Instruct'" 59 | raise ValueError(msg) 60 | 61 | server_config = json.loads(self.server_config) 62 | 63 | # Start vLLM server 64 | subprocess.Popen( 65 | " ".join( 66 | [ 67 | VLLM_PYTHON_BINARY, 68 | "-m", 69 | "vllm.entrypoints.openai.api_server", 70 | "--model", 71 | self.model, 72 | *( 73 | ["--tokenizer", server_config["tokenizer"]] 74 | if "tokenizer" in server_config 75 | else [] 76 | ), 77 | *server_config.get("extra_args", []), 78 | ], 79 | ) 80 | + f" || {VLLM_PYTHON_BINARY} -m http.server {PORT}", 81 | env={ 82 | **os.environ, 83 | **server_config.get("env_vars", {}), 84 | }, 85 | shell=True, 86 | ) 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # stopwatch 2 | 3 | _A simple solution for benchmarking [vLLM](https://docs.vllm.ai/en/latest/), [SGLang](https://docs.sglang.ai/), and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) on [Modal](https://modal.com/)._ ⏱️ 4 | 5 | ## Setup 6 | 7 | ### Install dependencies 8 | 9 | ```bash 10 | pip install -e . 11 | ``` 12 | 13 | ## Run a benchmark 14 | 15 | To run a single benchmark, you can use the `provision-and-benchmark` command, which will provision an LLM server, benchmark it, and save the results to a local file. 16 | For example, to run a synchronous (one request after another) benchmark with vLLM and save the results to `results.json`: 17 | 18 | ```bash 19 | LLM_SERVER_TYPE=vllm 20 | MODEL=meta-llama/Llama-3.1-8B-Instruct 21 | OUTPUT_PATH=results.json 22 | 23 | stopwatch provision-and-benchmark $MODEL $LLM_SERVER_TYPE --output-path $OUTPUT_PATH 24 | ``` 25 | 26 | Or, to run a fixed-rate (e.g. 5 requests per second) multi-GPU benchmark with SGLang: 27 | 28 | ```bash 29 | GPU_COUNT=4 30 | GPU_TYPE=H100 31 | LLM_SERVER_TYPE=sglang 32 | RATE_TYPE=constant 33 | REQUESTS_PER_SECOND=5 34 | 35 | stopwatch provision-and-benchmark $MODEL $LLM_SERVER_TYPE --output-path $OUTPUT_PATH --gpu "$GPU_TYPE:$GPU_COUNT" --rate-type $RATE_TYPE --rate $REQUESTS_PER_SECOND --llm-server-config "{\"extra_args\": [\"--tp-size\", \"$GPU_COUNT\"]}" 36 | ``` 37 | 38 | Or, to run a throughput (as many requests as the server can handle) test with TensorRT-LLM: 39 | 40 | ```bash 41 | LLM_SERVER_TYPE=tensorrt-llm 42 | RATE_TYPE=throughput 43 | 44 | stopwatch provision-and-benchmark $MODEL $LLM_SERVER_TYPE --output-path $OUTPUT_PATH --rate-type $RATE_TYPE 45 | ``` 46 | 47 | ## Run the profiler 48 | 49 | To profile a server with the PyTorch profiler, use the following command (only vLLM and SGLang are currently supported): 50 | 51 | ```bash 52 | LLM_SERVER_TYPE=vllm 53 | MODEL=meta-llama/Llama-3.1-8B-Instruct 54 | NUM_REQUESTS=10 55 | OUTPUT_PATH=trace.json.gz 56 | 57 | stopwatch profile $MODEL $LLM_SERVER_TYPE --output-path $OUTPUT_PATH --num-requests $NUM_REQUESTS 58 | ``` 59 | 60 | Once the profiling is done, the trace will be saved to `trace.json.gz`, which you can open and visualize at [https://ui.perfetto.dev](https://ui.perfetto.dev). 61 | Keep in mind that generated traces can get very large, so it is recommended to only send a few requests while profiling. 62 | 63 | ## Run tests 64 | 65 | Before committing any changes, you should make sure that your changes don't break any core functionality in Stopwatch. 66 | You may verify this with: 67 | 68 | ```bash 69 | pytest 70 | ``` 71 | 72 | ### Lint 73 | 74 | To make sure that any code changes are compliant with our linting rules, you can run `ruff` with: 75 | 76 | ```bash 77 | ruff check 78 | ``` 79 | 80 | ## Contributing 81 | 82 | We welcome contributions, including those that add tuned benchmarks to our collection. 83 | See the [CONTRIBUTING](/CONTRIBUTING.md) file and the [Getting Started](https://github.com/modal-labs/big-benchmark/wiki/Getting-Started) document for more details on contributing to Stopwatch. 84 | 85 | ## License 86 | 87 | Stopwatch is available under the MIT license. See the [LICENSE](/LICENSE.md) file for more details. 88 | -------------------------------------------------------------------------------- /tools/pd_disaggregation_proxy_server.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/vllm-project/vllm/blob/559756214b770d0405939a05172804221c2f5677/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py 2 | 3 | # SPDX-License-Identifier: Apache-2.0 4 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project 5 | 6 | import logging 7 | import os 8 | from collections.abc import AsyncGenerator 9 | from typing import Any 10 | 11 | import aiohttp 12 | from quart import Quart, Response, make_response, request 13 | 14 | AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) 15 | 16 | app = Quart(__name__) 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | async def forward_request( 21 | url: str, 22 | *, 23 | body: dict[str, Any] | None = None, 24 | method: str = "POST", 25 | ) -> AsyncGenerator[bytes, None]: 26 | """Forward a request to the prefill and decode servers.""" 27 | 28 | async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: 29 | headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} 30 | async with session.request( 31 | method=method, 32 | url=url, 33 | json=body, 34 | headers=headers, 35 | ) as response: 36 | if response.status == 200: # noqa: PLR2004 37 | if response.headers.get("Transfer-Encoding") == "chunked": 38 | # if True: 39 | async for chunk_bytes in response.content.iter_chunked(1024): 40 | yield chunk_bytes 41 | else: 42 | content = await response.read() 43 | yield content 44 | 45 | 46 | @app.route("/v1/completions", methods=["POST"]) 47 | async def handle_request() -> Response: 48 | """ 49 | Forward the request to the prefill and decode servers, and then return the response 50 | from the decode server. 51 | """ 52 | 53 | try: 54 | original_request_data = await request.get_json() 55 | 56 | prefill_request = original_request_data.copy() 57 | # change max_tokens = 1 to let it only do prefill 58 | prefill_request["max_tokens"] = 1 59 | 60 | # finish prefill 61 | async for _ in forward_request( 62 | "http://localhost:8100/v1/completions", 63 | body=prefill_request, 64 | ): 65 | continue 66 | 67 | # return from decode server 68 | generator = forward_request( 69 | "http://localhost:8200/v1/completions", 70 | body=original_request_data, 71 | ) 72 | response = await make_response(generator) 73 | response.timeout = None 74 | except Exception: 75 | logger.exception("Error occurred in disagg prefill proxy server") 76 | else: 77 | return response 78 | 79 | 80 | @app.route("/v1/models") 81 | async def handle_models() -> Response: 82 | """Return the models from the prefill server.""" 83 | generator = forward_request("http://localhost:8100/v1/models", method="GET") 84 | response = await make_response(generator) 85 | response.timeout = None 86 | return response 87 | 88 | 89 | @app.route("/ping") 90 | def handle_ping() -> str: 91 | """Return a 200 status code.""" 92 | return "pong" 93 | 94 | 95 | if __name__ == "__main__": 96 | app.run(host="0.0.0.0", port=8000) # noqa: S104 97 | -------------------------------------------------------------------------------- /src/stopwatch/profile.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import modal 5 | 6 | from .constants import GUIDELLM_VERSION 7 | from .resources import app, hf_secret, traces_volume 8 | 9 | TIMEOUT = 60 * 60 # 1 hour 10 | TRACES_PATH = "/traces" 11 | 12 | logger = logging.getLogger(__name__) 13 | logging.basicConfig(level=logging.INFO) 14 | 15 | 16 | profiling_image = ( 17 | modal.Image.debian_slim() 18 | .apt_install("git") 19 | .uv_pip_install( 20 | f"git+https://github.com/vllm-project/guidellm.git#{GUIDELLM_VERSION}", 21 | "openai", 22 | ) 23 | ) 24 | 25 | with profiling_image.imports(): 26 | import requests 27 | from guidellm.dataset import SyntheticDatasetConfig, SyntheticTextItemsGenerator 28 | from openai import OpenAI 29 | from transformers import AutoTokenizer 30 | 31 | 32 | @app.function( 33 | image=profiling_image, 34 | secrets=[hf_secret], 35 | volumes={TRACES_PATH: traces_volume}, 36 | timeout=TIMEOUT, 37 | ) 38 | def profile( 39 | endpoint: str, 40 | model: str, 41 | num_requests: int = 10, 42 | prompt_tokens: int = 512, 43 | output_tokens: int = 8, 44 | ) -> str: 45 | """ 46 | Run the PyTorch profiler alongside an LLM server. Currently, only vLLM is 47 | supported. 48 | 49 | :param: endpoint: The endpoint of the OpenAI-compatible LLM server to use. 50 | :param: model: The model to use. 51 | :param: num_requests: The number of requests to make. Traces get large very 52 | quickly, so this should be kept small. 53 | :param: prompt_tokens: The number of tokens to include in each request's prompt. 54 | :param: output_tokens: The number of tokens to generate in each request. 55 | :return: The path to the trace file. 56 | """ 57 | 58 | logger.info("Starting profiler with %s", model) 59 | 60 | generator_config = SyntheticDatasetConfig( 61 | prompt_tokens=prompt_tokens, 62 | output_tokens=output_tokens, 63 | ) 64 | tokenizer = AutoTokenizer.from_pretrained(model) 65 | text_generator = iter( 66 | SyntheticTextItemsGenerator( 67 | config=generator_config, 68 | processor=tokenizer, 69 | random_seed=42, 70 | ), 71 | ) 72 | 73 | # Start profiler 74 | requests.post(f"{endpoint}/start_profile") 75 | 76 | # Start vLLM server in background 77 | client = OpenAI(api_key="EMPTY", base_url=f"{endpoint}/v1") 78 | 79 | for _ in range(num_requests): 80 | client.completions.create( 81 | model=model, 82 | prompt=next(text_generator)["prompt"], 83 | max_tokens=output_tokens, 84 | echo=False, 85 | stream=False, 86 | ) 87 | 88 | # Stop profiler 89 | requests.post(f"{endpoint}/stop_profile") 90 | 91 | # Find and return trace path 92 | most_recent_path = None 93 | most_recent_size = 0 94 | most_recent_timestamp = 0 95 | 96 | for file in traces_volume.iterdir("/"): 97 | if file.mtime > most_recent_timestamp: 98 | most_recent_path = file.path 99 | most_recent_size = file.size 100 | most_recent_timestamp = file.mtime 101 | 102 | # Wait for profiler to finish writing profiling output before returning 103 | while True: 104 | time.sleep(5) 105 | 106 | traces_volume.reload() 107 | 108 | for file in traces_volume.iterdir("/"): 109 | if file.path == most_recent_path: 110 | latest_size = file.size 111 | break 112 | 113 | if latest_size == most_recent_size: 114 | break 115 | 116 | most_recent_size = latest_size 117 | 118 | return most_recent_path 119 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Stopwatch 2 | 3 | Thank you for taking the time to contribute to Stopwatch! 4 | All types of contributions are encouraged, but please first discuss the change you wish to make via issue, email, or another method with the owners of this repository. 5 | 6 | ## Adding benchmarks 7 | 8 | To generate the results you can view at [almanac.modal.com](https://almanac.modal.com), we use Big Benchmark, which runs Stopwatch internally. 9 | If you would like to contribute to this repository of benchmarks, please check out [Big Benchmark](https://github.com/modal-labs/big-benchmark) for more details. 10 | 11 | ## Code of Conduct 12 | 13 | ### Our Pledge 14 | 15 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. 16 | 17 | ### Our Standards 18 | 19 | Examples of behavior that contributes to creating a positive environment include: 20 | 21 | - Using welcoming and inclusive language 22 | - Being respectful of differing viewpoints and experiences 23 | - Gracefully accepting constructive criticism 24 | - Focusing on what is best for the community 25 | - Showing empathy towards other community members 26 | 27 | Examples of unacceptable behavior by participants include: 28 | 29 | - The use of sexualized language or imagery and unwelcome sexual attention or advances 30 | - Trolling, insulting/derogatory comments, and personal or political attacks 31 | - Public or private harassment 32 | - Publishing others' private information, such as a physical or electronic address, without explicit permission 33 | - Other conduct which could reasonably be considered inappropriate in a professional setting 34 | 35 | ### Our Responsibilities 36 | 37 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 38 | 39 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 40 | 41 | ### Scope 42 | 43 | This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 44 | 45 | ### Enforcement 46 | 47 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at [INSERT EMAIL ADDRESS]. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 48 | 49 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. 50 | 51 | ### Attribution 52 | 53 | This Code of Conduct is adapted from the Contributor Covenant, version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct/ 54 | -------------------------------------------------------------------------------- /src/stopwatch/llm_servers/vllm_pd_disaggregation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import subprocess 4 | from datetime import datetime 5 | 6 | import modal 7 | 8 | from stopwatch.constants import HOURS, LLMServerType 9 | from stopwatch.resources import ( 10 | hf_cache_volume, 11 | startup_metrics_dict, 12 | vllm_cache_volume, 13 | ) 14 | 15 | from .vllm import PORT, VLLM_PYTHON_BINARY, vllm_image_factory 16 | 17 | DECODE_PORT = 8200 18 | PREFILL_PORT = 8100 19 | PROXY_SERVER_SCRIPT = "/root/tools/pd_disaggregation_proxy_server.py" 20 | 21 | 22 | def vllm_pd_disaggregation_image_factory( 23 | docker_tag: str = LLMServerType.vllm.get_version(), 24 | ) -> modal.Image: 25 | """ 26 | Create a Modal image for running a vLLM server with PD disaggregation. 27 | 28 | :param: docker_tag: The tag of the vLLM Docker image to use. 29 | :return: A Modal image for running a vLLM server with PD disaggregation. 30 | """ 31 | 32 | return vllm_image_factory( 33 | docker_tag, 34 | extra_dockerfile_commands=[ 35 | f"RUN {VLLM_PYTHON_BINARY} -m pip install quart --ignore-installed", 36 | ], 37 | ).add_local_python_source("tools") 38 | 39 | 40 | class vLLMPDDisaggregationBase: 41 | """A Modal class that runs a vLLM server.""" 42 | 43 | @modal.web_server(port=PORT, startup_timeout=1 * HOURS) 44 | def start(self) -> None: 45 | """Start a vLLM server.""" 46 | 47 | # Save the startup time to a dictionary so we can measure cold start duration 48 | startup_metrics_dict[self.server_id] = datetime.now(datetime.UTC).timestamp() 49 | 50 | hf_cache_volume.reload() 51 | vllm_cache_volume.reload() 52 | 53 | if not self.model: 54 | msg = "model must be set, e.g. 'meta-llama/Llama-3.1-8B-Instruct'" 55 | raise ValueError(msg) 56 | 57 | server_config = json.loads(self.server_config) 58 | 59 | # vLLM currently only supports 2-GPU setups for disaggregated prefill. See: 60 | # https://github.com/vllm-project/vllm/issues/13004 61 | prefill_devices = "0" 62 | decode_devices = "1" 63 | 64 | # Start prefill server 65 | subprocess.Popen( 66 | " ".join( 67 | [ 68 | VLLM_PYTHON_BINARY, 69 | "-m", 70 | "vllm.entrypoints.openai.api_server", 71 | "--model", 72 | self.model, 73 | *( 74 | ["--tokenizer", server_config["tokenizer"]] 75 | if "tokenizer" in server_config 76 | else [] 77 | ), 78 | "--kv-transfer-config", 79 | '\'{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}\'', 80 | "--port", 81 | str(PREFILL_PORT), 82 | *server_config.get("extra_args", []), 83 | ], 84 | ) 85 | + f" || {VLLM_PYTHON_BINARY} -m http.server {PREFILL_PORT}", 86 | env={ 87 | **os.environ, 88 | **server_config.get("env_vars", {}), 89 | "CUDA_VISIBLE_DEVICES": prefill_devices, 90 | }, 91 | shell=True, 92 | ) 93 | 94 | # Start decode server 95 | subprocess.Popen( 96 | " ".join( 97 | [ 98 | VLLM_PYTHON_BINARY, 99 | "-m", 100 | "vllm.entrypoints.openai.api_server", 101 | "--model", 102 | self.model, 103 | *( 104 | ["--tokenizer", server_config["tokenizer"]] 105 | if "tokenizer" in server_config 106 | else [] 107 | ), 108 | "--kv-transfer-config", 109 | '\'{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}\'', 110 | "--port", 111 | str(DECODE_PORT), 112 | *server_config.get("extra_args", []), 113 | ], 114 | ) 115 | + f" || {VLLM_PYTHON_BINARY} -m http.server {DECODE_PORT}", 116 | env={ 117 | **os.environ, 118 | **server_config.get("env_vars", {}), 119 | "CUDA_VISIBLE_DEVICES": decode_devices, 120 | }, 121 | shell=True, 122 | ) 123 | 124 | # Wait for both servers to start 125 | import time 126 | 127 | import requests 128 | 129 | for port in [PREFILL_PORT, DECODE_PORT]: 130 | while True: 131 | try: 132 | requests.get(f"http://localhost:{port}/v1/completions") 133 | break 134 | except requests.exceptions.ConnectionError: 135 | time.sleep(1) 136 | 137 | # Run proxy server 138 | subprocess.Popen(f"{VLLM_PYTHON_BINARY} {PROXY_SERVER_SCRIPT}", shell=True) 139 | -------------------------------------------------------------------------------- /src/stopwatch/llm_servers/tensorrt_llm.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | import logging 4 | import os 5 | import subprocess 6 | import time 7 | import traceback 8 | from datetime import datetime, timezone 9 | from pathlib import Path 10 | 11 | import modal 12 | 13 | from stopwatch.constants import ( 14 | HF_CACHE_PATH, 15 | HOURS, 16 | TENSORRT_LLM_CUDA_VERSION, 17 | LLMServerType, 18 | ) 19 | from stopwatch.resources import hf_cache_volume, startup_metrics_dict 20 | 21 | LLM_KWARGS_PATH = "llm_kwargs.yaml" 22 | PORT = 8000 23 | 24 | logger = logging.getLogger(__name__) 25 | logging.basicConfig(level=logging.INFO) 26 | 27 | 28 | def tensorrt_llm_image_factory( 29 | tensorrt_llm_version: str = LLMServerType.tensorrt_llm.get_version(), 30 | cuda_version: str = TENSORRT_LLM_CUDA_VERSION, 31 | ) -> modal.Image: 32 | """ 33 | Create a Modal image for running a TensorRT-LLM server. 34 | 35 | :param: tensorrt_llm_version: The version of TensorRT-LLM to install. 36 | :param: cuda_version: The version of CUDA to start the image from. 37 | :return: A Modal image for running a TensorRT-LLM server. 38 | """ 39 | 40 | return ( 41 | modal.Image.from_registry( 42 | f"nvidia/cuda:{cuda_version}-devel-ubuntu24.04", 43 | add_python="3.12", 44 | ) 45 | .entrypoint([]) # Remove verbose logging by base image on entry 46 | .apt_install("libopenmpi-dev", "git", "git-lfs", "wget") 47 | .uv_pip_install( 48 | # This next line is not needed for tensorrt-llm>=1.1.0rc0, but 1.1.0 has a 49 | # bug when running HF models at the moment. Once 1.1.0 is stable, the line 50 | # for cuda-python can be removed. 51 | "cuda-python<13.0", 52 | "pynvml", 53 | f"tensorrt-llm=={tensorrt_llm_version}", 54 | extra_index_url="https://pypi.nvidia.com", 55 | ) 56 | .uv_pip_install( 57 | "hf-transfer", 58 | "huggingface-hub[hf_xet]", 59 | "requests", 60 | ) 61 | .env( 62 | { 63 | "HF_HUB_CACHE": HF_CACHE_PATH, 64 | "HF_HUB_ENABLE_HF_TRANSFER": "1", 65 | "PMIX_MCA_gds": "hash", 66 | }, 67 | ) 68 | ) 69 | 70 | 71 | class TensorRTLLMBase: 72 | """A Modal class that runs a TensorRT-LLM server.""" 73 | 74 | @modal.enter() 75 | def enter(self) -> None: 76 | """Download the base model and build the TensorRT-LLM engine.""" 77 | import tensorrt_llm 78 | import yaml 79 | 80 | # This entire function needs to be wrapped in a try/except block. If an error 81 | # occurs here, a crash will result in the container getting automatically 82 | # restarted indefinitely. By logging the error and returning, the container 83 | # will fail on the trtllm-serve command and then start the http.server module, 84 | # which will return a 404 error that will be returned to the client. 85 | 86 | try: 87 | server_config = json.loads(self.server_config) 88 | llm_kwargs = server_config.get("llm_kwargs", {}) 89 | 90 | if len(llm_kwargs) > 0: 91 | logger.info("Received server config") 92 | logger.info(server_config) 93 | 94 | engine_fingerprint = hashlib.md5( # noqa: S324 95 | json.dumps(llm_kwargs, sort_keys=True).encode(), 96 | ).hexdigest() 97 | logger.info("Engine fingerprint: %s", engine_fingerprint) 98 | logger.info("%s", llm_kwargs) 99 | 100 | self.config_path = ( 101 | Path(HF_CACHE_PATH) 102 | / "tensorrt-llm-configs" 103 | / f"{tensorrt_llm.__version__}-{self.model}-{engine_fingerprint}" 104 | / LLM_KWARGS_PATH 105 | ) 106 | logger.info("Config path: %s", self.config_path) 107 | 108 | if not self.config_path.exists(): 109 | # Save the config 110 | if not self.config_path.parent.exists(): 111 | self.config_path.parent.mkdir(parents=True) 112 | 113 | with self.config_path.open("w") as f: 114 | yaml.dump(llm_kwargs, f) 115 | except Exception: # noqa: BLE001 116 | traceback.print_exc() 117 | 118 | @modal.web_server(port=PORT, startup_timeout=1 * HOURS) 119 | def start(self) -> None: 120 | """Start a TensorRT-LLM server.""" 121 | 122 | # Save the startup time to a dictionary so we can measure cold start duration 123 | startup_metrics_dict[self.server_id] = datetime.now(timezone.utc).timestamp() 124 | 125 | if not self.model: 126 | msg = "model must be set, e.g. 'meta-llama/Llama-3.1-8B-Instruct'" 127 | raise ValueError(msg) 128 | 129 | server_config = json.loads(self.server_config) 130 | 131 | if hasattr(self, "config_path"): 132 | # Make sure the volume is up-to-date and this container has access to the 133 | # saved config. 134 | for _ in range(10): 135 | if self.config_path.exists(): 136 | break 137 | time.sleep(5) 138 | hf_cache_volume.reload() 139 | else: 140 | # self.config_path is only not set if there was an error in enter(). By 141 | # setting self.config_path to "none", trtllm-serve will fail, as explained 142 | # in the comment at the start of enter(). 143 | self.config_path = None 144 | 145 | # Start TensorRT-LLM server 146 | subprocess.Popen( 147 | " ".join( 148 | [ 149 | "trtllm-serve", 150 | self.model, 151 | "--host", 152 | "0.0.0.0", 153 | *( 154 | ["--extra_llm_api_options", str(self.config_path)] 155 | if self.config_path is not None 156 | else [] 157 | ), 158 | *( 159 | ["--tokenizer", server_config["tokenizer"]] 160 | if "tokenizer" in server_config 161 | else [] 162 | ), 163 | *server_config.get("extra_args", []), 164 | ], 165 | ) 166 | + f" || python -m http.server {PORT}", 167 | env={ 168 | **os.environ, 169 | **server_config.get("env_vars", {}), 170 | }, 171 | shell=True, 172 | ) 173 | -------------------------------------------------------------------------------- /src/stopwatch/llm_servers/dynamic.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | from pathlib import Path 4 | from typing import Any 5 | 6 | import modal 7 | 8 | from stopwatch.constants import ( 9 | DB_PATH, 10 | HF_CACHE_PATH, 11 | HOURS, 12 | MINUTES, 13 | SECONDS, 14 | TRACES_PATH, 15 | VLLM_CACHE_PATH, 16 | LLMServerType, 17 | ) 18 | from stopwatch.resources import ( 19 | app, 20 | db_volume, 21 | hf_cache_volume, 22 | hf_secret, 23 | traces_volume, 24 | vllm_cache_volume, 25 | ) 26 | 27 | from .sglang import SGLangBase, sglang_image_factory 28 | from .tensorrt_llm import TensorRTLLMBase, tensorrt_llm_image_factory 29 | from .tokasaurus import TokasaurusBase, tokasaurus_image_factory 30 | from .vllm import ( 31 | vllm_image_factory, 32 | vLLMBase, 33 | ) 34 | from .vllm_pd_disaggregation import ( 35 | vllm_pd_disaggregation_image_factory, 36 | ) 37 | 38 | 39 | def get_llm_server_class(llm_server_type: LLMServerType) -> type: 40 | """Get the base class for creating an LLM server with a given type.""" 41 | 42 | llm_server_classes = { 43 | LLMServerType.sglang: SGLangBase, 44 | LLMServerType.tensorrt_llm: TensorRTLLMBase, 45 | LLMServerType.tokasaurus: TokasaurusBase, 46 | LLMServerType.vllm: vLLMBase, 47 | LLMServerType.vllm_pd_disaggregation: vLLMBase, 48 | } 49 | 50 | return llm_server_classes[llm_server_type] 51 | 52 | 53 | def get_image( 54 | llm_server_type: LLMServerType, 55 | llm_server_config: dict[str, Any], 56 | ) -> modal.Image: 57 | """Create an image for an LLM server with a given type and configuration.""" 58 | image_factory_fn = { 59 | LLMServerType.sglang: sglang_image_factory, 60 | LLMServerType.tensorrt_llm: tensorrt_llm_image_factory, 61 | LLMServerType.tokasaurus: tokasaurus_image_factory, 62 | LLMServerType.vllm: vllm_image_factory, 63 | LLMServerType.vllm_pd_disaggregation: vllm_pd_disaggregation_image_factory, 64 | } 65 | 66 | return image_factory_fn[llm_server_type]( 67 | llm_server_config.get("version", llm_server_type.get_version()), 68 | **llm_server_config.get("image_kwargs", {}), 69 | ) 70 | 71 | 72 | def get_scaledown_window(llm_server_type: LLMServerType) -> int: 73 | """Get the scaledown window for an LLM server with a given type.""" 74 | 75 | scaledown_windows = { 76 | LLMServerType.sglang: 2 * MINUTES, 77 | LLMServerType.tensorrt_llm: 30 * SECONDS, 78 | LLMServerType.tokasaurus: 2 * MINUTES, 79 | LLMServerType.vllm: 30 * SECONDS, 80 | LLMServerType.vllm_pd_disaggregation: 30 * SECONDS, 81 | } 82 | 83 | return scaledown_windows[llm_server_type] 84 | 85 | 86 | def get_timeout(llm_server_type: LLMServerType) -> int: 87 | """Get the timeout for an LLM server with a given type.""" 88 | 89 | timeouts = { 90 | LLMServerType.sglang: 30 * MINUTES, 91 | LLMServerType.tensorrt_llm: 30 * MINUTES, 92 | LLMServerType.tokasaurus: 30 * MINUTES, 93 | LLMServerType.vllm: 1 * HOURS, 94 | LLMServerType.vllm_pd_disaggregation: 1 * HOURS, 95 | } 96 | 97 | return timeouts[llm_server_type] 98 | 99 | 100 | def get_volumes(llm_server_type: LLMServerType) -> dict[str, modal.Volume]: 101 | """Get the volumes for an LLM server with a given type.""" 102 | 103 | volumes = { 104 | DB_PATH: db_volume, 105 | HF_CACHE_PATH: hf_cache_volume, 106 | } 107 | 108 | if llm_server_type in (LLMServerType.vllm, LLMServerType.vllm_pd_disaggregation): 109 | volumes[VLLM_CACHE_PATH] = vllm_cache_volume 110 | 111 | if llm_server_type in (LLMServerType.vllm, LLMServerType.sglang): 112 | volumes[TRACES_PATH] = traces_volume 113 | 114 | return volumes 115 | 116 | 117 | def LLMServerClassFactory( # noqa: N802 118 | name: str, 119 | model: str, 120 | llm_server_type: LLMServerType, 121 | llm_server_config: dict[str, Any] | None = None, 122 | ) -> type: 123 | """ 124 | Create an LLM server class. 125 | 126 | :param: name: The name of the class. 127 | :param: model: Name of the model deployed on this server. 128 | :param: llm_server_type: Type of LLM server. 129 | :param: llm_server_config: Extra configuration for the LLM server. 130 | :return: A server class that hosts an OpenAI-compatible API endpoint. 131 | """ 132 | 133 | server_class = get_llm_server_class(llm_server_type) 134 | 135 | return type( 136 | name, 137 | (server_class,), 138 | { 139 | "model": model, 140 | "server_config": json.dumps(llm_server_config) or "{}", 141 | "server_id": name, 142 | "__annotations__": {}, 143 | }, 144 | ) 145 | 146 | 147 | def create_dynamic_llm_server_class( 148 | name: str, 149 | model: str, 150 | *, 151 | gpu: str, 152 | llm_server_type: LLMServerType, 153 | cpu: int | None = None, 154 | memory: int | None = None, 155 | min_containers: int | None = None, 156 | max_containers: int | None = None, 157 | cloud: str | None = None, 158 | region: str | None = None, 159 | llm_server_config: dict[str, Any] | None = None, 160 | max_concurrent_inputs: int = 1000, 161 | batch: modal.volume.AbstractVolumeUploadContextManager | None = None, 162 | ) -> tuple[type, str]: 163 | """ 164 | Create an LLM server class on the fly that will be included in the deployed Modal 165 | app. 166 | """ 167 | 168 | # Set default values for parameters that are not provided 169 | num_gpus = 1 if ":" not in gpu else int(gpu.split(":")[1]) 170 | 171 | if llm_server_config is None: 172 | llm_server_config = {} 173 | 174 | # Name must start with "LLM_" 175 | if not name.startswith("LLM_"): 176 | name = f"LLM_{name}" 177 | 178 | # Save server config to the DB volume so the class can be recreated later 179 | server_config = { 180 | "model": model, 181 | "llm_server_type": llm_server_type.value, 182 | "llm_server_config": llm_server_config, 183 | } 184 | 185 | config_buf = io.BytesIO(json.dumps(server_config).encode()) 186 | config_path = f"deployments/{name}.json" 187 | 188 | if batch is None: 189 | with db_volume.batch_upload(force=True) as b: 190 | b.put_file(config_buf, config_path) 191 | else: 192 | batch.put_file(config_buf, config_path) 193 | 194 | # Deploy the newly created class 195 | return ( 196 | app.cls( 197 | image=get_image(llm_server_type, llm_server_config), 198 | secrets=[hf_secret], 199 | gpu=gpu, 200 | volumes=get_volumes(llm_server_type), 201 | cpu=cpu or 4 * num_gpus, 202 | memory=memory or 8 * 1024 * num_gpus, 203 | min_containers=min_containers or 0, 204 | max_containers=max_containers or 1, 205 | scaledown_window=get_scaledown_window(llm_server_type), 206 | timeout=get_timeout(llm_server_type), 207 | cloud=cloud, 208 | region=region, 209 | )( 210 | modal.concurrent(max_inputs=max_concurrent_inputs)( 211 | LLMServerClassFactory( 212 | name, 213 | model, 214 | llm_server_type, 215 | llm_server_config, 216 | ), 217 | ), 218 | ), 219 | name, 220 | ) 221 | 222 | 223 | def __getattr__(name: str): # noqa: ANN202 224 | """ 225 | When Stopwatch is run, classes will be created dynamically in order to meet the 226 | needs of the configured benchmark(s). Modal will then need to call these classes 227 | once the code is deployed. This function allows us to dynamically create these 228 | classes once Stopwatch has already been deployed. 229 | """ 230 | 231 | if name in globals(): 232 | return globals()[name] 233 | 234 | if name.startswith("LLM_"): 235 | with Path(DB_PATH).joinpath("deployments", f"{name}.json").open("r") as f: 236 | server_config = json.load(f) 237 | 238 | return LLMServerClassFactory( 239 | name, 240 | server_config["model"], 241 | LLMServerType(server_config["llm_server_type"]), 242 | server_config["llm_server_config"], 243 | ) 244 | 245 | msg = f"No attribute {name}" 246 | raise AttributeError(msg) 247 | -------------------------------------------------------------------------------- /src/stopwatch/benchmark.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import itertools 3 | import logging 4 | from collections.abc import Mapping 5 | from datetime import datetime, timezone 6 | from typing import Any 7 | 8 | import modal 9 | 10 | from stopwatch.constants import GUIDELLM_VERSION, HOURS, SECONDS, RateType 11 | from stopwatch.resources import app, hf_secret, results_volume, startup_metrics_dict 12 | 13 | DELAY_BETWEEN_BENCHMARKS = 5 * SECONDS 14 | MEMORY = 1 * 1024 15 | NUM_CORES = 2 16 | RESULTS_PATH = "/results" 17 | SCALEDOWN_WINDOW = 5 * SECONDS 18 | TIMEOUT = 4 * HOURS 19 | 20 | logger = logging.getLogger(__name__) 21 | logging.basicConfig(level=logging.INFO) 22 | 23 | guidellm_image = ( 24 | modal.Image.debian_slim() 25 | .apt_install("git") 26 | .uv_pip_install( 27 | f"git+https://github.com/vllm-project/guidellm.git#{GUIDELLM_VERSION}", 28 | "tiktoken", 29 | ) 30 | .env( 31 | { 32 | "GUIDELLM__MAX_WORKER_PROCESSES": f"{NUM_CORES - 1}", 33 | }, 34 | ) 35 | ) 36 | 37 | with guidellm_image.imports(): 38 | from guidellm.benchmark.benchmarker import GenerativeBenchmarker 39 | from guidellm.benchmark.profile import create_profile 40 | from pydantic_core import ValidationError 41 | 42 | from stopwatch.utils import CustomGenerativeRequestLoader, CustomOpenAIHTTPBackend 43 | 44 | 45 | @app.cls( 46 | image=guidellm_image, 47 | secrets=[hf_secret], 48 | volumes={RESULTS_PATH: results_volume}, 49 | cpu=NUM_CORES, 50 | memory=MEMORY, 51 | scaledown_window=SCALEDOWN_WINDOW, 52 | timeout=TIMEOUT, 53 | ) 54 | class GuideLLM: 55 | """Run benchmarks with GuideLLM.""" 56 | 57 | @modal.method() 58 | async def run_benchmark( # noqa: C901, PLR0912 59 | self, 60 | endpoint: str, 61 | model: str, 62 | rate_type: RateType | list[RateType], 63 | data: str, 64 | duration: float | None = 120, # 2 minutes 65 | client_config: Mapping[str, Any] | None = None, 66 | rate: float | list[float] | None = None, 67 | server_id: str | None = None, 68 | ) -> list[dict[str, Any]]: 69 | """ 70 | Benchmarks a LLM deployment on Modal. 71 | 72 | :param: model: Name of the model to benchmark. 73 | :param: rate_type: The type of rate to use for benchmarking. If this is a list, 74 | benchmarks will be run sequentially at each rate type. 75 | :param: data: A configuration for emulated data (e.g.: 76 | 'prompt_tokens=128,output_tokens=128'). 77 | :param: duration: The duration of the benchmark in seconds. 78 | :param: client_config: Configuration for the GuideLLM client. 79 | :param: rate: If rate_type is RateType.constant, specify the number of requests 80 | that should be made per second. If this is a list, benchmarks will be run 81 | sequentially at each request rate. 82 | :param: server_id: The ID of the server being benchmarked. Useful for tracking 83 | cold start durations. 84 | """ 85 | 86 | if client_config is None: 87 | client_config = {} 88 | 89 | extra_query = client_config.get("extra_query", {}) 90 | 91 | # Convert rate_type to a list 92 | if not isinstance(rate_type, list): 93 | rate_type = [rate_type] 94 | 95 | # Convert RateTypes to strings 96 | for i in range(len(rate_type)): 97 | if isinstance(rate_type[i], RateType): 98 | rate_type[i] = rate_type[i].value 99 | 100 | # Convert rate to a list 101 | if not isinstance(rate, list): 102 | rate = [rate] 103 | 104 | if len(rate_type) > 1 and len(rate) > 1: 105 | msg = ( 106 | f"All benchmarks must have either the same rate type or the same rate: " 107 | f"{rate_type} vs. {rate}" 108 | ) 109 | raise ValueError(msg) 110 | 111 | # Create the request loader before starting the LLM server, since this can take 112 | # a long time for data configs with many prompt tokens. 113 | processor = client_config.get("tokenizer", model) 114 | request_loader = CustomGenerativeRequestLoader( 115 | data=data, 116 | data_args=None, 117 | processor=processor, 118 | processor_args=None, 119 | shuffle=False, 120 | iter_type="infinite", 121 | random_seed=42, 122 | extra_body=client_config.get("extra_body", {}), 123 | use_chat_completions=client_config.get("use_chat_completions", False), 124 | ) 125 | unique_requests = request_loader.num_unique_items(raise_err=False) 126 | logger.info( 127 | ( 128 | f"Created loader with {unique_requests} unique requests from {data}" 129 | if unique_requests > 0 130 | else f"Created loader with unknown number unique requests from {data}" 131 | ), 132 | ) 133 | 134 | benchmark_results = [] 135 | 136 | # Create backend 137 | backend = CustomOpenAIHTTPBackend( 138 | target=endpoint, 139 | extra_query=extra_query, 140 | remove_from_body=[ 141 | # TensorRT-LLM returns a 400 error if max_completion_tokens is set 142 | "max_completion_tokens", 143 | *client_config.get("remove_from_body", []), 144 | ], 145 | ) 146 | 147 | # Start the LLM server and save queuing and cold start durations 148 | queue_time = datetime.now(timezone.utc).timestamp() 149 | await backend.validate() 150 | connection_time = datetime.now(timezone.utc).timestamp() 151 | queue_duration = connection_time - queue_time 152 | 153 | if ( 154 | server_id is not None 155 | and (container_start_time := startup_metrics_dict.get(server_id)) 156 | is not None 157 | ): 158 | cold_start_duration = connection_time - container_start_time 159 | queue_duration -= cold_start_duration 160 | else: 161 | cold_start_duration = None 162 | 163 | logger.info("Connected to backend for model %s", backend.model) 164 | 165 | rates_to_run = list(itertools.product(rate_type, rate)) 166 | logger.info("Running %d benchmarks", len(rates_to_run)) 167 | 168 | for i, (rate_type_i, rate_i) in enumerate(rates_to_run): 169 | logger.info( 170 | "Starting benchmark with rate type %s and rate %s", 171 | rate_type_i, 172 | rate_i, 173 | ) 174 | 175 | profile = create_profile(rate_type=rate_type_i, rate=rate_i) 176 | benchmarker = GenerativeBenchmarker( 177 | backend, 178 | request_loader, 179 | request_loader.description, 180 | processor=processor, 181 | ) 182 | 183 | try: 184 | async for result in benchmarker.run( 185 | profile=profile, 186 | max_number_per_strategy=None, 187 | max_duration_per_strategy=duration, 188 | warmup_percent_per_strategy=None, 189 | cooldown_percent_per_strategy=None, 190 | ): 191 | if result.type_ == "benchmark_compiled": 192 | if result.current_benchmark is None: 193 | logger.exception( 194 | "Error running benchmark: Current benchmark is None", 195 | ) 196 | continue 197 | 198 | benchmark_results.append( 199 | { 200 | **result.current_benchmark.model_dump(), 201 | "rate_type": rate_type_i, 202 | "rate": rate_i, 203 | "queue_duration": queue_duration, 204 | "cold_start_duration": cold_start_duration, 205 | }, 206 | ) 207 | except (ValueError, ValidationError): 208 | logger.exception( 209 | "Error running benchmark: No requests completed successfully", 210 | ) 211 | continue 212 | 213 | if i != len(rates_to_run) - 1: 214 | await asyncio.sleep(DELAY_BETWEEN_BENCHMARKS) 215 | 216 | return benchmark_results 217 | --------------------------------------------------------------------------------