├── .gitignore ├── ACKNOWLEDGEMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── LICENSE_weights_data ├── README.md ├── docs └── fig_accuracy_latency.png ├── eval └── zeroshot_imagenet.py ├── get_pretrained_models.sh ├── hf_dataset_example.py ├── ios_app ├── ACKNOWLEDGEMENTS ├── MobileCLIPExplore.xcodeproj │ ├── project.pbxproj │ ├── project.xcworkspace │ │ ├── contents.xcworkspacedata │ │ ├── xcshareddata │ │ │ └── IDEWorkspaceChecks.plist │ │ └── xcuserdata │ │ │ └── pavankumaranasosaluvasu.xcuserdatad │ │ │ └── UserInterfaceState.xcuserstate │ └── xcuserdata │ │ └── pavankumaranasosaluvasu.xcuserdatad │ │ └── xcschemes │ │ └── xcschememanagement.plist ├── MobileCLIPExplore │ ├── Assets.xcassets │ │ ├── AccentColor.colorset │ │ │ └── Contents.json │ │ ├── AddClassName.colorset │ │ │ └── Contents.json │ │ ├── AppIcon.appiconset │ │ │ └── Contents.json │ │ ├── Contents.json │ │ ├── Preset.colorset │ │ │ └── Contents.json │ │ ├── PresetDisabled.colorset │ │ │ └── Contents.json │ │ ├── PromptEditor.colorset │ │ │ └── Contents.json │ │ └── PromptEditorBackground.colorset │ │ │ └── Contents.json │ ├── AsyncFactory.swift │ ├── CameraController.swift │ ├── Configuration │ │ └── Build.xcconfig │ ├── ContentView.swift │ ├── DisplayPrediction.swift │ ├── Extensions.swift │ ├── MobileCLIPExploreApp.swift │ ├── Models.swift │ ├── Models │ │ └── README.md │ ├── Preview Content │ │ └── Preview Assets.xcassets │ │ │ └── Contents.json │ ├── Prompt.swift │ ├── Resources │ │ ├── clip-merges.txt │ │ └── clip-vocab.json │ ├── Tokenizer │ │ ├── CLIPTokenizer.swift │ │ ├── GPT2ByteEncoder.swift │ │ └── Utils.swift │ ├── Views │ │ ├── Helpers │ │ │ ├── PredictionsTable.swift │ │ │ ├── PromptPreview.swift │ │ │ └── VideoFrameView.swift │ │ ├── InferenceView.swift │ │ ├── InfoView.swift │ │ └── PromptEditor.swift │ └── ZSImageClassification.swift ├── README.md └── docs │ └── app_screenshots │ ├── examples.png │ └── model_directory_guide.png ├── mobileclip ├── __init__.py ├── clip.py ├── configs │ ├── mobileclip_b.json │ ├── mobileclip_s0.json │ ├── mobileclip_s1.json │ └── mobileclip_s2.json ├── image_encoder.py ├── logger.py ├── models │ ├── __init__.py │ ├── mci.py │ └── vit.py ├── modules │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── mobileone.py │ │ └── transformer.py │ ├── image │ │ ├── __init__.py │ │ ├── image_projection.py │ │ └── replknet.py │ └── text │ │ ├── __init__.py │ │ ├── repmixer.py │ │ └── tokenizer.py └── text_encoder.py ├── requirements.txt ├── results ├── mobileclip_b.jsonl ├── mobileclip_b_lt.jsonl ├── mobileclip_s0.jsonl ├── mobileclip_s1.jsonl └── mobileclip_s2.jsonl ├── setup.py └── training ├── README.md ├── __init__.py ├── configs ├── datacompdr12m.json ├── datacompdr1b.json ├── run_datacomp12m.sh ├── run_datacompdr12m.sh └── run_datacompdr1b.sh ├── dr ├── __init__.py ├── transforms.py └── transforms_base.py └── open_clip.patch /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # macOS 132 | **/.DS_Store 133 | 134 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright © 2024 Apple Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | ------------------------------------------------------------------------------- 24 | SOFTWARE DISTRIBUTED WITH ML-MobileCLIP: 25 | 26 | The ML-MobileCLIP model weights and data copyright and license terms can be 27 | found in LICENSE_weights_data. 28 | 29 | The ML-MobileCLIP software includes a number of subcomponents with separate 30 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS. 31 | ------------------------------------------------------------------------------- 32 | -------------------------------------------------------------------------------- /LICENSE_weights_data: -------------------------------------------------------------------------------- 1 | ML-MobileCLIP Model Weights and Data 2 | 3 | Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | 5 | IMPORTANT: This Apple software is supplied to you by Apple 6 | Inc. ("Apple") in consideration of your agreement to the following 7 | terms, and your use, installation, modification or redistribution of 8 | this Apple software constitutes acceptance of these terms. If you do 9 | not agree with these terms, please do not use, install, modify or 10 | redistribute this Apple software. 11 | 12 | In consideration of your agreement to abide by the following terms, and 13 | subject to these terms, Apple grants you a personal, non-exclusive 14 | license, under Apple's copyrights in this original Apple software (the 15 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 16 | Software, with or without modifications, in source and/or binary forms; 17 | provided that if you redistribute the Apple Software in its entirety and 18 | without modifications, you must retain this notice and the following 19 | text and disclaimers in all such redistributions of the Apple Software. 20 | Neither the name, trademarks, service marks or logos of Apple Inc. may 21 | be used to endorse or promote products derived from the Apple Software 22 | without specific prior written permission from Apple. Except as 23 | expressly stated in this notice, no other rights or licenses, express or 24 | implied, are granted by Apple herein, including but not limited to any 25 | patent rights that may be infringed by your derivative works or by other 26 | works in which the Apple Software may be incorporated. 27 | 28 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 29 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 30 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 31 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 32 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 33 | 34 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 35 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 36 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 37 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 38 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 39 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 40 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 41 | POSSIBILITY OF SUCH DAMAGE. 42 | 43 | ------------------------------------------------------------------------------- 44 | SOFTWARE DISTRIBUTED WITH ML-MobileCLIP: 45 | 46 | The ML-MobileCLIP software copyright and license terms can be found in LICENSE. 47 | 48 | The ML-MobileCLIP software includes a number of subcomponents with separate 49 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS. 50 | ------------------------------------------------------------------------------- 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training 2 | 3 | This is the official repository of 4 | **[MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training](https://arxiv.org/pdf/2311.17049.pdf). (CVPR 2024)** 5 | *Pavan Kumar Anasosalu Vasu, Hadi Pouransari, Fartash Faghri, Raviteja Vemulapalli, Oncel Tuzel.* 6 | The repository contains code for inference, training, and evaluation of MobileCLIP models trained on DataCompDR datasets. 7 | 8 | [//]: # (![MobileCLIP Performance](docs/fig_accuracy_latency.png)) 9 |

10 | Accuracy vs latency figure. 11 |

