├── .github └── workflows │ └── publish.yml ├── .gitignore ├── MANIFEST.in ├── Makefile ├── README.md ├── app ├── __init__.py ├── api │ ├── __init__.py │ └── endpoints.py ├── cli.py ├── core │ ├── __init__.py │ ├── audio_processor.py │ ├── base_processor.py │ ├── image_processor.py │ ├── queue.py │ └── video_processor.py ├── handler │ ├── __init__.py │ ├── mflux.py │ ├── mlx_embeddings.py │ ├── mlx_lm.py │ ├── mlx_vlm.py │ ├── mlx_whisper.py │ └── parser │ │ ├── __init__.py │ │ ├── base.py │ │ ├── glm4_moe.py │ │ ├── harmony.py │ │ └── qwen3.py ├── main.py ├── models │ ├── __init__.py │ ├── mflux.py │ ├── mlx_embeddings.py │ ├── mlx_lm.py │ ├── mlx_vlm.py │ └── mlx_whisper.py ├── schemas │ ├── __init__.py │ └── openai.py ├── utils │ ├── __init__.py │ ├── dill.py │ ├── errors.py │ └── outlines_transformer_tokenizer.py └── version.py ├── configure_mlx.sh ├── examples ├── audio_examples.ipynb ├── audios │ ├── audio.wav │ └── podcast.wav ├── embedding_examples.ipynb ├── image_edit.ipynb ├── image_generations.ipynb ├── images │ ├── attention.png │ ├── china.png │ ├── green_dog.jpeg │ └── password.jpg ├── lm_embeddings_examples.ipynb ├── pdfs │ └── lab03.pdf ├── simple_rag_demo.ipynb ├── structured_outputs_examples.ipynb ├── transcription_examples.ipynb ├── videos │ └── demo.mp4 ├── vision_examples.ipynb └── vlm_embeddings_examples.ipynb ├── setup.py └── tests └── test_base_tool_parser.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distribution 📦 to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' # Triggers on version tags like v1.0.0 7 | 8 | jobs: 9 | build-and-publish: 10 | runs-on: macos-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - name: Set up Python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: '3.11' 19 | 20 | - name: Install build tools 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install build twine 24 | 25 | - name: Build package 26 | run: python -m build 27 | 28 | - name: Publish package to PyPI 29 | env: 30 | TWINE_USERNAME: __token__ 31 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 32 | run: twine upload dist/* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | oai-compat-server 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | # ignore .DS_Store 11 | .DS_Store 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # UV 102 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | #uv.lock 106 | 107 | # poetry 108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 109 | # This is especially recommended for binary packages to ensure reproducibility, and is more 110 | # commonly ignored for libraries. 111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 112 | #poetry.lock 113 | 114 | # pdm 115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 116 | #pdm.lock 117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 118 | # in version control. 119 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 120 | .pdm.toml 121 | .pdm-python 122 | .pdm-build/ 123 | 124 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 125 | __pypackages__/ 126 | 127 | # Celery stuff 128 | celerybeat-schedule 129 | celerybeat.pid 130 | 131 | # SageMath parsed files 132 | *.sage.py 133 | 134 | # Environments 135 | .env 136 | .venv 137 | env/ 138 | venv/ 139 | ENV/ 140 | env.bak/ 141 | venv.bak/ 142 | 143 | # Spyder project settings 144 | .spyderproject 145 | .spyproject 146 | 147 | # Rope project settings 148 | .ropeproject 149 | 150 | # mkdocs documentation 151 | /site 152 | 153 | # mypy 154 | .mypy_cache/ 155 | .dmypy.json 156 | dmypy.json 157 | 158 | # Pyre type checker 159 | .pyre/ 160 | 161 | # pytype static type analyzer 162 | .pytype/ 163 | 164 | # Cython debug symbols 165 | cython_debug/ 166 | 167 | # PyCharm 168 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 169 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 170 | # and can be added to the global gitignore or merged into this file. For a more nuclear 171 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 172 | #.idea/ 173 | 174 | # Ruff stuff: 175 | .ruff_cache/ 176 | 177 | # PyPI configuration file 178 | .pypirc 179 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include requirements.txt 3 | include MANIFEST.in 4 | include setup.py 5 | recursive-include app * -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | run: 2 | mlx-server launch \ 3 | --model-path mlx-community/Qwen3-1.7B-4bit \ 4 | --model-type lm \ 5 | --max-concurrency 1 \ 6 | --queue-timeout 300 \ 7 | --queue-size 100 8 | 9 | install: 10 | pip install -e . -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from app.version import __version__ 3 | 4 | # Suppress transformers warnings 5 | os.environ['TRANSFORMERS_VERBOSITY'] = 'error' 6 | 7 | __all__ = ["__version__"] -------------------------------------------------------------------------------- /app/api/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /app/cli.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import asyncio 3 | import click 4 | import uvicorn 5 | from loguru import logger 6 | from functools import lru_cache 7 | from app.version import __version__ 8 | from app.main import setup_server 9 | 10 | class Config: 11 | """Configuration container for server parameters.""" 12 | def __init__(self, model_path, model_type, context_length, port, host, max_concurrency, queue_timeout, queue_size, disable_auto_resize=False, quantize=8, config_name=None, lora_paths=None, lora_scales=None, log_file=None, no_log_file=False, log_level="INFO"): 13 | self.model_path = model_path 14 | self.model_type = model_type 15 | self.context_length = context_length 16 | self.port = port 17 | self.host = host 18 | self.max_concurrency = max_concurrency 19 | self.queue_timeout = queue_timeout 20 | self.queue_size = queue_size 21 | self.disable_auto_resize = disable_auto_resize 22 | self.quantize = quantize 23 | self.config_name = config_name 24 | self.log_file = log_file 25 | self.no_log_file = no_log_file 26 | self.log_level = log_level 27 | 28 | # Process comma-separated LoRA paths and scales 29 | if lora_paths: 30 | self.lora_paths = [path.strip() for path in lora_paths.split(',') if path.strip()] 31 | else: 32 | self.lora_paths = None 33 | 34 | if lora_scales: 35 | self.lora_scales = [float(scale.strip()) for scale in lora_scales.split(',') if scale.strip()] 36 | else: 37 | self.lora_scales = None 38 | 39 | 40 | @property 41 | def model_identifier(self): 42 | """Get the appropriate model identifier based on model type.""" 43 | # For Flux models, we always use model_path (local directory path) 44 | return self.model_path 45 | 46 | 47 | # Configure basic logging for CLI (will be overridden by main.py) 48 | logger.remove() # Remove default handler 49 | logger.add( 50 | sys.stderr, 51 | format="{time:YYYY-MM-DD HH:mm:ss} | " 52 | "{level: <8} | " 53 | "{name}:{function}:{line} | " 54 | "✦ {message}", 55 | colorize=True, 56 | level="INFO" 57 | ) 58 | 59 | 60 | @click.group() 61 | @click.version_option( 62 | version=__version__, 63 | message=""" 64 | ✨ %(prog)s - OpenAI Compatible API Server for MLX models ✨ 65 | ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66 | 🚀 Version: %(version)s 67 | """ 68 | ) 69 | def cli(): 70 | """MLX Server - OpenAI Compatible API for MLX models.""" 71 | pass 72 | 73 | 74 | @lru_cache(maxsize=1) 75 | def get_server_config(model_path, model_type, context_length, port, host, max_concurrency, queue_timeout, queue_size, quantize, config_name, lora_paths, lora_scales, disable_auto_resize, log_file, no_log_file, log_level): 76 | """Cache and return server configuration to avoid redundant processing.""" 77 | return Config( 78 | model_path=model_path, 79 | model_type=model_type, 80 | context_length=context_length, 81 | port=port, 82 | host=host, 83 | max_concurrency=max_concurrency, 84 | queue_timeout=queue_timeout, 85 | queue_size=queue_size, 86 | disable_auto_resize=disable_auto_resize, 87 | quantize=quantize, 88 | config_name=config_name, 89 | lora_paths=lora_paths, 90 | lora_scales=lora_scales, 91 | log_file=log_file, 92 | no_log_file=no_log_file, 93 | log_level=log_level 94 | ) 95 | 96 | 97 | def print_startup_banner(args): 98 | """Display beautiful startup banner with configuration details.""" 99 | logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") 100 | logger.info(f"✨ MLX Server v{__version__} Starting ✨") 101 | logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") 102 | logger.info(f"🔮 Model Path: {args.model_path}") 103 | logger.info(f"🔮 Model Type: {args.model_type}") 104 | if args.context_length: 105 | logger.info(f"🔮 Context Length: {args.context_length}") 106 | logger.info(f"🌐 Host: {args.host}") 107 | logger.info(f"🔌 Port: {args.port}") 108 | logger.info(f"⚡ Max Concurrency: {args.max_concurrency}") 109 | logger.info(f"⏱️ Queue Timeout: {args.queue_timeout} seconds") 110 | logger.info(f"📊 Queue Size: {args.queue_size}") 111 | if args.model_type in ["image-generation", "image-edit"]: 112 | logger.info(f"🔮 Quantize: {args.quantize}") 113 | logger.info(f"🔮 Config Name: {args.config_name}") 114 | if args.lora_paths: 115 | logger.info(f"🔮 LoRA Paths: {args.lora_paths}") 116 | if args.lora_scales: 117 | logger.info(f"🔮 LoRA Scales: {args.lora_scales}") 118 | if hasattr(args, 'disable_auto_resize') and args.disable_auto_resize and args.model_type == "multimodal": 119 | logger.info(f"🖼️ Auto-resize: Disabled") 120 | logger.info(f"📝 Log Level: {args.log_level}") 121 | if args.no_log_file: 122 | logger.info(f"📝 File Logging: Disabled") 123 | elif args.log_file: 124 | logger.info(f"📝 Log File: {args.log_file}") 125 | else: 126 | logger.info(f"📝 Log File: logs/app.log (default)") 127 | logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") 128 | 129 | @cli.command() 130 | @click.option( 131 | "--model-path", 132 | help="Path to the model (required for lm, multimodal, embeddings, image-generation, image-edit, whisper model types). With `image-generation` or `image-edit` model types, it should be the local path to the model." 133 | ) 134 | @click.option( 135 | "--model-type", 136 | default="lm", 137 | type=click.Choice(["lm", "multimodal", "image-generation", "image-edit", "embeddings", "whisper"]), 138 | help="Type of model to run (lm: text-only, multimodal: text+vision+audio, image-generation: flux image generation, image-edit: flux image edit, embeddings: text embeddings, whisper: audio transcription)" 139 | ) 140 | @click.option( 141 | "--context-length", 142 | default=None, 143 | type=int, 144 | help="Context length for language models. Only works with `lm` or `multimodal` model types." 145 | ) 146 | @click.option( 147 | "--port", 148 | default=8000, 149 | type=int, 150 | help="Port to run the server on" 151 | ) 152 | @click.option( 153 | "--host", 154 | default="0.0.0.0", 155 | help="Host to run the server on" 156 | ) 157 | @click.option( 158 | "--max-concurrency", 159 | default=1, 160 | type=int, 161 | help="Maximum number of concurrent requests" 162 | ) 163 | @click.option( 164 | "--queue-timeout", 165 | default=300, 166 | type=int, 167 | help="Request timeout in seconds" 168 | ) 169 | @click.option( 170 | "--queue-size", 171 | default=100, 172 | type=int, 173 | help="Maximum queue size for pending requests" 174 | ) 175 | @click.option( 176 | "--quantize", 177 | default=8, 178 | type=int, 179 | help="Quantization level for the model. Only used for image-generation and image-edit Flux models." 180 | ) 181 | @click.option( 182 | "--config-name", 183 | default=None, 184 | type=click.Choice(["flux-schnell", "flux-dev", "flux-krea-dev", "flux-kontext-dev"]), 185 | help="Config name of the model. Only used for image-generation and image-edit Flux models." 186 | ) 187 | @click.option( 188 | "--lora-paths", 189 | default=None, 190 | type=str, 191 | help="Path to the LoRA file(s). Multiple paths should be separated by commas." 192 | ) 193 | @click.option( 194 | "--lora-scales", 195 | default=None, 196 | type=str, 197 | help="Scale factor for the LoRA file(s). Multiple scales should be separated by commas." 198 | ) 199 | @click.option( 200 | "--disable-auto-resize", 201 | is_flag=True, 202 | help="Disable automatic model resizing. Only work for Vision Language Models." 203 | ) 204 | @click.option( 205 | "--log-file", 206 | default=None, 207 | type=str, 208 | help="Path to log file. If not specified, logs will be written to 'logs/app.log' by default." 209 | ) 210 | @click.option( 211 | "--no-log-file", 212 | is_flag=True, 213 | help="Disable file logging entirely. Only console output will be shown." 214 | ) 215 | @click.option( 216 | "--log-level", 217 | default="INFO", 218 | type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), 219 | help="Set the logging level. Default is INFO." 220 | ) 221 | def launch(model_path, model_type, context_length, port, host, max_concurrency, queue_timeout, queue_size, quantize, config_name, lora_paths, lora_scales, disable_auto_resize, log_file, no_log_file, log_level): 222 | """Launch the MLX server with the specified model.""" 223 | try: 224 | # Validate that config name is only used with image-generation and image-edit model types 225 | if config_name and model_type not in ["image-generation", "image-edit"]: 226 | logger.warning(f"Config name parameter '{config_name}' provided but model type is '{model_type}'. Config name is only used with image-generation and image-edit models.") 227 | elif model_type == "image-generation" and not config_name: 228 | logger.warning("Model type is 'image-generation' but no config name specified. Using default 'flux-schnell'.") 229 | config_name = "flux-schnell" 230 | elif model_type == "image-edit" and not config_name: 231 | logger.warning("Model type is 'image-edit' but no config name specified. Using default 'flux-kontext-dev'.") 232 | config_name = "flux-kontext-dev" 233 | 234 | # Get optimized configuration 235 | args = get_server_config(model_path, model_type, context_length, port, host, max_concurrency, queue_timeout, queue_size, quantize, config_name, lora_paths, lora_scales, disable_auto_resize, log_file, no_log_file, log_level) 236 | 237 | # Display startup information 238 | print_startup_banner(args) 239 | 240 | # Set up and start the server 241 | config = asyncio.run(setup_server(args)) 242 | logger.info("Server configuration complete.") 243 | logger.info("Starting Uvicorn server...") 244 | uvicorn.Server(config).run() 245 | except KeyboardInterrupt: 246 | logger.info("Server shutdown requested by user. Exiting...") 247 | except Exception as e: 248 | logger.error(f"Server startup failed: {str(e)}") 249 | sys.exit(1) 250 | 251 | 252 | if __name__ == "__main__": 253 | cli() -------------------------------------------------------------------------------- /app/core/__init__.py: -------------------------------------------------------------------------------- 1 | from app.core.base_processor import BaseProcessor 2 | from app.core.audio_processor import AudioProcessor 3 | from app.core.image_processor import ImageProcessor 4 | from app.core.video_processor import VideoProcessor 5 | 6 | __all__ = ["BaseProcessor", "AudioProcessor", "ImageProcessor", "VideoProcessor"] -------------------------------------------------------------------------------- /app/core/audio_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import asyncio 4 | from typing import List 5 | from app.core.base_processor import BaseProcessor 6 | 7 | 8 | class AudioProcessor(BaseProcessor): 9 | """Audio processor for handling audio files with caching and validation.""" 10 | 11 | def __init__(self, max_workers: int = 4, cache_size: int = 1000): 12 | super().__init__(max_workers, cache_size) 13 | # Supported audio formats 14 | self._supported_formats = {'.mp3', '.wav'} 15 | 16 | def _get_media_format(self, media_url: str, data: bytes = None) -> str: 17 | """Determine audio format from URL or data.""" 18 | if media_url.startswith("data:"): 19 | # Extract format from data URL 20 | mime_type = media_url.split(";")[0].split(":")[1] 21 | if "mp3" in mime_type or "mpeg" in mime_type: 22 | return "mp3" 23 | elif "wav" in mime_type: 24 | return "wav" 25 | elif "m4a" in mime_type or "mp4" in mime_type: 26 | return "m4a" 27 | elif "ogg" in mime_type: 28 | return "ogg" 29 | elif "flac" in mime_type: 30 | return "flac" 31 | elif "aac" in mime_type: 32 | return "aac" 33 | else: 34 | # Extract format from file extension 35 | ext = os.path.splitext(media_url.lower())[1] 36 | if ext in self._supported_formats: 37 | return ext[1:] # Remove the dot 38 | 39 | # Default to mp3 if format cannot be determined 40 | return "mp3" 41 | 42 | def _validate_media_data(self, data: bytes) -> bool: 43 | """Basic validation of audio data.""" 44 | if len(data) < 100: # Too small to be a valid audio file 45 | return False 46 | 47 | # Check for common audio file signatures 48 | audio_signatures = [ 49 | b'ID3', # MP3 with ID3 tag 50 | b'\xff\xfb', # MP3 frame header 51 | b'\xff\xf3', # MP3 frame header 52 | b'\xff\xf2', # MP3 frame header 53 | b'RIFF', # WAV/AVI 54 | b'OggS', # OGG 55 | b'fLaC', # FLAC 56 | b'\x00\x00\x00\x20ftypM4A', # M4A 57 | ] 58 | 59 | for sig in audio_signatures: 60 | if data.startswith(sig): 61 | return True 62 | 63 | # Check for WAV format (RIFF header might be at different position) 64 | if b'WAVE' in data[:50]: 65 | return True 66 | 67 | return True # Allow unknown formats to pass through 68 | 69 | def _get_timeout(self) -> int: 70 | """Get timeout for HTTP requests.""" 71 | return 60 # Longer timeout for audio files 72 | 73 | def _get_max_file_size(self) -> int: 74 | """Get maximum file size in bytes.""" 75 | return 500 * 1024 * 1024 # 500 MB limit for audio 76 | 77 | def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> str: 78 | """Process audio data and save to cached path.""" 79 | with open(cached_path, 'wb') as f: 80 | f.write(data) 81 | self._cleanup_old_files() 82 | return cached_path 83 | 84 | def _get_media_type_name(self) -> str: 85 | """Get media type name for logging.""" 86 | return "audio" 87 | 88 | async def process_audio_url(self, audio_url: str) -> str: 89 | """Process a single audio URL and return path to cached file.""" 90 | return await self._process_single_media(audio_url) 91 | 92 | async def process_audio_urls(self, audio_urls: List[str]) -> List[str]: 93 | """Process multiple audio URLs and return paths to cached files.""" 94 | tasks = [self.process_audio_url(url) for url in audio_urls] 95 | results = await asyncio.gather(*tasks, return_exceptions=True) 96 | # Force garbage collection after batch processing 97 | gc.collect() 98 | return results -------------------------------------------------------------------------------- /app/core/base_processor.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import hashlib 3 | import os 4 | import tempfile 5 | import aiohttp 6 | import time 7 | import gc 8 | from loguru import logger 9 | from typing import Dict, Optional, Any 10 | from concurrent.futures import ThreadPoolExecutor 11 | from abc import ABC, abstractmethod 12 | 13 | 14 | class BaseProcessor(ABC): 15 | """Base class for media processors with common caching and session management.""" 16 | 17 | def __init__(self, max_workers: int = 4, cache_size: int = 1000): 18 | # Use tempfile for macOS-efficient temporary file handling 19 | self.temp_dir = tempfile.TemporaryDirectory() 20 | self._session: Optional[aiohttp.ClientSession] = None 21 | self.executor = ThreadPoolExecutor(max_workers=max_workers) 22 | self._cache_size = cache_size 23 | self._last_cleanup = time.time() 24 | self._cleanup_interval = 3600 # 1 hour 25 | # Replace lru_cache with manual cache for better control 26 | self._hash_cache: Dict[str, str] = {} 27 | self._cache_access_times: Dict[str, float] = {} 28 | 29 | def _get_media_hash(self, media_url: str) -> str: 30 | """Get hash for media URL with manual caching that can be cleared.""" 31 | # Check if already cached 32 | if media_url in self._hash_cache: 33 | self._cache_access_times[media_url] = time.time() 34 | return self._hash_cache[media_url] 35 | 36 | # Generate hash 37 | if media_url.startswith("data:"): 38 | _, encoded = media_url.split(",", 1) 39 | data = base64.b64decode(encoded) 40 | else: 41 | data = media_url.encode('utf-8') 42 | 43 | hash_value = hashlib.md5(data).hexdigest() 44 | 45 | # Add to cache with size management 46 | if len(self._hash_cache) >= self._cache_size: 47 | self._evict_oldest_cache_entries() 48 | 49 | self._hash_cache[media_url] = hash_value 50 | self._cache_access_times[media_url] = time.time() 51 | return hash_value 52 | 53 | def _evict_oldest_cache_entries(self): 54 | """Remove oldest 20% of cache entries to make room.""" 55 | if not self._cache_access_times: 56 | return 57 | 58 | # Sort by access time and remove oldest 20% 59 | sorted_items = sorted(self._cache_access_times.items(), key=lambda x: x[1]) 60 | to_remove = len(sorted_items) // 5 # Remove 20% 61 | 62 | for url, _ in sorted_items[:to_remove]: 63 | self._hash_cache.pop(url, None) 64 | self._cache_access_times.pop(url, None) 65 | 66 | # Force garbage collection after cache eviction 67 | gc.collect() 68 | 69 | @abstractmethod 70 | def _get_media_format(self, media_url: str, data: bytes = None) -> str: 71 | """Determine media format from URL or data. Must be implemented by subclasses.""" 72 | pass 73 | 74 | @abstractmethod 75 | def _validate_media_data(self, data: bytes) -> bool: 76 | """Validate media data. Must be implemented by subclasses.""" 77 | pass 78 | 79 | @abstractmethod 80 | def _get_timeout(self) -> int: 81 | """Get timeout for HTTP requests. Must be implemented by subclasses.""" 82 | pass 83 | 84 | @abstractmethod 85 | def _get_max_file_size(self) -> int: 86 | """Get maximum file size in bytes. Must be implemented by subclasses.""" 87 | pass 88 | 89 | @abstractmethod 90 | def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> Dict[str, Any]: 91 | """Process media data and save to cached path. Must be implemented by subclasses.""" 92 | pass 93 | 94 | @abstractmethod 95 | def _get_media_type_name(self) -> str: 96 | """Get media type name for logging. Must be implemented by subclasses.""" 97 | pass 98 | 99 | async def _get_session(self) -> aiohttp.ClientSession: 100 | if self._session is None or self._session.closed: 101 | self._session = aiohttp.ClientSession( 102 | timeout=aiohttp.ClientTimeout(total=self._get_timeout()), 103 | headers={"User-Agent": "mlx-server-OAI-compat/1.0"} 104 | ) 105 | return self._session 106 | 107 | def _cleanup_old_files(self): 108 | current_time = time.time() 109 | if current_time - self._last_cleanup > self._cleanup_interval: 110 | try: 111 | for file in os.listdir(self.temp_dir.name): 112 | file_path = os.path.join(self.temp_dir.name, file) 113 | if os.path.getmtime(file_path) < current_time - self._cleanup_interval: 114 | os.remove(file_path) 115 | self._last_cleanup = current_time 116 | # Also clean up cache periodically 117 | if len(self._hash_cache) > self._cache_size * 0.8: 118 | self._evict_oldest_cache_entries() 119 | gc.collect() # Force garbage collection after cleanup 120 | except Exception as e: 121 | logger.warning(f"Failed to clean up old {self._get_media_type_name()} files: {str(e)}") 122 | 123 | async def _process_single_media(self, media_url: str, **kwargs) -> str: 124 | try: 125 | media_hash = self._get_media_hash(media_url) 126 | media_format = self._get_media_format(media_url) 127 | cached_path = os.path.join(self.temp_dir.name, f"{media_hash}.{media_format}") 128 | 129 | if os.path.exists(cached_path): 130 | logger.debug(f"Using cached {self._get_media_type_name()}: {cached_path}") 131 | return cached_path 132 | 133 | if os.path.exists(media_url): 134 | # Copy local file to cache 135 | with open(media_url, 'rb') as f: 136 | data = f.read() 137 | 138 | if not self._validate_media_data(data): 139 | raise ValueError(f"Invalid {self._get_media_type_name()} file format") 140 | 141 | return self._process_media_data(data, cached_path, **kwargs) 142 | 143 | elif media_url.startswith("data:"): 144 | _, encoded = media_url.split(",", 1) 145 | estimated_size = len(encoded) * 3 / 4 146 | if estimated_size > self._get_max_file_size(): 147 | raise ValueError(f"Base64-encoded {self._get_media_type_name()} exceeds size limit") 148 | data = base64.b64decode(encoded) 149 | 150 | if not self._validate_media_data(data): 151 | raise ValueError(f"Invalid {self._get_media_type_name()} file format") 152 | 153 | return self._process_media_data(data, cached_path, **kwargs) 154 | else: 155 | session = await self._get_session() 156 | async with session.get(media_url) as response: 157 | response.raise_for_status() 158 | data = await response.read() 159 | 160 | if not self._validate_media_data(data): 161 | raise ValueError(f"Invalid {self._get_media_type_name()} file format") 162 | 163 | return self._process_media_data(data, cached_path, **kwargs) 164 | 165 | except Exception as e: 166 | logger.error(f"Failed to process {self._get_media_type_name()}: {str(e)}") 167 | raise ValueError(f"Failed to process {self._get_media_type_name()}: {str(e)}") 168 | finally: 169 | gc.collect() 170 | 171 | def clear_cache(self): 172 | """Manually clear the hash cache to free memory.""" 173 | self._hash_cache.clear() 174 | self._cache_access_times.clear() 175 | gc.collect() 176 | 177 | async def cleanup(self): 178 | if hasattr(self, '_cleaned') and self._cleaned: 179 | return 180 | self._cleaned = True 181 | try: 182 | # Clear caches before cleanup 183 | self.clear_cache() 184 | 185 | if self._session and not self._session.closed: 186 | await self._session.close() 187 | except Exception as e: 188 | logger.warning(f"Exception closing aiohttp session: {str(e)}") 189 | try: 190 | self.executor.shutdown(wait=True) 191 | except Exception as e: 192 | logger.warning(f"Exception shutting down executor: {str(e)}") 193 | try: 194 | self.temp_dir.cleanup() 195 | except Exception as e: 196 | logger.warning(f"Exception cleaning up temp directory: {str(e)}") 197 | 198 | async def __aenter__(self): 199 | return self 200 | 201 | async def __aexit__(self, exc_type, exc, tb): 202 | await self.cleanup() 203 | 204 | def __del__(self): 205 | # Async cleanup cannot be reliably performed in __del__ 206 | # Please use 'async with Processor()' or call 'await cleanup()' explicitly. 207 | pass -------------------------------------------------------------------------------- /app/core/image_processor.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import asyncio 3 | from PIL import Image 4 | from loguru import logger 5 | from io import BytesIO 6 | from typing import List 7 | from app.core.base_processor import BaseProcessor 8 | 9 | 10 | class ImageProcessor(BaseProcessor): 11 | """Image processor for handling image files with caching, validation, and processing.""" 12 | 13 | def __init__(self, max_workers: int = 4, cache_size: int = 1000): 14 | super().__init__(max_workers, cache_size) 15 | Image.MAX_IMAGE_PIXELS = 100000000 # Limit to 100 megapixels 16 | 17 | def _get_media_format(self, media_url: str, data: bytes = None) -> str: 18 | """Determine image format from URL or data.""" 19 | # For images, we always save as JPEG for consistency 20 | return "jpg" 21 | 22 | def _validate_media_data(self, data: bytes) -> bool: 23 | """Basic validation of image data.""" 24 | if len(data) < 100: # Too small to be a valid image file 25 | return False 26 | 27 | # Check for common image file signatures 28 | image_signatures = [ 29 | b'\xff\xd8\xff', # JPEG 30 | b'\x89PNG\r\n\x1a\n', # PNG 31 | b'GIF87a', # GIF87a 32 | b'GIF89a', # GIF89a 33 | b'BM', # BMP 34 | b'II*\x00', # TIFF (little endian) 35 | b'MM\x00*', # TIFF (big endian) 36 | b'RIFF', # WebP (part of RIFF) 37 | ] 38 | 39 | for sig in image_signatures: 40 | if data.startswith(sig): 41 | return True 42 | 43 | # Additional check for WebP 44 | if data.startswith(b'RIFF') and b'WEBP' in data[:20]: 45 | return True 46 | 47 | return False 48 | 49 | def _get_timeout(self) -> int: 50 | """Get timeout for HTTP requests.""" 51 | return 30 # Standard timeout for images 52 | 53 | def _get_max_file_size(self) -> int: 54 | """Get maximum file size in bytes.""" 55 | return 100 * 1024 * 1024 # 100 MB limit for images 56 | 57 | def _get_media_type_name(self) -> str: 58 | """Get media type name for logging.""" 59 | return "image" 60 | 61 | def _resize_image_keep_aspect_ratio(self, image: Image.Image, max_size: int = 448) -> Image.Image: 62 | width, height = image.size 63 | if width <= max_size and height <= max_size: 64 | return image 65 | if width > height: 66 | new_width = max_size 67 | new_height = int(height * max_size / width) 68 | else: 69 | new_height = max_size 70 | new_width = int(width * max_size / height) 71 | 72 | image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) 73 | logger.info(f"Resized image to {new_width}x{new_height} from {width}x{height}") 74 | 75 | return image 76 | 77 | def _prepare_image_for_saving(self, image: Image.Image) -> Image.Image: 78 | if image.mode in ('RGBA', 'LA'): 79 | background = Image.new('RGB', image.size, (255, 255, 255)) 80 | if image.mode == 'RGBA': 81 | background.paste(image, mask=image.split()[3]) 82 | else: 83 | background.paste(image, mask=image.split()[1]) 84 | return background 85 | elif image.mode != 'RGB': 86 | return image.convert('RGB') 87 | return image 88 | 89 | def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> str: 90 | """Process image data and save to cached path.""" 91 | image = None 92 | resize = kwargs.get("resize", True) 93 | try: 94 | with Image.open(BytesIO(data), mode='r') as image: 95 | if resize: 96 | image = self._resize_image_keep_aspect_ratio(image) 97 | image = self._prepare_image_for_saving(image) 98 | image.save(cached_path, 'PNG', quality=100, optimize=True) 99 | 100 | self._cleanup_old_files() 101 | return cached_path 102 | finally: 103 | # Ensure image object is closed to free memory 104 | if image: 105 | try: 106 | image.close() 107 | except: 108 | pass 109 | 110 | async def process_image_url(self, image_url: str, resize: bool = True) -> str: 111 | """Process a single image URL and return path to cached file.""" 112 | return await self._process_single_media(image_url, resize=resize) 113 | 114 | async def process_image_urls(self, image_urls: List[str], resize: bool = True) -> List[str]: 115 | """Process multiple image URLs and return paths to cached files.""" 116 | tasks = [self.process_image_url(url, resize=resize) for url in image_urls] 117 | results = await asyncio.gather(*tasks, return_exceptions=True) 118 | # Force garbage collection after batch processing 119 | gc.collect() 120 | return results -------------------------------------------------------------------------------- /app/core/queue.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | from typing import Any, Dict, Optional, Callable, Awaitable, TypeVar, Generic 4 | import gc 5 | from loguru import logger 6 | 7 | T = TypeVar('T') 8 | 9 | class RequestItem(Generic[T]): 10 | """ 11 | Represents a single request in the queue. 12 | """ 13 | def __init__(self, request_id: str, data: Any): 14 | self.request_id = request_id 15 | self.data = data 16 | self.created_at = time.time() 17 | self.future = asyncio.Future() 18 | 19 | def set_result(self, result: T) -> None: 20 | """Set the result for this request.""" 21 | if not self.future.done(): 22 | self.future.set_result(result) 23 | 24 | def set_exception(self, exc: Exception) -> None: 25 | """Set an exception for this request.""" 26 | if not self.future.done(): 27 | self.future.set_exception(exc) 28 | 29 | async def get_result(self) -> T: 30 | """Wait for and return the result of this request.""" 31 | return await self.future 32 | 33 | class RequestQueue: 34 | """ 35 | A simple asynchronous request queue with configurable concurrency. 36 | """ 37 | def __init__(self, max_concurrency: int = 2, timeout: float = 300.0, queue_size: int = 100): 38 | """ 39 | Initialize the request queue. 40 | 41 | Args: 42 | max_concurrency (int): Maximum number of concurrent requests to process. 43 | timeout (float): Timeout in seconds for request processing. 44 | queue_size (int): Maximum queue size. 45 | """ 46 | self.max_concurrency = max_concurrency 47 | self.timeout = timeout 48 | self.queue_size = queue_size 49 | self.semaphore = asyncio.Semaphore(max_concurrency) 50 | self.queue = asyncio.Queue(maxsize=queue_size) 51 | self.active_requests: Dict[str, RequestItem] = {} 52 | self._worker_task = None 53 | self._running = False 54 | 55 | async def start(self, processor: Callable[[Any], Awaitable[Any]]): 56 | """ 57 | Start the queue worker. 58 | 59 | Args: 60 | processor: Async function that processes queue items. 61 | """ 62 | if self._running: 63 | return 64 | 65 | self._running = True 66 | self._worker_task = asyncio.create_task(self._worker_loop(processor)) 67 | logger.info(f"Started request queue with max concurrency: {self.max_concurrency}") 68 | 69 | async def stop(self): 70 | """Stop the queue worker.""" 71 | if not self._running: 72 | return 73 | 74 | self._running = False 75 | 76 | # Cancel the worker task 77 | if self._worker_task and not self._worker_task.done(): 78 | self._worker_task.cancel() 79 | try: 80 | await self._worker_task 81 | except asyncio.CancelledError: 82 | pass 83 | 84 | # Cancel all pending requests 85 | pending_requests = list(self.active_requests.values()) 86 | for request in pending_requests: 87 | if not request.future.done(): 88 | request.future.cancel() 89 | # Clean up request data 90 | try: 91 | if hasattr(request, 'data'): 92 | del request.data 93 | except: 94 | pass 95 | 96 | self.active_requests.clear() 97 | 98 | # Clear the queue 99 | while not self.queue.empty(): 100 | try: 101 | self.queue.get_nowait() 102 | except asyncio.QueueEmpty: 103 | break 104 | 105 | # Force garbage collection after cleanup 106 | gc.collect() 107 | logger.info("Stopped request queue") 108 | 109 | async def _worker_loop(self, processor: Callable[[Any], Awaitable[Any]]): 110 | """ 111 | Main worker loop that processes queue items. 112 | 113 | Args: 114 | processor: Async function that processes queue items. 115 | """ 116 | while self._running: 117 | try: 118 | # Get the next item from the queue 119 | request = await self.queue.get() 120 | 121 | # Process the request with concurrency control 122 | asyncio.create_task(self._process_request(request, processor)) 123 | 124 | except asyncio.CancelledError: 125 | break 126 | except Exception as e: 127 | logger.error(f"Error in worker loop: {str(e)}") 128 | 129 | async def _process_request(self, request: RequestItem, processor: Callable[[Any], Awaitable[Any]]): 130 | """ 131 | Process a single request with timeout and error handling. 132 | 133 | Args: 134 | request: The request to process. 135 | processor: Async function that processes the request. 136 | """ 137 | # Use semaphore to limit concurrency 138 | async with self.semaphore: 139 | try: 140 | # Process with timeout 141 | processing_start = time.time() 142 | result = await asyncio.wait_for( 143 | processor(request.data), 144 | timeout=self.timeout 145 | ) 146 | processing_time = time.time() - processing_start 147 | 148 | # Set the result 149 | request.set_result(result) 150 | logger.info(f"Request {request.request_id} processed in {processing_time:.2f}s") 151 | 152 | except asyncio.TimeoutError: 153 | request.set_exception(TimeoutError(f"Request processing timed out after {self.timeout}s")) 154 | logger.warning(f"Request {request.request_id} timed out after {self.timeout}s") 155 | 156 | except Exception as e: 157 | request.set_exception(e) 158 | logger.error(f"Error processing request {request.request_id}: {str(e)}") 159 | 160 | finally: 161 | # Always remove from active requests, even if an error occurred 162 | removed_request = self.active_requests.pop(request.request_id, None) 163 | if removed_request: 164 | # Clean up the request object 165 | try: 166 | if hasattr(removed_request, 'data'): 167 | del removed_request.data 168 | except: 169 | pass 170 | # Force garbage collection periodically to prevent memory buildup 171 | if len(self.active_requests) % 10 == 0: # Every 10 requests 172 | gc.collect() 173 | 174 | async def enqueue(self, request_id: str, data: Any) -> RequestItem: 175 | """ 176 | Add a request to the queue. 177 | 178 | Args: 179 | request_id: Unique ID for the request. 180 | data: The request data to process. 181 | 182 | Returns: 183 | RequestItem: The queued request item. 184 | 185 | Raises: 186 | asyncio.QueueFull: If the queue is full. 187 | """ 188 | if not self._running: 189 | raise RuntimeError("Queue is not running") 190 | 191 | # Create request item 192 | request = RequestItem(request_id, data) 193 | 194 | # Add to active requests and queue 195 | self.active_requests[request_id] = request 196 | 197 | try: 198 | # This will raise QueueFull if the queue is full 199 | await asyncio.wait_for( 200 | self.queue.put(request), 201 | timeout=1.0 # Short timeout for queue put 202 | ) 203 | queue_time = time.time() - request.created_at 204 | logger.info(f"Request {request_id} queued (wait: {queue_time:.2f}s)") 205 | return request 206 | 207 | except asyncio.TimeoutError: 208 | self.active_requests.pop(request_id, None) 209 | raise asyncio.QueueFull("Request queue is full and timed out waiting for space") 210 | 211 | async def submit(self, request_id: str, data: Any) -> Any: 212 | """ 213 | Submit a request and wait for its result. 214 | 215 | Args: 216 | request_id: Unique ID for the request. 217 | data: The request data to process. 218 | 219 | Returns: 220 | The result of processing the request. 221 | 222 | Raises: 223 | Various exceptions that may occur during processing. 224 | """ 225 | request = await self.enqueue(request_id, data) 226 | return await request.get_result() 227 | 228 | def get_queue_stats(self) -> Dict[str, Any]: 229 | """ 230 | Get queue statistics. 231 | 232 | Returns: 233 | Dict with queue statistics. 234 | """ 235 | return { 236 | "running": self._running, 237 | "queue_size": self.queue.qsize(), 238 | "max_queue_size": self.queue_size, 239 | "active_requests": len(self.active_requests), 240 | "max_concurrency": self.max_concurrency 241 | } 242 | 243 | # Alias for the async stop method to maintain consistency in cleanup interfaces 244 | async def stop_async(self): 245 | """Alias for stop - stops the queue worker asynchronously.""" 246 | await self.stop() -------------------------------------------------------------------------------- /app/core/video_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import asyncio 4 | from loguru import logger 5 | from typing import List 6 | from app.core.base_processor import BaseProcessor 7 | 8 | 9 | class VideoProcessor(BaseProcessor): 10 | """Video processor for handling video files with caching, validation, and processing.""" 11 | 12 | def __init__(self, max_workers: int = 4, cache_size: int = 1000): 13 | super().__init__(max_workers, cache_size) 14 | # Supported video formats 15 | self._supported_formats = {'.mp4', '.avi', '.mov'} 16 | 17 | def _get_media_format(self, media_url: str, data: bytes = None) -> str: 18 | """Determine video format from URL or data.""" 19 | if media_url.startswith("data:"): 20 | # Extract format from data URL 21 | mime_type = media_url.split(";")[0].split(":")[1] 22 | if "mp4" in mime_type: 23 | return "mp4" 24 | elif "quicktime" in mime_type or "mov" in mime_type: 25 | return "mov" 26 | elif "x-msvideo" in mime_type or "avi" in mime_type: 27 | return "avi" 28 | else: 29 | # Extract format from file extension 30 | ext = os.path.splitext(media_url.lower())[1] 31 | if ext in self._supported_formats: 32 | return ext[1:] # Remove the dot 33 | 34 | # Default to mp4 if format cannot be determined 35 | return "mp4" 36 | 37 | def _validate_media_data(self, data: bytes) -> bool: 38 | """Basic validation of video data.""" 39 | if len(data) < 100: # Too small to be a valid video file 40 | return False 41 | 42 | # Check for common video file signatures 43 | video_signatures = [ 44 | # MP4/M4V/MOV (ISO Base Media File Format) 45 | (b'\x00\x00\x00\x14ftypisom', 0), # MP4 46 | (b'\x00\x00\x00\x18ftyp', 0), # MP4/MOV 47 | (b'\x00\x00\x00\x1cftyp', 0), # MP4/MOV 48 | (b'\x00\x00\x00\x20ftyp', 0), # MP4/MOV 49 | (b'ftyp', 4), # MP4/MOV (ftyp at offset 4) 50 | 51 | # AVI 52 | (b'RIFF', 0), # AVI (also check for 'AVI ' at offset 8) 53 | 54 | # WebM/MKV (Matroska) 55 | (b'\x1a\x45\xdf\xa3', 0), # Matroska/WebM 56 | 57 | # FLV 58 | (b'FLV\x01', 0), # Flash Video 59 | 60 | # MPEG 61 | (b'\x00\x00\x01\xba', 0), # MPEG PS 62 | (b'\x00\x00\x01\xb3', 0), # MPEG PS 63 | 64 | # QuickTime 65 | (b'moov', 0), # QuickTime 66 | (b'mdat', 0), # QuickTime 67 | ] 68 | 69 | for sig, offset in video_signatures: 70 | if len(data) > offset + len(sig): 71 | if data[offset:offset+len(sig)] == sig: 72 | # Additional validation for AVI 73 | if sig == b'RIFF' and len(data) > 12: 74 | if data[8:12] == b'AVI ': 75 | return True 76 | elif sig == b'RIFF': 77 | continue # Not AVI, might be WAV 78 | else: 79 | return True 80 | 81 | # Check for ftyp box anywhere in first 32 bytes (MP4/MOV) 82 | if b'ftyp' in data[:32]: 83 | return True 84 | 85 | # Allow unknown formats to pass through for flexibility 86 | return True 87 | 88 | def _get_timeout(self) -> int: 89 | """Get timeout for HTTP requests.""" 90 | return 120 # Longer timeout for video files (2 minutes) 91 | 92 | def _get_max_file_size(self) -> int: 93 | """Get maximum file size in bytes.""" 94 | return 1024 * 1024 * 1024 # 1 GB limit for videos 95 | 96 | def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> str: 97 | """Process video data and save to cached path.""" 98 | try: 99 | with open(cached_path, 'wb') as f: 100 | f.write(data) 101 | 102 | logger.info(f"Saved video to {cached_path} ({len(data)} bytes)") 103 | self._cleanup_old_files() 104 | return cached_path 105 | except Exception as e: 106 | logger.error(f"Failed to save video data: {str(e)}") 107 | raise 108 | 109 | def _get_media_type_name(self) -> str: 110 | """Get media type name for logging.""" 111 | return "video" 112 | 113 | async def process_video_url(self, video_url: str) -> str: 114 | """ 115 | Process a single video URL and return path to cached file. 116 | 117 | Supports: 118 | - HTTP/HTTPS URLs (downloads video) 119 | - Local file paths (copies to cache) 120 | - Data URLs (base64 encoded videos) 121 | 122 | Args: 123 | video_url: URL, file path, or data URL of the video 124 | 125 | Returns: 126 | Path to the cached video file in temp directory 127 | """ 128 | return await self._process_single_media(video_url) 129 | 130 | async def process_video_urls(self, video_urls: List[str]) -> List[str]: 131 | """ 132 | Process multiple video URLs and return paths to cached files. 133 | 134 | Args: 135 | video_urls: List of URLs, file paths, or data URLs of videos 136 | 137 | Returns: 138 | List of paths to cached video files 139 | """ 140 | tasks = [self.process_video_url(url) for url in video_urls] 141 | results = await asyncio.gather(*tasks, return_exceptions=True) 142 | # Force garbage collection after batch processing 143 | gc.collect() 144 | return results 145 | -------------------------------------------------------------------------------- /app/handler/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | MLX model handlers for text, multimodal, image generation, and embeddings models. 3 | """ 4 | 5 | from app.handler.mlx_lm import MLXLMHandler 6 | from app.handler.mlx_vlm import MLXVLMHandler 7 | from app.handler.mlx_embeddings import MLXEmbeddingsHandler 8 | 9 | # Optional mflux import - only available if flux extra is installed 10 | try: 11 | from app.handler.mflux import MLXFluxHandler 12 | MFLUX_AVAILABLE = True 13 | except ImportError: 14 | MLXFluxHandler = None 15 | MFLUX_AVAILABLE = False 16 | 17 | __all__ = [ 18 | "MLXLMHandler", 19 | "MLXVLMHandler", 20 | "MLXFluxHandler", 21 | "MLXEmbeddingsHandler", 22 | "MFLUX_AVAILABLE" 23 | ] 24 | -------------------------------------------------------------------------------- /app/handler/mlx_embeddings.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import time 3 | import uuid 4 | from http import HTTPStatus 5 | from typing import Any, Dict, List 6 | 7 | from fastapi import HTTPException 8 | from loguru import logger 9 | 10 | from app.core.queue import RequestQueue 11 | from app.schemas.openai import EmbeddingRequest 12 | from app.utils.errors import create_error_response 13 | from app.models.mlx_embeddings import MLX_Embeddings 14 | 15 | class MLXEmbeddingsHandler: 16 | """ 17 | Handler class for making requests to the underlying MLX embeddings model service. 18 | Provides request queuing, metrics tracking, and robust error handling with memory management. 19 | """ 20 | 21 | def __init__(self, model_path: str, max_concurrency: int = 1): 22 | """ 23 | Initialize the handler with the specified model path. 24 | 25 | Args: 26 | model_path (str): Path to the embeddings model to load. 27 | max_concurrency (int): Maximum number of concurrent model inference tasks. 28 | """ 29 | self.model_path = model_path 30 | self.model = MLX_Embeddings(model_path) 31 | self.model_created = int(time.time()) # Store creation time when model is loaded 32 | 33 | # Initialize request queue for embedding tasks 34 | self.request_queue = RequestQueue(max_concurrency=max_concurrency) 35 | 36 | logger.info(f"Initialized MLXEmbeddingsHandler with model path: {model_path}") 37 | 38 | async def get_models(self) -> List[Dict[str, Any]]: 39 | """ 40 | Get list of available models with their metadata. 41 | """ 42 | try: 43 | return [{ 44 | "id": self.model_path, 45 | "object": "model", 46 | "created": self.model_created, 47 | "owned_by": "local" 48 | }] 49 | except Exception as e: 50 | logger.error(f"Error getting models: {str(e)}") 51 | return [] 52 | 53 | async def initialize(self, config: Dict[str, Any]): 54 | """ 55 | Initialize the request queue with configuration. 56 | 57 | Args: 58 | config: Dictionary containing queue configuration. 59 | """ 60 | await self.request_queue.start(self._process_request) 61 | 62 | async def generate_embeddings_response(self, request: EmbeddingRequest): 63 | """ 64 | Generate embeddings for a given text input. 65 | 66 | Args: 67 | request: EmbeddingRequest object containing the text input. 68 | 69 | Returns: 70 | List[float]: Embeddings for the input text. 71 | """ 72 | try: 73 | # Create a unique request ID 74 | request_id = f"embeddings-{uuid.uuid4()}" 75 | if isinstance(request.input, str): 76 | request.input = [request.input] 77 | request_data = { 78 | "type": "embeddings", 79 | "input": request.input, 80 | "max_length": getattr(request, 'max_length', 512) 81 | } 82 | 83 | # Submit to the request queue 84 | response = await self.request_queue.submit(request_id, request_data) 85 | 86 | return response 87 | 88 | except Exception as e: 89 | logger.error(f"Error in embeddings generation: {str(e)}") 90 | content = create_error_response(f"Failed to generate embeddings: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR) 91 | raise HTTPException(status_code=500, detail=content) 92 | 93 | async def _process_request(self, request_data: Dict[str, Any]) -> List[List[float]]: 94 | """ 95 | Process an embeddings request. This is the worker function for the request queue. 96 | 97 | Args: 98 | request_data: Dictionary containing the request data. 99 | 100 | Returns: 101 | List[List[float]]: The embeddings for the input texts. 102 | """ 103 | try: 104 | # Check if the request is for embeddings 105 | if request_data.get("type") == "embeddings": 106 | result = self.model( 107 | texts=request_data["input"], 108 | max_length=request_data.get("max_length", 512) 109 | ) 110 | # Force garbage collection after embeddings 111 | gc.collect() 112 | return result 113 | 114 | raise ValueError(f"Unknown request type: {request_data.get('type')}") 115 | 116 | except Exception as e: 117 | logger.error(f"Error processing embeddings request: {str(e)}") 118 | # Clean up on error 119 | gc.collect() 120 | raise 121 | 122 | async def get_queue_stats(self) -> Dict[str, Any]: 123 | """ 124 | Get statistics from the request queue and performance metrics. 125 | 126 | Returns: 127 | Dict with queue and performance statistics. 128 | """ 129 | queue_stats = self.request_queue.get_queue_stats() 130 | 131 | return { 132 | "queue_stats": queue_stats, 133 | } 134 | 135 | async def cleanup(self): 136 | """ 137 | Cleanup resources and stop the request queue before shutdown. 138 | 139 | This method ensures all pending requests are properly cancelled 140 | and resources are released. 141 | """ 142 | try: 143 | logger.info("Cleaning up MLXEmbeddingsHandler resources") 144 | if hasattr(self, 'request_queue'): 145 | await self.request_queue.stop() 146 | if hasattr(self, 'model'): 147 | self.model.cleanup() 148 | logger.info("MLXEmbeddingsHandler cleanup completed successfully") 149 | except Exception as e: 150 | logger.error(f"Error during MLXEmbeddingsHandler cleanup: {str(e)}") 151 | raise 152 | 153 | def __del__(self): 154 | """ 155 | Destructor to ensure cleanup on object deletion. 156 | Note: Async cleanup cannot be reliably performed in __del__. 157 | Please use 'await cleanup()' explicitly. 158 | """ 159 | if hasattr(self, '_cleaned') and self._cleaned: 160 | return 161 | # Set flag to prevent multiple cleanup attempts 162 | self._cleaned = True -------------------------------------------------------------------------------- /app/handler/mlx_whisper.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import os 4 | import tempfile 5 | import time 6 | import uuid 7 | from typing import Any, AsyncGenerator, Dict, List, Optional 8 | from http import HTTPStatus 9 | 10 | from fastapi import HTTPException 11 | from loguru import logger 12 | 13 | from app.core.queue import RequestQueue 14 | from app.models.mlx_whisper import MLX_Whisper, calculate_audio_duration 15 | from app.schemas.openai import ( 16 | TranscriptionRequest, 17 | TranscriptionResponse, 18 | TranscriptionUsageAudio, 19 | TranscriptionResponseFormat, 20 | TranscriptionResponseStream, 21 | TranscriptionResponseStreamChoice, 22 | Delta 23 | ) 24 | from app.utils.errors import create_error_response 25 | 26 | class MLXWhisperHandler: 27 | """ 28 | Handler class for making requests to the underlying MLX Whisper model service. 29 | Provides request queuing, metrics tracking, and robust error handling for audio transcription. 30 | """ 31 | 32 | def __init__(self, model_path: str, max_concurrency: int = 1): 33 | """ 34 | Initialize the handler with the specified model path. 35 | 36 | Args: 37 | model_path (str): Path to the model directory. 38 | max_concurrency (int): Maximum number of concurrent model inference tasks. 39 | """ 40 | self.model_path = model_path 41 | self.model = MLX_Whisper(model_path) 42 | self.model_created = int(time.time()) # Store creation time when model is loaded 43 | 44 | # Initialize request queue for audio transcription tasks 45 | self.request_queue = RequestQueue(max_concurrency=max_concurrency) 46 | 47 | logger.info(f"Initialized MLXWhisperHandler with model path: {model_path}") 48 | 49 | async def get_models(self) -> List[Dict[str, Any]]: 50 | """ 51 | Get list of available models with their metadata. 52 | """ 53 | try: 54 | return [{ 55 | "id": self.model_path, 56 | "object": "model", 57 | "created": self.model_created, 58 | "owned_by": "local" 59 | }] 60 | except Exception as e: 61 | logger.error(f"Error getting models: {str(e)}") 62 | return [] 63 | 64 | async def initialize(self, queue_config: Optional[Dict[str, Any]] = None): 65 | """Initialize the handler and start the request queue.""" 66 | if not queue_config: 67 | queue_config = { 68 | "max_concurrency": 1, 69 | "timeout": 600, # Longer timeout for audio processing 70 | "queue_size": 50 71 | } 72 | self.request_queue = RequestQueue( 73 | max_concurrency=queue_config.get("max_concurrency"), 74 | timeout=queue_config.get("timeout"), 75 | queue_size=queue_config.get("queue_size") 76 | ) 77 | await self.request_queue.start(self._process_request) 78 | logger.info("Initialized MLXWhisperHandler and started request queue") 79 | 80 | async def generate_transcription_response(self, request: TranscriptionRequest) -> TranscriptionResponse: 81 | """ 82 | Generate a transcription response for the given request. 83 | """ 84 | request_id = f"transcription-{uuid.uuid4()}" 85 | temp_file_path = None 86 | 87 | try: 88 | request_data = await self._prepare_transcription_request(request) 89 | temp_file_path = request_data.get("audio_path") 90 | response = await self.request_queue.submit(request_id, request_data) 91 | response_data = TranscriptionResponse( 92 | text=response["text"], 93 | usage=TranscriptionUsageAudio( 94 | type="duration", 95 | seconds=int(calculate_audio_duration(temp_file_path)) 96 | ) 97 | ) 98 | if request.response_format == TranscriptionResponseFormat.JSON: 99 | return response_data 100 | else: 101 | # dump to string for text response 102 | return json.dumps(response_data.model_dump()) 103 | finally: 104 | # Clean up temporary file 105 | if temp_file_path and os.path.exists(temp_file_path): 106 | try: 107 | os.unlink(temp_file_path) 108 | logger.debug(f"Cleaned up temporary file: {temp_file_path}") 109 | except Exception as e: 110 | logger.warning(f"Failed to clean up temporary file {temp_file_path}: {str(e)}") 111 | # Force garbage collection 112 | gc.collect() 113 | 114 | async def generate_transcription_stream_from_data( 115 | self, 116 | request_data: Dict[str, Any], 117 | response_format: TranscriptionResponseFormat 118 | ) -> AsyncGenerator[str, None]: 119 | """ 120 | Generate a transcription stream from prepared request data. 121 | Yields SSE-formatted chunks with timing information. 122 | 123 | Args: 124 | request_data: Prepared request data with audio_path already saved 125 | response_format: The response format (json or text) 126 | """ 127 | request_id = f"transcription-{uuid.uuid4()}" 128 | created_time = int(time.time()) 129 | temp_file_path = request_data.get("audio_path") 130 | 131 | try: 132 | # Set stream mode 133 | request_data["stream"] = True 134 | 135 | # Get the generator directly from the model (bypass queue for streaming) 136 | generator = self.model( 137 | audio_path=request_data.pop("audio_path"), 138 | **request_data 139 | ) 140 | 141 | # Stream each chunk 142 | for chunk in generator: 143 | # Create streaming response 144 | stream_response = TranscriptionResponseStream( 145 | id=request_id, 146 | object="transcription.chunk", 147 | created=created_time, 148 | model=self.model_path, 149 | choices=[ 150 | TranscriptionResponseStreamChoice( 151 | delta=Delta( 152 | content=chunk.get("text", "") 153 | ), 154 | finish_reason=None 155 | ) 156 | ] 157 | ) 158 | 159 | # Yield as SSE format 160 | yield f"data: {stream_response.model_dump_json()}\n\n" 161 | 162 | # Send final chunk with finish_reason 163 | final_response = TranscriptionResponseStream( 164 | id=request_id, 165 | object="transcription.chunk", 166 | created=created_time, 167 | model=self.model_path, 168 | choices=[ 169 | TranscriptionResponseStreamChoice( 170 | delta=Delta(content=""), 171 | finish_reason="stop" 172 | ) 173 | ] 174 | ) 175 | yield f"data: {final_response.model_dump_json()}\n\n" 176 | yield "data: [DONE]\n\n" 177 | 178 | except Exception as e: 179 | logger.error(f"Error during transcription streaming: {str(e)}") 180 | raise 181 | finally: 182 | # Clean up temporary file 183 | if temp_file_path and os.path.exists(temp_file_path): 184 | try: 185 | os.unlink(temp_file_path) 186 | logger.debug(f"Cleaned up temporary file: {temp_file_path}") 187 | except Exception as e: 188 | logger.warning(f"Failed to clean up temporary file {temp_file_path}: {str(e)}") 189 | # Clean up 190 | gc.collect() 191 | 192 | 193 | async def _process_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: 194 | """ 195 | Process an audio transcription request. This is the worker function for the request queue. 196 | 197 | Args: 198 | request_data: Dictionary containing the request data. 199 | 200 | Returns: 201 | Dict: The model's response containing transcribed text. 202 | """ 203 | try: 204 | # Extract request parameters 205 | audio_path = request_data.pop("audio_path") 206 | 207 | # Call the model with the audio file 208 | result = self.model( 209 | audio_path=audio_path, 210 | **request_data 211 | ) 212 | 213 | # Force garbage collection after model inference 214 | gc.collect() 215 | 216 | return result 217 | 218 | except Exception as e: 219 | logger.error(f"Error processing audio transcription request: {str(e)}") 220 | # Clean up on error 221 | gc.collect() 222 | raise 223 | 224 | async def _save_uploaded_file(self, file) -> str: 225 | """ 226 | Save the uploaded file to a temporary location. 227 | 228 | Args: 229 | file: The uploaded file object. 230 | 231 | Returns: 232 | str: Path to the temporary file. 233 | """ 234 | try: 235 | # Create a temporary file with the same extension as the uploaded file 236 | file_extension = os.path.splitext(file.filename)[1] if file.filename else ".wav" 237 | 238 | print("file_extension", file_extension) 239 | 240 | # Read file content first (this can only be done once with FastAPI uploads) 241 | content = await file.read() 242 | 243 | # Create temporary file 244 | with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file: 245 | # Write the file contents 246 | temp_file.write(content) 247 | temp_path = temp_file.name 248 | 249 | logger.debug(f"Saved uploaded file to temporary location: {temp_path}") 250 | return temp_path 251 | 252 | except Exception as e: 253 | logger.error(f"Error saving uploaded file: {str(e)}") 254 | raise 255 | 256 | async def _prepare_transcription_request( 257 | self, 258 | request: TranscriptionRequest 259 | ) -> Dict[str, Any]: 260 | """ 261 | Prepare a transcription request by parsing model parameters. 262 | 263 | Args: 264 | request: TranscriptionRequest object. 265 | audio_path: Path to the audio file. 266 | 267 | Returns: 268 | Dict containing the request data ready for the model. 269 | """ 270 | try: 271 | 272 | file = request.file 273 | 274 | file_path = await self._save_uploaded_file(file) 275 | request_data = { 276 | "audio_path": file_path, 277 | "verbose": False, 278 | } 279 | 280 | # Add optional parameters if provided 281 | if request.temperature is not None: 282 | request_data["temperature"] = request.temperature 283 | 284 | if request.language is not None: 285 | request_data["language"] = request.language 286 | 287 | if request.prompt is not None: 288 | request_data["initial_prompt"] = request.prompt 289 | 290 | # Map additional parameters if they exist 291 | decode_options = {} 292 | if request.language is not None: 293 | decode_options["language"] = request.language 294 | 295 | # Add decode options to request data 296 | request_data.update(decode_options) 297 | 298 | logger.debug(f"Prepared transcription request: {request_data}") 299 | 300 | return request_data 301 | 302 | except Exception as e: 303 | logger.error(f"Failed to prepare transcription request: {str(e)}") 304 | content = create_error_response( 305 | f"Failed to process request: {str(e)}", 306 | "bad_request", 307 | HTTPStatus.BAD_REQUEST 308 | ) 309 | raise HTTPException(status_code=400, detail=content) 310 | 311 | async def get_queue_stats(self) -> Dict[str, Any]: 312 | """ 313 | Get statistics from the request queue and performance metrics. 314 | 315 | Returns: 316 | Dict with queue and performance statistics. 317 | """ 318 | queue_stats = self.request_queue.get_queue_stats() 319 | 320 | return { 321 | "queue_stats": queue_stats, 322 | } 323 | 324 | async def cleanup(self): 325 | """ 326 | Cleanup resources and stop the request queue before shutdown. 327 | 328 | This method ensures all pending requests are properly cancelled 329 | and resources are released. 330 | """ 331 | try: 332 | logger.info("Cleaning up MLXWhisperHandler resources") 333 | if hasattr(self, 'request_queue'): 334 | await self.request_queue.stop() 335 | logger.info("MLXWhisperHandler cleanup completed successfully") 336 | except Exception as e: 337 | logger.error(f"Error during MLXWhisperHandler cleanup: {str(e)}") 338 | raise 339 | 340 | -------------------------------------------------------------------------------- /app/handler/parser/__init__.py: -------------------------------------------------------------------------------- 1 | from app.handler.parser.harmony import HarmonyParser 2 | from app.handler.parser.base import BaseToolParser, BaseThinkingParser 3 | from app.handler.parser.qwen3 import Qwen3ToolParser, Qwen3ThinkingParser 4 | from app.handler.parser.glm4_moe import Glm4MoEToolParser, Glm4MoEThinkingParser 5 | 6 | 7 | __all__ = ['BaseToolParser', 'BaseThinkingParser', 'Qwen3ToolParser', 'Qwen3ThinkingParser', 'HarmonyParser', 'Glm4MoEToolParser', 'Glm4MoEThinkingParser'] -------------------------------------------------------------------------------- /app/handler/parser/base.py: -------------------------------------------------------------------------------- 1 | import json 2 | from json_repair import repair_json 3 | from typing import Any, Dict, List, Optional, Tuple 4 | 5 | 6 | class BaseThinkingParser: 7 | def __init__(self, thinking_open: str, thinking_close: str): 8 | self.thinking_open = thinking_open 9 | self.thinking_close = thinking_close 10 | self.is_thinking = False 11 | 12 | def parse(self, content: str) -> Tuple[Optional[str], str]: 13 | if self.thinking_open in content: 14 | start_thinking = content.find(self.thinking_open) 15 | end_thinking = content.find(self.thinking_close) 16 | if end_thinking != -1: 17 | return content[start_thinking + len(self.thinking_open):end_thinking].strip(), content[end_thinking + len(self.thinking_close):].strip() 18 | return None, content 19 | 20 | def parse_stream(self, chunk: Optional[str] = None) -> Tuple[Optional[Any], bool]: 21 | """ 22 | Parse streaming chunks for thinking content. 23 | 24 | Returns: 25 | Tuple[parsed_content, is_complete]: 26 | - parsed_content: The parsed chunk (could be str, dict, or None) 27 | - is_complete: True if thinking section is complete 28 | """ 29 | if not self.is_thinking: 30 | if chunk == self.thinking_open: 31 | self.is_thinking = True 32 | return None, False 33 | return chunk, False 34 | if chunk == self.thinking_close: 35 | self.is_thinking = False 36 | return None, True 37 | 38 | return { 39 | "reasoning_content": chunk 40 | }, False 41 | 42 | class ParseToolState: 43 | NORMAL = 0 44 | FOUND_PREFIX = 1 45 | 46 | class BaseToolParser: 47 | def __init__(self, tool_open: str, tool_close: str): 48 | self.tool_open = tool_open 49 | self.tool_close = tool_close 50 | self.buffer = "" 51 | self.state = ParseToolState.NORMAL 52 | 53 | def get_tool_open(self): 54 | return self.tool_open 55 | 56 | def get_tool_close(self): 57 | return self.tool_close 58 | 59 | def parse(self, content: str) -> Tuple[Optional[List[Dict[str, Any]]], str]: 60 | tool_calls = [] 61 | remaining_content = "" 62 | start = 0 63 | while True: 64 | start_tool = content.find(self.tool_open, start) 65 | if start_tool == -1: 66 | break 67 | remaining_content += content[:start_tool].strip() 68 | end_tool = content.find(self.tool_close, start_tool + len(self.tool_open)) 69 | if end_tool == -1: 70 | break 71 | tool_content = content[start_tool + len(self.tool_open):end_tool].strip() 72 | 73 | try: 74 | repaired_json = repair_json(tool_content) 75 | json_output = json.loads(repaired_json) 76 | tool_calls.append(json_output) 77 | except json.JSONDecodeError: 78 | print("Error parsing tool call: ", tool_content) 79 | break 80 | content = content[end_tool + len(self.tool_close):].strip() 81 | return tool_calls, remaining_content 82 | 83 | def parse_stream(self, chunk: Optional[str] = None) -> Tuple[Optional[Any], bool]: 84 | """ 85 | Parse streaming chunks for tool calls. 86 | 87 | Returns: 88 | Tuple[parsed_content, is_complete]: 89 | - parsed_content: The parsed chunk (could be str, dict, or None) 90 | - is_complete: True if tool call is complete 91 | """ 92 | if chunk is None: 93 | return None, True 94 | 95 | if self.tool_open in chunk: 96 | self.state = ParseToolState.FOUND_PREFIX 97 | start_tool_index = chunk.find(self.tool_open) 98 | end_tool_index = chunk.find(self.tool_close) 99 | if end_tool_index != -1: 100 | self.buffer = chunk[start_tool_index + len(self.tool_open):end_tool_index] 101 | self.state = ParseToolState.NORMAL 102 | try: 103 | repaired_json = repair_json(self.buffer) 104 | json_output = json.loads(repaired_json) 105 | except json.JSONDecodeError: 106 | print("Error parsing tool call: ", self.buffer) 107 | return None, True 108 | return { 109 | "name": json_output["name"], 110 | "arguments": json.dumps(json_output["arguments"]) 111 | }, True 112 | 113 | self.buffer += chunk[start_tool_index + len(self.tool_open):] 114 | 115 | return chunk[:start_tool_index], False 116 | 117 | if self.state == ParseToolState.FOUND_PREFIX: 118 | end_tool_index = chunk.find(self.tool_close) 119 | if end_tool_index != -1: 120 | self.buffer += chunk[:end_tool_index] 121 | try: 122 | repaired_json = repair_json(self.buffer) 123 | json_output = json.loads(repaired_json) 124 | except json.JSONDecodeError: 125 | print("Error parsing tool call: ", self.buffer) 126 | return None, False 127 | return { 128 | "name": json_output["name"], 129 | "arguments": json.dumps(json_output["arguments"]) 130 | }, True 131 | else: 132 | self.buffer += chunk 133 | return None, False 134 | 135 | return chunk, False 136 | -------------------------------------------------------------------------------- /app/handler/parser/glm4_moe.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import Any, Dict, List, Optional, Tuple 4 | from app.handler.parser.base import BaseToolParser, BaseThinkingParser 5 | 6 | TOOL_OPEN = "" 7 | TOOL_CLOSE = "" 8 | THINKING_OPEN = "" 9 | THINKING_CLOSE = "" 10 | 11 | class Glm4MoEThinkingParser(BaseThinkingParser): 12 | """Parser for GLM4 model's thinking response format.""" 13 | 14 | def __init__(self): 15 | super().__init__( 16 | thinking_open=THINKING_OPEN, 17 | thinking_close=THINKING_CLOSE 18 | ) 19 | 20 | class Glm4MoEToolParser(BaseToolParser): 21 | """Parser for GLM4 model's tool response format with XML-style arguments.""" 22 | 23 | def __init__(self): 24 | super().__init__( 25 | tool_open=TOOL_OPEN, 26 | tool_close=TOOL_CLOSE 27 | ) 28 | # Regex patterns for parsing GLM4 XML-style tool calls 29 | self.func_call_regex = re.compile(r".*?", re.DOTALL) 30 | self.func_detail_regex = re.compile( 31 | r"([^\n]*)\n(.*)", re.DOTALL 32 | ) 33 | self.func_arg_regex = re.compile( 34 | r"(.*?)\s*(.*?)", re.DOTALL 35 | ) 36 | # State for streaming parsing 37 | self.stream_buffer = "" 38 | self.current_func_name = None 39 | self.current_args = {} 40 | self.parsing_tool = False 41 | 42 | def _deserialize_value(self, value: str) -> Any: 43 | """Try to deserialize a value from string to appropriate Python type.""" 44 | value = value.strip() 45 | 46 | # Try JSON parsing first 47 | try: 48 | return json.loads(value) 49 | except (json.JSONDecodeError, ValueError): 50 | pass 51 | 52 | # Try literal eval for Python literals 53 | try: 54 | import ast 55 | return ast.literal_eval(value) 56 | except (ValueError, SyntaxError): 57 | pass 58 | 59 | # Return as string if all else fails 60 | return value 61 | 62 | def parse(self, content: str) -> Tuple[List[Dict[str, Any]], str]: 63 | """ 64 | Parse complete content for GLM4 tool calls. 65 | 66 | Returns: 67 | Tuple of (list of tool calls, remaining content) 68 | """ 69 | tool_calls = [] 70 | matched_calls = self.func_call_regex.findall(content) 71 | 72 | try: 73 | for match in matched_calls: 74 | # Extract function name and arguments section 75 | detail_match = self.func_detail_regex.search(match) 76 | if not detail_match: 77 | continue 78 | 79 | func_name = detail_match.group(1).strip() 80 | args_section = detail_match.group(2) 81 | 82 | # Extract all key-value pairs 83 | arg_pairs = self.func_arg_regex.findall(args_section) 84 | arguments = {} 85 | for key, value in arg_pairs: 86 | arg_key = key.strip() 87 | arg_value = self._deserialize_value(value) 88 | arguments[arg_key] = arg_value 89 | 90 | # Build tool call object 91 | tool_calls.append({ 92 | "name": func_name, 93 | "arguments": json.dumps(arguments) 94 | }) 95 | except Exception as e: 96 | print(f"Error parsing GLM4 tool call: {e}") 97 | 98 | # Find content before first tool call 99 | first_tool_idx = content.find(self.tool_open) 100 | if first_tool_idx != -1: 101 | remaining_content = content[:first_tool_idx].strip() 102 | else: 103 | remaining_content = content.strip() 104 | 105 | return tool_calls, remaining_content 106 | 107 | def parse_stream(self, chunk: str) -> Tuple[Optional[Any], bool]: 108 | """ 109 | Parse streaming chunks for GLM4 tool calls. 110 | 111 | This handles the XML-style format incrementally. 112 | 113 | Returns: 114 | Tuple[parsed_content, is_complete]: 115 | - parsed_content: The parsed chunk (could be str, dict, or None) 116 | - is_complete: True if tool call is complete 117 | """ 118 | if chunk is None: 119 | return None, False 120 | 121 | self.stream_buffer += chunk 122 | 123 | # Check if we're starting a tool call 124 | if not self.parsing_tool: 125 | if self.tool_open in self.stream_buffer: 126 | tool_start_idx = self.stream_buffer.find(self.tool_open) 127 | # Return any content before the tool call 128 | content_before = self.stream_buffer[:tool_start_idx] 129 | self.stream_buffer = self.stream_buffer[tool_start_idx + len(self.tool_open):] 130 | self.parsing_tool = True 131 | self.current_func_name = None 132 | self.current_args = {} 133 | 134 | if content_before: 135 | return content_before, False 136 | return None, False 137 | else: 138 | # No tool call found yet, return the content (except last few chars as buffer) 139 | if len(self.stream_buffer) > len(self.tool_open): 140 | content_to_return = self.stream_buffer[:-len(self.tool_open)] 141 | self.stream_buffer = self.stream_buffer[-len(self.tool_open):] 142 | if content_to_return: 143 | return content_to_return, False 144 | return None, False 145 | 146 | # We're inside a tool call 147 | if self.tool_close in self.stream_buffer: 148 | tool_end_idx = self.stream_buffer.find(self.tool_close) 149 | tool_content = self.stream_buffer[:tool_end_idx] 150 | self.stream_buffer = self.stream_buffer[tool_end_idx + len(self.tool_close):] 151 | 152 | # Parse the complete tool call 153 | full_tool = f"{self.tool_open}{tool_content}{self.tool_close}" 154 | parsed_tools, _ = self.parse(full_tool) 155 | 156 | self.parsing_tool = False 157 | self.current_func_name = None 158 | self.current_args = {} 159 | 160 | if parsed_tools: 161 | tool = parsed_tools[0] 162 | # Return the complete tool call information 163 | return { 164 | "name": tool["name"], 165 | "arguments": tool["arguments"] 166 | }, True # Tool call complete 167 | return None, True 168 | 169 | # Still accumulating the tool call 170 | # Try to extract function name if we haven't yet 171 | if self.current_func_name is None: 172 | if '\n' in self.stream_buffer or len(self.stream_buffer) > 50: 173 | # Extract function name (first line) 174 | newline_idx = self.stream_buffer.find('\n') 175 | if newline_idx != -1: 176 | self.current_func_name = self.stream_buffer[:newline_idx].strip() 177 | self.stream_buffer = self.stream_buffer[newline_idx + 1:] 178 | # Return function name 179 | return { 180 | "name": self.current_func_name, 181 | "arguments": "" 182 | }, False 183 | 184 | # Check if we can parse any complete argument pairs 185 | arg_matches = list(self.func_arg_regex.finditer(self.stream_buffer)) 186 | if arg_matches: 187 | last_match = arg_matches[-1] 188 | # Only process if we have the complete closing tag 189 | if last_match.end() < len(self.stream_buffer): 190 | for match in arg_matches: 191 | arg_key = match.group(1).strip() 192 | arg_value = self._deserialize_value(match.group(2)) 193 | if arg_key not in self.current_args: 194 | self.current_args[arg_key] = arg_value 195 | 196 | # Remove processed content from buffer 197 | self.stream_buffer = self.stream_buffer[last_match.end():] 198 | 199 | # Return incremental arguments 200 | if self.current_args: 201 | return { 202 | "name": None, 203 | "arguments": json.dumps(self.current_args) 204 | }, False 205 | 206 | return None, False -------------------------------------------------------------------------------- /app/handler/parser/harmony.py: -------------------------------------------------------------------------------- 1 | from openai_harmony import ( 2 | load_harmony_encoding, 3 | HarmonyEncodingName, 4 | StreamableParser, 5 | Role 6 | ) 7 | from typing import Tuple, Dict, List, Optional, Any, Union 8 | import logging 9 | from enum import Enum 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | class ChannelType(Enum): 14 | """Enumeration of harmony channel types.""" 15 | ANALYSIS = "analysis" 16 | COMMENTARY = "commentary" 17 | FINAL = "final" 18 | 19 | class ParsingState(Enum): 20 | """Enumeration of parsing states.""" 21 | IDLE = "idle" 22 | PROCESSING_TOKENS = "processing_tokens" 23 | TOOL_PARSING = "tool_parsing" 24 | STREAM_ENDED = "stream_ended" 25 | 26 | # Harmony Parsing Helper Functions 27 | class HarmonyParser: 28 | """ 29 | Enhanced helper class for parsing GPT-OSS model responses using harmony encoding. 30 | 31 | This parser handles streaming and non-streaming responses with proper state management, 32 | error handling, and support for different harmony channels (analysis, commentary, final). 33 | """ 34 | 35 | def __init__(self): 36 | """Initialize the harmony parser with encoding and state management.""" 37 | try: 38 | self.enc = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) 39 | self.parser = StreamableParser(self.enc, role=Role.ASSISTANT) 40 | except Exception as e: 41 | logger.error(f"Failed to initialize harmony encoding: {e}") 42 | raise 43 | 44 | # Configuration 45 | self.end_tool_chunk = "<|call|>" 46 | 47 | # State management 48 | self._reset_state() 49 | 50 | def _reset_state(self) -> None: 51 | """Reset the parser state to initial values.""" 52 | self.tool_state = False 53 | self.end_stream = False 54 | self.parsing_state = ParsingState.IDLE 55 | self._accumulated_content = { 56 | ChannelType.ANALYSIS.value: [], 57 | ChannelType.COMMENTARY.value: [], 58 | ChannelType.FINAL.value: [] 59 | } 60 | self._current_function_name = None 61 | self._function_arguments = [] 62 | 63 | def parse_stream(self, text: Optional[str] = None) -> Tuple[Optional[Any], bool]: 64 | """ 65 | Parse streaming text input and return parsing state and extracted content. 66 | 67 | Args: 68 | text: The text chunk to parse, or None for empty chunks 69 | 70 | Returns: 71 | Tuple[parsed_content, is_complete]: 72 | - parsed_content: The parsed chunk (could be str, dict, or None) 73 | - is_complete: True if stream has ended 74 | 75 | Raises: 76 | Exception: If encoding or parsing fails 77 | """ 78 | # Handle end of stream marker 79 | if text == self.end_tool_chunk: 80 | logger.debug("End tool chunk detected, marking stream as ended") 81 | self.end_stream = True 82 | self.parsing_state = ParsingState.STREAM_ENDED 83 | return None, True 84 | 85 | # Handle empty or None text 86 | if not text: 87 | return None, self.end_stream 88 | 89 | try: 90 | self.parsing_state = ParsingState.PROCESSING_TOKENS 91 | text_tokens = self.enc.encode(text, allowed_special="all") 92 | 93 | # Initialize local variables for this chunk 94 | contents: List[str] = [] 95 | function_name: Optional[str] = None 96 | function_arguments: List[str] = [] 97 | reasoning_content: List[str] = [] 98 | current_channel: Optional[str] = None 99 | 100 | # Process each token 101 | for text_token in text_tokens: 102 | try: 103 | stream_text = self.parser.process(text_token) 104 | current_channel = stream_text.current_channel 105 | content = stream_text.last_content_delta 106 | 107 | if not content: 108 | continue 109 | 110 | # Handle different channels 111 | if current_channel == ChannelType.ANALYSIS.value: 112 | reasoning_content.append(content) 113 | self._accumulated_content[ChannelType.ANALYSIS.value].append(content) 114 | 115 | elif current_channel == ChannelType.COMMENTARY.value: 116 | self.parsing_state = ParsingState.TOOL_PARSING 117 | 118 | if self.tool_state: 119 | # Already parsing function arguments 120 | function_arguments.append(content) 121 | self._function_arguments.append(content) 122 | else: 123 | # Start of new function call 124 | self.tool_state = True 125 | if hasattr(stream_text, 'current_recipient') and stream_text.current_recipient: 126 | function_name = stream_text.current_recipient.replace("functions.", "") 127 | self._current_function_name = function_name 128 | function_arguments = [content] 129 | self._function_arguments = [content] 130 | 131 | elif current_channel == ChannelType.FINAL.value: 132 | contents.append(content) 133 | self._accumulated_content[ChannelType.FINAL.value].append(content) 134 | 135 | except Exception as token_error: 136 | logger.warning(f"Error processing token {text_token}: {token_error}") 137 | continue 138 | 139 | # Return appropriate response based on current channel 140 | return self._build_response(current_channel, { 141 | 'reasoning_content': reasoning_content, 142 | 'function_name': function_name, 143 | 'function_arguments': function_arguments, 144 | 'contents': contents 145 | }) 146 | 147 | except Exception as e: 148 | logger.error(f"Error in parse_stream: {e}") 149 | return None, self.end_stream 150 | 151 | def _build_response(self, current_channel: Optional[str], content_data: Dict[str, Any]) -> Tuple[Optional[Union[Dict[str, Any], str]], bool]: 152 | """ 153 | Build the appropriate response based on the current channel. 154 | 155 | Args: 156 | current_channel: The current harmony channel being processed 157 | content_data: Dictionary containing extracted content from different sources 158 | 159 | Returns: 160 | Tuple[parsed_content, is_complete]: 161 | - parsed_content: The parsed content (str or dict) 162 | - is_complete: Whether the stream has ended 163 | """ 164 | if not current_channel: 165 | return None, self.end_stream 166 | 167 | try: 168 | if current_channel == ChannelType.ANALYSIS.value: 169 | reasoning_content = content_data.get('reasoning_content', []) 170 | if reasoning_content: 171 | return { 172 | "reasoning_content": "".join(reasoning_content) 173 | }, self.end_stream 174 | 175 | elif current_channel == ChannelType.COMMENTARY.value: 176 | function_name = content_data.get('function_name') 177 | function_arguments = content_data.get('function_arguments', []) 178 | 179 | response = {} 180 | if function_name: 181 | response["name"] = function_name 182 | if function_arguments: 183 | response["arguments"] = "".join(function_arguments) 184 | 185 | if response: 186 | return response, self.end_stream 187 | 188 | elif current_channel == ChannelType.FINAL.value: 189 | contents = content_data.get('contents', []) 190 | if contents: 191 | return "".join(contents), self.end_stream 192 | except Exception as e: 193 | logger.error(f"Error building response for channel {current_channel}: {e}") 194 | 195 | return None, self.end_stream 196 | 197 | def reset(self) -> None: 198 | """Reset the parser to initial state for reuse.""" 199 | logger.debug("Resetting harmony parser state") 200 | self._reset_state() 201 | 202 | def get_accumulated_content(self, channel: Optional[str] = None) -> Dict[str, str]: 203 | """ 204 | Get accumulated content for all channels or a specific channel. 205 | 206 | Args: 207 | channel: Optional specific channel to retrieve content for 208 | 209 | Returns: 210 | Dictionary of channel content 211 | """ 212 | if channel and channel in self._accumulated_content: 213 | return {channel: "".join(self._accumulated_content[channel])} 214 | 215 | return { 216 | ch: "".join(content) for ch, content in self._accumulated_content.items() 217 | if content 218 | } 219 | 220 | def parse(self, text: str) -> Dict[str, Any]: 221 | """ 222 | Parse complete text response and extract structured content. 223 | 224 | This method processes the entire text at once (non-streaming) and extracts 225 | reasoning content, tool calls, and final content based on harmony channels. 226 | 227 | Args: 228 | text: The complete text response to parse 229 | 230 | Returns: 231 | Dictionary containing parsed content with keys: 232 | - reasoning_content: Analysis/thinking content 233 | - tool_calls: List of tool call objects 234 | - content: Final response content 235 | 236 | Raises: 237 | Exception: If encoding or parsing fails 238 | """ 239 | # Initialize result structure 240 | result = { 241 | "reasoning_content": None, 242 | "tool_calls": None, 243 | "content": None 244 | } 245 | 246 | if not text: 247 | logger.warning("Empty text provided to parse method") 248 | return result 249 | 250 | try: 251 | # Remove end tool chunk if present 252 | clean_text = text 253 | if self.end_tool_chunk in text: 254 | clean_text = text.split(self.end_tool_chunk)[0] 255 | logger.debug(f"Removed end tool chunk, processing {len(clean_text)} characters") 256 | 257 | # Encode and parse messages 258 | tokens = self.enc.encode(clean_text, allowed_special="all") 259 | parsed_messages = self.enc.parse_messages_from_completion_tokens(tokens, role=Role.ASSISTANT) 260 | 261 | # Process each parsed message 262 | for message in parsed_messages: 263 | try: 264 | if not hasattr(message, 'channel') or not hasattr(message, 'content'): 265 | logger.warning(f"Invalid message structure: {message}") 266 | continue 267 | 268 | if message.channel == ChannelType.ANALYSIS.value: 269 | if message.content and len(message.content) > 0: 270 | result["reasoning_content"] = message.content[0].text 271 | logger.debug("Extracted reasoning content") 272 | 273 | elif message.channel == ChannelType.COMMENTARY.value: 274 | if (hasattr(message, 'recipient') and message.recipient and 275 | message.content and len(message.content) > 0): 276 | 277 | tool_call = { 278 | "name": message.recipient.replace("functions.", ""), 279 | "arguments": message.content[0].text 280 | } 281 | result["tool_calls"] = [tool_call] 282 | logger.debug(f"Extracted tool call: {tool_call['name']}") 283 | 284 | elif message.channel == ChannelType.FINAL.value: 285 | if message.content and len(message.content) > 0: 286 | result["content"] = message.content[0].text 287 | logger.debug("Extracted final content") 288 | 289 | except Exception as msg_error: 290 | logger.warning(f"Error processing message: {msg_error}") 291 | continue 292 | 293 | except Exception as e: 294 | logger.error(f"Error in parse method: {e}") 295 | # Return partial results if available, don't raise 296 | 297 | return result 298 | 299 | def is_stream_ended(self) -> bool: 300 | """Check if the stream has ended.""" 301 | return self.end_stream 302 | 303 | def get_parsing_state(self) -> ParsingState: 304 | """Get the current parsing state.""" 305 | return self.parsing_state 306 | 307 | def is_tool_parsing_active(self) -> bool: 308 | """Check if currently parsing tool calls.""" 309 | return self.tool_state 310 | 311 | def get_current_function_info(self) -> Dict[str, Optional[str]]: 312 | """ 313 | Get information about the currently parsed function. 314 | 315 | Returns: 316 | Dictionary with function name and accumulated arguments 317 | """ 318 | return { 319 | "name": self._current_function_name, 320 | "arguments": "".join(self._function_arguments) if self._function_arguments else None 321 | } 322 | 323 | def __repr__(self) -> str: 324 | """String representation of the parser state.""" 325 | return (f"HarmonyParser(state={self.parsing_state.value}, " 326 | f"tool_state={self.tool_state}, " 327 | f"stream_ended={self.end_stream})") -------------------------------------------------------------------------------- /app/handler/parser/qwen3.py: -------------------------------------------------------------------------------- 1 | from app.handler.parser.base import BaseToolParser, BaseThinkingParser 2 | 3 | TOOL_OPEN = "" 4 | TOOL_CLOSE = "" 5 | THINKING_OPEN = "" 6 | THINKING_CLOSE = "" 7 | 8 | class Qwen3ToolParser(BaseToolParser): 9 | """Parser for Qwen3 model's tool response format.""" 10 | 11 | def __init__(self): 12 | super().__init__( 13 | tool_open=TOOL_OPEN, 14 | tool_close=TOOL_CLOSE 15 | ) 16 | 17 | class Qwen3ThinkingParser(BaseThinkingParser): 18 | """Parser for Qwen3 model's thinking response format.""" 19 | 20 | def __init__(self): 21 | super().__init__( 22 | thinking_open=THINKING_OPEN, 23 | thinking_close=THINKING_CLOSE 24 | ) -------------------------------------------------------------------------------- /app/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import gc 4 | import time 5 | from contextlib import asynccontextmanager 6 | 7 | import uvicorn 8 | from fastapi import FastAPI, Request 9 | from fastapi.middleware.cors import CORSMiddleware 10 | from fastapi.responses import JSONResponse 11 | from loguru import logger 12 | 13 | import mlx.core as mx 14 | from app.handler.mlx_vlm import MLXVLMHandler 15 | from app.handler.mlx_lm import MLXLMHandler 16 | from app.handler.mlx_embeddings import MLXEmbeddingsHandler 17 | from app.handler.mlx_whisper import MLXWhisperHandler 18 | from app.handler import MLXFluxHandler, MFLUX_AVAILABLE 19 | from app.api.endpoints import router 20 | from app.version import __version__ 21 | 22 | def configure_logging(log_file=None, no_log_file=False, log_level="INFO"): 23 | """Configure loguru logging based on CLI parameters.""" 24 | logger.remove() # Remove default handler 25 | 26 | # Add console handler 27 | logger.add( 28 | lambda msg: print(msg), 29 | level=log_level, 30 | format="{time:YYYY-MM-DD HH:mm:ss} | " 31 | "{level: <8} | " 32 | "{name}:{function}:{line} | " 33 | "✦ {message}", 34 | colorize=True 35 | ) 36 | 37 | # Add file handler if not disabled 38 | if not no_log_file: 39 | file_path = log_file if log_file else "logs/app.log" 40 | logger.add( 41 | file_path, 42 | rotation="500 MB", 43 | retention="10 days", 44 | level=log_level, 45 | format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}" 46 | ) 47 | 48 | # Logging will be configured in setup_server() with CLI arguments 49 | 50 | def parse_args(): 51 | parser = argparse.ArgumentParser(description="MLX OpenAI Compatible Server") 52 | parser.add_argument("--model-path", type=str, help="Path to the model (required for lm, multimodal, image-generation, image-edit, embeddings, whisper model types). With `image-generation` or `image-edit` model types, it should be the local path to the model.") 53 | parser.add_argument("--model-type", type=str, default="lm", choices=["lm", "multimodal", "image-generation", "image-edit", "embeddings", "whisper"], help="Model type") 54 | parser.add_argument("--context-length", type=int, default=None, help="Context length for language models. Only works with `lm` or `multimodal` model types.") 55 | parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") 56 | parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on") 57 | parser.add_argument("--max-concurrency", type=int, default=1, help="Maximum number of concurrent requests") 58 | parser.add_argument("--queue-timeout", type=int, default=300, help="Request timeout in seconds") 59 | parser.add_argument("--queue-size", type=int, default=100, help="Maximum queue size for pending requests") 60 | parser.add_argument("--quantize", type=int, default=8, help="Quantization level for the model. Only used for image-generation and image-edit Flux models.") 61 | parser.add_argument("--config-name", type=str, default=None, choices=["flux-schnell", "flux-dev", "flux-krea-dev", "flux-kontext-dev"], help="Config name of the model. Only used for image-generation and image-edit Flux models.") 62 | parser.add_argument("--lora-paths", type=str, default=None, help="Path to the LoRA file(s). Multiple paths should be separated by commas.") 63 | parser.add_argument("--lora-scales", type=str, default=None, help="Scale factor for the LoRA file(s). Multiple scales should be separated by commas.") 64 | parser.add_argument("--disable-auto-resize", action="store_true", help="Disable automatic model resizing. Only work for Vision Language Models.") 65 | parser.add_argument("--log-file", type=str, default=None, help="Path to log file. If not specified, logs will be written to 'logs/app.log' by default.") 66 | parser.add_argument("--no-log-file", action="store_true", help="Disable file logging entirely. Only console output will be shown.") 67 | parser.add_argument("--log-level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the logging level. Default is INFO.") 68 | 69 | args = parser.parse_args() 70 | 71 | return args 72 | 73 | 74 | def get_model_identifier(args): 75 | """Get the appropriate model identifier based on model type.""" 76 | return args.model_path 77 | 78 | def create_lifespan(config_args): 79 | """Factory function to create a lifespan context manager with access to config args.""" 80 | @asynccontextmanager 81 | async def lifespan(app: FastAPI): 82 | try: 83 | model_identifier = get_model_identifier(config_args) 84 | if config_args.model_type == "image-generation": 85 | logger.info(f"Initializing MLX handler with model name: {model_identifier}") 86 | else: 87 | logger.info(f"Initializing MLX handler with model path: {model_identifier}") 88 | 89 | if config_args.model_type == "multimodal": 90 | handler = MLXVLMHandler( 91 | model_path=model_identifier, 92 | context_length=getattr(config_args, 'context_length', None), 93 | max_concurrency=config_args.max_concurrency, 94 | disable_auto_resize=getattr(config_args, 'disable_auto_resize', False) 95 | ) 96 | elif config_args.model_type == "image-generation": 97 | if not MFLUX_AVAILABLE: 98 | raise ValueError("Image generation requires mflux. Install with: pip install git+https://github.com/cubist38/mflux.git") 99 | if not config_args.config_name in ["flux-schnell", "flux-dev", "flux-krea-dev"]: 100 | raise ValueError(f"Invalid config name: {config_args.config_name}. Only flux-schnell, flux-dev, and flux-krea-dev are supported for image generation.") 101 | handler = MLXFluxHandler( 102 | model_path=model_identifier, 103 | max_concurrency=config_args.max_concurrency, 104 | quantize=getattr(config_args, 'quantize', 8), 105 | config_name=config_args.config_name, 106 | lora_paths=getattr(config_args, 'lora_paths', None), 107 | lora_scales=getattr(config_args, 'lora_scales', None) 108 | ) 109 | elif config_args.model_type == "embeddings": 110 | handler = MLXEmbeddingsHandler( 111 | model_path=model_identifier, 112 | max_concurrency=config_args.max_concurrency 113 | ) 114 | elif config_args.model_type == "image-edit": 115 | if not MFLUX_AVAILABLE: 116 | raise ValueError("Image editing requires mflux. Install with: pip install git+https://github.com/cubist38/mflux.git") 117 | if config_args.config_name != "flux-kontext-dev": 118 | raise ValueError(f"Invalid config name: {config_args.config_name}. Only flux-kontext-dev is supported for image edit.") 119 | handler = MLXFluxHandler( 120 | model_path=model_identifier, 121 | max_concurrency=config_args.max_concurrency, 122 | quantize=getattr(config_args, 'quantize', 8), 123 | config_name=config_args.config_name, 124 | lora_paths=getattr(config_args, 'lora_paths', None), 125 | lora_scales=getattr(config_args, 'lora_scales', None) 126 | ) 127 | elif config_args.model_type == "whisper": 128 | handler = MLXWhisperHandler( 129 | model_path=model_identifier, 130 | max_concurrency=config_args.max_concurrency 131 | ) 132 | else: 133 | handler = MLXLMHandler( 134 | model_path=model_identifier, 135 | context_length=getattr(config_args, 'context_length', None), 136 | max_concurrency=config_args.max_concurrency 137 | ) 138 | # Initialize queue 139 | await handler.initialize({ 140 | "max_concurrency": config_args.max_concurrency, 141 | "timeout": config_args.queue_timeout, 142 | "queue_size": config_args.queue_size 143 | }) 144 | logger.info("MLX handler initialized successfully") 145 | app.state.handler = handler 146 | 147 | except Exception as e: 148 | logger.error(f"Failed to initialize MLX handler: {str(e)}") 149 | raise 150 | 151 | # Initial memory cleanup 152 | mx.clear_cache() 153 | gc.collect() 154 | 155 | yield 156 | 157 | # Shutdown 158 | logger.info("Shutting down application") 159 | if hasattr(app.state, "handler") and app.state.handler: 160 | try: 161 | # Use the proper cleanup method which handles both request queue and image processor 162 | logger.info("Cleaning up resources") 163 | await app.state.handler.cleanup() 164 | logger.info("Resources cleaned up successfully") 165 | except Exception as e: 166 | logger.error(f"Error during shutdown: {str(e)}") 167 | 168 | # Final memory cleanup 169 | mx.clear_cache() 170 | gc.collect() 171 | 172 | return lifespan 173 | 174 | # App instance will be created during setup with the correct lifespan 175 | app = None 176 | 177 | async def setup_server(args) -> uvicorn.Config: 178 | global app 179 | 180 | # Configure logging based on CLI parameters 181 | configure_logging( 182 | log_file=getattr(args, 'log_file', None), 183 | no_log_file=getattr(args, 'no_log_file', False), 184 | log_level=getattr(args, 'log_level', 'INFO') 185 | ) 186 | 187 | # Create FastAPI app with the configured lifespan 188 | app = FastAPI( 189 | title="OpenAI-compatible API", 190 | description="API for OpenAI-compatible chat completion and text embedding", 191 | version=__version__, 192 | lifespan=create_lifespan(args) 193 | ) 194 | 195 | app.include_router(router) 196 | 197 | # Add CORS middleware 198 | app.add_middleware( 199 | CORSMiddleware, 200 | allow_origins=["*"], # In production, replace with specific origins 201 | allow_credentials=True, 202 | allow_methods=["*"], 203 | allow_headers=["*"], 204 | ) 205 | 206 | @app.middleware("http") 207 | async def add_process_time_header(request: Request, call_next): 208 | start_time = time.time() 209 | response = await call_next(request) 210 | process_time = time.time() - start_time 211 | response.headers["X-Process-Time"] = str(process_time) 212 | 213 | # Periodic memory cleanup for long-running processes 214 | if hasattr(request.app.state, 'request_count'): 215 | request.app.state.request_count += 1 216 | else: 217 | request.app.state.request_count = 1 218 | 219 | # Clean up memory every 50 requests 220 | if request.app.state.request_count % 50 == 0: 221 | mx.clear_cache() 222 | gc.collect() 223 | logger.debug(f"Performed memory cleanup after {request.app.state.request_count} requests") 224 | 225 | return response 226 | 227 | @app.exception_handler(Exception) 228 | async def global_exception_handler(request: Request, exc: Exception): 229 | logger.error(f"Global exception handler caught: {str(exc)}", exc_info=True) 230 | return JSONResponse( 231 | status_code=500, 232 | content={"error": {"message": "Internal server error", "type": "internal_error"}} 233 | ) 234 | 235 | logger.info(f"Starting server on {args.host}:{args.port}") 236 | config = uvicorn.Config( 237 | app=app, 238 | host=args.host, 239 | port=args.port, 240 | log_level="info", 241 | access_log=True 242 | ) 243 | return config 244 | 245 | if __name__ == "__main__": 246 | args = parse_args() 247 | config = asyncio.run(setup_server(args)) 248 | uvicorn.Server(config).run() -------------------------------------------------------------------------------- /app/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cubist38/mlx-openai-server/25fc00aa21fc6794358849ebfb9d866a6e203eda/app/models/__init__.py -------------------------------------------------------------------------------- /app/models/mflux.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from PIL import Image 4 | from abc import ABC, abstractmethod 5 | from mflux.flux.flux import Flux1, Config 6 | from mflux.config.model_config import ModelConfig 7 | from mflux.kontext.flux_kontext import Flux1Kontext 8 | from typing import Dict, Type, Any, Optional, Union, List 9 | 10 | 11 | # Custom Exceptions 12 | class FluxModelError(Exception): 13 | """Base exception for Flux model errors.""" 14 | pass 15 | 16 | 17 | class ModelLoadError(FluxModelError): 18 | """Raised when model loading fails.""" 19 | pass 20 | 21 | 22 | class ModelGenerationError(FluxModelError): 23 | """Raised when image generation fails.""" 24 | pass 25 | 26 | 27 | class InvalidConfigurationError(FluxModelError): 28 | """Raised when configuration is invalid.""" 29 | pass 30 | 31 | 32 | class ModelConfiguration: 33 | """Configuration class for Flux models.""" 34 | 35 | def __init__(self, 36 | model_type: str, 37 | model_config: Optional[ModelConfig] = None, 38 | quantize: int = 8, 39 | default_steps: int = 20, 40 | default_guidance: float = 2.5, 41 | lora_paths: Optional[List[str]] = None, 42 | lora_scales: Optional[List[float]] = None): 43 | 44 | # Validate quantization level 45 | if quantize not in [4, 8, 16]: 46 | raise InvalidConfigurationError(f"Invalid quantization level: {quantize}. Must be 4, 8, or 16.") 47 | 48 | # Validate LoRA parameters: both must be provided together and have matching lengths 49 | if (lora_paths is None) != (lora_scales is None): 50 | raise InvalidConfigurationError( 51 | "Both lora_paths and lora_scales must be provided together." 52 | ) 53 | if lora_paths and lora_scales and len(lora_paths) != len(lora_scales): 54 | raise InvalidConfigurationError( 55 | f"lora_paths and lora_scales must have the same length (got {len(lora_paths)} and {len(lora_scales)})" 56 | ) 57 | 58 | self.model_type = model_type 59 | self.model_config = model_config 60 | self.quantize = quantize 61 | self.default_steps = default_steps 62 | self.default_guidance = default_guidance 63 | self.lora_paths = lora_paths 64 | self.lora_scales = lora_scales 65 | 66 | @classmethod 67 | def schnell(cls, quantize: int = 8, lora_paths: Optional[List[str]] = None, lora_scales: Optional[List[float]] = None) -> 'ModelConfiguration': 68 | """Create configuration for Flux Schnell model.""" 69 | return cls( 70 | model_type="schnell", 71 | model_config=ModelConfig.schnell(), 72 | quantize=quantize, 73 | default_steps=4, 74 | default_guidance=0.0, 75 | lora_paths=lora_paths, 76 | lora_scales=lora_scales 77 | ) 78 | 79 | @classmethod 80 | def dev(cls, quantize: int = 8, lora_paths: Optional[List[str]] = None, lora_scales: Optional[List[float]] = None) -> 'ModelConfiguration': 81 | """Create configuration for Flux Dev model.""" 82 | return cls( 83 | model_type="dev", 84 | model_config=ModelConfig.dev(), 85 | quantize=quantize, 86 | default_steps=25, 87 | default_guidance=3.5, 88 | lora_paths=lora_paths, 89 | lora_scales=lora_scales 90 | ) 91 | 92 | @classmethod 93 | def krea_dev(cls, quantize: int = 8, lora_paths: Optional[List[str]] = None, lora_scales: Optional[List[float]] = None) -> 'ModelConfiguration': 94 | """Create configuration for Flux Krea Dev model.""" 95 | return cls( 96 | model_type="krea-dev", 97 | model_config=ModelConfig.dev(), 98 | quantize=quantize, 99 | default_steps=28, 100 | default_guidance=4.5, 101 | lora_paths=lora_paths, 102 | lora_scales=lora_scales 103 | ) 104 | 105 | @classmethod 106 | def kontext(cls, quantize: int = 8) -> 'ModelConfiguration': 107 | """Create configuration for Flux Kontext model.""" 108 | return cls( 109 | model_type="kontext", 110 | model_config=None, # Kontext doesn't use ModelConfig 111 | quantize=quantize, 112 | default_steps=28, 113 | default_guidance=2.5, 114 | lora_paths=None, # Kontext doesn't support LoRA 115 | lora_scales=None 116 | ) 117 | 118 | 119 | class BaseFluxModel(ABC): 120 | """Abstract base class for Flux models with common functionality.""" 121 | 122 | def __init__(self, model_path: str, config: ModelConfiguration): 123 | self.model_path = model_path 124 | self.config = config 125 | self.logger = logging.getLogger(self.__class__.__name__) 126 | self._model = None 127 | self._is_loaded = False 128 | 129 | # Validate model path 130 | if not self._validate_model_path(): 131 | raise ModelLoadError(f"Invalid model path: {model_path}") 132 | 133 | self._load_model() 134 | 135 | def _validate_model_path(self) -> bool: 136 | """Validate that the model path exists or is a valid model name.""" 137 | # Check if it's a file path 138 | if os.path.exists(self.model_path): 139 | return True 140 | 141 | # Check if it's a valid model name (for downloading) 142 | valid_model_names = ["flux-dev", "flux-schnell", "flux-kontext-dev"] 143 | return self.model_path in valid_model_names 144 | 145 | @abstractmethod 146 | def _load_model(self): 147 | """Load the specific model implementation.""" 148 | pass 149 | 150 | @abstractmethod 151 | def _generate_image(self, prompt: str, seed: int, config: Config) -> Image.Image: 152 | """Generate image using the specific model implementation.""" 153 | pass 154 | 155 | def __call__(self, prompt: str, seed: int = 42, **kwargs) -> Image.Image: 156 | """Generate an image from a text prompt.""" 157 | if not self._is_loaded: 158 | raise ModelLoadError("Model is not loaded. Cannot generate image.") 159 | 160 | # Validate inputs 161 | if not prompt or not prompt.strip(): 162 | raise ModelGenerationError("Prompt cannot be empty.") 163 | 164 | if not isinstance(seed, int) or seed < 0: 165 | raise ModelGenerationError("Seed must be a non-negative integer.") 166 | 167 | # Merge default config values with provided kwargs 168 | try: 169 | generation_config = self._prepare_config(**kwargs) 170 | except Exception as e: 171 | raise ModelGenerationError(f"Failed to prepare configuration: {e}") 172 | 173 | self.logger.info(f"Generating image with prompt: '{prompt[:50]}...' " 174 | f"(steps: {generation_config.num_inference_steps}, seed: {seed})") 175 | 176 | try: 177 | result = self._generate_image(prompt, seed, generation_config) 178 | if result is None: 179 | raise ModelGenerationError("Model returned None instead of an image.") 180 | 181 | self.logger.info("Image generated successfully") 182 | return result 183 | except Exception as e: 184 | error_msg = f"Error generating image: {e}" 185 | self.logger.error(error_msg) 186 | raise ModelGenerationError(error_msg) from e 187 | 188 | def _prepare_config(self, **kwargs) -> Config: 189 | """Prepare configuration for image generation.""" 190 | # Validate dimensions 191 | width = kwargs.get('width', 1024) 192 | height = kwargs.get('height', 1024) 193 | 194 | if not isinstance(width, int) or width <= 0: 195 | raise ModelGenerationError("Width must be a positive integer.") 196 | if not isinstance(height, int) or height <= 0: 197 | raise ModelGenerationError("Height must be a positive integer.") 198 | 199 | # Validate steps 200 | steps = kwargs.get('num_inference_steps', self.config.default_steps) 201 | if not isinstance(steps, int) or steps <= 0: 202 | raise ModelGenerationError("Number of inference steps must be a positive integer.") 203 | 204 | # Validate guidance 205 | guidance = kwargs.get('guidance', self.config.default_guidance) 206 | if not isinstance(guidance, (int, float)) or guidance < 0: 207 | raise ModelGenerationError("Guidance must be a non-negative number.") 208 | 209 | config_params = { 210 | 'num_inference_steps': steps, 211 | 'guidance': guidance, 212 | 'width': width, 213 | 'height': height 214 | } 215 | 216 | # Add image_path if provided (for inpainting/editing) 217 | if 'image_path' in kwargs: 218 | image_path = kwargs['image_path'] 219 | if not os.path.exists(image_path): 220 | raise ModelGenerationError(f"Image path does not exist: {image_path}") 221 | config_params['image_path'] = image_path 222 | 223 | return Config(**config_params) 224 | 225 | 226 | class FluxStandardModel(BaseFluxModel): 227 | """Standard Flux model implementation for Dev and Schnell variants.""" 228 | 229 | def _load_model(self): 230 | """Load the standard Flux model.""" 231 | try: 232 | self.logger.info(f"Loading {self.config.model_type} model from {self.model_path}") 233 | 234 | # Prepare lora parameters 235 | lora_paths = self.config.lora_paths 236 | lora_scales = self.config.lora_scales 237 | 238 | # Log LoRA information if provided 239 | if lora_paths: 240 | self.logger.info(f"Using LoRA adapters: {lora_paths}") 241 | if lora_scales: 242 | self.logger.info(f"LoRA scales: {lora_scales}") 243 | 244 | self._model = Flux1( 245 | model_config=self.config.model_config, 246 | local_path=self.model_path, 247 | quantize=self.config.quantize, 248 | lora_paths=lora_paths, 249 | lora_scales=lora_scales, 250 | ) 251 | self._is_loaded = True 252 | self.logger.info(f"{self.config.model_type} model loaded successfully") 253 | except Exception as e: 254 | error_msg = f"Failed to load {self.config.model_type} model: {e}" 255 | self.logger.error(error_msg) 256 | raise ModelLoadError(error_msg) from e 257 | 258 | def _generate_image(self, prompt: str, seed: int, config: Config) -> Image.Image: 259 | """Generate image using standard Flux model.""" 260 | try: 261 | result = self._model.generate_image( 262 | config=config, 263 | prompt=prompt, 264 | seed=seed, 265 | ) 266 | return result.image 267 | except Exception as e: 268 | raise ModelGenerationError(f"Standard model generation failed: {e}") from e 269 | 270 | 271 | class FluxKontextModel(BaseFluxModel): 272 | """Flux Kontext model implementation.""" 273 | 274 | def _load_model(self): 275 | """Load the Flux Kontext model.""" 276 | try: 277 | self.logger.info(f"Loading Kontext model from {self.model_path}") 278 | self._model = Flux1Kontext( 279 | quantize=self.config.quantize, 280 | local_path=self.model_path 281 | ) 282 | self._is_loaded = True 283 | self.logger.info("Kontext model loaded successfully") 284 | except Exception as e: 285 | error_msg = f"Failed to load Kontext model: {e}" 286 | self.logger.error(error_msg) 287 | raise ModelLoadError(error_msg) from e 288 | 289 | def _generate_image(self, prompt: str, seed: int, config: Config) -> Image.Image: 290 | """Generate image using Flux Kontext model.""" 291 | try: 292 | result = self._model.generate_image( 293 | config=config, 294 | prompt=prompt, 295 | seed=seed, 296 | ) 297 | return result.image 298 | except Exception as e: 299 | raise ModelGenerationError(f"Kontext model generation failed: {e}") from e 300 | 301 | 302 | class FluxModel: 303 | """Factory class for creating and managing Flux models.""" 304 | 305 | _MODEL_CONFIGS = { 306 | "flux-schnell": ModelConfiguration.schnell, 307 | "flux-dev": ModelConfiguration.dev, 308 | "flux-krea-dev": ModelConfiguration.krea_dev, 309 | "flux-kontext-dev": ModelConfiguration.kontext, 310 | } 311 | 312 | _MODEL_CLASSES = { 313 | "flux-schnell": FluxStandardModel, 314 | "flux-dev": FluxStandardModel, 315 | "flux-krea-dev": FluxStandardModel, 316 | "flux-kontext-dev": FluxKontextModel, 317 | } 318 | 319 | def __init__(self, model_path: str, config_name: str, quantize: int = 8, 320 | lora_paths: Optional[List[str]] = None, lora_scales: Optional[List[float]] = None): 321 | 322 | self.config_name = config_name 323 | self.model_path = model_path 324 | self.quantize = quantize 325 | self.lora_paths = lora_paths 326 | self.lora_scales = lora_scales 327 | self.logger = logging.getLogger(self.__class__.__name__) 328 | 329 | # Validate configuration 330 | if config_name not in self._MODEL_CONFIGS: 331 | available_configs = ", ".join(self._MODEL_CONFIGS.keys()) 332 | raise InvalidConfigurationError(f"Invalid config name: {config_name}. Available options: {available_configs}") 333 | 334 | # Validate LoRA parameters for kontext model 335 | if config_name == "flux-kontext-dev" and (lora_paths is not None or lora_scales is not None): 336 | raise InvalidConfigurationError("Flux Kontext model does not support LoRA adapters") 337 | 338 | try: 339 | # Create model configuration 340 | config_factory = self._MODEL_CONFIGS[config_name] 341 | if config_name == "flux-kontext-dev": 342 | self.config = config_factory(quantize=quantize) 343 | else: 344 | self.config = config_factory(quantize=quantize, lora_paths=lora_paths, lora_scales=lora_scales) 345 | 346 | # Create model instance 347 | model_class = self._MODEL_CLASSES[config_name] 348 | self.flux = model_class(model_path, self.config) 349 | 350 | self.logger.info(f"FluxModel initialized successfully with config: {config_name}") 351 | if lora_paths: 352 | self.logger.info(f"LoRA adapters: {lora_paths}") 353 | 354 | except Exception as e: 355 | error_msg = f"Failed to initialize FluxModel: {e}" 356 | self.logger.error(error_msg) 357 | raise ModelLoadError(error_msg) from e 358 | 359 | def __call__(self, prompt: str, seed: int = 42, **kwargs) -> Image.Image: 360 | """Generate an image using the configured model.""" 361 | return self.flux(prompt, seed, **kwargs) 362 | 363 | @classmethod 364 | def get_available_configs(cls) -> list[str]: 365 | """Get list of available model configurations.""" 366 | return list(cls._MODEL_CONFIGS.keys()) 367 | 368 | @classmethod 369 | def get_model_info(cls, config_name: str) -> Dict[str, Any]: 370 | """Get information about a specific model configuration.""" 371 | if config_name not in cls._MODEL_CONFIGS: 372 | raise InvalidConfigurationError(f"Unknown config: {config_name}") 373 | 374 | config = cls._MODEL_CONFIGS[config_name]() 375 | return { 376 | "name": config_name, 377 | "type": config.model_type, 378 | "default_steps": config.default_steps, 379 | "default_guidance": config.default_guidance, 380 | "model_class": cls._MODEL_CLASSES[config_name].__name__ 381 | } 382 | 383 | def get_current_config(self) -> Dict[str, Any]: 384 | """Get current model configuration information.""" 385 | return { 386 | "config_name": self.config_name, 387 | "model_path": self.model_path, 388 | "quantize": self.quantize, 389 | "type": self.config.model_type, 390 | "default_steps": self.config.default_steps, 391 | "default_guidance": self.config.default_guidance, 392 | "is_loaded": self.flux._is_loaded if hasattr(self.flux, '_is_loaded') else False, 393 | "lora_paths": self.config.lora_paths, 394 | "lora_scales": self.config.lora_scales, 395 | } 396 | 397 | def is_loaded(self) -> bool: 398 | """Check if the model is loaded and ready for inference.""" 399 | return hasattr(self.flux, '_is_loaded') and self.flux._is_loaded -------------------------------------------------------------------------------- /app/models/mlx_embeddings.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import mlx.core as mx 3 | from mlx_embeddings.utils import load 4 | from typing import List, Optional 5 | 6 | class MLX_Embeddings: 7 | """ 8 | A wrapper class for MLX Embeddings that handles memory management to prevent leaks. 9 | 10 | This class provides a unified interface for generating embeddings from text inputs, 11 | with proper cleanup of MLX arrays and memory management. 12 | """ 13 | 14 | def __init__(self, model_path: str): 15 | """ 16 | Initialize the MLX_Embeddings model. 17 | 18 | Args: 19 | model_name (str): Name of the model to load. 20 | 21 | Raises: 22 | ValueError: If model loading fails. 23 | """ 24 | try: 25 | self.model, self.tokenizer = load(model_path) 26 | except Exception as e: 27 | raise ValueError(f"Error loading model: {str(e)}") 28 | 29 | def _get_embeddings(self, texts: List[str], max_length: int = 512) -> mx.array: 30 | """ 31 | Get embeddings for a list of texts with proper memory management. 32 | 33 | Args: 34 | texts: List of text inputs 35 | max_length: Maximum sequence length for tokenization 36 | 37 | Returns: 38 | MLX array of embeddings 39 | """ 40 | inputs = None 41 | outputs = None 42 | try: 43 | # Tokenize inputs 44 | inputs = self.tokenizer.batch_encode_plus( 45 | texts, 46 | return_tensors="mlx", 47 | padding=True, 48 | truncation=True, 49 | max_length=max_length 50 | ) 51 | 52 | # Generate embeddings 53 | outputs = self.model( 54 | inputs["input_ids"], 55 | attention_mask=inputs["attention_mask"] 56 | ).text_embeds 57 | 58 | # Return a copy to ensure the result persists after cleanup 59 | return mx.array(outputs) 60 | 61 | except Exception as e: 62 | # Clean up on error 63 | self._cleanup_arrays(inputs, outputs) 64 | raise 65 | finally: 66 | # Always clean up intermediate arrays 67 | self._cleanup_arrays(inputs, outputs) 68 | 69 | def _cleanup_arrays(self, *arrays): 70 | """Clean up MLX arrays to free memory.""" 71 | for array in arrays: 72 | if array is not None: 73 | try: 74 | if isinstance(array, dict): 75 | for key, value in array.items(): 76 | if hasattr(value, 'nbytes'): 77 | del value 78 | elif hasattr(array, 'nbytes'): 79 | del array 80 | except: 81 | pass 82 | 83 | # Clear MLX cache and force garbage collection 84 | mx.clear_cache() 85 | gc.collect() 86 | 87 | def __call__(self, texts: List[str], max_length: int = 512) -> List[List[float]]: 88 | """ 89 | Generate embeddings for a list of texts. 90 | 91 | Args: 92 | texts: List of text inputs 93 | max_length: Maximum sequence length for tokenization 94 | 95 | Returns: 96 | List of embedding vectors as float lists 97 | """ 98 | try: 99 | embeddings = self._get_embeddings(texts, max_length) 100 | # Convert to Python list and return 101 | result = embeddings.tolist() 102 | # Clean up the embeddings array 103 | del embeddings 104 | mx.clear_cache() 105 | gc.collect() 106 | return result 107 | except Exception as e: 108 | # Clean up on error 109 | mx.clear_cache() 110 | gc.collect() 111 | raise 112 | 113 | def cleanup(self): 114 | """Explicitly cleanup resources.""" 115 | try: 116 | # Clear any cached model outputs 117 | if hasattr(self, 'model'): 118 | del self.model 119 | if hasattr(self, 'tokenizer'): 120 | del self.tokenizer 121 | 122 | # Clear MLX cache and force garbage collection 123 | mx.clear_cache() 124 | gc.collect() 125 | except Exception as e: 126 | # Log cleanup errors but don't raise 127 | pass 128 | 129 | def __del__(self): 130 | """Destructor to ensure cleanup on object deletion.""" 131 | self.cleanup() 132 | 133 | if __name__ == "__main__": 134 | model_path = "mlx-community/all-MiniLM-L6-v2-4bit" 135 | model = MLX_Embeddings(model_path) 136 | try: 137 | texts = ["I like reading", "I like writing"] 138 | embeddings = model(texts) 139 | print(f"Generated embeddings shape: {len(embeddings)} x {len(embeddings[0])}") 140 | finally: 141 | # Explicit cleanup 142 | model.cleanup() -------------------------------------------------------------------------------- /app/models/mlx_lm.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import mlx.core as mx 4 | from mlx_lm.utils import load 5 | from mlx_lm.generate import ( 6 | generate, 7 | stream_generate, 8 | ) 9 | from outlines.processors import JSONLogitsProcessor 10 | from mlx_lm.models.cache import make_prompt_cache 11 | from mlx_lm.sample_utils import make_sampler, make_logits_processors 12 | from app.utils.outlines_transformer_tokenizer import OutlinesTransformerTokenizer 13 | from typing import List, Dict, Union, Generator 14 | 15 | DEFAULT_TEMPERATURE = os.getenv("DEFAULT_TEMPERATURE", 0.7) 16 | DEFAULT_TOP_P = os.getenv("DEFAULT_TOP_P", 0.95) 17 | DEFAULT_TOP_K = os.getenv("DEFAULT_TOP_K", 20) 18 | DEFAULT_MIN_P = os.getenv("DEFAULT_MIN_P", 0.0) 19 | DEFAULT_SEED = os.getenv("DEFAULT_SEED", 0) 20 | DEFAULT_MAX_TOKENS = os.getenv("DEFAULT_MAX_TOKENS", 8192) 21 | DEFAULT_BATCH_SIZE = os.getenv("DEFAULT_BATCH_SIZE", 32) 22 | 23 | class MLX_LM: 24 | """ 25 | A wrapper class for MLX Language Model that handles both streaming and non-streaming inference. 26 | 27 | This class provides a unified interface for generating text responses from text prompts, 28 | supporting both streaming and non-streaming modes. 29 | """ 30 | 31 | def __init__(self, model_path: str, context_length: int = 32768): 32 | try: 33 | self.model, self.tokenizer = load(model_path) 34 | self.pad_token_id = self.tokenizer.pad_token_id 35 | self.bos_token = self.tokenizer.bos_token 36 | self.model_type = self.model.model_type 37 | self.max_kv_size = context_length 38 | self.outlines_tokenizer = OutlinesTransformerTokenizer(self.tokenizer) 39 | except Exception as e: 40 | raise ValueError(f"Error loading model: {str(e)}") 41 | 42 | def _apply_pooling_strategy(self, embeddings: mx.array) -> mx.array: 43 | embeddings = mx.mean(embeddings, axis=1) 44 | return embeddings 45 | 46 | def _apply_l2_normalization(self, embeddings: mx.array) -> mx.array: 47 | l2_norms = mx.linalg.norm(embeddings, axis=1, keepdims=True) 48 | embeddings = embeddings / (l2_norms + 1e-8) 49 | return embeddings 50 | 51 | def _batch_process(self, prompts: List[str], batch_size: int = DEFAULT_BATCH_SIZE) -> List[List[int]]: 52 | """Process prompts in batches with optimized tokenization.""" 53 | all_tokenized = [] 54 | 55 | # Process prompts in batches 56 | for i in range(0, len(prompts), batch_size): 57 | batch = prompts[i:i + batch_size] 58 | tokenized_batch = [] 59 | 60 | # Tokenize all prompts in batch 61 | for p in batch: 62 | add_special_tokens = self.bos_token is None or not p.startswith(self.bos_token) 63 | tokens = self.tokenizer.encode(p, add_special_tokens=add_special_tokens) 64 | tokenized_batch.append(tokens) 65 | 66 | # Find max length in batch 67 | max_length = max(len(tokens) for tokens in tokenized_batch) 68 | 69 | # Pad tokens in a vectorized way 70 | for tokens in tokenized_batch: 71 | padding = [self.pad_token_id] * (max_length - len(tokens)) 72 | all_tokenized.append(tokens + padding) 73 | 74 | return all_tokenized 75 | 76 | def _preprocess_prompt(self, prompt: str) -> List[int]: 77 | """Tokenize a single prompt efficiently.""" 78 | add_special_tokens = self.bos_token is None or not prompt.startswith(self.bos_token) 79 | tokens = self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens) 80 | return mx.array(tokens) 81 | 82 | def get_model_type(self) -> str: 83 | return self.model_type 84 | 85 | def get_embeddings( 86 | self, 87 | prompts: List[str], 88 | batch_size: int = DEFAULT_BATCH_SIZE, 89 | normalize: bool = True 90 | ) -> List[float]: 91 | """ 92 | Get embeddings for a list of prompts efficiently. 93 | 94 | Args: 95 | prompts: List of text prompts 96 | batch_size: Size of batches for processing 97 | 98 | Returns: 99 | List of embeddings as float arrays 100 | """ 101 | # Process in batches to optimize memory usage 102 | all_embeddings = [] 103 | try: 104 | for i in range(0, len(prompts), batch_size): 105 | batch_prompts = prompts[i:i + batch_size] 106 | tokenized_batch = self._batch_process(batch_prompts, batch_size) 107 | 108 | # Convert to MLX array for efficient computation 109 | tokenized_batch = mx.array(tokenized_batch) 110 | 111 | try: 112 | # Compute embeddings for batch 113 | batch_embeddings = self.model.model(tokenized_batch) 114 | pooled_embedding = self._apply_pooling_strategy(batch_embeddings) 115 | if normalize: 116 | pooled_embedding = self._apply_l2_normalization(pooled_embedding) 117 | all_embeddings.extend(pooled_embedding.tolist()) 118 | finally: 119 | # Explicitly free MLX arrays to prevent memory leaks 120 | del tokenized_batch 121 | if 'batch_embeddings' in locals(): 122 | del batch_embeddings 123 | if 'pooled_embedding' in locals(): 124 | del pooled_embedding 125 | # Force MLX garbage collection 126 | mx.clear_cache() 127 | gc.collect() 128 | except Exception as e: 129 | # Clean up on error 130 | mx.clear_cache() 131 | gc.collect() 132 | raise 133 | 134 | return all_embeddings 135 | 136 | def __call__( 137 | self, 138 | messages: List[Dict[str, str]], 139 | stream: bool = False, 140 | **kwargs 141 | ) -> Union[str, Generator[str, None, None]]: 142 | """ 143 | Generate text response from the model. 144 | 145 | Args: 146 | messages (List[Dict[str, str]]): List of messages in the conversation. 147 | stream (bool): Whether to stream the response. 148 | **kwargs: Additional parameters for generation 149 | - temperature: Sampling temperature (default: 0.0) 150 | - top_p: Top-p sampling parameter (default: 1.0) 151 | - seed: Random seed (default: 0) 152 | - max_tokens: Maximum number of tokens to generate (default: 256) 153 | """ 154 | # Set default parameters if not provided 155 | seed = kwargs.get("seed", DEFAULT_SEED) 156 | max_tokens = kwargs.get("max_tokens", DEFAULT_MAX_TOKENS) 157 | chat_template_kwargs = kwargs.get("chat_template_kwargs", {}) 158 | 159 | sampler_kwargs = { 160 | "temp": kwargs.get("temperature", DEFAULT_TEMPERATURE), 161 | "top_p": kwargs.get("top_p", DEFAULT_TOP_P), 162 | "top_k": kwargs.get("top_k", DEFAULT_TOP_K), 163 | "min_p": kwargs.get("min_p", DEFAULT_MIN_P) 164 | } 165 | 166 | repetition_penalty = kwargs.get("repetition_penalty", 1.0) 167 | repetition_context_size = kwargs.get("repetition_context_size", 20) 168 | logits_processors = make_logits_processors(repetition_penalty=repetition_penalty, repetition_context_size=repetition_context_size) 169 | json_schema = kwargs.get("schema", None) 170 | if json_schema: 171 | logits_processors.append( 172 | JSONLogitsProcessor( 173 | schema = json_schema, 174 | tokenizer = self.outlines_tokenizer, 175 | tensor_library_name = "mlx" 176 | ) 177 | ) 178 | 179 | mx.random.seed(seed) 180 | prompt_cache = make_prompt_cache(self.model, self.max_kv_size) 181 | 182 | input_tokens = self.tokenizer.apply_chat_template( 183 | messages, 184 | add_generation_prompt=True, 185 | **chat_template_kwargs, 186 | ) 187 | 188 | sampler = make_sampler( 189 | **sampler_kwargs 190 | ) 191 | 192 | if not stream: 193 | return generate( 194 | self.model, 195 | self.tokenizer, 196 | input_tokens, 197 | sampler=sampler, 198 | max_tokens=max_tokens, 199 | prompt_cache=prompt_cache, 200 | logits_processors=logits_processors 201 | ) 202 | else: 203 | # Streaming mode: return generator of chunks 204 | return stream_generate( 205 | self.model, 206 | self.tokenizer, 207 | input_tokens, 208 | sampler=sampler, 209 | max_tokens=max_tokens, 210 | prompt_cache=prompt_cache, 211 | logits_processors=logits_processors 212 | ) -------------------------------------------------------------------------------- /app/models/mlx_vlm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mlx.core as mx 3 | from typing import List, Dict, Union, Generator 4 | from mlx_vlm.models.cache import make_prompt_cache 5 | from mlx_vlm import load, generate, stream_generate 6 | from mlx_vlm.video_generate import process_vision_info 7 | 8 | # Default model parameters 9 | DEFAULT_MAX_TOKENS = os.getenv("DEFAULT_MAX_TOKENS", 8192) 10 | DEFAULT_TEMPERATURE = os.getenv("DEFAULT_TEMPERATURE", 0.0) 11 | DEFAULT_TOP_P = os.getenv("DEFAULT_TOP_P", 1.0) 12 | DEFAULT_SEED = os.getenv("DEFAULT_SEED", 0) 13 | 14 | class MLX_VLM: 15 | """ 16 | A wrapper class for MLX Multimodal Model that handles both streaming and non-streaming inference. 17 | 18 | This class provides a unified interface for generating text responses from images and text prompts, 19 | supporting both streaming and non-streaming modes. 20 | """ 21 | 22 | def __init__(self, model_path: str, context_length: int = None): 23 | """ 24 | Initialize the MLX_VLM model. 25 | 26 | Args: 27 | model_path (str): Path to the model directory containing model weights and configuration. 28 | 29 | Raises: 30 | ValueError: If model loading fails. 31 | """ 32 | try: 33 | self.model, self.processor = load(model_path, lazy=False, trust_remote_code=True) 34 | self.max_kv_size = context_length 35 | self.config = self.model.config 36 | except Exception as e: 37 | raise ValueError(f"Error loading model: {str(e)}") 38 | 39 | def _is_video_model(self): 40 | return hasattr(self.config, "video_token_id") or hasattr( 41 | self.config, "video_token_index" 42 | ) 43 | 44 | def get_model_type(self): 45 | return self.config.model_type 46 | 47 | def __call__( 48 | self, 49 | messages: List[Dict[str, str]], 50 | images: List[str] = None, 51 | audios: List[str] = None, 52 | videos: List[str] = None, 53 | stream: bool = False, 54 | **kwargs 55 | ) -> Union[str, Generator[str, None, None]]: 56 | """ 57 | Generate text response from images and messages. 58 | 59 | Args: 60 | images (List[str]): List of image paths to process. 61 | messages (List[Dict[str, str]]): List of message dictionaries with 'role' and 'content' keys. 62 | stream (bool, optional): Whether to stream the response. Defaults to False. 63 | **kwargs: Additional model parameters (chat_template_kwargs, temperature, max_tokens, etc.) 64 | 65 | Returns: 66 | Union[str, Generator[str, None, None]]: 67 | - If stream=False: Complete response as string 68 | - If stream=True: Generator yielding response chunks 69 | """ 70 | 71 | if images and videos: 72 | raise ValueError("Cannot process both images and videos in the same request") 73 | 74 | if videos and not self._is_video_model(): 75 | raise ValueError("Model is not a video model") 76 | 77 | text = self.processor.apply_chat_template( 78 | messages, 79 | tokenize=False, 80 | add_generation_prompt=True, 81 | **kwargs.get("chat_template_kwargs", {}) 82 | ) 83 | 84 | image_inputs, video_inputs = process_vision_info(messages) 85 | 86 | inputs = self.processor( 87 | text=[text], 88 | images=image_inputs, 89 | videos=video_inputs, 90 | padding=True, 91 | return_tensors="pt" 92 | ) 93 | 94 | model_params = { 95 | "input_ids": mx.array(inputs["input_ids"]), 96 | "mask": mx.array(inputs["attention_mask"]), 97 | **kwargs 98 | } 99 | 100 | if images: 101 | model_params["pixel_values"] = mx.array(inputs["pixel_values"]) 102 | model_params["image_grid_thw"] = mx.array(inputs["image_grid_thw"]) 103 | 104 | if videos: 105 | model_params["pixel_values"] = mx.array(inputs["pixel_values_videos"]) 106 | model_params["video_grid_thw"] = mx.array(inputs["video_grid_thw"]) 107 | 108 | prompt_cache = make_prompt_cache(self.model, self.max_kv_size) 109 | 110 | if stream: 111 | return stream_generate( 112 | self.model, 113 | self.processor, 114 | prompt=text, 115 | prompt_cache=prompt_cache, 116 | **model_params 117 | ) 118 | else: 119 | return generate( 120 | self.model, 121 | self.processor, 122 | prompt=text, 123 | prompt_cache=prompt_cache, 124 | **model_params 125 | ) 126 | 127 | 128 | if __name__ == "__main__": 129 | image_path = "examples/images/attention.png" 130 | video_path = "examples/videos/demo.mp4" 131 | model_path = "mlx-community/GLM-4.5V-4bit" 132 | 133 | model = MLX_VLM(model_path) 134 | print("MODEL TYPE: ", model.get_model_type()) 135 | 136 | tools = [{ 137 | "type": "function", 138 | "function": { 139 | "name": "get_weather", 140 | "description": "Get the weather for a given city", 141 | "parameters": { 142 | "type": "object", 143 | "properties": { 144 | "city": {"type": "string", "description": "The city to get the weather for"} 145 | } 146 | }, 147 | "required": ["city"] 148 | }} 149 | ] 150 | kwargs = { 151 | "chat_template_kwargs": { 152 | "tools": tools, 153 | "enable_thinking": True, 154 | }, 155 | "temperature": 0.0, 156 | "top_p": 1.0, 157 | "seed": 0, 158 | "max_tokens": 8192, 159 | "frequency_penalty": 0.0, 160 | "presence_penalty": 0.0 161 | } 162 | messages = [ 163 | { 164 | "role": "user", 165 | "content": [ 166 | { 167 | "type": "text", 168 | "text": "Describe the video in detail" 169 | }, 170 | { 171 | "type": "image", 172 | "image": image_path 173 | } 174 | ] 175 | } 176 | ] 177 | response = model(messages, stream=False, **kwargs) 178 | print(response) -------------------------------------------------------------------------------- /app/models/mlx_whisper.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | from functools import lru_cache 4 | from mlx_whisper.transcribe import transcribe 5 | 6 | SAMPLING_RATE = 16000 7 | CHUNK_SIZE = 30 8 | 9 | 10 | @lru_cache(maxsize=32) 11 | def load_audio(fname): 12 | """Load and cache audio file. Cache size limited to 32 recent files.""" 13 | a, _ = librosa.load(fname, sr=SAMPLING_RATE, dtype=np.float32) 14 | return a 15 | 16 | @lru_cache(maxsize=32) 17 | def calculate_audio_duration(audio_path: str) -> int: 18 | """Calculate the duration of the audio file in seconds.""" 19 | audio = load_audio(audio_path) 20 | return len(audio) / SAMPLING_RATE 21 | 22 | class MLX_Whisper: 23 | def __init__(self, model_path: str): 24 | self.model_path = model_path 25 | 26 | def _transcribe_generator(self, audio_path: str, **kwargs): 27 | """Stream transcription by processing audio in larger chunks.""" 28 | # Load the audio file 29 | audio = load_audio(audio_path) 30 | duration = calculate_audio_duration(audio_path) 31 | 32 | beg = 0.0 33 | while beg < duration: 34 | # Calculate chunk boundaries 35 | chunk_end = min(beg + CHUNK_SIZE, duration) 36 | 37 | # Extract audio chunk 38 | beg_samples = int(beg * SAMPLING_RATE) 39 | end_samples = int(chunk_end * SAMPLING_RATE) 40 | audio_chunk = audio[beg_samples:end_samples] 41 | 42 | # Transcribe chunk 43 | result = transcribe(audio_chunk, path_or_hf_repo=self.model_path, **kwargs) 44 | 45 | # Add timing information 46 | result["chunk_start"] = beg 47 | result["chunk_end"] = chunk_end 48 | 49 | yield result 50 | 51 | beg += CHUNK_SIZE 52 | 53 | def __call__(self, audio_path: str, stream: bool = False, **kwargs): 54 | """ 55 | Transcribe audio file. 56 | 57 | Args: 58 | audio_path: Path to audio file 59 | stream: If True, yields chunks. If False, transcribes entire file at once. 60 | **kwargs: Additional arguments passed to transcribe() 61 | """ 62 | if stream: 63 | return self._transcribe_generator(audio_path, **kwargs) 64 | else: 65 | return transcribe(audio_path, path_or_hf_repo=self.model_path, **kwargs) 66 | 67 | 68 | if __name__ == "__main__": 69 | model = MLX_Whisper("mlx-community/whisper-tiny") 70 | # Non-streaming (fastest for most use cases) 71 | result = model("examples/audios/podcast.wav", stream=True) 72 | for chunk in result: 73 | print(f"[{chunk['chunk_start']:.1f}s - {chunk['chunk_end']:.1f}s]: {chunk['text']}") -------------------------------------------------------------------------------- /app/schemas/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /app/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /app/utils/dill.py: -------------------------------------------------------------------------------- 1 | # copied from https://github.com/huggingface/datasets/blob/1e1d313/src/datasets/utils/_dill.py 2 | 3 | # Copyright 2023 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Extends `dill` to support pickling more types and produce more consistent dumps.""" 17 | 18 | import sys 19 | from io import BytesIO 20 | from types import FunctionType 21 | from typing import Any, Dict, List, Union 22 | 23 | import dill 24 | import xxhash 25 | 26 | 27 | class Hasher: 28 | """Hasher that accepts python objects as inputs.""" 29 | 30 | dispatch: Dict = {} 31 | 32 | def __init__(self): 33 | self.m = xxhash.xxh64() 34 | 35 | @classmethod 36 | def hash_bytes(cls, value: Union[bytes, List[bytes]]) -> str: 37 | value = [value] if isinstance(value, bytes) else value 38 | m = xxhash.xxh64() 39 | for x in value: 40 | m.update(x) 41 | return m.hexdigest() 42 | 43 | @classmethod 44 | def hash(cls, value: Any) -> str: 45 | return cls.hash_bytes(dumps(value)) 46 | 47 | def update(self, value: Any) -> None: 48 | header_for_update = f"=={type(value)}==" 49 | value_for_update = self.hash(value) 50 | self.m.update(header_for_update.encode("utf8")) 51 | self.m.update(value_for_update.encode("utf-8")) 52 | 53 | def hexdigest(self) -> str: 54 | return self.m.hexdigest() 55 | 56 | 57 | class Pickler(dill.Pickler): 58 | dispatch = dill._dill.MetaCatchingDict(dill.Pickler.dispatch.copy()) 59 | _legacy_no_dict_keys_sorting = False 60 | 61 | def save(self, obj, save_persistent_id=True): 62 | obj_type = type(obj) 63 | if obj_type not in self.dispatch: 64 | if "regex" in sys.modules: 65 | import regex # type: ignore 66 | 67 | if obj_type is regex.Pattern: 68 | pklregister(obj_type)(_save_regexPattern) 69 | if "spacy" in sys.modules: 70 | import spacy # type: ignore 71 | 72 | if issubclass(obj_type, spacy.Language): 73 | pklregister(obj_type)(_save_spacyLanguage) 74 | if "tiktoken" in sys.modules: 75 | import tiktoken # type: ignore 76 | 77 | if obj_type is tiktoken.Encoding: 78 | pklregister(obj_type)(_save_tiktokenEncoding) 79 | if "torch" in sys.modules: 80 | import torch # type: ignore 81 | 82 | if issubclass(obj_type, torch.Tensor): 83 | pklregister(obj_type)(_save_torchTensor) 84 | 85 | if obj_type is torch.Generator: 86 | pklregister(obj_type)(_save_torchGenerator) 87 | 88 | # Unwrap `torch.compile`-ed modules 89 | if issubclass(obj_type, torch.nn.Module): 90 | obj = getattr(obj, "_orig_mod", obj) 91 | if "transformers" in sys.modules: 92 | import transformers # type: ignore 93 | 94 | if issubclass(obj_type, transformers.PreTrainedTokenizerBase): 95 | pklregister(obj_type)(_save_transformersPreTrainedTokenizerBase) 96 | 97 | # Unwrap `torch.compile`-ed functions 98 | if obj_type is FunctionType: 99 | obj = getattr(obj, "_torchdynamo_orig_callable", obj) 100 | dill.Pickler.save(self, obj, save_persistent_id=save_persistent_id) 101 | 102 | def _batch_setitems(self, items): 103 | if self._legacy_no_dict_keys_sorting: 104 | return super()._batch_setitems(items) 105 | # Ignore the order of keys in a dict 106 | try: 107 | # Faster, but fails for unorderable elements 108 | items = sorted(items) 109 | except Exception: # TypeError, decimal.InvalidOperation, etc. 110 | items = sorted(items, key=lambda x: Hasher.hash(x[0])) 111 | dill.Pickler._batch_setitems(self, items) 112 | 113 | def memoize(self, obj): 114 | # Don't memoize strings since two identical strings can have different Python ids 115 | if type(obj) is not str: # noqa: E721 116 | dill.Pickler.memoize(self, obj) 117 | 118 | 119 | def pklregister(t): 120 | """Register a custom reducer for the type.""" 121 | 122 | def proxy(func): 123 | Pickler.dispatch[t] = func 124 | return func 125 | 126 | return proxy 127 | 128 | 129 | def dump(obj, file): 130 | """Pickle an object to a file.""" 131 | Pickler(file, recurse=True).dump(obj) 132 | 133 | 134 | def dumps(obj): 135 | """Pickle an object to a string.""" 136 | file = BytesIO() 137 | dump(obj, file) 138 | return file.getvalue() 139 | 140 | 141 | def log(pickler, msg): 142 | pass 143 | 144 | 145 | def _save_regexPattern(pickler, obj): 146 | import regex # type: ignore 147 | 148 | log(pickler, f"Re: {obj}") 149 | args = (obj.pattern, obj.flags) 150 | pickler.save_reduce(regex.compile, args, obj=obj) 151 | log(pickler, "# Re") 152 | 153 | 154 | def _save_tiktokenEncoding(pickler, obj): 155 | import tiktoken # type: ignore 156 | 157 | log(pickler, f"Enc: {obj}") 158 | args = (obj.name, obj._pat_str, obj._mergeable_ranks, obj._special_tokens) 159 | pickler.save_reduce(tiktoken.Encoding, args, obj=obj) 160 | log(pickler, "# Enc") 161 | 162 | 163 | def _save_torchTensor(pickler, obj): 164 | import torch # type: ignore 165 | 166 | # `torch.from_numpy` is not picklable in `torch>=1.11.0` 167 | def create_torchTensor(np_array, dtype=None): 168 | tensor = torch.from_numpy(np_array) 169 | if dtype: 170 | tensor = tensor.type(dtype) 171 | return tensor 172 | 173 | log(pickler, f"To: {obj}") 174 | if obj.dtype == torch.bfloat16: 175 | args = (obj.detach().to(torch.float).cpu().numpy(), torch.bfloat16) 176 | else: 177 | args = (obj.detach().cpu().numpy(),) 178 | pickler.save_reduce(create_torchTensor, args, obj=obj) 179 | log(pickler, "# To") 180 | 181 | 182 | def _save_torchGenerator(pickler, obj): 183 | import torch # type: ignore 184 | 185 | def create_torchGenerator(state): 186 | generator = torch.Generator() 187 | generator.set_state(state) 188 | return generator 189 | 190 | log(pickler, f"Ge: {obj}") 191 | args = (obj.get_state(),) 192 | pickler.save_reduce(create_torchGenerator, args, obj=obj) 193 | log(pickler, "# Ge") 194 | 195 | 196 | def _save_spacyLanguage(pickler, obj): 197 | import spacy # type: ignore 198 | 199 | def create_spacyLanguage(config, bytes): 200 | lang_cls = spacy.util.get_lang_class(config["nlp"]["lang"]) 201 | lang_inst = lang_cls.from_config(config) 202 | return lang_inst.from_bytes(bytes) 203 | 204 | log(pickler, f"Sp: {obj}") 205 | args = (obj.config, obj.to_bytes()) 206 | pickler.save_reduce(create_spacyLanguage, args, obj=obj) 207 | log(pickler, "# Sp") 208 | 209 | 210 | def _save_transformersPreTrainedTokenizerBase(pickler, obj): 211 | log(pickler, f"Tok: {obj}") 212 | # Ignore the `cache` attribute 213 | state = obj.__dict__ 214 | if "cache" in state and isinstance(state["cache"], dict): 215 | state["cache"] = {} 216 | pickler.save_reduce(type(obj), (), state=state, obj=obj) 217 | log(pickler, "# Tok") -------------------------------------------------------------------------------- /app/utils/errors.py: -------------------------------------------------------------------------------- 1 | from http import HTTPStatus 2 | from typing import Union 3 | 4 | def create_error_response( 5 | message: str, 6 | err_type: str = "internal_error", 7 | status_code: Union[int, HTTPStatus] = HTTPStatus.INTERNAL_SERVER_ERROR, 8 | param: str = None, 9 | code: str = None 10 | ): 11 | return { 12 | "error": { 13 | "message": message, 14 | "type": err_type, 15 | "param": param, 16 | "code": str(code or (status_code.value if isinstance(status_code, HTTPStatus) else status_code)) 17 | } 18 | } -------------------------------------------------------------------------------- /app/utils/outlines_transformer_tokenizer.py: -------------------------------------------------------------------------------- 1 | from app.utils.dill import Hasher 2 | from outlines.models.transformers import TransformerTokenizer 3 | 4 | 5 | class OutlinesTransformerTokenizer(TransformerTokenizer): 6 | """ 7 | Update the outlines TransformerTokenizer to use our own Hasher class, so that we don't need the datasets dependency 8 | 9 | This class and the external dependency can be removed when the following import is deleted 10 | https://github.com/dottxt-ai/outlines/blob/69418d/outlines/models/transformers.py#L117 11 | """ 12 | 13 | def __hash__(self): 14 | return hash(Hasher.hash(self.tokenizer)) -------------------------------------------------------------------------------- /app/version.py: -------------------------------------------------------------------------------- 1 | # Version number format: MAJOR.MINOR.PATCH 2 | # Major: Major version number (increments when breaking changes are introduced) 3 | # Minor: Minor version number (increments when new features are added) 4 | # Patch: Patch version number (increments when bug fixes are made) 5 | 6 | __version__ = "1.3.12" -------------------------------------------------------------------------------- /configure_mlx.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Get the total memory in MB 4 | TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024)) 5 | 6 | # Calculate 80% and TOTAL_MEM_GB-5GB in MB 7 | EIGHTY_PERCENT=$(($TOTAL_MEM_MB * 80 / 100)) 8 | MINUS_5GB=$((($TOTAL_MEM_MB - 5120))) 9 | 10 | # Calculate 70% and TOTAL_MEM_GB-8GB in MB 11 | SEVENTY_PERCENT=$(($TOTAL_MEM_MB * 70 / 100)) 12 | MINUS_8GB=$((($TOTAL_MEM_MB - 8192))) 13 | 14 | # Set WIRED_LIMIT_MB to higher value 15 | if [ $EIGHTY_PERCENT -gt $MINUS_5GB ]; then 16 | WIRED_LIMIT_MB=$EIGHTY_PERCENT 17 | else 18 | WIRED_LIMIT_MB=$MINUS_5GB 19 | fi 20 | 21 | # Set WIRED_LWM_MB to higher value 22 | if [ $SEVENTY_PERCENT -gt $MINUS_8GB ]; then 23 | WIRED_LWM_MB=$SEVENTY_PERCENT 24 | else 25 | WIRED_LWM_MB=$MINUS_8GB 26 | fi 27 | 28 | # Display the calculated values 29 | echo "Total memory: $TOTAL_MEM_MB MB" 30 | echo "Maximum limit (iogpu.wired_limit_mb): $WIRED_LIMIT_MB MB" 31 | echo "Lower bound (iogpu.wired_lwm_mb): $WIRED_LWM_MB MB" 32 | 33 | # Apply the values with sysctl, but check if we're already root 34 | if [ "$EUID" -eq 0 ]; then 35 | sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB 36 | sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 37 | else 38 | # Try without sudo first, fall back to sudo if needed 39 | sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB 2>/dev/null || \ 40 | sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB 41 | sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 2>/dev/null || \ 42 | sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 43 | fi -------------------------------------------------------------------------------- /examples/audio_examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Audio Processing with MLX Server\n", 8 | "\n", 9 | "This notebook demonstrates how to process audio files using the MLX Server with OpenAI-compatible API.\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## What You'll Learn\n", 17 | "\n", 18 | "- Connect to MLX Server\n", 19 | "- Load and encode audio files for processing\n", 20 | "- Send audio to the model for analysis\n", 21 | "- Get text descriptions of audio content\n", 22 | "\n", 23 | "## Prerequisites\n", 24 | "\n", 25 | "- MLX Server running on localhost:8000\n", 26 | "- Audio file in the `audios/` directory\n", 27 | "- OpenAI Python library installed\n" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "## Step 1: Setup and Connection\n" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 1, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "name": "stdout", 44 | "output_type": "stream", 45 | "text": [ 46 | "✅ Connected to MLX Server\n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "# Import required libraries\n", 52 | "from openai import OpenAI\n", 53 | "import base64\n", 54 | "import os\n", 55 | "\n", 56 | "# Initialize OpenAI client to connect to MLX Server\n", 57 | "# The MLX Server runs locally and provides OpenAI-compatible endpoints\n", 58 | "client = OpenAI(\n", 59 | " base_url=\"http://localhost:8000/v1\", # MLX Server address\n", 60 | " api_key=\"fake-api-key\", # Any string works for local server\n", 61 | ")\n", 62 | "\n", 63 | "print(\"✅ Connected to MLX Server\")\n" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "## Step 2: Audio File Processing\n" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 2, 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "name": "stdout", 80 | "output_type": "stream", 81 | "text": [ 82 | "✅ Loaded audio file: audios/audio.wav\n", 83 | " File size: 372698 bytes\n", 84 | " Encoded size: 496932 characters\n" 85 | ] 86 | } 87 | ], 88 | "source": [ 89 | "def load_audio_file(audio_path: str) -> str:\n", 90 | " \"\"\"\n", 91 | " Load an audio file and encode it as base64 for API transmission.\n", 92 | " \n", 93 | " Args:\n", 94 | " audio_path (str): Path to the audio file\n", 95 | " \n", 96 | " Returns:\n", 97 | " str: Base64 encoded audio data\n", 98 | " \"\"\"\n", 99 | " if not os.path.exists(audio_path):\n", 100 | " raise FileNotFoundError(f\"Audio file not found: {audio_path}\")\n", 101 | " \n", 102 | " with open(audio_path, \"rb\") as audio_file:\n", 103 | " audio_data = audio_file.read()\n", 104 | " encoded_audio = base64.b64encode(audio_data).decode('utf-8')\n", 105 | " \n", 106 | " print(f\"✅ Loaded audio file: {audio_path}\")\n", 107 | " print(f\" File size: {len(audio_data)} bytes\")\n", 108 | " print(f\" Encoded size: {len(encoded_audio)} characters\")\n", 109 | " \n", 110 | " return encoded_audio\n", 111 | "\n", 112 | "# Load the sample audio file\n", 113 | "audio_path = \"audios/audio.wav\"\n", 114 | "audio_base64 = load_audio_file(audio_path)\n" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "## Step 3: Audio Analysis\n" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 3, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "name": "stdout", 131 | "output_type": "stream", 132 | "text": [ 133 | "🎵 Audio Analysis Result:\n", 134 | " Dogs are sitting by the door.\n" 135 | ] 136 | } 137 | ], 138 | "source": [ 139 | "def analyze_audio(audio_base64: str, prompt: str = \"Describe what you hear in this audio.\") -> str:\n", 140 | " \"\"\"\n", 141 | " Send audio to MLX Server for analysis.\n", 142 | " \n", 143 | " Args:\n", 144 | " audio_base64 (str): Base64 encoded audio data\n", 145 | " prompt (str): Text prompt for the model\n", 146 | " \n", 147 | " Returns:\n", 148 | " str: Model's response\n", 149 | " \"\"\"\n", 150 | " try:\n", 151 | " response = client.chat.completions.create(\n", 152 | " model=\"local-multimodal\",\n", 153 | " messages=[\n", 154 | " {\n", 155 | " \"role\": \"user\", \n", 156 | " \"content\": [\n", 157 | " {\n", 158 | " \"type\": \"input_audio\",\n", 159 | " \"input_audio\": {\n", 160 | " \"data\": audio_base64,\n", 161 | " \"format\": \"wav\"\n", 162 | " }\n", 163 | " },\n", 164 | " {\n", 165 | " \"type\": \"text\",\n", 166 | " \"text\": prompt\n", 167 | " }\n", 168 | " ]\n", 169 | " }\n", 170 | " ],\n", 171 | " max_tokens=1024\n", 172 | " )\n", 173 | " \n", 174 | " return response.choices[0].message.content\n", 175 | " \n", 176 | " except Exception as e:\n", 177 | " return f\"Error analyzing audio: {str(e)}\"\n", 178 | "\n", 179 | "# Analyze the audio with a descriptive prompt\n", 180 | "result = analyze_audio(audio_base64, \"Describe the audio in detail.\")\n", 181 | "print(\"🎵 Audio Analysis Result:\")\n", 182 | "print(f\" {result}\")\n" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "metadata": {}, 188 | "source": [ 189 | "## Conclusion\n", 190 | "\n", 191 | "This notebook demonstrated the audio processing capabilities of the MLX Server using OpenAI-compatible API endpoints. Key highlights include:\n", 192 | "\n", 193 | "- **Audio Input Support**: Successfully processed audio files by encoding them as base64 and sending them through the `input_audio` message type\n", 194 | "- **Multimodal Integration**: Combined audio input with text prompts to create rich, context-aware responses\n", 195 | "- **OpenAI Compatibility**: Leveraged familiar OpenAI API patterns for seamless integration with existing workflows\n", 196 | "- **Error Handling**: Implemented proper error handling for robust audio processing\n", 197 | "\n", 198 | "The MLX Server's audio processing capabilities enable powerful applications such as:\n", 199 | "- Audio transcription and analysis\n", 200 | "- Voice-controlled interfaces\n", 201 | "- Audio content summarization\n", 202 | "- Accessibility features for audio-based content\n", 203 | "\n", 204 | "This foundation opens up numerous possibilities for building audio-enabled AI applications with the performance benefits of MLX on Apple Silicon.\n" 205 | ] 206 | } 207 | ], 208 | "metadata": { 209 | "kernelspec": { 210 | "display_name": "testing", 211 | "language": "python", 212 | "name": "python3" 213 | }, 214 | "language_info": { 215 | "codemirror_mode": { 216 | "name": "ipython", 217 | "version": 3 218 | }, 219 | "file_extension": ".py", 220 | "mimetype": "text/x-python", 221 | "name": "python", 222 | "nbconvert_exporter": "python", 223 | "pygments_lexer": "ipython3", 224 | "version": "3.11.11" 225 | } 226 | }, 227 | "nbformat": 4, 228 | "nbformat_minor": 2 229 | } 230 | -------------------------------------------------------------------------------- /examples/audios/audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cubist38/mlx-openai-server/25fc00aa21fc6794358849ebfb9d866a6e203eda/examples/audios/audio.wav -------------------------------------------------------------------------------- /examples/audios/podcast.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cubist38/mlx-openai-server/25fc00aa21fc6794358849ebfb9d866a6e203eda/examples/audios/podcast.wav -------------------------------------------------------------------------------- /examples/images/attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cubist38/mlx-openai-server/25fc00aa21fc6794358849ebfb9d866a6e203eda/examples/images/attention.png -------------------------------------------------------------------------------- /examples/images/china.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cubist38/mlx-openai-server/25fc00aa21fc6794358849ebfb9d866a6e203eda/examples/images/china.png -------------------------------------------------------------------------------- /examples/images/green_dog.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cubist38/mlx-openai-server/25fc00aa21fc6794358849ebfb9d866a6e203eda/examples/images/green_dog.jpeg -------------------------------------------------------------------------------- /examples/images/password.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cubist38/mlx-openai-server/25fc00aa21fc6794358849ebfb9d866a6e203eda/examples/images/password.jpg -------------------------------------------------------------------------------- /examples/lm_embeddings_examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Embeddings API Examples with MLX Server" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "This notebook demonstrates how to use the embeddings endpoint of MLX Server through the OpenAI-compatible API. You'll learn how to generate embeddings, work with batches, compare similarity between texts, and use embeddings for practical applications." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## Setup and Connection" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# Import the OpenAI client for API communication\n", 31 | "from openai import OpenAI\n", 32 | "\n", 33 | "# Connect to the local MLX Server with OpenAI-compatible API\n", 34 | "client = OpenAI(\n", 35 | " base_url=\"http://localhost:8000/v1\",\n", 36 | " api_key=\"fake-api-key\",\n", 37 | ")" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "## Basic Embedding Generation\n", 45 | "\n", 46 | "### Single Text Embedding\n" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "# Generate embedding for a single text input\n", 56 | "single_text = \"Artificial intelligence is transforming how we interact with technology.\"\n", 57 | "response = client.embeddings.create(\n", 58 | " input=[single_text],\n", 59 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n", 60 | ")" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "### Batch Processing Multiple Texts" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "text_batch = [\n", 77 | " \"Machine learning algorithms improve with more data\",\n", 78 | " \"Natural language processing helps computers understand human language\",\n", 79 | " \"Computer vision allows machines to interpret visual information\"\n", 80 | "]\n", 81 | "\n", 82 | "batch_response = client.embeddings.create(\n", 83 | " input=text_batch,\n", 84 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n", 85 | ")" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 4, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "Number of embeddings generated: 3\n", 98 | "Dimensions of each embedding: 1536\n" 99 | ] 100 | } 101 | ], 102 | "source": [ 103 | "# Access all embeddings\n", 104 | "embeddings = [item.embedding for item in batch_response.data]\n", 105 | "print(f\"Number of embeddings generated: {len(embeddings)}\")\n", 106 | "print(f\"Dimensions of each embedding: {len(embeddings[0])}\")" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "## Semantic Similarity Calculation\n", 114 | "\n", 115 | "One of the most common uses for embeddings is measuring semantic similarity between texts." 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 5, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "import numpy as np" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 6, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "def cosine_similarity_score(vec1, vec2):\n", 134 | " \"\"\"Calculate cosine similarity between two vectors\"\"\"\n", 135 | " dot_product = np.dot(vec1, vec2)\n", 136 | " norm1 = np.linalg.norm(vec1)\n", 137 | " norm2 = np.linalg.norm(vec2)\n", 138 | " return dot_product / (norm1 * norm2)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 7, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "# Example texts to compare\n", 148 | "text1 = \"Dogs are loyal pets that provide companionship\"\n", 149 | "text2 = \"Canines make friendly companions for humans\"\n", 150 | "text3 = \"Quantum physics explores the behavior of matter at atomic scales\"" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 8, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "# Generate embeddings\n", 160 | "comparison_texts = [text1, text2, text3]\n", 161 | "comparison_response = client.embeddings.create(\n", 162 | " input=comparison_texts,\n", 163 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n", 164 | ")\n", 165 | "comparison_embeddings = [item.embedding for item in comparison_response.data]" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 9, 171 | "metadata": {}, 172 | "outputs": [ 173 | { 174 | "name": "stdout", 175 | "output_type": "stream", 176 | "text": [ 177 | "Similarity between text1 and text2: 0.8142\n", 178 | "Similarity between text1 and text3: 0.6082\n", 179 | "Similarity between text2 and text3: 0.5739\n" 180 | ] 181 | } 182 | ], 183 | "source": [ 184 | "# Compare similarities\n", 185 | "similarity_1_2 = cosine_similarity_score(comparison_embeddings[0], comparison_embeddings[1])\n", 186 | "similarity_1_3 = cosine_similarity_score(comparison_embeddings[0], comparison_embeddings[2])\n", 187 | "similarity_2_3 = cosine_similarity_score(comparison_embeddings[1], comparison_embeddings[2])\n", 188 | "\n", 189 | "print(f\"Similarity between text1 and text2: {similarity_1_2:.4f}\")\n", 190 | "print(f\"Similarity between text1 and text3: {similarity_1_3:.4f}\")\n", 191 | "print(f\"Similarity between text2 and text3: {similarity_2_3:.4f}\")" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | "## Text Search Using Embeddings" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 10, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "# Sample document collection\n", 208 | "documents = [\n", 209 | " \"The quick brown fox jumps over the lazy dog\",\n", 210 | " \"Machine learning models require training data\",\n", 211 | " \"Neural networks are inspired by biological neurons\",\n", 212 | " \"Deep learning is a subset of machine learning\",\n", 213 | " \"Natural language processing helps with text analysis\",\n", 214 | " \"Computer vision systems can detect objects in images\"\n", 215 | "]" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 11, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "# Generate embeddings for all documents\n", 225 | "doc_response = client.embeddings.create(\n", 226 | " input=documents,\n", 227 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n", 228 | ")\n", 229 | "doc_embeddings = [item.embedding for item in doc_response.data]" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 12, 235 | "metadata": {}, 236 | "outputs": [ 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "Search results:\n", 242 | "Score: 0.8574 - Computer vision systems can detect objects in images\n", 243 | "Score: 0.8356 - Neural networks are inspired by biological neurons\n", 244 | "Score: 0.8266 - Natural language processing helps with text analysis\n", 245 | "Score: 0.8141 - Deep learning is a subset of machine learning\n", 246 | "Score: 0.7474 - Machine learning models require training data\n", 247 | "Score: 0.5936 - The quick brown fox jumps over the lazy dog\n" 248 | ] 249 | } 250 | ], 251 | "source": [ 252 | "def search_documents(query, doc_collection, doc_embeddings):\n", 253 | " \"\"\"Search for documents similar to query\"\"\"\n", 254 | " # Generate embedding for query\n", 255 | " query_response = client.embeddings.create(\n", 256 | " input=[query],\n", 257 | " model=\"mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-MLX-Q8\"\n", 258 | " )\n", 259 | " query_embedding = query_response.data[0].embedding\n", 260 | " \n", 261 | " # Calculate similarity scores\n", 262 | " similarities = []\n", 263 | " for doc_embedding in doc_embeddings:\n", 264 | " similarity = cosine_similarity_score(query_embedding, doc_embedding)\n", 265 | " similarities.append(similarity)\n", 266 | " \n", 267 | " # Return results with scores\n", 268 | " results = []\n", 269 | " for i, score in enumerate(similarities):\n", 270 | " results.append((doc_collection[i], score))\n", 271 | " \n", 272 | " # Sort by similarity score (highest first)\n", 273 | " return sorted(results, key=lambda x: x[1], reverse=True)\n", 274 | "\n", 275 | "# Example search\n", 276 | "search_results = search_documents(\"How do AI models learn?\", documents, doc_embeddings)\n", 277 | "\n", 278 | "print(\"Search results:\")\n", 279 | "for doc, score in search_results:\n", 280 | " print(f\"Score: {score:.4f} - {doc}\")" 281 | ] 282 | } 283 | ], 284 | "metadata": { 285 | "kernelspec": { 286 | "display_name": "Python 3", 287 | "language": "python", 288 | "name": "python3" 289 | }, 290 | "language_info": { 291 | "codemirror_mode": { 292 | "name": "ipython", 293 | "version": 3 294 | }, 295 | "file_extension": ".py", 296 | "mimetype": "text/x-python", 297 | "name": "python", 298 | "nbconvert_exporter": "python", 299 | "pygments_lexer": "ipython3", 300 | "version": "3.11.12" 301 | } 302 | }, 303 | "nbformat": 4, 304 | "nbformat_minor": 2 305 | } 306 | -------------------------------------------------------------------------------- /examples/pdfs/lab03.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cubist38/mlx-openai-server/25fc00aa21fc6794358849ebfb9d866a6e203eda/examples/pdfs/lab03.pdf -------------------------------------------------------------------------------- /examples/structured_outputs_examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# MLX Server Structured Output Examples\n", 8 | "\n", 9 | "This is a detailed text version of the structured output examples for MLX Server with OpenAI-compatible API." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Setup" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 8, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "from openai import OpenAI" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "## Initialize the client\n", 33 | "\n", 34 | "Connect to your local MLX server:" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 18, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "client = OpenAI(\n", 44 | " base_url = \"http://localhost:8000/v1\",\n", 45 | " api_key = \"mlx-server-api-key\"\n", 46 | ")" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## Function calling example" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 19, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "# Define the user message\n", 63 | "messages = [\n", 64 | " {\n", 65 | " \"role\": \"user\",\n", 66 | " \"content\": \"What is the weather in Tokyo?\"\n", 67 | " }\n", 68 | "]\n", 69 | "\n", 70 | "# Define the available tools/functions\n", 71 | "tools = [\n", 72 | " {\n", 73 | " \"type\": \"function\",\n", 74 | " \"function\": {\n", 75 | " \"name\": \"get_weather\",\n", 76 | " \"description\": \"Get the weather in a given city\",\n", 77 | " \"parameters\": {\n", 78 | " \"type\": \"object\",\n", 79 | " \"properties\": {\n", 80 | " \"city\": {\"type\": \"string\", \"description\": \"The city to get the weather for\"}\n", 81 | " }\n", 82 | " }\n", 83 | " }\n", 84 | " }\n", 85 | "]" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "### Non Streaming Function Calling Example" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 20, 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "name": "stdout", 102 | "output_type": "stream", 103 | "text": [ 104 | "ChatCompletion(id='chatcmpl_1754135306120611', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content='', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_1754135306725351', function=Function(arguments='{\"city\": \"Tokyo\"}', name='get_weather'), type='function', index=0)], reasoning_content=None))], created=1754135306, model='mlx-server-model', object='chat.completion', service_tier=None, system_fingerprint=None, usage=None)\n" 105 | ] 106 | } 107 | ], 108 | "source": [ 109 | "# Make the API call\n", 110 | "completion = client.chat.completions.create(\n", 111 | " model=\"mlx-server-model\",\n", 112 | " messages=messages,\n", 113 | " tools=tools,\n", 114 | " tool_choice=\"auto\",\n", 115 | " max_tokens = 512\n", 116 | ")\n", 117 | "\n", 118 | "# Get the result\n", 119 | "print(completion)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "### Streaming Function Calling Example" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 21, 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=None, reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1754135306, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n", 139 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id='call_1754135307829795', function=ChoiceDeltaToolCallFunction(arguments='', name='get_weather'), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1754135306, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n", 140 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments=' {\"', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1754135306, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n", 141 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='city', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1754135306, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n", 142 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='\":', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1754135306, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n", 143 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments=' \"', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1754135306, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n", 144 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='Tok', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1754135306, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n", 145 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='yo', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1754135306, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n", 146 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, refusal=None, role='assistant', tool_calls=[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='\"}', name=None), type='function')], reasoning_content=None), finish_reason=None, index=0, logprobs=None)], created=1754135306, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n", 147 | "ChatCompletionChunk(id='chatcmpl_1754135306422307', choices=[Choice(delta=ChoiceDelta(content='', function_call=None, refusal=None, role='assistant', tool_calls=None, reasoning_content=None), finish_reason='tool_calls', index=0, logprobs=None)], created=1754135308, model='mlx-server-model', object='chat.completion.chunk', service_tier=None, system_fingerprint=None, usage=None)\n" 148 | ] 149 | } 150 | ], 151 | "source": [ 152 | "# Set stream=True in the API call\n", 153 | "completion = client.chat.completions.create(\n", 154 | " model=\"mlx-server-model\",\n", 155 | " messages=messages,\n", 156 | " tools=tools,\n", 157 | " tool_choice=\"auto\",\n", 158 | " stream=True\n", 159 | ")\n", 160 | "\n", 161 | "# Process the streaming response\n", 162 | "for chunk in completion:\n", 163 | " print(chunk)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": {}, 169 | "source": [ 170 | "# JSON Schema Example" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 22, 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "messages = [\n", 180 | " {\n", 181 | " \"role\": \"system\",\n", 182 | " \"content\": \"Extract the address from the user input into the specified JSON format.\"\n", 183 | " },\n", 184 | " {\n", 185 | " \"role\": \"user\",\n", 186 | " \"content\": \"Please format this address: 1 Hacker Wy Menlo Park CA 94025\"\n", 187 | " }\n", 188 | "]\n", 189 | "\n", 190 | "response_format = {\n", 191 | " \"type\": \"json_schema\",\n", 192 | " \"json_schema\": {\n", 193 | " \"name\": \"Address\",\n", 194 | " \"schema\": {\n", 195 | " \"properties\": {\n", 196 | " \"address\": {\n", 197 | " \"type\": \"object\",\n", 198 | " \"properties\": {\n", 199 | " \"street\": {\"type\": \"string\"},\n", 200 | " \"city\": {\"type\": \"string\"},\n", 201 | " \"state\": {\n", 202 | " \"type\": \"string\", \n", 203 | " \"description\": \"2 letter abbreviation of the state\"\n", 204 | " },\n", 205 | " \"zip\": {\n", 206 | " \"type\": \"string\", \n", 207 | " \"description\": \"5 digit zip code\"\n", 208 | " }\n", 209 | " },\n", 210 | " \"required\": [\"street\", \"city\", \"state\", \"zip\"]\n", 211 | " }\n", 212 | " },\n", 213 | " \"required\": [\"address\"],\n", 214 | " \"type\": \"object\"\n", 215 | " }\n", 216 | " }\n", 217 | "}\n" 218 | ] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "metadata": {}, 223 | "source": [ 224 | "### Non-streaming Structured Output Example" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 23, 230 | "metadata": {}, 231 | "outputs": [ 232 | { 233 | "name": "stdout", 234 | "output_type": "stream", 235 | "text": [ 236 | "ChatCompletion(id='chatcmpl_1754135313793796', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='{\"address\": {\"street\": \"1 Hacker Wy\", \"city\": \"Menlo Park\", \"state\": \"CA\", \"zip\": \"94025\"}}', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None))], created=1754135313, model='mlx-server-model', object='chat.completion', service_tier=None, system_fingerprint=None, usage=None)\n" 237 | ] 238 | } 239 | ], 240 | "source": [ 241 | "# Make the API call\n", 242 | "completion = client.chat.completions.create(\n", 243 | " model=\"mlx-server-model\",\n", 244 | " messages=messages,\n", 245 | " max_tokens = 512,\n", 246 | " response_format = response_format\n", 247 | ")\n", 248 | "\n", 249 | "# Get the result\n", 250 | "print(completion)" 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": {}, 256 | "source": [ 257 | "### Streaming Structured Output Example" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 25, 263 | "metadata": {}, 264 | "outputs": [ 265 | { 266 | "name": "stdout", 267 | "output_type": "stream", 268 | "text": [ 269 | "{\"address\": {\"street\": \"1 Hacker Wy\", \"city\": \"Menlo Park\", \"state\": \"CA\", \"zip\": \"94025\"}}" 270 | ] 271 | } 272 | ], 273 | "source": [ 274 | "# Make the API call\n", 275 | "completion = client.chat.completions.create(\n", 276 | " model=\"mlx-server-model\",\n", 277 | " messages=messages,\n", 278 | " max_tokens = 512,\n", 279 | " response_format = response_format,\n", 280 | " stream = True\n", 281 | ")\n", 282 | "\n", 283 | "# Process the streaming response\n", 284 | "for chunk in completion:\n", 285 | " if chunk.choices[0].delta.content:\n", 286 | " print(chunk.choices[0].delta.content, end=\"\", flush=True)" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": null, 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [] 295 | } 296 | ], 297 | "metadata": { 298 | "kernelspec": { 299 | "display_name": "testing", 300 | "language": "python", 301 | "name": "python3" 302 | }, 303 | "language_info": { 304 | "codemirror_mode": { 305 | "name": "ipython", 306 | "version": 3 307 | }, 308 | "file_extension": ".py", 309 | "mimetype": "text/x-python", 310 | "name": "python", 311 | "nbconvert_exporter": "python", 312 | "pygments_lexer": "ipython3", 313 | "version": "3.11.11" 314 | } 315 | }, 316 | "nbformat": 4, 317 | "nbformat_minor": 2 318 | } 319 | -------------------------------------------------------------------------------- /examples/transcription_examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "944bf441", 6 | "metadata": {}, 7 | "source": [ 8 | "# Transcription Tasks with MLX Server and OpenAI API" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "6dafc3b4", 14 | "metadata": {}, 15 | "source": [ 16 | "This notebook demonstrates how to use the MLX Server with OpenAI-compatible API for transcription tasks.\n" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "6d89bacb", 22 | "metadata": {}, 23 | "source": [ 24 | "## Setup and Imports" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "bfcfd46c", 30 | "metadata": {}, 31 | "source": [ 32 | "First, we'll import the necessary libraries and establish a connection to the MLX Server." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 7, 38 | "id": "a74ac262", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# Import the OpenAI client for API communication\n", 43 | "from openai import OpenAI\n", 44 | "\n", 45 | "# Connect to the local MLX Server with OpenAI-compatible API\n", 46 | "client = OpenAI(\n", 47 | " base_url=\"http://localhost:8000/v1\",\n", 48 | " api_key=\"fake-api-key\",\n", 49 | ")" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 2, 55 | "id": "d68dd370", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "audio_path = \"audios/podcast.wav\"" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 5, 65 | "id": "7e8a4c6a", 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stdout", 70 | "output_type": "stream", 71 | "text": [ 72 | "Transcription(text=\" What if Tangero never had the demon slayer mark? Without the mark, Tangero's strength would have hit a ceiling, no insane speed boosts, no crazy recovery. Against upper moons, he'd be fighting on pure heart and swordsmanship alone. Imagine the red blade moment, weaker, slower and every fight, becoming a desperate struggle. Would he still beat a Kaza? Maybe, but Muzon, without that extra edge, Tangero's fate could have been completely different. And here's the twist. Tangero's biggest weapon has always been his willpower. Even without the mark, would his determination rewrite destiny? Do you think Tangero could win without the mark? Comment below.\", logprobs=None, usage={'type': 'duration', 'seconds': 72})\n" 73 | ] 74 | } 75 | ], 76 | "source": [ 77 | "with open(audio_path, \"rb\") as f:\n", 78 | " transcription = client.audio.transcriptions.create(\n", 79 | " file=f,\n", 80 | " model=\"mlx-community/whisper-tiny\",\n", 81 | " language=\"en\",\n", 82 | " response_format=\"json\",\n", 83 | " temperature=0.0,\n", 84 | " )\n", 85 | " print(transcription)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 6, 91 | "id": "58062f32", 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "name": "stdout", 96 | "output_type": "stream", 97 | "text": [ 98 | "TranscriptionTextDeltaEvent(delta=None, type=None, logprobs=None, id='transcription-32cd0b68-a2f9-4240-bc9e-4a7dd1e7e17d', object='transcription.chunk', created=1759658874, model='mlx-community/whisper-tiny', choices=[{'delta': {'content': \" What if Tangero never had the demon slayer mark? Without the mark, Tangero's strength would have hit a ceiling, no insane speed boosts, no crazy recovery. Against upper moons, he'd be fighting on pure heart and swordsmanship alone. Imagine the red blade moment.\", 'function_call': None, 'refusal': None, 'role': None, 'tool_calls': None, 'reasoning_content': None}, 'finish_reason': None, 'stop_reason': None}], usage=None)\n", 99 | "TranscriptionTextDeltaEvent(delta=None, type=None, logprobs=None, id='transcription-32cd0b68-a2f9-4240-bc9e-4a7dd1e7e17d', object='transcription.chunk', created=1759658874, model='mlx-community/whisper-tiny', choices=[{'delta': {'content': \" weaker, slower, and every fight becoming a desperate struggle. Would he still beat Akaza? Maybe, but Muson, without that extra edge, Tangeros' fate could have been completely different. But here's the twist. Tangeros' biggest weapon has always been his willpower. Even without...\", 'function_call': None, 'refusal': None, 'role': None, 'tool_calls': None, 'reasoning_content': None}, 'finish_reason': None, 'stop_reason': None}], usage=None)\n", 100 | "TranscriptionTextDeltaEvent(delta=None, type=None, logprobs=None, id='transcription-32cd0b68-a2f9-4240-bc9e-4a7dd1e7e17d', object='transcription.chunk', created=1759658874, model='mlx-community/whisper-tiny', choices=[{'delta': {'content': ' The Mark would his determination rewrite destiny? Do you think Tangero could win without the Mark? Comment below!', 'function_call': None, 'refusal': None, 'role': None, 'tool_calls': None, 'reasoning_content': None}, 'finish_reason': None, 'stop_reason': None}], usage=None)\n", 101 | "TranscriptionTextDeltaEvent(delta=None, type=None, logprobs=None, id='transcription-32cd0b68-a2f9-4240-bc9e-4a7dd1e7e17d', object='transcription.chunk', created=1759658874, model='mlx-community/whisper-tiny', choices=[{'delta': {'content': '', 'function_call': None, 'refusal': None, 'role': None, 'tool_calls': None, 'reasoning_content': None}, 'finish_reason': 'stop', 'stop_reason': None}], usage=None)\n" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "with open(audio_path, \"rb\") as f:\n", 107 | " stream = client.audio.transcriptions.create(\n", 108 | " file=f,\n", 109 | " model=\"mlx-community/whisper-tiny\",\n", 110 | " language=\"en\",\n", 111 | " response_format=\"json\",\n", 112 | " temperature=0.0,\n", 113 | " stream=True,\n", 114 | " )\n", 115 | " for chunk in stream:\n", 116 | " print(chunk)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "id": "9495f63e", 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [] 126 | } 127 | ], 128 | "metadata": { 129 | "kernelspec": { 130 | "display_name": "Python 3", 131 | "language": "python", 132 | "name": "python3" 133 | }, 134 | "language_info": { 135 | "codemirror_mode": { 136 | "name": "ipython", 137 | "version": 3 138 | }, 139 | "file_extension": ".py", 140 | "mimetype": "text/x-python", 141 | "name": "python", 142 | "nbconvert_exporter": "python", 143 | "pygments_lexer": "ipython3", 144 | "version": "3.11.13" 145 | } 146 | }, 147 | "nbformat": 4, 148 | "nbformat_minor": 5 149 | } 150 | -------------------------------------------------------------------------------- /examples/videos/demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cubist38/mlx-openai-server/25fc00aa21fc6794358849ebfb9d866a6e203eda/examples/videos/demo.mp4 -------------------------------------------------------------------------------- /examples/vlm_embeddings_examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Vision-Language Model (VLM) Embeddings with MLX Server\n", 8 | "\n", 9 | "This notebook demonstrates how to leverage the embeddings endpoint of MLX Server through its OpenAI-compatible API. Vision-Language Models (VLMs) can process both images and text, allowing for multimodal understanding and representation.\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "\n", 17 | "## Introduction\n", 18 | "\n", 19 | "MLX Server provides an efficient way to serve multimodal models on Apple Silicon. In this notebook, we'll explore how to:\n", 20 | "\n", 21 | "- Generate embeddings for text and images\n", 22 | "- Work with the OpenAI-compatible API\n", 23 | "- Calculate similarity between text and image representations\n", 24 | "- Understand how these embeddings can be used for practical applications\n", 25 | "\n", 26 | "Embeddings are high-dimensional vector representations of content that capture semantic meaning, making them useful for search, recommendation systems, and other AI applications." 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## 1. Setup and API Connection\n", 34 | "\n", 35 | "- A local server endpoint (`http://localhost:8000/v1`)\n", 36 | "- A placeholder API key (since MLX Server doesn't require authentication by default)\n", 37 | "\n", 38 | "Make sure you have MLX Server running locally before executing this notebook." 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 1, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "# Import the OpenAI client for API communication\n", 48 | "from openai import OpenAI\n", 49 | "\n", 50 | "# Connect to the local MLX Server with OpenAI-compatible API\n", 51 | "client = OpenAI(\n", 52 | " base_url=\"http://localhost:8000/v1\",\n", 53 | " api_key=\"fake-api-key\",\n", 54 | ")" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "## 2. Image Processing for API Requests\n", 62 | "\n", 63 | "When working with image inputs, we need to prepare them in a format that the API can understand. The OpenAI-compatible API expects images to be provided as base64-encoded data URIs.\n", 64 | "\n", 65 | "Below, we'll import the necessary libraries and define a helper function to convert PIL Image objects to the required format." 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 2, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "from PIL import Image\n", 75 | "from io import BytesIO\n", 76 | "import base64" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 3, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "# To send images to the API, we need to convert them to base64-encoded strings in a data URI format.\n", 86 | "\n", 87 | "def image_to_base64(image: Image.Image):\n", 88 | " \"\"\"\n", 89 | " Convert a PIL Image to a base64-encoded data URI string that can be sent to the API.\n", 90 | " \n", 91 | " Args:\n", 92 | " image: A PIL Image object\n", 93 | " \n", 94 | " Returns:\n", 95 | " A data URI string with the base64-encoded image\n", 96 | " \"\"\"\n", 97 | " # Convert image to bytes\n", 98 | " buffer = BytesIO()\n", 99 | " image.save(buffer, format=\"PNG\")\n", 100 | " buffer.seek(0)\n", 101 | " image_data = buffer.getvalue()\n", 102 | " \n", 103 | " # Encode as base64\n", 104 | " image_base64 = base64.b64encode(image_data).decode('utf-8')\n", 105 | " \n", 106 | " # Create the data URI format required by the API\n", 107 | " mime_type = \"image/png\" \n", 108 | " image_uri = f\"data:{mime_type};base64,{image_base64}\"\n", 109 | " \n", 110 | " return image_uri" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "## 3. Loading and Preparing an Image\n", 118 | "\n", 119 | "Now we'll load a sample image (a green dog in this case) and convert it to the base64 format required by the API. This image will be used to generate embeddings in the subsequent steps." 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 5, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "image = Image.open(\"images/green_dog.jpeg\")\n", 129 | "image_uri = image_to_base64(image)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "## 4. Generating Embeddings" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 12, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "# Generate embedding for a single text input\n", 146 | "prompt = \"Describe the image in detail\"\n", 147 | "image_embedding = client.embeddings.create(\n", 148 | " input=[prompt],\n", 149 | " model=\"mlx-community/Qwen2.5-VL-3B-Instruct-4bit\",\n", 150 | " extra_body = {\n", 151 | " \"image_url\": image_uri\n", 152 | " }\n", 153 | ").data[0].embedding\n", 154 | "\n", 155 | "text = \"A green dog looking at the camera\"\n", 156 | "text_embedding = client.embeddings.create(\n", 157 | " input=[text],\n", 158 | " model=\"mlx-community/Qwen2.5-VL-3B-Instruct-4bit\"\n", 159 | ").data[0].embedding" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "## 5. Comparing Text and Image Embeddings\n", 167 | "\n", 168 | "One of the powerful features of VLM embeddings is that they create a shared vector space for both text and images. This means we can directly compare how similar a text description is to an image's content by calculating the cosine similarity between their embeddings.\n", 169 | "\n", 170 | "A higher similarity score (closer to 1.0) indicates that the text description closely matches the image content according to the model's understanding." 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 13, 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "name": "stdout", 180 | "output_type": "stream", 181 | "text": [ 182 | "0.8473370724651375\n" 183 | ] 184 | } 185 | ], 186 | "source": [ 187 | "import numpy as np\n", 188 | "\n", 189 | "def cosine_similarity(a, b):\n", 190 | " a = np.array(a)\n", 191 | " b = np.array(b)\n", 192 | " return np.dot(a, b)\n", 193 | "\n", 194 | "similarity = cosine_similarity(image_embedding, text_embedding)\n", 195 | "print(similarity)" 196 | ] 197 | } 198 | ], 199 | "metadata": { 200 | "kernelspec": { 201 | "display_name": "Python 3", 202 | "language": "python", 203 | "name": "python3" 204 | }, 205 | "language_info": { 206 | "codemirror_mode": { 207 | "name": "ipython", 208 | "version": 3 209 | }, 210 | "file_extension": ".py", 211 | "mimetype": "text/x-python", 212 | "name": "python", 213 | "nbconvert_exporter": "python", 214 | "pygments_lexer": "ipython3", 215 | "version": "3.11.12" 216 | } 217 | }, 218 | "nbformat": 4, 219 | "nbformat_minor": 2 220 | } 221 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from app import __version__ 3 | 4 | 5 | setup( 6 | name="mlx-openai-server", 7 | url="https://github.com/cubist38/mlx-openai-server", 8 | author="Gia-Huy Vuong", 9 | author_email="cubist38@gmail.com", 10 | version=__version__, 11 | description="A high-performance API server that provides OpenAI-compatible endpoints for MLX models. Built with Python and FastAPI, it enables efficient, scalable, and user-friendly local deployment of MLX-based multimodal models with an OpenAI-compatible interface. Supports text, vision, and audio processing capabilities. Perfect for developers looking to run MLX models locally while maintaining compatibility with existing OpenAI-based applications.", 12 | long_description=open("README.md").read(), 13 | long_description_content_type="text/markdown", 14 | packages=find_packages(), 15 | install_requires=[ 16 | "mlx-vlm==0.3.4", 17 | "mlx-lm==0.28.3", 18 | "torchvision==0.23.0", 19 | "mlx-whisper==0.4.3", 20 | "mlx-embeddings==0.0.4", 21 | "fastapi==0.115.14", 22 | "av==16.0.1", 23 | "uvicorn==0.35.0", 24 | "Pillow==10.4.0", 25 | "click==8.2.1", 26 | "loguru==0.7.3", 27 | "outlines==1.1.1", 28 | "librosa==0.11.0", 29 | "openai-harmony==0.0.4", 30 | "json_repair==0.52.1", 31 | "python-multipart==0.0.20" 32 | ], 33 | extras_require={ 34 | "dev": [ 35 | "pytest", 36 | "black", 37 | "isort", 38 | "flake8", 39 | ] 40 | }, 41 | entry_points={ 42 | "console_scripts": [ 43 | "mlx-openai-server=app.cli:cli", 44 | ], 45 | }, 46 | python_requires=">=3.11", 47 | classifiers=[ 48 | "Programming Language :: Python :: 3", 49 | "License :: OSI Approved :: MIT License", 50 | "Operating System :: OS Independent", 51 | ], 52 | ) -------------------------------------------------------------------------------- /tests/test_base_tool_parser.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from app.handler.parser.base import BaseToolParser, ParseState 4 | 5 | 6 | class TestBaseToolParser(unittest.TestCase): 7 | def setUp(self): 8 | self.test_cases = [ 9 | { 10 | "name": "simple function call", 11 | "chunks": '''## 12 | #{"#name#":# "#get#_weather#",# "#arguments#":# {"#city#":# "#H#ue#"}} 13 | ## 14 | ## 15 | #{"#name#":# "#get#_weather#",# "#arguments#":# {"#city#":# "#Sy#dney#"}} 16 | ###'''.split('#') 17 | , 18 | "expected_outputs": [ 19 | {'name': 'get_weather', 'arguments': ''}, 20 | {'name': None, 'arguments': ' {"'}, 21 | {'name': None, 'arguments': 'city'}, 22 | {'name': None, 'arguments': '":'}, 23 | {'name': None, 'arguments': ' "'}, 24 | {'name': None, 'arguments': 'H'}, 25 | {'name': None, 'arguments': 'ue'}, 26 | {'name': None, 'arguments': '"}'}, 27 | '\n', 28 | {'name': 'get_weather', 'arguments': ''}, 29 | {'name': None, 'arguments': ' {"'}, 30 | {'name': None, 'arguments': 'city'}, 31 | {'name': None, 'arguments': '":'}, 32 | {'name': None, 'arguments': ' "'}, 33 | {'name': None, 'arguments': 'Sy'}, 34 | {'name': None, 'arguments': 'dney'}, 35 | {'name': None, 'arguments': '"}'}, 36 | ] 37 | }, 38 | { 39 | "name": "code function call", 40 | "chunks": r'''@@ 41 | @@{"@@name@@":@@ "@@python@@",@@ "@@arguments@@":@@ {"@@code@@":@@ "@@def@@ calculator@@(a@@,@@ b@@,@@ operation@@):\@@n@@ @@ if@@ operation@@ ==@@ '@@add@@'\@@n@@ @@ return@@ a@@ +@@ b@@\n@@ @@ if@@ operation@@ ==@@ '@@subtract@@'\@@n@@ @@ return@@ a@@ -@@ b@@\n@@ @@ if@@ operation@@ ==@@ '@@multiply@@'\@@n@@ @@ return@@ a@@ *@@ b@@\n@@ @@ if@@ operation@@ ==@@ '@@divide@@'\@@n@@ @@ return@@ a@@ /@@ b@@\n@@ @@ return@@ '@@Invalid@@ operation@@'@@"}} 42 | @@@@@@'''.split('@@') 43 | , 44 | "expected_outputs": [ 45 | {'name': 'python', 'arguments': ''}, 46 | {'name': None, 'arguments': ' {"'}, 47 | {'name': None, 'arguments': 'code'}, 48 | {'name': None, 'arguments': '":'}, 49 | {'name': None, 'arguments': ' "'}, 50 | {'name': None, 'arguments': 'def'}, 51 | {'name': None, 'arguments': ' calculator'}, 52 | {'name': None, 'arguments': '(a'}, 53 | {'name': None, 'arguments': ','}, 54 | {'name': None, 'arguments': ' b'}, 55 | {'name': None, 'arguments': ','}, 56 | {'name': None, 'arguments': ' operation'}, 57 | {'name': None, 'arguments': '):\\'}, 58 | {'name': None, 'arguments': 'n'}, 59 | {'name': None, 'arguments': ' '}, 60 | {'name': None, 'arguments': ' if'}, 61 | {'name': None, 'arguments': ' operation'}, 62 | {'name': None, 'arguments': ' =='}, 63 | {'name': None, 'arguments': " '"}, 64 | {'name': None, 'arguments': 'add'}, 65 | {'name': None, 'arguments': "'\\"}, 66 | {'name': None, 'arguments': 'n'}, 67 | {'name': None, 'arguments': ' '}, 68 | {'name': None, 'arguments': ' return'}, 69 | {'name': None, 'arguments': ' a'}, 70 | {'name': None, 'arguments': ' +'}, 71 | {'name': None, 'arguments': ' b'}, 72 | {'name': None, 'arguments': '\\n'}, 73 | {'name': None, 'arguments': ' '}, 74 | {'name': None, 'arguments': ' if'}, 75 | {'name': None, 'arguments': ' operation'}, 76 | {'name': None, 'arguments': ' =='}, 77 | {'name': None, 'arguments': " '"}, 78 | {'name': None, 'arguments': 'subtract'}, 79 | {'name': None, 'arguments': "'\\"}, 80 | {'name': None, 'arguments': 'n'}, 81 | {'name': None, 'arguments': ' '}, 82 | {'name': None, 'arguments': ' return'}, 83 | {'name': None, 'arguments': ' a'}, 84 | {'name': None, 'arguments': ' -'}, 85 | {'name': None, 'arguments': ' b'}, 86 | {'name': None, 'arguments': '\\n'}, 87 | {'name': None, 'arguments': ' '}, 88 | {'name': None, 'arguments': ' if'}, 89 | {'name': None, 'arguments': ' operation'}, 90 | {'name': None, 'arguments': ' =='}, 91 | {'name': None, 'arguments': " '"}, 92 | {'name': None, 'arguments': 'multiply'}, 93 | {'name': None, 'arguments': "'\\"}, 94 | {'name': None, 'arguments': 'n'}, 95 | {'name': None, 'arguments': ' '}, 96 | {'name': None, 'arguments': ' return'}, 97 | {'name': None, 'arguments': ' a'}, 98 | {'name': None, 'arguments': ' *'}, 99 | {'name': None, 'arguments': ' b'}, 100 | {'name': None, 'arguments': '\\n'}, 101 | {'name': None, 'arguments': ' '}, 102 | {'name': None, 'arguments': ' if'}, 103 | {'name': None, 'arguments': ' operation'}, 104 | {'name': None, 'arguments': ' =='}, 105 | {'name': None, 'arguments': " '"}, 106 | {'name': None, 'arguments': 'divide'}, 107 | {'name': None, 'arguments': "'\\"}, 108 | {'name': None, 'arguments': 'n'}, 109 | {'name': None, 'arguments': ' '}, 110 | {'name': None, 'arguments': ' return'}, 111 | {'name': None, 'arguments': ' a'}, 112 | {'name': None, 'arguments': ' /'}, 113 | {'name': None, 'arguments': ' b'}, 114 | {'name': None, 'arguments': '\\n'}, 115 | {'name': None, 'arguments': ' '}, 116 | {'name': None, 'arguments': ' return'}, 117 | {'name': None, 'arguments': " '"}, 118 | {'name': None, 'arguments': 'Invalid'}, 119 | {'name': None, 'arguments': ' operation'}, 120 | {'name': None, 'arguments': "'"}, 121 | {'name': None, 'arguments': '"}'}, 122 | ] 123 | }, 124 | ] 125 | 126 | def test_parse_stream(self): 127 | for test_case in self.test_cases: 128 | with self.subTest(msg=test_case["name"]): 129 | parser = BaseToolParser("", "") 130 | outputs = [] 131 | 132 | for chunk in test_case["chunks"]: 133 | result = parser.parse_stream(chunk) 134 | if result: 135 | outputs.append(result) 136 | 137 | 138 | self.assertEqual(len(outputs), len(test_case["expected_outputs"]), 139 | f"Expected {len(test_case['expected_outputs'])} outputs, got {len(outputs)}") 140 | 141 | for i, (output, expected) in enumerate(zip(outputs, test_case["expected_outputs"])): 142 | self.assertEqual(output, expected, 143 | f"Chunk {i}: Expected {expected}, got {output}") 144 | 145 | if __name__ == '__main__': 146 | unittest.main() 147 | --------------------------------------------------------------------------------