├── .github
└── workflows
│ └── publish.yml
├── .gitignore
├── MANIFEST.in
├── Makefile
├── README.md
├── app
├── __init__.py
├── api
│ ├── __init__.py
│ └── endpoints.py
├── cli.py
├── core
│ ├── __init__.py
│ ├── image_processor.py
│ └── queue.py
├── handler
│ ├── __init__.py
│ ├── mlx_lm.py
│ ├── mlx_vlm.py
│ └── parser
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── qwen3.py
├── main.py
├── models
│ ├── __init__.py
│ ├── mlx_lm.py
│ └── mlx_vlm.py
├── schemas
│ ├── __init__.py
│ └── openai.py
├── utils
│ ├── __init__.py
│ └── errors.py
└── version.py
├── configure_mlx.sh
├── examples
├── function_calling_examples.ipynb
├── images
│ ├── attention.png
│ ├── green_dog.jpeg
│ └── password.jpg
├── lm_embeddings_examples.ipynb
├── pdfs
│ └── lab03.pdf
├── simple_rag_demo.ipynb
├── vision_examples.ipynb
└── vlm_embeddings_examples.ipynb
├── setup.py
└── tests
└── test_base_tool_parser.py
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish Python 🐍 distribution 📦 to PyPI
2 |
3 | on:
4 | push:
5 | tags:
6 | - 'v*' # Triggers on version tags like v1.0.0
7 |
8 | jobs:
9 | build-and-publish:
10 | runs-on: macos-latest
11 |
12 | steps:
13 | - uses: actions/checkout@v4
14 |
15 | - name: Set up Python
16 | uses: actions/setup-python@v5
17 | with:
18 | python-version: '3.11'
19 |
20 | - name: Install build tools
21 | run: |
22 | python -m pip install --upgrade pip
23 | pip install build twine
24 |
25 | - name: Build package
26 | run: python -m build
27 |
28 | - name: Publish package to PyPI
29 | env:
30 | TWINE_USERNAME: __token__
31 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
32 | run: twine upload dist/*
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | oai-compat-server
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 | # ignore .DS_Store
11 | .DS_Store
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | # For a library or package, you might want to ignore these files since the code is
91 | # intended to run in multiple environments; otherwise, check them in:
92 | # .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # UV
102 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
103 | # This is especially recommended for binary packages to ensure reproducibility, and is more
104 | # commonly ignored for libraries.
105 | #uv.lock
106 |
107 | # poetry
108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
109 | # This is especially recommended for binary packages to ensure reproducibility, and is more
110 | # commonly ignored for libraries.
111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
112 | #poetry.lock
113 |
114 | # pdm
115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
116 | #pdm.lock
117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
118 | # in version control.
119 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
120 | .pdm.toml
121 | .pdm-python
122 | .pdm-build/
123 |
124 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
125 | __pypackages__/
126 |
127 | # Celery stuff
128 | celerybeat-schedule
129 | celerybeat.pid
130 |
131 | # SageMath parsed files
132 | *.sage.py
133 |
134 | # Environments
135 | .env
136 | .venv
137 | env/
138 | venv/
139 | ENV/
140 | env.bak/
141 | venv.bak/
142 |
143 | # Spyder project settings
144 | .spyderproject
145 | .spyproject
146 |
147 | # Rope project settings
148 | .ropeproject
149 |
150 | # mkdocs documentation
151 | /site
152 |
153 | # mypy
154 | .mypy_cache/
155 | .dmypy.json
156 | dmypy.json
157 |
158 | # Pyre type checker
159 | .pyre/
160 |
161 | # pytype static type analyzer
162 | .pytype/
163 |
164 | # Cython debug symbols
165 | cython_debug/
166 |
167 | # PyCharm
168 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
169 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
170 | # and can be added to the global gitignore or merged into this file. For a more nuclear
171 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
172 | #.idea/
173 |
174 | # Ruff stuff:
175 | .ruff_cache/
176 |
177 | # PyPI configuration file
178 | .pypirc
179 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
2 | include requirements.txt
3 | include MANIFEST.in
4 | include setup.py
5 | recursive-include app *
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | run:
2 | mlx-server launch \
3 | --model-path mlx-community/Qwen3-1.7B-4bit \
4 | --model-type lm \
5 | --max-concurrency 1 \
6 | --queue-timeout 300 \
7 | --queue-size 100
8 |
9 | install:
10 | pip install -e .
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # mlx-openai-server
2 |
3 | [](LICENSE)
4 | [](https://www.python.org/downloads/release/python-3110/)
5 |
6 | ## Description
7 | This repository hosts a high-performance API server that provides OpenAI-compatible endpoints for MLX models. Developed using Python and powered by the FastAPI framework, it provides an efficient, scalable, and user-friendly solution for running MLX-based vision and language models locally with an OpenAI-compatible interface.
8 |
9 | > **Note:** This project currently supports **MacOS with M-series chips** only as it specifically leverages MLX, Apple's framework optimized for Apple Silicon.
10 |
11 | ---
12 |
13 | ## Table of Contents
14 | - [Key Features](#key-features)
15 | - [Demo](#demo)
16 | - [OpenAI Compatibility](#openai-compatibility)
17 | - [Supported Model Types](#supported-model-types)
18 | - [Installation](#installation)
19 | - [Usage](#usage)
20 | - [Starting the Server](#starting-the-server)
21 | - [CLI Usage](#cli-usage)
22 | - [Using the API](#using-the-api)
23 | - [Request Queue System](#request-queue-system)
24 | - [API Response Schemas](#api-response-schemas)
25 | - [Example Notebooks](#example-notebooks)
26 | - [Large Models](#large-models)
27 | - [Contributing](#contributing)
28 | - [License](#license)
29 | - [Support](#support)
30 | - [Acknowledgments](#acknowledgments)
31 |
32 | ---
33 |
34 | ## Key Features
35 | - 🚀 **Fast, local OpenAI-compatible API** for MLX models
36 | - 🖼️ **Vision-language and text-only model support**
37 | - 🔌 **Drop-in replacement** for OpenAI API in your apps
38 | - 📈 **Performance and queue monitoring endpoints**
39 | - 🧑💻 **Easy Python and CLI usage**
40 | - 🛡️ **Robust error handling and request management**
41 |
42 | ---
43 |
44 | ## Demo
45 |
46 | ### 🚀 See It In Action
47 |
48 | Check out our [video demonstration](https://youtu.be/D9a3AZSj6v8) to see the server in action! The demo showcases:
49 |
50 | - Setting up and launching the server
51 | - Using the OpenAI Python SDK for seamless integration
52 |
53 |
54 |
55 |
56 |
57 |
58 | ---
59 |
60 | ## OpenAI Compatibility
61 |
62 | This server implements the OpenAI API interface, allowing you to use it as a drop-in replacement for OpenAI's services in your applications. It supports:
63 | - Chat completions (both streaming and non-streaming)
64 | - Vision-language model interactions
65 | - Embeddings generation
66 | - Function calling and tool use
67 | - Standard OpenAI request/response formats
68 | - Common OpenAI parameters (temperature, top_p, etc.)
69 |
70 | ## Supported Model Types
71 |
72 | The server supports two types of MLX models:
73 |
74 | 1. **Text-only models** (`--model-type lm`) - Uses the `mlx-lm` library for pure language models
75 | 2. **Vision-language models** (`--model-type vlm`) - Uses the `mlx-vlm` library for multimodal models that can process both text and images
76 |
77 | ## Installation
78 |
79 | Follow these steps to set up the MLX-powered server:
80 |
81 | ### Prerequisites
82 | - MacOS with Apple Silicon (M-series) chip
83 | - Python 3.11 (native ARM version)
84 | - pip package manager
85 |
86 | ### Setup Steps
87 | 1. Create a virtual environment for the project:
88 | ```bash
89 | python3.11 -m venv oai-compat-server
90 | ```
91 |
92 | 2. Activate the virtual environment:
93 | ```bash
94 | source oai-compat-server/bin/activate
95 | ```
96 |
97 | 3. Install the package:
98 | ```bash
99 | # Option 1: Install from PyPI
100 | pip install mlx-openai-server
101 |
102 | # Option 2: Install directly from GitHub
103 | pip install git+https://github.com/cubist38/mlx-openai-server.git
104 |
105 | # Option 3: Clone and install in development mode
106 | git clone https://github.com/cubist38/mlx-openai-server.git
107 | cd mlx-openai-server
108 | pip install -e .
109 | ```
110 |
111 | ### Troubleshooting
112 | **Issue:** My OS and Python versions meet the requirements, but `pip` cannot find a matching distribution.
113 |
114 | **Cause:** You might be using a non-native Python version. Run the following command to check:
115 | ```bash
116 | python -c "import platform; print(platform.processor())"
117 | ```
118 | If the output is `i386` (on an M-series machine), you are using a non-native Python. Switch to a native Python version. A good approach is to use [Conda](https://stackoverflow.com/questions/65415996/how-to-specify-the-architecture-or-platform-for-a-new-conda-environment-apple).
119 |
120 | ## Usage
121 |
122 | ### Starting the Server
123 |
124 | To start the MLX server, activate the virtual environment and run the main application file:
125 | ```bash
126 | source oai-compat-server/bin/activate
127 | python -m app.main \
128 | --model-path \
129 | --model-type \
130 | --max-concurrency 1 \
131 | --queue-timeout 300 \
132 | --queue-size 100
133 | ```
134 |
135 | #### Server Parameters
136 | - `--model-path`: Path to the MLX model directory (local path or Hugging Face model repository)
137 | - `--model-type`: Type of model to run (`lm` for text-only models, `vlm` for vision-language models). Default: `lm`
138 | - `--max-concurrency`: Maximum number of concurrent requests (default: 1)
139 | - `--queue-timeout`: Request timeout in seconds (default: 300)
140 | - `--queue-size`: Maximum queue size for pending requests (default: 100)
141 | - `--port`: Port to run the server on (default: 8000)
142 | - `--host`: Host to run the server on (default: 0.0.0.0)
143 |
144 | #### Example Configurations
145 |
146 | Text-only model:
147 | ```bash
148 | python -m app.main \
149 | --model-path mlx-community/gemma-3-4b-it-4bit \
150 | --model-type lm \
151 | --max-concurrency 1 \
152 | --queue-timeout 300 \
153 | --queue-size 100
154 | ```
155 |
156 | > **Note:** Text embeddings via the `/v1/embeddings` endpoint are now available with both text-only models (`--model-type lm`) and vision-language models (`--model-type vlm`).
157 |
158 | Vision-language model:
159 | ```bash
160 | python -m app.main \
161 | --model-path mlx-community/llava-phi-3-vision-4bit \
162 | --model-type vlm \
163 | --max-concurrency 1 \
164 | --queue-timeout 300 \
165 | --queue-size 100
166 | ```
167 |
168 | ### CLI Usage
169 |
170 | CLI commands:
171 | ```bash
172 | mlx-openai-server --version
173 | mlx-openai-server --help
174 | mlx-openai-server launch --help
175 | ```
176 |
177 | To launch the server:
178 | ```bash
179 | mlx-openai-server launch --model-path --model-type --port 8000
180 | ```
181 |
182 | ### Using the API
183 |
184 | The server provides OpenAI-compatible endpoints that you can use with standard OpenAI client libraries. Here are some examples:
185 |
186 | #### Text Completion
187 | ```python
188 | import openai
189 |
190 | client = openai.OpenAI(
191 | base_url="http://localhost:8000/v1",
192 | api_key="not-needed" # API key is not required for local server
193 | )
194 |
195 | response = client.chat.completions.create(
196 | model="local-model", # Model name doesn't matter for local server
197 | messages=[
198 | {"role": "user", "content": "What is the capital of France?"}
199 | ],
200 | temperature=0.7
201 | )
202 | print(response.choices[0].message.content)
203 | ```
204 |
205 | #### Vision-Language Model
206 | ```python
207 | import openai
208 | import base64
209 |
210 | client = openai.OpenAI(
211 | base_url="http://localhost:8000/v1",
212 | api_key="not-needed"
213 | )
214 |
215 | # Load and encode image
216 | with open("image.jpg", "rb") as image_file:
217 | base64_image = base64.b64encode(image_file.read()).decode('utf-8')
218 |
219 | response = client.chat.completions.create(
220 | model="local-vlm", # Model name doesn't matter for local server
221 | messages=[
222 | {
223 | "role": "user",
224 | "content": [
225 | {"type": "text", "text": "What's in this image?"},
226 | {
227 | "type": "image_url",
228 | "image_url": {
229 | "url": f"data:image/jpeg;base64,{base64_image}"
230 | }
231 | }
232 | ]
233 | }
234 | ]
235 | )
236 | print(response.choices[0].message.content)
237 | ```
238 |
239 | #### Function Calling
240 | ```python
241 | import openai
242 |
243 | client = openai.OpenAI(
244 | base_url="http://localhost:8000/v1",
245 | api_key="not-needed"
246 | )
247 |
248 | # Define the messages and tools
249 | messages = [
250 | {
251 | "role": "user",
252 | "content": "What is the weather in Tokyo?"
253 | }
254 | ]
255 |
256 | tools = [
257 | {
258 | "type": "function",
259 | "function": {
260 | "name": "get_weather",
261 | "description": "Get the weather in a given city",
262 | "parameters": {
263 | "type": "object",
264 | "properties": {
265 | "city": {"type": "string", "description": "The city to get the weather for"}
266 | }
267 | }
268 | }
269 | }
270 | ]
271 |
272 | # Make the API call
273 | completion = client.chat.completions.create(
274 | model="local-model",
275 | messages=messages,
276 | tools=tools,
277 | tool_choice="auto"
278 | )
279 |
280 | # Handle the tool call response
281 | if completion.choices[0].message.tool_calls:
282 | tool_call = completion.choices[0].message.tool_calls[0]
283 | print(f"Function called: {tool_call.function.name}")
284 | print(f"Arguments: {tool_call.function.arguments}")
285 |
286 | # Process the tool call - typically you would call your actual function here
287 | # For this example, we'll just hardcode a weather response
288 | weather_info = {"temperature": "22°C", "conditions": "Sunny", "humidity": "65%"}
289 |
290 | # Add the tool call and function response to the conversation
291 | messages.append(completion.choices[0].message)
292 | messages.append({
293 | "role": "tool",
294 | "tool_call_id": tool_call.id,
295 | "name": tool_call.function.name,
296 | "content": str(weather_info)
297 | })
298 |
299 | # Continue the conversation with the function result
300 | final_response = client.chat.completions.create(
301 | model="local-model",
302 | messages=messages
303 | )
304 | print("\nFinal response:")
305 | print(final_response.choices[0].message.content)
306 | ```
307 |
308 | #### Embeddings
309 |
310 | 1. Text-only model embeddings:
311 | ```python
312 | import openai
313 |
314 | client = openai.OpenAI(
315 | base_url="http://localhost:8000/v1",
316 | api_key="not-needed"
317 | )
318 |
319 | # Generate embeddings for a single text
320 | embedding_response = client.embeddings.create(
321 | model="mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8",
322 | input=["The quick brown fox jumps over the lazy dog"]
323 | )
324 | print(f"Embedding dimension: {len(embedding_response.data[0].embedding)}")
325 |
326 | # Generate embeddings for multiple texts
327 | batch_response = client.embeddings.create(
328 | model="mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8",
329 | input=[
330 | "Machine learning algorithms improve with more data",
331 | "Natural language processing helps computers understand human language",
332 | "Computer vision allows machines to interpret visual information"
333 | ]
334 | )
335 | print(f"Number of embeddings: {len(batch_response.data)}")
336 | ```
337 |
338 | 2. Vision-language model embeddings:
339 | ```python
340 | import openai
341 | import base64
342 | from PIL import Image
343 | from io import BytesIO
344 |
345 | client = openai.OpenAI(
346 | base_url="http://localhost:8000/v1",
347 | api_key="not-needed"
348 | )
349 |
350 | # Helper function to encode images as base64
351 | def image_to_base64(image_path):
352 | image = Image.open(image_path)
353 | buffer = BytesIO()
354 | image.save(buffer, format="PNG")
355 | buffer.seek(0)
356 | image_data = buffer.getvalue()
357 | image_base64 = base64.b64encode(image_data).decode('utf-8')
358 | return f"data:image/png;base64,{image_base64}"
359 |
360 | # Encode the image
361 | image_uri = image_to_base64("images/attention.png")
362 |
363 | # Generate embeddings for text+image
364 | vision_embedding = client.embeddings.create(
365 | model="mlx-community/Qwen2.5-VL-3B-Instruct-4bit",
366 | input=["Describe the image in detail"],
367 | extra_body={"image_url": image_uri}
368 | )
369 | print(f"Vision embedding dimension: {len(vision_embedding.data[0].embedding)}")
370 | ```
371 |
372 | > **Note:** Replace the model name and image path as needed. The `extra_body` parameter is used to pass the image data URI to the API.
373 |
374 | > **Warning:** Make sure you're running the server with `--model-type vlm` when making vision requests. If you send a vision request to a server running with `--model-type lm` (text-only model), you'll receive a 400 error with a message that vision requests are not supported with text-only models.
375 |
376 | ## Request Queue System
377 |
378 | The server implements a robust request queue system to manage and optimize MLX model inference requests. This system ensures efficient resource utilization and fair request processing.
379 |
380 | ### Key Features
381 |
382 | - **Concurrency Control**: Limits the number of simultaneous model inferences to prevent resource exhaustion
383 | - **Request Queuing**: Implements a fair, first-come-first-served queue for pending requests
384 | - **Timeout Management**: Automatically handles requests that exceed the configured timeout
385 | - **Real-time Monitoring**: Provides endpoints to monitor queue status and performance metrics
386 |
387 | ### Architecture
388 |
389 | The queue system consists of two main components:
390 |
391 | 1. **RequestQueue**: An asynchronous queue implementation that:
392 | - Manages pending requests with configurable queue size
393 | - Controls concurrent execution using semaphores
394 | - Handles timeouts and errors gracefully
395 | - Provides real-time queue statistics
396 |
397 | 2. **Model Handlers**: Specialized handlers for different model types:
398 | - `MLXLMHandler`: Manages text-only model requests
399 | - `MLXVLMHandler`: Manages vision-language model requests
400 |
401 | ### Queue Monitoring
402 |
403 | Monitor queue statistics using the `/v1/queue/stats` endpoint:
404 |
405 | ```bash
406 | curl http://localhost:8000/v1/queue/stats
407 | ```
408 |
409 | Example response:
410 | ```json
411 | {
412 | "status": "ok",
413 | "queue_stats": {
414 | "running": true,
415 | "queue_size": 3,
416 | "max_queue_size": 100,
417 | "active_requests": 5,
418 | "max_concurrency": 2
419 | }
420 | }
421 | ```
422 |
423 | ### Error Handling
424 |
425 | The queue system handles various error conditions:
426 |
427 | 1. **Queue Full (429)**: When the queue reaches its maximum size
428 | ```json
429 | {
430 | "detail": "Too many requests. Service is at capacity."
431 | }
432 | ```
433 |
434 | 2. **Request Timeout**: When a request exceeds the configured timeout
435 | ```json
436 | {
437 | "detail": "Request processing timed out after 300 seconds"
438 | }
439 | ```
440 |
441 | 3. **Model Errors**: When the model encounters an error during inference
442 | ```json
443 | {
444 | "detail": "Failed to generate response: "
445 | }
446 | ```
447 |
448 | ### Streaming Responses
449 |
450 | The server supports streaming responses with proper chunk formatting:
451 | ```python
452 | {
453 | "id": "chatcmpl-1234567890",
454 | "object": "chat.completion.chunk",
455 | "created": 1234567890,
456 | "model": "local-model",
457 | "choices": [{
458 | "index": 0,
459 | "delta": {"content": "chunk of text"},
460 | "finish_reason": null
461 | }]
462 | }
463 | ```
464 |
465 | ## API Response Schemas
466 |
467 | The server implements OpenAI-compatible API response schemas to ensure seamless integration with existing applications. Below are the key response formats:
468 |
469 | ### Chat Completions Response
470 |
471 | ```json
472 | {
473 | "id": "chatcmpl-123456789",
474 | "object": "chat.completion",
475 | "created": 1677858242,
476 | "model": "local-model",
477 | "choices": [
478 | {
479 | "index": 0,
480 | "message": {
481 | "role": "assistant",
482 | "content": "This is the response content from the model."
483 | },
484 | "finish_reason": "stop"
485 | }
486 | ],
487 | "usage": {
488 | "prompt_tokens": 10,
489 | "completion_tokens": 20,
490 | "total_tokens": 30
491 | }
492 | }
493 | ```
494 |
495 | ### Embeddings Response
496 |
497 | ```json
498 | {
499 | "object": "list",
500 | "data": [
501 | {
502 | "object": "embedding",
503 | "embedding": [0.001, 0.002, ..., 0.999],
504 | "index": 0
505 | }
506 | ],
507 | "model": "local-model",
508 | "usage": {
509 | "prompt_tokens": 8,
510 | "total_tokens": 8
511 | }
512 | }
513 | ```
514 |
515 | ### Function/Tool Calling Response
516 |
517 | ```json
518 | {
519 | "id": "chatcmpl-123456789",
520 | "object": "chat.completion",
521 | "created": 1677858242,
522 | "model": "local-model",
523 | "choices": [
524 | {
525 | "index": 0,
526 | "message": {
527 | "role": "assistant",
528 | "content": null,
529 | "tool_calls": [
530 | {
531 | "id": "call_abc123",
532 | "type": "function",
533 | "function": {
534 | "name": "get_weather",
535 | "arguments": "{\"city\":\"Tokyo\"}"
536 | }
537 | }
538 | ]
539 | },
540 | "finish_reason": "tool_calls"
541 | }
542 | ],
543 | "usage": {
544 | "prompt_tokens": 15,
545 | "completion_tokens": 25,
546 | "total_tokens": 40
547 | }
548 | }
549 | ```
550 |
551 | ### Error Response
552 |
553 | ```json
554 | {
555 | "error": {
556 | "message": "Error message describing what went wrong",
557 | "type": "invalid_request_error",
558 | "param": null,
559 | "code": null
560 | }
561 | }
562 | ```
563 |
564 | ## Example Notebooks
565 |
566 | The repository includes example notebooks to help you get started with different aspects of the API:
567 |
568 | - **function_calling_examples.ipynb**: A practical guide to implementing and using function calling with local models, including:
569 | - Setting up function definitions
570 | - Making function calling requests
571 | - Handling function call responses
572 | - Working with streaming function calls
573 | - Building multi-turn conversations with tool use
574 |
575 | - **vision_examples.ipynb**: A comprehensive guide to using the vision capabilities of the API, including:
576 | - Processing image inputs in various formats
577 | - Vision analysis and object detection
578 | - Multi-turn conversations with images
579 | - Using vision models for detailed image description and analysis
580 |
581 | - **lm_embeddings_examples.ipynb**: A comprehensive guide to using the embeddings API for text-only models, including:
582 | - Generating embeddings for single and batch inputs
583 | - Computing semantic similarity between texts
584 | - Building a simple vector-based search system
585 | - Comparing semantic relationships between concepts
586 |
587 | - **vlm_embeddings_examples.ipynb**: A detailed guide to working with Vision-Language Model embeddings, including:
588 | - Generating embeddings for images with text prompts
589 | - Creating text-only embeddings with VLMs
590 | - Calculating similarity between text and image representations
591 | - Understanding the shared embedding space of multimodal models
592 | - Practical applications of VLM embeddings
593 |
594 | - **simple_rag_demo.ipynb**: A practical guide to building a lightweight Retrieval-Augmented Generation (RAG) pipeline over PDF documents using local MLX Server, including:
595 | - Reading and chunking PDF documents
596 | - Generating text embeddings via MLX Server
597 | - Creating a simple vector store for retrieval
598 | - Performing question answering based on relevant chunks
599 | - End-to-end demonstration of document QA using Qwen3 local model
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 | ## Large Models
608 | When using models that are large relative to your system's available RAM, performance may suffer. mlx-lm tries to improve speed by wiring the memory used by the model and its cache—this optimization is only available on macOS 15.0 or newer.
609 | If you see the following warning message:
610 | > [WARNING] Generating with a model that requires ...
611 | it means the model may run slowly on your machine. If the model fits in RAM, you can often improve performance by raising the system's wired memory limit. To do this, run:
612 | ```bash
613 | bash configure_mlx.sh
614 | ```
615 |
616 | ## Contributing
617 | We welcome contributions to improve this project! Here's how you can contribute:
618 | 1. Fork the repository to your GitHub account.
619 | 2. Create a new branch for your feature or bug fix:
620 | ```bash
621 | git checkout -b feature-name
622 | ```
623 | 3. Commit your changes with clear and concise messages:
624 | ```bash
625 | git commit -m "Add feature-name"
626 | ```
627 | 4. Push your branch to your forked repository:
628 | ```bash
629 | git push origin feature-name
630 | ```
631 | 5. Open a pull request to the main repository for review.
632 |
633 | ## License
634 | This project is licensed under the [MIT License](LICENSE). You are free to use, modify, and distribute it under the terms of the license.
635 |
636 | ## Support
637 | If you encounter any issues or have questions, please:
638 | - Open an issue in the repository.
639 | - Contact the maintainers via the provided contact information.
640 |
641 | Stay tuned for updates and enhancements!
642 |
643 | ## Acknowledgments
644 |
645 | We extend our heartfelt gratitude to the following individuals and organizations whose contributions have been instrumental in making this project possible:
646 |
647 | ### Core Technologies
648 | - [MLX team](https://github.com/ml-explore/mlx) for developing the groundbreaking MLX framework, which provides the foundation for efficient machine learning on Apple Silicon
649 | - [mlx-lm](https://github.com/ml-explore/mlx-lm) for efficient large language models support
650 | - [mlx-vlm](https://github.com/Blaizzy/mlx-vlm/tree/main) for pioneering multimodal model support within the MLX ecosystem
651 | - [mlx-community](https://huggingface.co/mlx-community) for curating and maintaining a diverse collection of high-quality MLX models
652 |
653 | ### Open Source Community
654 | We deeply appreciate the broader open-source community for their invaluable contributions. Your dedication to:
655 | - Innovation in machine learning and AI
656 | - Collaborative development practices
657 | - Knowledge sharing and documentation
658 | - Continuous improvement of tools and frameworks
659 |
660 | Your collective efforts continue to drive progress and make projects like this possible. We are proud to be part of this vibrant ecosystem.
661 |
662 | ### Special Thanks
663 | A special acknowledgment to all contributors, users, and supporters who have helped shape this project through their feedback, bug reports, and suggestions. Your engagement helps make this project better for everyone.
--------------------------------------------------------------------------------
/app/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from app.version import __version__
3 |
4 | # Suppress transformers warnings
5 | os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
6 |
7 | __all__ = ["__version__"]
--------------------------------------------------------------------------------
/app/api/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/app/api/endpoints.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import time
4 | from http import HTTPStatus
5 | from typing import Any, Dict, List, Optional, Union, AsyncGenerator
6 |
7 | from fastapi import APIRouter, Request
8 | from fastapi.responses import JSONResponse, StreamingResponse
9 | from loguru import logger
10 |
11 | from app.handler.mlx_lm import MLXLMHandler
12 | from app.schemas.openai import (ChatCompletionChunk,
13 | ChatCompletionMessageToolCall,
14 | ChatCompletionRequest, ChatCompletionResponse,
15 | Choice, ChoiceDeltaFunctionCall,
16 | ChoiceDeltaToolCall, Delta, Embedding,
17 | EmbeddingRequest, EmbeddingResponse,
18 | FunctionCall, Message, Model, ModelsResponse,
19 | StreamingChoice)
20 | from app.utils.errors import create_error_response
21 |
22 | router = APIRouter()
23 |
24 |
25 | @router.post("/health")
26 | async def health(raw_request: Request):
27 | """
28 | Health check endpoint.
29 | """
30 | try:
31 | return {"status": "ok"}
32 | except Exception as e:
33 | logger.error(f"Health check failed: {str(e)}")
34 | return JSONResponse(content=create_error_response("Health check failed", "server_error", 500), status_code=500)
35 |
36 | @router.get("/v1/queue/stats")
37 | async def queue_stats(raw_request: Request):
38 | """
39 | Get queue statistics.
40 | """
41 | handler = raw_request.app.state.handler
42 | if handler is None:
43 | return JSONResponse(content=create_error_response("Model handler not initialized", "service_unavailable", 503), status_code=503)
44 |
45 | try:
46 | stats = await handler.get_queue_stats()
47 | return {
48 | "status": "ok",
49 | "queue_stats": stats
50 | }
51 | except Exception as e:
52 | logger.error(f"Failed to get queue stats: {str(e)}")
53 | return JSONResponse(content=create_error_response("Failed to get queue stats", "server_error", 500), status_code=500)
54 |
55 | @router.get("/v1/models")
56 | async def models(raw_request: Request):
57 | """
58 | Get list of available models.
59 | """
60 | handler = raw_request.app.state.handler
61 | models_data = handler.get_models()
62 | return ModelsResponse(data=[Model(**model) for model in models_data])
63 |
64 | @router.post("/v1/chat/completions")
65 | async def chat_completions(request: ChatCompletionRequest, raw_request: Request):
66 | """Handle chat completion requests."""
67 |
68 | handler = raw_request.app.state.handler
69 | if handler is None:
70 | return JSONResponse(content=create_error_response("Model handler not initialized", "service_unavailable", 503), status_code=503)
71 |
72 | try:
73 | # Check if this is a vision request
74 | is_vision_request = request.is_vision_request()
75 | # If it's a vision request but the handler is MLXLMHandler (text-only), reject it
76 | if is_vision_request and isinstance(handler, MLXLMHandler):
77 | return JSONResponse(
78 | content=create_error_response(
79 | "Vision requests are not supported with text-only models. Use a VLM model type instead.",
80 | "unsupported_request",
81 | 400
82 | ),
83 | status_code=400
84 | )
85 |
86 | # Process the request based on type
87 | return await process_vision_request(handler, request) if is_vision_request \
88 | else await process_text_request(handler, request)
89 | except Exception as e:
90 | logger.error(f"Error processing chat completion request: {str(e)}", exc_info=True)
91 | return JSONResponse(content=create_error_response(str(e)), status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
92 |
93 | @router.post("/v1/embeddings")
94 | async def embeddings(request: EmbeddingRequest, raw_request: Request):
95 | """Handle embedding requests."""
96 | handler = raw_request.app.state.handler
97 | if handler is None:
98 | return JSONResponse(content=create_error_response("Model handler not initialized", "service_unavailable", 503), status_code=503)
99 |
100 | try:
101 | embeddings = await handler.generate_embeddings_response(request)
102 | return create_response_embeddings(embeddings, request.model)
103 | except Exception as e:
104 | logger.error(f"Error processing embedding request: {str(e)}", exc_info=True)
105 | return JSONResponse(content=create_error_response(str(e)), status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
106 |
107 | def create_response_embeddings(embeddings: List[float], model: str) -> EmbeddingResponse:
108 | embeddings_response = []
109 | for index, embedding in enumerate(embeddings):
110 | embeddings_response.append(Embedding(embedding=embedding, index=index))
111 | return EmbeddingResponse(data=embeddings_response, model=model)
112 |
113 | def create_response_chunk(chunk: Union[str, Dict[str, Any]], model: str, is_final: bool = False, finish_reason: Optional[str] = "stop", chat_id: Optional[str] = None, created_time: Optional[int] = None) -> ChatCompletionChunk:
114 | """Create a formatted response chunk for streaming."""
115 | chat_id = chat_id if chat_id else get_id()
116 | created_time = created_time if created_time else int(time.time())
117 | if isinstance(chunk, str):
118 | return ChatCompletionChunk(
119 | id=chat_id,
120 | object="chat.completion.chunk",
121 | created=created_time,
122 | model=model,
123 | choices=[StreamingChoice(
124 | index=0,
125 | delta=Delta(content=chunk, role="assistant"),
126 | finish_reason=finish_reason if is_final else None
127 | )]
128 | )
129 | if "reasoning_content" in chunk:
130 | return ChatCompletionChunk(
131 | id=chat_id,
132 | object="chat.completion.chunk",
133 | created=created_time,
134 | model=model,
135 | choices=[StreamingChoice(
136 | index=0,
137 | delta=Delta(reasoning_content=chunk["reasoning_content"], role="assistant"),
138 | finish_reason=finish_reason if is_final else None
139 | )]
140 | )
141 |
142 | if "name" in chunk and chunk["name"]:
143 | tool_chunk = ChoiceDeltaToolCall(
144 | index=chunk["index"],
145 | type="function",
146 | id=get_tool_call_id(),
147 | function=ChoiceDeltaFunctionCall(
148 | name=chunk["name"],
149 | arguments=""
150 | )
151 | )
152 | else:
153 | tool_chunk = ChoiceDeltaToolCall(
154 | index=chunk["index"],
155 | type="function",
156 | function= ChoiceDeltaFunctionCall(
157 | arguments=chunk["arguments"]
158 | )
159 | )
160 | delta = Delta(
161 | content = None,
162 | role = "assistant",
163 | tool_calls = [tool_chunk]
164 | )
165 | return ChatCompletionChunk(
166 | id=chat_id,
167 | object="chat.completion.chunk",
168 | created=created_time,
169 | model=model,
170 | choices=[StreamingChoice(index=0, delta=delta, finish_reason=None)]
171 | )
172 |
173 |
174 | async def handle_stream_response(generator: AsyncGenerator, model: str):
175 | """Handle streaming response generation (OpenAI-compatible)."""
176 | chat_index = get_id()
177 | created_time = int(time.time())
178 | try:
179 | finish_reason = "stop"
180 | index = -1
181 | # First chunk: role-only delta, as per OpenAI
182 | first_chunk = ChatCompletionChunk(
183 | id=chat_index,
184 | object="chat.completion.chunk",
185 | created=created_time,
186 | model=model,
187 | choices=[StreamingChoice(index=0, delta=Delta(role="assistant"), finish_reason=None)]
188 | )
189 | yield f"data: {json.dumps(first_chunk.model_dump())}\n\n"
190 | async for chunk in generator:
191 | if chunk:
192 | if isinstance(chunk, str):
193 | response_chunk = create_response_chunk(chunk, model, chat_id=chat_index, created_time=created_time)
194 | yield f"data: {json.dumps(response_chunk.model_dump())}\n\n"
195 | else:
196 | finish_reason = "tool_calls"
197 | if "name" in chunk and chunk["name"]:
198 | index += 1
199 | payload = {
200 | "index": index,
201 | **chunk
202 | }
203 | response_chunk = create_response_chunk(payload, model, chat_id=chat_index, created_time=created_time)
204 | yield f"data: {json.dumps(response_chunk.model_dump())}\n\n"
205 | except Exception as e:
206 | logger.error(f"Error in stream wrapper: {str(e)}")
207 | error_response = create_error_response(str(e), "server_error", HTTPStatus.INTERNAL_SERVER_ERROR)
208 | # Yield error as last chunk before [DONE]
209 | yield f"data: {json.dumps(error_response)}\n\n"
210 | finally:
211 | # Final chunk: finish_reason and [DONE], as per OpenAI
212 | final_chunk = create_response_chunk('', model, is_final=True, finish_reason=finish_reason, chat_id=chat_index)
213 | yield f"data: {json.dumps(final_chunk.model_dump())}\n\n"
214 | yield "data: [DONE]\n\n"
215 |
216 | async def process_vision_request(handler, request: ChatCompletionRequest):
217 | """Process vision-specific requests."""
218 | if request.stream:
219 | return StreamingResponse(
220 | handle_stream_response(handler.generate_vision_stream(request), request.model),
221 | media_type="text/event-stream",
222 | headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}
223 | )
224 | return format_final_response(await handler.generate_vision_response(request), request.model)
225 |
226 | async def process_text_request(handler, request: ChatCompletionRequest):
227 | """Process text-only requests."""
228 | if request.stream:
229 | return StreamingResponse(
230 | handle_stream_response(handler.generate_text_stream(request), request.model),
231 | media_type="text/event-stream",
232 | headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}
233 | )
234 | return format_final_response(await handler.generate_text_response(request), request.model)
235 |
236 | def get_id():
237 | """
238 | Generate a unique ID for chat completions with timestamp and random component.
239 | """
240 | timestamp = int(time.time())
241 | random_suffix = random.randint(0, 999999)
242 | return f"chatcmpl_{timestamp}{random_suffix:06d}"
243 |
244 | def get_tool_call_id():
245 | """
246 | Generate a unique ID for tool calls with timestamp and random component.
247 | """
248 | timestamp = int(time.time())
249 | random_suffix = random.randint(0, 999999)
250 | return f"call_{timestamp}{random_suffix:06d}"
251 |
252 | def format_final_response(response: Union[str, List[Dict[str, Any]]], model: str) -> ChatCompletionResponse:
253 | """Format the final non-streaming response."""
254 |
255 | if isinstance(response, str):
256 | return ChatCompletionResponse(
257 | id=get_id(),
258 | object="chat.completion",
259 | created=int(time.time()),
260 | model=model,
261 | choices=[Choice(
262 | index=0,
263 | message=Message(role="assistant", content=response),
264 | finish_reason="stop"
265 | )]
266 | )
267 |
268 | reasoning_content = response.get("reasoning_content", None)
269 | tool_calls = response.get("tool_calls", [])
270 | tool_call_responses = []
271 | for idx, tool_call in enumerate(tool_calls):
272 | function_call = FunctionCall(
273 | name=tool_call.get("name"),
274 | arguments=json.dumps(tool_call.get("arguments"))
275 | )
276 | tool_call_response = ChatCompletionMessageToolCall(
277 | id=get_tool_call_id(),
278 | type="function",
279 | function=function_call,
280 | index=idx
281 | )
282 | tool_call_responses.append(tool_call_response)
283 |
284 | if len(tool_calls) > 0:
285 | message = Message(role="assistant", reasoning_content=reasoning_content, tool_calls=tool_call_responses)
286 | else:
287 | message = Message(role="assistant", content=response, reasoning_content=reasoning_content, tool_calls=tool_call_responses)
288 |
289 | return ChatCompletionResponse(
290 | id=get_id(),
291 | object="chat.completion",
292 | created=int(time.time()),
293 | model=model,
294 | choices=[Choice(
295 | index=0,
296 | message=message,
297 | finish_reason="tool_calls"
298 | )]
299 | )
--------------------------------------------------------------------------------
/app/cli.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import click
3 | import uvicorn
4 | from loguru import logger
5 | import sys
6 | from functools import lru_cache
7 | from app.version import __version__
8 | from app.main import setup_server
9 |
10 | class Config:
11 | """Configuration container for server parameters."""
12 | def __init__(self, model_path, model_type, port, host, max_concurrency, queue_timeout, queue_size):
13 | self.model_path = model_path
14 | self.model_type = model_type
15 | self.port = port
16 | self.host = host
17 | self.max_concurrency = max_concurrency
18 | self.queue_timeout = queue_timeout
19 | self.queue_size = queue_size
20 |
21 |
22 | # Configure Loguru once at module level
23 | def configure_logging():
24 | """Set up optimized logging configuration."""
25 | logger.remove() # Remove default handler
26 | logger.add(
27 | sys.stderr,
28 | format="{time:YYYY-MM-DD HH:mm:ss} | "
29 | "{level: <8} | "
30 | "{name}:{function}:{line} | "
31 | "✦ {message}",
32 | colorize=True,
33 | level="INFO"
34 | )
35 |
36 | # Apply logging configuration
37 | configure_logging()
38 |
39 |
40 | @click.group()
41 | @click.version_option(
42 | version=__version__,
43 | message="""
44 | ✨ %(prog)s - OpenAI Compatible API Server for MLX models ✨
45 | ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
46 | 🚀 Version: %(version)s
47 | """
48 | )
49 | def cli():
50 | """MLX Server - OpenAI Compatible API for MLX models."""
51 | pass
52 |
53 |
54 | @lru_cache(maxsize=1)
55 | def get_server_config(model_path, model_type, port, host, max_concurrency, queue_timeout, queue_size):
56 | """Cache and return server configuration to avoid redundant processing."""
57 | return Config(
58 | model_path=model_path,
59 | model_type=model_type,
60 | port=port,
61 | host=host,
62 | max_concurrency=max_concurrency,
63 | queue_timeout=queue_timeout,
64 | queue_size=queue_size
65 | )
66 |
67 |
68 | def print_startup_banner(args):
69 | """Display beautiful startup banner with configuration details."""
70 | logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
71 | logger.info(f"✨ MLX Server v{__version__} Starting ✨")
72 | logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
73 | logger.info(f"🔮 Model: {args.model_path}")
74 | logger.info(f"🔮 Model Type: {args.model_type}")
75 | logger.info(f"🌐 Host: {args.host}")
76 | logger.info(f"🔌 Port: {args.port}")
77 | logger.info(f"⚡ Max Concurrency: {args.max_concurrency}")
78 | logger.info(f"⏱️ Queue Timeout: {args.queue_timeout} seconds")
79 | logger.info(f"📊 Queue Size: {args.queue_size}")
80 | logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
81 |
82 |
83 | @cli.command()
84 | @click.option(
85 | "--model-path",
86 | required=True,
87 | help="Path to the model"
88 | )
89 | @click.option(
90 | "--model-type",
91 | default="lm",
92 | type=click.Choice(["lm", "vlm"]),
93 | help="Type of model to run"
94 | )
95 | @click.option(
96 | "--port",
97 | default=8000,
98 | type=int,
99 | help="Port to run the server on"
100 | )
101 | @click.option(
102 | "--host",
103 | default="0.0.0.0",
104 | help="Host to run the server on"
105 | )
106 | @click.option(
107 | "--max-concurrency",
108 | default=1,
109 | type=int,
110 | help="Maximum number of concurrent requests"
111 | )
112 | @click.option(
113 | "--queue-timeout",
114 | default=300,
115 | type=int,
116 | help="Request timeout in seconds"
117 | )
118 | @click.option(
119 | "--queue-size",
120 | default=100,
121 | type=int,
122 | help="Maximum queue size for pending requests"
123 | )
124 | def launch(model_path, model_type, port, host, max_concurrency, queue_timeout, queue_size):
125 | """Launch the MLX server with the specified model."""
126 | try:
127 | # Get optimized configuration
128 | args = get_server_config(model_path, model_type, port, host, max_concurrency, queue_timeout, queue_size)
129 |
130 | # Display startup information
131 | print_startup_banner(args)
132 |
133 | # Set up and start the server
134 | config = asyncio.run(setup_server(args))
135 | logger.info("Server configuration complete.")
136 | logger.info("Starting Uvicorn server...")
137 | uvicorn.Server(config).run()
138 | except KeyboardInterrupt:
139 | logger.info("Server shutdown requested by user. Exiting...")
140 | except Exception as e:
141 | logger.error(f"Server startup failed: {str(e)}")
142 | sys.exit(1)
143 |
144 |
145 | if __name__ == "__main__":
146 | cli()
--------------------------------------------------------------------------------
/app/core/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/app/core/image_processor.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import base64
3 | import hashlib
4 | import os
5 | from concurrent.futures import ThreadPoolExecutor
6 | from functools import lru_cache
7 | from io import BytesIO
8 | from typing import List, Optional
9 | import tempfile
10 | import aiohttp
11 | import time
12 | from PIL import Image
13 | from loguru import logger
14 |
15 | class ImageProcessor:
16 | def __init__(self, max_workers: int = 4, cache_size: int = 1000):
17 | # Use tempfile for macOS-efficient temporary file handling
18 | self.temp_dir = tempfile.TemporaryDirectory()
19 | self._session: Optional[aiohttp.ClientSession] = None
20 | self.executor = ThreadPoolExecutor(max_workers=max_workers)
21 | self._cache_size = cache_size
22 | self._last_cleanup = time.time()
23 | self._cleanup_interval = 3600 # 1 hour
24 | Image.MAX_IMAGE_PIXELS = 100000000 # Limit to 100 megapixels
25 |
26 | @lru_cache(maxsize=1000)
27 | def _get_image_hash(self, image_url: str) -> str:
28 | if image_url.startswith("data:"):
29 | _, encoded = image_url.split(",", 1)
30 | data = base64.b64decode(encoded)
31 | else:
32 | data = image_url.encode('utf-8')
33 | return hashlib.md5(data).hexdigest()
34 |
35 | def _resize_image_keep_aspect_ratio(self, image: Image.Image, max_size: int = 448) -> Image.Image:
36 | width, height = image.size
37 | if width <= max_size and height <= max_size:
38 | return image
39 | if width > height:
40 | new_width = max_size
41 | new_height = int(height * max_size / width)
42 | else:
43 | new_height = max_size
44 | new_width = int(width * max_size / height)
45 |
46 | return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
47 |
48 | def _prepare_image_for_saving(self, image: Image.Image) -> Image.Image:
49 | if image.mode in ('RGBA', 'LA'):
50 | background = Image.new('RGB', image.size, (255, 255, 255))
51 | if image.mode == 'RGBA':
52 | background.paste(image, mask=image.split()[3])
53 | else:
54 | background.paste(image, mask=image.split()[1])
55 | return background
56 | elif image.mode != 'RGB':
57 | return image.convert('RGB')
58 | return image
59 |
60 | async def _get_session(self) -> aiohttp.ClientSession:
61 | if self._session is None or self._session.closed:
62 | self._session = aiohttp.ClientSession(
63 | timeout=aiohttp.ClientTimeout(total=30),
64 | headers={"User-Agent": "mlx-server-OAI-compat/1.0"}
65 | )
66 | return self._session
67 |
68 | def _cleanup_old_files(self):
69 | current_time = time.time()
70 | if current_time - self._last_cleanup > self._cleanup_interval:
71 | try:
72 | for file in os.listdir(self.temp_dir.name):
73 | file_path = os.path.join(self.temp_dir.name, file)
74 | if os.path.getmtime(file_path) < current_time - self._cleanup_interval:
75 | os.remove(file_path)
76 | self._last_cleanup = current_time
77 | except Exception as e:
78 | logger.warning(f"Failed to clean up old files: {str(e)}")
79 |
80 | async def _process_single_image(self, image_url: str) -> str:
81 | try:
82 | image_hash = self._get_image_hash(image_url)
83 | cached_path = os.path.join(self.temp_dir.name, f"{image_hash}.jpg")
84 |
85 | if os.path.exists(cached_path):
86 | logger.debug(f"Using cached image: {cached_path}")
87 | return cached_path
88 |
89 | if os.path.exists(image_url):
90 | # Read-only image loading for memory efficiency
91 | with Image.open(image_url, mode='r') as image:
92 | image = self._resize_image_keep_aspect_ratio(image)
93 | image = self._prepare_image_for_saving(image)
94 | image.save(cached_path, 'JPEG', quality=100, optimize=True)
95 | return cached_path
96 |
97 | elif image_url.startswith("data:"):
98 | _, encoded = image_url.split(",", 1)
99 | estimated_size = len(encoded) * 3 / 4
100 | if estimated_size > 100 * 1024 * 1024:
101 | raise ValueError("Base64-encoded image exceeds 100 MB")
102 | data = base64.b64decode(encoded)
103 | with Image.open(BytesIO(data), mode='r') as image:
104 | image = self._resize_image_keep_aspect_ratio(image)
105 | image = self._prepare_image_for_saving(image)
106 | image.save(cached_path, 'JPEG', quality=100, optimize=True)
107 | else:
108 | session = await self._get_session()
109 | async with session.get(image_url) as response:
110 | response.raise_for_status()
111 | data = await response.read()
112 | with Image.open(BytesIO(data), mode='r') as image:
113 | image = self._resize_image_keep_aspect_ratio(image)
114 | image = self._prepare_image_for_saving(image)
115 | image.save(cached_path, 'JPEG', quality=100, optimize=True)
116 |
117 | self._cleanup_old_files()
118 | return cached_path
119 |
120 | except Exception as e:
121 | logger.error(f"Failed to process image: {str(e)}")
122 | raise ValueError(f"Failed to process image: {str(e)}")
123 |
124 | async def process_image_url(self, image_url: str) -> str:
125 | return await self._process_single_image(image_url)
126 |
127 | async def process_image_urls(self, image_urls: List[str]) -> List[str]:
128 | tasks = [self.process_image_url(url) for url in image_urls]
129 | return await asyncio.gather(*tasks, return_exceptions=True)
130 |
131 | async def cleanup(self):
132 | if hasattr(self, '_cleaned') and self._cleaned:
133 | return
134 | self._cleaned = True
135 | try:
136 | if self._session and not self._session.closed:
137 | await self._session.close()
138 | except Exception as e:
139 | logger.warning(f"Exception closing aiohttp session: {str(e)}")
140 | try:
141 | self.executor.shutdown(wait=True)
142 | except Exception as e:
143 | logger.warning(f"Exception shutting down executor: {str(e)}")
144 | try:
145 | self.temp_dir.cleanup()
146 | except Exception as e:
147 | logger.warning(f"Exception cleaning up temp_dir: {str(e)}")
148 |
149 | async def __aenter__(self):
150 | """Enter the async context manager."""
151 | return self
152 |
153 | async def __aexit__(self, exc_type, exc, tb):
154 | """Exit the async context manager and perform cleanup."""
155 | await self.cleanup()
156 |
157 | def __del__(self):
158 | # Async cleanup cannot be reliably performed in __del__
159 | # Please use 'async with ImageProcessor()' or call 'await cleanup()' explicitly.
160 | pass
--------------------------------------------------------------------------------
/app/core/queue.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import time
3 | from typing import Any, Dict, Optional, Callable, Awaitable, TypeVar, Generic
4 | from loguru import logger
5 |
6 | T = TypeVar('T')
7 |
8 | class RequestItem(Generic[T]):
9 | """
10 | Represents a single request in the queue.
11 | """
12 | def __init__(self, request_id: str, data: Any):
13 | self.request_id = request_id
14 | self.data = data
15 | self.created_at = time.time()
16 | self.future = asyncio.Future()
17 |
18 | def set_result(self, result: T) -> None:
19 | """Set the result for this request."""
20 | if not self.future.done():
21 | self.future.set_result(result)
22 |
23 | def set_exception(self, exc: Exception) -> None:
24 | """Set an exception for this request."""
25 | if not self.future.done():
26 | self.future.set_exception(exc)
27 |
28 | async def get_result(self) -> T:
29 | """Wait for and return the result of this request."""
30 | return await self.future
31 |
32 | class RequestQueue:
33 | """
34 | A simple asynchronous request queue with configurable concurrency.
35 | """
36 | def __init__(self, max_concurrency: int = 2, timeout: float = 300.0, queue_size: int = 100):
37 | """
38 | Initialize the request queue.
39 |
40 | Args:
41 | max_concurrency (int): Maximum number of concurrent requests to process.
42 | timeout (float): Timeout in seconds for request processing.
43 | queue_size (int): Maximum queue size.
44 | """
45 | self.max_concurrency = max_concurrency
46 | self.timeout = timeout
47 | self.queue_size = queue_size
48 | self.semaphore = asyncio.Semaphore(max_concurrency)
49 | self.queue = asyncio.Queue(maxsize=queue_size)
50 | self.active_requests: Dict[str, RequestItem] = {}
51 | self._worker_task = None
52 | self._running = False
53 |
54 | async def start(self, processor: Callable[[Any], Awaitable[Any]]):
55 | """
56 | Start the queue worker.
57 |
58 | Args:
59 | processor: Async function that processes queue items.
60 | """
61 | if self._running:
62 | return
63 |
64 | self._running = True
65 | self._worker_task = asyncio.create_task(self._worker_loop(processor))
66 | logger.info(f"Started request queue with max concurrency: {self.max_concurrency}")
67 |
68 | async def stop(self):
69 | """Stop the queue worker."""
70 | if not self._running:
71 | return
72 |
73 | self._running = False
74 |
75 | # Cancel the worker task
76 | if self._worker_task and not self._worker_task.done():
77 | self._worker_task.cancel()
78 | try:
79 | await self._worker_task
80 | except asyncio.CancelledError:
81 | pass
82 |
83 | # Cancel all pending requests
84 | pending_requests = list(self.active_requests.values())
85 | for request in pending_requests:
86 | if not request.future.done():
87 | request.future.cancel()
88 |
89 | self.active_requests.clear()
90 | logger.info("Stopped request queue")
91 |
92 | async def _worker_loop(self, processor: Callable[[Any], Awaitable[Any]]):
93 | """
94 | Main worker loop that processes queue items.
95 |
96 | Args:
97 | processor: Async function that processes queue items.
98 | """
99 | while self._running:
100 | try:
101 | # Get the next item from the queue
102 | request = await self.queue.get()
103 |
104 | # Process the request with concurrency control
105 | asyncio.create_task(self._process_request(request, processor))
106 |
107 | except asyncio.CancelledError:
108 | break
109 | except Exception as e:
110 | logger.error(f"Error in worker loop: {str(e)}")
111 |
112 | async def _process_request(self, request: RequestItem, processor: Callable[[Any], Awaitable[Any]]):
113 | """
114 | Process a single request with timeout and error handling.
115 |
116 | Args:
117 | request: The request to process.
118 | processor: Async function that processes the request.
119 | """
120 | # Use semaphore to limit concurrency
121 | async with self.semaphore:
122 | try:
123 | # Process with timeout
124 | processing_start = time.time()
125 | result = await asyncio.wait_for(
126 | processor(request.data),
127 | timeout=self.timeout
128 | )
129 | processing_time = time.time() - processing_start
130 |
131 | # Set the result
132 | request.set_result(result)
133 | logger.info(f"Request {request.request_id} processed in {processing_time:.2f}s")
134 |
135 | except asyncio.TimeoutError:
136 | request.set_exception(TimeoutError(f"Request processing timed out after {self.timeout}s"))
137 | logger.warning(f"Request {request.request_id} timed out after {self.timeout}s")
138 |
139 | except Exception as e:
140 | request.set_exception(e)
141 | logger.error(f"Error processing request {request.request_id}: {str(e)}")
142 |
143 | finally:
144 | # Remove from active requests
145 | self.active_requests.pop(request.request_id, None)
146 |
147 | async def enqueue(self, request_id: str, data: Any) -> RequestItem:
148 | """
149 | Add a request to the queue.
150 |
151 | Args:
152 | request_id: Unique ID for the request.
153 | data: The request data to process.
154 |
155 | Returns:
156 | RequestItem: The queued request item.
157 |
158 | Raises:
159 | asyncio.QueueFull: If the queue is full.
160 | """
161 | if not self._running:
162 | raise RuntimeError("Queue is not running")
163 |
164 | # Create request item
165 | request = RequestItem(request_id, data)
166 |
167 | # Add to active requests and queue
168 | self.active_requests[request_id] = request
169 |
170 | try:
171 | # This will raise QueueFull if the queue is full
172 | await asyncio.wait_for(
173 | self.queue.put(request),
174 | timeout=1.0 # Short timeout for queue put
175 | )
176 | queue_time = time.time() - request.created_at
177 | logger.info(f"Request {request_id} queued (wait: {queue_time:.2f}s)")
178 | return request
179 |
180 | except asyncio.TimeoutError:
181 | self.active_requests.pop(request_id, None)
182 | raise asyncio.QueueFull("Request queue is full and timed out waiting for space")
183 |
184 | async def submit(self, request_id: str, data: Any) -> Any:
185 | """
186 | Submit a request and wait for its result.
187 |
188 | Args:
189 | request_id: Unique ID for the request.
190 | data: The request data to process.
191 |
192 | Returns:
193 | The result of processing the request.
194 |
195 | Raises:
196 | Various exceptions that may occur during processing.
197 | """
198 | request = await self.enqueue(request_id, data)
199 | return await request.get_result()
200 |
201 | def get_queue_stats(self) -> Dict[str, Any]:
202 | """
203 | Get queue statistics.
204 |
205 | Returns:
206 | Dict with queue statistics.
207 | """
208 | return {
209 | "running": self._running,
210 | "queue_size": self.queue.qsize(),
211 | "max_queue_size": self.queue_size,
212 | "active_requests": len(self.active_requests),
213 | "max_concurrency": self.max_concurrency
214 | }
215 |
216 | # Alias for the async stop method to maintain consistency in cleanup interfaces
217 | async def stop_async(self):
218 | """Alias for stop - stops the queue worker asynchronously."""
219 | await self.stop()
--------------------------------------------------------------------------------
/app/handler/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | MLX model handlers for text and vision-language models.
3 | """
4 |
5 | from app.handler.mlx_lm import MLXLMHandler
6 | from app.handler.mlx_vlm import MLXVLMHandler
7 |
8 | __all__ = [
9 | "MLXLMHandler",
10 | "MLXVLMHandler"
11 | ]
12 |
--------------------------------------------------------------------------------
/app/handler/mlx_lm.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import time
3 | import uuid
4 | from http import HTTPStatus
5 | from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
6 |
7 | from fastapi import HTTPException
8 | from loguru import logger
9 |
10 | from app.core.queue import RequestQueue
11 | from app.handler.parser import get_parser
12 | from app.models.mlx_lm import MLX_LM
13 | from app.schemas.openai import ChatCompletionRequest, EmbeddingRequest
14 | from app.utils.errors import create_error_response
15 |
16 | class MLXLMHandler:
17 | """
18 | Handler class for making requests to the underlying MLX text-only language model service.
19 | Provides request queuing, metrics tracking, and robust error handling.
20 | """
21 |
22 | def __init__(self, model_path: str, max_concurrency: int = 1):
23 | """
24 | Initialize the handler with the specified model path.
25 |
26 | Args:
27 | model_path (str): Path to the model directory.
28 | max_concurrency (int): Maximum number of concurrent model inference tasks.
29 | """
30 | self.model_path = model_path
31 | self.model = MLX_LM(model_path)
32 | self.model_type = self.model.get_model_type()
33 | self.model_created = int(time.time()) # Store creation time when model is loaded
34 |
35 | # Initialize request queue for text tasks
36 | self.request_queue = RequestQueue(max_concurrency=max_concurrency)
37 |
38 | logger.info(f"Initialized MLXHandler with model path: {model_path}")
39 |
40 | def get_models(self) -> List[Dict[str, Any]]:
41 | """
42 | Get list of available models with their metadata.
43 | """
44 | return [{
45 | "id": self.model_path,
46 | "object": "model",
47 | "created": self.model_created,
48 | "owned_by": "local"
49 | }]
50 |
51 | async def initialize(self, queue_config: Optional[Dict[str, Any]] = None):
52 | """Initialize the handler and start the request queue."""
53 | if not queue_config:
54 | queue_config = {
55 | "max_concurrency": 1,
56 | "timeout": 300,
57 | "queue_size": 100
58 | }
59 | self.request_queue = RequestQueue(
60 | max_concurrency=queue_config.get("max_concurrency"),
61 | timeout=queue_config.get("timeout"),
62 | queue_size=queue_config.get("queue_size")
63 | )
64 | await self.request_queue.start(self._process_request)
65 | logger.info("Initialized MLXHandler and started request queue")
66 |
67 | async def generate_text_stream(self, request: ChatCompletionRequest) -> AsyncGenerator[str, None]:
68 | """
69 | Generate a streaming response for text-only chat completion requests.
70 | Uses the request queue for handling concurrent requests.
71 |
72 | Args:
73 | request: ChatCompletionRequest object containing the messages.
74 |
75 | Yields:
76 | str: Response chunks.
77 | """
78 | request_id = f"text-{uuid.uuid4()}"
79 |
80 | try:
81 | chat_messages, model_params = await self._prepare_text_request(request)
82 | request_data = {
83 | "messages": chat_messages,
84 | "stream": True,
85 | **model_params
86 | }
87 | response_generator = await self.request_queue.submit(request_id, request_data)
88 |
89 | tools = model_params.get("chat_template_kwargs", {}).get("tools", None)
90 | enable_thinking = model_params.get("chat_template_kwargs", {}).get("enable_thinking", None)
91 |
92 | tool_parser, thinking_parser = get_parser(self.model_type)
93 | if enable_thinking and thinking_parser:
94 | for chunk in response_generator:
95 | if chunk:
96 | chunk, is_finish = thinking_parser.parse_stream(chunk.text)
97 | if chunk:
98 | yield chunk
99 | if is_finish:
100 | break
101 |
102 | if tools and tool_parser:
103 | for chunk in response_generator:
104 | if chunk:
105 | chunk = tool_parser.parse_stream(chunk.text)
106 | if chunk:
107 | yield chunk
108 | else:
109 | for chunk in response_generator:
110 | if chunk:
111 | yield chunk.text
112 |
113 | except asyncio.QueueFull:
114 | logger.error("Too many requests. Service is at capacity.")
115 | content = create_error_response("Too many requests. Service is at capacity.", "rate_limit_exceeded", HTTPStatus.TOO_MANY_REQUESTS)
116 | raise HTTPException(status_code=429, detail=content)
117 | except Exception as e:
118 | logger.error(f"Error in text stream generation for request {request_id}: {str(e)}")
119 | content = create_error_response(f"Failed to generate text stream: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR)
120 | raise HTTPException(status_code=500, detail=content)
121 |
122 | async def generate_text_response(self, request: ChatCompletionRequest) -> str:
123 | """
124 | Generate a complete response for text-only chat completion requests.
125 | Uses the request queue for handling concurrent requests.
126 |
127 | Args:
128 | request: ChatCompletionRequest object containing the messages.
129 |
130 | Returns:
131 | str: Complete response.
132 | """
133 | request_id = f"text-{uuid.uuid4()}"
134 |
135 | try:
136 | chat_messages, model_params = await self._prepare_text_request(request)
137 | request_data = {
138 | "messages": chat_messages,
139 | "stream": False,
140 | **model_params
141 | }
142 | response = await self.request_queue.submit(request_id, request_data)
143 | tools = model_params.get("chat_template_kwargs", {}).get("tools", None)
144 | enable_thinking = model_params.get("chat_template_kwargs", {}).get("enable_thinking", None)
145 | if not tools and not enable_thinking:
146 | return response
147 |
148 | tool_parser, thinking_parser = get_parser(self.model_type)
149 | if not tool_parser and not thinking_parser:
150 | return response
151 | parsed_response = {
152 | "reasoning_content": None,
153 | "tool_calls": None,
154 | "content": None
155 | }
156 | if enable_thinking and thinking_parser:
157 | thinking_response, response = thinking_parser.parse(response)
158 | parsed_response["reasoning_content"] = thinking_response
159 | if tools and tool_parser:
160 | tool_response, response = tool_parser.parse(response)
161 | parsed_response["tool_calls"] = tool_response
162 | parsed_response["content"] = response
163 |
164 | return parsed_response
165 |
166 | except asyncio.QueueFull:
167 | logger.error("Too many requests. Service is at capacity.")
168 | content = create_error_response("Too many requests. Service is at capacity.", "rate_limit_exceeded", HTTPStatus.TOO_MANY_REQUESTS)
169 | raise HTTPException(status_code=429, detail=content)
170 | except Exception as e:
171 | logger.error(f"Error in text response generation: {str(e)}")
172 | content = create_error_response(f"Failed to generate text response: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR)
173 | raise HTTPException(status_code=500, detail=content)
174 |
175 | async def generate_embeddings_response(self, request: EmbeddingRequest):
176 | """
177 | Generate embeddings for a given text input.
178 |
179 | Args:
180 | request: EmbeddingRequest object containing the text input.
181 |
182 | Returns:
183 | List[float]: Embeddings for the input text.
184 | """
185 | try:
186 | # Create a unique request ID
187 | request_id = f"embeddings-{uuid.uuid4()}"
188 | request_data = {
189 | "type": "embeddings",
190 | "input": request.input,
191 | "model": request.model
192 | }
193 |
194 | # Submit to the request queue
195 | response = await self.request_queue.submit(request_id, request_data)
196 |
197 | return response
198 |
199 | except Exception as e:
200 | logger.error(f"Error in embeddings generation: {str(e)}")
201 | content = create_error_response(f"Failed to generate embeddings: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR)
202 | raise HTTPException(status_code=500, detail=content)
203 |
204 |
205 | async def _process_request(self, request_data: Dict[str, Any]) -> str:
206 | """
207 | Process a text request. This is the worker function for the request queue.
208 |
209 | Args:
210 | request_data: Dictionary containing the request data.
211 |
212 | Returns:
213 | str: The model's response.
214 | """
215 | try:
216 | # Check if the request is for embeddings
217 | if request_data.get("type") == "embeddings":
218 | return self.model.get_embeddings(request_data["input"])
219 |
220 | # Extract request parameters
221 | messages = request_data.get("messages", [])
222 | stream = request_data.get("stream", False)
223 |
224 | # Remove these keys from model_params
225 | model_params = request_data.copy()
226 | model_params.pop("messages", None)
227 | model_params.pop("stream", None)
228 |
229 | # Call the model
230 | response = self.model(
231 | messages=messages,
232 | stream=stream,
233 | **model_params
234 | )
235 |
236 | return response
237 |
238 | except Exception as e:
239 | logger.error(f"Error processing text request: {str(e)}")
240 | raise
241 |
242 | async def get_queue_stats(self) -> Dict[str, Any]:
243 | """
244 | Get statistics from the request queue and performance metrics.
245 |
246 | Returns:
247 | Dict with queue and performance statistics.
248 | """
249 | queue_stats = self.request_queue.get_queue_stats()
250 |
251 | return {
252 | "queue_stats": queue_stats,
253 | }
254 |
255 | async def cleanup(self):
256 | """
257 | Cleanup resources and stop the request queue before shutdown.
258 |
259 | This method ensures all pending requests are properly cancelled
260 | and resources are released.
261 | """
262 | try:
263 | logger.info("Cleaning up MLXLMHandler resources")
264 | if hasattr(self, 'request_queue'):
265 | await self.request_queue.stop()
266 | logger.info("MLXLMHandler cleanup completed successfully")
267 | except Exception as e:
268 | logger.error(f"Error during MLXLMHandler cleanup: {str(e)}")
269 | raise
270 |
271 | async def _prepare_text_request(self, request: ChatCompletionRequest) -> Tuple[List[Dict[str, str]], Dict[str, Any]]:
272 | """
273 | Prepare a text request by parsing model parameters and verifying the format of messages.
274 |
275 | Args:
276 | request: ChatCompletionRequest object containing the messages.
277 |
278 | Returns:
279 | Tuple containing the formatted chat messages and model parameters.
280 | """
281 | try:
282 | # Get model parameters from the request
283 | temperature = request.temperature or 0.7
284 | top_p = request.top_p or 1.0
285 | frequency_penalty = request.frequency_penalty or 0.0
286 | presence_penalty = request.presence_penalty or 0.0
287 | max_tokens = request.max_tokens or 1024
288 | tools = request.tools or None
289 | chat_template_kwargs = request.chat_template_kwargs
290 | if tools:
291 | chat_template_kwargs.tools = tools
292 |
293 | model_params = {
294 | "temperature": temperature,
295 | "top_p": top_p,
296 | "frequency_penalty": frequency_penalty,
297 | "presence_penalty": presence_penalty,
298 | "max_tokens": max_tokens,
299 | "tools": tools,
300 | "chat_template_kwargs": chat_template_kwargs.model_dump()
301 | }
302 |
303 | # Format chat messages
304 | chat_messages = []
305 | for message in request.messages:
306 | chat_messages.append({
307 | "role": message.role,
308 | "content": message.content
309 | })
310 |
311 | return chat_messages, model_params
312 |
313 | except Exception as e:
314 | logger.error(f"Failed to prepare text request: {str(e)}")
315 | content = create_error_response(f"Failed to process request: {str(e)}", "bad_request", HTTPStatus.BAD_REQUEST)
316 | raise HTTPException(status_code=400, detail=content)
317 |
--------------------------------------------------------------------------------
/app/handler/mlx_vlm.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import base64
3 | import time
4 | import uuid
5 | from typing import Any, Dict, List, Optional, Tuple
6 | from http import HTTPStatus
7 |
8 | from fastapi import HTTPException
9 | from loguru import logger
10 |
11 | from app.core.image_processor import ImageProcessor
12 | from app.core.queue import RequestQueue
13 | from app.models.mlx_vlm import MLX_VLM
14 | from app.schemas.openai import ChatCompletionRequest, EmbeddingRequest
15 | from app.utils.errors import create_error_response
16 |
17 | class MLXVLMHandler:
18 | """
19 | Handler class for making requests to the underlying MLX vision-language model service.
20 | Provides caching, concurrent image processing, and robust error handling.
21 | """
22 |
23 | def __init__(self, model_path: str, max_workers: int = 4, max_concurrency: int = 1):
24 | """
25 | Initialize the handler with the specified model path.
26 |
27 | Args:
28 | model_path (str): Path to the model directory.
29 | max_workers (int): Maximum number of worker threads for image processing.
30 | max_concurrency (int): Maximum number of concurrent model inference tasks.
31 | """
32 | self.model_path = model_path
33 | self.model = MLX_VLM(model_path)
34 | self.image_processor = ImageProcessor(max_workers)
35 | self.model_created = int(time.time()) # Store creation time when model is loaded
36 |
37 | # Initialize request queue for vision and text tasks
38 | # We use the same queue for both vision and text tasks for simplicity
39 | # and to ensure we don't overload the model with too many concurrent requests
40 | self.request_queue = RequestQueue(max_concurrency=max_concurrency)
41 |
42 | logger.info(f"Initialized MLXHandler with model path: {model_path}")
43 |
44 | def get_models(self) -> List[Dict[str, Any]]:
45 | """
46 | Get list of available models with their metadata.
47 | """
48 | return [{
49 | "id": self.model_path,
50 | "object": "model",
51 | "created": self.model_created,
52 | "owned_by": "local"
53 | }]
54 |
55 | async def initialize(self, queue_config: Optional[Dict[str, Any]] = None):
56 | """Initialize the handler and start the request queue."""
57 |
58 | if not queue_config:
59 | queue_config = {
60 | "max_concurrency": 1,
61 | "timeout": 300,
62 | "queue_size": 100
63 | }
64 | self.request_queue = RequestQueue(
65 | max_concurrency=queue_config.get("max_concurrency"),
66 | timeout=queue_config.get("timeout"),
67 | queue_size=queue_config.get("queue_size")
68 | )
69 | await self.request_queue.start(self._process_request)
70 | logger.info("Initialized MLXHandler and started request queue")
71 |
72 | async def generate_vision_stream(self, request: ChatCompletionRequest):
73 | """
74 | Generate a streaming response for vision-based chat completion requests.
75 |
76 | Args:
77 | request: ChatCompletionRequest object containing the messages.
78 |
79 | Returns:
80 | AsyncGenerator: Yields response chunks.
81 | """
82 |
83 | # Create a unique request ID
84 | request_id = f"vision-{uuid.uuid4()}"
85 |
86 | try:
87 | chat_messages, image_paths, model_params = await self._prepare_vision_request(request)
88 |
89 | # Create a request data object
90 | request_data = {
91 | "images": image_paths,
92 | "messages": chat_messages,
93 | "stream": True,
94 | **model_params
95 | }
96 |
97 | # Submit to the vision queue and get the generator
98 | response_generator = await self.request_queue.submit(request_id, request_data)
99 |
100 | # Process and yield each chunk asynchronously
101 | for chunk in response_generator:
102 | if chunk:
103 | chunk = chunk.text
104 | yield chunk
105 |
106 | except asyncio.QueueFull:
107 | logger.error("Too many requests. Service is at capacity.")
108 | content = create_error_response("Too many requests. Service is at capacity.", "rate_limit_exceeded", HTTPStatus.TOO_MANY_REQUESTS)
109 | raise HTTPException(status_code=429, detail=content)
110 |
111 | except Exception as e:
112 | logger.error(f"Error in vision stream generation for request {request_id}: {str(e)}")
113 | content = create_error_response(f"Failed to generate vision stream: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR)
114 | raise HTTPException(status_code=500, detail=content)
115 |
116 | async def generate_vision_response(self, request: ChatCompletionRequest):
117 | """
118 | Generate a complete response for vision-based chat completion requests.
119 | Uses the request queue for handling concurrent requests.
120 |
121 | Args:
122 | request: ChatCompletionRequest object containing the messages.
123 |
124 | Returns:
125 | str: Complete response.
126 | """
127 | try:
128 | # Create a unique request ID
129 | request_id = f"vision-{uuid.uuid4()}"
130 |
131 | # Prepare the vision request
132 | chat_messages, image_paths, model_params = await self._prepare_vision_request(request)
133 |
134 | # Create a request data object
135 | request_data = {
136 | "images": image_paths,
137 | "messages": chat_messages,
138 | "stream": False,
139 | **model_params
140 | }
141 |
142 | response = await self.request_queue.submit(request_id, request_data)
143 |
144 | return response
145 |
146 | except asyncio.QueueFull:
147 | logger.error("Too many requests. Service is at capacity.")
148 | content = create_error_response("Too many requests. Service is at capacity.", "rate_limit_exceeded", HTTPStatus.TOO_MANY_REQUESTS)
149 | raise HTTPException(status_code=429, detail=content)
150 | except Exception as e:
151 | logger.error(f"Error in vision response generation: {str(e)}")
152 | content = create_error_response(f"Failed to generate vision response: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR)
153 | raise HTTPException(status_code=500, detail=content)
154 |
155 | async def generate_text_stream(self, request: ChatCompletionRequest):
156 | """
157 | Generate a streaming response for text-only chat completion requests.
158 | Uses the request queue for handling concurrent requests.
159 |
160 | Args:
161 | request: ChatCompletionRequest object containing the messages.
162 |
163 | Returns:
164 | AsyncGenerator: Yields response chunks.
165 | """
166 | request_id = f"text-{uuid.uuid4()}"
167 |
168 | try:
169 | chat_messages, model_params = await self._prepare_text_request(request)
170 | request_data = {
171 | "messages": chat_messages,
172 | "stream": True,
173 | **model_params
174 | }
175 | response_generator = await self.request_queue.submit(request_id, request_data)
176 |
177 | for chunk in response_generator:
178 | if chunk:
179 | yield chunk.text
180 |
181 | except asyncio.QueueFull:
182 | logger.error("Too many requests. Service is at capacity.")
183 | content = create_error_response("Too many requests. Service is at capacity.", "rate_limit_exceeded", HTTPStatus.TOO_MANY_REQUESTS)
184 | raise HTTPException(status_code=429, detail=content)
185 | except Exception as e:
186 | logger.error(f"Error in text stream generation for request {request_id}: {str(e)}")
187 | content = create_error_response(f"Failed to generate text stream: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR)
188 | raise HTTPException(status_code=500, detail=content)
189 |
190 | async def generate_text_response(self, request: ChatCompletionRequest):
191 | """
192 | Generate a complete response for text-only chat completion requests.
193 | Uses the request queue for handling concurrent requests.
194 |
195 | Args:
196 | request: ChatCompletionRequest object containing the messages.
197 |
198 | Returns:
199 | str: Complete response.
200 | """
201 | try:
202 | # Create a unique request ID
203 | request_id = f"text-{uuid.uuid4()}"
204 |
205 | # Prepare the text request
206 | chat_messages, model_params = await self._prepare_text_request(request)
207 |
208 | # Create a request data object
209 | request_data = {
210 | "messages": chat_messages,
211 | "stream": False,
212 | **model_params
213 | }
214 |
215 | # Submit to the vision queue (reusing the same queue for text requests)
216 | response = await self.request_queue.submit(request_id, request_data)
217 | return response
218 |
219 | except asyncio.QueueFull:
220 | logger.error("Too many requests. Service is at capacity.")
221 | content = create_error_response("Too many requests. Service is at capacity.", "rate_limit_exceeded", HTTPStatus.TOO_MANY_REQUESTS)
222 | raise HTTPException(status_code=429, detail=content)
223 | except Exception as e:
224 | logger.error(f"Error in text response generation: {str(e)}")
225 | content = create_error_response(f"Failed to generate text response: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR)
226 | raise HTTPException(status_code=500, detail=content)
227 |
228 | async def generate_embeddings_response(self, request: EmbeddingRequest):
229 | """
230 | Generate embeddings for a given text input.
231 |
232 | Args:
233 | request: EmbeddingRequest object containing the text input.
234 |
235 | Returns:
236 | List[float]: Embeddings for the input text or images
237 | """
238 | try:
239 | # Create a unique request ID
240 | image_url = request.image_url
241 | # Process the image URL to get a local file path
242 | images = []
243 | if request.image_url:
244 | image_path = await self.image_processor.process_image_url(image_url)
245 | images.append(image_path)
246 | request_id = f"embeddings-{uuid.uuid4()}"
247 | request_data = {
248 | "type": "embeddings",
249 | "input": request.input,
250 | "model": request.model,
251 | "images": images
252 | }
253 |
254 | # Submit to the request queue
255 | response = await self.request_queue.submit(request_id, request_data)
256 |
257 | return response
258 |
259 | except Exception as e:
260 | logger.error(f"Error in embeddings generation: {str(e)}")
261 | content = create_error_response(f"Failed to generate embeddings: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR)
262 | raise HTTPException(status_code=500, detail=content)
263 |
264 |
265 | def __del__(self):
266 | """Cleanup resources on deletion."""
267 | # Removed async cleanup from __del__; use close() instead
268 | pass
269 |
270 | async def close(self):
271 | """Explicitly cleanup resources asynchronously."""
272 | if hasattr(self, 'image_processor'):
273 | await self.image_processor.cleanup()
274 |
275 | async def cleanup(self):
276 | """
277 | Cleanup resources and stop the request queue before shutdown.
278 |
279 | This method ensures all pending requests are properly cancelled
280 | and resources are released, including the image processor.
281 | """
282 | try:
283 | logger.info("Cleaning up MLXVLMHandler resources")
284 | if hasattr(self, 'request_queue'):
285 | await self.request_queue.stop()
286 | if hasattr(self, 'image_processor'):
287 | await self.image_processor.cleanup()
288 | logger.info("MLXVLMHandler cleanup completed successfully")
289 | except Exception as e:
290 | logger.error(f"Error during MLXVLMHandler cleanup: {str(e)}")
291 | raise
292 |
293 | async def _process_request(self, request_data: Dict[str, Any]) -> str:
294 | """
295 | Process a vision request. This is the worker function for the request queue.
296 |
297 | Args:
298 | request_data: Dictionary containing the request data.
299 |
300 | Returns:
301 | str: The model's response.
302 | """
303 | try:
304 | # Check if the request is for embeddings
305 | if request_data.get("type") == "embeddings":
306 | return self.model.get_embeddings(request_data["input"], request_data["images"])
307 |
308 | # Extract request parameters
309 | images = request_data.get("images", [])
310 | messages = request_data.get("messages", [])
311 | stream = request_data.get("stream", False)
312 |
313 | # Remove these keys from model_params
314 | model_params = request_data.copy()
315 | model_params.pop("images", None)
316 | model_params.pop("messages", None)
317 | model_params.pop("stream", None)
318 |
319 | # Start timing
320 | start_time = time.time()
321 |
322 | # Call the model
323 | response = self.model(
324 | images=images,
325 | messages=messages,
326 | stream=stream,
327 | **model_params
328 | )
329 |
330 | return response
331 |
332 | except Exception as e:
333 | logger.error(f"Error processing vision request: {str(e)}")
334 | raise
335 |
336 | async def get_queue_stats(self) -> Dict[str, Any]:
337 | """
338 | Get statistics from the request queue and performance metrics.
339 |
340 | Returns:
341 | Dict with queue and performance statistics.
342 | """
343 | queue_stats = self.request_queue.get_queue_stats()
344 |
345 | return {
346 | "queue_stats": queue_stats,
347 | }
348 |
349 | async def _prepare_text_request(self, request: ChatCompletionRequest) -> Tuple[List[Dict[str, str]], Dict[str, Any]]:
350 | """
351 | Prepare a text request by parsing model parameters and verifying the format of messages.
352 |
353 | Args:
354 | request: ChatCompletionRequest object containing the messages.
355 |
356 | Returns:
357 | Tuple containing the formatted chat messages and model parameters.
358 | """
359 | chat_messages = []
360 |
361 | try:
362 |
363 | # Convert Message objects to dictionaries with 'role' and 'content' keys
364 | chat_messages = []
365 | for message in request.messages:
366 | # Only handle simple string content for text-only requests
367 | if not isinstance(message.content, str):
368 | logger.warning(f"Non-string content in text request will be skipped: {message.role}")
369 | continue
370 |
371 | chat_messages.append({
372 | "role": message.role,
373 | "content": message.content
374 | })
375 |
376 | # Extract model parameters, filtering out None values
377 | model_params = {
378 | k: v for k, v in {
379 | "max_tokens": request.max_tokens,
380 | "temperature": request.temperature,
381 | "top_p": request.top_p,
382 | "frequency_penalty": request.frequency_penalty,
383 | "presence_penalty": request.presence_penalty,
384 | "stop": request.stop,
385 | "n": request.n,
386 | "seed": request.seed
387 | }.items() if v is not None
388 | }
389 |
390 | # Handle response format
391 | if request.response_format and request.response_format.get("type") == "json_object":
392 | model_params["response_format"] = "json"
393 |
394 | # Handle tools and tool choice
395 | if request.tools:
396 | model_params["tools"] = request.tools
397 | if request.tool_choice:
398 | model_params["tool_choice"] = request.tool_choice
399 |
400 | # Log processed data
401 | logger.debug(f"Processed text chat messages: {chat_messages}")
402 | logger.debug(f"Model parameters: {model_params}")
403 |
404 | return chat_messages, model_params
405 |
406 | except HTTPException:
407 | raise
408 | except Exception as e:
409 | logger.error(f"Failed to prepare text request: {str(e)}")
410 | content = create_error_response(f"Failed to process request: {str(e)}", "bad_request", HTTPStatus.BAD_REQUEST)
411 | raise HTTPException(status_code=400, detail=content)
412 |
413 | async def _prepare_vision_request(self, request: ChatCompletionRequest) -> Tuple[List[Dict[str, Any]], List[str], Dict[str, Any]]:
414 | """
415 | Prepare the vision request by processing messages and images.
416 |
417 | This method:
418 | 1. Extracts text messages and image URLs from the request
419 | 2. Processes image URLs to get local file paths
420 | 3. Prepares model parameters
421 | 4. Returns processed data ready for model inference
422 |
423 | Args:
424 | request (ChatCompletionRequest): The incoming request containing messages and parameters.
425 |
426 | Returns:
427 | Tuple[List[Dict[str, Any]], List[str], Dict[str, Any]]: A tuple containing:
428 | - List of processed chat messages
429 | - List of processed image paths
430 | - Dictionary of model parameters
431 | """
432 | chat_messages = []
433 | image_urls = []
434 |
435 | try:
436 | # Process each message in the request
437 | for message in request.messages:
438 | # Handle system and assistant messages (simple text content)
439 | if message.role in ["system", "assistant"]:
440 | chat_messages.append({"role": message.role, "content": message.content})
441 | continue
442 |
443 | # Handle user messages
444 | if message.role == "user":
445 | # Case 1: Simple string content
446 | if isinstance(message.content, str):
447 | chat_messages.append({"role": "user", "content": message.content})
448 | continue
449 |
450 | # Case 2: Content is a list of dictionaries or objects
451 | if isinstance(message.content, list):
452 | # Initialize containers for this message
453 | texts = []
454 | images = []
455 | # Process each content item in the list
456 | for item in message.content:
457 | if item.type == "text":
458 | text = getattr(item, "text", "").strip()
459 | if text:
460 | texts.append(text)
461 |
462 | elif item.type == "image_url":
463 | url = getattr(item, "image_url", None)
464 | if url and hasattr(url, "url"):
465 | url = url.url
466 | # Validate URL
467 | self._validate_image_url(url)
468 | images.append(url)
469 |
470 | # Add collected images to global list
471 | if images:
472 | image_urls.extend(images)
473 |
474 | # Validate constraints
475 | if len(images) > 4:
476 | content = create_error_response("Too many images in a single message (max: 4)", "invalid_request_error", HTTPStatus.BAD_REQUEST)
477 | raise HTTPException(status_code=400, detail=content)
478 |
479 | # Add text content if available, otherwise raise an error
480 | if texts:
481 | chat_messages.append({"role": "user", "content": " ".join(texts)})
482 | else:
483 | chat_messages.append({"role": "user", "content": ""})
484 | else:
485 | content = create_error_response("Invalid message content format", "invalid_request_error", HTTPStatus.BAD_REQUEST)
486 | raise HTTPException(status_code=400, detail=content)
487 |
488 | # Process images and prepare model parameters
489 | image_paths = await self.image_processor.process_image_urls(image_urls)
490 |
491 |
492 | # Get model parameters from the request
493 | temperature = request.temperature or 0.7
494 | top_p = request.top_p or 1.0
495 | frequency_penalty = request.frequency_penalty or 0.0
496 | presence_penalty = request.presence_penalty or 0.0
497 | max_tokens = request.max_tokens or 1024
498 | tools = request.tools or None
499 | tool_choice = request.tool_choice or None
500 |
501 | model_params = {
502 | "temperature": temperature,
503 | "top_p": top_p,
504 | "frequency_penalty": frequency_penalty,
505 | "presence_penalty": presence_penalty,
506 | "max_tokens": max_tokens,
507 | "tools": tools,
508 | "tool_choice": tool_choice
509 | }
510 |
511 | # Log processed data at debug level
512 | logger.debug(f"Processed chat messages: {chat_messages}")
513 | logger.debug(f"Processed image paths: {image_paths}")
514 | logger.debug(f"Model parameters: {model_params}")
515 |
516 | return chat_messages, image_paths, model_params
517 |
518 | except HTTPException:
519 | raise
520 | except Exception as e:
521 | logger.error(f"Failed to prepare vision request: {str(e)}")
522 | content = create_error_response(f"Failed to process request: {str(e)}", "bad_request", HTTPStatus.BAD_REQUEST)
523 | raise HTTPException(status_code=400, detail=content)
524 |
525 | def _validate_image_url(self, url: str) -> None:
526 | """
527 | Validate image URL format.
528 |
529 | Args:
530 | url: The image URL to validate
531 |
532 | Raises:
533 | HTTPException: If URL is invalid
534 | """
535 | if not url:
536 | content = create_error_response("Empty image URL provided", "invalid_request_error", HTTPStatus.BAD_REQUEST)
537 | raise HTTPException(status_code=400, detail=content)
538 |
539 | # Validate base64 images
540 | if url.startswith("data:"):
541 | try:
542 | header, encoded = url.split(",", 1)
543 | if not header.startswith("data:image/"):
544 | raise ValueError("Invalid image format")
545 | base64.b64decode(encoded)
546 | except Exception as e:
547 | content = create_error_response(f"Invalid base64 image: {str(e)}", "invalid_request_error", HTTPStatus.BAD_REQUEST)
548 | raise HTTPException(status_code=400, detail=content)
549 |
--------------------------------------------------------------------------------
/app/handler/parser/__init__.py:
--------------------------------------------------------------------------------
1 | from app.handler.parser.base import BaseToolParser, BaseThinkingParser
2 | from app.handler.parser.qwen3 import Qwen3ToolParser, Qwen3ThinkingParser
3 | from typing import Tuple
4 | __all__ = ['BaseToolParser', 'BaseThinkingParser', 'Qwen3ToolParser', 'Qwen3ThinkingParser']
5 |
6 | parser_map = {
7 | 'qwen3': {
8 | "tool_parser": Qwen3ToolParser,
9 | "thinking_parser": Qwen3ThinkingParser
10 | }
11 | }
12 |
13 | def get_parser(model_name: str) -> Tuple[BaseToolParser, BaseThinkingParser]:
14 | if model_name not in parser_map:
15 | return None, None
16 |
17 | model_parsers = parser_map[model_name]
18 | tool_parser = model_parsers.get("tool_parser")
19 | thinking_parser = model_parsers.get("thinking_parser")
20 |
21 | return (tool_parser() if tool_parser else None,
22 | thinking_parser() if thinking_parser else None)
--------------------------------------------------------------------------------
/app/handler/parser/base.py:
--------------------------------------------------------------------------------
1 | import json
2 | import uuid
3 | from typing import Any, Dict, List, Tuple
4 |
5 |
6 | class BaseThinkingParser:
7 | def __init__(self, thinking_open: str, thinking_close: str):
8 | self.thinking_open = thinking_open
9 | self.thinking_close = thinking_close
10 | self.is_thinking = False
11 |
12 | def parse(self, content: str) -> str:
13 | if self.thinking_open in content:
14 | start_thinking = content.find(self.thinking_open)
15 | end_thinking = content.find(self.thinking_close)
16 | if end_thinking != -1:
17 | return content[start_thinking + len(self.thinking_open):end_thinking].strip(), content[end_thinking + len(self.thinking_close):].strip()
18 | return None, content
19 |
20 | def parse_stream(self, chunk: str) -> Tuple[str, bool]:
21 | if not self.is_thinking:
22 | if chunk == self.thinking_open:
23 | self.is_thinking = True
24 | return None, False
25 | return chunk, False
26 | if chunk == self.thinking_close:
27 | self.is_thinking = False
28 | return None, True
29 |
30 | return {
31 | "reasoning_content": chunk
32 | }, False
33 |
34 | class ParseState:
35 | NORMAL = 0
36 | FOUND_PREFIX = 1
37 | FOUND_FUNC_NAME = 2
38 | FOUND_FUNC_ARGS = 3
39 | PROCESS_FUNC_ARGS = 4
40 | @staticmethod
41 | def next_state(state):
42 | return (state + 1) % 5
43 |
44 | class BaseToolParser:
45 | def __init__(self, tool_open: str, tool_close: str):
46 | self.tool_open = tool_open
47 | self.tool_close = tool_close
48 | self.buffer = ""
49 | self.state = ParseState.NORMAL
50 |
51 | def get_tool_open(self):
52 | return self.tool_open
53 |
54 | def get_tool_close(self):
55 | return self.tool_close
56 |
57 | def parse(self, content: str) -> Tuple[List[Dict[str, Any]], str]:
58 | res = []
59 | start = 0
60 | while True:
61 | start_tool = content.find(self.tool_open, start)
62 | if start_tool == -1:
63 | break
64 | end_tool = content.find(self.tool_close, start_tool + len(self.tool_open))
65 | if end_tool == -1:
66 | break
67 | tool_content = content[start_tool + len(self.tool_open):end_tool].strip()
68 |
69 | try:
70 | json_output = json.loads(tool_content)
71 | res.append(json_output)
72 | except json.JSONDecodeError:
73 | print("Error parsing tool call: ", tool_content)
74 | break
75 | start = end_tool + len(self.tool_close)
76 | return res, content[start:].strip()
77 |
78 | def parse_stream(self, chunk: str):
79 | if self.state == ParseState.NORMAL:
80 | if chunk.strip() == self.tool_open:
81 | self.state = ParseState.next_state(self.state)
82 | self.buffer = ""
83 | self.current_func = None
84 | return None
85 | return chunk
86 |
87 | if self.state == ParseState.FOUND_PREFIX:
88 | self.buffer += chunk
89 | # Try to parse function name
90 | if self.buffer.count('"') >= 4:
91 | # try parse json
92 | try:
93 | json_output = json.loads(self.buffer.rstrip(',') + "}")
94 | self.current_func = {
95 | "name": None
96 | }
97 | self.state = ParseState.next_state(self.state)
98 | return {
99 | "name": json_output["name"],
100 | "arguments": ""
101 | }
102 | except json.JSONDecodeError:
103 | return None
104 | return None
105 |
106 | if self.state == ParseState.FOUND_FUNC_NAME:
107 | # Try to parse function arguments
108 | if chunk.strip() == "arguments":
109 | self.state = ParseState.next_state(self.state)
110 | return None
111 | return None
112 |
113 | if self.state == ParseState.FOUND_FUNC_ARGS:
114 | if ":" in chunk:
115 | chunk = chunk[:chunk.find(":") + 1: ].lstrip()
116 | self.state = ParseState.next_state(self.state)
117 | if not chunk:
118 | return None
119 | return None
120 |
121 | if '}\n' in chunk:
122 | chunk = chunk[:chunk.find('}\n')]
123 |
124 | if chunk == self.tool_close:
125 | # end of arguments
126 | # reset
127 | self.state = ParseState.NORMAL
128 | self.buffer = ""
129 | self.current_func = None
130 | return None
131 |
132 | return {
133 | "name": None,
134 | "arguments": chunk
135 | }
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
--------------------------------------------------------------------------------
/app/handler/parser/qwen3.py:
--------------------------------------------------------------------------------
1 | from app.handler.parser.base import BaseToolParser, BaseThinkingParser
2 |
3 | TOOL_OPEN = ""
4 | TOOL_CLOSE = ""
5 | THINKING_OPEN = ""
6 | THINKING_CLOSE = ""
7 |
8 | class Qwen3ToolParser(BaseToolParser):
9 | """Parser for Qwen3 model's tool response format."""
10 |
11 | def __init__(self):
12 | super().__init__(
13 | tool_open=TOOL_OPEN,
14 | tool_close=TOOL_CLOSE
15 | )
16 |
17 | class Qwen3ThinkingParser(BaseThinkingParser):
18 | """Parser for Qwen3 model's thinking response format."""
19 |
20 | def __init__(self):
21 | super().__init__(
22 | thinking_open=THINKING_OPEN,
23 | thinking_close=THINKING_CLOSE
24 | )
--------------------------------------------------------------------------------
/app/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import asyncio
3 | import gc
4 | import time
5 | from contextlib import asynccontextmanager
6 |
7 | import uvicorn
8 | from fastapi import FastAPI, Request
9 | from fastapi.middleware.cors import CORSMiddleware
10 | from fastapi.responses import JSONResponse
11 | from loguru import logger
12 |
13 | from app.handler.mlx_vlm import MLXVLMHandler
14 | from app.handler.mlx_lm import MLXLMHandler
15 | from app.api.endpoints import router
16 | from app.version import __version__
17 |
18 | # Configure loguru
19 | logger.remove() # Remove default handler
20 | logger.add(
21 | "logs/app.log",
22 | rotation="500 MB",
23 | retention="10 days",
24 | level="INFO",
25 | format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}"
26 | )
27 | logger.add(lambda msg: print(msg), level="INFO") # Also print to console
28 |
29 | def parse_args():
30 | parser = argparse.ArgumentParser(description="OAI-compatible proxy")
31 | parser.add_argument("--model-path", type=str, required=True, help="Huggingface model repo or local path")
32 | parser.add_argument("--model-type", type=str, default="lm", choices=["lm", "vlm"], help="Model type")
33 | parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
34 | parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
35 | parser.add_argument("--max-concurrency", type=int, default=1, help="Maximum number of concurrent requests")
36 | parser.add_argument("--queue-timeout", type=int, default=300, help="Request timeout in seconds")
37 | parser.add_argument("--queue-size", type=int, default=100, help="Maximum queue size for pending requests")
38 | return parser.parse_args()
39 |
40 |
41 | def create_lifespan(config_args):
42 | """Factory function to create a lifespan context manager with access to config args."""
43 | @asynccontextmanager
44 | async def lifespan(app: FastAPI):
45 | try:
46 | logger.info(f"Initializing MLX handler with model path: {config_args.model_path}")
47 | if config_args.model_type == "vlm":
48 | handler = MLXVLMHandler(
49 | model_path=config_args.model_path,
50 | max_concurrency=config_args.max_concurrency
51 | )
52 | else:
53 | handler = MLXLMHandler(
54 | model_path=config_args.model_path,
55 | max_concurrency=config_args.max_concurrency
56 | )
57 | # Initialize queue
58 | await handler.initialize({
59 | "max_concurrency": config_args.max_concurrency,
60 | "timeout": config_args.queue_timeout,
61 | "queue_size": config_args.queue_size
62 | })
63 | logger.info("MLX handler initialized successfully")
64 | app.state.handler = handler
65 | except Exception as e:
66 | logger.error(f"Failed to initialize MLX handler: {str(e)}")
67 | raise
68 | gc.collect()
69 | yield
70 | # Shutdown
71 | logger.info("Shutting down application")
72 | if hasattr(app.state, "handler") and app.state.handler:
73 | try:
74 | # Use the proper cleanup method which handles both request queue and image processor
75 | logger.info("Cleaning up resources")
76 | await app.state.handler.cleanup()
77 | logger.info("Resources cleaned up successfully")
78 | except Exception as e:
79 | logger.error(f"Error during shutdown: {str(e)}")
80 |
81 | return lifespan
82 |
83 | # App instance will be created during setup with the correct lifespan
84 | app = None
85 |
86 | async def setup_server(args) -> uvicorn.Config:
87 | global app
88 |
89 | # Create FastAPI app with the configured lifespan
90 | app = FastAPI(
91 | title="OpenAI-compatible API",
92 | description="API for OpenAI-compatible chat completion and text embedding",
93 | version=__version__,
94 | lifespan=create_lifespan(args)
95 | )
96 |
97 | app.include_router(router)
98 |
99 | # Add CORS middleware
100 | app.add_middleware(
101 | CORSMiddleware,
102 | allow_origins=["*"], # In production, replace with specific origins
103 | allow_credentials=True,
104 | allow_methods=["*"],
105 | allow_headers=["*"],
106 | )
107 |
108 | @app.middleware("http")
109 | async def add_process_time_header(request: Request, call_next):
110 | start_time = time.time()
111 | response = await call_next(request)
112 | process_time = time.time() - start_time
113 | response.headers["X-Process-Time"] = str(process_time)
114 | return response
115 |
116 | @app.exception_handler(Exception)
117 | async def global_exception_handler(request: Request, exc: Exception):
118 | logger.error(f"Global exception handler caught: {str(exc)}", exc_info=True)
119 | return JSONResponse(
120 | status_code=500,
121 | content={"error": {"message": "Internal server error", "type": "internal_error"}}
122 | )
123 |
124 | logger.info(f"Starting server on {args.host}:{args.port}")
125 | config = uvicorn.Config(
126 | app=app,
127 | host=args.host,
128 | port=args.port,
129 | log_level="info",
130 | access_log=True
131 | )
132 | return config
133 |
134 | if __name__ == "__main__":
135 | args = parse_args()
136 | config = asyncio.run(setup_server(args))
137 | uvicorn.Server(config).run()
--------------------------------------------------------------------------------
/app/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cubist38/mlx-openai-server/1c0b466e28a61471b6bb8ed8fe21deb9db7a76a6/app/models/__init__.py
--------------------------------------------------------------------------------
/app/models/mlx_lm.py:
--------------------------------------------------------------------------------
1 | import mlx.core as mx
2 | from mlx_lm.utils import load
3 | from mlx_lm.generate import (
4 | generate,
5 | stream_generate,
6 | )
7 | from mlx_lm.sample_utils import make_sampler
8 | from typing import List, Dict, Union, Generator, Optional, Tuple
9 |
10 | DEFAULT_TEMPERATURE = 0.7
11 | DEFAULT_TOP_P = 0.95
12 | DEFAULT_TOP_K = 20
13 | DEFAULT_MIN_P = 0.0
14 | DEFAULT_SEED = 0
15 | DEFAULT_MAX_TOKENS = 512
16 | DEFAULT_BATCH_SIZE = 32
17 |
18 | class MLX_LM:
19 | """
20 | A wrapper class for MLX Language Model that handles both streaming and non-streaming inference.
21 |
22 | This class provides a unified interface for generating text responses from text prompts,
23 | supporting both streaming and non-streaming modes.
24 | """
25 |
26 | def __init__(self, model_path: str):
27 | try:
28 | self.model, self.tokenizer = load(model_path)
29 | self.pad_token_id = self.tokenizer.pad_token_id
30 | self.bos_token = self.tokenizer.bos_token
31 | self.model_type = self.model.model_type
32 | except Exception as e:
33 | raise ValueError(f"Error loading model: {str(e)}")
34 |
35 | def _apply_pooling_strategy(self, embeddings: mx.array) -> mx.array:
36 | embeddings = mx.mean(embeddings, axis=1)
37 | return embeddings
38 |
39 | def _apply_l2_normalization(self, embeddings: mx.array) -> mx.array:
40 | l2_norms = mx.linalg.norm(embeddings, axis=1, keepdims=True)
41 | embeddings = embeddings / (l2_norms + 1e-8)
42 | return embeddings
43 |
44 | def _batch_process(self, prompts: List[str], batch_size: int = DEFAULT_BATCH_SIZE) -> List[List[int]]:
45 | """Process prompts in batches with optimized tokenization."""
46 | all_tokenized = []
47 |
48 | # Process prompts in batches
49 | for i in range(0, len(prompts), batch_size):
50 | batch = prompts[i:i + batch_size]
51 | tokenized_batch = []
52 |
53 | # Tokenize all prompts in batch
54 | for p in batch:
55 | add_special_tokens = self.bos_token is None or not p.startswith(self.bos_token)
56 | tokens = self.tokenizer.encode(p, add_special_tokens=add_special_tokens)
57 | tokenized_batch.append(tokens)
58 |
59 | # Find max length in batch
60 | max_length = max(len(tokens) for tokens in tokenized_batch)
61 |
62 | # Pad tokens in a vectorized way
63 | for tokens in tokenized_batch:
64 | padding = [self.pad_token_id] * (max_length - len(tokens))
65 | all_tokenized.append(tokens + padding)
66 |
67 | return all_tokenized
68 |
69 | def _preprocess_prompt(self, prompt: str) -> List[int]:
70 | """Tokenize a single prompt efficiently."""
71 | add_special_tokens = self.bos_token is None or not prompt.startswith(self.bos_token)
72 | tokens = self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
73 | return mx.array(tokens)
74 |
75 | def get_model_type(self) -> str:
76 | return self.model_type
77 |
78 | def get_embeddings(
79 | self,
80 | prompts: List[str],
81 | batch_size: int = DEFAULT_BATCH_SIZE,
82 | normalize: bool = True
83 | ) -> List[float]:
84 | """
85 | Get embeddings for a list of prompts efficiently.
86 |
87 | Args:
88 | prompts: List of text prompts
89 | batch_size: Size of batches for processing
90 |
91 | Returns:
92 | List of embeddings as float arrays
93 | """
94 | # Process in batches to optimize memory usage
95 | all_embeddings = []
96 | for i in range(0, len(prompts), batch_size):
97 | batch_prompts = prompts[i:i + batch_size]
98 | tokenized_batch = self._batch_process(batch_prompts, batch_size)
99 |
100 | # Convert to MLX array for efficient computation
101 | tokenized_batch = mx.array(tokenized_batch)
102 |
103 | # Compute embeddings for batch
104 | batch_embeddings = self.model.model(tokenized_batch)
105 | pooled_embedding = self._apply_pooling_strategy(batch_embeddings)
106 | if normalize:
107 | pooled_embedding = self._apply_l2_normalization(pooled_embedding)
108 | all_embeddings.extend(pooled_embedding.tolist())
109 |
110 | return all_embeddings
111 |
112 | def __call__(
113 | self,
114 | messages: List[Dict[str, str]],
115 | stream: bool = False,
116 | **kwargs
117 | ) -> Union[str, Generator[str, None, None]]:
118 | """
119 | Generate text response from the model.
120 |
121 | Args:
122 | messages (List[Dict[str, str]]): List of messages in the conversation.
123 | stream (bool): Whether to stream the response.
124 | **kwargs: Additional parameters for generation
125 | - temperature: Sampling temperature (default: 0.0)
126 | - top_p: Top-p sampling parameter (default: 1.0)
127 | - seed: Random seed (default: 0)
128 | - max_tokens: Maximum number of tokens to generate (default: 256)
129 | """
130 | # Set default parameters if not provided
131 | seed = kwargs.get("seed", DEFAULT_SEED)
132 | max_tokens = kwargs.get("max_tokens", DEFAULT_MAX_TOKENS)
133 | chat_template_kwargs = kwargs.get("chat_template_kwargs", {})
134 |
135 | sampler_kwargs = {
136 | "temp": kwargs.get("temperature", DEFAULT_TEMPERATURE),
137 | "top_p": kwargs.get("top_p", DEFAULT_TOP_P),
138 | "top_k": kwargs.get("top_k", DEFAULT_TOP_K),
139 | "min_p": kwargs.get("min_p", DEFAULT_MIN_P)
140 | }
141 |
142 | mx.random.seed(seed)
143 |
144 | # Prepare input tokens
145 | prompt = self.tokenizer.apply_chat_template(
146 | messages,
147 | **chat_template_kwargs
148 | )
149 |
150 | sampler = make_sampler(
151 | **sampler_kwargs
152 | )
153 |
154 | if not stream:
155 | return generate(
156 | self.model,
157 | self.tokenizer,
158 | prompt,
159 | sampler=sampler,
160 | max_tokens=max_tokens
161 | )
162 | else:
163 | # Streaming mode: return generator of chunks
164 | return stream_generate(
165 | self.model,
166 | self.tokenizer,
167 | prompt,
168 | sampler=sampler,
169 | max_tokens=max_tokens
170 | )
--------------------------------------------------------------------------------
/app/models/mlx_vlm.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Union, Generator, Optional
2 | from mlx_vlm import load
3 | from mlx_vlm.prompt_utils import apply_chat_template
4 | from mlx_vlm.utils import load_config, generate, stream_generate, prepare_inputs
5 | import mlx.core as mx
6 |
7 |
8 | # Default model parameters
9 | DEFAULT_MAX_TOKENS = 256
10 | DEFAULT_TEMPERATURE = 0.0
11 | DEFAULT_TOP_P = 1.0
12 | DEFAULT_SEED = 0
13 |
14 | class MLX_VLM:
15 | """
16 | A wrapper class for MLX Vision Language Model that handles both streaming and non-streaming inference.
17 |
18 | This class provides a unified interface for generating text responses from images and text prompts,
19 | supporting both streaming and non-streaming modes.
20 | """
21 |
22 | def __init__(self, model_path: str):
23 | """
24 | Initialize the MLX_VLM model.
25 |
26 | Args:
27 | model_path (str): Path to the model directory containing model weights and configuration.
28 |
29 | Raises:
30 | ValueError: If model loading fails.
31 | """
32 | try:
33 | self.model, self.processor = load(model_path, lazy=False, trust_remote_code=True)
34 | self.config = load_config(model_path, trust_remote_code=True)
35 | except Exception as e:
36 | raise ValueError(f"Error loading model: {str(e)}")
37 |
38 | def __call__(
39 | self,
40 | messages: List[Dict[str, str]],
41 | images: List[str] = None,
42 | stream: bool = False,
43 | **kwargs
44 | ) -> Union[str, Generator[str, None, None]]:
45 | """
46 | Generate text response from images and messages.
47 |
48 | Args:
49 | images (List[str]): List of image paths to process.
50 | messages (List[Dict[str, str]]): List of message dictionaries with 'role' and 'content' keys.
51 | stream (bool, optional): Whether to stream the response. Defaults to False.
52 | **kwargs: Additional model parameters (temperature, max_tokens, etc.)
53 |
54 | Returns:
55 | Union[str, Generator[str, None, None]]:
56 | - If stream=False: Complete response as string
57 | - If stream=True: Generator yielding response chunks
58 | """
59 | # Prepare the prompt using the chat template
60 | prompt = apply_chat_template(
61 | self.processor,
62 | self.config,
63 | messages,
64 | add_generation_prompt=True,
65 | num_images=len(images) if images else 0
66 | )
67 | # Set default parameters if not provided
68 | model_params = {
69 | "temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
70 | "max_tokens": kwargs.get("max_tokens", DEFAULT_MAX_TOKENS),
71 | **kwargs
72 | }
73 |
74 | if not stream:
75 | # Non-streaming mode: return complete response
76 | text, _ = generate(
77 | self.model,
78 | self.processor,
79 | prompt,
80 | image=images,
81 | **model_params
82 | )
83 | return text
84 | else:
85 | # Streaming mode: return generator of chunks
86 | return stream_generate(
87 | self.model,
88 | self.processor,
89 | prompt,
90 | images,
91 | **model_params
92 | )
93 |
94 | def get_embeddings(
95 | self,
96 | prompts: List[str],
97 | images: Optional[List[str]] = None,
98 | batch_size: int = 1,
99 | normalize: bool = True
100 | ) -> List[List[float]]:
101 | """
102 | Get embeddings for a list of prompts and optional images, supporting batch processing.
103 | Args:
104 | prompts: List of text prompts
105 | images: Optional list of image paths (must be same length as prompts if provided)
106 | batch_size: Size of batches for processing
107 | normalize: Whether to apply L2 normalization to embeddings
108 | Returns:
109 | List of embeddings as float arrays
110 | """
111 | if images is None:
112 | images = []
113 |
114 | # Text-only batch
115 | if not images:
116 | # Batch tokenize and pad
117 | tokenized = [self.processor.tokenizer.encode(self._format_prompt(p, 0), add_special_tokens=True) for p in prompts]
118 | max_len = max(len(t) for t in tokenized)
119 | pad_id = self.processor.tokenizer.pad_token_id or self.processor.tokenizer.eos_token_id
120 | batch_input_ids = [t + [pad_id] * (max_len - len(t)) for t in tokenized]
121 | batch_input_ids = mx.array(batch_input_ids)
122 |
123 | # Run in batches
124 | all_embeddings = []
125 | for i in range(0, len(prompts), batch_size):
126 | batch_ids = batch_input_ids[i:i+batch_size]
127 | embeddings = self.model.language_model.model(batch_ids)
128 | pooled = self._apply_pooling_strategy(embeddings)
129 | if normalize:
130 | pooled = self._apply_l2_normalization(pooled)
131 | all_embeddings.extend(pooled.tolist())
132 | return all_embeddings
133 |
134 | # Image+prompt batch
135 | if len(images) != len(prompts):
136 | raise ValueError("If images are provided, must be same length as prompts (one image per prompt)")
137 |
138 | all_embeddings = []
139 | for i in range(0, len(prompts), batch_size):
140 | batch_prompts = prompts[i:i+batch_size]
141 | batch_images = images[i:i+batch_size]
142 | formatted_prompts = [self._format_prompt(p, 1) for p in batch_prompts]
143 | inputs = prepare_inputs(
144 | self.processor,
145 | batch_images,
146 | formatted_prompts,
147 | getattr(self.model.config, "image_token_index", None)
148 | )
149 | input_ids = inputs["input_ids"]
150 | pixel_values = inputs.get("pixel_values", None)
151 | image_grid_thw = inputs.get("image_grid_thw", None)
152 | inputs_embeds = self.model.get_input_embeddings(input_ids, pixel_values, image_grid_thw)
153 | embeddings = self.model.language_model.model(None, inputs_embeds=inputs_embeds)
154 | pooled = self._apply_pooling_strategy(embeddings)
155 | if normalize:
156 | pooled = self._apply_l2_normalization(pooled)
157 | all_embeddings.extend(pooled.tolist())
158 | return all_embeddings
159 |
160 | def _format_prompt(self, prompt: str, n_images: int) -> str:
161 | """Format a single prompt using the chat template."""
162 | return apply_chat_template(
163 | self.processor,
164 | self.config,
165 | prompt,
166 | add_generation_prompt=True,
167 | num_images=n_images
168 | )
169 |
170 | def _prepare_single_input(self, formatted_prompt: str, images: List[str]) -> Dict:
171 | """Prepare inputs for a single prompt-image pair."""
172 | return prepare_inputs(
173 | self.processor,
174 | images,
175 | formatted_prompt,
176 | getattr(self.model.config, "image_token_index", None)
177 | )
178 |
179 | def _get_single_embedding(
180 | self,
181 | inputs: Dict,
182 | normalize: bool = True
183 | ) -> List[float]:
184 | """Get embedding for a single processed input."""
185 | input_ids = inputs["input_ids"]
186 | pixel_values = inputs.get("pixel_values", None)
187 |
188 | # Extract additional kwargs
189 | data_kwargs = {
190 | k: v for k, v in inputs.items()
191 | if k not in ["input_ids", "pixel_values", "attention_mask"]
192 | }
193 | image_grid_thw = data_kwargs.pop("image_grid_thw", None)
194 |
195 | inputs_embeds = self.model.get_input_embeddings(input_ids, pixel_values, image_grid_thw)
196 | embeddings = self.model.language_model.model(None, inputs_embeds=inputs_embeds)
197 |
198 | # Apply pooling
199 | pooled_embedding = self._apply_pooling_strategy(embeddings)
200 |
201 | # Apply normalization if requested
202 | if normalize:
203 | pooled_embedding = self._apply_l2_normalization(pooled_embedding)
204 |
205 | return pooled_embedding.tolist()
206 |
207 | def _apply_pooling_strategy(self, embeddings: mx.array) -> mx.array:
208 | """Apply mean pooling to embeddings."""
209 | return mx.mean(embeddings, axis=1)
210 |
211 | def _apply_l2_normalization(self, embeddings: mx.array) -> mx.array:
212 | """Apply L2 normalization to embeddings."""
213 | l2_norms = mx.linalg.norm(embeddings, axis=1, keepdims=True)
214 | embeddings = embeddings / (l2_norms + 1e-8)
215 | return embeddings
--------------------------------------------------------------------------------
/app/schemas/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/app/schemas/openai.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Union
2 |
3 | from pydantic import BaseModel, Field, validator
4 | from typing_extensions import Literal
5 | from loguru import logger
6 |
7 |
8 | # Configuration
9 | class Config:
10 | """
11 | Configuration class holding the default model names for different types of requests.
12 | """
13 | TEXT_MODEL = "gpt-4-turbo" # Default model for text-based chat completions
14 | VISION_MODEL = "gpt-4-vision-preview" # Model used for vision-based requests
15 | EMBEDDING_MODEL = "text-embedding-ada-002" # Model used for generating embeddings
16 |
17 | class ErrorResponse(BaseModel):
18 | object: str = Field("error", description="The object type, always 'error'.")
19 | message: str = Field(..., description="The error message.")
20 | type: str = Field(..., description="The type of error.")
21 | param: Optional[str] = Field(None, description="The parameter related to the error, if any.")
22 | code: int = Field(..., description="The error code.")
23 |
24 | # Common models used in both streaming and non-streaming contexts
25 | class ImageUrl(BaseModel):
26 | """
27 | Represents an image URL in a message.
28 | """
29 | url: str = Field(..., description="The image URL.")
30 |
31 | class VisionContentItem(BaseModel):
32 | """
33 | Represents a single content item in a message (text or image).
34 | """
35 | type: str = Field(..., description="The type of content, e.g., 'text' or 'image_url'.")
36 | text: Optional[str] = Field(None, description="The text content, if type is 'text'.")
37 | image_url: Optional[ImageUrl] = Field(None, description="The image URL object, if type is 'image_url'.")
38 |
39 | class FunctionCall(BaseModel):
40 | """
41 | Represents a function call in a message.
42 | """
43 | arguments: str = Field(..., description="The arguments for the function call.")
44 | name: str = Field(..., description="The name of the function to call.")
45 |
46 | class ChatCompletionMessageToolCall(BaseModel):
47 | """
48 | Represents a tool call in a message.
49 | """
50 | id: str = Field(..., description="The ID of the tool call.")
51 | function: FunctionCall = Field(..., description="The function call details.")
52 | type: Literal["function"] = Field(..., description="The type of tool call, always 'function'.")
53 | index: int = Field(..., description="The index of the tool call.")
54 |
55 | class Message(BaseModel):
56 | """
57 | Represents a message in a chat completion.
58 | """
59 | content: Union[str, List[VisionContentItem]] = Field(None, description="The content of the message, either text or a list of vision content items.")
60 | refusal: Optional[str] = Field(None, description="The refusal reason, if any.")
61 | role: Literal["system", "user", "assistant", "tool"] = Field(..., description="The role of the message sender.")
62 | function_call: Optional[FunctionCall] = Field(None, description="The function call, if any.")
63 | reasoning_content: Optional[str] = Field(None, description="The reasoning content, if any.")
64 | tool_calls: Optional[List[ChatCompletionMessageToolCall]] = Field(None, description="List of tool calls, if any.")
65 |
66 | # Common request base for both streaming and non-streaming
67 | class ChatCompletionRequestBase(BaseModel):
68 | """
69 | Base model for chat completion requests.
70 | """
71 | model: str = Field(Config.TEXT_MODEL, description="The model to use for completion.")
72 | messages: List[Message] = Field(..., description="The list of messages in the conversation.")
73 | tools: Optional[List[Dict[str, Any]]] = Field(None, description="List of tools available for the request.")
74 | tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(None, description="Tool choice for the request.")
75 | max_tokens: Optional[int] = Field(None, description="The maximum number of tokens to generate.")
76 | temperature: Optional[float] = Field(0.7, description="Sampling temperature.")
77 | top_p: Optional[float] = Field(1.0, description="Nucleus sampling probability.")
78 | frequency_penalty: Optional[float] = Field(0.0, description="Frequency penalty for token generation.")
79 | presence_penalty: Optional[float] = Field(0.0, description="Presence penalty for token generation.")
80 | stop: Optional[List[str]] = Field(None, description="List of stop sequences.")
81 | n: Optional[int] = Field(1, description="Number of completions to generate.")
82 | response_format: Optional[Dict[str, str]] = Field(None, description="Format for the response.")
83 | seed: Optional[int] = Field(None, description="Random seed for reproducibility.")
84 | user: Optional[str] = Field(None, description="User identifier.")
85 |
86 | @validator("messages")
87 | def check_messages_not_empty(cls, v):
88 | """
89 | Ensure that the messages list is not empty and validate message structure.
90 | """
91 | if not v:
92 | raise ValueError("messages cannot be empty")
93 |
94 | # Validate message history length
95 | if len(v) > 100: # OpenAI's limit is typically around 100 messages
96 | raise ValueError("message history too long")
97 |
98 | # Validate message roles
99 | valid_roles = {"user", "assistant", "system", "tool"}
100 | for msg in v:
101 | if msg.role not in valid_roles:
102 | raise ValueError(f"invalid role: {msg.role}")
103 |
104 | return v
105 |
106 | @validator("temperature")
107 | def check_temperature(cls, v):
108 | """
109 | Validate temperature is between 0 and 2.
110 | """
111 | if v is not None and (v < 0 or v > 2):
112 | raise ValueError("temperature must be between 0 and 2")
113 | return v
114 |
115 | @validator("max_tokens")
116 | def check_max_tokens(cls, v):
117 | """
118 | Validate max_tokens is positive and within reasonable limits.
119 | """
120 | if v is not None:
121 | if v <= 0:
122 | raise ValueError("max_tokens must be positive")
123 | if v > 4096: # Typical limit for GPT-4
124 | raise ValueError("max_tokens too high")
125 | return v
126 |
127 | def is_vision_request(self) -> bool:
128 | """
129 | Check if the request includes image content, indicating a vision-based request.
130 | """
131 | for message in self.messages:
132 | content = message.content
133 | if isinstance(content, list):
134 | for item in content:
135 | if hasattr(item, 'type') and item.type == "image_url":
136 | if hasattr(item, 'image_url') and item.image_url and item.image_url.url:
137 | logger.debug(f"Detected vision request with image: {item.image_url.url[:30]}...")
138 | return True
139 |
140 | logger.debug(f"No images detected, treating as text-only request")
141 | return False
142 |
143 | class ChatTemplateKwargs(BaseModel):
144 | """
145 | Represents the arguments for a chat template.
146 | """
147 | enable_thinking: bool = Field(False, description="Whether to enable thinking mode.")
148 | tools: Optional[List[Dict[str, Any]]] = Field(None, description="List of tools to use in the request.")
149 | add_generation_prompt: bool = Field(True, description="Whether to add a generation prompt to the request.")
150 |
151 | # Non-streaming request and response
152 | class ChatCompletionRequest(ChatCompletionRequestBase):
153 | """
154 | Model for non-streaming chat completion requests.
155 | """
156 | stream: bool = Field(False, description="Whether to stream the response.")
157 | chat_template_kwargs: ChatTemplateKwargs = Field(ChatTemplateKwargs(), description="Arguments for the chat template.")
158 |
159 | class Choice(BaseModel):
160 | """
161 | Represents a choice in a chat completion response.
162 | """
163 | finish_reason: Literal["stop", "length", "tool_calls", "content_filter", "function_call"] = Field(..., description="The reason for the choice.")
164 | index: int = Field(..., description="The index of the choice.")
165 | message: Message = Field(..., description="The message of the choice.")
166 |
167 | class ChatCompletionResponse(BaseModel):
168 | """
169 | Represents a complete chat completion response.
170 | """
171 | id: str = Field(..., description="The response ID.")
172 | object: Literal["chat.completion"] = Field(..., description="The object type, always 'chat.completion'.")
173 | created: int = Field(..., description="The creation timestamp.")
174 | model: str = Field(..., description="The model used for completion.")
175 | choices: List[Choice] = Field(..., description="List of choices in the response.")
176 |
177 |
178 | class ChoiceDeltaFunctionCall(BaseModel):
179 | """
180 | Represents a function call delta in a streaming response.
181 | """
182 | arguments: Optional[str] = Field(None, description="Arguments for the function call delta.")
183 | name: Optional[str] = Field(None, description="Name of the function in the delta.")
184 |
185 | class ChoiceDeltaToolCall(BaseModel):
186 | """
187 | Represents a tool call delta in a streaming response.
188 | """
189 | index: Optional[int] = Field(None, description="Index of the tool call delta.")
190 | id: Optional[str] = Field(None, description="ID of the tool call delta.")
191 | function: Optional[ChoiceDeltaFunctionCall] = Field(None, description="Function call details in the delta.")
192 | type: Optional[str] = Field(None, description="Type of the tool call delta.")
193 |
194 | class Delta(BaseModel):
195 | """
196 | Represents a delta in a streaming response.
197 | """
198 | content: Optional[str] = Field(None, description="Content of the delta.")
199 | function_call: Optional[ChoiceDeltaFunctionCall] = Field(None, description="Function call delta, if any.")
200 | refusal: Optional[str] = Field(None, description="Refusal reason, if any.")
201 | role: Optional[Literal["system", "user", "assistant", "tool"]] = Field(None, description="Role in the delta.")
202 | tool_calls: Optional[List[ChoiceDeltaToolCall]] = Field(None, description="List of tool call deltas, if any.")
203 | reasoning_content: Optional[str] = Field(None, description="Reasoning content, if any.")
204 |
205 | class StreamingChoice(BaseModel):
206 | """
207 | Represents a choice in a streaming response.
208 | """
209 | delta: Delta = Field(..., description="The delta for this streaming choice.")
210 | finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]] = Field(None, description="The reason for finishing, if any.")
211 | index: int = Field(..., description="The index of the streaming choice.")
212 |
213 | class ChatCompletionChunk(BaseModel):
214 | """
215 | Represents a chunk in a streaming chat completion response.
216 | """
217 | id: str = Field(..., description="The chunk ID.")
218 | choices: List[StreamingChoice] = Field(..., description="List of streaming choices in the chunk.")
219 | created: int = Field(..., description="The creation timestamp of the chunk.")
220 | model: str = Field(..., description="The model used for the chunk.")
221 | object: Literal["chat.completion.chunk"] = Field(..., description="The object type, always 'chat.completion.chunk'.")
222 |
223 | # Embedding models
224 | class EmbeddingRequest(BaseModel):
225 | """
226 | Model for embedding requests.
227 | """
228 | model: str = Field(Config.EMBEDDING_MODEL, description="The embedding model to use.")
229 | input: List[str] = Field(..., description="List of text inputs for embedding.")
230 | image_url: Optional[str] = Field(default=None, description="Image URL to embed.")
231 |
232 | class Embedding(BaseModel):
233 | """
234 | Represents an embedding object in an embedding response.
235 | """
236 | embedding: List[float] = Field(..., description="The embedding vector.")
237 | index: int = Field(..., description="The index of the embedding in the list.")
238 | object: str = Field(default="embedding", description="The object type, always 'embedding'.")
239 |
240 | class EmbeddingResponse(BaseModel):
241 | """
242 | Represents an embedding response.
243 | """
244 | object: str = Field("list", description="The object type, always 'list'.")
245 | data: List[Embedding] = Field(..., description="List of embedding objects.")
246 | model: str = Field(..., description="The model used for embedding.")
247 |
248 | class Model(BaseModel):
249 | """
250 | Represents a model in the models list response.
251 | """
252 | id: str = Field(..., description="The model ID.")
253 | object: str = Field("model", description="The object type, always 'model'.")
254 | created: int = Field(..., description="The creation timestamp.")
255 | owned_by: str = Field("openai", description="The owner of the model.")
256 |
257 | class ModelsResponse(BaseModel):
258 | """
259 | Represents the response for the models list endpoint.
260 | """
261 | object: str = Field("list", description="The object type, always 'list'.")
262 | data: List[Model] = Field(..., description="List of models.")
--------------------------------------------------------------------------------
/app/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/app/utils/errors.py:
--------------------------------------------------------------------------------
1 | from http import HTTPStatus
2 | from typing import Union
3 |
4 | from app.schemas.openai import ErrorResponse
5 |
6 |
7 | def create_error_response(
8 | message: str,
9 | err_type: str = "internal_error",
10 | status_code: Union[int, HTTPStatus] = HTTPStatus.INTERNAL_SERVER_ERROR,
11 | param: str = None,
12 | code: str = None
13 | ):
14 | return {
15 | "error": {
16 | "message": message,
17 | "type": err_type,
18 | "param": param,
19 | "code": str(code or (status_code.value if isinstance(status_code, HTTPStatus) else status_code))
20 | }
21 | }
--------------------------------------------------------------------------------
/app/version.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.0.11"
--------------------------------------------------------------------------------
/configure_mlx.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Get the total memory in MB
4 | TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))
5 |
6 | # Calculate 80% and TOTAL_MEM_GB-5GB in MB
7 | EIGHTY_PERCENT=$(($TOTAL_MEM_MB * 80 / 100))
8 | MINUS_5GB=$((($TOTAL_MEM_MB - 5120)))
9 |
10 | # Calculate 70% and TOTAL_MEM_GB-8GB in MB
11 | SEVENTY_PERCENT=$(($TOTAL_MEM_MB * 70 / 100))
12 | MINUS_8GB=$((($TOTAL_MEM_MB - 8192)))
13 |
14 | # Set WIRED_LIMIT_MB to higher value
15 | if [ $EIGHTY_PERCENT -gt $MINUS_5GB ]; then
16 | WIRED_LIMIT_MB=$EIGHTY_PERCENT
17 | else
18 | WIRED_LIMIT_MB=$MINUS_5GB
19 | fi
20 |
21 | # Set WIRED_LWM_MB to higher value
22 | if [ $SEVENTY_PERCENT -gt $MINUS_8GB ]; then
23 | WIRED_LWM_MB=$SEVENTY_PERCENT
24 | else
25 | WIRED_LWM_MB=$MINUS_8GB
26 | fi
27 |
28 | # Display the calculated values
29 | echo "Total memory: $TOTAL_MEM_MB MB"
30 | echo "Maximum limit (iogpu.wired_limit_mb): $WIRED_LIMIT_MB MB"
31 | echo "Lower bound (iogpu.wired_lwm_mb): $WIRED_LWM_MB MB"
32 |
33 | # Apply the values with sysctl, but check if we're already root
34 | if [ "$EUID" -eq 0 ]; then
35 | sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
36 | sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
37 | else
38 | # Try without sudo first, fall back to sudo if needed
39 | sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB 2>/dev/null || \
40 | sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
41 | sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 2>/dev/null || \
42 | sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
43 | fi
--------------------------------------------------------------------------------
/examples/function_calling_examples.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# MLX Server Function Calling Example\n",
8 | "\n",
9 | "This is a detailed text version of the function calling example for MLX Server with OpenAI-compatible API."
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "## Setup"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 53,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "from openai import OpenAI"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "## Initialize the client\n",
33 | "\n",
34 | "Connect to your local MLX server:"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 54,
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "client = OpenAI(\n",
44 | " base_url = \"http://localhost:8000/v1\",\n",
45 | " api_key = \"mlx-server-api-key\"\n",
46 | ")"
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "metadata": {},
52 | "source": [
53 | "## Function calling example\n",
54 | "\n",
55 | "This example demonstrates how to use function calling with the MLX server:"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 56,
61 | "metadata": {},
62 | "outputs": [
63 | {
64 | "name": "stdout",
65 | "output_type": "stream",
66 | "text": [
67 | "ChatCompletion(id='chatcmpl_1748500691067459', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_1748500691650837', function=Function(arguments='{\"city\": \"Tokyo\"}', name='get_weather'), type='function', index=0)], reasoning_content='Okay, the user is asking for the weather in Tokyo. Let me check the tools provided. There\\'s a function called get_weather that takes a city parameter. So I need to call that function with the city set to Tokyo. I\\'ll make sure the JSON is correctly formatted with the city name as a string. Let me double-check the parameters. The function requires \"city\" as a string, so the arguments should be {\"city\": \"Tokyo\"}. Alright, that\\'s all I need for the tool call.'))], created=1748500691, model='mlx-server-model', object='chat.completion', service_tier=None, system_fingerprint=None, usage=None)\n"
68 | ]
69 | }
70 | ],
71 | "source": [
72 | "# Define the user message\n",
73 | "messages = [\n",
74 | " {\n",
75 | " \"role\": \"user\",\n",
76 | " \"content\": \"What is the weather in Tokyo?\"\n",
77 | " }\n",
78 | "]\n",
79 | "\n",
80 | "# Define the available tools/functions\n",
81 | "tools = [\n",
82 | " {\n",
83 | " \"type\": \"function\",\n",
84 | " \"function\": {\n",
85 | " \"name\": \"get_weather\",\n",
86 | " \"description\": \"Get the weather in a given city\",\n",
87 | " \"parameters\": {\n",
88 | " \"type\": \"object\",\n",
89 | " \"properties\": {\n",
90 | " \"city\": {\"type\": \"string\", \"description\": \"The city to get the weather for\"}\n",
91 | " }\n",
92 | " }\n",
93 | " }\n",
94 | " }\n",
95 | "]\n",
96 | "\n",
97 | "# Make the API call\n",
98 | "completion = client.chat.completions.create(\n",
99 | " model=\"mlx-server-model\",\n",
100 | " messages=messages,\n",
101 | " tools=tools,\n",
102 | " tool_choice=\"auto\",\n",
103 | " max_tokens = 512,\n",
104 | " extra_body = {\n",
105 | " \"chat_template_kwargs\": {\n",
106 | " \"enable_thinking\": True\n",
107 | " }\n",
108 | " }\n",
109 | ")\n",
110 | "\n",
111 | "# Get the result\n",
112 | "print(completion)"
113 | ]
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "metadata": {},
118 | "source": [
119 | "## Streaming version"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": 57,
125 | "metadata": {},
126 | "outputs": [
127 | {
128 | "name": "stdout",
129 | "output_type": "stream",
130 | "text": [
131 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=None, reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1748500691, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n",
132 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id='call_1748500694172265', function=ChoiceDeltaToolCallFunction(arguments='', name='get_weather'), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1748500691, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n",
133 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments=' {\"', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1748500691, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n",
134 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='city', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1748500691, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n",
135 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='\":', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1748500691, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n",
136 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments=' \"', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1748500691, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n",
137 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='Tok', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1748500691, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n",
138 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='yo', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1748500691, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n",
139 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='\"}', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1748500691, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n",
140 | "ChatCompletionChunk(id='chatcmpl_1748500691833604', choices=[Choice(delta=ChoiceDelta(content='', function_call=None, refusal=None, role='assistant', tool_calls=None, reasoning_content=None), finish_reason='tool_calls', index=0, logprobs=None)], created=1748500695, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n"
141 | ]
142 | }
143 | ],
144 | "source": [
145 | "# Set stream=True in the API call\n",
146 | "completion = client.chat.completions.create(\n",
147 | " model=\"mlx-server-model\",\n",
148 | " messages=messages,\n",
149 | " tools=tools,\n",
150 | " tool_choice=\"auto\",\n",
151 | " stream=True,\n",
152 | " extra_body = {\n",
153 | " \"chat_template_kwargs\": {\n",
154 | " \"enable_thinking\": False\n",
155 | " }\n",
156 | " }\n",
157 | ")\n",
158 | "\n",
159 | "# Process the streaming response\n",
160 | "for chunk in completion:\n",
161 | " print(chunk)"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": null,
167 | "metadata": {},
168 | "outputs": [],
169 | "source": []
170 | }
171 | ],
172 | "metadata": {
173 | "kernelspec": {
174 | "display_name": "Python 3",
175 | "language": "python",
176 | "name": "python3"
177 | },
178 | "language_info": {
179 | "codemirror_mode": {
180 | "name": "ipython",
181 | "version": 3
182 | },
183 | "file_extension": ".py",
184 | "mimetype": "text/x-python",
185 | "name": "python",
186 | "nbconvert_exporter": "python",
187 | "pygments_lexer": "ipython3",
188 | "version": "3.11.12"
189 | }
190 | },
191 | "nbformat": 4,
192 | "nbformat_minor": 2
193 | }
194 |
--------------------------------------------------------------------------------
/examples/images/attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cubist38/mlx-openai-server/1c0b466e28a61471b6bb8ed8fe21deb9db7a76a6/examples/images/attention.png
--------------------------------------------------------------------------------
/examples/images/green_dog.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cubist38/mlx-openai-server/1c0b466e28a61471b6bb8ed8fe21deb9db7a76a6/examples/images/green_dog.jpeg
--------------------------------------------------------------------------------
/examples/images/password.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cubist38/mlx-openai-server/1c0b466e28a61471b6bb8ed8fe21deb9db7a76a6/examples/images/password.jpg
--------------------------------------------------------------------------------
/examples/lm_embeddings_examples.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Embeddings API Examples with MLX Server"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "This notebook demonstrates how to use the embeddings endpoint of MLX Server through the OpenAI-compatible API. You'll learn how to generate embeddings, work with batches, compare similarity between texts, and use embeddings for practical applications."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "## Setup and Connection"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 1,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "# Import the OpenAI client for API communication\n",
31 | "from openai import OpenAI\n",
32 | "\n",
33 | "# Connect to the local MLX Server with OpenAI-compatible API\n",
34 | "client = OpenAI(\n",
35 | " base_url=\"http://localhost:8000/v1\",\n",
36 | " api_key=\"fake-api-key\",\n",
37 | ")"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {},
43 | "source": [
44 | "## Basic Embedding Generation\n",
45 | "\n",
46 | "### Single Text Embedding\n"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 2,
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "# Generate embedding for a single text input\n",
56 | "single_text = \"Artificial intelligence is transforming how we interact with technology.\"\n",
57 | "response = client.embeddings.create(\n",
58 | " input=[single_text],\n",
59 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n",
60 | ")"
61 | ]
62 | },
63 | {
64 | "cell_type": "markdown",
65 | "metadata": {},
66 | "source": [
67 | "### Batch Processing Multiple Texts"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 3,
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "text_batch = [\n",
77 | " \"Machine learning algorithms improve with more data\",\n",
78 | " \"Natural language processing helps computers understand human language\",\n",
79 | " \"Computer vision allows machines to interpret visual information\"\n",
80 | "]\n",
81 | "\n",
82 | "batch_response = client.embeddings.create(\n",
83 | " input=text_batch,\n",
84 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n",
85 | ")"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 4,
91 | "metadata": {},
92 | "outputs": [
93 | {
94 | "name": "stdout",
95 | "output_type": "stream",
96 | "text": [
97 | "Number of embeddings generated: 3\n",
98 | "Dimensions of each embedding: 1536\n"
99 | ]
100 | }
101 | ],
102 | "source": [
103 | "# Access all embeddings\n",
104 | "embeddings = [item.embedding for item in batch_response.data]\n",
105 | "print(f\"Number of embeddings generated: {len(embeddings)}\")\n",
106 | "print(f\"Dimensions of each embedding: {len(embeddings[0])}\")"
107 | ]
108 | },
109 | {
110 | "cell_type": "markdown",
111 | "metadata": {},
112 | "source": [
113 | "## Semantic Similarity Calculation\n",
114 | "\n",
115 | "One of the most common uses for embeddings is measuring semantic similarity between texts."
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 5,
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "import numpy as np"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 6,
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "def cosine_similarity_score(vec1, vec2):\n",
134 | " \"\"\"Calculate cosine similarity between two vectors\"\"\"\n",
135 | " dot_product = np.dot(vec1, vec2)\n",
136 | " norm1 = np.linalg.norm(vec1)\n",
137 | " norm2 = np.linalg.norm(vec2)\n",
138 | " return dot_product / (norm1 * norm2)"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 7,
144 | "metadata": {},
145 | "outputs": [],
146 | "source": [
147 | "# Example texts to compare\n",
148 | "text1 = \"Dogs are loyal pets that provide companionship\"\n",
149 | "text2 = \"Canines make friendly companions for humans\"\n",
150 | "text3 = \"Quantum physics explores the behavior of matter at atomic scales\""
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": 8,
156 | "metadata": {},
157 | "outputs": [],
158 | "source": [
159 | "# Generate embeddings\n",
160 | "comparison_texts = [text1, text2, text3]\n",
161 | "comparison_response = client.embeddings.create(\n",
162 | " input=comparison_texts,\n",
163 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n",
164 | ")\n",
165 | "comparison_embeddings = [item.embedding for item in comparison_response.data]"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": 9,
171 | "metadata": {},
172 | "outputs": [
173 | {
174 | "name": "stdout",
175 | "output_type": "stream",
176 | "text": [
177 | "Similarity between text1 and text2: 0.8142\n",
178 | "Similarity between text1 and text3: 0.6082\n",
179 | "Similarity between text2 and text3: 0.5739\n"
180 | ]
181 | }
182 | ],
183 | "source": [
184 | "# Compare similarities\n",
185 | "similarity_1_2 = cosine_similarity_score(comparison_embeddings[0], comparison_embeddings[1])\n",
186 | "similarity_1_3 = cosine_similarity_score(comparison_embeddings[0], comparison_embeddings[2])\n",
187 | "similarity_2_3 = cosine_similarity_score(comparison_embeddings[1], comparison_embeddings[2])\n",
188 | "\n",
189 | "print(f\"Similarity between text1 and text2: {similarity_1_2:.4f}\")\n",
190 | "print(f\"Similarity between text1 and text3: {similarity_1_3:.4f}\")\n",
191 | "print(f\"Similarity between text2 and text3: {similarity_2_3:.4f}\")"
192 | ]
193 | },
194 | {
195 | "cell_type": "markdown",
196 | "metadata": {},
197 | "source": [
198 | "## Text Search Using Embeddings"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 10,
204 | "metadata": {},
205 | "outputs": [],
206 | "source": [
207 | "# Sample document collection\n",
208 | "documents = [\n",
209 | " \"The quick brown fox jumps over the lazy dog\",\n",
210 | " \"Machine learning models require training data\",\n",
211 | " \"Neural networks are inspired by biological neurons\",\n",
212 | " \"Deep learning is a subset of machine learning\",\n",
213 | " \"Natural language processing helps with text analysis\",\n",
214 | " \"Computer vision systems can detect objects in images\"\n",
215 | "]"
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": 11,
221 | "metadata": {},
222 | "outputs": [],
223 | "source": [
224 | "# Generate embeddings for all documents\n",
225 | "doc_response = client.embeddings.create(\n",
226 | " input=documents,\n",
227 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n",
228 | ")\n",
229 | "doc_embeddings = [item.embedding for item in doc_response.data]"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": 12,
235 | "metadata": {},
236 | "outputs": [
237 | {
238 | "name": "stdout",
239 | "output_type": "stream",
240 | "text": [
241 | "Search results:\n",
242 | "Score: 0.8574 - Computer vision systems can detect objects in images\n",
243 | "Score: 0.8356 - Neural networks are inspired by biological neurons\n",
244 | "Score: 0.8266 - Natural language processing helps with text analysis\n",
245 | "Score: 0.8141 - Deep learning is a subset of machine learning\n",
246 | "Score: 0.7474 - Machine learning models require training data\n",
247 | "Score: 0.5936 - The quick brown fox jumps over the lazy dog\n"
248 | ]
249 | }
250 | ],
251 | "source": [
252 | "def search_documents(query, doc_collection, doc_embeddings):\n",
253 | " \"\"\"Search for documents similar to query\"\"\"\n",
254 | " # Generate embedding for query\n",
255 | " query_response = client.embeddings.create(\n",
256 | " input=[query],\n",
257 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n",
258 | " )\n",
259 | " query_embedding = query_response.data[0].embedding\n",
260 | " \n",
261 | " # Calculate similarity scores\n",
262 | " similarities = []\n",
263 | " for doc_embedding in doc_embeddings:\n",
264 | " similarity = cosine_similarity_score(query_embedding, doc_embedding)\n",
265 | " similarities.append(similarity)\n",
266 | " \n",
267 | " # Return results with scores\n",
268 | " results = []\n",
269 | " for i, score in enumerate(similarities):\n",
270 | " results.append((doc_collection[i], score))\n",
271 | " \n",
272 | " # Sort by similarity score (highest first)\n",
273 | " return sorted(results, key=lambda x: x[1], reverse=True)\n",
274 | "\n",
275 | "# Example search\n",
276 | "search_results = search_documents(\"How do AI models learn?\", documents, doc_embeddings)\n",
277 | "\n",
278 | "print(\"Search results:\")\n",
279 | "for doc, score in search_results:\n",
280 | " print(f\"Score: {score:.4f} - {doc}\")"
281 | ]
282 | }
283 | ],
284 | "metadata": {
285 | "kernelspec": {
286 | "display_name": "Python 3",
287 | "language": "python",
288 | "name": "python3"
289 | },
290 | "language_info": {
291 | "codemirror_mode": {
292 | "name": "ipython",
293 | "version": 3
294 | },
295 | "file_extension": ".py",
296 | "mimetype": "text/x-python",
297 | "name": "python",
298 | "nbconvert_exporter": "python",
299 | "pygments_lexer": "ipython3",
300 | "version": "3.11.12"
301 | }
302 | },
303 | "nbformat": 4,
304 | "nbformat_minor": 2
305 | }
306 |
--------------------------------------------------------------------------------
/examples/pdfs/lab03.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cubist38/mlx-openai-server/1c0b466e28a61471b6bb8ed8fe21deb9db7a76a6/examples/pdfs/lab03.pdf
--------------------------------------------------------------------------------
/examples/simple_rag_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "7fe5da27",
6 | "metadata": {},
7 | "source": [
8 | "# PDF Question Answering with Embedding Search\n",
9 | "\n",
10 | "This notebook demonstrates how to build a simple RAG (Retrieval-Augmented Generation) system that:\n",
11 | "\n",
12 | "1. Extracts and chunks text from PDF documents.\n",
13 | "2. Embeds the text using a local embedding model served via MLX Server.\n",
14 | "3. Stores the embeddings in a FAISS index for fast retrieval.\n",
15 | "4. Answers user queries by retrieving relevant chunks and using a chat model to respond based on context."
16 | ]
17 | },
18 | {
19 | "cell_type": "markdown",
20 | "id": "fc490ab0",
21 | "metadata": {},
22 | "source": [
23 | "Before running the notebook, make sure to launch the local MLX Server by executing the following command in your terminal (`lm`: Text-only model): \n",
24 | "```bash\n",
25 | "mlx-server launch --model-path mlx-community/Qwen3-4B-8bit --model-type lm\n",
26 | "```\n",
27 | "This command starts the MLX API server locally at http://localhost:8000/v1, which exposes an OpenAI-compatible interface. It enables the specified model to be used for both embedding (vector representation of text) and response generation (chat completion).\n",
28 | "\n",
29 | "For this illustration, we use the model `mlx-community/Qwen3-4B-8bit`, a lightweight and efficient language model that supports both tasks. You can substitute this with any other compatible model depending on your use case and hardware capability."
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "id": "036359c3",
35 | "metadata": {},
36 | "source": [
37 | "## Install dependencies\n"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 15,
43 | "id": "2d641abc",
44 | "metadata": {},
45 | "outputs": [
46 | {
47 | "name": "stdout",
48 | "output_type": "stream",
49 | "text": [
50 | "Note: you may need to restart the kernel to use updated packages.\n"
51 | ]
52 | }
53 | ],
54 | "source": [
55 | "# Install required packages\n",
56 | "%pip install -Uq numpy PyMuPDF faiss-cpu openai"
57 | ]
58 | },
59 | {
60 | "cell_type": "markdown",
61 | "id": "ab9a6f78",
62 | "metadata": {},
63 | "source": [
64 | "## Initialize MLX Server client\n"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 16,
70 | "id": "546b6775",
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "# Connect to the local MLX Server that serves embedding and chat models\n",
75 | "from openai import OpenAI\n",
76 | "\n",
77 | "client = OpenAI(\n",
78 | " base_url=\"http://localhost:8000/v1\", # This is your MLX Server endpoint\n",
79 | " api_key=\"fake-api-key\" # Dummy key, not used by local server\n",
80 | ")"
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "id": "c573141f",
86 | "metadata": {},
87 | "source": [
88 | "## Load and chunk PDF document\n",
89 | "\n",
90 | "We load the PDF file and split it into smaller chunks to ensure each chunk fits within the context window of the model."
91 | ]
92 | },
93 | {
94 | "cell_type": "markdown",
95 | "id": "8001354f",
96 | "metadata": {},
97 | "source": [
98 | "### Read PDF\n",
99 | "Extracts text from each page of the PDF."
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 17,
105 | "id": "27be75d1",
106 | "metadata": {},
107 | "outputs": [],
108 | "source": [
109 | "import fitz # PyMuPDF\n",
110 | "\n",
111 | "def read_pdf(path):\n",
112 | " \"\"\"Extract text from each page of a PDF file.\"\"\"\n",
113 | " doc = fitz.open(path)\n",
114 | " texts = []\n",
115 | " for page in doc:\n",
116 | " text = page.get_text().strip()\n",
117 | " if text: \n",
118 | " texts.append(text)\n",
119 | " doc.close()\n",
120 | " return texts"
121 | ]
122 | },
123 | {
124 | "cell_type": "markdown",
125 | "id": "adae0912",
126 | "metadata": {},
127 | "source": [
128 | "### Chunk Text\n",
129 | "Splits the extracted text into smaller chunks with overlap to preserve context.\n",
130 | "We use the following parameters:\n",
131 | "- `chunk_size=500`: Maximum number of words in a chunk.\n",
132 | "- `overlap=100`: Number of words overlapping between consecutive chunks to avoid breaking context too harshly.\n",
133 | "\n",
134 | "Each chunk is created using simple whitespace (`\" \"`) tokenization and rejoined with spaces, which works well for general text."
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": 18,
140 | "id": "e093a75f",
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "def chunk_text(texts, chunk_size=400, overlap=200):\n",
145 | " \"\"\"Split text into smaller chunks with overlap for better context preservation.\"\"\"\n",
146 | " chunks = []\n",
147 | " for text in texts:\n",
148 | " words = text.split() \n",
149 | " i = 0\n",
150 | " while i < len(words):\n",
151 | " chunk = words[i:i + chunk_size]\n",
152 | " chunks.append(\" \".join(chunk))\n",
153 | " i += chunk_size - overlap\n",
154 | " return chunks"
155 | ]
156 | },
157 | {
158 | "cell_type": "markdown",
159 | "id": "603c9e71",
160 | "metadata": {},
161 | "source": [
162 | "## Save embeddings and chunks to FAISS"
163 | ]
164 | },
165 | {
166 | "cell_type": "markdown",
167 | "id": "d7d76e7d",
168 | "metadata": {},
169 | "source": [
170 | "### Embed Chunks\n",
171 | "\n",
172 | "Uses the MLX Server to generate embeddings for the text chunks."
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": 19,
178 | "id": "eeb628d3",
179 | "metadata": {},
180 | "outputs": [],
181 | "source": [
182 | "import numpy as np\n",
183 | "\n",
184 | "# Embed chunks using the model served by MLX Server\n",
185 | "def embed_chunks(chunks, model_name):\n",
186 | " \"\"\"Generate embeddings for text chunks using the MLX Server model.\"\"\"\n",
187 | " response = client.embeddings.create(input=chunks, model=model_name)\n",
188 | " embeddings = [np.array(item.embedding).astype('float32') for item in response.data]\n",
189 | " return embeddings"
190 | ]
191 | },
192 | {
193 | "cell_type": "markdown",
194 | "id": "294297ae",
195 | "metadata": {},
196 | "source": [
197 | "### Save to FAISS\n",
198 | "\n",
199 | "Saves the embeddings in a FAISS index and the chunks in a metadata file."
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": null,
205 | "id": "68d39501",
206 | "metadata": {},
207 | "outputs": [],
208 | "source": [
209 | "import faiss\n",
210 | "import os\n",
211 | "import pickle\n",
212 | "import numpy as np\n",
213 | "\n",
214 | "def normalize(vectors):\n",
215 | " \"\"\"\n",
216 | " Normalize a set of vectors.\n",
217 | " \"\"\"\n",
218 | " norms = np.linalg.norm(vectors, axis=1, keepdims=True) # Compute L2 norms for each vector\n",
219 | " return vectors / norms # Divide each vector by its norm to normalize\n",
220 | "\n",
221 | "def save_faiss_index(embeddings, chunks, index_path=\"db/index.faiss\", meta_path=\"db/meta.pkl\"):\n",
222 | " \"\"\"\n",
223 | " The embeddings are stored in a FAISS index, \n",
224 | " and the corresponding text chunks are saved in a metadata file locally. \n",
225 | " \"\"\"\n",
226 | " if not os.path.exists(\"db\"):\n",
227 | " os.makedirs(\"db\") \n",
228 | " dim = len(embeddings[0])\n",
229 | " \n",
230 | " # Normalize the embeddings to unit length for cosine similarity\n",
231 | " # This is required because FAISS's IndexFlatIP uses inner product\n",
232 | " embeddings = normalize(embeddings)\n",
233 | " index = faiss.IndexFlatIP(dim)\n",
234 | " index.add(np.array(embeddings))\n",
235 | " faiss.write_index(index, index_path)\n",
236 | " with open(meta_path, \"wb\") as f:\n",
237 | " pickle.dump(chunks, f)"
238 | ]
239 | },
240 | {
241 | "cell_type": "markdown",
242 | "id": "3d34ab09",
243 | "metadata": {},
244 | "source": [
245 | "Combines the above steps into a single pipeline to process a PDF."
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": 21,
251 | "id": "c2a4f4a2",
252 | "metadata": {},
253 | "outputs": [],
254 | "source": [
255 | "# Full pipeline: Read PDF → Chunk → Embed → Save\n",
256 | "def prepare_pdf(pdf_path, model_name):\n",
257 | " texts = read_pdf(pdf_path)\n",
258 | " chunks = chunk_text(texts)\n",
259 | " embeddings = embed_chunks(chunks, model_name)\n",
260 | " save_faiss_index(embeddings, chunks)"
261 | ]
262 | },
263 | {
264 | "cell_type": "markdown",
265 | "id": "012725f5",
266 | "metadata": {},
267 | "source": [
268 | "## Query PDF using FAISS"
269 | ]
270 | },
271 | {
272 | "cell_type": "markdown",
273 | "id": "4056aa04",
274 | "metadata": {},
275 | "source": [
276 | "### Load FAISS Index"
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": 22,
282 | "id": "4022de75",
283 | "metadata": {},
284 | "outputs": [],
285 | "source": [
286 | "def load_faiss_index(index_path=\"db/index.faiss\", meta_path=\"db/meta.pkl\"):\n",
287 | " \"\"\"Load the FAISS index and corresponding text chunks from disk.\"\"\"\n",
288 | " index = faiss.read_index(index_path)\n",
289 | " with open(meta_path, \"rb\") as f:\n",
290 | " chunks = pickle.load(f)\n",
291 | " return index, chunks"
292 | ]
293 | },
294 | {
295 | "cell_type": "markdown",
296 | "id": "142389ab",
297 | "metadata": {},
298 | "source": [
299 | "### Embed Query\n",
300 | "Embeds the user's query using the same model."
301 | ]
302 | },
303 | {
304 | "cell_type": "code",
305 | "execution_count": 23,
306 | "id": "4791b77c",
307 | "metadata": {},
308 | "outputs": [],
309 | "source": [
310 | "def embed_query(query, model_name):\n",
311 | " \"\"\"Convert a query string into an embedding vector.\"\"\"\n",
312 | " embedding = client.embeddings.create(input=[query], model=model_name).data[0].embedding\n",
313 | " return np.array(embedding).astype('float32')"
314 | ]
315 | },
316 | {
317 | "cell_type": "markdown",
318 | "id": "ada96dd2",
319 | "metadata": {},
320 | "source": [
321 | "### Retrieve Chunks\n",
322 | "Retrieves the top-k most relevant chunks based on the query embedding."
323 | ]
324 | },
325 | {
326 | "cell_type": "code",
327 | "execution_count": 24,
328 | "id": "267bcfe2",
329 | "metadata": {},
330 | "outputs": [],
331 | "source": [
332 | "def retrieve_chunks(query, index, chunks, model_name, top_k=5):\n",
333 | " \"\"\"Retrieve the top-k most relevant chunks from the FAISS index.\"\"\"\n",
334 | " query_vector = embed_query(query, model_name).reshape(1, -1)\n",
335 | " query_vector = normalize(query_vector) # Normalize the query vector\n",
336 | " distances, indices = index.search(query_vector, top_k) # Search for nearest neighbors\n",
337 | " relevant_chunks = [chunks[i] for i in indices[0]] \n",
338 | " return relevant_chunks"
339 | ]
340 | },
341 | {
342 | "cell_type": "markdown",
343 | "id": "6c3491a5",
344 | "metadata": {},
345 | "source": [
346 | "### Generate Answer with Context"
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "execution_count": 25,
352 | "id": "34a811b3",
353 | "metadata": {},
354 | "outputs": [],
355 | "source": [
356 | "def answer_with_context(query, retrieved_chunks, model_name):\n",
357 | " \"\"\"Generate a response to the query using retrieved chunks as context.\"\"\"\n",
358 | " context = \"\\n\".join(retrieved_chunks)\n",
359 | " prompt = f\"\"\"You are a helpful assistant. Use the context below to answer the question.\n",
360 | " Context:\n",
361 | " {context}\n",
362 | " Question: {query}\n",
363 | " Answer:\"\"\"\n",
364 | " \n",
365 | " response = client.chat.completions.create(\n",
366 | " model=model_name,\n",
367 | " messages=[{\"role\": \"user\", \"content\": prompt}]\n",
368 | " )\n",
369 | " return response.choices[0].message.content"
370 | ]
371 | },
372 | {
373 | "cell_type": "markdown",
374 | "id": "487af341",
375 | "metadata": {},
376 | "source": [
377 | "Combines the query steps into a single function."
378 | ]
379 | },
380 | {
381 | "cell_type": "code",
382 | "execution_count": 26,
383 | "id": "2f09e17d",
384 | "metadata": {},
385 | "outputs": [],
386 | "source": [
387 | "# Full pipeline: Query → Embed → Retrieve → Answer\n",
388 | "def query_pdf(query, model_name=\"mlx-community/Qwen3-4B-8bit\"):\n",
389 | " index, chunks = load_faiss_index()\n",
390 | " top_chunks = retrieve_chunks(query, index, chunks, model_name)\n",
391 | " return answer_with_context(query, top_chunks, model_name)"
392 | ]
393 | },
394 | {
395 | "cell_type": "markdown",
396 | "id": "71c53682",
397 | "metadata": {},
398 | "source": [
399 | "## Example Usage"
400 | ]
401 | },
402 | {
403 | "cell_type": "markdown",
404 | "id": "7b022805",
405 | "metadata": {},
406 | "source": [
407 | "Index text chunks from PDF into FAISS using Qwen3-4B-8bit model"
408 | ]
409 | },
410 | {
411 | "cell_type": "code",
412 | "execution_count": 27,
413 | "id": "18552003",
414 | "metadata": {},
415 | "outputs": [],
416 | "source": [
417 | "prepare_pdf(\"./pdfs/lab03.pdf\", \"mlx-community/Qwen3-4B-8bit\")"
418 | ]
419 | },
420 | {
421 | "cell_type": "markdown",
422 | "id": "af4fcab9",
423 | "metadata": {},
424 | "source": [
425 | "Sample query:"
426 | ]
427 | },
428 | {
429 | "cell_type": "code",
430 | "execution_count": 28,
431 | "id": "9c333397",
432 | "metadata": {},
433 | "outputs": [
434 | {
435 | "name": "stdout",
436 | "output_type": "stream",
437 | "text": [
438 | "Query: What submissions do I need to submit in this lab?\n",
439 | "Response: For this lab, you need to submit the following:\n",
440 | "\n",
441 | "1. **StudentID1_StudentID2_Report.pdf**: \n",
442 | " - A short report explaining your solution, any problems encountered, and any approaches you tried that didn't work. This report should not include your source code but should clarify your solution and any issues you faced.\n",
443 | "\n",
444 | "2. **.patch**: \n",
445 | " - A git diff file showing the changes you made to the xv6 codebase. You can generate this by running the command: \n",
446 | " ```bash\n",
447 | " $ git diff > .patch\n",
448 | " ```\n",
449 | "\n",
450 | "3. **Zip file of xv6**: \n",
451 | " - A zip file containing the modified xv6 codebase. The code should be in a clean state (i.e., after running `make clean`). The filename should follow the format: \n",
452 | " ```bash\n",
453 | " .zip\n",
454 | " ``` \n",
455 | " For example, if the students' IDs are 2312001 and 2312002, the filename would be: \n",
456 | " ```bash\n",
457 | " 2312001_2312002.zip\n",
458 | " ```\n",
459 | "\n",
460 | "Make sure to follow these submission guidelines carefully to ensure your work is graded properly.\n"
461 | ]
462 | }
463 | ],
464 | "source": [
465 | "# Ask a question related to the content of the PDF\n",
466 | "query = \"What submissions do I need to submit in this lab?\"\n",
467 | "print(\"Query: \", query)\n",
468 | "response = query_pdf(query, model_name=\"mlx-community/Qwen3-4B-8bit\")\n",
469 | "print(\"Response: \", response)"
470 | ]
471 | },
472 | {
473 | "cell_type": "code",
474 | "execution_count": 29,
475 | "id": "b8eee4af",
476 | "metadata": {},
477 | "outputs": [
478 | {
479 | "name": "stdout",
480 | "output_type": "stream",
481 | "text": [
482 | "Query: What is the hint for Lab 4.2 – Speed up system calls?\n",
483 | "Response: The hint for Lab 4.2 – Speed up system calls is to choose permission bits that allow userspace to only read the page. This ensures that the shared page between userspace and the kernel can be accessed by userspace for reading the PID, but not modified, which is necessary for the optimization of the getpid() system call.\n"
484 | ]
485 | }
486 | ],
487 | "source": [
488 | "# Ask a question related to the content of the PDF\n",
489 | "query = \"What is the hint for Lab 4.2 – Speed up system calls?\"\n",
490 | "print(\"Query: \", query)\n",
491 | "response = query_pdf(query, model_name=\"mlx-community/Qwen3-4B-8bit\")\n",
492 | "print(\"Response: \", response)"
493 | ]
494 | }
495 | ],
496 | "metadata": {
497 | "kernelspec": {
498 | "display_name": "base",
499 | "language": "python",
500 | "name": "python3"
501 | },
502 | "language_info": {
503 | "codemirror_mode": {
504 | "name": "ipython",
505 | "version": 3
506 | },
507 | "file_extension": ".py",
508 | "mimetype": "text/x-python",
509 | "name": "python",
510 | "nbconvert_exporter": "python",
511 | "pygments_lexer": "ipython3",
512 | "version": "3.12.2"
513 | }
514 | },
515 | "nbformat": 4,
516 | "nbformat_minor": 5
517 | }
518 |
--------------------------------------------------------------------------------
/examples/vlm_embeddings_examples.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Vision-Language Model (VLM) Embeddings with MLX Server\n",
8 | "\n",
9 | "This notebook demonstrates how to leverage the embeddings endpoint of MLX Server through its OpenAI-compatible API. Vision-Language Models (VLMs) can process both images and text, allowing for multimodal understanding and representation.\n"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "\n",
17 | "## Introduction\n",
18 | "\n",
19 | "MLX Server provides an efficient way to serve multimodal models on Apple Silicon. In this notebook, we'll explore how to:\n",
20 | "\n",
21 | "- Generate embeddings for text and images\n",
22 | "- Work with the OpenAI-compatible API\n",
23 | "- Calculate similarity between text and image representations\n",
24 | "- Understand how these embeddings can be used for practical applications\n",
25 | "\n",
26 | "Embeddings are high-dimensional vector representations of content that capture semantic meaning, making them useful for search, recommendation systems, and other AI applications."
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {},
32 | "source": [
33 | "## 1. Setup and API Connection\n",
34 | "\n",
35 | "- A local server endpoint (`http://localhost:8000/v1`)\n",
36 | "- A placeholder API key (since MLX Server doesn't require authentication by default)\n",
37 | "\n",
38 | "Make sure you have MLX Server running locally before executing this notebook."
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 1,
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "# Import the OpenAI client for API communication\n",
48 | "from openai import OpenAI\n",
49 | "\n",
50 | "# Connect to the local MLX Server with OpenAI-compatible API\n",
51 | "client = OpenAI(\n",
52 | " base_url=\"http://localhost:8000/v1\",\n",
53 | " api_key=\"fake-api-key\",\n",
54 | ")"
55 | ]
56 | },
57 | {
58 | "cell_type": "markdown",
59 | "metadata": {},
60 | "source": [
61 | "## 2. Image Processing for API Requests\n",
62 | "\n",
63 | "When working with image inputs, we need to prepare them in a format that the API can understand. The OpenAI-compatible API expects images to be provided as base64-encoded data URIs.\n",
64 | "\n",
65 | "Below, we'll import the necessary libraries and define a helper function to convert PIL Image objects to the required format."
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 2,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "from PIL import Image\n",
75 | "from io import BytesIO\n",
76 | "import base64"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": 3,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "# To send images to the API, we need to convert them to base64-encoded strings in a data URI format.\n",
86 | "\n",
87 | "def image_to_base64(image: Image.Image):\n",
88 | " \"\"\"\n",
89 | " Convert a PIL Image to a base64-encoded data URI string that can be sent to the API.\n",
90 | " \n",
91 | " Args:\n",
92 | " image: A PIL Image object\n",
93 | " \n",
94 | " Returns:\n",
95 | " A data URI string with the base64-encoded image\n",
96 | " \"\"\"\n",
97 | " # Convert image to bytes\n",
98 | " buffer = BytesIO()\n",
99 | " image.save(buffer, format=\"PNG\")\n",
100 | " buffer.seek(0)\n",
101 | " image_data = buffer.getvalue()\n",
102 | " \n",
103 | " # Encode as base64\n",
104 | " image_base64 = base64.b64encode(image_data).decode('utf-8')\n",
105 | " \n",
106 | " # Create the data URI format required by the API\n",
107 | " mime_type = \"image/png\" \n",
108 | " image_uri = f\"data:{mime_type};base64,{image_base64}\"\n",
109 | " \n",
110 | " return image_uri"
111 | ]
112 | },
113 | {
114 | "cell_type": "markdown",
115 | "metadata": {},
116 | "source": [
117 | "## 3. Loading and Preparing an Image\n",
118 | "\n",
119 | "Now we'll load a sample image (a green dog in this case) and convert it to the base64 format required by the API. This image will be used to generate embeddings in the subsequent steps."
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": 5,
125 | "metadata": {},
126 | "outputs": [],
127 | "source": [
128 | "image = Image.open(\"images/green_dog.jpeg\")\n",
129 | "image_uri = image_to_base64(image)"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {},
135 | "source": [
136 | "## 4. Generating Embeddings"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": 12,
142 | "metadata": {},
143 | "outputs": [],
144 | "source": [
145 | "# Generate embedding for a single text input\n",
146 | "prompt = \"Describe the image in detail\"\n",
147 | "image_embedding = client.embeddings.create(\n",
148 | " input=[prompt],\n",
149 | " model=\"mlx-community/Qwen2.5-VL-3B-Instruct-4bit\",\n",
150 | " extra_body = {\n",
151 | " \"image_url\": image_uri\n",
152 | " }\n",
153 | ").data[0].embedding\n",
154 | "\n",
155 | "text = \"A green dog looking at the camera\"\n",
156 | "text_embedding = client.embeddings.create(\n",
157 | " input=[text],\n",
158 | " model=\"mlx-community/Qwen2.5-VL-3B-Instruct-4bit\"\n",
159 | ").data[0].embedding"
160 | ]
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "metadata": {},
165 | "source": [
166 | "## 5. Comparing Text and Image Embeddings\n",
167 | "\n",
168 | "One of the powerful features of VLM embeddings is that they create a shared vector space for both text and images. This means we can directly compare how similar a text description is to an image's content by calculating the cosine similarity between their embeddings.\n",
169 | "\n",
170 | "A higher similarity score (closer to 1.0) indicates that the text description closely matches the image content according to the model's understanding."
171 | ]
172 | },
173 | {
174 | "cell_type": "code",
175 | "execution_count": 13,
176 | "metadata": {},
177 | "outputs": [
178 | {
179 | "name": "stdout",
180 | "output_type": "stream",
181 | "text": [
182 | "0.8473370724651375\n"
183 | ]
184 | }
185 | ],
186 | "source": [
187 | "import numpy as np\n",
188 | "\n",
189 | "def cosine_similarity(a, b):\n",
190 | " a = np.array(a)\n",
191 | " b = np.array(b)\n",
192 | " return np.dot(a, b)\n",
193 | "\n",
194 | "similarity = cosine_similarity(image_embedding, text_embedding)\n",
195 | "print(similarity)"
196 | ]
197 | }
198 | ],
199 | "metadata": {
200 | "kernelspec": {
201 | "display_name": "Python 3",
202 | "language": "python",
203 | "name": "python3"
204 | },
205 | "language_info": {
206 | "codemirror_mode": {
207 | "name": "ipython",
208 | "version": 3
209 | },
210 | "file_extension": ".py",
211 | "mimetype": "text/x-python",
212 | "name": "python",
213 | "nbconvert_exporter": "python",
214 | "pygments_lexer": "ipython3",
215 | "version": "3.11.12"
216 | }
217 | },
218 | "nbformat": 4,
219 | "nbformat_minor": 2
220 | }
221 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from app import __version__
3 |
4 |
5 | setup(
6 | name="mlx-openai-server",
7 | url="https://github.com/cubist38/mlx-openai-server",
8 | author="Huy Vuong",
9 | author_email="cubist38@gmail.com",
10 | version=__version__,
11 | description="A high-performance API server that provides OpenAI-compatible endpoints for MLX models. Built with Python and FastAPI, it enables efficient, scalable, and user-friendly local deployment of MLX-based vision and language models with an OpenAI-compatible interface. Perfect for developers looking to run MLX models locally while maintaining compatibility with existing OpenAI-based applications.",
12 | long_description=open("README.md").read(),
13 | long_description_content_type="text/markdown",
14 | packages=find_packages(),
15 | install_requires=[
16 | "mlx-vlm==0.1.27",
17 | "mlx-lm==0.25.2",
18 | "fastapi",
19 | "uvicorn",
20 | "Pillow",
21 | "click",
22 | "loguru",
23 | ],
24 | extras_require={
25 | "dev": [
26 | "pytest",
27 | "black",
28 | "isort",
29 | "flake8",
30 | ]
31 | },
32 | entry_points={
33 | "console_scripts": [
34 | "mlx-openai-server=app.cli:cli",
35 | ],
36 | },
37 | python_requires=">=3.11",
38 | classifiers=[
39 | "Programming Language :: Python :: 3",
40 | "License :: OSI Approved :: MIT License",
41 | "Operating System :: OS Independent",
42 | ],
43 | )
44 |
--------------------------------------------------------------------------------
/tests/test_base_tool_parser.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from app.handler.parser.base import BaseToolParser, ParseState
4 |
5 |
6 | class TestBaseToolParser(unittest.TestCase):
7 | def setUp(self):
8 | self.test_cases = [
9 | {
10 | "name": "simple function call",
11 | "chunks": '''##
12 | #{"#name#":# "#get#_weather#",# "#arguments#":# {"#city#":# "#H#ue#"}}
13 | ##
14 | ##
15 | #{"#name#":# "#get#_weather#",# "#arguments#":# {"#city#":# "#Sy#dney#"}}
16 | ###'''.split('#')
17 | ,
18 | "expected_outputs": [
19 | {'name': 'get_weather', 'arguments': ''},
20 | {'name': None, 'arguments': ' {"'},
21 | {'name': None, 'arguments': 'city'},
22 | {'name': None, 'arguments': '":'},
23 | {'name': None, 'arguments': ' "'},
24 | {'name': None, 'arguments': 'H'},
25 | {'name': None, 'arguments': 'ue'},
26 | {'name': None, 'arguments': '"}'},
27 | '\n',
28 | {'name': 'get_weather', 'arguments': ''},
29 | {'name': None, 'arguments': ' {"'},
30 | {'name': None, 'arguments': 'city'},
31 | {'name': None, 'arguments': '":'},
32 | {'name': None, 'arguments': ' "'},
33 | {'name': None, 'arguments': 'Sy'},
34 | {'name': None, 'arguments': 'dney'},
35 | {'name': None, 'arguments': '"}'},
36 | ]
37 | },
38 | {
39 | "name": "code function call",
40 | "chunks": r'''@@
41 | @@{"@@name@@":@@ "@@python@@",@@ "@@arguments@@":@@ {"@@code@@":@@ "@@def@@ calculator@@(a@@,@@ b@@,@@ operation@@):\@@n@@ @@ if@@ operation@@ ==@@ '@@add@@'\@@n@@ @@ return@@ a@@ +@@ b@@\n@@ @@ if@@ operation@@ ==@@ '@@subtract@@'\@@n@@ @@ return@@ a@@ -@@ b@@\n@@ @@ if@@ operation@@ ==@@ '@@multiply@@'\@@n@@ @@ return@@ a@@ *@@ b@@\n@@ @@ if@@ operation@@ ==@@ '@@divide@@'\@@n@@ @@ return@@ a@@ /@@ b@@\n@@ @@ return@@ '@@Invalid@@ operation@@'@@"}}
42 | @@@@@@'''.split('@@')
43 | ,
44 | "expected_outputs": [
45 | {'name': 'python', 'arguments': ''},
46 | {'name': None, 'arguments': ' {"'},
47 | {'name': None, 'arguments': 'code'},
48 | {'name': None, 'arguments': '":'},
49 | {'name': None, 'arguments': ' "'},
50 | {'name': None, 'arguments': 'def'},
51 | {'name': None, 'arguments': ' calculator'},
52 | {'name': None, 'arguments': '(a'},
53 | {'name': None, 'arguments': ','},
54 | {'name': None, 'arguments': ' b'},
55 | {'name': None, 'arguments': ','},
56 | {'name': None, 'arguments': ' operation'},
57 | {'name': None, 'arguments': '):\\'},
58 | {'name': None, 'arguments': 'n'},
59 | {'name': None, 'arguments': ' '},
60 | {'name': None, 'arguments': ' if'},
61 | {'name': None, 'arguments': ' operation'},
62 | {'name': None, 'arguments': ' =='},
63 | {'name': None, 'arguments': " '"},
64 | {'name': None, 'arguments': 'add'},
65 | {'name': None, 'arguments': "'\\"},
66 | {'name': None, 'arguments': 'n'},
67 | {'name': None, 'arguments': ' '},
68 | {'name': None, 'arguments': ' return'},
69 | {'name': None, 'arguments': ' a'},
70 | {'name': None, 'arguments': ' +'},
71 | {'name': None, 'arguments': ' b'},
72 | {'name': None, 'arguments': '\\n'},
73 | {'name': None, 'arguments': ' '},
74 | {'name': None, 'arguments': ' if'},
75 | {'name': None, 'arguments': ' operation'},
76 | {'name': None, 'arguments': ' =='},
77 | {'name': None, 'arguments': " '"},
78 | {'name': None, 'arguments': 'subtract'},
79 | {'name': None, 'arguments': "'\\"},
80 | {'name': None, 'arguments': 'n'},
81 | {'name': None, 'arguments': ' '},
82 | {'name': None, 'arguments': ' return'},
83 | {'name': None, 'arguments': ' a'},
84 | {'name': None, 'arguments': ' -'},
85 | {'name': None, 'arguments': ' b'},
86 | {'name': None, 'arguments': '\\n'},
87 | {'name': None, 'arguments': ' '},
88 | {'name': None, 'arguments': ' if'},
89 | {'name': None, 'arguments': ' operation'},
90 | {'name': None, 'arguments': ' =='},
91 | {'name': None, 'arguments': " '"},
92 | {'name': None, 'arguments': 'multiply'},
93 | {'name': None, 'arguments': "'\\"},
94 | {'name': None, 'arguments': 'n'},
95 | {'name': None, 'arguments': ' '},
96 | {'name': None, 'arguments': ' return'},
97 | {'name': None, 'arguments': ' a'},
98 | {'name': None, 'arguments': ' *'},
99 | {'name': None, 'arguments': ' b'},
100 | {'name': None, 'arguments': '\\n'},
101 | {'name': None, 'arguments': ' '},
102 | {'name': None, 'arguments': ' if'},
103 | {'name': None, 'arguments': ' operation'},
104 | {'name': None, 'arguments': ' =='},
105 | {'name': None, 'arguments': " '"},
106 | {'name': None, 'arguments': 'divide'},
107 | {'name': None, 'arguments': "'\\"},
108 | {'name': None, 'arguments': 'n'},
109 | {'name': None, 'arguments': ' '},
110 | {'name': None, 'arguments': ' return'},
111 | {'name': None, 'arguments': ' a'},
112 | {'name': None, 'arguments': ' /'},
113 | {'name': None, 'arguments': ' b'},
114 | {'name': None, 'arguments': '\\n'},
115 | {'name': None, 'arguments': ' '},
116 | {'name': None, 'arguments': ' return'},
117 | {'name': None, 'arguments': " '"},
118 | {'name': None, 'arguments': 'Invalid'},
119 | {'name': None, 'arguments': ' operation'},
120 | {'name': None, 'arguments': "'"},
121 | {'name': None, 'arguments': '"}'},
122 | ]
123 | },
124 | ]
125 |
126 | def test_parse_stream(self):
127 | for test_case in self.test_cases:
128 | with self.subTest(msg=test_case["name"]):
129 | parser = BaseToolParser("", "")
130 | outputs = []
131 |
132 | for chunk in test_case["chunks"]:
133 | result = parser.parse_stream(chunk)
134 | if result:
135 | outputs.append(result)
136 |
137 |
138 | self.assertEqual(len(outputs), len(test_case["expected_outputs"]),
139 | f"Expected {len(test_case['expected_outputs'])} outputs, got {len(outputs)}")
140 |
141 | for i, (output, expected) in enumerate(zip(outputs, test_case["expected_outputs"])):
142 | self.assertEqual(output, expected,
143 | f"Chunk {i}: Expected {expected}, got {output}")
144 |
145 | if __name__ == '__main__':
146 | unittest.main()
147 |
--------------------------------------------------------------------------------