├── assets ├── cat.jpeg ├── dog.jpeg └── README.md ├── requirements.txt ├── setup.py ├── example.py ├── LICENSE ├── tests └── test_mlx_clip.py ├── README.md ├── mlx_clip ├── image_processor.py ├── tokenizer.py ├── __init__.py ├── convert.py └── model.py └── .gitignore /assets/cat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harperreed/mlx_clip/HEAD/assets/cat.jpeg -------------------------------------------------------------------------------- /assets/dog.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harperreed/mlx_clip/HEAD/assets/dog.jpeg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mlx 2 | numpy 3 | transformers 4 | torch 5 | huggingface_hub 6 | Pillow 7 | pytest 8 | -------------------------------------------------------------------------------- /assets/README.md: -------------------------------------------------------------------------------- 1 | # Attribution 2 | 3 | - `cat.jpeg` is a "Cat" by London's, licensed under CC BY-SA 2.0. 4 | - `dog.jpeg` is a "Happy Dog" by tedmurphy, licensed under CC BY 2.0. 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('requirements.txt') as f: 4 | required = f.read().splitlines() 5 | 6 | setup( 7 | name='mlx_clip', 8 | version='0.2', 9 | packages=find_packages(), 10 | description='A simple package to use CLIP on apple silicon using the MLX libraries from Apple', 11 | long_description=open('README.md').read(), 12 | long_description_content_type='text/markdown', 13 | author='Harper Reed', 14 | author_email='harper@modest.com', 15 | url='https://github.com/yourusername/yourrepository', 16 | install_requires=required, 17 | ) 18 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import mlx_clip 2 | 3 | # Initialize the mlx_clip model with the given model name. 4 | clip = mlx_clip.mlx_clip("mlx_model") 5 | 6 | # Encode the image from the specified file path and obtain the image embeddings. 7 | # The embeddings are a numerical representation of the image content. 8 | image_embeddings = clip.image_encoder("assets/cat.jpeg") 9 | # Print the image embeddings to the console. 10 | #print(image_embeddings) 11 | 12 | # Encode the text description and obtain the text embeddings. 13 | # The embeddings are a numerical representation of the textual description. 14 | text_embeddings = clip.text_encoder("a photo of a cat") 15 | # Print the text embeddings to the console. 16 | print(text_embeddings) 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Harper Reed 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/test_mlx_clip.py: -------------------------------------------------------------------------------- 1 | # tests/test_mlx_clip.py 2 | 3 | import os 4 | import tempfile 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import pytest 9 | from PIL import Image 10 | 11 | from mlx_clip import mlx_clip 12 | from mlx_clip.convert import convert_weights 13 | from mlx_clip.image_processor import CLIPImageProcessor 14 | from mlx_clip.model import CLIPModel 15 | from mlx_clip.tokenizer import CLIPTokenizer 16 | import mlx.core as mx 17 | 18 | # Helpers 19 | 20 | def create_dummy_image(size=(224, 224)): 21 | image = Image.new("RGB", size, "white") 22 | return image 23 | 24 | def create_temp_dir(): 25 | return tempfile.mkdtemp() 26 | 27 | # Tests 28 | 29 | def test_convert_weights(): 30 | hf_repo = "openai/clip-vit-base-patch32" 31 | with tempfile.TemporaryDirectory() as temp_dir: 32 | convert_weights(hf_repo, temp_dir) 33 | assert len(os.listdir(temp_dir)) > 0 34 | 35 | def test_clip_image_processor(): 36 | image_processor = CLIPImageProcessor() 37 | image = create_dummy_image() 38 | processed_image = image_processor([image]) 39 | assert processed_image.shape == (1, 224, 224, 3) 40 | 41 | def test_clip_tokenizer(): 42 | with tempfile.TemporaryDirectory() as temp_dir: 43 | convert_weights("openai/clip-vit-base-patch32", temp_dir) 44 | tokenizer = CLIPTokenizer.from_pretrained(temp_dir) 45 | text = "This is a test sentence." 46 | tokens = tokenizer(text) 47 | assert len(tokens) > 0 48 | 49 | def test_clip_model(): 50 | with tempfile.TemporaryDirectory() as temp_dir: 51 | convert_weights("openai/clip-vit-base-patch32", temp_dir) 52 | model = CLIPModel.from_pretrained(temp_dir) 53 | 54 | image = create_dummy_image() 55 | image_processor = CLIPImageProcessor() 56 | processed_image = image_processor([image]) 57 | 58 | tokenizer = CLIPTokenizer.from_pretrained(temp_dir) 59 | text = "This is a test sentence." 60 | tokens = tokenizer(text) 61 | 62 | output = model(input_ids=mx.array(tokens).reshape(1, -1), pixel_values=processed_image) 63 | assert output.text_embeds is not None 64 | assert output.image_embeds is not None 65 | 66 | def test_mlx_clip_end_to_end(): 67 | hf_repo = "openai/clip-vit-base-patch32" 68 | with tempfile.TemporaryDirectory() as temp_dir: 69 | convert_weights(hf_repo, temp_dir) 70 | clip = mlx_clip(temp_dir) 71 | 72 | image_path = "assets/cat.jpeg" 73 | image_embeddings = clip.image_encoder(image_path) 74 | assert len(image_embeddings) > 0 75 | 76 | text = "a photo of a cat" 77 | text_embeddings = clip.text_encoder(text) 78 | assert len(text_embeddings) > 0 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MLX_CLIP 📚🤖 2 | 3 | [![GitHub](https://img.shields.io/github/license/harperreed/mlx-clip)](https://github.com/harperreed/mlx-clip/blob/main/LICENSE) 4 | 5 | Welcome to the MLX_CLIP repository! 🎉 This repository contains an implementation of the CLIP (Contrastive Language-Image Pre-training) model using the MLX library. CLIP is a powerful model that learns to associate images with their corresponding textual descriptions, enabling various downstream tasks such as image retrieval and zero-shot classification. 🖼️📝 6 | 7 | ## Features ✨ 8 | 9 | - Easy-to-use MLX_CLIP model for generating image and text embeddings 10 | - Support for loading pre-trained CLIP weights from Hugging Face 11 | - Efficient conversion of weights to MLX format for optimal performance 12 | - Seamless integration with the MLX library for accelerated inference on Apple Silicon devices 13 | 14 | ## Getting Started 🚀 15 | 16 | To get started with MLX_CLIP, follow these steps: 17 | 18 | 1. Clone the repository: 19 | ``` 20 | git clone https://github.com/harperreed/mlx_clip.git 21 | ``` 22 | 23 | 2. Install the required dependencies: 24 | ``` 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | 3. Load the pre-trained CLIP model: 29 | ```python 30 | from mlx_clip import mlx_clip 31 | 32 | model_dir = "path/to/pretrained/model" 33 | clip = mlx_clip(model_dir) 34 | ``` 35 | 36 | 4. Use the CLIP model for generating image and text embeddings: 37 | ```python 38 | image_path = "path/to/image.jpg" 39 | image_embedding = clip.image_encoder(image_path) 40 | 41 | text = "A description of the image" 42 | text_embedding = clip.text_encoder(text) 43 | ``` 44 | 45 | 46 | 47 | ## Examples 💡 48 | 49 | Check out the `example.py` file for a simple example of how to use MLX_CLIP to generate image and text embeddings. 50 | 51 | ## Model Conversion 🔄 52 | 53 | MLX_CLIP provides a convenient utility to convert pre-trained CLIP weights from the Hugging Face repository to the MLX format. To convert weights, use the `convert_weights` function from `mlx_clip.convert`: 54 | 55 | ```python 56 | from mlx_clip.convert import convert_weights 57 | 58 | hf_repo = "openai/clip-vit-base-patch32" 59 | mlx_path = "path/to/save/converted/model" 60 | convert_weights(hf_repo, mlx_path) 61 | ``` 62 | 63 | ## Contributing 🤝 64 | 65 | Contributions to MLX_CLIP are welcome! If you encounter any issues, have suggestions for improvements, or want to add new features, please open an issue or submit a pull request. Make sure to follow the existing code style and provide appropriate documentation for your changes. 66 | 67 | ## License 📜 68 | 69 | MLX_CLIP is licensed under the [MIT License](LICENSE). 70 | 71 | ## Acknowledgments 🙏 72 | 73 | MLX_CLIP is heavily based on the [mlx-experiments clip implementation](https://github.com/ml-explore/mlx-examples/tree/main/clip). Special thanks to the MLX team for their incredible work! 74 | 75 | ## Contact 📞 76 | 77 | For any questions or inquiries, feel free to reach out to the project maintainer: 78 | 79 | Harper Reed 80 | - Email: harper@modest.com 81 | - GitHub: [harperreed](https://github.com/harperreed) 82 | 83 | Happy coding with MLX_CLIP! 😄💻🚀 84 | -------------------------------------------------------------------------------- /mlx_clip/image_processor.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import json 4 | from pathlib import Path 5 | from typing import List, Tuple 6 | 7 | import mlx.core as mx 8 | import numpy as np 9 | from PIL.Image import Image 10 | 11 | 12 | class CLIPImageProcessor: 13 | """ 14 | A simple port of 15 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/image_processing_clip.py. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | crop_size: int = 224, 21 | do_center_crop: bool = True, 22 | do_normalize: bool = True, 23 | do_resize: bool = True, 24 | image_mean: List[float] = [0.48145466, 0.4578275, 0.40821073], 25 | image_std: List[float] = [0.26862954, 0.26130258, 0.27577711], 26 | size: int = 224, 27 | **kwargs 28 | ) -> None: 29 | self.crop_size = crop_size 30 | self.do_center_crop = do_center_crop 31 | self.do_normalize = do_normalize 32 | self.do_resize = do_resize 33 | self.image_mean = mx.array(image_mean) 34 | self.image_std = mx.array(image_std) 35 | self.size = size 36 | 37 | def __call__(self, images: List[Image]) -> mx.array: 38 | return mx.concatenate( 39 | [self._preprocess(image)[None] for image in images], axis=0 40 | ) 41 | 42 | def _preprocess(self, image: Image) -> mx.array: 43 | if self.do_resize: 44 | image = resize(image, self.size) 45 | if self.do_center_crop: 46 | image = center_crop(image, (self.crop_size, self.crop_size)) 47 | image = mx.array(np.array(image)) 48 | image = rescale(image) 49 | if self.do_normalize: 50 | image = normalize(image, self.image_mean, self.image_std) 51 | return image 52 | 53 | @staticmethod 54 | def from_pretrained(path: str): 55 | path = Path(path) 56 | with open(path / "preprocessor_config.json", encoding="utf-8") as f: 57 | config = json.load(f) 58 | return CLIPImageProcessor(**config) 59 | 60 | 61 | def resize(image: Image, short_size: int) -> Image: 62 | """ 63 | Resize so small size to short_size 64 | """ 65 | width, height = image.size 66 | short = min(width, height) 67 | long = max(width, height) 68 | if short == short_size: 69 | return image 70 | new_short = short_size 71 | new_long = int(short_size * long / short) 72 | new_size = (new_short, new_long) if width <= height else (new_long, new_short) 73 | return image.resize(new_size) 74 | 75 | 76 | def center_crop(image: Image, size: Tuple[int, int]) -> Image: 77 | if size[0] % 2 != 0 or size[1] % 2 != 0: 78 | raise ValueError("Only even crop sizes supported.") 79 | original_width, original_height = image.size 80 | crop_height, crop_width = size 81 | top = (original_height - crop_height) // 2 82 | bottom = top + crop_height 83 | left = (original_width - crop_width) // 2 84 | right = left + crop_width 85 | return image.crop((left, top, right, bottom)) 86 | 87 | 88 | def rescale(image: mx.array) -> mx.array: 89 | return image.astype(mx.float32) * (1 / 255.0) 90 | 91 | 92 | def normalize(image: mx.array, mean: mx.array, std: mx.array) -> mx.array: 93 | return (image - mean) / std 94 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | mlx_model 163 | -------------------------------------------------------------------------------- /mlx_clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import json 4 | from pathlib import Path 5 | from typing import Any 6 | 7 | import mlx.core as mx 8 | import regex 9 | 10 | 11 | class CLIPTokenizer: 12 | """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ .""" 13 | 14 | def __init__(self, bpe_ranks, vocab): 15 | self.bpe_ranks = bpe_ranks 16 | self.vocab = vocab 17 | self.pat = regex.compile( 18 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 19 | regex.IGNORECASE, 20 | ) 21 | self._cache = {self.bos: self.bos, self.eos: self.eos} 22 | 23 | @property 24 | def bos(self): 25 | return "<|startoftext|>" 26 | 27 | @property 28 | def bos_token(self): 29 | return self.vocab[self.bos] 30 | 31 | @property 32 | def eos(self): 33 | return "<|endoftext|>" 34 | 35 | @property 36 | def eos_token(self): 37 | return self.vocab[self.eos] 38 | 39 | def bpe(self, text): 40 | if text in self._cache: 41 | return self._cache[text] 42 | 43 | unigrams = list(text[:-1]) + [text[-1] + ""] 44 | unique_bigrams = set(zip(unigrams, unigrams[1:])) 45 | 46 | if not unique_bigrams: 47 | return unigrams 48 | 49 | # In every iteration try to merge the two most likely bigrams. If none 50 | # was merged we are done. 51 | # 52 | # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_py 53 | while unique_bigrams: 54 | bigram = min( 55 | unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf")) 56 | ) 57 | if bigram not in self.bpe_ranks: 58 | break 59 | 60 | new_unigrams = [] 61 | skip = False 62 | for a, b in zip(unigrams, unigrams[1:]): 63 | if skip: 64 | skip = False 65 | continue 66 | 67 | if (a, b) == bigram: 68 | new_unigrams.append(a + b) 69 | skip = True 70 | 71 | else: 72 | new_unigrams.append(a) 73 | 74 | if not skip: 75 | new_unigrams.append(b) 76 | 77 | unigrams = new_unigrams 78 | unique_bigrams = set(zip(unigrams, unigrams[1:])) 79 | 80 | self._cache[text] = unigrams 81 | 82 | return unigrams 83 | 84 | def __call__(self, *args: Any, **kwargs: Any) -> Any: 85 | return self.tokenize(*args, **kwargs) 86 | 87 | def tokenize(self, text, prepend_bos=True, append_eos=True) -> mx.array: 88 | if isinstance(text, list): 89 | return mx.array([self.tokenize(t, prepend_bos, append_eos) for t in text]) 90 | 91 | # Lower case, cleanup, and split. Hugging Face does a much, 92 | # more thorough job here but this should suffice for 95% of 93 | # cases. 94 | clean_text = regex.sub(r"\s+", " ", text.lower()) 95 | tokens = regex.findall(self.pat, clean_text) 96 | 97 | # Split the tokens according to the byte-pair merge file 98 | bpe_tokens = [ti for t in tokens for ti in self.bpe(t)] 99 | 100 | # Map to token ids and return 101 | tokens = [] 102 | if prepend_bos: 103 | tokens.append(self.bos_token) 104 | tokens.extend(self.vocab[t] for t in bpe_tokens) 105 | if append_eos: 106 | tokens.append(self.eos_token) 107 | return mx.array(tokens) 108 | 109 | @staticmethod 110 | def from_pretrained(path: str): 111 | path = Path(path) 112 | 113 | with open(path / "vocab.json", encoding="utf-8") as f: 114 | vocab = json.load(f) 115 | with open(path / "merges.txt", encoding="utf-8") as f: 116 | bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1] 117 | 118 | bpe_merges = [tuple(m.split()) for m in bpe_merges] 119 | bpe_ranks = dict(map(reversed, enumerate(bpe_merges))) 120 | 121 | return CLIPTokenizer(bpe_ranks, vocab) 122 | -------------------------------------------------------------------------------- /mlx_clip/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from PIL import Image 3 | from typing import Tuple 4 | from pathlib import Path 5 | from .image_processor import CLIPImageProcessor 6 | from .model import CLIPModel 7 | from .tokenizer import CLIPTokenizer 8 | from .convert import convert_weights 9 | 10 | class mlx_clip: 11 | def __init__(self, model_dir: str, hf_repo: str = "openai/clip-vit-base-patch32"): 12 | """ 13 | Initialize the MLX_CLIP class by loading the CLIP model, tokenizer, and image processor. 14 | 15 | Args: 16 | model_dir (str): The directory where the CLIP model is stored. 17 | """ 18 | self.logger = logging.getLogger(__name__) 19 | self.hf_repo = hf_repo 20 | self.model_dir = model_dir 21 | self.model, self.tokenizer, self.img_processor = self.load_clip_model(model_dir) 22 | 23 | def download_and_convert_weights(self, hf_repo: str = "openai/clip-vit-base-patch32", dtype: str = "float32") -> str: 24 | """ 25 | Download the pre-trained weights from Hugging Face and convert them to the appropriate format. 26 | 27 | This method checks if the model directory already exists. If not, it will download the pre-trained 28 | weights from the specified Hugging Face repository and convert them to the required format and 29 | data type for the CLIP model. 30 | 31 | Args: 32 | hf_repo (str): The Hugging Face repository where the pre-trained CLIP weights are stored. 33 | dtype (str): The data type to which the weights should be converted. Default is 'float32'. 34 | 35 | Returns: 36 | str: The path to the directory where the weights are stored after conversion. 37 | 38 | Raises: 39 | Exception: If any error occurs during the download or conversion process. 40 | """ 41 | # Define the path to the model directory 42 | mlx_path = Path(self.model_dir) 43 | self.logger.debug(f"Checking if model directory {mlx_path} exists.") 44 | 45 | # If the model directory does not exist, download and convert weights 46 | if not mlx_path.exists(): 47 | self.logger.info(f"Model directory does not exist. Downloading and converting weights from {hf_repo}.") 48 | try: 49 | convert_weights(hf_repo, str(mlx_path), dtype) 50 | self.logger.info("Weights downloaded and converted successfully.") 51 | except Exception as e: 52 | self.logger.error(f"Failed to download and convert weights: {e}") 53 | raise 54 | 55 | # Return the model directory path as a string 56 | return str(mlx_path) 57 | 58 | 59 | def load_clip_model(self, model_dir: str) -> Tuple[CLIPModel, CLIPTokenizer, CLIPImageProcessor]: 60 | """ 61 | Loads the CLIP model, tokenizer, and image processor from a given directory. 62 | If the model directory does not exist or is empty, it attempts to download and convert weights. 63 | 64 | Args: 65 | model_dir (str): The directory where the CLIP model is stored. 66 | 67 | Returns: 68 | Tuple[CLIPModel, CLIPTokenizer, CLIPImageProcessor]: The loaded CLIP model, tokenizer, and image processor. 69 | 70 | Raises: 71 | FileNotFoundError: If the model directory does not exist and cannot be created. 72 | Exception: If there is an issue loading any of the model components. 73 | """ 74 | model_path = Path(model_dir) 75 | if not model_path.exists() or not any(model_path.iterdir()): 76 | self.logger.warning(f"Model directory {model_dir} not found or is empty. Attempting to download and convert weights.") 77 | try: 78 | model_dir = self.download_and_convert_weights(hf_repo= self.hf_repo) 79 | except Exception as e: 80 | self.logger.error(f"Failed to download and convert weights: {e}") 81 | raise FileNotFoundError(f"Model directory {model_dir} does not exist and weights could not be downloaded.") 82 | 83 | self.logger.info(f"Loading CLIP model from directory: {model_dir}") 84 | try: 85 | model = CLIPModel.from_pretrained(model_dir) 86 | self.logger.debug("CLIP model loaded successfully.") 87 | except Exception as e: 88 | self.logger.error(f"Failed to load CLIP model: {e}") 89 | raise Exception(f"Failed to load CLIP model from {model_dir}") from e 90 | 91 | try: 92 | tokenizer = CLIPTokenizer.from_pretrained(model_dir) 93 | self.logger.debug("CLIP tokenizer loaded successfully.") 94 | except Exception as e: 95 | self.logger.error(f"Failed to load CLIP tokenizer: {e}") 96 | raise Exception(f"Failed to load CLIP tokenizer from {model_dir}") from e 97 | 98 | try: 99 | img_processor = CLIPImageProcessor.from_pretrained(model_dir) 100 | self.logger.debug("CLIP image processor loaded successfully.") 101 | except Exception as e: 102 | self.logger.error(f"Failed to load CLIP image processor: {e}") 103 | raise Exception(f"Failed to load CLIP image processor from {model_dir}") from e 104 | 105 | return model, tokenizer, img_processor 106 | 107 | def image_encoder(self, image_path: str): 108 | """ 109 | Generate an image embedding using the CLIP model. 110 | 111 | Args: 112 | - image_path: Path to the image file to be processed. 113 | 114 | Returns: 115 | - A numpy array representing the embedding of the image. 116 | """ 117 | try: 118 | # Open the image file 119 | image = Image.open(image_path) 120 | self.logger.debug(f"Image {image_path} opened successfully.") 121 | except Exception as e: 122 | self.logger.error(f"Error opening image {image_path}: {e}") 123 | raise 124 | 125 | try: 126 | # Preprocess the image using the provided image processor 127 | processed_image = self.img_processor([image]) 128 | self.logger.debug(f"Image {image_path} processed successfully.") 129 | except Exception as e: 130 | self.logger.error(f"Error processing image {image_path}: {e}") 131 | raise 132 | 133 | try: 134 | # Generate embeddings using the CLIP model 135 | inputs = {"pixel_values": processed_image} 136 | output = self.model(**inputs) 137 | image_embed = output.image_embeds 138 | self.logger.debug(f"Image embedding for {image_path} generated successfully.") 139 | except Exception as e: 140 | self.logger.error(f"Error generating embedding for image {image_path}: {e}") 141 | raise 142 | 143 | # Return the first (and only) image embedding 144 | return image_embed[0].tolist() 145 | 146 | def text_encoder(self, text: str): 147 | """ 148 | Generate a text embedding using the CLIP model. 149 | 150 | Args: 151 | - text: The text string to be processed and embedded. 152 | 153 | Returns: 154 | - A numpy array representing the embedding of the text. 155 | 156 | Raises: 157 | - Exception: Propagates any exception that might occur during tokenization or embedding generation. 158 | """ 159 | try: 160 | # Tokenize the text using the provided tokenizer 161 | inputs = {"input_ids": self.tokenizer([text])} 162 | self.logger.debug(f"Text '{text}' tokenized successfully.") 163 | except Exception as e: 164 | self.logger.error(f"Error tokenizing text '{text}': {e}") 165 | raise 166 | 167 | try: 168 | # Generate embeddings using the CLIP model 169 | output = self.model(**inputs) 170 | text_embeds = output.text_embeds 171 | self.logger.debug(f"Text embedding for '{text}' generated successfully.") 172 | except Exception as e: 173 | self.logger.error(f"Error generating embedding for text '{text}': {e}") 174 | raise 175 | 176 | # Return the first (and only) text embedding 177 | return text_embeds[0].tolist() 178 | -------------------------------------------------------------------------------- /mlx_clip/convert.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | # Hacked on by Harper Reed 2024 3 | 4 | import argparse 5 | import json 6 | import shutil 7 | import logging 8 | from pathlib import Path 9 | from typing import Any, Dict, Union 10 | 11 | import mlx.core as mx 12 | import torch 13 | from huggingface_hub import snapshot_download 14 | 15 | #set logger level info 16 | logging.basicConfig(level=logging.INFO) 17 | 18 | def convert_weights( 19 | hf_repo: str = "openai/clip-vit-base-patch32", 20 | mlx_path: str = "mlx_model", 21 | dtype: str = "float32", 22 | ) -> None: 23 | """ 24 | Convert the weights from a Hugging Face repository to the MLX format. 25 | 26 | Args: 27 | hf_repo (str): The name of the Hugging Face repository to download the weights from. 28 | mlx_path (str): The local directory path where the converted MLX weights will be saved. 29 | dtype (str): The target data type for the converted weights. Supported types include 30 | 'float32', 'float16', 'bfloat16', etc. 31 | 32 | Raises: 33 | ValueError: If the specified data type is not supported. 34 | """ 35 | # Attempt to create the MLX model directory, if it doesn't exist 36 | mlx_path = Path(mlx_path) 37 | try: 38 | mlx_path.mkdir(parents=True, exist_ok=True) 39 | except Exception as e: 40 | raise IOError(f"Failed to create directory {mlx_path}: {e}") 41 | 42 | # Download the model from the Hugging Face repository 43 | torch_path = get_model_path(hf_repo) 44 | 45 | # Load the PyTorch model weights 46 | try: 47 | torch_weights = torch.load(torch_path / "pytorch_model.bin", map_location="cpu") 48 | logging.info("Loaded PyTorch weights") 49 | except FileNotFoundError: 50 | raise FileNotFoundError(f"Pytorch model file not found in {torch_path}") 51 | except Exception as e: 52 | raise IOError(f"Error loading PyTorch weights: {e}") 53 | 54 | # Convert the weights to MLX format 55 | try: 56 | mlx_weights = { 57 | k: torch_to_mx(v, dtype=dtype) for k, v in torch_weights.items() 58 | } 59 | logging.info("Converted weights to MLX format") 60 | except ValueError as e: 61 | raise ValueError(f"Unsupported dtype {dtype}: {e}") 62 | 63 | # Save the converted weights to the MLX format 64 | try: 65 | save_weights(mlx_path, mlx_weights) 66 | logging.info(f"Saved MLX weights to {mlx_path}") 67 | except Exception as e: 68 | raise IOError(f"Error saving MLX weights: {e}") 69 | 70 | # Copy additional necessary files from the Hugging Face repository to the MLX directory 71 | additional_files = ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"] 72 | for fn in additional_files: 73 | try: 74 | src_file = torch_path / fn 75 | dst_file = mlx_path / fn 76 | shutil.copyfile(str(src_file), str(dst_file)) 77 | logging.info(f" Copied {fn} to {mlx_path}") 78 | except FileNotFoundError: 79 | logging.warning(f" {fn} not found in {torch_path}, skipping.") 80 | except Exception as e: 81 | raise IOError(f"Error copying {fn}: {e}") 82 | 83 | def make_shards(weights: Dict[str, Any], max_file_size_gb: int = 5) -> list: 84 | """ 85 | Splits the weights dictionary into multiple shards, each with a maximum size limit. 86 | 87 | This function is used to avoid memory issues when saving large models by creating 88 | smaller, more manageable files that can be saved individually. 89 | 90 | Args: 91 | weights (Dict[str, Any]): The dictionary containing the model weights. 92 | max_file_size_gb (int): The maximum allowed file size for each shard in gigabytes. 93 | 94 | Returns: 95 | list: A list of dictionaries, where each dictionary represents a shard. 96 | 97 | Raises: 98 | ValueError: If the max_file_size_gb is less than or equal to zero. 99 | """ 100 | if max_file_size_gb <= 0: 101 | raise ValueError("max_file_size_gb must be greater than zero.") 102 | 103 | # Convert the maximum file size to bytes for comparison 104 | max_file_size_bytes = max_file_size_gb * (1 << 30) # 1 GB = 2^30 bytes 105 | shards = [] 106 | shard, shard_size = {}, 0 107 | 108 | # Iterate over the weights and partition them into shards 109 | for k, v in weights.items(): 110 | weight_size = v.nbytes 111 | if shard_size + weight_size > max_file_size_bytes: 112 | # If adding the weight to the current shard exceeds the size limit, 113 | # save the current shard and start a new one 114 | shards.append(shard) 115 | shard, shard_size = {}, 0 116 | shard[k] = v 117 | shard_size += weight_size 118 | 119 | # Add the last shard if it contains any weights 120 | if shard: 121 | shards.append(shard) 122 | 123 | # Log the sharding information 124 | num_shards = len(shards) 125 | logging.info(f"Created {num_shards} shards with a maximum size of {max_file_size_gb} GB each.") 126 | return shards 127 | 128 | 129 | def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None: 130 | """Save model weights into specified directory.""" 131 | if isinstance(save_path, str): 132 | save_path = Path(save_path) 133 | save_path.mkdir(parents=True, exist_ok=True) 134 | 135 | shards = make_shards(weights) 136 | shards_count = len(shards) 137 | shard_file_format = ( 138 | "model-{:05d}-of-{:05d}.safetensors" 139 | if shards_count > 1 140 | else "model.safetensors" 141 | ) 142 | 143 | total_size = sum(v.nbytes for v in weights.values()) 144 | index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} 145 | 146 | for i, shard in enumerate(shards): 147 | shard_name = shard_file_format.format(i + 1, shards_count) 148 | shard_path = save_path / shard_name 149 | 150 | mx.save_safetensors(str(shard_path), shard) 151 | 152 | for weight_name in shard.keys(): 153 | index_data["weight_map"][weight_name] = shard_name 154 | 155 | index_data["weight_map"] = { 156 | k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"]) 157 | } 158 | 159 | with open(save_path / "model.safetensors.index.json", "w") as f: 160 | json.dump( 161 | index_data, 162 | f, 163 | indent=4, 164 | ) 165 | 166 | 167 | def get_model_path(path_or_hf_repo: str) -> Path: 168 | """ 169 | Retrieves the model path for a given local path or Hugging Face repository. 170 | 171 | If the input is a local path and it exists, it returns the Path object directly. 172 | If the input is a Hugging Face repository, it downloads the repository and returns 173 | the path to the downloaded files. 174 | 175 | Args: 176 | path_or_hf_repo (str): The local path or Hugging Face repository identifier. 177 | 178 | Returns: 179 | Path: The path to the model files. 180 | 181 | Raises: 182 | ValueError: If the input path does not exist and is not a valid Hugging Face repository. 183 | """ 184 | model_path = Path(path_or_hf_repo) 185 | if model_path.exists(): 186 | # The input is a local path and it exists 187 | logging.info(f"Using existing local model path: {model_path}") 188 | else: 189 | # The input is assumed to be a Hugging Face repository identifier 190 | try: 191 | # Attempt to download the model from Hugging Face 192 | logging.info(f"Downloading model from Hugging Face repository: {path_or_hf_repo}") 193 | model_path = Path( 194 | snapshot_download( 195 | repo_id=path_or_hf_repo, 196 | allow_patterns=[ 197 | "*.bin", 198 | "*.json", 199 | "*.txt", 200 | ], 201 | ) 202 | ) 203 | logging.info(f"Model downloaded and extracted to: {model_path}") 204 | except Exception as e: 205 | # An error occurred during download, raise a more informative error 206 | raise ValueError( 207 | f"Failed to download the model from Hugging Face repository " 208 | f"'{path_or_hf_repo}'. Ensure that the repository exists and is accessible. " 209 | f"Error: {e}" 210 | ) 211 | return model_path 212 | 213 | 214 | def torch_to_mx(tensor: torch.Tensor, *, dtype: str) -> mx.array: 215 | """ 216 | Convert a PyTorch tensor to an MLX array with the specified data type. 217 | 218 | Args: 219 | tensor (torch.Tensor): The PyTorch tensor to convert. 220 | dtype (str): The target data type for the converted array. Supported types include 221 | 'float32', 'float16', 'bfloat16', etc. 222 | 223 | Returns: 224 | mx.array: The converted MLX array. 225 | 226 | Raises: 227 | TypeError: If the input is not a PyTorch tensor. 228 | ValueError: If the specified data type is not supported by MLX or PyTorch. 229 | """ 230 | if not isinstance(tensor, torch.Tensor): 231 | raise TypeError("The input must be a PyTorch tensor.") 232 | 233 | # Check if the specified dtype is supported by both PyTorch and MLX 234 | supported_dtypes = ["float32", "float16", "bfloat16"] 235 | if dtype not in supported_dtypes: 236 | raise ValueError(f"Unsupported dtype '{dtype}'. Supported types are: {supported_dtypes}") 237 | 238 | # Handle bfloat16 separately since it is not directly convertible to numpy 239 | if dtype == "bfloat16": 240 | # bfloat16 is not supported by NumPy, so we convert it to float32 first 241 | logging.info("Converting bfloat16 to float32 to avoid precision loss before conversion.") 242 | tensor = tensor.to(torch.float32) 243 | dtype = "float32" 244 | 245 | # Convert the PyTorch tensor to the specified dtype 246 | try: 247 | tensor = tensor.to(getattr(torch, dtype)) 248 | except AttributeError: 249 | raise ValueError(f"PyTorch does not support the specified dtype '{dtype}'.") 250 | 251 | # Convert the tensor to MLX array 252 | try: 253 | mlx_array = mx.array(tensor.numpy(), getattr(mx, dtype)) 254 | except AttributeError: 255 | raise ValueError(f"MLX does not support the specified dtype '{dtype}'.") 256 | except RuntimeError as e: 257 | raise RuntimeError(f"Error occurred during conversion to MLX array: {e}") 258 | 259 | logging.debug(f"Converted PyTorch tensor to MLX array with dtype '{dtype}'.") 260 | return mlx_array 261 | 262 | 263 | if __name__ == "__main__": 264 | parser = argparse.ArgumentParser( 265 | description="Download and Convert (OpenAI) CLIP weights to MLX" 266 | ) 267 | parser.add_argument( 268 | "--hf-repo", 269 | type=str, 270 | default="openai/clip-vit-large-patch14", 271 | help="Hugging Face repository name.", 272 | ) 273 | parser.add_argument( 274 | "--mlx-path", 275 | type=str, 276 | default="mlx_model", 277 | help="Path to save the MLX model.", 278 | ) 279 | parser.add_argument( 280 | "--dtype", 281 | help="The data type to save the converted model.", 282 | type=str, 283 | default="float32", 284 | ) 285 | args = parser.parse_args() 286 | 287 | torch_path = get_model_path(args.hf_repo) 288 | mlx_path = Path(args.mlx_path) 289 | mlx_path.mkdir(parents=True, exist_ok=True) 290 | 291 | print("[INFO] Loading") 292 | torch_weights = torch.load(torch_path / "pytorch_model.bin") 293 | print("[INFO] Converting") 294 | mlx_weights = { 295 | k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items() 296 | } 297 | print("[INFO] Saving") 298 | save_weights(mlx_path, mlx_weights) 299 | for fn in ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"]: 300 | shutil.copyfile( 301 | str(torch_path / f"{fn}"), 302 | str(mlx_path / f"{fn}"), 303 | ) 304 | -------------------------------------------------------------------------------- /mlx_clip/model.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import glob 4 | import json 5 | import logging 6 | import math 7 | from dataclasses import dataclass 8 | from pathlib import Path 9 | from typing import Optional, Union 10 | 11 | import mlx.core as mx 12 | import mlx.nn as nn 13 | from mlx.core import linalg as LA 14 | from mlx.nn.losses import cross_entropy 15 | from mlx.utils import tree_flatten 16 | 17 | 18 | @dataclass 19 | class CLIPVisionOutput: 20 | pooler_output: mx.array 21 | last_hidden_state: mx.array 22 | hidden_states: Optional[mx.array] 23 | 24 | 25 | @dataclass 26 | class CLIPTextOutput: 27 | pooler_output: mx.array 28 | last_hidden_state: mx.array 29 | 30 | 31 | @dataclass 32 | class CLIPModelOutput: 33 | loss: Optional[mx.array] 34 | text_embeds: Optional[mx.array] 35 | image_embeds: Optional[mx.array] 36 | text_model_output: CLIPTextOutput 37 | vision_model_output: CLIPVisionOutput 38 | 39 | 40 | @dataclass 41 | class CLIPTextConfig: 42 | num_hidden_layers: int 43 | hidden_size: int 44 | intermediate_size: int 45 | num_attention_heads: int 46 | max_position_embeddings: int 47 | vocab_size: int 48 | layer_norm_eps: float 49 | 50 | 51 | @dataclass 52 | class CLIPVisionConfig: 53 | num_hidden_layers: int 54 | hidden_size: int 55 | intermediate_size: int 56 | num_attention_heads: int 57 | num_channels: int 58 | image_size: int 59 | patch_size: int 60 | layer_norm_eps: float 61 | 62 | 63 | @dataclass 64 | class CLIPConfig: 65 | text_config: CLIPTextConfig 66 | vision_config: CLIPVisionConfig 67 | projection_dim: int 68 | 69 | 70 | def quick_gelu(x: mx.array) -> mx.array: 71 | """ 72 | A fast GELU approximation https://github.com/hendrycks/GELUs 73 | """ 74 | return x * mx.sigmoid(1.702 * x) 75 | 76 | 77 | def clip_loss(logits: mx.array) -> mx.array: 78 | N, M = logits.shape 79 | caption_loss = cross_entropy(logits, mx.arange(N), reduction="mean") 80 | image_loss = cross_entropy(logits.T, mx.arange(M), reduction="mean") 81 | return (caption_loss + image_loss) / 2.0 82 | 83 | 84 | class Attention(nn.Module): 85 | def __init__( 86 | self, 87 | dims: int, 88 | num_heads: int, 89 | query_input_dims: Optional[int] = None, 90 | key_input_dims: Optional[int] = None, 91 | value_input_dims: Optional[int] = None, 92 | value_dims: Optional[int] = None, 93 | value_output_dims: Optional[int] = None, 94 | bias: bool = False, 95 | ): 96 | super().__init__() 97 | 98 | if (dims % num_heads) != 0: 99 | raise ValueError( 100 | "The input feature dimensions should be divisible by the " 101 | f"number of heads ({dims} % {num_heads}) != 0" 102 | ) 103 | 104 | query_input_dims = query_input_dims or dims 105 | key_input_dims = key_input_dims or dims 106 | value_input_dims = value_input_dims or key_input_dims 107 | value_dims = value_dims or dims 108 | value_output_dims = value_output_dims or dims 109 | 110 | self.num_heads = num_heads 111 | self.q_proj = nn.Linear(query_input_dims, dims, bias=bias) 112 | self.k_proj = nn.Linear(key_input_dims, dims, bias=bias) 113 | self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) 114 | self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) 115 | 116 | def __call__(self, queries, keys, values, mask=None): 117 | queries = self.q_proj(queries) 118 | keys = self.k_proj(keys) 119 | values = self.v_proj(values) 120 | 121 | num_heads = self.num_heads 122 | B, L, D = queries.shape 123 | _, S, _ = keys.shape 124 | queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) 125 | keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) 126 | values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) 127 | 128 | scale = math.sqrt(1 / queries.shape[-1]) 129 | scores = (queries * scale) @ keys 130 | if mask is not None: 131 | scores = scores + mask.astype(scores.dtype) 132 | scores = mx.softmax(scores, axis=-1) 133 | values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) 134 | 135 | return self.out_proj(values_hat) 136 | 137 | 138 | class MLP(nn.Module): 139 | def __init__(self, config: CLIPTextConfig): 140 | super().__init__() 141 | self.config = config 142 | self.activation_fn = quick_gelu 143 | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) 144 | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) 145 | 146 | def __call__(self, x: mx.array) -> mx.array: 147 | x = self.activation_fn(self.fc1(x)) 148 | x = self.fc2(x) 149 | return x 150 | 151 | 152 | class EncoderLayer(nn.Module): 153 | """The transformer encoder layer from CLIP.""" 154 | 155 | def __init__(self, config: CLIPTextConfig): 156 | super().__init__() 157 | self.embed_dim = config.hidden_size 158 | # Add biases to the attention projections 159 | self.self_attn = Attention( 160 | config.hidden_size, config.num_attention_heads, bias=True 161 | ) 162 | self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 163 | self.mlp = MLP(config) 164 | self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 165 | 166 | def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: 167 | y = self.layer_norm1(x) 168 | y = self.self_attn(y, y, y, mask) 169 | x = x + y 170 | y = self.layer_norm2(x) 171 | y = self.mlp(y) 172 | return x + y 173 | 174 | 175 | class TextEmbeddings(nn.Module): 176 | def __init__(self, config: CLIPTextConfig): 177 | super().__init__() 178 | embed_dim = config.hidden_size 179 | 180 | self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) 181 | self.position_embedding = nn.Embedding( 182 | config.max_position_embeddings, embed_dim 183 | ) 184 | 185 | def __call__(self, x: mx.array) -> mx.array: 186 | embeddings = self.token_embedding(x) 187 | embeddings += self.position_embedding.weight[: x.shape[1]] 188 | return embeddings 189 | 190 | 191 | class Encoder(nn.Module): 192 | def __init__(self, config: CLIPTextConfig): 193 | self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] 194 | 195 | 196 | class ClipTextModel(nn.Module): 197 | """Implements the text encoder transformer from CLIP.""" 198 | 199 | def __init__(self, config: CLIPTextConfig): 200 | super().__init__() 201 | self.embeddings = TextEmbeddings(config) 202 | self.encoder = Encoder(config) 203 | self.final_layer_norm = nn.LayerNorm(config.hidden_size) 204 | 205 | def __call__(self, x: mx.array) -> CLIPTextOutput: 206 | B, N = x.shape 207 | eot_tokens = mx.argmax(x, axis=-1) 208 | x = self.embeddings(x) 209 | mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype) 210 | for l in self.encoder.layers: 211 | x = l(x, mask) 212 | last_hidden_state = self.final_layer_norm(x) 213 | pooler_output = last_hidden_state[mx.arange(B), eot_tokens] 214 | 215 | return CLIPTextOutput( 216 | pooler_output=pooler_output, last_hidden_state=last_hidden_state 217 | ) 218 | 219 | 220 | class VisionEmbeddings(nn.Module): 221 | def __init__(self, config: CLIPVisionConfig): 222 | super().__init__() 223 | self.config = config 224 | self.embed_dim = config.hidden_size 225 | self.image_size = config.image_size 226 | self.patch_size = config.patch_size 227 | 228 | self.class_embedding = mx.zeros((config.hidden_size,)) 229 | 230 | self.patch_embedding = nn.Conv2d( 231 | in_channels=config.num_channels, 232 | out_channels=self.embed_dim, 233 | kernel_size=self.patch_size, 234 | stride=self.patch_size, 235 | bias=False, 236 | ) 237 | 238 | self.num_patches = (self.image_size // self.patch_size) ** 2 239 | self.num_positions = self.num_patches + 1 240 | self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) 241 | 242 | def __call__(self, x: mx.array) -> mx.array: 243 | batch_size = x.shape[0] 244 | # Patchify using conv: 245 | # [batch_size, sqrt(num_patches), sqrt(num_patches), embed_dim] 246 | patch_embeddings = self.patch_embedding(x) 247 | # [batch_size, num_patches, embed_dim] 248 | patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) 249 | embed_dim = patch_embeddings.shape[-1] 250 | # Prepend embeddings 251 | # [batch_size, 1, embed_dim] 252 | cls_embeddings = mx.broadcast_to( 253 | self.class_embedding, (batch_size, 1, embed_dim) 254 | ) 255 | # [batch_size, num_patches + 1, embed_dim] 256 | embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) 257 | # Add positional encoding 258 | embeddings += self.position_embedding.weight 259 | return embeddings 260 | 261 | 262 | class ClipVisionModel(nn.Module): 263 | """Implements the vision encoder transformer from CLIP.""" 264 | 265 | def __init__(self, config: CLIPVisionConfig): 266 | super().__init__() 267 | self.embeddings = VisionEmbeddings(config) 268 | self.pre_layrnorm = nn.LayerNorm(config.hidden_size) 269 | self.encoder = Encoder(config) 270 | self.post_layernorm = nn.LayerNorm(config.hidden_size) 271 | 272 | def __call__( 273 | self, 274 | x: mx.array, 275 | output_hidden_states: Optional[bool] = None, 276 | ) -> CLIPVisionOutput: 277 | x = self.embeddings(x) 278 | x = self.pre_layrnorm(x) 279 | 280 | encoder_states = (x,) if output_hidden_states else None 281 | 282 | for l in self.encoder.layers: 283 | x = l(x, mask=None) 284 | if output_hidden_states: 285 | encoder_states = encoder_states + (x,) 286 | 287 | # Extract token embedding 288 | pooler_output = self.post_layernorm(x[:, 0, :]) 289 | return CLIPVisionOutput( 290 | pooler_output=pooler_output, 291 | last_hidden_state=x, 292 | hidden_states=encoder_states, 293 | ) 294 | 295 | 296 | class CLIPModel(nn.Module): 297 | def __init__(self, config: CLIPConfig): 298 | self.text_model = ClipTextModel(config.text_config) 299 | self.vision_model = ClipVisionModel(config.vision_config) 300 | 301 | text_embed_dim = config.text_config.hidden_size 302 | vision_embed_dim = config.vision_config.hidden_size 303 | projection_dim = config.projection_dim 304 | 305 | self.visual_projection = nn.Linear(vision_embed_dim, projection_dim, bias=False) 306 | self.text_projection = nn.Linear(text_embed_dim, projection_dim, bias=False) 307 | self.logit_scale = mx.array(0.0) 308 | 309 | def get_text_features(self, x: mx.array) -> mx.array: 310 | return self.text_projection(self.text_model(x).pooler_output) 311 | 312 | def get_image_features(self, x: mx.array) -> mx.array: 313 | return self.visual_projection(self.vision_model(x).pooler_output) 314 | 315 | def __call__( 316 | self, 317 | input_ids: Optional[mx.array] = None, 318 | pixel_values: Optional[mx.array] = None, 319 | return_loss=False, 320 | ) -> CLIPModelOutput: 321 | if input_ids is not None: 322 | text_model_output = self.text_model(input_ids) 323 | text_embeds = self.text_projection(text_model_output.pooler_output) 324 | text_embeds = text_embeds / LA.norm(text_embeds, axis=-1, keepdims=True) 325 | else: 326 | text_embeds = None 327 | text_model_output = None 328 | 329 | if pixel_values is not None: 330 | vision_model_output = self.vision_model(pixel_values) 331 | image_embeds = self.visual_projection(vision_model_output.pooler_output) 332 | image_embeds = image_embeds / LA.norm(image_embeds, axis=-1, keepdims=True) 333 | else: 334 | image_embeds = None 335 | vision_model_output = None 336 | 337 | if return_loss and (input_ids is None or pixel_values is None): 338 | raise ValueError("Must provide text and image inputs to compute loss.") 339 | 340 | if return_loss: 341 | logit_scale = mx.exp(self.logit_scale) 342 | logits = (text_embeds @ image_embeds.T) * logit_scale 343 | loss = clip_loss(logits) 344 | else: 345 | loss = None 346 | 347 | return CLIPModelOutput( 348 | loss=loss, 349 | text_embeds=text_embeds, 350 | image_embeds=image_embeds, 351 | vision_model_output=vision_model_output, 352 | text_model_output=text_model_output, 353 | ) 354 | 355 | @staticmethod 356 | def from_pretrained(path: str): 357 | path = Path(path) 358 | 359 | with open(path / "config.json", "r") as fid: 360 | config = json.load(fid) 361 | 362 | text_config = config["text_config"] 363 | text_config = CLIPTextConfig( 364 | num_hidden_layers=text_config["num_hidden_layers"], 365 | hidden_size=text_config["hidden_size"], 366 | intermediate_size=text_config["intermediate_size"], 367 | num_attention_heads=text_config["num_attention_heads"], 368 | max_position_embeddings=text_config["max_position_embeddings"], 369 | vocab_size=text_config["vocab_size"], 370 | layer_norm_eps=text_config["layer_norm_eps"], 371 | ) 372 | 373 | vision_config = config["vision_config"] 374 | 375 | vision_config = CLIPVisionConfig( 376 | num_hidden_layers=vision_config["num_hidden_layers"], 377 | hidden_size=vision_config["hidden_size"], 378 | intermediate_size=vision_config["intermediate_size"], 379 | num_attention_heads=vision_config["num_attention_heads"], 380 | num_channels=3, 381 | image_size=vision_config["image_size"], 382 | patch_size=vision_config["patch_size"], 383 | layer_norm_eps=vision_config["layer_norm_eps"], 384 | ) 385 | 386 | config = CLIPConfig( 387 | text_config=text_config, 388 | vision_config=vision_config, 389 | projection_dim=config["projection_dim"], 390 | ) 391 | model = CLIPModel(config) 392 | weight_files = glob.glob(str(path / "*.safetensors")) 393 | if not weight_files: 394 | logging.error(f"No safetensors found in {path}") 395 | raise FileNotFoundError(f"No safetensors found in {path}") 396 | 397 | weights = {} 398 | for wf in weight_files: 399 | weights.update(mx.load(wf)) 400 | 401 | weights = model.sanitize(weights) 402 | model.load_weights(list(weights.items())) 403 | return model 404 | 405 | @staticmethod 406 | def sanitize(weights): 407 | sanitized_weights = {} 408 | for k, v in weights.items(): 409 | if "position_ids" in k: 410 | # Remove unused position_ids 411 | continue 412 | elif "patch_embedding.weight" in k: 413 | # pytorch conv2d expects the weight tensor to be of shape [out_channels, in_channels, kH, KW] 414 | # mlx conv2d expects the weight tensor to be of shape [out_channels, kH, KW, in_channels] 415 | sanitized_weights[k] = v.transpose(0, 2, 3, 1) 416 | else: 417 | sanitized_weights[k] = v 418 | 419 | return sanitized_weights 420 | --------------------------------------------------------------------------------