├── .github └── workflows │ └── publish.yml ├── .gitignore ├── MANIFEST.in ├── Makefile ├── README.md ├── app ├── __init__.py ├── api │ ├── __init__.py │ └── endpoints.py ├── cli.py ├── core │ ├── __init__.py │ ├── image_processor.py │ └── queue.py ├── handler │ ├── __init__.py │ ├── mlx_lm.py │ ├── mlx_vlm.py │ └── parser │ │ ├── __init__.py │ │ ├── base.py │ │ └── qwen3.py ├── main.py ├── models │ ├── __init__.py │ ├── mlx_lm.py │ └── mlx_vlm.py ├── schemas │ ├── __init__.py │ └── openai.py ├── utils │ ├── __init__.py │ └── errors.py └── version.py ├── configure_mlx.sh ├── examples ├── function_calling_examples.ipynb ├── images │ ├── attention.png │ ├── green_dog.jpeg │ └── password.jpg ├── lm_embeddings_examples.ipynb ├── pdfs │ └── lab03.pdf ├── simple_rag_demo.ipynb ├── vision_examples.ipynb └── vlm_embeddings_examples.ipynb ├── setup.py └── tests └── test_base_tool_parser.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distribution 📦 to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' # Triggers on version tags like v1.0.0 7 | 8 | jobs: 9 | build-and-publish: 10 | runs-on: macos-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - name: Set up Python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: '3.11' 19 | 20 | - name: Install build tools 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install build twine 24 | 25 | - name: Build package 26 | run: python -m build 27 | 28 | - name: Publish package to PyPI 29 | env: 30 | TWINE_USERNAME: __token__ 31 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 32 | run: twine upload dist/* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | oai-compat-server 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | # ignore .DS_Store 11 | .DS_Store 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # UV 102 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | #uv.lock 106 | 107 | # poetry 108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 109 | # This is especially recommended for binary packages to ensure reproducibility, and is more 110 | # commonly ignored for libraries. 111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 112 | #poetry.lock 113 | 114 | # pdm 115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 116 | #pdm.lock 117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 118 | # in version control. 119 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 120 | .pdm.toml 121 | .pdm-python 122 | .pdm-build/ 123 | 124 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 125 | __pypackages__/ 126 | 127 | # Celery stuff 128 | celerybeat-schedule 129 | celerybeat.pid 130 | 131 | # SageMath parsed files 132 | *.sage.py 133 | 134 | # Environments 135 | .env 136 | .venv 137 | env/ 138 | venv/ 139 | ENV/ 140 | env.bak/ 141 | venv.bak/ 142 | 143 | # Spyder project settings 144 | .spyderproject 145 | .spyproject 146 | 147 | # Rope project settings 148 | .ropeproject 149 | 150 | # mkdocs documentation 151 | /site 152 | 153 | # mypy 154 | .mypy_cache/ 155 | .dmypy.json 156 | dmypy.json 157 | 158 | # Pyre type checker 159 | .pyre/ 160 | 161 | # pytype static type analyzer 162 | .pytype/ 163 | 164 | # Cython debug symbols 165 | cython_debug/ 166 | 167 | # PyCharm 168 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 169 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 170 | # and can be added to the global gitignore or merged into this file. For a more nuclear 171 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 172 | #.idea/ 173 | 174 | # Ruff stuff: 175 | .ruff_cache/ 176 | 177 | # PyPI configuration file 178 | .pypirc 179 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include requirements.txt 3 | include MANIFEST.in 4 | include setup.py 5 | recursive-include app * -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | run: 2 | mlx-server launch \ 3 | --model-path mlx-community/Qwen3-1.7B-4bit \ 4 | --model-type lm \ 5 | --max-concurrency 1 \ 6 | --queue-timeout 300 \ 7 | --queue-size 100 8 | 9 | install: 10 | pip install -e . -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mlx-openai-server 2 | 3 | [![MIT License](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE) 4 | [![Python 3.11](https://img.shields.io/badge/python-3.11-blue.svg)](https://www.python.org/downloads/release/python-3110/) 5 | 6 | ## Description 7 | This repository hosts a high-performance API server that provides OpenAI-compatible endpoints for MLX models. Developed using Python and powered by the FastAPI framework, it provides an efficient, scalable, and user-friendly solution for running MLX-based vision and language models locally with an OpenAI-compatible interface. 8 | 9 | > **Note:** This project currently supports **MacOS with M-series chips** only as it specifically leverages MLX, Apple's framework optimized for Apple Silicon. 10 | 11 | --- 12 | 13 | ## Table of Contents 14 | - [Key Features](#key-features) 15 | - [Demo](#demo) 16 | - [OpenAI Compatibility](#openai-compatibility) 17 | - [Supported Model Types](#supported-model-types) 18 | - [Installation](#installation) 19 | - [Usage](#usage) 20 | - [Starting the Server](#starting-the-server) 21 | - [CLI Usage](#cli-usage) 22 | - [Using the API](#using-the-api) 23 | - [Request Queue System](#request-queue-system) 24 | - [API Response Schemas](#api-response-schemas) 25 | - [Example Notebooks](#example-notebooks) 26 | - [Large Models](#large-models) 27 | - [Contributing](#contributing) 28 | - [License](#license) 29 | - [Support](#support) 30 | - [Acknowledgments](#acknowledgments) 31 | 32 | --- 33 | 34 | ## Key Features 35 | - 🚀 **Fast, local OpenAI-compatible API** for MLX models 36 | - 🖼️ **Vision-language and text-only model support** 37 | - 🔌 **Drop-in replacement** for OpenAI API in your apps 38 | - 📈 **Performance and queue monitoring endpoints** 39 | - 🧑‍💻 **Easy Python and CLI usage** 40 | - 🛡️ **Robust error handling and request management** 41 | 42 | --- 43 | 44 | ## Demo 45 | 46 | ### 🚀 See It In Action 47 | 48 | Check out our [video demonstration](https://youtu.be/D9a3AZSj6v8) to see the server in action! The demo showcases: 49 | 50 | - Setting up and launching the server 51 | - Using the OpenAI Python SDK for seamless integration 52 |

53 | 54 | MLX Server OAI-Compatible Demo 55 | 56 |

57 | 58 | --- 59 | 60 | ## OpenAI Compatibility 61 | 62 | This server implements the OpenAI API interface, allowing you to use it as a drop-in replacement for OpenAI's services in your applications. It supports: 63 | - Chat completions (both streaming and non-streaming) 64 | - Vision-language model interactions 65 | - Embeddings generation 66 | - Function calling and tool use 67 | - Standard OpenAI request/response formats 68 | - Common OpenAI parameters (temperature, top_p, etc.) 69 | 70 | ## Supported Model Types 71 | 72 | The server supports two types of MLX models: 73 | 74 | 1. **Text-only models** (`--model-type lm`) - Uses the `mlx-lm` library for pure language models 75 | 2. **Vision-language models** (`--model-type vlm`) - Uses the `mlx-vlm` library for multimodal models that can process both text and images 76 | 77 | ## Installation 78 | 79 | Follow these steps to set up the MLX-powered server: 80 | 81 | ### Prerequisites 82 | - MacOS with Apple Silicon (M-series) chip 83 | - Python 3.11 (native ARM version) 84 | - pip package manager 85 | 86 | ### Setup Steps 87 | 1. Create a virtual environment for the project: 88 | ```bash 89 | python3.11 -m venv oai-compat-server 90 | ``` 91 | 92 | 2. Activate the virtual environment: 93 | ```bash 94 | source oai-compat-server/bin/activate 95 | ``` 96 | 97 | 3. Install the package: 98 | ```bash 99 | # Option 1: Install from PyPI 100 | pip install mlx-openai-server 101 | 102 | # Option 2: Install directly from GitHub 103 | pip install git+https://github.com/cubist38/mlx-openai-server.git 104 | 105 | # Option 3: Clone and install in development mode 106 | git clone https://github.com/cubist38/mlx-openai-server.git 107 | cd mlx-openai-server 108 | pip install -e . 109 | ``` 110 | 111 | ### Troubleshooting 112 | **Issue:** My OS and Python versions meet the requirements, but `pip` cannot find a matching distribution. 113 | 114 | **Cause:** You might be using a non-native Python version. Run the following command to check: 115 | ```bash 116 | python -c "import platform; print(platform.processor())" 117 | ``` 118 | If the output is `i386` (on an M-series machine), you are using a non-native Python. Switch to a native Python version. A good approach is to use [Conda](https://stackoverflow.com/questions/65415996/how-to-specify-the-architecture-or-platform-for-a-new-conda-environment-apple). 119 | 120 | ## Usage 121 | 122 | ### Starting the Server 123 | 124 | To start the MLX server, activate the virtual environment and run the main application file: 125 | ```bash 126 | source oai-compat-server/bin/activate 127 | python -m app.main \ 128 | --model-path \ 129 | --model-type \ 130 | --max-concurrency 1 \ 131 | --queue-timeout 300 \ 132 | --queue-size 100 133 | ``` 134 | 135 | #### Server Parameters 136 | - `--model-path`: Path to the MLX model directory (local path or Hugging Face model repository) 137 | - `--model-type`: Type of model to run (`lm` for text-only models, `vlm` for vision-language models). Default: `lm` 138 | - `--max-concurrency`: Maximum number of concurrent requests (default: 1) 139 | - `--queue-timeout`: Request timeout in seconds (default: 300) 140 | - `--queue-size`: Maximum queue size for pending requests (default: 100) 141 | - `--port`: Port to run the server on (default: 8000) 142 | - `--host`: Host to run the server on (default: 0.0.0.0) 143 | 144 | #### Example Configurations 145 | 146 | Text-only model: 147 | ```bash 148 | python -m app.main \ 149 | --model-path mlx-community/gemma-3-4b-it-4bit \ 150 | --model-type lm \ 151 | --max-concurrency 1 \ 152 | --queue-timeout 300 \ 153 | --queue-size 100 154 | ``` 155 | 156 | > **Note:** Text embeddings via the `/v1/embeddings` endpoint are now available with both text-only models (`--model-type lm`) and vision-language models (`--model-type vlm`). 157 | 158 | Vision-language model: 159 | ```bash 160 | python -m app.main \ 161 | --model-path mlx-community/llava-phi-3-vision-4bit \ 162 | --model-type vlm \ 163 | --max-concurrency 1 \ 164 | --queue-timeout 300 \ 165 | --queue-size 100 166 | ``` 167 | 168 | ### CLI Usage 169 | 170 | CLI commands: 171 | ```bash 172 | mlx-openai-server --version 173 | mlx-openai-server --help 174 | mlx-openai-server launch --help 175 | ``` 176 | 177 | To launch the server: 178 | ```bash 179 | mlx-openai-server launch --model-path --model-type --port 8000 180 | ``` 181 | 182 | ### Using the API 183 | 184 | The server provides OpenAI-compatible endpoints that you can use with standard OpenAI client libraries. Here are some examples: 185 | 186 | #### Text Completion 187 | ```python 188 | import openai 189 | 190 | client = openai.OpenAI( 191 | base_url="http://localhost:8000/v1", 192 | api_key="not-needed" # API key is not required for local server 193 | ) 194 | 195 | response = client.chat.completions.create( 196 | model="local-model", # Model name doesn't matter for local server 197 | messages=[ 198 | {"role": "user", "content": "What is the capital of France?"} 199 | ], 200 | temperature=0.7 201 | ) 202 | print(response.choices[0].message.content) 203 | ``` 204 | 205 | #### Vision-Language Model 206 | ```python 207 | import openai 208 | import base64 209 | 210 | client = openai.OpenAI( 211 | base_url="http://localhost:8000/v1", 212 | api_key="not-needed" 213 | ) 214 | 215 | # Load and encode image 216 | with open("image.jpg", "rb") as image_file: 217 | base64_image = base64.b64encode(image_file.read()).decode('utf-8') 218 | 219 | response = client.chat.completions.create( 220 | model="local-vlm", # Model name doesn't matter for local server 221 | messages=[ 222 | { 223 | "role": "user", 224 | "content": [ 225 | {"type": "text", "text": "What's in this image?"}, 226 | { 227 | "type": "image_url", 228 | "image_url": { 229 | "url": f"data:image/jpeg;base64,{base64_image}" 230 | } 231 | } 232 | ] 233 | } 234 | ] 235 | ) 236 | print(response.choices[0].message.content) 237 | ``` 238 | 239 | #### Function Calling 240 | ```python 241 | import openai 242 | 243 | client = openai.OpenAI( 244 | base_url="http://localhost:8000/v1", 245 | api_key="not-needed" 246 | ) 247 | 248 | # Define the messages and tools 249 | messages = [ 250 | { 251 | "role": "user", 252 | "content": "What is the weather in Tokyo?" 253 | } 254 | ] 255 | 256 | tools = [ 257 | { 258 | "type": "function", 259 | "function": { 260 | "name": "get_weather", 261 | "description": "Get the weather in a given city", 262 | "parameters": { 263 | "type": "object", 264 | "properties": { 265 | "city": {"type": "string", "description": "The city to get the weather for"} 266 | } 267 | } 268 | } 269 | } 270 | ] 271 | 272 | # Make the API call 273 | completion = client.chat.completions.create( 274 | model="local-model", 275 | messages=messages, 276 | tools=tools, 277 | tool_choice="auto" 278 | ) 279 | 280 | # Handle the tool call response 281 | if completion.choices[0].message.tool_calls: 282 | tool_call = completion.choices[0].message.tool_calls[0] 283 | print(f"Function called: {tool_call.function.name}") 284 | print(f"Arguments: {tool_call.function.arguments}") 285 | 286 | # Process the tool call - typically you would call your actual function here 287 | # For this example, we'll just hardcode a weather response 288 | weather_info = {"temperature": "22°C", "conditions": "Sunny", "humidity": "65%"} 289 | 290 | # Add the tool call and function response to the conversation 291 | messages.append(completion.choices[0].message) 292 | messages.append({ 293 | "role": "tool", 294 | "tool_call_id": tool_call.id, 295 | "name": tool_call.function.name, 296 | "content": str(weather_info) 297 | }) 298 | 299 | # Continue the conversation with the function result 300 | final_response = client.chat.completions.create( 301 | model="local-model", 302 | messages=messages 303 | ) 304 | print("\nFinal response:") 305 | print(final_response.choices[0].message.content) 306 | ``` 307 | 308 | #### Embeddings 309 | 310 | 1. Text-only model embeddings: 311 | ```python 312 | import openai 313 | 314 | client = openai.OpenAI( 315 | base_url="http://localhost:8000/v1", 316 | api_key="not-needed" 317 | ) 318 | 319 | # Generate embeddings for a single text 320 | embedding_response = client.embeddings.create( 321 | model="mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8", 322 | input=["The quick brown fox jumps over the lazy dog"] 323 | ) 324 | print(f"Embedding dimension: {len(embedding_response.data[0].embedding)}") 325 | 326 | # Generate embeddings for multiple texts 327 | batch_response = client.embeddings.create( 328 | model="mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8", 329 | input=[ 330 | "Machine learning algorithms improve with more data", 331 | "Natural language processing helps computers understand human language", 332 | "Computer vision allows machines to interpret visual information" 333 | ] 334 | ) 335 | print(f"Number of embeddings: {len(batch_response.data)}") 336 | ``` 337 | 338 | 2. Vision-language model embeddings: 339 | ```python 340 | import openai 341 | import base64 342 | from PIL import Image 343 | from io import BytesIO 344 | 345 | client = openai.OpenAI( 346 | base_url="http://localhost:8000/v1", 347 | api_key="not-needed" 348 | ) 349 | 350 | # Helper function to encode images as base64 351 | def image_to_base64(image_path): 352 | image = Image.open(image_path) 353 | buffer = BytesIO() 354 | image.save(buffer, format="PNG") 355 | buffer.seek(0) 356 | image_data = buffer.getvalue() 357 | image_base64 = base64.b64encode(image_data).decode('utf-8') 358 | return f"data:image/png;base64,{image_base64}" 359 | 360 | # Encode the image 361 | image_uri = image_to_base64("images/attention.png") 362 | 363 | # Generate embeddings for text+image 364 | vision_embedding = client.embeddings.create( 365 | model="mlx-community/Qwen2.5-VL-3B-Instruct-4bit", 366 | input=["Describe the image in detail"], 367 | extra_body={"image_url": image_uri} 368 | ) 369 | print(f"Vision embedding dimension: {len(vision_embedding.data[0].embedding)}") 370 | ``` 371 | 372 | > **Note:** Replace the model name and image path as needed. The `extra_body` parameter is used to pass the image data URI to the API. 373 | 374 | > **Warning:** Make sure you're running the server with `--model-type vlm` when making vision requests. If you send a vision request to a server running with `--model-type lm` (text-only model), you'll receive a 400 error with a message that vision requests are not supported with text-only models. 375 | 376 | ## Request Queue System 377 | 378 | The server implements a robust request queue system to manage and optimize MLX model inference requests. This system ensures efficient resource utilization and fair request processing. 379 | 380 | ### Key Features 381 | 382 | - **Concurrency Control**: Limits the number of simultaneous model inferences to prevent resource exhaustion 383 | - **Request Queuing**: Implements a fair, first-come-first-served queue for pending requests 384 | - **Timeout Management**: Automatically handles requests that exceed the configured timeout 385 | - **Real-time Monitoring**: Provides endpoints to monitor queue status and performance metrics 386 | 387 | ### Architecture 388 | 389 | The queue system consists of two main components: 390 | 391 | 1. **RequestQueue**: An asynchronous queue implementation that: 392 | - Manages pending requests with configurable queue size 393 | - Controls concurrent execution using semaphores 394 | - Handles timeouts and errors gracefully 395 | - Provides real-time queue statistics 396 | 397 | 2. **Model Handlers**: Specialized handlers for different model types: 398 | - `MLXLMHandler`: Manages text-only model requests 399 | - `MLXVLMHandler`: Manages vision-language model requests 400 | 401 | ### Queue Monitoring 402 | 403 | Monitor queue statistics using the `/v1/queue/stats` endpoint: 404 | 405 | ```bash 406 | curl http://localhost:8000/v1/queue/stats 407 | ``` 408 | 409 | Example response: 410 | ```json 411 | { 412 | "status": "ok", 413 | "queue_stats": { 414 | "running": true, 415 | "queue_size": 3, 416 | "max_queue_size": 100, 417 | "active_requests": 5, 418 | "max_concurrency": 2 419 | } 420 | } 421 | ``` 422 | 423 | ### Error Handling 424 | 425 | The queue system handles various error conditions: 426 | 427 | 1. **Queue Full (429)**: When the queue reaches its maximum size 428 | ```json 429 | { 430 | "detail": "Too many requests. Service is at capacity." 431 | } 432 | ``` 433 | 434 | 2. **Request Timeout**: When a request exceeds the configured timeout 435 | ```json 436 | { 437 | "detail": "Request processing timed out after 300 seconds" 438 | } 439 | ``` 440 | 441 | 3. **Model Errors**: When the model encounters an error during inference 442 | ```json 443 | { 444 | "detail": "Failed to generate response: " 445 | } 446 | ``` 447 | 448 | ### Streaming Responses 449 | 450 | The server supports streaming responses with proper chunk formatting: 451 | ```python 452 | { 453 | "id": "chatcmpl-1234567890", 454 | "object": "chat.completion.chunk", 455 | "created": 1234567890, 456 | "model": "local-model", 457 | "choices": [{ 458 | "index": 0, 459 | "delta": {"content": "chunk of text"}, 460 | "finish_reason": null 461 | }] 462 | } 463 | ``` 464 | 465 | ## API Response Schemas 466 | 467 | The server implements OpenAI-compatible API response schemas to ensure seamless integration with existing applications. Below are the key response formats: 468 | 469 | ### Chat Completions Response 470 | 471 | ```json 472 | { 473 | "id": "chatcmpl-123456789", 474 | "object": "chat.completion", 475 | "created": 1677858242, 476 | "model": "local-model", 477 | "choices": [ 478 | { 479 | "index": 0, 480 | "message": { 481 | "role": "assistant", 482 | "content": "This is the response content from the model." 483 | }, 484 | "finish_reason": "stop" 485 | } 486 | ], 487 | "usage": { 488 | "prompt_tokens": 10, 489 | "completion_tokens": 20, 490 | "total_tokens": 30 491 | } 492 | } 493 | ``` 494 | 495 | ### Embeddings Response 496 | 497 | ```json 498 | { 499 | "object": "list", 500 | "data": [ 501 | { 502 | "object": "embedding", 503 | "embedding": [0.001, 0.002, ..., 0.999], 504 | "index": 0 505 | } 506 | ], 507 | "model": "local-model", 508 | "usage": { 509 | "prompt_tokens": 8, 510 | "total_tokens": 8 511 | } 512 | } 513 | ``` 514 | 515 | ### Function/Tool Calling Response 516 | 517 | ```json 518 | { 519 | "id": "chatcmpl-123456789", 520 | "object": "chat.completion", 521 | "created": 1677858242, 522 | "model": "local-model", 523 | "choices": [ 524 | { 525 | "index": 0, 526 | "message": { 527 | "role": "assistant", 528 | "content": null, 529 | "tool_calls": [ 530 | { 531 | "id": "call_abc123", 532 | "type": "function", 533 | "function": { 534 | "name": "get_weather", 535 | "arguments": "{\"city\":\"Tokyo\"}" 536 | } 537 | } 538 | ] 539 | }, 540 | "finish_reason": "tool_calls" 541 | } 542 | ], 543 | "usage": { 544 | "prompt_tokens": 15, 545 | "completion_tokens": 25, 546 | "total_tokens": 40 547 | } 548 | } 549 | ``` 550 | 551 | ### Error Response 552 | 553 | ```json 554 | { 555 | "error": { 556 | "message": "Error message describing what went wrong", 557 | "type": "invalid_request_error", 558 | "param": null, 559 | "code": null 560 | } 561 | } 562 | ``` 563 | 564 | ## Example Notebooks 565 | 566 | The repository includes example notebooks to help you get started with different aspects of the API: 567 | 568 | - **function_calling_examples.ipynb**: A practical guide to implementing and using function calling with local models, including: 569 | - Setting up function definitions 570 | - Making function calling requests 571 | - Handling function call responses 572 | - Working with streaming function calls 573 | - Building multi-turn conversations with tool use 574 | 575 | - **vision_examples.ipynb**: A comprehensive guide to using the vision capabilities of the API, including: 576 | - Processing image inputs in various formats 577 | - Vision analysis and object detection 578 | - Multi-turn conversations with images 579 | - Using vision models for detailed image description and analysis 580 | 581 | - **lm_embeddings_examples.ipynb**: A comprehensive guide to using the embeddings API for text-only models, including: 582 | - Generating embeddings for single and batch inputs 583 | - Computing semantic similarity between texts 584 | - Building a simple vector-based search system 585 | - Comparing semantic relationships between concepts 586 | 587 | - **vlm_embeddings_examples.ipynb**: A detailed guide to working with Vision-Language Model embeddings, including: 588 | - Generating embeddings for images with text prompts 589 | - Creating text-only embeddings with VLMs 590 | - Calculating similarity between text and image representations 591 | - Understanding the shared embedding space of multimodal models 592 | - Practical applications of VLM embeddings 593 | 594 | - **simple_rag_demo.ipynb**: A practical guide to building a lightweight Retrieval-Augmented Generation (RAG) pipeline over PDF documents using local MLX Server, including: 595 | - Reading and chunking PDF documents 596 | - Generating text embeddings via MLX Server 597 | - Creating a simple vector store for retrieval 598 | - Performing question answering based on relevant chunks 599 | - End-to-end demonstration of document QA using Qwen3 local model 600 |

