├── .gitignore ├── LICENSE ├── README.md ├── examples ├── batch_inference_wildchat.py └── mp_inference_wildchat.py ├── llm_engines ├── __init__.py ├── cache │ ├── __init__.py │ ├── cache_dict.py │ ├── cache_lru.py │ ├── cache_sqlite3.py │ └── cache_utils.py ├── claude.py ├── cli.py ├── deepseek.py ├── fireworks.py ├── gemini.py ├── grok.py ├── mistral.py ├── openai_text.py ├── sglang.py ├── together.py ├── utils.py └── vllm.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | /llm_engines/generation_cache 165 | /test* 166 | .vscode 167 | /llm_engines/*.json -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Dongfu Jiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLM-Engines 2 | 3 | [Author: Dongfu Jiang](https://jdf-prog.github.io/), [Twitter](https://x.com/DongfuJiang/status/1833730295696334925), [PyPI Package](https://pypi.org/project/llm-engines/) 4 | 5 | A unified inference engine for large language models (LLMs) including open-source models (VLLM, SGLang, Together) and commercial models (OpenAI, Mistral, Claude). 6 | 7 | The correctness of the inference has been verified by comparing the outputs of the models with different engines when `temperature=0.0` and `max_tokens=None`. 8 | For example, the outputs of a single model using 3 enginer (VLLM, SGLang, Together) will be the same when `temperature=0.0` and `max_tokens=None`. 9 | Try examples below to see the outputs of different engines. 10 | 11 | ## News 12 | - 2025-03-03: support `sleep` for vllm models, see [Sleep Mode](#sleep-mode) for more details. 13 | - 2025-02-23: Support for vision input for all engines. See [Vision Input](#vision-input) for more details. 14 | - 2025-02-19: Add support for `fireworks` api services, which provide calling for deepseek-r1 models with high speed. 15 | - 2025-02-18: Add support for `grok` models. 16 | 17 | ## Installation 18 | We recommend to use `uv` to manage the environment due to its fast installation speed. 19 | ```bash 20 | pip install llm-engines # or 21 | # pip install git+https://github.com/jdf-prog/LLM-Engines.git 22 | pip install flash-attn --no-build-isolation 23 | ``` 24 | If you want to use SGLang, you need to install it separately: 25 | ```bash 26 | pip install "sglang[all]>=0.4.6.post5" 27 | ``` 28 | 29 | For development: 30 | ```bash 31 | pip install -e . # for development 32 | ``` 33 | 34 | ## Usage 35 | 36 | ### Engines 37 | - use vllm or sglang 38 | ```python 39 | from llm_engines import LLMEngine 40 | model_name="Qwen/Qwen2.5-0.5B-Instruct" 41 | llm = LLMEngine() 42 | llm.load_model( 43 | model_name=model_name, 44 | num_workers=1, # number of workers 45 | num_gpu_per_worker=1, # tensor parallelism size for each worker 46 | engine="vllm", # or "sglang" 47 | use_cache=False 48 | ) 49 | response = llm.call_model(model_name, "What is the capital of France?", temperature=0.0, max_tokens=None) 50 | print(response) 51 | ``` 52 | 53 | - use together 54 | ```python 55 | # export TOGETHER_API_KEY="your_together_api_key" 56 | from llm_engines import LLMEngine 57 | model_name="meta-llama/Llama-3-8b-chat-hf" 58 | llm = LLMEngine() 59 | llm.load_model( 60 | model_name=model_name, 61 | engine="together", # or "openai", "mistral", "claude" 62 | use_cache=False 63 | ) 64 | response = llm.call_model(model_name, "What is the capital of France?", temperature=0.0, max_tokens=None) 65 | print(response) 66 | ``` 67 | 68 | - openai models 69 | ```python 70 | # export OPENAI_API_KEY="your_openai_api_key" 71 | from llm_engines import LLMEngine 72 | model_name="gpt-3.5-turbo" 73 | llm = LLMEngine() 74 | llm.load_model( 75 | model_name=model_name, 76 | engine="openai", # or "vllm", "together", "mistral", "claude" 77 | use_cache=False 78 | ) 79 | response = llm.call_model(model_name, "What is the capital of France?", temperature=0.0, max_tokens=None) 80 | print(response) 81 | ``` 82 | 83 | - grok models 84 | ```python 85 | # export XAI_API_KEY="your_xai_api_key" 86 | from llm_engines import LLMEngine 87 | model_name="grok-2-latest" 88 | llm = LLMEngine() 89 | llm.load_model( 90 | model_name=model_name, 91 | engine="grok", # or "vllm", "together", "mistral", "claude" 92 | use_cache=False 93 | ) 94 | response = llm.call_model(model_name, "What is the capital of France?", temperature=0.0, max_tokens=None) 95 | print(response) 96 | ``` 97 | 98 | - mistral models 99 | ```python 100 | # export MISTRAL_API_KEY="your_mistral_api_key" 101 | from llm_engines import LLMEngine 102 | model_name="mistral-large-latest" 103 | llm = LLMEngine() 104 | llm.load_model( 105 | model_name=model_name, 106 | engine="mistral", # or "vllm", "together", "openai", "claude" 107 | use_cache=False 108 | ) 109 | response = llm.call_model(model_name, "What is the capital of France?", temperature=0.0, max_tokens=None) 110 | print(response) 111 | ``` 112 | 113 | - claude models 114 | ```python 115 | # export ANTHROPIC_API_KEY="your_claude_api_key" 116 | from llm_engines import LLMEngine 117 | model_name="claude-3-opus-20240229" 118 | llm = LLMEngine() 119 | llm.load_model( 120 | model_name=model_name, 121 | engine="claude", # or "vllm", "together", "openai", "mistral" 122 | use_cache=False 123 | ) 124 | response = llm.call_model(model_name, "What is the capital of France?", temperature=0.0, max_tokens=None) 125 | print(response) 126 | ``` 127 | 128 | - gemini models 129 | ```python 130 | # export GEMINI_API_KEY="your_gemini_api_key" 131 | from llm_engines import LLMEngine 132 | model_name="gemini-1.5-flash" 133 | llm = LLMEngine() 134 | llm.load_model( 135 | model_name=model_name, 136 | engine="gemini", # or "vllm", "together", "openai", "mistral", "claude" 137 | use_cache=False 138 | ) 139 | response = llm.call_model(model_name, "What is the capital of France?", temperature=0.0, max_tokens=None) 140 | print(response) 141 | ``` 142 | 143 | - fireworks api 144 | ```python 145 | ```python 146 | # export FIREWORKS_API_KEY="your_fireworks_api_key" 147 | from llm_engines import LLMEngine 148 | model_name="accounts/fireworks/models/deepseek-r1" 149 | llm = LLMEngine() 150 | llm.load_model( 151 | model_name=model_name, 152 | engine="fireworks", # or "vllm", "together", "openai", "mistral", "claude" 153 | use_cache=False 154 | ) 155 | response = llm.call_model(model_name, "What is the capital of France?", temperature=0.0, max_tokens=None) 156 | print(response) 157 | ``` 158 | 159 | ### unload model 160 | Remember to unload the model after using it to free up the resources. By default, all the workers will be unloaded after the program exits. If you want to use different models in the same program, you can unload the model before loading a new model, if that model needs gpu resources. 161 | ```python 162 | llm.unload_model(model_name) # unload all the workers named model_name 163 | llm.unload_model() # unload all the workers 164 | ``` 165 | 166 | ### Multi-turn conversation 167 | ```python 168 | from llm_engines import LLMEngine 169 | model_name="Qwen/Qwen2.5-0.5B-Instruct" 170 | llm = LLMEngine() 171 | llm.load_model( 172 | model_name="Qwen/Qwen2.5-0.5B-Instruct", 173 | num_workers=1, # number of workers 174 | num_gpu_per_worker=1, # tensor parallelism size for each worker 175 | engine="vllm", # or "sglang" 176 | use_cache=False 177 | ) 178 | messages = [ 179 | "Hello", # user message 180 | "Hello! It's nice to meet you. Is there something I can help you with, or would you like to chat?", # previous model response 181 | "What is the capital of France?" # user message 182 | ] 183 | # or you can use opneai's multi-turn conversation format. 184 | messages = [ 185 | {"role": "user", "content": "Hello"}, # user message 186 | {"role": "assistant", "content": "Hello! It's nice to meet you. Is there something I can help you with, or would you like to chat?"}, # previous model response 187 | {"role": "user", "content": "What is the capital of France?"} # user message 188 | ] 189 | response = llm.call_model(model_name, messages, temperature=0.0, max_tokens=None) 190 | print(response) 191 | ``` 192 | the messages should be in the format of 193 | - `[user_message, model_response, user_message, model_response, ...]` 194 | - or in the format of openai's multi-turn conversation format. 195 | 196 | ### Vision Input 197 | ```python 198 | from llm_engines import LLMEngine 199 | from PIL import Image 200 | import requests 201 | from io import BytesIO 202 | response = requests.get("https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg") 203 | image = Image.open(BytesIO(response.content)).resize((256, 256)) 204 | image.save("./test.jpg") 205 | messages_with_image = [ 206 | { 207 | "role": "user", 208 | "content": [ 209 | { 210 | "type": "text", 211 | "text": "What's in the image?" 212 | }, 213 | { 214 | "type": "image", 215 | "image": image 216 | } 217 | ] 218 | } 219 | ] # the 'image' type is not offical format of openai API, LLM-Engines will convert it into image_url type internally 220 | messages_with_image_url = [ 221 | { 222 | "role": "user", 223 | "content": [ 224 | { 225 | "type": "text", 226 | "text": "What's in the image?" 227 | }, 228 | { 229 | "type": "image_url", 230 | "image_url": {"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"} 231 | } 232 | ] 233 | } 234 | ] # the 'image_url' type is the offical format of openai API 235 | additional_args=[] 236 | # engine="openai"; model_name="gpt-4o-mini" 237 | # engine="claude"; model_name="claude-3-5-sonnet-20241022" 238 | # engine="gemini"; model_name="gemini-2.0-flash" 239 | # engine="grok"; model_name="grok-2-vision-latest" 240 | # engine="sglang"; model_name="meta-llama/Llama-3.2-11B-Vision-Instruct"; additional_args=["--chat-template=llama_3_vision"] # refer to 241 | engine="vllm"; model_name="microsoft/Phi-3.5-vision-instruct"; additional_args=["--limit-mm-per-prompt", "image=2", "--max-model-len", "4096"] # refer to vllm serve api 242 | llm = LLMEngine() 243 | llm.load_model( 244 | model_name=model_name, 245 | engine=engine, # or "vllm", "together", "mist 246 | use_cache=False, 247 | additional_args=additional_args, 248 | ) 249 | response = llm.call_model(model_name, messages_with_image, temperature=0.0, max_tokens=None) 250 | print(response) 251 | response = llm.call_model(model_name, messages_with_image_url, temperature=0.0, max_tokens=None) 252 | print(response) 253 | ``` 254 | 255 | ### Sleep Mode 256 | We support vllm's sleep mode if you want to save the GPU resources when the model is not used. (should have `vllm>=0.7.3`) 257 | ```python 258 | import time 259 | from llm_engines import LLMEngine 260 | model_name="Qwen/Qwen2.5-0.5B-Instruct" 261 | llm = LLMEngine() 262 | llm.load_model( 263 | model_name=model_name, 264 | num_workers=1, # number of workers 265 | num_gpu_per_worker=1, # tensor parallelism size for each worker 266 | engine="vllm", # or "sglang" 267 | use_cache=False, 268 | additional_args=["--enable-sleep-mode"] # enable sleep mode 269 | ) 270 | response = llm.call_model(model_name, "What is the capital of France?", temperature=0.0, max_tokens=None) 271 | print(response) 272 | llm.sleep_model(model_name) # sleep all the instances that named model_name 273 | time.sleep(20) # check your GPU usage, it should be almost 0 274 | llm.wake_up_model(model_name) # wake up all the instances that named model_name 275 | response = llm.call_model(model_name, "What is the capital of France?", temperature=0.0, max_tokens=None) 276 | ``` 277 | 278 | ### Batch inference 279 | ```python 280 | from llm_engines import LLMEngine 281 | model_name="Qwen/Qwen2.5-0.5B-Instruct" 282 | llm = LLMEngine() 283 | llm.load_model( 284 | model_name="Qwen/Qwen2.5-0.5B-Instruct", 285 | num_workers=1, # number of workers 286 | num_gpu_per_worker=1, # tensor parallelism size for each worker 287 | engine="vllm", # or "sglang" 288 | use_cache=False 289 | ) 290 | batch_messages = [ 291 | "Hello", # user message 292 | "Hello! It's nice to meet you. Is there something I can help you with, or would you like to chat?", # previous model response 293 | "What is the capital of France?" # user message 294 | ] * 100 295 | response = llm.batch_call_model(model_name, batch_messages, num_proc=32, temperature=0.0, max_tokens=None) 296 | print(response) 297 | # List of responses [response1, response2, ...] 298 | ``` 299 | Example inference file: [`./examples/batch_inference_wildchat.py`](./examples/batch_inference_wildchat.py) 300 | ```bash 301 | python examples/batch_inference_wildchat.py 302 | ``` 303 | 304 | **OpenAI Batch API** 305 | by using the above code, it will automatically use the batch API for openai models. if you don't want to use the batch API and still want to use the normal API, set `disable_batch_api=True` when loading the model. `num_proc` will be ignored when using the batch API. 306 | 307 | By using openai's batch API, you can get half the price of the normal API. The batch API is only available for the models with `max_batch_size > 1`. 308 | 309 | LLM-Engines will calculates the hash of the inputs and generation parameters, and will only send new batch requests if the inputs and generation parameters are different from the previous requests. You can check a list of requested batch information in the [`~/llm_engines/generation_cache/openai_batch_cache/batch_submission_status.json`](~/llm_engines/generation_cache/openai_batch_cache/batch_submission_status.json) file. 310 | 311 | ### Parallel infernece throught huggingface dataset map 312 | Check out [`./examples/mp_inference_wildchat.py`](./examples/mp_inference_wildchat.py) for parallel inference with multiple models. 313 | ```bash 314 | python examples/mp_inference_wildchat.py 315 | ``` 316 | 317 | ### Cache 318 | 319 | if `use_cache=True`, all the queries and responses are cached in the `generation_cache` folder, no duplicate queries will be sent to the model. 320 | The cache of each model is saved to `generation_cache/{model_name}.jsonl` 321 | 322 | Example items in the cache: 323 | ```json 324 | {"cb0b4aaf80c43c9973aefeda1bd72890": {"input": ["What is the capital of France?"], "output": "The capital of France is Paris."}} 325 | ``` 326 | The hash key here is the hash of the concatenated inputs. 327 | 328 | ### Chat template 329 | For each open-source models, we use the default chat template as follows: 330 | ```python 331 | prompt = self.tokenizer.apply_chat_template( 332 | messages, 333 | add_generation_prompt=add_generation_prompt, 334 | tokenize=False, 335 | chat_template=chat_template, 336 | ) 337 | ``` 338 | There will be errors if the model does not support the chat template. 339 | 340 | ### Worker initialization parameters (`load_model`) 341 | - `model_name`: the model name, e.g., "Qwen/Qwen2.5-0.5B-Instruct" (required) 342 | - `worker_addrs`: the list of worker addresses to use, if not provided, a new worker will be launched. If provided, it will use the existing workers (default: None) 343 | - `num_workers`: the number of workers to use for the model (default: 1) 344 | - `num_gpu_per_worker`: the number of GPUs to use for each worker (default: None) 345 | - `engine`: the engine to use, one of {vllm, sglang, together, openai, mistral, claude, gemini} (default: "vllm") 346 | - `additional_args`: list of str, additional arguments for launching the (vllm, sglang) worker, e.g., `["--max-model-len", "65536"]` (default: []) 347 | - `use_cache`: whether to use the cache for the queries and responses (default: True) 348 | - `cache_dir`: the cache directory, env variable `LLM_ENGINES_CACHE_DIR` (default: `~/llm-engines/generation_cache`) 349 | - `overwrite_cache`: whether to overwrite the cache (default: False) 350 | - `dtype`: the data type to use (default: "auto"; {auto,half,float16,bfloat16,float,float32}) 351 | - `quantization`: specify the quantization type, one of {aqlm,awq,deepspeedfp,tpu_int8,fp8,fbgemm_fp8,marlin,gguf,gptq_marlin_24,gptq_marlin,awq_marlin,gptq,squeezellm,compressed-tensors,bitsandbytes,qqq,experts_int8} (default: None) 352 | - `max_retry`: the maximum number of retries for the request (default: None) 353 | - `completion`: whether to use the completion API; If you use completion, (default: False) 354 | 355 | 356 | ### Generation parameters (`call_model`, `batch_call_model`) 357 | - `inputs`: the list of inputs for the model; Either a list of strings or a list of dictionaries for multi-turn conversation in openai conversation format; If `completion` is True, it should be a single string (required) 358 | - `top_p`: the nucleus sampling parameter, 0.0 means no sampling (default: 1.0) 359 | - `temperature`: the randomness of the generation, 0.0 means deterministic generation (default: 0.0) 360 | - `max_tokens`: the maximum number of tokens to generate, `None` means no limit (default: None) 361 | - `timeout`: the maximum time to wait for the response, `None` means no limit (default: 300) 362 | - `frequency_penalty`: Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. (default: 0.0) 363 | - `presence_penalty`: Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. (default: 0.0) 364 | - `n`: Number of completions to generate for each prompt. (**only vllm, sglang, openai have this feature**) (default: 1) 365 | - `stream`: Whether to stream the response or not. If True, `n` will be ignored. (default: False) 366 | - `conv_system_msg`: The system message for multi-turn conversation; If the meessage contains a system message, this parameter will be overwritten (default: None) 367 | - `logprobs`: Whether to return the log probabilities of the generated tokens, True/False/None (default: None) 368 | - all the other parameters that are supported by different engines. 369 | - for openai and sglang, check out [openai](https://platform.openai.com/docs/api-reference/chat) 370 | - for extra paramters of vllm, check out [vllm](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters) 371 | 372 | ### Launch a separate vllm/sglang model worker 373 | 374 | - launch a separate vllm worker 375 | 376 | ```bash 377 | CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B-Instruct --dtype auto --host "127.0.0.1" --port 34200 --tensor-parallel-size 1 --disable-log-requests & 378 | # address: http://127.0.0.1:34200 379 | ``` 380 | 381 | - launch a separate sglang worker 382 | ```bash 383 | CUDA_VISIBLE_DEVICES=1 python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --dtype auto --host "127.0.0.1" --port 34201 --tp-size 1 & 384 | CUDA_VISIBLE_DEVICES=1 python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --dtype auto --host "127.0.0.1" --port 34201 --tp-size 1 --disable-flashinfer & # disable flashinfer if it's not installed 385 | # address: http://127.0.0.1:34201 386 | ``` 387 | 388 | - query multiple existing workers 389 | ```python 390 | from llm_engines import ModelWorker 391 | call_worker_func = ModelWorker( 392 | model_name="Qwen/Qwen2.5-0.5B-Instruct", 393 | worker_addrs=["http://127.0.0.1:34200", "http://127.0.0.1:34201"], # many workers can be used, will be load balanced 394 | engine="sglang", 395 | use_cache=False 396 | ) 397 | response = ModelWorker(["What is the capital of France?"], temperature=0.0, max_tokens=None) 398 | print(response) 399 | # The capital of France is Paris. 400 | ``` 401 | 402 | ### Test notes 403 | 404 | When setting `temperature=0.0` and `max_tokens=None`, testing long generations: 405 | - VLLM (fp16) can generate same outputs with hugging face transformers (fp16) generations, but not for bf16. 406 | - Together AI can generate almost the same outputs with vllm (fp16, bf16) generations 407 | - SGLang's outputs outputs are sometimes not consistent with others. 408 | - note that some weird inputs will cause the models to inference forever, it's better to set `timeout` (default: 300) to drop the request after certain seconds. 409 | - Bug: [issue](https://github.com/vllm-project/vllm/issues/7196) of `vllm==0.5.4` when num_workers > 1, use `vllm==0.5.5` instead. 410 | - Try not load the same openai models with different cache directories, the current code only loads the cache from the first provided cache directory. But when writing the cache, it will write to different cache directories correspondingly. This might cause some confusion when using. 411 | 412 | ## Star History 413 | 414 | [![Star History Chart](https://api.star-history.com/svg?repos=jdf-prog/LLM-Engines&type=Date)](https://star-history.com/#jdf-prog/LLM-Engines&Date) 415 | 416 | ## Citation 417 | ```bibtex 418 | @misc{jiang2024llmengines, 419 | title = {LLM-Engines: A unified and parallel inference engine for large language models}, 420 | author = {Dongfu Jiang}, 421 | year = {2024}, 422 | publisher = {GitHub}, 423 | journal = {GitHub repository}, 424 | howpublished = {\url{https://github.com/jdf-progLLM-Engines}}, 425 | } 426 | ``` 427 | -------------------------------------------------------------------------------- /examples/batch_inference_wildchat.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import json 3 | import os 4 | import datasets 5 | from llm_engines import LLMEngine 6 | from llm_engines.utils import MaxRetriesExceededError 7 | 8 | def main( 9 | dataset: str="allenai/WildChat", 10 | model_name: str="meta-llama/Meta-Llama-3-8B-Instruct", 11 | engine: str="vllm", 12 | worker_addrs: str=None, 13 | num_workers: int=2, 14 | num_gpu_per_worker: int=1, 15 | overwrite=False, 16 | max_size=100, 17 | ): 18 | # input_file is a hugingface dataset 19 | dataset = datasets.load_dataset(dataset, split='train') 20 | if max_size and max_size < len(dataset): 21 | dataset = dataset.select(range(max_size)) 22 | print(f"Dataset truncated to {max_size} examples") 23 | 24 | def format_func(item): 25 | return {"query": item["conversation"][0]['content']} 26 | 27 | dataset = dataset.map(format_func, remove_columns=dataset.column_names) 28 | 29 | output_file="./wildchat.inference.jsonl" 30 | 31 | if os.path.exists(output_file) and not overwrite: 32 | print(f"Output file {output_file} exists and overwrite is set to False. Skipping.") 33 | exit(0) 34 | else: 35 | llm = LLMEngine() 36 | llm.load_model( 37 | model_name=model_name, 38 | engine=engine, 39 | worker_addrs=worker_addrs, 40 | num_workers=num_workers, 41 | num_gpu_per_worker=num_gpu_per_worker, 42 | use_cache=True, 43 | max_retry=1 44 | ) 45 | 46 | generation_kwargs = { 47 | "temperature": 0.0, 48 | "max_tokens": 4096, 49 | } 50 | batch_messages = [item['query'] for item in dataset] 51 | responses = llm.batch_call_model(model_name, batch_messages, **generation_kwargs, num_proc=num_workers * 16) 52 | dataset = dataset.add_column("response", responses) 53 | 54 | def filter_none(item): 55 | return item['response'] is not None 56 | print(f"Before filtering None responses: {len(dataset)}") 57 | dataset = dataset.filter(filter_none) 58 | print(f"After filtering None responses: {len(dataset)}") 59 | 60 | dataset.to_json(output_file, orient="records", lines=True) 61 | 62 | 63 | if __name__ == "__main__": 64 | fire.Fire(main) 65 | 66 | 67 | """ 68 | python mp_inference_wildchat.py 69 | """ -------------------------------------------------------------------------------- /examples/mp_inference_wildchat.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import json 3 | import os 4 | import datasets 5 | from llm_engines import ModelWorker 6 | from llm_engines.utils import MaxRetriesExceededError 7 | 8 | def main( 9 | dataset: str="allenai/WildChat", 10 | model_name: str="meta-llama/Meta-Llama-3-8B-Instruct", 11 | engine: str="vllm", 12 | worker_addrs: str=None, 13 | num_workers: int=2, 14 | num_gpu_per_worker: int=1, 15 | overwrite=False, 16 | max_size=1000, 17 | ): 18 | # input_file is a hugingface dataset 19 | dataset = datasets.load_dataset(dataset, split='train') 20 | if max_size and max_size < len(dataset): 21 | dataset = dataset.select(range(max_size)) 22 | print(f"Dataset truncated to {max_size} examples") 23 | 24 | def format_func(item): 25 | return {"query": item["conversation"][0]['content']} 26 | 27 | dataset = dataset.map(format_func, remove_columns=dataset.column_names) 28 | 29 | output_file="./wildchat.inference.jsonl" 30 | 31 | if os.path.exists(output_file) and not overwrite: 32 | print(f"Output file {output_file} exists and overwrite is set to False. Skipping.") 33 | exit(0) 34 | else: 35 | call_worker = ModelWorker( 36 | model_name=model_name, 37 | engine=engine, 38 | worker_addrs=worker_addrs, 39 | num_workers=num_workers, 40 | num_gpu_per_worker=num_gpu_per_worker, 41 | use_cache=True, 42 | max_retry=1 43 | ) 44 | 45 | generation_kwargs = { 46 | "temperature": 0.0, 47 | "max_tokens": 4096, 48 | } 49 | def map_generate(item): 50 | try: 51 | response = call_worker([item['query']], **generation_kwargs) 52 | except MaxRetriesExceededError: 53 | response = None 54 | item['response'] = response 55 | return item 56 | 57 | 58 | new_dataset = dataset.map(map_generate, num_proc=num_workers * 16) 59 | 60 | def filter_none(item): 61 | return item['response'] is not None 62 | print(f"Before filtering None responses: {len(new_dataset)}") 63 | new_dataset = new_dataset.filter(filter_none) 64 | print(f"After filtering None responses: {len(new_dataset)}") 65 | 66 | new_dataset.to_json(output_file, orient="records", lines=True) 67 | 68 | 69 | if __name__ == "__main__": 70 | fire.Fire(main) 71 | 72 | 73 | """ 74 | python mp_inference_wildchat.py 75 | """ -------------------------------------------------------------------------------- /llm_engines/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import random 5 | import atexit 6 | import psutil 7 | import hashlib 8 | import requests 9 | import subprocess 10 | from packaging import version 11 | from functools import partial 12 | from .utils import retry_on_failure, convert_messages_wrapper, SubprocessMonitor, MaxRetriesExceededError, max_retry_wrapper 13 | from .cache import get_batch_cache_dir, generation_cache_wrapper 14 | from typing import Union, List 15 | from tqdm import tqdm 16 | 17 | import importlib.util 18 | flash_attn = importlib.util.find_spec("flash_attn") 19 | if not flash_attn: 20 | print("Warning: flash_attn not found, recommend to install flash_attn for better performance") 21 | print("Simple Command: pip install flash_attn --no-build-isolation") 22 | print("Please refer to https://github.com/Dao-AILab/flash-attention for detailed installation instructions") 23 | 24 | ENGINES = ["vllm", "sglang", "openai", "gemini", "mistral", "together", "claude"] 25 | all_workers = [] 26 | verbose = False 27 | def set_verbose(value): 28 | global verbose 29 | verbose = value 30 | 31 | class WorkerInstance: 32 | def __init__(self, model_name, worker_addr, proc, gpu_ids=None): 33 | self.model_name = model_name 34 | self.worker_addr = worker_addr 35 | self.proc = proc 36 | self.gpu_ids = gpu_ids 37 | 38 | def __str__(self): 39 | return f"WorkerInstance(model_name={self.model_name}, worker_addr={self.worker_addr}, proc={self.proc}, gpu_ids={self.gpu_ids})" 40 | 41 | def __repr__(self): 42 | return self.__str__() 43 | 44 | 45 | class ModelWorker: 46 | def __init__( 47 | self, 48 | model_name, 49 | worker_addrs=None, 50 | cache_dir=None, 51 | use_cache=True, 52 | overwrite_cache=False, 53 | completion=False, 54 | num_workers=1, 55 | num_gpu_per_worker=None, 56 | gpu_ids=None, 57 | dtype="auto", 58 | quantization=None, 59 | engine="vllm", 60 | additional_args=[], 61 | max_retry=None, 62 | verbose=False, 63 | ): 64 | """ 65 | Return a function that calls the model worker, takes a list of messages (user, gpt, user, ...) and returns the generated text 66 | Args: 67 | model_name: model name 68 | worker_addrs: worker addresses, if None, launch local workers 69 | cache_dir: cache directory 70 | use_cache: use cache or not. Cache is on the hash of the input message. 71 | overwrite_cache: overwrite cache or not. If True, previous cache will be overwritten. 72 | completion: use completion or not (use chat by default) 73 | num_workers: number of workers 74 | num_gpu_per_worker: number of gpus per worker 75 | dtype: data type 76 | engine: engine name 77 | additional_args: additional arguments for launching the worker (vllm, sglang) 78 | """ 79 | self.model_name = model_name 80 | self.worker_addrs = worker_addrs 81 | self.cache_dir = cache_dir 82 | self.use_cache = use_cache 83 | self.overwrite_cache = overwrite_cache 84 | self.completion = completion 85 | self.num_workers = num_workers 86 | self.num_gpu_per_worker = num_gpu_per_worker 87 | self.gpu_ids = gpu_ids 88 | self.dtype = dtype 89 | self.quantization = quantization 90 | self.engine = engine 91 | self.additional_args = additional_args 92 | self.max_retry = max_retry 93 | self.verbose = verbose 94 | self.worker_instances = [] # cuda workers instances 95 | self.is_sleeping = False 96 | 97 | if engine == "openai": 98 | from .openai_text import call_worker_openai, call_worker_openai_completion 99 | call_model_worker = call_worker_openai if not completion else call_worker_openai_completion 100 | elif engine == "gemini": 101 | if completion: 102 | raise ValueError(f"Engine {engine} does not support completion") 103 | from .gemini import call_worker_gemini 104 | call_model_worker = call_worker_gemini 105 | elif engine == "claude": 106 | if completion: 107 | raise ValueError(f"Engine {engine} does not support completion") 108 | from .claude import call_worker_claude 109 | call_model_worker = call_worker_claude 110 | elif engine == "mistral": 111 | if completion: 112 | raise ValueError(f"Engine {engine} does not support completion") 113 | from .mistral import call_worker_mistral 114 | call_model_worker = call_worker_mistral 115 | elif engine == "together": 116 | from .together import call_worker_together, call_worker_together_completion 117 | call_model_worker = call_worker_together if not completion else call_worker_together_completion 118 | elif engine == "grok": 119 | from .grok import call_worker_grok, call_worker_grok_completion 120 | call_model_worker = call_worker_grok if not completion else call_worker_grok_completion 121 | elif engine == "fireworks": 122 | from .fireworks import call_worker_fireworks, call_worker_fireworks_completion 123 | call_model_worker = call_worker_fireworks if not completion else call_worker_fireworks_completion 124 | elif engine in ["vllm", "sglang"]: 125 | assert num_gpu_per_worker is not None, "num_gpu_per_worker must be provided for vllm and sglang" 126 | if engine == "vllm": 127 | from .vllm import launch_vllm_worker, call_vllm_worker, call_vllm_worker_completion 128 | call_worker_func = call_vllm_worker if not completion else call_vllm_worker_completion 129 | launch_worker_func = launch_vllm_worker 130 | elif engine == "sglang": 131 | from .sglang import launch_sglang_worker, call_sglang_worker, call_sglang_worker_completion 132 | call_worker_func = call_sglang_worker if not completion else call_sglang_worker_completion 133 | launch_worker_func = launch_sglang_worker 134 | else: 135 | raise ValueError(f"Internal error: engine {engine} not supported") 136 | if worker_addrs is None: 137 | import torch 138 | 139 | print(f"Launching model worker {model_name} locally") 140 | worker_addrs = [] 141 | total_gpus = torch.cuda.device_count() 142 | if gpu_ids: 143 | gpu_ids = [int(gpu_id) for gpu_id in gpu_ids] 144 | assert len(gpu_ids) <= total_gpus, f"Error: number of gpus {len(gpu_ids)} is greater than total gpus {total_gpus}" 145 | total_gpus = len(gpu_ids) 146 | if total_gpus < num_workers * num_gpu_per_worker: 147 | if total_gpus >= num_gpu_per_worker: 148 | print(f"Warning: total gpus ({total_gpus}) is less than num_workers * num_gpu_per_worker ({num_workers * num_gpu_per_worker}), using {total_gpus // num_gpu_per_worker} workers instead") 149 | num_workers = total_gpus // num_gpu_per_worker 150 | else: 151 | print(f"Error: total gpus ({total_gpus}) is less than num_gpu_per_worker ({num_gpu_per_worker}), exiting...") 152 | sys.exit(1) 153 | if not gpu_ids: 154 | if os.environ.get("CUDA_VISIBLE_DEVICES") is not None: 155 | gpus_ids = os.environ.get("CUDA_VISIBLE_DEVICES").split(",") 156 | gpu_ids = [int(gpu_id) for gpu_id in gpus_ids] 157 | else: 158 | gpu_ids = list(range(total_gpus)) 159 | start_port = random.randint(31000, 32000) 160 | for i in range(num_workers): 161 | worker_addr, proc = launch_worker_func(model_name, 162 | num_gpus=num_gpu_per_worker, 163 | gpu_ids=gpu_ids[i*num_gpu_per_worker:(i+1)*num_gpu_per_worker], 164 | port=start_port+i*10, 165 | dtype=dtype, quantization=quantization, additional_args=additional_args) 166 | worker = WorkerInstance(model_name, worker_addr, proc, gpu_ids=gpu_ids[i*num_gpu_per_worker:(i+1)*num_gpu_per_worker]) 167 | worker_addrs.append(worker_addr) 168 | all_workers.append(worker) 169 | self.worker_instances.append(worker) 170 | else: 171 | if verbose: 172 | print(f"Using existing worker at {worker_addrs}") 173 | if not isinstance(worker_addrs, list): 174 | worker_addrs = [worker_addr] 175 | call_model_worker = partial(call_worker_func, worker_addrs=worker_addrs) 176 | self.worker_addrs = worker_addrs 177 | self.gpu_ids = gpu_ids 178 | else: 179 | raise ValueError(f"Engine {engine} not supported, available engines: {ENGINES}") 180 | 181 | # wrap the call_model_worker with the model_name and other arguments 182 | call_model_worker = partial(call_model_worker, model_name=model_name) 183 | # test local worker connection 184 | if not completion: 185 | test_response = call_model_worker([{"role": "user", "content": "Hello"}], temperature=0, max_tokens=256, timeout=None) 186 | else: 187 | test_response = call_model_worker("Hello", temperature=0, max_tokens=256, timeout=None) 188 | if not test_response: 189 | print("Error: failed to connect to the worker, exiting...") 190 | for worker in self.worker_instances: 191 | cleanup_process(worker) 192 | sys.exit(1) 193 | else: 194 | if verbose: 195 | print(f"Successfully connected to the workers") 196 | print("Test prompt: \n", "Hello") 197 | print("Test response: \n", test_response) 198 | 199 | # add cache wrapper 200 | if use_cache: 201 | call_model_worker = generation_cache_wrapper(call_model_worker, model_name, cache_dir, overwrite_cache) 202 | else: 203 | if verbose: 204 | print("Cache is disabled") 205 | call_model_worker = retry_on_failure(call_model_worker, num_retries=max_retry) 206 | call_model_worker = convert_messages_wrapper(call_model_worker, is_completion=completion) 207 | set_do_cleanup(True) 208 | set_verbose(verbose) 209 | self.call_model_worker = call_model_worker 210 | 211 | def __call__( 212 | self, 213 | messages:Union[str, List[str], List[dict]], 214 | *args, **kwds 215 | ): 216 | if self.is_sleeping: 217 | print("Warning: Worker is sleeping, waking up...") 218 | self.wake_up_worker() 219 | return self.call_model_worker(messages, *args, **kwds) 220 | 221 | def sleep_worker(self, level=1): 222 | if self.engine in ["vllm"]: 223 | # vllm version should be >= 0.7.3 224 | from .vllm import vllm_version 225 | if version.parse(vllm_version) < version.parse("0.7.3"): 226 | raise ValueError(f"vllm version {vllm_version} does not support sleep mode, please upgrade to >= 0.7.3") 227 | for worker in self.worker_instances: 228 | response = requests.post(worker.worker_addr + "/sleep", 229 | data={"level": level}) 230 | print(worker.worker_addr + "/sleep") 231 | print(response) 232 | assert response.status_code == 200 233 | print(f"Worker {worker} is sleeping") 234 | self.is_sleeping = True 235 | elif self.engine in ["sglang"]: 236 | raise NotImplementedError(f"Engine {self.engine} does not support sleep") 237 | 238 | def wake_up_worker(self): 239 | if not self.is_sleeping: 240 | return 241 | if self.engine in ["vllm"]: 242 | # vllm version should be >= 0.7.3 243 | from .vllm import vllm_version 244 | if version.parse(vllm_version) < version.parse("0.7.3"): 245 | raise ValueError(f"vllm version {vllm_version} does not support sleep mode, please upgrade to >= 0.7.3") 246 | for worker in self.worker_instances: 247 | response = requests.post(worker.worker_addr + "/wake_up") 248 | assert response.status_code == 200 249 | print(f"Worker {worker} is woken up") 250 | elif self.engine in ["sglang"]: 251 | raise NotImplementedError(f"Engine {self.engine} does not support wake up") 252 | 253 | def __str__(self): 254 | return f"ModelWorker(model_name={self.model_name}, num_workers={self.num_workers}, num_gpu_per_worker={self.num_gpu_per_worker}, worker_addrs={self.worker_addrs}, cache_dir={self.cache_dir}, use_cache={self.use_cache}, overwrite_cache={self.overwrite_cache}, completion={self.completion}, dtype={self.dtype}, quantization={self.quantization}, engine={self.engine}, additional_args={self.additional_args}, max_retry={self.max_retry}, verbose={self.verbose})" 255 | 256 | def __repr__(self): 257 | return self.__str__() 258 | 259 | 260 | do_cleanup = True 261 | def set_do_cleanup(value): 262 | global do_cleanup 263 | do_cleanup = value 264 | 265 | def kill_process_and_children(pid): 266 | # check if the process is still alive 267 | if 'psutil' not in sys.modules: 268 | # possibly the main process is in the final stage of termination, no need to kill the child processes 269 | return None 270 | try: 271 | parent = psutil.Process(pid) 272 | except psutil.NoSuchProcess: 273 | print(f"No process with PID {pid} found.") 274 | return 275 | 276 | children = parent.children(recursive=True) 277 | for child in children: 278 | try: 279 | child.kill() 280 | # print(f"Killed child process {child.pid}") 281 | except psutil.NoSuchProcess: 282 | print(f"Child process {child.pid} already terminated.") 283 | 284 | try: 285 | parent.kill() 286 | # print(f"Killed parent process {pid}") 287 | except psutil.NoSuchProcess: 288 | print(f"Parent process {pid} already terminated.") 289 | return True 290 | 291 | def cleanup_process(worker:Union[ModelWorker, WorkerInstance, SubprocessMonitor, subprocess.Popen]): 292 | if isinstance(worker, ModelWorker): 293 | for worker_instance in worker.worker_instances: 294 | cleanup_process(worker_instance) 295 | return 296 | if isinstance(worker, WorkerInstance): 297 | proc = worker.proc.proc 298 | elif isinstance(worker, SubprocessMonitor): 299 | proc = worker.proc 300 | elif isinstance(worker, subprocess.Popen): 301 | proc = worker 302 | else: 303 | raise ValueError(f"Unknown process type {type(proc)}") 304 | killed = kill_process_and_children(proc.pid) 305 | if verbose and killed: 306 | print(f"Model Worker terminated: {worker} ") 307 | return killed 308 | 309 | @atexit.register 310 | def cleanup_all_workers(): 311 | if not do_cleanup: 312 | return 313 | for worker in all_workers: 314 | cleanup_process(worker) 315 | if all_workers and verbose: 316 | print("All workers terminated.") 317 | all_workers.clear() 318 | 319 | 320 | class LLMEngine: 321 | 322 | def __init__(self, verbose=False, num_gpus: int = None, gpu_ids: Union[List[int], str] = None): 323 | self.workers = [] 324 | self.loaded_model_worker = {} 325 | import torch 326 | self.verbose = verbose 327 | total_gpus = torch.cuda.device_count() 328 | if gpu_ids: 329 | assert isinstance(gpu_ids, (list, str)), "passed gpu_ids must be a list or a string" 330 | if isinstance(gpu_ids, str): 331 | gpu_ids = [int(gpu_id) for gpu_id in gpu_ids.split(",")] 332 | assert all(isinstance(gpu_id, int) for gpu_id in gpu_ids), "passed gpu_ids must be a list of integers" 333 | assert len(gpu_ids) <= total_gpus, f"Error: passed gpu_ids {gpu_ids} is greater than total gpus {total_gpus}" 334 | else: 335 | if os.environ.get("CUDA_VISIBLE_DEVICES") is not None: 336 | gpu_ids = [int(gpu_id) for gpu_id in os.environ.get("CUDA_VISIBLE_DEVICES").split(",")] 337 | else: 338 | gpu_ids = list(range(total_gpus)) 339 | if num_gpus is not None: 340 | num_gpus = int(num_gpus) 341 | assert num_gpus <= total_gpus, f"Error: passed num_gpus {num_gpus} is greater than total gpus {total_gpus}" 342 | gpu_ids = gpu_ids[:num_gpus] 343 | self.gpu_ids = gpu_ids 344 | self.num_gpus = len(gpu_ids) 345 | if verbose: 346 | print(f"LLMEngine initialized with {self.num_gpus} GPUs: {gpu_ids}") 347 | 348 | def get_available_gpu_ids(self): 349 | worker_used_gpu_ids = [] 350 | for worker in self.workers: 351 | for worker_instance in worker.worker_instances: 352 | worker_used_gpu_ids.extend(worker_instance.gpu_ids) 353 | available_gpu_ids = [gpu_id for gpu_id in self.gpu_ids if gpu_id not in worker_used_gpu_ids] 354 | return available_gpu_ids 355 | 356 | def load_model( 357 | self, 358 | model_name, 359 | worker_addrs=None, 360 | cache_dir=None, 361 | use_cache=True, 362 | overwrite_cache=False, 363 | completion=False, 364 | num_workers=1, 365 | num_gpu_per_worker=None, 366 | dtype="auto", 367 | quantization=None, 368 | engine="vllm", 369 | additional_args=[], 370 | max_retry=None, 371 | verbose=None 372 | ): 373 | """ 374 | Load a model 375 | Args: 376 | model_name: model name 377 | worker_addrs: worker addresses, if None, launch local workers 378 | cache_dir: cache directory 379 | use_cache: use cache or not. Cache is on the hash of the input message. 380 | overwrite_cache: overwrite cache or not. If True, previous cache will be overwritten. 381 | completion: use completion or not (use chat by default) 382 | num_workers: number of workers 383 | num_gpu_per_worker: number of gpus per worker 384 | dtype: data type 385 | engine: engine name 386 | additional_args: additional arguments for launching the worker (vllm, sglang) 387 | max_retry: maximum number of retries 388 | verbose: verbose 389 | """ 390 | verbose = self.verbose or verbose 391 | if self.workers: 392 | print("Warning: previous workers are not cleaned up, please call unload_model() to clean up previous workers") 393 | self.model_name = model_name 394 | available_gpu_ids = self.get_available_gpu_ids() 395 | if engine in ["vllm", "sglang"]: 396 | if not num_gpu_per_worker: 397 | print("Warning: num_gpu_per_worker not provided, using 1 GPU per worker") 398 | num_gpu_per_worker = 1 399 | num_required_gpus = num_workers * num_gpu_per_worker 400 | if len(available_gpu_ids) < num_required_gpus: 401 | print("Error: No available GPU to launch the model worker") 402 | print("Provided GPU IDs for this LLMEngine class: ", self.gpu_ids) 403 | print("Used GPU IDs for all workers: ", [worker_instance.gpu_ids for worker in self.workers for worker_instance in worker.worker_instances]) 404 | print("Available GPU IDs: ", available_gpu_ids) 405 | print("Number of required GPUs: ", num_required_gpus) 406 | raise ValueError("Not enought available GPU to launch the model worker") 407 | gpu_ids = available_gpu_ids[:num_required_gpus] 408 | else: 409 | num_required_gpus = num_workers 410 | gpu_ids = None 411 | 412 | model_worker = ModelWorker( 413 | model_name, 414 | worker_addrs=worker_addrs, 415 | cache_dir=cache_dir, 416 | use_cache=use_cache, 417 | overwrite_cache=overwrite_cache, 418 | completion=completion, 419 | num_workers=num_workers, 420 | num_gpu_per_worker=num_gpu_per_worker, 421 | gpu_ids=gpu_ids, 422 | dtype=dtype, 423 | quantization=quantization, 424 | engine=engine, 425 | additional_args=additional_args, 426 | max_retry=max_retry, 427 | verbose=verbose, 428 | ) 429 | self.workers.append(model_worker) 430 | self.loaded_model_worker[model_name] = model_worker 431 | return model_worker 432 | 433 | def sleep_model( 434 | self, 435 | model_name, 436 | level=1 437 | ): 438 | model_worker = self.loaded_model_worker.get(model_name) 439 | model_worker.sleep_worker(level=level) 440 | 441 | def wake_up_model( 442 | self, 443 | model_name, 444 | ): 445 | model_worker = self.loaded_model_worker.get(model_name) 446 | model_worker.wake_up_worker() 447 | 448 | def call_model( 449 | self, 450 | model_name, 451 | messages:Union[List[str], List[dict], str], 452 | timeout:int=60, 453 | conv_system_msg=None, 454 | **generate_kwargs 455 | ): 456 | """ 457 | Call a model 458 | Args: 459 | model_name: model name 460 | messages: list of messages in openai format 461 | timeout: timeout 462 | conv_system_msg: conversation system message 463 | generate_kwargs: generation arguments 464 | """ 465 | call_model_worker = self.loaded_model_worker.get(model_name) 466 | if call_model_worker is None: 467 | raise ValueError(f"Model {model_name} not loaded, please call load_model() first") 468 | try: 469 | return call_model_worker(messages, timeout=timeout, conv_system_msg=conv_system_msg, **generate_kwargs) 470 | except MaxRetriesExceededError as e: 471 | print(e) 472 | return None 473 | 474 | def batch_call_model( 475 | self, 476 | model_name, 477 | batch_messages:List[Union[List[str], List[dict], str]], 478 | timeout:int=60, 479 | conv_system_msg=None, 480 | num_proc=8, 481 | desc=None, 482 | disable_batch_api=False, 483 | max_batch_size=None, 484 | **generate_kwargs 485 | ): 486 | """ 487 | Batch call a model 488 | Args: 489 | model_name: model name 490 | batch_messages: list of list of messages in openai format or list of strings 491 | timeout: timeout 492 | conv_system_msg: conversation system message 493 | num_proc: number of processes 494 | generate_kwargs: generation arguments 495 | """ 496 | supported_batch_api_engines = ["openai", "claude"] 497 | model_worker = self.loaded_model_worker.get(model_name) 498 | engine = model_worker.engine 499 | if engine not in supported_batch_api_engines or disable_batch_api: 500 | if model_worker is None: 501 | raise ValueError(f"Model {model_name} not loaded, please call load_model() first") 502 | 503 | batch_cache_dir = get_batch_cache_dir(model_name, None) 504 | to_write_batch_messages = [ 505 | {"input": message, "generation_kwargs": generate_kwargs} for message in batch_messages 506 | ] 507 | hash_str = hashlib.md5("".join([ 508 | str(x) for x in to_write_batch_messages 509 | ]).encode()).hexdigest() 510 | hash_result_file = batch_cache_dir / f"{hash_str}_batch_results.jsonl" 511 | if hash_result_file.exists(): 512 | results = [] 513 | with open(hash_result_file, "r") as f: 514 | for line in f: 515 | message = json.loads(line) 516 | results.append(message["output"]) 517 | else: 518 | from functools import partial 519 | from multiprocessing import Pool 520 | num_proc = min(num_proc, len(batch_messages)) 521 | if model_worker.is_sleeping: 522 | print("Warning: Worker is sleeping, waking up...") 523 | model_worker.wake_up_worker() 524 | call_model_worker_mp = partial(model_worker.call_model_worker, timeout=timeout, conv_system_msg=conv_system_msg, **generate_kwargs) 525 | call_model_worker_mp = partial(max_retry_wrapper, call_model_worker_mp) 526 | with Pool(num_proc) as p: 527 | results = list(tqdm(p.imap(call_model_worker_mp, batch_messages), total=len(batch_messages), desc=desc or "LLMEngine Batch Inference")) 528 | if results: 529 | for i, message in enumerate(to_write_batch_messages): 530 | message["output"] = results[i] 531 | with open(batch_cache_dir / f"{hash_str}_batch_results.jsonl", "w") as f: 532 | for message in to_write_batch_messages: 533 | f.write(json.dumps(message) + "\n") 534 | else: 535 | if engine == "openai": 536 | print("Using OpenAI batch API") 537 | from .openai_text import openai_batch_request, check_batch_status, get_batch_progress, get_batch_result 538 | batch_request_func = openai_batch_request 539 | elif engine == "claude": 540 | print("Using Claude batch API") 541 | from .claude import claude_batch_request, check_batch_status, get_batch_progress, get_batch_result 542 | batch_request_func = claude_batch_request 543 | else: 544 | raise ValueError(f"Engine {engine} not supported for batch API") 545 | if max_batch_size is None: 546 | results = batch_request_func(model_name, batch_messages, conv_system_msg=conv_system_msg, desc=desc, **generate_kwargs) 547 | else: 548 | # using multiprocess to submit batch request per batch 549 | from functools import partial 550 | import time 551 | 552 | max_slots = num_proc 553 | all_batch_inputs = [ 554 | batch_messages[i:i+max_batch_size] for i in range(0, len(batch_messages), max_batch_size) 555 | ] 556 | # submit detach jobs 557 | batch_ids = [None] * len(all_batch_inputs) 558 | # wait for all jobs to finish and periodically check the status 559 | idx = 0 560 | tqdm_bar = tqdm(total=len(batch_ids), desc=desc or "LLMEngine Batch Inference") 561 | all_batch_status = [check_batch_status(batch_id)['status'] if batch_id is not None else "pending" for batch_id in batch_ids] 562 | while True: 563 | batch_id = batch_ids[idx] 564 | if batch_id is None: 565 | cur_slots = len([bstatus for bstatus in all_batch_status if bstatus not in ['pending', 'completed', 'cancelled', 'canceled', "failed", "expired"]]) 566 | if cur_slots < max_slots: 567 | batch_id = batch_request_func(model_name, all_batch_inputs[idx], conv_system_msg=conv_system_msg, desc=desc, **generate_kwargs, detach=True) 568 | batch_ids[idx] = batch_id 569 | else: 570 | tqdm_bar.n = 0 571 | tqdm_bar.total = len(all_batch_inputs[idx]) 572 | tqdm_bar.desc = "pending" + f" (batch {idx+1}/{len(batch_ids)})" 573 | tqmd_postfix = tqdm_postfix = { 574 | "completed": 0, 575 | "total": len(all_batch_inputs[idx]), 576 | "failed": 0, 577 | } 578 | tqdm_bar.set_postfix_str(tqmd_postfix) 579 | tqdm_bar.refresh() 580 | idx = (idx + 1) % len(batch_ids) 581 | time.sleep(5) 582 | continue 583 | 584 | n, total, tqdm_postfix, cur_batch_status = get_batch_progress(batch_id) 585 | tqdm_bar.set_postfix_str(tqdm_postfix) 586 | tqdm_bar.n = n 587 | tqdm_bar.total = total 588 | tqdm_bar.desc = cur_batch_status + f" (batch {idx+1}/{len(batch_ids)})" 589 | tqdm_bar.refresh() 590 | all_batch_status[idx] = cur_batch_status 591 | if all(status == "completed" for status in all_batch_status): 592 | tqdm_bar.close() 593 | break 594 | idx = (idx + 1) % len(batch_ids) 595 | time.sleep(10) 596 | 597 | # collect results 598 | all_batch_results = [] 599 | for i, batch_id in enumerate(batch_ids): 600 | batch_results = get_batch_result(batch_id, generate_kwargs=generate_kwargs) 601 | if batch_results: 602 | all_batch_results.extend(batch_results) 603 | else: 604 | raise ValueError(f"Warning: batch {batch_id} has no results") 605 | results = all_batch_results 606 | 607 | return results 608 | 609 | def __call__(self, *args, **kwds): 610 | return self.call_model(*args, **kwds) 611 | 612 | def unload_model(self, model_name=None): 613 | to_remove_local_workers = [] 614 | to_remove_global_workers = [] 615 | for worker in self.workers: 616 | if model_name is None or worker.model_name == model_name: 617 | print(f"Unloading model worker: {worker}") 618 | cleanup_process(worker) 619 | if worker in all_workers: 620 | to_remove_global_workers.append(worker) 621 | if worker in self.workers: 622 | to_remove_local_workers.append(worker) 623 | for worker in to_remove_global_workers: 624 | all_workers.remove(worker) 625 | for worker in to_remove_local_workers: 626 | self.workers.remove(worker) 627 | 628 | def __del__(self): 629 | pass 630 | 631 | -------------------------------------------------------------------------------- /llm_engines/cache/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | from functools import partial 5 | from typing import Union, List 6 | from pathlib import Path 7 | from .cache_sqlite3 import load_cache as load_sqlite3_cache 8 | from .cache_lru import load_cache as load_lru_cache 9 | from .cache_dict import load_cache as load_dict_cache 10 | from .cache_utils import get_cache_file, get_inputs_hash, get_printable_messages, get_batch_cache_dir 11 | 12 | load_cache = load_sqlite3_cache 13 | 14 | def _generation_cache_wrapper(inputs: Union[str, List[dict]], call_model_worker, model_name, cache_dir=None, overwrite_cache=False, **generate_kwargs): 15 | cache_dict = load_cache(model_name, cache_dir) 16 | 17 | conv_system_msg = generate_kwargs.get("conv_system_msg", "") 18 | if "n" in generate_kwargs: 19 | non_hash_keys = ["timeout", "stream"] 20 | inputs_hash = get_inputs_hash(inputs, conv_system_msg, {k: v for k, v in generate_kwargs.items() if k not in non_hash_keys}) 21 | else: 22 | inputs_hash = get_inputs_hash(inputs, conv_system_msg) 23 | 24 | if not overwrite_cache: 25 | cached_value = cache_dict[inputs_hash] 26 | if cached_value: 27 | if "logprobs" not in generate_kwargs or not generate_kwargs["logprobs"]: 28 | return cached_value["output"] 29 | elif "logprobs" in cached_value: 30 | return cached_value["output"], cached_value["logprobs"] 31 | 32 | response = call_model_worker(inputs, **generate_kwargs) 33 | if isinstance(response, tuple): 34 | generated_text, logprobs = response 35 | else: 36 | generated_text = response 37 | logprobs = None 38 | cache_item = { 39 | "input": get_printable_messages(inputs), 40 | "output": generated_text, 41 | "logprobs": logprobs, 42 | "model_name": model_name, 43 | 'tstamp': time.time(), 44 | "time": time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()), 45 | "generate_kwargs": generate_kwargs 46 | } 47 | 48 | # cache_dict[inputs_hash] = cache_item 49 | 50 | cache_file = get_cache_file(model_name, cache_dir) 51 | with open(cache_file, "a+") as f: 52 | f.write(json.dumps({inputs_hash: cache_item}) + "\n") 53 | 54 | return response 55 | 56 | def generation_cache_wrapper(call_model_worker, model_name, cache_dir=None, overwrite_cache=False): 57 | print(f"Using efficient multi-level cache for model {model_name}") 58 | if cache_dir is None: 59 | env_cache_dir = os.getenv("LLM_ENGINES_CACHE_DIR") 60 | if env_cache_dir: 61 | cache_dir = Path(env_cache_dir) 62 | else: 63 | cache_dir = Path(os.path.expanduser(f"~/llm_engines/generation_cache")) 64 | print(f"Cache directory: {cache_dir}") 65 | load_cache(model_name, cache_dir) # preload cache 66 | 67 | return partial(_generation_cache_wrapper, call_model_worker=call_model_worker, model_name=model_name, cache_dir=cache_dir, overwrite_cache=overwrite_cache) -------------------------------------------------------------------------------- /llm_engines/cache/cache_dict.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | from .cache_utils import get_cache_file 5 | from collections import defaultdict 6 | cache_dict = {} 7 | loaded_cache_files = defaultdict(list) 8 | def load_cache(model_name, cache_dir=None): 9 | global cache_dict 10 | global loaded_cache_files 11 | if model_name not in cache_dict: 12 | if cache_dir is None: 13 | cache_dir = Path(os.path.expanduser(f"~/llm_engines/generation_cache")) 14 | else: 15 | cache_dir = Path(cache_dir) 16 | cache_file = get_cache_file(model_name, cache_dir) 17 | if cache_file.exists(): 18 | print("Cache file exists at:", cache_file.absolute()) 19 | if model_name not in loaded_cache_files or cache_file.absolute() not in loaded_cache_files[model_name]: 20 | print(f"Loading cache from {cache_file}") 21 | with open(cache_file, "r") as f: 22 | model_cache_dict = [json.loads(line) for line in f.readlines()] 23 | model_cache_dict = {list(item.keys())[0]: list(item.values())[0] for item in model_cache_dict} 24 | # only keep the output in the value 25 | cache_dict[model_name] = defaultdict(lambda: None) 26 | for key, value in model_cache_dict.items(): 27 | cache_dict[model_name][key] = value["output"] 28 | loaded_cache_files[model_name].append(cache_file.absolute()) 29 | else: 30 | cache_dict[model_name] = {} 31 | return cache_dict[model_name] -------------------------------------------------------------------------------- /llm_engines/cache/cache_lru.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import os 4 | import json 5 | import hashlib 6 | from pathlib import Path 7 | from typing import Union, List 8 | from typing import List 9 | from cachetools import LRUCache 10 | from functools import lru_cache 11 | from .cache_utils import get_cache_file 12 | from tqdm import tqdm 13 | from collections import defaultdict 14 | 15 | # Global cache dictionary using LRUCache 16 | cache_dict = {} 17 | loaded_cache_files = defaultdict(list) 18 | # Adjust this value based on your memory constraints and requirements 19 | MAX_CACHE_SIZE = 500000 # Example: 100k items 20 | 21 | class LRUCacheManager: 22 | def __init__(self, maxsize=MAX_CACHE_SIZE): 23 | self.cache = LRUCache(maxsize=maxsize) 24 | 25 | def get(self, key): 26 | if key not in self.cache: 27 | return None 28 | return self.cache.get(key) 29 | 30 | def set(self, key, value): 31 | self.cache[key] = value 32 | 33 | def __getitem__(self, key): 34 | return self.get(key) 35 | 36 | def __setitem__(self, key, value): 37 | self.set(key, value) 38 | 39 | def __del__(self): 40 | self.save() 41 | 42 | def load_cache(model_name, cache_dir=None): 43 | global cache_dict 44 | global loaded_cache_files 45 | 46 | if model_name not in cache_dict: 47 | if cache_dir is None: 48 | cache_dir = Path(os.path.expanduser(f"~/llm_engines/generation_cache")) 49 | else: 50 | cache_dir = Path(cache_dir) 51 | cache_dict[model_name] = LRUCacheManager() 52 | cache_file = get_cache_file(model_name, cache_dir) 53 | if cache_file.exists(): 54 | print("Cache file exists at:", cache_file.absolute()) 55 | if model_name not in loaded_cache_files or cache_file.absolute() not in loaded_cache_files[model_name]: 56 | with open(cache_file, "r") as f: 57 | for line in tqdm(f, desc="Loading cache for model: " + model_name): 58 | item = json.loads(line) 59 | key = list(item.keys())[0] 60 | value = list(item.values())[0] 61 | # cache_dict[model_name][key] = {"output": value["output"]} 62 | cache_dict[model_name][key] = value 63 | loaded_cache_files[model_name].append(cache_file.absolute()) 64 | 65 | return cache_dict[model_name] 66 | 67 | -------------------------------------------------------------------------------- /llm_engines/cache/cache_sqlite3.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | import atexit 5 | import sqlite3 6 | import mmap 7 | import threading 8 | import struct 9 | from .cache_utils import get_cache_file 10 | from pathlib import Path 11 | from typing import Union, List, Dict 12 | from cachetools import LRUCache 13 | from tqdm import tqdm 14 | from collections import defaultdict 15 | 16 | BLOCK_SIZE = 3072 * 1024 # 3MB 17 | MAX_CACHE_SIZE = 100000 # Example: 10k items 18 | MAX_MEMORY_BLOCKS = 32 # Example: 32 blocks 19 | # Global cache dictionary using MultiLevelCache 20 | cache_dict = {} 21 | loaded_cache_files = defaultdict(list) 22 | 23 | class BlockCache: 24 | def __init__(self, max_size=MAX_MEMORY_BLOCKS): 25 | self.cache = LRUCache(maxsize=max_size) 26 | self.lock = threading.Lock() 27 | 28 | def get(self, block_id): 29 | with self.lock: 30 | return self.cache.get(block_id) 31 | 32 | def set(self, block_id, data): 33 | with self.lock: 34 | self.cache[block_id] = data 35 | 36 | class EfficientDiskCache: 37 | def __init__(self, cache_dir, model_name, block_size=BLOCK_SIZE, max_memory_blocks=MAX_MEMORY_BLOCKS): 38 | self.cache_dir = Path(cache_dir) / f"{model_name}_disk_cache" 39 | self.cache_dir.mkdir(parents=True, exist_ok=True) 40 | self.model_name = model_name 41 | self.block_size = block_size 42 | self.index_db = self.cache_dir / "index.db" 43 | self.block_cache = BlockCache(max_size=max_memory_blocks) 44 | self.db_lock = threading.Lock() 45 | self.init_index_db() 46 | 47 | def __del__(self): 48 | self.cleanup() 49 | 50 | def cleanup(self): 51 | try: 52 | if self.cache_dir.exists(): 53 | shutil.rmtree(self.cache_dir) 54 | print(f"Cleaned up cache directory: {self.cache_dir}") 55 | except Exception as e: 56 | print(f"Error during cleanup: {e}") 57 | 58 | def init_index_db(self): 59 | with self.db_lock: 60 | with sqlite3.connect(self.index_db) as conn: 61 | cursor = conn.cursor() 62 | cursor.execute(''' 63 | CREATE TABLE IF NOT EXISTS cache_index 64 | (key TEXT PRIMARY KEY, block_id INTEGER, offset INTEGER, length INTEGER) 65 | ''') 66 | conn.commit() 67 | 68 | def get_block_file(self, block_id): 69 | return self.cache_dir / f"block_{block_id}.bin" 70 | 71 | def get(self, key): 72 | with self.db_lock: 73 | with sqlite3.connect(self.index_db) as conn: 74 | cursor = conn.cursor() 75 | cursor.execute("SELECT block_id, offset, length FROM cache_index WHERE key = ?", (key,)) 76 | result = cursor.fetchone() 77 | 78 | if result: 79 | block_id, offset, length = result 80 | block_data = self.block_cache.get(block_id) 81 | if block_data is None: 82 | block_file = self.get_block_file(block_id) 83 | if block_file.exists(): 84 | try: 85 | with open(block_file, 'rb') as f: 86 | mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) 87 | block_data = mm.read() 88 | self.block_cache.set(block_id, block_data) 89 | except Exception as e: 90 | print(f"Error reading block file: {e}") 91 | return None 92 | else: 93 | return None 94 | 95 | try: 96 | data = block_data[offset:offset+length] 97 | value_length = struct.unpack('!I', data[:4])[0] 98 | json_data = data[4:4+value_length].decode('utf-8') 99 | return json.loads(json_data) 100 | except (struct.error, json.JSONDecodeError, UnicodeDecodeError) as e: 101 | print(f"Error decoding data: {e}") 102 | return None 103 | return None 104 | 105 | def set(self, key, value): 106 | json_data = json.dumps(value).encode('utf-8') 107 | data_length = len(json_data) 108 | full_data = struct.pack('!I', data_length) + json_data 109 | 110 | with self.db_lock: 111 | with sqlite3.connect(self.index_db) as conn: 112 | cursor = conn.cursor() 113 | cursor.execute("SELECT MAX(block_id) FROM cache_index") 114 | result = cursor.fetchone() 115 | current_block_id = result[0] if result[0] is not None else 0 116 | 117 | block_file = self.get_block_file(current_block_id) 118 | if block_file.exists(): 119 | file_size = block_file.stat().st_size 120 | if file_size + len(full_data) > self.block_size: 121 | current_block_id += 1 122 | block_file = self.get_block_file(current_block_id) 123 | file_size = 0 124 | else: 125 | file_size = 0 126 | 127 | with open(block_file, 'ab') as f: 128 | f.write(full_data) 129 | 130 | cursor.execute(''' 131 | INSERT OR REPLACE INTO cache_index (key, block_id, offset, length) 132 | VALUES (?, ?, ?, ?) 133 | ''', (key, current_block_id, file_size, len(full_data))) 134 | conn.commit() 135 | 136 | block_data = self.block_cache.get(current_block_id) 137 | if block_data is not None: 138 | self.block_cache.set(current_block_id, block_data + full_data) 139 | 140 | def bulk_insert(self, data: Dict[str, Dict]): 141 | with self.db_lock: 142 | with sqlite3.connect(self.index_db) as conn: 143 | cursor = conn.cursor() 144 | cursor.execute("SELECT MAX(block_id) FROM cache_index") 145 | result = cursor.fetchone() 146 | current_block_id = result[0] if result[0] is not None else 0 147 | 148 | block_file = self.get_block_file(current_block_id) 149 | if block_file.exists(): 150 | file_size = block_file.stat().st_size 151 | else: 152 | file_size = 0 153 | 154 | index_data = [] 155 | current_block_data = b"" 156 | 157 | for key, value in data.items(): 158 | json_data = json.dumps(value).encode('utf-8') 159 | data_length = len(json_data) 160 | full_data = struct.pack('!I', data_length) + json_data 161 | 162 | if file_size + len(full_data) > self.block_size: 163 | with open(block_file, 'ab') as f: 164 | f.write(current_block_data) 165 | self.block_cache.set(current_block_id, current_block_data) 166 | 167 | current_block_id += 1 168 | block_file = self.get_block_file(current_block_id) 169 | file_size = 0 170 | current_block_data = b"" 171 | 172 | index_data.append((key, current_block_id, file_size, len(full_data))) 173 | current_block_data += full_data 174 | file_size += len(full_data) 175 | 176 | if current_block_data: 177 | with open(block_file, 'ab') as f: 178 | f.write(current_block_data) 179 | self.block_cache.set(current_block_id, current_block_data) 180 | 181 | cursor.executemany(''' 182 | INSERT OR REPLACE INTO cache_index (key, block_id, offset, length) 183 | VALUES (?, ?, ?, ?) 184 | ''', index_data) 185 | conn.commit() 186 | 187 | class MultiLevelCache: 188 | def __init__(self, model_name, cache_dir, memory_size=MAX_CACHE_SIZE): 189 | self.memory_cache = LRUCache(maxsize=memory_size) 190 | self.disk_cache = EfficientDiskCache(cache_dir, model_name) 191 | 192 | def get(self, key): 193 | value = self.memory_cache.get(key) 194 | if value is not None: 195 | return value 196 | 197 | value = self.disk_cache.get(key) 198 | if value is not None: 199 | self.memory_cache[key] = value 200 | return value 201 | 202 | return None 203 | 204 | def __getitem__(self, key): 205 | return self.get(key) 206 | 207 | def set(self, key, value): 208 | self.memory_cache[key] = value 209 | self.disk_cache.set(key, value) 210 | 211 | def __setitem__(self, key, value): 212 | self.set(key, value) 213 | 214 | def bulk_insert(self, data: Dict[str, Dict]): 215 | self.disk_cache.bulk_insert(data) 216 | for key, value in tqdm(data.items(), desc="Bulk inserting into Memory Cache"): 217 | self.memory_cache[key] = value 218 | 219 | 220 | def load_cache(model_name, cache_dir=None): 221 | global cache_dict 222 | global loaded_cache_files 223 | if model_name not in cache_dict: 224 | if cache_dir is None: 225 | cache_dir = Path(os.path.expanduser(f"~/llm_engines/generation_cache")) 226 | else: 227 | cache_dir = Path(cache_dir) 228 | cache_file = get_cache_file(model_name, cache_dir) 229 | cache_dict[model_name] = MultiLevelCache(model_name, cache_dir) 230 | 231 | if cache_file.exists(): 232 | print("Cache file exists at:", cache_file.absolute()) 233 | if model_name not in loaded_cache_files or cache_file.absolute() not in loaded_cache_files[model_name]: 234 | initial_data = {} 235 | with open(cache_file, 'r') as f: 236 | for line in tqdm(f, desc="Loading cache for model: " + model_name): 237 | data = json.loads(line) 238 | key = list(data.keys())[0] 239 | initial_data[key] = data[key] 240 | if initial_data: 241 | cache_dict[model_name].bulk_insert(initial_data) 242 | loaded_cache_files[model_name].append(cache_file.absolute()) 243 | 244 | return cache_dict[model_name] 245 | 246 | # Cleanup function to be called at exit 247 | def cleanup_all_caches(): 248 | global cache_dict 249 | # print("Cleaning up all caches...") 250 | for model_name, cache in cache_dict.items(): 251 | cache.disk_cache.cleanup() 252 | cache_dict.clear() 253 | 254 | # Register the cleanup function to be called at exit 255 | atexit.register(cleanup_all_caches) -------------------------------------------------------------------------------- /llm_engines/cache/cache_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | import base64 4 | from copy import deepcopy 5 | from typing import List, Union 6 | from cachetools import LRUCache 7 | from functools import lru_cache 8 | from pathlib import Path 9 | 10 | @lru_cache(maxsize=None) 11 | def get_cache_file(model_name_or_path, cache_dir): 12 | model_name = model_name_or_path.split("/")[-2:] 13 | model_name = "/".join(model_name) 14 | if cache_dir is not None: 15 | cache_file = Path(cache_dir) / f"{model_name}.jsonl" 16 | else: 17 | cache_file = Path(os.path.expanduser(f"~/llm_engines/generation_cache/{model_name}.jsonl")) 18 | if not cache_file.parent.exists(): 19 | cache_file.parent.mkdir(parents=True) 20 | return cache_file 21 | 22 | @lru_cache(maxsize=None) 23 | def get_batch_cache_dir(model_name_or_path, cache_dir): 24 | model_name = model_name_or_path.split("/")[-2:] 25 | model_name = "/".join(model_name) 26 | if cache_dir is not None: 27 | batch_cache_dir = Path(cache_dir) / f"{model_name}_batch_cache" 28 | else: 29 | batch_cache_dir = Path(os.path.expanduser(f"~/llm_engines/generation_cache/{model_name}_batch_cache")) 30 | if not batch_cache_dir.exists(): 31 | batch_cache_dir.mkdir(parents=True) 32 | return batch_cache_dir 33 | 34 | MAX_PRINTABLE_IMAGE_URL_LENGTH = 100 35 | def get_printable_messages(messages): 36 | # mainly for image_url, only keep the first 100 characters 37 | messages = deepcopy(messages) 38 | for message in messages: 39 | if isinstance(message["content"], list): 40 | for sub_message in message["content"]: 41 | if sub_message["type"] == "image_url": 42 | sub_message["image_url"]["url"] = sub_message["image_url"]["url"][:MAX_PRINTABLE_IMAGE_URL_LENGTH] + \ 43 | (f"... ({len(sub_message['image_url']['url']) - MAX_PRINTABLE_IMAGE_URL_LENGTH} more characters)" if len(sub_message['image_url']['url']) > MAX_PRINTABLE_IMAGE_URL_LENGTH else "") 44 | elif sub_message["type"] == "image": 45 | sub_message["image"] = sub_message["image"][:100] 46 | return messages 47 | 48 | def get_inputs_hash(inputs:Union[str, List[dict]], conv_system_msg, generate_kwargs=None): 49 | 50 | inputs = inputs.copy() 51 | if isinstance(inputs, str): 52 | try: 53 | return hashlib.md5(inputs.encode()).hexdigest() 54 | except UnicodeEncodeError as e: 55 | return hashlib.md5(inputs.encode('utf-16', 'surrogatepass').decode('utf-16').encode('utf-8')).hexdigest() 56 | 57 | # inputs is a list of dicts in openai format 58 | 59 | if conv_system_msg: 60 | to_hash_messages = [{ 61 | "role": "system", 62 | "content": conv_system_msg 63 | }] + inputs 64 | else: 65 | to_hash_messages = inputs 66 | 67 | to_hash_inputs = [] 68 | for message in to_hash_messages: 69 | role = message["role"] 70 | content = message["content"] 71 | if isinstance(content, str): 72 | to_hash_inputs.append(f"{role}:{content}") 73 | elif isinstance(content, list): 74 | strs = [] 75 | for sub_content in content: 76 | if sub_content["type"] == "text": 77 | strs.append(sub_content["text"]) 78 | elif sub_content["type"] == "image_url": 79 | if "url" not in sub_content["image_url"]: 80 | raise ValueError("image_url must have a url key") 81 | image_url_hash = hashlib.md5(sub_content["image_url"]["url"].encode()).hexdigest() 82 | strs.append(f"image_url:{image_url_hash}") 83 | else: 84 | raise ValueError(f"Unknown content type {sub_content['type']}") 85 | to_hash_inputs.append(f"{role}:{''.join(strs)}") 86 | else: 87 | raise ValueError(f"Unknown content type {type(content)}") 88 | 89 | if generate_kwargs: 90 | to_hash_inputs.append(str(generate_kwargs)) 91 | 92 | try: 93 | return hashlib.md5("".join(to_hash_inputs).encode()).hexdigest() 94 | except UnicodeEncodeError as e: 95 | return hashlib.md5("".join(to_hash_inputs).encode('utf-16', 'surrogatepass').decode('utf-16').encode('utf-8')).hexdigest() 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /llm_engines/claude.py: -------------------------------------------------------------------------------- 1 | import anthropic 2 | import os 3 | import requests 4 | import json 5 | import hashlib 6 | import time 7 | import filelock 8 | import random 9 | import regex as re 10 | from copy import deepcopy 11 | from typing import List, Union 12 | from anthropic import NOT_GIVEN 13 | from datetime import datetime 14 | from pathlib import Path 15 | from tqdm import tqdm 16 | from .utils import with_timeout, is_base64_image_url, encode_base64_image, load_image, decode_base64_image_url 17 | batch_submission_status_file = Path(os.path.expanduser(f"~/llm_engines/generation_cache")) / "claude_batch_cache" / "batch_submission_status.json" 18 | batch_submission_status_file.parent.mkdir(parents=True, exist_ok=True) 19 | 20 | 21 | def read_batch_submission_status(): 22 | print("Loading batch submission status from", batch_submission_status_file) 23 | if batch_submission_status_file.exists(): 24 | 25 | lock = filelock.FileLock(str(batch_submission_status_file) + ".lock", timeout=30) 26 | try: 27 | with lock: 28 | with open(batch_submission_status_file, "r") as f: 29 | batch_submission_status = json.load(f) 30 | except filelock.Timeout as e: 31 | print("Timeout acquiring lock") 32 | raise e 33 | else: 34 | batch_submission_status = {} 35 | return batch_submission_status 36 | 37 | def write_batch_submission_status(batch_submission_status): 38 | batch_submission_status_file.parent.mkdir(parents=True, exist_ok=True) 39 | lock = filelock.FileLock(str(batch_submission_status_file) + ".lock", timeout=30) 40 | try: 41 | with lock: 42 | with open(batch_submission_status_file, "w") as f: 43 | json.dump(batch_submission_status, f, indent=4) 44 | except filelock.Timeout as e: 45 | print("Timeout acquiring lock") 46 | raise e 47 | 48 | # 5MB 49 | MAX_IMAGE_SIZE = 5 * 1024 * 1024 50 | def preprocess_claude_messages(messages:List[dict]) -> List[dict]: 51 | messages = deepcopy(messages) 52 | for message in messages: 53 | if isinstance(message['content'], list): 54 | for sub_message in message['content']: 55 | if sub_message['type'] == "image_url": 56 | if is_base64_image_url(sub_message['image_url']['url']): 57 | # if image size is greater than 5MB, decode and resize and re-encode 58 | im64 = sub_message['image_url']['url'].split(",", 1)[1] 59 | current_size = len(im64) 60 | # print("current_size", current_size) 61 | if current_size > MAX_IMAGE_SIZE: 62 | print("Warning: Image size is greater than 5MB. Resizing image due to Claude API limit.") 63 | image = decode_base64_image_url(sub_message['image_url']['url']) 64 | image_format = image.format if image.format else "png" 65 | scale_factor = (MAX_IMAGE_SIZE / current_size) ** 0.6 66 | image = image.resize((int(image.width * scale_factor), int(image.height * scale_factor))) 67 | im64 = encode_base64_image(image, image_format) 68 | media_type = "image/png" 69 | else: 70 | start_idx = sub_message['image_url']['url'].find("image/") 71 | end_idx = sub_message['image_url']['url'].find(";base64") 72 | media_type = sub_message['image_url']['url'][start_idx:end_idx].lower() 73 | else: 74 | image = load_image(sub_message['image_url']['url']) 75 | current_size = image.size[0] * image.size[1] * 3 76 | if current_size > MAX_IMAGE_SIZE: 77 | print("Warning: Image size is greater than 5MB. Resizing image due to Claude API limit.") 78 | scale_factor = (MAX_IMAGE_SIZE / current_size) ** 0.6 79 | image = image.resize((int(image.size[0] * scale_factor), int(image.size[1] * scale_factor))) 80 | image_format = image.format if image.format else "png" 81 | im64= encode_base64_image(image, image_format) 82 | media_type = f"image/{image_format}".lower() 83 | sub_message['source'] = { 84 | "type": "base64", 85 | "media_type": media_type, 86 | "data": im64 87 | } 88 | sub_message['type'] = "image" 89 | sub_message.pop('image_url') 90 | return messages 91 | 92 | # no image, multi-turn, do not use openai_generate, but can refer to it 93 | def call_worker_claude(messages:List[str], model_name, timeout:int=60, conv_system_msg=None, **generate_kwargs) -> str: 94 | # change messages to mistral format 95 | client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) 96 | # change messages to openai format 97 | if conv_system_msg: 98 | new_messages = [{"role": "system", "content": conv_system_msg}] + messages 99 | else: 100 | new_messages = messages 101 | new_messages = preprocess_claude_messages(new_messages) 102 | 103 | generate_kwargs.pop("n", None) # claude does not have n 104 | if not generate_kwargs.get("max_tokens", None): 105 | generate_kwargs["max_tokens"] = 1024 106 | stream = generate_kwargs.pop("stream", False) 107 | if "logprobs" in generate_kwargs: 108 | raise ValueError("Error: logprobs is not supported in claude") 109 | @with_timeout(timeout) 110 | def get_response(): 111 | completion = client.messages.create( 112 | model=model_name, 113 | messages=new_messages, 114 | system=conv_system_msg if conv_system_msg else NOT_GIVEN, 115 | timeout=timeout, 116 | **generate_kwargs, 117 | ) 118 | if len(completion.content) > 1: 119 | return [c.text for c in completion.content] 120 | else: 121 | return completion.content[0].text 122 | 123 | @with_timeout(timeout) 124 | def stream_response(): 125 | with client.messages.stream( 126 | model=model_name, 127 | messages=new_messages, 128 | system=conv_system_msg if conv_system_msg else NOT_GIVEN, 129 | timeout=timeout, 130 | **generate_kwargs, 131 | ) as stream: 132 | for text in stream.text_stream: 133 | yield text 134 | 135 | if not stream: 136 | return get_response() 137 | else: 138 | return stream_response() 139 | 140 | def save_batch_file(batch_messages:List[Union[str, dict]], model:str, batch_name:str=None, cache_dir=None, custom_ids=None, **generate_kwargs): 141 | if isinstance(batch_messages[0], str): 142 | batch_messages = [{"role": "user", "content": message} for message in batch_messages] 143 | 144 | if cache_dir is None: 145 | cache_dir = Path(os.path.expanduser(f"~/llm_engines/generation_cache")) / "openai_batch_cache" 146 | else: 147 | cache_dir = Path(cache_dir) / "openai_batch_cache" 148 | cache_dir.mkdir(parents=True, exist_ok=True) 149 | if custom_ids is None: 150 | custom_ids = [f"request-{i+1}" for i in range(len(batch_messages))] 151 | assert len(custom_ids) == len(batch_messages) 152 | if batch_name is None: 153 | # batch_name = f"{model}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" 154 | hash_str = "".join([str(message) for message in batch_messages]) 155 | batch_name = f"{model}_{hashlib.md5(hash_str.encode()).hexdigest()}" 156 | file_path = cache_dir / f"{batch_name}.jsonl" 157 | with open(file_path, "w") as f: 158 | for custom_id, message in zip(custom_ids, batch_messages): 159 | if isinstance(message, str): 160 | message = {"role": "user", "content": message} 161 | f.write(json.dumps({"custom_id": custom_id, "params": {"model": model, "messages": message, **generate_kwargs}}, ensure_ascii=False) + "\n") 162 | return file_path 163 | 164 | def submit_batch_file(batch_file:str, output_path:str=None, project_name:str=None, description:str=None): 165 | """ 166 | Submit a batch of queries to OpenAI API 167 | Args: 168 | batch_file: str, path to the batch jsonl file 169 | output_path: str, path to the output file; if not provided, will be saved in the same directory as the batch_file, ending with ".batch_results.jsonl" 170 | project_name: str, project name, default to "vapo", keyword to filter batches 171 | description: str, description of the batch, default to the batch_file name, can also be used to filter batches 172 | Returns: 173 | batch_result_id: str, the id of the batch submission 174 | """ 175 | if isinstance(batch_file, str): 176 | batch_file = Path(batch_file) 177 | 178 | # internally maintain a batch submission status json 179 | batch_submission_status = read_batch_submission_status() 180 | 181 | 182 | client = anthropic.Anthropic() 183 | with open(batch_file, "r") as f: 184 | batch_inputs = [json.loads(line) for line in f.readlines()] 185 | if not project_name: 186 | project_name = "llm_engines" 187 | description = description if description is not None else batch_file.stem 188 | output_path = output_path if output_path is not None else batch_file.parent / f"{description}.batch_results.jsonl" 189 | 190 | batch_file_hash = hashlib.md5(batch_file.read_bytes()).hexdigest() 191 | batch_file_size = batch_file.stat().st_size 192 | 193 | for key, value in batch_submission_status.items(): 194 | value_input_file_metadata = value["input_path_metadata"] 195 | if not os.path.samefile(value["input_path"], batch_file): 196 | # print(f"Batch {key} has a different input file. need to resubmit.") 197 | continue 198 | if batch_file_size != value_input_file_metadata["size"]: 199 | # print(f"Batch {key} has a newer version of the input file. need to resubmit.") 200 | continue 201 | if batch_file_hash != value_input_file_metadata["hash"]: 202 | # print(f"Batch {key} has a different input hash. need to resubmit.") 203 | continue 204 | if value['status'] in ["errored", "expired", "canceled"]: 205 | continue 206 | print(f"Batch {key} is still in progress. Skipping submission.") 207 | return key 208 | 209 | 210 | batch_result = client.beta.messages.batches.create( 211 | requests=batch_inputs, 212 | ) 213 | print(f"Batch {batch_result.id} submitted") 214 | 215 | claude_batch_metadata = batch_result.to_dict() 216 | for key in claude_batch_metadata: 217 | if isinstance(claude_batch_metadata[key], datetime): 218 | claude_batch_metadata[key] = str(claude_batch_metadata[key]) 219 | # time should be in the current timezone, in the format like 2022-01-01T00:00:00 220 | batch_submission_status[batch_result.id] = { 221 | "project": project_name, 222 | "description": description, 223 | "endpoint": None, 224 | "completion_window": None, 225 | "input_path": str(batch_file), 226 | "input_path_metadata": { 227 | "hash": hashlib.md5(batch_file.read_bytes()).hexdigest(), 228 | "size": batch_file.stat().st_size, 229 | "total": len(batch_inputs) 230 | }, 231 | "output_path": str(output_path), 232 | "batch_input_file_id": None, 233 | "batch_result_id": batch_result.id, 234 | "status": "submitted", 235 | "timeline": { 236 | "submitted": str(datetime.now()), 237 | "completed": None, 238 | "failed": None, 239 | "downloaded": None, 240 | }, 241 | "last_updated": str(datetime.now()), 242 | "claude_batch_metadata": claude_batch_metadata 243 | } 244 | 245 | write_batch_submission_status(batch_submission_status) 246 | return batch_result.id 247 | 248 | def check_batch_status(batch_result_id, overwrite:bool=False): 249 | batch_submission_status = read_batch_submission_status() 250 | client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) 251 | try: 252 | batch = client.beta.messages.batches.retrieve(batch_result_id) 253 | except anthropic.NotFoundError as e: 254 | return None 255 | batch_id = batch.id 256 | batch_status = batch.processing_status 257 | batch_desc = batch_submission_status[batch_id]["description"] if batch_id in batch_submission_status else "" 258 | batch_project_name = batch_submission_status[batch_id]["project"] if batch_id in batch_submission_status else "" 259 | if batch_submission_status[batch_id]["output_path"]: 260 | batch_output_path = batch_submission_status[batch_id]["output_path"] 261 | else: 262 | batch_output_path = f"./batch_results/{batch_id}.batch_results.jsonl" 263 | batch_submission_status[batch_id]["output_path"] = batch_output_path 264 | # print(f"{batch_id: <20} {batch_status: <20} {batch_project_name: <20} {batch_desc: <20}") 265 | claude_batch_metadata = batch.to_dict() 266 | for key in claude_batch_metadata: 267 | if isinstance(claude_batch_metadata[key], datetime): 268 | claude_batch_metadata[key] = str(claude_batch_metadata[key]) 269 | if batch_status == "ended": 270 | output_path = batch_output_path 271 | output_path = Path(output_path) 272 | output_path.parent.mkdir(parents=True, exist_ok=True) 273 | if output_path.exists() and not overwrite: 274 | print(f"File {output_path} already exists. Skipping writing to file.") 275 | else: 276 | if output_path.exists() and overwrite: 277 | print(f"Overwriting file {output_path}") 278 | 279 | # retrieve the results 280 | print(f"Downloading output file for batch {batch_id}") 281 | results = client.beta.messages.batches.results(batch_id) 282 | # save url to file via web requests, download the results_url file, and save to output_path 283 | with open(output_path, "wb") as f: 284 | for result in results: 285 | f.write(json.dumps(result.to_dict(), ensure_ascii=False).encode() + b"\n") 286 | print(f"Output file written to {output_path}") 287 | with open(output_path, "r") as f: 288 | results = [json.loads(line) for line in f.readlines()] 289 | all_result_types = [x['result']['type'] for x in results] 290 | if all(all_result_types) == "errored" or all(all_result_types) == "error": 291 | batch_status = "errored" 292 | batch_submission_status[batch_id]["status"] = "errored" 293 | batch_submission_status[batch_id]["timeline"]["failed"] = str(datetime.now()) 294 | elif all(all_result_types) == "canceled": 295 | batch_status = "canceled" 296 | batch_submission_status[batch_id]["status"] = "canceled" 297 | batch_submission_status[batch_id]["timeline"]["canceled"] = str(datetime.now()) 298 | elif all(all_result_types) == "expired": 299 | batch_status = "expired" 300 | batch_submission_status[batch_id]["status"] = "expired" 301 | batch_submission_status[batch_id]["timeline"]["expired"] = str(datetime.now()) 302 | else: 303 | batch_submission_status[batch_id]["status"] = "completed" 304 | batch_submission_status[batch_id]["timeline"]["completed"] = str(datetime.now()) 305 | batch_submission_status[batch_id]["timeline"]["downloaded"] = str(datetime.now()) 306 | batch_submission_status[batch_id]["last_updated"] = str(datetime.now()) 307 | batch_submission_status[batch_id]["claude_batch_metadata"].update(claude_batch_metadata) 308 | else: 309 | batch_submission_status[batch_id]["status"] = batch_status 310 | batch_submission_status[batch_id]["last_updated"] = str(datetime.now()) 311 | batch_submission_status[batch_id]["claude_batch_metadata"].update(claude_batch_metadata) 312 | write_batch_submission_status(batch_submission_status) 313 | return batch_submission_status[batch_id] 314 | 315 | def get_batch_progress(batch_result_id): 316 | batch_status = check_batch_status(batch_result_id) 317 | num_succeeded = batch_status["claude_batch_metadata"]['request_counts']['succeeded'] 318 | num_processing = batch_status["claude_batch_metadata"]['request_counts']['processing'] 319 | num_errored = batch_status["claude_batch_metadata"]['request_counts']['errored'] 320 | num_expired = batch_status["claude_batch_metadata"]['request_counts']['expired'] 321 | num_canceled = batch_status["claude_batch_metadata"]['request_counts']['canceled'] 322 | num_total = num_succeeded + num_processing + num_errored + num_expired + num_canceled 323 | n = num_succeeded + num_errored + num_expired + num_canceled 324 | total = num_total 325 | tqdm_postfix = { 326 | "completed": num_succeeded, 327 | "processing": num_processing, 328 | "errored": num_errored, 329 | "expired": num_expired, 330 | "canceled": num_canceled 331 | } 332 | return n, total, tqdm_postfix, batch_status["status"] 333 | 334 | def get_batch_result(batch_id, generate_kwargs={}): 335 | batch_status = check_batch_status(batch_id) 336 | if batch_status["status"] == "completed": 337 | output_path = batch_status["output_path"] 338 | results = [] 339 | with open(output_path, "r") as f: 340 | raw_results = [json.loads(line) for line in f.readlines()] 341 | results = [] 342 | for item in raw_results: 343 | if item["result"]["type"] == "succeeded": 344 | results.append(item["result"]['message']["content"][0]['text']) 345 | else: 346 | results.append(None) 347 | print("Batch requests status counts:", batch_status["claude_batch_metadata"]['request_counts']) 348 | else: 349 | results = None 350 | return results 351 | 352 | def claude_batch_request( 353 | model_name:str, 354 | batch_messages:List[Union[str, List[str], List[dict]]], 355 | conv_system_msg:str=None, 356 | desc:str=None, 357 | detach:bool=False, 358 | **generate_kwargs 359 | ): 360 | if isinstance(batch_messages[0], str): 361 | batch_messages = [[{"role": "user", "content": message}] for message in batch_messages] 362 | elif isinstance(batch_messages[0], list): 363 | if isinstance(batch_messages[0][0], str): 364 | batch_messages = [[{ 365 | "role": "user" if i % 2 == 0 else "assistant", 366 | "content": message 367 | } for i, message in enumerate(messages)] for messages in batch_messages] 368 | elif isinstance(batch_messages[0][0], dict): 369 | assert all("role" in message for message in batch_messages[0]), "Error: role key not found in the message" 370 | assert all("content" in message for message in batch_messages[0]), "Error: content key not found in the message" 371 | else: 372 | raise ValueError("Error: unknown message format") 373 | else: 374 | raise ValueError("Error: unknown message format") 375 | if conv_system_msg: 376 | batch_messages = [[{"role": "system", "content": conv_system_msg}] + messages for messages in batch_messages] 377 | if "stream" in generate_kwargs: 378 | generate_kwargs.pop("stream") 379 | batch_file = save_batch_file(batch_messages, model_name, **generate_kwargs) 380 | batch_result_id = submit_batch_file(batch_file) 381 | if detach: 382 | return batch_result_id 383 | num_total = len(batch_messages) 384 | tqdm_bar = tqdm(total=num_total, desc=desc or "LLMEngine Batch Inference") 385 | while True: 386 | batch_status = check_batch_status(batch_result_id) 387 | assert batch_status is not None, f"Error: batch {batch_result_id} not found in Anthropic API" 388 | num_succeeded = batch_status["claude_batch_metadata"]['request_counts']['succeeded'] 389 | num_processing = batch_status["claude_batch_metadata"]['request_counts']['processing'] 390 | num_errored = batch_status["claude_batch_metadata"]['request_counts']['errored'] 391 | num_expired = batch_status["claude_batch_metadata"]['request_counts']['expired'] 392 | num_canceled = batch_status["claude_batch_metadata"]['request_counts']['canceled'] 393 | num_total = num_succeeded + num_processing + num_errored + num_expired + num_canceled 394 | assert num_total == len(batch_messages), f"Error: total number of requests {num_total} does not match the number of requests {len(batch_messages)}" 395 | tqdm_bar.n = num_succeeded + num_errored + num_expired + num_canceled 396 | tqdm_bar.total = num_total 397 | # tqdm_bar.set_postfix(completed=num_completed, total=num_total, failed=num_failed) 398 | tqdm_bar.set_postfix( 399 | completed=num_succeeded, 400 | processing=num_processing, 401 | errored=num_errored, 402 | expired=num_expired, 403 | canceled=num_canceled 404 | ) 405 | if batch_status["status"] in ["completed", "errored", "expired", "canceled"]: 406 | tqdm_bar.close() 407 | break 408 | elif batch_status["status"] == "in_progress": 409 | tqdm_bar.desc = "In Progress" 410 | tqdm_bar.refresh() 411 | else: 412 | tqdm_bar.desc = batch_status["status"] 413 | tqdm_bar.refresh() 414 | time.sleep(random.randint(5, 10)) 415 | 416 | results = get_batch_result(batch_result_id) 417 | return results 418 | 419 | 420 | if __name__ == "__main__": 421 | from icecream import ic 422 | ic(call_worker_claude(["Hello", "Hi, I am claude", "What did I ask in the last response?"], "claude-3-opus-20240229")) -------------------------------------------------------------------------------- /llm_engines/cli.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import random 3 | import os 4 | from typing import List, Union 5 | 6 | class Cli: 7 | def __init__(self): 8 | pass 9 | 10 | def serve(self, 11 | model_name: str, 12 | engine: str="vllm", 13 | num_gpus: int=1, 14 | gpu_ids: List[int]=None, 15 | dtype: str="auto", 16 | quantization: str=None, 17 | port: Union[int,None] =None, 18 | host: str="127.0.0.1", 19 | ) -> str: 20 | from . import set_do_cleanup 21 | set_do_cleanup(False) 22 | assert engine in ["vllm", "sglang"] 23 | if engine == "vllm": 24 | from .vllm import launch_vllm_worker 25 | launch_worker_func = launch_vllm_worker 26 | elif engine == "sglang": 27 | from .sglang import launch_sglang_worker 28 | launch_worker_func = launch_sglang_worker 29 | if port is None: 30 | print("Warning: port not provided, using random port in range 31000-32000") 31 | port = random.randint(31000, 32000) 32 | if isinstance(gpu_ids, int): 33 | gpu_ids = [gpu_ids] 34 | elif isinstance(gpu_ids, str): 35 | gpu_ids = [int(gpu_id) for gpu_id in gpu_ids.split(",")] 36 | else: 37 | gpu_ids = None 38 | workder_addr, proc = launch_worker_func( 39 | model_name=model_name, 40 | num_gpus=num_gpus, 41 | gpu_ids=gpu_ids, 42 | dtype=dtype, 43 | quantization=quantization, 44 | port=port, 45 | host=host 46 | ) 47 | 48 | def call_worker( 49 | self, 50 | model_name: str, 51 | worker_addr: str, 52 | prompt: str, 53 | engine: str="vllm", 54 | temperature: float=0.0, 55 | top_p: float=1.0, 56 | max_tokens: int=None, 57 | timeout: int=60, 58 | ) -> str: 59 | from . import ModelWorker 60 | call_worker_func = ModelWorker( 61 | model_name=model_name, 62 | worker_addrs=[worker_addr], 63 | engine=engine, 64 | use_cache=True, 65 | overwrite_cache=False 66 | ) 67 | response = call_worker_func([prompt], temperature=temperature, top_p=top_p, max_tokens=max_tokens, timeout=timeout) 68 | return response 69 | 70 | def clean_workers( 71 | self, 72 | engine: str="vllm", 73 | ): 74 | print("Note: This will kill all workers for the specified engine") 75 | if engine == "vllm": 76 | os.system("pkill -f vllm.entrypoints.openai.api_server") 77 | elif engine == "sglang": 78 | os.system("pkill -f sglang.launch_server") 79 | else: 80 | raise ValueError(f"Engine {engine} not supported") 81 | print("Workers cleaned") 82 | 83 | def main(): 84 | fire.Fire(Cli) 85 | 86 | if __name__ == "__main__": 87 | main() 88 | 89 | 90 | """ 91 | llm-engines serve "meta-llama/Meta-Llama-3-8B-Instruct" --engine vllm --num-gpus 1 --gpu-ids 2 --dtype auto 92 | llm-engines call-worker "meta-llama/Meta-Llama-3-8B-Instruct" "http://127.0.0.1:31845" "Hello" 93 | llm-engines clean-workers --engine vllm 94 | 95 | """ -------------------------------------------------------------------------------- /llm_engines/deepseek.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import hashlib 4 | import time 5 | import filelock 6 | import random 7 | import openai 8 | from datetime import datetime 9 | from openai import OpenAI 10 | from typing import List, Union 11 | from pathlib import Path 12 | from tqdm import tqdm 13 | 14 | # no image, multi-turn, do not use deepseek_generate, but can refer to it 15 | def call_worker_deepseek(messages:List[str], model_name, timeout:int=60, conv_system_msg=None, **generate_kwargs) -> str: 16 | # change messages to openai format 17 | if conv_system_msg: 18 | new_messages = [{"role": "system", "content": conv_system_msg}] + messages 19 | else: 20 | new_messages = messages 21 | # initialize openai client 22 | client = OpenAI(api_key=os.environ["DEEPSEEK_API_KEY"], base_url="https://api.deepseek.com") 23 | # call deepseek 24 | completion = client.chat.completions.create( 25 | model=model_name, 26 | messages=new_messages, 27 | timeout=timeout, 28 | **generate_kwargs, 29 | ) 30 | stream = generate_kwargs.get("stream", False) 31 | 32 | if "logprobs" in generate_kwargs: 33 | return_logprobs = True 34 | 35 | if not stream: 36 | if "logprobs" not in generate_kwargs or not generate_kwargs["logprobs"]: 37 | if len(completion.choices) > 1: 38 | return [c.message.content for c in completion.choices] 39 | else: 40 | return completion.choices[0].message.content 41 | else: 42 | if len(completion.choices) > 1: 43 | return [c.message.content for c in completion.choices], [c.logprobs.dict() for c in completion.choices] 44 | else: 45 | return completion.choices[0].message.content, completion.choices[0].logprobs.dict() 46 | else: 47 | def generate_stream(): 48 | for chunk in completion: 49 | if chunk.choices[0].delta.content is not None: 50 | yield chunk.choices[0].delta.content 51 | return generate_stream() 52 | 53 | def call_worker_deepseek_completion(prompt:str, model_name, timeout:int=60, **generate_kwargs) -> str: 54 | # initialize openai client 55 | client = OpenAI(api_key=os.environ["DEEPSEEK_API_KEY"], base_url="https://api.deepseek.com") 56 | # call deepseek 57 | print(generate_kwargs) 58 | if "max_tokens" not in generate_kwargs: 59 | generate_kwargs["max_tokens"] = 256 # have to set max_tokens to be explicit 60 | completion = client.completions.create( 61 | model=model_name, 62 | prompt=prompt, 63 | timeout=timeout, 64 | **generate_kwargs, 65 | ) 66 | stream = generate_kwargs.get("stream", False) 67 | if not stream: 68 | if len(completion.choices) > 1: 69 | return [c.text for c in completion.choices] 70 | else: 71 | return completion.choices[0].text 72 | else: 73 | def generate_stream(): 74 | for chunk in completion: 75 | if chunk.choices[0].text is not None: 76 | yield chunk.choices[0].text 77 | return generate_stream() 78 | 79 | if __name__ == "__main__": 80 | from icecream import ic 81 | ic(call_worker_deepseek(["Hello"], "deepseek-reasoner")) 82 | ic(call_worker_deepseek_completion("Hello", "deepseek-reasoner")) -------------------------------------------------------------------------------- /llm_engines/fireworks.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Union 3 | from fireworks.client import Fireworks 4 | 5 | # no image, multi-turn 6 | def call_worker_fireworks(messages:List[str], model_name, timeout:int=60, conv_system_msg=None, **generate_kwargs) -> str: 7 | # change messages to openai format 8 | if conv_system_msg: 9 | new_messages = [{"role": "system", "content": conv_system_msg}] + messages 10 | else: 11 | new_messages = messages 12 | # initialize openai client 13 | client = Fireworks(api_key=os.environ["FIREWORKS_API_KEY"]) 14 | # call fireworks 15 | completion = client.chat.completions.create( 16 | model=model_name, 17 | messages=new_messages, 18 | **generate_kwargs, 19 | ) 20 | stream = generate_kwargs.get("stream", False) 21 | 22 | if "logprobs" in generate_kwargs: 23 | return_logprobs = True 24 | 25 | if not stream: 26 | if "logprobs" not in generate_kwargs or not generate_kwargs["logprobs"]: 27 | if len(completion.choices) > 1: 28 | return [c.message.content for c in completion.choices] 29 | else: 30 | return completion.choices[0].message.content 31 | else: 32 | if len(completion.choices) > 1: 33 | return [c.message.content for c in completion.choices], [c.logprobs.dict() for c in completion.choices] 34 | else: 35 | return completion.choices[0].message.content, completion.choices[0].logprobs.dict() 36 | else: 37 | def generate_stream(): 38 | for chunk in completion: 39 | if chunk.choices[0].delta.content is not None: 40 | yield chunk.choices[0].delta.content 41 | return generate_stream() 42 | 43 | def call_worker_fireworks_completion(prompt:str, model_name, timeout:int=60, **generate_kwargs) -> str: 44 | # initialize openai client 45 | client = Fireworks(api_key=os.environ["FIREWORKS_API_KEY"]) 46 | # call fireworks 47 | print(generate_kwargs) 48 | if "max_tokens" not in generate_kwargs: 49 | generate_kwargs["max_tokens"] = 256 # have to set max_tokens to be explicit 50 | completion = client.completions.create( 51 | model=model_name, 52 | prompt=prompt, 53 | **generate_kwargs, 54 | ) 55 | stream = generate_kwargs.get("stream", False) 56 | if not stream: 57 | if len(completion.choices) > 1: 58 | return [c.text for c in completion.choices] 59 | else: 60 | return completion.choices[0].text 61 | else: 62 | def generate_stream(): 63 | for chunk in completion: 64 | if chunk.choices[0].text is not None: 65 | yield chunk.choices[0].text 66 | return generate_stream() 67 | 68 | if __name__ == "__main__": 69 | from icecream import ic 70 | generate_kwargs = { 71 | "top_p": 1, 72 | "top_k": 40, 73 | "presence_penalty": 0, 74 | "frequency_penalty": 0, 75 | "temperature": 0.6, 76 | "max_tokens": 20480 77 | } 78 | ic(call_worker_fireworks(["Hello"], "accounts/fireworks/models/deepseek-r1", **generate_kwargs)) 79 | ic(call_worker_fireworks_completion("Hello", "accounts/fireworks/models/deepseek-r1", **generate_kwargs)) -------------------------------------------------------------------------------- /llm_engines/gemini.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from typing import List 4 | import google.ai.generativelanguage as glm 5 | import google.generativeai as genai 6 | from google.api_core.exceptions import ServiceUnavailable, ResourceExhausted 7 | from .utils import with_timeout, decode_base64_image_url 8 | genai.configure(api_key=os.environ.get("GEMINI_API_KEY")) 9 | safety_settings = [ 10 | { 11 | "category": "HARM_CATEGORY_DANGEROUS", 12 | "threshold": "BLOCK_NONE", 13 | }, 14 | { 15 | "category": "HARM_CATEGORY_HARASSMENT", 16 | "threshold": "BLOCK_NONE", 17 | }, 18 | { 19 | "category": "HARM_CATEGORY_HATE_SPEECH", 20 | "threshold": "BLOCK_NONE", 21 | }, 22 | { 23 | "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", 24 | "threshold": "BLOCK_NONE", 25 | }, 26 | { 27 | "category": "HARM_CATEGORY_DANGEROUS_CONTENT", 28 | "threshold": "BLOCK_NONE", 29 | }, 30 | ] 31 | # no image, multi-turn, do not use openai_generate, but can refer to it 32 | def call_worker_gemini(messages:List[str], model_name, timeout:int=60, conv_system_msg=None, **generate_kwargs) -> str: 33 | # change messages to gemini format 34 | model = genai.GenerativeModel(model_name, system_instruction=conv_system_msg) 35 | 36 | new_messages = [] 37 | role_map = {"user": "user", "assistant": "model"} 38 | for i, message in enumerate(messages): 39 | role = role_map[message["role"]] 40 | if isinstance(message["content"], str): 41 | new_messages.append({"role": role, "parts": [glm.Part(text=message["content"])]}) 42 | elif isinstance(message["content"], list): 43 | parts = [] 44 | for sub_message in message["content"]: 45 | if sub_message["type"] == "text": 46 | parts.append(glm.Part(text=sub_message["text"])) 47 | elif sub_message["type"] == "image_url": 48 | try: 49 | image = decode_base64_image_url(sub_message["image_url"]['url']) 50 | except Exception as e: 51 | image = sub_message["image_url"]['url'] 52 | parts.append(image) 53 | else: 54 | raise ValueError("Invalid message format") 55 | new_messages.append({"role": role, "parts": parts}) 56 | else: 57 | raise ValueError("Invalid message format") 58 | 59 | stream = generate_kwargs.pop("stream", False) 60 | generation_config = genai.GenerationConfig( 61 | candidate_count=generate_kwargs.get("num_return_sequences", None), 62 | stop_sequences=generate_kwargs.get("stop", None), 63 | max_output_tokens=generate_kwargs.get("max_tokens", None), 64 | temperature=generate_kwargs.get("temperature", None), 65 | top_p=generate_kwargs.get("top_p", None), 66 | top_k=generate_kwargs.get("top_k", None), 67 | response_mime_type=generate_kwargs.get("response_mime_type", None), 68 | response_schema=generate_kwargs.get("response_schema", None), 69 | ) 70 | request_options = genai.types.RequestOptions( 71 | timeout=timeout, 72 | ) 73 | if "logprobs" in generate_kwargs: 74 | raise ValueError("logprobs is not supported in gemini") 75 | @with_timeout(timeout) 76 | def generate_content(): 77 | return model.generate_content(new_messages, safety_settings=safety_settings, generation_config=generation_config, request_options=request_options, stream=stream) 78 | while True: 79 | try: 80 | response = generate_content() 81 | break 82 | except ServiceUnavailable as e: 83 | # sleep for a while and retry 84 | # print("ServiceUnavailable, retrying...") 85 | time.sleep(2) 86 | continue 87 | except ResourceExhausted as e: 88 | # sleep for a while and retry 89 | # print("ResourceExhausted, retrying...") 90 | time.sleep(10) 91 | continue 92 | try: 93 | if not stream: 94 | return response.text 95 | else: 96 | def generate_stream(): 97 | for chunk in response: 98 | yield chunk.text 99 | return generate_stream() 100 | except ValueError as e: 101 | print(f"Empty response from gemini due to {e}") 102 | return None 103 | 104 | if __name__ == "__main__": 105 | from icecream import ic 106 | ic(call_worker_gemini(["Hello", "Hi, I am gemini", "What did I ask in the last response?"], "gemini-1.5-flash")) 107 | -------------------------------------------------------------------------------- /llm_engines/grok.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import hashlib 4 | import time 5 | import filelock 6 | import random 7 | import openai 8 | from datetime import datetime 9 | from openai import OpenAI 10 | from typing import List, Union 11 | from pathlib import Path 12 | from tqdm import tqdm 13 | from .cache import get_printable_messages 14 | 15 | # no image, multi-turn, do not use grok_generate, but can refer to it 16 | def call_worker_grok(messages:List[str], model_name, timeout:int=60, conv_system_msg=None, **generate_kwargs) -> str: 17 | # change messages to openai format 18 | if conv_system_msg: 19 | new_messages = [{"role": "system", "content": conv_system_msg}] + messages 20 | else: 21 | new_messages = messages 22 | # initialize openai client 23 | client = OpenAI(api_key=os.environ["XAI_API_KEY"], base_url="https://api.x.ai/v1") 24 | # call grok 25 | completion = client.chat.completions.create( 26 | model=model_name, 27 | messages=new_messages, 28 | timeout=timeout, 29 | **generate_kwargs, 30 | ) 31 | stream = generate_kwargs.get("stream", False) 32 | 33 | if "logprobs" in generate_kwargs: 34 | return_logprobs = True 35 | 36 | if not stream: 37 | if "logprobs" not in generate_kwargs or not generate_kwargs["logprobs"]: 38 | if len(completion.choices) > 1: 39 | return [c.message.content for c in completion.choices] 40 | else: 41 | return completion.choices[0].message.content 42 | else: 43 | if len(completion.choices) > 1: 44 | return [c.message.content for c in completion.choices], [c.logprobs.dict() for c in completion.choices] 45 | else: 46 | return completion.choices[0].message.content, completion.choices[0].logprobs.dict() 47 | else: 48 | def generate_stream(): 49 | for chunk in completion: 50 | if chunk.choices[0].delta.content is not None: 51 | yield chunk.choices[0].delta.content 52 | return generate_stream() 53 | 54 | def call_worker_grok_completion(prompt:str, model_name, timeout:int=60, **generate_kwargs) -> str: 55 | # initialize openai client 56 | client = OpenAI(api_key=os.environ["XAI_API_KEY"], base_url="https://api.x.ai/v1") 57 | # call grok 58 | if "max_tokens" not in generate_kwargs: 59 | generate_kwargs["max_tokens"] = 256 # have to set max_tokens to be explicit 60 | completion = client.completions.create( 61 | model=model_name, 62 | prompt=prompt, 63 | timeout=timeout, 64 | **generate_kwargs, 65 | ) 66 | stream = generate_kwargs.get("stream", False) 67 | if not stream: 68 | if len(completion.choices) > 1: 69 | return [c.text for c in completion.choices] 70 | else: 71 | return completion.choices[0].text 72 | else: 73 | def generate_stream(): 74 | for chunk in completion: 75 | if chunk.choices[0].text is not None: 76 | yield chunk.choices[0].text 77 | return generate_stream() 78 | 79 | if __name__ == "__main__": 80 | from icecream import ic 81 | ic(call_worker_grok(["Hello"], "grok-2-latest")) 82 | ic(call_worker_grok_completion("Hello", "grok-2-latest")) -------------------------------------------------------------------------------- /llm_engines/mistral.py: -------------------------------------------------------------------------------- 1 | import os 2 | from mistralai import Mistral 3 | from typing import List 4 | from .utils import with_timeout 5 | 6 | # no image, multi-turn, do not use openai_generate, but can refer to it 7 | def call_worker_mistral(messages:List[str], model_name, timeout:int=120, conv_system_msg=None, **generate_kwargs) -> str: 8 | client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY")) 9 | # change messages to openai format 10 | if conv_system_msg: 11 | new_messages = [{"role": "system", "content": conv_system_msg}] + messages 12 | else: 13 | new_messages = messages 14 | 15 | if "n" in generate_kwargs: 16 | generate_kwargs.pop("n") # mistral does not have n 17 | if "logprobs" in generate_kwargs: 18 | raise ValueError("logprobs is not supported in mistral") 19 | stream = generate_kwargs.pop("stream", False) 20 | @with_timeout(timeout) 21 | def generate_content(): 22 | completion = client.chat.complete( 23 | model=model_name, 24 | messages=new_messages, 25 | **generate_kwargs, 26 | ) 27 | if len(completion.choices) > 1: 28 | return [c.message.content for c in completion.choices] 29 | else: 30 | return completion.choices[0].message.content 31 | 32 | @with_timeout(timeout) 33 | def stream_content(): 34 | completion = client.chat.stream( 35 | model=model_name, 36 | messages=new_messages, 37 | **generate_kwargs, 38 | ) 39 | def generate_stream(): 40 | for chunk in completion: 41 | if chunk.data.choices[0].delta.content is not None: 42 | yield chunk.data.choices[0].delta.content 43 | return generate_stream() 44 | if not stream: 45 | return generate_content() 46 | else: 47 | return stream_content() 48 | 49 | if __name__ == "__main__": 50 | from icecream import ic 51 | ic(call_worker_mistral(["Hello", "Hi, I am mistral", "What did I ask in the last response?"], "mistral-large-latest")) -------------------------------------------------------------------------------- /llm_engines/openai_text.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import hashlib 4 | import time 5 | import filelock 6 | import random 7 | import openai 8 | from datetime import datetime 9 | from openai import OpenAI 10 | from typing import List, Union 11 | from pathlib import Path 12 | from tqdm import tqdm 13 | 14 | 15 | batch_submission_status_file = Path(os.path.expanduser(f"~/llm_engines/generation_cache")) / "openai_batch_cache" / "batch_submission_status.json" 16 | batch_submission_status_file.parent.mkdir(parents=True, exist_ok=True) 17 | 18 | def read_batch_submission_status(): 19 | print("Loading batch submission status from", batch_submission_status_file) 20 | if batch_submission_status_file.exists(): 21 | 22 | lock = filelock.FileLock(str(batch_submission_status_file) + ".lock", timeout=30) 23 | try: 24 | with lock: 25 | with open(batch_submission_status_file, "r") as f: 26 | batch_submission_status = json.load(f) 27 | except filelock.Timeout as e: 28 | print("Timeout acquiring lock") 29 | raise e 30 | else: 31 | batch_submission_status = {} 32 | return batch_submission_status 33 | 34 | def write_batch_submission_status(batch_submission_status): 35 | lock = filelock.FileLock(str(batch_submission_status_file) + ".lock", timeout=30) 36 | try: 37 | with lock: 38 | with open(batch_submission_status_file, "w") as f: 39 | json.dump(batch_submission_status, f, indent=4) 40 | except filelock.Timeout as e: 41 | print("Timeout acquiring lock") 42 | raise e 43 | 44 | 45 | # no image, multi-turn, do not use openai_generate, but can refer to it 46 | def call_worker_openai(messages:List[str], model_name, timeout:int=60, conv_system_msg=None, **generate_kwargs) -> str: 47 | # change messages to openai format 48 | if conv_system_msg: 49 | new_messages = [{"role": "system", "content": conv_system_msg}] + messages 50 | else: 51 | new_messages = messages 52 | # initialize openai client 53 | client = OpenAI() 54 | o_series = ["o1", "o3"] 55 | if any(o in model_name for o in o_series): 56 | # fixed parameters for openai o1 models 57 | max_tokens = generate_kwargs.pop("max_tokens", None) 58 | if max_tokens is not None: 59 | generate_kwargs["max_completion_tokens"] = max_tokens 60 | generate_kwargs['temperature'] = 1 61 | generate_kwargs['top_p'] = 1 62 | generate_kwargs['n'] = 1 63 | generate_kwargs['frequency_penalty'] = 0 64 | generate_kwargs['presence_penalty'] = 0 65 | # call openai 66 | completion = client.chat.completions.create( 67 | model=model_name, 68 | messages=new_messages, 69 | timeout=timeout, 70 | **generate_kwargs, 71 | ) 72 | stream = generate_kwargs.get("stream", False) 73 | 74 | if "logprobs" in generate_kwargs: 75 | return_logprobs = True 76 | 77 | if not stream: 78 | if "logprobs" not in generate_kwargs or not generate_kwargs["logprobs"]: 79 | if len(completion.choices) > 1: 80 | return [c.message.content for c in completion.choices] 81 | else: 82 | return completion.choices[0].message.content 83 | else: 84 | if len(completion.choices) > 1: 85 | return [c.message.content for c in completion.choices], [c.logprobs.dict() for c in completion.choices] 86 | else: 87 | return completion.choices[0].message.content, completion.choices[0].logprobs.dict() 88 | else: 89 | def generate_stream(): 90 | for chunk in completion: 91 | if chunk.choices[0].delta.content is not None: 92 | yield chunk.choices[0].delta.content 93 | return generate_stream() 94 | 95 | def call_worker_openai_completion(prompt:str, model_name, timeout:int=60, **generate_kwargs) -> str: 96 | # initialize openai client 97 | client = OpenAI() 98 | # call openai 99 | completion = client.completions.create( 100 | model=model_name, 101 | prompt=prompt, 102 | timeout=timeout, 103 | **generate_kwargs, 104 | ) 105 | stream = generate_kwargs.get("stream", False) 106 | if not stream: 107 | if len(completion.choices) > 1: 108 | return [c.text for c in completion.choices] 109 | else: 110 | return completion.choices[0].text 111 | else: 112 | def generate_stream(): 113 | for chunk in completion: 114 | if chunk.choices[0].text is not None: 115 | yield chunk.choices[0].text 116 | return generate_stream() 117 | 118 | def save_batch_file(batch_messages:List[Union[str, dict]], model:str, batch_name:str=None, cache_dir=None, custom_ids=None, **generate_kwargs): 119 | if isinstance(batch_messages[0], str): 120 | batch_messages = [{"role": "user", "content": message} for message in batch_messages] 121 | 122 | if cache_dir is None: 123 | cache_dir = Path(os.path.expanduser(f"~/llm_engines/generation_cache")) / "openai_batch_cache" 124 | else: 125 | cache_dir = Path(cache_dir) / "openai_batch_cache" 126 | cache_dir.mkdir(parents=True, exist_ok=True) 127 | if custom_ids is None: 128 | custom_ids = [f"request-{i+1}" for i in range(len(batch_messages))] 129 | assert len(custom_ids) == len(batch_messages) 130 | if batch_name is None: 131 | # batch_name = f"{model}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" 132 | hash_str = "".join([str(message) for message in batch_messages]) 133 | batch_name = f"{model}_{hashlib.md5(hash_str.encode()).hexdigest()}" 134 | file_path = cache_dir / f"{batch_name}.jsonl" 135 | with open(file_path, "w") as f: 136 | for custom_id, message in zip(custom_ids, batch_messages): 137 | if isinstance(message, str): 138 | message = {"role": "user", "content": message} 139 | f.write(json.dumps({"custom_id": custom_id, "method": "POST", "url": "/v1/chat/completions", "body": {"model": model, "messages": message, **generate_kwargs}}, ensure_ascii=False) + "\n") 140 | return file_path 141 | 142 | def submit_batch_file(batch_file:str, output_path:str=None, project_name:str=None, description:str=None): 143 | """ 144 | Submit a batch of queries to OpenAI API 145 | Args: 146 | batch_file: str, path to the batch jsonl file 147 | output_path: str, path to the output file; if not provided, will be saved in the same directory as the batch_file, ending with ".batch_results.jsonl" 148 | project_name: str, project name, default to "vapo", keyword to filter batches 149 | description: str, description of the batch, default to the batch_file name, can also be used to filter batches 150 | Returns: 151 | batch_result_id: str, the id of the batch submission 152 | """ 153 | if isinstance(batch_file, str): 154 | batch_file = Path(batch_file) 155 | 156 | # internally maintain a batch submission status json 157 | batch_submission_status = read_batch_submission_status() 158 | 159 | client = OpenAI() 160 | batch_input_file = client.files.create( 161 | file=open(batch_file, "rb"), 162 | purpose="batch" 163 | ) 164 | if not project_name: 165 | project_name = "llm_engines" 166 | description = description if description is not None else batch_file.stem 167 | output_path = output_path if output_path is not None else batch_file.parent / f"{description}.batch_results.jsonl" 168 | batch_input_file_id = batch_input_file.id 169 | 170 | batch_file_hash = hashlib.md5(batch_file.read_bytes()).hexdigest() 171 | batch_file_size = batch_file.stat().st_size 172 | 173 | for key, value in batch_submission_status.items(): 174 | value_input_file_metadata = value["input_path_metadata"] 175 | if not os.path.samefile(value["input_path"], batch_file): 176 | # print(f"Batch {key} has a different input file. need to resubmit.") 177 | continue 178 | if batch_file_size != value_input_file_metadata["size"]: 179 | # print(f"Batch {key} has a newer version of the input file. need to resubmit.") 180 | continue 181 | if batch_file_hash != value_input_file_metadata["hash"]: 182 | # print(f"Batch {key} has a different input hash. need to resubmit.") 183 | continue 184 | if value['status'] in ["validating", "in_progress", "finalizing", "completed"]: 185 | print(f"Batch {key} is still in progress. Skipping submission.") 186 | return key 187 | else: 188 | continue 189 | 190 | batch_result = None 191 | for batch in client.batches.list(limit=10): 192 | if batch.metadata and (batch.metadata.get('input_path') == str(batch_file)) and batch.status in ["validating", "in_progress", "finalizing", "completed"]: 193 | batch_result = batch 194 | break 195 | 196 | completion_window = "24h" 197 | endpoint = "/v1/chat/completions" 198 | if batch_result is None: 199 | batch_result = client.batches.create( 200 | input_file_id=batch_input_file_id, 201 | endpoint=endpoint, 202 | completion_window=completion_window, 203 | metadata={ 204 | "project": project_name, 205 | "description": description, 206 | "input_path": str(batch_file), 207 | "output_path": str(output_path) 208 | } 209 | ) 210 | print(f"Batch {batch_result.id} submitted") 211 | submit_time = str(datetime.now()) 212 | else: 213 | print(f"Batch already exists for {batch_file}, but not found in the managing file, writing to {batch_submission_status_file}") 214 | submit_time = batch_result.created_at 215 | 216 | # time should be in the current timezone, in the format like 2022-01-01T00:00:00 217 | batch_submission_status[batch_result.id] = { 218 | "project": project_name, 219 | "description": description, 220 | "endpoint": endpoint, 221 | "completion_window": completion_window, 222 | "input_path": str(batch_file), 223 | "input_path_metadata": { 224 | "hash": hashlib.md5(batch_file.read_bytes()).hexdigest(), 225 | "size": batch_file.stat().st_size 226 | }, 227 | "output_path": str(output_path), 228 | "batch_input_file_id": batch_input_file_id, 229 | "batch_result_id": batch_result.id, 230 | "status": batch_result.status, 231 | "timeline": { 232 | "submitted": submit_time, 233 | "completed": None, 234 | "failed": None, 235 | "downloaded": None 236 | }, 237 | "last_updated": str(datetime.now()), 238 | "openai_batch_metadata": batch_result.to_dict() 239 | } 240 | 241 | write_batch_submission_status(batch_submission_status) 242 | return batch_result.id 243 | 244 | def check_batch_status(batch_id, overwrite:bool=False): 245 | # internally maintain a batch submission status json 246 | batch_submission_status = read_batch_submission_status() 247 | if batch_id in batch_submission_status: 248 | batch_status = batch_submission_status[batch_id]["status"] 249 | else: 250 | client = OpenAI() 251 | try: 252 | batch = client.batches.retrieve(batch_id) 253 | except openai.error.NotFoundError: 254 | print(f"Batch {batch_id} not found.") 255 | return None 256 | if batch_status == "completed": 257 | output_path = Path(batch_submission_status[batch_id]["output_path"]) 258 | if output_path.exists() and not overwrite: 259 | print(f"Output file {output_path} already exists. Skipping writing.") 260 | else: 261 | print(f"Downloading output file for batch {batch_id}") 262 | client = OpenAI() 263 | batch = client.batches.retrieve(batch_id) 264 | batch_id = batch.id 265 | batch_status = batch.status 266 | if batch.metadata is not None and "output_path" in batch.metadata: 267 | batch_output_path = batch.metadata["output_path"] 268 | else: 269 | batch_output_path = f"./batch_results/{batch_id}.batch_results.jsonl" 270 | batch_submission_status[batch_id]["output_path"] = batch_output_path 271 | content = client.files.content(batch.output_file_id) 272 | output_path = batch_output_path 273 | output_path = Path(output_path) 274 | output_path.parent.mkdir(parents=True, exist_ok=True) 275 | if output_path.exists() and overwrite: 276 | print(f"Overwriting file {output_path}") 277 | content.write_to_file(output_path) 278 | print(f"Output file written to {output_path}") 279 | batch_submission_status[batch_id]["status"] = "completed" 280 | batch_submission_status[batch_id]["timeline"]["completed"] = str(datetime.now()) 281 | batch_submission_status[batch_id]["timeline"]["downloaded"] = str(datetime.now()) 282 | batch_submission_status[batch_id]["openai_batch_metadata"].update(batch.to_dict()) 283 | else: 284 | client = OpenAI() 285 | batch = client.batches.retrieve(batch_id) 286 | batch_id = batch.id 287 | batch_status = batch.status 288 | batch_desc = batch.metadata["description"] if batch.metadata is not None and "description" in batch.metadata else "" 289 | batch_project_name = batch.metadata["project"] if batch.metadata is not None and "project" in batch.metadata else "" 290 | if batch.metadata is not None and "output_path" in batch.metadata: 291 | batch_output_path = batch.metadata["output_path"] 292 | else: 293 | batch_output_path = f"./batch_results/{batch_id}.batch_results.jsonl" 294 | batch_submission_status[batch_id]["output_path"] = batch_output_path 295 | # print(f"{batch_id: <20} {batch_status: <20} {batch_project_name: <20} {batch_desc: <20}") 296 | if batch_status == "failed": 297 | print(f"Batch {batch_id} failed.") 298 | batch_submission_status[batch_id]["status"] = "failed" 299 | batch_submission_status[batch_id]["timeline"]["failed"] = str(datetime.now()) 300 | batch_submission_status[batch_id]["openai_batch_metadata"].update(batch.to_dict()) 301 | else: 302 | batch_submission_status[batch_id]["status"] = batch_status 303 | batch_submission_status[batch_id]["openai_batch_metadata"].update(batch.to_dict()) 304 | 305 | batch_submission_status[batch_id]["last_updated"] = str(datetime.now()) 306 | write_batch_submission_status(batch_submission_status) 307 | return batch_submission_status[batch_id] 308 | 309 | def get_batch_progress(batch_id): 310 | batch_status = check_batch_status(batch_id) 311 | num_completed = batch_status["openai_batch_metadata"]['request_counts']['completed'] 312 | num_total = batch_status["openai_batch_metadata"]['request_counts']['total'] 313 | num_failed = batch_status["openai_batch_metadata"]['request_counts']['failed'] 314 | n = num_completed 315 | total = num_total 316 | tqdm_postfix = { 317 | "completed": num_completed, 318 | "total": num_total, 319 | "failed": num_failed 320 | } 321 | return n, total, tqdm_postfix, batch_status['status'] 322 | 323 | def get_batch_result(batch_id, generate_kwargs={}): 324 | batch_status = check_batch_status(batch_id) 325 | if not batch_status["status"] == "completed": 326 | return None 327 | output_path = batch_status["output_path"] 328 | 329 | if not os.path.exists(output_path): 330 | client = OpenAI() 331 | batch = client.batches.retrieve(batch_id) 332 | content = client.files.content(batch.output_file_id) 333 | output_path = Path(output_path) 334 | output_path.parent.mkdir(parents=True, exist_ok=True) 335 | print(f"Downloading output file for batch {batch_id}") 336 | content.write_to_file(output_path) 337 | 338 | results = [] 339 | with open(output_path, "r") as f: 340 | results = [json.loads(line) for line in f.readlines()] 341 | if "logprobs" not in generate_kwargs or not generate_kwargs["logprobs"]: 342 | all_completions = [[choice['message']['content'] for choice in x['response']['body']['choices']] for x in results] 343 | else: 344 | all_completions = [[(choice['message']['content'], choice['logprobs']) for choice in x['response']['body']['choices']] for x in results] 345 | if all(len(x) == 1 for x in all_completions): 346 | all_completions = [x[0] for x in all_completions] 347 | results = all_completions 348 | return results 349 | 350 | def openai_batch_request( 351 | model_name:str, 352 | batch_messages:List[Union[str, List[str], List[dict]]], 353 | conv_system_msg:str=None, 354 | desc:str=None, 355 | detach:bool=False, 356 | **generate_kwargs 357 | ): 358 | if isinstance(batch_messages[0], str): 359 | batch_messages = [[{"role": "user", "content": message}] for message in batch_messages] 360 | elif isinstance(batch_messages[0], list): 361 | if isinstance(batch_messages[0][0], str): 362 | batch_messages = [[{ 363 | "role": "user" if i % 2 == 0 else "assistant", 364 | "content": message 365 | } for i, message in enumerate(messages)] for messages in batch_messages] 366 | elif isinstance(batch_messages[0][0], dict): 367 | assert all("role" in message for message in batch_messages[0]), "Error: role key not found in the message" 368 | assert all("content" in message for message in batch_messages[0]), "Error: content key not found in the message" 369 | else: 370 | raise ValueError("Error: unknown message format") 371 | else: 372 | raise ValueError("Error: unknown message format") 373 | if conv_system_msg: 374 | batch_messages = [[{"role": "system", "content": conv_system_msg}] + messages for messages in batch_messages] 375 | if "stream" in generate_kwargs: 376 | generate_kwargs.pop("stream") 377 | batch_file = save_batch_file(batch_messages, model_name, **generate_kwargs) 378 | batch_result_id = submit_batch_file(batch_file) 379 | if detach: 380 | return batch_result_id 381 | num_total = len(batch_messages) 382 | tqdm_bar = tqdm(total=num_total, desc=desc or "LLMEngine Batch Inference") 383 | while True: 384 | batch_status = check_batch_status(batch_result_id) 385 | assert batch_status is not None, f"Error: {batch_result_id} not found in batch submission status or OpenAI API" 386 | num_completed = batch_status["openai_batch_metadata"]['request_counts']['completed'] 387 | num_total = batch_status["openai_batch_metadata"]['request_counts']['total'] 388 | num_failed = batch_status["openai_batch_metadata"]['request_counts']['failed'] 389 | tqdm_bar.n = num_completed 390 | tqdm_bar.total = num_total 391 | tqdm_bar.set_postfix(completed=num_completed, total=num_total, failed=num_failed) 392 | if batch_status["status"] == "completed": 393 | tqdm_bar.close() 394 | break 395 | elif batch_status["status"] == "finalizing": 396 | tqdm_bar.desc = "Finalizing" 397 | tqdm_bar.refresh() 398 | elif batch_status["status"] == "validating": 399 | tqdm_bar.desc = "Validating" 400 | tqdm_bar.refresh() 401 | elif batch_status["status"] == "failed": 402 | tqdm_bar.close() 403 | print("Batch failed") 404 | break 405 | elif batch_status["status"] == "in_progress": 406 | tqdm_bar.desc = "In Progress" 407 | tqdm_bar.refresh() 408 | elif batch_status["status"] == "cancelled": 409 | tqdm_bar.close() 410 | print("Batch cancelled") 411 | break 412 | elif batch_status["status"] == "expired": 413 | tqdm_bar.close() 414 | print("Batch expired") 415 | break 416 | elif batch_status["status"] == "cancelling": 417 | tqdm_bar.desc = "Cancelling" 418 | tqdm_bar.refresh() 419 | else: 420 | tqdm_bar.desc = batch_status["status"] 421 | tqdm_bar.refresh() 422 | time.sleep(random.randint(5, 10)) 423 | 424 | results = get_batch_result(batch_result_id, generate_kwargs) 425 | return results 426 | 427 | if __name__ == "__main__": 428 | from icecream import ic 429 | ic(call_worker_openai(["Hello"], "gpt-3.5-turbo")) 430 | ic(call_worker_openai_completion("Hello", "gpt-3.5-turbo-instruct")) -------------------------------------------------------------------------------- /llm_engines/sglang.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import random 5 | import openai 6 | import importlib.util 7 | from pathlib import Path 8 | from typing import List 9 | from sglang import function, system, user, assistant, gen 10 | from .utils import SubprocessMonitor, ChatTokenizer, with_timeout, get_function_arg_names 11 | worker_initiated = False 12 | sglang_workers = {} 13 | def launch_sglang_worker( 14 | model_name: str, 15 | num_gpus: int=None, 16 | gpu_ids: List[int]=None, 17 | dtype: str="auto", 18 | quantization: str=None, 19 | port: int=34200, 20 | host: str="127.0.0.1", 21 | additional_args: List[str]=[] 22 | ) -> str: 23 | """ 24 | Launch a model worker and return the address 25 | Args: 26 | model_name: the model name to launch 27 | Returns: 28 | the address of the launched model 29 | """ 30 | # python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 31 | ### For debug 32 | # port, additonal_ports = allocate_init_ports(port) 33 | # print(f"Launching SGLang model {model_name} on port {port}") 34 | # print(f"Additional ports: {additonal_ports}") 35 | ### For debug 36 | worker_addr = f"http://{host}:{port}" 37 | log_file = Path(os.path.abspath(__file__)).parent / "logs" / f"{model_name}.log" 38 | log_file.parent.mkdir(parents=True, exist_ok=True) 39 | if gpu_ids: 40 | num_gpus = len(gpu_ids) 41 | else: 42 | if not num_gpus: 43 | num_gpus = torch.cuda.device_count() 44 | print(f"Warning: num_gpus or gpu_ids not provided, using {num_gpus} GPUs") 45 | gpu_ids = list(range(num_gpus)) 46 | env = os.environ.copy() 47 | # Set the CUDA_VISIBLE_DEVICES environment variable 48 | env["CUDA_VISIBLE_DEVICES"] = ",".join([str(gpu_id) for gpu_id in gpu_ids]) 49 | 50 | # check flashinfer 51 | flashinfer = importlib.util.find_spec("flashinfer") 52 | if flashinfer is None: 53 | print("flashinfer not found, please first install flashinfer for sglang") 54 | print("Simple Command: pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/") 55 | print("Please refer to https://docs.flashinfer.ai/installation.html for detailed installation instructions") 56 | exit(1) 57 | else: 58 | print("flashinfer found, enable flashinfer for sglang") 59 | flashinfer_args = [] 60 | if quantization: 61 | available_quantizations = "awq,fp8,gptq,marlin,gptq_marlin,awq_marlin,squeezellm,bitsandbytes" 62 | available_quantizations = available_quantizations.split(",") 63 | if quantization not in available_quantizations: 64 | raise ValueError(f"Quantization {quantization} not supported, available quantizations: {available_quantizations}") 65 | flashinfer_args = ["--quantization", quantization] 66 | # additonal_ports = [port+i for i in range(1, 9)] 67 | proc = SubprocessMonitor([ 68 | "python3", "-m", "sglang.launch_server", 69 | "--model-path", model_name, 70 | "--host", host, 71 | "--port", str(port), 72 | "--dtype", dtype, 73 | # "--api-key", "sglang", 74 | "--log-level", "warning", 75 | "--tp-size", str(num_gpus) if num_gpus is not None else "1", 76 | # "--additional-ports"] + [str(port) for port in additonal_ports 77 | ] + flashinfer_args + additional_args ,env=env) 78 | print(f"Launching SGLang model {model_name} with CUDA_VISIBLE_DEVICES={env['CUDA_VISIBLE_DEVICES']}") 79 | sglang_workers[worker_addr] = proc 80 | return worker_addr, proc 81 | 82 | @function 83 | def multi_turn_question(s, messages, system_message=None): 84 | if system_message: 85 | s += system(system_message) 86 | for i, message in enumerate(messages): 87 | if i % 2 == 0: 88 | s += user(message) 89 | else: 90 | s += assistant(message) 91 | s += assistant(gen("answer")) 92 | 93 | @function 94 | def question(s, prompt): 95 | s += prompt 96 | s += gen("answer") 97 | 98 | chat_tokenizers = {} 99 | def call_sglang_worker(messages, model_name, worker_addrs, timeout:int=300, conv_system_msg=None, **generate_kwargs) -> str: 100 | global worker_initiated 101 | global chat_tokenizers 102 | 103 | if model_name not in chat_tokenizers: 104 | chat_tokenizers[model_name] = ChatTokenizer(model_name) 105 | chat_tokenizer = chat_tokenizers[model_name] 106 | 107 | # change messages to openai format 108 | if conv_system_msg: 109 | chat_messages = [{"role": "system", "content": conv_system_msg}] + messages 110 | else: 111 | chat_messages = messages 112 | 113 | # prompt = chat_tokenizer(chat_messages) 114 | 115 | worker_addr = random.choice(worker_addrs) 116 | 117 | client = openai.OpenAI( 118 | base_url=f"{worker_addr}/v1", 119 | api_key="sglang-engine-token", 120 | ) 121 | 122 | generate_kwargs['max_tokens'] = generate_kwargs.get('max_tokens', chat_tokenizer.max_length) # for sglang, max_tokens is required and must > 0 123 | args_names, kwargs_names = get_function_arg_names(client.chat.completions.create) 124 | extra_body_params = {} 125 | for key in list(generate_kwargs.keys()): 126 | if key not in args_names + kwargs_names: 127 | extra_body_params[key] = generate_kwargs[key] 128 | del generate_kwargs[key] 129 | generate_kwargs["extra_body"] = extra_body_params 130 | 131 | stream = generate_kwargs.get("stream", False) 132 | if stream: 133 | generate_kwargs.pop("n", None) 134 | @with_timeout(timeout) 135 | def get_response(): 136 | while True: 137 | try: 138 | completion = client.chat.completions.create( 139 | model=model_name, 140 | messages=chat_messages, 141 | **generate_kwargs, 142 | ) 143 | break 144 | except openai.APIConnectionError as e: 145 | if not worker_initiated: 146 | time.sleep(5) 147 | continue 148 | print(f"API connection error: {e}") 149 | time.sleep(5) 150 | continue 151 | if not stream: 152 | if "logprobs" not in generate_kwargs or not generate_kwargs["logprobs"]: 153 | if len(completion.choices) > 1: 154 | return [c.message.content for c in completion.choices] 155 | else: 156 | return completion.choices[0].message.content 157 | else: 158 | if len(completion.choices) > 1: 159 | return [c.message.content for c in completion.choices], [c.logprobs.dict() for c in completion.choices] 160 | else: 161 | return completion.choices[0].message.content, completion.choices[0].logprobs.dict() 162 | else: 163 | def generate_stream(): 164 | for chunk in completion: 165 | if chunk.choices[0].delta.content is not None: 166 | yield chunk.choices[0].delta.content 167 | return generate_stream() 168 | return get_response() 169 | 170 | def call_sglang_worker_completion(prompt:str, model_name, worker_addrs, timeout:int=300, **generate_kwargs) -> str: 171 | global worker_initiated 172 | global chat_tokenizers 173 | if model_name not in chat_tokenizers: 174 | chat_tokenizers[model_name] = ChatTokenizer(model_name) 175 | chat_tokenizer = chat_tokenizers[model_name] 176 | 177 | if "max_new_tokens" in generate_kwargs: 178 | if "max_tokens" not in generate_kwargs: 179 | generate_kwargs["max_tokens"] = generate_kwargs["max_new_tokens"] 180 | del generate_kwargs["max_new_tokens"] 181 | 182 | worker_addr = random.choice(worker_addrs) 183 | 184 | client = openai.OpenAI( 185 | base_url=f"{worker_addr}/v1", 186 | api_key="sglang-engine-token", 187 | ) 188 | 189 | generate_kwargs['max_tokens'] = generate_kwargs.get('max_tokens', chat_tokenizer.max_length) # for sglang, max_tokens is required and must > 0 190 | args_names, kwargs_names = get_function_arg_names(client.completions.create) 191 | extra_body_params = {} 192 | for key in list(generate_kwargs.keys()): 193 | if key not in args_names + kwargs_names: 194 | extra_body_params[key] = generate_kwargs[key] 195 | del generate_kwargs[key] 196 | generate_kwargs["extra_body"] = extra_body_params 197 | 198 | stream = generate_kwargs.get("stream", False) 199 | if stream: 200 | generate_kwargs.pop("n", None) 201 | @with_timeout(timeout) 202 | def get_response(): 203 | while True: 204 | try: 205 | completion = client.completions.create( 206 | model=model_name, 207 | prompt=prompt, 208 | **generate_kwargs, 209 | ) 210 | break 211 | except openai.APIConnectionError as e: 212 | if not worker_initiated: 213 | time.sleep(5) 214 | continue 215 | print(f"API connection error: {e}") 216 | time.sleep(5) 217 | continue 218 | 219 | # return completion.choices[0].text 220 | if not stream: 221 | if len(completion.choices) > 1: 222 | return [c.text for c in completion.choices] 223 | else: 224 | return completion.choices[0].text 225 | else: 226 | def generate_stream(): 227 | for chunk in completion: 228 | if chunk.choices[0].text is not None: 229 | yield chunk.choices[0].text 230 | return generate_stream() 231 | return get_response() -------------------------------------------------------------------------------- /llm_engines/together.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from .utils import with_timeout 4 | together_client = None 5 | def call_worker_together(messages, model_name, timeout:int=60, conv_system_msg=None, **generate_kwargs) -> str: 6 | from together import Together 7 | global together_client 8 | if together_client is None: 9 | together_client = Together(api_key=os.environ.get("TOGETHER_API_KEY"), timeout=timeout) 10 | 11 | if model_name.startswith("together_"): 12 | model_name = model_name.replace("together_", "") 13 | 14 | # change messages to openai format 15 | if conv_system_msg: 16 | new_messages = [{"role": "system", "content": conv_system_msg}] + messages 17 | else: 18 | new_messages = messages 19 | 20 | stream = generate_kwargs.get("stream", False) 21 | if stream and "n" in generate_kwargs: 22 | generate_kwargs.pop("n") 23 | 24 | @with_timeout(timeout) 25 | def get_response(): 26 | max_retry_for_unbound_local_error = 10 27 | retry_count = 0 28 | while True: 29 | try: 30 | completion = together_client.chat.completions.create( 31 | model=model_name, 32 | messages=new_messages, 33 | **generate_kwargs, 34 | ) 35 | break 36 | except UnboundLocalError as e: 37 | time.sleep(0.2) 38 | retry_count += 1 39 | if retry_count >= max_retry_for_unbound_local_error: 40 | 41 | raise e 42 | continue 43 | if not stream: 44 | if "logprobs" not in generate_kwargs or not generate_kwargs["logprobs"]: 45 | if len(completion.choices) > 1: 46 | return [c.message.content for c in completion.choices] 47 | else: 48 | return completion.choices[0].message.content 49 | else: 50 | if len(completion.choices) > 1: 51 | return [c.message.content for c in completion.choices], [c.logprobs.dict() for c in completion.choices] 52 | else: 53 | return completion.choices[0].message.content, completion.choices[0].logprobs.dict() 54 | else: 55 | def generate_stream(): 56 | for chunk in completion: 57 | if chunk.choices[0].delta.content is not None: 58 | yield chunk.choices[0].delta.content 59 | return generate_stream() 60 | return get_response() 61 | 62 | def call_worker_together_completion(prompt:str, model_name, timeout:int=60, **generate_kwargs) -> str: 63 | from together import Together 64 | global together_client 65 | if together_client is None: 66 | together_client = Together(api_key=os.environ.get("TOGETHER_API_KEY"), timeout=timeout) 67 | 68 | if model_name.startswith("together_"): 69 | model_name = model_name.replace("together_", "") 70 | 71 | stream = generate_kwargs.get("stream", False) 72 | if stream and "n" in generate_kwargs: 73 | generate_kwargs.pop("n") 74 | @with_timeout(timeout) 75 | def get_response(): 76 | completion = together_client.completions.create( 77 | model=model_name, 78 | prompt=prompt, 79 | **generate_kwargs, 80 | ) 81 | if not stream: 82 | if len(completion.choices) > 1: 83 | return [c.text for c in completion.choices] 84 | else: 85 | return completion.choices[0].text 86 | else: 87 | def generate_stream(): 88 | for chunk in completion: 89 | if chunk.choices[0].delta.content is not None: 90 | yield chunk.choices[0].delta.content 91 | return generate_stream() 92 | return get_response() -------------------------------------------------------------------------------- /llm_engines/utils.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import time 3 | import os 4 | import signal 5 | import os 6 | import signal 7 | import json 8 | import traceback 9 | import threading 10 | import openai 11 | import inspect 12 | import base64 13 | import datetime 14 | import regex as re 15 | import requests 16 | from io import BytesIO 17 | from PIL import Image 18 | from pathlib import Path 19 | from typing import Union, List 20 | from typing import List 21 | from transformers import AutoTokenizer 22 | from functools import partial 23 | from .cache import load_cache, get_inputs_hash, get_cache_file 24 | 25 | default_gen_params = { 26 | "temperature": 0.0, 27 | "max_tokens": None, 28 | "top_p": 1.0, 29 | "timeout": 600, 30 | } 31 | class SubprocessMonitor: 32 | def _monitor(self): 33 | while True: 34 | if self.proc.poll() is not None: 35 | print("Subprocess has exited with code", self.proc.returncode) 36 | os.kill(os.getpid(), signal.SIGTERM) # Exit the main process 37 | break 38 | time.sleep(5) 39 | 40 | def __init__(self, command, **kwargs): 41 | print("Launching subprocess with command:\n", " ".join(command)) 42 | self.proc = subprocess.Popen(command, **kwargs) 43 | # self.monitor_thread = threading.Thread(target=self._monitor) 44 | # self.monitor_thread.start() 45 | 46 | class ChatTokenizer: 47 | def __init__(self, model_name): 48 | 49 | self.model_name = model_name 50 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 51 | self.system_message = None 52 | try: 53 | self.max_length = self.tokenizer.model_max_length 54 | except AttributeError: 55 | self.max_length = 4096 56 | if not isinstance(self.max_length, int): 57 | self.max_length = 4096 58 | if self.max_length > 1e6: 59 | self.max_length = 1e6 60 | 61 | if self.tokenizer.chat_template: 62 | self.apply_chat_template = self.apply_chat_template_default 63 | print("Using hugging face chat template for model", model_name) 64 | self.chat_template_source = "huggingface" 65 | else: 66 | self.apply_chat_template = None 67 | self.chat_template_source = None 68 | print("Example prompt: \n", self.example_prompt()) 69 | 70 | def apply_chat_template_default( 71 | self, 72 | messages:List[str], 73 | add_generation_prompt:bool=True, 74 | chat_template:str=None 75 | ): 76 | prompt = self.tokenizer.apply_chat_template( 77 | messages, 78 | add_generation_prompt=add_generation_prompt, 79 | tokenize=False, 80 | chat_template=chat_template, 81 | ) 82 | return prompt 83 | 84 | def example_prompt(self): 85 | if not self.apply_chat_template: 86 | return "Chat template not available for this model" 87 | else: 88 | example_messages = [ 89 | {"role": "user", "content": "Hello"}, 90 | {"role": "assistant", "content": "Hi"}, 91 | {"role": "user", "content": "How are you?"}, 92 | {"role": "assistant", "content": "I'm good, how about you?"}, 93 | ] 94 | return self.apply_chat_template(example_messages) 95 | 96 | def __call__(self, messages:List[str], **kwargs): 97 | if not self.apply_chat_template: 98 | raise NotImplementedError("Chat template not available for this model") 99 | return self.apply_chat_template(messages, **kwargs) 100 | 101 | 102 | def convert_messages(messages:Union[List[str], List[dict], str]): 103 | """ 104 | Convert messages to the format expected by the model 105 | """ 106 | if all(isinstance(item, dict) for item in messages): 107 | assert all("content" in item for item in messages), "content key not found in messages" 108 | assert all("role" in item for item in messages), "role key not found in messages" 109 | if messages[0]["role"] == "system": 110 | conv_system_msg = messages[0]["content"] 111 | messages = messages[1:] 112 | else: 113 | conv_system_msg = None 114 | new_messages = messages 115 | else: 116 | if isinstance(messages, str): 117 | messages = [messages] 118 | assert all(isinstance(item, str) for item in messages) 119 | new_messages = [] 120 | for i, message in enumerate(messages): 121 | if i % 2 == 0: 122 | new_messages.append({"role": "user", "content": message}) 123 | else: 124 | new_messages.append({"role": "assistant", "content": message}) 125 | conv_system_msg = None 126 | 127 | # assert the correct format of images 128 | for message in new_messages: 129 | assert "role" in message, "role key not found in message" 130 | assert "content" in message, "content key not found in message" 131 | if isinstance(message["content"], str): 132 | pass 133 | elif isinstance(message["content"], list): 134 | for sub_content in message["content"]: 135 | assert sub_content["type"] in sub_content, f"'{sub_content['type']}' key not found in sub_content of type {sub_content['type']}" 136 | if sub_content["type"] == "text": 137 | assert isinstance(sub_content["text"], str), "text key not found in sub_content" 138 | elif sub_content["type"] == "image_url": 139 | assert "url" in sub_content["image_url"] and isinstance(sub_content["image_url"]["url"], str), "url key not found in sub_content['image_url']" 140 | elif sub_content["type"] == "image": 141 | assert isinstance(sub_content["image"], Image.Image), "The image key in of 'image' type must be a PIL Image object" 142 | # change image to image_url 143 | sub_content["type"] = "image_url" 144 | sub_content["image_url"] = {"url": encode_base64_image_url(sub_content["image"])} 145 | del sub_content["image"] 146 | else: 147 | raise ValueError(f"Unsupported sub_content type: {sub_content['type']}") 148 | else: 149 | raise ValueError(f"Unsupported content type: {type(message['content'])}") 150 | 151 | return new_messages, conv_system_msg 152 | 153 | def _convert_messages_wrapper(messages:Union[List[str], List[dict], str], call_model_worker, is_completion=False, **generate_kwargs): 154 | if not is_completion: 155 | messages, conv_system_msg = convert_messages(messages) 156 | generate_kwargs["conv_system_msg"] = conv_system_msg 157 | else: 158 | assert isinstance(messages, str), "Completion model only accepts a single string input" 159 | # add default generation parameters 160 | for key, value in default_gen_params.items(): 161 | if key not in generate_kwargs: 162 | generate_kwargs[key] = value 163 | return call_model_worker(messages, **generate_kwargs) 164 | 165 | def convert_messages_wrapper(call_model_worker, is_completion=False): 166 | return partial(_convert_messages_wrapper, call_model_worker=call_model_worker, is_completion=is_completion) 167 | 168 | class MaxRetriesExceededError(Exception): 169 | pass 170 | 171 | short_error_instances = [ 172 | openai.BadRequestError, 173 | ] 174 | 175 | def _retry_on_failure(*args, call_model_worker=None, num_retries=5, **kwargs): 176 | if not call_model_worker: 177 | raise ValueError("call_model_worker is required") 178 | try: 179 | return call_model_worker(*args, **kwargs) 180 | except Exception as e: 181 | if not num_retries: 182 | if any(isinstance(e, error_instance) for error_instance in short_error_instances): 183 | print(str(e)) 184 | else: 185 | print(traceback.format_exc()) 186 | raise MaxRetriesExceededError(f"Max retries exceeded for call_model_worker (num_retries={num_retries})") 187 | for i in range(num_retries): 188 | try: 189 | return call_model_worker(*args, **kwargs) 190 | except Exception as e: 191 | print("Error in call_model_worker, retrying... (Error: {})".format(e)) 192 | time.sleep(1) 193 | if i >= num_retries - 1 and not isinstance(e, TimeoutError): 194 | if any(isinstance(e, error_instance) for error_instance in short_error_instances): 195 | print(str(e)) 196 | else: 197 | # format dump of the last error and 198 | print(traceback.format_exc()) 199 | raise MaxRetriesExceededError(f"Max retries exceeded for call_model_worker (num_retries={num_retries})") 200 | 201 | def retry_on_failure(call_model_worker, num_retries=5): 202 | return partial(_retry_on_failure, call_model_worker=call_model_worker, num_retries=num_retries) 203 | 204 | def timeout_handler(signum, frame): 205 | raise TimeoutError("Function call timed out") 206 | 207 | def with_timeout(timeout): 208 | def decorator(func): 209 | def wrapper(*args, **kwargs): 210 | result = [TimeoutError(f"Function call timed out (timeout={timeout})")] 211 | stop_event = threading.Event() 212 | 213 | def target(): 214 | try: 215 | result[0] = func(*args, **kwargs) 216 | except Exception as e: 217 | result[0] = e 218 | 219 | thread = threading.Thread(target=target) 220 | thread.start() 221 | thread.join(timeout) 222 | if thread.is_alive(): 223 | stop_event.set() 224 | raise TimeoutError(f"Function call timed out (timeout={timeout})") 225 | if isinstance(result[0], Exception): 226 | raise result[0] 227 | return result[0] 228 | return wrapper 229 | return decorator 230 | 231 | def max_retry_wrapper(func, *args, **kwargs): 232 | try: 233 | return func(*args, **kwargs) 234 | except MaxRetriesExceededError as e: 235 | print(str(e)) 236 | return None 237 | 238 | def get_function_arg_names(func): 239 | signature = inspect.signature(func) 240 | parameters = signature.parameters 241 | 242 | arg_names = [] 243 | kwarg_names = [] 244 | 245 | for name, param in parameters.items(): 246 | if param.kind == inspect.Parameter.VAR_POSITIONAL: 247 | arg_names.append(f"*{name}") 248 | elif param.kind == inspect.Parameter.VAR_KEYWORD: 249 | kwarg_names.append(f"**{name}") 250 | else: 251 | arg_names.append(name) 252 | 253 | return arg_names, kwarg_names 254 | 255 | def encode_base64_image(image:Image.Image, image_format="PNG") -> str: 256 | im_file = BytesIO() 257 | image.save(im_file, format=image_format) 258 | im_bytes = im_file.getvalue() 259 | im_64 = base64.b64encode(im_bytes).decode("utf-8") 260 | return im_64 261 | 262 | def encode_base64_image_url(image:Image.Image, image_format="PNG") -> str: 263 | return f"data:image/{image_format};base64,{encode_base64_image(image, image_format)}" 264 | 265 | def decode_base64_image_url(base64_uri:str) -> Image.Image: 266 | # Split the URI to get the base64 data 267 | try: 268 | # Remove the "data:image/format;base64," prefix 269 | header, base64_data = base64_uri.split(',', 1) 270 | # Get image format from header 271 | image_format = header.split('/')[1].split(';')[0] 272 | # Decode base64 string 273 | image_data = base64.b64decode(base64_data) 274 | # Create image from binary data 275 | image = Image.open(BytesIO(image_data)) 276 | return image 277 | except Exception as e: 278 | raise ValueError(f"Failed to decode base64 image: {str(e)}") 279 | 280 | def is_base64_image_url(base64_uri:str) -> bool: 281 | return base64_uri.startswith("data:image/") 282 | 283 | def load_image(image_path:str) -> Image.Image: 284 | # either http or local file path 285 | if image_path.startswith("http"): 286 | response = requests.get(image_path) 287 | image = Image.open(BytesIO(response.content)).convert("RGB") 288 | else: 289 | image = Image.open(image_path).convert("RGB") 290 | return image -------------------------------------------------------------------------------- /llm_engines/vllm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import random 5 | import json 6 | import openai 7 | import vllm 8 | import signal 9 | import regex as re 10 | from pathlib import Path 11 | from typing import List 12 | from .utils import SubprocessMonitor, ChatTokenizer, with_timeout, get_function_arg_names 13 | from huggingface_hub import HfApi, hf_hub_download, snapshot_download 14 | worker_initiated = False 15 | vllm_version = vllm.__version__ 16 | 17 | chat_tokenizers = {} 18 | def launch_vllm_worker( 19 | model_name: str, 20 | num_gpus: int=None, 21 | gpu_ids: List[int]=None, 22 | dtype: str="auto", 23 | quantization: str=None, 24 | port: int=34200, 25 | host: str="127.0.0.1", 26 | root_path: str=None, 27 | additional_args: List[str]=[], 28 | ) -> str: 29 | """ 30 | Launch a model worker and return the address 31 | Args: 32 | model_name: the model name to launch 33 | Returns: 34 | the address of the launched model 35 | """ 36 | print(f"Launching model {model_name}") 37 | worker_addr = f"http://{host}:{port}" 38 | log_file = Path(os.path.abspath(__file__)).parent / "logs" / f"{model_name}.log" 39 | log_file.parent.mkdir(parents=True, exist_ok=True) 40 | if gpu_ids: 41 | num_gpus = len(gpu_ids) 42 | else: 43 | if not num_gpus: 44 | num_gpus = torch.cuda.device_count() 45 | print(f"Warning: num_gpus or gpu_ids not provided, using {num_gpus} GPUs") 46 | gpu_ids = list(range(num_gpus)) 47 | env = os.environ.copy() 48 | # Set the CUDA_VISIBLE_DEVICES environment variable 49 | env["CUDA_VISIBLE_DEVICES"] = ",".join([str(gpu_id) for gpu_id in gpu_ids]) 50 | if "gemma-2" in model_name: 51 | env["VLLM_ATTENTION_BACKEND"] = "FLASHINFER" 52 | else: 53 | env["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN" 54 | env["VLLM_SERVER_DEV_MODE"] = "1" 55 | # print(num_gpus, gpu_ids) 56 | 57 | model_path = Path(model_name) 58 | if model_path.exists() and ((model_path / "config.json").exists() or (model_path / "adapter_config.json").exists()): 59 | if (model_path / "adapter_config.json").exists(): 60 | print(f"Loading local model {model_name} with adapter") 61 | use_lora = True 62 | with open(model_path / "adapter_config.json") as f: 63 | adapter_config = json.load(f) 64 | adapter_path = model_path 65 | base_model_name_or_path = adapter_config["base_model_name_or_path"] 66 | elif (model_path / "config.json").exists(): 67 | print(f"Loading local model {model_name}") 68 | use_lora = False 69 | adapter_path = None 70 | base_model_name_or_path = model_name 71 | else: 72 | raise ValueError(f"no config.json or adapter_config.json found in model {model_name}") 73 | else: 74 | # check whether there is a adapter_config.json 75 | api = HfApi() 76 | model_info = api.model_info(model_name) 77 | model_files = [x.rfilename for x in model_info.siblings] 78 | if "adapter_config.json" in model_files: 79 | use_lora = True 80 | adapter_path = snapshot_download(model_name) 81 | adapter_config_path = Path(adapter_path) / "adapter_config.json" 82 | with open(adapter_config_path) as f: 83 | adapter_config = json.load(f) 84 | base_model_name_or_path = adapter_config["base_model_name_or_path"] 85 | print(f"Loading model from Hugging Face {model_name} with adapter") 86 | elif "config.json" in model_files: 87 | use_lora = False 88 | adapter_path = None 89 | base_model_name_or_path = model_name 90 | print(f"Loading model from Hugging Face {model_name}") 91 | else: 92 | raise ValueError(f"no config.json or adapter_config.json found in model {model_name}") 93 | 94 | # python -m vllm.entrypoints.openai.api_server \ 95 | # --model meta-llama/Llama-2-7b-hf \ 96 | # --enable-lora \ 97 | # --lora-modules sql-lora=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/ 98 | if use_lora: 99 | lora_args = [ 100 | "--enable-lora", 101 | "--lora-modules", f"{model_name}={adapter_path}", 102 | "--max-loras", "1", 103 | "--max-lora-rank", str(adapter_config["r"]) 104 | ] 105 | else: 106 | lora_args = [] 107 | if quantization: 108 | available_quantizations = "aqlm,awq,deepspeedfp,tpu_int8,fp8,fbgemm_fp8,marlin,gguf,gptq_marlin_24,gptq_marlin,awq_marlin,gptq,squeezellm,compressed-tensors,bitsandbytes,qqq,experts_int8" 109 | available_quantizations = available_quantizations.split(",") 110 | if quantization not in available_quantizations: 111 | raise ValueError(f"quantization {quantization} not in available quantizations: {available_quantizations}") 112 | lora_args += ["--quantization", quantization] 113 | if quantization == "bitsandbytes": 114 | lora_args += ["--load-format", "bitsandbytes", "--enforce-eager"] 115 | # python -m vllm.entrypoints.openai.api_server --model NousResearch/Meta-Llama-3-8B-Instruct --dtype auto --api-key token-abc123 116 | proc = SubprocessMonitor([ 117 | "vllm", "serve", 118 | base_model_name_or_path, 119 | "--dtype", dtype, 120 | "--api-key", "vllm-engine-token", 121 | "--port", str(port), 122 | "--host", host, 123 | "--tensor-parallel-size", str(num_gpus), 124 | "--disable-log-requests", 125 | "--trust-remote-code", 126 | # "--enable-sleep-mode" 127 | ] + (["--root-path", root_path] if root_path else []) 128 | + lora_args + additional_args, env=env) 129 | print(f"Launched VLLM model {model_name} at address {worker_addr} with CUDA_VISIBLE_DEVICES={env['CUDA_VISIBLE_DEVICES']}") 130 | if model_name not in chat_tokenizers: 131 | chat_tokenizers[model_name] = ChatTokenizer(base_model_name_or_path) 132 | if base_model_name_or_path not in chat_tokenizers: 133 | chat_tokenizers[base_model_name_or_path] = ChatTokenizer(base_model_name_or_path) 134 | return f"http://127.0.0.1:{port}", proc 135 | 136 | def call_vllm_worker(messages, model_name, worker_addrs, timeout:int=300, conv_system_msg=None, **generate_kwargs) -> str: 137 | global worker_initiated 138 | global chat_tokenizers 139 | if "max_new_tokens" in generate_kwargs: 140 | if "max_tokens" not in generate_kwargs: 141 | generate_kwargs["max_tokens"] = generate_kwargs["max_new_tokens"] 142 | del generate_kwargs["max_new_tokens"] 143 | # try: 144 | # if model_name not in chat_tokenizers: 145 | # chat_tokenizers[model_name] = ChatTokenizer(model_name) 146 | # chat_tokenizer = chat_tokenizers[model_name] 147 | # prompt = chat_tokenizer(chat_messages) 148 | # except Exception as e: 149 | # pass 150 | 151 | # change messages to openai format 152 | if conv_system_msg: 153 | chat_messages = [{"role": "system", "content": conv_system_msg}] + messages 154 | else: 155 | chat_messages = messages 156 | 157 | worker_addr = random.choice(worker_addrs) 158 | 159 | client = openai.OpenAI( 160 | base_url=f"{worker_addr}/v1", 161 | api_key="vllm-engine-token", 162 | ) 163 | 164 | args_names, kwargs_names = get_function_arg_names(client.chat.completions.create) 165 | extra_body_params = {} 166 | for key in list(generate_kwargs.keys()): 167 | if key not in args_names + kwargs_names: 168 | extra_body_params[key] = generate_kwargs[key] 169 | del generate_kwargs[key] 170 | generate_kwargs["extra_body"] = extra_body_params 171 | 172 | stream = generate_kwargs.get("stream", False) 173 | if stream: 174 | generate_kwargs.pop("n", None) 175 | 176 | # print(generate_kwargs) 177 | @with_timeout(timeout) 178 | def get_response(): 179 | while True: 180 | try: 181 | completion = client.chat.completions.create( 182 | model=model_name, 183 | messages=chat_messages, 184 | **generate_kwargs, 185 | ) 186 | break 187 | except openai.APIConnectionError as e: 188 | if not worker_initiated: 189 | time.sleep(5) 190 | continue 191 | print(f"API connection error: {e}") 192 | time.sleep(5) 193 | continue 194 | except openai.BadRequestError as e: 195 | error_response = e.response.json() 196 | if error_response['code'] == 400: 197 | pattern = r"This model's maximum context length is (\d+) tokens. However, you requested (\d+) tokens \((\d+) in the messages, (\d+) in the completion\). Please reduce the length of the messages or completion." 198 | res = re.match(pattern, error_response['message']) 199 | if res: 200 | max_context_length = int(res.group(1)) 201 | num_tokens_requested = int(res.group(2)) 202 | num_tokens_in_messages = int(res.group(3)) 203 | num_tokens_in_completion = int(res.group(4)) 204 | 205 | new_max_tokens = num_tokens_in_completion - (num_tokens_requested - max_context_length) 206 | if new_max_tokens <= 0: 207 | raise e 208 | print(f"Reducing max_tokens to {new_max_tokens}, and retrying") 209 | generate_kwargs["max_tokens"] = new_max_tokens 210 | continue 211 | else: 212 | raise e 213 | 214 | if not stream: 215 | if "logprobs" not in generate_kwargs or not generate_kwargs["logprobs"]: 216 | if len(completion.choices) > 1: 217 | return [c.message.content for c in completion.choices] 218 | else: 219 | return completion.choices[0].message.content 220 | else: 221 | if len(completion.choices) > 1: 222 | return [c.message.content for c in completion.choices], [c.logprobs.dict() for c in completion.choices] 223 | else: 224 | return completion.choices[0].message.content, completion.choices[0].logprobs.dict() 225 | else: 226 | def generate_stream(): 227 | for chunk in completion: 228 | if chunk.choices[0].delta.content is not None: 229 | yield chunk.choices[0].delta.content 230 | return generate_stream() 231 | 232 | return get_response() 233 | 234 | def call_vllm_worker_completion(prompt:str, model_name, worker_addrs, timeout:int=300, **generate_kwargs) -> str: 235 | global worker_initiated 236 | 237 | worker_addr = random.choice(worker_addrs) 238 | 239 | client = openai.OpenAI( 240 | base_url=f"{worker_addr}/v1", 241 | api_key="vllm-engine-token", 242 | ) 243 | 244 | stream = generate_kwargs.get("stream", False) 245 | if stream: 246 | generate_kwargs.pop("n", None) 247 | 248 | args_names, kwargs_names = get_function_arg_names(client.completions.create) 249 | extra_body_params = {} 250 | for key in list(generate_kwargs.keys()): 251 | if key not in args_names + kwargs_names: 252 | extra_body_params[key] = generate_kwargs[key] 253 | del generate_kwargs[key] 254 | # generate_kwargs["extra_body"] = extra_body_params 255 | 256 | @with_timeout(timeout) 257 | def get_response(): 258 | while True: 259 | try: 260 | completion = client.completions.create( 261 | model=model_name, 262 | prompt=prompt, 263 | **generate_kwargs, 264 | ) 265 | break 266 | except openai.APIConnectionError as e: 267 | if not worker_initiated: 268 | time.sleep(5) 269 | continue 270 | print(f"API connection error: {e}") 271 | time.sleep(5) 272 | continue 273 | if not stream: 274 | if len(completion.choices) > 1: 275 | return [c.text for c in completion.choices] 276 | else: 277 | return completion.choices[0].text 278 | else: 279 | def generate_stream(): 280 | for chunk in completion: 281 | if chunk.choices[0].text is not None: 282 | yield chunk.choices[0].text 283 | return generate_stream() 284 | return get_response() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | # read the contents of your README file 3 | from pathlib import Path 4 | this_directory = Path(__file__).parent 5 | long_description = (this_directory / "README.md").read_text() 6 | 7 | setup( 8 | name='llm-engines', 9 | version='0.0.25', 10 | description='A unified inference engine for large language models (LLMs) including open-source models (VLLM, SGLang, Together) and commercial models (OpenAI, Mistral, Claude).', 11 | long_description=long_description, 12 | long_description_content_type='text/markdown', 13 | author='Dongfu Jiang', 14 | author_email='dongfu.jiang@uwaterloo.ca', 15 | packages=find_packages(), 16 | url='https://github.com/jdf-prog/LLM-Engines', 17 | entry_points={"console_scripts": ["llm-engines=llm_engines.cli:main"]}, 18 | install_requires=[ 19 | "fire", 20 | "openai", 21 | "google-generativeai", 22 | "accelerate", 23 | "transformers", 24 | "torch", 25 | "Pillow", 26 | "torch", 27 | "tqdm", 28 | "numpy", 29 | "requests", 30 | "sentencepiece", 31 | "vllm", 32 | "together", 33 | "icecream", 34 | "prettytable", 35 | "mistralai", 36 | "anthropic", 37 | "fastapi", 38 | "fireworks-ai" 39 | ], 40 | extras_require={ 41 | "flash-attn": { 42 | "flash-attn" 43 | }, 44 | "sglang": { 45 | "sglang[all]" 46 | } 47 | } 48 | ) 49 | 50 | """ 51 | rm -rf dist build llm_engines.egg-info 52 | python setup.py sdist bdist_wheel 53 | twine upload dist/* 54 | """ 55 | --------------------------------------------------------------------------------