├── .github ├── FUNDING.yml └── workflows │ ├── deploy-docs.yml │ ├── python-publish.yml │ ├── tests.yml │ └── update-changelog.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── computer_use ├── README.md ├── audio │ ├── ask.wav │ ├── ask_2.wav │ ├── ask_voice.wav │ ├── ask_voice_2.wav │ ├── ok.wav │ ├── output.wav │ └── task_completed.wav ├── autonomous_gui_agent.py ├── autonomous_gui_agent_voice.py ├── gui_agent.py ├── gui_agent_voice.py ├── navigation_history.csv ├── requirements.txt ├── screenshots │ ├── screenshot_20241210-191351.png │ ├── screenshot_20241210-191519.png │ ├── screenshot_20241210-192437.png │ ├── screenshot_20241210-193535.png │ ├── screenshot_20241210-193911.png │ ├── screenshot_20241210-194314.png │ ├── screenshot_20241210-195055.png │ ├── screenshot_20241210-195145.png │ ├── screenshot_20241210-230354.png │ ├── screenshot_20241210-230553.png │ ├── screenshot_20241210-230651.png │ ├── screenshot_20241210-230730.png │ ├── screenshot_20241210-230923.png │ ├── screenshot_20241210-230958.png │ ├── screenshot_20241210-231015.png │ ├── screenshot_20241210-231058.png │ ├── screenshot_20241210-231133.png │ ├── screenshot_20241210-231212.png │ ├── screenshot_20241210-231321.png │ ├── screenshot_20241210-231404.png │ ├── screenshot_20241210-231511.png │ ├── screenshot_20241210-231527.png │ ├── screenshot_20241210-231552.png │ ├── screenshot_20241210-231719.png │ ├── screenshot_20241210-231751.png │ ├── screenshot_20241211-101202.png │ └── screenshot_20241211-101219.png └── utils.py ├── docs ├── changelog.md ├── cli_reference.md ├── community_projects.md ├── contributing.md ├── examples.md ├── index.md ├── installation.md ├── overrides │ └── main.html ├── report_issues.md └── usage.md ├── examples ├── images │ ├── cats.jpg │ ├── desktop_setup.png │ ├── graph.png │ ├── handwritten_hard.webp │ ├── latex.png │ ├── menu.webp │ ├── paper.png │ └── renewables_california.png ├── multi_image_generation.ipynb ├── object_detection.ipynb ├── object_pointing.ipynb ├── ocr_with_region.ipynb ├── text_extraction.ipynb ├── utils.py ├── video_understanding.ipynb └── videos │ └── fastmlx_local_ai_hub.mp4 ├── mkdocs.yml ├── mlx_vlm ├── LORA.MD ├── __init__.py ├── chat.py ├── chat_ui.py ├── convert.py ├── generate.py ├── lora.py ├── models │ ├── __init__.py │ ├── aya_vision │ │ ├── __init__.py │ │ ├── aya_vision.py │ │ ├── interpolate.py │ │ ├── language.py │ │ └── vision.py │ ├── base.py │ ├── cache.py │ ├── deepseek_vl_v2 │ │ ├── __init__.py │ │ ├── conversation.py │ │ ├── deepseek_vl_v2.py │ │ ├── language.py │ │ ├── processing_deepsek_vl_v2.py │ │ └── vision.py │ ├── florence2 │ │ ├── __init__.py │ │ ├── florence2.py │ │ ├── language.py │ │ └── vision.py │ ├── gemma3 │ │ ├── __init__.py │ │ ├── gemma3.py │ │ ├── language.py │ │ └── vision.py │ ├── idefics2 │ │ ├── __init__.py │ │ ├── idefics2.py │ │ ├── language.py │ │ └── vision.py │ ├── idefics3 │ │ ├── __init__.py │ │ ├── idefics3.py │ │ ├── language.py │ │ └── vision.py │ ├── internvl_chat │ │ ├── __init__.py │ │ ├── internvl_chat.py │ │ ├── language.py │ │ ├── processor.py │ │ └── vision.py │ ├── kimi_vl │ │ ├── __init__.py │ │ ├── kimi_vl.py │ │ ├── language.py │ │ └── vision.py │ ├── llama4 │ │ ├── __init__.py │ │ ├── language.py │ │ ├── llama4.py │ │ └── vision.py │ ├── llava │ │ ├── __init__.py │ │ ├── language.py │ │ ├── llava.py │ │ └── vision.py │ ├── llava_bunny │ │ ├── __init__.py │ │ ├── language.py │ │ ├── llava_bunny.py │ │ └── vision.py │ ├── llava_next │ │ ├── __init__.py │ │ ├── language.py │ │ ├── llava_next.py │ │ └── vision.py │ ├── mistral3 │ │ ├── __init__.py │ │ └── mistral3.py │ ├── mllama │ │ ├── __init__.py │ │ ├── language.py │ │ ├── mllama.py │ │ └── vision.py │ ├── molmo │ │ ├── __init__.py │ │ ├── language.py │ │ ├── molmo.py │ │ └── vision.py │ ├── multi_modality │ │ ├── __init__.py │ │ ├── language.py │ │ ├── multi_modality.py │ │ ├── sam.py │ │ └── vision.py │ ├── paligemma │ │ ├── __init__.py │ │ ├── language.py │ │ ├── paligemma.py │ │ └── vision.py │ ├── phi3_v │ │ ├── __init__.py │ │ ├── language.py │ │ ├── phi3_v.py │ │ ├── su_rope.py │ │ └── vision.py │ ├── pixtral │ │ ├── __init__.py │ │ ├── language.py │ │ ├── pixtral.py │ │ └── vision.py │ ├── qwen2_5_vl │ │ ├── __init__.py │ │ ├── language.py │ │ ├── qwen2_5_vl.py │ │ └── vision.py │ ├── qwen2_vl │ │ ├── __init__.py │ │ ├── config.py │ │ ├── language.py │ │ ├── qwen2_vl.py │ │ └── vision.py │ └── smolvlm │ │ ├── __init__.py │ │ └── smolvlm.py ├── prompt_utils.py ├── sample_utils.py ├── server.py ├── smolvlm_video_generate.py ├── tests │ ├── test_models.py │ ├── test_smoke.py │ ├── test_trainer.py │ ├── test_trainer_utils.py │ └── test_utils.py ├── tokenizer_utils.py ├── trainer │ ├── __init__.py │ ├── lora.py │ ├── trainer.py │ └── utils.py ├── utils.py ├── version.py └── video_generate.py ├── pytest.ini ├── requirements.txt ├── setup.py └── update_changelog.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: Blaizzy # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | polar: # Replace with a single Polar username 13 | buy_me_a_coffee: # Replace with a single Buy Me a Coffee username 14 | thanks_dev: # Replace with a single thanks.dev username 15 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 16 | -------------------------------------------------------------------------------- /.github/workflows/deploy-docs.yml: -------------------------------------------------------------------------------- 1 | name: Deploy Documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | deploy: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v2 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install mkdocs mkdocs-material mkdocstrings[python] mkdocs-autorefs mkdocs-git-revision-date-localized-plugin mkdocs-jupyter 26 | 27 | - name: Build documentation 28 | run: mkdocs build 29 | 30 | - name: Deploy to GitHub Pages 31 | if: github.event_name == 'push' && github.ref == 'refs/heads/main' 32 | uses: peaceiris/actions-gh-pages@v3 33 | with: 34 | github_token: ${{ secrets.DOCS_TOKEN }} 35 | publish_dir: ./site 36 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [published] 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | deploy: 15 | 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - uses: actions/checkout@v3 20 | - name: Set up Python 21 | uses: actions/setup-python@v3 22 | with: 23 | python-version: '3.10' 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install build 28 | - name: Build package 29 | run: python -m build 30 | - name: Publish package 31 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 32 | with: 33 | user: __token__ 34 | password: ${{ secrets.PYPI_API_TOKEN }} 35 | packages_dir: dist 36 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Test PRs 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | test: 10 | runs-on: macos-14 11 | 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v4 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: '3.10' 20 | 21 | - name: Install MLX 22 | run: | 23 | pip install mlx>=0.15 24 | 25 | - name: Install pre-commit 26 | run: | 27 | python -m pip install pre-commit 28 | pre-commit run --all 29 | if ! git diff --quiet; then 30 | echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change' 31 | exit 1 32 | fi 33 | 34 | - name: Install package and dependencies 35 | run: | 36 | python -m pip install pytest 37 | python -m pip install -e . 38 | 39 | - name: Run Python tests 40 | run: | 41 | cd mlx_vlm/ 42 | pytest -s ./tests --ignore=tests/test_smoke.py 43 | -------------------------------------------------------------------------------- /.github/workflows/update-changelog.yml: -------------------------------------------------------------------------------- 1 | name: Update Changelog 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | update-changelog: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 13 | uses: actions/setup-python@v2 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install requests python-dotenv 20 | - name: Update Changelog 21 | env: 22 | GITHUB_TOKEN: ${{ secrets.DOCS_TOKEN }} 23 | run: python update_changelog.py 24 | - name: Commit changes 25 | run: | 26 | set -x 27 | git config --local user.email "action@github.com" 28 | git config --local user.name "GitHub Action" 29 | 30 | echo "Fetching latest changes..." 31 | git fetch origin 32 | 33 | echo "Checking out and updating branch..." 34 | if [ "${{ github.event_name }}" = "pull_request" ]; then 35 | git checkout -B "${{ github.head_ref }}" "origin/${{ github.head_ref }}" 36 | git pull origin "${{ github.head_ref }}" 37 | else 38 | git checkout -B "${{ github.ref_name }}" "origin/${{ github.ref_name }}" 39 | git pull origin "${{ github.ref_name }}" 40 | fi 41 | 42 | echo "Running update script..." 43 | python update_changelog.py 44 | 45 | echo "Checking for changes..." 46 | git add docs/changelog.md 47 | git pull 48 | if git diff --staged --quiet; then 49 | echo "No changes to commit" 50 | else 51 | git commit -m "Update changelog for latest release" 52 | git push origin HEAD:"${{ github.head_ref || github.ref_name }}" || echo "Failed to push changes" 53 | fi 54 | 55 | git status 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | test.ipynb 2 | **.pyc 3 | 4 | Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | 8 | # Distribution / packaging 9 | bin/ 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | .DS_Store 23 | *.log 24 | 25 | # virtual environment 26 | .myenv 27 | .venv -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black-pre-commit-mirror 3 | rev: 24.2.0 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/pycqa/isort 7 | rev: 5.13.2 8 | hooks: 9 | - id: isort 10 | args: 11 | - --profile=black 12 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to MLX VLM 2 | 3 | Below are some tips to port Vision LLMs available on Hugging Face to MLX. 4 | 5 | Next, from this directory, do an editable install: 6 | 7 | ```shell 8 | pip install -e . 9 | ``` 10 | 11 | Then check if the model has weights in the 12 | [safetensors](https://huggingface.co/docs/safetensors/index) format. If not 13 | [follow instructions](https://huggingface.co/spaces/safetensors/convert) to 14 | convert it. 15 | 16 | After that, add the model file to the 17 | [`mlx_vlm/models`](https://github.com/Blaizzy/mlx-vlm/tree/main/src/models) 18 | directory. You can see other examples there. We recommend starting from a model 19 | that is similar to the model you are porting. 20 | 21 | Make sure the name of the new model file is the same as the `model_type` in the 22 | `config.json`, for example 23 | [llava](https://huggingface.co/llava-hf/llava-1.5-7b-hf/blob/main/config.json#L7). 24 | 25 | To determine the model layer names, we suggest either: 26 | 27 | - Refer to the Transformers implementation if you are familiar with the 28 | codebase. 29 | - Load the model weights and check the weight names which will tell you about 30 | the model structure. 31 | - Look at the names of the weights by inspecting `model.safetensors.index.json` 32 | in the Hugging Face repo. 33 | 34 | Additionally, add a test for the new modle type to the [model 35 | tests](https://github.com/Blaizzy/mlx-vlm/tree/main/src/tests/test_models.py). 36 | 37 | From the `src/` directory, you can run the tests with: 38 | 39 | ```shell 40 | python -m unittest discover tests/ 41 | ``` 42 | 43 | ## Pull Requests 44 | 45 | 1. Fork and submit pull requests to the repo. 46 | 2. If you've added code that should be tested, add tests. 47 | 3. Every PR should have passing tests and at least one review. 48 | 4. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`. 49 | This should install hooks for running `black` and `clang-format` to ensure 50 | consistent style for C++ and python code. 51 | 52 | You can also run the formatters manually as follows on individual files: 53 | 54 | ```bash 55 | clang-format -i file.cpp 56 | ``` 57 | 58 | ```bash 59 | black file.py 60 | ``` 61 | 62 | or, 63 | 64 | ```bash 65 | # single file 66 | pre-commit run --files file1.py 67 | 68 | # specific files 69 | pre-commit run --files file1.py file2.py 70 | ``` 71 | 72 | or run `pre-commit run --all-files` to check all files in the repo. 73 | 74 | ## Issues 75 | 76 | We use GitHub issues to track public bugs. Please ensure your description is 77 | clear and has sufficient instructions to be able to reproduce the issue. 78 | 79 | ## License 80 | 81 | By contributing to mlx-examples, you agree that your contributions will be licensed 82 | under the LICENSE file in the root directory of this source tree. 83 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright © 2023 Apple Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include ./requirements.txt 2 | recursive-include mlx_vlm/ *.py 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [](https://github.com/Blaizzy/mlx-vlm/actions/workflows/python-publish.yml) 2 | # MLX-VLM 3 | 4 | MLX-VLM is a package for inference and fine-tuning of Vision Language Models (VLMs) on your Mac using MLX. 5 | 6 | ## Table of Contents 7 | - [Installation](#installation) 8 | - [Usage](#usage) 9 | - [Command Line Interface (CLI)](#command-line-interface-cli) 10 | - [Chat UI with Gradio](#chat-ui-with-gradio) 11 | - [Python Script](#python-script) 12 | - [Multi-Image Chat Support](#multi-image-chat-support) 13 | - [Supported Models](#supported-models) 14 | - [Usage Examples](#usage-examples) 15 | - [Fine-tuning](#fine-tuning) 16 | 17 | ## Installation 18 | 19 | The easiest way to get started is to install the `mlx-vlm` package using pip: 20 | 21 | ```sh 22 | pip install mlx-vlm 23 | ``` 24 | 25 | ## Usage 26 | 27 | ### Command Line Interface (CLI) 28 | 29 | Generate output from a model using the CLI: 30 | 31 | ```sh 32 | python -m mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --temperature 0.0 --image http://images.cocodataset.org/val2017/000000039769.jpg 33 | ``` 34 | 35 | ### Chat UI with Gradio 36 | 37 | Launch a chat interface using Gradio: 38 | 39 | ```sh 40 | python -m mlx_vlm.chat_ui --model mlx-community/Qwen2-VL-2B-Instruct-4bit 41 | ``` 42 | 43 | ### Python Script 44 | 45 | Here's an example of how to use MLX-VLM in a Python script: 46 | 47 | ```python 48 | import mlx.core as mx 49 | from mlx_vlm import load, generate 50 | from mlx_vlm.prompt_utils import apply_chat_template 51 | from mlx_vlm.utils import load_config 52 | 53 | # Load the model 54 | model_path = "mlx-community/Qwen2-VL-2B-Instruct-4bit" 55 | model, processor = load(model_path) 56 | config = load_config(model_path) 57 | 58 | # Prepare input 59 | image = ["http://images.cocodataset.org/val2017/000000039769.jpg"] 60 | # image = [Image.open("...")] can also be used with PIL.Image.Image objects 61 | prompt = "Describe this image." 62 | 63 | # Apply chat template 64 | formatted_prompt = apply_chat_template( 65 | processor, config, prompt, num_images=len(image) 66 | ) 67 | 68 | # Generate output 69 | output = generate(model, processor, formatted_prompt, image, verbose=False) 70 | print(output) 71 | ``` 72 | 73 | ### Server (FastAPI) 74 | To start the server 75 | ```sh 76 | python -m mlx_vlm.server 77 | ``` 78 | 79 | Models can be loaded or unloaded dynamically and they are cached (one at a time) when the server is running. 80 | 81 | Usage example: 82 | ```sh 83 | curl -X POST "http://localhost:8000/generate" \ 84 | -H "Content-Type: application/json" \ 85 | -d '{ 86 | "model": "mlx-community/Qwen2.5-VL-32B-Instruct-8bit", 87 | "image": ["/path/to/repo/examples/images/renewables_california.png"], 88 | "prompt": "This is today'\''s chart for energy demand in California. Can you provide an analysis of the chart and comment on the implications for renewable energy in California?", 89 | "system": "You are a helpful assistant.", 90 | "stream": true, 91 | "max_tokens": 1000 92 | }' 93 | ``` 94 | 95 | 96 | ## Multi-Image Chat Support 97 | 98 | MLX-VLM supports analyzing multiple images simultaneously with select models. This feature enables more complex visual reasoning tasks and comprehensive analysis across multiple images in a single conversation. 99 | 100 | ### Supported Models 101 | 102 | The following models support multi-image chat: 103 | 104 | 1. Idefics 2 105 | 2. LLaVA (Interleave) 106 | 3. Qwen2-VL 107 | 4. Phi3-Vision 108 | 5. Pixtral 109 | 110 | ### Usage Examples 111 | 112 | #### Python Script 113 | 114 | ```python 115 | from mlx_vlm import load, generate 116 | from mlx_vlm.prompt_utils import apply_chat_template 117 | from mlx_vlm.utils import load_config 118 | 119 | model_path = "mlx-community/Qwen2-VL-2B-Instruct-4bit" 120 | model, processor = load(model_path) 121 | config = load_config(model_path) 122 | 123 | images = ["path/to/image1.jpg", "path/to/image2.jpg"] 124 | prompt = "Compare these two images." 125 | 126 | formatted_prompt = apply_chat_template( 127 | processor, config, prompt, num_images=len(images) 128 | ) 129 | 130 | output = generate(model, processor, formatted_prompt, images, verbose=False) 131 | print(output) 132 | ``` 133 | 134 | #### Command Line 135 | 136 | ```sh 137 | python -m mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --prompt "Compare these images" --image path/to/image1.jpg path/to/image2.jpg 138 | ``` 139 | 140 | ## Video Understanding 141 | 142 | MLX-VLM also supports video analysis such as captioning, summarization, and more, with select models. 143 | 144 | ### Supported Models 145 | 146 | The following models support video chat: 147 | 148 | 1. Qwen2-VL 149 | 2. Qwen2.5-VL 150 | 3. Idefics3 151 | 4. LLaVA 152 | 153 | With more coming soon. 154 | 155 | ### Usage Examples 156 | 157 | #### Command Line 158 | ```sh 159 | python -m mlx_vlm.video_generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --prompt "Describe this video" --video path/to/video.mp4 --max-pixels 224 224 --fps 1.0 160 | ``` 161 | 162 | 163 | These examples demonstrate how to use multiple images with MLX-VLM for more complex visual reasoning tasks. 164 | 165 | # Fine-tuning 166 | 167 | MLX-VLM supports fine-tuning models with LoRA and QLoRA. 168 | 169 | ## LoRA & QLoRA 170 | 171 | To learn more about LoRA, please refer to the [LoRA.md](./mlx_vlm/LORA.MD) file. 172 | -------------------------------------------------------------------------------- /computer_use/README.md: -------------------------------------------------------------------------------- 1 | # Computer Use with MLX-VLM 2 | 3 |
15 | Automate your workflow with natural language commands and visual intelligence 16 |
17 | 18 | ## 🤖 Current Implementation Status 19 | The project now supports both Level 1 (GUI Agent) and Level 2 (Autonomous GUI Agent) capabilities: 20 | 21 | - **Level 1 (GUI Agent)**: Basic visual understanding and action capabilities 22 | - **Level 2 (Autonomous GUI Agent)**: Enhanced memory, planning, reasoning, and human-in-the-loop functionality 23 | 24 | *Community help is more than welcome!* We're looking for contributors to help us enhance these capabilities further. Join us in building the future of computer automation. 25 | 26 | 27 | ## 🔍 Overview 28 | 29 | Computer Use with MLX-VLM transforms how you interact with your Mac by combining the power of: 30 | 31 | - **MLX** - Apple's machine learning framework optimized for Apple Silicon 32 | - **Vision Language Models (VLMs)** - AI models that understand both visual and textual information 33 | - **Automation** - Seamless execution of tasks across your Mac's interface 34 | 35 | By processing screenshots and visual information from your screen, the system understands the current state of applications and executes appropriate actions to accomplish tasks you specify in natural language. 36 | 37 | ## ✨ Key Features 38 | 39 | - **Mac-Native Performance**: Optimized for Apple Silicon with MLX for efficient, local processing 40 | - **Visual Understanding**: Interprets screen content, UI elements, and application states 41 | - **Contextual Reasoning**: Makes intelligent decisions based on visual context 42 | - **Cross-Application Automation**: Works across multiple applications and system interfaces 43 | - **Natural Language Control**: Simple, human-like instructions to control your computer 44 | - **Privacy-Focused**: All processing happens locally on your device 45 | - **Customizable**: Adapt to your specific workflow and preferences 46 | - **Autonomous Operation**: Level 2 agent can plan and execute multi-step tasks with minimal supervision 47 | - **Voice Control**: Hands-free operation with voice commands using local speech recognition 48 | 49 | ## 🚀 Getting Started 50 | 51 | ### Prerequisites 52 | 53 | - **macOS** running on Apple Silicon (M series) 54 | - **Python 3.8+** 55 | - **pip** (Python package manager) 56 | 57 | ### Installation 58 | 1. **Install MLX-VLM package**: 59 | ```bash 60 | pip install mlx-vlm 61 | ``` 62 | 63 | 2. **Clone the repository**: 64 | ```bash 65 | git clone https://github.com/Blaizzy/mlx-vlm.git 66 | ``` 67 | 68 | 3. **Navigate to computer control directory**: 69 | ```bash 70 | cd computer_use 71 | ``` 72 | 73 | 4. **Install dependencies**: 74 | ```bash 75 | pip install -r requirements.txt 76 | ``` 77 | 78 | ## 💻 Usage 79 | 80 | ### Quick Start 81 | 82 | Launch the standard application with: 83 | 84 | ```bash 85 | python main.py 86 | ``` 87 | 88 | ### Autonomous GUI Agent 89 | 90 | For enhanced autonomous operation with planning capabilities: 91 | 92 | ```bash 93 | python autonomous_gui_agent.py 94 | ``` 95 | 96 | This launches the Level 2 autonomous agent that can: 97 | - Plan and execute multi-step tasks 98 | - Maintain context across actions 99 | - Make decisions based on visual feedback 100 | - Request human assistance when needed 101 | 102 | ### Voice Control Interface 103 | 104 | For hands-free operation, you can use the voice-enabled autonomous agent: 105 | 106 | ```bash 107 | python autonomous_gui_agent_voice.py 108 | ``` 109 | 110 | This launches a voice-controlled version that: 111 | - Listens for your voice commands using your Mac's microphone 112 | - Converts speech to text using local speech recognition using [mlx-whisper](https://github.com/ml-explore/mlx-examples) 113 | - Processes your commands and executes them visually 114 | - Provides audio feedback on actions taken 115 | 116 | Voice commands work just like text commands, so you can say things like: 117 | 118 | ### Command Examples 119 | 120 | Control your Mac with natural language instructions like: 121 | 122 | ``` 123 | "Open Safari and navigate to apple.com" 124 | "Open the notifications tab and click on the first notification" 125 | "Open the email app and reply to the most recent email" 126 | ``` 127 | 128 | ## ⚙️ How It Works 129 | 130 | 1. **Screen Capture**: The system takes screenshots of your Mac display 131 | 2. **Visual Analysis**: MLX-VLM processes the visual information to understand: 132 | - UI elements and their states 133 | - Text content on screen 134 | - Application context 135 | - System status 136 | 3. **Instruction Processing**: Your natural language commands are interpreted 137 | 4. **Action Planning**: The system determines the sequence of actions needed 138 | 5. **Execution**: Actions are performed through macOS APIs or simulated inputs (click, scroll, etc) 139 | 140 | ## 🔒 Privacy & Security 141 | 142 | - **Local Processing**: All AI inference happens on your Mac using MLX 143 | - **No Cloud Dependency**: Your screenshots and data never leave your device 144 | - **Permission Control**: Fine-grained control over what the system can access 145 | - **Transparent Operation**: Clear visibility into actions being performed 146 | 147 | ## 🛠️ Troubleshooting 148 | 149 | ### Common Issues 150 | 151 | - **Permission Errors**: Make sure to grant screen recording permissions in System Preferences > Security & Privacy > Privacy 152 | - **Performance Issues**: Try reducing the screenshot resolution in config.json 153 | - **Application Compatibility**: Some applications with non-standard UI elements may have limited support 154 | 155 | ### Getting Help 156 | 157 | - Check the [Issues](https://github.com/yourusername/computer_use/issues) page 158 | - Join our [Discord community](https://discord.gg/yourdiscord) 159 | 160 | ## 🤝 Contributing 161 | 162 | We welcome contributions! Here's how to get started: 163 | 164 | 1. **Fork the repository** 165 | 2. **Create your feature branch**: 166 | ```bash 167 | git checkout -b feature/amazing-feature 168 | ``` 169 | 3. **Make your changes** 170 | 4. **Run tests**: 171 | ```bash 172 | python -m pytest 173 | ``` 174 | 5. **Commit your changes**: 175 | ```bash 176 | git commit -m 'Add some amazing feature' 177 | ``` 178 | 6. **Push to the branch**: 179 | ```bash 180 | git push origin feature/amazing-feature 181 | ``` 182 | 7. **Open a Pull Request** 183 | 184 | Please read [CONTRIBUTING.md](CONTRIBUTING.md) for detailed guidelines. 185 | 186 | ## 📜 License 187 | 188 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 189 | 190 | ## 🙏 Acknowledgments 191 | 192 | - The **MLX team at Apple** for creating the MLX framework 193 | - **Our community of testers and contributors** who help improve the project 194 | 195 | --- 196 | 197 |198 | Made with ❤️ for Mac users who love automation and AI 199 |
-------------------------------------------------------------------------------- /computer_use/audio/ask.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/audio/ask.wav -------------------------------------------------------------------------------- /computer_use/audio/ask_2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/audio/ask_2.wav -------------------------------------------------------------------------------- /computer_use/audio/ask_voice.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/audio/ask_voice.wav -------------------------------------------------------------------------------- /computer_use/audio/ask_voice_2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/audio/ask_voice_2.wav -------------------------------------------------------------------------------- /computer_use/audio/ok.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/audio/ok.wav -------------------------------------------------------------------------------- /computer_use/audio/output.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/audio/output.wav -------------------------------------------------------------------------------- /computer_use/audio/task_completed.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/audio/task_completed.wav -------------------------------------------------------------------------------- /computer_use/navigation_history.csv: -------------------------------------------------------------------------------- 1 | Query,Response,Screenshot Path 2 | Click on Explore,"{'action': 'CLICK', 'value': None, 'position': [211.68, 294.59999999999997]}",screenshots/screenshot_20241210-191351.png 3 | Click on notifications,"{'action': 'CLICK', 'value': None, 'position': [241.92000000000002, 333.88000000000005]}",screenshots/screenshot_20241210-191519.png 4 | Click on Profile,"{'action': 'CLICK', 'value': None, 'position': [196.56, 746.32]}",screenshots/screenshot_20241210-192437.png 5 | Click on the Image that Philipp posted,"{'action': 'CLICK', 'value': None, 'position': [589.6800000000001, 569.56]}",screenshots/screenshot_20241210-193535.png 6 | Select text on the screen,"{'action': 'CLICK', 'value': None, 'position': [740.88, 314.24]}",screenshots/screenshot_20241210-193911.png 7 | Click on show posts,"{'action': 'CLICK', 'value': None, 'position': [680.4, 343.7]}",screenshots/screenshot_20241210-194314.png 8 | Search youtube,"{'action': 'CLICK', 'value': None, 'position': [1179.3600000000001, 98.2]}",screenshots/screenshot_20241210-195055.png 9 | Click on the BMW X3 video,"{'action': 'CLICK', 'value': None, 'position': [1300.32, 559.7399999999999]}",screenshots/screenshot_20241210-195145.png 10 | Click on notifications,"{'action': 'CLICK', 'value': None, 'position': [257.04, 333.88000000000005]}",screenshots/screenshot_20241210-230354.png 11 | Click on notifications,"{'action': 'CLICK', 'value': None, 'position': [468.71999999999997, 196.4]}",screenshots/screenshot_20241210-230553.png 12 | Click on notifications,"{'action': 'CLICK', 'value': None, 'position': [468.71999999999997, 196.4]}",screenshots/screenshot_20241210-230651.png 13 | Click on notifications,"{'action': 'CLICK', 'value': None, 'position': [468.71999999999997, 196.4]}",screenshots/screenshot_20241210-230730.png 14 | Click on notifications,"{'action': 'CLICK', 'value': None, 'position': [512.0, 299.52]}",screenshots/screenshot_20241210-230923.png 15 | Click on notifications,"{'action': 'CLICK', 'value': None, 'position': [793.6, 288.0]}",screenshots/screenshot_20241210-230958.png 16 | Click on Home,"{'action': 'CLICK', 'value': None, 'position': [742.4, 187.20000000000002]}",screenshots/screenshot_20241210-231015.png 17 | Click on Elon Musk,"{'action': 'CLICK', 'value': None, 'position': [1239.84, 1040.92]}",screenshots/screenshot_20241210-231058.png 18 | Click on Elon Musk's profile,"{'action': 'CLICK', 'value': None, 'position': [1058.3999999999999, 1060.5600000000002]}",screenshots/screenshot_20241210-231133.png 19 | Click on Elon Musks profile,"{'action': 'CLICK', 'value': None, 'position': [1058.3999999999999, 1060.5600000000002]}",screenshots/screenshot_20241210-231212.png 20 | Click on Nikita,"{'action': 'CLICK', 'value': None, 'position': [1028.16, 510.64000000000004]}",screenshots/screenshot_20241210-231321.png 21 | Click on Explore,"{'action': 'CLICK', 'value': None, 'position': [423.36, 549.9200000000001]}",screenshots/screenshot_20241210-231404.png 22 | Click on Explore,"{'action': 'CLICK', 'value': None, 'position': [211.68, 274.96000000000004]}",screenshots/screenshot_20241210-231511.png 23 | Click on Home,"{'action': 'CLICK', 'value': None, 'position': [181.44, 225.86]}",screenshots/screenshot_20241210-231527.png 24 | Click on Rohan Pauls profile,"{'action': 'CLICK', 'value': None, 'position': [483.84000000000003, 549.9200000000001]}",screenshots/screenshot_20241210-231552.png 25 | cClick on Home,"{'action': 'CLICK', 'value': None, 'position': [393.12, 432.08]}",screenshots/screenshot_20241210-231719.png 26 | Click on Home,"{'action': 'CLICK', 'value': None, 'position': [196.56, 216.04]}",screenshots/screenshot_20241210-231751.png 27 | Click on notifications,"{'action': 'CLICK', 'value': None, 'position': [153.6, 273.6]}",screenshots/screenshot_20241211-101202.png 28 | Click on Home,"{'action': 'CLICK', 'value': None, 'position': [166.4, 172.79999999999998]}",screenshots/screenshot_20241211-101219.png 29 | -------------------------------------------------------------------------------- /computer_use/requirements.txt: -------------------------------------------------------------------------------- 1 | SpeechRecognition>=3.12.0 2 | mlx-whisper>=0.4.1 3 | mlx-audio>=0.0.1 -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-191351.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-191351.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-191519.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-191519.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-192437.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-192437.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-193535.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-193535.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-193911.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-193911.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-194314.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-194314.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-195055.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-195055.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-195145.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-195145.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-230354.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-230354.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-230553.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-230553.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-230651.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-230651.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-230730.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-230730.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-230923.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-230923.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-230958.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-230958.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231015.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231058.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231058.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231133.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231133.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231212.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231212.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231321.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231321.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231404.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231404.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231511.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231511.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231527.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231527.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231552.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231552.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231719.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231719.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241210-231751.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241210-231751.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241211-101202.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241211-101202.png -------------------------------------------------------------------------------- /computer_use/screenshots/screenshot_20241211-101219.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/computer_use/screenshots/screenshot_20241211-101219.png -------------------------------------------------------------------------------- /computer_use/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from io import BytesIO 4 | 5 | import pandas as pd 6 | import requests 7 | from PIL import Image, ImageDraw 8 | 9 | 10 | def draw_point( 11 | image_input, point: tuple = None, radius: int = 8, color: str = "red" 12 | ) -> Image.Image: 13 | """ 14 | Draw a point on an image and return the modified image. 15 | `point` should be (x_norm, y_norm) in normalized coordinates (0 to 1). 16 | """ 17 | # Load image from path/URL or use directly if it's already a PIL Image 18 | if isinstance(image_input, str): 19 | if image_input.startswith("http"): 20 | response = requests.get(image_input) 21 | response.raise_for_status() 22 | image = Image.open(BytesIO(response.content)) 23 | else: 24 | image = Image.open(image_input) 25 | elif isinstance(image_input, Image.Image): 26 | image = image_input 27 | else: 28 | raise ValueError("image_input must be a string path/URL or a PIL Image object") 29 | 30 | # Only draw if a valid point is provided 31 | if point is not None: 32 | x = int(point[0]) 33 | y = int(point[1]) 34 | print(f"Drawing ellipse at pixel coordinates: ({x}, {y})") 35 | 36 | draw = ImageDraw.Draw(image) 37 | draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=color) 38 | 39 | return image 40 | 41 | 42 | # Update CSV with query, response and image path 43 | def update_navigation_history( 44 | query, response, filepath, csv_path="navigation_history.csv" 45 | ): 46 | """ 47 | Update the navigation history CSV file with the query, response and screenshot filepath. 48 | 49 | Args: 50 | query: The user's query/task 51 | response: The system's response/action 52 | filepath: Path to the screenshot image 53 | csv_path: Path to the CSV file (default: navigation_history.csv) 54 | """ 55 | # Create new row as a DataFrame 56 | new_row = pd.DataFrame( 57 | {"Query": [query], "Response": [str(response)], "Screenshot Path": [filepath]} 58 | ) 59 | 60 | if os.path.exists(csv_path): 61 | # Append to existing CSV 62 | new_row.to_csv(csv_path, mode="a", header=False, index=False) 63 | else: 64 | # Create new CSV with headers 65 | new_row.to_csv(csv_path, index=False) 66 | -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | -------------------------------------------------------------------------------- /docs/cli_reference.md: -------------------------------------------------------------------------------- 1 | # CLI Reference 2 | 3 | MLX-VLM provides several command line entry points: 4 | 5 | - `mlx_vlm.convert` – convert Hugging Face models to MLX format. 6 | - `mlx_vlm.generate` – run inference on images. 7 | - `mlx_vlm.video_generate` – generate from a video file. 8 | - `mlx_vlm.smolvlm_video_generate` – lightweight video generation. 9 | - `mlx_vlm.chat_ui` – start an interactive Gradio UI. 10 | - `mlx_vlm.server` – run the FastAPI server. 11 | 12 | Each command accepts `--help` for full usage information. 13 | 14 | -------------------------------------------------------------------------------- /docs/community_projects.md: -------------------------------------------------------------------------------- 1 | # Community Projects 2 | 3 | If you have a project built on top of MLX-VLM let us know! We plan to showcase community examples and links here. 4 | 5 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | To work on MLX-VLM in editable mode run: 4 | 5 | ```bash 6 | pip install -e . 7 | ``` 8 | 9 | Check that the model weights are available in the `safetensors` format, convert if necessary and add the model file to `mlx_vlm/models`. 10 | 11 | Tests can be run from the `mlx_vlm/` directory: 12 | 13 | ```bash 14 | python -m unittest discover tests/ 15 | ``` 16 | 17 | Please format code using `pre-commit` before submitting a pull request. 18 | 19 | -------------------------------------------------------------------------------- /docs/examples.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | Example notebooks are available in the `examples/` directory: 4 | 5 | - [multi_image_generation.ipynb](https://github.com/Blaizzy/mlx-vlm/blob/main/examples/multi_image_generation.ipynb) 6 | - [object_detection.ipynb](https://github.com/Blaizzy/mlx-vlm/blob/main/examples/object_detection.ipynb) 7 | - [object_pointing.ipynb](https://github.com/Blaizzy/mlx-vlm/blob/main/examples/object_pointing.ipynb) 8 | - [ocr_with_region.ipynb](https://github.com/Blaizzy/mlx-vlm/blob/main/examples/ocr_with_region.ipynb) 9 | - [text_extraction.ipynb](https://github.com/Blaizzy/mlx-vlm/blob/main/examples/text_extraction.ipynb) 10 | - [video_understanding.ipynb](https://github.com/Blaizzy/mlx-vlm/blob/main/examples/video_understanding.ipynb) 11 | 12 | Images and videos used by the notebooks are stored in the `examples/images/` and `examples/videos/` folders. 13 | 14 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # MLX-VLM 2 | 3 | MLX-VLM is a package for inference and fine-tuning of Vision Language Models (VLMs) on Apple silicon using [MLX](https://github.com/ml-explore/mlx). 4 | 5 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | Install the package from PyPI: 4 | 5 | ```bash 6 | pip install mlx-vlm 7 | ``` 8 | 9 | -------------------------------------------------------------------------------- /docs/overrides/main.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block content %} 4 | {% if page.nb_url %} 5 | 6 | {% include ".icons/material/download.svg" %} 7 | 8 | {% endif %} 9 | 10 | {{ super() }} 11 | {% endblock content %} -------------------------------------------------------------------------------- /docs/report_issues.md: -------------------------------------------------------------------------------- 1 | # Report Issues 2 | 3 | Please open an issue on [GitHub](https://github.com/Blaizzy/mlx-vlm/issues) with clear steps to reproduce the problem. 4 | 5 | -------------------------------------------------------------------------------- /docs/usage.md: -------------------------------------------------------------------------------- 1 | # Usage 2 | 3 | ## Command Line Interface (CLI) 4 | 5 | Generate output from a model: 6 | 7 | ```bash 8 | python -m mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --temperature 0.0 --image http://images.cocodataset.org/val2017/000000039769.jpg 9 | ``` 10 | 11 | ## Chat UI with Gradio 12 | 13 | Launch the chat interface: 14 | 15 | ```bash 16 | python -m mlx_vlm.chat_ui --model mlx-community/Qwen2-VL-2B-Instruct-4bit 17 | ``` 18 | 19 | ## Python Script 20 | 21 | ```python 22 | from mlx_vlm import load, generate 23 | from mlx_vlm.prompt_utils import apply_chat_template 24 | from mlx_vlm.utils import load_config 25 | 26 | model_path = "mlx-community/Qwen2-VL-2B-Instruct-4bit" 27 | model, processor = load(model_path) 28 | config = load_config(model_path) 29 | 30 | image = ["http://images.cocodataset.org/val2017/000000039769.jpg"] 31 | prompt = "Describe this image." 32 | 33 | formatted_prompt = apply_chat_template(processor, config, prompt, num_images=len(image)) 34 | output = generate(model, processor, formatted_prompt, image, verbose=False) 35 | print(output) 36 | ``` 37 | 38 | ## Server (FastAPI) 39 | 40 | ```bash 41 | python -m mlx_vlm.server 42 | ``` 43 | 44 | See `README.md` for a complete `curl` example. 45 | 46 | -------------------------------------------------------------------------------- /examples/images/cats.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/examples/images/cats.jpg -------------------------------------------------------------------------------- /examples/images/desktop_setup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/examples/images/desktop_setup.png -------------------------------------------------------------------------------- /examples/images/graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/examples/images/graph.png -------------------------------------------------------------------------------- /examples/images/handwritten_hard.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/examples/images/handwritten_hard.webp -------------------------------------------------------------------------------- /examples/images/latex.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/examples/images/latex.png -------------------------------------------------------------------------------- /examples/images/menu.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/examples/images/menu.webp -------------------------------------------------------------------------------- /examples/images/paper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/examples/images/paper.png -------------------------------------------------------------------------------- /examples/images/renewables_california.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/examples/images/renewables_california.png -------------------------------------------------------------------------------- /examples/video_understanding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Video Understanding\n", 8 | "\n", 9 | "In this example, we will generate a description of a video using `Qwen2-VL`, `Qwen2-5-VL`, `LLava`, and `Idefics3`, with more models coming soon.\n", 10 | "\n", 11 | "This feature is currently in beta, may not work as expected.\n", 12 | "\n" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "## Install Dependencies" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "!pip install -U mlx-vlm" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "## Import Dependencies" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 1, 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "name": "stderr", 45 | "output_type": "stream", 46 | "text": [ 47 | "/opt/homebrew/Caskroom/miniconda/base/envs/mlx_code/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 48 | " from .autonotebook import tqdm as notebook_tqdm\n", 49 | "This is a beta version of the video understanding. It may not work as expected.\n" 50 | ] 51 | } 52 | ], 53 | "source": [ 54 | "from pprint import pprint\n", 55 | "from mlx_vlm import load\n", 56 | "from mlx_vlm.utils import generate\n", 57 | "from mlx_vlm.video_generate import process_vision_info\n", 58 | "\n", 59 | "import mlx.core as mx" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "# Load the model and processor\n", 69 | "model, processor = load(\"mlx-community/Qwen2.5-VL-7B-Instruct-4bit\")" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 3, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "name": "stderr", 79 | "output_type": "stream", 80 | "text": [ 81 | "numpy reader: video_path=videos/fastmlx_local_ai_hub.mp4, total_frames=1134, video_fps=59.941855343141576, time=0.000s\n" 82 | ] 83 | } 84 | ], 85 | "source": [ 86 | "# Messages containing a video and a text query\n", 87 | "messages = [\n", 88 | " {\n", 89 | " \"role\": \"user\",\n", 90 | " \"content\": [\n", 91 | " {\n", 92 | " \"type\": \"video\",\n", 93 | " \"video\": \"videos/fastmlx_local_ai_hub.mp4\",\n", 94 | " \"max_pixels\": 360 * 360,\n", 95 | " \"fps\": 1.0,\n", 96 | " },\n", 97 | " {\"type\": \"text\", \"text\": \"Describe this video.\"},\n", 98 | " ],\n", 99 | " }\n", 100 | "]\n", 101 | "\n", 102 | "# Preparation for inference\n", 103 | "text = processor.apply_chat_template(\n", 104 | " messages, tokenize=False, add_generation_prompt=True\n", 105 | ")\n", 106 | "image_inputs, video_inputs = process_vision_info(messages)\n", 107 | "inputs = processor(\n", 108 | " text=[text],\n", 109 | " images=image_inputs,\n", 110 | " videos=video_inputs,\n", 111 | " padding=True,\n", 112 | " return_tensors=\"pt\",\n", 113 | ")\n" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 4, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "# Convert inputs to mlx arrays\n", 123 | "input_ids = mx.array(inputs['input_ids'])\n", 124 | "pixel_values = mx.array(inputs['pixel_values_videos'])\n", 125 | "mask = mx.array(inputs['attention_mask'])\n", 126 | "image_grid_thw = mx.array(inputs['video_grid_thw'])\n", 127 | "\n", 128 | "kwargs = {\n", 129 | " \"image_grid_thw\": image_grid_thw,\n", 130 | "}" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 9, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "kwargs[\"video\"] = \"videos/fastmlx_local_ai_hub.mp4\"\n", 140 | "kwargs[\"input_ids\"] = input_ids\n", 141 | "kwargs[\"pixel_values\"] = pixel_values\n", 142 | "kwargs[\"mask\"] = mask\n", 143 | "response = generate(model, processor, prompt=text, temperature=0.7, max_tokens=100, **kwargs)" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 10, 149 | "metadata": {}, 150 | "outputs": [ 151 | { 152 | "name": "stdout", 153 | "output_type": "stream", 154 | "text": [ 155 | "('The video appears to be a live stream or a recording of a coding session, '\n", 156 | " 'likely on a platform like Discord, as indicated by the presence of text '\n", 157 | " \"chats and a streamer's interface. The video is primarily focused on a \"\n", 158 | " 'computer screen displaying a code editor with various programming languages '\n", 159 | " 'and snippets of code. The coder seems to be explaining or demonstrating '\n", 160 | " 'something related to the code, possibly working through a programming '\n", 161 | " 'problem, explaining the logic, or showing the process of solving a problem.\\n'\n", 162 | " '\\n'\n", 163 | " 'Here are some key observations from')\n" 164 | ] 165 | } 166 | ], 167 | "source": [ 168 | "pprint(response)\n" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "# open video and play it\n", 178 | "from ipywidgets import Video\n", 179 | "Video.from_file(\"videos/fastmlx_local_ai_hub.mp4\", width=320, height=240)" 180 | ] 181 | } 182 | ], 183 | "metadata": { 184 | "kernelspec": { 185 | "display_name": "mlx_code", 186 | "language": "python", 187 | "name": "python3" 188 | }, 189 | "language_info": { 190 | "codemirror_mode": { 191 | "name": "ipython", 192 | "version": 3 193 | }, 194 | "file_extension": ".py", 195 | "mimetype": "text/x-python", 196 | "name": "python", 197 | "nbconvert_exporter": "python", 198 | "pygments_lexer": "ipython3", 199 | "version": "3.11.11" 200 | } 201 | }, 202 | "nbformat": 4, 203 | "nbformat_minor": 2 204 | } 205 | -------------------------------------------------------------------------------- /examples/videos/fastmlx_local_ai_hub.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/examples/videos/fastmlx_local_ai_hub.mp4 -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: MLX-VLM 2 | site_description: MLX-VLM is a package for inference and fine-tuning of Vision Language Models (VLMs) on your Mac using MLX. 3 | site_author: Prince Canuma 4 | repo_name: Blaizzy/mlx-vlm 5 | site_url: https://Blaizzy.github.io/mlx-vlm 6 | repo_url: https://github.com/Blaizzy/mlx-vlm 7 | 8 | copyright: "Copyright \u00a9 2024 - 2024 Prince Canuma" 9 | 10 | theme: 11 | palette: 12 | - scheme: default 13 | primary: black 14 | toggle: 15 | icon: material/toggle-switch-off-outline 16 | name: Switch to dark mode 17 | - scheme: slate 18 | primary: black 19 | accent: indigo 20 | toggle: 21 | icon: material/toggle-switch 22 | name: Switch to light mode 23 | name: material 24 | icon: 25 | repo: fontawesome/brands/github 26 | features: 27 | - navigation.instant 28 | - navigation.tracking 29 | - navigation.top 30 | - navigation.footer 31 | - search.highlight 32 | - search.share 33 | - content.code.copy 34 | custom_dir: "docs/overrides" 35 | font: 36 | text: Google Sans 37 | code: Regular 38 | 39 | plugins: 40 | - search 41 | - mkdocstrings 42 | - mkdocs-jupyter: 43 | include_source: True 44 | ignore_h1_titles: True 45 | execute: True 46 | allow_errors: false 47 | ignore: ["conf.py"] 48 | execute_ignore: ["*ignore.ipynb"] 49 | 50 | markdown_extensions: 51 | - admonition 52 | - abbr 53 | - attr_list 54 | - def_list 55 | - footnotes 56 | - meta 57 | - md_in_html 58 | - pymdownx.superfences 59 | - pymdownx.highlight: 60 | linenums: true 61 | - toc: 62 | permalink: true 63 | 64 | extra: 65 | social: 66 | - icon: fontawesome/brands/github 67 | link: https://github.com/Blaizzy 68 | - icon: fontawesome/brands/twitter 69 | link: https://twitter.com/Prince_Canuma 70 | version: 71 | provider: mike 72 | consent: 73 | title: Cookie consent 74 | description: >- 75 | We use cookies to recognize your repeated visits and preferences, as well 76 | as to measure the effectiveness of our documentation and whether users 77 | find what they're searching for. With your consent, you're helping us to 78 | make our documentation better. 79 | 80 | extra_css: 81 | - stylesheets/extra.css 82 | 83 | nav: 84 | - Home: index.md 85 | - Installation: installation.md 86 | - CLI Reference: cli_reference.md 87 | - Examples: examples.md 88 | - Contributing: contributing.md 89 | - Community Projects: community_projects.md 90 | - Report Issues: report_issues.md 91 | - Changelog: changelog.md 92 | 93 | docs_dir: docs 94 | -------------------------------------------------------------------------------- /mlx_vlm/LORA.MD: -------------------------------------------------------------------------------- 1 | # LoRA Training Script 2 | 3 | ## Overview 4 | 5 | `lora.py` is a Python script for fine-tuning a vision language models (VLMs) using Low-Rank Adaptation (LoRA or QLoRA). This script allows you to train the model on your custom dataset, adjusting various parameters through command-line arguments. 6 | 7 | ## Requirements 8 | 9 | - Python 3.7+ 10 | - Required Python packages: `mlx-vlm`, `numpy`, `transformers`, `datasets`, `PIL` 11 | 12 | ## Supported Models 13 | - Qwen2 14 | - LLaVA (except for LLaVA-Next) 15 | - Pixtral 16 | - Idefics 2 17 | - Deepseek-VL 18 | - Paligemma 19 | - Mllama (Llama-3.2-vision) 20 | 21 | ## Coming Soon 22 | - LLaVA-Next 23 | - Phi3_vision 24 | 25 | ## Usage 26 | 27 | To use the script, run it from the command line with the desired arguments: 28 | 29 | ``` 30 | python lora.py --dataset /path/to/your/dataset [other options] 31 | ``` 32 | 33 | ## Dataset format 34 | 35 | The dataset should be a Hugging Face dataset with a `images` column and a `messages` column. 36 | 37 | ``` 38 | { 39 | "images": ..., 40 | "messages": ..., 41 | } 42 | ``` 43 | 44 | Support for other formats and column names will be added soon. 45 | 46 | ## Arguments 47 | 48 | The script accepts the following command-line arguments: 49 | 50 | - `--model-path`: Path to the pre-trained model (default: "mlx-community/Qwen2-VL-2B-Instruct-bf16") 51 | - `--dataset`: Path to your dataset (required) 52 | - `--learning-rate`: Learning rate for the optimizer (default: 1e-4) 53 | - `--batch-size`: Batch size for training (default: 1) 54 | - `--epochs`: Number of epochs to train (default: 1) 55 | - `--steps`: Number of steps per epoch (default: 0) 56 | - `--print-every`: Print loss every n steps (default: 10) 57 | - `--adapter-path`: Load path to resume training from a previously saved adapter (default: None) 58 | - `--output-path`: Path to save the trained adapter (default: "adapters.safetensors") 59 | 60 | ## Example 61 | 62 | Here's an example of how to run the script with custom parameters: 63 | 64 | ``` 65 | python lora.py --dataset /path/to/your/dataset --model-path /path/to/your/model --epochs 2 --batch-size 4 --learning-rate 5e-5 66 | ``` 67 | 68 | ## Output 69 | 70 | The script will print the training loss at regular intervals (defined by `--print-every`). After training, it will save the LoRA adapter to the specified output path. 71 | 72 | ## Note 73 | 74 | If you want to use QLoRA, you need to pass a pre-quantized model to the script using the `--model-path` argument (i.e. `mlx-community/Qwen2-VL-2B-Instruct-4bit`). 75 | Make sure you have the necessary permissions to read the dataset and write the output file. Also, ensure that your system has sufficient computational resources to handle the specified batch size and model. 76 | 77 | ## Contributing 78 | 79 | Feel free to submit issues or pull requests if you find any bugs or have suggestions for improvements. 80 | -------------------------------------------------------------------------------- /mlx_vlm/__init__.py: -------------------------------------------------------------------------------- 1 | from .prompt_utils import apply_chat_template, get_message_json 2 | from .utils import ( 3 | GenerationResult, 4 | convert, 5 | generate, 6 | load, 7 | prepare_inputs, 8 | process_image, 9 | quantize_model, 10 | stream_generate, 11 | ) 12 | from .version import __version__ 13 | -------------------------------------------------------------------------------- /mlx_vlm/chat_ui.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gradio as gr 4 | 5 | from mlx_vlm import load 6 | 7 | from .prompt_utils import get_chat_template, get_message_json 8 | from .utils import load, load_config, load_image_processor, stream_generate 9 | 10 | 11 | def parse_arguments(): 12 | parser = argparse.ArgumentParser( 13 | description="Generate text from an image using a model." 14 | ) 15 | parser.add_argument( 16 | "--model", 17 | type=str, 18 | default="qnguyen3/nanoLLaVA", 19 | help="The path to the local model directory or Hugging Face repo.", 20 | ) 21 | return parser.parse_args() 22 | 23 | 24 | args = parse_arguments() 25 | config = load_config(args.model) 26 | model, processor = load(args.model, processor_kwargs={"trust_remote_code": True}) 27 | image_processor = load_image_processor(args.model) 28 | 29 | 30 | def chat(message, history, temperature, max_tokens): 31 | if config["model_type"] != "paligemma": 32 | chat_history = [] 33 | for item in history: 34 | if isinstance(item[0], str): 35 | chat_history.append({"role": "user", "content": item[0]}) 36 | if item[1] is not None: 37 | chat_history.append({"role": "assistant", "content": item[1]}) 38 | 39 | chat_history.append({"role": "user", "content": message["text"]}) 40 | 41 | messages = [] 42 | for i, m in enumerate(chat_history): 43 | skip_token = True 44 | if i == len(chat_history) - 1 and m["role"] == "user": 45 | skip_token = False 46 | messages.append( 47 | get_message_json( 48 | config["model_type"], 49 | m["content"], 50 | role=m["role"], 51 | skip_image_token=skip_token, 52 | ) 53 | ) 54 | 55 | messages = get_chat_template(processor, messages, add_generation_prompt=True) 56 | 57 | else: 58 | messages = message["text"] 59 | 60 | files = "" 61 | if "files" in message and len(message["files"]) > 0: 62 | files = message["files"][-1] 63 | 64 | response = "" 65 | for chunk in stream_generate( 66 | model, 67 | processor, 68 | messages, 69 | files, 70 | max_tokens=max_tokens, 71 | temperature=temperature, 72 | ): 73 | response += chunk.text 74 | yield response 75 | 76 | 77 | demo = gr.ChatInterface( 78 | fn=chat, 79 | title="MLX-VLM Chat UI", 80 | additional_inputs_accordion=gr.Accordion( 81 | label="⚙️ Parameters", open=False, render=False 82 | ), 83 | additional_inputs=[ 84 | gr.Slider( 85 | minimum=0, maximum=1, step=0.1, value=0.1, label="Temperature", render=False 86 | ), 87 | gr.Slider( 88 | minimum=128, 89 | maximum=4096, 90 | step=1, 91 | value=200, 92 | label="Max new tokens", 93 | render=False, 94 | ), 95 | ], 96 | description=f"Now Running {args.model}", 97 | multimodal=True, 98 | ) 99 | 100 | demo.launch(inbrowser=True) 101 | -------------------------------------------------------------------------------- /mlx_vlm/convert.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import argparse 4 | 5 | from .utils import MODEL_CONVERSION_DTYPES, convert 6 | 7 | 8 | def configure_parser() -> argparse.ArgumentParser: 9 | """ 10 | Configures and returns the argument parser for the script. 11 | 12 | Returns: 13 | argparse.ArgumentParser: Configured argument parser. 14 | """ 15 | parser = argparse.ArgumentParser( 16 | description="Convert Hugging Face model to MLX format" 17 | ) 18 | 19 | parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.") 20 | parser.add_argument( 21 | "--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model." 22 | ) 23 | parser.add_argument( 24 | "-q", "--quantize", help="Generate a quantized model.", action="store_true" 25 | ) 26 | parser.add_argument( 27 | "--q-group-size", help="Group size for quantization.", type=int, default=64 28 | ) 29 | parser.add_argument( 30 | "--q-bits", help="Bits per weight for quantization.", type=int, default=4 31 | ) 32 | parser.add_argument( 33 | "--dtype", 34 | help="Type to save the parameter. Defaults to config.json's `torch_dtype` or the current model weights dtype", 35 | type=str, 36 | choices=MODEL_CONVERSION_DTYPES, 37 | default=None, 38 | ) 39 | parser.add_argument( 40 | "--upload-repo", 41 | help="The Hugging Face repo to upload the model to.", 42 | type=str, 43 | default=None, 44 | ) 45 | parser.add_argument( 46 | "-d", 47 | "--dequantize", 48 | help="Dequantize a quantized model.", 49 | action="store_true", 50 | default=False, 51 | ) 52 | parser.add_argument( 53 | "--skip-vision", 54 | help="Skip vision module quantization.", 55 | action="store_true", 56 | default=False, 57 | ) 58 | return parser 59 | 60 | 61 | def main(): 62 | parser = configure_parser() 63 | args = parser.parse_args() 64 | convert(**vars(args)) 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /mlx_vlm/generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import codecs 3 | 4 | from .prompt_utils import apply_chat_template 5 | from .utils import ( 6 | generate, 7 | get_model_path, 8 | load, 9 | load_config, 10 | load_image_processor, 11 | stream_generate, 12 | ) 13 | 14 | DEFAULT_MODEL_PATH = "mlx-community/nanoLLaVA-1.5-8bit" 15 | DEFAULT_IMAGE = [] 16 | DEFAULT_PROMPT = "What are these?" 17 | DEFAULT_MAX_TOKENS = 256 18 | DEFAULT_TEMPERATURE = 0.5 19 | DEFAULT_TOP_P = 1.0 20 | DEFAULT_SEED = 0 21 | 22 | 23 | def parse_arguments(): 24 | parser = argparse.ArgumentParser( 25 | description="Generate text from an image using a model." 26 | ) 27 | parser.add_argument( 28 | "--model", 29 | type=str, 30 | default=DEFAULT_MODEL_PATH, 31 | help="The path to the local model directory or Hugging Face repo.", 32 | ) 33 | parser.add_argument( 34 | "--adapter-path", 35 | type=str, 36 | default=None, 37 | help="The path to the adapter weights.", 38 | ) 39 | parser.add_argument( 40 | "--image", 41 | type=str, 42 | nargs="+", 43 | default=DEFAULT_IMAGE, 44 | help="URL or path of the image to process.", 45 | ) 46 | parser.add_argument( 47 | "--resize-shape", 48 | type=int, 49 | nargs="+", 50 | default=None, 51 | help="Resize shape for the image.", 52 | ) 53 | parser.add_argument( 54 | "--prompt", 55 | type=str, 56 | default=DEFAULT_PROMPT, 57 | help="Message to be processed by the model.", 58 | ) 59 | parser.add_argument( 60 | "--system", 61 | type=str, 62 | default=None, 63 | help="System message for the model.", 64 | ) 65 | parser.add_argument( 66 | "--max-tokens", 67 | type=int, 68 | default=DEFAULT_MAX_TOKENS, 69 | help="Maximum number of tokens to generate.", 70 | ) 71 | parser.add_argument( 72 | "--temperature", 73 | type=float, 74 | default=DEFAULT_TEMPERATURE, 75 | help="Temperature for sampling.", 76 | ) 77 | parser.add_argument("--chat", action="store_true", help="Chat in multi-turn style.") 78 | parser.add_argument("--verbose", action="store_false", help="Detailed output.") 79 | parser.add_argument( 80 | "--eos-tokens", 81 | type=str, 82 | nargs="+", 83 | default=None, 84 | help="EOS tokens to add to the tokenizer.", 85 | ) 86 | parser.add_argument( 87 | "--skip-special-tokens", 88 | action="store_true", 89 | help="Skip special tokens in the detokenizer.", 90 | ) 91 | 92 | return parser.parse_args() 93 | 94 | 95 | def get_model_and_processors(model_path, adapter_path): 96 | model_path = get_model_path(model_path) 97 | config = load_config(model_path, trust_remote_code=True) 98 | model, processor = load( 99 | model_path, adapter_path=adapter_path, lazy=False, trust_remote_code=True 100 | ) 101 | return model, processor, config 102 | 103 | 104 | def main(): 105 | args = parse_arguments() 106 | if isinstance(args.image, str): 107 | args.image = [args.image] 108 | 109 | model, processor, config = get_model_and_processors(args.model, args.adapter_path) 110 | 111 | prompt = codecs.decode(args.prompt, "unicode_escape") 112 | 113 | prompt = apply_chat_template(processor, config, prompt, num_images=len(args.image)) 114 | 115 | kwargs = {} 116 | 117 | if args.resize_shape is not None: 118 | if len(args.resize_shape) not in [1, 2]: 119 | raise ValueError("Resize shape must be 1 or 2 integers") 120 | kwargs["resize_shape"] = ( 121 | (args.resize_shape[0],) * 2 122 | if len(args.resize_shape) == 1 123 | else tuple(args.resize_shape) 124 | ) 125 | 126 | if args.eos_tokens is not None: 127 | kwargs["eos_tokens"] = [ 128 | codecs.decode(token, "unicode_escape") for token in args.eos_tokens 129 | ] 130 | 131 | if args.skip_special_tokens: 132 | kwargs["skip_special_tokens"] = args.skip_special_tokens 133 | 134 | if args.chat: 135 | chat = [] 136 | if args.system: 137 | chat.append({"role": "system", "content": args.system}) 138 | while user := input("User:"): 139 | chat.append({"role": "user", "content": user}) 140 | prompt = apply_chat_template( 141 | processor, config, chat, num_images=len(args.image) 142 | ) 143 | response = "" 144 | print("Assistant:", end="") 145 | for chunk in stream_generate( 146 | model, 147 | processor, 148 | prompt, 149 | args.image, 150 | max_tokens=args.max_tokens, 151 | temperature=args.temperature, 152 | **kwargs, 153 | ): 154 | response += chunk.text 155 | print(chunk.text, end="") 156 | 157 | chat.append({"role": "assistant", "content": response}) 158 | print() 159 | 160 | else: 161 | output = generate( 162 | model, 163 | processor, 164 | prompt, 165 | image=args.image, 166 | temperature=args.temperature, 167 | max_tokens=args.max_tokens, 168 | verbose=args.verbose, 169 | **kwargs, 170 | ) 171 | if not args.verbose: 172 | print(output) 173 | 174 | 175 | if __name__ == "__main__": 176 | main() 177 | -------------------------------------------------------------------------------- /mlx_vlm/lora.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | 5 | import mlx.optimizers as optim 6 | from datasets import load_dataset 7 | from tqdm import tqdm 8 | 9 | from .prompt_utils import apply_chat_template 10 | from .trainer import Dataset, Trainer, save_adapter 11 | from .trainer.utils import apply_lora_layers, find_all_linear_names, get_peft_model 12 | from .utils import load, load_image_processor 13 | 14 | logging.basicConfig(level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def custom_print(*args, **kwargs): 19 | tqdm.write(" ".join(map(str, args)), **kwargs) 20 | 21 | 22 | def main(args): 23 | logger.info(f"\033[32mLoading model from {args.model_path}\033[0m") 24 | model, processor = load( 25 | args.model_path, processor_config={"trust_remote_code": True} 26 | ) 27 | config = model.config.__dict__ 28 | image_processor = load_image_processor(args.model_path) 29 | 30 | logger.info(f"\033[32mLoading dataset from {args.dataset}\033[0m") 31 | dataset = load_dataset(args.dataset, split=args.split) 32 | 33 | if "messages" not in dataset.column_names: 34 | raise ValueError("Dataset must have a 'messages' column") 35 | if "images" not in dataset.column_names: 36 | raise ValueError("Dataset must have an 'images' column") 37 | 38 | if args.apply_chat_template: 39 | logger.info(f"\033[32mApplying chat template to the dataset\033[0m") 40 | 41 | def process_data(examples): 42 | if config["model_type"] == "pixtral": 43 | conversations = apply_chat_template( 44 | config=config, 45 | processor=processor, 46 | prompt=examples["messages"], 47 | return_messages=True, 48 | ) 49 | examples["messages"] = [ 50 | json.dumps(item, ensure_ascii=False) for item in conversations 51 | ] 52 | else: 53 | examples["messages"] = apply_chat_template( 54 | config=config, 55 | processor=processor, 56 | prompt=examples["messages"], 57 | return_messages=True, 58 | ) 59 | return examples 60 | 61 | dataset = dataset.map(process_data) 62 | 63 | dataset = Dataset( 64 | dataset, 65 | config, 66 | processor, 67 | image_processor=image_processor, 68 | image_resize_shape=args.image_resize_shape, 69 | ) 70 | 71 | adapter_path = args.adapter_path 72 | if adapter_path: 73 | logger.info(f"\033[32mResuming from adapter path {adapter_path}\033[0m") 74 | logger.info( 75 | f"\033[32mLora rank, alpha, and dropout will be loaded from adapter_config.json file\033[0m" 76 | ) 77 | 78 | model = apply_lora_layers(model, adapter_path) 79 | 80 | else: 81 | logger.info(f"\033[32mSetting up LoRA\033[0m") 82 | 83 | list_of_modules = find_all_linear_names(model.language_model) 84 | model = get_peft_model( 85 | model, 86 | list_of_modules, 87 | rank=args.lora_rank, 88 | alpha=args.lora_alpha, 89 | dropout=args.lora_dropout, 90 | ) 91 | 92 | logger.info(f"\033[32mSetting up optimizer\033[0m") 93 | optimizer = optim.Adam(learning_rate=args.learning_rate) 94 | 95 | logger.info(f"\033[32mSetting up trainer\033[0m") 96 | trainer = Trainer(model, optimizer) 97 | 98 | model.train() 99 | 100 | # Training loop 101 | logger.info(f"\033[32mTraining model\033[0m") 102 | for epoch in range(args.epochs): 103 | if args.steps == 0: 104 | args.steps = len(dataset) // args.batch_size 105 | 106 | progress_bar = tqdm(range(args.steps), position=0, leave=True) 107 | for i in progress_bar: 108 | loss = trainer.train_step( 109 | dataset[i * args.batch_size : (i + 1) * args.batch_size] 110 | ) 111 | # Update progress bar 112 | progress_bar.update(1) 113 | progress_bar.set_postfix( 114 | {"Epoch": epoch, "Step": i, "Loss": f"{loss.item():.4f}"} 115 | ) 116 | 117 | if i % args.print_every == 0: 118 | # Log additional information 119 | custom_print( 120 | { 121 | "Epoch": epoch, 122 | "Step": i, 123 | "Loss": f"{loss.item():.4f}", 124 | } 125 | ) 126 | 127 | # Save the adapter 128 | save_adapter(model, args.output_path) 129 | 130 | 131 | if __name__ == "__main__": 132 | parser = argparse.ArgumentParser(description="Train NanoLLaVA model") 133 | parser.add_argument( 134 | "--model-path", 135 | type=str, 136 | default="mlx-community/Qwen2-VL-2B-Instruct-bf16", 137 | help="Path to the pre-trained model", 138 | ) 139 | parser.add_argument( 140 | "--dataset", type=str, required=True, help="Path to the dataset" 141 | ) 142 | parser.add_argument( 143 | "--split", type=str, default="train", help="Split to use for training" 144 | ) 145 | parser.add_argument( 146 | "--image-resize-shape", 147 | type=int, 148 | nargs=2, 149 | default=None, 150 | help="Resize images to this shape", 151 | ) 152 | parser.add_argument( 153 | "--apply-chat-template", 154 | action="store_false", 155 | help="Apply chat template to the dataset", 156 | ) 157 | parser.add_argument( 158 | "--learning-rate", 159 | type=float, 160 | default=1e-4, 161 | help="Learning rate for the optimizer", 162 | ) 163 | parser.add_argument( 164 | "--batch-size", type=int, default=1, help="Batch size for training" 165 | ) 166 | parser.add_argument( 167 | "--epochs", type=int, default=1, help="Number of epochs to train" 168 | ) 169 | parser.add_argument( 170 | "--steps", type=int, default=0, help="Number of steps per epoch" 171 | ) 172 | parser.add_argument( 173 | "--print-every", type=int, default=10, help="Print loss every n steps" 174 | ) 175 | parser.add_argument( 176 | "--lora-alpha", type=int, default=0.1, help="LoRA alpha parameter" 177 | ) 178 | parser.add_argument("--lora-rank", type=int, default=10, help="LoRA rank") 179 | parser.add_argument("--lora-dropout", type=float, default=0.1, help="LoRA dropout") 180 | parser.add_argument( 181 | "--output-path", 182 | type=str, 183 | default="adapters", 184 | help="Path to save the trained adapter", 185 | ) 186 | parser.add_argument( 187 | "--adapter-path", 188 | type=str, 189 | default=None, 190 | help="Load path to resume training from a previously saved adapter", 191 | ) 192 | 193 | args = parser.parse_args() 194 | main(args) 195 | -------------------------------------------------------------------------------- /mlx_vlm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blaizzy/mlx-vlm/09724c1453c2f6dfed8c975cf393dc4e8dc8955d/mlx_vlm/models/__init__.py -------------------------------------------------------------------------------- /mlx_vlm/models/aya_vision/__init__.py: -------------------------------------------------------------------------------- 1 | from .aya_vision import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/aya_vision/interpolate.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import numpy as np 3 | 4 | 5 | def gaussian_blur_axis(image, sigma, axis): 6 | """ 7 | Applies a 1D Gaussian blur along the given axis. 8 | This version works for arrays with any number of dimensions. 9 | """ 10 | radius = int(3 * sigma) 11 | if radius < 1: 12 | return image 13 | x = mx.arange(-radius, radius + 1) 14 | kernel = mx.exp(-(x**2) / (2 * sigma**2)) 15 | kernel = kernel / mx.sum(kernel) 16 | 17 | # MLX doesn't have a direct apply_along_axis equivalent, 18 | # so we'll implement the convolution differently based on the axis 19 | 20 | # Helper function to apply 1D convolution along specific axis 21 | def conv_1d(array, kernel, axis): 22 | # Reshape kernel to broadcast along the right dimensions 23 | kernel_shape = [1] * image.ndim 24 | kernel_shape[axis] = len(kernel) 25 | kernel_reshaped = kernel.reshape(kernel_shape) 26 | 27 | # Pad the array 28 | pad_width = [(0, 0)] * image.ndim 29 | pad_width[axis] = (radius, radius) 30 | padded = mx.pad(array, pad_width, mode="edge") 31 | 32 | # Perform convolution via sliding window sum 33 | result = mx.zeros_like(array) 34 | slices = [slice(None)] * padded.ndim 35 | 36 | for i in range(2 * radius + 1): 37 | slices[axis] = slice(i, i + array.shape[axis]) 38 | result = result + padded[tuple(slices)] * kernel_reshaped 39 | 40 | return result 41 | 42 | return conv_1d(image, kernel, axis) 43 | 44 | 45 | def bilinear_interpolate(image, new_height, new_width, align_corners=False): 46 | """ 47 | Performs bilinear interpolation on an array whose spatial dimensions are the first two. 48 | It supports extra dimensions (e.g. channels or batch dimensions that have been moved to the trailing axes). 49 | """ 50 | # image is assumed to have shape (H, W, ...) where H and W are spatial dimensions. 51 | H_in, W_in = image.shape[0], image.shape[1] 52 | 53 | # Compute sampling positions in the input image. 54 | if new_height == 1: 55 | row_positions = mx.array([0.0]) 56 | else: 57 | if align_corners: 58 | row_positions = mx.linspace(0, H_in - 1, new_height) 59 | else: 60 | row_positions = (mx.arange(new_height) + 0.5) * H_in / new_height - 0.5 61 | 62 | if new_width == 1: 63 | col_positions = mx.array([0.0]) 64 | else: 65 | if align_corners: 66 | col_positions = mx.linspace(0, W_in - 1, new_width) 67 | else: 68 | col_positions = (mx.arange(new_width) + 0.5) * W_in / new_width - 0.5 69 | 70 | # Compute floor and ceil indices. 71 | row_floor = mx.floor(row_positions).astype(mx.int32) 72 | col_floor = mx.floor(col_positions).astype(mx.int32) 73 | row_ceil = row_floor + 1 74 | col_ceil = col_floor + 1 75 | 76 | row_floor = mx.clip(row_floor, 0, H_in - 1) 77 | row_ceil = mx.clip(row_ceil, 0, H_in - 1) 78 | col_floor = mx.clip(col_floor, 0, W_in - 1) 79 | col_ceil = mx.clip(col_ceil, 0, W_in - 1) 80 | 81 | row_weight = row_positions - row_floor # shape (new_height,) 82 | col_weight = col_positions - col_floor # shape (new_width,) 83 | 84 | # Use advanced indexing for gather operations 85 | # Create meshgrid for coordinates 86 | row_floor_grid, col_floor_grid = mx.meshgrid(row_floor, col_floor, indexing="ij") 87 | row_ceil_grid, col_floor_grid = mx.meshgrid(row_ceil, col_floor, indexing="ij") 88 | row_floor_grid, col_ceil_grid = mx.meshgrid(row_floor, col_ceil, indexing="ij") 89 | row_ceil_grid, col_ceil_grid = mx.meshgrid(row_ceil, col_ceil, indexing="ij") 90 | 91 | # Gather the four surrounding pixels using take_along_axis 92 | # For higher dimensional arrays, we'll need to reshape and broadcast 93 | extra_dims = image.ndim - 2 94 | 95 | def gather_pixels(row_indices, col_indices): 96 | # Flatten the spatial dimensions for gathering 97 | flat_indices = row_indices * W_in + col_indices 98 | flat_image = mx.reshape(image, (-1,) + image.shape[2:]) 99 | # Gather and reshape back 100 | gathered = mx.take(flat_image, flat_indices.reshape(-1), axis=0) 101 | return mx.reshape(gathered, (new_height, new_width) + image.shape[2:]) 102 | 103 | top_left = gather_pixels(row_floor_grid, col_floor_grid) 104 | top_right = gather_pixels(row_floor_grid, col_ceil_grid) 105 | bottom_left = gather_pixels(row_ceil_grid, col_floor_grid) 106 | bottom_right = gather_pixels(row_ceil_grid, col_ceil_grid) 107 | 108 | # Expand the weights to have shape (new_height, new_width, *[1]*extra_dims) 109 | r_weight = row_weight.reshape(new_height, 1, *([1] * extra_dims)) 110 | c_weight = col_weight.reshape(1, new_width, *([1] * extra_dims)) 111 | 112 | # Perform bilinear interpolation. 113 | result = ( 114 | (1 - r_weight) * (1 - c_weight) * top_left 115 | + (1 - r_weight) * c_weight * top_right 116 | + r_weight * (1 - c_weight) * bottom_left 117 | + r_weight * c_weight * bottom_right 118 | ) 119 | return result 120 | 121 | 122 | def resize_bilinear(image, new_size, align_corners=False, antialias=True): 123 | """ 124 | Resizes an image (or embedding tensor) to new_size=(new_height, new_width) 125 | using bilinear interpolation with MLX. 126 | 127 | Supports: 128 | - 2D: (H, W) 129 | - 3D: (H, W, C) 130 | - 4D: (B, C, H, W) (assumed for typical image batches) 131 | """ 132 | new_height, new_width = new_size 133 | 134 | # Convert numpy arrays to MLX arrays if needed 135 | if isinstance(image, np.ndarray): 136 | image = mx.array(image) 137 | 138 | if image.ndim == 2 or image.ndim == 3: 139 | # Assume spatial dims are the first two. 140 | resized = image 141 | H_in, W_in = image.shape[:2] 142 | if antialias: 143 | if new_height < H_in: 144 | scale_y = new_height / H_in 145 | sigma_y = (1 / scale_y - 1) / 2.0 # heuristic 146 | if sigma_y > 0: 147 | resized = gaussian_blur_axis(resized, sigma_y, axis=0) 148 | if new_width < W_in: 149 | scale_x = new_width / W_in 150 | sigma_x = (1 / scale_x - 1) / 2.0 151 | if sigma_x > 0: 152 | resized = gaussian_blur_axis(resized, sigma_x, axis=1) 153 | resized = bilinear_interpolate( 154 | resized, new_height, new_width, align_corners=align_corners 155 | ) 156 | return resized 157 | 158 | elif image.ndim == 4: 159 | # Assume shape is (B, C, H, W) (typical PyTorch/MLX format). 160 | B, C, H_in, W_in = image.shape 161 | # Permute to bring spatial dims to the front: (H, W, B, C) 162 | image_perm = mx.transpose(image, (2, 3, 0, 1)) 163 | resized = image_perm 164 | if antialias: 165 | if new_height < H_in: 166 | scale_y = new_height / H_in 167 | sigma_y = (1 / scale_y - 1) / 2.0 168 | if sigma_y > 0: 169 | resized = gaussian_blur_axis(resized, sigma_y, axis=0) 170 | if new_width < W_in: 171 | scale_x = new_width / W_in 172 | sigma_x = (1 / scale_x - 1) / 2.0 173 | if sigma_x > 0: 174 | resized = gaussian_blur_axis(resized, sigma_x, axis=1) 175 | resized = bilinear_interpolate( 176 | resized, new_height, new_width, align_corners=align_corners 177 | ) 178 | # Permute back to (B, C, new_height, new_width) 179 | resized = mx.transpose(resized, (2, 3, 0, 1)) 180 | return resized 181 | 182 | else: 183 | raise ValueError("Unsupported image dimensions.") 184 | 185 | 186 | # 187 | -------------------------------------------------------------------------------- /mlx_vlm/models/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | from abc import ABC, abstractmethod 3 | from dataclasses import dataclass 4 | from typing import Any, Dict, List, Optional 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | from mlx_lm.models.base import create_attention_mask 9 | from mlx_lm.models.cache import RotatingKVCache 10 | from PIL import Image 11 | from transformers.image_processing_utils import BaseImageProcessor as ImageProcessor 12 | from transformers.image_processing_utils import get_size_dict 13 | from transformers.image_utils import ChannelDimension, PILImageResampling 14 | 15 | 16 | @dataclass 17 | class LanguageModelOutput: 18 | logits: mx.array 19 | cross_attention_states: Optional[List[mx.array]] = None 20 | encoder_outputs: Optional[List[mx.array]] = None 21 | 22 | 23 | def expand2square(pil_img, background_color): 24 | width, height = pil_img.size 25 | if width == height: 26 | return pil_img 27 | elif width > height: 28 | result = Image.new(pil_img.mode, (width, width), background_color) 29 | result.paste(pil_img, (0, (width - height) // 2)) 30 | return result 31 | else: 32 | result = Image.new(pil_img.mode, (height, height), background_color) 33 | result.paste(pil_img, ((height - width) // 2, 0)) 34 | return result 35 | 36 | 37 | class BaseImageProcessor(ImageProcessor): 38 | def __init__( 39 | self, 40 | image_mean=(0.5, 0.5, 0.5), 41 | image_std=(0.5, 0.5, 0.5), 42 | size=(384, 384), 43 | crop_size: Dict[str, int] = None, 44 | resample=PILImageResampling.BICUBIC, 45 | rescale_factor=1 / 255, 46 | data_format=ChannelDimension.FIRST, 47 | ): 48 | crop_size = ( 49 | crop_size if crop_size is not None else {"height": 384, "width": 384} 50 | ) 51 | crop_size = get_size_dict( 52 | crop_size, default_to_square=True, param_name="crop_size" 53 | ) 54 | 55 | self.image_mean = image_mean 56 | self.image_std = image_std 57 | self.size = size 58 | self.resample = resample 59 | self.rescale_factor = rescale_factor 60 | self.data_format = data_format 61 | self.crop_size = crop_size 62 | 63 | @abstractmethod 64 | def preprocess(self, images): 65 | pass 66 | 67 | 68 | # Add this code to visualize the chunked attention mask 69 | def visualize_attention_mask(mask): 70 | """Visualize attention mask with symbols for better readability.""" 71 | if mask is None: 72 | print("No mask") 73 | return 74 | 75 | seq_len = mask.shape[0] 76 | 77 | print(" ", end="") 78 | for i in range(seq_len): 79 | print(f"{i:2d} ", end="") 80 | print() 81 | 82 | for i in range(seq_len): 83 | print(f"Token {i:2d}: ", end="") 84 | for j in range(seq_len): 85 | if mask[i, j]: 86 | print(" ■ ", end="") 87 | else: 88 | print(" ⬚ ", end="") 89 | print() 90 | 91 | 92 | def check_activation_stats(name, tensor): 93 | """Helper function to check for anomalies and log stats.""" 94 | 95 | print(f"--- Activation Stats: {name} ---") 96 | # Check for NaNs/Infs 97 | has_nan = mx.isnan(tensor).any() 98 | has_inf = mx.isinf(tensor).any() 99 | if has_nan: 100 | print(f"WARNING: Found NaN in {name}") 101 | if has_inf: 102 | print(f"WARNING: Found Inf in {name}") 103 | 104 | # Calculate and print stats (ensure computation happens) 105 | min_val = mx.min(tensor).item() 106 | max_val = mx.max(tensor).item() 107 | mean_val = mx.mean(tensor).item() 108 | std_val = mx.std(tensor).item() 109 | print(f" Shape: {tensor.shape}") 110 | print(f" Min: {min_val:.4f}, Max: {max_val:.4f}") 111 | print(f" Mean: {mean_val:.4f}, Std: {std_val:.4f}") 112 | print("-" * (len(name) + 24)) 113 | 114 | 115 | def pixel_shuffle(input_tensor, shuffle_ratio): 116 | # input_tensor: [batch_size, num_patches, channels] 117 | batch_size, num_patches, channels = input_tensor.shape 118 | patch_size = int(math.sqrt(num_patches)) 119 | 120 | input_tensor = input_tensor.reshape(batch_size, patch_size, patch_size, -1) 121 | batch_size, height, width, channels = input_tensor.shape 122 | 123 | reshaped_tensor = input_tensor.reshape( 124 | batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio) 125 | ) 126 | reshaped_tensor = reshaped_tensor.transpose(0, 2, 1, 3) 127 | 128 | reshaped_tensor = reshaped_tensor.reshape( 129 | batch_size, 130 | int(height * shuffle_ratio), 131 | int(width * shuffle_ratio), 132 | int(channels / (shuffle_ratio**2)), 133 | ) 134 | reshaped_tensor = reshaped_tensor.transpose(0, 2, 1, 3) 135 | 136 | output_tensor = reshaped_tensor.reshape(batch_size, -1, reshaped_tensor.shape[-1]) 137 | return output_tensor 138 | 139 | 140 | def interpolate(pos_embed, size, mode="cubic", align_corners=False): 141 | """ 142 | MLX implementation of PyTorch's F.interpolate with bicubic mode 143 | 144 | Args: 145 | pos_embed: MLX array with shape [B, C, H_src, W_src] or [C, H_src, W_src] 146 | size: Tuple (H_dst, W_dst) - target size 147 | align_corners: Boolean - whether to align corners 148 | 149 | Returns: 150 | Interpolated array with shape [B, C, H_dst, W_dst] or [C, H_dst, W_dst] 151 | """ 152 | # Handle different input shapes 153 | input_dim = pos_embed.ndim 154 | original_shape = pos_embed.shape 155 | 156 | if input_dim == 3: 157 | # [C, H, W] -> [1, C, H, W] 158 | pos_embed = pos_embed.reshape(1, *original_shape) 159 | 160 | # Get source dimensions 161 | h_src, w_src = pos_embed.shape[-2:] 162 | h_dst, w_dst = size 163 | 164 | # Calculate scale factors 165 | scale_h = h_dst / h_src 166 | scale_w = w_dst / w_src 167 | 168 | # Create upsampler 169 | upsampler = nn.Upsample( 170 | scale_factor=(scale_h, scale_w), mode=mode, align_corners=align_corners 171 | ) 172 | 173 | # Apply upsampling 174 | result = upsampler(pos_embed) 175 | 176 | # Return in the original dimension format 177 | if input_dim == 3: 178 | return result.reshape(original_shape[0], *size) 179 | return result 180 | -------------------------------------------------------------------------------- /mlx_vlm/models/cache.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | from mlx_lm.models.cache import ChunkedKVCache, KVCache, RotatingKVCache, _BaseCache 3 | 4 | 5 | class SimpleKVCache: 6 | """A simple key-value cache for transformer attention layers. 7 | 8 | Stores and concatenates key/value tensors along sequence dimension. 9 | """ 10 | 11 | def __init__(self): 12 | self.keys = None 13 | self.values = None 14 | self.cache_length = 0 15 | 16 | def update_and_fetch(self, keys, values): 17 | """Update cache with new key/value tensors and return full cache. 18 | 19 | Args: 20 | keys: New key tensor to add [batch, heads, seq_len, head_dim] 21 | values: New value tensor to add [batch, heads, seq_len, head_dim] 22 | 23 | Returns: 24 | Tuple of (cached_keys, cached_values) containing full cache history 25 | """ 26 | if self.cache_length == 0: 27 | # First update - just store tensors 28 | self.keys = keys 29 | self.values = values 30 | else: 31 | # Concatenate with existing cache along sequence dimension 32 | self.keys = mx.concatenate([self.keys, keys], axis=2) 33 | self.values = mx.concatenate([self.values, values], axis=2) 34 | 35 | self.cache_length += keys.shape[2] 36 | return self.keys, self.values 37 | 38 | def fetch(self): 39 | return self.keys, self.values 40 | 41 | def update(self, keys, values): 42 | """Update cache with new key/value tensors without returning. 43 | 44 | Args: 45 | keys: New key tensor to store 46 | values: New value tensor to store 47 | """ 48 | self.keys = keys 49 | self.values = values 50 | self.cache_length += keys.shape[2] 51 | -------------------------------------------------------------------------------- /mlx_vlm/models/deepseek_vl_v2/__init__.py: -------------------------------------------------------------------------------- 1 | from .deepseek_vl_v2 import ( 2 | DeepseekVLV2Processor, 3 | LanguageModel, 4 | Model, 5 | ModelConfig, 6 | ProjectorConfig, 7 | TextConfig, 8 | VisionConfig, 9 | VisionModel, 10 | ) 11 | -------------------------------------------------------------------------------- /mlx_vlm/models/florence2/__init__.py: -------------------------------------------------------------------------------- 1 | from .florence2 import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/gemma3/__init__.py: -------------------------------------------------------------------------------- 1 | from .gemma3 import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/idefics2/__init__.py: -------------------------------------------------------------------------------- 1 | from .idefics2 import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | PerceiverConfig, 6 | TextConfig, 7 | VisionConfig, 8 | VisionModel, 9 | ) 10 | -------------------------------------------------------------------------------- /mlx_vlm/models/idefics2/language.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | 8 | from ..base import LanguageModelOutput, create_attention_mask 9 | from ..cache import KVCache 10 | 11 | 12 | @dataclass 13 | class TextConfig: 14 | model_type: str 15 | hidden_size: int 16 | num_hidden_layers: int 17 | intermediate_size: int 18 | num_attention_heads: int 19 | rms_norm_eps: float 20 | vocab_size: int 21 | num_key_value_heads: int 22 | rope_theta: float = 1000000.0 23 | rope_traditional: bool = False 24 | max_position_embeddings: int = 4096 25 | tie_word_embeddings: bool = False 26 | 27 | @classmethod 28 | def from_dict(cls, params): 29 | return cls( 30 | **{ 31 | k: v 32 | for k, v in params.items() 33 | if k in inspect.signature(cls).parameters 34 | } 35 | ) 36 | 37 | def __post_init__(self): 38 | if self.num_key_value_heads is None: 39 | self.num_key_value_heads = self.num_attention_heads 40 | 41 | 42 | class Attention(nn.Module): 43 | def __init__(self, config: TextConfig): 44 | super().__init__() 45 | 46 | dim = config.hidden_size 47 | self.n_heads = n_heads = config.num_attention_heads 48 | self.n_kv_heads = n_kv_heads = config.num_key_value_heads 49 | 50 | head_dim = config.hidden_size // n_heads 51 | self.scale = head_dim**-0.5 52 | 53 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) 54 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 55 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 56 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 57 | 58 | self.rope = nn.RoPE( 59 | head_dim, 60 | traditional=config.rope_traditional, 61 | base=config.rope_theta, 62 | ) 63 | 64 | def __call__( 65 | self, 66 | x: mx.array, 67 | mask: Optional[mx.array] = None, 68 | cache: Optional[KVCache] = None, 69 | ) -> mx.array: 70 | B, L, D = x.shape 71 | 72 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 73 | 74 | # Prepare the queries, keys and values for the attention computation 75 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 76 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 77 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 78 | 79 | if cache is not None: 80 | queries = self.rope(queries, offset=cache.offset) 81 | keys = self.rope(keys, offset=cache.offset) 82 | keys, values = cache.update_and_fetch(keys, values) 83 | else: 84 | queries = self.rope(queries) 85 | keys = self.rope(keys) 86 | 87 | output = mx.fast.scaled_dot_product_attention( 88 | queries, keys, values, scale=self.scale, mask=mask 89 | ) 90 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 91 | return self.o_proj(output) 92 | 93 | 94 | class MLP(nn.Module): 95 | def __init__(self, dim, hidden_dim): 96 | super().__init__() 97 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 98 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 99 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 100 | 101 | def __call__(self, x) -> mx.array: 102 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 103 | 104 | 105 | class TransformerBlock(nn.Module): 106 | def __init__(self, config: TextConfig): 107 | super().__init__() 108 | self.num_attention_heads = config.num_attention_heads 109 | self.hidden_size = config.hidden_size 110 | self.self_attn = Attention(config) 111 | self.mlp = MLP(config.hidden_size, config.intermediate_size) 112 | self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 113 | self.post_attention_layernorm = nn.RMSNorm( 114 | config.hidden_size, eps=config.rms_norm_eps 115 | ) 116 | self.config = config 117 | 118 | def __call__( 119 | self, 120 | x: mx.array, 121 | mask: Optional[mx.array] = None, 122 | cache: Optional[KVCache] = None, 123 | ) -> mx.array: 124 | r = self.self_attn(self.input_layernorm(x), mask, cache) 125 | h = x + r 126 | r = self.mlp(self.post_attention_layernorm(h)) 127 | out = h + r 128 | return out 129 | 130 | 131 | class LanguageModel(nn.Module): 132 | def __init__(self, config: TextConfig): 133 | super().__init__() 134 | self.config = config 135 | self.model_type = config.model_type 136 | self.vocab_size = config.vocab_size 137 | self.num_hidden_layers = config.num_hidden_layers 138 | assert self.vocab_size > 0 139 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) 140 | self.layers = [ 141 | TransformerBlock(config=config) for _ in range(config.num_hidden_layers) 142 | ] 143 | self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 144 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 145 | 146 | def __call__( 147 | self, 148 | inputs: mx.array, 149 | inputs_embeds: Optional[mx.array] = None, 150 | mask: Optional[mx.array] = None, 151 | cache=None, 152 | ): 153 | # for passing merged input embeddings 154 | if inputs_embeds is None: 155 | h = self.embed_tokens(inputs) 156 | else: 157 | h = inputs_embeds 158 | 159 | if cache is None: 160 | cache = [None] * len(self.layers) 161 | 162 | if mask is None: 163 | mask = create_attention_mask(h, cache) 164 | 165 | for layer, c in zip(self.layers, cache): 166 | h = layer(h, mask, c) 167 | 168 | logits = self.lm_head(self.norm(h)) 169 | return LanguageModelOutput(logits=logits) 170 | 171 | def sanitize(self, weights): 172 | # Remove unused precomputed rotary freqs 173 | return { 174 | k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k 175 | } 176 | 177 | @property 178 | def layers(self): 179 | return self.model.layers 180 | 181 | @property 182 | def head_dim(self): 183 | return self.config.hidden_size // self.config.num_attention_heads 184 | 185 | @property 186 | def n_kv_heads(self): 187 | return self.config.num_key_value_heads 188 | -------------------------------------------------------------------------------- /mlx_vlm/models/idefics3/__init__.py: -------------------------------------------------------------------------------- 1 | from .idefics3 import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/idefics3/language.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import re 3 | from dataclasses import dataclass 4 | from typing import Dict, Optional, Tuple, Union 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | from ..base import LanguageModelOutput, create_attention_mask 10 | from ..cache import KVCache 11 | 12 | 13 | @dataclass 14 | class TextConfig: 15 | model_type: str 16 | hidden_size: int 17 | intermediate_size: int 18 | num_attention_heads: int 19 | rms_norm_eps: float 20 | vocab_size: int 21 | num_key_value_heads: int 22 | rope_theta: float = 1000000.0 23 | num_hidden_layers: int = 32 24 | rope_traditional: bool = False 25 | max_position_embeddings: int = 4096 26 | tie_word_embeddings: bool = False 27 | 28 | @classmethod 29 | def from_dict(cls, params): 30 | return cls( 31 | **{ 32 | k: v 33 | for k, v in params.items() 34 | if k in inspect.signature(cls).parameters 35 | } 36 | ) 37 | 38 | def __post_init__(self): 39 | if self.num_key_value_heads is None: 40 | self.num_key_value_heads = self.num_attention_heads 41 | 42 | 43 | class Attention(nn.Module): 44 | def __init__(self, config: TextConfig): 45 | super().__init__() 46 | 47 | dim = config.hidden_size 48 | self.n_heads = n_heads = config.num_attention_heads 49 | self.n_kv_heads = n_kv_heads = config.num_key_value_heads 50 | 51 | head_dim = config.hidden_size // n_heads 52 | self.scale = head_dim**-0.5 53 | 54 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) 55 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 56 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 57 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 58 | 59 | self.rope = nn.RoPE( 60 | head_dim, 61 | traditional=config.rope_traditional, 62 | base=config.rope_theta, 63 | ) 64 | 65 | def __call__( 66 | self, 67 | x: mx.array, 68 | mask: Optional[mx.array] = None, 69 | cache: Optional[KVCache] = None, 70 | ) -> mx.array: 71 | B, L, D = x.shape 72 | 73 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 74 | 75 | # Prepare the queries, keys and values for the attention computation 76 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 77 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 78 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 79 | 80 | if cache is not None: 81 | queries = self.rope(queries, offset=cache.offset) 82 | keys = self.rope(keys, offset=cache.offset) 83 | keys, values = cache.update_and_fetch(keys, values) 84 | else: 85 | queries = self.rope(queries) 86 | keys = self.rope(keys) 87 | 88 | output = mx.fast.scaled_dot_product_attention( 89 | queries, keys, values, scale=self.scale, mask=mask 90 | ) 91 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 92 | return self.o_proj(output) 93 | 94 | 95 | class MLP(nn.Module): 96 | def __init__(self, dim, hidden_dim): 97 | super().__init__() 98 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 99 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 100 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 101 | 102 | def __call__(self, x) -> mx.array: 103 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 104 | 105 | 106 | class TransformerBlock(nn.Module): 107 | def __init__(self, config: TextConfig): 108 | super().__init__() 109 | self.num_attention_heads = config.num_attention_heads 110 | self.hidden_size = config.hidden_size 111 | self.self_attn = Attention(config) 112 | self.mlp = MLP(config.hidden_size, config.intermediate_size) 113 | self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 114 | self.post_attention_layernorm = nn.RMSNorm( 115 | config.hidden_size, eps=config.rms_norm_eps 116 | ) 117 | self.config = config 118 | 119 | def __call__( 120 | self, 121 | x: mx.array, 122 | mask: Optional[mx.array] = None, 123 | cache: Optional[KVCache] = None, 124 | ) -> mx.array: 125 | r = self.self_attn(self.input_layernorm(x), mask, cache) 126 | h = x + r 127 | r = self.mlp(self.post_attention_layernorm(h)) 128 | out = h + r 129 | return out 130 | 131 | 132 | class LanguageModel(nn.Module): 133 | def __init__(self, config: TextConfig): 134 | super().__init__() 135 | self.config = config 136 | self.model_type = config.model_type 137 | self.vocab_size = config.vocab_size 138 | self.num_hidden_layers = config.num_hidden_layers 139 | assert self.vocab_size > 0 140 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) 141 | self.layers = [ 142 | TransformerBlock(config=config) for _ in range(config.num_hidden_layers) 143 | ] 144 | self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 145 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 146 | 147 | def __call__( 148 | self, 149 | inputs: mx.array, 150 | inputs_embeds: Optional[mx.array] = None, 151 | mask: Optional[mx.array] = None, 152 | cache=None, 153 | ): 154 | # for passing merged input embeddings 155 | if inputs_embeds is None: 156 | h = self.embed_tokens(inputs) 157 | else: 158 | h = inputs_embeds.astype(self.norm.weight.dtype) 159 | 160 | if cache is None: 161 | cache = [None] * len(self.layers) 162 | 163 | if mask is None: 164 | mask = create_attention_mask(h, cache) 165 | 166 | for layer, c in zip(self.layers, cache): 167 | h = layer(h, mask, c) 168 | 169 | logits = self.lm_head(self.norm(h)) 170 | return LanguageModelOutput(logits=logits) 171 | 172 | def sanitize(self, weights): 173 | # Remove unused precomputed rotary freqs 174 | return { 175 | k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k 176 | } 177 | 178 | @property 179 | def layers(self): 180 | return self.model.layers 181 | 182 | @property 183 | def head_dim(self): 184 | return self.config.hidden_size // self.config.num_attention_heads 185 | 186 | @property 187 | def n_kv_heads(self): 188 | return self.config.num_key_value_heads 189 | -------------------------------------------------------------------------------- /mlx_vlm/models/internvl_chat/__init__.py: -------------------------------------------------------------------------------- 1 | from .internvl_chat import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | from .processor import InternVLChatProcessor, InternVLImageProcessor 10 | -------------------------------------------------------------------------------- /mlx_vlm/models/internvl_chat/internvl_chat.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import json 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import List, Optional 7 | 8 | import mlx.core as mx 9 | import mlx.nn as nn 10 | import numpy as np 11 | from huggingface_hub import snapshot_download 12 | 13 | from ..base import pixel_shuffle 14 | from .language import LanguageModel, TextConfig 15 | from .vision import VisionConfig, VisionModel 16 | 17 | 18 | @dataclass 19 | class ModelConfig: 20 | text_config: TextConfig 21 | vision_config: VisionConfig 22 | model_type: str 23 | ignore_index: int = -100 24 | image_token_index: int = 151667 25 | video_token_index: int = 151656 26 | vision_feature_select_strategy: str = "default" 27 | vision_feature_layer: int = -1 28 | vocab_size: int = 32000 29 | downsample_ratio: float = 0.5 30 | eos_token_id: Optional[List[int]] = None 31 | 32 | @classmethod 33 | def from_dict(cls, params): 34 | return cls( 35 | **{ 36 | k: v 37 | for k, v in params.items() 38 | if k in inspect.signature(cls).parameters 39 | } 40 | ) 41 | 42 | 43 | class Model(nn.Module): 44 | def __init__(self, config: ModelConfig): 45 | super().__init__() 46 | self.config = config 47 | self.vision_model = VisionModel(config.vision_config) 48 | self.language_model = LanguageModel(config.text_config) 49 | 50 | self.downsample_ratio = config.downsample_ratio 51 | 52 | vit_hidden_size = self.config.vision_config.hidden_size 53 | llm_hidden_size = self.config.text_config.hidden_size 54 | 55 | self.mlp1 = [ 56 | nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), 57 | nn.Linear( 58 | vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size 59 | ), 60 | nn.GELU(), 61 | nn.Linear(llm_hidden_size, llm_hidden_size), 62 | ] 63 | 64 | def get_input_embeddings( 65 | self, 66 | input_ids: Optional[mx.array] = None, 67 | pixel_values: Optional[mx.array] = None, 68 | ): 69 | 70 | if pixel_values is None: 71 | return self.language_model.model.embed_tokens(input_ids) 72 | 73 | dtype = self.vision_model.embeddings.patch_embedding.weight.dtype 74 | pixel_values = pixel_values.astype(dtype) 75 | 76 | # TODO: Remove this after transformers implementation is merged 77 | if pixel_values.ndim == 5: 78 | pixel_values = pixel_values[0] 79 | 80 | # Get the input embeddings from the language model 81 | inputs_embeds = self.language_model.model.embed_tokens(input_ids) 82 | 83 | # Get the ouptut hidden states from the vision model 84 | hidden_states, _, _ = self.vision_model( 85 | pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True 86 | ) 87 | 88 | # Extract vision embeddings, removing the class token (first token) 89 | hidden_states = hidden_states[:, 1:, :] 90 | 91 | # Apply pixel shuffle with downsampling 92 | hidden_states = pixel_shuffle( 93 | hidden_states, shuffle_ratio=self.downsample_ratio 94 | ) 95 | 96 | # Apply MLP transformation 97 | for layer in self.mlp1: 98 | hidden_states = layer(hidden_states) 99 | 100 | # Insert special image tokens in the input_ids 101 | final_inputs_embeds = self._merge_input_ids_with_image_features( 102 | hidden_states, inputs_embeds, input_ids 103 | ) 104 | return final_inputs_embeds 105 | 106 | def _merge_input_ids_with_image_features( 107 | self, image_features, inputs_embeds, input_ids 108 | ): 109 | B, N, C = inputs_embeds.shape 110 | image_token_index = self.config.image_token_index 111 | video_token_index = self.config.video_token_index 112 | 113 | # Positions of