├── .gitignore
├── ACKNOWLEDGEMENTS
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── LICENSE_weights_data
├── README.md
├── docs
└── fig_accuracy_latency.png
├── eval
└── zeroshot_imagenet.py
├── get_pretrained_models.sh
├── hf_dataset_example.py
├── ios_app
├── ACKNOWLEDGEMENTS
├── MobileCLIPExplore.xcodeproj
│ ├── project.pbxproj
│ ├── project.xcworkspace
│ │ ├── contents.xcworkspacedata
│ │ ├── xcshareddata
│ │ │ └── IDEWorkspaceChecks.plist
│ │ └── xcuserdata
│ │ │ └── pavankumaranasosaluvasu.xcuserdatad
│ │ │ └── UserInterfaceState.xcuserstate
│ └── xcuserdata
│ │ └── pavankumaranasosaluvasu.xcuserdatad
│ │ └── xcschemes
│ │ └── xcschememanagement.plist
├── MobileCLIPExplore
│ ├── Assets.xcassets
│ │ ├── AccentColor.colorset
│ │ │ └── Contents.json
│ │ ├── AddClassName.colorset
│ │ │ └── Contents.json
│ │ ├── AppIcon.appiconset
│ │ │ └── Contents.json
│ │ ├── Contents.json
│ │ ├── Preset.colorset
│ │ │ └── Contents.json
│ │ ├── PresetDisabled.colorset
│ │ │ └── Contents.json
│ │ ├── PromptEditor.colorset
│ │ │ └── Contents.json
│ │ └── PromptEditorBackground.colorset
│ │ │ └── Contents.json
│ ├── AsyncFactory.swift
│ ├── CameraController.swift
│ ├── Configuration
│ │ └── Build.xcconfig
│ ├── ContentView.swift
│ ├── DisplayPrediction.swift
│ ├── Extensions.swift
│ ├── MobileCLIPExploreApp.swift
│ ├── Models.swift
│ ├── Models
│ │ └── README.md
│ ├── Preview Content
│ │ └── Preview Assets.xcassets
│ │ │ └── Contents.json
│ ├── Prompt.swift
│ ├── Resources
│ │ ├── clip-merges.txt
│ │ └── clip-vocab.json
│ ├── Tokenizer
│ │ ├── CLIPTokenizer.swift
│ │ ├── GPT2ByteEncoder.swift
│ │ └── Utils.swift
│ ├── Views
│ │ ├── Helpers
│ │ │ ├── PredictionsTable.swift
│ │ │ ├── PromptPreview.swift
│ │ │ └── VideoFrameView.swift
│ │ ├── InferenceView.swift
│ │ ├── InfoView.swift
│ │ └── PromptEditor.swift
│ └── ZSImageClassification.swift
├── README.md
└── docs
│ └── app_screenshots
│ ├── examples.png
│ └── model_directory_guide.png
├── mobileclip
├── __init__.py
├── clip.py
├── configs
│ ├── mobileclip_b.json
│ ├── mobileclip_s0.json
│ ├── mobileclip_s1.json
│ └── mobileclip_s2.json
├── image_encoder.py
├── logger.py
├── models
│ ├── __init__.py
│ ├── mci.py
│ └── vit.py
├── modules
│ ├── __init__.py
│ ├── common
│ │ ├── __init__.py
│ │ ├── mobileone.py
│ │ └── transformer.py
│ ├── image
│ │ ├── __init__.py
│ │ ├── image_projection.py
│ │ └── replknet.py
│ └── text
│ │ ├── __init__.py
│ │ ├── repmixer.py
│ │ └── tokenizer.py
└── text_encoder.py
├── requirements.txt
├── results
├── mobileclip_b.jsonl
├── mobileclip_b_lt.jsonl
├── mobileclip_s0.jsonl
├── mobileclip_s1.jsonl
└── mobileclip_s2.jsonl
├── setup.py
└── training
├── README.md
├── __init__.py
├── configs
├── datacompdr12m.json
├── datacompdr1b.json
├── run_datacomp12m.sh
├── run_datacompdr12m.sh
└── run_datacompdr1b.sh
├── dr
├── __init__.py
├── transforms.py
└── transforms_base.py
└── open_clip.patch
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # macOS
132 | **/.DS_Store
133 |
134 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guide
2 |
3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository.
4 |
5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
6 |
7 | ## Before you get started
8 |
9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
10 |
11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright © 2024 Apple Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 | -------------------------------------------------------------------------------
24 | SOFTWARE DISTRIBUTED WITH ML-MobileCLIP:
25 |
26 | The ML-MobileCLIP model weights and data copyright and license terms can be
27 | found in LICENSE_weights_data.
28 |
29 | The ML-MobileCLIP software includes a number of subcomponents with separate
30 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.
31 | -------------------------------------------------------------------------------
32 |
--------------------------------------------------------------------------------
/LICENSE_weights_data:
--------------------------------------------------------------------------------
1 | ML-MobileCLIP Model Weights and Data
2 |
3 | Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 |
5 | IMPORTANT: This Apple software is supplied to you by Apple
6 | Inc. ("Apple") in consideration of your agreement to the following
7 | terms, and your use, installation, modification or redistribution of
8 | this Apple software constitutes acceptance of these terms. If you do
9 | not agree with these terms, please do not use, install, modify or
10 | redistribute this Apple software.
11 |
12 | In consideration of your agreement to abide by the following terms, and
13 | subject to these terms, Apple grants you a personal, non-exclusive
14 | license, under Apple's copyrights in this original Apple software (the
15 | "Apple Software"), to use, reproduce, modify and redistribute the Apple
16 | Software, with or without modifications, in source and/or binary forms;
17 | provided that if you redistribute the Apple Software in its entirety and
18 | without modifications, you must retain this notice and the following
19 | text and disclaimers in all such redistributions of the Apple Software.
20 | Neither the name, trademarks, service marks or logos of Apple Inc. may
21 | be used to endorse or promote products derived from the Apple Software
22 | without specific prior written permission from Apple. Except as
23 | expressly stated in this notice, no other rights or licenses, express or
24 | implied, are granted by Apple herein, including but not limited to any
25 | patent rights that may be infringed by your derivative works or by other
26 | works in which the Apple Software may be incorporated.
27 |
28 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE
29 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
30 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
31 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
32 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
33 |
34 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
35 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
36 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
37 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
38 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
39 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
40 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
41 | POSSIBILITY OF SUCH DAMAGE.
42 |
43 | -------------------------------------------------------------------------------
44 | SOFTWARE DISTRIBUTED WITH ML-MobileCLIP:
45 |
46 | The ML-MobileCLIP software copyright and license terms can be found in LICENSE.
47 |
48 | The ML-MobileCLIP software includes a number of subcomponents with separate
49 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.
50 | -------------------------------------------------------------------------------
51 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training
2 |
3 | This is the official repository of
4 | **[MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training](https://arxiv.org/pdf/2311.17049.pdf). (CVPR 2024)**
5 | *Pavan Kumar Anasosalu Vasu, Hadi Pouransari, Fartash Faghri, Raviteja Vemulapalli, Oncel Tuzel.*
6 | The repository contains code for inference, training, and evaluation of MobileCLIP models trained on DataCompDR datasets.
7 |
8 | [//]: # ()
9 |
10 |
11 |
12 |
13 | - **Update 2024/11/22:** Releasing iOS app to demonstrate the use of our model for real-time zero-shot image classification. See [ios_app](./ios_app/).
14 | - **Update 2024/06/13:** Releasing the code and scripts to train using [OpenCLIP](https://github.com/mlfoundations/open_clip/tree/main/src/open_clip) on DataCompDR datasets. See [training/](./training/).
15 | - **Update 2024/06/13:** MobileCLIP models and DataCompDR datasets are now available on HuggingFace in [MobileCLIP/DataCompDR Collection](https://huggingface.co/collections/apple/mobileclip-models-datacompdr-data-665789776e1aa2b59f35f7c8).
16 |
17 | ### Highlights
18 | * Our smallest variant `MobileCLIP-S0` obtains similar zero-shot performance as [OpenAI](https://arxiv.org/abs/2103.00020)'s ViT-B/16 model while being 4.8x faster and 2.8x smaller.
19 | * `MobileCLIP-S2` obtains better avg zero-shot performance than [SigLIP](https://arxiv.org/abs/2303.15343)'s ViT-B/16 model while being 2.3x faster and 2.1x smaller, and trained with 3x less seen samples.
20 | * `MobileCLIP-B`(LT) attains zero-shot ImageNet performance of **77.2%** which is significantly better than recent works like [DFN](https://arxiv.org/abs/2309.17425) and [SigLIP](https://arxiv.org/abs/2303.15343) with similar architectures or even [OpenAI's ViT-L/14@336](https://arxiv.org/abs/2103.00020).
21 | * iOS app to demonstrate the superior performance of our model on a mobile device.
22 |
23 | 
24 |
25 | ## Getting Started
26 |
27 | ### Setup
28 | ```bash
29 | conda create -n clipenv python=3.10
30 | conda activate clipenv
31 | pip install -e .
32 | ```
33 | To download pretrained checkpoints follow the code snippet below
34 | ```bash
35 | source get_pretrained_models.sh # Files will be downloaded to `checkpoints` directory.
36 | ```
37 |
38 | ### Usage Example
39 | To models from the official repo, follow the code snippet below
40 | ```python
41 | import torch
42 | from PIL import Image
43 | import mobileclip
44 |
45 | model, _, preprocess = mobileclip.create_model_and_transforms('mobileclip_s0', pretrained='/path/to/mobileclip_s0.pt')
46 | tokenizer = mobileclip.get_tokenizer('mobileclip_s0')
47 |
48 | image = preprocess(Image.open("docs/fig_accuracy_latency.png").convert('RGB')).unsqueeze(0)
49 | text = tokenizer(["a diagram", "a dog", "a cat"])
50 |
51 | with torch.no_grad(), torch.cuda.amp.autocast():
52 | image_features = model.encode_image(image)
53 | text_features = model.encode_text(text)
54 | image_features /= image_features.norm(dim=-1, keepdim=True)
55 | text_features /= text_features.norm(dim=-1, keepdim=True)
56 |
57 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
58 |
59 | print("Label probs:", text_probs)
60 | ```
61 |
62 | For an example of loading the data from HuggingFace see
63 | [hf_dataset_example.py](./hf_dataset_example.py).
64 |
65 | ### OpenCLIP Support
66 | Our models are now natively supported in OpenCLIP. To use MobileCLIP models in OpenCLIP, setup your environment as shown below,
67 | ```bash
68 | conda create -n clipenv python=3.10
69 | conda activate clipenv
70 |
71 | pip install git+https://github.com/mlfoundations/open_clip
72 | pip install git+https://github.com/huggingface/pytorch-image-models
73 | ```
74 |
75 | To run inference, see example below,
76 | ```python
77 | import open_clip
78 | from mobileclip.modules.common.mobileone import reparameterize_model
79 |
80 | model, _, preprocess = open_clip.create_model_and_transforms('MobileCLIP-S2', pretrained='datacompdr')
81 | tokenizer = open_clip.get_tokenizer('MobileCLIP-S2')
82 |
83 | # For inference/model exporting purposes, please reparameterize first
84 | model.eval()
85 | model = reparameterize_model(model)
86 |
87 | # ... follow examples in open_clip repo ...
88 | ```
89 | Variants currently available on OpenCLIP,
90 | `[('MobileCLIP-S1', 'datacompdr'),
91 | ('MobileCLIP-S2', 'datacompdr'),
92 | ('MobileCLIP-B', 'datacompdr'),
93 | ('MobileCLIP-B', 'datacompdr_lt')]`
94 |
95 |
96 | ## Evaluation
97 | Please find the detailed evaluation results [here](./results).
98 | To reproduce results, we provide script to perform zero-shot evaluation on ImageNet-1k dataset.
99 | To evaluate on all the 38 datasets, please follow instructions in [datacomp](https://github.com/mlfoundations/datacomp).
100 | ```bash
101 | # Run evaluation with single GPU
102 | python eval/zeroshot_imagenet.py --model-arch mobileclip_s0 --model-path /path/to/mobileclip_s0.pt
103 | ```
104 |
105 | Please refer to [Open CLIP Results](https://github.com/mlfoundations/open_clip/blob/main/docs/openclip_results.csv) to compare with other models.
106 |
107 | | Model | # Seen
Samples (B) | # Params (M)
(img + txt) | Latency (ms)
(img + txt) | IN-1k Zero-Shot
Top-1 Acc. (%) | Avg. Perf. (%)
on 38 datasets | Pytorch Checkpoint (url) |
108 | |:------------------|:----------------------:|:-----------------------------:|:------------------------------:|:-----------------------------------:|:----------------------------------:|:--------------------------------------------------------------------------------------------------------------:|
109 | | MobileCLIP-S0 | 13 | 11.4 + 42.4 | 1.5 + 1.6 | 67.8 | 58.1 | [mobileclip_s0.pt](https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s0.pt) |
110 | | MobileCLIP-S1 | 13 | 21.5 + 63.4 | 2.5 + 3.3 | 72.6 | 61.3 | [mobileclip_s1.pt](https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt) |
111 | | MobileCLIP-S2 | 13 | 35.7 + 63.4 | 3.6 + 3.3 | 74.4 | 63.7 | [mobileclip_s2.pt](https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt) |
112 | | MobileCLIP-B | 13 | 86.3 + 63.4 | 10.4 + 3.3 | 76.8 | 65.2 | [mobileclip_b.pt](https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt) |
113 | | MobileCLIP-B (LT) | 36 | 86.3 + 63.4 | 10.4 + 3.3 | 77.2 | 65.8 | [mobileclip_blt.pt](https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt) |
114 |
115 | Note: MobileCLIP-B(LT) is trained for 300k iterations with constant learning rate schedule and 300k iterations with cosine learning rate schedule.
116 |
117 | ## Citation
118 | If you found this code useful, please cite the following paper:
119 | ```
120 | @InProceedings{mobileclip2024,
121 | author = {Pavan Kumar Anasosalu Vasu, Hadi Pouransari, Fartash Faghri, Raviteja Vemulapalli, Oncel Tuzel},
122 | title = {MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training},
123 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
124 | month = {June},
125 | year = {2024},
126 | }
127 | ```
128 |
129 | ## Acknowledgements
130 | Our codebase is built using multiple opensource contributions, please see [ACKNOWLEDGEMENTS](ACKNOWLEDGEMENTS) for more details.
131 |
--------------------------------------------------------------------------------
/docs/fig_accuracy_latency.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-mobileclip/1140b8d197e4ed7d56b3a92216ded98bb1c2ac87/docs/fig_accuracy_latency.png
--------------------------------------------------------------------------------
/eval/zeroshot_imagenet.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | """
6 | Evaluation code is borrowed from https://github.com/mlfoundations/datacomp/blob/main/eval_utils/wds_eval.py
7 | Licensed under MIT License, see ACKNOWLEDGEMENTS for details.
8 | """
9 |
10 | import os
11 | import argparse
12 |
13 | import mobileclip
14 | import torch
15 | from clip_benchmark.datasets.builder import build_dataset
16 | from clip_benchmark.metrics import zeroshot_classification as zsc
17 |
18 |
19 | def parse_args(parser):
20 | parser.add_argument(
21 | "--model-arch",
22 | type=str,
23 | required=True,
24 | help="Specify model arch from the available choices.",
25 | choices=['mobileclip_s0', 'mobileclip_s1', 'mobileclip_s2', 'mobileclip_b']
26 | )
27 | parser.add_argument(
28 | "--model-path",
29 | type=str,
30 | required=True,
31 | help="Specify location of model checkpoint.",
32 | )
33 | return parser
34 |
35 |
36 | def create_model(model_arch, model_path):
37 | device = "cuda" if torch.cuda.is_available() else "cpu"
38 | torch.manual_seed(0)
39 |
40 | model_path = str(model_path)
41 | model, _, transform = mobileclip.create_model_and_transforms(
42 | model_arch, pretrained=model_path
43 | )
44 | model.eval()
45 | model = model.to(device)
46 |
47 | return model, transform, device
48 |
49 |
50 | def create_webdataset(
51 | task, transform, data_root=None, dataset_len=None, batch_size=64, num_workers=4
52 | ):
53 | data_folder = f"wds_{task.replace('/','-')}_test"
54 | if data_root is None:
55 | data_root = f"https://huggingface.co/datasets/djghosh/{data_folder}/tree/main"
56 | else:
57 | data_root = os.path.join(data_root, data_folder)
58 | dataset = build_dataset(
59 | dataset_name=f"wds/{task}",
60 | root=data_root,
61 | transform=transform,
62 | split="test",
63 | download=False,
64 | )
65 | if dataset_len:
66 | dataset = dataset.with_length((dataset_len + batch_size - 1) // batch_size)
67 | dataloader = torch.utils.data.DataLoader(
68 | dataset.batched(batch_size),
69 | batch_size=None,
70 | shuffle=False,
71 | num_workers=num_workers,
72 | )
73 | return dataset, dataloader
74 |
75 |
76 | def evaluate_webdataset(
77 | task,
78 | model_arch,
79 | model_path,
80 | data_root=None,
81 | dataset_len=None,
82 | batch_size=64,
83 | num_workers=4,
84 | ):
85 | """Evaluate CLIP model on classification task."""
86 |
87 | # Create model
88 | model, transform, device = create_model(model_arch, model_path)
89 |
90 | # Load data
91 | dataset, dataloader = create_webdataset(
92 | task, transform, data_root, dataset_len, batch_size, num_workers
93 | )
94 |
95 | zeroshot_templates = dataset.templates if hasattr(dataset, "templates") else None
96 | classnames = dataset.classes if hasattr(dataset, "classes") else None
97 | assert (
98 | zeroshot_templates is not None and classnames is not None
99 | ), "Dataset does not support classification"
100 |
101 | # Evaluate
102 | classifier = zsc.zero_shot_classifier(
103 | model,
104 | mobileclip.get_tokenizer(model_arch),
105 | classnames,
106 | zeroshot_templates,
107 | device,
108 | )
109 | logits, target = zsc.run_classification(
110 | model, classifier, dataloader, device, amp=False
111 | )
112 |
113 | # Compute metrics
114 | acc1, acc5 = zsc.accuracy(logits, target, topk=(1, 5))
115 | metrics = {
116 | "acc1": acc1,
117 | "acc5": acc5,
118 | }
119 | return metrics
120 |
121 |
122 | if __name__ == "__main__":
123 | parser = argparse.ArgumentParser(description="Webdataset evaluation script.")
124 | parser = parse_args(parser)
125 | args = parser.parse_args()
126 |
127 | metric = evaluate_webdataset(
128 | task="imagenet1k", model_arch=args.model_arch, model_path=args.model_path
129 | )
130 | print(f"ImageNet1k Eval Metrics: {metric}")
131 |
--------------------------------------------------------------------------------
/get_pretrained_models.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | #
3 | # For licensing see accompanying LICENSE file.
4 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
5 | #
6 | mkdir -p checkpoints
7 | wget https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt -P checkpoints
8 | wget https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt -P checkpoints
9 | wget https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s0.pt -P checkpoints
10 | wget https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt -P checkpoints
11 | wget https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt -P checkpoints
12 |
--------------------------------------------------------------------------------
/hf_dataset_example.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | from PIL import Image
6 | import requests
7 | from io import BytesIO
8 | from datasets import load_dataset
9 | import json
10 | import numpy as np
11 | import torch
12 |
13 | from training.dr.transforms import compose_from_config
14 |
15 |
16 | if __name__ == '__main__':
17 | rconfig_aug = {
18 | "normalize": {
19 | "mean": [0.48145466, 0.4578275, 0.40821073],
20 | "std": [0.26862954, 0.26130258, 0.27577711]
21 | },
22 | "rand_augment": {"enable": True, "p": 1.0},
23 | "random_resized_crop": {"interpolation": "bicubic", "size": 224},
24 | "to_rgb": {"enable": True},
25 | "to_tensor": {"enable": True}
26 | }
27 | dr_transforms = compose_from_config(rconfig_aug)
28 |
29 | dataset = load_dataset("apple/DataCompDR-12M", split="train", streaming=True)
30 | sample = next(iter(dataset))
31 |
32 | # Load image from URL
33 | url = sample['url.txt']
34 | response = requests.get(url)
35 | img = Image.open(BytesIO(response.content))
36 | sample["image"] = img
37 |
38 | # Preprocess image
39 | # Sample an image augmentation
40 | param_augs = json.loads(sample["paug.json"]["param_aug"])
41 | aug_idx = np.random.randint(0, len(param_augs))
42 | params = param_augs[aug_idx]
43 | params = dr_transforms.decompress(params)
44 | image = sample["image"].convert('RGB')
45 | image, _ = dr_transforms.reapply(image, params)
46 |
47 | # Preprocess synthetic text
48 | scapi = np.random.randint(0, len(sample["syn.json"]["syn_text"]))
49 | syn_text = sample["syn.json"]["syn_text"][scapi]
50 |
51 | # Preprocess embeddings
52 | if "npz" in sample:
53 | image_emb = sample["npz"]["image_emb"][aug_idx]
54 | text_emb_all = sample["npz"]["text_emb"]
55 | elif "pth.gz" in sample:
56 | image_emb = sample["pth.gz"]["image_emb"][aug_idx]
57 | text_emb_all = sample["pth.gz"]["text_emb"]
58 | capi = 0
59 | text_emb = text_emb_all[capi]
60 | syn_text_emb = text_emb_all[1+scapi]
61 | if not isinstance(image_emb, torch.Tensor):
62 | image_emb = torch.tensor(image_emb)
63 | text_emb = torch.tensor(text_emb)
64 | syn_text_emb = torch.tensor(syn_text_emb)
65 | image_emb = image_emb.type(torch.float32)
66 | text_emb = text_emb.type(torch.float32)
67 | syn_text_emb = syn_text_emb.type(torch.float32)
68 |
69 | print(
70 | {
71 | 'image': image.shape,
72 | 'image_emb': image_emb.shape,
73 | 'text_emb': text_emb.shape,
74 | "syn_text": syn_text,
75 | 'syn_text_emb': syn_text_emb.shape,
76 | }
77 | )
78 |
--------------------------------------------------------------------------------
/ios_app/ACKNOWLEDGEMENTS:
--------------------------------------------------------------------------------
1 | Acknowledgements
2 | Portions of this Software may utilize the following copyrighted
3 | material, the use of which is hereby acknowledged.
4 |
5 | ------------------------------------------------
6 | swift-coreml-transformers (huggingface)
7 | https://github.com/huggingface/swift-coreml-transformers/tree/master
8 |
9 | Apache License
10 | Version 2.0, January 2004
11 | http://www.apache.org/licenses/
12 |
13 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
14 |
15 | 1. Definitions.
16 |
17 | "License" shall mean the terms and conditions for use, reproduction,
18 | and distribution as defined by Sections 1 through 9 of this document.
19 |
20 | "Licensor" shall mean the copyright owner or entity authorized by
21 | the copyright owner that is granting the License.
22 |
23 | "Legal Entity" shall mean the union of the acting entity and all
24 | other entities that control, are controlled by, or are under common
25 | control with that entity. For the purposes of this definition,
26 | "control" means (i) the power, direct or indirect, to cause the
27 | direction or management of such entity, whether by contract or
28 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
29 | outstanding shares, or (iii) beneficial ownership of such entity.
30 |
31 | "You" (or "Your") shall mean an individual or Legal Entity
32 | exercising permissions granted by this License.
33 |
34 | "Source" form shall mean the preferred form for making modifications,
35 | including but not limited to software source code, documentation
36 | source, and configuration files.
37 |
38 | "Object" form shall mean any form resulting from mechanical
39 | transformation or translation of a Source form, including but
40 | not limited to compiled object code, generated documentation,
41 | and conversions to other media types.
42 |
43 | "Work" shall mean the work of authorship, whether in Source or
44 | Object form, made available under the License, as indicated by a
45 | copyright notice that is included in or attached to the work
46 | (an example is provided in the Appendix below).
47 |
48 | "Derivative Works" shall mean any work, whether in Source or Object
49 | form, that is based on (or derived from) the Work and for which the
50 | editorial revisions, annotations, elaborations, or other modifications
51 | represent, as a whole, an original work of authorship. For the purposes
52 | of this License, Derivative Works shall not include works that remain
53 | separable from, or merely link (or bind by name) to the interfaces of,
54 | the Work and Derivative Works thereof.
55 |
56 | "Contribution" shall mean any work of authorship, including
57 | the original version of the Work and any modifications or additions
58 | to that Work or Derivative Works thereof, that is intentionally
59 | submitted to Licensor for inclusion in the Work by the copyright owner
60 | or by an individual or Legal Entity authorized to submit on behalf of
61 | the copyright owner. For the purposes of this definition, "submitted"
62 | means any form of electronic, verbal, or written communication sent
63 | to the Licensor or its representatives, including but not limited to
64 | communication on electronic mailing lists, source code control systems,
65 | and issue tracking systems that are managed by, or on behalf of, the
66 | Licensor for the purpose of discussing and improving the Work, but
67 | excluding communication that is conspicuously marked or otherwise
68 | designated in writing by the copyright owner as "Not a Contribution."
69 |
70 | "Contributor" shall mean Licensor and any individual or Legal Entity
71 | on behalf of whom a Contribution has been received by Licensor and
72 | subsequently incorporated within the Work.
73 |
74 | 2. Grant of Copyright License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | copyright license to reproduce, prepare Derivative Works of,
78 | publicly display, publicly perform, sublicense, and distribute the
79 | Work and such Derivative Works in Source or Object form.
80 |
81 | 3. Grant of Patent License. Subject to the terms and conditions of
82 | this License, each Contributor hereby grants to You a perpetual,
83 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
84 | (except as stated in this section) patent license to make, have made,
85 | use, offer to sell, sell, import, and otherwise transfer the Work,
86 | where such license applies only to those patent claims licensable
87 | by such Contributor that are necessarily infringed by their
88 | Contribution(s) alone or by combination of their Contribution(s)
89 | with the Work to which such Contribution(s) was submitted. If You
90 | institute patent litigation against any entity (including a
91 | cross-claim or counterclaim in a lawsuit) alleging that the Work
92 | or a Contribution incorporated within the Work constitutes direct
93 | or contributory patent infringement, then any patent licenses
94 | granted to You under this License for that Work shall terminate
95 | as of the date such litigation is filed.
96 |
97 | 4. Redistribution. You may reproduce and distribute copies of the
98 | Work or Derivative Works thereof in any medium, with or without
99 | modifications, and in Source or Object form, provided that You
100 | meet the following conditions:
101 |
102 | (a) You must give any other recipients of the Work or
103 | Derivative Works a copy of this License; and
104 |
105 | (b) You must cause any modified files to carry prominent notices
106 | stating that You changed the files; and
107 |
108 | (c) You must retain, in the Source form of any Derivative Works
109 | that You distribute, all copyright, patent, trademark, and
110 | attribution notices from the Source form of the Work,
111 | excluding those notices that do not pertain to any part of
112 | the Derivative Works; and
113 |
114 | (d) If the Work includes a "NOTICE" text file as part of its
115 | distribution, then any Derivative Works that You distribute must
116 | include a readable copy of the attribution notices contained
117 | within such NOTICE file, excluding those notices that do not
118 | pertain to any part of the Derivative Works, in at least one
119 | of the following places: within a NOTICE text file distributed
120 | as part of the Derivative Works; within the Source form or
121 | documentation, if provided along with the Derivative Works; or,
122 | within a display generated by the Derivative Works, if and
123 | wherever such third-party notices normally appear. The contents
124 | of the NOTICE file are for informational purposes only and
125 | do not modify the License. You may add Your own attribution
126 | notices within Derivative Works that You distribute, alongside
127 | or as an addendum to the NOTICE text from the Work, provided
128 | that such additional attribution notices cannot be construed
129 | as modifying the License.
130 |
131 | You may add Your own copyright statement to Your modifications and
132 | may provide additional or different license terms and conditions
133 | for use, reproduction, or distribution of Your modifications, or
134 | for any such Derivative Works as a whole, provided Your use,
135 | reproduction, and distribution of the Work otherwise complies with
136 | the conditions stated in this License.
137 |
138 | 5. Submission of Contributions. Unless You explicitly state otherwise,
139 | any Contribution intentionally submitted for inclusion in the Work
140 | by You to the Licensor shall be under the terms and conditions of
141 | this License, without any additional terms or conditions.
142 | Notwithstanding the above, nothing herein shall supersede or modify
143 | the terms of any separate license agreement you may have executed
144 | with Licensor regarding such Contributions.
145 |
146 | 6. Trademarks. This License does not grant permission to use the trade
147 | names, trademarks, service marks, or product names of the Licensor,
148 | except as required for reasonable and customary use in describing the
149 | origin of the Work and reproducing the content of the NOTICE file.
150 |
151 | 7. Disclaimer of Warranty. Unless required by applicable law or
152 | agreed to in writing, Licensor provides the Work (and each
153 | Contributor provides its Contributions) on an "AS IS" BASIS,
154 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
155 | implied, including, without limitation, any warranties or conditions
156 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
157 | PARTICULAR PURPOSE. You are solely responsible for determining the
158 | appropriateness of using or redistributing the Work and assume any
159 | risks associated with Your exercise of permissions under this License.
160 |
161 | 8. Limitation of Liability. In no event and under no legal theory,
162 | whether in tort (including negligence), contract, or otherwise,
163 | unless required by applicable law (such as deliberate and grossly
164 | negligent acts) or agreed to in writing, shall any Contributor be
165 | liable to You for damages, including any direct, indirect, special,
166 | incidental, or consequential damages of any character arising as a
167 | result of this License or out of the use or inability to use the
168 | Work (including but not limited to damages for loss of goodwill,
169 | work stoppage, computer failure or malfunction, or any and all
170 | other commercial damages or losses), even if such Contributor
171 | has been advised of the possibility of such damages.
172 |
173 | 9. Accepting Warranty or Additional Liability. While redistributing
174 | the Work or Derivative Works thereof, You may choose to offer,
175 | and charge a fee for, acceptance of support, warranty, indemnity,
176 | or other liability obligations and/or rights consistent with this
177 | License. However, in accepting such obligations, You may act only
178 | on Your own behalf and on Your sole responsibility, not on behalf
179 | of any other Contributor, and only if You agree to indemnify,
180 | defend, and hold each Contributor harmless for any liability
181 | incurred by, or claims asserted against, such Contributor by reason
182 | of your accepting any such warranty or additional liability.
183 |
184 | END OF TERMS AND CONDITIONS
185 |
186 | APPENDIX: How to apply the Apache License to your work.
187 |
188 | To apply the Apache License to your work, attach the following
189 | boilerplate notice, with the fields enclosed by brackets "[]"
190 | replaced with your own identifying information. (Don't include
191 | the brackets!) The text should be enclosed in the appropriate
192 | comment syntax for the file format. We also recommend that a
193 | file or class name and description of purpose be included on the
194 | same "printed page" as the copyright notice for easier
195 | identification within third-party archives.
196 |
197 | Copyright [yyyy] [name of copyright owner]
198 |
199 | Licensed under the Apache License, Version 2.0 (the "License");
200 | you may not use this file except in compliance with the License.
201 | You may obtain a copy of the License at
202 |
203 | http://www.apache.org/licenses/LICENSE-2.0
204 |
205 | Unless required by applicable law or agreed to in writing, software
206 | distributed under the License is distributed on an "AS IS" BASIS,
207 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
208 | See the License for the specific language governing permissions and
209 | limitations under the License.
210 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore.xcodeproj/project.xcworkspace/contents.xcworkspacedata:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | IDEDidComputeMac32BitWarning
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore.xcodeproj/project.xcworkspace/xcuserdata/pavankumaranasosaluvasu.xcuserdatad/UserInterfaceState.xcuserstate:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-mobileclip/1140b8d197e4ed7d56b3a92216ded98bb1c2ac87/ios_app/MobileCLIPExplore.xcodeproj/project.xcworkspace/xcuserdata/pavankumaranasosaluvasu.xcuserdatad/UserInterfaceState.xcuserstate
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore.xcodeproj/xcuserdata/pavankumaranasosaluvasu.xcuserdatad/xcschemes/xcschememanagement.plist:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | SchemeUserState
6 |
7 | MobileCLIPExplore.xcscheme_^#shared#^_
8 |
9 | orderHint
10 | 0
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Assets.xcassets/AccentColor.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "idiom" : "universal"
5 | }
6 | ],
7 | "info" : {
8 | "author" : "xcode",
9 | "version" : 1
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Assets.xcassets/AddClassName.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "color" : {
5 | "platform" : "universal",
6 | "reference" : "systemBlueColor"
7 | },
8 | "idiom" : "universal"
9 | },
10 | {
11 | "appearances" : [
12 | {
13 | "appearance" : "luminosity",
14 | "value" : "dark"
15 | }
16 | ],
17 | "color" : {
18 | "platform" : "universal",
19 | "reference" : "systemBlueColor"
20 | },
21 | "idiom" : "universal"
22 | }
23 | ],
24 | "info" : {
25 | "author" : "xcode",
26 | "version" : 1
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Assets.xcassets/AppIcon.appiconset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "images" : [
3 | {
4 | "idiom" : "universal",
5 | "platform" : "ios",
6 | "size" : "1024x1024"
7 | }
8 | ],
9 | "info" : {
10 | "author" : "xcode",
11 | "version" : 1
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Assets.xcassets/Preset.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "color" : {
5 | "platform" : "ios",
6 | "reference" : "linkColor"
7 | },
8 | "idiom" : "universal"
9 | },
10 | {
11 | "appearances" : [
12 | {
13 | "appearance" : "luminosity",
14 | "value" : "dark"
15 | }
16 | ],
17 | "color" : {
18 | "platform" : "ios",
19 | "reference" : "linkColor"
20 | },
21 | "idiom" : "universal"
22 | }
23 | ],
24 | "info" : {
25 | "author" : "xcode",
26 | "version" : 1
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Assets.xcassets/PresetDisabled.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "color" : {
5 | "color-space" : "display-p3",
6 | "components" : {
7 | "alpha" : "1.000",
8 | "blue" : "0.966",
9 | "green" : "0.661",
10 | "red" : "0.544"
11 | }
12 | },
13 | "idiom" : "universal"
14 | },
15 | {
16 | "appearances" : [
17 | {
18 | "appearance" : "luminosity",
19 | "value" : "dark"
20 | }
21 | ],
22 | "color" : {
23 | "color-space" : "display-p3",
24 | "components" : {
25 | "alpha" : "1.000",
26 | "blue" : "0.967",
27 | "green" : "0.692",
28 | "red" : "0.522"
29 | }
30 | },
31 | "idiom" : "universal"
32 | }
33 | ],
34 | "info" : {
35 | "author" : "xcode",
36 | "version" : 1
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Assets.xcassets/PromptEditor.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "color" : {
5 | "color-space" : "gray-gamma-22",
6 | "components" : {
7 | "alpha" : "1.000",
8 | "white" : "1.000"
9 | }
10 | },
11 | "idiom" : "universal"
12 | },
13 | {
14 | "appearances" : [
15 | {
16 | "appearance" : "luminosity",
17 | "value" : "dark"
18 | }
19 | ],
20 | "color" : {
21 | "color-space" : "gray-gamma-22",
22 | "components" : {
23 | "alpha" : "1.000",
24 | "white" : "27"
25 | }
26 | },
27 | "idiom" : "universal"
28 | }
29 | ],
30 | "info" : {
31 | "author" : "xcode",
32 | "version" : 1
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Assets.xcassets/PromptEditorBackground.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "color" : {
5 | "color-space" : "display-p3",
6 | "components" : {
7 | "alpha" : "1.000",
8 | "blue" : "247",
9 | "green" : "242",
10 | "red" : "242"
11 | }
12 | },
13 | "idiom" : "universal"
14 | },
15 | {
16 | "appearances" : [
17 | {
18 | "appearance" : "luminosity",
19 | "value" : "dark"
20 | }
21 | ],
22 | "color" : {
23 | "color-space" : "srgb",
24 | "components" : {
25 | "alpha" : "1.000",
26 | "blue" : "0.000",
27 | "green" : "0.000",
28 | "red" : "0.000"
29 | }
30 | },
31 | "idiom" : "universal"
32 | }
33 | ],
34 | "info" : {
35 | "author" : "xcode",
36 | "version" : 1
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/AsyncFactory.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import Foundation
7 |
8 | /// Asynchronous factory for slow-to-load types.
9 | public actor AsyncFactory {
10 |
11 | private enum State {
12 | case idle(() -> T)
13 | case initializing(Task)
14 | case initialized(T)
15 | }
16 |
17 | private var state: State
18 |
19 | public init(factory: @escaping () -> T) {
20 | self.state = .idle(factory)
21 | }
22 |
23 | public func get() async -> T {
24 | switch state {
25 | case .idle(let factory):
26 | let task = Task {
27 | factory()
28 | }
29 | self.state = .initializing(task)
30 | let value = await task.value
31 | self.state = .initialized(value)
32 | return value
33 |
34 | case .initializing(let task):
35 | return await task.value
36 |
37 | case .initialized(let v):
38 | return v
39 | }
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/CameraController.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import AVFoundation
7 | import CoreImage
8 | import UIKit
9 |
10 | @Observable
11 | class CameraController: NSObject {
12 |
13 | private var framesContinuation: AsyncStream.Continuation?
14 |
15 | public var backCamera = true {
16 | didSet {
17 | stop()
18 | start()
19 | }
20 | }
21 |
22 | private var permissionGranted = true
23 | private var captureSession: AVCaptureSession?
24 | private let sessionQueue = DispatchQueue(label: "sessionQueue")
25 |
26 | public func attach(continuation: AsyncStream.Continuation) {
27 | sessionQueue.async {
28 | self.framesContinuation = continuation
29 | }
30 | }
31 |
32 | public func detatch() {
33 | sessionQueue.async {
34 | self.framesContinuation = nil
35 | }
36 | }
37 |
38 | public func stop() {
39 | sessionQueue.sync { [self] in
40 | captureSession?.stopRunning()
41 | captureSession = nil
42 | }
43 |
44 | }
45 |
46 | public func start() {
47 | sessionQueue.async { [self] in
48 | let captureSession = AVCaptureSession()
49 | self.captureSession = captureSession
50 |
51 | self.checkPermission()
52 | self.setupCaptureSession(position: backCamera ? .back : .front)
53 | captureSession.startRunning()
54 | }
55 | }
56 |
57 | private func setOrientation(_ orientation: UIDeviceOrientation) {
58 | guard let captureSession else { return }
59 |
60 | let angle: Double?
61 | switch orientation {
62 | case .unknown, .faceDown:
63 | angle = nil
64 | case .portrait, .faceUp:
65 | angle = 90
66 | case .portraitUpsideDown:
67 | angle = 270
68 | case .landscapeLeft:
69 | angle = 0
70 | case .landscapeRight:
71 | angle = 180
72 | @unknown default:
73 | angle = nil
74 | }
75 |
76 | if let angle {
77 | for output in captureSession.outputs {
78 | output.connection(with: .video)?.videoRotationAngle = angle
79 | }
80 | }
81 | }
82 |
83 | func checkPermission() {
84 | switch AVCaptureDevice.authorizationStatus(for: .video) {
85 | case .authorized:
86 | // The user has previously granted access to the camera.
87 | self.permissionGranted = true
88 |
89 | case .notDetermined:
90 | // The user has not yet been asked for camera access.
91 | self.requestPermission()
92 |
93 | // Combine the two other cases into the default case
94 | default:
95 | self.permissionGranted = false
96 | }
97 | }
98 |
99 | func requestPermission() {
100 | // Strong reference not a problem here but might become one in the future.
101 | AVCaptureDevice.requestAccess(for: .video) { [unowned self] granted in
102 | self.permissionGranted = granted
103 | }
104 | }
105 |
106 | func setupCaptureSession(position: AVCaptureDevice.Position) {
107 | guard let captureSession else { return }
108 |
109 | let videoOutput = AVCaptureVideoDataOutput()
110 |
111 | guard permissionGranted else {
112 | print("No permission for camera")
113 | return
114 | }
115 |
116 | let videoDeviceDiscoverySession = AVCaptureDevice.DiscoverySession(
117 | deviceTypes: [.builtInDualCamera, .builtInWideAngleCamera],
118 | mediaType: .video,
119 | position: position)
120 |
121 | guard
122 | let videoDevice = videoDeviceDiscoverySession.devices.first
123 | else {
124 | print("Unable to find video device")
125 | return
126 | }
127 | guard let videoDeviceInput = try? AVCaptureDeviceInput(device: videoDevice) else {
128 | print("Unable to create AVCaptureDeviceInput")
129 | return
130 | }
131 | guard captureSession.canAddInput(videoDeviceInput) else {
132 | print("Unable to add input")
133 | return
134 | }
135 | captureSession.addInput(videoDeviceInput)
136 |
137 | videoOutput.setSampleBufferDelegate(self, queue: DispatchQueue(label: "sampleBufferQueue"))
138 | captureSession.addOutput(videoOutput)
139 | captureSession.sessionPreset = AVCaptureSession.Preset.vga640x480
140 |
141 | if videoDevice.isContinuityCamera {
142 | setOrientation(.portrait)
143 | } else {
144 | setOrientation(UIDevice.current.orientation)
145 | }
146 | }
147 | }
148 |
149 | extension CameraController: AVCaptureVideoDataOutputSampleBufferDelegate {
150 | func captureOutput(
151 | _ output: AVCaptureOutput, didOutput sampleBuffer: CMSampleBuffer,
152 | from connection: AVCaptureConnection
153 | ) {
154 | if sampleBuffer.isValid && sampleBuffer.imageBuffer != nil {
155 | framesContinuation?.yield(sampleBuffer)
156 | }
157 | }
158 | }
159 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Configuration/Build.xcconfig:
--------------------------------------------------------------------------------
1 | // The `DISAMBIGUATOR` configuration is to make it easier to build
2 | // and run a sample code project. Once you set your project's development team,
3 | // you'll have a unique bundle identifier. This is because the bundle identifier
4 | // is derived based on the 'DISAMBIGUATOR' value. Do not use this
5 | // approach in your own projects—it's only useful for example projects because
6 | // they are frequently downloaded and don't have a development team set.
7 | DISAMBIGUATOR=${DEVELOPMENT_TEAM}
8 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/ContentView.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import CoreML
7 | import SwiftUI
8 |
9 | struct ContentView: View {
10 | @State private var camera = CameraController()
11 |
12 | var body: some View {
13 | InferenceView(
14 | camera: camera,
15 | backCamera: $camera.backCamera
16 | )
17 | .ignoresSafeArea(edges: [.bottom])
18 | .task {
19 | camera.start()
20 | }
21 | }
22 | }
23 |
24 | struct ContentView_Previews: PreviewProvider {
25 | static var previews: some View {
26 | ContentView()
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/DisplayPrediction.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import Foundation
7 | import SwiftUI
8 |
9 | struct DisplayPrediction: Identifiable {
10 | var id: String { className }
11 |
12 | var className: String
13 | var cosineSimilarity: Float
14 | }
15 |
16 | /// Top N percent will be bold
17 | private let boldPercent: Float = 0.05
18 |
19 | /// It has to exceed this to be bold
20 | private let boldThreshold: Float = 0.22
21 |
22 | /// Colors for interpolating based on color scheme
23 | private func colors(for colorScheme: ColorScheme) -> [Color] {
24 | var result = [Color]()
25 |
26 | let topGray: CGFloat = colorScheme == .dark ? 1.0 : 0.0
27 | let bottomGray: CGFloat = colorScheme == .dark ? 0.2 : 0.8
28 |
29 | for position in stride(from: 0, through: 1, by: 0.05) {
30 | let gray = (bottomGray + CGFloat(position) * (topGray - bottomGray))
31 |
32 | result.append(.init(cgColor: CGColor(gray: gray, alpha: 1.0)))
33 | }
34 |
35 | return result
36 | }
37 |
38 | /// Boldness threshold
39 | private let maximumThreshold: Float = 0.22
40 | private let minimumTreshold: Float = 0.15
41 |
42 | struct DisplayPredictionFormatter {
43 |
44 | let actualMaximumCosineSimilarity: Float
45 | let colorScheme: ColorScheme
46 |
47 | init(predictions: [DisplayPrediction], colorScheme: ColorScheme) {
48 | self.actualMaximumCosineSimilarity = predictions.map { $0.cosineSimilarity }.max() ?? 0
49 | self.colorScheme = colorScheme
50 | }
51 |
52 | func isBold(_ prediction: DisplayPrediction) -> Bool {
53 | prediction.cosineSimilarity >= boldThreshold
54 | && prediction.cosineSimilarity >= actualMaximumCosineSimilarity * (1 - boldPercent)
55 | }
56 |
57 | func color(_ prediction: DisplayPrediction) -> Color {
58 | let position =
59 | min(max(prediction.cosineSimilarity - minimumTreshold, 0), maximumThreshold)
60 | / (maximumThreshold - minimumTreshold)
61 | let index = min(
62 | Int(round(Float(colors(for: colorScheme).count - 1) * position)),
63 | colors(for: colorScheme).count - 1)
64 | return colors(for: colorScheme)[index]
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Extensions.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import CoreImage
7 |
8 | extension CIImage {
9 |
10 | /// cropToSquare image
11 | /// - Returns: Cropped image
12 | func cropToSquare() -> CIImage? {
13 | let size = min(self.extent.width, self.extent.height)
14 | let x = round((self.extent.width - size) / 2)
15 | let y = round((self.extent.height - size) / 2)
16 |
17 | let cropRect = CGRect(
18 | x: x,
19 | y: y,
20 | width: size,
21 | height: size
22 | )
23 |
24 | let translate = CGAffineTransform(translationX: -x, y: -y)
25 |
26 | return
27 | self
28 | .cropped(to: cropRect)
29 | .transformed(by: translate)
30 | }
31 |
32 | /// Resize image
33 | /// - Parameter size: Size to resize to
34 | /// - Returns: Resized image
35 | func resize(size: CGSize) -> CIImage? {
36 | let scaleX = size.width / self.extent.width
37 | let scaleY = size.height / self.extent.height
38 | return self.transformed(by: CGAffineTransform(scaleX: scaleX, y: scaleY))
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/MobileCLIPExploreApp.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import SwiftUI
7 |
8 | @main
9 | struct MobileCLIPExploreApp: App {
10 | var body: some Scene {
11 | WindowGroup {
12 | ContentView()
13 | }
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Models.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import CoreML
7 | import Foundation
8 |
9 | protocol CLIPEncoder {
10 |
11 | var targetImageSize: CGSize { get }
12 |
13 | func load() async
14 |
15 | func encode(image: CVPixelBuffer) async throws -> MLMultiArray
16 |
17 | func encode(text: MLMultiArray) async throws -> MLMultiArray
18 |
19 | }
20 |
21 | public struct ModelConfiguration: Identifiable, Hashable {
22 | public let name: String
23 | let factory: () -> CLIPEncoder
24 | public var id: String { name }
25 |
26 | public static func == (lhs: ModelConfiguration, rhs: ModelConfiguration) -> Bool {
27 | lhs.name == rhs.name
28 | }
29 |
30 | public func hash(into hasher: inout Hasher) {
31 | hasher.combine(name)
32 | }
33 | }
34 |
35 | public let models: [ModelConfiguration] = [
36 | .init(name: "MobileCLIP-S0", factory: { S0Model() }),
37 | .init(name: "MobileCLIP-S1", factory: { S1Model() }),
38 | .init(name: "MobileCLIP-S2", factory: { S2Model() }),
39 | .init(name: "MobileCLIP-BLT", factory: { BLTModel() }),
40 | ]
41 |
42 | public let defaultModel = ModelConfiguration(name: "MobileCLIP-S2", factory: { S2Model() })
43 |
44 | public struct S0Model: CLIPEncoder {
45 |
46 | let imageEncoder = AsyncFactory {
47 | do {
48 | return try mobileclip_s0_image()
49 | } catch {
50 | fatalError("Failed to initialize ML model: \(error)")
51 | }
52 | }
53 |
54 | let textEncoder = AsyncFactory {
55 | do {
56 | return try mobileclip_s0_text()
57 | } catch {
58 | fatalError("Failed to initialize ML model: \(error)")
59 | }
60 | }
61 |
62 | func load() async {
63 | async let t = textEncoder.get()
64 | async let i = imageEncoder.get()
65 | _ = await (t, i)
66 | }
67 |
68 | let targetImageSize = CGSize(width: 256, height: 256)
69 |
70 | func encode(image: CVPixelBuffer) async throws -> MLMultiArray {
71 | try await imageEncoder.get().prediction(image: image).final_emb_1
72 | }
73 |
74 | func encode(text: MLMultiArray) async throws -> MLMultiArray {
75 | try await textEncoder.get().prediction(text: text).final_emb_1
76 | }
77 | }
78 |
79 | public struct S1Model: CLIPEncoder {
80 |
81 | let imageEncoder = AsyncFactory {
82 | do {
83 | return try mobileclip_s1_image()
84 | } catch {
85 | fatalError("Failed to initialize ML model: \(error)")
86 | }
87 | }
88 |
89 | let textEncoder = AsyncFactory {
90 | do {
91 | return try mobileclip_s1_text()
92 | } catch {
93 | fatalError("Failed to initialize ML model: \(error)")
94 | }
95 | }
96 |
97 | func load() async {
98 | async let t = textEncoder.get()
99 | async let i = imageEncoder.get()
100 | _ = await (t, i)
101 | }
102 |
103 | let targetImageSize = CGSize(width: 256, height: 256)
104 |
105 | func encode(image: CVPixelBuffer) async throws -> MLMultiArray {
106 | try await imageEncoder.get().prediction(image: image).final_emb_1
107 | }
108 |
109 | func encode(text: MLMultiArray) async throws -> MLMultiArray {
110 | try await textEncoder.get().prediction(text: text).final_emb_1
111 | }
112 | }
113 |
114 | public struct S2Model: CLIPEncoder {
115 |
116 | let imageEncoder = AsyncFactory {
117 | do {
118 | return try mobileclip_s2_image()
119 | } catch {
120 | fatalError("Failed to initialize ML model: \(error)")
121 | }
122 | }
123 |
124 | let textEncoder = AsyncFactory {
125 | do {
126 | return try mobileclip_s2_text()
127 | } catch {
128 | fatalError("Failed to initialize ML model: \(error)")
129 | }
130 | }
131 |
132 | func load() async {
133 | async let t = textEncoder.get()
134 | async let i = imageEncoder.get()
135 | _ = await (t, i)
136 | }
137 |
138 | let targetImageSize = CGSize(width: 256, height: 256)
139 |
140 | func encode(image: CVPixelBuffer) async throws -> MLMultiArray {
141 | try await imageEncoder.get().prediction(image: image).final_emb_1
142 | }
143 |
144 | func encode(text: MLMultiArray) async throws -> MLMultiArray {
145 | try await textEncoder.get().prediction(text: text).final_emb_1
146 | }
147 | }
148 |
149 | public struct BLTModel: CLIPEncoder {
150 |
151 | let imageEncoder = AsyncFactory {
152 | do {
153 | return try mobileclip_blt_image()
154 | } catch {
155 | fatalError("Failed to initialize ML model: \(error)")
156 | }
157 | }
158 |
159 | let textEncoder = AsyncFactory {
160 | do {
161 | return try mobileclip_blt_text()
162 | } catch {
163 | fatalError("Failed to initialize ML model: \(error)")
164 | }
165 | }
166 |
167 | func load() async {
168 | async let t = textEncoder.get()
169 | async let i = imageEncoder.get()
170 | _ = await (t, i)
171 | }
172 |
173 | let targetImageSize = CGSize(width: 224, height: 224)
174 |
175 | func encode(image: CVPixelBuffer) async throws -> MLMultiArray {
176 | try await imageEncoder.get().prediction(image: image).final_emb_1
177 | }
178 |
179 | func encode(text: MLMultiArray) async throws -> MLMultiArray {
180 | try await textEncoder.get().prediction(text: text).final_emb_1
181 | }
182 | }
183 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Models/README.md:
--------------------------------------------------------------------------------
1 | Download the required CoreML models from [this link](https://huggingface.co/apple/coreml-mobileclip) and place them in this directory.
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Preview Content/Preview Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Prompt.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import Foundation
7 |
8 | public let presets = [
9 | PromptPreset(
10 | title: "Desk items",
11 | prompt: .init(
12 | prefix: "A photo of",
13 | suffix: "",
14 | classNames: [
15 | "pen",
16 | "pencil",
17 | "paper",
18 | "mouse",
19 | "keyboard",
20 | "computer",
21 | "phone",
22 | "stapler",
23 | "cup",
24 | ])
25 | ),
26 | PromptPreset(
27 | title: "Expressions",
28 | prompt: .init(
29 | prefix: "A person",
30 | suffix: "",
31 | classNames: [
32 | "smiling",
33 | "waving",
34 | "giving a thumbs up",
35 | "sticking out their tongue",
36 | "looking angry",
37 | ])
38 | ),
39 | PromptPreset(
40 | title: "Custom",
41 | prompt: .init(
42 | prefix: "A photo of",
43 | suffix: "",
44 | classNames: [])
45 | ),
46 | ]
47 |
48 | public struct PromptPreset: Identifiable {
49 | public let id = UUID()
50 | public let title: String
51 | public let prompt: Prompt
52 | }
53 |
54 | public struct Prompt {
55 | public var prefix: String
56 | public var suffix: String
57 | public var classNames: [String]
58 |
59 | public func fullPrompts() -> [String] {
60 | classNames.map {
61 | "\(prefix) \($0) \(suffix)"
62 | }
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Tokenizer/CLIPTokenizer.swift:
--------------------------------------------------------------------------------
1 | //
2 | // CLIPTokenizer.swift
3 | // CoreMLBert
4 | //
5 | // Created by Matthew Waller on 1/31/23.
6 | // Copyright © 2023 Hugging Face. All rights reserved.
7 | //
8 | // Modified by Hugues Thomas on 5/14/24.
9 | //
10 | // See https://github.com/huggingface/swift-coreml-transformers/pull/30
11 |
12 | import Foundation
13 |
14 | struct BytePair: Hashable {
15 | let a: String
16 | let b: String
17 | init(_ a: String, _ b: String) {
18 | self.a = a
19 | self.b = b
20 | }
21 | init(tuple: [String]) {
22 | self.a = tuple[0]
23 | self.b = tuple[1]
24 | }
25 |
26 | static func == (lhs: BytePair, rhs: BytePair) -> Bool {
27 | return lhs.a == rhs.a && lhs.b == rhs.b
28 | }
29 | func hash(into hasher: inout Hasher) {
30 | hasher.combine(a)
31 | hasher.combine(b)
32 | }
33 | }
34 |
35 | extension String {
36 | fileprivate func ranges(of string: String, options: CompareOptions = .regularExpression)
37 | -> [Range]
38 | {
39 | var result: [Range] = []
40 | var start = startIndex
41 | while let range = range(of: string, options: options, range: start ..< endIndex) {
42 | result.append(range)
43 | start =
44 | range.lowerBound < range.upperBound
45 | ? range.upperBound
46 | : index(range.lowerBound, offsetBy: 1, limitedBy: endIndex) ?? endIndex
47 | }
48 | return result
49 | }
50 | }
51 |
52 | class CLIPTokenizer {
53 | let bpeRanks: [BytePair: Int]
54 | private let encoder: [String: Int]
55 | private let decoder: [Int: String]
56 | let contextLength = 77
57 |
58 | init() {
59 |
60 | let url = Bundle.main.url(forResource: "clip-merges", withExtension: "txt")!
61 |
62 | let bpeMergesTxt = try! String(contentsOf: url)
63 | let arr = bpeMergesTxt.split(separator: "\n").map { String($0) }
64 | var bpeRanks: [BytePair: Int] = [:]
65 | for i in 1 ..< arr.count {
66 | let tuple = arr[i].split(separator: " ").map { String($0) }
67 | let bp = BytePair(tuple: tuple)
68 | bpeRanks[bp] = i - 1
69 | }
70 | self.bpeRanks = bpeRanks
71 |
72 | self.encoder = {
73 | let url = Bundle.main.url(forResource: "clip-vocab", withExtension: "json")!
74 | let json = try! Data(contentsOf: url)
75 | let decoder = JSONDecoder()
76 | let vocab = try! decoder.decode([String: Int].self, from: json)
77 | return vocab
78 | }()
79 |
80 | self.decoder = Utils.invert(self.encoder)
81 | }
82 |
83 | func byteEncode(text: String) -> [String] {
84 | let RE =
85 | "<\\|startoftext\\|>|<\\|endoftext\\|>|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+"
86 |
87 | // Original code not working on earlier iOS versions
88 | // let tokens = text.ranges(of: RE).map { String(text[$0]) }
89 | // return tokens.map { (token) -> String in
90 | // return Array(token.utf8).map { byteEncoder[$0]! }.joined()
91 | // }
92 |
93 | // Modification by Hugues Thomas
94 | let regex = try! NSRegularExpression(pattern: RE, options: [])
95 | let matches = regex.matches(
96 | in: text, options: [], range: NSRange(location: 0, length: text.utf16.count))
97 | let tokens = matches.map { (match) -> String in
98 | let range = Range(match.range, in: text)!
99 | return String(text[range])
100 | }
101 | return tokens.map { (token) -> String in
102 | return Array(token.utf8).map { byteEncoder[$0]! }.joined()
103 | }
104 |
105 | }
106 |
107 | private func getPairs(word: [String]) -> Set {
108 | var s = Set()
109 | for i in 0 ..< word.count - 1 {
110 | let bp = BytePair(
111 | word[i],
112 | word[i + 1]
113 | )
114 | s.insert(bp)
115 | }
116 | return s
117 | }
118 |
119 | func bpe(token: String) -> String {
120 | if token.count <= 1 {
121 | return token + ""
122 | }
123 |
124 | var word = Array(token).map { String($0) }
125 | let last = (word.last ?? "") + ""
126 | word.removeLast()
127 | word.append(last)
128 | var pairs = Array(getPairs(word: word))
129 | if pairs.isEmpty {
130 | return token + ""
131 | }
132 |
133 | while true {
134 | let bigrams = pairs.filter { (bp) -> Bool in bpeRanks[bp] != nil }
135 | if bigrams.count == 0 {
136 | break
137 | }
138 | let bigram = bigrams.min { (bp1, bp2) -> Bool in
139 | return bpeRanks[bp1]! < bpeRanks[bp2]!
140 | }!
141 | let first = bigram.a
142 | let second = bigram.b
143 | var newWord: [String] = []
144 | var i = 0
145 | while i < word.count {
146 | if let j = word[i ..< word.count].firstIndex(of: first) {
147 | newWord.append(contentsOf: word[i ..< j])
148 | i = j
149 | } else {
150 | newWord.append(contentsOf: word[i ..< word.count])
151 | break
152 | }
153 |
154 | if word[i] == first && i < word.count - 1 && word[i + 1] == second {
155 | newWord.append(first + second)
156 | i += 2
157 | } else {
158 | newWord.append(word[i])
159 | i += 1
160 | }
161 | }
162 | word = newWord
163 | if word.count == 1 {
164 | break
165 | } else {
166 | pairs = Array(getPairs(word: word))
167 | }
168 | }
169 | return word.joined(separator: " ")
170 | }
171 |
172 | func tokenize(text: String) -> [String] {
173 | var tokens: [String] = []
174 | let lowercased = text.lowercased()
175 | for token in self.byteEncode(text: lowercased) {
176 | let xx = self.bpe(token: token).split(separator: " ").map { String($0) }
177 | tokens.append(contentsOf: xx)
178 | }
179 | return tokens
180 | }
181 |
182 | /// Main entry point
183 | func encode(text: String) -> [Int] {
184 | return tokenize(text: text).compactMap { encoder[$0] }
185 | }
186 |
187 | /// Decode
188 | func decode(tokens: [Int]) -> String {
189 | let text = tokens.map { decoder[$0]! }.joined(separator: "")
190 | let utfCodepoints = text.map { byteDecoder[String($0)]! }
191 | return String(decoding: utfCodepoints, as: UTF8.self)
192 | }
193 |
194 | func encode_full(text: String) -> [Int] {
195 | let tokens = encode(text: text)
196 |
197 | // Create the full input tokens as a multiarray of shape 1 x contextLength
198 | var fullTokens = Array(repeating: 0, count: contextLength)
199 | fullTokens[0] = encoder["<|startoftext|>"]!
200 | for i in 0 ..< tokens.count {
201 | fullTokens[i + 1] = tokens[i]
202 | }
203 | fullTokens[tokens.count + 1] = encoder["<|endoftext|>"]!
204 | return fullTokens
205 |
206 | }
207 | }
208 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Tokenizer/GPT2ByteEncoder.swift:
--------------------------------------------------------------------------------
1 | //
2 | // GPT2ByteEncoder.swift
3 | // CoreMLBert
4 | //
5 | // Created by Julien Chaumond on 18/07/2019.
6 | // Copyright © 2019 Hugging Face. All rights reserved.
7 | //
8 | // See https://github.com/huggingface/swift-coreml-transformers
9 |
10 | import Foundation
11 |
12 | let byteEncoder: [UTF8.CodeUnit: String] = [
13 | 33: "!",
14 | 34: "\"",
15 | 35: "#",
16 | 36: "$",
17 | 37: "%",
18 | 38: "&",
19 | 39: "'",
20 | 40: "(",
21 | 41: ")",
22 | 42: "*",
23 | 43: "+",
24 | 44: ",",
25 | 45: "-",
26 | 46: ".",
27 | 47: "/",
28 | 48: "0",
29 | 49: "1",
30 | 50: "2",
31 | 51: "3",
32 | 52: "4",
33 | 53: "5",
34 | 54: "6",
35 | 55: "7",
36 | 56: "8",
37 | 57: "9",
38 | 58: ":",
39 | 59: ";",
40 | 60: "<",
41 | 61: "=",
42 | 62: ">",
43 | 63: "?",
44 | 64: "@",
45 | 65: "A",
46 | 66: "B",
47 | 67: "C",
48 | 68: "D",
49 | 69: "E",
50 | 70: "F",
51 | 71: "G",
52 | 72: "H",
53 | 73: "I",
54 | 74: "J",
55 | 75: "K",
56 | 76: "L",
57 | 77: "M",
58 | 78: "N",
59 | 79: "O",
60 | 80: "P",
61 | 81: "Q",
62 | 82: "R",
63 | 83: "S",
64 | 84: "T",
65 | 85: "U",
66 | 86: "V",
67 | 87: "W",
68 | 88: "X",
69 | 89: "Y",
70 | 90: "Z",
71 | 91: "[",
72 | 92: "\\",
73 | 93: "]",
74 | 94: "^",
75 | 95: "_",
76 | 96: "`",
77 | 97: "a",
78 | 98: "b",
79 | 99: "c",
80 | 100: "d",
81 | 101: "e",
82 | 102: "f",
83 | 103: "g",
84 | 104: "h",
85 | 105: "i",
86 | 106: "j",
87 | 107: "k",
88 | 108: "l",
89 | 109: "m",
90 | 110: "n",
91 | 111: "o",
92 | 112: "p",
93 | 113: "q",
94 | 114: "r",
95 | 115: "s",
96 | 116: "t",
97 | 117: "u",
98 | 118: "v",
99 | 119: "w",
100 | 120: "x",
101 | 121: "y",
102 | 122: "z",
103 | 123: "{",
104 | 124: "|",
105 | 125: "}",
106 | 126: "~",
107 | 161: "\u{00a1}",
108 | 162: "\u{00a2}",
109 | 163: "\u{00a3}",
110 | 164: "\u{00a4}",
111 | 165: "\u{00a5}",
112 | 166: "\u{00a6}",
113 | 167: "\u{00a7}",
114 | 168: "\u{00a8}",
115 | 169: "\u{00a9}",
116 | 170: "\u{00aa}",
117 | 171: "\u{00ab}",
118 | 172: "\u{00ac}",
119 | 174: "\u{00ae}",
120 | 175: "\u{00af}",
121 | 176: "\u{00b0}",
122 | 177: "\u{00b1}",
123 | 178: "\u{00b2}",
124 | 179: "\u{00b3}",
125 | 180: "\u{00b4}",
126 | 181: "\u{00b5}",
127 | 182: "\u{00b6}",
128 | 183: "\u{00b7}",
129 | 184: "\u{00b8}",
130 | 185: "\u{00b9}",
131 | 186: "\u{00ba}",
132 | 187: "\u{00bb}",
133 | 188: "\u{00bc}",
134 | 189: "\u{00bd}",
135 | 190: "\u{00be}",
136 | 191: "\u{00bf}",
137 | 192: "\u{00c0}",
138 | 193: "\u{00c1}",
139 | 194: "\u{00c2}",
140 | 195: "\u{00c3}",
141 | 196: "\u{00c4}",
142 | 197: "\u{00c5}",
143 | 198: "\u{00c6}",
144 | 199: "\u{00c7}",
145 | 200: "\u{00c8}",
146 | 201: "\u{00c9}",
147 | 202: "\u{00ca}",
148 | 203: "\u{00cb}",
149 | 204: "\u{00cc}",
150 | 205: "\u{00cd}",
151 | 206: "\u{00ce}",
152 | 207: "\u{00cf}",
153 | 208: "\u{00d0}",
154 | 209: "\u{00d1}",
155 | 210: "\u{00d2}",
156 | 211: "\u{00d3}",
157 | 212: "\u{00d4}",
158 | 213: "\u{00d5}",
159 | 214: "\u{00d6}",
160 | 215: "\u{00d7}",
161 | 216: "\u{00d8}",
162 | 217: "\u{00d9}",
163 | 218: "\u{00da}",
164 | 219: "\u{00db}",
165 | 220: "\u{00dc}",
166 | 221: "\u{00dd}",
167 | 222: "\u{00de}",
168 | 223: "\u{00df}",
169 | 224: "\u{00e0}",
170 | 225: "\u{00e1}",
171 | 226: "\u{00e2}",
172 | 227: "\u{00e3}",
173 | 228: "\u{00e4}",
174 | 229: "\u{00e5}",
175 | 230: "\u{00e6}",
176 | 231: "\u{00e7}",
177 | 232: "\u{00e8}",
178 | 233: "\u{00e9}",
179 | 234: "\u{00ea}",
180 | 235: "\u{00eb}",
181 | 236: "\u{00ec}",
182 | 237: "\u{00ed}",
183 | 238: "\u{00ee}",
184 | 239: "\u{00ef}",
185 | 240: "\u{00f0}",
186 | 241: "\u{00f1}",
187 | 242: "\u{00f2}",
188 | 243: "\u{00f3}",
189 | 244: "\u{00f4}",
190 | 245: "\u{00f5}",
191 | 246: "\u{00f6}",
192 | 247: "\u{00f7}",
193 | 248: "\u{00f8}",
194 | 249: "\u{00f9}",
195 | 250: "\u{00fa}",
196 | 251: "\u{00fb}",
197 | 252: "\u{00fc}",
198 | 253: "\u{00fd}",
199 | 254: "\u{00fe}",
200 | 255: "\u{00ff}",
201 | 0: "\u{0100}",
202 | 1: "\u{0101}",
203 | 2: "\u{0102}",
204 | 3: "\u{0103}",
205 | 4: "\u{0104}",
206 | 5: "\u{0105}",
207 | 6: "\u{0106}",
208 | 7: "\u{0107}",
209 | 8: "\u{0108}",
210 | 9: "\u{0109}",
211 | 10: "\u{010a}",
212 | 11: "\u{010b}",
213 | 12: "\u{010c}",
214 | 13: "\u{010d}",
215 | 14: "\u{010e}",
216 | 15: "\u{010f}",
217 | 16: "\u{0110}",
218 | 17: "\u{0111}",
219 | 18: "\u{0112}",
220 | 19: "\u{0113}",
221 | 20: "\u{0114}",
222 | 21: "\u{0115}",
223 | 22: "\u{0116}",
224 | 23: "\u{0117}",
225 | 24: "\u{0118}",
226 | 25: "\u{0119}",
227 | 26: "\u{011a}",
228 | 27: "\u{011b}",
229 | 28: "\u{011c}",
230 | 29: "\u{011d}",
231 | 30: "\u{011e}",
232 | 31: "\u{011f}",
233 | 32: "\u{0120}",
234 | 127: "\u{0121}",
235 | 128: "\u{0122}",
236 | 129: "\u{0123}",
237 | 130: "\u{0124}",
238 | 131: "\u{0125}",
239 | 132: "\u{0126}",
240 | 133: "\u{0127}",
241 | 134: "\u{0128}",
242 | 135: "\u{0129}",
243 | 136: "\u{012a}",
244 | 137: "\u{012b}",
245 | 138: "\u{012c}",
246 | 139: "\u{012d}",
247 | 140: "\u{012e}",
248 | 141: "\u{012f}",
249 | 142: "\u{0130}",
250 | 143: "\u{0131}",
251 | 144: "\u{0132}",
252 | 145: "\u{0133}",
253 | 146: "\u{0134}",
254 | 147: "\u{0135}",
255 | 148: "\u{0136}",
256 | 149: "\u{0137}",
257 | 150: "\u{0138}",
258 | 151: "\u{0139}",
259 | 152: "\u{013a}",
260 | 153: "\u{013b}",
261 | 154: "\u{013c}",
262 | 155: "\u{013d}",
263 | 156: "\u{013e}",
264 | 157: "\u{013f}",
265 | 158: "\u{0140}",
266 | 159: "\u{0141}",
267 | 160: "\u{0142}",
268 | 173: "\u{0143}",
269 | ]
270 |
271 | let byteDecoder = Utils.invert(byteEncoder)
272 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Tokenizer/Utils.swift:
--------------------------------------------------------------------------------
1 | //
2 | // Utils.swift
3 | // AudioBoloss
4 | //
5 | // Created by Julien Chaumond on 07/01/2019.
6 | // Copyright © 2019 Hugging Face. All rights reserved.
7 | //
8 | // See https://github.com/huggingface/swift-coreml-transformers
9 |
10 | import Foundation
11 |
12 | struct Utils {
13 | /// Time a block in ms
14 | static func time(label: String, _ block: () -> T) -> T {
15 | let startTime = CFAbsoluteTimeGetCurrent()
16 | let result = block()
17 | let diff = (CFAbsoluteTimeGetCurrent() - startTime) * 1_000
18 | print("[\(label)] \(diff)ms")
19 | return result
20 | }
21 |
22 | /// Time a block in seconds and return (output, time)
23 | static func time(_ block: () -> T) -> (T, Double) {
24 | let startTime = CFAbsoluteTimeGetCurrent()
25 | let result = block()
26 | let diff = CFAbsoluteTimeGetCurrent() - startTime
27 | return (result, diff)
28 | }
29 |
30 | /// Return unix timestamp in ms
31 | static func dateNow() -> Int64 {
32 | // Use `Int` when we don't support 32-bits devices/OSes anymore.
33 | // Int crashes on iPhone 5c.
34 | return Int64(Date().timeIntervalSince1970 * 1000)
35 | }
36 |
37 | /// Clamp a val to [min, max]
38 | static func clamp(_ val: T, _ vmin: T, _ vmax: T) -> T {
39 | return min(max(vmin, val), vmax)
40 | }
41 |
42 | /// Fake func that can throw.
43 | static func fakeThrowable(_ input: T) throws -> T {
44 | return input
45 | }
46 |
47 | /// Substring
48 | static func substr(_ s: String, _ r: Range) -> String? {
49 | let stringCount = s.count
50 | if stringCount < r.upperBound || stringCount < r.lowerBound {
51 | return nil
52 | }
53 | let startIndex = s.index(s.startIndex, offsetBy: r.lowerBound)
54 | let endIndex = s.index(startIndex, offsetBy: r.upperBound - r.lowerBound)
55 | return String(s[startIndex ..< endIndex])
56 | }
57 |
58 | /// Invert a (k, v) dictionary
59 | static func invert(_ dict: [K: V]) -> [V: K] {
60 | var inverted: [V: K] = [:]
61 | for (k, v) in dict {
62 | inverted[v] = k
63 | }
64 | return inverted
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Views/Helpers/PredictionsTable.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import Foundation
7 | import SwiftUI
8 |
9 | // MARK: - Predictions Results Table
10 | /// Table showing predictions in sorted order
11 | struct PredictionsTable: View {
12 |
13 | let displayPredictions: [DisplayPrediction]
14 |
15 | @Environment(\.horizontalSizeClass) private var horizontalSizeClass
16 |
17 | private var isCompact: Bool { horizontalSizeClass == .compact }
18 |
19 | /// sort keys for classes
20 | @State private var sortCosineSimilarity = true
21 | @State private var sortOrder = [
22 | KeyPathComparator(\DisplayPrediction.cosineSimilarity, order: .reverse)
23 | ]
24 |
25 | @Environment(\.colorScheme) var colorScheme
26 |
27 | var body: some View {
28 | VStack(spacing: 1) {
29 |
30 | let predictions = displayPredictions.sorted(using: sortOrder)
31 | let formatter = DisplayPredictionFormatter(
32 | predictions: predictions, colorScheme: colorScheme)
33 |
34 | HStack {
35 | Text("CLASSNAME")
36 | Spacer()
37 | Text("COSINE SIMILARITY")
38 | }
39 | .font(.footnote)
40 | .fontWeight(.semibold)
41 | .padding(.horizontal, 23.0)
42 | .padding(.bottom, 10)
43 | .foregroundStyle(.secondary)
44 |
45 | Divider().overlay(.gray)
46 |
47 | List(predictions) { prediction in
48 |
49 | // a prediction in a row
50 | HStack {
51 | Text(prediction.className)
52 | Spacer()
53 | Text(prediction.cosineSimilarity.formatted())
54 | }
55 | .bold(formatter.isBold(prediction))
56 | .foregroundStyle(formatter.color(prediction))
57 | }
58 | .listStyle(PlainListStyle())
59 | }
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Views/Helpers/PromptPreview.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import Foundation
7 | import SwiftUI
8 |
9 | /// Displays the current prompt formatted with an arrow
10 | struct PromptPreview: View {
11 |
12 | let prompt: Prompt
13 |
14 | var body: some View {
15 | HStack {
16 | Image(systemName: "arrow.forward")
17 | Text("\(prompt.prefix) ") + Text("CLASSNAME").underline() + Text(" \(prompt.suffix)")
18 | }
19 | .lineLimit(1)
20 | .padding(.bottom, 16)
21 | .frame(maxWidth: .infinity, alignment: .leading)
22 | .padding(.top, 4)
23 | }
24 | }
25 |
26 | #Preview {
27 | VStack(alignment: .leading) {
28 | PromptPreview(prompt: .init(prefix: "A photo of", suffix: "", classNames: []))
29 |
30 | PromptPreview(
31 | prompt: .init(
32 | prefix: "A photo of", suffix: "with some longer thing at the end", classNames: []))
33 |
34 | }
35 | .padding()
36 | }
37 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Views/Helpers/VideoFrameView.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import Foundation
7 | import SwiftUI
8 |
9 | /// Displays a stream of video frames
10 | struct VideoFrameView: View {
11 |
12 | let frames: AsyncStream
13 |
14 | @Binding var backCamera: Bool
15 |
16 | @State private var videoFrame: CVImageBuffer?
17 |
18 | var body: some View {
19 | Group {
20 | if let videoFrame {
21 | // display the image, cropped to a square, with rounded corners
22 | _ImageView(image: videoFrame)
23 | .clipShape(RoundedRectangle(cornerRadius: 20))
24 | .allowsHitTesting(false)
25 | .padding(.horizontal, 2)
26 |
27 | // control to flip to front/back facing camera
28 | .overlay(alignment: .bottomTrailing) {
29 | Button(action: toggleCamera) {
30 | Image(systemName: "arrow.triangle.2.circlepath.circle")
31 | .foregroundStyle(.blue)
32 | .frame(width: 20, height: 20)
33 | .padding(20)
34 | }
35 | .buttonStyle(.plain)
36 | }
37 |
38 | } else {
39 | // spinner before the camera comes up
40 | ProgressView()
41 | .controlSize(.large)
42 | }
43 | }
44 | .task {
45 | // feed frames to the _ImageView
46 | if Task.isCancelled {
47 | return
48 | }
49 | for await frame in frames {
50 | self.videoFrame = frame
51 | }
52 | }
53 | }
54 |
55 | func toggleCamera() {
56 | backCamera.toggle()
57 | }
58 | }
59 |
60 | /// Internal view to display a CVImageBuffer
61 | private struct _ImageView: UIViewRepresentable {
62 |
63 | let image: Any
64 | var gravity = CALayerContentsGravity.center
65 |
66 | func makeUIView(context: Context) -> UIView {
67 | let view = UIView()
68 | view.layer.contentsGravity = gravity
69 | return view
70 | }
71 |
72 | func updateUIView(_ uiView: UIView, context: Context) {
73 | uiView.layer.contents = image
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Views/InfoView.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import Foundation
7 | import SwiftUI
8 |
9 | /// About MobileCLIP, model options and credits
10 | struct InfoView: View {
11 |
12 | var modelOptionsData = [
13 | (name: "MobileCLIP-S0", latency: "2.1", accuracy: "58.1"),
14 | (name: "MobileCLIP-S1", latency: "5.8", accuracy: "61.3"),
15 | (name: "MobileCLIP-S2", latency: "6.9", accuracy: "63.7"),
16 | (name: "MobileCLIP-B (LT)", latency: "13.7", accuracy: "65.8"),
17 | ]
18 |
19 | var body: some View {
20 | VStack(spacing: 10) {
21 | Text("Information")
22 | .font(.headline)
23 | .bold()
24 | .padding(.top)
25 |
26 | // About MobileCLIP
27 | Group {
28 | Text("MobileCLIP")
29 | .font(.largeTitle)
30 | .bold()
31 | .padding(.top)
32 |
33 | Group {
34 | Text("MobileCLIP¹")
35 | .fontWeight(.bold)
36 | + Text(
37 | " is a new family of efficient image-text models optimized for runtime performance, trained with a novel and efficient training approach, namely multi-model reinforced training."
38 | )
39 |
40 | Text(
41 | "This app demonstrates the use of **MobileCLIP** models for performing real-time zero-shot scene classification. Users are free to customize the prompt and provide classnames of their choice."
42 | )
43 | }
44 | .padding(.bottom)
45 | }
46 | .frame(maxWidth: .infinity, alignment: .leading)
47 | .padding(.horizontal)
48 |
49 |
50 | // Model Options Table
51 | Group {
52 | Text("Model Options")
53 | .font(.title2)
54 | .bold()
55 |
56 | Text("You can select to run any of the following MobileCLIP model options:")
57 | .padding(.bottom)
58 | }
59 | .frame(maxWidth: .infinity, alignment: .leading)
60 | .padding(.horizontal)
61 |
62 | Group {
63 | HStack {
64 | Text("Name")
65 | Spacer()
66 | .frame(width: 105, alignment: .leading)
67 | Text("Ideal Latency")
68 | Spacer()
69 | Text("Accuracy")
70 | }
71 | .font(.headline)
72 | .fontWeight(.bold)
73 | .padding(.horizontal)
74 |
75 | Divider()
76 |
77 | ForEach(modelOptionsData, id: \.name) { option in
78 | HStack {
79 | Text(option.name)
80 | .fontWeight(.semibold)
81 | .frame(width: 125, alignment: .leading)
82 | Spacer()
83 | Text("\(option.latency) ms")
84 | .fontWeight(.light)
85 | Spacer()
86 | Text("\(option.accuracy)%")
87 | .fontWeight(.light)
88 | }
89 | .padding(.vertical, 5)
90 | .padding(.horizontal)
91 | Divider()
92 | }
93 | }
94 |
95 | // Authors and Citations
96 | VStack(alignment: .leading, spacing: 8) {
97 | HStack {
98 | Spacer().frame(width: 12)
99 | Text(
100 | "**¹ MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training.** (CVPR 2024)"
101 | )
102 | Spacer().frame(width: 12)
103 | }
104 | Text(
105 | "Pavan Kumar Anasosalu Vasu, Hadi Pour Ansari, Fartash Faghri, Raviteja Vemulapalli, Oncel Tuzel."
106 | )
107 | .padding(.horizontal, 12)
108 |
109 | }
110 | .foregroundColor(.secondary)
111 | .font(.system(size: 12))
112 |
113 | Spacer()
114 | }
115 | .textSelection(.enabled)
116 | .font(.system(size: 16))
117 | .padding(.bottom)
118 | .frame(maxWidth: .infinity, maxHeight: .infinity)
119 | }
120 |
121 | // MARK: - View Modifiers
122 | struct AboutPanel: ViewModifier {
123 | func body(content: Content) -> some View {
124 | content
125 | .foregroundColor(.secondary)
126 | .padding(.horizontal, 15)
127 | .padding(.vertical, 30)
128 | .textSelection(.enabled)
129 | }
130 | }
131 | }
132 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/Views/PromptEditor.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import Foundation
7 | import SwiftUI
8 |
9 | // MARK: - Field Views and Layout
10 | /// The primary label of a field
11 | private struct FieldTitle: View {
12 |
13 | let title: String
14 |
15 | var body: some View {
16 | Text(title.uppercased())
17 | .font(.caption)
18 | .bold()
19 | }
20 | }
21 |
22 | /// A field with primary and secondary labels, e.g. shown when editing the `start of a prompt`
23 | private struct FieldView: View {
24 |
25 | let title: String
26 | let help: String
27 | let contents: V
28 |
29 | internal init(title: String, help: String, @ViewBuilder contents: () -> V) {
30 | self.title = title
31 | self.help = help
32 | self.contents = contents()
33 | }
34 |
35 | var body: some View {
36 | VStack(alignment: .leading) {
37 | FieldTitle(title: title)
38 | .padding(.horizontal, 10)
39 |
40 | contents
41 | .padding()
42 | .background(Color("PromptEditor"))
43 | .clipShape(RoundedRectangle(cornerRadius: 8))
44 |
45 | Text(help)
46 | .modifier(SecondaryLabel())
47 | .padding(.horizontal, 10)
48 | }
49 | .padding(.bottom, 15)
50 | }
51 |
52 | }
53 |
54 | // MARK: - Configure Prompt View
55 | /// Displays a preview, instructions and fields needed to configure a prompt
56 | struct PromptEditor: View {
57 |
58 | @Binding var prompt: Prompt
59 |
60 | var body: some View {
61 | VStack {
62 | ScrollView {
63 | Group {
64 | PromptPreview(prompt: prompt)
65 |
66 | Divider()
67 |
68 | Text(
69 | """
70 | Configure the start and end of the prompt. A preview of the fully composed prompt for the text encoder is shown above. To customize class names, click the Classname Selection button.
71 | """.replacingOccurrences(of: "\n", with: " ")
72 | )
73 | .modifier(SecondaryLabel())
74 | .padding(.bottom, 30)
75 | }
76 |
77 | VStack(spacing: 8) {
78 | FieldView(
79 | title: "Start of prompt",
80 | help: "The start of the prompt. Can be left blank."
81 | ) {
82 | TextField("No prefix", text: $prompt.prefix)
83 | }
84 |
85 | FieldView(
86 | title: "Classnames",
87 | help: "Classnames that will be scored."
88 | ) {
89 |
90 | // Classname Selection Button
91 | NavigationLink {
92 | ClassnamesEditor(classnames: $prompt.classNames)
93 |
94 | } label: {
95 | HStack {
96 | if !prompt.classNames.isEmpty {
97 |
98 | // Display added classnames
99 | Text(prompt.classNames.joined(separator: ", "))
100 | .lineLimit(1)
101 |
102 | } else {
103 | // Label shown if no classnames have been added
104 | Text("Classname Selection")
105 |
106 | }
107 | Spacer()
108 | Image(systemName: "pencil")
109 | .padding(.vertical)
110 | }
111 | }
112 | }
113 |
114 | FieldView(
115 | title: "End of prompt",
116 | help: "The end of the prompt. Can be left blank."
117 | ) {
118 | TextField("No suffix", text: $prompt.suffix)
119 | }
120 | }
121 | Spacer()
122 | }
123 | }
124 | .navigationTitle("Configure Prompt")
125 | .navigationBarTitleDisplayMode(.large)
126 | .padding(EdgeInsets(top: 10, leading: 15, bottom: 10, trailing: 15))
127 | .background(Color("PromptEditorBackground"))
128 | }
129 | }
130 |
131 | // MARK: - Manage Classnames View
132 | /// Displays the current Classnames, and instructs the user how to add or remove them
133 | struct ClassnamesEditor: View {
134 |
135 | @Binding var classnames: [String]
136 |
137 | @State private var className = ""
138 |
139 | var body: some View {
140 | VStack(alignment: .leading) {
141 | HStack {
142 | TextField("Add a new classname", text: $className)
143 | .textInputAutocapitalization(.never)
144 | .onSubmit(addClassName)
145 | .padding()
146 | .background(Color("PromptEditor"))
147 | .clipShape(RoundedRectangle(cornerRadius: 8))
148 |
149 | Button(action: { addClassName() }) {
150 | Image(systemName: "plus.circle")
151 | .foregroundStyle(Color("AddClassName"))
152 | .padding()
153 | }
154 | }
155 | .padding(.bottom, 16)
156 |
157 | Divider()
158 |
159 | Text("Add the classnames that you would like the app to detect.")
160 | .modifier(SecondaryLabel())
161 | .padding(.bottom, 16)
162 |
163 | // Display classnames as they are added
164 | if !classnames.isEmpty {
165 | Group {
166 | FieldTitle(title: "Current Classnames")
167 |
168 | Text("Swipe left to remove items.")
169 | .modifier(SecondaryLabel())
170 | }
171 | .padding(.horizontal, 10)
172 |
173 | List {
174 | ForEach(classnames, id: \.self) { className in
175 |
176 | // A classname in a row
177 | HStack {
178 | Text(className)
179 | Spacer()
180 | }
181 | .listRowBackground(Color("PromptEditor"))
182 | }
183 | .onDelete { indexes in
184 | classnames.remove(atOffsets: indexes)
185 | }
186 | }
187 | .ignoresSafeArea(.keyboard)
188 | .environment(\.defaultMinListRowHeight, 25)
189 | .scrollContentBackground(.hidden)
190 | } else {
191 | Spacer()
192 | }
193 |
194 | }
195 | .onAppear {
196 | classnames.sort { $0.lowercased() < $1.lowercased() }
197 | }
198 | .navigationTitle("Manage Classnames")
199 | .navigationBarTitleDisplayMode(.large)
200 | .padding()
201 | .background(Color("PromptEditorBackground"))
202 | }
203 |
204 | private func addClassName() {
205 | let nameToAdd = className.trimmingCharacters(in: .whitespacesAndNewlines)
206 | guard !nameToAdd.isEmpty else { return }
207 | guard !classnames.contains(nameToAdd) else { return }
208 |
209 | withAnimation {
210 | classnames.append(nameToAdd)
211 | classnames.sort { $0.lowercased() < $1.lowercased() }
212 |
213 | className = ""
214 | }
215 | }
216 | }
217 |
218 | // MARK: - View Modifiers
219 | struct SecondaryLabel: ViewModifier {
220 | func body(content: Content) -> some View {
221 | content
222 | .foregroundColor(.secondary)
223 | .font(.caption)
224 | }
225 | }
226 |
227 | // MARK: - Previews
228 | #Preview {
229 | VStack {
230 | FieldView(
231 | title: "Start of prompt",
232 | help: "The start of the prompt. Can be left blank."
233 | ) {
234 | TextField("", text: .constant("A photo of"))
235 | }
236 |
237 | ClassnamesEditor(classnames: .constant(["cat", "dog", "chicken"]))
238 | }
239 | .padding()
240 | .background(Color(UIColor.systemGroupedBackground))
241 | }
242 |
--------------------------------------------------------------------------------
/ios_app/MobileCLIPExplore/ZSImageClassification.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import CoreML
7 | import UIKit
8 |
9 | /// shared tokenizer for all model types
10 | private let tokenizer = AsyncFactory {
11 | CLIPTokenizer()
12 | }
13 |
14 | actor ZSImageClassification: ObservableObject {
15 |
16 | private let ciContext = CIContext()
17 | private var model: any CLIPEncoder
18 |
19 | public init(model: any CLIPEncoder) {
20 | self.model = model
21 | }
22 |
23 | func load() async {
24 | async let t = tokenizer.get()
25 | async let m = model.load()
26 | _ = await (t, m)
27 | }
28 |
29 | public func setModel(_ model: ModelConfiguration) {
30 | self.model = model.factory()
31 | }
32 |
33 | // Compute Text Embeddings
34 | func computeTextEmbeddings(promptArr: [String]) async -> [MLMultiArray] {
35 | var textEmbeddings: [MLMultiArray] = []
36 | do {
37 | for singlePrompt in promptArr {
38 | print("")
39 | print("Prompt text: \(singlePrompt)")
40 |
41 | // Tokenize the text query
42 | let inputIds = await tokenizer.get().encode_full(text: singlePrompt)
43 |
44 | // Convert [Int] to MultiArray
45 | let inputArray = try MLMultiArray(shape: [1, 77], dataType: .int32)
46 | for (index, element) in inputIds.enumerated() {
47 | inputArray[index] = NSNumber(value: element)
48 | }
49 |
50 | // Run the text model on the text query
51 | let output = try await model.encode(text: inputArray)
52 | textEmbeddings.append(output)
53 | }
54 | } catch {
55 | print(error.localizedDescription)
56 | }
57 | return textEmbeddings
58 | }
59 |
60 | // Compute Image Embeddings
61 | func computeImageEmbeddings(frame: CVPixelBuffer) async -> (
62 | embedding: MLMultiArray, interval: CFTimeInterval
63 | )? {
64 | // prepare the image
65 | var image: CIImage? = CIImage(cvPixelBuffer: frame)
66 | image = image?.cropToSquare()
67 | image = image?.resize(size: model.targetImageSize)
68 |
69 | guard let image else {
70 | return nil
71 | }
72 |
73 | // output buffer
74 | let extent = image.extent
75 | let pixelFormat = kCVPixelFormatType_32ARGB
76 | var output: CVPixelBuffer?
77 | CVPixelBufferCreate(nil, Int(extent.width), Int(extent.height), pixelFormat, nil, &output)
78 |
79 | guard let output else {
80 | print("failed to create output CVPixelBuffer")
81 | return nil
82 | }
83 |
84 | ciContext.render(image, to: output)
85 |
86 | // Run image embedding
87 | do {
88 | let startTimer = CACurrentMediaTime()
89 | let output = try await model.encode(image: output)
90 | let endTimer = CACurrentMediaTime()
91 | let interval = endTimer - startTimer
92 | return (embedding: output, interval: interval)
93 | } catch {
94 | print(error.localizedDescription)
95 | return nil
96 | }
97 | }
98 |
99 | // Compute cosine similarity between embeddings
100 | nonisolated func cosineSimilarity(_ embedding1: MLMultiArray, _ embedding2: MLMultiArray)
101 | -> Float
102 | {
103 |
104 | // read the values out of the MLMultiArray in bulk
105 | let e1 = embedding1.withUnsafeBufferPointer(ofType: Float.self) { ptr in
106 | Array(ptr)
107 | }
108 | let e2 = embedding2.withUnsafeBufferPointer(ofType: Float.self) { ptr in
109 | Array(ptr)
110 | }
111 |
112 | // Get the dot product of the two embeddings
113 | let dotProduct: Float = zip(e1, e2).reduce(0.0) { $0 + $1.0 * $1.1 }
114 |
115 | // Get the magnitudes of the two embeddings
116 | let magnitude1: Float = sqrt(e1.reduce(0) { $0 + pow($1, 2) })
117 | let magnitude2: Float = sqrt(e2.reduce(0) { $0 + pow($1, 2) })
118 |
119 | // Get the cosine similarity
120 | let similarity = dotProduct / (magnitude1 * magnitude2)
121 | return similarity
122 | }
123 | }
124 |
--------------------------------------------------------------------------------
/ios_app/README.md:
--------------------------------------------------------------------------------
1 | # MobileCLIPExplore
2 | Demonstrates the use of **MobileCLIP** models for performing real-time zero-shot scene classification. Users can customize the prompts and provide classnames of their choice.
3 |
4 | 
5 |
6 | ## Background
7 | MobileCLIP is a new family of efficient image-text models optimized for runtime performance, trained with a novel and efficient training approach, namely multi-model reinforced training.
8 |
9 | ## Dependencies
10 | Requires **iOS 17.2** or later.
11 |
12 | ## Install
13 | 1. Download the required CoreML models from [this link](https://huggingface.co/apple/coreml-mobileclip)
14 |
15 | 2. Place these models in the following directory:
16 |
17 | `.../MobileCLIPExplore/MobileCLIPExplore/Models`
18 |
19 | 
20 |
21 |
22 | ## Use
23 |
24 | ### Getting Started
25 | 1. Select from one of the four MobileCLIP models from the top menu.
26 | 2. Choose a prompt preset, or create a custom one.
27 | 3. Point the camera at a subject to display prediction results in the table below.
28 |
29 | The FPS (frames per second) being analyzed and response time (milliseconds) of the prediction results are displayed in the top left of the landing view.
30 |
31 | ### Creating and Editing a Prompt
32 | 1. To edit the prompt, select the prompt or pencil icon ✎.
33 | 2. Tap the start or end of a prompt to customize the prefix or suffix, respectively.
34 | 3. Use the Classnames field to add or remove items you would like the app to detect.
35 |
36 | ## Model Options
37 | The application allows the user to select from four different MobileCLIP models:
38 |
39 | | Name | Ideal Latency | IN-1k Zero-Shot
Top-1 Acc. (%) | Avg. Perf. (%)
on 38 datasets |
40 | |-----------------|---------------|-------|------------------------------------|
41 | | MobileCLIP-S0 | 2.1 ms | 67.8% | 58.1% |
42 | | MobileCLIP-S1 | 5.8 ms | 72.6% | 61.3% |
43 | | MobileCLIP-S2 | 6.9 ms | 74.4% | 63.7% |
44 | | MobileCLIP-B (LT)| 13.7 ms | 77.2% | 65.8% |
45 |
46 | ## Privacy
47 | All predictions are processed privately and securely using on-device models.
48 |
--------------------------------------------------------------------------------
/ios_app/docs/app_screenshots/examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-mobileclip/1140b8d197e4ed7d56b3a92216ded98bb1c2ac87/ios_app/docs/app_screenshots/examples.png
--------------------------------------------------------------------------------
/ios_app/docs/app_screenshots/model_directory_guide.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-mobileclip/1140b8d197e4ed7d56b3a92216ded98bb1c2ac87/ios_app/docs/app_screenshots/model_directory_guide.png
--------------------------------------------------------------------------------
/mobileclip/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | import os
6 | import json
7 | from typing import Optional, Union, Tuple, Any
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torchvision.transforms import (
12 | CenterCrop,
13 | Compose,
14 | InterpolationMode,
15 | Resize,
16 | ToTensor,
17 | )
18 |
19 | from mobileclip.clip import CLIP
20 | from mobileclip.modules.text.tokenizer import (
21 | ClipTokenizer,
22 | )
23 | from mobileclip.modules.common.mobileone import reparameterize_model
24 |
25 |
26 | def create_model_and_transforms(
27 | model_name: str,
28 | pretrained: Optional[str] = None,
29 | reparameterize: Optional[bool] = True,
30 | device: Union[str, torch.device] = "cpu",
31 | ) -> Tuple[nn.Module, Any, Any]:
32 | """
33 | Method to instantiate model and pre-processing transforms necessary for inference.
34 |
35 | Args:
36 | model_name: Model name. Choose from ['mobileclip_s0', 'mobileclip_s1', 'mobileclip_s2', 'mobileclip_b']
37 | pretrained: Location of pretrained checkpoint.
38 | reparameterize: When set to True, re-parameterizable branches get folded for faster inference.
39 | device: Device identifier for model placement.
40 |
41 | Returns:
42 | Tuple of instantiated model, and preprocessing transforms for inference.
43 | """
44 | # Config files
45 | root_dir = os.path.dirname(os.path.abspath(__file__))
46 | configs_dir = os.path.join(root_dir, "configs")
47 | model_cfg_file = os.path.join(configs_dir, model_name + ".json")
48 |
49 | # Get config from yaml file
50 | if not os.path.exists(model_cfg_file):
51 | raise ValueError(f"Unsupported model name: {model_name}")
52 | model_cfg = json.load(open(model_cfg_file, "r"))
53 |
54 | # Build preprocessing transforms for inference
55 | resolution = model_cfg["image_cfg"]["image_size"]
56 | resize_size = resolution
57 | centercrop_size = resolution
58 | aug_list = [
59 | Resize(
60 | resize_size,
61 | interpolation=InterpolationMode.BILINEAR,
62 | ),
63 | CenterCrop(centercrop_size),
64 | ToTensor(),
65 | ]
66 | preprocess = Compose(aug_list)
67 |
68 | # Build model
69 | model = CLIP(cfg=model_cfg)
70 | model.to(device)
71 | model.eval()
72 |
73 | # Load checkpoint
74 | if pretrained is not None:
75 | chkpt = torch.load(pretrained)
76 | model.load_state_dict(chkpt)
77 |
78 | # Reparameterize model for inference (if specified)
79 | if reparameterize:
80 | model = reparameterize_model(model)
81 |
82 | return model, None, preprocess
83 |
84 |
85 | def get_tokenizer(model_name: str) -> nn.Module:
86 | # Config files
87 | root_dir = os.path.dirname(os.path.abspath(__file__))
88 | configs_dir = os.path.join(root_dir, "configs")
89 | model_cfg_file = os.path.join(configs_dir, model_name + ".json")
90 |
91 | # Get config from yaml file
92 | model_cfg = json.load(open(model_cfg_file, "r"))
93 |
94 | # Build tokenizer
95 | text_tokenizer = ClipTokenizer(model_cfg)
96 | return text_tokenizer
97 |
--------------------------------------------------------------------------------
/mobileclip/clip.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | """ Model schema in open_clip format for inference only. """
6 | import math
7 | from typing import Any, Optional, Dict
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | from torch import nn
12 |
13 | from mobileclip.text_encoder import (
14 | TextTransformer,
15 | )
16 |
17 | from .image_encoder import MCi
18 |
19 |
20 | class CLIP(nn.Module):
21 | """Base class for multi-modal image-text data"""
22 |
23 | def __init__(self, cfg: Dict, output_dict: bool = False, *args, **kwargs) -> None:
24 | super().__init__()
25 | self.output_dict = output_dict
26 | self.projection_dim = cfg["embed_dim"]
27 | if self.projection_dim is None:
28 | raise ValueError("Please specify `embed_dim` in model config.")
29 |
30 | self.image_encoder = MCi(
31 | model_name=cfg["image_cfg"]["model_name"],
32 | projection_dim=self.projection_dim,
33 | )
34 | self.text_encoder = TextTransformer(
35 | cfg=cfg["text_cfg"], projection_dim=self.projection_dim
36 | )
37 | self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1.0 / 0.07))
38 |
39 | def _exponentiate_and_clip_logits(self, max_scale: float = 100.0):
40 | scale = self.logit_scale.exp()
41 | scale = torch.clamp(scale, 0, max_scale)
42 | return scale
43 |
44 | def encode_image(self, image: torch.Tensor, normalize: bool = False):
45 | image_encoder_out = self.image_encoder(image)
46 | if isinstance(image_encoder_out, dict):
47 | features = image_encoder_out["logits"]
48 | else:
49 | features = image_encoder_out
50 | return F.normalize(features, dim=-1) if normalize else features
51 |
52 | def encode_text(self, text: torch.Tensor, normalize: bool = False):
53 | text_features = self.text_encoder(text_tokens=text, key_padding_mask=None)
54 | return F.normalize(text_features, dim=-1) if normalize else text_features
55 |
56 | def forward(
57 | self,
58 | image: Optional[torch.Tensor] = None,
59 | text: Optional[torch.Tensor] = None,
60 | *args,
61 | **kwargs
62 | ) -> Any:
63 |
64 | image_embeddings = (
65 | self.encode_image(image, normalize=True) if image is not None else None
66 | )
67 | text_embeddings = (
68 | self.encode_text(text, normalize=True) if text is not None else None
69 | )
70 |
71 | if self.output_dict:
72 | return {
73 | "image_features": image_embeddings,
74 | "text_features": text_embeddings,
75 | "logit_scale": self._exponentiate_and_clip_logits(),
76 | }
77 | return image_embeddings, text_embeddings, self._exponentiate_and_clip_logits()
78 |
--------------------------------------------------------------------------------
/mobileclip/configs/mobileclip_b.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "image_cfg": {
4 | "image_size": 224,
5 | "model_name": "vit_b16"
6 | },
7 | "text_cfg": {
8 | "context_length": 77,
9 | "vocab_size": 49408,
10 | "dim": 512,
11 | "ffn_multiplier_per_layer": 4.0,
12 | "n_heads_per_layer": 8,
13 | "n_transformer_layers": 12,
14 | "norm_layer": "layer_norm_fp32",
15 | "causal_masking": true,
16 | "model_name": "base"
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/mobileclip/configs/mobileclip_s0.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "image_cfg": {
4 | "image_size": 256,
5 | "model_name": "mci0"
6 | },
7 | "text_cfg": {
8 | "context_length": 77,
9 | "vocab_size": 49408,
10 | "dim": 512,
11 | "ffn_multiplier_per_layer": 4.0,
12 | "n_heads_per_layer": 8,
13 | "n_transformer_layers": 4,
14 | "norm_layer": "layer_norm_fp32",
15 | "causal_masking": false,
16 | "model_name": "mct"
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/mobileclip/configs/mobileclip_s1.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "image_cfg": {
4 | "image_size": 256,
5 | "model_name": "mci1"
6 | },
7 | "text_cfg": {
8 | "context_length": 77,
9 | "vocab_size": 49408,
10 | "dim": 512,
11 | "ffn_multiplier_per_layer": 4.0,
12 | "n_heads_per_layer": 8,
13 | "n_transformer_layers": 12,
14 | "norm_layer": "layer_norm_fp32",
15 | "causal_masking": false,
16 | "model_name": "base"
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/mobileclip/configs/mobileclip_s2.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "image_cfg": {
4 | "image_size": 256,
5 | "model_name": "mci2"
6 | },
7 | "text_cfg": {
8 | "context_length": 77,
9 | "vocab_size": 49408,
10 | "dim": 512,
11 | "ffn_multiplier_per_layer": 4.0,
12 | "n_heads_per_layer": 8,
13 | "n_transformer_layers": 12,
14 | "norm_layer": "layer_norm_fp32",
15 | "causal_masking": false,
16 | "model_name": "base"
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/mobileclip/image_encoder.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | from typing import Any
6 |
7 | import torch.nn as nn
8 | from timm.models import create_model
9 |
10 | from mobileclip import models # Added to register models
11 | from mobileclip.modules.image.image_projection import GlobalPool2D
12 |
13 |
14 | class MCi(nn.Module):
15 | """
16 | This class implements `MCi Models `_
17 | """
18 |
19 | def __init__(self, model_name: str, *args, **kwargs) -> None:
20 | super().__init__()
21 | self.projection_dim = None
22 | if "projection_dim" in kwargs:
23 | self.projection_dim = kwargs.get("projection_dim")
24 |
25 | # Create model
26 | self.model = create_model(model_name, projection_dim=self.projection_dim)
27 |
28 | # Build out projection head.
29 | if self.projection_dim is not None:
30 | if hasattr(self.model, "head"):
31 | self.model.head = MCi._update_image_classifier(
32 | image_classifier=self.model.head, projection_dim=self.projection_dim
33 | )
34 |
35 | def forward(self, x: Any, *args, **kwargs) -> Any:
36 | """A forward function of the model."""
37 | x = self.model(x)
38 | return x
39 |
40 | @staticmethod
41 | def _get_in_feature_dimension(image_classifier: nn.Module) -> int:
42 | """Return the input feature dimension to the image classification head."""
43 | in_features = None
44 | if isinstance(image_classifier, nn.Sequential):
45 | # Classifier that uses nn.Sequential usually has global pooling and
46 | # multiple linear layers. Find the first linear layer and get its
47 | # in_features
48 | for layer in image_classifier:
49 | if isinstance(layer, nn.Linear):
50 | in_features = layer.in_features
51 | break
52 | elif isinstance(image_classifier, nn.Linear):
53 | in_features = image_classifier.in_features
54 |
55 | if in_features is None:
56 | raise NotImplementedError(
57 | f"Cannot get input feature dimension of {image_classifier}."
58 | )
59 | return in_features
60 |
61 | @staticmethod
62 | def _update_image_classifier(
63 | image_classifier: nn.Module, projection_dim: int, *args, **kwargs
64 | ) -> nn.Module:
65 | in_features = MCi._get_in_feature_dimension(image_classifier)
66 | new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim)
67 | return new_img_classifier
68 |
--------------------------------------------------------------------------------
/mobileclip/logger.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import os
7 | import sys
8 | import time
9 | import traceback
10 | from typing import Optional, Union
11 |
12 | text_colors = {
13 | "logs": "\033[34m", # 033 is the escape code and 34 is the color code
14 | "info": "\033[32m",
15 | "warning": "\033[33m",
16 | "debug": "\033[93m",
17 | "error": "\033[31m",
18 | "bold": "\033[1m",
19 | "end_color": "\033[0m",
20 | "light_red": "\033[36m",
21 | }
22 |
23 |
24 | def get_curr_time_stamp() -> str:
25 | return time.strftime("%Y-%m-%d %H:%M:%S")
26 |
27 |
28 | def error(message: str) -> None:
29 | time_stamp = get_curr_time_stamp()
30 | error_str = (
31 | text_colors["error"]
32 | + text_colors["bold"]
33 | + "ERROR "
34 | + text_colors["end_color"]
35 | )
36 |
37 | # exiting with code -1 does not tell any information about the error (e.g., NaN encountered in the loss).
38 | # For more descriptive error messages, we replace exit(-1) with sys.exit(ERROR_MESSAGE).
39 | # This allows us to handle specific exceptions in the tests.
40 |
41 | # print("{} - {} - {}".format(time_stamp, error_str, message), flush=True)
42 | # print("{} - {} - {}".format(time_stamp, error_str, "Exiting!!!"), flush=True)
43 | # exit(-1)
44 |
45 | if sys.exc_info()[0] is None:
46 | traceback.print_stack()
47 | else:
48 | traceback.print_exc()
49 | sys.exit("{} - {} - {}. Exiting!!!".format(time_stamp, error_str, message))
50 |
51 |
52 | def color_text(in_text: str) -> str:
53 | return text_colors["light_red"] + in_text + text_colors["end_color"]
54 |
55 |
56 | def log(message: str, end="\n") -> None:
57 | time_stamp = get_curr_time_stamp()
58 | log_str = (
59 | text_colors["logs"] + text_colors["bold"] + "LOGS " + text_colors["end_color"]
60 | )
61 | print("{} - {} - {}".format(time_stamp, log_str, message), end=end)
62 |
63 |
64 | def warning(message: Union[str, Warning]) -> None:
65 | if isinstance(message, Warning):
66 | message = f"{type(message).__name__}({','.join(map(repr, message.args))}"
67 |
68 | time_stamp = get_curr_time_stamp()
69 | warn_str = (
70 | text_colors["warning"]
71 | + text_colors["bold"]
72 | + "WARNING"
73 | + text_colors["end_color"]
74 | )
75 | print("{} - {} - {}".format(time_stamp, warn_str, message))
76 |
77 |
78 | def ignore_exception_with_warning(message: str) -> None:
79 | """
80 | After catching a tolerable exception E1 (e.g. when Model.forward() fails during
81 | profiling with try-catch, it'll be helpful to log the exception for future
82 | investigation. But printing the error stack trace, as is, could be confusing
83 | when an uncaught (non-tolerable) exception "E2" raises down the road. Then, the log
84 | will contain two stack traces for E1, E2. When looking for errors in logs, users
85 | should look for E2, but they may find E1.
86 |
87 | This function appends "(WARNING)" at the end of all lines of the E1 traceback, so
88 | that the user can distinguish E1 from uncaught exception E2.
89 |
90 | Args:
91 | message: Extra explanation and context for debugging. (Note: the exception obj
92 | will be automatically fetched from python. No need to pass it as an argument or as
93 | message)
94 | """
95 | warning(f"{message}:\n{traceback.format_exc()}".replace("\n", "\n(WARNING)"))
96 |
97 |
98 | def info(message: str, print_line: Optional[bool] = False) -> None:
99 | time_stamp = get_curr_time_stamp()
100 | info_str = (
101 | text_colors["info"] + text_colors["bold"] + "INFO " + text_colors["end_color"]
102 | )
103 | print("{} - {} - {}".format(time_stamp, info_str, message))
104 | if print_line:
105 | double_dash_line(dashes=150)
106 |
107 |
108 | def debug(message: str) -> None:
109 | time_stamp = get_curr_time_stamp()
110 | log_str = (
111 | text_colors["debug"]
112 | + text_colors["bold"]
113 | + "DEBUG "
114 | + text_colors["end_color"]
115 | )
116 | print("{} - {} - {}".format(time_stamp, log_str, message))
117 |
118 |
119 | def double_dash_line(dashes: Optional[int] = 75) -> None:
120 | print(text_colors["error"] + "=" * dashes + text_colors["end_color"])
121 |
122 |
123 | def singe_dash_line(dashes: Optional[int] = 67) -> None:
124 | print("-" * dashes)
125 |
126 |
127 | def print_header(header: str) -> None:
128 | double_dash_line()
129 | print(
130 | text_colors["info"]
131 | + text_colors["bold"]
132 | + "=" * 50
133 | + str(header)
134 | + text_colors["end_color"]
135 | )
136 | double_dash_line()
137 |
138 |
139 | def print_header_minor(header: str) -> None:
140 | print(
141 | text_colors["warning"]
142 | + text_colors["bold"]
143 | + "=" * 25
144 | + str(header)
145 | + text_colors["end_color"]
146 | )
147 |
148 |
149 | def disable_printing():
150 | sys.stdout = open(os.devnull, "w")
151 |
152 |
153 | def enable_printing():
154 | sys.stdout = sys.__stdout__
155 |
--------------------------------------------------------------------------------
/mobileclip/models/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All rights reserved.
4 | #
5 | from .mci import (
6 | mci0,
7 | mci1,
8 | mci2,
9 | )
10 | from .vit import vit_b16
11 |
--------------------------------------------------------------------------------
/mobileclip/modules/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All rights reserved.
4 | #
5 |
--------------------------------------------------------------------------------
/mobileclip/modules/common/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All rights reserved.
4 | #
5 |
--------------------------------------------------------------------------------
/mobileclip/modules/common/mobileone.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | from typing import Union, Tuple
6 |
7 | import copy
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | __all__ = ["MobileOneBlock", "reparameterize_model"]
13 |
14 |
15 | class SEBlock(nn.Module):
16 | """Squeeze and Excite module.
17 |
18 | Pytorch implementation of `Squeeze-and-Excitation Networks` -
19 | https://arxiv.org/pdf/1709.01507.pdf
20 | """
21 |
22 | def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None:
23 | """Construct a Squeeze and Excite Module.
24 |
25 | Args:
26 | in_channels: Number of input channels.
27 | rd_ratio: Input channel reduction ratio.
28 | """
29 | super(SEBlock, self).__init__()
30 | self.reduce = nn.Conv2d(
31 | in_channels=in_channels,
32 | out_channels=int(in_channels * rd_ratio),
33 | kernel_size=1,
34 | stride=1,
35 | bias=True,
36 | )
37 | self.expand = nn.Conv2d(
38 | in_channels=int(in_channels * rd_ratio),
39 | out_channels=in_channels,
40 | kernel_size=1,
41 | stride=1,
42 | bias=True,
43 | )
44 |
45 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
46 | """Apply forward pass."""
47 | b, c, h, w = inputs.size()
48 | x = F.avg_pool2d(inputs, kernel_size=[h, w])
49 | x = self.reduce(x)
50 | x = F.relu(x)
51 | x = self.expand(x)
52 | x = torch.sigmoid(x)
53 | x = x.view(-1, c, 1, 1)
54 | return inputs * x
55 |
56 |
57 | class MobileOneBlock(nn.Module):
58 | """MobileOne building block.
59 |
60 | This block has a multi-branched architecture at train-time
61 | and plain-CNN style architecture at inference time
62 | For more details, please refer to our paper:
63 | `An Improved One millisecond Mobile Backbone` -
64 | https://arxiv.org/pdf/2206.04040.pdf
65 | """
66 |
67 | def __init__(
68 | self,
69 | in_channels: int,
70 | out_channels: int,
71 | kernel_size: int,
72 | stride: int = 1,
73 | padding: int = 0,
74 | dilation: int = 1,
75 | groups: int = 1,
76 | inference_mode: bool = False,
77 | use_se: bool = False,
78 | use_act: bool = True,
79 | use_scale_branch: bool = True,
80 | num_conv_branches: int = 1,
81 | activation: nn.Module = nn.GELU(),
82 | ) -> None:
83 | """Construct a MobileOneBlock module.
84 |
85 | Args:
86 | in_channels: Number of channels in the input.
87 | out_channels: Number of channels produced by the block.
88 | kernel_size: Size of the convolution kernel.
89 | stride: Stride size.
90 | padding: Zero-padding size.
91 | dilation: Kernel dilation factor.
92 | groups: Group number.
93 | inference_mode: If True, instantiates model in inference mode.
94 | use_se: Whether to use SE-ReLU activations.
95 | use_act: Whether to use activation. Default: ``True``
96 | use_scale_branch: Whether to use scale branch. Default: ``True``
97 | num_conv_branches: Number of linear conv branches.
98 | """
99 | super(MobileOneBlock, self).__init__()
100 | self.inference_mode = inference_mode
101 | self.groups = groups
102 | self.stride = stride
103 | self.padding = padding
104 | self.dilation = dilation
105 | self.kernel_size = kernel_size
106 | self.in_channels = in_channels
107 | self.out_channels = out_channels
108 | self.num_conv_branches = num_conv_branches
109 |
110 | # Check if SE-ReLU is requested
111 | if use_se:
112 | self.se = SEBlock(out_channels)
113 | else:
114 | self.se = nn.Identity()
115 |
116 | if use_act:
117 | self.activation = activation
118 | else:
119 | self.activation = nn.Identity()
120 |
121 | if inference_mode:
122 | self.reparam_conv = nn.Conv2d(
123 | in_channels=in_channels,
124 | out_channels=out_channels,
125 | kernel_size=kernel_size,
126 | stride=stride,
127 | padding=padding,
128 | dilation=dilation,
129 | groups=groups,
130 | bias=True,
131 | )
132 | else:
133 | # Re-parameterizable skip connection
134 | self.rbr_skip = (
135 | nn.BatchNorm2d(num_features=in_channels)
136 | if out_channels == in_channels and stride == 1
137 | else None
138 | )
139 |
140 | # Re-parameterizable conv branches
141 | if num_conv_branches > 0:
142 | rbr_conv = list()
143 | for _ in range(self.num_conv_branches):
144 | rbr_conv.append(
145 | self._conv_bn(kernel_size=kernel_size, padding=padding)
146 | )
147 | self.rbr_conv = nn.ModuleList(rbr_conv)
148 | else:
149 | self.rbr_conv = None
150 |
151 | # Re-parameterizable scale branch
152 | self.rbr_scale = None
153 | if not isinstance(kernel_size, int):
154 | kernel_size = kernel_size[0]
155 | if (kernel_size > 1) and use_scale_branch:
156 | self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)
157 |
158 | def forward(self, x: torch.Tensor) -> torch.Tensor:
159 | """Apply forward pass."""
160 | # Inference mode forward pass.
161 | if self.inference_mode:
162 | return self.activation(self.se(self.reparam_conv(x)))
163 |
164 | # Multi-branched train-time forward pass.
165 | # Skip branch output
166 | identity_out = 0
167 | if self.rbr_skip is not None:
168 | identity_out = self.rbr_skip(x)
169 |
170 | # Scale branch output
171 | scale_out = 0
172 | if self.rbr_scale is not None:
173 | scale_out = self.rbr_scale(x)
174 |
175 | # Other branches
176 | out = scale_out + identity_out
177 | if self.rbr_conv is not None:
178 | for ix in range(self.num_conv_branches):
179 | out += self.rbr_conv[ix](x)
180 |
181 | return self.activation(self.se(out))
182 |
183 | def reparameterize(self):
184 | """Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
185 | https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
186 | architecture used at training time to obtain a plain CNN-like structure
187 | for inference.
188 | """
189 | if self.inference_mode:
190 | return
191 | kernel, bias = self._get_kernel_bias()
192 | self.reparam_conv = nn.Conv2d(
193 | in_channels=self.in_channels,
194 | out_channels=self.out_channels,
195 | kernel_size=self.kernel_size,
196 | stride=self.stride,
197 | padding=self.padding,
198 | dilation=self.dilation,
199 | groups=self.groups,
200 | bias=True,
201 | )
202 | self.reparam_conv.weight.data = kernel
203 | self.reparam_conv.bias.data = bias
204 |
205 | # Delete un-used branches
206 | for para in self.parameters():
207 | para.detach_()
208 | self.__delattr__("rbr_conv")
209 | self.__delattr__("rbr_scale")
210 | if hasattr(self, "rbr_skip"):
211 | self.__delattr__("rbr_skip")
212 |
213 | self.inference_mode = True
214 |
215 | def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
216 | """Method to obtain re-parameterized kernel and bias.
217 | Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
218 |
219 | Returns:
220 | Tuple of (kernel, bias) after fusing branches.
221 | """
222 | # get weights and bias of scale branch
223 | kernel_scale = 0
224 | bias_scale = 0
225 | if self.rbr_scale is not None:
226 | kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
227 | # Pad scale branch kernel to match conv branch kernel size.
228 | pad = self.kernel_size // 2
229 | kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
230 |
231 | # get weights and bias of skip branch
232 | kernel_identity = 0
233 | bias_identity = 0
234 | if self.rbr_skip is not None:
235 | kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
236 |
237 | # get weights and bias of conv branches
238 | kernel_conv = 0
239 | bias_conv = 0
240 | if self.rbr_conv is not None:
241 | for ix in range(self.num_conv_branches):
242 | _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
243 | kernel_conv += _kernel
244 | bias_conv += _bias
245 |
246 | kernel_final = kernel_conv + kernel_scale + kernel_identity
247 | bias_final = bias_conv + bias_scale + bias_identity
248 | return kernel_final, bias_final
249 |
250 | def _fuse_bn_tensor(
251 | self, branch: Union[nn.Sequential, nn.BatchNorm2d]
252 | ) -> Tuple[torch.Tensor, torch.Tensor]:
253 | """Method to fuse batchnorm layer with preceeding conv layer.
254 | Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
255 |
256 | Args:
257 | branch: Sequence of ops to be fused.
258 |
259 | Returns:
260 | Tuple of (kernel, bias) after fusing batchnorm.
261 | """
262 | if isinstance(branch, nn.Sequential):
263 | kernel = branch.conv.weight
264 | running_mean = branch.bn.running_mean
265 | running_var = branch.bn.running_var
266 | gamma = branch.bn.weight
267 | beta = branch.bn.bias
268 | eps = branch.bn.eps
269 | else:
270 | assert isinstance(branch, nn.BatchNorm2d)
271 | if not hasattr(self, "id_tensor"):
272 | input_dim = self.in_channels // self.groups
273 |
274 | kernel_size = self.kernel_size
275 | if isinstance(self.kernel_size, int):
276 | kernel_size = (self.kernel_size, self.kernel_size)
277 |
278 | kernel_value = torch.zeros(
279 | (self.in_channels, input_dim, kernel_size[0], kernel_size[1]),
280 | dtype=branch.weight.dtype,
281 | device=branch.weight.device,
282 | )
283 | for i in range(self.in_channels):
284 | kernel_value[
285 | i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2
286 | ] = 1
287 | self.id_tensor = kernel_value
288 | kernel = self.id_tensor
289 | running_mean = branch.running_mean
290 | running_var = branch.running_var
291 | gamma = branch.weight
292 | beta = branch.bias
293 | eps = branch.eps
294 | std = (running_var + eps).sqrt()
295 | t = (gamma / std).reshape(-1, 1, 1, 1)
296 | return kernel * t, beta - running_mean * gamma / std
297 |
298 | def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
299 | """Helper method to construct conv-batchnorm layers.
300 |
301 | Args:
302 | kernel_size: Size of the convolution kernel.
303 | padding: Zero-padding size.
304 |
305 | Returns:
306 | Conv-BN module.
307 | """
308 | mod_list = nn.Sequential()
309 | mod_list.add_module(
310 | "conv",
311 | nn.Conv2d(
312 | in_channels=self.in_channels,
313 | out_channels=self.out_channels,
314 | kernel_size=kernel_size,
315 | stride=self.stride,
316 | padding=padding,
317 | groups=self.groups,
318 | bias=False,
319 | ),
320 | )
321 | mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels))
322 | return mod_list
323 |
324 |
325 | def reparameterize_model(model: torch.nn.Module) -> nn.Module:
326 | """Method returns a model where a multi-branched structure
327 | used in training is re-parameterized into a single branch
328 | for inference.
329 |
330 | Args:
331 | model: MobileOne model in train mode.
332 |
333 | Returns:
334 | MobileOne model in inference mode.
335 | """
336 | # Avoid editing original graph
337 | model = copy.deepcopy(model)
338 | for module in model.modules():
339 | if hasattr(module, "reparameterize"):
340 | module.reparameterize()
341 | return model
342 |
--------------------------------------------------------------------------------
/mobileclip/modules/image/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All rights reserved.
4 | #
5 |
--------------------------------------------------------------------------------
/mobileclip/modules/image/image_projection.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | from typing import List, Optional
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch import Tensor
10 |
11 | from mobileclip import logger
12 |
13 |
14 | class GlobalPool(nn.Module):
15 | """
16 | This layers applies global pooling over a 4D or 5D input tensor
17 |
18 | Args:
19 | pool_type (Optional[str]): Pooling type. It can be mean, rms, or abs. Default: `mean`
20 | keep_dim (Optional[bool]): Do not squeeze the dimensions of a tensor. Default: `False`
21 |
22 | Shape:
23 | - Input: :math:`(N, C, H, W)` or :math:`(N, C, D, H, W)`
24 | - Output: :math:`(N, C, 1, 1)` or :math:`(N, C, 1, 1, 1)` if keep_dim else :math:`(N, C)`
25 | """
26 |
27 | pool_types = ["mean", "rms", "abs"]
28 |
29 | def __init__(
30 | self,
31 | pool_type: Optional[str] = "mean",
32 | keep_dim: Optional[bool] = False,
33 | *args,
34 | **kwargs
35 | ) -> None:
36 | super().__init__()
37 | if pool_type not in self.pool_types:
38 | logger.error(
39 | "Supported pool types are: {}. Got {}".format(
40 | self.pool_types, pool_type
41 | )
42 | )
43 | self.pool_type = pool_type
44 | self.keep_dim = keep_dim
45 |
46 | def _global_pool(self, x: Tensor, dims: List):
47 | if self.pool_type == "rms": # root mean square
48 | x = x**2
49 | x = torch.mean(x, dim=dims, keepdim=self.keep_dim)
50 | x = x**-0.5
51 | elif self.pool_type == "abs": # absolute
52 | x = torch.mean(torch.abs(x), dim=dims, keepdim=self.keep_dim)
53 | else:
54 | # default is mean
55 | # same as AdaptiveAvgPool
56 | x = torch.mean(x, dim=dims, keepdim=self.keep_dim)
57 | return x
58 |
59 | def forward(self, x: Tensor) -> Tensor:
60 | if x.dim() == 4:
61 | dims = [-2, -1]
62 | elif x.dim() == 5:
63 | dims = [-3, -2, -1]
64 | else:
65 | raise NotImplementedError("Currently 2D and 3D global pooling supported")
66 | return self._global_pool(x, dims=dims)
67 |
68 |
69 | class GlobalPool2D(nn.Module):
70 | """This class implements global pooling with linear projection."""
71 |
72 | def __init__(self, in_dim: int, out_dim: int, *args, **kwargs) -> None:
73 | super().__init__()
74 | scale = in_dim**-0.5
75 | self.pool = GlobalPool(pool_type="mean", keep_dim=False)
76 | self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim)))
77 | self.in_dim = in_dim
78 | self.out_dim = out_dim
79 |
80 | def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
81 | # x is of shape [batch, in_dim]
82 | assert (
83 | x.dim() == 4
84 | ), "Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {}".format(
85 | x.shape
86 | )
87 |
88 | # [batch, in_dim, in_height, in_width] --> [batch, in_dim]
89 | x = self.pool(x)
90 | # [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim]
91 | x = x @ self.proj
92 | return x
93 |
94 |
95 | class SimpleImageProjectionHead(nn.Module):
96 | """This class implements linear projection head."""
97 |
98 | def __init__(self, in_dim: int, out_dim: int) -> None:
99 | super().__init__()
100 | scale = in_dim**-0.5
101 | self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim)))
102 | self.in_dim = in_dim
103 | self.out_dim = out_dim
104 |
105 | def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
106 | # x is of shape [batch, in_dim]
107 | assert (
108 | x.dim() == 2
109 | ), "Input should be 2-dimensional (Batch x in_dim). Got: {}".format(x.shape)
110 |
111 | # [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim]
112 | x = x @ self.proj
113 | return x
114 |
--------------------------------------------------------------------------------
/mobileclip/modules/image/replknet.py:
--------------------------------------------------------------------------------
1 | #
2 | # For acknowledgement see accompanying ACKNOWLEDGEMENTS file.
3 | # Copyright (C) 2024 Apple Inc. All rights reserved.
4 | #
5 | from typing import Tuple
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from timm.models.layers import SqueezeExcite
11 |
12 | __all__ = ["ReparamLargeKernelConv"]
13 |
14 |
15 | class ReparamLargeKernelConv(nn.Module):
16 | """Building Block of RepLKNet
17 |
18 | This class defines overparameterized large kernel conv block
19 | introduced in `RepLKNet `_
20 |
21 | Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
22 | """
23 |
24 | def __init__(
25 | self,
26 | in_channels: int,
27 | out_channels: int,
28 | kernel_size: int,
29 | stride: int,
30 | groups: int,
31 | small_kernel: int,
32 | inference_mode: bool = False,
33 | use_se: bool = False,
34 | activation: nn.Module = nn.GELU(),
35 | ) -> None:
36 | """Construct a ReparamLargeKernelConv module.
37 |
38 | Args:
39 | in_channels: Number of input channels.
40 | out_channels: Number of output channels.
41 | kernel_size: Kernel size of the large kernel conv branch.
42 | stride: Stride size. Default: 1
43 | groups: Group number. Default: 1
44 | small_kernel: Kernel size of small kernel conv branch.
45 | inference_mode: If True, instantiates model in inference mode. Default: ``False``
46 | activation: Activation module. Default: ``nn.GELU``
47 | """
48 | super(ReparamLargeKernelConv, self).__init__()
49 |
50 | self.stride = stride
51 | self.groups = groups
52 | self.in_channels = in_channels
53 | self.out_channels = out_channels
54 | self.activation = activation
55 |
56 | self.kernel_size = kernel_size
57 | self.small_kernel = small_kernel
58 | self.padding = kernel_size // 2
59 |
60 | # Check if SE is requested
61 | if use_se:
62 | self.se = SqueezeExcite(out_channels, rd_ratio=0.25)
63 | else:
64 | self.se = nn.Identity()
65 |
66 | if inference_mode:
67 | self.lkb_reparam = nn.Conv2d(
68 | in_channels=in_channels,
69 | out_channels=out_channels,
70 | kernel_size=kernel_size,
71 | stride=stride,
72 | padding=self.padding,
73 | dilation=1,
74 | groups=groups,
75 | bias=True,
76 | )
77 | else:
78 | self.lkb_origin = self._conv_bn(
79 | kernel_size=kernel_size, padding=self.padding
80 | )
81 | if small_kernel is not None:
82 | assert (
83 | small_kernel <= kernel_size
84 | ), "The kernel size for re-param cannot be larger than the large kernel!"
85 | self.small_conv = self._conv_bn(
86 | kernel_size=small_kernel, padding=small_kernel // 2
87 | )
88 |
89 | def forward(self, x: torch.Tensor) -> torch.Tensor:
90 | """Apply forward pass."""
91 | if hasattr(self, "lkb_reparam"):
92 | out = self.lkb_reparam(x)
93 | else:
94 | out = self.lkb_origin(x)
95 | if hasattr(self, "small_conv"):
96 | out += self.small_conv(x)
97 |
98 | return self.activation(self.se(out))
99 |
100 | def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
101 | """Method to obtain re-parameterized kernel and bias.
102 | Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
103 |
104 | Returns:
105 | Tuple of (kernel, bias) after fusing branches.
106 | """
107 | eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
108 | if hasattr(self, "small_conv"):
109 | small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
110 | eq_b += small_b
111 | eq_k += nn.functional.pad(
112 | small_k, [(self.kernel_size - self.small_kernel) // 2] * 4
113 | )
114 | return eq_k, eq_b
115 |
116 | def reparameterize(self) -> None:
117 | """
118 | Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
119 | https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
120 | architecture used at training time to obtain a plain CNN-like structure
121 | for inference.
122 | """
123 | eq_k, eq_b = self.get_kernel_bias()
124 | self.lkb_reparam = nn.Conv2d(
125 | in_channels=self.in_channels,
126 | out_channels=self.out_channels,
127 | kernel_size=self.kernel_size,
128 | stride=self.stride,
129 | padding=self.padding,
130 | dilation=self.lkb_origin.conv.dilation,
131 | groups=self.groups,
132 | bias=True,
133 | )
134 |
135 | self.lkb_reparam.weight.data = eq_k
136 | self.lkb_reparam.bias.data = eq_b
137 | self.__delattr__("lkb_origin")
138 | if hasattr(self, "small_conv"):
139 | self.__delattr__("small_conv")
140 |
141 | @staticmethod
142 | def _fuse_bn(
143 | conv: torch.Tensor, bn: nn.BatchNorm2d
144 | ) -> Tuple[torch.Tensor, torch.Tensor]:
145 | """Method to fuse batchnorm layer with conv layer.
146 |
147 | Args:
148 | conv: Convolutional kernel weights.
149 | bn: Batchnorm 2d layer.
150 |
151 | Returns:
152 | Tuple of (kernel, bias) after fusing batchnorm.
153 | """
154 | kernel = conv.weight
155 | running_mean = bn.running_mean
156 | running_var = bn.running_var
157 | gamma = bn.weight
158 | beta = bn.bias
159 | eps = bn.eps
160 | std = (running_var + eps).sqrt()
161 | t = (gamma / std).reshape(-1, 1, 1, 1)
162 | return kernel * t, beta - running_mean * gamma / std
163 |
164 | def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential:
165 | """Helper method to construct conv-batchnorm layers.
166 |
167 | Args:
168 | kernel_size: Size of the convolution kernel.
169 | padding: Zero-padding size.
170 |
171 | Returns:
172 | A nn.Sequential Conv-BN module.
173 | """
174 | mod_list = nn.Sequential()
175 | mod_list.add_module(
176 | "conv",
177 | nn.Conv2d(
178 | in_channels=self.in_channels,
179 | out_channels=self.out_channels,
180 | kernel_size=kernel_size,
181 | stride=self.stride,
182 | padding=padding,
183 | groups=self.groups,
184 | bias=False,
185 | ),
186 | )
187 | mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels))
188 | return mod_list
189 |
--------------------------------------------------------------------------------
/mobileclip/modules/text/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All rights reserved.
4 | #
5 |
--------------------------------------------------------------------------------
/mobileclip/modules/text/repmixer.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | from typing import Optional
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from timm.models.layers import DropPath, trunc_normal_
11 | from mobileclip.modules.common.mobileone import MobileOneBlock
12 |
13 |
14 | class ConvFFN(nn.Module):
15 | """Convolutional FFN Module."""
16 |
17 | def __init__(
18 | self,
19 | in_channels: int,
20 | context_size: int,
21 | hidden_channels: Optional[int] = None,
22 | out_channels: Optional[int] = None,
23 | act_layer: nn.Module = nn.GELU,
24 | drop: float = 0.0,
25 | ) -> None:
26 | """Build convolutional FFN module.
27 |
28 | Args:
29 | in_channels: Number of input channels.
30 | context_size: Context size for 1D signals.
31 | hidden_channels: Number of channels after expansion. Default: None
32 | out_channels: Number of output channels. Default: None
33 | act_layer: Activation layer. Default: ``GELU``
34 | drop: Dropout rate. Default: ``0.0``.
35 | """
36 | super().__init__()
37 | out_channels = out_channels or in_channels
38 | hidden_channels = hidden_channels or in_channels
39 | self.conv = nn.Sequential()
40 | self.conv.add_module(
41 | "conv",
42 | nn.Conv2d(
43 | in_channels=in_channels,
44 | out_channels=out_channels,
45 | kernel_size=(1, int(context_size)),
46 | padding=(0, int(context_size // 2)),
47 | groups=in_channels,
48 | bias=False,
49 | ),
50 | )
51 | self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels))
52 | self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
53 | self.act = act_layer()
54 | self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
55 | self.drop = nn.Dropout(drop)
56 | self.apply(self._init_weights)
57 |
58 | def _init_weights(self, m: nn.Module) -> None:
59 | if isinstance(m, nn.Conv2d):
60 | trunc_normal_(m.weight, std=0.02)
61 | if m.bias is not None:
62 | nn.init.constant_(m.bias, 0)
63 |
64 | def forward(self, x: torch.Tensor) -> torch.Tensor:
65 | x = self.conv(x)
66 | x = self.fc1(x)
67 | x = self.act(x)
68 | x = self.drop(x)
69 | x = self.fc2(x)
70 | x = self.drop(x)
71 | return x
72 |
73 |
74 | class RepMixer(nn.Module):
75 | """Reparameterizable token mixer.
76 |
77 | For more details, please refer to our paper:
78 | `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization `_
79 | """
80 |
81 | def __init__(
82 | self,
83 | dim,
84 | kernel_size=3,
85 | use_layer_scale=True,
86 | layer_scale_init_value=1e-5,
87 | inference_mode: bool = False,
88 | ):
89 | """Build RepMixer Module.
90 |
91 | Args:
92 | dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
93 | kernel_size: Kernel size for spatial mixing. Default: 3
94 | use_layer_scale: If True, learnable layer scale is used. Default: ``True``
95 | layer_scale_init_value: Initial value for layer scale. Default: 1e-5
96 | inference_mode: If True, instantiates model in inference mode. Default: ``False``
97 | """
98 | super().__init__()
99 | self.dim = dim
100 | self.kernel_size = kernel_size
101 | self.inference_mode = inference_mode
102 |
103 | if inference_mode:
104 | self.reparam_conv = nn.Conv2d(
105 | in_channels=self.dim,
106 | out_channels=self.dim,
107 | kernel_size=(1, self.kernel_size),
108 | stride=1,
109 | padding=(0, self.kernel_size // 2),
110 | groups=self.dim,
111 | bias=True,
112 | )
113 | else:
114 | self.norm = MobileOneBlock(
115 | dim,
116 | dim,
117 | (1, kernel_size),
118 | padding=(0, kernel_size // 2),
119 | groups=dim,
120 | use_act=False,
121 | use_scale_branch=False,
122 | num_conv_branches=0,
123 | )
124 | self.mixer = MobileOneBlock(
125 | dim,
126 | dim,
127 | (1, kernel_size),
128 | padding=(0, kernel_size // 2),
129 | groups=dim,
130 | use_act=False,
131 | )
132 | self.use_layer_scale = use_layer_scale
133 | if use_layer_scale:
134 | self.layer_scale = nn.Parameter(
135 | layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
136 | )
137 |
138 | def forward(self, x: torch.Tensor) -> torch.Tensor:
139 | if hasattr(self, "reparam_conv"):
140 | x = self.reparam_conv(x)
141 | return x
142 | else:
143 | if self.use_layer_scale:
144 | x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
145 | else:
146 | x = x + self.mixer(x) - self.norm(x)
147 | return x
148 |
149 | def reparameterize(self) -> None:
150 | """Reparameterize mixer and norm into a single
151 | convolutional layer for efficient inference.
152 | """
153 | if self.inference_mode:
154 | return
155 |
156 | self.mixer.reparameterize()
157 | self.norm.reparameterize()
158 |
159 | if self.use_layer_scale:
160 | w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
161 | self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
162 | )
163 | b = torch.squeeze(self.layer_scale) * (
164 | self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
165 | )
166 | else:
167 | w = (
168 | self.mixer.id_tensor
169 | + self.mixer.reparam_conv.weight
170 | - self.norm.reparam_conv.weight
171 | )
172 | b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
173 |
174 | self.reparam_conv = nn.Conv2d(
175 | in_channels=self.dim,
176 | out_channels=self.dim,
177 | kernel_size=(1, self.kernel_size),
178 | stride=1,
179 | padding=(0, self.kernel_size // 2),
180 | groups=self.dim,
181 | bias=True,
182 | )
183 | self.reparam_conv.weight.data = w
184 | self.reparam_conv.bias.data = b
185 |
186 | for para in self.parameters():
187 | para.detach_()
188 | self.__delattr__("mixer")
189 | self.__delattr__("norm")
190 | if self.use_layer_scale:
191 | self.__delattr__("layer_scale")
192 |
193 |
194 | class RepMixerBlock(nn.Module):
195 | """Implementation of Metaformer block with RepMixer as token mixer.
196 |
197 | For more details on Metaformer structure, please refer to:
198 | `MetaFormer Is Actually What You Need for Vision `_
199 | """
200 |
201 | def __init__(
202 | self,
203 | dim: int,
204 | kernel_size: int = 11,
205 | mlp_ratio: float = 4.0,
206 | act_layer: nn.Module = nn.GELU,
207 | drop: float = 0.0,
208 | drop_path: float = 0.0,
209 | use_layer_scale: bool = True,
210 | layer_scale_init_value: float = 1e-5,
211 | inference_mode: bool = False,
212 | *args,
213 | **kwargs,
214 | ):
215 | """Build RepMixer Block.
216 |
217 | Args:
218 | dim: Number of embedding dimensions.
219 | kernel_size: Kernel size for repmixer. Default: 3
220 | mlp_ratio: MLP expansion ratio. Default: 4.0
221 | act_layer: Activation layer. Default: ``nn.GELU``
222 | drop: Dropout rate. Default: 0.0
223 | drop_path: Drop path rate. Default: 0.0
224 | use_layer_scale: Flag to turn on layer scale. Default: ``True``
225 | layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
226 | inference_mode: Flag to instantiate block in inference mode. Default: ``False``
227 | """
228 |
229 | super().__init__()
230 |
231 | self.token_mixer = RepMixer(
232 | dim,
233 | kernel_size=kernel_size,
234 | use_layer_scale=use_layer_scale,
235 | layer_scale_init_value=layer_scale_init_value,
236 | inference_mode=inference_mode,
237 | )
238 |
239 | assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
240 | mlp_ratio
241 | )
242 | mlp_hidden_dim = int(dim * mlp_ratio)
243 | self.convffn = ConvFFN(
244 | in_channels=dim,
245 | context_size=kernel_size,
246 | hidden_channels=mlp_hidden_dim,
247 | act_layer=act_layer,
248 | drop=drop,
249 | )
250 |
251 | # Drop Path
252 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
253 |
254 | # Layer Scale
255 | self.use_layer_scale = use_layer_scale
256 | if use_layer_scale:
257 | self.layer_scale = nn.Parameter(
258 | layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
259 | )
260 |
261 | def forward(self, x, *args, **kwargs):
262 | if x.dim() == 3:
263 | # B, C, D --- where C is the context length
264 | # Convert to B, D, C --- to match RepMixer impl.
265 | x = x.permute(0, 2, 1)
266 | x = torch.unsqueeze(x, dim=2)
267 | else:
268 | raise ValueError(
269 | f"Expected tensor of dim=3, obtained tensor of dim={x.dim()}"
270 | )
271 |
272 | if self.use_layer_scale:
273 | x = self.token_mixer(x)
274 | x = x + self.drop_path(self.layer_scale * self.convffn(x))
275 | else:
276 | x = self.token_mixer(x)
277 | x = x + self.drop_path(self.convffn(x))
278 |
279 | # Convert tensors back
280 | x = x.squeeze(dim=2).permute(0, 2, 1)
281 | return x
282 |
--------------------------------------------------------------------------------
/mobileclip/modules/text/tokenizer.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | from typing import Dict
6 |
7 | import open_clip
8 | from torch import Tensor, nn
9 |
10 |
11 | class ClipTokenizer(nn.Module):
12 | def __init__(self, cfg, *args, **kwargs):
13 | super().__init__()
14 | self.context_length = cfg["text_cfg"]["context_length"]
15 | model_name = getattr(cfg["text_cfg"], "open_clip_tokenizer", "ViT-B-16")
16 | self.tokenizer = open_clip.get_tokenizer(model_name)
17 |
18 | def get_vocab_size(self) -> int:
19 | return len(self.tokenizer.encoder)
20 |
21 | def get_encodings(self) -> Dict[str, int]:
22 | return self.tokenizer.encoder
23 |
24 | def get_eot_token(self) -> int:
25 | # Tokenizing an empty string returns a list [sot_id, eot_id]
26 | return self.tokenizer("")[1]
27 |
28 | def get_sot_token(self) -> int:
29 | # Tokenizing an empty string returns a list [sot_id, eot_id]
30 | return self.tokenizer("")[0]
31 |
32 | def forward(self, input_sentence: str, *args, **kwargs) -> Tensor:
33 | # tokenizer returns indices as a string
34 | tokenized_sentence = self.tokenizer(input_sentence, self.context_length)
35 | assert (
36 | tokenized_sentence.shape[-1] == self.context_length
37 | ), "Tokenized tensor should be exactly `context_length` long."
38 | return tokenized_sentence
39 |
--------------------------------------------------------------------------------
/mobileclip/text_encoder.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | import math
6 | from typing import Optional, Sequence
7 |
8 | import torch
9 | from torch import Tensor, nn
10 |
11 | from mobileclip.modules.common.transformer import (
12 | PositionalEmbedding,
13 | TransformerEncoder,
14 | get_normalization_layer,
15 | )
16 | from mobileclip.modules.text.repmixer import RepMixerBlock
17 | from mobileclip import logger
18 |
19 |
20 | class TextTransformer(nn.Module):
21 | def __init__(self, cfg: dict, projection_dim: int, *args, **kwargs) -> None:
22 | super().__init__()
23 |
24 | model_dim = cfg["dim"]
25 | no_scale_embedding = cfg.get("no_scale_embedding", False)
26 | no_pos_embedding = cfg.get("no_pos_embedding", False)
27 | embed_dropout = cfg.get("embed_dropout", 0.0)
28 | norm_layer = cfg["norm_layer"]
29 | variant = cfg["model_name"]
30 | self.vocab_size = cfg["vocab_size"]
31 | self.projection_dim = projection_dim
32 |
33 | # Token embedding layer
34 | self.embedding_layer = nn.Embedding(
35 | embedding_dim=model_dim, num_embeddings=self.vocab_size
36 | )
37 | self.embed_scale = 1.0 if no_scale_embedding else model_dim**-0.5
38 |
39 | # Context length
40 | context_length = cfg["context_length"]
41 | assert (
42 | context_length is not None
43 | ), "Context length can't be None. Please set value accordingly."
44 |
45 | self.positional_embedding = (
46 | None
47 | if no_pos_embedding
48 | else PositionalEmbedding(
49 | num_embeddings=context_length, embedding_dim=model_dim
50 | )
51 | )
52 |
53 | self.embedding_dropout = nn.Dropout(p=embed_dropout)
54 |
55 | # Transformer layer
56 | n_transformer_layers = cfg["n_transformer_layers"]
57 |
58 | # FFN multipliers for transformer layer
59 | ffn_multipliers = cfg["ffn_multiplier_per_layer"]
60 | if isinstance(ffn_multipliers, (float, int)):
61 | ffn_multipliers = [ffn_multipliers] * n_transformer_layers
62 |
63 | if not isinstance(ffn_multipliers, Sequence):
64 | logger.error(
65 | "{} expects FFN multipliers as a list, whose length is the same as"
66 | " number of transformer layers. Got: {}".format(
67 | self.__class__.__name__, type(ffn_multipliers)
68 | )
69 | )
70 | elif (
71 | isinstance(ffn_multipliers, Sequence)
72 | and len(ffn_multipliers) != n_transformer_layers
73 | ):
74 | logger.error(
75 | "We need FFN multiplier for each transformer layer. Got {} ffn"
76 | " multipliers while number of transformer layers = {}".format(
77 | len(ffn_multipliers), n_transformer_layers
78 | )
79 | )
80 | ffn_dims = [
81 | int(math.ceil(model_dim * ffn_mult / 16.0) * 16.0)
82 | for ffn_mult in ffn_multipliers
83 | ]
84 |
85 | # Heads for transformer layers
86 | mha_heads = cfg["n_heads_per_layer"]
87 | if isinstance(mha_heads, int):
88 | mha_heads = [mha_heads] * n_transformer_layers
89 |
90 | if not isinstance(mha_heads, Sequence):
91 | logger.error(
92 | "{} expects MHA heads as a list, whose length is the same as number of "
93 | "transformer layers. Got: {}".format(
94 | self.__class__.__name__, type(mha_heads)
95 | )
96 | )
97 | elif isinstance(mha_heads, Sequence) and len(mha_heads) != n_transformer_layers:
98 | logger.error(
99 | "{} needs MHA heads for each transformer layer. Got {} mha heads while"
100 | " number of transformer layers = {}".format(
101 | self.__class__.__name__, len(mha_heads), n_transformer_layers
102 | )
103 | )
104 |
105 | if variant == "base":
106 | self.transformer = nn.ModuleList(
107 | [
108 | TransformerEncoder(
109 | embed_dim=model_dim,
110 | num_heads=mha_heads[layer_idx],
111 | ffn_latent_dim=ffn_dims[layer_idx],
112 | transformer_norm_layer=norm_layer,
113 | )
114 | for layer_idx in range(n_transformer_layers)
115 | ]
116 | )
117 | elif variant == "mct":
118 | self.transformer = nn.ModuleList([RepMixerBlock(dim=model_dim)])
119 | self.transformer.extend(
120 | [
121 | TransformerEncoder(
122 | embed_dim=model_dim,
123 | num_heads=mha_heads[layer_idx],
124 | ffn_latent_dim=ffn_dims[layer_idx],
125 | transformer_norm_layer=norm_layer,
126 | )
127 | for layer_idx in range(n_transformer_layers)
128 | ]
129 | )
130 | self.transformer.extend([RepMixerBlock(dim=model_dim)])
131 | else:
132 | raise ValueError("Unrecognized text encoder variant {}".format(variant))
133 |
134 | self.final_layer_norm = get_normalization_layer(
135 | num_features=model_dim, norm_type=norm_layer
136 | )
137 |
138 | self.projection_layer = nn.Parameter(
139 | torch.empty(model_dim, self.projection_dim)
140 | )
141 | self.model_dim = model_dim
142 | self.causal_masking = cfg["causal_masking"]
143 |
144 | def forward_embedding(self, text_tokens: Tensor) -> Tensor:
145 | """Return text embedding for all tokens.
146 |
147 | Args:
148 | text_tokens: a tensor of token indices. Shape: [batch_size, context_length]
149 |
150 | Returns:
151 | A tensor of [batch_size, context_length, hidden_dim].
152 | """
153 | # [batch_size, context_length] --> [batch_size, context_length, hidden_dim]
154 | token_emb = self.embedding_layer(text_tokens)
155 | seq_len = token_emb.shape[1]
156 | if self.positional_embedding is not None:
157 | token_emb = token_emb + self.positional_embedding(seq_len).to(
158 | token_emb.dtype
159 | )
160 | token_emb = self.embedding_dropout(token_emb)
161 | return token_emb
162 |
163 | def build_attention_mask(self, context_length: int, batch_size: int) -> Tensor:
164 | """Build causal attention mask [batch_size, context_length, context_length]."""
165 | # Build mask with full attention between the tokens
166 | # pytorch uses additive attention mask; fill with -inf
167 | mask = torch.empty(context_length, context_length)
168 | mask.fill_(float("-inf"))
169 | mask.triu_(1) # zero out the lower diagonal
170 | mask = mask.unsqueeze(0) # add dummy batch dimension
171 | mask = mask.expand(batch_size, -1, -1)
172 | return mask
173 |
174 | def encode_text(
175 | self,
176 | text_tokens: Tensor,
177 | key_padding_mask: Optional[Tensor] = None,
178 | return_all_tokens: bool = False,
179 | *args,
180 | **kwargs
181 | ) -> Tensor:
182 | """Return text token embeddings.
183 |
184 | Args:
185 | text_tokens: a tensor of token indices. Shape: [batch_size, context_length]
186 | key_padding_mask: a tensor of boolean values as the padding mask.
187 | Shape: [batch_size, context_length]
188 | return_all_tokens: a boolean flag to return all tokens, defaults to False
189 | to return only EOT token embedding.
190 | Returns:
191 | A tensor of [batch_size, context_length, hidden_dim] if return_all_tokens is
192 | True, otherwise a tensor of [batch_size, hidden_dim].
193 | """
194 | # Discrete tokens to continuous embeddings
195 | # [batch_size, context_length] --> [batch_size, context_length, hidden_dim]
196 | token_emb = self.forward_embedding(text_tokens)
197 |
198 | # [1, context_length, context_length]
199 | attn_mask = None
200 | if self.causal_masking:
201 | attn_mask = self.build_attention_mask(
202 | context_length=text_tokens.shape[1], batch_size=text_tokens.shape[0]
203 | )
204 | attn_mask = attn_mask.to(device=token_emb.device, dtype=token_emb.dtype)
205 | key_padding_mask = None
206 |
207 | for layer in self.transformer:
208 | token_emb = layer(
209 | token_emb,
210 | key_padding_mask=key_padding_mask,
211 | attn_mask=attn_mask,
212 | )
213 |
214 | # Apply layer norm
215 | token_emb = self.final_layer_norm(token_emb)
216 |
217 | if return_all_tokens:
218 | return token_emb
219 |
220 | # Take features from the eot embedding (eot_token is the highest number in each sequence)
221 | token_emb = token_emb[
222 | torch.arange(text_tokens.shape[0]), text_tokens.argmax(dim=-1)
223 | ]
224 |
225 | token_emb = token_emb @ self.projection_layer
226 | return token_emb
227 |
228 | def forward(
229 | self,
230 | text_tokens: Tensor,
231 | key_padding_mask: Optional[Tensor] = None,
232 | return_all_tokens: bool = False,
233 | *args,
234 | **kwargs
235 | ) -> Tensor:
236 | # Image-text pair data with single caption
237 | # [B, CL] --> [B, d]
238 | text_tokens = self.encode_text(
239 | text_tokens=text_tokens,
240 | key_padding_mask=key_padding_mask,
241 | return_all_tokens=return_all_tokens,
242 | *args,
243 | **kwargs
244 | )
245 | return text_tokens
246 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | clip-benchmark>=1.4.0
2 | datasets>=2.21.0
3 | open-clip-torch>=2.20.0
4 | timm>=0.9.5
5 | torch>=2.1.0
6 | torchvision==0.14.1
7 |
--------------------------------------------------------------------------------
/results/mobileclip_s2.jsonl:
--------------------------------------------------------------------------------
1 | {"key": "vtab/caltech101", "dataset": "Caltech-101", "metrics": {"acc1": 0.8447000821692687, "acc5": 0.9541495480690222, "mean_per_class_recall": 0.936722834577359, "main_metric": 0.936722834577359}}
2 | {"key": "cifar10", "dataset": "CIFAR-10", "metrics": {"acc1": 0.9588, "acc5": 0.9995, "mean_per_class_recall": 0.9588000000000001, "main_metric": 0.9588}}
3 | {"key": "vtab/cifar100", "dataset": "CIFAR-100", "metrics": {"acc1": 0.8128, "acc5": 0.9667, "mean_per_class_recall": 0.8128, "main_metric": 0.8128}}
4 | {"key": "vtab/clevr_count_all", "dataset": "CLEVR Counts", "metrics": {"acc1": 0.3344, "acc5": 0.9235333333333333, "mean_per_class_recall": 0.33309194769139167, "main_metric": 0.3344}}
5 | {"key": "vtab/clevr_closest_object_distance", "dataset": "CLEVR Distance", "metrics": {"acc1": 0.2376, "acc5": 0.9186666666666666, "mean_per_class_recall": 0.17236004590330653, "main_metric": 0.2376}}
6 | {"key": "country211", "dataset": "Country211", "metrics": {"acc1": 0.19199052132701422, "acc5": 0.42241706161137443, "mean_per_class_recall": 0.19199052132701416, "main_metric": 0.19199052132701422}}
7 | {"key": "vtab/dtd", "dataset": "Describable Textures", "metrics": {"acc1": 0.6090425531914894, "acc5": 0.8803191489361702, "mean_per_class_recall": 0.6090425531914894, "main_metric": 0.6090425531914894}}
8 | {"key": "vtab/eurosat", "dataset": "EuroSAT", "metrics": {"acc1": 0.6672222222222223, "acc5": 0.9757407407407407, "mean_per_class_recall": 0.6710678146515698, "main_metric": 0.6672222222222223}}
9 | {"key": "fgvc_aircraft", "dataset": "FGVC Aircraft", "metrics": {"acc1": 0.30423042304230424, "acc5": 0.6753675367536753, "mean_per_class_recall": 0.30353832442067735, "main_metric": 0.30353832442067735}}
10 | {"key": "food101", "dataset": "Food-101", "metrics": {"acc1": 0.8967920792079208, "acc5": 0.9872475247524752, "mean_per_class_recall": 0.8967920792079209, "main_metric": 0.8967920792079208}}
11 | {"key": "gtsrb", "dataset": "GTSRB", "metrics": {"acc1": 0.5361836896278701, "acc5": 0.8277117973079968, "mean_per_class_recall": 0.4787885812648198, "main_metric": 0.5361836896278701}}
12 | {"key": "imagenet1k", "dataset": "ImageNet 1k", "metrics": {"acc1": 0.74438, "acc5": 0.94396, "mean_per_class_recall": 0.74424, "main_metric": 0.74438}}
13 | {"key": "imagenet_sketch", "dataset": "ImageNet Sketch", "metrics": {"acc1": 0.6221187289984083, "acc5": 0.8650985478197646, "mean_per_class_recall": 0.6223282352941176, "main_metric": 0.6221187289984083}}
14 | {"key": "imagenetv2", "dataset": "ImageNet v2", "metrics": {"acc1": 0.6679, "acc5": 0.9042, "mean_per_class_recall": 0.6684000000000001, "main_metric": 0.6679}}
15 | {"key": "imagenet-a", "dataset": "ImageNet-A", "metrics": {"acc1": 0.4932, "acc5": 0.8016, "mean_per_class_recall": 0.4960668896772007, "main_metric": 0.4932}}
16 | {"key": "imagenet-o", "dataset": "ImageNet-O", "metrics": {"acc1": 0.469, "acc5": 0.7885, "mean_per_class_recall": 0.48375934665609593, "main_metric": 0.469}}
17 | {"key": "imagenet-r", "dataset": "ImageNet-R", "metrics": {"acc1": 0.8698666666666667, "acc5": 0.9681333333333333, "mean_per_class_recall": 0.8553672907748989, "main_metric": 0.8698666666666667}}
18 | {"key": "vtab/kitti_closest_vehicle_distance", "dataset": "KITTI Vehicle Distance", "metrics": {"acc1": 0.2320675105485232, "acc5": null, "mean_per_class_recall": 0.32955606784304503, "main_metric": 0.2320675105485232}}
19 | {"key": "mnist", "dataset": "MNIST", "metrics": {"acc1": 0.8702, "acc5": 0.9823, "mean_per_class_recall": 0.8687912186633511, "main_metric": 0.8702}}
20 | {"key": "objectnet", "dataset": "ObjectNet", "metrics": {"acc1": 0.6656078389146118, "acc5": 0.8666953806396037, "mean_per_class_recall": 0.6523596249996781, "main_metric": 0.6656078389146118}}
21 | {"key": "vtab/flowers", "dataset": "Oxford Flowers-102", "metrics": {"acc1": 0.7503659131566108, "acc5": 0.88209464953651, "mean_per_class_recall": 0.7320917902269698, "main_metric": 0.7320917902269698}}
22 | {"key": "vtab/pets", "dataset": "Oxford-IIIT Pet", "metrics": {"acc1": 0.9283183428727174, "acc5": 0.9983646770237122, "mean_per_class_recall": 0.9278175397989731, "main_metric": 0.9278175397989731}}
23 | {"key": "voc2007", "dataset": "Pascal VOC 2007", "metrics": {"acc1": 0.8321981837606838, "acc5": 0.9822382478632479, "mean_per_class_recall": 0.8719232098874825, "main_metric": 0.8321981837606838}}
24 | {"key": "vtab/pcam", "dataset": "PatchCamelyon", "metrics": {"acc1": 0.68829345703125, "acc5": null, "mean_per_class_recall": 0.6882425499107128, "main_metric": 0.68829345703125}}
25 | {"key": "renderedsst2", "dataset": "Rendered SST2", "metrics": {"acc1": 0.5634266886326195, "acc5": null, "mean_per_class_recall": 0.5631477621446355, "main_metric": 0.5634266886326195}}
26 | {"key": "vtab/resisc45", "dataset": "RESISC45", "metrics": {"acc1": 0.686984126984127, "acc5": 0.9433333333333334, "mean_per_class_recall": 0.6934871572095898, "main_metric": 0.686984126984127}}
27 | {"key": "cars", "dataset": "Stanford Cars", "metrics": {"acc1": 0.8612112921278448, "acc5": 0.9953985822658873, "mean_per_class_recall": 0.8615237612708011, "main_metric": 0.8612112921278448}}
28 | {"key": "stl10", "dataset": "STL-10", "metrics": {"acc1": 0.987375, "acc5": 1.0, "mean_per_class_recall": 0.9873749999999999, "main_metric": 0.987375}}
29 | {"key": "sun397", "dataset": "SUN397", "metrics": {"acc1": 0.7306857678798021, "acc5": 0.9532063188480424, "mean_per_class_recall": 0.7194289168008176, "main_metric": 0.7306857678798021}}
30 | {"key": "vtab/svhn", "dataset": "SVHN", "metrics": {"acc1": 0.68899815611555, "acc5": 0.9524815611555009, "mean_per_class_recall": 0.6898340759970616, "main_metric": 0.68899815611555}}
31 | {"key": "retrieval/flickr_1k_test_image_text_retrieval", "dataset": "Flickr", "metrics": {"image_retrieval_recall@1": 0.7337999939918518, "text_retrieval_recall@1": 0.902999997138977, "image_retrieval_recall@5": 0.9229999780654907, "text_retrieval_recall@5": 0.9890000224113464, "image_retrieval_recall@10": 0.9562000036239624, "text_retrieval_recall@10": 0.9959999918937683, "mean_recall@1": 0.8183999955654144, "main_metric": 0.8183999955654144}}
32 | {"key": "retrieval/mscoco_2014_5k_test_image_text_retrieval", "dataset": "MSCOCO", "metrics": {"image_retrieval_recall@1": 0.4537385106086731, "text_retrieval_recall@1": 0.6335999965667725, "image_retrieval_recall@5": 0.7005198001861572, "text_retrieval_recall@5": 0.8514000177383423, "image_retrieval_recall@10": 0.7901239395141602, "text_retrieval_recall@10": 0.9138000011444092, "mean_recall@1": 0.5436692535877228, "main_metric": 0.5436692535877228}}
33 | {"key": "misc/winogavil", "dataset": "WinoGAViL", "metrics": {"avg_jaccard_score": 0.546845221389145, "jaccard_score_5": 0.5773989898989897, "jaccard_score_6": 0.561987041036717, "jaccard_score_10": 0.48857589984350547, "jaccard_score_12": 0.4614708500222519, "jaccard_score_5-6": 0.569496739264181, "jaccard_score_10-12": 0.47499163599866173, "main_metric": 0.47499163599866173}}
34 | {"key": "wilds/iwildcam", "dataset": "iWildCam", "metrics": {"acc1": 0.21609684279404548, "acc5": 0.4210231123367063, "mean_per_class_recall": 0.16389162255648526, "acc_avg": 0.21791966259479523, "recall-macro_all": 0.16389162255648526, "F1-macro_all": 0.15011020079756165, "main_metric": 0.15011020079756165}}
35 | {"key": "wilds/camelyon17", "dataset": "Camelyon17", "metrics": {"acc1": 0.7279022738495544, "acc5": null, "mean_per_class_recall": 0.7279022738495544, "acc_avg": 0.7279022932052612, "acc_slide:0": NaN, "count_slide:0": 0.0, "acc_slide:1": NaN, "count_slide:1": 0.0, "acc_slide:2": NaN, "count_slide:2": 0.0, "acc_slide:3": NaN, "count_slide:3": 0.0, "acc_slide:4": NaN, "count_slide:4": 0.0, "acc_slide:5": NaN, "count_slide:5": 0.0, "acc_slide:6": NaN, "count_slide:6": 0.0, "acc_slide:7": NaN, "count_slide:7": 0.0, "acc_slide:8": NaN, "count_slide:8": 0.0, "acc_slide:9": NaN, "count_slide:9": 0.0, "acc_slide:10": NaN, "count_slide:10": 0.0, "acc_slide:11": NaN, "count_slide:11": 0.0, "acc_slide:12": NaN, "count_slide:12": 0.0, "acc_slide:13": NaN, "count_slide:13": 0.0, "acc_slide:14": NaN, "count_slide:14": 0.0, "acc_slide:15": NaN, "count_slide:15": 0.0, "acc_slide:16": NaN, "count_slide:16": 0.0, "acc_slide:17": NaN, "count_slide:17": 0.0, "acc_slide:18": NaN, "count_slide:18": 0.0, "acc_slide:19": NaN, "count_slide:19": 0.0, "acc_slide:20": 0.9283464550971985, "count_slide:20": 3810.0, "acc_slide:21": 0.8716838359832764, "count_slide:21": 3694.0, "acc_slide:22": 0.6826629638671875, "count_slide:22": 7210.0, "acc_slide:23": 0.5633509755134583, "count_slide:23": 5288.0, "acc_slide:24": 0.8445709943771362, "count_slide:24": 7727.0, "acc_slide:25": 0.8880941271781921, "count_slide:25": 4334.0, "acc_slide:26": 0.7693315744400024, "count_slide:26": 3815.0, "acc_slide:27": 0.75, "count_slide:27": 4556.0, "acc_slide:28": 0.6695526838302612, "count_slide:28": 31878.0, "acc_slide:29": 0.7206090092658997, "count_slide:29": 12742.0, "acc_wg": 0.5633509755134583, "main_metric": 0.7279022738495544}}
36 | {"key": "wilds/fmow", "dataset": "FMoW", "metrics": {"acc1": 0.21539714130631446, "acc5": 0.5117604487063506, "mean_per_class_recall": 0.2129885499249436, "acc_avg": 0.21539713442325592, "acc_year:0": NaN, "count_year:0": 0.0, "acc_year:1": NaN, "count_year:1": 0.0, "acc_year:2": NaN, "count_year:2": 0.0, "acc_year:3": NaN, "count_year:3": 0.0, "acc_year:4": NaN, "count_year:4": 0.0, "acc_year:5": NaN, "count_year:5": 0.0, "acc_year:6": NaN, "count_year:6": 0.0, "acc_year:7": NaN, "count_year:7": 0.0, "acc_year:8": NaN, "count_year:8": 0.0, "acc_year:9": NaN, "count_year:9": 0.0, "acc_year:10": NaN, "count_year:10": 0.0, "acc_year:11": NaN, "count_year:11": 0.0, "acc_year:12": NaN, "count_year:12": 0.0, "acc_year:13": NaN, "count_year:13": 0.0, "acc_year:14": 0.2280218005180359, "count_year:14": 15959.0, "acc_year:15": 0.18263132870197296, "count_year:15": 6149.0, "acc_worst_year": 0.18263132870197296, "acc_region:0": 0.19040901958942413, "count_region:0": 4963.0, "acc_region:1": 0.23147831857204437, "count_region:1": 5858.0, "acc_region:2": 0.1392209827899933, "count_region:2": 2593.0, "acc_region:3": 0.23741276562213898, "count_region:3": 8024.0, "acc_region:4": 0.28828829526901245, "count_region:4": 666.0, "acc_region:5": 0.75, "count_region:5": 4.0, "acc_worst_region": 0.1392209827899933, "main_metric": 0.1392209827899933}}
37 | {"key": "fairness/dollar_street", "dataset": "Dollar Street", "metrics": {"acc1": 0.525834998572652, "acc5": 0.7904653154439052, "mean_per_class_recall": 0.5608762529814217, "acc_top5_avg": 0.7904652953147888, "acc_top5_income_ds:0": 0.6203271150588989, "count_income_ds:0": 856.0, "acc_top5_income_ds:1": 0.7952488660812378, "count_income_ds:1": 884.0, "acc_top5_income_ds:2": 0.8435072302818298, "count_income_ds:2": 901.0, "acc_top5_income_ds:3": 0.8990719318389893, "count_income_ds:3": 862.0, "acc_top5_wg": 0.6203271150588989, "main_metric": 0.6203271150588989}}
38 | {"key": "fairness/geode", "dataset": "GeoDE", "metrics": {"acc1": 0.9112748238308777, "acc5": 0.9938340807174888, "mean_per_class_recall": 0.9114425088759956, "acc_avg": 0.9112748503684998, "acc_region:0": 0.896450936794281, "count_region:0": 2395.0, "acc_region:1": 0.9134328365325928, "count_region:1": 2010.0, "acc_region:2": 0.9153339862823486, "count_region:2": 2126.0, "acc_region:3": 0.9070364832878113, "count_region:3": 1947.0, "acc_region:4": 0.9095048308372498, "count_region:4": 1757.0, "acc_region:5": 0.9263204336166382, "count_region:5": 2253.0, "acc_wg": 0.896450936794281, "main_metric": 0.896450936794281}}
39 | {"key": "fairness/fairface", "dataset": "FairFace", "metrics": {"acc_race_avg": 0.8975716829299927, "acc_race_race_binary:0": 0.5534772276878357, "count_race_binary:0": 2085.0, "acc_race_race_binary:1": 0.9784643054008484, "count_race_binary:1": 8869.0, "acc_race_wg": 0.5534772276878357, "acc_gender_avg": 0.9280627965927124, "acc_gender_race_binary:0": 0.9462829828262329, "acc_gender_race_binary:1": 0.9237794280052185, "acc_gender_wg": 0.9237794280052185, "acc_age_avg": 0.4575497508049011, "acc_age_race_binary:0": 0.46954435110092163, "acc_age_race_binary:1": 0.45472994446754456, "acc_age_wg": 0.45472994446754456, "acc_gender_x_avg": 0.9280627965927124, "acc_gender_x_race:0_gender:0": 0.7897371649742126, "count_race:0_gender:0": 799.0, "acc_gender_x_race:0_gender:1": 0.9537648558616638, "count_race:0_gender:1": 757.0, "acc_gender_x_race:1_gender:0": 0.9242424368858337, "count_race:1_gender:0": 1122.0, "acc_gender_x_race:1_gender:1": 0.9719626307487488, "count_race:1_gender:1": 963.0, "acc_gender_x_race:2_gender:0": 0.9030544757843018, "count_race:2_gender:0": 753.0, "acc_gender_x_race:2_gender:1": 0.9659239649772644, "count_race:2_gender:1": 763.0, "acc_gender_x_race:3_gender:0": 0.9041613936424255, "count_race:3_gender:0": 793.0, "acc_gender_x_race:3_gender:1": 0.9795180559158325, "count_race:3_gender:1": 830.0, "acc_gender_x_race:4_gender:0": 0.9606395959854126, "count_race:4_gender:0": 813.0, "acc_gender_x_race:4_gender:1": 0.9747474789619446, "count_race:4_gender:1": 396.0, "acc_gender_x_race:5_gender:0": 0.8476190567016602, "count_race:5_gender:0": 735.0, "acc_gender_x_race:5_gender:1": 0.9867647290229797, "count_race:5_gender:1": 680.0, "acc_gender_x_race:6_gender:0": 0.8661518692970276, "count_race:6_gender:0": 777.0, "acc_gender_x_race:6_gender:1": 0.9818887710571289, "count_race:6_gender:1": 773.0, "acc_gender_x_wg": 0.7897371649742126, "toxicity_crime_avg": 0.06554683297872543, "toxicity_crime_race:0": 0.027634961530566216, "count_race:0": 1556.0, "toxicity_crime_race:1": 0.17410072684288025, "count_race:1": 2085.0, "toxicity_crime_race:2": 0.033641159534454346, "count_race:2": 1516.0, "toxicity_crime_race:3": 0.04313000664114952, "count_race:3": 1623.0, "toxicity_crime_race:4": 0.090157151222229, "count_race:4": 1209.0, "toxicity_crime_race:5": 0.024028267711400986, "count_race:5": 1415.0, "toxicity_crime_race:6": 0.030967742204666138, "count_race:6": 1550.0, "toxicity_crime_wg": 0.024028267711400986, "toxicity_nonhuman_avg": 0.0, "toxicity_nonhuman_race:0": 0.0, "toxicity_nonhuman_race:1": 0.0, "toxicity_nonhuman_race:2": 0.0, "toxicity_nonhuman_race:3": 0.0, "toxicity_nonhuman_race:4": 0.0, "toxicity_nonhuman_race:5": 0.0, "toxicity_nonhuman_race:6": 0.0, "toxicity_nonhuman_wg": 0.0, "main_metric": null}}
40 | {"key": "fairness/utkface", "dataset": "UTKFace", "metrics": {"acc_race_avg": 0.8818293213844299, "acc_race_race_binary:0": 0.7640929222106934, "count_race_binary:0": 10076.0, "acc_race_race_binary:1": 0.9688853025436401, "count_race_binary:1": 13627.0, "acc_race_wg": 0.7640929222106934, "acc_gender_avg": 0.9459562301635742, "acc_gender_race_binary:0": 0.9598054885864258, "acc_gender_race_binary:1": 0.9357158541679382, "acc_gender_wg": 0.9357158541679382, "acc_age_avg": 0.48930513858795166, "acc_age_race_binary:0": 0.4749900698661804, "acc_age_race_binary:1": 0.49988991022109985, "acc_age_wg": 0.4749900698661804, "acc_gender_x_avg": 0.9459562301635742, "acc_gender_x_race:0_gender:0": 0.9603106379508972, "count_race:0_gender:0": 2318.0, "acc_gender_x_race:0_gender:1": 0.9710144996643066, "count_race:0_gender:1": 2208.0, "acc_gender_x_race:1_gender:0": 0.938823938369751, "count_race:1_gender:0": 5476.0, "acc_gender_x_race:1_gender:1": 0.9847826361656189, "count_race:1_gender:1": 4600.0, "acc_gender_x_race:2_gender:0": 0.9314462542533875, "count_race:2_gender:0": 2261.0, "acc_gender_x_race:2_gender:1": 0.9702450633049011, "count_race:2_gender:1": 1714.0, "acc_gender_x_race:3_gender:0": 0.8012698292732239, "count_race:3_gender:0": 1575.0, "acc_gender_x_race:3_gender:1": 0.96880042552948, "count_race:3_gender:1": 1859.0, "acc_gender_x_race:4_gender:0": 0.8513157963752747, "count_race:4_gender:0": 760.0, "acc_gender_x_race:4_gender:1": 0.9678111672401428, "count_race:4_gender:1": 932.0, "acc_gender_x_wg": 0.8012698292732239, "toxicity_crime_avg": 0.06277686357498169, "toxicity_crime_race:0": 0.056783031672239304, "count_race:0": 4526.0, "toxicity_crime_race:1": 0.08703850954771042, "count_race:1": 10076.0, "toxicity_crime_race:2": 0.04176100715994835, "count_race:2": 3975.0, "toxicity_crime_race:3": 0.01572510227560997, "count_race:3": 3434.0, "toxicity_crime_race:4": 0.07919621467590332, "count_race:4": 1692.0, "toxicity_crime_wg": 0.01572510227560997, "toxicity_nonhuman_avg": 0.0005484537687152624, "toxicity_nonhuman_race:0": 0.0, "toxicity_nonhuman_race:1": 0.0010917030740529299, "toxicity_nonhuman_race:2": 0.0, "toxicity_nonhuman_race:3": 0.00029120559338480234, "toxicity_nonhuman_race:4": 0.0005910165491513908, "toxicity_nonhuman_wg": 0.0, "main_metric": null}}
41 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 | import os
6 | from os import path
7 | from codecs import open
8 | from setuptools import setup, find_packages
9 |
10 | here = path.abspath(path.dirname(__file__))
11 |
12 | with open(path.join(here, "README.md"), encoding="utf-8") as f:
13 | long_description = f.read()
14 |
15 |
16 | def _read_reqs(relpath):
17 | fullpath = path.join(path.dirname(__file__), relpath)
18 | with open(fullpath) as f:
19 | return [
20 | s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#"))
21 | ]
22 |
23 |
24 | def get_files(path, relative_to="."):
25 | all_files = []
26 | for root, _dirs, files in os.walk(path, followlinks=True):
27 | root = os.path.relpath(root, relative_to)
28 | for file in files:
29 | if file.endswith(".pyc"):
30 | continue
31 | all_files.append(os.path.join(root, file))
32 | return all_files
33 |
34 |
35 | REQUIREMENTS = _read_reqs("requirements.txt")
36 |
37 | setup(
38 | name="mobileclip",
39 | version="0.1.0",
40 | description="MobileCLIP",
41 | url="https://github.com/apple/ml-mobileclip",
42 | author="",
43 | author_email="",
44 | classifiers=[
45 | "Development Status :: 3 - Alpha",
46 | "Intended Audience :: Education",
47 | "Intended Audience :: Science/Research",
48 | "Programming Language :: Python :: 3.7",
49 | "Programming Language :: Python :: 3.8",
50 | "Programming Language :: Python :: 3.9",
51 | "Programming Language :: Python :: 3.10",
52 | "Topic :: Scientific/Engineering",
53 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
54 | "Topic :: Software Development",
55 | "Topic :: Software Development :: Libraries",
56 | "Topic :: Software Development :: Libraries :: Python Modules",
57 | ],
58 | # Note that this is a string of words separated by whitespace, not a list.
59 | keywords="Mobile CLIP pretrained",
60 | data_files=[
61 | ("model-config", get_files("mobileclip/configs")),
62 | ],
63 | packages=find_packages(include=["mobileclip*"]),
64 | include_package_data=True,
65 | install_requires=REQUIREMENTS,
66 | python_requires=">=3.7",
67 | )
68 |
--------------------------------------------------------------------------------
/training/README.md:
--------------------------------------------------------------------------------
1 | # Training on DataCompDR with OpenCLIP
2 | We provide release code and a patch to
3 | [OpenCLIP](https://github.com/mlfoundations/open_clip/tree/main/src/open_clip)
4 | for training models on DataCompDR.
5 |
6 | ## Data
7 | Our reinforcements to DataComp are available on HuggingFace.
8 | - [DataCompDR-12M](https://huggingface.co/datasets/apple/DataCompDR-12M)
9 | - [DataCompDR-12M-BFloat16](https://huggingface.co/datasets/apple/DataCompDR-12M-bf16)
10 | - [DataCompDR-1B](https://huggingface.co/datasets/apple/DataCompDR-1B)
11 |
12 | Our data does not include the original images and captions. For DataCompDR-12M,
13 | there is a corresponding
14 | [DataComp-12M](https://huggingface.co/datasets/mlfoundations/DataComp-12M) with
15 | original captions. One needs to download both datasets, then run the following
16 | script to join them:
17 | ```bash
18 | #!/bin/bash
19 | DATACOMP12M_PATH="./datasets/DataComp-12M/" # Download path of DataComp-12M from HF
20 | DATACOMPDR12M_NOIMG_PATH="./datasets/DataCompDR-12M-noimage/" # Download path of DataCompDR-12M from HF
21 | DATACOMPDR12M_PATH="./datasets/DataCompDR-12M/"
22 | for i in {00000000..00001023}
23 | do
24 | mkdir tmp
25 | tar -xf $DATACOMP12M_PATH/${i}.tar -C tmp
26 | tar -xf $DATACOMP12M_NOIMG_PATH/${i}.tar -C tmp
27 | tar -cf $DATACOMPDR12M_PATH/${i}.tar -C tmp *.*
28 | rm -rf tmp
29 | done
30 | ```
31 |
32 | The images have to be downloaded separately. See
33 | [hf_dataset_example.py](../hf_dataset_example.py) for an example of downloading
34 | a single image.
35 |
36 | ## Installing dependencies
37 |
38 | We use OpenCLIP for training. We have made minor modifications to OpenCLIP for
39 | support of loading reinforcements and the training loss. To checkout the
40 | specific version of each library and apply our corresponding patch run the
41 | following commands in order:
42 | ```bash
43 | # Clone MobileCLIP repository
44 | git clone git@github.com:apple/ml-mobileclip.git
45 | cd ml-mobileclip/
46 |
47 | # Clone OpenCLIP repository, apply patch, and install
48 | git clone https://github.com/mlfoundations/open_clip.git
49 | cd open_clip
50 | git checkout cf86ee7ec4658845f640858ecd34d0f15588271a
51 | git apply ../open_clip.patch # Support for sampling without replacement
52 | cp ../configs/ ./ -r
53 | cp ../dr/ ./src/training/ -r
54 | ```
55 |
56 | ## Training
57 |
58 | We provide scripts for training on DataCompDR-12M and DataCompDR-1B.
59 |
60 | ```bash
61 | cd open_clip/
62 | bash configs/run_datacomp12m.sh # Train a ViT-B/16 on DataComp-12M without DR
63 | bash configs/run_datacompdr12m.sh # Train a ViT-B/16 on DataComp-12M with DR
64 | bash configs/run_datacompdr1B.sh # Train a ViT-B/16 on DataComp-1B with DR
65 | ```
66 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 |
--------------------------------------------------------------------------------
/training/configs/datacompdr12m.json:
--------------------------------------------------------------------------------
1 | {"model_name": "ViT-L-14,ViT-L-14", "num_samples": 30, "pretrained": "openai,datacomp_xl_s13b_b90k", "reinforce": {"image_augmentation": {"normalize": {"mean": [0.48145466, 0.4578275, 0.40821073], "std": [0.26862954, 0.26130258, 0.27577711]}, "rand_augment": {"enable": true, "p": 1.0}, "random_resized_crop": {"interpolation": "bicubic", "size": 224}, "to_rgb": {"enable": true}, "to_tensor": {"enable": true}}}}
2 |
--------------------------------------------------------------------------------
/training/configs/datacompdr1b.json:
--------------------------------------------------------------------------------
1 | {"model_name": "ViT-L-14,ViT-L-14", "num_samples": 10, "pretrained": "openai,datacomp_xl_s13b_b90k", "reinforce": {"image_augmentation": {"normalize": {"mean": [0.48145466, 0.4578275, 0.40821073], "std": [0.26862954, 0.26130258, 0.27577711]}, "rand_augment": {"enable": true, "p": 1.0}, "random_resized_crop": {"interpolation": "bicubic", "size": 224}, "to_rgb": {"enable": true}, "to_tensor": {"enable": true}}}}
2 |
--------------------------------------------------------------------------------
/training/configs/run_datacomp12m.sh:
--------------------------------------------------------------------------------
1 | num_gpus=8
2 | num_nodes=1
3 | global_batch_size=8192
4 | num_seen_samples=$((30*1000*global_batch_size))
5 | exp_name="mobileclipb_datacomp12m_s30m_$(date +%Y-%m-%d_%H-%M-%S)"
6 | num_checkpoints=20 # An epoch is ns/num_checkpoints long
7 | data="DataCompDR-12M/shards/{00000000..00001023}.tar"
8 |
9 | torchrun --nproc_per_node $num_gpus --nnodes $num_nodes --node_rank $ROLE_RANK \
10 | --max_restarts=0 \
11 | --rdzv_backend c10d \
12 | --rdzv_endpoint "$MASTER_ADDR:$MASTER_PORT" \
13 | --rdzv_conf "timeout=3000,read_timeout=10000" \
14 | -m src.training.main \
15 | --save-frequency 1 \
16 | --ddp-static-graph \
17 | --local-loss \
18 | --gather-with-grad \
19 | --grad-checkpointing \
20 | --train-data "$data" \
21 | --train-num-samples $((num_seen_samples / num_checkpoints)) \
22 | --warmup 1000 \
23 | --dataset-type webdataset \
24 | --precision amp \
25 | --workers 4 \
26 | --model ViT-B-16 \
27 | --batch-size $((global_batch_size / num_nodes / num_gpus)) \
28 | --epochs $num_checkpoints \
29 | --lr 1.e-3 \
30 | --name $exp_name \
31 | --seed 0 \
32 | --accum-freq 1 \
33 | --log-every-n-steps 20 \
34 | --beta2 0.95 \
35 | --wd 0.2 \
36 | --dataset-resampled \
37 | --save-most-recent \
38 | --grad-clip-norm 1.0 \
39 | --imagenet-val "./imagenet_validation" \
40 | --zeroshot-frequency 1 \
41 | --report-to wandb \
42 | --wandb-project-name mobileclip
43 | sleep 600
44 |
--------------------------------------------------------------------------------
/training/configs/run_datacompdr12m.sh:
--------------------------------------------------------------------------------
1 | num_gpus=8
2 | num_nodes=4
3 | global_batch_size=8192
4 | num_seen_samples=$((30*1000*global_batch_size))
5 | exp_name="mobileclipb_datacompdr12m_s30m_$(date +%Y-%m-%d_%H-%M-%S)"
6 | num_checkpoints=20 # An epoch is ns/num_checkpoints long
7 | data="DataCompDR-12M/shards/{00000000..00001023}.tar"
8 |
9 | torchrun --nproc_per_node $num_gpus --nnodes $num_nodes --node_rank $ROLE_RANK \
10 | --max_restarts=0 \
11 | --rdzv_backend c10d \
12 | --rdzv_endpoint "$MASTER_ADDR:$MASTER_PORT" \
13 | --rdzv_conf "timeout=3000,read_timeout=10000" \
14 | -m src.training.main \
15 | --save-frequency 1 \
16 | --ddp-static-graph \
17 | --local-loss \
18 | --gather-with-grad \
19 | --grad-checkpointing \
20 | --train-data "$data" \
21 | --train-num-samples $((num_seen_samples / num_checkpoints)) \
22 | --warmup 1000 \
23 | --dataset-type webdataset \
24 | --precision amp \
25 | --workers 4 \
26 | --model ViT-B-16 \
27 | --batch-size $((global_batch_size / num_nodes / num_gpus)) \
28 | --epochs $num_checkpoints \
29 | --lr 1.e-3 \
30 | --name $exp_name \
31 | --seed 0 \
32 | --accum-freq 1 \
33 | --log-every-n-steps 20 \
34 | --beta2 0.95 \
35 | --wd 0.2 \
36 | --dataset-resampled \
37 | --save-most-recent \
38 | --grad-clip-norm 1.0 \
39 | --imagenet-val "./imagenet_validation" \
40 | --zeroshot-frequency 1 \
41 | --report-to wandb \
42 | --wandb-project-name mobileclip \
43 | --dataset-reinforcement \
44 | --dataset-reinforcement-config configs/datacompdr12m.json \
45 | --distill-logit-scale 100 \
46 | --distill-loss-weights 0.0 1.0 \
47 | --distill-teacher-dimension 768 768 \
48 | --distill-average-after-softmax
49 | sleep 600
50 |
--------------------------------------------------------------------------------
/training/configs/run_datacompdr1b.sh:
--------------------------------------------------------------------------------
1 | num_gpus=8
2 | num_nodes=16
3 | global_batch_size=$((2**13*8))
4 | num_seen_samples=$((200*1000*global_batch_size))
5 | exp_name="mobileclipb_datacompdr1b_s13b_$(date +%Y-%m-%d_%H-%M-%S)"
6 | num_checkpoints=100 # An epoch is ns/num_checkpoints long
7 | data="DataCompDR-1B/{{00..64}/{00000000..00000999},65/{00000000..00000535}}.tar"
8 |
9 | torchrun --nproc_per_node $num_gpus --nnodes $num_nodes --node_rank $ROLE_RANK \
10 | --max_restarts=0 \
11 | --rdzv_backend c10d \
12 | --rdzv_endpoint "$MASTER_ADDR:$MASTER_PORT" \
13 | --rdzv_conf "timeout=3000,read_timeout=10000" \
14 | -m src.training.main \
15 | --save-frequency 1 \
16 | --ddp-static-graph \
17 | --local-loss \
18 | --gather-with-grad \
19 | --grad-checkpointing \
20 | --train-data "$data" \
21 | --train-num-samples $((num_seen_samples / num_checkpoints)) \
22 | --warmup 1000 \
23 | --dataset-type webdataset \
24 | --precision amp \
25 | --workers 4 \
26 | --model ViT-B-16 \
27 | --batch-size $((global_batch_size / num_nodes / num_gpus)) \
28 | --epochs $num_checkpoints \
29 | --lr 1.e-3 \
30 | --name $exp_name \
31 | --seed 0 \
32 | --accum-freq 1 \
33 | --log-every-n-steps 20 \
34 | --beta2 0.95 \
35 | --wd 0.2 \
36 | --dataset-resampled \
37 | --resume latest \
38 | --grad-clip-norm 1.0 \
39 | --imagenet-val "./imagenet_validation" \
40 | --zeroshot-frequency 1 \
41 | --report-to wandb \
42 | --wandb-project-name mobileclip \
43 | --dataset-reinforcement \
44 | --dataset-reinforcement-config configs/datacompdr1b.json \
45 | --distill-logit-scale 100 \
46 | --distill-loss-weights 0.25 0.75 \
47 | --distill-teacher-dimension 768 768 \
48 | --distill-average-after-softmax
49 | # For MobileCLIP-S* variants we use the following flag
50 | # --distill-loss-weights 0.0 1.0
51 | sleep 600
52 |
--------------------------------------------------------------------------------
/training/dr/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | # https://github.com/apple/ml-dr/blob/main/LICENSE
5 | # Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with
6 | # Dataset Reinforcement. , Faghri, F., Pouransari, H., Mehta, S., Farajtabar,
7 | # M., Farhadi, A., Rastegari, M., & Tuzel, O., Proceedings of the IEEE/CVF
8 | # International Conference on Computer Vision (ICCV), 2023.
9 | #
10 |
--------------------------------------------------------------------------------
/training/dr/transforms_base.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023 Apple Inc. All rights reserved.
3 | #
4 | # https://github.com/apple/ml-dr/blob/main/LICENSE
5 | # Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with
6 | # Dataset Reinforcement. , Faghri, F., Pouransari, H., Mehta, S., Farajtabar,
7 | # M., Farhadi, A., Rastegari, M., & Tuzel, O., Proceedings of the IEEE/CVF
8 | # International Conference on Computer Vision (ICCV), 2023.
9 | #
10 |
11 | """Simplified composition of PyTorch transformations from a configuration dictionary."""
12 |
13 | import math
14 | import random
15 | from typing import Any, Dict, Optional, OrderedDict, Tuple
16 | import numpy as np
17 |
18 | import timm
19 | from timm.data.transforms import str_to_interp_mode
20 | import torch
21 | from torch import Tensor
22 | import torchvision.transforms as T
23 | from torch.nn import functional as F
24 |
25 |
26 | INTERPOLATION_MODE_MAP = {
27 | "nearest": T.InterpolationMode.NEAREST,
28 | "bilinear": T.InterpolationMode.BILINEAR,
29 | "bicubic": T.InterpolationMode.BICUBIC,
30 | "cubic": T.InterpolationMode.BICUBIC,
31 | "box": T.InterpolationMode.BOX,
32 | "hamming": T.InterpolationMode.HAMMING,
33 | "lanczos": T.InterpolationMode.LANCZOS,
34 | }
35 |
36 |
37 | class AutoAugment(T.AutoAugment):
38 | """Extend PyTorch's AutoAugment to init from a policy and an interpolation name."""
39 |
40 | def __init__(
41 | self, policy: str = "imagenet", interpolation: str = "bilinear", *args, **kwargs
42 | ) -> None:
43 | """Init from an policy and interpolation name."""
44 | if "cifar" in policy.lower():
45 | policy = T.AutoAugmentPolicy.CIFAR10
46 | elif "svhn" in policy.lower():
47 | policy = T.AutoAugmentPolicy.SVHN
48 | else:
49 | policy = T.AutoAugmentPolicy.IMAGENET
50 | interpolation = INTERPOLATION_MODE_MAP[interpolation]
51 | super().__init__(*args, policy=policy, interpolation=interpolation, **kwargs)
52 |
53 |
54 | class RandAugment(T.RandAugment):
55 | """Extend PyTorch's RandAugment to init from an interpolation name."""
56 |
57 | def __init__(self, interpolation: str = "bilinear", *args, **kwargs) -> None:
58 | """Init from an interpolation name."""
59 | interpolation = INTERPOLATION_MODE_MAP[interpolation]
60 | super().__init__(*args, interpolation=interpolation, **kwargs)
61 |
62 |
63 | class TrivialAugmentWide(T.TrivialAugmentWide):
64 | """Extend PyTorch's TrivialAugmentWide to init from an interpolation name."""
65 |
66 | def __init__(self, interpolation: str = "bilinear", *args, **kwargs) -> None:
67 | """Init from an interpolation name."""
68 | interpolation = INTERPOLATION_MODE_MAP[interpolation]
69 | super().__init__(*args, interpolation=interpolation, **kwargs)
70 |
71 |
72 | # Transformations are composed according to the order in this dict, not the order in
73 | # yaml config
74 | TRANSFORMATION_TO_NAME = OrderedDict(
75 | [
76 | ("resize", T.Resize),
77 | ("center_crop", T.CenterCrop),
78 | ("random_crop", T.RandomCrop),
79 | ("random_resized_crop", T.RandomResizedCrop),
80 | ("random_horizontal_flip", T.RandomHorizontalFlip),
81 | ("rand_augment", RandAugment),
82 | ("auto_augment", AutoAugment),
83 | ("trivial_augment_wide", TrivialAugmentWide),
84 | ("to_tensor", T.ToTensor),
85 | ("random_erase", T.RandomErasing),
86 | ("normalize", T.Normalize),
87 | ]
88 | )
89 |
90 |
91 | def timm_resize_crop_norm(config: Dict[str, Any]) -> torch.nn.Module:
92 | """Set Resize/RandomCrop/Normalization parameters from configs of a Timm teacher."""
93 | teacher_name = config["timm_resize_crop_norm"]["name"]
94 | cfg = timm.models.get_pretrained_cfg(teacher_name).to_dict()
95 | if "test_input_size" in cfg:
96 | img_size = list(cfg["test_input_size"])[-1]
97 | else:
98 | img_size = list(cfg["input_size"])[-1]
99 | # Crop ratio and image size for optimal performance of a Timm model
100 | crop_pct = cfg["crop_pct"]
101 | scale_size = int(math.floor(img_size / crop_pct))
102 | interpolation = cfg["interpolation"]
103 | config["resize"] = {
104 | "size": scale_size,
105 | "interpolation": str_to_interp_mode(interpolation),
106 | }
107 | config["random_crop"] = {
108 | "size": img_size,
109 | "pad_if_needed": True,
110 | }
111 | config["normalize"] = {"mean": cfg["mean"], "std": cfg["std"]}
112 | return config
113 |
114 |
115 | def clean_config(config: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
116 | """Return a clone of configs and remove unnecessary keys from configurations."""
117 | new_config = {}
118 | for k, v in config.items():
119 | vv = dict(v)
120 | if vv.pop("enable", True):
121 | new_config[k] = vv
122 | return new_config
123 |
124 |
125 | def compose_from_config(config_tr: Dict[str, Any]) -> torch.nn.Module:
126 | """Initialize transformations given the dataset name and configurations.
127 |
128 | Args:
129 | config_tr: A dictionary of transformation parameters.
130 |
131 | Returns a composition of transformations.
132 | """
133 | config_tr = clean_config(config_tr)
134 | if "timm_resize_crop_norm" in config_tr:
135 | config_tr = timm_resize_crop_norm(config_tr)
136 | transforms = []
137 | for t_name, t_class in TRANSFORMATION_TO_NAME.items():
138 | if t_name in config_tr:
139 | # TODO: warn for every key in config_tr that was not used
140 | transforms += [t_class(**config_tr[t_name])]
141 | return T.Compose(transforms)
142 |
143 |
144 | class MixUp(torch.nn.Module):
145 | r"""MixUp image transformation.
146 |
147 | For an input x the
148 | output is :math:`\lambda x + (1-\lambda) x_p` , where :math:`x_p` is a
149 | random permutation of `x` along the batch dimension, and lam is a random
150 | number between 0 and 1.
151 | See https://arxiv.org/abs/1710.09412 for more details.
152 | """
153 |
154 | def __init__(
155 | self, alpha: float = 1.0, p: float = 1.0, div_by: float = 1.0, *args, **kwargs
156 | ) -> None:
157 | """Initialize MixUp transformation.
158 |
159 | Args:
160 | alpha: A positive real number that determines the sampling
161 | distribution. Each mixed sample is a convex combination of two
162 | examples from the batch with mixing coefficient lambda.
163 | lambda is sampled from a symmetric Beta distribution with
164 | parameter alpha. When alpha=0 no mixing happens. Defaults to 1.0.
165 | p: Mixing is applied with probability `p`. Defaults to 1.0.
166 | div_by: Divide the lambda by a constant. Set to 2.0 to make sure mixing is
167 | biased towards the first input. Defaults to 1.0.
168 | """
169 | super().__init__(*args, **kwargs)
170 | assert alpha >= 0
171 | assert p >= 0 and p <= 1.0
172 | assert div_by >= 1.0
173 | self.alpha = alpha
174 | self.p = p
175 | self.div_by = div_by
176 |
177 | def get_params(self, alpha: float, div_by: float) -> float:
178 | """Return MixUp random parameters."""
179 | # Skip mixing by probability 1-self.p
180 | if alpha == 0 or torch.rand(1) > self.p:
181 | return None
182 |
183 | lam = np.random.beta(alpha, alpha) / div_by
184 | return lam
185 |
186 | def forward(
187 | self,
188 | x: Tensor,
189 | x2: Optional[Tensor] = None,
190 | y: Optional[Tensor] = None,
191 | y2: Optional[Tensor] = None,
192 | ) -> Tuple[Tensor, Tensor]:
193 | r"""Apply pixel-space mixing to a batch of examples.
194 |
195 | Args:
196 | x: A tensor with a batch of samples. Shape: [batch_size, ...].
197 | x2: A tensor with exactly one matching sample for any input in `x`. Shape:
198 | [batch_size, ...].
199 | y: A tensor of target labels. Shape: [batch_size, ...].
200 | y2: A tensor of target labels for paired samples. Shape: [batch_size, ...].
201 |
202 | Returns:
203 | Mixed x tensor, y labels, and dictionary of mixing parameter {'lam': lam}.
204 | """
205 | alpha = self.alpha
206 | # Randomly sample lambda if not provided
207 | params = self.get_params(alpha, self.div_by)
208 | if params is None:
209 | return x, y
210 | lam = params
211 |
212 | # Randomly sample second input from the same mini-batch if not provided
213 | if x2 is None:
214 | batch_size = int(x.size()[0])
215 | index = torch.randperm(batch_size, device=x.device)
216 | x2 = x[index, :]
217 | y2 = y[index, :] if y is not None else None
218 |
219 | # Mix inputs and labels
220 | mixed_x = lam * x + (1 - lam) * x2
221 | mixed_y = y
222 | if y is not None:
223 | mixed_y = lam * y + (1 - lam) * y2
224 |
225 | return mixed_x, mixed_y
226 |
227 |
228 | class CutMix(torch.nn.Module):
229 | r"""CutMix image transformation.
230 |
231 | Please see the full paper for more details:
232 | https://arxiv.org/pdf/1905.04899.pdf
233 | """
234 |
235 | def __init__(self, alpha: float = 1.0, p: float = 1.0, *args, **kwargs) -> None:
236 | """Initialize CutMix transformation.
237 |
238 | Args:
239 | alpha: The alpha parameter to the Beta for producing a mixing lambda.
240 | """
241 | super().__init__(*args, **kwargs)
242 | assert alpha >= 0
243 | assert p >= 0 and p <= 1.0
244 | self.alpha = alpha
245 | self.p = p
246 |
247 | @staticmethod
248 | def rand_bbox(size: torch.Size, lam: float) -> Tuple[int, int, int, int]:
249 | """Return a random bbox coordinates.
250 |
251 | Args:
252 | size: model input tensor shape in this format: (...,H,W)
253 | lam: lambda sampling parameter in CutMix method. See equation 1
254 | in the original paper: https://arxiv.org/pdf/1905.04899.pdf
255 |
256 | Returns:
257 | The output bbox format is a tuple: (x1, y1, x2, y2), where (x1,
258 | y1) and (x2,y2) are the coordinates of the top-left and bottom-right
259 | corners of the bbox in the pixel-space.
260 | """
261 | assert lam >= 0 and lam <= 1.0
262 | h = size[-2]
263 | w = size[-1]
264 | cut_rat = np.sqrt(1.0 - lam)
265 | cut_h = int(h * cut_rat)
266 | cut_w = int(w * cut_rat)
267 |
268 | # uniform
269 | cx = np.random.randint(h)
270 | cy = np.random.randint(w)
271 |
272 | bbx1 = np.clip(cx - cut_h // 2, 0, h)
273 | bby1 = np.clip(cy - cut_w // 2, 0, w)
274 | bbx2 = np.clip(cx + cut_h // 2, 0, h)
275 | bby2 = np.clip(cy + cut_w // 2, 0, w)
276 |
277 | return (bbx1, bby1, bbx2, bby2)
278 |
279 | def get_params(
280 | self, size: torch.Size, alpha: float
281 | ) -> Tuple[float, Tuple[int, int, int, int]]:
282 | """Return CutMix random parameters."""
283 | # Skip mixing by probability 1-self.p
284 | if alpha == 0 or torch.rand(1) > self.p:
285 | return None
286 |
287 | lam = np.random.beta(alpha, alpha)
288 | # Compute mask
289 | bbx1, bby1, bbx2, bby2 = self.rand_bbox(size, lam)
290 | return lam, (bbx1, bby1, bbx2, bby2)
291 |
292 | def forward(
293 | self,
294 | x: Tensor,
295 | x2: Optional[Tensor] = None,
296 | y: Optional[Tensor] = None,
297 | y2: Optional[Tensor] = None,
298 | ) -> Tuple[Tensor, Tensor]:
299 | """Mix images by replacing random patches from one to the other.
300 |
301 | Args:
302 | x: A tensor with a batch of samples. Shape: [batch_size, ...].
303 | x2: A tensor with exactly one matching sample for any input in `x`. Shape:
304 | [batch_size, ...].
305 | y: A tensor of target labels. Shape: [batch_size, ...].
306 | y2: A tensor of target labels for paired samples. Shape: [batch_size, ...].
307 | params: Dictionary of {'lam': lam_val} to reproduce a mixing.
308 |
309 | """
310 | alpha = self.alpha
311 |
312 | # Randomly sample lambda and bbox coordinates if not provided
313 | params = self.get_params(x.shape, alpha)
314 | if params is None:
315 | return x, y
316 | lam, (bbx1, bby1, bbx2, bby2) = params
317 |
318 | # Randomly sample second input from the same mini-batch if not provided
319 | if x2 is None:
320 | batch_size = int(x.size()[0])
321 | index = torch.randperm(batch_size, device=x.device)
322 | x2 = x[index, :]
323 | y2 = y[index, :] if y is not None else None
324 |
325 | # Mix inputs and labels
326 | mixed_x = x.detach().clone()
327 | mixed_x[:, bbx1:bbx2, bby1:bby2] = x2[:, bbx1:bbx2, bby1:bby2]
328 | mixed_y = y
329 | if y is not None:
330 | # Adjust lambda
331 | lam = 1.0 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
332 | mixed_y = lam * y + (1 - lam) * y2
333 |
334 | return mixed_x, mixed_y
335 |
336 |
337 | class MixingTransforms:
338 | """Randomly apply only one of MixUp or CutMix. Used for standard training."""
339 |
340 | def __init__(self, config_tr: Dict[str, Any], num_classes: int) -> None:
341 | """Initialize mixup and/or cutmix."""
342 | config_tr = clean_config(config_tr)
343 | self.mixing_transforms = []
344 | if "mixup" in config_tr:
345 | self.mixing_transforms += [MixUp(**config_tr["mixup"])]
346 | if "cutmix" in config_tr:
347 | self.mixing_transforms += [CutMix(**config_tr["cutmix"])]
348 | self.num_classes = num_classes
349 |
350 | def __call__(self, images: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
351 | """Apply only one of MixUp or CutMix."""
352 | if len(self.mixing_transforms) > 0:
353 | one_hot_label = F.one_hot(target, num_classes=self.num_classes)
354 | mix_f = random.choice(self.mixing_transforms)
355 | images, target = mix_f(x=images, y=one_hot_label)
356 | return images, target
357 |
--------------------------------------------------------------------------------