├── mlx_embeddings ├── models │ ├── __init__.py │ ├── base.py │ ├── gemma3_text.py │ ├── lfm2.py │ ├── bert.py │ ├── colqwen2_5.py │ ├── xlm_roberta.py │ ├── qwen3.py │ ├── modernbert.py │ └── siglip.py ├── version.py ├── tests │ ├── __init__.py │ ├── test_base.py │ ├── test_smoke.py │ └── test_models.py ├── __init__.py ├── convert.py ├── tokenizer_utils.py └── utils.py ├── images ├── cats.jpg └── desktop_setup.png ├── requirements.txt ├── MANIFEST.in ├── AUTHORS.rst ├── .pre-commit-config.yaml ├── .editorconfig ├── .github ├── ISSUE_TEMPLATE │ ├── feature_request.md │ ├── config.yml │ └── bug_report.md └── workflows │ ├── python-publish.yaml │ ├── docs.yml │ ├── test.yaml │ └── docs-build.yml ├── LICENSE ├── .gitignore ├── pyproject.toml ├── mkdocs.yml └── README.md /mlx_embeddings/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mlx_embeddings/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.5" 2 | -------------------------------------------------------------------------------- /mlx_embeddings/tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit test package for mlx_embeddings.""" 2 | -------------------------------------------------------------------------------- /images/cats.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-embeddings/HEAD/images/cats.jpg -------------------------------------------------------------------------------- /images/desktop_setup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-embeddings/HEAD/images/desktop_setup.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.16.3 2 | mlx-vlm>=0.1.21 3 | transformers[sentencepiece]>=4.44.0 4 | huggingface-hub>=0.25.1 5 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | include requirements.txt 4 | 5 | recursive-exclude * __pycache__ 6 | recursive-exclude * *.py[co] 7 | 8 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Credits 3 | ======= 4 | 5 | Development Lead 6 | ---------------- 7 | 8 | * Prince Canuma 9 | 10 | Contributors 11 | ------------ 12 | 13 | None yet. Why not be the first? 14 | -------------------------------------------------------------------------------- /mlx_embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | """Top-level package for MLX-Embeddings.""" 2 | 3 | __author__ = """Prince Canuma""" 4 | __email__ = "prince.gdt@gmail.com" 5 | 6 | from .utils import convert, generate, load 7 | from .version import __version__ 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black-pre-commit-mirror 3 | rev: 24.2.0 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/pycqa/isort 7 | rev: 5.13.2 8 | hooks: 9 | - id: isort 10 | args: 11 | - --profile=black -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: Submit a feature request to help us improve 4 | labels: Feature Request 5 | --- 6 | 7 | 8 | 9 | ### Description 10 | 11 | Describe the feature (e.g., new functions/tutorials) you would like to propose. 12 | Tell us what can be achieved with this new feature and what's the expected outcome. 13 | 14 | ### Source code 15 | 16 | ``` 17 | Paste your source code here if have sample code to share. 18 | ``` 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | contact_links: 2 | - name: Ask questions 3 | url: https://github.com/Blaizzy/mlx-embeddings/discussions/categories/q-a 4 | about: Please ask and answer questions here. 5 | - name: Ideas 6 | url: https://github.com/Blaizzy/mlx-embeddings/discussions/categories/ideas 7 | about: Please share your ideas here. 8 | - name: Ask questions from the GIS community 9 | url: https://gis.stackexchange.com 10 | about: To get answers from questions in the GIS community, please ask and answer questions here. 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Create a bug report to help us improve 4 | labels: bug 5 | --- 6 | 7 | 8 | 9 | ### Environment Information 10 | 11 | - mlx_embeddings version: 12 | - Python version: 13 | - Operating System: 14 | 15 | ### Description 16 | 17 | Describe what you were trying to get done. 18 | Tell us what happened, what went wrong, and what you expected to happen. 19 | 20 | ### What I Did 21 | 22 | ``` 23 | Paste the command(s) you ran and the output. 24 | If there was a crash, please include the traceback here. 25 | ``` 26 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [published] 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | deploy: 15 | 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - uses: actions/checkout@v3 20 | - name: Set up Python 21 | uses: actions/setup-python@v3 22 | with: 23 | python-version: '3.10' 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install build 28 | - name: Build package 29 | run: python -m build 30 | - name: Publish package 31 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 32 | with: 33 | user: __token__ 34 | password: ${{ secrets.PYPI_API_TOKEN }} 35 | packages_dir: dist -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | # name: docs 2 | # on: 3 | # push: 4 | # branches: 5 | # - main 6 | # - master 7 | # jobs: 8 | # deploy: 9 | # runs-on: ubuntu-latest 10 | # steps: 11 | # - uses: actions/checkout@v4 12 | # - uses: actions/setup-python@v5 13 | # with: 14 | # python-version: "3.11" 15 | 16 | # - name: Install dependencies 17 | # run: | 18 | # python -m pip install --upgrade pip 19 | # pip install --user --no-cache-dir Cython 20 | # pip install --user -r requirements.txt -r requirements_dev.txt 21 | # pip install . 22 | # - name: Discover typos with codespell 23 | # run: | 24 | # codespell --skip="*.csv,*.geojson,*.json,*.js,*.html,*cff,./.git" --ignore-words-list="aci,hist" 25 | # - name: PKG-TEST 26 | # run: | 27 | # python -m unittest discover tests/ 28 | # - run: mkdocs gh-deploy --force 29 | 30 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test PRs 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | test: 10 | runs-on: macos-14 11 | 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v4 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: '3.10' 20 | 21 | - name: Install MLX 22 | run: | 23 | pip install mlx>=0.15 24 | 25 | - name: Install pre-commit 26 | run: | 27 | python -m pip install pre-commit 28 | pre-commit run --all 29 | if ! git diff --quiet; then 30 | echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change' 31 | exit 1 32 | fi 33 | 34 | - name: Install package and dependencies 35 | run: | 36 | python -m pip install pytest 37 | python -m pip install -e . 38 | 39 | - name: Run Python tests 40 | run: | 41 | cd mlx_embeddings/ 42 | pytest -s ./tests --ignore=tests/test_smoke.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | MLX-VLM is a package for running Vision LLMs l 5 | Copyright (C) 2024 Prince Canuma 6 | 7 | This program is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | This program is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with this program. If not, see . 19 | 20 | Also add information on how to contact you by electronic and paper mail. 21 | 22 | You should also get your employer (if you work as a programmer) or school, 23 | if any, to sign a "copyright disclaimer" for the program, if necessary. 24 | For more information on this, and how to apply and follow the GNU GPL, see 25 | . 26 | 27 | The GNU General Public License does not permit incorporating your program 28 | into proprietary programs. If your program is a subroutine library, you 29 | may consider it more useful to permit linking proprietary applications with 30 | the library. If this is what you want to do, use the GNU Lesser General 31 | Public License instead of this License. But first, please read 32 | . 33 | 34 | -------------------------------------------------------------------------------- /mlx_embeddings/models/base.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import List, Optional 4 | 5 | import mlx.core as mx 6 | 7 | 8 | @dataclass 9 | class BaseModelArgs: 10 | @classmethod 11 | def from_dict(cls, params): 12 | return cls( 13 | **{ 14 | k: v 15 | for k, v in params.items() 16 | if k in inspect.signature(cls).parameters 17 | } 18 | ) 19 | 20 | 21 | @dataclass 22 | class BaseModelOutput: 23 | last_hidden_state: Optional[mx.array] = None 24 | pooler_output: Optional[mx.array] = None 25 | text_embeds: Optional[mx.array] = None # mean pooled and normalized embeddings 26 | hidden_states: Optional[List[mx.array]] = None 27 | 28 | 29 | @dataclass 30 | class ViTModelOutput: 31 | logits: Optional[mx.array] = None 32 | text_embeds: Optional[mx.array] = None 33 | image_embeds: Optional[mx.array] = None 34 | logits_per_text: Optional[mx.array] = None 35 | logits_per_image: Optional[mx.array] = None 36 | text_model_output: Optional[mx.array] = None 37 | vision_model_output: Optional[mx.array] = None 38 | 39 | 40 | def mean_pooling(token_embeddings: mx.array, attention_mask: mx.array): 41 | input_mask_expanded = mx.expand_dims(attention_mask, -1) 42 | input_mask_expanded = mx.broadcast_to( 43 | input_mask_expanded, token_embeddings.shape 44 | ).astype(mx.float32) 45 | sum_embeddings = mx.sum(token_embeddings * input_mask_expanded, axis=1) 46 | sum_mask = mx.maximum(mx.sum(input_mask_expanded, axis=1), 1e-9) 47 | return sum_embeddings / sum_mask 48 | 49 | 50 | def normalize_embeddings(embeddings, p=2, axis=-1, keepdims=True, eps=1e-9): 51 | return embeddings / mx.maximum( 52 | mx.linalg.norm(embeddings, ord=p, axis=axis, keepdims=keepdims), eps 53 | ) 54 | -------------------------------------------------------------------------------- /mlx_embeddings/convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from .utils import convert 4 | 5 | 6 | def configure_parser() -> argparse.ArgumentParser: 7 | """ 8 | Configures and returns the argument parser for the script. 9 | 10 | Returns: 11 | argparse.ArgumentParser: Configured argument parser. 12 | """ 13 | parser = argparse.ArgumentParser( 14 | description="Convert Hugging Face model to MLX format" 15 | ) 16 | 17 | parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.") 18 | parser.add_argument( 19 | "--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model." 20 | ) 21 | parser.add_argument( 22 | "-q", "--quantize", help="Generate a quantized model.", action="store_true" 23 | ) 24 | parser.add_argument( 25 | "--q-group-size", help="Group size for quantization.", type=int, default=64 26 | ) 27 | parser.add_argument( 28 | "--q-bits", help="Bits per weight for quantization.", type=int, default=4 29 | ) 30 | parser.add_argument( 31 | "--dtype", 32 | help="Type to save the parameters, ignored if -q is given.", 33 | type=str, 34 | choices=["float16", "bfloat16", "float32"], 35 | default="float16", 36 | ) 37 | parser.add_argument( 38 | "--upload-repo", 39 | help="The Hugging Face repo to upload the model to.", 40 | type=str, 41 | default=None, 42 | ) 43 | parser.add_argument( 44 | "-d", 45 | "--dequantize", 46 | help="Dequantize a quantized model.", 47 | action="store_true", 48 | default=False, 49 | ) 50 | return parser 51 | 52 | 53 | def main(): 54 | parser = configure_parser() 55 | args = parser.parse_args() 56 | convert(**vars(args)) 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | private/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | private/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # dotenv 86 | .env 87 | 88 | # virtualenv 89 | .venv 90 | venv/ 91 | ENV/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # IDE settings 107 | .vscode/ 108 | test.ipynb 109 | .lisapro/*.py 110 | .lisapro/*.yml 111 | 112 | .DS_Store 113 | 114 | # testing README scripts 115 | test_readme.py -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mlx-embeddings" 3 | version = "0.0.5" 4 | dynamic = [ 5 | "dependencies", 6 | ] 7 | description = "MLX-Embeddings is a package for running Vision and Language Embedding models locally on your Mac using MLX." 8 | readme = "README.md" 9 | requires-python = ">=3.8" 10 | keywords = [ 11 | "mlx-embeddings", 12 | ] 13 | license = {text = "GNU General Public License v3"} 14 | authors = [ 15 | {name = "Prince Canuma", email = "prince.gdt@gmail.com"}, 16 | ] 17 | classifiers = [ 18 | "Intended Audience :: Developers", 19 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", 20 | "Natural Language :: English", 21 | "Programming Language :: Python :: 3.8", 22 | "Programming Language :: Python :: 3.9", 23 | "Programming Language :: Python :: 3.10", 24 | "Programming Language :: Python :: 3.11", 25 | "Programming Language :: Python :: 3.12", 26 | ] 27 | 28 | [project.entry-points."console_scripts"] 29 | mlx_embeddings = "mlx_embeddings.cli:main" 30 | 31 | [project.optional-dependencies] 32 | all = [ 33 | "mlx-embeddings[extra]", 34 | ] 35 | 36 | extra = [ 37 | "pandas", 38 | ] 39 | 40 | 41 | [tool] 42 | [tool.setuptools.packages.find] 43 | include = ["mlx_embeddings*"] 44 | exclude = ["docs*"] 45 | 46 | [tool.setuptools.dynamic] 47 | dependencies = {file = ["requirements.txt"]} 48 | 49 | 50 | [tool.distutils.bdist_wheel] 51 | universal = true 52 | 53 | 54 | [tool.bumpversion] 55 | current_version = "0.0.5" 56 | commit = true 57 | tag = true 58 | 59 | [[tool.bumpversion.files]] 60 | filename = "pyproject.toml" 61 | search = 'version = "{current_version}"' 62 | replace = 'version = "{new_version}"' 63 | 64 | [[tool.bumpversion.files]] 65 | filename = "mlx_embeddings/version.py" 66 | search = '__version__ = "{current_version}"' 67 | replace = '__version__ = "{new_version}"' 68 | 69 | 70 | [tool.flake8] 71 | exclude = [ 72 | "docs", 73 | ] 74 | max-line-length = 88 75 | 76 | 77 | [project.urls] 78 | Homepage = "https://github.com/Blaizzy/mlx-embeddings" 79 | 80 | [build-system] 81 | requires = ["setuptools>=64", "setuptools_scm>=8"] 82 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /.github/workflows/docs-build.yml: -------------------------------------------------------------------------------- 1 | # name: docs-build 2 | # on: 3 | # pull_request: 4 | # branches: 5 | # - main 6 | # - master 7 | 8 | # jobs: 9 | # deploy: 10 | # runs-on: ubuntu-latest 11 | # steps: 12 | # - uses: actions/checkout@v4 13 | # with: 14 | # fetch-depth: 0 15 | # - uses: actions/setup-python@v5 16 | # with: 17 | # python-version: "3.11" 18 | 19 | # - name: Install GDAL 20 | # run: | 21 | # python -m pip install --upgrade pip 22 | # pip install --find-links=https://girder.github.io/large_image_wheels --no-cache GDAL pyproj 23 | # - name: Test GDAL installation 24 | # run: | 25 | # python -c "from osgeo import gdal" 26 | # gdalinfo --version 27 | # - name: Install dependencies 28 | # run: | 29 | # pip install --no-cache-dir Cython 30 | # pip install -r requirements.txt -r requirements_dev.txt 31 | # pip install . 32 | # - name: Discover typos with codespell 33 | # run: codespell --skip="*.csv,*.geojson,*.json,*.js,*.html,*cff,*.pdf,./.git" --ignore-words-list="aci,acount,hist" 34 | # - name: PKG-TEST 35 | # run: | 36 | # python -m unittest discover tests/ 37 | # - name: Build docs 38 | # run: | 39 | # mkdocs build 40 | # - name: Deploy to Netlify 41 | # uses: nwtgck/actions-netlify@v2.0 42 | # with: 43 | # publish-dir: "./site" 44 | # production-branch: main 45 | 46 | # github-token: ${{ secrets.GITHUB_TOKEN }} 47 | # deploy-message: "Deploy from GitHub Actions" 48 | # enable-pull-request-comment: true 49 | # enable-commit-comment: false 50 | # overwrites-pull-request-comment: true 51 | # env: 52 | # NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }} 53 | # NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }} 54 | # timeout-minutes: 10 55 | 56 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: mlx-embeddings 2 | site_description: MLX-VLM is a package for running Vision LLMs l 3 | site_author: Blaizzy 4 | site_url: https://Blaizzy.github.io/mlx-embeddings 5 | repo_url: https://github.com/Blaizzy/mlx-embeddings 6 | 7 | copyright: "Copyright © 2024 - 2024 Prince Canuma" 8 | 9 | theme: 10 | palette: 11 | - scheme: default 12 | # primary: blue 13 | # accent: indigo 14 | toggle: 15 | icon: material/toggle-switch-off-outline 16 | name: Switch to dark mode 17 | - scheme: slate 18 | primary: indigo 19 | accent: indigo 20 | toggle: 21 | icon: material/toggle-switch 22 | name: Switch to light mode 23 | name: material 24 | icon: 25 | repo: fontawesome/brands/github 26 | # logo: assets/logo.png 27 | # favicon: assets/favicon.png 28 | features: 29 | - navigation.instant 30 | - navigation.tracking 31 | - navigation.top 32 | - search.highlight 33 | - search.share 34 | custom_dir: "docs/overrides" 35 | font: 36 | text: Google Sans 37 | code: Regular 38 | 39 | plugins: 40 | - search 41 | - mkdocstrings 42 | - git-revision-date 43 | - git-revision-date-localized: 44 | enable_creation_date: true 45 | type: timeago 46 | # - pdf-export 47 | - mkdocs-jupyter: 48 | include_source: True 49 | ignore_h1_titles: True 50 | execute: True 51 | allow_errors: false 52 | ignore: ["conf.py"] 53 | execute_ignore: ["*ignore.ipynb"] 54 | 55 | markdown_extensions: 56 | - admonition 57 | - abbr 58 | - attr_list 59 | - def_list 60 | - footnotes 61 | - meta 62 | - md_in_html 63 | - pymdownx.superfences 64 | - pymdownx.highlight: 65 | linenums: true 66 | - toc: 67 | permalink: true 68 | 69 | # extra: 70 | # analytics: 71 | # provider: google 72 | # property: UA-XXXXXXXXX-X 73 | 74 | nav: 75 | - Home: index.md 76 | - Installation: installation.md 77 | - Usage: usage.md 78 | - Contributing: contributing.md 79 | - FAQ: faq.md 80 | - Changelog: changelog.md 81 | - Report Issues: https://github.com/Blaizzy/mlx-embeddings/issues 82 | - Examples: 83 | - examples/intro.ipynb 84 | - API Reference: 85 | - mlx_embeddings module: mlx_embeddings.md 86 | - common module: common.md 87 | -------------------------------------------------------------------------------- /mlx_embeddings/models/gemma3_text.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Optional 3 | 4 | import mlx.core as mx 5 | import mlx.nn as nn 6 | from mlx_lm.models.base import create_attention_mask 7 | from mlx_lm.models.gemma3_text import ModelArgs, RMSNorm, TransformerBlock 8 | 9 | from .base import BaseModelOutput, mean_pooling, normalize_embeddings 10 | 11 | 12 | class Gemma3Model(nn.Module): 13 | def __init__(self, args: ModelArgs): 14 | super().__init__() 15 | self.config = args 16 | self.vocab_size = args.vocab_size 17 | self.num_hidden_layers = args.num_hidden_layers 18 | assert self.vocab_size > 0 19 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 20 | self.layers = [ 21 | TransformerBlock(args=args, layer_idx=layer_idx) 22 | for layer_idx in range(args.num_hidden_layers) 23 | ] 24 | self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 25 | 26 | def __call__( 27 | self, 28 | inputs: mx.array, 29 | mask: mx.array = None, 30 | cache=None, 31 | input_embeddings: Optional[mx.array] = None, 32 | ): 33 | if input_embeddings is not None: 34 | h = input_embeddings 35 | else: 36 | h = self.embed_tokens(inputs) 37 | h *= mx.array( 38 | self.config.hidden_size**0.5, self.embed_tokens.weight.dtype 39 | ).astype(h.dtype) 40 | 41 | if cache is None: 42 | cache = [None] * len(self.layers) 43 | 44 | if mask is None: 45 | j = self.config.sliding_window_pattern 46 | full_mask = create_attention_mask(h, cache[j - 1 : j]) 47 | sliding_window_mask = create_attention_mask(h, cache) 48 | 49 | for i, (layer, c) in enumerate(zip(self.layers, cache)): 50 | is_global = ( 51 | i % self.config.sliding_window_pattern 52 | == self.config.sliding_window_pattern - 1 53 | ) 54 | 55 | local_mask = mask 56 | if mask is None and is_global: 57 | local_mask = full_mask 58 | elif mask is None: 59 | local_mask = sliding_window_mask 60 | 61 | h = layer(h, local_mask, c) 62 | 63 | return self.norm(h) 64 | 65 | 66 | class Model(nn.Module): 67 | def __init__(self, config: ModelArgs): 68 | super().__init__() 69 | self.config = config 70 | self.model_type = config.model_type 71 | self.model = Gemma3Model(config) 72 | self.dense = [ 73 | nn.Linear(config.hidden_size, config.hidden_size * 4, bias=False), 74 | nn.Linear(config.hidden_size * 4, config.hidden_size, bias=False), 75 | ] 76 | 77 | def get_extended_attention_mask(self, attention_mask, input_shape): 78 | if attention_mask.ndim == 3: 79 | extended_attention_mask = attention_mask[:, None, :, :] 80 | elif attention_mask.ndim == 2: 81 | extended_attention_mask = attention_mask[:, None, None, :] 82 | extended_attention_mask = mx.repeat( 83 | extended_attention_mask, attention_mask.shape[-1], -2 84 | ) 85 | 86 | else: 87 | raise ValueError( 88 | f"Wrong shape for attention_mask (shape {attention_mask.shape})" 89 | ) 90 | return extended_attention_mask 91 | 92 | def __call__( 93 | self, 94 | inputs: mx.array, 95 | attention_mask: Optional[mx.array] = None, 96 | ): 97 | 98 | if attention_mask is None: 99 | attention_mask = mx.ones(inputs.shape) 100 | 101 | extended_attention_mask = self.get_extended_attention_mask( 102 | attention_mask, inputs.shape 103 | ) 104 | 105 | out = self.model(inputs, extended_attention_mask) 106 | 107 | for dense in self.dense: 108 | out = dense(out) 109 | 110 | # normalized features 111 | text_embeds = mean_pooling(out, attention_mask) 112 | text_embeds = normalize_embeddings(text_embeds) 113 | 114 | return BaseModelOutput( 115 | last_hidden_state=out, 116 | text_embeds=text_embeds, 117 | pooler_output=None, 118 | ) 119 | 120 | def sanitize(self, weights): 121 | sanitized_weights = {} 122 | for k, v in weights.items(): 123 | if "linear" not in k and "dense" not in k: 124 | new_key = f"model.{k}" if not k.startswith("model") else k 125 | sanitized_weights[new_key] = v 126 | elif "dense" not in k: 127 | key_id = "0" if v.shape[0] > v.shape[1] else "1" 128 | new_key = re.sub(r"\d+_Dense\.linear", f"dense.{key_id}", k) 129 | sanitized_weights[new_key] = v 130 | else: 131 | sanitized_weights[k] = v 132 | 133 | return sanitized_weights 134 | 135 | @property 136 | def layers(self): 137 | return self.model.layers 138 | -------------------------------------------------------------------------------- /mlx_embeddings/models/lfm2.py: -------------------------------------------------------------------------------- 1 | import re 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | from mlx_lm.models.base import create_attention_mask, create_ssm_mask 8 | from mlx_lm.models.cache import ArraysCache, KVCache 9 | from mlx_lm.models.lfm2 import Lfm2DecoderLayer 10 | from mlx_lm.models.lfm2 import ModelArgs as Lfm2ModelArgs 11 | 12 | from .base import BaseModelOutput, mean_pooling, normalize_embeddings 13 | 14 | 15 | @dataclass 16 | class ModelArgs(Lfm2ModelArgs): 17 | out_features: int = 128 18 | 19 | 20 | class Lfm2Model(nn.Module): 21 | def __init__(self, args: ModelArgs): 22 | super().__init__() 23 | self.args = args 24 | self.vocab_size = args.vocab_size 25 | self.num_hidden_layers = args.num_hidden_layers 26 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 27 | self.layers = [ 28 | Lfm2DecoderLayer(args, layer_idx=i) for i in range(args.num_hidden_layers) 29 | ] 30 | 31 | self.embedding_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps) 32 | 33 | self.fa_idx = args.full_attn_idxs[0] 34 | self.conv_idx = 0 35 | for i in range(args.num_hidden_layers): 36 | if i in args.full_attn_idxs: 37 | self.conv_idx += 1 38 | else: 39 | break 40 | 41 | def __call__( 42 | self, 43 | inputs: mx.array, 44 | cache=None, 45 | input_embeddings: Optional[mx.array] = None, 46 | ): 47 | if input_embeddings is not None: 48 | h = input_embeddings 49 | else: 50 | h = self.embed_tokens(inputs) 51 | 52 | if cache is None: 53 | cache = [None] * len(self.layers) 54 | 55 | attn_mask = create_attention_mask(h, cache[self.fa_idx]) 56 | conv_mask = create_ssm_mask(h, cache[self.conv_idx]) 57 | 58 | for layer, c in zip(self.layers, cache): 59 | mask = attn_mask if layer.is_attention_layer else conv_mask 60 | h = layer(h, mask, cache=c) 61 | 62 | return self.embedding_norm(h) 63 | 64 | 65 | class Model(nn.Module): 66 | def __init__(self, config: ModelArgs): 67 | super().__init__() 68 | self.config = config 69 | self.model_type = config.model_type 70 | self.model = Lfm2Model(config) 71 | self.dense = [ 72 | nn.Linear(config.block_dim, config.out_features, bias=False), 73 | ] 74 | 75 | def get_extended_attention_mask(self, attention_mask, input_shape): 76 | if attention_mask.ndim == 3: 77 | extended_attention_mask = attention_mask[:, None, :, :] 78 | elif attention_mask.ndim == 2: 79 | extended_attention_mask = attention_mask[:, None, None, :] 80 | extended_attention_mask = mx.repeat( 81 | extended_attention_mask, attention_mask.shape[-1], -2 82 | ) 83 | 84 | else: 85 | raise ValueError( 86 | f"Wrong shape for attention_mask (shape {attention_mask.shape})" 87 | ) 88 | return extended_attention_mask 89 | 90 | def __call__( 91 | self, 92 | inputs: mx.array, 93 | attention_mask: Optional[mx.array] = None, 94 | ): 95 | 96 | if attention_mask is None: 97 | attention_mask = mx.ones(inputs.shape) 98 | 99 | h = self.model(inputs, cache=self.make_cache) 100 | out = h 101 | for dense in self.dense: 102 | out = dense(out) 103 | 104 | text_embeds = normalize_embeddings(out) 105 | 106 | # Mask pad tokens 107 | text_embeds = text_embeds * attention_mask[:, :, None] 108 | 109 | pooled = mean_pooling(text_embeds, attention_mask) 110 | 111 | return BaseModelOutput( 112 | last_hidden_state=h, 113 | text_embeds=text_embeds, 114 | pooler_output=pooled, 115 | ) 116 | 117 | def sanitize(self, weights): 118 | sanitized_weights = {} 119 | for k, v in weights.items(): 120 | 121 | if "linear" not in k and "dense" not in k: 122 | new_key = f"model.{k}" if not k.startswith("model") else k 123 | if "conv.weight" in new_key: 124 | if v.shape[-1] > v.shape[1]: 125 | v = v.transpose(0, 2, 1) 126 | 127 | sanitized_weights[new_key] = v 128 | elif "1_Dense.linear" in k: 129 | new_key = k.replace("1_Dense.linear", "dense.0") 130 | sanitized_weights[new_key] = v 131 | else: 132 | sanitized_weights[k] = v 133 | 134 | return sanitized_weights 135 | 136 | @property 137 | def layers(self): 138 | return self.model.layers 139 | 140 | @property 141 | def make_cache(self): 142 | return [ 143 | KVCache() if l.is_attention_layer else ArraysCache(size=1) 144 | for l in self.model.layers 145 | ] 146 | -------------------------------------------------------------------------------- /mlx_embeddings/tests/test_base.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import numpy as np 3 | import pytest 4 | 5 | from mlx_embeddings.models.base import ( 6 | BaseModelArgs, 7 | BaseModelOutput, 8 | ViTModelOutput, 9 | mean_pooling, 10 | normalize_embeddings, 11 | ) 12 | 13 | 14 | class TestBaseModelArgs: 15 | def test_from_dict(self): 16 | # Create a sample class that inherits from BaseModelArgs 17 | class TestArgs(BaseModelArgs): 18 | def __init__(self, a=1, b=2, c=3): 19 | self.a = a 20 | self.b = b 21 | self.c = c 22 | 23 | # Test with exact params 24 | params = {"a": 10, "b": 20, "c": 30} 25 | args = TestArgs.from_dict(params) 26 | assert args.a == 10 27 | assert args.b == 20 28 | assert args.c == 30 29 | 30 | # Test with extra params (should be ignored) 31 | params = {"a": 10, "b": 20, "c": 30, "d": 40} 32 | args = TestArgs.from_dict(params) 33 | assert args.a == 10 34 | assert args.b == 20 35 | assert args.c == 30 36 | assert not hasattr(args, "d") 37 | 38 | # Test with missing params (should use defaults) 39 | params = {"a": 10} 40 | args = TestArgs.from_dict(params) 41 | assert args.a == 10 42 | assert args.b == 2 43 | assert args.c == 3 44 | 45 | 46 | class TestBaseModelOutput: 47 | def test_initialization(self): 48 | # Test default initialization 49 | output = BaseModelOutput() 50 | assert output.last_hidden_state is None 51 | assert output.pooler_output is None 52 | assert output.text_embeds is None 53 | assert output.hidden_states is None 54 | 55 | # Test with values 56 | mock_array = mx.array([1, 2, 3]) 57 | mock_list = [mx.array([1, 2]), mx.array([3, 4])] 58 | output = BaseModelOutput( 59 | last_hidden_state=mock_array, 60 | pooler_output=mock_array, 61 | text_embeds=mock_array, 62 | hidden_states=mock_list, 63 | ) 64 | assert output.last_hidden_state is mock_array 65 | assert output.pooler_output is mock_array 66 | assert output.text_embeds is mock_array 67 | assert output.hidden_states is mock_list 68 | 69 | 70 | class TestViTModelOutput: 71 | def test_initialization(self): 72 | # Test default initialization 73 | output = ViTModelOutput() 74 | assert output.logits is None 75 | assert output.text_embeds is None 76 | assert output.image_embeds is None 77 | assert output.logits_per_text is None 78 | assert output.logits_per_image is None 79 | assert output.text_model_output is None 80 | assert output.vision_model_output is None 81 | 82 | # Test with values 83 | mock_array = mx.array([1, 2, 3]) 84 | output = ViTModelOutput( 85 | logits=mock_array, 86 | text_embeds=mock_array, 87 | image_embeds=mock_array, 88 | logits_per_text=mock_array, 89 | logits_per_image=mock_array, 90 | text_model_output=mock_array, 91 | vision_model_output=mock_array, 92 | ) 93 | assert output.logits is mock_array 94 | assert output.text_embeds is mock_array 95 | assert output.image_embeds is mock_array 96 | assert output.logits_per_text is mock_array 97 | assert output.logits_per_image is mock_array 98 | assert output.text_model_output is mock_array 99 | assert output.vision_model_output is mock_array 100 | 101 | 102 | class TestMeanPooling: 103 | def test_mean_pooling(self): 104 | # Create sample inputs 105 | batch_size, seq_len, hidden_dim = 2, 3, 4 106 | token_embeddings = mx.random.normal((batch_size, seq_len, hidden_dim)) 107 | 108 | # Test case 1: No masking (all 1s) 109 | attention_mask = mx.ones((batch_size, seq_len)) 110 | result = mean_pooling(token_embeddings, attention_mask) 111 | 112 | # Expected result is the mean across sequence dimension 113 | expected = mx.mean(token_embeddings, axis=1) 114 | np.testing.assert_allclose(result.tolist(), expected.tolist(), rtol=1e-5) 115 | 116 | # Test case 2: With masking 117 | attention_mask = mx.array( 118 | [ 119 | [1, 1, 0], # Only first two tokens are valid 120 | [1, 0, 0], # Only first token is valid 121 | ] 122 | ) 123 | result = mean_pooling(token_embeddings, attention_mask) 124 | 125 | # Manual calculation for verification 126 | expected_0 = mx.sum(token_embeddings[0, :2], axis=0) / 2 127 | expected_1 = token_embeddings[1, 0] # Just the first embedding 128 | expected = mx.stack([expected_0, expected_1]) 129 | np.testing.assert_allclose(result.tolist(), expected.tolist(), rtol=1e-5) 130 | 131 | 132 | class TestNormalizeEmbeddings: 133 | def test_normalize_embeddings(self): 134 | # Test case 1: 2D array 135 | embeddings = mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 136 | normalized = normalize_embeddings(embeddings) 137 | 138 | # Check that each row has unit norm 139 | norms = mx.linalg.norm(normalized, ord=2, axis=-1) 140 | np.testing.assert_allclose(norms.tolist(), [1.0, 1.0], rtol=1e-5) 141 | 142 | # Test case 2: 3D array 143 | embeddings = mx.random.normal((2, 3, 4)) 144 | normalized = normalize_embeddings(embeddings) 145 | 146 | # Check shape is preserved 147 | assert normalized.shape == embeddings.shape 148 | 149 | # Check that each vector in the last dimension has unit norm 150 | norms = mx.linalg.norm(normalized, ord=2, axis=-1) 151 | expected_norms = mx.ones((2, 3)) 152 | np.testing.assert_allclose(norms.tolist(), expected_norms.tolist(), rtol=1e-5) 153 | 154 | # Test case 3: Small values (testing the epsilon) 155 | embeddings = mx.zeros((2, 3)) 156 | normalized = normalize_embeddings(embeddings, eps=1.0) 157 | expected = mx.zeros((2, 3)) 158 | np.testing.assert_allclose(normalized.tolist(), expected.tolist(), rtol=1e-5) 159 | -------------------------------------------------------------------------------- /mlx_embeddings/tests/test_smoke.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import platform 5 | import subprocess 6 | import sys 7 | import textwrap 8 | import time 9 | import traceback 10 | 11 | import mlx.core as mx 12 | import psutil 13 | import requests 14 | from PIL import Image 15 | from rich.console import Console 16 | from rich.panel import Panel 17 | from tqdm import tqdm 18 | from transformers import __version__ as transformers_version 19 | 20 | from mlx_embeddings import generate, load 21 | from mlx_embeddings.utils import load_config 22 | from mlx_embeddings.version import __version__ 23 | 24 | # Initialize console 25 | console = Console() 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser(description="Test MLX-VLM models") 30 | parser.add_argument( 31 | "--models", 32 | type=str, 33 | nargs="+", 34 | required=True, 35 | help="Path to file containing model paths, one per line", 36 | ) 37 | parser.add_argument( 38 | "--images", 39 | type=str, 40 | nargs="+", 41 | required=False, 42 | help="Path to file containing image paths, one per line", 43 | ) 44 | return parser.parse_args() 45 | 46 | 47 | def get_device_info(): 48 | # Disable tokenizers parallelism to avoid deadlocks after forking 49 | import os 50 | 51 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 52 | 53 | try: 54 | data = subprocess.check_output( 55 | ["system_profiler", "SPDisplaysDataType", "-json"], text=True 56 | ) 57 | device_info = json.loads(data) 58 | return device_info 59 | except Exception as e: 60 | print(f"Could not retrieve GPU information: {e}") 61 | return None 62 | 63 | 64 | def test_model_loading(model_path): 65 | try: 66 | console.print("[bold green]Loading model...") 67 | start_time = time.time() 68 | model, processor = load(model_path) 69 | end_time = time.time() 70 | console.print( 71 | f"[bold green]✓[/] Model loaded successfully in {end_time - start_time:.2f} seconds" 72 | ) 73 | return model, processor, False 74 | except Exception as e: 75 | console.print(f"[bold red]✗[/] Failed to load model: {str(e)}") 76 | traceback.print_exc() 77 | return None, None, True 78 | 79 | 80 | def test_generation(model, processor, images): 81 | try: 82 | console.print(f"[bold yellow]Testing embedding...") 83 | test_type = "Text embedding" 84 | 85 | if hasattr(model.config, "vision_config"): 86 | test_type = "ViT embedding" 87 | 88 | # Text descriptions 89 | texts = [ 90 | "a photo of cats", 91 | "a photo of a desktop setup", 92 | "a photo of a person", 93 | ] 94 | 95 | # Process all image-text pairs 96 | all_probs = [] 97 | 98 | for i, image in enumerate(images): 99 | # Process inputs for current image with all texts 100 | output = generate(model, processor, texts=texts, images=image) 101 | logits_per_image = output.logits_per_image 102 | probs = mx.sigmoid(logits_per_image)[0] # probabilities for this image 103 | all_probs.append(probs.tolist()) 104 | 105 | # Print results for this image 106 | print(f"Image {i+1}:") 107 | for j, text in enumerate(texts): 108 | print(f" {probs[j]:.1%} match with '{text}'") 109 | print() 110 | 111 | assert len(all_probs) == len(images) 112 | elif hasattr(model.config, "architectures") and model.config.architectures == [ 113 | "ModernBertForMaskedLM" 114 | ]: 115 | test_type = "Masked Language Modeling" 116 | texts = [ 117 | "The capital of France is [MASK].", 118 | "The capital of Poland is [MASK].", 119 | ] 120 | inputs = processor.batch_encode_plus( 121 | texts, 122 | return_tensors="mlx", 123 | padding=True, 124 | truncation=True, 125 | max_length=512, 126 | ) 127 | 128 | output = generate( 129 | model, 130 | processor, 131 | texts=texts, 132 | padding=True, 133 | truncation=True, 134 | max_length=512, 135 | ) 136 | mask_indices = mx.array( 137 | [ 138 | ids.tolist().index(processor.mask_token_id) 139 | for ids in inputs["input_ids"] 140 | ] 141 | ) 142 | 143 | # Get predictions for all masked tokens at once 144 | batch_indices = mx.arange(len(mask_indices)) 145 | predicted_token_ids = mx.argmax( 146 | output.pooler_output[batch_indices, mask_indices], axis=-1 147 | ).tolist() 148 | 149 | predicted_tokens = processor.batch_decode( 150 | predicted_token_ids, skip_special_tokens=True 151 | ) 152 | print("Predicted tokens:", predicted_tokens) 153 | 154 | else: 155 | test_type = "Text embedding" 156 | # Create text descriptions to compare with the image 157 | texts = [ 158 | "I like grapes", 159 | "I like fruits", 160 | "The slow green turtle crawls under the busy ant.", 161 | ] 162 | 163 | # Process inputs 164 | output = generate(model, processor, texts=texts) 165 | 166 | assert output.text_embeds.shape == (len(texts), model.config.hidden_size) 167 | 168 | # Calculate similarity between text embeddings 169 | embeddings = output.text_embeds 170 | # Compute dot product between normalized embeddings 171 | similarity_matrix = mx.matmul(embeddings, embeddings.T) 172 | 173 | print("\nSimilarity matrix between texts:") 174 | print(similarity_matrix) 175 | 176 | console.print(f"[bold green]✓[/] {test_type} generation successful") 177 | return False 178 | except Exception as e: 179 | console.print(f"[bold red]✗[/] {test_type} generation failed: {str(e)}") 180 | traceback.print_exc() 181 | return True 182 | 183 | 184 | def main(): 185 | args = parse_args() 186 | 187 | # Load models list 188 | if isinstance(args.models, str) and os.path.exists(args.models): 189 | with open(args.models, "r", encoding="utf-8") as f: 190 | models = [line.strip() for line in f.readlines()] 191 | else: 192 | models = args.models 193 | 194 | results = [] 195 | 196 | for model_path in tqdm(models): 197 | console.print(Panel(f"Testing {model_path}", style="bold blue")) 198 | 199 | # Run tests 200 | model, processor, error = test_model_loading(model_path) 201 | 202 | if not error and model: 203 | print("\n") 204 | # Test vision-language generation 205 | error |= test_generation(model, processor, args.images) 206 | 207 | print("\n") 208 | 209 | console.print("[bold blue]Cleaning up...") 210 | del model, processor 211 | mx.metal.clear_cache() 212 | mx.metal.reset_peak_memory() 213 | console.print("[bold green]✓[/] Cleanup complete\n") 214 | results.append( 215 | f"[bold {'green' if not error else 'red'}]{'✓' if not error else '✗'}[/] {model_path}" 216 | ) 217 | 218 | print("\n") 219 | success = all(result.startswith("[bold green]") for result in results) 220 | panel_style = "bold green" if success else "bold red" 221 | console.print(Panel("\n".join(results), title="Results", style=panel_style)) 222 | console.print( 223 | f"[bold {'green' if success else 'red'}]{'All' if success else 'Some'} models tested {'successfully' if success else 'failed to test'}" 224 | ) 225 | 226 | print("\n") 227 | device_info = get_device_info() 228 | console.print( 229 | Panel( 230 | title="System Information", 231 | renderable=textwrap.dedent( 232 | f"""{platform.machine() == 'arm64' and f''' 233 | MAC OS: v{platform.mac_ver()[0]} 234 | Python: v{sys.version.split()[0]} 235 | MLX: v{mx.__version__} 236 | MLX-VLM: v{__version__} 237 | Transformers: v{transformers_version} 238 | 239 | Hardware: 240 | • Chip: {device_info['SPDisplaysDataType'][0]['_name']} 241 | • RAM: {psutil.virtual_memory().total / (1024 ** 3):.1f} GB 242 | • CPU Cores: {psutil.cpu_count(logical=False)} 243 | • GPU Cores: {device_info['SPDisplaysDataType'][0]['sppci_cores']} 244 | ''' or 'Not running on Apple Silicon'}""" 245 | ), 246 | style="bold blue", 247 | ) 248 | ) 249 | 250 | 251 | if __name__ == "__main__": 252 | main() 253 | -------------------------------------------------------------------------------- /mlx_embeddings/models/bert.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from typing import Optional, Tuple 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | 8 | from .base import BaseModelArgs, BaseModelOutput, mean_pooling, normalize_embeddings 9 | 10 | 11 | @dataclass 12 | class ModelArgs(BaseModelArgs): 13 | model_type: str 14 | num_hidden_layers: int 15 | num_attention_heads: int 16 | hidden_size: int 17 | intermediate_size: int 18 | max_position_embeddings: int 19 | hidden_dropout_prob: float = 0.1 20 | attention_probs_dropout_prob: float = 0.1 21 | type_vocab_size: int = 2 22 | initializer_range: float = 0.02 23 | layer_norm_eps: float = 1e-12 24 | vocab_size: int = 30522 25 | 26 | 27 | class BertEmbeddings(nn.Module): 28 | def __init__(self, config: ModelArgs): 29 | super().__init__() 30 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 31 | self.position_embeddings = nn.Embedding( 32 | config.max_position_embeddings, config.hidden_size 33 | ) 34 | self.token_type_embeddings = nn.Embedding( 35 | config.type_vocab_size, config.hidden_size 36 | ) 37 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 38 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 39 | 40 | def __call__(self, input_ids, token_type_ids=None, position_ids=None): 41 | seq_length = input_ids.shape[1] 42 | if position_ids is None: 43 | position_ids = mx.arange(seq_length, dtype=mx.int32)[None, :] 44 | if token_type_ids is None: 45 | token_type_ids = mx.zeros_like(input_ids) 46 | 47 | words_embeddings = self.word_embeddings(input_ids) 48 | position_embeddings = self.position_embeddings(position_ids) 49 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 50 | 51 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 52 | embeddings = self.LayerNorm(embeddings) 53 | embeddings = self.dropout(embeddings) 54 | return embeddings 55 | 56 | 57 | class BertSelfAttention(nn.Module): 58 | def __init__(self, config: ModelArgs): 59 | super().__init__() 60 | self.num_attention_heads = config.num_attention_heads 61 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 62 | self.all_head_size = self.num_attention_heads * self.attention_head_size 63 | 64 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 65 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 66 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 67 | 68 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 69 | 70 | def transpose_for_scores(self, x): 71 | new_x_shape = x.shape[:-1] + ( 72 | self.num_attention_heads, 73 | self.attention_head_size, 74 | ) 75 | x = x.reshape(new_x_shape) 76 | return x.transpose(0, 2, 1, 3) 77 | 78 | def __call__(self, hidden_states, attention_mask=None): 79 | mixed_query_layer = self.query(hidden_states) 80 | mixed_key_layer = self.key(hidden_states) 81 | mixed_value_layer = self.value(hidden_states) 82 | 83 | query_layer = self.transpose_for_scores(mixed_query_layer) 84 | key_layer = self.transpose_for_scores(mixed_key_layer) 85 | value_layer = self.transpose_for_scores(mixed_value_layer) 86 | 87 | attention_scores = mx.matmul(query_layer, key_layer.transpose(0, 1, 3, 2)) 88 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 89 | if attention_mask is not None: 90 | attention_scores = attention_scores + attention_mask 91 | 92 | attention_probs = mx.softmax(attention_scores, axis=-1) 93 | attention_probs = self.dropout(attention_probs) 94 | 95 | context_layer = mx.matmul(attention_probs, value_layer) 96 | context_layer = context_layer.transpose(0, 2, 1, 3) 97 | new_context_layer_shape = context_layer.shape[:-2] + (self.all_head_size,) 98 | context_layer = context_layer.reshape(new_context_layer_shape) 99 | 100 | return context_layer 101 | 102 | 103 | class BertSelfOutput(nn.Module): 104 | def __init__(self, config: ModelArgs): 105 | super().__init__() 106 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 107 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 108 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 109 | 110 | def __call__(self, hidden_states, input_tensor): 111 | hidden_states = self.dense(hidden_states) 112 | hidden_states = self.dropout(hidden_states) 113 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 114 | return hidden_states 115 | 116 | 117 | class BertAttention(nn.Module): 118 | def __init__(self, config: ModelArgs): 119 | super().__init__() 120 | self.self = BertSelfAttention(config) 121 | self.output = BertSelfOutput(config) 122 | 123 | def __call__(self, hidden_states, attention_mask=None): 124 | self_outputs = self.self(hidden_states, attention_mask) 125 | attention_output = self.output(self_outputs, hidden_states) 126 | return attention_output 127 | 128 | 129 | class BertIntermediate(nn.Module): 130 | def __init__(self, config: ModelArgs): 131 | super().__init__() 132 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 133 | self.intermediate_act_fn = nn.GELU() 134 | 135 | def __call__(self, hidden_states): 136 | hidden_states = self.dense(hidden_states) 137 | hidden_states = self.intermediate_act_fn(hidden_states) 138 | return hidden_states 139 | 140 | 141 | class BertOutput(nn.Module): 142 | def __init__(self, config: ModelArgs): 143 | super().__init__() 144 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 145 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 146 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 147 | 148 | def __call__(self, hidden_states, input_tensor): 149 | hidden_states = self.dense(hidden_states) 150 | hidden_states = self.dropout(hidden_states) 151 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 152 | return hidden_states 153 | 154 | 155 | class BertLayer(nn.Module): 156 | def __init__(self, config: ModelArgs): 157 | super().__init__() 158 | self.attention = BertAttention(config) 159 | self.intermediate = BertIntermediate(config) 160 | self.output = BertOutput(config) 161 | 162 | def __call__(self, hidden_states, attention_mask=None): 163 | attention_output = self.attention(hidden_states, attention_mask) 164 | intermediate_output = self.intermediate(attention_output) 165 | layer_output = self.output(intermediate_output, attention_output) 166 | return layer_output 167 | 168 | 169 | class BertEncoder(nn.Module): 170 | def __init__(self, config: ModelArgs): 171 | super().__init__() 172 | self.layer = [BertLayer(config) for _ in range(config.num_hidden_layers)] 173 | 174 | def __call__(self, hidden_states, attention_mask=None): 175 | for layer_module in self.layer: 176 | hidden_states = layer_module(hidden_states, attention_mask) 177 | return hidden_states 178 | 179 | 180 | class BertPooler(nn.Module): 181 | def __init__(self, config: ModelArgs): 182 | super().__init__() 183 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 184 | self.activation = nn.Tanh() 185 | 186 | def __call__(self, hidden_states): 187 | first_token_tensor = hidden_states[:, 0] 188 | pooled_output = self.dense(first_token_tensor) 189 | pooled_output = self.activation(pooled_output) 190 | return pooled_output 191 | 192 | 193 | class Model(nn.Module): 194 | def __init__(self, config: ModelArgs): 195 | super().__init__() 196 | self.config = config 197 | self.embeddings = BertEmbeddings(config) 198 | self.encoder = BertEncoder(config) 199 | self.pooler = BertPooler(config) 200 | 201 | def get_extended_attention_mask(self, attention_mask): 202 | if attention_mask.ndim == 3: 203 | extended_attention_mask = attention_mask[:, None, :, :] 204 | elif attention_mask.ndim == 2: 205 | extended_attention_mask = attention_mask[:, None, None, :] 206 | else: 207 | raise ValueError( 208 | f"Wrong shape for attention_mask (shape {attention_mask.shape})" 209 | ) 210 | 211 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 212 | return extended_attention_mask 213 | 214 | def __call__(self, input_ids, token_type_ids=None, attention_mask=None): 215 | batch_size, seq_len = input_ids.shape 216 | embedding_output = self.embeddings(input_ids, token_type_ids) 217 | 218 | if attention_mask is None: 219 | attention_mask = mx.ones((batch_size, seq_len)) 220 | 221 | extended_attention_mask = self.get_extended_attention_mask(attention_mask) 222 | 223 | encoder_outputs = self.encoder(embedding_output, extended_attention_mask) 224 | sequence_output = encoder_outputs 225 | pooled_output = self.pooler(sequence_output) 226 | 227 | # normalized features 228 | text_embeds = mean_pooling(sequence_output, attention_mask) 229 | text_embeds = normalize_embeddings(text_embeds) 230 | 231 | return BaseModelOutput( 232 | last_hidden_state=sequence_output, 233 | text_embeds=text_embeds, 234 | pooler_output=pooled_output, 235 | ) 236 | 237 | def sanitize(self, weights): 238 | sanitized_weights = {} 239 | for k, v in weights.items(): 240 | if "position_ids" in k: 241 | # Remove unused position_ids 242 | continue 243 | else: 244 | sanitized_weights[k] = v 245 | return sanitized_weights 246 | -------------------------------------------------------------------------------- /mlx_embeddings/tokenizer_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from functools import partial 3 | 4 | from transformers import AutoTokenizer 5 | 6 | REPLACEMENT_CHAR = "\ufffd" 7 | 8 | 9 | def _remove_space(x): 10 | if x and x[0] == " ": 11 | return x[1:] 12 | return x 13 | 14 | 15 | class StreamingDetokenizer: 16 | """The streaming detokenizer interface so that we can detokenize one token at a time. 17 | 18 | Example usage is as follows: 19 | 20 | detokenizer = ... 21 | 22 | # Reset the tokenizer state 23 | detokenizer.reset() 24 | 25 | for token in generate(...): 26 | detokenizer.add_token(token.item()) 27 | 28 | # Contains the whole text so far. Some tokens may not be included 29 | # since it contains whole words usually. 30 | detokenizer.text 31 | 32 | # Contains the printable segment (usually a word) since the last 33 | # time it was accessed 34 | detokenizer.last_segment 35 | 36 | # Contains all the tokens added so far 37 | detokenizer.tokens 38 | 39 | # Make sure that we detokenize any remaining tokens 40 | detokenizer.finalize() 41 | 42 | # Now detokenizer.text should match tokenizer.decode(detokenizer.tokens) 43 | """ 44 | 45 | __slots__ = ("text", "tokens", "offset") 46 | 47 | def reset(self): 48 | raise NotImplementedError() 49 | 50 | def add_token(self, token): 51 | raise NotImplementedError() 52 | 53 | def finalize(self): 54 | raise NotImplementedError() 55 | 56 | @property 57 | def last_segment(self): 58 | """Return the last segment of readable text since last time this property was accessed.""" 59 | text = self.text 60 | if text and text[-1] != REPLACEMENT_CHAR: 61 | segment = text[self.offset :] 62 | self.offset = len(text) 63 | return segment 64 | return "" 65 | 66 | 67 | class NaiveStreamingDetokenizer(StreamingDetokenizer): 68 | """NaiveStreamingDetokenizer relies on the underlying tokenizer 69 | implementation and should work with every tokenizer. 70 | 71 | Its complexity is O(T^2) where T is the longest line since it will 72 | repeatedly detokenize the same tokens until a new line is generated. 73 | """ 74 | 75 | def __init__(self, tokenizer): 76 | self._tokenizer = tokenizer 77 | self._tokenizer.decode([0]) 78 | self.reset() 79 | 80 | def reset(self): 81 | self.offset = 0 82 | self._tokens = [] 83 | self._text = "" 84 | self._current_tokens = [] 85 | self._current_text = "" 86 | 87 | def add_token(self, token): 88 | self._current_tokens.append(token) 89 | 90 | def finalize(self): 91 | self._tokens.extend(self._current_tokens) 92 | self._text += self._tokenizer.decode(self._current_tokens) 93 | self._current_tokens = [] 94 | self._current_text = "" 95 | 96 | @property 97 | def text(self): 98 | if self._current_tokens: 99 | self._current_text = self._tokenizer.decode(self._current_tokens) 100 | if self._current_text and self._current_text[-1] == "\n": 101 | self._tokens.extend(self._current_tokens) 102 | self._text += self._current_text 103 | self._current_tokens.clear() 104 | self._current_text = "" 105 | return self._text + self._current_text 106 | 107 | @property 108 | def tokens(self): 109 | return self._tokens 110 | 111 | 112 | class SPMStreamingDetokenizer(StreamingDetokenizer): 113 | """A streaming detokenizer for SPM models. 114 | 115 | It adds tokens to the text if the next token starts with the special SPM 116 | underscore which results in linear complexity. 117 | """ 118 | 119 | def __init__(self, tokenizer, trim_space=True): 120 | self.trim_space = trim_space 121 | 122 | # Extract the tokens in a list from id to text 123 | self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1) 124 | for value, tokenid in tokenizer.vocab.items(): 125 | self.tokenmap[tokenid] = value 126 | 127 | # Replace bytes with their value 128 | for i in range(len(self.tokenmap)): 129 | if self.tokenmap[i].startswith("<0x"): 130 | self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16)) 131 | 132 | self.reset() 133 | 134 | def reset(self): 135 | self.offset = 0 136 | self._unflushed = "" 137 | self.text = "" 138 | self.tokens = [] 139 | 140 | def add_token(self, token): 141 | v = self.tokenmap[token] 142 | if v[0] == "\u2581": 143 | if self.text or not self.trim_space: 144 | self.text += self._unflushed.replace("\u2581", " ") 145 | else: 146 | self.text = _remove_space(self._unflushed.replace("\u2581", " ")) 147 | self._unflushed = v 148 | else: 149 | self._unflushed += v 150 | 151 | def finalize(self): 152 | if self.text or not self.trim_space: 153 | self.text += self._unflushed.replace("\u2581", " ") 154 | else: 155 | self.text = _remove_space(self._unflushed.replace("\u2581", " ")) 156 | self._unflushed = "" 157 | 158 | 159 | class BPEStreamingDetokenizer(StreamingDetokenizer): 160 | """A streaming detokenizer for OpenAI style BPE models. 161 | 162 | It adds tokens to the text if the next token starts with a space similar to 163 | the SPM detokenizer. 164 | """ 165 | 166 | _byte_decoder = None 167 | 168 | def __init__(self, tokenizer, trim_space=False): 169 | self.trim_space = trim_space 170 | 171 | # Extract the tokens in a list from id to text 172 | self.tokenmap = [None] * len(tokenizer.vocab) 173 | for value, tokenid in tokenizer.vocab.items(): 174 | self.tokenmap[tokenid] = value 175 | 176 | self.reset() 177 | 178 | # Make the BPE byte decoder from 179 | # https://github.com/openai/gpt-2/blob/master/src/encoder.py 180 | self.make_byte_decoder() 181 | 182 | def reset(self): 183 | self.offset = 0 184 | self._unflushed = "" 185 | self.text = "" 186 | self.tokens = [] 187 | 188 | def add_token(self, token): 189 | v = self.tokenmap[token] 190 | # if the token starts with space 191 | if self._byte_decoder[v[0]] == 32: 192 | current_text = bytearray( 193 | self._byte_decoder[c] for c in self._unflushed 194 | ).decode("utf-8") 195 | if self.text or not self.trim_space: 196 | self.text += current_text 197 | else: 198 | self.text += _remove_space(current_text) 199 | self._unflushed = v 200 | else: 201 | self._unflushed += v 202 | 203 | def finalize(self): 204 | current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( 205 | "utf-8" 206 | ) 207 | if self.text or not self.trim_space: 208 | self.text += current_text 209 | else: 210 | self.text += _remove_space(current_text) 211 | self._unflushed = "" 212 | 213 | @classmethod 214 | def make_byte_decoder(cls): 215 | """See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale.""" 216 | if cls._byte_decoder is not None: 217 | return 218 | 219 | char_to_bytes = {} 220 | limits = [ 221 | 0, 222 | ord("!"), 223 | ord("~") + 1, 224 | ord("¡"), 225 | ord("¬") + 1, 226 | ord("®"), 227 | ord("ÿ") + 1, 228 | ] 229 | n = 0 230 | for i, (start, stop) in enumerate(zip(limits, limits[1:])): 231 | if i % 2 == 0: 232 | for b in range(start, stop): 233 | char_to_bytes[chr(2**8 + n)] = b 234 | n += 1 235 | else: 236 | for b in range(start, stop): 237 | char_to_bytes[chr(b)] = b 238 | cls._byte_decoder = char_to_bytes 239 | 240 | 241 | class TokenizerWrapper: 242 | """A wrapper that combines an HF tokenizer and a detokenizer. 243 | 244 | Accessing any attribute other than the ``detokenizer`` is forwarded to the 245 | huggingface tokenizer. 246 | """ 247 | 248 | def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer): 249 | self._tokenizer = tokenizer 250 | self._detokenizer = detokenizer_class(tokenizer) 251 | 252 | def __getattr__(self, attr): 253 | if attr == "detokenizer": 254 | return self._detokenizer 255 | else: 256 | return getattr(self._tokenizer, attr) 257 | 258 | 259 | def _match(a, b): 260 | if type(a) != type(b): 261 | return False 262 | if isinstance(a, dict): 263 | return len(a) == len(b) and all(k in b and _match(a[k], b[k]) for k in a) 264 | if isinstance(a, list): 265 | return len(a) == len(b) and all(_match(ai, bi) for ai, bi in zip(a, b)) 266 | 267 | return a == b 268 | 269 | 270 | def _is_spm_decoder(decoder): 271 | _target_description = { 272 | "type": "Sequence", 273 | "decoders": [ 274 | {"type": "Replace", "pattern": {"String": "▁"}, "content": " "}, 275 | {"type": "ByteFallback"}, 276 | {"type": "Fuse"}, 277 | {"type": "Strip", "content": " ", "start": 1, "stop": 0}, 278 | ], 279 | } 280 | return _match(_target_description, decoder) 281 | 282 | 283 | def _is_spm_decoder_no_space(decoder): 284 | _target_description = { 285 | "type": "Sequence", 286 | "decoders": [ 287 | {"type": "Replace", "pattern": {"String": "▁"}, "content": " "}, 288 | {"type": "ByteFallback"}, 289 | {"type": "Fuse"}, 290 | ], 291 | } 292 | return _match(_target_description, decoder) 293 | 294 | 295 | def _is_bpe_decoder(decoder): 296 | _target_description = { 297 | "type": "ByteLevel", 298 | "add_prefix_space": False, 299 | "trim_offsets": False, 300 | "use_regex": False, 301 | } 302 | 303 | return _match(_target_description, decoder) 304 | 305 | 306 | def load_tokenizer(model_path, tokenizer_config_extra={}): 307 | """Load a huggingface tokenizer and try to infer the type of streaming 308 | detokenizer to use. 309 | 310 | Note, to use a fast streaming tokenizer, pass a local file path rather than 311 | a Hugging Face repo ID. 312 | """ 313 | detokenizer_class = NaiveStreamingDetokenizer 314 | 315 | tokenizer_file = model_path / "tokenizer.json" 316 | if tokenizer_file.exists(): 317 | with open(tokenizer_file, "r") as fid: 318 | tokenizer_content = json.load(fid) 319 | if "decoder" in tokenizer_content: 320 | if _is_spm_decoder(tokenizer_content["decoder"]): 321 | detokenizer_class = SPMStreamingDetokenizer 322 | elif _is_spm_decoder_no_space(tokenizer_content["decoder"]): 323 | detokenizer_class = partial(SPMStreamingDetokenizer, trim_space=False) 324 | elif _is_bpe_decoder(tokenizer_content["decoder"]): 325 | detokenizer_class = BPEStreamingDetokenizer 326 | 327 | return TokenizerWrapper( 328 | AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra), 329 | detokenizer_class, 330 | ) 331 | -------------------------------------------------------------------------------- /mlx_embeddings/models/colqwen2_5.py: -------------------------------------------------------------------------------- 1 | """ 2 | ColQwen2.5 model implementation for MLX. 3 | 4 | ColQwen2.5 is a multimodal retrieval model that uses Qwen2.5-VL as its backbone 5 | to create efficient multi-vector embeddings from document images for retrieval. 6 | It follows the ColPali approach, eliminating the need for OCR pipelines. 7 | """ 8 | 9 | import inspect 10 | from dataclasses import asdict, dataclass 11 | from typing import Any, Dict, Optional 12 | 13 | import mlx.core as mx 14 | import mlx.nn as nn 15 | import numpy as np 16 | from mlx_vlm.models.qwen2_5_vl import Model as Qwen2_5VLModel 17 | from mlx_vlm.models.qwen2_5_vl import ModelConfig, TextConfig, VisionConfig 18 | 19 | from .base import ViTModelOutput, normalize_embeddings 20 | 21 | 22 | @dataclass 23 | class ModelArgs: 24 | text_config: Dict[str, Any] # Keep as dict for utils.py compatibility 25 | vision_config: Dict[str, Any] # Keep as dict for utils.py compatibility 26 | vlm_config: Dict[str, Any] 27 | embedding_dim: int = 128 28 | initializer_range: float = 0.02 29 | model_type: str = "colqwen2_5" 30 | 31 | @classmethod 32 | def from_dict(cls, params): 33 | # Extract vlm_config 34 | vlm_config = params.get("vlm_config", {}) 35 | 36 | # Extract and clean text_config and vision_config 37 | text_config_raw = vlm_config.get("text_config", {}) 38 | vision_config_raw = vlm_config.get("vision_config", {}) 39 | 40 | # Use the Config classes' from_dict methods to filter parameters, 41 | # then convert back to clean dictionaries using asdict() 42 | text_config = ( 43 | asdict(TextConfig.from_dict(text_config_raw)) if text_config_raw else {} 44 | ) 45 | vision_config = ( 46 | asdict(VisionConfig.from_dict(vision_config_raw)) 47 | if vision_config_raw 48 | else {} 49 | ) 50 | 51 | # Create the ModelArgs with the cleaned configs 52 | return cls( 53 | text_config=text_config, 54 | vision_config=vision_config, 55 | vlm_config=vlm_config, 56 | embedding_dim=params.get("embedding_dim", 128), 57 | initializer_range=params.get("initializer_range", 0.02), 58 | model_type=params.get("model_type", "colqwen2_5"), 59 | ) 60 | 61 | def __post_init__(self): 62 | # Ensure vlm_config is a dictionary 63 | if not isinstance(self.vlm_config, dict): 64 | self.vlm_config = ( 65 | self.vlm_config.__dict__ if hasattr(self.vlm_config, "__dict__") else {} 66 | ) 67 | 68 | 69 | class Model(nn.Module): 70 | def __init__(self, config: ModelArgs): 71 | super().__init__() 72 | self.config = config 73 | 74 | # Import Qwen2_5VL model from mlx-vlm 75 | 76 | # Create VLM config from the dictionary 77 | vlm_config = ModelConfig.from_dict(config.vlm_config) 78 | if isinstance(vlm_config.vision_config, dict): 79 | vlm_config.vision_config = VisionConfig.from_dict(vlm_config.vision_config) 80 | if isinstance(vlm_config.text_config, dict): 81 | vlm_config.text_config = TextConfig.from_dict(vlm_config.text_config) 82 | 83 | # Initialize the VLM 84 | self.vlm = Qwen2_5VLModel(vlm_config) 85 | 86 | # Initialize the embedding projection layer 87 | self.embedding_proj_layer = nn.Linear( 88 | vlm_config.text_config.hidden_size, config.embedding_dim, bias=True 89 | ) 90 | 91 | # Get special token IDs from the VLM config 92 | self.image_token_id = vlm_config.image_token_id 93 | self.video_token_id = vlm_config.video_token_id 94 | 95 | def get_image_features( 96 | self, 97 | pixel_values: mx.array, 98 | image_grid_thw: Optional[mx.array] = None, 99 | ) -> mx.array: 100 | """Extract image features using the vision model.""" 101 | # Get vision features from the vision tower 102 | dtype = self.vlm.vision_tower.patch_embed.proj.weight.dtype 103 | pixel_values = pixel_values.astype(dtype) 104 | 105 | hidden_states = self.vlm.vision_tower( 106 | pixel_values, image_grid_thw, output_hidden_states=False 107 | ) 108 | 109 | return hidden_states 110 | 111 | def get_input_embeddings_batch( 112 | self, 113 | input_ids: mx.array, 114 | pixel_values: Optional[mx.array] = None, 115 | image_grid_thw: Optional[mx.array] = None, 116 | ): 117 | """Override VLM's get_input_embeddings to handle batch processing correctly.""" 118 | if pixel_values is None: 119 | return self.vlm.language_model.model.embed_tokens(input_ids) 120 | 121 | dtype = self.vlm.vision_tower.patch_embed.proj.weight.dtype 122 | pixel_values = pixel_values.astype(dtype) 123 | 124 | # Get the input embeddings from the language model 125 | inputs_embeds = self.vlm.language_model.model.embed_tokens(input_ids) 126 | 127 | # Get the output hidden states from the vision model 128 | hidden_states = self.vlm.vision_tower( 129 | pixel_values, image_grid_thw, output_hidden_states=False 130 | ) 131 | 132 | # Reshape hidden_states to match batch structure if needed 133 | batch_size = input_ids.shape[0] 134 | if batch_size > 1 and hidden_states.ndim == 2: 135 | # Calculate features per image based on grid_thw 136 | features_per_image = [] 137 | start_idx = 0 138 | for i in range(batch_size): 139 | t, h, w = image_grid_thw[i].tolist() # Convert to Python integers 140 | num_features = int( 141 | (h // self.vlm.vision_tower.spatial_merge_size) 142 | * (w // self.vlm.vision_tower.spatial_merge_size) 143 | * t 144 | ) 145 | features_per_image.append( 146 | hidden_states[start_idx : start_idx + num_features] 147 | ) 148 | start_idx += num_features 149 | hidden_states = mx.stack(features_per_image) 150 | 151 | if hidden_states.ndim == 2: 152 | hidden_states = hidden_states[None, :, :] 153 | 154 | # Merge image features with input embeddings 155 | image_token_id = self.vlm.config.image_token_id 156 | video_token_id = self.vlm.config.video_token_id 157 | 158 | # Handle batch processing correctly 159 | image_positions = input_ids == image_token_id 160 | if mx.sum(image_positions) == 0: 161 | image_positions = input_ids == video_token_id 162 | 163 | if batch_size == 1: 164 | # Original single-batch logic using numpy for index finding 165 | image_positions_np = np.array(image_positions) 166 | image_indices = np.where(image_positions_np)[1].tolist() 167 | inputs_embeds[:, image_indices, :] = hidden_states 168 | else: 169 | # Multi-batch processing 170 | for batch_idx in range(batch_size): 171 | # Get positions for this batch item 172 | batch_positions = image_positions[batch_idx] 173 | # Convert to numpy to find indices 174 | batch_positions_np = np.array(batch_positions) 175 | batch_indices = np.where(batch_positions_np)[0].tolist() 176 | 177 | # Get the corresponding features for this batch 178 | batch_features = hidden_states[batch_idx] 179 | 180 | # Update embeddings for this batch 181 | inputs_embeds[batch_idx, batch_indices, :] = batch_features 182 | 183 | return inputs_embeds 184 | 185 | def __call__( 186 | self, 187 | input_ids: mx.array, 188 | pixel_values: Optional[mx.array] = None, 189 | attention_mask: Optional[mx.array] = None, 190 | image_grid_thw: Optional[mx.array] = None, 191 | position_ids: Optional[mx.array] = None, 192 | cache=None, 193 | **kwargs, 194 | ) -> ViTModelOutput: 195 | """ 196 | Forward pass for ColQwen2_5 model. 197 | 198 | Args: 199 | input_ids: Input token IDs 200 | pixel_values: Pixel values for images 201 | attention_mask: Attention mask 202 | image_grid_thw: Image grid dimensions (temporal, height, width) 203 | position_ids: Position IDs 204 | cache: Cache for autoregressive generation 205 | 206 | Returns: 207 | ViTModelOutput with embeddings 208 | """ 209 | # Get input embeddings with merged image features using our batch-aware method 210 | inputs_embeds = self.get_input_embeddings_batch( 211 | input_ids, pixel_values, image_grid_thw 212 | ) 213 | 214 | # Run through the language model 215 | output = self.vlm.language_model.model( 216 | None, inputs_embeds=inputs_embeds, mask=None, cache=cache 217 | ) 218 | 219 | # Project to embedding dimension 220 | embeddings = self.embedding_proj_layer(output) 221 | 222 | # L2 normalize the embeddings 223 | embeddings = normalize_embeddings(embeddings) 224 | 225 | # Apply attention mask if provided 226 | if attention_mask is not None: 227 | embeddings = embeddings * attention_mask[:, :, None] 228 | 229 | if pixel_values is None: 230 | return ViTModelOutput( 231 | text_embeds=embeddings, 232 | ) 233 | else: 234 | return ViTModelOutput( 235 | image_embeds=embeddings, 236 | ) 237 | 238 | def sanitize(self, weights): 239 | """Sanitize weights for loading.""" 240 | sanitized_weights = {} 241 | 242 | for k, v in weights.items(): 243 | # Handle the projection layer 244 | if k.startswith("embedding_proj_layer"): 245 | sanitized_weights[k] = v 246 | # Handle VLM weights - need to fix the paths 247 | elif k.startswith("vlm."): 248 | # The HuggingFace model has a different structure: 249 | # HF: vlm.model.visual.* -> MLX: vlm.vision_tower.* 250 | # HF: vlm.model.language_model.* -> MLX: vlm.language_model.model.* 251 | 252 | new_key = k 253 | 254 | # First, fix vision/visual path 255 | if "vlm.model.visual." in k: 256 | new_key = k.replace("vlm.model.visual.", "vlm.vision_tower.") 257 | # Then fix the language model path structure 258 | elif "vlm.model.language_model." in k: 259 | # Replace vlm.model.language_model. with vlm.language_model.model. 260 | new_key = k.replace( 261 | "vlm.model.language_model.", "vlm.language_model.model." 262 | ) 263 | 264 | # Special handling for patch_embed.proj.weight 265 | if new_key == "vlm.vision_tower.patch_embed.proj.weight": 266 | # Check if we need to transpose based on the shape 267 | # HF format: (out_channels, in_channels, temporal, height, width) -> e.g., (1280, 3, 2, 14, 14) 268 | # MLX format: (out_channels, temporal, height, width, in_channels) -> e.g., (1280, 2, 14, 14, 3) 269 | if v.shape[1] == 3 and v.shape[2] == 2: # HF format detected 270 | # Transpose from HF format to MLX format 271 | v = v.transpose(0, 2, 3, 4, 1) 272 | 273 | # Now apply VLM-specific sanitization 274 | if hasattr(self.vlm, "sanitize"): 275 | # Remove the "vlm." prefix for VLM sanitization 276 | vlm_key = new_key[4:] 277 | vlm_weights = {vlm_key: v} 278 | vlm_weights = self.vlm.sanitize(vlm_weights) 279 | for vk, vv in vlm_weights.items(): 280 | sanitized_weights[f"vlm.{vk}"] = vv 281 | else: 282 | sanitized_weights[new_key] = v 283 | else: 284 | # Handle any other weights that might not have the vlm prefix 285 | sanitized_weights[k] = v 286 | 287 | return sanitized_weights 288 | 289 | @staticmethod 290 | def from_pretrained(path_or_hf_repo: str): 291 | """Load a pretrained ColQwen2_5 model.""" 292 | import json 293 | from pathlib import Path 294 | 295 | from huggingface_hub import snapshot_download 296 | 297 | path = Path(path_or_hf_repo) 298 | if not path.exists(): 299 | path = Path( 300 | snapshot_download( 301 | repo_id=path_or_hf_repo, 302 | allow_patterns=[ 303 | "*.json", 304 | "*.safetensors", 305 | "*.py", 306 | "tokenizer.model", 307 | "*.tiktoken", 308 | ], 309 | ) 310 | ) 311 | 312 | # Load config 313 | with open(path / "config.json", "r") as f: 314 | config_dict = json.load(f) 315 | 316 | # Create config object 317 | config = ModelArgs.from_dict(config_dict) 318 | 319 | # Create model 320 | model = Model(config) 321 | 322 | # Load weights 323 | weight_files = list(path.glob("*.safetensors")) 324 | if not weight_files: 325 | raise FileNotFoundError(f"No safetensors found in {path}") 326 | 327 | weights = {} 328 | for wf in weight_files: 329 | weights.update(mx.load(wf)) 330 | 331 | # Sanitize weights 332 | weights = model.sanitize(weights) 333 | 334 | # Load weights into model 335 | model.load_weights(list(weights.items())) 336 | 337 | return model 338 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MLX-Embeddings 2 | 3 | [![image](https://img.shields.io/pypi/v/mlx-embeddings.svg)](https://pypi.python.org/pypi/mlx-embeddings) [![Upload Python Package](https://github.com/Blaizzy/mlx-embeddings/actions/workflows/python-publish.yaml/badge.svg)](https://github.com/Blaizzy/mlx-embeddings/actions/workflows/python-publish.yaml) 4 | 5 | **MLX-Embeddings is a package for running Vision and Language Embedding models locally on your Mac using MLX.** 6 | 7 | - Free software: GNU General Public License v3 8 | 9 | ## Features 10 | 11 | - Generate embeddings for text and images using MLX models 12 | - Support for single-item and batch processing 13 | - Utilities for comparing text similarities 14 | 15 | ## Supported Models Archictectures 16 | MLX-Embeddings supports a variety of model architectures for text embedding tasks. Here's a breakdown of the currently supported architectures: 17 | - XLM-RoBERTa (Cross-lingual Language Model - Robustly Optimized BERT Approach) 18 | - BERT (Bidirectional Encoder Representations from Transformers) 19 | - ModernBERT (modernized bidirectional encoder-only Transformer model) 20 | - Qwen3 (Qwen3's embedding model) 21 | 22 | We're continuously working to expand our support for additional model architectures. Check our GitHub repository or documentation for the most up-to-date list of supported models and their specific versions. 23 | 24 | ## Installation 25 | 26 | You can install mlx-embeddings using pip: 27 | 28 | ```bash 29 | pip install mlx-embeddings 30 | ``` 31 | 32 | ## Usage 33 | 34 | ### Single Item Embedding 35 | 36 | 37 | #### Text Embedding 38 | To generate an embedding for a single piece of text: 39 | 40 | ```python 41 | from mlx_embeddings.utils import load 42 | 43 | # Load the model and tokenizer 44 | model_name = "mlx-community/all-MiniLM-L6-v2-4bit" 45 | model, tokenizer = load(model_name) 46 | 47 | # Prepare the text 48 | text = "I like reading" 49 | 50 | # Tokenize and generate embedding 51 | input_ids = tokenizer.encode(text, return_tensors="mlx") 52 | outputs = model(input_ids) 53 | raw_embeds = outputs.last_hidden_state[:, 0, :] # CLS token 54 | text_embeds = outputs.text_embeds # mean pooled and normalized embeddings 55 | ``` 56 | 57 | Note : text-embeds use mean pooling for bert and xlm-robert. For modernbert, pooling strategy is set through the config file, defaulting to mean 58 | 59 | #### Masked Language Modeling 60 | 61 | To generate embeddings for masked language modeling tasks: 62 | 63 | ```python 64 | from mlx_embeddings.utils import load 65 | 66 | # Load ModernBERT model and tokenizer 67 | model, tokenizer = load("mlx-community/answerdotai-ModernBERT-base-4bit") 68 | 69 | # Masked Language Modeling example 70 | text = "The capital of France is [MASK]." 71 | inputs = tokenizer.encode(text, return_tensors="mlx") 72 | outputs = model(inputs) 73 | 74 | # Get predictions for the masked token 75 | masked_index = inputs.tolist()[0].index(tokenizer.mask_token_id) 76 | predicted_token_id = mx.argmax(outputs.pooler_output[0, masked_index]).tolist() 77 | predicted_token = tokenizer.decode(predicted_token_id) 78 | print("Predicted token:", predicted_token) # Should output: Paris 79 | ``` 80 | 81 | #### Sequence classification 82 | ```python 83 | from mlx_embeddings.utils import load 84 | 85 | # Load ModernBERT model and tokenizer 86 | model, tokenizer = load( 87 | "NousResearch/Minos-v1", 88 | ) 89 | 90 | id2label=model.config.id2label 91 | 92 | # Masked Language Modeling example 93 | text = "<|user|> Explain the theory of relativity in simple terms. <|assistant|> Imagine space and time are like a stretchy fabric. Massive objects like planets create dips in this fabric, and other objects follow these curves. That's gravity! Also, the faster you move, the slower time passes for you compared to someone standing still" 94 | inputs = tokenizer.encode(text, return_tensors="mlx") 95 | outputs = model(inputs) 96 | 97 | # Get predictions for the masked token 98 | predictions = outputs.pooler_output[0] # Shape: (num_label,) 99 | print(text) 100 | 101 | # Print results 102 | print("\nTop predictions for classification:") 103 | for idx, logit in enumerate(predictions.tolist()): 104 | label = id2label[str(idx)] 105 | print(f"{label}: {logit:.3f}") 106 | ``` 107 | 108 | ### Batch Processing 109 | 110 | #### Multiple Texts Comparison 111 | 112 | To embed multiple texts and compare them using their embeddings: 113 | 114 | ```python 115 | from sklearn.metrics.pairwise import cosine_similarity 116 | import matplotlib.pyplot as plt 117 | import seaborn as sns 118 | import mlx.core as mx 119 | from mlx_embeddings.utils import load 120 | 121 | # Load the model and tokenizer 122 | model, tokenizer = load("mlx-community/all-MiniLM-L6-v2-4bit") 123 | 124 | def get_embedding(texts, model, tokenizer): 125 | inputs = tokenizer.batch_encode_plus(texts, return_tensors="mlx", padding=True, truncation=True, max_length=512) 126 | outputs = model( 127 | inputs["input_ids"], 128 | attention_mask=inputs["attention_mask"] 129 | ) 130 | return outputs.text_embeds # mean pooled and normalized embeddings 131 | 132 | def compute_and_print_similarity(embeddings): 133 | B, _ = embeddings.shape 134 | similarity_matrix = cosine_similarity(embeddings) 135 | print("Similarity matrix between sequences:") 136 | print(similarity_matrix) 137 | print("\n") 138 | 139 | for i in range(B): 140 | for j in range(i+1, B): 141 | print(f"Similarity between sequence {i+1} and sequence {j+1}: {similarity_matrix[i][j]:.4f}") 142 | 143 | return similarity_matrix 144 | 145 | # Visualize results 146 | def plot_similarity_matrix(similarity_matrix, labels): 147 | plt.figure(figsize=(5, 4)) 148 | sns.heatmap(similarity_matrix, annot=True, cmap='coolwarm', xticklabels=labels, yticklabels=labels) 149 | plt.title('Similarity Matrix Heatmap') 150 | plt.tight_layout() 151 | plt.show() 152 | 153 | # Sample texts 154 | texts = [ 155 | "I like grapes", 156 | "I like fruits", 157 | "The slow green turtle crawls under the busy ant." 158 | ] 159 | 160 | embeddings = get_embedding(texts, model, tokenizer) 161 | similarity_matrix = compute_and_print_similarity(embeddings) 162 | 163 | # Visualize results 164 | labels = [f"Text {i+1}" for i in range(len(texts))] 165 | plot_similarity_matrix(similarity_matrix, labels) 166 | ``` 167 | 168 | #### Masked Language Modeling 169 | 170 | To get predictions for the masked token in multiple texts: 171 | 172 | ```python 173 | import mlx.core as mx 174 | from mlx_embeddings.utils import load 175 | 176 | # Load the model and tokenizer 177 | model, tokenizer = load("mlx-community/answerdotai-ModernBERT-base-4bit") 178 | 179 | text = ["The capital of France is [MASK].", "The capital of Poland is [MASK]."] 180 | inputs = tokenizer.batch_encode_plus(text, return_tensors="mlx", padding=True, truncation=True, max_length=512) 181 | outputs = model(**inputs) 182 | 183 | # To get predictions for the mask: 184 | # Find mask token indices for each sequence in the batch 185 | # Find mask indices for all sequences in batch 186 | mask_indices = mx.array([ids.tolist().index(tokenizer.mask_token_id) for ids in inputs["input_ids"]]) 187 | 188 | # Get predictions for all masked tokens at once 189 | batch_indices = mx.arange(len(mask_indices)) 190 | predicted_token_ids = mx.argmax(outputs.pooler_output[batch_indices, mask_indices], axis=-1).tolist() 191 | 192 | # Decode the predicted tokens 193 | predicted_token = tokenizer.batch_decode(predicted_token_ids) 194 | 195 | print("Predicted token:", predicted_token) 196 | # Predicted token: Paris, Warsaw 197 | ``` 198 | 199 | 200 | ## Vision Transformer Models 201 | 202 | MLX-Embeddings also supports vision models that can generate embeddings for images or image-text pairs. 203 | 204 | ### Single Image Processing 205 | 206 | To evaluate how well an image matches different text descriptions: 207 | 208 | ```python 209 | import mlx.core as mx 210 | from mlx_embeddings.utils import load 211 | import requests 212 | from PIL import Image 213 | 214 | # Load vision model and processor 215 | model, processor = load("mlx-community/siglip-so400m-patch14-384") 216 | 217 | # Load an image 218 | url = "http://images.cocodataset.org/val2017/000000039769.jpg" 219 | image = Image.open(requests.get(url, stream=True).raw) 220 | 221 | # Create text descriptions to compare with the image 222 | texts = ["a photo of 2 dogs", "a photo of 2 cats"] 223 | 224 | # Process inputs 225 | inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") 226 | pixel_values = mx.array(inputs.pixel_values).transpose(0, 2, 3, 1).astype(mx.float32) 227 | input_ids = mx.array(inputs.input_ids) 228 | 229 | # Generate embeddings and calculate similarity 230 | outputs = model(pixel_values=pixel_values, input_ids=input_ids) 231 | logits_per_image = outputs.logits_per_image 232 | probs = mx.sigmoid(logits_per_image) # probabilities of image matching each text 233 | 234 | # Print results 235 | print(f"{probs[0][0]:.1%} that image matches '{texts[0]}'") 236 | print(f"{probs[0][1]:.1%} that image matches '{texts[1]}'") 237 | ``` 238 | 239 | ### Batch Processing for Multiple Images comparison 240 | 241 | Process multiple images and compare them with text descriptions: 242 | 243 | ```python 244 | import mlx.core as mx 245 | from mlx_embeddings.utils import load 246 | import requests 247 | from PIL import Image 248 | import matplotlib.pyplot as plt 249 | import seaborn as sns 250 | 251 | # Load vision model and processor 252 | model, processor = load("mlx-community/siglip-so400m-patch14-384") 253 | 254 | # Load multiple images 255 | image_urls = [ 256 | "./images/cats.jpg", # cats 257 | "./images/desktop_setup.png" # desktop setup 258 | ] 259 | images = [Image.open(requests.get(url, stream=True).raw) if url.startswith("http") else Image.open(url) for url in image_urls] 260 | 261 | # Text descriptions 262 | texts = ["a photo of cats", "a photo of a desktop setup", "a photo of a person"] 263 | 264 | # Process all image-text pairs 265 | all_probs = [] 266 | 267 | 268 | # Process all image-text pairs in batch 269 | inputs = processor(text=texts, images=images, padding="max_length", return_tensors="pt") 270 | pixel_values = mx.array(inputs.pixel_values).transpose(0, 2, 3, 1).astype(mx.float32) 271 | input_ids = mx.array(inputs.input_ids) 272 | 273 | # Generate embeddings and calculate similarity 274 | outputs = model(pixel_values=pixel_values, input_ids=input_ids) 275 | logits_per_image = outputs.logits_per_image 276 | probs = mx.sigmoid(logits_per_image) # probabilities for this image 277 | all_probs.append(probs.tolist()) 278 | 279 | 280 | # Print results for this image 281 | for i, image in enumerate(images): 282 | print(f"Image {i+1}:") 283 | for j, text in enumerate(texts): 284 | print(f" {probs[i][j]:.1%} match with '{text}'") 285 | print() 286 | 287 | # Visualize results with a heatmap 288 | def plot_similarity_matrix(probs_matrix, image_labels, text_labels): 289 | # Convert to 2D numpy array if needed 290 | import numpy as np 291 | probs_matrix = np.array(probs_matrix) 292 | 293 | # Ensure we have a 2D matrix for the heatmap 294 | if probs_matrix.ndim > 2: 295 | probs_matrix = probs_matrix.squeeze() 296 | 297 | plt.figure(figsize=(8, 5)) 298 | sns.heatmap(probs_matrix, annot=True, cmap='viridis', 299 | xticklabels=text_labels, yticklabels=image_labels, 300 | fmt=".1%", vmin=0, vmax=1) 301 | plt.title('Image-Text Match Probability') 302 | plt.tight_layout() 303 | plt.show() 304 | 305 | # Plot the images for reference 306 | plt.figure(figsize=(8, 5)) 307 | for i, image in enumerate(images): 308 | plt.subplot(1, len(images), i+1) 309 | plt.imshow(image) 310 | plt.title(f"Image {i+1}") 311 | plt.axis('off') 312 | plt.tight_layout() 313 | plt.show() 314 | 315 | image_labels = [f"Image {i+1}" for i in range(len(images))] 316 | plot_similarity_matrix(all_probs, image_labels, texts) 317 | ``` 318 | 319 | ### Late Interaction Multimodal Retrival Models (ColPali/ColQwen) 320 | 321 | ```python 322 | import mlx.core as mx 323 | from mlx_embeddings.utils import load 324 | import requests 325 | from PIL import Image 326 | import torch 327 | 328 | # Load vision model and processor 329 | model, processor = load("qnguyen3/colqwen2.5-v0.2-mlx") 330 | 331 | # Load an image 332 | 333 | url_1 = "https://upload.wikimedia.org/wikipedia/commons/8/89/US-original-Declaration-1776.jpg" 334 | image_1 = Image.open(url_1) 335 | 336 | url_2 = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Romeoandjuliet1597.jpg/500px-Romeoandjuliet1597.jpg" 337 | image_2 = Image.open(url_2) 338 | 339 | # Create text descriptions to compare with the image 340 | texts = ["how many percent of data are books?", "evaluation results between models"] 341 | 342 | # Process inputs - text and images need to be processed separately for ColQwen2.5 343 | text_inputs = processor(text=texts, padding=True, return_tensors="pt") 344 | image_inputs = processor(images=[image_1, image_2], padding=True, return_tensors="pt") 345 | 346 | # Convert to MLX arrays 347 | text_input_ids = mx.array(text_inputs.input_ids) 348 | 349 | image_input_ids = mx.array(image_inputs.input_ids) 350 | pixel_values = mx.array(image_inputs.pixel_values) 351 | image_grid_thw = mx.array(image_inputs.image_grid_thw) 352 | 353 | text_embeddings = model(input_ids=text_input_ids) 354 | image_embeddings = model( 355 | input_ids=image_input_ids, 356 | pixel_values=pixel_values, 357 | image_grid_thw=image_grid_thw, 358 | ) 359 | 360 | print(text_embeddings.text_embeds.shape) 361 | print(image_embeddings.image_embeds.shape) 362 | 363 | ## convert to torch 364 | import torch 365 | text_embeddings = torch.tensor(text_embeddings.text_embeds) 366 | image_embeddings = torch.tensor(image_embeddings.image_embeds) 367 | 368 | scores = processor.score_retrieval(text_embeddings, image_embeddings) 369 | print(scores) 370 | ``` 371 | 372 | ## Contributing 373 | 374 | Contributions to MLX-Embeddings are welcome! Please refer to our contribution guidelines for more information. 375 | 376 | ## License 377 | 378 | This project is licensed under the GNU General Public License v3. 379 | 380 | ## Contact 381 | 382 | For any questions or issues, please open an issue on the [GitHub repository](https://github.com/Blaizzy/mlx-embeddings). 383 | -------------------------------------------------------------------------------- /mlx_embeddings/models/xlm_roberta.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from typing import Optional, Tuple 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | 8 | from .base import BaseModelArgs, BaseModelOutput, mean_pooling, normalize_embeddings 9 | 10 | 11 | @dataclass 12 | class ModelArgs(BaseModelArgs): 13 | model_type: str 14 | hidden_size: int 15 | num_hidden_layers: int 16 | intermediate_size: int 17 | num_attention_heads: int 18 | max_position_embeddings: int 19 | layer_norm_eps: int = 1e-05 20 | vocab_size: int = 46166 21 | add_pooling_layer: bool = True 22 | attention_probs_dropout_prob: float = 0.1 23 | hidden_dropout_prob: float = 0.1 24 | type_vocab_size: int = 1 25 | output_past: bool = True 26 | pad_token_id: int = 1 27 | position_embedding_type: str = "absolute" 28 | pooling_config: dict = None 29 | 30 | 31 | class XLMRobertaEmbeddings(nn.Module): 32 | def __init__(self, config: ModelArgs): 33 | super().__init__() 34 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 35 | self.position_embeddings = nn.Embedding( 36 | config.max_position_embeddings, config.hidden_size 37 | ) 38 | self.token_type_embeddings = nn.Embedding( 39 | config.type_vocab_size, config.hidden_size 40 | ) 41 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 42 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 43 | self.padding_idx = config.pad_token_id 44 | 45 | def create_position_ids_from_input_ids( 46 | self, input_ids, padding_idx, past_key_values_length=0 47 | ): 48 | mask = mx.where(input_ids != padding_idx, 1, 0) 49 | incremental_indices = (mx.cumsum(mask, axis=1) + past_key_values_length) * mask 50 | return incremental_indices + padding_idx 51 | 52 | def __call__( 53 | self, 54 | input_ids: mx.array, 55 | token_type_ids=None, 56 | position_ids=None, 57 | inputs_embeds=None, 58 | past_key_values_length=0, 59 | ) -> mx.array: 60 | if input_ids is not None: 61 | input_shape = input_ids.shape 62 | else: 63 | input_shape = inputs_embeds.shape[:-1] 64 | 65 | seq_length = input_shape[1] 66 | 67 | if position_ids is None: 68 | position_ids = self.create_position_ids_from_input_ids( 69 | input_ids, self.padding_idx, past_key_values_length 70 | ) 71 | 72 | if token_type_ids is None: 73 | token_type_ids = mx.zeros(input_shape, dtype=mx.int32) 74 | 75 | if inputs_embeds is None: 76 | inputs_embeds = self.word_embeddings(input_ids) 77 | 78 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 79 | 80 | embeddings = inputs_embeds + token_type_embeddings 81 | position_embeddings = self.position_embeddings(position_ids) 82 | embeddings += position_embeddings 83 | 84 | embeddings = self.LayerNorm(embeddings) 85 | embeddings = self.dropout(embeddings) 86 | return embeddings 87 | 88 | 89 | class XLMRobertaSelfAttention(nn.Module): 90 | def __init__(self, config: ModelArgs): 91 | super().__init__() 92 | self.num_attention_heads = config.num_attention_heads 93 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 94 | self.all_head_size = self.num_attention_heads * self.attention_head_size 95 | 96 | self.scale = self.all_head_size**-0.5 97 | 98 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 99 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 100 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 101 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 102 | 103 | def transpose_for_scores(self, x): 104 | new_x_shape = x.shape[:-1] + ( 105 | self.num_attention_heads, 106 | self.attention_head_size, 107 | ) 108 | x = x.reshape(new_x_shape) 109 | return x.transpose(0, 2, 1, 3) 110 | 111 | def __call__( 112 | self, x: mx.array, attention_mask=None, head_mask=None, output_attentions=False 113 | ): 114 | queries, keys, values = self.query(x), self.key(x), self.value(x) 115 | 116 | # Prepare the queries, keys and values for the attention computation 117 | queries = self.transpose_for_scores(queries) 118 | keys = self.transpose_for_scores(keys) 119 | values = self.transpose_for_scores(values) 120 | 121 | attention_scores = queries @ keys.swapaxes(-1, -2) 122 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 123 | 124 | if attention_mask is not None: 125 | attention_scores = attention_scores + attention_mask 126 | 127 | attention_probs = nn.softmax( 128 | attention_scores.astype(mx.float32), axis=-1 129 | ).astype(attention_scores.dtype) 130 | attention_probs = self.dropout(attention_probs) 131 | 132 | if head_mask is not None: 133 | attention_probs = attention_probs * mx.array(head_mask) 134 | 135 | context_layer = attention_probs @ values 136 | context_layer = context_layer.transpose(0, 2, 1, 3) 137 | new_context_layer_shape = context_layer.shape[:-2] + (self.all_head_size,) 138 | context_layer = context_layer.reshape(new_context_layer_shape) 139 | 140 | outputs = ( 141 | (context_layer, attention_probs) if output_attentions else (context_layer,) 142 | ) 143 | return outputs 144 | 145 | 146 | class XLMRobertaSelfOutput(nn.Module): 147 | def __init__(self, config: ModelArgs): 148 | super().__init__() 149 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 150 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 151 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 152 | 153 | def __call__(self, hidden_states, input_tensor): 154 | hidden_states = self.dense(hidden_states) 155 | hidden_states = self.dropout(hidden_states) 156 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 157 | return hidden_states 158 | 159 | 160 | class XLMRobertaAttention(nn.Module): 161 | def __init__(self, config: ModelArgs): 162 | super().__init__() 163 | self.self = XLMRobertaSelfAttention(config) 164 | self.output = XLMRobertaSelfOutput(config) 165 | 166 | def __call__( 167 | self, 168 | hidden_states, 169 | attention_mask=None, 170 | head_mask=None, 171 | output_attentions=False, 172 | ): 173 | self_outputs = self.self( 174 | hidden_states, attention_mask, head_mask, output_attentions 175 | ) 176 | attention_output = self.output(self_outputs[0], hidden_states) 177 | outputs = (attention_output,) + self_outputs[1:] 178 | return outputs 179 | 180 | 181 | class XLMRobertaIntermediate(nn.Module): 182 | def __init__(self, config: ModelArgs): 183 | super().__init__() 184 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 185 | 186 | def __call__(self, hidden_states): 187 | hidden_states = self.dense(hidden_states) 188 | hidden_states = nn.gelu(hidden_states) 189 | return hidden_states 190 | 191 | 192 | class XLMRobertaOutput(nn.Module): 193 | def __init__(self, config: ModelArgs): 194 | super().__init__() 195 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 196 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 197 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 198 | 199 | def __call__(self, hidden_states, input_tensor): 200 | hidden_states = self.dense(hidden_states) 201 | hidden_states = self.dropout(hidden_states) 202 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 203 | return hidden_states 204 | 205 | 206 | class XLMRobertaLayer(nn.Module): 207 | def __init__(self, config): 208 | super().__init__() 209 | self.attention = XLMRobertaAttention(config) 210 | self.intermediate = XLMRobertaIntermediate(config) 211 | self.output = XLMRobertaOutput(config) 212 | 213 | def __call__( 214 | self, 215 | hidden_states, 216 | attention_mask=None, 217 | head_mask=None, 218 | output_attentions=False, 219 | ): 220 | attention_outputs = self.attention( 221 | hidden_states, attention_mask, head_mask, output_attentions 222 | ) 223 | attention_output = attention_outputs[0] 224 | intermediate_output = self.intermediate(attention_output) 225 | layer_output = self.output(intermediate_output, attention_output) 226 | outputs = (layer_output,) + attention_outputs[1:] 227 | return outputs 228 | 229 | 230 | class XLMRobertaEncoder(nn.Module): 231 | def __init__(self, config: ModelArgs): 232 | super().__init__() 233 | self.config = config 234 | self.layer = [XLMRobertaLayer(config) for _ in range(config.num_hidden_layers)] 235 | 236 | def __call__( 237 | self, 238 | hidden_states, 239 | attention_mask=None, 240 | head_mask=None, 241 | output_attentions=False, 242 | output_hidden_states=False, 243 | ): 244 | all_hidden_states = () if output_hidden_states else None 245 | all_attentions = () if output_attentions else None 246 | 247 | for i, layer_module in enumerate(self.layer): 248 | if output_hidden_states: 249 | all_hidden_states = all_hidden_states + (hidden_states,) 250 | 251 | layer_head_mask = head_mask[i] if head_mask is not None else None 252 | 253 | layer_outputs = layer_module( 254 | hidden_states, attention_mask, layer_head_mask, output_attentions 255 | ) 256 | 257 | hidden_states = layer_outputs[0] 258 | 259 | if output_attentions: 260 | all_attentions = all_attentions + (layer_outputs[1],) 261 | 262 | if output_hidden_states: 263 | all_hidden_states = all_hidden_states + (hidden_states,) 264 | 265 | return tuple( 266 | v 267 | for v in [hidden_states, all_hidden_states, all_attentions] 268 | if v is not None 269 | ) 270 | 271 | 272 | class XLMRobertaPooler(nn.Module): 273 | def __init__(self, config: ModelArgs): 274 | super().__init__() 275 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 276 | self.activation = nn.Tanh() 277 | 278 | def __call__(self, hidden_states): 279 | # We "pool" the model by simply taking the hidden state corresponding 280 | # to the first token. 281 | first_token_tensor = hidden_states[:, 0] 282 | pooled_output = self.dense(first_token_tensor) 283 | pooled_output = self.activation(pooled_output) 284 | return pooled_output 285 | 286 | 287 | class Model(nn.Module): 288 | def __init__(self, config: ModelArgs): 289 | super().__init__() 290 | self.config = config 291 | self.embeddings = XLMRobertaEmbeddings(config) 292 | self.encoder = XLMRobertaEncoder(config) 293 | self.pooler = XLMRobertaPooler(config) if config.add_pooling_layer else None 294 | 295 | def get_extended_attention_mask(self, attention_mask, input_shape): 296 | if attention_mask.ndim == 3: 297 | extended_attention_mask = attention_mask[:, None, :, :] 298 | elif attention_mask.ndim == 2: 299 | extended_attention_mask = attention_mask[:, None, None, :] 300 | else: 301 | raise ValueError( 302 | f"Wrong shape for attention_mask (shape {attention_mask.shape})" 303 | ) 304 | 305 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 306 | return extended_attention_mask 307 | 308 | def get_head_mask(self, head_mask, num_hidden_layers): 309 | if head_mask is None: 310 | return [1] * num_hidden_layers 311 | 312 | if isinstance(head_mask, mx.array) and len(head_mask.shape) == 1: 313 | head_mask = mx.expand_dims(mx.expand_dims(head_mask, axis=0), axis=0) 314 | head_mask = mx.broadcast_to(head_mask, (num_hidden_layers, -1, -1)) 315 | elif isinstance(head_mask, mx.array) and len(head_mask.shape) == 2: 316 | head_mask = mx.expand_dims(mx.expand_dims(head_mask, axis=1), axis=-1) 317 | 318 | return mx.array(head_mask) 319 | 320 | def __call__( 321 | self, 322 | input_ids, 323 | attention_mask=None, 324 | token_type_ids=None, 325 | position_ids=None, 326 | head_mask=None, 327 | output_attentions=False, 328 | output_hidden_states=False, 329 | ): 330 | 331 | input_shape = input_ids.shape 332 | 333 | if attention_mask is None: 334 | attention_mask = mx.ones(input_shape) 335 | if token_type_ids is None: 336 | token_type_ids = mx.zeros(input_shape, dtype=mx.int64) 337 | 338 | extended_attention_mask = self.get_extended_attention_mask( 339 | attention_mask, input_shape 340 | ) 341 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 342 | 343 | embedding_output = self.embeddings(input_ids, token_type_ids, position_ids) 344 | encoder_outputs = self.encoder( 345 | embedding_output, 346 | extended_attention_mask, 347 | output_attentions=output_attentions, 348 | output_hidden_states=output_hidden_states, 349 | ) 350 | sequence_output = encoder_outputs[0] 351 | pooled_output = ( 352 | self.pooler(sequence_output) if self.pooler is not None else None 353 | ) 354 | 355 | # normalized features 356 | text_embeds = mean_pooling(sequence_output, attention_mask) 357 | text_embeds = normalize_embeddings(text_embeds) 358 | 359 | return BaseModelOutput( 360 | last_hidden_state=sequence_output, 361 | text_embeds=text_embeds, 362 | pooler_output=pooled_output, 363 | hidden_states=encoder_outputs[1:], 364 | ) 365 | 366 | def sanitize(self, weights): 367 | sanitized_weights = {} 368 | for k, v in weights.items(): 369 | if "position_ids" in k: 370 | # Remove unused position_ids 371 | continue 372 | else: 373 | sanitized_weights[k] = v 374 | return sanitized_weights 375 | -------------------------------------------------------------------------------- /mlx_embeddings/tests/test_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Tests for `mlx_embeddings` package.""" 4 | import unittest 5 | 6 | import mlx.core as mx 7 | from mlx.utils import tree_map 8 | 9 | 10 | class TestModels(unittest.TestCase): 11 | 12 | def model_test_runner( 13 | self, 14 | model, 15 | model_type, 16 | num_layers, 17 | last_hidden_state_is_sequence=True, 18 | text_embeds_is_sequence=False, 19 | ): 20 | self.assertEqual(model.config.model_type, model_type) 21 | if hasattr(model, "encoder"): 22 | self.assertEqual(len(model.encoder.layer), num_layers) 23 | elif hasattr(model, "model"): 24 | self.assertEqual(len(model.model.layers), num_layers) 25 | 26 | batch_size = 1 27 | seq_length = 5 28 | 29 | model.update(tree_map(lambda p: p.astype(mx.float32), model.parameters())) 30 | 31 | inputs = mx.array([[0, 1, 2, 3, 4]]) 32 | outputs = model(inputs) 33 | self.assertEqual(outputs.last_hidden_state.dtype, mx.float32) 34 | 35 | # Verify last_hidden_state shape 36 | expected_hidden_shape = ( 37 | (batch_size, seq_length, model.config.hidden_size) 38 | if last_hidden_state_is_sequence 39 | else (batch_size, model.config.hidden_size) 40 | ) 41 | self.assertEqual(outputs.last_hidden_state.shape, expected_hidden_shape) 42 | 43 | output_dim = getattr(model.config, "out_features", model.config.hidden_size) 44 | 45 | # Verify text_embeds shape 46 | expected_embeds_shape = ( 47 | (batch_size, seq_length, output_dim) 48 | if text_embeds_is_sequence 49 | else (batch_size, output_dim) 50 | ) 51 | self.assertEqual(outputs.text_embeds.shape, expected_embeds_shape) 52 | self.assertEqual(outputs.text_embeds.dtype, mx.float32) 53 | 54 | def vlm_model_test_runner(self, model, model_type, num_layers): 55 | self.assertEqual(model.config.model_type, model_type) 56 | 57 | # Check layers based on model architecture 58 | if hasattr(model, "encoder"): 59 | self.assertEqual(len(model.encoder.layer), num_layers) 60 | # For SigLIP models, check vision and text layers separately 61 | elif model_type == "siglip": 62 | if ( 63 | hasattr(model, "vision_model") 64 | and hasattr(model.vision_model, "vision_model") 65 | and hasattr(model.vision_model.vision_model, "encoder") 66 | ): 67 | self.assertEqual( 68 | len(model.vision_model.vision_model.encoder.layers), 69 | model.config.vision_config.num_hidden_layers, 70 | ) 71 | if hasattr(model, "text_model") and hasattr(model.text_model, "encoder"): 72 | self.assertEqual( 73 | len(model.text_model.encoder.layers), 74 | model.config.text_config.num_hidden_layers, 75 | ) 76 | 77 | batch_size = 1 78 | seq_length = 5 79 | 80 | # Convert model parameters to float32 for testing 81 | model.update(tree_map(lambda p: p.astype(mx.float32), model.parameters())) 82 | 83 | # Test text-only input if supported 84 | if hasattr(model, "get_text_features") or not hasattr(model, "vision_config"): 85 | text_inputs = mx.array([[0, 1, 2, 3, 4]]) 86 | attention_mask = mx.ones((batch_size, seq_length)) 87 | 88 | if hasattr(model, "get_text_features"): 89 | text_outputs = model.get_text_features(text_inputs, attention_mask) 90 | self.assertIsNotNone(text_outputs) 91 | self.assertEqual(text_outputs.dtype, mx.float32) 92 | else: 93 | text_outputs = model(text_inputs) 94 | self.assertEqual( 95 | text_outputs.last_hidden_state.shape, 96 | (batch_size, seq_length, model.config.hidden_size), 97 | ) 98 | self.assertEqual(text_outputs.last_hidden_state.dtype, mx.float32) 99 | 100 | # Test image-only input if supported 101 | if hasattr(model, "vision_config"): 102 | # Get image size from vision config 103 | image_size = model.vision_config.image_size 104 | # Create dummy image tensor [batch_size, height, width, channels] 105 | image_inputs = mx.random.normal((batch_size, 3, image_size, image_size)) 106 | 107 | if hasattr(model, "get_image_features"): 108 | image_outputs = model.get_image_features(image_inputs) 109 | self.assertIsNotNone(image_outputs) 110 | self.assertEqual(image_outputs.dtype, mx.float32) 111 | elif hasattr(model, "encode_image"): 112 | image_outputs = model.encode_image(image_inputs) 113 | self.assertIsNotNone(image_outputs) 114 | self.assertEqual(image_outputs.dtype, mx.float32) 115 | 116 | # Test multimodal input if model supports both text and image 117 | text_inputs = mx.array([[0, 1, 2, 3, 4]]) 118 | attention_mask = mx.ones((batch_size, seq_length)) 119 | image_size = model.config.vision_config.image_size 120 | image_inputs = mx.random.normal((batch_size, 3, image_size, image_size)) 121 | 122 | # Only try this if the model has a method that accepts both inputs 123 | multimodal_outputs = model( 124 | input_ids=text_inputs, 125 | attention_mask=attention_mask, 126 | pixel_values=image_inputs, 127 | ) 128 | 129 | self.assertEqual( 130 | multimodal_outputs.text_model_output[0].shape, 131 | (batch_size, seq_length, model.config.text_config.hidden_size), 132 | ) 133 | self.assertEqual( 134 | multimodal_outputs.vision_model_output[0].shape, 135 | (batch_size, image_size * 2, model.config.vision_config.hidden_size), 136 | ) 137 | self.assertEqual( 138 | multimodal_outputs.text_embeds.shape, 139 | (batch_size, model.config.text_config.hidden_size), 140 | ) 141 | self.assertEqual( 142 | multimodal_outputs.image_embeds.shape, 143 | (batch_size, model.config.vision_config.hidden_size), 144 | ) 145 | self.assertEqual(multimodal_outputs.logits_per_image.shape, (batch_size, 1)) 146 | self.assertEqual(multimodal_outputs.logits_per_text.shape, (batch_size, 1)) 147 | 148 | def test_xlm_roberta_model(self): 149 | from mlx_embeddings.models import xlm_roberta 150 | 151 | config = xlm_roberta.ModelArgs( 152 | model_type="xlm-roberta", 153 | hidden_size=768, 154 | num_hidden_layers=12, 155 | intermediate_size=3072, 156 | num_attention_heads=12, 157 | max_position_embeddings=512, 158 | vocab_size=250002, 159 | ) 160 | model = xlm_roberta.Model(config) 161 | 162 | self.model_test_runner( 163 | model, 164 | config.model_type, 165 | config.num_hidden_layers, 166 | ) 167 | 168 | def test_bert_model(self): 169 | from mlx_embeddings.models import bert 170 | 171 | config = bert.ModelArgs( 172 | model_type="bert", 173 | hidden_size=384, 174 | num_hidden_layers=6, 175 | intermediate_size=1536, 176 | num_attention_heads=12, 177 | max_position_embeddings=512, 178 | vocab_size=30522, 179 | ) 180 | model = bert.Model(config) 181 | 182 | self.model_test_runner( 183 | model, 184 | config.model_type, 185 | config.num_hidden_layers, 186 | ) 187 | 188 | def test_lfm2_model(self): 189 | from mlx_embeddings.models import lfm2 190 | 191 | config = lfm2.ModelArgs( 192 | model_type="lfm2", 193 | hidden_size=1024, 194 | num_hidden_layers=16, 195 | num_attention_heads=16, 196 | num_key_value_heads=8, 197 | max_position_embeddings=128000, 198 | vocab_size=64402, 199 | norm_eps=1e-05, 200 | layer_types=[ 201 | "conv", 202 | "conv", 203 | "full_attention", 204 | "conv", 205 | "conv", 206 | "full_attention", 207 | "conv", 208 | "conv", 209 | "full_attention", 210 | "conv", 211 | "full_attention", 212 | "conv", 213 | "full_attention", 214 | "conv", 215 | "full_attention", 216 | "conv", 217 | ], 218 | conv_bias=False, 219 | conv_L_cache=3, 220 | block_dim=1024, 221 | block_ff_dim=6656, 222 | block_multiple_of=256, 223 | block_ffn_dim_multiplier=1.0, 224 | block_auto_adjust_ff_dim=True, 225 | rope_theta=1000000.0, 226 | out_features=128, 227 | ) 228 | model = lfm2.Model(config) 229 | 230 | self.model_test_runner( 231 | model, 232 | config.model_type, 233 | config.num_hidden_layers, 234 | text_embeds_is_sequence=True, 235 | last_hidden_state_is_sequence=True, 236 | ) 237 | 238 | def test_modernbert_model_mask_token(self): 239 | from mlx_embeddings.models import modernbert 240 | 241 | config = modernbert.ModelArgs( 242 | architectures=["ModernBertForMaskedLM"], 243 | model_type="modernbert", 244 | hidden_size=768, 245 | num_hidden_layers=22, 246 | intermediate_size=1152, 247 | num_attention_heads=12, 248 | max_position_embeddings=8192, 249 | vocab_size=50368, 250 | ) 251 | model = modernbert.Model(config) 252 | 253 | self.model_test_runner( 254 | model, 255 | config.model_type, 256 | config.num_hidden_layers, 257 | text_embeds_is_sequence=True, 258 | last_hidden_state_is_sequence=True, 259 | ) 260 | 261 | def test_modernbert_model_embeddings(self): 262 | from mlx_embeddings.models import modernbert 263 | 264 | config = modernbert.ModelArgs( 265 | architectures=["ModernBertModel"], 266 | model_type="modernbert", 267 | hidden_size=768, 268 | num_hidden_layers=22, 269 | intermediate_size=1152, 270 | num_attention_heads=12, 271 | max_position_embeddings=8192, 272 | vocab_size=50368, 273 | ) 274 | model = modernbert.Model(config) 275 | 276 | self.model_test_runner( 277 | model, 278 | config.model_type, 279 | config.num_hidden_layers, 280 | last_hidden_state_is_sequence=False, 281 | text_embeds_is_sequence=False, 282 | ) 283 | 284 | def test_siglip_model(self): 285 | from mlx_embeddings.models import siglip 286 | 287 | config = siglip.ModelArgs( 288 | model_type="siglip", 289 | text_config=siglip.TextConfig( 290 | hidden_size=768, 291 | num_hidden_layers=12, 292 | intermediate_size=3072, 293 | num_attention_heads=12, 294 | max_position_embeddings=512, 295 | vocab_size=250002, 296 | ), 297 | vision_config=siglip.VisionConfig( 298 | hidden_size=768, 299 | num_hidden_layers=12, 300 | intermediate_size=3072, 301 | num_attention_heads=12, 302 | image_size=512, 303 | patch_size=16, 304 | ), 305 | ) 306 | model = siglip.Model(config) 307 | 308 | self.vlm_model_test_runner( 309 | model, 310 | config.model_type, 311 | config.text_config.num_hidden_layers, 312 | ) 313 | 314 | def test_siglip2_model(self): 315 | """Test SigLIP2 with new parameters including num_patches for dynamic resolution.""" 316 | from mlx_embeddings.models import siglip 317 | 318 | # Test SigLIP2 with num_patches specified (new SigLIP2 feature) 319 | config = siglip.ModelArgs( 320 | model_type="siglip", 321 | text_config=siglip.TextConfig( 322 | hidden_size=768, 323 | num_hidden_layers=12, 324 | intermediate_size=3072, 325 | num_attention_heads=12, 326 | max_position_embeddings=64, 327 | vocab_size=32000, 328 | ), 329 | vision_config=siglip.VisionConfig( 330 | hidden_size=768, 331 | num_hidden_layers=12, 332 | intermediate_size=3072, 333 | num_attention_heads=12, 334 | image_size=512, # Same as original SigLIP test 335 | patch_size=16, 336 | num_patches=1024, # SigLIP2 feature: (512//16)**2 = 1024 337 | max_num_patches=1024, # SigLIP2 naflex feature 338 | ), 339 | ) 340 | model = siglip.Model(config) 341 | 342 | # Test basic functionality 343 | self.vlm_model_test_runner( 344 | model, 345 | config.model_type, 346 | config.text_config.num_hidden_layers, 347 | ) 348 | 349 | # Test SigLIP2-specific features 350 | import mlx.core as mx 351 | 352 | batch_size = 2 353 | image_size = config.vision_config.image_size # Use the config's image_size 354 | seq_len = 64 355 | 356 | # Test with pixel_attention_mask and spatial_shapes (SigLIP2 naflex features) 357 | pixel_values = mx.random.normal((batch_size, image_size, image_size, 3)) 358 | input_ids = mx.array([[1, 2, 3, 4, 5] + [0] * (seq_len - 5)] * batch_size) 359 | attention_mask = mx.ones((batch_size, seq_len)) 360 | 361 | # SigLIP2 specific parameters 362 | pixel_attention_mask = mx.ones((batch_size, image_size, image_size)) 363 | spatial_shapes = mx.array([[image_size, image_size]] * batch_size) 364 | 365 | # Test forward pass with SigLIP2 parameters 366 | outputs = model( 367 | input_ids=input_ids, 368 | pixel_values=pixel_values, 369 | attention_mask=attention_mask, 370 | pixel_attention_mask=pixel_attention_mask, 371 | spatial_shapes=spatial_shapes, 372 | ) 373 | 374 | # Verify outputs have expected shapes 375 | self.assertIsNotNone(outputs.logits_per_image) 376 | self.assertIsNotNone(outputs.logits_per_text) 377 | self.assertEqual(outputs.logits_per_image.shape, (batch_size, batch_size)) 378 | self.assertEqual(outputs.logits_per_text.shape, (batch_size, batch_size)) 379 | 380 | def test_qwen3_model(self): 381 | from mlx_embeddings.models import qwen3 382 | 383 | config = qwen3.ModelArgs( 384 | model_type="qwen3", 385 | hidden_size=1024, 386 | num_hidden_layers=28, 387 | intermediate_size=3072, 388 | num_attention_heads=16, 389 | num_key_value_heads=8, 390 | head_dim=128, 391 | max_position_embeddings=32768, 392 | vocab_size=151669, 393 | rope_theta=1000000, 394 | ) 395 | model = qwen3.Model(config) 396 | 397 | self.model_test_runner( 398 | model, 399 | config.model_type, 400 | config.num_hidden_layers, 401 | ) 402 | 403 | 404 | if __name__ == "__main__": 405 | unittest.main() 406 | -------------------------------------------------------------------------------- /mlx_embeddings/models/qwen3.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from dataclasses import dataclass, field 4 | from typing import Dict, List, Optional, Union 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | from .base import BaseModelArgs, BaseModelOutput, normalize_embeddings 10 | 11 | 12 | def last_token_pool( 13 | last_hidden_states: mx.array, attention_mask: Optional[mx.array] = None 14 | ) -> mx.array: 15 | """ 16 | Last token pooling implementation 17 | 18 | Args: 19 | last_hidden_states: Hidden states from the model, shape (batch_size, seq_len, hidden_size) 20 | attention_mask: Attention mask, shape (batch_size, seq_len). If None, uses last position. 21 | 22 | Returns: 23 | Pooled embeddings, shape (batch_size, hidden_size) 24 | """ 25 | if attention_mask is None: 26 | return last_hidden_states[:, -1] 27 | 28 | # Check if we have left padding (all sequences end with valid tokens) 29 | left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] 30 | if left_padding: 31 | return last_hidden_states[:, -1] 32 | else: 33 | # Find the last valid token position for each sequence 34 | sequence_lengths = attention_mask.sum(axis=1) - 1 35 | batch_size = last_hidden_states.shape[0] 36 | return last_hidden_states[mx.arange(batch_size), sequence_lengths] 37 | 38 | 39 | @dataclass 40 | class ModelArgs(BaseModelArgs): 41 | # Core architecture 42 | model_type: str = "qwen3" 43 | hidden_size: int = 1024 44 | num_hidden_layers: int = 28 45 | intermediate_size: int = 3072 46 | num_attention_heads: int = 16 47 | num_key_value_heads: Optional[int] = None 48 | head_dim: Optional[int] = None 49 | max_position_embeddings: int = 32768 50 | vocab_size: int = 151669 51 | 52 | # Normalization and regularization 53 | rms_norm_eps: float = 1e-6 54 | attention_dropout: float = 0.0 55 | hidden_dropout: float = 0.0 56 | 57 | # RoPE configuration 58 | rope_theta: float = 1000000.0 59 | rope_scaling: Optional[Dict[str, Union[float, str]]] = None 60 | 61 | # Attention configuration 62 | attention_bias: bool = False 63 | use_sliding_window: bool = False 64 | sliding_window: Optional[int] = None 65 | max_window_layers: Optional[int] = 28 66 | 67 | # Model behavior 68 | tie_word_embeddings: bool = False 69 | hidden_act: str = "silu" 70 | 71 | # Token IDs 72 | bos_token_id: Optional[int] = None 73 | eos_token_id: Optional[int] = None 74 | pad_token_id: Optional[int] = None 75 | 76 | # Architecture variants (for potential future use) 77 | architectures: List[str] = field(default_factory=lambda: ["Qwen3Model"]) 78 | 79 | # Initialization 80 | initializer_range: float = 0.02 81 | 82 | def __post_init__(self): 83 | """Validate and set derived parameters.""" 84 | if self.num_key_value_heads is None: 85 | self.num_key_value_heads = self.num_attention_heads 86 | 87 | if self.head_dim is None: 88 | if self.hidden_size % self.num_attention_heads != 0: 89 | raise ValueError( 90 | f"hidden_size ({self.hidden_size}) must be divisible by " 91 | f"num_attention_heads ({self.num_attention_heads})" 92 | ) 93 | self.head_dim = self.hidden_size // self.num_attention_heads 94 | 95 | 96 | class Qwen3MLP(nn.Module): 97 | """ 98 | Multi-Layer Perceptron (MLP) for Qwen3 with SwiGLU activation. 99 | 100 | Implements the SwiGLU activation function: SiLU(gate_proj(x)) * up_proj(x) 101 | This is a gated activation that has been shown to improve performance 102 | compared to standard activations like ReLU or GELU. 103 | """ 104 | 105 | def __init__(self, config: ModelArgs): 106 | super().__init__() 107 | self.config = config 108 | 109 | # Three linear projections for SwiGLU 110 | self.gate_proj = nn.Linear( 111 | config.hidden_size, config.intermediate_size, bias=False 112 | ) 113 | self.up_proj = nn.Linear( 114 | config.hidden_size, config.intermediate_size, bias=False 115 | ) 116 | self.down_proj = nn.Linear( 117 | config.intermediate_size, config.hidden_size, bias=False 118 | ) 119 | 120 | def __call__(self, x: mx.array) -> mx.array: 121 | """ 122 | Forward pass with SwiGLU activation. 123 | 124 | Args: 125 | x: Input tensor, shape (..., hidden_size) 126 | 127 | Returns: 128 | Output tensor, shape (..., hidden_size) 129 | """ 130 | # SwiGLU: SiLU(gate_proj(x)) * up_proj(x) 131 | gate = self.gate_proj(x) 132 | up = self.up_proj(x) 133 | return self.down_proj(nn.silu(gate) * up) 134 | 135 | 136 | class Qwen3Attention(nn.Module): 137 | """ 138 | Multi-head attention mechanism for Qwen3 with query/key normalization. 139 | 140 | - Grouped query attention 141 | - Query and key normalization 142 | - RoPE (Rotary Position Embedding) support 143 | - Fallback attention computation 144 | """ 145 | 146 | def __init__(self, config: ModelArgs): 147 | super().__init__() 148 | self.config = config 149 | self.hidden_size = config.hidden_size 150 | self.num_heads = config.num_attention_heads 151 | self.head_dim = config.head_dim 152 | self.num_key_value_heads = config.num_key_value_heads 153 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 154 | self.max_position_embeddings = config.max_position_embeddings 155 | self.rope_theta = config.rope_theta 156 | 157 | # Validate configuration 158 | if self.hidden_size % self.num_heads != 0: 159 | raise ValueError( 160 | f"hidden_size ({self.hidden_size}) must be divisible by " 161 | f"num_attention_heads ({self.num_heads})" 162 | ) 163 | 164 | if self.num_heads % self.num_key_value_heads != 0: 165 | raise ValueError( 166 | f"num_attention_heads ({self.num_heads}) must be divisible by " 167 | f"num_key_value_heads ({self.num_key_value_heads})" 168 | ) 169 | 170 | # Projection layers 171 | self.q_proj = nn.Linear( 172 | self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias 173 | ) 174 | self.k_proj = nn.Linear( 175 | self.hidden_size, 176 | self.num_key_value_heads * self.head_dim, 177 | bias=config.attention_bias, 178 | ) 179 | self.v_proj = nn.Linear( 180 | self.hidden_size, 181 | self.num_key_value_heads * self.head_dim, 182 | bias=config.attention_bias, 183 | ) 184 | self.o_proj = nn.Linear( 185 | self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias 186 | ) 187 | 188 | # Query and key normalization for training stability 189 | self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) 190 | self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) 191 | 192 | # Rotary position embeddings 193 | self.rotary_emb = nn.RoPE( 194 | self.head_dim, 195 | traditional=False, 196 | base=self.rope_theta, 197 | ) 198 | 199 | def __call__( 200 | self, 201 | hidden_states: mx.array, 202 | attention_mask: Optional[mx.array] = None, 203 | **kwargs, 204 | ) -> mx.array: 205 | """ 206 | Forward pass for Qwen3 attention. 207 | 208 | Args: 209 | hidden_states: Input hidden states, shape (batch_size, seq_len, hidden_size) 210 | attention_mask: Attention mask, shape (batch_size, 1, seq_len, seq_len) 211 | 212 | Returns: 213 | Attention output, shape (batch_size, seq_len, hidden_size) 214 | """ 215 | bsz, q_len, _ = hidden_states.shape 216 | 217 | # Project to query, key, value 218 | query_states = self.q_proj(hidden_states) 219 | key_states = self.k_proj(hidden_states) 220 | value_states = self.v_proj(hidden_states) 221 | 222 | # Reshape for multi-head attention: (batch, seq_len, num_heads, head_dim) 223 | query_states = query_states.reshape( 224 | bsz, q_len, self.num_heads, self.head_dim 225 | ).transpose(0, 2, 1, 3) 226 | key_states = key_states.reshape( 227 | bsz, q_len, self.num_key_value_heads, self.head_dim 228 | ).transpose(0, 2, 1, 3) 229 | value_states = value_states.reshape( 230 | bsz, q_len, self.num_key_value_heads, self.head_dim 231 | ).transpose(0, 2, 1, 3) 232 | 233 | # Apply query and key normalization for training stability 234 | query_states = self.q_norm(query_states) 235 | key_states = self.k_norm(key_states) 236 | 237 | # Apply rotary position embeddings 238 | query_states = self.rotary_emb(query_states) 239 | key_states = self.rotary_emb(key_states) 240 | 241 | # Expand key/value states for grouped query attention if needed 242 | if self.num_key_value_groups > 1: 243 | key_states = mx.repeat(key_states, self.num_key_value_groups, axis=1) 244 | value_states = mx.repeat(value_states, self.num_key_value_groups, axis=1) 245 | 246 | # Compute attention with MLX's scaled_dot_product_attention 247 | scale = 1.0 / math.sqrt(self.head_dim) 248 | 249 | try: 250 | # Use MLX's fast scaled dot product attention with correct signature 251 | attn_output = mx.fast.scaled_dot_product_attention( 252 | query_states, key_states, value_states, scale=scale, mask=attention_mask 253 | ) 254 | except Exception as e: 255 | # Fallback to manual attention computation 256 | logging.warning(f"Fast attention failed, using fallback: {e}") 257 | 258 | attn_weights = (query_states @ key_states.transpose(0, 1, 3, 2)) * scale 259 | 260 | if attention_mask is not None: 261 | attn_weights = attn_weights + attention_mask 262 | 263 | attn_weights = mx.softmax(attn_weights, axis=-1) 264 | attn_output = attn_weights @ value_states 265 | 266 | # Reshape back to (batch_size, seq_len, hidden_size) 267 | attn_output = attn_output.transpose(0, 2, 1, 3).reshape( 268 | bsz, q_len, self.num_heads * self.head_dim 269 | ) 270 | 271 | # Final output projection 272 | attn_output = self.o_proj(attn_output) 273 | 274 | return attn_output 275 | 276 | 277 | class Qwen3DecoderLayer(nn.Module): 278 | """ 279 | Single decoder layer for Qwen3 transformer. 280 | 281 | Implements the standard transformer decoder layer with: 282 | - Pre-normalization (RMSNorm before attention and MLP) 283 | - Residual connections 284 | - Self-attention mechanism 285 | - Feed-forward network (MLP) 286 | """ 287 | 288 | def __init__(self, config: ModelArgs): 289 | super().__init__() 290 | self.hidden_size = config.hidden_size 291 | 292 | # Self-attention mechanism 293 | self.self_attn = Qwen3Attention(config) 294 | 295 | # Feed-forward network 296 | self.mlp = Qwen3MLP(config) 297 | 298 | # Layer normalization (pre-norm architecture) 299 | self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 300 | self.post_attention_layernorm = nn.RMSNorm( 301 | config.hidden_size, eps=config.rms_norm_eps 302 | ) 303 | 304 | def __call__( 305 | self, 306 | hidden_states: mx.array, 307 | attention_mask: Optional[mx.array] = None, 308 | **kwargs, 309 | ) -> mx.array: 310 | """ 311 | Forward pass for decoder layer. 312 | 313 | Args: 314 | hidden_states: Input hidden states, shape (batch_size, seq_len, hidden_size) 315 | attention_mask: Attention mask for self-attention 316 | 317 | Returns: 318 | Output hidden states, shape (batch_size, seq_len, hidden_size) 319 | """ 320 | # Self-attention with pre-normalization and residual connection 321 | residual = hidden_states 322 | hidden_states = self.input_layernorm(hidden_states) 323 | hidden_states = self.self_attn( 324 | hidden_states, 325 | attention_mask=attention_mask, 326 | ) 327 | hidden_states = residual + hidden_states 328 | 329 | # Feed-forward network with pre-normalization and residual connection 330 | residual = hidden_states 331 | hidden_states = self.post_attention_layernorm(hidden_states) 332 | hidden_states = self.mlp(hidden_states) 333 | hidden_states = residual + hidden_states 334 | 335 | return hidden_states 336 | 337 | 338 | class Qwen3Model(nn.Module): 339 | """ 340 | Qwen3 transformer model 341 | 342 | Full transformer decoder stack with: 343 | - Token embeddings 344 | - Multiple decoder layers 345 | - Final layer normalization 346 | - Causal attention masking 347 | """ 348 | 349 | def __init__(self, config: ModelArgs): 350 | super().__init__() 351 | self.config = config 352 | self.vocab_size = config.vocab_size 353 | self.hidden_size = config.hidden_size 354 | 355 | # Token embeddings 356 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) 357 | 358 | # Decoder layers 359 | self.layers = [ 360 | Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers) 361 | ] 362 | 363 | # Final layer normalization 364 | self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 365 | 366 | def _create_causal_mask(self, seq_length: int, dtype: mx.Dtype) -> mx.array: 367 | """ 368 | Create a causal attention mask for autoregressive generation. 369 | 370 | Args: 371 | seq_length: Sequence length 372 | dtype: Data type for the mask 373 | 374 | Returns: 375 | Causal mask of shape (1, 1, seq_length, seq_length) 376 | """ 377 | # Create lower triangular mask (causal mask) 378 | mask = mx.tril(mx.ones((seq_length, seq_length), dtype=mx.bool_)) 379 | # Convert to additive mask (0 for valid positions, -inf for masked) 380 | mask = mx.where(mask, 0.0, -mx.inf).astype(dtype) 381 | # Add batch and head dimensions: (1, 1, seq_len, seq_len) 382 | return mx.expand_dims(mask, axis=(0, 1)) 383 | 384 | def __call__( 385 | self, 386 | input_ids: mx.array, 387 | attention_mask: Optional[mx.array] = None, 388 | **kwargs, 389 | ) -> mx.array: 390 | """ 391 | Forward pass through the model 392 | 393 | Args: 394 | input_ids: Input token IDs, shape (batch_size, seq_len) 395 | attention_mask: Attention mask, shape (batch_size, seq_len) or (batch_size, 1, seq_len, seq_len) 396 | 397 | Returns: 398 | Hidden states, shape (batch_size, seq_len, hidden_size) 399 | """ 400 | batch_size, seq_length = input_ids.shape 401 | 402 | # Get token embeddings 403 | hidden_states = self.embed_tokens(input_ids) 404 | 405 | # Create or process attention mask 406 | if attention_mask is None: 407 | # Create causal mask for autoregressive generation 408 | attention_mask = self._create_causal_mask(seq_length, hidden_states.dtype) 409 | elif attention_mask.ndim == 2: 410 | # Convert padding mask to additive mask and combine with causal mask 411 | # attention_mask shape: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len) 412 | padding_mask = attention_mask[:, None, None, :] 413 | padding_mask = mx.where(padding_mask == 0, -mx.inf, 0.0).astype( 414 | hidden_states.dtype 415 | ) 416 | 417 | # Create causal mask 418 | causal_mask = self._create_causal_mask(seq_length, hidden_states.dtype) 419 | 420 | # Combine masks (broadcast padding mask to match causal mask shape) 421 | attention_mask = causal_mask + padding_mask 422 | 423 | # Apply transformer layers 424 | for layer in self.layers: 425 | hidden_states = layer( 426 | hidden_states, 427 | attention_mask=attention_mask, 428 | ) 429 | 430 | # Apply final layer normalization 431 | hidden_states = self.norm(hidden_states) 432 | 433 | return hidden_states 434 | 435 | 436 | class Model(nn.Module): 437 | """ 438 | Qwen3 model for embedding generation 439 | 440 | The main model class that wraps the core Qwen3Model and adds 441 | embedding-specific functionality like last token pooling and normalization 442 | """ 443 | 444 | def __init__(self, config: ModelArgs): 445 | super().__init__() 446 | self.config = config 447 | self.model_type = config.model_type 448 | 449 | # Core transformer model 450 | self.model = Qwen3Model(config) 451 | 452 | def __call__( 453 | self, 454 | input_ids: mx.array, 455 | attention_mask: Optional[mx.array] = None, 456 | ) -> BaseModelOutput: 457 | """ 458 | Forward pass for embedding generation 459 | 460 | Args: 461 | input_ids: Input token IDs, shape (batch_size, seq_len) 462 | attention_mask: Attention mask, shape (batch_size, seq_len) 463 | 464 | Returns: 465 | BaseModelOutput containing: 466 | - text_embeds: Normalized embeddings from last token pooling 467 | - last_hidden_state: Full sequence hidden states 468 | """ 469 | # Validate inputs 470 | if input_ids.ndim != 2: 471 | raise ValueError(f"input_ids must be 2D, got shape {input_ids.shape}") 472 | 473 | batch_size, seq_len = input_ids.shape 474 | 475 | # Create default attention mask if not provided 476 | if attention_mask is None: 477 | attention_mask = mx.ones((batch_size, seq_len), dtype=mx.int32) 478 | elif attention_mask.shape != (batch_size, seq_len): 479 | raise ValueError( 480 | f"attention_mask shape {attention_mask.shape} doesn't match " 481 | f"input_ids shape {input_ids.shape}" 482 | ) 483 | 484 | # Forward pass through the transformer 485 | last_hidden_state = self.model(input_ids, attention_mask=attention_mask) 486 | 487 | # Apply last token pooling for embeddings (best for autoregressive models) 488 | pooled_output = last_token_pool(last_hidden_state, attention_mask) 489 | 490 | # Normalize embeddings for downstream tasks 491 | text_embeds = normalize_embeddings(pooled_output) 492 | 493 | return BaseModelOutput( 494 | text_embeds=text_embeds, last_hidden_state=last_hidden_state 495 | ) 496 | 497 | def sanitize(self, weights: dict) -> dict: 498 | """ 499 | Sanitize weights for loading from different checkpoint formats 500 | 501 | Handles parameter name transformations between different model formats 502 | and ensures compatibility with the MLX model structure 503 | 504 | Args: 505 | weights: Dictionary of model weights 506 | 507 | Returns: 508 | Sanitized weights dictionary 509 | """ 510 | sanitized_weights = {} 511 | 512 | for key, value in weights.items(): 513 | # Skip language model head weights (not used for embeddings) 514 | if "lm_head.weight" in key: 515 | continue 516 | 517 | # Handle different checkpoint formats 518 | new_key = key 519 | 520 | # Map common parameter naming patterns 521 | if key.startswith("transformer."): 522 | # Some checkpoints use "transformer." prefix 523 | new_key = key.replace("transformer.", "model.") 524 | elif key.startswith("model."): 525 | # Already has correct prefix 526 | new_key = key 527 | elif not key.startswith("model.") and "." in key: 528 | # Add model prefix for transformer parameters 529 | new_key = f"model.{key}" 530 | else: 531 | # Keep as is for other parameters 532 | new_key = key 533 | 534 | sanitized_weights[new_key] = value 535 | 536 | return sanitized_weights 537 | -------------------------------------------------------------------------------- /mlx_embeddings/models/modernbert.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass, field 3 | from typing import Any, Dict, List, Literal, Optional 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | 8 | from .base import BaseModelArgs, BaseModelOutput, mean_pooling, normalize_embeddings 9 | 10 | 11 | @dataclass 12 | class ModelArgs(BaseModelArgs): 13 | model_type: str 14 | vocab_size: int 15 | hidden_size: int 16 | num_hidden_layers: int 17 | intermediate_size: int 18 | num_attention_heads: int 19 | max_position_embeddings: Optional[int] = None 20 | norm_eps: float = 1e-05 21 | norm_bias: bool = False 22 | global_rope_theta: float = 160000.0 23 | attention_bias: bool = False 24 | attention_dropout: float = 0.0 25 | global_attn_every_n_layers: int = 3 26 | local_attention: int = 128 27 | local_rope_theta: float = 10000 28 | embedding_dropout: float = 0.0 29 | mlp_bias: bool = False 30 | mlp_dropout: float = 0.0 31 | pad_token_id = 50283 32 | eos_token_id = 50282 33 | bos_token_id = 50281 34 | cls_token_id = 50281 35 | sep_token_id = 50282 36 | output_hidden_states: bool = False 37 | use_return_dict: bool = True 38 | tie_word_embeddings: bool = True 39 | architectures: List[str] = field(default_factory=lambda: ["ModernBertForMaskedLM"]) 40 | 41 | # pipeline args, mostly for classification 42 | # not directly related to this project but consistent with original ModernBERT implementation 43 | # may be useful for future pipelines 44 | decoder_bias = (True,) 45 | classifier_pooling: Literal["cls", "mean"] = "mean" 46 | classifier_dropout = 0.0 # for Sequence Classification 47 | classifier_bias = False # for Sequence Classification 48 | sparse_prediction = True # for MLM 49 | sparse_pred_ignore_index = -100 # for MLM 50 | is_regression: Optional[bool] = None # for Sequence Classification 51 | label2id: Optional[Dict[str, int]] = None # for Sequence Classification 52 | id2label: Optional[Dict[int, str]] = None # for Sequence Classification 53 | pipeline_config: Optional[Dict[str, Any]] = None # for Sequence Classification 54 | 55 | @property 56 | def num_labels(self) -> int: # for Sequence Classification 57 | """ 58 | Number of labels is determined by: 59 | - For zero-shot classification: length of label_candidates 60 | - For regression or binary with sigmoid: 1 61 | - For classification: length of id2label mapping 62 | """ 63 | 64 | if self.is_regression: 65 | return 1 66 | 67 | if self.pipeline_config and self.pipeline_config.get("binary_sigmoid", False): 68 | return 1 69 | 70 | if self.id2label is None: 71 | raise ValueError( 72 | "id2label mapping must be provided for categorical classification. " 73 | "For regression or binary classification with sigmoid output, " 74 | "set is_regression=True or binary_sigmoid=True in pipeline_config." 75 | ) 76 | 77 | return len(self.id2label) 78 | 79 | 80 | class ModernBertEmbeddings(nn.Module): 81 | """ 82 | Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. 83 | """ 84 | 85 | def __init__(self, config: ModelArgs): 86 | super().__init__() 87 | self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 88 | self.norm = nn.LayerNorm( 89 | config.hidden_size, eps=config.norm_eps, bias=config.norm_bias 90 | ) 91 | self.drop = nn.Dropout(p=config.embedding_dropout) 92 | 93 | def __call__(self, input_ids): 94 | embeddings = self.tok_embeddings(input_ids) 95 | embeddings = self.norm(embeddings) 96 | embeddings = self.drop(embeddings) 97 | return embeddings 98 | 99 | 100 | class ModernBertMLP(nn.Module): 101 | """Applies the GLU at the end of each ModernBERT layer. 102 | 103 | Compared to the default BERT architecture, this block replaces class BertIntermediate` 104 | and class SelfOutput with a single module that has similar functionality. 105 | """ 106 | 107 | def __init__(self, config: ModelArgs): 108 | super().__init__() 109 | self.config = config 110 | self.Wi = nn.Linear( 111 | config.hidden_size, config.intermediate_size * 2, bias=config.mlp_bias 112 | ) 113 | self.act = nn.GELU(approx="precise") 114 | self.drop = nn.Dropout(p=config.mlp_dropout) 115 | self.Wo = nn.Linear( 116 | int(config.intermediate_size), config.hidden_size, bias=config.mlp_bias 117 | ) 118 | 119 | def __call__(self, hidden_states): 120 | x = self.Wi(hidden_states) 121 | # Implementing chunk operation 122 | split_dim = x.shape[-1] // 2 123 | input, gate = x[:, :, :split_dim], x[:, :, split_dim:] 124 | return self.Wo(self.drop(self.act(input) * gate)) 125 | 126 | 127 | class ModernBertAttention(nn.Module): 128 | """Performs multi-headed self attention on a batch of unpadded sequences. 129 | For now, only supports the Scaled Dot-Product Attention (SDPA) implementation. 130 | """ 131 | 132 | def __init__(self, config: ModelArgs, layer_id: Optional[int] = None): 133 | super().__init__() 134 | self.config = config 135 | self.layer_id = layer_id 136 | 137 | if config.hidden_size % config.num_attention_heads != 0: 138 | raise ValueError( 139 | f"hidden_size ({config.hidden_size}) must be divisible by num_attention_heads ({config.num_attention_heads})" 140 | ) 141 | 142 | self.attention_dropout = config.attention_dropout 143 | self.num_heads = config.num_attention_heads 144 | self.head_dim = config.hidden_size // config.num_attention_heads 145 | self.scale = self.head_dim**-0.5 146 | self.all_head_size = self.head_dim * self.num_heads 147 | self.Wqkv = nn.Linear( 148 | config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias 149 | ) 150 | 151 | if layer_id % config.global_attn_every_n_layers != 0: 152 | self.local_attention = ( 153 | config.local_attention // 2, 154 | config.local_attention // 2, 155 | ) 156 | else: 157 | self.local_attention = (-1, -1) 158 | 159 | rope_theta = config.global_rope_theta 160 | if self.local_attention != (-1, -1) and config.local_rope_theta is not None: 161 | rope_theta = config.local_rope_theta 162 | 163 | self.rotary_emb = nn.RoPE(dims=self.head_dim, base=rope_theta) 164 | 165 | self.Wo = nn.Linear( 166 | config.hidden_size, config.hidden_size, bias=config.attention_bias 167 | ) 168 | self.out_drop = ( 169 | nn.Dropout(p=config.attention_dropout) 170 | if config.attention_dropout > 0.0 171 | else nn.Identity() 172 | ) 173 | self.pruned_heads = set() 174 | 175 | def __call__( 176 | self, 177 | hidden_states, 178 | attention_mask=None, 179 | sliding_window_mask=None, 180 | position_ids=None, 181 | **kwargs, 182 | ): 183 | qkv = self.Wqkv(hidden_states) 184 | bs = hidden_states.shape[0] 185 | qkv = mx.reshape(qkv, (bs, -1, 3, self.num_heads, self.head_dim)) 186 | 187 | # Get attention outputs using SDPA 188 | qkv = mx.transpose( 189 | qkv, [0, 3, 2, 1, 4] 190 | ) # [batch_size, nheads, 3, seqlen, headdim] 191 | query, key, value = mx.split( 192 | qkv, indices_or_sections=3, axis=2 193 | ) # each [batch_size, nheads, 1, seqlen, headdim] 194 | 195 | query = query.squeeze(2) # [batch_size, nheads, seqlen, headdim] 196 | key = key.squeeze(2) # [batch_size, nheads, seqlen, headdim] 197 | value = value.squeeze(2) # [batch_size, nheads, seqlen, headdim] 198 | 199 | # Applying rotary embeddings 200 | query = self.rotary_emb(query) 201 | key = self.rotary_emb(key) 202 | 203 | # Handling local attention if needed 204 | if self.local_attention != (-1, -1): 205 | attention_mask = sliding_window_mask 206 | 207 | attn_output = mx.fast.scaled_dot_product_attention( 208 | query, key, value, scale=self.scale, mask=attention_mask 209 | ) 210 | 211 | # Reshaping and apply output projection 212 | attn_output = mx.transpose(attn_output, [0, 2, 1, 3]) 213 | attn_output = mx.reshape(attn_output, (bs, -1, self.all_head_size)) 214 | 215 | # Applying output projection and dropout 216 | hidden_states = self.Wo(attn_output) 217 | hidden_states = self.out_drop(hidden_states) 218 | 219 | return (hidden_states,) 220 | 221 | 222 | class ModernBertEncoderLayer(nn.Module): 223 | def __init__(self, config: ModelArgs, layer_id: Optional[int] = None): 224 | super().__init__() 225 | self.config = config 226 | if layer_id == 0: 227 | self.attn_norm = nn.Identity() 228 | else: 229 | self.attn_norm = nn.LayerNorm( 230 | config.hidden_size, eps=config.norm_eps, bias=config.norm_bias 231 | ) 232 | self.attn = ModernBertAttention(config=config, layer_id=layer_id) 233 | self.mlp = ModernBertMLP(config) 234 | self.mlp_norm = nn.LayerNorm( 235 | config.hidden_size, eps=config.norm_eps, bias=config.norm_bias 236 | ) 237 | 238 | def __call__( 239 | self, 240 | hidden_states, 241 | attention_mask=None, 242 | sliding_window_mask=None, 243 | position_ids=None, 244 | ): 245 | normalized_hidden_states = self.attn_norm(hidden_states) 246 | attention_output = self.attn( 247 | normalized_hidden_states, 248 | attention_mask=attention_mask, 249 | sliding_window_mask=sliding_window_mask, 250 | position_ids=position_ids, 251 | ) 252 | hidden_states = hidden_states + attention_output[0] 253 | mlp_output = self.mlp(self.mlp_norm(hidden_states)) 254 | hidden_states = hidden_states + mlp_output 255 | 256 | return (hidden_states,) 257 | 258 | 259 | class ModernBertModel(nn.Module): 260 | def __init__(self, config: ModelArgs): 261 | super().__init__() 262 | self.config = config 263 | self.embeddings = ModernBertEmbeddings(config) 264 | self.layers = [ 265 | ModernBertEncoderLayer(config, i) for i in range(config.num_hidden_layers) 266 | ] 267 | self.final_norm = nn.LayerNorm( 268 | config.hidden_size, eps=config.norm_eps, bias=config.norm_bias 269 | ) 270 | self.gradient_checkpointing = False ### TBC 271 | 272 | def get_input_embeddings(self) -> ModernBertEmbeddings: 273 | return self.embeddings.tok_embeddings 274 | 275 | def set_input_embeddings(self, value): 276 | self.embeddings.tok_embeddings = value 277 | 278 | def __call__( 279 | self, 280 | input_ids, 281 | attention_mask=None, # shape: (batch_size, seq_len), updated with _update_attention_mask below 282 | sliding_window_mask=None, 283 | position_ids=None, 284 | output_hidden_states: Optional[bool] = False, 285 | return_dict: Optional[bool] = True, 286 | ): 287 | output_hidden_states = ( 288 | output_hidden_states 289 | if output_hidden_states is not None 290 | else self.config.output_hidden_states 291 | ) 292 | return_dict = ( 293 | return_dict if return_dict is not None else self.config.use_return_dict 294 | ) 295 | 296 | all_hidden_states = () if output_hidden_states else None 297 | 298 | batch_size, seq_len = input_ids.shape[:2] 299 | 300 | if position_ids is None: 301 | position_ids = mx.arange(seq_len, dtype=mx.int32) 302 | position_ids = mx.repeat(position_ids, batch_size, axis=0) 303 | 304 | # get attention mask and sliding window mask 305 | attention_mask, sliding_window_mask = self._update_attention_mask( 306 | attention_mask=attention_mask, 307 | ) 308 | 309 | hidden_states = self.embeddings(input_ids) 310 | 311 | for encoder_layer in self.layers: 312 | if output_hidden_states: 313 | all_hidden_states = all_hidden_states + (hidden_states,) 314 | 315 | layer_outputs = encoder_layer( 316 | hidden_states, 317 | attention_mask=attention_mask, 318 | sliding_window_mask=sliding_window_mask, 319 | position_ids=position_ids, 320 | ) 321 | 322 | hidden_states = layer_outputs[0] 323 | 324 | hidden_states = self.final_norm(hidden_states) 325 | 326 | if not return_dict: 327 | return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) 328 | return { 329 | "last_hidden_state": hidden_states, 330 | "hidden_states": all_hidden_states, 331 | } 332 | 333 | def _update_attention_mask(self, attention_mask): 334 | dtype = attention_mask.dtype 335 | batch_size, seq_len = attention_mask.shape 336 | 337 | additive_mask = mx.where(attention_mask == 1, 0.0, -1e9) 338 | additive_mask = additive_mask[:, None, None, :] 339 | 340 | # Create the causal mask for global attention 341 | # (1, 1, seq_len, seq_len) 342 | global_attention_mask = mx.broadcast_to( 343 | additive_mask, (batch_size, 1, seq_len, seq_len) 344 | ) 345 | 346 | # Create position indices for sliding window 347 | rows = mx.arange(seq_len) 348 | rows = rows[None, :] # (1, seq_len) 349 | # Calculate position-wise distances 350 | distance = mx.abs(rows - rows.T) # (seq_len, seq_len) 351 | 352 | # Create sliding window mask using mx.where 353 | window_mask = mx.where( 354 | distance <= (self.config.local_attention // 2), 355 | mx.ones_like(distance), 356 | mx.zeros_like(distance), 357 | ) 358 | 359 | # Expand dimensions using None indexing 360 | window_mask = window_mask[None, None, :, :] # (1, 1, seq_len, seq_len) 361 | 362 | # Broadcast to match batch size 363 | window_mask = mx.broadcast_to(window_mask, global_attention_mask.shape) 364 | 365 | # Creating sliding window attention mask 366 | # Replacing non-window positions with large negative value 367 | sliding_window_mask = mx.where(window_mask, global_attention_mask, -1e9) 368 | 369 | return global_attention_mask.astype(dtype), sliding_window_mask.astype(dtype) 370 | 371 | 372 | class ModernBertPredictionHead(nn.Module): 373 | def __init__(self, config: ModelArgs): 374 | super().__init__() 375 | self.config = config 376 | self.dense = nn.Linear( 377 | config.hidden_size, config.hidden_size, config.classifier_bias 378 | ) 379 | self.act = nn.GELU(approx="precise") 380 | self.norm = nn.LayerNorm( 381 | config.hidden_size, eps=config.norm_eps, bias=config.norm_bias 382 | ) 383 | 384 | def __call__(self, hidden_states: mx.array) -> mx.array: 385 | return self.norm(self.act(self.dense(hidden_states))) 386 | 387 | 388 | # classes for specific pipelines 389 | class Model(nn.Module): 390 | """ 391 | Computes pooled, normalized embeddings for input sequences using a ModernBERT model. 392 | 393 | Note : sanitization is a hack to align with other models here while downloading weights 394 | with the maskedlm config from HF (original modelBert model). 395 | The decoder.bias is ignored here. 396 | """ 397 | 398 | def __init__(self, config: ModelArgs): 399 | super().__init__() 400 | self.config = config 401 | self.model = ModernBertModel(config) 402 | if config.architectures == ["ModernBertForMaskedLM"]: 403 | self.head = ModernBertPredictionHead(config) 404 | self.decoder = nn.Linear( 405 | config.hidden_size, config.vocab_size, bias=config.decoder_bias 406 | ) 407 | elif config.architectures == ["ModernBertForSequenceClassification"]: 408 | self.num_labels = config.num_labels 409 | self.is_regression = config.is_regression 410 | self.head = ModernBertPredictionHead(config) 411 | self.drop = nn.Dropout(p=config.classifier_dropout) 412 | self.classifier = nn.Linear( 413 | config.hidden_size, 414 | config.num_labels, 415 | bias=True, ### bias=config.classifier_bias removed because mismatch with HF checkpoint 416 | ) 417 | 418 | def _process_outputs(self, logits: mx.array) -> mx.array: 419 | """Apply the appropriate activation function to the logits for classification tasks.""" 420 | if self.is_regression: 421 | return logits # No activation for regression 422 | elif self.num_labels == 1: 423 | return mx.sigmoid(logits) # Binary classification 424 | else: 425 | # Using softmax for multi-class classification 426 | return mx.softmax(logits, axis=-1) 427 | 428 | def __call__( 429 | self, 430 | input_ids: mx.array, 431 | attention_mask: Optional[mx.array] = None, 432 | position_ids: Optional[mx.array] = None, 433 | return_dict: Optional[bool] = False, 434 | ): 435 | 436 | if attention_mask is None: 437 | batch_size, seq_len = input_ids.shape 438 | attention_mask = mx.ones( 439 | (batch_size, seq_len), 440 | dtype=self.model.embeddings.tok_embeddings.weight.dtype, 441 | ) 442 | 443 | # Get embeddings and encoder outputs as before 444 | encoder_outputs = self.model( 445 | input_ids, 446 | attention_mask=attention_mask, 447 | position_ids=position_ids, 448 | output_hidden_states=None, # only last_hidden_state is returned 449 | return_dict=return_dict, 450 | ) 451 | last_hidden_state = ( 452 | encoder_outputs["last_hidden_state"] 453 | if isinstance(encoder_outputs, dict) 454 | else encoder_outputs[0] 455 | ) 456 | 457 | # Pooling strategy using config 458 | if self.config.architectures != ["ModernBertForMaskedLM"]: 459 | if self.config.classifier_pooling == "cls": 460 | last_hidden_state = last_hidden_state[:, 0] 461 | elif self.config.classifier_pooling == "mean": 462 | last_hidden_state = mean_pooling(last_hidden_state, attention_mask) 463 | else: 464 | raise ValueError( 465 | f"Invalid pooling strategy: {self.config.classifier_pooling}" 466 | ) 467 | 468 | pooled_output = None 469 | if self.config.architectures == ["ModernBertForMaskedLM"]: 470 | pooled_output = self.head(last_hidden_state) 471 | pooled_output = self.decoder(pooled_output) 472 | elif self.config.architectures == ["ModernBertForSequenceClassification"]: 473 | pooled_output = self.head(last_hidden_state) 474 | pooled_output = self.drop(pooled_output) 475 | pooled_output = self.classifier(pooled_output) 476 | pooled_output = self._process_outputs(pooled_output) 477 | 478 | # normalized features 479 | text_embeds = normalize_embeddings(last_hidden_state) 480 | 481 | return BaseModelOutput( 482 | last_hidden_state=last_hidden_state, 483 | text_embeds=text_embeds, 484 | pooler_output=pooled_output, 485 | hidden_states=encoder_outputs[1:], 486 | ) 487 | 488 | def sanitize(self, weights): 489 | sanitized_weights = {} 490 | for k, v in weights.items(): 491 | if ( 492 | self.config.architectures != ["ModernBertForMaskedLM"] 493 | and self.config.architectures != ["ModernBertForSequenceClassification"] 494 | and not k.startswith("model") 495 | ): 496 | new_key = "model." + k 497 | sanitized_weights[new_key] = v 498 | elif ( 499 | self.config.tie_word_embeddings 500 | and "decoder.bias" in k 501 | and "decoder.biases" not in k 502 | ): 503 | sanitized_weights["decoder.bias"] = v 504 | sanitized_weights["decoder.weight"] = weights[ 505 | "model.embeddings.tok_embeddings.weight" 506 | ] 507 | else: 508 | sanitized_weights[k] = v 509 | return sanitized_weights 510 | 511 | 512 | class ModelSentenceTransformers(Model): 513 | """ 514 | Different santiization method for sentence transformers. 515 | """ 516 | 517 | def __init__(self, config): 518 | super().__init__(config) 519 | 520 | def sanitize(self, weights): 521 | """Convert sentence transformer weights to ModernBERT format.""" 522 | sanitized_weights = {} 523 | 524 | for k, v in weights.items(): 525 | new_key = "model." + k 526 | sanitized_weights[new_key] = v 527 | return sanitized_weights 528 | -------------------------------------------------------------------------------- /mlx_embeddings/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import copy 4 | import glob 5 | import importlib 6 | import json 7 | import logging 8 | import re 9 | import shutil 10 | from pathlib import Path 11 | from textwrap import dedent 12 | from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union 13 | 14 | import mlx.core as mx 15 | import mlx.nn as nn 16 | from huggingface_hub import snapshot_download 17 | from huggingface_hub.errors import RepositoryNotFoundError 18 | from mlx.utils import tree_flatten, tree_unflatten 19 | from mlx_vlm.utils import process_image 20 | from transformers import AutoProcessor, PreTrainedTokenizer 21 | 22 | from .tokenizer_utils import TokenizerWrapper, load_tokenizer 23 | 24 | # Constants 25 | MODEL_REMAPPING = {} 26 | 27 | MAX_FILE_SIZE_GB = 5 28 | 29 | 30 | class ModelNotFoundError(Exception): 31 | def __init__(self, message): 32 | self.message = message 33 | super().__init__(self.message) 34 | 35 | 36 | def _get_classes(config: dict): 37 | """ 38 | Retrieve the model and model args classes based on the configuration. 39 | 40 | Args: 41 | config (dict): The model configuration. 42 | 43 | Returns: 44 | A tuple containing the Model class and the ModelArgs class. 45 | """ 46 | model_type = config["model_type"].replace("-", "_") 47 | model_type = MODEL_REMAPPING.get(model_type, model_type) 48 | try: 49 | arch = importlib.import_module(f"mlx_embeddings.models.{model_type}") 50 | except ImportError: 51 | msg = f"Model type {model_type} not supported." 52 | logging.error(msg) 53 | raise ValueError(msg) 54 | 55 | if hasattr(arch, "TextConfig") and hasattr(arch, "VisionConfig"): 56 | return arch.Model, arch.ModelArgs, arch.TextConfig, arch.VisionConfig 57 | 58 | return arch.Model, arch.ModelArgs, None, None 59 | 60 | 61 | def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: 62 | """ 63 | Ensures the model is available locally. If the path does not exist locally, 64 | it is downloaded from the Hugging Face Hub. 65 | 66 | Args: 67 | path_or_hf_repo (str): The local path or Hugging Face repository ID of the model. 68 | revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash. 69 | 70 | Returns: 71 | Path: The path to the model. 72 | """ 73 | model_path = Path(path_or_hf_repo) 74 | if not model_path.exists(): 75 | try: 76 | model_path = Path( 77 | snapshot_download( 78 | repo_id=path_or_hf_repo, 79 | revision=revision, 80 | allow_patterns=[ 81 | "*.json", 82 | "*.safetensors", 83 | "*.py", 84 | "*.tiktoken", 85 | "*.txt", 86 | "*.model", 87 | ], 88 | ) 89 | ) 90 | except RepositoryNotFoundError: 91 | raise ModelNotFoundError( 92 | f"Model not found for path or HF repo: {path_or_hf_repo}.\n" 93 | "Please make sure you specified the local path or Hugging Face" 94 | " repo id correctly.\nIf you are trying to access a private or" 95 | " gated Hugging Face repo, make sure you are authenticated:\n" 96 | "https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login" 97 | ) from None 98 | return model_path 99 | 100 | 101 | def load_config(model_path: Path) -> dict: 102 | try: 103 | with open(model_path / "config.json", "r") as f: 104 | config = json.load(f) 105 | except FileNotFoundError: 106 | logging.error(f"Config file not found in {model_path}") 107 | raise 108 | return config 109 | 110 | 111 | def load_model( 112 | model_path: Path, 113 | lazy: bool = False, 114 | model_config: dict = {}, 115 | get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes, 116 | **kwargs, 117 | ) -> nn.Module: 118 | """ 119 | Load and initialize the model from a given path. 120 | 121 | Args: 122 | model_path (Path): The path to load the model from. 123 | lazy (bool): If False eval the model parameters to make sure they are 124 | loaded in memory before returning, otherwise they will be loaded 125 | when needed. Default: ``False`` 126 | model_config (dict, optional): Configuration parameters for the model. 127 | Defaults to an empty dictionary. 128 | get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): 129 | A function that returns the model class and model args class given a config. 130 | Defaults to the _get_classes function. 131 | 132 | Returns: 133 | nn.Module: The loaded and initialized model. 134 | 135 | Raises: 136 | FileNotFoundError: If the weight files (.safetensors) are not found. 137 | ValueError: If the model class or args class are not found or cannot be instantiated. 138 | """ 139 | 140 | config = load_config(model_path) 141 | config.update(model_config) 142 | 143 | weight_files = glob.glob(str(model_path / "**/model*.safetensors"), recursive=True) 144 | 145 | if not weight_files: 146 | # Try weight for back-compat 147 | weight_files = glob.glob(str(model_path / "weight*.safetensors")) 148 | 149 | if not weight_files: 150 | logging.error(f"No safetensors found in {model_path}") 151 | raise FileNotFoundError(f"No safetensors found in {model_path}") 152 | 153 | weights = {} 154 | for wf in weight_files: 155 | loaded_weights = mx.load(wf) 156 | if Path(wf).parent != model_path: 157 | folder_name = Path(wf).parent.name 158 | renamed_weights = {} 159 | for key, value in loaded_weights.items(): 160 | new_key = f"{folder_name}.{key}" 161 | renamed_weights[new_key] = value 162 | weights.update(renamed_weights) 163 | else: 164 | weights.update(loaded_weights) 165 | 166 | model_class, model_args_class, text_config, vision_config = get_model_classes( 167 | config=config 168 | ) 169 | 170 | model_args = model_args_class.from_dict(config) 171 | 172 | if text_config is not None: 173 | model_args.text_config = text_config(**model_args.text_config) 174 | if vision_config is not None: 175 | model_args.vision_config = vision_config(**model_args.vision_config) 176 | 177 | # siglip models have a different image size 178 | if "siglip" in config["model_type"]: 179 | # Extract the image size 180 | image_size = re.search( 181 | r"patch\d+-(\d+)(?:-|$)", kwargs["path_to_repo"] 182 | ).group(1) 183 | # Extract the patch size 184 | patch_size = re.search(r"patch(\d+)", kwargs["path_to_repo"]).group(1) 185 | patch_size = ( 186 | re.search(r"\d+", patch_size).group() 187 | if re.search(r"\d+", patch_size) 188 | else patch_size 189 | ) 190 | if model_args.vision_config.image_size != int(image_size): 191 | model_args.vision_config.image_size = int(image_size) 192 | if model_args.vision_config.patch_size != int(patch_size): 193 | model_args.vision_config.patch_size = int(patch_size) 194 | 195 | model = model_class(model_args) 196 | 197 | if hasattr(model, "sanitize"): 198 | weights = model.sanitize(weights) 199 | 200 | if (quantization := config.get("quantization", None)) is not None: 201 | # Handle legacy models which may not have everything quantized 202 | def class_predicate(p, m): 203 | if not hasattr(m, "to_quantized"): 204 | return False 205 | return f"{p}.scales" in weights 206 | 207 | nn.quantize( 208 | model, 209 | **quantization, 210 | class_predicate=class_predicate, 211 | ) 212 | 213 | model.load_weights(list(weights.items())) 214 | 215 | if not lazy: 216 | mx.eval(model.parameters()) 217 | 218 | model.eval() 219 | return model 220 | 221 | 222 | def load( 223 | path_or_hf_repo: str, 224 | tokenizer_config={}, 225 | model_config={}, 226 | adapter_path: Optional[str] = None, 227 | lazy: bool = False, 228 | ) -> Tuple[nn.Module, TokenizerWrapper]: 229 | """ 230 | Load the model and tokenizer from a given path or a huggingface repository. 231 | 232 | Args: 233 | path_or_hf_repo (Path): The path or the huggingface repository to load the model from. 234 | tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. 235 | Defaults to an empty dictionary. 236 | model_config(dict, optional): Configuration parameters specifically for the model. 237 | Defaults to an empty dictionary. 238 | adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers 239 | to the model. Default: ``None``. 240 | lazy (bool): If False eval the model parameters to make sure they are 241 | loaded in memory before returning, otherwise they will be loaded 242 | when needed. Default: ``False`` 243 | Returns: 244 | Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer. 245 | 246 | Raises: 247 | FileNotFoundError: If config file or safetensors are not found. 248 | ValueError: If model class or args class are not found. 249 | """ 250 | model_path = get_model_path(path_or_hf_repo) 251 | 252 | model = load_model(model_path, lazy, model_config, path_to_repo=path_or_hf_repo) 253 | 254 | # Try to load tokenizer first, then fall back to processor if needed 255 | tokenizer = None 256 | 257 | # First attempt: load tokenizer 258 | try: 259 | if hasattr(model.config, "vision_config"): 260 | tokenizer = AutoProcessor.from_pretrained(model_path) 261 | else: 262 | tokenizer = load_tokenizer(model_path, tokenizer_config) 263 | except Exception as tokenizer_error: 264 | raise ValueError( 265 | f"Failed to initialize tokenizer or processor: {tokenizer_error}" 266 | ) from tokenizer_error 267 | 268 | return model, tokenizer 269 | 270 | 271 | def fetch_from_hub( 272 | model_path: Path, lazy: bool = False, **kwargs 273 | ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: 274 | model = load_model(model_path, lazy, **kwargs) 275 | config = load_config(model_path) 276 | tokenizer = load_tokenizer(model_path) 277 | return model, config, tokenizer 278 | 279 | 280 | def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list: 281 | """ 282 | Splits the weights into smaller shards. 283 | 284 | Args: 285 | weights (dict): Model weights. 286 | max_file_size_gb (int): Maximum size of each shard in gigabytes. 287 | 288 | Returns: 289 | list: List of weight shards. 290 | """ 291 | max_file_size_bytes = max_file_size_gb << 30 292 | shards = [] 293 | shard, shard_size = {}, 0 294 | for k, v in weights.items(): 295 | if shard_size + v.nbytes > max_file_size_bytes: 296 | shards.append(shard) 297 | shard, shard_size = {}, 0 298 | shard[k] = v 299 | shard_size += v.nbytes 300 | shards.append(shard) 301 | return shards 302 | 303 | 304 | def upload_to_hub(path: str, upload_repo: str, hf_path: str, config: dict): 305 | """ 306 | Uploads the model to Hugging Face hub. 307 | 308 | Args: 309 | path (str): Local path to the model. 310 | upload_repo (str): Name of the HF repo to upload to. 311 | hf_path (str): Path to the original Hugging Face model. 312 | """ 313 | import os 314 | 315 | from huggingface_hub import HfApi, ModelCard, logging 316 | 317 | from . import __version__ 318 | 319 | # Determine appropriate example code based on model type 320 | if config.get("vision_config", None) is None: 321 | # Text-only model 322 | text_example = """ 323 | # For text embeddings 324 | output = generate(model, processor, texts=["I like grapes", "I like fruits"]) 325 | embeddings = output.text_embeds # Normalized embeddings 326 | 327 | # Compute dot product between normalized embeddings 328 | similarity_matrix = mx.matmul(embeddings, embeddings.T) 329 | 330 | print("Similarity matrix between texts:") 331 | print(similarity_matrix) 332 | """ 333 | 334 | if config.get("architectures", None) == "ModernBertForMaskedLM": 335 | text_example = """ 336 | # For masked language modeling 337 | output = generate(model, processor, texts=["The capital of France is [MASK]."])\n 338 | mask_index = processor.encode("[MASK]", add_special_tokens=False)[0]\n 339 | predicted_token_id = mx.argmax(output.logits[0, mask_index], axis=-1)\n 340 | predicted_token = processor.decode([predicted_token_id.item()]) 341 | """ 342 | 343 | response = text_example 344 | else: 345 | # Vision-text model 346 | response = """ 347 | # For image-text embeddings 348 | images = [ 349 | "./images/cats.jpg", # cats 350 | ] 351 | texts = ["a photo of cats", "a photo of a desktop setup", "a photo of a person"] 352 | 353 | # Process all image-text pairs 354 | outputs = generate(model, processor, texts, images=images) 355 | logits_per_image = outputs.logits_per_image 356 | probs = mx.sigmoid(logits_per_image) # probabilities for this image 357 | for i, image in enumerate(images): 358 | print(f"Image {i+1}:") 359 | for j, text in enumerate(texts): 360 | print(f" {probs[i][j]:.1%} match with '{text}'") 361 | print() 362 | """ 363 | 364 | card = ModelCard.load(hf_path) 365 | card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"] 366 | card.text = dedent( 367 | f""" 368 | # {upload_repo} 369 | 370 | The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**. 371 | 372 | ## Use with mlx 373 | 374 | ```bash 375 | pip install mlx-embeddings 376 | ``` 377 | 378 | ```python 379 | from mlx_embeddings import load, generate 380 | import mlx.core as mx 381 | 382 | model, tokenizer = load("{upload_repo}") 383 | {response} 384 | 385 | ``` 386 | """ 387 | ) 388 | card.save(os.path.join(path, "README.md")) 389 | 390 | logging.set_verbosity_info() 391 | 392 | api = HfApi() 393 | api.create_repo(repo_id=upload_repo, exist_ok=True) 394 | api.upload_folder(folder_path=path, repo_id=upload_repo, repo_type="model") 395 | print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.") 396 | 397 | 398 | def save_weights( 399 | save_path: Union[str, Path], 400 | weights: Dict[str, Any], 401 | *, 402 | donate_weights: bool = False, 403 | ) -> None: 404 | """Save model weights into specified directory.""" 405 | if isinstance(save_path, str): 406 | save_path = Path(save_path) 407 | save_path.mkdir(parents=True, exist_ok=True) 408 | 409 | shards = make_shards(weights) 410 | shards_count = len(shards) 411 | shard_file_format = ( 412 | "model-{:05d}-of-{:05d}.safetensors" 413 | if shards_count > 1 414 | else "model.safetensors" 415 | ) 416 | 417 | total_size = sum(v.nbytes for v in weights.values()) 418 | index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} 419 | 420 | # Write the weights and make sure no references are kept other than the 421 | # necessary ones 422 | if donate_weights: 423 | weights.clear() 424 | del weights 425 | 426 | for i in range(len(shards)): 427 | shard = shards[i] 428 | shards[i] = None 429 | shard_name = shard_file_format.format(i + 1, shards_count) 430 | shard_path = save_path / shard_name 431 | 432 | mx.save_safetensors(str(shard_path), shard, metadata={"format": "mlx"}) 433 | 434 | for weight_name in shard.keys(): 435 | index_data["weight_map"][weight_name] = shard_name 436 | del shard 437 | 438 | index_data["weight_map"] = { 439 | k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"]) 440 | } 441 | 442 | with open(save_path / "model.safetensors.index.json", "w") as f: 443 | json.dump( 444 | index_data, 445 | f, 446 | indent=4, 447 | ) 448 | 449 | 450 | def get_class_predicate(skip_vision: bool, q_group_size: int, weights: dict = None): 451 | """ 452 | Returns a predicate function for quantization that handles vision model skipping 453 | and dimension compatibility checks. 454 | """ 455 | 456 | def class_predicate(p, m): 457 | # Must have a to_quantized method 458 | if not hasattr(m, "to_quantized"): 459 | return False 460 | 461 | # Optionally skip vision model layers 462 | if skip_vision and ("vision_model" in p or "vision_tower" in p): 463 | return False 464 | 465 | # Check for weight attribute and dimension compatibility 466 | if hasattr(m, "weight"): 467 | if m.weight.ndim < 2 or m.weight.shape[-1] % q_group_size != 0: 468 | print( 469 | f"Skipping quantization of {p}:" 470 | f" Last dimension {m.weight.shape[-1]} is not divisible by group size {q_group_size}." 471 | ) 472 | return False 473 | 474 | # Check against a whitelist of weights if provided 475 | if weights: 476 | return p in weights 477 | 478 | return True 479 | 480 | return class_predicate 481 | 482 | 483 | def quantize_model( 484 | model: nn.Module, 485 | config: dict, 486 | q_group_size: int, 487 | q_bits: int, 488 | skip_vision: bool = True, 489 | ) -> Tuple: 490 | """ 491 | Applies quantization to the model weights. 492 | 493 | Args: 494 | model (nn.Module): The model to be quantized. 495 | config (dict): Model configuration. 496 | q_group_size (int): Group size for quantization. 497 | q_bits (int): Bits per weight for quantization. 498 | 499 | Returns: 500 | Tuple: Tuple containing quantized weights and config. 501 | """ 502 | quantized_config = copy.deepcopy(config) 503 | nn.quantize( 504 | model, 505 | q_group_size, 506 | q_bits, 507 | class_predicate=get_class_predicate( 508 | skip_vision=skip_vision, q_group_size=q_group_size 509 | ), 510 | ) 511 | quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} 512 | quantized_weights = dict(tree_flatten(model.parameters())) 513 | 514 | return quantized_weights, quantized_config 515 | 516 | 517 | def save_config( 518 | config: dict, 519 | config_path: Union[str, Path], 520 | ) -> None: 521 | """Save the model configuration to the ``config_path``. 522 | 523 | The final configuration will be sorted before saving for better readability. 524 | 525 | Args: 526 | config (dict): The model configuration. 527 | config_path (Union[str, Path]): Model configuration file path. 528 | """ 529 | # Clean unused keys 530 | config.pop("_name_or_path", None) 531 | 532 | # sort the config for better readability 533 | config = dict(sorted(config.items())) 534 | 535 | # write the updated config to the config_path (if provided) 536 | with open(config_path, "w") as fid: 537 | json.dump(config, fid, indent=4) 538 | 539 | 540 | def dequantize_model(model: nn.Module) -> nn.Module: 541 | """ 542 | Dequantize the quantized linear layers in the model. 543 | 544 | Args: 545 | model (nn.Module): The model with quantized linear layers. 546 | 547 | Returns: 548 | nn.Module: The model with dequantized layers. 549 | """ 550 | de_quantize_layers = [] 551 | for name, module in model.named_modules(): 552 | if isinstance(module, nn.QuantizedLinear): 553 | bias = "bias" in module 554 | weight = module.weight 555 | weight = mx.dequantize( 556 | weight, 557 | module.scales, 558 | module.biases, 559 | module.group_size, 560 | module.bits, 561 | ).astype(mx.float16) 562 | output_dims, input_dims = weight.shape 563 | linear = nn.Linear(input_dims, output_dims, bias=bias) 564 | linear.weight = weight 565 | if bias: 566 | linear.bias = module.bias 567 | de_quantize_layers.append((name, linear)) 568 | if isinstance(module, nn.QuantizedEmbedding): 569 | weight = mx.dequantize( 570 | module.weight, 571 | module.scales, 572 | module.biases, 573 | module.group_size, 574 | module.bits, 575 | ).astype(mx.float16) 576 | num_embeddings, dims = weight.shape 577 | emb = nn.Embedding(num_embeddings, dims) 578 | emb.weight = weight 579 | de_quantize_layers.append((name, emb)) 580 | 581 | if len(de_quantize_layers) > 0: 582 | model.update_modules(tree_unflatten(de_quantize_layers)) 583 | return model 584 | 585 | 586 | def convert( 587 | hf_path: str, 588 | mlx_path: str = "mlx_model", 589 | quantize: bool = False, 590 | q_group_size: int = 64, 591 | q_bits: int = 4, 592 | dtype: str = "float16", 593 | upload_repo: str = None, 594 | revision: Optional[str] = None, 595 | dequantize: bool = False, 596 | skip_vision: bool = True, 597 | ): 598 | print("[INFO] Loading") 599 | model_path = get_model_path(hf_path, revision=revision) 600 | model, config, tokenizer = fetch_from_hub( 601 | model_path, lazy=True, path_to_repo=hf_path 602 | ) 603 | 604 | weights = dict(tree_flatten(model.parameters())) 605 | dtype = getattr(mx, dtype) 606 | weights = {k: v.astype(dtype) for k, v in weights.items()} 607 | 608 | if quantize and dequantize: 609 | raise ValueError("Choose either quantize or dequantize, not both.") 610 | 611 | if quantize: 612 | print("[INFO] Quantizing") 613 | model.load_weights(list(weights.items())) 614 | weights, config = quantize_model( 615 | model, config, q_group_size, q_bits, skip_vision=skip_vision 616 | ) 617 | 618 | if dequantize: 619 | print("[INFO] Dequantizing") 620 | model = dequantize_model(model) 621 | weights = dict(tree_flatten(model.parameters())) 622 | 623 | if isinstance(mlx_path, str): 624 | mlx_path = Path(mlx_path) 625 | 626 | del model 627 | save_weights(mlx_path, weights, donate_weights=True) 628 | 629 | # Copy Python and JSON files from the model path to the MLX path 630 | for pattern in ["*.py", "*.json"]: 631 | files = glob.glob(str(model_path / pattern)) 632 | for file in files: 633 | shutil.copy(file, mlx_path) 634 | 635 | tokenizer.save_pretrained(mlx_path) 636 | 637 | save_config(config, config_path=mlx_path / "config.json") 638 | 639 | if upload_repo is not None: 640 | upload_to_hub(mlx_path, upload_repo, hf_path, config) 641 | 642 | 643 | def load_images(images, processor, resize_shape=None): 644 | image_processor = ( 645 | processor.image_processor if hasattr(processor, "image_processor") else None 646 | ) 647 | if isinstance(images, str): 648 | images = [process_image(images, resize_shape, image_processor)] 649 | elif isinstance(images, list): 650 | images = [ 651 | process_image(image, resize_shape, image_processor) for image in images 652 | ] 653 | else: 654 | raise ValueError(f"Unsupported image type: {type(images)}") 655 | return images 656 | 657 | 658 | def prepare_inputs( 659 | processor, images, texts, max_length, padding, truncation, resize_shape=None 660 | ): 661 | # Preprocess image-text embeddings 662 | if images is not None: 663 | images = load_images(images, processor, resize_shape=resize_shape) 664 | inputs = processor( 665 | text=texts, images=images, padding="max_length", return_tensors="mlx" 666 | ) 667 | 668 | # Preprocess text embeddings 669 | elif isinstance(texts, str): 670 | inputs = processor.encode(texts, return_tensors="mlx") 671 | elif isinstance(texts, list): 672 | inputs = processor.batch_encode_plus( 673 | texts, 674 | return_tensors="mlx", 675 | padding=padding, 676 | truncation=truncation, 677 | max_length=max_length, 678 | ) 679 | else: 680 | raise ValueError(f"Unsupported input type: {type(texts)}") 681 | 682 | return inputs 683 | 684 | 685 | def generate( 686 | model: nn.Module, 687 | processor: Union[PreTrainedTokenizer, TokenizerWrapper, AutoProcessor], 688 | texts: Union[str, List[str]], 689 | images: Union[str, mx.array, List[str], List[mx.array]] = None, 690 | max_length: int = 512, 691 | padding: bool = True, 692 | truncation: bool = True, 693 | **kwargs, 694 | ) -> mx.array: 695 | """ 696 | Generate embeddings for input text(s) using the provided model and tokenizer. 697 | 698 | Args: 699 | model (nn.Module): The MLX model for generating embeddings. 700 | tokenizer (TokenizerWrapper): The tokenizer for preprocessing text. 701 | texts (Union[str, List[str]]): A single text string or a list of text strings. 702 | 703 | Returns: 704 | mx.array: The generated embeddings. 705 | """ 706 | 707 | resize_shape = kwargs.get("resize_shape", None) 708 | inputs = prepare_inputs( 709 | processor, images, texts, max_length, padding, truncation, resize_shape 710 | ) 711 | 712 | # Generate embeddings 713 | if isinstance(inputs, mx.array): 714 | outputs = model(inputs) 715 | else: 716 | outputs = model(**inputs) 717 | 718 | return outputs 719 | -------------------------------------------------------------------------------- /mlx_embeddings/models/siglip.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | import numpy as np 8 | 9 | from .base import ViTModelOutput, normalize_embeddings 10 | 11 | 12 | @dataclass 13 | class TextConfig: 14 | vocab_size: int = 32000 15 | max_position_embeddings: int = 64 16 | hidden_size: int = 768 17 | intermediate_size: int = 3072 18 | num_attention_heads: int = 12 19 | num_hidden_layers: int = 12 20 | layer_norm_eps: float = 1e-6 21 | projection_size: Optional[int] = None 22 | model_type: str = "siglip_text_model" 23 | 24 | def __post_init__(self): 25 | if self.projection_size is None: 26 | self.projection_size = self.hidden_size 27 | 28 | 29 | @dataclass 30 | class VisionConfig: 31 | image_size: int = 224 32 | patch_size: int = 16 33 | num_channels: int = 3 34 | hidden_size: int = 768 35 | intermediate_size: int = 3072 36 | num_attention_heads: int = 12 37 | num_hidden_layers: int = 12 38 | layer_norm_eps: float = 1e-6 39 | model_type: str = "siglip_vision_model" 40 | vision_use_head: bool = True 41 | # SigLIP2 parameters 42 | num_patches: Optional[int] = None # For SigLIP2, defaults to 256 43 | max_num_patches: Optional[int] = None # For naflex variants 44 | 45 | 46 | @dataclass 47 | class ModelArgs: 48 | text_config: TextConfig 49 | vision_config: VisionConfig 50 | model_type: str = "siglip" 51 | output_hidden_states: bool = False 52 | output_attentions: bool = False 53 | use_return_dict: bool = True 54 | num_labels: int = 0 55 | 56 | @classmethod 57 | def from_dict(cls, params): 58 | return cls( 59 | **{ 60 | k: v 61 | for k, v in params.items() 62 | if k in inspect.signature(cls).parameters 63 | } 64 | ) 65 | 66 | 67 | def check_array_shape(arr): 68 | shape = arr.shape 69 | 70 | # Check if the shape has 4 dimensions 71 | if len(shape) != 4: 72 | return False 73 | 74 | out_channels, kH, KW, _ = shape 75 | 76 | # Check if out_channels is the largest, and kH and KW are the same 77 | if (out_channels >= kH) and (out_channels >= KW) and (kH == KW): 78 | return True 79 | else: 80 | return False 81 | 82 | 83 | class MHA(nn.Module): 84 | def __init__( 85 | self, 86 | dims: int, 87 | num_heads: int, 88 | bias: bool = False, 89 | ): 90 | super().__init__() 91 | 92 | if (dims % num_heads) != 0: 93 | raise ValueError( 94 | "The input feature dimensions should be divisible by the " 95 | f"number of heads ({dims} % {num_heads}) != 0" 96 | ) 97 | 98 | self.num_heads = num_heads 99 | head_dim = dims // num_heads 100 | self.scale = head_dim**-0.5 101 | 102 | self.in_proj = nn.Linear(dims, dims * 3, bias=bias) 103 | self.out_proj = nn.Linear(dims, dims, bias=bias) 104 | 105 | def __call__(self, queries: mx.array, keys: mx.array, values: mx.array, mask=None): 106 | B, L, D = queries.shape 107 | 108 | qkv = self.in_proj(keys) 109 | queries, keys, values = mx.split(qkv, 3, axis=-1) 110 | 111 | num_heads = self.num_heads 112 | B, L, D = queries.shape 113 | _, S, _ = keys.shape 114 | queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) 115 | keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) 116 | values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) 117 | 118 | output = mx.fast.scaled_dot_product_attention( 119 | queries, keys, values, scale=self.scale, mask=mask 120 | ) 121 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 122 | return self.out_proj(output) 123 | 124 | 125 | class Attention(nn.Module): 126 | def __init__( 127 | self, 128 | dims: int, 129 | num_heads: int, 130 | query_input_dims: Optional[int] = None, 131 | key_input_dims: Optional[int] = None, 132 | value_input_dims: Optional[int] = None, 133 | value_dims: Optional[int] = None, 134 | value_output_dims: Optional[int] = None, 135 | bias: bool = False, 136 | ): 137 | super().__init__() 138 | 139 | if (dims % num_heads) != 0: 140 | raise ValueError( 141 | "The input feature dimensions should be divisible by the " 142 | f"number of heads ({dims} % {num_heads}) != 0" 143 | ) 144 | 145 | query_input_dims = query_input_dims or dims 146 | key_input_dims = key_input_dims or dims 147 | value_input_dims = value_input_dims or key_input_dims 148 | value_dims = value_dims or dims 149 | value_output_dims = value_output_dims or dims 150 | 151 | self.num_heads = num_heads 152 | head_dim = dims // num_heads 153 | self.scale = head_dim**-0.5 154 | 155 | self.q_proj = nn.Linear(query_input_dims, dims, bias=bias) 156 | self.k_proj = nn.Linear(key_input_dims, dims, bias=bias) 157 | self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) 158 | self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) 159 | 160 | def __call__(self, x, mask=None): 161 | queries = self.q_proj(x) 162 | keys = self.k_proj(x) 163 | values = self.v_proj(x) 164 | 165 | num_heads = self.num_heads 166 | B, L, D = queries.shape 167 | _, S, _ = keys.shape 168 | queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) 169 | keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) 170 | values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) 171 | 172 | # Process attention mask for multi-head attention if provided 173 | if mask is not None: 174 | if mask.ndim == 2: 175 | # mask shape: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len) 176 | mask = mask[:, None, None, :] 177 | elif mask.ndim == 3: 178 | # mask shape: (batch_size, seq_len, seq_len) -> (batch_size, 1, seq_len, seq_len) 179 | mask = mask[:, None, :, :] 180 | # For boolean masks, convert to additive mask 181 | if mask.dtype == mx.bool_: 182 | # Convert boolean mask to additive mask (True -> 0.0, False -> -inf) 183 | mask = mx.where(mask, 0.0, -mx.inf) 184 | 185 | output = mx.fast.scaled_dot_product_attention( 186 | queries, keys, values, scale=self.scale, mask=mask 187 | ) 188 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 189 | return self.out_proj(output) 190 | 191 | 192 | class MLP(nn.Module): 193 | def __init__(self, config: ModelArgs, approx: str = "none"): 194 | super().__init__() 195 | self.activation_fn = nn.GELU(approx=approx) 196 | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True) 197 | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True) 198 | 199 | def __call__(self, x: mx.array) -> mx.array: 200 | x = self.fc1(x) 201 | x = self.activation_fn(x) 202 | x = self.fc2(x) 203 | return x 204 | 205 | 206 | class EncoderLayer(nn.Module): 207 | def __init__(self, config: ModelArgs, approx: str = "none"): 208 | super().__init__() 209 | self.embed_dim = config.hidden_size 210 | self.self_attn = Attention( 211 | config.hidden_size, config.num_attention_heads, bias=True 212 | ) 213 | self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 214 | self.mlp = MLP(config, approx=approx) 215 | self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 216 | 217 | def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: 218 | 219 | r = self.self_attn(self.layer_norm1(x), mask) 220 | h = x + r 221 | r = self.mlp(self.layer_norm2(h)) 222 | return h + r 223 | 224 | 225 | class Encoder(nn.Module): 226 | def __init__(self, config: ModelArgs, approx: str = "none"): 227 | super().__init__() 228 | self.layers = [ 229 | EncoderLayer(config, approx=approx) for _ in range(config.num_hidden_layers) 230 | ] 231 | 232 | def __call__( 233 | self, 234 | x: mx.array, 235 | output_hidden_states: Optional[bool] = None, 236 | mask: Optional[mx.array] = None, 237 | ) -> mx.array: 238 | encoder_states = (x,) if output_hidden_states else None 239 | for l in self.layers: 240 | x = l(x, mask=mask) 241 | if output_hidden_states: 242 | encoder_states = encoder_states + (x,) 243 | 244 | return (x, encoder_states) 245 | 246 | 247 | class VisionEmbeddings(nn.Module): 248 | def __init__(self, config: ModelArgs): 249 | super().__init__() 250 | self.config = config 251 | self.embed_dim = config.hidden_size 252 | self.image_size = config.image_size 253 | self.patch_size = config.patch_size 254 | 255 | self.patch_embedding = nn.Conv2d( 256 | in_channels=config.num_channels, 257 | out_channels=self.embed_dim, 258 | kernel_size=self.patch_size, 259 | stride=self.patch_size, 260 | ) 261 | 262 | # For SigLIP2, use num_patches if provided, otherwise calculate from image_size 263 | if config.num_patches is not None: 264 | self.num_patches = config.num_patches 265 | else: 266 | self.num_patches = (self.image_size // self.patch_size) ** 2 267 | self.num_positions = self.num_patches 268 | self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) 269 | 270 | def interpolate_pos_encoding( 271 | self, embeddings: mx.array, height: int, width: int 272 | ) -> mx.array: 273 | # TODO: Implement this 274 | raise NotImplementedError( 275 | "Interpolation of positional encodings is not implemented for SigLIP" 276 | ) 277 | 278 | def __call__( 279 | self, 280 | x: mx.array, 281 | interpolate_pos_encoding: bool = False, 282 | pixel_attention_mask: Optional[mx.array] = None, 283 | ) -> mx.array: 284 | _, _, height, width = x.shape 285 | patch_embeddings = self.patch_embedding(x) 286 | patch_embeddings = mx.transpose(patch_embeddings, (0, 3, 1, 2)) 287 | patch_embeddings = mx.flatten(patch_embeddings, start_axis=2, end_axis=3) 288 | patch_embeddings = mx.transpose(patch_embeddings, (0, 2, 1)) 289 | 290 | # Handle variable sequence length for SigLIP2 naflex variants 291 | batch_size, seq_len, embed_dim = patch_embeddings.shape 292 | 293 | # If we have fewer patches than expected, pad to num_positions 294 | if seq_len < self.num_positions: 295 | padding_size = self.num_positions - seq_len 296 | padding = mx.zeros((batch_size, padding_size, embed_dim)) 297 | patch_embeddings = mx.concatenate([patch_embeddings, padding], axis=1) 298 | elif seq_len > self.num_positions: 299 | # Truncate if we have more patches than expected 300 | patch_embeddings = patch_embeddings[:, : self.num_positions, :] 301 | 302 | position_ids = mx.array(np.arange(self.num_positions)[None, :]) 303 | embeddings = patch_embeddings 304 | if interpolate_pos_encoding: 305 | embeddings = self.interpolate_pos_encoding(embeddings, height, width) 306 | else: 307 | embeddings += self.position_embedding(position_ids) 308 | 309 | return embeddings 310 | 311 | 312 | class SiglipMultiheadAttentionPoolingHead(nn.Module): 313 | """Multihead Attention Pooling.""" 314 | 315 | def __init__(self, config: VisionConfig, approx: str = "none"): 316 | super().__init__() 317 | 318 | self.probe = mx.ones((1, 1, config.hidden_size)) 319 | self.attention = MHA(config.hidden_size, config.num_attention_heads, bias=True) 320 | self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 321 | self.mlp = MLP(config, approx=approx) 322 | 323 | def __call__(self, hidden_state): 324 | batch_size = hidden_state.shape[0] 325 | # Repeat the probe for each item in the batch 326 | # mx.repeat only takes one axis at a time, so we need to do it sequentially 327 | probe = mx.repeat(self.probe, batch_size, axis=0) 328 | 329 | hidden_state = self.attention(probe, hidden_state, hidden_state) 330 | 331 | residual = hidden_state 332 | hidden_state = self.layernorm(hidden_state) 333 | hidden_state = residual + self.mlp(hidden_state) 334 | 335 | return hidden_state[:, 0] 336 | 337 | 338 | class SiglipVisionTransformer(nn.Module): 339 | def __init__(self, config: VisionConfig): 340 | super().__init__() 341 | self.embeddings = VisionEmbeddings(config) 342 | self.encoder = Encoder(config, approx="precise") 343 | self.post_layernorm = nn.LayerNorm( 344 | config.hidden_size, eps=config.layer_norm_eps 345 | ) 346 | self.use_head = ( 347 | True if not hasattr(config, "vision_use_head") else config.vision_use_head 348 | ) 349 | 350 | if self.use_head: 351 | self.head = SiglipMultiheadAttentionPoolingHead(config, approx="precise") 352 | 353 | def __call__( 354 | self, 355 | pixel_values: mx.array, 356 | output_hidden_states: Optional[bool] = None, 357 | interpolate_pos_encoding: bool = False, 358 | pixel_attention_mask: Optional[mx.array] = None, 359 | spatial_shapes: Optional[mx.array] = None, 360 | ) -> mx.array: 361 | x = self.embeddings( 362 | pixel_values, 363 | interpolate_pos_encoding=interpolate_pos_encoding, 364 | pixel_attention_mask=pixel_attention_mask, 365 | ) 366 | 367 | # For SigLIP2, we accept pixel_attention_mask but don't process it yet 368 | # This maintains API compatibility while keeping the original behavior 369 | attention_mask = None 370 | 371 | x, encoder_outputs = self.encoder( 372 | x=x, output_hidden_states=output_hidden_states, mask=attention_mask 373 | ) 374 | 375 | x = self.post_layernorm(x) 376 | pooler_output = self.head(x) if self.use_head else None 377 | 378 | if output_hidden_states: 379 | return x, pooler_output, encoder_outputs[1:] 380 | else: 381 | return x, pooler_output 382 | 383 | 384 | class SiglipVisionModel(nn.Module): 385 | def __init__(self, config: ModelArgs): 386 | super().__init__() 387 | self.model_type = config.model_type 388 | if self.model_type not in ["siglip_vision_model"]: 389 | raise ValueError(f"Unsupported model type: {self.model_type}") 390 | 391 | self.vision_model = SiglipVisionTransformer(config) 392 | 393 | def __call__( 394 | self, 395 | pixel_values: mx.array, 396 | output_hidden_states: Optional[bool] = None, 397 | interpolate_pos_encoding: bool = False, 398 | pixel_attention_mask: Optional[mx.array] = None, 399 | spatial_shapes: Optional[mx.array] = None, 400 | ) -> mx.array: 401 | return self.vision_model( 402 | pixel_values, 403 | output_hidden_states, 404 | interpolate_pos_encoding, 405 | pixel_attention_mask, 406 | spatial_shapes, 407 | ) 408 | 409 | def sanitize(self, weights): 410 | sanitized_weights = {} 411 | for k, v in weights.items(): 412 | if "position_ids" in k: 413 | # Remove unused position_ids 414 | continue 415 | if "patch_embedding.weight" in k: 416 | # PyTorch conv2d weight tensors have shape: 417 | # [out_channels, in_channels, kH, KW] 418 | # MLX conv2d expects the weight be of shape: 419 | # [out_channels, kH, KW, in_channels] 420 | if check_array_shape(v): 421 | sanitized_weights[k] = v 422 | else: 423 | sanitized_weights[k] = v.transpose(0, 2, 3, 1) 424 | else: 425 | sanitized_weights[k] = v 426 | 427 | return sanitized_weights 428 | 429 | 430 | class SiglipTextEmbeddings(nn.Module): 431 | def __init__(self, config: TextConfig): 432 | super().__init__() 433 | embed_dim = config.hidden_size 434 | 435 | self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) 436 | self.position_embedding = nn.Embedding( 437 | config.max_position_embeddings, embed_dim 438 | ) 439 | 440 | def __call__( 441 | self, 442 | input_ids: mx.array, 443 | position_ids: mx.array, 444 | inputs_embeds: Optional[mx.array] = None, 445 | ) -> mx.array: 446 | seq_length = ( 447 | input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] 448 | ) 449 | max_position_embedding = self.position_embedding.weight.shape[0] 450 | 451 | if seq_length > max_position_embedding: 452 | raise ValueError( 453 | f"Sequence length must be less than max_position_embeddings (got `sequence length`: " 454 | f"{seq_length} and max_position_embeddings: {max_position_embedding}" 455 | ) 456 | 457 | if position_ids is None: 458 | position_ids = mx.array(np.arange(seq_length)[None, :]) 459 | 460 | if inputs_embeds is None: 461 | inputs_embeds = self.token_embedding(input_ids) 462 | 463 | position_embeddings = self.position_embedding(position_ids) 464 | embeddings = inputs_embeds + position_embeddings 465 | 466 | return embeddings 467 | 468 | def sanitize(self, weights): 469 | sanitized_weights = {} 470 | for k, v in weights.items(): 471 | if "position_ids" in k: 472 | # Remove unused position_ids 473 | continue 474 | else: 475 | sanitized_weights[k] = v 476 | return sanitized_weights 477 | 478 | 479 | class SiglipTextTransformer(nn.Module): 480 | def __init__(self, config: TextConfig): 481 | super().__init__() 482 | self.config = config 483 | embed_dim = config.hidden_size 484 | self.embeddings = SiglipTextEmbeddings(config) 485 | self.encoder = Encoder(config, approx="precise") 486 | self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) 487 | 488 | self.head = nn.Linear(embed_dim, config.projection_size) 489 | 490 | def __call__( 491 | self, 492 | input_ids: mx.array, 493 | attention_mask: Optional[mx.array] = None, 494 | position_ids: Optional[mx.array] = None, 495 | output_hidden_states: Optional[bool] = None, 496 | ) -> mx.array: 497 | 498 | input_shape = input_ids.shape 499 | input_ids = input_ids.reshape(-1, input_shape[-1]) 500 | x = self.embeddings(input_ids, position_ids) 501 | x, encoder_states = self.encoder(x, output_hidden_states, attention_mask) 502 | x = self.final_layer_norm(x) 503 | 504 | # Assuming "sticky" EOS tokenization, last token is always EOS. 505 | pooled_output = x[:, -1, :] 506 | pooled_output = self.head(pooled_output) 507 | 508 | if output_hidden_states: 509 | return x, pooled_output, encoder_states[1:] 510 | else: 511 | return x, pooled_output 512 | 513 | 514 | class SiglipTextModel(nn.Module): 515 | def __init__(self, config: ModelArgs): 516 | super().__init__() 517 | self.model_type = config.model_type 518 | if self.model_type not in ["siglip_text_model"]: 519 | raise ValueError(f"Unsupported model type: {self.model_type}") 520 | 521 | self.text_model = SiglipTextTransformer(config) 522 | 523 | def __call__( 524 | self, 525 | input_ids: mx.array, 526 | attention_mask: Optional[mx.array] = None, 527 | position_ids: Optional[mx.array] = None, 528 | output_hidden_states: Optional[bool] = None, 529 | ) -> mx.array: 530 | return self.text_model( 531 | input_ids, attention_mask, position_ids, output_hidden_states 532 | ) 533 | 534 | 535 | class Model(nn.Module): 536 | def __init__(self, config: ModelArgs): 537 | super().__init__() 538 | self.config = config 539 | vision_config = config.vision_config 540 | self.vision_model = SiglipVisionModel(vision_config) 541 | 542 | if config.num_labels > 0: 543 | # Classifier head 544 | self.classifier = ( 545 | nn.Linear(config.vision_config.hidden_size, config.num_labels) 546 | if config.num_labels > 0 547 | else nn.Identity() 548 | ) 549 | else: 550 | text_config = config.text_config 551 | self.text_model = SiglipTextModel(text_config) 552 | self.logit_scale = mx.zeros((1,)) 553 | self.logit_bias = mx.zeros((1,)) 554 | 555 | def get_text_features( 556 | self, 557 | input_ids: Optional[mx.array] = None, 558 | attention_mask: Optional[mx.array] = None, 559 | position_ids: Optional[mx.array] = None, 560 | output_hidden_states: Optional[bool] = None, 561 | ) -> mx.array: 562 | 563 | # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. 564 | output_hidden_states = ( 565 | output_hidden_states 566 | if output_hidden_states is not None 567 | else self.config.output_hidden_states 568 | ) 569 | 570 | text_outputs = self.text_model( 571 | input_ids=input_ids, 572 | attention_mask=attention_mask, 573 | position_ids=position_ids, 574 | output_hidden_states=output_hidden_states, 575 | ) 576 | 577 | pooled_output = text_outputs[1] 578 | if output_hidden_states: 579 | return pooled_output, text_outputs[2] 580 | else: 581 | return pooled_output 582 | 583 | def get_image_features( 584 | self, 585 | pixel_values: Optional[mx.array] = None, 586 | output_attentions: Optional[bool] = None, 587 | output_hidden_states: Optional[bool] = None, 588 | return_dict: Optional[bool] = None, 589 | interpolate_pos_encoding: bool = False, 590 | pixel_attention_mask: Optional[mx.array] = None, 591 | spatial_shapes: Optional[mx.array] = None, 592 | ) -> mx.array: 593 | 594 | if pixel_values is not None: 595 | dtype = ( 596 | self.vision_model.vision_model.embeddings.patch_embedding.weight.dtype 597 | ) 598 | if not isinstance(pixel_values, mx.array): 599 | pixel_values = mx.array(pixel_values) 600 | 601 | # MLX Conv2d expects NHWC, but processors usually output NCHW. 602 | # If the input looks like NCHW, we transpose it. 603 | if pixel_values.ndim == 4 and pixel_values.shape[1] == 3: 604 | pixel_values = pixel_values.transpose(0, 2, 3, 1) 605 | 606 | pixel_values = pixel_values.astype(dtype) 607 | 608 | # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. 609 | output_attentions = ( 610 | output_attentions 611 | if output_attentions is not None 612 | else self.config.output_attentions 613 | ) 614 | output_hidden_states = ( 615 | output_hidden_states 616 | if output_hidden_states is not None 617 | else self.config.output_hidden_states 618 | ) 619 | return_dict = ( 620 | return_dict if return_dict is not None else self.config.use_return_dict 621 | ) 622 | 623 | vision_outputs = self.vision_model( 624 | pixel_values=pixel_values, 625 | output_hidden_states=output_hidden_states, 626 | interpolate_pos_encoding=interpolate_pos_encoding, 627 | pixel_attention_mask=pixel_attention_mask, 628 | spatial_shapes=spatial_shapes, 629 | ) 630 | 631 | pooled_output = vision_outputs[1] 632 | 633 | return pooled_output 634 | 635 | def __call__( 636 | self, 637 | input_ids: Optional[mx.array] = None, 638 | pixel_values: Optional[mx.array] = None, 639 | attention_mask: Optional[mx.array] = None, 640 | position_ids: Optional[mx.array] = None, 641 | output_hidden_states: Optional[bool] = None, 642 | interpolate_pos_encoding: bool = False, 643 | pixel_attention_mask: Optional[mx.array] = None, 644 | spatial_shapes: Optional[mx.array] = None, 645 | ) -> mx.array: 646 | if pixel_values is not None: 647 | dtype = ( 648 | self.vision_model.vision_model.embeddings.patch_embedding.weight.dtype 649 | ) 650 | if not isinstance(pixel_values, mx.array): 651 | pixel_values = mx.array(pixel_values) 652 | 653 | # MLX Conv2d expects NHWC, but processors usually output NCHW. 654 | # If the input looks like NCHW, we transpose it. 655 | if pixel_values.ndim == 4 and pixel_values.shape[1] == 3: 656 | pixel_values = pixel_values.transpose(0, 2, 3, 1) 657 | 658 | pixel_values = pixel_values.astype(dtype) 659 | 660 | # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. 661 | output_hidden_states = ( 662 | output_hidden_states 663 | if output_hidden_states is not None 664 | else self.config.output_hidden_states 665 | ) 666 | 667 | vision_outputs = self.vision_model( 668 | pixel_values=pixel_values, 669 | output_hidden_states=output_hidden_states, 670 | interpolate_pos_encoding=interpolate_pos_encoding, 671 | pixel_attention_mask=pixel_attention_mask, 672 | spatial_shapes=spatial_shapes, 673 | ) 674 | 675 | image_embeds = vision_outputs[1] 676 | 677 | # classifier head 678 | logits = None 679 | if self.config.num_labels > 0: 680 | # average pool the patch tokens 681 | image_embeds_mean = mx.mean(image_embeds, axis=1) 682 | # apply classifier 683 | logits = self.classifier(image_embeds_mean) 684 | 685 | return ViTModelOutput( 686 | logits=logits, 687 | image_embeds=image_embeds, 688 | vision_model_output=vision_outputs, 689 | ) 690 | 691 | else: 692 | text_outputs = self.text_model( 693 | input_ids=input_ids, 694 | attention_mask=attention_mask, 695 | position_ids=position_ids, 696 | output_hidden_states=output_hidden_states, 697 | ) 698 | 699 | text_embeds = text_outputs[1] 700 | 701 | # normalized features 702 | image_embeds = normalize_embeddings(image_embeds) 703 | text_embeds = normalize_embeddings(text_embeds) 704 | 705 | # cosine similarity as logits 706 | logits_per_text = mx.matmul(text_embeds, image_embeds.T) 707 | 708 | # Apply scale and bias 709 | logits_per_text = ( 710 | logits_per_text * mx.exp(self.logit_scale) + self.logit_bias 711 | ) 712 | 713 | logits_per_image = logits_per_text.T 714 | 715 | return ViTModelOutput( 716 | logits_per_text=logits_per_text, 717 | logits_per_image=logits_per_image, 718 | text_embeds=text_embeds, 719 | image_embeds=image_embeds, 720 | text_model_output=text_outputs, 721 | vision_model_output=vision_outputs, 722 | ) 723 | 724 | def sanitize(self, weights): 725 | sanitized_weights = {} 726 | for k, v in weights.items(): 727 | if k.startswith("text_model") and not k.startswith("text_model.text_model"): 728 | sanitized_weights["text_model." + k] = v 729 | elif k.startswith("vision_model") and not k.startswith( 730 | "vision_model.vision_model" 731 | ): 732 | if "in_proj_bias" in k: 733 | k = k.replace("in_proj_bias", "in_proj.bias") 734 | if "in_proj_weight" in k: 735 | k = k.replace("in_proj_weight", "in_proj.weight") 736 | sanitized_weights["vision_model." + k] = v 737 | else: 738 | sanitized_weights[k] = v 739 | 740 | if hasattr(self.text_model, "sanitize"): 741 | sanitized_weights = self.text_model.sanitize(sanitized_weights) 742 | if hasattr(self.vision_model, "sanitize"): 743 | sanitized_weights = self.vision_model.sanitize(sanitized_weights) 744 | 745 | return sanitized_weights 746 | --------------------------------------------------------------------------------