├── .DS_Store ├── LICENSE ├── README.md ├── docs ├── .DS_Store ├── README.md ├── index.html ├── resources │ ├── .DS_Store │ ├── LLaVA.png │ ├── ar.svg │ ├── block_mask.jpg │ ├── first_mask.jpg │ ├── gr.svg │ ├── hf_dataset.jpg │ ├── hg.svg │ ├── icon.png │ ├── mask_strategy.jpg │ ├── method.jpg │ ├── random_mask.jpg │ ├── retrieval.png │ ├── sota.png │ ├── subcaption_mask.jpg │ └── tw.png └── static │ ├── css │ ├── bulma-carousel.min.css │ ├── bulma-slider.min.css │ ├── bulma.css.map.txt │ ├── bulma.min.css │ ├── fontawesome.all.min.css │ └── index.css │ └── js │ ├── bulma-carousel.js │ ├── bulma-carousel.min.js │ ├── bulma-slider.js │ ├── bulma-slider.min.js │ ├── fontawesome.all.min.js │ └── index.js ├── inference.py ├── open_clip ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-311.pyc │ ├── coca_model.cpython-311.pyc │ ├── constants.cpython-311.pyc │ ├── convert.cpython-311.pyc │ ├── factory.cpython-311.pyc │ ├── hf_configs.cpython-311.pyc │ ├── hf_model.cpython-311.pyc │ ├── loss.cpython-311.pyc │ ├── model.cpython-311.pyc │ ├── modified_resnet.cpython-311.pyc │ ├── openai.cpython-311.pyc │ ├── pos_embed.cpython-311.pyc │ ├── pretrained.cpython-311.pyc │ ├── push_to_hf_hub.cpython-311.pyc │ ├── timm_model.cpython-311.pyc │ ├── tokenizer.cpython-311.pyc │ ├── transform.cpython-311.pyc │ ├── transformer.cpython-311.pyc │ ├── utils.cpython-311.pyc │ ├── version.cpython-311.pyc │ ├── zero_shot_classifier.cpython-311.pyc │ └── zero_shot_metadata.cpython-311.pyc ├── bpe_simple_vocab_16e6.txt.gz ├── coca_model.py ├── constants.py ├── convert.py ├── factory.py ├── hf_configs.py ├── hf_model.py ├── loss.py ├── model.py ├── model_configs │ ├── EVA01-g-14-plus.json │ ├── EVA01-g-14.json │ ├── EVA02-B-16.json │ ├── EVA02-E-14-plus.json │ ├── EVA02-E-14.json │ ├── EVA02-L-14-336.json │ ├── EVA02-L-14.json │ ├── MobileCLIP-B.json │ ├── MobileCLIP-S1.json │ ├── MobileCLIP-S2.json │ ├── RN101-quickgelu.json │ ├── RN101.json │ ├── RN50-quickgelu.json │ ├── RN50.json │ ├── RN50x16-quickgelu.json │ ├── RN50x16.json │ ├── RN50x4-quickgelu.json │ ├── RN50x4.json │ ├── RN50x64-quickgelu.json │ ├── RN50x64.json │ ├── ViT-B-16-SigLIP-256.json │ ├── ViT-B-16-SigLIP-384.json │ ├── ViT-B-16-SigLIP-512.json │ ├── ViT-B-16-SigLIP-i18n-256.json │ ├── ViT-B-16-SigLIP.json │ ├── ViT-B-16-SigLIP2-256.json │ ├── ViT-B-16-SigLIP2-384.json │ ├── ViT-B-16-SigLIP2-512.json │ ├── ViT-B-16-SigLIP2.json │ ├── ViT-B-16-plus-240.json │ ├── ViT-B-16-plus.json │ ├── ViT-B-16-quickgelu.json │ ├── ViT-B-16.json │ ├── ViT-B-32-256.json │ ├── ViT-B-32-SigLIP2-256.json │ ├── ViT-B-32-plus-256.json │ ├── ViT-B-32-quickgelu.json │ ├── ViT-B-32.json │ ├── ViT-H-14-378-quickgelu.json │ ├── ViT-H-14-378.json │ ├── ViT-H-14-CLIPA-336.json │ ├── ViT-H-14-CLIPA.json │ ├── ViT-H-14-CLIPS-224.json │ ├── ViT-H-14-quickgelu.json │ ├── ViT-H-14.json │ ├── ViT-H-16.json │ ├── ViT-L-14-280.json │ ├── ViT-L-14-336-quickgelu.json │ ├── ViT-L-14-336.json │ ├── ViT-L-14-CLIPA-336.json │ ├── ViT-L-14-CLIPA.json │ ├── ViT-L-14-CLIPS-224.json │ ├── ViT-L-14-CLIPS-336.json │ ├── ViT-L-14-quickgelu.json │ ├── ViT-L-14.json │ ├── ViT-L-16-320.json │ ├── ViT-L-16-SigLIP-256.json │ ├── ViT-L-16-SigLIP-384.json │ ├── ViT-L-16-SigLIP2-256.json │ ├── ViT-L-16-SigLIP2-384.json │ ├── ViT-L-16-SigLIP2-512.json │ ├── ViT-L-16.json │ ├── ViT-M-16-alt.json │ ├── ViT-M-16.json │ ├── ViT-M-32-alt.json │ ├── ViT-M-32.json │ ├── ViT-S-16-alt.json │ ├── ViT-S-16.json │ ├── ViT-S-32-alt.json │ ├── ViT-S-32.json │ ├── ViT-SO400M-14-SigLIP-378.json │ ├── ViT-SO400M-14-SigLIP-384.json │ ├── ViT-SO400M-14-SigLIP.json │ ├── ViT-SO400M-14-SigLIP2-378.json │ ├── ViT-SO400M-14-SigLIP2.json │ ├── ViT-SO400M-16-SigLIP-i18n-256.json │ ├── ViT-SO400M-16-SigLIP2-256.json │ ├── ViT-SO400M-16-SigLIP2-384.json │ ├── ViT-SO400M-16-SigLIP2-512.json │ ├── ViT-bigG-14-CLIPA-336.json │ ├── ViT-bigG-14-CLIPA.json │ ├── ViT-bigG-14-quickgelu.json │ ├── ViT-bigG-14.json │ ├── ViT-e-14.json │ ├── ViT-g-14.json │ ├── ViT-gopt-16-SigLIP2-256.json │ ├── ViT-gopt-16-SigLIP2-384.json │ ├── ViTamin-B-LTT.json │ ├── ViTamin-B.json │ ├── ViTamin-L-256.json │ ├── ViTamin-L-336.json │ ├── ViTamin-L-384.json │ ├── ViTamin-L.json │ ├── ViTamin-L2-256.json │ ├── ViTamin-L2-336.json │ ├── ViTamin-L2-384.json │ ├── ViTamin-L2.json │ ├── ViTamin-S-LTT.json │ ├── ViTamin-S.json │ ├── ViTamin-XL-256.json │ ├── ViTamin-XL-336.json │ ├── ViTamin-XL-384.json │ ├── coca_ViT-B-32.json │ ├── coca_ViT-L-14.json │ ├── coca_base.json │ ├── coca_roberta-ViT-B-32.json │ ├── convnext_base.json │ ├── convnext_base_w.json │ ├── convnext_base_w_320.json │ ├── convnext_large.json │ ├── convnext_large_d.json │ ├── convnext_large_d_320.json │ ├── convnext_small.json │ ├── convnext_tiny.json │ ├── convnext_xlarge.json │ ├── convnext_xxlarge.json │ ├── convnext_xxlarge_320.json │ ├── mt5-base-ViT-B-32.json │ ├── mt5-xl-ViT-H-14.json │ ├── nllb-clip-base-siglip.json │ ├── nllb-clip-base.json │ ├── nllb-clip-large-siglip.json │ ├── nllb-clip-large.json │ ├── roberta-ViT-B-32.json │ ├── swin_base_patch4_window7_224.json │ ├── vit_medium_patch16_gap_256.json │ ├── vit_relpos_medium_patch16_cls_224.json │ ├── xlm-roberta-base-ViT-B-32.json │ └── xlm-roberta-large-ViT-H-14.json ├── modified_resnet.py ├── openai.py ├── pos_embed.py ├── pretrained.py ├── push_to_hf_hub.py ├── timm_model.py ├── tokenizer.py ├── transform.py ├── transformer.py ├── utils.py ├── version.py ├── zero_shot_classifier.py └── zero_shot_metadata.py ├── open_clip_train ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-311.pyc │ ├── data.cpython-311.pyc │ ├── distributed.cpython-311.pyc │ ├── file_utils.cpython-311.pyc │ ├── logger.cpython-311.pyc │ ├── params.cpython-311.pyc │ ├── precision.cpython-311.pyc │ ├── scheduler.cpython-311.pyc │ ├── train.cpython-311.pyc │ └── zero_shot.cpython-311.pyc ├── data.py ├── distributed.py ├── file_utils.py ├── logger.py ├── main.py ├── params.py ├── precision.py ├── profiler.py ├── scheduler.py ├── train.py └── zero_shot.py ├── requirements.txt └── vocab.txt /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 VLAA@UCSC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **CLIPS** 2 | 3 | **Official implementation of the paper "[_CLIPS: An Enhanced CLIP Framework for Learning with Synthetic Captions_](https://arxiv.org/abs/2411.16828)".** 4 | 5 | 6 | ![Method Pipeline](./docs/resources/method.jpg) 7 | 8 | Previous works show that noisy, web-crawled image-text pairs may limit vision-language pretraining like CLIP and propose learning with synthetic captions as a promising alternative. Our work continues this effort, introducing two simple yet effective designs to better leverage richly described synthetic captions: 9 | 10 | 1. By observing a strong inverse effect with synthetic captions, we use only **partial synthetic captions** to feed the text encoder, achieving significantly better performance. 11 | 2. We incorporate an **autoregressive captioner** that mimics the recaptioning process, predicting full-length synthetic captions conditioned on the image and original web-crawled captions. 12 | 13 | Our method achieves **state-of-the-art (SOTA)** results in zero-shot image-text retrieval on MSCOCO and Flickr30K, while enhancing the visual capability of LLaVA. 14 | 15 | --- 16 | 17 | ## **Links** 18 | - [📄 Paper (arXiv)](https://arxiv.org/abs/2411.16828) 19 | - [🤗 Pretrained Model on HuggingFace](https://huggingface.co/UCSC-VLAA/ViT-L-14-CLIPS-Recap-DataComp-1B) 20 | - [🌐 Project Page](https://ucsc-vlaa.github.io/CLIPS/) 21 | 22 | --- 23 | 24 | ## **Key Results** 25 | 26 | ### **Inverse Effect with Synthetic Captions** 27 | ![Inverse Effect Visualization](./docs/resources/mask_strategy.jpg) 28 | 29 | Visualization of four different token reduction strategies. These strategies can improve the model's learning efficiency on synthetic captions to varying degrees. Among these strategies, the sub-caption and block mask perform best. 30 | 31 | --- 32 | 33 | ### **Zero-Shot Cross-Modal Retrieval** 34 | ![Zero-Shot Retrieval Results](./docs/resources/retrieval.png) 35 | 36 | Our method consistently achieves superior performance across all benchmarks and model sizes, yielding significant improvements over the baselines. 37 | 38 | --- 39 | 40 | ### **Comparison with State-of-the-Art Methods** 41 | ![SOTA Comparison](./docs/resources/sota.png) 42 | 43 | With increased computational resources and scaling, our best model further achieves 76.4% and 96.6% R@1 text retrieval performance on MSCOCO and Flickr30K respectively, and 57.2% and 83.9% R@1 image retrieval performance on the same datasets, setting new state-of-the-art (SOTA) results. 44 | 45 | --- 46 | 47 | ### **CLIPS in LLaVA** 48 | ![LLaVA Results](./docs/resources/LLaVA.png) 49 | 50 | Replacing OpenAI-CLIP with **CLIPS** significantly boosts LLaVA's performance across various benchmarks. 51 | 52 | --- 53 | 54 | ## **Model Zoo** 55 | 56 | | Model | Link | 57 | |----------------|------------------------------------------------------------------------------------------| 58 | | CLIPS-Large-14-224 | [🤗 HuggingFace Model](https://huggingface.co/UCSC-VLAA/ViT-L-14-CLIPS-224-Recap-DataComp-1B) | 59 | | CLIPS-Large-14-336 | [🤗 HuggingFace Model](https://huggingface.co/UCSC-VLAA/ViT-L-14-CLIPS-Recap-DataComp-1B) | 60 | | CLIPS-Huge-14-224 | [🤗 HuggingFace Model](https://huggingface.co/UCSC-VLAA/ViT-H-14-CLIPS-224-Recap-DataComp-1B) | 61 | | CLIPS-Huge-14-336 | Coming Soon... | 62 | 63 | ## **Model Usage** 64 | ### **Environment** 65 | Install dependencies: 66 | ``` 67 | pip3 install -r requirements.txt 68 | ``` 69 | ### **With OpenCLIP** 70 | ```python 71 | import torch 72 | import torch.nn.functional as F 73 | from urllib.request import urlopen 74 | from PIL import Image 75 | from open_clip import create_model_from_pretrained, get_tokenizer 76 | 77 | model, preprocess = create_model_from_pretrained('hf-hub:UCSC-VLAA/ViT-L-14-CLIPS-Recap-DataComp-1B') 78 | tokenizer = get_tokenizer('hf-hub:UCSC-VLAA/ViT-L-14-CLIPS-Recap-DataComp-1B') 79 | 80 | image = Image.open(urlopen( 81 | 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png' 82 | )) 83 | image = preprocess(image).unsqueeze(0) 84 | 85 | text = tokenizer(["a diagram", "a dog", "a cat", "a beignet"], context_length=model.context_length) 86 | 87 | with torch.no_grad(), torch.cuda.amp.autocast(): 88 | image_features = model.encode_image(image) 89 | text_features = model.encode_text(text) 90 | image_features = F.normalize(image_features, dim=-1) 91 | text_features = F.normalize(text_features, dim=-1) 92 | 93 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) 94 | 95 | print("Label probs:", text_probs) # prints: [[0., 0., 0., 1.0]] 96 | ``` 97 | #### Note: We made modifications to the tokenizer implementation in open_clip/tokenizer.py. 98 | 99 | ## Acknowledgement 100 | 101 | This pytorch repo is built on [OpenCLIP](https://github.com/mlfoundations/open_clip). 102 | Many thanks to the awesome works from the open-source community! 103 | 104 | We would like to thank TPU Research Cloud (TRC) program, Google Cloud Research Credits program, and AWS Cloud Credit for Research program for supporting our computing needs. 105 | 106 | --- 107 | 108 | ## **Citation** 109 | 110 | If you use our work, please cite it: 111 | 112 | ```bibtex 113 | @article{liu2024clips, 114 | title={CLIPS: An Enhanced CLIP Framework for Learning with Synthetic Captions}, 115 | author={Liu, Yanqing and Li, Xianhang and Wang, Zeyu and Zhao, Bingchen and Xie, Cihang}, 116 | journal={arXiv preprint arXiv:2411.16828}, 117 | year={2024} 118 | } 119 | -------------------------------------------------------------------------------- /docs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/.DS_Store -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # CLIPS 2 | [https://ucsc-vlaa.github.io/CLIPS/](https://ucsc-vlaa.github.io/CLIPS/) 3 | -------------------------------------------------------------------------------- /docs/resources/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/.DS_Store -------------------------------------------------------------------------------- /docs/resources/LLaVA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/LLaVA.png -------------------------------------------------------------------------------- /docs/resources/ar.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/resources/block_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/block_mask.jpg -------------------------------------------------------------------------------- /docs/resources/first_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/first_mask.jpg -------------------------------------------------------------------------------- /docs/resources/gr.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /docs/resources/hf_dataset.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/hf_dataset.jpg -------------------------------------------------------------------------------- /docs/resources/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/icon.png -------------------------------------------------------------------------------- /docs/resources/mask_strategy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/mask_strategy.jpg -------------------------------------------------------------------------------- /docs/resources/method.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/method.jpg -------------------------------------------------------------------------------- /docs/resources/random_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/random_mask.jpg -------------------------------------------------------------------------------- /docs/resources/retrieval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/retrieval.png -------------------------------------------------------------------------------- /docs/resources/sota.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/sota.png -------------------------------------------------------------------------------- /docs/resources/subcaption_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/subcaption_mask.jpg -------------------------------------------------------------------------------- /docs/resources/tw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/tw.png -------------------------------------------------------------------------------- /docs/static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround { 2 | from { 3 | -webkit-transform: rotate(0); 4 | transform: rotate(0) 5 | } 6 | 7 | to { 8 | -webkit-transform: rotate(359deg); 9 | transform: rotate(359deg) 10 | } 11 | } 12 | 13 | @keyframes spinAround { 14 | from { 15 | -webkit-transform: rotate(0); 16 | transform: rotate(0) 17 | } 18 | 19 | to { 20 | -webkit-transform: rotate(359deg); 21 | transform: rotate(359deg) 22 | } 23 | } 24 | 25 | .slider { 26 | position: relative; 27 | width: 100% 28 | } 29 | 30 | .slider-container { 31 | display: flex; 32 | flex-wrap: nowrap; 33 | flex-direction: row; 34 | overflow: hidden; 35 | -webkit-transform: translate3d(0, 0, 0); 36 | transform: translate3d(0, 0, 0); 37 | min-height: 100% 38 | } 39 | 40 | .slider-container.is-vertical { 41 | flex-direction: column 42 | } 43 | 44 | .slider-container .slider-item { 45 | flex: none 46 | } 47 | 48 | .slider-container .slider-item .image.is-covered img { 49 | -o-object-fit: cover; 50 | object-fit: cover; 51 | -o-object-position: center center; 52 | object-position: center center; 53 | height: 100%; 54 | width: 100% 55 | } 56 | 57 | .slider-container .slider-item .video-container { 58 | height: 0; 59 | padding-bottom: 0; 60 | padding-top: 56.25%; 61 | margin: 0; 62 | position: relative 63 | } 64 | 65 | .slider-container .slider-item .video-container.is-1by1, 66 | .slider-container .slider-item .video-container.is-square { 67 | padding-top: 100% 68 | } 69 | 70 | .slider-container .slider-item .video-container.is-4by3 { 71 | padding-top: 75% 72 | } 73 | 74 | .slider-container .slider-item .video-container.is-21by9 { 75 | padding-top: 42.857143% 76 | } 77 | 78 | .slider-container .slider-item .video-container embed, 79 | .slider-container .slider-item .video-container iframe, 80 | .slider-container .slider-item .video-container object { 81 | position: absolute; 82 | top: 0; 83 | left: 0; 84 | width: 100% !important; 85 | height: 100% !important 86 | } 87 | 88 | .slider-navigation-next, 89 | .slider-navigation-previous { 90 | display: flex; 91 | justify-content: center; 92 | align-items: center; 93 | position: absolute; 94 | width: 42px; 95 | height: 42px; 96 | background: #fff center center no-repeat; 97 | background-size: 20px 20px; 98 | border: 1px solid #fff; 99 | border-radius: 25091983px; 100 | box-shadow: 0 2px 5px #3232321a; 101 | top: 50%; 102 | margin-top: -20px; 103 | left: 0; 104 | cursor: pointer; 105 | transition: opacity .3s, -webkit-transform .3s; 106 | transition: transform .3s, opacity .3s; 107 | transition: transform .3s, opacity .3s, -webkit-transform .3s 108 | } 109 | 110 | .slider-navigation-next:hover, 111 | .slider-navigation-previous:hover { 112 | -webkit-transform: scale(1.2); 113 | transform: scale(1.2) 114 | } 115 | 116 | .slider-navigation-next.is-hidden, 117 | .slider-navigation-previous.is-hidden { 118 | display: none; 119 | opacity: 0 120 | } 121 | 122 | .slider-navigation-next svg, 123 | .slider-navigation-previous svg { 124 | width: 25% 125 | } 126 | 127 | .slider-navigation-next { 128 | left: auto; 129 | right: 0; 130 | background: #fff center center no-repeat; 131 | background-size: 20px 20px 132 | } 133 | 134 | .slider-pagination { 135 | display: none; 136 | justify-content: center; 137 | align-items: center; 138 | position: absolute; 139 | bottom: 0; 140 | left: 0; 141 | right: 0; 142 | padding: .5rem 1rem; 143 | text-align: center 144 | } 145 | 146 | .slider-pagination .slider-page { 147 | background: #fff; 148 | width: 10px; 149 | height: 10px; 150 | border-radius: 25091983px; 151 | display: inline-block; 152 | margin: 0 3px; 153 | box-shadow: 0 2px 5px #3232321a; 154 | transition: -webkit-transform .3s; 155 | transition: transform .3s; 156 | transition: transform .3s, -webkit-transform .3s; 157 | cursor: pointer 158 | } 159 | 160 | .slider-pagination .slider-page.is-active, 161 | .slider-pagination .slider-page:hover { 162 | -webkit-transform: scale(1.4); 163 | transform: scale(1.4) 164 | } 165 | 166 | @media screen and (min-width:800px) { 167 | .slider-pagination { 168 | display: flex 169 | } 170 | } 171 | 172 | .hero.has-carousel { 173 | position: relative 174 | } 175 | 176 | .hero.has-carousel+.hero-body, 177 | .hero.has-carousel+.hero-footer, 178 | .hero.has-carousel+.hero-head { 179 | z-index: 10; 180 | overflow: hidden 181 | } 182 | 183 | .hero.has-carousel .hero-carousel { 184 | position: absolute; 185 | top: 0; 186 | left: 0; 187 | bottom: 0; 188 | right: 0; 189 | height: auto; 190 | border: none; 191 | margin: auto; 192 | padding: 0; 193 | z-index: 0 194 | } 195 | 196 | .hero.has-carousel .hero-carousel .slider { 197 | width: 100%; 198 | max-width: 100%; 199 | overflow: hidden; 200 | height: 100% !important; 201 | max-height: 100%; 202 | z-index: 0 203 | } 204 | 205 | .hero.has-carousel .hero-carousel .slider .has-background { 206 | max-height: 100% 207 | } 208 | 209 | .hero.has-carousel .hero-carousel .slider .has-background .is-background { 210 | -o-object-fit: cover; 211 | object-fit: cover; 212 | -o-object-position: center center; 213 | object-position: center center; 214 | height: 100%; 215 | width: 100% 216 | } 217 | 218 | .hero.has-carousel .hero-body { 219 | margin: 0 3rem; 220 | z-index: 10 221 | } -------------------------------------------------------------------------------- /docs/static/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Noto Sans', sans-serif; 3 | } 4 | 5 | 6 | .footer .icon-link { 7 | font-size: 25px; 8 | color: #000; 9 | } 10 | 11 | .link-block a { 12 | margin-top: 5px; 13 | margin-bottom: 5px; 14 | } 15 | 16 | .dnerf { 17 | font-variant: small-caps; 18 | } 19 | 20 | 21 | .teaser .hero-body { 22 | padding-top: 0; 23 | padding-bottom: 0; 24 | } 25 | 26 | .teaser { 27 | font-family: 'Google Sans', sans-serif; 28 | } 29 | 30 | 31 | .publication-title { 32 | } 33 | 34 | .publication-banner { 35 | max-height: parent; 36 | 37 | } 38 | 39 | .publication-banner video { 40 | position: relative; 41 | left: auto; 42 | top: auto; 43 | transform: none; 44 | object-fit: fit; 45 | } 46 | 47 | .publication-header .hero-body { 48 | } 49 | 50 | .publication-title { 51 | font-family: 'Google Sans', sans-serif; 52 | } 53 | 54 | .publication-authors { 55 | font-family: 'Google Sans', sans-serif; 56 | } 57 | 58 | .publication-venue { 59 | color: #555; 60 | width: fit-content; 61 | font-weight: bold; 62 | } 63 | 64 | .publication-awards { 65 | color: #ff3860; 66 | width: fit-content; 67 | font-weight: bolder; 68 | } 69 | 70 | .publication-authors { 71 | } 72 | 73 | .publication-authors a { 74 | color: hsl(204, 86%, 53%) !important; 75 | } 76 | 77 | .publication-authors a:hover { 78 | text-decoration: underline; 79 | } 80 | 81 | .author-block { 82 | display: inline-block; 83 | } 84 | 85 | .publication-banner img { 86 | } 87 | 88 | .publication-authors { 89 | /*color: #4286f4;*/ 90 | } 91 | 92 | .publication-video { 93 | position: relative; 94 | width: 100%; 95 | height: 0; 96 | padding-bottom: 56.25%; 97 | 98 | overflow: hidden; 99 | border-radius: 10px !important; 100 | } 101 | 102 | .publication-video iframe { 103 | position: absolute; 104 | top: 0; 105 | left: 0; 106 | width: 100%; 107 | height: 100%; 108 | } 109 | 110 | .publication-body img { 111 | } 112 | 113 | .results-carousel { 114 | overflow: hidden; 115 | } 116 | 117 | .results-carousel .item { 118 | margin: 5px; 119 | overflow: hidden; 120 | border: 1px solid #bbb; 121 | border-radius: 10px; 122 | padding: 0; 123 | font-size: 0; 124 | } 125 | 126 | .results-carousel video { 127 | margin: 0; 128 | } 129 | 130 | 131 | .interpolation-panel { 132 | background: #f5f5f5; 133 | border-radius: 10px; 134 | } 135 | 136 | .interpolation-panel .interpolation-image { 137 | width: 100%; 138 | border-radius: 5px; 139 | } 140 | 141 | .interpolation-video-column { 142 | } 143 | 144 | .interpolation-panel .slider { 145 | margin: 0 !important; 146 | } 147 | 148 | .interpolation-panel .slider { 149 | margin: 0 !important; 150 | } 151 | 152 | #interpolation-image-wrapper { 153 | width: 100%; 154 | } 155 | #interpolation-image-wrapper img { 156 | border-radius: 5px; 157 | } 158 | -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default}); -------------------------------------------------------------------------------- /docs/static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | var INTERP_BASE = "./static/interpolation/stacked"; 4 | var NUM_INTERP_FRAMES = 240; 5 | 6 | var interp_images = []; 7 | function preloadInterpolationImages() { 8 | for (var i = 0; i < NUM_INTERP_FRAMES; i++) { 9 | var path = INTERP_BASE + '/' + String(i).padStart(6, '0') + '.jpg'; 10 | interp_images[i] = new Image(); 11 | interp_images[i].src = path; 12 | } 13 | } 14 | 15 | function setInterpolationImage(i) { 16 | var image = interp_images[i]; 17 | image.ondragstart = function() { return false; }; 18 | image.oncontextmenu = function() { return false; }; 19 | $('#interpolation-image-wrapper').empty().append(image); 20 | } 21 | 22 | 23 | $(document).ready(function() { 24 | // Check for click events on the navbar burger icon 25 | $(".navbar-burger").click(function() { 26 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu" 27 | $(".navbar-burger").toggleClass("is-active"); 28 | $(".navbar-menu").toggleClass("is-active"); 29 | 30 | }); 31 | 32 | var options = { 33 | slidesToScroll: 1, 34 | slidesToShow: 3, 35 | loop: true, 36 | infinite: true, 37 | autoplay: false, 38 | autoplaySpeed: 3000, 39 | } 40 | 41 | // Initialize all div with carousel class 42 | var carousels = bulmaCarousel.attach('.carousel', options); 43 | 44 | // Loop on each carousel initialized 45 | for(var i = 0; i < carousels.length; i++) { 46 | // Add listener to event 47 | carousels[i].on('before:show', state => { 48 | console.log(state); 49 | }); 50 | } 51 | 52 | // Access to bulmaCarousel instance of an element 53 | var element = document.querySelector('#my-element'); 54 | if (element && element.bulmaCarousel) { 55 | // bulmaCarousel instance is available as element.bulmaCarousel 56 | element.bulmaCarousel.on('before-show', function(state) { 57 | console.log(state); 58 | }); 59 | } 60 | 61 | /*var player = document.getElementById('interpolation-video'); 62 | player.addEventListener('loadedmetadata', function() { 63 | $('#interpolation-slider').on('input', function(event) { 64 | console.log(this.value, player.duration); 65 | player.currentTime = player.duration / 100 * this.value; 66 | }) 67 | }, false);*/ 68 | preloadInterpolationImages(); 69 | 70 | $('#interpolation-slider').on('input', function(event) { 71 | setInterpolationImage(this.value); 72 | }); 73 | setInterpolationImage(0); 74 | $('#interpolation-slider').prop('max', NUM_INTERP_FRAMES - 1); 75 | 76 | bulmaSlider.attach(); 77 | 78 | }) 79 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from urllib.request import urlopen 4 | from PIL import Image 5 | from open_clip import create_model_from_pretrained, get_tokenizer 6 | 7 | model, preprocess = create_model_from_pretrained('hf-hub:UCSC-VLAA/ViT-L-14-CLIPS-Recap-DataComp-1B') 8 | tokenizer = get_tokenizer('hf-hub:UCSC-VLAA/ViT-L-14-CLIPS-Recap-DataComp-1B') 9 | 10 | image = Image.open(urlopen( 11 | 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png' 12 | )) 13 | image = preprocess(image).unsqueeze(0) 14 | 15 | text = tokenizer(["a diagram", "a dog", "a cat", "a beignet"], context_length=model.context_length) 16 | 17 | with torch.no_grad(), torch.cuda.amp.autocast(): 18 | image_features = model.encode_image(image) 19 | text_features = model.encode_text(text) 20 | image_features = F.normalize(image_features, dim=-1) 21 | text_features = F.normalize(text_features, dim=-1) 22 | 23 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) 24 | 25 | print("Label probs:", text_probs) # prints: [[0., 0., 0., 1.0]] 26 | -------------------------------------------------------------------------------- /open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | 3 | from .coca_model import CoCa 4 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 5 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss 6 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 7 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss 8 | from .model import CLIPS, CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ 9 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \ 10 | get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg 11 | from .openai import load_openai_model, list_openai_models 12 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ 13 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 14 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub 15 | from .tokenizer import SimpleTokenizer, tokenize, decode 16 | from .transform import image_transform, AugmentationCfg 17 | from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy 18 | from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES 19 | -------------------------------------------------------------------------------- /open_clip/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/coca_model.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/coca_model.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/constants.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/constants.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/convert.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/convert.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/factory.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/factory.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/hf_configs.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/hf_configs.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/hf_model.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/hf_model.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/loss.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/loss.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/model.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/model.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/modified_resnet.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/modified_resnet.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/openai.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/openai.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/pos_embed.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/pos_embed.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/pretrained.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/pretrained.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/push_to_hf_hub.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/push_to_hf_hub.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/timm_model.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/timm_model.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/tokenizer.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/tokenizer.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/transform.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/transform.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/transformer.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/transformer.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/version.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/version.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/zero_shot_classifier.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/zero_shot_classifier.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/zero_shot_metadata.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/zero_shot_metadata.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /open_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 4 | IMAGENET_STD = (0.229, 0.224, 0.225) 5 | INCEPTION_MEAN = (0.5, 0.5, 0.5) 6 | INCEPTION_STD = (0.5, 0.5, 0.5) 7 | 8 | # Default name for a weights file hosted on the Huggingface Hub. 9 | HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl 10 | HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version 11 | HF_CONFIG_NAME = 'open_clip_config.json' 12 | -------------------------------------------------------------------------------- /open_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | # https://huggingface.co/docs/transformers/model_doc/bert 46 | "bert": { 47 | "config_names": { 48 | "context_length": "max_position_embeddings", 49 | "vocab_size": "vocab_size", 50 | "width": "hidden_size", 51 | "heads": "num_attention_heads", 52 | "layers": "num_hidden_layers", 53 | }, 54 | "pooler": "cls_pooler", 55 | }, 56 | # https://huggingface.co/docs/transformers/model_doc/m2m_100 57 | "m2m_100": { 58 | "config_names": { 59 | "context_length": "max_position_embeddings", 60 | "vocab_size": "vocab_size", 61 | "width": "d_model", 62 | "heads": "encoder_attention_heads", 63 | "layers": "encoder_layers", 64 | }, 65 | "pooler": "cls_pooler", 66 | }, 67 | } 68 | -------------------------------------------------------------------------------- /open_clip/hf_model.py: -------------------------------------------------------------------------------- 1 | """ huggingface model adapter 2 | 3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. 4 | """ 5 | import re 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch import TensorType 10 | 11 | try: 12 | import transformers 13 | from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig 14 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ 15 | BaseModelOutputWithPoolingAndCrossAttentions 16 | except ImportError as e: 17 | transformers = None 18 | 19 | 20 | class BaseModelOutput: 21 | pass 22 | 23 | 24 | class PretrainedConfig: 25 | pass 26 | 27 | from .hf_configs import arch_dict 28 | 29 | 30 | # utils 31 | def _camel2snake(s): 32 | return re.sub(r'(? 1 18 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.act1 = nn.ReLU(inplace=True) 21 | 22 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | self.act2 = nn.ReLU(inplace=True) 25 | 26 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 27 | 28 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 29 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 30 | self.act3 = nn.ReLU(inplace=True) 31 | 32 | self.downsample = None 33 | self.stride = stride 34 | 35 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 36 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 37 | self.downsample = nn.Sequential(OrderedDict([ 38 | ("-1", nn.AvgPool2d(stride)), 39 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 40 | ("1", nn.BatchNorm2d(planes * self.expansion)) 41 | ])) 42 | 43 | def forward(self, x: torch.Tensor): 44 | identity = x 45 | 46 | out = self.act1(self.bn1(self.conv1(x))) 47 | out = self.act2(self.bn2(self.conv2(out))) 48 | out = self.avgpool(out) 49 | out = self.bn3(self.conv3(out)) 50 | 51 | if self.downsample is not None: 52 | identity = self.downsample(x) 53 | 54 | out += identity 55 | out = self.act3(out) 56 | return out 57 | 58 | 59 | class AttentionPool2d(nn.Module): 60 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 61 | super().__init__() 62 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 63 | self.k_proj = nn.Linear(embed_dim, embed_dim) 64 | self.q_proj = nn.Linear(embed_dim, embed_dim) 65 | self.v_proj = nn.Linear(embed_dim, embed_dim) 66 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 67 | self.num_heads = num_heads 68 | 69 | def forward(self, x): 70 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 71 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 72 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 73 | x, _ = F.multi_head_attention_forward( 74 | query=x, key=x, value=x, 75 | embed_dim_to_check=x.shape[-1], 76 | num_heads=self.num_heads, 77 | q_proj_weight=self.q_proj.weight, 78 | k_proj_weight=self.k_proj.weight, 79 | v_proj_weight=self.v_proj.weight, 80 | in_proj_weight=None, 81 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 82 | bias_k=None, 83 | bias_v=None, 84 | add_zero_attn=False, 85 | dropout_p=0., 86 | out_proj_weight=self.c_proj.weight, 87 | out_proj_bias=self.c_proj.bias, 88 | use_separate_proj_weight=True, 89 | training=self.training, 90 | need_weights=False 91 | ) 92 | 93 | return x[0] 94 | 95 | 96 | class ModifiedResNet(nn.Module): 97 | """ 98 | A ResNet class that is similar to torchvision's but contains the following changes: 99 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 100 | - Performs antialiasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 101 | - The final pooling layer is a QKV attention instead of an average pool 102 | """ 103 | 104 | def __init__( 105 | self, 106 | layers: List[int], 107 | output_dim: int, 108 | heads: int, 109 | image_size: int = 224, 110 | width: int = 64, 111 | ): 112 | super().__init__() 113 | self.output_dim = output_dim 114 | self.image_size = image_size 115 | 116 | # the 3-layer stem 117 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 118 | self.bn1 = nn.BatchNorm2d(width // 2) 119 | self.act1 = nn.ReLU(inplace=True) 120 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 121 | self.bn2 = nn.BatchNorm2d(width // 2) 122 | self.act2 = nn.ReLU(inplace=True) 123 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 124 | self.bn3 = nn.BatchNorm2d(width) 125 | self.act3 = nn.ReLU(inplace=True) 126 | self.avgpool = nn.AvgPool2d(2) 127 | 128 | # residual layers 129 | self._inplanes = width # this is a *mutable* variable used during construction 130 | self.layer1 = self._make_layer(width, layers[0]) 131 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 132 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 133 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 134 | 135 | embed_dim = width * 32 # the ResNet feature dimension 136 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 137 | 138 | self.init_parameters() 139 | 140 | def _make_layer(self, planes, blocks, stride=1): 141 | layers = [Bottleneck(self._inplanes, planes, stride)] 142 | 143 | self._inplanes = planes * Bottleneck.expansion 144 | for _ in range(1, blocks): 145 | layers.append(Bottleneck(self._inplanes, planes)) 146 | 147 | return nn.Sequential(*layers) 148 | 149 | def init_parameters(self): 150 | if self.attnpool is not None: 151 | std = self.attnpool.c_proj.in_features ** -0.5 152 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 153 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 154 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 155 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 156 | 157 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 158 | for name, param in resnet_block.named_parameters(): 159 | if name.endswith("bn3.weight"): 160 | nn.init.zeros_(param) 161 | 162 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 163 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 164 | for param in self.parameters(): 165 | param.requires_grad = False 166 | if freeze_bn_stats: 167 | freeze_batch_norm_2d(self) 168 | 169 | @torch.jit.ignore 170 | def set_grad_checkpointing(self, enable=True): 171 | # FIXME support for non-transformer 172 | pass 173 | 174 | def stem(self, x): 175 | x = self.act1(self.bn1(self.conv1(x))) 176 | x = self.act2(self.bn2(self.conv2(x))) 177 | x = self.act3(self.bn3(self.conv3(x))) 178 | x = self.avgpool(x) 179 | return x 180 | 181 | def forward_intermediates( 182 | self, 183 | x: torch.Tensor, 184 | indices: Optional[Union[int, List[int]]] = None, 185 | stop_early: bool = False, 186 | normalize_intermediates: bool = False, 187 | intermediates_only: bool = False, 188 | output_fmt: str = 'NCHW', 189 | output_extra_tokens: bool = False, 190 | ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: 191 | """ Forward features that returns intermediates. 192 | 193 | Args: 194 | x: Input image tensor 195 | indices: Take last n blocks if int, all if None, select matching indices if sequence 196 | stop_early: Stop iterating over blocks when last desired intermediate hit 197 | normalize_intermediates: Apply final norm layer to all intermediates 198 | intermediates_only: Only return intermediate features 199 | output_fmt: Shape of intermediate feature outputs 200 | output_extra_tokens: Return both extra class, eot tokens 201 | Returns: 202 | 203 | """ 204 | assert output_fmt in ('NCHW',), 'Output format must be == NCHW.' 205 | # NOTE normalize_intermediates and return_extra_tokens don't apply 206 | take_indices, max_index = feature_take_indices(5, indices) 207 | 208 | output = {} 209 | intermediates = [] 210 | blocks = [self.stem, self.layer1, self.layer2, self.layer3, self.layer4] 211 | if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript 212 | blocks = blocks[:max_index + 1] 213 | for i, blk in enumerate(blocks): 214 | x = blk(x) 215 | if i in take_indices: 216 | intermediates.append(x) 217 | 218 | output['image_intermediates'] = intermediates 219 | 220 | if intermediates_only: 221 | return output 222 | 223 | x = self.attnpool(x) 224 | output['image_features'] = x 225 | 226 | return output 227 | 228 | def forward(self, x): 229 | x = self.stem(x) 230 | x = self.layer1(x) 231 | x = self.layer2(x) 232 | x = self.layer3(x) 233 | x = self.layer4(x) 234 | x = self.attnpool(x) 235 | 236 | return x 237 | -------------------------------------------------------------------------------- /open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 13 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 14 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 15 | 16 | __all__ = ["list_openai_models", "load_openai_model"] 17 | 18 | 19 | def list_openai_models() -> List[str]: 20 | """Returns the names of available CLIP models""" 21 | return list_pretrained_models_by_tag('openai') 22 | 23 | 24 | def load_openai_model( 25 | name: str, 26 | precision: Optional[str] = None, 27 | device: Optional[Union[str, torch.device]] = None, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | cache_dir : Optional[str] 41 | The directory to cache the downloaded model weights 42 | 43 | Returns 44 | ------- 45 | model : torch.nn.Module 46 | The CLIP model 47 | preprocess : Callable[[PIL.Image], torch.Tensor] 48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 49 | """ 50 | if device is None: 51 | device = "cuda" if torch.cuda.is_available() else "cpu" 52 | if precision is None: 53 | precision = 'fp32' if device == 'cpu' else 'fp16' 54 | 55 | if get_pretrained_url(name, 'openai'): 56 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 57 | elif os.path.isfile(name): 58 | model_path = name 59 | else: 60 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 61 | 62 | try: 63 | # loading JIT archive 64 | model = torch.jit.load(model_path, map_location="cpu").eval() 65 | state_dict = None 66 | except RuntimeError: 67 | # loading saved state dict 68 | state_dict = torch.load(model_path, map_location="cpu") 69 | 70 | # Build a non-jit model from the OpenAI jitted model state dict 71 | cast_dtype = get_cast_dtype(precision) 72 | try: 73 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 74 | except KeyError: 75 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 76 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 77 | 78 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 79 | model = model.to(device) 80 | # FIXME support pure fp16/bf16 precision modes 81 | if precision != 'fp16': 82 | model.float() 83 | if precision == 'bf16': 84 | # for bf16, convert back to low-precision 85 | convert_weights_to_lp(model, dtype=torch.bfloat16) 86 | 87 | # add mean / std attributes for consistency with OpenCLIP models 88 | model.visual.image_mean = OPENAI_DATASET_MEAN 89 | model.visual.image_std = OPENAI_DATASET_STD 90 | return model 91 | -------------------------------------------------------------------------------- /open_clip/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /open_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | import logging 6 | from collections import OrderedDict 7 | from typing import Dict, List, Optional, Tuple, Union 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | try: 13 | import timm 14 | from timm.layers import RotAttentionPool2d 15 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 16 | from timm.layers import Mlp, to_2tuple 17 | except ImportError: 18 | timm = None 19 | 20 | from .utils import freeze_batch_norm_2d 21 | 22 | 23 | class TimmModel(nn.Module): 24 | """ timm model adapter 25 | """ 26 | 27 | def __init__( 28 | self, 29 | model_name: str, 30 | embed_dim: int, 31 | image_size: Union[int, Tuple[int, int]] = 224, 32 | pool: str = 'avg', 33 | proj: str = 'linear', 34 | proj_bias: bool = False, 35 | drop: float = 0., 36 | drop_path: Optional[float] = None, 37 | patch_drop: Optional[float] = None, 38 | pretrained: bool = False, 39 | ): 40 | super().__init__() 41 | if timm is None: 42 | raise RuntimeError("Please install the latest timm (`pip install timm`) to use timm based models.") 43 | self.image_size = to_2tuple(image_size) 44 | 45 | # setup kwargs that may not be common across all models 46 | timm_kwargs = {} 47 | if drop_path is not None: 48 | timm_kwargs['drop_path_rate'] = drop_path 49 | if patch_drop is not None: 50 | timm_kwargs['patch_drop_rate'] = patch_drop 51 | 52 | custom_pool = pool in ('abs_attn', 'rot_attn') 53 | if proj: 54 | assert proj in ("linear", "mlp", "none") 55 | extra_proj = proj in ("linear", "mlp") 56 | if not extra_proj and not custom_pool: 57 | # use network classifier head as projection if no proj specified and no custom pooling used 58 | # if projection is explicitly set to "none" will be pass through from network trunk 59 | proj_dim = 0 if proj == 'none' else embed_dim 60 | self.trunk = timm.create_model( 61 | model_name, 62 | num_classes=proj_dim, 63 | global_pool=pool, 64 | pretrained=pretrained, 65 | **timm_kwargs, 66 | ) 67 | prev_chs = embed_dim 68 | else: 69 | self.trunk = timm.create_model( 70 | model_name, 71 | pretrained=pretrained, 72 | **timm_kwargs, 73 | ) 74 | feat_size = self.trunk.default_cfg.get('pool_size', None) 75 | feature_ndim = 1 if not feat_size else 2 76 | if custom_pool: 77 | assert feature_ndim == 2 78 | # if attn pooling used, remove both classifier and default pool 79 | self.trunk.reset_classifier(0, global_pool='') 80 | else: 81 | # reset global pool if pool config set, otherwise leave as network default 82 | reset_kwargs = dict(global_pool=pool) if pool else {} 83 | self.trunk.reset_classifier(0, **reset_kwargs) 84 | prev_chs = self.trunk.num_features 85 | 86 | head_layers = OrderedDict() 87 | 88 | # Add custom pooling to head 89 | if pool == 'abs_attn': 90 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 91 | prev_chs = embed_dim 92 | elif pool == 'rot_attn': 93 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 94 | prev_chs = embed_dim 95 | 96 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 97 | if proj == 'linear': 98 | head_layers['drop'] = nn.Dropout(drop) 99 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 100 | elif proj == 'mlp': 101 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) 102 | 103 | self.head = nn.Sequential(head_layers) 104 | 105 | def lock(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False): 106 | """ lock modules 107 | Args: 108 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 109 | """ 110 | if not unlocked_groups: 111 | # lock full model 112 | for param in self.trunk.parameters(): 113 | param.requires_grad = False 114 | if freeze_bn_stats: 115 | freeze_batch_norm_2d(self.trunk) 116 | else: 117 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 118 | try: 119 | # FIXME import here until API stable and in an official release 120 | from timm.models.helpers import group_parameters, group_modules 121 | except ImportError: 122 | raise RuntimeError( 123 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 124 | matcher = self.trunk.group_matcher() 125 | gparams = group_parameters(self.trunk, matcher) 126 | max_layer_id = max(gparams.keys()) 127 | max_layer_id = max_layer_id - unlocked_groups 128 | for group_idx in range(max_layer_id + 1): 129 | group = gparams[group_idx] 130 | for param in group: 131 | self.trunk.get_parameter(param).requires_grad = False 132 | if freeze_bn_stats: 133 | gmodules = group_modules(self.trunk, matcher, reverse=True) 134 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 135 | freeze_batch_norm_2d(self.trunk, gmodules) 136 | 137 | @torch.jit.ignore 138 | def set_grad_checkpointing(self, enable: bool = True): 139 | try: 140 | self.trunk.set_grad_checkpointing(enable) 141 | except Exception as e: 142 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') 143 | 144 | def forward_intermediates( 145 | self, 146 | x: torch.Tensor, 147 | indices: Optional[Union[int, List[int]]] = None, 148 | stop_early: bool = False, 149 | normalize_intermediates: bool = False, 150 | intermediates_only: bool = False, 151 | output_fmt: str = 'NCHW', 152 | output_extra_tokens: bool = False, 153 | ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: 154 | """ Forward features that returns intermediates. 155 | 156 | Args: 157 | x: Input image tensor 158 | indices: Take last n blocks if int, all if None, select matching indices if sequence 159 | stop_early: Stop iterating over blocks when last desired intermediate hit 160 | normalize_intermediates: Apply norm layer to all intermediates 161 | intermediates_only: Only return intermediate features 162 | output_fmt: Shape of intermediate feature outputs 163 | output_extra_tokens: Return both prefix and spatial intermediate tokens 164 | Returns: 165 | """ 166 | extra_args = {} 167 | if output_extra_tokens: 168 | extra_args['return_prefix_tokens'] = True 169 | trunk_output = self.trunk.forward_intermediates( 170 | x, 171 | indices=indices, 172 | intermediates_only=intermediates_only, 173 | norm=normalize_intermediates, 174 | stop_early=stop_early, 175 | output_fmt=output_fmt, 176 | **extra_args, 177 | ) 178 | 179 | return_dict = {} 180 | intermediates = trunk_output if intermediates_only else trunk_output[1] 181 | if output_extra_tokens and intermediates and isinstance(intermediates[0], tuple): 182 | intermediates_prefix = [xi[1] for xi in intermediates] 183 | intermediates = [xi[0] for xi in intermediates] 184 | return_dict['image_intermediates_prefix'] = intermediates_prefix 185 | 186 | return_dict['image_intermediates'] = intermediates 187 | if intermediates_only: 188 | return return_dict 189 | 190 | image_features = self.trunk.forward_head(trunk_output[0]) # run through timm pooling / projection 191 | image_features = self.head(image_features) # run through adapter pooling / projection 192 | return_dict['image_features'] = image_features 193 | return return_dict 194 | 195 | def forward(self, x): 196 | x = self.trunk(x) 197 | x = self.head(x) 198 | return x 199 | -------------------------------------------------------------------------------- /open_clip/utils.py: -------------------------------------------------------------------------------- 1 | import collections.abc 2 | from itertools import repeat 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import torch 6 | from torch import nn as nn 7 | from torch import _assert 8 | from torchvision.ops.misc import FrozenBatchNorm2d 9 | 10 | 11 | def freeze_batch_norm_2d(module, module_match={}, name=''): 12 | """ 13 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 14 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 15 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 16 | 17 | Args: 18 | module (torch.nn.Module): Any PyTorch module. 19 | module_match (dict): Dictionary of full module names to freeze (all if empty) 20 | name (str): Full module name (prefix) 21 | 22 | Returns: 23 | torch.nn.Module: Resulting module 24 | 25 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 26 | """ 27 | res = module 28 | is_match = True 29 | if module_match: 30 | is_match = name in module_match 31 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 32 | res = FrozenBatchNorm2d(module.num_features) 33 | res.num_features = module.num_features 34 | res.affine = module.affine 35 | if module.affine: 36 | res.weight.data = module.weight.data.clone().detach() 37 | res.bias.data = module.bias.data.clone().detach() 38 | res.running_mean.data = module.running_mean.data 39 | res.running_var.data = module.running_var.data 40 | res.eps = module.eps 41 | else: 42 | for child_name, child in module.named_children(): 43 | full_child_name = '.'.join([name, child_name]) if name else child_name 44 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 45 | if new_child is not child: 46 | res.add_module(child_name, new_child) 47 | return res 48 | 49 | 50 | # From PyTorch internals 51 | def _ntuple(n): 52 | def parse(x): 53 | if isinstance(x, collections.abc.Iterable): 54 | return x 55 | return tuple(repeat(x, n)) 56 | return parse 57 | 58 | 59 | to_1tuple = _ntuple(1) 60 | to_2tuple = _ntuple(2) 61 | to_3tuple = _ntuple(3) 62 | to_4tuple = _ntuple(4) 63 | to_ntuple = lambda n, x: _ntuple(n)(x) 64 | 65 | # Replaces all linear layers with linear_replacement 66 | # TODO: add int8 support for other linear layers including attn and convnets 67 | def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): 68 | for name, module in model.named_children(): 69 | if len(list(module.children())) > 0: 70 | replace_linear(module, linear_replacement, include_modules, copy_weights) 71 | 72 | if isinstance(module, torch.nn.Linear) and name in include_modules: 73 | old_module = model._modules[name] 74 | model._modules[name] = linear_replacement( 75 | module.in_features, 76 | module.out_features, 77 | module.bias is not None, 78 | ) 79 | if copy_weights: 80 | model._modules[name].weight.data.copy_(old_module.weight.data) 81 | if model._modules[name].bias is not None: 82 | model._modules[name].bias.data.copy_(old_module.bias) 83 | 84 | return model 85 | 86 | def convert_int8_model_to_inference_mode(model): 87 | for m in model.modules(): 88 | if hasattr(m, 'prepare_for_eval'): 89 | int8_original_dtype = m.weight.dtype 90 | m.prepare_for_eval() 91 | m.int8_original_dtype = int8_original_dtype 92 | 93 | 94 | def feature_take_indices( 95 | num_features: int, 96 | indices: Optional[Union[int, List[int]]] = None, 97 | as_set: bool = False, 98 | ) -> Tuple[List[int], int]: 99 | """ Determine the absolute feature indices to 'take' from. 100 | 101 | Note: This function can be called in forward() so must be torchscript compatible, 102 | which requires some incomplete typing and workaround hacks. 103 | 104 | Args: 105 | num_features: total number of features to select from 106 | indices: indices to select, 107 | None -> select all 108 | int -> select last n 109 | list/tuple of int -> return specified (-ve indices specify from end) 110 | as_set: return as a set 111 | 112 | Returns: 113 | List (or set) of absolute (from beginning) indices, Maximum index 114 | """ 115 | if indices is None: 116 | indices = num_features # all features if None 117 | 118 | if isinstance(indices, int): 119 | # convert int -> last n indices 120 | _assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})') 121 | take_indices = [num_features - indices + i for i in range(indices)] 122 | else: 123 | take_indices: List[int] = [] 124 | for i in indices: 125 | idx = num_features + i if i < 0 else i 126 | _assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})') 127 | take_indices.append(idx) 128 | 129 | if not torch.jit.is_scripting() and as_set: 130 | return set(take_indices), max(take_indices) 131 | 132 | return take_indices, max(take_indices) 133 | 134 | 135 | def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: 136 | if isinstance(x, int): 137 | # if indices is an int, take last N features 138 | return tuple(range(-x, 0)) 139 | return tuple(x) 140 | -------------------------------------------------------------------------------- /open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.31.0' 2 | -------------------------------------------------------------------------------- /open_clip/zero_shot_classifier.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from itertools import islice 3 | from typing import Callable, List, Optional, Sequence, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def batched(iterable, n): 10 | """Batch data into lists of length *n*. The last batch may be shorter. 11 | NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl 12 | """ 13 | it = iter(iterable) 14 | while True: 15 | batch = list(islice(it, n)) 16 | if not batch: 17 | break 18 | yield batch 19 | 20 | 21 | def build_zero_shot_classifier( 22 | model, 23 | tokenizer, 24 | classnames: Sequence[str], 25 | templates: Sequence[Union[Callable, str]], 26 | num_classes_per_batch: Optional[int] = 10, 27 | device: Union[str, torch.device] = 'cpu', 28 | use_tqdm: bool = False, 29 | ): 30 | """ Build zero-shot classifier weights by iterating over class names in batches 31 | Args: 32 | model: CLIP model instance 33 | tokenizer: CLIP tokenizer instance 34 | classnames: A sequence of class (label) names 35 | templates: A sequence of callables or format() friendly strings to produce templates per class name 36 | num_classes_per_batch: The number of classes to batch together in each forward, all if None 37 | device: Device to use. 38 | use_tqdm: Enable TQDM progress bar. 39 | """ 40 | assert isinstance(templates, Sequence) and len(templates) > 0 41 | assert isinstance(classnames, Sequence) and len(classnames) > 0 42 | use_format = isinstance(templates[0], str) 43 | num_templates = len(templates) 44 | num_classes = len(classnames) 45 | if use_tqdm: 46 | import tqdm 47 | num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) 48 | iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) 49 | else: 50 | iter_wrap = iter 51 | 52 | def _process_batch(batch_classnames): 53 | num_batch_classes = len(batch_classnames) 54 | texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] 55 | texts = tokenizer(texts).to(device) 56 | class_embeddings = model.encode_text(texts, normalize=True) 57 | class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) 58 | class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) 59 | class_embeddings = class_embeddings.T 60 | return class_embeddings 61 | 62 | with torch.no_grad(): 63 | if num_classes_per_batch: 64 | batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] 65 | zeroshot_weights = torch.cat(batched_embeds, dim=1) 66 | else: 67 | zeroshot_weights = _process_batch(classnames) 68 | return zeroshot_weights 69 | 70 | 71 | def build_zero_shot_classifier_legacy( 72 | model, 73 | tokenizer, 74 | classnames: Sequence[str], 75 | templates: Sequence[Union[Callable, str]], 76 | device: Union[str, torch.device] = 'cpu', 77 | use_tqdm: bool = False, 78 | ): 79 | """ Build zero-shot classifier weights by iterating over class names 1 by 1 80 | Args: 81 | model: CLIP model instance 82 | tokenizer: CLIP tokenizer instance 83 | classnames: A sequence of class (label) names 84 | templates: A sequence of callables or format() friendly strings to produce templates per class name 85 | device: Device to use. 86 | use_tqdm: Enable TQDM progress bar. 87 | """ 88 | assert isinstance(templates, Sequence) and len(templates) > 0 89 | assert isinstance(classnames, Sequence) and len(classnames) > 0 90 | if use_tqdm: 91 | import tqdm 92 | iter_wrap = tqdm.tqdm 93 | else: 94 | iter_wrap = iter 95 | 96 | use_format = isinstance(templates[0], str) 97 | 98 | with torch.no_grad(): 99 | zeroshot_weights = [] 100 | for classname in iter_wrap(classnames): 101 | texts = [template.format(classname) if use_format else template(classname) for template in templates] 102 | texts = tokenizer(texts).to(device) # tokenize 103 | class_embeddings = model.encode_text(texts) 104 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) 105 | class_embedding /= class_embedding.norm() 106 | zeroshot_weights.append(class_embedding) 107 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) 108 | 109 | return zeroshot_weights 110 | 111 | -------------------------------------------------------------------------------- /open_clip_train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__init__.py -------------------------------------------------------------------------------- /open_clip_train/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip_train/__pycache__/data.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/data.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip_train/__pycache__/distributed.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/distributed.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip_train/__pycache__/file_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/file_utils.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip_train/__pycache__/logger.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/logger.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip_train/__pycache__/params.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/params.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip_train/__pycache__/precision.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/precision.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip_train/__pycache__/scheduler.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/scheduler.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip_train/__pycache__/train.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/train.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip_train/__pycache__/zero_shot.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/zero_shot.cpython-311.pyc -------------------------------------------------------------------------------- /open_clip_train/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from typing import Optional 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | try: 9 | import horovod.torch as hvd 10 | except ImportError: 11 | hvd = None 12 | 13 | 14 | def is_global_master(args): 15 | return args.rank == 0 16 | 17 | 18 | def is_local_master(args): 19 | return args.local_rank == 0 20 | 21 | 22 | def is_master(args, local=False): 23 | return is_local_master(args) if local else is_global_master(args) 24 | 25 | 26 | def is_device_available(device): 27 | device_type = torch.device(device).type 28 | is_avail = False 29 | is_known = False 30 | if device_type == 'cuda': 31 | is_avail = torch.cuda.is_available() 32 | is_known = True 33 | elif device_type == 'npu': 34 | # NOTE autoload device extension needed for this not to error out on this check 35 | is_avail = torch.npu.is_available() 36 | is_known = True 37 | elif device_type == 'mps': 38 | is_avail = torch.backends.mps.is_available() 39 | is_known = True 40 | elif device_type == 'cpu': 41 | is_avail = True 42 | is_known = True 43 | 44 | return is_avail, is_known 45 | 46 | 47 | def set_device(device): 48 | if device.startswith('cuda:'): 49 | torch.cuda.set_device(device) 50 | elif device.startswith('npu:'): 51 | torch.npu.set_device(device) 52 | 53 | 54 | def is_using_horovod(): 55 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 56 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 57 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 58 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 59 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): 60 | return True 61 | else: 62 | return False 63 | 64 | 65 | def is_using_distributed(): 66 | if 'WORLD_SIZE' in os.environ: 67 | return int(os.environ['WORLD_SIZE']) > 1 68 | if 'SLURM_NTASKS' in os.environ: 69 | return int(os.environ['SLURM_NTASKS']) > 1 70 | return False 71 | 72 | 73 | def world_info_from_env(): 74 | local_rank = 0 75 | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): 76 | if v in os.environ: 77 | local_rank = int(os.environ[v]) 78 | break 79 | global_rank = 0 80 | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): 81 | if v in os.environ: 82 | global_rank = int(os.environ[v]) 83 | break 84 | world_size = 1 85 | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): 86 | if v in os.environ: 87 | world_size = int(os.environ[v]) 88 | break 89 | 90 | return local_rank, global_rank, world_size 91 | 92 | 93 | def init_distributed_device(args): 94 | # Distributed training = training on more than one GPU. 95 | # Works in both single and multi-node scenarios. 96 | args.distributed = False 97 | args.world_size = 1 98 | args.rank = 0 # global rank 99 | args.local_rank = 0 100 | result = init_distributed_device_so( 101 | device=getattr(args, 'device', 'cuda'), 102 | dist_backend=getattr(args, 'dist_backend', None), 103 | dist_url=getattr(args, 'dist_url', None), 104 | horovod=getattr(args, 'horovod', False), 105 | no_set_device_rank=getattr(args, 'no_set_device_rank', False), 106 | ) 107 | args.device = result['device'] 108 | args.world_size = result['world_size'] 109 | args.rank = result['global_rank'] 110 | args.local_rank = result['local_rank'] 111 | args.distributed = result['distributed'] 112 | device = torch.device(args.device) 113 | return device 114 | 115 | 116 | def init_distributed_device_so( 117 | device: str = 'cuda', 118 | dist_backend: Optional[str] = None, 119 | dist_url: Optional[str] = None, 120 | horovod: bool = False, 121 | no_set_device_rank: bool = False, 122 | ): 123 | # Distributed training = training on more than one GPU. 124 | # Works in both single and multi-node scenarios. 125 | distributed = False 126 | world_size = 1 127 | global_rank = 0 128 | local_rank = 0 129 | device_type, *device_idx = device.split(':', maxsplit=1) 130 | is_avail, is_known = is_device_available(device_type) 131 | if not is_known: 132 | warnings.warn(f"Device {device} was not known and checked for availability, trying anyways.") 133 | elif not is_avail: 134 | warnings.warn(f"Device {device} was not available, falling back to CPU.") 135 | device_type = device = 'cpu' 136 | 137 | if horovod: 138 | import horovod.torch as hvd 139 | assert hvd is not None, "Horovod is not installed" 140 | hvd.init() 141 | local_rank = int(hvd.local_rank()) 142 | global_rank = hvd.rank() 143 | world_size = hvd.size() 144 | distributed = True 145 | elif is_using_distributed(): 146 | if dist_backend is None: 147 | dist_backends = { 148 | "cuda": "nccl", 149 | "hpu": "hccl", 150 | "npu": "hccl", 151 | "xpu": "ccl", 152 | } 153 | dist_backend = dist_backends.get(device_type, 'gloo') 154 | 155 | dist_url = dist_url or 'env://' 156 | 157 | if 'SLURM_PROCID' in os.environ: 158 | # DDP via SLURM 159 | local_rank, global_rank, world_size = world_info_from_env() 160 | # SLURM var -> torch.distributed vars in case needed 161 | os.environ['LOCAL_RANK'] = str(local_rank) 162 | os.environ['RANK'] = str(global_rank) 163 | os.environ['WORLD_SIZE'] = str(world_size) 164 | torch.distributed.init_process_group( 165 | backend=dist_backend, 166 | init_method=dist_url, 167 | world_size=world_size, 168 | rank=global_rank, 169 | ) 170 | else: 171 | # DDP via torchrun, torch.distributed.launch 172 | local_rank, _, _ = world_info_from_env() 173 | torch.distributed.init_process_group( 174 | backend=dist_backend, 175 | init_method=dist_url, 176 | ) 177 | world_size = torch.distributed.get_world_size() 178 | global_rank = torch.distributed.get_rank() 179 | distributed = True 180 | 181 | if distributed and not no_set_device_rank and device_type not in ('cpu', 'mps'): 182 | # Ignore manually specified device index in distributed mode and 183 | # override with resolved local rank, fewer headaches in most setups. 184 | if device_idx: 185 | warnings.warn(f'device index {device_idx[0]} removed from specified ({device}).') 186 | device = f'{device_type}:{local_rank}' 187 | set_device(device) 188 | 189 | return dict( 190 | device=device, 191 | global_rank=global_rank, 192 | local_rank=local_rank, 193 | world_size=world_size, 194 | distributed=distributed, 195 | ) 196 | 197 | 198 | def broadcast_object(args, obj, src=0): 199 | # broadcast a pickle-able python object from rank-0 to all ranks 200 | if args.horovod: 201 | return hvd.broadcast_object(obj, root_rank=src) 202 | else: 203 | if args.rank == src: 204 | objects = [obj] 205 | else: 206 | objects = [None] 207 | dist.broadcast_object_list(objects, src=src) 208 | return objects[0] 209 | 210 | 211 | def all_gather_object(args, obj, dst=0): 212 | # gather a pickle-able python object across all ranks 213 | if args.horovod: 214 | return hvd.allgather_object(obj) 215 | else: 216 | objects = [None for _ in range(args.world_size)] 217 | dist.all_gather_object(objects, obj) 218 | return objects 219 | -------------------------------------------------------------------------------- /open_clip_train/file_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import multiprocessing 4 | import subprocess 5 | import time 6 | import fsspec 7 | import torch 8 | from tqdm import tqdm 9 | 10 | def remote_sync_s3(local_dir, remote_dir): 11 | # skip epoch_latest which can change during sync. 12 | result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 13 | if result.returncode != 0: 14 | logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}") 15 | return False 16 | 17 | logging.info(f"Successfully synced with S3 bucket") 18 | return True 19 | 20 | def remote_sync_fsspec(local_dir, remote_dir): 21 | # FIXME currently this is slow and not recommended. Look into speeding up. 22 | a = fsspec.get_mapper(local_dir) 23 | b = fsspec.get_mapper(remote_dir) 24 | 25 | for k in a: 26 | # skip epoch_latest which can change during sync. 27 | if 'epoch_latest.pt' in k: 28 | continue 29 | 30 | logging.info(f'Attempting to sync {k}') 31 | if k in b and len(a[k]) == len(b[k]): 32 | logging.debug(f'Skipping remote sync for {k}.') 33 | continue 34 | 35 | try: 36 | logging.info(f'Successful sync for {k}.') 37 | b[k] = a[k] 38 | except Exception as e: 39 | logging.info(f'Error during remote sync for {k}: {e}') 40 | return False 41 | 42 | return True 43 | 44 | def remote_sync(local_dir, remote_dir, protocol): 45 | logging.info('Starting remote sync.') 46 | if protocol == 's3': 47 | return remote_sync_s3(local_dir, remote_dir) 48 | elif protocol == 'fsspec': 49 | return remote_sync_fsspec(local_dir, remote_dir) 50 | else: 51 | logging.error('Remote protocol not known') 52 | return False 53 | 54 | def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol): 55 | while True: 56 | time.sleep(sync_every) 57 | remote_sync(local_dir, remote_dir, protocol) 58 | 59 | def start_sync_process(sync_every, local_dir, remote_dir, protocol): 60 | p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol)) 61 | return p 62 | 63 | # Note: we are not currently using this save function. 64 | def pt_save(pt_obj, file_path): 65 | of = fsspec.open(file_path, "wb") 66 | with of as f: 67 | torch.save(pt_obj, file_path) 68 | 69 | def pt_load(file_path, map_location=None): 70 | if file_path.startswith('s3'): 71 | logging.info('Loading remote checkpoint, which may take a bit.') 72 | of = fsspec.open(file_path, "rb") 73 | with of as f: 74 | out = torch.load(f, map_location=map_location) 75 | return out 76 | 77 | def check_exists(file_path): 78 | try: 79 | with fsspec.open(file_path): 80 | pass 81 | except FileNotFoundError: 82 | return False 83 | return True 84 | -------------------------------------------------------------------------------- /open_clip_train/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logging(log_file, level, include_host=False): 5 | if include_host: 6 | import socket 7 | hostname = socket.gethostname() 8 | formatter = logging.Formatter( 9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 10 | else: 11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 12 | 13 | logging.root.setLevel(level) 14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 15 | for logger in loggers: 16 | logger.setLevel(level) 17 | 18 | stream_handler = logging.StreamHandler() 19 | stream_handler.setFormatter(formatter) 20 | logging.root.addHandler(stream_handler) 21 | 22 | if log_file: 23 | file_handler = logging.FileHandler(filename=log_file) 24 | file_handler.setFormatter(formatter) 25 | logging.root.addHandler(file_handler) 26 | 27 | -------------------------------------------------------------------------------- /open_clip_train/precision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from contextlib import suppress 3 | from functools import partial 4 | 5 | 6 | def get_autocast(precision, device_type='cuda'): 7 | if precision =='amp': 8 | amp_dtype = torch.float16 9 | elif precision == 'amp_bfloat16' or precision == 'amp_bf16': 10 | amp_dtype = torch.bfloat16 11 | else: 12 | return suppress 13 | 14 | return partial(torch.amp.autocast, device_type=device_type, dtype=amp_dtype) -------------------------------------------------------------------------------- /open_clip_train/profiler.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import open_clip 5 | import pandas as pd 6 | from torch.utils.flop_counter import FlopCounterMode 7 | try: 8 | import fvcore 9 | except: 10 | fvcore = None 11 | 12 | parser = argparse.ArgumentParser(description='OpenCLIP Profiler') 13 | 14 | # benchmark specific args 15 | parser.add_argument('--model', metavar='NAME', default='', 16 | help='model(s) to profile') 17 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', 18 | help='Output csv file for results') 19 | parser.add_argument('--profiler', default='torch', type=str, choices=['torch', 'fvcore']) 20 | parser.add_argument('--batch-size', default=1, type=int, help='Batch size for profiling') 21 | 22 | 23 | def profile_fvcore( 24 | model, 25 | image_input_size=(3, 224, 224), 26 | text_input_size=(77,), 27 | batch_size=1, 28 | detailed=False, 29 | force_cpu=False 30 | ): 31 | if force_cpu: 32 | model = model.to('cpu') 33 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 34 | example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 35 | example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 36 | fca = fvcore.nn.FlopCountAnalysis(model, (example_image_input, example_text_input)) 37 | aca = fvcore.nn.ActivationCountAnalysis(model, (example_image_input, example_text_input)) 38 | if detailed: 39 | fcs = fvcore.nn.flop_count_str(fca) 40 | print(fcs) 41 | return fca.total() / batch_size, aca.total() / batch_size 42 | 43 | 44 | def profile_fvcore_text( 45 | model, 46 | text_input_size=(77,), 47 | batch_size=1, 48 | detailed=False, 49 | force_cpu=False 50 | ): 51 | if force_cpu: 52 | model = model.to('cpu') 53 | device = next(model.parameters()).device 54 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 55 | fca = fvcore.nn.FlopCountAnalysis(model, example_input) 56 | aca = fvcore.nn.ActivationCountAnalysis(model, example_input) 57 | if detailed: 58 | fcs = fvcore.nn.flop_count_str(fca) 59 | print(fcs) 60 | return fca.total() / batch_size, aca.total() / batch_size 61 | 62 | 63 | def profile_fvcore_image( 64 | model, 65 | image_input_size=(3, 224, 224), 66 | batch_size=1, 67 | detailed=False, 68 | force_cpu=False 69 | ): 70 | if force_cpu: 71 | model = model.to('cpu') 72 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 73 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 74 | fca = fvcore.nn.FlopCountAnalysis(model, example_input) 75 | aca = fvcore.nn.ActivationCountAnalysis(model, example_input) 76 | if detailed: 77 | fcs = fvcore.nn.flop_count_str(fca) 78 | print(fcs) 79 | return fca.total() / batch_size, aca.total() / batch_size 80 | 81 | 82 | def profile_torch_image(model, image_input_size, batch_size=1, force_cpu=False): 83 | """Profile the image encoder using torch.utils.flop_counter""" 84 | if force_cpu: 85 | model = model.to('cpu') 86 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 87 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 88 | 89 | flop_counter = FlopCounterMode() 90 | with flop_counter: 91 | model(example_input) 92 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values()) 93 | return total_flops / batch_size 94 | 95 | 96 | def profile_torch_text(model, text_input_size, batch_size=1, force_cpu=False): 97 | """Profile the text encoder using torch.utils.flop_counter""" 98 | if force_cpu: 99 | model = model.to('cpu') 100 | device = next(model.parameters()).device 101 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 102 | 103 | flop_counter = FlopCounterMode() 104 | with flop_counter: 105 | model(example_input) 106 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values()) 107 | return total_flops / batch_size 108 | 109 | 110 | def profile_torch(model, text_input_size, image_input_size, batch_size=1, force_cpu=False): 111 | """Profile the full model using torch.utils.flop_counter""" 112 | if force_cpu: 113 | model = model.to('cpu') 114 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 115 | image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 116 | text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 117 | 118 | flop_counter = FlopCounterMode() 119 | with flop_counter: 120 | model(image_input, text_input) 121 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values()) 122 | return total_flops / batch_size 123 | 124 | 125 | def count_params(model): 126 | return sum(m.numel() for m in model.parameters()) 127 | 128 | def profile_model(model_name, batch_size=1, profiler='torch', device="cuda"): 129 | assert profiler in ['torch', 'fvcore'], 'Only torch and fvcore profilers are supported' 130 | if profiler == 'fvcore': 131 | assert fvcore is not None, 'Please install fvcore.' 132 | model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False) 133 | model.eval() 134 | 135 | if torch.cuda.is_available(): 136 | model = model.cuda() 137 | elif device == "npu" and torch.npu.is_available(): 138 | model = model.npu() 139 | 140 | if isinstance(model.visual.image_size, (tuple, list)): 141 | image_input_size = (3,) + tuple(model.visual.image_size[-2:]) 142 | else: 143 | image_input_size = (3, model.visual.image_size, model.visual.image_size) 144 | 145 | text_input_size = (77,) 146 | if hasattr(model, 'context_length') and model.context_length: 147 | text_input_size = (model.context_length,) 148 | 149 | results = {} 150 | results['model'] = model_name 151 | results['image_size'] = image_input_size[1] 152 | 153 | model_cfg = open_clip.get_model_config(model_name) 154 | if model_cfg: 155 | vision_cfg = open_clip.CLIPVisionCfg(**model_cfg['vision_cfg']) 156 | text_cfg = open_clip.CLIPTextCfg(**model_cfg['text_cfg']) 157 | results['image_width'] = int(vision_cfg.width) 158 | results['text_width'] = int(text_cfg.width) 159 | results['embed_dim'] = int(model_cfg['embed_dim']) 160 | else: 161 | results['image_width'] = 0 162 | results['text_width'] = 0 163 | results['embed_dim'] = 0 164 | 165 | retries = 2 166 | while retries: 167 | retries -= 1 168 | try: 169 | results['mparams'] = round(count_params(model) / 1e6, 2) 170 | results['image_mparams'] = round(count_params(model.visual) / 1e6, 2) 171 | results['text_mparams'] = round(count_params(model.text) / 1e6, 2) 172 | 173 | if profiler == 'fvcore': 174 | macs, acts = profile_fvcore( 175 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 176 | 177 | image_macs, image_acts = profile_fvcore_image( 178 | model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size) 179 | 180 | text_macs, text_acts = profile_fvcore_text( 181 | model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 182 | 183 | results['gmacs'] = round(macs / 1e9, 2) 184 | results['macts'] = round(acts / 1e6, 2) 185 | 186 | results['image_gmacs'] = round(image_macs / 1e9, 2) 187 | results['image_macts'] = round(image_acts / 1e6, 2) 188 | 189 | results['text_gmacs'] = round(text_macs / 1e9, 2) 190 | results['text_macts'] = round(text_acts / 1e6, 2) 191 | elif profiler == 'torch': 192 | image_flops = profile_torch_image( 193 | model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size) 194 | text_flops = profile_torch_text( 195 | model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 196 | total_flops = profile_torch( 197 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 198 | 199 | results['gflops'] = round(total_flops / 1e9, 2) 200 | results['image_gflops'] = round(image_flops / 1e9, 2) 201 | results['text_gflops'] = round(text_flops / 1e9, 2) 202 | 203 | except RuntimeError as e: 204 | pass 205 | return results 206 | 207 | 208 | def main(): 209 | args = parser.parse_args() 210 | 211 | # FIXME accept a text file name to allow lists of models in txt/csv 212 | if args.model == 'all': 213 | parsed_model = open_clip.list_models() 214 | else: 215 | parsed_model = args.model.split(',') 216 | 217 | results = [] 218 | models_with_errors = [] 219 | for m in parsed_model: 220 | print('='*100) 221 | print(f'Profiling {m}') 222 | try: 223 | row = profile_model(m, batch_size=args.batch_size, profiler=args.profiler, device=args.device) 224 | results.append(row) 225 | except Exception as e: 226 | print(f'Error profiling {m}: {e}') 227 | import traceback 228 | traceback.print_exc() 229 | models_with_errors.append(m) 230 | 231 | df = pd.DataFrame(results, columns=results[0].keys()) 232 | 233 | if 'gmacs' in df.columns: 234 | df = df.sort_values(by=['gmacs', 'mparams', 'model']) 235 | else: 236 | df = df.sort_values(by=['gflops', 'mparams', 'model']) 237 | 238 | print('='*100) 239 | print('Done.') 240 | print(df) 241 | if args.results_file: 242 | df.to_csv(args.results_file, index=False) 243 | 244 | if models_with_errors: 245 | print('Models with errors:', models_with_errors) 246 | 247 | 248 | if __name__ == '__main__': 249 | main() 250 | -------------------------------------------------------------------------------- /open_clip_train/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def const_lr(optimizer, base_lr, warmup_length, steps): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | lr = base_lr 19 | assign_learning_rate(optimizer, lr) 20 | return lr 21 | 22 | return _lr_adjuster 23 | 24 | 25 | def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.): 26 | def _lr_adjuster(step): 27 | start_cooldown_step = steps - cooldown_steps 28 | if step < warmup_length: 29 | lr = _warmup_lr(base_lr, warmup_length, step) 30 | else: 31 | if step < start_cooldown_step: 32 | lr = base_lr 33 | else: 34 | e = step - start_cooldown_step 35 | es = steps - start_cooldown_step 36 | # linear decay if power == 1; polynomial decay otherwise; 37 | decay = (1 - (e / es)) ** cooldown_power 38 | lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr 39 | assign_learning_rate(optimizer, lr) 40 | return lr 41 | 42 | return _lr_adjuster 43 | 44 | 45 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 46 | def _lr_adjuster(step): 47 | if step < warmup_length: 48 | lr = _warmup_lr(base_lr, warmup_length, step) 49 | else: 50 | e = step - warmup_length 51 | es = steps - warmup_length 52 | lr = 0.5 * (1 + math.cos(math.pi * e / es)) * base_lr 53 | assign_learning_rate(optimizer, lr) 54 | return lr 55 | 56 | return _lr_adjuster 57 | 58 | -------------------------------------------------------------------------------- /open_clip_train/zero_shot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \ 7 | IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES 8 | from open_clip_train.precision import get_autocast 9 | 10 | 11 | def accuracy(output, target, topk=(1,)): 12 | pred = output.topk(max(topk), 1, True, True)[1].t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 15 | 16 | 17 | def run(model, classifier, dataloader, args): 18 | device = torch.device(args.device) 19 | autocast = get_autocast(args.precision, device_type=device.type) 20 | input_dtype = get_input_dtype(args.precision) 21 | 22 | with torch.inference_mode(): 23 | top1, top5, n = 0., 0., 0. 24 | for images, target in tqdm(dataloader, unit_scale=args.batch_size): 25 | images = images.to(device=device, dtype=input_dtype) 26 | target = target.to(device) 27 | 28 | with autocast(): 29 | # predict 30 | output = model(image=images) 31 | image_features = output['image_features'] if isinstance(output, dict) else output[0] 32 | logits = 100. * image_features @ classifier 33 | 34 | # measure accuracy 35 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 36 | top1 += acc1 37 | top5 += acc5 38 | n += images.size(0) 39 | 40 | top1 = (top1 / n) 41 | top5 = (top5 / n) 42 | return top1, top5 43 | 44 | 45 | def zero_shot_eval(model, data, epoch, args, tokenizer=None): 46 | if 'imagenet-val' not in data and 'imagenet-v2' not in data: 47 | return {} 48 | if args.zeroshot_frequency == 0: 49 | return {} 50 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: 51 | return {} 52 | if args.distributed and not args.horovod: 53 | model = model.module 54 | 55 | logging.info('Starting zero-shot imagenet.') 56 | if tokenizer is None: 57 | tokenizer = get_tokenizer(args.model) 58 | 59 | logging.info('Building zero-shot classifier') 60 | device = torch.device(args.device) 61 | autocast = get_autocast(args.precision, device_type=device.type) 62 | with autocast(): 63 | classifier = build_zero_shot_classifier( 64 | model, 65 | tokenizer=tokenizer, 66 | classnames=IMAGENET_CLASSNAMES, 67 | templates=OPENAI_IMAGENET_TEMPLATES, 68 | num_classes_per_batch=10, 69 | device=device, 70 | use_tqdm=True, 71 | ) 72 | 73 | logging.info('Using classifier') 74 | results = {} 75 | if 'imagenet-val' in data: 76 | top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) 77 | results['imagenet-zeroshot-val-top1'] = top1 78 | results['imagenet-zeroshot-val-top5'] = top5 79 | if 'imagenet-v2' in data: 80 | top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) 81 | results['imagenetv2-zeroshot-val-top1'] = top1 82 | results['imagenetv2-zeroshot-val-top5'] = top5 83 | 84 | logging.info('Finished zero-shot imagenet.') 85 | 86 | return results 87 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | torchvision 3 | regex 4 | ftfy 5 | tqdm 6 | huggingface_hub 7 | safetensors 8 | timm 9 | transformers --------------------------------------------------------------------------------