├── .github ├── FUNDING.yml └── workflows │ ├── deploy-docs.yml │ ├── python-publish.yml │ ├── tests.yml │ └── update-changelog.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── computer_use ├── README.md ├── audio │ ├── ask.wav │ ├── ask_2.wav │ ├── ask_voice.wav │ ├── ask_voice_2.wav │ ├── ok.wav │ ├── output.wav │ └── task_completed.wav ├── autonomous_gui_agent.py ├── autonomous_gui_agent_voice.py ├── gui_agent.py ├── gui_agent_voice.py ├── navigation_history.csv ├── requirements.txt ├── screenshots │ ├── screenshot_20241210-191351.png │ ├── screenshot_20241210-191519.png │ ├── screenshot_20241210-192437.png │ ├── screenshot_20241210-193535.png │ ├── screenshot_20241210-193911.png │ ├── screenshot_20241210-194314.png │ ├── screenshot_20241210-195055.png │ ├── screenshot_20241210-195145.png │ ├── screenshot_20241210-230354.png │ ├── screenshot_20241210-230553.png │ ├── screenshot_20241210-230651.png │ ├── screenshot_20241210-230730.png │ ├── screenshot_20241210-230923.png │ ├── screenshot_20241210-230958.png │ ├── screenshot_20241210-231015.png │ ├── screenshot_20241210-231058.png │ ├── screenshot_20241210-231133.png │ ├── screenshot_20241210-231212.png │ ├── screenshot_20241210-231321.png │ ├── screenshot_20241210-231404.png │ ├── screenshot_20241210-231511.png │ ├── screenshot_20241210-231527.png │ ├── screenshot_20241210-231552.png │ ├── screenshot_20241210-231719.png │ ├── screenshot_20241210-231751.png │ ├── screenshot_20241211-101202.png │ └── screenshot_20241211-101219.png └── utils.py ├── docs ├── changelog.md ├── cli_reference.md ├── community_projects.md ├── contributing.md ├── examples.md ├── index.md ├── installation.md ├── overrides │ └── main.html ├── report_issues.md └── usage.md ├── examples ├── images │ ├── cats.jpg │ ├── desktop_setup.png │ ├── graph.png │ ├── handwritten_hard.webp │ ├── latex.png │ ├── menu.webp │ ├── paper.png │ └── renewables_california.png ├── multi_image_generation.ipynb ├── object_detection.ipynb ├── object_pointing.ipynb ├── ocr_with_region.ipynb ├── text_extraction.ipynb ├── utils.py ├── video_understanding.ipynb └── videos │ └── fastmlx_local_ai_hub.mp4 ├── mkdocs.yml ├── mlx_vlm ├── LORA.MD ├── __init__.py ├── chat.py ├── chat_ui.py ├── convert.py ├── generate.py ├── lora.py ├── models │ ├── __init__.py │ ├── aya_vision │ │ ├── __init__.py │ │ ├── aya_vision.py │ │ ├── interpolate.py │ │ ├── language.py │ │ └── vision.py │ ├── base.py │ ├── cache.py │ ├── deepseek_vl_v2 │ │ ├── __init__.py │ │ ├── conversation.py │ │ ├── deepseek_vl_v2.py │ │ ├── language.py │ │ ├── processing_deepsek_vl_v2.py │ │ └── vision.py │ ├── florence2 │ │ ├── __init__.py │ │ ├── florence2.py │ │ ├── language.py │ │ └── vision.py │ ├── gemma3 │ │ ├── __init__.py │ │ ├── gemma3.py │ │ ├── language.py │ │ └── vision.py │ ├── idefics2 │ │ ├── __init__.py │ │ ├── idefics2.py │ │ ├── language.py │ │ └── vision.py │ ├── idefics3 │ │ ├── __init__.py │ │ ├── idefics3.py │ │ ├── language.py │ │ └── vision.py │ ├── internvl_chat │ │ ├── __init__.py │ │ ├── internvl_chat.py │ │ ├── language.py │ │ ├── processor.py │ │ └── vision.py │ ├── kimi_vl │ │ ├── __init__.py │ │ ├── kimi_vl.py │ │ ├── language.py │ │ └── vision.py │ ├── llama4 │ │ ├── __init__.py │ │ ├── language.py │ │ ├── llama4.py │ │ └── vision.py │ ├── llava │ │ ├── __init__.py │ │ ├── language.py │ │ ├── llava.py │ │ └── vision.py │ ├── llava_bunny │ │ ├── __init__.py │ │ ├── language.py │ │ ├── llava_bunny.py │ │ └── vision.py │ ├── llava_next │ │ ├── __init__.py │ │ ├── language.py │ │ ├── llava_next.py │ │ └── vision.py │ ├── mistral3 │ │ ├── __init__.py │ │ └── mistral3.py │ ├── mllama │ │ ├── __init__.py │ │ ├── language.py │ │ ├── mllama.py │ │ └── vision.py │ ├── molmo │ │ ├── __init__.py │ │ ├── language.py │ │ ├── molmo.py │ │ └── vision.py │ ├── multi_modality │ │ ├── __init__.py │ │ ├── language.py │ │ ├── multi_modality.py │ │ ├── sam.py │ │ └── vision.py │ ├── paligemma │ │ ├── __init__.py │ │ ├── language.py │ │ ├── paligemma.py │ │ └── vision.py │ ├── phi3_v │ │ ├── __init__.py │ │ ├── language.py │ │ ├── phi3_v.py │ │ ├── su_rope.py │ │ └── vision.py │ ├── pixtral │ │ ├── __init__.py │ │ ├── language.py │ │ ├── pixtral.py │ │ └── vision.py │ ├── qwen2_5_vl │ │ ├── __init__.py │ │ ├── language.py │ │ ├── qwen2_5_vl.py │ │ └── vision.py │ ├── qwen2_vl │ │ ├── __init__.py │ │ ├── config.py │ │ ├── language.py │ │ ├── qwen2_vl.py │ │ └── vision.py │ └── smolvlm │ │ ├── __init__.py │ │ └── smolvlm.py ├── prompt_utils.py ├── sample_utils.py ├── server.py ├── smolvlm_video_generate.py ├── tests │ ├── test_models.py │ ├── test_smoke.py │ ├── test_trainer.py │ ├── test_trainer_utils.py │ └── test_utils.py ├── tokenizer_utils.py ├── trainer │ ├── __init__.py │ ├── lora.py │ ├── trainer.py │ └── utils.py ├── utils.py ├── version.py └── video_generate.py ├── pytest.ini ├── requirements.txt ├── setup.py └── update_changelog.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: Blaizzy # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | polar: # Replace with a single Polar username 13 | buy_me_a_coffee: # Replace with a single Buy Me a Coffee username 14 | thanks_dev: # Replace with a single thanks.dev username 15 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 16 | -------------------------------------------------------------------------------- /.github/workflows/deploy-docs.yml: -------------------------------------------------------------------------------- 1 | name: Deploy Documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | deploy: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v2 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install mkdocs mkdocs-material mkdocstrings[python] mkdocs-autorefs mkdocs-git-revision-date-localized-plugin mkdocs-jupyter 26 | 27 | - name: Build documentation 28 | run: mkdocs build 29 | 30 | - name: Deploy to GitHub Pages 31 | if: github.event_name == 'push' && github.ref == 'refs/heads/main' 32 | uses: peaceiris/actions-gh-pages@v3 33 | with: 34 | github_token: ${{ secrets.DOCS_TOKEN }} 35 | publish_dir: ./site 36 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [published] 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | deploy: 15 | 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - uses: actions/checkout@v3 20 | - name: Set up Python 21 | uses: actions/setup-python@v3 22 | with: 23 | python-version: '3.10' 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install build 28 | - name: Build package 29 | run: python -m build 30 | - name: Publish package 31 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 32 | with: 33 | user: __token__ 34 | password: ${{ secrets.PYPI_API_TOKEN }} 35 | packages_dir: dist 36 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Test PRs 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | test: 10 | runs-on: macos-14 11 | 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v4 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: '3.10' 20 | 21 | - name: Install MLX 22 | run: | 23 | pip install mlx>=0.15 24 | 25 | - name: Install pre-commit 26 | run: | 27 | python -m pip install pre-commit 28 | pre-commit run --all 29 | if ! git diff --quiet; then 30 | echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change' 31 | exit 1 32 | fi 33 | 34 | - name: Install package and dependencies 35 | run: | 36 | python -m pip install pytest 37 | python -m pip install -e . 38 | 39 | - name: Run Python tests 40 | run: | 41 | cd mlx_vlm/ 42 | pytest -s ./tests --ignore=tests/test_smoke.py 43 | -------------------------------------------------------------------------------- /.github/workflows/update-changelog.yml: -------------------------------------------------------------------------------- 1 | name: Update Changelog 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | update-changelog: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 13 | uses: actions/setup-python@v2 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install requests python-dotenv 20 | - name: Update Changelog 21 | env: 22 | GITHUB_TOKEN: ${{ secrets.DOCS_TOKEN }} 23 | run: python update_changelog.py 24 | - name: Commit changes 25 | run: | 26 | set -x 27 | git config --local user.email "action@github.com" 28 | git config --local user.name "GitHub Action" 29 | 30 | echo "Fetching latest changes..." 31 | git fetch origin 32 | 33 | echo "Checking out and updating branch..." 34 | if [ "${{ github.event_name }}" = "pull_request" ]; then 35 | git checkout -B "${{ github.head_ref }}" "origin/${{ github.head_ref }}" 36 | git pull origin "${{ github.head_ref }}" 37 | else 38 | git checkout -B "${{ github.ref_name }}" "origin/${{ github.ref_name }}" 39 | git pull origin "${{ github.ref_name }}" 40 | fi 41 | 42 | echo "Running update script..." 43 | python update_changelog.py 44 | 45 | echo "Checking for changes..." 46 | git add docs/changelog.md 47 | git pull 48 | if git diff --staged --quiet; then 49 | echo "No changes to commit" 50 | else 51 | git commit -m "Update changelog for latest release" 52 | git push origin HEAD:"${{ github.head_ref || github.ref_name }}" || echo "Failed to push changes" 53 | fi 54 | 55 | git status 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | test.ipynb 2 | **.pyc 3 | 4 | Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | 8 | # Distribution / packaging 9 | bin/ 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | .DS_Store 23 | *.log 24 | 25 | # virtual environment 26 | .myenv 27 | .venv -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black-pre-commit-mirror 3 | rev: 24.2.0 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/pycqa/isort 7 | rev: 5.13.2 8 | hooks: 9 | - id: isort 10 | args: 11 | - --profile=black 12 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to MLX VLM 2 | 3 | Below are some tips to port Vision LLMs available on Hugging Face to MLX. 4 | 5 | Next, from this directory, do an editable install: 6 | 7 | ```shell 8 | pip install -e . 9 | ``` 10 | 11 | Then check if the model has weights in the 12 | [safetensors](https://huggingface.co/docs/safetensors/index) format. If not 13 | [follow instructions](https://huggingface.co/spaces/safetensors/convert) to 14 | convert it. 15 | 16 | After that, add the model file to the 17 | [`mlx_vlm/models`](https://github.com/Blaizzy/mlx-vlm/tree/main/src/models) 18 | directory. You can see other examples there. We recommend starting from a model 19 | that is similar to the model you are porting. 20 | 21 | Make sure the name of the new model file is the same as the `model_type` in the 22 | `config.json`, for example 23 | [llava](https://huggingface.co/llava-hf/llava-1.5-7b-hf/blob/main/config.json#L7). 24 | 25 | To determine the model layer names, we suggest either: 26 | 27 | - Refer to the Transformers implementation if you are familiar with the 28 | codebase. 29 | - Load the model weights and check the weight names which will tell you about 30 | the model structure. 31 | - Look at the names of the weights by inspecting `model.safetensors.index.json` 32 | in the Hugging Face repo. 33 | 34 | Additionally, add a test for the new modle type to the [model 35 | tests](https://github.com/Blaizzy/mlx-vlm/tree/main/src/tests/test_models.py). 36 | 37 | From the `src/` directory, you can run the tests with: 38 | 39 | ```shell 40 | python -m unittest discover tests/ 41 | ``` 42 | 43 | ## Pull Requests 44 | 45 | 1. Fork and submit pull requests to the repo. 46 | 2. If you've added code that should be tested, add tests. 47 | 3. Every PR should have passing tests and at least one review. 48 | 4. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`. 49 | This should install hooks for running `black` and `clang-format` to ensure 50 | consistent style for C++ and python code. 51 | 52 | You can also run the formatters manually as follows on individual files: 53 | 54 | ```bash 55 | clang-format -i file.cpp 56 | ``` 57 | 58 | ```bash 59 | black file.py 60 | ``` 61 | 62 | or, 63 | 64 | ```bash 65 | # single file 66 | pre-commit run --files file1.py 67 | 68 | # specific files 69 | pre-commit run --files file1.py file2.py 70 | ``` 71 | 72 | or run `pre-commit run --all-files` to check all files in the repo. 73 | 74 | ## Issues 75 | 76 | We use GitHub issues to track public bugs. Please ensure your description is 77 | clear and has sufficient instructions to be able to reproduce the issue. 78 | 79 | ## License 80 | 81 | By contributing to mlx-examples, you agree that your contributions will be licensed 82 | under the LICENSE file in the root directory of this source tree. 83 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright © 2023 Apple Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include ./requirements.txt 2 | recursive-include mlx_vlm/ *.py 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Upload Python Package](https://github.com/Blaizzy/mlx-vlm/actions/workflows/python-publish.yml/badge.svg)](https://github.com/Blaizzy/mlx-vlm/actions/workflows/python-publish.yml) 2 | # MLX-VLM 3 | 4 | MLX-VLM is a package for inference and fine-tuning of Vision Language Models (VLMs) on your Mac using MLX. 5 | 6 | ## Table of Contents 7 | - [Installation](#installation) 8 | - [Usage](#usage) 9 | - [Command Line Interface (CLI)](#command-line-interface-cli) 10 | - [Chat UI with Gradio](#chat-ui-with-gradio) 11 | - [Python Script](#python-script) 12 | - [Multi-Image Chat Support](#multi-image-chat-support) 13 | - [Supported Models](#supported-models) 14 | - [Usage Examples](#usage-examples) 15 | - [Fine-tuning](#fine-tuning) 16 | 17 | ## Installation 18 | 19 | The easiest way to get started is to install the `mlx-vlm` package using pip: 20 | 21 | ```sh 22 | pip install mlx-vlm 23 | ``` 24 | 25 | ## Usage 26 | 27 | ### Command Line Interface (CLI) 28 | 29 | Generate output from a model using the CLI: 30 | 31 | ```sh 32 | python -m mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --temperature 0.0 --image http://images.cocodataset.org/val2017/000000039769.jpg 33 | ``` 34 | 35 | ### Chat UI with Gradio 36 | 37 | Launch a chat interface using Gradio: 38 | 39 | ```sh 40 | python -m mlx_vlm.chat_ui --model mlx-community/Qwen2-VL-2B-Instruct-4bit 41 | ``` 42 | 43 | ### Python Script 44 | 45 | Here's an example of how to use MLX-VLM in a Python script: 46 | 47 | ```python 48 | import mlx.core as mx 49 | from mlx_vlm import load, generate 50 | from mlx_vlm.prompt_utils import apply_chat_template 51 | from mlx_vlm.utils import load_config 52 | 53 | # Load the model 54 | model_path = "mlx-community/Qwen2-VL-2B-Instruct-4bit" 55 | model, processor = load(model_path) 56 | config = load_config(model_path) 57 | 58 | # Prepare input 59 | image = ["http://images.cocodataset.org/val2017/000000039769.jpg"] 60 | # image = [Image.open("...")] can also be used with PIL.Image.Image objects 61 | prompt = "Describe this image." 62 | 63 | # Apply chat template 64 | formatted_prompt = apply_chat_template( 65 | processor, config, prompt, num_images=len(image) 66 | ) 67 | 68 | # Generate output 69 | output = generate(model, processor, formatted_prompt, image, verbose=False) 70 | print(output) 71 | ``` 72 | 73 | ### Server (FastAPI) 74 | To start the server 75 | ```sh 76 | python -m mlx_vlm.server 77 | ``` 78 | 79 | Models can be loaded or unloaded dynamically and they are cached (one at a time) when the server is running. 80 | 81 | Usage example: 82 | ```sh 83 | curl -X POST "http://localhost:8000/generate" \ 84 | -H "Content-Type: application/json" \ 85 | -d '{ 86 | "model": "mlx-community/Qwen2.5-VL-32B-Instruct-8bit", 87 | "image": ["/path/to/repo/examples/images/renewables_california.png"], 88 | "prompt": "This is today'\''s chart for energy demand in California. Can you provide an analysis of the chart and comment on the implications for renewable energy in California?", 89 | "system": "You are a helpful assistant.", 90 | "stream": true, 91 | "max_tokens": 1000 92 | }' 93 | ``` 94 | 95 | 96 | ## Multi-Image Chat Support 97 | 98 | MLX-VLM supports analyzing multiple images simultaneously with select models. This feature enables more complex visual reasoning tasks and comprehensive analysis across multiple images in a single conversation. 99 | 100 | ### Supported Models 101 | 102 | The following models support multi-image chat: 103 | 104 | 1. Idefics 2 105 | 2. LLaVA (Interleave) 106 | 3. Qwen2-VL 107 | 4. Phi3-Vision 108 | 5. Pixtral 109 | 110 | ### Usage Examples 111 | 112 | #### Python Script 113 | 114 | ```python 115 | from mlx_vlm import load, generate 116 | from mlx_vlm.prompt_utils import apply_chat_template 117 | from mlx_vlm.utils import load_config 118 | 119 | model_path = "mlx-community/Qwen2-VL-2B-Instruct-4bit" 120 | model, processor = load(model_path) 121 | config = load_config(model_path) 122 | 123 | images = ["path/to/image1.jpg", "path/to/image2.jpg"] 124 | prompt = "Compare these two images." 125 | 126 | formatted_prompt = apply_chat_template( 127 | processor, config, prompt, num_images=len(images) 128 | ) 129 | 130 | output = generate(model, processor, formatted_prompt, images, verbose=False) 131 | print(output) 132 | ``` 133 | 134 | #### Command Line 135 | 136 | ```sh 137 | python -m mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --prompt "Compare these images" --image path/to/image1.jpg path/to/image2.jpg 138 | ``` 139 | 140 | ## Video Understanding 141 | 142 | MLX-VLM also supports video analysis such as captioning, summarization, and more, with select models. 143 | 144 | ### Supported Models 145 | 146 | The following models support video chat: 147 | 148 | 1. Qwen2-VL 149 | 2. Qwen2.5-VL 150 | 3. Idefics3 151 | 4. LLaVA 152 | 153 | With more coming soon. 154 | 155 | ### Usage Examples 156 | 157 | #### Command Line 158 | ```sh 159 | python -m mlx_vlm.video_generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --prompt "Describe this video" --video path/to/video.mp4 --max-pixels 224 224 --fps 1.0 160 | ``` 161 | 162 | 163 | These examples demonstrate how to use multiple images with MLX-VLM for more complex visual reasoning tasks. 164 | 165 | # Fine-tuning 166 | 167 | MLX-VLM supports fine-tuning models with LoRA and QLoRA. 168 | 169 | ## LoRA & QLoRA 170 | 171 | To learn more about LoRA, please refer to the [LoRA.md](./mlx_vlm/LORA.MD) file. 172 | -------------------------------------------------------------------------------- /computer_use/README.md: -------------------------------------------------------------------------------- 1 | # Computer Use with MLX-VLM 2 | 3 |
4 | 5 | ![MLX-VLM Computer Control](https://img.shields.io/badge/MLX--VLM-Computer%20Control-blue) 6 | ![macOS](https://img.shields.io/badge/platform-macOS-lightgrey) 7 | ![Apple Silicon](https://img.shields.io/badge/optimized-Apple%20Silicon-orange) 8 | ![License](https://img.shields.io/badge/license-MIT-green) 9 | 10 |
11 | 12 | A powerful tool that leverages Vision Language Models (VLMs) to enable AI-driven control of your Mac through visual understanding and contextual reasoning. 13 | 14 |

15 | Automate your workflow with natural language commands and visual intelligence 16 |

17 | 18 | ## 🤖 Current Implementation Status 19 | The project now supports both Level 1 (GUI Agent) and Level 2 (Autonomous GUI Agent) capabilities: 20 | 21 | - **Level 1 (GUI Agent)**: Basic visual understanding and action capabilities 22 | - **Level 2 (Autonomous GUI Agent)**: Enhanced memory, planning, reasoning, and human-in-the-loop functionality 23 | 24 | *Community help is more than welcome!* We're looking for contributors to help us enhance these capabilities further. Join us in building the future of computer automation. 25 | 26 | 27 | ## 🔍 Overview 28 | 29 | Computer Use with MLX-VLM transforms how you interact with your Mac by combining the power of: 30 | 31 | - **MLX** - Apple's machine learning framework optimized for Apple Silicon 32 | - **Vision Language Models (VLMs)** - AI models that understand both visual and textual information 33 | - **Automation** - Seamless execution of tasks across your Mac's interface 34 | 35 | By processing screenshots and visual information from your screen, the system understands the current state of applications and executes appropriate actions to accomplish tasks you specify in natural language. 36 | 37 | ## ✨ Key Features 38 | 39 | - **Mac-Native Performance**: Optimized for Apple Silicon with MLX for efficient, local processing 40 | - **Visual Understanding**: Interprets screen content, UI elements, and application states 41 | - **Contextual Reasoning**: Makes intelligent decisions based on visual context 42 | - **Cross-Application Automation**: Works across multiple applications and system interfaces 43 | - **Natural Language Control**: Simple, human-like instructions to control your computer 44 | - **Privacy-Focused**: All processing happens locally on your device 45 | - **Customizable**: Adapt to your specific workflow and preferences 46 | - **Autonomous Operation**: Level 2 agent can plan and execute multi-step tasks with minimal supervision 47 | - **Voice Control**: Hands-free operation with voice commands using local speech recognition 48 | 49 | ## 🚀 Getting Started 50 | 51 | ### Prerequisites 52 | 53 | - **macOS** running on Apple Silicon (M series) 54 | - **Python 3.8+** 55 | - **pip** (Python package manager) 56 | 57 | ### Installation 58 | 1. **Install MLX-VLM package**: 59 | ```bash 60 | pip install mlx-vlm 61 | ``` 62 | 63 | 2. **Clone the repository**: 64 | ```bash 65 | git clone https://github.com/Blaizzy/mlx-vlm.git 66 | ``` 67 | 68 | 3. **Navigate to computer control directory**: 69 | ```bash 70 | cd computer_use 71 | ``` 72 | 73 | 4. **Install dependencies**: 74 | ```bash 75 | pip install -r requirements.txt 76 | ``` 77 | 78 | ## 💻 Usage 79 | 80 | ### Quick Start 81 | 82 | Launch the standard application with: 83 | 84 | ```bash 85 | python main.py 86 | ``` 87 | 88 | ### Autonomous GUI Agent 89 | 90 | For enhanced autonomous operation with planning capabilities: 91 | 92 | ```bash 93 | python autonomous_gui_agent.py 94 | ``` 95 | 96 | This launches the Level 2 autonomous agent that can: 97 | - Plan and execute multi-step tasks 98 | - Maintain context across actions 99 | - Make decisions based on visual feedback 100 | - Request human assistance when needed 101 | 102 | ### Voice Control Interface 103 | 104 | For hands-free operation, you can use the voice-enabled autonomous agent: 105 | 106 | ```bash 107 | python autonomous_gui_agent_voice.py 108 | ``` 109 | 110 | This launches a voice-controlled version that: 111 | - Listens for your voice commands using your Mac's microphone 112 | - Converts speech to text using local speech recognition using [mlx-whisper](https://github.com/ml-explore/mlx-examples) 113 | - Processes your commands and executes them visually 114 | - Provides audio feedback on actions taken 115 | 116 | Voice commands work just like text commands, so you can say things like: 117 | 118 | ### Command Examples 119 | 120 | Control your Mac with natural language instructions like: 121 | 122 | ``` 123 | "Open Safari and navigate to apple.com" 124 | "Open the notifications tab and click on the first notification" 125 | "Open the email app and reply to the most recent email" 126 | ``` 127 | 128 | ## ⚙️ How It Works 129 | 130 | 1. **Screen Capture**: The system takes screenshots of your Mac display 131 | 2. **Visual Analysis**: MLX-VLM processes the visual information to understand: 132 | - UI elements and their states 133 | - Text content on screen 134 | - Application context 135 | - System status 136 | 3. **Instruction Processing**: Your natural language commands are interpreted 137 | 4. **Action Planning**: The system determines the sequence of actions needed 138 | 5. **Execution**: Actions are performed through macOS APIs or simulated inputs (click, scroll, etc) 139 | 140 | ## 🔒 Privacy & Security 141 | 142 | - **Local Processing**: All AI inference happens on your Mac using MLX 143 | - **No Cloud Dependency**: Your screenshots and data never leave your device 144 | - **Permission Control**: Fine-grained control over what the system can access 145 | - **Transparent Operation**: Clear visibility into actions being performed 146 | 147 | ## 🛠️ Troubleshooting 148 | 149 | ### Common Issues 150 | 151 | - **Permission Errors**: Make sure to grant screen recording permissions in System Preferences > Security & Privacy > Privacy 152 | - **Performance Issues**: Try reducing the screenshot resolution in config.json 153 | - **Application Compatibility**: Some applications with non-standard UI elements may have limited support 154 | 155 | ### Getting Help 156 | 157 | - Check the [Issues](https://github.com/yourusername/computer_use/issues) page 158 | - Join our [Discord community](https://discord.gg/yourdiscord) 159 | 160 | ## 🤝 Contributing 161 | 162 | We welcome contributions! Here's how to get started: 163 | 164 | 1. **Fork the repository** 165 | 2. **Create your feature branch**: 166 | ```bash 167 | git checkout -b feature/amazing-feature 168 | ``` 169 | 3. **Make your changes** 170 | 4. **Run tests**: 171 | ```bash 172 | python -m pytest 173 | ``` 174 | 5. **Commit your changes**: 175 | ```bash 176 | git commit -m 'Add some amazing feature' 177 | ``` 178 | 6. **Push to the branch**: 179 | ```bash 180 | git push origin feature/amazing-feature 181 | ``` 182 | 7. **Open a Pull Request** 183 | 184 | Please read [CONTRIBUTING.md](CONTRIBUTING.md) for detailed guidelines. 185 | 186 | ## 📜 License 187 | 188 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 189 | 190 | ## 🙏 Acknowledgments 191 | 192 | - The **MLX team at Apple** for creating the MLX framework 193 | - **Our community of testers and contributors** who help improve the project 194 | 195 | --- 196 | 197 |

198 | Made with ❤️ for Mac users who love automation and AI 199 |

-------------------------------------------------------------------------------- /computer_use/audio/ask.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/audio/ask.wav -------------------------------------------------------------------------------- /computer_use/audio/ask_2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/audio/ask_2.wav -------------------------------------------------------------------------------- /computer_use/audio/ask_voice.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/audio/ask_voice.wav -------------------------------------------------------------------------------- /computer_use/audio/ask_voice_2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/audio/ask_voice_2.wav -------------------------------------------------------------------------------- /computer_use/audio/ok.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/audio/ok.wav -------------------------------------------------------------------------------- /computer_use/audio/output.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/audio/output.wav -------------------------------------------------------------------------------- /computer_use/audio/task_completed.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/audio/task_completed.wav -------------------------------------------------------------------------------- /computer_use/navigation_history.csv: -------------------------------------------------------------------------------- 1 | Query,Response,Screenshot Path 2 | Click on Explore,"{'action': 'CLICK', 'value': None, 'position': [211.68, 294.59999999999997]}",screenshots/screenshot_20241210-191351.png 3 | Click on notifications,"{'action': 'CLICK', 'value': None, 'position': [241.92000000000002, 333.88000000000005]}",screenshots/screenshot_20241210-191519.png 4 | Click on Profile,"{'action': 'CLICK', 'value': None, 'position': [196.56, 746.32]}",screenshots/screenshot_20241210-192437.png 5 | Click on the Image that Philipp posted,"{'action': 'CLICK', 'value': None, 'position': [589.6800000000001, 569.56]}",screenshots/screenshot_20241210-193535.png 6 | Select text on the screen,"{'action': 'CLICK', 'value': None, 'position': [740.88, 314.24]}",screenshots/screenshot_20241210-193911.png 7 | Click on show posts,"{'action': 'CLICK', 'value': None, 'position': [680.4, 343.7]}",screenshots/screenshot_20241210-194314.png 8 | Search youtube,"{'action': 'CLICK', 'value': None, 'position': [1179.3600000000001, 98.2]}",screenshots/screenshot_20241210-195055.png 9 | Click on the BMW X3 video,"{'action': 'CLICK', 'value': None, 'position': [1300.32, 559.7399999999999]}",screenshots/screenshot_20241210-195145.png 10 | Click on notifications,"{'action': 'CLICK', 'value': None, 'position': [257.04, 333.88000000000005]}",screenshots/screenshot_20241210-230354.png 11 | Click on notifications,"{'action': 'CLICK', 'value': None, 'position': [468.71999999999997, 196.4]}",screenshots/screenshot_20241210-230553.png 12 | Click on notifications,"{'action': 'CLICK', 'value': None, 'position': [468.71999999999997, 196.4]}",screenshots/screenshot_20241210-230651.png 13 | Click on notifications,"{'action': 'CLICK', 'value': None, 'position': [468.71999999999997, 196.4]}",screenshots/screenshot_20241210-230730.png 14 | Click on notifications,"{'action': 'CLICK', 'value': None, 'position': [512.0, 299.52]}",screenshots/screenshot_20241210-230923.png 15 | Click on notifications,"{'action': 'CLICK', 'value': None, 'position': [793.6, 288.0]}",screenshots/screenshot_20241210-230958.png 16 | Click on Home,"{'action': 'CLICK', 'value': None, 'position': [742.4, 187.20000000000002]}",screenshots/screenshot_20241210-231015.png 17 | Click on Elon Musk,"{'action': 'CLICK', 'value': None, 'position': [1239.84, 1040.92]}",screenshots/screenshot_20241210-231058.png 18 | Click on Elon Musk's profile,"{'action': 'CLICK', 'value': None, 'position': [1058.3999999999999, 1060.5600000000002]}",screenshots/screenshot_20241210-231133.png 19 | Click on Elon Musks profile,"{'action': 'CLICK', 'value': None, 'position': [1058.3999999999999, 1060.5600000000002]}",screenshots/screenshot_20241210-231212.png 20 | Click on Nikita,"{'action': 'CLICK', 'value': None, 'position': [1028.16, 510.64000000000004]}",screenshots/screenshot_20241210-231321.png 21 | Click on Explore,"{'action': 'CLICK', 'value': None, 'position': [423.36, 549.9200000000001]}",screenshots/screenshot_20241210-231404.png 22 | Click on Explore,"{'action': 'CLICK', 'value': None, 'position': [211.68, 274.96000000000004]}",screenshots/screenshot_20241210-231511.png 23 | Click on Home,"{'action': 'CLICK', 'value': None, 'position': [181.44, 225.86]}",screenshots/screenshot_20241210-231527.png 24 | Click on Rohan Pauls profile,"{'action': 'CLICK', 'value': None, 'position': [483.84000000000003, 549.9200000000001]}",screenshots/screenshot_20241210-231552.png 25 | cClick on Home,"{'action': 'CLICK', 'value': None, 'position': [393.12, 432.08]}",screenshots/screenshot_20241210-231719.png 26 | Click on Home,"{'action': 'CLICK', 'value': None, 'position': [196.56, 216.04]}",screenshots/screenshot_20241210-231751.png 27 | Click on notifications,"{'action': 'CLICK', 'value': None, 'position': [153.6, 273.6]}",screenshots/screenshot_20241211-101202.png 28 | Click on Home,"{'action': 'CLICK', 'value': None, 'position': [166.4, 172.79999999999998]}",screenshots/screenshot_20241211-101219.png 29 | -------------------------------------------------------------------------------- /computer_use/requirements.txt: -------------------------------------------------------------------------------- 1 | SpeechRecognition>=3.12.0 2 | mlx-whisper>=0.4.1 3 | mlx-audio>=0.0.1 -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-191351.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-191351.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-191519.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-191519.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-192437.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-192437.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-193535.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-193535.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-193911.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-193911.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-194314.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-194314.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-195055.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-195055.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-195145.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-195145.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-230354.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-230354.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-230553.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-230553.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-230651.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-230651.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-230730.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-230730.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-230923.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-230923.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-230958.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-230958.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231015.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231058.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231058.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231133.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231133.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231212.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231212.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231321.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231321.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231404.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231404.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231511.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231511.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231527.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231527.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231552.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231552.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231719.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231719.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231751.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231751.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241211-101202.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241211-101202.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241211-101219.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241211-101219.png -------------------------------------------------------------------------------- /computer_use/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from io import BytesIO 4 | 5 | import pandas as pd 6 | import requests 7 | from PIL import Image, ImageDraw 8 | 9 | 10 | def draw_point( 11 | image_input, point: tuple = None, radius: int = 8, color: str = "red" 12 | ) -> Image.Image: 13 | """ 14 | Draw a point on an image and return the modified image. 15 | `point` should be (x_norm, y_norm) in normalized coordinates (0 to 1). 16 | """ 17 | # Load image from path/URL or use directly if it's already a PIL Image 18 | if isinstance(image_input, str): 19 | if image_input.startswith("http"): 20 | response = requests.get(image_input) 21 | response.raise_for_status() 22 | image = Image.open(BytesIO(response.content)) 23 | else: 24 | image = Image.open(image_input) 25 | elif isinstance(image_input, Image.Image): 26 | image = image_input 27 | else: 28 | raise ValueError("image_input must be a string path/URL or a PIL Image object") 29 | 30 | # Only draw if a valid point is provided 31 | if point is not None: 32 | x = int(point[0]) 33 | y = int(point[1]) 34 | print(f"Drawing ellipse at pixel coordinates: ({x}, {y})") 35 | 36 | draw = ImageDraw.Draw(image) 37 | draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=color) 38 | 39 | return image 40 | 41 | 42 | # Update CSV with query, response and image path 43 | def update_navigation_history( 44 | query, response, filepath, csv_path="navigation_history.csv" 45 | ): 46 | """ 47 | Update the navigation history CSV file with the query, response and screenshot filepath. 48 | 49 | Args: 50 | query: The user's query/task 51 | response: The system's response/action 52 | filepath: Path to the screenshot image 53 | csv_path: Path to the CSV file (default: navigation_history.csv) 54 | """ 55 | # Create new row as a DataFrame 56 | new_row = pd.DataFrame( 57 | {"Query": [query], "Response": [str(response)], "Screenshot Path": [filepath]} 58 | ) 59 | 60 | if os.path.exists(csv_path): 61 | # Append to existing CSV 62 | new_row.to_csv(csv_path, mode="a", header=False, index=False) 63 | else: 64 | # Create new CSV with headers 65 | new_row.to_csv(csv_path, index=False) 66 | -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | -------------------------------------------------------------------------------- /docs/cli_reference.md: -------------------------------------------------------------------------------- 1 | # CLI Reference 2 | 3 | MLX-VLM provides several command line entry points: 4 | 5 | - `mlx_vlm.convert` – convert Hugging Face models to MLX format. 6 | - `mlx_vlm.generate` – run inference on images. 7 | - `mlx_vlm.video_generate` – generate from a video file. 8 | - `mlx_vlm.smolvlm_video_generate` – lightweight video generation. 9 | - `mlx_vlm.chat_ui` – start an interactive Gradio UI. 10 | - `mlx_vlm.server` – run the FastAPI server. 11 | 12 | Each command accepts `--help` for full usage information. 13 | 14 | -------------------------------------------------------------------------------- /docs/community_projects.md: -------------------------------------------------------------------------------- 1 | # Community Projects 2 | 3 | If you have a project built on top of MLX-VLM let us know! We plan to showcase community examples and links here. 4 | 5 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | To work on MLX-VLM in editable mode run: 4 | 5 | ```bash 6 | pip install -e . 7 | ``` 8 | 9 | Check that the model weights are available in the `safetensors` format, convert if necessary and add the model file to `mlx_vlm/models`. 10 | 11 | Tests can be run from the `mlx_vlm/` directory: 12 | 13 | ```bash 14 | python -m unittest discover tests/ 15 | ``` 16 | 17 | Please format code using `pre-commit` before submitting a pull request. 18 | 19 | -------------------------------------------------------------------------------- /docs/examples.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | Example notebooks are available in the `examples/` directory: 4 | 5 | - [multi_image_generation.ipynb](https://github.com/Blaizzy/mlx-vlm/blob/main/examples/multi_image_generation.ipynb) 6 | - [object_detection.ipynb](https://github.com/Blaizzy/mlx-vlm/blob/main/examples/object_detection.ipynb) 7 | - [object_pointing.ipynb](https://github.com/Blaizzy/mlx-vlm/blob/main/examples/object_pointing.ipynb) 8 | - [ocr_with_region.ipynb](https://github.com/Blaizzy/mlx-vlm/blob/main/examples/ocr_with_region.ipynb) 9 | - [text_extraction.ipynb](https://github.com/Blaizzy/mlx-vlm/blob/main/examples/text_extraction.ipynb) 10 | - [video_understanding.ipynb](https://github.com/Blaizzy/mlx-vlm/blob/main/examples/video_understanding.ipynb) 11 | 12 | Images and videos used by the notebooks are stored in the `examples/images/` and `examples/videos/` folders. 13 | 14 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # MLX-VLM 2 | 3 | MLX-VLM is a package for inference and fine-tuning of Vision Language Models (VLMs) on Apple silicon using [MLX](https://github.com/ml-explore/mlx). 4 | 5 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | Install the package from PyPI: 4 | 5 | ```bash 6 | pip install mlx-vlm 7 | ``` 8 | 9 | -------------------------------------------------------------------------------- /docs/overrides/main.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block content %} 4 | {% if page.nb_url %} 5 | 6 | {% include ".icons/material/download.svg" %} 7 | 8 | {% endif %} 9 | 10 | {{ super() }} 11 | {% endblock content %} -------------------------------------------------------------------------------- /docs/report_issues.md: -------------------------------------------------------------------------------- 1 | # Report Issues 2 | 3 | Please open an issue on [GitHub](https://github.com/Blaizzy/mlx-vlm/issues) with clear steps to reproduce the problem. 4 | 5 | -------------------------------------------------------------------------------- /docs/usage.md: -------------------------------------------------------------------------------- 1 | # Usage 2 | 3 | ## Command Line Interface (CLI) 4 | 5 | Generate output from a model: 6 | 7 | ```bash 8 | python -m mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --temperature 0.0 --image http://images.cocodataset.org/val2017/000000039769.jpg 9 | ``` 10 | 11 | ## Chat UI with Gradio 12 | 13 | Launch the chat interface: 14 | 15 | ```bash 16 | python -m mlx_vlm.chat_ui --model mlx-community/Qwen2-VL-2B-Instruct-4bit 17 | ``` 18 | 19 | ## Python Script 20 | 21 | ```python 22 | from mlx_vlm import load, generate 23 | from mlx_vlm.prompt_utils import apply_chat_template 24 | from mlx_vlm.utils import load_config 25 | 26 | model_path = "mlx-community/Qwen2-VL-2B-Instruct-4bit" 27 | model, processor = load(model_path) 28 | config = load_config(model_path) 29 | 30 | image = ["http://images.cocodataset.org/val2017/000000039769.jpg"] 31 | prompt = "Describe this image." 32 | 33 | formatted_prompt = apply_chat_template(processor, config, prompt, num_images=len(image)) 34 | output = generate(model, processor, formatted_prompt, image, verbose=False) 35 | print(output) 36 | ``` 37 | 38 | ## Server (FastAPI) 39 | 40 | ```bash 41 | python -m mlx_vlm.server 42 | ``` 43 | 44 | See `README.md` for a complete `curl` example. 45 | 46 | -------------------------------------------------------------------------------- /examples/images/cats.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/examples/images/cats.jpg -------------------------------------------------------------------------------- /examples/images/desktop_setup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/examples/images/desktop_setup.png -------------------------------------------------------------------------------- /examples/images/graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/examples/images/graph.png -------------------------------------------------------------------------------- /examples/images/handwritten_hard.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/examples/images/handwritten_hard.webp -------------------------------------------------------------------------------- /examples/images/latex.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/examples/images/latex.png -------------------------------------------------------------------------------- /examples/images/menu.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/examples/images/menu.webp -------------------------------------------------------------------------------- /examples/images/paper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/examples/images/paper.png -------------------------------------------------------------------------------- /examples/images/renewables_california.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/examples/images/renewables_california.png -------------------------------------------------------------------------------- /examples/video_understanding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Video Understanding\n", 8 | "\n", 9 | "In this example, we will generate a description of a video using `Qwen2-VL`, `Qwen2-5-VL`, `LLava`, and `Idefics3`, with more models coming soon.\n", 10 | "\n", 11 | "This feature is currently in beta, may not work as expected.\n", 12 | "\n" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "## Install Dependencies" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "!pip install -U mlx-vlm" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "## Import Dependencies" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 1, 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "name": "stderr", 45 | "output_type": "stream", 46 | "text": [ 47 | "/opt/homebrew/Caskroom/miniconda/base/envs/mlx_code/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 48 | " from .autonotebook import tqdm as notebook_tqdm\n", 49 | "This is a beta version of the video understanding. It may not work as expected.\n" 50 | ] 51 | } 52 | ], 53 | "source": [ 54 | "from pprint import pprint\n", 55 | "from mlx_vlm import load\n", 56 | "from mlx_vlm.utils import generate\n", 57 | "from mlx_vlm.video_generate import process_vision_info\n", 58 | "\n", 59 | "import mlx.core as mx" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "# Load the model and processor\n", 69 | "model, processor = load(\"mlx-community/Qwen2.5-VL-7B-Instruct-4bit\")" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 3, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "name": "stderr", 79 | "output_type": "stream", 80 | "text": [ 81 | "numpy reader: video_path=videos/fastmlx_local_ai_hub.mp4, total_frames=1134, video_fps=59.941855343141576, time=0.000s\n" 82 | ] 83 | } 84 | ], 85 | "source": [ 86 | "# Messages containing a video and a text query\n", 87 | "messages = [\n", 88 | " {\n", 89 | " \"role\": \"user\",\n", 90 | " \"content\": [\n", 91 | " {\n", 92 | " \"type\": \"video\",\n", 93 | " \"video\": \"videos/fastmlx_local_ai_hub.mp4\",\n", 94 | " \"max_pixels\": 360 * 360,\n", 95 | " \"fps\": 1.0,\n", 96 | " },\n", 97 | " {\"type\": \"text\", \"text\": \"Describe this video.\"},\n", 98 | " ],\n", 99 | " }\n", 100 | "]\n", 101 | "\n", 102 | "# Preparation for inference\n", 103 | "text = processor.apply_chat_template(\n", 104 | " messages, tokenize=False, add_generation_prompt=True\n", 105 | ")\n", 106 | "image_inputs, video_inputs = process_vision_info(messages)\n", 107 | "inputs = processor(\n", 108 | " text=[text],\n", 109 | " images=image_inputs,\n", 110 | " videos=video_inputs,\n", 111 | " padding=True,\n", 112 | " return_tensors=\"pt\",\n", 113 | ")\n" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 4, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "# Convert inputs to mlx arrays\n", 123 | "input_ids = mx.array(inputs['input_ids'])\n", 124 | "pixel_values = mx.array(inputs['pixel_values_videos'])\n", 125 | "mask = mx.array(inputs['attention_mask'])\n", 126 | "image_grid_thw = mx.array(inputs['video_grid_thw'])\n", 127 | "\n", 128 | "kwargs = {\n", 129 | " \"image_grid_thw\": image_grid_thw,\n", 130 | "}" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 9, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "kwargs[\"video\"] = \"videos/fastmlx_local_ai_hub.mp4\"\n", 140 | "kwargs[\"input_ids\"] = input_ids\n", 141 | "kwargs[\"pixel_values\"] = pixel_values\n", 142 | "kwargs[\"mask\"] = mask\n", 143 | "response = generate(model, processor, prompt=text, temperature=0.7, max_tokens=100, **kwargs)" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 10, 149 | "metadata": {}, 150 | "outputs": [ 151 | { 152 | "name": "stdout", 153 | "output_type": "stream", 154 | "text": [ 155 | "('The video appears to be a live stream or a recording of a coding session, '\n", 156 | " 'likely on a platform like Discord, as indicated by the presence of text '\n", 157 | " \"chats and a streamer's interface. The video is primarily focused on a \"\n", 158 | " 'computer screen displaying a code editor with various programming languages '\n", 159 | " 'and snippets of code. The coder seems to be explaining or demonstrating '\n", 160 | " 'something related to the code, possibly working through a programming '\n", 161 | " 'problem, explaining the logic, or showing the process of solving a problem.\\n'\n", 162 | " '\\n'\n", 163 | " 'Here are some key observations from')\n" 164 | ] 165 | } 166 | ], 167 | "source": [ 168 | "pprint(response)\n" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "# open video and play it\n", 178 | "from ipywidgets import Video\n", 179 | "Video.from_file(\"videos/fastmlx_local_ai_hub.mp4\", width=320, height=240)" 180 | ] 181 | } 182 | ], 183 | "metadata": { 184 | "kernelspec": { 185 | "display_name": "mlx_code", 186 | "language": "python", 187 | "name": "python3" 188 | }, 189 | "language_info": { 190 | "codemirror_mode": { 191 | "name": "ipython", 192 | "version": 3 193 | }, 194 | "file_extension": ".py", 195 | "mimetype": "text/x-python", 196 | "name": "python", 197 | "nbconvert_exporter": "python", 198 | "pygments_lexer": "ipython3", 199 | "version": "3.11.11" 200 | } 201 | }, 202 | "nbformat": 4, 203 | "nbformat_minor": 2 204 | } 205 | -------------------------------------------------------------------------------- /examples/videos/fastmlx_local_ai_hub.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/examples/videos/fastmlx_local_ai_hub.mp4 -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: MLX-VLM 2 | site_description: MLX-VLM is a package for inference and fine-tuning of Vision Language Models (VLMs) on your Mac using MLX. 3 | site_author: Prince Canuma 4 | repo_name: Blaizzy/mlx-vlm 5 | site_url: https://Blaizzy.github.io/mlx-vlm 6 | repo_url: https://github.com/Blaizzy/mlx-vlm 7 | 8 | copyright: "Copyright \u00a9 2024 - 2024 Prince Canuma" 9 | 10 | theme: 11 | palette: 12 | - scheme: default 13 | primary: black 14 | toggle: 15 | icon: material/toggle-switch-off-outline 16 | name: Switch to dark mode 17 | - scheme: slate 18 | primary: black 19 | accent: indigo 20 | toggle: 21 | icon: material/toggle-switch 22 | name: Switch to light mode 23 | name: material 24 | icon: 25 | repo: fontawesome/brands/github 26 | features: 27 | - navigation.instant 28 | - navigation.tracking 29 | - navigation.top 30 | - navigation.footer 31 | - search.highlight 32 | - search.share 33 | - content.code.copy 34 | custom_dir: "docs/overrides" 35 | font: 36 | text: Google Sans 37 | code: Regular 38 | 39 | plugins: 40 | - search 41 | - mkdocstrings 42 | - mkdocs-jupyter: 43 | include_source: True 44 | ignore_h1_titles: True 45 | execute: True 46 | allow_errors: false 47 | ignore: ["conf.py"] 48 | execute_ignore: ["*ignore.ipynb"] 49 | 50 | markdown_extensions: 51 | - admonition 52 | - abbr 53 | - attr_list 54 | - def_list 55 | - footnotes 56 | - meta 57 | - md_in_html 58 | - pymdownx.superfences 59 | - pymdownx.highlight: 60 | linenums: true 61 | - toc: 62 | permalink: true 63 | 64 | extra: 65 | social: 66 | - icon: fontawesome/brands/github 67 | link: https://github.com/Blaizzy 68 | - icon: fontawesome/brands/twitter 69 | link: https://twitter.com/Prince_Canuma 70 | version: 71 | provider: mike 72 | consent: 73 | title: Cookie consent 74 | description: >- 75 | We use cookies to recognize your repeated visits and preferences, as well 76 | as to measure the effectiveness of our documentation and whether users 77 | find what they're searching for. With your consent, you're helping us to 78 | make our documentation better. 79 | 80 | extra_css: 81 | - stylesheets/extra.css 82 | 83 | nav: 84 | - Home: index.md 85 | - Installation: installation.md 86 | - CLI Reference: cli_reference.md 87 | - Examples: examples.md 88 | - Contributing: contributing.md 89 | - Community Projects: community_projects.md 90 | - Report Issues: report_issues.md 91 | - Changelog: changelog.md 92 | 93 | docs_dir: docs 94 | -------------------------------------------------------------------------------- /mlx_vlm/LORA.MD: -------------------------------------------------------------------------------- 1 | # LoRA Training Script 2 | 3 | ## Overview 4 | 5 | `lora.py` is a Python script for fine-tuning a vision language models (VLMs) using Low-Rank Adaptation (LoRA or QLoRA). This script allows you to train the model on your custom dataset, adjusting various parameters through command-line arguments. 6 | 7 | ## Requirements 8 | 9 | - Python 3.7+ 10 | - Required Python packages: `mlx-vlm`, `numpy`, `transformers`, `datasets`, `PIL` 11 | 12 | ## Supported Models 13 | - Qwen2 14 | - LLaVA (except for LLaVA-Next) 15 | - Pixtral 16 | - Idefics 2 17 | - Deepseek-VL 18 | - Paligemma 19 | - Mllama (Llama-3.2-vision) 20 | 21 | ## Coming Soon 22 | - LLaVA-Next 23 | - Phi3_vision 24 | 25 | ## Usage 26 | 27 | To use the script, run it from the command line with the desired arguments: 28 | 29 | ``` 30 | python lora.py --dataset /path/to/your/dataset [other options] 31 | ``` 32 | 33 | ## Dataset format 34 | 35 | The dataset should be a Hugging Face dataset with a `images` column and a `messages` column. 36 | 37 | ``` 38 | { 39 | "images": ..., 40 | "messages": ..., 41 | } 42 | ``` 43 | 44 | Support for other formats and column names will be added soon. 45 | 46 | ## Arguments 47 | 48 | The script accepts the following command-line arguments: 49 | 50 | - `--model-path`: Path to the pre-trained model (default: "mlx-community/Qwen2-VL-2B-Instruct-bf16") 51 | - `--dataset`: Path to your dataset (required) 52 | - `--learning-rate`: Learning rate for the optimizer (default: 1e-4) 53 | - `--batch-size`: Batch size for training (default: 1) 54 | - `--epochs`: Number of epochs to train (default: 1) 55 | - `--steps`: Number of steps per epoch (default: 0) 56 | - `--print-every`: Print loss every n steps (default: 10) 57 | - `--adapter-path`: Load path to resume training from a previously saved adapter (default: None) 58 | - `--output-path`: Path to save the trained adapter (default: "adapters.safetensors") 59 | 60 | ## Example 61 | 62 | Here's an example of how to run the script with custom parameters: 63 | 64 | ``` 65 | python lora.py --dataset /path/to/your/dataset --model-path /path/to/your/model --epochs 2 --batch-size 4 --learning-rate 5e-5 66 | ``` 67 | 68 | ## Output 69 | 70 | The script will print the training loss at regular intervals (defined by `--print-every`). After training, it will save the LoRA adapter to the specified output path. 71 | 72 | ## Note 73 | 74 | If you want to use QLoRA, you need to pass a pre-quantized model to the script using the `--model-path` argument (i.e. `mlx-community/Qwen2-VL-2B-Instruct-4bit`). 75 | Make sure you have the necessary permissions to read the dataset and write the output file. Also, ensure that your system has sufficient computational resources to handle the specified batch size and model. 76 | 77 | ## Contributing 78 | 79 | Feel free to submit issues or pull requests if you find any bugs or have suggestions for improvements. 80 | -------------------------------------------------------------------------------- /mlx_vlm/__init__.py: -------------------------------------------------------------------------------- 1 | from .prompt_utils import apply_chat_template, get_message_json 2 | from .utils import ( 3 | GenerationResult, 4 | convert, 5 | generate, 6 | load, 7 | prepare_inputs, 8 | process_image, 9 | quantize_model, 10 | stream_generate, 11 | ) 12 | from .version import __version__ 13 | -------------------------------------------------------------------------------- /mlx_vlm/chat_ui.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gradio as gr 4 | 5 | from mlx_vlm import load 6 | 7 | from .prompt_utils import get_chat_template, get_message_json 8 | from .utils import load, load_config, load_image_processor, stream_generate 9 | 10 | 11 | def parse_arguments(): 12 | parser = argparse.ArgumentParser( 13 | description="Generate text from an image using a model." 14 | ) 15 | parser.add_argument( 16 | "--model", 17 | type=str, 18 | default="qnguyen3/nanoLLaVA", 19 | help="The path to the local model directory or Hugging Face repo.", 20 | ) 21 | return parser.parse_args() 22 | 23 | 24 | args = parse_arguments() 25 | config = load_config(args.model) 26 | model, processor = load(args.model, processor_kwargs={"trust_remote_code": True}) 27 | image_processor = load_image_processor(args.model) 28 | 29 | 30 | def chat(message, history, temperature, max_tokens): 31 | if config["model_type"] != "paligemma": 32 | chat_history = [] 33 | for item in history: 34 | if isinstance(item[0], str): 35 | chat_history.append({"role": "user", "content": item[0]}) 36 | if item[1] is not None: 37 | chat_history.append({"role": "assistant", "content": item[1]}) 38 | 39 | chat_history.append({"role": "user", "content": message["text"]}) 40 | 41 | messages = [] 42 | for i, m in enumerate(chat_history): 43 | skip_token = True 44 | if i == len(chat_history) - 1 and m["role"] == "user": 45 | skip_token = False 46 | messages.append( 47 | get_message_json( 48 | config["model_type"], 49 | m["content"], 50 | role=m["role"], 51 | skip_image_token=skip_token, 52 | ) 53 | ) 54 | 55 | messages = get_chat_template(processor, messages, add_generation_prompt=True) 56 | 57 | else: 58 | messages = message["text"] 59 | 60 | files = "" 61 | if "files" in message and len(message["files"]) > 0: 62 | files = message["files"][-1] 63 | 64 | response = "" 65 | for chunk in stream_generate( 66 | model, 67 | processor, 68 | messages, 69 | files, 70 | max_tokens=max_tokens, 71 | temperature=temperature, 72 | ): 73 | response += chunk.text 74 | yield response 75 | 76 | 77 | demo = gr.ChatInterface( 78 | fn=chat, 79 | title="MLX-VLM Chat UI", 80 | additional_inputs_accordion=gr.Accordion( 81 | label="⚙️ Parameters", open=False, render=False 82 | ), 83 | additional_inputs=[ 84 | gr.Slider( 85 | minimum=0, maximum=1, step=0.1, value=0.1, label="Temperature", render=False 86 | ), 87 | gr.Slider( 88 | minimum=128, 89 | maximum=4096, 90 | step=1, 91 | value=200, 92 | label="Max new tokens", 93 | render=False, 94 | ), 95 | ], 96 | description=f"Now Running {args.model}", 97 | multimodal=True, 98 | ) 99 | 100 | demo.launch(inbrowser=True) 101 | -------------------------------------------------------------------------------- /mlx_vlm/convert.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import argparse 4 | 5 | from .utils import MODEL_CONVERSION_DTYPES, convert 6 | 7 | 8 | def configure_parser() -> argparse.ArgumentParser: 9 | """ 10 | Configures and returns the argument parser for the script. 11 | 12 | Returns: 13 | argparse.ArgumentParser: Configured argument parser. 14 | """ 15 | parser = argparse.ArgumentParser( 16 | description="Convert Hugging Face model to MLX format" 17 | ) 18 | 19 | parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.") 20 | parser.add_argument( 21 | "--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model." 22 | ) 23 | parser.add_argument( 24 | "-q", "--quantize", help="Generate a quantized model.", action="store_true" 25 | ) 26 | parser.add_argument( 27 | "--q-group-size", help="Group size for quantization.", type=int, default=64 28 | ) 29 | parser.add_argument( 30 | "--q-bits", help="Bits per weight for quantization.", type=int, default=4 31 | ) 32 | parser.add_argument( 33 | "--dtype", 34 | help="Type to save the parameter. Defaults to config.json's `torch_dtype` or the current model weights dtype", 35 | type=str, 36 | choices=MODEL_CONVERSION_DTYPES, 37 | default=None, 38 | ) 39 | parser.add_argument( 40 | "--upload-repo", 41 | help="The Hugging Face repo to upload the model to.", 42 | type=str, 43 | default=None, 44 | ) 45 | parser.add_argument( 46 | "-d", 47 | "--dequantize", 48 | help="Dequantize a quantized model.", 49 | action="store_true", 50 | default=False, 51 | ) 52 | parser.add_argument( 53 | "--skip-vision", 54 | help="Skip vision module quantization.", 55 | action="store_true", 56 | default=False, 57 | ) 58 | return parser 59 | 60 | 61 | def main(): 62 | parser = configure_parser() 63 | args = parser.parse_args() 64 | convert(**vars(args)) 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /mlx_vlm/generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import codecs 3 | 4 | from .prompt_utils import apply_chat_template 5 | from .utils import ( 6 | generate, 7 | get_model_path, 8 | load, 9 | load_config, 10 | load_image_processor, 11 | stream_generate, 12 | ) 13 | 14 | DEFAULT_MODEL_PATH = "mlx-community/nanoLLaVA-1.5-8bit" 15 | DEFAULT_IMAGE = [] 16 | DEFAULT_PROMPT = "What are these?" 17 | DEFAULT_MAX_TOKENS = 256 18 | DEFAULT_TEMPERATURE = 0.5 19 | DEFAULT_TOP_P = 1.0 20 | DEFAULT_SEED = 0 21 | 22 | 23 | def parse_arguments(): 24 | parser = argparse.ArgumentParser( 25 | description="Generate text from an image using a model." 26 | ) 27 | parser.add_argument( 28 | "--model", 29 | type=str, 30 | default=DEFAULT_MODEL_PATH, 31 | help="The path to the local model directory or Hugging Face repo.", 32 | ) 33 | parser.add_argument( 34 | "--adapter-path", 35 | type=str, 36 | default=None, 37 | help="The path to the adapter weights.", 38 | ) 39 | parser.add_argument( 40 | "--image", 41 | type=str, 42 | nargs="+", 43 | default=DEFAULT_IMAGE, 44 | help="URL or path of the image to process.", 45 | ) 46 | parser.add_argument( 47 | "--resize-shape", 48 | type=int, 49 | nargs="+", 50 | default=None, 51 | help="Resize shape for the image.", 52 | ) 53 | parser.add_argument( 54 | "--prompt", 55 | type=str, 56 | default=DEFAULT_PROMPT, 57 | help="Message to be processed by the model.", 58 | ) 59 | parser.add_argument( 60 | "--system", 61 | type=str, 62 | default=None, 63 | help="System message for the model.", 64 | ) 65 | parser.add_argument( 66 | "--max-tokens", 67 | type=int, 68 | default=DEFAULT_MAX_TOKENS, 69 | help="Maximum number of tokens to generate.", 70 | ) 71 | parser.add_argument( 72 | "--temperature", 73 | type=float, 74 | default=DEFAULT_TEMPERATURE, 75 | help="Temperature for sampling.", 76 | ) 77 | parser.add_argument("--chat", action="store_true", help="Chat in multi-turn style.") 78 | parser.add_argument("--verbose", action="store_false", help="Detailed output.") 79 | parser.add_argument( 80 | "--eos-tokens", 81 | type=str, 82 | nargs="+", 83 | default=None, 84 | help="EOS tokens to add to the tokenizer.", 85 | ) 86 | parser.add_argument( 87 | "--skip-special-tokens", 88 | action="store_true", 89 | help="Skip special tokens in the detokenizer.", 90 | ) 91 | 92 | return parser.parse_args() 93 | 94 | 95 | def get_model_and_processors(model_path, adapter_path): 96 | model_path = get_model_path(model_path) 97 | config = load_config(model_path, trust_remote_code=True) 98 | model, processor = load( 99 | model_path, adapter_path=adapter_path, lazy=False, trust_remote_code=True 100 | ) 101 | return model, processor, config 102 | 103 | 104 | def main(): 105 | args = parse_arguments() 106 | if isinstance(args.image, str): 107 | args.image = [args.image] 108 | 109 | model, processor, config = get_model_and_processors(args.model, args.adapter_path) 110 | 111 | prompt = codecs.decode(args.prompt, "unicode_escape") 112 | 113 | prompt = apply_chat_template(processor, config, prompt, num_images=len(args.image)) 114 | 115 | kwargs = {} 116 | 117 | if args.resize_shape is not None: 118 | if len(args.resize_shape) not in [1, 2]: 119 | raise ValueError("Resize shape must be 1 or 2 integers") 120 | kwargs["resize_shape"] = ( 121 | (args.resize_shape[0],) * 2 122 | if len(args.resize_shape) == 1 123 | else tuple(args.resize_shape) 124 | ) 125 | 126 | if args.eos_tokens is not None: 127 | kwargs["eos_tokens"] = [ 128 | codecs.decode(token, "unicode_escape") for token in args.eos_tokens 129 | ] 130 | 131 | if args.skip_special_tokens: 132 | kwargs["skip_special_tokens"] = args.skip_special_tokens 133 | 134 | if args.chat: 135 | chat = [] 136 | if args.system: 137 | chat.append({"role": "system", "content": args.system}) 138 | while user := input("User:"): 139 | chat.append({"role": "user", "content": user}) 140 | prompt = apply_chat_template( 141 | processor, config, chat, num_images=len(args.image) 142 | ) 143 | response = "" 144 | print("Assistant:", end="") 145 | for chunk in stream_generate( 146 | model, 147 | processor, 148 | prompt, 149 | args.image, 150 | max_tokens=args.max_tokens, 151 | temperature=args.temperature, 152 | **kwargs, 153 | ): 154 | response += chunk.text 155 | print(chunk.text, end="") 156 | 157 | chat.append({"role": "assistant", "content": response}) 158 | print() 159 | 160 | else: 161 | output = generate( 162 | model, 163 | processor, 164 | prompt, 165 | image=args.image, 166 | temperature=args.temperature, 167 | max_tokens=args.max_tokens, 168 | verbose=args.verbose, 169 | **kwargs, 170 | ) 171 | if not args.verbose: 172 | print(output) 173 | 174 | 175 | if __name__ == "__main__": 176 | main() 177 | -------------------------------------------------------------------------------- /mlx_vlm/lora.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | 5 | import mlx.optimizers as optim 6 | from datasets import load_dataset 7 | from tqdm import tqdm 8 | 9 | from .prompt_utils import apply_chat_template 10 | from .trainer import Dataset, Trainer, save_adapter 11 | from .trainer.utils import apply_lora_layers, find_all_linear_names, get_peft_model 12 | from .utils import load, load_image_processor 13 | 14 | logging.basicConfig(level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def custom_print(*args, **kwargs): 19 | tqdm.write(" ".join(map(str, args)), **kwargs) 20 | 21 | 22 | def main(args): 23 | logger.info(f"\033[32mLoading model from {args.model_path}\033[0m") 24 | model, processor = load( 25 | args.model_path, processor_config={"trust_remote_code": True} 26 | ) 27 | config = model.config.__dict__ 28 | image_processor = load_image_processor(args.model_path) 29 | 30 | logger.info(f"\033[32mLoading dataset from {args.dataset}\033[0m") 31 | dataset = load_dataset(args.dataset, split=args.split) 32 | 33 | if "messages" not in dataset.column_names: 34 | raise ValueError("Dataset must have a 'messages' column") 35 | if "images" not in dataset.column_names: 36 | raise ValueError("Dataset must have an 'images' column") 37 | 38 | if args.apply_chat_template: 39 | logger.info(f"\033[32mApplying chat template to the dataset\033[0m") 40 | 41 | def process_data(examples): 42 | if config["model_type"] == "pixtral": 43 | conversations = apply_chat_template( 44 | config=config, 45 | processor=processor, 46 | prompt=examples["messages"], 47 | return_messages=True, 48 | ) 49 | examples["messages"] = [ 50 | json.dumps(item, ensure_ascii=False) for item in conversations 51 | ] 52 | else: 53 | examples["messages"] = apply_chat_template( 54 | config=config, 55 | processor=processor, 56 | prompt=examples["messages"], 57 | return_messages=True, 58 | ) 59 | return examples 60 | 61 | dataset = dataset.map(process_data) 62 | 63 | dataset = Dataset( 64 | dataset, 65 | config, 66 | processor, 67 | image_processor=image_processor, 68 | image_resize_shape=args.image_resize_shape, 69 | ) 70 | 71 | adapter_path = args.adapter_path 72 | if adapter_path: 73 | logger.info(f"\033[32mResuming from adapter path {adapter_path}\033[0m") 74 | logger.info( 75 | f"\033[32mLora rank, alpha, and dropout will be loaded from adapter_config.json file\033[0m" 76 | ) 77 | 78 | model = apply_lora_layers(model, adapter_path) 79 | 80 | else: 81 | logger.info(f"\033[32mSetting up LoRA\033[0m") 82 | 83 | list_of_modules = find_all_linear_names(model.language_model) 84 | model = get_peft_model( 85 | model, 86 | list_of_modules, 87 | rank=args.lora_rank, 88 | alpha=args.lora_alpha, 89 | dropout=args.lora_dropout, 90 | ) 91 | 92 | logger.info(f"\033[32mSetting up optimizer\033[0m") 93 | optimizer = optim.Adam(learning_rate=args.learning_rate) 94 | 95 | logger.info(f"\033[32mSetting up trainer\033[0m") 96 | trainer = Trainer(model, optimizer) 97 | 98 | model.train() 99 | 100 | # Training loop 101 | logger.info(f"\033[32mTraining model\033[0m") 102 | for epoch in range(args.epochs): 103 | if args.steps == 0: 104 | args.steps = len(dataset) // args.batch_size 105 | 106 | progress_bar = tqdm(range(args.steps), position=0, leave=True) 107 | for i in progress_bar: 108 | loss = trainer.train_step( 109 | dataset[i * args.batch_size : (i + 1) * args.batch_size] 110 | ) 111 | # Update progress bar 112 | progress_bar.update(1) 113 | progress_bar.set_postfix( 114 | {"Epoch": epoch, "Step": i, "Loss": f"{loss.item():.4f}"} 115 | ) 116 | 117 | if i % args.print_every == 0: 118 | # Log additional information 119 | custom_print( 120 | { 121 | "Epoch": epoch, 122 | "Step": i, 123 | "Loss": f"{loss.item():.4f}", 124 | } 125 | ) 126 | 127 | # Save the adapter 128 | save_adapter(model, args.output_path) 129 | 130 | 131 | if __name__ == "__main__": 132 | parser = argparse.ArgumentParser(description="Train NanoLLaVA model") 133 | parser.add_argument( 134 | "--model-path", 135 | type=str, 136 | default="mlx-community/Qwen2-VL-2B-Instruct-bf16", 137 | help="Path to the pre-trained model", 138 | ) 139 | parser.add_argument( 140 | "--dataset", type=str, required=True, help="Path to the dataset" 141 | ) 142 | parser.add_argument( 143 | "--split", type=str, default="train", help="Split to use for training" 144 | ) 145 | parser.add_argument( 146 | "--image-resize-shape", 147 | type=int, 148 | nargs=2, 149 | default=None, 150 | help="Resize images to this shape", 151 | ) 152 | parser.add_argument( 153 | "--apply-chat-template", 154 | action="store_false", 155 | help="Apply chat template to the dataset", 156 | ) 157 | parser.add_argument( 158 | "--learning-rate", 159 | type=float, 160 | default=1e-4, 161 | help="Learning rate for the optimizer", 162 | ) 163 | parser.add_argument( 164 | "--batch-size", type=int, default=1, help="Batch size for training" 165 | ) 166 | parser.add_argument( 167 | "--epochs", type=int, default=1, help="Number of epochs to train" 168 | ) 169 | parser.add_argument( 170 | "--steps", type=int, default=0, help="Number of steps per epoch" 171 | ) 172 | parser.add_argument( 173 | "--print-every", type=int, default=10, help="Print loss every n steps" 174 | ) 175 | parser.add_argument( 176 | "--lora-alpha", type=int, default=0.1, help="LoRA alpha parameter" 177 | ) 178 | parser.add_argument("--lora-rank", type=int, default=10, help="LoRA rank") 179 | parser.add_argument("--lora-dropout", type=float, default=0.1, help="LoRA dropout") 180 | parser.add_argument( 181 | "--output-path", 182 | type=str, 183 | default="adapters", 184 | help="Path to save the trained adapter", 185 | ) 186 | parser.add_argument( 187 | "--adapter-path", 188 | type=str, 189 | default=None, 190 | help="Load path to resume training from a previously saved adapter", 191 | ) 192 | 193 | args = parser.parse_args() 194 | main(args) 195 | -------------------------------------------------------------------------------- /mlx_vlm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/mlx_vlm/models/__init__.py -------------------------------------------------------------------------------- /mlx_vlm/models/aya_vision/__init__.py: -------------------------------------------------------------------------------- 1 | from .aya_vision import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/aya_vision/interpolate.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import numpy as np 3 | 4 | 5 | def gaussian_blur_axis(image, sigma, axis): 6 | """ 7 | Applies a 1D Gaussian blur along the given axis. 8 | This version works for arrays with any number of dimensions. 9 | """ 10 | radius = int(3 * sigma) 11 | if radius < 1: 12 | return image 13 | x = mx.arange(-radius, radius + 1) 14 | kernel = mx.exp(-(x**2) / (2 * sigma**2)) 15 | kernel = kernel / mx.sum(kernel) 16 | 17 | # MLX doesn't have a direct apply_along_axis equivalent, 18 | # so we'll implement the convolution differently based on the axis 19 | 20 | # Helper function to apply 1D convolution along specific axis 21 | def conv_1d(array, kernel, axis): 22 | # Reshape kernel to broadcast along the right dimensions 23 | kernel_shape = [1] * image.ndim 24 | kernel_shape[axis] = len(kernel) 25 | kernel_reshaped = kernel.reshape(kernel_shape) 26 | 27 | # Pad the array 28 | pad_width = [(0, 0)] * image.ndim 29 | pad_width[axis] = (radius, radius) 30 | padded = mx.pad(array, pad_width, mode="edge") 31 | 32 | # Perform convolution via sliding window sum 33 | result = mx.zeros_like(array) 34 | slices = [slice(None)] * padded.ndim 35 | 36 | for i in range(2 * radius + 1): 37 | slices[axis] = slice(i, i + array.shape[axis]) 38 | result = result + padded[tuple(slices)] * kernel_reshaped 39 | 40 | return result 41 | 42 | return conv_1d(image, kernel, axis) 43 | 44 | 45 | def bilinear_interpolate(image, new_height, new_width, align_corners=False): 46 | """ 47 | Performs bilinear interpolation on an array whose spatial dimensions are the first two. 48 | It supports extra dimensions (e.g. channels or batch dimensions that have been moved to the trailing axes). 49 | """ 50 | # image is assumed to have shape (H, W, ...) where H and W are spatial dimensions. 51 | H_in, W_in = image.shape[0], image.shape[1] 52 | 53 | # Compute sampling positions in the input image. 54 | if new_height == 1: 55 | row_positions = mx.array([0.0]) 56 | else: 57 | if align_corners: 58 | row_positions = mx.linspace(0, H_in - 1, new_height) 59 | else: 60 | row_positions = (mx.arange(new_height) + 0.5) * H_in / new_height - 0.5 61 | 62 | if new_width == 1: 63 | col_positions = mx.array([0.0]) 64 | else: 65 | if align_corners: 66 | col_positions = mx.linspace(0, W_in - 1, new_width) 67 | else: 68 | col_positions = (mx.arange(new_width) + 0.5) * W_in / new_width - 0.5 69 | 70 | # Compute floor and ceil indices. 71 | row_floor = mx.floor(row_positions).astype(mx.int32) 72 | col_floor = mx.floor(col_positions).astype(mx.int32) 73 | row_ceil = row_floor + 1 74 | col_ceil = col_floor + 1 75 | 76 | row_floor = mx.clip(row_floor, 0, H_in - 1) 77 | row_ceil = mx.clip(row_ceil, 0, H_in - 1) 78 | col_floor = mx.clip(col_floor, 0, W_in - 1) 79 | col_ceil = mx.clip(col_ceil, 0, W_in - 1) 80 | 81 | row_weight = row_positions - row_floor # shape (new_height,) 82 | col_weight = col_positions - col_floor # shape (new_width,) 83 | 84 | # Use advanced indexing for gather operations 85 | # Create meshgrid for coordinates 86 | row_floor_grid, col_floor_grid = mx.meshgrid(row_floor, col_floor, indexing="ij") 87 | row_ceil_grid, col_floor_grid = mx.meshgrid(row_ceil, col_floor, indexing="ij") 88 | row_floor_grid, col_ceil_grid = mx.meshgrid(row_floor, col_ceil, indexing="ij") 89 | row_ceil_grid, col_ceil_grid = mx.meshgrid(row_ceil, col_ceil, indexing="ij") 90 | 91 | # Gather the four surrounding pixels using take_along_axis 92 | # For higher dimensional arrays, we'll need to reshape and broadcast 93 | extra_dims = image.ndim - 2 94 | 95 | def gather_pixels(row_indices, col_indices): 96 | # Flatten the spatial dimensions for gathering 97 | flat_indices = row_indices * W_in + col_indices 98 | flat_image = mx.reshape(image, (-1,) + image.shape[2:]) 99 | # Gather and reshape back 100 | gathered = mx.take(flat_image, flat_indices.reshape(-1), axis=0) 101 | return mx.reshape(gathered, (new_height, new_width) + image.shape[2:]) 102 | 103 | top_left = gather_pixels(row_floor_grid, col_floor_grid) 104 | top_right = gather_pixels(row_floor_grid, col_ceil_grid) 105 | bottom_left = gather_pixels(row_ceil_grid, col_floor_grid) 106 | bottom_right = gather_pixels(row_ceil_grid, col_ceil_grid) 107 | 108 | # Expand the weights to have shape (new_height, new_width, *[1]*extra_dims) 109 | r_weight = row_weight.reshape(new_height, 1, *([1] * extra_dims)) 110 | c_weight = col_weight.reshape(1, new_width, *([1] * extra_dims)) 111 | 112 | # Perform bilinear interpolation. 113 | result = ( 114 | (1 - r_weight) * (1 - c_weight) * top_left 115 | + (1 - r_weight) * c_weight * top_right 116 | + r_weight * (1 - c_weight) * bottom_left 117 | + r_weight * c_weight * bottom_right 118 | ) 119 | return result 120 | 121 | 122 | def resize_bilinear(image, new_size, align_corners=False, antialias=True): 123 | """ 124 | Resizes an image (or embedding tensor) to new_size=(new_height, new_width) 125 | using bilinear interpolation with MLX. 126 | 127 | Supports: 128 | - 2D: (H, W) 129 | - 3D: (H, W, C) 130 | - 4D: (B, C, H, W) (assumed for typical image batches) 131 | """ 132 | new_height, new_width = new_size 133 | 134 | # Convert numpy arrays to MLX arrays if needed 135 | if isinstance(image, np.ndarray): 136 | image = mx.array(image) 137 | 138 | if image.ndim == 2 or image.ndim == 3: 139 | # Assume spatial dims are the first two. 140 | resized = image 141 | H_in, W_in = image.shape[:2] 142 | if antialias: 143 | if new_height < H_in: 144 | scale_y = new_height / H_in 145 | sigma_y = (1 / scale_y - 1) / 2.0 # heuristic 146 | if sigma_y > 0: 147 | resized = gaussian_blur_axis(resized, sigma_y, axis=0) 148 | if new_width < W_in: 149 | scale_x = new_width / W_in 150 | sigma_x = (1 / scale_x - 1) / 2.0 151 | if sigma_x > 0: 152 | resized = gaussian_blur_axis(resized, sigma_x, axis=1) 153 | resized = bilinear_interpolate( 154 | resized, new_height, new_width, align_corners=align_corners 155 | ) 156 | return resized 157 | 158 | elif image.ndim == 4: 159 | # Assume shape is (B, C, H, W) (typical PyTorch/MLX format). 160 | B, C, H_in, W_in = image.shape 161 | # Permute to bring spatial dims to the front: (H, W, B, C) 162 | image_perm = mx.transpose(image, (2, 3, 0, 1)) 163 | resized = image_perm 164 | if antialias: 165 | if new_height < H_in: 166 | scale_y = new_height / H_in 167 | sigma_y = (1 / scale_y - 1) / 2.0 168 | if sigma_y > 0: 169 | resized = gaussian_blur_axis(resized, sigma_y, axis=0) 170 | if new_width < W_in: 171 | scale_x = new_width / W_in 172 | sigma_x = (1 / scale_x - 1) / 2.0 173 | if sigma_x > 0: 174 | resized = gaussian_blur_axis(resized, sigma_x, axis=1) 175 | resized = bilinear_interpolate( 176 | resized, new_height, new_width, align_corners=align_corners 177 | ) 178 | # Permute back to (B, C, new_height, new_width) 179 | resized = mx.transpose(resized, (2, 3, 0, 1)) 180 | return resized 181 | 182 | else: 183 | raise ValueError("Unsupported image dimensions.") 184 | 185 | 186 | # 187 | -------------------------------------------------------------------------------- /mlx_vlm/models/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | from abc import ABC, abstractmethod 3 | from dataclasses import dataclass 4 | from typing import Any, Dict, List, Optional 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | from mlx_lm.models.base import create_attention_mask 9 | from mlx_lm.models.cache import RotatingKVCache 10 | from PIL import Image 11 | from transformers.image_processing_utils import BaseImageProcessor as ImageProcessor 12 | from transformers.image_processing_utils import get_size_dict 13 | from transformers.image_utils import ChannelDimension, PILImageResampling 14 | 15 | 16 | @dataclass 17 | class LanguageModelOutput: 18 | logits: mx.array 19 | cross_attention_states: Optional[List[mx.array]] = None 20 | encoder_outputs: Optional[List[mx.array]] = None 21 | 22 | 23 | def expand2square(pil_img, background_color): 24 | width, height = pil_img.size 25 | if width == height: 26 | return pil_img 27 | elif width > height: 28 | result = Image.new(pil_img.mode, (width, width), background_color) 29 | result.paste(pil_img, (0, (width - height) // 2)) 30 | return result 31 | else: 32 | result = Image.new(pil_img.mode, (height, height), background_color) 33 | result.paste(pil_img, ((height - width) // 2, 0)) 34 | return result 35 | 36 | 37 | class BaseImageProcessor(ImageProcessor): 38 | def __init__( 39 | self, 40 | image_mean=(0.5, 0.5, 0.5), 41 | image_std=(0.5, 0.5, 0.5), 42 | size=(384, 384), 43 | crop_size: Dict[str, int] = None, 44 | resample=PILImageResampling.BICUBIC, 45 | rescale_factor=1 / 255, 46 | data_format=ChannelDimension.FIRST, 47 | ): 48 | crop_size = ( 49 | crop_size if crop_size is not None else {"height": 384, "width": 384} 50 | ) 51 | crop_size = get_size_dict( 52 | crop_size, default_to_square=True, param_name="crop_size" 53 | ) 54 | 55 | self.image_mean = image_mean 56 | self.image_std = image_std 57 | self.size = size 58 | self.resample = resample 59 | self.rescale_factor = rescale_factor 60 | self.data_format = data_format 61 | self.crop_size = crop_size 62 | 63 | @abstractmethod 64 | def preprocess(self, images): 65 | pass 66 | 67 | 68 | # Add this code to visualize the chunked attention mask 69 | def visualize_attention_mask(mask): 70 | """Visualize attention mask with symbols for better readability.""" 71 | if mask is None: 72 | print("No mask") 73 | return 74 | 75 | seq_len = mask.shape[0] 76 | 77 | print(" ", end="") 78 | for i in range(seq_len): 79 | print(f"{i:2d} ", end="") 80 | print() 81 | 82 | for i in range(seq_len): 83 | print(f"Token {i:2d}: ", end="") 84 | for j in range(seq_len): 85 | if mask[i, j]: 86 | print(" ■ ", end="") 87 | else: 88 | print(" ⬚ ", end="") 89 | print() 90 | 91 | 92 | def check_activation_stats(name, tensor): 93 | """Helper function to check for anomalies and log stats.""" 94 | 95 | print(f"--- Activation Stats: {name} ---") 96 | # Check for NaNs/Infs 97 | has_nan = mx.isnan(tensor).any() 98 | has_inf = mx.isinf(tensor).any() 99 | if has_nan: 100 | print(f"WARNING: Found NaN in {name}") 101 | if has_inf: 102 | print(f"WARNING: Found Inf in {name}") 103 | 104 | # Calculate and print stats (ensure computation happens) 105 | min_val = mx.min(tensor).item() 106 | max_val = mx.max(tensor).item() 107 | mean_val = mx.mean(tensor).item() 108 | std_val = mx.std(tensor).item() 109 | print(f" Shape: {tensor.shape}") 110 | print(f" Min: {min_val:.4f}, Max: {max_val:.4f}") 111 | print(f" Mean: {mean_val:.4f}, Std: {std_val:.4f}") 112 | print("-" * (len(name) + 24)) 113 | 114 | 115 | def pixel_shuffle(input_tensor, shuffle_ratio): 116 | # input_tensor: [batch_size, num_patches, channels] 117 | batch_size, num_patches, channels = input_tensor.shape 118 | patch_size = int(math.sqrt(num_patches)) 119 | 120 | input_tensor = input_tensor.reshape(batch_size, patch_size, patch_size, -1) 121 | batch_size, height, width, channels = input_tensor.shape 122 | 123 | reshaped_tensor = input_tensor.reshape( 124 | batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio) 125 | ) 126 | reshaped_tensor = reshaped_tensor.transpose(0, 2, 1, 3) 127 | 128 | reshaped_tensor = reshaped_tensor.reshape( 129 | batch_size, 130 | int(height * shuffle_ratio), 131 | int(width * shuffle_ratio), 132 | int(channels / (shuffle_ratio**2)), 133 | ) 134 | reshaped_tensor = reshaped_tensor.transpose(0, 2, 1, 3) 135 | 136 | output_tensor = reshaped_tensor.reshape(batch_size, -1, reshaped_tensor.shape[-1]) 137 | return output_tensor 138 | 139 | 140 | def interpolate(pos_embed, size, mode="cubic", align_corners=False): 141 | """ 142 | MLX implementation of PyTorch's F.interpolate with bicubic mode 143 | 144 | Args: 145 | pos_embed: MLX array with shape [B, C, H_src, W_src] or [C, H_src, W_src] 146 | size: Tuple (H_dst, W_dst) - target size 147 | align_corners: Boolean - whether to align corners 148 | 149 | Returns: 150 | Interpolated array with shape [B, C, H_dst, W_dst] or [C, H_dst, W_dst] 151 | """ 152 | # Handle different input shapes 153 | input_dim = pos_embed.ndim 154 | original_shape = pos_embed.shape 155 | 156 | if input_dim == 3: 157 | # [C, H, W] -> [1, C, H, W] 158 | pos_embed = pos_embed.reshape(1, *original_shape) 159 | 160 | # Get source dimensions 161 | h_src, w_src = pos_embed.shape[-2:] 162 | h_dst, w_dst = size 163 | 164 | # Calculate scale factors 165 | scale_h = h_dst / h_src 166 | scale_w = w_dst / w_src 167 | 168 | # Create upsampler 169 | upsampler = nn.Upsample( 170 | scale_factor=(scale_h, scale_w), mode=mode, align_corners=align_corners 171 | ) 172 | 173 | # Apply upsampling 174 | result = upsampler(pos_embed) 175 | 176 | # Return in the original dimension format 177 | if input_dim == 3: 178 | return result.reshape(original_shape[0], *size) 179 | return result 180 | -------------------------------------------------------------------------------- /mlx_vlm/models/cache.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | from mlx_lm.models.cache import ChunkedKVCache, KVCache, RotatingKVCache, _BaseCache 3 | 4 | 5 | class SimpleKVCache: 6 | """A simple key-value cache for transformer attention layers. 7 | 8 | Stores and concatenates key/value tensors along sequence dimension. 9 | """ 10 | 11 | def __init__(self): 12 | self.keys = None 13 | self.values = None 14 | self.cache_length = 0 15 | 16 | def update_and_fetch(self, keys, values): 17 | """Update cache with new key/value tensors and return full cache. 18 | 19 | Args: 20 | keys: New key tensor to add [batch, heads, seq_len, head_dim] 21 | values: New value tensor to add [batch, heads, seq_len, head_dim] 22 | 23 | Returns: 24 | Tuple of (cached_keys, cached_values) containing full cache history 25 | """ 26 | if self.cache_length == 0: 27 | # First update - just store tensors 28 | self.keys = keys 29 | self.values = values 30 | else: 31 | # Concatenate with existing cache along sequence dimension 32 | self.keys = mx.concatenate([self.keys, keys], axis=2) 33 | self.values = mx.concatenate([self.values, values], axis=2) 34 | 35 | self.cache_length += keys.shape[2] 36 | return self.keys, self.values 37 | 38 | def fetch(self): 39 | return self.keys, self.values 40 | 41 | def update(self, keys, values): 42 | """Update cache with new key/value tensors without returning. 43 | 44 | Args: 45 | keys: New key tensor to store 46 | values: New value tensor to store 47 | """ 48 | self.keys = keys 49 | self.values = values 50 | self.cache_length += keys.shape[2] 51 | -------------------------------------------------------------------------------- /mlx_vlm/models/deepseek_vl_v2/__init__.py: -------------------------------------------------------------------------------- 1 | from .deepseek_vl_v2 import ( 2 | DeepseekVLV2Processor, 3 | LanguageModel, 4 | Model, 5 | ModelConfig, 6 | ProjectorConfig, 7 | TextConfig, 8 | VisionConfig, 9 | VisionModel, 10 | ) 11 | -------------------------------------------------------------------------------- /mlx_vlm/models/florence2/__init__.py: -------------------------------------------------------------------------------- 1 | from .florence2 import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/gemma3/__init__.py: -------------------------------------------------------------------------------- 1 | from .gemma3 import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/idefics2/__init__.py: -------------------------------------------------------------------------------- 1 | from .idefics2 import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | PerceiverConfig, 6 | TextConfig, 7 | VisionConfig, 8 | VisionModel, 9 | ) 10 | -------------------------------------------------------------------------------- /mlx_vlm/models/idefics2/language.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | 8 | from ..base import LanguageModelOutput, create_attention_mask 9 | from ..cache import KVCache 10 | 11 | 12 | @dataclass 13 | class TextConfig: 14 | model_type: str 15 | hidden_size: int 16 | num_hidden_layers: int 17 | intermediate_size: int 18 | num_attention_heads: int 19 | rms_norm_eps: float 20 | vocab_size: int 21 | num_key_value_heads: int 22 | rope_theta: float = 1000000.0 23 | rope_traditional: bool = False 24 | max_position_embeddings: int = 4096 25 | tie_word_embeddings: bool = False 26 | 27 | @classmethod 28 | def from_dict(cls, params): 29 | return cls( 30 | **{ 31 | k: v 32 | for k, v in params.items() 33 | if k in inspect.signature(cls).parameters 34 | } 35 | ) 36 | 37 | def __post_init__(self): 38 | if self.num_key_value_heads is None: 39 | self.num_key_value_heads = self.num_attention_heads 40 | 41 | 42 | class Attention(nn.Module): 43 | def __init__(self, config: TextConfig): 44 | super().__init__() 45 | 46 | dim = config.hidden_size 47 | self.n_heads = n_heads = config.num_attention_heads 48 | self.n_kv_heads = n_kv_heads = config.num_key_value_heads 49 | 50 | head_dim = config.hidden_size // n_heads 51 | self.scale = head_dim**-0.5 52 | 53 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) 54 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 55 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 56 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 57 | 58 | self.rope = nn.RoPE( 59 | head_dim, 60 | traditional=config.rope_traditional, 61 | base=config.rope_theta, 62 | ) 63 | 64 | def __call__( 65 | self, 66 | x: mx.array, 67 | mask: Optional[mx.array] = None, 68 | cache: Optional[KVCache] = None, 69 | ) -> mx.array: 70 | B, L, D = x.shape 71 | 72 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 73 | 74 | # Prepare the queries, keys and values for the attention computation 75 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 76 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 77 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 78 | 79 | if cache is not None: 80 | queries = self.rope(queries, offset=cache.offset) 81 | keys = self.rope(keys, offset=cache.offset) 82 | keys, values = cache.update_and_fetch(keys, values) 83 | else: 84 | queries = self.rope(queries) 85 | keys = self.rope(keys) 86 | 87 | output = mx.fast.scaled_dot_product_attention( 88 | queries, keys, values, scale=self.scale, mask=mask 89 | ) 90 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 91 | return self.o_proj(output) 92 | 93 | 94 | class MLP(nn.Module): 95 | def __init__(self, dim, hidden_dim): 96 | super().__init__() 97 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 98 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 99 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 100 | 101 | def __call__(self, x) -> mx.array: 102 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 103 | 104 | 105 | class TransformerBlock(nn.Module): 106 | def __init__(self, config: TextConfig): 107 | super().__init__() 108 | self.num_attention_heads = config.num_attention_heads 109 | self.hidden_size = config.hidden_size 110 | self.self_attn = Attention(config) 111 | self.mlp = MLP(config.hidden_size, config.intermediate_size) 112 | self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 113 | self.post_attention_layernorm = nn.RMSNorm( 114 | config.hidden_size, eps=config.rms_norm_eps 115 | ) 116 | self.config = config 117 | 118 | def __call__( 119 | self, 120 | x: mx.array, 121 | mask: Optional[mx.array] = None, 122 | cache: Optional[KVCache] = None, 123 | ) -> mx.array: 124 | r = self.self_attn(self.input_layernorm(x), mask, cache) 125 | h = x + r 126 | r = self.mlp(self.post_attention_layernorm(h)) 127 | out = h + r 128 | return out 129 | 130 | 131 | class LanguageModel(nn.Module): 132 | def __init__(self, config: TextConfig): 133 | super().__init__() 134 | self.config = config 135 | self.model_type = config.model_type 136 | self.vocab_size = config.vocab_size 137 | self.num_hidden_layers = config.num_hidden_layers 138 | assert self.vocab_size > 0 139 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) 140 | self.layers = [ 141 | TransformerBlock(config=config) for _ in range(config.num_hidden_layers) 142 | ] 143 | self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 144 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 145 | 146 | def __call__( 147 | self, 148 | inputs: mx.array, 149 | inputs_embeds: Optional[mx.array] = None, 150 | mask: Optional[mx.array] = None, 151 | cache=None, 152 | ): 153 | # for passing merged input embeddings 154 | if inputs_embeds is None: 155 | h = self.embed_tokens(inputs) 156 | else: 157 | h = inputs_embeds 158 | 159 | if cache is None: 160 | cache = [None] * len(self.layers) 161 | 162 | if mask is None: 163 | mask = create_attention_mask(h, cache) 164 | 165 | for layer, c in zip(self.layers, cache): 166 | h = layer(h, mask, c) 167 | 168 | logits = self.lm_head(self.norm(h)) 169 | return LanguageModelOutput(logits=logits) 170 | 171 | def sanitize(self, weights): 172 | # Remove unused precomputed rotary freqs 173 | return { 174 | k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k 175 | } 176 | 177 | @property 178 | def layers(self): 179 | return self.model.layers 180 | 181 | @property 182 | def head_dim(self): 183 | return self.config.hidden_size // self.config.num_attention_heads 184 | 185 | @property 186 | def n_kv_heads(self): 187 | return self.config.num_key_value_heads 188 | -------------------------------------------------------------------------------- /mlx_vlm/models/idefics3/__init__.py: -------------------------------------------------------------------------------- 1 | from .idefics3 import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/idefics3/language.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import re 3 | from dataclasses import dataclass 4 | from typing import Dict, Optional, Tuple, Union 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | from ..base import LanguageModelOutput, create_attention_mask 10 | from ..cache import KVCache 11 | 12 | 13 | @dataclass 14 | class TextConfig: 15 | model_type: str 16 | hidden_size: int 17 | intermediate_size: int 18 | num_attention_heads: int 19 | rms_norm_eps: float 20 | vocab_size: int 21 | num_key_value_heads: int 22 | rope_theta: float = 1000000.0 23 | num_hidden_layers: int = 32 24 | rope_traditional: bool = False 25 | max_position_embeddings: int = 4096 26 | tie_word_embeddings: bool = False 27 | 28 | @classmethod 29 | def from_dict(cls, params): 30 | return cls( 31 | **{ 32 | k: v 33 | for k, v in params.items() 34 | if k in inspect.signature(cls).parameters 35 | } 36 | ) 37 | 38 | def __post_init__(self): 39 | if self.num_key_value_heads is None: 40 | self.num_key_value_heads = self.num_attention_heads 41 | 42 | 43 | class Attention(nn.Module): 44 | def __init__(self, config: TextConfig): 45 | super().__init__() 46 | 47 | dim = config.hidden_size 48 | self.n_heads = n_heads = config.num_attention_heads 49 | self.n_kv_heads = n_kv_heads = config.num_key_value_heads 50 | 51 | head_dim = config.hidden_size // n_heads 52 | self.scale = head_dim**-0.5 53 | 54 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) 55 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 56 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 57 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 58 | 59 | self.rope = nn.RoPE( 60 | head_dim, 61 | traditional=config.rope_traditional, 62 | base=config.rope_theta, 63 | ) 64 | 65 | def __call__( 66 | self, 67 | x: mx.array, 68 | mask: Optional[mx.array] = None, 69 | cache: Optional[KVCache] = None, 70 | ) -> mx.array: 71 | B, L, D = x.shape 72 | 73 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 74 | 75 | # Prepare the queries, keys and values for the attention computation 76 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 77 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 78 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 79 | 80 | if cache is not None: 81 | queries = self.rope(queries, offset=cache.offset) 82 | keys = self.rope(keys, offset=cache.offset) 83 | keys, values = cache.update_and_fetch(keys, values) 84 | else: 85 | queries = self.rope(queries) 86 | keys = self.rope(keys) 87 | 88 | output = mx.fast.scaled_dot_product_attention( 89 | queries, keys, values, scale=self.scale, mask=mask 90 | ) 91 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 92 | return self.o_proj(output) 93 | 94 | 95 | class MLP(nn.Module): 96 | def __init__(self, dim, hidden_dim): 97 | super().__init__() 98 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 99 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 100 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 101 | 102 | def __call__(self, x) -> mx.array: 103 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 104 | 105 | 106 | class TransformerBlock(nn.Module): 107 | def __init__(self, config: TextConfig): 108 | super().__init__() 109 | self.num_attention_heads = config.num_attention_heads 110 | self.hidden_size = config.hidden_size 111 | self.self_attn = Attention(config) 112 | self.mlp = MLP(config.hidden_size, config.intermediate_size) 113 | self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 114 | self.post_attention_layernorm = nn.RMSNorm( 115 | config.hidden_size, eps=config.rms_norm_eps 116 | ) 117 | self.config = config 118 | 119 | def __call__( 120 | self, 121 | x: mx.array, 122 | mask: Optional[mx.array] = None, 123 | cache: Optional[KVCache] = None, 124 | ) -> mx.array: 125 | r = self.self_attn(self.input_layernorm(x), mask, cache) 126 | h = x + r 127 | r = self.mlp(self.post_attention_layernorm(h)) 128 | out = h + r 129 | return out 130 | 131 | 132 | class LanguageModel(nn.Module): 133 | def __init__(self, config: TextConfig): 134 | super().__init__() 135 | self.config = config 136 | self.model_type = config.model_type 137 | self.vocab_size = config.vocab_size 138 | self.num_hidden_layers = config.num_hidden_layers 139 | assert self.vocab_size > 0 140 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) 141 | self.layers = [ 142 | TransformerBlock(config=config) for _ in range(config.num_hidden_layers) 143 | ] 144 | self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 145 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 146 | 147 | def __call__( 148 | self, 149 | inputs: mx.array, 150 | inputs_embeds: Optional[mx.array] = None, 151 | mask: Optional[mx.array] = None, 152 | cache=None, 153 | ): 154 | # for passing merged input embeddings 155 | if inputs_embeds is None: 156 | h = self.embed_tokens(inputs) 157 | else: 158 | h = inputs_embeds.astype(self.norm.weight.dtype) 159 | 160 | if cache is None: 161 | cache = [None] * len(self.layers) 162 | 163 | if mask is None: 164 | mask = create_attention_mask(h, cache) 165 | 166 | for layer, c in zip(self.layers, cache): 167 | h = layer(h, mask, c) 168 | 169 | logits = self.lm_head(self.norm(h)) 170 | return LanguageModelOutput(logits=logits) 171 | 172 | def sanitize(self, weights): 173 | # Remove unused precomputed rotary freqs 174 | return { 175 | k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k 176 | } 177 | 178 | @property 179 | def layers(self): 180 | return self.model.layers 181 | 182 | @property 183 | def head_dim(self): 184 | return self.config.hidden_size // self.config.num_attention_heads 185 | 186 | @property 187 | def n_kv_heads(self): 188 | return self.config.num_key_value_heads 189 | -------------------------------------------------------------------------------- /mlx_vlm/models/internvl_chat/__init__.py: -------------------------------------------------------------------------------- 1 | from .internvl_chat import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | from .processor import InternVLChatProcessor, InternVLImageProcessor 10 | -------------------------------------------------------------------------------- /mlx_vlm/models/internvl_chat/internvl_chat.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import json 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import List, Optional 7 | 8 | import mlx.core as mx 9 | import mlx.nn as nn 10 | import numpy as np 11 | from huggingface_hub import snapshot_download 12 | 13 | from ..base import pixel_shuffle 14 | from .language import LanguageModel, TextConfig 15 | from .vision import VisionConfig, VisionModel 16 | 17 | 18 | @dataclass 19 | class ModelConfig: 20 | text_config: TextConfig 21 | vision_config: VisionConfig 22 | model_type: str 23 | ignore_index: int = -100 24 | image_token_index: int = 151667 25 | video_token_index: int = 151656 26 | vision_feature_select_strategy: str = "default" 27 | vision_feature_layer: int = -1 28 | vocab_size: int = 32000 29 | downsample_ratio: float = 0.5 30 | eos_token_id: Optional[List[int]] = None 31 | 32 | @classmethod 33 | def from_dict(cls, params): 34 | return cls( 35 | **{ 36 | k: v 37 | for k, v in params.items() 38 | if k in inspect.signature(cls).parameters 39 | } 40 | ) 41 | 42 | 43 | class Model(nn.Module): 44 | def __init__(self, config: ModelConfig): 45 | super().__init__() 46 | self.config = config 47 | self.vision_model = VisionModel(config.vision_config) 48 | self.language_model = LanguageModel(config.text_config) 49 | 50 | self.downsample_ratio = config.downsample_ratio 51 | 52 | vit_hidden_size = self.config.vision_config.hidden_size 53 | llm_hidden_size = self.config.text_config.hidden_size 54 | 55 | self.mlp1 = [ 56 | nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), 57 | nn.Linear( 58 | vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size 59 | ), 60 | nn.GELU(), 61 | nn.Linear(llm_hidden_size, llm_hidden_size), 62 | ] 63 | 64 | def get_input_embeddings( 65 | self, 66 | input_ids: Optional[mx.array] = None, 67 | pixel_values: Optional[mx.array] = None, 68 | ): 69 | 70 | if pixel_values is None: 71 | return self.language_model.model.embed_tokens(input_ids) 72 | 73 | dtype = self.vision_model.embeddings.patch_embedding.weight.dtype 74 | pixel_values = pixel_values.astype(dtype) 75 | 76 | # TODO: Remove this after transformers implementation is merged 77 | if pixel_values.ndim == 5: 78 | pixel_values = pixel_values[0] 79 | 80 | # Get the input embeddings from the language model 81 | inputs_embeds = self.language_model.model.embed_tokens(input_ids) 82 | 83 | # Get the ouptut hidden states from the vision model 84 | hidden_states, _, _ = self.vision_model( 85 | pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True 86 | ) 87 | 88 | # Extract vision embeddings, removing the class token (first token) 89 | hidden_states = hidden_states[:, 1:, :] 90 | 91 | # Apply pixel shuffle with downsampling 92 | hidden_states = pixel_shuffle( 93 | hidden_states, shuffle_ratio=self.downsample_ratio 94 | ) 95 | 96 | # Apply MLP transformation 97 | for layer in self.mlp1: 98 | hidden_states = layer(hidden_states) 99 | 100 | # Insert special image tokens in the input_ids 101 | final_inputs_embeds = self._merge_input_ids_with_image_features( 102 | hidden_states, inputs_embeds, input_ids 103 | ) 104 | return final_inputs_embeds 105 | 106 | def _merge_input_ids_with_image_features( 107 | self, image_features, inputs_embeds, input_ids 108 | ): 109 | B, N, C = inputs_embeds.shape 110 | image_token_index = self.config.image_token_index 111 | video_token_index = self.config.video_token_index 112 | 113 | # Positions of tokens in input_ids, assuming batch size is 1 114 | image_positions = input_ids == image_token_index 115 | if mx.sum(image_positions) == 0: 116 | image_positions = input_ids == video_token_index 117 | 118 | image_indices = np.where(image_positions)[1].tolist() 119 | 120 | image_features = image_features.reshape(-1, image_features.shape[-1]) 121 | 122 | inputs_embeds[:, image_indices, :] = image_features 123 | 124 | return inputs_embeds.reshape(B, N, C) 125 | 126 | def __call__( 127 | self, 128 | input_ids: mx.array, 129 | pixel_values: mx.array, 130 | mask: mx.array, 131 | cache=None, 132 | **kwargs, 133 | ): 134 | input_embddings = self.get_input_embeddings(input_ids, pixel_values) 135 | logits = self.language_model(None, cache=cache, inputs_embeds=input_embddings) 136 | return logits 137 | 138 | @staticmethod 139 | def from_pretrained(path_or_hf_repo: str): 140 | path = Path(path_or_hf_repo) 141 | if not path.exists(): 142 | path = Path( 143 | snapshot_download( 144 | repo_id=path_or_hf_repo, 145 | allow_patterns=[ 146 | "*.json", 147 | "*.safetensors", 148 | "*.py", 149 | "tokenizer.model", 150 | "*.tiktoken", 151 | ], 152 | ) 153 | ) 154 | 155 | with open(path / "config.json", "r") as f: 156 | model_config = json.load(f) 157 | 158 | model_config = ModelConfig.from_dict(model_config) 159 | 160 | model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) 161 | model_config.text_config = TextConfig.from_dict(model_config) 162 | 163 | model = Model(model_config) 164 | weight_files = glob.glob(str(path / "*.safetensors")) 165 | if not weight_files: 166 | raise FileNotFoundError(f"No safetensors found in {path}") 167 | 168 | weights = {} 169 | for wf in weight_files: 170 | weights.update(mx.load(wf)) 171 | 172 | weights = VisionModel.sanitize(weights) 173 | weights = LanguageModel.sanitize(weights) 174 | 175 | model.load_weights(list(weights.items())) 176 | return model 177 | -------------------------------------------------------------------------------- /mlx_vlm/models/internvl_chat/language.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Dict, Optional, Union 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | 8 | from ..base import LanguageModelOutput, create_attention_mask 9 | from ..cache import KVCache 10 | 11 | 12 | @dataclass 13 | class TextConfig: 14 | model_type: str 15 | hidden_size: int 16 | num_hidden_layers: int 17 | intermediate_size: int 18 | num_attention_heads: int 19 | rms_norm_eps: float 20 | vocab_size: int 21 | max_window_layers: int 22 | hidden_act: str 23 | num_key_value_heads: Optional[int] = 8 24 | max_position_embeddings: Optional[int] = 40960 25 | rope_theta: float = 1000000.0 26 | rope_traditional: bool = False 27 | rope_scaling: Optional[Dict[str, Union[float, str]]] = None 28 | tie_word_embeddings: bool = False 29 | sliding_window: int = 32768 30 | use_sliding_window: bool = False 31 | use_cache: bool = True 32 | 33 | def __post_init__(self): 34 | if self.num_key_value_heads is None: 35 | self.num_key_value_heads = self.num_attention_heads 36 | 37 | @classmethod 38 | def from_dict(cls, params): 39 | return cls( 40 | **{ 41 | k: v 42 | for k, v in params.items() 43 | if k in inspect.signature(cls).parameters 44 | } 45 | ) 46 | 47 | 48 | class Attention(nn.Module): 49 | def __init__(self, args: TextConfig): 50 | super().__init__() 51 | 52 | dim = args.hidden_size 53 | self.n_heads = n_heads = args.num_attention_heads 54 | assert args.num_key_value_heads is not None 55 | self.n_kv_heads = n_kv_heads = args.num_key_value_heads 56 | 57 | self.head_dim = head_dim = args.hidden_size // n_heads 58 | self.scale = head_dim**-0.5 59 | 60 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) 61 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) 62 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) 63 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 64 | 65 | self.rotary_emb = nn.RoPE( 66 | head_dim, 67 | base=args.rope_theta, 68 | traditional=args.rope_traditional, 69 | ) 70 | 71 | def __call__( 72 | self, 73 | x: mx.array, 74 | mask: Optional[mx.array] = None, 75 | cache: Optional[KVCache] = None, 76 | ) -> mx.array: 77 | B, L, D = x.shape 78 | 79 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 80 | 81 | # Prepare the queries, keys and values for the attention computation 82 | queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose( 83 | 0, 2, 1, 3 84 | ) 85 | keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3) 86 | values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose( 87 | 0, 2, 1, 3 88 | ) 89 | 90 | offset = cache.offset if cache else 0 91 | 92 | if mask is not None and isinstance(mask, mx.array): 93 | mask = mask[..., : keys.shape[-2]] 94 | 95 | queries = self.rotary_emb(queries, offset=offset) 96 | keys = self.rotary_emb(keys, offset=offset) 97 | 98 | if cache is not None: 99 | keys, values = cache.update_and_fetch(keys, values) 100 | 101 | output = mx.fast.scaled_dot_product_attention( 102 | queries, keys, values, scale=self.scale, mask=mask 103 | ) 104 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 105 | return self.o_proj(output) 106 | 107 | 108 | class MLP(nn.Module): 109 | def __init__(self, dim, hidden_dim): 110 | super().__init__() 111 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 112 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 113 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 114 | 115 | def __call__(self, x) -> mx.array: 116 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 117 | 118 | 119 | class Qwen2VLDecoderLayer(nn.Module): 120 | def __init__(self, args: TextConfig): 121 | super().__init__() 122 | self.num_attention_heads = args.num_attention_heads 123 | self.hidden_size = args.hidden_size 124 | self.self_attn = Attention(args) 125 | self.mlp = MLP(args.hidden_size, args.intermediate_size) 126 | self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 127 | self.post_attention_layernorm = nn.RMSNorm( 128 | args.hidden_size, eps=args.rms_norm_eps 129 | ) 130 | self.args = args 131 | 132 | def __call__( 133 | self, 134 | x: mx.array, 135 | mask: Optional[mx.array] = None, 136 | cache: Optional[KVCache] = None, 137 | ) -> mx.array: 138 | r = self.self_attn(self.input_layernorm(x), mask, cache) 139 | h = x + r 140 | r = self.mlp(self.post_attention_layernorm(h)) 141 | out = h + r 142 | return out 143 | 144 | 145 | class Qwen2Model(nn.Module): 146 | def __init__(self, args: TextConfig): 147 | super().__init__() 148 | self.args = args 149 | self.vocab_size = args.vocab_size 150 | self.num_hidden_layers = args.num_hidden_layers 151 | assert self.vocab_size > 0 152 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 153 | self.layers = [ 154 | Qwen2VLDecoderLayer(args=args) for _ in range(args.num_hidden_layers) 155 | ] 156 | self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 157 | 158 | def __call__( 159 | self, 160 | inputs: mx.array, 161 | inputs_embeds: Optional[mx.array] = None, 162 | mask: Optional[mx.array] = None, 163 | cache=None, 164 | ): 165 | if inputs_embeds is None: 166 | h = self.embed_tokens(inputs) 167 | else: 168 | h = inputs_embeds 169 | 170 | if cache is None: 171 | cache = [None] * len(self.layers) 172 | 173 | if mask is None: 174 | mask = create_attention_mask(h, cache) 175 | 176 | for layer, c in zip(self.layers, cache): 177 | h = layer(h, mask, c) 178 | 179 | return self.norm(h) 180 | 181 | 182 | class LanguageModel(nn.Module): 183 | def __init__(self, args: TextConfig): 184 | super().__init__() 185 | self.args = args 186 | self.model_type = args.model_type 187 | self.model = Qwen2Model(args) 188 | 189 | if not args.tie_word_embeddings: 190 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 191 | 192 | def __call__( 193 | self, 194 | inputs: mx.array, 195 | inputs_embeds: Optional[mx.array] = None, 196 | mask: Optional[mx.array] = None, 197 | cache=None, 198 | ): 199 | out = self.model(inputs, cache=cache, inputs_embeds=inputs_embeds) 200 | if self.args.tie_word_embeddings: 201 | out = self.model.embed_tokens.as_linear(out) 202 | else: 203 | out = self.lm_head(out) 204 | return LanguageModelOutput(logits=out) 205 | 206 | @property 207 | def layers(self): 208 | return self.model.layers 209 | 210 | @property 211 | def head_dim(self): 212 | return self.args.hidden_size // self.args.num_attention_heads 213 | 214 | @property 215 | def n_kv_heads(self): 216 | return self.args.num_key_value_heads 217 | -------------------------------------------------------------------------------- /mlx_vlm/models/kimi_vl/__init__.py: -------------------------------------------------------------------------------- 1 | from .kimi_vl import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/kimi_vl/kimi_vl.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import json 4 | import re 5 | from dataclasses import dataclass 6 | from pathlib import Path 7 | from typing import List, Optional 8 | 9 | import mlx.core as mx 10 | import mlx.nn as nn 11 | import numpy as np 12 | from huggingface_hub import snapshot_download 13 | from transformers import AutoConfig 14 | 15 | from .language import LanguageModel, TextConfig 16 | from .vision import VisionConfig, VisionModel 17 | 18 | 19 | @dataclass 20 | class ModelConfig: 21 | text_config: TextConfig 22 | vision_config: VisionConfig 23 | model_type: str 24 | ignore_index: int = -100 25 | vocab_size: int = 128259 26 | scale_factor: int = 2 27 | media_placeholder_token_id: int = 163606 28 | image_token_index: Optional[int] = None 29 | eos_token_id: Optional[List[int]] = None 30 | 31 | def __post_init__(self): 32 | if self.image_token_index is None: 33 | self.image_token_index = self.media_placeholder_token_id 34 | 35 | @classmethod 36 | def from_dict(cls, params): 37 | return cls( 38 | **{ 39 | k: v 40 | for k, v in params.items() 41 | if k in inspect.signature(cls).parameters 42 | } 43 | ) 44 | 45 | 46 | class KimiVLMultiModalProjector(nn.Module): 47 | 48 | def __init__(self, config: ModelConfig): 49 | super().__init__() 50 | 51 | self.hidden_size = ( 52 | config.vision_config.hidden_size 53 | * config.vision_config.merge_kernel_size[0] 54 | * config.vision_config.merge_kernel_size[1] 55 | ) 56 | 57 | self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=1e-05) 58 | self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 59 | self.act = nn.GELU() 60 | self.linear_2 = nn.Linear( 61 | self.hidden_size, config.text_config.hidden_size, bias=True 62 | ) 63 | 64 | def __call__(self, image_features: list[mx.array]) -> mx.array: 65 | image_features = mx.concatenate(image_features, axis=0) 66 | h = self.pre_norm(image_features).reshape(-1, self.hidden_size) 67 | h = self.linear_1(h) 68 | h = self.act(h) 69 | h = self.linear_2(h) 70 | return h 71 | 72 | 73 | class Model(nn.Module): 74 | def __init__(self, config: ModelConfig): 75 | super().__init__() 76 | self.model_type = config.model_type 77 | self.config = config 78 | 79 | self.vision_tower = VisionModel(config.vision_config) 80 | self.language_model = LanguageModel(config.text_config) 81 | self.multi_modal_projector = KimiVLMultiModalProjector(config) 82 | 83 | def get_input_embeddings( 84 | self, 85 | input_ids: Optional[mx.array] = None, 86 | pixel_values: Optional[mx.array] = None, 87 | grid_thw: Optional[mx.array] = None, 88 | ): 89 | if pixel_values is None: 90 | return self.language_model.embed_tokens(input_ids) 91 | 92 | inputs_embeds = self.language_model.embed_tokens(input_ids) 93 | 94 | hidden_state = self.vision_tower( 95 | pixel_values.transpose(0, 2, 3, 1), 96 | output_hidden_states=True, 97 | grid_thw=grid_thw, 98 | ) 99 | 100 | image_features = self.multi_modal_projector(hidden_state) 101 | 102 | final_inputs_embeds = self._prepare_inputs_for_multimodal( 103 | image_features, inputs_embeds, input_ids 104 | ) 105 | return final_inputs_embeds 106 | 107 | def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids): 108 | image_token_index = self.config.image_token_index 109 | 110 | # Positions of tokens in input_ids, assuming batch size is 1 111 | image_positions = np.where(input_ids == image_token_index)[1].tolist() 112 | 113 | inputs_embeds[:, image_positions, :] = image_features 114 | 115 | return inputs_embeds 116 | 117 | def __call__( 118 | self, 119 | input_ids: mx.array, 120 | pixel_values: mx.array, 121 | cache=None, 122 | **kwargs, 123 | ): 124 | image_grid_thw = kwargs.pop("image_grid_hws", None) 125 | video_grid_thw = kwargs.pop("video_grid_hws", None) 126 | grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw 127 | input_embeddings = self.get_input_embeddings( 128 | input_ids, pixel_values, grid_thw=grid_thw 129 | ) 130 | logits = self.language_model( 131 | inputs=input_ids, cache=cache, inputs_embeds=input_embeddings 132 | ) 133 | return logits 134 | 135 | @staticmethod 136 | def from_pretrained(path_or_hf_repo: str): 137 | path = Path(path_or_hf_repo) 138 | if not path.exists(): 139 | path = Path( 140 | snapshot_download( 141 | repo_id=path_or_hf_repo, 142 | allow_patterns=[ 143 | "*.json", 144 | "*.safetensors", 145 | "*.py", 146 | "tokenizer.model", 147 | "*.tiktoken", 148 | ], 149 | ) 150 | ) 151 | 152 | with open(path / "config.json", "r") as f: 153 | config = json.load(f) 154 | 155 | text_config = AutoConfig.from_pretrained(config["text_config"]["model_type"]) 156 | text_config = text_config.to_dict() 157 | config["text_config"] = text_config 158 | model_config = ModelConfig.from_dict(config) 159 | model_config.vision_config = VisionConfig.from_dict(config["vision_config"]) 160 | model_config.text_config = TextConfig.from_dict(config["text_config"]) 161 | 162 | model = Model(model_config) 163 | weight_files = glob.glob(str(path / "*.safetensors")) 164 | if not weight_files: 165 | raise FileNotFoundError(f"No safetensors found in {path}") 166 | 167 | weights = {} 168 | for wf in weight_files: 169 | weights.update(mx.load(wf)) 170 | 171 | weights = model.sanitize(weights=weights) 172 | weights = VisionModel(model_config.vision_config).sanitize(weights=weights) 173 | weights = LanguageModel(model_config.text_config).sanitize(weights=weights) 174 | model.load_weights(list(weights.items())) 175 | return model 176 | 177 | def sanitize(self, weights): 178 | return { 179 | k.replace("encoder.", "") if "vision_tower" in k else k: v 180 | for k, v in weights.items() 181 | } 182 | -------------------------------------------------------------------------------- /mlx_vlm/models/llama4/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama4 import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/llama4/llama4.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import json 4 | import re 5 | from dataclasses import dataclass 6 | from pathlib import Path 7 | from typing import Any, Callable, List, Optional, Tuple, Union 8 | 9 | import mlx.core as mx 10 | import mlx.nn as nn 11 | import numpy as np 12 | 13 | from .language import LanguageModel, TextConfig 14 | from .vision import Llama4MultiModalProjector, VisionConfig, VisionModel 15 | 16 | 17 | @dataclass 18 | class ModelConfig: 19 | text_config: TextConfig 20 | vision_config: VisionConfig 21 | model_type: str 22 | ignore_index: int = -100 23 | image_token_id: int = 200092 24 | image_token_index: Optional[int] = None 25 | eos_token_id: Optional[List[int]] = None 26 | 27 | def __post_init__(self): 28 | if self.image_token_index is None: 29 | self.image_token_index = self.image_token_id 30 | 31 | @classmethod 32 | def from_dict(cls, params): 33 | return cls( 34 | **{ 35 | k: v 36 | for k, v in params.items() 37 | if k in inspect.signature(cls).parameters 38 | } 39 | ) 40 | 41 | 42 | class Model(nn.Module): 43 | def __init__(self, config: ModelConfig): 44 | super().__init__() 45 | self.config = config 46 | self.vision_model = VisionModel(config.vision_config) 47 | self.multi_modal_projector = Llama4MultiModalProjector(config) 48 | self.language_model = LanguageModel(config.text_config) 49 | self.vocab_size = config.text_config.vocab_size 50 | 51 | def set_input_embeddings(self, value): 52 | self.language_model.set_input_embeddings(value) 53 | 54 | def get_output_embeddings(self): 55 | return self.language_model.get_output_embeddings() 56 | 57 | def set_output_embeddings(self, new_embeddings): 58 | self.language_model.set_output_embeddings(new_embeddings) 59 | 60 | def set_decoder(self, decoder): 61 | self.language_model.set_decoder(decoder) 62 | 63 | def get_decoder(self): 64 | return self.language_model.get_decoder() 65 | 66 | def get_image_features( 67 | self, 68 | pixel_values: mx.array, 69 | vision_feature_layer: Union[int, List[int]], 70 | vision_feature_select_strategy: str, 71 | **kwargs, 72 | ): 73 | if vision_feature_select_strategy not in ["default", "full"]: 74 | raise ValueError( 75 | f"Unexpected select feature strategy: {self.vision_feature_select_strategy}" 76 | ) 77 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 78 | hidden_state = self.vision_model( 79 | pixel_values, output_hidden_states=False, **kwargs 80 | ) 81 | return hidden_state 82 | 83 | def get_input_embeddings( 84 | self, 85 | input_ids: Optional[mx.array] = None, 86 | pixel_values: Optional[mx.array] = None, 87 | **kwargs, 88 | ): 89 | if pixel_values is None: 90 | return self.language_model.model.embed_tokens(input_ids) 91 | 92 | # Get the input embeddings from the language model 93 | inputs_embeds = self.language_model.model.embed_tokens(input_ids) 94 | 95 | image_features = self.get_image_features( 96 | pixel_values=pixel_values, 97 | vision_feature_layer=kwargs.get("vision_feature_layer", -1), 98 | vision_feature_select_strategy=kwargs.get( 99 | "vision_feature_select_strategy", "default" 100 | ), 101 | ) 102 | 103 | vision_flat = image_features.reshape(-1, image_features.shape[-1]) 104 | projected_vision_flat = self.multi_modal_projector(vision_flat) 105 | 106 | # Insert special image tokens in the input_ids 107 | final_inputs_embeds = self._prepare_inputs_for_multimodal( 108 | projected_vision_flat, inputs_embeds, input_ids 109 | ) 110 | return final_inputs_embeds 111 | 112 | def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids): 113 | image_token_index = self.config.image_token_index 114 | 115 | # Positions of tokens in input_ids, assuming batch size is 1 116 | image_positions = np.where(input_ids == image_token_index)[1].tolist() 117 | 118 | inputs_embeds[:, image_positions, :] = image_features 119 | 120 | return inputs_embeds 121 | 122 | def __call__( 123 | self, 124 | input_ids: mx.array, 125 | pixel_values: mx.array, 126 | cache=None, 127 | **kwargs, 128 | ): 129 | 130 | input_embeddings = self.get_input_embeddings(input_ids, pixel_values, **kwargs) 131 | logits = self.language_model( 132 | input_ids=input_ids, cache=cache, input_embeds=input_embeddings 133 | ) 134 | return logits 135 | -------------------------------------------------------------------------------- /mlx_vlm/models/llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .llava import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/llava/llava.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import json 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import List, Optional 7 | 8 | import mlx.core as mx 9 | import mlx.nn as nn 10 | import numpy as np 11 | from huggingface_hub import snapshot_download 12 | 13 | from .language import LanguageModel, TextConfig 14 | from .vision import VisionConfig, VisionModel 15 | 16 | 17 | @dataclass 18 | class ModelConfig: 19 | text_config: TextConfig 20 | vision_config: VisionConfig 21 | model_type: str 22 | ignore_index: int = -100 23 | image_token_index: int = 32000 24 | vision_feature_select_strategy: str = "default" 25 | vision_feature_layer: int = -2 26 | vocab_size: int = 32000 27 | eos_token_id: Optional[List[int]] = None 28 | 29 | @classmethod 30 | def from_dict(cls, params): 31 | return cls( 32 | **{ 33 | k: v 34 | for k, v in params.items() 35 | if k in inspect.signature(cls).parameters 36 | } 37 | ) 38 | 39 | 40 | class LlavaMultiModalProjector(nn.Module): 41 | def __init__(self, config: ModelConfig): 42 | super().__init__() 43 | self.linear_1 = nn.Linear( 44 | config.vision_config.hidden_size, config.text_config.hidden_size, bias=True 45 | ) 46 | self.gelu = nn.GELU() 47 | self.linear_2 = nn.Linear( 48 | config.text_config.hidden_size, config.text_config.hidden_size, bias=True 49 | ) 50 | 51 | def __call__(self, x: mx.array) -> mx.array: 52 | x = self.linear_1(x) 53 | x = self.gelu(x) 54 | x = self.linear_2(x) 55 | return x 56 | 57 | 58 | class Model(nn.Module): 59 | def __init__(self, config: ModelConfig): 60 | super().__init__() 61 | self.config = config 62 | self.vision_tower = VisionModel(config.vision_config) 63 | self.language_model = LanguageModel(config.text_config) 64 | self.multi_modal_projector = LlavaMultiModalProjector(config) 65 | self.vision_feature_layer = config.vision_feature_layer 66 | self.vision_feature_select_strategy = config.vision_feature_select_strategy 67 | 68 | def get_input_embeddings( 69 | self, 70 | input_ids: Optional[mx.array] = None, 71 | pixel_values: Optional[mx.array] = None, 72 | ): 73 | if pixel_values is None: 74 | return self.language_model.model.embed_tokens(input_ids) 75 | 76 | # Get the input embeddings from the language model 77 | inputs_embeds = self.language_model.model.embed_tokens(input_ids) 78 | 79 | # Get the ouptut hidden states from the vision model 80 | *_, hidden_states = self.vision_tower( 81 | pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True 82 | ) 83 | 84 | # Select the hidden states from the desired layer 85 | selected_image_feature = hidden_states[self.vision_feature_layer] 86 | 87 | if isinstance(self.vision_feature_layer, int): 88 | if self.vision_feature_select_strategy == "default": 89 | selected_image_feature = selected_image_feature[:, 1:] 90 | 91 | else: 92 | hs_pool = [ 93 | hidden_states[layer_idx] for layer_idx in self.vision_feature_layer 94 | ] 95 | # For default; crop CLS from each hidden state in the hidden state pool 96 | if self.vision_feature_select_strategy == "default": 97 | hs_pool = [hs[:, 1:] for hs in hs_pool] 98 | selected_image_feature = mx.concatenate(hs_pool, axis=-1) 99 | 100 | # Pass image features through the multi-modal projector 101 | image_features = self.multi_modal_projector(selected_image_feature) 102 | 103 | # Insert special image tokens in the input_ids 104 | final_inputs_embeds = self._merge_input_ids_with_image_features( 105 | image_features, inputs_embeds, input_ids 106 | ) 107 | return final_inputs_embeds 108 | 109 | def _merge_input_ids_with_image_features( 110 | self, image_features, inputs_embeds, input_ids 111 | ): 112 | image_token_index = self.config.image_token_index 113 | 114 | # Positions of tokens in input_ids, assuming batch size is 1 115 | image_positions = np.where(input_ids == image_token_index)[1].tolist() 116 | num_images, _, vision_hidden_size = image_features.shape 117 | 118 | reshaped_image_hidden_states = image_features.reshape(-1, vision_hidden_size) 119 | 120 | # cast to the dtype of the input_embeds to support quantized models 121 | reshaped_image_hidden_states = reshaped_image_hidden_states.astype( 122 | inputs_embeds.dtype 123 | ) 124 | 125 | # Pad image_positions to match the length of reshaped_image_hidden_states 126 | num_positions_needed = len(image_positions) 127 | 128 | if reshaped_image_hidden_states.shape[0] > num_positions_needed: 129 | # TODO: Think about how to handle this case 130 | raise ValueError( 131 | "Llava model supports only one image per input. Please check your input_ids and pixel_values." 132 | ) 133 | 134 | inputs_embeds[:, image_positions, :] = reshaped_image_hidden_states 135 | return inputs_embeds 136 | 137 | def __call__( 138 | self, 139 | input_ids: mx.array, 140 | pixel_values: mx.array, 141 | mask: mx.array, 142 | cache=None, 143 | **kwargs, 144 | ): 145 | input_embddings = self.get_input_embeddings(input_ids, pixel_values) 146 | logits = self.language_model( 147 | input_ids, mask=mask, cache=cache, inputs_embeds=input_embddings 148 | ) 149 | return logits 150 | 151 | @staticmethod 152 | def from_pretrained(path_or_hf_repo: str): 153 | path = Path(path_or_hf_repo) 154 | if not path.exists(): 155 | path = Path( 156 | snapshot_download( 157 | repo_id=path_or_hf_repo, 158 | allow_patterns=[ 159 | "*.json", 160 | "*.safetensors", 161 | "*.py", 162 | "tokenizer.model", 163 | "*.tiktoken", 164 | ], 165 | ) 166 | ) 167 | 168 | with open(path / "config.json", "r") as f: 169 | model_config = json.load(f) 170 | 171 | model_config = ModelConfig.from_dict(model_config) 172 | 173 | model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) 174 | model_config.text_config = TextConfig.from_dict(model_config.text_config) 175 | 176 | model = Model(model_config) 177 | weight_files = glob.glob(str(path / "*.safetensors")) 178 | if not weight_files: 179 | raise FileNotFoundError(f"No safetensors found in {path}") 180 | 181 | weights = {} 182 | for wf in weight_files: 183 | weights.update(mx.load(wf)) 184 | 185 | weights = VisionModel.sanitize(weights) 186 | weights = LanguageModel.sanitize(weights) 187 | 188 | model.load_weights(list(weights.items())) 189 | return model 190 | -------------------------------------------------------------------------------- /mlx_vlm/models/llava_bunny/__init__.py: -------------------------------------------------------------------------------- 1 | from .llava_bunny import ( 2 | ImageProcessor, 3 | LanguageModel, 4 | Model, 5 | ModelConfig, 6 | TextConfig, 7 | VisionConfig, 8 | VisionModel, 9 | ) 10 | -------------------------------------------------------------------------------- /mlx_vlm/models/llava_next/__init__.py: -------------------------------------------------------------------------------- 1 | from .llava_next import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/llava_next/llava_next.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import json 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import List, Optional 7 | 8 | import mlx.core as mx 9 | import mlx.nn as nn 10 | import numpy as np 11 | from huggingface_hub import snapshot_download 12 | 13 | from .language import LanguageModel, TextConfig 14 | from .vision import VisionConfig, VisionModel 15 | 16 | 17 | @dataclass 18 | class ModelConfig: 19 | text_config: TextConfig 20 | vision_config: VisionConfig 21 | model_type: str 22 | ignore_index: int = -100 23 | image_token_index: int = 32000 24 | vision_feature_select_strategy: str = "default" 25 | vision_feature_layer: int = -2 26 | vocab_size: int = 32000 27 | eos_token_id: Optional[List[int]] = None 28 | 29 | @classmethod 30 | def from_dict(cls, params): 31 | return cls( 32 | **{ 33 | k: v 34 | for k, v in params.items() 35 | if k in inspect.signature(cls).parameters 36 | } 37 | ) 38 | 39 | 40 | class LlavaMultiModalProjector(nn.Module): 41 | def __init__(self, config: ModelConfig): 42 | super().__init__() 43 | self.linear_1 = nn.Linear( 44 | config.vision_config.hidden_size, config.text_config.hidden_size, bias=True 45 | ) 46 | self.gelu = nn.GELU() 47 | self.linear_2 = nn.Linear( 48 | config.text_config.hidden_size, config.text_config.hidden_size, bias=True 49 | ) 50 | 51 | def __call__(self, x: mx.array) -> mx.array: 52 | x = self.linear_1(x) 53 | x = self.gelu(x) 54 | x = self.linear_2(x) 55 | return x 56 | 57 | 58 | class Model(nn.Module): 59 | def __init__(self, config: ModelConfig): 60 | super().__init__() 61 | self.config = config 62 | self.vision_tower = VisionModel(config.vision_config) 63 | self.language_model = LanguageModel(config.text_config) 64 | embed_std = 1 / mx.sqrt(config.text_config.hidden_size) 65 | self.image_newline = ( 66 | mx.random.normal((config.text_config.hidden_size,)) * embed_std 67 | ) 68 | 69 | self.multi_modal_projector = LlavaMultiModalProjector(config) 70 | self.vision_feature_layer = config.vision_feature_layer 71 | self.vision_feature_select_strategy = config.vision_feature_select_strategy 72 | 73 | def get_input_embeddings( 74 | self, 75 | input_ids: Optional[mx.array] = None, 76 | pixel_values: Optional[mx.array] = None, 77 | ): 78 | if pixel_values is None: 79 | return self.language_model.model.embed_tokens(input_ids) 80 | 81 | # Get the input embeddings from the language model 82 | inputs_embeds = self.language_model.model.embed_tokens(input_ids) 83 | 84 | # Get the ouptut hidden states from the vision model 85 | *_, hidden_states = self.vision_tower( 86 | pixel_values[0].transpose(0, 2, 3, 1), output_hidden_states=True 87 | ) 88 | 89 | # Select the hidden states from the desired layer 90 | selected_image_feature = hidden_states[self.vision_feature_layer] 91 | 92 | if self.vision_feature_select_strategy == "default": 93 | selected_image_feature = selected_image_feature[:, 1:] 94 | elif self.vision_feature_select_strategy == "full": 95 | selected_image_feature = selected_image_feature 96 | else: 97 | raise ValueError( 98 | "Unexpected feature selection strategy: " 99 | f"{self.vision_feature_select_strategy}" 100 | ) 101 | 102 | # Pass image features through the multi-modal projector 103 | image_features = self.multi_modal_projector(selected_image_feature) 104 | 105 | # Add a newline token to the image features 106 | if self.image_newline is not None: 107 | self.image_newline = np.array(self.image_newline)[None, None, :] 108 | self.image_newline = np.broadcast_to( 109 | self.image_newline, image_features.shape 110 | ) 111 | image_newline = mx.array(self.image_newline) 112 | image_features = mx.concatenate([image_features, image_newline], axis=0) 113 | 114 | # Insert special image tokens in the input_ids 115 | final_inputs_embeds = self._merge_input_ids_with_image_features( 116 | image_features, inputs_embeds, input_ids 117 | ) 118 | return final_inputs_embeds 119 | 120 | def _merge_input_ids_with_image_features( 121 | self, image_features, inputs_embeds, input_ids 122 | ): 123 | image_token_index = self.config.image_token_index 124 | num_images, num_image_patches, embed_dim = image_features.shape 125 | 126 | image_positions = np.where(input_ids == image_token_index)[1].tolist() 127 | 128 | text_segments = [] 129 | start_idx = 0 130 | 131 | for position in image_positions: 132 | text_segments.append(inputs_embeds[:, start_idx:position]) 133 | start_idx = position + 1 134 | 135 | image_embeddings = mx.split(image_features, image_features.shape[0]) 136 | final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] 137 | final_embeddings += [inputs_embeds[:, start_idx:]] 138 | 139 | # Create a final embedding of shape 140 | # (1, num_image_patches*num_images + sequence_len, embed_dim) 141 | return mx.concatenate(final_embeddings, axis=1) 142 | 143 | def __call__( 144 | self, 145 | input_ids: mx.array, 146 | pixel_values: mx.array, 147 | mask: mx.array, 148 | cache=None, 149 | **kwargs, 150 | ): 151 | 152 | input_embddings = self.get_input_embeddings(input_ids, pixel_values) 153 | logits = self.language_model( 154 | input_ids, cache=cache, inputs_embeds=input_embddings 155 | ) 156 | return logits 157 | 158 | @staticmethod 159 | def from_pretrained(path_or_hf_repo: str): 160 | path = Path(path_or_hf_repo) 161 | if not path.exists(): 162 | path = Path( 163 | snapshot_download( 164 | repo_id=path_or_hf_repo, 165 | allow_patterns=[ 166 | "*.json", 167 | "*.safetensors", 168 | "*.py", 169 | "tokenizer.model", 170 | "*.tiktoken", 171 | ], 172 | ) 173 | ) 174 | 175 | with open(path / "config.json", "r") as f: 176 | model_config = json.load(f) 177 | 178 | model_config = ModelConfig.from_dict(model_config) 179 | 180 | model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) 181 | model_config.text_config = TextConfig.from_dict(model_config.text_config) 182 | 183 | model = Model(model_config) 184 | weight_files = glob.glob(str(path / "*.safetensors")) 185 | if not weight_files: 186 | raise FileNotFoundError(f"No safetensors found in {path}") 187 | 188 | weights = {} 189 | for wf in weight_files: 190 | weights.update(mx.load(wf)) 191 | 192 | weights = VisionModel.sanitize(weights) 193 | weights = LanguageModel.sanitize(weights) 194 | 195 | model.load_weights(list(weights.items())) 196 | return model 197 | -------------------------------------------------------------------------------- /mlx_vlm/models/mistral3/__init__.py: -------------------------------------------------------------------------------- 1 | from .mistral3 import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/mllama/__init__.py: -------------------------------------------------------------------------------- 1 | from .mllama import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/mllama/mllama.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import json 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import List, Optional, Tuple 7 | 8 | import mlx.core as mx 9 | import mlx.nn as nn 10 | from huggingface_hub import snapshot_download 11 | 12 | from ..cache import KVCache 13 | from .language import LanguageModel, TextConfig 14 | from .vision import VisionConfig, VisionModel 15 | 16 | 17 | @dataclass 18 | class ModelConfig: 19 | text_config: TextConfig 20 | vision_config: VisionConfig 21 | model_type: str 22 | ignore_index: int = -100 23 | image_token_index: int = 128256 24 | vision_feature_select_strategy: str = "default" 25 | vision_feature_layer: int = -2 26 | vocab_size: int = 32000 27 | eos_token_id: Optional[List[int]] = None 28 | 29 | @classmethod 30 | def from_dict(cls, params): 31 | return cls( 32 | **{ 33 | k: v 34 | for k, v in params.items() 35 | if k in inspect.signature(cls).parameters 36 | } 37 | ) 38 | 39 | 40 | class Model(nn.Module): 41 | def __init__(self, config: ModelConfig): 42 | super().__init__() 43 | self.config = config 44 | self.vision_tower = VisionModel(config.vision_config) 45 | self.language_model = LanguageModel(config.text_config) 46 | self.multi_modal_projector = nn.Linear( 47 | config.vision_config.vision_output_dim, 48 | config.text_config.hidden_size, 49 | bias=True, 50 | ) 51 | 52 | def __call__( 53 | self, 54 | input_ids: mx.array, 55 | pixel_values: mx.array, 56 | mask: mx.array, 57 | cache: Optional[KVCache] = None, 58 | **kwargs, 59 | ) -> Tuple[mx.array, Optional[mx.array]]: 60 | 61 | aspect_ratio_ids = kwargs.pop("aspect_ratio_ids", None) 62 | aspect_ratio_mask = kwargs.pop("aspect_ratio_mask", None) 63 | cross_attention_mask = kwargs.pop("cross_attention_mask", None) 64 | 65 | inputs_embeds = None 66 | 67 | # Process vision input if provided 68 | if pixel_values is not None: 69 | if aspect_ratio_ids is None: 70 | raise ValueError( 71 | "`aspect_ratio_ids` must be provided if `pixel_values` is provided" 72 | ) 73 | 74 | vision_outputs = self.vision_tower( 75 | pixel_values=pixel_values, 76 | aspect_ratio_ids=aspect_ratio_ids, 77 | aspect_ratio_mask=aspect_ratio_mask, 78 | ) 79 | cross_attention_states = vision_outputs[0] 80 | 81 | cross_attention_states = self.multi_modal_projector( 82 | cross_attention_states 83 | ).reshape( 84 | -1, 85 | cross_attention_states.shape[-2], 86 | self.config.text_config.hidden_size, 87 | ) 88 | 89 | else: 90 | cross_attention_states = None 91 | 92 | # Prepare cross attention mask 93 | if cross_attention_mask is not None: 94 | cross_attention_mask, full_text_row_masked_out_mask = ( 95 | self._prepare_cross_attention_mask( 96 | cross_attention_mask, 97 | num_vision_tokens=( 98 | self.config.vision_config.image_size 99 | // self.config.vision_config.patch_size 100 | ) 101 | ** 2 102 | + 1, 103 | ) 104 | ) 105 | else: 106 | full_text_row_masked_out_mask = None 107 | 108 | if cross_attention_mask is not None: 109 | cache_position = mx.arange(input_ids.shape[1], dtype=mx.int32) 110 | cross_attention_mask = cross_attention_mask[:, :, cache_position] 111 | full_text_row_masked_out_mask = full_text_row_masked_out_mask[ 112 | :, :, cache_position 113 | ] 114 | 115 | # Process language input 116 | outputs = self.language_model( 117 | input_ids=input_ids, 118 | mask=mask, 119 | cross_attention_states=cross_attention_states, 120 | cross_attention_mask=cross_attention_mask, 121 | full_text_row_masked_out_mask=full_text_row_masked_out_mask, 122 | inputs_embeds=inputs_embeds, 123 | cache=cache, 124 | ) 125 | 126 | return outputs 127 | 128 | def _prepare_cross_attention_mask( 129 | self, 130 | cross_attention_mask: mx.array, 131 | num_vision_tokens: int, 132 | ) -> Tuple[mx.array, mx.array]: 133 | batch_size, text_total_length, *_ = cross_attention_mask.shape 134 | cross_attention_mask = mx.repeat( 135 | cross_attention_mask, num_vision_tokens, axis=3 136 | ) 137 | cross_attention_mask = cross_attention_mask.reshape( 138 | batch_size, text_total_length, -1 139 | ) 140 | cross_attention_mask = mx.expand_dims(cross_attention_mask, 1) 141 | 142 | # Invert the mask 143 | inverted_cross_attn_mask = 1.0 - cross_attention_mask 144 | fill_array = mx.array(-1e9) 145 | fill_array = mx.broadcast_to(fill_array, inverted_cross_attn_mask.shape) 146 | cross_attention_mask = mx.where( 147 | inverted_cross_attn_mask, 148 | fill_array, 149 | cross_attention_mask, 150 | ) 151 | 152 | # Apply full-row bias 153 | full_text_row_masked_out_mask = mx.any( 154 | cross_attention_mask != -1e9, 155 | axis=-1, 156 | keepdims=True, 157 | ) 158 | cross_attention_mask *= full_text_row_masked_out_mask 159 | 160 | return cross_attention_mask, full_text_row_masked_out_mask 161 | 162 | @staticmethod 163 | def from_pretrained(path_or_hf_repo: str): 164 | path = Path(path_or_hf_repo) 165 | if not path.exists(): 166 | path = Path( 167 | snapshot_download( 168 | repo_id=path_or_hf_repo, 169 | allow_patterns=[ 170 | "*.json", 171 | "*.safetensors", 172 | "*.py", 173 | "tokenizer.model", 174 | "*.tiktoken", 175 | ], 176 | ) 177 | ) 178 | 179 | with open(path / "config.json", "r") as f: 180 | model_config = json.load(f) 181 | 182 | model_config = ModelConfig.from_dict(model_config) 183 | 184 | model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) 185 | model_config.text_config = TextConfig.from_dict(model_config) 186 | 187 | model = Model(model_config) 188 | weight_files = glob.glob(str(path / "*.safetensors")) 189 | if not weight_files: 190 | raise FileNotFoundError(f"No safetensors found in {path}") 191 | 192 | weights = {} 193 | for wf in weight_files: 194 | weights.update(mx.load(wf)) 195 | 196 | weights = VisionModel.sanitize(weights) 197 | weights = LanguageModel.sanitize(weights) 198 | 199 | model.load_weights(list(weights.items())) 200 | return model 201 | 202 | def sanitize(self, weights): 203 | def transform_key(key): 204 | if "vision_tower" not in key: 205 | key = key.replace("vision_model", "vision_tower") 206 | return key 207 | 208 | return {transform_key(k): v for k, v in weights.items()} 209 | -------------------------------------------------------------------------------- /mlx_vlm/models/molmo/__init__.py: -------------------------------------------------------------------------------- 1 | from .molmo import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/molmo/molmo.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import json 4 | from dataclasses import dataclass, field 5 | from pathlib import Path 6 | from typing import Dict, List, Optional, Tuple, Union 7 | 8 | import mlx.core as mx 9 | import mlx.nn as nn 10 | import numpy as np 11 | from huggingface_hub import snapshot_download 12 | 13 | from .language import LanguageModel, TextConfig 14 | from .vision import VisionConfig, VisionModel 15 | 16 | 17 | @dataclass 18 | class ModelConfig: 19 | text_config: TextConfig = field(default_factory=TextConfig) 20 | vision_config: VisionConfig = field(default_factory=VisionConfig) 21 | model_type: str = "molmo" 22 | image_feature_dropout: float = 0.0 23 | image_pooling_h: int = 2 24 | image_pooling_w: int = 2 25 | image_pooling_2d: str = "attention" 26 | image_projector: str = "mlp" 27 | eos_token_id: Optional[List[int]] = None 28 | 29 | @classmethod 30 | def from_dict(cls, params): 31 | return cls( 32 | **{ 33 | k: v 34 | for k, v in params.items() 35 | if k in inspect.signature(cls).parameters 36 | } 37 | ) 38 | 39 | 40 | class Model(nn.Module): 41 | def __init__(self, config: ModelConfig): 42 | super().__init__() 43 | self.config = config 44 | self.language_model = LanguageModel(config.text_config) 45 | self.vision_tower = VisionModel(config.vision_config) 46 | 47 | def __call__( 48 | self, 49 | input_ids: mx.array, 50 | pixel_values: mx.array, 51 | mask: mx.array, 52 | cache=None, 53 | **kwargs, 54 | ) -> Dict[str, Union[mx.array, List[Tuple[mx.array, mx.array]]]]: 55 | if input_ids.ndim == 1: 56 | input_ids = input_ids[None, :] 57 | 58 | batch_size, seq_len = input_ids.shape 59 | 60 | image_input_idx = kwargs.get("image_input_idx", None) 61 | image_masks = kwargs.get("image_masks", None) 62 | 63 | if pixel_values is not None: 64 | assert ( 65 | image_masks is not None and image_input_idx is not None 66 | ), "image_masks and image_input_idx must be provided when images are given" 67 | 68 | dtype = self.vision_tower.image_vit.patch_embedding.weight.dtype 69 | pixel_values = pixel_values.astype(dtype) 70 | 71 | # Process images 72 | if pixel_values.ndim == 3: 73 | pixel_values = mx.expand_dims(pixel_values, 0) 74 | image_masks = ( 75 | mx.expand_dims(image_masks, 0) if image_masks is not None else None 76 | ) 77 | image_input_idx = ( 78 | mx.expand_dims(image_input_idx, 0) 79 | if image_input_idx is not None 80 | else None 81 | ) 82 | 83 | image_features, cls_embed = self.vision_tower(pixel_values, image_masks) 84 | 85 | # Insert image features into the input embeddings 86 | num_image, num_patch = image_features.shape[1:3] 87 | 88 | assert image_input_idx.shape == ( 89 | batch_size, 90 | num_image, 91 | num_patch, 92 | ), f"image_input_idx.shape: {image_input_idx.shape}, expected: {(batch_size, num_image, num_patch)}" 93 | 94 | # Insert image features into the input embeddings 95 | image_features = image_features.reshape( 96 | batch_size, num_image * num_patch, -1 97 | ) 98 | image_input_idx = image_input_idx.reshape(batch_size, num_image * num_patch) 99 | 100 | valid = np.where(image_input_idx >= 0)[0].tolist() 101 | batch_idx = mx.arange(batch_size) 102 | batch_idx = mx.tile(batch_idx[:, None], [1, image_features.shape[1]]) 103 | 104 | input_embeddings = self.language_model.model.wte(input_ids) 105 | input_embeddings[ 106 | batch_idx[valid], image_input_idx[valid] 107 | ] += image_features[valid] 108 | else: 109 | input_embeddings = None 110 | 111 | # Forward pass through the language model 112 | logits = self.language_model( 113 | input_ids, 114 | inputs_embeds=input_embeddings, 115 | mask=mask, 116 | cache=cache, 117 | ) 118 | 119 | return logits 120 | 121 | @staticmethod 122 | def from_pretrained(path_or_hf_repo: str): 123 | path = Path(path_or_hf_repo) 124 | if not path.exists(): 125 | path = Path( 126 | snapshot_download( 127 | repo_id=path_or_hf_repo, 128 | allow_patterns=[ 129 | "*.json", 130 | "*.safetensors", 131 | "*.py", 132 | "tokenizer.model", 133 | "*.tiktoken", 134 | ], 135 | ) 136 | ) 137 | 138 | with open(path / "config.json", "r") as f: 139 | model_config = json.load(f) 140 | 141 | model_config = ModelConfig.from_dict(model_config) 142 | 143 | model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) 144 | model_config.text_config = TextConfig.from_dict(model_config.text_config) 145 | 146 | model = Model(model_config) 147 | weight_files = glob.glob(str(path / "*.safetensors")) 148 | if not weight_files: 149 | raise FileNotFoundError(f"No safetensors found in {path}") 150 | 151 | weights = {} 152 | for wf in weight_files: 153 | weights.update(mx.load(wf)) 154 | 155 | weights = VisionModel.sanitize(weights) 156 | weights = LanguageModel.sanitize(weights) 157 | 158 | model.load_weights(list(weights.items())) 159 | return model 160 | 161 | def sanitize(self, weights): 162 | def transform_key(key): 163 | if "model.transformer" in key: 164 | key = key.replace("model.transformer", "language_model.model") 165 | if "model.vision_backbone" in key: 166 | key = key.replace("model.vision_backbone", "vision_tower") 167 | return key 168 | 169 | return {transform_key(k): v for k, v in weights.items()} 170 | -------------------------------------------------------------------------------- /mlx_vlm/models/multi_modality/__init__.py: -------------------------------------------------------------------------------- 1 | from .multi_modality import ( 2 | ImageProcessor, 3 | LanguageModel, 4 | Model, 5 | ModelConfig, 6 | ProjectorConfig, 7 | TextConfig, 8 | VisionConfig, 9 | VisionModel, 10 | ) 11 | -------------------------------------------------------------------------------- /mlx_vlm/models/paligemma/__init__.py: -------------------------------------------------------------------------------- 1 | from .paligemma import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/paligemma/paligemma.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import json 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import List, Optional 7 | 8 | import mlx.core as mx 9 | import mlx.nn as nn 10 | from huggingface_hub import snapshot_download 11 | 12 | from .language import LanguageModel, TextConfig 13 | from .vision import VisionConfig, VisionModel 14 | 15 | 16 | @dataclass 17 | class ModelConfig: 18 | text_config: TextConfig 19 | vision_config: VisionConfig 20 | model_type: str 21 | vocab_size: int = 257152 22 | ignore_index: int = -100 23 | image_token_index: int = 257152 24 | hidden_size: int = 2048 25 | pad_token_id: int = 0 26 | eos_token_id: Optional[List[int]] = None 27 | 28 | @classmethod 29 | def from_dict(cls, params): 30 | return cls( 31 | **{ 32 | k: v 33 | for k, v in params.items() 34 | if k in inspect.signature(cls).parameters 35 | } 36 | ) 37 | 38 | 39 | class PaliGemmaMultiModalProjector(nn.Module): 40 | def __init__(self, config: ModelConfig): 41 | super().__init__() 42 | self.linear = nn.Linear( 43 | config.vision_config.hidden_size, 44 | config.vision_config.projection_dim, 45 | bias=True, 46 | ) 47 | 48 | def __call__(self, x: mx.array) -> mx.array: 49 | output = self.linear(x) 50 | return output 51 | 52 | 53 | class Model(nn.Module): 54 | def __init__(self, config: ModelConfig): 55 | super().__init__() 56 | self.model_type = config.model_type 57 | self.config = config 58 | 59 | self.vision_tower = VisionModel(config.vision_config) 60 | self.language_model = LanguageModel(config.text_config) 61 | self.multi_modal_projector = PaliGemmaMultiModalProjector(config) 62 | 63 | def get_input_embeddings( 64 | self, 65 | input_ids: Optional[mx.array] = None, 66 | pixel_values: Optional[mx.array] = None, 67 | mask: Optional[mx.array] = None, 68 | ): 69 | if pixel_values is None: 70 | return self.language_model.model.embed_tokens(input_ids), None 71 | 72 | inputs_embeds = self.language_model.model.embed_tokens(input_ids) 73 | 74 | hidden_state, _, _ = self.vision_tower( 75 | pixel_values.transpose(0, 2, 3, 1).astype(inputs_embeds.dtype), 76 | output_hidden_states=True, 77 | ) 78 | 79 | image_features = hidden_state[None, :].astype(pixel_values.dtype) 80 | image_features = self.multi_modal_projector(image_features) 81 | 82 | final_inputs_embeds, final_attention_mask_4d = ( 83 | self._prepare_inputs_for_multimodal( 84 | image_features, inputs_embeds, input_ids, mask 85 | ) 86 | ) 87 | return final_inputs_embeds, final_attention_mask_4d 88 | 89 | def _prepare_inputs_for_multimodal( 90 | self, image_features, inputs_embeds, input_ids, attention_mask 91 | ): 92 | _, _, embed_dim = image_features.shape 93 | 94 | batch_size, sequence_length = input_ids.shape 95 | scaled_image_features = image_features / (self.config.hidden_size**0.5) 96 | final_embedding = mx.zeros((batch_size, sequence_length, embed_dim)) 97 | 98 | text_mask = (input_ids != self.config.image_token_index) & ( 99 | input_ids != self.config.pad_token_id 100 | ) 101 | image_mask = input_ids == self.config.image_token_index 102 | pad_mask = input_ids == self.config.pad_token_id 103 | 104 | # expand masks to match embedding dimension 105 | text_mask_expanded = mx.expand_dims(text_mask, -1) 106 | text_mask_expanded = mx.repeat(text_mask_expanded, embed_dim, axis=-1) 107 | pad_mask_expanded = mx.expand_dims(pad_mask, -1) 108 | pad_mask_expanded = mx.repeat(pad_mask_expanded, embed_dim, axis=-1) 109 | 110 | # insert padding and text token embeddings 111 | final_embedding = mx.where(text_mask_expanded, inputs_embeds, final_embedding) 112 | final_embedding = mx.where( 113 | pad_mask_expanded, mx.zeros_like(final_embedding), final_embedding 114 | ) 115 | pad_size = final_embedding.shape[1] - scaled_image_features.shape[1] 116 | scaled_image_features = mx.pad( 117 | scaled_image_features, ((0, 0), (0, pad_size), (0, 0)) 118 | ) 119 | # insert image embeddings - the image mask is always less or equal to the sentence in length 120 | image_mask_expanded = mx.expand_dims(image_mask, -1) 121 | image_mask_expanded = mx.repeat(image_mask_expanded, embed_dim, axis=-1) 122 | final_embedding = mx.where( 123 | image_mask_expanded, scaled_image_features, final_embedding 124 | ) 125 | 126 | final_embedding = mx.where( 127 | pad_mask_expanded, mx.zeros_like(final_embedding), final_embedding 128 | ) 129 | 130 | attention_mask_expanded_1 = mx.expand_dims(attention_mask, 1) 131 | attention_mask_expanded_2 = mx.expand_dims(attention_mask, 2) 132 | final_attention_mask_4d = attention_mask_expanded_1 * attention_mask_expanded_2 133 | final_attention_mask_4d = final_attention_mask_4d 134 | final_attention_mask_4d = mx.expand_dims(final_attention_mask_4d, 1) 135 | final_embedding = mx.array(final_embedding) 136 | return final_embedding, final_attention_mask_4d 137 | 138 | def __call__( 139 | self, 140 | input_ids: mx.array, 141 | pixel_values: mx.array, 142 | mask: Optional[mx.array] = None, 143 | cache: Optional[mx.array] = None, 144 | **kwargs, 145 | ): 146 | input_embeddings, final_attention_mask_4d = self.get_input_embeddings( 147 | input_ids, pixel_values, mask 148 | ) 149 | 150 | logits = self.language_model( 151 | inputs=input_ids, 152 | cache=cache, 153 | inputs_embeds=input_embeddings, 154 | mask=final_attention_mask_4d, 155 | ) 156 | return logits 157 | 158 | @staticmethod 159 | def from_pretrained(path_or_hf_repo: str): 160 | path = Path(path_or_hf_repo) 161 | if not path.exists(): 162 | path = Path( 163 | snapshot_download( 164 | repo_id=path_or_hf_repo, 165 | allow_patterns=[ 166 | "*.json", 167 | "*.safetensors", 168 | "*.py", 169 | "tokenizer.model", 170 | "*.tiktoken", 171 | ], 172 | ) 173 | ) 174 | 175 | with open(path / "config.json", "r") as f: 176 | config = json.load(f) 177 | 178 | model_config = ModelConfig.from_dict(config) 179 | model_config.vision_config = VisionConfig.from_dict(config["vision_config"]) 180 | model_config.text_config = TextConfig.from_dict(config["text_config"]) 181 | 182 | model = Model(model_config) 183 | weight_files = glob.glob(str(path / "*.safetensors")) 184 | if not weight_files: 185 | raise FileNotFoundError(f"No safetensors found in {path}") 186 | 187 | weights = {} 188 | for wf in weight_files: 189 | weights.update(mx.load(wf)) 190 | 191 | weights = model.sanitize(weights=weights) 192 | 193 | weights = VisionModel(model_config.vision_config).sanitize(weights=weights) 194 | model.load_weights(list(weights.items())) 195 | return model 196 | -------------------------------------------------------------------------------- /mlx_vlm/models/phi3_v/__init__.py: -------------------------------------------------------------------------------- 1 | from .phi3_v import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/phi3_v/language.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass 6 | class TextConfig: 7 | max_position_embeddings: int = 4096 8 | 9 | @classmethod 10 | def from_dict(cls, params): 11 | return cls( 12 | **{ 13 | k: v 14 | for k, v in params.items() 15 | if k in inspect.signature(cls).parameters 16 | } 17 | ) 18 | 19 | 20 | class LanguageModel: 21 | pass 22 | -------------------------------------------------------------------------------- /mlx_vlm/models/phi3_v/su_rope.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import mlx.core as mx 4 | 5 | 6 | class Phi3SuScaledRotaryEmbedding: 7 | def __init__( 8 | self, 9 | dims: int, 10 | traditional: bool = False, 11 | base: float = 10000.0, 12 | scale: float = 1.0, 13 | max_position_embeddings: int = 131072, 14 | original_max_position_embeddings: int = 4096, 15 | short_factor: list[float] | float = 1.0, 16 | long_factor: list[float] | float = 1.0, 17 | ): 18 | """ 19 | Phi3Su Scaled Rotary Embedding layer for Phi-3 models. 20 | 21 | Args: 22 | dims (int): The feature dimensions to be rotated. 23 | traditional (bool, optional): Unused. Default: ``False``. 24 | base (int, optional): Base for the exponential scaling. 25 | scale (float, optional): The scale used to scale the positions. Default: 1.0. 26 | max_position_embeddings (int, optional): The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling. Default: 131072. 27 | original_max_position_embeddings (int, optional): The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling. Default: 4096. 28 | short_factor (float or list of floats, optional): List of scaling factors for sequences of length lesser than original_max_position_embeddings. Default: 1.0. 29 | long_factor (float or list of floats, optional): List of scaling factors for sequences of length greater than original_max_position_embeddings. Default: 1.0. 30 | """ 31 | self.inv_freq_short = 1.0 / ( 32 | mx.array(short_factor, dtype=mx.float32) 33 | * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) 34 | ) 35 | self.inv_freq_long = 1.0 / ( 36 | scale 37 | * mx.array(long_factor, dtype=mx.float32) 38 | * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) 39 | ) 40 | self.original_max_position_embeddings = original_max_position_embeddings 41 | self.scaling_factor = math.sqrt( 42 | 1 43 | + math.log(max_position_embeddings / original_max_position_embeddings) 44 | / math.log(original_max_position_embeddings) 45 | ) 46 | 47 | def _get_cos_sin(self, offset, L): 48 | position_ids = mx.arange(offset, offset + L, dtype=mx.float32)[None] 49 | inv_freq = ( 50 | self.inv_freq_long 51 | if position_ids.max() + 1 > self.original_max_position_embeddings 52 | else self.inv_freq_short 53 | ) 54 | inv_freq_expanded = mx.repeat( 55 | inv_freq[None, :, None], position_ids.shape[0], axis=0 56 | ) 57 | position_ids_expanded = position_ids[:, None, :] 58 | freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1) 59 | emb = mx.concatenate([freqs, freqs], axis=-1) 60 | cos = mx.cos(emb) * self.scaling_factor 61 | sin = mx.sin(emb) * self.scaling_factor 62 | return mx.expand_dims(cos, axis=1), mx.expand_dims(sin, axis=1) 63 | 64 | def __call__(self, x, offset: int = 0): 65 | def _rotate_half(_x): 66 | midpoint = _x.shape[-1] // 2 67 | x1, x2 = _x[..., :midpoint], _x[..., midpoint:] 68 | return mx.concatenate([-x2, x1], axis=-1) 69 | 70 | cos, sin = self._get_cos_sin(offset, x.shape[2]) 71 | return (x * cos) + (_rotate_half(x) * sin) 72 | -------------------------------------------------------------------------------- /mlx_vlm/models/pixtral/__init__.py: -------------------------------------------------------------------------------- 1 | from .pixtral import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/qwen2_5_vl/__init__.py: -------------------------------------------------------------------------------- 1 | from .qwen2_5_vl import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/qwen2_vl/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import ModelConfig, TextConfig, VisionConfig 2 | from .qwen2_vl import LanguageModel, Model, VisionModel 3 | -------------------------------------------------------------------------------- /mlx_vlm/models/qwen2_vl/config.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Dict, List, Optional, Union 4 | 5 | 6 | @dataclass 7 | class VisionConfig: 8 | model_type: str = "qwen2_vl" 9 | depth: int = 32 10 | embed_dim: int = 1280 11 | hidden_size: int = 1536 12 | num_heads: int = 16 13 | image_size: int = 384 14 | patch_size: int = 14 15 | vocab_size: int = 32000 16 | mlp_ratio: float = 4.0 17 | in_channels: int = 3 18 | layer_norm_eps: float = 1e-6 19 | spatial_patch_size: int = 14 20 | spatial_merge_size: int = 2 21 | temporal_patch_size: int = 2 22 | 23 | @classmethod 24 | def from_dict(cls, params): 25 | return cls( 26 | **{ 27 | k: v 28 | for k, v in params.items() 29 | if k in inspect.signature(cls).parameters 30 | } 31 | ) 32 | 33 | 34 | @dataclass 35 | class TextConfig: 36 | model_type: str 37 | hidden_size: int 38 | num_hidden_layers: int 39 | intermediate_size: int 40 | num_attention_heads: int 41 | rms_norm_eps: float 42 | vocab_size: int 43 | num_key_value_heads: Optional[int] = 8 44 | max_position_embeddings: Optional[int] = 40960 45 | rope_theta: float = 1000000.0 46 | rope_traditional: bool = False 47 | rope_scaling: Optional[Dict[str, Union[float, str]]] = None 48 | tie_word_embeddings: bool = False 49 | sliding_window: int = 32768 50 | use_sliding_window: bool = False 51 | use_cache: bool = True 52 | 53 | def __post_init__(self): 54 | if self.num_key_value_heads is None: 55 | self.num_key_value_heads = self.num_attention_heads 56 | 57 | if self.rope_scaling: 58 | required_keys = {"mrope_section", "type"} 59 | if not all(key in self.rope_scaling for key in required_keys): 60 | raise ValueError(f"rope_scaling must contain keys {required_keys}") 61 | 62 | if not self.rope_scaling["type"] in ["mrope", "default"]: 63 | raise ValueError(f"rope_scaling type must be 'mrope' or 'default'") 64 | 65 | @classmethod 66 | def from_dict(cls, params): 67 | return cls( 68 | **{ 69 | k: v 70 | for k, v in params.items() 71 | if k in inspect.signature(cls).parameters 72 | } 73 | ) 74 | 75 | 76 | @dataclass 77 | class ModelConfig: 78 | text_config: TextConfig 79 | vision_config: VisionConfig 80 | model_type: str 81 | ignore_index: int = -100 82 | image_token_index: int = 151655 83 | video_token_index: int = 151656 84 | vision_start_token_id: int = 151652 85 | vision_feature_select_strategy: str = "default" 86 | vision_feature_layer: int = -2 87 | vocab_size: int = 32000 88 | eos_token_id: Optional[List[int]] = None 89 | 90 | @classmethod 91 | def from_dict(cls, params): 92 | # Copy text config parameters from root level 93 | excluded_keys = {"vision_config"} 94 | params["text_config"] = dict( 95 | filter(lambda x: x[0] not in excluded_keys, params.items()) 96 | ) 97 | 98 | return cls( 99 | **{ 100 | k: v 101 | for k, v in params.items() 102 | if k in inspect.signature(cls).parameters 103 | } 104 | ) 105 | -------------------------------------------------------------------------------- /mlx_vlm/models/qwen2_vl/qwen2_vl.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import json 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import List, Optional 7 | 8 | import mlx.core as mx 9 | import mlx.nn as nn 10 | import numpy as np 11 | from huggingface_hub import snapshot_download 12 | 13 | from .config import ModelConfig, TextConfig, VisionConfig 14 | from .language import LanguageModel 15 | from .vision import VisionModel 16 | 17 | 18 | class Model(nn.Module): 19 | def __init__(self, config: ModelConfig): 20 | super().__init__() 21 | self.config = config 22 | self.vision_tower = VisionModel(config.vision_config) 23 | self.language_model = LanguageModel(config.text_config, config) 24 | 25 | def get_input_embeddings( 26 | self, 27 | input_ids: Optional[mx.array] = None, 28 | pixel_values: Optional[mx.array] = None, 29 | grid_thw: Optional[mx.array] = None, 30 | ): 31 | 32 | if pixel_values is None: 33 | return self.language_model.model.embed_tokens(input_ids) 34 | 35 | dtype = self.vision_tower.patch_embed.proj.weight.dtype 36 | pixel_values = pixel_values.astype(dtype) 37 | 38 | # Get the input embeddings from the language model 39 | inputs_embeds = self.language_model.model.embed_tokens(input_ids) 40 | 41 | # Get the ouptut hidden states from the vision model 42 | hidden_states = self.vision_tower( 43 | pixel_values, grid_thw, output_hidden_states=False 44 | ) 45 | 46 | # hidden_states is already in the correct shape (num_features, hidden_dim) 47 | # Don't add extra batch dimension 48 | 49 | # Insert special image tokens in the input_ids 50 | final_inputs_embeds = self._merge_input_ids_with_image_features( 51 | hidden_states, inputs_embeds, input_ids 52 | ) 53 | return final_inputs_embeds 54 | 55 | def _merge_input_ids_with_image_features( 56 | self, image_features, inputs_embeds, input_ids 57 | ): 58 | """Merge image features into input embeddings at image token positions. 59 | 60 | Args: 61 | image_features: Vision features from the vision tower [num_features, hidden_dim] 62 | inputs_embeds: Input embeddings [batch_size, seq_len, hidden_dim] 63 | input_ids: Input token IDs [batch_size, seq_len] 64 | 65 | Returns: 66 | Updated input embeddings with image features inserted 67 | """ 68 | image_token_index = self.config.image_token_index 69 | video_token_index = self.config.video_token_index 70 | 71 | # Positions of tokens in input_ids 72 | image_positions = input_ids == image_token_index 73 | if mx.sum(image_positions) == 0: 74 | image_positions = input_ids == video_token_index 75 | 76 | # Get dimensions 77 | batch_size, seq_len = input_ids.shape 78 | 79 | # Process each batch item 80 | batch_outputs = [] 81 | feature_start_idx = 0 82 | 83 | for batch_idx in range(batch_size): 84 | # Get mask for this batch 85 | image_mask = image_positions[batch_idx] 86 | num_positions = mx.sum(image_mask).item() 87 | 88 | if num_positions > 0: 89 | # Extract features for this batch 90 | batch_features = image_features[ 91 | feature_start_idx : feature_start_idx + num_positions 92 | ] 93 | 94 | # Validate we have the right number of features 95 | if batch_features.shape[0] != num_positions: 96 | raise ValueError( 97 | f"Number of image token positions ({num_positions}) does not match " 98 | f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}" 99 | ) 100 | 101 | # Create indices for gathering 102 | cumsum = mx.cumsum(image_mask.astype(mx.int32)) 103 | feature_indices = mx.where(image_mask, cumsum - 1, 0) 104 | 105 | # Gather features 106 | gathered_features = batch_features[feature_indices] 107 | 108 | # Combine with original embeddings 109 | image_mask_expanded = mx.expand_dims(image_mask, axis=-1) 110 | batch_output = mx.where( 111 | image_mask_expanded, gathered_features, inputs_embeds[batch_idx] 112 | ) 113 | 114 | feature_start_idx += num_positions 115 | else: 116 | # No image tokens in this batch item 117 | batch_output = inputs_embeds[batch_idx] 118 | 119 | batch_outputs.append(batch_output) 120 | 121 | # Stack all batch outputs 122 | return mx.stack(batch_outputs, axis=0) 123 | 124 | def __call__( 125 | self, 126 | input_ids: mx.array, 127 | pixel_values: Optional[mx.array] = None, 128 | mask: Optional[mx.array] = None, 129 | cache=None, 130 | **kwargs, 131 | ): 132 | 133 | image_grid_thw = kwargs.pop("image_grid_thw", None) 134 | video_grid_thw = kwargs.pop("video_grid_thw", None) 135 | grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw 136 | input_embddings = self.get_input_embeddings(input_ids, pixel_values, grid_thw) 137 | kwargs = { 138 | "pixel_values": pixel_values, 139 | "image_grid_thw": image_grid_thw, 140 | "video_grid_thw": video_grid_thw, 141 | **kwargs, 142 | } 143 | logits = self.language_model( 144 | input_ids, input_embddings, mask=mask, cache=cache, **kwargs 145 | ) 146 | return logits 147 | 148 | @staticmethod 149 | def from_pretrained(path_or_hf_repo: str): 150 | path = Path(path_or_hf_repo) 151 | if not path.exists(): 152 | path = Path( 153 | snapshot_download( 154 | repo_id=path_or_hf_repo, 155 | allow_patterns=[ 156 | "*.json", 157 | "*.safetensors", 158 | "*.py", 159 | "tokenizer.model", 160 | "*.tiktoken", 161 | ], 162 | ) 163 | ) 164 | 165 | with open(path / "config.json", "r") as f: 166 | model_config = json.load(f) 167 | 168 | model_config = ModelConfig.from_dict(model_config) 169 | 170 | model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) 171 | model_config.text_config = TextConfig.from_dict(model_config) 172 | 173 | model = Model(model_config) 174 | weight_files = glob.glob(str(path / "*.safetensors")) 175 | if not weight_files: 176 | raise FileNotFoundError(f"No safetensors found in {path}") 177 | 178 | weights = {} 179 | for wf in weight_files: 180 | weights.update(mx.load(wf)) 181 | 182 | weights = VisionModel.sanitize(weights) 183 | weights = LanguageModel.sanitize(weights) 184 | 185 | model.load_weights(list(weights.items())) 186 | return model 187 | 188 | def sanitize(self, weights): 189 | def transform_key(key): 190 | if "vision_tower" not in key: 191 | key = key.replace("visual", "vision_tower") 192 | if "language_model" not in key: 193 | if "model" in key: 194 | key = key.replace("model", "language_model.model") 195 | elif "lm_head" in key: 196 | key = key.replace("lm_head", "language_model.lm_head") 197 | return key 198 | 199 | return {transform_key(k): v for k, v in weights.items()} 200 | -------------------------------------------------------------------------------- /mlx_vlm/models/smolvlm/__init__.py: -------------------------------------------------------------------------------- 1 | from .smolvlm import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/smolvlm/smolvlm.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import numpy as np 3 | 4 | from ..idefics3 import LanguageModel 5 | from ..idefics3 import Model as Idefics3Model 6 | from ..idefics3 import ModelConfig, TextConfig, VisionConfig, VisionModel 7 | 8 | 9 | class Model(Idefics3Model): 10 | def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids): 11 | # Assumes bs == 1 12 | 13 | B, T, D_text = inputs_embeds.shape 14 | N, S, D_img = image_features.shape 15 | 16 | image_offset = 0 17 | cur_embeds = inputs_embeds[0] 18 | 19 | # Find positions of tokens in the text 20 | image_token_index = self.config.image_token_index 21 | image_positions = np.where(input_ids == image_token_index)[1].tolist() 22 | num_image_tokens = len(image_positions) 23 | 24 | # If no => text-only 25 | if num_image_tokens == 0: 26 | empty_slice = image_features[0][:0, :] # shape (0, D) 27 | return mx.concatenate([cur_embeds, empty_slice], axis=0) 28 | 29 | # Typically, if each image is S embeddings, we expect the total # of tokens 30 | # in this sample to be multiple of S => each group of S tokens = 1 image 31 | if num_image_tokens % S != 0: 32 | raise ValueError( 33 | f"Input has {num_image_tokens} tokens, not a multiple of S={S}. " 34 | "Cannot map them to blocks of shape (S, D)." 35 | ) 36 | 37 | chunks = [image_positions[i : i + S] for i in range(0, num_image_tokens, S)] 38 | 39 | segments = [] 40 | text_start = 0 41 | 42 | # For each chunk (each chunk => 1 image) 43 | for chunk in chunks: 44 | cur_block = image_features[image_offset] 45 | image_offset += 1 46 | 47 | # We'll iterate over the S positions in ascending order 48 | for i_s, pos in enumerate(chunk): 49 | if pos > text_start: 50 | segments.append(cur_embeds[text_start:pos]) 51 | # Then add one row from cur_block => shape (1, D) 52 | row_of_block = cur_block[i_s : i_s + 1, :] 53 | segments.append(row_of_block) 54 | text_start = pos + 1 55 | 56 | # leftover text after the final token 57 | if text_start < T: 58 | segments.append(cur_embeds[text_start:]) 59 | 60 | # cat them into a single (T_b, D) tensor 61 | merged_sample = mx.concatenate(segments, axis=0) 62 | return mx.expand_dims(merged_sample, axis=0) 63 | -------------------------------------------------------------------------------- /mlx_vlm/sample_utils.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | 3 | 4 | def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array: 5 | """ 6 | Apply top-p (nucleus) sampling to logits. 7 | 8 | Args: 9 | logits: The logits from the model's output. 10 | top_p: The cumulative probability threshold for top-p filtering. 11 | temperature: Temperature parameter for softmax distribution reshaping. 12 | Returns: 13 | token selected based on the top-p criterion. 14 | """ 15 | if ( 16 | logits.dtype == mx.bfloat16 17 | ): # workaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16 18 | logits = logits.astype(mx.float32) 19 | 20 | # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 21 | probs = mx.softmax(logits / temperature, axis=-1) 22 | 23 | # sort probs in ascending order 24 | sorted_indices = mx.argsort(probs, axis=-1) 25 | sorted_probs = probs[..., sorted_indices.squeeze(0)] 26 | 27 | cumulative_probs = mx.cumsum(sorted_probs, axis=-1) 28 | 29 | # select tokens with cumulative probs below threshold 30 | top_probs = mx.where( 31 | cumulative_probs > 1 - top_p, 32 | sorted_probs, 33 | mx.zeros_like(sorted_probs), 34 | ) 35 | 36 | sorted_token = mx.random.categorical(mx.log(top_probs)) 37 | token = sorted_indices.squeeze(0)[sorted_token] 38 | 39 | return token 40 | -------------------------------------------------------------------------------- /mlx_vlm/smolvlm_video_generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | import mlx.core as mx 5 | 6 | from .utils import generate, load 7 | 8 | # This is a proof-of-concept script for video generation with SmolVLM2. 9 | 10 | logger = logging.getLogger(__name__) 11 | logger.setLevel(logging.INFO) 12 | logger.addHandler(logging.StreamHandler()) 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser(description="Video Description CLI") 17 | parser.add_argument( 18 | "--video", type=str, required=True, help="Path to the video file" 19 | ) 20 | parser.add_argument( 21 | "--max-frames", type=int, default=None, help="Maximum number of frames" 22 | ) 23 | parser.add_argument( 24 | "--prompt", default="Describe this video.", help="Text prompt for the model" 25 | ) 26 | parser.add_argument("--system", type=str, required=False, help="System prompt") 27 | parser.add_argument( 28 | "--temp", type=float, default=0.7, help="Temperature for generation" 29 | ) 30 | parser.add_argument( 31 | "--max-tokens", 32 | type=int, 33 | default=100, 34 | help="Maximum number of tokens to generate", 35 | ) 36 | parser.add_argument( 37 | "--model", 38 | default="HuggingFaceTB/SmolVLM2-500M-Video-Instruct-mlx", 39 | help="Select the model to use", 40 | ) 41 | parser.add_argument("--verbose", action="store_false", help="Print verbose output") 42 | 43 | args = parser.parse_args() 44 | 45 | print(f"\033[32mLoading model:\033[0m {args.model}") 46 | model, processor = load(args.model) 47 | 48 | messages = [ 49 | { 50 | "role": "user", 51 | "content": [ 52 | { 53 | "type": "video", 54 | "path": args.video, 55 | }, 56 | {"type": "text", "text": args.prompt}, 57 | ], 58 | } 59 | ] 60 | if args.system: 61 | messages.insert( 62 | 0, 63 | { 64 | "role": "system", 65 | "content": [ 66 | {"type": "text", "text": args.video}, 67 | ], 68 | }, 69 | ) 70 | 71 | inputs = processor.apply_chat_template( 72 | messages, 73 | tokenize=True, 74 | add_generation_prompt=True, 75 | return_dict=True, 76 | return_tensors="np", 77 | ) 78 | 79 | input_ids = mx.array(inputs["input_ids"]) 80 | pixel_values = mx.array(inputs["pixel_values"][0]) 81 | pixel_values = mx.expand_dims(pixel_values, 0) 82 | mask = mx.array(inputs["attention_mask"]) 83 | pixel_mask = mx.array(inputs["pixel_attention_mask"]) 84 | 85 | logger.info("\033[32mGenerating response...\033[0m") 86 | 87 | kwargs = {} 88 | kwargs["input_ids"] = input_ids 89 | kwargs["pixel_values"] = pixel_values 90 | kwargs["mask"] = mask 91 | kwargs["pixel_mask"] = pixel_mask 92 | kwargs["temp"] = args.temp 93 | kwargs["max_tokens"] = args.max_tokens 94 | 95 | response = generate( 96 | model, 97 | processor, 98 | prompt="", 99 | verbose=args.verbose, 100 | **kwargs, 101 | ) 102 | 103 | if not args.verbose: 104 | print(response) 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /mlx_vlm/tests/test_trainer.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import MagicMock, Mock, patch 3 | 4 | import mlx.core as mx 5 | import mlx.nn as nn 6 | 7 | from mlx_vlm.trainer.trainer import Dataset, Trainer 8 | 9 | 10 | class TestDataset(unittest.TestCase): 11 | def setUp(self): 12 | self.mock_hf_dataset = MagicMock() 13 | self.mock_config = {"model_type": "test_model", "image_token_index": 1} 14 | self.mock_processor = MagicMock() 15 | self.mock_image_processor = MagicMock() 16 | 17 | def test_dataset_initialization(self): 18 | dataset = Dataset( 19 | self.mock_hf_dataset, 20 | self.mock_config, 21 | self.mock_processor, 22 | self.mock_image_processor, 23 | take=10, 24 | split="train", 25 | ) 26 | 27 | self.assertEqual(len(dataset), len(self.mock_hf_dataset["train"].take(10))) 28 | self.assertEqual(dataset.config, self.mock_config) 29 | self.assertEqual(dataset.processor, self.mock_processor) 30 | self.assertEqual(dataset.image_processor, self.mock_image_processor) 31 | 32 | @patch("mlx_vlm.trainer.trainer.get_prompt") 33 | @patch("mlx_vlm.utils.prepare_inputs") 34 | def test_dataset_getitem(self, mock_prepare_inputs, mock_get_prompt): 35 | dataset = Dataset( 36 | self.mock_hf_dataset, 37 | self.mock_config, 38 | self.mock_processor, 39 | self.mock_image_processor, 40 | ) 41 | 42 | mock_item = { 43 | "images": ["image1.jpg"], 44 | "messages": [{"role": "user", "content": "Hello"}], 45 | } 46 | self.mock_hf_dataset.__getitem__.return_value = mock_item 47 | 48 | mock_get_prompt.return_value = "Mocked prompt" 49 | 50 | mock_prepare_inputs.return_value = { 51 | "input_ids": mx.array([1, 2, 3]), # input_ids 52 | "pixel_values": mx.array( 53 | [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]] 54 | ), # pixel_values 55 | "attention_mask": mx.array([1, 1, 1]), # mask 56 | "image_grid_thw": (1, 1, 1), # image_grid_thw 57 | "image_sizes": [224, 224], # image_sizes 58 | } 59 | 60 | result = dataset[0] 61 | 62 | mock_prepare_inputs.assert_called_once() 63 | self.assertIn("pixel_values", result) 64 | self.assertIn("input_ids", result) 65 | self.assertIn("attention_mask", result) 66 | self.assertIn("image_grid_thw", result) 67 | self.assertIn("image_sizes", result) 68 | 69 | # Check if the returned values match the mocked input 70 | self.assertTrue(mx.array_equal(result["input_ids"], mx.array([1, 2, 3]))) 71 | self.assertTrue( 72 | mx.array_equal( 73 | result["pixel_values"], 74 | mx.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]), 75 | ) 76 | ) 77 | self.assertTrue(mx.array_equal(result["attention_mask"], mx.array([1, 1, 1]))) 78 | self.assertEqual(result["image_grid_thw"], (1, 1, 1)) 79 | self.assertEqual(result["image_sizes"], [224, 224]) 80 | 81 | 82 | class TestTrainer(unittest.TestCase): 83 | def setUp(self): 84 | self.mock_model = MagicMock(spec=nn.Module) 85 | self.mock_optimizer = MagicMock() 86 | self.trainer = Trainer(self.mock_model, self.mock_optimizer) 87 | 88 | def test_trainer_initialization(self): 89 | self.assertEqual(self.trainer.model, self.mock_model) 90 | self.assertEqual(self.trainer.optimizer, self.mock_optimizer) 91 | self.assertFalse(self.trainer.train_on_completions) 92 | self.assertEqual(self.trainer.assistant_id, 77091) 93 | 94 | def test_loss_fn(self): 95 | batch = { 96 | "pixel_values": mx.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), 97 | "input_ids": mx.array([[1, 2, 3], [4, 5, 6]]), 98 | "attention_mask": mx.array([[1, 1, 1], [1, 1, 0]]), 99 | "image_grid_thw": (1, 1, 1), 100 | "image_sizes": [224, 224], 101 | "aspect_ratio_ids": mx.array([[1, 2], [3, 4]]), 102 | "aspect_ratio_mask": mx.array([[1, 1], [1, 0]]), 103 | "cross_attention_mask": mx.array([[1, 1], [1, 0]]), 104 | } 105 | 106 | mock_logits = mx.array([[[0.1, 0.2, 0.3]], [[0.4, 0.5, 0.6]]]) 107 | # Create a mock LanguageModelOutput with the logits 108 | mock_output = Mock() 109 | mock_output.logits = mock_logits 110 | self.mock_model.return_value = mock_output 111 | 112 | loss = self.trainer.loss_fn(self.mock_model, batch) 113 | 114 | self.assertIsInstance(loss, mx.array) 115 | self.assertEqual(loss.shape, ()) # Scalar value 116 | 117 | @patch.object(Trainer, "loss_fn") 118 | @patch("mlx.nn.value_and_grad") 119 | def test_train_step(self, mock_value_and_grad, mock_loss_fn): 120 | mock_batch = MagicMock() 121 | mock_loss = mx.array(0.5) 122 | mock_grads = {"param1": mx.array([0.1, 0.2]), "param2": mx.array([0.3, 0.4])} 123 | 124 | mock_value_and_grad.return_value = lambda *args, **kwargs: ( 125 | mock_loss, 126 | mock_grads, 127 | ) 128 | 129 | loss = self.trainer.train_step(mock_batch) 130 | 131 | self.mock_optimizer.update.assert_called_once_with(self.mock_model, mock_grads) 132 | self.assertEqual(loss, mock_loss) 133 | 134 | 135 | if __name__ == "__main__": 136 | unittest.main() 137 | -------------------------------------------------------------------------------- /mlx_vlm/tests/test_trainer_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import MagicMock, patch 3 | 4 | import mlx.nn as nn 5 | 6 | from mlx_vlm.trainer.utils import ( 7 | find_all_linear_names, 8 | get_module_by_name, 9 | get_peft_model, 10 | set_module_by_name, 11 | ) 12 | 13 | 14 | class TestTrainerUtils(unittest.TestCase): 15 | 16 | def test_get_module_by_name(self): 17 | model = MagicMock() 18 | model.layer1.layer2.layer3 = "test_module" 19 | 20 | result = get_module_by_name(model, "layer1.layer2.layer3") 21 | self.assertEqual(result, "test_module") 22 | 23 | def test_set_module_by_name(self): 24 | model = MagicMock() 25 | new_module = MagicMock() 26 | 27 | set_module_by_name(model, "layer1.layer2.layer3", new_module) 28 | self.assertEqual(model.layer1.layer2.layer3, new_module) 29 | 30 | @patch("mlx_vlm.trainer.utils.freeze_model") 31 | @patch("mlx_vlm.trainer.utils.print_trainable_parameters") 32 | def test_get_peft_model(self, mock_print, mock_freeze): 33 | model = MagicMock() 34 | model.language_model.named_modules.return_value = [ 35 | ("layer1", nn.Linear(256, 512)), 36 | ("layer2", nn.QuantizedLinear(256, 512, 8)), 37 | ] 38 | 39 | result = get_peft_model(model, ["layer1", "layer2"]) 40 | 41 | self.assertTrue(mock_freeze.called) 42 | self.assertTrue(mock_print.called) 43 | self.assertTrue(hasattr(model.config, "lora")) 44 | 45 | def test_find_all_linear_names(self): 46 | model = MagicMock() 47 | model.named_modules.return_value = [ 48 | ("layer1", nn.Linear(256, 512)), 49 | ("layer2", nn.QuantizedLinear(256, 512, 8)), 50 | ("mm_projector", nn.Linear(256, 512)), 51 | ("lm_head", nn.Linear(256, 512)), 52 | ] 53 | 54 | result = find_all_linear_names(model) 55 | self.assertEqual(set(result), {"layer1", "layer2"}) 56 | 57 | 58 | if __name__ == "__main__": 59 | unittest.main() 60 | -------------------------------------------------------------------------------- /mlx_vlm/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .lora import LoRaLayer, replace_lora_with_linear 2 | from .trainer import Dataset, Trainer, save_adapter 3 | from .utils import ( 4 | apply_lora_layers, 5 | count_parameters, 6 | find_all_linear_names, 7 | get_peft_model, 8 | print_trainable_parameters, 9 | ) 10 | -------------------------------------------------------------------------------- /mlx_vlm/trainer/lora.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Union 3 | 4 | import mlx.core as mx 5 | import mlx.nn as nn 6 | 7 | 8 | class LoRaLayer(nn.Module): 9 | def __init__( 10 | self, 11 | linear: Union[nn.Linear, nn.QuantizedLinear], 12 | rank: int, 13 | alpha: float = 0.1, 14 | dropout: float = 0.0, 15 | ): 16 | super().__init__() 17 | 18 | self.original_layer = linear 19 | 20 | self.dropout = nn.Dropout(p=dropout) 21 | 22 | output_dims, input_dims = linear.weight.shape 23 | if isinstance(linear, nn.QuantizedLinear): 24 | input_dims *= 32 // linear.bits 25 | 26 | std_dev = 1 / math.sqrt(rank) 27 | 28 | self.A = mx.random.uniform( 29 | low=-std_dev, 30 | high=std_dev, 31 | shape=(input_dims, rank), 32 | ) 33 | self.B = mx.zeros((rank, output_dims)) 34 | self.alpha = alpha 35 | 36 | def __call__(self, x): 37 | y = self.original_layer(x) 38 | lora_update = (self.dropout(x) @ self.A) @ self.B 39 | return y + (self.alpha * lora_update).astype(x.dtype) 40 | 41 | 42 | def replace_lora_with_linear(model): 43 | for i, layer in enumerate(model.layers): 44 | if isinstance(layer, LoRaLayer): 45 | # Compute the final merged weight 46 | lora_update = layer.alpha * (layer.A @ layer.B) 47 | updated_weight = layer.original_layer.weight + lora_update 48 | use_bias = layer.original_layer.bias is not None 49 | 50 | updated_bias = layer.original_layer.bias 51 | 52 | # Create a new Linear layer with the updated parameters 53 | new_linear_layer = nn.Linear( 54 | updated_weight.size(1), updated_weight.size(0), bias=use_bias 55 | ) 56 | 57 | new_linear_layer.weight = updated_weight 58 | 59 | if use_bias: 60 | new_linear_layer.bias = updated_bias 61 | 62 | if isinstance(layer.original_layer, nn.QuantizedLinear): 63 | new_linear_layer = nn.QuantizedLinear.from_linear( 64 | new_linear_layer, 65 | new_linear_layer.group_size, 66 | new_linear_layer.bits, 67 | ) 68 | 69 | # Replace the LoRaLayer with the new Linear layer in the model 70 | model.layers[i] = new_linear_layer 71 | -------------------------------------------------------------------------------- /mlx_vlm/trainer/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import mlx.nn as nn 5 | from mlx.utils import tree_flatten 6 | 7 | from .lora import LoRaLayer 8 | 9 | 10 | def get_module_by_name(model, name): 11 | parts = name.split(".") 12 | module = model 13 | for part in parts: 14 | if part.isdigit(): 15 | module = module[int(part)] 16 | else: 17 | module = getattr(module, part) 18 | return module 19 | 20 | 21 | def set_module_by_name(model, name, new_module): 22 | parts = name.split(".") 23 | module = model 24 | for part in parts[:-1]: 25 | if part.isdigit(): 26 | module = module[int(part)] 27 | else: 28 | module = getattr(module, part) 29 | if parts[-1].isdigit(): 30 | module[int(parts[-1])] = new_module 31 | else: 32 | setattr(module, parts[-1], new_module) 33 | 34 | 35 | def get_peft_model( 36 | model, linear_layers, rank=10, alpha=0.1, dropout=0.1, freeze=True, verbose=True 37 | ): 38 | if freeze: 39 | freeze_model(model) 40 | 41 | for name, module in model.language_model.named_modules(): 42 | if isinstance(module, nn.Linear) or isinstance(module, nn.QuantizedLinear): 43 | if name.split(".")[-1] in linear_layers: 44 | lora_layer = LoRaLayer(module, rank, alpha, dropout) 45 | set_module_by_name(model.language_model, name, lora_layer) 46 | 47 | model.config.lora = {} 48 | model.config.lora["rank"] = rank 49 | model.config.lora["alpha"] = alpha 50 | model.config.lora["dropout"] = dropout 51 | 52 | if verbose: 53 | print_trainable_parameters(model.language_model) 54 | 55 | return model 56 | 57 | 58 | def freeze_model(model): 59 | for name, module in model.named_modules(): 60 | name = name.split(".")[0] 61 | if name in [ 62 | "language_model", 63 | "vision_model", 64 | "vision_tower", 65 | "aligner", 66 | "connector", 67 | "multi_modal_projector", 68 | "mm_projector", 69 | ]: 70 | model[f"{name}"].freeze() 71 | 72 | 73 | def find_all_linear_names(model): 74 | cls = nn.Linear 75 | quantized_cls = nn.QuantizedLinear 76 | lora_module_names = set() 77 | multimodal_keywords = [ 78 | "mm_projector", 79 | "vision_tower", 80 | "vision_resampler", 81 | "aligner", 82 | ] 83 | for name, module in model.named_modules(): 84 | if any(mm_keyword in name for mm_keyword in multimodal_keywords): 85 | continue 86 | if isinstance(module, cls) or isinstance(module, quantized_cls): 87 | names = name.split(".") 88 | lora_module_names.add(names[0] if len(names) == 1 else names[-1]) 89 | 90 | if "lm_head" in lora_module_names: # needed for 16-bit 91 | lora_module_names.remove("lm_head") 92 | return list(lora_module_names) 93 | 94 | 95 | def count_parameters(model): 96 | def nparams(m): 97 | if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)): 98 | return m.weight.size * (32 // m.bits) 99 | return sum(v.size for _, v in tree_flatten(m.parameters())) 100 | 101 | leaf_modules = tree_flatten( 102 | model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) 103 | ) 104 | total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6 105 | 106 | return total_p 107 | 108 | 109 | def print_trainable_parameters(model): 110 | def nparams(m): 111 | if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)): 112 | return m.weight.size * (32 // m.bits) 113 | return sum(v.size for _, v in tree_flatten(m.parameters())) 114 | 115 | leaf_modules = tree_flatten( 116 | model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) 117 | ) 118 | total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6 119 | trainable_p = ( 120 | sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 121 | ) 122 | 123 | print( 124 | f"#trainable params: {trainable_p} M || all params: {total_p} M || trainable%: {(trainable_p * 100 / total_p):.3f}%" 125 | ) 126 | 127 | 128 | def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module: 129 | """ 130 | Apply LoRA layers to the model. 131 | 132 | Args: 133 | model (nn.Module): The neural network model. 134 | adapter_path (str): Path to the adapter configuration file. 135 | 136 | Returns: 137 | nn.Module: The updated model with LoRA layers applied. 138 | """ 139 | adapter_path = Path(adapter_path) 140 | 141 | if not adapter_path.exists(): 142 | raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}") 143 | 144 | # Check if the adapter has lora params in the config (adapter_config.json) 145 | with open(adapter_path / "adapter_config.json", "r") as f: 146 | config = json.load(f) 147 | if "rank" not in config: 148 | raise ValueError("The adapter does not have lora params in the config") 149 | 150 | # TODO: add lora params to the config and load them here 151 | list_of_modules = find_all_linear_names(model.language_model.model) 152 | if config is not None: 153 | model = get_peft_model(model, list_of_modules, **config) 154 | else: 155 | model = get_peft_model(model, list_of_modules) 156 | 157 | # TODO: Use custom adapter name 158 | model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False) 159 | 160 | return model 161 | -------------------------------------------------------------------------------- /mlx_vlm/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.26" 2 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | asyncio_mode = auto 3 | asyncio_default_fixture_loop_scope = function -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.25.0 2 | datasets>=2.19.1 3 | tqdm>=4.66.2 4 | numpy>=1.23.4 5 | transformers>=4.51.3 6 | gradio>=5.19.0 7 | Pillow>=10.3.0 8 | requests>=2.31.0 9 | opencv-python==4.10.0.84 10 | mlx-lm>=0.23.0 11 | fastapi>=0.95.1 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | from setuptools import find_packages, setup 5 | 6 | # Get the project root directory 7 | root_dir = Path(__file__).parent 8 | 9 | # Add the package directory to the Python path 10 | package_dir = root_dir / "mlx_vlm" 11 | sys.path.append(str(package_dir)) 12 | 13 | # Read the requirements from the requirements.txt file 14 | requirements_path = root_dir / "requirements.txt" 15 | with open(requirements_path) as fid: 16 | requirements = [l.strip() for l in fid.readlines()] 17 | 18 | # Import the version from the package 19 | from version import __version__ 20 | 21 | # Setup configuration 22 | setup( 23 | name="mlx-vlm", 24 | version=__version__, 25 | description="Vision LLMs on Apple silicon with MLX and the Hugging Face Hub", 26 | long_description=open(root_dir / "README.md", encoding="utf-8").read(), 27 | long_description_content_type="text/markdown", 28 | author_email="prince.gdt@gmail.com", 29 | author="Prince Canuma", 30 | url="https://github.com/Blaizzy/mlx-vlm", 31 | license="MIT", 32 | install_requires=requirements, 33 | packages=find_packages(where=root_dir), 34 | python_requires=">=3.8", 35 | entry_points={ 36 | "console_scripts": [ 37 | "mlx_vlm.convert = mlx_vlm.convert:main", 38 | "mlx_vlm.generate = mlx_vlm.generate:main", 39 | ] 40 | }, 41 | ) 42 | -------------------------------------------------------------------------------- /update_changelog.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import requests 5 | 6 | 7 | def fetch_releases(repo: str, token: str | None = None): 8 | url = f"https://api.github.com/repos/{repo}/releases" 9 | headers = {"Accept": "application/vnd.github+json"} 10 | if token: 11 | headers["Authorization"] = f"Bearer {token}" 12 | response = requests.get(url, headers=headers, timeout=30) 13 | response.raise_for_status() 14 | return response.json() 15 | 16 | 17 | def write_changelog(releases): 18 | path = Path("docs/changelog.md") 19 | lines = [ 20 | "# Changelog", 21 | "", 22 | "_This file is automatically generated from GitHub releases._", 23 | "", 24 | ] 25 | for rel in releases: 26 | tag = rel.get("tag_name") or "" 27 | name = rel.get("name") or tag 28 | date = (rel.get("published_at") or "")[:10] 29 | body = rel.get("body") or "" 30 | header = f"## {name} - {date}".strip() 31 | lines.append(header) 32 | if body: 33 | lines.append("") 34 | lines.extend(body.splitlines()) 35 | lines.append("") 36 | path.write_text("\n".join(lines)) 37 | 38 | 39 | def main() -> None: 40 | repo = os.getenv("GITHUB_REPOSITORY", "Blaizzy/mlx-vlm") 41 | token = os.getenv("GITHUB_TOKEN") 42 | releases = fetch_releases(repo, token) 43 | write_changelog(releases) 44 | 45 | 46 | if __name__ == "__main__": 47 | main() 48 | --------------------------------------------------------------------------------