├── .dockerignore ├── .gitattributes ├── .github └── workflows │ └── deploy.yml ├── .gitignore ├── Justfile ├── LICENSE ├── Procfile ├── README.md ├── app.py ├── dalle.py ├── dalle_suite.sh ├── elevenlabs_api_benchmark.py ├── elevenlabs_stream_benchmark.py ├── elevenlabs_ws_benchmark.py ├── fly.toml ├── llm_benchmark.py ├── llm_benchmark_suite.py ├── llm_request.py ├── media ├── audio │ ├── boolq.wav │ ├── news.wav │ └── say_cheese.wav ├── image │ ├── great_wave.png │ └── inception.jpeg ├── text │ └── llama31.md ├── tools │ └── flights.json └── video │ └── psa.webm ├── openai_finetune.py ├── playht_benchmark.py ├── poetry.lock └── pyproject.toml /.dockerignore: -------------------------------------------------------------------------------- 1 | fly.toml 2 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | great_wave.png filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.github/workflows/deploy.yml: -------------------------------------------------------------------------------- 1 | name: Deploy to Fly.io 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | deploy: 10 | name: Deploy benchmarks to Fly.io 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | - uses: superfly/flyctl-actions/setup-flyctl@master 15 | - run: | 16 | echo '${{secrets.GCP_SERVICE_ACCOUNT}}' > service_account.json 17 | flyctl deploy --remote-only 18 | env: 19 | FLY_API_TOKEN: ${{ secrets.FLY_API_TOKEN }} 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | venv 3 | __pycache__/ 4 | service_account.json 5 | -------------------------------------------------------------------------------- /Justfile: -------------------------------------------------------------------------------- 1 | # This is the main Justfile for the Fixie repo. 2 | # To install Just, see: https://github.com/casey/just#installation 3 | 4 | # Allow for positional arguments in Just recipes. 5 | set positional-arguments := true 6 | 7 | # Default recipe that runs if you type "just". 8 | default: format check test 9 | 10 | # Install dependencies for local development. 11 | install: 12 | pip install poetry==1.7.1 13 | poetry install --sync 14 | 15 | format: 16 | poetry run autoflake . --remove-all-unused-imports --quiet --in-place -r --exclude third_party 17 | poetry run isort . --force-single-line-imports 18 | poetry run black . 19 | 20 | check: 21 | poetry run black . --check 22 | poetry run isort . --check --force-single-line-imports 23 | poetry run autoflake . --check --quiet --remove-all-unused-imports -r --exclude third_party 24 | poetry run mypy . 25 | 26 | deploy *FLAGS: 27 | flyctl deploy {{FLAGS}} 28 | 29 | server: 30 | just python app.py 31 | 32 | curl *FLAGS: 33 | curl -X POST "https://ai-benchmarks.fly.dev/bench?max_tokens=20&{{FLAGS}}" -H fly-prefer-region:sea 34 | 35 | curl_local *FLAGS: 36 | curl -X POST "http://localhost:8000/bench?max_tokens=20&{{FLAGS}}" 37 | 38 | test *FLAGS: 39 | poetry run pytest {{FLAGS}} 40 | 41 | python *FLAGS: 42 | poetry run python {{FLAGS}} 43 | 44 | llm *FLAGS: 45 | poetry run python llm_benchmark.py {{FLAGS}} 46 | 47 | llms *FLAGS: 48 | poetry run python llm_benchmark_suite.py {{FLAGS}} 49 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Fixie.ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Procfile: -------------------------------------------------------------------------------- 1 | web: gunicorn -k uvicorn.workers.UvicornWorker app:app 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ai-benchmarks 2 | 3 | This repo contains a handful of utilities for benchmarking the response latency of popular AI services, including: 4 | 5 | Large Language Models (LLMs): 6 | 7 | - OpenAI GPT-3.5, GPT-4 (from OpenAI or Azure OpenAI service) 8 | - Anthropic Claude 3, Claude 2, Claude Instant 9 | - Google Gemini Pro and PaLM 2 Bison 10 | - Llama2 and 3 from several different providers, including 11 | - Anyscale 12 | - Azure 13 | - Cerebras 14 | - Cloudflare 15 | - Groq 16 | - OctoAI 17 | - OVHcloud 18 | - Perplexity 19 | - Together 20 | - Mixtral 8x7B from several different providers, including 21 | - Anyscale 22 | - Azure 23 | - Groq 24 | - OctoAI 25 | - OVHcloud 26 | - Perplexity 27 | 28 | Embedding Models: 29 | 30 | - Ada-002 31 | - Cohere 32 | 33 | Text-to-Speech Models (TTS): 34 | 35 | - ElevenLabs 36 | - PlayHT 37 | 38 | ## Leaderboard 39 | 40 | See [thefastest.ai](https://thefastest.ai) for the current leaderboard. 41 | 42 | ### Test methodology 43 | 44 | - Tests are run from a set of distributed benchmark runners. 45 | - Input requests are relatively brief, typically about 1000 tokens, and ask for a brief output response. 46 | - Max output tokens is set to 20, to avoid distortion of TPS values from long outputs. 47 | - A warmup connection is made to remove any connection setup latency. 48 | - The TTFT clock starts when the HTTP request is made and stops when the first token result is received in the response stream. 49 | - For each provider, three separate inferences are done, and the best result is kept (to remove any outliers due to queuing etc). 50 | 51 | ## Initial setup 52 | 53 | This repo uses [Poetry](https://python-poetry.org/) for dependency management. To install the dependencies, run: 54 | 55 | ``` 56 | pip install poetry 57 | poetry install --sync 58 | ``` 59 | 60 | ## Running benchmarks 61 | 62 | To run a benchmark, first set the appropriate environment variable (e.g., OPENAI_API_KEY, ELEVEN_API_KEY) etc, and then run 63 | the appropriate benchmark script. 64 | 65 | ### LLM benchmarks 66 | 67 | To generate LLM benchmarks, use the `llm_benchmark.py` script. For most providers, you can just pass the model name and the script will figure out what API endpoint to invoke. e.g., 68 | 69 | ``` 70 | poetry run python llm_benchmark.py -m gpt-3.5-turbo "Write me a haiku." 71 | ``` 72 | 73 | However, when invoking generic models like Llama2, you'll need to pass in the base_url and api_key via the -b and -k parameters, e.g., 74 | 75 | ``` 76 | poetry run python llm_benchmark.py -k $OCTOML_API_KEY -b https://text.octoai.run/v1 \ 77 | -m llama-2-70b-chat-fp16 "Write me a haiku." 78 | ``` 79 | 80 | Similarly, when invoking Azure OpenAI, you'll need to specify your Azure API key and the base URL of your Azure deployment, e.g., 81 | 82 | ``` 83 | poetry run python llm_benchmark.py -b https://fixie-westus.openai.azure.com \ 84 | -m gpt-4-1106-preview "Write me a haiku." 85 | ``` 86 | 87 | See [this script](https://github.com/fixie-ai/ai-benchmarks/blob/main/llm_benchmark_suite.sh) for more examples of how to invoke various providers. 88 | 89 | #### Options 90 | 91 | ``` 92 | usage: llm_benchmark.py [-h] [--model MODEL] [--temperature TEMPERATURE] [--max-tokens MAX_TOKENS] [--base-url BASE_URL] 93 | [--api-key API_KEY] [--no-warmup] [--num-requests NUM_REQUESTS] [--print] [--verbose] 94 | [prompt] 95 | 96 | positional arguments: 97 | prompt Prompt to send to the API 98 | 99 | optional arguments: 100 | -h, --help show this help message and exit 101 | --model MODEL, -m MODEL Model to benchmark 102 | --temperature TEMPERATURE, -t TEMPERATURE Temperature for the response 103 | --max-tokens, -T MAX_TOKEN Max tokens for the response 104 | --base-url BASE_URL, -b BASE_URL Base URL for the LLM API endpoint 105 | --api-key API_KEY, -k API_KEY API key for the LLM API endpoint 106 | --no-warmup Don't do a warmup call to the API 107 | --num-requests NUM_REQUESTS, -n NUM_REQUESTS Number of requests to make 108 | --print, -p Print the response 109 | --verbose, -v Print verbose output 110 | ``` 111 | 112 | #### Output 113 | 114 | By default a summary of the requests is printed: 115 | 116 | ``` 117 | Latency saved: 0.01 seconds <---- Difference between first response time and fastest reponse time 118 | Optimized response time: 0.14 seconds <---- fastest(http_response_time - http_start_time) of N requests 119 | Median response time: 0.15 seconds <---- median(http_response_time - http_start_time) of N requests 120 | Time to first token: 0.34 seconds <---- first_token_time - http_start_time 121 | Tokens: 147 (211 tokens/sec) <---- num_generated_tokens / (last_token_time - first_token_time) 122 | Total time: 1.03 seconds <---- last_token_time - http_start_time 123 | ``` 124 | 125 | You can specify -p to print the output of the LLM, or -v to see detailed timing for each request. 126 | 127 | ### TTS benchmarks 128 | 129 | To generate TTS benchmarks, there are various scripts for the individual providers, e.g., 130 | 131 | ``` 132 | python elevenlabs_stream_benchmark.py "Haikus I find tricky, With a 5-7-5 count, But I'll give it a go" 133 | ``` 134 | 135 | #### Playing audio 136 | 137 | By default, only timing information for TTS is emitted. Follow the steps below to actually play out the received audio. 138 | 139 | First, install `mpv` via 140 | 141 | ``` 142 | brew install mpv 143 | ``` 144 | 145 | Then, just pass the -p argument when generating text, e.g., 146 | 147 | ``` 148 | python playht_benchmark.py -p "Well, basically I have intuition." 149 | ``` 150 | 151 | You can use the -v parameter to select which voice to use for generation. 152 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import fastapi 2 | 3 | import llm_benchmark_suite 4 | 5 | app = fastapi.FastAPI() 6 | 7 | 8 | @app.get("/") 9 | async def root(): 10 | return fastapi.Response( 11 | status_code=302, headers={"location": "https://thefastest.ai"} 12 | ) 13 | 14 | 15 | @app.route("/bench", methods=["GET", "POST"]) 16 | async def bench(req: fastapi.Request): 17 | text, content_type = await llm_benchmark_suite.run(req.query_params) 18 | return fastapi.Response(content=text, media_type=content_type) 19 | 20 | 21 | if __name__ == "__main__": 22 | import uvicorn 23 | 24 | uvicorn.run(app, host="0.0.0.0", port=8000) # Run the app with uvicorn on port 8000 25 | -------------------------------------------------------------------------------- /dalle.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import asyncio 4 | import base64 5 | import dataclasses 6 | import json 7 | import os 8 | import time 9 | import urllib 10 | from typing import Generator, Optional 11 | 12 | import aiohttp 13 | 14 | DEFAULT_PROMPT = "A pixel art version of the Mona Lisa." 15 | API_VERSION = "2023-12-01-preview" 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "prompt", 20 | type=str, 21 | nargs="?", 22 | default=DEFAULT_PROMPT, 23 | help="Prompt to send to the API", 24 | ) 25 | parser.add_argument( 26 | "--model", 27 | "-m", 28 | type=str, 29 | default="dall-e-3", 30 | help="Model to use", 31 | ) 32 | parser.add_argument( 33 | "--num-images", 34 | "-n", 35 | type=int, 36 | default=1, 37 | help="Number of images to generate", 38 | ) 39 | parser.add_argument( 40 | "--image-size", 41 | "-s", 42 | type=int, 43 | default=1024, 44 | help="Size of image to generate", 45 | ) 46 | parser.add_argument( 47 | "--base-url", 48 | "-b", 49 | type=str, 50 | ) 51 | parser.add_argument( 52 | "--api-key", 53 | "-k", 54 | type=str, 55 | ) 56 | parser.add_argument( 57 | "--play", 58 | "-p", 59 | action="store_true", 60 | help="Display the image after generation", 61 | ) 62 | parser.add_argument( 63 | "--minimal", 64 | action="store_true", 65 | dest="minimal", 66 | help="Print minimal output", 67 | ) 68 | args = parser.parse_args() 69 | 70 | 71 | @dataclasses.dataclass 72 | class ApiContext: 73 | session: aiohttp.ClientSession 74 | index: int 75 | model: str 76 | prompt: str 77 | 78 | 79 | @dataclasses.dataclass 80 | class ApiResult: 81 | def __init__(self, index, start_time, response): 82 | self.index = index 83 | self.start_time = start_time 84 | self.latency = time.time() - start_time 85 | self.response = response 86 | 87 | index: int 88 | start_time: int 89 | latency: float # HTTP response time 90 | response: aiohttp.ClientResponse 91 | chunk_gen: Generator[str, None, None] 92 | 93 | 94 | async def post(context: ApiContext, url: str, headers: dict, data: dict): 95 | start_time = time.time() 96 | response = await context.session.post(url, headers=headers, data=json.dumps(data)) 97 | return ApiResult(context.index, start_time, response) 98 | 99 | 100 | def get_api_key(env_var: str) -> str: 101 | if args.api_key: 102 | return args.api_key 103 | if env_var in os.environ: 104 | return os.environ[env_var] 105 | raise ValueError(f"Missing API key: {env_var}") 106 | 107 | 108 | def make_headers(auth_token: Optional[str] = None, x_api_key: Optional[str] = None): 109 | headers = { 110 | "content-type": "application/json", 111 | } 112 | if auth_token: 113 | headers["authorization"] = f"Bearer {auth_token}" 114 | if x_api_key: 115 | headers["x-api-key"] = x_api_key 116 | return headers 117 | 118 | 119 | def make_openai_url_and_headers(model: str, path: str): 120 | url = args.base_url or "https://api.openai.com/v1" 121 | hostname = urllib.parse.urlparse(url).hostname 122 | use_azure = hostname and hostname.endswith(".azure.com") 123 | headers = { 124 | "Content-Type": "application/json", 125 | } 126 | if use_azure: 127 | api_key = get_api_key("AZURE_OPENAI_API_KEY") 128 | headers["Api-Key"] = api_key 129 | url += f"/openai/deployments/{model.replace('.', '')}{path}?api-version={API_VERSION}" 130 | else: 131 | api_key = get_api_key("OPENAI_API_KEY") 132 | headers["Authorization"] = f"Bearer {api_key}" 133 | url += path 134 | return url, headers 135 | 136 | 137 | async def dalle_image(context: ApiContext) -> ApiResult: 138 | url, headers = make_openai_url_and_headers(context.model, "/images/generations") 139 | data = { 140 | "model": context.model, 141 | "prompt": context.prompt, 142 | "n": args.num_images, 143 | "size": f"{args.image_size}x{args.image_size}", 144 | "response_format": "b64_json", 145 | } 146 | return await post(context, url, headers, data) 147 | 148 | 149 | async def async_main(): 150 | async with aiohttp.ClientSession() as session: 151 | fq_model = ( 152 | args.model if not args.base_url else f"{args.base_url[8:]}/{args.model}" 153 | ) 154 | if not args.minimal: 155 | print(f"Invoking {fq_model}...") 156 | result = await dalle_image(ApiContext(session, 0, args.model, args.prompt)) 157 | if not result.response.ok: 158 | print(f"Error: {result.response.status} {result.response.reason}") 159 | return 160 | 161 | data = await result.response.json() 162 | end_time = time.time() 163 | 164 | latency = result.latency 165 | total_time = end_time - result.start_time 166 | if not args.minimal: 167 | print(f"Response time: {latency:.2f} seconds") 168 | print(f"Total time: {total_time:.2f} seconds") 169 | else: 170 | print(f"{fq_model:48} | {latency:5.2f} | {total_time:5.2f}") 171 | if args.play: 172 | with open("image.png", "wb") as f: 173 | b64 = data["data"][0]["b64_json"] 174 | f.write(base64.b64decode(b64)) 175 | os.system("open image.png") 176 | 177 | 178 | asyncio.run(async_main()) 179 | -------------------------------------------------------------------------------- /dalle_suite.sh: -------------------------------------------------------------------------------- 1 | echo "Provider/Model | TTR | Total" 2 | python dalle.py --minimal "$@" 3 | python dalle.py --minimal -k $AZURE_SECENTRAL_OPENAI_API_KEY -b https://fixie-secentral.openai.azure.com "$@" 4 | python dalle.py --minimal -k $AZURE_EASTUS_OPENAI_API_KEY -b https://fixie-eastus.openai.azure.com "$@" 5 | -------------------------------------------------------------------------------- /elevenlabs_api_benchmark.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import time 4 | import os 5 | import argparse 6 | import asyncio 7 | import websockets 8 | import base64 9 | import logging 10 | from typing import Iterator 11 | 12 | logging.basicConfig(level=logging.INFO) 13 | 14 | # Defaults for both scripts 15 | DEFAULT_SAMPLES = 10 16 | DEFAULT_TEXT = "Hello World!" 17 | DEFAULT_MODEL_ID = "eleven_monolingual_v1" 18 | DEFAULT_LATENCY_OPTIMIZER = 4 19 | DEFAULT_VOICE_ID = "pNInz6obpgDQGcFmaJgB" 20 | DEFAULT_OUTPUT_FORMAT = "mp3_44100" 21 | DEFAULT_STABILITY = 0.5 22 | DEFAULT_SIMILARITY_BOOST = False 23 | DEFAULT_XI_API_KEY = os.environ["ELEVEN_API_KEY"], 24 | 25 | # Configuration for HTTP API 26 | DEFAULT_CHUNK_SIZE = 7868 27 | 28 | # Configuration for WebSocket API 29 | chunk_length_schedule = [50] 30 | max_length = 10 # Maximum length for audio string truncation 31 | delay_time = 0.0001 # Use this to simulate the token output speed of your LLM 32 | try_trigger_generation = True 33 | 34 | 35 | # Argument parsing 36 | parser = argparse.ArgumentParser( 37 | formatter_class=argparse.RawDescriptionHelpFormatter, 38 | description='''\ 39 | The script allows for comprehensive benchmarking of the 11Labs API for text-to-speech generation to achieve the lowest possible latency, given any combination of parameters. 40 | ''') 41 | 42 | API_group = parser.add_argument_group('API Type') 43 | API_group.add_argument("--API", choices=["http", "websocket"], required=True, 44 | help="API type: 'http' or 'websocket'") 45 | 46 | input_group = parser.add_argument_group('Input Parameters') 47 | input_group.add_argument("--text", default=DEFAULT_TEXT, 48 | help="Input text for speech synthesis") 49 | input_group.add_argument("--model", default=DEFAULT_MODEL_ID, 50 | help="Model ID for speech synthesis. Options: 'eleven_monolingual_v1', 'eleven_english_v2', 'eleven_multilingual_v1', 'eleven_multilingual_v2'") 51 | 52 | output_group = parser.add_argument_group('Output Parameters') 53 | output_group.add_argument("--num_samples", type=int, default=DEFAULT_SAMPLES, 54 | help="Number of speech samples to generate") 55 | output_group.add_argument("--output_format", default=DEFAULT_OUTPUT_FORMAT, 56 | help="Speech output format. Options: 'mp3_44100', 'pcm_16000', 'pcm_22050', 'pcm_24000', 'pcm_44100'") 57 | 58 | http_group = parser.add_argument_group('HTTP API Parameters') 59 | http_group.add_argument("--chunk_size", type=int, default=DEFAULT_CHUNK_SIZE, 60 | help="Size of the first playable chunk in bytes, default is 7868") 61 | 62 | websocket_group = parser.add_argument_group('WebSocket API Parameters') 63 | websocket_group.add_argument("--latency_optimizer", type=int, default=DEFAULT_LATENCY_OPTIMIZER, 64 | help="Latency optimization level. Default is 4. Lower to 3 or less to improve pronunciation of numbers and dates.") 65 | websocket_group.add_argument("--text_chunker", action="store_true", default=False, 66 | help="Enable text chunker for input streaming. This chunks text blocks and sets last char to space, simulating the default behavior of the 11labs Library.") 67 | 68 | general_group = parser.add_argument_group('General Parameters') 69 | general_group.add_argument("--voice_id", default=DEFAULT_VOICE_ID, 70 | help="ID of the voice for speech synthesis") 71 | 72 | args = parser.parse_args() 73 | 74 | 75 | 76 | # Text chunker function 77 | def text_chunker(text: str) -> Iterator[str]: 78 | """ 79 | Used during input streaming to chunk text blocks and set last char to space. 80 | Use this function to simulate the default behavior of the official 11labs Library. 81 | """ 82 | splitters = (".", ",", "?", "!", ";", ":", "—", "-", "(", ")", "[", "]", "}", " ") 83 | buffer = "" 84 | for i, char in enumerate(text): 85 | buffer += char 86 | if i < len(text) - 1: # Check if this is not the last character 87 | next_char = text[i + 1] 88 | if buffer.endswith(splitters) and next_char == " ": 89 | yield buffer if buffer.endswith(" ") else buffer + " " 90 | buffer = "" 91 | if buffer != "": 92 | yield buffer + " " 93 | 94 | # Simulate text stream function 95 | def simulate_text_stream(): 96 | """ 97 | When use_text_chunker is True, use a single text chunk here to process via the text_chunker function from elevenlabs library. 98 | When use_text_chunker is False, you can simulate chunks of text from an LLM by adding more lines like this, in the above text_chunks list: 99 | text_chunks = [ 100 | "Hello ", 101 | "World, ", 102 | "this ", 103 | "is ", 104 | "a ", 105 | "voice ", 106 | "sample! ", 107 | ] 108 | """ 109 | text_chunks = ["Hello world! This is a sample of a streaming voice. "] 110 | for text_chunk in text_chunks: 111 | time.sleep(delay_time) 112 | yield text_chunk 113 | 114 | # Truncate audio string function 115 | def truncate_audio_string(audio_string): 116 | """ 117 | Truncate audio string if it exceeds the max_length 118 | """ 119 | if len(audio_string) > max_length: 120 | return audio_string[:max_length] + "..." 121 | return audio_string 122 | 123 | # HTTP API request function 124 | def http_api_request(): 125 | url = f"https://api.elevenlabs.io/v1/text-to-speech/{args.voice_id}/stream?optimize_streaming_latency={args.latency_optimizer}" 126 | headers = { 127 | "accept": "audio/mpeg", 128 | "xi-api-key": DEFAULT_XI_API_KEY, 129 | "Content-Type": "application/json", 130 | } 131 | data = { 132 | "text": args.text, 133 | "model_id": args.model, 134 | "voice_settings": {"stability": DEFAULT_STABILITY, "similarity_boost": DEFAULT_SIMILARITY_BOOST}, 135 | } 136 | response_latencies = [] 137 | chunk_latencies = [] 138 | for i in range(args.num_samples): 139 | print(f"\nAPI Call {i+1}:") 140 | start_time = time.perf_counter() 141 | response = requests.post(url, headers=headers, data=json.dumps(data), stream=True) 142 | if not response.ok: 143 | print("Error: " + response.json()["detail"]["message"]) 144 | exit(1) 145 | response_received_time = time.perf_counter() 146 | response_latency = (response_received_time - start_time) * 1000 147 | response_latencies.append(response_latency) 148 | print(f" Initial Response (Header) Time: {response_latency:.2f} ms") 149 | audio_data = b"" 150 | for chunk in response.iter_content(chunk_size=DEFAULT_CHUNK_SIZE): 151 | if chunk: 152 | audio_data += chunk 153 | if len(audio_data) >= args.chunk_size: 154 | chunk_received_time = time.perf_counter() 155 | chunk_latency = (chunk_received_time - start_time) * 1000 156 | chunk_latencies.append(chunk_latency) 157 | 158 | print(f" First Playable Chunk (Body) Time: {chunk_latency:.2f} ms") 159 | break 160 | 161 | average_response_latency = sum(response_latencies) / len(response_latencies) 162 | median_response_latency = sorted(response_latencies)[len(response_latencies) // 2] 163 | average_chunk_latency = sum(chunk_latencies) / len(chunk_latencies) 164 | median_chunk_latency = sorted(chunk_latencies)[len(chunk_latencies) // 2] 165 | return average_response_latency, median_response_latency, average_chunk_latency, median_chunk_latency 166 | 167 | async def websocket_api_request(): 168 | logging.basicConfig(level=logging.INFO) # Configure logging inside the function 169 | uri = f"wss://api.elevenlabs.io/v1/text-to-speech/{args.voice_id}/stream-input?model_type={args.model}&optimize_streaming_latency={args.latency_optimizer}&output_format={args.output_format}" 170 | start_time = time.time() # Record the time before the request is made 171 | chunk_times = [] 172 | first_chunk_received = False 173 | first_chunk_time = None 174 | async with websockets.connect(uri) as websocket: 175 | connection_open_time = time.time() 176 | time_to_open_connection = connection_open_time - start_time 177 | bos_message = { 178 | "text": " ", 179 | "voice_settings": { 180 | "stability": DEFAULT_STABILITY, 181 | "similarity_boost": DEFAULT_SIMILARITY_BOOST, 182 | }, 183 | "generation_config": {"chunk_length_schedule": chunk_length_schedule}, 184 | "xi_api_key": DEFAULT_XI_API_KEY, 185 | "try_trigger_generation": try_trigger_generation, 186 | } 187 | await websocket.send(json.dumps(bos_message)) 188 | for text_chunk in simulate_text_stream(): 189 | if args.text_chunker: 190 | for chunk in text_chunker(text_chunk): 191 | input_message = { 192 | "text": chunk, 193 | "try_trigger_generation": try_trigger_generation, 194 | } 195 | await websocket.send(json.dumps(input_message)) 196 | else: 197 | input_message = { 198 | "text": text_chunk, 199 | "try_trigger_generation": try_trigger_generation, 200 | } 201 | await websocket.send(json.dumps(input_message)) 202 | try: 203 | response = await asyncio.wait_for(websocket.recv(), timeout=delay_time) 204 | response_received_time = time.time() 205 | data = json.loads(response) 206 | if "audio" in data: 207 | chunk = base64.b64decode(data["audio"]) 208 | truncated_audio_string = truncate_audio_string(data["audio"]) 209 | logging.info(f"Truncated audio string: {truncated_audio_string}") 210 | chunk_received_time = time.time() 211 | if not first_chunk_received: 212 | first_chunk_received = True 213 | first_chunk_time = chunk_received_time - start_time # Calculate the time from the request to the first chunk 214 | chunk_times.append(chunk_received_time - connection_open_time) 215 | except asyncio.TimeoutError: 216 | pass 217 | eos_message = {"text": ""} 218 | await websocket.send(json.dumps(eos_message)) 219 | while True: 220 | try: 221 | response = await websocket.recv() 222 | response_received_time = time.time() 223 | data = json.loads(response) 224 | audio = data.get("audio") 225 | if audio is not None: 226 | chunk = base64.b64decode(data["audio"]) 227 | truncated_audio_string = truncate_audio_string(data["audio"]) 228 | logging.info(f"Truncated audio string: {truncated_audio_string}") 229 | chunk_received_time = time.time() 230 | chunk_times.append(chunk_received_time - connection_open_time) 231 | else: 232 | break 233 | except websockets.exceptions.ConnectionClosed: 234 | break 235 | connection_close_time = time.time() 236 | total_time_websocket_was_open = connection_close_time - connection_open_time 237 | return time_to_open_connection, first_chunk_time, chunk_times, total_time_websocket_was_open 238 | 239 | # Main function 240 | if args.API == "http": 241 | average_response_latency, median_response_latency, average_chunk_latency, median_chunk_latency = http_api_request() 242 | print(f"\nAverage Initial Response (Header) Time: {average_response_latency:.2f} ms") 243 | print(f"Median Initial Response (Header) Time: {median_response_latency:.2f} ms") 244 | print(f"Average First Playable Chunk (Body) Time: {average_chunk_latency:.2f} ms") 245 | print(f"Median First Playable Chunk (Body) Time: {median_chunk_latency:.2f} ms") 246 | elif args.API == "websocket": 247 | time_to_open_connection, first_chunk_time, chunk_times, total_time_websocket_was_open = asyncio.run(websocket_api_request()) 248 | print(f"\nTime to open connection: {time_to_open_connection:.4f} seconds") 249 | if first_chunk_time is not None: 250 | print(f"Time from request to first chunk: {first_chunk_time:.4f} seconds") # Updated print statement 251 | for i, chunk_time in enumerate(chunk_times, start=1): 252 | print(f"Time to receive chunk {i} after request: {chunk_time:.4f} seconds") # Updated print statement 253 | print(f"Total time WebSocket connection was open: {total_time_websocket_was_open:.4f} seconds") 254 | -------------------------------------------------------------------------------- /elevenlabs_stream_benchmark.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import time 4 | import os 5 | import argparse 6 | 7 | DEFAULT_SAMPLES = 10 8 | DEFAULT_TEXT = "I'm calling for Jim." 9 | DEFAULT_MODEL_ID = "eleven_monolingual_v1" 10 | DEFAULT_CHUNK_SIZE = 7868 #This defines the size of the first playable chunk in bytes, which is 7868, roughly equivalent to half a second of audio 11 | DEFAULT_LATENCY_OPTIMIZER = 4 # This can be set to values 1 through 4, with 4 disabling the text normalizer 12 | DEFAULT_VOICE_ID = "flq6f7yk4E4fJM5XTYuZ" 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("text", nargs="?", default=DEFAULT_TEXT) 16 | parser.add_argument( 17 | "--model", 18 | "-m", 19 | default=DEFAULT_MODEL_ID, 20 | ) 21 | parser.add_argument( 22 | "--num_samples", 23 | "-n", 24 | type=int, 25 | default=DEFAULT_SAMPLES, 26 | ) 27 | parser.add_argument( 28 | "--chunk_size", 29 | "-c", 30 | type=int, 31 | default=DEFAULT_CHUNK_SIZE, 32 | ) 33 | parser.add_argument( 34 | "--latency_optimizer", 35 | "-l", 36 | type=int, 37 | default=DEFAULT_LATENCY_OPTIMIZER, 38 | ) 39 | parser.add_argument( 40 | "--voice_id", 41 | "-v", 42 | default=DEFAULT_VOICE_ID, 43 | ) 44 | args = parser.parse_args() 45 | 46 | url = f"https://api.elevenlabs.io/v1/text-to-speech/{args.voice_id}/stream?optimize_streaming_latency={args.latency_optimizer}" 47 | 48 | headers = { 49 | "accept": "audio/mpeg", 50 | "xi-api-key": os.environ["ELEVEN_API_KEY"], 51 | "Content-Type": "application/json", 52 | } 53 | 54 | data = { 55 | "text": args.text, 56 | "model_id": args.model, 57 | "voice_settings": {"stability": 0.5, "similarity_boost": 1}, 58 | } 59 | 60 | response_latencies = [] 61 | chunk_latencies = [] 62 | 63 | for i in range(args.num_samples): 64 | print(f"\nAPI Call {i+1}:") 65 | start_time = time.perf_counter() 66 | response = requests.post(url, headers=headers, data=json.dumps(data), stream=True) 67 | if not response.ok: 68 | print("Error: " + response.json()["detail"]["message"]) 69 | exit(1) 70 | 71 | response_received_time = time.perf_counter() 72 | response_latency = (response_received_time - start_time) * 1000 73 | response_latencies.append(response_latency) 74 | print(f" Initial Response (Header) Time: {response_latency:.2f} ms") 75 | 76 | audio_data = b"" 77 | for chunk in response.iter_content(chunk_size=1024): 78 | if chunk: 79 | audio_data += chunk 80 | if len(audio_data) >= args.chunk_size: 81 | chunk_received_time = time.perf_counter() 82 | chunk_latency = (chunk_received_time - start_time) * 1000 83 | chunk_latencies.append(chunk_latency) 84 | print(f" First Playable Chunk (Body) Time: {chunk_latency:.2f} ms") 85 | break 86 | 87 | with open(f'audio_sample_{i+1}.mp3', 'wb') as f: 88 | f.write(audio_data) 89 | 90 | average_response_latency = sum(response_latencies) / len(response_latencies) 91 | median_response_latency = sorted(response_latencies)[len(response_latencies) // 2] 92 | print(f"\nAverage Initial Response (Header) Time: {average_response_latency:.2f} ms") 93 | print(f"Median Initial Response (Header) Time: {median_response_latency:.2f} ms") 94 | 95 | average_chunk_latency = sum(chunk_latencies) / len(chunk_latencies) 96 | median_chunk_latency = sorted(chunk_latencies)[len(chunk_latencies) // 2] 97 | print(f"\nAverage First Playable Chunk (Body) Time: {average_chunk_latency:.2f} ms") 98 | print(f"Median First Playable Chunk (Body) Time: {median_chunk_latency:.2f} ms") 99 | 100 | -------------------------------------------------------------------------------- /elevenlabs_ws_benchmark.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import websockets 3 | import json 4 | import base64 5 | import time 6 | import logging 7 | from typing import Iterator 8 | import os 9 | import argparse 10 | 11 | # Read some settings from command line 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--voice", "-v") 14 | parser.add_argument("--model", "-m", default="eleven_monolingual_v1") 15 | parser.add_argument("--text-chunker", action="store_true", default=False) 16 | args = parser.parse_args() 17 | 18 | # Set up logging 19 | logging.basicConfig(level=logging.INFO) 20 | 21 | # Configuration section 22 | voice_id = args.voice 23 | # choices: eleven_monolingual_v1, eleven_english_v2, eleven_multilingual_v1, eleven_multilingual_v2 24 | model = args.model 25 | stability = 0.5 26 | similarity_boost = False 27 | chunk_length_schedule = [50] 28 | xi_api_key = os.environ["ELEVEN_API_KEY"] 29 | max_length = 1 # Maximum length for audio string truncation 30 | delay_time = 0.0001 # Use this to simulate the token output speed of your LLM 31 | try_trigger_generation = True 32 | optimize_streaming_latency = "4" # The default setting in the WS API is 4. Change it to 3 or lower to improve the pronunciation of numbers and dates to enable the text normalizer. 33 | use_text_chunker = args.text_chunker 34 | output_format = "mp3_44100" # Output format of the generated audio. Must be one of: mp3_44100, pcm_16000, pcm_22050, pcm_24000, pcm_44100 35 | 36 | 37 | def text_chunker(text: str) -> Iterator[str]: 38 | """ 39 | Used during input streaming to chunk text blocks and set last char to space. 40 | Use this function to simulate the default behavior of the official 11labs Library. 41 | """ 42 | splitters = (".", ",", "?", "!", ";", ":", "—", "-", "(", ")", "[", "]", "}", " ") 43 | buffer = "" 44 | for i, char in enumerate(text): 45 | buffer += char 46 | if i < len(text) - 1: # Check if this is not the last character 47 | next_char = text[i + 1] 48 | if buffer.endswith(splitters) and next_char == " ": 49 | logging.info(f"Chunked text: {buffer}") 50 | yield buffer if buffer.endswith(" ") else buffer + " " 51 | buffer = "" 52 | if buffer != "": 53 | logging.info(f"Chunked text: {buffer}") 54 | yield buffer + " " 55 | 56 | 57 | def simulate_text_stream(): 58 | """ 59 | When use_text_chunker is True, use a single text chunk here to process via the text_chunker function from elevenlabs library. 60 | When use_text_chunker is False, you can simulate chunks of text from an LLM by adding more lines like this, in the above text_chunks list: 61 | text_chunks = [ 62 | "Hello ", 63 | "World, ", 64 | "this ", 65 | "is ", 66 | "a ", 67 | "voice ", 68 | "sample! ", 69 | ] 70 | """ 71 | text_chunks = [ 72 | "Hello world! This is a sample of a streaming voice. ", 73 | ] 74 | for text_chunk in text_chunks: 75 | time.sleep(delay_time) 76 | yield text_chunk 77 | 78 | 79 | def truncate_audio_string(audio_string): 80 | """ 81 | Truncate audio string if it exceeds the max_length 82 | """ 83 | if len(audio_string) > max_length: 84 | return audio_string[:max_length] + "..." 85 | return audio_string 86 | 87 | 88 | async def text_to_speech(): 89 | uri = f"wss://api.elevenlabs.io/v1/text-to-speech/{voice_id}/stream-input?model_type={model}&optimize_streaming_latency={optimize_streaming_latency}&output_format={output_format}" 90 | 91 | start_time = time.time() 92 | chunk_times = [] 93 | first_chunk_received = False 94 | first_chunk_time = None 95 | 96 | async with websockets.connect(uri) as websocket: 97 | print("Connected to WebSocket") 98 | connection_open_time = time.time() 99 | time_to_open_connection = connection_open_time - start_time 100 | 101 | bos_message = { 102 | "text": " ", 103 | "voice_settings": { 104 | "stability": stability, 105 | "similarity_boost": similarity_boost, 106 | }, 107 | "generation_config": {"chunk_length_schedule": chunk_length_schedule}, 108 | "xi_api_key": xi_api_key, 109 | "try_trigger_generation": try_trigger_generation, 110 | } 111 | await websocket.send(json.dumps(bos_message)) 112 | 113 | for text_chunk in simulate_text_stream(): 114 | if use_text_chunker: 115 | for chunk in text_chunker(text_chunk): 116 | input_message = { 117 | "text": chunk, 118 | "try_trigger_generation": try_trigger_generation, 119 | } 120 | input_message_time = time.time() 121 | logging.info( 122 | f"[{input_message_time:.4f}] Sending input message: {chunk}" 123 | ) 124 | await websocket.send(json.dumps(input_message)) 125 | else: 126 | input_message = { 127 | "text": text_chunk, 128 | "try_trigger_generation": try_trigger_generation, 129 | } 130 | input_message_time = time.time() 131 | logging.info( 132 | f"[{input_message_time:.4f}] Sending input message: {text_chunk}" 133 | ) 134 | await websocket.send(json.dumps(input_message)) 135 | 136 | try: 137 | start_waiting_time = time.time() 138 | logging.info(f"[{start_waiting_time:.4f}] Start waiting for response") 139 | response = await asyncio.wait_for(websocket.recv(), timeout=0.0001) 140 | end_waiting_time = time.time() 141 | logging.info(f"[{end_waiting_time:.4f}] End waiting for response") 142 | response_received_time = time.time() 143 | logging.info(f"[{response_received_time:.4f}] Response received") 144 | data = json.loads(response) 145 | 146 | data_copy = data.copy() 147 | 148 | if "audio" in data_copy: 149 | data_copy["audio"] = truncate_audio_string(data_copy["audio"]) 150 | 151 | logging.info(f"Server response: {data_copy}") 152 | 153 | if "audio" in data: 154 | chunk = base64.b64decode(data["audio"]) 155 | logging.info("Received audio chunk") 156 | chunk_received_time = time.time() 157 | if not first_chunk_received: 158 | first_chunk_received = True 159 | first_chunk_time = chunk_received_time - connection_open_time 160 | logging.info( 161 | f"Time to receive first chunk after connection opened: {first_chunk_time:.4f} seconds" 162 | ) 163 | chunk_times.append(chunk_received_time - connection_open_time) 164 | else: 165 | logging.info("No audio data in the response") 166 | except asyncio.TimeoutError: 167 | pass 168 | 169 | eos_message = {"text": ""} 170 | eos_message_time = time.time() 171 | logging.info(f"[{eos_message_time:.4f}] Sending eos_message") 172 | await websocket.send(json.dumps(eos_message)) 173 | 174 | while True: 175 | try: 176 | response = await websocket.recv() 177 | response_received_time = time.time() 178 | logging.info(f"[{response_received_time:.4f}] Response received") 179 | data = json.loads(response) 180 | audio = data.get("audio") 181 | if audio is not None: 182 | truncated_audio = truncate_audio_string(data["audio"]) 183 | logging.info(f"Server response: {{'audio': '{truncated_audio}'}}") 184 | else: 185 | logging.info("Server response:", data) 186 | await asyncio.sleep(0) 187 | 188 | if audio is not None: 189 | chunk = base64.b64decode(data["audio"]) 190 | logging.info("Received audio chunk") 191 | await asyncio.sleep(0) 192 | chunk_received_time = time.time() 193 | chunk_times.append(chunk_received_time - connection_open_time) 194 | else: 195 | logging.info("No audio data in the response") 196 | await asyncio.sleep(0) 197 | break 198 | except websockets.exceptions.ConnectionClosed: 199 | logging.info("Connection closed") 200 | await asyncio.sleep(0) 201 | break 202 | 203 | connection_close_time = time.time() 204 | total_time_websocket_was_open = connection_close_time - connection_open_time 205 | 206 | logging.info("\n-----Latency Summary-----") 207 | logging.info(f"Time to open connection: {time_to_open_connection:.4f} seconds") 208 | if ( 209 | first_chunk_time is not None 210 | ): # Check if first_chunk_time is not None before trying to print it 211 | logging.info( 212 | f"Time to first chunk after connection opened: {first_chunk_time:.4f} seconds" 213 | ) 214 | for i, chunk_time in enumerate(chunk_times, start=1): 215 | logging.info( 216 | f"Time to receive chunk {i} after connection opened: {chunk_time:.4f} seconds" 217 | ) 218 | logging.info( 219 | f"Total time WebSocket connection was open: {total_time_websocket_was_open:.4f} seconds" 220 | ) 221 | 222 | 223 | asyncio.get_event_loop().run_until_complete(text_to_speech()) 224 | -------------------------------------------------------------------------------- /fly.toml: -------------------------------------------------------------------------------- 1 | # fly.toml app configuration file generated for ai-benchmarks on 2024-04-11T13:20:45-07:00 2 | # 3 | # See https://fly.io/docs/reference/configuration/ for information about how to use this file. 4 | # 5 | 6 | app = 'ai-benchmarks' 7 | primary_region = 'sea' 8 | 9 | [build] 10 | builder = "paketobuildpacks/builder:base" 11 | 12 | [env] 13 | PORT = '8080' 14 | 15 | [http_service] 16 | internal_port = 8080 17 | force_https = true 18 | auto_stop_machines = true 19 | auto_start_machines = true 20 | min_machines_running = 0 21 | processes = ['app'] 22 | 23 | [[vm]] 24 | size = "performance-2x" 25 | -------------------------------------------------------------------------------- /llm_benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import asyncio 4 | import json 5 | import time 6 | from typing import List 7 | 8 | import aiohttp 9 | 10 | import llm_request 11 | 12 | DEFAULT_PROMPT = "Write a nonet about a sunset." 13 | DEFAULT_MAX_TOKENS = 100 14 | DEFAULT_NUM_REQUESTS = 4 15 | 16 | FMT_DEFAULT = "default" 17 | FMT_MINIMAL = "minimal" 18 | FMT_JSON = "json" 19 | FMT_NONE = "none" 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument( 23 | "prompt", 24 | type=str, 25 | nargs="?", 26 | default=DEFAULT_PROMPT, 27 | help="Prompt to send to the API", 28 | ) 29 | parser.add_argument( 30 | "--file", 31 | "-f", 32 | type=str, 33 | action="append", 34 | help="Multimedia file(s) to include with the prompt", 35 | ) 36 | parser.add_argument( 37 | "--tool", 38 | type=argparse.FileType("r"), 39 | action="append", 40 | help="JSON file defining tools that can be used", 41 | ) 42 | parser.add_argument( 43 | "--strict", action="store_true", help="Use strict mode when using tools" 44 | ) 45 | parser.add_argument( 46 | "--model", 47 | "-m", 48 | type=str, 49 | default="", 50 | help="Model to benchmark", 51 | ) 52 | parser.add_argument( 53 | "--display-name", 54 | "-N", 55 | type=str, 56 | help="Display name for the model", 57 | ) 58 | parser.add_argument( 59 | "--peft", 60 | type=str, 61 | help="PEFT adapter to use", 62 | ) 63 | parser.add_argument( 64 | "--temperature", 65 | "-t", 66 | type=float, 67 | default=0.0, 68 | help="Temperature for the response", 69 | ) 70 | parser.add_argument( 71 | "--max-tokens", 72 | "-T", 73 | type=int, 74 | default=DEFAULT_MAX_TOKENS, 75 | help="Max tokens for the response", 76 | ) 77 | parser.add_argument( 78 | "--detail", 79 | "-d", 80 | help="Image detail level to use, low or high", 81 | ) 82 | parser.add_argument( 83 | "--base-url", 84 | "-b", 85 | type=str, 86 | default=None, 87 | help="Base URL for the LLM API endpoint", 88 | ) 89 | parser.add_argument( 90 | "--api-key", 91 | "-k", 92 | type=str, 93 | default=None, 94 | help="API key for the LLM API endpoint", 95 | ) 96 | parser.add_argument( 97 | "--no-warmup", 98 | action="store_false", 99 | dest="warmup", 100 | help="Don't do a warmup call to the API", 101 | ) 102 | parser.add_argument( 103 | "--no-reuse-connections", 104 | action="store_false", 105 | dest="reuse_connections", 106 | help="Don't reuse connections", 107 | ) 108 | parser.add_argument( 109 | "--num-requests", 110 | "-n", 111 | type=int, 112 | default=DEFAULT_NUM_REQUESTS, 113 | help="Number of requests to make", 114 | ) 115 | parser.add_argument( 116 | "--parallel-requests", 117 | "-P", 118 | type=int, 119 | default=100, 120 | help="Number of requests to make in parallel", 121 | ) 122 | parser.add_argument( 123 | "--print", 124 | "-p", 125 | action="store_true", 126 | dest="print", 127 | help="Print the response", 128 | ) 129 | parser.add_argument( 130 | "--verbose", 131 | "-v", 132 | action="store_true", 133 | dest="verbose", 134 | help="Print verbose output", 135 | ) 136 | parser.add_argument( 137 | "--format", 138 | "-F", 139 | type=str, 140 | default=FMT_DEFAULT, 141 | ) 142 | parser.add_argument( 143 | "--timeout", 144 | type=float, 145 | default=30.0, 146 | help="Timeout for the API call", 147 | ) 148 | 149 | 150 | class LlmTraceConfig(aiohttp.TraceConfig): 151 | def __init__(self, *args, **kwargs): 152 | super().__init__() 153 | self.on_request_start.append(self._on_request_start_func) 154 | self.on_connection_create_end.append(self._on_connection_create_end_func) 155 | self.on_connection_reuseconn.append(self._on_connection_reuseconn_func) 156 | self.on_request_headers_sent.append(self._on_request_headers_sent_func) 157 | self.on_request_chunk_sent.append(self._on_request_chunk_sent_func) 158 | 159 | async def _on_request_start_func(self, session, ctx, params): 160 | ctx.url = params.url 161 | ctx.start_time = time.time() 162 | 163 | async def _on_connection_create_end_func(self, session, ctx, params): 164 | self._trace(ctx, "created connection") 165 | 166 | async def _on_connection_reuseconn_func(self, session, ctx, params): 167 | self._trace(ctx, "reused connection") 168 | 169 | async def _on_request_headers_sent_func(self, session, ctx, params): 170 | self._trace(ctx, "sent headers") 171 | 172 | async def _on_request_chunk_sent_func(self, session, ctx, params): 173 | self._trace(ctx, "sent chunk") 174 | 175 | def _trace(self, ctx, action): 176 | delta = time.time() - ctx.start_time 177 | print(f"[{delta:.3f}] {ctx.url.host}: {action}") 178 | 179 | 180 | async def main(args: argparse.Namespace): 181 | if not args.model and not args.base_url: 182 | print("Either MODEL or BASE_URL must be specified") 183 | return None 184 | 185 | # Run the queries. 186 | prompt = args.prompt 187 | if prompt.startswith("@"): 188 | with open(prompt[1:], "r") as f: 189 | prompt = f.read() 190 | tools = [json.load(tool) for tool in args.tool or []] 191 | files = [llm_request.InputFile.from_file(file) for file in args.file or []] 192 | timeout = aiohttp.ClientTimeout(total=args.timeout) 193 | trace_configs = [LlmTraceConfig()] if args.verbose else [] 194 | connector = aiohttp.TCPConnector(force_close=not args.reuse_connections) 195 | async with aiohttp.ClientSession( 196 | timeout=timeout, trace_configs=trace_configs, connector=connector 197 | ) as session: 198 | contexts = [ 199 | llm_request.make_context(session, i, args, prompt, files, tools) 200 | for i in range(args.num_requests) 201 | ] 202 | name = contexts[0].name 203 | chosen = None 204 | 205 | def on_token(ctx: llm_request.ApiContext, token: str): 206 | nonlocal chosen 207 | if not chosen: 208 | chosen = ctx 209 | if args.format == FMT_DEFAULT: 210 | ttft = chosen.metrics.ttft 211 | print(f"Chosen API Call: {chosen.index} ({ttft:.2f}s)") 212 | if ctx == chosen: 213 | if args.print: 214 | if token: 215 | print(token, end="", flush=True) 216 | else: 217 | print("\n") 218 | 219 | num_parallel = max(min(args.parallel_requests, args.num_requests), 1) 220 | warmup_contexts = [ 221 | llm_request.make_context(session, -1, args) for _ in range(num_parallel) 222 | ] 223 | if args.verbose: 224 | print(f"Warming up {len(warmup_contexts)} connections...") 225 | # Do a warmup call to make sure the connection is ready, 226 | # and sleep it off to make sure it doesn't affect rate limits. 227 | await asyncio.gather(*[ctx.run() for ctx in warmup_contexts]) 228 | await asyncio.sleep(1.0) 229 | for i in range(len(warmup_contexts)): 230 | if warmup_contexts[i].ws: 231 | contexts[i].ws = warmup_contexts[i].ws 232 | 233 | if args.format == FMT_DEFAULT: 234 | print( 235 | f"Sending {args.num_requests} API calls ({num_parallel} at a time) to {name}..." 236 | ) 237 | for i in range(0, args.num_requests, num_parallel): 238 | tasks = [ 239 | asyncio.create_task(ctx.run(on_token)) 240 | for ctx in contexts[i : i + num_parallel] 241 | ] 242 | await asyncio.gather(*tasks) 243 | 244 | # Bail out if there were no successful API calls. 245 | task0_metrics = contexts[0].metrics 246 | if not chosen: 247 | if args.format == FMT_DEFAULT: 248 | print( 249 | f"No successful API calls for {name}. Sample error: {task0_metrics.error}" 250 | ) 251 | return task0_metrics 252 | 253 | # Print results. 254 | if args.verbose: 255 | for ctx in contexts: 256 | r = ctx.metrics 257 | if not r.error: 258 | print( 259 | f"API Call {ctx.index}: TTFT={r.ttft:.2f}s, Total={r.total_time:.2f}s" 260 | ) 261 | else: 262 | print(f"API Call {ctx.index}: {r.error}") 263 | print("") 264 | 265 | metrics = [ctx.metrics for ctx in contexts if not ctx.metrics.error] 266 | metrics.sort(key=lambda x: x.ttft) 267 | r = metrics[0] 268 | if args.format == FMT_DEFAULT: 269 | latency_saved = task0_metrics.ttft - r.ttft 270 | med_index1 = (len(metrics) - 1) // 2 271 | med_index2 = len(metrics) // 2 272 | median_latency = (metrics[med_index1].ttft + metrics[med_index2].ttft) / 2 273 | print(f"Latency saved: {latency_saved:.2f} seconds") 274 | print(f"Optimized TTFT: {r.ttft:.2f} seconds") 275 | print(f"Median TTFT: {median_latency:.2f} seconds") 276 | if r.output_tokens: 277 | print(f"Tokens: {r.output_tokens} ({r.tps:.0f} tokens/sec)") 278 | print(f"Total time: {r.total_time:.2f} seconds") 279 | elif args.format == "minimal": 280 | assert r.output 281 | minimal_output = r.error or r.output.replace("\n", "\\n").strip()[:64] 282 | print( 283 | f"| {r.model:42} | {r.ttr:4.2f} | {r.ttft:4.2f} | {r.tps:4.0f} " 284 | f"| {r.output_tokens:3} | {r.total_time:5.2f} | {minimal_output} |" 285 | ) 286 | elif args.format == "json": 287 | print(r.to_json(indent=2)) 288 | return r 289 | 290 | 291 | async def run(argv: List[str]): 292 | args = parser.parse_args(argv) 293 | return await main(args) 294 | 295 | 296 | if __name__ == "__main__": 297 | args = parser.parse_args() 298 | asyncio.run(main(args)) 299 | -------------------------------------------------------------------------------- /llm_benchmark_suite.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import dataclasses 4 | import datetime 5 | import os 6 | import random 7 | import sys 8 | from typing import Any, Dict, List, Optional, Tuple 9 | 10 | import dataclasses_json 11 | import gcloud.aio.storage as gcs 12 | 13 | import llm_benchmark 14 | import llm_request 15 | 16 | DEFAULT_DISPLAY_LENGTH = 64 17 | DEFAULT_GCS_BUCKET = "thefastest-data" 18 | 19 | GPT_4O_REALTIME_PREVIEW = "gpt-4o-realtime-preview-2024-10-01" 20 | GPT_4O = "gpt-4o" 21 | GPT_4O_MINI = "gpt-4o-mini" 22 | GPT_4_TURBO = "gpt-4-turbo" 23 | GPT_4_0125_PREVIEW = "gpt-4-0125-preview" 24 | GPT_4_1106_PREVIEW = "gpt-4-1106-preview" 25 | GPT_35_TURBO = "gpt-3.5-turbo" 26 | GPT_35_TURBO_0125 = "gpt-3.5-turbo-0125" 27 | GPT_35_TURBO_1106 = "gpt-3.5-turbo-1106" 28 | GEMINI_1_5_PRO = "gemini-1.5-pro" 29 | GEMINI_1_5_FLASH = "gemini-1.5-flash" 30 | LLAMA_31_405B_CHAT = "llama-3.1-405b-chat" 31 | LLAMA_31_405B_CHAT_FP8 = "llama-3.1-405b-chat-fp8" 32 | LLAMA_31_70B_CHAT = "llama-3.1-70b-chat" 33 | LLAMA_31_70B_CHAT_FP8 = "llama-3.1-70b-chat-fp8" 34 | LLAMA_31_8B_CHAT = "llama-3.1-8b-chat" 35 | LLAMA_31_8B_CHAT_FP8 = "llama-3.1-8b-chat-fp8" 36 | LLAMA_3_70B_CHAT = "llama-3-70b-chat" 37 | LLAMA_3_70B_CHAT_FP8 = "llama-3-70b-chat-fp8" 38 | LLAMA_3_70B_CHAT_FP4 = "llama-3-70b-chat-fp4" 39 | LLAMA_3_8B_CHAT = "llama-3-8b-chat" 40 | LLAMA_3_8B_CHAT_FP8 = "llama-3-8b-chat-fp8" 41 | LLAMA_3_8B_CHAT_FP4 = "llama-3-8b-chat-fp4" 42 | MIXTRAL_8X7B_INSTRUCT = "mixtral-8x7b-instruct" 43 | MIXTRAL_8X7B_INSTRUCT_FP8 = "mixtral-8x7b-instruct-fp8" 44 | 45 | 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument( 48 | "--format", 49 | "-F", 50 | choices=["text", "json"], 51 | default="text", 52 | help="Output results in the specified format", 53 | ) 54 | parser.add_argument( 55 | "--mode", 56 | "-m", 57 | choices=["text", "tools", "image", "audio", "video"], 58 | default="text", 59 | help="Mode to run benchmarks for", 60 | ) 61 | parser.add_argument( 62 | "--filter", 63 | "-r", 64 | help="Filter models by name", 65 | ) 66 | parser.add_argument( 67 | "--spread", 68 | "-s", 69 | type=float, 70 | default=0.0, 71 | help="Spread the requests out over the specified time in seconds", 72 | ) 73 | parser.add_argument( 74 | "--display-length", 75 | "-l", 76 | type=int, 77 | default=DEFAULT_DISPLAY_LENGTH, 78 | help="Amount of the generation response to display", 79 | ) 80 | parser.add_argument( 81 | "--store", 82 | action="store_true", 83 | help="Store the results in the configured GCP bucket", 84 | ) 85 | 86 | 87 | def _dict_to_argv(d: Dict[str, Any]) -> List[str]: 88 | return [ 89 | f"--{k.replace('_', '-')}" + (f"={v}" if v or v == 0 else "") 90 | for k, v in d.items() 91 | ] 92 | 93 | 94 | class _Llm: 95 | """ 96 | We maintain a dict of params for the llm, as well as any 97 | command-line flags that we didn't already handle. We'll 98 | turn this into a single command line for llm_benchmark.run 99 | to consume, which allows us to reuse the parsing logic 100 | from that script, rather than having to duplicate it here. 101 | """ 102 | 103 | def __init__( 104 | self, 105 | model: str, 106 | display_name: Optional[str] = None, 107 | peft: Optional[str] = None, 108 | **kwargs, 109 | ): 110 | self.args = { 111 | "format": "none", 112 | **kwargs, 113 | } 114 | if model: 115 | self.args["model"] = model 116 | if display_name: 117 | self.args["display_name"] = display_name 118 | if peft: 119 | self.args["peft"] = peft 120 | 121 | async def run(self, pass_argv: List[str], spread: float) -> asyncio.Task: 122 | if spread: 123 | await asyncio.sleep(spread) 124 | full_argv = _dict_to_argv(self.args) + pass_argv 125 | return await llm_benchmark.run(full_argv) 126 | 127 | 128 | class _CerebrasLlm(_Llm): 129 | """See https://docs.cerebras.ai/en/latest/wsc/Model-zoo/MZ-overview.html#list-of-models""" 130 | 131 | def __init__(self, model: str, display_model: Optional[str] = None): 132 | super().__init__( 133 | model, 134 | "cerebras.ai/" + (display_model or model), 135 | api_key=os.getenv("CEREBRAS_API_KEY"), 136 | base_url="https://api.cerebras.ai/v1", 137 | ) 138 | 139 | 140 | class _CloudflareLlm(_Llm): 141 | """See https://developers.cloudflare.com/workers-ai/models/""" 142 | 143 | def __init__(self, model: str, display_model: Optional[str] = None): 144 | super().__init__( 145 | model, 146 | "cloudflare.com/" + (display_model or model), 147 | ) 148 | 149 | 150 | class _DeepInfraLlm(_Llm): 151 | """See https://deepinfra.com/models""" 152 | 153 | def __init__(self, model: str, display_model: Optional[str] = None): 154 | super().__init__( 155 | model, 156 | "deepinfra.com/" + (display_model or model), 157 | api_key=os.getenv("DEEPINFRA_API_TOKEN"), 158 | base_url="https://api.deepinfra.com/v1/openai", 159 | ) 160 | 161 | 162 | class _DatabricksLlm(_Llm): 163 | """See https://docs.databricks.com/en/machine-learning/foundation-models/supported-models.html""" 164 | 165 | def __init__(self, model: str, display_model: Optional[str] = None): 166 | super().__init__( 167 | model, 168 | "databricks.com/" + (display_model or model), 169 | api_key=os.getenv("DATABRICKS_TOKEN"), 170 | base_url="https://adb-1558081827343359.19.azuredatabricks.net/serving-endpoints", 171 | ) 172 | 173 | 174 | class _FireworksLlm(_Llm): 175 | """See https://fireworks.ai/models""" 176 | 177 | def __init__(self, model: str, display_model: Optional[str] = None): 178 | super().__init__( 179 | model, 180 | "fireworks.ai/" + (display_model or model), 181 | api_key=os.getenv("FIREWORKS_API_KEY"), 182 | base_url="https://api.fireworks.ai/inference/v1", 183 | ) 184 | 185 | 186 | class _GroqLlm(_Llm): 187 | """See https://console.groq.com/docs/models""" 188 | 189 | def __init__(self, model: str, display_model: Optional[str] = None): 190 | super().__init__( 191 | model, 192 | "groq.com/" + (display_model or model), 193 | api_key=os.getenv("GROQ_API_KEY"), 194 | base_url="https://api.groq.com/openai/v1", 195 | ) 196 | 197 | 198 | class _MistralLlm(_Llm): 199 | """See https://docs.mistral.ai/getting-started/models""" 200 | 201 | def __init__(self, model: str, display_model: Optional[str] = None): 202 | super().__init__( 203 | model, 204 | "mistral.ai/" + (display_model or model), 205 | api_key=os.getenv("MISTRAL_API_KEY"), 206 | base_url="https://api.mistral.ai/v1", 207 | ) 208 | 209 | 210 | class _NvidiaLlm(_Llm): 211 | """See https://build.nvidia.com/explore/discover""" 212 | 213 | def __init__(self, model: str, display_model: Optional[str] = None): 214 | super().__init__( 215 | model, 216 | "nvidia.com/" + (display_model or model), 217 | api_key=os.getenv("NVIDIA_API_KEY"), 218 | base_url="https://integrate.api.nvidia.com/v1", 219 | ) 220 | 221 | 222 | class _OvhLlm(_Llm): 223 | """See https://llama-3-70b-instruct.endpoints.kepler.ai.cloud.ovh.net/doc""" 224 | 225 | def __init__(self, model: str, display_model: Optional[str] = None): 226 | super().__init__( 227 | "", 228 | f"endpoints.ai.cloud.ovh.net/{model}", 229 | api_key=os.getenv("OVH_AI_ENDPOINTS_API_KEY"), 230 | base_url=f"https://{model}.endpoints.kepler.ai.cloud.ovh.net/api/openai_compat/v1", 231 | ) 232 | 233 | 234 | class _PerplexityLlm(_Llm): 235 | """See https://docs.perplexity.ai/docs/model-cards""" 236 | 237 | def __init__(self, model: str, display_model: Optional[str] = None): 238 | super().__init__( 239 | model, 240 | "perplexity.ai/" + (display_model or model), 241 | api_key=os.getenv("PERPLEXITY_API_KEY"), 242 | base_url="https://api.perplexity.ai", 243 | ) 244 | 245 | 246 | class _TogetherLlm(_Llm): 247 | """See https://docs.together.ai/docs/inference-models""" 248 | 249 | def __init__(self, model: str, display_model: Optional[str] = None): 250 | super().__init__( 251 | model, 252 | "together.ai/" + (display_model or model), 253 | api_key=os.getenv("TOGETHER_API_KEY"), 254 | base_url="https://api.together.xyz/v1", 255 | ) 256 | 257 | 258 | class _UltravoxLlm(_Llm): 259 | """See https://docs.ultravox.ai/docs/models""" 260 | 261 | def __init__(self, model: str, display_model: Optional[str] = None): 262 | super().__init__( 263 | model, 264 | "ultravox.ai/" + (display_model or model), 265 | api_key=os.getenv("ULTRAVOX_API_KEY"), 266 | base_url="https://api.ultravox.ai/api", 267 | ) 268 | 269 | 270 | def _text_models(): 271 | AZURE_EASTUS2_OPENAI_API_KEY = os.getenv("AZURE_EASTUS2_OPENAI_API_KEY") 272 | return [ 273 | # GPT-4o 274 | _Llm(GPT_4O_REALTIME_PREVIEW), 275 | _Llm(GPT_4O), 276 | _Llm( 277 | GPT_4O, 278 | api_key=AZURE_EASTUS2_OPENAI_API_KEY, 279 | base_url="https://fixie-openai-sub-with-gpt4.openai.azure.com", 280 | ), 281 | _Llm(GPT_4O, base_url="https://fixie-westus.openai.azure.com"), 282 | _Llm( 283 | GPT_4O, 284 | api_key=os.getenv("AZURE_NCENTRALUS_OPENAI_API_KEY"), 285 | base_url="https://fixie-centralus.openai.azure.com", 286 | ), 287 | _Llm(GPT_4O_MINI), 288 | # GPT-4 Turbo 289 | _Llm(GPT_4_TURBO), 290 | # GPT-4 Turbo Previews 291 | _Llm(GPT_4_0125_PREVIEW), 292 | _Llm( 293 | GPT_4_0125_PREVIEW, 294 | api_key=os.getenv("AZURE_SCENTRALUS_OPENAI_API_KEY"), 295 | base_url="https://fixie-scentralus.openai.azure.com", 296 | ), 297 | _Llm(GPT_4_1106_PREVIEW), 298 | _Llm(GPT_4_1106_PREVIEW, base_url="https://fixie-westus.openai.azure.com"), 299 | _Llm( 300 | GPT_4_1106_PREVIEW, 301 | api_key=AZURE_EASTUS2_OPENAI_API_KEY, 302 | base_url="https://fixie-openai-sub-with-gpt4.openai.azure.com", 303 | ), 304 | _Llm( 305 | GPT_4_1106_PREVIEW, 306 | api_key=os.getenv("AZURE_FRCENTRAL_OPENAI_API_KEY"), 307 | base_url="https://fixie-frcentral.openai.azure.com", 308 | ), 309 | _Llm( 310 | GPT_4_1106_PREVIEW, 311 | api_key=os.getenv("AZURE_SECENTRAL_OPENAI_API_KEY"), 312 | base_url="https://fixie-secentral.openai.azure.com", 313 | ), 314 | _Llm( 315 | GPT_4_1106_PREVIEW, 316 | api_key=os.getenv("AZURE_UKSOUTH_OPENAI_API_KEY"), 317 | base_url="https://fixie-uksouth.openai.azure.com", 318 | ), 319 | # GPT-3.5 320 | _Llm(GPT_35_TURBO_0125), 321 | _Llm(GPT_35_TURBO_1106), 322 | _Llm(GPT_35_TURBO_1106, base_url="https://fixie-westus.openai.azure.com"), 323 | _Llm( 324 | GPT_35_TURBO, 325 | api_key=AZURE_EASTUS2_OPENAI_API_KEY, 326 | base_url="https://fixie-openai-sub-with-gpt4.openai.azure.com", 327 | ), 328 | # Claude 329 | _Llm("claude-3-opus-20240229"), 330 | _Llm("claude-3-5-sonnet-20240620"), 331 | _Llm("claude-3-sonnet-20240229"), 332 | _Llm("claude-3-haiku-20240307"), 333 | # Cohere 334 | _Llm("command-r-plus"), 335 | _Llm("command-r"), 336 | _Llm("command-light"), 337 | # Gemini 338 | _Llm("gemini-pro"), 339 | _Llm(GEMINI_1_5_PRO), 340 | _Llm(GEMINI_1_5_FLASH), 341 | # Mistral 342 | _MistralLlm("mistral-large-latest", "mistral-large"), 343 | _MistralLlm("open-mistral-nemo", "mistral-nemo"), 344 | # Mistral 8x7b 345 | _DatabricksLlm("databricks-mixtral-8x7b-instruct", MIXTRAL_8X7B_INSTRUCT), 346 | _DeepInfraLlm("mistralai/Mixtral-8x7B-Instruct-v0.1", MIXTRAL_8X7B_INSTRUCT), 347 | _FireworksLlm( 348 | "accounts/fireworks/models/mixtral-8x7b-instruct", MIXTRAL_8X7B_INSTRUCT_FP8 349 | ), 350 | _FireworksLlm( 351 | "accounts/fireworks/models/mixtral-8x7b-instruct-hf", MIXTRAL_8X7B_INSTRUCT 352 | ), 353 | _GroqLlm("mixtral-8x7b-32768", MIXTRAL_8X7B_INSTRUCT_FP8), 354 | _NvidiaLlm("mistralai/mixtral-8x7b-instruct-v0.1-turbo", MIXTRAL_8X7B_INSTRUCT_FP8), 355 | _TogetherLlm("mistralai/Mixtral-8x7B-Instruct-v0.1", MIXTRAL_8X7B_INSTRUCT), 356 | _OvhLlm("mixtral-8x7b-instruct-v01", MIXTRAL_8X7B_INSTRUCT), 357 | # Llama 3.1 405b 358 | _DatabricksLlm("databricks-meta-llama-3.1-405b-instruct", LLAMA_31_405B_CHAT), 359 | _DeepInfraLlm( 360 | "meta-llama/Meta-Llama-3.1-405B-Instruct", LLAMA_31_405B_CHAT_FP8 361 | ), 362 | _FireworksLlm( 363 | "accounts/fireworks/models/llama-v3p1-405b-instruct", LLAMA_31_405B_CHAT_FP8 364 | ), 365 | _GroqLlm("llama-3.1-405b-reasoning", LLAMA_31_405B_CHAT_FP8), 366 | _NvidiaLlm("meta/llama-3.1-405b-instruct-turbo", LLAMA_31_405B_CHAT_FP8), 367 | _TogetherLlm( 368 | "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", LLAMA_31_405B_CHAT_FP8 369 | ), 370 | # _OvhLlm("llama-3-1-405b-instruct", LLAMA_31_405B_CHAT), 371 | # Llama 3.1 70b 372 | _CerebrasLlm("llama3.1-70b", LLAMA_31_70B_CHAT), 373 | _CloudflareLlm("@cf/meta/llama-3.1-70b-preview", LLAMA_31_70B_CHAT), 374 | # _DatabricksLlm("databricks-meta-llama-3.1-70b-instruct", LLAMA_31_70B_CHAT), 375 | _DeepInfraLlm("meta-llama/Meta-Llama-3.1-70B-Instruct", LLAMA_31_70B_CHAT), 376 | _FireworksLlm( 377 | "accounts/fireworks/models/llama-v3p1-70b-instruct", LLAMA_31_70B_CHAT_FP8 378 | ), 379 | _GroqLlm("llama-3.1-70b-versatile", LLAMA_31_70B_CHAT_FP8), 380 | _NvidiaLlm("meta/llama-3.1-70b-instruct-turbo", LLAMA_31_70B_CHAT_FP8), 381 | _PerplexityLlm("llama-3.1-70b-instruct", LLAMA_31_70B_CHAT), 382 | _TogetherLlm( 383 | "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", LLAMA_31_70B_CHAT_FP8 384 | ), 385 | _OvhLlm("llama-3-1-70b-instruct", LLAMA_31_70B_CHAT), 386 | # Llama 3.1 8b 387 | _CerebrasLlm("llama3.1-8b", LLAMA_31_8B_CHAT), 388 | _CloudflareLlm("@cf/meta/llama-3.1-8b-preview", LLAMA_31_8B_CHAT), 389 | # _DatabricksLlm("databricks-meta-llama-3.1-8b-instruct", LLAMA_31_8B_CHAT), 390 | _DeepInfraLlm("meta-llama/Meta-Llama-3.1-8B-Instruct", LLAMA_31_8B_CHAT), 391 | _FireworksLlm( 392 | "accounts/fireworks/models/llama-v3p1-8b-instruct", LLAMA_31_8B_CHAT_FP8 393 | ), 394 | _GroqLlm("llama-3.1-8b-instant", LLAMA_31_8B_CHAT_FP8), 395 | _NvidiaLlm("meta/llama-3.1-8b-instruct-turbo", LLAMA_31_8B_CHAT_FP8), 396 | _PerplexityLlm("llama-3.1-8b-instruct", LLAMA_31_8B_CHAT), 397 | _TogetherLlm( 398 | "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", LLAMA_31_8B_CHAT_FP8 399 | ), 400 | # _OvhLlm("llama-3-1-8b-instruct", LLAMA_31_8B_CHAT), 401 | # Llama 3 70b 402 | _DatabricksLlm("databricks-meta-llama-3-70b-instruct", LLAMA_3_70B_CHAT), 403 | _DeepInfraLlm("meta-llama/Meta-Llama-3-70B-Instruct", LLAMA_3_70B_CHAT), 404 | _FireworksLlm( 405 | "accounts/fireworks/models/llama-v3-70b-instruct", LLAMA_3_70B_CHAT_FP8 406 | ), 407 | _FireworksLlm( 408 | "accounts/fireworks/models/llama-v3-70b-instruct-hf", LLAMA_3_70B_CHAT 409 | ), 410 | _GroqLlm("llama3-70b-8192", LLAMA_3_70B_CHAT_FP8), 411 | _TogetherLlm("meta-llama/Llama-3-70b-chat-hf", LLAMA_3_70B_CHAT), 412 | _TogetherLlm( 413 | "meta-llama/Meta-Llama-3-70B-Instruct-Turbo", LLAMA_3_70B_CHAT_FP8 414 | ), 415 | _TogetherLlm("meta-llama/Meta-Llama-3-70B-Instruct-Lite", LLAMA_3_70B_CHAT_FP4), 416 | _OvhLlm("llama-3-70b-instruct", LLAMA_3_70B_CHAT), 417 | # Finetunes on Llama 3 70b 418 | _FireworksLlm( 419 | "accounts/fixie/models/1b68538a063a49e2ae4513d4ef186e9a", 420 | LLAMA_3_70B_CHAT + "-lora-1b68", 421 | ), 422 | # Llama 3 8b 423 | _CloudflareLlm("@cf/meta/llama-3-8b-instruct", LLAMA_3_8B_CHAT), 424 | _DeepInfraLlm("meta-llama/Meta-Llama-3-8B-Instruct", LLAMA_3_8B_CHAT), 425 | _FireworksLlm( 426 | "accounts/fireworks/models/llama-v3-8b-instruct", LLAMA_3_8B_CHAT_FP8 427 | ), 428 | _FireworksLlm( 429 | "accounts/fireworks/models/llama-v3-8b-instruct-hf", LLAMA_3_8B_CHAT 430 | ), 431 | _GroqLlm("llama3-8b-8192", LLAMA_3_8B_CHAT_FP8), 432 | _TogetherLlm("meta-llama/Llama-3-8b-chat-hf", LLAMA_3_8B_CHAT), 433 | _TogetherLlm("meta-llama/Meta-Llama-3-8B-Instruct-Turbo", LLAMA_3_8B_CHAT_FP8), 434 | _TogetherLlm("meta-llama/Meta-Llama-3-8B-Instruct-Lite", LLAMA_3_8B_CHAT_FP4), 435 | _OvhLlm("llama-3-8b-instruct", LLAMA_3_8B_CHAT), 436 | # Fine-tunes on Llama 3 8b 437 | _FireworksLlm( 438 | "accounts/fixie/models/8ab03ea85d2a4b9da659ce63db36a9b1", 439 | LLAMA_3_8B_CHAT + "-lora-8ab0", 440 | ), 441 | ] 442 | 443 | 444 | def _tools_models(): 445 | return [ 446 | _Llm(GPT_4O), 447 | _Llm(GPT_4O_MINI), 448 | _Llm(GPT_4_TURBO), 449 | _Llm(GPT_4O, GPT_4O + "-strict", strict=None), 450 | _Llm(GPT_4O_MINI, GPT_4O_MINI + "-strict", strict=None), 451 | _Llm(GPT_4_TURBO, GPT_4_TURBO + "-strict", strict=None), 452 | _Llm("claude-3-opus-20240229"), 453 | _Llm("claude-3-5-sonnet-20240620"), 454 | _Llm("claude-3-sonnet-20240229"), 455 | _Llm("claude-3-haiku-20240307"), 456 | _Llm(GEMINI_1_5_PRO), 457 | _Llm(GEMINI_1_5_FLASH), 458 | _FireworksLlm("accounts/fireworks/models/firefunction-v2", "firefunction-v2"), 459 | # _FireworksLlm( 460 | # "accounts/fireworks/models/llama-v3p1-405b-instruct", LLAMA_31_405B_CHAT_FP8 461 | # ), returns "FUNCTION" and the call as text 462 | _GroqLlm("llama-3.1-405b-reasoning", LLAMA_31_405B_CHAT_FP8), 463 | _GroqLlm("llama-3.1-70b-versatile", LLAMA_31_70B_CHAT_FP8), 464 | _GroqLlm("llama-3.1-8b-instant", LLAMA_31_8B_CHAT_FP8), 465 | _GroqLlm("llama3-groq-70b-8192-tool-use-preview"), 466 | _GroqLlm("llama3-groq-8b-8192-tool-use-preview"), 467 | ] 468 | 469 | 470 | def _image_models(): 471 | return [ 472 | _Llm(GPT_4O), 473 | _Llm(GPT_4O_MINI), 474 | _Llm(GPT_4_TURBO), 475 | _Llm("gpt-4-vision-preview", base_url="https://fixie-westus.openai.azure.com"), 476 | _Llm("claude-3-opus-20240229"), 477 | _Llm("claude-3-5-sonnet-20240620"), 478 | _Llm("claude-3-sonnet-20240229"), 479 | _Llm("gemini-pro-vision"), 480 | _Llm(GEMINI_1_5_PRO), 481 | _Llm(GEMINI_1_5_FLASH), 482 | _FireworksLlm( 483 | "accounts/fireworks/models/phi-3-vision-128k-instruct", "phi-3-vision" 484 | ), 485 | _MistralLlm("pixtral-latest", "pixtral"), 486 | ] 487 | 488 | 489 | def _audio_models(): 490 | return [ 491 | _Llm(GPT_4O_REALTIME_PREVIEW), 492 | _Llm(GEMINI_1_5_PRO), 493 | _Llm(GEMINI_1_5_FLASH), 494 | _UltravoxLlm("fixie-ai/ultravox", "ultravox-v0.5-70b"), 495 | _Llm( 496 | "ultravox", 497 | "baseten.co/ultravox-v0.4", 498 | base_url="https://bridge.baseten.co/v1/direct", 499 | api_key=os.getenv("BASETEN_API_KEY"), 500 | ), 501 | ] 502 | 503 | 504 | def _video_models(): 505 | return [ 506 | # _Llm(GPT_4O), 507 | _Llm(GEMINI_1_5_PRO), 508 | _Llm(GEMINI_1_5_FLASH), 509 | ] 510 | 511 | 512 | def _get_models(mode: str, filter: Optional[str] = None): 513 | mode_map = { 514 | "text": _text_models, 515 | "tools": _tools_models, 516 | "image": _image_models, 517 | "audio": _audio_models, 518 | "video": _video_models, 519 | } 520 | if mode not in mode_map: 521 | raise ValueError(f"Unknown mode {mode}") 522 | models = mode_map[mode]() 523 | return [ 524 | m 525 | for m in models 526 | if not filter 527 | or filter in (m.args.get("display_name") or m.args["model"]).lower() 528 | ] 529 | 530 | 531 | def _get_prompt(mode: str) -> List[str]: 532 | if mode == "text": 533 | return ["@media/text/llama31.md"] 534 | elif mode == "tools": 535 | return [ 536 | "I have a flight booked for July 14, 2024, and the flight number is AA100. Please check its status for me.", 537 | "--tool", 538 | "media/tools/flights.json", 539 | ] 540 | elif mode == "image": 541 | return [ 542 | "Based on the image, explain what will happen next.", 543 | "--file", 544 | "media/image/inception.jpeg", 545 | ] 546 | elif mode == "audio": 547 | return [ 548 | "Listen and respond to the following:", 549 | "--file", 550 | "media/audio/boolq.wav", 551 | ] 552 | elif mode == "video": 553 | return [ 554 | "What color is the logo on the screen and how does it relate to what the actor is saying?", 555 | "--file", 556 | "media/video/psa.webm", 557 | ] 558 | raise ValueError(f"Unknown mode {mode}") 559 | 560 | 561 | @dataclasses.dataclass 562 | class _Response(dataclasses_json.DataClassJsonMixin): 563 | time: str 564 | duration: str 565 | region: str 566 | cmd: str 567 | results: List[llm_request.ApiMetrics] 568 | 569 | 570 | def _format_response( 571 | response: _Response, format: str, dlen: int = 0 572 | ) -> Tuple[str, str]: 573 | if format == "json": 574 | return response.to_json(indent=2), "application/json" 575 | else: 576 | s = ( 577 | "| Provider/Model | TTR | TTFT | TPS | ITk | OTk | ITim | OTim | Total |" 578 | f" {'Response':{dlen}.{dlen}} |\n" 579 | "| :----------------------------------------- | ---: | ---: | ---: | ---: | --: | ---: | ---: | ----: |" 580 | f" {':--':-<{dlen}.{dlen}} |\n" 581 | ) 582 | 583 | for r in response.results: 584 | ttr = r.ttr or 0.0 585 | ttft = r.ttft or 0.0 586 | tps = r.tps or 0.0 587 | in_tokens = r.input_tokens or 0 588 | out_tokens = r.output_tokens or 0 589 | in_time = r.provider_input_time or 0 590 | out_time = ( 591 | r.provider_output_time or r.total_time - r.ttft 592 | if out_tokens 593 | else r.ttft 594 | ) 595 | total_time = r.total_time or 0.0 596 | output = (r.error or r.output).strip().replace("\n", "\\n") 597 | s += ( 598 | f"| {r.model[:42]:42} | {ttr:4.2f} | {ttft:4.2f} | {tps:4.0f} " 599 | f"| {in_tokens:4} | {out_tokens:3} | {in_time:4.2f} | {out_time:4.2f} " 600 | f"| {total_time:5.2f} | {output:{dlen}.{dlen}} |\n" 601 | ) 602 | 603 | s += f"\ntime: {response.time}, duration: {response.duration} region: {response.region}, cmd: {response.cmd}\n" 604 | return s, "text/markdown" 605 | 606 | 607 | async def _store_response(gcp_bucket: str, key: str, text: str, content_type: str): 608 | print(f"Storing results in {gcp_bucket}/{key}") 609 | storage = gcs.Storage(service_file="service_account.json") 610 | await storage.upload(gcp_bucket, key, text, content_type=content_type) 611 | await storage.close() 612 | 613 | 614 | async def _run(argv: List[str]) -> Tuple[str, str]: 615 | """ 616 | This function is invoked either from the webapp (via run) or the main function below. 617 | The args we know about are stored in args, and any unknown args are stored in pass_argv, 618 | which we'll pass to the _Llm.run function, who will turn them back into a 619 | single list of flags for consumption by the llm_benchmark.run function. 620 | """ 621 | time_start = datetime.datetime.now() 622 | time_str = time_start.isoformat() 623 | region = os.getenv("FLY_REGION", "local") 624 | cmd = " ".join(argv) 625 | args, pass_argv = parser.parse_known_args(argv) 626 | pass_argv += _get_prompt(args.mode) 627 | models = _get_models(args.mode, args.filter) 628 | tasks = [] 629 | for m in models: 630 | delay = random.uniform(0, args.spread) 631 | tasks.append(asyncio.create_task(m.run(pass_argv, delay))) 632 | await asyncio.gather(*tasks) 633 | results = [t.result() for t in tasks if t.result() is not None] 634 | elapsed = datetime.datetime.now() - time_start 635 | elapsed_str = f"{elapsed.total_seconds():.2f}s" 636 | response = _Response(time_str, elapsed_str, region, cmd, results) 637 | if args.store: 638 | path = f"{region}/{args.mode}/{time_str.split('T')[0]}.json" 639 | json, content_type = _format_response(response, "json") 640 | await _store_response(DEFAULT_GCS_BUCKET, path, json, content_type) 641 | return _format_response(response, args.format, args.display_length) 642 | 643 | 644 | async def run(params: Dict[str, Any]) -> Tuple[str, str]: 645 | return await _run(_dict_to_argv(params)) 646 | 647 | 648 | async def main(): 649 | text, _ = await _run(sys.argv[1:]) 650 | print(text) 651 | 652 | 653 | if __name__ == "__main__": 654 | asyncio.run(main()) 655 | -------------------------------------------------------------------------------- /llm_request.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import base64 4 | import dataclasses 5 | import io 6 | import json 7 | import mimetypes 8 | import os 9 | import re 10 | import time 11 | import urllib 12 | from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple 13 | 14 | import aiohttp 15 | import dataclasses_json 16 | import numpy as np 17 | import soundfile as sf 18 | import soxr 19 | 20 | TokenGenerator = AsyncGenerator[str, None] 21 | ApiResult = Tuple[aiohttp.ClientResponse, TokenGenerator] 22 | 23 | AZURE_OPENAI_API_VERSION = "2024-02-15-preview" 24 | MAX_TPS = 9999 25 | MAX_TTFT = 9.99 26 | MAX_TOTAL_TIME = 99.99 27 | 28 | 29 | @dataclasses.dataclass 30 | class InputFile: 31 | @classmethod 32 | def from_file(cls, path: str): 33 | mime_type, _ = mimetypes.guess_type(path) 34 | if not mime_type: 35 | raise ValueError(f"Unknown file type: {path}") 36 | with open(path, "rb") as f: 37 | data = f.read() 38 | return cls(mime_type, data) 39 | 40 | @classmethod 41 | def from_bytes(cls, mime_type: str, data: bytes): 42 | return cls(mime_type, data) 43 | 44 | mime_type: str 45 | data: bytes 46 | 47 | @property 48 | def is_image(self): 49 | return self.mime_type.startswith("image/") 50 | 51 | @property 52 | def is_audio(self): 53 | return self.mime_type.startswith("audio/") 54 | 55 | @property 56 | def is_video(self): 57 | return self.mime_type.startswith("video/") 58 | 59 | @property 60 | def base64_data(self): 61 | return base64.b64encode(self.data).decode("utf-8") 62 | 63 | 64 | @dataclasses.dataclass 65 | class ApiMetrics(dataclasses_json.DataClassJsonMixin): 66 | model: str 67 | ttr: Optional[float] = None 68 | ttft: Optional[float] = None 69 | tps: Optional[float] = None 70 | input_tokens: Optional[int] = None 71 | output_tokens: Optional[int] = None 72 | total_time: Optional[float] = None 73 | provider_queue_time: Optional[float] = None 74 | provider_input_time: Optional[float] = None 75 | provider_output_time: Optional[float] = None 76 | provider_total_time: Optional[float] = None 77 | output: Optional[str] = None 78 | error: Optional[str] = None 79 | 80 | 81 | @dataclasses.dataclass 82 | class ApiContext: 83 | session: aiohttp.ClientSession 84 | index: int 85 | name: str 86 | func: Callable 87 | model: str 88 | prompt: str 89 | files: List[InputFile] 90 | tools: List[Dict] 91 | strict: bool 92 | temperature: float 93 | max_tokens: int 94 | detail: Optional[str] = None 95 | api_key: Optional[str] = None 96 | base_url: Optional[str] = None 97 | peft: Optional[str] = None 98 | ws: Optional[aiohttp.ClientWebSocketResponse] = None 99 | 100 | def __init__(self, session, index, name, func, args, prompt, files, tools): 101 | self.session = session 102 | self.index = index 103 | self.name = name 104 | self.func = func 105 | self.model = args.model 106 | self.prompt = prompt 107 | self.files = files 108 | self.tools = tools 109 | self.strict = args.strict 110 | self.detail = args.detail 111 | self.temperature = args.temperature 112 | self.max_tokens = args.max_tokens 113 | self.api_key = args.api_key 114 | self.base_url = args.base_url 115 | self.peft = args.peft 116 | self.metrics = ApiMetrics(model=self.name) 117 | 118 | @property 119 | def is_warmup(self): 120 | return self.index == -1 121 | 122 | async def run(self, on_token: Optional[Callable[["ApiContext", str], None]] = None): 123 | response = None 124 | try: 125 | start_time = time.time() 126 | first_token_time = None 127 | response, chunk_gen = await self.func(self) 128 | self.metrics.ttr = time.time() - start_time 129 | if response.ok: 130 | if chunk_gen: 131 | self.metrics.output_tokens = 0 132 | self.metrics.output = "" 133 | async for chunk in chunk_gen: 134 | self.metrics.output += chunk 135 | self.metrics.output_tokens += 1 136 | if not first_token_time: 137 | first_token_time = time.time() 138 | self.metrics.ttft = first_token_time - start_time 139 | if on_token and chunk: 140 | on_token(self, chunk) 141 | if first_token_time: 142 | # Signal the end of the generation. 143 | if on_token: 144 | on_token(self, "") 145 | else: 146 | self.metrics.error = "No tokens received" 147 | else: 148 | text = await response.text() 149 | self.metrics.error = f"{response.status} {response.reason} {text}" 150 | except TimeoutError: 151 | self.metrics.error = "Timeout" 152 | except aiohttp.ClientError as e: 153 | self.metrics.error = str(e) 154 | end_time = time.time() 155 | if not self.metrics.error: 156 | token_time = end_time - first_token_time 157 | self.metrics.total_time = end_time - start_time 158 | self.metrics.tps = min( 159 | (self.metrics.output_tokens - 1) / token_time, MAX_TPS 160 | ) 161 | if self.metrics.tps == MAX_TPS: 162 | self.metrics.tps = 0.0 163 | else: 164 | self.metrics.ttft = MAX_TTFT 165 | self.metrics.tps = 0.0 166 | self.metrics.total_time = MAX_TOTAL_TIME 167 | if response: 168 | await response.release() 169 | 170 | 171 | async def post( 172 | ctx: ApiContext, 173 | url: str, 174 | headers: dict, 175 | data: dict, 176 | make_chunk_gen: Optional[Callable[[aiohttp.ClientResponse], TokenGenerator]] = None, 177 | ): 178 | response = await ctx.session.post(url, headers=headers, data=json.dumps(data)) 179 | chunk_gen = make_chunk_gen(ctx, response) if make_chunk_gen else None 180 | return response, chunk_gen 181 | 182 | 183 | def get_api_key(ctx: ApiContext, env_var: str) -> str: 184 | if ctx.api_key: 185 | return ctx.api_key 186 | if env_var in os.environ: 187 | return os.environ[env_var] 188 | raise ValueError(f"Missing API key: {env_var}") 189 | 190 | 191 | def make_headers( 192 | auth_token: Optional[str] = None, 193 | api_key: Optional[str] = None, 194 | x_api_key: Optional[str] = None, 195 | ): 196 | headers = { 197 | "content-type": "application/json", 198 | } 199 | if auth_token: 200 | headers["authorization"] = f"Bearer {auth_token}" 201 | if api_key: 202 | headers["api-key"] = api_key 203 | if x_api_key: 204 | headers["x-api-key"] = x_api_key 205 | return headers 206 | 207 | 208 | def make_openai_url_and_headers(ctx: ApiContext, path: str): 209 | url = ctx.base_url or "https://api.openai.com/v1" 210 | hostname = urllib.parse.urlparse(url).hostname 211 | use_azure_openai = hostname and hostname.endswith("openai.azure.com") 212 | use_ovh = hostname and hostname.endswith("cloud.ovh.net") 213 | if use_azure_openai: 214 | api_key = get_api_key(ctx, "AZURE_OPENAI_API_KEY") 215 | headers = make_headers(api_key=api_key) 216 | url += f"/openai/deployments/{ctx.model.replace('.', '')}{path}?api-version={AZURE_OPENAI_API_VERSION}" 217 | elif use_ovh: 218 | api_key = get_api_key(ctx, "OVH_AI_ENDPOINTS_API_KEY") 219 | headers = { 220 | "content-type": "application/json", 221 | "authorization": api_key 222 | } 223 | url += path 224 | else: 225 | api_key = ctx.api_key if ctx.base_url else get_api_key(ctx, "OPENAI_API_KEY") 226 | headers = make_headers(auth_token=api_key) 227 | url += path 228 | return url, headers 229 | 230 | 231 | def make_openai_messages(ctx: ApiContext): 232 | if not ctx.files: 233 | return [{"role": "user", "content": ctx.prompt}] 234 | 235 | content: List[Dict[str, Any]] = [{"type": "text", "text": ctx.prompt}] 236 | for file in ctx.files: 237 | url = f"data:{file.mime_type};base64,{file.base64_data}" 238 | media_url = {"url": url} 239 | url_type = "audio_url" if file.is_audio else "image_url" 240 | if ctx.detail: 241 | media_url["detail"] = ctx.detail 242 | content.append({"type": url_type, url_type: media_url}) 243 | return [{"role": "user", "content": content}] 244 | 245 | 246 | def make_openai_ws_message(ctx: ApiContext): 247 | content = [{"type": "input_text", "text": ctx.prompt}] 248 | for file in ctx.files: 249 | if file.is_audio: 250 | audio, sr = sf.read(io.BytesIO(file.data)) 251 | audio_24k = soxr.resample(audio, sr, 24000) 252 | audio_pcm = (audio_24k * 32767).astype(np.int16).tobytes() 253 | b64_data = base64.b64encode(audio_pcm).decode("utf-8") 254 | content.append({"type": "input_audio", "audio": b64_data}) 255 | else: 256 | raise NotImplementedError("Images not yet supported in WebSocket mode") 257 | return {"type": "message", "role": "user", "content": content} 258 | 259 | 260 | def make_openai_chat_body(ctx: ApiContext, **kwargs): 261 | # Models differ in how they want to receive the prompt, so 262 | # we let the caller specify the key and format. 263 | body = { 264 | "model": ctx.model or None, 265 | "max_tokens": ctx.max_tokens, 266 | "temperature": ctx.temperature, 267 | "stream": True, 268 | } 269 | for key, value in kwargs.items(): 270 | body[key] = value 271 | return body 272 | 273 | 274 | async def make_sse_chunk_gen(response) -> AsyncGenerator[Dict[str, Any], None]: 275 | done = False 276 | async for line in response.content: 277 | line = line.decode("utf-8").strip() 278 | if line.startswith("data:"): 279 | content = line[5:].strip() 280 | if content == "[DONE]": 281 | done = True 282 | elif not done: 283 | yield json.loads(content) 284 | 285 | 286 | async def openai_chunk_gen(ctx: ApiContext, response) -> TokenGenerator: 287 | async for chunk in make_sse_chunk_gen(response): 288 | if chunk.get("choices", []): 289 | delta = chunk["choices"][0]["delta"] 290 | delta_content = delta.get("content") 291 | delta_tool = delta.get("tool_calls") 292 | if delta_content: 293 | yield delta_content 294 | elif delta_tool: 295 | function = delta_tool[0]["function"] 296 | name = function.get("name", "").strip() 297 | if name: 298 | yield name 299 | args = function.get("arguments", "").strip() 300 | if args: 301 | yield args 302 | usage = chunk.get("usage") or chunk.get("x_groq", {}).get("usage") 303 | if usage: 304 | ctx.metrics.input_tokens = usage.get("prompt_tokens") 305 | ctx.metrics.output_tokens = usage.get("completion_tokens") 306 | ctx.metrics.provider_queue_time = usage.get("queue_time") 307 | ctx.metrics.provider_input_time = usage.get("prompt_time") 308 | ctx.metrics.provider_output_time = usage.get("completion_time") 309 | ctx.metrics.provider_total_time = usage.get("total_time") 310 | 311 | 312 | async def openai_chat(ctx: ApiContext, path: str = "/chat/completions") -> ApiResult: 313 | url, headers = make_openai_url_and_headers(ctx, path) 314 | kwargs = {"messages": make_openai_messages(ctx)} 315 | if ctx.tools: 316 | tools = ctx.tools[:] 317 | if ctx.strict: 318 | for t in tools: 319 | t["function"]["strict"] = True 320 | t["function"]["parameters"]["additionalProperties"] = False 321 | kwargs["tools"] = tools 322 | kwargs["tool_choice"] = "required" 323 | if ctx.peft: 324 | kwargs["peft"] = ctx.peft 325 | # Some providers require opt-in for stream stats, but some providers don't like this opt-in. 326 | # Regardless of opt-in, Azure and ovh.net don't return stream stats at the moment. 327 | # See https://github.com/Azure/azure-rest-api-specs/issues/25062 328 | if not any(p in ctx.name for p in ["azure", "databricks", "fireworks", "mistral"]): 329 | kwargs["stream_options"] = {"include_usage": True} 330 | # Hack to identify our baseten deployment, which isn't contained in the URL. 331 | if ctx.name.startswith("baseten"): 332 | kwargs["baseten"] = {"model_id": "rwn2v41w"} 333 | data = make_openai_chat_body(ctx, **kwargs) 334 | return await post(ctx, url, headers, data, openai_chunk_gen) 335 | 336 | 337 | async def openai_embed(ctx: ApiContext) -> ApiResult: 338 | url, headers = make_openai_url_and_headers(ctx, "/embeddings") 339 | data = {"model": ctx.model, "input": ctx.prompt} 340 | return await post(ctx, url, headers, data) 341 | 342 | 343 | class WebSocketResponse: 344 | """Mirrors the aiohttp.ClientHttpResponse interface, but for a WebSocket.""" 345 | 346 | def __init__(self, ctx: ApiContext): 347 | self.ctx = ctx 348 | 349 | @property 350 | def ok(self): 351 | return True 352 | 353 | async def release(self): 354 | if not self.ctx.is_warmup: 355 | await self.ctx.ws.close() 356 | 357 | 358 | async def openai_ws(ctx: ApiContext) -> ApiResult: 359 | async def warmup_gen() -> TokenGenerator: 360 | yield " " 361 | 362 | async def chunk_gen(ctx: ApiContext) -> TokenGenerator: 363 | async for msg in ctx.ws: 364 | chunk = json.loads(msg.data) 365 | match chunk["type"]: 366 | case "error": 367 | print(chunk) 368 | break 369 | case "response.text.delta": 370 | yield chunk["delta"] 371 | case "response.audio_transcript.delta": 372 | yield chunk["delta"] 373 | case "response.done": 374 | response = chunk["response"] 375 | ctx.metrics.input_tokens = response["usage"]["input_tokens"] 376 | ctx.metrics.output_tokens = response["usage"]["output_tokens"] 377 | break 378 | 379 | if not ctx.ws: 380 | base_url = ctx.base_url or "wss://api.openai.com/v1/realtime" 381 | url = f"{base_url}?model={ctx.model}" 382 | api_key = get_api_key(ctx, "OPENAI_API_KEY") 383 | headers = {"Authorization": f"Bearer {api_key}", "OpenAI-Beta": "realtime=v1"} 384 | ctx.ws = await ctx.session.ws_connect(url, headers=headers) 385 | if ctx.is_warmup: 386 | return WebSocketResponse(ctx), warmup_gen() 387 | 388 | create_item = { 389 | "type": "conversation.item.create", 390 | "item": make_openai_ws_message(ctx), 391 | } 392 | await ctx.ws.send_json(create_item) 393 | 394 | modalities = ["text"] 395 | if any(file.is_audio for file in ctx.files): 396 | modalities.append("audio") 397 | create_response = { 398 | "type": "response.create", 399 | "response": {"modalities": modalities}, 400 | } 401 | await ctx.ws.send_json(create_response) 402 | return WebSocketResponse(ctx), chunk_gen(ctx) 403 | 404 | 405 | def make_anthropic_messages(prompt: str, files: Optional[List[InputFile]] = None): 406 | """Formats the prompt as a text chunk and any images as image chunks. 407 | Note that Anthropic's image protocol is somewhat different from OpenAI's.""" 408 | if not files: 409 | return [{"role": "user", "content": prompt}] 410 | 411 | content: List[Dict[str, Any]] = [{"type": "text", "text": prompt}] 412 | for file in files: 413 | if not file.mime_type.startswith("image/"): 414 | raise ValueError(f"Unsupported file type: {file.mime_type}") 415 | source = { 416 | "type": "base64", 417 | "media_type": file.mime_type, 418 | "data": file.base64_data, 419 | } 420 | content.append({"type": "image", "source": source}) 421 | return [{"role": "user", "content": content}] 422 | 423 | 424 | async def anthropic_chat(ctx: ApiContext) -> ApiResult: 425 | """Make an Anthropic chat completion request. The request protocol is similar to OpenAI's, 426 | but the response protocol is completely different.""" 427 | 428 | async def chunk_gen(ctx: ApiContext, response) -> TokenGenerator: 429 | async for chunk in make_sse_chunk_gen(response): 430 | delta = chunk.get("delta") 431 | if delta and delta.get("type") == "text_delta": 432 | yield delta["text"] 433 | 434 | type = chunk.get("type") 435 | if type == "message_start": 436 | usage = chunk["message"].get("usage") 437 | if usage: 438 | ctx.metrics.input_tokens = usage.get("input_tokens") 439 | elif type == "message_delta": 440 | usage = chunk.get("usage") 441 | if usage: 442 | ctx.metrics.output_tokens = usage.get("output_tokens") 443 | 444 | url = "https://api.anthropic.com/v1/messages" 445 | headers = { 446 | "content-type": "application/json", 447 | "x-api-key": get_api_key(ctx, "ANTHROPIC_API_KEY"), 448 | "anthropic-version": "2023-06-01", 449 | "anthropic-beta": "messages-2023-12-15", 450 | } 451 | # Anthropic's schema is slightly different than OpenAI's. 452 | tools = [t["function"].copy() for t in ctx.tools] 453 | for tool in tools: 454 | tool["input_schema"] = tool["parameters"] 455 | del tool["parameters"] 456 | data = make_openai_chat_body( 457 | ctx, messages=make_anthropic_messages(ctx.prompt, ctx.files), tools=tools 458 | ) 459 | return await post(ctx, url, headers, data, chunk_gen) 460 | 461 | 462 | async def cohere_chat(ctx: ApiContext) -> ApiResult: 463 | """Make a Cohere chat completion request.""" 464 | 465 | async def chunk_gen(ctx: ApiContext, response) -> TokenGenerator: 466 | async for line in response.content: 467 | chunk = json.loads(line) 468 | if chunk.get("event_type") == "text-generation" and "text" in chunk: 469 | yield chunk["text"] 470 | elif chunk.get("event_type") == "stream-end": 471 | meta = chunk["response"]["meta"] 472 | ctx.metrics.input_tokens = meta["tokens"]["input_tokens"] 473 | ctx.metrics.output_tokens = meta["tokens"]["output_tokens"] 474 | 475 | url = "https://api.cohere.ai/v1/chat" 476 | headers = make_headers(auth_token=get_api_key(ctx, "COHERE_API_KEY")) 477 | data = make_openai_chat_body(ctx, message=ctx.prompt) 478 | return await post(ctx, url, headers, data, chunk_gen) 479 | 480 | 481 | async def cloudflare_chat(ctx: ApiContext) -> ApiResult: 482 | """Make a Cloudflare chat completion request. The protocol is similar to OpenAI's, 483 | but the URL doesn't follow the same scheme and the response structure is different. 484 | """ 485 | 486 | async def chunk_gen(ctx: ApiContext, response) -> TokenGenerator: 487 | async for chunk in make_sse_chunk_gen(response): 488 | yield chunk["response"] 489 | 490 | account_id = os.environ["CF_ACCOUNT_ID"] 491 | url = ( 492 | f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{ctx.model}" 493 | ) 494 | headers = make_headers(auth_token=get_api_key(ctx, "CF_API_KEY")) 495 | data = make_openai_chat_body(ctx, messages=make_openai_messages(ctx)) 496 | return await post(ctx, url, headers, data, chunk_gen) 497 | 498 | 499 | async def make_json_chunk_gen(response) -> AsyncGenerator[Dict[str, Any], None]: 500 | """Hacky parser for the JSON streaming format used by Google Vertex AI.""" 501 | buf = "" 502 | async for line in response.content: 503 | # Eat the first array bracket, we'll do the same for the last one below. 504 | line = line.decode("utf-8").strip() 505 | if not buf and line.startswith("["): 506 | line = line[1:] 507 | # Split on comma-only lines, otherwise concatenate. 508 | if line == ",": 509 | yield json.loads(buf) 510 | buf = "" 511 | else: 512 | buf += line 513 | yield json.loads(buf[:-1]) 514 | 515 | 516 | def get_google_access_token(): 517 | from google.auth.transport import requests 518 | from google.oauth2 import service_account 519 | 520 | creds = service_account.Credentials.from_service_account_file( 521 | "service_account.json", 522 | scopes=["https://www.googleapis.com/auth/cloud-platform"], 523 | ) 524 | if not creds.token: 525 | creds.refresh(requests.Request()) 526 | return creds.token 527 | 528 | 529 | def make_google_url_and_headers(ctx: ApiContext, method: str): 530 | region = "us-west1" 531 | project_id = os.environ["GCP_PROJECT"] 532 | url = f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/publishers/google/models/{ctx.model}:{method}" 533 | api_key = ctx.api_key 534 | if not api_key: 535 | api_key = get_google_access_token() 536 | headers = make_headers(auth_token=api_key) 537 | return url, headers 538 | 539 | 540 | def make_gemini_messages(prompt: str, files: List[InputFile]): 541 | parts: List[Dict[str, Any]] = [{"text": prompt}] 542 | for file in files: 543 | parts.append( 544 | {"inline_data": {"mime_type": file.mime_type, "data": file.base64_data}} 545 | ) 546 | 547 | return [{"role": "user", "parts": parts}] 548 | 549 | 550 | async def gemini_chat(ctx: ApiContext) -> ApiResult: 551 | async def chunk_gen(ctx: ApiContext, response) -> TokenGenerator: 552 | async for chunk in make_json_chunk_gen(response): 553 | candidates = chunk.get("candidates") 554 | if candidates: 555 | content = candidates[0].get("content") 556 | if content and "parts" in content: 557 | part = content["parts"][0] 558 | if "text" in part: 559 | yield part["text"] 560 | elif "functionCall" in part: 561 | call = part["functionCall"] 562 | if "name" in call: 563 | yield call["name"] 564 | if "args" in call: 565 | yield str(call["args"]) 566 | usage = chunk.get("usageMetadata") 567 | if usage: 568 | ctx.metrics.input_tokens = usage.get("promptTokenCount") 569 | ctx.metrics.output_tokens = usage.get("candidatesTokenCount") 570 | 571 | # The Google AI Gemini API (URL below) doesn't return the number of generated tokens. 572 | # Instead we use the Google Cloud Vertex AI Gemini API, which does return the number of tokens, but requires an Oauth credential. 573 | # Also, setting safetySettings to BLOCK_NONE is not supported in the Vertex AI Gemini API, at least for now. 574 | if True: 575 | url, headers = make_google_url_and_headers(ctx, "streamGenerateContent") 576 | else: 577 | url = f"https://generativelanguage.googleapis.com/v1beta/models/{ctx.model}:streamGenerateContent?key={get_api_key(ctx, 'GOOGLE_GEMINI_API_KEY')}" 578 | headers = make_headers() 579 | harm_categories = [ 580 | "HARM_CATEGORY_HARASSMENT", 581 | "HARM_CATEGORY_HATE_SPEECH", 582 | "HARM_CATEGORY_SEXUALLY_EXPLICIT", 583 | "HARM_CATEGORY_DANGEROUS_CONTENT", 584 | ] 585 | data = { 586 | "contents": make_gemini_messages(ctx.prompt, ctx.files), 587 | "generationConfig": { 588 | "temperature": ctx.temperature, 589 | "maxOutputTokens": ctx.max_tokens, 590 | }, 591 | "safetySettings": [ 592 | {"category": category, "threshold": "BLOCK_NONE"} 593 | for category in harm_categories 594 | if not ctx.files or ctx.files[0].is_image 595 | ], 596 | } 597 | if ctx.tools: 598 | data["tools"] = ( 599 | [{"function_declarations": [tool["function"] for tool in ctx.tools]}], 600 | ) 601 | return await post(ctx, url, headers, data, chunk_gen) 602 | 603 | 604 | async def cohere_embed(ctx: ApiContext) -> ApiResult: 605 | url = "https://api.cohere.ai/v1/embed" 606 | headers = make_headers(auth_token=get_api_key(ctx, "COHERE_API_KEY")) 607 | data = { 608 | "model": ctx.model, 609 | "texts": [ctx.prompt], 610 | "input_type": "search_query", 611 | } 612 | return await post(ctx, url, headers, data) 613 | 614 | 615 | async def fake_chat(ctx: ApiContext) -> ApiResult: 616 | class FakeResponse(aiohttp.ClientResponse): 617 | def __init__(self, status, reason): 618 | self.status = status 619 | self.reason = reason 620 | 621 | # async def release(self): 622 | # pass 623 | 624 | async def make_fake_chunk_gen(output: str): 625 | for word in output.split(): 626 | yield word + " " 627 | await asyncio.sleep(0.05) 628 | 629 | output = "This is a fake response." 630 | if ctx.index % 2 == 0: 631 | response = FakeResponse(200, "OK") 632 | else: 633 | response = FakeResponse(500, "Internal Server Error") 634 | sleep = 0.5 * (ctx.index + 1) 635 | max_sleep = ctx.session.timeout.total 636 | if max_sleep: 637 | await asyncio.sleep(min(sleep, max_sleep)) 638 | if sleep > max_sleep: 639 | raise TimeoutError 640 | return (response, make_fake_chunk_gen(output)) 641 | 642 | 643 | def make_display_name(provider_or_url: str, model: str) -> str: 644 | # Clean up the base URL to get a nicer provider name. 645 | if provider_or_url.startswith("https://"): 646 | provider = ( 647 | provider_or_url[8:] 648 | .split("/")[0] 649 | .replace("openai-sub-with-gpt4", "eastus2") 650 | .replace("fixie-", "") 651 | .replace("-serverless", "") 652 | .replace("inference.ai.azure.com", "azure") 653 | .replace("openai.azure.com", "azure") 654 | ) 655 | # Get the last two segments of the domain, and swap foo.azure to azure.foo. 656 | provider = ".".join(provider.split(".")[-2:]) 657 | provider = re.sub(r"(\w+)\.azure$", r"azure.\1", provider) 658 | else: 659 | provider = provider_or_url 660 | model_segments = model.split("/") 661 | if provider: 662 | # We already have a provider, so just need to add the model name. 663 | # If we've got a model name, add the end of the split to the provider. 664 | # Otherwise, we have model.domain.com, so we need to swap to domain.com/model. 665 | if model: 666 | name = provider + "/" + model_segments[-1] 667 | else: 668 | domain_segments = provider.split(".") 669 | name = ".".join(domain_segments[1:]) + "/" + domain_segments[0] 670 | elif len(model_segments) > 1: 671 | # We've got a provider/model string, from which we need to get the provider and model. 672 | provider = model_segments[0] 673 | name = provider + "/" + model_segments[-1] 674 | return name 675 | 676 | 677 | def make_context( 678 | session: aiohttp.ClientSession, 679 | index: int, 680 | args: argparse.Namespace, 681 | prompt: Optional[str] = None, 682 | files: Optional[List[InputFile]] = None, 683 | tools: Optional[List[Dict]] = None, 684 | ) -> ApiContext: 685 | model = args.model 686 | prefix = re.split("-|/", model)[0] 687 | provider = args.base_url 688 | match prefix: 689 | case "claude": 690 | provider = "anthropic" 691 | func = anthropic_chat 692 | case "command": 693 | provider = "cohere" 694 | func = cohere_chat 695 | case "@cf": 696 | provider = "cloudflare" 697 | func = cloudflare_chat 698 | case "gemini": 699 | provider = "google" 700 | func = gemini_chat 701 | case "text-embedding-ada": 702 | provider = "openai" 703 | func = openai_embed 704 | case "embed": 705 | provider = "cohere" 706 | func = cohere_embed 707 | case "fake": 708 | provider = "test" 709 | func = fake_chat 710 | case _ if "realtime" in model: 711 | func = openai_ws 712 | if not args.base_url: 713 | provider = "openai" 714 | case _ if args.base_url or model.startswith("gpt-") or model.startswith( 715 | "ft:gpt-" 716 | ): 717 | func = openai_chat 718 | if not args.base_url: 719 | provider = "openai" 720 | case _: 721 | raise ValueError(f"Unknown model: {model}") 722 | name = args.display_name or make_display_name(provider, model) 723 | return ApiContext( 724 | session, index, name, func, args, prompt or "", files or [], tools or [] 725 | ) 726 | -------------------------------------------------------------------------------- /media/audio/boolq.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fixie-ai/ai-benchmarks/265e8bcadad1c3749fb940e08b09928ee4a279e8/media/audio/boolq.wav -------------------------------------------------------------------------------- /media/audio/news.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fixie-ai/ai-benchmarks/265e8bcadad1c3749fb940e08b09928ee4a279e8/media/audio/news.wav -------------------------------------------------------------------------------- /media/audio/say_cheese.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fixie-ai/ai-benchmarks/265e8bcadad1c3749fb940e08b09928ee4a279e8/media/audio/say_cheese.wav -------------------------------------------------------------------------------- /media/image/great_wave.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:62d12f0d6e4804b96735cbaa827b6652532209b7b8df4f8573058f52d4698398 3 | size 1283028 4 | -------------------------------------------------------------------------------- /media/image/inception.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fixie-ai/ai-benchmarks/265e8bcadad1c3749fb940e08b09928ee4a279e8/media/image/inception.jpeg -------------------------------------------------------------------------------- /media/text/llama31.md: -------------------------------------------------------------------------------- 1 | Summarize the document below into a single informative sentence: 2 | --- 3 | Foundation models are general models of language, vision, speech, and/or other modalities that are designed to support a large variety of AI tasks. They form the basis of many modern AI systems. 4 | 5 | The development of modern foundation models consists of two main stages: (1) a pre-training stage in which the model is trained at massive scale using straightforward tasks such as next-word prediction or captioning and (2) a post-training stage in which the model is tuned to follow instructions, align with human preferences, and improve specific capabilities (for example, coding and reasoning). 6 | 7 | In this paper, we present a new set of foundation models for language, called Llama 3. The Llama 3 Herd of models natively supports multilinguality, coding, reasoning, and tool usage. Our largest model is dense Transformer with 405B parameters, processing information in a context window of up to 128K tokens. Each member of the herd is listed in Table 1. All the results presented in this paper are for the Llama 3.1 models, which we will refer to as Llama 3 throughout for brevity. 8 | 9 | We believe there are three key levers in the development of high-quality foundation models: data, scale, and managing complexity. We seek to optimize for these three levers in our development process: 10 | 11 | Data. Compared to prior versions of Llama (Touvron et al., 2023a,b), we improved both the quantity and quality of the data we use for pre-training and post-training. These improvements include the development of more careful pre-processing and curation pipelines for pre-training data and the development of more rigorous quality assurance and filtering approaches for post-training data. We pre-train Llama 3 on a corpus of about 15T multilingual tokens, compared to 1.8T tokens for Llama 2. 12 | 13 | Scale. We train a model at far larger scale than previous Llama models: our flagship language model was pre-trained using 3.8 × 1025 FLOPs, almost 50× more than the largest version of Llama 2. Specifically, we pre-trained a flagship model with 405B trainable parameters on 15.6T text tokens. As expected per. scaling laws for foundation models, our flagship model outperforms smaller models trained using the same procedure. While our scaling laws suggest our flagship model is an approximately compute-optimal size for our training budget, we also train our smaller models for much longer than is compute-optimal. The resulting models perform better than compute-optimal models at the same inference budget. We use the flagship model to further improve the quality of those smaller models during post-training. 14 | 15 | Managing complexity. We make design choices that seek to maximize our ability to scale the model development process. For example, we opt for a standard dense Transformer model architecture (Vaswani et al., 2017) with minor adaptations, rather than for a mixture-of-experts model (Shazeer et al., 2017) to maximize training stability. Similarly, we adopt a relatively simple post-training procedure based on supervised finetuning (SFT), rejection sampling (RS), and direct preference optimization (DPO; Rafailov et al. (2023)) as opposed to more complex reinforcement learning algorithms (Ouyang et al., 2022; Schulman et al., 2017) that tend to be less stable and harder to scale. 16 | 17 | The result of our work is Llama 3: a herd of three multilingual language models with 8B, 70B, and 405B parameters. We evaluate the performance of Llama 3 on a plethora of benchmark datasets that span a wide range of language understanding tasks. In addition, we perform extensive human evaluations that compare Llama 3 with competing models. An overview of the performance of the flagship Llama 3 model on key benchmarks is presented in Table 2. Our experimental evaluation suggests that our flagship model performs on par with leading language models such as GPT-4 (OpenAI, 2023a) across a variety of tasks, and is close to matching the state-of-the-art. Our smaller models are best-in-class, outperforming alternative models with similar numbers of parameters (Bai et al., 2023; Jiang et al., 2023). Llama 3 also delivers a much better balance between helpfulness and harmlessness than its predecessor (Touvron et al., 2023b). We present a detailed analysis of the safety of Llama 3 in Section 5.4. 18 | 19 | We are publicly releasing all three Llama 3 models under an updated version of the Llama 3 Community License; see https://llama.meta.com. 20 | -------------------------------------------------------------------------------- /media/tools/flights.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "function", 3 | "function": { 4 | "name": "get_flight_status", 5 | "description": "Get the current status of a flight", 6 | "parameters": { 7 | "type": "object", 8 | "properties": { 9 | "flight_number": { 10 | "type": "string", 11 | "description": "The flight number, e.g., AA100" 12 | }, 13 | "date": { 14 | "type": "string", 15 | "description": "The date of the flight, e.g., 2024-06-17" 16 | } 17 | }, 18 | "required": ["flight_number", "date"] 19 | } 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /media/video/psa.webm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fixie-ai/ai-benchmarks/265e8bcadad1c3749fb940e08b09928ee4a279e8/media/video/psa.webm -------------------------------------------------------------------------------- /openai_finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import openai 5 | 6 | openai.api_key = os.getenv("OPENAI_API_KEY") 7 | training_file = "/Users/juberti/Downloads/pirate_tune.jsonl" 8 | training_file_response = openai.File.create( 9 | file=open(training_file, "rb"), purpose="fine-tune" 10 | ) 11 | training_file_id = training_file_response["id"] 12 | print(f"Training file uploaded with ID: {training_file_id}") 13 | 14 | fine_tuning_job = openai.FineTuningJob.create( 15 | training_file=training_file_id, model="gpt-3.5-turbo" 16 | ) 17 | job_id = fine_tuning_job["id"] 18 | print(f"Fine-tuning job created with ID: {job_id}") 19 | 20 | while True: 21 | try: 22 | fine_tuning_status = openai.FineTune.retrieve(job_id) 23 | except openai.error.InvalidRequestError as e: 24 | print(e) 25 | time.sleep(1) 26 | continue 27 | 28 | status = fine_tuning_status["status"] 29 | print(f"Fine-tuning job status: {status}") 30 | 31 | if status in ["completed", "failed"]: 32 | break 33 | 34 | time.sleep(60) 35 | fine_tuned_model_id = fine_tuning_status["fine_tuned_model_id"] 36 | 37 | completion = openai.ChatCompletion.create( 38 | model=fine_tuned_model_id, 39 | messages=[ 40 | {"role": "system", "content": "You are a helpful assistant."}, 41 | {"role": "user", "content": "Hello!"}, 42 | ], 43 | ) 44 | print(completion.choices[0].message) 45 | -------------------------------------------------------------------------------- /playht_benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import dataclasses 4 | import logging 5 | import os 6 | import subprocess 7 | import time 8 | 9 | import aiohttp 10 | from pyht import client 11 | from pyht.protos import api_pb2 12 | 13 | logging.basicConfig( 14 | format="%(asctime)s [%(levelname)s] %(message)s", level=logging.INFO 15 | ) 16 | 17 | PLAYHT_API_KEY = os.environ.get("PLAYHT_API_KEY") 18 | PLAYHT_USER_ID = os.environ.get("PLAYHT_USER_ID") 19 | 20 | # Defaults Text and voice settings 21 | DEFAULT_TEXT = "Ah, these kids today! They don't know the struggle, I'll tell ya that much. Back in my day, the internet wasn't this instant-gratification paradise it is now." 22 | DEFAULT_VOICE_ID = "s3://voice-cloning-zero-shot/7c339a9d-370f-4643-adf5-4134e3ec9886/mlae02/manifest.json" 23 | 24 | # Audio settings 25 | DEFAULT_QUALITY = "draft" # Quality of the audio. Default is "draft". Other options are low, medium, high, and premium. 26 | DEFAULT_OUTPUT_FORMAT = "mp3" 27 | DEFAULT_SPEED = 1 # Values 1 to 5 28 | DEFAULT_SAMPLE_RATE = ( 29 | 24000 # Sample rate of the audio. A number between 8000 and 48000. 30 | ) 31 | 32 | # Random generator settings 33 | DEFAULT_SEED = ( 34 | None # Seed for the random generator. If None, a random seed will be used. 35 | ) 36 | DEFAULT_TEMPERATURE = None # Temperature for the random generator. Controls variance. If None, the model's default temperature will be used. 37 | 38 | # Voice engine settings 39 | DEFAULT_VOICE_ENGINE = "PlayHT2.0" # Voice engine to be used. Default is "PlayHT1.0". 40 | DEFAULT_EMOTION = "male_angry" # Emotion of the speech. Only supported when voice_engine is set to "PlayHT2.0". 41 | DEFAULT_VOICE_GUIDANCE = 2 # Voice guidance level. A number between 1 and 6. Lower numbers reduce the voice's uniqueness, higher numbers maximize its individuality. Only supported when voice_engine is set to "PlayHT2.0". 42 | DEFAULT_STYLE_GUIDANCE = 20 # Style guidance level. A number between 1 and 30. Lower numbers reduce the strength of the chosen emotion, higher numbers create a more emotional performance. Only supported when voice_engine is set to "PlayHT2.0". 43 | 44 | WARMUP_TEXT = "a" 45 | CHUNK_SIZE = 4096 46 | 47 | # argument parser 48 | parser = argparse.ArgumentParser(description="Stream audio from server.") 49 | parser.add_argument( 50 | "text", default=DEFAULT_TEXT, nargs="?", help="Text to be converted to speech" 51 | ) 52 | parser.add_argument( 53 | "--play", "-p", default=False, action="store_true", help="Play the audio using mpv" 54 | ) 55 | parser.add_argument("--transport", "-t", default="rest", help="Transport to be use") 56 | parser.add_argument("--voice", "-v", default=DEFAULT_VOICE_ID, help="Voice to be used") 57 | parser.add_argument( 58 | "--quality", "-q", default=DEFAULT_QUALITY, help="Quality of the audio" 59 | ) 60 | parser.add_argument( 61 | "--speed", "-s", default=DEFAULT_SPEED, type=int, help="Speed of the speech" 62 | ) 63 | parser.add_argument( 64 | "--format", "-f", default=DEFAULT_OUTPUT_FORMAT, help="Output format of the audio" 65 | ) 66 | parser.add_argument( 67 | "--warmup", 68 | "-w", 69 | default=False, 70 | action="store_true", 71 | help="Perform a warmup call before generation", 72 | ) 73 | parser.add_argument( 74 | "--voice-engine", default=DEFAULT_VOICE_ENGINE, help="Voice engine to be used" 75 | ) 76 | parser.add_argument("--emotion", default=DEFAULT_EMOTION, help="Emotion of the speech") 77 | parser.add_argument( 78 | "--voice-guidance", 79 | default=DEFAULT_VOICE_GUIDANCE, 80 | type=int, 81 | help="Voice guidance level", 82 | ) 83 | parser.add_argument( 84 | "--style-guidance", 85 | default=DEFAULT_STYLE_GUIDANCE, 86 | type=int, 87 | help="Style guidance level", 88 | ) 89 | 90 | args = parser.parse_args() 91 | 92 | 93 | @dataclasses.dataclass 94 | class LatencyData: 95 | def __init__(self): 96 | self.start_time = 0 97 | self.headers_received = 0 98 | self.first_chunk = 0 99 | self.chunk_times = [] 100 | self.total_time = 0 101 | 102 | def start(self): 103 | self.start_time = time.perf_counter() 104 | 105 | def set_headers_received(self): 106 | self.headers_received = time.perf_counter() - self.start_time 107 | 108 | def set_first_chunk(self): 109 | self.first_chunk = time.perf_counter() - self.start_time 110 | 111 | def add_chunk_time(self): 112 | self.chunk_times.append(time.perf_counter() - self.start_time) 113 | 114 | def set_total_time(self): 115 | self.total_time = time.perf_counter() - self.start_time 116 | 117 | start_time: float 118 | headers_received: float 119 | chunk_times: list[float] 120 | total_time: float 121 | 122 | 123 | async def stream_rest(response, latency_data: LatencyData): 124 | if args.play: 125 | mpv_command = ["mpv", "--no-cache", "--no-terminal", "--", "fd://0"] 126 | mpv_process = subprocess.Popen( 127 | mpv_command, 128 | stdin=subprocess.PIPE, 129 | stdout=subprocess.DEVNULL, 130 | stderr=subprocess.DEVNULL, 131 | ) 132 | else: 133 | mpv_process = None 134 | 135 | bytes_received = 0 136 | async for chunk in response.content.iter_chunked(CHUNK_SIZE): 137 | latency_data.add_chunk_time() 138 | if mpv_process: 139 | mpv_process.stdin.write(chunk) 140 | bytes_received += len(chunk) 141 | print( 142 | f"Received chunk of size {len(chunk):<5} bytes | Total bytes received: {bytes_received:<6}", 143 | end="\r", 144 | ) 145 | 146 | latency_data.set_total_time() 147 | if mpv_process: 148 | mpv_process.stdin.close() 149 | mpv_process.wait() 150 | 151 | 152 | def create_rest_body(text: str): 153 | return { 154 | "text": text, 155 | "voice": args.voice, 156 | "quality": args.quality, 157 | "output_format": args.format, 158 | "speed": args.speed, 159 | "sample_rate": DEFAULT_SAMPLE_RATE, 160 | "seed": DEFAULT_SEED, 161 | "temperature": DEFAULT_TEMPERATURE, 162 | "voice_engine": args.voice_engine, 163 | "emotion": args.emotion, 164 | "voice_guidance": args.voice_guidance, 165 | "style_guidance": args.style_guidance, 166 | } 167 | 168 | 169 | async def async_generate_rest(latency_data: LatencyData): 170 | url = "https://play.ht/api/v2/tts/stream" 171 | headers = { 172 | "AUTHORIZATION": f"Bearer {PLAYHT_API_KEY}", 173 | "X-USER-ID": PLAYHT_USER_ID, 174 | "accept": "audio/mpeg", 175 | "content-type": "application/json", 176 | } 177 | async with aiohttp.ClientSession() as session: 178 | if args.warmup: 179 | logging.info("Sending warmup request...") 180 | async with session.post( 181 | url, headers=headers, json=create_rest_body(WARMUP_TEXT) 182 | ) as response: 183 | pass 184 | logging.info("Sending REST request...") 185 | if latency_data: 186 | latency_data.start() 187 | async with session.post( 188 | url, headers=headers, json=create_rest_body(args.text) 189 | ) as response: 190 | latency_data.set_headers_received() 191 | logging.info(f"Latency: {latency_data.headers_received*1000:.2f} ms") 192 | logging.info(f"Status code: {response.status}") 193 | logging.info("-" * 40) 194 | logging.info(f'Text: "{args.text}"') 195 | logging.info("-" * 40) 196 | if response.ok and "audio/mpeg" in response.headers.get("Content-Type"): 197 | logging.info("Streaming audio...") 198 | logging.info("-" * 40) 199 | await stream_rest(response, latency_data) 200 | else: 201 | logging.error("No audio data in the response.") 202 | 203 | 204 | def generate_rest(latency_data: LatencyData): 205 | return asyncio.get_event_loop().run_until_complete( 206 | async_generate_rest(latency_data) 207 | ) 208 | 209 | 210 | def stream_grpc(gen, latency_data: LatencyData): 211 | if args.play: 212 | mpv_command = ["mpv", "--no-cache", "--no-terminal", "--", "fd://0"] 213 | mpv_process = subprocess.Popen( 214 | mpv_command, 215 | stdin=subprocess.PIPE, 216 | stdout=subprocess.DEVNULL, 217 | stderr=subprocess.DEVNULL, 218 | ) 219 | else: 220 | mpv_process = None 221 | 222 | bytes_received = 0 223 | for chunk in gen: 224 | latency_data.add_chunk_time() 225 | if mpv_process: 226 | mpv_process.stdin.write(chunk) 227 | bytes_received += len(chunk) 228 | print( 229 | f"Received chunk of size {len(chunk):<5} bytes | Total bytes received: {bytes_received:<6}", 230 | end="\r", 231 | ) 232 | 233 | latency_data.set_total_time() 234 | if mpv_process: 235 | mpv_process.stdin.close() 236 | mpv_process.wait() 237 | 238 | 239 | def generate_grpc(latency_data: LatencyData): 240 | advanced = client.Client.AdvancedOptions(grpc_addr="prod.turbo.play.ht:443") 241 | grpc_client = client.Client(PLAYHT_USER_ID, PLAYHT_API_KEY, advanced=advanced) 242 | if args.format == "mp3": 243 | format = api_pb2.FORMAT_MP3 244 | elif args.format == "wav": 245 | format = api_pb2.FORMAT_WAV 246 | else: 247 | logging.error("Invalid format") 248 | exit(1) 249 | options = client.TTSOptions(format=format, voice=args.voice, quality=args.quality) 250 | if args.warmup: 251 | logging.info("Sending warmup request...") 252 | list(grpc_client.tts(WARMUP_TEXT, options)) 253 | logging.info("Sending GRPC request...") 254 | latency_data.start() 255 | result = grpc_client.tts(args.text, options) 256 | header = next(result) 257 | latency_data.set_headers_received() 258 | logging.info(f"Latency: {latency_data.headers_received*1000:.2f} ms") 259 | logging.info("-" * 40) 260 | logging.info(f'Text: "{args.text}"') 261 | logging.info("-" * 40) 262 | logging.info("Streaming audio...") 263 | logging.info("-" * 40) 264 | stream_grpc(result, latency_data) 265 | grpc_client.close() 266 | 267 | 268 | def main(): 269 | latency_data = LatencyData() 270 | if args.transport == "rest": 271 | generate_rest(latency_data) 272 | elif args.transport == "grpc": 273 | generate_grpc(latency_data) 274 | else: 275 | logging.error("Invalid transport") 276 | exit(1) 277 | 278 | # Latency Summary 279 | logging.info("\n" + "=" * 40) 280 | logging.info("LATENCY SUMMARY") 281 | logging.info("-" * 40) 282 | logging.info( 283 | f"Time to receive headers: {latency_data.headers_received*1000:.2f} ms" 284 | ) 285 | if latency_data.chunk_times: 286 | logging.info(f"Time to first chunk: {latency_data.chunk_times[0]*1000:.2f} ms") 287 | logging.info(f"Total time: {latency_data.total_time*1000:.2f} ms") 288 | logging.info("=" * 40 + "\n") 289 | 290 | 291 | if __name__ == "__main__": 292 | main() 293 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "ai-benchmarks" 3 | version = "0.1.0" 4 | description = "AI API benchmarking suite" 5 | authors = ["juberti "] 6 | license = "MIT" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.11" 11 | requests = "2.31.0" 12 | websockets = "12.0" 13 | aiohttp = "3.9.3" 14 | pyht = "0.0.27" 15 | gcloud-aio-storage = "^9.3.0" 16 | gunicorn = "21.2.0" 17 | uvicorn = "^0.29.0" 18 | fastapi = "^0.110.2" 19 | google-auth = "^2.29.0" 20 | dataclasses-json = "^0.6.5" 21 | soxr = "^0.5.0.post1" 22 | soundfile = "^0.12.1" 23 | 24 | 25 | [tool.poetry.group.dev.dependencies] 26 | black = "^24.4.0" 27 | isort = "^5.13.2" 28 | autoflake = "^2.3.1" 29 | mypy = "^1.9.0" 30 | 31 | [build-system] 32 | requires = ["poetry-core"] 33 | build-backend = "poetry.core.masonry.api" 34 | 35 | [tool.isort] 36 | profile = "black" 37 | single_line_exclusions = ["typing", "collections.abc", "typing_extensions"] 38 | --------------------------------------------------------------------------------