├── mlx_vlm ├── models │ ├── __init__.py │ ├── llava │ │ ├── __init__.py │ │ ├── README.md │ │ ├── llava.py │ │ ├── language.py │ │ └── vision.py │ ├── phi3_v │ │ ├── __init__.py │ │ ├── language.py │ │ ├── su_rope.py │ │ ├── phi3_v.py │ │ └── vision.py │ ├── paligemma │ │ ├── __init__.py │ │ ├── language.py │ │ ├── paligemma.py │ │ └── vision.py │ ├── llava_next │ │ ├── __init__.py │ │ ├── llava_next.py │ │ ├── language.py │ │ └── vision.py │ ├── idefics2 │ │ ├── __init__.py │ │ ├── language.py │ │ ├── vision.py │ │ └── idefics2.py │ ├── nanoLlava │ │ ├── __init__.py │ │ ├── language.py │ │ ├── nanoLlava.py │ │ └── vision.py │ ├── multi_modality │ │ ├── __init__.py │ │ ├── language.py │ │ └── multi_modality.py │ └── base.py ├── version.py ├── __init__.py ├── prompt_utils.py ├── sample_utils.py ├── convert.py ├── generate.py ├── chat_ui.py └── tokenizer_utils.py ├── MANIFEST.in ├── requirements.txt ├── .pre-commit-config.yaml ├── .gitignore ├── .github └── workflows │ ├── tests.yml │ └── python-publish.yml ├── README.md ├── LICENSE ├── setup.py └── CONTRIBUTING.md /mlx_vlm/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mlx_vlm/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.10" 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include ./requirements.txt 2 | recursive-include mlx_vlm/ *.py 3 | -------------------------------------------------------------------------------- /mlx_vlm/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import convert, generate, load 2 | from .version import __version__ 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.14 2 | numpy 3 | transformers>=4.39.3 4 | scipy==1.13.1 5 | gradio 6 | Pillow 7 | requests 8 | -------------------------------------------------------------------------------- /mlx_vlm/models/llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .llava import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/phi3_v/__init__.py: -------------------------------------------------------------------------------- 1 | from .phi3_v import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/paligemma/__init__.py: -------------------------------------------------------------------------------- 1 | from .paligemma import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/llava_next/__init__.py: -------------------------------------------------------------------------------- 1 | from .llava_next import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | TextConfig, 6 | VisionConfig, 7 | VisionModel, 8 | ) 9 | -------------------------------------------------------------------------------- /mlx_vlm/models/idefics2/__init__.py: -------------------------------------------------------------------------------- 1 | from .idefics2 import ( 2 | LanguageModel, 3 | Model, 4 | ModelConfig, 5 | PerceiverConfig, 6 | TextConfig, 7 | VisionConfig, 8 | VisionModel, 9 | ) 10 | -------------------------------------------------------------------------------- /mlx_vlm/models/nanoLlava/__init__.py: -------------------------------------------------------------------------------- 1 | from .nanoLlava import ( 2 | ImageProcessor, 3 | LanguageModel, 4 | Model, 5 | ModelConfig, 6 | TextConfig, 7 | VisionConfig, 8 | VisionModel, 9 | ) 10 | -------------------------------------------------------------------------------- /mlx_vlm/models/multi_modality/__init__.py: -------------------------------------------------------------------------------- 1 | from .multi_modality import ( 2 | AlignerConfig, 3 | ImageProcessor, 4 | LanguageModel, 5 | Model, 6 | ModelConfig, 7 | TextConfig, 8 | VisionConfig, 9 | VisionModel, 10 | ) 11 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black-pre-commit-mirror 3 | rev: 24.2.0 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/pycqa/isort 7 | rev: 5.13.2 8 | hooks: 9 | - id: isort 10 | args: 11 | - --profile=black 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **.ipynb 2 | **.pyc 3 | 4 | Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | 8 | # Distribution / packaging 9 | bin/ 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | .DS_Store 23 | -------------------------------------------------------------------------------- /mlx_vlm/models/phi3_v/language.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass 6 | class TextConfig: 7 | @classmethod 8 | def from_dict(cls, params): 9 | return cls( 10 | **{ 11 | k: v 12 | for k, v in params.items() 13 | if k in inspect.signature(cls).parameters 14 | } 15 | ) 16 | 17 | 18 | class LanguageModel: 19 | pass 20 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Test PRs 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | test: 10 | runs-on: macos-14 11 | 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v2 15 | 16 | - name: Set up Python 17 | run: | 18 | brew install python@3.10 19 | python3 -m venv env 20 | source env/bin/activate 21 | 22 | 23 | - name: Run style checks 24 | run: | 25 | pip install pre-commit 26 | pre-commit run --all 27 | if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi 28 | 29 | - name: Install dependencies 30 | run: | 31 | pip install pytest 32 | pip install -e . 33 | 34 | - name: Run Python tests 35 | run: | 36 | cd mlx_vlm/ 37 | pytest -s ./tests 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MLX-VLM 2 | 3 | MLX-VLM a package for running Vision LLMs on your Mac using MLX. 4 | 5 | 6 | ## Get started 7 | 8 | The easiest way to get started is to install the `mlx-vlm` package: 9 | 10 | **With `pip`**: 11 | 12 | ```sh 13 | pip install mlx-vlm 14 | ``` 15 | 16 | ## Inference 17 | 18 | **CLI** 19 | ```sh 20 | python -m mlx_vlm.generate --model qnguyen3/nanoLLaVA --max-tokens 100 --temp 0.0 21 | ``` 22 | 23 | **Chat UI with Gradio** 24 | ```sh 25 | python -m mlx_vlm.chat_ui --model qnguyen3/nanoLLaVA 26 | ``` 27 | 28 | **Script** 29 | ```python 30 | import mlx.core as mx 31 | from mlx_vlm import load, generate 32 | 33 | model_path = "mlx-community/llava-1.5-7b-4bit" 34 | model, processor = load(model_path) 35 | 36 | prompt = processor.tokenizer.apply_chat_template( 37 | [{"role": "user", "content": f"\nWhat are these?"}], 38 | tokenize=False, 39 | add_generation_prompt=True, 40 | ) 41 | 42 | output = generate(model, processor, "http://images.cocodataset.org/val2017/000000039769.jpg", prompt, verbose=False) 43 | ``` 44 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [published] 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | deploy: 15 | 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - uses: actions/checkout@v3 20 | - name: Set up Python 21 | uses: actions/setup-python@v3 22 | with: 23 | python-version: '3.10' 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install build 28 | - name: Build package 29 | run: python -m build 30 | - name: Publish package 31 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 32 | with: 33 | user: __token__ 34 | password: ${{ secrets.PYPI_API_TOKEN }} 35 | packages_dir: dist 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright © 2023 Apple Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mlx_vlm/prompt_utils.py: -------------------------------------------------------------------------------- 1 | def get_message_json(model_name, prompt): 2 | """ 3 | Get the appropriate JSON message based on the specified model. 4 | 5 | Args: 6 | model_name (str): The model for which to generate the message. Options: 'Idefics 2', 'nanollava', 'llava'. 7 | prompt (str): The text prompt to be included in the message. 8 | *args: Additional positional arguments (unused). 9 | **kwargs: Additional keyword arguments (unused). 10 | 11 | Returns: 12 | dict: A dictionary representing the JSON message for the specified model. 13 | """ 14 | if model_name.lower() == "idefics2": 15 | message = { 16 | "role": "user", 17 | "content": [{"type": "image"}, {"type": "text", "text": prompt}], 18 | } 19 | elif model_name.lower() in ["llava-qwen2", "llava", "llava_next"]: 20 | message = {"role": "user", "content": f"\n{prompt}"} 21 | elif model_name.lower() == "phi3_v": 22 | message = {"role": "user", "content": f"<|image_1|>\n{prompt}"} 23 | elif model_name.lower() == "multi_modality": 24 | message = {"role": "user", "content": f"{prompt}"} 25 | elif model_name.lower() == "paligemma": 26 | message = prompt 27 | else: 28 | raise ValueError(f"Unsupported model: {model_name}") 29 | 30 | return message 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | from setuptools import find_packages, setup 5 | 6 | # Get the project root directory 7 | root_dir = Path(__file__).parent 8 | 9 | # Add the package directory to the Python path 10 | package_dir = root_dir / "mlx_vlm" 11 | sys.path.append(str(package_dir)) 12 | 13 | # Read the requirements from the requirements.txt file 14 | requirements_path = root_dir / "requirements.txt" 15 | with open(requirements_path) as fid: 16 | requirements = [l.strip() for l in fid.readlines()] 17 | 18 | # Import the version from the package 19 | from version import __version__ 20 | 21 | # Setup configuration 22 | setup( 23 | name="mlx-vlm", 24 | version=__version__, 25 | description="Vision LLMs on Apple silicon with MLX and the Hugging Face Hub", 26 | long_description=open(root_dir / "README.md", encoding="utf-8").read(), 27 | long_description_content_type="text/markdown", 28 | author_email="prince.gdt@gmail.com", 29 | author="Prince Canuma", 30 | url="https://github.com/Blaizzy/mlx-vlm", 31 | license="MIT", 32 | install_requires=requirements, 33 | packages=find_packages(where=root_dir), 34 | python_requires=">=3.8", 35 | entry_points={ 36 | "console_scripts": [ 37 | "mlx_vlm.convert = mlx_vlm.convert:main", 38 | "mlx_vlm.generate = mlx_vlm.generate:main", 39 | ] 40 | }, 41 | ) 42 | -------------------------------------------------------------------------------- /mlx_vlm/sample_utils.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | 3 | 4 | def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array: 5 | """ 6 | Apply top-p (nucleus) sampling to logits. 7 | 8 | Args: 9 | logits: The logits from the model's output. 10 | top_p: The cumulative probability threshold for top-p filtering. 11 | temperature: Temperature parameter for softmax distribution reshaping. 12 | Returns: 13 | token selected based on the top-p criterion. 14 | """ 15 | if ( 16 | logits.dtype == mx.bfloat16 17 | ): # workaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16 18 | logits = logits.astype(mx.float32) 19 | 20 | # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 21 | probs = mx.softmax(logits / temperature, axis=-1) 22 | 23 | # sort probs in ascending order 24 | sorted_indices = mx.argsort(probs, axis=-1) 25 | sorted_probs = probs[..., sorted_indices.squeeze(0)] 26 | 27 | cumulative_probs = mx.cumsum(sorted_probs, axis=-1) 28 | 29 | # select tokens with cumulative probs below threshold 30 | top_probs = mx.where( 31 | cumulative_probs > 1 - top_p, 32 | sorted_probs, 33 | mx.zeros_like(sorted_probs), 34 | ) 35 | 36 | sorted_token = mx.random.categorical(mx.log(top_probs)) 37 | token = sorted_indices.squeeze(0)[sorted_token] 38 | 39 | return token 40 | -------------------------------------------------------------------------------- /mlx_vlm/models/llava/README.md: -------------------------------------------------------------------------------- 1 | # LLaVA 2 | 3 | An example of LLaVA: Large Language and Vision Assistant in MLX.[^1] LLlava is 4 | a multimodal model that can generate text given combined image and text inputs. 5 | 6 | ## Setup 7 | 8 | Install the dependencies: 9 | 10 | ```bash 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | ## Run 15 | 16 | You can use LLaVA to ask questions about images. 17 | 18 | For example, using the command line: 19 | 20 | ```bash 21 | python generate.py \ 22 | --model llava-hf/llava-1.5-7b-hf \ 23 | --image "http://images.cocodataset.org/val2017/000000039769.jpg" \ 24 | --prompt "USER: \nWhat are these?\nASSISTANT:" \ 25 | --max-tokens 128 \ 26 | --temp 0 27 | ``` 28 | 29 | This uses the following image: 30 | 31 | ![alt text](http://images.cocodataset.org/val2017/000000039769.jpg) 32 | 33 | And generates the output: 34 | 35 | ``` 36 | These are two cats lying on a pink couch. 37 | ``` 38 | 39 | You can also use LLaVA in Python: 40 | 41 | ```python 42 | from generate import load_model, prepare_inputs, generate_text 43 | 44 | processor, model = load_model("llava-hf/llava-1.5-7b-hf") 45 | 46 | max_tokens, temperature = 128, 0.0 47 | 48 | prompt = "USER: \nWhat are these?\nASSISTANT:" 49 | image = "http://images.cocodataset.org/val2017/000000039769.jpg" 50 | input_ids, pixel_values = prepare_inputs(processor, image, prompt) 51 | 52 | reply = generate_text( 53 | input_ids, pixel_values, model, processor, max_tokens, temperature 54 | ) 55 | 56 | print(reply) 57 | ``` 58 | 59 | [^1]: 60 | Refer to [LLaVA project webpage](https://llava-vl.github.io/) for more 61 | information. 62 | -------------------------------------------------------------------------------- /mlx_vlm/models/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict 3 | 4 | from PIL import Image 5 | from transformers.image_processing_utils import get_size_dict 6 | from transformers.image_utils import ChannelDimension, PILImageResampling 7 | 8 | 9 | def expand2square(pil_img, background_color): 10 | width, height = pil_img.size 11 | if width == height: 12 | return pil_img 13 | elif width > height: 14 | result = Image.new(pil_img.mode, (width, width), background_color) 15 | result.paste(pil_img, (0, (width - height) // 2)) 16 | return result 17 | else: 18 | result = Image.new(pil_img.mode, (height, height), background_color) 19 | result.paste(pil_img, ((height - width) // 2, 0)) 20 | return result 21 | 22 | 23 | class BaseImageProcessor(ABC): 24 | def __init__( 25 | self, 26 | image_mean=(0.5, 0.5, 0.5), 27 | image_std=(0.5, 0.5, 0.5), 28 | size=(384, 384), 29 | crop_size: Dict[str, int] = None, 30 | resample=PILImageResampling.BICUBIC, 31 | rescale_factor=1 / 255, 32 | data_format=ChannelDimension.FIRST, 33 | ): 34 | crop_size = ( 35 | crop_size if crop_size is not None else {"height": 384, "width": 384} 36 | ) 37 | crop_size = get_size_dict( 38 | crop_size, default_to_square=True, param_name="crop_size" 39 | ) 40 | 41 | self.image_mean = image_mean 42 | self.image_std = image_std 43 | self.size = size 44 | self.resample = resample 45 | self.rescale_factor = rescale_factor 46 | self.data_format = data_format 47 | self.crop_size = crop_size 48 | 49 | @abstractmethod 50 | def preprocess(self, images): 51 | pass 52 | -------------------------------------------------------------------------------- /mlx_vlm/convert.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import argparse 4 | 5 | from .utils import convert 6 | 7 | 8 | def configure_parser() -> argparse.ArgumentParser: 9 | """ 10 | Configures and returns the argument parser for the script. 11 | 12 | Returns: 13 | argparse.ArgumentParser: Configured argument parser. 14 | """ 15 | parser = argparse.ArgumentParser( 16 | description="Convert Hugging Face model to MLX format" 17 | ) 18 | 19 | parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.") 20 | parser.add_argument( 21 | "--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model." 22 | ) 23 | parser.add_argument( 24 | "-q", "--quantize", help="Generate a quantized model.", action="store_true" 25 | ) 26 | parser.add_argument( 27 | "--q-group-size", help="Group size for quantization.", type=int, default=64 28 | ) 29 | parser.add_argument( 30 | "--q-bits", help="Bits per weight for quantization.", type=int, default=4 31 | ) 32 | parser.add_argument( 33 | "--dtype", 34 | help="Type to save the parameters, ignored if -q is given.", 35 | type=str, 36 | choices=["float16", "bfloat16", "float32"], 37 | default="float16", 38 | ) 39 | parser.add_argument( 40 | "--upload-repo", 41 | help="The Hugging Face repo to upload the model to.", 42 | type=str, 43 | default=None, 44 | ) 45 | parser.add_argument( 46 | "-d", 47 | "--dequantize", 48 | help="Dequantize a quantized model.", 49 | action="store_true", 50 | default=False, 51 | ) 52 | return parser 53 | 54 | 55 | def main(): 56 | parser = configure_parser() 57 | args = parser.parse_args() 58 | convert(**vars(args)) 59 | 60 | 61 | if __name__ == "__main__": 62 | main() 63 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to MLX VLM 2 | 3 | Below are some tips to port Vision LLMs available on Hugging Face to MLX. 4 | 5 | Next, from this directory, do an editable install: 6 | 7 | ```shell 8 | pip install -e . 9 | ``` 10 | 11 | Then check if the model has weights in the 12 | [safetensors](https://huggingface.co/docs/safetensors/index) format. If not 13 | [follow instructions](https://huggingface.co/spaces/safetensors/convert) to 14 | convert it. 15 | 16 | After that, add the model file to the 17 | [`mlx_vlm/models`](https://github.com/Blaizzy/mlx-vlm/tree/main/src/models) 18 | directory. You can see other examples there. We recommend starting from a model 19 | that is similar to the model you are porting. 20 | 21 | Make sure the name of the new model file is the same as the `model_type` in the 22 | `config.json`, for example 23 | [llava](https://huggingface.co/llava-hf/llava-1.5-7b-hf/blob/main/config.json#L7). 24 | 25 | To determine the model layer names, we suggest either: 26 | 27 | - Refer to the Transformers implementation if you are familiar with the 28 | codebase. 29 | - Load the model weights and check the weight names which will tell you about 30 | the model structure. 31 | - Look at the names of the weights by inspecting `model.safetensors.index.json` 32 | in the Hugging Face repo. 33 | 34 | Additionally, add a test for the new modle type to the [model 35 | tests](https://github.com/Blaizzy/mlx-vlm/tree/main/src/tests/test_models.py). 36 | 37 | From the `src/` directory, you can run the tests with: 38 | 39 | ```shell 40 | python -m unittest discover tests/ 41 | ``` 42 | 43 | ## Pull Requests 44 | 45 | 1. Fork and submit pull requests to the repo. 46 | 2. If you've added code that should be tested, add tests. 47 | 3. Every PR should have passing tests and at least one review. 48 | 4. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`. 49 | This should install hooks for running `black` and `clang-format` to ensure 50 | consistent style for C++ and python code. 51 | 52 | You can also run the formatters manually as follows on individual files: 53 | 54 | ```bash 55 | clang-format -i file.cpp 56 | ``` 57 | 58 | ```bash 59 | black file.py 60 | ``` 61 | 62 | or, 63 | 64 | ```bash 65 | # single file 66 | pre-commit run --files file1.py 67 | 68 | # specific files 69 | pre-commit run --files file1.py file2.py 70 | ``` 71 | 72 | or run `pre-commit run --all-files` to check all files in the repo. 73 | 74 | ## Issues 75 | 76 | We use GitHub issues to track public bugs. Please ensure your description is 77 | clear and has sufficient instructions to be able to reproduce the issue. 78 | 79 | ## License 80 | 81 | By contributing to mlx-examples, you agree that your contributions will be licensed 82 | under the LICENSE file in the root directory of this source tree. 83 | -------------------------------------------------------------------------------- /mlx_vlm/generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import codecs 3 | 4 | import mlx.core as mx 5 | 6 | from .prompt_utils import get_message_json 7 | from .utils import generate, get_model_path, load, load_config, load_image_processor 8 | 9 | MODEL_TYPE = "" 10 | 11 | 12 | def parse_arguments(): 13 | parser = argparse.ArgumentParser( 14 | description="Generate text from an image using a model." 15 | ) 16 | parser.add_argument( 17 | "--model", 18 | type=str, 19 | default="qnguyen3/nanoLLaVA", 20 | help="The path to the local model directory or Hugging Face repo.", 21 | ) 22 | parser.add_argument( 23 | "--image", 24 | type=str, 25 | default="http://images.cocodataset.org/val2017/000000039769.jpg", 26 | help="URL or path of the image to process.", 27 | ) 28 | parser.add_argument( 29 | "--prompt", 30 | type=str, 31 | default="What are these?", 32 | help="Message to be processed by the model.", 33 | ) 34 | parser.add_argument( 35 | "--max-tokens", 36 | type=int, 37 | default=100, 38 | help="Maximum number of tokens to generate.", 39 | ) 40 | parser.add_argument( 41 | "--temp", type=float, default=0.3, help="Temperature for sampling." 42 | ) 43 | parser.add_argument( 44 | "--verbose", 45 | type=bool, 46 | help="Detailed output.", 47 | default=True, 48 | ) 49 | return parser.parse_args() 50 | 51 | 52 | def get_model_and_processors(model_path): 53 | model_path = get_model_path(model_path) 54 | config = load_config(model_path) 55 | model, processor = load(model_path, {"trust_remote_code": True}) 56 | image_processor = load_image_processor(model_path) 57 | return model, processor, image_processor, config 58 | 59 | 60 | def sample(logits, temperature=0.0): 61 | if temperature == 0: 62 | return mx.argmax(logits, axis=-1) 63 | else: 64 | return mx.random.categorical(logits * (1 / temperature)) 65 | 66 | 67 | def main(): 68 | args = parse_arguments() 69 | model, processor, image_processor, config = get_model_and_processors(args.model) 70 | 71 | prompt = codecs.decode(args.prompt, "unicode_escape") 72 | 73 | if "chat_template" in processor.__dict__.keys(): 74 | prompt = processor.apply_chat_template( 75 | [get_message_json(config["model_type"], prompt)], 76 | tokenize=False, 77 | add_generation_prompt=True, 78 | ) 79 | 80 | elif "tokenizer" in processor.__dict__.keys(): 81 | if model.config.model_type != "paligemma": 82 | prompt = processor.tokenizer.apply_chat_template( 83 | [get_message_json(config["model_type"], prompt)], 84 | tokenize=False, 85 | add_generation_prompt=True, 86 | ) 87 | 88 | else: 89 | ValueError( 90 | "Error: processor does not have 'chat_template' or 'tokenizer' attribute." 91 | ) 92 | 93 | output = generate( 94 | model, 95 | processor, 96 | args.image, 97 | prompt, 98 | image_processor, 99 | args.temp, 100 | args.max_tokens, 101 | args.verbose, 102 | ) 103 | if not args.verbose: 104 | print(output) 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /mlx_vlm/models/phi3_v/su_rope.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import mlx.core as mx 4 | 5 | 6 | class Phi3SuScaledRotaryEmbedding: 7 | def __init__( 8 | self, 9 | dims: int, 10 | traditional: bool = False, 11 | base: float = 10000.0, 12 | scale: float = 1.0, 13 | max_position_embeddings: int = 131072, 14 | original_max_position_embeddings: int = 4096, 15 | short_factor: list[float] | float = 1.0, 16 | long_factor: list[float] | float = 1.0, 17 | ): 18 | """ 19 | Phi3Su Scaled Rotary Embedding layer for Phi-3 models. 20 | 21 | Args: 22 | dims (int): The feature dimensions to be rotated. 23 | traditional (bool, optional): Unused. Default: ``False``. 24 | base (int, optional): Base for the exponential scaling. 25 | scale (float, optional): The scale used to scale the positions. Default: 1.0. 26 | max_position_embeddings (int, optional): The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling. Default: 131072. 27 | original_max_position_embeddings (int, optional): The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling. Default: 4096. 28 | short_factor (float or list of floats, optional): List of scaling factors for sequences of length lesser than original_max_position_embeddings. Default: 1.0. 29 | long_factor (float or list of floats, optional): List of scaling factors for sequences of length greater than original_max_position_embeddings. Default: 1.0. 30 | """ 31 | self.inv_freq_short = 1.0 / ( 32 | mx.array(short_factor, dtype=mx.float32) 33 | * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) 34 | ) 35 | self.inv_freq_long = 1.0 / ( 36 | scale 37 | * mx.array(long_factor, dtype=mx.float32) 38 | * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) 39 | ) 40 | self.original_max_position_embeddings = original_max_position_embeddings 41 | self.scaling_factor = math.sqrt( 42 | 1 43 | + math.log(max_position_embeddings / original_max_position_embeddings) 44 | / math.log(original_max_position_embeddings) 45 | ) 46 | 47 | def _get_cos_sin(self, offset, L): 48 | position_ids = mx.arange(offset, offset + L, dtype=mx.float32)[None] 49 | inv_freq = ( 50 | self.inv_freq_long 51 | if position_ids.max() + 1 > self.original_max_position_embeddings 52 | else self.inv_freq_short 53 | ) 54 | inv_freq_expanded = mx.repeat( 55 | inv_freq[None, :, None], position_ids.shape[0], axis=0 56 | ) 57 | position_ids_expanded = position_ids[:, None, :] 58 | freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1) 59 | emb = mx.concatenate([freqs, freqs], axis=-1) 60 | cos = mx.cos(emb) * self.scaling_factor 61 | sin = mx.sin(emb) * self.scaling_factor 62 | return mx.expand_dims(cos, axis=1), mx.expand_dims(sin, axis=1) 63 | 64 | def __call__(self, x, offset: int = 0): 65 | def _rotate_half(_x): 66 | midpoint = _x.shape[-1] // 2 67 | x1, x2 = _x[..., :midpoint], _x[..., midpoint:] 68 | return mx.concatenate([-x2, x1], axis=-1) 69 | 70 | cos, sin = self._get_cos_sin(offset, x.shape[2]) 71 | return (x * cos) + (_rotate_half(x) * sin) 72 | -------------------------------------------------------------------------------- /mlx_vlm/chat_ui.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Optional 3 | 4 | import gradio as gr 5 | import mlx.core as mx 6 | 7 | from mlx_vlm import load 8 | 9 | from .prompt_utils import get_message_json 10 | from .utils import ( 11 | generate_step, 12 | load, 13 | load_config, 14 | load_image_processor, 15 | prepare_inputs, 16 | sample, 17 | ) 18 | 19 | 20 | def parse_arguments(): 21 | parser = argparse.ArgumentParser( 22 | description="Generate text from an image using a model." 23 | ) 24 | parser.add_argument( 25 | "--model", 26 | type=str, 27 | default="qnguyen3/nanoLLaVA", 28 | help="The path to the local model directory or Hugging Face repo.", 29 | ) 30 | return parser.parse_args() 31 | 32 | 33 | args = parse_arguments() 34 | config = load_config(args.model) 35 | model, processor = load(args.model, {"trust_remote_code": True}) 36 | image_processor = load_image_processor(args.model) 37 | 38 | 39 | def generate( 40 | model, 41 | processor, 42 | image: str, 43 | prompt: str, 44 | image_processor=None, 45 | temp: float = 0.0, 46 | max_tokens: int = 100, 47 | repetition_penalty: Optional[float] = None, 48 | repetition_context_size: Optional[int] = None, 49 | top_p: float = 1.0, 50 | ): 51 | 52 | if image_processor is not None: 53 | tokenizer = processor 54 | else: 55 | tokenizer = processor.tokenizer 56 | 57 | image_token_index = model.config.image_token_index 58 | input_ids, pixel_values, mask = prepare_inputs( 59 | image_processor, processor, image, prompt, image_token_index 60 | ) 61 | logits, cache = model(input_ids, pixel_values, mask=mask) 62 | logits = logits[:, -1, :] 63 | y, _ = sample(logits, temp, top_p) 64 | 65 | detokenizer = processor.detokenizer 66 | detokenizer.reset() 67 | 68 | detokenizer.add_token(y.item()) 69 | 70 | for (token, _), n in zip( 71 | generate_step( 72 | model.language_model, 73 | logits, 74 | mask, 75 | cache, 76 | temp, 77 | repetition_penalty, 78 | repetition_context_size, 79 | top_p, 80 | ), 81 | range(max_tokens), 82 | ): 83 | token = token.item() 84 | 85 | if token == tokenizer.eos_token_id: 86 | break 87 | 88 | detokenizer.add_token(token) 89 | detokenizer.finalize() 90 | yield detokenizer.last_segment 91 | 92 | 93 | def chat(message, history, temperature, max_tokens): 94 | 95 | chat = [] 96 | if len(message["files"]) >= 1: 97 | chat.append(get_message_json(config["model_type"], message["text"])) 98 | else: 99 | raise gr.Error("Please upload an image. Text only chat is not supported.") 100 | 101 | files = message["files"][-1] 102 | if "chat_template" in processor.__dict__.keys(): 103 | messages = processor.apply_chat_template( 104 | chat, 105 | tokenize=False, 106 | add_generation_prompt=True, 107 | ) 108 | 109 | elif "tokenizer" in processor.__dict__.keys(): 110 | if model.config.model_type != "paligemma": 111 | messages = processor.tokenizer.apply_chat_template( 112 | chat, 113 | tokenize=False, 114 | add_generation_prompt=True, 115 | ) 116 | else: 117 | messages = message["text"] 118 | 119 | response = "" 120 | for chunk in generate( 121 | model, 122 | processor, 123 | files, 124 | messages, 125 | image_processor, 126 | temperature, 127 | max_tokens, 128 | ): 129 | response += chunk 130 | yield response 131 | 132 | 133 | demo = gr.ChatInterface( 134 | fn=chat, 135 | title="MLX-VLM Chat UI", 136 | additional_inputs_accordion=gr.Accordion( 137 | label="⚙️ Parameters", open=False, render=False 138 | ), 139 | additional_inputs=[ 140 | gr.Slider( 141 | minimum=0, maximum=1, step=0.1, value=0.1, label="Temperature", render=False 142 | ), 143 | gr.Slider( 144 | minimum=128, 145 | maximum=4096, 146 | step=1, 147 | value=200, 148 | label="Max new tokens", 149 | render=False, 150 | ), 151 | ], 152 | description=f"Now Running {args.model}", 153 | multimodal=True, 154 | ) 155 | 156 | demo.launch(inbrowser=True) 157 | -------------------------------------------------------------------------------- /mlx_vlm/models/idefics2/language.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import re 3 | from dataclasses import dataclass 4 | from typing import Dict, Optional, Tuple, Union 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | 10 | @dataclass 11 | class TextConfig: 12 | model_type: str 13 | hidden_size: int 14 | num_hidden_layers: int 15 | intermediate_size: int 16 | num_attention_heads: int 17 | rms_norm_eps: float 18 | vocab_size: int 19 | num_key_value_heads: int 20 | rope_theta: float = 1000000.0 21 | rope_traditional: bool = False 22 | tie_word_embeddings: bool = False 23 | 24 | @classmethod 25 | def from_dict(cls, params): 26 | return cls( 27 | **{ 28 | k: v 29 | for k, v in params.items() 30 | if k in inspect.signature(cls).parameters 31 | } 32 | ) 33 | 34 | def __post_init__(self): 35 | if self.num_key_value_heads is None: 36 | self.num_key_value_heads = self.num_attention_heads 37 | 38 | 39 | class Attention(nn.Module): 40 | def __init__(self, args: TextConfig): 41 | super().__init__() 42 | 43 | dim = args.hidden_size 44 | self.n_heads = n_heads = args.num_attention_heads 45 | self.n_kv_heads = n_kv_heads = args.num_key_value_heads 46 | 47 | head_dim = args.hidden_size // n_heads 48 | self.scale = head_dim**-0.5 49 | 50 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) 51 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 52 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 53 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 54 | 55 | self.rope = nn.RoPE( 56 | head_dim, 57 | traditional=args.rope_traditional, 58 | base=args.rope_theta, 59 | ) 60 | 61 | def __call__( 62 | self, 63 | x: mx.array, 64 | mask: Optional[mx.array] = None, 65 | cache: Optional[Tuple[mx.array, mx.array]] = None, 66 | ) -> mx.array: 67 | B, L, D = x.shape 68 | 69 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 70 | 71 | # Prepare the queries, keys and values for the attention computation 72 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 73 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 74 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 75 | 76 | if cache is not None: 77 | key_cache, value_cache = cache 78 | queries = self.rope(queries, offset=key_cache.shape[2]) 79 | keys = self.rope(keys, offset=key_cache.shape[2]) 80 | keys = mx.concatenate([key_cache, keys], axis=2) 81 | values = mx.concatenate([value_cache, values], axis=2) 82 | else: 83 | queries = self.rope(queries) 84 | keys = self.rope(keys) 85 | 86 | output = mx.fast.scaled_dot_product_attention( 87 | queries, keys, values, scale=self.scale, mask=mask 88 | ) 89 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 90 | return self.o_proj(output), (keys, values) 91 | 92 | 93 | class MLP(nn.Module): 94 | def __init__(self, dim, hidden_dim): 95 | super().__init__() 96 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 97 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 98 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 99 | 100 | def __call__(self, x) -> mx.array: 101 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 102 | 103 | 104 | class TransformerBlock(nn.Module): 105 | def __init__(self, args: TextConfig): 106 | super().__init__() 107 | self.num_attention_heads = args.num_attention_heads 108 | self.hidden_size = args.hidden_size 109 | self.self_attn = Attention(args) 110 | self.mlp = MLP(args.hidden_size, args.intermediate_size) 111 | self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 112 | self.post_attention_layernorm = nn.RMSNorm( 113 | args.hidden_size, eps=args.rms_norm_eps 114 | ) 115 | self.args = args 116 | 117 | def __call__( 118 | self, 119 | x: mx.array, 120 | mask: Optional[mx.array] = None, 121 | cache: Optional[Tuple[mx.array, mx.array]] = None, 122 | ) -> mx.array: 123 | r, cache = self.self_attn(self.input_layernorm(x), mask, cache) 124 | h = x + r 125 | r = self.mlp(self.post_attention_layernorm(h)) 126 | out = h + r 127 | return out, cache 128 | 129 | 130 | class LanguageModel(nn.Module): 131 | def __init__(self, args: TextConfig): 132 | super().__init__() 133 | self.args = args 134 | self.model_type = args.model_type 135 | self.vocab_size = args.vocab_size 136 | self.num_hidden_layers = args.num_hidden_layers 137 | assert self.vocab_size > 0 138 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 139 | self.layers = [ 140 | TransformerBlock(args=args) for _ in range(args.num_hidden_layers) 141 | ] 142 | self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 143 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 144 | 145 | def __call__( 146 | self, 147 | inputs: mx.array, 148 | cache=None, 149 | inputs_embeds=None, 150 | mask: Optional[mx.array] = None, 151 | ): 152 | # for passing merged input embeddings 153 | if inputs_embeds is None: 154 | h = self.embed_tokens(inputs) 155 | else: 156 | h = inputs_embeds 157 | 158 | mask = None 159 | if h.shape[1] > 1: 160 | mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) 161 | mask = mask.astype(h.dtype) 162 | 163 | if cache is None: 164 | cache = [None] * len(self.layers) 165 | 166 | for e, layer in enumerate(self.layers): 167 | h, cache[e] = layer(h, mask, cache[e]) 168 | 169 | return self.lm_head(self.norm(h)), cache 170 | 171 | def sanitize(self, weights): 172 | # Remove unused precomputed rotary freqs 173 | return { 174 | k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k 175 | } 176 | 177 | @property 178 | def layers(self): 179 | return self.model.layers 180 | -------------------------------------------------------------------------------- /mlx_vlm/models/llava/llava.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import json 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import Optional 7 | 8 | import mlx.core as mx 9 | import mlx.nn as nn 10 | import numpy as np 11 | from huggingface_hub import snapshot_download 12 | 13 | from .language import LanguageModel, TextConfig 14 | from .vision import VisionConfig, VisionModel 15 | 16 | 17 | @dataclass 18 | class ModelConfig: 19 | text_config: TextConfig 20 | vision_config: VisionConfig 21 | model_type: str 22 | ignore_index: int = -100 23 | image_token_index: int = 32000 24 | vision_feature_select_strategy: str = "default" 25 | vision_feature_layer: int = -2 26 | vocab_size: int = 32000 27 | 28 | @classmethod 29 | def from_dict(cls, params): 30 | return cls( 31 | **{ 32 | k: v 33 | for k, v in params.items() 34 | if k in inspect.signature(cls).parameters 35 | } 36 | ) 37 | 38 | 39 | class LlavaMultiModalProjector(nn.Module): 40 | def __init__(self, config: ModelConfig): 41 | super().__init__() 42 | self.linear_1 = nn.Linear( 43 | config.vision_config.hidden_size, config.text_config.hidden_size, bias=True 44 | ) 45 | self.gelu = nn.GELU() 46 | self.linear_2 = nn.Linear( 47 | config.text_config.hidden_size, config.text_config.hidden_size, bias=True 48 | ) 49 | 50 | def __call__(self, x: mx.array) -> mx.array: 51 | x = self.linear_1(x) 52 | x = self.gelu(x) 53 | x = self.linear_2(x) 54 | return x 55 | 56 | 57 | class Model(nn.Module): 58 | def __init__(self, config: ModelConfig): 59 | self.config = config 60 | self.vision_tower = VisionModel(config.vision_config) 61 | self.language_model = LanguageModel(config.text_config) 62 | self.multi_modal_projector = LlavaMultiModalProjector(config) 63 | self.vision_feature_layer = config.vision_feature_layer 64 | self.vision_feature_select_strategy = config.vision_feature_select_strategy 65 | 66 | def get_input_embeddings( 67 | self, 68 | input_ids: Optional[mx.array] = None, 69 | pixel_values: Optional[mx.array] = None, 70 | ): 71 | if pixel_values is None: 72 | return self.language_model(input_ids) 73 | 74 | # Get the input embeddings from the language model 75 | inputs_embeds = self.language_model.model.embed_tokens(input_ids) 76 | 77 | # Get the ouptut hidden states from the vision model 78 | *_, hidden_states = self.vision_tower( 79 | pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True 80 | ) 81 | 82 | # Select the hidden states from the desired layer 83 | selected_image_feature = hidden_states[self.vision_feature_layer] 84 | 85 | if self.vision_feature_select_strategy == "default": 86 | selected_image_feature = selected_image_feature[:, 1:] 87 | elif self.vision_feature_select_strategy == "full": 88 | selected_image_feature = selected_image_feature 89 | else: 90 | raise ValueError( 91 | "Unexpected feature selection strategy: " 92 | f"{self.vision_feature_select_strategy}" 93 | ) 94 | 95 | # Pass image features through the multi-modal projector 96 | image_features = self.multi_modal_projector(selected_image_feature) 97 | 98 | # Insert special image tokens in the input_ids 99 | final_inputs_embeds = self._merge_input_ids_with_image_features( 100 | image_features, inputs_embeds, input_ids 101 | ) 102 | return final_inputs_embeds 103 | 104 | def _merge_input_ids_with_image_features( 105 | self, image_features, inputs_embeds, input_ids 106 | ): 107 | image_token_index = self.config.image_token_index 108 | num_images, num_image_patches, embed_dim = image_features.shape 109 | 110 | # Positions of tokens in input_ids, assuming batch size is 1 111 | image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() 112 | 113 | if len(image_positions) != num_images: 114 | raise ValueError( 115 | f"The number of image tokens ({len(image_positions)}) does not " 116 | f" match the number of image inputs ({num_images})." 117 | ) 118 | 119 | text_segments = [] 120 | start_idx = 0 121 | 122 | for position in image_positions: 123 | text_segments.append(inputs_embeds[:, start_idx:position]) 124 | start_idx = position + 1 125 | 126 | image_embeddings = mx.split(image_features, image_features.shape[0]) 127 | final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] 128 | final_embeddings += [inputs_embeds[:, start_idx:]] 129 | 130 | # Create a final embedding of shape 131 | # (1, num_image_patches*num_images + sequence_len, embed_dim) 132 | return mx.concatenate(final_embeddings, axis=1) 133 | 134 | def __call__( 135 | self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None 136 | ): 137 | input_embddings = self.get_input_embeddings(input_ids, pixel_values) 138 | logits, cache = self.language_model( 139 | input_ids, cache=cache, inputs_embeds=input_embddings 140 | ) 141 | return logits, cache 142 | 143 | @staticmethod 144 | def from_pretrained(path_or_hf_repo: str): 145 | path = Path(path_or_hf_repo) 146 | if not path.exists(): 147 | path = Path( 148 | snapshot_download( 149 | repo_id=path_or_hf_repo, 150 | allow_patterns=[ 151 | "*.json", 152 | "*.safetensors", 153 | "*.py", 154 | "tokenizer.model", 155 | "*.tiktoken", 156 | ], 157 | ) 158 | ) 159 | 160 | with open(path / "config.json", "r") as f: 161 | model_config = json.load(f) 162 | 163 | model_config = ModelConfig.from_dict(model_config) 164 | 165 | model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) 166 | model_config.text_config = TextConfig.from_dict(model_config.text_config) 167 | 168 | model = Model(model_config) 169 | weight_files = glob.glob(str(path / "*.safetensors")) 170 | if not weight_files: 171 | raise FileNotFoundError(f"No safetensors found in {path}") 172 | 173 | weights = {} 174 | for wf in weight_files: 175 | weights.update(mx.load(wf)) 176 | 177 | weights = VisionModel.sanitize(weights) 178 | weights = LanguageModel.sanitize(weights) 179 | 180 | model.load_weights(list(weights.items())) 181 | return model 182 | -------------------------------------------------------------------------------- /mlx_vlm/models/paligemma/language.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Optional, Tuple 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | 8 | 9 | @dataclass 10 | class TextConfig: 11 | model_type: str 12 | hidden_size: int 13 | num_hidden_layers: int 14 | intermediate_size: int 15 | num_attention_heads: int 16 | num_key_value_heads: int 17 | vocab_size: int 18 | rms_norm_eps: float = 1e-6 19 | rope_theta: float = 10000 20 | rope_traditional: bool = False 21 | 22 | @classmethod 23 | def from_dict(cls, params): 24 | return cls( 25 | **{ 26 | k: v 27 | for k, v in params.items() 28 | if k in inspect.signature(cls).parameters 29 | } 30 | ) 31 | 32 | 33 | class RMSNorm(nn.Module): 34 | def __init__(self, dims: int, eps: float = 1e-6): 35 | super().__init__() 36 | self.weight = mx.ones((dims,)) 37 | self.eps = eps 38 | 39 | def __call__(self, x): 40 | return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps) 41 | 42 | 43 | class Attention(nn.Module): 44 | def __init__(self, args: TextConfig): 45 | super().__init__() 46 | 47 | dim = args.hidden_size 48 | self.n_heads = n_heads = args.num_attention_heads 49 | self.n_kv_heads = n_kv_heads = args.num_key_value_heads 50 | 51 | head_dim = args.hidden_size // n_heads 52 | self.scale = head_dim**-0.5 53 | 54 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) 55 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 56 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 57 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 58 | 59 | self.rope = nn.RoPE( 60 | head_dim, 61 | traditional=args.rope_traditional, 62 | base=args.rope_theta, 63 | ) 64 | 65 | def __call__( 66 | self, 67 | x: mx.array, 68 | mask: Optional[mx.array] = None, 69 | cache: Optional[Tuple[mx.array, mx.array]] = None, 70 | ) -> mx.array: 71 | B, L, D = x.shape 72 | 73 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 74 | 75 | # Prepare the queries, keys and values for the attention computation 76 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 77 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 78 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 79 | 80 | if cache is not None: 81 | key_cache, value_cache = cache 82 | queries = self.rope(queries, offset=key_cache.shape[2]) 83 | keys = self.rope(keys, offset=key_cache.shape[2]) 84 | keys = mx.concatenate([key_cache, keys], axis=2) 85 | values = mx.concatenate([value_cache, values], axis=2) 86 | else: 87 | queries = self.rope(queries) 88 | keys = self.rope(keys) 89 | 90 | output = mx.fast.scaled_dot_product_attention( 91 | queries, keys, values, scale=self.scale, mask=mask 92 | ) 93 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 94 | return self.o_proj(output), (keys, values) 95 | 96 | 97 | class MLP(nn.Module): 98 | def __init__(self, dim, hidden_dim): 99 | super().__init__() 100 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 101 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 102 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 103 | 104 | def __call__(self, x) -> mx.array: 105 | return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x)) 106 | 107 | 108 | class TransformerBlock(nn.Module): 109 | def __init__(self, args: TextConfig): 110 | super().__init__() 111 | self.num_attention_heads = args.num_attention_heads 112 | self.hidden_size = args.hidden_size 113 | self.self_attn = Attention(args) 114 | self.mlp = MLP(args.hidden_size, args.intermediate_size) 115 | self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 116 | self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 117 | self.args = args 118 | 119 | def __call__( 120 | self, 121 | x: mx.array, 122 | mask: Optional[mx.array] = None, 123 | cache: Optional[Tuple[mx.array, mx.array]] = None, 124 | ) -> mx.array: 125 | r, cache = self.self_attn(self.input_layernorm(x), mask, cache) 126 | h = x + r 127 | r = self.mlp(self.post_attention_layernorm(h)) 128 | out = h + r 129 | return out, cache 130 | 131 | 132 | class GemmaModel(nn.Module): 133 | def __init__(self, args: TextConfig): 134 | super().__init__() 135 | self.args = args 136 | self.vocab_size = args.vocab_size 137 | self.num_hidden_layers = args.num_hidden_layers 138 | assert self.vocab_size > 0 139 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 140 | self.layers = [ 141 | TransformerBlock(args=args) for _ in range(args.num_hidden_layers) 142 | ] 143 | self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 144 | 145 | def __call__( 146 | self, 147 | inputs: mx.array, 148 | cache=None, 149 | inputs_embeds=None, 150 | mask: Optional[mx.array] = None, 151 | ): 152 | # for passing merged input embeddings 153 | if inputs_embeds is None: 154 | h = self.embed_tokens(inputs) 155 | 156 | else: 157 | h = inputs_embeds 158 | 159 | h = h * (self.args.hidden_size**0.5) 160 | 161 | if cache is not None: 162 | mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) 163 | mask = mask.astype(h.dtype) 164 | 165 | if cache is None: 166 | cache = [None] * len(self.layers) 167 | 168 | for e, layer in enumerate(self.layers): 169 | h, cache[e] = layer(h, mask, cache[e]) 170 | 171 | return self.norm(h), cache 172 | 173 | 174 | class LanguageModel(nn.Module): 175 | def __init__(self, args: TextConfig): 176 | super().__init__() 177 | self.args = args 178 | self.model_type = args.model_type 179 | self.model = GemmaModel(args) 180 | 181 | def __call__( 182 | self, 183 | inputs: mx.array, 184 | cache=None, 185 | inputs_embeds=None, 186 | mask: Optional[mx.array] = None, 187 | ): 188 | out, cache = self.model(inputs, cache, inputs_embeds=inputs_embeds, mask=mask) 189 | out = self.model.embed_tokens.as_linear(out) 190 | return out, cache 191 | 192 | def sanitize(self, weights): 193 | return { 194 | k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k 195 | } 196 | 197 | @property 198 | def layers(self): 199 | return self.model.layers 200 | -------------------------------------------------------------------------------- /mlx_vlm/models/llava_next/llava_next.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import json 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import Optional 7 | 8 | import mlx.core as mx 9 | import mlx.nn as nn 10 | import numpy as np 11 | from huggingface_hub import snapshot_download 12 | 13 | from .language import LanguageModel, TextConfig 14 | from .vision import VisionConfig, VisionModel 15 | 16 | 17 | @dataclass 18 | class ModelConfig: 19 | text_config: TextConfig 20 | vision_config: VisionConfig 21 | model_type: str 22 | ignore_index: int = -100 23 | image_token_index: int = 32000 24 | vision_feature_select_strategy: str = "default" 25 | vision_feature_layer: int = -2 26 | vocab_size: int = 32000 27 | 28 | @classmethod 29 | def from_dict(cls, params): 30 | return cls( 31 | **{ 32 | k: v 33 | for k, v in params.items() 34 | if k in inspect.signature(cls).parameters 35 | } 36 | ) 37 | 38 | 39 | class LlavaMultiModalProjector(nn.Module): 40 | def __init__(self, config: ModelConfig): 41 | super().__init__() 42 | self.linear_1 = nn.Linear( 43 | config.vision_config.hidden_size, config.text_config.hidden_size, bias=True 44 | ) 45 | self.gelu = nn.GELU() 46 | self.linear_2 = nn.Linear( 47 | config.text_config.hidden_size, config.text_config.hidden_size, bias=True 48 | ) 49 | 50 | def __call__(self, x: mx.array) -> mx.array: 51 | x = self.linear_1(x) 52 | x = self.gelu(x) 53 | x = self.linear_2(x) 54 | return x 55 | 56 | 57 | class Model(nn.Module): 58 | def __init__(self, config: ModelConfig): 59 | self.config = config 60 | self.vision_tower = VisionModel(config.vision_config) 61 | self.language_model = LanguageModel(config.text_config) 62 | embed_std = 1 / mx.sqrt(config.text_config.hidden_size) 63 | self.image_newline = ( 64 | mx.random.normal((config.text_config.hidden_size,)) * embed_std 65 | ) 66 | 67 | self.multi_modal_projector = LlavaMultiModalProjector(config) 68 | self.vision_feature_layer = config.vision_feature_layer 69 | self.vision_feature_select_strategy = config.vision_feature_select_strategy 70 | 71 | def get_input_embeddings( 72 | self, 73 | input_ids: Optional[mx.array] = None, 74 | pixel_values: Optional[mx.array] = None, 75 | ): 76 | if pixel_values is None: 77 | return self.language_model(input_ids) 78 | 79 | # Get the input embeddings from the language model 80 | inputs_embeds = self.language_model.model.embed_tokens(input_ids) 81 | 82 | # Get the ouptut hidden states from the vision model 83 | *_, hidden_states = self.vision_tower( 84 | pixel_values[0].transpose(0, 2, 3, 1), output_hidden_states=True 85 | ) 86 | 87 | # Select the hidden states from the desired layer 88 | selected_image_feature = hidden_states[self.vision_feature_layer] 89 | 90 | if self.vision_feature_select_strategy == "default": 91 | selected_image_feature = selected_image_feature[:, 1:] 92 | elif self.vision_feature_select_strategy == "full": 93 | selected_image_feature = selected_image_feature 94 | else: 95 | raise ValueError( 96 | "Unexpected feature selection strategy: " 97 | f"{self.vision_feature_select_strategy}" 98 | ) 99 | 100 | # Pass image features through the multi-modal projector 101 | image_features = self.multi_modal_projector(selected_image_feature) 102 | if self.image_newline is not None: 103 | self.image_newline = np.array(self.image_newline)[None, None, :] 104 | self.image_newline = np.broadcast_to( 105 | self.image_newline, image_features.shape 106 | ) 107 | image_newline = mx.array(self.image_newline) 108 | image_features = mx.concatenate([image_features, image_newline], axis=0) 109 | 110 | # Insert special image tokens in the input_ids 111 | final_inputs_embeds = self._merge_input_ids_with_image_features( 112 | image_features, inputs_embeds, input_ids 113 | ) 114 | return final_inputs_embeds 115 | 116 | def _merge_input_ids_with_image_features( 117 | self, image_features, inputs_embeds, input_ids 118 | ): 119 | image_token_index = self.config.image_token_index 120 | 121 | # Positions of tokens in input_ids, assuming batch size is 1 122 | image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() 123 | text_segments = [] 124 | start_idx = 0 125 | 126 | for position in image_positions: 127 | text_segments.append(inputs_embeds[:, start_idx:position]) 128 | start_idx = position + 1 129 | 130 | image_embeddings = mx.split(image_features, image_features.shape[0]) 131 | final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] 132 | final_embeddings += [inputs_embeds[:, start_idx:]] 133 | 134 | # Create a final embedding of shape 135 | # (1, num_image_patches*num_images + sequence_len, embed_dim) 136 | return mx.concatenate(final_embeddings, axis=1) 137 | 138 | def __call__( 139 | self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None 140 | ): 141 | 142 | input_embddings = self.get_input_embeddings(input_ids, pixel_values) 143 | logits, cache = self.language_model( 144 | input_ids, cache=cache, inputs_embeds=input_embddings 145 | ) 146 | return logits, cache 147 | 148 | @staticmethod 149 | def from_pretrained(path_or_hf_repo: str): 150 | path = Path(path_or_hf_repo) 151 | if not path.exists(): 152 | path = Path( 153 | snapshot_download( 154 | repo_id=path_or_hf_repo, 155 | allow_patterns=[ 156 | "*.json", 157 | "*.safetensors", 158 | "*.py", 159 | "tokenizer.model", 160 | "*.tiktoken", 161 | ], 162 | ) 163 | ) 164 | 165 | with open(path / "config.json", "r") as f: 166 | model_config = json.load(f) 167 | 168 | model_config = ModelConfig.from_dict(model_config) 169 | 170 | model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) 171 | model_config.text_config = TextConfig.from_dict(model_config.text_config) 172 | 173 | model = Model(model_config) 174 | weight_files = glob.glob(str(path / "*.safetensors")) 175 | if not weight_files: 176 | raise FileNotFoundError(f"No safetensors found in {path}") 177 | 178 | weights = {} 179 | for wf in weight_files: 180 | weights.update(mx.load(wf)) 181 | 182 | weights = VisionModel.sanitize(weights) 183 | weights = LanguageModel.sanitize(weights) 184 | 185 | model.load_weights(list(weights.items())) 186 | return model 187 | -------------------------------------------------------------------------------- /mlx_vlm/models/paligemma/paligemma.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import json 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import Dict, Optional 7 | 8 | import mlx.core as mx 9 | import mlx.nn as nn 10 | import numpy as np 11 | from huggingface_hub import snapshot_download 12 | 13 | from .language import LanguageModel, TextConfig 14 | from .vision import VisionConfig, VisionModel 15 | 16 | 17 | @dataclass 18 | class ModelConfig: 19 | text_config: TextConfig 20 | vision_config: VisionConfig 21 | model_type: str 22 | vocab_size: int 23 | ignore_index: int = -100 24 | image_token_index: int = 257152 25 | hidden_size: int = 2048 26 | pad_token_id: int = 0 27 | 28 | @classmethod 29 | def from_dict(cls, params): 30 | return cls( 31 | **{ 32 | k: v 33 | for k, v in params.items() 34 | if k in inspect.signature(cls).parameters 35 | } 36 | ) 37 | 38 | 39 | class PaliGemmaMultiModalProjector(nn.Module): 40 | def __init__(self, config: ModelConfig): 41 | super().__init__() 42 | self.linear = nn.Linear( 43 | config.vision_config.hidden_size, 44 | config.vision_config.projection_dim, 45 | bias=True, 46 | ) 47 | 48 | def __call__(self, x: mx.array) -> mx.array: 49 | output = self.linear(x) 50 | return output 51 | 52 | 53 | class Model(nn.Module): 54 | def __init__(self, config: ModelConfig): 55 | self.model_type = config.model_type 56 | self.config = config 57 | 58 | self.vision_tower = VisionModel(config.vision_config) 59 | self.language_model = LanguageModel(config.text_config) 60 | self.multi_modal_projector = PaliGemmaMultiModalProjector(config) 61 | 62 | def get_input_embeddings( 63 | self, 64 | input_ids: Optional[mx.array] = None, 65 | pixel_values: Optional[mx.array] = None, 66 | mask: Optional[mx.array] = None, 67 | ): 68 | if pixel_values is None: 69 | return self.language_model(input_ids) 70 | 71 | inputs_embeds = self.language_model.model.embed_tokens(input_ids) 72 | 73 | hidden_state, _, _ = self.vision_tower( 74 | pixel_values.transpose(0, 2, 3, 1).astype(inputs_embeds.dtype), 75 | output_hidden_states=True, 76 | ) 77 | 78 | image_features = hidden_state[None, :].astype(pixel_values.dtype) 79 | image_features = self.multi_modal_projector(image_features) 80 | 81 | final_inputs_embeds, final_attention_mask_4d = ( 82 | self._prepare_inputs_for_multimodal( 83 | image_features, inputs_embeds, input_ids, mask 84 | ) 85 | ) 86 | return final_inputs_embeds, final_attention_mask_4d 87 | 88 | def _prepare_inputs_for_multimodal( 89 | self, image_features, inputs_embeds, input_ids, attention_mask 90 | ): 91 | _, _, embed_dim = image_features.shape 92 | 93 | batch_size, sequence_length = input_ids.shape 94 | scaled_image_features = image_features / (self.config.hidden_size**0.5) 95 | final_embedding = np.zeros((batch_size, sequence_length, embed_dim)) 96 | 97 | text_mask = (input_ids != self.config.image_token_index) & ( 98 | input_ids != self.config.pad_token_id 99 | ) 100 | image_mask = input_ids == self.config.image_token_index 101 | pad_mask = input_ids == self.config.pad_token_id 102 | 103 | # expand masks to match embedding dimension 104 | text_mask_expanded = np.expand_dims(text_mask, -1).repeat(embed_dim, axis=-1) 105 | pad_mask_expanded = np.expand_dims(pad_mask, -1).repeat(embed_dim, axis=-1) 106 | 107 | # insert padding and text token embeddings 108 | final_embedding = np.where(text_mask_expanded, inputs_embeds, final_embedding) 109 | final_embedding = np.where( 110 | pad_mask_expanded, np.zeros_like(final_embedding), final_embedding 111 | ) 112 | 113 | # insert image embeddings - the image mask is always less or equal to the sentence in length 114 | image_mask_expanded = np.expand_dims(image_mask, -1).repeat(embed_dim, axis=-1) 115 | final_embedding[image_mask_expanded] = scaled_image_features.flatten() 116 | 117 | final_embedding = np.where( 118 | pad_mask_expanded, np.zeros_like(final_embedding), final_embedding 119 | ) 120 | 121 | attention_mask_expanded_1 = np.expand_dims(attention_mask, 1) 122 | attention_mask_expanded_2 = np.expand_dims(attention_mask, 2) 123 | final_attention_mask_4d = attention_mask_expanded_1 * attention_mask_expanded_2 124 | final_attention_mask_4d = final_attention_mask_4d 125 | final_attention_mask_4d = np.expand_dims(final_attention_mask_4d, 1).repeat( 126 | self.config.text_config.num_key_value_heads, axis=1 127 | ) 128 | final_embedding = mx.array(final_embedding) 129 | final_attention_mask_4d = mx.array(final_attention_mask_4d) 130 | return final_embedding, final_attention_mask_4d 131 | 132 | def __call__( 133 | self, 134 | input_ids: mx.array, 135 | pixel_values: mx.array, 136 | mask: Optional[mx.array] = None, 137 | cache: Optional[mx.array] = None, 138 | ): 139 | input_embeddings, final_attention_mask_4d = self.get_input_embeddings( 140 | input_ids, pixel_values, mask 141 | ) 142 | 143 | logits, cache = self.language_model( 144 | inputs=input_ids, 145 | cache=cache, 146 | inputs_embeds=input_embeddings, 147 | mask=final_attention_mask_4d, 148 | ) 149 | return logits, cache 150 | 151 | @staticmethod 152 | def from_pretrained(path_or_hf_repo: str): 153 | path = Path(path_or_hf_repo) 154 | if not path.exists(): 155 | path = Path( 156 | snapshot_download( 157 | repo_id=path_or_hf_repo, 158 | allow_patterns=[ 159 | "*.json", 160 | "*.safetensors", 161 | "*.py", 162 | "tokenizer.model", 163 | "*.tiktoken", 164 | ], 165 | ) 166 | ) 167 | 168 | with open(path / "config.json", "r") as f: 169 | config = json.load(f) 170 | 171 | model_config = ModelConfig.from_dict(config) 172 | model_config.vision_config = VisionConfig.from_dict(config["vision_config"]) 173 | model_config.text_config = TextConfig.from_dict(config["text_config"]) 174 | 175 | model = Model(model_config) 176 | weight_files = glob.glob(str(path / "*.safetensors")) 177 | if not weight_files: 178 | raise FileNotFoundError(f"No safetensors found in {path}") 179 | 180 | weights = {} 181 | for wf in weight_files: 182 | weights.update(mx.load(wf)) 183 | 184 | weights = model.sanitize(weights=weights) 185 | 186 | weights = VisionModel(model_config.vision_config).sanitize(weights=weights) 187 | model.load_weights(list(weights.items())) 188 | return model 189 | -------------------------------------------------------------------------------- /mlx_vlm/models/nanoLlava/language.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Dict, Optional, Tuple, Union 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | 8 | 9 | @dataclass 10 | class TextConfig: 11 | model_type: str 12 | hidden_size: int 13 | num_hidden_layers: int 14 | intermediate_size: int 15 | num_attention_heads: int 16 | rms_norm_eps: float 17 | vocab_size: int 18 | num_key_value_heads: int = None 19 | rope_theta: float = 1000000 20 | rope_traditional: bool = False 21 | rope_scaling: Optional[Dict[str, Union[float, str]]] = None 22 | tie_word_embeddings: bool = True 23 | 24 | @classmethod 25 | def from_dict(cls, params): 26 | return cls( 27 | **{ 28 | k: v 29 | for k, v in params.items() 30 | if k in inspect.signature(cls).parameters 31 | } 32 | ) 33 | 34 | def __post_init__(self): 35 | if self.num_key_value_heads is None: 36 | self.num_key_value_heads = self.num_attention_heads 37 | 38 | if self.rope_scaling: 39 | required_keys = {"factor", "type"} 40 | if not all(key in self.rope_scaling for key in required_keys): 41 | raise ValueError(f"rope_scaling must contain keys {required_keys}") 42 | 43 | if self.rope_scaling["type"] != "linear": 44 | raise ValueError("rope_scaling 'type' currently only supports 'linear'") 45 | 46 | 47 | class Attention(nn.Module): 48 | def __init__(self, args: TextConfig): 49 | super().__init__() 50 | 51 | dim = args.hidden_size 52 | self.n_heads = n_heads = args.num_attention_heads 53 | self.n_kv_heads = n_kv_heads = args.num_key_value_heads 54 | 55 | head_dim = args.hidden_size // n_heads 56 | self.scale = head_dim**-0.5 57 | 58 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) 59 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) 60 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) 61 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 62 | 63 | rope_scale = ( 64 | 1 / args.rope_scaling["factor"] 65 | if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" 66 | else 1 67 | ) 68 | self.rope = nn.RoPE( 69 | head_dim, 70 | traditional=args.rope_traditional, 71 | base=args.rope_theta, 72 | scale=rope_scale, 73 | ) 74 | 75 | def __call__( 76 | self, 77 | x: mx.array, 78 | mask: Optional[mx.array] = None, 79 | cache: Optional[Tuple[mx.array, mx.array]] = None, 80 | ) -> mx.array: 81 | B, L, D = x.shape 82 | 83 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 84 | 85 | # Prepare the queries, keys and values for the attention computation 86 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 87 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 88 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 89 | 90 | if cache is not None: 91 | key_cache, value_cache = cache 92 | queries = self.rope(queries, offset=key_cache.shape[2]) 93 | keys = self.rope(keys, offset=key_cache.shape[2]) 94 | keys = mx.concatenate([key_cache, keys], axis=2) 95 | values = mx.concatenate([value_cache, values], axis=2) 96 | else: 97 | queries = self.rope(queries) 98 | keys = self.rope(keys) 99 | 100 | output = mx.fast.scaled_dot_product_attention( 101 | queries, keys, values, scale=self.scale, mask=mask 102 | ) 103 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 104 | return self.o_proj(output), (keys, values) 105 | 106 | 107 | class MLP(nn.Module): 108 | def __init__(self, dim, hidden_dim): 109 | super().__init__() 110 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 111 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 112 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 113 | 114 | def __call__(self, x) -> mx.array: 115 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 116 | 117 | 118 | class TransformerBlock(nn.Module): 119 | def __init__(self, args: TextConfig): 120 | super().__init__() 121 | self.num_attention_heads = args.num_attention_heads 122 | self.hidden_size = args.hidden_size 123 | self.self_attn = Attention(args) 124 | self.mlp = MLP(args.hidden_size, args.intermediate_size) 125 | self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 126 | self.post_attention_layernorm = nn.RMSNorm( 127 | args.hidden_size, eps=args.rms_norm_eps 128 | ) 129 | self.args = args 130 | 131 | def __call__( 132 | self, 133 | x: mx.array, 134 | mask: Optional[mx.array] = None, 135 | cache: Optional[Tuple[mx.array, mx.array]] = None, 136 | ) -> mx.array: 137 | r, cache = self.self_attn(self.input_layernorm(x), mask, cache) 138 | h = x + r 139 | r = self.mlp(self.post_attention_layernorm(h)) 140 | out = h + r 141 | return out, cache 142 | 143 | 144 | class Qwen2Model(nn.Module): 145 | def __init__(self, args: TextConfig): 146 | super().__init__() 147 | self.args = args 148 | self.vocab_size = args.vocab_size 149 | self.num_hidden_layers = args.num_hidden_layers 150 | assert self.vocab_size > 0 151 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 152 | self.layers = [ 153 | TransformerBlock(args=args) for _ in range(args.num_hidden_layers) 154 | ] 155 | self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 156 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 157 | 158 | def __call__( 159 | self, 160 | inputs: mx.array, 161 | cache=None, 162 | inputs_embeds=None, 163 | ): 164 | # for passing merged input embeddings 165 | if inputs_embeds is None: 166 | h = self.embed_tokens(inputs) 167 | else: 168 | h = inputs_embeds 169 | 170 | mask = None 171 | if h.shape[1] > 1: 172 | mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) 173 | mask = mask.astype(h.dtype) 174 | 175 | if cache is None: 176 | cache = [None] * len(self.layers) 177 | 178 | for e, layer in enumerate(self.layers): 179 | h, cache[e] = layer(h, mask, cache[e]) 180 | 181 | return self.lm_head(self.norm(h)), cache 182 | 183 | 184 | class LanguageModel(nn.Module): 185 | def __init__(self, args: TextConfig): 186 | super().__init__() 187 | self.args = args 188 | self.model_type = args.model_type 189 | self.model = Qwen2Model(args) 190 | 191 | def __call__( 192 | self, 193 | inputs: mx.array, 194 | cache=None, 195 | inputs_embeds=None, 196 | mask: Optional[mx.array] = None, 197 | ): 198 | out, cache = self.model(inputs, cache, inputs_embeds=inputs_embeds) 199 | return out, cache 200 | 201 | def sanitize(self, weights): 202 | if ( 203 | self.args.tie_word_embeddings 204 | and "language_model.model.lm_head.weight" not in weights 205 | ): 206 | weights["language_model.model.lm_head.weight"] = weights[ 207 | "language_model.model.embed_tokens.weight" 208 | ] 209 | 210 | return { 211 | k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k 212 | } 213 | 214 | @property 215 | def layers(self): 216 | return self.model.layers 217 | -------------------------------------------------------------------------------- /mlx_vlm/models/llava/language.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Dict, Optional, Tuple, Union 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | 8 | 9 | @dataclass 10 | class TextConfig: 11 | model_type: str 12 | hidden_size: int = 4096 13 | num_hidden_layers: int = 32 14 | intermediate_size: int = 11008 15 | num_attention_heads: int = 32 16 | rms_norm_eps: float = 1e-6 17 | vocab_size: int = 32000 18 | num_key_value_heads: int = None 19 | rope_theta: float = 10000 20 | rope_traditional: bool = False 21 | rope_scaling: Optional[Dict[str, Union[float, str]]] = None 22 | 23 | @classmethod 24 | def from_dict(cls, params): 25 | return cls( 26 | **{ 27 | k: v 28 | for k, v in params.items() 29 | if k in inspect.signature(cls).parameters 30 | } 31 | ) 32 | 33 | def __post_init__(self): 34 | if self.num_key_value_heads is None: 35 | self.num_key_value_heads = self.num_attention_heads 36 | 37 | if self.rope_scaling: 38 | required_keys = {"factor", "type"} 39 | if not all(key in self.rope_scaling for key in required_keys): 40 | raise ValueError(f"rope_scaling must contain keys {required_keys}") 41 | 42 | if self.rope_scaling["type"] != "linear": 43 | raise ValueError("rope_scaling 'type' currently only supports 'linear'") 44 | 45 | 46 | class Attention(nn.Module): 47 | def __init__(self, config: TextConfig): 48 | super().__init__() 49 | 50 | dim = config.hidden_size 51 | self.n_heads = n_heads = config.num_attention_heads 52 | self.n_kv_heads = n_kv_heads = config.num_key_value_heads 53 | 54 | self.repeats = n_heads // n_kv_heads 55 | 56 | head_dim = config.hidden_size // n_heads 57 | self.scale = head_dim**-0.5 58 | 59 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) 60 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 61 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 62 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 63 | 64 | rope_scale = ( 65 | 1 / config.rope_scaling["factor"] 66 | if config.rope_scaling is not None 67 | and config.rope_scaling["type"] == "linear" 68 | else 1 69 | ) 70 | self.rope = nn.RoPE( 71 | head_dim, 72 | traditional=config.rope_traditional, 73 | base=config.rope_theta, 74 | scale=rope_scale, 75 | ) 76 | 77 | def __call__( 78 | self, 79 | x: mx.array, 80 | mask: Optional[mx.array] = None, 81 | cache: Optional[Tuple[mx.array, mx.array]] = None, 82 | ) -> mx.array: 83 | B, L, D = x.shape 84 | 85 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 86 | 87 | # Prepare the queries, keys and values for the attention computation 88 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 89 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 90 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 91 | 92 | if cache is not None: 93 | key_cache, value_cache = cache 94 | queries = self.rope(queries, offset=key_cache.shape[2]) 95 | keys = self.rope(keys, offset=key_cache.shape[2]) 96 | keys = mx.concatenate([key_cache, keys], axis=2) 97 | values = mx.concatenate([value_cache, values], axis=2) 98 | else: 99 | queries = self.rope(queries) 100 | keys = self.rope(keys) 101 | 102 | output = mx.fast.scaled_dot_product_attention( 103 | queries, keys, values, scale=self.scale, mask=mask 104 | ) 105 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 106 | return self.o_proj(output), (keys, values) 107 | 108 | 109 | class MLP(nn.Module): 110 | def __init__(self, dim, hidden_dim): 111 | super().__init__() 112 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 113 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 114 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 115 | 116 | def __call__(self, x) -> mx.array: 117 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 118 | 119 | 120 | class TransformerBlock(nn.Module): 121 | def __init__(self, config: TextConfig): 122 | super().__init__() 123 | self.num_attention_heads = config.num_attention_heads 124 | self.hidden_size = config.hidden_size 125 | self.self_attn = Attention(config) 126 | self.mlp = MLP(config.hidden_size, config.intermediate_size) 127 | self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 128 | self.post_attention_layernorm = nn.RMSNorm( 129 | config.hidden_size, eps=config.rms_norm_eps 130 | ) 131 | self.config = config 132 | 133 | def __call__( 134 | self, 135 | x: mx.array, 136 | mask: Optional[mx.array] = None, 137 | cache: Optional[Tuple[mx.array, mx.array]] = None, 138 | ) -> mx.array: 139 | r, cache = self.self_attn(self.input_layernorm(x), mask, cache) 140 | h = x + r 141 | r = self.mlp(self.post_attention_layernorm(h)) 142 | out = h + r 143 | return out, cache 144 | 145 | 146 | class Llama(nn.Module): 147 | def __init__(self, config: TextConfig): 148 | super().__init__() 149 | self.config = config 150 | self.vocab_size = config.vocab_size 151 | self.num_hidden_layers = config.num_hidden_layers 152 | assert self.vocab_size > 0 153 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) 154 | self.layers = [ 155 | TransformerBlock(config=config) for _ in range(config.num_hidden_layers) 156 | ] 157 | self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 158 | 159 | def __call__( 160 | self, 161 | inputs: mx.array, 162 | cache=None, 163 | inputs_embeds=None, 164 | ): 165 | # for passing merged input embeddings 166 | if inputs_embeds is None: 167 | h = self.embed_tokens(inputs) 168 | else: 169 | h = inputs_embeds 170 | 171 | mask = None 172 | if h.shape[1] > 1: 173 | mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) 174 | mask = mask.astype(h.dtype) 175 | 176 | if cache is None: 177 | cache = [None] * len(self.layers) 178 | 179 | for e, layer in enumerate(self.layers): 180 | h, cache[e] = layer(h, mask, cache[e]) 181 | 182 | return self.norm(h), cache 183 | 184 | 185 | class LanguageModel(nn.Module): 186 | def __init__(self, config: TextConfig): 187 | super().__init__() 188 | self.model_type = config.model_type 189 | if self.model_type != "llama": 190 | raise ValueError( 191 | f"Model type {self.model_type} not supported. Currently only 'llama' is supported" 192 | ) 193 | self.model = Llama(config) 194 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 195 | 196 | def __call__( 197 | self, 198 | inputs: mx.array, 199 | cache=None, 200 | inputs_embeds=None, 201 | mask: Optional[mx.array] = None, 202 | ): 203 | out, cache = self.model(inputs, cache, inputs_embeds) 204 | return self.lm_head(out), cache 205 | 206 | @staticmethod 207 | def sanitize(weights): 208 | # Remove unused precomputed rotary freqs 209 | return { 210 | k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k 211 | } 212 | 213 | @property 214 | def layers(self): 215 | return self.model.layers 216 | -------------------------------------------------------------------------------- /mlx_vlm/models/multi_modality/language.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Dict, Optional, Tuple, Union 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | 8 | 9 | @dataclass 10 | class TextConfig: 11 | model_type: str 12 | hidden_size: int = 4096 13 | num_hidden_layers: int = 32 14 | intermediate_size: int = 11008 15 | num_attention_heads: int = 32 16 | rms_norm_eps: float = 1e-6 17 | vocab_size: int = 102400 18 | num_key_value_heads: int = None 19 | rope_theta: float = 10000 20 | rope_traditional: bool = False 21 | rope_scaling: Optional[Dict[str, Union[float, str]]] = None 22 | 23 | @classmethod 24 | def from_dict(cls, params): 25 | return cls( 26 | **{ 27 | k: v 28 | for k, v in params.items() 29 | if k in inspect.signature(cls).parameters 30 | } 31 | ) 32 | 33 | def __post_init__(self): 34 | if self.num_key_value_heads is None: 35 | self.num_key_value_heads = self.num_attention_heads 36 | 37 | if self.rope_scaling: 38 | required_keys = {"factor", "type"} 39 | if not all(key in self.rope_scaling for key in required_keys): 40 | raise ValueError(f"rope_scaling must contain keys {required_keys}") 41 | 42 | if self.rope_scaling["type"] != "linear": 43 | raise ValueError("rope_scaling 'type' currently only supports 'linear'") 44 | 45 | 46 | class Attention(nn.Module): 47 | def __init__(self, config: TextConfig): 48 | super().__init__() 49 | 50 | dim = config.hidden_size 51 | self.n_heads = n_heads = config.num_attention_heads 52 | self.n_kv_heads = n_kv_heads = config.num_key_value_heads 53 | 54 | self.repeats = n_heads // n_kv_heads 55 | 56 | head_dim = config.hidden_size // n_heads 57 | self.scale = head_dim**-0.5 58 | 59 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) 60 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 61 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 62 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 63 | 64 | rope_scale = ( 65 | 1 / config.rope_scaling["factor"] 66 | if config.rope_scaling is not None 67 | and config.rope_scaling["type"] == "linear" 68 | else 1 69 | ) 70 | self.rope = nn.RoPE( 71 | head_dim, 72 | traditional=config.rope_traditional, 73 | base=config.rope_theta, 74 | scale=rope_scale, 75 | ) 76 | 77 | def __call__( 78 | self, 79 | x: mx.array, 80 | mask: Optional[mx.array] = None, 81 | cache: Optional[Tuple[mx.array, mx.array]] = None, 82 | ) -> mx.array: 83 | B, L, D = x.shape 84 | 85 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 86 | 87 | # Prepare the queries, keys and values for the attention computation 88 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 89 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 90 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 91 | 92 | if cache is not None: 93 | key_cache, value_cache = cache 94 | queries = self.rope(queries, offset=key_cache.shape[2]) 95 | keys = self.rope(keys, offset=key_cache.shape[2]) 96 | keys = mx.concatenate([key_cache, keys], axis=2) 97 | values = mx.concatenate([value_cache, values], axis=2) 98 | else: 99 | queries = self.rope(queries) 100 | keys = self.rope(keys) 101 | 102 | output = mx.fast.scaled_dot_product_attention( 103 | queries, keys, values, scale=self.scale, mask=mask 104 | ) 105 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 106 | return self.o_proj(output), (keys, values) 107 | 108 | 109 | class MLP(nn.Module): 110 | def __init__(self, dim, hidden_dim): 111 | super().__init__() 112 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 113 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 114 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 115 | 116 | def __call__(self, x) -> mx.array: 117 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 118 | 119 | 120 | class TransformerBlock(nn.Module): 121 | def __init__(self, config: TextConfig): 122 | super().__init__() 123 | self.num_attention_heads = config.num_attention_heads 124 | self.hidden_size = config.hidden_size 125 | self.self_attn = Attention(config) 126 | self.mlp = MLP(config.hidden_size, config.intermediate_size) 127 | self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 128 | self.post_attention_layernorm = nn.RMSNorm( 129 | config.hidden_size, eps=config.rms_norm_eps 130 | ) 131 | self.config = config 132 | 133 | def __call__( 134 | self, 135 | x: mx.array, 136 | mask: Optional[mx.array] = None, 137 | cache: Optional[Tuple[mx.array, mx.array]] = None, 138 | ) -> mx.array: 139 | r, cache = self.self_attn(self.input_layernorm(x), mask, cache) 140 | h = x + r 141 | r = self.mlp(self.post_attention_layernorm(h)) 142 | out = h + r 143 | return out, cache 144 | 145 | 146 | class Llama(nn.Module): 147 | def __init__(self, config: TextConfig): 148 | super().__init__() 149 | self.config = config 150 | self.vocab_size = config.vocab_size 151 | self.num_hidden_layers = config.num_hidden_layers 152 | assert self.vocab_size > 0 153 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) 154 | self.layers = [ 155 | TransformerBlock(config=config) for _ in range(config.num_hidden_layers) 156 | ] 157 | self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 158 | 159 | def __call__( 160 | self, 161 | inputs: mx.array, 162 | cache=None, 163 | inputs_embeds=None, 164 | ): 165 | # for passing merged input embeddings 166 | if inputs_embeds is None: 167 | h = self.embed_tokens(inputs) 168 | else: 169 | h = inputs_embeds 170 | 171 | mask = None 172 | if h.shape[1] > 1: 173 | mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) 174 | mask = mask.astype(h.dtype) 175 | 176 | if cache is None: 177 | cache = [None] * len(self.layers) 178 | 179 | for e, layer in enumerate(self.layers): 180 | h, cache[e] = layer(h, mask, cache[e]) 181 | 182 | return self.norm(h), cache 183 | 184 | 185 | class LanguageModel(nn.Module): 186 | def __init__(self, config: TextConfig): 187 | super().__init__() 188 | self.model_type = config.model_type 189 | if self.model_type != "llama": 190 | raise ValueError( 191 | f"Model type {self.model_type} not supported. Currently only 'llama' is supported" 192 | ) 193 | self.model = Llama(config) 194 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 195 | 196 | def __call__( 197 | self, 198 | inputs: mx.array, 199 | cache=None, 200 | inputs_embeds=None, 201 | mask: Optional[mx.array] = None, 202 | ): 203 | out, cache = self.model(inputs, cache, inputs_embeds) 204 | return self.lm_head(out), cache 205 | 206 | @staticmethod 207 | def sanitize(weights): 208 | # Remove unused precomputed rotary freqs 209 | return { 210 | k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k 211 | } 212 | 213 | @property 214 | def layers(self): 215 | return self.model.layers 216 | -------------------------------------------------------------------------------- /mlx_vlm/models/llava_next/language.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Dict, Optional, Tuple, Union 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | 8 | 9 | @dataclass 10 | class TextConfig: 11 | model_type: str 12 | hidden_size: int = 4096 13 | num_hidden_layers: int = 32 14 | intermediate_size: int = 14336 15 | num_attention_heads: int = 32 16 | rms_norm_eps: float = 1e-05 17 | vocab_size: int = 32064 18 | num_key_value_heads: int = 8 19 | rope_theta: float = 1000000 20 | rope_traditional: bool = False 21 | rope_scaling: Optional[Dict[str, Union[float, str]]] = None 22 | 23 | @classmethod 24 | def from_dict(cls, params): 25 | return cls( 26 | **{ 27 | k: v 28 | for k, v in params.items() 29 | if k in inspect.signature(cls).parameters 30 | } 31 | ) 32 | 33 | def __post_init__(self): 34 | if self.num_key_value_heads is None: 35 | self.num_key_value_heads = self.num_attention_heads 36 | 37 | if self.rope_scaling: 38 | required_keys = {"factor", "type"} 39 | if not all(key in self.rope_scaling for key in required_keys): 40 | raise ValueError(f"rope_scaling must contain keys {required_keys}") 41 | 42 | if self.rope_scaling["type"] != "linear": 43 | raise ValueError("rope_scaling 'type' currently only supports 'linear'") 44 | 45 | 46 | class Attention(nn.Module): 47 | def __init__(self, config: TextConfig): 48 | super().__init__() 49 | 50 | dim = config.hidden_size 51 | self.n_heads = n_heads = config.num_attention_heads 52 | self.n_kv_heads = n_kv_heads = config.num_key_value_heads 53 | 54 | self.repeats = n_heads // n_kv_heads 55 | 56 | head_dim = config.hidden_size // n_heads 57 | self.scale = head_dim**-0.5 58 | 59 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) 60 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 61 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 62 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 63 | 64 | rope_scale = ( 65 | 1 / config.rope_scaling["factor"] 66 | if config.rope_scaling is not None 67 | and config.rope_scaling["type"] == "linear" 68 | else 1 69 | ) 70 | self.rope = nn.RoPE( 71 | head_dim, 72 | traditional=config.rope_traditional, 73 | base=config.rope_theta, 74 | scale=rope_scale, 75 | ) 76 | 77 | def __call__( 78 | self, 79 | x: mx.array, 80 | mask: Optional[mx.array] = None, 81 | cache: Optional[Tuple[mx.array, mx.array]] = None, 82 | ) -> mx.array: 83 | B, L, D = x.shape 84 | 85 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 86 | 87 | # Prepare the queries, keys and values for the attention computation 88 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 89 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 90 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 91 | 92 | if cache is not None: 93 | key_cache, value_cache = cache 94 | queries = self.rope(queries, offset=key_cache.shape[2]) 95 | keys = self.rope(keys, offset=key_cache.shape[2]) 96 | keys = mx.concatenate([key_cache, keys], axis=2) 97 | values = mx.concatenate([value_cache, values], axis=2) 98 | else: 99 | queries = self.rope(queries) 100 | keys = self.rope(keys) 101 | 102 | output = mx.fast.scaled_dot_product_attention( 103 | queries, keys, values, scale=self.scale, mask=mask 104 | ) 105 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 106 | return self.o_proj(output), (keys, values) 107 | 108 | 109 | class MLP(nn.Module): 110 | def __init__(self, dim, hidden_dim): 111 | super().__init__() 112 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 113 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 114 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 115 | 116 | def __call__(self, x) -> mx.array: 117 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 118 | 119 | 120 | class TransformerBlock(nn.Module): 121 | def __init__(self, config: TextConfig): 122 | super().__init__() 123 | self.num_attention_heads = config.num_attention_heads 124 | self.hidden_size = config.hidden_size 125 | self.self_attn = Attention(config) 126 | self.mlp = MLP(config.hidden_size, config.intermediate_size) 127 | self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 128 | self.post_attention_layernorm = nn.RMSNorm( 129 | config.hidden_size, eps=config.rms_norm_eps 130 | ) 131 | self.config = config 132 | 133 | def __call__( 134 | self, 135 | x: mx.array, 136 | mask: Optional[mx.array] = None, 137 | cache: Optional[Tuple[mx.array, mx.array]] = None, 138 | ) -> mx.array: 139 | r, cache = self.self_attn(self.input_layernorm(x), mask, cache) 140 | h = x + r 141 | r = self.mlp(self.post_attention_layernorm(h)) 142 | out = h + r 143 | return out, cache 144 | 145 | 146 | class Llama(nn.Module): 147 | def __init__(self, config: TextConfig): 148 | super().__init__() 149 | self.config = config 150 | self.vocab_size = config.vocab_size 151 | self.num_hidden_layers = config.num_hidden_layers 152 | assert self.vocab_size > 0 153 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) 154 | self.layers = [ 155 | TransformerBlock(config=config) for _ in range(config.num_hidden_layers) 156 | ] 157 | self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 158 | 159 | def __call__( 160 | self, 161 | inputs: mx.array, 162 | cache=None, 163 | inputs_embeds=None, 164 | ): 165 | # for passing merged input embeddings 166 | if inputs_embeds is None: 167 | h = self.embed_tokens(inputs) 168 | else: 169 | h = inputs_embeds 170 | 171 | mask = None 172 | if h.shape[1] > 1: 173 | mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) 174 | mask = mask.astype(h.dtype) 175 | 176 | if cache is None: 177 | cache = [None] * len(self.layers) 178 | 179 | for e, layer in enumerate(self.layers): 180 | h, cache[e] = layer(h, mask, cache[e]) 181 | 182 | return self.norm(h), cache 183 | 184 | 185 | class LanguageModel(nn.Module): 186 | def __init__(self, config: TextConfig): 187 | super().__init__() 188 | self.model_type = config.model_type 189 | if self.model_type not in ["mistral", "llama"]: 190 | raise ValueError( 191 | f"Model type {self.model_type} not supported. Currently only 'llama' is supported" 192 | ) 193 | self.model = Llama(config) 194 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 195 | 196 | def __call__( 197 | self, 198 | inputs: mx.array, 199 | cache=None, 200 | inputs_embeds=None, 201 | mask: Optional[mx.array] = None, 202 | ): 203 | out, cache = self.model(inputs, cache, inputs_embeds) 204 | return self.lm_head(out), cache 205 | 206 | @staticmethod 207 | def sanitize(weights): 208 | # Remove unused precomputed rotary freqs 209 | return { 210 | k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k 211 | } 212 | 213 | @property 214 | def layers(self): 215 | return self.model.layers 216 | -------------------------------------------------------------------------------- /mlx_vlm/models/phi3_v/phi3_v.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import math 3 | from dataclasses import dataclass 4 | from types import SimpleNamespace 5 | from typing import Dict, Optional, Tuple, Union 6 | 7 | import mlx.core as mx 8 | import mlx.nn as nn 9 | import numpy as np 10 | 11 | from .language import LanguageModel, TextConfig 12 | from .su_rope import Phi3SuScaledRotaryEmbedding 13 | from .vision import VisionConfig, VisionModel 14 | 15 | 16 | @dataclass 17 | class ModelConfig: 18 | text_config: TextConfig 19 | vision_config: VisionConfig 20 | model_type: str 21 | vocab_size: int 22 | 23 | num_hidden_layers: int 24 | intermediate_size: int 25 | num_attention_heads: int 26 | rms_norm_eps: float 27 | 28 | ignore_index: int = -100 29 | image_token_index: int = 257152 30 | hidden_size: int = 2048 31 | pad_token_id: int = 0 32 | 33 | num_key_value_heads: int = None 34 | rope_theta: float = 10000 35 | rope_traditional: bool = False 36 | rope_scaling: Optional[Dict[str, Union[float, str]]] = None 37 | max_position_embeddings: int = 131072 38 | original_max_position_embeddings: int = 4096 39 | 40 | @classmethod 41 | def from_dict(cls, params): 42 | return cls( 43 | **{ 44 | k: v 45 | for k, v in params.items() 46 | if k in inspect.signature(cls).parameters 47 | } 48 | ) 49 | 50 | 51 | class Attention(nn.Module): 52 | def __init__(self, args: TextConfig): 53 | super().__init__() 54 | 55 | dim = args.hidden_size 56 | self.n_heads = n_heads = args.num_attention_heads 57 | self.n_kv_heads = n_kv_heads = args.num_key_value_heads 58 | self.num_hidden_layers = args.num_hidden_layers 59 | 60 | self.head_dim = head_dim = args.hidden_size // n_heads 61 | self.scale = head_dim**-0.5 62 | 63 | op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim) 64 | self.qkv_proj = nn.Linear(dim, op_size, bias=False) 65 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 66 | 67 | rope_scale = 1.0 68 | if args.rope_scaling and args.rope_scaling["type"] == "su": 69 | self.rope = Phi3SuScaledRotaryEmbedding( 70 | head_dim, 71 | traditional=False, 72 | base=args.rope_theta, 73 | scale=rope_scale, 74 | max_position_embeddings=args.max_position_embeddings, 75 | original_max_position_embeddings=args.original_max_position_embeddings, 76 | short_factor=args.rope_scaling["short_factor"], 77 | long_factor=args.rope_scaling["long_factor"], 78 | ) 79 | else: 80 | if args.rope_scaling and args.rope_scaling["type"] == "linear": 81 | rope_scale = 1 / args.rope_scaling["factor"] 82 | self.rope = nn.RoPE( 83 | head_dim, 84 | traditional=args.rope_traditional, 85 | base=args.rope_theta, 86 | scale=rope_scale, 87 | ) 88 | 89 | def __call__( 90 | self, 91 | x: mx.array, 92 | mask: Optional[mx.array] = None, 93 | cache: Optional[Tuple[mx.array, mx.array]] = None, 94 | ) -> mx.array: 95 | B, L, D = x.shape 96 | 97 | qkv = self.qkv_proj(x) 98 | query_pos = self.n_heads * self.head_dim 99 | queries, keys, values = mx.split( 100 | qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1 101 | ) 102 | 103 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 104 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 105 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 106 | 107 | if cache is not None: 108 | offset = cache[0].shape[2] 109 | queries = self.rope(queries, offset=offset) 110 | keys = self.rope(keys, offset=offset) 111 | keys = mx.concatenate([cache[0], keys], axis=2) 112 | values = mx.concatenate([cache[1], values], axis=2) 113 | else: 114 | queries = self.rope(queries) 115 | keys = self.rope(keys) 116 | 117 | output = mx.fast.scaled_dot_product_attention( 118 | queries, keys, values, scale=self.scale, mask=mask 119 | ) 120 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 121 | return self.o_proj(output), (keys, values) 122 | 123 | 124 | class MLP(nn.Module): 125 | def __init__(self, dim, hidden_dim): 126 | super().__init__() 127 | self.gate_up_proj = nn.Linear(dim, 2 * hidden_dim, bias=False) 128 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 129 | 130 | def __call__(self, x) -> mx.array: 131 | x = self.gate_up_proj(x) 132 | gate, x = mx.split(x, 2, axis=-1) 133 | return self.down_proj(nn.silu(gate) * x) 134 | 135 | 136 | class TransformerBlock(nn.Module): 137 | def __init__(self, args: TextConfig): 138 | super().__init__() 139 | self.num_attention_heads = args.num_attention_heads 140 | self.hidden_size = args.hidden_size 141 | self.self_attn = Attention(args) 142 | self.mlp = MLP(args.hidden_size, args.intermediate_size) 143 | self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 144 | self.post_attention_layernorm = nn.RMSNorm( 145 | args.hidden_size, eps=args.rms_norm_eps 146 | ) 147 | self.args = args 148 | 149 | def __call__( 150 | self, 151 | x: mx.array, 152 | mask: Optional[mx.array] = None, 153 | cache: Optional[Tuple[mx.array, mx.array]] = None, 154 | ) -> mx.array: 155 | r, cache = self.self_attn(self.input_layernorm(x), mask, cache) 156 | h = x + r 157 | r = self.mlp(self.post_attention_layernorm(h)) 158 | out = h + r 159 | return out, cache 160 | 161 | 162 | class Phi3V(nn.Module): 163 | def __init__(self, args: TextConfig): 164 | super().__init__() 165 | self.args = args 166 | self.vocab_size = args.vocab_size 167 | self.num_hidden_layers = args.num_hidden_layers 168 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 169 | self.vision_embed_tokens = VisionModel(args) 170 | self.layers = [ 171 | TransformerBlock(args=args) for _ in range(args.num_hidden_layers) 172 | ] 173 | self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 174 | 175 | def __call__( 176 | self, 177 | inputs: mx.array, 178 | pixel_values=None, 179 | image_sizes=None, 180 | cache=None, 181 | ): 182 | # print('inputs', inputs) # debug 183 | h = self.embed_tokens(inputs) 184 | p = np.argwhere(inputs < 0).tolist() 185 | if pixel_values is not None: 186 | h = self.vision_embed_tokens(pixel_values, h, image_sizes, p) 187 | mask = None 188 | if h.shape[1] > 1: 189 | mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) 190 | mask = mask.astype(h.dtype) 191 | if cache is None: 192 | cache = [None] * len(self.layers) 193 | for i, layer in enumerate(self.layers): 194 | h, cache[i] = layer(h, mask, cache[i]) 195 | return self.norm(h), cache 196 | 197 | 198 | class Model(nn.Module): 199 | def __init__(self, args: TextConfig): 200 | super().__init__() 201 | self.model_type = args.model_type 202 | self.model = Phi3V(args) 203 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 204 | self.config = args 205 | 206 | def __call__( 207 | self, 208 | inputs: mx.array, 209 | pixel_values=None, 210 | mask=None, 211 | cache=None, 212 | ): 213 | out, cache = self.model(inputs, pixel_values, mask, cache) 214 | return self.lm_head(out).astype(self.lm_head.weight.dtype), cache 215 | 216 | @property 217 | def layers(self): 218 | return self.model.layers 219 | 220 | @property 221 | def head_dim(self): 222 | return self.args.hidden_size // self.args.num_attention_heads 223 | 224 | @property 225 | def n_kv_heads(self): 226 | return self.args.num_key_value_heads 227 | 228 | @property 229 | def language_model(self): 230 | return self 231 | 232 | @property 233 | def vision_model(self): 234 | return self.model.vision_embed_tokens 235 | -------------------------------------------------------------------------------- /mlx_vlm/models/llava/vision.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import math 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | import numpy as np 9 | 10 | 11 | @dataclass 12 | class VisionConfig: 13 | model_type: str 14 | num_hidden_layers: int = 24 15 | hidden_size: int = 1024 16 | intermediate_size: int = 4096 17 | num_attention_heads: int = 16 18 | image_size: int = 336 19 | patch_size: int = 14 20 | projection_dim: int = 768 21 | vocab_size: int = 32000 22 | num_channels: int = 3 23 | layer_norm_eps: float = 1e-5 24 | 25 | @classmethod 26 | def from_dict(cls, params): 27 | return cls( 28 | **{ 29 | k: v 30 | for k, v in params.items() 31 | if k in inspect.signature(cls).parameters 32 | } 33 | ) 34 | 35 | 36 | def check_array_shape(arr): 37 | shape = arr.shape 38 | 39 | # Check if the shape has 4 dimensions 40 | if len(shape) != 4: 41 | return False 42 | 43 | out_channels, kH, KW, _ = shape 44 | 45 | # Check if out_channels is the largest, and kH and KW are the same 46 | if (out_channels >= kH) and (out_channels >= KW) and (kH == KW): 47 | return True 48 | else: 49 | return False 50 | 51 | 52 | class Attention(nn.Module): 53 | def __init__( 54 | self, 55 | dims: int, 56 | num_heads: int, 57 | query_input_dims: Optional[int] = None, 58 | key_input_dims: Optional[int] = None, 59 | value_input_dims: Optional[int] = None, 60 | value_dims: Optional[int] = None, 61 | value_output_dims: Optional[int] = None, 62 | bias: bool = False, 63 | ): 64 | super().__init__() 65 | 66 | if (dims % num_heads) != 0: 67 | raise ValueError( 68 | "The input feature dimensions should be divisible by the " 69 | f"number of heads ({dims} % {num_heads}) != 0" 70 | ) 71 | 72 | query_input_dims = query_input_dims or dims 73 | key_input_dims = key_input_dims or dims 74 | value_input_dims = value_input_dims or key_input_dims 75 | value_dims = value_dims or dims 76 | value_output_dims = value_output_dims or dims 77 | 78 | self.num_heads = num_heads = num_heads 79 | head_dim = dims // num_heads 80 | self.scale = head_dim**-0.5 81 | 82 | self.q_proj = nn.Linear(query_input_dims, dims, bias=bias) 83 | self.k_proj = nn.Linear(key_input_dims, dims, bias=bias) 84 | self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) 85 | self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) 86 | 87 | def __call__(self, queries, keys, values, mask=None): 88 | queries = self.q_proj(queries) 89 | keys = self.k_proj(keys) 90 | values = self.v_proj(values) 91 | 92 | num_heads = self.num_heads 93 | B, L, D = queries.shape 94 | _, S, _ = keys.shape 95 | queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) 96 | keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) 97 | values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) 98 | 99 | output = mx.fast.scaled_dot_product_attention( 100 | queries, keys, values, scale=self.scale, mask=mask 101 | ) 102 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 103 | 104 | return self.out_proj(output) 105 | 106 | 107 | class MLP(nn.Module): 108 | def __init__(self, config: VisionConfig): 109 | super().__init__() 110 | self.activation_fn = nn.GELU(approx="fast") 111 | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) 112 | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) 113 | 114 | def __call__(self, x: mx.array) -> mx.array: 115 | x = self.activation_fn(self.fc1(x)) 116 | x = self.fc2(x) 117 | return x 118 | 119 | 120 | class EncoderLayer(nn.Module): 121 | def __init__(self, config: VisionConfig): 122 | super().__init__() 123 | self.embed_dim = config.hidden_size 124 | self.self_attn = Attention( 125 | config.hidden_size, config.num_attention_heads, bias=True 126 | ) 127 | self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 128 | self.mlp = MLP(config) 129 | self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 130 | 131 | def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: 132 | y = self.layer_norm1(x) 133 | y = self.self_attn(y, y, y, mask) 134 | x = x + y 135 | y = self.layer_norm2(x) 136 | y = self.mlp(y) 137 | return x + y 138 | 139 | 140 | class Encoder(nn.Module): 141 | def __init__(self, config: VisionConfig): 142 | super().__init__() 143 | self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] 144 | 145 | 146 | class VisionEmbeddings(nn.Module): 147 | def __init__(self, config: VisionConfig): 148 | super().__init__() 149 | self.config = config 150 | self.embed_dim = config.hidden_size 151 | self.image_size = config.image_size 152 | self.patch_size = config.patch_size 153 | 154 | self.class_embedding = mx.zeros((config.hidden_size,)) 155 | 156 | self.patch_embedding = nn.Conv2d( 157 | in_channels=config.num_channels, 158 | out_channels=self.embed_dim, 159 | kernel_size=self.patch_size, 160 | stride=self.patch_size, 161 | bias=False, 162 | ) 163 | 164 | self.num_patches = (self.image_size // self.patch_size) ** 2 165 | self.num_positions = self.num_patches + 1 166 | self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) 167 | 168 | def __call__(self, x: mx.array) -> mx.array: 169 | batch_size = x.shape[0] 170 | patch_embeddings = self.patch_embedding(x) 171 | patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) 172 | embed_dim = patch_embeddings.shape[-1] 173 | cls_embeddings = mx.broadcast_to( 174 | self.class_embedding, (batch_size, 1, embed_dim) 175 | ) 176 | position_ids = mx.array(np.arange(self.num_positions)[None, :]) 177 | 178 | embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) 179 | embeddings += self.position_embedding(position_ids) 180 | return embeddings 181 | 182 | 183 | class ClipVisionModel(nn.Module): 184 | def __init__(self, config: VisionConfig): 185 | super().__init__() 186 | self.embeddings = VisionEmbeddings(config) 187 | self.pre_layrnorm = nn.LayerNorm(config.hidden_size) 188 | self.encoder = Encoder(config) 189 | self.post_layernorm = nn.LayerNorm(config.hidden_size) 190 | 191 | def __call__( 192 | self, 193 | x: mx.array, 194 | output_hidden_states: Optional[bool] = None, 195 | ) -> mx.array: 196 | x = self.embeddings(x) 197 | x = self.pre_layrnorm(x) 198 | 199 | encoder_states = (x,) if output_hidden_states else None 200 | 201 | for l in self.encoder.layers: 202 | x = l(x, mask=None) 203 | if output_hidden_states: 204 | encoder_states = encoder_states + (x,) 205 | 206 | pooler_output = self.post_layernorm(x[:, 0, :]) 207 | return pooler_output, x, encoder_states 208 | 209 | 210 | class VisionModel(nn.Module): 211 | def __init__(self, config: VisionConfig): 212 | super().__init__() 213 | 214 | self.model_type = config.model_type 215 | if self.model_type != "clip_vision_model": 216 | raise ValueError(f"Unsupported model type: {self.model_type}") 217 | 218 | self.vision_model = ClipVisionModel(config) 219 | 220 | def __call__( 221 | self, x: mx.array, output_hidden_states: Optional[bool] = None 222 | ) -> mx.array: 223 | return self.vision_model(x, output_hidden_states) 224 | 225 | def sanitize(self, weights): 226 | sanitized_weights = {} 227 | for k, v in weights.items(): 228 | if "position_ids" in k: 229 | # Remove unused position_ids 230 | continue 231 | elif "patch_embedding.weight" in k: 232 | # PyTorch conv2d weight tensors have shape: 233 | # [out_channels, in_channels, kH, KW] 234 | # MLX conv2d expects the weight be of shape: 235 | # [out_channels, kH, KW, in_channels] 236 | if check_array_shape(v): 237 | sanitized_weights[k] = v 238 | else: 239 | sanitized_weights[k] = v.transpose(0, 2, 3, 1) 240 | else: 241 | sanitized_weights[k] = v 242 | 243 | return sanitized_weights 244 | -------------------------------------------------------------------------------- /mlx_vlm/models/llava_next/vision.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import math 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | import numpy as np 9 | 10 | 11 | @dataclass 12 | class VisionConfig: 13 | model_type: str 14 | num_hidden_layers: int = 24 15 | hidden_size: int = 1024 16 | intermediate_size: int = 4096 17 | num_attention_heads: int = 16 18 | image_size: int = 336 19 | patch_size: int = 14 20 | projection_dim: int = 768 21 | vocab_size: int = 32000 22 | num_channels: int = 3 23 | layer_norm_eps: float = 1e-5 24 | 25 | @classmethod 26 | def from_dict(cls, params): 27 | return cls( 28 | **{ 29 | k: v 30 | for k, v in params.items() 31 | if k in inspect.signature(cls).parameters 32 | } 33 | ) 34 | 35 | 36 | def check_array_shape(arr): 37 | shape = arr.shape 38 | 39 | # Check if the shape has 4 dimensions 40 | if len(shape) != 4: 41 | return False 42 | 43 | out_channels, kH, KW, _ = shape 44 | 45 | # Check if out_channels is the largest, and kH and KW are the same 46 | if (out_channels >= kH) and (out_channels >= KW) and (kH == KW): 47 | return True 48 | else: 49 | return False 50 | 51 | 52 | class Attention(nn.Module): 53 | def __init__( 54 | self, 55 | dims: int, 56 | num_heads: int, 57 | query_input_dims: Optional[int] = None, 58 | key_input_dims: Optional[int] = None, 59 | value_input_dims: Optional[int] = None, 60 | value_dims: Optional[int] = None, 61 | value_output_dims: Optional[int] = None, 62 | bias: bool = False, 63 | ): 64 | super().__init__() 65 | 66 | if (dims % num_heads) != 0: 67 | raise ValueError( 68 | "The input feature dimensions should be divisible by the " 69 | f"number of heads ({dims} % {num_heads}) != 0" 70 | ) 71 | 72 | query_input_dims = query_input_dims or dims 73 | key_input_dims = key_input_dims or dims 74 | value_input_dims = value_input_dims or key_input_dims 75 | value_dims = value_dims or dims 76 | value_output_dims = value_output_dims or dims 77 | 78 | self.num_heads = num_heads = num_heads 79 | head_dim = dims // num_heads 80 | self.scale = head_dim**-0.5 81 | 82 | self.q_proj = nn.Linear(query_input_dims, dims, bias=bias) 83 | self.k_proj = nn.Linear(key_input_dims, dims, bias=bias) 84 | self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) 85 | self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) 86 | 87 | def __call__(self, queries, keys, values, mask=None): 88 | queries = self.q_proj(queries) 89 | keys = self.k_proj(keys) 90 | values = self.v_proj(values) 91 | 92 | num_heads = self.num_heads 93 | B, L, D = queries.shape 94 | _, S, _ = keys.shape 95 | queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) 96 | keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) 97 | values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) 98 | 99 | output = mx.fast.scaled_dot_product_attention( 100 | queries, keys, values, scale=self.scale, mask=mask 101 | ) 102 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 103 | 104 | return self.out_proj(output) 105 | 106 | 107 | class MLP(nn.Module): 108 | def __init__(self, config: VisionConfig): 109 | super().__init__() 110 | self.activation_fn = nn.GELU(approx="fast") 111 | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) 112 | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) 113 | 114 | def __call__(self, x: mx.array) -> mx.array: 115 | x = self.activation_fn(self.fc1(x)) 116 | x = self.fc2(x) 117 | return x 118 | 119 | 120 | class EncoderLayer(nn.Module): 121 | def __init__(self, config: VisionConfig): 122 | super().__init__() 123 | self.embed_dim = config.hidden_size 124 | self.self_attn = Attention( 125 | config.hidden_size, config.num_attention_heads, bias=True 126 | ) 127 | self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 128 | self.mlp = MLP(config) 129 | self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 130 | 131 | def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: 132 | y = self.layer_norm1(x) 133 | y = self.self_attn(y, y, y, mask) 134 | x = x + y 135 | y = self.layer_norm2(x) 136 | y = self.mlp(y) 137 | return x + y 138 | 139 | 140 | class Encoder(nn.Module): 141 | def __init__(self, config: VisionConfig): 142 | super().__init__() 143 | self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] 144 | 145 | 146 | class VisionEmbeddings(nn.Module): 147 | def __init__(self, config: VisionConfig): 148 | super().__init__() 149 | self.config = config 150 | self.embed_dim = config.hidden_size 151 | self.image_size = config.image_size 152 | self.patch_size = config.patch_size 153 | 154 | self.class_embedding = mx.zeros((config.hidden_size,)) 155 | 156 | self.patch_embedding = nn.Conv2d( 157 | in_channels=config.num_channels, 158 | out_channels=self.embed_dim, 159 | kernel_size=self.patch_size, 160 | stride=self.patch_size, 161 | bias=False, 162 | ) 163 | 164 | self.num_patches = (self.image_size // self.patch_size) ** 2 165 | self.num_positions = self.num_patches + 1 166 | self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) 167 | 168 | def __call__(self, x: mx.array) -> mx.array: 169 | batch_size = x.shape[0] 170 | patch_embeddings = self.patch_embedding(x) 171 | patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) 172 | embed_dim = patch_embeddings.shape[-1] 173 | cls_embeddings = mx.broadcast_to( 174 | self.class_embedding, (batch_size, 1, embed_dim) 175 | ) 176 | position_ids = mx.array(np.arange(self.num_positions)[None, :]) 177 | 178 | embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) 179 | embeddings += self.position_embedding(position_ids) 180 | return embeddings 181 | 182 | 183 | class ClipVisionModel(nn.Module): 184 | def __init__(self, config: VisionConfig): 185 | super().__init__() 186 | self.embeddings = VisionEmbeddings(config) 187 | self.pre_layrnorm = nn.LayerNorm(config.hidden_size) 188 | self.encoder = Encoder(config) 189 | self.post_layernorm = nn.LayerNorm(config.hidden_size) 190 | 191 | def __call__( 192 | self, 193 | x: mx.array, 194 | output_hidden_states: Optional[bool] = None, 195 | ) -> mx.array: 196 | x = self.embeddings(x) 197 | x = self.pre_layrnorm(x) 198 | 199 | encoder_states = (x,) if output_hidden_states else None 200 | 201 | for l in self.encoder.layers: 202 | x = l(x, mask=None) 203 | if output_hidden_states: 204 | encoder_states = encoder_states + (x,) 205 | 206 | pooler_output = self.post_layernorm(x[:, 0, :]) 207 | return pooler_output, x, encoder_states 208 | 209 | 210 | class VisionModel(nn.Module): 211 | def __init__(self, config: VisionConfig): 212 | super().__init__() 213 | 214 | self.model_type = config.model_type 215 | if self.model_type != "clip_vision_model": 216 | raise ValueError(f"Unsupported model type: {self.model_type}") 217 | 218 | self.vision_model = ClipVisionModel(config) 219 | 220 | def __call__( 221 | self, x: mx.array, output_hidden_states: Optional[bool] = None 222 | ) -> mx.array: 223 | return self.vision_model(x, output_hidden_states) 224 | 225 | def sanitize(self, weights): 226 | sanitized_weights = {} 227 | for k, v in weights.items(): 228 | if "position_ids" in k: 229 | # Remove unused position_ids 230 | continue 231 | elif "patch_embedding.weight" in k: 232 | # PyTorch conv2d weight tensors have shape: 233 | # [out_channels, in_channels, kH, KW] 234 | # MLX conv2d expects the weight be of shape: 235 | # [out_channels, kH, KW, in_channels] 236 | if check_array_shape(v): 237 | sanitized_weights[k] = v 238 | else: 239 | sanitized_weights[k] = v.transpose(0, 2, 3, 1) 240 | else: 241 | sanitized_weights[k] = v 242 | 243 | return sanitized_weights 244 | -------------------------------------------------------------------------------- /mlx_vlm/models/paligemma/vision.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | import numpy as np 8 | 9 | 10 | @dataclass 11 | class VisionConfig: 12 | model_type: str 13 | num_hidden_layers: int 14 | hidden_size: int 15 | intermediate_size: int 16 | num_attention_heads: int 17 | patch_size: int 18 | projection_dim: int 19 | image_size: int = 224 20 | num_channels: int = 3 21 | layer_norm_eps: float = 1e-6 22 | 23 | @classmethod 24 | def from_dict(cls, params): 25 | return cls( 26 | **{ 27 | k: v 28 | for k, v in params.items() 29 | if k in inspect.signature(cls).parameters 30 | } 31 | ) 32 | 33 | 34 | def check_array_shape(arr): 35 | shape = arr.shape 36 | 37 | # Check if the shape has 4 dimensions 38 | if len(shape) != 4: 39 | return False 40 | 41 | out_channels, kH, KW, _ = shape 42 | 43 | # Check if out_channels is the largest, and kH and KW are the same 44 | if (out_channels >= kH) and (out_channels >= KW) and (kH == KW): 45 | return True 46 | else: 47 | return False 48 | 49 | 50 | class Attention(nn.Module): 51 | def __init__( 52 | self, 53 | dims: int, 54 | num_heads: int, 55 | query_input_dims: Optional[int] = None, 56 | key_input_dims: Optional[int] = None, 57 | value_input_dims: Optional[int] = None, 58 | value_dims: Optional[int] = None, 59 | value_output_dims: Optional[int] = None, 60 | bias: bool = True, 61 | ): 62 | super().__init__() 63 | 64 | if (dims % num_heads) != 0: 65 | raise ValueError( 66 | "The input feature dimensions should be divisible by the " 67 | f"number of heads ({dims} % {num_heads}) != 0" 68 | ) 69 | 70 | query_input_dims = query_input_dims or dims 71 | key_input_dims = key_input_dims or dims 72 | value_input_dims = value_input_dims or key_input_dims 73 | value_dims = value_dims or dims 74 | value_output_dims = value_output_dims or dims 75 | 76 | self.num_heads = num_heads 77 | head_dim = dims // num_heads 78 | self.scale = head_dim**-0.5 79 | 80 | self.q_proj = nn.Linear(query_input_dims, dims, bias=bias) 81 | self.k_proj = nn.Linear(key_input_dims, dims, bias=bias) 82 | self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) 83 | self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) 84 | 85 | def __call__(self, x, mask=None): 86 | queries = self.q_proj(x) 87 | keys = self.k_proj(x) 88 | values = self.v_proj(x) 89 | 90 | num_heads = self.num_heads 91 | B, L, D = queries.shape 92 | _, S, _ = keys.shape 93 | queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) 94 | keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) 95 | values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) 96 | 97 | output = mx.fast.scaled_dot_product_attention( 98 | queries, keys, values, scale=self.scale, mask=mask 99 | ) 100 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 101 | return self.out_proj(output) 102 | 103 | 104 | class FastGELUActivation(nn.Module): 105 | """ 106 | Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs 107 | """ 108 | 109 | def __call__(self, input: mx.array) -> mx.array: 110 | return ( 111 | 0.5 112 | * input 113 | * (1.0 + mx.tanh(np.sqrt(2 / np.pi) * (input + 0.044715 * (input**3)))) 114 | ) 115 | 116 | 117 | class MLP(nn.Module): 118 | def __init__(self, config: VisionConfig): 119 | super().__init__() 120 | self.activation_fn = FastGELUActivation() 121 | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True) 122 | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True) 123 | 124 | def __call__(self, x: mx.array) -> mx.array: 125 | x = self.fc1(x) 126 | x = self.activation_fn(x) 127 | x = self.fc2(x) 128 | return x 129 | 130 | 131 | class EncoderLayer(nn.Module): 132 | def __init__(self, config: VisionConfig): 133 | super().__init__() 134 | self.embed_dim = config.hidden_size 135 | self.self_attn = Attention( 136 | config.hidden_size, config.num_attention_heads, bias=True 137 | ) 138 | self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 139 | self.mlp = MLP(config) 140 | self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 141 | 142 | def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: 143 | r = self.self_attn(self.layer_norm1(x), mask) 144 | h = x + r 145 | r = self.mlp(self.layer_norm2(h)) 146 | return h + r 147 | 148 | 149 | class Encoder(nn.Module): 150 | def __init__(self, config: VisionConfig): 151 | super().__init__() 152 | self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] 153 | 154 | def __call__( 155 | self, 156 | x: mx.array, 157 | output_hidden_states: Optional[bool] = None, 158 | mask: Optional[mx.array] = None, 159 | ) -> mx.array: 160 | encoder_states = (x,) if output_hidden_states else None 161 | h = x 162 | for l in self.layers: 163 | x = l(x, mask=mask) 164 | if output_hidden_states: 165 | encoder_states = encoder_states + (x,) 166 | 167 | h = x[0] 168 | 169 | return (h, encoder_states) 170 | 171 | 172 | class VisionEmbeddings(nn.Module): 173 | def __init__(self, config: VisionConfig): 174 | super().__init__() 175 | self.config = config 176 | self.embed_dim = config.hidden_size 177 | self.image_size = config.image_size 178 | self.patch_size = config.patch_size 179 | 180 | self.patch_embedding = nn.Conv2d( 181 | in_channels=config.num_channels, 182 | out_channels=self.embed_dim, 183 | kernel_size=self.patch_size, 184 | stride=self.patch_size, 185 | ) 186 | 187 | self.num_patches = (self.image_size // self.patch_size) ** 2 188 | self.num_positions = self.num_patches 189 | self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) 190 | 191 | def __call__(self, x: mx.array) -> mx.array: 192 | patch_embeddings = self.patch_embedding(x) 193 | patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) 194 | position_ids = mx.array(np.arange(self.num_positions)[None, :]) 195 | embeddings = patch_embeddings 196 | embeddings += self.position_embedding(position_ids) 197 | return embeddings 198 | 199 | 200 | class SigLipVisionModel(nn.Module): 201 | def __init__(self, config: VisionConfig): 202 | super().__init__() 203 | self.embeddings = VisionEmbeddings(config) 204 | self.encoder = Encoder(config) 205 | self.post_layernorm = nn.LayerNorm(config.hidden_size) 206 | 207 | def __call__( 208 | self, 209 | x: mx.array, 210 | output_hidden_states: Optional[bool] = None, 211 | ) -> mx.array: 212 | x = self.embeddings(x) 213 | 214 | encoder_outputs = self.encoder( 215 | x=x, output_hidden_states=output_hidden_states, mask=None 216 | ) 217 | 218 | pooler_output = self.post_layernorm(encoder_outputs[0]) 219 | 220 | return pooler_output, x, encoder_outputs[-1] 221 | 222 | 223 | class VisionModel(nn.Module): 224 | def __init__(self, config: VisionConfig): 225 | super().__init__() 226 | self.model_type = config.model_type 227 | if self.model_type != "siglip_vision_model": 228 | raise ValueError(f"Unsupported model type: {self.model_type}") 229 | 230 | self.vision_model = SigLipVisionModel(config) 231 | 232 | def __call__( 233 | self, x: mx.array, output_hidden_states: Optional[bool] = None 234 | ) -> mx.array: 235 | return self.vision_model(x, output_hidden_states) 236 | 237 | def sanitize(self, weights): 238 | sanitized_weights = {} 239 | for k, v in weights.items(): 240 | if "position_ids" in k: 241 | # Remove unused position_ids 242 | continue 243 | elif "patch_embedding.weight" in k: 244 | # PyTorch conv2d weight tensors have shape: 245 | # [out_channels, in_channels, kH, KW] 246 | # MLX conv2d expects the weight be of shape: 247 | # [out_channels, kH, KW, in_channels] 248 | if check_array_shape(v): 249 | sanitized_weights[k] = v 250 | else: 251 | sanitized_weights[k] = v.transpose(0, 2, 3, 1) 252 | else: 253 | sanitized_weights[k] = v 254 | 255 | return sanitized_weights 256 | -------------------------------------------------------------------------------- /mlx_vlm/models/idefics2/vision.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | import numpy as np 8 | 9 | 10 | @dataclass 11 | class VisionConfig: 12 | model_type: str 13 | hidden_size: int 14 | intermediate_size: int 15 | num_hidden_layers: int 16 | num_attention_heads: int 17 | image_size: int 18 | patch_size: int 19 | layer_norm_eps: float 20 | num_channels: int = 3 21 | 22 | @classmethod 23 | def from_dict(cls, params): 24 | return cls( 25 | **{ 26 | k: v 27 | for k, v in params.items() 28 | if k in inspect.signature(cls).parameters 29 | } 30 | ) 31 | 32 | 33 | def check_array_shape(arr): 34 | shape = arr.shape 35 | 36 | # Check if the shape has 4 dimensions 37 | if len(shape) != 4: 38 | return False 39 | 40 | out_channels, kH, KW, _ = shape 41 | 42 | # Check if out_channels is the largest, and kH and KW are the same 43 | if (out_channels >= kH) and (out_channels >= KW) and (kH == KW): 44 | return True 45 | else: 46 | return False 47 | 48 | 49 | class Attention(nn.Module): 50 | def __init__( 51 | self, 52 | dims: int, 53 | num_heads: int, 54 | query_input_dims: Optional[int] = None, 55 | key_input_dims: Optional[int] = None, 56 | value_input_dims: Optional[int] = None, 57 | value_dims: Optional[int] = None, 58 | value_output_dims: Optional[int] = None, 59 | ): 60 | super().__init__() 61 | 62 | if (dims % num_heads) != 0: 63 | raise ValueError( 64 | "The input feature dimensions should be divisible by the " 65 | f"number of heads ({dims} % {num_heads}) != 0" 66 | ) 67 | 68 | query_input_dims = query_input_dims or dims 69 | key_input_dims = key_input_dims or dims 70 | value_input_dims = value_input_dims or key_input_dims 71 | value_dims = value_dims or dims 72 | value_output_dims = value_output_dims or dims 73 | 74 | self.num_heads = num_heads 75 | head_dim = dims // num_heads 76 | self.scale = head_dim**-0.5 77 | 78 | self.q_proj = nn.Linear(query_input_dims, dims, bias=True) 79 | self.k_proj = nn.Linear(key_input_dims, dims, bias=True) 80 | self.v_proj = nn.Linear(value_input_dims, value_dims, bias=True) 81 | self.out_proj = nn.Linear(value_dims, value_output_dims, bias=True) 82 | 83 | def __call__(self, x: mx.array, mask=None): 84 | B, L, _ = x.shape 85 | queries = self.q_proj(x) 86 | keys = self.k_proj(x) 87 | values = self.v_proj(x) 88 | 89 | num_heads = self.num_heads 90 | 91 | queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) 92 | keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) 93 | values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) 94 | 95 | output = mx.fast.scaled_dot_product_attention( 96 | queries, keys, values, scale=self.scale, mask=mask 97 | ) 98 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 99 | return self.out_proj(output) 100 | 101 | 102 | class MLP(nn.Module): 103 | def __init__(self, config: VisionConfig): 104 | super().__init__() 105 | self.activation_fn = nn.GELU(approx="fast") 106 | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) 107 | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) 108 | 109 | def __call__(self, x: mx.array) -> mx.array: 110 | x = self.activation_fn(self.fc1(x)) 111 | x = self.fc2(x) 112 | return x 113 | 114 | 115 | class EncoderLayer(nn.Module): 116 | def __init__(self, config: VisionConfig): 117 | super().__init__() 118 | self.embed_dim = config.hidden_size 119 | self.self_attn = Attention(config.hidden_size, config.num_attention_heads) 120 | self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 121 | self.mlp = MLP(config) 122 | self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 123 | 124 | def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: 125 | y = self.layer_norm1(x) 126 | y = self.self_attn(y, mask) 127 | x = x + y 128 | y = self.layer_norm2(x) 129 | y = self.mlp(y) 130 | return x + y 131 | 132 | 133 | class Encoder(nn.Module): 134 | def __init__(self, config: VisionConfig): 135 | super().__init__() 136 | self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] 137 | 138 | def __call__( 139 | self, 140 | x: mx.array, 141 | output_hidden_states: Optional[bool] = None, 142 | mask: Optional[mx.array] = None, 143 | ) -> mx.array: 144 | encoder_states = (x,) if output_hidden_states else None 145 | h = x 146 | for l in self.layers: 147 | x = l(x, mask=mask) 148 | if output_hidden_states: 149 | encoder_states = encoder_states + (x,) 150 | 151 | h = x[0] 152 | 153 | return (h, encoder_states) 154 | 155 | 156 | class VisionEmbeddings(nn.Module): 157 | def __init__(self, config: VisionConfig): 158 | super().__init__() 159 | self.config = config 160 | self.embed_dim = config.hidden_size 161 | self.image_size = config.image_size 162 | self.patch_size = config.patch_size 163 | 164 | self.patch_embedding = nn.Conv2d( 165 | in_channels=config.num_channels, 166 | out_channels=self.embed_dim, 167 | kernel_size=self.patch_size, 168 | stride=self.patch_size, 169 | ) 170 | 171 | self.num_patches = self.image_size // self.patch_size 172 | self.num_positions = self.num_patches**2 173 | self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) 174 | 175 | def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: 176 | B, H, W, C = x.shape 177 | patch_embeddings = self.patch_embedding(x) 178 | patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) 179 | max_nb_patches_h, max_nb_patches_w = ( 180 | H // self.patch_size, 181 | W // self.patch_size, 182 | ) 183 | boundaries = np.linspace( 184 | 1 / self.num_patches, 1.0, self.num_patches, endpoint=False 185 | ) 186 | position_ids = np.zeros((B, max_nb_patches_h * max_nb_patches_w), dtype=int) 187 | 188 | for batch_idx, p_attn_mask in enumerate(mask): 189 | p_attn_mask = np.array(p_attn_mask) 190 | nb_patches_h = p_attn_mask[:, 0].sum() 191 | nb_patches_w = p_attn_mask[0, :].sum() 192 | 193 | fractional_coords_h = np.linspace(0, 1, nb_patches_h, endpoint=False) 194 | fractional_coords_w = np.linspace(0, 1, nb_patches_w, endpoint=False) 195 | 196 | bucket_coords_h = ( 197 | np.digitize(fractional_coords_h, boundaries, right=True) - 1 198 | ) 199 | bucket_coords_w = ( 200 | np.digitize(fractional_coords_w, boundaries, right=True) - 1 201 | ) 202 | 203 | pos_ids = ( 204 | bucket_coords_h[:, None] * self.num_patches + bucket_coords_w 205 | ).flatten() 206 | position_ids[batch_idx][p_attn_mask.reshape(-1)] = pos_ids 207 | 208 | embeddings = patch_embeddings 209 | embeddings += self.position_embedding(mx.array(position_ids)) 210 | return embeddings 211 | 212 | 213 | class VisionModel(nn.Module): 214 | def __init__(self, config: VisionConfig): 215 | super().__init__() 216 | self.config = config 217 | self.model_type = config.model_type 218 | if self.model_type != "idefics2": 219 | raise ValueError(f"Unsupported model type: {self.model_type}") 220 | self.embeddings = VisionEmbeddings(config) 221 | self.encoder = Encoder(config) 222 | self.post_layernorm = nn.LayerNorm(config.hidden_size) 223 | 224 | def __call__( 225 | self, 226 | x: mx.array, 227 | patch_attention_mask: Optional[mx.array] = None, 228 | output_hidden_states: Optional[bool] = None, 229 | ) -> mx.array: 230 | 231 | B, L, D, C = x.shape 232 | if patch_attention_mask is None: 233 | patch_size = self.config.patch_size 234 | patch_attention_mask = mx.ones( 235 | ( 236 | B, 237 | L // patch_size, 238 | D // patch_size, 239 | ), 240 | dtype=mx.bool_, 241 | ) 242 | 243 | x = self.embeddings(x, mask=patch_attention_mask) 244 | 245 | encoder_outputs = self.encoder(x=x, output_hidden_states=output_hidden_states) 246 | 247 | pooler_output = self.post_layernorm(encoder_outputs[0]) 248 | 249 | return pooler_output, x, encoder_outputs[-1] 250 | 251 | def sanitize(self, weights): 252 | sanitized_weights = {} 253 | for k, v in weights.items(): 254 | if "patch_embedding.weight" in k: 255 | # PyTorch conv2d weight tensors have shape: 256 | # [out_channels, in_channels, kH, KW] 257 | # MLX conv2d expects the weight be of shape: 258 | # [out_channels, kH, KW, in_channels] 259 | if check_array_shape(v): 260 | sanitized_weights[k] = v 261 | else: 262 | sanitized_weights[k] = v.transpose(0, 2, 3, 1) 263 | else: 264 | sanitized_weights[k] = v 265 | 266 | return sanitized_weights 267 | -------------------------------------------------------------------------------- /mlx_vlm/models/nanoLlava/nanoLlava.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import json 4 | import re 5 | from dataclasses import dataclass 6 | from functools import partial, reduce 7 | from pathlib import Path 8 | from typing import Dict, Optional 9 | 10 | import mlx.core as mx 11 | import mlx.nn as nn 12 | import numpy as np 13 | from huggingface_hub import snapshot_download 14 | from PIL import Image 15 | from transformers import AutoConfig 16 | from transformers.image_transforms import ( 17 | convert_to_rgb, 18 | normalize, 19 | rescale, 20 | resize, 21 | to_channel_dimension_format, 22 | ) 23 | from transformers.image_utils import to_numpy_array 24 | 25 | from ..base import BaseImageProcessor 26 | from .language import LanguageModel, TextConfig 27 | from .vision import VisionConfig, VisionModel 28 | 29 | 30 | @dataclass 31 | class ModelConfig: 32 | text_config: TextConfig 33 | vision_config: VisionConfig 34 | model_type: str 35 | auto_map: dict 36 | hidden_size: int 37 | mm_hidden_size: int 38 | mm_vision_tower: str 39 | mm_projector_type: str = "mlp2x_gelu" 40 | ignore_index: int = -100 41 | image_token_index: int = -200 42 | vocab_size: int = 151936 43 | 44 | @classmethod 45 | def from_dict(cls, params): 46 | return cls( 47 | **{ 48 | k: v 49 | for k, v in params.items() 50 | if k in inspect.signature(cls).parameters 51 | } 52 | ) 53 | 54 | 55 | class ImageProcessor(BaseImageProcessor): 56 | def preprocess(self, images): 57 | if isinstance(images, Image.Image): 58 | images = [images] 59 | else: 60 | assert isinstance(images, list) 61 | 62 | transforms = [ 63 | convert_to_rgb, 64 | to_numpy_array, 65 | partial( 66 | resize, 67 | size=self.size, 68 | resample=self.resample, 69 | data_format=self.data_format, 70 | ), 71 | partial(rescale, scale=self.rescale_factor, data_format=self.data_format), 72 | partial( 73 | normalize, 74 | mean=self.image_mean, 75 | std=self.image_std, 76 | data_format=self.data_format, 77 | ), 78 | partial( 79 | to_channel_dimension_format, 80 | channel_dim=self.data_format, 81 | input_channel_dim=self.data_format, 82 | ), 83 | ] 84 | 85 | images = reduce(lambda x, f: [*map(f, x)], transforms, images) 86 | 87 | return images 88 | 89 | 90 | class LlavaMultiModalProjector(nn.Module): 91 | def __init__(self, config: ModelConfig): 92 | super().__init__() 93 | self.linear_1 = nn.Linear( 94 | config.vision_config.hidden_size, config.text_config.hidden_size, bias=True 95 | ) 96 | self.gelu = nn.GELU() 97 | self.linear_2 = nn.Linear( 98 | config.text_config.hidden_size, config.text_config.hidden_size, bias=True 99 | ) 100 | 101 | def __call__(self, x: mx.array) -> mx.array: 102 | x = self.linear_1(x) 103 | x = self.gelu(x) 104 | x = self.linear_2(x) 105 | return x 106 | 107 | 108 | class SigLipVisionTower(nn.Module): 109 | def __init__(self, config: VisionConfig): 110 | super().__init__() 111 | self.vision_tower = VisionModel(config) 112 | 113 | def __call__( 114 | self, x: mx.array, output_hidden_states: Optional[bool] = None 115 | ) -> mx.array: 116 | return self.vision_tower(x, output_hidden_states) 117 | 118 | 119 | class Model(nn.Module): 120 | def __init__(self, config: ModelConfig): 121 | self.model_type = config.model_type 122 | self.config = config 123 | 124 | self.vision_tower = SigLipVisionTower(config.vision_config) 125 | self.language_model = LanguageModel(config.text_config) 126 | self.mm_projector = LlavaMultiModalProjector(config) 127 | 128 | def get_input_embeddings( 129 | self, 130 | input_ids: Optional[mx.array] = None, 131 | pixel_values: Optional[mx.array] = None, 132 | ): 133 | if pixel_values is None: 134 | return self.language_model(input_ids) 135 | 136 | inputs_embeds = self.language_model.model.embed_tokens(input_ids) 137 | 138 | *_, hidden_state = self.vision_tower( 139 | pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True 140 | ) 141 | 142 | image_features = hidden_state[-1].astype(pixel_values.dtype) 143 | assert image_features.shape[-2] == 729 144 | 145 | image_features = self.mm_projector(image_features) 146 | 147 | final_inputs_embeds = self._prepare_inputs_for_multimodal( 148 | image_features, inputs_embeds, input_ids 149 | ) 150 | return final_inputs_embeds 151 | 152 | def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids): 153 | image_token_index = self.config.image_token_index 154 | num_images, num_image_patches, embed_dim = image_features.shape 155 | 156 | # Positions of tokens in input_ids, assuming batch size is 1 157 | image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() 158 | 159 | if len(image_positions) != num_images: 160 | raise ValueError( 161 | f"The number of image tokens ({len(image_positions)}) does not " 162 | f" match the number of image inputs ({num_images})." 163 | ) 164 | 165 | text_segments = [] 166 | start_idx = 0 167 | 168 | for position in image_positions: 169 | text_segments.append(inputs_embeds[:, start_idx:position]) 170 | start_idx = position + 1 171 | 172 | image_embeddings = mx.split(image_features, image_features.shape[0]) 173 | final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] 174 | final_embeddings += [inputs_embeds[:, start_idx:]] 175 | 176 | # Create a final embedding of shape 177 | # (1, num_image_patches*num_images + sequence_len, embed_dim) 178 | return mx.concatenate(final_embeddings, axis=1) 179 | 180 | def __call__( 181 | self, 182 | input_ids: mx.array, 183 | pixel_values: mx.array, 184 | mask: Optional[mx.array] = None, 185 | cache=None, 186 | ): 187 | input_embeddings = self.get_input_embeddings(input_ids, pixel_values) 188 | logits, cache = self.language_model( 189 | inputs=input_ids, cache=cache, inputs_embeds=input_embeddings 190 | ) 191 | return logits, cache 192 | 193 | @staticmethod 194 | def from_pretrained(path_or_hf_repo: str): 195 | path = Path(path_or_hf_repo) 196 | if not path.exists(): 197 | path = Path( 198 | snapshot_download( 199 | repo_id=path_or_hf_repo, 200 | allow_patterns=[ 201 | "*.json", 202 | "*.safetensors", 203 | "*.py", 204 | "tokenizer.model", 205 | "*.tiktoken", 206 | ], 207 | ) 208 | ) 209 | 210 | with open(path / "config.json", "r") as f: 211 | config = json.load(f) 212 | 213 | siglip_config = AutoConfig.from_pretrained(config["mm_vision_tower"]) 214 | text_config = AutoConfig.from_pretrained(config["language_model"]) 215 | siglip_config = siglip_config.to_dict() 216 | text_config = text_config.to_dict() 217 | config["vision_config"] = siglip_config["vision_config"] 218 | config["text_config"] = text_config 219 | 220 | model_config = ModelConfig.from_dict(config) 221 | model_config.vision_config = VisionConfig.from_dict(config["vision_config"]) 222 | model_config.text_config = TextConfig.from_dict(config["text_config"]) 223 | 224 | model = Model(model_config) 225 | weight_files = glob.glob(str(path / "*.safetensors")) 226 | if not weight_files: 227 | raise FileNotFoundError(f"No safetensors found in {path}") 228 | 229 | weights = {} 230 | for wf in weight_files: 231 | weights.update(mx.load(wf)) 232 | 233 | weights = model.sanitize(weights=weights) 234 | 235 | weights = VisionModel(model_config.vision_config).sanitize(weights=weights) 236 | weights = LanguageModel(model_config.text_config).sanitize(weights=weights) 237 | model.load_weights(list(weights.items())) 238 | return model 239 | 240 | def sanitize(self, weights): 241 | weights = { 242 | ( 243 | f"{k.split('.', 1)[1]}" 244 | if re.match(r"^model\.vision_tower", k) 245 | else ( 246 | f"mm_projector.linear_1.{k.split('.')[-1]}" 247 | if re.match(r"^model\.mm_projector\.0", k) 248 | else ( 249 | f"mm_projector.linear_2.{k.split('.')[-1]}" 250 | if re.match(r"^model\.mm_projector\.2", k) 251 | else ( 252 | f"language_model.model.{k}" 253 | if re.match(r"^lm_head", k) 254 | else ( 255 | f"language_model.{k}" 256 | if re.match(r"^model\.(embed_tokens|norm|layers)", k) 257 | else k 258 | ) 259 | ) 260 | ) 261 | ) 262 | ): v 263 | for k, v in weights.items() 264 | } 265 | 266 | weights = { 267 | ( 268 | f"vision_tower.vision_tower.vision_model.head.attention.in_proj.bias" 269 | if re.match( 270 | r"^vision_tower\.vision_tower\.vision_model\.head\.attention\.in_proj_bias", 271 | k, 272 | ) 273 | else ( 274 | f"vision_tower.vision_tower.vision_model.head.attention.in_proj.weight" 275 | if re.match( 276 | r"^vision_tower\.vision_tower\.vision_model\.head\.attention\.in_proj_weight", 277 | k, 278 | ) 279 | else k 280 | ) 281 | ): v 282 | for k, v in weights.items() 283 | } 284 | 285 | return weights 286 | -------------------------------------------------------------------------------- /mlx_vlm/models/nanoLlava/vision.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | import numpy as np 8 | 9 | 10 | @dataclass 11 | class VisionConfig: 12 | model_type: str 13 | num_hidden_layers: int = 27 14 | hidden_size: int = 1152 15 | intermediate_size: int = 4304 16 | num_attention_heads: int = 16 17 | image_size: int = 384 18 | patch_size: int = 14 19 | projection_dim: int = 768 20 | vocab_size: int = 32000 21 | num_channels: int = 3 22 | layer_norm_eps: float = 1e-6 23 | 24 | @classmethod 25 | def from_dict(cls, params): 26 | return cls( 27 | **{ 28 | k: v 29 | for k, v in params.items() 30 | if k in inspect.signature(cls).parameters 31 | } 32 | ) 33 | 34 | 35 | def check_array_shape(arr): 36 | shape = arr.shape 37 | 38 | # Check if the shape has 4 dimensions 39 | if len(shape) != 4: 40 | return False 41 | 42 | out_channels, kH, KW, _ = shape 43 | 44 | # Check if out_channels is the largest, and kH and KW are the same 45 | if (out_channels >= kH) and (out_channels >= KW) and (kH == KW): 46 | return True 47 | else: 48 | return False 49 | 50 | 51 | class Attention(nn.Module): 52 | def __init__( 53 | self, 54 | dims: int, 55 | num_heads: int, 56 | query_input_dims: Optional[int] = None, 57 | key_input_dims: Optional[int] = None, 58 | value_input_dims: Optional[int] = None, 59 | value_dims: Optional[int] = None, 60 | value_output_dims: Optional[int] = None, 61 | bias: bool = False, 62 | ): 63 | super().__init__() 64 | 65 | if (dims % num_heads) != 0: 66 | raise ValueError( 67 | "The input feature dimensions should be divisible by the " 68 | f"number of heads ({dims} % {num_heads}) != 0" 69 | ) 70 | 71 | query_input_dims = query_input_dims or dims 72 | key_input_dims = key_input_dims or dims 73 | value_input_dims = value_input_dims or key_input_dims 74 | value_dims = value_dims or dims 75 | value_output_dims = value_output_dims or dims 76 | 77 | self.num_heads = num_heads 78 | head_dim = dims // num_heads 79 | self.scale = head_dim**-0.5 80 | 81 | self.q_proj = nn.Linear(query_input_dims, dims, bias=bias) 82 | self.k_proj = nn.Linear(key_input_dims, dims, bias=bias) 83 | self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) 84 | self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) 85 | 86 | def __call__(self, queries, keys, values, mask=None): 87 | queries = self.q_proj(queries) 88 | keys = self.k_proj(keys) 89 | values = self.v_proj(values) 90 | 91 | num_heads = self.num_heads 92 | B, L, D = queries.shape 93 | _, S, _ = keys.shape 94 | queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) 95 | keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) 96 | values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) 97 | 98 | output = mx.fast.scaled_dot_product_attention( 99 | queries, keys, values, scale=self.scale, mask=mask 100 | ) 101 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 102 | return self.out_proj(output) 103 | 104 | 105 | class MHA(nn.Module): 106 | def __init__( 107 | self, 108 | dims: int, 109 | num_heads: int, 110 | bias: bool = False, 111 | ): 112 | super().__init__() 113 | 114 | if (dims % num_heads) != 0: 115 | raise ValueError( 116 | "The input feature dimensions should be divisible by the " 117 | f"number of heads ({dims} % {num_heads}) != 0" 118 | ) 119 | 120 | self.num_heads = num_heads 121 | head_dim = dims // num_heads 122 | self.scale = head_dim**-0.5 123 | 124 | self.in_proj = nn.Linear(dims, dims * 3, bias=bias) 125 | self.out_proj = nn.Linear(dims, dims, bias=bias) 126 | 127 | def __call__(self, queries: mx.array, kv: mx.array, mask=None, cache=None): 128 | B, L, D = queries.shape 129 | 130 | qkv = self.in_proj(queries) 131 | _, keys, values = mx.split(qkv, 3, axis=-1) 132 | 133 | num_heads = self.num_heads 134 | B, L, D = queries.shape 135 | _, S, _ = keys.shape 136 | queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) 137 | keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) 138 | values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) 139 | 140 | output = mx.fast.scaled_dot_product_attention( 141 | queries, keys, values, scale=self.scale, mask=mask 142 | ) 143 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 144 | return self.out_proj(output) 145 | 146 | 147 | class MLP(nn.Module): 148 | def __init__(self, config: VisionConfig): 149 | super().__init__() 150 | self.activation_fn = nn.GELU(approx="fast") 151 | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) 152 | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) 153 | 154 | def __call__(self, x: mx.array) -> mx.array: 155 | x = self.activation_fn(self.fc1(x)) 156 | x = self.fc2(x) 157 | return x 158 | 159 | 160 | class EncoderLayer(nn.Module): 161 | def __init__(self, config: VisionConfig): 162 | super().__init__() 163 | self.embed_dim = config.hidden_size 164 | self.self_attn = Attention( 165 | config.hidden_size, config.num_attention_heads, bias=True 166 | ) 167 | self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 168 | self.mlp = MLP(config) 169 | self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 170 | 171 | def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: 172 | y = self.layer_norm1(x) 173 | y = self.self_attn(y, y, y, mask) 174 | x = x + y 175 | y = self.layer_norm2(x) 176 | y = self.mlp(y) 177 | return x + y 178 | 179 | 180 | class Encoder(nn.Module): 181 | def __init__(self, config: VisionConfig): 182 | super().__init__() 183 | self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] 184 | 185 | 186 | class VisionEmbeddings(nn.Module): 187 | def __init__(self, config: VisionConfig): 188 | super().__init__() 189 | self.config = config 190 | self.embed_dim = config.hidden_size 191 | self.image_size = config.image_size 192 | self.patch_size = config.patch_size 193 | 194 | self.patch_embedding = nn.Conv2d( 195 | in_channels=config.num_channels, 196 | out_channels=self.embed_dim, 197 | kernel_size=self.patch_size, 198 | stride=self.patch_size, 199 | bias=True, 200 | ) 201 | 202 | self.num_patches = (self.image_size // self.patch_size) ** 2 203 | self.num_positions = self.num_patches 204 | self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) 205 | 206 | def __call__(self, x: mx.array) -> mx.array: 207 | batch_size = x.shape[0] 208 | patch_embeddings = self.patch_embedding(x) 209 | patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) 210 | self.position_ids = mx.array(np.arange(self.num_positions)[None, :]) 211 | embeddings = patch_embeddings 212 | embeddings += self.position_embedding(self.position_ids) 213 | return embeddings 214 | 215 | 216 | class SigLipVisionModel(nn.Module): 217 | def __init__(self, config: VisionConfig): 218 | super().__init__() 219 | self.embeddings = VisionEmbeddings(config) 220 | self.encoder = Encoder(config) 221 | self.post_layernorm = nn.LayerNorm(config.hidden_size) 222 | self.head = SigLipMultiheadAttentionPoolingHead(config) 223 | 224 | def __call__( 225 | self, 226 | x: mx.array, 227 | output_hidden_states: Optional[bool] = None, 228 | ) -> mx.array: 229 | x = self.embeddings(x) 230 | 231 | encoder_states = (x,) if output_hidden_states else None 232 | 233 | for l in self.encoder.layers: 234 | x = l(x, mask=None) 235 | if output_hidden_states: 236 | encoder_states = encoder_states + (x,) 237 | 238 | pooler_output = self.post_layernorm(x[:, 0, :]) 239 | pooler_output = self.head(pooler_output) 240 | return pooler_output, x, encoder_states 241 | 242 | 243 | class SigLipMultiheadAttentionPoolingHead(nn.Module): 244 | 245 | def __init__(self, config: VisionConfig): 246 | super().__init__() 247 | 248 | self.probe = mx.ones( 249 | ( 250 | 1, 251 | 1, 252 | config.hidden_size, 253 | ) 254 | ) 255 | self.attention = MHA( 256 | config.hidden_size, num_heads=config.num_attention_heads, bias=True 257 | ) 258 | self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 259 | self.mlp = MLP(config) 260 | 261 | def __call__(self, x: mx.array): 262 | x = self.attention(self.probe, x)[0] 263 | 264 | residual = x 265 | x = self.layernorm(x) 266 | x = residual + self.mlp(x) 267 | 268 | return x[:, 0] 269 | 270 | 271 | class VisionModel(nn.Module): 272 | def __init__(self, config: VisionConfig): 273 | super().__init__() 274 | self.model_type = config.model_type 275 | if self.model_type != "siglip_vision_model": 276 | raise ValueError(f"Unsupported model type: {self.model_type}") 277 | 278 | self.vision_model = SigLipVisionModel(config) 279 | 280 | def __call__( 281 | self, x: mx.array, output_hidden_states: Optional[bool] = None 282 | ) -> mx.array: 283 | return self.vision_model(x, output_hidden_states) 284 | 285 | def sanitize(self, weights): 286 | sanitized_weights = {} 287 | for k, v in weights.items(): 288 | if "position_ids" in k: 289 | # Remove unused position_ids 290 | continue 291 | elif "patch_embedding.weight" in k: 292 | # PyTorch conv2d weight tensors have shape: 293 | # [out_channels, in_channels, kH, KW] 294 | # MLX conv2d expects the weight be of shape: 295 | # [out_channels, kH, KW, in_channels] 296 | if check_array_shape(v): 297 | sanitized_weights[k] = v 298 | else: 299 | sanitized_weights[k] = v.transpose(0, 2, 3, 1) 300 | else: 301 | sanitized_weights[k] = v 302 | 303 | return sanitized_weights 304 | -------------------------------------------------------------------------------- /mlx_vlm/tokenizer_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from functools import partial 3 | 4 | from transformers import AutoTokenizer 5 | 6 | REPLACEMENT_CHAR = "\ufffd" 7 | 8 | 9 | def _remove_space(x): 10 | if x and x[0] == " ": 11 | return x[1:] 12 | return x 13 | 14 | 15 | class StreamingDetokenizer: 16 | """The streaming detokenizer interface so that we can detokenize one token at a time. 17 | 18 | Example usage is as follows: 19 | 20 | detokenizer = ... 21 | 22 | # Reset the tokenizer state 23 | detokenizer.reset() 24 | 25 | for token in generate(...): 26 | detokenizer.add_token(token.item()) 27 | 28 | # Contains the whole text so far. Some tokens may not be included 29 | # since it contains whole words usually. 30 | detokenizer.text 31 | 32 | # Contains the printable segment (usually a word) since the last 33 | # time it was accessed 34 | detokenizer.last_segment 35 | 36 | # Contains all the tokens added so far 37 | detokenizer.tokens 38 | 39 | # Make sure that we detokenize any remaining tokens 40 | detokenizer.finalize() 41 | 42 | # Now detokenizer.text should match tokenizer.decode(detokenizer.tokens) 43 | """ 44 | 45 | __slots__ = ("text", "tokens", "offset") 46 | 47 | def reset(self): 48 | raise NotImplementedError() 49 | 50 | def add_token(self, token): 51 | raise NotImplementedError() 52 | 53 | def finalize(self): 54 | raise NotImplementedError() 55 | 56 | @property 57 | def last_segment(self): 58 | """Return the last segment of readable text since last time this property was accessed.""" 59 | text = self.text 60 | if text and text[-1] != REPLACEMENT_CHAR: 61 | segment = text[self.offset :] 62 | self.offset = len(text) 63 | return segment 64 | return "" 65 | 66 | 67 | class NaiveStreamingDetokenizer(StreamingDetokenizer): 68 | """NaiveStreamingDetokenizer relies on the underlying tokenizer 69 | implementation and should work with every tokenizer. 70 | 71 | Its complexity is O(T^2) where T is the longest line since it will 72 | repeatedly detokenize the same tokens until a new line is generated. 73 | """ 74 | 75 | def __init__(self, tokenizer): 76 | self._tokenizer = tokenizer 77 | self._tokenizer.decode([0]) 78 | self.reset() 79 | 80 | def reset(self): 81 | self.offset = 0 82 | self._tokens = [] 83 | self._text = "" 84 | self._current_tokens = [] 85 | self._current_text = "" 86 | 87 | def add_token(self, token): 88 | self._current_tokens.append(token) 89 | 90 | def finalize(self): 91 | self._tokens.extend(self._current_tokens) 92 | self._text += self._tokenizer.decode(self._current_tokens) 93 | self._current_tokens = [] 94 | self._current_text = "" 95 | 96 | @property 97 | def text(self): 98 | if self._current_tokens: 99 | self._current_text = self._tokenizer.decode(self._current_tokens) 100 | if self._current_text and self._current_text[-1] == "\n": 101 | self._tokens.extend(self._current_tokens) 102 | self._text += self._current_text 103 | self._current_tokens.clear() 104 | self._current_text = "" 105 | return self._text + self._current_text 106 | 107 | @property 108 | def tokens(self): 109 | return self._tokens 110 | 111 | 112 | class SPMStreamingDetokenizer(StreamingDetokenizer): 113 | """A streaming detokenizer for SPM models. 114 | 115 | It adds tokens to the text if the next token starts with the special SPM 116 | underscore which results in linear complexity. 117 | """ 118 | 119 | def __init__(self, tokenizer, trim_space=True): 120 | self.trim_space = trim_space 121 | 122 | # Extract the tokens in a list from id to text 123 | self.tokenmap = [None] * len(tokenizer.vocab) 124 | for value, tokenid in tokenizer.vocab.items(): 125 | self.tokenmap[tokenid] = value 126 | 127 | # Replace bytes with their value 128 | for i in range(len(self.tokenmap)): 129 | if self.tokenmap[i].startswith("<0x"): 130 | self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16)) 131 | 132 | self.reset() 133 | 134 | def reset(self): 135 | self.offset = 0 136 | self._unflushed = "" 137 | self.text = "" 138 | self.tokens = [] 139 | 140 | def add_token(self, token): 141 | v = self.tokenmap[token] 142 | if v[0] == "\u2581": 143 | if self.text or not self.trim_space: 144 | self.text += self._unflushed.replace("\u2581", " ") 145 | else: 146 | self.text = _remove_space(self._unflushed.replace("\u2581", " ")) 147 | self._unflushed = v 148 | else: 149 | self._unflushed += v 150 | 151 | def finalize(self): 152 | if self.text or not self.trim_space: 153 | self.text += self._unflushed.replace("\u2581", " ") 154 | else: 155 | self.text = _remove_space(self._unflushed.replace("\u2581", " ")) 156 | self._unflushed = "" 157 | 158 | 159 | class BPEStreamingDetokenizer(StreamingDetokenizer): 160 | """A streaming detokenizer for OpenAI style BPE models. 161 | 162 | It adds tokens to the text if the next token starts with a space similar to 163 | the SPM detokenizer. 164 | """ 165 | 166 | _byte_decoder = None 167 | 168 | def __init__(self, tokenizer, trim_space=False): 169 | self.trim_space = trim_space 170 | 171 | # Extract the tokens in a list from id to text 172 | self.tokenmap = [None] * len(tokenizer.vocab) 173 | for value, tokenid in tokenizer.vocab.items(): 174 | self.tokenmap[tokenid] = value 175 | 176 | self.reset() 177 | 178 | # Make the BPE byte decoder from 179 | # https://github.com/openai/gpt-2/blob/master/src/encoder.py 180 | self.make_byte_decoder() 181 | 182 | def reset(self): 183 | self.offset = 0 184 | self._unflushed = "" 185 | self.text = "" 186 | self.tokens = [] 187 | 188 | def add_token(self, token): 189 | v = self.tokenmap[token] 190 | # if the token starts with space 191 | if self._byte_decoder[v[0]] == 32: 192 | current_text = bytearray( 193 | self._byte_decoder[c] for c in self._unflushed 194 | ).decode("utf-8") 195 | if self.text or not self.trim_space: 196 | self.text += current_text 197 | else: 198 | self.text += _remove_space(current_text) 199 | self._unflushed = v 200 | else: 201 | self._unflushed += v 202 | 203 | def finalize(self): 204 | current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( 205 | "utf-8" 206 | ) 207 | if self.text or not self.trim_space: 208 | self.text += current_text 209 | else: 210 | self.text += _remove_space(current_text) 211 | self._unflushed = "" 212 | 213 | @classmethod 214 | def make_byte_decoder(cls): 215 | """See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale.""" 216 | if cls._byte_decoder is not None: 217 | return 218 | 219 | char_to_bytes = {} 220 | limits = [ 221 | 0, 222 | ord("!"), 223 | ord("~") + 1, 224 | ord("¡"), 225 | ord("¬") + 1, 226 | ord("®"), 227 | ord("ÿ") + 1, 228 | ] 229 | n = 0 230 | for i, (start, stop) in enumerate(zip(limits, limits[1:])): 231 | if i % 2 == 0: 232 | for b in range(start, stop): 233 | char_to_bytes[chr(2**8 + n)] = b 234 | n += 1 235 | else: 236 | for b in range(start, stop): 237 | char_to_bytes[chr(b)] = b 238 | cls._byte_decoder = char_to_bytes 239 | 240 | 241 | class TokenizerWrapper: 242 | """A wrapper that combines an HF tokenizer and a detokenizer. 243 | 244 | Accessing any attribute other than the ``detokenizer`` is forwarded to the 245 | huggingface tokenizer. 246 | """ 247 | 248 | def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer): 249 | self._tokenizer = tokenizer 250 | self._detokenizer = detokenizer_class(tokenizer) 251 | 252 | def __getattr__(self, attr): 253 | if attr == "detokenizer": 254 | return self._detokenizer 255 | else: 256 | return getattr(self._tokenizer, attr) 257 | 258 | 259 | def _match(a, b): 260 | if type(a) != type(b): 261 | return False 262 | if isinstance(a, dict): 263 | return len(a) == len(b) and all(k in b and _match(a[k], b[k]) for k in a) 264 | if isinstance(a, list): 265 | return len(a) == len(b) and all(_match(ai, bi) for ai, bi in zip(a, b)) 266 | 267 | return a == b 268 | 269 | 270 | def _is_spm_decoder(decoder): 271 | _target_description = { 272 | "type": "Sequence", 273 | "decoders": [ 274 | {"type": "Replace", "pattern": {"String": "▁"}, "content": " "}, 275 | {"type": "ByteFallback"}, 276 | {"type": "Fuse"}, 277 | {"type": "Strip", "content": " ", "start": 1, "stop": 0}, 278 | ], 279 | } 280 | return _match(_target_description, decoder) 281 | 282 | 283 | def _is_spm_decoder_no_space(decoder): 284 | _target_description = { 285 | "type": "Sequence", 286 | "decoders": [ 287 | {"type": "Replace", "pattern": {"String": "▁"}, "content": " "}, 288 | {"type": "ByteFallback"}, 289 | {"type": "Fuse"}, 290 | ], 291 | } 292 | return _match(_target_description, decoder) 293 | 294 | 295 | def _is_bpe_decoder(decoder): 296 | _target_description = { 297 | "type": "ByteLevel", 298 | "add_prefix_space": False, 299 | "trim_offsets": False, 300 | "use_regex": False, 301 | } 302 | 303 | return _match(_target_description, decoder) 304 | 305 | 306 | def load_tokenizer(model_path, return_tokenizer=True, tokenizer_config_extra={}): 307 | """Load a huggingface tokenizer and try to infer the type of streaming 308 | detokenizer to use. 309 | 310 | Note, to use a fast streaming tokenizer, pass a local file path rather than 311 | a Hugging Face repo ID. 312 | """ 313 | detokenizer_class = NaiveStreamingDetokenizer 314 | 315 | tokenizer_file = model_path / "tokenizer.json" 316 | if tokenizer_file.exists(): 317 | tokenizer_content = json.load(tokenizer_file.open()) 318 | if "decoder" in tokenizer_content: 319 | if _is_spm_decoder(tokenizer_content["decoder"]): 320 | detokenizer_class = SPMStreamingDetokenizer 321 | elif _is_spm_decoder_no_space(tokenizer_content["decoder"]): 322 | detokenizer_class = partial(SPMStreamingDetokenizer, trim_space=False) 323 | elif _is_bpe_decoder(tokenizer_content["decoder"]): 324 | detokenizer_class = BPEStreamingDetokenizer 325 | 326 | if return_tokenizer: 327 | return TokenizerWrapper( 328 | AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra), 329 | detokenizer_class, 330 | ) 331 | else: 332 | return detokenizer_class 333 | -------------------------------------------------------------------------------- /mlx_vlm/models/phi3_v/vision.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import math 3 | from dataclasses import dataclass 4 | from types import SimpleNamespace 5 | from typing import Optional 6 | 7 | import mlx.core as mx 8 | import mlx.nn as nn 9 | import numpy as np 10 | 11 | 12 | @dataclass 13 | class VisionConfig: 14 | model_type: str = "phi3_v" 15 | num_hidden_layers: int = 24 16 | hidden_size: int = 1024 17 | intermediate_size: int = 4096 18 | num_attention_heads: int = 16 19 | image_size: int = 336 20 | patch_size: int = 14 21 | projection_dim: int = 768 22 | vocab_size: int = 32000 23 | num_channels: int = 3 24 | layer_norm_eps: float = 1e-5 25 | image_dim_out: int = (1024,) 26 | model_name: str = "openai/clip-vit-large-patch14-336" 27 | name: str = "clip_vision_model" 28 | num_img_tokens: int = 144 29 | 30 | @classmethod 31 | def from_dict(cls, params): 32 | return cls( 33 | **{ 34 | k: v 35 | for k, v in params.items() 36 | if k in inspect.signature(cls).parameters 37 | } 38 | ) 39 | 40 | 41 | def check_array_shape(arr): 42 | shape = arr.shape 43 | 44 | # Check if the shape has 4 dimensions 45 | if len(shape) != 4: 46 | return False 47 | 48 | out_channels, kH, KW, _ = shape 49 | 50 | # Check if out_channels is the largest, and kH and KW are the same 51 | if (out_channels >= kH) and (out_channels >= KW) and (kH == KW): 52 | return True 53 | else: 54 | return False 55 | 56 | 57 | class Attention(nn.Module): 58 | def __init__( 59 | self, 60 | dims: int, 61 | num_heads: int, 62 | query_input_dims: Optional[int] = None, 63 | key_input_dims: Optional[int] = None, 64 | value_input_dims: Optional[int] = None, 65 | value_dims: Optional[int] = None, 66 | value_output_dims: Optional[int] = None, 67 | bias: bool = False, 68 | ): 69 | super().__init__() 70 | 71 | if (dims % num_heads) != 0: 72 | raise ValueError( 73 | "The input feature dimensions should be divisible by the " 74 | f"number of heads ({dims} % {num_heads}) != 0" 75 | ) 76 | 77 | query_input_dims = query_input_dims or dims 78 | key_input_dims = key_input_dims or dims 79 | value_input_dims = value_input_dims or key_input_dims 80 | value_dims = value_dims or dims 81 | value_output_dims = value_output_dims or dims 82 | 83 | self.num_heads = num_heads = num_heads 84 | head_dim = dims // num_heads 85 | self.scale = head_dim**-0.5 86 | 87 | self.q_proj = nn.Linear(query_input_dims, dims, bias=bias) 88 | self.k_proj = nn.Linear(key_input_dims, dims, bias=bias) 89 | self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) 90 | self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) 91 | 92 | def __call__(self, queries, keys, values, mask=None): 93 | queries = self.q_proj(queries) 94 | keys = self.k_proj(keys) 95 | values = self.v_proj(values) 96 | 97 | num_heads = self.num_heads 98 | B, L, D = queries.shape 99 | _, S, _ = keys.shape 100 | queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) 101 | keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) 102 | values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) 103 | 104 | output = mx.fast.scaled_dot_product_attention( 105 | queries, keys, values, scale=self.scale, mask=mask 106 | ) 107 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 108 | 109 | return self.out_proj(output) 110 | 111 | 112 | class MLP(nn.Module): 113 | def __init__(self, config: VisionConfig): 114 | super().__init__() 115 | self.activation_fn = nn.GELU(approx="fast") 116 | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) 117 | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) 118 | 119 | def __call__(self, x: mx.array) -> mx.array: 120 | x = self.activation_fn(self.fc1(x)) 121 | x = self.fc2(x) 122 | return x 123 | 124 | 125 | class EncoderLayer(nn.Module): 126 | def __init__(self, config: VisionConfig): 127 | super().__init__() 128 | self.embed_dim = config.hidden_size 129 | self.self_attn = Attention( 130 | config.hidden_size, config.num_attention_heads, bias=True 131 | ) 132 | self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 133 | self.mlp = MLP(config) 134 | self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 135 | 136 | def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: 137 | y = self.layer_norm1(x) 138 | y = self.self_attn(y, y, y, mask) 139 | x = x + y 140 | y = self.layer_norm2(x) 141 | y = self.mlp(y) 142 | return x + y 143 | 144 | 145 | class Encoder(nn.Module): 146 | def __init__(self, config: VisionConfig): 147 | super().__init__() 148 | self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] 149 | 150 | 151 | class VisionEmbeddings(nn.Module): 152 | def __init__(self, config: VisionConfig): 153 | super().__init__() 154 | self.config = config 155 | self.embed_dim = config.hidden_size 156 | self.image_size = config.image_size 157 | self.patch_size = config.patch_size 158 | 159 | self.class_embedding = mx.zeros((config.hidden_size,)) 160 | 161 | self.patch_embedding = nn.Conv2d( 162 | in_channels=config.num_channels, 163 | out_channels=self.embed_dim, 164 | kernel_size=self.patch_size, 165 | stride=self.patch_size, 166 | bias=False, 167 | ) 168 | 169 | self.num_patches = (self.image_size // self.patch_size) ** 2 170 | self.num_positions = self.num_patches + 1 171 | self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) 172 | 173 | def __call__(self, x: mx.array) -> mx.array: 174 | batch_size = x.shape[0] 175 | patch_embeddings = self.patch_embedding(x) 176 | patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) 177 | embed_dim = patch_embeddings.shape[-1] 178 | cls_embeddings = mx.broadcast_to( 179 | self.class_embedding, (batch_size, 1, embed_dim) 180 | ) 181 | position_ids = mx.array(np.arange(self.num_positions)[None, :]) 182 | 183 | embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) 184 | embeddings += self.position_embedding(position_ids) 185 | return embeddings 186 | 187 | 188 | class ClipModel(nn.Module): 189 | def __init__(self, config: VisionConfig): 190 | super().__init__() 191 | self.model_type = config.model_type 192 | self.embeddings = VisionEmbeddings(config) 193 | self.pre_layrnorm = nn.LayerNorm(config.hidden_size) 194 | self.encoder = Encoder(config) 195 | self.post_layernorm = nn.LayerNorm(config.hidden_size) 196 | 197 | def __call__( 198 | self, 199 | x: mx.array, 200 | output_hidden_states: Optional[bool] = None, 201 | ) -> mx.array: 202 | x = self.embeddings(x) 203 | x = self.pre_layrnorm(x) 204 | 205 | encoder_states = (x,) if output_hidden_states else None 206 | 207 | for l in self.encoder.layers: 208 | x = l(x, mask=None) 209 | if output_hidden_states: 210 | encoder_states = encoder_states + (x,) 211 | 212 | pooler_output = self.post_layernorm(x[:, 0, :]) 213 | return pooler_output, x, encoder_states 214 | 215 | 216 | class ClipVModel(nn.Module): 217 | def __init__(self, config): 218 | super().__init__() 219 | self.model_type = config.model_type 220 | self.vision_model = ClipModel(config) 221 | 222 | 223 | class VisionModel(nn.Module): 224 | CLIP_VIT_LARGE_PATCH14_336_CONFIG = SimpleNamespace( 225 | model_type="phi3_v", 226 | hidden_size=1024, 227 | image_size=336, 228 | intermediate_size=4096, 229 | layer_norm_eps=1e-05, 230 | num_attention_heads=16, 231 | num_channels=3, 232 | num_hidden_layers=24, 233 | patch_size=14, 234 | ) 235 | 236 | def __init__(self, config): 237 | super().__init__() 238 | self.model_type = config.model_type 239 | self.img_processor = ClipVModel(self.CLIP_VIT_LARGE_PATCH14_336_CONFIG) 240 | self.image_dim_out = image_dim_out = 1024 241 | self.glb_GN = mx.zeros([1, 1, image_dim_out * 4]) 242 | self.sub_GN = mx.zeros([1, 1, 1, image_dim_out * 4]) 243 | self.img_projection = [ 244 | nn.Linear(image_dim_out * 4, config.hidden_size), 245 | nn.GELU(), 246 | nn.Linear(config.hidden_size, config.hidden_size), 247 | ] 248 | 249 | def __call__( 250 | self, 251 | img_embeds, 252 | txt_embeds=None, 253 | img_sizes=None, 254 | positions=None, 255 | output_hidden_states=None, 256 | ): 257 | if output_hidden_states: 258 | return self.img_processor.vision_model( 259 | img_embeds, output_hidden_states=output_hidden_states 260 | ) 261 | img_embeds = mx.array(img_embeds) 262 | img_sizes = mx.array(img_sizes) 263 | B = img_embeds.shape[0] 264 | img_sizes = (img_sizes // 336).tolist() 265 | img_features = self.img_processor.vision_model( 266 | img_embeds.reshape(-1, *img_embeds.shape[2:]).transpose(0, 2, 3, 1), True 267 | )[-1][-2][:, 1:] 268 | img_features = img_features.reshape(B, -1, *img_features.shape[1:]) 269 | C, H = self.image_dim_out, int(img_features.shape[2] ** 0.5) 270 | output_imgs, output_len = [], [] 271 | for _bs in range(B): 272 | h, w = img_sizes[_bs] 273 | B_ = h * w 274 | 275 | def _reshape_and_concatenate(img, shape, tile_shape): 276 | return mx.concatenate( 277 | [ 278 | img.reshape(shape) 279 | .transpose(0, 1, 3, 2, 4, 5) 280 | .reshape(tile_shape), 281 | mx.tile(self.sub_GN, (1, tile_shape[1], 1, 1)), 282 | ], 283 | axis=2, 284 | ).reshape(1, -1, 4 * C) 285 | 286 | glb_img = _reshape_and_concatenate( 287 | img_features[_bs, :1], 288 | (1, H // 2, 2, H // 2, 2, C), 289 | (1, H // 2, H // 2, 4 * C), 290 | ) 291 | sub_img = _reshape_and_concatenate( 292 | img_features[_bs, 1 : B_ + 1], 293 | (B_, H // 2, 2, H // 2, 2, C), 294 | (1, h * 12, w * 12, 4 * C), 295 | ) 296 | x = mx.concatenate([sub_img, self.glb_GN, glb_img], axis=1) 297 | for l in self.img_projection: 298 | x = l(x) 299 | output_imgs.append(np.array(x.astype(mx.float32))) 300 | output_len.append(int((h * w + 1) * 144 + 1 + (h + 1) * 12)) 301 | idx = 0 302 | txt_embeds = np.array(txt_embeds.astype(mx.float32)) 303 | for i, cnt in enumerate(output_len): 304 | txt_embeds[ 305 | positions[idx][0], positions[idx][1] : positions[idx][1] + cnt 306 | ] = output_imgs[i] 307 | idx += cnt 308 | txt_embeds = mx.array(txt_embeds) 309 | return txt_embeds 310 | 311 | def sanitize(self, weights): 312 | sanitized_weights = {} 313 | for k, v in weights.items(): 314 | if "position_ids" in k: 315 | continue 316 | elif "patch_embedding.weight" in k: 317 | if check_array_shape(v): 318 | sanitized_weights[k] = v 319 | else: 320 | sanitized_weights[k] = v.transpose(0, 2, 3, 1) 321 | else: 322 | sanitized_weights[k] = v 323 | 324 | return sanitized_weights 325 | -------------------------------------------------------------------------------- /mlx_vlm/models/idefics2/idefics2.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import json 4 | import re 5 | from dataclasses import dataclass 6 | from pathlib import Path 7 | from typing import Optional, Tuple 8 | 9 | import mlx.core as mx 10 | import mlx.nn as nn 11 | import numpy as np 12 | from huggingface_hub import snapshot_download 13 | from transformers import AutoConfig 14 | 15 | from .language import LanguageModel, TextConfig 16 | from .vision import VisionConfig, VisionModel 17 | 18 | 19 | @dataclass 20 | class PerceiverConfig: 21 | model_type: str 22 | num_key_value_heads: int = 4 23 | resampler_depth: int = 3 24 | resampler_head_dim: int = 96 25 | resampler_n_heads: int = 16 26 | resampler_n_latents: int = 64 27 | 28 | @classmethod 29 | def from_dict(cls, params): 30 | return cls( 31 | **{ 32 | k: v 33 | for k, v in params.items() 34 | if k in inspect.signature(cls).parameters 35 | } 36 | ) 37 | 38 | 39 | @dataclass 40 | class ModelConfig: 41 | text_config: TextConfig 42 | vision_config: VisionConfig 43 | perceiver_config: PerceiverConfig 44 | model_type: str 45 | ignore_index: int = -100 46 | image_token_index: int = 32001 47 | vocab_size: int = 151936 48 | 49 | @classmethod 50 | def from_dict(cls, params): 51 | return cls( 52 | **{ 53 | k: v 54 | for k, v in params.items() 55 | if k in inspect.signature(cls).parameters 56 | } 57 | ) 58 | 59 | 60 | class Idefics2PerceiverAttention(nn.Module): 61 | def __init__(self, config: ModelConfig): 62 | super().__init__() 63 | 64 | dim = config.text_config.hidden_size 65 | self.n_heads = n_heads = config.perceiver_config.resampler_n_heads 66 | self.n_kv_heads = n_kv_heads = config.perceiver_config.num_key_value_heads 67 | 68 | head_dim = config.perceiver_config.resampler_head_dim 69 | self.scale = head_dim**-0.5 70 | 71 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) 72 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 73 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 74 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 75 | 76 | def __call__( 77 | self, 78 | x: mx.array, 79 | kv: mx.array, 80 | mask: Optional[mx.array] = None, 81 | cache: Optional[Tuple[mx.array, mx.array]] = None, 82 | ) -> mx.array: 83 | B, L, D = x.shape 84 | kv_seq_len = L + kv.shape[1] 85 | hidden_states = mx.concatenate([kv, x], axis=-2) 86 | 87 | queries = self.q_proj(x) 88 | keys = self.k_proj(hidden_states) 89 | values = self.v_proj(hidden_states) 90 | 91 | # Prepare the queries, keys and values for the attention computation 92 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 93 | keys = keys.reshape(B, kv_seq_len, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 94 | values = values.reshape(B, kv_seq_len, self.n_kv_heads, -1).transpose( 95 | 0, 2, 1, 3 96 | ) 97 | 98 | if cache is not None: 99 | key_cache, value_cache = cache 100 | keys = mx.concatenate([key_cache, keys], axis=2) 101 | values = mx.concatenate([value_cache, values], axis=2) 102 | 103 | output = mx.fast.scaled_dot_product_attention( 104 | queries, keys, values, scale=self.scale 105 | ) 106 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 107 | return self.o_proj(output) 108 | 109 | 110 | class Idefics2PerceiverLayer(nn.Module): 111 | def __init__(self, config: ModelConfig): 112 | super().__init__() 113 | self.hidden_size = config.text_config.hidden_size 114 | self.n_latents = config.perceiver_config.resampler_n_latents 115 | self.depth = config.perceiver_config.resampler_depth 116 | self.rms_norm_eps = config.text_config.rms_norm_eps 117 | 118 | self.input_latents_norm = nn.RMSNorm(self.hidden_size, eps=self.rms_norm_eps) 119 | self.input_context_norm = nn.RMSNorm(self.hidden_size, eps=self.rms_norm_eps) 120 | self.self_attn = Idefics2PerceiverAttention(config) 121 | self.post_attention_layernorm = nn.RMSNorm( 122 | self.hidden_size, eps=self.rms_norm_eps 123 | ) 124 | self.mlp = MLP(self.hidden_size, self.hidden_size * 4, self.hidden_size) 125 | 126 | def __call__( 127 | self, 128 | x: mx.array, 129 | hidden_states: mx.array, 130 | mask: Optional[mx.array] = None, 131 | ) -> mx.array: 132 | latents = self.input_latents_norm(x) 133 | context = self.input_context_norm(hidden_states) 134 | 135 | latents = self.self_attn(latents, context, mask=mask) 136 | 137 | latents = x + latents 138 | r = latents 139 | 140 | latents = self.post_attention_layernorm(latents) 141 | latents = self.mlp(latents) 142 | latents = r + latents 143 | return latents 144 | 145 | 146 | class Idefics2PerceiverResampler(nn.Module): 147 | def __init__(self, config: ModelConfig): 148 | super().__init__() 149 | self.hidden_size = config.text_config.hidden_size 150 | self.n_latents = config.perceiver_config.resampler_n_latents 151 | 152 | self.latents = mx.ones((self.n_latents, self.hidden_size)) 153 | self.layers = [ 154 | Idefics2PerceiverLayer(config) 155 | for _ in range(config.perceiver_config.resampler_depth) 156 | ] 157 | self.norm = nn.RMSNorm(self.hidden_size, eps=config.text_config.rms_norm_eps) 158 | 159 | def __call__(self, x: mx.array, mask: Optional[mx.array] = None): 160 | 161 | h = mx.expand_dims(self.latents, axis=0) 162 | h = mx.repeat(h, x.shape[0], axis=0) 163 | 164 | for layer in self.layers: 165 | h = layer(h, x, mask=mask) 166 | 167 | return self.norm(h) 168 | 169 | 170 | class MLP(nn.Module): 171 | def __init__(self, dim, hidden_dim, output_size): 172 | super().__init__() 173 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 174 | self.down_proj = nn.Linear(hidden_dim, output_size, bias=False) 175 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 176 | 177 | def __call__(self, x) -> mx.array: 178 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 179 | 180 | 181 | class Idefics2Connector(nn.Module): 182 | def __init__(self, config: ModelConfig): 183 | super().__init__() 184 | self.modality_projection = MLP( 185 | config.vision_config.hidden_size, 186 | config.text_config.intermediate_size, 187 | config.text_config.hidden_size, 188 | ) 189 | 190 | self.perceiver_resampler = Idefics2PerceiverResampler(config) 191 | 192 | def __call__(self, x: mx.array, mask=None) -> mx.array: 193 | x = self.modality_projection(x) 194 | x = self.perceiver_resampler(x, mask=mask) 195 | return x 196 | 197 | 198 | class Model(nn.Module): 199 | def __init__(self, config: ModelConfig): 200 | self.model_type = config.model_type 201 | self.config = config 202 | 203 | self.vision_model = VisionModel(config.vision_config) 204 | self.language_model = LanguageModel(config.text_config) 205 | self.connector = Idefics2Connector(config) 206 | 207 | def get_input_embeddings( 208 | self, 209 | input_ids: Optional[mx.array] = None, 210 | pixel_values: Optional[mx.array] = None, 211 | pixel_attention_mask: Optional[mx.array] = None, 212 | ): 213 | if pixel_values is None: 214 | return self.language_model(input_ids) 215 | 216 | inputs_embeds = self.language_model.embed_tokens(input_ids) 217 | 218 | pooler_output, embeddings, hidden_state = self.vision_model( 219 | pixel_values[0].transpose(0, 2, 3, 1), output_hidden_states=True 220 | ) 221 | 222 | image_features = pooler_output[None, :].astype(pixel_values.dtype) 223 | 224 | image_features = self.connector(image_features, mask=None) 225 | 226 | final_inputs_embeds = self._prepare_inputs_for_multimodal( 227 | image_features, inputs_embeds, input_ids 228 | ) 229 | return final_inputs_embeds 230 | 231 | def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids): 232 | image_token_index = self.config.image_token_index 233 | num_images, num_image_patches, embed_dim = image_features.shape 234 | 235 | # Positions of tokens in input_ids, assuming batch size is 1 236 | image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() 237 | 238 | text_segments = [] 239 | start_idx = 0 240 | 241 | for position in image_positions: 242 | text_segments.append(inputs_embeds[:, start_idx:position]) 243 | start_idx = position + 1 244 | 245 | image_embeddings = mx.split(image_features, image_features.shape[0]) 246 | final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] 247 | final_embeddings += [inputs_embeds[:, start_idx:]] 248 | 249 | # Create a final embedding of shape 250 | # (1, num_image_patches*num_images + sequence_len, embed_dim) 251 | return mx.concatenate(final_embeddings, axis=1) 252 | 253 | def __call__( 254 | self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None 255 | ): 256 | input_embeddings = self.get_input_embeddings(input_ids, pixel_values) 257 | logits, cache = self.language_model( 258 | inputs=input_ids, cache=cache, inputs_embeds=input_embeddings 259 | ) 260 | return logits, cache 261 | 262 | @staticmethod 263 | def from_pretrained(path_or_hf_repo: str): 264 | path = Path(path_or_hf_repo) 265 | if not path.exists(): 266 | path = Path( 267 | snapshot_download( 268 | repo_id=path_or_hf_repo, 269 | allow_patterns=[ 270 | "*.json", 271 | "*.safetensors", 272 | "*.py", 273 | "tokenizer.model", 274 | "*.tiktoken", 275 | ], 276 | ) 277 | ) 278 | 279 | with open(path / "config.json", "r") as f: 280 | config = json.load(f) 281 | 282 | text_config = AutoConfig.from_pretrained(config["text_config"]["model_type"]) 283 | text_config = text_config.to_dict() 284 | config["text_config"] = text_config 285 | model_config = ModelConfig.from_dict(config) 286 | model_config.vision_config = VisionConfig.from_dict(config["vision_config"]) 287 | model_config.text_config = TextConfig.from_dict(config["text_config"]) 288 | model_config.perceiver_config = PerceiverConfig.from_dict( 289 | config["perceiver_config"] 290 | ) 291 | 292 | model = Model(model_config) 293 | weight_files = glob.glob(str(path / "*.safetensors")) 294 | if not weight_files: 295 | raise FileNotFoundError(f"No safetensors found in {path}") 296 | 297 | weights = {} 298 | for wf in weight_files: 299 | weights.update(mx.load(wf)) 300 | 301 | weights = model.sanitize(weights=weights) 302 | weights = VisionModel(model_config.vision_config).sanitize(weights=weights) 303 | weights = LanguageModel(model_config.text_config).sanitize(weights=weights) 304 | model.load_weights(list(weights.items())) 305 | return model 306 | 307 | def sanitize(self, weights): 308 | weights = { 309 | ( 310 | f"{k.split('.', 1)[1]}" 311 | if re.match(r"^model\.", k) 312 | else (f"language_model.{k}" if re.match(r"^lm_head\.", k) else k) 313 | ): v 314 | for k, v in weights.items() 315 | } 316 | 317 | weights = { 318 | ( 319 | f"language_model.{k.split('.', 1)[1]}" 320 | if re.match( 321 | r"^text_model\.", 322 | k, 323 | ) 324 | else k 325 | ): v 326 | for k, v in weights.items() 327 | } 328 | 329 | return weights 330 | -------------------------------------------------------------------------------- /mlx_vlm/models/multi_modality/multi_modality.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import json 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import List, Optional, Tuple, Union 7 | 8 | import mlx.core as mx 9 | import mlx.nn as nn 10 | import numpy as np 11 | from huggingface_hub import snapshot_download 12 | from PIL import Image 13 | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature 14 | from transformers.image_utils import to_numpy_array 15 | 16 | from ..base import expand2square 17 | from .language import LanguageModel, TextConfig 18 | from .vision import VisionConfig, VisionModel 19 | 20 | 21 | @dataclass 22 | class AlignerConfig: 23 | cls: str 24 | model_type: str 25 | params: dict 26 | 27 | @classmethod 28 | def from_dict(cls, params): 29 | return cls( 30 | **{ 31 | k: v 32 | for k, v in params.items() 33 | if k in inspect.signature(cls).parameters 34 | } 35 | ) 36 | 37 | 38 | @dataclass 39 | class ModelConfig: 40 | text_config: TextConfig 41 | vision_config: VisionConfig 42 | aligner_config: AlignerConfig 43 | model_type: str 44 | ignore_index: int = -100 45 | image_token_index: int = 100015 46 | vision_feature_select_strategy: str = "default" 47 | select_layer: int = -1 48 | pad_id: int = 100001 49 | num_image_tokens: int = 576 50 | vocab_size: int = 32000 51 | 52 | @classmethod 53 | def from_dict(cls, params): 54 | return cls( 55 | **{ 56 | k: v 57 | for k, v in params.items() 58 | if k in inspect.signature(cls).parameters 59 | } 60 | ) 61 | 62 | 63 | class ImageProcessor(BaseImageProcessor): 64 | model_input_names = ["pixel_values"] 65 | 66 | def __init__( 67 | self, 68 | config, 69 | image_size: int = 384, 70 | min_size: int = 14, 71 | image_mean: Union[Tuple[float, float, float], List[float]] = ( 72 | 0.5, 73 | 0.5, 74 | 0.5, 75 | ), 76 | image_std: Union[Tuple[float, float, float], List[float]] = ( 77 | 0.5, 78 | 0.5, 79 | 0.5, 80 | ), 81 | rescale_factor: float = 1.0 / 255.0, 82 | do_normalize: bool = True, 83 | **kwargs, 84 | ): 85 | super().__init__(**kwargs) 86 | if "high_res_cfg" in config["vision_config"]["params"]: 87 | self.image_size = config["vision_config"]["params"]["high_res_cfg"][ 88 | "image_size" 89 | ] 90 | self.image_mean = config["vision_config"]["params"]["high_res_cfg"][ 91 | "pixel_mean" 92 | ] 93 | self.image_std = config["vision_config"]["params"]["high_res_cfg"][ 94 | "pixel_std" 95 | ] 96 | self.do_normalize = False 97 | else: 98 | self.image_size = image_size 99 | self.image_mean = image_mean 100 | self.image_std = image_std 101 | self.do_normalize = do_normalize 102 | 103 | self.rescale_factor = rescale_factor 104 | self.min_size = min_size 105 | 106 | if image_mean is None: 107 | self.background_color = (127, 127, 127) 108 | else: 109 | self.background_color = tuple([int(x * 255) for x in self.image_mean]) 110 | 111 | def resize(self, pil_img: Image) -> np.ndarray: 112 | """ 113 | 114 | Args: 115 | pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB 116 | 117 | Returns: 118 | x (np.ndarray): [3, self.image_size, self.image_size] 119 | """ 120 | 121 | width, height = pil_img.size 122 | max_size = max(width, height) 123 | 124 | size = [ 125 | max(int(height / max_size * self.image_size), self.min_size), 126 | max(int(width / max_size * self.image_size), self.min_size), 127 | ] 128 | 129 | if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0: 130 | print(f"orig size = {pil_img.size}, new size = {size}") 131 | raise ValueError("Invalid size!") 132 | 133 | pil_img = pil_img.resize(size=tuple(size[::-1]), resample=Image.BICUBIC) 134 | 135 | pil_img = expand2square(pil_img, self.background_color) 136 | x = to_numpy_array(pil_img) 137 | 138 | # [H, W, 3] -> [3, H, W] 139 | x = np.transpose(x, (2, 0, 1)) 140 | 141 | return x 142 | 143 | def preprocess(self, images, **kwargs) -> BatchFeature: 144 | # resize and pad to [self.image_size, self.image_size] 145 | # then convert from [H, W, 3] to [3, H, W] 146 | images: List[np.ndarray] = [self.resize(image) for image in images] 147 | 148 | # resacle from [0, 255] -> [0, 1] 149 | images = [ 150 | self.rescale( 151 | image=image, 152 | scale=self.rescale_factor, 153 | input_data_format="channels_first", 154 | ) 155 | for image in images 156 | ] 157 | 158 | # normalize 159 | if self.do_normalize: 160 | images = [ 161 | self.normalize( 162 | image=image, 163 | mean=self.image_mean, 164 | std=self.image_std, 165 | input_data_format="channels_first", 166 | ) 167 | for image in images 168 | ] 169 | 170 | return images 171 | 172 | 173 | class MlpProjector(nn.Module): 174 | def __init__(self, config: ModelConfig): 175 | super().__init__() 176 | 177 | if config.aligner_config.params["projector_type"] == "mlp_gelu": 178 | self.layers = [ 179 | nn.Linear( 180 | config.vision_config.hidden_size, 181 | config.text_config.hidden_size, 182 | bias=True, 183 | ) 184 | ] 185 | mlp_depth = config.aligner_config.params["depth"] 186 | for _ in range(1, mlp_depth): 187 | self.layers.append(nn.GELU()) 188 | self.layers.append( 189 | nn.Linear( 190 | config.text_config.hidden_size, 191 | config.text_config.hidden_size, 192 | bias=True, 193 | ) 194 | ) 195 | elif ( 196 | config.aligner_config.params["projector_type"] 197 | == "low_high_hybrid_split_mlp_gelu" 198 | ): 199 | mlp_depth = config.aligner_config.params["depth"] 200 | self.high_up_proj = nn.Linear( 201 | config.vision_config.hidden_size, config.text_config.hidden_size // 2 202 | ) 203 | self.low_up_proj = nn.Linear( 204 | config.vision_config.hidden_size, config.text_config.hidden_size // 2 205 | ) 206 | 207 | self.layers = [] 208 | for _ in range(1, mlp_depth): 209 | self.layers.append(nn.GELU()) 210 | self.layers.append( 211 | nn.Linear( 212 | config.text_config.hidden_size, config.text_config.hidden_size 213 | ) 214 | ) 215 | 216 | else: 217 | projector_type = config.aligner_config.params["projector_type"] 218 | raise ValueError(f"Unknown projector type: {projector_type}") 219 | 220 | def __call__(self, x: Union[mx.array, Tuple]) -> mx.array: 221 | 222 | if isinstance(x, tuple): 223 | high_x, low_x = x 224 | 225 | high_x = self.high_up_proj(high_x) 226 | low_x = self.low_up_proj(low_x) 227 | 228 | B, D = high_x.shape[0], high_x.shape[-1] 229 | high_x = high_x.reshape(B, -1, D) 230 | 231 | x = mx.concatenate([high_x, low_x], axis=-1) 232 | 233 | for layer in self.layers: 234 | x = layer(x) 235 | 236 | return x 237 | 238 | 239 | class Model(nn.Module): 240 | def __init__(self, config: ModelConfig): 241 | self.config = config 242 | self.vision_model = VisionModel(config.vision_config) 243 | self.language_model = LanguageModel(config.text_config) 244 | self.aligner = MlpProjector(config) 245 | self.vision_feature_layer = config.select_layer 246 | self.vision_feature_select_strategy = config.vision_feature_select_strategy 247 | 248 | def add_image_token( 249 | self, 250 | image_indices: list, 251 | input_ids: np.ndarray, 252 | image_token_index: int, 253 | num_image_tokens: int, 254 | add_special_token: bool = False, 255 | ): 256 | """ 257 | Inserts image tokens into an array of input IDs at specified indices. 258 | 259 | Args: 260 | image_indices (List[int]): Indices where image tokens should be inserted. 261 | input_ids (np.ndarray): Original array of input IDs, expected to be two-dimensional. 262 | image_token_index (int): The ID used to represent an image token. 263 | num_image_tokens (int): Number of image tokens to insert at each index. 264 | add_special_token (bool): If True, adjusts the indices to include a special token. 265 | 266 | Returns: 267 | Tuple of (np.ndarray, np.ndarray): 268 | - Updated array of input IDs with image tokens inserted. 269 | - Array indicating the number of image tokens added at each position. 270 | """ 271 | input_slices = [] 272 | 273 | start = 0 274 | flat_input_ids = input_ids.flatten() 275 | 276 | for index in image_indices: 277 | end = (index[0] + 1) if add_special_token else index[0] 278 | 279 | input_slices.append(flat_input_ids[start:end]) 280 | input_slices.append( 281 | np.full((num_image_tokens,), image_token_index, dtype=np.int64) 282 | ) 283 | start = index[0] + 1 # Move start past the current image insertion point 284 | 285 | input_slices.append(flat_input_ids[start:]) 286 | 287 | input_ids = np.concatenate(input_slices, axis=0) 288 | num_image_tokens_array = np.array( 289 | [num_image_tokens] * len(image_indices), dtype=np.int64 290 | ) 291 | input_ids = input_ids.reshape(1, -1) 292 | 293 | return input_ids, num_image_tokens_array 294 | 295 | def get_input_embeddings( 296 | self, 297 | input_ids: Optional[mx.array] = None, 298 | pixel_values: Optional[mx.array] = None, 299 | ): 300 | if pixel_values is None: 301 | return self.language_model(input_ids) 302 | 303 | image_token_index = self.config.image_token_index 304 | num_image_tokens = self.config.num_image_tokens 305 | 306 | image_token_mask = np.array(input_ids[0] == image_token_index).astype(bool) 307 | image_indices = np.nonzero(image_token_mask) 308 | 309 | input_ids, num_image_tokens = self.add_image_token( 310 | image_indices=image_indices, 311 | input_ids=np.array(input_ids), 312 | image_token_index=image_token_index, 313 | num_image_tokens=num_image_tokens, 314 | ) 315 | 316 | input_ids = mx.array(input_ids) 317 | 318 | # Get the input embeddings from the language model 319 | inputs_embeds = self.language_model.model.embed_tokens(input_ids) 320 | 321 | # Get the ouptut hidden states from the vision model 322 | if self.config.vision_config.cls == "HybridVisionTower": 323 | hidden_states = self.vision_model( 324 | pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True 325 | ) 326 | else: 327 | hidden_states, _, _ = self.vision_model( 328 | pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True 329 | ) 330 | 331 | # Pass image features through the multi-modal projector 332 | image_features = self.aligner(hidden_states) 333 | 334 | # Insert special image tokens in the input_ids 335 | final_inputs_embeds = self._merge_input_ids_with_image_features( 336 | image_features, inputs_embeds, input_ids 337 | ) 338 | return final_inputs_embeds 339 | 340 | def _merge_input_ids_with_image_features( 341 | self, image_features, inputs_embeds, input_ids 342 | ): 343 | image_token_index = self.config.image_token_index 344 | 345 | # Positions of tokens in input_ids, assuming batch size is 1 346 | image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() 347 | text_segments = [] 348 | start_idx = 0 349 | 350 | for position in image_positions: 351 | text_segments.append(inputs_embeds[:, start_idx:position]) 352 | start_idx = position + 1 353 | 354 | image_embeddings = mx.split(image_features, image_features.shape[0]) 355 | final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] 356 | final_embeddings += [inputs_embeds[:, start_idx:]] 357 | 358 | # Create a final embedding of shape 359 | # (1, num_image_patches*num_images + sequence_len, embed_dim) 360 | return mx.concatenate(final_embeddings, axis=1) 361 | 362 | def __call__( 363 | self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None 364 | ): 365 | 366 | input_embeddings = self.get_input_embeddings(input_ids, pixel_values) 367 | logits, cache = self.language_model( 368 | input_ids, cache=cache, inputs_embeds=input_embeddings 369 | ) 370 | return logits, cache 371 | 372 | @staticmethod 373 | def from_pretrained(path_or_hf_repo: str): 374 | path = Path(path_or_hf_repo) 375 | if not path.exists(): 376 | path = Path( 377 | snapshot_download( 378 | repo_id=path_or_hf_repo, 379 | allow_patterns=[ 380 | "*.json", 381 | "*.safetensors", 382 | "*.py", 383 | "tokenizer.model", 384 | "*.tiktoken", 385 | ], 386 | ) 387 | ) 388 | 389 | with open(path / "config.json", "r") as f: 390 | model_config = json.load(f) 391 | 392 | model_config = ModelConfig.from_dict(model_config) 393 | 394 | model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) 395 | model_config.aligner_config = AlignerConfig.from_dict( 396 | model_config.aligner_config 397 | ) 398 | model_config.text_config = TextConfig.from_dict(model_config.text_config) 399 | 400 | model = Model(model_config) 401 | weight_files = glob.glob(str(path / "*.safetensors")) 402 | if not weight_files: 403 | raise FileNotFoundError(f"No safetensors found in {path}") 404 | 405 | weights = {} 406 | for wf in weight_files: 407 | weights.update(mx.load(wf)) 408 | 409 | weights = VisionModel.sanitize(weights) 410 | weights = LanguageModel.sanitize(weights) 411 | 412 | model.load_weights(list(weights.items())) 413 | return model 414 | --------------------------------------------------------------------------------