12 | 13 | - **Update 2024/11/22:** Releasing iOS app to demonstrate the use of our model for real-time zero-shot image classification. See [ios_app](./ios_app/). 14 | - **Update 2024/06/13:** Releasing the code and scripts to train using [OpenCLIP](https://github.com/mlfoundations/open_clip/tree/main/src/open_clip) on DataCompDR datasets. See [training/](./training/). 15 | - **Update 2024/06/13:** MobileCLIP models and DataCompDR datasets are now available on HuggingFace in [MobileCLIP/DataCompDR Collection](https://huggingface.co/collections/apple/mobileclip-models-datacompdr-data-665789776e1aa2b59f35f7c8). 16 | 17 | ### Highlights 18 | * Our smallest variant `MobileCLIP-S0` obtains similar zero-shot performance as [OpenAI](https://arxiv.org/abs/2103.00020)'s ViT-B/16 model while being 4.8x faster and 2.8x smaller. 19 | * `MobileCLIP-S2` obtains better avg zero-shot performance than [SigLIP](https://arxiv.org/abs/2303.15343)'s ViT-B/16 model while being 2.3x faster and 2.1x smaller, and trained with 3x less seen samples. 20 | * `MobileCLIP-B`(LT) attains zero-shot ImageNet performance of **77.2%** which is significantly better than recent works like [DFN](https://arxiv.org/abs/2309.17425) and [SigLIP](https://arxiv.org/abs/2303.15343) with similar architectures or even [OpenAI's ViT-L/14@336](https://arxiv.org/abs/2103.00020). 21 | * iOS app to demonstrate the superior performance of our model on a mobile device. 22 | 23 | ![Examples](ios_app/docs/app_screenshots/examples.png) 24 | 25 | ## Getting Started 26 | 27 | ### Setup 28 | ```bash 29 | conda create -n clipenv python=3.10 30 | conda activate clipenv 31 | pip install -e . 32 | ``` 33 | To download pretrained checkpoints follow the code snippet below 34 | ```bash 35 | source get_pretrained_models.sh # Files will be downloaded to `checkpoints` directory. 36 | ``` 37 | 38 | ### Usage Example 39 | To models from the official repo, follow the code snippet below 40 | ```python 41 | import torch 42 | from PIL import Image 43 | import mobileclip 44 | 45 | model, _, preprocess = mobileclip.create_model_and_transforms('mobileclip_s0', pretrained='/path/to/mobileclip_s0.pt') 46 | tokenizer = mobileclip.get_tokenizer('mobileclip_s0') 47 | 48 | image = preprocess(Image.open("docs/fig_accuracy_latency.png").convert('RGB')).unsqueeze(0) 49 | text = tokenizer(["a diagram", "a dog", "a cat"]) 50 | 51 | with torch.no_grad(), torch.cuda.amp.autocast(): 52 | image_features = model.encode_image(image) 53 | text_features = model.encode_text(text) 54 | image_features /= image_features.norm(dim=-1, keepdim=True) 55 | text_features /= text_features.norm(dim=-1, keepdim=True) 56 | 57 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) 58 | 59 | print("Label probs:", text_probs) 60 | ``` 61 | 62 | For an example of loading the data from HuggingFace see 63 | [hf_dataset_example.py](./hf_dataset_example.py). 64 | 65 | ### OpenCLIP Support 66 | Our models are now natively supported in OpenCLIP. To use MobileCLIP models in OpenCLIP, setup your environment as shown below, 67 | ```bash 68 | conda create -n clipenv python=3.10 69 | conda activate clipenv 70 | 71 | pip install git+https://github.com/mlfoundations/open_clip 72 | pip install git+https://github.com/huggingface/pytorch-image-models 73 | ``` 74 | 75 | To run inference, see example below, 76 | ```python 77 | import open_clip 78 | from mobileclip.modules.common.mobileone import reparameterize_model 79 | 80 | model, _, preprocess = open_clip.create_model_and_transforms('MobileCLIP-S2', pretrained='datacompdr') 81 | tokenizer = open_clip.get_tokenizer('MobileCLIP-S2') 82 | 83 | # For inference/model exporting purposes, please reparameterize first 84 | model.eval() 85 | model = reparameterize_model(model) 86 | 87 | # ... follow examples in open_clip repo ... 88 | ``` 89 | Variants currently available on OpenCLIP, 90 | `[('MobileCLIP-S1', 'datacompdr'), 91 | ('MobileCLIP-S2', 'datacompdr'), 92 | ('MobileCLIP-B', 'datacompdr'), 93 | ('MobileCLIP-B', 'datacompdr_lt')]` 94 | 95 | 96 | ## Evaluation 97 | Please find the detailed evaluation results [here](./results). 98 | To reproduce results, we provide script to perform zero-shot evaluation on ImageNet-1k dataset. 99 | To evaluate on all the 38 datasets, please follow instructions in [datacomp](https://github.com/mlfoundations/datacomp). 100 | ```bash 101 | # Run evaluation with single GPU 102 | python eval/zeroshot_imagenet.py --model-arch mobileclip_s0 --model-path /path/to/mobileclip_s0.pt 103 | ``` 104 | 105 | Please refer to [Open CLIP Results](https://github.com/mlfoundations/open_clip/blob/main/docs/openclip_results.csv) to compare with other models. 106 | 107 | | Model | # Seen
Samples (B) | # Params (M)
(img + txt) | Latency (ms)
(img + txt) | IN-1k Zero-Shot
Top-1 Acc. (%) | Avg. Perf. (%)
on 38 datasets | Pytorch Checkpoint (url) | 108 | |:------------------|:----------------------:|:-----------------------------:|:------------------------------:|:-----------------------------------:|:----------------------------------:|:--------------------------------------------------------------------------------------------------------------:| 109 | | MobileCLIP-S0 | 13 | 11.4 + 42.4 | 1.5 + 1.6 | 67.8 | 58.1 | [mobileclip_s0.pt](https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s0.pt) | 110 | | MobileCLIP-S1 | 13 | 21.5 + 63.4 | 2.5 + 3.3 | 72.6 | 61.3 | [mobileclip_s1.pt](https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt) | 111 | | MobileCLIP-S2 | 13 | 35.7 + 63.4 | 3.6 + 3.3 | 74.4 | 63.7 | [mobileclip_s2.pt](https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt) | 112 | | MobileCLIP-B | 13 | 86.3 + 63.4 | 10.4 + 3.3 | 76.8 | 65.2 | [mobileclip_b.pt](https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt) | 113 | | MobileCLIP-B (LT) | 36 | 86.3 + 63.4 | 10.4 + 3.3 | 77.2 | 65.8 | [mobileclip_blt.pt](https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt) | 114 | 115 | Note: MobileCLIP-B(LT) is trained for 300k iterations with constant learning rate schedule and 300k iterations with cosine learning rate schedule. 116 | 117 | ## Citation 118 | If you found this code useful, please cite the following paper: 119 | ``` 120 | @InProceedings{mobileclip2024, 121 | author = {Pavan Kumar Anasosalu Vasu, Hadi Pouransari, Fartash Faghri, Raviteja Vemulapalli, Oncel Tuzel}, 122 | title = {MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training}, 123 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 124 | month = {June}, 125 | year = {2024}, 126 | } 127 | ``` 128 | 129 | ## Acknowledgements 130 | Our codebase is built using multiple opensource contributions, please see [ACKNOWLEDGEMENTS](ACKNOWLEDGEMENTS) for more details. 131 | -------------------------------------------------------------------------------- /docs/fig_accuracy_latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-mobileclip/1140b8d197e4ed7d56b3a92216ded98bb1c2ac87/docs/fig_accuracy_latency.png -------------------------------------------------------------------------------- /eval/zeroshot_imagenet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | """ 6 | Evaluation code is borrowed from https://github.com/mlfoundations/datacomp/blob/main/eval_utils/wds_eval.py 7 | Licensed under MIT License, see ACKNOWLEDGEMENTS for details. 8 | """ 9 | 10 | import os 11 | import argparse 12 | 13 | import mobileclip 14 | import torch 15 | from clip_benchmark.datasets.builder import build_dataset 16 | from clip_benchmark.metrics import zeroshot_classification as zsc 17 | 18 | 19 | def parse_args(parser): 20 | parser.add_argument( 21 | "--model-arch", 22 | type=str, 23 | required=True, 24 | help="Specify model arch from the available choices.", 25 | choices=['mobileclip_s0', 'mobileclip_s1', 'mobileclip_s2', 'mobileclip_b'] 26 | ) 27 | parser.add_argument( 28 | "--model-path", 29 | type=str, 30 | required=True, 31 | help="Specify location of model checkpoint.", 32 | ) 33 | return parser 34 | 35 | 36 | def create_model(model_arch, model_path): 37 | device = "cuda" if torch.cuda.is_available() else "cpu" 38 | torch.manual_seed(0) 39 | 40 | model_path = str(model_path) 41 | model, _, transform = mobileclip.create_model_and_transforms( 42 | model_arch, pretrained=model_path 43 | ) 44 | model.eval() 45 | model = model.to(device) 46 | 47 | return model, transform, device 48 | 49 | 50 | def create_webdataset( 51 | task, transform, data_root=None, dataset_len=None, batch_size=64, num_workers=4 52 | ): 53 | data_folder = f"wds_{task.replace('/','-')}_test" 54 | if data_root is None: 55 | data_root = f"https://huggingface.co/datasets/djghosh/{data_folder}/tree/main" 56 | else: 57 | data_root = os.path.join(data_root, data_folder) 58 | dataset = build_dataset( 59 | dataset_name=f"wds/{task}", 60 | root=data_root, 61 | transform=transform, 62 | split="test", 63 | download=False, 64 | ) 65 | if dataset_len: 66 | dataset = dataset.with_length((dataset_len + batch_size - 1) // batch_size) 67 | dataloader = torch.utils.data.DataLoader( 68 | dataset.batched(batch_size), 69 | batch_size=None, 70 | shuffle=False, 71 | num_workers=num_workers, 72 | ) 73 | return dataset, dataloader 74 | 75 | 76 | def evaluate_webdataset( 77 | task, 78 | model_arch, 79 | model_path, 80 | data_root=None, 81 | dataset_len=None, 82 | batch_size=64, 83 | num_workers=4, 84 | ): 85 | """Evaluate CLIP model on classification task.""" 86 | 87 | # Create model 88 | model, transform, device = create_model(model_arch, model_path) 89 | 90 | # Load data 91 | dataset, dataloader = create_webdataset( 92 | task, transform, data_root, dataset_len, batch_size, num_workers 93 | ) 94 | 95 | zeroshot_templates = dataset.templates if hasattr(dataset, "templates") else None 96 | classnames = dataset.classes if hasattr(dataset, "classes") else None 97 | assert ( 98 | zeroshot_templates is not None and classnames is not None 99 | ), "Dataset does not support classification" 100 | 101 | # Evaluate 102 | classifier = zsc.zero_shot_classifier( 103 | model, 104 | mobileclip.get_tokenizer(model_arch), 105 | classnames, 106 | zeroshot_templates, 107 | device, 108 | ) 109 | logits, target = zsc.run_classification( 110 | model, classifier, dataloader, device, amp=False 111 | ) 112 | 113 | # Compute metrics 114 | acc1, acc5 = zsc.accuracy(logits, target, topk=(1, 5)) 115 | metrics = { 116 | "acc1": acc1, 117 | "acc5": acc5, 118 | } 119 | return metrics 120 | 121 | 122 | if __name__ == "__main__": 123 | parser = argparse.ArgumentParser(description="Webdataset evaluation script.") 124 | parser = parse_args(parser) 125 | args = parser.parse_args() 126 | 127 | metric = evaluate_webdataset( 128 | task="imagenet1k", model_arch=args.model_arch, model_path=args.model_path 129 | ) 130 | print(f"ImageNet1k Eval Metrics: {metric}") 131 | -------------------------------------------------------------------------------- /get_pretrained_models.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # For licensing see accompanying LICENSE file. 4 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 5 | # 6 | mkdir -p checkpoints 7 | wget https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt -P checkpoints 8 | wget https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt -P checkpoints 9 | wget https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s0.pt -P checkpoints 10 | wget https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt -P checkpoints 11 | wget https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt -P checkpoints 12 | -------------------------------------------------------------------------------- /hf_dataset_example.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from PIL import Image 6 | import requests 7 | from io import BytesIO 8 | from datasets import load_dataset 9 | import json 10 | import numpy as np 11 | import torch 12 | 13 | from training.dr.transforms import compose_from_config 14 | 15 | 16 | if __name__ == '__main__': 17 | rconfig_aug = { 18 | "normalize": { 19 | "mean": [0.48145466, 0.4578275, 0.40821073], 20 | "std": [0.26862954, 0.26130258, 0.27577711] 21 | }, 22 | "rand_augment": {"enable": True, "p": 1.0}, 23 | "random_resized_crop": {"interpolation": "bicubic", "size": 224}, 24 | "to_rgb": {"enable": True}, 25 | "to_tensor": {"enable": True} 26 | } 27 | dr_transforms = compose_from_config(rconfig_aug) 28 | 29 | dataset = load_dataset("apple/DataCompDR-12M", split="train", streaming=True) 30 | sample = next(iter(dataset)) 31 | 32 | # Load image from URL 33 | url = sample['url.txt'] 34 | response = requests.get(url) 35 | img = Image.open(BytesIO(response.content)) 36 | sample["image"] = img 37 | 38 | # Preprocess image 39 | # Sample an image augmentation 40 | param_augs = json.loads(sample["paug.json"]["param_aug"]) 41 | aug_idx = np.random.randint(0, len(param_augs)) 42 | params = param_augs[aug_idx] 43 | params = dr_transforms.decompress(params) 44 | image = sample["image"].convert('RGB') 45 | image, _ = dr_transforms.reapply(image, params) 46 | 47 | # Preprocess synthetic text 48 | scapi = np.random.randint(0, len(sample["syn.json"]["syn_text"])) 49 | syn_text = sample["syn.json"]["syn_text"][scapi] 50 | 51 | # Preprocess embeddings 52 | if "npz" in sample: 53 | image_emb = sample["npz"]["image_emb"][aug_idx] 54 | text_emb_all = sample["npz"]["text_emb"] 55 | elif "pth.gz" in sample: 56 | image_emb = sample["pth.gz"]["image_emb"][aug_idx] 57 | text_emb_all = sample["pth.gz"]["text_emb"] 58 | capi = 0 59 | text_emb = text_emb_all[capi] 60 | syn_text_emb = text_emb_all[1+scapi] 61 | if not isinstance(image_emb, torch.Tensor): 62 | image_emb = torch.tensor(image_emb) 63 | text_emb = torch.tensor(text_emb) 64 | syn_text_emb = torch.tensor(syn_text_emb) 65 | image_emb = image_emb.type(torch.float32) 66 | text_emb = text_emb.type(torch.float32) 67 | syn_text_emb = syn_text_emb.type(torch.float32) 68 | 69 | print( 70 | { 71 | 'image': image.shape, 72 | 'image_emb': image_emb.shape, 73 | 'text_emb': text_emb.shape, 74 | "syn_text": syn_text, 75 | 'syn_text_emb': syn_text_emb.shape, 76 | } 77 | ) 78 | -------------------------------------------------------------------------------- /ios_app/ACKNOWLEDGEMENTS: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this Software may utilize the following copyrighted 3 | material, the use of which is hereby acknowledged. 4 | 5 | ------------------------------------------------ 6 | swift-coreml-transformers (huggingface) 7 | https://github.com/huggingface/swift-coreml-transformers/tree/master 8 | 9 | Apache License 10 | Version 2.0, January 2004 11 | http://www.apache.org/licenses/ 12 | 13 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 14 | 15 | 1. Definitions. 16 | 17 | "License" shall mean the terms and conditions for use, reproduction, 18 | and distribution as defined by Sections 1 through 9 of this document. 19 | 20 | "Licensor" shall mean the copyright owner or entity authorized by 21 | the copyright owner that is granting the License. 22 | 23 | "Legal Entity" shall mean the union of the acting entity and all 24 | other entities that control, are controlled by, or are under common 25 | control with that entity. For the purposes of this definition, 26 | "control" means (i) the power, direct or indirect, to cause the 27 | direction or management of such entity, whether by contract or 28 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 29 | outstanding shares, or (iii) beneficial ownership of such entity. 30 | 31 | "You" (or "Your") shall mean an individual or Legal Entity 32 | exercising permissions granted by this License. 33 | 34 | "Source" form shall mean the preferred form for making modifications, 35 | including but not limited to software source code, documentation 36 | source, and configuration files. 37 | 38 | "Object" form shall mean any form resulting from mechanical 39 | transformation or translation of a Source form, including but 40 | not limited to compiled object code, generated documentation, 41 | and conversions to other media types. 42 | 43 | "Work" shall mean the work of authorship, whether in Source or 44 | Object form, made available under the License, as indicated by a 45 | copyright notice that is included in or attached to the work 46 | (an example is provided in the Appendix below). 47 | 48 | "Derivative Works" shall mean any work, whether in Source or Object 49 | form, that is based on (or derived from) the Work and for which the 50 | editorial revisions, annotations, elaborations, or other modifications 51 | represent, as a whole, an original work of authorship. For the purposes 52 | of this License, Derivative Works shall not include works that remain 53 | separable from, or merely link (or bind by name) to the interfaces of, 54 | the Work and Derivative Works thereof. 55 | 56 | "Contribution" shall mean any work of authorship, including 57 | the original version of the Work and any modifications or additions 58 | to that Work or Derivative Works thereof, that is intentionally 59 | submitted to Licensor for inclusion in the Work by the copyright owner 60 | or by an individual or Legal Entity authorized to submit on behalf of 61 | the copyright owner. For the purposes of this definition, "submitted" 62 | means any form of electronic, verbal, or written communication sent 63 | to the Licensor or its representatives, including but not limited to 64 | communication on electronic mailing lists, source code control systems, 65 | and issue tracking systems that are managed by, or on behalf of, the 66 | Licensor for the purpose of discussing and improving the Work, but 67 | excluding communication that is conspicuously marked or otherwise 68 | designated in writing by the copyright owner as "Not a Contribution." 69 | 70 | "Contributor" shall mean Licensor and any individual or Legal Entity 71 | on behalf of whom a Contribution has been received by Licensor and 72 | subsequently incorporated within the Work. 73 | 74 | 2. Grant of Copyright License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | copyright license to reproduce, prepare Derivative Works of, 78 | publicly display, publicly perform, sublicense, and distribute the 79 | Work and such Derivative Works in Source or Object form. 80 | 81 | 3. Grant of Patent License. Subject to the terms and conditions of 82 | this License, each Contributor hereby grants to You a perpetual, 83 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 84 | (except as stated in this section) patent license to make, have made, 85 | use, offer to sell, sell, import, and otherwise transfer the Work, 86 | where such license applies only to those patent claims licensable 87 | by such Contributor that are necessarily infringed by their 88 | Contribution(s) alone or by combination of their Contribution(s) 89 | with the Work to which such Contribution(s) was submitted. If You 90 | institute patent litigation against any entity (including a 91 | cross-claim or counterclaim in a lawsuit) alleging that the Work 92 | or a Contribution incorporated within the Work constitutes direct 93 | or contributory patent infringement, then any patent licenses 94 | granted to You under this License for that Work shall terminate 95 | as of the date such litigation is filed. 96 | 97 | 4. Redistribution. You may reproduce and distribute copies of the 98 | Work or Derivative Works thereof in any medium, with or without 99 | modifications, and in Source or Object form, provided that You 100 | meet the following conditions: 101 | 102 | (a) You must give any other recipients of the Work or 103 | Derivative Works a copy of this License; and 104 | 105 | (b) You must cause any modified files to carry prominent notices 106 | stating that You changed the files; and 107 | 108 | (c) You must retain, in the Source form of any Derivative Works 109 | that You distribute, all copyright, patent, trademark, and 110 | attribution notices from the Source form of the Work, 111 | excluding those notices that do not pertain to any part of 112 | the Derivative Works; and 113 | 114 | (d) If the Work includes a "NOTICE" text file as part of its 115 | distribution, then any Derivative Works that You distribute must 116 | include a readable copy of the attribution notices contained 117 | within such NOTICE file, excluding those notices that do not 118 | pertain to any part of the Derivative Works, in at least one 119 | of the following places: within a NOTICE text file distributed 120 | as part of the Derivative Works; within the Source form or 121 | documentation, if provided along with the Derivative Works; or, 122 | within a display generated by the Derivative Works, if and 123 | wherever such third-party notices normally appear. The contents 124 | of the NOTICE file are for informational purposes only and 125 | do not modify the License. You may add Your own attribution 126 | notices within Derivative Works that You distribute, alongside 127 | or as an addendum to the NOTICE text from the Work, provided 128 | that such additional attribution notices cannot be construed 129 | as modifying the License. 130 | 131 | You may add Your own copyright statement to Your modifications and 132 | may provide additional or different license terms and conditions 133 | for use, reproduction, or distribution of Your modifications, or 134 | for any such Derivative Works as a whole, provided Your use, 135 | reproduction, and distribution of the Work otherwise complies with 136 | the conditions stated in this License. 137 | 138 | 5. Submission of Contributions. Unless You explicitly state otherwise, 139 | any Contribution intentionally submitted for inclusion in the Work 140 | by You to the Licensor shall be under the terms and conditions of 141 | this License, without any additional terms or conditions. 142 | Notwithstanding the above, nothing herein shall supersede or modify 143 | the terms of any separate license agreement you may have executed 144 | with Licensor regarding such Contributions. 145 | 146 | 6. Trademarks. This License does not grant permission to use the trade 147 | names, trademarks, service marks, or product names of the Licensor, 148 | except as required for reasonable and customary use in describing the 149 | origin of the Work and reproducing the content of the NOTICE file. 150 | 151 | 7. Disclaimer of Warranty. Unless required by applicable law or 152 | agreed to in writing, Licensor provides the Work (and each 153 | Contributor provides its Contributions) on an "AS IS" BASIS, 154 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 155 | implied, including, without limitation, any warranties or conditions 156 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 157 | PARTICULAR PURPOSE. You are solely responsible for determining the 158 | appropriateness of using or redistributing the Work and assume any 159 | risks associated with Your exercise of permissions under this License. 160 | 161 | 8. Limitation of Liability. In no event and under no legal theory, 162 | whether in tort (including negligence), contract, or otherwise, 163 | unless required by applicable law (such as deliberate and grossly 164 | negligent acts) or agreed to in writing, shall any Contributor be 165 | liable to You for damages, including any direct, indirect, special, 166 | incidental, or consequential damages of any character arising as a 167 | result of this License or out of the use or inability to use the 168 | Work (including but not limited to damages for loss of goodwill, 169 | work stoppage, computer failure or malfunction, or any and all 170 | other commercial damages or losses), even if such Contributor 171 | has been advised of the possibility of such damages. 172 | 173 | 9. Accepting Warranty or Additional Liability. While redistributing 174 | the Work or Derivative Works thereof, You may choose to offer, 175 | and charge a fee for, acceptance of support, warranty, indemnity, 176 | or other liability obligations and/or rights consistent with this 177 | License. However, in accepting such obligations, You may act only 178 | on Your own behalf and on Your sole responsibility, not on behalf 179 | of any other Contributor, and only if You agree to indemnify, 180 | defend, and hold each Contributor harmless for any liability 181 | incurred by, or claims asserted against, such Contributor by reason 182 | of your accepting any such warranty or additional liability. 183 | 184 | END OF TERMS AND CONDITIONS 185 | 186 | APPENDIX: How to apply the Apache License to your work. 187 | 188 | To apply the Apache License to your work, attach the following 189 | boilerplate notice, with the fields enclosed by brackets "[]" 190 | replaced with your own identifying information. (Don't include 191 | the brackets!) The text should be enclosed in the appropriate 192 | comment syntax for the file format. We also recommend that a 193 | file or class name and description of purpose be included on the 194 | same "printed page" as the copyright notice for easier 195 | identification within third-party archives. 196 | 197 | Copyright [yyyy] [name of copyright owner] 198 | 199 | Licensed under the Apache License, Version 2.0 (the "License"); 200 | you may not use this file except in compliance with the License. 201 | You may obtain a copy of the License at 202 | 203 | http://www.apache.org/licenses/LICENSE-2.0 204 | 205 | Unless required by applicable law or agreed to in writing, software 206 | distributed under the License is distributed on an "AS IS" BASIS, 207 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 208 | See the License for the specific language governing permissions and 209 | limitations under the License. 210 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore.xcodeproj/project.xcworkspace/contents.xcworkspacedata: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | IDEDidComputeMac32BitWarning 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore.xcodeproj/project.xcworkspace/xcuserdata/pavankumaranasosaluvasu.xcuserdatad/UserInterfaceState.xcuserstate: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-mobileclip/1140b8d197e4ed7d56b3a92216ded98bb1c2ac87/ios_app/MobileCLIPExplore.xcodeproj/project.xcworkspace/xcuserdata/pavankumaranasosaluvasu.xcuserdatad/UserInterfaceState.xcuserstate -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore.xcodeproj/xcuserdata/pavankumaranasosaluvasu.xcuserdatad/xcschemes/xcschememanagement.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | SchemeUserState 6 | 7 | MobileCLIPExplore.xcscheme_^#shared#^_ 8 | 9 | orderHint 10 | 0 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Assets.xcassets/AddClassName.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "color" : { 5 | "platform" : "universal", 6 | "reference" : "systemBlueColor" 7 | }, 8 | "idiom" : "universal" 9 | }, 10 | { 11 | "appearances" : [ 12 | { 13 | "appearance" : "luminosity", 14 | "value" : "dark" 15 | } 16 | ], 17 | "color" : { 18 | "platform" : "universal", 19 | "reference" : "systemBlueColor" 20 | }, 21 | "idiom" : "universal" 22 | } 23 | ], 24 | "info" : { 25 | "author" : "xcode", 26 | "version" : 1 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "platform" : "ios", 6 | "size" : "1024x1024" 7 | } 8 | ], 9 | "info" : { 10 | "author" : "xcode", 11 | "version" : 1 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Assets.xcassets/Preset.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "color" : { 5 | "platform" : "ios", 6 | "reference" : "linkColor" 7 | }, 8 | "idiom" : "universal" 9 | }, 10 | { 11 | "appearances" : [ 12 | { 13 | "appearance" : "luminosity", 14 | "value" : "dark" 15 | } 16 | ], 17 | "color" : { 18 | "platform" : "ios", 19 | "reference" : "linkColor" 20 | }, 21 | "idiom" : "universal" 22 | } 23 | ], 24 | "info" : { 25 | "author" : "xcode", 26 | "version" : 1 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Assets.xcassets/PresetDisabled.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "color" : { 5 | "color-space" : "display-p3", 6 | "components" : { 7 | "alpha" : "1.000", 8 | "blue" : "0.966", 9 | "green" : "0.661", 10 | "red" : "0.544" 11 | } 12 | }, 13 | "idiom" : "universal" 14 | }, 15 | { 16 | "appearances" : [ 17 | { 18 | "appearance" : "luminosity", 19 | "value" : "dark" 20 | } 21 | ], 22 | "color" : { 23 | "color-space" : "display-p3", 24 | "components" : { 25 | "alpha" : "1.000", 26 | "blue" : "0.967", 27 | "green" : "0.692", 28 | "red" : "0.522" 29 | } 30 | }, 31 | "idiom" : "universal" 32 | } 33 | ], 34 | "info" : { 35 | "author" : "xcode", 36 | "version" : 1 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Assets.xcassets/PromptEditor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "color" : { 5 | "color-space" : "gray-gamma-22", 6 | "components" : { 7 | "alpha" : "1.000", 8 | "white" : "1.000" 9 | } 10 | }, 11 | "idiom" : "universal" 12 | }, 13 | { 14 | "appearances" : [ 15 | { 16 | "appearance" : "luminosity", 17 | "value" : "dark" 18 | } 19 | ], 20 | "color" : { 21 | "color-space" : "gray-gamma-22", 22 | "components" : { 23 | "alpha" : "1.000", 24 | "white" : "27" 25 | } 26 | }, 27 | "idiom" : "universal" 28 | } 29 | ], 30 | "info" : { 31 | "author" : "xcode", 32 | "version" : 1 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Assets.xcassets/PromptEditorBackground.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "color" : { 5 | "color-space" : "display-p3", 6 | "components" : { 7 | "alpha" : "1.000", 8 | "blue" : "247", 9 | "green" : "242", 10 | "red" : "242" 11 | } 12 | }, 13 | "idiom" : "universal" 14 | }, 15 | { 16 | "appearances" : [ 17 | { 18 | "appearance" : "luminosity", 19 | "value" : "dark" 20 | } 21 | ], 22 | "color" : { 23 | "color-space" : "srgb", 24 | "components" : { 25 | "alpha" : "1.000", 26 | "blue" : "0.000", 27 | "green" : "0.000", 28 | "red" : "0.000" 29 | } 30 | }, 31 | "idiom" : "universal" 32 | } 33 | ], 34 | "info" : { 35 | "author" : "xcode", 36 | "version" : 1 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/AsyncFactory.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import Foundation 7 | 8 | /// Asynchronous factory for slow-to-load types. 9 | public actor AsyncFactory { 10 | 11 | private enum State { 12 | case idle(() -> T) 13 | case initializing(Task) 14 | case initialized(T) 15 | } 16 | 17 | private var state: State 18 | 19 | public init(factory: @escaping () -> T) { 20 | self.state = .idle(factory) 21 | } 22 | 23 | public func get() async -> T { 24 | switch state { 25 | case .idle(let factory): 26 | let task = Task { 27 | factory() 28 | } 29 | self.state = .initializing(task) 30 | let value = await task.value 31 | self.state = .initialized(value) 32 | return value 33 | 34 | case .initializing(let task): 35 | return await task.value 36 | 37 | case .initialized(let v): 38 | return v 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/CameraController.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import AVFoundation 7 | import CoreImage 8 | import UIKit 9 | 10 | @Observable 11 | class CameraController: NSObject { 12 | 13 | private var framesContinuation: AsyncStream.Continuation? 14 | 15 | public var backCamera = true { 16 | didSet { 17 | stop() 18 | start() 19 | } 20 | } 21 | 22 | private var permissionGranted = true 23 | private var captureSession: AVCaptureSession? 24 | private let sessionQueue = DispatchQueue(label: "sessionQueue") 25 | 26 | public func attach(continuation: AsyncStream.Continuation) { 27 | sessionQueue.async { 28 | self.framesContinuation = continuation 29 | } 30 | } 31 | 32 | public func detatch() { 33 | sessionQueue.async { 34 | self.framesContinuation = nil 35 | } 36 | } 37 | 38 | public func stop() { 39 | sessionQueue.sync { [self] in 40 | captureSession?.stopRunning() 41 | captureSession = nil 42 | } 43 | 44 | } 45 | 46 | public func start() { 47 | sessionQueue.async { [self] in 48 | let captureSession = AVCaptureSession() 49 | self.captureSession = captureSession 50 | 51 | self.checkPermission() 52 | self.setupCaptureSession(position: backCamera ? .back : .front) 53 | captureSession.startRunning() 54 | } 55 | } 56 | 57 | private func setOrientation(_ orientation: UIDeviceOrientation) { 58 | guard let captureSession else { return } 59 | 60 | let angle: Double? 61 | switch orientation { 62 | case .unknown, .faceDown: 63 | angle = nil 64 | case .portrait, .faceUp: 65 | angle = 90 66 | case .portraitUpsideDown: 67 | angle = 270 68 | case .landscapeLeft: 69 | angle = 0 70 | case .landscapeRight: 71 | angle = 180 72 | @unknown default: 73 | angle = nil 74 | } 75 | 76 | if let angle { 77 | for output in captureSession.outputs { 78 | output.connection(with: .video)?.videoRotationAngle = angle 79 | } 80 | } 81 | } 82 | 83 | func checkPermission() { 84 | switch AVCaptureDevice.authorizationStatus(for: .video) { 85 | case .authorized: 86 | // The user has previously granted access to the camera. 87 | self.permissionGranted = true 88 | 89 | case .notDetermined: 90 | // The user has not yet been asked for camera access. 91 | self.requestPermission() 92 | 93 | // Combine the two other cases into the default case 94 | default: 95 | self.permissionGranted = false 96 | } 97 | } 98 | 99 | func requestPermission() { 100 | // Strong reference not a problem here but might become one in the future. 101 | AVCaptureDevice.requestAccess(for: .video) { [unowned self] granted in 102 | self.permissionGranted = granted 103 | } 104 | } 105 | 106 | func setupCaptureSession(position: AVCaptureDevice.Position) { 107 | guard let captureSession else { return } 108 | 109 | let videoOutput = AVCaptureVideoDataOutput() 110 | 111 | guard permissionGranted else { 112 | print("No permission for camera") 113 | return 114 | } 115 | 116 | let videoDeviceDiscoverySession = AVCaptureDevice.DiscoverySession( 117 | deviceTypes: [.builtInDualCamera, .builtInWideAngleCamera], 118 | mediaType: .video, 119 | position: position) 120 | 121 | guard 122 | let videoDevice = videoDeviceDiscoverySession.devices.first 123 | else { 124 | print("Unable to find video device") 125 | return 126 | } 127 | guard let videoDeviceInput = try? AVCaptureDeviceInput(device: videoDevice) else { 128 | print("Unable to create AVCaptureDeviceInput") 129 | return 130 | } 131 | guard captureSession.canAddInput(videoDeviceInput) else { 132 | print("Unable to add input") 133 | return 134 | } 135 | captureSession.addInput(videoDeviceInput) 136 | 137 | videoOutput.setSampleBufferDelegate(self, queue: DispatchQueue(label: "sampleBufferQueue")) 138 | captureSession.addOutput(videoOutput) 139 | captureSession.sessionPreset = AVCaptureSession.Preset.vga640x480 140 | 141 | if videoDevice.isContinuityCamera { 142 | setOrientation(.portrait) 143 | } else { 144 | setOrientation(UIDevice.current.orientation) 145 | } 146 | } 147 | } 148 | 149 | extension CameraController: AVCaptureVideoDataOutputSampleBufferDelegate { 150 | func captureOutput( 151 | _ output: AVCaptureOutput, didOutput sampleBuffer: CMSampleBuffer, 152 | from connection: AVCaptureConnection 153 | ) { 154 | if sampleBuffer.isValid && sampleBuffer.imageBuffer != nil { 155 | framesContinuation?.yield(sampleBuffer) 156 | } 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Configuration/Build.xcconfig: -------------------------------------------------------------------------------- 1 | // The `DISAMBIGUATOR` configuration is to make it easier to build 2 | // and run a sample code project. Once you set your project's development team, 3 | // you'll have a unique bundle identifier. This is because the bundle identifier 4 | // is derived based on the 'DISAMBIGUATOR' value. Do not use this 5 | // approach in your own projects—it's only useful for example projects because 6 | // they are frequently downloaded and don't have a development team set. 7 | DISAMBIGUATOR=${DEVELOPMENT_TEAM} 8 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/ContentView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import CoreML 7 | import SwiftUI 8 | 9 | struct ContentView: View { 10 | @State private var camera = CameraController() 11 | 12 | var body: some View { 13 | InferenceView( 14 | camera: camera, 15 | backCamera: $camera.backCamera 16 | ) 17 | .ignoresSafeArea(edges: [.bottom]) 18 | .task { 19 | camera.start() 20 | } 21 | } 22 | } 23 | 24 | struct ContentView_Previews: PreviewProvider { 25 | static var previews: some View { 26 | ContentView() 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/DisplayPrediction.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import Foundation 7 | import SwiftUI 8 | 9 | struct DisplayPrediction: Identifiable { 10 | var id: String { className } 11 | 12 | var className: String 13 | var cosineSimilarity: Float 14 | } 15 | 16 | /// Top N percent will be bold 17 | private let boldPercent: Float = 0.05 18 | 19 | /// It has to exceed this to be bold 20 | private let boldThreshold: Float = 0.22 21 | 22 | /// Colors for interpolating based on color scheme 23 | private func colors(for colorScheme: ColorScheme) -> [Color] { 24 | var result = [Color]() 25 | 26 | let topGray: CGFloat = colorScheme == .dark ? 1.0 : 0.0 27 | let bottomGray: CGFloat = colorScheme == .dark ? 0.2 : 0.8 28 | 29 | for position in stride(from: 0, through: 1, by: 0.05) { 30 | let gray = (bottomGray + CGFloat(position) * (topGray - bottomGray)) 31 | 32 | result.append(.init(cgColor: CGColor(gray: gray, alpha: 1.0))) 33 | } 34 | 35 | return result 36 | } 37 | 38 | /// Boldness threshold 39 | private let maximumThreshold: Float = 0.22 40 | private let minimumTreshold: Float = 0.15 41 | 42 | struct DisplayPredictionFormatter { 43 | 44 | let actualMaximumCosineSimilarity: Float 45 | let colorScheme: ColorScheme 46 | 47 | init(predictions: [DisplayPrediction], colorScheme: ColorScheme) { 48 | self.actualMaximumCosineSimilarity = predictions.map { $0.cosineSimilarity }.max() ?? 0 49 | self.colorScheme = colorScheme 50 | } 51 | 52 | func isBold(_ prediction: DisplayPrediction) -> Bool { 53 | prediction.cosineSimilarity >= boldThreshold 54 | && prediction.cosineSimilarity >= actualMaximumCosineSimilarity * (1 - boldPercent) 55 | } 56 | 57 | func color(_ prediction: DisplayPrediction) -> Color { 58 | let position = 59 | min(max(prediction.cosineSimilarity - minimumTreshold, 0), maximumThreshold) 60 | / (maximumThreshold - minimumTreshold) 61 | let index = min( 62 | Int(round(Float(colors(for: colorScheme).count - 1) * position)), 63 | colors(for: colorScheme).count - 1) 64 | return colors(for: colorScheme)[index] 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Extensions.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import CoreImage 7 | 8 | extension CIImage { 9 | 10 | /// cropToSquare image 11 | /// - Returns: Cropped image 12 | func cropToSquare() -> CIImage? { 13 | let size = min(self.extent.width, self.extent.height) 14 | let x = round((self.extent.width - size) / 2) 15 | let y = round((self.extent.height - size) / 2) 16 | 17 | let cropRect = CGRect( 18 | x: x, 19 | y: y, 20 | width: size, 21 | height: size 22 | ) 23 | 24 | let translate = CGAffineTransform(translationX: -x, y: -y) 25 | 26 | return 27 | self 28 | .cropped(to: cropRect) 29 | .transformed(by: translate) 30 | } 31 | 32 | /// Resize image 33 | /// - Parameter size: Size to resize to 34 | /// - Returns: Resized image 35 | func resize(size: CGSize) -> CIImage? { 36 | let scaleX = size.width / self.extent.width 37 | let scaleY = size.height / self.extent.height 38 | return self.transformed(by: CGAffineTransform(scaleX: scaleX, y: scaleY)) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/MobileCLIPExploreApp.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import SwiftUI 7 | 8 | @main 9 | struct MobileCLIPExploreApp: App { 10 | var body: some Scene { 11 | WindowGroup { 12 | ContentView() 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Models.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import CoreML 7 | import Foundation 8 | 9 | protocol CLIPEncoder { 10 | 11 | var targetImageSize: CGSize { get } 12 | 13 | func load() async 14 | 15 | func encode(image: CVPixelBuffer) async throws -> MLMultiArray 16 | 17 | func encode(text: MLMultiArray) async throws -> MLMultiArray 18 | 19 | } 20 | 21 | public struct ModelConfiguration: Identifiable, Hashable { 22 | public let name: String 23 | let factory: () -> CLIPEncoder 24 | public var id: String { name } 25 | 26 | public static func == (lhs: ModelConfiguration, rhs: ModelConfiguration) -> Bool { 27 | lhs.name == rhs.name 28 | } 29 | 30 | public func hash(into hasher: inout Hasher) { 31 | hasher.combine(name) 32 | } 33 | } 34 | 35 | public let models: [ModelConfiguration] = [ 36 | .init(name: "MobileCLIP-S0", factory: { S0Model() }), 37 | .init(name: "MobileCLIP-S1", factory: { S1Model() }), 38 | .init(name: "MobileCLIP-S2", factory: { S2Model() }), 39 | .init(name: "MobileCLIP-BLT", factory: { BLTModel() }), 40 | ] 41 | 42 | public let defaultModel = ModelConfiguration(name: "MobileCLIP-S2", factory: { S2Model() }) 43 | 44 | public struct S0Model: CLIPEncoder { 45 | 46 | let imageEncoder = AsyncFactory { 47 | do { 48 | return try mobileclip_s0_image() 49 | } catch { 50 | fatalError("Failed to initialize ML model: \(error)") 51 | } 52 | } 53 | 54 | let textEncoder = AsyncFactory { 55 | do { 56 | return try mobileclip_s0_text() 57 | } catch { 58 | fatalError("Failed to initialize ML model: \(error)") 59 | } 60 | } 61 | 62 | func load() async { 63 | async let t = textEncoder.get() 64 | async let i = imageEncoder.get() 65 | _ = await (t, i) 66 | } 67 | 68 | let targetImageSize = CGSize(width: 256, height: 256) 69 | 70 | func encode(image: CVPixelBuffer) async throws -> MLMultiArray { 71 | try await imageEncoder.get().prediction(image: image).final_emb_1 72 | } 73 | 74 | func encode(text: MLMultiArray) async throws -> MLMultiArray { 75 | try await textEncoder.get().prediction(text: text).final_emb_1 76 | } 77 | } 78 | 79 | public struct S1Model: CLIPEncoder { 80 | 81 | let imageEncoder = AsyncFactory { 82 | do { 83 | return try mobileclip_s1_image() 84 | } catch { 85 | fatalError("Failed to initialize ML model: \(error)") 86 | } 87 | } 88 | 89 | let textEncoder = AsyncFactory { 90 | do { 91 | return try mobileclip_s1_text() 92 | } catch { 93 | fatalError("Failed to initialize ML model: \(error)") 94 | } 95 | } 96 | 97 | func load() async { 98 | async let t = textEncoder.get() 99 | async let i = imageEncoder.get() 100 | _ = await (t, i) 101 | } 102 | 103 | let targetImageSize = CGSize(width: 256, height: 256) 104 | 105 | func encode(image: CVPixelBuffer) async throws -> MLMultiArray { 106 | try await imageEncoder.get().prediction(image: image).final_emb_1 107 | } 108 | 109 | func encode(text: MLMultiArray) async throws -> MLMultiArray { 110 | try await textEncoder.get().prediction(text: text).final_emb_1 111 | } 112 | } 113 | 114 | public struct S2Model: CLIPEncoder { 115 | 116 | let imageEncoder = AsyncFactory { 117 | do { 118 | return try mobileclip_s2_image() 119 | } catch { 120 | fatalError("Failed to initialize ML model: \(error)") 121 | } 122 | } 123 | 124 | let textEncoder = AsyncFactory { 125 | do { 126 | return try mobileclip_s2_text() 127 | } catch { 128 | fatalError("Failed to initialize ML model: \(error)") 129 | } 130 | } 131 | 132 | func load() async { 133 | async let t = textEncoder.get() 134 | async let i = imageEncoder.get() 135 | _ = await (t, i) 136 | } 137 | 138 | let targetImageSize = CGSize(width: 256, height: 256) 139 | 140 | func encode(image: CVPixelBuffer) async throws -> MLMultiArray { 141 | try await imageEncoder.get().prediction(image: image).final_emb_1 142 | } 143 | 144 | func encode(text: MLMultiArray) async throws -> MLMultiArray { 145 | try await textEncoder.get().prediction(text: text).final_emb_1 146 | } 147 | } 148 | 149 | public struct BLTModel: CLIPEncoder { 150 | 151 | let imageEncoder = AsyncFactory { 152 | do { 153 | return try mobileclip_blt_image() 154 | } catch { 155 | fatalError("Failed to initialize ML model: \(error)") 156 | } 157 | } 158 | 159 | let textEncoder = AsyncFactory { 160 | do { 161 | return try mobileclip_blt_text() 162 | } catch { 163 | fatalError("Failed to initialize ML model: \(error)") 164 | } 165 | } 166 | 167 | func load() async { 168 | async let t = textEncoder.get() 169 | async let i = imageEncoder.get() 170 | _ = await (t, i) 171 | } 172 | 173 | let targetImageSize = CGSize(width: 224, height: 224) 174 | 175 | func encode(image: CVPixelBuffer) async throws -> MLMultiArray { 176 | try await imageEncoder.get().prediction(image: image).final_emb_1 177 | } 178 | 179 | func encode(text: MLMultiArray) async throws -> MLMultiArray { 180 | try await textEncoder.get().prediction(text: text).final_emb_1 181 | } 182 | } 183 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Models/README.md: -------------------------------------------------------------------------------- 1 | Download the required CoreML models from [this link](https://huggingface.co/apple/coreml-mobileclip) and place them in this directory. -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Prompt.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import Foundation 7 | 8 | public let presets = [ 9 | PromptPreset( 10 | title: "Desk items", 11 | prompt: .init( 12 | prefix: "A photo of", 13 | suffix: "", 14 | classNames: [ 15 | "pen", 16 | "pencil", 17 | "paper", 18 | "mouse", 19 | "keyboard", 20 | "computer", 21 | "phone", 22 | "stapler", 23 | "cup", 24 | ]) 25 | ), 26 | PromptPreset( 27 | title: "Expressions", 28 | prompt: .init( 29 | prefix: "A person", 30 | suffix: "", 31 | classNames: [ 32 | "smiling", 33 | "waving", 34 | "giving a thumbs up", 35 | "sticking out their tongue", 36 | "looking angry", 37 | ]) 38 | ), 39 | PromptPreset( 40 | title: "Custom", 41 | prompt: .init( 42 | prefix: "A photo of", 43 | suffix: "", 44 | classNames: []) 45 | ), 46 | ] 47 | 48 | public struct PromptPreset: Identifiable { 49 | public let id = UUID() 50 | public let title: String 51 | public let prompt: Prompt 52 | } 53 | 54 | public struct Prompt { 55 | public var prefix: String 56 | public var suffix: String 57 | public var classNames: [String] 58 | 59 | public func fullPrompts() -> [String] { 60 | classNames.map { 61 | "\(prefix) \($0) \(suffix)" 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Tokenizer/CLIPTokenizer.swift: -------------------------------------------------------------------------------- 1 | // 2 | // CLIPTokenizer.swift 3 | // CoreMLBert 4 | // 5 | // Created by Matthew Waller on 1/31/23. 6 | // Copyright © 2023 Hugging Face. All rights reserved. 7 | // 8 | // Modified by Hugues Thomas on 5/14/24. 9 | // 10 | // See https://github.com/huggingface/swift-coreml-transformers/pull/30 11 | 12 | import Foundation 13 | 14 | struct BytePair: Hashable { 15 | let a: String 16 | let b: String 17 | init(_ a: String, _ b: String) { 18 | self.a = a 19 | self.b = b 20 | } 21 | init(tuple: [String]) { 22 | self.a = tuple[0] 23 | self.b = tuple[1] 24 | } 25 | 26 | static func == (lhs: BytePair, rhs: BytePair) -> Bool { 27 | return lhs.a == rhs.a && lhs.b == rhs.b 28 | } 29 | func hash(into hasher: inout Hasher) { 30 | hasher.combine(a) 31 | hasher.combine(b) 32 | } 33 | } 34 | 35 | extension String { 36 | fileprivate func ranges(of string: String, options: CompareOptions = .regularExpression) 37 | -> [Range] 38 | { 39 | var result: [Range] = [] 40 | var start = startIndex 41 | while let range = range(of: string, options: options, range: start ..< endIndex) { 42 | result.append(range) 43 | start = 44 | range.lowerBound < range.upperBound 45 | ? range.upperBound 46 | : index(range.lowerBound, offsetBy: 1, limitedBy: endIndex) ?? endIndex 47 | } 48 | return result 49 | } 50 | } 51 | 52 | class CLIPTokenizer { 53 | let bpeRanks: [BytePair: Int] 54 | private let encoder: [String: Int] 55 | private let decoder: [Int: String] 56 | let contextLength = 77 57 | 58 | init() { 59 | 60 | let url = Bundle.main.url(forResource: "clip-merges", withExtension: "txt")! 61 | 62 | let bpeMergesTxt = try! String(contentsOf: url) 63 | let arr = bpeMergesTxt.split(separator: "\n").map { String($0) } 64 | var bpeRanks: [BytePair: Int] = [:] 65 | for i in 1 ..< arr.count { 66 | let tuple = arr[i].split(separator: " ").map { String($0) } 67 | let bp = BytePair(tuple: tuple) 68 | bpeRanks[bp] = i - 1 69 | } 70 | self.bpeRanks = bpeRanks 71 | 72 | self.encoder = { 73 | let url = Bundle.main.url(forResource: "clip-vocab", withExtension: "json")! 74 | let json = try! Data(contentsOf: url) 75 | let decoder = JSONDecoder() 76 | let vocab = try! decoder.decode([String: Int].self, from: json) 77 | return vocab 78 | }() 79 | 80 | self.decoder = Utils.invert(self.encoder) 81 | } 82 | 83 | func byteEncode(text: String) -> [String] { 84 | let RE = 85 | "<\\|startoftext\\|>|<\\|endoftext\\|>|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+" 86 | 87 | // Original code not working on earlier iOS versions 88 | // let tokens = text.ranges(of: RE).map { String(text[$0]) } 89 | // return tokens.map { (token) -> String in 90 | // return Array(token.utf8).map { byteEncoder[$0]! }.joined() 91 | // } 92 | 93 | // Modification by Hugues Thomas 94 | let regex = try! NSRegularExpression(pattern: RE, options: []) 95 | let matches = regex.matches( 96 | in: text, options: [], range: NSRange(location: 0, length: text.utf16.count)) 97 | let tokens = matches.map { (match) -> String in 98 | let range = Range(match.range, in: text)! 99 | return String(text[range]) 100 | } 101 | return tokens.map { (token) -> String in 102 | return Array(token.utf8).map { byteEncoder[$0]! }.joined() 103 | } 104 | 105 | } 106 | 107 | private func getPairs(word: [String]) -> Set { 108 | var s = Set() 109 | for i in 0 ..< word.count - 1 { 110 | let bp = BytePair( 111 | word[i], 112 | word[i + 1] 113 | ) 114 | s.insert(bp) 115 | } 116 | return s 117 | } 118 | 119 | func bpe(token: String) -> String { 120 | if token.count <= 1 { 121 | return token + "" 122 | } 123 | 124 | var word = Array(token).map { String($0) } 125 | let last = (word.last ?? "") + "" 126 | word.removeLast() 127 | word.append(last) 128 | var pairs = Array(getPairs(word: word)) 129 | if pairs.isEmpty { 130 | return token + "" 131 | } 132 | 133 | while true { 134 | let bigrams = pairs.filter { (bp) -> Bool in bpeRanks[bp] != nil } 135 | if bigrams.count == 0 { 136 | break 137 | } 138 | let bigram = bigrams.min { (bp1, bp2) -> Bool in 139 | return bpeRanks[bp1]! < bpeRanks[bp2]! 140 | }! 141 | let first = bigram.a 142 | let second = bigram.b 143 | var newWord: [String] = [] 144 | var i = 0 145 | while i < word.count { 146 | if let j = word[i ..< word.count].firstIndex(of: first) { 147 | newWord.append(contentsOf: word[i ..< j]) 148 | i = j 149 | } else { 150 | newWord.append(contentsOf: word[i ..< word.count]) 151 | break 152 | } 153 | 154 | if word[i] == first && i < word.count - 1 && word[i + 1] == second { 155 | newWord.append(first + second) 156 | i += 2 157 | } else { 158 | newWord.append(word[i]) 159 | i += 1 160 | } 161 | } 162 | word = newWord 163 | if word.count == 1 { 164 | break 165 | } else { 166 | pairs = Array(getPairs(word: word)) 167 | } 168 | } 169 | return word.joined(separator: " ") 170 | } 171 | 172 | func tokenize(text: String) -> [String] { 173 | var tokens: [String] = [] 174 | let lowercased = text.lowercased() 175 | for token in self.byteEncode(text: lowercased) { 176 | let xx = self.bpe(token: token).split(separator: " ").map { String($0) } 177 | tokens.append(contentsOf: xx) 178 | } 179 | return tokens 180 | } 181 | 182 | /// Main entry point 183 | func encode(text: String) -> [Int] { 184 | return tokenize(text: text).compactMap { encoder[$0] } 185 | } 186 | 187 | /// Decode 188 | func decode(tokens: [Int]) -> String { 189 | let text = tokens.map { decoder[$0]! }.joined(separator: "") 190 | let utfCodepoints = text.map { byteDecoder[String($0)]! } 191 | return String(decoding: utfCodepoints, as: UTF8.self) 192 | } 193 | 194 | func encode_full(text: String) -> [Int] { 195 | let tokens = encode(text: text) 196 | 197 | // Create the full input tokens as a multiarray of shape 1 x contextLength 198 | var fullTokens = Array(repeating: 0, count: contextLength) 199 | fullTokens[0] = encoder["<|startoftext|>"]! 200 | for i in 0 ..< tokens.count { 201 | fullTokens[i + 1] = tokens[i] 202 | } 203 | fullTokens[tokens.count + 1] = encoder["<|endoftext|>"]! 204 | return fullTokens 205 | 206 | } 207 | } 208 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Tokenizer/GPT2ByteEncoder.swift: -------------------------------------------------------------------------------- 1 | // 2 | // GPT2ByteEncoder.swift 3 | // CoreMLBert 4 | // 5 | // Created by Julien Chaumond on 18/07/2019. 6 | // Copyright © 2019 Hugging Face. All rights reserved. 7 | // 8 | // See https://github.com/huggingface/swift-coreml-transformers 9 | 10 | import Foundation 11 | 12 | let byteEncoder: [UTF8.CodeUnit: String] = [ 13 | 33: "!", 14 | 34: "\"", 15 | 35: "#", 16 | 36: "$", 17 | 37: "%", 18 | 38: "&", 19 | 39: "'", 20 | 40: "(", 21 | 41: ")", 22 | 42: "*", 23 | 43: "+", 24 | 44: ",", 25 | 45: "-", 26 | 46: ".", 27 | 47: "/", 28 | 48: "0", 29 | 49: "1", 30 | 50: "2", 31 | 51: "3", 32 | 52: "4", 33 | 53: "5", 34 | 54: "6", 35 | 55: "7", 36 | 56: "8", 37 | 57: "9", 38 | 58: ":", 39 | 59: ";", 40 | 60: "<", 41 | 61: "=", 42 | 62: ">", 43 | 63: "?", 44 | 64: "@", 45 | 65: "A", 46 | 66: "B", 47 | 67: "C", 48 | 68: "D", 49 | 69: "E", 50 | 70: "F", 51 | 71: "G", 52 | 72: "H", 53 | 73: "I", 54 | 74: "J", 55 | 75: "K", 56 | 76: "L", 57 | 77: "M", 58 | 78: "N", 59 | 79: "O", 60 | 80: "P", 61 | 81: "Q", 62 | 82: "R", 63 | 83: "S", 64 | 84: "T", 65 | 85: "U", 66 | 86: "V", 67 | 87: "W", 68 | 88: "X", 69 | 89: "Y", 70 | 90: "Z", 71 | 91: "[", 72 | 92: "\\", 73 | 93: "]", 74 | 94: "^", 75 | 95: "_", 76 | 96: "`", 77 | 97: "a", 78 | 98: "b", 79 | 99: "c", 80 | 100: "d", 81 | 101: "e", 82 | 102: "f", 83 | 103: "g", 84 | 104: "h", 85 | 105: "i", 86 | 106: "j", 87 | 107: "k", 88 | 108: "l", 89 | 109: "m", 90 | 110: "n", 91 | 111: "o", 92 | 112: "p", 93 | 113: "q", 94 | 114: "r", 95 | 115: "s", 96 | 116: "t", 97 | 117: "u", 98 | 118: "v", 99 | 119: "w", 100 | 120: "x", 101 | 121: "y", 102 | 122: "z", 103 | 123: "{", 104 | 124: "|", 105 | 125: "}", 106 | 126: "~", 107 | 161: "\u{00a1}", 108 | 162: "\u{00a2}", 109 | 163: "\u{00a3}", 110 | 164: "\u{00a4}", 111 | 165: "\u{00a5}", 112 | 166: "\u{00a6}", 113 | 167: "\u{00a7}", 114 | 168: "\u{00a8}", 115 | 169: "\u{00a9}", 116 | 170: "\u{00aa}", 117 | 171: "\u{00ab}", 118 | 172: "\u{00ac}", 119 | 174: "\u{00ae}", 120 | 175: "\u{00af}", 121 | 176: "\u{00b0}", 122 | 177: "\u{00b1}", 123 | 178: "\u{00b2}", 124 | 179: "\u{00b3}", 125 | 180: "\u{00b4}", 126 | 181: "\u{00b5}", 127 | 182: "\u{00b6}", 128 | 183: "\u{00b7}", 129 | 184: "\u{00b8}", 130 | 185: "\u{00b9}", 131 | 186: "\u{00ba}", 132 | 187: "\u{00bb}", 133 | 188: "\u{00bc}", 134 | 189: "\u{00bd}", 135 | 190: "\u{00be}", 136 | 191: "\u{00bf}", 137 | 192: "\u{00c0}", 138 | 193: "\u{00c1}", 139 | 194: "\u{00c2}", 140 | 195: "\u{00c3}", 141 | 196: "\u{00c4}", 142 | 197: "\u{00c5}", 143 | 198: "\u{00c6}", 144 | 199: "\u{00c7}", 145 | 200: "\u{00c8}", 146 | 201: "\u{00c9}", 147 | 202: "\u{00ca}", 148 | 203: "\u{00cb}", 149 | 204: "\u{00cc}", 150 | 205: "\u{00cd}", 151 | 206: "\u{00ce}", 152 | 207: "\u{00cf}", 153 | 208: "\u{00d0}", 154 | 209: "\u{00d1}", 155 | 210: "\u{00d2}", 156 | 211: "\u{00d3}", 157 | 212: "\u{00d4}", 158 | 213: "\u{00d5}", 159 | 214: "\u{00d6}", 160 | 215: "\u{00d7}", 161 | 216: "\u{00d8}", 162 | 217: "\u{00d9}", 163 | 218: "\u{00da}", 164 | 219: "\u{00db}", 165 | 220: "\u{00dc}", 166 | 221: "\u{00dd}", 167 | 222: "\u{00de}", 168 | 223: "\u{00df}", 169 | 224: "\u{00e0}", 170 | 225: "\u{00e1}", 171 | 226: "\u{00e2}", 172 | 227: "\u{00e3}", 173 | 228: "\u{00e4}", 174 | 229: "\u{00e5}", 175 | 230: "\u{00e6}", 176 | 231: "\u{00e7}", 177 | 232: "\u{00e8}", 178 | 233: "\u{00e9}", 179 | 234: "\u{00ea}", 180 | 235: "\u{00eb}", 181 | 236: "\u{00ec}", 182 | 237: "\u{00ed}", 183 | 238: "\u{00ee}", 184 | 239: "\u{00ef}", 185 | 240: "\u{00f0}", 186 | 241: "\u{00f1}", 187 | 242: "\u{00f2}", 188 | 243: "\u{00f3}", 189 | 244: "\u{00f4}", 190 | 245: "\u{00f5}", 191 | 246: "\u{00f6}", 192 | 247: "\u{00f7}", 193 | 248: "\u{00f8}", 194 | 249: "\u{00f9}", 195 | 250: "\u{00fa}", 196 | 251: "\u{00fb}", 197 | 252: "\u{00fc}", 198 | 253: "\u{00fd}", 199 | 254: "\u{00fe}", 200 | 255: "\u{00ff}", 201 | 0: "\u{0100}", 202 | 1: "\u{0101}", 203 | 2: "\u{0102}", 204 | 3: "\u{0103}", 205 | 4: "\u{0104}", 206 | 5: "\u{0105}", 207 | 6: "\u{0106}", 208 | 7: "\u{0107}", 209 | 8: "\u{0108}", 210 | 9: "\u{0109}", 211 | 10: "\u{010a}", 212 | 11: "\u{010b}", 213 | 12: "\u{010c}", 214 | 13: "\u{010d}", 215 | 14: "\u{010e}", 216 | 15: "\u{010f}", 217 | 16: "\u{0110}", 218 | 17: "\u{0111}", 219 | 18: "\u{0112}", 220 | 19: "\u{0113}", 221 | 20: "\u{0114}", 222 | 21: "\u{0115}", 223 | 22: "\u{0116}", 224 | 23: "\u{0117}", 225 | 24: "\u{0118}", 226 | 25: "\u{0119}", 227 | 26: "\u{011a}", 228 | 27: "\u{011b}", 229 | 28: "\u{011c}", 230 | 29: "\u{011d}", 231 | 30: "\u{011e}", 232 | 31: "\u{011f}", 233 | 32: "\u{0120}", 234 | 127: "\u{0121}", 235 | 128: "\u{0122}", 236 | 129: "\u{0123}", 237 | 130: "\u{0124}", 238 | 131: "\u{0125}", 239 | 132: "\u{0126}", 240 | 133: "\u{0127}", 241 | 134: "\u{0128}", 242 | 135: "\u{0129}", 243 | 136: "\u{012a}", 244 | 137: "\u{012b}", 245 | 138: "\u{012c}", 246 | 139: "\u{012d}", 247 | 140: "\u{012e}", 248 | 141: "\u{012f}", 249 | 142: "\u{0130}", 250 | 143: "\u{0131}", 251 | 144: "\u{0132}", 252 | 145: "\u{0133}", 253 | 146: "\u{0134}", 254 | 147: "\u{0135}", 255 | 148: "\u{0136}", 256 | 149: "\u{0137}", 257 | 150: "\u{0138}", 258 | 151: "\u{0139}", 259 | 152: "\u{013a}", 260 | 153: "\u{013b}", 261 | 154: "\u{013c}", 262 | 155: "\u{013d}", 263 | 156: "\u{013e}", 264 | 157: "\u{013f}", 265 | 158: "\u{0140}", 266 | 159: "\u{0141}", 267 | 160: "\u{0142}", 268 | 173: "\u{0143}", 269 | ] 270 | 271 | let byteDecoder = Utils.invert(byteEncoder) 272 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Tokenizer/Utils.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Utils.swift 3 | // AudioBoloss 4 | // 5 | // Created by Julien Chaumond on 07/01/2019. 6 | // Copyright © 2019 Hugging Face. All rights reserved. 7 | // 8 | // See https://github.com/huggingface/swift-coreml-transformers 9 | 10 | import Foundation 11 | 12 | struct Utils { 13 | /// Time a block in ms 14 | static func time(label: String, _ block: () -> T) -> T { 15 | let startTime = CFAbsoluteTimeGetCurrent() 16 | let result = block() 17 | let diff = (CFAbsoluteTimeGetCurrent() - startTime) * 1_000 18 | print("[\(label)] \(diff)ms") 19 | return result 20 | } 21 | 22 | /// Time a block in seconds and return (output, time) 23 | static func time(_ block: () -> T) -> (T, Double) { 24 | let startTime = CFAbsoluteTimeGetCurrent() 25 | let result = block() 26 | let diff = CFAbsoluteTimeGetCurrent() - startTime 27 | return (result, diff) 28 | } 29 | 30 | /// Return unix timestamp in ms 31 | static func dateNow() -> Int64 { 32 | // Use `Int` when we don't support 32-bits devices/OSes anymore. 33 | // Int crashes on iPhone 5c. 34 | return Int64(Date().timeIntervalSince1970 * 1000) 35 | } 36 | 37 | /// Clamp a val to [min, max] 38 | static func clamp(_ val: T, _ vmin: T, _ vmax: T) -> T { 39 | return min(max(vmin, val), vmax) 40 | } 41 | 42 | /// Fake func that can throw. 43 | static func fakeThrowable(_ input: T) throws -> T { 44 | return input 45 | } 46 | 47 | /// Substring 48 | static func substr(_ s: String, _ r: Range) -> String? { 49 | let stringCount = s.count 50 | if stringCount < r.upperBound || stringCount < r.lowerBound { 51 | return nil 52 | } 53 | let startIndex = s.index(s.startIndex, offsetBy: r.lowerBound) 54 | let endIndex = s.index(startIndex, offsetBy: r.upperBound - r.lowerBound) 55 | return String(s[startIndex ..< endIndex]) 56 | } 57 | 58 | /// Invert a (k, v) dictionary 59 | static func invert(_ dict: [K: V]) -> [V: K] { 60 | var inverted: [V: K] = [:] 61 | for (k, v) in dict { 62 | inverted[v] = k 63 | } 64 | return inverted 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Views/Helpers/PredictionsTable.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import Foundation 7 | import SwiftUI 8 | 9 | // MARK: - Predictions Results Table 10 | /// Table showing predictions in sorted order 11 | struct PredictionsTable: View { 12 | 13 | let displayPredictions: [DisplayPrediction] 14 | 15 | @Environment(\.horizontalSizeClass) private var horizontalSizeClass 16 | 17 | private var isCompact: Bool { horizontalSizeClass == .compact } 18 | 19 | /// sort keys for classes 20 | @State private var sortCosineSimilarity = true 21 | @State private var sortOrder = [ 22 | KeyPathComparator(\DisplayPrediction.cosineSimilarity, order: .reverse) 23 | ] 24 | 25 | @Environment(\.colorScheme) var colorScheme 26 | 27 | var body: some View { 28 | VStack(spacing: 1) { 29 | 30 | let predictions = displayPredictions.sorted(using: sortOrder) 31 | let formatter = DisplayPredictionFormatter( 32 | predictions: predictions, colorScheme: colorScheme) 33 | 34 | HStack { 35 | Text("CLASSNAME") 36 | Spacer() 37 | Text("COSINE SIMILARITY") 38 | } 39 | .font(.footnote) 40 | .fontWeight(.semibold) 41 | .padding(.horizontal, 23.0) 42 | .padding(.bottom, 10) 43 | .foregroundStyle(.secondary) 44 | 45 | Divider().overlay(.gray) 46 | 47 | List(predictions) { prediction in 48 | 49 | // a prediction in a row 50 | HStack { 51 | Text(prediction.className) 52 | Spacer() 53 | Text(prediction.cosineSimilarity.formatted()) 54 | } 55 | .bold(formatter.isBold(prediction)) 56 | .foregroundStyle(formatter.color(prediction)) 57 | } 58 | .listStyle(PlainListStyle()) 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Views/Helpers/PromptPreview.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import Foundation 7 | import SwiftUI 8 | 9 | /// Displays the current prompt formatted with an arrow 10 | struct PromptPreview: View { 11 | 12 | let prompt: Prompt 13 | 14 | var body: some View { 15 | HStack { 16 | Image(systemName: "arrow.forward") 17 | Text("\(prompt.prefix) ") + Text("CLASSNAME").underline() + Text(" \(prompt.suffix)") 18 | } 19 | .lineLimit(1) 20 | .padding(.bottom, 16) 21 | .frame(maxWidth: .infinity, alignment: .leading) 22 | .padding(.top, 4) 23 | } 24 | } 25 | 26 | #Preview { 27 | VStack(alignment: .leading) { 28 | PromptPreview(prompt: .init(prefix: "A photo of", suffix: "", classNames: [])) 29 | 30 | PromptPreview( 31 | prompt: .init( 32 | prefix: "A photo of", suffix: "with some longer thing at the end", classNames: [])) 33 | 34 | } 35 | .padding() 36 | } 37 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Views/Helpers/VideoFrameView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import Foundation 7 | import SwiftUI 8 | 9 | /// Displays a stream of video frames 10 | struct VideoFrameView: View { 11 | 12 | let frames: AsyncStream 13 | 14 | @Binding var backCamera: Bool 15 | 16 | @State private var videoFrame: CVImageBuffer? 17 | 18 | var body: some View { 19 | Group { 20 | if let videoFrame { 21 | // display the image, cropped to a square, with rounded corners 22 | _ImageView(image: videoFrame) 23 | .clipShape(RoundedRectangle(cornerRadius: 20)) 24 | .allowsHitTesting(false) 25 | .padding(.horizontal, 2) 26 | 27 | // control to flip to front/back facing camera 28 | .overlay(alignment: .bottomTrailing) { 29 | Button(action: toggleCamera) { 30 | Image(systemName: "arrow.triangle.2.circlepath.circle") 31 | .foregroundStyle(.blue) 32 | .frame(width: 20, height: 20) 33 | .padding(20) 34 | } 35 | .buttonStyle(.plain) 36 | } 37 | 38 | } else { 39 | // spinner before the camera comes up 40 | ProgressView() 41 | .controlSize(.large) 42 | } 43 | } 44 | .task { 45 | // feed frames to the _ImageView 46 | if Task.isCancelled { 47 | return 48 | } 49 | for await frame in frames { 50 | self.videoFrame = frame 51 | } 52 | } 53 | } 54 | 55 | func toggleCamera() { 56 | backCamera.toggle() 57 | } 58 | } 59 | 60 | /// Internal view to display a CVImageBuffer 61 | private struct _ImageView: UIViewRepresentable { 62 | 63 | let image: Any 64 | var gravity = CALayerContentsGravity.center 65 | 66 | func makeUIView(context: Context) -> UIView { 67 | let view = UIView() 68 | view.layer.contentsGravity = gravity 69 | return view 70 | } 71 | 72 | func updateUIView(_ uiView: UIView, context: Context) { 73 | uiView.layer.contents = image 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Views/InfoView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import Foundation 7 | import SwiftUI 8 | 9 | /// About MobileCLIP, model options and credits 10 | struct InfoView: View { 11 | 12 | var modelOptionsData = [ 13 | (name: "MobileCLIP-S0", latency: "2.1", accuracy: "58.1"), 14 | (name: "MobileCLIP-S1", latency: "5.8", accuracy: "61.3"), 15 | (name: "MobileCLIP-S2", latency: "6.9", accuracy: "63.7"), 16 | (name: "MobileCLIP-B (LT)", latency: "13.7", accuracy: "65.8"), 17 | ] 18 | 19 | var body: some View { 20 | VStack(spacing: 10) { 21 | Text("Information") 22 | .font(.headline) 23 | .bold() 24 | .padding(.top) 25 | 26 | // About MobileCLIP 27 | Group { 28 | Text("MobileCLIP") 29 | .font(.largeTitle) 30 | .bold() 31 | .padding(.top) 32 | 33 | Group { 34 | Text("MobileCLIP¹") 35 | .fontWeight(.bold) 36 | + Text( 37 | " is a new family of efficient image-text models optimized for runtime performance, trained with a novel and efficient training approach, namely multi-model reinforced training." 38 | ) 39 | 40 | Text( 41 | "This app demonstrates the use of **MobileCLIP** models for performing real-time zero-shot scene classification. Users are free to customize the prompt and provide classnames of their choice." 42 | ) 43 | } 44 | .padding(.bottom) 45 | } 46 | .frame(maxWidth: .infinity, alignment: .leading) 47 | .padding(.horizontal) 48 | 49 | 50 | // Model Options Table 51 | Group { 52 | Text("Model Options") 53 | .font(.title2) 54 | .bold() 55 | 56 | Text("You can select to run any of the following MobileCLIP model options:") 57 | .padding(.bottom) 58 | } 59 | .frame(maxWidth: .infinity, alignment: .leading) 60 | .padding(.horizontal) 61 | 62 | Group { 63 | HStack { 64 | Text("Name") 65 | Spacer() 66 | .frame(width: 105, alignment: .leading) 67 | Text("Ideal Latency") 68 | Spacer() 69 | Text("Accuracy") 70 | } 71 | .font(.headline) 72 | .fontWeight(.bold) 73 | .padding(.horizontal) 74 | 75 | Divider() 76 | 77 | ForEach(modelOptionsData, id: \.name) { option in 78 | HStack { 79 | Text(option.name) 80 | .fontWeight(.semibold) 81 | .frame(width: 125, alignment: .leading) 82 | Spacer() 83 | Text("\(option.latency) ms") 84 | .fontWeight(.light) 85 | Spacer() 86 | Text("\(option.accuracy)%") 87 | .fontWeight(.light) 88 | } 89 | .padding(.vertical, 5) 90 | .padding(.horizontal) 91 | Divider() 92 | } 93 | } 94 | 95 | // Authors and Citations 96 | VStack(alignment: .leading, spacing: 8) { 97 | HStack { 98 | Spacer().frame(width: 12) 99 | Text( 100 | "**¹ MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training.** (CVPR 2024)" 101 | ) 102 | Spacer().frame(width: 12) 103 | } 104 | Text( 105 | "Pavan Kumar Anasosalu Vasu, Hadi Pour Ansari, Fartash Faghri, Raviteja Vemulapalli, Oncel Tuzel." 106 | ) 107 | .padding(.horizontal, 12) 108 | 109 | } 110 | .foregroundColor(.secondary) 111 | .font(.system(size: 12)) 112 | 113 | Spacer() 114 | } 115 | .textSelection(.enabled) 116 | .font(.system(size: 16)) 117 | .padding(.bottom) 118 | .frame(maxWidth: .infinity, maxHeight: .infinity) 119 | } 120 | 121 | // MARK: - View Modifiers 122 | struct AboutPanel: ViewModifier { 123 | func body(content: Content) -> some View { 124 | content 125 | .foregroundColor(.secondary) 126 | .padding(.horizontal, 15) 127 | .padding(.vertical, 30) 128 | .textSelection(.enabled) 129 | } 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/Views/PromptEditor.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import Foundation 7 | import SwiftUI 8 | 9 | // MARK: - Field Views and Layout 10 | /// The primary label of a field 11 | private struct FieldTitle: View { 12 | 13 | let title: String 14 | 15 | var body: some View { 16 | Text(title.uppercased()) 17 | .font(.caption) 18 | .bold() 19 | } 20 | } 21 | 22 | /// A field with primary and secondary labels, e.g. shown when editing the `start of a prompt` 23 | private struct FieldView: View { 24 | 25 | let title: String 26 | let help: String 27 | let contents: V 28 | 29 | internal init(title: String, help: String, @ViewBuilder contents: () -> V) { 30 | self.title = title 31 | self.help = help 32 | self.contents = contents() 33 | } 34 | 35 | var body: some View { 36 | VStack(alignment: .leading) { 37 | FieldTitle(title: title) 38 | .padding(.horizontal, 10) 39 | 40 | contents 41 | .padding() 42 | .background(Color("PromptEditor")) 43 | .clipShape(RoundedRectangle(cornerRadius: 8)) 44 | 45 | Text(help) 46 | .modifier(SecondaryLabel()) 47 | .padding(.horizontal, 10) 48 | } 49 | .padding(.bottom, 15) 50 | } 51 | 52 | } 53 | 54 | // MARK: - Configure Prompt View 55 | /// Displays a preview, instructions and fields needed to configure a prompt 56 | struct PromptEditor: View { 57 | 58 | @Binding var prompt: Prompt 59 | 60 | var body: some View { 61 | VStack { 62 | ScrollView { 63 | Group { 64 | PromptPreview(prompt: prompt) 65 | 66 | Divider() 67 | 68 | Text( 69 | """ 70 | Configure the start and end of the prompt. A preview of the fully composed prompt for the text encoder is shown above. To customize class names, click the Classname Selection button. 71 | """.replacingOccurrences(of: "\n", with: " ") 72 | ) 73 | .modifier(SecondaryLabel()) 74 | .padding(.bottom, 30) 75 | } 76 | 77 | VStack(spacing: 8) { 78 | FieldView( 79 | title: "Start of prompt", 80 | help: "The start of the prompt. Can be left blank." 81 | ) { 82 | TextField("No prefix", text: $prompt.prefix) 83 | } 84 | 85 | FieldView( 86 | title: "Classnames", 87 | help: "Classnames that will be scored." 88 | ) { 89 | 90 | // Classname Selection Button 91 | NavigationLink { 92 | ClassnamesEditor(classnames: $prompt.classNames) 93 | 94 | } label: { 95 | HStack { 96 | if !prompt.classNames.isEmpty { 97 | 98 | // Display added classnames 99 | Text(prompt.classNames.joined(separator: ", ")) 100 | .lineLimit(1) 101 | 102 | } else { 103 | // Label shown if no classnames have been added 104 | Text("Classname Selection") 105 | 106 | } 107 | Spacer() 108 | Image(systemName: "pencil") 109 | .padding(.vertical) 110 | } 111 | } 112 | } 113 | 114 | FieldView( 115 | title: "End of prompt", 116 | help: "The end of the prompt. Can be left blank." 117 | ) { 118 | TextField("No suffix", text: $prompt.suffix) 119 | } 120 | } 121 | Spacer() 122 | } 123 | } 124 | .navigationTitle("Configure Prompt") 125 | .navigationBarTitleDisplayMode(.large) 126 | .padding(EdgeInsets(top: 10, leading: 15, bottom: 10, trailing: 15)) 127 | .background(Color("PromptEditorBackground")) 128 | } 129 | } 130 | 131 | // MARK: - Manage Classnames View 132 | /// Displays the current Classnames, and instructs the user how to add or remove them 133 | struct ClassnamesEditor: View { 134 | 135 | @Binding var classnames: [String] 136 | 137 | @State private var className = "" 138 | 139 | var body: some View { 140 | VStack(alignment: .leading) { 141 | HStack { 142 | TextField("Add a new classname", text: $className) 143 | .textInputAutocapitalization(.never) 144 | .onSubmit(addClassName) 145 | .padding() 146 | .background(Color("PromptEditor")) 147 | .clipShape(RoundedRectangle(cornerRadius: 8)) 148 | 149 | Button(action: { addClassName() }) { 150 | Image(systemName: "plus.circle") 151 | .foregroundStyle(Color("AddClassName")) 152 | .padding() 153 | } 154 | } 155 | .padding(.bottom, 16) 156 | 157 | Divider() 158 | 159 | Text("Add the classnames that you would like the app to detect.") 160 | .modifier(SecondaryLabel()) 161 | .padding(.bottom, 16) 162 | 163 | // Display classnames as they are added 164 | if !classnames.isEmpty { 165 | Group { 166 | FieldTitle(title: "Current Classnames") 167 | 168 | Text("Swipe left to remove items.") 169 | .modifier(SecondaryLabel()) 170 | } 171 | .padding(.horizontal, 10) 172 | 173 | List { 174 | ForEach(classnames, id: \.self) { className in 175 | 176 | // A classname in a row 177 | HStack { 178 | Text(className) 179 | Spacer() 180 | } 181 | .listRowBackground(Color("PromptEditor")) 182 | } 183 | .onDelete { indexes in 184 | classnames.remove(atOffsets: indexes) 185 | } 186 | } 187 | .ignoresSafeArea(.keyboard) 188 | .environment(\.defaultMinListRowHeight, 25) 189 | .scrollContentBackground(.hidden) 190 | } else { 191 | Spacer() 192 | } 193 | 194 | } 195 | .onAppear { 196 | classnames.sort { $0.lowercased() < $1.lowercased() } 197 | } 198 | .navigationTitle("Manage Classnames") 199 | .navigationBarTitleDisplayMode(.large) 200 | .padding() 201 | .background(Color("PromptEditorBackground")) 202 | } 203 | 204 | private func addClassName() { 205 | let nameToAdd = className.trimmingCharacters(in: .whitespacesAndNewlines) 206 | guard !nameToAdd.isEmpty else { return } 207 | guard !classnames.contains(nameToAdd) else { return } 208 | 209 | withAnimation { 210 | classnames.append(nameToAdd) 211 | classnames.sort { $0.lowercased() < $1.lowercased() } 212 | 213 | className = "" 214 | } 215 | } 216 | } 217 | 218 | // MARK: - View Modifiers 219 | struct SecondaryLabel: ViewModifier { 220 | func body(content: Content) -> some View { 221 | content 222 | .foregroundColor(.secondary) 223 | .font(.caption) 224 | } 225 | } 226 | 227 | // MARK: - Previews 228 | #Preview { 229 | VStack { 230 | FieldView( 231 | title: "Start of prompt", 232 | help: "The start of the prompt. Can be left blank." 233 | ) { 234 | TextField("", text: .constant("A photo of")) 235 | } 236 | 237 | ClassnamesEditor(classnames: .constant(["cat", "dog", "chicken"])) 238 | } 239 | .padding() 240 | .background(Color(UIColor.systemGroupedBackground)) 241 | } 242 | -------------------------------------------------------------------------------- /ios_app/MobileCLIPExplore/ZSImageClassification.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import CoreML 7 | import UIKit 8 | 9 | /// shared tokenizer for all model types 10 | private let tokenizer = AsyncFactory { 11 | CLIPTokenizer() 12 | } 13 | 14 | actor ZSImageClassification: ObservableObject { 15 | 16 | private let ciContext = CIContext() 17 | private var model: any CLIPEncoder 18 | 19 | public init(model: any CLIPEncoder) { 20 | self.model = model 21 | } 22 | 23 | func load() async { 24 | async let t = tokenizer.get() 25 | async let m = model.load() 26 | _ = await (t, m) 27 | } 28 | 29 | public func setModel(_ model: ModelConfiguration) { 30 | self.model = model.factory() 31 | } 32 | 33 | // Compute Text Embeddings 34 | func computeTextEmbeddings(promptArr: [String]) async -> [MLMultiArray] { 35 | var textEmbeddings: [MLMultiArray] = [] 36 | do { 37 | for singlePrompt in promptArr { 38 | print("") 39 | print("Prompt text: \(singlePrompt)") 40 | 41 | // Tokenize the text query 42 | let inputIds = await tokenizer.get().encode_full(text: singlePrompt) 43 | 44 | // Convert [Int] to MultiArray 45 | let inputArray = try MLMultiArray(shape: [1, 77], dataType: .int32) 46 | for (index, element) in inputIds.enumerated() { 47 | inputArray[index] = NSNumber(value: element) 48 | } 49 | 50 | // Run the text model on the text query 51 | let output = try await model.encode(text: inputArray) 52 | textEmbeddings.append(output) 53 | } 54 | } catch { 55 | print(error.localizedDescription) 56 | } 57 | return textEmbeddings 58 | } 59 | 60 | // Compute Image Embeddings 61 | func computeImageEmbeddings(frame: CVPixelBuffer) async -> ( 62 | embedding: MLMultiArray, interval: CFTimeInterval 63 | )? { 64 | // prepare the image 65 | var image: CIImage? = CIImage(cvPixelBuffer: frame) 66 | image = image?.cropToSquare() 67 | image = image?.resize(size: model.targetImageSize) 68 | 69 | guard let image else { 70 | return nil 71 | } 72 | 73 | // output buffer 74 | let extent = image.extent 75 | let pixelFormat = kCVPixelFormatType_32ARGB 76 | var output: CVPixelBuffer? 77 | CVPixelBufferCreate(nil, Int(extent.width), Int(extent.height), pixelFormat, nil, &output) 78 | 79 | guard let output else { 80 | print("failed to create output CVPixelBuffer") 81 | return nil 82 | } 83 | 84 | ciContext.render(image, to: output) 85 | 86 | // Run image embedding 87 | do { 88 | let startTimer = CACurrentMediaTime() 89 | let output = try await model.encode(image: output) 90 | let endTimer = CACurrentMediaTime() 91 | let interval = endTimer - startTimer 92 | return (embedding: output, interval: interval) 93 | } catch { 94 | print(error.localizedDescription) 95 | return nil 96 | } 97 | } 98 | 99 | // Compute cosine similarity between embeddings 100 | nonisolated func cosineSimilarity(_ embedding1: MLMultiArray, _ embedding2: MLMultiArray) 101 | -> Float 102 | { 103 | 104 | // read the values out of the MLMultiArray in bulk 105 | let e1 = embedding1.withUnsafeBufferPointer(ofType: Float.self) { ptr in 106 | Array(ptr) 107 | } 108 | let e2 = embedding2.withUnsafeBufferPointer(ofType: Float.self) { ptr in 109 | Array(ptr) 110 | } 111 | 112 | // Get the dot product of the two embeddings 113 | let dotProduct: Float = zip(e1, e2).reduce(0.0) { $0 + $1.0 * $1.1 } 114 | 115 | // Get the magnitudes of the two embeddings 116 | let magnitude1: Float = sqrt(e1.reduce(0) { $0 + pow($1, 2) }) 117 | let magnitude2: Float = sqrt(e2.reduce(0) { $0 + pow($1, 2) }) 118 | 119 | // Get the cosine similarity 120 | let similarity = dotProduct / (magnitude1 * magnitude2) 121 | return similarity 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /ios_app/README.md: -------------------------------------------------------------------------------- 1 | # MobileCLIPExplore 2 | Demonstrates the use of **MobileCLIP** models for performing real-time zero-shot scene classification. Users can customize the prompts and provide classnames of their choice. 3 | 4 | ![Examples](docs/app_screenshots/examples.png) 5 | 6 | ## Background 7 | MobileCLIP is a new family of efficient image-text models optimized for runtime performance, trained with a novel and efficient training approach, namely multi-model reinforced training. 8 | 9 | ## Dependencies 10 | Requires **iOS 17.2** or later. 11 | 12 | ## Install 13 | 1. Download the required CoreML models from [this link](https://huggingface.co/apple/coreml-mobileclip) 14 | 15 | 2. Place these models in the following directory: 16 | 17 | `.../MobileCLIPExplore/MobileCLIPExplore/Models` 18 | 19 | ![Model Directory](docs/app_screenshots/model_directory_guide.png) 20 | 21 | 22 | ## Use 23 | 24 | ### Getting Started 25 | 1. Select from one of the four MobileCLIP models from the top menu. 26 | 2. Choose a prompt preset, or create a custom one. 27 | 3. Point the camera at a subject to display prediction results in the table below. 28 | 29 | The FPS (frames per second) being analyzed and response time (milliseconds) of the prediction results are displayed in the top left of the landing view. 30 | 31 | ### Creating and Editing a Prompt 32 | 1. To edit the prompt, select the prompt or pencil icon ✎. 33 | 2. Tap the start or end of a prompt to customize the prefix or suffix, respectively. 34 | 3. Use the Classnames field to add or remove items you would like the app to detect. 35 | 36 | ## Model Options 37 | The application allows the user to select from four different MobileCLIP models: 38 | 39 | | Name | Ideal Latency | IN-1k Zero-Shot
Top-1 Acc. (%) | Avg. Perf. (%)
on 38 datasets | 40 | |-----------------|---------------|-------|------------------------------------| 41 | | MobileCLIP-S0 | 2.1 ms | 67.8% | 58.1% | 42 | | MobileCLIP-S1 | 5.8 ms | 72.6% | 61.3% | 43 | | MobileCLIP-S2 | 6.9 ms | 74.4% | 63.7% | 44 | | MobileCLIP-B (LT)| 13.7 ms | 77.2% | 65.8% | 45 | 46 | ## Privacy 47 | All predictions are processed privately and securely using on-device models. 48 | -------------------------------------------------------------------------------- /ios_app/docs/app_screenshots/examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-mobileclip/1140b8d197e4ed7d56b3a92216ded98bb1c2ac87/ios_app/docs/app_screenshots/examples.png -------------------------------------------------------------------------------- /ios_app/docs/app_screenshots/model_directory_guide.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-mobileclip/1140b8d197e4ed7d56b3a92216ded98bb1c2ac87/ios_app/docs/app_screenshots/model_directory_guide.png -------------------------------------------------------------------------------- /mobileclip/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import json 7 | from typing import Optional, Union, Tuple, Any 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torchvision.transforms import ( 12 | CenterCrop, 13 | Compose, 14 | InterpolationMode, 15 | Resize, 16 | ToTensor, 17 | ) 18 | 19 | from mobileclip.clip import CLIP 20 | from mobileclip.modules.text.tokenizer import ( 21 | ClipTokenizer, 22 | ) 23 | from mobileclip.modules.common.mobileone import reparameterize_model 24 | 25 | 26 | def create_model_and_transforms( 27 | model_name: str, 28 | pretrained: Optional[str] = None, 29 | reparameterize: Optional[bool] = True, 30 | device: Union[str, torch.device] = "cpu", 31 | ) -> Tuple[nn.Module, Any, Any]: 32 | """ 33 | Method to instantiate model and pre-processing transforms necessary for inference. 34 | 35 | Args: 36 | model_name: Model name. Choose from ['mobileclip_s0', 'mobileclip_s1', 'mobileclip_s2', 'mobileclip_b'] 37 | pretrained: Location of pretrained checkpoint. 38 | reparameterize: When set to True, re-parameterizable branches get folded for faster inference. 39 | device: Device identifier for model placement. 40 | 41 | Returns: 42 | Tuple of instantiated model, and preprocessing transforms for inference. 43 | """ 44 | # Config files 45 | root_dir = os.path.dirname(os.path.abspath(__file__)) 46 | configs_dir = os.path.join(root_dir, "configs") 47 | model_cfg_file = os.path.join(configs_dir, model_name + ".json") 48 | 49 | # Get config from yaml file 50 | if not os.path.exists(model_cfg_file): 51 | raise ValueError(f"Unsupported model name: {model_name}") 52 | model_cfg = json.load(open(model_cfg_file, "r")) 53 | 54 | # Build preprocessing transforms for inference 55 | resolution = model_cfg["image_cfg"]["image_size"] 56 | resize_size = resolution 57 | centercrop_size = resolution 58 | aug_list = [ 59 | Resize( 60 | resize_size, 61 | interpolation=InterpolationMode.BILINEAR, 62 | ), 63 | CenterCrop(centercrop_size), 64 | ToTensor(), 65 | ] 66 | preprocess = Compose(aug_list) 67 | 68 | # Build model 69 | model = CLIP(cfg=model_cfg) 70 | model.to(device) 71 | model.eval() 72 | 73 | # Load checkpoint 74 | if pretrained is not None: 75 | chkpt = torch.load(pretrained) 76 | model.load_state_dict(chkpt) 77 | 78 | # Reparameterize model for inference (if specified) 79 | if reparameterize: 80 | model = reparameterize_model(model) 81 | 82 | return model, None, preprocess 83 | 84 | 85 | def get_tokenizer(model_name: str) -> nn.Module: 86 | # Config files 87 | root_dir = os.path.dirname(os.path.abspath(__file__)) 88 | configs_dir = os.path.join(root_dir, "configs") 89 | model_cfg_file = os.path.join(configs_dir, model_name + ".json") 90 | 91 | # Get config from yaml file 92 | model_cfg = json.load(open(model_cfg_file, "r")) 93 | 94 | # Build tokenizer 95 | text_tokenizer = ClipTokenizer(model_cfg) 96 | return text_tokenizer 97 | -------------------------------------------------------------------------------- /mobileclip/clip.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | """ Model schema in open_clip format for inference only. """ 6 | import math 7 | from typing import Any, Optional, Dict 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | from mobileclip.text_encoder import ( 14 | TextTransformer, 15 | ) 16 | 17 | from .image_encoder import MCi 18 | 19 | 20 | class CLIP(nn.Module): 21 | """Base class for multi-modal image-text data""" 22 | 23 | def __init__(self, cfg: Dict, output_dict: bool = False, *args, **kwargs) -> None: 24 | super().__init__() 25 | self.output_dict = output_dict 26 | self.projection_dim = cfg["embed_dim"] 27 | if self.projection_dim is None: 28 | raise ValueError("Please specify `embed_dim` in model config.") 29 | 30 | self.image_encoder = MCi( 31 | model_name=cfg["image_cfg"]["model_name"], 32 | projection_dim=self.projection_dim, 33 | ) 34 | self.text_encoder = TextTransformer( 35 | cfg=cfg["text_cfg"], projection_dim=self.projection_dim 36 | ) 37 | self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1.0 / 0.07)) 38 | 39 | def _exponentiate_and_clip_logits(self, max_scale: float = 100.0): 40 | scale = self.logit_scale.exp() 41 | scale = torch.clamp(scale, 0, max_scale) 42 | return scale 43 | 44 | def encode_image(self, image: torch.Tensor, normalize: bool = False): 45 | image_encoder_out = self.image_encoder(image) 46 | if isinstance(image_encoder_out, dict): 47 | features = image_encoder_out["logits"] 48 | else: 49 | features = image_encoder_out 50 | return F.normalize(features, dim=-1) if normalize else features 51 | 52 | def encode_text(self, text: torch.Tensor, normalize: bool = False): 53 | text_features = self.text_encoder(text_tokens=text, key_padding_mask=None) 54 | return F.normalize(text_features, dim=-1) if normalize else text_features 55 | 56 | def forward( 57 | self, 58 | image: Optional[torch.Tensor] = None, 59 | text: Optional[torch.Tensor] = None, 60 | *args, 61 | **kwargs 62 | ) -> Any: 63 | 64 | image_embeddings = ( 65 | self.encode_image(image, normalize=True) if image is not None else None 66 | ) 67 | text_embeddings = ( 68 | self.encode_text(text, normalize=True) if text is not None else None 69 | ) 70 | 71 | if self.output_dict: 72 | return { 73 | "image_features": image_embeddings, 74 | "text_features": text_embeddings, 75 | "logit_scale": self._exponentiate_and_clip_logits(), 76 | } 77 | return image_embeddings, text_embeddings, self._exponentiate_and_clip_logits() 78 | -------------------------------------------------------------------------------- /mobileclip/configs/mobileclip_b.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "image_cfg": { 4 | "image_size": 224, 5 | "model_name": "vit_b16" 6 | }, 7 | "text_cfg": { 8 | "context_length": 77, 9 | "vocab_size": 49408, 10 | "dim": 512, 11 | "ffn_multiplier_per_layer": 4.0, 12 | "n_heads_per_layer": 8, 13 | "n_transformer_layers": 12, 14 | "norm_layer": "layer_norm_fp32", 15 | "causal_masking": true, 16 | "model_name": "base" 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /mobileclip/configs/mobileclip_s0.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "image_cfg": { 4 | "image_size": 256, 5 | "model_name": "mci0" 6 | }, 7 | "text_cfg": { 8 | "context_length": 77, 9 | "vocab_size": 49408, 10 | "dim": 512, 11 | "ffn_multiplier_per_layer": 4.0, 12 | "n_heads_per_layer": 8, 13 | "n_transformer_layers": 4, 14 | "norm_layer": "layer_norm_fp32", 15 | "causal_masking": false, 16 | "model_name": "mct" 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /mobileclip/configs/mobileclip_s1.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "image_cfg": { 4 | "image_size": 256, 5 | "model_name": "mci1" 6 | }, 7 | "text_cfg": { 8 | "context_length": 77, 9 | "vocab_size": 49408, 10 | "dim": 512, 11 | "ffn_multiplier_per_layer": 4.0, 12 | "n_heads_per_layer": 8, 13 | "n_transformer_layers": 12, 14 | "norm_layer": "layer_norm_fp32", 15 | "causal_masking": false, 16 | "model_name": "base" 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /mobileclip/configs/mobileclip_s2.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "image_cfg": { 4 | "image_size": 256, 5 | "model_name": "mci2" 6 | }, 7 | "text_cfg": { 8 | "context_length": 77, 9 | "vocab_size": 49408, 10 | "dim": 512, 11 | "ffn_multiplier_per_layer": 4.0, 12 | "n_heads_per_layer": 8, 13 | "n_transformer_layers": 12, 14 | "norm_layer": "layer_norm_fp32", 15 | "causal_masking": false, 16 | "model_name": "base" 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /mobileclip/image_encoder.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from typing import Any 6 | 7 | import torch.nn as nn 8 | from timm.models import create_model 9 | 10 | from mobileclip import models # Added to register models 11 | from mobileclip.modules.image.image_projection import GlobalPool2D 12 | 13 | 14 | class MCi(nn.Module): 15 | """ 16 | This class implements `MCi Models `_ 17 | """ 18 | 19 | def __init__(self, model_name: str, *args, **kwargs) -> None: 20 | super().__init__() 21 | self.projection_dim = None 22 | if "projection_dim" in kwargs: 23 | self.projection_dim = kwargs.get("projection_dim") 24 | 25 | # Create model 26 | self.model = create_model(model_name, projection_dim=self.projection_dim) 27 | 28 | # Build out projection head. 29 | if self.projection_dim is not None: 30 | if hasattr(self.model, "head"): 31 | self.model.head = MCi._update_image_classifier( 32 | image_classifier=self.model.head, projection_dim=self.projection_dim 33 | ) 34 | 35 | def forward(self, x: Any, *args, **kwargs) -> Any: 36 | """A forward function of the model.""" 37 | x = self.model(x) 38 | return x 39 | 40 | @staticmethod 41 | def _get_in_feature_dimension(image_classifier: nn.Module) -> int: 42 | """Return the input feature dimension to the image classification head.""" 43 | in_features = None 44 | if isinstance(image_classifier, nn.Sequential): 45 | # Classifier that uses nn.Sequential usually has global pooling and 46 | # multiple linear layers. Find the first linear layer and get its 47 | # in_features 48 | for layer in image_classifier: 49 | if isinstance(layer, nn.Linear): 50 | in_features = layer.in_features 51 | break 52 | elif isinstance(image_classifier, nn.Linear): 53 | in_features = image_classifier.in_features 54 | 55 | if in_features is None: 56 | raise NotImplementedError( 57 | f"Cannot get input feature dimension of {image_classifier}." 58 | ) 59 | return in_features 60 | 61 | @staticmethod 62 | def _update_image_classifier( 63 | image_classifier: nn.Module, projection_dim: int, *args, **kwargs 64 | ) -> nn.Module: 65 | in_features = MCi._get_in_feature_dimension(image_classifier) 66 | new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim) 67 | return new_img_classifier 68 | -------------------------------------------------------------------------------- /mobileclip/logger.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | import sys 8 | import time 9 | import traceback 10 | from typing import Optional, Union 11 | 12 | text_colors = { 13 | "logs": "\033[34m", # 033 is the escape code and 34 is the color code 14 | "info": "\033[32m", 15 | "warning": "\033[33m", 16 | "debug": "\033[93m", 17 | "error": "\033[31m", 18 | "bold": "\033[1m", 19 | "end_color": "\033[0m", 20 | "light_red": "\033[36m", 21 | } 22 | 23 | 24 | def get_curr_time_stamp() -> str: 25 | return time.strftime("%Y-%m-%d %H:%M:%S") 26 | 27 | 28 | def error(message: str) -> None: 29 | time_stamp = get_curr_time_stamp() 30 | error_str = ( 31 | text_colors["error"] 32 | + text_colors["bold"] 33 | + "ERROR " 34 | + text_colors["end_color"] 35 | ) 36 | 37 | # exiting with code -1 does not tell any information about the error (e.g., NaN encountered in the loss). 38 | # For more descriptive error messages, we replace exit(-1) with sys.exit(ERROR_MESSAGE). 39 | # This allows us to handle specific exceptions in the tests. 40 | 41 | # print("{} - {} - {}".format(time_stamp, error_str, message), flush=True) 42 | # print("{} - {} - {}".format(time_stamp, error_str, "Exiting!!!"), flush=True) 43 | # exit(-1) 44 | 45 | if sys.exc_info()[0] is None: 46 | traceback.print_stack() 47 | else: 48 | traceback.print_exc() 49 | sys.exit("{} - {} - {}. Exiting!!!".format(time_stamp, error_str, message)) 50 | 51 | 52 | def color_text(in_text: str) -> str: 53 | return text_colors["light_red"] + in_text + text_colors["end_color"] 54 | 55 | 56 | def log(message: str, end="\n") -> None: 57 | time_stamp = get_curr_time_stamp() 58 | log_str = ( 59 | text_colors["logs"] + text_colors["bold"] + "LOGS " + text_colors["end_color"] 60 | ) 61 | print("{} - {} - {}".format(time_stamp, log_str, message), end=end) 62 | 63 | 64 | def warning(message: Union[str, Warning]) -> None: 65 | if isinstance(message, Warning): 66 | message = f"{type(message).__name__}({','.join(map(repr, message.args))}" 67 | 68 | time_stamp = get_curr_time_stamp() 69 | warn_str = ( 70 | text_colors["warning"] 71 | + text_colors["bold"] 72 | + "WARNING" 73 | + text_colors["end_color"] 74 | ) 75 | print("{} - {} - {}".format(time_stamp, warn_str, message)) 76 | 77 | 78 | def ignore_exception_with_warning(message: str) -> None: 79 | """ 80 | After catching a tolerable exception E1 (e.g. when Model.forward() fails during 81 | profiling with try-catch, it'll be helpful to log the exception for future 82 | investigation. But printing the error stack trace, as is, could be confusing 83 | when an uncaught (non-tolerable) exception "E2" raises down the road. Then, the log 84 | will contain two stack traces for E1, E2. When looking for errors in logs, users 85 | should look for E2, but they may find E1. 86 | 87 | This function appends "(WARNING)" at the end of all lines of the E1 traceback, so 88 | that the user can distinguish E1 from uncaught exception E2. 89 | 90 | Args: 91 | message: Extra explanation and context for debugging. (Note: the exception obj 92 | will be automatically fetched from python. No need to pass it as an argument or as 93 | message) 94 | """ 95 | warning(f"{message}:\n{traceback.format_exc()}".replace("\n", "\n(WARNING)")) 96 | 97 | 98 | def info(message: str, print_line: Optional[bool] = False) -> None: 99 | time_stamp = get_curr_time_stamp() 100 | info_str = ( 101 | text_colors["info"] + text_colors["bold"] + "INFO " + text_colors["end_color"] 102 | ) 103 | print("{} - {} - {}".format(time_stamp, info_str, message)) 104 | if print_line: 105 | double_dash_line(dashes=150) 106 | 107 | 108 | def debug(message: str) -> None: 109 | time_stamp = get_curr_time_stamp() 110 | log_str = ( 111 | text_colors["debug"] 112 | + text_colors["bold"] 113 | + "DEBUG " 114 | + text_colors["end_color"] 115 | ) 116 | print("{} - {} - {}".format(time_stamp, log_str, message)) 117 | 118 | 119 | def double_dash_line(dashes: Optional[int] = 75) -> None: 120 | print(text_colors["error"] + "=" * dashes + text_colors["end_color"]) 121 | 122 | 123 | def singe_dash_line(dashes: Optional[int] = 67) -> None: 124 | print("-" * dashes) 125 | 126 | 127 | def print_header(header: str) -> None: 128 | double_dash_line() 129 | print( 130 | text_colors["info"] 131 | + text_colors["bold"] 132 | + "=" * 50 133 | + str(header) 134 | + text_colors["end_color"] 135 | ) 136 | double_dash_line() 137 | 138 | 139 | def print_header_minor(header: str) -> None: 140 | print( 141 | text_colors["warning"] 142 | + text_colors["bold"] 143 | + "=" * 25 144 | + str(header) 145 | + text_colors["end_color"] 146 | ) 147 | 148 | 149 | def disable_printing(): 150 | sys.stdout = open(os.devnull, "w") 151 | 152 | 153 | def enable_printing(): 154 | sys.stdout = sys.__stdout__ 155 | -------------------------------------------------------------------------------- /mobileclip/models/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All rights reserved. 4 | # 5 | from .mci import ( 6 | mci0, 7 | mci1, 8 | mci2, 9 | ) 10 | from .vit import vit_b16 11 | -------------------------------------------------------------------------------- /mobileclip/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All rights reserved. 4 | # 5 | -------------------------------------------------------------------------------- /mobileclip/modules/common/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All rights reserved. 4 | # 5 | -------------------------------------------------------------------------------- /mobileclip/modules/common/mobileone.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from typing import Union, Tuple 6 | 7 | import copy 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | __all__ = ["MobileOneBlock", "reparameterize_model"] 13 | 14 | 15 | class SEBlock(nn.Module): 16 | """Squeeze and Excite module. 17 | 18 | Pytorch implementation of `Squeeze-and-Excitation Networks` - 19 | https://arxiv.org/pdf/1709.01507.pdf 20 | """ 21 | 22 | def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None: 23 | """Construct a Squeeze and Excite Module. 24 | 25 | Args: 26 | in_channels: Number of input channels. 27 | rd_ratio: Input channel reduction ratio. 28 | """ 29 | super(SEBlock, self).__init__() 30 | self.reduce = nn.Conv2d( 31 | in_channels=in_channels, 32 | out_channels=int(in_channels * rd_ratio), 33 | kernel_size=1, 34 | stride=1, 35 | bias=True, 36 | ) 37 | self.expand = nn.Conv2d( 38 | in_channels=int(in_channels * rd_ratio), 39 | out_channels=in_channels, 40 | kernel_size=1, 41 | stride=1, 42 | bias=True, 43 | ) 44 | 45 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 46 | """Apply forward pass.""" 47 | b, c, h, w = inputs.size() 48 | x = F.avg_pool2d(inputs, kernel_size=[h, w]) 49 | x = self.reduce(x) 50 | x = F.relu(x) 51 | x = self.expand(x) 52 | x = torch.sigmoid(x) 53 | x = x.view(-1, c, 1, 1) 54 | return inputs * x 55 | 56 | 57 | class MobileOneBlock(nn.Module): 58 | """MobileOne building block. 59 | 60 | This block has a multi-branched architecture at train-time 61 | and plain-CNN style architecture at inference time 62 | For more details, please refer to our paper: 63 | `An Improved One millisecond Mobile Backbone` - 64 | https://arxiv.org/pdf/2206.04040.pdf 65 | """ 66 | 67 | def __init__( 68 | self, 69 | in_channels: int, 70 | out_channels: int, 71 | kernel_size: int, 72 | stride: int = 1, 73 | padding: int = 0, 74 | dilation: int = 1, 75 | groups: int = 1, 76 | inference_mode: bool = False, 77 | use_se: bool = False, 78 | use_act: bool = True, 79 | use_scale_branch: bool = True, 80 | num_conv_branches: int = 1, 81 | activation: nn.Module = nn.GELU(), 82 | ) -> None: 83 | """Construct a MobileOneBlock module. 84 | 85 | Args: 86 | in_channels: Number of channels in the input. 87 | out_channels: Number of channels produced by the block. 88 | kernel_size: Size of the convolution kernel. 89 | stride: Stride size. 90 | padding: Zero-padding size. 91 | dilation: Kernel dilation factor. 92 | groups: Group number. 93 | inference_mode: If True, instantiates model in inference mode. 94 | use_se: Whether to use SE-ReLU activations. 95 | use_act: Whether to use activation. Default: ``True`` 96 | use_scale_branch: Whether to use scale branch. Default: ``True`` 97 | num_conv_branches: Number of linear conv branches. 98 | """ 99 | super(MobileOneBlock, self).__init__() 100 | self.inference_mode = inference_mode 101 | self.groups = groups 102 | self.stride = stride 103 | self.padding = padding 104 | self.dilation = dilation 105 | self.kernel_size = kernel_size 106 | self.in_channels = in_channels 107 | self.out_channels = out_channels 108 | self.num_conv_branches = num_conv_branches 109 | 110 | # Check if SE-ReLU is requested 111 | if use_se: 112 | self.se = SEBlock(out_channels) 113 | else: 114 | self.se = nn.Identity() 115 | 116 | if use_act: 117 | self.activation = activation 118 | else: 119 | self.activation = nn.Identity() 120 | 121 | if inference_mode: 122 | self.reparam_conv = nn.Conv2d( 123 | in_channels=in_channels, 124 | out_channels=out_channels, 125 | kernel_size=kernel_size, 126 | stride=stride, 127 | padding=padding, 128 | dilation=dilation, 129 | groups=groups, 130 | bias=True, 131 | ) 132 | else: 133 | # Re-parameterizable skip connection 134 | self.rbr_skip = ( 135 | nn.BatchNorm2d(num_features=in_channels) 136 | if out_channels == in_channels and stride == 1 137 | else None 138 | ) 139 | 140 | # Re-parameterizable conv branches 141 | if num_conv_branches > 0: 142 | rbr_conv = list() 143 | for _ in range(self.num_conv_branches): 144 | rbr_conv.append( 145 | self._conv_bn(kernel_size=kernel_size, padding=padding) 146 | ) 147 | self.rbr_conv = nn.ModuleList(rbr_conv) 148 | else: 149 | self.rbr_conv = None 150 | 151 | # Re-parameterizable scale branch 152 | self.rbr_scale = None 153 | if not isinstance(kernel_size, int): 154 | kernel_size = kernel_size[0] 155 | if (kernel_size > 1) and use_scale_branch: 156 | self.rbr_scale = self._conv_bn(kernel_size=1, padding=0) 157 | 158 | def forward(self, x: torch.Tensor) -> torch.Tensor: 159 | """Apply forward pass.""" 160 | # Inference mode forward pass. 161 | if self.inference_mode: 162 | return self.activation(self.se(self.reparam_conv(x))) 163 | 164 | # Multi-branched train-time forward pass. 165 | # Skip branch output 166 | identity_out = 0 167 | if self.rbr_skip is not None: 168 | identity_out = self.rbr_skip(x) 169 | 170 | # Scale branch output 171 | scale_out = 0 172 | if self.rbr_scale is not None: 173 | scale_out = self.rbr_scale(x) 174 | 175 | # Other branches 176 | out = scale_out + identity_out 177 | if self.rbr_conv is not None: 178 | for ix in range(self.num_conv_branches): 179 | out += self.rbr_conv[ix](x) 180 | 181 | return self.activation(self.se(out)) 182 | 183 | def reparameterize(self): 184 | """Following works like `RepVGG: Making VGG-style ConvNets Great Again` - 185 | https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched 186 | architecture used at training time to obtain a plain CNN-like structure 187 | for inference. 188 | """ 189 | if self.inference_mode: 190 | return 191 | kernel, bias = self._get_kernel_bias() 192 | self.reparam_conv = nn.Conv2d( 193 | in_channels=self.in_channels, 194 | out_channels=self.out_channels, 195 | kernel_size=self.kernel_size, 196 | stride=self.stride, 197 | padding=self.padding, 198 | dilation=self.dilation, 199 | groups=self.groups, 200 | bias=True, 201 | ) 202 | self.reparam_conv.weight.data = kernel 203 | self.reparam_conv.bias.data = bias 204 | 205 | # Delete un-used branches 206 | for para in self.parameters(): 207 | para.detach_() 208 | self.__delattr__("rbr_conv") 209 | self.__delattr__("rbr_scale") 210 | if hasattr(self, "rbr_skip"): 211 | self.__delattr__("rbr_skip") 212 | 213 | self.inference_mode = True 214 | 215 | def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: 216 | """Method to obtain re-parameterized kernel and bias. 217 | Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 218 | 219 | Returns: 220 | Tuple of (kernel, bias) after fusing branches. 221 | """ 222 | # get weights and bias of scale branch 223 | kernel_scale = 0 224 | bias_scale = 0 225 | if self.rbr_scale is not None: 226 | kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale) 227 | # Pad scale branch kernel to match conv branch kernel size. 228 | pad = self.kernel_size // 2 229 | kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad]) 230 | 231 | # get weights and bias of skip branch 232 | kernel_identity = 0 233 | bias_identity = 0 234 | if self.rbr_skip is not None: 235 | kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip) 236 | 237 | # get weights and bias of conv branches 238 | kernel_conv = 0 239 | bias_conv = 0 240 | if self.rbr_conv is not None: 241 | for ix in range(self.num_conv_branches): 242 | _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix]) 243 | kernel_conv += _kernel 244 | bias_conv += _bias 245 | 246 | kernel_final = kernel_conv + kernel_scale + kernel_identity 247 | bias_final = bias_conv + bias_scale + bias_identity 248 | return kernel_final, bias_final 249 | 250 | def _fuse_bn_tensor( 251 | self, branch: Union[nn.Sequential, nn.BatchNorm2d] 252 | ) -> Tuple[torch.Tensor, torch.Tensor]: 253 | """Method to fuse batchnorm layer with preceeding conv layer. 254 | Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 255 | 256 | Args: 257 | branch: Sequence of ops to be fused. 258 | 259 | Returns: 260 | Tuple of (kernel, bias) after fusing batchnorm. 261 | """ 262 | if isinstance(branch, nn.Sequential): 263 | kernel = branch.conv.weight 264 | running_mean = branch.bn.running_mean 265 | running_var = branch.bn.running_var 266 | gamma = branch.bn.weight 267 | beta = branch.bn.bias 268 | eps = branch.bn.eps 269 | else: 270 | assert isinstance(branch, nn.BatchNorm2d) 271 | if not hasattr(self, "id_tensor"): 272 | input_dim = self.in_channels // self.groups 273 | 274 | kernel_size = self.kernel_size 275 | if isinstance(self.kernel_size, int): 276 | kernel_size = (self.kernel_size, self.kernel_size) 277 | 278 | kernel_value = torch.zeros( 279 | (self.in_channels, input_dim, kernel_size[0], kernel_size[1]), 280 | dtype=branch.weight.dtype, 281 | device=branch.weight.device, 282 | ) 283 | for i in range(self.in_channels): 284 | kernel_value[ 285 | i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2 286 | ] = 1 287 | self.id_tensor = kernel_value 288 | kernel = self.id_tensor 289 | running_mean = branch.running_mean 290 | running_var = branch.running_var 291 | gamma = branch.weight 292 | beta = branch.bias 293 | eps = branch.eps 294 | std = (running_var + eps).sqrt() 295 | t = (gamma / std).reshape(-1, 1, 1, 1) 296 | return kernel * t, beta - running_mean * gamma / std 297 | 298 | def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential: 299 | """Helper method to construct conv-batchnorm layers. 300 | 301 | Args: 302 | kernel_size: Size of the convolution kernel. 303 | padding: Zero-padding size. 304 | 305 | Returns: 306 | Conv-BN module. 307 | """ 308 | mod_list = nn.Sequential() 309 | mod_list.add_module( 310 | "conv", 311 | nn.Conv2d( 312 | in_channels=self.in_channels, 313 | out_channels=self.out_channels, 314 | kernel_size=kernel_size, 315 | stride=self.stride, 316 | padding=padding, 317 | groups=self.groups, 318 | bias=False, 319 | ), 320 | ) 321 | mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels)) 322 | return mod_list 323 | 324 | 325 | def reparameterize_model(model: torch.nn.Module) -> nn.Module: 326 | """Method returns a model where a multi-branched structure 327 | used in training is re-parameterized into a single branch 328 | for inference. 329 | 330 | Args: 331 | model: MobileOne model in train mode. 332 | 333 | Returns: 334 | MobileOne model in inference mode. 335 | """ 336 | # Avoid editing original graph 337 | model = copy.deepcopy(model) 338 | for module in model.modules(): 339 | if hasattr(module, "reparameterize"): 340 | module.reparameterize() 341 | return model 342 | -------------------------------------------------------------------------------- /mobileclip/modules/image/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All rights reserved. 4 | # 5 | -------------------------------------------------------------------------------- /mobileclip/modules/image/image_projection.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from typing import List, Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch import Tensor 10 | 11 | from mobileclip import logger 12 | 13 | 14 | class GlobalPool(nn.Module): 15 | """ 16 | This layers applies global pooling over a 4D or 5D input tensor 17 | 18 | Args: 19 | pool_type (Optional[str]): Pooling type. It can be mean, rms, or abs. Default: `mean` 20 | keep_dim (Optional[bool]): Do not squeeze the dimensions of a tensor. Default: `False` 21 | 22 | Shape: 23 | - Input: :math:`(N, C, H, W)` or :math:`(N, C, D, H, W)` 24 | - Output: :math:`(N, C, 1, 1)` or :math:`(N, C, 1, 1, 1)` if keep_dim else :math:`(N, C)` 25 | """ 26 | 27 | pool_types = ["mean", "rms", "abs"] 28 | 29 | def __init__( 30 | self, 31 | pool_type: Optional[str] = "mean", 32 | keep_dim: Optional[bool] = False, 33 | *args, 34 | **kwargs 35 | ) -> None: 36 | super().__init__() 37 | if pool_type not in self.pool_types: 38 | logger.error( 39 | "Supported pool types are: {}. Got {}".format( 40 | self.pool_types, pool_type 41 | ) 42 | ) 43 | self.pool_type = pool_type 44 | self.keep_dim = keep_dim 45 | 46 | def _global_pool(self, x: Tensor, dims: List): 47 | if self.pool_type == "rms": # root mean square 48 | x = x**2 49 | x = torch.mean(x, dim=dims, keepdim=self.keep_dim) 50 | x = x**-0.5 51 | elif self.pool_type == "abs": # absolute 52 | x = torch.mean(torch.abs(x), dim=dims, keepdim=self.keep_dim) 53 | else: 54 | # default is mean 55 | # same as AdaptiveAvgPool 56 | x = torch.mean(x, dim=dims, keepdim=self.keep_dim) 57 | return x 58 | 59 | def forward(self, x: Tensor) -> Tensor: 60 | if x.dim() == 4: 61 | dims = [-2, -1] 62 | elif x.dim() == 5: 63 | dims = [-3, -2, -1] 64 | else: 65 | raise NotImplementedError("Currently 2D and 3D global pooling supported") 66 | return self._global_pool(x, dims=dims) 67 | 68 | 69 | class GlobalPool2D(nn.Module): 70 | """This class implements global pooling with linear projection.""" 71 | 72 | def __init__(self, in_dim: int, out_dim: int, *args, **kwargs) -> None: 73 | super().__init__() 74 | scale = in_dim**-0.5 75 | self.pool = GlobalPool(pool_type="mean", keep_dim=False) 76 | self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim))) 77 | self.in_dim = in_dim 78 | self.out_dim = out_dim 79 | 80 | def forward(self, x: Tensor, *args, **kwargs) -> Tensor: 81 | # x is of shape [batch, in_dim] 82 | assert ( 83 | x.dim() == 4 84 | ), "Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {}".format( 85 | x.shape 86 | ) 87 | 88 | # [batch, in_dim, in_height, in_width] --> [batch, in_dim] 89 | x = self.pool(x) 90 | # [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim] 91 | x = x @ self.proj 92 | return x 93 | 94 | 95 | class SimpleImageProjectionHead(nn.Module): 96 | """This class implements linear projection head.""" 97 | 98 | def __init__(self, in_dim: int, out_dim: int) -> None: 99 | super().__init__() 100 | scale = in_dim**-0.5 101 | self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim))) 102 | self.in_dim = in_dim 103 | self.out_dim = out_dim 104 | 105 | def forward(self, x: Tensor, *args, **kwargs) -> Tensor: 106 | # x is of shape [batch, in_dim] 107 | assert ( 108 | x.dim() == 2 109 | ), "Input should be 2-dimensional (Batch x in_dim). Got: {}".format(x.shape) 110 | 111 | # [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim] 112 | x = x @ self.proj 113 | return x 114 | -------------------------------------------------------------------------------- /mobileclip/modules/image/replknet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For acknowledgement see accompanying ACKNOWLEDGEMENTS file. 3 | # Copyright (C) 2024 Apple Inc. All rights reserved. 4 | # 5 | from typing import Tuple 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from timm.models.layers import SqueezeExcite 11 | 12 | __all__ = ["ReparamLargeKernelConv"] 13 | 14 | 15 | class ReparamLargeKernelConv(nn.Module): 16 | """Building Block of RepLKNet 17 | 18 | This class defines overparameterized large kernel conv block 19 | introduced in `RepLKNet `_ 20 | 21 | Reference: https://github.com/DingXiaoH/RepLKNet-pytorch 22 | """ 23 | 24 | def __init__( 25 | self, 26 | in_channels: int, 27 | out_channels: int, 28 | kernel_size: int, 29 | stride: int, 30 | groups: int, 31 | small_kernel: int, 32 | inference_mode: bool = False, 33 | use_se: bool = False, 34 | activation: nn.Module = nn.GELU(), 35 | ) -> None: 36 | """Construct a ReparamLargeKernelConv module. 37 | 38 | Args: 39 | in_channels: Number of input channels. 40 | out_channels: Number of output channels. 41 | kernel_size: Kernel size of the large kernel conv branch. 42 | stride: Stride size. Default: 1 43 | groups: Group number. Default: 1 44 | small_kernel: Kernel size of small kernel conv branch. 45 | inference_mode: If True, instantiates model in inference mode. Default: ``False`` 46 | activation: Activation module. Default: ``nn.GELU`` 47 | """ 48 | super(ReparamLargeKernelConv, self).__init__() 49 | 50 | self.stride = stride 51 | self.groups = groups 52 | self.in_channels = in_channels 53 | self.out_channels = out_channels 54 | self.activation = activation 55 | 56 | self.kernel_size = kernel_size 57 | self.small_kernel = small_kernel 58 | self.padding = kernel_size // 2 59 | 60 | # Check if SE is requested 61 | if use_se: 62 | self.se = SqueezeExcite(out_channels, rd_ratio=0.25) 63 | else: 64 | self.se = nn.Identity() 65 | 66 | if inference_mode: 67 | self.lkb_reparam = nn.Conv2d( 68 | in_channels=in_channels, 69 | out_channels=out_channels, 70 | kernel_size=kernel_size, 71 | stride=stride, 72 | padding=self.padding, 73 | dilation=1, 74 | groups=groups, 75 | bias=True, 76 | ) 77 | else: 78 | self.lkb_origin = self._conv_bn( 79 | kernel_size=kernel_size, padding=self.padding 80 | ) 81 | if small_kernel is not None: 82 | assert ( 83 | small_kernel <= kernel_size 84 | ), "The kernel size for re-param cannot be larger than the large kernel!" 85 | self.small_conv = self._conv_bn( 86 | kernel_size=small_kernel, padding=small_kernel // 2 87 | ) 88 | 89 | def forward(self, x: torch.Tensor) -> torch.Tensor: 90 | """Apply forward pass.""" 91 | if hasattr(self, "lkb_reparam"): 92 | out = self.lkb_reparam(x) 93 | else: 94 | out = self.lkb_origin(x) 95 | if hasattr(self, "small_conv"): 96 | out += self.small_conv(x) 97 | 98 | return self.activation(self.se(out)) 99 | 100 | def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: 101 | """Method to obtain re-parameterized kernel and bias. 102 | Reference: https://github.com/DingXiaoH/RepLKNet-pytorch 103 | 104 | Returns: 105 | Tuple of (kernel, bias) after fusing branches. 106 | """ 107 | eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) 108 | if hasattr(self, "small_conv"): 109 | small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn) 110 | eq_b += small_b 111 | eq_k += nn.functional.pad( 112 | small_k, [(self.kernel_size - self.small_kernel) // 2] * 4 113 | ) 114 | return eq_k, eq_b 115 | 116 | def reparameterize(self) -> None: 117 | """ 118 | Following works like `RepVGG: Making VGG-style ConvNets Great Again` - 119 | https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched 120 | architecture used at training time to obtain a plain CNN-like structure 121 | for inference. 122 | """ 123 | eq_k, eq_b = self.get_kernel_bias() 124 | self.lkb_reparam = nn.Conv2d( 125 | in_channels=self.in_channels, 126 | out_channels=self.out_channels, 127 | kernel_size=self.kernel_size, 128 | stride=self.stride, 129 | padding=self.padding, 130 | dilation=self.lkb_origin.conv.dilation, 131 | groups=self.groups, 132 | bias=True, 133 | ) 134 | 135 | self.lkb_reparam.weight.data = eq_k 136 | self.lkb_reparam.bias.data = eq_b 137 | self.__delattr__("lkb_origin") 138 | if hasattr(self, "small_conv"): 139 | self.__delattr__("small_conv") 140 | 141 | @staticmethod 142 | def _fuse_bn( 143 | conv: torch.Tensor, bn: nn.BatchNorm2d 144 | ) -> Tuple[torch.Tensor, torch.Tensor]: 145 | """Method to fuse batchnorm layer with conv layer. 146 | 147 | Args: 148 | conv: Convolutional kernel weights. 149 | bn: Batchnorm 2d layer. 150 | 151 | Returns: 152 | Tuple of (kernel, bias) after fusing batchnorm. 153 | """ 154 | kernel = conv.weight 155 | running_mean = bn.running_mean 156 | running_var = bn.running_var 157 | gamma = bn.weight 158 | beta = bn.bias 159 | eps = bn.eps 160 | std = (running_var + eps).sqrt() 161 | t = (gamma / std).reshape(-1, 1, 1, 1) 162 | return kernel * t, beta - running_mean * gamma / std 163 | 164 | def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential: 165 | """Helper method to construct conv-batchnorm layers. 166 | 167 | Args: 168 | kernel_size: Size of the convolution kernel. 169 | padding: Zero-padding size. 170 | 171 | Returns: 172 | A nn.Sequential Conv-BN module. 173 | """ 174 | mod_list = nn.Sequential() 175 | mod_list.add_module( 176 | "conv", 177 | nn.Conv2d( 178 | in_channels=self.in_channels, 179 | out_channels=self.out_channels, 180 | kernel_size=kernel_size, 181 | stride=self.stride, 182 | padding=padding, 183 | groups=self.groups, 184 | bias=False, 185 | ), 186 | ) 187 | mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels)) 188 | return mod_list 189 | -------------------------------------------------------------------------------- /mobileclip/modules/text/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All rights reserved. 4 | # 5 | -------------------------------------------------------------------------------- /mobileclip/modules/text/repmixer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from timm.models.layers import DropPath, trunc_normal_ 11 | from mobileclip.modules.common.mobileone import MobileOneBlock 12 | 13 | 14 | class ConvFFN(nn.Module): 15 | """Convolutional FFN Module.""" 16 | 17 | def __init__( 18 | self, 19 | in_channels: int, 20 | context_size: int, 21 | hidden_channels: Optional[int] = None, 22 | out_channels: Optional[int] = None, 23 | act_layer: nn.Module = nn.GELU, 24 | drop: float = 0.0, 25 | ) -> None: 26 | """Build convolutional FFN module. 27 | 28 | Args: 29 | in_channels: Number of input channels. 30 | context_size: Context size for 1D signals. 31 | hidden_channels: Number of channels after expansion. Default: None 32 | out_channels: Number of output channels. Default: None 33 | act_layer: Activation layer. Default: ``GELU`` 34 | drop: Dropout rate. Default: ``0.0``. 35 | """ 36 | super().__init__() 37 | out_channels = out_channels or in_channels 38 | hidden_channels = hidden_channels or in_channels 39 | self.conv = nn.Sequential() 40 | self.conv.add_module( 41 | "conv", 42 | nn.Conv2d( 43 | in_channels=in_channels, 44 | out_channels=out_channels, 45 | kernel_size=(1, int(context_size)), 46 | padding=(0, int(context_size // 2)), 47 | groups=in_channels, 48 | bias=False, 49 | ), 50 | ) 51 | self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels)) 52 | self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1) 53 | self.act = act_layer() 54 | self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1) 55 | self.drop = nn.Dropout(drop) 56 | self.apply(self._init_weights) 57 | 58 | def _init_weights(self, m: nn.Module) -> None: 59 | if isinstance(m, nn.Conv2d): 60 | trunc_normal_(m.weight, std=0.02) 61 | if m.bias is not None: 62 | nn.init.constant_(m.bias, 0) 63 | 64 | def forward(self, x: torch.Tensor) -> torch.Tensor: 65 | x = self.conv(x) 66 | x = self.fc1(x) 67 | x = self.act(x) 68 | x = self.drop(x) 69 | x = self.fc2(x) 70 | x = self.drop(x) 71 | return x 72 | 73 | 74 | class RepMixer(nn.Module): 75 | """Reparameterizable token mixer. 76 | 77 | For more details, please refer to our paper: 78 | `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization `_ 79 | """ 80 | 81 | def __init__( 82 | self, 83 | dim, 84 | kernel_size=3, 85 | use_layer_scale=True, 86 | layer_scale_init_value=1e-5, 87 | inference_mode: bool = False, 88 | ): 89 | """Build RepMixer Module. 90 | 91 | Args: 92 | dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`. 93 | kernel_size: Kernel size for spatial mixing. Default: 3 94 | use_layer_scale: If True, learnable layer scale is used. Default: ``True`` 95 | layer_scale_init_value: Initial value for layer scale. Default: 1e-5 96 | inference_mode: If True, instantiates model in inference mode. Default: ``False`` 97 | """ 98 | super().__init__() 99 | self.dim = dim 100 | self.kernel_size = kernel_size 101 | self.inference_mode = inference_mode 102 | 103 | if inference_mode: 104 | self.reparam_conv = nn.Conv2d( 105 | in_channels=self.dim, 106 | out_channels=self.dim, 107 | kernel_size=(1, self.kernel_size), 108 | stride=1, 109 | padding=(0, self.kernel_size // 2), 110 | groups=self.dim, 111 | bias=True, 112 | ) 113 | else: 114 | self.norm = MobileOneBlock( 115 | dim, 116 | dim, 117 | (1, kernel_size), 118 | padding=(0, kernel_size // 2), 119 | groups=dim, 120 | use_act=False, 121 | use_scale_branch=False, 122 | num_conv_branches=0, 123 | ) 124 | self.mixer = MobileOneBlock( 125 | dim, 126 | dim, 127 | (1, kernel_size), 128 | padding=(0, kernel_size // 2), 129 | groups=dim, 130 | use_act=False, 131 | ) 132 | self.use_layer_scale = use_layer_scale 133 | if use_layer_scale: 134 | self.layer_scale = nn.Parameter( 135 | layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True 136 | ) 137 | 138 | def forward(self, x: torch.Tensor) -> torch.Tensor: 139 | if hasattr(self, "reparam_conv"): 140 | x = self.reparam_conv(x) 141 | return x 142 | else: 143 | if self.use_layer_scale: 144 | x = x + self.layer_scale * (self.mixer(x) - self.norm(x)) 145 | else: 146 | x = x + self.mixer(x) - self.norm(x) 147 | return x 148 | 149 | def reparameterize(self) -> None: 150 | """Reparameterize mixer and norm into a single 151 | convolutional layer for efficient inference. 152 | """ 153 | if self.inference_mode: 154 | return 155 | 156 | self.mixer.reparameterize() 157 | self.norm.reparameterize() 158 | 159 | if self.use_layer_scale: 160 | w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * ( 161 | self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight 162 | ) 163 | b = torch.squeeze(self.layer_scale) * ( 164 | self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias 165 | ) 166 | else: 167 | w = ( 168 | self.mixer.id_tensor 169 | + self.mixer.reparam_conv.weight 170 | - self.norm.reparam_conv.weight 171 | ) 172 | b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias 173 | 174 | self.reparam_conv = nn.Conv2d( 175 | in_channels=self.dim, 176 | out_channels=self.dim, 177 | kernel_size=(1, self.kernel_size), 178 | stride=1, 179 | padding=(0, self.kernel_size // 2), 180 | groups=self.dim, 181 | bias=True, 182 | ) 183 | self.reparam_conv.weight.data = w 184 | self.reparam_conv.bias.data = b 185 | 186 | for para in self.parameters(): 187 | para.detach_() 188 | self.__delattr__("mixer") 189 | self.__delattr__("norm") 190 | if self.use_layer_scale: 191 | self.__delattr__("layer_scale") 192 | 193 | 194 | class RepMixerBlock(nn.Module): 195 | """Implementation of Metaformer block with RepMixer as token mixer. 196 | 197 | For more details on Metaformer structure, please refer to: 198 | `MetaFormer Is Actually What You Need for Vision `_ 199 | """ 200 | 201 | def __init__( 202 | self, 203 | dim: int, 204 | kernel_size: int = 11, 205 | mlp_ratio: float = 4.0, 206 | act_layer: nn.Module = nn.GELU, 207 | drop: float = 0.0, 208 | drop_path: float = 0.0, 209 | use_layer_scale: bool = True, 210 | layer_scale_init_value: float = 1e-5, 211 | inference_mode: bool = False, 212 | *args, 213 | **kwargs, 214 | ): 215 | """Build RepMixer Block. 216 | 217 | Args: 218 | dim: Number of embedding dimensions. 219 | kernel_size: Kernel size for repmixer. Default: 3 220 | mlp_ratio: MLP expansion ratio. Default: 4.0 221 | act_layer: Activation layer. Default: ``nn.GELU`` 222 | drop: Dropout rate. Default: 0.0 223 | drop_path: Drop path rate. Default: 0.0 224 | use_layer_scale: Flag to turn on layer scale. Default: ``True`` 225 | layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 226 | inference_mode: Flag to instantiate block in inference mode. Default: ``False`` 227 | """ 228 | 229 | super().__init__() 230 | 231 | self.token_mixer = RepMixer( 232 | dim, 233 | kernel_size=kernel_size, 234 | use_layer_scale=use_layer_scale, 235 | layer_scale_init_value=layer_scale_init_value, 236 | inference_mode=inference_mode, 237 | ) 238 | 239 | assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( 240 | mlp_ratio 241 | ) 242 | mlp_hidden_dim = int(dim * mlp_ratio) 243 | self.convffn = ConvFFN( 244 | in_channels=dim, 245 | context_size=kernel_size, 246 | hidden_channels=mlp_hidden_dim, 247 | act_layer=act_layer, 248 | drop=drop, 249 | ) 250 | 251 | # Drop Path 252 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 253 | 254 | # Layer Scale 255 | self.use_layer_scale = use_layer_scale 256 | if use_layer_scale: 257 | self.layer_scale = nn.Parameter( 258 | layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True 259 | ) 260 | 261 | def forward(self, x, *args, **kwargs): 262 | if x.dim() == 3: 263 | # B, C, D --- where C is the context length 264 | # Convert to B, D, C --- to match RepMixer impl. 265 | x = x.permute(0, 2, 1) 266 | x = torch.unsqueeze(x, dim=2) 267 | else: 268 | raise ValueError( 269 | f"Expected tensor of dim=3, obtained tensor of dim={x.dim()}" 270 | ) 271 | 272 | if self.use_layer_scale: 273 | x = self.token_mixer(x) 274 | x = x + self.drop_path(self.layer_scale * self.convffn(x)) 275 | else: 276 | x = self.token_mixer(x) 277 | x = x + self.drop_path(self.convffn(x)) 278 | 279 | # Convert tensors back 280 | x = x.squeeze(dim=2).permute(0, 2, 1) 281 | return x 282 | -------------------------------------------------------------------------------- /mobileclip/modules/text/tokenizer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from typing import Dict 6 | 7 | import open_clip 8 | from torch import Tensor, nn 9 | 10 | 11 | class ClipTokenizer(nn.Module): 12 | def __init__(self, cfg, *args, **kwargs): 13 | super().__init__() 14 | self.context_length = cfg["text_cfg"]["context_length"] 15 | model_name = getattr(cfg["text_cfg"], "open_clip_tokenizer", "ViT-B-16") 16 | self.tokenizer = open_clip.get_tokenizer(model_name) 17 | 18 | def get_vocab_size(self) -> int: 19 | return len(self.tokenizer.encoder) 20 | 21 | def get_encodings(self) -> Dict[str, int]: 22 | return self.tokenizer.encoder 23 | 24 | def get_eot_token(self) -> int: 25 | # Tokenizing an empty string returns a list [sot_id, eot_id] 26 | return self.tokenizer("")[1] 27 | 28 | def get_sot_token(self) -> int: 29 | # Tokenizing an empty string returns a list [sot_id, eot_id] 30 | return self.tokenizer("")[0] 31 | 32 | def forward(self, input_sentence: str, *args, **kwargs) -> Tensor: 33 | # tokenizer returns indices as a string 34 | tokenized_sentence = self.tokenizer(input_sentence, self.context_length) 35 | assert ( 36 | tokenized_sentence.shape[-1] == self.context_length 37 | ), "Tokenized tensor should be exactly `context_length` long." 38 | return tokenized_sentence 39 | -------------------------------------------------------------------------------- /mobileclip/text_encoder.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import math 6 | from typing import Optional, Sequence 7 | 8 | import torch 9 | from torch import Tensor, nn 10 | 11 | from mobileclip.modules.common.transformer import ( 12 | PositionalEmbedding, 13 | TransformerEncoder, 14 | get_normalization_layer, 15 | ) 16 | from mobileclip.modules.text.repmixer import RepMixerBlock 17 | from mobileclip import logger 18 | 19 | 20 | class TextTransformer(nn.Module): 21 | def __init__(self, cfg: dict, projection_dim: int, *args, **kwargs) -> None: 22 | super().__init__() 23 | 24 | model_dim = cfg["dim"] 25 | no_scale_embedding = cfg.get("no_scale_embedding", False) 26 | no_pos_embedding = cfg.get("no_pos_embedding", False) 27 | embed_dropout = cfg.get("embed_dropout", 0.0) 28 | norm_layer = cfg["norm_layer"] 29 | variant = cfg["model_name"] 30 | self.vocab_size = cfg["vocab_size"] 31 | self.projection_dim = projection_dim 32 | 33 | # Token embedding layer 34 | self.embedding_layer = nn.Embedding( 35 | embedding_dim=model_dim, num_embeddings=self.vocab_size 36 | ) 37 | self.embed_scale = 1.0 if no_scale_embedding else model_dim**-0.5 38 | 39 | # Context length 40 | context_length = cfg["context_length"] 41 | assert ( 42 | context_length is not None 43 | ), "Context length can't be None. Please set value accordingly." 44 | 45 | self.positional_embedding = ( 46 | None 47 | if no_pos_embedding 48 | else PositionalEmbedding( 49 | num_embeddings=context_length, embedding_dim=model_dim 50 | ) 51 | ) 52 | 53 | self.embedding_dropout = nn.Dropout(p=embed_dropout) 54 | 55 | # Transformer layer 56 | n_transformer_layers = cfg["n_transformer_layers"] 57 | 58 | # FFN multipliers for transformer layer 59 | ffn_multipliers = cfg["ffn_multiplier_per_layer"] 60 | if isinstance(ffn_multipliers, (float, int)): 61 | ffn_multipliers = [ffn_multipliers] * n_transformer_layers 62 | 63 | if not isinstance(ffn_multipliers, Sequence): 64 | logger.error( 65 | "{} expects FFN multipliers as a list, whose length is the same as" 66 | " number of transformer layers. Got: {}".format( 67 | self.__class__.__name__, type(ffn_multipliers) 68 | ) 69 | ) 70 | elif ( 71 | isinstance(ffn_multipliers, Sequence) 72 | and len(ffn_multipliers) != n_transformer_layers 73 | ): 74 | logger.error( 75 | "We need FFN multiplier for each transformer layer. Got {} ffn" 76 | " multipliers while number of transformer layers = {}".format( 77 | len(ffn_multipliers), n_transformer_layers 78 | ) 79 | ) 80 | ffn_dims = [ 81 | int(math.ceil(model_dim * ffn_mult / 16.0) * 16.0) 82 | for ffn_mult in ffn_multipliers 83 | ] 84 | 85 | # Heads for transformer layers 86 | mha_heads = cfg["n_heads_per_layer"] 87 | if isinstance(mha_heads, int): 88 | mha_heads = [mha_heads] * n_transformer_layers 89 | 90 | if not isinstance(mha_heads, Sequence): 91 | logger.error( 92 | "{} expects MHA heads as a list, whose length is the same as number of " 93 | "transformer layers. Got: {}".format( 94 | self.__class__.__name__, type(mha_heads) 95 | ) 96 | ) 97 | elif isinstance(mha_heads, Sequence) and len(mha_heads) != n_transformer_layers: 98 | logger.error( 99 | "{} needs MHA heads for each transformer layer. Got {} mha heads while" 100 | " number of transformer layers = {}".format( 101 | self.__class__.__name__, len(mha_heads), n_transformer_layers 102 | ) 103 | ) 104 | 105 | if variant == "base": 106 | self.transformer = nn.ModuleList( 107 | [ 108 | TransformerEncoder( 109 | embed_dim=model_dim, 110 | num_heads=mha_heads[layer_idx], 111 | ffn_latent_dim=ffn_dims[layer_idx], 112 | transformer_norm_layer=norm_layer, 113 | ) 114 | for layer_idx in range(n_transformer_layers) 115 | ] 116 | ) 117 | elif variant == "mct": 118 | self.transformer = nn.ModuleList([RepMixerBlock(dim=model_dim)]) 119 | self.transformer.extend( 120 | [ 121 | TransformerEncoder( 122 | embed_dim=model_dim, 123 | num_heads=mha_heads[layer_idx], 124 | ffn_latent_dim=ffn_dims[layer_idx], 125 | transformer_norm_layer=norm_layer, 126 | ) 127 | for layer_idx in range(n_transformer_layers) 128 | ] 129 | ) 130 | self.transformer.extend([RepMixerBlock(dim=model_dim)]) 131 | else: 132 | raise ValueError("Unrecognized text encoder variant {}".format(variant)) 133 | 134 | self.final_layer_norm = get_normalization_layer( 135 | num_features=model_dim, norm_type=norm_layer 136 | ) 137 | 138 | self.projection_layer = nn.Parameter( 139 | torch.empty(model_dim, self.projection_dim) 140 | ) 141 | self.model_dim = model_dim 142 | self.causal_masking = cfg["causal_masking"] 143 | 144 | def forward_embedding(self, text_tokens: Tensor) -> Tensor: 145 | """Return text embedding for all tokens. 146 | 147 | Args: 148 | text_tokens: a tensor of token indices. Shape: [batch_size, context_length] 149 | 150 | Returns: 151 | A tensor of [batch_size, context_length, hidden_dim]. 152 | """ 153 | # [batch_size, context_length] --> [batch_size, context_length, hidden_dim] 154 | token_emb = self.embedding_layer(text_tokens) 155 | seq_len = token_emb.shape[1] 156 | if self.positional_embedding is not None: 157 | token_emb = token_emb + self.positional_embedding(seq_len).to( 158 | token_emb.dtype 159 | ) 160 | token_emb = self.embedding_dropout(token_emb) 161 | return token_emb 162 | 163 | def build_attention_mask(self, context_length: int, batch_size: int) -> Tensor: 164 | """Build causal attention mask [batch_size, context_length, context_length].""" 165 | # Build mask with full attention between the tokens 166 | # pytorch uses additive attention mask; fill with -inf 167 | mask = torch.empty(context_length, context_length) 168 | mask.fill_(float("-inf")) 169 | mask.triu_(1) # zero out the lower diagonal 170 | mask = mask.unsqueeze(0) # add dummy batch dimension 171 | mask = mask.expand(batch_size, -1, -1) 172 | return mask 173 | 174 | def encode_text( 175 | self, 176 | text_tokens: Tensor, 177 | key_padding_mask: Optional[Tensor] = None, 178 | return_all_tokens: bool = False, 179 | *args, 180 | **kwargs 181 | ) -> Tensor: 182 | """Return text token embeddings. 183 | 184 | Args: 185 | text_tokens: a tensor of token indices. Shape: [batch_size, context_length] 186 | key_padding_mask: a tensor of boolean values as the padding mask. 187 | Shape: [batch_size, context_length] 188 | return_all_tokens: a boolean flag to return all tokens, defaults to False 189 | to return only EOT token embedding. 190 | Returns: 191 | A tensor of [batch_size, context_length, hidden_dim] if return_all_tokens is 192 | True, otherwise a tensor of [batch_size, hidden_dim]. 193 | """ 194 | # Discrete tokens to continuous embeddings 195 | # [batch_size, context_length] --> [batch_size, context_length, hidden_dim] 196 | token_emb = self.forward_embedding(text_tokens) 197 | 198 | # [1, context_length, context_length] 199 | attn_mask = None 200 | if self.causal_masking: 201 | attn_mask = self.build_attention_mask( 202 | context_length=text_tokens.shape[1], batch_size=text_tokens.shape[0] 203 | ) 204 | attn_mask = attn_mask.to(device=token_emb.device, dtype=token_emb.dtype) 205 | key_padding_mask = None 206 | 207 | for layer in self.transformer: 208 | token_emb = layer( 209 | token_emb, 210 | key_padding_mask=key_padding_mask, 211 | attn_mask=attn_mask, 212 | ) 213 | 214 | # Apply layer norm 215 | token_emb = self.final_layer_norm(token_emb) 216 | 217 | if return_all_tokens: 218 | return token_emb 219 | 220 | # Take features from the eot embedding (eot_token is the highest number in each sequence) 221 | token_emb = token_emb[ 222 | torch.arange(text_tokens.shape[0]), text_tokens.argmax(dim=-1) 223 | ] 224 | 225 | token_emb = token_emb @ self.projection_layer 226 | return token_emb 227 | 228 | def forward( 229 | self, 230 | text_tokens: Tensor, 231 | key_padding_mask: Optional[Tensor] = None, 232 | return_all_tokens: bool = False, 233 | *args, 234 | **kwargs 235 | ) -> Tensor: 236 | # Image-text pair data with single caption 237 | # [B, CL] --> [B, d] 238 | text_tokens = self.encode_text( 239 | text_tokens=text_tokens, 240 | key_padding_mask=key_padding_mask, 241 | return_all_tokens=return_all_tokens, 242 | *args, 243 | **kwargs 244 | ) 245 | return text_tokens 246 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | clip-benchmark>=1.4.0 2 | datasets>=2.21.0 3 | open-clip-torch>=2.20.0 4 | timm>=0.9.5 5 | torch>=2.1.0 6 | torchvision==0.14.1 7 | -------------------------------------------------------------------------------- /results/mobileclip_s2.jsonl: -------------------------------------------------------------------------------- 1 | {"key": "vtab/caltech101", "dataset": "Caltech-101", "metrics": {"acc1": 0.8447000821692687, "acc5": 0.9541495480690222, "mean_per_class_recall": 0.936722834577359, "main_metric": 0.936722834577359}} 2 | {"key": "cifar10", "dataset": "CIFAR-10", "metrics": {"acc1": 0.9588, "acc5": 0.9995, "mean_per_class_recall": 0.9588000000000001, "main_metric": 0.9588}} 3 | {"key": "vtab/cifar100", "dataset": "CIFAR-100", "metrics": {"acc1": 0.8128, "acc5": 0.9667, "mean_per_class_recall": 0.8128, "main_metric": 0.8128}} 4 | {"key": "vtab/clevr_count_all", "dataset": "CLEVR Counts", "metrics": {"acc1": 0.3344, "acc5": 0.9235333333333333, "mean_per_class_recall": 0.33309194769139167, "main_metric": 0.3344}} 5 | {"key": "vtab/clevr_closest_object_distance", "dataset": "CLEVR Distance", "metrics": {"acc1": 0.2376, "acc5": 0.9186666666666666, "mean_per_class_recall": 0.17236004590330653, "main_metric": 0.2376}} 6 | {"key": "country211", "dataset": "Country211", "metrics": {"acc1": 0.19199052132701422, "acc5": 0.42241706161137443, "mean_per_class_recall": 0.19199052132701416, "main_metric": 0.19199052132701422}} 7 | {"key": "vtab/dtd", "dataset": "Describable Textures", "metrics": {"acc1": 0.6090425531914894, "acc5": 0.8803191489361702, "mean_per_class_recall": 0.6090425531914894, "main_metric": 0.6090425531914894}} 8 | {"key": "vtab/eurosat", "dataset": "EuroSAT", "metrics": {"acc1": 0.6672222222222223, "acc5": 0.9757407407407407, "mean_per_class_recall": 0.6710678146515698, "main_metric": 0.6672222222222223}} 9 | {"key": "fgvc_aircraft", "dataset": "FGVC Aircraft", "metrics": {"acc1": 0.30423042304230424, "acc5": 0.6753675367536753, "mean_per_class_recall": 0.30353832442067735, "main_metric": 0.30353832442067735}} 10 | {"key": "food101", "dataset": "Food-101", "metrics": {"acc1": 0.8967920792079208, "acc5": 0.9872475247524752, "mean_per_class_recall": 0.8967920792079209, "main_metric": 0.8967920792079208}} 11 | {"key": "gtsrb", "dataset": "GTSRB", "metrics": {"acc1": 0.5361836896278701, "acc5": 0.8277117973079968, "mean_per_class_recall": 0.4787885812648198, "main_metric": 0.5361836896278701}} 12 | {"key": "imagenet1k", "dataset": "ImageNet 1k", "metrics": {"acc1": 0.74438, "acc5": 0.94396, "mean_per_class_recall": 0.74424, "main_metric": 0.74438}} 13 | {"key": "imagenet_sketch", "dataset": "ImageNet Sketch", "metrics": {"acc1": 0.6221187289984083, "acc5": 0.8650985478197646, "mean_per_class_recall": 0.6223282352941176, "main_metric": 0.6221187289984083}} 14 | {"key": "imagenetv2", "dataset": "ImageNet v2", "metrics": {"acc1": 0.6679, "acc5": 0.9042, "mean_per_class_recall": 0.6684000000000001, "main_metric": 0.6679}} 15 | {"key": "imagenet-a", "dataset": "ImageNet-A", "metrics": {"acc1": 0.4932, "acc5": 0.8016, "mean_per_class_recall": 0.4960668896772007, "main_metric": 0.4932}} 16 | {"key": "imagenet-o", "dataset": "ImageNet-O", "metrics": {"acc1": 0.469, "acc5": 0.7885, "mean_per_class_recall": 0.48375934665609593, "main_metric": 0.469}} 17 | {"key": "imagenet-r", "dataset": "ImageNet-R", "metrics": {"acc1": 0.8698666666666667, "acc5": 0.9681333333333333, "mean_per_class_recall": 0.8553672907748989, "main_metric": 0.8698666666666667}} 18 | {"key": "vtab/kitti_closest_vehicle_distance", "dataset": "KITTI Vehicle Distance", "metrics": {"acc1": 0.2320675105485232, "acc5": null, "mean_per_class_recall": 0.32955606784304503, "main_metric": 0.2320675105485232}} 19 | {"key": "mnist", "dataset": "MNIST", "metrics": {"acc1": 0.8702, "acc5": 0.9823, "mean_per_class_recall": 0.8687912186633511, "main_metric": 0.8702}} 20 | {"key": "objectnet", "dataset": "ObjectNet", "metrics": {"acc1": 0.6656078389146118, "acc5": 0.8666953806396037, "mean_per_class_recall": 0.6523596249996781, "main_metric": 0.6656078389146118}} 21 | {"key": "vtab/flowers", "dataset": "Oxford Flowers-102", "metrics": {"acc1": 0.7503659131566108, "acc5": 0.88209464953651, "mean_per_class_recall": 0.7320917902269698, "main_metric": 0.7320917902269698}} 22 | {"key": "vtab/pets", "dataset": "Oxford-IIIT Pet", "metrics": {"acc1": 0.9283183428727174, "acc5": 0.9983646770237122, "mean_per_class_recall": 0.9278175397989731, "main_metric": 0.9278175397989731}} 23 | {"key": "voc2007", "dataset": "Pascal VOC 2007", "metrics": {"acc1": 0.8321981837606838, "acc5": 0.9822382478632479, "mean_per_class_recall": 0.8719232098874825, "main_metric": 0.8321981837606838}} 24 | {"key": "vtab/pcam", "dataset": "PatchCamelyon", "metrics": {"acc1": 0.68829345703125, "acc5": null, "mean_per_class_recall": 0.6882425499107128, "main_metric": 0.68829345703125}} 25 | {"key": "renderedsst2", "dataset": "Rendered SST2", "metrics": {"acc1": 0.5634266886326195, "acc5": null, "mean_per_class_recall": 0.5631477621446355, "main_metric": 0.5634266886326195}} 26 | {"key": "vtab/resisc45", "dataset": "RESISC45", "metrics": {"acc1": 0.686984126984127, "acc5": 0.9433333333333334, "mean_per_class_recall": 0.6934871572095898, "main_metric": 0.686984126984127}} 27 | {"key": "cars", "dataset": "Stanford Cars", "metrics": {"acc1": 0.8612112921278448, "acc5": 0.9953985822658873, "mean_per_class_recall": 0.8615237612708011, "main_metric": 0.8612112921278448}} 28 | {"key": "stl10", "dataset": "STL-10", "metrics": {"acc1": 0.987375, "acc5": 1.0, "mean_per_class_recall": 0.9873749999999999, "main_metric": 0.987375}} 29 | {"key": "sun397", "dataset": "SUN397", "metrics": {"acc1": 0.7306857678798021, "acc5": 0.9532063188480424, "mean_per_class_recall": 0.7194289168008176, "main_metric": 0.7306857678798021}} 30 | {"key": "vtab/svhn", "dataset": "SVHN", "metrics": {"acc1": 0.68899815611555, "acc5": 0.9524815611555009, "mean_per_class_recall": 0.6898340759970616, "main_metric": 0.68899815611555}} 31 | {"key": "retrieval/flickr_1k_test_image_text_retrieval", "dataset": "Flickr", "metrics": {"image_retrieval_recall@1": 0.7337999939918518, "text_retrieval_recall@1": 0.902999997138977, "image_retrieval_recall@5": 0.9229999780654907, "text_retrieval_recall@5": 0.9890000224113464, "image_retrieval_recall@10": 0.9562000036239624, "text_retrieval_recall@10": 0.9959999918937683, "mean_recall@1": 0.8183999955654144, "main_metric": 0.8183999955654144}} 32 | {"key": "retrieval/mscoco_2014_5k_test_image_text_retrieval", "dataset": "MSCOCO", "metrics": {"image_retrieval_recall@1": 0.4537385106086731, "text_retrieval_recall@1": 0.6335999965667725, "image_retrieval_recall@5": 0.7005198001861572, "text_retrieval_recall@5": 0.8514000177383423, "image_retrieval_recall@10": 0.7901239395141602, "text_retrieval_recall@10": 0.9138000011444092, "mean_recall@1": 0.5436692535877228, "main_metric": 0.5436692535877228}} 33 | {"key": "misc/winogavil", "dataset": "WinoGAViL", "metrics": {"avg_jaccard_score": 0.546845221389145, "jaccard_score_5": 0.5773989898989897, "jaccard_score_6": 0.561987041036717, "jaccard_score_10": 0.48857589984350547, "jaccard_score_12": 0.4614708500222519, "jaccard_score_5-6": 0.569496739264181, "jaccard_score_10-12": 0.47499163599866173, "main_metric": 0.47499163599866173}} 34 | {"key": "wilds/iwildcam", "dataset": "iWildCam", "metrics": {"acc1": 0.21609684279404548, "acc5": 0.4210231123367063, "mean_per_class_recall": 0.16389162255648526, "acc_avg": 0.21791966259479523, "recall-macro_all": 0.16389162255648526, "F1-macro_all": 0.15011020079756165, "main_metric": 0.15011020079756165}} 35 | {"key": "wilds/camelyon17", "dataset": "Camelyon17", "metrics": {"acc1": 0.7279022738495544, "acc5": null, "mean_per_class_recall": 0.7279022738495544, "acc_avg": 0.7279022932052612, "acc_slide:0": NaN, "count_slide:0": 0.0, "acc_slide:1": NaN, "count_slide:1": 0.0, "acc_slide:2": NaN, "count_slide:2": 0.0, "acc_slide:3": NaN, "count_slide:3": 0.0, "acc_slide:4": NaN, "count_slide:4": 0.0, "acc_slide:5": NaN, "count_slide:5": 0.0, "acc_slide:6": NaN, "count_slide:6": 0.0, "acc_slide:7": NaN, "count_slide:7": 0.0, "acc_slide:8": NaN, "count_slide:8": 0.0, "acc_slide:9": NaN, "count_slide:9": 0.0, "acc_slide:10": NaN, "count_slide:10": 0.0, "acc_slide:11": NaN, "count_slide:11": 0.0, "acc_slide:12": NaN, "count_slide:12": 0.0, "acc_slide:13": NaN, "count_slide:13": 0.0, "acc_slide:14": NaN, "count_slide:14": 0.0, "acc_slide:15": NaN, "count_slide:15": 0.0, "acc_slide:16": NaN, "count_slide:16": 0.0, "acc_slide:17": NaN, "count_slide:17": 0.0, "acc_slide:18": NaN, "count_slide:18": 0.0, "acc_slide:19": NaN, "count_slide:19": 0.0, "acc_slide:20": 0.9283464550971985, "count_slide:20": 3810.0, "acc_slide:21": 0.8716838359832764, "count_slide:21": 3694.0, "acc_slide:22": 0.6826629638671875, "count_slide:22": 7210.0, "acc_slide:23": 0.5633509755134583, "count_slide:23": 5288.0, "acc_slide:24": 0.8445709943771362, "count_slide:24": 7727.0, "acc_slide:25": 0.8880941271781921, "count_slide:25": 4334.0, "acc_slide:26": 0.7693315744400024, "count_slide:26": 3815.0, "acc_slide:27": 0.75, "count_slide:27": 4556.0, "acc_slide:28": 0.6695526838302612, "count_slide:28": 31878.0, "acc_slide:29": 0.7206090092658997, "count_slide:29": 12742.0, "acc_wg": 0.5633509755134583, "main_metric": 0.7279022738495544}} 36 | {"key": "wilds/fmow", "dataset": "FMoW", "metrics": {"acc1": 0.21539714130631446, "acc5": 0.5117604487063506, "mean_per_class_recall": 0.2129885499249436, "acc_avg": 0.21539713442325592, "acc_year:0": NaN, "count_year:0": 0.0, "acc_year:1": NaN, "count_year:1": 0.0, "acc_year:2": NaN, "count_year:2": 0.0, "acc_year:3": NaN, "count_year:3": 0.0, "acc_year:4": NaN, "count_year:4": 0.0, "acc_year:5": NaN, "count_year:5": 0.0, "acc_year:6": NaN, "count_year:6": 0.0, "acc_year:7": NaN, "count_year:7": 0.0, "acc_year:8": NaN, "count_year:8": 0.0, "acc_year:9": NaN, "count_year:9": 0.0, "acc_year:10": NaN, "count_year:10": 0.0, "acc_year:11": NaN, "count_year:11": 0.0, "acc_year:12": NaN, "count_year:12": 0.0, "acc_year:13": NaN, "count_year:13": 0.0, "acc_year:14": 0.2280218005180359, "count_year:14": 15959.0, "acc_year:15": 0.18263132870197296, "count_year:15": 6149.0, "acc_worst_year": 0.18263132870197296, "acc_region:0": 0.19040901958942413, "count_region:0": 4963.0, "acc_region:1": 0.23147831857204437, "count_region:1": 5858.0, "acc_region:2": 0.1392209827899933, "count_region:2": 2593.0, "acc_region:3": 0.23741276562213898, "count_region:3": 8024.0, "acc_region:4": 0.28828829526901245, "count_region:4": 666.0, "acc_region:5": 0.75, "count_region:5": 4.0, "acc_worst_region": 0.1392209827899933, "main_metric": 0.1392209827899933}} 37 | {"key": "fairness/dollar_street", "dataset": "Dollar Street", "metrics": {"acc1": 0.525834998572652, "acc5": 0.7904653154439052, "mean_per_class_recall": 0.5608762529814217, "acc_top5_avg": 0.7904652953147888, "acc_top5_income_ds:0": 0.6203271150588989, "count_income_ds:0": 856.0, "acc_top5_income_ds:1": 0.7952488660812378, "count_income_ds:1": 884.0, "acc_top5_income_ds:2": 0.8435072302818298, "count_income_ds:2": 901.0, "acc_top5_income_ds:3": 0.8990719318389893, "count_income_ds:3": 862.0, "acc_top5_wg": 0.6203271150588989, "main_metric": 0.6203271150588989}} 38 | {"key": "fairness/geode", "dataset": "GeoDE", "metrics": {"acc1": 0.9112748238308777, "acc5": 0.9938340807174888, "mean_per_class_recall": 0.9114425088759956, "acc_avg": 0.9112748503684998, "acc_region:0": 0.896450936794281, "count_region:0": 2395.0, "acc_region:1": 0.9134328365325928, "count_region:1": 2010.0, "acc_region:2": 0.9153339862823486, "count_region:2": 2126.0, "acc_region:3": 0.9070364832878113, "count_region:3": 1947.0, "acc_region:4": 0.9095048308372498, "count_region:4": 1757.0, "acc_region:5": 0.9263204336166382, "count_region:5": 2253.0, "acc_wg": 0.896450936794281, "main_metric": 0.896450936794281}} 39 | {"key": "fairness/fairface", "dataset": "FairFace", "metrics": {"acc_race_avg": 0.8975716829299927, "acc_race_race_binary:0": 0.5534772276878357, "count_race_binary:0": 2085.0, "acc_race_race_binary:1": 0.9784643054008484, "count_race_binary:1": 8869.0, "acc_race_wg": 0.5534772276878357, "acc_gender_avg": 0.9280627965927124, "acc_gender_race_binary:0": 0.9462829828262329, "acc_gender_race_binary:1": 0.9237794280052185, "acc_gender_wg": 0.9237794280052185, "acc_age_avg": 0.4575497508049011, "acc_age_race_binary:0": 0.46954435110092163, "acc_age_race_binary:1": 0.45472994446754456, "acc_age_wg": 0.45472994446754456, "acc_gender_x_avg": 0.9280627965927124, "acc_gender_x_race:0_gender:0": 0.7897371649742126, "count_race:0_gender:0": 799.0, "acc_gender_x_race:0_gender:1": 0.9537648558616638, "count_race:0_gender:1": 757.0, "acc_gender_x_race:1_gender:0": 0.9242424368858337, "count_race:1_gender:0": 1122.0, "acc_gender_x_race:1_gender:1": 0.9719626307487488, "count_race:1_gender:1": 963.0, "acc_gender_x_race:2_gender:0": 0.9030544757843018, "count_race:2_gender:0": 753.0, "acc_gender_x_race:2_gender:1": 0.9659239649772644, "count_race:2_gender:1": 763.0, "acc_gender_x_race:3_gender:0": 0.9041613936424255, "count_race:3_gender:0": 793.0, "acc_gender_x_race:3_gender:1": 0.9795180559158325, "count_race:3_gender:1": 830.0, "acc_gender_x_race:4_gender:0": 0.9606395959854126, "count_race:4_gender:0": 813.0, "acc_gender_x_race:4_gender:1": 0.9747474789619446, "count_race:4_gender:1": 396.0, "acc_gender_x_race:5_gender:0": 0.8476190567016602, "count_race:5_gender:0": 735.0, "acc_gender_x_race:5_gender:1": 0.9867647290229797, "count_race:5_gender:1": 680.0, "acc_gender_x_race:6_gender:0": 0.8661518692970276, "count_race:6_gender:0": 777.0, "acc_gender_x_race:6_gender:1": 0.9818887710571289, "count_race:6_gender:1": 773.0, "acc_gender_x_wg": 0.7897371649742126, "toxicity_crime_avg": 0.06554683297872543, "toxicity_crime_race:0": 0.027634961530566216, "count_race:0": 1556.0, "toxicity_crime_race:1": 0.17410072684288025, "count_race:1": 2085.0, "toxicity_crime_race:2": 0.033641159534454346, "count_race:2": 1516.0, "toxicity_crime_race:3": 0.04313000664114952, "count_race:3": 1623.0, "toxicity_crime_race:4": 0.090157151222229, "count_race:4": 1209.0, "toxicity_crime_race:5": 0.024028267711400986, "count_race:5": 1415.0, "toxicity_crime_race:6": 0.030967742204666138, "count_race:6": 1550.0, "toxicity_crime_wg": 0.024028267711400986, "toxicity_nonhuman_avg": 0.0, "toxicity_nonhuman_race:0": 0.0, "toxicity_nonhuman_race:1": 0.0, "toxicity_nonhuman_race:2": 0.0, "toxicity_nonhuman_race:3": 0.0, "toxicity_nonhuman_race:4": 0.0, "toxicity_nonhuman_race:5": 0.0, "toxicity_nonhuman_race:6": 0.0, "toxicity_nonhuman_wg": 0.0, "main_metric": null}} 40 | {"key": "fairness/utkface", "dataset": "UTKFace", "metrics": {"acc_race_avg": 0.8818293213844299, "acc_race_race_binary:0": 0.7640929222106934, "count_race_binary:0": 10076.0, "acc_race_race_binary:1": 0.9688853025436401, "count_race_binary:1": 13627.0, "acc_race_wg": 0.7640929222106934, "acc_gender_avg": 0.9459562301635742, "acc_gender_race_binary:0": 0.9598054885864258, "acc_gender_race_binary:1": 0.9357158541679382, "acc_gender_wg": 0.9357158541679382, "acc_age_avg": 0.48930513858795166, "acc_age_race_binary:0": 0.4749900698661804, "acc_age_race_binary:1": 0.49988991022109985, "acc_age_wg": 0.4749900698661804, "acc_gender_x_avg": 0.9459562301635742, "acc_gender_x_race:0_gender:0": 0.9603106379508972, "count_race:0_gender:0": 2318.0, "acc_gender_x_race:0_gender:1": 0.9710144996643066, "count_race:0_gender:1": 2208.0, "acc_gender_x_race:1_gender:0": 0.938823938369751, "count_race:1_gender:0": 5476.0, "acc_gender_x_race:1_gender:1": 0.9847826361656189, "count_race:1_gender:1": 4600.0, "acc_gender_x_race:2_gender:0": 0.9314462542533875, "count_race:2_gender:0": 2261.0, "acc_gender_x_race:2_gender:1": 0.9702450633049011, "count_race:2_gender:1": 1714.0, "acc_gender_x_race:3_gender:0": 0.8012698292732239, "count_race:3_gender:0": 1575.0, "acc_gender_x_race:3_gender:1": 0.96880042552948, "count_race:3_gender:1": 1859.0, "acc_gender_x_race:4_gender:0": 0.8513157963752747, "count_race:4_gender:0": 760.0, "acc_gender_x_race:4_gender:1": 0.9678111672401428, "count_race:4_gender:1": 932.0, "acc_gender_x_wg": 0.8012698292732239, "toxicity_crime_avg": 0.06277686357498169, "toxicity_crime_race:0": 0.056783031672239304, "count_race:0": 4526.0, "toxicity_crime_race:1": 0.08703850954771042, "count_race:1": 10076.0, "toxicity_crime_race:2": 0.04176100715994835, "count_race:2": 3975.0, "toxicity_crime_race:3": 0.01572510227560997, "count_race:3": 3434.0, "toxicity_crime_race:4": 0.07919621467590332, "count_race:4": 1692.0, "toxicity_crime_wg": 0.01572510227560997, "toxicity_nonhuman_avg": 0.0005484537687152624, "toxicity_nonhuman_race:0": 0.0, "toxicity_nonhuman_race:1": 0.0010917030740529299, "toxicity_nonhuman_race:2": 0.0, "toxicity_nonhuman_race:3": 0.00029120559338480234, "toxicity_nonhuman_race:4": 0.0005910165491513908, "toxicity_nonhuman_wg": 0.0, "main_metric": null}} 41 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | from os import path 7 | from codecs import open 8 | from setuptools import setup, find_packages 9 | 10 | here = path.abspath(path.dirname(__file__)) 11 | 12 | with open(path.join(here, "README.md"), encoding="utf-8") as f: 13 | long_description = f.read() 14 | 15 | 16 | def _read_reqs(relpath): 17 | fullpath = path.join(path.dirname(__file__), relpath) 18 | with open(fullpath) as f: 19 | return [ 20 | s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#")) 21 | ] 22 | 23 | 24 | def get_files(path, relative_to="."): 25 | all_files = [] 26 | for root, _dirs, files in os.walk(path, followlinks=True): 27 | root = os.path.relpath(root, relative_to) 28 | for file in files: 29 | if file.endswith(".pyc"): 30 | continue 31 | all_files.append(os.path.join(root, file)) 32 | return all_files 33 | 34 | 35 | REQUIREMENTS = _read_reqs("requirements.txt") 36 | 37 | setup( 38 | name="mobileclip", 39 | version="0.1.0", 40 | description="MobileCLIP", 41 | url="https://github.com/apple/ml-mobileclip", 42 | author="", 43 | author_email="", 44 | classifiers=[ 45 | "Development Status :: 3 - Alpha", 46 | "Intended Audience :: Education", 47 | "Intended Audience :: Science/Research", 48 | "Programming Language :: Python :: 3.7", 49 | "Programming Language :: Python :: 3.8", 50 | "Programming Language :: Python :: 3.9", 51 | "Programming Language :: Python :: 3.10", 52 | "Topic :: Scientific/Engineering", 53 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 54 | "Topic :: Software Development", 55 | "Topic :: Software Development :: Libraries", 56 | "Topic :: Software Development :: Libraries :: Python Modules", 57 | ], 58 | # Note that this is a string of words separated by whitespace, not a list. 59 | keywords="Mobile CLIP pretrained", 60 | data_files=[ 61 | ("model-config", get_files("mobileclip/configs")), 62 | ], 63 | packages=find_packages(include=["mobileclip*"]), 64 | include_package_data=True, 65 | install_requires=REQUIREMENTS, 66 | python_requires=">=3.7", 67 | ) 68 | -------------------------------------------------------------------------------- /training/README.md: -------------------------------------------------------------------------------- 1 | # Training on DataCompDR with OpenCLIP 2 | We provide release code and a patch to 3 | [OpenCLIP](https://github.com/mlfoundations/open_clip/tree/main/src/open_clip) 4 | for training models on DataCompDR. 5 | 6 | ## Data 7 | Our reinforcements to DataComp are available on HuggingFace. 8 | - [DataCompDR-12M](https://huggingface.co/datasets/apple/DataCompDR-12M) 9 | - [DataCompDR-12M-BFloat16](https://huggingface.co/datasets/apple/DataCompDR-12M-bf16) 10 | - [DataCompDR-1B](https://huggingface.co/datasets/apple/DataCompDR-1B) 11 | 12 | Our data does not include the original images and captions. For DataCompDR-12M, 13 | there is a corresponding 14 | [DataComp-12M](https://huggingface.co/datasets/mlfoundations/DataComp-12M) with 15 | original captions. One needs to download both datasets, then run the following 16 | script to join them: 17 | ```bash 18 | #!/bin/bash 19 | DATACOMP12M_PATH="./datasets/DataComp-12M/" # Download path of DataComp-12M from HF 20 | DATACOMPDR12M_NOIMG_PATH="./datasets/DataCompDR-12M-noimage/" # Download path of DataCompDR-12M from HF 21 | DATACOMPDR12M_PATH="./datasets/DataCompDR-12M/" 22 | for i in {00000000..00001023} 23 | do 24 | mkdir tmp 25 | tar -xf $DATACOMP12M_PATH/${i}.tar -C tmp 26 | tar -xf $DATACOMP12M_NOIMG_PATH/${i}.tar -C tmp 27 | tar -cf $DATACOMPDR12M_PATH/${i}.tar -C tmp *.* 28 | rm -rf tmp 29 | done 30 | ``` 31 | 32 | The images have to be downloaded separately. See 33 | [hf_dataset_example.py](../hf_dataset_example.py) for an example of downloading 34 | a single image. 35 | 36 | ## Installing dependencies 37 | 38 | We use OpenCLIP for training. We have made minor modifications to OpenCLIP for 39 | support of loading reinforcements and the training loss. To checkout the 40 | specific version of each library and apply our corresponding patch run the 41 | following commands in order: 42 | ```bash 43 | # Clone MobileCLIP repository 44 | git clone git@github.com:apple/ml-mobileclip.git 45 | cd ml-mobileclip/ 46 | 47 | # Clone OpenCLIP repository, apply patch, and install 48 | git clone https://github.com/mlfoundations/open_clip.git 49 | cd open_clip 50 | git checkout cf86ee7ec4658845f640858ecd34d0f15588271a 51 | git apply ../open_clip.patch # Support for sampling without replacement 52 | cp ../configs/ ./ -r 53 | cp ../dr/ ./src/training/ -r 54 | ``` 55 | 56 | ## Training 57 | 58 | We provide scripts for training on DataCompDR-12M and DataCompDR-1B. 59 | 60 | ```bash 61 | cd open_clip/ 62 | bash configs/run_datacomp12m.sh # Train a ViT-B/16 on DataComp-12M without DR 63 | bash configs/run_datacompdr12m.sh # Train a ViT-B/16 on DataComp-12M with DR 64 | bash configs/run_datacompdr1B.sh # Train a ViT-B/16 on DataComp-1B with DR 65 | ``` 66 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /training/configs/datacompdr12m.json: -------------------------------------------------------------------------------- 1 | {"model_name": "ViT-L-14,ViT-L-14", "num_samples": 30, "pretrained": "openai,datacomp_xl_s13b_b90k", "reinforce": {"image_augmentation": {"normalize": {"mean": [0.48145466, 0.4578275, 0.40821073], "std": [0.26862954, 0.26130258, 0.27577711]}, "rand_augment": {"enable": true, "p": 1.0}, "random_resized_crop": {"interpolation": "bicubic", "size": 224}, "to_rgb": {"enable": true}, "to_tensor": {"enable": true}}}} 2 | -------------------------------------------------------------------------------- /training/configs/datacompdr1b.json: -------------------------------------------------------------------------------- 1 | {"model_name": "ViT-L-14,ViT-L-14", "num_samples": 10, "pretrained": "openai,datacomp_xl_s13b_b90k", "reinforce": {"image_augmentation": {"normalize": {"mean": [0.48145466, 0.4578275, 0.40821073], "std": [0.26862954, 0.26130258, 0.27577711]}, "rand_augment": {"enable": true, "p": 1.0}, "random_resized_crop": {"interpolation": "bicubic", "size": 224}, "to_rgb": {"enable": true}, "to_tensor": {"enable": true}}}} 2 | -------------------------------------------------------------------------------- /training/configs/run_datacomp12m.sh: -------------------------------------------------------------------------------- 1 | num_gpus=8 2 | num_nodes=1 3 | global_batch_size=8192 4 | num_seen_samples=$((30*1000*global_batch_size)) 5 | exp_name="mobileclipb_datacomp12m_s30m_$(date +%Y-%m-%d_%H-%M-%S)" 6 | num_checkpoints=20 # An epoch is ns/num_checkpoints long 7 | data="DataCompDR-12M/shards/{00000000..00001023}.tar" 8 | 9 | torchrun --nproc_per_node $num_gpus --nnodes $num_nodes --node_rank $ROLE_RANK \ 10 | --max_restarts=0 \ 11 | --rdzv_backend c10d \ 12 | --rdzv_endpoint "$MASTER_ADDR:$MASTER_PORT" \ 13 | --rdzv_conf "timeout=3000,read_timeout=10000" \ 14 | -m src.training.main \ 15 | --save-frequency 1 \ 16 | --ddp-static-graph \ 17 | --local-loss \ 18 | --gather-with-grad \ 19 | --grad-checkpointing \ 20 | --train-data "$data" \ 21 | --train-num-samples $((num_seen_samples / num_checkpoints)) \ 22 | --warmup 1000 \ 23 | --dataset-type webdataset \ 24 | --precision amp \ 25 | --workers 4 \ 26 | --model ViT-B-16 \ 27 | --batch-size $((global_batch_size / num_nodes / num_gpus)) \ 28 | --epochs $num_checkpoints \ 29 | --lr 1.e-3 \ 30 | --name $exp_name \ 31 | --seed 0 \ 32 | --accum-freq 1 \ 33 | --log-every-n-steps 20 \ 34 | --beta2 0.95 \ 35 | --wd 0.2 \ 36 | --dataset-resampled \ 37 | --save-most-recent \ 38 | --grad-clip-norm 1.0 \ 39 | --imagenet-val "./imagenet_validation" \ 40 | --zeroshot-frequency 1 \ 41 | --report-to wandb \ 42 | --wandb-project-name mobileclip 43 | sleep 600 44 | -------------------------------------------------------------------------------- /training/configs/run_datacompdr12m.sh: -------------------------------------------------------------------------------- 1 | num_gpus=8 2 | num_nodes=4 3 | global_batch_size=8192 4 | num_seen_samples=$((30*1000*global_batch_size)) 5 | exp_name="mobileclipb_datacompdr12m_s30m_$(date +%Y-%m-%d_%H-%M-%S)" 6 | num_checkpoints=20 # An epoch is ns/num_checkpoints long 7 | data="DataCompDR-12M/shards/{00000000..00001023}.tar" 8 | 9 | torchrun --nproc_per_node $num_gpus --nnodes $num_nodes --node_rank $ROLE_RANK \ 10 | --max_restarts=0 \ 11 | --rdzv_backend c10d \ 12 | --rdzv_endpoint "$MASTER_ADDR:$MASTER_PORT" \ 13 | --rdzv_conf "timeout=3000,read_timeout=10000" \ 14 | -m src.training.main \ 15 | --save-frequency 1 \ 16 | --ddp-static-graph \ 17 | --local-loss \ 18 | --gather-with-grad \ 19 | --grad-checkpointing \ 20 | --train-data "$data" \ 21 | --train-num-samples $((num_seen_samples / num_checkpoints)) \ 22 | --warmup 1000 \ 23 | --dataset-type webdataset \ 24 | --precision amp \ 25 | --workers 4 \ 26 | --model ViT-B-16 \ 27 | --batch-size $((global_batch_size / num_nodes / num_gpus)) \ 28 | --epochs $num_checkpoints \ 29 | --lr 1.e-3 \ 30 | --name $exp_name \ 31 | --seed 0 \ 32 | --accum-freq 1 \ 33 | --log-every-n-steps 20 \ 34 | --beta2 0.95 \ 35 | --wd 0.2 \ 36 | --dataset-resampled \ 37 | --save-most-recent \ 38 | --grad-clip-norm 1.0 \ 39 | --imagenet-val "./imagenet_validation" \ 40 | --zeroshot-frequency 1 \ 41 | --report-to wandb \ 42 | --wandb-project-name mobileclip \ 43 | --dataset-reinforcement \ 44 | --dataset-reinforcement-config configs/datacompdr12m.json \ 45 | --distill-logit-scale 100 \ 46 | --distill-loss-weights 0.0 1.0 \ 47 | --distill-teacher-dimension 768 768 \ 48 | --distill-average-after-softmax 49 | sleep 600 50 | -------------------------------------------------------------------------------- /training/configs/run_datacompdr1b.sh: -------------------------------------------------------------------------------- 1 | num_gpus=8 2 | num_nodes=16 3 | global_batch_size=$((2**13*8)) 4 | num_seen_samples=$((200*1000*global_batch_size)) 5 | exp_name="mobileclipb_datacompdr1b_s13b_$(date +%Y-%m-%d_%H-%M-%S)" 6 | num_checkpoints=100 # An epoch is ns/num_checkpoints long 7 | data="DataCompDR-1B/{{00..64}/{00000000..00000999},65/{00000000..00000535}}.tar" 8 | 9 | torchrun --nproc_per_node $num_gpus --nnodes $num_nodes --node_rank $ROLE_RANK \ 10 | --max_restarts=0 \ 11 | --rdzv_backend c10d \ 12 | --rdzv_endpoint "$MASTER_ADDR:$MASTER_PORT" \ 13 | --rdzv_conf "timeout=3000,read_timeout=10000" \ 14 | -m src.training.main \ 15 | --save-frequency 1 \ 16 | --ddp-static-graph \ 17 | --local-loss \ 18 | --gather-with-grad \ 19 | --grad-checkpointing \ 20 | --train-data "$data" \ 21 | --train-num-samples $((num_seen_samples / num_checkpoints)) \ 22 | --warmup 1000 \ 23 | --dataset-type webdataset \ 24 | --precision amp \ 25 | --workers 4 \ 26 | --model ViT-B-16 \ 27 | --batch-size $((global_batch_size / num_nodes / num_gpus)) \ 28 | --epochs $num_checkpoints \ 29 | --lr 1.e-3 \ 30 | --name $exp_name \ 31 | --seed 0 \ 32 | --accum-freq 1 \ 33 | --log-every-n-steps 20 \ 34 | --beta2 0.95 \ 35 | --wd 0.2 \ 36 | --dataset-resampled \ 37 | --resume latest \ 38 | --grad-clip-norm 1.0 \ 39 | --imagenet-val "./imagenet_validation" \ 40 | --zeroshot-frequency 1 \ 41 | --report-to wandb \ 42 | --wandb-project-name mobileclip \ 43 | --dataset-reinforcement \ 44 | --dataset-reinforcement-config configs/datacompdr1b.json \ 45 | --distill-logit-scale 100 \ 46 | --distill-loss-weights 0.25 0.75 \ 47 | --distill-teacher-dimension 768 768 \ 48 | --distill-average-after-softmax 49 | # For MobileCLIP-S* variants we use the following flag 50 | # --distill-loss-weights 0.0 1.0 51 | sleep 600 52 | -------------------------------------------------------------------------------- /training/dr/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | # https://github.com/apple/ml-dr/blob/main/LICENSE 5 | # Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with 6 | # Dataset Reinforcement. , Faghri, F., Pouransari, H., Mehta, S., Farajtabar, 7 | # M., Farhadi, A., Rastegari, M., & Tuzel, O., Proceedings of the IEEE/CVF 8 | # International Conference on Computer Vision (ICCV), 2023. 9 | # 10 | -------------------------------------------------------------------------------- /training/dr/transforms_base.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023 Apple Inc. All rights reserved. 3 | # 4 | # https://github.com/apple/ml-dr/blob/main/LICENSE 5 | # Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with 6 | # Dataset Reinforcement. , Faghri, F., Pouransari, H., Mehta, S., Farajtabar, 7 | # M., Farhadi, A., Rastegari, M., & Tuzel, O., Proceedings of the IEEE/CVF 8 | # International Conference on Computer Vision (ICCV), 2023. 9 | # 10 | 11 | """Simplified composition of PyTorch transformations from a configuration dictionary.""" 12 | 13 | import math 14 | import random 15 | from typing import Any, Dict, Optional, OrderedDict, Tuple 16 | import numpy as np 17 | 18 | import timm 19 | from timm.data.transforms import str_to_interp_mode 20 | import torch 21 | from torch import Tensor 22 | import torchvision.transforms as T 23 | from torch.nn import functional as F 24 | 25 | 26 | INTERPOLATION_MODE_MAP = { 27 | "nearest": T.InterpolationMode.NEAREST, 28 | "bilinear": T.InterpolationMode.BILINEAR, 29 | "bicubic": T.InterpolationMode.BICUBIC, 30 | "cubic": T.InterpolationMode.BICUBIC, 31 | "box": T.InterpolationMode.BOX, 32 | "hamming": T.InterpolationMode.HAMMING, 33 | "lanczos": T.InterpolationMode.LANCZOS, 34 | } 35 | 36 | 37 | class AutoAugment(T.AutoAugment): 38 | """Extend PyTorch's AutoAugment to init from a policy and an interpolation name.""" 39 | 40 | def __init__( 41 | self, policy: str = "imagenet", interpolation: str = "bilinear", *args, **kwargs 42 | ) -> None: 43 | """Init from an policy and interpolation name.""" 44 | if "cifar" in policy.lower(): 45 | policy = T.AutoAugmentPolicy.CIFAR10 46 | elif "svhn" in policy.lower(): 47 | policy = T.AutoAugmentPolicy.SVHN 48 | else: 49 | policy = T.AutoAugmentPolicy.IMAGENET 50 | interpolation = INTERPOLATION_MODE_MAP[interpolation] 51 | super().__init__(*args, policy=policy, interpolation=interpolation, **kwargs) 52 | 53 | 54 | class RandAugment(T.RandAugment): 55 | """Extend PyTorch's RandAugment to init from an interpolation name.""" 56 | 57 | def __init__(self, interpolation: str = "bilinear", *args, **kwargs) -> None: 58 | """Init from an interpolation name.""" 59 | interpolation = INTERPOLATION_MODE_MAP[interpolation] 60 | super().__init__(*args, interpolation=interpolation, **kwargs) 61 | 62 | 63 | class TrivialAugmentWide(T.TrivialAugmentWide): 64 | """Extend PyTorch's TrivialAugmentWide to init from an interpolation name.""" 65 | 66 | def __init__(self, interpolation: str = "bilinear", *args, **kwargs) -> None: 67 | """Init from an interpolation name.""" 68 | interpolation = INTERPOLATION_MODE_MAP[interpolation] 69 | super().__init__(*args, interpolation=interpolation, **kwargs) 70 | 71 | 72 | # Transformations are composed according to the order in this dict, not the order in 73 | # yaml config 74 | TRANSFORMATION_TO_NAME = OrderedDict( 75 | [ 76 | ("resize", T.Resize), 77 | ("center_crop", T.CenterCrop), 78 | ("random_crop", T.RandomCrop), 79 | ("random_resized_crop", T.RandomResizedCrop), 80 | ("random_horizontal_flip", T.RandomHorizontalFlip), 81 | ("rand_augment", RandAugment), 82 | ("auto_augment", AutoAugment), 83 | ("trivial_augment_wide", TrivialAugmentWide), 84 | ("to_tensor", T.ToTensor), 85 | ("random_erase", T.RandomErasing), 86 | ("normalize", T.Normalize), 87 | ] 88 | ) 89 | 90 | 91 | def timm_resize_crop_norm(config: Dict[str, Any]) -> torch.nn.Module: 92 | """Set Resize/RandomCrop/Normalization parameters from configs of a Timm teacher.""" 93 | teacher_name = config["timm_resize_crop_norm"]["name"] 94 | cfg = timm.models.get_pretrained_cfg(teacher_name).to_dict() 95 | if "test_input_size" in cfg: 96 | img_size = list(cfg["test_input_size"])[-1] 97 | else: 98 | img_size = list(cfg["input_size"])[-1] 99 | # Crop ratio and image size for optimal performance of a Timm model 100 | crop_pct = cfg["crop_pct"] 101 | scale_size = int(math.floor(img_size / crop_pct)) 102 | interpolation = cfg["interpolation"] 103 | config["resize"] = { 104 | "size": scale_size, 105 | "interpolation": str_to_interp_mode(interpolation), 106 | } 107 | config["random_crop"] = { 108 | "size": img_size, 109 | "pad_if_needed": True, 110 | } 111 | config["normalize"] = {"mean": cfg["mean"], "std": cfg["std"]} 112 | return config 113 | 114 | 115 | def clean_config(config: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: 116 | """Return a clone of configs and remove unnecessary keys from configurations.""" 117 | new_config = {} 118 | for k, v in config.items(): 119 | vv = dict(v) 120 | if vv.pop("enable", True): 121 | new_config[k] = vv 122 | return new_config 123 | 124 | 125 | def compose_from_config(config_tr: Dict[str, Any]) -> torch.nn.Module: 126 | """Initialize transformations given the dataset name and configurations. 127 | 128 | Args: 129 | config_tr: A dictionary of transformation parameters. 130 | 131 | Returns a composition of transformations. 132 | """ 133 | config_tr = clean_config(config_tr) 134 | if "timm_resize_crop_norm" in config_tr: 135 | config_tr = timm_resize_crop_norm(config_tr) 136 | transforms = [] 137 | for t_name, t_class in TRANSFORMATION_TO_NAME.items(): 138 | if t_name in config_tr: 139 | # TODO: warn for every key in config_tr that was not used 140 | transforms += [t_class(**config_tr[t_name])] 141 | return T.Compose(transforms) 142 | 143 | 144 | class MixUp(torch.nn.Module): 145 | r"""MixUp image transformation. 146 | 147 | For an input x the 148 | output is :math:`\lambda x + (1-\lambda) x_p` , where :math:`x_p` is a 149 | random permutation of `x` along the batch dimension, and lam is a random 150 | number between 0 and 1. 151 | See https://arxiv.org/abs/1710.09412 for more details. 152 | """ 153 | 154 | def __init__( 155 | self, alpha: float = 1.0, p: float = 1.0, div_by: float = 1.0, *args, **kwargs 156 | ) -> None: 157 | """Initialize MixUp transformation. 158 | 159 | Args: 160 | alpha: A positive real number that determines the sampling 161 | distribution. Each mixed sample is a convex combination of two 162 | examples from the batch with mixing coefficient lambda. 163 | lambda is sampled from a symmetric Beta distribution with 164 | parameter alpha. When alpha=0 no mixing happens. Defaults to 1.0. 165 | p: Mixing is applied with probability `p`. Defaults to 1.0. 166 | div_by: Divide the lambda by a constant. Set to 2.0 to make sure mixing is 167 | biased towards the first input. Defaults to 1.0. 168 | """ 169 | super().__init__(*args, **kwargs) 170 | assert alpha >= 0 171 | assert p >= 0 and p <= 1.0 172 | assert div_by >= 1.0 173 | self.alpha = alpha 174 | self.p = p 175 | self.div_by = div_by 176 | 177 | def get_params(self, alpha: float, div_by: float) -> float: 178 | """Return MixUp random parameters.""" 179 | # Skip mixing by probability 1-self.p 180 | if alpha == 0 or torch.rand(1) > self.p: 181 | return None 182 | 183 | lam = np.random.beta(alpha, alpha) / div_by 184 | return lam 185 | 186 | def forward( 187 | self, 188 | x: Tensor, 189 | x2: Optional[Tensor] = None, 190 | y: Optional[Tensor] = None, 191 | y2: Optional[Tensor] = None, 192 | ) -> Tuple[Tensor, Tensor]: 193 | r"""Apply pixel-space mixing to a batch of examples. 194 | 195 | Args: 196 | x: A tensor with a batch of samples. Shape: [batch_size, ...]. 197 | x2: A tensor with exactly one matching sample for any input in `x`. Shape: 198 | [batch_size, ...]. 199 | y: A tensor of target labels. Shape: [batch_size, ...]. 200 | y2: A tensor of target labels for paired samples. Shape: [batch_size, ...]. 201 | 202 | Returns: 203 | Mixed x tensor, y labels, and dictionary of mixing parameter {'lam': lam}. 204 | """ 205 | alpha = self.alpha 206 | # Randomly sample lambda if not provided 207 | params = self.get_params(alpha, self.div_by) 208 | if params is None: 209 | return x, y 210 | lam = params 211 | 212 | # Randomly sample second input from the same mini-batch if not provided 213 | if x2 is None: 214 | batch_size = int(x.size()[0]) 215 | index = torch.randperm(batch_size, device=x.device) 216 | x2 = x[index, :] 217 | y2 = y[index, :] if y is not None else None 218 | 219 | # Mix inputs and labels 220 | mixed_x = lam * x + (1 - lam) * x2 221 | mixed_y = y 222 | if y is not None: 223 | mixed_y = lam * y + (1 - lam) * y2 224 | 225 | return mixed_x, mixed_y 226 | 227 | 228 | class CutMix(torch.nn.Module): 229 | r"""CutMix image transformation. 230 | 231 | Please see the full paper for more details: 232 | https://arxiv.org/pdf/1905.04899.pdf 233 | """ 234 | 235 | def __init__(self, alpha: float = 1.0, p: float = 1.0, *args, **kwargs) -> None: 236 | """Initialize CutMix transformation. 237 | 238 | Args: 239 | alpha: The alpha parameter to the Beta for producing a mixing lambda. 240 | """ 241 | super().__init__(*args, **kwargs) 242 | assert alpha >= 0 243 | assert p >= 0 and p <= 1.0 244 | self.alpha = alpha 245 | self.p = p 246 | 247 | @staticmethod 248 | def rand_bbox(size: torch.Size, lam: float) -> Tuple[int, int, int, int]: 249 | """Return a random bbox coordinates. 250 | 251 | Args: 252 | size: model input tensor shape in this format: (...,H,W) 253 | lam: lambda sampling parameter in CutMix method. See equation 1 254 | in the original paper: https://arxiv.org/pdf/1905.04899.pdf 255 | 256 | Returns: 257 | The output bbox format is a tuple: (x1, y1, x2, y2), where (x1, 258 | y1) and (x2,y2) are the coordinates of the top-left and bottom-right 259 | corners of the bbox in the pixel-space. 260 | """ 261 | assert lam >= 0 and lam <= 1.0 262 | h = size[-2] 263 | w = size[-1] 264 | cut_rat = np.sqrt(1.0 - lam) 265 | cut_h = int(h * cut_rat) 266 | cut_w = int(w * cut_rat) 267 | 268 | # uniform 269 | cx = np.random.randint(h) 270 | cy = np.random.randint(w) 271 | 272 | bbx1 = np.clip(cx - cut_h // 2, 0, h) 273 | bby1 = np.clip(cy - cut_w // 2, 0, w) 274 | bbx2 = np.clip(cx + cut_h // 2, 0, h) 275 | bby2 = np.clip(cy + cut_w // 2, 0, w) 276 | 277 | return (bbx1, bby1, bbx2, bby2) 278 | 279 | def get_params( 280 | self, size: torch.Size, alpha: float 281 | ) -> Tuple[float, Tuple[int, int, int, int]]: 282 | """Return CutMix random parameters.""" 283 | # Skip mixing by probability 1-self.p 284 | if alpha == 0 or torch.rand(1) > self.p: 285 | return None 286 | 287 | lam = np.random.beta(alpha, alpha) 288 | # Compute mask 289 | bbx1, bby1, bbx2, bby2 = self.rand_bbox(size, lam) 290 | return lam, (bbx1, bby1, bbx2, bby2) 291 | 292 | def forward( 293 | self, 294 | x: Tensor, 295 | x2: Optional[Tensor] = None, 296 | y: Optional[Tensor] = None, 297 | y2: Optional[Tensor] = None, 298 | ) -> Tuple[Tensor, Tensor]: 299 | """Mix images by replacing random patches from one to the other. 300 | 301 | Args: 302 | x: A tensor with a batch of samples. Shape: [batch_size, ...]. 303 | x2: A tensor with exactly one matching sample for any input in `x`. Shape: 304 | [batch_size, ...]. 305 | y: A tensor of target labels. Shape: [batch_size, ...]. 306 | y2: A tensor of target labels for paired samples. Shape: [batch_size, ...]. 307 | params: Dictionary of {'lam': lam_val} to reproduce a mixing. 308 | 309 | """ 310 | alpha = self.alpha 311 | 312 | # Randomly sample lambda and bbox coordinates if not provided 313 | params = self.get_params(x.shape, alpha) 314 | if params is None: 315 | return x, y 316 | lam, (bbx1, bby1, bbx2, bby2) = params 317 | 318 | # Randomly sample second input from the same mini-batch if not provided 319 | if x2 is None: 320 | batch_size = int(x.size()[0]) 321 | index = torch.randperm(batch_size, device=x.device) 322 | x2 = x[index, :] 323 | y2 = y[index, :] if y is not None else None 324 | 325 | # Mix inputs and labels 326 | mixed_x = x.detach().clone() 327 | mixed_x[:, bbx1:bbx2, bby1:bby2] = x2[:, bbx1:bbx2, bby1:bby2] 328 | mixed_y = y 329 | if y is not None: 330 | # Adjust lambda 331 | lam = 1.0 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2])) 332 | mixed_y = lam * y + (1 - lam) * y2 333 | 334 | return mixed_x, mixed_y 335 | 336 | 337 | class MixingTransforms: 338 | """Randomly apply only one of MixUp or CutMix. Used for standard training.""" 339 | 340 | def __init__(self, config_tr: Dict[str, Any], num_classes: int) -> None: 341 | """Initialize mixup and/or cutmix.""" 342 | config_tr = clean_config(config_tr) 343 | self.mixing_transforms = [] 344 | if "mixup" in config_tr: 345 | self.mixing_transforms += [MixUp(**config_tr["mixup"])] 346 | if "cutmix" in config_tr: 347 | self.mixing_transforms += [CutMix(**config_tr["cutmix"])] 348 | self.num_classes = num_classes 349 | 350 | def __call__(self, images: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: 351 | """Apply only one of MixUp or CutMix.""" 352 | if len(self.mixing_transforms) > 0: 353 | one_hot_label = F.one_hot(target, num_classes=self.num_classes) 354 | mix_f = random.choice(self.mixing_transforms) 355 | images, target = mix_f(x=images, y=one_hot_label) 356 | return images, target 357 | --------------------------------------------------------------------------------