├── .gitignore ├── ruff.toml ├── demo-data ├── toucan.jpeg └── chameleon.webp ├── tests ├── __init__.py ├── processors │ ├── __init__.py │ ├── README.md │ └── dump_logits_processor.py ├── conftest.py ├── shared.py ├── test_cache_wrapper.py ├── utils │ └── test_progress_decorators.py ├── data │ └── ben_franklin_autobiography_start.txt └── test_stop_string_processor.py ├── mlx_engine ├── external │ ├── models │ │ ├── ernie4_5 │ │ │ ├── README.md │ │ │ ├── configuration_ernie4_5.py │ │ │ └── tokenization_ernie4_5.py │ │ ├── ernie4_5_moe │ │ │ ├── README.md │ │ │ └── configuration_ernie4_5_moe.py │ │ └── lfm2_vl │ │ │ ├── README.md │ │ │ └── configuration_lfm2_vl.py │ └── datasets │ │ └── dill.py ├── utils │ ├── token.py │ ├── outlines_transformer_tokenizer.py │ ├── top_logprobs.py │ ├── disable_hf_download.py │ ├── logger.py │ ├── register_models.py │ ├── prompt_processing.py │ ├── set_seed.py │ ├── speculative_decoding.py │ ├── kv_cache_quantization.py │ ├── eot_tokens.py │ ├── fix_mistral_pre_tokenizer.py │ ├── progress_decorators.py │ └── image_utils.py ├── model_kit │ ├── __init__.py │ ├── patches │ │ ├── ernie_4_5.py │ │ └── gemma3n.py │ ├── vision_add_ons │ │ ├── base.py │ │ ├── qwen3_vl.py │ │ ├── qwen3_vl_moe.py │ │ ├── process_prompt_with_images.py │ │ ├── qwen2_vl.py │ │ ├── gemma3.py │ │ ├── qwen_vl_utils.py │ │ ├── pixtral.py │ │ ├── mistral3.py │ │ ├── gemma3n.py │ │ ├── lfm2_vl.py │ │ └── load_utils.py │ └── model_kit.py ├── __init__.py ├── processors │ └── repetition_penalty_processor.py ├── vision_model_kit │ ├── _transformers_compatibility.py │ ├── vision_model_kit.py │ └── vision_model_wrapper.py └── stop_string_processor.py ├── .pre-commit-config.yaml ├── .github └── workflows │ ├── ruff.yml │ └── cla.yml ├── LICENSE ├── requirements.txt ├── CONTRIBUTING.md ├── README.md └── demo.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .venv*/ 3 | .DS_Store 4 | .idea 5 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | [lint.per-file-ignores] 2 | "mlx_engine/__init__.py" = ["E402"] 3 | -------------------------------------------------------------------------------- /demo-data/toucan.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmstudio-ai/mlx-engine/HEAD/demo-data/toucan.jpeg -------------------------------------------------------------------------------- /demo-data/chameleon.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmstudio-ai/mlx-engine/HEAD/demo-data/chameleon.webp -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Trigger syntax validation of processors 2 | from . import processors # noqa: F401 3 | -------------------------------------------------------------------------------- /tests/processors/__init__.py: -------------------------------------------------------------------------------- 1 | # Import for syntax and dependency checking only 2 | from .dump_logits_processor import DumpLogitsProcessor # noqa: F401 3 | -------------------------------------------------------------------------------- /mlx_engine/external/models/ernie4_5/README.md: -------------------------------------------------------------------------------- 1 | These files were retrieved from https://huggingface.co/mlx-community/ERNIE-4.5-0.3B-PT-4bit/tree/8af9e5fc3fe9f3c44ec933978869f309454d2238 -------------------------------------------------------------------------------- /mlx_engine/external/models/ernie4_5_moe/README.md: -------------------------------------------------------------------------------- 1 | These files were retrieved from https://huggingface.co/mlx-community/ERNIE-4.5-21B-A3B-PT-4bit/tree/6a55392c9930d43543bfa0a8ac6b67985bf4bd1d 2 | -------------------------------------------------------------------------------- /mlx_engine/utils/token.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class Token: 7 | """ 8 | Base dataclass for a single generated token. 9 | """ 10 | 11 | id: int 12 | text: str 13 | logprob: float 14 | from_draft: Optional[bool] = None 15 | -------------------------------------------------------------------------------- /mlx_engine/external/models/lfm2_vl/README.md: -------------------------------------------------------------------------------- 1 | These files were retrieved from https://huggingface.co/LiquidAI/LFM2-VL-1.6B/tree/95bd1b5ff38beb09619b894f8b6882a0c66eac2c. 2 | 3 | `configuration_lfm2_vl.py` is actually a custom file formed by ripping out all modeling code except the configuration class 4 | from [`modeling_lfm2_vl.py`](https://huggingface.co/LiquidAI/LFM2-VL-1.6B/blob/95bd1b5ff38beb09619b894f8b6882a0c66eac2c/modeling_lfm2_vl.py). 5 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.11.0 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | args: [ --fix ] 9 | exclude: ^mlx_engine/external/ 10 | stages: [pre-commit] 11 | # Run the formatter. 12 | - id: ruff-format 13 | args: [--target-version, py311 ] 14 | exclude: ^mlx_engine/external/ 15 | stages: [pre-commit] 16 | 17 | -------------------------------------------------------------------------------- /mlx_engine/model_kit/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model Kit module with automatic compatibility patches. 3 | 4 | This module automatically applies compatibility patches for various model types 5 | by replacing classes in their respective modules with derived, compatible versions. 6 | """ 7 | 8 | from .patches.gemma3n import apply_patches as _apply_patches_gemma3n 9 | from .patches.ernie_4_5 import apply_patches as _apply_patches_ernie_4_5 10 | 11 | _apply_patches_gemma3n() 12 | _apply_patches_ernie_4_5() 13 | -------------------------------------------------------------------------------- /mlx_engine/utils/outlines_transformer_tokenizer.py: -------------------------------------------------------------------------------- 1 | from outlines.models.transformers import TransformerTokenizer 2 | from mlx_engine.external.datasets.dill import Hasher 3 | 4 | 5 | class OutlinesTransformerTokenizer(TransformerTokenizer): 6 | """ 7 | Update the outlines TransformerTokenizer to use our own Hasher class, so that we don't need the datasets dependency 8 | 9 | This class and the external dependency can be removed when the following import is deleted 10 | https://github.com/dottxt-ai/outlines/blob/69418d/outlines/models/transformers.py#L117 11 | """ 12 | 13 | def __hash__(self): 14 | return hash(Hasher.hash(self.tokenizer)) 15 | -------------------------------------------------------------------------------- /mlx_engine/model_kit/patches/ernie_4_5.py: -------------------------------------------------------------------------------- 1 | """ " 2 | Patch outlines_core to handler ERNIE tokenizer. 3 | Specifically, fix the handling of these tokens: 4 | - `>�` 5 | - `�@` 6 | - `�@�@` 7 | 8 | An issue is opened in outlines_core tracking this issue: 9 | https://github.com/dottxt-ai/outlines-core/issues/222 10 | """ 11 | 12 | import re 13 | import outlines_core.fsm.regex 14 | 15 | 16 | def apply_patches(): 17 | """ 18 | Apply patches to the outlines_core module. 19 | """ 20 | # Update the replacement regex to fix the ernie tokenizer 21 | # Patching this line https://github.com/dottxt-ai/outlines-core/blob/0.1.26/python/outlines_core/fsm/regex.py#L349 22 | outlines_core.fsm.regex.re_replacement_seq = re.compile(r"^▁*\.*>*�+\.*s*@*(�@)*$") 23 | -------------------------------------------------------------------------------- /.github/workflows/ruff.yml: -------------------------------------------------------------------------------- 1 | name: Lint Checks with Ruff 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - '**' 7 | 8 | 9 | jobs: 10 | pre-commit-check: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Checkout code 15 | uses: actions/checkout@v4 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: '3.11' 21 | cache: "pip" 22 | 23 | - name: Install pre-commit 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install pre-commit 27 | 28 | - name: Cache pre-commit 29 | uses: actions/cache@v4 30 | with: 31 | path: ~/.cache/pre-commit/ 32 | key: pre-commit|${{ env.pythonLocation }}|${{ hashFiles('.pre-commit-config.yaml') }} 33 | 34 | - name: Run pre-commit 35 | run: pre-commit run --show-diff-on-failure --color=always --all-files 36 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def pytest_addoption(parser): 5 | """Add command line option for heavy tests.""" 6 | parser.addoption( 7 | "--heavy", 8 | action="store_true", 9 | default=False, 10 | help="run heavy tests (e.g., tests that require large models or long execution time)", 11 | ) 12 | 13 | 14 | def pytest_configure(config): 15 | """Configure pytest markers.""" 16 | config.addinivalue_line( 17 | "markers", "heavy: mark test as heavy (requires --heavy option to run)" 18 | ) 19 | 20 | 21 | def pytest_collection_modifyitems(config, items): 22 | """Skip heavy tests unless --heavy option is provided.""" 23 | if config.getoption("--heavy"): 24 | # --heavy given in cli: do not skip heavy tests 25 | return 26 | 27 | skip_heavy = pytest.mark.skip(reason="need --heavy option to run") 28 | for item in items: 29 | if "heavy" in item.keywords: 30 | item.add_marker(skip_heavy) 31 | -------------------------------------------------------------------------------- /mlx_engine/utils/top_logprobs.py: -------------------------------------------------------------------------------- 1 | from mlx_engine.utils.token import Token 2 | 3 | import mlx.core as mx 4 | 5 | 6 | def summarize_top_logprobs( 7 | tokenizer, logprobs: mx.array, top_logprobs: int 8 | ) -> list[Token]: 9 | # find the sorted indices (in descending order) of the logprobs 10 | sorted_indices = mx.argsort(-logprobs) 11 | 12 | # sort the logprobs in descending order 13 | sorted_logprobs = logprobs[..., sorted_indices] 14 | 15 | # slice the top logprobs 16 | top_indices = sorted_indices[:top_logprobs] 17 | top_logprobs = sorted_logprobs[:top_logprobs] 18 | 19 | # decode the top indices 20 | text_list = [tokenizer.decode(index) for index in top_indices.tolist()] 21 | 22 | # return list of TokenLogprob with id (int), text (str), and logprob (float) 23 | return [ 24 | Token(id=int(idx), text=txt, logprob=float(prob)) 25 | for idx, txt, prob in zip( 26 | top_indices.tolist(), text_list, top_logprobs.tolist() 27 | ) 28 | ] 29 | -------------------------------------------------------------------------------- /mlx_engine/utils/disable_hf_download.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | import sys 3 | import huggingface_hub 4 | 5 | # Store the original function before we patch anything 6 | _original_snapshot_download = huggingface_hub.snapshot_download 7 | 8 | 9 | @wraps(_original_snapshot_download) 10 | def snapshot_download(*args, **kwargs): 11 | """ 12 | Wrapper around huggingface_hub.snapshot_download that disables it 13 | """ 14 | raise RuntimeError( 15 | "Internal error: Cannot proceed without downloading from huggingface. Please report this error to the LM Studio team." 16 | ) 17 | 18 | 19 | def patch_huggingface_hub(): 20 | """ 21 | Patch the huggingface_hub module to use our local-only snapshot_download. 22 | This ensures that any import of snapshot_download from huggingface_hub 23 | will use our wrapped version. 24 | """ 25 | huggingface_hub.snapshot_download = snapshot_download 26 | # Also patch the module in sys.modules to ensure any other imports get our version 27 | sys.modules["huggingface_hub"].snapshot_download = snapshot_download 28 | -------------------------------------------------------------------------------- /mlx_engine/model_kit/vision_add_ons/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import mlx.core as mx 4 | from mlx import nn 5 | 6 | 7 | class BaseVisionAddOn: 8 | """ 9 | Base class that defines the interface for a VisionAddOn. 10 | """ 11 | 12 | @abstractmethod 13 | def __init__(self): 14 | """ 15 | Where load of vision model components is intended to occur. 16 | """ 17 | 18 | @abstractmethod 19 | def compute_embeddings( 20 | self, 21 | text_model: nn.Module, 22 | prompt_tokens: mx.array, 23 | images_b64: list[str], 24 | max_size: tuple[int, int] | None, 25 | ) -> tuple[mx.array, mx.array]: 26 | """ 27 | Returns input ids and input embeddings for the language model after text/image merging of the prompt. 28 | 29 | Args: 30 | text_model: Text model for embedding tokens 31 | prompt_tokens: Input prompt tokens 32 | images_b64: List of base64-encoded images 33 | max_size: Maximum image size as (width, height) tuple. If None, no resizing. 34 | """ 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 LM Studio 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mlx_engine/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic logging setup for mlx_engine. 3 | 4 | This module configures standard library logging to output to stderr. 5 | Individual modules should get their own loggers using logging.getLogger(__name__). 6 | """ 7 | 8 | import logging 9 | import sys 10 | 11 | 12 | def setup_logging(): 13 | """Setup basic logging configuration for mlx_engine.""" 14 | # Silence exceptions that happen within the logger 15 | logging.raiseExceptions = False 16 | 17 | # Configure root logger for mlx_engine 18 | logger = logging.getLogger("mlx_engine") 19 | logger.setLevel(logging.INFO) 20 | 21 | # Remove any existing handlers 22 | logger.handlers.clear() 23 | 24 | # Create handler that writes to stderr 25 | handler = logging.StreamHandler(sys.stderr) 26 | handler.setLevel(logging.INFO) 27 | 28 | # Simple formatter with logger name and level 29 | formatter = logging.Formatter("[%(module)s][%(levelname)s]: %(message)s") 30 | handler.setFormatter(formatter) 31 | 32 | logger.addHandler(handler) 33 | 34 | # Prevent propagation to root logger 35 | logger.propagate = False 36 | 37 | return logger 38 | -------------------------------------------------------------------------------- /mlx_engine/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | `mlx_engine` is LM Studio's LLM inferencing engine for Apple MLX 3 | """ 4 | 5 | __all__ = [ 6 | "load_model", 7 | "load_draft_model", 8 | "is_draft_model_compatible", 9 | "unload_draft_model", 10 | "create_generator", 11 | "tokenize", 12 | ] 13 | 14 | from pathlib import Path 15 | import os 16 | 17 | from .utils.disable_hf_download import patch_huggingface_hub 18 | from .utils.register_models import register_models 19 | from .utils.logger import setup_logging 20 | 21 | 22 | from .generate import ( 23 | load_model, 24 | load_draft_model, 25 | is_draft_model_compatible, 26 | unload_draft_model, 27 | create_generator, 28 | tokenize, 29 | ) 30 | 31 | patch_huggingface_hub() 32 | register_models() 33 | setup_logging() 34 | 35 | 36 | def _set_outlines_cache_dir(cache_dir: Path | str): 37 | """ 38 | Set the cache dir for Outlines. 39 | 40 | Outlines reads the OUTLINES_CACHE_DIR environment variable to 41 | determine where to read/write its cache files 42 | """ 43 | cache_dir = Path(cache_dir).expanduser().resolve() 44 | os.environ["OUTLINES_CACHE_DIR"] = str(cache_dir) 45 | 46 | 47 | _set_outlines_cache_dir(Path("~/.cache/lm-studio/.internal/outlines")) 48 | -------------------------------------------------------------------------------- /tests/processors/README.md: -------------------------------------------------------------------------------- 1 | This directory contains processors that are useful during testing, but do not have a prod use-case. They can easily be inserted into the generate flow during development by modifying `mlx_engine/generate.py`. 2 | 3 | For example, we can add a `DumpLogitsProcessor` that writes the logits on each generated token as a CSV to the `logits-dump` directory: 4 | 5 | ```diff 6 | --- a/mlx_engine/generate.py 7 | +++ b/mlx_engine/generate.py 8 | @@ -12,6 +12,9 @@ from mlx_engine.processors.outlines_logits_processor import OutlinesLogitsProces 9 | from mlx_engine.processors.repetition_penalty_processor import ( 10 | RepetitionPenaltyProcessor, 11 | ) 12 | +from tests.processors.dump_logits_processor import ( 13 | + DumpLogitsProcessor, 14 | +) 15 | from mlx_engine.utils.token import Token 16 | from mlx_engine.utils.eot_tokens import get_eot_token_ids 17 | from mlx_engine.utils.top_logprobs import summarize_top_logprobs 18 | @@ -236,6 +239,9 @@ def create_generator( 19 | token_history=cached_tokens, **repetition_penalty_kwargs 20 | ) 21 | ) 22 | + generate_args["logits_processors"].append( 23 | + DumpLogitsProcessor(model_kit.tokenizer.vocab, Path("logits-dump")) 24 | + ) 25 | 26 | # Set up sampler 27 | generate_args["sampler"] = make_sampler( 28 | ``` 29 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | airportsdata==20250909 2 | annotated-doc==0.0.4 3 | annotated-types==0.7.0 4 | anyio==4.12.0 5 | certifi==2025.11.12 6 | charset-normalizer==3.4.4 7 | click==8.3.1 8 | cloudpickle==3.1.2 9 | dill==0.4.0 10 | diskcache==5.6.3 11 | fastapi==0.124.2 12 | genson==1.3.0 13 | h11==0.16.0 14 | hf-xet==1.2.0 15 | httpcore==1.0.9 16 | httpx==0.28.1 17 | huggingface-hub==1.2.2 18 | iniconfig==2.3.0 19 | interegular==0.3.3 20 | iso3166==2.1.1 21 | jsonpath-ng==1.7.0 22 | jsonschema==4.25.1 23 | jsonschema-specifications==2025.9.1 24 | lark==1.3.1 25 | mlx==0.30.0 26 | mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git@f3ed856 27 | mlx-metal==0.30.0 28 | mlx-vlm @ git+https://github.com/Blaizzy/mlx-vlm.git@1d8622b0 29 | nest-asyncio==1.6.0 30 | outlines==1.1.1 31 | outlines-core==0.1.26 32 | packaging==25.0 33 | pillow==12.0.0 34 | pluggy==1.6.0 35 | ply==3.11 36 | protobuf==6.33.2 37 | pydantic==2.12.5 38 | pydantic-core==2.41.5 39 | pygments==2.19.2 40 | pytest==9.0.2 41 | pyyaml==6.0.3 42 | referencing==0.37.0 43 | regex==2025.11.3 44 | requests==2.32.5 45 | rpds-py==0.30.0 46 | safetensors==0.7.0 47 | sentencepiece==0.2.1 48 | shellingham==1.5.4 49 | starlette==0.50.0 50 | tokenizers==0.22.1 51 | torchvision==0.24.0 52 | tqdm==4.67.1 53 | transformers==5.0.0rc1 54 | typer-slim==0.20.0 55 | typing-inspection==0.4.2 56 | urllib3==2.6.2 57 | xxhash==3.6.0 58 | -------------------------------------------------------------------------------- /mlx_engine/utils/register_models.py: -------------------------------------------------------------------------------- 1 | """Register local model-specific code to bypass enabling `trust_remote_code`.""" 2 | 3 | from transformers import AutoTokenizer, AutoProcessor 4 | import transformers.models.auto.processing_auto as processing_auto 5 | from mlx_engine.external.models.ernie4_5.configuration_ernie4_5 import Ernie4_5_Config 6 | from mlx_engine.external.models.ernie4_5_moe.configuration_ernie4_5_moe import ( 7 | Ernie4_5_MoeConfig, 8 | ) 9 | from mlx_engine.external.models.ernie4_5.tokenization_ernie4_5 import Ernie4_5_Tokenizer 10 | from mlx_engine.external.models.lfm2_vl.processing_lfm2_vl import Lfm2VlProcessor 11 | from mlx_engine.external.models.lfm2_vl.configuration_lfm2_vl import Lfm2VlConfig 12 | 13 | 14 | def register_models(): 15 | # exist_ok=True should be an indication that we should remove external code 16 | # ref https://github.com/lmstudio-ai/mlx-engine/issues/211 17 | AutoTokenizer.register(Ernie4_5_Config, Ernie4_5_Tokenizer, exist_ok=True) 18 | AutoTokenizer.register(Ernie4_5_MoeConfig, Ernie4_5_Tokenizer, exist_ok=True) 19 | 20 | # mlx-vlm is not compatible with the transformers version of lfm2 21 | # See https://github.com/lmstudio-ai/mlx-engine/issues/211#issuecomment-3397933488 22 | del processing_auto.PROCESSOR_MAPPING_NAMES["lfm2_vl"] 23 | AutoProcessor.register(Lfm2VlConfig, Lfm2VlProcessor, exist_ok=False) 24 | -------------------------------------------------------------------------------- /mlx_engine/utils/prompt_processing.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable 2 | 3 | from mlx import nn 4 | import mlx.core as mx 5 | 6 | from mlx_engine.cache_wrapper import CacheWrapper 7 | 8 | 9 | def process_prompt_text_only( 10 | prompt_tokens: mx.array, 11 | cache_wrapper: CacheWrapper, 12 | generate_args: dict, 13 | draft_model: Optional[nn.Module], 14 | speculative_decoding_toggle: Optional[bool], 15 | prompt_progress_callback: Optional[Callable[[float], bool]], 16 | ): 17 | if cache_wrapper is None: 18 | raise ValueError("Cache wrapper is not initialized, cannot process prompt") 19 | # Make sure cache's draft model setting aligns with speculative decoding toggle 20 | should_use_draft_model = ( 21 | speculative_decoding_toggle 22 | if speculative_decoding_toggle is not None 23 | else draft_model is not None 24 | ) 25 | if should_use_draft_model: 26 | if not draft_model: 27 | raise ValueError( 28 | "Speculative decoding toggle is enabled for prompt processing but no " 29 | "draft model is loaded" 30 | ) 31 | cache_wrapper.set_draft_model(draft_model) 32 | else: 33 | cache_wrapper.unset_draft_model() 34 | 35 | # Check for common tokens with the previous cache and re-use the cache if possible 36 | prompt_tokens = cache_wrapper.update_cache( 37 | prompt_tokens, 38 | prompt_progress_callback, 39 | ) 40 | generate_args["prompt_cache"] = cache_wrapper.cache 41 | return prompt_tokens 42 | -------------------------------------------------------------------------------- /mlx_engine/utils/set_seed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import mlx.core as mx 4 | import time 5 | from typing import Optional 6 | import random 7 | 8 | 9 | def set_seed(seed: Optional[int]) -> None: 10 | """ 11 | Set the seed for all random number generators used in mlx-engine to ensure reproducible results. 12 | This function synchronizes the random states across multiple libraries including MLX, NumPy, 13 | PyTorch, and Python's built-in random module. 14 | 15 | Args: 16 | seed: The seed value to initialize random number generators. If None, a seed will be 17 | automatically generated using the current nanosecond timestamp. The final seed 18 | value will be truncated to 32 bits for compatibility across all random number 19 | generators. 20 | 21 | Raises: 22 | ValueError: If the provided seed is negative. 23 | 24 | Returns: 25 | None 26 | 27 | Note: 28 | This function affects the following random number generators: 29 | - MLX (mx.random) 30 | - NumPy (np.random) 31 | - PyTorch (torch.manual_seed) 32 | - Python's built-in random module 33 | """ 34 | if seed is None: 35 | # Get nanosecond timestamp and use it as seed 36 | seed = int(time.time_ns()) 37 | 38 | if seed < 0: 39 | raise ValueError("Seed must be a non-negative integer.") 40 | seed = seed & (2**32 - 1) # Ensure seed fits in 32 bits 41 | 42 | # For MLX and MLX_LM 43 | mx.random.seed(seed) 44 | 45 | # MLX_VLM depends on numpy and torch 46 | np.random.seed(seed) 47 | torch.manual_seed(seed) 48 | 49 | # Just in case 50 | random.seed(seed) 51 | -------------------------------------------------------------------------------- /mlx_engine/utils/speculative_decoding.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import mlx.nn as nn 3 | import logging 4 | 5 | from mlx_engine.model_kit.model_kit import ModelKit 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def determine_draft_model_for_generation( 11 | model_kit: ModelKit, speculative_decoding_toggle: Optional[bool] 12 | ) -> Optional[nn.Module]: 13 | """ 14 | Based on ModelKit and speculative_decoding_toggle, determine draft model to use for 15 | generation, or None 16 | """ 17 | if speculative_decoding_toggle is None: 18 | # toggle not set, use draft model if available 19 | return model_kit.draft_model 20 | elif speculative_decoding_toggle and model_kit.draft_model is None: 21 | raise ValueError( 22 | "Speculative decoding toggle is explicitly enabled but no draft model is loaded" 23 | ) 24 | elif not speculative_decoding_toggle and model_kit.draft_model is not None: 25 | logger.info( 26 | "Draft model is loaded but speculative decoding is disabled for this generation" 27 | ) 28 | return None 29 | else: 30 | # toggle set to true, draft model available 31 | return model_kit.draft_model 32 | 33 | 34 | def configure_num_draft_tokens_in_generate_args( 35 | model_kit: ModelKit, 36 | draft_model: Optional[nn.Module], 37 | num_draft_tokens: Optional[int], 38 | generate_args: dict, 39 | ): 40 | """ 41 | Modifies generate_args in place to include num_draft_tokens if applicable 42 | """ 43 | # Only configure draft tokens when all required conditions are met 44 | should_add_num_draft_tokens_to_args = ( 45 | type(model_kit) is ModelKit 46 | and draft_model is not None 47 | and num_draft_tokens is not None 48 | ) 49 | if should_add_num_draft_tokens_to_args: 50 | generate_args["num_draft_tokens"] = num_draft_tokens 51 | -------------------------------------------------------------------------------- /mlx_engine/utils/kv_cache_quantization.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | # https://github.com/ml-explore/mlx/blob/f288db8d34c0bcfa0867b6458ab0277c5e86ed45/mlx/fast.cpp#L782 4 | VALID_KV_BITS = (2, 3, 4, 6, 8) 5 | # https://github.com/ml-explore/mlx/blob/f288db8d34c0bcfa0867b6458ab0277c5e86ed45/mlx/fast.cpp#L775 6 | VALID_KV_GROUP_SIZE = (32, 64, 128) 7 | 8 | 9 | def get_kv_cache_quantization_params( 10 | kv_bits: Optional[int], 11 | kv_group_size: Optional[int], 12 | quantized_kv_start: Optional[int], 13 | ) -> Tuple[Optional[int], Optional[int], Optional[int]]: 14 | """ 15 | Validates and processes KV cache quantization parameters. 16 | 17 | Args: 18 | kv_bits: Number of bits for quantization. If None, disables quantization. 19 | kv_group_size: Group size for quantization. Defaults to 64 if quantization enabled. 20 | quantized_kv_start: Step to begin quantization. Defaults to 0 if quantization enabled. 21 | 22 | Returns: 23 | Tuple of (kv_bits, kv_group_size, quantized_kv_start), all None if quantization disabled. 24 | 25 | Raises: 26 | ValueError: If kv_bits is invalid or missing when other params are set. 27 | """ 28 | if any([kv_group_size, quantized_kv_start]) and kv_bits is None: 29 | raise ValueError("Enabling KV Cache Quantization requires kv_bits to be set") 30 | 31 | if kv_bits is None: 32 | return None, None, None 33 | 34 | # defaults taken from here: 35 | # https://github.com/ml-explore/mlx-examples/blob/3d793ec/llms/mlx_lm/utils.py#L352-L353 36 | if kv_group_size is None: 37 | kv_group_size = 64 38 | if quantized_kv_start is None: 39 | quantized_kv_start = 0 40 | 41 | if kv_bits not in VALID_KV_BITS: 42 | raise ValueError(f"Invalid kv_bits value. Must be one of {VALID_KV_BITS}") 43 | if kv_group_size not in VALID_KV_GROUP_SIZE: 44 | raise ValueError( 45 | f"Invalid kv_group_size value. Must be one of {VALID_KV_GROUP_SIZE}" 46 | ) 47 | 48 | return kv_bits, kv_group_size, quantized_kv_start 49 | -------------------------------------------------------------------------------- /tests/shared.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | import subprocess 4 | 5 | from mlx_engine.generate import load_model, load_draft_model, tokenize 6 | 7 | 8 | def model_getter(model_name: str): 9 | """Helper method to get a model, prompt user to download if not found""" 10 | 11 | with open(Path("~/.lmstudio-home-pointer").expanduser().resolve(), "r") as f: 12 | lmstudio_home = Path(f.read().strip()) 13 | model_path = lmstudio_home / "models" / model_name 14 | 15 | # Check if model exists, if not prompt user to download 16 | if not model_path.exists(): 17 | print(f"\nModel {model_name} not found at {model_path}") 18 | 19 | def greenify(text): 20 | return f"\033[92m{text}\033[0m" 21 | 22 | response = input( 23 | f"Would you like to download the model {greenify(model_name)}? (y/N): " 24 | ) 25 | if response.lower() == "y": 26 | print(f"Downloading model with command: lms get {model_name}") 27 | subprocess.run(["lms", "get", model_name], check=True) 28 | else: 29 | print(f"Model {model_name} not found") 30 | sys.exit(1) 31 | 32 | return model_path 33 | 34 | 35 | def model_load_and_tokenize_prompt( 36 | model_name: str, 37 | prompt: str, 38 | max_kv_size=4096, 39 | trust_remote_code=False, 40 | draft_model_name=None, 41 | ): 42 | """Helper method to test a model""" 43 | print(f"Testing model {model_name}") 44 | 45 | # Check if model exists, if not prompt user to download 46 | model_path = model_getter(model_name) 47 | 48 | # Load the model 49 | model_kit = load_model( 50 | model_path=model_path, 51 | max_kv_size=max_kv_size, 52 | trust_remote_code=trust_remote_code, 53 | ) 54 | 55 | # Load the draft model if any 56 | if draft_model_name is not None: 57 | draft_model_path = model_getter(draft_model_name) 58 | load_draft_model(model_kit, draft_model_path) 59 | 60 | # Tokenize the prompt 61 | prompt_tokens = tokenize(model_kit, prompt) 62 | 63 | return model_kit, prompt_tokens 64 | -------------------------------------------------------------------------------- /mlx_engine/processors/repetition_penalty_processor.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | from mlx_lm.sample_utils import make_repetition_penalty 3 | 4 | """ 5 | Wrapper for the standard mlx-lm repetition penalty processor 6 | ref: https://github.com/ml-explore/mlx-lm/blob/69195f8632869d35306d085de7dc4e7d6954baac/mlx_lm/sample_utils.py#L245-L255 7 | 8 | This wrapper enables the repetition penalty processor to take into account the tokens that have already been cached, 9 | without the need for recomputing the logits for those tokens. 10 | """ 11 | 12 | 13 | class RepetitionPenaltyProcessor: 14 | def __init__( 15 | self, 16 | token_history: list[int], 17 | repetition_penalty: float, 18 | repetition_context_size: int, 19 | ): 20 | self.token_history = token_history 21 | self.repetition_context_size = repetition_context_size 22 | self.repetition_penalty_function = make_repetition_penalty( 23 | repetition_penalty, repetition_context_size 24 | ) 25 | 26 | def __call__(self, tokens: mx.array, logits: mx.array) -> mx.array: 27 | """ 28 | Apply repetition penalty to the logits, accounting for tokens that have already been processed within 29 | the same prediction. 30 | 31 | Args: 32 | tokens: The tokens to be processed. 33 | logits: The logits to be processed. 34 | """ 35 | # append historical tokens s.t. repetition penalty accounts tokens that have already been processed in this gen 36 | num_tokens_to_prepend_from_history = max( 37 | self.repetition_context_size - len(tokens), 0 38 | ) 39 | historical_tokens = ( 40 | self.token_history[-num_tokens_to_prepend_from_history:] 41 | if num_tokens_to_prepend_from_history > 0 42 | else [] 43 | ) 44 | historical_tokens_mx = mx.array( 45 | historical_tokens, 46 | dtype=mx.int64, 47 | ) 48 | all_tokens_to_consider = mx.concat([historical_tokens_mx, tokens]) 49 | result = self.repetition_penalty_function(all_tokens_to_consider, logits) 50 | return result 51 | -------------------------------------------------------------------------------- /mlx_engine/model_kit/vision_add_ons/qwen3_vl.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | from mlx import nn 5 | import mlx.core as mx 6 | 7 | from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn 8 | from mlx_engine.model_kit.vision_add_ons.load_utils import load_vision_addon 9 | from mlx_engine.model_kit.vision_add_ons.qwen_vl_utils import compute_qwen_vl_embeddings 10 | 11 | from mlx_vlm.models.qwen3_vl import ( 12 | VisionModel as Qwen3VLVisionTower, 13 | ModelConfig as Qwen3VLModelConfig, 14 | VisionConfig as Qwen3VLVisionConfig, 15 | TextConfig as Qwen3VLTextConfig, 16 | Model as Qwen3VLModel, 17 | ) 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class Qwen3_VLVisionAddOn(BaseVisionAddOn): 23 | """ 24 | Vision add-on for Qwen3-VL Dense models. 25 | """ 26 | 27 | def __init__(self, model_path: Path): 28 | """Initialize Qwen3_VLVisionAddOn with vision components loaded from the given path.""" 29 | super().__init__() 30 | 31 | # Store the model class for use in compute_embeddings 32 | self.model_cls = Qwen3VLModel 33 | 34 | # Load vision components 35 | self.vision_tower, _, self.config, self.processor = load_vision_addon( 36 | model_path=model_path, 37 | model_config_class=Qwen3VLModelConfig, 38 | vision_config_class=Qwen3VLVisionConfig, 39 | text_config_class=Qwen3VLTextConfig, 40 | vision_tower_class=Qwen3VLVisionTower, 41 | multi_modal_projector_class=None, 42 | logger=logger, 43 | ) 44 | 45 | def compute_embeddings( 46 | self, 47 | text_model: nn.Module, 48 | prompt_tokens: mx.array, 49 | images_b64: list[str], 50 | max_size: tuple[int, int] | None, 51 | ) -> tuple[mx.array, mx.array]: 52 | """ 53 | Compute input_ids and embeddings for text with images. 54 | """ 55 | 56 | return compute_qwen_vl_embeddings( 57 | addon=self, 58 | text_model=text_model, 59 | prompt_tokens=prompt_tokens, 60 | images_b64=images_b64, 61 | qwen_vl_version=3, 62 | max_size=max_size, 63 | ) 64 | -------------------------------------------------------------------------------- /mlx_engine/model_kit/vision_add_ons/qwen3_vl_moe.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | from mlx import nn 5 | import mlx.core as mx 6 | 7 | from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn 8 | from mlx_engine.model_kit.vision_add_ons.load_utils import load_vision_addon 9 | from mlx_engine.model_kit.vision_add_ons.qwen_vl_utils import compute_qwen_vl_embeddings 10 | 11 | from mlx_vlm.models.qwen3_vl_moe import ( 12 | VisionModel as Qwen3_VL_MoEVisionTower, 13 | ModelConfig as Qwen3_VL_MoEModelConfig, 14 | VisionConfig as Qwen3_VL_MoEVisionConfig, 15 | TextConfig as Qwen3_VL_MoETextConfig, 16 | Model as Qwen3_VL_MoEModel, 17 | ) 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class Qwen3_VL_MoEVisionAddOn(BaseVisionAddOn): 23 | """ 24 | Vision add-on for Qwen3-VL MoE models. 25 | """ 26 | 27 | def __init__(self, model_path: Path): 28 | """Initialize Qwen3_VL_MoEVisionAddOn with vision components loaded from the given path.""" 29 | super().__init__() 30 | 31 | # Store the model class for use in compute_embeddings 32 | self.model_cls = Qwen3_VL_MoEModel 33 | 34 | # Load vision components 35 | self.vision_tower, _, self.config, self.processor = load_vision_addon( 36 | model_path=model_path, 37 | model_config_class=Qwen3_VL_MoEModelConfig, 38 | vision_config_class=Qwen3_VL_MoEVisionConfig, 39 | text_config_class=Qwen3_VL_MoETextConfig, 40 | vision_tower_class=Qwen3_VL_MoEVisionTower, 41 | multi_modal_projector_class=None, 42 | logger=logger, 43 | ) 44 | 45 | def compute_embeddings( 46 | self, 47 | text_model: nn.Module, 48 | prompt_tokens: mx.array, 49 | images_b64: list[str], 50 | max_size: tuple[int, int] | None, 51 | ) -> tuple[mx.array, mx.array]: 52 | """ 53 | Compute input_ids and embeddings for text with images. 54 | """ 55 | 56 | return compute_qwen_vl_embeddings( 57 | addon=self, 58 | text_model=text_model, 59 | prompt_tokens=prompt_tokens, 60 | images_b64=images_b64, 61 | qwen_vl_version=3, 62 | max_size=max_size, 63 | ) 64 | -------------------------------------------------------------------------------- /.github/workflows/cla.yml: -------------------------------------------------------------------------------- 1 | name: "CLA Assistant" 2 | 3 | on: 4 | issue_comment: 5 | types: [created] 6 | pull_request_target: 7 | types: [opened, closed, synchronize, labeled] # Added "labeled" event to check for label changes 8 | workflow_dispatch: # Allow manual triggering of the workflow 9 | 10 | permissions: 11 | actions: write 12 | contents: write 13 | pull-requests: write 14 | statuses: write 15 | checks: write 16 | 17 | jobs: 18 | CLAAssistant: 19 | runs-on: ubuntu-latest 20 | steps: 21 | - name: "CLA Assistant" 22 | if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target' 23 | uses: contributor-assistant/github-action@v2.6.1 24 | env: 25 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 26 | PERSONAL_ACCESS_TOKEN: ${{ secrets.CLA_PAT }} 27 | with: 28 | path-to-signatures: 'signatures/version1/cla.json' 29 | path-to-document: 'https://lmstudio.ai/opensource/cla' 30 | remote-organization-name: lmstudio-ai 31 | remote-repository-name: cla-signatures 32 | branch: 'main' 33 | allowlist: yagil,ryan-the-crayon,azisislm,mattjcly,neilmehta24,ncoghlan 34 | 35 | - name: "Label PR as CLA Signed" 36 | if: success() 37 | run: | 38 | if [[ "${{ github.event_name }}" == "pull_request_target" ]]; then 39 | PR_NUMBER="${{ github.event.pull_request.number }}" 40 | elif [[ "${{ github.event_name }}" == "issue_comment" ]]; then 41 | PR_NUMBER="${{ github.event.issue.number }}" 42 | fi 43 | ENDPOINT="https://api.github.com/repos/${{ github.repository }}/issues/$PR_NUMBER/labels" 44 | curl -L -X POST \ 45 | -H "Accept: application/vnd.github+json" \ 46 | -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ 47 | -H "X-GitHub-Api-Version: 2022-11-28" \ 48 | -d '{"labels":["CLA signed"]}' \ 49 | $ENDPOINT 50 | curl -L -X DELETE \ 51 | -H "Accept: application/vnd.github+json" \ 52 | -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ 53 | -H "X-GitHub-Api-Version: 2022-11-28" \ 54 | "https://api.github.com/repos/${{ github.repository }}/issues/$PR_NUMBER/labels/Request%20CLA" || true 55 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Our Open Source Projects 2 | 3 | First off, thank you for considering contributing to our open source projects! 👾❤️ 4 | 5 | ## Communication 6 | 7 | - **The best way to communicate with the team is to open an issue in this repository** 8 | - For bug reports, include steps to reproduce, expected behavior, and actual behavior 9 | - For feature requests, explain the use case and benefits clearly 10 | 11 | ## Before You Contribute 12 | 13 | - **If you find an existing issue you'd like to work on, please comment on it first and tag the team** 14 | - This allows us to provide guidance and ensures your time is well spent 15 | - **We discourage drive-by feature PRs** without prior discussion - we want to make sure your efforts align with our roadmap and won't go to waste 16 | 17 | ## Development Workflow 18 | 19 | 1. **Fork the repository** to your own GitHub account 20 | 2. **Create a new branch** for your changes 21 | 3. Follow the development instructions in the README.md 22 | 4. Make your changes 23 | 5. Test thoroughly 24 | 6. Make sure the pre-commit hooks pass 25 | 7. Submit your PR 26 | 27 | ## Creating Good Pull Requests 28 | 29 | ### Keep PRs Small and Focused 30 | 31 | - Address one concern per PR 32 | - Smaller PRs are easier to review and more likely to be merged quickly 33 | 34 | ### Write Thoughtful PR Descriptions 35 | 36 | - Clearly explain what the PR does and why 37 | - When applicable, show before/after states or screenshots 38 | - Include any relevant context for reviewers 39 | - Reference the issue(s) your PR addresses with GitHub keywords (Fixes #123, Resolves #456) 40 | 41 | ### Quality Expectations 42 | 43 | - Follow existing code style and patterns 44 | - Include tests for new functionality 45 | - Ensure all tests pass 46 | - Update documentation as needed 47 | 48 | ## Code Review Process 49 | 50 | - Maintainers will review your PR as soon as possible 51 | - We may request changes or clarification 52 | - Once approved, a maintainer will merge your contribution 53 | 54 | ## Contributor License Agreement (CLA) 55 | 56 | - We require all contributors to sign a Contributor License Agreement (CLA) 57 | - For first-time contributors, a bot will automatically comment on your PR with instructions 58 | - You'll need to accept the CLA before we can merge your contribution 59 | - This is standard practice in open source and helps protect both contributors and the project 60 | 61 | Thank you for your contributions! 62 | -------------------------------------------------------------------------------- /mlx_engine/model_kit/patches/gemma3n.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gemma3n compatibility patches using derive and override pattern. 3 | 4 | This module provides derived classes that inherit from the original mlx-lm classes 5 | and override specific methods to handle compatibility issues between mlx-vlm and mlx-lm. 6 | """ 7 | 8 | from mlx_lm.models.gemma3n import Model, TextConfig 9 | from mlx.utils import tree_flatten, tree_unflatten 10 | import inspect 11 | 12 | 13 | class CompatibleTextConfig(TextConfig): 14 | """ 15 | TextConfig that handles intermediate_size as list or integer. 16 | 17 | mlx-vlm's conversion (transformers under the hood) changes the 18 | "text_config" -> "intermediate_size" value from a single integer to 19 | a list of integers of length number of layers. 20 | mlx-lm's model loader expects it to be a single integer. 21 | This class handles both formats by taking the first value if it's a list. 22 | """ 23 | 24 | @classmethod 25 | def from_dict(cls, params): 26 | config_dict = { 27 | k: v for k, v in params.items() if k in inspect.signature(cls).parameters 28 | } 29 | if "intermediate_size" in config_dict and isinstance( 30 | config_dict["intermediate_size"], list 31 | ): 32 | config_dict["intermediate_size"] = config_dict["intermediate_size"][0] 33 | return cls(**config_dict) 34 | 35 | 36 | class CompatibleModel(Model): 37 | """ 38 | Model that handles mlx-vlm compatible weight ordering. 39 | 40 | mlx-vlm's conversion changes the weight keys from the original huggingface weights. 41 | For example, "model.language_model.embed_tokens.weight" becomes 42 | "language_model.model.embed_tokens.weight". 43 | mlx-lm expects the weight keys to be in the original huggingface order. 44 | This class handles both weight formats. 45 | """ 46 | 47 | def sanitize(self, weights): 48 | weights = tree_unflatten(list(weights.items())) 49 | if weights.get("language_model", {}).get("model", None) is not None: 50 | weights = {"model": {"language_model": weights["language_model"]["model"]}} 51 | weights = dict(tree_flatten(weights)) 52 | return super().sanitize(weights) 53 | 54 | 55 | def apply_patches(): 56 | """ 57 | Apply gemma3n compatibility patches by replacing classes in the mlx_lm module. 58 | """ 59 | import mlx_lm.models.gemma3n 60 | 61 | mlx_lm.models.gemma3n.Model = CompatibleModel 62 | mlx_lm.models.gemma3n.TextConfig = CompatibleTextConfig 63 | -------------------------------------------------------------------------------- /mlx_engine/vision_model_kit/_transformers_compatibility.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | import logging 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | def fix_qwen2_5_vl_image_processor(model_path: Path): 9 | """ 10 | Update the `image_processor_type` in the preprocessor_config.json file to Qwen2VLImageProcessor 11 | Ref https://huggingface.co/mlx-community/Qwen2.5-VL-7B-Instruct-4bit/commit/fdcc572e8b05ba9daeaf71be8c9e4267c826ff9b 12 | """ 13 | try: 14 | # We are looking for a specific entry, so if any of this throws, we don't need to do anything 15 | with open(model_path / "preprocessor_config.json", "r") as f: 16 | image_processor_type = json.load(f)["image_processor_type"] 17 | with open(model_path / "config.json", "r") as f: 18 | model_type = json.load(f)["model_type"] 19 | except: # noqa: E722 20 | return 21 | 22 | if not ( 23 | image_processor_type == "Qwen2_5_VLImageProcessor" 24 | and model_type == "qwen2_5_vl" 25 | ): 26 | return 27 | 28 | # Replace image_processor_type with Qwen2VLImageProcessor 29 | logger.warning( 30 | "Replacing `image_processor_type` with Qwen2VLImageProcessor in preprocessor_config.json" 31 | ) 32 | with open(model_path / "preprocessor_config.json", "r") as f: 33 | preprocessor_config = json.load(f) 34 | preprocessor_config["image_processor_type"] = "Qwen2VLImageProcessor" 35 | with open(model_path / "preprocessor_config.json", "w") as f: 36 | json.dump(preprocessor_config, f) 37 | 38 | 39 | def fix_qwen2_vl_preprocessor(model_path: Path): 40 | """ 41 | Remove the `size` entry from the preprocessor_config.json file, which is broken as of transformers 5.50.0 42 | Ref the transformers implementation: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct/blob/e28f5d3/preprocessor_config.json 43 | """ 44 | try: 45 | # We are looking for a specific entry, so if any of this throws, we don't need to do anything 46 | with open(model_path / "config.json", "r") as f: 47 | model_type = json.load(f)["model_type"] 48 | if model_type != "qwen2_vl": 49 | return 50 | with open(model_path / "preprocessor_config.json", "r") as f: 51 | json.load(f)["size"] 52 | except: # noqa: E722 53 | return 54 | 55 | logger.warning("Removing `size` entry from preprocessor_config.json") 56 | with open(model_path / "preprocessor_config.json", "r") as f: 57 | preprocessor_config = json.load(f) 58 | preprocessor_config.pop("size") 59 | with open(model_path / "preprocessor_config.json", "w") as f: 60 | json.dump(preprocessor_config, f) 61 | -------------------------------------------------------------------------------- /mlx_engine/utils/eot_tokens.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from mlx_engine.model_kit.model_kit import ModelKit 3 | from mlx_engine.vision_model_kit.vision_model_kit import VisionModelKit 4 | 5 | # Taken from https://github.com/ggml-org/llama.cpp/blob/971f245/src/llama-vocab.cpp#L1807-L1814 6 | DEFAULT_EOT_TOKENS = [ 7 | "<|eot_id|>", 8 | "<|im_end|>", 9 | "<|end|>", 10 | "", 11 | "<|endoftext|>", 12 | "", 13 | "_", 14 | "<|end▁of▁sentence|>", 15 | ] 16 | 17 | MODEL_TYPE_TO_EOT_TOKENS = {"gpt_oss": ["<|return|>", "<|call|>"]} 18 | 19 | 20 | def _get_eot_token_ids(tokenizer, model_type: Optional[str] = None) -> set[int]: 21 | """ 22 | Get the token ID of common end-of-text tokens, using the provided tokenizer. 23 | 24 | If the EOT token str cannot be converted into a single token ID, it is discarded as a candidate. 25 | """ 26 | if ( 27 | isinstance(model_type, str) 28 | and len(MODEL_TYPE_TO_EOT_TOKENS.get(model_type, [])) > 0 29 | ): 30 | eot_tokens = MODEL_TYPE_TO_EOT_TOKENS[model_type] 31 | else: 32 | eot_tokens = DEFAULT_EOT_TOKENS 33 | 34 | # Convert EOT tokens to token IDs 35 | eot_token_ids = [ 36 | tokenizer.encode(eot_str, add_special_tokens=False) for eot_str in eot_tokens 37 | ] 38 | 39 | # Find all elements that are either a single integer or a list with a single integer 40 | single_int = [token_id for token_id in eot_token_ids if isinstance(token_id, int)] 41 | single_element_list = [ 42 | token_id[0] 43 | for token_id in eot_token_ids 44 | if isinstance(token_id, list) and len(token_id) == 1 45 | ] 46 | 47 | return set(single_int + single_element_list) 48 | 49 | 50 | def sanitize_eos_tokens(model_kit: ModelKit | VisionModelKit) -> None: 51 | # Remove (probably) incorrect EOS tokens 52 | tokenizer = model_kit.tokenizer 53 | temp_tokens = set() 54 | for id in tokenizer.eos_token_ids: 55 | text = tokenizer.decode(id) 56 | # Specific override for RNJ-1 57 | if model_kit.model_type == "gemma3_text" and id == 1 and text == '"': 58 | continue 59 | temp_tokens.add(id) 60 | temp_tokens = temp_tokens.union(_get_eot_token_ids(tokenizer, model_kit.model_type)) 61 | 62 | if len(temp_tokens) == 0: 63 | raise RuntimeError( 64 | f"EOS tokens cannot be empty. Before cleaning, the tokens were {tokenizer.eos_token_ids}" 65 | ) 66 | tokenizer.eos_token_ids = temp_tokens 67 | 68 | if tokenizer.eos_token_id not in tokenizer.eos_token_ids: 69 | tokenizer.eos_token_id = min(tokenizer.eos_token_ids) 70 | tokenizer._tokenizer.eos_token_id = tokenizer.eos_token_id 71 | -------------------------------------------------------------------------------- /mlx_engine/model_kit/vision_add_ons/process_prompt_with_images.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, NamedTuple 2 | import mlx.core as mx 3 | from mlx_vlm import prepare_inputs 4 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 5 | import logging 6 | 7 | from mlx_engine.utils.image_utils import convert_to_pil, custom_resize 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class ProcessedImagePrompt(NamedTuple): 13 | input_ids: mx.array 14 | pixel_values: mx.array 15 | attention_mask: mx.array 16 | other_inputs: dict 17 | 18 | 19 | def common_process_prompt_with_images( 20 | prompt_tokens: mx.array, 21 | images_b64: List[str], 22 | processor: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], 23 | config, # expected to be a ModelConfig object as defined by mlx-vlm. Can vary by model 24 | max_size: tuple[int, int] | None, 25 | ) -> ProcessedImagePrompt: 26 | """ 27 | Common prompt processing used by mlx-vlm vision add-ons. 28 | Returns a named tuple with all processed inputs. 29 | 30 | Args: 31 | prompt_tokens: Input prompt tokens 32 | images_b64: List of base64-encoded images 33 | processor: Tokenizer/processor for the model 34 | config: Model configuration object 35 | max_size: Maximum image size as (width, height) tuple. If None, no resizing. 36 | """ 37 | if len(images_b64) == 0: 38 | raise ValueError("Images must be non-empty") 39 | detokenizer = processor.detokenizer 40 | detokenizer.reset() 41 | [detokenizer.add_token(token) for token in prompt_tokens] 42 | detokenizer.finalize() 43 | prompt = detokenizer.text 44 | 45 | logger.info(f"Prompt dump: {prompt}\n") 46 | 47 | images = convert_to_pil(images_b64) 48 | images = custom_resize(images, max_size=max_size) 49 | 50 | if hasattr(config, "image_token_index"): 51 | image_token_index = config.image_token_index 52 | elif hasattr(config.vision_config, "image_token_id"): 53 | image_token_index = config.vision_config.image_token_id 54 | else: 55 | image_token_index = None 56 | 57 | inputs = prepare_inputs( 58 | processor=processor, 59 | images=images, 60 | prompts=prompt, 61 | image_token_index=image_token_index, 62 | resize_shape=None, 63 | ) 64 | 65 | input_ids = inputs["input_ids"] 66 | pixel_values = inputs["pixel_values"] 67 | attention_mask = inputs["attention_mask"] 68 | other_model_inputs = { 69 | k: v 70 | for k, v in inputs.items() 71 | if k not in ["input_ids", "pixel_values", "attention_mask"] 72 | } 73 | 74 | return ProcessedImagePrompt( 75 | input_ids=input_ids, 76 | pixel_values=pixel_values, 77 | attention_mask=attention_mask, 78 | other_inputs=other_model_inputs, 79 | ) 80 | -------------------------------------------------------------------------------- /mlx_engine/utils/fix_mistral_pre_tokenizer.py: -------------------------------------------------------------------------------- 1 | from tokenizers import Tokenizer 2 | from pathlib import Path 3 | from mlx_lm.tokenizer_utils import TokenizerWrapper 4 | import logging 5 | from transformers import LlamaTokenizer 6 | import traceback 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | # List taken from here 11 | # https://github.com/huggingface/transformers/blob/b9951b4/src/transformers/tokenization_utils_tokenizers.py#L1187 12 | _LEGACY_MISTRAL_MODEL_TYPES = [ 13 | "mistral", 14 | "mistral3", 15 | "voxtral", 16 | "ministral", 17 | "pixtral", 18 | ] 19 | 20 | 21 | def fix_mistral_pre_tokenizer( 22 | *, tokenizer: TokenizerWrapper, model_path: Path, model_type: str 23 | ) -> None: 24 | """ 25 | transformers v5 introduces breaking changes in their tokenization framework. 26 | Unfortunately, some of the mistral models were patched incorrectly in transformers during this breakage. 27 | 28 | In transformers-world, using mistral_common for tokenization is a possibility, but we can't use that here 29 | since mistral_common (by design) doesn't tokenize the special tokens correctly 30 | 31 | For mistral models that were introduced in transformers v4, check to see if tokenization is broken. The breakage 32 | specifically happens for LlamaTokenizer instances 33 | 34 | Tokenization is considered broken if it doesn't handle whitespace correctly. For example, tokenizing 35 | `Hello world` should result in tokens `["Hello", " world"]`, and not `["Hello", "world"]`. Note the missing 36 | whitespace before `world` 37 | """ 38 | if model_type not in _LEGACY_MISTRAL_MODEL_TYPES: 39 | return 40 | logger.info("Detected mistral model. Checking if tokenizer needs fixing...") 41 | if not isinstance(tokenizer._tokenizer, LlamaTokenizer): 42 | logger.info(f"Tokenizer is of type {type(tokenizer._tokenizer)}. Skipping fix.") 43 | return 44 | if not _tokenizer_is_broken(tokenizer): 45 | logger.info("Tokenizer working as expected.") 46 | return 47 | 48 | # Fix pre-tokenizer 49 | try: 50 | tok = Tokenizer.from_file(str(model_path / "tokenizer.json")) 51 | tokenizer._tokenizer._tokenizer.pre_tokenizer = tok.pre_tokenizer 52 | except Exception: 53 | logger.warning(f"Failed to fix tokenizer: {traceback.format_exc()}.") 54 | return 55 | 56 | if _tokenizer_is_broken(tokenizer): 57 | logger.warning("Tokenizer could not be fixed.") 58 | return 59 | 60 | logger.info("Successfully fixed tokenizer.") 61 | 62 | 63 | def _tokenizer_is_broken(tokenizer: TokenizerWrapper) -> bool: 64 | """ 65 | `["about", "Paris"]` shows us that the tokenization is broken because 66 | the whitespace is missing between `about` and `Paris`. 67 | """ 68 | test_prompt = "Tell me about Paris" 69 | tokens = tokenizer.tokenize(test_prompt) 70 | return tokens[-2:] == ["about", "Paris"] 71 | -------------------------------------------------------------------------------- /tests/processors/dump_logits_processor.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | from pathlib import Path 3 | from os import makedirs 4 | from datetime import datetime 5 | from csv import DictWriter 6 | from typing import Dict 7 | 8 | """ 9 | Wrapper to dump logits to a directory for debugging. 10 | """ 11 | 12 | 13 | class DumpLogitsProcessor: 14 | def __init__( 15 | self, 16 | vocab: Dict[str, int], 17 | dump_directory: Path, 18 | ): 19 | token_id_to_str_map = {} 20 | for token_str, token_id in vocab.items(): 21 | token_id_to_str_map[token_id] = token_str 22 | self._vocab = [token_id_to_str_map[i] for i in range(len(token_id_to_str_map))] 23 | if len(self._vocab) != len(vocab): 24 | raise RuntimeError( 25 | f"Vocab of size {len(vocab)} had {len(self._vocab)} unique token IDs." 26 | ) 27 | # Append current time so that we can re-run the same command without 28 | # overwriting previous outputs 29 | self._dump_directory = dump_directory / datetime.now().isoformat() 30 | makedirs(self._dump_directory, exist_ok=True) 31 | print(f"Will dump logits to {self._dump_directory}") 32 | 33 | def __call__(self, tokens: mx.array, logits: mx.array) -> mx.array: 34 | """ 35 | Dump the logits to a file in the specified directory 36 | 37 | Args: 38 | tokens: The tokens to be processed. 39 | logits: The logits to be processed. 40 | """ 41 | dump_file = self._dump_directory / f"logits_{len(tokens):04d}.csv" 42 | flat_logits = logits.squeeze(0).tolist() 43 | probs = mx.softmax(logits, 1).squeeze(0).tolist() 44 | vocab = self._vocab.copy() 45 | if len(flat_logits) < len(vocab): 46 | # Does not make sense for number of logits to be smaller than vocab 47 | raise RuntimeError( 48 | f"Got {len(flat_logits)} logits but vocab size {len(vocab)}" 49 | ) 50 | elif len(flat_logits) > len(vocab): 51 | # Also weird, but (maybe) expected for language models. 52 | # Pad the vocab with "!!!OUT OF RANGE!!!" 53 | vocab.extend( 54 | ["!!!OUT OF RANGE!!!" for _ in range(len(flat_logits) - len(vocab))] 55 | ) 56 | output = sorted( 57 | [ 58 | { 59 | "token_id": token_id, 60 | "token_str": token_str, 61 | "logit": logit, 62 | "prob": prob, 63 | } 64 | for (token_id, (token_str, logit, prob)) in enumerate( 65 | zip(vocab, flat_logits, probs, strict=True) 66 | ) 67 | ], 68 | key=lambda d: d["prob"], 69 | reverse=True, 70 | ) 71 | with open(dump_file, "w") as f: 72 | writer = DictWriter(f, ["token_id", "token_str", "logit", "prob"]) 73 | writer.writeheader() 74 | for row in output: 75 | writer.writerow(row) 76 | return logits 77 | -------------------------------------------------------------------------------- /mlx_engine/model_kit/vision_add_ons/qwen2_vl.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import json 3 | from pathlib import Path 4 | 5 | from mlx import nn 6 | import mlx.core as mx 7 | 8 | from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn 9 | from mlx_engine.model_kit.vision_add_ons.load_utils import load_vision_addon 10 | from mlx_engine.model_kit.vision_add_ons.qwen_vl_utils import compute_qwen_vl_embeddings 11 | 12 | from mlx_vlm.models.qwen2_5_vl import ( 13 | VisionModel as Qwen25VLVisionTower, 14 | ModelConfig as Qwen25VLModelConfig, 15 | VisionConfig as Qwen25VLVisionConfig, 16 | TextConfig as Qwen25VLTextConfig, 17 | Model as Qwen25VLModel, 18 | ) 19 | from mlx_vlm.models.qwen2_vl import ( 20 | VisionModel as Qwen2VLVisionTower, 21 | ModelConfig as Qwen2VLModelConfig, 22 | VisionConfig as Qwen2VLVisionConfig, 23 | TextConfig as Qwen2VLTextConfig, 24 | Model as Qwen2VLModel, 25 | ) 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | class Qwen2_VLVisionAddOn(BaseVisionAddOn): 31 | """ 32 | Vision add-on for Qwen2-VL and Qwen2.5-VL models. 33 | """ 34 | 35 | def __init__(self, model_path: Path): 36 | """Initialize Qwen2_VLVisionAddOn with vision components loaded from the given path.""" 37 | super().__init__() 38 | 39 | # Determine model type from config to select appropriate classes 40 | config_path = model_path / "config.json" 41 | with open(config_path, "r") as f: 42 | config_dict = json.load(f) 43 | model_type = config_dict.get("model_type") 44 | 45 | # Import appropriate classes based on model type 46 | if model_type == "qwen2_5_vl": 47 | vision_tower_cls = Qwen25VLVisionTower 48 | model_config_cls = Qwen25VLModelConfig 49 | vision_config_cls = Qwen25VLVisionConfig 50 | text_config_cls = Qwen25VLTextConfig 51 | model_cls = Qwen25VLModel 52 | else: # Default to qwen2_vl 53 | vision_tower_cls = Qwen2VLVisionTower 54 | model_config_cls = Qwen2VLModelConfig 55 | vision_config_cls = Qwen2VLVisionConfig 56 | text_config_cls = Qwen2VLTextConfig 57 | model_cls = Qwen2VLModel 58 | 59 | # Store the model class for use in compute_embeddings 60 | self.model_cls = model_cls 61 | 62 | # Load vision components 63 | self.vision_tower, _, self.config, self.processor = load_vision_addon( 64 | model_path=model_path, 65 | model_config_class=model_config_cls, 66 | vision_config_class=vision_config_cls, 67 | text_config_class=text_config_cls, 68 | vision_tower_class=vision_tower_cls, 69 | multi_modal_projector_class=None, 70 | logger=logger, 71 | ) 72 | 73 | def compute_embeddings( 74 | self, 75 | text_model: nn.Module, 76 | prompt_tokens: mx.array, 77 | images_b64: list[str], 78 | max_size: tuple[int, int] | None, 79 | ) -> tuple[mx.array, mx.array]: 80 | """ 81 | Compute input_ids and embeddings for text with images. 82 | """ 83 | 84 | return compute_qwen_vl_embeddings( 85 | addon=self, 86 | text_model=text_model, 87 | prompt_tokens=prompt_tokens, 88 | images_b64=images_b64, 89 | qwen_vl_version=2, 90 | max_size=max_size, 91 | ) 92 | -------------------------------------------------------------------------------- /mlx_engine/model_kit/vision_add_ons/gemma3.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | from mlx import nn 5 | import mlx.core as mx 6 | 7 | from mlx_vlm.models.gemma3 import ( 8 | VisionModel as Gemma3VisionTower, 9 | ModelConfig as Gemma3ModelConfig, 10 | VisionConfig as Gemma3VisionConfig, 11 | TextConfig as Gemma3TextConfig, 12 | Model as Gemma3CombinedModel, # for prepare_inputs_for_multimodal 13 | ) 14 | from mlx_vlm.models.gemma3.gemma3 import Gemma3MultiModalProjector 15 | 16 | from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn 17 | from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import ( 18 | common_process_prompt_with_images, 19 | ) 20 | from mlx_engine.model_kit.vision_add_ons.load_utils import load_vision_addon 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class Gemma3VisionAddOn(BaseVisionAddOn): 26 | """ 27 | Vision add-on for Gemma3 model. Uses mlx-vlm vision components of Gemma3. 28 | """ 29 | 30 | def __init__(self, model_path: Path): 31 | """Initialize Gemma3VisionAddOn with vision components loaded from the given path.""" 32 | super().__init__() 33 | 34 | # Load vision model components, configuration, and processor 35 | self.vision_tower, self.multi_modal_projector, self.config, self.processor = ( 36 | load_vision_addon( 37 | model_path=model_path, 38 | model_config_class=Gemma3ModelConfig, 39 | vision_config_class=Gemma3VisionConfig, 40 | text_config_class=Gemma3TextConfig, 41 | vision_tower_class=Gemma3VisionTower, 42 | multi_modal_projector_class=Gemma3MultiModalProjector, 43 | logger=logger, 44 | ) 45 | ) 46 | 47 | def compute_embeddings( 48 | self, 49 | text_model: nn.Module, 50 | prompt_tokens: mx.array, 51 | images_b64: list[str], 52 | max_size: tuple[int, int] | None, 53 | ) -> tuple[mx.array, mx.array]: 54 | """Compute input_ids and embeddings for text with images.""" 55 | input_ids, pixel_values, attention_mask, other_model_inputs = ( 56 | common_process_prompt_with_images( 57 | prompt_tokens=prompt_tokens, 58 | images_b64=images_b64, 59 | processor=self.processor, 60 | config=self.config, 61 | max_size=max_size, 62 | ) 63 | ) 64 | input_embeddings = text_model.language_model.model.embed_tokens(input_ids) 65 | 66 | # Process image through vision tower 67 | hidden_state, _, _ = self.vision_tower( 68 | pixel_values.transpose(0, 2, 3, 1).astype(input_embeddings.dtype), 69 | output_hidden_states=True, 70 | ) 71 | 72 | # Format image features 73 | image_features = hidden_state.astype(pixel_values.dtype) 74 | image_features = self.multi_modal_projector(image_features) 75 | 76 | # Combine image and text embeddings 77 | final_inputs_embeds, _ = Gemma3CombinedModel.prepare_inputs_for_multimodal( 78 | self.config.hidden_size, 79 | self.config.pad_token_id, 80 | self.config.image_token_index, 81 | image_features, 82 | input_embeddings, 83 | input_ids, 84 | attention_mask, 85 | ) 86 | # remove batch dimension 87 | return input_ids.squeeze(0), final_inputs_embeds.squeeze(0) 88 | -------------------------------------------------------------------------------- /mlx_engine/model_kit/vision_add_ons/qwen_vl_utils.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | from mlx import nn 3 | 4 | from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn 5 | from mlx_engine.utils.image_utils import convert_to_pil, custom_resize 6 | 7 | from mlx_vlm.utils import prepare_inputs 8 | 9 | 10 | def compute_qwen_vl_embeddings( 11 | addon: BaseVisionAddOn, 12 | text_model: nn.Module, 13 | prompt_tokens: mx.array, 14 | images_b64: list[str], 15 | qwen_vl_version: int, 16 | max_size: tuple[int, int] | None, 17 | ) -> tuple[mx.array, mx.array]: 18 | """ 19 | Compute input_ids and embeddings for Qwen2-VL, Qwen2.5-VL, and Qwen3-VL models. 20 | 21 | Args: 22 | addon: Vision add-on instance containing vision tower, config, and processor 23 | text_model: Text model for embedding tokens 24 | prompt_tokens: Input prompt tokens 25 | images_b64: List of base64-encoded images 26 | qwen_vl_version: Version number (2 for Qwen2/2.5-VL, 3 for Qwen3-VL) 27 | max_size: Maximum image size as (width, height) tuple. If None, no resizing. 28 | 29 | Returns: 30 | Tuple of (input_ids, final_embeddings) with batch dimension removed 31 | """ 32 | 33 | # Convert and resize images 34 | images = convert_to_pil(images_b64) 35 | images = custom_resize(images, max_size=max_size, should_pad=False) 36 | 37 | # Build prompt text 38 | tokens = ( 39 | prompt_tokens if isinstance(prompt_tokens, list) else prompt_tokens.tolist() 40 | ) 41 | prompt = addon.processor.decode(tokens) 42 | 43 | # Prepare inputs 44 | inputs = prepare_inputs( 45 | processor=addon.processor, 46 | images=images, 47 | prompts=prompt, 48 | image_token_index=addon.config.image_token_id, 49 | resize_shape=None, 50 | ) 51 | input_ids = inputs["input_ids"] 52 | pixel_values = inputs["pixel_values"] 53 | grid_thw = inputs.get("image_grid_thw") 54 | 55 | # Get text embeddings 56 | input_embeddings = text_model.language_model.model.embed_tokens(input_ids) 57 | 58 | # If no images, return input_ids and input_embeddings 59 | if pixel_values is None: 60 | return input_ids.squeeze(0), input_embeddings.squeeze(0) 61 | 62 | # Ensure pixel values are in the right format for vision tower 63 | if pixel_values.dtype != input_embeddings.dtype: 64 | pixel_values = pixel_values.astype(input_embeddings.dtype) 65 | 66 | # Process image through vision tower and merge embeddings 67 | if qwen_vl_version == 2: 68 | hidden_states = addon.vision_tower( 69 | pixel_values, grid_thw, output_hidden_states=False 70 | ) 71 | 72 | final_inputs_embeds = addon.model_cls.merge_input_ids_with_image_features( 73 | addon.config.image_token_id, 74 | addon.config.video_token_id, 75 | hidden_states, 76 | input_embeddings, 77 | input_ids, 78 | ) 79 | elif qwen_vl_version == 3: 80 | hidden_states, _ = addon.vision_tower( 81 | pixel_values, grid_thw, output_hidden_states=False 82 | ) 83 | 84 | final_inputs_embeds, _ = addon.model_cls.merge_input_ids_with_image_features( 85 | hidden_states, 86 | input_embeddings, 87 | input_ids, 88 | addon.config.image_token_id, 89 | addon.config.video_token_id, 90 | ) 91 | else: 92 | raise ValueError(f"Invalid Qwen-VL version: {qwen_vl_version}") 93 | 94 | # Remove batch dimension 95 | return input_ids.squeeze(0), final_inputs_embeds.squeeze(0) 96 | -------------------------------------------------------------------------------- /mlx_engine/utils/progress_decorators.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable 2 | from mlx_engine.cache_wrapper import StopPromptProcessing 3 | import logging 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | def ratchet( 9 | callback: Optional[Callable[[float], bool]], 10 | ) -> Optional[Callable[[float], bool]]: 11 | """ 12 | Wraps a progress callback to ensure progress values are monotonically increasing. 13 | 14 | This wrapper prevents progress from appearing to move backwards by using a ratchet 15 | mechanism. If a lower percentage is reported than previously seen, the callback 16 | returns True (continue) without calling the original callback. 17 | 18 | Args: 19 | callback: A callback that accepts progress (0.0–100.0) and returns 20 | True to continue or False to stop. May be None. 21 | 22 | Returns: 23 | A wrapped callback that ensures monotonic progress reporting. 24 | If callback is None, returns None. 25 | """ 26 | if callback is None: 27 | return None 28 | 29 | ratchet = float("-inf") 30 | 31 | def inner_callback(percentage: float) -> bool: 32 | nonlocal ratchet 33 | if percentage <= ratchet: 34 | return True 35 | ratchet = percentage 36 | return callback(percentage) 37 | 38 | return inner_callback 39 | 40 | 41 | def throw_to_stop( 42 | callback: Optional[Callable[[float], bool]], 43 | ) -> Optional[Callable[[float], bool]]: 44 | """ 45 | Wraps a progress callback to raise an exception when stopping is requested. 46 | 47 | Instead of returning False to indicate stopping, this wrapper raises a 48 | StopPromptProcessing exception when the original callback returns False. 49 | This allows for immediate termination of the generation process. 50 | 51 | Args: 52 | callback: A callback that accepts progress (0.0–100.0) and returns 53 | True to continue or False to stop. May be None. 54 | 55 | Returns: 56 | A wrapped callback that raises StopPromptProcessing when stopping 57 | is requested. If callback is None, returns None. 58 | 59 | Raises: 60 | StopPromptProcessing: When the original callback returns False. 61 | """ 62 | if callback is None: 63 | return None 64 | 65 | def inner_callback(percentage: float) -> bool: 66 | should_continue = callback(percentage) 67 | if not should_continue: 68 | logger.info("Prompt processing was cancelled by the user.") 69 | raise StopPromptProcessing 70 | return should_continue 71 | 72 | return inner_callback 73 | 74 | 75 | def token_count( 76 | callback: Optional[Callable[[float], bool]], 77 | ) -> Optional[Callable[[int, int], None]]: 78 | """ 79 | Adapts a float percentage based progress callback into a token count based one. 80 | 81 | Args: 82 | outer_callback: A callback that accepts progress (0.0–100.0) and returns 83 | True to continue or False to stop. May be None. 84 | 85 | Returns: 86 | A token-based callback (processed_tokens, total_tokens) -> None, 87 | as is expected by mlx-lm's stream_generate. 88 | If outer_callback is None, returns None. 89 | """ 90 | if callback is None: 91 | return None 92 | 93 | def inner_callback(processed_tokens: int, total_tokens: int) -> None: 94 | if total_tokens <= 0: 95 | progress = 0.0 96 | else: 97 | progress = 100 * processed_tokens / total_tokens 98 | callback(progress) 99 | 100 | return inner_callback 101 | -------------------------------------------------------------------------------- /mlx_engine/model_kit/vision_add_ons/pixtral.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | from mlx import nn 5 | import mlx.core as mx 6 | 7 | from mlx_vlm.models.pixtral import ( 8 | VisionModel as PixtralVisionTower, 9 | ModelConfig as PixtralModelConfig, 10 | VisionConfig as PixtralVisionConfig, 11 | TextConfig as PixtralTextConfig, 12 | Model as PixtralCombinedModel, # for merge_input_ids_with_image_features 13 | ) 14 | from mlx_vlm.models.pixtral.pixtral import ( 15 | LlavaMultiModalProjector as PixtralMultiModalProjector, 16 | ) 17 | 18 | from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn 19 | from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import ( 20 | common_process_prompt_with_images, 21 | ) 22 | from mlx_engine.model_kit.vision_add_ons.load_utils import load_vision_addon 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class PixtralVisionAddOn(BaseVisionAddOn): 28 | """ 29 | Vision add-on for Pixtral model. Uses mlx-vlm vision components of Pixtral. 30 | """ 31 | 32 | def __init__(self, model_path: Path): 33 | """Initialize PixtralVisionAddOn with vision components loaded from the given path.""" 34 | super().__init__() 35 | 36 | # Load vision model components, configuration, and processor 37 | self.vision_tower, self.multi_modal_projector, self.config, self.processor = ( 38 | load_vision_addon( 39 | model_path=model_path, 40 | model_config_class=PixtralModelConfig, 41 | vision_config_class=PixtralVisionConfig, 42 | text_config_class=PixtralTextConfig, 43 | vision_tower_class=PixtralVisionTower, 44 | multi_modal_projector_class=PixtralMultiModalProjector, 45 | logger=logger, 46 | ) 47 | ) 48 | 49 | def compute_embeddings( 50 | self, 51 | text_model: nn.Module, 52 | prompt_tokens: mx.array, 53 | images_b64: list[str], 54 | max_size: tuple[int, int] | None, 55 | ) -> tuple[mx.array, mx.array]: 56 | """Compute input_ids and embeddings for text with images.""" 57 | input_ids, pixel_values, attention_mask, other_model_inputs = ( 58 | common_process_prompt_with_images( 59 | prompt_tokens=prompt_tokens, 60 | images_b64=images_b64, 61 | processor=self.processor, 62 | config=self.config, 63 | max_size=max_size, 64 | ) 65 | ) 66 | input_embeddings = text_model.language_model.model.embed_tokens(input_ids) 67 | 68 | if isinstance(pixel_values, list): 69 | pixel_values = mx.concatenate( 70 | [mx.array(pv)[None, ...] for pv in pixel_values], axis=0 71 | ) 72 | if pixel_values.ndim == 3: 73 | pixel_values = pixel_values[None, ...] 74 | 75 | # Process image through vision tower 76 | *_, hidden_states = self.vision_tower( 77 | pixel_values.transpose(0, 2, 3, 1), 78 | output_hidden_states=True, 79 | ) 80 | # Select the hidden states from the desired layer 81 | selected_image_feature = hidden_states[self.config.vision_feature_layer] 82 | 83 | # Pass image features through the multi-modal projector 84 | image_features = self.multi_modal_projector(selected_image_feature) 85 | 86 | # Insert special image tokens in the input_ids 87 | final_inputs_embeds = PixtralCombinedModel.merge_input_ids_with_image_features( 88 | self.config.image_token_index, image_features, input_embeddings, input_ids 89 | ) 90 | # remove batch dimension 91 | return input_ids.squeeze(0), final_inputs_embeds.squeeze(0) 92 | -------------------------------------------------------------------------------- /mlx_engine/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from io import BytesIO 3 | from typing import List 4 | import PIL 5 | import logging 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def convert_to_pil(images_b64: List[str]) -> list[PIL.Image.Image]: 11 | """Convert a list of base64 strings to PIL Images""" 12 | return [ 13 | PIL.Image.open(BytesIO(base64.b64decode(img))).convert("RGB") 14 | for img in images_b64 or [] 15 | ] 16 | 17 | 18 | def custom_resize( 19 | pil_images: list[PIL.Image.Image], 20 | *, 21 | max_size: tuple[int, int] | None, 22 | should_pad: bool = True, 23 | ): 24 | """ 25 | Resize and optionally pad a list of PIL images. 26 | 27 | This function resizes images that exceed the specified maximum dimensions, 28 | maintaining their aspect ratios. If there is more than one image, it also 29 | pads all images to the same size. 30 | 31 | Args: 32 | pil_images (list): A list of PIL Image objects to be processed. 33 | max_size (tuple): Maximum allowed dimensions (width, height) for the images. 34 | If None, no resizing is performed. 35 | should_pad (bool): Whether to pad the images to the same size. 36 | Defaults to True. 37 | 38 | Returns: 39 | list: A list of processed PIL Image objects. If there was only one input image, 40 | it returns the resized image without padding. If there were multiple input 41 | images, it returns padded images of uniform size. 42 | 43 | Side effects: 44 | Writes progress and status messages to sys.stderr. 45 | """ 46 | # Validate max_size parameter 47 | if max_size is not None: 48 | if not isinstance(max_size, tuple) or len(max_size) != 2: 49 | raise ValueError( 50 | "max_size must be a tuple of (width, height), e.g., (1024, 1024)" 51 | ) 52 | if not all(isinstance(dim, int) and dim > 0 for dim in max_size): 53 | raise ValueError("max_size dimensions must be positive integers") 54 | 55 | resized_images = [] 56 | max_width, max_height = 0, 0 57 | 58 | for i, img in enumerate(pil_images): 59 | original_width, original_height = img.width, img.height 60 | aspect_ratio = img.width / img.height 61 | 62 | if max_size is not None and ( 63 | img.width > max_size[0] or img.height > max_size[1] 64 | ): 65 | if img.width > img.height: 66 | new_width = max_size[0] 67 | new_height = int(new_width / aspect_ratio) 68 | else: 69 | new_height = max_size[1] 70 | new_width = int(new_height * aspect_ratio) 71 | img = img.resize((new_width, new_height), PIL.Image.LANCZOS) 72 | logger.info( 73 | f"Image {i + 1}: Resized from {original_width}x{original_height} to {img.width}x{img.height}\n", 74 | ) 75 | 76 | max_width = max(max_width, img.width) 77 | max_height = max(max_height, img.height) 78 | 79 | resized_images.append(img) 80 | 81 | if len(pil_images) > 1 and should_pad: 82 | logger.info( 83 | f"[mlx-engine] Maximum dimensions: {max_width}x{max_height}. " 84 | f"Adding padding so that all images are the same size.\n", 85 | ) 86 | 87 | final_images = [] 88 | for i, img in enumerate(resized_images): 89 | new_img = PIL.Image.new("RGB", (max_width, max_height), (0, 0, 0)) 90 | paste_x = (max_width - img.width) // 2 91 | paste_y = (max_height - img.height) // 2 92 | new_img.paste(img, (paste_x, paste_y)) 93 | final_images.append(new_img) 94 | return final_images 95 | else: 96 | return resized_images 97 | -------------------------------------------------------------------------------- /tests/test_cache_wrapper.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import mlx.core as mx 3 | from mlx_engine.cache_wrapper import CacheWrapper, StopPromptProcessing 4 | from tests.shared import model_getter 5 | from mlx_engine.generate import load_model, tokenize 6 | 7 | 8 | class TestCacheWrapper(unittest.TestCase): 9 | def test_find_common_prefix_with_mismatch(self): 10 | """Test when there's a mismatch in the tokens""" 11 | # Create two arrays with a known common prefix [1, 2, 3] 12 | current_tokens = mx.array([1, 2, 3, 4, 5]) 13 | prompt_tokens = mx.array([1, 2, 3, 6, 7]) # Mismatch at index 3 14 | num_tokens_to_exclude = 1 15 | 16 | print("\nTest with mismatch:") 17 | print(f"current_tokens: {current_tokens}") 18 | print(f"prompt_tokens: {prompt_tokens}") 19 | 20 | result = CacheWrapper._find_common_prefix( 21 | current_tokens, prompt_tokens, num_tokens_to_exclude 22 | ) 23 | self.assertEqual(result, 3) # Should find 3 matching tokens 24 | 25 | def test_find_common_prefix_all_match(self): 26 | """Test when all tokens match""" 27 | # Create two identical arrays 28 | current_tokens = mx.array([1, 2, 3, 4, 5]) 29 | prompt_tokens = mx.array([1, 2, 3, 4, 5]) # All tokens match 30 | num_tokens_to_exclude = 1 31 | 32 | print("\nTest with all matching:") 33 | print(f"current_tokens: {current_tokens}") 34 | print(f"prompt_tokens: {prompt_tokens}") 35 | 36 | result = CacheWrapper._find_common_prefix( 37 | current_tokens, prompt_tokens, num_tokens_to_exclude 38 | ) 39 | self.assertEqual( 40 | result, 4 41 | ) # Should find 4 matching tokens (5-1 due to num_tokens_to_exclude) 42 | 43 | def test_prompt_processing_cancellation(self): 44 | """Test that progress is saved when processing is cancelled and cache is reused on retry""" 45 | 46 | model_path = model_getter("lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit") 47 | model_kit = load_model(model_path=model_path, max_kv_size=4096) 48 | 49 | chunk_size = 20 # Small chunk size to ensure multiple progress callbacks 50 | num_tokens_to_exclude = 1 51 | model_kit.cache_wrapper = CacheWrapper( 52 | model_kit.model, 53 | max_kv_size=4096, 54 | chunk_size=chunk_size, 55 | ) 56 | 57 | long_prompt = ( 58 | "This is a test prompt that needs to be long enough to require multiple chunks for processing. " 59 | * 50 60 | ) 61 | prompt_tokens = mx.array(tokenize(model_kit, long_prompt)) 62 | tokens_to_process = len(prompt_tokens) - num_tokens_to_exclude 63 | # ceiling division 64 | expected_chunks = (tokens_to_process + chunk_size - 1) // chunk_size 65 | 66 | # First attempt: Progress callback that cancels after a few updates 67 | first_progress_calls = [] 68 | 69 | def cancelling_progress_callback(progress): 70 | first_progress_calls.append(progress) 71 | if len(first_progress_calls) >= 3: 72 | return False 73 | return True 74 | 75 | with self.assertRaises(StopPromptProcessing): 76 | model_kit.cache_wrapper.update_cache( 77 | prompt_tokens=prompt_tokens, 78 | prompt_progress_callback=cancelling_progress_callback, 79 | num_tokens_to_exclude=1, 80 | ) 81 | first_attempt_progress_calls = len(first_progress_calls) 82 | 83 | # Second attempt: Progress callback that doesn't cancel 84 | second_progress_calls = [] 85 | 86 | def non_cancelling_progress_callback(progress): 87 | second_progress_calls.append(progress) 88 | return True 89 | 90 | result_tokens = model_kit.cache_wrapper.update_cache( 91 | prompt_tokens=prompt_tokens, 92 | prompt_progress_callback=non_cancelling_progress_callback, 93 | num_tokens_to_exclude=1, 94 | ) 95 | second_attempt_progress_calls = len(second_progress_calls) 96 | 97 | self.assertEqual( 98 | second_attempt_progress_calls, 99 | # +1 for the final 100% callback, +1 for the duplicate 0% callback 100 | expected_chunks - first_attempt_progress_calls + 2, 101 | ) 102 | 103 | # Verify that the second attempt completed successfully 104 | self.assertIsNotNone(result_tokens) 105 | 106 | 107 | if __name__ == "__main__": 108 | unittest.main(verbosity=2) 109 | -------------------------------------------------------------------------------- /tests/utils/test_progress_decorators.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from mlx_engine.utils.progress_decorators import ratchet, throw_to_stop, token_count 3 | from mlx_engine.cache_wrapper import StopPromptProcessing 4 | from typing import Callable, TypeVar 5 | 6 | 7 | def create_callback(received_calls: list[float], retval: bool): 8 | def callback(val: float) -> bool: 9 | received_calls.append(val) 10 | return retval 11 | 12 | return callback 13 | 14 | 15 | def execute_calls( 16 | system_under_test: Callable[[float], bool], inputs: list[float] 17 | ) -> list[bool]: 18 | return [system_under_test(i) for i in inputs] 19 | 20 | 21 | T = TypeVar("T") 22 | 23 | 24 | def unwrap_optional(value: T | None) -> T: 25 | if value is None: 26 | raise RuntimeError("Value cannot be None") 27 | return value 28 | 29 | 30 | class TestRatchet(unittest.TestCase): 31 | def test_none_callback_returns_none(self): 32 | """Test that ratchet returns None when given None.""" 33 | result = ratchet(None) 34 | self.assertIsNone(result) 35 | 36 | def test_monotonic_progress(self): 37 | """Test that ratchet allows monotonic progress updates.""" 38 | inputs = [25.0, 50.0, 75.0] 39 | received_calls = [] 40 | original_callback = create_callback(received_calls, False) 41 | 42 | system_under_test = unwrap_optional(ratchet(original_callback)) 43 | results = [system_under_test(i) for i in inputs] 44 | 45 | self.assertEqual(received_calls, inputs) 46 | self.assertEqual(results, [False] * 3) 47 | 48 | def test_non_monotonic_progress(self): 49 | """Test that ratchet disallows non-monotonic progress updates.""" 50 | inputs = [0.0, 25.0, 0.0, 50.0, 30.0, 75.0, 60.0, 100.0, 99.9] 51 | # We construct the return value of the original_callback so that the wrapped callback will 52 | # return False when original_callback is called, and True otherwise. 53 | expected_results = [False, False, True, False, True, False, True, False, True] 54 | expected_calls = [ 55 | input for (input, result) in zip(inputs, expected_results) if not result 56 | ] 57 | received_calls = [] 58 | original_callback = create_callback(received_calls, False) 59 | 60 | system_under_test = unwrap_optional(ratchet(original_callback)) 61 | results = [system_under_test(i) for i in inputs] 62 | 63 | self.assertEqual(received_calls, expected_calls) 64 | self.assertEqual(results, expected_results) 65 | 66 | 67 | class TestThrowToStop(unittest.TestCase): 68 | def test_none_callback_returns_none(self): 69 | """Test that throw_to_stop returns None when given None.""" 70 | result = throw_to_stop(None) 71 | self.assertIsNone(result) 72 | 73 | def test_callback_returns_true(self): 74 | """Test that throw_to_stop continues when callback returns True.""" 75 | inputs = [25.0, 50.0, 75.0] 76 | received_calls = [] 77 | original_callback = create_callback(received_calls, True) 78 | 79 | system_under_test = unwrap_optional(throw_to_stop(original_callback)) 80 | results = [system_under_test(i) for i in inputs] 81 | 82 | self.assertEqual(received_calls, inputs) 83 | self.assertEqual(results, [True] * 3) 84 | 85 | def test_callback_returns_false_raises_exception(self): 86 | """Test that throw_to_stop raises StopPromptProcessing when callback returns False.""" 87 | input = 25.0 88 | received_calls = [] 89 | original_callback = create_callback(received_calls, False) 90 | 91 | system_under_test = unwrap_optional(throw_to_stop(original_callback)) 92 | with self.assertRaises(StopPromptProcessing): 93 | system_under_test(input) 94 | 95 | self.assertEqual(received_calls, [input]) 96 | 97 | 98 | class TestTokenCount(unittest.TestCase): 99 | def test_none_callback_returns_none(self): 100 | """Test that token_count returns None when given None.""" 101 | result = token_count(None) 102 | self.assertIsNone(result) 103 | 104 | def test_token_count_callback(self): 105 | """Test that token_count calls the callback with correct token counts.""" 106 | inputs = [(0, 30), (10, 30), (20, 30), (30, 30)] 107 | expected_calls = [input[0] / input[1] * 100.0 for input in inputs] 108 | received_calls = [] 109 | original_callback = create_callback(received_calls, True) 110 | 111 | system_under_test = unwrap_optional(token_count(original_callback)) 112 | results = [system_under_test(*i) for i in inputs] 113 | 114 | for received, expected in zip(received_calls, expected_calls): 115 | self.assertAlmostEqual(received, expected) 116 | self.assertEqual(results, [None] * len(inputs)) 117 | -------------------------------------------------------------------------------- /tests/data/ben_franklin_autobiography_start.txt: -------------------------------------------------------------------------------- 1 | B. FRANKLIN 2 | B. Franklin's signature 3 | From an engraving by J. Thomson from the original picture by J. A. Duplessis. 4 | INTRODUCTION 5 | 6 | block-W E Americans devour eagerly any piece of writing that purports to tell us the secret of success in life; yet how often we are disappointed to find nothing but commonplace statements, or receipts that we know by heart but never follow. Most of the life stories of our famous and successful men fail to inspire because they lack the human element that makes the record real and brings the story within our grasp. While we are searching far and near for some Aladdin's Lamp to give coveted fortune, there is ready at our hand if we will only reach out and take it, like the charm in Milton's Comus, 7 | 8 | "Unknown, and like esteemed, and the dull swain 9 | Treads on it daily with his clouted shoon;" 10 | the interesting, human, and vividly told story of one of the wisest and most useful lives in our own history, and perhaps in any history. In Franklin's Autobiography is offered not so much a ready-made formula for success, as the companionship of a real flesh and blood man of extraordinary mind and quality, whose daily walk and conversation will help us to meet our own difficulties, much as does the example of a wise and strong friend. While we are fascinated by the story, we absorb the human experience through which a strong and helpful character is building. 11 | 12 | The thing that makes Franklin's Autobiography different from every other life story of a great and successful man is just this human aspect of the account. Franklin told the story of his life, as he himself says, for the benefit of his posterity. He wanted to help them by the relation of his own rise from obscurity and poverty to eminence and wealth. He is not unmindful of the importance of his public services and their recognition, yet his accounts of these achievements are given only as a part of the story, and the vanity displayed is incidental and in keeping with the honesty of the recital. There is nothing of the impossible in the method and practice of Franklin as he sets them forth. The youth who reads the fascinating story is astonished to find that Franklin in his early years struggled with the same everyday passions and difficulties that he himself experiences, and he loses the sense of discouragement that comes from a realization of his own shortcomings and inability to attain. 13 | 14 | There are other reasons why the Autobiography should be an intimate friend of American young people. Here they may establish a close relationship with one of the foremost Americans as well as one of the wisest men of his age. 15 | 16 | The life of Benjamin Franklin is of importance to every American primarily because of the part he played in securing the independence of the United States and in establishing it as a nation. Franklin shares with Washington the honors of the Revolution, and of the events leading to the birth of the new nation. While Washington was the animating spirit of the struggle in the colonies, Franklin was its ablest champion abroad. To Franklin's cogent reasoning and keen satire, we owe the clear and forcible presentation of the American case in England and France; while to his personality and diplomacy as well as to his facile pen, we are indebted for the foreign alliance and the funds without which Washington's work must have failed. His patience, fortitude, and practical wisdom, coupled with self-sacrificing devotion to the cause of his country, are hardly less noticeable than similar qualities displayed by Washington. In fact, Franklin as a public man was much like Washington, especially in the entire disinterestedness of his public service. 17 | 18 | Franklin is also interesting to us because by his life and teachings he has done more than any other American to advance the material prosperity of his countrymen. It is said that his widely and faithfully read maxims made Philadelphia and Pennsylvania wealthy, while Poor Richard's pithy sayings, translated into many languages, have had a world-wide influence. 19 | 20 | Franklin is a good type of our American manhood. Although not the wealthiest or the most powerful, he is undoubtedly, in the versatility of his genius and achievements, the greatest of our self-made men. The simple yet graphic story in the Autobiography of his steady rise from humble boyhood in a tallow-chandler shop, by industry, economy, and perseverance in self-improvement, to eminence, is the most remarkable of all the remarkable histories of our self-made men. It is in itself a wonderful illustration of the results possible to be attained in a land of unequaled opportunity by following Franklin's maxims. 21 | 22 | Franklin's fame, however, was not confined to his own country. Although he lived in a century notable for the rapid evolution of scientific and political thought and activity, yet no less a keen judge and critic than Lord Jeffrey, the famous editor of the Edinburgh Review, a century ago said that "in one point of view the name of Franklin must be considered as standing higher than any of the others which illustrated the eighteenth century. Distinguished as a statesman, he was equally great as a philosopher, thus uniting in himself a rare degree of excellence in both these pursuits, to excel in either of which is deemed the highest praise." 23 | -------------------------------------------------------------------------------- /mlx_engine/model_kit/vision_add_ons/mistral3.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | from mlx import nn 5 | import mlx.core as mx 6 | 7 | from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn 8 | from mlx_vlm.models.mistral3 import ( 9 | VisionModel as Mistral3VisionTower, 10 | ModelConfig as Mistral3ModelConfig, 11 | VisionConfig as Mistral3VisionConfig, 12 | TextConfig as Mistral3TextConfig, 13 | Model as Mistral3CombinedModel, 14 | ) 15 | from mlx_vlm.models.mistral3.mistral3 import Mistral3MultiModalProjector 16 | from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import ( 17 | common_process_prompt_with_images, 18 | ) 19 | from mlx_engine.model_kit.vision_add_ons.load_utils import load_vision_addon 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class Mistral3VisionAddOn(BaseVisionAddOn): 25 | """ 26 | Vision add-on for Mistral3 models. 27 | """ 28 | 29 | def __init__(self, model_path: Path): 30 | """Initialize Mistral3VisionAddOn with vision components loaded from the given path.""" 31 | super().__init__() 32 | 33 | processor_kwargs: dict | None = None 34 | if self._is_lmstudio_mistral_3_2_small(model_path): 35 | processor_kwargs = { 36 | "patch_size": 14, 37 | "spatial_merge_size": 2, 38 | } 39 | logger.info( 40 | "Detected LM Studio Mistral Small 3.2 model. " 41 | f"Using custom processor kwargs: {processor_kwargs}" 42 | ) 43 | 44 | self.vision_tower, self.multi_modal_projector, self.config, self.processor = ( 45 | load_vision_addon( 46 | model_path=model_path, 47 | model_config_class=Mistral3ModelConfig, 48 | vision_config_class=Mistral3VisionConfig, 49 | text_config_class=Mistral3TextConfig, 50 | vision_tower_class=Mistral3VisionTower, 51 | multi_modal_projector_class=Mistral3MultiModalProjector, 52 | logger=logger, 53 | processor_kwargs=processor_kwargs, 54 | ) 55 | ) 56 | 57 | def compute_embeddings( 58 | self, 59 | text_model: nn.Module, 60 | prompt_tokens: mx.array, 61 | images_b64: list[str], 62 | max_size: tuple[int, int] | None, 63 | ) -> tuple[mx.array, mx.array]: 64 | """ 65 | Compute embeddings for text with images. 66 | 67 | This method is heavily based on the mlx-vlm's mistral3 `get_input_embeddings` 68 | https://github.com/Blaizzy/mlx-vlm/blob/2c3014fd40962bd5320ad611502e7e26cae08926/mlx_vlm/models/mistral3/mistral3.py#L240-L279 69 | """ 70 | 71 | input_ids, pixel_values, attention_mask, other_model_inputs = ( 72 | common_process_prompt_with_images( 73 | prompt_tokens=prompt_tokens, 74 | images_b64=images_b64, 75 | processor=self.processor, 76 | config=self.config, 77 | max_size=max_size, 78 | ) 79 | ) 80 | 81 | image_sizes_list = other_model_inputs["image_sizes"] 82 | image_sizes = mx.array(image_sizes_list) 83 | 84 | if pixel_values is None: 85 | return text_model.language_model.model.embed_tokens(input_ids) 86 | 87 | # Get the input embeddings from the language model 88 | inputs_embeds = text_model.language_model.model.embed_tokens(input_ids) 89 | 90 | # Get the output hidden states from the vision model 91 | if isinstance(pixel_values, list): 92 | pixel_values = mx.concatenate( 93 | [mx.array(pv)[None, ...] for pv in pixel_values], axis=0 94 | ) 95 | if pixel_values.ndim == 3: 96 | pixel_values = pixel_values[None, ...] 97 | 98 | # Pass pixel_values as list of images, as each image is individually run through conv2d and position encoding 99 | # Reference code from transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/modeling_pixtral.py#L479C9-L479C21 100 | # and mistral_inference: https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py#L85 101 | *_, hidden_states = self.vision_tower( 102 | pixel_values.transpose(0, 2, 3, 1), 103 | output_hidden_states=True, 104 | ) 105 | # Select the hidden states from the desired layer 106 | selected_image_feature = hidden_states[self.config.vision_feature_layer] 107 | 108 | # Pass image features through the multi-modal projector 109 | image_features = self.multi_modal_projector(selected_image_feature, image_sizes) 110 | 111 | # Insert special image tokens in the input_ids 112 | final_inputs_embeds = Mistral3CombinedModel.merge_input_ids_with_image_features( 113 | self.config.image_token_index, image_features, inputs_embeds, input_ids 114 | ) 115 | # remove batch dimension 116 | return input_ids.squeeze(0), final_inputs_embeds.squeeze(0) 117 | 118 | @staticmethod 119 | def _is_lmstudio_mistral_3_2_small(model_path: Path) -> bool: 120 | return "lmstudio-community/Mistral-Small-3.2-24B-Instruct-2506-MLX" in str( 121 | model_path 122 | ) 123 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | lmstudio + MLX 4 | 5 |