601 | 602 | RAG Demo 603 | 604 |

605 | 606 | 607 | ## Large Models 608 | When using models that are large relative to your system's available RAM, performance may suffer. mlx-lm tries to improve speed by wiring the memory used by the model and its cache—this optimization is only available on macOS 15.0 or newer. 609 | If you see the following warning message: 610 | > [WARNING] Generating with a model that requires ... 611 | it means the model may run slowly on your machine. If the model fits in RAM, you can often improve performance by raising the system's wired memory limit. To do this, run: 612 | ```bash 613 | bash configure_mlx.sh 614 | ``` 615 | 616 | ## Contributing 617 | We welcome contributions to improve this project! Here's how you can contribute: 618 | 1. Fork the repository to your GitHub account. 619 | 2. Create a new branch for your feature or bug fix: 620 | ```bash 621 | git checkout -b feature-name 622 | ``` 623 | 3. Commit your changes with clear and concise messages: 624 | ```bash 625 | git commit -m "Add feature-name" 626 | ``` 627 | 4. Push your branch to your forked repository: 628 | ```bash 629 | git push origin feature-name 630 | ``` 631 | 5. Open a pull request to the main repository for review. 632 | 633 | ## License 634 | This project is licensed under the [MIT License](LICENSE). You are free to use, modify, and distribute it under the terms of the license. 635 | 636 | ## Support 637 | If you encounter any issues or have questions, please: 638 | - Open an issue in the repository. 639 | - Contact the maintainers via the provided contact information. 640 | 641 | Stay tuned for updates and enhancements! 642 | 643 | ## Acknowledgments 644 | 645 | We extend our heartfelt gratitude to the following individuals and organizations whose contributions have been instrumental in making this project possible: 646 | 647 | ### Core Technologies 648 | - [MLX team](https://github.com/ml-explore/mlx) for developing the groundbreaking MLX framework, which provides the foundation for efficient machine learning on Apple Silicon 649 | - [mlx-lm](https://github.com/ml-explore/mlx-lm) for efficient large language models support 650 | - [mlx-vlm](https://github.com/Blaizzy/mlx-vlm/tree/main) for pioneering multimodal model support within the MLX ecosystem 651 | - [mlx-community](https://huggingface.co/mlx-community) for curating and maintaining a diverse collection of high-quality MLX models 652 | 653 | ### Open Source Community 654 | We deeply appreciate the broader open-source community for their invaluable contributions. Your dedication to: 655 | - Innovation in machine learning and AI 656 | - Collaborative development practices 657 | - Knowledge sharing and documentation 658 | - Continuous improvement of tools and frameworks 659 | 660 | Your collective efforts continue to drive progress and make projects like this possible. We are proud to be part of this vibrant ecosystem. 661 | 662 | ### Special Thanks 663 | A special acknowledgment to all contributors, users, and supporters who have helped shape this project through their feedback, bug reports, and suggestions. Your engagement helps make this project better for everyone. -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from app.version import __version__ 3 | 4 | # Suppress transformers warnings 5 | os.environ['TRANSFORMERS_VERBOSITY'] = 'error' 6 | 7 | __all__ = ["__version__"] -------------------------------------------------------------------------------- /app/api/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /app/api/endpoints.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import time 4 | from http import HTTPStatus 5 | from typing import Any, Dict, List, Optional, Union, AsyncGenerator 6 | 7 | from fastapi import APIRouter, Request 8 | from fastapi.responses import JSONResponse, StreamingResponse 9 | from loguru import logger 10 | 11 | from app.handler.mlx_lm import MLXLMHandler 12 | from app.schemas.openai import (ChatCompletionChunk, 13 | ChatCompletionMessageToolCall, 14 | ChatCompletionRequest, ChatCompletionResponse, 15 | Choice, ChoiceDeltaFunctionCall, 16 | ChoiceDeltaToolCall, Delta, Embedding, 17 | EmbeddingRequest, EmbeddingResponse, 18 | FunctionCall, Message, Model, ModelsResponse, 19 | StreamingChoice) 20 | from app.utils.errors import create_error_response 21 | 22 | router = APIRouter() 23 | 24 | 25 | @router.post("/health") 26 | async def health(raw_request: Request): 27 | """ 28 | Health check endpoint. 29 | """ 30 | try: 31 | return {"status": "ok"} 32 | except Exception as e: 33 | logger.error(f"Health check failed: {str(e)}") 34 | return JSONResponse(content=create_error_response("Health check failed", "server_error", 500), status_code=500) 35 | 36 | @router.get("/v1/queue/stats") 37 | async def queue_stats(raw_request: Request): 38 | """ 39 | Get queue statistics. 40 | """ 41 | handler = raw_request.app.state.handler 42 | if handler is None: 43 | return JSONResponse(content=create_error_response("Model handler not initialized", "service_unavailable", 503), status_code=503) 44 | 45 | try: 46 | stats = await handler.get_queue_stats() 47 | return { 48 | "status": "ok", 49 | "queue_stats": stats 50 | } 51 | except Exception as e: 52 | logger.error(f"Failed to get queue stats: {str(e)}") 53 | return JSONResponse(content=create_error_response("Failed to get queue stats", "server_error", 500), status_code=500) 54 | 55 | @router.get("/v1/models") 56 | async def models(raw_request: Request): 57 | """ 58 | Get list of available models. 59 | """ 60 | handler = raw_request.app.state.handler 61 | models_data = handler.get_models() 62 | return ModelsResponse(data=[Model(**model) for model in models_data]) 63 | 64 | @router.post("/v1/chat/completions") 65 | async def chat_completions(request: ChatCompletionRequest, raw_request: Request): 66 | """Handle chat completion requests.""" 67 | 68 | handler = raw_request.app.state.handler 69 | if handler is None: 70 | return JSONResponse(content=create_error_response("Model handler not initialized", "service_unavailable", 503), status_code=503) 71 | 72 | try: 73 | # Check if this is a vision request 74 | is_vision_request = request.is_vision_request() 75 | # If it's a vision request but the handler is MLXLMHandler (text-only), reject it 76 | if is_vision_request and isinstance(handler, MLXLMHandler): 77 | return JSONResponse( 78 | content=create_error_response( 79 | "Vision requests are not supported with text-only models. Use a VLM model type instead.", 80 | "unsupported_request", 81 | 400 82 | ), 83 | status_code=400 84 | ) 85 | 86 | # Process the request based on type 87 | return await process_vision_request(handler, request) if is_vision_request \ 88 | else await process_text_request(handler, request) 89 | except Exception as e: 90 | logger.error(f"Error processing chat completion request: {str(e)}", exc_info=True) 91 | return JSONResponse(content=create_error_response(str(e)), status_code=HTTPStatus.INTERNAL_SERVER_ERROR) 92 | 93 | @router.post("/v1/embeddings") 94 | async def embeddings(request: EmbeddingRequest, raw_request: Request): 95 | """Handle embedding requests.""" 96 | handler = raw_request.app.state.handler 97 | if handler is None: 98 | return JSONResponse(content=create_error_response("Model handler not initialized", "service_unavailable", 503), status_code=503) 99 | 100 | try: 101 | embeddings = await handler.generate_embeddings_response(request) 102 | return create_response_embeddings(embeddings, request.model) 103 | except Exception as e: 104 | logger.error(f"Error processing embedding request: {str(e)}", exc_info=True) 105 | return JSONResponse(content=create_error_response(str(e)), status_code=HTTPStatus.INTERNAL_SERVER_ERROR) 106 | 107 | def create_response_embeddings(embeddings: List[float], model: str) -> EmbeddingResponse: 108 | embeddings_response = [] 109 | for index, embedding in enumerate(embeddings): 110 | embeddings_response.append(Embedding(embedding=embedding, index=index)) 111 | return EmbeddingResponse(data=embeddings_response, model=model) 112 | 113 | def create_response_chunk(chunk: Union[str, Dict[str, Any]], model: str, is_final: bool = False, finish_reason: Optional[str] = "stop", chat_id: Optional[str] = None, created_time: Optional[int] = None) -> ChatCompletionChunk: 114 | """Create a formatted response chunk for streaming.""" 115 | chat_id = chat_id if chat_id else get_id() 116 | created_time = created_time if created_time else int(time.time()) 117 | if isinstance(chunk, str): 118 | return ChatCompletionChunk( 119 | id=chat_id, 120 | object="chat.completion.chunk", 121 | created=created_time, 122 | model=model, 123 | choices=[StreamingChoice( 124 | index=0, 125 | delta=Delta(content=chunk, role="assistant"), 126 | finish_reason=finish_reason if is_final else None 127 | )] 128 | ) 129 | if "reasoning_content" in chunk: 130 | return ChatCompletionChunk( 131 | id=chat_id, 132 | object="chat.completion.chunk", 133 | created=created_time, 134 | model=model, 135 | choices=[StreamingChoice( 136 | index=0, 137 | delta=Delta(reasoning_content=chunk["reasoning_content"], role="assistant"), 138 | finish_reason=finish_reason if is_final else None 139 | )] 140 | ) 141 | 142 | if "name" in chunk and chunk["name"]: 143 | tool_chunk = ChoiceDeltaToolCall( 144 | index=chunk["index"], 145 | type="function", 146 | id=get_tool_call_id(), 147 | function=ChoiceDeltaFunctionCall( 148 | name=chunk["name"], 149 | arguments="" 150 | ) 151 | ) 152 | else: 153 | tool_chunk = ChoiceDeltaToolCall( 154 | index=chunk["index"], 155 | type="function", 156 | function= ChoiceDeltaFunctionCall( 157 | arguments=chunk["arguments"] 158 | ) 159 | ) 160 | delta = Delta( 161 | content = None, 162 | role = "assistant", 163 | tool_calls = [tool_chunk] 164 | ) 165 | return ChatCompletionChunk( 166 | id=chat_id, 167 | object="chat.completion.chunk", 168 | created=created_time, 169 | model=model, 170 | choices=[StreamingChoice(index=0, delta=delta, finish_reason=None)] 171 | ) 172 | 173 | 174 | async def handle_stream_response(generator: AsyncGenerator, model: str): 175 | """Handle streaming response generation (OpenAI-compatible).""" 176 | chat_index = get_id() 177 | created_time = int(time.time()) 178 | try: 179 | finish_reason = "stop" 180 | index = -1 181 | # First chunk: role-only delta, as per OpenAI 182 | first_chunk = ChatCompletionChunk( 183 | id=chat_index, 184 | object="chat.completion.chunk", 185 | created=created_time, 186 | model=model, 187 | choices=[StreamingChoice(index=0, delta=Delta(role="assistant"), finish_reason=None)] 188 | ) 189 | yield f"data: {json.dumps(first_chunk.model_dump())}\n\n" 190 | async for chunk in generator: 191 | if chunk: 192 | if isinstance(chunk, str): 193 | response_chunk = create_response_chunk(chunk, model, chat_id=chat_index, created_time=created_time) 194 | yield f"data: {json.dumps(response_chunk.model_dump())}\n\n" 195 | else: 196 | finish_reason = "tool_calls" 197 | if "name" in chunk and chunk["name"]: 198 | index += 1 199 | payload = { 200 | "index": index, 201 | **chunk 202 | } 203 | response_chunk = create_response_chunk(payload, model, chat_id=chat_index, created_time=created_time) 204 | yield f"data: {json.dumps(response_chunk.model_dump())}\n\n" 205 | except Exception as e: 206 | logger.error(f"Error in stream wrapper: {str(e)}") 207 | error_response = create_error_response(str(e), "server_error", HTTPStatus.INTERNAL_SERVER_ERROR) 208 | # Yield error as last chunk before [DONE] 209 | yield f"data: {json.dumps(error_response)}\n\n" 210 | finally: 211 | # Final chunk: finish_reason and [DONE], as per OpenAI 212 | final_chunk = create_response_chunk('', model, is_final=True, finish_reason=finish_reason, chat_id=chat_index) 213 | yield f"data: {json.dumps(final_chunk.model_dump())}\n\n" 214 | yield "data: [DONE]\n\n" 215 | 216 | async def process_vision_request(handler, request: ChatCompletionRequest): 217 | """Process vision-specific requests.""" 218 | if request.stream: 219 | return StreamingResponse( 220 | handle_stream_response(handler.generate_vision_stream(request), request.model), 221 | media_type="text/event-stream", 222 | headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"} 223 | ) 224 | return format_final_response(await handler.generate_vision_response(request), request.model) 225 | 226 | async def process_text_request(handler, request: ChatCompletionRequest): 227 | """Process text-only requests.""" 228 | if request.stream: 229 | return StreamingResponse( 230 | handle_stream_response(handler.generate_text_stream(request), request.model), 231 | media_type="text/event-stream", 232 | headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"} 233 | ) 234 | return format_final_response(await handler.generate_text_response(request), request.model) 235 | 236 | def get_id(): 237 | """ 238 | Generate a unique ID for chat completions with timestamp and random component. 239 | """ 240 | timestamp = int(time.time()) 241 | random_suffix = random.randint(0, 999999) 242 | return f"chatcmpl_{timestamp}{random_suffix:06d}" 243 | 244 | def get_tool_call_id(): 245 | """ 246 | Generate a unique ID for tool calls with timestamp and random component. 247 | """ 248 | timestamp = int(time.time()) 249 | random_suffix = random.randint(0, 999999) 250 | return f"call_{timestamp}{random_suffix:06d}" 251 | 252 | def format_final_response(response: Union[str, List[Dict[str, Any]]], model: str) -> ChatCompletionResponse: 253 | """Format the final non-streaming response.""" 254 | 255 | if isinstance(response, str): 256 | return ChatCompletionResponse( 257 | id=get_id(), 258 | object="chat.completion", 259 | created=int(time.time()), 260 | model=model, 261 | choices=[Choice( 262 | index=0, 263 | message=Message(role="assistant", content=response), 264 | finish_reason="stop" 265 | )] 266 | ) 267 | 268 | reasoning_content = response.get("reasoning_content", None) 269 | tool_calls = response.get("tool_calls", []) 270 | tool_call_responses = [] 271 | for idx, tool_call in enumerate(tool_calls): 272 | function_call = FunctionCall( 273 | name=tool_call.get("name"), 274 | arguments=json.dumps(tool_call.get("arguments")) 275 | ) 276 | tool_call_response = ChatCompletionMessageToolCall( 277 | id=get_tool_call_id(), 278 | type="function", 279 | function=function_call, 280 | index=idx 281 | ) 282 | tool_call_responses.append(tool_call_response) 283 | 284 | if len(tool_calls) > 0: 285 | message = Message(role="assistant", reasoning_content=reasoning_content, tool_calls=tool_call_responses) 286 | else: 287 | message = Message(role="assistant", content=response, reasoning_content=reasoning_content, tool_calls=tool_call_responses) 288 | 289 | return ChatCompletionResponse( 290 | id=get_id(), 291 | object="chat.completion", 292 | created=int(time.time()), 293 | model=model, 294 | choices=[Choice( 295 | index=0, 296 | message=message, 297 | finish_reason="tool_calls" 298 | )] 299 | ) -------------------------------------------------------------------------------- /app/cli.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import click 3 | import uvicorn 4 | from loguru import logger 5 | import sys 6 | from functools import lru_cache 7 | from app.version import __version__ 8 | from app.main import setup_server 9 | 10 | class Config: 11 | """Configuration container for server parameters.""" 12 | def __init__(self, model_path, model_type, port, host, max_concurrency, queue_timeout, queue_size): 13 | self.model_path = model_path 14 | self.model_type = model_type 15 | self.port = port 16 | self.host = host 17 | self.max_concurrency = max_concurrency 18 | self.queue_timeout = queue_timeout 19 | self.queue_size = queue_size 20 | 21 | 22 | # Configure Loguru once at module level 23 | def configure_logging(): 24 | """Set up optimized logging configuration.""" 25 | logger.remove() # Remove default handler 26 | logger.add( 27 | sys.stderr, 28 | format="{time:YYYY-MM-DD HH:mm:ss} | " 29 | "{level: <8} | " 30 | "{name}:{function}:{line} | " 31 | "✦ {message}", 32 | colorize=True, 33 | level="INFO" 34 | ) 35 | 36 | # Apply logging configuration 37 | configure_logging() 38 | 39 | 40 | @click.group() 41 | @click.version_option( 42 | version=__version__, 43 | message=""" 44 | ✨ %(prog)s - OpenAI Compatible API Server for MLX models ✨ 45 | ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 46 | 🚀 Version: %(version)s 47 | """ 48 | ) 49 | def cli(): 50 | """MLX Server - OpenAI Compatible API for MLX models.""" 51 | pass 52 | 53 | 54 | @lru_cache(maxsize=1) 55 | def get_server_config(model_path, model_type, port, host, max_concurrency, queue_timeout, queue_size): 56 | """Cache and return server configuration to avoid redundant processing.""" 57 | return Config( 58 | model_path=model_path, 59 | model_type=model_type, 60 | port=port, 61 | host=host, 62 | max_concurrency=max_concurrency, 63 | queue_timeout=queue_timeout, 64 | queue_size=queue_size 65 | ) 66 | 67 | 68 | def print_startup_banner(args): 69 | """Display beautiful startup banner with configuration details.""" 70 | logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") 71 | logger.info(f"✨ MLX Server v{__version__} Starting ✨") 72 | logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") 73 | logger.info(f"🔮 Model: {args.model_path}") 74 | logger.info(f"🔮 Model Type: {args.model_type}") 75 | logger.info(f"🌐 Host: {args.host}") 76 | logger.info(f"🔌 Port: {args.port}") 77 | logger.info(f"⚡ Max Concurrency: {args.max_concurrency}") 78 | logger.info(f"⏱️ Queue Timeout: {args.queue_timeout} seconds") 79 | logger.info(f"📊 Queue Size: {args.queue_size}") 80 | logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") 81 | 82 | 83 | @cli.command() 84 | @click.option( 85 | "--model-path", 86 | required=True, 87 | help="Path to the model" 88 | ) 89 | @click.option( 90 | "--model-type", 91 | default="lm", 92 | type=click.Choice(["lm", "vlm"]), 93 | help="Type of model to run" 94 | ) 95 | @click.option( 96 | "--port", 97 | default=8000, 98 | type=int, 99 | help="Port to run the server on" 100 | ) 101 | @click.option( 102 | "--host", 103 | default="0.0.0.0", 104 | help="Host to run the server on" 105 | ) 106 | @click.option( 107 | "--max-concurrency", 108 | default=1, 109 | type=int, 110 | help="Maximum number of concurrent requests" 111 | ) 112 | @click.option( 113 | "--queue-timeout", 114 | default=300, 115 | type=int, 116 | help="Request timeout in seconds" 117 | ) 118 | @click.option( 119 | "--queue-size", 120 | default=100, 121 | type=int, 122 | help="Maximum queue size for pending requests" 123 | ) 124 | def launch(model_path, model_type, port, host, max_concurrency, queue_timeout, queue_size): 125 | """Launch the MLX server with the specified model.""" 126 | try: 127 | # Get optimized configuration 128 | args = get_server_config(model_path, model_type, port, host, max_concurrency, queue_timeout, queue_size) 129 | 130 | # Display startup information 131 | print_startup_banner(args) 132 | 133 | # Set up and start the server 134 | config = asyncio.run(setup_server(args)) 135 | logger.info("Server configuration complete.") 136 | logger.info("Starting Uvicorn server...") 137 | uvicorn.Server(config).run() 138 | except KeyboardInterrupt: 139 | logger.info("Server shutdown requested by user. Exiting...") 140 | except Exception as e: 141 | logger.error(f"Server startup failed: {str(e)}") 142 | sys.exit(1) 143 | 144 | 145 | if __name__ == "__main__": 146 | cli() -------------------------------------------------------------------------------- /app/core/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /app/core/image_processor.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import base64 3 | import hashlib 4 | import os 5 | from concurrent.futures import ThreadPoolExecutor 6 | from functools import lru_cache 7 | from io import BytesIO 8 | from typing import List, Optional 9 | import tempfile 10 | import aiohttp 11 | import time 12 | from PIL import Image 13 | from loguru import logger 14 | 15 | class ImageProcessor: 16 | def __init__(self, max_workers: int = 4, cache_size: int = 1000): 17 | # Use tempfile for macOS-efficient temporary file handling 18 | self.temp_dir = tempfile.TemporaryDirectory() 19 | self._session: Optional[aiohttp.ClientSession] = None 20 | self.executor = ThreadPoolExecutor(max_workers=max_workers) 21 | self._cache_size = cache_size 22 | self._last_cleanup = time.time() 23 | self._cleanup_interval = 3600 # 1 hour 24 | Image.MAX_IMAGE_PIXELS = 100000000 # Limit to 100 megapixels 25 | 26 | @lru_cache(maxsize=1000) 27 | def _get_image_hash(self, image_url: str) -> str: 28 | if image_url.startswith("data:"): 29 | _, encoded = image_url.split(",", 1) 30 | data = base64.b64decode(encoded) 31 | else: 32 | data = image_url.encode('utf-8') 33 | return hashlib.md5(data).hexdigest() 34 | 35 | def _resize_image_keep_aspect_ratio(self, image: Image.Image, max_size: int = 448) -> Image.Image: 36 | width, height = image.size 37 | if width <= max_size and height <= max_size: 38 | return image 39 | if width > height: 40 | new_width = max_size 41 | new_height = int(height * max_size / width) 42 | else: 43 | new_height = max_size 44 | new_width = int(width * max_size / height) 45 | 46 | return image.resize((new_width, new_height), Image.Resampling.LANCZOS) 47 | 48 | def _prepare_image_for_saving(self, image: Image.Image) -> Image.Image: 49 | if image.mode in ('RGBA', 'LA'): 50 | background = Image.new('RGB', image.size, (255, 255, 255)) 51 | if image.mode == 'RGBA': 52 | background.paste(image, mask=image.split()[3]) 53 | else: 54 | background.paste(image, mask=image.split()[1]) 55 | return background 56 | elif image.mode != 'RGB': 57 | return image.convert('RGB') 58 | return image 59 | 60 | async def _get_session(self) -> aiohttp.ClientSession: 61 | if self._session is None or self._session.closed: 62 | self._session = aiohttp.ClientSession( 63 | timeout=aiohttp.ClientTimeout(total=30), 64 | headers={"User-Agent": "mlx-server-OAI-compat/1.0"} 65 | ) 66 | return self._session 67 | 68 | def _cleanup_old_files(self): 69 | current_time = time.time() 70 | if current_time - self._last_cleanup > self._cleanup_interval: 71 | try: 72 | for file in os.listdir(self.temp_dir.name): 73 | file_path = os.path.join(self.temp_dir.name, file) 74 | if os.path.getmtime(file_path) < current_time - self._cleanup_interval: 75 | os.remove(file_path) 76 | self._last_cleanup = current_time 77 | except Exception as e: 78 | logger.warning(f"Failed to clean up old files: {str(e)}") 79 | 80 | async def _process_single_image(self, image_url: str) -> str: 81 | try: 82 | image_hash = self._get_image_hash(image_url) 83 | cached_path = os.path.join(self.temp_dir.name, f"{image_hash}.jpg") 84 | 85 | if os.path.exists(cached_path): 86 | logger.debug(f"Using cached image: {cached_path}") 87 | return cached_path 88 | 89 | if os.path.exists(image_url): 90 | # Read-only image loading for memory efficiency 91 | with Image.open(image_url, mode='r') as image: 92 | image = self._resize_image_keep_aspect_ratio(image) 93 | image = self._prepare_image_for_saving(image) 94 | image.save(cached_path, 'JPEG', quality=100, optimize=True) 95 | return cached_path 96 | 97 | elif image_url.startswith("data:"): 98 | _, encoded = image_url.split(",", 1) 99 | estimated_size = len(encoded) * 3 / 4 100 | if estimated_size > 100 * 1024 * 1024: 101 | raise ValueError("Base64-encoded image exceeds 100 MB") 102 | data = base64.b64decode(encoded) 103 | with Image.open(BytesIO(data), mode='r') as image: 104 | image = self._resize_image_keep_aspect_ratio(image) 105 | image = self._prepare_image_for_saving(image) 106 | image.save(cached_path, 'JPEG', quality=100, optimize=True) 107 | else: 108 | session = await self._get_session() 109 | async with session.get(image_url) as response: 110 | response.raise_for_status() 111 | data = await response.read() 112 | with Image.open(BytesIO(data), mode='r') as image: 113 | image = self._resize_image_keep_aspect_ratio(image) 114 | image = self._prepare_image_for_saving(image) 115 | image.save(cached_path, 'JPEG', quality=100, optimize=True) 116 | 117 | self._cleanup_old_files() 118 | return cached_path 119 | 120 | except Exception as e: 121 | logger.error(f"Failed to process image: {str(e)}") 122 | raise ValueError(f"Failed to process image: {str(e)}") 123 | 124 | async def process_image_url(self, image_url: str) -> str: 125 | return await self._process_single_image(image_url) 126 | 127 | async def process_image_urls(self, image_urls: List[str]) -> List[str]: 128 | tasks = [self.process_image_url(url) for url in image_urls] 129 | return await asyncio.gather(*tasks, return_exceptions=True) 130 | 131 | async def cleanup(self): 132 | if hasattr(self, '_cleaned') and self._cleaned: 133 | return 134 | self._cleaned = True 135 | try: 136 | if self._session and not self._session.closed: 137 | await self._session.close() 138 | except Exception as e: 139 | logger.warning(f"Exception closing aiohttp session: {str(e)}") 140 | try: 141 | self.executor.shutdown(wait=True) 142 | except Exception as e: 143 | logger.warning(f"Exception shutting down executor: {str(e)}") 144 | try: 145 | self.temp_dir.cleanup() 146 | except Exception as e: 147 | logger.warning(f"Exception cleaning up temp_dir: {str(e)}") 148 | 149 | async def __aenter__(self): 150 | """Enter the async context manager.""" 151 | return self 152 | 153 | async def __aexit__(self, exc_type, exc, tb): 154 | """Exit the async context manager and perform cleanup.""" 155 | await self.cleanup() 156 | 157 | def __del__(self): 158 | # Async cleanup cannot be reliably performed in __del__ 159 | # Please use 'async with ImageProcessor()' or call 'await cleanup()' explicitly. 160 | pass -------------------------------------------------------------------------------- /app/core/queue.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | from typing import Any, Dict, Optional, Callable, Awaitable, TypeVar, Generic 4 | from loguru import logger 5 | 6 | T = TypeVar('T') 7 | 8 | class RequestItem(Generic[T]): 9 | """ 10 | Represents a single request in the queue. 11 | """ 12 | def __init__(self, request_id: str, data: Any): 13 | self.request_id = request_id 14 | self.data = data 15 | self.created_at = time.time() 16 | self.future = asyncio.Future() 17 | 18 | def set_result(self, result: T) -> None: 19 | """Set the result for this request.""" 20 | if not self.future.done(): 21 | self.future.set_result(result) 22 | 23 | def set_exception(self, exc: Exception) -> None: 24 | """Set an exception for this request.""" 25 | if not self.future.done(): 26 | self.future.set_exception(exc) 27 | 28 | async def get_result(self) -> T: 29 | """Wait for and return the result of this request.""" 30 | return await self.future 31 | 32 | class RequestQueue: 33 | """ 34 | A simple asynchronous request queue with configurable concurrency. 35 | """ 36 | def __init__(self, max_concurrency: int = 2, timeout: float = 300.0, queue_size: int = 100): 37 | """ 38 | Initialize the request queue. 39 | 40 | Args: 41 | max_concurrency (int): Maximum number of concurrent requests to process. 42 | timeout (float): Timeout in seconds for request processing. 43 | queue_size (int): Maximum queue size. 44 | """ 45 | self.max_concurrency = max_concurrency 46 | self.timeout = timeout 47 | self.queue_size = queue_size 48 | self.semaphore = asyncio.Semaphore(max_concurrency) 49 | self.queue = asyncio.Queue(maxsize=queue_size) 50 | self.active_requests: Dict[str, RequestItem] = {} 51 | self._worker_task = None 52 | self._running = False 53 | 54 | async def start(self, processor: Callable[[Any], Awaitable[Any]]): 55 | """ 56 | Start the queue worker. 57 | 58 | Args: 59 | processor: Async function that processes queue items. 60 | """ 61 | if self._running: 62 | return 63 | 64 | self._running = True 65 | self._worker_task = asyncio.create_task(self._worker_loop(processor)) 66 | logger.info(f"Started request queue with max concurrency: {self.max_concurrency}") 67 | 68 | async def stop(self): 69 | """Stop the queue worker.""" 70 | if not self._running: 71 | return 72 | 73 | self._running = False 74 | 75 | # Cancel the worker task 76 | if self._worker_task and not self._worker_task.done(): 77 | self._worker_task.cancel() 78 | try: 79 | await self._worker_task 80 | except asyncio.CancelledError: 81 | pass 82 | 83 | # Cancel all pending requests 84 | pending_requests = list(self.active_requests.values()) 85 | for request in pending_requests: 86 | if not request.future.done(): 87 | request.future.cancel() 88 | 89 | self.active_requests.clear() 90 | logger.info("Stopped request queue") 91 | 92 | async def _worker_loop(self, processor: Callable[[Any], Awaitable[Any]]): 93 | """ 94 | Main worker loop that processes queue items. 95 | 96 | Args: 97 | processor: Async function that processes queue items. 98 | """ 99 | while self._running: 100 | try: 101 | # Get the next item from the queue 102 | request = await self.queue.get() 103 | 104 | # Process the request with concurrency control 105 | asyncio.create_task(self._process_request(request, processor)) 106 | 107 | except asyncio.CancelledError: 108 | break 109 | except Exception as e: 110 | logger.error(f"Error in worker loop: {str(e)}") 111 | 112 | async def _process_request(self, request: RequestItem, processor: Callable[[Any], Awaitable[Any]]): 113 | """ 114 | Process a single request with timeout and error handling. 115 | 116 | Args: 117 | request: The request to process. 118 | processor: Async function that processes the request. 119 | """ 120 | # Use semaphore to limit concurrency 121 | async with self.semaphore: 122 | try: 123 | # Process with timeout 124 | processing_start = time.time() 125 | result = await asyncio.wait_for( 126 | processor(request.data), 127 | timeout=self.timeout 128 | ) 129 | processing_time = time.time() - processing_start 130 | 131 | # Set the result 132 | request.set_result(result) 133 | logger.info(f"Request {request.request_id} processed in {processing_time:.2f}s") 134 | 135 | except asyncio.TimeoutError: 136 | request.set_exception(TimeoutError(f"Request processing timed out after {self.timeout}s")) 137 | logger.warning(f"Request {request.request_id} timed out after {self.timeout}s") 138 | 139 | except Exception as e: 140 | request.set_exception(e) 141 | logger.error(f"Error processing request {request.request_id}: {str(e)}") 142 | 143 | finally: 144 | # Remove from active requests 145 | self.active_requests.pop(request.request_id, None) 146 | 147 | async def enqueue(self, request_id: str, data: Any) -> RequestItem: 148 | """ 149 | Add a request to the queue. 150 | 151 | Args: 152 | request_id: Unique ID for the request. 153 | data: The request data to process. 154 | 155 | Returns: 156 | RequestItem: The queued request item. 157 | 158 | Raises: 159 | asyncio.QueueFull: If the queue is full. 160 | """ 161 | if not self._running: 162 | raise RuntimeError("Queue is not running") 163 | 164 | # Create request item 165 | request = RequestItem(request_id, data) 166 | 167 | # Add to active requests and queue 168 | self.active_requests[request_id] = request 169 | 170 | try: 171 | # This will raise QueueFull if the queue is full 172 | await asyncio.wait_for( 173 | self.queue.put(request), 174 | timeout=1.0 # Short timeout for queue put 175 | ) 176 | queue_time = time.time() - request.created_at 177 | logger.info(f"Request {request_id} queued (wait: {queue_time:.2f}s)") 178 | return request 179 | 180 | except asyncio.TimeoutError: 181 | self.active_requests.pop(request_id, None) 182 | raise asyncio.QueueFull("Request queue is full and timed out waiting for space") 183 | 184 | async def submit(self, request_id: str, data: Any) -> Any: 185 | """ 186 | Submit a request and wait for its result. 187 | 188 | Args: 189 | request_id: Unique ID for the request. 190 | data: The request data to process. 191 | 192 | Returns: 193 | The result of processing the request. 194 | 195 | Raises: 196 | Various exceptions that may occur during processing. 197 | """ 198 | request = await self.enqueue(request_id, data) 199 | return await request.get_result() 200 | 201 | def get_queue_stats(self) -> Dict[str, Any]: 202 | """ 203 | Get queue statistics. 204 | 205 | Returns: 206 | Dict with queue statistics. 207 | """ 208 | return { 209 | "running": self._running, 210 | "queue_size": self.queue.qsize(), 211 | "max_queue_size": self.queue_size, 212 | "active_requests": len(self.active_requests), 213 | "max_concurrency": self.max_concurrency 214 | } 215 | 216 | # Alias for the async stop method to maintain consistency in cleanup interfaces 217 | async def stop_async(self): 218 | """Alias for stop - stops the queue worker asynchronously.""" 219 | await self.stop() -------------------------------------------------------------------------------- /app/handler/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | MLX model handlers for text and vision-language models. 3 | """ 4 | 5 | from app.handler.mlx_lm import MLXLMHandler 6 | from app.handler.mlx_vlm import MLXVLMHandler 7 | 8 | __all__ = [ 9 | "MLXLMHandler", 10 | "MLXVLMHandler" 11 | ] 12 | -------------------------------------------------------------------------------- /app/handler/mlx_lm.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | import uuid 4 | from http import HTTPStatus 5 | from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple 6 | 7 | from fastapi import HTTPException 8 | from loguru import logger 9 | 10 | from app.core.queue import RequestQueue 11 | from app.handler.parser import get_parser 12 | from app.models.mlx_lm import MLX_LM 13 | from app.schemas.openai import ChatCompletionRequest, EmbeddingRequest 14 | from app.utils.errors import create_error_response 15 | 16 | class MLXLMHandler: 17 | """ 18 | Handler class for making requests to the underlying MLX text-only language model service. 19 | Provides request queuing, metrics tracking, and robust error handling. 20 | """ 21 | 22 | def __init__(self, model_path: str, max_concurrency: int = 1): 23 | """ 24 | Initialize the handler with the specified model path. 25 | 26 | Args: 27 | model_path (str): Path to the model directory. 28 | max_concurrency (int): Maximum number of concurrent model inference tasks. 29 | """ 30 | self.model_path = model_path 31 | self.model = MLX_LM(model_path) 32 | self.model_type = self.model.get_model_type() 33 | self.model_created = int(time.time()) # Store creation time when model is loaded 34 | 35 | # Initialize request queue for text tasks 36 | self.request_queue = RequestQueue(max_concurrency=max_concurrency) 37 | 38 | logger.info(f"Initialized MLXHandler with model path: {model_path}") 39 | 40 | def get_models(self) -> List[Dict[str, Any]]: 41 | """ 42 | Get list of available models with their metadata. 43 | """ 44 | return [{ 45 | "id": self.model_path, 46 | "object": "model", 47 | "created": self.model_created, 48 | "owned_by": "local" 49 | }] 50 | 51 | async def initialize(self, queue_config: Optional[Dict[str, Any]] = None): 52 | """Initialize the handler and start the request queue.""" 53 | if not queue_config: 54 | queue_config = { 55 | "max_concurrency": 1, 56 | "timeout": 300, 57 | "queue_size": 100 58 | } 59 | self.request_queue = RequestQueue( 60 | max_concurrency=queue_config.get("max_concurrency"), 61 | timeout=queue_config.get("timeout"), 62 | queue_size=queue_config.get("queue_size") 63 | ) 64 | await self.request_queue.start(self._process_request) 65 | logger.info("Initialized MLXHandler and started request queue") 66 | 67 | async def generate_text_stream(self, request: ChatCompletionRequest) -> AsyncGenerator[str, None]: 68 | """ 69 | Generate a streaming response for text-only chat completion requests. 70 | Uses the request queue for handling concurrent requests. 71 | 72 | Args: 73 | request: ChatCompletionRequest object containing the messages. 74 | 75 | Yields: 76 | str: Response chunks. 77 | """ 78 | request_id = f"text-{uuid.uuid4()}" 79 | 80 | try: 81 | chat_messages, model_params = await self._prepare_text_request(request) 82 | request_data = { 83 | "messages": chat_messages, 84 | "stream": True, 85 | **model_params 86 | } 87 | response_generator = await self.request_queue.submit(request_id, request_data) 88 | 89 | tools = model_params.get("chat_template_kwargs", {}).get("tools", None) 90 | enable_thinking = model_params.get("chat_template_kwargs", {}).get("enable_thinking", None) 91 | 92 | tool_parser, thinking_parser = get_parser(self.model_type) 93 | if enable_thinking and thinking_parser: 94 | for chunk in response_generator: 95 | if chunk: 96 | chunk, is_finish = thinking_parser.parse_stream(chunk.text) 97 | if chunk: 98 | yield chunk 99 | if is_finish: 100 | break 101 | 102 | if tools and tool_parser: 103 | for chunk in response_generator: 104 | if chunk: 105 | chunk = tool_parser.parse_stream(chunk.text) 106 | if chunk: 107 | yield chunk 108 | else: 109 | for chunk in response_generator: 110 | if chunk: 111 | yield chunk.text 112 | 113 | except asyncio.QueueFull: 114 | logger.error("Too many requests. Service is at capacity.") 115 | content = create_error_response("Too many requests. Service is at capacity.", "rate_limit_exceeded", HTTPStatus.TOO_MANY_REQUESTS) 116 | raise HTTPException(status_code=429, detail=content) 117 | except Exception as e: 118 | logger.error(f"Error in text stream generation for request {request_id}: {str(e)}") 119 | content = create_error_response(f"Failed to generate text stream: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR) 120 | raise HTTPException(status_code=500, detail=content) 121 | 122 | async def generate_text_response(self, request: ChatCompletionRequest) -> str: 123 | """ 124 | Generate a complete response for text-only chat completion requests. 125 | Uses the request queue for handling concurrent requests. 126 | 127 | Args: 128 | request: ChatCompletionRequest object containing the messages. 129 | 130 | Returns: 131 | str: Complete response. 132 | """ 133 | request_id = f"text-{uuid.uuid4()}" 134 | 135 | try: 136 | chat_messages, model_params = await self._prepare_text_request(request) 137 | request_data = { 138 | "messages": chat_messages, 139 | "stream": False, 140 | **model_params 141 | } 142 | response = await self.request_queue.submit(request_id, request_data) 143 | tools = model_params.get("chat_template_kwargs", {}).get("tools", None) 144 | enable_thinking = model_params.get("chat_template_kwargs", {}).get("enable_thinking", None) 145 | if not tools and not enable_thinking: 146 | return response 147 | 148 | tool_parser, thinking_parser = get_parser(self.model_type) 149 | if not tool_parser and not thinking_parser: 150 | return response 151 | parsed_response = { 152 | "reasoning_content": None, 153 | "tool_calls": None, 154 | "content": None 155 | } 156 | if enable_thinking and thinking_parser: 157 | thinking_response, response = thinking_parser.parse(response) 158 | parsed_response["reasoning_content"] = thinking_response 159 | if tools and tool_parser: 160 | tool_response, response = tool_parser.parse(response) 161 | parsed_response["tool_calls"] = tool_response 162 | parsed_response["content"] = response 163 | 164 | return parsed_response 165 | 166 | except asyncio.QueueFull: 167 | logger.error("Too many requests. Service is at capacity.") 168 | content = create_error_response("Too many requests. Service is at capacity.", "rate_limit_exceeded", HTTPStatus.TOO_MANY_REQUESTS) 169 | raise HTTPException(status_code=429, detail=content) 170 | except Exception as e: 171 | logger.error(f"Error in text response generation: {str(e)}") 172 | content = create_error_response(f"Failed to generate text response: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR) 173 | raise HTTPException(status_code=500, detail=content) 174 | 175 | async def generate_embeddings_response(self, request: EmbeddingRequest): 176 | """ 177 | Generate embeddings for a given text input. 178 | 179 | Args: 180 | request: EmbeddingRequest object containing the text input. 181 | 182 | Returns: 183 | List[float]: Embeddings for the input text. 184 | """ 185 | try: 186 | # Create a unique request ID 187 | request_id = f"embeddings-{uuid.uuid4()}" 188 | request_data = { 189 | "type": "embeddings", 190 | "input": request.input, 191 | "model": request.model 192 | } 193 | 194 | # Submit to the request queue 195 | response = await self.request_queue.submit(request_id, request_data) 196 | 197 | return response 198 | 199 | except Exception as e: 200 | logger.error(f"Error in embeddings generation: {str(e)}") 201 | content = create_error_response(f"Failed to generate embeddings: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR) 202 | raise HTTPException(status_code=500, detail=content) 203 | 204 | 205 | async def _process_request(self, request_data: Dict[str, Any]) -> str: 206 | """ 207 | Process a text request. This is the worker function for the request queue. 208 | 209 | Args: 210 | request_data: Dictionary containing the request data. 211 | 212 | Returns: 213 | str: The model's response. 214 | """ 215 | try: 216 | # Check if the request is for embeddings 217 | if request_data.get("type") == "embeddings": 218 | return self.model.get_embeddings(request_data["input"]) 219 | 220 | # Extract request parameters 221 | messages = request_data.get("messages", []) 222 | stream = request_data.get("stream", False) 223 | 224 | # Remove these keys from model_params 225 | model_params = request_data.copy() 226 | model_params.pop("messages", None) 227 | model_params.pop("stream", None) 228 | 229 | # Call the model 230 | response = self.model( 231 | messages=messages, 232 | stream=stream, 233 | **model_params 234 | ) 235 | 236 | return response 237 | 238 | except Exception as e: 239 | logger.error(f"Error processing text request: {str(e)}") 240 | raise 241 | 242 | async def get_queue_stats(self) -> Dict[str, Any]: 243 | """ 244 | Get statistics from the request queue and performance metrics. 245 | 246 | Returns: 247 | Dict with queue and performance statistics. 248 | """ 249 | queue_stats = self.request_queue.get_queue_stats() 250 | 251 | return { 252 | "queue_stats": queue_stats, 253 | } 254 | 255 | async def cleanup(self): 256 | """ 257 | Cleanup resources and stop the request queue before shutdown. 258 | 259 | This method ensures all pending requests are properly cancelled 260 | and resources are released. 261 | """ 262 | try: 263 | logger.info("Cleaning up MLXLMHandler resources") 264 | if hasattr(self, 'request_queue'): 265 | await self.request_queue.stop() 266 | logger.info("MLXLMHandler cleanup completed successfully") 267 | except Exception as e: 268 | logger.error(f"Error during MLXLMHandler cleanup: {str(e)}") 269 | raise 270 | 271 | async def _prepare_text_request(self, request: ChatCompletionRequest) -> Tuple[List[Dict[str, str]], Dict[str, Any]]: 272 | """ 273 | Prepare a text request by parsing model parameters and verifying the format of messages. 274 | 275 | Args: 276 | request: ChatCompletionRequest object containing the messages. 277 | 278 | Returns: 279 | Tuple containing the formatted chat messages and model parameters. 280 | """ 281 | try: 282 | # Get model parameters from the request 283 | temperature = request.temperature or 0.7 284 | top_p = request.top_p or 1.0 285 | frequency_penalty = request.frequency_penalty or 0.0 286 | presence_penalty = request.presence_penalty or 0.0 287 | max_tokens = request.max_tokens or 1024 288 | tools = request.tools or None 289 | chat_template_kwargs = request.chat_template_kwargs 290 | if tools: 291 | chat_template_kwargs.tools = tools 292 | 293 | model_params = { 294 | "temperature": temperature, 295 | "top_p": top_p, 296 | "frequency_penalty": frequency_penalty, 297 | "presence_penalty": presence_penalty, 298 | "max_tokens": max_tokens, 299 | "tools": tools, 300 | "chat_template_kwargs": chat_template_kwargs.model_dump() 301 | } 302 | 303 | # Format chat messages 304 | chat_messages = [] 305 | for message in request.messages: 306 | chat_messages.append({ 307 | "role": message.role, 308 | "content": message.content 309 | }) 310 | 311 | return chat_messages, model_params 312 | 313 | except Exception as e: 314 | logger.error(f"Failed to prepare text request: {str(e)}") 315 | content = create_error_response(f"Failed to process request: {str(e)}", "bad_request", HTTPStatus.BAD_REQUEST) 316 | raise HTTPException(status_code=400, detail=content) 317 | -------------------------------------------------------------------------------- /app/handler/mlx_vlm.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import base64 3 | import time 4 | import uuid 5 | from typing import Any, Dict, List, Optional, Tuple 6 | from http import HTTPStatus 7 | 8 | from fastapi import HTTPException 9 | from loguru import logger 10 | 11 | from app.core.image_processor import ImageProcessor 12 | from app.core.queue import RequestQueue 13 | from app.models.mlx_vlm import MLX_VLM 14 | from app.schemas.openai import ChatCompletionRequest, EmbeddingRequest 15 | from app.utils.errors import create_error_response 16 | 17 | class MLXVLMHandler: 18 | """ 19 | Handler class for making requests to the underlying MLX vision-language model service. 20 | Provides caching, concurrent image processing, and robust error handling. 21 | """ 22 | 23 | def __init__(self, model_path: str, max_workers: int = 4, max_concurrency: int = 1): 24 | """ 25 | Initialize the handler with the specified model path. 26 | 27 | Args: 28 | model_path (str): Path to the model directory. 29 | max_workers (int): Maximum number of worker threads for image processing. 30 | max_concurrency (int): Maximum number of concurrent model inference tasks. 31 | """ 32 | self.model_path = model_path 33 | self.model = MLX_VLM(model_path) 34 | self.image_processor = ImageProcessor(max_workers) 35 | self.model_created = int(time.time()) # Store creation time when model is loaded 36 | 37 | # Initialize request queue for vision and text tasks 38 | # We use the same queue for both vision and text tasks for simplicity 39 | # and to ensure we don't overload the model with too many concurrent requests 40 | self.request_queue = RequestQueue(max_concurrency=max_concurrency) 41 | 42 | logger.info(f"Initialized MLXHandler with model path: {model_path}") 43 | 44 | def get_models(self) -> List[Dict[str, Any]]: 45 | """ 46 | Get list of available models with their metadata. 47 | """ 48 | return [{ 49 | "id": self.model_path, 50 | "object": "model", 51 | "created": self.model_created, 52 | "owned_by": "local" 53 | }] 54 | 55 | async def initialize(self, queue_config: Optional[Dict[str, Any]] = None): 56 | """Initialize the handler and start the request queue.""" 57 | 58 | if not queue_config: 59 | queue_config = { 60 | "max_concurrency": 1, 61 | "timeout": 300, 62 | "queue_size": 100 63 | } 64 | self.request_queue = RequestQueue( 65 | max_concurrency=queue_config.get("max_concurrency"), 66 | timeout=queue_config.get("timeout"), 67 | queue_size=queue_config.get("queue_size") 68 | ) 69 | await self.request_queue.start(self._process_request) 70 | logger.info("Initialized MLXHandler and started request queue") 71 | 72 | async def generate_vision_stream(self, request: ChatCompletionRequest): 73 | """ 74 | Generate a streaming response for vision-based chat completion requests. 75 | 76 | Args: 77 | request: ChatCompletionRequest object containing the messages. 78 | 79 | Returns: 80 | AsyncGenerator: Yields response chunks. 81 | """ 82 | 83 | # Create a unique request ID 84 | request_id = f"vision-{uuid.uuid4()}" 85 | 86 | try: 87 | chat_messages, image_paths, model_params = await self._prepare_vision_request(request) 88 | 89 | # Create a request data object 90 | request_data = { 91 | "images": image_paths, 92 | "messages": chat_messages, 93 | "stream": True, 94 | **model_params 95 | } 96 | 97 | # Submit to the vision queue and get the generator 98 | response_generator = await self.request_queue.submit(request_id, request_data) 99 | 100 | # Process and yield each chunk asynchronously 101 | for chunk in response_generator: 102 | if chunk: 103 | chunk = chunk.text 104 | yield chunk 105 | 106 | except asyncio.QueueFull: 107 | logger.error("Too many requests. Service is at capacity.") 108 | content = create_error_response("Too many requests. Service is at capacity.", "rate_limit_exceeded", HTTPStatus.TOO_MANY_REQUESTS) 109 | raise HTTPException(status_code=429, detail=content) 110 | 111 | except Exception as e: 112 | logger.error(f"Error in vision stream generation for request {request_id}: {str(e)}") 113 | content = create_error_response(f"Failed to generate vision stream: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR) 114 | raise HTTPException(status_code=500, detail=content) 115 | 116 | async def generate_vision_response(self, request: ChatCompletionRequest): 117 | """ 118 | Generate a complete response for vision-based chat completion requests. 119 | Uses the request queue for handling concurrent requests. 120 | 121 | Args: 122 | request: ChatCompletionRequest object containing the messages. 123 | 124 | Returns: 125 | str: Complete response. 126 | """ 127 | try: 128 | # Create a unique request ID 129 | request_id = f"vision-{uuid.uuid4()}" 130 | 131 | # Prepare the vision request 132 | chat_messages, image_paths, model_params = await self._prepare_vision_request(request) 133 | 134 | # Create a request data object 135 | request_data = { 136 | "images": image_paths, 137 | "messages": chat_messages, 138 | "stream": False, 139 | **model_params 140 | } 141 | 142 | response = await self.request_queue.submit(request_id, request_data) 143 | 144 | return response 145 | 146 | except asyncio.QueueFull: 147 | logger.error("Too many requests. Service is at capacity.") 148 | content = create_error_response("Too many requests. Service is at capacity.", "rate_limit_exceeded", HTTPStatus.TOO_MANY_REQUESTS) 149 | raise HTTPException(status_code=429, detail=content) 150 | except Exception as e: 151 | logger.error(f"Error in vision response generation: {str(e)}") 152 | content = create_error_response(f"Failed to generate vision response: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR) 153 | raise HTTPException(status_code=500, detail=content) 154 | 155 | async def generate_text_stream(self, request: ChatCompletionRequest): 156 | """ 157 | Generate a streaming response for text-only chat completion requests. 158 | Uses the request queue for handling concurrent requests. 159 | 160 | Args: 161 | request: ChatCompletionRequest object containing the messages. 162 | 163 | Returns: 164 | AsyncGenerator: Yields response chunks. 165 | """ 166 | request_id = f"text-{uuid.uuid4()}" 167 | 168 | try: 169 | chat_messages, model_params = await self._prepare_text_request(request) 170 | request_data = { 171 | "messages": chat_messages, 172 | "stream": True, 173 | **model_params 174 | } 175 | response_generator = await self.request_queue.submit(request_id, request_data) 176 | 177 | for chunk in response_generator: 178 | if chunk: 179 | yield chunk.text 180 | 181 | except asyncio.QueueFull: 182 | logger.error("Too many requests. Service is at capacity.") 183 | content = create_error_response("Too many requests. Service is at capacity.", "rate_limit_exceeded", HTTPStatus.TOO_MANY_REQUESTS) 184 | raise HTTPException(status_code=429, detail=content) 185 | except Exception as e: 186 | logger.error(f"Error in text stream generation for request {request_id}: {str(e)}") 187 | content = create_error_response(f"Failed to generate text stream: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR) 188 | raise HTTPException(status_code=500, detail=content) 189 | 190 | async def generate_text_response(self, request: ChatCompletionRequest): 191 | """ 192 | Generate a complete response for text-only chat completion requests. 193 | Uses the request queue for handling concurrent requests. 194 | 195 | Args: 196 | request: ChatCompletionRequest object containing the messages. 197 | 198 | Returns: 199 | str: Complete response. 200 | """ 201 | try: 202 | # Create a unique request ID 203 | request_id = f"text-{uuid.uuid4()}" 204 | 205 | # Prepare the text request 206 | chat_messages, model_params = await self._prepare_text_request(request) 207 | 208 | # Create a request data object 209 | request_data = { 210 | "messages": chat_messages, 211 | "stream": False, 212 | **model_params 213 | } 214 | 215 | # Submit to the vision queue (reusing the same queue for text requests) 216 | response = await self.request_queue.submit(request_id, request_data) 217 | return response 218 | 219 | except asyncio.QueueFull: 220 | logger.error("Too many requests. Service is at capacity.") 221 | content = create_error_response("Too many requests. Service is at capacity.", "rate_limit_exceeded", HTTPStatus.TOO_MANY_REQUESTS) 222 | raise HTTPException(status_code=429, detail=content) 223 | except Exception as e: 224 | logger.error(f"Error in text response generation: {str(e)}") 225 | content = create_error_response(f"Failed to generate text response: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR) 226 | raise HTTPException(status_code=500, detail=content) 227 | 228 | async def generate_embeddings_response(self, request: EmbeddingRequest): 229 | """ 230 | Generate embeddings for a given text input. 231 | 232 | Args: 233 | request: EmbeddingRequest object containing the text input. 234 | 235 | Returns: 236 | List[float]: Embeddings for the input text or images 237 | """ 238 | try: 239 | # Create a unique request ID 240 | image_url = request.image_url 241 | # Process the image URL to get a local file path 242 | images = [] 243 | if request.image_url: 244 | image_path = await self.image_processor.process_image_url(image_url) 245 | images.append(image_path) 246 | request_id = f"embeddings-{uuid.uuid4()}" 247 | request_data = { 248 | "type": "embeddings", 249 | "input": request.input, 250 | "model": request.model, 251 | "images": images 252 | } 253 | 254 | # Submit to the request queue 255 | response = await self.request_queue.submit(request_id, request_data) 256 | 257 | return response 258 | 259 | except Exception as e: 260 | logger.error(f"Error in embeddings generation: {str(e)}") 261 | content = create_error_response(f"Failed to generate embeddings: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR) 262 | raise HTTPException(status_code=500, detail=content) 263 | 264 | 265 | def __del__(self): 266 | """Cleanup resources on deletion.""" 267 | # Removed async cleanup from __del__; use close() instead 268 | pass 269 | 270 | async def close(self): 271 | """Explicitly cleanup resources asynchronously.""" 272 | if hasattr(self, 'image_processor'): 273 | await self.image_processor.cleanup() 274 | 275 | async def cleanup(self): 276 | """ 277 | Cleanup resources and stop the request queue before shutdown. 278 | 279 | This method ensures all pending requests are properly cancelled 280 | and resources are released, including the image processor. 281 | """ 282 | try: 283 | logger.info("Cleaning up MLXVLMHandler resources") 284 | if hasattr(self, 'request_queue'): 285 | await self.request_queue.stop() 286 | if hasattr(self, 'image_processor'): 287 | await self.image_processor.cleanup() 288 | logger.info("MLXVLMHandler cleanup completed successfully") 289 | except Exception as e: 290 | logger.error(f"Error during MLXVLMHandler cleanup: {str(e)}") 291 | raise 292 | 293 | async def _process_request(self, request_data: Dict[str, Any]) -> str: 294 | """ 295 | Process a vision request. This is the worker function for the request queue. 296 | 297 | Args: 298 | request_data: Dictionary containing the request data. 299 | 300 | Returns: 301 | str: The model's response. 302 | """ 303 | try: 304 | # Check if the request is for embeddings 305 | if request_data.get("type") == "embeddings": 306 | return self.model.get_embeddings(request_data["input"], request_data["images"]) 307 | 308 | # Extract request parameters 309 | images = request_data.get("images", []) 310 | messages = request_data.get("messages", []) 311 | stream = request_data.get("stream", False) 312 | 313 | # Remove these keys from model_params 314 | model_params = request_data.copy() 315 | model_params.pop("images", None) 316 | model_params.pop("messages", None) 317 | model_params.pop("stream", None) 318 | 319 | # Start timing 320 | start_time = time.time() 321 | 322 | # Call the model 323 | response = self.model( 324 | images=images, 325 | messages=messages, 326 | stream=stream, 327 | **model_params 328 | ) 329 | 330 | return response 331 | 332 | except Exception as e: 333 | logger.error(f"Error processing vision request: {str(e)}") 334 | raise 335 | 336 | async def get_queue_stats(self) -> Dict[str, Any]: 337 | """ 338 | Get statistics from the request queue and performance metrics. 339 | 340 | Returns: 341 | Dict with queue and performance statistics. 342 | """ 343 | queue_stats = self.request_queue.get_queue_stats() 344 | 345 | return { 346 | "queue_stats": queue_stats, 347 | } 348 | 349 | async def _prepare_text_request(self, request: ChatCompletionRequest) -> Tuple[List[Dict[str, str]], Dict[str, Any]]: 350 | """ 351 | Prepare a text request by parsing model parameters and verifying the format of messages. 352 | 353 | Args: 354 | request: ChatCompletionRequest object containing the messages. 355 | 356 | Returns: 357 | Tuple containing the formatted chat messages and model parameters. 358 | """ 359 | chat_messages = [] 360 | 361 | try: 362 | 363 | # Convert Message objects to dictionaries with 'role' and 'content' keys 364 | chat_messages = [] 365 | for message in request.messages: 366 | # Only handle simple string content for text-only requests 367 | if not isinstance(message.content, str): 368 | logger.warning(f"Non-string content in text request will be skipped: {message.role}") 369 | continue 370 | 371 | chat_messages.append({ 372 | "role": message.role, 373 | "content": message.content 374 | }) 375 | 376 | # Extract model parameters, filtering out None values 377 | model_params = { 378 | k: v for k, v in { 379 | "max_tokens": request.max_tokens, 380 | "temperature": request.temperature, 381 | "top_p": request.top_p, 382 | "frequency_penalty": request.frequency_penalty, 383 | "presence_penalty": request.presence_penalty, 384 | "stop": request.stop, 385 | "n": request.n, 386 | "seed": request.seed 387 | }.items() if v is not None 388 | } 389 | 390 | # Handle response format 391 | if request.response_format and request.response_format.get("type") == "json_object": 392 | model_params["response_format"] = "json" 393 | 394 | # Handle tools and tool choice 395 | if request.tools: 396 | model_params["tools"] = request.tools 397 | if request.tool_choice: 398 | model_params["tool_choice"] = request.tool_choice 399 | 400 | # Log processed data 401 | logger.debug(f"Processed text chat messages: {chat_messages}") 402 | logger.debug(f"Model parameters: {model_params}") 403 | 404 | return chat_messages, model_params 405 | 406 | except HTTPException: 407 | raise 408 | except Exception as e: 409 | logger.error(f"Failed to prepare text request: {str(e)}") 410 | content = create_error_response(f"Failed to process request: {str(e)}", "bad_request", HTTPStatus.BAD_REQUEST) 411 | raise HTTPException(status_code=400, detail=content) 412 | 413 | async def _prepare_vision_request(self, request: ChatCompletionRequest) -> Tuple[List[Dict[str, Any]], List[str], Dict[str, Any]]: 414 | """ 415 | Prepare the vision request by processing messages and images. 416 | 417 | This method: 418 | 1. Extracts text messages and image URLs from the request 419 | 2. Processes image URLs to get local file paths 420 | 3. Prepares model parameters 421 | 4. Returns processed data ready for model inference 422 | 423 | Args: 424 | request (ChatCompletionRequest): The incoming request containing messages and parameters. 425 | 426 | Returns: 427 | Tuple[List[Dict[str, Any]], List[str], Dict[str, Any]]: A tuple containing: 428 | - List of processed chat messages 429 | - List of processed image paths 430 | - Dictionary of model parameters 431 | """ 432 | chat_messages = [] 433 | image_urls = [] 434 | 435 | try: 436 | # Process each message in the request 437 | for message in request.messages: 438 | # Handle system and assistant messages (simple text content) 439 | if message.role in ["system", "assistant"]: 440 | chat_messages.append({"role": message.role, "content": message.content}) 441 | continue 442 | 443 | # Handle user messages 444 | if message.role == "user": 445 | # Case 1: Simple string content 446 | if isinstance(message.content, str): 447 | chat_messages.append({"role": "user", "content": message.content}) 448 | continue 449 | 450 | # Case 2: Content is a list of dictionaries or objects 451 | if isinstance(message.content, list): 452 | # Initialize containers for this message 453 | texts = [] 454 | images = [] 455 | # Process each content item in the list 456 | for item in message.content: 457 | if item.type == "text": 458 | text = getattr(item, "text", "").strip() 459 | if text: 460 | texts.append(text) 461 | 462 | elif item.type == "image_url": 463 | url = getattr(item, "image_url", None) 464 | if url and hasattr(url, "url"): 465 | url = url.url 466 | # Validate URL 467 | self._validate_image_url(url) 468 | images.append(url) 469 | 470 | # Add collected images to global list 471 | if images: 472 | image_urls.extend(images) 473 | 474 | # Validate constraints 475 | if len(images) > 4: 476 | content = create_error_response("Too many images in a single message (max: 4)", "invalid_request_error", HTTPStatus.BAD_REQUEST) 477 | raise HTTPException(status_code=400, detail=content) 478 | 479 | # Add text content if available, otherwise raise an error 480 | if texts: 481 | chat_messages.append({"role": "user", "content": " ".join(texts)}) 482 | else: 483 | chat_messages.append({"role": "user", "content": ""}) 484 | else: 485 | content = create_error_response("Invalid message content format", "invalid_request_error", HTTPStatus.BAD_REQUEST) 486 | raise HTTPException(status_code=400, detail=content) 487 | 488 | # Process images and prepare model parameters 489 | image_paths = await self.image_processor.process_image_urls(image_urls) 490 | 491 | 492 | # Get model parameters from the request 493 | temperature = request.temperature or 0.7 494 | top_p = request.top_p or 1.0 495 | frequency_penalty = request.frequency_penalty or 0.0 496 | presence_penalty = request.presence_penalty or 0.0 497 | max_tokens = request.max_tokens or 1024 498 | tools = request.tools or None 499 | tool_choice = request.tool_choice or None 500 | 501 | model_params = { 502 | "temperature": temperature, 503 | "top_p": top_p, 504 | "frequency_penalty": frequency_penalty, 505 | "presence_penalty": presence_penalty, 506 | "max_tokens": max_tokens, 507 | "tools": tools, 508 | "tool_choice": tool_choice 509 | } 510 | 511 | # Log processed data at debug level 512 | logger.debug(f"Processed chat messages: {chat_messages}") 513 | logger.debug(f"Processed image paths: {image_paths}") 514 | logger.debug(f"Model parameters: {model_params}") 515 | 516 | return chat_messages, image_paths, model_params 517 | 518 | except HTTPException: 519 | raise 520 | except Exception as e: 521 | logger.error(f"Failed to prepare vision request: {str(e)}") 522 | content = create_error_response(f"Failed to process request: {str(e)}", "bad_request", HTTPStatus.BAD_REQUEST) 523 | raise HTTPException(status_code=400, detail=content) 524 | 525 | def _validate_image_url(self, url: str) -> None: 526 | """ 527 | Validate image URL format. 528 | 529 | Args: 530 | url: The image URL to validate 531 | 532 | Raises: 533 | HTTPException: If URL is invalid 534 | """ 535 | if not url: 536 | content = create_error_response("Empty image URL provided", "invalid_request_error", HTTPStatus.BAD_REQUEST) 537 | raise HTTPException(status_code=400, detail=content) 538 | 539 | # Validate base64 images 540 | if url.startswith("data:"): 541 | try: 542 | header, encoded = url.split(",", 1) 543 | if not header.startswith("data:image/"): 544 | raise ValueError("Invalid image format") 545 | base64.b64decode(encoded) 546 | except Exception as e: 547 | content = create_error_response(f"Invalid base64 image: {str(e)}", "invalid_request_error", HTTPStatus.BAD_REQUEST) 548 | raise HTTPException(status_code=400, detail=content) 549 | -------------------------------------------------------------------------------- /app/handler/parser/__init__.py: -------------------------------------------------------------------------------- 1 | from app.handler.parser.base import BaseToolParser, BaseThinkingParser 2 | from app.handler.parser.qwen3 import Qwen3ToolParser, Qwen3ThinkingParser 3 | from typing import Tuple 4 | __all__ = ['BaseToolParser', 'BaseThinkingParser', 'Qwen3ToolParser', 'Qwen3ThinkingParser'] 5 | 6 | parser_map = { 7 | 'qwen3': { 8 | "tool_parser": Qwen3ToolParser, 9 | "thinking_parser": Qwen3ThinkingParser 10 | } 11 | } 12 | 13 | def get_parser(model_name: str) -> Tuple[BaseToolParser, BaseThinkingParser]: 14 | if model_name not in parser_map: 15 | return None, None 16 | 17 | model_parsers = parser_map[model_name] 18 | tool_parser = model_parsers.get("tool_parser") 19 | thinking_parser = model_parsers.get("thinking_parser") 20 | 21 | return (tool_parser() if tool_parser else None, 22 | thinking_parser() if thinking_parser else None) -------------------------------------------------------------------------------- /app/handler/parser/base.py: -------------------------------------------------------------------------------- 1 | import json 2 | import uuid 3 | from typing import Any, Dict, List, Tuple 4 | 5 | 6 | class BaseThinkingParser: 7 | def __init__(self, thinking_open: str, thinking_close: str): 8 | self.thinking_open = thinking_open 9 | self.thinking_close = thinking_close 10 | self.is_thinking = False 11 | 12 | def parse(self, content: str) -> str: 13 | if self.thinking_open in content: 14 | start_thinking = content.find(self.thinking_open) 15 | end_thinking = content.find(self.thinking_close) 16 | if end_thinking != -1: 17 | return content[start_thinking + len(self.thinking_open):end_thinking].strip(), content[end_thinking + len(self.thinking_close):].strip() 18 | return None, content 19 | 20 | def parse_stream(self, chunk: str) -> Tuple[str, bool]: 21 | if not self.is_thinking: 22 | if chunk == self.thinking_open: 23 | self.is_thinking = True 24 | return None, False 25 | return chunk, False 26 | if chunk == self.thinking_close: 27 | self.is_thinking = False 28 | return None, True 29 | 30 | return { 31 | "reasoning_content": chunk 32 | }, False 33 | 34 | class ParseState: 35 | NORMAL = 0 36 | FOUND_PREFIX = 1 37 | FOUND_FUNC_NAME = 2 38 | FOUND_FUNC_ARGS = 3 39 | PROCESS_FUNC_ARGS = 4 40 | @staticmethod 41 | def next_state(state): 42 | return (state + 1) % 5 43 | 44 | class BaseToolParser: 45 | def __init__(self, tool_open: str, tool_close: str): 46 | self.tool_open = tool_open 47 | self.tool_close = tool_close 48 | self.buffer = "" 49 | self.state = ParseState.NORMAL 50 | 51 | def get_tool_open(self): 52 | return self.tool_open 53 | 54 | def get_tool_close(self): 55 | return self.tool_close 56 | 57 | def parse(self, content: str) -> Tuple[List[Dict[str, Any]], str]: 58 | res = [] 59 | start = 0 60 | while True: 61 | start_tool = content.find(self.tool_open, start) 62 | if start_tool == -1: 63 | break 64 | end_tool = content.find(self.tool_close, start_tool + len(self.tool_open)) 65 | if end_tool == -1: 66 | break 67 | tool_content = content[start_tool + len(self.tool_open):end_tool].strip() 68 | 69 | try: 70 | json_output = json.loads(tool_content) 71 | res.append(json_output) 72 | except json.JSONDecodeError: 73 | print("Error parsing tool call: ", tool_content) 74 | break 75 | start = end_tool + len(self.tool_close) 76 | return res, content[start:].strip() 77 | 78 | def parse_stream(self, chunk: str): 79 | if self.state == ParseState.NORMAL: 80 | if chunk.strip() == self.tool_open: 81 | self.state = ParseState.next_state(self.state) 82 | self.buffer = "" 83 | self.current_func = None 84 | return None 85 | return chunk 86 | 87 | if self.state == ParseState.FOUND_PREFIX: 88 | self.buffer += chunk 89 | # Try to parse function name 90 | if self.buffer.count('"') >= 4: 91 | # try parse json 92 | try: 93 | json_output = json.loads(self.buffer.rstrip(',') + "}") 94 | self.current_func = { 95 | "name": None 96 | } 97 | self.state = ParseState.next_state(self.state) 98 | return { 99 | "name": json_output["name"], 100 | "arguments": "" 101 | } 102 | except json.JSONDecodeError: 103 | return None 104 | return None 105 | 106 | if self.state == ParseState.FOUND_FUNC_NAME: 107 | # Try to parse function arguments 108 | if chunk.strip() == "arguments": 109 | self.state = ParseState.next_state(self.state) 110 | return None 111 | return None 112 | 113 | if self.state == ParseState.FOUND_FUNC_ARGS: 114 | if ":" in chunk: 115 | chunk = chunk[:chunk.find(":") + 1: ].lstrip() 116 | self.state = ParseState.next_state(self.state) 117 | if not chunk: 118 | return None 119 | return None 120 | 121 | if '}\n' in chunk: 122 | chunk = chunk[:chunk.find('}\n')] 123 | 124 | if chunk == self.tool_close: 125 | # end of arguments 126 | # reset 127 | self.state = ParseState.NORMAL 128 | self.buffer = "" 129 | self.current_func = None 130 | return None 131 | 132 | return { 133 | "name": None, 134 | "arguments": chunk 135 | } 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /app/handler/parser/qwen3.py: -------------------------------------------------------------------------------- 1 | from app.handler.parser.base import BaseToolParser, BaseThinkingParser 2 | 3 | TOOL_OPEN = "" 4 | TOOL_CLOSE = "" 5 | THINKING_OPEN = "" 6 | THINKING_CLOSE = "" 7 | 8 | class Qwen3ToolParser(BaseToolParser): 9 | """Parser for Qwen3 model's tool response format.""" 10 | 11 | def __init__(self): 12 | super().__init__( 13 | tool_open=TOOL_OPEN, 14 | tool_close=TOOL_CLOSE 15 | ) 16 | 17 | class Qwen3ThinkingParser(BaseThinkingParser): 18 | """Parser for Qwen3 model's thinking response format.""" 19 | 20 | def __init__(self): 21 | super().__init__( 22 | thinking_open=THINKING_OPEN, 23 | thinking_close=THINKING_CLOSE 24 | ) -------------------------------------------------------------------------------- /app/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import gc 4 | import time 5 | from contextlib import asynccontextmanager 6 | 7 | import uvicorn 8 | from fastapi import FastAPI, Request 9 | from fastapi.middleware.cors import CORSMiddleware 10 | from fastapi.responses import JSONResponse 11 | from loguru import logger 12 | 13 | from app.handler.mlx_vlm import MLXVLMHandler 14 | from app.handler.mlx_lm import MLXLMHandler 15 | from app.api.endpoints import router 16 | from app.version import __version__ 17 | 18 | # Configure loguru 19 | logger.remove() # Remove default handler 20 | logger.add( 21 | "logs/app.log", 22 | rotation="500 MB", 23 | retention="10 days", 24 | level="INFO", 25 | format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}" 26 | ) 27 | logger.add(lambda msg: print(msg), level="INFO") # Also print to console 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser(description="OAI-compatible proxy") 31 | parser.add_argument("--model-path", type=str, required=True, help="Huggingface model repo or local path") 32 | parser.add_argument("--model-type", type=str, default="lm", choices=["lm", "vlm"], help="Model type") 33 | parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") 34 | parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on") 35 | parser.add_argument("--max-concurrency", type=int, default=1, help="Maximum number of concurrent requests") 36 | parser.add_argument("--queue-timeout", type=int, default=300, help="Request timeout in seconds") 37 | parser.add_argument("--queue-size", type=int, default=100, help="Maximum queue size for pending requests") 38 | return parser.parse_args() 39 | 40 | 41 | def create_lifespan(config_args): 42 | """Factory function to create a lifespan context manager with access to config args.""" 43 | @asynccontextmanager 44 | async def lifespan(app: FastAPI): 45 | try: 46 | logger.info(f"Initializing MLX handler with model path: {config_args.model_path}") 47 | if config_args.model_type == "vlm": 48 | handler = MLXVLMHandler( 49 | model_path=config_args.model_path, 50 | max_concurrency=config_args.max_concurrency 51 | ) 52 | else: 53 | handler = MLXLMHandler( 54 | model_path=config_args.model_path, 55 | max_concurrency=config_args.max_concurrency 56 | ) 57 | # Initialize queue 58 | await handler.initialize({ 59 | "max_concurrency": config_args.max_concurrency, 60 | "timeout": config_args.queue_timeout, 61 | "queue_size": config_args.queue_size 62 | }) 63 | logger.info("MLX handler initialized successfully") 64 | app.state.handler = handler 65 | except Exception as e: 66 | logger.error(f"Failed to initialize MLX handler: {str(e)}") 67 | raise 68 | gc.collect() 69 | yield 70 | # Shutdown 71 | logger.info("Shutting down application") 72 | if hasattr(app.state, "handler") and app.state.handler: 73 | try: 74 | # Use the proper cleanup method which handles both request queue and image processor 75 | logger.info("Cleaning up resources") 76 | await app.state.handler.cleanup() 77 | logger.info("Resources cleaned up successfully") 78 | except Exception as e: 79 | logger.error(f"Error during shutdown: {str(e)}") 80 | 81 | return lifespan 82 | 83 | # App instance will be created during setup with the correct lifespan 84 | app = None 85 | 86 | async def setup_server(args) -> uvicorn.Config: 87 | global app 88 | 89 | # Create FastAPI app with the configured lifespan 90 | app = FastAPI( 91 | title="OpenAI-compatible API", 92 | description="API for OpenAI-compatible chat completion and text embedding", 93 | version=__version__, 94 | lifespan=create_lifespan(args) 95 | ) 96 | 97 | app.include_router(router) 98 | 99 | # Add CORS middleware 100 | app.add_middleware( 101 | CORSMiddleware, 102 | allow_origins=["*"], # In production, replace with specific origins 103 | allow_credentials=True, 104 | allow_methods=["*"], 105 | allow_headers=["*"], 106 | ) 107 | 108 | @app.middleware("http") 109 | async def add_process_time_header(request: Request, call_next): 110 | start_time = time.time() 111 | response = await call_next(request) 112 | process_time = time.time() - start_time 113 | response.headers["X-Process-Time"] = str(process_time) 114 | return response 115 | 116 | @app.exception_handler(Exception) 117 | async def global_exception_handler(request: Request, exc: Exception): 118 | logger.error(f"Global exception handler caught: {str(exc)}", exc_info=True) 119 | return JSONResponse( 120 | status_code=500, 121 | content={"error": {"message": "Internal server error", "type": "internal_error"}} 122 | ) 123 | 124 | logger.info(f"Starting server on {args.host}:{args.port}") 125 | config = uvicorn.Config( 126 | app=app, 127 | host=args.host, 128 | port=args.port, 129 | log_level="info", 130 | access_log=True 131 | ) 132 | return config 133 | 134 | if __name__ == "__main__": 135 | args = parse_args() 136 | config = asyncio.run(setup_server(args)) 137 | uvicorn.Server(config).run() -------------------------------------------------------------------------------- /app/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cubist38/mlx-openai-server/1c0b466e28a61471b6bb8ed8fe21deb9db7a76a6/app/models/__init__.py -------------------------------------------------------------------------------- /app/models/mlx_lm.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | from mlx_lm.utils import load 3 | from mlx_lm.generate import ( 4 | generate, 5 | stream_generate, 6 | ) 7 | from mlx_lm.sample_utils import make_sampler 8 | from typing import List, Dict, Union, Generator, Optional, Tuple 9 | 10 | DEFAULT_TEMPERATURE = 0.7 11 | DEFAULT_TOP_P = 0.95 12 | DEFAULT_TOP_K = 20 13 | DEFAULT_MIN_P = 0.0 14 | DEFAULT_SEED = 0 15 | DEFAULT_MAX_TOKENS = 512 16 | DEFAULT_BATCH_SIZE = 32 17 | 18 | class MLX_LM: 19 | """ 20 | A wrapper class for MLX Language Model that handles both streaming and non-streaming inference. 21 | 22 | This class provides a unified interface for generating text responses from text prompts, 23 | supporting both streaming and non-streaming modes. 24 | """ 25 | 26 | def __init__(self, model_path: str): 27 | try: 28 | self.model, self.tokenizer = load(model_path) 29 | self.pad_token_id = self.tokenizer.pad_token_id 30 | self.bos_token = self.tokenizer.bos_token 31 | self.model_type = self.model.model_type 32 | except Exception as e: 33 | raise ValueError(f"Error loading model: {str(e)}") 34 | 35 | def _apply_pooling_strategy(self, embeddings: mx.array) -> mx.array: 36 | embeddings = mx.mean(embeddings, axis=1) 37 | return embeddings 38 | 39 | def _apply_l2_normalization(self, embeddings: mx.array) -> mx.array: 40 | l2_norms = mx.linalg.norm(embeddings, axis=1, keepdims=True) 41 | embeddings = embeddings / (l2_norms + 1e-8) 42 | return embeddings 43 | 44 | def _batch_process(self, prompts: List[str], batch_size: int = DEFAULT_BATCH_SIZE) -> List[List[int]]: 45 | """Process prompts in batches with optimized tokenization.""" 46 | all_tokenized = [] 47 | 48 | # Process prompts in batches 49 | for i in range(0, len(prompts), batch_size): 50 | batch = prompts[i:i + batch_size] 51 | tokenized_batch = [] 52 | 53 | # Tokenize all prompts in batch 54 | for p in batch: 55 | add_special_tokens = self.bos_token is None or not p.startswith(self.bos_token) 56 | tokens = self.tokenizer.encode(p, add_special_tokens=add_special_tokens) 57 | tokenized_batch.append(tokens) 58 | 59 | # Find max length in batch 60 | max_length = max(len(tokens) for tokens in tokenized_batch) 61 | 62 | # Pad tokens in a vectorized way 63 | for tokens in tokenized_batch: 64 | padding = [self.pad_token_id] * (max_length - len(tokens)) 65 | all_tokenized.append(tokens + padding) 66 | 67 | return all_tokenized 68 | 69 | def _preprocess_prompt(self, prompt: str) -> List[int]: 70 | """Tokenize a single prompt efficiently.""" 71 | add_special_tokens = self.bos_token is None or not prompt.startswith(self.bos_token) 72 | tokens = self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens) 73 | return mx.array(tokens) 74 | 75 | def get_model_type(self) -> str: 76 | return self.model_type 77 | 78 | def get_embeddings( 79 | self, 80 | prompts: List[str], 81 | batch_size: int = DEFAULT_BATCH_SIZE, 82 | normalize: bool = True 83 | ) -> List[float]: 84 | """ 85 | Get embeddings for a list of prompts efficiently. 86 | 87 | Args: 88 | prompts: List of text prompts 89 | batch_size: Size of batches for processing 90 | 91 | Returns: 92 | List of embeddings as float arrays 93 | """ 94 | # Process in batches to optimize memory usage 95 | all_embeddings = [] 96 | for i in range(0, len(prompts), batch_size): 97 | batch_prompts = prompts[i:i + batch_size] 98 | tokenized_batch = self._batch_process(batch_prompts, batch_size) 99 | 100 | # Convert to MLX array for efficient computation 101 | tokenized_batch = mx.array(tokenized_batch) 102 | 103 | # Compute embeddings for batch 104 | batch_embeddings = self.model.model(tokenized_batch) 105 | pooled_embedding = self._apply_pooling_strategy(batch_embeddings) 106 | if normalize: 107 | pooled_embedding = self._apply_l2_normalization(pooled_embedding) 108 | all_embeddings.extend(pooled_embedding.tolist()) 109 | 110 | return all_embeddings 111 | 112 | def __call__( 113 | self, 114 | messages: List[Dict[str, str]], 115 | stream: bool = False, 116 | **kwargs 117 | ) -> Union[str, Generator[str, None, None]]: 118 | """ 119 | Generate text response from the model. 120 | 121 | Args: 122 | messages (List[Dict[str, str]]): List of messages in the conversation. 123 | stream (bool): Whether to stream the response. 124 | **kwargs: Additional parameters for generation 125 | - temperature: Sampling temperature (default: 0.0) 126 | - top_p: Top-p sampling parameter (default: 1.0) 127 | - seed: Random seed (default: 0) 128 | - max_tokens: Maximum number of tokens to generate (default: 256) 129 | """ 130 | # Set default parameters if not provided 131 | seed = kwargs.get("seed", DEFAULT_SEED) 132 | max_tokens = kwargs.get("max_tokens", DEFAULT_MAX_TOKENS) 133 | chat_template_kwargs = kwargs.get("chat_template_kwargs", {}) 134 | 135 | sampler_kwargs = { 136 | "temp": kwargs.get("temperature", DEFAULT_TEMPERATURE), 137 | "top_p": kwargs.get("top_p", DEFAULT_TOP_P), 138 | "top_k": kwargs.get("top_k", DEFAULT_TOP_K), 139 | "min_p": kwargs.get("min_p", DEFAULT_MIN_P) 140 | } 141 | 142 | mx.random.seed(seed) 143 | 144 | # Prepare input tokens 145 | prompt = self.tokenizer.apply_chat_template( 146 | messages, 147 | **chat_template_kwargs 148 | ) 149 | 150 | sampler = make_sampler( 151 | **sampler_kwargs 152 | ) 153 | 154 | if not stream: 155 | return generate( 156 | self.model, 157 | self.tokenizer, 158 | prompt, 159 | sampler=sampler, 160 | max_tokens=max_tokens 161 | ) 162 | else: 163 | # Streaming mode: return generator of chunks 164 | return stream_generate( 165 | self.model, 166 | self.tokenizer, 167 | prompt, 168 | sampler=sampler, 169 | max_tokens=max_tokens 170 | ) -------------------------------------------------------------------------------- /app/models/mlx_vlm.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Union, Generator, Optional 2 | from mlx_vlm import load 3 | from mlx_vlm.prompt_utils import apply_chat_template 4 | from mlx_vlm.utils import load_config, generate, stream_generate, prepare_inputs 5 | import mlx.core as mx 6 | 7 | 8 | # Default model parameters 9 | DEFAULT_MAX_TOKENS = 256 10 | DEFAULT_TEMPERATURE = 0.0 11 | DEFAULT_TOP_P = 1.0 12 | DEFAULT_SEED = 0 13 | 14 | class MLX_VLM: 15 | """ 16 | A wrapper class for MLX Vision Language Model that handles both streaming and non-streaming inference. 17 | 18 | This class provides a unified interface for generating text responses from images and text prompts, 19 | supporting both streaming and non-streaming modes. 20 | """ 21 | 22 | def __init__(self, model_path: str): 23 | """ 24 | Initialize the MLX_VLM model. 25 | 26 | Args: 27 | model_path (str): Path to the model directory containing model weights and configuration. 28 | 29 | Raises: 30 | ValueError: If model loading fails. 31 | """ 32 | try: 33 | self.model, self.processor = load(model_path, lazy=False, trust_remote_code=True) 34 | self.config = load_config(model_path, trust_remote_code=True) 35 | except Exception as e: 36 | raise ValueError(f"Error loading model: {str(e)}") 37 | 38 | def __call__( 39 | self, 40 | messages: List[Dict[str, str]], 41 | images: List[str] = None, 42 | stream: bool = False, 43 | **kwargs 44 | ) -> Union[str, Generator[str, None, None]]: 45 | """ 46 | Generate text response from images and messages. 47 | 48 | Args: 49 | images (List[str]): List of image paths to process. 50 | messages (List[Dict[str, str]]): List of message dictionaries with 'role' and 'content' keys. 51 | stream (bool, optional): Whether to stream the response. Defaults to False. 52 | **kwargs: Additional model parameters (temperature, max_tokens, etc.) 53 | 54 | Returns: 55 | Union[str, Generator[str, None, None]]: 56 | - If stream=False: Complete response as string 57 | - If stream=True: Generator yielding response chunks 58 | """ 59 | # Prepare the prompt using the chat template 60 | prompt = apply_chat_template( 61 | self.processor, 62 | self.config, 63 | messages, 64 | add_generation_prompt=True, 65 | num_images=len(images) if images else 0 66 | ) 67 | # Set default parameters if not provided 68 | model_params = { 69 | "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE), 70 | "max_tokens": kwargs.get("max_tokens", DEFAULT_MAX_TOKENS), 71 | **kwargs 72 | } 73 | 74 | if not stream: 75 | # Non-streaming mode: return complete response 76 | text, _ = generate( 77 | self.model, 78 | self.processor, 79 | prompt, 80 | image=images, 81 | **model_params 82 | ) 83 | return text 84 | else: 85 | # Streaming mode: return generator of chunks 86 | return stream_generate( 87 | self.model, 88 | self.processor, 89 | prompt, 90 | images, 91 | **model_params 92 | ) 93 | 94 | def get_embeddings( 95 | self, 96 | prompts: List[str], 97 | images: Optional[List[str]] = None, 98 | batch_size: int = 1, 99 | normalize: bool = True 100 | ) -> List[List[float]]: 101 | """ 102 | Get embeddings for a list of prompts and optional images, supporting batch processing. 103 | Args: 104 | prompts: List of text prompts 105 | images: Optional list of image paths (must be same length as prompts if provided) 106 | batch_size: Size of batches for processing 107 | normalize: Whether to apply L2 normalization to embeddings 108 | Returns: 109 | List of embeddings as float arrays 110 | """ 111 | if images is None: 112 | images = [] 113 | 114 | # Text-only batch 115 | if not images: 116 | # Batch tokenize and pad 117 | tokenized = [self.processor.tokenizer.encode(self._format_prompt(p, 0), add_special_tokens=True) for p in prompts] 118 | max_len = max(len(t) for t in tokenized) 119 | pad_id = self.processor.tokenizer.pad_token_id or self.processor.tokenizer.eos_token_id 120 | batch_input_ids = [t + [pad_id] * (max_len - len(t)) for t in tokenized] 121 | batch_input_ids = mx.array(batch_input_ids) 122 | 123 | # Run in batches 124 | all_embeddings = [] 125 | for i in range(0, len(prompts), batch_size): 126 | batch_ids = batch_input_ids[i:i+batch_size] 127 | embeddings = self.model.language_model.model(batch_ids) 128 | pooled = self._apply_pooling_strategy(embeddings) 129 | if normalize: 130 | pooled = self._apply_l2_normalization(pooled) 131 | all_embeddings.extend(pooled.tolist()) 132 | return all_embeddings 133 | 134 | # Image+prompt batch 135 | if len(images) != len(prompts): 136 | raise ValueError("If images are provided, must be same length as prompts (one image per prompt)") 137 | 138 | all_embeddings = [] 139 | for i in range(0, len(prompts), batch_size): 140 | batch_prompts = prompts[i:i+batch_size] 141 | batch_images = images[i:i+batch_size] 142 | formatted_prompts = [self._format_prompt(p, 1) for p in batch_prompts] 143 | inputs = prepare_inputs( 144 | self.processor, 145 | batch_images, 146 | formatted_prompts, 147 | getattr(self.model.config, "image_token_index", None) 148 | ) 149 | input_ids = inputs["input_ids"] 150 | pixel_values = inputs.get("pixel_values", None) 151 | image_grid_thw = inputs.get("image_grid_thw", None) 152 | inputs_embeds = self.model.get_input_embeddings(input_ids, pixel_values, image_grid_thw) 153 | embeddings = self.model.language_model.model(None, inputs_embeds=inputs_embeds) 154 | pooled = self._apply_pooling_strategy(embeddings) 155 | if normalize: 156 | pooled = self._apply_l2_normalization(pooled) 157 | all_embeddings.extend(pooled.tolist()) 158 | return all_embeddings 159 | 160 | def _format_prompt(self, prompt: str, n_images: int) -> str: 161 | """Format a single prompt using the chat template.""" 162 | return apply_chat_template( 163 | self.processor, 164 | self.config, 165 | prompt, 166 | add_generation_prompt=True, 167 | num_images=n_images 168 | ) 169 | 170 | def _prepare_single_input(self, formatted_prompt: str, images: List[str]) -> Dict: 171 | """Prepare inputs for a single prompt-image pair.""" 172 | return prepare_inputs( 173 | self.processor, 174 | images, 175 | formatted_prompt, 176 | getattr(self.model.config, "image_token_index", None) 177 | ) 178 | 179 | def _get_single_embedding( 180 | self, 181 | inputs: Dict, 182 | normalize: bool = True 183 | ) -> List[float]: 184 | """Get embedding for a single processed input.""" 185 | input_ids = inputs["input_ids"] 186 | pixel_values = inputs.get("pixel_values", None) 187 | 188 | # Extract additional kwargs 189 | data_kwargs = { 190 | k: v for k, v in inputs.items() 191 | if k not in ["input_ids", "pixel_values", "attention_mask"] 192 | } 193 | image_grid_thw = data_kwargs.pop("image_grid_thw", None) 194 | 195 | inputs_embeds = self.model.get_input_embeddings(input_ids, pixel_values, image_grid_thw) 196 | embeddings = self.model.language_model.model(None, inputs_embeds=inputs_embeds) 197 | 198 | # Apply pooling 199 | pooled_embedding = self._apply_pooling_strategy(embeddings) 200 | 201 | # Apply normalization if requested 202 | if normalize: 203 | pooled_embedding = self._apply_l2_normalization(pooled_embedding) 204 | 205 | return pooled_embedding.tolist() 206 | 207 | def _apply_pooling_strategy(self, embeddings: mx.array) -> mx.array: 208 | """Apply mean pooling to embeddings.""" 209 | return mx.mean(embeddings, axis=1) 210 | 211 | def _apply_l2_normalization(self, embeddings: mx.array) -> mx.array: 212 | """Apply L2 normalization to embeddings.""" 213 | l2_norms = mx.linalg.norm(embeddings, axis=1, keepdims=True) 214 | embeddings = embeddings / (l2_norms + 1e-8) 215 | return embeddings -------------------------------------------------------------------------------- /app/schemas/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /app/schemas/openai.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Union 2 | 3 | from pydantic import BaseModel, Field, validator 4 | from typing_extensions import Literal 5 | from loguru import logger 6 | 7 | 8 | # Configuration 9 | class Config: 10 | """ 11 | Configuration class holding the default model names for different types of requests. 12 | """ 13 | TEXT_MODEL = "gpt-4-turbo" # Default model for text-based chat completions 14 | VISION_MODEL = "gpt-4-vision-preview" # Model used for vision-based requests 15 | EMBEDDING_MODEL = "text-embedding-ada-002" # Model used for generating embeddings 16 | 17 | class ErrorResponse(BaseModel): 18 | object: str = Field("error", description="The object type, always 'error'.") 19 | message: str = Field(..., description="The error message.") 20 | type: str = Field(..., description="The type of error.") 21 | param: Optional[str] = Field(None, description="The parameter related to the error, if any.") 22 | code: int = Field(..., description="The error code.") 23 | 24 | # Common models used in both streaming and non-streaming contexts 25 | class ImageUrl(BaseModel): 26 | """ 27 | Represents an image URL in a message. 28 | """ 29 | url: str = Field(..., description="The image URL.") 30 | 31 | class VisionContentItem(BaseModel): 32 | """ 33 | Represents a single content item in a message (text or image). 34 | """ 35 | type: str = Field(..., description="The type of content, e.g., 'text' or 'image_url'.") 36 | text: Optional[str] = Field(None, description="The text content, if type is 'text'.") 37 | image_url: Optional[ImageUrl] = Field(None, description="The image URL object, if type is 'image_url'.") 38 | 39 | class FunctionCall(BaseModel): 40 | """ 41 | Represents a function call in a message. 42 | """ 43 | arguments: str = Field(..., description="The arguments for the function call.") 44 | name: str = Field(..., description="The name of the function to call.") 45 | 46 | class ChatCompletionMessageToolCall(BaseModel): 47 | """ 48 | Represents a tool call in a message. 49 | """ 50 | id: str = Field(..., description="The ID of the tool call.") 51 | function: FunctionCall = Field(..., description="The function call details.") 52 | type: Literal["function"] = Field(..., description="The type of tool call, always 'function'.") 53 | index: int = Field(..., description="The index of the tool call.") 54 | 55 | class Message(BaseModel): 56 | """ 57 | Represents a message in a chat completion. 58 | """ 59 | content: Union[str, List[VisionContentItem]] = Field(None, description="The content of the message, either text or a list of vision content items.") 60 | refusal: Optional[str] = Field(None, description="The refusal reason, if any.") 61 | role: Literal["system", "user", "assistant", "tool"] = Field(..., description="The role of the message sender.") 62 | function_call: Optional[FunctionCall] = Field(None, description="The function call, if any.") 63 | reasoning_content: Optional[str] = Field(None, description="The reasoning content, if any.") 64 | tool_calls: Optional[List[ChatCompletionMessageToolCall]] = Field(None, description="List of tool calls, if any.") 65 | 66 | # Common request base for both streaming and non-streaming 67 | class ChatCompletionRequestBase(BaseModel): 68 | """ 69 | Base model for chat completion requests. 70 | """ 71 | model: str = Field(Config.TEXT_MODEL, description="The model to use for completion.") 72 | messages: List[Message] = Field(..., description="The list of messages in the conversation.") 73 | tools: Optional[List[Dict[str, Any]]] = Field(None, description="List of tools available for the request.") 74 | tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(None, description="Tool choice for the request.") 75 | max_tokens: Optional[int] = Field(None, description="The maximum number of tokens to generate.") 76 | temperature: Optional[float] = Field(0.7, description="Sampling temperature.") 77 | top_p: Optional[float] = Field(1.0, description="Nucleus sampling probability.") 78 | frequency_penalty: Optional[float] = Field(0.0, description="Frequency penalty for token generation.") 79 | presence_penalty: Optional[float] = Field(0.0, description="Presence penalty for token generation.") 80 | stop: Optional[List[str]] = Field(None, description="List of stop sequences.") 81 | n: Optional[int] = Field(1, description="Number of completions to generate.") 82 | response_format: Optional[Dict[str, str]] = Field(None, description="Format for the response.") 83 | seed: Optional[int] = Field(None, description="Random seed for reproducibility.") 84 | user: Optional[str] = Field(None, description="User identifier.") 85 | 86 | @validator("messages") 87 | def check_messages_not_empty(cls, v): 88 | """ 89 | Ensure that the messages list is not empty and validate message structure. 90 | """ 91 | if not v: 92 | raise ValueError("messages cannot be empty") 93 | 94 | # Validate message history length 95 | if len(v) > 100: # OpenAI's limit is typically around 100 messages 96 | raise ValueError("message history too long") 97 | 98 | # Validate message roles 99 | valid_roles = {"user", "assistant", "system", "tool"} 100 | for msg in v: 101 | if msg.role not in valid_roles: 102 | raise ValueError(f"invalid role: {msg.role}") 103 | 104 | return v 105 | 106 | @validator("temperature") 107 | def check_temperature(cls, v): 108 | """ 109 | Validate temperature is between 0 and 2. 110 | """ 111 | if v is not None and (v < 0 or v > 2): 112 | raise ValueError("temperature must be between 0 and 2") 113 | return v 114 | 115 | @validator("max_tokens") 116 | def check_max_tokens(cls, v): 117 | """ 118 | Validate max_tokens is positive and within reasonable limits. 119 | """ 120 | if v is not None: 121 | if v <= 0: 122 | raise ValueError("max_tokens must be positive") 123 | if v > 4096: # Typical limit for GPT-4 124 | raise ValueError("max_tokens too high") 125 | return v 126 | 127 | def is_vision_request(self) -> bool: 128 | """ 129 | Check if the request includes image content, indicating a vision-based request. 130 | """ 131 | for message in self.messages: 132 | content = message.content 133 | if isinstance(content, list): 134 | for item in content: 135 | if hasattr(item, 'type') and item.type == "image_url": 136 | if hasattr(item, 'image_url') and item.image_url and item.image_url.url: 137 | logger.debug(f"Detected vision request with image: {item.image_url.url[:30]}...") 138 | return True 139 | 140 | logger.debug(f"No images detected, treating as text-only request") 141 | return False 142 | 143 | class ChatTemplateKwargs(BaseModel): 144 | """ 145 | Represents the arguments for a chat template. 146 | """ 147 | enable_thinking: bool = Field(False, description="Whether to enable thinking mode.") 148 | tools: Optional[List[Dict[str, Any]]] = Field(None, description="List of tools to use in the request.") 149 | add_generation_prompt: bool = Field(True, description="Whether to add a generation prompt to the request.") 150 | 151 | # Non-streaming request and response 152 | class ChatCompletionRequest(ChatCompletionRequestBase): 153 | """ 154 | Model for non-streaming chat completion requests. 155 | """ 156 | stream: bool = Field(False, description="Whether to stream the response.") 157 | chat_template_kwargs: ChatTemplateKwargs = Field(ChatTemplateKwargs(), description="Arguments for the chat template.") 158 | 159 | class Choice(BaseModel): 160 | """ 161 | Represents a choice in a chat completion response. 162 | """ 163 | finish_reason: Literal["stop", "length", "tool_calls", "content_filter", "function_call"] = Field(..., description="The reason for the choice.") 164 | index: int = Field(..., description="The index of the choice.") 165 | message: Message = Field(..., description="The message of the choice.") 166 | 167 | class ChatCompletionResponse(BaseModel): 168 | """ 169 | Represents a complete chat completion response. 170 | """ 171 | id: str = Field(..., description="The response ID.") 172 | object: Literal["chat.completion"] = Field(..., description="The object type, always 'chat.completion'.") 173 | created: int = Field(..., description="The creation timestamp.") 174 | model: str = Field(..., description="The model used for completion.") 175 | choices: List[Choice] = Field(..., description="List of choices in the response.") 176 | 177 | 178 | class ChoiceDeltaFunctionCall(BaseModel): 179 | """ 180 | Represents a function call delta in a streaming response. 181 | """ 182 | arguments: Optional[str] = Field(None, description="Arguments for the function call delta.") 183 | name: Optional[str] = Field(None, description="Name of the function in the delta.") 184 | 185 | class ChoiceDeltaToolCall(BaseModel): 186 | """ 187 | Represents a tool call delta in a streaming response. 188 | """ 189 | index: Optional[int] = Field(None, description="Index of the tool call delta.") 190 | id: Optional[str] = Field(None, description="ID of the tool call delta.") 191 | function: Optional[ChoiceDeltaFunctionCall] = Field(None, description="Function call details in the delta.") 192 | type: Optional[str] = Field(None, description="Type of the tool call delta.") 193 | 194 | class Delta(BaseModel): 195 | """ 196 | Represents a delta in a streaming response. 197 | """ 198 | content: Optional[str] = Field(None, description="Content of the delta.") 199 | function_call: Optional[ChoiceDeltaFunctionCall] = Field(None, description="Function call delta, if any.") 200 | refusal: Optional[str] = Field(None, description="Refusal reason, if any.") 201 | role: Optional[Literal["system", "user", "assistant", "tool"]] = Field(None, description="Role in the delta.") 202 | tool_calls: Optional[List[ChoiceDeltaToolCall]] = Field(None, description="List of tool call deltas, if any.") 203 | reasoning_content: Optional[str] = Field(None, description="Reasoning content, if any.") 204 | 205 | class StreamingChoice(BaseModel): 206 | """ 207 | Represents a choice in a streaming response. 208 | """ 209 | delta: Delta = Field(..., description="The delta for this streaming choice.") 210 | finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]] = Field(None, description="The reason for finishing, if any.") 211 | index: int = Field(..., description="The index of the streaming choice.") 212 | 213 | class ChatCompletionChunk(BaseModel): 214 | """ 215 | Represents a chunk in a streaming chat completion response. 216 | """ 217 | id: str = Field(..., description="The chunk ID.") 218 | choices: List[StreamingChoice] = Field(..., description="List of streaming choices in the chunk.") 219 | created: int = Field(..., description="The creation timestamp of the chunk.") 220 | model: str = Field(..., description="The model used for the chunk.") 221 | object: Literal["chat.completion.chunk"] = Field(..., description="The object type, always 'chat.completion.chunk'.") 222 | 223 | # Embedding models 224 | class EmbeddingRequest(BaseModel): 225 | """ 226 | Model for embedding requests. 227 | """ 228 | model: str = Field(Config.EMBEDDING_MODEL, description="The embedding model to use.") 229 | input: List[str] = Field(..., description="List of text inputs for embedding.") 230 | image_url: Optional[str] = Field(default=None, description="Image URL to embed.") 231 | 232 | class Embedding(BaseModel): 233 | """ 234 | Represents an embedding object in an embedding response. 235 | """ 236 | embedding: List[float] = Field(..., description="The embedding vector.") 237 | index: int = Field(..., description="The index of the embedding in the list.") 238 | object: str = Field(default="embedding", description="The object type, always 'embedding'.") 239 | 240 | class EmbeddingResponse(BaseModel): 241 | """ 242 | Represents an embedding response. 243 | """ 244 | object: str = Field("list", description="The object type, always 'list'.") 245 | data: List[Embedding] = Field(..., description="List of embedding objects.") 246 | model: str = Field(..., description="The model used for embedding.") 247 | 248 | class Model(BaseModel): 249 | """ 250 | Represents a model in the models list response. 251 | """ 252 | id: str = Field(..., description="The model ID.") 253 | object: str = Field("model", description="The object type, always 'model'.") 254 | created: int = Field(..., description="The creation timestamp.") 255 | owned_by: str = Field("openai", description="The owner of the model.") 256 | 257 | class ModelsResponse(BaseModel): 258 | """ 259 | Represents the response for the models list endpoint. 260 | """ 261 | object: str = Field("list", description="The object type, always 'list'.") 262 | data: List[Model] = Field(..., description="List of models.") -------------------------------------------------------------------------------- /app/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /app/utils/errors.py: -------------------------------------------------------------------------------- 1 | from http import HTTPStatus 2 | from typing import Union 3 | 4 | from app.schemas.openai import ErrorResponse 5 | 6 | 7 | def create_error_response( 8 | message: str, 9 | err_type: str = "internal_error", 10 | status_code: Union[int, HTTPStatus] = HTTPStatus.INTERNAL_SERVER_ERROR, 11 | param: str = None, 12 | code: str = None 13 | ): 14 | return { 15 | "error": { 16 | "message": message, 17 | "type": err_type, 18 | "param": param, 19 | "code": str(code or (status_code.value if isinstance(status_code, HTTPStatus) else status_code)) 20 | } 21 | } -------------------------------------------------------------------------------- /app/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.11" -------------------------------------------------------------------------------- /configure_mlx.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Get the total memory in MB 4 | TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024)) 5 | 6 | # Calculate 80% and TOTAL_MEM_GB-5GB in MB 7 | EIGHTY_PERCENT=$(($TOTAL_MEM_MB * 80 / 100)) 8 | MINUS_5GB=$((($TOTAL_MEM_MB - 5120))) 9 | 10 | # Calculate 70% and TOTAL_MEM_GB-8GB in MB 11 | SEVENTY_PERCENT=$(($TOTAL_MEM_MB * 70 / 100)) 12 | MINUS_8GB=$((($TOTAL_MEM_MB - 8192))) 13 | 14 | # Set WIRED_LIMIT_MB to higher value 15 | if [ $EIGHTY_PERCENT -gt $MINUS_5GB ]; then 16 | WIRED_LIMIT_MB=$EIGHTY_PERCENT 17 | else 18 | WIRED_LIMIT_MB=$MINUS_5GB 19 | fi 20 | 21 | # Set WIRED_LWM_MB to higher value 22 | if [ $SEVENTY_PERCENT -gt $MINUS_8GB ]; then 23 | WIRED_LWM_MB=$SEVENTY_PERCENT 24 | else 25 | WIRED_LWM_MB=$MINUS_8GB 26 | fi 27 | 28 | # Display the calculated values 29 | echo "Total memory: $TOTAL_MEM_MB MB" 30 | echo "Maximum limit (iogpu.wired_limit_mb): $WIRED_LIMIT_MB MB" 31 | echo "Lower bound (iogpu.wired_lwm_mb): $WIRED_LWM_MB MB" 32 | 33 | # Apply the values with sysctl, but check if we're already root 34 | if [ "$EUID" -eq 0 ]; then 35 | sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB 36 | sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 37 | else 38 | # Try without sudo first, fall back to sudo if needed 39 | sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB 2>/dev/null || \ 40 | sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB 41 | sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 2>/dev/null || \ 42 | sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 43 | fi -------------------------------------------------------------------------------- /examples/function_calling_examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# MLX Server Function Calling Example\n", 8 | "\n", 9 | "This is a detailed text version of the function calling example for MLX Server with OpenAI-compatible API." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Setup" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 53, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "from openai import OpenAI" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "## Initialize the client\n", 33 | "\n", 34 | "Connect to your local MLX server:" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 54, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "client = OpenAI(\n", 44 | " base_url = \"http://localhost:8000/v1\",\n", 45 | " api_key = \"mlx-server-api-key\"\n", 46 | ")" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## Function calling example\n", 54 | "\n", 55 | "This example demonstrates how to use function calling with the MLX server:" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 56, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "ChatCompletion(id='chatcmpl_1748500691067459', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_1748500691650837', function=Function(arguments='{\"city\": \"Tokyo\"}', name='get_weather'), type='function', index=0)], reasoning_content='Okay, the user is asking for the weather in Tokyo. Let me check the tools provided. There\\'s a function called get_weather that takes a city parameter. So I need to call that function with the city set to Tokyo. I\\'ll make sure the JSON is correctly formatted with the city name as a string. Let me double-check the parameters. The function requires \"city\" as a string, so the arguments should be {\"city\": \"Tokyo\"}. Alright, that\\'s all I need for the tool call.'))], created=1748500691, model='mlx-server-model', object='chat.completion', service_tier=None, system_fingerprint=None, usage=None)\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "# Define the user message\n", 73 | "messages = [\n", 74 | " {\n", 75 | " \"role\": \"user\",\n", 76 | " \"content\": \"What is the weather in Tokyo?\"\n", 77 | " }\n", 78 | "]\n", 79 | "\n", 80 | "# Define the available tools/functions\n", 81 | "tools = [\n", 82 | " {\n", 83 | " \"type\": \"function\",\n", 84 | " \"function\": {\n", 85 | " \"name\": \"get_weather\",\n", 86 | " \"description\": \"Get the weather in a given city\",\n", 87 | " \"parameters\": {\n", 88 | " \"type\": \"object\",\n", 89 | " \"properties\": {\n", 90 | " \"city\": {\"type\": \"string\", \"description\": \"The city to get the weather for\"}\n", 91 | " }\n", 92 | " }\n", 93 | " }\n", 94 | " }\n", 95 | "]\n", 96 | "\n", 97 | "# Make the API call\n", 98 | "completion = client.chat.completions.create(\n", 99 | " model=\"mlx-server-model\",\n", 100 | " messages=messages,\n", 101 | " tools=tools,\n", 102 | " tool_choice=\"auto\",\n", 103 | " max_tokens = 512,\n", 104 | " extra_body = {\n", 105 | " \"chat_template_kwargs\": {\n", 106 | " \"enable_thinking\": True\n", 107 | " }\n", 108 | " }\n", 109 | ")\n", 110 | "\n", 111 | "# Get the result\n", 112 | "print(completion)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "## Streaming version" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 57, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=None, reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1748500691, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n", 132 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id='call_1748500694172265', function=ChoiceDeltaToolCallFunction(arguments='', name='get_weather'), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1748500691, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n", 133 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments=' {\"', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1748500691, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n", 134 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='city', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1748500691, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n", 135 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='\":', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1748500691, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n", 136 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments=' \"', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1748500691, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n", 137 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='Tok', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1748500691, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n", 138 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='yo', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1748500691, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n", 139 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='\"}', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1748500691, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n", 140 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content='', function_call=None, refusal=None, role='assistant', tool_calls=None, reasoning_content=None), finish_reason='tool_calls', index=0, logprobs=None)], created=1748500695, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n" 141 | ] 142 | } 143 | ], 144 | "source": [ 145 | "# Set stream=True in the API call\n", 146 | "completion = client.chat.completions.create(\n", 147 | " model=\"mlx-server-model\",\n", 148 | " messages=messages,\n", 149 | " tools=tools,\n", 150 | " tool_choice=\"auto\",\n", 151 | " stream=True,\n", 152 | " extra_body = {\n", 153 | " \"chat_template_kwargs\": {\n", 154 | " \"enable_thinking\": False\n", 155 | " }\n", 156 | " }\n", 157 | ")\n", 158 | "\n", 159 | "# Process the streaming response\n", 160 | "for chunk in completion:\n", 161 | " print(chunk)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [] 170 | } 171 | ], 172 | "metadata": { 173 | "kernelspec": { 174 | "display_name": "Python 3", 175 | "language": "python", 176 | "name": "python3" 177 | }, 178 | "language_info": { 179 | "codemirror_mode": { 180 | "name": "ipython", 181 | "version": 3 182 | }, 183 | "file_extension": ".py", 184 | "mimetype": "text/x-python", 185 | "name": "python", 186 | "nbconvert_exporter": "python", 187 | "pygments_lexer": "ipython3", 188 | "version": "3.11.12" 189 | } 190 | }, 191 | "nbformat": 4, 192 | "nbformat_minor": 2 193 | } 194 | -------------------------------------------------------------------------------- /examples/images/attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cubist38/mlx-openai-server/1c0b466e28a61471b6bb8ed8fe21deb9db7a76a6/examples/images/attention.png -------------------------------------------------------------------------------- /examples/images/green_dog.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cubist38/mlx-openai-server/1c0b466e28a61471b6bb8ed8fe21deb9db7a76a6/examples/images/green_dog.jpeg -------------------------------------------------------------------------------- /examples/images/password.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cubist38/mlx-openai-server/1c0b466e28a61471b6bb8ed8fe21deb9db7a76a6/examples/images/password.jpg -------------------------------------------------------------------------------- /examples/lm_embeddings_examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Embeddings API Examples with MLX Server" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "This notebook demonstrates how to use the embeddings endpoint of MLX Server through the OpenAI-compatible API. You'll learn how to generate embeddings, work with batches, compare similarity between texts, and use embeddings for practical applications." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## Setup and Connection" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# Import the OpenAI client for API communication\n", 31 | "from openai import OpenAI\n", 32 | "\n", 33 | "# Connect to the local MLX Server with OpenAI-compatible API\n", 34 | "client = OpenAI(\n", 35 | " base_url=\"http://localhost:8000/v1\",\n", 36 | " api_key=\"fake-api-key\",\n", 37 | ")" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "## Basic Embedding Generation\n", 45 | "\n", 46 | "### Single Text Embedding\n" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "# Generate embedding for a single text input\n", 56 | "single_text = \"Artificial intelligence is transforming how we interact with technology.\"\n", 57 | "response = client.embeddings.create(\n", 58 | " input=[single_text],\n", 59 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n", 60 | ")" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "### Batch Processing Multiple Texts" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "text_batch = [\n", 77 | " \"Machine learning algorithms improve with more data\",\n", 78 | " \"Natural language processing helps computers understand human language\",\n", 79 | " \"Computer vision allows machines to interpret visual information\"\n", 80 | "]\n", 81 | "\n", 82 | "batch_response = client.embeddings.create(\n", 83 | " input=text_batch,\n", 84 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n", 85 | ")" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 4, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "Number of embeddings generated: 3\n", 98 | "Dimensions of each embedding: 1536\n" 99 | ] 100 | } 101 | ], 102 | "source": [ 103 | "# Access all embeddings\n", 104 | "embeddings = [item.embedding for item in batch_response.data]\n", 105 | "print(f\"Number of embeddings generated: {len(embeddings)}\")\n", 106 | "print(f\"Dimensions of each embedding: {len(embeddings[0])}\")" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "## Semantic Similarity Calculation\n", 114 | "\n", 115 | "One of the most common uses for embeddings is measuring semantic similarity between texts." 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 5, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "import numpy as np" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 6, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "def cosine_similarity_score(vec1, vec2):\n", 134 | " \"\"\"Calculate cosine similarity between two vectors\"\"\"\n", 135 | " dot_product = np.dot(vec1, vec2)\n", 136 | " norm1 = np.linalg.norm(vec1)\n", 137 | " norm2 = np.linalg.norm(vec2)\n", 138 | " return dot_product / (norm1 * norm2)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 7, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "# Example texts to compare\n", 148 | "text1 = \"Dogs are loyal pets that provide companionship\"\n", 149 | "text2 = \"Canines make friendly companions for humans\"\n", 150 | "text3 = \"Quantum physics explores the behavior of matter at atomic scales\"" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 8, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "# Generate embeddings\n", 160 | "comparison_texts = [text1, text2, text3]\n", 161 | "comparison_response = client.embeddings.create(\n", 162 | " input=comparison_texts,\n", 163 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n", 164 | ")\n", 165 | "comparison_embeddings = [item.embedding for item in comparison_response.data]" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 9, 171 | "metadata": {}, 172 | "outputs": [ 173 | { 174 | "name": "stdout", 175 | "output_type": "stream", 176 | "text": [ 177 | "Similarity between text1 and text2: 0.8142\n", 178 | "Similarity between text1 and text3: 0.6082\n", 179 | "Similarity between text2 and text3: 0.5739\n" 180 | ] 181 | } 182 | ], 183 | "source": [ 184 | "# Compare similarities\n", 185 | "similarity_1_2 = cosine_similarity_score(comparison_embeddings[0], comparison_embeddings[1])\n", 186 | "similarity_1_3 = cosine_similarity_score(comparison_embeddings[0], comparison_embeddings[2])\n", 187 | "similarity_2_3 = cosine_similarity_score(comparison_embeddings[1], comparison_embeddings[2])\n", 188 | "\n", 189 | "print(f\"Similarity between text1 and text2: {similarity_1_2:.4f}\")\n", 190 | "print(f\"Similarity between text1 and text3: {similarity_1_3:.4f}\")\n", 191 | "print(f\"Similarity between text2 and text3: {similarity_2_3:.4f}\")" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | "## Text Search Using Embeddings" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 10, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "# Sample document collection\n", 208 | "documents = [\n", 209 | " \"The quick brown fox jumps over the lazy dog\",\n", 210 | " \"Machine learning models require training data\",\n", 211 | " \"Neural networks are inspired by biological neurons\",\n", 212 | " \"Deep learning is a subset of machine learning\",\n", 213 | " \"Natural language processing helps with text analysis\",\n", 214 | " \"Computer vision systems can detect objects in images\"\n", 215 | "]" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 11, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "# Generate embeddings for all documents\n", 225 | "doc_response = client.embeddings.create(\n", 226 | " input=documents,\n", 227 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n", 228 | ")\n", 229 | "doc_embeddings = [item.embedding for item in doc_response.data]" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 12, 235 | "metadata": {}, 236 | "outputs": [ 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "Search results:\n", 242 | "Score: 0.8574 - Computer vision systems can detect objects in images\n", 243 | "Score: 0.8356 - Neural networks are inspired by biological neurons\n", 244 | "Score: 0.8266 - Natural language processing helps with text analysis\n", 245 | "Score: 0.8141 - Deep learning is a subset of machine learning\n", 246 | "Score: 0.7474 - Machine learning models require training data\n", 247 | "Score: 0.5936 - The quick brown fox jumps over the lazy dog\n" 248 | ] 249 | } 250 | ], 251 | "source": [ 252 | "def search_documents(query, doc_collection, doc_embeddings):\n", 253 | " \"\"\"Search for documents similar to query\"\"\"\n", 254 | " # Generate embedding for query\n", 255 | " query_response = client.embeddings.create(\n", 256 | " input=[query],\n", 257 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n", 258 | " )\n", 259 | " query_embedding = query_response.data[0].embedding\n", 260 | " \n", 261 | " # Calculate similarity scores\n", 262 | " similarities = []\n", 263 | " for doc_embedding in doc_embeddings:\n", 264 | " similarity = cosine_similarity_score(query_embedding, doc_embedding)\n", 265 | " similarities.append(similarity)\n", 266 | " \n", 267 | " # Return results with scores\n", 268 | " results = []\n", 269 | " for i, score in enumerate(similarities):\n", 270 | " results.append((doc_collection[i], score))\n", 271 | " \n", 272 | " # Sort by similarity score (highest first)\n", 273 | " return sorted(results, key=lambda x: x[1], reverse=True)\n", 274 | "\n", 275 | "# Example search\n", 276 | "search_results = search_documents(\"How do AI models learn?\", documents, doc_embeddings)\n", 277 | "\n", 278 | "print(\"Search results:\")\n", 279 | "for doc, score in search_results:\n", 280 | " print(f\"Score: {score:.4f} - {doc}\")" 281 | ] 282 | } 283 | ], 284 | "metadata": { 285 | "kernelspec": { 286 | "display_name": "Python 3", 287 | "language": "python", 288 | "name": "python3" 289 | }, 290 | "language_info": { 291 | "codemirror_mode": { 292 | "name": "ipython", 293 | "version": 3 294 | }, 295 | "file_extension": ".py", 296 | "mimetype": "text/x-python", 297 | "name": "python", 298 | "nbconvert_exporter": "python", 299 | "pygments_lexer": "ipython3", 300 | "version": "3.11.12" 301 | } 302 | }, 303 | "nbformat": 4, 304 | "nbformat_minor": 2 305 | } 306 | -------------------------------------------------------------------------------- /examples/pdfs/lab03.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cubist38/mlx-openai-server/1c0b466e28a61471b6bb8ed8fe21deb9db7a76a6/examples/pdfs/lab03.pdf -------------------------------------------------------------------------------- /examples/simple_rag_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "7fe5da27", 6 | "metadata": {}, 7 | "source": [ 8 | "# PDF Question Answering with Embedding Search\n", 9 | "\n", 10 | "This notebook demonstrates how to build a simple RAG (Retrieval-Augmented Generation) system that:\n", 11 | "\n", 12 | "1. Extracts and chunks text from PDF documents.\n", 13 | "2. Embeds the text using a local embedding model served via MLX Server.\n", 14 | "3. Stores the embeddings in a FAISS index for fast retrieval.\n", 15 | "4. Answers user queries by retrieving relevant chunks and using a chat model to respond based on context." 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "id": "fc490ab0", 21 | "metadata": {}, 22 | "source": [ 23 | "Before running the notebook, make sure to launch the local MLX Server by executing the following command in your terminal (`lm`: Text-only model): \n", 24 | "```bash\n", 25 | "mlx-server launch --model-path mlx-community/Qwen3-4B-8bit --model-type lm\n", 26 | "```\n", 27 | "This command starts the MLX API server locally at http://localhost:8000/v1, which exposes an OpenAI-compatible interface. It enables the specified model to be used for both embedding (vector representation of text) and response generation (chat completion).\n", 28 | "\n", 29 | "For this illustration, we use the model `mlx-community/Qwen3-4B-8bit`, a lightweight and efficient language model that supports both tasks. You can substitute this with any other compatible model depending on your use case and hardware capability." 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "id": "036359c3", 35 | "metadata": {}, 36 | "source": [ 37 | "## Install dependencies\n" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 15, 43 | "id": "2d641abc", 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "name": "stdout", 48 | "output_type": "stream", 49 | "text": [ 50 | "Note: you may need to restart the kernel to use updated packages.\n" 51 | ] 52 | } 53 | ], 54 | "source": [ 55 | "# Install required packages\n", 56 | "%pip install -Uq numpy PyMuPDF faiss-cpu openai" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "id": "ab9a6f78", 62 | "metadata": {}, 63 | "source": [ 64 | "## Initialize MLX Server client\n" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 16, 70 | "id": "546b6775", 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "# Connect to the local MLX Server that serves embedding and chat models\n", 75 | "from openai import OpenAI\n", 76 | "\n", 77 | "client = OpenAI(\n", 78 | " base_url=\"http://localhost:8000/v1\", # This is your MLX Server endpoint\n", 79 | " api_key=\"fake-api-key\" # Dummy key, not used by local server\n", 80 | ")" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "id": "c573141f", 86 | "metadata": {}, 87 | "source": [ 88 | "## Load and chunk PDF document\n", 89 | "\n", 90 | "We load the PDF file and split it into smaller chunks to ensure each chunk fits within the context window of the model." 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "id": "8001354f", 96 | "metadata": {}, 97 | "source": [ 98 | "### Read PDF\n", 99 | "Extracts text from each page of the PDF." 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 17, 105 | "id": "27be75d1", 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "import fitz # PyMuPDF\n", 110 | "\n", 111 | "def read_pdf(path):\n", 112 | " \"\"\"Extract text from each page of a PDF file.\"\"\"\n", 113 | " doc = fitz.open(path)\n", 114 | " texts = []\n", 115 | " for page in doc:\n", 116 | " text = page.get_text().strip()\n", 117 | " if text: \n", 118 | " texts.append(text)\n", 119 | " doc.close()\n", 120 | " return texts" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "id": "adae0912", 126 | "metadata": {}, 127 | "source": [ 128 | "### Chunk Text\n", 129 | "Splits the extracted text into smaller chunks with overlap to preserve context.\n", 130 | "We use the following parameters:\n", 131 | "- `chunk_size=500`: Maximum number of words in a chunk.\n", 132 | "- `overlap=100`: Number of words overlapping between consecutive chunks to avoid breaking context too harshly.\n", 133 | "\n", 134 | "Each chunk is created using simple whitespace (`\" \"`) tokenization and rejoined with spaces, which works well for general text." 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 18, 140 | "id": "e093a75f", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "def chunk_text(texts, chunk_size=400, overlap=200):\n", 145 | " \"\"\"Split text into smaller chunks with overlap for better context preservation.\"\"\"\n", 146 | " chunks = []\n", 147 | " for text in texts:\n", 148 | " words = text.split() \n", 149 | " i = 0\n", 150 | " while i < len(words):\n", 151 | " chunk = words[i:i + chunk_size]\n", 152 | " chunks.append(\" \".join(chunk))\n", 153 | " i += chunk_size - overlap\n", 154 | " return chunks" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "id": "603c9e71", 160 | "metadata": {}, 161 | "source": [ 162 | "## Save embeddings and chunks to FAISS" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "id": "d7d76e7d", 168 | "metadata": {}, 169 | "source": [ 170 | "### Embed Chunks\n", 171 | "\n", 172 | "Uses the MLX Server to generate embeddings for the text chunks." 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 19, 178 | "id": "eeb628d3", 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "import numpy as np\n", 183 | "\n", 184 | "# Embed chunks using the model served by MLX Server\n", 185 | "def embed_chunks(chunks, model_name):\n", 186 | " \"\"\"Generate embeddings for text chunks using the MLX Server model.\"\"\"\n", 187 | " response = client.embeddings.create(input=chunks, model=model_name)\n", 188 | " embeddings = [np.array(item.embedding).astype('float32') for item in response.data]\n", 189 | " return embeddings" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "id": "294297ae", 195 | "metadata": {}, 196 | "source": [ 197 | "### Save to FAISS\n", 198 | "\n", 199 | "Saves the embeddings in a FAISS index and the chunks in a metadata file." 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "id": "68d39501", 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "import faiss\n", 210 | "import os\n", 211 | "import pickle\n", 212 | "import numpy as np\n", 213 | "\n", 214 | "def normalize(vectors):\n", 215 | " \"\"\"\n", 216 | " Normalize a set of vectors.\n", 217 | " \"\"\"\n", 218 | " norms = np.linalg.norm(vectors, axis=1, keepdims=True) # Compute L2 norms for each vector\n", 219 | " return vectors / norms # Divide each vector by its norm to normalize\n", 220 | "\n", 221 | "def save_faiss_index(embeddings, chunks, index_path=\"db/index.faiss\", meta_path=\"db/meta.pkl\"):\n", 222 | " \"\"\"\n", 223 | " The embeddings are stored in a FAISS index, \n", 224 | " and the corresponding text chunks are saved in a metadata file locally. \n", 225 | " \"\"\"\n", 226 | " if not os.path.exists(\"db\"):\n", 227 | " os.makedirs(\"db\") \n", 228 | " dim = len(embeddings[0])\n", 229 | " \n", 230 | " # Normalize the embeddings to unit length for cosine similarity\n", 231 | " # This is required because FAISS's IndexFlatIP uses inner product\n", 232 | " embeddings = normalize(embeddings)\n", 233 | " index = faiss.IndexFlatIP(dim)\n", 234 | " index.add(np.array(embeddings))\n", 235 | " faiss.write_index(index, index_path)\n", 236 | " with open(meta_path, \"wb\") as f:\n", 237 | " pickle.dump(chunks, f)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "id": "3d34ab09", 243 | "metadata": {}, 244 | "source": [ 245 | "Combines the above steps into a single pipeline to process a PDF." 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 21, 251 | "id": "c2a4f4a2", 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "# Full pipeline: Read PDF → Chunk → Embed → Save\n", 256 | "def prepare_pdf(pdf_path, model_name):\n", 257 | " texts = read_pdf(pdf_path)\n", 258 | " chunks = chunk_text(texts)\n", 259 | " embeddings = embed_chunks(chunks, model_name)\n", 260 | " save_faiss_index(embeddings, chunks)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "id": "012725f5", 266 | "metadata": {}, 267 | "source": [ 268 | "## Query PDF using FAISS" 269 | ] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "id": "4056aa04", 274 | "metadata": {}, 275 | "source": [ 276 | "### Load FAISS Index" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 22, 282 | "id": "4022de75", 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "def load_faiss_index(index_path=\"db/index.faiss\", meta_path=\"db/meta.pkl\"):\n", 287 | " \"\"\"Load the FAISS index and corresponding text chunks from disk.\"\"\"\n", 288 | " index = faiss.read_index(index_path)\n", 289 | " with open(meta_path, \"rb\") as f:\n", 290 | " chunks = pickle.load(f)\n", 291 | " return index, chunks" 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "id": "142389ab", 297 | "metadata": {}, 298 | "source": [ 299 | "### Embed Query\n", 300 | "Embeds the user's query using the same model." 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 23, 306 | "id": "4791b77c", 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "def embed_query(query, model_name):\n", 311 | " \"\"\"Convert a query string into an embedding vector.\"\"\"\n", 312 | " embedding = client.embeddings.create(input=[query], model=model_name).data[0].embedding\n", 313 | " return np.array(embedding).astype('float32')" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "id": "ada96dd2", 319 | "metadata": {}, 320 | "source": [ 321 | "### Retrieve Chunks\n", 322 | "Retrieves the top-k most relevant chunks based on the query embedding." 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 24, 328 | "id": "267bcfe2", 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "def retrieve_chunks(query, index, chunks, model_name, top_k=5):\n", 333 | " \"\"\"Retrieve the top-k most relevant chunks from the FAISS index.\"\"\"\n", 334 | " query_vector = embed_query(query, model_name).reshape(1, -1)\n", 335 | " query_vector = normalize(query_vector) # Normalize the query vector\n", 336 | " distances, indices = index.search(query_vector, top_k) # Search for nearest neighbors\n", 337 | " relevant_chunks = [chunks[i] for i in indices[0]] \n", 338 | " return relevant_chunks" 339 | ] 340 | }, 341 | { 342 | "cell_type": "markdown", 343 | "id": "6c3491a5", 344 | "metadata": {}, 345 | "source": [ 346 | "### Generate Answer with Context" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 25, 352 | "id": "34a811b3", 353 | "metadata": {}, 354 | "outputs": [], 355 | "source": [ 356 | "def answer_with_context(query, retrieved_chunks, model_name):\n", 357 | " \"\"\"Generate a response to the query using retrieved chunks as context.\"\"\"\n", 358 | " context = \"\\n\".join(retrieved_chunks)\n", 359 | " prompt = f\"\"\"You are a helpful assistant. Use the context below to answer the question.\n", 360 | " Context:\n", 361 | " {context}\n", 362 | " Question: {query}\n", 363 | " Answer:\"\"\"\n", 364 | " \n", 365 | " response = client.chat.completions.create(\n", 366 | " model=model_name,\n", 367 | " messages=[{\"role\": \"user\", \"content\": prompt}]\n", 368 | " )\n", 369 | " return response.choices[0].message.content" 370 | ] 371 | }, 372 | { 373 | "cell_type": "markdown", 374 | "id": "487af341", 375 | "metadata": {}, 376 | "source": [ 377 | "Combines the query steps into a single function." 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": 26, 383 | "id": "2f09e17d", 384 | "metadata": {}, 385 | "outputs": [], 386 | "source": [ 387 | "# Full pipeline: Query → Embed → Retrieve → Answer\n", 388 | "def query_pdf(query, model_name=\"mlx-community/Qwen3-4B-8bit\"):\n", 389 | " index, chunks = load_faiss_index()\n", 390 | " top_chunks = retrieve_chunks(query, index, chunks, model_name)\n", 391 | " return answer_with_context(query, top_chunks, model_name)" 392 | ] 393 | }, 394 | { 395 | "cell_type": "markdown", 396 | "id": "71c53682", 397 | "metadata": {}, 398 | "source": [ 399 | "## Example Usage" 400 | ] 401 | }, 402 | { 403 | "cell_type": "markdown", 404 | "id": "7b022805", 405 | "metadata": {}, 406 | "source": [ 407 | "Index text chunks from PDF into FAISS using Qwen3-4B-8bit model" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 27, 413 | "id": "18552003", 414 | "metadata": {}, 415 | "outputs": [], 416 | "source": [ 417 | "prepare_pdf(\"./pdfs/lab03.pdf\", \"mlx-community/Qwen3-4B-8bit\")" 418 | ] 419 | }, 420 | { 421 | "cell_type": "markdown", 422 | "id": "af4fcab9", 423 | "metadata": {}, 424 | "source": [ 425 | "Sample query:" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": 28, 431 | "id": "9c333397", 432 | "metadata": {}, 433 | "outputs": [ 434 | { 435 | "name": "stdout", 436 | "output_type": "stream", 437 | "text": [ 438 | "Query: What submissions do I need to submit in this lab?\n", 439 | "Response: For this lab, you need to submit the following:\n", 440 | "\n", 441 | "1. **StudentID1_StudentID2_Report.pdf**: \n", 442 | " - A short report explaining your solution, any problems encountered, and any approaches you tried that didn't work. This report should not include your source code but should clarify your solution and any issues you faced.\n", 443 | "\n", 444 | "2. **.patch**: \n", 445 | " - A git diff file showing the changes you made to the xv6 codebase. You can generate this by running the command: \n", 446 | " ```bash\n", 447 | " $ git diff > .patch\n", 448 | " ```\n", 449 | "\n", 450 | "3. **Zip file of xv6**: \n", 451 | " - A zip file containing the modified xv6 codebase. The code should be in a clean state (i.e., after running `make clean`). The filename should follow the format: \n", 452 | " ```bash\n", 453 | " .zip\n", 454 | " ``` \n", 455 | " For example, if the students' IDs are 2312001 and 2312002, the filename would be: \n", 456 | " ```bash\n", 457 | " 2312001_2312002.zip\n", 458 | " ```\n", 459 | "\n", 460 | "Make sure to follow these submission guidelines carefully to ensure your work is graded properly.\n" 461 | ] 462 | } 463 | ], 464 | "source": [ 465 | "# Ask a question related to the content of the PDF\n", 466 | "query = \"What submissions do I need to submit in this lab?\"\n", 467 | "print(\"Query: \", query)\n", 468 | "response = query_pdf(query, model_name=\"mlx-community/Qwen3-4B-8bit\")\n", 469 | "print(\"Response: \", response)" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": 29, 475 | "id": "b8eee4af", 476 | "metadata": {}, 477 | "outputs": [ 478 | { 479 | "name": "stdout", 480 | "output_type": "stream", 481 | "text": [ 482 | "Query: What is the hint for Lab 4.2 – Speed up system calls?\n", 483 | "Response: The hint for Lab 4.2 – Speed up system calls is to choose permission bits that allow userspace to only read the page. This ensures that the shared page between userspace and the kernel can be accessed by userspace for reading the PID, but not modified, which is necessary for the optimization of the getpid() system call.\n" 484 | ] 485 | } 486 | ], 487 | "source": [ 488 | "# Ask a question related to the content of the PDF\n", 489 | "query = \"What is the hint for Lab 4.2 – Speed up system calls?\"\n", 490 | "print(\"Query: \", query)\n", 491 | "response = query_pdf(query, model_name=\"mlx-community/Qwen3-4B-8bit\")\n", 492 | "print(\"Response: \", response)" 493 | ] 494 | } 495 | ], 496 | "metadata": { 497 | "kernelspec": { 498 | "display_name": "base", 499 | "language": "python", 500 | "name": "python3" 501 | }, 502 | "language_info": { 503 | "codemirror_mode": { 504 | "name": "ipython", 505 | "version": 3 506 | }, 507 | "file_extension": ".py", 508 | "mimetype": "text/x-python", 509 | "name": "python", 510 | "nbconvert_exporter": "python", 511 | "pygments_lexer": "ipython3", 512 | "version": "3.12.2" 513 | } 514 | }, 515 | "nbformat": 4, 516 | "nbformat_minor": 5 517 | } 518 | -------------------------------------------------------------------------------- /examples/vlm_embeddings_examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Vision-Language Model (VLM) Embeddings with MLX Server\n", 8 | "\n", 9 | "This notebook demonstrates how to leverage the embeddings endpoint of MLX Server through its OpenAI-compatible API. Vision-Language Models (VLMs) can process both images and text, allowing for multimodal understanding and representation.\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "\n", 17 | "## Introduction\n", 18 | "\n", 19 | "MLX Server provides an efficient way to serve multimodal models on Apple Silicon. In this notebook, we'll explore how to:\n", 20 | "\n", 21 | "- Generate embeddings for text and images\n", 22 | "- Work with the OpenAI-compatible API\n", 23 | "- Calculate similarity between text and image representations\n", 24 | "- Understand how these embeddings can be used for practical applications\n", 25 | "\n", 26 | "Embeddings are high-dimensional vector representations of content that capture semantic meaning, making them useful for search, recommendation systems, and other AI applications." 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## 1. Setup and API Connection\n", 34 | "\n", 35 | "- A local server endpoint (`http://localhost:8000/v1`)\n", 36 | "- A placeholder API key (since MLX Server doesn't require authentication by default)\n", 37 | "\n", 38 | "Make sure you have MLX Server running locally before executing this notebook." 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 1, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "# Import the OpenAI client for API communication\n", 48 | "from openai import OpenAI\n", 49 | "\n", 50 | "# Connect to the local MLX Server with OpenAI-compatible API\n", 51 | "client = OpenAI(\n", 52 | " base_url=\"http://localhost:8000/v1\",\n", 53 | " api_key=\"fake-api-key\",\n", 54 | ")" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "## 2. Image Processing for API Requests\n", 62 | "\n", 63 | "When working with image inputs, we need to prepare them in a format that the API can understand. The OpenAI-compatible API expects images to be provided as base64-encoded data URIs.\n", 64 | "\n", 65 | "Below, we'll import the necessary libraries and define a helper function to convert PIL Image objects to the required format." 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 2, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "from PIL import Image\n", 75 | "from io import BytesIO\n", 76 | "import base64" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 3, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "# To send images to the API, we need to convert them to base64-encoded strings in a data URI format.\n", 86 | "\n", 87 | "def image_to_base64(image: Image.Image):\n", 88 | " \"\"\"\n", 89 | " Convert a PIL Image to a base64-encoded data URI string that can be sent to the API.\n", 90 | " \n", 91 | " Args:\n", 92 | " image: A PIL Image object\n", 93 | " \n", 94 | " Returns:\n", 95 | " A data URI string with the base64-encoded image\n", 96 | " \"\"\"\n", 97 | " # Convert image to bytes\n", 98 | " buffer = BytesIO()\n", 99 | " image.save(buffer, format=\"PNG\")\n", 100 | " buffer.seek(0)\n", 101 | " image_data = buffer.getvalue()\n", 102 | " \n", 103 | " # Encode as base64\n", 104 | " image_base64 = base64.b64encode(image_data).decode('utf-8')\n", 105 | " \n", 106 | " # Create the data URI format required by the API\n", 107 | " mime_type = \"image/png\" \n", 108 | " image_uri = f\"data:{mime_type};base64,{image_base64}\"\n", 109 | " \n", 110 | " return image_uri" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "## 3. Loading and Preparing an Image\n", 118 | "\n", 119 | "Now we'll load a sample image (a green dog in this case) and convert it to the base64 format required by the API. This image will be used to generate embeddings in the subsequent steps." 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 5, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "image = Image.open(\"images/green_dog.jpeg\")\n", 129 | "image_uri = image_to_base64(image)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "## 4. Generating Embeddings" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 12, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "# Generate embedding for a single text input\n", 146 | "prompt = \"Describe the image in detail\"\n", 147 | "image_embedding = client.embeddings.create(\n", 148 | " input=[prompt],\n", 149 | " model=\"mlx-community/Qwen2.5-VL-3B-Instruct-4bit\",\n", 150 | " extra_body = {\n", 151 | " \"image_url\": image_uri\n", 152 | " }\n", 153 | ").data[0].embedding\n", 154 | "\n", 155 | "text = \"A green dog looking at the camera\"\n", 156 | "text_embedding = client.embeddings.create(\n", 157 | " input=[text],\n", 158 | " model=\"mlx-community/Qwen2.5-VL-3B-Instruct-4bit\"\n", 159 | ").data[0].embedding" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "## 5. Comparing Text and Image Embeddings\n", 167 | "\n", 168 | "One of the powerful features of VLM embeddings is that they create a shared vector space for both text and images. This means we can directly compare how similar a text description is to an image's content by calculating the cosine similarity between their embeddings.\n", 169 | "\n", 170 | "A higher similarity score (closer to 1.0) indicates that the text description closely matches the image content according to the model's understanding." 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 13, 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "name": "stdout", 180 | "output_type": "stream", 181 | "text": [ 182 | "0.8473370724651375\n" 183 | ] 184 | } 185 | ], 186 | "source": [ 187 | "import numpy as np\n", 188 | "\n", 189 | "def cosine_similarity(a, b):\n", 190 | " a = np.array(a)\n", 191 | " b = np.array(b)\n", 192 | " return np.dot(a, b)\n", 193 | "\n", 194 | "similarity = cosine_similarity(image_embedding, text_embedding)\n", 195 | "print(similarity)" 196 | ] 197 | } 198 | ], 199 | "metadata": { 200 | "kernelspec": { 201 | "display_name": "Python 3", 202 | "language": "python", 203 | "name": "python3" 204 | }, 205 | "language_info": { 206 | "codemirror_mode": { 207 | "name": "ipython", 208 | "version": 3 209 | }, 210 | "file_extension": ".py", 211 | "mimetype": "text/x-python", 212 | "name": "python", 213 | "nbconvert_exporter": "python", 214 | "pygments_lexer": "ipython3", 215 | "version": "3.11.12" 216 | } 217 | }, 218 | "nbformat": 4, 219 | "nbformat_minor": 2 220 | } 221 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from app import __version__ 3 | 4 | 5 | setup( 6 | name="mlx-openai-server", 7 | url="https://github.com/cubist38/mlx-openai-server", 8 | author="Huy Vuong", 9 | author_email="cubist38@gmail.com", 10 | version=__version__, 11 | description="A high-performance API server that provides OpenAI-compatible endpoints for MLX models. Built with Python and FastAPI, it enables efficient, scalable, and user-friendly local deployment of MLX-based vision and language models with an OpenAI-compatible interface. Perfect for developers looking to run MLX models locally while maintaining compatibility with existing OpenAI-based applications.", 12 | long_description=open("README.md").read(), 13 | long_description_content_type="text/markdown", 14 | packages=find_packages(), 15 | install_requires=[ 16 | "mlx-vlm==0.1.27", 17 | "mlx-lm==0.25.2", 18 | "fastapi", 19 | "uvicorn", 20 | "Pillow", 21 | "click", 22 | "loguru", 23 | ], 24 | extras_require={ 25 | "dev": [ 26 | "pytest", 27 | "black", 28 | "isort", 29 | "flake8", 30 | ] 31 | }, 32 | entry_points={ 33 | "console_scripts": [ 34 | "mlx-openai-server=app.cli:cli", 35 | ], 36 | }, 37 | python_requires=">=3.11", 38 | classifiers=[ 39 | "Programming Language :: Python :: 3", 40 | "License :: OSI Approved :: MIT License", 41 | "Operating System :: OS Independent", 42 | ], 43 | ) 44 | -------------------------------------------------------------------------------- /tests/test_base_tool_parser.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from app.handler.parser.base import BaseToolParser, ParseState 4 | 5 | 6 | class TestBaseToolParser(unittest.TestCase): 7 | def setUp(self): 8 | self.test_cases = [ 9 | { 10 | "name": "simple function call", 11 | "chunks": '''## 12 | #{"#name#":# "#get#_weather#",# "#arguments#":# {"#city#":# "#H#ue#"}} 13 | ## 14 | ## 15 | #{"#name#":# "#get#_weather#",# "#arguments#":# {"#city#":# "#Sy#dney#"}} 16 | ###'''.split('#') 17 | , 18 | "expected_outputs": [ 19 | {'name': 'get_weather', 'arguments': ''}, 20 | {'name': None, 'arguments': ' {"'}, 21 | {'name': None, 'arguments': 'city'}, 22 | {'name': None, 'arguments': '":'}, 23 | {'name': None, 'arguments': ' "'}, 24 | {'name': None, 'arguments': 'H'}, 25 | {'name': None, 'arguments': 'ue'}, 26 | {'name': None, 'arguments': '"}'}, 27 | '\n', 28 | {'name': 'get_weather', 'arguments': ''}, 29 | {'name': None, 'arguments': ' {"'}, 30 | {'name': None, 'arguments': 'city'}, 31 | {'name': None, 'arguments': '":'}, 32 | {'name': None, 'arguments': ' "'}, 33 | {'name': None, 'arguments': 'Sy'}, 34 | {'name': None, 'arguments': 'dney'}, 35 | {'name': None, 'arguments': '"}'}, 36 | ] 37 | }, 38 | { 39 | "name": "code function call", 40 | "chunks": r'''@@ 41 | @@{"@@name@@":@@ "@@python@@",@@ "@@arguments@@":@@ {"@@code@@":@@ "@@def@@ calculator@@(a@@,@@ b@@,@@ operation@@):\@@n@@ @@ if@@ operation@@ ==@@ '@@add@@'\@@n@@ @@ return@@ a@@ +@@ b@@\n@@ @@ if@@ operation@@ ==@@ '@@subtract@@'\@@n@@ @@ return@@ a@@ -@@ b@@\n@@ @@ if@@ operation@@ ==@@ '@@multiply@@'\@@n@@ @@ return@@ a@@ *@@ b@@\n@@ @@ if@@ operation@@ ==@@ '@@divide@@'\@@n@@ @@ return@@ a@@ /@@ b@@\n@@ @@ return@@ '@@Invalid@@ operation@@'@@"}} 42 | @@@@@@'''.split('@@') 43 | , 44 | "expected_outputs": [ 45 | {'name': 'python', 'arguments': ''}, 46 | {'name': None, 'arguments': ' {"'}, 47 | {'name': None, 'arguments': 'code'}, 48 | {'name': None, 'arguments': '":'}, 49 | {'name': None, 'arguments': ' "'}, 50 | {'name': None, 'arguments': 'def'}, 51 | {'name': None, 'arguments': ' calculator'}, 52 | {'name': None, 'arguments': '(a'}, 53 | {'name': None, 'arguments': ','}, 54 | {'name': None, 'arguments': ' b'}, 55 | {'name': None, 'arguments': ','}, 56 | {'name': None, 'arguments': ' operation'}, 57 | {'name': None, 'arguments': '):\\'}, 58 | {'name': None, 'arguments': 'n'}, 59 | {'name': None, 'arguments': ' '}, 60 | {'name': None, 'arguments': ' if'}, 61 | {'name': None, 'arguments': ' operation'}, 62 | {'name': None, 'arguments': ' =='}, 63 | {'name': None, 'arguments': " '"}, 64 | {'name': None, 'arguments': 'add'}, 65 | {'name': None, 'arguments': "'\\"}, 66 | {'name': None, 'arguments': 'n'}, 67 | {'name': None, 'arguments': ' '}, 68 | {'name': None, 'arguments': ' return'}, 69 | {'name': None, 'arguments': ' a'}, 70 | {'name': None, 'arguments': ' +'}, 71 | {'name': None, 'arguments': ' b'}, 72 | {'name': None, 'arguments': '\\n'}, 73 | {'name': None, 'arguments': ' '}, 74 | {'name': None, 'arguments': ' if'}, 75 | {'name': None, 'arguments': ' operation'}, 76 | {'name': None, 'arguments': ' =='}, 77 | {'name': None, 'arguments': " '"}, 78 | {'name': None, 'arguments': 'subtract'}, 79 | {'name': None, 'arguments': "'\\"}, 80 | {'name': None, 'arguments': 'n'}, 81 | {'name': None, 'arguments': ' '}, 82 | {'name': None, 'arguments': ' return'}, 83 | {'name': None, 'arguments': ' a'}, 84 | {'name': None, 'arguments': ' -'}, 85 | {'name': None, 'arguments': ' b'}, 86 | {'name': None, 'arguments': '\\n'}, 87 | {'name': None, 'arguments': ' '}, 88 | {'name': None, 'arguments': ' if'}, 89 | {'name': None, 'arguments': ' operation'}, 90 | {'name': None, 'arguments': ' =='}, 91 | {'name': None, 'arguments': " '"}, 92 | {'name': None, 'arguments': 'multiply'}, 93 | {'name': None, 'arguments': "'\\"}, 94 | {'name': None, 'arguments': 'n'}, 95 | {'name': None, 'arguments': ' '}, 96 | {'name': None, 'arguments': ' return'}, 97 | {'name': None, 'arguments': ' a'}, 98 | {'name': None, 'arguments': ' *'}, 99 | {'name': None, 'arguments': ' b'}, 100 | {'name': None, 'arguments': '\\n'}, 101 | {'name': None, 'arguments': ' '}, 102 | {'name': None, 'arguments': ' if'}, 103 | {'name': None, 'arguments': ' operation'}, 104 | {'name': None, 'arguments': ' =='}, 105 | {'name': None, 'arguments': " '"}, 106 | {'name': None, 'arguments': 'divide'}, 107 | {'name': None, 'arguments': "'\\"}, 108 | {'name': None, 'arguments': 'n'}, 109 | {'name': None, 'arguments': ' '}, 110 | {'name': None, 'arguments': ' return'}, 111 | {'name': None, 'arguments': ' a'}, 112 | {'name': None, 'arguments': ' /'}, 113 | {'name': None, 'arguments': ' b'}, 114 | {'name': None, 'arguments': '\\n'}, 115 | {'name': None, 'arguments': ' '}, 116 | {'name': None, 'arguments': ' return'}, 117 | {'name': None, 'arguments': " '"}, 118 | {'name': None, 'arguments': 'Invalid'}, 119 | {'name': None, 'arguments': ' operation'}, 120 | {'name': None, 'arguments': "'"}, 121 | {'name': None, 'arguments': '"}'}, 122 | ] 123 | }, 124 | ] 125 | 126 | def test_parse_stream(self): 127 | for test_case in self.test_cases: 128 | with self.subTest(msg=test_case["name"]): 129 | parser = BaseToolParser("", "") 130 | outputs = [] 131 | 132 | for chunk in test_case["chunks"]: 133 | result = parser.parse_stream(chunk) 134 | if result: 135 | outputs.append(result) 136 | 137 | 138 | self.assertEqual(len(outputs), len(test_case["expected_outputs"]), 139 | f"Expected {len(test_case['expected_outputs'])} outputs, got {len(outputs)}") 140 | 141 | for i, (output, expected) in enumerate(zip(outputs, test_case["expected_outputs"])): 142 | self.assertEqual(output, expected, 143 | f"Chunk {i}: Expected {expected}, got {output}") 144 | 145 | if __name__ == '__main__': 146 | unittest.main() 147 | --------------------------------------------------------------------------------