├── .gitignore ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── assets ├── overview.png ├── result_interleaved.png ├── result_main_multimodal.png └── result_main_transfer.png ├── requirements-test.txt ├── requirements-training.txt ├── requirements.txt ├── scripts ├── clipav1_vit_l16_i37_t8.sh ├── clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh ├── h14_224_32_finetune.sh ├── h14_84_8_pretrain.sh ├── lcl_vit_b_16_laion.sh ├── lcl_vit_b_32_laion.sh └── lcl_vit_b_32_mmc4.sh ├── setup.py ├── src ├── open_clip │ ├── __init__.py │ ├── attentive.py │ ├── big_vision.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── coca_model.py │ ├── constants.py │ ├── factory.py │ ├── hf_configs.py │ ├── hf_model.py │ ├── lcl_model.py │ ├── loss.py │ ├── model.py │ ├── model_configs │ │ ├── EVA01-g-14-plus.json │ │ ├── EVA01-g-14.json │ │ ├── EVA02-B-16.json │ │ ├── EVA02-E-14-plus.json │ │ ├── EVA02-E-14.json │ │ ├── EVA02-L-14-336.json │ │ ├── EVA02-L-14.json │ │ ├── LCL_ViT-B-16_laion.json │ │ ├── LCL_ViT-B-32_laion.json │ │ ├── LCL_ViT-B-32_mmc4.json │ │ ├── RN101-quickgelu.json │ │ ├── RN101.json │ │ ├── RN50-quickgelu.json │ │ ├── RN50.json │ │ ├── RN50x16.json │ │ ├── RN50x4.json │ │ ├── RN50x64.json │ │ ├── ViT-B-16-SigLIP-256.json │ │ ├── ViT-B-16-SigLIP-384.json │ │ ├── ViT-B-16-SigLIP-512.json │ │ ├── ViT-B-16-SigLIP-i18n-256.json │ │ ├── ViT-B-16-SigLIP.json │ │ ├── ViT-B-16-plus-240.json │ │ ├── ViT-B-16-plus.json │ │ ├── ViT-B-16-quickgelu.json │ │ ├── ViT-B-16.json │ │ ├── ViT-B-32-256.json │ │ ├── ViT-B-32-plus-256.json │ │ ├── ViT-B-32-quickgelu.json │ │ ├── ViT-B-32.json │ │ ├── ViT-H-14-378-quickgelu.json │ │ ├── ViT-H-14-CLIPA-336.json │ │ ├── ViT-H-14-CLIPA.json │ │ ├── ViT-H-14-quickgelu.json │ │ ├── ViT-H-14.json │ │ ├── ViT-H-16.json │ │ ├── ViT-L-14-280.json │ │ ├── ViT-L-14-336.json │ │ ├── ViT-L-14-CLIPA-336.json │ │ ├── ViT-L-14-CLIPA.json │ │ ├── ViT-L-14-quickgelu.json │ │ ├── ViT-L-14.json │ │ ├── ViT-L-16-320.json │ │ ├── ViT-L-16-SigLIP-256.json │ │ ├── ViT-L-16-SigLIP-384.json │ │ ├── ViT-L-16.json │ │ ├── ViT-M-16-alt.json │ │ ├── ViT-M-16.json │ │ ├── ViT-M-32-alt.json │ │ ├── ViT-M-32.json │ │ ├── ViT-S-16-alt.json │ │ ├── ViT-S-16.json │ │ ├── ViT-S-32-alt.json │ │ ├── ViT-S-32.json │ │ ├── ViT-SO400M-14-SigLIP-384.json │ │ ├── ViT-SO400M-14-SigLIP.json │ │ ├── ViT-bigG-14-CLIPA-336.json │ │ ├── ViT-bigG-14-CLIPA.json │ │ ├── ViT-bigG-14.json │ │ ├── ViT-e-14.json │ │ ├── ViT-g-14.json │ │ ├── coca_ViT-B-32.json │ │ ├── coca_ViT-L-14.json │ │ ├── coca_base.json │ │ ├── coca_roberta-ViT-B-32.json │ │ ├── convnext_base.json │ │ ├── convnext_base_w.json │ │ ├── convnext_base_w_320.json │ │ ├── convnext_large.json │ │ ├── convnext_large_d.json │ │ ├── convnext_large_d_320.json │ │ ├── convnext_small.json │ │ ├── convnext_tiny.json │ │ ├── convnext_xlarge.json │ │ ├── convnext_xxlarge.json │ │ ├── convnext_xxlarge_320.json │ │ ├── mt5-base-ViT-B-32.json │ │ ├── mt5-xl-ViT-H-14.json │ │ ├── nllb-clip-base-siglip.json │ │ ├── nllb-clip-base.json │ │ ├── nllb-clip-large-siglip.json │ │ ├── nllb-clip-large.json │ │ ├── roberta-ViT-B-32.json │ │ ├── swin_base_patch4_window7_224.json │ │ ├── vit_medium_patch16_gap_256.json │ │ ├── vit_relpos_medium_patch16_cls_224.json │ │ ├── xlm-roberta-base-ViT-B-32.json │ │ └── xlm-roberta-large-ViT-H-14.json │ ├── modified_resnet.py │ ├── openai.py │ ├── pos_embed.py │ ├── pretrained.py │ ├── push_to_hf_hub.py │ ├── timm_model.py │ ├── tokenizer.py │ ├── transform.py │ ├── transformer.py │ ├── utils.py │ ├── version.py │ ├── zero_shot_classifier.py │ └── zero_shot_metadata.py └── training │ ├── .gitignore │ ├── __init__.py │ ├── data.py │ ├── distributed.py │ ├── file_utils.py │ ├── interleaved.py │ ├── logger.py │ ├── main.py │ ├── params.py │ ├── precision.py │ ├── profiler.py │ ├── scheduler.py │ ├── train.py │ └── zero_shot.py └── tests ├── test_download_pretrained.py ├── test_hf_model.py ├── test_inference.py ├── test_inference_simple.py ├── test_num_shards.py ├── test_training_simple.py ├── test_wds.py └── util_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | wandb/ 3 | models/ 4 | features/ 5 | results/ 6 | 7 | tests/data/ 8 | *.pt 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | sync.sh 140 | gpu1sync.sh 141 | .idea 142 | *.pdf 143 | **/._* 144 | **/*DS_* 145 | **.jsonl 146 | src/sbatch 147 | src/misc 148 | .vscode 149 | src/debug 150 | core.* 151 | 152 | # Allow 153 | !src/evaluation/misc/results_dbs/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 OpenGVLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include src/open_clip/bpe_simple_vocab_16e6.txt.gz 2 | include src/open_clip/model_configs/*.json 3 | 4 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | install: ## [Local development] Upgrade pip, install requirements, install package. 2 | python -m pip install -U pip 3 | python -m pip install -e . 4 | 5 | install-training: 6 | python -m pip install -r requirements-training.txt 7 | 8 | install-test: ## [Local development] Install test requirements 9 | python -m pip install -r requirements-test.txt 10 | 11 | test: ## [Local development] Run unit tests 12 | python -m pytest -x -s -v tests 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Latent Compression Learning (LCL) 2 | 3 | ![Static Badge](https://img.shields.io/badge/NeurIPS-2024-red) 4 | [![Static Badge](https://img.shields.io/badge/arXiv-2406.07543-green)](https://arxiv.org/abs/2406.07543) 5 | 6 | **[NeurIPS 2024]** [**Vision Model Pre-training on Interleaved Image-Text Data via Latent Compression Learning**](https://arxiv.org/abs/2406.07543) 7 | 8 | We introduce the Latent Compression Learning (LCL) to pre-train vision models from scratch with interleaved image-text data. Compared to existing methods (e.g., CLIP, auto-regressive text generation), our proposed LCL is the first to achieve both 9 | 10 | * Learning vision models from scratch 11 | * Training on interleaved image-text data 12 | 13 | ![overview](./assets/overview.png) 14 | 15 | ## 📈 Results 16 | 17 | ### Pre-training on MMC4 Dataset 18 | 19 | ![result_interleaved](./assets/result_interleaved.png) 20 | 21 | Our LCL pre-training significantly outperforms all other methods in the caption tasks and is on par with the best paired pre-training methods on classification and retrieval tasks. 22 | 23 | ### Comparison with OpenCLIP 24 | 25 | ![result_main_transfer](./assets/result_main_transfer.png) 26 | 27 | ![result_main_multimodal](./assets/result_main_multimodal.png) 28 | 29 | When both using LAION-400M data, our LCL pre-training achieves similar performance to OpenCLIP. When combined with MMC4 data, our LCL pre-training outperforms OpenCLIP, especially in caption and multi-modal dialogue tasks. For a fair comparison, the total number of images seen during pre-training is 13B. 30 | 31 | ## 📦 Pre-trained Checkpoints 32 | 33 | | model | data | # samples | download | 34 | | :---: | :---: | :---: | :---: | 35 | | ViT-B/16 | LAION-400M | 13B | [config](./src/open_clip/model_configs/LCL_ViT-B-16_laion.json) / [ckpt](https://huggingface.co/OpenGVLab/LCL-ViT-B-16-Laion) | 36 | 37 | ## 🛠️ Usage 38 | 39 | ### Install 40 | 41 | This code is built upon [OpenCLIP](https://github.com/mlfoundations/open_clip), you can refer to their repository for setup. 42 | 43 | ### Load Pre-trained Checkpoints 44 | 45 | Here is an example code to load pre-trained checkpoints: 46 | 47 | ```python 48 | import open_clip 49 | 50 | model_name = "LCL_ViT-B-16_laion" 51 | pretrained = "path to the `.pt` file" 52 | 53 | model = open_clip.create_model(model_name, pretrained=pretrained) 54 | ``` 55 | 56 | ### Train LCL 57 | 58 | The example training scripts are provided in [`./scripts`](./scripts). You can refer to [OpenCLIP](https://github.com/mlfoundations/open_clip?tab=readme-ov-file#training-clip) for more ways to launch training. 59 | 60 | **Training on LAION-400M.** Here is an example training script: [`./scripts/lcl_vit_b_32_laion.sh`](./scripts/lcl_vit_b_32_laion.sh). The corresponding model config is [here](./src/open_clip/model_configs/LCL_ViT-B-32_laion.json). 61 | 62 | **Training on MMC4.** We provide a simple dataloader that supports the original [MMC4](https://github.com/allenai/mmc4) dataset. Organize the data folder as follows: 63 | 64 | ``` 65 | /path/to/mmc4/ 66 | ├── images/ 67 | │ └── ... 68 | └── data/ 69 |    ├── docs_shard_0_v2.jsonl.zip 70 | ├── docs_shard_1_v2.jsonl.zip 71 | └── ... 72 | ``` 73 | 74 | Here is an example training script: [`./scripts/lcl_vit_b_32_mmc4.sh`](./scripts/lcl_vit_b_32_mmc4.sh). The corresponding model config is [here](./src/open_clip/model_configs/LCL_ViT-B-32_mmc4.json). 75 | 76 | More training scripts can be found under [`./scripts`](./scripts). 77 | 78 | **NOTE:** We conduct large-scale pre-training with internal efficient code, which will not be released due to intellectual property reasons. This released version has been verified and can reproduce the results of ViT-B/16 on LAION-400M dataset. 79 | 80 | 81 | ## 📅 Schedule 82 | 83 | * [X] basic code of LCL 84 | * [ ] checkpoints of more models and datasets 85 | * [ ] transfer evaluation code 86 | 87 | ## 🖊️ Citation 88 | 89 | If you find this work helpful in your research, please consider citing: 90 | 91 | ```bibtex 92 | @article{yang2024vision, 93 | title={Vision Model Pre-training on Interleaved Image-Text Data via Latent Compression Learning}, 94 | author={Yang, Chenyu and Zhu, Xizhou and Zhu, Jinguo and Su, Weijie and Wang, Junjie and Dong, Xuan and Wang, Wenhai and Li, Bin and Zhou, Jie and Qiao, Yu and Dai, Jifeng}, 95 | journal={arXiv preprint arXiv:2406.07543}, 96 | year={2024} 97 | } 98 | ``` 99 | 100 | ## 📃 License 101 | 102 | This project is released under the [MIT license](LICENSE). Parts of this project contain code and models from other sources, which are subject to their respective licenses. 103 | 104 | ## 🙏 Acknowledgements 105 | 106 | Our code is built with reference to the code of the following projects: [OpenCLIP](https://github.com/mlfoundations/open_clip). 107 | -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/LCL/b1b8e083fc48580438c6a60981590fdb08cc1a25/assets/overview.png -------------------------------------------------------------------------------- /assets/result_interleaved.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/LCL/b1b8e083fc48580438c6a60981590fdb08cc1a25/assets/result_interleaved.png -------------------------------------------------------------------------------- /assets/result_main_multimodal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/LCL/b1b8e083fc48580438c6a60981590fdb08cc1a25/assets/result_main_multimodal.png -------------------------------------------------------------------------------- /assets/result_main_transfer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/LCL/b1b8e083fc48580438c6a60981590fdb08cc1a25/assets/result_main_transfer.png -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest-split==0.8.0 2 | pytest==7.2.0 3 | transformers 4 | timm>=0.9.8 5 | -------------------------------------------------------------------------------- /requirements-training.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | torchvision 3 | webdataset>=0.2.5 4 | regex 5 | ftfy 6 | tqdm 7 | pandas 8 | braceexpand 9 | huggingface_hub 10 | transformers 11 | timm>=0.9.8 12 | fsspec 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | torchvision 3 | regex 4 | ftfy 5 | tqdm 6 | huggingface_hub 7 | sentencepiece 8 | protobuf 9 | timm 10 | -------------------------------------------------------------------------------- /scripts/clipav1_vit_l16_i37_t8.sh: -------------------------------------------------------------------------------- 1 | # eval on a single gpu 2 | CUDA_VISIBLE_DEVICES=2 TORCH_CUDNN_V8_API_ENABLED=1 TFDS_PREFETCH_SIZE=8192 python3 -m training.main \ 3 | --model ViT-L-16-CL32-GAP \ 4 | --pretrained "/path/to/clipa_vit_l16_i37_t8.pt" \ 5 | --seed 0 \ 6 | --imagenet-val '/path/to/ImageNet/val' -------------------------------------------------------------------------------- /scripts/clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python3 -m training.main \ 2 | --model ViT-H-14-CL32-GAP-BigVision \ 3 | --pretrained "/path/to/vit_h14_i84_224_336_cl32_gap_datacomp1b.pt" \ 4 | --force-image-size 336 \ 5 | --square-resize-only \ 6 | --interpolation 'bilinear' \ 7 | --image-mean 0.485 0.456 0.406 \ 8 | --image-std 0.229 0.224 0.225 \ 9 | --seed 0 \ 10 | --imagenet-val '/path/to/ImageNet/val' 11 | -------------------------------------------------------------------------------- /scripts/h14_224_32_finetune.sh: -------------------------------------------------------------------------------- 1 | # 64k batchsize for 2.048e-3 lr 2 | TORCH_CUDNN_V8_API_ENABLED=1 torchrun --nproc_per_node 8 -m training.main \ 3 | --save-frequency 1 \ 4 | --save-most-recent \ 5 | --zeroshot-frequency 1 \ 6 | --train-data '/path/to/laion' \ 7 | --dataset-type webdataset \ 8 | --lr "2.048e-3" \ 9 | --beta1 0.9 \ 10 | --beta2 0.95 \ 11 | --warmup 782 \ 12 | --wd 0.2 \ 13 | --batch-size 4096 \ 14 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 15 | --epochs=7 \ 16 | --workers=6 \ 17 | --model ViT-H-14-CL32-GAP \ 18 | --precision 'amp_bf16' \ 19 | --local-loss \ 20 | --gather-with-grad \ 21 | --force-image-size 224 \ 22 | --grad-checkpointing \ 23 | --log-every-n-steps 32 \ 24 | --seed 0 \ 25 | --logs ./logs/ \ 26 | --imagenet-val '/path/to/ImageNet/val' \ 27 | --name 'name' \ 28 | --report-to "wandb" \ 29 | --wandb-project-name "project_name" 30 | 31 | 32 | -------------------------------------------------------------------------------- /scripts/h14_84_8_pretrain.sh: -------------------------------------------------------------------------------- 1 | # 64k batchsize for 2.048e-3 lr 2 | TORCH_CUDNN_V8_API_ENABLED=1 torchrun --nproc_per_node 8 -m training.main \ 3 | --save-frequency 1 \ 4 | --save-most-recent \ 5 | --zeroshot-frequency 1 \ 6 | --train-data '/path/to/laion' \ 7 | --dataset-type webdataset \ 8 | --lr "2.048e-3" \ 9 | --beta1 0.9 \ 10 | --beta2 0.95 \ 11 | --warmup 782 \ 12 | --wd 0.2 \ 13 | --batch-size 4096 \ 14 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 15 | --epochs=7 \ 16 | --workers=6 \ 17 | --model ViT-H-14-CL8-SyntaxMask-GAP \ 18 | --precision 'amp_bf16' \ 19 | --local-loss \ 20 | --gather-with-grad \ 21 | --force-image-size 84 \ 22 | --grad-checkpointing \ 23 | --log-every-n-steps 32 \ 24 | --seed 0 \ 25 | --logs ./logs/ \ 26 | --imagenet-val '/path/to/ImageNet/val' \ 27 | --name 'name' \ 28 | --report-to "wandb" \ 29 | --wandb-project-name "project_name" 30 | 31 | 32 | -------------------------------------------------------------------------------- /scripts/lcl_vit_b_16_laion.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node $NPROC_PER_NODE -m training.main \ 2 | --save-frequency 1 \ 3 | --zeroshot-frequency 1 \ 4 | --report-to tensorboard \ 5 | --train-data '/path/to/laion' \ 6 | --dataset-type webdataset \ 7 | --imagenet-val '/path/to/ImageNet/val' \ 8 | --warmup 2000 \ 9 | --batch-size $BATCH_SIZE \ 10 | --epochs 32 \ 11 | --workers 10 \ 12 | --model LCL_ViT-B-16_laion \ 13 | --name LCL_ViT-B-16_laion \ 14 | --seed 0 \ 15 | --local-loss \ 16 | --grad-checkpointing \ 17 | --gather-with-grad \ 18 | --use-interleaved-wrapper \ 19 | --interleaved-context-length 280 \ 20 | --num-img-token 196 \ 21 | --img-first-prob 0.5 \ 22 | --lcl-generation-loss-weight 1.0 \ 23 | --lcl-contrastive-loss-weight 0.1 \ -------------------------------------------------------------------------------- /scripts/lcl_vit_b_32_laion.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node $NPROC_PER_NODE -m training.main \ 2 | --save-frequency 1 \ 3 | --zeroshot-frequency 1 \ 4 | --report-to tensorboard \ 5 | --train-data '/path/to/laion' \ 6 | --dataset-type webdataset \ 7 | --imagenet-val '/path/to/ImageNet/val' \ 8 | --warmup 2000 \ 9 | --batch-size $BATCH_SIZE \ 10 | --epochs 32 \ 11 | --workers 10 \ 12 | --model LCL_ViT-B-32_laion \ 13 | --name LCL_ViT-B-32_laion \ 14 | --seed 0 \ 15 | --local-loss \ 16 | --gather-with-grad \ 17 | --use-interleaved-wrapper \ 18 | --interleaved-context-length 128 \ 19 | --num-img-token 49 \ 20 | --img-first-prob 0.5 \ 21 | --lcl-generation-loss-weight 1.0 \ 22 | --lcl-contrastive-loss-weight 0.1 \ -------------------------------------------------------------------------------- /scripts/lcl_vit_b_32_mmc4.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node $NPROC_PER_NODE -m training.main \ 2 | --save-frequency 1 \ 3 | --zeroshot-frequency 1 \ 4 | --report-to tensorboard \ 5 | --train-data '/path/to/mmc4' \ 6 | --dataset-type mmc4 \ 7 | --imagenet-val '/path/to/ImageNet/val' \ 8 | --lr 3e-4 \ 9 | --warmup 2000 \ 10 | --batch-size $BATCH_SIZE \ 11 | --epochs 32 \ 12 | --workers 10 \ 13 | --model LCL_ViT-B-32_mmc4 \ 14 | --name LCL_ViT-B-32_mmc4 \ 15 | --seed 0 \ 16 | --grad-checkpointing \ 17 | --local-loss \ 18 | --gather-with-grad \ 19 | --interleaved-context-length 2048 \ 20 | --num-img-token 49 \ 21 | --img-first-prob 0.5 \ 22 | --data-global-distributed \ 23 | --lcl-generation-loss-weight 1.0 \ 24 | --lcl-contrastive-loss-weight 0.1 \ -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ Setup 2 | """ 3 | from setuptools import setup, find_packages 4 | from codecs import open 5 | from os import path 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | 9 | # Get the long description from the README file 10 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 11 | long_description = f.read() 12 | 13 | def _read_reqs(relpath): 14 | fullpath = path.join(path.dirname(__file__), relpath) 15 | with open(fullpath) as f: 16 | return [s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#"))] 17 | 18 | REQUIREMENTS = _read_reqs("requirements.txt") 19 | TRAINING_REQUIREMENTS = _read_reqs("requirements-training.txt") 20 | 21 | exec(open('src/open_clip/version.py').read()) 22 | setup( 23 | name='open_clip_torch', 24 | version=__version__, 25 | description='OpenCLIP', 26 | license='MIT', 27 | long_description=long_description, 28 | long_description_content_type='text/markdown', 29 | url='https://github.com/mlfoundations/open_clip', 30 | author='', 31 | author_email='', 32 | classifiers=[ 33 | # How mature is this project? Common values are 34 | # 3 - Alpha 35 | # 4 - Beta 36 | # 5 - Production/Stable 37 | 'Development Status :: 3 - Alpha', 38 | 'Intended Audience :: Education', 39 | 'Intended Audience :: Science/Research', 40 | 'License :: OSI Approved :: Apache Software License', 41 | 'Programming Language :: Python :: 3.7', 42 | 'Programming Language :: Python :: 3.8', 43 | 'Programming Language :: Python :: 3.9', 44 | 'Programming Language :: Python :: 3.10', 45 | 'Topic :: Scientific/Engineering', 46 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 47 | 'Topic :: Software Development', 48 | 'Topic :: Software Development :: Libraries', 49 | 'Topic :: Software Development :: Libraries :: Python Modules', 50 | ], 51 | 52 | # Note that this is a string of words separated by whitespace, not a list. 53 | keywords='CLIP pretrained', 54 | package_dir={'': 'src'}, 55 | packages=find_packages(where='src'), 56 | include_package_data=True, 57 | install_requires=REQUIREMENTS, 58 | extras_require={ 59 | "training": TRAINING_REQUIREMENTS, 60 | }, 61 | python_requires='>=3.7', 62 | ) 63 | -------------------------------------------------------------------------------- /src/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .coca_model import CoCa 2 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 3 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss 4 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 5 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss 6 | from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ 7 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \ 8 | get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg 9 | from .openai import load_openai_model, list_openai_models 10 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ 11 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 12 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub 13 | from .tokenizer import SimpleTokenizer, tokenize, decode 14 | from .transform import image_transform, AugmentationCfg 15 | from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy 16 | from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES 17 | -------------------------------------------------------------------------------- /src/open_clip/attentive.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class CrossAttention(nn.Module): 8 | def __init__(self, 9 | dim, 10 | num_heads=8, 11 | qkv_bias=False, 12 | qk_scale=None, 13 | attn_head_dim=None, 14 | out_dim=None, 15 | out_bias=True 16 | ): 17 | super().__init__() 18 | if out_dim is None: 19 | out_dim = dim 20 | self.num_heads = num_heads 21 | head_dim = out_dim // num_heads 22 | if attn_head_dim is not None: 23 | head_dim = attn_head_dim 24 | all_head_dim = head_dim * self.num_heads 25 | self.scale = qk_scale or head_dim ** -0.5 26 | assert all_head_dim == out_dim 27 | 28 | self.q = nn.Linear(dim, all_head_dim, bias=False) 29 | self.k = nn.Linear(dim, all_head_dim, bias=False) 30 | self.v = nn.Linear(dim, all_head_dim, bias=False) 31 | 32 | if qkv_bias: 33 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) 34 | self.k_bias = nn.Parameter(torch.zeros(all_head_dim)) 35 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) 36 | else: 37 | self.q_bias = None 38 | self.k_bias = None 39 | self.v_bias = None 40 | 41 | self.proj = nn.Linear(all_head_dim, out_dim, bias=out_bias) 42 | 43 | def forward(self, x, k=None, v=None): 44 | B, N, C = x.shape 45 | N_k = k.shape[1] 46 | N_v = v.shape[1] 47 | 48 | q_bias, k_bias, v_bias = None, None, None 49 | if self.q_bias is not None: 50 | q_bias = self.q_bias 51 | k_bias = self.k_bias 52 | v_bias = self.v_bias 53 | 54 | q = F.linear(input=x, weight=self.q.weight, bias=q_bias) 55 | q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim) 56 | 57 | k = F.linear(input=k, weight=self.k.weight, bias=k_bias) 58 | k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) 59 | 60 | v = F.linear(input=v, weight=self.v.weight, bias=v_bias) 61 | v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) 62 | 63 | q = q * self.scale 64 | attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k) 65 | 66 | attn = attn.softmax(dim=-1) 67 | 68 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 69 | x = self.proj(x) 70 | 71 | return x 72 | 73 | 74 | class AttentiveBlock(nn.Module): 75 | 76 | def __init__(self, 77 | dim, 78 | num_heads, 79 | qkv_bias=False, 80 | qk_scale=None, 81 | norm_layer=nn.LayerNorm, 82 | attn_head_dim=None, 83 | out_dim=None, 84 | out_bias=True 85 | ): 86 | super().__init__() 87 | 88 | self.norm1_q = norm_layer(dim) 89 | self.norm1_k = norm_layer(dim) 90 | self.norm1_v = norm_layer(dim) 91 | self.cross_attn = CrossAttention( 92 | dim, 93 | num_heads=num_heads, 94 | qkv_bias=qkv_bias, 95 | qk_scale=qk_scale, 96 | attn_head_dim=attn_head_dim, 97 | out_dim=out_dim, 98 | out_bias=out_bias 99 | ) 100 | 101 | def forward(self, x_q, x_kv, pos_q, pos_k): 102 | x_q = self.norm1_q(x_q + pos_q) 103 | x_k = self.norm1_k(x_kv + pos_k) 104 | x_v = self.norm1_v(x_kv) 105 | x = self.cross_attn(x_q, k=x_k, v=x_v) 106 | return x 107 | 108 | 109 | class AttentivePoolingProjection(nn.Module): 110 | def __init__(self, 111 | input_dim, 112 | output_dim, 113 | num_query, 114 | num_heads=None, 115 | norm_layer=nn.LayerNorm, 116 | out_bias=False, 117 | ): 118 | super().__init__() 119 | if num_heads is None: 120 | num_heads = int(output_dim // 64) 121 | self.query_token = nn.Parameter(torch.randn(1, num_query, input_dim)) 122 | self.pooler = AttentiveBlock( 123 | dim=input_dim, 124 | out_dim=output_dim, 125 | num_heads=num_heads, 126 | qkv_bias=True, 127 | qk_scale=None, 128 | norm_layer=norm_layer, 129 | out_bias=out_bias, 130 | ) 131 | 132 | def forward_pool(self, x): 133 | query_tokens = self.query_token.expand(x.shape[0], -1, -1) 134 | query_tokens = self.pooler(query_tokens, x, 0, 0) 135 | return query_tokens.squeeze(1) 136 | 137 | def forward_project(self, x): 138 | x = self.pooler.norm1_v(x) 139 | x = F.linear(input=x, weight=self.pooler.cross_attn.v.weight, bias=self.pooler.cross_attn.v_bias) 140 | x = self.pooler.cross_attn.proj(x) 141 | return x 142 | 143 | def forward(self, x): 144 | pooled = self.forward_pool(x) 145 | projected = self.forward_project(x) 146 | return pooled, projected 147 | -------------------------------------------------------------------------------- /src/open_clip/big_vision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .model import CustomTextCLIP 5 | from .transformer import TextTransformer, Transformer 6 | 7 | 8 | @torch.no_grad() 9 | def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): 10 | """ Load weights from .npz checkpoints for official Google big_vision image-text models 11 | 12 | Currently the SigLIP source models are supported and a CustomTextCLIP destination model 13 | w/ timm image encoder. 14 | """ 15 | from timm.layers import resample_patch_embed, resample_abs_pos_embed 16 | 17 | def _n2p(w, t=True): 18 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 19 | w = w.flatten() 20 | if t: 21 | if w.ndim == 4: 22 | w = w.transpose([3, 2, 0, 1]) 23 | elif w.ndim == 3: 24 | w = w.transpose([2, 0, 1]) 25 | elif w.ndim == 2: 26 | w = w.transpose([1, 0]) 27 | return torch.from_numpy(w) 28 | 29 | w = np.load(checkpoint_path) 30 | interpolation = 'bilinear' 31 | antialias = False 32 | 33 | def _convert_timm_img(module, prefix): 34 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 35 | if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]: 36 | embed_conv_w = resample_patch_embed( 37 | embed_conv_w, 38 | module.patch_embed.proj.weight.shape[-2:], 39 | interpolation=interpolation, 40 | antialias=antialias, 41 | verbose=True, 42 | ) 43 | module.patch_embed.proj.weight.copy_(embed_conv_w) 44 | module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 45 | 46 | if module.cls_token is not None: 47 | module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 48 | 49 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) 50 | if pos_embed_w.shape != module.pos_embed.shape: 51 | assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}' 52 | num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1) 53 | pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights 54 | pos_embed_w, 55 | new_size=module.patch_embed.grid_size, 56 | num_prefix_tokens=num_prefix_tokens, 57 | interpolation=interpolation, 58 | antialias=antialias, 59 | verbose=True, 60 | ) 61 | module.pos_embed.copy_(pos_embed_w) 62 | 63 | mha_sub, b_sub, ln1_sub = (0, 0, 1) 64 | for i, block in enumerate(module.blocks.children()): 65 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 66 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' 67 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 68 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 69 | block.attn.qkv.weight.copy_(torch.cat([ 70 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 71 | block.attn.qkv.bias.copy_(torch.cat([ 72 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 73 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 74 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 75 | for r in range(2): 76 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) 77 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) 78 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) 79 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) 80 | 81 | module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 82 | module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 83 | 84 | if module.attn_pool is not None: 85 | block_prefix = f'{prefix}MAPHead_0/' 86 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' 87 | module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) 88 | module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) 89 | module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) 90 | module.attn_pool.kv.weight.copy_(torch.cat([ 91 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) 92 | module.attn_pool.kv.bias.copy_(torch.cat([ 93 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) 94 | module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 95 | module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 96 | module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 97 | module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 98 | for r in range(2): 99 | getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) 100 | getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) 101 | 102 | def _convert_openclip_transformer(module: Transformer, prefix): 103 | for i, block in enumerate(module.resblocks.children()): 104 | block_prefix = f'{prefix}encoderblock_{i}/' 105 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' 106 | block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 107 | block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 108 | block.attn.in_proj_weight.copy_(torch.cat([ 109 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 110 | block.attn.in_proj_bias.copy_(torch.cat([ 111 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 112 | block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 113 | block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 114 | block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'])) 115 | block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'])) 116 | block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'])) 117 | block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'])) 118 | block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'])) 119 | block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'])) 120 | 121 | def _convert_openclip_txt(module: TextTransformer, prefix): 122 | module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False)) 123 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0) 124 | module.positional_embedding.copy_(pos_embed_w) 125 | _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') 126 | module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) 127 | module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) 128 | module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 129 | module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 130 | 131 | _convert_timm_img(model.visual.trunk, 'params/img/') 132 | _convert_openclip_txt(model.text, 'params/txt/') 133 | model.logit_bias.copy_(_n2p(w['params/b'])[0]) 134 | model.logit_scale.copy_(_n2p(w['params/t'])[0]) 135 | 136 | 137 | -------------------------------------------------------------------------------- /src/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/LCL/b1b8e083fc48580438c6a60981590fdb08cc1a25/src/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /src/open_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 4 | IMAGENET_STD = (0.229, 0.224, 0.225) 5 | INCEPTION_MEAN = (0.5, 0.5, 0.5) 6 | INCEPTION_STD = (0.5, 0.5, 0.5) 7 | -------------------------------------------------------------------------------- /src/open_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | # https://huggingface.co/docs/transformers/model_doc/bert 46 | "bert": { 47 | "config_names": { 48 | "context_length": "max_position_embeddings", 49 | "vocab_size": "vocab_size", 50 | "width": "hidden_size", 51 | "heads": "num_attention_heads", 52 | "layers": "num_hidden_layers", 53 | }, 54 | "pooler": "cls_pooler", 55 | }, 56 | # https://huggingface.co/docs/transformers/model_doc/m2m_100 57 | "m2m_100": { 58 | "config_names": { 59 | "context_length": "max_position_embeddings", 60 | "vocab_size": "vocab_size", 61 | "width": "d_model", 62 | "heads": "encoder_attention_heads", 63 | "layers": "encoder_layers", 64 | }, 65 | "pooler": "cls_pooler", 66 | }, 67 | } 68 | -------------------------------------------------------------------------------- /src/open_clip/hf_model.py: -------------------------------------------------------------------------------- 1 | """ huggingface model adapter 2 | 3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. 4 | """ 5 | import re 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch import TensorType 10 | 11 | try: 12 | import transformers 13 | from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig 14 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ 15 | BaseModelOutputWithPoolingAndCrossAttentions 16 | except ImportError as e: 17 | transformers = None 18 | 19 | 20 | class BaseModelOutput: 21 | pass 22 | 23 | 24 | class PretrainedConfig: 25 | pass 26 | 27 | from .hf_configs import arch_dict 28 | 29 | 30 | # utils 31 | def _camel2snake(s): 32 | return re.sub(r'(?'] 85 | self.eot_id = special_token_ids[''] 86 | self.img_id = special_token_ids[''] 87 | self.soi_id = special_token_ids[''] 88 | self.eoi_id = special_token_ids[''] 89 | 90 | self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) 91 | if init_logit_bias is not None: 92 | self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) 93 | else: 94 | self.logit_bias = None 95 | 96 | @torch.jit.ignore 97 | def set_grad_checkpointing(self, enable=True): 98 | self.visual.set_grad_checkpointing(enable) 99 | self.text.set_grad_checkpointing(enable) 100 | 101 | def _encode_image(self, x: torch.Tensor): 102 | x = self.visual.conv1(x) # shape = [*, width, grid, grid] 103 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 104 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 105 | 106 | # class embeddings and positional embeddings 107 | x = torch.cat([_expand_token(self.visual.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) 108 | # shape = [*, grid ** 2 + 1, width] 109 | x = x + self.visual.positional_embedding.to(x.dtype) 110 | 111 | x = self.visual.patch_dropout(x) 112 | x = self.visual.ln_pre(x) 113 | 114 | x = x.permute(1, 0, 2) # NLD -> LND 115 | x = self.visual.transformer(x) 116 | x = x.permute(1, 0, 2) # LND -> NLD 117 | 118 | x = self.visual.ln_post(x) 119 | x = x[:, 1:] # NOTE: only take patch tokens 120 | pooled, tokens = self.pool_project(x) # attn pool & project 121 | 122 | return pooled, tokens 123 | 124 | def _forward_text(self, x: torch.Tensor): 125 | cast_dtype = self.text.transformer.get_cast_dtype() 126 | x = x + self.text.positional_embedding.to(cast_dtype)[:x.shape[1]] 127 | # deal with variable seq length 128 | if x.shape[1] != self.text.attn_mask.shape[0]: 129 | attn_mask = _build_causal_mask(x.shape[1]).to(x.device) 130 | else: 131 | attn_mask = self.text.attn_mask 132 | x = x.permute(1, 0, 2) # NLD -> LND 133 | x = self.text.transformer(x, attn_mask=attn_mask) 134 | x = x.permute(1, 0, 2) # LND -> NLD 135 | return x 136 | 137 | def encode_image(self, images, normalize: bool = True): 138 | image_latent, _ = self._encode_image(images) 139 | image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent 140 | return image_latent 141 | 142 | def get_interleaved_embs(self, image_embs, text_ids): 143 | cast_dtype = self.text.transformer.get_cast_dtype() 144 | 145 | # image token mask 146 | img_mask = (text_ids == self.img_id) 147 | # text token embs 148 | text_embs = self.text.token_embedding(text_ids[~img_mask]).to(cast_dtype) 149 | # merge to interleaved embs 150 | input_embs = text_embs.new_zeros((*text_ids.shape, text_embs.shape[-1])) 151 | input_embs[~img_mask] = text_embs 152 | input_embs[img_mask] = image_embs.flatten(0, 1).to(cast_dtype) 153 | 154 | return input_embs 155 | 156 | def encode_text(self, text, normalize: bool = True, context_length=77): 157 | # only for getting text feature for retrieval 158 | eot_mask = (text == self.eot_id) 159 | # NOTE: replace with since we use text features at for contrastive learning 160 | text[eot_mask] = self.soi_id 161 | 162 | # truncation 163 | text = text[:, :context_length] 164 | num_soi = torch.sum((text == self.soi_id), dim=-1) 165 | text[:, -1][num_soi == 0] = self.soi_id # add if truncated 166 | 167 | cast_dtype = self.text.transformer.get_cast_dtype() 168 | x = self.text.token_embedding(text).to(cast_dtype) 169 | x = self._forward_text(x) 170 | 171 | x = x[text == self.soi_id] 172 | x = self.text.ln_final(x) 173 | if self.text.text_projection is not None: 174 | if isinstance(self.text.text_projection, nn.Linear): 175 | x = self.text.text_projection(x) 176 | else: 177 | x = x @ self.text.text_projection 178 | 179 | return F.normalize(x, dim=-1) if normalize else x 180 | 181 | def get_contrastive_features(self, image_features, text_outputs, text_ids): 182 | ind_matrix = torch.arange(text_ids.shape[1], device=text_ids.device)[None].repeat((text_ids.shape[0], 1)) 183 | # NOTE: do not take at the beginning 184 | soi_mask = (text_ids == self.soi_id) & (ind_matrix > 0) 185 | soi_batch_id, soi_seq_id = torch.nonzero(soi_mask, as_tuple=True) 186 | text_features = text_outputs[soi_batch_id, soi_seq_id] 187 | text_features = self.text.ln_final(text_features) 188 | if self.text.text_projection is not None: 189 | if isinstance(self.text.text_projection, nn.Linear): 190 | text_features = self.text.text_projection(text_features) 191 | else: 192 | text_features = text_features @ self.text.text_projection 193 | text_features = F.normalize(text_features.float(), dim=-1).to(dtype=text_features.dtype) 194 | 195 | # NOTE: do not take at the beginning 196 | soi_mask = (text_ids == self.soi_id) 197 | ignore_mask = (ind_matrix[soi_mask] == 0) 198 | image_features = image_features[~ignore_mask] 199 | image_features = F.normalize(image_features.float(), dim=-1).to(dtype=image_features.dtype) 200 | 201 | assert len(text_features) == len(image_features) 202 | return image_features, text_features 203 | 204 | def get_generation_logits_labels(self, text_outputs, text_ids): 205 | gen_mask = \ 206 | (text_ids[:, 1:] != self.img_id) & \ 207 | (text_ids[:, 1:] != self.eoi_id) & \ 208 | (text_ids[:, 1:] != self.pad_id) 209 | 210 | logits = self.lm_head(text_outputs[:, :-1][gen_mask]) 211 | labels = text_ids[:, 1:][gen_mask] 212 | 213 | return logits, labels 214 | 215 | def forward( 216 | self, 217 | image: Optional[torch.Tensor] = None, 218 | text: Optional[torch.Tensor] = None, 219 | ): 220 | image_latent, image_embs = self._encode_image(image) 221 | if text is None: 222 | image_latent = F.normalize(image_latent, dim=-1) 223 | return {"image_features": image_latent, "image_embs": image_embs} 224 | 225 | interleaved_embs = self.get_interleaved_embs(image_embs, text) 226 | text_outputs = self._forward_text(interleaved_embs) 227 | 228 | image_features, text_features = self.get_contrastive_features(image_latent, text_outputs, text) 229 | logits, labels = self.get_generation_logits_labels(text_outputs, text) 230 | 231 | out_dict = { 232 | "image_features": image_features, 233 | "text_features": text_features, 234 | "logits": logits, 235 | "labels": labels, 236 | "logit_scale": self.logit_scale.exp() 237 | } 238 | return out_dict -------------------------------------------------------------------------------- /src/open_clip/model_configs/EVA01-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva_giant_patch14_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/EVA01-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva_giant_patch14_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/EVA02-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_base_patch16_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/EVA02-E-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_enormous_patch14_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1280, 14 | "heads": 20, 15 | "layers": 32 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/EVA02-E-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_enormous_patch14_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/EVA02-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "timm_model_name": "eva02_large_patch14_clip_336", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/EVA02-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_large_patch14_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/LCL_ViT-B-16_laion.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 280, 11 | "vocab_size": 49411, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12, 15 | "tokenizer_kwargs": { 16 | "additional_special_tokens": [ 17 | "", 18 | "", 19 | "" 20 | ] 21 | } 22 | }, 23 | "pool_project_cfg": { 24 | "pool_proj_type": "attn", 25 | "input_dim": 768, 26 | "output_dim": 768, 27 | "attn_num_heads": 12 28 | }, 29 | "special_token_ids": { 30 | "": 49406, 31 | "": 49407, 32 | "": 49408, 33 | "": 49409, 34 | "": 49410 35 | }, 36 | "custom_text": true 37 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/LCL_ViT-B-32_laion.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 128, 11 | "vocab_size": 49411, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12, 15 | "tokenizer_kwargs": { 16 | "additional_special_tokens": [ 17 | "", 18 | "", 19 | "" 20 | ] 21 | } 22 | }, 23 | "pool_project_cfg": { 24 | "pool_proj_type": "attn", 25 | "input_dim": 768, 26 | "output_dim": 512, 27 | "attn_num_heads": 8 28 | }, 29 | "special_token_ids": { 30 | "": 49406, 31 | "": 49407, 32 | "": 49408, 33 | "": 49409, 34 | "": 49410 35 | }, 36 | "custom_text": true 37 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/LCL_ViT-B-32_mmc4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 2048, 11 | "vocab_size": 49411, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12, 15 | "tokenizer_kwargs": { 16 | "additional_special_tokens": [ 17 | "", 18 | "", 19 | "" 20 | ] 21 | } 22 | }, 23 | "pool_project_cfg": { 24 | "pool_proj_type": "attn", 25 | "input_dim": 768, 26 | "output_dim": 512, 27 | "attn_num_heads": 8 28 | }, 29 | "special_token_ids": { 30 | "": 49406, 31 | "": 49407, 32 | "": 49408, 33 | "": 49409, 34 | "": 49410 35 | }, 36 | "custom_text": true 37 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN50x64.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": [ 6 | 3, 7 | 15, 8 | 36, 9 | 10 10 | ], 11 | "width": 128, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 1024, 18 | "heads": 16, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-16-SigLIP-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 256, 7 | "timm_model_name": "vit_base_patch16_siglip_256", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-16-SigLIP-384.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_base_patch16_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-16-SigLIP-512.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 512, 7 | "timm_model_name": "vit_base_patch16_siglip_512", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-16-SigLIP-i18n-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 256, 7 | "timm_model_name": "vit_base_patch16_siglip_256", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 250000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP-i18n-256", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-16-SigLIP.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 224, 7 | "timm_model_name": "vit_base_patch16_siglip_224", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-16-plus-240.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 240, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-16-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-16-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-32-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-32-plus-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-H-14-378-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 378, 6 | "layers": 32, 7 | "width": 1280, 8 | "head_width": 80, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-H-14-CLIPA-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14, 9 | "no_ln_pre": true, 10 | "pool_type": "avg", 11 | "final_ln_after_pool": true 12 | }, 13 | "text_cfg": { 14 | "context_length": 32, 15 | "vocab_size": 32000, 16 | "hf_tokenizer_name": "bert-base-uncased", 17 | "tokenizer_kwargs": { 18 | "strip_sep_token": true 19 | }, 20 | "width": 1024, 21 | "heads": 16, 22 | "layers": 24, 23 | "pool_type": "last", 24 | "no_causal_mask": true 25 | } 26 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-H-14-CLIPA.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14, 9 | "no_ln_pre": true, 10 | "pool_type": "avg", 11 | "final_ln_after_pool": true 12 | }, 13 | "text_cfg": { 14 | "context_length": 32, 15 | "vocab_size": 32000, 16 | "hf_tokenizer_name": "bert-base-uncased", 17 | "tokenizer_kwargs": { 18 | "strip_sep_token": true 19 | }, 20 | "width": 1024, 21 | "heads": 16, 22 | "layers": 24, 23 | "pool_type": "last", 24 | "no_causal_mask": true 25 | } 26 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-H-14-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 32, 7 | "width": 1280, 8 | "head_width": 80, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-H-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-14-280.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 280, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-14-CLIPA-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "no_ln_pre": true, 9 | "pool_type": "avg", 10 | "final_ln_after_pool": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 32, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "bert-base-uncased", 16 | "tokenizer_kwargs": { 17 | "strip_sep_token": true 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "pool_type": "last", 23 | "no_causal_mask": true 24 | } 25 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-14-CLIPA.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "no_ln_pre": true, 9 | "pool_type": "avg", 10 | "final_ln_after_pool": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 32, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "bert-base-uncased", 16 | "tokenizer_kwargs": { 17 | "strip_sep_token": true 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "pool_type": "last", 23 | "no_causal_mask": true 24 | } 25 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-14-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 24, 7 | "width": 1024, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-16-320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 320, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-16-SigLIP-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 256, 7 | "timm_model_name": "vit_large_patch16_siglip_256", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-16-SigLIP-384.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_large_patch16_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-M-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16, 8 | "ls_init_value": 1e-4 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 384, 14 | "heads": 6, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-M-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-M-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-M-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-S-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-S-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-S-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-S-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-SO400M-14-SigLIP-384.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1152, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_so400m_patch14_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 1152, 20 | "heads": 16, 21 | "layers": 27, 22 | "mlp_ratio": 3.7362, 23 | "no_causal_mask": true, 24 | "proj_bias": true, 25 | "pool_type": "last", 26 | "norm_kwargs":{ 27 | "eps": 1e-6 28 | } 29 | } 30 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-SO400M-14-SigLIP.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1152, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 224, 7 | "timm_model_name": "vit_so400m_patch14_siglip_224", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 16, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 1152, 20 | "heads": 16, 21 | "layers": 27, 22 | "mlp_ratio": 3.7362, 23 | "no_causal_mask": true, 24 | "proj_bias": true, 25 | "pool_type": "last", 26 | "norm_kwargs":{ 27 | "eps": 1e-6 28 | } 29 | } 30 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-bigG-14-CLIPA-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14, 10 | "no_ln_pre": true, 11 | "pool_type": "avg", 12 | "final_ln_after_pool": true 13 | }, 14 | "text_cfg": { 15 | "context_length": 32, 16 | "vocab_size": 32000, 17 | "hf_tokenizer_name": "bert-base-uncased", 18 | "tokenizer_kwargs": { 19 | "strip_sep_token": true 20 | }, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "pool_type": "last", 25 | "no_causal_mask": true 26 | } 27 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-bigG-14-CLIPA.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14, 10 | "no_ln_pre": true, 11 | "pool_type": "avg", 12 | "final_ln_after_pool": true 13 | }, 14 | "text_cfg": { 15 | "context_length": 32, 16 | "vocab_size": 32000, 17 | "hf_tokenizer_name": "bert-base-uncased", 18 | "tokenizer_kwargs": { 19 | "strip_sep_token": true 20 | }, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "pool_type": "last", 25 | "no_causal_mask": true 26 | } 27 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-bigG-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 32 17 | } 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-e-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 56, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.5715, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 36 17 | } 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/coca_ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32, 8 | "attentional_pool": true, 9 | "attn_pooler_heads": 8, 10 | "output_tokens": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 76, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12, 18 | "embed_cls": true, 19 | "output_tokens": true 20 | }, 21 | "multimodal_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 49408, 24 | "width": 512, 25 | "heads": 8, 26 | "layers": 12, 27 | "attn_pooler_heads": 8 28 | }, 29 | "custom_text": true 30 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/coca_ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "attentional_pool": true, 9 | "attn_pooler_heads": 8, 10 | "output_tokens": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 76, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 12, 18 | "embed_cls": true, 19 | "output_tokens": true 20 | }, 21 | "multimodal_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 49408, 24 | "width": 768, 25 | "heads": 12, 26 | "layers": 12, 27 | "attn_pooler_heads": 12 28 | }, 29 | "custom_text": true 30 | } 31 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/coca_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "multimodal_cfg": { 4 | "width": 768, 5 | "context_length": 76, 6 | "vocab_size": 64000, 7 | "mlp_ratio": 4, 8 | "layers": 12, 9 | "dim_head": 64, 10 | "heads": 12, 11 | "n_queries": 256, 12 | "attn_pooler_heads": 8 13 | }, 14 | "vision_cfg": { 15 | "image_size": 288, 16 | "layers": 12, 17 | "width": 768, 18 | "patch_size": 18, 19 | "output_tokens": true 20 | }, 21 | "text_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 64000, 24 | "layers": 12, 25 | "heads": 12, 26 | "width": 768, 27 | "embed_cls": true, 28 | "output_tokens": true 29 | }, 30 | "custom_text": true 31 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/coca_roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32, 8 | "output_tokens": true 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "hf_proj_type": "linear", 14 | "width": 768, 15 | "output_tokens": true 16 | }, 17 | "multimodal_cfg": { 18 | "context_length": 76, 19 | "width": 768, 20 | "heads": 8, 21 | "layers": 12 22 | }, 23 | "custom_text": true 24 | } 25 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_base_w.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 640, 16 | "heads": 10, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_base_w_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 640, 16 | "heads": 10, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_large_d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "mlp", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 16 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_large_d_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "mlp", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 16 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_small", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_tiny", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_xlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 20 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_xxlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xxlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 24 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_xxlarge_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xxlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 24 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/mt5-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "google/mt5-base", 11 | "hf_tokenizer_name": "google/mt5-base", 12 | "hf_pooler_type": "mean_pooler" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/mt5-xl-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "google/mt5-xl", 12 | "hf_tokenizer_name": "google/mt5-xl", 13 | "hf_pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/nllb-clip-base-siglip.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "custom_text": true, 4 | "init_logit_bias": -10, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_base_patch16_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "hf_model_name": "facebook/nllb-200-distilled-600M", 14 | "hf_tokenizer_name": "facebook/nllb-200-distilled-600M", 15 | "hf_proj_type": "linear", 16 | "hf_pooler_type": "cls_pooler" 17 | } 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/nllb-clip-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "facebook/nllb-200-distilled-600M", 11 | "hf_tokenizer_name": "facebook/nllb-200-distilled-600M", 12 | "hf_proj_type": "linear", 13 | "hf_pooler_type": "cls_pooler" 14 | } 15 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/nllb-clip-large-siglip.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1152, 3 | "custom_text": true, 4 | "init_logit_bias": -10, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_so400m_patch14_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "hf_model_name": "facebook/nllb-200-distilled-1.3B", 14 | "hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B", 15 | "hf_proj_type": "linear", 16 | "hf_pooler_type": "cls_pooler" 17 | } 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/nllb-clip-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "facebook/nllb-200-distilled-1.3B", 12 | "hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B", 13 | "hf_proj_type": "linear", 14 | "hf_pooler_type": "cls_pooler" 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "hf_pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/swin_base_patch4_window7_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "swin_base_patch4_window7_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 640, 14 | "heads": 10, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/vit_medium_patch16_gap_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_medium_patch16_gap_256", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 256 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_relpos_medium_patch16_cls_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "xlm-roberta-base", 11 | "hf_tokenizer_name": "xlm-roberta-base", 12 | "hf_pooler_type": "mean_pooler" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "xlm-roberta-large", 12 | "hf_tokenizer_name": "xlm-roberta-large", 13 | "hf_pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/open_clip/modified_resnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from open_clip.utils import freeze_batch_norm_2d 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.act1(self.bn1(self.conv1(x))) 46 | out = self.act2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.act3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x, key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0., 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | 92 | return x[0] 93 | 94 | 95 | class ModifiedResNet(nn.Module): 96 | """ 97 | A ResNet class that is similar to torchvision's but contains the following changes: 98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 100 | - The final pooling layer is a QKV attention instead of an average pool 101 | """ 102 | 103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 104 | super().__init__() 105 | self.output_dim = output_dim 106 | self.image_size = image_size 107 | 108 | # the 3-layer stem 109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 110 | self.bn1 = nn.BatchNorm2d(width // 2) 111 | self.act1 = nn.ReLU(inplace=True) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.act2 = nn.ReLU(inplace=True) 115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 116 | self.bn3 = nn.BatchNorm2d(width) 117 | self.act3 = nn.ReLU(inplace=True) 118 | self.avgpool = nn.AvgPool2d(2) 119 | 120 | # residual layers 121 | self._inplanes = width # this is a *mutable* variable used during construction 122 | self.layer1 = self._make_layer(width, layers[0]) 123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 126 | 127 | embed_dim = width * 32 # the ResNet feature dimension 128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 129 | 130 | self.init_parameters() 131 | 132 | def _make_layer(self, planes, blocks, stride=1): 133 | layers = [Bottleneck(self._inplanes, planes, stride)] 134 | 135 | self._inplanes = planes * Bottleneck.expansion 136 | for _ in range(1, blocks): 137 | layers.append(Bottleneck(self._inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def init_parameters(self): 142 | if self.attnpool is not None: 143 | std = self.attnpool.c_proj.in_features ** -0.5 144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 148 | 149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 150 | for name, param in resnet_block.named_parameters(): 151 | if name.endswith("bn3.weight"): 152 | nn.init.zeros_(param) 153 | 154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | if freeze_bn_stats: 159 | freeze_batch_norm_2d(self) 160 | 161 | @torch.jit.ignore 162 | def set_grad_checkpointing(self, enable=True): 163 | # FIXME support for non-transformer 164 | pass 165 | 166 | def stem(self, x): 167 | x = self.act1(self.bn1(self.conv1(x))) 168 | x = self.act2(self.bn2(self.conv2(x))) 169 | x = self.act3(self.bn3(self.conv3(x))) 170 | x = self.avgpool(x) 171 | return x 172 | 173 | def forward(self, x): 174 | x = self.stem(x) 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | x = self.attnpool(x) 180 | 181 | return x 182 | -------------------------------------------------------------------------------- /src/open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 13 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 14 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 15 | 16 | __all__ = ["list_openai_models", "load_openai_model"] 17 | 18 | 19 | def list_openai_models() -> List[str]: 20 | """Returns the names of available CLIP models""" 21 | return list_pretrained_models_by_tag('openai') 22 | 23 | 24 | def load_openai_model( 25 | name: str, 26 | precision: Optional[str] = None, 27 | device: Optional[Union[str, torch.device]] = None, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | cache_dir : Optional[str] 41 | The directory to cache the downloaded model weights 42 | 43 | Returns 44 | ------- 45 | model : torch.nn.Module 46 | The CLIP model 47 | preprocess : Callable[[PIL.Image], torch.Tensor] 48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 49 | """ 50 | if device is None: 51 | device = "cuda" if torch.cuda.is_available() else "cpu" 52 | if precision is None: 53 | precision = 'fp32' if device == 'cpu' else 'fp16' 54 | 55 | if get_pretrained_url(name, 'openai'): 56 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 57 | elif os.path.isfile(name): 58 | model_path = name 59 | else: 60 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 61 | 62 | try: 63 | # loading JIT archive 64 | model = torch.jit.load(model_path, map_location="cpu").eval() 65 | state_dict = None 66 | except RuntimeError: 67 | # loading saved state dict 68 | state_dict = torch.load(model_path, map_location="cpu") 69 | 70 | # Build a non-jit model from the OpenAI jitted model state dict 71 | cast_dtype = get_cast_dtype(precision) 72 | try: 73 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 74 | except KeyError: 75 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 76 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 77 | 78 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 79 | model = model.to(device) 80 | # FIXME support pure fp16/bf16 precision modes 81 | if precision != 'fp16': 82 | model.float() 83 | if precision == 'bf16': 84 | # for bf16, convert back to low-precision 85 | convert_weights_to_lp(model, dtype=torch.bfloat16) 86 | 87 | # add mean / std attributes for consistency with OpenCLIP models 88 | model.visual.image_mean = OPENAI_DATASET_MEAN 89 | model.visual.image_std = OPENAI_DATASET_STD 90 | return model 91 | -------------------------------------------------------------------------------- /src/open_clip/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /src/open_clip/push_to_hf_hub.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from pathlib import Path 5 | from tempfile import TemporaryDirectory 6 | from typing import Optional, Tuple, Union 7 | 8 | import torch 9 | 10 | try: 11 | from huggingface_hub import ( 12 | create_repo, 13 | get_hf_file_metadata, 14 | hf_hub_download, 15 | hf_hub_url, 16 | repo_type_and_id_from_hf_id, 17 | upload_folder, 18 | list_repo_files, 19 | ) 20 | from huggingface_hub.utils import EntryNotFoundError 21 | _has_hf_hub = True 22 | except ImportError: 23 | _has_hf_hub = False 24 | 25 | try: 26 | import safetensors.torch 27 | _has_safetensors = True 28 | except ImportError: 29 | _has_safetensors = False 30 | 31 | from .factory import create_model_from_pretrained, get_model_config, get_tokenizer 32 | from .tokenizer import HFTokenizer 33 | 34 | # Default name for a weights file hosted on the Huggingface Hub. 35 | HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl 36 | HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version 37 | HF_CONFIG_NAME = 'open_clip_config.json' 38 | 39 | 40 | def save_config_for_hf( 41 | model, 42 | config_path: str, 43 | model_config: Optional[dict] 44 | ): 45 | preprocess_cfg = { 46 | 'mean': model.visual.image_mean, 47 | 'std': model.visual.image_std, 48 | } 49 | other_pp = getattr(model.visual, 'preprocess_cfg', {}) 50 | if 'interpolation' in other_pp: 51 | preprocess_cfg['interpolation'] = other_pp['interpolation'] 52 | if 'resize_mode' in other_pp: 53 | preprocess_cfg['resize_mode'] = other_pp['resize_mode'] 54 | hf_config = { 55 | 'model_cfg': model_config, 56 | 'preprocess_cfg': preprocess_cfg, 57 | } 58 | 59 | with config_path.open('w') as f: 60 | json.dump(hf_config, f, indent=2) 61 | 62 | 63 | def save_for_hf( 64 | model, 65 | tokenizer: HFTokenizer, 66 | model_config: dict, 67 | save_directory: str, 68 | safe_serialization: Union[bool, str] = 'both', 69 | skip_weights : bool = False, 70 | ): 71 | config_filename = HF_CONFIG_NAME 72 | 73 | save_directory = Path(save_directory) 74 | save_directory.mkdir(exist_ok=True, parents=True) 75 | 76 | if not skip_weights: 77 | tensors = model.state_dict() 78 | if safe_serialization is True or safe_serialization == "both": 79 | assert _has_safetensors, "`pip install safetensors` to use .safetensors" 80 | safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME) 81 | if safe_serialization is False or safe_serialization == "both": 82 | torch.save(tensors, save_directory / HF_WEIGHTS_NAME) 83 | 84 | tokenizer.save_pretrained(save_directory) 85 | 86 | config_path = save_directory / config_filename 87 | save_config_for_hf(model, config_path, model_config=model_config) 88 | 89 | 90 | def push_to_hf_hub( 91 | model, 92 | tokenizer, 93 | model_config: Optional[dict], 94 | repo_id: str, 95 | commit_message: str = 'Add model', 96 | token: Optional[str] = None, 97 | revision: Optional[str] = None, 98 | private: bool = False, 99 | create_pr: bool = False, 100 | model_card: Optional[dict] = None, 101 | safe_serialization: Union[bool, str] = False, 102 | ): 103 | if not isinstance(tokenizer, HFTokenizer): 104 | # FIXME this makes it awkward to push models with new tokenizers, come up with better soln. 105 | # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 106 | tokenizer = HFTokenizer('openai/clip-vit-large-patch14') 107 | 108 | # Create repo if it doesn't exist yet 109 | repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) 110 | 111 | # Infer complete repo_id from repo_url 112 | # Can be different from the input `repo_id` if repo_owner was implicit 113 | _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) 114 | repo_id = f"{repo_owner}/{repo_name}" 115 | 116 | # Check if repo already exists and determine what needs updating 117 | repo_exists = False 118 | repo_files = {} 119 | try: 120 | repo_files = set(list_repo_files(repo_id)) 121 | repo_exists = True 122 | except Exception as e: 123 | print('Repo does not exist', e) 124 | 125 | try: 126 | get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) 127 | has_readme = True 128 | except EntryNotFoundError: 129 | has_readme = False 130 | 131 | # Dump model and push to Hub 132 | with TemporaryDirectory() as tmpdir: 133 | # Save model weights and config. 134 | save_for_hf( 135 | model, 136 | tokenizer=tokenizer, 137 | model_config=model_config, 138 | save_directory=tmpdir, 139 | safe_serialization=safe_serialization, 140 | ) 141 | 142 | # Add readme if it does not exist 143 | if not has_readme: 144 | model_card = model_card or {} 145 | model_name = repo_id.split('/')[-1] 146 | readme_path = Path(tmpdir) / "README.md" 147 | readme_text = generate_readme(model_card, model_name) 148 | readme_path.write_text(readme_text) 149 | 150 | # Upload model and return 151 | return upload_folder( 152 | repo_id=repo_id, 153 | folder_path=tmpdir, 154 | revision=revision, 155 | create_pr=create_pr, 156 | commit_message=commit_message, 157 | ) 158 | 159 | 160 | def push_pretrained_to_hf_hub( 161 | model_name, 162 | pretrained: str, 163 | repo_id: str, 164 | precision: str = 'fp32', 165 | image_mean: Optional[Tuple[float, ...]] = None, 166 | image_std: Optional[Tuple[float, ...]] = None, 167 | image_interpolation: Optional[str] = None, 168 | image_resize_mode: Optional[str] = None, # only effective for inference 169 | commit_message: str = 'Add model', 170 | token: Optional[str] = None, 171 | revision: Optional[str] = None, 172 | private: bool = False, 173 | create_pr: bool = False, 174 | model_card: Optional[dict] = None, 175 | hf_tokenizer_self: bool = False, 176 | ): 177 | model, preprocess_eval = create_model_from_pretrained( 178 | model_name, 179 | pretrained=pretrained, 180 | precision=precision, 181 | image_mean=image_mean, 182 | image_std=image_std, 183 | image_interpolation=image_interpolation, 184 | image_resize_mode=image_resize_mode, 185 | ) 186 | model_config = get_model_config(model_name) 187 | assert model_config 188 | 189 | tokenizer = get_tokenizer(model_name) 190 | if hf_tokenizer_self: 191 | # make hf tokenizer config in the uploaded model point to self instead of original location 192 | model_config['text']['hf_tokenizer_name'] = repo_id 193 | 194 | push_to_hf_hub( 195 | model=model, 196 | tokenizer=tokenizer, 197 | model_config=model_config, 198 | repo_id=repo_id, 199 | commit_message=commit_message, 200 | token=token, 201 | revision=revision, 202 | private=private, 203 | create_pr=create_pr, 204 | model_card=model_card, 205 | safe_serialization='both', 206 | ) 207 | 208 | 209 | def generate_readme(model_card: dict, model_name: str): 210 | tags = model_card.pop('tags', ('clip',)) 211 | pipeline_tag = model_card.pop('pipeline_tag', 'zero-shot-image-classification') 212 | readme_text = "---\n" 213 | if tags: 214 | readme_text += "tags:\n" 215 | for t in tags: 216 | readme_text += f"- {t}\n" 217 | readme_text += "library_name: open_clip\n" 218 | readme_text += f"pipeline_tag: {pipeline_tag}\n" 219 | readme_text += f"license: {model_card.get('license', 'mit')}\n" 220 | if 'details' in model_card and 'Dataset' in model_card['details']: 221 | readme_text += 'datasets:\n' 222 | readme_text += f"- {model_card['details']['Dataset'].lower()}\n" 223 | readme_text += "---\n" 224 | readme_text += f"# Model card for {model_name}\n" 225 | if 'description' in model_card: 226 | readme_text += f"\n{model_card['description']}\n" 227 | if 'details' in model_card: 228 | readme_text += f"\n## Model Details\n" 229 | for k, v in model_card['details'].items(): 230 | if isinstance(v, (list, tuple)): 231 | readme_text += f"- **{k}:**\n" 232 | for vi in v: 233 | readme_text += f" - {vi}\n" 234 | elif isinstance(v, dict): 235 | readme_text += f"- **{k}:**\n" 236 | for ki, vi in v.items(): 237 | readme_text += f" - {ki}: {vi}\n" 238 | else: 239 | readme_text += f"- **{k}:** {v}\n" 240 | if 'usage' in model_card: 241 | readme_text += f"\n## Model Usage\n" 242 | readme_text += model_card['usage'] 243 | readme_text += '\n' 244 | 245 | if 'comparison' in model_card: 246 | readme_text += f"\n## Model Comparison\n" 247 | readme_text += model_card['comparison'] 248 | readme_text += '\n' 249 | 250 | if 'citation' in model_card: 251 | readme_text += f"\n## Citation\n" 252 | if not isinstance(model_card['citation'], (list, tuple)): 253 | citations = [model_card['citation']] 254 | else: 255 | citations = model_card['citation'] 256 | for c in citations: 257 | readme_text += f"```bibtex\n{c}\n```\n" 258 | 259 | return readme_text 260 | 261 | 262 | if __name__ == "__main__": 263 | parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") 264 | parser.add_argument( 265 | "--model", type=str, help="Name of the model to use.", 266 | ) 267 | parser.add_argument( 268 | "--pretrained", type=str, 269 | help="Use a pretrained CLIP model weights with the specified tag or file path.", 270 | ) 271 | parser.add_argument( 272 | "--repo-id", type=str, 273 | help="Destination HF Hub repo-id ie 'organization/model_id'.", 274 | ) 275 | parser.add_argument( 276 | "--precision", type=str, default='fp32', 277 | ) 278 | parser.add_argument( 279 | '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', 280 | help='Override default image mean value of dataset') 281 | parser.add_argument( 282 | '--image-std', type=float, nargs='+', default=None, metavar='STD', 283 | help='Override default image std deviation of of dataset') 284 | parser.add_argument( 285 | '--image-interpolation', 286 | default=None, type=str, choices=['bicubic', 'bilinear', 'random'], 287 | help="image resize interpolation" 288 | ) 289 | parser.add_argument( 290 | '--image-resize-mode', 291 | default=None, type=str, choices=['shortest', 'longest', 'squash'], 292 | help="image resize mode during inference" 293 | ) 294 | parser.add_argument( 295 | "--hf-tokenizer-self", 296 | default=False, 297 | action="store_true", 298 | help="make hf_tokenizer_name point in uploaded config point to itself" 299 | ) 300 | args = parser.parse_args() 301 | 302 | print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') 303 | 304 | # FIXME add support to pass model_card json / template from file via cmd line 305 | 306 | push_pretrained_to_hf_hub( 307 | args.model, 308 | args.pretrained, 309 | args.repo_id, 310 | precision=args.precision, 311 | image_mean=args.image_mean, # override image mean/std if trained w/ non defaults 312 | image_std=args.image_std, 313 | image_interpolation=args.image_interpolation, 314 | image_resize_mode=args.image_resize_mode, 315 | ) 316 | 317 | print(f'{args.model} saved.') 318 | -------------------------------------------------------------------------------- /src/open_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | import logging 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | try: 12 | import timm 13 | from timm.models.layers import Mlp, to_2tuple 14 | try: 15 | # old timm imports < 0.8.1 16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 18 | except ImportError: 19 | # new timm imports >= 0.8.1 20 | from timm.layers import RotAttentionPool2d 21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 22 | except ImportError: 23 | timm = None 24 | 25 | from .utils import freeze_batch_norm_2d 26 | 27 | 28 | class TimmModel(nn.Module): 29 | """ timm model adapter 30 | """ 31 | 32 | def __init__( 33 | self, 34 | model_name, 35 | embed_dim, 36 | image_size=224, 37 | pool='avg', 38 | proj='linear', 39 | proj_bias=False, 40 | drop=0., 41 | drop_path=None, 42 | patch_drop=None, 43 | pretrained=False, 44 | ): 45 | super().__init__() 46 | if timm is None: 47 | raise RuntimeError("Please `pip install timm` to use timm models.") 48 | self.image_size = to_2tuple(image_size) 49 | 50 | # setup kwargs that may not be common across all models 51 | timm_kwargs = {} 52 | if drop_path is not None: 53 | timm_kwargs['drop_path_rate'] = drop_path 54 | if patch_drop is not None: 55 | timm_kwargs['patch_drop_rate'] = patch_drop 56 | 57 | custom_pool = pool in ('abs_attn', 'rot_attn') 58 | if proj: 59 | assert proj in ("linear", "mlp", "none") 60 | extra_proj = proj in ("linear", "mlp") 61 | if not extra_proj and not custom_pool: 62 | # use network classifier head as projection if no proj specified and no custom pooling used 63 | # if projection is explicitly set to "none" will be pass through from network trunk 64 | proj_dim = 0 if proj == 'none' else embed_dim 65 | self.trunk = timm.create_model( 66 | model_name, 67 | num_classes=proj_dim, 68 | global_pool=pool, 69 | pretrained=pretrained, 70 | **timm_kwargs, 71 | ) 72 | prev_chs = embed_dim 73 | else: 74 | self.trunk = timm.create_model( 75 | model_name, 76 | pretrained=pretrained, 77 | **timm_kwargs, 78 | ) 79 | feat_size = self.trunk.default_cfg.get('pool_size', None) 80 | feature_ndim = 1 if not feat_size else 2 81 | if custom_pool: 82 | assert feature_ndim == 2 83 | # if attn pooling used, remove both classifier and default pool 84 | self.trunk.reset_classifier(0, global_pool='') 85 | else: 86 | # reset global pool if pool config set, otherwise leave as network default 87 | reset_kwargs = dict(global_pool=pool) if pool else {} 88 | self.trunk.reset_classifier(0, **reset_kwargs) 89 | prev_chs = self.trunk.num_features 90 | 91 | head_layers = OrderedDict() 92 | 93 | # Add custom pooling to head 94 | if pool == 'abs_attn': 95 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 96 | prev_chs = embed_dim 97 | elif pool == 'rot_attn': 98 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 99 | prev_chs = embed_dim 100 | 101 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 102 | if proj == 'linear': 103 | head_layers['drop'] = nn.Dropout(drop) 104 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 105 | elif proj == 'mlp': 106 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) 107 | 108 | self.head = nn.Sequential(head_layers) 109 | 110 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 111 | """ lock modules 112 | Args: 113 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 114 | """ 115 | if not unlocked_groups: 116 | # lock full model 117 | for param in self.trunk.parameters(): 118 | param.requires_grad = False 119 | if freeze_bn_stats: 120 | freeze_batch_norm_2d(self.trunk) 121 | else: 122 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 123 | try: 124 | # FIXME import here until API stable and in an official release 125 | from timm.models.helpers import group_parameters, group_modules 126 | except ImportError: 127 | raise RuntimeError( 128 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 129 | matcher = self.trunk.group_matcher() 130 | gparams = group_parameters(self.trunk, matcher) 131 | max_layer_id = max(gparams.keys()) 132 | max_layer_id = max_layer_id - unlocked_groups 133 | for group_idx in range(max_layer_id + 1): 134 | group = gparams[group_idx] 135 | for param in group: 136 | self.trunk.get_parameter(param).requires_grad = False 137 | if freeze_bn_stats: 138 | gmodules = group_modules(self.trunk, matcher, reverse=True) 139 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 140 | freeze_batch_norm_2d(self.trunk, gmodules) 141 | 142 | @torch.jit.ignore 143 | def set_grad_checkpointing(self, enable=True): 144 | try: 145 | self.trunk.set_grad_checkpointing(enable) 146 | except Exception as e: 147 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') 148 | 149 | def forward(self, x): 150 | x = self.trunk(x) 151 | x = self.head(x) 152 | return x 153 | -------------------------------------------------------------------------------- /src/open_clip/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | import torch 5 | from torch import nn as nn 6 | from torchvision.ops.misc import FrozenBatchNorm2d 7 | 8 | 9 | def freeze_batch_norm_2d(module, module_match={}, name=''): 10 | """ 11 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 12 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 13 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 14 | 15 | Args: 16 | module (torch.nn.Module): Any PyTorch module. 17 | module_match (dict): Dictionary of full module names to freeze (all if empty) 18 | name (str): Full module name (prefix) 19 | 20 | Returns: 21 | torch.nn.Module: Resulting module 22 | 23 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 24 | """ 25 | res = module 26 | is_match = True 27 | if module_match: 28 | is_match = name in module_match 29 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 30 | res = FrozenBatchNorm2d(module.num_features) 31 | res.num_features = module.num_features 32 | res.affine = module.affine 33 | if module.affine: 34 | res.weight.data = module.weight.data.clone().detach() 35 | res.bias.data = module.bias.data.clone().detach() 36 | res.running_mean.data = module.running_mean.data 37 | res.running_var.data = module.running_var.data 38 | res.eps = module.eps 39 | else: 40 | for child_name, child in module.named_children(): 41 | full_child_name = '.'.join([name, child_name]) if name else child_name 42 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 43 | if new_child is not child: 44 | res.add_module(child_name, new_child) 45 | return res 46 | 47 | 48 | # From PyTorch internals 49 | def _ntuple(n): 50 | def parse(x): 51 | if isinstance(x, collections.abc.Iterable): 52 | return x 53 | return tuple(repeat(x, n)) 54 | return parse 55 | 56 | 57 | to_1tuple = _ntuple(1) 58 | to_2tuple = _ntuple(2) 59 | to_3tuple = _ntuple(3) 60 | to_4tuple = _ntuple(4) 61 | to_ntuple = lambda n, x: _ntuple(n)(x) 62 | 63 | # Replaces all linear layers with linear_replacement 64 | # TODO: add int8 support for other linear layers including attn and convnets 65 | def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): 66 | for name, module in model.named_children(): 67 | if len(list(module.children())) > 0: 68 | replace_linear(module, linear_replacement, include_modules, copy_weights) 69 | 70 | if isinstance(module, torch.nn.Linear) and name in include_modules: 71 | old_module = model._modules[name] 72 | model._modules[name] = linear_replacement( 73 | module.in_features, 74 | module.out_features, 75 | module.bias is not None, 76 | ) 77 | if copy_weights: 78 | model._modules[name].weight.data.copy_(old_module.weight.data) 79 | if model._modules[name].bias is not None: 80 | model._modules[name].bias.data.copy_(old_module.bias) 81 | 82 | return model 83 | 84 | def convert_int8_model_to_inference_mode(model): 85 | for m in model.modules(): 86 | if hasattr(m, 'prepare_for_eval'): 87 | int8_original_dtype = m.weight.dtype 88 | m.prepare_for_eval() 89 | m.int8_original_dtype = int8_original_dtype -------------------------------------------------------------------------------- /src/open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.24.0' 2 | -------------------------------------------------------------------------------- /src/open_clip/zero_shot_classifier.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from itertools import islice 3 | from typing import Callable, List, Optional, Sequence, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def batched(iterable, n): 10 | """Batch data into lists of length *n*. The last batch may be shorter. 11 | NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl 12 | """ 13 | it = iter(iterable) 14 | while True: 15 | batch = list(islice(it, n)) 16 | if not batch: 17 | break 18 | yield batch 19 | 20 | 21 | def build_zero_shot_classifier( 22 | model, 23 | tokenizer, 24 | classnames: Sequence[str], 25 | templates: Sequence[Union[Callable, str]], 26 | num_classes_per_batch: Optional[int] = 10, 27 | device: Union[str, torch.device] = 'cpu', 28 | use_tqdm: bool = False, 29 | ): 30 | """ Build zero-shot classifier weights by iterating over class names in batches 31 | Args: 32 | model: CLIP model instance 33 | tokenizer: CLIP tokenizer instance 34 | classnames: A sequence of class (label) names 35 | templates: A sequence of callables or format() friendly strings to produce templates per class name 36 | num_classes_per_batch: The number of classes to batch together in each forward, all if None 37 | device: Device to use. 38 | use_tqdm: Enable TQDM progress bar. 39 | """ 40 | assert isinstance(templates, Sequence) and len(templates) > 0 41 | assert isinstance(classnames, Sequence) and len(classnames) > 0 42 | use_format = isinstance(templates[0], str) 43 | num_templates = len(templates) 44 | num_classes = len(classnames) 45 | if use_tqdm: 46 | import tqdm 47 | num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) 48 | iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) 49 | else: 50 | iter_wrap = iter 51 | 52 | def _process_batch(batch_classnames): 53 | num_batch_classes = len(batch_classnames) 54 | texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] 55 | texts = tokenizer(texts).to(device) 56 | class_embeddings = model.encode_text(texts, normalize=True) 57 | class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) 58 | class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) 59 | class_embeddings = class_embeddings.T 60 | return class_embeddings 61 | 62 | with torch.no_grad(): 63 | if num_classes_per_batch: 64 | batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] 65 | zeroshot_weights = torch.cat(batched_embeds, dim=1) 66 | else: 67 | zeroshot_weights = _process_batch(classnames) 68 | return zeroshot_weights 69 | 70 | 71 | def build_zero_shot_classifier_legacy( 72 | model, 73 | tokenizer, 74 | classnames: Sequence[str], 75 | templates: Sequence[Union[Callable, str]], 76 | device: Union[str, torch.device] = 'cpu', 77 | use_tqdm: bool = False, 78 | ): 79 | """ Build zero-shot classifier weights by iterating over class names 1 by 1 80 | Args: 81 | model: CLIP model instance 82 | tokenizer: CLIP tokenizer instance 83 | classnames: A sequence of class (label) names 84 | templates: A sequence of callables or format() friendly strings to produce templates per class name 85 | device: Device to use. 86 | use_tqdm: Enable TQDM progress bar. 87 | """ 88 | assert isinstance(templates, Sequence) and len(templates) > 0 89 | assert isinstance(classnames, Sequence) and len(classnames) > 0 90 | if use_tqdm: 91 | import tqdm 92 | iter_wrap = tqdm.tqdm 93 | else: 94 | iter_wrap = iter 95 | 96 | use_format = isinstance(templates[0], str) 97 | 98 | with torch.no_grad(): 99 | zeroshot_weights = [] 100 | for classname in iter_wrap(classnames): 101 | texts = [template.format(classname) if use_format else template(classname) for template in templates] 102 | texts = tokenizer(texts).to(device) # tokenize 103 | class_embeddings = model.encode_text(texts) 104 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) 105 | class_embedding /= class_embedding.norm() 106 | zeroshot_weights.append(class_embedding) 107 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) 108 | 109 | return zeroshot_weights 110 | 111 | -------------------------------------------------------------------------------- /src/training/.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | -------------------------------------------------------------------------------- /src/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/LCL/b1b8e083fc48580438c6a60981590fdb08cc1a25/src/training/__init__.py -------------------------------------------------------------------------------- /src/training/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | from torch.utils.data.distributed import DistributedSampler 6 | 7 | try: 8 | import horovod.torch as hvd 9 | except ImportError: 10 | hvd = None 11 | 12 | 13 | def is_global_master(args): 14 | return args.rank == 0 15 | 16 | 17 | def is_local_master(args): 18 | return args.local_rank == 0 19 | 20 | 21 | def is_master(args, local=False): 22 | return is_local_master(args) if local else is_global_master(args) 23 | 24 | 25 | def is_using_horovod(): 26 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 27 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 28 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 29 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 30 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): 31 | return True 32 | else: 33 | return False 34 | 35 | 36 | def is_using_distributed(): 37 | if 'WORLD_SIZE' in os.environ: 38 | return int(os.environ['WORLD_SIZE']) > 1 39 | if 'SLURM_NTASKS' in os.environ: 40 | return int(os.environ['SLURM_NTASKS']) > 1 41 | return False 42 | 43 | 44 | def world_info_from_env(): 45 | local_rank = 0 46 | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): 47 | if v in os.environ: 48 | local_rank = int(os.environ[v]) 49 | break 50 | global_rank = 0 51 | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): 52 | if v in os.environ: 53 | global_rank = int(os.environ[v]) 54 | break 55 | world_size = 1 56 | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): 57 | if v in os.environ: 58 | world_size = int(os.environ[v]) 59 | break 60 | 61 | return local_rank, global_rank, world_size 62 | 63 | 64 | def init_distributed_device(args): 65 | # Distributed training = training on more than one GPU. 66 | # Works in both single and multi-node scenarios. 67 | args.distributed = False 68 | args.world_size = 1 69 | args.rank = 0 # global rank 70 | args.local_rank = 0 71 | if args.horovod: 72 | assert hvd is not None, "Horovod is not installed" 73 | hvd.init() 74 | args.local_rank = int(hvd.local_rank()) 75 | args.rank = hvd.rank() 76 | args.world_size = hvd.size() 77 | args.distributed = True 78 | os.environ['LOCAL_RANK'] = str(args.local_rank) 79 | os.environ['RANK'] = str(args.rank) 80 | os.environ['WORLD_SIZE'] = str(args.world_size) 81 | elif is_using_distributed(): 82 | if 'SLURM_PROCID' in os.environ: 83 | # DDP via SLURM 84 | args.local_rank, args.rank, args.world_size = world_info_from_env() 85 | # SLURM var -> torch.distributed vars in case needed 86 | os.environ['LOCAL_RANK'] = str(args.local_rank) 87 | os.environ['RANK'] = str(args.rank) 88 | os.environ['WORLD_SIZE'] = str(args.world_size) 89 | torch.distributed.init_process_group( 90 | backend=args.dist_backend, 91 | init_method=args.dist_url, 92 | world_size=args.world_size, 93 | rank=args.rank, 94 | ) 95 | else: 96 | # DDP via torchrun, torch.distributed.launch 97 | args.local_rank, _, _ = world_info_from_env() 98 | torch.distributed.init_process_group( 99 | backend=args.dist_backend, 100 | init_method=args.dist_url) 101 | args.world_size = torch.distributed.get_world_size() 102 | args.rank = torch.distributed.get_rank() 103 | args.distributed = True 104 | 105 | if torch.cuda.is_available(): 106 | if args.distributed and not args.no_set_device_rank: 107 | device = 'cuda:%d' % args.local_rank 108 | else: 109 | device = 'cuda:0' 110 | torch.cuda.set_device(device) 111 | else: 112 | device = 'cpu' 113 | args.device = device 114 | device = torch.device(device) 115 | return device 116 | 117 | 118 | def broadcast_object(args, obj, src=0): 119 | # broadcast a pickle-able python object from rank-0 to all ranks 120 | if args.horovod: 121 | return hvd.broadcast_object(obj, root_rank=src) 122 | else: 123 | if args.rank == src: 124 | objects = [obj] 125 | else: 126 | objects = [None] 127 | dist.broadcast_object_list(objects, src=src) 128 | return objects[0] 129 | 130 | 131 | def all_gather_object(args, obj, dst=0): 132 | # gather a pickle-able python object across all ranks 133 | if args.horovod: 134 | return hvd.allgather_object(obj) 135 | else: 136 | objects = [None for _ in range(args.world_size)] 137 | dist.all_gather_object(objects, obj) 138 | return objects 139 | 140 | 141 | class GlobalDistributedSampler(DistributedSampler): 142 | """ 143 | A modified class for global data distribution. 144 | Ensure `index % world_size == rank` 145 | """ 146 | def __iter__(self): 147 | if self.shuffle: 148 | # deterministically shuffle based on epoch 149 | g = torch.Generator() 150 | g.manual_seed(self.seed + self.epoch) 151 | indices = torch.randperm(self.num_samples, generator=g).tolist() 152 | else: 153 | indices = torch.arange(self.num_samples).tolist() 154 | 155 | indices = [(i * self.num_replicas + self.rank) % len(self.dataset) for i in indices] 156 | assert len(indices) == self.num_samples 157 | 158 | return iter(indices) 159 | -------------------------------------------------------------------------------- /src/training/file_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import multiprocessing 4 | import subprocess 5 | import time 6 | import fsspec 7 | import torch 8 | from tqdm import tqdm 9 | 10 | def remote_sync_s3(local_dir, remote_dir): 11 | # skip epoch_latest which can change during sync. 12 | result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 13 | if result.returncode != 0: 14 | logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}") 15 | return False 16 | 17 | logging.info(f"Successfully synced with S3 bucket") 18 | return True 19 | 20 | def remote_sync_fsspec(local_dir, remote_dir): 21 | # FIXME currently this is slow and not recommended. Look into speeding up. 22 | a = fsspec.get_mapper(local_dir) 23 | b = fsspec.get_mapper(remote_dir) 24 | 25 | for k in a: 26 | # skip epoch_latest which can change during sync. 27 | if 'epoch_latest.pt' in k: 28 | continue 29 | 30 | logging.info(f'Attempting to sync {k}') 31 | if k in b and len(a[k]) == len(b[k]): 32 | logging.debug(f'Skipping remote sync for {k}.') 33 | continue 34 | 35 | try: 36 | logging.info(f'Successful sync for {k}.') 37 | b[k] = a[k] 38 | except Exception as e: 39 | logging.info(f'Error during remote sync for {k}: {e}') 40 | return False 41 | 42 | return True 43 | 44 | def remote_sync(local_dir, remote_dir, protocol): 45 | logging.info('Starting remote sync.') 46 | if protocol == 's3': 47 | return remote_sync_s3(local_dir, remote_dir) 48 | elif protocol == 'fsspec': 49 | return remote_sync_fsspec(local_dir, remote_dir) 50 | else: 51 | logging.error('Remote protocol not known') 52 | return False 53 | 54 | def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol): 55 | while True: 56 | time.sleep(sync_every) 57 | remote_sync(local_dir, remote_dir, protocol) 58 | 59 | def start_sync_process(sync_every, local_dir, remote_dir, protocol): 60 | p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol)) 61 | return p 62 | 63 | # Note: we are not currently using this save function. 64 | def pt_save(pt_obj, file_path): 65 | of = fsspec.open(file_path, "wb") 66 | with of as f: 67 | torch.save(pt_obj, file_path) 68 | 69 | def pt_load(file_path, map_location=None): 70 | if file_path.startswith('s3'): 71 | logging.info('Loading remote checkpoint, which may take a bit.') 72 | of = fsspec.open(file_path, "rb") 73 | with of as f: 74 | out = torch.load(f, map_location=map_location) 75 | return out 76 | 77 | def check_exists(file_path): 78 | try: 79 | with fsspec.open(file_path): 80 | pass 81 | except FileNotFoundError: 82 | return False 83 | return True 84 | -------------------------------------------------------------------------------- /src/training/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logging(log_file, level, include_host=False): 5 | if include_host: 6 | import socket 7 | hostname = socket.gethostname() 8 | formatter = logging.Formatter( 9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 10 | else: 11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 12 | 13 | logging.root.setLevel(level) 14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 15 | for logger in loggers: 16 | logger.setLevel(level) 17 | 18 | stream_handler = logging.StreamHandler() 19 | stream_handler.setFormatter(formatter) 20 | logging.root.addHandler(stream_handler) 21 | 22 | if log_file: 23 | file_handler = logging.FileHandler(filename=log_file) 24 | file_handler.setFormatter(formatter) 25 | logging.root.addHandler(file_handler) 26 | 27 | -------------------------------------------------------------------------------- /src/training/precision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from contextlib import suppress 3 | 4 | 5 | def get_autocast(precision): 6 | if precision == 'amp': 7 | return torch.cuda.amp.autocast 8 | elif precision == 'amp_bfloat16' or precision == 'amp_bf16': 9 | # amp_bfloat16 is more stable than amp float16 for clip training 10 | return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) 11 | else: 12 | return suppress 13 | -------------------------------------------------------------------------------- /src/training/profiler.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import open_clip 5 | import pandas as pd 6 | from torch.utils.flop_counter import FlopCounterMode 7 | try: 8 | import fvcore 9 | except: 10 | fvcore = None 11 | 12 | parser = argparse.ArgumentParser(description='OpenCLIP Profiler') 13 | 14 | # benchmark specific args 15 | parser.add_argument('--model', metavar='NAME', default='', 16 | help='model(s) to profile') 17 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', 18 | help='Output csv file for results') 19 | parser.add_argument('--profiler', default='torch', type=str, choices=['torch', 'fvcore']) 20 | parser.add_argument('--batch-size', default=1, type=int, help='Batch size for profiling') 21 | 22 | 23 | def profile_fvcore( 24 | model, 25 | image_input_size=(3, 224, 224), 26 | text_input_size=(77,), 27 | batch_size=1, 28 | detailed=False, 29 | force_cpu=False 30 | ): 31 | if force_cpu: 32 | model = model.to('cpu') 33 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 34 | example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 35 | example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 36 | fca = fvcore.nn.FlopCountAnalysis(model, (example_image_input, example_text_input)) 37 | aca = fvcore.nn.ActivationCountAnalysis(model, (example_image_input, example_text_input)) 38 | if detailed: 39 | fcs = fvcore.nn.flop_count_str(fca) 40 | print(fcs) 41 | return fca.total() / batch_size, aca.total() / batch_size 42 | 43 | 44 | def profile_fvcore_text( 45 | model, 46 | text_input_size=(77,), 47 | batch_size=1, 48 | detailed=False, 49 | force_cpu=False 50 | ): 51 | if force_cpu: 52 | model = model.to('cpu') 53 | device = next(model.parameters()).device 54 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 55 | fca = fvcore.nn.FlopCountAnalysis(model, example_input) 56 | aca = fvcore.nn.ActivationCountAnalysis(model, example_input) 57 | if detailed: 58 | fcs = fvcore.nn.flop_count_str(fca) 59 | print(fcs) 60 | return fca.total() / batch_size, aca.total() / batch_size 61 | 62 | 63 | def profile_fvcore_image( 64 | model, 65 | image_input_size=(3, 224, 224), 66 | batch_size=1, 67 | detailed=False, 68 | force_cpu=False 69 | ): 70 | if force_cpu: 71 | model = model.to('cpu') 72 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 73 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 74 | fca = fvcore.nn.FlopCountAnalysis(model, example_input) 75 | aca = fvcore.nn.ActivationCountAnalysis(model, example_input) 76 | if detailed: 77 | fcs = fvcore.nn.flop_count_str(fca) 78 | print(fcs) 79 | return fca.total() / batch_size, aca.total() / batch_size 80 | 81 | 82 | def profile_torch_image(model, image_input_size, batch_size=1, force_cpu=False): 83 | """Profile the image encoder using torch.utils.flop_counter""" 84 | if force_cpu: 85 | model = model.to('cpu') 86 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 87 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 88 | 89 | flop_counter = FlopCounterMode() 90 | with flop_counter: 91 | model(example_input) 92 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values()) 93 | return total_flops / batch_size 94 | 95 | 96 | def profile_torch_text(model, text_input_size, batch_size=1, force_cpu=False): 97 | """Profile the text encoder using torch.utils.flop_counter""" 98 | if force_cpu: 99 | model = model.to('cpu') 100 | device = next(model.parameters()).device 101 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 102 | 103 | flop_counter = FlopCounterMode() 104 | with flop_counter: 105 | model(example_input) 106 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values()) 107 | return total_flops / batch_size 108 | 109 | 110 | def profile_torch(model, text_input_size, image_input_size, batch_size=1, force_cpu=False): 111 | """Profile the full model using torch.utils.flop_counter""" 112 | if force_cpu: 113 | model = model.to('cpu') 114 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 115 | image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 116 | text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 117 | 118 | flop_counter = FlopCounterMode() 119 | with flop_counter: 120 | model(image_input, text_input) 121 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values()) 122 | return total_flops / batch_size 123 | 124 | 125 | def count_params(model): 126 | return sum(m.numel() for m in model.parameters()) 127 | 128 | def profile_model(model_name, batch_size=1, profiler='torch'): 129 | assert profiler in ['torch', 'fvcore'], 'Only torch and fvcore profilers are supported' 130 | if profiler == 'fvcore': 131 | assert fvcore is not None, 'Please install fvcore.' 132 | model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False) 133 | model.eval() 134 | if torch.cuda.is_available(): 135 | model = model.cuda() 136 | 137 | if isinstance(model.visual.image_size, (tuple, list)): 138 | image_input_size = (3,) + tuple(model.visual.image_size[-2:]) 139 | else: 140 | image_input_size = (3, model.visual.image_size, model.visual.image_size) 141 | 142 | text_input_size = (77,) 143 | if hasattr(model, 'context_length') and model.context_length: 144 | text_input_size = (model.context_length,) 145 | 146 | results = {} 147 | results['model'] = model_name 148 | results['image_size'] = image_input_size[1] 149 | 150 | model_cfg = open_clip.get_model_config(model_name) 151 | if model_cfg: 152 | vision_cfg = open_clip.CLIPVisionCfg(**model_cfg['vision_cfg']) 153 | text_cfg = open_clip.CLIPTextCfg(**model_cfg['text_cfg']) 154 | results['image_width'] = int(vision_cfg.width) 155 | results['text_width'] = int(text_cfg.width) 156 | results['embed_dim'] = int(model_cfg['embed_dim']) 157 | else: 158 | results['image_width'] = 0 159 | results['text_width'] = 0 160 | results['embed_dim'] = 0 161 | 162 | retries = 2 163 | while retries: 164 | retries -= 1 165 | try: 166 | results['mparams'] = round(count_params(model) / 1e6, 2) 167 | results['image_mparams'] = round(count_params(model.visual) / 1e6, 2) 168 | results['text_mparams'] = round(count_params(model.text) / 1e6, 2) 169 | 170 | if profiler == 'fvcore': 171 | macs, acts = profile_fvcore( 172 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 173 | 174 | image_macs, image_acts = profile_fvcore_image( 175 | model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size) 176 | 177 | text_macs, text_acts = profile_fvcore_text( 178 | model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 179 | 180 | results['gmacs'] = round(macs / 1e9, 2) 181 | results['macts'] = round(acts / 1e6, 2) 182 | 183 | results['image_gmacs'] = round(image_macs / 1e9, 2) 184 | results['image_macts'] = round(image_acts / 1e6, 2) 185 | 186 | results['text_gmacs'] = round(text_macs / 1e9, 2) 187 | results['text_macts'] = round(text_acts / 1e6, 2) 188 | elif profiler == 'torch': 189 | image_flops = profile_torch_image( 190 | model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size) 191 | text_flops = profile_torch_text( 192 | model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 193 | total_flops = profile_torch( 194 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 195 | 196 | results['gflops'] = round(total_flops / 1e9, 2) 197 | results['image_gflops'] = round(image_flops / 1e9, 2) 198 | results['text_gflops'] = round(text_flops / 1e9, 2) 199 | 200 | except RuntimeError as e: 201 | pass 202 | return results 203 | 204 | 205 | def main(): 206 | args = parser.parse_args() 207 | 208 | # FIXME accept a text file name to allow lists of models in txt/csv 209 | if args.model == 'all': 210 | parsed_model = open_clip.list_models() 211 | else: 212 | parsed_model = args.model.split(',') 213 | 214 | results = [] 215 | models_with_errors = [] 216 | for m in parsed_model: 217 | print('='*100) 218 | print(f'Profiling {m}') 219 | try: 220 | row = profile_model(m, batch_size=args.batch_size, profiler=args.profiler) 221 | results.append(row) 222 | except Exception as e: 223 | print(f'Error profiling {m}: {e}') 224 | import traceback 225 | traceback.print_exc() 226 | models_with_errors.append(m) 227 | 228 | df = pd.DataFrame(results, columns=results[0].keys()) 229 | 230 | if 'gmacs' in df.columns: 231 | df = df.sort_values(by=['gmacs', 'mparams', 'model']) 232 | else: 233 | df = df.sort_values(by=['gflops', 'mparams', 'model']) 234 | 235 | print('='*100) 236 | print('Done.') 237 | print(df) 238 | if args.results_file: 239 | df.to_csv(args.results_file, index=False) 240 | 241 | if models_with_errors: 242 | print('Models with errors:', models_with_errors) 243 | 244 | 245 | if __name__ == '__main__': 246 | main() 247 | -------------------------------------------------------------------------------- /src/training/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def const_lr(optimizer, base_lr, warmup_length, steps): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | lr = base_lr 19 | assign_learning_rate(optimizer, lr) 20 | return lr 21 | return _lr_adjuster 22 | 23 | 24 | def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.): 25 | def _lr_adjuster(step): 26 | start_cooldown_step = steps - cooldown_steps 27 | if step < warmup_length: 28 | lr = _warmup_lr(base_lr, warmup_length, step) 29 | else: 30 | if step < start_cooldown_step: 31 | lr = base_lr 32 | else: 33 | e = step - start_cooldown_step 34 | es = steps - start_cooldown_step 35 | # linear decay if power == 1; polynomial decay otherwise; 36 | decay = (1 - (e/es)) ** cooldown_power 37 | lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr 38 | assign_learning_rate(optimizer, lr) 39 | return lr 40 | return _lr_adjuster 41 | 42 | 43 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 44 | def _lr_adjuster(step): 45 | if step < warmup_length: 46 | lr = _warmup_lr(base_lr, warmup_length, step) 47 | else: 48 | e = step - warmup_length 49 | es = steps - warmup_length 50 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 51 | assign_learning_rate(optimizer, lr) 52 | return lr 53 | return _lr_adjuster 54 | -------------------------------------------------------------------------------- /src/training/zero_shot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \ 7 | IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES 8 | from .precision import get_autocast 9 | 10 | 11 | def accuracy(output, target, topk=(1,)): 12 | pred = output.topk(max(topk), 1, True, True)[1].t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 15 | 16 | 17 | def run(model, classifier, dataloader, args): 18 | autocast = get_autocast(args.precision) 19 | input_dtype = get_input_dtype(args.precision) 20 | 21 | with torch.no_grad(): 22 | top1, top5, n = 0., 0., 0. 23 | for images, target in tqdm(dataloader, unit_scale=args.batch_size): 24 | images = images.to(device=args.device, dtype=input_dtype) 25 | target = target.to(args.device) 26 | 27 | with autocast(): 28 | # predict 29 | output = model(image=images) 30 | image_features = output['image_features'] if isinstance(output, dict) else output[0] 31 | logits = 100. * image_features @ classifier 32 | 33 | # measure accuracy 34 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 35 | top1 += acc1 36 | top5 += acc5 37 | n += images.size(0) 38 | 39 | top1 = (top1 / n) 40 | top5 = (top5 / n) 41 | return top1, top5 42 | 43 | 44 | def zero_shot_eval(model, data, epoch, args, tokenizer=None): 45 | if 'imagenet-val' not in data and 'imagenet-v2' not in data: 46 | return {} 47 | if args.zeroshot_frequency == 0: 48 | return {} 49 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: 50 | return {} 51 | if args.distributed and not args.horovod: 52 | model = model.module 53 | 54 | logging.info('Starting zero-shot imagenet.') 55 | if tokenizer is None: 56 | tokenizer = get_tokenizer(args.model) 57 | 58 | logging.info('Building zero-shot classifier') 59 | autocast = get_autocast(args.precision) 60 | with autocast(): 61 | classifier = build_zero_shot_classifier( 62 | model, 63 | tokenizer=tokenizer, 64 | classnames=IMAGENET_CLASSNAMES, 65 | templates=OPENAI_IMAGENET_TEMPLATES, 66 | num_classes_per_batch=10, 67 | device=args.device, 68 | use_tqdm=True, 69 | ) 70 | 71 | logging.info('Using classifier') 72 | results = {} 73 | if 'imagenet-val' in data: 74 | top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) 75 | results['imagenet-zeroshot-val-top1'] = top1 76 | results['imagenet-zeroshot-val-top5'] = top5 77 | if 'imagenet-v2' in data: 78 | top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) 79 | results['imagenetv2-zeroshot-val-top1'] = top1 80 | results['imagenetv2-zeroshot-val-top5'] = top5 81 | 82 | logging.info('Finished zero-shot imagenet.') 83 | 84 | return results 85 | -------------------------------------------------------------------------------- /tests/test_download_pretrained.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import torch 3 | from PIL import Image 4 | import hashlib 5 | import tempfile 6 | import unittest 7 | from io import BytesIO 8 | from pathlib import Path 9 | from unittest.mock import patch 10 | 11 | from urllib3 import HTTPResponse 12 | from urllib3._collections import HTTPHeaderDict 13 | 14 | import open_clip 15 | from open_clip.pretrained import download_pretrained_from_url 16 | 17 | 18 | class DownloadPretrainedTests(unittest.TestCase): 19 | 20 | def create_response(self, data, status_code=200, content_type='application/octet-stream'): 21 | fp = BytesIO(data) 22 | headers = HTTPHeaderDict({ 23 | 'Content-Type': content_type, 24 | 'Content-Length': str(len(data)) 25 | }) 26 | raw = HTTPResponse(fp, preload_content=False, headers=headers, status=status_code) 27 | return raw 28 | 29 | @patch('open_clip.pretrained.urllib') 30 | def test_download_pretrained_from_url_from_openaipublic(self, urllib): 31 | file_contents = b'pretrained model weights' 32 | expected_hash = hashlib.sha256(file_contents).hexdigest() 33 | urllib.request.urlopen.return_value = self.create_response(file_contents) 34 | with tempfile.TemporaryDirectory() as root: 35 | url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' 36 | download_pretrained_from_url(url, root) 37 | urllib.request.urlopen.assert_called_once() 38 | 39 | @patch('open_clip.pretrained.urllib') 40 | def test_download_pretrained_from_url_from_openaipublic_corrupted(self, urllib): 41 | file_contents = b'pretrained model weights' 42 | expected_hash = hashlib.sha256(file_contents).hexdigest() 43 | urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model') 44 | with tempfile.TemporaryDirectory() as root: 45 | url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' 46 | with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'): 47 | download_pretrained_from_url(url, root) 48 | urllib.request.urlopen.assert_called_once() 49 | 50 | @patch('open_clip.pretrained.urllib') 51 | def test_download_pretrained_from_url_from_openaipublic_valid_cache(self, urllib): 52 | file_contents = b'pretrained model weights' 53 | expected_hash = hashlib.sha256(file_contents).hexdigest() 54 | urllib.request.urlopen.return_value = self.create_response(file_contents) 55 | with tempfile.TemporaryDirectory() as root: 56 | local_file = Path(root) / 'RN50.pt' 57 | local_file.write_bytes(file_contents) 58 | url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' 59 | download_pretrained_from_url(url, root) 60 | urllib.request.urlopen.assert_not_called() 61 | 62 | @patch('open_clip.pretrained.urllib') 63 | def test_download_pretrained_from_url_from_openaipublic_corrupted_cache(self, urllib): 64 | file_contents = b'pretrained model weights' 65 | expected_hash = hashlib.sha256(file_contents).hexdigest() 66 | urllib.request.urlopen.return_value = self.create_response(file_contents) 67 | with tempfile.TemporaryDirectory() as root: 68 | local_file = Path(root) / 'RN50.pt' 69 | local_file.write_bytes(b'corrupted pretrained model') 70 | url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' 71 | download_pretrained_from_url(url, root) 72 | urllib.request.urlopen.assert_called_once() 73 | 74 | @patch('open_clip.pretrained.urllib') 75 | def test_download_pretrained_from_url_from_mlfoundations(self, urllib): 76 | file_contents = b'pretrained model weights' 77 | expected_hash = hashlib.sha256(file_contents).hexdigest()[:8] 78 | urllib.request.urlopen.return_value = self.create_response(file_contents) 79 | with tempfile.TemporaryDirectory() as root: 80 | url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt' 81 | download_pretrained_from_url(url, root) 82 | urllib.request.urlopen.assert_called_once() 83 | 84 | @patch('open_clip.pretrained.urllib') 85 | def test_download_pretrained_from_url_from_mlfoundations_corrupted(self, urllib): 86 | file_contents = b'pretrained model weights' 87 | expected_hash = hashlib.sha256(file_contents).hexdigest()[:8] 88 | urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model') 89 | with tempfile.TemporaryDirectory() as root: 90 | url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt' 91 | with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'): 92 | download_pretrained_from_url(url, root) 93 | urllib.request.urlopen.assert_called_once() 94 | 95 | @patch('open_clip.pretrained.urllib') 96 | def test_download_pretrained_from_hfh(self, urllib): 97 | model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:hf-internal-testing/tiny-open-clip-model') 98 | tokenizer = open_clip.get_tokenizer('hf-hub:hf-internal-testing/tiny-open-clip-model') 99 | img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png" 100 | image = preprocess(Image.open(requests.get(img_url, stream=True).raw)).unsqueeze(0) 101 | text = tokenizer(["a diagram", "a dog", "a cat"]) 102 | 103 | with torch.no_grad(): 104 | image_features = model.encode_image(image) 105 | text_features = model.encode_text(text) 106 | image_features /= image_features.norm(dim=-1, keepdim=True) 107 | text_features /= text_features.norm(dim=-1, keepdim=True) 108 | 109 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) 110 | 111 | self.assertTrue(torch.allclose(text_probs, torch.tensor([[0.0597, 0.6349, 0.3053]]), 1e-3)) 112 | -------------------------------------------------------------------------------- /tests/test_hf_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | from open_clip.hf_model import _POOLERS, HFTextEncoder 5 | from transformers import AutoConfig 6 | from transformers.modeling_outputs import BaseModelOutput 7 | # test poolers 8 | def test_poolers(): 9 | bs, sl, d = 2, 10, 5 10 | h = torch.arange(sl).repeat(bs).reshape(bs, sl)[..., None] * torch.linspace(0.2, 1., d) 11 | mask = torch.ones(bs, sl, dtype=torch.bool) 12 | mask[:2, 6:] = False 13 | x = BaseModelOutput(h) 14 | for name, cls in _POOLERS.items(): 15 | pooler = cls() 16 | res = pooler(x, mask) 17 | assert res.shape == (bs, d), f"{name} returned wrong shape" 18 | 19 | # test HFTextEncoder 20 | @pytest.mark.parametrize("model_id", ["arampacha/roberta-tiny", "roberta-base", "xlm-roberta-base", "google/mt5-base"]) 21 | def test_pretrained_text_encoder(model_id): 22 | bs, sl, d = 2, 10, 64 23 | cfg = AutoConfig.from_pretrained(model_id) 24 | model = HFTextEncoder(model_id, d, proj_type='linear') 25 | x = torch.randint(0, cfg.vocab_size, (bs, sl)) 26 | with torch.no_grad(): 27 | emb = model(x) 28 | 29 | assert emb.shape == (bs, d) 30 | -------------------------------------------------------------------------------- /tests/test_inference.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import pytest 4 | import torch 5 | import open_clip 6 | import util_test 7 | 8 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 9 | 10 | if hasattr(torch._C, '_jit_set_profiling_executor'): 11 | # legacy executor is too slow to compile large models for unit tests 12 | # no need for the fusion performance here 13 | torch._C._jit_set_profiling_executor(True) 14 | torch._C._jit_set_profiling_mode(False) 15 | 16 | models_to_test = set(open_clip.list_models()) 17 | 18 | # testing excemptions 19 | models_to_test = models_to_test.difference({ 20 | # not available with timm yet 21 | # see https://github.com/mlfoundations/open_clip/issues/219 22 | 'convnext_xlarge', 23 | 'convnext_xxlarge', 24 | 'convnext_xxlarge_320', 25 | 'vit_medium_patch16_gap_256', 26 | # exceeds GH runner memory limit 27 | 'ViT-bigG-14', 28 | 'ViT-e-14', 29 | 'mt5-xl-ViT-H-14', 30 | 'coca_base', 31 | 'coca_ViT-B-32', 32 | 'coca_roberta-ViT-B-32' 33 | }) 34 | 35 | if 'OPEN_CLIP_TEST_REG_MODELS' in os.environ: 36 | external_model_list = os.environ['OPEN_CLIP_TEST_REG_MODELS'] 37 | with open(external_model_list, 'r') as f: 38 | models_to_test = set(f.read().splitlines()).intersection(models_to_test) 39 | print(f"Selected models from {external_model_list}: {models_to_test}") 40 | 41 | # TODO: add "coca_ViT-B-32" onece https://github.com/pytorch/pytorch/issues/92073 gets fixed 42 | models_to_test = list(models_to_test) 43 | models_to_test.sort() 44 | models_to_test = [(model_name, False) for model_name in models_to_test] 45 | 46 | models_to_jit_test = {"ViT-B-32"} 47 | models_to_jit_test = list(models_to_jit_test) 48 | models_to_jit_test = [(model_name, True) for model_name in models_to_jit_test] 49 | models_to_test_fully = models_to_test + models_to_jit_test 50 | 51 | 52 | @pytest.mark.regression_test 53 | @pytest.mark.parametrize("model_name,jit", models_to_test_fully) 54 | def test_inference_with_data( 55 | model_name, 56 | jit, 57 | pretrained = None, 58 | pretrained_hf = False, 59 | precision = 'fp32', 60 | force_quick_gelu = False, 61 | ): 62 | util_test.seed_all() 63 | model, _, preprocess_val = open_clip.create_model_and_transforms( 64 | model_name, 65 | pretrained = pretrained, 66 | precision = precision, 67 | jit = jit, 68 | force_quick_gelu = force_quick_gelu, 69 | pretrained_hf = pretrained_hf 70 | ) 71 | model_id = f'{model_name}_{pretrained or pretrained_hf}_{precision}' 72 | input_dir, output_dir = util_test.get_data_dirs() 73 | # text 74 | input_text_path = os.path.join(input_dir, 'random_text.pt') 75 | gt_text_path = os.path.join(output_dir, f'{model_id}_random_text.pt') 76 | if not os.path.isfile(input_text_path): 77 | pytest.skip(reason = f"missing test data, expected at {input_text_path}") 78 | if not os.path.isfile(gt_text_path): 79 | pytest.skip(reason = f"missing test data, expected at {gt_text_path}") 80 | input_text = torch.load(input_text_path) 81 | gt_text = torch.load(gt_text_path) 82 | y_text = util_test.inference_text(model, model_name, input_text) 83 | assert (y_text == gt_text).all(), f"text output differs @ {input_text_path}" 84 | # image 85 | image_size = model.visual.image_size 86 | if not isinstance(image_size, tuple): 87 | image_size = (image_size, image_size) 88 | input_image_path = os.path.join(input_dir, f'random_image_{image_size[0]}_{image_size[1]}.pt') 89 | gt_image_path = os.path.join(output_dir, f'{model_id}_random_image.pt') 90 | if not os.path.isfile(input_image_path): 91 | pytest.skip(reason = f"missing test data, expected at {input_image_path}") 92 | if not os.path.isfile(gt_image_path): 93 | pytest.skip(reason = f"missing test data, expected at {gt_image_path}") 94 | input_image = torch.load(input_image_path) 95 | gt_image = torch.load(gt_image_path) 96 | y_image = util_test.inference_image(model, preprocess_val, input_image) 97 | assert (y_image == gt_image).all(), f"image output differs @ {input_image_path}" 98 | 99 | if not jit: 100 | model.eval() 101 | model_out = util_test.forward_model(model, model_name, preprocess_val, input_image, input_text) 102 | if type(model) not in [open_clip.CLIP, open_clip.CustomTextCLIP]: 103 | assert type(model_out) == dict 104 | else: 105 | model.output_dict = True 106 | model_out_dict = util_test.forward_model(model, model_name, preprocess_val, input_image, input_text) 107 | assert (model_out_dict["image_features"] == model_out[0]).all() 108 | assert (model_out_dict["text_features"] == model_out[1]).all() 109 | assert (model_out_dict["logit_scale"] == model_out[2]).all() 110 | model.output_dict = None 111 | else: 112 | model, _, preprocess_val = open_clip.create_model_and_transforms( 113 | model_name, 114 | pretrained = pretrained, 115 | precision = precision, 116 | jit = False, 117 | force_quick_gelu = force_quick_gelu, 118 | pretrained_hf = pretrained_hf 119 | ) 120 | 121 | test_model = util_test.TestWrapper(model, model_name, output_dict=False) 122 | test_model = torch.jit.script(test_model) 123 | model_out = util_test.forward_model(test_model, model_name, preprocess_val, input_image, input_text) 124 | assert model_out["test_output"].shape[-1] == 2 125 | 126 | test_model = util_test.TestWrapper(model, model_name, output_dict=True) 127 | test_model = torch.jit.script(test_model) 128 | model_out = util_test.forward_model(test_model, model_name, preprocess_val, input_image, input_text) 129 | assert model_out["test_output"].shape[-1] == 2 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /tests/test_inference_simple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from open_clip.factory import get_tokenizer 4 | import pytest 5 | import open_clip 6 | import os 7 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 8 | 9 | if hasattr(torch._C, '_jit_set_profiling_executor'): 10 | # legacy executor is too slow to compile large models for unit tests 11 | # no need for the fusion performance here 12 | torch._C._jit_set_profiling_executor(True) 13 | torch._C._jit_set_profiling_mode(False) 14 | 15 | 16 | test_simple_models = [ 17 | # model, pretrained, jit, force_custom_text 18 | ("ViT-B-32", "laion2b_s34b_b79k", False, False), 19 | ("ViT-B-32", "laion2b_s34b_b79k", True, False), 20 | ("ViT-B-32", "laion2b_s34b_b79k", True, True), 21 | ("roberta-ViT-B-32", "laion2b_s12b_b32k", False, False), 22 | ] 23 | 24 | 25 | @pytest.mark.parametrize("model_type,pretrained,jit,force_custom_text", test_simple_models) 26 | def test_inference_simple( 27 | model_type, 28 | pretrained, 29 | jit, 30 | force_custom_text, 31 | ): 32 | model, _, preprocess = open_clip.create_model_and_transforms( 33 | model_type, 34 | pretrained=pretrained, 35 | jit=jit, 36 | force_custom_text=force_custom_text, 37 | ) 38 | tokenizer = get_tokenizer(model_type) 39 | 40 | current_dir = os.path.dirname(os.path.realpath(__file__)) 41 | 42 | image = preprocess(Image.open(current_dir + "/../docs/CLIP.png")).unsqueeze(0) 43 | text = tokenizer(["a diagram", "a dog", "a cat"]) 44 | 45 | with torch.no_grad(): 46 | image_features = model.encode_image(image) 47 | text_features = model.encode_text(text) 48 | 49 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) 50 | 51 | assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0] 52 | -------------------------------------------------------------------------------- /tests/test_num_shards.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from training.data import get_dataset_size 4 | 5 | @pytest.mark.parametrize( 6 | "shards,expected_size", 7 | [ 8 | ('/path/to/shard.tar', 1), 9 | ('/path/to/shard_{000..000}.tar', 1), 10 | ('/path/to/shard_{000..009}.tar', 10), 11 | ('/path/to/shard_{000..009}_{000..009}.tar', 100), 12 | ('/path/to/shard.tar::/path/to/other_shard_{000..009}.tar', 11), 13 | ('/path/to/shard_{000..009}.tar::/path/to/other_shard_{000..009}.tar', 20), 14 | (['/path/to/shard.tar'], 1), 15 | (['/path/to/shard.tar', '/path/to/other_shard.tar'], 2), 16 | ] 17 | ) 18 | def test_num_shards(shards, expected_size): 19 | _, size = get_dataset_size(shards) 20 | assert size == expected_size, f'Expected {expected_size} for {shards} but found {size} instead.' 21 | -------------------------------------------------------------------------------- /tests/test_training_simple.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import pytest 5 | from PIL import Image 6 | import torch 7 | from training.main import main 8 | 9 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 10 | 11 | if hasattr(torch._C, '_jit_set_profiling_executor'): 12 | # legacy executor is too slow to compile large models for unit tests 13 | # no need for the fusion performance here 14 | torch._C._jit_set_profiling_executor(True) 15 | torch._C._jit_set_profiling_mode(False) 16 | 17 | @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") 18 | def test_training(): 19 | main([ 20 | '--save-frequency', '1', 21 | '--zeroshot-frequency', '1', 22 | '--dataset-type', "synthetic", 23 | '--train-num-samples', '16', 24 | '--warmup', '1', 25 | '--batch-size', '4', 26 | '--lr', '1e-3', 27 | '--wd', '0.1', 28 | '--epochs', '1', 29 | '--workers', '2', 30 | '--model', 'RN50' 31 | ]) 32 | 33 | @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") 34 | def test_training_coca(): 35 | main([ 36 | '--save-frequency', '1', 37 | '--zeroshot-frequency', '1', 38 | '--dataset-type', "synthetic", 39 | '--train-num-samples', '16', 40 | '--warmup', '1', 41 | '--batch-size', '4', 42 | '--lr', '1e-3', 43 | '--wd', '0.1', 44 | '--epochs', '1', 45 | '--workers', '2', 46 | '--model', 'coca_ViT-B-32' 47 | ]) 48 | 49 | @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") 50 | def test_training_mt5(): 51 | main([ 52 | '--save-frequency', '1', 53 | '--zeroshot-frequency', '1', 54 | '--dataset-type', "synthetic", 55 | '--train-num-samples', '16', 56 | '--warmup', '1', 57 | '--batch-size', '4', 58 | '--lr', '1e-3', 59 | '--wd', '0.1', 60 | '--epochs', '1', 61 | '--workers', '2', 62 | '--model', 'mt5-base-ViT-B-32', 63 | '--lock-text', 64 | '--lock-text-unlocked-layers', '2' 65 | ]) 66 | 67 | 68 | 69 | @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") 70 | def test_training_unfreezing_vit(): 71 | main([ 72 | '--save-frequency', '1', 73 | '--zeroshot-frequency', '1', 74 | '--dataset-type', "synthetic", 75 | '--train-num-samples', '16', 76 | '--warmup', '1', 77 | '--batch-size', '4', 78 | '--lr', '1e-3', 79 | '--wd', '0.1', 80 | '--epochs', '1', 81 | '--workers', '2', 82 | '--model', 'ViT-B-32', 83 | '--lock-image', 84 | '--lock-image-unlocked-groups', '5', 85 | '--accum-freq', '2' 86 | ]) 87 | 88 | 89 | @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") 90 | def test_training_clip_with_jit(): 91 | main([ 92 | '--save-frequency', '1', 93 | '--zeroshot-frequency', '1', 94 | '--dataset-type', "synthetic", 95 | '--train-num-samples', '16', 96 | '--warmup', '1', 97 | '--batch-size', '4', 98 | '--lr', '1e-3', 99 | '--wd', '0.1', 100 | '--epochs', '1', 101 | '--workers', '2', 102 | '--model', 'ViT-B-32', 103 | '--torchscript' 104 | ]) 105 | -------------------------------------------------------------------------------- /tests/test_wds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import util_test 4 | import collections 5 | import tarfile 6 | import io 7 | from PIL import Image 8 | 9 | from training.data import get_wds_dataset 10 | from training.params import parse_args 11 | from training.main import random_seed 12 | 13 | TRAIN_NUM_SAMPLES = 10_000 14 | RTOL = 0.2 15 | 16 | # NOTE: we use two test tar files, which are created on the fly and saved to data/input. 17 | # 000.tar has 10 samples, and the captions are 000_0, 000_1, ..., 000_9 18 | # 001.tar has 5 samples, and the captions are 001_0, 001_1, ..., 001_4 19 | def build_inputs(test_name): 20 | base_input_dir, _ = util_test.get_data_dirs() 21 | input_dir = os.path.join(base_input_dir, test_name) 22 | os.makedirs(input_dir, exist_ok=True) 23 | 24 | def save_tar(idx, num_samples): 25 | filename = os.path.join(input_dir, f'test_data_{idx:03d}.tar') 26 | tar = tarfile.open(filename, 'w') 27 | 28 | for sample_idx in range(num_samples): 29 | # Image 30 | image = Image.new('RGB', (32, 32)) 31 | info = tarfile.TarInfo(f'{sample_idx}.png') 32 | bio = io.BytesIO() 33 | image.save(bio, format='png') 34 | size = bio.tell() 35 | bio.seek(0) 36 | info.size = size 37 | tar.addfile(info, bio) 38 | 39 | # Caption 40 | info = tarfile.TarInfo(f'{sample_idx}.txt') 41 | bio = io.BytesIO() 42 | bio.write(f'{idx:03d}_{sample_idx}'.encode('utf-8')) 43 | size = bio.tell() 44 | bio.seek(0) 45 | info.size = size 46 | tar.addfile(info, bio) 47 | 48 | tar.close() 49 | 50 | save_tar(0, 10) 51 | save_tar(1, 5) 52 | 53 | return input_dir 54 | 55 | 56 | def build_params(input_shards, seed=0): 57 | args = parse_args([]) 58 | args.train_data = input_shards 59 | args.train_num_samples = TRAIN_NUM_SAMPLES 60 | args.dataset_resampled = True 61 | args.seed = seed 62 | args.workers = 1 63 | args.world_size = 1 64 | args.batch_size = 1 65 | random_seed(seed) 66 | 67 | preprocess_img = lambda x: x 68 | tokenizer = lambda x: [x.strip()] 69 | 70 | return args, preprocess_img, tokenizer 71 | 72 | 73 | def get_dataloader(input_shards): 74 | args, preprocess_img, tokenizer = build_params(input_shards) 75 | dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer) 76 | dataloader = dataset.dataloader 77 | return dataloader 78 | 79 | 80 | def test_single_source(): 81 | """Test webdataset with a single tar file.""" 82 | input_dir = build_inputs('single_source') 83 | input_shards = os.path.join(input_dir, 'test_data_000.tar') 84 | dataloader = get_dataloader(input_shards) 85 | 86 | counts = collections.defaultdict(int) 87 | for sample in dataloader: 88 | txts = sample[1] 89 | for txt in txts: 90 | counts[txt] += 1 91 | 92 | for key, count in counts.items(): 93 | assert count == pytest.approx(TRAIN_NUM_SAMPLES / 10, RTOL) 94 | 95 | 96 | def test_two_sources(): 97 | """Test webdataset with a single two tar files.""" 98 | input_dir = build_inputs('two_sources') 99 | input_shards = os.path.join(input_dir, 'test_data_{000..001}.tar') 100 | dataloader = get_dataloader(input_shards) 101 | 102 | counts = collections.defaultdict(int) 103 | for sample in dataloader: 104 | txts = sample[1] 105 | for txt in txts: 106 | counts[txt] += 1 107 | 108 | for key, count in counts.items(): 109 | assert count == pytest.approx(TRAIN_NUM_SAMPLES / 15, RTOL), f'{key}, {count}' 110 | 111 | 112 | def test_two_sources_same_weights(): 113 | """Test webdataset with a two tar files, using --train-data-weights=1::1.""" 114 | input_dir = build_inputs('two_sources_same_weights') 115 | input_shards = f"{os.path.join(input_dir, 'test_data_000.tar')}::{os.path.join(input_dir, 'test_data_001.tar')}" 116 | args, preprocess_img, tokenizer = build_params(input_shards) 117 | args.train_data_upsampling_factors = '1::1' 118 | dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer) 119 | dataloader = dataset.dataloader 120 | 121 | counts = collections.defaultdict(int) 122 | for sample in dataloader: 123 | txts = sample[1] 124 | for txt in txts: 125 | counts[txt] += 1 126 | 127 | for key, count in counts.items(): 128 | assert count == pytest.approx(TRAIN_NUM_SAMPLES / 15, RTOL), f'{key}, {count}' 129 | 130 | def test_two_sources_with_upsampling(): 131 | """Test webdataset with a two tar files with upsampling.""" 132 | input_dir = build_inputs('two_sources_with_upsampling') 133 | input_shards = f"{os.path.join(input_dir, 'test_data_000.tar')}::{os.path.join(input_dir, 'test_data_001.tar')}" 134 | args, preprocess_img, tokenizer = build_params(input_shards) 135 | args.train_data_upsampling_factors = '1::2' 136 | dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer) 137 | dataloader = dataset.dataloader 138 | 139 | counts = collections.defaultdict(int) 140 | for sample in dataloader: 141 | txts = sample[1] 142 | for txt in txts: 143 | counts[txt] += 1 144 | 145 | for key, count in counts.items(): 146 | if key.startswith('000'): 147 | assert count == pytest.approx(TRAIN_NUM_SAMPLES / 20, RTOL), f'{key}, {count}' 148 | else: 149 | assert count == pytest.approx(TRAIN_NUM_SAMPLES / 10, RTOL), f'{key}, {count}' 150 | -------------------------------------------------------------------------------- /tests/util_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from PIL import Image 5 | import torch 6 | 7 | if __name__ != '__main__': 8 | import open_clip 9 | 10 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 11 | 12 | def seed_all(seed = 0): 13 | torch.backends.cudnn.deterministic = True 14 | torch.backends.cudnn.benchmark = False 15 | torch.use_deterministic_algorithms(True, warn_only=False) 16 | random.seed(seed) 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | 20 | def inference_text(model, model_name, batches): 21 | y = [] 22 | tokenizer = open_clip.get_tokenizer(model_name) 23 | with torch.no_grad(): 24 | for x in batches: 25 | x = tokenizer(x) 26 | y.append(model.encode_text(x)) 27 | return torch.stack(y) 28 | 29 | def inference_image(model, preprocess_val, batches): 30 | y = [] 31 | with torch.no_grad(): 32 | for x in batches: 33 | x = torch.stack([preprocess_val(img) for img in x]) 34 | y.append(model.encode_image(x)) 35 | return torch.stack(y) 36 | 37 | def forward_model(model, model_name, preprocess_val, image_batch, text_batch): 38 | y = [] 39 | tokenizer = open_clip.get_tokenizer(model_name) 40 | with torch.no_grad(): 41 | for x_im, x_txt in zip(image_batch, text_batch): 42 | x_im = torch.stack([preprocess_val(im) for im in x_im]) 43 | x_txt = tokenizer(x_txt) 44 | y.append(model(x_im, x_txt)) 45 | if type(y[0]) == dict: 46 | out = {} 47 | for key in y[0].keys(): 48 | out[key] = torch.stack([batch_out[key] for batch_out in y]) 49 | else: 50 | out = [] 51 | for i in range(len(y[0])): 52 | out.append(torch.stack([batch_out[i] for batch_out in y])) 53 | return out 54 | 55 | def random_image_batch(batch_size, size): 56 | h, w = size 57 | data = np.random.randint(255, size = (batch_size, h, w, 3), dtype = np.uint8) 58 | return [ Image.fromarray(d) for d in data ] 59 | 60 | def random_text_batch(batch_size, min_length = 75, max_length = 75): 61 | t = open_clip.tokenizer.SimpleTokenizer() 62 | # every token decoded as string, exclude SOT and EOT, replace EOW with space 63 | token_words = [ 64 | x[1].replace('', ' ') 65 | for x in t.decoder.items() 66 | if x[0] not in t.all_special_ids 67 | ] 68 | # strings of randomly chosen tokens 69 | return [ 70 | ''.join(random.choices( 71 | token_words, 72 | k = random.randint(min_length, max_length) 73 | )) 74 | for _ in range(batch_size) 75 | ] 76 | 77 | def create_random_text_data( 78 | path, 79 | min_length = 75, 80 | max_length = 75, 81 | batches = 1, 82 | batch_size = 1 83 | ): 84 | text_batches = [ 85 | random_text_batch(batch_size, min_length, max_length) 86 | for _ in range(batches) 87 | ] 88 | print(f"{path}") 89 | torch.save(text_batches, path) 90 | 91 | def create_random_image_data(path, size, batches = 1, batch_size = 1): 92 | image_batches = [ 93 | random_image_batch(batch_size, size) 94 | for _ in range(batches) 95 | ] 96 | print(f"{path}") 97 | torch.save(image_batches, path) 98 | 99 | def get_data_dirs(make_dir = True): 100 | data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data') 101 | input_dir = os.path.join(data_dir, 'input') 102 | output_dir = os.path.join(data_dir, 'output') 103 | if make_dir: 104 | os.makedirs(input_dir, exist_ok = True) 105 | os.makedirs(output_dir, exist_ok = True) 106 | assert os.path.isdir(data_dir), f"data directory missing, expected at {input_dir}" 107 | assert os.path.isdir(data_dir), f"data directory missing, expected at {output_dir}" 108 | return input_dir, output_dir 109 | 110 | def create_test_data_for_model( 111 | model_name, 112 | pretrained = None, 113 | precision = 'fp32', 114 | jit = False, 115 | pretrained_hf = False, 116 | force_quick_gelu = False, 117 | create_missing_input_data = True, 118 | batches = 1, 119 | batch_size = 1, 120 | overwrite = False 121 | ): 122 | model_id = f'{model_name}_{pretrained or pretrained_hf}_{precision}' 123 | input_dir, output_dir = get_data_dirs() 124 | output_file_text = os.path.join(output_dir, f'{model_id}_random_text.pt') 125 | output_file_image = os.path.join(output_dir, f'{model_id}_random_image.pt') 126 | text_exists = os.path.exists(output_file_text) 127 | image_exists = os.path.exists(output_file_image) 128 | if not overwrite and text_exists and image_exists: 129 | return 130 | seed_all() 131 | model, _, preprocess_val = open_clip.create_model_and_transforms( 132 | model_name, 133 | pretrained = pretrained, 134 | precision = precision, 135 | jit = jit, 136 | force_quick_gelu = force_quick_gelu, 137 | pretrained_hf = pretrained_hf 138 | ) 139 | # text 140 | if overwrite or not text_exists: 141 | input_file_text = os.path.join(input_dir, 'random_text.pt') 142 | if create_missing_input_data and not os.path.exists(input_file_text): 143 | create_random_text_data( 144 | input_file_text, 145 | batches = batches, 146 | batch_size = batch_size 147 | ) 148 | assert os.path.isfile(input_file_text), f"missing input data, expected at {input_file_text}" 149 | input_data_text = torch.load(input_file_text) 150 | output_data_text = inference_text(model, model_name, input_data_text) 151 | print(f"{output_file_text}") 152 | torch.save(output_data_text, output_file_text) 153 | # image 154 | if overwrite or not image_exists: 155 | size = model.visual.image_size 156 | if not isinstance(size, tuple): 157 | size = (size, size) 158 | input_file_image = os.path.join(input_dir, f'random_image_{size[0]}_{size[1]}.pt') 159 | if create_missing_input_data and not os.path.exists(input_file_image): 160 | create_random_image_data( 161 | input_file_image, 162 | size, 163 | batches = batches, 164 | batch_size = batch_size 165 | ) 166 | assert os.path.isfile(input_file_image), f"missing input data, expected at {input_file_image}" 167 | input_data_image = torch.load(input_file_image) 168 | output_data_image = inference_image(model, preprocess_val, input_data_image) 169 | print(f"{output_file_image}") 170 | torch.save(output_data_image, output_file_image) 171 | 172 | def create_test_data( 173 | models, 174 | batches = 1, 175 | batch_size = 1, 176 | overwrite = False 177 | ): 178 | models = list(set(models).difference({ 179 | # not available with timm 180 | # see https://github.com/mlfoundations/open_clip/issues/219 181 | 'timm-convnext_xlarge', 182 | 'timm-vit_medium_patch16_gap_256' 183 | }).intersection(open_clip.list_models())) 184 | models.sort() 185 | print(f"generating test data for:\n{models}") 186 | for model_name in models: 187 | print(model_name) 188 | create_test_data_for_model( 189 | model_name, 190 | batches = batches, 191 | batch_size = batch_size, 192 | overwrite = overwrite 193 | ) 194 | return models 195 | 196 | def _sytem_assert(string): 197 | assert os.system(string) == 0 198 | 199 | class TestWrapper(torch.nn.Module): 200 | output_dict: torch.jit.Final[bool] 201 | def __init__(self, model, model_name, output_dict=True) -> None: 202 | super().__init__() 203 | self.model = model 204 | self.output_dict = output_dict 205 | if type(model) in [open_clip.CLIP, open_clip.CustomTextCLIP]: 206 | self.model.output_dict = self.output_dict 207 | config = open_clip.get_model_config(model_name) 208 | self.head = torch.nn.Linear(config["embed_dim"], 2) 209 | 210 | def forward(self, image, text): 211 | x = self.model(image, text) 212 | x = x['image_features'] if self.output_dict else x[0] 213 | assert x is not None # remove Optional[], type refinement for torchscript 214 | out = self.head(x) 215 | return {"test_output": out} 216 | 217 | def main(args): 218 | global open_clip 219 | import importlib 220 | import shutil 221 | import subprocess 222 | import argparse 223 | parser = argparse.ArgumentParser(description = "Populate test data directory") 224 | parser.add_argument( 225 | '-a', '--all', 226 | action = 'store_true', 227 | help = "create test data for all models" 228 | ) 229 | parser.add_argument( 230 | '-m', '--model', 231 | type = str, 232 | default = [], 233 | nargs = '+', 234 | help = "model(s) to create test data for" 235 | ) 236 | parser.add_argument( 237 | '-f', '--model_list', 238 | type = str, 239 | help = "path to a text file containing a list of model names, one model per line" 240 | ) 241 | parser.add_argument( 242 | '-s', '--save_model_list', 243 | type = str, 244 | help = "path to save the list of models that data was generated for" 245 | ) 246 | parser.add_argument( 247 | '-g', '--git_revision', 248 | type = str, 249 | help = "git revision to generate test data for" 250 | ) 251 | parser.add_argument( 252 | '--overwrite', 253 | action = 'store_true', 254 | help = "overwrite existing output data" 255 | ) 256 | parser.add_argument( 257 | '-n', '--num_batches', 258 | default = 1, 259 | type = int, 260 | help = "amount of data batches to create (default: 1)" 261 | ) 262 | parser.add_argument( 263 | '-b', '--batch_size', 264 | default = 1, 265 | type = int, 266 | help = "test data batch size (default: 1)" 267 | ) 268 | args = parser.parse_args(args) 269 | model_list = [] 270 | if args.model_list is not None: 271 | with open(args.model_list, 'r') as f: 272 | model_list = f.read().splitlines() 273 | if not args.all and len(args.model) < 1 and len(model_list) < 1: 274 | print("error: at least one model name is required") 275 | parser.print_help() 276 | parser.exit(1) 277 | if args.git_revision is not None: 278 | stash_output = subprocess.check_output(['git', 'stash']).decode().splitlines() 279 | has_stash = len(stash_output) > 0 and stash_output[0] != 'No local changes to save' 280 | current_branch = subprocess.check_output(['git', 'branch', '--show-current']) 281 | if len(current_branch) < 1: 282 | # not on a branch -> detached head 283 | current_branch = subprocess.check_output(['git', 'rev-parse', 'HEAD']) 284 | current_branch = current_branch.splitlines()[0].decode() 285 | try: 286 | _sytem_assert(f'git checkout {args.git_revision}') 287 | except AssertionError as e: 288 | _sytem_assert(f'git checkout -f {current_branch}') 289 | if has_stash: 290 | os.system(f'git stash pop') 291 | raise e 292 | open_clip = importlib.import_module('open_clip') 293 | models = open_clip.list_models() if args.all else args.model + model_list 294 | try: 295 | models = create_test_data( 296 | models, 297 | batches = args.num_batches, 298 | batch_size = args.batch_size, 299 | overwrite = args.overwrite 300 | ) 301 | finally: 302 | if args.git_revision is not None: 303 | test_dir = os.path.join(os.path.dirname(__file__), 'data') 304 | test_dir_ref = os.path.join(os.path.dirname(__file__), 'data_ref') 305 | if os.path.exists(test_dir_ref): 306 | shutil.rmtree(test_dir_ref, ignore_errors = True) 307 | if os.path.exists(test_dir): 308 | os.rename(test_dir, test_dir_ref) 309 | _sytem_assert(f'git checkout {current_branch}') 310 | if has_stash: 311 | os.system(f'git stash pop') 312 | os.rename(test_dir_ref, test_dir) 313 | if args.save_model_list is not None: 314 | print(f"Saving model list as {args.save_model_list}") 315 | with open(args.save_model_list, 'w') as f: 316 | for m in models: 317 | print(m, file=f) 318 | 319 | 320 | if __name__ == '__main__': 321 | import sys 322 | main(sys.argv[1:]) 323 | 324 | --------------------------------------------------------------------------------