6 | 7 |

mlx-engine - Apple MLX LLM Engine for LM Studio

8 |
9 |

Discord

10 | 11 | # mlx-engine 12 | MLX engine for LM Studio 13 | 14 |
15 | 16 | ## Built with 17 | - [mlx-lm](https://github.com/ml-explore/mlx-lm) - Apple MLX inference engine (MIT) 18 | - [Outlines](https://github.com/dottxt-ai/outlines) - Structured output for LLMs (Apache 2.0) 19 | - [mlx-vlm](https://github.com/Blaizzy/mlx-vlm) - Vision model inferencing for MLX (MIT) 20 | 21 |
22 | 23 | ## How to use in LM Studio 24 | LM Studio 0.3.4 and newer for Mac ships pre-bundled with mlx-engine. 25 | Download LM Studio from [here](https://lmstudio.ai/download?os=mac) 26 | 27 |
28 | 29 | ## Standalone Demo 30 | 31 | ### Prerequisites 32 | 33 | - macOS 14.0 (Sonoma) or greater. 34 | - python3.11 35 | - The requirements.txt file is compiled specifically for python3.11. python3.11 is the python version bundled within the LM Studio MLX runtime 36 | - `brew install python@3.11` is a quick way to add python3.11 to your path that doesn't break your default python setup 37 | 38 | ### Install Steps 39 | To run a demo of model load and inference: 40 | 1. Clone the repository 41 | ``` 42 | git clone https://github.com/lmstudio-ai/mlx-engine.git 43 | cd mlx-engine 44 | ``` 45 | 2. Create a virtual environment (optional) 46 | ``` 47 | python3.11 -m venv .venv 48 | source .venv/bin/activate 49 | ``` 50 | 3. Install the required dependency packages 51 | ``` 52 | pip install -U -r requirements.txt 53 | ``` 54 | 55 | ### Text Model Demo 56 | Download models with the `lms` CLI tool. The `lms` CLI documentation can be found here: https://lmstudio.ai/docs/cli 57 | Run the `demo.py` script with an MLX text generation model: 58 | ```bash 59 | lms get mlx-community/Meta-Llama-3.1-8B-Instruct-4bit 60 | python demo.py --model mlx-community/Meta-Llama-3.1-8B-Instruct-4bit 61 | ``` 62 | [mlx-community/Meta-Llama-3.1-8B-Instruct-4bit](https://model.lmstudio.ai/download/mlx-community/Meta-Llama-3.1-8B-Instruct-4bit) - 4.53 GB 63 | 64 | This command will use a default prompt. For a different prompt, add a custom `--prompt` argument like: 65 | ```bash 66 | lms get mlx-community/Mistral-Small-Instruct-2409-4bit 67 | python demo.py --model mlx-community/Mistral-Small-Instruct-2409-4bit --prompt "How long will it take for an apple to fall from a 10m tree?" 68 | ``` 69 | [mlx-community/Mistral-Small-Instruct-2409-4bit](https://model.lmstudio.ai/download/mlx-community/Mistral-Small-Instruct-2409-4bit) - 12.52 GB 70 | 71 | ### Vision Model Demo 72 | Run the `demo.py` script with an MLX vision model: 73 | ```bash 74 | lms get mlx-community/pixtral-12b-4bit 75 | python demo.py --model mlx-community/pixtral-12b-4bit --prompt "Compare these images" --images demo-data/chameleon.webp demo-data/toucan.jpeg 76 | ``` 77 | Currently supported vision models include: 78 | - [Llama-3.2-Vision](https://model.lmstudio.ai/download/mlx-community/Llama-3.2-11B-Vision-Instruct-4bit) 79 | - `lms get mlx-community/Llama-3.2-11B-Vision-Instruct-4bit` 80 | - [Pixtral](https://model.lmstudio.ai/download/mlx-community/pixtral-12b-4bit) 81 | - `lms get mlx-community/pixtral-12b-4bit` 82 | - [Qwen2-VL](https://model.lmstudio.ai/download/mlx-community/Qwen2-VL-7B-Instruct-4bit) 83 | - `lms get mlx-community/Qwen2-VL-7B-Instruct-4bit` 84 | - [Llava-v1.6](https://model.lmstudio.ai/download/mlx-community/llava-v1.6-mistral-7b-4bit) 85 | - `lms get mlx-community/llava-v1.6-mistral-7b-4bit` 86 | 87 | ### Speculative Decoding Demo 88 | Run the `demo.py` script with an MLX text generation model and a compatible `--draft-model` 89 | ```bash 90 | lms get mlx-community/Qwen2.5-7B-Instruct-4bit 91 | lms get lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit 92 | python demo.py \ 93 | --model mlx-community/Qwen2.5-7B-Instruct-4bit \ 94 | --draft-model lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit \ 95 | --prompt "<|im_start|>system 96 | You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|> 97 | <|im_start|>user 98 | Write a quick sort algorithm in C++<|im_end|> 99 | <|im_start|>assistant 100 | " 101 | ``` 102 | 103 | ## Development Setup 104 | 105 | ### Pre-commit Hooks 106 | 107 | We use pre-commit hooks to maintain code quality. Before contributing, please: 108 | 109 | 1. Install pre-commit: 110 | ```bash 111 | pip install pre-commit && pre-commit install 112 | ``` 113 | 2. Run pre-commit: 114 | ```bash 115 | pre-commit run --all-files 116 | ``` 117 | 3. Fix any issues before submitting your PR 118 | 119 | ## Testing 120 | 121 | To run tests, run the following from the root of this repo: 122 | ```bash 123 | python -m pip install pytest 124 | python -m pytest tests/ 125 | ``` 126 | 127 | To test specific vision models: 128 | ```bash 129 | python -m pytest tests/test_vision_models.py -k pixtral 130 | ``` 131 | 132 | ## Attribution 133 | 134 | Ernie 4.5 modeling code is sourced from [Baidu](https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT/tree/da6f3b1158d5d0d2bbf552bfc3364c9ec64e8aa5) 135 | -------------------------------------------------------------------------------- /mlx_engine/external/models/ernie4_5/configuration_ernie4_5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Baidu, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from transformers import PretrainedConfig 16 | 17 | 18 | class Ernie4_5_Config(PretrainedConfig): 19 | """ 20 | Configuration class. 21 | 22 | This class stores the configuration of an Ernie model, defining the model architecture. 23 | It inherits from PretrainedConfig and can be used to control model outputs. 24 | """ 25 | 26 | model_type = "ernie4_5" 27 | keys_to_ignore_at_inference = ["past_key_values"] 28 | 29 | # Default tensor parallel plan for base model `Qwen3` 30 | base_model_tp_plan = { 31 | "layers.*.self_attn.q_proj": "colwise", 32 | "layers.*.self_attn.k_proj": "colwise", 33 | "layers.*.self_attn.v_proj": "colwise", 34 | "layers.*.self_attn.o_proj": "rowwise", 35 | "layers.*.mlp.gate_proj": "colwise", 36 | "layers.*.mlp.up_proj": "colwise", 37 | "layers.*.mlp.down_proj": "rowwise", 38 | } 39 | base_model_pp_plan = { 40 | "embed_tokens": (["input_ids"], ["inputs_embeds"]), 41 | "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), 42 | "norm": (["hidden_states"], ["hidden_states"]), 43 | } 44 | 45 | def __init__( 46 | self, 47 | vocab_size=32000, 48 | hidden_size=768, 49 | intermediate_size=11008, 50 | max_position_embeddings=32768, 51 | num_hidden_layers=2, 52 | num_attention_heads=2, 53 | rms_norm_eps=1e-6, 54 | use_cache=False, 55 | use_flash_attention=False, 56 | pad_token_id=0, 57 | bos_token_id=1, 58 | eos_token_id=2, 59 | use_bias=False, 60 | rope_theta=10000, 61 | weight_share_add_bias=True, 62 | ignored_index=-100, 63 | attention_probs_dropout_prob=0.0, 64 | hidden_dropout_prob=0.0, 65 | compression_ratio: float = 1.0, 66 | num_key_value_heads=None, 67 | max_sequence_length=None, 68 | **kwargs, 69 | ): 70 | """ 71 | Initialize configuration with default or specified parameters. 72 | 73 | Args: 74 | vocab_size (int): Size of the vocabulary (number of unique tokens) 75 | hidden_size (int): Dimensionality of the encoder layers and the pooler layer 76 | intermediate_size (int): Dimensionality of the "intermediate" (feed-forward) layer 77 | max_position_embeddings (int): Maximum sequence length the model can handle 78 | num_hidden_layers (int): Number of hidden layers in the Transformer encoder 79 | num_attention_heads (int): Number of attention heads for each attention layer 80 | rms_norm_eps (float): The epsilon used by the RMS normalization layers 81 | use_cache (bool): Whether to use caching for faster generation (decoding) 82 | use_flash_attention (bool): Whether to use FlashAttention for optimized attention computation 83 | pad_token_id (int): Token ID used for padding sequences 84 | bos_token_id (int): Token ID used for beginning-of-sequence 85 | eos_token_id (int): Token ID used for end-of-sequence 86 | use_bias (bool): Whether to use bias terms in linear layers 87 | rope_theta (float): The base period of the RoPE embeddings 88 | weight_share_add_bias (bool): Whether to share bias weights in certain layers 89 | ignored_index (int): Target value that is ignored during loss computation 90 | attention_probs_dropout_prob (float): Dropout probability for attention weights 91 | hidden_dropout_prob (float): Dropout probability for hidden layers 92 | compression_ratio (float): Ratio for KV cache compression (1.0 = no compression) 93 | num_key_value_heads (int): Number of key/value heads (for Grouped Query Attention) 94 | max_sequence_length (int): Maximum sequence length for positional embeddings 95 | **kwargs: Additional keyword arguments passed to parent class 96 | """ 97 | 98 | # Set default for tied embeddings if not specified. 99 | if "tie_word_embeddings" not in kwargs: 100 | kwargs["tie_word_embeddings"] = False 101 | super().__init__( 102 | pad_token_id=pad_token_id, 103 | bos_token_id=bos_token_id, 104 | eos_token_id=eos_token_id, 105 | **kwargs, 106 | ) 107 | self.vocab_size = vocab_size 108 | self.hidden_size = hidden_size 109 | self.intermediate_size = intermediate_size 110 | self.max_position_embeddings = max_position_embeddings 111 | self.num_hidden_layers = num_hidden_layers 112 | self.num_attention_heads = num_attention_heads 113 | self.rms_norm_eps = rms_norm_eps 114 | self.use_cache = use_cache 115 | self.use_flash_attention = use_flash_attention 116 | self.pad_token_id = pad_token_id 117 | self.bos_token_id = bos_token_id 118 | self.eos_token_id = eos_token_id 119 | self.use_bias = use_bias 120 | self.weight_share_add_bias = weight_share_add_bias 121 | self.rope_theta = rope_theta 122 | self.ignored_index = ignored_index 123 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 124 | self.hidden_dropout_prob = hidden_dropout_prob 125 | self.compression_ratio = compression_ratio 126 | self.num_key_value_heads = num_key_value_heads 127 | self.max_sequence_length = max_sequence_length 128 | -------------------------------------------------------------------------------- /mlx_engine/vision_model_kit/vision_model_kit.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional, List, Tuple 2 | from mlx_engine.model_kit.model_kit import ModelKit 3 | import logging 4 | 5 | from ._transformers_compatibility import ( 6 | fix_qwen2_5_vl_image_processor, 7 | fix_qwen2_vl_preprocessor, 8 | ) 9 | from .vision_model_wrapper import VisionModelWrapper 10 | import mlx_vlm 11 | import mlx_lm 12 | from pathlib import Path 13 | import mlx.core as mx 14 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class VisionModelKit(ModelKit): 20 | """ 21 | Collection of objects and methods that are needed for operating a vision model 22 | """ 23 | 24 | config: dict = None 25 | trust_remote_code: bool = False 26 | model_path: Path = None 27 | vocab_only: bool = False 28 | model_weights = None 29 | 30 | processor: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None 31 | has_processed_prompt: bool = False 32 | 33 | def __init__( 34 | self, 35 | model_path: Path, 36 | vocab_only: bool, 37 | trust_remote_code: bool, 38 | ): 39 | fix_qwen2_5_vl_image_processor(model_path) 40 | fix_qwen2_vl_preprocessor(model_path) 41 | self.config = mlx_vlm.utils.load_config( 42 | model_path, trust_remote_code=trust_remote_code 43 | ) 44 | self.trust_remote_code = trust_remote_code 45 | self.vocab_only = vocab_only 46 | self.model_path = model_path 47 | self._initializer() 48 | 49 | def _vocab_only_init(self): 50 | self.tokenizer = mlx_vlm.tokenizer_utils.load_tokenizer(self.model_path) 51 | self.detokenizer = self.tokenizer.detokenizer 52 | 53 | def _full_model_init(self): 54 | additional_kwargs = {} 55 | if self.model_weights: 56 | additional_kwargs["weights"] = self.model_weights 57 | return_tuple = mlx_vlm.utils.load( 58 | self.model_path, 59 | processor_config={"trust_remote_code": self.trust_remote_code}, 60 | trust_remote_code=self.trust_remote_code, 61 | **additional_kwargs, 62 | ) 63 | if len(return_tuple) == 2: 64 | self.model, self.processor = return_tuple 65 | else: 66 | self.model, self.processor, self.model_weights = return_tuple 67 | self.model = VisionModelWrapper(self.model) 68 | 69 | # Set the eos_token_ids (check root level first, then text_config) 70 | eos_token_ids_raw = self.config.get("eos_token_id") 71 | if eos_token_ids_raw is None: 72 | eos_token_ids_raw = self.config.get("text_config", {}).get("eos_token_id") 73 | eos_token_ids = None 74 | if eos_token_ids_raw is not None: 75 | eos_token_ids = ( 76 | [eos_token_ids_raw] 77 | if isinstance(eos_token_ids_raw, int) 78 | else list(set(eos_token_ids_raw)) 79 | ) 80 | logger.info(f"Setting eos token ids: {eos_token_ids}") 81 | 82 | # Use the mlx_lm tokenizer since it's more robust 83 | self.tokenizer = mlx_lm.tokenizer_utils.load( 84 | self.model_path, eos_token_ids=eos_token_ids 85 | ) 86 | self.detokenizer = self.tokenizer.detokenizer 87 | 88 | self.cache_wrapper = None 89 | mx.clear_cache() 90 | 91 | def _initializer(self): 92 | if self.vocab_only: 93 | self._vocab_only_init() 94 | else: 95 | self._full_model_init() 96 | 97 | def _reset_for_prediction(self): 98 | # It's a shortcoming that the only way to reset the model for prediction 99 | # is to reload it. Worth investigating how to make resetting faster 100 | self._full_model_init() 101 | 102 | def process_prompt( 103 | self, 104 | prompt_tokens, 105 | images_b64: Optional[List[str]], 106 | prompt_progress_callback, 107 | generate_args, 108 | max_image_size: tuple[int, int] | None, 109 | speculative_decoding_toggle: Optional[bool] = None, 110 | ) -> Tuple[mx.array, Optional[mx.array]]: 111 | """ 112 | Call this before starting evaluation 113 | 114 | This method opens the image from the base64-encoded string, and adds the special image token to the prompt 115 | 116 | Returns the processed prompt tokens to be input to the `generate_step` function, and optionally input 117 | embeddings. For VisionModelKit, the input embeddings are always none. 118 | """ 119 | if self.has_processed_prompt: 120 | self._reset_for_prediction() 121 | 122 | self.model.process_prompt_with_images( 123 | images_b64, prompt_tokens, self.processor, self.detokenizer, max_image_size 124 | ) 125 | self.has_processed_prompt = True 126 | 127 | # The VLM input_ids shape is important, but mlx_lm expects a flattened array 128 | # Send back a fake shape and input_ids, and save the real shape in `self.model.input_ids` 129 | if images_b64 is None or len(images_b64) == 0: 130 | # For text-only, enable mlx-lm prompt processing 131 | return self.model.input_ids.reshape(-1), None 132 | # Disable mlx-lm prompt processing by returning a fake input 133 | return mx.array([0]), mx.array([0]) 134 | 135 | def is_cross_prompt_cache_active(self) -> bool: 136 | """VisionModelKit does not support cross prompt caching""" 137 | return False 138 | 139 | def record_token_to_cache(self, token: int) -> None: 140 | pass 141 | 142 | def record_sampled_token(self, token: int) -> None: 143 | self.model.record_sampled_token(token) 144 | 145 | def is_draft_model_compatible(self, path: str | Path) -> bool: 146 | return False 147 | 148 | def load_draft_model(self, path: str | Path) -> None: 149 | raise ValueError( 150 | "Speculative decoding is not currently supported for vision models" 151 | ) 152 | 153 | def unload_draft_model(self) -> None: 154 | raise ValueError( 155 | "Speculative decoding is not currently supported for vision models" 156 | ) 157 | 158 | @property 159 | def language_model(self): 160 | return self.model.language_model 161 | -------------------------------------------------------------------------------- /mlx_engine/external/models/lfm2_vl/configuration_lfm2_vl.py: -------------------------------------------------------------------------------- 1 | """PyTorch LFM2-VL model.""" 2 | 3 | import torch 4 | from transformers import AutoConfig 5 | from transformers.configuration_utils import PretrainedConfig 6 | from transformers.models.lfm2.configuration_lfm2 import Lfm2Config 7 | from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig 8 | from transformers.utils import logging 9 | 10 | logger = logging.get_logger(__name__) 11 | 12 | 13 | class Lfm2VlConfig(PretrainedConfig): 14 | r""" 15 | This is the configuration class to store the configuration of a [`Lfm2VlForConditionalGeneration`]. It is used to instantiate an 16 | Lfm2Vl model according to the specified arguments, defining the model architecture. Instantiating a configuration 17 | with the defaults will yield a similar configuration to that of the Lfm2-VL-1.6B. 18 | 19 | e.g. [LiquidAI/LFM2-VL-1.6B](https://huggingface.co/LiquidAI/LFM2-VL-1.6B) 20 | 21 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 22 | documentation from [`PretrainedConfig`] for more information. 23 | 24 | Args: 25 | vision_config (`AutoConfig | dict`, *optional*, defaults to `Siglip2ImageConfig`): 26 | The config object or dictionary of the vision backbone. 27 | text_config (`AutoConfig | dict`, *optional*, defaults to `Lfm2Config`): 28 | The config object or dictionary of the text backbone. 29 | image_token_id (`int`, *optional*, defaults to 396): 30 | The image token index to encode the image prompt. 31 | projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): 32 | The activation function used by the multimodal projector. 33 | projector_hidden_size (`int`, *optional*, defaults to 2056): 34 | The hidden size of the multimodal projector. 35 | projector_bias (`bool`, *optional*, defaults to `True`): 36 | Whether to use bias in the multimodal projector. 37 | downsample_factor (`int`, *optional*, defaults to 2): 38 | The downsample_factor factor of the vision backbone. 39 | vision_feature_layer (`int`, *optional*, defaults to -1): 40 | The layer of the vision tower to use as features. 41 | min_image_tokens (`int`, *optional*, defaults to 64): 42 | The minimum number of image tokens for smart resize. 43 | max_image_tokens (`int`, *optional*, defaults to 256): 44 | The maximum number of image tokens for smart resize. 45 | encoder_patch_size (`int`, *optional*, defaults to 16): 46 | The patch size of the encoder. 47 | max_num_patches (`int`, *optional*, defaults to 1024): 48 | The maximum number of image tokens passed to the encoder per image or tile. 49 | use_image_special_tokens (`bool`, *optional*, defaults to `True`): 50 | Whether to use image special tokens. 51 | do_image_splitting (`bool`, *optional*, defaults to `True`): 52 | Whether to split large images into tiles. 53 | min_tiles (`int`, *optional*, defaults to 2): 54 | The minimum number of tiles to split the image into. 55 | max_tiles (`int`, *optional*, defaults to 10): 56 | The maximum number of tiles to split the image into. 57 | tile_size (`int`, *optional*, defaults to 512): 58 | The size of the tile to split the image into. 59 | max_pixels_tolerance (`float`, *optional*, defaults to 2.0): 60 | The maximum tolerance for the number of pixels in the image before splitting. 61 | use_thumbnail (`bool`, *optional*, defaults to `True`): 62 | Whether to append the thumbnail of the image when splitting. 63 | """ 64 | 65 | model_type = "lfm2-vl" 66 | attribute_map = { 67 | "image_token_id": "image_token_index", 68 | } 69 | sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} 70 | 71 | def __init__( 72 | self, 73 | vision_config=None, 74 | text_config=None, 75 | image_token_index=396, 76 | projector_hidden_act="gelu", 77 | projector_hidden_size=2560, 78 | projector_bias=True, 79 | downsample_factor=2, 80 | vision_feature_layer=-1, 81 | min_image_tokens=64, 82 | max_image_tokens=256, 83 | encoder_patch_size=16, 84 | max_num_patches=1024, 85 | use_image_special_tokens=True, 86 | do_image_splitting=True, 87 | min_tiles=2, 88 | max_tiles=10, 89 | tile_size=512, 90 | max_pixels_tolerance=2.0, 91 | use_thumbnail=True, 92 | torch_dtype=torch.bfloat16, 93 | **kwargs, 94 | ): 95 | self.vision_config = vision_config 96 | self.text_config = text_config 97 | self.image_token_index = image_token_index 98 | self.projector_hidden_act = projector_hidden_act 99 | self.projector_hidden_size = projector_hidden_size 100 | self.projector_bias = projector_bias 101 | self.downsample_factor = downsample_factor 102 | self.vision_feature_layer = vision_feature_layer 103 | self.min_image_tokens = min_image_tokens 104 | self.max_image_tokens = max_image_tokens 105 | self.encoder_patch_size = encoder_patch_size 106 | self.max_num_patches = max_num_patches 107 | self.use_image_special_tokens = use_image_special_tokens 108 | self.do_image_splitting = do_image_splitting 109 | self.min_tiles = min_tiles 110 | self.max_tiles = max_tiles 111 | self.tile_size = tile_size 112 | self.max_pixels_tolerance = max_pixels_tolerance 113 | self.use_thumbnail = use_thumbnail 114 | self.torch_dtype = torch_dtype 115 | 116 | if isinstance(vision_config, dict): 117 | vision_config = Siglip2VisionConfig(**vision_config) 118 | elif vision_config is None: 119 | vision_config = Siglip2VisionConfig() 120 | self.vision_config = vision_config 121 | 122 | self.vision_config = vision_config 123 | 124 | if isinstance(text_config, dict): 125 | text_config = Lfm2Config(**text_config) 126 | elif text_config is None: 127 | text_config = Lfm2Config() 128 | 129 | self.text_config = text_config 130 | 131 | super().__init__(**kwargs) 132 | -------------------------------------------------------------------------------- /mlx_engine/model_kit/vision_add_ons/gemma3n.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from mlx import nn 3 | import mlx.core as mx 4 | from mlx_vlm.models.gemma3n import ( 5 | VisionModel as Gemma3nVisionTower, 6 | ModelConfig as Gemma3nModelConfig, 7 | VisionConfig as Gemma3nVisionConfig, 8 | TextConfig as Gemma3nTextConfig, 9 | Model as Gemma3nCombinedModel, 10 | ) 11 | from mlx_vlm.models.gemma3n.gemma3n import Gemma3nMultimodalEmbedder 12 | from mlx_vlm.utils import sanitize_weights, load_processor 13 | import logging 14 | 15 | 16 | from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn 17 | from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import ( 18 | common_process_prompt_with_images, 19 | ) 20 | from mlx_engine.model_kit.vision_add_ons.load_utils import ( 21 | load_and_filter_weights, 22 | load_and_parse_config, 23 | maybe_apply_quantization, 24 | prepare_components, 25 | ) 26 | import json 27 | 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | class Gemma3nVisionComponents(nn.Module): 33 | def __init__(self, vision_tower: nn.Module, embed_vision: nn.Module): 34 | super().__init__() 35 | self.vision_tower = vision_tower 36 | self.embed_vision = embed_vision 37 | 38 | 39 | class Gemma3nVisionAddOn(BaseVisionAddOn): 40 | """ 41 | Vision add-on for Gemma3n model. Uses mlx-vlm vision components of Gemma3n. 42 | """ 43 | 44 | def __init__(self, model_path: Path): 45 | """Initialize Gemma3nVisionAddOn with vision components loaded from the given path.""" 46 | super().__init__() 47 | 48 | # The gemma3n weights were re-uploaded by google on 20250710 49 | # The re-upload transposed two of the axis of the weights 50 | # Here, check to see if we're using a model uploaded before 20250710 51 | self.using_legacy_weights = False 52 | with open(model_path / "config.json", "r") as f: 53 | config_json = json.load(f) 54 | if ( 55 | "text_config" in config_json 56 | and "query_pre_attn_scalar" in config_json["text_config"] 57 | ): 58 | self.using_legacy_weights = True 59 | 60 | config, config_dict = load_and_parse_config( 61 | model_path=model_path, 62 | model_config_class=Gemma3nModelConfig, 63 | vision_config_class=Gemma3nVisionConfig, 64 | text_config_class=Gemma3nTextConfig, 65 | ) 66 | 67 | components = Gemma3nVisionComponents( 68 | vision_tower=Gemma3nVisionTower(config.vision_config), 69 | embed_vision=Gemma3nMultimodalEmbedder( 70 | config.vision_config, config.text_config 71 | ), 72 | ) 73 | if self.using_legacy_weights: 74 | del components.vision_tower.timm_model.conv_stem.conv.bias 75 | processor = load_processor(model_path=model_path, add_detokenizer=True) 76 | vision_weights = load_and_filter_weights(model_path, components) 77 | vision_weights = sanitize_weights( 78 | components.vision_tower.__class__, vision_weights, config.vision_config 79 | ) 80 | maybe_apply_quantization(components, config_dict, vision_weights) 81 | prepare_components(components, vision_weights) 82 | 83 | logger.info( 84 | f"Vision add-on loaded successfully from {model_path}", 85 | ) 86 | 87 | self.vision_tower = components.vision_tower 88 | self.embed_vision = components.embed_vision 89 | self.config = config 90 | self.processor = processor 91 | 92 | def compute_embeddings( 93 | self, 94 | text_model: nn.Module, 95 | prompt_tokens: mx.array, 96 | images_b64: list[str], 97 | max_size: tuple[int, int] | None, 98 | ) -> tuple[mx.array, mx.array]: 99 | """Compute input_ids and embeddings for text with images.""" 100 | input_ids, pixel_values, attention_mask, other_model_inputs = ( 101 | common_process_prompt_with_images( 102 | prompt_tokens=prompt_tokens, 103 | images_b64=images_b64, 104 | processor=self.processor, 105 | config=self.config, 106 | max_size=max_size, 107 | ) 108 | ) 109 | assert input_ids is not None 110 | 111 | # See mlx_vlm.models.gemma3n.gemma3n.Model.get_input_embeddings 112 | # This implementation was based on commit mlx-vlm commit ebafa5a789ed1a8e050b8366ae4e845dbe640b90 113 | # It differs slightly from mlx-vlm in the bounds on the vision_mask. 114 | # However, the two calculations should be equivalent (vision vocab offset + size) == audio vocab offset 115 | inputs_embeds = text_model.model.language_model.embed_tokens(input_ids) 116 | vision_mask = mx.logical_and( 117 | input_ids >= self.embed_vision.vocab_offset, 118 | input_ids < self.embed_vision.vocab_offset + self.embed_vision.vocab_size, 119 | ) 120 | dummy_vision_token_id = ( 121 | self.embed_vision.vocab_offset + self.embed_vision.vocab_size - 1 122 | ) 123 | vision_tokens = mx.where(vision_mask, input_ids, dummy_vision_token_id) 124 | vision_embeds_flat = self.embed_vision(input_ids=vision_tokens) 125 | inputs_embeds = mx.where( 126 | vision_mask[..., None], vision_embeds_flat, inputs_embeds 127 | ) 128 | 129 | if self.using_legacy_weights: 130 | # The array is still in pytorch format here (NCHW) 131 | # Transpose the HW axes 132 | pixel_values = pixel_values.swapaxes(2, 3) 133 | 134 | # Process image through vision tower, then embed into language model space 135 | image_features = Gemma3nCombinedModel.get_image_features( 136 | pixel_values, 137 | self.vision_tower, 138 | self.config, 139 | self.embed_vision, 140 | ) 141 | 142 | # Construct mask that matches image embedding locations 143 | special_modality_mask = mx.expand_dims( 144 | input_ids == self.config.image_token_id, -1 145 | ) 146 | special_modality_mask = mx.broadcast_to( 147 | special_modality_mask, inputs_embeds.shape 148 | ) 149 | 150 | # Construct embeddings with image and text tokens interleaved per special modality mask 151 | final_inputs_embeds = Gemma3nCombinedModel.merge_multimodal_and_text( 152 | inputs_embeds, image_features, special_modality_mask, "image" 153 | ) 154 | # remove batch dimension 155 | return input_ids.squeeze(0), final_inputs_embeds.squeeze(0) 156 | -------------------------------------------------------------------------------- /mlx_engine/model_kit/vision_add_ons/lfm2_vl.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | from mlx import nn 5 | import mlx.core as mx 6 | 7 | from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn 8 | from mlx_vlm.models.lfm2_vl import ( 9 | VisionModel as LFM2VlVisionTower, 10 | ModelConfig as LFM2VlModelConfig, 11 | VisionConfig as LFM2VlVisionConfig, 12 | TextConfig as LFM2VlTextConfig, 13 | Model as LFM2VlModel, 14 | ) 15 | from mlx_vlm.models.lfm2_vl.lfm2_vl import ( 16 | Lfm2VlMultiModalProjector, 17 | PixelUnshuffleBlock, 18 | ) 19 | from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import ( 20 | common_process_prompt_with_images, 21 | ) 22 | from mlx_engine.model_kit.vision_add_ons.load_utils import load_vision_addon 23 | from transformers.image_utils import ChannelDimension 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | class LFM2VisionAddOn(BaseVisionAddOn): 29 | """ 30 | Vision add-on for LFM2 models. 31 | """ 32 | 33 | def __init__(self, model_path: Path): 34 | """Initialize LFM2VisionAddOn with vision components loaded from the given path.""" 35 | super().__init__() 36 | 37 | self.vision_tower, self.multi_modal_projector, self.config, self.processor = ( 38 | load_vision_addon( 39 | model_path=model_path, 40 | model_config_class=LFM2VlModelConfig, 41 | vision_config_class=LFM2VlVisionConfig, 42 | text_config_class=LFM2VlTextConfig, 43 | vision_tower_class=LFM2VlVisionTower, 44 | multi_modal_projector_class=Lfm2VlMultiModalProjector, 45 | logger=logger, 46 | ) 47 | ) 48 | 49 | self._ensure_channel_first_if_fast_processor(self.processor) 50 | 51 | # this particular block comes from 52 | # https://github.com/Blaizzy/mlx-vlm/blob/f02d63e8f5b521e8c75f129a63d2660efd132693/mlx_vlm/models/lfm2_vl/lfm2_vl.py#L102-L105 53 | if self.config.downsample_factor > 1: 54 | self.pixel_unshuffle = PixelUnshuffleBlock(self.config.downsample_factor) 55 | else: 56 | self.pixel_unshuffle = nn.Identity() 57 | 58 | def compute_embeddings( 59 | self, 60 | text_model: nn.Module, 61 | prompt_tokens: mx.array, 62 | images_b64: list[str], 63 | max_size: tuple[int, int] | None, 64 | ) -> tuple[mx.array, mx.array]: 65 | """ 66 | Compute embeddings for text with images. 67 | 68 | This method is heavily based on the mlx-vlm's lfm2_vl `get_input_embeddings` 69 | https://github.com/Blaizzy/mlx-vlm/blob/f02d63e8f5b521e8c75f129a63d2660efd132693/mlx_vlm/models/lfm2_vl/lfm2_vl.py#L110-L150 70 | """ 71 | 72 | input_ids, pixel_values, attention_mask, other_model_inputs = ( 73 | common_process_prompt_with_images( 74 | prompt_tokens=prompt_tokens, 75 | images_b64=images_b64, 76 | processor=self.processor, 77 | config=self.config, 78 | max_size=max_size, 79 | ) 80 | ) 81 | 82 | # Get the input embeddings from the language model 83 | inputs_embeds = text_model.language_model.model.embed_tokens(input_ids) 84 | 85 | if pixel_values is None: 86 | return inputs_embeds 87 | 88 | spatial_shapes = other_model_inputs["spatial_shapes"] 89 | pixel_attention_mask = other_model_inputs["pixel_attention_mask"] 90 | 91 | # Get the ouptut hidden states from the vision model 92 | *_, hidden_states = self.vision_tower( 93 | pixel_values, output_hidden_states=True, spatial_shapes=spatial_shapes 94 | ) 95 | 96 | img_feature_lengths = pixel_attention_mask.sum(axis=1).tolist() 97 | image_features = [] 98 | 99 | for img_idx in range(hidden_states.shape[0]): 100 | feature = hidden_states[img_idx] 101 | 102 | feature = feature[: img_feature_lengths[img_idx], :][None, ...] 103 | 104 | feature_org_h, feature_org_w = spatial_shapes[img_idx] 105 | feature = feature.reshape(1, feature_org_h, feature_org_w, -1) 106 | feature = self.pixel_unshuffle(feature) 107 | 108 | img_embedding = self.multi_modal_projector(feature) 109 | 110 | img_embedding = img_embedding.reshape(-1, img_embedding.shape[-1]) 111 | image_features.append(img_embedding) 112 | 113 | image_features = mx.concatenate(image_features, axis=0) 114 | 115 | final_inputs_embeds = LFM2VlModel.merge_input_ids_with_image_features( 116 | image_features, inputs_embeds, input_ids, self.config.image_token_index 117 | ) 118 | 119 | if input_ids.shape[1] == final_inputs_embeds.shape[1]: 120 | return input_ids.squeeze(0), final_inputs_embeds.squeeze(0) 121 | return input_ids, final_inputs_embeds 122 | 123 | @staticmethod 124 | def _ensure_channel_first_if_fast_processor(processor) -> None: 125 | """Override input_data_format is "channels_first" to avoid double permutes.""" 126 | image_processor = getattr(processor, "image_processor", None) 127 | if image_processor and getattr(image_processor, "is_fast", False): 128 | # LFM2 model shipped with preprocessor_config.json "input_data_format": "channels_last" 129 | # ref: https://huggingface.co/mlx-community/LFM2-VL-450M-4bit/blob/main/preprocessor_config.json#L20 130 | 131 | # mlx-vlm sets use_fast=True when loading the processor 132 | # ref: https://github.com/Blaizzy/mlx-vlm/blob/1d8622b061cd39b7af6738500c60804a3f171095/mlx_vlm/utils.py#L367 133 | 134 | # In transformers fast processors, torchvision pil_to_tensor call naturally permutes 135 | # from "channels_last" -> "channels_first" 136 | # ref: https://github.com/huggingface/transformers/blob/dd24a80666b72c85f02c6cf9df18164cc174ab74/src/transformers/image_processing_utils_fast.py#L689-L690 137 | # ref: https://github.com/pytorch/vision/blob/96e779759a883651e6ec2b394bf89de8beb5b709/torchvision/transforms/functional.py#L211 138 | 139 | # transformers fast processors still sees "channels_last" in config, so still forces a 140 | # permute after pil_to_tensor, causing an incorrect double-permute. 141 | # ref: https://github.com/huggingface/transformers/blob/dd24a80666b72c85f02c6cf9df18164cc174ab74/src/transformers/image_processing_utils_fast.py#L703-L705 142 | 143 | # So since we ship with torchvision, set to ChannelDimension.FIRST to avoid this 144 | image_processor.input_data_format = ChannelDimension.FIRST 145 | -------------------------------------------------------------------------------- /tests/test_stop_string_processor.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from pathlib import Path 3 | 4 | from mlx_engine.utils.disable_hf_download import _original_snapshot_download 5 | import mlx_lm 6 | 7 | from mlx_engine.stop_string_processor import StopStringProcessor 8 | 9 | 10 | class TestStopStringProcessor(unittest.TestCase): 11 | @classmethod 12 | def setUpClass(cls): 13 | """Set up any necessary resources that can be shared across all tests.""" 14 | # use Llama-3.1 tokenizer for testing 15 | cls.tokenizer = cls.download_tokenizer( 16 | "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit" 17 | ) 18 | 19 | # followed pattern from mlx-examples 20 | # https://github.com/ml-explore/mlx-examples/blob/cfc29c29f45372c78876335a44b0c99ab6565ae0/llms/tests/test_tokenizers.py#L17 21 | @staticmethod 22 | def download_tokenizer(repo): 23 | path = Path( 24 | _original_snapshot_download( 25 | repo_id=repo, 26 | allow_patterns=[ 27 | "tokenizer.json", 28 | "tokenizer_config.json", 29 | "special_tokens_map.json", 30 | "tokenizer.model", 31 | ], 32 | ) 33 | ) 34 | return mlx_lm.tokenizer_utils.load(path) 35 | 36 | def process_tokens(self, stop_strings, input_string): 37 | """Helper method to process tokens and collect results""" 38 | processor = StopStringProcessor(stop_strings, self.tokenizer) 39 | input_tokens = self.tokenizer.encode(input_string) 40 | results = [] 41 | for token in input_tokens: 42 | result = processor.process_token(token) 43 | results.append(result) 44 | if result.status == "full_stop": 45 | break 46 | return results 47 | 48 | def test_stop_string_processor_simple(self): 49 | results = self.process_tokens(["of"], "The objective of chess") 50 | 51 | self.assertEqual(results[-1].status, "full_stop") 52 | self.assertEqual(results[-1].stop_string, "of") 53 | self.assertEqual(results[-1].stop_tokens, self.tokenizer.encode(" of")) 54 | 55 | def test_stop_string_at_start(self): 56 | results = self.process_tokens(["Hello"], "Hello world") 57 | self.assertEqual(results[-1].status, "full_stop") 58 | self.assertEqual(results[-1].stop_string, "Hello") 59 | self.assertEqual(results[-1].stop_tokens, self.tokenizer.encode("Hello")) 60 | 61 | def test_stop_string_at_end(self): 62 | results = self.process_tokens(["world"], "Hello world") 63 | self.assertEqual(results[-1].status, "full_stop") 64 | self.assertEqual(results[-1].stop_string, "world") 65 | self.assertEqual(results[-1].stop_tokens, self.tokenizer.encode(" world")) 66 | 67 | def test_case_sensitivity(self): 68 | results = self.process_tokens(["Stop"], "This is a STOP sign") 69 | self.assertEqual(results[-1].status, "no_match") 70 | 71 | def test_stop_string_with_special_characters(self): 72 | results = self.process_tokens(["\n"], "Hello\nworld") 73 | self.assertEqual(results[-1].status, "full_stop") 74 | self.assertEqual(results[-1].stop_string, "\n") 75 | self.assertEqual(results[-1].stop_tokens, self.tokenizer.encode("\n")) 76 | 77 | def test_unicode_stop_strings(self): 78 | results = self.process_tokens(["é", "ñ", "北京"], "Hello 北京 é ñ") 79 | self.assertEqual(results[-1].status, "full_stop") 80 | self.assertEqual(results[-1].stop_string, "北京") 81 | self.assertEqual(results[-1].stop_tokens, self.tokenizer.encode(" 北京")) 82 | 83 | def test_stop_string_processor_no_match(self): 84 | results = self.process_tokens(["other"], "The objective of chess") 85 | 86 | for i, result in enumerate(results): 87 | self.assertEqual( 88 | result.status, 89 | "no_match", 90 | f"Result at position {i} has status '{result.status}' instead of 'no_match'", 91 | ) 92 | 93 | def test_stop_string_processor_long_no_match(self): 94 | results = self.process_tokens( 95 | ["The objective of checkers"], "The objective of chess" 96 | ) 97 | 98 | for i, result in enumerate(results[:-1]): 99 | self.assertEqual( 100 | result.status, 101 | "partial_match", 102 | f"Result at position {i} has status '{result.status}' instead of 'partial_match'", 103 | ) 104 | self.assertEqual(results[-1].status, "no_match") 105 | 106 | def test_stop_string_processor_mid_word(self): 107 | results = self.process_tokens(["cti"], "The objective of chess") 108 | 109 | self.assertEqual(results[-1].status, "full_stop") 110 | self.assertEqual(results[-1].stop_string, "cti") 111 | self.assertEqual(results[-1].stop_tokens, self.tokenizer.encode(" objective")) 112 | 113 | def test_stop_string_processor_multi_token_multi_word(self): 114 | results = self.process_tokens(["objective of"], "The objective of chess") 115 | 116 | self.assertEqual(results[-1].status, "full_stop") 117 | self.assertEqual(results[-1].stop_string, "objective of") 118 | self.assertEqual( 119 | results[-1].stop_tokens, self.tokenizer.encode(" objective of") 120 | ) 121 | 122 | def test_stop_string_processor_multi_token_multi_token_single_char(self): 123 | results = self.process_tokens(["🌟"], "The objective 🌟 of chess") 124 | 125 | self.assertEqual(results[-1].status, "full_stop") 126 | self.assertEqual(results[-1].stop_string, "🌟") 127 | self.assertEqual(results[-1].stop_tokens, self.tokenizer.encode(" 🌟")) 128 | 129 | self.assertEqual(results[-3].status, "multi_byte") 130 | self.assertEqual(results[-2].status, "multi_byte") 131 | 132 | def test_multiple_stop_strings(self): 133 | results = self.process_tokens( 134 | ["of", "chess", "objective"], "The objective of chess" 135 | ) 136 | 137 | self.assertEqual(results[-1].status, "full_stop") 138 | self.assertEqual(results[-1].stop_string, "objective") 139 | self.assertEqual(results[-1].stop_tokens, self.tokenizer.encode(" objective")) 140 | 141 | def test_overlapping_stop_strings(self): 142 | results = self.process_tokens( 143 | ["objective of", "of chess"], "The objective of chess" 144 | ) 145 | 146 | self.assertEqual(results[-1].status, "full_stop") 147 | self.assertEqual(results[-1].stop_string, "objective of") 148 | self.assertEqual( 149 | results[-1].stop_tokens, self.tokenizer.encode(" objective of") 150 | ) 151 | 152 | def test_empty_stop_strings_list_raises(self): 153 | with self.assertRaises(ValueError): 154 | StopStringProcessor([], self.tokenizer) 155 | 156 | def test_non_string_stop_string_raises(self): 157 | stop_strings = ["valid", 123] 158 | with self.assertRaises(TypeError): 159 | StopStringProcessor(stop_strings, self.tokenizer) 160 | 161 | def test_none_stop_string_raises(self): 162 | stop_strings = ["valid", None] 163 | with self.assertRaises(TypeError): 164 | StopStringProcessor(stop_strings, self.tokenizer) 165 | 166 | def test_empty_stop_string_raises(self): 167 | stop_strings = ["valid", ""] 168 | with self.assertRaises(ValueError): 169 | StopStringProcessor(stop_strings, self.tokenizer) 170 | -------------------------------------------------------------------------------- /mlx_engine/stop_string_processor.py: -------------------------------------------------------------------------------- 1 | """Module for processing and handling stop strings during token generation.""" 2 | 3 | from typing import List, Literal, NamedTuple, Optional, Sequence 4 | 5 | StopStringProcessorStatus = Literal[ 6 | "full_stop", "partial_match", "no_match", "multi_byte" 7 | ] 8 | 9 | REPLACEMENT_CHAR = "\ufffd" 10 | 11 | 12 | class StopStringProcessorResult(NamedTuple): 13 | """Result of stop string processing containing status and details.""" 14 | 15 | status: StopStringProcessorStatus 16 | stop_string: Optional[str] = None # populated if status is "full_stop" 17 | # sequence of tokens that the stop_string was found in 18 | stop_tokens: Optional[List[int]] = None # populated if status is "full_stop" 19 | 20 | 21 | class StopStringProcessor: 22 | """State-fully processes tokens to check for stop strings during generation.""" 23 | 24 | def __init__(self, stop_strings: List[str], tokenizer): 25 | """ 26 | Args: 27 | stop_strings: List of strings that should stop generation if found 28 | tokenizer: Tokenizer instance for encoding token IDs to text 29 | 30 | Raises: 31 | ValueError: If stop_strings is empty or contains invalid values 32 | TypeError: If stop_strings contains non-string values 33 | """ 34 | if not stop_strings: 35 | raise ValueError("Must provide at least one stop string") 36 | 37 | if not all(isinstance(s, str) for s in stop_strings): 38 | raise TypeError("All stop strings must be strings") 39 | 40 | if any(not stop_string for stop_string in stop_strings): 41 | raise ValueError("Stop strings cannot be empty") 42 | 43 | self.stop_strings = stop_strings 44 | self.tokenizer = tokenizer 45 | self.token_id_buffer = [] 46 | 47 | def process_token(self, token: int) -> StopStringProcessorResult: 48 | """Process a new string segment and check how it relates to stop strings. 49 | 50 | Args: 51 | segment: The new string segment to process 52 | 53 | Returns: 54 | StopProcessorResult indicating the state of stop string detection 55 | """ 56 | if len(self.stop_strings) == 0: 57 | return StopStringProcessorResult( 58 | status="no_match", stop_string=None, stop_tokens=None 59 | ) 60 | 61 | self.token_id_buffer.append(token) 62 | 63 | result = self._stopping_criteria( 64 | string=self.tokenizer.decode(self.token_id_buffer), 65 | stop_strings=self.stop_strings, 66 | ) 67 | 68 | if result.status == "no_match": 69 | # Can clear the buffer in no partial or full matches with stop sequences 70 | self.token_id_buffer = [] 71 | return StopStringProcessorResult( 72 | status="no_match", stop_string=None, stop_tokens=None 73 | ) 74 | 75 | elif result.status == "partial_match": 76 | return StopStringProcessorResult( 77 | status="partial_match", stop_string=None, stop_tokens=None 78 | ) 79 | 80 | elif result.status == "multi_byte": 81 | return StopStringProcessorResult( 82 | status="multi_byte", stop_string=None, stop_tokens=None 83 | ) 84 | 85 | elif result.status == "full_stop": 86 | return StopStringProcessorResult( 87 | status="full_stop", 88 | stop_string=result.stop_string, 89 | stop_tokens=self.token_id_buffer, 90 | ) 91 | 92 | else: 93 | raise ValueError(f"Unknown StopProcessorStatus: {result.status}") 94 | 95 | class _StoppingCriteriaResult(NamedTuple): 96 | status: StopStringProcessorStatus 97 | stop_string: Optional[str] = None # populated if status is "full_stop" 98 | 99 | def _stopping_criteria( 100 | self, 101 | string: str, 102 | stop_strings: List[str], 103 | ) -> _StoppingCriteriaResult: 104 | """Check if stop strings match or partially match the input string 105 | 106 | Args: 107 | string: The string to check for stop strings 108 | stop_strings: List of strings that should stop generation if found 109 | 110 | Returns: 111 | StopStringProcessorResult indicating match status and stop string if matched 112 | 113 | Checks stopping conditions in priority order: 114 | 1. Incomplete UTF-8 string 115 | 2. Exact stop string match 116 | 3. Partial stop string match 117 | """ 118 | 119 | result = ( 120 | self._check_incomplete_utf8(string) 121 | or self._check_full_text_match(string, stop_strings) 122 | or self._check_partial_text_match(string, stop_strings) 123 | or self._StoppingCriteriaResult(status="no_match", stop_string=None) 124 | ) 125 | 126 | return result 127 | 128 | def _check_incomplete_utf8(self, string: str) -> Optional[_StoppingCriteriaResult]: 129 | if len(string) == 0 or string[-1] == REPLACEMENT_CHAR: 130 | return self._StoppingCriteriaResult(status="multi_byte", stop_string=None) 131 | return None 132 | 133 | def _check_full_text_match( 134 | self, string: str, stop_strings: List[str] 135 | ) -> Optional[_StoppingCriteriaResult]: 136 | """Find earliest full text match of any stop sequence.""" 137 | earliest_match = {"position": float("inf"), "stop_string": None} 138 | 139 | for stop_string in stop_strings: 140 | position = string.find(stop_string) 141 | 142 | if position != -1 and position < earliest_match["position"]: 143 | earliest_match.update( 144 | {"position": position, "stop_string": stop_string} 145 | ) 146 | 147 | if earliest_match["stop_string"] is not None: 148 | return self._StoppingCriteriaResult( 149 | status="full_stop", stop_string=earliest_match["stop_string"] 150 | ) 151 | return None 152 | 153 | def check_partial_token_match( 154 | self, token_sequence: List[int], stop_token_sequences: List[List[int]] 155 | ) -> Optional[_StoppingCriteriaResult]: 156 | """Check for partial matches with any stop sequence.""" 157 | for stop_token_sequence in stop_token_sequences: 158 | if self._sequence_overlap(token_sequence, stop_token_sequence): 159 | return self._StoppingCriteriaResult( 160 | status="partial_match", stop_string=None 161 | ) 162 | return None 163 | 164 | def _check_partial_text_match( 165 | self, string: str, stop_strings: List[str] 166 | ) -> Optional[StopStringProcessorResult]: 167 | """Check for partial matches with any stop sequence.""" 168 | for stop_string in stop_strings: 169 | if self._sequence_overlap(string, stop_string): 170 | return StopStringProcessorResult( 171 | status="partial_match", stop_string=None 172 | ) 173 | return None 174 | 175 | def _sequence_overlap(self, s1: Sequence, s2: Sequence) -> bool: 176 | """ 177 | Checks if a suffix of s1 has overlap with a prefix of s2 178 | 179 | Args: 180 | s1 (Sequence): The first sequence 181 | s2 (Sequence): The second sequence 182 | 183 | Returns: 184 | bool: If the two sequences have overlap 185 | """ 186 | max_overlap = min(len(s1), len(s2)) 187 | return any(s1[-i:] == s2[:i] for i in range(1, max_overlap + 1)) 188 | -------------------------------------------------------------------------------- /mlx_engine/external/datasets/dill.py: -------------------------------------------------------------------------------- 1 | # copied from https://github.com/huggingface/datasets/blob/1e1d313/src/datasets/utils/_dill.py 2 | 3 | # Copyright 2023 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Extends `dill` to support pickling more types and produce more consistent dumps.""" 17 | 18 | import sys 19 | from io import BytesIO 20 | from types import FunctionType 21 | from typing import Any, Dict, List, Union 22 | 23 | import dill 24 | import xxhash 25 | 26 | 27 | class Hasher: 28 | """Hasher that accepts python objects as inputs.""" 29 | 30 | dispatch: Dict = {} 31 | 32 | def __init__(self): 33 | self.m = xxhash.xxh64() 34 | 35 | @classmethod 36 | def hash_bytes(cls, value: Union[bytes, List[bytes]]) -> str: 37 | value = [value] if isinstance(value, bytes) else value 38 | m = xxhash.xxh64() 39 | for x in value: 40 | m.update(x) 41 | return m.hexdigest() 42 | 43 | @classmethod 44 | def hash(cls, value: Any) -> str: 45 | return cls.hash_bytes(dumps(value)) 46 | 47 | def update(self, value: Any) -> None: 48 | header_for_update = f"=={type(value)}==" 49 | value_for_update = self.hash(value) 50 | self.m.update(header_for_update.encode("utf8")) 51 | self.m.update(value_for_update.encode("utf-8")) 52 | 53 | def hexdigest(self) -> str: 54 | return self.m.hexdigest() 55 | 56 | 57 | class Pickler(dill.Pickler): 58 | dispatch = dill._dill.MetaCatchingDict(dill.Pickler.dispatch.copy()) 59 | _legacy_no_dict_keys_sorting = False 60 | 61 | def save(self, obj, save_persistent_id=True): 62 | obj_type = type(obj) 63 | if obj_type not in self.dispatch: 64 | if "regex" in sys.modules: 65 | import regex # type: ignore 66 | 67 | if obj_type is regex.Pattern: 68 | pklregister(obj_type)(_save_regexPattern) 69 | if "spacy" in sys.modules: 70 | import spacy # type: ignore 71 | 72 | if issubclass(obj_type, spacy.Language): 73 | pklregister(obj_type)(_save_spacyLanguage) 74 | if "tiktoken" in sys.modules: 75 | import tiktoken # type: ignore 76 | 77 | if obj_type is tiktoken.Encoding: 78 | pklregister(obj_type)(_save_tiktokenEncoding) 79 | if "torch" in sys.modules: 80 | import torch # type: ignore 81 | 82 | if issubclass(obj_type, torch.Tensor): 83 | pklregister(obj_type)(_save_torchTensor) 84 | 85 | if obj_type is torch.Generator: 86 | pklregister(obj_type)(_save_torchGenerator) 87 | 88 | # Unwrap `torch.compile`-ed modules 89 | if issubclass(obj_type, torch.nn.Module): 90 | obj = getattr(obj, "_orig_mod", obj) 91 | if "transformers" in sys.modules: 92 | import transformers # type: ignore 93 | 94 | if issubclass(obj_type, transformers.PreTrainedTokenizerBase): 95 | pklregister(obj_type)(_save_transformersPreTrainedTokenizerBase) 96 | 97 | # Unwrap `torch.compile`-ed functions 98 | if obj_type is FunctionType: 99 | obj = getattr(obj, "_torchdynamo_orig_callable", obj) 100 | dill.Pickler.save(self, obj, save_persistent_id=save_persistent_id) 101 | 102 | def _batch_setitems(self, items): 103 | if self._legacy_no_dict_keys_sorting: 104 | return super()._batch_setitems(items) 105 | # Ignore the order of keys in a dict 106 | try: 107 | # Faster, but fails for unorderable elements 108 | items = sorted(items) 109 | except Exception: # TypeError, decimal.InvalidOperation, etc. 110 | items = sorted(items, key=lambda x: Hasher.hash(x[0])) 111 | dill.Pickler._batch_setitems(self, items) 112 | 113 | def memoize(self, obj): 114 | # Don't memoize strings since two identical strings can have different Python ids 115 | if type(obj) is not str: # noqa: E721 116 | dill.Pickler.memoize(self, obj) 117 | 118 | 119 | def pklregister(t): 120 | """Register a custom reducer for the type.""" 121 | 122 | def proxy(func): 123 | Pickler.dispatch[t] = func 124 | return func 125 | 126 | return proxy 127 | 128 | 129 | def dump(obj, file): 130 | """Pickle an object to a file.""" 131 | Pickler(file, recurse=True).dump(obj) 132 | 133 | 134 | def dumps(obj): 135 | """Pickle an object to a string.""" 136 | file = BytesIO() 137 | dump(obj, file) 138 | return file.getvalue() 139 | 140 | 141 | def log(pickler, msg): 142 | pass 143 | 144 | 145 | def _save_regexPattern(pickler, obj): 146 | import regex # type: ignore 147 | 148 | log(pickler, f"Re: {obj}") 149 | args = (obj.pattern, obj.flags) 150 | pickler.save_reduce(regex.compile, args, obj=obj) 151 | log(pickler, "# Re") 152 | 153 | 154 | def _save_tiktokenEncoding(pickler, obj): 155 | import tiktoken # type: ignore 156 | 157 | log(pickler, f"Enc: {obj}") 158 | args = (obj.name, obj._pat_str, obj._mergeable_ranks, obj._special_tokens) 159 | pickler.save_reduce(tiktoken.Encoding, args, obj=obj) 160 | log(pickler, "# Enc") 161 | 162 | 163 | def _save_torchTensor(pickler, obj): 164 | import torch # type: ignore 165 | 166 | # `torch.from_numpy` is not picklable in `torch>=1.11.0` 167 | def create_torchTensor(np_array, dtype=None): 168 | tensor = torch.from_numpy(np_array) 169 | if dtype: 170 | tensor = tensor.type(dtype) 171 | return tensor 172 | 173 | log(pickler, f"To: {obj}") 174 | if obj.dtype == torch.bfloat16: 175 | args = (obj.detach().to(torch.float).cpu().numpy(), torch.bfloat16) 176 | else: 177 | args = (obj.detach().cpu().numpy(),) 178 | pickler.save_reduce(create_torchTensor, args, obj=obj) 179 | log(pickler, "# To") 180 | 181 | 182 | def _save_torchGenerator(pickler, obj): 183 | import torch # type: ignore 184 | 185 | def create_torchGenerator(state): 186 | generator = torch.Generator() 187 | generator.set_state(state) 188 | return generator 189 | 190 | log(pickler, f"Ge: {obj}") 191 | args = (obj.get_state(),) 192 | pickler.save_reduce(create_torchGenerator, args, obj=obj) 193 | log(pickler, "# Ge") 194 | 195 | 196 | def _save_spacyLanguage(pickler, obj): 197 | import spacy # type: ignore 198 | 199 | def create_spacyLanguage(config, bytes): 200 | lang_cls = spacy.util.get_lang_class(config["nlp"]["lang"]) 201 | lang_inst = lang_cls.from_config(config) 202 | return lang_inst.from_bytes(bytes) 203 | 204 | log(pickler, f"Sp: {obj}") 205 | args = (obj.config, obj.to_bytes()) 206 | pickler.save_reduce(create_spacyLanguage, args, obj=obj) 207 | log(pickler, "# Sp") 208 | 209 | 210 | def _save_transformersPreTrainedTokenizerBase(pickler, obj): 211 | log(pickler, f"Tok: {obj}") 212 | # Ignore the `cache` attribute 213 | state = obj.__dict__ 214 | if "cache" in state and isinstance(state["cache"], dict): 215 | state["cache"] = {} 216 | pickler.save_reduce(type(obj), (), state=state, obj=obj) 217 | log(pickler, "# Tok") 218 | -------------------------------------------------------------------------------- /mlx_engine/external/models/ernie4_5/tokenization_ernie4_5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Baidu, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from shutil import copyfile 17 | from typing import List, Optional, Tuple 18 | import sentencepiece as spm 19 | 20 | from transformers import PreTrainedTokenizer 21 | from transformers.utils import logging 22 | 23 | 24 | logger = logging.get_logger(__name__) 25 | 26 | 27 | class Ernie4_5_Tokenizer(PreTrainedTokenizer): 28 | 29 | vocab_files_names = { 30 | "vocab_file": "tokenizer.model", 31 | } 32 | # Model input names expected by the tokenizer 33 | model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"] 34 | # Padding side (where to add padding tokens) 35 | padding_side = "right" 36 | 37 | def __init__( 38 | self, 39 | vocab_file, 40 | bos_token="", 41 | cls_token="", 42 | eos_token="", 43 | mask_token="", 44 | pad_token="", 45 | sep_token="", 46 | unk_token="", 47 | additional_special_tokens=None, 48 | verbose=False, 49 | **kwargs, 50 | ): 51 | """ 52 | Initialize the ERNIE tokenizer. 53 | 54 | Args: 55 | vocab_file (str): Path to the SentencePiece model file. 56 | bos_token (str, optional): Beginning of sentence token. Defaults to "". 57 | cls_token (str, optional): Classification token. Defaults to "". 58 | eos_token (str, optional): End of sentence token. Defaults to "". 59 | mask_token (str, optional): Mask token. Defaults to "". 60 | pad_token (str, optional): Padding token. Defaults to "". 61 | sep_token (str, optional): Separator token. Defaults to "". 62 | unk_token (str, optional): Unknown token. Defaults to "". 63 | additional_special_tokens (List[str], optional): Additional special tokens. 64 | Defaults to ["", ""]. 65 | verbose (bool, optional): Whether to print detailed logs or progress information during execution. 66 | **kwargs: Additional keyword arguments passed to the parent class. 67 | """ 68 | 69 | self.vocab_file = vocab_file 70 | self.sp_model = spm.SentencePieceProcessor() 71 | self.sp_model.Load(vocab_file) 72 | 73 | if additional_special_tokens is None: 74 | additional_special_tokens = ["", ""] 75 | super().__init__( 76 | bos_token=bos_token, 77 | cls_token=cls_token, 78 | eos_token=eos_token, 79 | mask_token=mask_token, 80 | pad_token=pad_token, 81 | sep_token=sep_token, 82 | unk_token=unk_token, 83 | additional_special_tokens=additional_special_tokens, 84 | verbose=verbose, 85 | **kwargs, 86 | ) 87 | 88 | @property 89 | def vocab_size(self): 90 | """Returns the size of the vocabulary. 91 | 92 | Returns: 93 | int: The number of tokens in the vocabulary. 94 | """ 95 | return self.sp_model.vocab_size() 96 | 97 | def get_vocab(self): 98 | """Get the vocabulary as a dictionary mapping tokens to their IDs. 99 | 100 | Returns: 101 | dict: A dictionary mapping tokens to their corresponding IDs. 102 | """ 103 | vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} 104 | vocab.update(self.added_tokens_encoder) 105 | return vocab 106 | 107 | def _tokenize(self, text): 108 | """Tokenize text using SentencePiece. 109 | 110 | Args: 111 | text (str): The text to tokenize. 112 | 113 | Returns: 114 | list: A list of tokens. 115 | """ 116 | return self.sp_model.encode_as_pieces(text) 117 | 118 | def _convert_token_to_id(self, token): 119 | """Convert a token (str) to an ID using the vocabulary. 120 | 121 | Args: 122 | token (str): The token to convert. 123 | 124 | Returns: 125 | int: The corresponding token ID. 126 | """ 127 | return self.sp_model.piece_to_id(token) 128 | 129 | def _convert_id_to_token(self, id): 130 | """Convert an ID to a token (str) using the vocabulary. 131 | 132 | Args: 133 | id (int): The token ID to convert. 134 | 135 | Returns: 136 | str: The corresponding token. 137 | """ 138 | if id >= self.vocab_size: 139 | return self.unk_token 140 | else: 141 | return self.sp_model.id_to_piece(id) 142 | 143 | def convert_tokens_to_string(self, tokens): 144 | """Convert a sequence of tokens back to a single string. 145 | 146 | Args: 147 | tokens (List[str]): A list of tokens to convert. 148 | 149 | Returns: 150 | str: The reconstructed string. 151 | """ 152 | current_sub_tokens = [] 153 | out_string = "" 154 | for token in tokens: 155 | # make sure that special tokens are not decoded using sentencepiece model 156 | if token in self.all_special_tokens: 157 | out_string += self.sp_model.decode(current_sub_tokens) + token 158 | current_sub_tokens = [] 159 | else: 160 | current_sub_tokens.append(token) 161 | out_string += self.sp_model.decode(current_sub_tokens) 162 | return out_string 163 | 164 | def prepare_for_model(self, *args, **kwargs): 165 | if "add_special_tokens" in kwargs: 166 | kwargs.pop("add_special_tokens") 167 | return super().prepare_for_model(*args, **kwargs) 168 | 169 | def save_vocabulary( 170 | self, save_directory, filename_prefix: Optional[str] = None 171 | ) -> Tuple[str]: 172 | """ 173 | Save the vocabulary and special tokens file to a directory. 174 | 175 | Args: 176 | save_directory (str): The directory in which to save the vocabulary. 177 | filename_prefix (Optional[str]): Optional prefix for the saved filename. 178 | 179 | Returns: 180 | Tuple[str]: Paths to the files saved. 181 | 182 | Raises: 183 | ValueError: If the save_directory is not a valid directory. 184 | """ 185 | if not os.path.isdir(save_directory): 186 | logger.error(f"Vocabulary path ({save_directory}) should be a directory") 187 | return 188 | out_vocab_file = os.path.join( 189 | save_directory, 190 | (filename_prefix + "-" if filename_prefix else "") 191 | + self.vocab_files_names["vocab_file"], 192 | ) 193 | 194 | if os.path.abspath(self.vocab_file) != os.path.abspath( 195 | out_vocab_file 196 | ) and os.path.isfile(self.vocab_file): 197 | copyfile(self.vocab_file, out_vocab_file) 198 | elif not os.path.isfile(self.vocab_file): 199 | with open(out_vocab_file, "wb") as fi: 200 | content_spiece_model = self.sp_model.serialized_model_proto() 201 | fi.write(content_spiece_model) 202 | 203 | return (out_vocab_file,) 204 | 205 | def _decode(self, *args, **kwargs): 206 | kwargs.pop("clean_up_tokenization_spaces", None) 207 | kwargs.pop("spaces_between_special_tokens", None) 208 | return super()._decode( 209 | *args, 210 | **kwargs, 211 | clean_up_tokenization_spaces=False, 212 | spaces_between_special_tokens=False, 213 | ) 214 | 215 | -------------------------------------------------------------------------------- /mlx_engine/external/models/ernie4_5_moe/configuration_ernie4_5_moe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Baidu, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from transformers import PretrainedConfig 16 | 17 | 18 | 19 | class Ernie4_5_MoeConfig(PretrainedConfig): 20 | r""" 21 | This is the configuration class to store the configuration of a [`Ernie4_5_Model`]. 22 | 23 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 24 | documentation from [`PretrainedConfig`] for more information. 25 | 26 | 27 | Args: 28 | vocab_size (int): Size of the vocabulary (number of unique tokens) 29 | hidden_size (int): Dimensionality of the encoder layers and the pooler layer 30 | intermediate_size (int): Dimensionality of the "intermediate" (feed-forward) layer 31 | max_position_embeddings (int): Maximum sequence length the model can handle 32 | num_hidden_layers (int): Number of hidden layers in the Transformer encoder 33 | num_attention_heads (int): Number of attention heads for each attention layer 34 | rms_norm_eps (float): The epsilon used by the RMS normalization layers 35 | use_cache (bool): Whether to use caching for faster generation (decoding) 36 | use_flash_attention (bool): Whether to use FlashAttention for optimized attention computation 37 | pad_token_id (int): Token ID used for padding sequences 38 | bos_token_id (int): Token ID used for beginning-of-sequence 39 | eos_token_id (int): Token ID used for end-of-sequence 40 | use_bias (bool): Whether to use bias terms in linear layers 41 | rope_theta (float): The base period of the RoPE embeddings 42 | weight_share_add_bias (bool): Whether to share bias weights in certain layers 43 | ignored_index (int): Target value that is ignored during loss computation 44 | attention_probs_dropout_prob (float): Dropout probability for attention weights 45 | hidden_dropout_prob (float): Dropout probability for hidden layers 46 | num_key_value_heads (int): Number of key/value heads (for Grouped Query Attention) 47 | max_sequence_length (int): Maximum sequence length for positional embeddings 48 | moe_num_experts: Number of experts in MoE layers 49 | moe_capacity: Capacity configuration for MoE layers 50 | moe_layer_interval: Interval between MoE layers 51 | moe_layer_start_index: Starting layer index for MoE 52 | moe_layer_end_index: Ending layer index for MoE (-1 means last layer) 53 | sinkhorn_2gate: Whether to use sinkhorn 2-gate routing 54 | sinkhorn_temp: Temperature for sinkhorn routing 55 | moe_dropout_prob: Dropout probability for MoE layers 56 | moe_gate: Type of gating mechanism ('top2', etc.) 57 | moe_intermediate_size: Intermediate size for MoE layers 58 | moe_gate_act: Activation function for gating 59 | moe_k: Number of experts to route to 60 | num_nextn_predict_layers: Number of mtp predict layers, if use mtp, set `num_nextn_predict_layers > 0` 61 | multi_token_pred_lambda: The weight of multi token prediction loss 62 | **kwargs: Additional base model configuration parameters 63 | """ 64 | 65 | model_type = "ernie4_5_moe" 66 | use_keep_in_fp32_modules = True 67 | keys_to_ignore_at_inference = ["past_key_values"] 68 | 69 | attribute_map = { 70 | "n_positions": "max_position_embeddings", 71 | "n_embd": "hidden_size", 72 | "n_layer": "num_hidden_layers", 73 | "n_head": "num_attention_heads", 74 | "n_inner": "intermediate_size", 75 | "activation_function": "hidden_act", 76 | } 77 | 78 | # Default tensor parallel plan for base model `ernie_4_5_moe` 79 | base_model_tp_plan = { 80 | "model.layers.*.self_attn.q_proj": "colwise_rep", 81 | "model.layers.*.self_attn.k_proj": "colwise_rep", 82 | "model.layers.*.self_attn.v_proj": "colwise_rep", 83 | "model.layers.*.self_attn.o_proj": "rowwise_rep", 84 | "model.layers.*.mlp.experts.*.gate_proj": "colwise", 85 | "model.layers.*.mlp.experts.*.up_proj": "colwise", 86 | "model.layers.*.mlp.experts.*.down_proj": "rowwise", 87 | "model.layers.*.mlp.gate_proj": "colwise", 88 | "model.layers.*.mlp.up_proj": "colwise", 89 | "model.layers.*.mlp.down_proj": "rowwise", 90 | } 91 | base_model_pp_plan = { 92 | "embed_tokens": (["input_ids"], ["inputs_embeds"]), 93 | "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), 94 | "norm": (["hidden_states"], ["hidden_states"]), 95 | } 96 | 97 | def __init__( 98 | self, 99 | vocab_size=32000, 100 | hidden_size=768, 101 | intermediate_size=11008, 102 | num_hidden_layers=2, 103 | num_attention_heads=2, 104 | num_key_value_heads=None, 105 | max_position_embeddings=32768, 106 | rms_norm_eps=1e-6, 107 | use_cache=False, 108 | pad_token_id=0, 109 | bos_token_id=1, 110 | eos_token_id=2, 111 | attention_probs_dropout_prob=0.0, 112 | hidden_dropout_prob=0.0, 113 | rope_theta=10000.0, 114 | use_flash_attention=False, 115 | use_rmsnorm=True, 116 | use_bias=False, 117 | weight_share_add_bias=True, 118 | max_sequence_length=None, 119 | ignored_index=-100, 120 | use_moe=True, 121 | moe_num_experts=64, 122 | moe_capacity=(64, 64, 64), 123 | moe_layer_interval=2, 124 | moe_layer_start_index=0, 125 | moe_layer_end_index=-1, 126 | sinkhorn_2gate=True, 127 | sinkhorn_temp=3e-2, 128 | moe_dropout_prob=0.0, 129 | moe_gate="top2", 130 | moe_intermediate_size=3584, 131 | moe_k=2, 132 | moe_gate_act: str = "softmax", 133 | moe_use_aux_free=False, 134 | num_nextn_predict_layers=0, 135 | multi_token_pred_lambda=1.0, 136 | **kwargs, 137 | ): 138 | self.vocab_size = vocab_size 139 | self.max_position_embeddings = max_position_embeddings 140 | self.hidden_size = hidden_size 141 | self.intermediate_size = intermediate_size 142 | self.num_hidden_layers = num_hidden_layers 143 | self.num_attention_heads = num_attention_heads 144 | 145 | if num_key_value_heads is None: 146 | num_key_value_heads = num_attention_heads 147 | 148 | self.num_key_value_heads = num_key_value_heads 149 | self.use_rmsnorm = use_rmsnorm 150 | self.rms_norm_eps = rms_norm_eps 151 | self.rope_theta = rope_theta 152 | self.max_sequence_length = max_sequence_length 153 | self.pad_token_id = pad_token_id 154 | self.bos_token_id = bos_token_id 155 | self.eos_token_id = eos_token_id 156 | self.ignored_index = ignored_index 157 | self.use_cache = use_cache 158 | self.use_bias = use_bias 159 | self.weight_share_add_bias = weight_share_add_bias 160 | self.use_flash_attention = use_flash_attention 161 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 162 | self.hidden_dropout_prob = hidden_dropout_prob 163 | 164 | self.use_moe = moe_num_experts > 0 and use_moe 165 | self.moe_num_experts = moe_num_experts 166 | self.moe_capacity = moe_capacity 167 | self.sinkhorn_2gate = sinkhorn_2gate 168 | self.sinkhorn_temp = sinkhorn_temp 169 | self.moe_layer_interval = moe_layer_interval 170 | self.moe_dropout_prob = moe_dropout_prob 171 | self.moe_gate = moe_gate 172 | self.moe_intermediate_size = moe_intermediate_size 173 | self.moe_k = moe_k 174 | self.moe_layer_start_index = moe_layer_start_index 175 | self.moe_layer_end_index = ( 176 | self.num_hidden_layers - 1 177 | if moe_layer_end_index == -1 178 | else moe_layer_end_index 179 | ) 180 | self.moe_gate_act = moe_gate_act 181 | self.moe_use_aux_free = moe_use_aux_free 182 | self.num_nextn_predict_layers = num_nextn_predict_layers 183 | self.multi_token_pred_lambda = multi_token_pred_lambda 184 | 185 | # Set default for tied embeddings if not specified. 186 | if "tie_word_embeddings" not in kwargs: 187 | kwargs["tie_word_embeddings"] = False 188 | 189 | super().__init__( 190 | pad_token_id=pad_token_id, 191 | bos_token_id=bos_token_id, 192 | eos_token_id=eos_token_id, 193 | **kwargs, 194 | ) -------------------------------------------------------------------------------- /mlx_engine/vision_model_kit/vision_model_wrapper.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import logging 3 | 4 | from mlx_vlm.models.cache import KVCache, SimpleKVCache 5 | from typing import List, Optional 6 | from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import ( 7 | common_process_prompt_with_images, 8 | ) 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class VisionModelWrapper: 14 | """ 15 | Wrapper class for Vision Models support 16 | This wrapper class adapts mlx-vlm models so that they can be slotted into the mlx_lm generation engine 17 | This wrapper defines `__getattr__` and `__setattr__` to allow the model properties to be set/get as if it were a text model 18 | 19 | Models are evaluated in `mlx_lm` with the `__call__` method, so define a custom `__call__` method to forward calls to the vision model 20 | """ 21 | 22 | def __init__(self, model): 23 | """ 24 | Set the class members in this unusual way, so that we can define `__getattr__` and `__setattr__` 25 | """ 26 | self.__dict__["_model_attrs"] = { 27 | "vision_model": model, 28 | "input_ids": None, 29 | "pixel_values": None, 30 | "mask": None, 31 | "first_call": False, 32 | "decoder_input_ids": None, 33 | "language_model_kwargs": {}, 34 | # vision model kwargs 35 | "model_inputs": {}, 36 | } 37 | 38 | def __getattr__(self, name): 39 | """ 40 | First, check if the name is a member of this class 41 | Then, check if the name is a member of the language model 42 | Finally, check if the name is a member of the vision model 43 | """ 44 | if name in self._model_attrs: 45 | return self._model_attrs[name] 46 | try: 47 | return getattr(self.vision_model.language_model, name) 48 | except AttributeError: 49 | pass 50 | return getattr(self.vision_model, name) 51 | 52 | def __setattr__(self, name, value): 53 | """ 54 | Set attribute of this class if it's not a member of the vision model 55 | """ 56 | if name in self._model_attrs or not hasattr(self.vision_model, name): 57 | self._model_attrs[name] = value 58 | else: 59 | setattr(self.vision_model, name, value) 60 | 61 | def __call__(self, *args, input_embeddings=None, **kwargs): 62 | """ 63 | See this reference implementation 64 | https://github.com/Blaizzy/mlx-vlm/blob/6c98971/mlx_vlm/utils.py#L783-L810 65 | 66 | In the reference implementation, the vision model is called once at the beginning, 67 | then all subsequent calls are forwarded to the language model. Mirror that behavior here. 68 | """ 69 | if self.pixel_values is not None and not self.first_call: 70 | self.first_call = True 71 | 72 | # taken from here https://github.com/Blaizzy/mlx-vlm/blob/2974401/mlx_vlm/utils.py#L987 73 | if hasattr(self.language_model, "make_cache"): 74 | cache = self.language_model.make_cache() 75 | else: 76 | kv_heads = ( 77 | [self.language_model.n_kv_heads] * len(self.language_model.layers) 78 | if isinstance(self.language_model.n_kv_heads, int) 79 | else self.language_model.n_kv_heads 80 | ) 81 | if self.vision_model.config.model_type == "florence2": 82 | cache = [ 83 | (SimpleKVCache(), SimpleKVCache()) 84 | for n in self.language_model.layers 85 | ] 86 | else: 87 | cache = [KVCache() for n in kv_heads] 88 | 89 | # Replace the mlx_lm cache with the one we created 90 | kwargs["cache"] = cache 91 | 92 | outputs = self.vision_model( 93 | self.input_ids, 94 | self.pixel_values, 95 | mask=self.mask, 96 | **self.model_inputs, 97 | **kwargs, 98 | ) 99 | 100 | # taken from here https://github.com/Blaizzy/mlx-vlm/blob/2974401/mlx_vlm/utils.py#L1045-L1056 101 | if outputs.cross_attention_states is not None: 102 | self.language_model_kwargs = { 103 | k: v 104 | for k, v in zip( 105 | ["cross_attention_states"], [outputs.cross_attention_states] 106 | ) 107 | } 108 | elif outputs.encoder_outputs is not None: 109 | self.decoder_input_ids = self.input_ids 110 | self.language_model_kwargs = { 111 | "decoder_input_ids": self.decoder_input_ids, 112 | "encoder_outputs": outputs.encoder_outputs, 113 | } 114 | 115 | # Add the cache we created here to the language model kwargs 116 | self.language_model_kwargs["cache"] = cache 117 | else: 118 | try: 119 | if ( 120 | "cache" in self.language_model_kwargs 121 | ): # This is only missing if self.pixel_values is None 122 | del kwargs["cache"] # Use the cache from self.language_model_kwargs 123 | 124 | # taken from here https://github.com/Blaizzy/mlx-vlm/blob/2974401/mlx_vlm/utils.py#L1009 125 | if "decoder_input_ids" in self.language_model_kwargs: 126 | self.language_model_kwargs["decoder_input_ids"] = ( 127 | self.decoder_input_ids 128 | ) 129 | outputs = self.language_model( 130 | **kwargs, 131 | **self.language_model_kwargs, 132 | ) 133 | else: 134 | outputs = self.language_model( 135 | *args, 136 | **kwargs, 137 | **self.language_model_kwargs, 138 | ) 139 | 140 | except ValueError as e: 141 | # Create a friendly error message if a user tries to use mllama without images 142 | if "Cross attention states must be provided for layer" in str(e): 143 | raise ValueError( 144 | "Using this model without any images attached is not supported yet." 145 | ) 146 | raise e 147 | 148 | return outputs.logits 149 | 150 | def record_sampled_token(self, token: int) -> None: 151 | # Adapted from here https://github.com/Blaizzy/mlx-vlm/blob/2974401/mlx_vlm/utils.py#L1064 152 | self.decoder_input_ids = mx.array([token]) 153 | 154 | def process_prompt_with_images( 155 | self, 156 | images_b64: Optional[List[str]], 157 | prompt_tokens: mx.array, 158 | processor, 159 | detokenizer, 160 | max_image_size: tuple[int, int] | None, 161 | ): 162 | """ 163 | This method generates the input_ids, pixel_values, and mask for the vision model 164 | Call this before starting evaluation 165 | """ 166 | if images_b64 is None: 167 | images_b64 = [] 168 | 169 | # Handle the case with no images 170 | if len(images_b64) == 0: 171 | detokenizer.reset() 172 | [detokenizer.add_token(token) for token in prompt_tokens] 173 | detokenizer.finalize() 174 | prompt = detokenizer.text 175 | 176 | logger.debug(f"Prompt dump: {prompt}\n") 177 | 178 | try: 179 | if hasattr(processor, "process"): 180 | # Needed for Molmo 181 | self.input_ids = mx.array( 182 | processor.process(text=prompt)["input_ids"] 183 | ) 184 | else: 185 | self.input_ids = mx.array(processor(text=prompt).input_ids) 186 | except ValueError as e: 187 | if "`images` are expected as arguments" in str(e): 188 | raise ValueError( 189 | "Using this model without any images attached is not supported yet." 190 | ) 191 | raise e 192 | else: 193 | # Use the common function for image processing 194 | processed = common_process_prompt_with_images( 195 | prompt_tokens=prompt_tokens, 196 | images_b64=images_b64, 197 | processor=processor, 198 | config=self.vision_model.config, 199 | max_size=max_image_size, 200 | ) 201 | 202 | # Set class attributes from the processed result 203 | self.input_ids = processed.input_ids 204 | self.pixel_values = processed.pixel_values 205 | self.mask = processed.attention_mask 206 | self.model_inputs = processed.other_inputs 207 | 208 | @property 209 | def vision_model(self): 210 | return self._model_attrs["vision_model"] 211 | 212 | @property 213 | def language_model(self): 214 | return self.vision_model.language_model 215 | -------------------------------------------------------------------------------- /mlx_engine/model_kit/vision_add_ons/load_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | from pathlib import Path 4 | from typing import Any, Tuple, Type 5 | import mlx.core as mx 6 | from mlx import nn 7 | from mlx_vlm.utils import sanitize_weights, load_processor, skip_multimodal_module 8 | import logging 9 | 10 | 11 | def load_and_parse_config( 12 | model_path: Path, 13 | model_config_class: Any, 14 | vision_config_class: Any, 15 | text_config_class: Any, 16 | ) -> Tuple[Any, dict]: 17 | """ 18 | Load and parse vision model configuration from config.json. 19 | 20 | Args: 21 | model_path: Path to the model directory 22 | model_config_class: Configuration class for the model 23 | vision_config_class: Configuration class for vision component 24 | text_config_class: Configuration class for text component 25 | 26 | Returns: 27 | Tuple containing: 28 | - The fully initialized config object 29 | - The raw config dictionary (needed for quantization later) 30 | """ 31 | config_path = model_path / "config.json" 32 | if not config_path.exists(): 33 | raise FileNotFoundError(f"Configuration file not found at {config_path}") 34 | 35 | config_dict = json.loads(config_path.read_text()) 36 | config = model_config_class.from_dict(config_dict) 37 | config.vision_config = vision_config_class.from_dict(config.vision_config) 38 | config.text_config = text_config_class.from_dict(config.text_config) 39 | 40 | # hack for lfm2_vl, which uses a `vision_feature_layer` to reduce the number of actual layers 41 | # https://github.com/Blaizzy/mlx-vlm/blob/f02d63e8f5b521e8c75f129a63d2660efd132693/mlx_vlm/models/lfm2_vl/lfm2_vl.py#L98-L101 42 | if ( 43 | hasattr(config.text_config, "model_type") 44 | and "lfm2" in config.text_config.model_type 45 | ): 46 | vision_feature_layer = config_dict.get("vision_feature_layer", -1) 47 | if vision_feature_layer != -1: 48 | config.vision_config.num_hidden_layers += vision_feature_layer + 1 49 | config_dict["vision_config"]["num_hidden_layers"] = ( 50 | config.vision_config.num_hidden_layers 51 | ) 52 | 53 | return config, config_dict 54 | 55 | 56 | class VisionComponents(nn.Module): 57 | def __init__( 58 | self, vision_tower: nn.Module, multi_modal_projector: nn.Module | None = None 59 | ): 60 | super().__init__() 61 | self.vision_tower = vision_tower 62 | self.multi_modal_projector = multi_modal_projector 63 | 64 | 65 | def create_vision_components( 66 | config: Any, 67 | vision_tower_class: Type[nn.Module], 68 | multi_modal_projector_class: Type[nn.Module] | None, 69 | ) -> VisionComponents: 70 | """ 71 | Create vision model components and wrap them in a container module. 72 | 73 | Args: 74 | config: The fully initialized config object 75 | vision_tower_class: The vision tower model class 76 | multi_modal_projector_class: The multi-modal projector class 77 | 78 | Returns: 79 | The container module with both components 80 | """ 81 | components = VisionComponents( 82 | vision_tower_class(config.vision_config), 83 | multi_modal_projector_class(config) if multi_modal_projector_class else None, 84 | ) 85 | return components 86 | 87 | 88 | def load_and_filter_weights( 89 | model_path: Path, 90 | components: nn.Module, 91 | ) -> dict: 92 | """ 93 | Load model weights from safetensors files and filter for vision-related weights. 94 | 95 | Args: 96 | model_path: Path to the model directory 97 | components: The vision components container module 98 | 99 | Returns: 100 | Dictionary containing filtered vision-related weights 101 | """ 102 | # Load model weights 103 | weight_files = glob.glob(str(model_path / "*.safetensors")) 104 | if not weight_files: 105 | raise FileNotFoundError( 106 | f"Failed to load vision add-on: {model_path} does not contain any safetensors files" 107 | ) 108 | 109 | # Load and filter weights 110 | weights = {} 111 | for wf in weight_files: 112 | weights.update(mx.load(wf)) 113 | 114 | # Filter only vision-related weights 115 | vision_weights = { 116 | k: v 117 | for k, v in weights.items() 118 | if any(k.startswith(name) for name in components.children().keys()) 119 | } 120 | 121 | return vision_weights 122 | 123 | 124 | def maybe_apply_quantization( 125 | components: nn.Module, 126 | config_dict: dict, 127 | vision_weights: dict, 128 | ) -> None: 129 | """ 130 | Apply quantization to vision components if specified in config. 131 | 132 | Args: 133 | components: The vision components container module 134 | config_dict: Raw config dictionary containing quantization settings 135 | vision_weights: The vision-related weights dictionary 136 | """ 137 | # Apply quantization if specified in config 138 | if (quantization := config_dict.get("quantization", None)) is not None: 139 | # Copied from mlx_vlm/utils.py at commit 140 | # 65ecc837f24d0f8b138f300c7efef87f00fba74d 141 | skip_vision = config_dict.get("vision_config", {}).get("skip_vision", False) 142 | 143 | def get_class_predicate(p, m): 144 | # Always skip vision and audio models 145 | if skip_multimodal_module(p) and skip_vision: 146 | return False 147 | # Handle custom per layer quantizations 148 | if p in config_dict["quantization"]: 149 | return config_dict["quantization"][p] 150 | if not hasattr(m, "to_quantized"): 151 | return False 152 | # Skip layers not divisible by 64 153 | if hasattr(m, "weight") and m.weight.size % 64 != 0: 154 | return False 155 | # Handle legacy models which may not have everything quantized 156 | return f"{p}.scales" in vision_weights 157 | 158 | quantize_kwargs = {} 159 | if "bits" in quantization: 160 | quantize_kwargs["bits"] = quantization["bits"] 161 | if "group_size" in quantization: 162 | quantize_kwargs["group_size"] = quantization["group_size"] 163 | nn.quantize( 164 | components, 165 | class_predicate=get_class_predicate, 166 | **quantize_kwargs, 167 | ) 168 | 169 | 170 | def prepare_components( 171 | components: nn.Module, 172 | vision_weights: dict, 173 | ) -> None: 174 | """ 175 | Prepare vision components by loading weights and setting to evaluation mode. 176 | 177 | Args: 178 | components: The vision components container module 179 | vision_weights: The vision-related weights dictionary 180 | """ 181 | # Load weights into the model 182 | components.load_weights(list(vision_weights.items())) 183 | 184 | # Always load weights to memory here 185 | mx.eval(components.parameters()) 186 | 187 | # Set model to evaluation mode 188 | components.eval() 189 | 190 | 191 | def load_vision_addon( 192 | model_path: Path, 193 | model_config_class: Any, 194 | vision_config_class: Any, 195 | text_config_class: Any, 196 | vision_tower_class: Type[nn.Module], 197 | multi_modal_projector_class: Type[nn.Module] | None, 198 | logger: logging.Logger, 199 | processor_kwargs: dict | None = None, 200 | ) -> Tuple[nn.Module, nn.Module | None, Any, Any]: 201 | """ 202 | Load vision add-on components, configuration, and processor. 203 | 204 | Args: 205 | model_path: Path to the model directory 206 | model_config_class: Configuration class for the model 207 | vision_config_class: Configuration class for vision component 208 | text_config_class: Configuration class for text component 209 | vision_tower_class: The vision tower model class 210 | multi_modal_projector_class: The multi-modal projector class 211 | logger: logging.Logger 212 | 213 | Returns: 214 | Tuple containing: 215 | - The vision tower module 216 | - The multi-modal projector module 217 | - The model configuration 218 | - The processor for handling images and text 219 | """ 220 | # Load and parse configuration 221 | config, config_dict = load_and_parse_config( 222 | model_path, model_config_class, vision_config_class, text_config_class 223 | ) 224 | 225 | # Create model components 226 | components = create_vision_components( 227 | config, 228 | vision_tower_class, 229 | multi_modal_projector_class, 230 | ) 231 | 232 | # Load processor 233 | processor = load_processor( 234 | model_path=model_path, 235 | add_detokenizer=True, 236 | **(processor_kwargs or {}), 237 | ) 238 | 239 | # Load and filter weights 240 | vision_weights = load_and_filter_weights(model_path, components) 241 | 242 | # Sanitize weights for vision tower 243 | vision_weights = sanitize_weights( 244 | components.vision_tower.__class__, vision_weights, config.vision_config 245 | ) 246 | 247 | # Apply quantization if specified in config 248 | maybe_apply_quantization(components, config_dict, vision_weights) 249 | 250 | # Prepare components (load weights and set to eval mode) 251 | prepare_components(components, vision_weights) 252 | 253 | logger.info( 254 | f"Vision add-on loaded successfully from {model_path}", 255 | ) 256 | 257 | return components.vision_tower, components.multi_modal_projector, config, processor 258 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import base64 3 | import time 4 | import os 5 | 6 | from mlx_engine.generate import load_model, load_draft_model, create_generator, tokenize 7 | from mlx_engine.utils.token import Token 8 | from mlx_engine.utils.kv_cache_quantization import VALID_KV_BITS, VALID_KV_GROUP_SIZE 9 | from transformers import AutoTokenizer, AutoProcessor 10 | 11 | DEFAULT_PROMPT = "Explain the rules of chess in one sentence" 12 | DEFAULT_TEMP = 0.8 13 | 14 | DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant." 15 | 16 | 17 | def setup_arg_parser(): 18 | """Set up and return the argument parser.""" 19 | parser = argparse.ArgumentParser( 20 | description="LM Studio mlx-engine inference script" 21 | ) 22 | parser.add_argument( 23 | "--model", 24 | required=True, 25 | type=str, 26 | help="The file system path to the model", 27 | ) 28 | parser.add_argument( 29 | "--prompt", 30 | default=DEFAULT_PROMPT, 31 | type=str, 32 | help="Message to be processed by the model", 33 | ) 34 | parser.add_argument( 35 | "--system", 36 | default=DEFAULT_SYSTEM_PROMPT, 37 | type=str, 38 | help="System prompt for the model", 39 | ) 40 | parser.add_argument( 41 | "--no-system", 42 | action="store_true", 43 | help="Disable the system prompt", 44 | ) 45 | parser.add_argument( 46 | "--images", 47 | type=str, 48 | nargs="+", 49 | help="Path of the images to process", 50 | ) 51 | parser.add_argument( 52 | "--temp", 53 | default=DEFAULT_TEMP, 54 | type=float, 55 | help="Sampling temperature", 56 | ) 57 | parser.add_argument( 58 | "--stop-strings", 59 | type=str, 60 | nargs="+", 61 | help="Strings that will stop the generation", 62 | ) 63 | parser.add_argument( 64 | "--top-logprobs", 65 | type=int, 66 | default=0, 67 | help="Number of top logprobs to return", 68 | ) 69 | parser.add_argument( 70 | "--max-kv-size", 71 | type=int, 72 | help="Max context size of the model", 73 | ) 74 | parser.add_argument( 75 | "--kv-bits", 76 | type=int, 77 | choices=VALID_KV_BITS, 78 | help="Number of bits for KV cache quantization. Must be between 3 and 8 (inclusive)", 79 | ) 80 | parser.add_argument( 81 | "--kv-group-size", 82 | type=int, 83 | choices=VALID_KV_GROUP_SIZE, 84 | help="Group size for KV cache quantization", 85 | ) 86 | parser.add_argument( 87 | "--quantized-kv-start", 88 | type=int, 89 | help="When --kv-bits is set, start quantizing the KV cache from this step onwards", 90 | ) 91 | parser.add_argument( 92 | "--draft-model", 93 | type=str, 94 | help="The file system path to the draft model for speculative decoding.", 95 | ) 96 | parser.add_argument( 97 | "--num-draft-tokens", 98 | type=int, 99 | help="Number of tokens to draft when using speculative decoding.", 100 | ) 101 | parser.add_argument( 102 | "--print-prompt-progress", 103 | action="store_true", 104 | help="Enable printed prompt processing progress callback", 105 | ) 106 | parser.add_argument( 107 | "--max-img-size", type=int, help="Downscale images to this side length (px)" 108 | ) 109 | return parser 110 | 111 | 112 | def image_to_base64(image_path): 113 | with open(image_path, "rb") as image_file: 114 | return base64.b64encode(image_file.read()).decode("utf-8") 115 | 116 | 117 | class GenerationStatsCollector: 118 | def __init__(self): 119 | self.start_time = time.time() 120 | self.first_token_time = None 121 | self.total_tokens = 0 122 | self.num_accepted_draft_tokens: int | None = None 123 | 124 | def add_tokens(self, tokens: list[Token]): 125 | """Record new tokens and their timing.""" 126 | if self.first_token_time is None: 127 | self.first_token_time = time.time() 128 | 129 | draft_tokens = sum(1 for token in tokens if token.from_draft) 130 | if self.num_accepted_draft_tokens is None: 131 | self.num_accepted_draft_tokens = 0 132 | self.num_accepted_draft_tokens += draft_tokens 133 | 134 | self.total_tokens += len(tokens) 135 | 136 | def print_stats(self): 137 | """Print generation statistics.""" 138 | end_time = time.time() 139 | total_time = end_time - self.start_time 140 | time_to_first_token = self.first_token_time - self.start_time 141 | effective_time = total_time - time_to_first_token 142 | tokens_per_second = ( 143 | self.total_tokens / effective_time if effective_time > 0 else float("inf") 144 | ) 145 | print("\n\nGeneration stats:") 146 | print(f" - Tokens per second: {tokens_per_second:.2f}") 147 | if self.num_accepted_draft_tokens is not None: 148 | print( 149 | f" - Number of accepted draft tokens: {self.num_accepted_draft_tokens}" 150 | ) 151 | print(f" - Time to first token: {time_to_first_token:.2f}s") 152 | print(f" - Total tokens generated: {self.total_tokens}") 153 | print(f" - Total time: {total_time:.2f}s") 154 | 155 | 156 | def resolve_model_path(model_arg): 157 | # If it's a full path or local file, return as-is 158 | if os.path.exists(model_arg): 159 | return model_arg 160 | 161 | # Check common local directories 162 | local_paths = [ 163 | os.path.expanduser("~/.lmstudio/models"), 164 | os.path.expanduser("~/.cache/lm-studio/models"), 165 | ] 166 | 167 | for path in local_paths: 168 | full_path = os.path.join(path, model_arg) 169 | if os.path.exists(full_path): 170 | return full_path 171 | 172 | raise ValueError(f"Could not find model '{model_arg}' in local directories") 173 | 174 | 175 | if __name__ == "__main__": 176 | # Parse arguments 177 | parser = setup_arg_parser() 178 | args = parser.parse_args() 179 | if isinstance(args.images, str): 180 | args.images = [args.images] 181 | 182 | # Set up prompt processing callback 183 | def prompt_progress_callback(percent): 184 | if args.print_prompt_progress: 185 | width = 40 # bar width 186 | filled = int(width * percent / 100) 187 | bar = "█" * filled + "░" * (width - filled) 188 | print(f"\rProcessing prompt: |{bar}| ({percent:.1f}%)", end="", flush=True) 189 | if percent >= 100: 190 | print() # new line when done 191 | return True # Progress callback must return True to continue 192 | 193 | # Load the model 194 | model_path = resolve_model_path(args.model) 195 | print("Loading model...", end="\n", flush=True) 196 | model_kit = load_model( 197 | str(model_path), 198 | max_kv_size=args.max_kv_size, 199 | trust_remote_code=False, 200 | kv_bits=args.kv_bits, 201 | kv_group_size=args.kv_group_size, 202 | quantized_kv_start=args.quantized_kv_start, 203 | ) 204 | print("\rModel load complete ✓", end="\n", flush=True) 205 | 206 | # Load draft model if requested 207 | if args.draft_model: 208 | load_draft_model(model_kit=model_kit, path=resolve_model_path(args.draft_model)) 209 | 210 | # Tokenize the prompt 211 | prompt = args.prompt 212 | 213 | # Build conversation with optional system prompt 214 | conversation = [] 215 | if not args.no_system: 216 | conversation.append({"role": "system", "content": args.system}) 217 | 218 | # Handle the prompt according to the input type 219 | # If images are provided, add them to the prompt 220 | images_base64 = [] 221 | if args.images: 222 | tf_tokenizer = AutoProcessor.from_pretrained(model_path) 223 | images_base64 = [image_to_base64(img_path) for img_path in args.images] 224 | conversation.append( 225 | { 226 | "role": "user", 227 | "content": [ 228 | *[ 229 | {"type": "image", "base64": image_b64} 230 | for image_b64 in images_base64 231 | ], 232 | {"type": "text", "text": prompt}, 233 | ], 234 | } 235 | ) 236 | else: 237 | tf_tokenizer = AutoTokenizer.from_pretrained(model_path) 238 | conversation.append({"role": "user", "content": prompt}) 239 | prompt = tf_tokenizer.apply_chat_template( 240 | conversation, tokenize=False, add_generation_prompt=True 241 | ) 242 | prompt_tokens = tokenize(model_kit, prompt) 243 | 244 | # Record top logprobs 245 | logprobs_list = [] 246 | 247 | # Initialize generation stats collector 248 | stats_collector = GenerationStatsCollector() 249 | 250 | # Clamp image size 251 | max_img_size = (args.max_img_size, args.max_img_size) if args.max_img_size else None 252 | 253 | # Generate the response 254 | generator = create_generator( 255 | model_kit, 256 | prompt_tokens, 257 | images_b64=images_base64, 258 | max_image_size=max_img_size, 259 | stop_strings=args.stop_strings, 260 | max_tokens=1024, 261 | top_logprobs=args.top_logprobs, 262 | prompt_progress_callback=prompt_progress_callback, 263 | num_draft_tokens=args.num_draft_tokens, 264 | temp=args.temp, 265 | ) 266 | for generation_result in generator: 267 | print(generation_result.text, end="", flush=True) 268 | stats_collector.add_tokens(generation_result.tokens) 269 | logprobs_list.extend(generation_result.top_logprobs) 270 | 271 | if generation_result.stop_condition: 272 | stats_collector.print_stats() 273 | print( 274 | f"\nStopped generation due to: {generation_result.stop_condition.stop_reason}" 275 | ) 276 | if generation_result.stop_condition.stop_string: 277 | print(f"Stop string: {generation_result.stop_condition.stop_string}") 278 | 279 | if args.top_logprobs: 280 | [print(x) for x in logprobs_list] 281 | -------------------------------------------------------------------------------- /mlx_engine/model_kit/model_kit.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Callable, Optional, List, Tuple 3 | import mlx_lm 4 | from mlx_lm.tokenizer_utils import TokenizerWrapper, StreamingDetokenizer 5 | from mlx_engine.cache_wrapper import CacheWrapper 6 | from pathlib import Path 7 | import mlx.nn as nn 8 | import mlx.core as mx 9 | import logging 10 | from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn 11 | from mlx_engine.model_kit.vision_add_ons.gemma3 import Gemma3VisionAddOn 12 | from mlx_engine.model_kit.vision_add_ons.pixtral import PixtralVisionAddOn 13 | from mlx_engine.model_kit.vision_add_ons.gemma3n import Gemma3nVisionAddOn 14 | from mlx_engine.model_kit.vision_add_ons.mistral3 import Mistral3VisionAddOn 15 | from mlx_engine.model_kit.vision_add_ons.lfm2_vl import LFM2VisionAddOn 16 | from mlx_engine.utils.kv_cache_quantization import get_kv_cache_quantization_params 17 | from mlx_engine.utils.prompt_processing import process_prompt_text_only 18 | from mlx_engine.utils.fix_mistral_pre_tokenizer import fix_mistral_pre_tokenizer 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class ModelKit: 24 | """ 25 | Collection of objects and methods that are needed for operating a model. 26 | 27 | Args: 28 | model_path (Path): Path to the model directory containing model files. 29 | vocab_only (bool): Only load vocabulary/tokenizer, not the full model. 30 | max_kv_size (int): Maximum size of the key-value cache used during model inference. 31 | kv_bits (Optional[int]): Number of bits for KV cache quantization. None disables quantization. 32 | kv_group_size (Optional[int]): Group size for KV cache quantization. Defaults to 64. 33 | quantized_kv_start (Optional[int]): Step to begin KV cache quantization when enabled. Defaults to 0. 34 | """ 35 | 36 | VISION_ADD_ON_MAP = { 37 | "gemma3": Gemma3VisionAddOn, 38 | "gemma3n": Gemma3nVisionAddOn, 39 | "lfm2-vl": LFM2VisionAddOn, 40 | "mistral3": Mistral3VisionAddOn, 41 | "pixtral": PixtralVisionAddOn, 42 | # qwen vl ports are bugged: https://github.com/lmstudio-ai/mlx-engine/issues/237 43 | # "qwen2_vl": Qwen2_VLVisionAddOn, 44 | # "qwen2_5_vl": Qwen2_VLVisionAddOn, 45 | # "qwen3_vl_moe": Qwen3_VL_MoEVisionAddOn, 46 | # "qwen3_vl": Qwen3_VLVisionAddOn, 47 | } 48 | 49 | # model state tracking 50 | model: nn.Module = None 51 | tokenizer: TokenizerWrapper = None 52 | detokenizer: StreamingDetokenizer = None 53 | cache_wrapper: Optional[CacheWrapper] = None 54 | _cross_prompt_cache_active: bool = False 55 | max_kv_size: Optional[int] = None 56 | kv_bits: Optional[int] = None 57 | kv_group_size: Optional[int] = None 58 | quantized_kv_start: Optional[int] = None 59 | draft_model: Optional[nn.Module] = None 60 | model_type: Optional[str] = None 61 | 62 | # multi-modal add-ons 63 | vision_add_on: Optional[BaseVisionAddOn] = None 64 | 65 | def _vocab_only_init(self, model_path: Path): 66 | logger.info(f"Loading model (vocab-only) from {model_path}...") 67 | self.tokenizer = mlx_lm.tokenizer_utils.load(model_path) 68 | self.detokenizer = self.tokenizer.detokenizer 69 | logger.info("Model (vocab-only) loaded successfully") 70 | 71 | def _full_model_init( 72 | self, 73 | model_path: Path, 74 | max_kv_size: Optional[int] = None, 75 | kv_bits: Optional[int] = None, 76 | kv_group_size: Optional[int] = None, 77 | quantized_kv_start: Optional[int] = None, 78 | ): 79 | kv_bits, kv_group_size, quantized_kv_start = get_kv_cache_quantization_params( 80 | kv_bits, 81 | kv_group_size, 82 | quantized_kv_start, 83 | ) 84 | if kv_bits and max_kv_size is not None: 85 | # Quantized KV cache is only supported for non-rotating KV cache 86 | logger.warning("max_kv_size is ignored when using KV cache quantization") 87 | max_kv_size = None 88 | self.model_path = model_path 89 | logger.info(f"Loading model from {model_path}...") 90 | config_json = json.loads((model_path / "config.json").read_text()) 91 | self.model_type = config_json.get("model_type", None) 92 | 93 | self.model, self.tokenizer = mlx_lm.utils.load(self.model_path) 94 | fix_mistral_pre_tokenizer( 95 | tokenizer=self.tokenizer, model_path=model_path, model_type=self.model_type 96 | ) 97 | self.detokenizer = self.tokenizer.detokenizer 98 | self.cache_wrapper = CacheWrapper( 99 | self.model, 100 | max_kv_size, 101 | kv_bits=kv_bits, 102 | kv_group_size=kv_group_size, 103 | quantized_kv_start=quantized_kv_start, 104 | ) 105 | self.kv_bits = kv_bits 106 | self.kv_group_size = kv_group_size 107 | self.quantized_kv_start = quantized_kv_start 108 | vision_add_on_class = self.VISION_ADD_ON_MAP.get(self.model_type) 109 | should_load_vision_add_on = ( 110 | vision_add_on_class is not None and "vision_config" in config_json 111 | ) 112 | if should_load_vision_add_on: 113 | self.vision_add_on = vision_add_on_class(model_path) 114 | logger.info("Model loaded successfully") 115 | 116 | def __init__( 117 | self, 118 | model_path: Path, 119 | vocab_only: bool = False, 120 | max_kv_size: Optional[int] = None, 121 | kv_bits: Optional[int] = None, 122 | kv_group_size: Optional[int] = None, 123 | quantized_kv_start: Optional[int] = None, 124 | ): 125 | if vocab_only: 126 | self._vocab_only_init(model_path) 127 | else: 128 | self._full_model_init( 129 | model_path, 130 | max_kv_size, 131 | kv_bits, 132 | kv_group_size, 133 | quantized_kv_start, 134 | ) 135 | 136 | def tokenize(self, prompt: str) -> List[int]: 137 | ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt)) 138 | if isinstance(ids, int): 139 | return [ids] 140 | return ids 141 | 142 | def process_prompt( 143 | self, 144 | prompt_tokens, 145 | images_b64: Optional[List[str]], 146 | prompt_progress_callback: Optional[Callable[[float], bool]], 147 | generate_args: dict, 148 | max_image_size: tuple[int, int] | None, 149 | speculative_decoding_toggle: Optional[bool] = None, 150 | ) -> Tuple[mx.array, Optional[mx.array]]: 151 | ### TEXT-ONLY PROCESS_PROMPT ### 152 | is_text_only_processing = images_b64 is None or len(images_b64) == 0 153 | if is_text_only_processing: 154 | self._cross_prompt_cache_active = True 155 | if len(prompt_tokens) == 0: 156 | logger.warning( 157 | "Received empty prompt. Generation quality will likely be poor" 158 | ) 159 | # Models expect some sort of input, so add whitespace 160 | prompt_tokens = self.tokenize(" ") 161 | return process_prompt_text_only( 162 | mx.array(prompt_tokens), 163 | self.cache_wrapper, 164 | generate_args, 165 | self.draft_model, 166 | speculative_decoding_toggle, 167 | prompt_progress_callback, 168 | ), None 169 | ### WITH IMAGES PROMPT PROCESSING ###s 170 | if self.vision_add_on is None: 171 | raise ValueError( 172 | "Vision add-on is not loaded, but images were provided for processing" 173 | ) 174 | self._cross_prompt_cache_active = False 175 | input_ids, embeddings = self.vision_add_on.compute_embeddings( 176 | self.model, prompt_tokens, images_b64, max_size=max_image_size 177 | ) 178 | return input_ids, embeddings 179 | 180 | def is_cross_prompt_cache_active(self) -> bool: 181 | """ 182 | Check if cross-prompt caching is currently enabled. 183 | Can be overridden by subclasses for custom behavior. 184 | """ 185 | return self._cross_prompt_cache_active 186 | 187 | def record_token_to_cache(self, token: int) -> None: 188 | self.cache_wrapper.record_generated_token(token) 189 | 190 | @staticmethod 191 | def is_supported_vision_arch(model_arch: str) -> bool: 192 | """ 193 | Determines if the specified model architecture has vision support. 194 | 195 | Args: 196 | model_arch (str): The model architecture identifier to check 197 | 198 | Returns: 199 | bool: True if vision is supported, False otherwise 200 | """ 201 | return model_arch in ModelKit.VISION_ADD_ON_MAP 202 | 203 | def is_draft_model_compatible(self, path: str | Path) -> bool: 204 | path = Path(path) 205 | if self.tokenizer is None: 206 | logger.warning( 207 | "Draft model compatibility check requires at least a vocab-only loaded main model" 208 | ) 209 | return False 210 | if self.vision_add_on is not None: 211 | logger.warning("Draft models are currently unsupported for vision models") 212 | return False 213 | draft_tokenizer = mlx_lm.tokenizer_utils.load(path) 214 | if draft_tokenizer.vocab_size != self.tokenizer.vocab_size: 215 | return False 216 | return True 217 | 218 | def load_draft_model(self, path: str | Path) -> None: 219 | logger.info(f"Loading draft model from {path}...") 220 | path = Path(path) 221 | if self.model is None: 222 | raise ValueError("Main model must be loaded before loading a draft model") 223 | if not self.is_draft_model_compatible(path): 224 | raise ValueError("Draft model is not compatible with main model") 225 | self.draft_model, _ = mlx_lm.utils.load(path) 226 | self.cache_wrapper.set_draft_model(self.draft_model) 227 | logger.info("Draft model loaded") 228 | 229 | def unload_draft_model(self) -> None: 230 | if self.draft_model is None: 231 | logger.info("No loaded draft model to unload") 232 | else: 233 | self.draft_model = None 234 | self.cache_wrapper.unset_draft_model() 235 | # Noticed that draft model memory would not be released without clearing metal cache 236 | mx.clear_cache() 237 | --------------------------------------------------------------------------------