├── .github
└── workflows
│ └── publish.yml
├── .gitignore
├── MANIFEST.in
├── Makefile
├── README.md
├── app
├── __init__.py
├── api
│ ├── __init__.py
│ └── endpoints.py
├── cli.py
├── core
│ ├── __init__.py
│ ├── audio_processor.py
│ ├── base_processor.py
│ ├── image_processor.py
│ ├── queue.py
│ └── video_processor.py
├── handler
│ ├── __init__.py
│ ├── mflux.py
│ ├── mlx_embeddings.py
│ ├── mlx_lm.py
│ ├── mlx_vlm.py
│ ├── mlx_whisper.py
│ └── parser
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── glm4_moe.py
│ │ ├── harmony.py
│ │ └── qwen3.py
├── main.py
├── models
│ ├── __init__.py
│ ├── mflux.py
│ ├── mlx_embeddings.py
│ ├── mlx_lm.py
│ ├── mlx_vlm.py
│ └── mlx_whisper.py
├── schemas
│ ├── __init__.py
│ └── openai.py
├── utils
│ ├── __init__.py
│ ├── dill.py
│ ├── errors.py
│ └── outlines_transformer_tokenizer.py
└── version.py
├── configure_mlx.sh
├── examples
├── audio_examples.ipynb
├── audios
│ ├── audio.wav
│ └── podcast.wav
├── embedding_examples.ipynb
├── image_edit.ipynb
├── image_generations.ipynb
├── images
│ ├── attention.png
│ ├── china.png
│ ├── green_dog.jpeg
│ └── password.jpg
├── lm_embeddings_examples.ipynb
├── pdfs
│ └── lab03.pdf
├── simple_rag_demo.ipynb
├── structured_outputs_examples.ipynb
├── transcription_examples.ipynb
├── videos
│ └── demo.mp4
├── vision_examples.ipynb
└── vlm_embeddings_examples.ipynb
├── setup.py
└── tests
└── test_base_tool_parser.py
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish Python 🐍 distribution 📦 to PyPI
2 |
3 | on:
4 | push:
5 | tags:
6 | - 'v*' # Triggers on version tags like v1.0.0
7 |
8 | jobs:
9 | build-and-publish:
10 | runs-on: macos-latest
11 |
12 | steps:
13 | - uses: actions/checkout@v4
14 |
15 | - name: Set up Python
16 | uses: actions/setup-python@v5
17 | with:
18 | python-version: '3.11'
19 |
20 | - name: Install build tools
21 | run: |
22 | python -m pip install --upgrade pip
23 | pip install build twine
24 |
25 | - name: Build package
26 | run: python -m build
27 |
28 | - name: Publish package to PyPI
29 | env:
30 | TWINE_USERNAME: __token__
31 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
32 | run: twine upload dist/*
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | oai-compat-server
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 | # ignore .DS_Store
11 | .DS_Store
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | # For a library or package, you might want to ignore these files since the code is
91 | # intended to run in multiple environments; otherwise, check them in:
92 | # .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # UV
102 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
103 | # This is especially recommended for binary packages to ensure reproducibility, and is more
104 | # commonly ignored for libraries.
105 | #uv.lock
106 |
107 | # poetry
108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
109 | # This is especially recommended for binary packages to ensure reproducibility, and is more
110 | # commonly ignored for libraries.
111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
112 | #poetry.lock
113 |
114 | # pdm
115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
116 | #pdm.lock
117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
118 | # in version control.
119 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
120 | .pdm.toml
121 | .pdm-python
122 | .pdm-build/
123 |
124 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
125 | __pypackages__/
126 |
127 | # Celery stuff
128 | celerybeat-schedule
129 | celerybeat.pid
130 |
131 | # SageMath parsed files
132 | *.sage.py
133 |
134 | # Environments
135 | .env
136 | .venv
137 | env/
138 | venv/
139 | ENV/
140 | env.bak/
141 | venv.bak/
142 |
143 | # Spyder project settings
144 | .spyderproject
145 | .spyproject
146 |
147 | # Rope project settings
148 | .ropeproject
149 |
150 | # mkdocs documentation
151 | /site
152 |
153 | # mypy
154 | .mypy_cache/
155 | .dmypy.json
156 | dmypy.json
157 |
158 | # Pyre type checker
159 | .pyre/
160 |
161 | # pytype static type analyzer
162 | .pytype/
163 |
164 | # Cython debug symbols
165 | cython_debug/
166 |
167 | # PyCharm
168 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
169 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
170 | # and can be added to the global gitignore or merged into this file. For a more nuclear
171 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
172 | #.idea/
173 |
174 | # Ruff stuff:
175 | .ruff_cache/
176 |
177 | # PyPI configuration file
178 | .pypirc
179 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
2 | include requirements.txt
3 | include MANIFEST.in
4 | include setup.py
5 | recursive-include app *
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | run:
2 | mlx-server launch \
3 | --model-path mlx-community/Qwen3-1.7B-4bit \
4 | --model-type lm \
5 | --max-concurrency 1 \
6 | --queue-timeout 300 \
7 | --queue-size 100
8 |
9 | install:
10 | pip install -e .
--------------------------------------------------------------------------------
/app/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from app.version import __version__
3 |
4 | # Suppress transformers warnings
5 | os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
6 |
7 | __all__ = ["__version__"]
--------------------------------------------------------------------------------
/app/api/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/app/cli.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import asyncio
3 | import click
4 | import uvicorn
5 | from loguru import logger
6 | from functools import lru_cache
7 | from app.version import __version__
8 | from app.main import setup_server
9 |
10 | class Config:
11 | """Configuration container for server parameters."""
12 | def __init__(self, model_path, model_type, context_length, port, host, max_concurrency, queue_timeout, queue_size, disable_auto_resize=False, quantize=8, config_name=None, lora_paths=None, lora_scales=None, log_file=None, no_log_file=False, log_level="INFO"):
13 | self.model_path = model_path
14 | self.model_type = model_type
15 | self.context_length = context_length
16 | self.port = port
17 | self.host = host
18 | self.max_concurrency = max_concurrency
19 | self.queue_timeout = queue_timeout
20 | self.queue_size = queue_size
21 | self.disable_auto_resize = disable_auto_resize
22 | self.quantize = quantize
23 | self.config_name = config_name
24 | self.log_file = log_file
25 | self.no_log_file = no_log_file
26 | self.log_level = log_level
27 |
28 | # Process comma-separated LoRA paths and scales
29 | if lora_paths:
30 | self.lora_paths = [path.strip() for path in lora_paths.split(',') if path.strip()]
31 | else:
32 | self.lora_paths = None
33 |
34 | if lora_scales:
35 | self.lora_scales = [float(scale.strip()) for scale in lora_scales.split(',') if scale.strip()]
36 | else:
37 | self.lora_scales = None
38 |
39 |
40 | @property
41 | def model_identifier(self):
42 | """Get the appropriate model identifier based on model type."""
43 | # For Flux models, we always use model_path (local directory path)
44 | return self.model_path
45 |
46 |
47 | # Configure basic logging for CLI (will be overridden by main.py)
48 | logger.remove() # Remove default handler
49 | logger.add(
50 | sys.stderr,
51 | format="{time:YYYY-MM-DD HH:mm:ss} | "
52 | "{level: <8} | "
53 | "{name}:{function}:{line} | "
54 | "✦ {message}",
55 | colorize=True,
56 | level="INFO"
57 | )
58 |
59 |
60 | @click.group()
61 | @click.version_option(
62 | version=__version__,
63 | message="""
64 | ✨ %(prog)s - OpenAI Compatible API Server for MLX models ✨
65 | ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
66 | 🚀 Version: %(version)s
67 | """
68 | )
69 | def cli():
70 | """MLX Server - OpenAI Compatible API for MLX models."""
71 | pass
72 |
73 |
74 | @lru_cache(maxsize=1)
75 | def get_server_config(model_path, model_type, context_length, port, host, max_concurrency, queue_timeout, queue_size, quantize, config_name, lora_paths, lora_scales, disable_auto_resize, log_file, no_log_file, log_level):
76 | """Cache and return server configuration to avoid redundant processing."""
77 | return Config(
78 | model_path=model_path,
79 | model_type=model_type,
80 | context_length=context_length,
81 | port=port,
82 | host=host,
83 | max_concurrency=max_concurrency,
84 | queue_timeout=queue_timeout,
85 | queue_size=queue_size,
86 | disable_auto_resize=disable_auto_resize,
87 | quantize=quantize,
88 | config_name=config_name,
89 | lora_paths=lora_paths,
90 | lora_scales=lora_scales,
91 | log_file=log_file,
92 | no_log_file=no_log_file,
93 | log_level=log_level
94 | )
95 |
96 |
97 | def print_startup_banner(args):
98 | """Display beautiful startup banner with configuration details."""
99 | logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
100 | logger.info(f"✨ MLX Server v{__version__} Starting ✨")
101 | logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
102 | logger.info(f"🔮 Model Path: {args.model_path}")
103 | logger.info(f"🔮 Model Type: {args.model_type}")
104 | if args.context_length:
105 | logger.info(f"🔮 Context Length: {args.context_length}")
106 | logger.info(f"🌐 Host: {args.host}")
107 | logger.info(f"🔌 Port: {args.port}")
108 | logger.info(f"⚡ Max Concurrency: {args.max_concurrency}")
109 | logger.info(f"⏱️ Queue Timeout: {args.queue_timeout} seconds")
110 | logger.info(f"📊 Queue Size: {args.queue_size}")
111 | if args.model_type in ["image-generation", "image-edit"]:
112 | logger.info(f"🔮 Quantize: {args.quantize}")
113 | logger.info(f"🔮 Config Name: {args.config_name}")
114 | if args.lora_paths:
115 | logger.info(f"🔮 LoRA Paths: {args.lora_paths}")
116 | if args.lora_scales:
117 | logger.info(f"🔮 LoRA Scales: {args.lora_scales}")
118 | if hasattr(args, 'disable_auto_resize') and args.disable_auto_resize and args.model_type == "multimodal":
119 | logger.info(f"🖼️ Auto-resize: Disabled")
120 | logger.info(f"📝 Log Level: {args.log_level}")
121 | if args.no_log_file:
122 | logger.info(f"📝 File Logging: Disabled")
123 | elif args.log_file:
124 | logger.info(f"📝 Log File: {args.log_file}")
125 | else:
126 | logger.info(f"📝 Log File: logs/app.log (default)")
127 | logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
128 |
129 | @cli.command()
130 | @click.option(
131 | "--model-path",
132 | help="Path to the model (required for lm, multimodal, embeddings, image-generation, image-edit, whisper model types). With `image-generation` or `image-edit` model types, it should be the local path to the model."
133 | )
134 | @click.option(
135 | "--model-type",
136 | default="lm",
137 | type=click.Choice(["lm", "multimodal", "image-generation", "image-edit", "embeddings", "whisper"]),
138 | help="Type of model to run (lm: text-only, multimodal: text+vision+audio, image-generation: flux image generation, image-edit: flux image edit, embeddings: text embeddings, whisper: audio transcription)"
139 | )
140 | @click.option(
141 | "--context-length",
142 | default=None,
143 | type=int,
144 | help="Context length for language models. Only works with `lm` or `multimodal` model types."
145 | )
146 | @click.option(
147 | "--port",
148 | default=8000,
149 | type=int,
150 | help="Port to run the server on"
151 | )
152 | @click.option(
153 | "--host",
154 | default="0.0.0.0",
155 | help="Host to run the server on"
156 | )
157 | @click.option(
158 | "--max-concurrency",
159 | default=1,
160 | type=int,
161 | help="Maximum number of concurrent requests"
162 | )
163 | @click.option(
164 | "--queue-timeout",
165 | default=300,
166 | type=int,
167 | help="Request timeout in seconds"
168 | )
169 | @click.option(
170 | "--queue-size",
171 | default=100,
172 | type=int,
173 | help="Maximum queue size for pending requests"
174 | )
175 | @click.option(
176 | "--quantize",
177 | default=8,
178 | type=int,
179 | help="Quantization level for the model. Only used for image-generation and image-edit Flux models."
180 | )
181 | @click.option(
182 | "--config-name",
183 | default=None,
184 | type=click.Choice(["flux-schnell", "flux-dev", "flux-krea-dev", "flux-kontext-dev"]),
185 | help="Config name of the model. Only used for image-generation and image-edit Flux models."
186 | )
187 | @click.option(
188 | "--lora-paths",
189 | default=None,
190 | type=str,
191 | help="Path to the LoRA file(s). Multiple paths should be separated by commas."
192 | )
193 | @click.option(
194 | "--lora-scales",
195 | default=None,
196 | type=str,
197 | help="Scale factor for the LoRA file(s). Multiple scales should be separated by commas."
198 | )
199 | @click.option(
200 | "--disable-auto-resize",
201 | is_flag=True,
202 | help="Disable automatic model resizing. Only work for Vision Language Models."
203 | )
204 | @click.option(
205 | "--log-file",
206 | default=None,
207 | type=str,
208 | help="Path to log file. If not specified, logs will be written to 'logs/app.log' by default."
209 | )
210 | @click.option(
211 | "--no-log-file",
212 | is_flag=True,
213 | help="Disable file logging entirely. Only console output will be shown."
214 | )
215 | @click.option(
216 | "--log-level",
217 | default="INFO",
218 | type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
219 | help="Set the logging level. Default is INFO."
220 | )
221 | def launch(model_path, model_type, context_length, port, host, max_concurrency, queue_timeout, queue_size, quantize, config_name, lora_paths, lora_scales, disable_auto_resize, log_file, no_log_file, log_level):
222 | """Launch the MLX server with the specified model."""
223 | try:
224 | # Validate that config name is only used with image-generation and image-edit model types
225 | if config_name and model_type not in ["image-generation", "image-edit"]:
226 | logger.warning(f"Config name parameter '{config_name}' provided but model type is '{model_type}'. Config name is only used with image-generation and image-edit models.")
227 | elif model_type == "image-generation" and not config_name:
228 | logger.warning("Model type is 'image-generation' but no config name specified. Using default 'flux-schnell'.")
229 | config_name = "flux-schnell"
230 | elif model_type == "image-edit" and not config_name:
231 | logger.warning("Model type is 'image-edit' but no config name specified. Using default 'flux-kontext-dev'.")
232 | config_name = "flux-kontext-dev"
233 |
234 | # Get optimized configuration
235 | args = get_server_config(model_path, model_type, context_length, port, host, max_concurrency, queue_timeout, queue_size, quantize, config_name, lora_paths, lora_scales, disable_auto_resize, log_file, no_log_file, log_level)
236 |
237 | # Display startup information
238 | print_startup_banner(args)
239 |
240 | # Set up and start the server
241 | config = asyncio.run(setup_server(args))
242 | logger.info("Server configuration complete.")
243 | logger.info("Starting Uvicorn server...")
244 | uvicorn.Server(config).run()
245 | except KeyboardInterrupt:
246 | logger.info("Server shutdown requested by user. Exiting...")
247 | except Exception as e:
248 | logger.error(f"Server startup failed: {str(e)}")
249 | sys.exit(1)
250 |
251 |
252 | if __name__ == "__main__":
253 | cli()
--------------------------------------------------------------------------------
/app/core/__init__.py:
--------------------------------------------------------------------------------
1 | from app.core.base_processor import BaseProcessor
2 | from app.core.audio_processor import AudioProcessor
3 | from app.core.image_processor import ImageProcessor
4 | from app.core.video_processor import VideoProcessor
5 |
6 | __all__ = ["BaseProcessor", "AudioProcessor", "ImageProcessor", "VideoProcessor"]
--------------------------------------------------------------------------------
/app/core/audio_processor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gc
3 | import asyncio
4 | from typing import List
5 | from app.core.base_processor import BaseProcessor
6 |
7 |
8 | class AudioProcessor(BaseProcessor):
9 | """Audio processor for handling audio files with caching and validation."""
10 |
11 | def __init__(self, max_workers: int = 4, cache_size: int = 1000):
12 | super().__init__(max_workers, cache_size)
13 | # Supported audio formats
14 | self._supported_formats = {'.mp3', '.wav'}
15 |
16 | def _get_media_format(self, media_url: str, data: bytes = None) -> str:
17 | """Determine audio format from URL or data."""
18 | if media_url.startswith("data:"):
19 | # Extract format from data URL
20 | mime_type = media_url.split(";")[0].split(":")[1]
21 | if "mp3" in mime_type or "mpeg" in mime_type:
22 | return "mp3"
23 | elif "wav" in mime_type:
24 | return "wav"
25 | elif "m4a" in mime_type or "mp4" in mime_type:
26 | return "m4a"
27 | elif "ogg" in mime_type:
28 | return "ogg"
29 | elif "flac" in mime_type:
30 | return "flac"
31 | elif "aac" in mime_type:
32 | return "aac"
33 | else:
34 | # Extract format from file extension
35 | ext = os.path.splitext(media_url.lower())[1]
36 | if ext in self._supported_formats:
37 | return ext[1:] # Remove the dot
38 |
39 | # Default to mp3 if format cannot be determined
40 | return "mp3"
41 |
42 | def _validate_media_data(self, data: bytes) -> bool:
43 | """Basic validation of audio data."""
44 | if len(data) < 100: # Too small to be a valid audio file
45 | return False
46 |
47 | # Check for common audio file signatures
48 | audio_signatures = [
49 | b'ID3', # MP3 with ID3 tag
50 | b'\xff\xfb', # MP3 frame header
51 | b'\xff\xf3', # MP3 frame header
52 | b'\xff\xf2', # MP3 frame header
53 | b'RIFF', # WAV/AVI
54 | b'OggS', # OGG
55 | b'fLaC', # FLAC
56 | b'\x00\x00\x00\x20ftypM4A', # M4A
57 | ]
58 |
59 | for sig in audio_signatures:
60 | if data.startswith(sig):
61 | return True
62 |
63 | # Check for WAV format (RIFF header might be at different position)
64 | if b'WAVE' in data[:50]:
65 | return True
66 |
67 | return True # Allow unknown formats to pass through
68 |
69 | def _get_timeout(self) -> int:
70 | """Get timeout for HTTP requests."""
71 | return 60 # Longer timeout for audio files
72 |
73 | def _get_max_file_size(self) -> int:
74 | """Get maximum file size in bytes."""
75 | return 500 * 1024 * 1024 # 500 MB limit for audio
76 |
77 | def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> str:
78 | """Process audio data and save to cached path."""
79 | with open(cached_path, 'wb') as f:
80 | f.write(data)
81 | self._cleanup_old_files()
82 | return cached_path
83 |
84 | def _get_media_type_name(self) -> str:
85 | """Get media type name for logging."""
86 | return "audio"
87 |
88 | async def process_audio_url(self, audio_url: str) -> str:
89 | """Process a single audio URL and return path to cached file."""
90 | return await self._process_single_media(audio_url)
91 |
92 | async def process_audio_urls(self, audio_urls: List[str]) -> List[str]:
93 | """Process multiple audio URLs and return paths to cached files."""
94 | tasks = [self.process_audio_url(url) for url in audio_urls]
95 | results = await asyncio.gather(*tasks, return_exceptions=True)
96 | # Force garbage collection after batch processing
97 | gc.collect()
98 | return results
--------------------------------------------------------------------------------
/app/core/base_processor.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import hashlib
3 | import os
4 | import tempfile
5 | import aiohttp
6 | import time
7 | import gc
8 | from loguru import logger
9 | from typing import Dict, Optional, Any
10 | from concurrent.futures import ThreadPoolExecutor
11 | from abc import ABC, abstractmethod
12 |
13 |
14 | class BaseProcessor(ABC):
15 | """Base class for media processors with common caching and session management."""
16 |
17 | def __init__(self, max_workers: int = 4, cache_size: int = 1000):
18 | # Use tempfile for macOS-efficient temporary file handling
19 | self.temp_dir = tempfile.TemporaryDirectory()
20 | self._session: Optional[aiohttp.ClientSession] = None
21 | self.executor = ThreadPoolExecutor(max_workers=max_workers)
22 | self._cache_size = cache_size
23 | self._last_cleanup = time.time()
24 | self._cleanup_interval = 3600 # 1 hour
25 | # Replace lru_cache with manual cache for better control
26 | self._hash_cache: Dict[str, str] = {}
27 | self._cache_access_times: Dict[str, float] = {}
28 |
29 | def _get_media_hash(self, media_url: str) -> str:
30 | """Get hash for media URL with manual caching that can be cleared."""
31 | # Check if already cached
32 | if media_url in self._hash_cache:
33 | self._cache_access_times[media_url] = time.time()
34 | return self._hash_cache[media_url]
35 |
36 | # Generate hash
37 | if media_url.startswith("data:"):
38 | _, encoded = media_url.split(",", 1)
39 | data = base64.b64decode(encoded)
40 | else:
41 | data = media_url.encode('utf-8')
42 |
43 | hash_value = hashlib.md5(data).hexdigest()
44 |
45 | # Add to cache with size management
46 | if len(self._hash_cache) >= self._cache_size:
47 | self._evict_oldest_cache_entries()
48 |
49 | self._hash_cache[media_url] = hash_value
50 | self._cache_access_times[media_url] = time.time()
51 | return hash_value
52 |
53 | def _evict_oldest_cache_entries(self):
54 | """Remove oldest 20% of cache entries to make room."""
55 | if not self._cache_access_times:
56 | return
57 |
58 | # Sort by access time and remove oldest 20%
59 | sorted_items = sorted(self._cache_access_times.items(), key=lambda x: x[1])
60 | to_remove = len(sorted_items) // 5 # Remove 20%
61 |
62 | for url, _ in sorted_items[:to_remove]:
63 | self._hash_cache.pop(url, None)
64 | self._cache_access_times.pop(url, None)
65 |
66 | # Force garbage collection after cache eviction
67 | gc.collect()
68 |
69 | @abstractmethod
70 | def _get_media_format(self, media_url: str, data: bytes = None) -> str:
71 | """Determine media format from URL or data. Must be implemented by subclasses."""
72 | pass
73 |
74 | @abstractmethod
75 | def _validate_media_data(self, data: bytes) -> bool:
76 | """Validate media data. Must be implemented by subclasses."""
77 | pass
78 |
79 | @abstractmethod
80 | def _get_timeout(self) -> int:
81 | """Get timeout for HTTP requests. Must be implemented by subclasses."""
82 | pass
83 |
84 | @abstractmethod
85 | def _get_max_file_size(self) -> int:
86 | """Get maximum file size in bytes. Must be implemented by subclasses."""
87 | pass
88 |
89 | @abstractmethod
90 | def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> Dict[str, Any]:
91 | """Process media data and save to cached path. Must be implemented by subclasses."""
92 | pass
93 |
94 | @abstractmethod
95 | def _get_media_type_name(self) -> str:
96 | """Get media type name for logging. Must be implemented by subclasses."""
97 | pass
98 |
99 | async def _get_session(self) -> aiohttp.ClientSession:
100 | if self._session is None or self._session.closed:
101 | self._session = aiohttp.ClientSession(
102 | timeout=aiohttp.ClientTimeout(total=self._get_timeout()),
103 | headers={"User-Agent": "mlx-server-OAI-compat/1.0"}
104 | )
105 | return self._session
106 |
107 | def _cleanup_old_files(self):
108 | current_time = time.time()
109 | if current_time - self._last_cleanup > self._cleanup_interval:
110 | try:
111 | for file in os.listdir(self.temp_dir.name):
112 | file_path = os.path.join(self.temp_dir.name, file)
113 | if os.path.getmtime(file_path) < current_time - self._cleanup_interval:
114 | os.remove(file_path)
115 | self._last_cleanup = current_time
116 | # Also clean up cache periodically
117 | if len(self._hash_cache) > self._cache_size * 0.8:
118 | self._evict_oldest_cache_entries()
119 | gc.collect() # Force garbage collection after cleanup
120 | except Exception as e:
121 | logger.warning(f"Failed to clean up old {self._get_media_type_name()} files: {str(e)}")
122 |
123 | async def _process_single_media(self, media_url: str, **kwargs) -> str:
124 | try:
125 | media_hash = self._get_media_hash(media_url)
126 | media_format = self._get_media_format(media_url)
127 | cached_path = os.path.join(self.temp_dir.name, f"{media_hash}.{media_format}")
128 |
129 | if os.path.exists(cached_path):
130 | logger.debug(f"Using cached {self._get_media_type_name()}: {cached_path}")
131 | return cached_path
132 |
133 | if os.path.exists(media_url):
134 | # Copy local file to cache
135 | with open(media_url, 'rb') as f:
136 | data = f.read()
137 |
138 | if not self._validate_media_data(data):
139 | raise ValueError(f"Invalid {self._get_media_type_name()} file format")
140 |
141 | return self._process_media_data(data, cached_path, **kwargs)
142 |
143 | elif media_url.startswith("data:"):
144 | _, encoded = media_url.split(",", 1)
145 | estimated_size = len(encoded) * 3 / 4
146 | if estimated_size > self._get_max_file_size():
147 | raise ValueError(f"Base64-encoded {self._get_media_type_name()} exceeds size limit")
148 | data = base64.b64decode(encoded)
149 |
150 | if not self._validate_media_data(data):
151 | raise ValueError(f"Invalid {self._get_media_type_name()} file format")
152 |
153 | return self._process_media_data(data, cached_path, **kwargs)
154 | else:
155 | session = await self._get_session()
156 | async with session.get(media_url) as response:
157 | response.raise_for_status()
158 | data = await response.read()
159 |
160 | if not self._validate_media_data(data):
161 | raise ValueError(f"Invalid {self._get_media_type_name()} file format")
162 |
163 | return self._process_media_data(data, cached_path, **kwargs)
164 |
165 | except Exception as e:
166 | logger.error(f"Failed to process {self._get_media_type_name()}: {str(e)}")
167 | raise ValueError(f"Failed to process {self._get_media_type_name()}: {str(e)}")
168 | finally:
169 | gc.collect()
170 |
171 | def clear_cache(self):
172 | """Manually clear the hash cache to free memory."""
173 | self._hash_cache.clear()
174 | self._cache_access_times.clear()
175 | gc.collect()
176 |
177 | async def cleanup(self):
178 | if hasattr(self, '_cleaned') and self._cleaned:
179 | return
180 | self._cleaned = True
181 | try:
182 | # Clear caches before cleanup
183 | self.clear_cache()
184 |
185 | if self._session and not self._session.closed:
186 | await self._session.close()
187 | except Exception as e:
188 | logger.warning(f"Exception closing aiohttp session: {str(e)}")
189 | try:
190 | self.executor.shutdown(wait=True)
191 | except Exception as e:
192 | logger.warning(f"Exception shutting down executor: {str(e)}")
193 | try:
194 | self.temp_dir.cleanup()
195 | except Exception as e:
196 | logger.warning(f"Exception cleaning up temp directory: {str(e)}")
197 |
198 | async def __aenter__(self):
199 | return self
200 |
201 | async def __aexit__(self, exc_type, exc, tb):
202 | await self.cleanup()
203 |
204 | def __del__(self):
205 | # Async cleanup cannot be reliably performed in __del__
206 | # Please use 'async with Processor()' or call 'await cleanup()' explicitly.
207 | pass
--------------------------------------------------------------------------------
/app/core/image_processor.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import asyncio
3 | from PIL import Image
4 | from loguru import logger
5 | from io import BytesIO
6 | from typing import List
7 | from app.core.base_processor import BaseProcessor
8 |
9 |
10 | class ImageProcessor(BaseProcessor):
11 | """Image processor for handling image files with caching, validation, and processing."""
12 |
13 | def __init__(self, max_workers: int = 4, cache_size: int = 1000):
14 | super().__init__(max_workers, cache_size)
15 | Image.MAX_IMAGE_PIXELS = 100000000 # Limit to 100 megapixels
16 |
17 | def _get_media_format(self, media_url: str, data: bytes = None) -> str:
18 | """Determine image format from URL or data."""
19 | # For images, we always save as JPEG for consistency
20 | return "jpg"
21 |
22 | def _validate_media_data(self, data: bytes) -> bool:
23 | """Basic validation of image data."""
24 | if len(data) < 100: # Too small to be a valid image file
25 | return False
26 |
27 | # Check for common image file signatures
28 | image_signatures = [
29 | b'\xff\xd8\xff', # JPEG
30 | b'\x89PNG\r\n\x1a\n', # PNG
31 | b'GIF87a', # GIF87a
32 | b'GIF89a', # GIF89a
33 | b'BM', # BMP
34 | b'II*\x00', # TIFF (little endian)
35 | b'MM\x00*', # TIFF (big endian)
36 | b'RIFF', # WebP (part of RIFF)
37 | ]
38 |
39 | for sig in image_signatures:
40 | if data.startswith(sig):
41 | return True
42 |
43 | # Additional check for WebP
44 | if data.startswith(b'RIFF') and b'WEBP' in data[:20]:
45 | return True
46 |
47 | return False
48 |
49 | def _get_timeout(self) -> int:
50 | """Get timeout for HTTP requests."""
51 | return 30 # Standard timeout for images
52 |
53 | def _get_max_file_size(self) -> int:
54 | """Get maximum file size in bytes."""
55 | return 100 * 1024 * 1024 # 100 MB limit for images
56 |
57 | def _get_media_type_name(self) -> str:
58 | """Get media type name for logging."""
59 | return "image"
60 |
61 | def _resize_image_keep_aspect_ratio(self, image: Image.Image, max_size: int = 448) -> Image.Image:
62 | width, height = image.size
63 | if width <= max_size and height <= max_size:
64 | return image
65 | if width > height:
66 | new_width = max_size
67 | new_height = int(height * max_size / width)
68 | else:
69 | new_height = max_size
70 | new_width = int(width * max_size / height)
71 |
72 | image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
73 | logger.info(f"Resized image to {new_width}x{new_height} from {width}x{height}")
74 |
75 | return image
76 |
77 | def _prepare_image_for_saving(self, image: Image.Image) -> Image.Image:
78 | if image.mode in ('RGBA', 'LA'):
79 | background = Image.new('RGB', image.size, (255, 255, 255))
80 | if image.mode == 'RGBA':
81 | background.paste(image, mask=image.split()[3])
82 | else:
83 | background.paste(image, mask=image.split()[1])
84 | return background
85 | elif image.mode != 'RGB':
86 | return image.convert('RGB')
87 | return image
88 |
89 | def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> str:
90 | """Process image data and save to cached path."""
91 | image = None
92 | resize = kwargs.get("resize", True)
93 | try:
94 | with Image.open(BytesIO(data), mode='r') as image:
95 | if resize:
96 | image = self._resize_image_keep_aspect_ratio(image)
97 | image = self._prepare_image_for_saving(image)
98 | image.save(cached_path, 'PNG', quality=100, optimize=True)
99 |
100 | self._cleanup_old_files()
101 | return cached_path
102 | finally:
103 | # Ensure image object is closed to free memory
104 | if image:
105 | try:
106 | image.close()
107 | except:
108 | pass
109 |
110 | async def process_image_url(self, image_url: str, resize: bool = True) -> str:
111 | """Process a single image URL and return path to cached file."""
112 | return await self._process_single_media(image_url, resize=resize)
113 |
114 | async def process_image_urls(self, image_urls: List[str], resize: bool = True) -> List[str]:
115 | """Process multiple image URLs and return paths to cached files."""
116 | tasks = [self.process_image_url(url, resize=resize) for url in image_urls]
117 | results = await asyncio.gather(*tasks, return_exceptions=True)
118 | # Force garbage collection after batch processing
119 | gc.collect()
120 | return results
--------------------------------------------------------------------------------
/app/core/queue.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import time
3 | from typing import Any, Dict, Optional, Callable, Awaitable, TypeVar, Generic
4 | import gc
5 | from loguru import logger
6 |
7 | T = TypeVar('T')
8 |
9 | class RequestItem(Generic[T]):
10 | """
11 | Represents a single request in the queue.
12 | """
13 | def __init__(self, request_id: str, data: Any):
14 | self.request_id = request_id
15 | self.data = data
16 | self.created_at = time.time()
17 | self.future = asyncio.Future()
18 |
19 | def set_result(self, result: T) -> None:
20 | """Set the result for this request."""
21 | if not self.future.done():
22 | self.future.set_result(result)
23 |
24 | def set_exception(self, exc: Exception) -> None:
25 | """Set an exception for this request."""
26 | if not self.future.done():
27 | self.future.set_exception(exc)
28 |
29 | async def get_result(self) -> T:
30 | """Wait for and return the result of this request."""
31 | return await self.future
32 |
33 | class RequestQueue:
34 | """
35 | A simple asynchronous request queue with configurable concurrency.
36 | """
37 | def __init__(self, max_concurrency: int = 2, timeout: float = 300.0, queue_size: int = 100):
38 | """
39 | Initialize the request queue.
40 |
41 | Args:
42 | max_concurrency (int): Maximum number of concurrent requests to process.
43 | timeout (float): Timeout in seconds for request processing.
44 | queue_size (int): Maximum queue size.
45 | """
46 | self.max_concurrency = max_concurrency
47 | self.timeout = timeout
48 | self.queue_size = queue_size
49 | self.semaphore = asyncio.Semaphore(max_concurrency)
50 | self.queue = asyncio.Queue(maxsize=queue_size)
51 | self.active_requests: Dict[str, RequestItem] = {}
52 | self._worker_task = None
53 | self._running = False
54 |
55 | async def start(self, processor: Callable[[Any], Awaitable[Any]]):
56 | """
57 | Start the queue worker.
58 |
59 | Args:
60 | processor: Async function that processes queue items.
61 | """
62 | if self._running:
63 | return
64 |
65 | self._running = True
66 | self._worker_task = asyncio.create_task(self._worker_loop(processor))
67 | logger.info(f"Started request queue with max concurrency: {self.max_concurrency}")
68 |
69 | async def stop(self):
70 | """Stop the queue worker."""
71 | if not self._running:
72 | return
73 |
74 | self._running = False
75 |
76 | # Cancel the worker task
77 | if self._worker_task and not self._worker_task.done():
78 | self._worker_task.cancel()
79 | try:
80 | await self._worker_task
81 | except asyncio.CancelledError:
82 | pass
83 |
84 | # Cancel all pending requests
85 | pending_requests = list(self.active_requests.values())
86 | for request in pending_requests:
87 | if not request.future.done():
88 | request.future.cancel()
89 | # Clean up request data
90 | try:
91 | if hasattr(request, 'data'):
92 | del request.data
93 | except:
94 | pass
95 |
96 | self.active_requests.clear()
97 |
98 | # Clear the queue
99 | while not self.queue.empty():
100 | try:
101 | self.queue.get_nowait()
102 | except asyncio.QueueEmpty:
103 | break
104 |
105 | # Force garbage collection after cleanup
106 | gc.collect()
107 | logger.info("Stopped request queue")
108 |
109 | async def _worker_loop(self, processor: Callable[[Any], Awaitable[Any]]):
110 | """
111 | Main worker loop that processes queue items.
112 |
113 | Args:
114 | processor: Async function that processes queue items.
115 | """
116 | while self._running:
117 | try:
118 | # Get the next item from the queue
119 | request = await self.queue.get()
120 |
121 | # Process the request with concurrency control
122 | asyncio.create_task(self._process_request(request, processor))
123 |
124 | except asyncio.CancelledError:
125 | break
126 | except Exception as e:
127 | logger.error(f"Error in worker loop: {str(e)}")
128 |
129 | async def _process_request(self, request: RequestItem, processor: Callable[[Any], Awaitable[Any]]):
130 | """
131 | Process a single request with timeout and error handling.
132 |
133 | Args:
134 | request: The request to process.
135 | processor: Async function that processes the request.
136 | """
137 | # Use semaphore to limit concurrency
138 | async with self.semaphore:
139 | try:
140 | # Process with timeout
141 | processing_start = time.time()
142 | result = await asyncio.wait_for(
143 | processor(request.data),
144 | timeout=self.timeout
145 | )
146 | processing_time = time.time() - processing_start
147 |
148 | # Set the result
149 | request.set_result(result)
150 | logger.info(f"Request {request.request_id} processed in {processing_time:.2f}s")
151 |
152 | except asyncio.TimeoutError:
153 | request.set_exception(TimeoutError(f"Request processing timed out after {self.timeout}s"))
154 | logger.warning(f"Request {request.request_id} timed out after {self.timeout}s")
155 |
156 | except Exception as e:
157 | request.set_exception(e)
158 | logger.error(f"Error processing request {request.request_id}: {str(e)}")
159 |
160 | finally:
161 | # Always remove from active requests, even if an error occurred
162 | removed_request = self.active_requests.pop(request.request_id, None)
163 | if removed_request:
164 | # Clean up the request object
165 | try:
166 | if hasattr(removed_request, 'data'):
167 | del removed_request.data
168 | except:
169 | pass
170 | # Force garbage collection periodically to prevent memory buildup
171 | if len(self.active_requests) % 10 == 0: # Every 10 requests
172 | gc.collect()
173 |
174 | async def enqueue(self, request_id: str, data: Any) -> RequestItem:
175 | """
176 | Add a request to the queue.
177 |
178 | Args:
179 | request_id: Unique ID for the request.
180 | data: The request data to process.
181 |
182 | Returns:
183 | RequestItem: The queued request item.
184 |
185 | Raises:
186 | asyncio.QueueFull: If the queue is full.
187 | """
188 | if not self._running:
189 | raise RuntimeError("Queue is not running")
190 |
191 | # Create request item
192 | request = RequestItem(request_id, data)
193 |
194 | # Add to active requests and queue
195 | self.active_requests[request_id] = request
196 |
197 | try:
198 | # This will raise QueueFull if the queue is full
199 | await asyncio.wait_for(
200 | self.queue.put(request),
201 | timeout=1.0 # Short timeout for queue put
202 | )
203 | queue_time = time.time() - request.created_at
204 | logger.info(f"Request {request_id} queued (wait: {queue_time:.2f}s)")
205 | return request
206 |
207 | except asyncio.TimeoutError:
208 | self.active_requests.pop(request_id, None)
209 | raise asyncio.QueueFull("Request queue is full and timed out waiting for space")
210 |
211 | async def submit(self, request_id: str, data: Any) -> Any:
212 | """
213 | Submit a request and wait for its result.
214 |
215 | Args:
216 | request_id: Unique ID for the request.
217 | data: The request data to process.
218 |
219 | Returns:
220 | The result of processing the request.
221 |
222 | Raises:
223 | Various exceptions that may occur during processing.
224 | """
225 | request = await self.enqueue(request_id, data)
226 | return await request.get_result()
227 |
228 | def get_queue_stats(self) -> Dict[str, Any]:
229 | """
230 | Get queue statistics.
231 |
232 | Returns:
233 | Dict with queue statistics.
234 | """
235 | return {
236 | "running": self._running,
237 | "queue_size": self.queue.qsize(),
238 | "max_queue_size": self.queue_size,
239 | "active_requests": len(self.active_requests),
240 | "max_concurrency": self.max_concurrency
241 | }
242 |
243 | # Alias for the async stop method to maintain consistency in cleanup interfaces
244 | async def stop_async(self):
245 | """Alias for stop - stops the queue worker asynchronously."""
246 | await self.stop()
--------------------------------------------------------------------------------
/app/core/video_processor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gc
3 | import asyncio
4 | from loguru import logger
5 | from typing import List
6 | from app.core.base_processor import BaseProcessor
7 |
8 |
9 | class VideoProcessor(BaseProcessor):
10 | """Video processor for handling video files with caching, validation, and processing."""
11 |
12 | def __init__(self, max_workers: int = 4, cache_size: int = 1000):
13 | super().__init__(max_workers, cache_size)
14 | # Supported video formats
15 | self._supported_formats = {'.mp4', '.avi', '.mov'}
16 |
17 | def _get_media_format(self, media_url: str, data: bytes = None) -> str:
18 | """Determine video format from URL or data."""
19 | if media_url.startswith("data:"):
20 | # Extract format from data URL
21 | mime_type = media_url.split(";")[0].split(":")[1]
22 | if "mp4" in mime_type:
23 | return "mp4"
24 | elif "quicktime" in mime_type or "mov" in mime_type:
25 | return "mov"
26 | elif "x-msvideo" in mime_type or "avi" in mime_type:
27 | return "avi"
28 | else:
29 | # Extract format from file extension
30 | ext = os.path.splitext(media_url.lower())[1]
31 | if ext in self._supported_formats:
32 | return ext[1:] # Remove the dot
33 |
34 | # Default to mp4 if format cannot be determined
35 | return "mp4"
36 |
37 | def _validate_media_data(self, data: bytes) -> bool:
38 | """Basic validation of video data."""
39 | if len(data) < 100: # Too small to be a valid video file
40 | return False
41 |
42 | # Check for common video file signatures
43 | video_signatures = [
44 | # MP4/M4V/MOV (ISO Base Media File Format)
45 | (b'\x00\x00\x00\x14ftypisom', 0), # MP4
46 | (b'\x00\x00\x00\x18ftyp', 0), # MP4/MOV
47 | (b'\x00\x00\x00\x1cftyp', 0), # MP4/MOV
48 | (b'\x00\x00\x00\x20ftyp', 0), # MP4/MOV
49 | (b'ftyp', 4), # MP4/MOV (ftyp at offset 4)
50 |
51 | # AVI
52 | (b'RIFF', 0), # AVI (also check for 'AVI ' at offset 8)
53 |
54 | # WebM/MKV (Matroska)
55 | (b'\x1a\x45\xdf\xa3', 0), # Matroska/WebM
56 |
57 | # FLV
58 | (b'FLV\x01', 0), # Flash Video
59 |
60 | # MPEG
61 | (b'\x00\x00\x01\xba', 0), # MPEG PS
62 | (b'\x00\x00\x01\xb3', 0), # MPEG PS
63 |
64 | # QuickTime
65 | (b'moov', 0), # QuickTime
66 | (b'mdat', 0), # QuickTime
67 | ]
68 |
69 | for sig, offset in video_signatures:
70 | if len(data) > offset + len(sig):
71 | if data[offset:offset+len(sig)] == sig:
72 | # Additional validation for AVI
73 | if sig == b'RIFF' and len(data) > 12:
74 | if data[8:12] == b'AVI ':
75 | return True
76 | elif sig == b'RIFF':
77 | continue # Not AVI, might be WAV
78 | else:
79 | return True
80 |
81 | # Check for ftyp box anywhere in first 32 bytes (MP4/MOV)
82 | if b'ftyp' in data[:32]:
83 | return True
84 |
85 | # Allow unknown formats to pass through for flexibility
86 | return True
87 |
88 | def _get_timeout(self) -> int:
89 | """Get timeout for HTTP requests."""
90 | return 120 # Longer timeout for video files (2 minutes)
91 |
92 | def _get_max_file_size(self) -> int:
93 | """Get maximum file size in bytes."""
94 | return 1024 * 1024 * 1024 # 1 GB limit for videos
95 |
96 | def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> str:
97 | """Process video data and save to cached path."""
98 | try:
99 | with open(cached_path, 'wb') as f:
100 | f.write(data)
101 |
102 | logger.info(f"Saved video to {cached_path} ({len(data)} bytes)")
103 | self._cleanup_old_files()
104 | return cached_path
105 | except Exception as e:
106 | logger.error(f"Failed to save video data: {str(e)}")
107 | raise
108 |
109 | def _get_media_type_name(self) -> str:
110 | """Get media type name for logging."""
111 | return "video"
112 |
113 | async def process_video_url(self, video_url: str) -> str:
114 | """
115 | Process a single video URL and return path to cached file.
116 |
117 | Supports:
118 | - HTTP/HTTPS URLs (downloads video)
119 | - Local file paths (copies to cache)
120 | - Data URLs (base64 encoded videos)
121 |
122 | Args:
123 | video_url: URL, file path, or data URL of the video
124 |
125 | Returns:
126 | Path to the cached video file in temp directory
127 | """
128 | return await self._process_single_media(video_url)
129 |
130 | async def process_video_urls(self, video_urls: List[str]) -> List[str]:
131 | """
132 | Process multiple video URLs and return paths to cached files.
133 |
134 | Args:
135 | video_urls: List of URLs, file paths, or data URLs of videos
136 |
137 | Returns:
138 | List of paths to cached video files
139 | """
140 | tasks = [self.process_video_url(url) for url in video_urls]
141 | results = await asyncio.gather(*tasks, return_exceptions=True)
142 | # Force garbage collection after batch processing
143 | gc.collect()
144 | return results
145 |
--------------------------------------------------------------------------------
/app/handler/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | MLX model handlers for text, multimodal, image generation, and embeddings models.
3 | """
4 |
5 | from app.handler.mlx_lm import MLXLMHandler
6 | from app.handler.mlx_vlm import MLXVLMHandler
7 | from app.handler.mlx_embeddings import MLXEmbeddingsHandler
8 |
9 | # Optional mflux import - only available if flux extra is installed
10 | try:
11 | from app.handler.mflux import MLXFluxHandler
12 | MFLUX_AVAILABLE = True
13 | except ImportError:
14 | MLXFluxHandler = None
15 | MFLUX_AVAILABLE = False
16 |
17 | __all__ = [
18 | "MLXLMHandler",
19 | "MLXVLMHandler",
20 | "MLXFluxHandler",
21 | "MLXEmbeddingsHandler",
22 | "MFLUX_AVAILABLE"
23 | ]
24 |
--------------------------------------------------------------------------------
/app/handler/mlx_embeddings.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import time
3 | import uuid
4 | from http import HTTPStatus
5 | from typing import Any, Dict, List
6 |
7 | from fastapi import HTTPException
8 | from loguru import logger
9 |
10 | from app.core.queue import RequestQueue
11 | from app.schemas.openai import EmbeddingRequest
12 | from app.utils.errors import create_error_response
13 | from app.models.mlx_embeddings import MLX_Embeddings
14 |
15 | class MLXEmbeddingsHandler:
16 | """
17 | Handler class for making requests to the underlying MLX embeddings model service.
18 | Provides request queuing, metrics tracking, and robust error handling with memory management.
19 | """
20 |
21 | def __init__(self, model_path: str, max_concurrency: int = 1):
22 | """
23 | Initialize the handler with the specified model path.
24 |
25 | Args:
26 | model_path (str): Path to the embeddings model to load.
27 | max_concurrency (int): Maximum number of concurrent model inference tasks.
28 | """
29 | self.model_path = model_path
30 | self.model = MLX_Embeddings(model_path)
31 | self.model_created = int(time.time()) # Store creation time when model is loaded
32 |
33 | # Initialize request queue for embedding tasks
34 | self.request_queue = RequestQueue(max_concurrency=max_concurrency)
35 |
36 | logger.info(f"Initialized MLXEmbeddingsHandler with model path: {model_path}")
37 |
38 | async def get_models(self) -> List[Dict[str, Any]]:
39 | """
40 | Get list of available models with their metadata.
41 | """
42 | try:
43 | return [{
44 | "id": self.model_path,
45 | "object": "model",
46 | "created": self.model_created,
47 | "owned_by": "local"
48 | }]
49 | except Exception as e:
50 | logger.error(f"Error getting models: {str(e)}")
51 | return []
52 |
53 | async def initialize(self, config: Dict[str, Any]):
54 | """
55 | Initialize the request queue with configuration.
56 |
57 | Args:
58 | config: Dictionary containing queue configuration.
59 | """
60 | await self.request_queue.start(self._process_request)
61 |
62 | async def generate_embeddings_response(self, request: EmbeddingRequest):
63 | """
64 | Generate embeddings for a given text input.
65 |
66 | Args:
67 | request: EmbeddingRequest object containing the text input.
68 |
69 | Returns:
70 | List[float]: Embeddings for the input text.
71 | """
72 | try:
73 | # Create a unique request ID
74 | request_id = f"embeddings-{uuid.uuid4()}"
75 | if isinstance(request.input, str):
76 | request.input = [request.input]
77 | request_data = {
78 | "type": "embeddings",
79 | "input": request.input,
80 | "max_length": getattr(request, 'max_length', 512)
81 | }
82 |
83 | # Submit to the request queue
84 | response = await self.request_queue.submit(request_id, request_data)
85 |
86 | return response
87 |
88 | except Exception as e:
89 | logger.error(f"Error in embeddings generation: {str(e)}")
90 | content = create_error_response(f"Failed to generate embeddings: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR)
91 | raise HTTPException(status_code=500, detail=content)
92 |
93 | async def _process_request(self, request_data: Dict[str, Any]) -> List[List[float]]:
94 | """
95 | Process an embeddings request. This is the worker function for the request queue.
96 |
97 | Args:
98 | request_data: Dictionary containing the request data.
99 |
100 | Returns:
101 | List[List[float]]: The embeddings for the input texts.
102 | """
103 | try:
104 | # Check if the request is for embeddings
105 | if request_data.get("type") == "embeddings":
106 | result = self.model(
107 | texts=request_data["input"],
108 | max_length=request_data.get("max_length", 512)
109 | )
110 | # Force garbage collection after embeddings
111 | gc.collect()
112 | return result
113 |
114 | raise ValueError(f"Unknown request type: {request_data.get('type')}")
115 |
116 | except Exception as e:
117 | logger.error(f"Error processing embeddings request: {str(e)}")
118 | # Clean up on error
119 | gc.collect()
120 | raise
121 |
122 | async def get_queue_stats(self) -> Dict[str, Any]:
123 | """
124 | Get statistics from the request queue and performance metrics.
125 |
126 | Returns:
127 | Dict with queue and performance statistics.
128 | """
129 | queue_stats = self.request_queue.get_queue_stats()
130 |
131 | return {
132 | "queue_stats": queue_stats,
133 | }
134 |
135 | async def cleanup(self):
136 | """
137 | Cleanup resources and stop the request queue before shutdown.
138 |
139 | This method ensures all pending requests are properly cancelled
140 | and resources are released.
141 | """
142 | try:
143 | logger.info("Cleaning up MLXEmbeddingsHandler resources")
144 | if hasattr(self, 'request_queue'):
145 | await self.request_queue.stop()
146 | if hasattr(self, 'model'):
147 | self.model.cleanup()
148 | logger.info("MLXEmbeddingsHandler cleanup completed successfully")
149 | except Exception as e:
150 | logger.error(f"Error during MLXEmbeddingsHandler cleanup: {str(e)}")
151 | raise
152 |
153 | def __del__(self):
154 | """
155 | Destructor to ensure cleanup on object deletion.
156 | Note: Async cleanup cannot be reliably performed in __del__.
157 | Please use 'await cleanup()' explicitly.
158 | """
159 | if hasattr(self, '_cleaned') and self._cleaned:
160 | return
161 | # Set flag to prevent multiple cleanup attempts
162 | self._cleaned = True
--------------------------------------------------------------------------------
/app/handler/mlx_whisper.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import json
3 | import os
4 | import tempfile
5 | import time
6 | import uuid
7 | from typing import Any, AsyncGenerator, Dict, List, Optional
8 | from http import HTTPStatus
9 |
10 | from fastapi import HTTPException
11 | from loguru import logger
12 |
13 | from app.core.queue import RequestQueue
14 | from app.models.mlx_whisper import MLX_Whisper, calculate_audio_duration
15 | from app.schemas.openai import (
16 | TranscriptionRequest,
17 | TranscriptionResponse,
18 | TranscriptionUsageAudio,
19 | TranscriptionResponseFormat,
20 | TranscriptionResponseStream,
21 | TranscriptionResponseStreamChoice,
22 | Delta
23 | )
24 | from app.utils.errors import create_error_response
25 |
26 | class MLXWhisperHandler:
27 | """
28 | Handler class for making requests to the underlying MLX Whisper model service.
29 | Provides request queuing, metrics tracking, and robust error handling for audio transcription.
30 | """
31 |
32 | def __init__(self, model_path: str, max_concurrency: int = 1):
33 | """
34 | Initialize the handler with the specified model path.
35 |
36 | Args:
37 | model_path (str): Path to the model directory.
38 | max_concurrency (int): Maximum number of concurrent model inference tasks.
39 | """
40 | self.model_path = model_path
41 | self.model = MLX_Whisper(model_path)
42 | self.model_created = int(time.time()) # Store creation time when model is loaded
43 |
44 | # Initialize request queue for audio transcription tasks
45 | self.request_queue = RequestQueue(max_concurrency=max_concurrency)
46 |
47 | logger.info(f"Initialized MLXWhisperHandler with model path: {model_path}")
48 |
49 | async def get_models(self) -> List[Dict[str, Any]]:
50 | """
51 | Get list of available models with their metadata.
52 | """
53 | try:
54 | return [{
55 | "id": self.model_path,
56 | "object": "model",
57 | "created": self.model_created,
58 | "owned_by": "local"
59 | }]
60 | except Exception as e:
61 | logger.error(f"Error getting models: {str(e)}")
62 | return []
63 |
64 | async def initialize(self, queue_config: Optional[Dict[str, Any]] = None):
65 | """Initialize the handler and start the request queue."""
66 | if not queue_config:
67 | queue_config = {
68 | "max_concurrency": 1,
69 | "timeout": 600, # Longer timeout for audio processing
70 | "queue_size": 50
71 | }
72 | self.request_queue = RequestQueue(
73 | max_concurrency=queue_config.get("max_concurrency"),
74 | timeout=queue_config.get("timeout"),
75 | queue_size=queue_config.get("queue_size")
76 | )
77 | await self.request_queue.start(self._process_request)
78 | logger.info("Initialized MLXWhisperHandler and started request queue")
79 |
80 | async def generate_transcription_response(self, request: TranscriptionRequest) -> TranscriptionResponse:
81 | """
82 | Generate a transcription response for the given request.
83 | """
84 | request_id = f"transcription-{uuid.uuid4()}"
85 | temp_file_path = None
86 |
87 | try:
88 | request_data = await self._prepare_transcription_request(request)
89 | temp_file_path = request_data.get("audio_path")
90 | response = await self.request_queue.submit(request_id, request_data)
91 | response_data = TranscriptionResponse(
92 | text=response["text"],
93 | usage=TranscriptionUsageAudio(
94 | type="duration",
95 | seconds=int(calculate_audio_duration(temp_file_path))
96 | )
97 | )
98 | if request.response_format == TranscriptionResponseFormat.JSON:
99 | return response_data
100 | else:
101 | # dump to string for text response
102 | return json.dumps(response_data.model_dump())
103 | finally:
104 | # Clean up temporary file
105 | if temp_file_path and os.path.exists(temp_file_path):
106 | try:
107 | os.unlink(temp_file_path)
108 | logger.debug(f"Cleaned up temporary file: {temp_file_path}")
109 | except Exception as e:
110 | logger.warning(f"Failed to clean up temporary file {temp_file_path}: {str(e)}")
111 | # Force garbage collection
112 | gc.collect()
113 |
114 | async def generate_transcription_stream_from_data(
115 | self,
116 | request_data: Dict[str, Any],
117 | response_format: TranscriptionResponseFormat
118 | ) -> AsyncGenerator[str, None]:
119 | """
120 | Generate a transcription stream from prepared request data.
121 | Yields SSE-formatted chunks with timing information.
122 |
123 | Args:
124 | request_data: Prepared request data with audio_path already saved
125 | response_format: The response format (json or text)
126 | """
127 | request_id = f"transcription-{uuid.uuid4()}"
128 | created_time = int(time.time())
129 | temp_file_path = request_data.get("audio_path")
130 |
131 | try:
132 | # Set stream mode
133 | request_data["stream"] = True
134 |
135 | # Get the generator directly from the model (bypass queue for streaming)
136 | generator = self.model(
137 | audio_path=request_data.pop("audio_path"),
138 | **request_data
139 | )
140 |
141 | # Stream each chunk
142 | for chunk in generator:
143 | # Create streaming response
144 | stream_response = TranscriptionResponseStream(
145 | id=request_id,
146 | object="transcription.chunk",
147 | created=created_time,
148 | model=self.model_path,
149 | choices=[
150 | TranscriptionResponseStreamChoice(
151 | delta=Delta(
152 | content=chunk.get("text", "")
153 | ),
154 | finish_reason=None
155 | )
156 | ]
157 | )
158 |
159 | # Yield as SSE format
160 | yield f"data: {stream_response.model_dump_json()}\n\n"
161 |
162 | # Send final chunk with finish_reason
163 | final_response = TranscriptionResponseStream(
164 | id=request_id,
165 | object="transcription.chunk",
166 | created=created_time,
167 | model=self.model_path,
168 | choices=[
169 | TranscriptionResponseStreamChoice(
170 | delta=Delta(content=""),
171 | finish_reason="stop"
172 | )
173 | ]
174 | )
175 | yield f"data: {final_response.model_dump_json()}\n\n"
176 | yield "data: [DONE]\n\n"
177 |
178 | except Exception as e:
179 | logger.error(f"Error during transcription streaming: {str(e)}")
180 | raise
181 | finally:
182 | # Clean up temporary file
183 | if temp_file_path and os.path.exists(temp_file_path):
184 | try:
185 | os.unlink(temp_file_path)
186 | logger.debug(f"Cleaned up temporary file: {temp_file_path}")
187 | except Exception as e:
188 | logger.warning(f"Failed to clean up temporary file {temp_file_path}: {str(e)}")
189 | # Clean up
190 | gc.collect()
191 |
192 |
193 | async def _process_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
194 | """
195 | Process an audio transcription request. This is the worker function for the request queue.
196 |
197 | Args:
198 | request_data: Dictionary containing the request data.
199 |
200 | Returns:
201 | Dict: The model's response containing transcribed text.
202 | """
203 | try:
204 | # Extract request parameters
205 | audio_path = request_data.pop("audio_path")
206 |
207 | # Call the model with the audio file
208 | result = self.model(
209 | audio_path=audio_path,
210 | **request_data
211 | )
212 |
213 | # Force garbage collection after model inference
214 | gc.collect()
215 |
216 | return result
217 |
218 | except Exception as e:
219 | logger.error(f"Error processing audio transcription request: {str(e)}")
220 | # Clean up on error
221 | gc.collect()
222 | raise
223 |
224 | async def _save_uploaded_file(self, file) -> str:
225 | """
226 | Save the uploaded file to a temporary location.
227 |
228 | Args:
229 | file: The uploaded file object.
230 |
231 | Returns:
232 | str: Path to the temporary file.
233 | """
234 | try:
235 | # Create a temporary file with the same extension as the uploaded file
236 | file_extension = os.path.splitext(file.filename)[1] if file.filename else ".wav"
237 |
238 | print("file_extension", file_extension)
239 |
240 | # Read file content first (this can only be done once with FastAPI uploads)
241 | content = await file.read()
242 |
243 | # Create temporary file
244 | with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
245 | # Write the file contents
246 | temp_file.write(content)
247 | temp_path = temp_file.name
248 |
249 | logger.debug(f"Saved uploaded file to temporary location: {temp_path}")
250 | return temp_path
251 |
252 | except Exception as e:
253 | logger.error(f"Error saving uploaded file: {str(e)}")
254 | raise
255 |
256 | async def _prepare_transcription_request(
257 | self,
258 | request: TranscriptionRequest
259 | ) -> Dict[str, Any]:
260 | """
261 | Prepare a transcription request by parsing model parameters.
262 |
263 | Args:
264 | request: TranscriptionRequest object.
265 | audio_path: Path to the audio file.
266 |
267 | Returns:
268 | Dict containing the request data ready for the model.
269 | """
270 | try:
271 |
272 | file = request.file
273 |
274 | file_path = await self._save_uploaded_file(file)
275 | request_data = {
276 | "audio_path": file_path,
277 | "verbose": False,
278 | }
279 |
280 | # Add optional parameters if provided
281 | if request.temperature is not None:
282 | request_data["temperature"] = request.temperature
283 |
284 | if request.language is not None:
285 | request_data["language"] = request.language
286 |
287 | if request.prompt is not None:
288 | request_data["initial_prompt"] = request.prompt
289 |
290 | # Map additional parameters if they exist
291 | decode_options = {}
292 | if request.language is not None:
293 | decode_options["language"] = request.language
294 |
295 | # Add decode options to request data
296 | request_data.update(decode_options)
297 |
298 | logger.debug(f"Prepared transcription request: {request_data}")
299 |
300 | return request_data
301 |
302 | except Exception as e:
303 | logger.error(f"Failed to prepare transcription request: {str(e)}")
304 | content = create_error_response(
305 | f"Failed to process request: {str(e)}",
306 | "bad_request",
307 | HTTPStatus.BAD_REQUEST
308 | )
309 | raise HTTPException(status_code=400, detail=content)
310 |
311 | async def get_queue_stats(self) -> Dict[str, Any]:
312 | """
313 | Get statistics from the request queue and performance metrics.
314 |
315 | Returns:
316 | Dict with queue and performance statistics.
317 | """
318 | queue_stats = self.request_queue.get_queue_stats()
319 |
320 | return {
321 | "queue_stats": queue_stats,
322 | }
323 |
324 | async def cleanup(self):
325 | """
326 | Cleanup resources and stop the request queue before shutdown.
327 |
328 | This method ensures all pending requests are properly cancelled
329 | and resources are released.
330 | """
331 | try:
332 | logger.info("Cleaning up MLXWhisperHandler resources")
333 | if hasattr(self, 'request_queue'):
334 | await self.request_queue.stop()
335 | logger.info("MLXWhisperHandler cleanup completed successfully")
336 | except Exception as e:
337 | logger.error(f"Error during MLXWhisperHandler cleanup: {str(e)}")
338 | raise
339 |
340 |
--------------------------------------------------------------------------------
/app/handler/parser/__init__.py:
--------------------------------------------------------------------------------
1 | from app.handler.parser.harmony import HarmonyParser
2 | from app.handler.parser.base import BaseToolParser, BaseThinkingParser
3 | from app.handler.parser.qwen3 import Qwen3ToolParser, Qwen3ThinkingParser
4 | from app.handler.parser.glm4_moe import Glm4MoEToolParser, Glm4MoEThinkingParser
5 |
6 |
7 | __all__ = ['BaseToolParser', 'BaseThinkingParser', 'Qwen3ToolParser', 'Qwen3ThinkingParser', 'HarmonyParser', 'Glm4MoEToolParser', 'Glm4MoEThinkingParser']
--------------------------------------------------------------------------------
/app/handler/parser/base.py:
--------------------------------------------------------------------------------
1 | import json
2 | from json_repair import repair_json
3 | from typing import Any, Dict, List, Optional, Tuple
4 |
5 |
6 | class BaseThinkingParser:
7 | def __init__(self, thinking_open: str, thinking_close: str):
8 | self.thinking_open = thinking_open
9 | self.thinking_close = thinking_close
10 | self.is_thinking = False
11 |
12 | def parse(self, content: str) -> Tuple[Optional[str], str]:
13 | if self.thinking_open in content:
14 | start_thinking = content.find(self.thinking_open)
15 | end_thinking = content.find(self.thinking_close)
16 | if end_thinking != -1:
17 | return content[start_thinking + len(self.thinking_open):end_thinking].strip(), content[end_thinking + len(self.thinking_close):].strip()
18 | return None, content
19 |
20 | def parse_stream(self, chunk: Optional[str] = None) -> Tuple[Optional[Any], bool]:
21 | """
22 | Parse streaming chunks for thinking content.
23 |
24 | Returns:
25 | Tuple[parsed_content, is_complete]:
26 | - parsed_content: The parsed chunk (could be str, dict, or None)
27 | - is_complete: True if thinking section is complete
28 | """
29 | if not self.is_thinking:
30 | if chunk == self.thinking_open:
31 | self.is_thinking = True
32 | return None, False
33 | return chunk, False
34 | if chunk == self.thinking_close:
35 | self.is_thinking = False
36 | return None, True
37 |
38 | return {
39 | "reasoning_content": chunk
40 | }, False
41 |
42 | class ParseToolState:
43 | NORMAL = 0
44 | FOUND_PREFIX = 1
45 |
46 | class BaseToolParser:
47 | def __init__(self, tool_open: str, tool_close: str):
48 | self.tool_open = tool_open
49 | self.tool_close = tool_close
50 | self.buffer = ""
51 | self.state = ParseToolState.NORMAL
52 |
53 | def get_tool_open(self):
54 | return self.tool_open
55 |
56 | def get_tool_close(self):
57 | return self.tool_close
58 |
59 | def parse(self, content: str) -> Tuple[Optional[List[Dict[str, Any]]], str]:
60 | tool_calls = []
61 | remaining_content = ""
62 | start = 0
63 | while True:
64 | start_tool = content.find(self.tool_open, start)
65 | if start_tool == -1:
66 | break
67 | remaining_content += content[:start_tool].strip()
68 | end_tool = content.find(self.tool_close, start_tool + len(self.tool_open))
69 | if end_tool == -1:
70 | break
71 | tool_content = content[start_tool + len(self.tool_open):end_tool].strip()
72 |
73 | try:
74 | repaired_json = repair_json(tool_content)
75 | json_output = json.loads(repaired_json)
76 | tool_calls.append(json_output)
77 | except json.JSONDecodeError:
78 | print("Error parsing tool call: ", tool_content)
79 | break
80 | content = content[end_tool + len(self.tool_close):].strip()
81 | return tool_calls, remaining_content
82 |
83 | def parse_stream(self, chunk: Optional[str] = None) -> Tuple[Optional[Any], bool]:
84 | """
85 | Parse streaming chunks for tool calls.
86 |
87 | Returns:
88 | Tuple[parsed_content, is_complete]:
89 | - parsed_content: The parsed chunk (could be str, dict, or None)
90 | - is_complete: True if tool call is complete
91 | """
92 | if chunk is None:
93 | return None, True
94 |
95 | if self.tool_open in chunk:
96 | self.state = ParseToolState.FOUND_PREFIX
97 | start_tool_index = chunk.find(self.tool_open)
98 | end_tool_index = chunk.find(self.tool_close)
99 | if end_tool_index != -1:
100 | self.buffer = chunk[start_tool_index + len(self.tool_open):end_tool_index]
101 | self.state = ParseToolState.NORMAL
102 | try:
103 | repaired_json = repair_json(self.buffer)
104 | json_output = json.loads(repaired_json)
105 | except json.JSONDecodeError:
106 | print("Error parsing tool call: ", self.buffer)
107 | return None, True
108 | return {
109 | "name": json_output["name"],
110 | "arguments": json.dumps(json_output["arguments"])
111 | }, True
112 |
113 | self.buffer += chunk[start_tool_index + len(self.tool_open):]
114 |
115 | return chunk[:start_tool_index], False
116 |
117 | if self.state == ParseToolState.FOUND_PREFIX:
118 | end_tool_index = chunk.find(self.tool_close)
119 | if end_tool_index != -1:
120 | self.buffer += chunk[:end_tool_index]
121 | try:
122 | repaired_json = repair_json(self.buffer)
123 | json_output = json.loads(repaired_json)
124 | except json.JSONDecodeError:
125 | print("Error parsing tool call: ", self.buffer)
126 | return None, False
127 | return {
128 | "name": json_output["name"],
129 | "arguments": json.dumps(json_output["arguments"])
130 | }, True
131 | else:
132 | self.buffer += chunk
133 | return None, False
134 |
135 | return chunk, False
136 |
--------------------------------------------------------------------------------
/app/handler/parser/glm4_moe.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | from typing import Any, Dict, List, Optional, Tuple
4 | from app.handler.parser.base import BaseToolParser, BaseThinkingParser
5 |
6 | TOOL_OPEN = ""
7 | TOOL_CLOSE = ""
8 | THINKING_OPEN = ""
9 | THINKING_CLOSE = ""
10 |
11 | class Glm4MoEThinkingParser(BaseThinkingParser):
12 | """Parser for GLM4 model's thinking response format."""
13 |
14 | def __init__(self):
15 | super().__init__(
16 | thinking_open=THINKING_OPEN,
17 | thinking_close=THINKING_CLOSE
18 | )
19 |
20 | class Glm4MoEToolParser(BaseToolParser):
21 | """Parser for GLM4 model's tool response format with XML-style arguments."""
22 |
23 | def __init__(self):
24 | super().__init__(
25 | tool_open=TOOL_OPEN,
26 | tool_close=TOOL_CLOSE
27 | )
28 | # Regex patterns for parsing GLM4 XML-style tool calls
29 | self.func_call_regex = re.compile(r".*?", re.DOTALL)
30 | self.func_detail_regex = re.compile(
31 | r"([^\n]*)\n(.*)", re.DOTALL
32 | )
33 | self.func_arg_regex = re.compile(
34 | r"(.*?)\s*(.*?)", re.DOTALL
35 | )
36 | # State for streaming parsing
37 | self.stream_buffer = ""
38 | self.current_func_name = None
39 | self.current_args = {}
40 | self.parsing_tool = False
41 |
42 | def _deserialize_value(self, value: str) -> Any:
43 | """Try to deserialize a value from string to appropriate Python type."""
44 | value = value.strip()
45 |
46 | # Try JSON parsing first
47 | try:
48 | return json.loads(value)
49 | except (json.JSONDecodeError, ValueError):
50 | pass
51 |
52 | # Try literal eval for Python literals
53 | try:
54 | import ast
55 | return ast.literal_eval(value)
56 | except (ValueError, SyntaxError):
57 | pass
58 |
59 | # Return as string if all else fails
60 | return value
61 |
62 | def parse(self, content: str) -> Tuple[List[Dict[str, Any]], str]:
63 | """
64 | Parse complete content for GLM4 tool calls.
65 |
66 | Returns:
67 | Tuple of (list of tool calls, remaining content)
68 | """
69 | tool_calls = []
70 | matched_calls = self.func_call_regex.findall(content)
71 |
72 | try:
73 | for match in matched_calls:
74 | # Extract function name and arguments section
75 | detail_match = self.func_detail_regex.search(match)
76 | if not detail_match:
77 | continue
78 |
79 | func_name = detail_match.group(1).strip()
80 | args_section = detail_match.group(2)
81 |
82 | # Extract all key-value pairs
83 | arg_pairs = self.func_arg_regex.findall(args_section)
84 | arguments = {}
85 | for key, value in arg_pairs:
86 | arg_key = key.strip()
87 | arg_value = self._deserialize_value(value)
88 | arguments[arg_key] = arg_value
89 |
90 | # Build tool call object
91 | tool_calls.append({
92 | "name": func_name,
93 | "arguments": json.dumps(arguments)
94 | })
95 | except Exception as e:
96 | print(f"Error parsing GLM4 tool call: {e}")
97 |
98 | # Find content before first tool call
99 | first_tool_idx = content.find(self.tool_open)
100 | if first_tool_idx != -1:
101 | remaining_content = content[:first_tool_idx].strip()
102 | else:
103 | remaining_content = content.strip()
104 |
105 | return tool_calls, remaining_content
106 |
107 | def parse_stream(self, chunk: str) -> Tuple[Optional[Any], bool]:
108 | """
109 | Parse streaming chunks for GLM4 tool calls.
110 |
111 | This handles the XML-style format incrementally.
112 |
113 | Returns:
114 | Tuple[parsed_content, is_complete]:
115 | - parsed_content: The parsed chunk (could be str, dict, or None)
116 | - is_complete: True if tool call is complete
117 | """
118 | if chunk is None:
119 | return None, False
120 |
121 | self.stream_buffer += chunk
122 |
123 | # Check if we're starting a tool call
124 | if not self.parsing_tool:
125 | if self.tool_open in self.stream_buffer:
126 | tool_start_idx = self.stream_buffer.find(self.tool_open)
127 | # Return any content before the tool call
128 | content_before = self.stream_buffer[:tool_start_idx]
129 | self.stream_buffer = self.stream_buffer[tool_start_idx + len(self.tool_open):]
130 | self.parsing_tool = True
131 | self.current_func_name = None
132 | self.current_args = {}
133 |
134 | if content_before:
135 | return content_before, False
136 | return None, False
137 | else:
138 | # No tool call found yet, return the content (except last few chars as buffer)
139 | if len(self.stream_buffer) > len(self.tool_open):
140 | content_to_return = self.stream_buffer[:-len(self.tool_open)]
141 | self.stream_buffer = self.stream_buffer[-len(self.tool_open):]
142 | if content_to_return:
143 | return content_to_return, False
144 | return None, False
145 |
146 | # We're inside a tool call
147 | if self.tool_close in self.stream_buffer:
148 | tool_end_idx = self.stream_buffer.find(self.tool_close)
149 | tool_content = self.stream_buffer[:tool_end_idx]
150 | self.stream_buffer = self.stream_buffer[tool_end_idx + len(self.tool_close):]
151 |
152 | # Parse the complete tool call
153 | full_tool = f"{self.tool_open}{tool_content}{self.tool_close}"
154 | parsed_tools, _ = self.parse(full_tool)
155 |
156 | self.parsing_tool = False
157 | self.current_func_name = None
158 | self.current_args = {}
159 |
160 | if parsed_tools:
161 | tool = parsed_tools[0]
162 | # Return the complete tool call information
163 | return {
164 | "name": tool["name"],
165 | "arguments": tool["arguments"]
166 | }, True # Tool call complete
167 | return None, True
168 |
169 | # Still accumulating the tool call
170 | # Try to extract function name if we haven't yet
171 | if self.current_func_name is None:
172 | if '\n' in self.stream_buffer or len(self.stream_buffer) > 50:
173 | # Extract function name (first line)
174 | newline_idx = self.stream_buffer.find('\n')
175 | if newline_idx != -1:
176 | self.current_func_name = self.stream_buffer[:newline_idx].strip()
177 | self.stream_buffer = self.stream_buffer[newline_idx + 1:]
178 | # Return function name
179 | return {
180 | "name": self.current_func_name,
181 | "arguments": ""
182 | }, False
183 |
184 | # Check if we can parse any complete argument pairs
185 | arg_matches = list(self.func_arg_regex.finditer(self.stream_buffer))
186 | if arg_matches:
187 | last_match = arg_matches[-1]
188 | # Only process if we have the complete closing tag
189 | if last_match.end() < len(self.stream_buffer):
190 | for match in arg_matches:
191 | arg_key = match.group(1).strip()
192 | arg_value = self._deserialize_value(match.group(2))
193 | if arg_key not in self.current_args:
194 | self.current_args[arg_key] = arg_value
195 |
196 | # Remove processed content from buffer
197 | self.stream_buffer = self.stream_buffer[last_match.end():]
198 |
199 | # Return incremental arguments
200 | if self.current_args:
201 | return {
202 | "name": None,
203 | "arguments": json.dumps(self.current_args)
204 | }, False
205 |
206 | return None, False
--------------------------------------------------------------------------------
/app/handler/parser/harmony.py:
--------------------------------------------------------------------------------
1 | from openai_harmony import (
2 | load_harmony_encoding,
3 | HarmonyEncodingName,
4 | StreamableParser,
5 | Role
6 | )
7 | from typing import Tuple, Dict, List, Optional, Any, Union
8 | import logging
9 | from enum import Enum
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 | class ChannelType(Enum):
14 | """Enumeration of harmony channel types."""
15 | ANALYSIS = "analysis"
16 | COMMENTARY = "commentary"
17 | FINAL = "final"
18 |
19 | class ParsingState(Enum):
20 | """Enumeration of parsing states."""
21 | IDLE = "idle"
22 | PROCESSING_TOKENS = "processing_tokens"
23 | TOOL_PARSING = "tool_parsing"
24 | STREAM_ENDED = "stream_ended"
25 |
26 | # Harmony Parsing Helper Functions
27 | class HarmonyParser:
28 | """
29 | Enhanced helper class for parsing GPT-OSS model responses using harmony encoding.
30 |
31 | This parser handles streaming and non-streaming responses with proper state management,
32 | error handling, and support for different harmony channels (analysis, commentary, final).
33 | """
34 |
35 | def __init__(self):
36 | """Initialize the harmony parser with encoding and state management."""
37 | try:
38 | self.enc = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
39 | self.parser = StreamableParser(self.enc, role=Role.ASSISTANT)
40 | except Exception as e:
41 | logger.error(f"Failed to initialize harmony encoding: {e}")
42 | raise
43 |
44 | # Configuration
45 | self.end_tool_chunk = "<|call|>"
46 |
47 | # State management
48 | self._reset_state()
49 |
50 | def _reset_state(self) -> None:
51 | """Reset the parser state to initial values."""
52 | self.tool_state = False
53 | self.end_stream = False
54 | self.parsing_state = ParsingState.IDLE
55 | self._accumulated_content = {
56 | ChannelType.ANALYSIS.value: [],
57 | ChannelType.COMMENTARY.value: [],
58 | ChannelType.FINAL.value: []
59 | }
60 | self._current_function_name = None
61 | self._function_arguments = []
62 |
63 | def parse_stream(self, text: Optional[str] = None) -> Tuple[Optional[Any], bool]:
64 | """
65 | Parse streaming text input and return parsing state and extracted content.
66 |
67 | Args:
68 | text: The text chunk to parse, or None for empty chunks
69 |
70 | Returns:
71 | Tuple[parsed_content, is_complete]:
72 | - parsed_content: The parsed chunk (could be str, dict, or None)
73 | - is_complete: True if stream has ended
74 |
75 | Raises:
76 | Exception: If encoding or parsing fails
77 | """
78 | # Handle end of stream marker
79 | if text == self.end_tool_chunk:
80 | logger.debug("End tool chunk detected, marking stream as ended")
81 | self.end_stream = True
82 | self.parsing_state = ParsingState.STREAM_ENDED
83 | return None, True
84 |
85 | # Handle empty or None text
86 | if not text:
87 | return None, self.end_stream
88 |
89 | try:
90 | self.parsing_state = ParsingState.PROCESSING_TOKENS
91 | text_tokens = self.enc.encode(text, allowed_special="all")
92 |
93 | # Initialize local variables for this chunk
94 | contents: List[str] = []
95 | function_name: Optional[str] = None
96 | function_arguments: List[str] = []
97 | reasoning_content: List[str] = []
98 | current_channel: Optional[str] = None
99 |
100 | # Process each token
101 | for text_token in text_tokens:
102 | try:
103 | stream_text = self.parser.process(text_token)
104 | current_channel = stream_text.current_channel
105 | content = stream_text.last_content_delta
106 |
107 | if not content:
108 | continue
109 |
110 | # Handle different channels
111 | if current_channel == ChannelType.ANALYSIS.value:
112 | reasoning_content.append(content)
113 | self._accumulated_content[ChannelType.ANALYSIS.value].append(content)
114 |
115 | elif current_channel == ChannelType.COMMENTARY.value:
116 | self.parsing_state = ParsingState.TOOL_PARSING
117 |
118 | if self.tool_state:
119 | # Already parsing function arguments
120 | function_arguments.append(content)
121 | self._function_arguments.append(content)
122 | else:
123 | # Start of new function call
124 | self.tool_state = True
125 | if hasattr(stream_text, 'current_recipient') and stream_text.current_recipient:
126 | function_name = stream_text.current_recipient.replace("functions.", "")
127 | self._current_function_name = function_name
128 | function_arguments = [content]
129 | self._function_arguments = [content]
130 |
131 | elif current_channel == ChannelType.FINAL.value:
132 | contents.append(content)
133 | self._accumulated_content[ChannelType.FINAL.value].append(content)
134 |
135 | except Exception as token_error:
136 | logger.warning(f"Error processing token {text_token}: {token_error}")
137 | continue
138 |
139 | # Return appropriate response based on current channel
140 | return self._build_response(current_channel, {
141 | 'reasoning_content': reasoning_content,
142 | 'function_name': function_name,
143 | 'function_arguments': function_arguments,
144 | 'contents': contents
145 | })
146 |
147 | except Exception as e:
148 | logger.error(f"Error in parse_stream: {e}")
149 | return None, self.end_stream
150 |
151 | def _build_response(self, current_channel: Optional[str], content_data: Dict[str, Any]) -> Tuple[Optional[Union[Dict[str, Any], str]], bool]:
152 | """
153 | Build the appropriate response based on the current channel.
154 |
155 | Args:
156 | current_channel: The current harmony channel being processed
157 | content_data: Dictionary containing extracted content from different sources
158 |
159 | Returns:
160 | Tuple[parsed_content, is_complete]:
161 | - parsed_content: The parsed content (str or dict)
162 | - is_complete: Whether the stream has ended
163 | """
164 | if not current_channel:
165 | return None, self.end_stream
166 |
167 | try:
168 | if current_channel == ChannelType.ANALYSIS.value:
169 | reasoning_content = content_data.get('reasoning_content', [])
170 | if reasoning_content:
171 | return {
172 | "reasoning_content": "".join(reasoning_content)
173 | }, self.end_stream
174 |
175 | elif current_channel == ChannelType.COMMENTARY.value:
176 | function_name = content_data.get('function_name')
177 | function_arguments = content_data.get('function_arguments', [])
178 |
179 | response = {}
180 | if function_name:
181 | response["name"] = function_name
182 | if function_arguments:
183 | response["arguments"] = "".join(function_arguments)
184 |
185 | if response:
186 | return response, self.end_stream
187 |
188 | elif current_channel == ChannelType.FINAL.value:
189 | contents = content_data.get('contents', [])
190 | if contents:
191 | return "".join(contents), self.end_stream
192 | except Exception as e:
193 | logger.error(f"Error building response for channel {current_channel}: {e}")
194 |
195 | return None, self.end_stream
196 |
197 | def reset(self) -> None:
198 | """Reset the parser to initial state for reuse."""
199 | logger.debug("Resetting harmony parser state")
200 | self._reset_state()
201 |
202 | def get_accumulated_content(self, channel: Optional[str] = None) -> Dict[str, str]:
203 | """
204 | Get accumulated content for all channels or a specific channel.
205 |
206 | Args:
207 | channel: Optional specific channel to retrieve content for
208 |
209 | Returns:
210 | Dictionary of channel content
211 | """
212 | if channel and channel in self._accumulated_content:
213 | return {channel: "".join(self._accumulated_content[channel])}
214 |
215 | return {
216 | ch: "".join(content) for ch, content in self._accumulated_content.items()
217 | if content
218 | }
219 |
220 | def parse(self, text: str) -> Dict[str, Any]:
221 | """
222 | Parse complete text response and extract structured content.
223 |
224 | This method processes the entire text at once (non-streaming) and extracts
225 | reasoning content, tool calls, and final content based on harmony channels.
226 |
227 | Args:
228 | text: The complete text response to parse
229 |
230 | Returns:
231 | Dictionary containing parsed content with keys:
232 | - reasoning_content: Analysis/thinking content
233 | - tool_calls: List of tool call objects
234 | - content: Final response content
235 |
236 | Raises:
237 | Exception: If encoding or parsing fails
238 | """
239 | # Initialize result structure
240 | result = {
241 | "reasoning_content": None,
242 | "tool_calls": None,
243 | "content": None
244 | }
245 |
246 | if not text:
247 | logger.warning("Empty text provided to parse method")
248 | return result
249 |
250 | try:
251 | # Remove end tool chunk if present
252 | clean_text = text
253 | if self.end_tool_chunk in text:
254 | clean_text = text.split(self.end_tool_chunk)[0]
255 | logger.debug(f"Removed end tool chunk, processing {len(clean_text)} characters")
256 |
257 | # Encode and parse messages
258 | tokens = self.enc.encode(clean_text, allowed_special="all")
259 | parsed_messages = self.enc.parse_messages_from_completion_tokens(tokens, role=Role.ASSISTANT)
260 |
261 | # Process each parsed message
262 | for message in parsed_messages:
263 | try:
264 | if not hasattr(message, 'channel') or not hasattr(message, 'content'):
265 | logger.warning(f"Invalid message structure: {message}")
266 | continue
267 |
268 | if message.channel == ChannelType.ANALYSIS.value:
269 | if message.content and len(message.content) > 0:
270 | result["reasoning_content"] = message.content[0].text
271 | logger.debug("Extracted reasoning content")
272 |
273 | elif message.channel == ChannelType.COMMENTARY.value:
274 | if (hasattr(message, 'recipient') and message.recipient and
275 | message.content and len(message.content) > 0):
276 |
277 | tool_call = {
278 | "name": message.recipient.replace("functions.", ""),
279 | "arguments": message.content[0].text
280 | }
281 | result["tool_calls"] = [tool_call]
282 | logger.debug(f"Extracted tool call: {tool_call['name']}")
283 |
284 | elif message.channel == ChannelType.FINAL.value:
285 | if message.content and len(message.content) > 0:
286 | result["content"] = message.content[0].text
287 | logger.debug("Extracted final content")
288 |
289 | except Exception as msg_error:
290 | logger.warning(f"Error processing message: {msg_error}")
291 | continue
292 |
293 | except Exception as e:
294 | logger.error(f"Error in parse method: {e}")
295 | # Return partial results if available, don't raise
296 |
297 | return result
298 |
299 | def is_stream_ended(self) -> bool:
300 | """Check if the stream has ended."""
301 | return self.end_stream
302 |
303 | def get_parsing_state(self) -> ParsingState:
304 | """Get the current parsing state."""
305 | return self.parsing_state
306 |
307 | def is_tool_parsing_active(self) -> bool:
308 | """Check if currently parsing tool calls."""
309 | return self.tool_state
310 |
311 | def get_current_function_info(self) -> Dict[str, Optional[str]]:
312 | """
313 | Get information about the currently parsed function.
314 |
315 | Returns:
316 | Dictionary with function name and accumulated arguments
317 | """
318 | return {
319 | "name": self._current_function_name,
320 | "arguments": "".join(self._function_arguments) if self._function_arguments else None
321 | }
322 |
323 | def __repr__(self) -> str:
324 | """String representation of the parser state."""
325 | return (f"HarmonyParser(state={self.parsing_state.value}, "
326 | f"tool_state={self.tool_state}, "
327 | f"stream_ended={self.end_stream})")
--------------------------------------------------------------------------------
/app/handler/parser/qwen3.py:
--------------------------------------------------------------------------------
1 | from app.handler.parser.base import BaseToolParser, BaseThinkingParser
2 |
3 | TOOL_OPEN = ""
4 | TOOL_CLOSE = ""
5 | THINKING_OPEN = ""
6 | THINKING_CLOSE = ""
7 |
8 | class Qwen3ToolParser(BaseToolParser):
9 | """Parser for Qwen3 model's tool response format."""
10 |
11 | def __init__(self):
12 | super().__init__(
13 | tool_open=TOOL_OPEN,
14 | tool_close=TOOL_CLOSE
15 | )
16 |
17 | class Qwen3ThinkingParser(BaseThinkingParser):
18 | """Parser for Qwen3 model's thinking response format."""
19 |
20 | def __init__(self):
21 | super().__init__(
22 | thinking_open=THINKING_OPEN,
23 | thinking_close=THINKING_CLOSE
24 | )
--------------------------------------------------------------------------------
/app/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import asyncio
3 | import gc
4 | import time
5 | from contextlib import asynccontextmanager
6 |
7 | import uvicorn
8 | from fastapi import FastAPI, Request
9 | from fastapi.middleware.cors import CORSMiddleware
10 | from fastapi.responses import JSONResponse
11 | from loguru import logger
12 |
13 | import mlx.core as mx
14 | from app.handler.mlx_vlm import MLXVLMHandler
15 | from app.handler.mlx_lm import MLXLMHandler
16 | from app.handler.mlx_embeddings import MLXEmbeddingsHandler
17 | from app.handler.mlx_whisper import MLXWhisperHandler
18 | from app.handler import MLXFluxHandler, MFLUX_AVAILABLE
19 | from app.api.endpoints import router
20 | from app.version import __version__
21 |
22 | def configure_logging(log_file=None, no_log_file=False, log_level="INFO"):
23 | """Configure loguru logging based on CLI parameters."""
24 | logger.remove() # Remove default handler
25 |
26 | # Add console handler
27 | logger.add(
28 | lambda msg: print(msg),
29 | level=log_level,
30 | format="{time:YYYY-MM-DD HH:mm:ss} | "
31 | "{level: <8} | "
32 | "{name}:{function}:{line} | "
33 | "✦ {message}",
34 | colorize=True
35 | )
36 |
37 | # Add file handler if not disabled
38 | if not no_log_file:
39 | file_path = log_file if log_file else "logs/app.log"
40 | logger.add(
41 | file_path,
42 | rotation="500 MB",
43 | retention="10 days",
44 | level=log_level,
45 | format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}"
46 | )
47 |
48 | # Logging will be configured in setup_server() with CLI arguments
49 |
50 | def parse_args():
51 | parser = argparse.ArgumentParser(description="MLX OpenAI Compatible Server")
52 | parser.add_argument("--model-path", type=str, help="Path to the model (required for lm, multimodal, image-generation, image-edit, embeddings, whisper model types). With `image-generation` or `image-edit` model types, it should be the local path to the model.")
53 | parser.add_argument("--model-type", type=str, default="lm", choices=["lm", "multimodal", "image-generation", "image-edit", "embeddings", "whisper"], help="Model type")
54 | parser.add_argument("--context-length", type=int, default=None, help="Context length for language models. Only works with `lm` or `multimodal` model types.")
55 | parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
56 | parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
57 | parser.add_argument("--max-concurrency", type=int, default=1, help="Maximum number of concurrent requests")
58 | parser.add_argument("--queue-timeout", type=int, default=300, help="Request timeout in seconds")
59 | parser.add_argument("--queue-size", type=int, default=100, help="Maximum queue size for pending requests")
60 | parser.add_argument("--quantize", type=int, default=8, help="Quantization level for the model. Only used for image-generation and image-edit Flux models.")
61 | parser.add_argument("--config-name", type=str, default=None, choices=["flux-schnell", "flux-dev", "flux-krea-dev", "flux-kontext-dev"], help="Config name of the model. Only used for image-generation and image-edit Flux models.")
62 | parser.add_argument("--lora-paths", type=str, default=None, help="Path to the LoRA file(s). Multiple paths should be separated by commas.")
63 | parser.add_argument("--lora-scales", type=str, default=None, help="Scale factor for the LoRA file(s). Multiple scales should be separated by commas.")
64 | parser.add_argument("--disable-auto-resize", action="store_true", help="Disable automatic model resizing. Only work for Vision Language Models.")
65 | parser.add_argument("--log-file", type=str, default=None, help="Path to log file. If not specified, logs will be written to 'logs/app.log' by default.")
66 | parser.add_argument("--no-log-file", action="store_true", help="Disable file logging entirely. Only console output will be shown.")
67 | parser.add_argument("--log-level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the logging level. Default is INFO.")
68 |
69 | args = parser.parse_args()
70 |
71 | return args
72 |
73 |
74 | def get_model_identifier(args):
75 | """Get the appropriate model identifier based on model type."""
76 | return args.model_path
77 |
78 | def create_lifespan(config_args):
79 | """Factory function to create a lifespan context manager with access to config args."""
80 | @asynccontextmanager
81 | async def lifespan(app: FastAPI):
82 | try:
83 | model_identifier = get_model_identifier(config_args)
84 | if config_args.model_type == "image-generation":
85 | logger.info(f"Initializing MLX handler with model name: {model_identifier}")
86 | else:
87 | logger.info(f"Initializing MLX handler with model path: {model_identifier}")
88 |
89 | if config_args.model_type == "multimodal":
90 | handler = MLXVLMHandler(
91 | model_path=model_identifier,
92 | context_length=getattr(config_args, 'context_length', None),
93 | max_concurrency=config_args.max_concurrency,
94 | disable_auto_resize=getattr(config_args, 'disable_auto_resize', False)
95 | )
96 | elif config_args.model_type == "image-generation":
97 | if not MFLUX_AVAILABLE:
98 | raise ValueError("Image generation requires mflux. Install with: pip install git+https://github.com/cubist38/mflux.git")
99 | if not config_args.config_name in ["flux-schnell", "flux-dev", "flux-krea-dev"]:
100 | raise ValueError(f"Invalid config name: {config_args.config_name}. Only flux-schnell, flux-dev, and flux-krea-dev are supported for image generation.")
101 | handler = MLXFluxHandler(
102 | model_path=model_identifier,
103 | max_concurrency=config_args.max_concurrency,
104 | quantize=getattr(config_args, 'quantize', 8),
105 | config_name=config_args.config_name,
106 | lora_paths=getattr(config_args, 'lora_paths', None),
107 | lora_scales=getattr(config_args, 'lora_scales', None)
108 | )
109 | elif config_args.model_type == "embeddings":
110 | handler = MLXEmbeddingsHandler(
111 | model_path=model_identifier,
112 | max_concurrency=config_args.max_concurrency
113 | )
114 | elif config_args.model_type == "image-edit":
115 | if not MFLUX_AVAILABLE:
116 | raise ValueError("Image editing requires mflux. Install with: pip install git+https://github.com/cubist38/mflux.git")
117 | if config_args.config_name != "flux-kontext-dev":
118 | raise ValueError(f"Invalid config name: {config_args.config_name}. Only flux-kontext-dev is supported for image edit.")
119 | handler = MLXFluxHandler(
120 | model_path=model_identifier,
121 | max_concurrency=config_args.max_concurrency,
122 | quantize=getattr(config_args, 'quantize', 8),
123 | config_name=config_args.config_name,
124 | lora_paths=getattr(config_args, 'lora_paths', None),
125 | lora_scales=getattr(config_args, 'lora_scales', None)
126 | )
127 | elif config_args.model_type == "whisper":
128 | handler = MLXWhisperHandler(
129 | model_path=model_identifier,
130 | max_concurrency=config_args.max_concurrency
131 | )
132 | else:
133 | handler = MLXLMHandler(
134 | model_path=model_identifier,
135 | context_length=getattr(config_args, 'context_length', None),
136 | max_concurrency=config_args.max_concurrency
137 | )
138 | # Initialize queue
139 | await handler.initialize({
140 | "max_concurrency": config_args.max_concurrency,
141 | "timeout": config_args.queue_timeout,
142 | "queue_size": config_args.queue_size
143 | })
144 | logger.info("MLX handler initialized successfully")
145 | app.state.handler = handler
146 |
147 | except Exception as e:
148 | logger.error(f"Failed to initialize MLX handler: {str(e)}")
149 | raise
150 |
151 | # Initial memory cleanup
152 | mx.clear_cache()
153 | gc.collect()
154 |
155 | yield
156 |
157 | # Shutdown
158 | logger.info("Shutting down application")
159 | if hasattr(app.state, "handler") and app.state.handler:
160 | try:
161 | # Use the proper cleanup method which handles both request queue and image processor
162 | logger.info("Cleaning up resources")
163 | await app.state.handler.cleanup()
164 | logger.info("Resources cleaned up successfully")
165 | except Exception as e:
166 | logger.error(f"Error during shutdown: {str(e)}")
167 |
168 | # Final memory cleanup
169 | mx.clear_cache()
170 | gc.collect()
171 |
172 | return lifespan
173 |
174 | # App instance will be created during setup with the correct lifespan
175 | app = None
176 |
177 | async def setup_server(args) -> uvicorn.Config:
178 | global app
179 |
180 | # Configure logging based on CLI parameters
181 | configure_logging(
182 | log_file=getattr(args, 'log_file', None),
183 | no_log_file=getattr(args, 'no_log_file', False),
184 | log_level=getattr(args, 'log_level', 'INFO')
185 | )
186 |
187 | # Create FastAPI app with the configured lifespan
188 | app = FastAPI(
189 | title="OpenAI-compatible API",
190 | description="API for OpenAI-compatible chat completion and text embedding",
191 | version=__version__,
192 | lifespan=create_lifespan(args)
193 | )
194 |
195 | app.include_router(router)
196 |
197 | # Add CORS middleware
198 | app.add_middleware(
199 | CORSMiddleware,
200 | allow_origins=["*"], # In production, replace with specific origins
201 | allow_credentials=True,
202 | allow_methods=["*"],
203 | allow_headers=["*"],
204 | )
205 |
206 | @app.middleware("http")
207 | async def add_process_time_header(request: Request, call_next):
208 | start_time = time.time()
209 | response = await call_next(request)
210 | process_time = time.time() - start_time
211 | response.headers["X-Process-Time"] = str(process_time)
212 |
213 | # Periodic memory cleanup for long-running processes
214 | if hasattr(request.app.state, 'request_count'):
215 | request.app.state.request_count += 1
216 | else:
217 | request.app.state.request_count = 1
218 |
219 | # Clean up memory every 50 requests
220 | if request.app.state.request_count % 50 == 0:
221 | mx.clear_cache()
222 | gc.collect()
223 | logger.debug(f"Performed memory cleanup after {request.app.state.request_count} requests")
224 |
225 | return response
226 |
227 | @app.exception_handler(Exception)
228 | async def global_exception_handler(request: Request, exc: Exception):
229 | logger.error(f"Global exception handler caught: {str(exc)}", exc_info=True)
230 | return JSONResponse(
231 | status_code=500,
232 | content={"error": {"message": "Internal server error", "type": "internal_error"}}
233 | )
234 |
235 | logger.info(f"Starting server on {args.host}:{args.port}")
236 | config = uvicorn.Config(
237 | app=app,
238 | host=args.host,
239 | port=args.port,
240 | log_level="info",
241 | access_log=True
242 | )
243 | return config
244 |
245 | if __name__ == "__main__":
246 | args = parse_args()
247 | config = asyncio.run(setup_server(args))
248 | uvicorn.Server(config).run()
--------------------------------------------------------------------------------
/app/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cubist38/mlx-openai-server/25fc00aa21fc6794358849ebfb9d866a6e203eda/app/models/__init__.py
--------------------------------------------------------------------------------
/app/models/mflux.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from PIL import Image
4 | from abc import ABC, abstractmethod
5 | from mflux.flux.flux import Flux1, Config
6 | from mflux.config.model_config import ModelConfig
7 | from mflux.kontext.flux_kontext import Flux1Kontext
8 | from typing import Dict, Type, Any, Optional, Union, List
9 |
10 |
11 | # Custom Exceptions
12 | class FluxModelError(Exception):
13 | """Base exception for Flux model errors."""
14 | pass
15 |
16 |
17 | class ModelLoadError(FluxModelError):
18 | """Raised when model loading fails."""
19 | pass
20 |
21 |
22 | class ModelGenerationError(FluxModelError):
23 | """Raised when image generation fails."""
24 | pass
25 |
26 |
27 | class InvalidConfigurationError(FluxModelError):
28 | """Raised when configuration is invalid."""
29 | pass
30 |
31 |
32 | class ModelConfiguration:
33 | """Configuration class for Flux models."""
34 |
35 | def __init__(self,
36 | model_type: str,
37 | model_config: Optional[ModelConfig] = None,
38 | quantize: int = 8,
39 | default_steps: int = 20,
40 | default_guidance: float = 2.5,
41 | lora_paths: Optional[List[str]] = None,
42 | lora_scales: Optional[List[float]] = None):
43 |
44 | # Validate quantization level
45 | if quantize not in [4, 8, 16]:
46 | raise InvalidConfigurationError(f"Invalid quantization level: {quantize}. Must be 4, 8, or 16.")
47 |
48 | # Validate LoRA parameters: both must be provided together and have matching lengths
49 | if (lora_paths is None) != (lora_scales is None):
50 | raise InvalidConfigurationError(
51 | "Both lora_paths and lora_scales must be provided together."
52 | )
53 | if lora_paths and lora_scales and len(lora_paths) != len(lora_scales):
54 | raise InvalidConfigurationError(
55 | f"lora_paths and lora_scales must have the same length (got {len(lora_paths)} and {len(lora_scales)})"
56 | )
57 |
58 | self.model_type = model_type
59 | self.model_config = model_config
60 | self.quantize = quantize
61 | self.default_steps = default_steps
62 | self.default_guidance = default_guidance
63 | self.lora_paths = lora_paths
64 | self.lora_scales = lora_scales
65 |
66 | @classmethod
67 | def schnell(cls, quantize: int = 8, lora_paths: Optional[List[str]] = None, lora_scales: Optional[List[float]] = None) -> 'ModelConfiguration':
68 | """Create configuration for Flux Schnell model."""
69 | return cls(
70 | model_type="schnell",
71 | model_config=ModelConfig.schnell(),
72 | quantize=quantize,
73 | default_steps=4,
74 | default_guidance=0.0,
75 | lora_paths=lora_paths,
76 | lora_scales=lora_scales
77 | )
78 |
79 | @classmethod
80 | def dev(cls, quantize: int = 8, lora_paths: Optional[List[str]] = None, lora_scales: Optional[List[float]] = None) -> 'ModelConfiguration':
81 | """Create configuration for Flux Dev model."""
82 | return cls(
83 | model_type="dev",
84 | model_config=ModelConfig.dev(),
85 | quantize=quantize,
86 | default_steps=25,
87 | default_guidance=3.5,
88 | lora_paths=lora_paths,
89 | lora_scales=lora_scales
90 | )
91 |
92 | @classmethod
93 | def krea_dev(cls, quantize: int = 8, lora_paths: Optional[List[str]] = None, lora_scales: Optional[List[float]] = None) -> 'ModelConfiguration':
94 | """Create configuration for Flux Krea Dev model."""
95 | return cls(
96 | model_type="krea-dev",
97 | model_config=ModelConfig.dev(),
98 | quantize=quantize,
99 | default_steps=28,
100 | default_guidance=4.5,
101 | lora_paths=lora_paths,
102 | lora_scales=lora_scales
103 | )
104 |
105 | @classmethod
106 | def kontext(cls, quantize: int = 8) -> 'ModelConfiguration':
107 | """Create configuration for Flux Kontext model."""
108 | return cls(
109 | model_type="kontext",
110 | model_config=None, # Kontext doesn't use ModelConfig
111 | quantize=quantize,
112 | default_steps=28,
113 | default_guidance=2.5,
114 | lora_paths=None, # Kontext doesn't support LoRA
115 | lora_scales=None
116 | )
117 |
118 |
119 | class BaseFluxModel(ABC):
120 | """Abstract base class for Flux models with common functionality."""
121 |
122 | def __init__(self, model_path: str, config: ModelConfiguration):
123 | self.model_path = model_path
124 | self.config = config
125 | self.logger = logging.getLogger(self.__class__.__name__)
126 | self._model = None
127 | self._is_loaded = False
128 |
129 | # Validate model path
130 | if not self._validate_model_path():
131 | raise ModelLoadError(f"Invalid model path: {model_path}")
132 |
133 | self._load_model()
134 |
135 | def _validate_model_path(self) -> bool:
136 | """Validate that the model path exists or is a valid model name."""
137 | # Check if it's a file path
138 | if os.path.exists(self.model_path):
139 | return True
140 |
141 | # Check if it's a valid model name (for downloading)
142 | valid_model_names = ["flux-dev", "flux-schnell", "flux-kontext-dev"]
143 | return self.model_path in valid_model_names
144 |
145 | @abstractmethod
146 | def _load_model(self):
147 | """Load the specific model implementation."""
148 | pass
149 |
150 | @abstractmethod
151 | def _generate_image(self, prompt: str, seed: int, config: Config) -> Image.Image:
152 | """Generate image using the specific model implementation."""
153 | pass
154 |
155 | def __call__(self, prompt: str, seed: int = 42, **kwargs) -> Image.Image:
156 | """Generate an image from a text prompt."""
157 | if not self._is_loaded:
158 | raise ModelLoadError("Model is not loaded. Cannot generate image.")
159 |
160 | # Validate inputs
161 | if not prompt or not prompt.strip():
162 | raise ModelGenerationError("Prompt cannot be empty.")
163 |
164 | if not isinstance(seed, int) or seed < 0:
165 | raise ModelGenerationError("Seed must be a non-negative integer.")
166 |
167 | # Merge default config values with provided kwargs
168 | try:
169 | generation_config = self._prepare_config(**kwargs)
170 | except Exception as e:
171 | raise ModelGenerationError(f"Failed to prepare configuration: {e}")
172 |
173 | self.logger.info(f"Generating image with prompt: '{prompt[:50]}...' "
174 | f"(steps: {generation_config.num_inference_steps}, seed: {seed})")
175 |
176 | try:
177 | result = self._generate_image(prompt, seed, generation_config)
178 | if result is None:
179 | raise ModelGenerationError("Model returned None instead of an image.")
180 |
181 | self.logger.info("Image generated successfully")
182 | return result
183 | except Exception as e:
184 | error_msg = f"Error generating image: {e}"
185 | self.logger.error(error_msg)
186 | raise ModelGenerationError(error_msg) from e
187 |
188 | def _prepare_config(self, **kwargs) -> Config:
189 | """Prepare configuration for image generation."""
190 | # Validate dimensions
191 | width = kwargs.get('width', 1024)
192 | height = kwargs.get('height', 1024)
193 |
194 | if not isinstance(width, int) or width <= 0:
195 | raise ModelGenerationError("Width must be a positive integer.")
196 | if not isinstance(height, int) or height <= 0:
197 | raise ModelGenerationError("Height must be a positive integer.")
198 |
199 | # Validate steps
200 | steps = kwargs.get('num_inference_steps', self.config.default_steps)
201 | if not isinstance(steps, int) or steps <= 0:
202 | raise ModelGenerationError("Number of inference steps must be a positive integer.")
203 |
204 | # Validate guidance
205 | guidance = kwargs.get('guidance', self.config.default_guidance)
206 | if not isinstance(guidance, (int, float)) or guidance < 0:
207 | raise ModelGenerationError("Guidance must be a non-negative number.")
208 |
209 | config_params = {
210 | 'num_inference_steps': steps,
211 | 'guidance': guidance,
212 | 'width': width,
213 | 'height': height
214 | }
215 |
216 | # Add image_path if provided (for inpainting/editing)
217 | if 'image_path' in kwargs:
218 | image_path = kwargs['image_path']
219 | if not os.path.exists(image_path):
220 | raise ModelGenerationError(f"Image path does not exist: {image_path}")
221 | config_params['image_path'] = image_path
222 |
223 | return Config(**config_params)
224 |
225 |
226 | class FluxStandardModel(BaseFluxModel):
227 | """Standard Flux model implementation for Dev and Schnell variants."""
228 |
229 | def _load_model(self):
230 | """Load the standard Flux model."""
231 | try:
232 | self.logger.info(f"Loading {self.config.model_type} model from {self.model_path}")
233 |
234 | # Prepare lora parameters
235 | lora_paths = self.config.lora_paths
236 | lora_scales = self.config.lora_scales
237 |
238 | # Log LoRA information if provided
239 | if lora_paths:
240 | self.logger.info(f"Using LoRA adapters: {lora_paths}")
241 | if lora_scales:
242 | self.logger.info(f"LoRA scales: {lora_scales}")
243 |
244 | self._model = Flux1(
245 | model_config=self.config.model_config,
246 | local_path=self.model_path,
247 | quantize=self.config.quantize,
248 | lora_paths=lora_paths,
249 | lora_scales=lora_scales,
250 | )
251 | self._is_loaded = True
252 | self.logger.info(f"{self.config.model_type} model loaded successfully")
253 | except Exception as e:
254 | error_msg = f"Failed to load {self.config.model_type} model: {e}"
255 | self.logger.error(error_msg)
256 | raise ModelLoadError(error_msg) from e
257 |
258 | def _generate_image(self, prompt: str, seed: int, config: Config) -> Image.Image:
259 | """Generate image using standard Flux model."""
260 | try:
261 | result = self._model.generate_image(
262 | config=config,
263 | prompt=prompt,
264 | seed=seed,
265 | )
266 | return result.image
267 | except Exception as e:
268 | raise ModelGenerationError(f"Standard model generation failed: {e}") from e
269 |
270 |
271 | class FluxKontextModel(BaseFluxModel):
272 | """Flux Kontext model implementation."""
273 |
274 | def _load_model(self):
275 | """Load the Flux Kontext model."""
276 | try:
277 | self.logger.info(f"Loading Kontext model from {self.model_path}")
278 | self._model = Flux1Kontext(
279 | quantize=self.config.quantize,
280 | local_path=self.model_path
281 | )
282 | self._is_loaded = True
283 | self.logger.info("Kontext model loaded successfully")
284 | except Exception as e:
285 | error_msg = f"Failed to load Kontext model: {e}"
286 | self.logger.error(error_msg)
287 | raise ModelLoadError(error_msg) from e
288 |
289 | def _generate_image(self, prompt: str, seed: int, config: Config) -> Image.Image:
290 | """Generate image using Flux Kontext model."""
291 | try:
292 | result = self._model.generate_image(
293 | config=config,
294 | prompt=prompt,
295 | seed=seed,
296 | )
297 | return result.image
298 | except Exception as e:
299 | raise ModelGenerationError(f"Kontext model generation failed: {e}") from e
300 |
301 |
302 | class FluxModel:
303 | """Factory class for creating and managing Flux models."""
304 |
305 | _MODEL_CONFIGS = {
306 | "flux-schnell": ModelConfiguration.schnell,
307 | "flux-dev": ModelConfiguration.dev,
308 | "flux-krea-dev": ModelConfiguration.krea_dev,
309 | "flux-kontext-dev": ModelConfiguration.kontext,
310 | }
311 |
312 | _MODEL_CLASSES = {
313 | "flux-schnell": FluxStandardModel,
314 | "flux-dev": FluxStandardModel,
315 | "flux-krea-dev": FluxStandardModel,
316 | "flux-kontext-dev": FluxKontextModel,
317 | }
318 |
319 | def __init__(self, model_path: str, config_name: str, quantize: int = 8,
320 | lora_paths: Optional[List[str]] = None, lora_scales: Optional[List[float]] = None):
321 |
322 | self.config_name = config_name
323 | self.model_path = model_path
324 | self.quantize = quantize
325 | self.lora_paths = lora_paths
326 | self.lora_scales = lora_scales
327 | self.logger = logging.getLogger(self.__class__.__name__)
328 |
329 | # Validate configuration
330 | if config_name not in self._MODEL_CONFIGS:
331 | available_configs = ", ".join(self._MODEL_CONFIGS.keys())
332 | raise InvalidConfigurationError(f"Invalid config name: {config_name}. Available options: {available_configs}")
333 |
334 | # Validate LoRA parameters for kontext model
335 | if config_name == "flux-kontext-dev" and (lora_paths is not None or lora_scales is not None):
336 | raise InvalidConfigurationError("Flux Kontext model does not support LoRA adapters")
337 |
338 | try:
339 | # Create model configuration
340 | config_factory = self._MODEL_CONFIGS[config_name]
341 | if config_name == "flux-kontext-dev":
342 | self.config = config_factory(quantize=quantize)
343 | else:
344 | self.config = config_factory(quantize=quantize, lora_paths=lora_paths, lora_scales=lora_scales)
345 |
346 | # Create model instance
347 | model_class = self._MODEL_CLASSES[config_name]
348 | self.flux = model_class(model_path, self.config)
349 |
350 | self.logger.info(f"FluxModel initialized successfully with config: {config_name}")
351 | if lora_paths:
352 | self.logger.info(f"LoRA adapters: {lora_paths}")
353 |
354 | except Exception as e:
355 | error_msg = f"Failed to initialize FluxModel: {e}"
356 | self.logger.error(error_msg)
357 | raise ModelLoadError(error_msg) from e
358 |
359 | def __call__(self, prompt: str, seed: int = 42, **kwargs) -> Image.Image:
360 | """Generate an image using the configured model."""
361 | return self.flux(prompt, seed, **kwargs)
362 |
363 | @classmethod
364 | def get_available_configs(cls) -> list[str]:
365 | """Get list of available model configurations."""
366 | return list(cls._MODEL_CONFIGS.keys())
367 |
368 | @classmethod
369 | def get_model_info(cls, config_name: str) -> Dict[str, Any]:
370 | """Get information about a specific model configuration."""
371 | if config_name not in cls._MODEL_CONFIGS:
372 | raise InvalidConfigurationError(f"Unknown config: {config_name}")
373 |
374 | config = cls._MODEL_CONFIGS[config_name]()
375 | return {
376 | "name": config_name,
377 | "type": config.model_type,
378 | "default_steps": config.default_steps,
379 | "default_guidance": config.default_guidance,
380 | "model_class": cls._MODEL_CLASSES[config_name].__name__
381 | }
382 |
383 | def get_current_config(self) -> Dict[str, Any]:
384 | """Get current model configuration information."""
385 | return {
386 | "config_name": self.config_name,
387 | "model_path": self.model_path,
388 | "quantize": self.quantize,
389 | "type": self.config.model_type,
390 | "default_steps": self.config.default_steps,
391 | "default_guidance": self.config.default_guidance,
392 | "is_loaded": self.flux._is_loaded if hasattr(self.flux, '_is_loaded') else False,
393 | "lora_paths": self.config.lora_paths,
394 | "lora_scales": self.config.lora_scales,
395 | }
396 |
397 | def is_loaded(self) -> bool:
398 | """Check if the model is loaded and ready for inference."""
399 | return hasattr(self.flux, '_is_loaded') and self.flux._is_loaded
--------------------------------------------------------------------------------
/app/models/mlx_embeddings.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import mlx.core as mx
3 | from mlx_embeddings.utils import load
4 | from typing import List, Optional
5 |
6 | class MLX_Embeddings:
7 | """
8 | A wrapper class for MLX Embeddings that handles memory management to prevent leaks.
9 |
10 | This class provides a unified interface for generating embeddings from text inputs,
11 | with proper cleanup of MLX arrays and memory management.
12 | """
13 |
14 | def __init__(self, model_path: str):
15 | """
16 | Initialize the MLX_Embeddings model.
17 |
18 | Args:
19 | model_name (str): Name of the model to load.
20 |
21 | Raises:
22 | ValueError: If model loading fails.
23 | """
24 | try:
25 | self.model, self.tokenizer = load(model_path)
26 | except Exception as e:
27 | raise ValueError(f"Error loading model: {str(e)}")
28 |
29 | def _get_embeddings(self, texts: List[str], max_length: int = 512) -> mx.array:
30 | """
31 | Get embeddings for a list of texts with proper memory management.
32 |
33 | Args:
34 | texts: List of text inputs
35 | max_length: Maximum sequence length for tokenization
36 |
37 | Returns:
38 | MLX array of embeddings
39 | """
40 | inputs = None
41 | outputs = None
42 | try:
43 | # Tokenize inputs
44 | inputs = self.tokenizer.batch_encode_plus(
45 | texts,
46 | return_tensors="mlx",
47 | padding=True,
48 | truncation=True,
49 | max_length=max_length
50 | )
51 |
52 | # Generate embeddings
53 | outputs = self.model(
54 | inputs["input_ids"],
55 | attention_mask=inputs["attention_mask"]
56 | ).text_embeds
57 |
58 | # Return a copy to ensure the result persists after cleanup
59 | return mx.array(outputs)
60 |
61 | except Exception as e:
62 | # Clean up on error
63 | self._cleanup_arrays(inputs, outputs)
64 | raise
65 | finally:
66 | # Always clean up intermediate arrays
67 | self._cleanup_arrays(inputs, outputs)
68 |
69 | def _cleanup_arrays(self, *arrays):
70 | """Clean up MLX arrays to free memory."""
71 | for array in arrays:
72 | if array is not None:
73 | try:
74 | if isinstance(array, dict):
75 | for key, value in array.items():
76 | if hasattr(value, 'nbytes'):
77 | del value
78 | elif hasattr(array, 'nbytes'):
79 | del array
80 | except:
81 | pass
82 |
83 | # Clear MLX cache and force garbage collection
84 | mx.clear_cache()
85 | gc.collect()
86 |
87 | def __call__(self, texts: List[str], max_length: int = 512) -> List[List[float]]:
88 | """
89 | Generate embeddings for a list of texts.
90 |
91 | Args:
92 | texts: List of text inputs
93 | max_length: Maximum sequence length for tokenization
94 |
95 | Returns:
96 | List of embedding vectors as float lists
97 | """
98 | try:
99 | embeddings = self._get_embeddings(texts, max_length)
100 | # Convert to Python list and return
101 | result = embeddings.tolist()
102 | # Clean up the embeddings array
103 | del embeddings
104 | mx.clear_cache()
105 | gc.collect()
106 | return result
107 | except Exception as e:
108 | # Clean up on error
109 | mx.clear_cache()
110 | gc.collect()
111 | raise
112 |
113 | def cleanup(self):
114 | """Explicitly cleanup resources."""
115 | try:
116 | # Clear any cached model outputs
117 | if hasattr(self, 'model'):
118 | del self.model
119 | if hasattr(self, 'tokenizer'):
120 | del self.tokenizer
121 |
122 | # Clear MLX cache and force garbage collection
123 | mx.clear_cache()
124 | gc.collect()
125 | except Exception as e:
126 | # Log cleanup errors but don't raise
127 | pass
128 |
129 | def __del__(self):
130 | """Destructor to ensure cleanup on object deletion."""
131 | self.cleanup()
132 |
133 | if __name__ == "__main__":
134 | model_path = "mlx-community/all-MiniLM-L6-v2-4bit"
135 | model = MLX_Embeddings(model_path)
136 | try:
137 | texts = ["I like reading", "I like writing"]
138 | embeddings = model(texts)
139 | print(f"Generated embeddings shape: {len(embeddings)} x {len(embeddings[0])}")
140 | finally:
141 | # Explicit cleanup
142 | model.cleanup()
--------------------------------------------------------------------------------
/app/models/mlx_lm.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 | import mlx.core as mx
4 | from mlx_lm.utils import load
5 | from mlx_lm.generate import (
6 | generate,
7 | stream_generate,
8 | )
9 | from outlines.processors import JSONLogitsProcessor
10 | from mlx_lm.models.cache import make_prompt_cache
11 | from mlx_lm.sample_utils import make_sampler, make_logits_processors
12 | from app.utils.outlines_transformer_tokenizer import OutlinesTransformerTokenizer
13 | from typing import List, Dict, Union, Generator
14 |
15 | DEFAULT_TEMPERATURE = os.getenv("DEFAULT_TEMPERATURE", 0.7)
16 | DEFAULT_TOP_P = os.getenv("DEFAULT_TOP_P", 0.95)
17 | DEFAULT_TOP_K = os.getenv("DEFAULT_TOP_K", 20)
18 | DEFAULT_MIN_P = os.getenv("DEFAULT_MIN_P", 0.0)
19 | DEFAULT_SEED = os.getenv("DEFAULT_SEED", 0)
20 | DEFAULT_MAX_TOKENS = os.getenv("DEFAULT_MAX_TOKENS", 8192)
21 | DEFAULT_BATCH_SIZE = os.getenv("DEFAULT_BATCH_SIZE", 32)
22 |
23 | class MLX_LM:
24 | """
25 | A wrapper class for MLX Language Model that handles both streaming and non-streaming inference.
26 |
27 | This class provides a unified interface for generating text responses from text prompts,
28 | supporting both streaming and non-streaming modes.
29 | """
30 |
31 | def __init__(self, model_path: str, context_length: int = 32768):
32 | try:
33 | self.model, self.tokenizer = load(model_path)
34 | self.pad_token_id = self.tokenizer.pad_token_id
35 | self.bos_token = self.tokenizer.bos_token
36 | self.model_type = self.model.model_type
37 | self.max_kv_size = context_length
38 | self.outlines_tokenizer = OutlinesTransformerTokenizer(self.tokenizer)
39 | except Exception as e:
40 | raise ValueError(f"Error loading model: {str(e)}")
41 |
42 | def _apply_pooling_strategy(self, embeddings: mx.array) -> mx.array:
43 | embeddings = mx.mean(embeddings, axis=1)
44 | return embeddings
45 |
46 | def _apply_l2_normalization(self, embeddings: mx.array) -> mx.array:
47 | l2_norms = mx.linalg.norm(embeddings, axis=1, keepdims=True)
48 | embeddings = embeddings / (l2_norms + 1e-8)
49 | return embeddings
50 |
51 | def _batch_process(self, prompts: List[str], batch_size: int = DEFAULT_BATCH_SIZE) -> List[List[int]]:
52 | """Process prompts in batches with optimized tokenization."""
53 | all_tokenized = []
54 |
55 | # Process prompts in batches
56 | for i in range(0, len(prompts), batch_size):
57 | batch = prompts[i:i + batch_size]
58 | tokenized_batch = []
59 |
60 | # Tokenize all prompts in batch
61 | for p in batch:
62 | add_special_tokens = self.bos_token is None or not p.startswith(self.bos_token)
63 | tokens = self.tokenizer.encode(p, add_special_tokens=add_special_tokens)
64 | tokenized_batch.append(tokens)
65 |
66 | # Find max length in batch
67 | max_length = max(len(tokens) for tokens in tokenized_batch)
68 |
69 | # Pad tokens in a vectorized way
70 | for tokens in tokenized_batch:
71 | padding = [self.pad_token_id] * (max_length - len(tokens))
72 | all_tokenized.append(tokens + padding)
73 |
74 | return all_tokenized
75 |
76 | def _preprocess_prompt(self, prompt: str) -> List[int]:
77 | """Tokenize a single prompt efficiently."""
78 | add_special_tokens = self.bos_token is None or not prompt.startswith(self.bos_token)
79 | tokens = self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
80 | return mx.array(tokens)
81 |
82 | def get_model_type(self) -> str:
83 | return self.model_type
84 |
85 | def get_embeddings(
86 | self,
87 | prompts: List[str],
88 | batch_size: int = DEFAULT_BATCH_SIZE,
89 | normalize: bool = True
90 | ) -> List[float]:
91 | """
92 | Get embeddings for a list of prompts efficiently.
93 |
94 | Args:
95 | prompts: List of text prompts
96 | batch_size: Size of batches for processing
97 |
98 | Returns:
99 | List of embeddings as float arrays
100 | """
101 | # Process in batches to optimize memory usage
102 | all_embeddings = []
103 | try:
104 | for i in range(0, len(prompts), batch_size):
105 | batch_prompts = prompts[i:i + batch_size]
106 | tokenized_batch = self._batch_process(batch_prompts, batch_size)
107 |
108 | # Convert to MLX array for efficient computation
109 | tokenized_batch = mx.array(tokenized_batch)
110 |
111 | try:
112 | # Compute embeddings for batch
113 | batch_embeddings = self.model.model(tokenized_batch)
114 | pooled_embedding = self._apply_pooling_strategy(batch_embeddings)
115 | if normalize:
116 | pooled_embedding = self._apply_l2_normalization(pooled_embedding)
117 | all_embeddings.extend(pooled_embedding.tolist())
118 | finally:
119 | # Explicitly free MLX arrays to prevent memory leaks
120 | del tokenized_batch
121 | if 'batch_embeddings' in locals():
122 | del batch_embeddings
123 | if 'pooled_embedding' in locals():
124 | del pooled_embedding
125 | # Force MLX garbage collection
126 | mx.clear_cache()
127 | gc.collect()
128 | except Exception as e:
129 | # Clean up on error
130 | mx.clear_cache()
131 | gc.collect()
132 | raise
133 |
134 | return all_embeddings
135 |
136 | def __call__(
137 | self,
138 | messages: List[Dict[str, str]],
139 | stream: bool = False,
140 | **kwargs
141 | ) -> Union[str, Generator[str, None, None]]:
142 | """
143 | Generate text response from the model.
144 |
145 | Args:
146 | messages (List[Dict[str, str]]): List of messages in the conversation.
147 | stream (bool): Whether to stream the response.
148 | **kwargs: Additional parameters for generation
149 | - temperature: Sampling temperature (default: 0.0)
150 | - top_p: Top-p sampling parameter (default: 1.0)
151 | - seed: Random seed (default: 0)
152 | - max_tokens: Maximum number of tokens to generate (default: 256)
153 | """
154 | # Set default parameters if not provided
155 | seed = kwargs.get("seed", DEFAULT_SEED)
156 | max_tokens = kwargs.get("max_tokens", DEFAULT_MAX_TOKENS)
157 | chat_template_kwargs = kwargs.get("chat_template_kwargs", {})
158 |
159 | sampler_kwargs = {
160 | "temp": kwargs.get("temperature", DEFAULT_TEMPERATURE),
161 | "top_p": kwargs.get("top_p", DEFAULT_TOP_P),
162 | "top_k": kwargs.get("top_k", DEFAULT_TOP_K),
163 | "min_p": kwargs.get("min_p", DEFAULT_MIN_P)
164 | }
165 |
166 | repetition_penalty = kwargs.get("repetition_penalty", 1.0)
167 | repetition_context_size = kwargs.get("repetition_context_size", 20)
168 | logits_processors = make_logits_processors(repetition_penalty=repetition_penalty, repetition_context_size=repetition_context_size)
169 | json_schema = kwargs.get("schema", None)
170 | if json_schema:
171 | logits_processors.append(
172 | JSONLogitsProcessor(
173 | schema = json_schema,
174 | tokenizer = self.outlines_tokenizer,
175 | tensor_library_name = "mlx"
176 | )
177 | )
178 |
179 | mx.random.seed(seed)
180 | prompt_cache = make_prompt_cache(self.model, self.max_kv_size)
181 |
182 | input_tokens = self.tokenizer.apply_chat_template(
183 | messages,
184 | add_generation_prompt=True,
185 | **chat_template_kwargs,
186 | )
187 |
188 | sampler = make_sampler(
189 | **sampler_kwargs
190 | )
191 |
192 | if not stream:
193 | return generate(
194 | self.model,
195 | self.tokenizer,
196 | input_tokens,
197 | sampler=sampler,
198 | max_tokens=max_tokens,
199 | prompt_cache=prompt_cache,
200 | logits_processors=logits_processors
201 | )
202 | else:
203 | # Streaming mode: return generator of chunks
204 | return stream_generate(
205 | self.model,
206 | self.tokenizer,
207 | input_tokens,
208 | sampler=sampler,
209 | max_tokens=max_tokens,
210 | prompt_cache=prompt_cache,
211 | logits_processors=logits_processors
212 | )
--------------------------------------------------------------------------------
/app/models/mlx_vlm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import mlx.core as mx
3 | from typing import List, Dict, Union, Generator
4 | from mlx_vlm.models.cache import make_prompt_cache
5 | from mlx_vlm import load, generate, stream_generate
6 | from mlx_vlm.video_generate import process_vision_info
7 |
8 | # Default model parameters
9 | DEFAULT_MAX_TOKENS = os.getenv("DEFAULT_MAX_TOKENS", 8192)
10 | DEFAULT_TEMPERATURE = os.getenv("DEFAULT_TEMPERATURE", 0.0)
11 | DEFAULT_TOP_P = os.getenv("DEFAULT_TOP_P", 1.0)
12 | DEFAULT_SEED = os.getenv("DEFAULT_SEED", 0)
13 |
14 | class MLX_VLM:
15 | """
16 | A wrapper class for MLX Multimodal Model that handles both streaming and non-streaming inference.
17 |
18 | This class provides a unified interface for generating text responses from images and text prompts,
19 | supporting both streaming and non-streaming modes.
20 | """
21 |
22 | def __init__(self, model_path: str, context_length: int = None):
23 | """
24 | Initialize the MLX_VLM model.
25 |
26 | Args:
27 | model_path (str): Path to the model directory containing model weights and configuration.
28 |
29 | Raises:
30 | ValueError: If model loading fails.
31 | """
32 | try:
33 | self.model, self.processor = load(model_path, lazy=False, trust_remote_code=True)
34 | self.max_kv_size = context_length
35 | self.config = self.model.config
36 | except Exception as e:
37 | raise ValueError(f"Error loading model: {str(e)}")
38 |
39 | def _is_video_model(self):
40 | return hasattr(self.config, "video_token_id") or hasattr(
41 | self.config, "video_token_index"
42 | )
43 |
44 | def get_model_type(self):
45 | return self.config.model_type
46 |
47 | def __call__(
48 | self,
49 | messages: List[Dict[str, str]],
50 | images: List[str] = None,
51 | audios: List[str] = None,
52 | videos: List[str] = None,
53 | stream: bool = False,
54 | **kwargs
55 | ) -> Union[str, Generator[str, None, None]]:
56 | """
57 | Generate text response from images and messages.
58 |
59 | Args:
60 | images (List[str]): List of image paths to process.
61 | messages (List[Dict[str, str]]): List of message dictionaries with 'role' and 'content' keys.
62 | stream (bool, optional): Whether to stream the response. Defaults to False.
63 | **kwargs: Additional model parameters (chat_template_kwargs, temperature, max_tokens, etc.)
64 |
65 | Returns:
66 | Union[str, Generator[str, None, None]]:
67 | - If stream=False: Complete response as string
68 | - If stream=True: Generator yielding response chunks
69 | """
70 |
71 | if images and videos:
72 | raise ValueError("Cannot process both images and videos in the same request")
73 |
74 | if videos and not self._is_video_model():
75 | raise ValueError("Model is not a video model")
76 |
77 | text = self.processor.apply_chat_template(
78 | messages,
79 | tokenize=False,
80 | add_generation_prompt=True,
81 | **kwargs.get("chat_template_kwargs", {})
82 | )
83 |
84 | image_inputs, video_inputs = process_vision_info(messages)
85 |
86 | inputs = self.processor(
87 | text=[text],
88 | images=image_inputs,
89 | videos=video_inputs,
90 | padding=True,
91 | return_tensors="pt"
92 | )
93 |
94 | model_params = {
95 | "input_ids": mx.array(inputs["input_ids"]),
96 | "mask": mx.array(inputs["attention_mask"]),
97 | **kwargs
98 | }
99 |
100 | if images:
101 | model_params["pixel_values"] = mx.array(inputs["pixel_values"])
102 | model_params["image_grid_thw"] = mx.array(inputs["image_grid_thw"])
103 |
104 | if videos:
105 | model_params["pixel_values"] = mx.array(inputs["pixel_values_videos"])
106 | model_params["video_grid_thw"] = mx.array(inputs["video_grid_thw"])
107 |
108 | prompt_cache = make_prompt_cache(self.model, self.max_kv_size)
109 |
110 | if stream:
111 | return stream_generate(
112 | self.model,
113 | self.processor,
114 | prompt=text,
115 | prompt_cache=prompt_cache,
116 | **model_params
117 | )
118 | else:
119 | return generate(
120 | self.model,
121 | self.processor,
122 | prompt=text,
123 | prompt_cache=prompt_cache,
124 | **model_params
125 | )
126 |
127 |
128 | if __name__ == "__main__":
129 | image_path = "examples/images/attention.png"
130 | video_path = "examples/videos/demo.mp4"
131 | model_path = "mlx-community/GLM-4.5V-4bit"
132 |
133 | model = MLX_VLM(model_path)
134 | print("MODEL TYPE: ", model.get_model_type())
135 |
136 | tools = [{
137 | "type": "function",
138 | "function": {
139 | "name": "get_weather",
140 | "description": "Get the weather for a given city",
141 | "parameters": {
142 | "type": "object",
143 | "properties": {
144 | "city": {"type": "string", "description": "The city to get the weather for"}
145 | }
146 | },
147 | "required": ["city"]
148 | }}
149 | ]
150 | kwargs = {
151 | "chat_template_kwargs": {
152 | "tools": tools,
153 | "enable_thinking": True,
154 | },
155 | "temperature": 0.0,
156 | "top_p": 1.0,
157 | "seed": 0,
158 | "max_tokens": 8192,
159 | "frequency_penalty": 0.0,
160 | "presence_penalty": 0.0
161 | }
162 | messages = [
163 | {
164 | "role": "user",
165 | "content": [
166 | {
167 | "type": "text",
168 | "text": "Describe the video in detail"
169 | },
170 | {
171 | "type": "image",
172 | "image": image_path
173 | }
174 | ]
175 | }
176 | ]
177 | response = model(messages, stream=False, **kwargs)
178 | print(response)
--------------------------------------------------------------------------------
/app/models/mlx_whisper.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import numpy as np
3 | from functools import lru_cache
4 | from mlx_whisper.transcribe import transcribe
5 |
6 | SAMPLING_RATE = 16000
7 | CHUNK_SIZE = 30
8 |
9 |
10 | @lru_cache(maxsize=32)
11 | def load_audio(fname):
12 | """Load and cache audio file. Cache size limited to 32 recent files."""
13 | a, _ = librosa.load(fname, sr=SAMPLING_RATE, dtype=np.float32)
14 | return a
15 |
16 | @lru_cache(maxsize=32)
17 | def calculate_audio_duration(audio_path: str) -> int:
18 | """Calculate the duration of the audio file in seconds."""
19 | audio = load_audio(audio_path)
20 | return len(audio) / SAMPLING_RATE
21 |
22 | class MLX_Whisper:
23 | def __init__(self, model_path: str):
24 | self.model_path = model_path
25 |
26 | def _transcribe_generator(self, audio_path: str, **kwargs):
27 | """Stream transcription by processing audio in larger chunks."""
28 | # Load the audio file
29 | audio = load_audio(audio_path)
30 | duration = calculate_audio_duration(audio_path)
31 |
32 | beg = 0.0
33 | while beg < duration:
34 | # Calculate chunk boundaries
35 | chunk_end = min(beg + CHUNK_SIZE, duration)
36 |
37 | # Extract audio chunk
38 | beg_samples = int(beg * SAMPLING_RATE)
39 | end_samples = int(chunk_end * SAMPLING_RATE)
40 | audio_chunk = audio[beg_samples:end_samples]
41 |
42 | # Transcribe chunk
43 | result = transcribe(audio_chunk, path_or_hf_repo=self.model_path, **kwargs)
44 |
45 | # Add timing information
46 | result["chunk_start"] = beg
47 | result["chunk_end"] = chunk_end
48 |
49 | yield result
50 |
51 | beg += CHUNK_SIZE
52 |
53 | def __call__(self, audio_path: str, stream: bool = False, **kwargs):
54 | """
55 | Transcribe audio file.
56 |
57 | Args:
58 | audio_path: Path to audio file
59 | stream: If True, yields chunks. If False, transcribes entire file at once.
60 | **kwargs: Additional arguments passed to transcribe()
61 | """
62 | if stream:
63 | return self._transcribe_generator(audio_path, **kwargs)
64 | else:
65 | return transcribe(audio_path, path_or_hf_repo=self.model_path, **kwargs)
66 |
67 |
68 | if __name__ == "__main__":
69 | model = MLX_Whisper("mlx-community/whisper-tiny")
70 | # Non-streaming (fastest for most use cases)
71 | result = model("examples/audios/podcast.wav", stream=True)
72 | for chunk in result:
73 | print(f"[{chunk['chunk_start']:.1f}s - {chunk['chunk_end']:.1f}s]: {chunk['text']}")
--------------------------------------------------------------------------------
/app/schemas/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/app/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/app/utils/dill.py:
--------------------------------------------------------------------------------
1 | # copied from https://github.com/huggingface/datasets/blob/1e1d313/src/datasets/utils/_dill.py
2 |
3 | # Copyright 2023 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Extends `dill` to support pickling more types and produce more consistent dumps."""
17 |
18 | import sys
19 | from io import BytesIO
20 | from types import FunctionType
21 | from typing import Any, Dict, List, Union
22 |
23 | import dill
24 | import xxhash
25 |
26 |
27 | class Hasher:
28 | """Hasher that accepts python objects as inputs."""
29 |
30 | dispatch: Dict = {}
31 |
32 | def __init__(self):
33 | self.m = xxhash.xxh64()
34 |
35 | @classmethod
36 | def hash_bytes(cls, value: Union[bytes, List[bytes]]) -> str:
37 | value = [value] if isinstance(value, bytes) else value
38 | m = xxhash.xxh64()
39 | for x in value:
40 | m.update(x)
41 | return m.hexdigest()
42 |
43 | @classmethod
44 | def hash(cls, value: Any) -> str:
45 | return cls.hash_bytes(dumps(value))
46 |
47 | def update(self, value: Any) -> None:
48 | header_for_update = f"=={type(value)}=="
49 | value_for_update = self.hash(value)
50 | self.m.update(header_for_update.encode("utf8"))
51 | self.m.update(value_for_update.encode("utf-8"))
52 |
53 | def hexdigest(self) -> str:
54 | return self.m.hexdigest()
55 |
56 |
57 | class Pickler(dill.Pickler):
58 | dispatch = dill._dill.MetaCatchingDict(dill.Pickler.dispatch.copy())
59 | _legacy_no_dict_keys_sorting = False
60 |
61 | def save(self, obj, save_persistent_id=True):
62 | obj_type = type(obj)
63 | if obj_type not in self.dispatch:
64 | if "regex" in sys.modules:
65 | import regex # type: ignore
66 |
67 | if obj_type is regex.Pattern:
68 | pklregister(obj_type)(_save_regexPattern)
69 | if "spacy" in sys.modules:
70 | import spacy # type: ignore
71 |
72 | if issubclass(obj_type, spacy.Language):
73 | pklregister(obj_type)(_save_spacyLanguage)
74 | if "tiktoken" in sys.modules:
75 | import tiktoken # type: ignore
76 |
77 | if obj_type is tiktoken.Encoding:
78 | pklregister(obj_type)(_save_tiktokenEncoding)
79 | if "torch" in sys.modules:
80 | import torch # type: ignore
81 |
82 | if issubclass(obj_type, torch.Tensor):
83 | pklregister(obj_type)(_save_torchTensor)
84 |
85 | if obj_type is torch.Generator:
86 | pklregister(obj_type)(_save_torchGenerator)
87 |
88 | # Unwrap `torch.compile`-ed modules
89 | if issubclass(obj_type, torch.nn.Module):
90 | obj = getattr(obj, "_orig_mod", obj)
91 | if "transformers" in sys.modules:
92 | import transformers # type: ignore
93 |
94 | if issubclass(obj_type, transformers.PreTrainedTokenizerBase):
95 | pklregister(obj_type)(_save_transformersPreTrainedTokenizerBase)
96 |
97 | # Unwrap `torch.compile`-ed functions
98 | if obj_type is FunctionType:
99 | obj = getattr(obj, "_torchdynamo_orig_callable", obj)
100 | dill.Pickler.save(self, obj, save_persistent_id=save_persistent_id)
101 |
102 | def _batch_setitems(self, items):
103 | if self._legacy_no_dict_keys_sorting:
104 | return super()._batch_setitems(items)
105 | # Ignore the order of keys in a dict
106 | try:
107 | # Faster, but fails for unorderable elements
108 | items = sorted(items)
109 | except Exception: # TypeError, decimal.InvalidOperation, etc.
110 | items = sorted(items, key=lambda x: Hasher.hash(x[0]))
111 | dill.Pickler._batch_setitems(self, items)
112 |
113 | def memoize(self, obj):
114 | # Don't memoize strings since two identical strings can have different Python ids
115 | if type(obj) is not str: # noqa: E721
116 | dill.Pickler.memoize(self, obj)
117 |
118 |
119 | def pklregister(t):
120 | """Register a custom reducer for the type."""
121 |
122 | def proxy(func):
123 | Pickler.dispatch[t] = func
124 | return func
125 |
126 | return proxy
127 |
128 |
129 | def dump(obj, file):
130 | """Pickle an object to a file."""
131 | Pickler(file, recurse=True).dump(obj)
132 |
133 |
134 | def dumps(obj):
135 | """Pickle an object to a string."""
136 | file = BytesIO()
137 | dump(obj, file)
138 | return file.getvalue()
139 |
140 |
141 | def log(pickler, msg):
142 | pass
143 |
144 |
145 | def _save_regexPattern(pickler, obj):
146 | import regex # type: ignore
147 |
148 | log(pickler, f"Re: {obj}")
149 | args = (obj.pattern, obj.flags)
150 | pickler.save_reduce(regex.compile, args, obj=obj)
151 | log(pickler, "# Re")
152 |
153 |
154 | def _save_tiktokenEncoding(pickler, obj):
155 | import tiktoken # type: ignore
156 |
157 | log(pickler, f"Enc: {obj}")
158 | args = (obj.name, obj._pat_str, obj._mergeable_ranks, obj._special_tokens)
159 | pickler.save_reduce(tiktoken.Encoding, args, obj=obj)
160 | log(pickler, "# Enc")
161 |
162 |
163 | def _save_torchTensor(pickler, obj):
164 | import torch # type: ignore
165 |
166 | # `torch.from_numpy` is not picklable in `torch>=1.11.0`
167 | def create_torchTensor(np_array, dtype=None):
168 | tensor = torch.from_numpy(np_array)
169 | if dtype:
170 | tensor = tensor.type(dtype)
171 | return tensor
172 |
173 | log(pickler, f"To: {obj}")
174 | if obj.dtype == torch.bfloat16:
175 | args = (obj.detach().to(torch.float).cpu().numpy(), torch.bfloat16)
176 | else:
177 | args = (obj.detach().cpu().numpy(),)
178 | pickler.save_reduce(create_torchTensor, args, obj=obj)
179 | log(pickler, "# To")
180 |
181 |
182 | def _save_torchGenerator(pickler, obj):
183 | import torch # type: ignore
184 |
185 | def create_torchGenerator(state):
186 | generator = torch.Generator()
187 | generator.set_state(state)
188 | return generator
189 |
190 | log(pickler, f"Ge: {obj}")
191 | args = (obj.get_state(),)
192 | pickler.save_reduce(create_torchGenerator, args, obj=obj)
193 | log(pickler, "# Ge")
194 |
195 |
196 | def _save_spacyLanguage(pickler, obj):
197 | import spacy # type: ignore
198 |
199 | def create_spacyLanguage(config, bytes):
200 | lang_cls = spacy.util.get_lang_class(config["nlp"]["lang"])
201 | lang_inst = lang_cls.from_config(config)
202 | return lang_inst.from_bytes(bytes)
203 |
204 | log(pickler, f"Sp: {obj}")
205 | args = (obj.config, obj.to_bytes())
206 | pickler.save_reduce(create_spacyLanguage, args, obj=obj)
207 | log(pickler, "# Sp")
208 |
209 |
210 | def _save_transformersPreTrainedTokenizerBase(pickler, obj):
211 | log(pickler, f"Tok: {obj}")
212 | # Ignore the `cache` attribute
213 | state = obj.__dict__
214 | if "cache" in state and isinstance(state["cache"], dict):
215 | state["cache"] = {}
216 | pickler.save_reduce(type(obj), (), state=state, obj=obj)
217 | log(pickler, "# Tok")
--------------------------------------------------------------------------------
/app/utils/errors.py:
--------------------------------------------------------------------------------
1 | from http import HTTPStatus
2 | from typing import Union
3 |
4 | def create_error_response(
5 | message: str,
6 | err_type: str = "internal_error",
7 | status_code: Union[int, HTTPStatus] = HTTPStatus.INTERNAL_SERVER_ERROR,
8 | param: str = None,
9 | code: str = None
10 | ):
11 | return {
12 | "error": {
13 | "message": message,
14 | "type": err_type,
15 | "param": param,
16 | "code": str(code or (status_code.value if isinstance(status_code, HTTPStatus) else status_code))
17 | }
18 | }
--------------------------------------------------------------------------------
/app/utils/outlines_transformer_tokenizer.py:
--------------------------------------------------------------------------------
1 | from app.utils.dill import Hasher
2 | from outlines.models.transformers import TransformerTokenizer
3 |
4 |
5 | class OutlinesTransformerTokenizer(TransformerTokenizer):
6 | """
7 | Update the outlines TransformerTokenizer to use our own Hasher class, so that we don't need the datasets dependency
8 |
9 | This class and the external dependency can be removed when the following import is deleted
10 | https://github.com/dottxt-ai/outlines/blob/69418d/outlines/models/transformers.py#L117
11 | """
12 |
13 | def __hash__(self):
14 | return hash(Hasher.hash(self.tokenizer))
--------------------------------------------------------------------------------
/app/version.py:
--------------------------------------------------------------------------------
1 | # Version number format: MAJOR.MINOR.PATCH
2 | # Major: Major version number (increments when breaking changes are introduced)
3 | # Minor: Minor version number (increments when new features are added)
4 | # Patch: Patch version number (increments when bug fixes are made)
5 |
6 | __version__ = "1.3.12"
--------------------------------------------------------------------------------
/configure_mlx.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Get the total memory in MB
4 | TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))
5 |
6 | # Calculate 80% and TOTAL_MEM_GB-5GB in MB
7 | EIGHTY_PERCENT=$(($TOTAL_MEM_MB * 80 / 100))
8 | MINUS_5GB=$((($TOTAL_MEM_MB - 5120)))
9 |
10 | # Calculate 70% and TOTAL_MEM_GB-8GB in MB
11 | SEVENTY_PERCENT=$(($TOTAL_MEM_MB * 70 / 100))
12 | MINUS_8GB=$((($TOTAL_MEM_MB - 8192)))
13 |
14 | # Set WIRED_LIMIT_MB to higher value
15 | if [ $EIGHTY_PERCENT -gt $MINUS_5GB ]; then
16 | WIRED_LIMIT_MB=$EIGHTY_PERCENT
17 | else
18 | WIRED_LIMIT_MB=$MINUS_5GB
19 | fi
20 |
21 | # Set WIRED_LWM_MB to higher value
22 | if [ $SEVENTY_PERCENT -gt $MINUS_8GB ]; then
23 | WIRED_LWM_MB=$SEVENTY_PERCENT
24 | else
25 | WIRED_LWM_MB=$MINUS_8GB
26 | fi
27 |
28 | # Display the calculated values
29 | echo "Total memory: $TOTAL_MEM_MB MB"
30 | echo "Maximum limit (iogpu.wired_limit_mb): $WIRED_LIMIT_MB MB"
31 | echo "Lower bound (iogpu.wired_lwm_mb): $WIRED_LWM_MB MB"
32 |
33 | # Apply the values with sysctl, but check if we're already root
34 | if [ "$EUID" -eq 0 ]; then
35 | sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
36 | sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
37 | else
38 | # Try without sudo first, fall back to sudo if needed
39 | sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB 2>/dev/null || \
40 | sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
41 | sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 2>/dev/null || \
42 | sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
43 | fi
--------------------------------------------------------------------------------
/examples/audio_examples.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Audio Processing with MLX Server\n",
8 | "\n",
9 | "This notebook demonstrates how to process audio files using the MLX Server with OpenAI-compatible API.\n"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "## What You'll Learn\n",
17 | "\n",
18 | "- Connect to MLX Server\n",
19 | "- Load and encode audio files for processing\n",
20 | "- Send audio to the model for analysis\n",
21 | "- Get text descriptions of audio content\n",
22 | "\n",
23 | "## Prerequisites\n",
24 | "\n",
25 | "- MLX Server running on localhost:8000\n",
26 | "- Audio file in the `audios/` directory\n",
27 | "- OpenAI Python library installed\n"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {},
33 | "source": [
34 | "## Step 1: Setup and Connection\n"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 1,
40 | "metadata": {},
41 | "outputs": [
42 | {
43 | "name": "stdout",
44 | "output_type": "stream",
45 | "text": [
46 | "✅ Connected to MLX Server\n"
47 | ]
48 | }
49 | ],
50 | "source": [
51 | "# Import required libraries\n",
52 | "from openai import OpenAI\n",
53 | "import base64\n",
54 | "import os\n",
55 | "\n",
56 | "# Initialize OpenAI client to connect to MLX Server\n",
57 | "# The MLX Server runs locally and provides OpenAI-compatible endpoints\n",
58 | "client = OpenAI(\n",
59 | " base_url=\"http://localhost:8000/v1\", # MLX Server address\n",
60 | " api_key=\"fake-api-key\", # Any string works for local server\n",
61 | ")\n",
62 | "\n",
63 | "print(\"✅ Connected to MLX Server\")\n"
64 | ]
65 | },
66 | {
67 | "cell_type": "markdown",
68 | "metadata": {},
69 | "source": [
70 | "## Step 2: Audio File Processing\n"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 2,
76 | "metadata": {},
77 | "outputs": [
78 | {
79 | "name": "stdout",
80 | "output_type": "stream",
81 | "text": [
82 | "✅ Loaded audio file: audios/audio.wav\n",
83 | " File size: 372698 bytes\n",
84 | " Encoded size: 496932 characters\n"
85 | ]
86 | }
87 | ],
88 | "source": [
89 | "def load_audio_file(audio_path: str) -> str:\n",
90 | " \"\"\"\n",
91 | " Load an audio file and encode it as base64 for API transmission.\n",
92 | " \n",
93 | " Args:\n",
94 | " audio_path (str): Path to the audio file\n",
95 | " \n",
96 | " Returns:\n",
97 | " str: Base64 encoded audio data\n",
98 | " \"\"\"\n",
99 | " if not os.path.exists(audio_path):\n",
100 | " raise FileNotFoundError(f\"Audio file not found: {audio_path}\")\n",
101 | " \n",
102 | " with open(audio_path, \"rb\") as audio_file:\n",
103 | " audio_data = audio_file.read()\n",
104 | " encoded_audio = base64.b64encode(audio_data).decode('utf-8')\n",
105 | " \n",
106 | " print(f\"✅ Loaded audio file: {audio_path}\")\n",
107 | " print(f\" File size: {len(audio_data)} bytes\")\n",
108 | " print(f\" Encoded size: {len(encoded_audio)} characters\")\n",
109 | " \n",
110 | " return encoded_audio\n",
111 | "\n",
112 | "# Load the sample audio file\n",
113 | "audio_path = \"audios/audio.wav\"\n",
114 | "audio_base64 = load_audio_file(audio_path)\n"
115 | ]
116 | },
117 | {
118 | "cell_type": "markdown",
119 | "metadata": {},
120 | "source": [
121 | "## Step 3: Audio Analysis\n"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": 3,
127 | "metadata": {},
128 | "outputs": [
129 | {
130 | "name": "stdout",
131 | "output_type": "stream",
132 | "text": [
133 | "🎵 Audio Analysis Result:\n",
134 | " Dogs are sitting by the door.\n"
135 | ]
136 | }
137 | ],
138 | "source": [
139 | "def analyze_audio(audio_base64: str, prompt: str = \"Describe what you hear in this audio.\") -> str:\n",
140 | " \"\"\"\n",
141 | " Send audio to MLX Server for analysis.\n",
142 | " \n",
143 | " Args:\n",
144 | " audio_base64 (str): Base64 encoded audio data\n",
145 | " prompt (str): Text prompt for the model\n",
146 | " \n",
147 | " Returns:\n",
148 | " str: Model's response\n",
149 | " \"\"\"\n",
150 | " try:\n",
151 | " response = client.chat.completions.create(\n",
152 | " model=\"local-multimodal\",\n",
153 | " messages=[\n",
154 | " {\n",
155 | " \"role\": \"user\", \n",
156 | " \"content\": [\n",
157 | " {\n",
158 | " \"type\": \"input_audio\",\n",
159 | " \"input_audio\": {\n",
160 | " \"data\": audio_base64,\n",
161 | " \"format\": \"wav\"\n",
162 | " }\n",
163 | " },\n",
164 | " {\n",
165 | " \"type\": \"text\",\n",
166 | " \"text\": prompt\n",
167 | " }\n",
168 | " ]\n",
169 | " }\n",
170 | " ],\n",
171 | " max_tokens=1024\n",
172 | " )\n",
173 | " \n",
174 | " return response.choices[0].message.content\n",
175 | " \n",
176 | " except Exception as e:\n",
177 | " return f\"Error analyzing audio: {str(e)}\"\n",
178 | "\n",
179 | "# Analyze the audio with a descriptive prompt\n",
180 | "result = analyze_audio(audio_base64, \"Describe the audio in detail.\")\n",
181 | "print(\"🎵 Audio Analysis Result:\")\n",
182 | "print(f\" {result}\")\n"
183 | ]
184 | },
185 | {
186 | "cell_type": "markdown",
187 | "metadata": {},
188 | "source": [
189 | "## Conclusion\n",
190 | "\n",
191 | "This notebook demonstrated the audio processing capabilities of the MLX Server using OpenAI-compatible API endpoints. Key highlights include:\n",
192 | "\n",
193 | "- **Audio Input Support**: Successfully processed audio files by encoding them as base64 and sending them through the `input_audio` message type\n",
194 | "- **Multimodal Integration**: Combined audio input with text prompts to create rich, context-aware responses\n",
195 | "- **OpenAI Compatibility**: Leveraged familiar OpenAI API patterns for seamless integration with existing workflows\n",
196 | "- **Error Handling**: Implemented proper error handling for robust audio processing\n",
197 | "\n",
198 | "The MLX Server's audio processing capabilities enable powerful applications such as:\n",
199 | "- Audio transcription and analysis\n",
200 | "- Voice-controlled interfaces\n",
201 | "- Audio content summarization\n",
202 | "- Accessibility features for audio-based content\n",
203 | "\n",
204 | "This foundation opens up numerous possibilities for building audio-enabled AI applications with the performance benefits of MLX on Apple Silicon.\n"
205 | ]
206 | }
207 | ],
208 | "metadata": {
209 | "kernelspec": {
210 | "display_name": "testing",
211 | "language": "python",
212 | "name": "python3"
213 | },
214 | "language_info": {
215 | "codemirror_mode": {
216 | "name": "ipython",
217 | "version": 3
218 | },
219 | "file_extension": ".py",
220 | "mimetype": "text/x-python",
221 | "name": "python",
222 | "nbconvert_exporter": "python",
223 | "pygments_lexer": "ipython3",
224 | "version": "3.11.11"
225 | }
226 | },
227 | "nbformat": 4,
228 | "nbformat_minor": 2
229 | }
230 |
--------------------------------------------------------------------------------
/examples/audios/audio.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cubist38/mlx-openai-server/25fc00aa21fc6794358849ebfb9d866a6e203eda/examples/audios/audio.wav
--------------------------------------------------------------------------------
/examples/audios/podcast.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cubist38/mlx-openai-server/25fc00aa21fc6794358849ebfb9d866a6e203eda/examples/audios/podcast.wav
--------------------------------------------------------------------------------
/examples/images/attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cubist38/mlx-openai-server/25fc00aa21fc6794358849ebfb9d866a6e203eda/examples/images/attention.png
--------------------------------------------------------------------------------
/examples/images/china.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cubist38/mlx-openai-server/25fc00aa21fc6794358849ebfb9d866a6e203eda/examples/images/china.png
--------------------------------------------------------------------------------
/examples/images/green_dog.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cubist38/mlx-openai-server/25fc00aa21fc6794358849ebfb9d866a6e203eda/examples/images/green_dog.jpeg
--------------------------------------------------------------------------------
/examples/images/password.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cubist38/mlx-openai-server/25fc00aa21fc6794358849ebfb9d866a6e203eda/examples/images/password.jpg
--------------------------------------------------------------------------------
/examples/lm_embeddings_examples.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Embeddings API Examples with MLX Server"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "This notebook demonstrates how to use the embeddings endpoint of MLX Server through the OpenAI-compatible API. You'll learn how to generate embeddings, work with batches, compare similarity between texts, and use embeddings for practical applications."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "## Setup and Connection"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 1,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "# Import the OpenAI client for API communication\n",
31 | "from openai import OpenAI\n",
32 | "\n",
33 | "# Connect to the local MLX Server with OpenAI-compatible API\n",
34 | "client = OpenAI(\n",
35 | " base_url=\"http://localhost:8000/v1\",\n",
36 | " api_key=\"fake-api-key\",\n",
37 | ")"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {},
43 | "source": [
44 | "## Basic Embedding Generation\n",
45 | "\n",
46 | "### Single Text Embedding\n"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 2,
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "# Generate embedding for a single text input\n",
56 | "single_text = \"Artificial intelligence is transforming how we interact with technology.\"\n",
57 | "response = client.embeddings.create(\n",
58 | " input=[single_text],\n",
59 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n",
60 | ")"
61 | ]
62 | },
63 | {
64 | "cell_type": "markdown",
65 | "metadata": {},
66 | "source": [
67 | "### Batch Processing Multiple Texts"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 3,
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "text_batch = [\n",
77 | " \"Machine learning algorithms improve with more data\",\n",
78 | " \"Natural language processing helps computers understand human language\",\n",
79 | " \"Computer vision allows machines to interpret visual information\"\n",
80 | "]\n",
81 | "\n",
82 | "batch_response = client.embeddings.create(\n",
83 | " input=text_batch,\n",
84 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n",
85 | ")"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 4,
91 | "metadata": {},
92 | "outputs": [
93 | {
94 | "name": "stdout",
95 | "output_type": "stream",
96 | "text": [
97 | "Number of embeddings generated: 3\n",
98 | "Dimensions of each embedding: 1536\n"
99 | ]
100 | }
101 | ],
102 | "source": [
103 | "# Access all embeddings\n",
104 | "embeddings = [item.embedding for item in batch_response.data]\n",
105 | "print(f\"Number of embeddings generated: {len(embeddings)}\")\n",
106 | "print(f\"Dimensions of each embedding: {len(embeddings[0])}\")"
107 | ]
108 | },
109 | {
110 | "cell_type": "markdown",
111 | "metadata": {},
112 | "source": [
113 | "## Semantic Similarity Calculation\n",
114 | "\n",
115 | "One of the most common uses for embeddings is measuring semantic similarity between texts."
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 5,
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "import numpy as np"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 6,
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "def cosine_similarity_score(vec1, vec2):\n",
134 | " \"\"\"Calculate cosine similarity between two vectors\"\"\"\n",
135 | " dot_product = np.dot(vec1, vec2)\n",
136 | " norm1 = np.linalg.norm(vec1)\n",
137 | " norm2 = np.linalg.norm(vec2)\n",
138 | " return dot_product / (norm1 * norm2)"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 7,
144 | "metadata": {},
145 | "outputs": [],
146 | "source": [
147 | "# Example texts to compare\n",
148 | "text1 = \"Dogs are loyal pets that provide companionship\"\n",
149 | "text2 = \"Canines make friendly companions for humans\"\n",
150 | "text3 = \"Quantum physics explores the behavior of matter at atomic scales\""
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": 8,
156 | "metadata": {},
157 | "outputs": [],
158 | "source": [
159 | "# Generate embeddings\n",
160 | "comparison_texts = [text1, text2, text3]\n",
161 | "comparison_response = client.embeddings.create(\n",
162 | " input=comparison_texts,\n",
163 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n",
164 | ")\n",
165 | "comparison_embeddings = [item.embedding for item in comparison_response.data]"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": 9,
171 | "metadata": {},
172 | "outputs": [
173 | {
174 | "name": "stdout",
175 | "output_type": "stream",
176 | "text": [
177 | "Similarity between text1 and text2: 0.8142\n",
178 | "Similarity between text1 and text3: 0.6082\n",
179 | "Similarity between text2 and text3: 0.5739\n"
180 | ]
181 | }
182 | ],
183 | "source": [
184 | "# Compare similarities\n",
185 | "similarity_1_2 = cosine_similarity_score(comparison_embeddings[0], comparison_embeddings[1])\n",
186 | "similarity_1_3 = cosine_similarity_score(comparison_embeddings[0], comparison_embeddings[2])\n",
187 | "similarity_2_3 = cosine_similarity_score(comparison_embeddings[1], comparison_embeddings[2])\n",
188 | "\n",
189 | "print(f\"Similarity between text1 and text2: {similarity_1_2:.4f}\")\n",
190 | "print(f\"Similarity between text1 and text3: {similarity_1_3:.4f}\")\n",
191 | "print(f\"Similarity between text2 and text3: {similarity_2_3:.4f}\")"
192 | ]
193 | },
194 | {
195 | "cell_type": "markdown",
196 | "metadata": {},
197 | "source": [
198 | "## Text Search Using Embeddings"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 10,
204 | "metadata": {},
205 | "outputs": [],
206 | "source": [
207 | "# Sample document collection\n",
208 | "documents = [\n",
209 | " \"The quick brown fox jumps over the lazy dog\",\n",
210 | " \"Machine learning models require training data\",\n",
211 | " \"Neural networks are inspired by biological neurons\",\n",
212 | " \"Deep learning is a subset of machine learning\",\n",
213 | " \"Natural language processing helps with text analysis\",\n",
214 | " \"Computer vision systems can detect objects in images\"\n",
215 | "]"
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": 11,
221 | "metadata": {},
222 | "outputs": [],
223 | "source": [
224 | "# Generate embeddings for all documents\n",
225 | "doc_response = client.embeddings.create(\n",
226 | " input=documents,\n",
227 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n",
228 | ")\n",
229 | "doc_embeddings = [item.embedding for item in doc_response.data]"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": 12,
235 | "metadata": {},
236 | "outputs": [
237 | {
238 | "name": "stdout",
239 | "output_type": "stream",
240 | "text": [
241 | "Search results:\n",
242 | "Score: 0.8574 - Computer vision systems can detect objects in images\n",
243 | "Score: 0.8356 - Neural networks are inspired by biological neurons\n",
244 | "Score: 0.8266 - Natural language processing helps with text analysis\n",
245 | "Score: 0.8141 - Deep learning is a subset of machine learning\n",
246 | "Score: 0.7474 - Machine learning models require training data\n",
247 | "Score: 0.5936 - The quick brown fox jumps over the lazy dog\n"
248 | ]
249 | }
250 | ],
251 | "source": [
252 | "def search_documents(query, doc_collection, doc_embeddings):\n",
253 | " \"\"\"Search for documents similar to query\"\"\"\n",
254 | " # Generate embedding for query\n",
255 | " query_response = client.embeddings.create(\n",
256 | " input=[query],\n",
257 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n",
258 | " )\n",
259 | " query_embedding = query_response.data[0].embedding\n",
260 | " \n",
261 | " # Calculate similarity scores\n",
262 | " similarities = []\n",
263 | " for doc_embedding in doc_embeddings:\n",
264 | " similarity = cosine_similarity_score(query_embedding, doc_embedding)\n",
265 | " similarities.append(similarity)\n",
266 | " \n",
267 | " # Return results with scores\n",
268 | " results = []\n",
269 | " for i, score in enumerate(similarities):\n",
270 | " results.append((doc_collection[i], score))\n",
271 | " \n",
272 | " # Sort by similarity score (highest first)\n",
273 | " return sorted(results, key=lambda x: x[1], reverse=True)\n",
274 | "\n",
275 | "# Example search\n",
276 | "search_results = search_documents(\"How do AI models learn?\", documents, doc_embeddings)\n",
277 | "\n",
278 | "print(\"Search results:\")\n",
279 | "for doc, score in search_results:\n",
280 | " print(f\"Score: {score:.4f} - {doc}\")"
281 | ]
282 | }
283 | ],
284 | "metadata": {
285 | "kernelspec": {
286 | "display_name": "Python 3",
287 | "language": "python",
288 | "name": "python3"
289 | },
290 | "language_info": {
291 | "codemirror_mode": {
292 | "name": "ipython",
293 | "version": 3
294 | },
295 | "file_extension": ".py",
296 | "mimetype": "text/x-python",
297 | "name": "python",
298 | "nbconvert_exporter": "python",
299 | "pygments_lexer": "ipython3",
300 | "version": "3.11.12"
301 | }
302 | },
303 | "nbformat": 4,
304 | "nbformat_minor": 2
305 | }
306 |
--------------------------------------------------------------------------------
/examples/pdfs/lab03.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cubist38/mlx-openai-server/25fc00aa21fc6794358849ebfb9d866a6e203eda/examples/pdfs/lab03.pdf
--------------------------------------------------------------------------------
/examples/structured_outputs_examples.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# MLX Server Structured Output Examples\n",
8 | "\n",
9 | "This is a detailed text version of the structured output examples for MLX Server with OpenAI-compatible API."
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "## Setup"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 8,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "from openai import OpenAI"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "## Initialize the client\n",
33 | "\n",
34 | "Connect to your local MLX server:"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 18,
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "client = OpenAI(\n",
44 | " base_url = \"http://localhost:8000/v1\",\n",
45 | " api_key = \"mlx-server-api-key\"\n",
46 | ")"
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "metadata": {},
52 | "source": [
53 | "## Function calling example"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 19,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "# Define the user message\n",
63 | "messages = [\n",
64 | " {\n",
65 | " \"role\": \"user\",\n",
66 | " \"content\": \"What is the weather in Tokyo?\"\n",
67 | " }\n",
68 | "]\n",
69 | "\n",
70 | "# Define the available tools/functions\n",
71 | "tools = [\n",
72 | " {\n",
73 | " \"type\": \"function\",\n",
74 | " \"function\": {\n",
75 | " \"name\": \"get_weather\",\n",
76 | " \"description\": \"Get the weather in a given city\",\n",
77 | " \"parameters\": {\n",
78 | " \"type\": \"object\",\n",
79 | " \"properties\": {\n",
80 | " \"city\": {\"type\": \"string\", \"description\": \"The city to get the weather for\"}\n",
81 | " }\n",
82 | " }\n",
83 | " }\n",
84 | " }\n",
85 | "]"
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {},
91 | "source": [
92 | "### Non Streaming Function Calling Example"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 20,
98 | "metadata": {},
99 | "outputs": [
100 | {
101 | "name": "stdout",
102 | "output_type": "stream",
103 | "text": [
104 | "ChatCompletion(id='chatcmpl_1754135306120611', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content='', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_1754135306725351', function=Function(arguments='{\"city\": \"Tokyo\"}', name='get_weather'), type='function', index=0)], reasoning_content=None))], created=1754135306, model='mlx-server-model', object='chat.completion', service_tier=None, system_fingerprint=None, usage=None)\n"
105 | ]
106 | }
107 | ],
108 | "source": [
109 | "# Make the API call\n",
110 | "completion = client.chat.completions.create(\n",
111 | " model=\"mlx-server-model\",\n",
112 | " messages=messages,\n",
113 | " tools=tools,\n",
114 | " tool_choice=\"auto\",\n",
115 | " max_tokens = 512\n",
116 | ")\n",
117 | "\n",
118 | "# Get the result\n",
119 | "print(completion)"
120 | ]
121 | },
122 | {
123 | "cell_type": "markdown",
124 | "metadata": {},
125 | "source": [
126 | "### Streaming Function Calling Example"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": 21,
132 | "metadata": {},
133 | "outputs": [
134 | {
135 | "name": "stdout",
136 | "output_type": "stream",
137 | "text": [
138 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=None, reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1754135306, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n",
139 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id='call_1754135307829795', function=ChoiceDeltaToolCallFunction(arguments='', name='get_weather'), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1754135306, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n",
140 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments=' {\"', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1754135306, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n",
141 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='city', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1754135306, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n",
142 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='\":', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1754135306, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n",
143 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments=' \"', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1754135306, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n",
144 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='Tok', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1754135306, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n",
145 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='yo', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1754135306, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n",
146 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='\"}', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1754135306, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n",
147 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content='', function_call=None, refusal=None, role='assistant', tool_calls=None, reasoning_content=None), finish_reason='tool_calls', index=0, logprobs=None)], created=1754135308, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n"
148 | ]
149 | }
150 | ],
151 | "source": [
152 | "# Set stream=True in the API call\n",
153 | "completion = client.chat.completions.create(\n",
154 | " model=\"mlx-server-model\",\n",
155 | " messages=messages,\n",
156 | " tools=tools,\n",
157 | " tool_choice=\"auto\",\n",
158 | " stream=True\n",
159 | ")\n",
160 | "\n",
161 | "# Process the streaming response\n",
162 | "for chunk in completion:\n",
163 | " print(chunk)"
164 | ]
165 | },
166 | {
167 | "cell_type": "markdown",
168 | "metadata": {},
169 | "source": [
170 | "# JSON Schema Example"
171 | ]
172 | },
173 | {
174 | "cell_type": "code",
175 | "execution_count": 22,
176 | "metadata": {},
177 | "outputs": [],
178 | "source": [
179 | "messages = [\n",
180 | " {\n",
181 | " \"role\": \"system\",\n",
182 | " \"content\": \"Extract the address from the user input into the specified JSON format.\"\n",
183 | " },\n",
184 | " {\n",
185 | " \"role\": \"user\",\n",
186 | " \"content\": \"Please format this address: 1 Hacker Wy Menlo Park CA 94025\"\n",
187 | " }\n",
188 | "]\n",
189 | "\n",
190 | "response_format = {\n",
191 | " \"type\": \"json_schema\",\n",
192 | " \"json_schema\": {\n",
193 | " \"name\": \"Address\",\n",
194 | " \"schema\": {\n",
195 | " \"properties\": {\n",
196 | " \"address\": {\n",
197 | " \"type\": \"object\",\n",
198 | " \"properties\": {\n",
199 | " \"street\": {\"type\": \"string\"},\n",
200 | " \"city\": {\"type\": \"string\"},\n",
201 | " \"state\": {\n",
202 | " \"type\": \"string\", \n",
203 | " \"description\": \"2 letter abbreviation of the state\"\n",
204 | " },\n",
205 | " \"zip\": {\n",
206 | " \"type\": \"string\", \n",
207 | " \"description\": \"5 digit zip code\"\n",
208 | " }\n",
209 | " },\n",
210 | " \"required\": [\"street\", \"city\", \"state\", \"zip\"]\n",
211 | " }\n",
212 | " },\n",
213 | " \"required\": [\"address\"],\n",
214 | " \"type\": \"object\"\n",
215 | " }\n",
216 | " }\n",
217 | "}\n"
218 | ]
219 | },
220 | {
221 | "cell_type": "markdown",
222 | "metadata": {},
223 | "source": [
224 | "### Non-streaming Structured Output Example"
225 | ]
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": 23,
230 | "metadata": {},
231 | "outputs": [
232 | {
233 | "name": "stdout",
234 | "output_type": "stream",
235 | "text": [
236 | "ChatCompletion(id='chatcmpl_1754135313793796', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='{\"address\": {\"street\": \"1 Hacker Wy\", \"city\": \"Menlo Park\", \"state\": \"CA\", \"zip\": \"94025\"}}', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None))], created=1754135313, model='mlx-server-model', object='chat.completion', service_tier=None, system_fingerprint=None, usage=None)\n"
237 | ]
238 | }
239 | ],
240 | "source": [
241 | "# Make the API call\n",
242 | "completion = client.chat.completions.create(\n",
243 | " model=\"mlx-server-model\",\n",
244 | " messages=messages,\n",
245 | " max_tokens = 512,\n",
246 | " response_format = response_format\n",
247 | ")\n",
248 | "\n",
249 | "# Get the result\n",
250 | "print(completion)"
251 | ]
252 | },
253 | {
254 | "cell_type": "markdown",
255 | "metadata": {},
256 | "source": [
257 | "### Streaming Structured Output Example"
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": 25,
263 | "metadata": {},
264 | "outputs": [
265 | {
266 | "name": "stdout",
267 | "output_type": "stream",
268 | "text": [
269 | "{\"address\": {\"street\": \"1 Hacker Wy\", \"city\": \"Menlo Park\", \"state\": \"CA\", \"zip\": \"94025\"}}"
270 | ]
271 | }
272 | ],
273 | "source": [
274 | "# Make the API call\n",
275 | "completion = client.chat.completions.create(\n",
276 | " model=\"mlx-server-model\",\n",
277 | " messages=messages,\n",
278 | " max_tokens = 512,\n",
279 | " response_format = response_format,\n",
280 | " stream = True\n",
281 | ")\n",
282 | "\n",
283 | "# Process the streaming response\n",
284 | "for chunk in completion:\n",
285 | " if chunk.choices[0].delta.content:\n",
286 | " print(chunk.choices[0].delta.content, end=\"\", flush=True)"
287 | ]
288 | },
289 | {
290 | "cell_type": "code",
291 | "execution_count": null,
292 | "metadata": {},
293 | "outputs": [],
294 | "source": []
295 | }
296 | ],
297 | "metadata": {
298 | "kernelspec": {
299 | "display_name": "testing",
300 | "language": "python",
301 | "name": "python3"
302 | },
303 | "language_info": {
304 | "codemirror_mode": {
305 | "name": "ipython",
306 | "version": 3
307 | },
308 | "file_extension": ".py",
309 | "mimetype": "text/x-python",
310 | "name": "python",
311 | "nbconvert_exporter": "python",
312 | "pygments_lexer": "ipython3",
313 | "version": "3.11.11"
314 | }
315 | },
316 | "nbformat": 4,
317 | "nbformat_minor": 2
318 | }
319 |
--------------------------------------------------------------------------------
/examples/transcription_examples.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "944bf441",
6 | "metadata": {},
7 | "source": [
8 | "# Transcription Tasks with MLX Server and OpenAI API"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "6dafc3b4",
14 | "metadata": {},
15 | "source": [
16 | "This notebook demonstrates how to use the MLX Server with OpenAI-compatible API for transcription tasks.\n"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "id": "6d89bacb",
22 | "metadata": {},
23 | "source": [
24 | "## Setup and Imports"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "id": "bfcfd46c",
30 | "metadata": {},
31 | "source": [
32 | "First, we'll import the necessary libraries and establish a connection to the MLX Server."
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 7,
38 | "id": "a74ac262",
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "# Import the OpenAI client for API communication\n",
43 | "from openai import OpenAI\n",
44 | "\n",
45 | "# Connect to the local MLX Server with OpenAI-compatible API\n",
46 | "client = OpenAI(\n",
47 | " base_url=\"http://localhost:8000/v1\",\n",
48 | " api_key=\"fake-api-key\",\n",
49 | ")"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": 2,
55 | "id": "d68dd370",
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "audio_path = \"audios/podcast.wav\""
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 5,
65 | "id": "7e8a4c6a",
66 | "metadata": {},
67 | "outputs": [
68 | {
69 | "name": "stdout",
70 | "output_type": "stream",
71 | "text": [
72 | "Transcription(text=\" What if Tangero never had the demon slayer mark? Without the mark, Tangero's strength would have hit a ceiling, no insane speed boosts, no crazy recovery. Against upper moons, he'd be fighting on pure heart and swordsmanship alone. Imagine the red blade moment, weaker, slower and every fight, becoming a desperate struggle. Would he still beat a Kaza? Maybe, but Muzon, without that extra edge, Tangero's fate could have been completely different. And here's the twist. Tangero's biggest weapon has always been his willpower. Even without the mark, would his determination rewrite destiny? Do you think Tangero could win without the mark? Comment below.\", logprobs=None, usage={'type': 'duration', 'seconds': 72})\n"
73 | ]
74 | }
75 | ],
76 | "source": [
77 | "with open(audio_path, \"rb\") as f:\n",
78 | " transcription = client.audio.transcriptions.create(\n",
79 | " file=f,\n",
80 | " model=\"mlx-community/whisper-tiny\",\n",
81 | " language=\"en\",\n",
82 | " response_format=\"json\",\n",
83 | " temperature=0.0,\n",
84 | " )\n",
85 | " print(transcription)"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 6,
91 | "id": "58062f32",
92 | "metadata": {},
93 | "outputs": [
94 | {
95 | "name": "stdout",
96 | "output_type": "stream",
97 | "text": [
98 | "TranscriptionTextDeltaEvent(delta=None, type=None, logprobs=None, id='transcription-32cd0b68-a2f9-4240-bc9e-4a7dd1e7e17d', object='transcription.chunk', created=1759658874, model='mlx-community/whisper-tiny', choices=[{'delta': {'content': \" What if Tangero never had the demon slayer mark? Without the mark, Tangero's strength would have hit a ceiling, no insane speed boosts, no crazy recovery. Against upper moons, he'd be fighting on pure heart and swordsmanship alone. Imagine the red blade moment.\", 'function_call': None, 'refusal': None, 'role': None, 'tool_calls': None, 'reasoning_content': None}, 'finish_reason': None, 'stop_reason': None}], usage=None)\n",
99 | "TranscriptionTextDeltaEvent(delta=None, type=None, logprobs=None, id='transcription-32cd0b68-a2f9-4240-bc9e-4a7dd1e7e17d', object='transcription.chunk', created=1759658874, model='mlx-community/whisper-tiny', choices=[{'delta': {'content': \" weaker, slower, and every fight becoming a desperate struggle. Would he still beat Akaza? Maybe, but Muson, without that extra edge, Tangeros' fate could have been completely different. But here's the twist. Tangeros' biggest weapon has always been his willpower. Even without...\", 'function_call': None, 'refusal': None, 'role': None, 'tool_calls': None, 'reasoning_content': None}, 'finish_reason': None, 'stop_reason': None}], usage=None)\n",
100 | "TranscriptionTextDeltaEvent(delta=None, type=None, logprobs=None, id='transcription-32cd0b68-a2f9-4240-bc9e-4a7dd1e7e17d', object='transcription.chunk', created=1759658874, model='mlx-community/whisper-tiny', choices=[{'delta': {'content': ' The Mark would his determination rewrite destiny? Do you think Tangero could win without the Mark? Comment below!', 'function_call': None, 'refusal': None, 'role': None, 'tool_calls': None, 'reasoning_content': None}, 'finish_reason': None, 'stop_reason': None}], usage=None)\n",
101 | "TranscriptionTextDeltaEvent(delta=None, type=None, logprobs=None, id='transcription-32cd0b68-a2f9-4240-bc9e-4a7dd1e7e17d', object='transcription.chunk', created=1759658874, model='mlx-community/whisper-tiny', choices=[{'delta': {'content': '', 'function_call': None, 'refusal': None, 'role': None, 'tool_calls': None, 'reasoning_content': None}, 'finish_reason': 'stop', 'stop_reason': None}], usage=None)\n"
102 | ]
103 | }
104 | ],
105 | "source": [
106 | "with open(audio_path, \"rb\") as f:\n",
107 | " stream = client.audio.transcriptions.create(\n",
108 | " file=f,\n",
109 | " model=\"mlx-community/whisper-tiny\",\n",
110 | " language=\"en\",\n",
111 | " response_format=\"json\",\n",
112 | " temperature=0.0,\n",
113 | " stream=True,\n",
114 | " )\n",
115 | " for chunk in stream:\n",
116 | " print(chunk)"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": null,
122 | "id": "9495f63e",
123 | "metadata": {},
124 | "outputs": [],
125 | "source": []
126 | }
127 | ],
128 | "metadata": {
129 | "kernelspec": {
130 | "display_name": "Python 3",
131 | "language": "python",
132 | "name": "python3"
133 | },
134 | "language_info": {
135 | "codemirror_mode": {
136 | "name": "ipython",
137 | "version": 3
138 | },
139 | "file_extension": ".py",
140 | "mimetype": "text/x-python",
141 | "name": "python",
142 | "nbconvert_exporter": "python",
143 | "pygments_lexer": "ipython3",
144 | "version": "3.11.13"
145 | }
146 | },
147 | "nbformat": 4,
148 | "nbformat_minor": 5
149 | }
150 |
--------------------------------------------------------------------------------
/examples/videos/demo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cubist38/mlx-openai-server/25fc00aa21fc6794358849ebfb9d866a6e203eda/examples/videos/demo.mp4
--------------------------------------------------------------------------------
/examples/vlm_embeddings_examples.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Vision-Language Model (VLM) Embeddings with MLX Server\n",
8 | "\n",
9 | "This notebook demonstrates how to leverage the embeddings endpoint of MLX Server through its OpenAI-compatible API. Vision-Language Models (VLMs) can process both images and text, allowing for multimodal understanding and representation.\n"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "\n",
17 | "## Introduction\n",
18 | "\n",
19 | "MLX Server provides an efficient way to serve multimodal models on Apple Silicon. In this notebook, we'll explore how to:\n",
20 | "\n",
21 | "- Generate embeddings for text and images\n",
22 | "- Work with the OpenAI-compatible API\n",
23 | "- Calculate similarity between text and image representations\n",
24 | "- Understand how these embeddings can be used for practical applications\n",
25 | "\n",
26 | "Embeddings are high-dimensional vector representations of content that capture semantic meaning, making them useful for search, recommendation systems, and other AI applications."
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {},
32 | "source": [
33 | "## 1. Setup and API Connection\n",
34 | "\n",
35 | "- A local server endpoint (`http://localhost:8000/v1`)\n",
36 | "- A placeholder API key (since MLX Server doesn't require authentication by default)\n",
37 | "\n",
38 | "Make sure you have MLX Server running locally before executing this notebook."
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 1,
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "# Import the OpenAI client for API communication\n",
48 | "from openai import OpenAI\n",
49 | "\n",
50 | "# Connect to the local MLX Server with OpenAI-compatible API\n",
51 | "client = OpenAI(\n",
52 | " base_url=\"http://localhost:8000/v1\",\n",
53 | " api_key=\"fake-api-key\",\n",
54 | ")"
55 | ]
56 | },
57 | {
58 | "cell_type": "markdown",
59 | "metadata": {},
60 | "source": [
61 | "## 2. Image Processing for API Requests\n",
62 | "\n",
63 | "When working with image inputs, we need to prepare them in a format that the API can understand. The OpenAI-compatible API expects images to be provided as base64-encoded data URIs.\n",
64 | "\n",
65 | "Below, we'll import the necessary libraries and define a helper function to convert PIL Image objects to the required format."
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 2,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "from PIL import Image\n",
75 | "from io import BytesIO\n",
76 | "import base64"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": 3,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "# To send images to the API, we need to convert them to base64-encoded strings in a data URI format.\n",
86 | "\n",
87 | "def image_to_base64(image: Image.Image):\n",
88 | " \"\"\"\n",
89 | " Convert a PIL Image to a base64-encoded data URI string that can be sent to the API.\n",
90 | " \n",
91 | " Args:\n",
92 | " image: A PIL Image object\n",
93 | " \n",
94 | " Returns:\n",
95 | " A data URI string with the base64-encoded image\n",
96 | " \"\"\"\n",
97 | " # Convert image to bytes\n",
98 | " buffer = BytesIO()\n",
99 | " image.save(buffer, format=\"PNG\")\n",
100 | " buffer.seek(0)\n",
101 | " image_data = buffer.getvalue()\n",
102 | " \n",
103 | " # Encode as base64\n",
104 | " image_base64 = base64.b64encode(image_data).decode('utf-8')\n",
105 | " \n",
106 | " # Create the data URI format required by the API\n",
107 | " mime_type = \"image/png\" \n",
108 | " image_uri = f\"data:{mime_type};base64,{image_base64}\"\n",
109 | " \n",
110 | " return image_uri"
111 | ]
112 | },
113 | {
114 | "cell_type": "markdown",
115 | "metadata": {},
116 | "source": [
117 | "## 3. Loading and Preparing an Image\n",
118 | "\n",
119 | "Now we'll load a sample image (a green dog in this case) and convert it to the base64 format required by the API. This image will be used to generate embeddings in the subsequent steps."
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": 5,
125 | "metadata": {},
126 | "outputs": [],
127 | "source": [
128 | "image = Image.open(\"images/green_dog.jpeg\")\n",
129 | "image_uri = image_to_base64(image)"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {},
135 | "source": [
136 | "## 4. Generating Embeddings"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": 12,
142 | "metadata": {},
143 | "outputs": [],
144 | "source": [
145 | "# Generate embedding for a single text input\n",
146 | "prompt = \"Describe the image in detail\"\n",
147 | "image_embedding = client.embeddings.create(\n",
148 | " input=[prompt],\n",
149 | " model=\"mlx-community/Qwen2.5-VL-3B-Instruct-4bit\",\n",
150 | " extra_body = {\n",
151 | " \"image_url\": image_uri\n",
152 | " }\n",
153 | ").data[0].embedding\n",
154 | "\n",
155 | "text = \"A green dog looking at the camera\"\n",
156 | "text_embedding = client.embeddings.create(\n",
157 | " input=[text],\n",
158 | " model=\"mlx-community/Qwen2.5-VL-3B-Instruct-4bit\"\n",
159 | ").data[0].embedding"
160 | ]
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "metadata": {},
165 | "source": [
166 | "## 5. Comparing Text and Image Embeddings\n",
167 | "\n",
168 | "One of the powerful features of VLM embeddings is that they create a shared vector space for both text and images. This means we can directly compare how similar a text description is to an image's content by calculating the cosine similarity between their embeddings.\n",
169 | "\n",
170 | "A higher similarity score (closer to 1.0) indicates that the text description closely matches the image content according to the model's understanding."
171 | ]
172 | },
173 | {
174 | "cell_type": "code",
175 | "execution_count": 13,
176 | "metadata": {},
177 | "outputs": [
178 | {
179 | "name": "stdout",
180 | "output_type": "stream",
181 | "text": [
182 | "0.8473370724651375\n"
183 | ]
184 | }
185 | ],
186 | "source": [
187 | "import numpy as np\n",
188 | "\n",
189 | "def cosine_similarity(a, b):\n",
190 | " a = np.array(a)\n",
191 | " b = np.array(b)\n",
192 | " return np.dot(a, b)\n",
193 | "\n",
194 | "similarity = cosine_similarity(image_embedding, text_embedding)\n",
195 | "print(similarity)"
196 | ]
197 | }
198 | ],
199 | "metadata": {
200 | "kernelspec": {
201 | "display_name": "Python 3",
202 | "language": "python",
203 | "name": "python3"
204 | },
205 | "language_info": {
206 | "codemirror_mode": {
207 | "name": "ipython",
208 | "version": 3
209 | },
210 | "file_extension": ".py",
211 | "mimetype": "text/x-python",
212 | "name": "python",
213 | "nbconvert_exporter": "python",
214 | "pygments_lexer": "ipython3",
215 | "version": "3.11.12"
216 | }
217 | },
218 | "nbformat": 4,
219 | "nbformat_minor": 2
220 | }
221 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from app import __version__
3 |
4 |
5 | setup(
6 | name="mlx-openai-server",
7 | url="https://github.com/cubist38/mlx-openai-server",
8 | author="Gia-Huy Vuong",
9 | author_email="cubist38@gmail.com",
10 | version=__version__,
11 | description="A high-performance API server that provides OpenAI-compatible endpoints for MLX models. Built with Python and FastAPI, it enables efficient, scalable, and user-friendly local deployment of MLX-based multimodal models with an OpenAI-compatible interface. Supports text, vision, and audio processing capabilities. Perfect for developers looking to run MLX models locally while maintaining compatibility with existing OpenAI-based applications.",
12 | long_description=open("README.md").read(),
13 | long_description_content_type="text/markdown",
14 | packages=find_packages(),
15 | install_requires=[
16 | "mlx-vlm==0.3.4",
17 | "mlx-lm==0.28.3",
18 | "torchvision==0.23.0",
19 | "mlx-whisper==0.4.3",
20 | "mlx-embeddings==0.0.4",
21 | "fastapi==0.115.14",
22 | "av==16.0.1",
23 | "uvicorn==0.35.0",
24 | "Pillow==10.4.0",
25 | "click==8.2.1",
26 | "loguru==0.7.3",
27 | "outlines==1.1.1",
28 | "librosa==0.11.0",
29 | "openai-harmony==0.0.4",
30 | "json_repair==0.52.1",
31 | "python-multipart==0.0.20"
32 | ],
33 | extras_require={
34 | "dev": [
35 | "pytest",
36 | "black",
37 | "isort",
38 | "flake8",
39 | ]
40 | },
41 | entry_points={
42 | "console_scripts": [
43 | "mlx-openai-server=app.cli:cli",
44 | ],
45 | },
46 | python_requires=">=3.11",
47 | classifiers=[
48 | "Programming Language :: Python :: 3",
49 | "License :: OSI Approved :: MIT License",
50 | "Operating System :: OS Independent",
51 | ],
52 | )
--------------------------------------------------------------------------------
/tests/test_base_tool_parser.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from app.handler.parser.base import BaseToolParser, ParseState
4 |
5 |
6 | class TestBaseToolParser(unittest.TestCase):
7 | def setUp(self):
8 | self.test_cases = [
9 | {
10 | "name": "simple function call",
11 | "chunks": '''##
12 | #{"#name#":# "#get#_weather#",# "#arguments#":# {"#city#":# "#H#ue#"}}
13 | ##
14 | ##
15 | #{"#name#":# "#get#_weather#",# "#arguments#":# {"#city#":# "#Sy#dney#"}}
16 | ###'''.split('#')
17 | ,
18 | "expected_outputs": [
19 | {'name': 'get_weather', 'arguments': ''},
20 | {'name': None, 'arguments': ' {"'},
21 | {'name': None, 'arguments': 'city'},
22 | {'name': None, 'arguments': '":'},
23 | {'name': None, 'arguments': ' "'},
24 | {'name': None, 'arguments': 'H'},
25 | {'name': None, 'arguments': 'ue'},
26 | {'name': None, 'arguments': '"}'},
27 | '\n',
28 | {'name': 'get_weather', 'arguments': ''},
29 | {'name': None, 'arguments': ' {"'},
30 | {'name': None, 'arguments': 'city'},
31 | {'name': None, 'arguments': '":'},
32 | {'name': None, 'arguments': ' "'},
33 | {'name': None, 'arguments': 'Sy'},
34 | {'name': None, 'arguments': 'dney'},
35 | {'name': None, 'arguments': '"}'},
36 | ]
37 | },
38 | {
39 | "name": "code function call",
40 | "chunks": r'''@@
41 | @@{"@@name@@":@@ "@@python@@",@@ "@@arguments@@":@@ {"@@code@@":@@ "@@def@@ calculator@@(a@@,@@ b@@,@@ operation@@):\@@n@@ @@ if@@ operation@@ ==@@ '@@add@@'\@@n@@ @@ return@@ a@@ +@@ b@@\n@@ @@ if@@ operation@@ ==@@ '@@subtract@@'\@@n@@ @@ return@@ a@@ -@@ b@@\n@@ @@ if@@ operation@@ ==@@ '@@multiply@@'\@@n@@ @@ return@@ a@@ *@@ b@@\n@@ @@ if@@ operation@@ ==@@ '@@divide@@'\@@n@@ @@ return@@ a@@ /@@ b@@\n@@ @@ return@@ '@@Invalid@@ operation@@'@@"}}
42 | @@@@@@'''.split('@@')
43 | ,
44 | "expected_outputs": [
45 | {'name': 'python', 'arguments': ''},
46 | {'name': None, 'arguments': ' {"'},
47 | {'name': None, 'arguments': 'code'},
48 | {'name': None, 'arguments': '":'},
49 | {'name': None, 'arguments': ' "'},
50 | {'name': None, 'arguments': 'def'},
51 | {'name': None, 'arguments': ' calculator'},
52 | {'name': None, 'arguments': '(a'},
53 | {'name': None, 'arguments': ','},
54 | {'name': None, 'arguments': ' b'},
55 | {'name': None, 'arguments': ','},
56 | {'name': None, 'arguments': ' operation'},
57 | {'name': None, 'arguments': '):\\'},
58 | {'name': None, 'arguments': 'n'},
59 | {'name': None, 'arguments': ' '},
60 | {'name': None, 'arguments': ' if'},
61 | {'name': None, 'arguments': ' operation'},
62 | {'name': None, 'arguments': ' =='},
63 | {'name': None, 'arguments': " '"},
64 | {'name': None, 'arguments': 'add'},
65 | {'name': None, 'arguments': "'\\"},
66 | {'name': None, 'arguments': 'n'},
67 | {'name': None, 'arguments': ' '},
68 | {'name': None, 'arguments': ' return'},
69 | {'name': None, 'arguments': ' a'},
70 | {'name': None, 'arguments': ' +'},
71 | {'name': None, 'arguments': ' b'},
72 | {'name': None, 'arguments': '\\n'},
73 | {'name': None, 'arguments': ' '},
74 | {'name': None, 'arguments': ' if'},
75 | {'name': None, 'arguments': ' operation'},
76 | {'name': None, 'arguments': ' =='},
77 | {'name': None, 'arguments': " '"},
78 | {'name': None, 'arguments': 'subtract'},
79 | {'name': None, 'arguments': "'\\"},
80 | {'name': None, 'arguments': 'n'},
81 | {'name': None, 'arguments': ' '},
82 | {'name': None, 'arguments': ' return'},
83 | {'name': None, 'arguments': ' a'},
84 | {'name': None, 'arguments': ' -'},
85 | {'name': None, 'arguments': ' b'},
86 | {'name': None, 'arguments': '\\n'},
87 | {'name': None, 'arguments': ' '},
88 | {'name': None, 'arguments': ' if'},
89 | {'name': None, 'arguments': ' operation'},
90 | {'name': None, 'arguments': ' =='},
91 | {'name': None, 'arguments': " '"},
92 | {'name': None, 'arguments': 'multiply'},
93 | {'name': None, 'arguments': "'\\"},
94 | {'name': None, 'arguments': 'n'},
95 | {'name': None, 'arguments': ' '},
96 | {'name': None, 'arguments': ' return'},
97 | {'name': None, 'arguments': ' a'},
98 | {'name': None, 'arguments': ' *'},
99 | {'name': None, 'arguments': ' b'},
100 | {'name': None, 'arguments': '\\n'},
101 | {'name': None, 'arguments': ' '},
102 | {'name': None, 'arguments': ' if'},
103 | {'name': None, 'arguments': ' operation'},
104 | {'name': None, 'arguments': ' =='},
105 | {'name': None, 'arguments': " '"},
106 | {'name': None, 'arguments': 'divide'},
107 | {'name': None, 'arguments': "'\\"},
108 | {'name': None, 'arguments': 'n'},
109 | {'name': None, 'arguments': ' '},
110 | {'name': None, 'arguments': ' return'},
111 | {'name': None, 'arguments': ' a'},
112 | {'name': None, 'arguments': ' /'},
113 | {'name': None, 'arguments': ' b'},
114 | {'name': None, 'arguments': '\\n'},
115 | {'name': None, 'arguments': ' '},
116 | {'name': None, 'arguments': ' return'},
117 | {'name': None, 'arguments': " '"},
118 | {'name': None, 'arguments': 'Invalid'},
119 | {'name': None, 'arguments': ' operation'},
120 | {'name': None, 'arguments': "'"},
121 | {'name': None, 'arguments': '"}'},
122 | ]
123 | },
124 | ]
125 |
126 | def test_parse_stream(self):
127 | for test_case in self.test_cases:
128 | with self.subTest(msg=test_case["name"]):
129 | parser = BaseToolParser("", "")
130 | outputs = []
131 |
132 | for chunk in test_case["chunks"]:
133 | result = parser.parse_stream(chunk)
134 | if result:
135 | outputs.append(result)
136 |
137 |
138 | self.assertEqual(len(outputs), len(test_case["expected_outputs"]),
139 | f"Expected {len(test_case['expected_outputs'])} outputs, got {len(outputs)}")
140 |
141 | for i, (output, expected) in enumerate(zip(outputs, test_case["expected_outputs"])):
142 | self.assertEqual(output, expected,
143 | f"Chunk {i}: Expected {expected}, got {output}")
144 |
145 | if __name__ == '__main__':
146 | unittest.main()
147 |
--------------------------------------------------------------------------------