├── .DS_Store
├── LICENSE
├── README.md
├── docs
    ├── .DS_Store
    ├── README.md
    ├── index.html
    ├── resources
    │   ├── .DS_Store
    │   ├── LLaVA.png
    │   ├── ar.svg
    │   ├── block_mask.jpg
    │   ├── first_mask.jpg
    │   ├── gr.svg
    │   ├── hf_dataset.jpg
    │   ├── hg.svg
    │   ├── icon.png
    │   ├── mask_strategy.jpg
    │   ├── method.jpg
    │   ├── random_mask.jpg
    │   ├── retrieval.png
    │   ├── sota.png
    │   ├── subcaption_mask.jpg
    │   └── tw.png
    └── static
    │   ├── css
    │       ├── bulma-carousel.min.css
    │       ├── bulma-slider.min.css
    │       ├── bulma.css.map.txt
    │       ├── bulma.min.css
    │       ├── fontawesome.all.min.css
    │       └── index.css
    │   └── js
    │       ├── bulma-carousel.js
    │       ├── bulma-carousel.min.js
    │       ├── bulma-slider.js
    │       ├── bulma-slider.min.js
    │       ├── fontawesome.all.min.js
    │       └── index.js
├── inference.py
├── open_clip
    ├── __init__.py
    ├── __pycache__
    │   ├── __init__.cpython-311.pyc
    │   ├── coca_model.cpython-311.pyc
    │   ├── constants.cpython-311.pyc
    │   ├── convert.cpython-311.pyc
    │   ├── factory.cpython-311.pyc
    │   ├── hf_configs.cpython-311.pyc
    │   ├── hf_model.cpython-311.pyc
    │   ├── loss.cpython-311.pyc
    │   ├── model.cpython-311.pyc
    │   ├── modified_resnet.cpython-311.pyc
    │   ├── openai.cpython-311.pyc
    │   ├── pos_embed.cpython-311.pyc
    │   ├── pretrained.cpython-311.pyc
    │   ├── push_to_hf_hub.cpython-311.pyc
    │   ├── timm_model.cpython-311.pyc
    │   ├── tokenizer.cpython-311.pyc
    │   ├── transform.cpython-311.pyc
    │   ├── transformer.cpython-311.pyc
    │   ├── utils.cpython-311.pyc
    │   ├── version.cpython-311.pyc
    │   ├── zero_shot_classifier.cpython-311.pyc
    │   └── zero_shot_metadata.cpython-311.pyc
    ├── bpe_simple_vocab_16e6.txt.gz
    ├── coca_model.py
    ├── constants.py
    ├── convert.py
    ├── factory.py
    ├── hf_configs.py
    ├── hf_model.py
    ├── loss.py
    ├── model.py
    ├── model_configs
    │   ├── EVA01-g-14-plus.json
    │   ├── EVA01-g-14.json
    │   ├── EVA02-B-16.json
    │   ├── EVA02-E-14-plus.json
    │   ├── EVA02-E-14.json
    │   ├── EVA02-L-14-336.json
    │   ├── EVA02-L-14.json
    │   ├── MobileCLIP-B.json
    │   ├── MobileCLIP-S1.json
    │   ├── MobileCLIP-S2.json
    │   ├── RN101-quickgelu.json
    │   ├── RN101.json
    │   ├── RN50-quickgelu.json
    │   ├── RN50.json
    │   ├── RN50x16-quickgelu.json
    │   ├── RN50x16.json
    │   ├── RN50x4-quickgelu.json
    │   ├── RN50x4.json
    │   ├── RN50x64-quickgelu.json
    │   ├── RN50x64.json
    │   ├── ViT-B-16-SigLIP-256.json
    │   ├── ViT-B-16-SigLIP-384.json
    │   ├── ViT-B-16-SigLIP-512.json
    │   ├── ViT-B-16-SigLIP-i18n-256.json
    │   ├── ViT-B-16-SigLIP.json
    │   ├── ViT-B-16-SigLIP2-256.json
    │   ├── ViT-B-16-SigLIP2-384.json
    │   ├── ViT-B-16-SigLIP2-512.json
    │   ├── ViT-B-16-SigLIP2.json
    │   ├── ViT-B-16-plus-240.json
    │   ├── ViT-B-16-plus.json
    │   ├── ViT-B-16-quickgelu.json
    │   ├── ViT-B-16.json
    │   ├── ViT-B-32-256.json
    │   ├── ViT-B-32-SigLIP2-256.json
    │   ├── ViT-B-32-plus-256.json
    │   ├── ViT-B-32-quickgelu.json
    │   ├── ViT-B-32.json
    │   ├── ViT-H-14-378-quickgelu.json
    │   ├── ViT-H-14-378.json
    │   ├── ViT-H-14-CLIPA-336.json
    │   ├── ViT-H-14-CLIPA.json
    │   ├── ViT-H-14-CLIPS-224.json
    │   ├── ViT-H-14-quickgelu.json
    │   ├── ViT-H-14.json
    │   ├── ViT-H-16.json
    │   ├── ViT-L-14-280.json
    │   ├── ViT-L-14-336-quickgelu.json
    │   ├── ViT-L-14-336.json
    │   ├── ViT-L-14-CLIPA-336.json
    │   ├── ViT-L-14-CLIPA.json
    │   ├── ViT-L-14-CLIPS-224.json
    │   ├── ViT-L-14-CLIPS-336.json
    │   ├── ViT-L-14-quickgelu.json
    │   ├── ViT-L-14.json
    │   ├── ViT-L-16-320.json
    │   ├── ViT-L-16-SigLIP-256.json
    │   ├── ViT-L-16-SigLIP-384.json
    │   ├── ViT-L-16-SigLIP2-256.json
    │   ├── ViT-L-16-SigLIP2-384.json
    │   ├── ViT-L-16-SigLIP2-512.json
    │   ├── ViT-L-16.json
    │   ├── ViT-M-16-alt.json
    │   ├── ViT-M-16.json
    │   ├── ViT-M-32-alt.json
    │   ├── ViT-M-32.json
    │   ├── ViT-S-16-alt.json
    │   ├── ViT-S-16.json
    │   ├── ViT-S-32-alt.json
    │   ├── ViT-S-32.json
    │   ├── ViT-SO400M-14-SigLIP-378.json
    │   ├── ViT-SO400M-14-SigLIP-384.json
    │   ├── ViT-SO400M-14-SigLIP.json
    │   ├── ViT-SO400M-14-SigLIP2-378.json
    │   ├── ViT-SO400M-14-SigLIP2.json
    │   ├── ViT-SO400M-16-SigLIP-i18n-256.json
    │   ├── ViT-SO400M-16-SigLIP2-256.json
    │   ├── ViT-SO400M-16-SigLIP2-384.json
    │   ├── ViT-SO400M-16-SigLIP2-512.json
    │   ├── ViT-bigG-14-CLIPA-336.json
    │   ├── ViT-bigG-14-CLIPA.json
    │   ├── ViT-bigG-14-quickgelu.json
    │   ├── ViT-bigG-14.json
    │   ├── ViT-e-14.json
    │   ├── ViT-g-14.json
    │   ├── ViT-gopt-16-SigLIP2-256.json
    │   ├── ViT-gopt-16-SigLIP2-384.json
    │   ├── ViTamin-B-LTT.json
    │   ├── ViTamin-B.json
    │   ├── ViTamin-L-256.json
    │   ├── ViTamin-L-336.json
    │   ├── ViTamin-L-384.json
    │   ├── ViTamin-L.json
    │   ├── ViTamin-L2-256.json
    │   ├── ViTamin-L2-336.json
    │   ├── ViTamin-L2-384.json
    │   ├── ViTamin-L2.json
    │   ├── ViTamin-S-LTT.json
    │   ├── ViTamin-S.json
    │   ├── ViTamin-XL-256.json
    │   ├── ViTamin-XL-336.json
    │   ├── ViTamin-XL-384.json
    │   ├── coca_ViT-B-32.json
    │   ├── coca_ViT-L-14.json
    │   ├── coca_base.json
    │   ├── coca_roberta-ViT-B-32.json
    │   ├── convnext_base.json
    │   ├── convnext_base_w.json
    │   ├── convnext_base_w_320.json
    │   ├── convnext_large.json
    │   ├── convnext_large_d.json
    │   ├── convnext_large_d_320.json
    │   ├── convnext_small.json
    │   ├── convnext_tiny.json
    │   ├── convnext_xlarge.json
    │   ├── convnext_xxlarge.json
    │   ├── convnext_xxlarge_320.json
    │   ├── mt5-base-ViT-B-32.json
    │   ├── mt5-xl-ViT-H-14.json
    │   ├── nllb-clip-base-siglip.json
    │   ├── nllb-clip-base.json
    │   ├── nllb-clip-large-siglip.json
    │   ├── nllb-clip-large.json
    │   ├── roberta-ViT-B-32.json
    │   ├── swin_base_patch4_window7_224.json
    │   ├── vit_medium_patch16_gap_256.json
    │   ├── vit_relpos_medium_patch16_cls_224.json
    │   ├── xlm-roberta-base-ViT-B-32.json
    │   └── xlm-roberta-large-ViT-H-14.json
    ├── modified_resnet.py
    ├── openai.py
    ├── pos_embed.py
    ├── pretrained.py
    ├── push_to_hf_hub.py
    ├── timm_model.py
    ├── tokenizer.py
    ├── transform.py
    ├── transformer.py
    ├── utils.py
    ├── version.py
    ├── zero_shot_classifier.py
    └── zero_shot_metadata.py
├── open_clip_train
    ├── __init__.py
    ├── __pycache__
    │   ├── __init__.cpython-311.pyc
    │   ├── data.cpython-311.pyc
    │   ├── distributed.cpython-311.pyc
    │   ├── file_utils.cpython-311.pyc
    │   ├── logger.cpython-311.pyc
    │   ├── params.cpython-311.pyc
    │   ├── precision.cpython-311.pyc
    │   ├── scheduler.cpython-311.pyc
    │   ├── train.cpython-311.pyc
    │   └── zero_shot.cpython-311.pyc
    ├── data.py
    ├── distributed.py
    ├── file_utils.py
    ├── logger.py
    ├── main.py
    ├── params.py
    ├── precision.py
    ├── profiler.py
    ├── scheduler.py
    ├── train.py
    └── zero_shot.py
├── requirements.txt
└── vocab.txt
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/.DS_Store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
 1 | MIT License
 2 | 
 3 | Copyright (c) 2024 VLAA@UCSC
 4 | 
 5 | Permission is hereby granted, free of charge, to any person obtaining a copy
 6 | of this software and associated documentation files (the "Software"), to deal
 7 | in the Software without restriction, including without limitation the rights
 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 | 
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 | 
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 | 
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
  1 | # **CLIPS**
  2 | 
  3 | **Official implementation of the paper "[_CLIPS: An Enhanced CLIP Framework for Learning with Synthetic Captions_](https://arxiv.org/abs/2411.16828)".**
  4 | 
  5 | 
  6 | 
  7 | 
  8 | Previous works show that noisy, web-crawled image-text pairs may limit vision-language pretraining like CLIP and propose learning with synthetic captions as a promising alternative. Our work continues this effort, introducing two simple yet effective designs to better leverage richly described synthetic captions:
  9 | 
 10 | 1. By observing a strong inverse effect with synthetic captions, we use only **partial synthetic captions** to feed the text encoder, achieving significantly better performance.
 11 | 2. We incorporate an **autoregressive captioner** that mimics the recaptioning process, predicting full-length synthetic captions conditioned on the image and original web-crawled captions.
 12 | 
 13 | Our method achieves **state-of-the-art (SOTA)** results in zero-shot image-text retrieval on MSCOCO and Flickr30K, while enhancing the visual capability of LLaVA.
 14 | 
 15 | ---
 16 | 
 17 | ## **Links**
 18 | - [📄 Paper (arXiv)](https://arxiv.org/abs/2411.16828)  
 19 | - [🤗 Pretrained Model on HuggingFace](https://huggingface.co/UCSC-VLAA/ViT-L-14-CLIPS-Recap-DataComp-1B)  
 20 | - [🌐 Project Page](https://ucsc-vlaa.github.io/CLIPS/)
 21 | 
 22 | ---
 23 | 
 24 | ## **Key Results**
 25 | 
 26 | ### **Inverse Effect with Synthetic Captions**
 27 | 
 28 | 
 29 | Visualization of four different token reduction strategies. These strategies can improve the model's learning efficiency on synthetic captions to varying degrees. Among these strategies, the sub-caption and block mask perform best.
 30 | 
 31 | ---
 32 | 
 33 | ### **Zero-Shot Cross-Modal Retrieval**
 34 | 
 35 | 
 36 | Our method consistently achieves superior performance across all benchmarks and model sizes, yielding significant improvements over the baselines.
 37 | 
 38 | ---
 39 | 
 40 | ### **Comparison with State-of-the-Art Methods**
 41 | 
 42 | 
 43 | With increased computational resources and scaling, our best model further achieves 76.4% and 96.6% R@1 text retrieval performance on MSCOCO and Flickr30K respectively, and 57.2% and 83.9% R@1 image retrieval performance on the same datasets, setting new state-of-the-art (SOTA) results.
 44 | 
 45 | ---
 46 | 
 47 | ### **CLIPS in LLaVA**
 48 | 
 49 | 
 50 | Replacing OpenAI-CLIP with **CLIPS** significantly boosts LLaVA's performance across various benchmarks.
 51 | 
 52 | ---
 53 | 
 54 | ## **Model Zoo**
 55 | 
 56 | | Model          | Link                                                                                     |
 57 | |----------------|------------------------------------------------------------------------------------------|
 58 | | CLIPS-Large-14-224 | [🤗 HuggingFace Model](https://huggingface.co/UCSC-VLAA/ViT-L-14-CLIPS-224-Recap-DataComp-1B) |
 59 | | CLIPS-Large-14-336 | [🤗 HuggingFace Model](https://huggingface.co/UCSC-VLAA/ViT-L-14-CLIPS-Recap-DataComp-1B) |
 60 | | CLIPS-Huge-14-224  | [🤗 HuggingFace Model](https://huggingface.co/UCSC-VLAA/ViT-H-14-CLIPS-224-Recap-DataComp-1B) |
 61 | | CLIPS-Huge-14-336  | Coming Soon...                                                                          |
 62 | 
 63 | ## **Model Usage**
 64 | ### **Environment**
 65 | Install dependencies:
 66 | ```
 67 | pip3 install -r requirements.txt
 68 | ```
 69 | ### **With OpenCLIP**
 70 | ```python
 71 | import torch
 72 | import torch.nn.functional as F
 73 | from urllib.request import urlopen
 74 | from PIL import Image
 75 | from open_clip import create_model_from_pretrained, get_tokenizer
 76 | 
 77 | model, preprocess = create_model_from_pretrained('hf-hub:UCSC-VLAA/ViT-L-14-CLIPS-Recap-DataComp-1B')
 78 | tokenizer = get_tokenizer('hf-hub:UCSC-VLAA/ViT-L-14-CLIPS-Recap-DataComp-1B')
 79 | 
 80 | image = Image.open(urlopen(
 81 |     'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
 82 | ))
 83 | image = preprocess(image).unsqueeze(0)
 84 | 
 85 | text = tokenizer(["a diagram", "a dog", "a cat", "a beignet"], context_length=model.context_length)
 86 | 
 87 | with torch.no_grad(), torch.cuda.amp.autocast():
 88 |     image_features = model.encode_image(image)
 89 |     text_features = model.encode_text(text)
 90 |     image_features = F.normalize(image_features, dim=-1)
 91 |     text_features = F.normalize(text_features, dim=-1)
 92 | 
 93 |     text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
 94 | 
 95 | print("Label probs:", text_probs)  # prints: [[0., 0., 0., 1.0]]
 96 | ```
 97 | #### Note: We made modifications to the tokenizer implementation in open_clip/tokenizer.py.
 98 | 
 99 | ## Acknowledgement
100 | 
101 | This pytorch repo is built on [OpenCLIP](https://github.com/mlfoundations/open_clip). 
102 | Many thanks to the awesome works from the open-source community!
103 | 
104 | We would like to thank TPU Research Cloud (TRC) program, Google Cloud Research Credits program, and AWS Cloud Credit for Research program for supporting our computing needs.
105 | 
106 | ---
107 | 
108 | ## **Citation**
109 | 
110 | If you use our work, please cite it:
111 | 
112 | ```bibtex
113 | @article{liu2024clips,
114 |   title={CLIPS: An Enhanced CLIP Framework for Learning with Synthetic Captions},
115 |   author={Liu, Yanqing and Li, Xianhang and Wang, Zeyu and Zhao, Bingchen and Xie, Cihang},
116 |   journal={arXiv preprint arXiv:2411.16828},
117 |   year={2024}
118 | }
119 | 
--------------------------------------------------------------------------------
/docs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/.DS_Store
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # CLIPS
2 | [https://ucsc-vlaa.github.io/CLIPS/](https://ucsc-vlaa.github.io/CLIPS/)
3 | 
--------------------------------------------------------------------------------
/docs/resources/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/.DS_Store
--------------------------------------------------------------------------------
/docs/resources/LLaVA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/LLaVA.png
--------------------------------------------------------------------------------
/docs/resources/ar.svg:
--------------------------------------------------------------------------------
1 | 
--------------------------------------------------------------------------------
/docs/resources/block_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/block_mask.jpg
--------------------------------------------------------------------------------
/docs/resources/first_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/first_mask.jpg
--------------------------------------------------------------------------------
/docs/resources/gr.svg:
--------------------------------------------------------------------------------
 1 | 
20 | 
--------------------------------------------------------------------------------
/docs/resources/hf_dataset.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/hf_dataset.jpg
--------------------------------------------------------------------------------
/docs/resources/icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/icon.png
--------------------------------------------------------------------------------
/docs/resources/mask_strategy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/mask_strategy.jpg
--------------------------------------------------------------------------------
/docs/resources/method.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/method.jpg
--------------------------------------------------------------------------------
/docs/resources/random_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/random_mask.jpg
--------------------------------------------------------------------------------
/docs/resources/retrieval.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/retrieval.png
--------------------------------------------------------------------------------
/docs/resources/sota.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/sota.png
--------------------------------------------------------------------------------
/docs/resources/subcaption_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/subcaption_mask.jpg
--------------------------------------------------------------------------------
/docs/resources/tw.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/docs/resources/tw.png
--------------------------------------------------------------------------------
/docs/static/css/bulma-carousel.min.css:
--------------------------------------------------------------------------------
  1 | @-webkit-keyframes spinAround {
  2 |     from {
  3 |         -webkit-transform: rotate(0);
  4 |         transform: rotate(0)
  5 |     }
  6 | 
  7 |     to {
  8 |         -webkit-transform: rotate(359deg);
  9 |         transform: rotate(359deg)
 10 |     }
 11 | }
 12 | 
 13 | @keyframes spinAround {
 14 |     from {
 15 |         -webkit-transform: rotate(0);
 16 |         transform: rotate(0)
 17 |     }
 18 | 
 19 |     to {
 20 |         -webkit-transform: rotate(359deg);
 21 |         transform: rotate(359deg)
 22 |     }
 23 | }
 24 | 
 25 | .slider {
 26 |     position: relative;
 27 |     width: 100%
 28 | }
 29 | 
 30 | .slider-container {
 31 |     display: flex;
 32 |     flex-wrap: nowrap;
 33 |     flex-direction: row;
 34 |     overflow: hidden;
 35 |     -webkit-transform: translate3d(0, 0, 0);
 36 |     transform: translate3d(0, 0, 0);
 37 |     min-height: 100%
 38 | }
 39 | 
 40 | .slider-container.is-vertical {
 41 |     flex-direction: column
 42 | }
 43 | 
 44 | .slider-container .slider-item {
 45 |     flex: none
 46 | }
 47 | 
 48 | .slider-container .slider-item .image.is-covered img {
 49 |     -o-object-fit: cover;
 50 |     object-fit: cover;
 51 |     -o-object-position: center center;
 52 |     object-position: center center;
 53 |     height: 100%;
 54 |     width: 100%
 55 | }
 56 | 
 57 | .slider-container .slider-item .video-container {
 58 |     height: 0;
 59 |     padding-bottom: 0;
 60 |     padding-top: 56.25%;
 61 |     margin: 0;
 62 |     position: relative
 63 | }
 64 | 
 65 | .slider-container .slider-item .video-container.is-1by1,
 66 | .slider-container .slider-item .video-container.is-square {
 67 |     padding-top: 100%
 68 | }
 69 | 
 70 | .slider-container .slider-item .video-container.is-4by3 {
 71 |     padding-top: 75%
 72 | }
 73 | 
 74 | .slider-container .slider-item .video-container.is-21by9 {
 75 |     padding-top: 42.857143%
 76 | }
 77 | 
 78 | .slider-container .slider-item .video-container embed,
 79 | .slider-container .slider-item .video-container iframe,
 80 | .slider-container .slider-item .video-container object {
 81 |     position: absolute;
 82 |     top: 0;
 83 |     left: 0;
 84 |     width: 100% !important;
 85 |     height: 100% !important
 86 | }
 87 | 
 88 | .slider-navigation-next,
 89 | .slider-navigation-previous {
 90 |     display: flex;
 91 |     justify-content: center;
 92 |     align-items: center;
 93 |     position: absolute;
 94 |     width: 42px;
 95 |     height: 42px;
 96 |     background: #fff center center no-repeat;
 97 |     background-size: 20px 20px;
 98 |     border: 1px solid #fff;
 99 |     border-radius: 25091983px;
100 |     box-shadow: 0 2px 5px #3232321a;
101 |     top: 50%;
102 |     margin-top: -20px;
103 |     left: 0;
104 |     cursor: pointer;
105 |     transition: opacity .3s, -webkit-transform .3s;
106 |     transition: transform .3s, opacity .3s;
107 |     transition: transform .3s, opacity .3s, -webkit-transform .3s
108 | }
109 | 
110 | .slider-navigation-next:hover,
111 | .slider-navigation-previous:hover {
112 |     -webkit-transform: scale(1.2);
113 |     transform: scale(1.2)
114 | }
115 | 
116 | .slider-navigation-next.is-hidden,
117 | .slider-navigation-previous.is-hidden {
118 |     display: none;
119 |     opacity: 0
120 | }
121 | 
122 | .slider-navigation-next svg,
123 | .slider-navigation-previous svg {
124 |     width: 25%
125 | }
126 | 
127 | .slider-navigation-next {
128 |     left: auto;
129 |     right: 0;
130 |     background: #fff center center no-repeat;
131 |     background-size: 20px 20px
132 | }
133 | 
134 | .slider-pagination {
135 |     display: none;
136 |     justify-content: center;
137 |     align-items: center;
138 |     position: absolute;
139 |     bottom: 0;
140 |     left: 0;
141 |     right: 0;
142 |     padding: .5rem 1rem;
143 |     text-align: center
144 | }
145 | 
146 | .slider-pagination .slider-page {
147 |     background: #fff;
148 |     width: 10px;
149 |     height: 10px;
150 |     border-radius: 25091983px;
151 |     display: inline-block;
152 |     margin: 0 3px;
153 |     box-shadow: 0 2px 5px #3232321a;
154 |     transition: -webkit-transform .3s;
155 |     transition: transform .3s;
156 |     transition: transform .3s, -webkit-transform .3s;
157 |     cursor: pointer
158 | }
159 | 
160 | .slider-pagination .slider-page.is-active,
161 | .slider-pagination .slider-page:hover {
162 |     -webkit-transform: scale(1.4);
163 |     transform: scale(1.4)
164 | }
165 | 
166 | @media screen and (min-width:800px) {
167 |     .slider-pagination {
168 |         display: flex
169 |     }
170 | }
171 | 
172 | .hero.has-carousel {
173 |     position: relative
174 | }
175 | 
176 | .hero.has-carousel+.hero-body,
177 | .hero.has-carousel+.hero-footer,
178 | .hero.has-carousel+.hero-head {
179 |     z-index: 10;
180 |     overflow: hidden
181 | }
182 | 
183 | .hero.has-carousel .hero-carousel {
184 |     position: absolute;
185 |     top: 0;
186 |     left: 0;
187 |     bottom: 0;
188 |     right: 0;
189 |     height: auto;
190 |     border: none;
191 |     margin: auto;
192 |     padding: 0;
193 |     z-index: 0
194 | }
195 | 
196 | .hero.has-carousel .hero-carousel .slider {
197 |     width: 100%;
198 |     max-width: 100%;
199 |     overflow: hidden;
200 |     height: 100% !important;
201 |     max-height: 100%;
202 |     z-index: 0
203 | }
204 | 
205 | .hero.has-carousel .hero-carousel .slider .has-background {
206 |     max-height: 100%
207 | }
208 | 
209 | .hero.has-carousel .hero-carousel .slider .has-background .is-background {
210 |     -o-object-fit: cover;
211 |     object-fit: cover;
212 |     -o-object-position: center center;
213 |     object-position: center center;
214 |     height: 100%;
215 |     width: 100%
216 | }
217 | 
218 | .hero.has-carousel .hero-body {
219 |     margin: 0 3rem;
220 |     z-index: 10
221 | }
--------------------------------------------------------------------------------
/docs/static/css/index.css:
--------------------------------------------------------------------------------
  1 | body {
  2 |   font-family: 'Noto Sans', sans-serif;
  3 | }
  4 | 
  5 | 
  6 | .footer .icon-link {
  7 |     font-size: 25px;
  8 |     color: #000;
  9 | }
 10 | 
 11 | .link-block a {
 12 |     margin-top: 5px;
 13 |     margin-bottom: 5px;
 14 | }
 15 | 
 16 | .dnerf {
 17 |   font-variant: small-caps;
 18 | }
 19 | 
 20 | 
 21 | .teaser .hero-body {
 22 |   padding-top: 0;
 23 |   padding-bottom: 0;
 24 | }
 25 | 
 26 | .teaser {
 27 |   font-family: 'Google Sans', sans-serif;
 28 | }
 29 | 
 30 | 
 31 | .publication-title {
 32 | }
 33 | 
 34 | .publication-banner {
 35 |   max-height: parent;
 36 | 
 37 | }
 38 | 
 39 | .publication-banner video {
 40 |   position: relative;
 41 |   left: auto;
 42 |   top: auto;
 43 |   transform: none;
 44 |   object-fit: fit;
 45 | }
 46 | 
 47 | .publication-header .hero-body {
 48 | }
 49 | 
 50 | .publication-title {
 51 |     font-family: 'Google Sans', sans-serif;
 52 | }
 53 | 
 54 | .publication-authors {
 55 |     font-family: 'Google Sans', sans-serif;
 56 | }
 57 | 
 58 | .publication-venue {
 59 |     color: #555;
 60 |     width: fit-content;
 61 |     font-weight: bold;
 62 | }
 63 | 
 64 | .publication-awards {
 65 |     color: #ff3860;
 66 |     width: fit-content;
 67 |     font-weight: bolder;
 68 | }
 69 | 
 70 | .publication-authors {
 71 | }
 72 | 
 73 | .publication-authors a {
 74 |    color: hsl(204, 86%, 53%) !important;
 75 | }
 76 | 
 77 | .publication-authors a:hover {
 78 |     text-decoration: underline;
 79 | }
 80 | 
 81 | .author-block {
 82 |   display: inline-block;
 83 | }
 84 | 
 85 | .publication-banner img {
 86 | }
 87 | 
 88 | .publication-authors {
 89 |   /*color: #4286f4;*/
 90 | }
 91 | 
 92 | .publication-video {
 93 |     position: relative;
 94 |     width: 100%;
 95 |     height: 0;
 96 |     padding-bottom: 56.25%;
 97 | 
 98 |     overflow: hidden;
 99 |     border-radius: 10px !important;
100 | }
101 | 
102 | .publication-video iframe {
103 |     position: absolute;
104 |     top: 0;
105 |     left: 0;
106 |     width: 100%;
107 |     height: 100%;
108 | }
109 | 
110 | .publication-body img {
111 | }
112 | 
113 | .results-carousel {
114 |   overflow: hidden;
115 | }
116 | 
117 | .results-carousel .item {
118 |   margin: 5px;
119 |   overflow: hidden;
120 |   border: 1px solid #bbb;
121 |   border-radius: 10px;
122 |   padding: 0;
123 |   font-size: 0;
124 | }
125 | 
126 | .results-carousel video {
127 |   margin: 0;
128 | }
129 | 
130 | 
131 | .interpolation-panel {
132 |   background: #f5f5f5;
133 |   border-radius: 10px;
134 | }
135 | 
136 | .interpolation-panel .interpolation-image {
137 |   width: 100%;
138 |   border-radius: 5px;
139 | }
140 | 
141 | .interpolation-video-column {
142 | }
143 | 
144 | .interpolation-panel .slider {
145 |   margin: 0 !important;
146 | }
147 | 
148 | .interpolation-panel .slider {
149 |   margin: 0 !important;
150 | }
151 | 
152 | #interpolation-image-wrapper {
153 |   width: 100%;
154 | }
155 | #interpolation-image-wrapper img {
156 |   border-radius: 5px;
157 | }
158 | 
--------------------------------------------------------------------------------
/docs/static/js/bulma-slider.min.js:
--------------------------------------------------------------------------------
1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default});
--------------------------------------------------------------------------------
/docs/static/js/index.js:
--------------------------------------------------------------------------------
 1 | window.HELP_IMPROVE_VIDEOJS = false;
 2 | 
 3 | var INTERP_BASE = "./static/interpolation/stacked";
 4 | var NUM_INTERP_FRAMES = 240;
 5 | 
 6 | var interp_images = [];
 7 | function preloadInterpolationImages() {
 8 |   for (var i = 0; i < NUM_INTERP_FRAMES; i++) {
 9 |     var path = INTERP_BASE + '/' + String(i).padStart(6, '0') + '.jpg';
10 |     interp_images[i] = new Image();
11 |     interp_images[i].src = path;
12 |   }
13 | }
14 | 
15 | function setInterpolationImage(i) {
16 |   var image = interp_images[i];
17 |   image.ondragstart = function() { return false; };
18 |   image.oncontextmenu = function() { return false; };
19 |   $('#interpolation-image-wrapper').empty().append(image);
20 | }
21 | 
22 | 
23 | $(document).ready(function() {
24 |     // Check for click events on the navbar burger icon
25 |     $(".navbar-burger").click(function() {
26 |       // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu"
27 |       $(".navbar-burger").toggleClass("is-active");
28 |       $(".navbar-menu").toggleClass("is-active");
29 | 
30 |     });
31 | 
32 |     var options = {
33 | 			slidesToScroll: 1,
34 | 			slidesToShow: 3,
35 | 			loop: true,
36 | 			infinite: true,
37 | 			autoplay: false,
38 | 			autoplaySpeed: 3000,
39 |     }
40 | 
41 | 		// Initialize all div with carousel class
42 |     var carousels = bulmaCarousel.attach('.carousel', options);
43 | 
44 |     // Loop on each carousel initialized
45 |     for(var i = 0; i < carousels.length; i++) {
46 |     	// Add listener to  event
47 |     	carousels[i].on('before:show', state => {
48 |     		console.log(state);
49 |     	});
50 |     }
51 | 
52 |     // Access to bulmaCarousel instance of an element
53 |     var element = document.querySelector('#my-element');
54 |     if (element && element.bulmaCarousel) {
55 |     	// bulmaCarousel instance is available as element.bulmaCarousel
56 |     	element.bulmaCarousel.on('before-show', function(state) {
57 |     		console.log(state);
58 |     	});
59 |     }
60 | 
61 |     /*var player = document.getElementById('interpolation-video');
62 |     player.addEventListener('loadedmetadata', function() {
63 |       $('#interpolation-slider').on('input', function(event) {
64 |         console.log(this.value, player.duration);
65 |         player.currentTime = player.duration / 100 * this.value;
66 |       })
67 |     }, false);*/
68 |     preloadInterpolationImages();
69 | 
70 |     $('#interpolation-slider').on('input', function(event) {
71 |       setInterpolationImage(this.value);
72 |     });
73 |     setInterpolationImage(0);
74 |     $('#interpolation-slider').prop('max', NUM_INTERP_FRAMES - 1);
75 | 
76 |     bulmaSlider.attach();
77 | 
78 | })
79 | 
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import torch.nn.functional as F
 3 | from urllib.request import urlopen
 4 | from PIL import Image
 5 | from open_clip import create_model_from_pretrained, get_tokenizer
 6 | 
 7 | model, preprocess = create_model_from_pretrained('hf-hub:UCSC-VLAA/ViT-L-14-CLIPS-Recap-DataComp-1B')
 8 | tokenizer = get_tokenizer('hf-hub:UCSC-VLAA/ViT-L-14-CLIPS-Recap-DataComp-1B')
 9 | 
10 | image = Image.open(urlopen(
11 |     'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
12 | ))
13 | image = preprocess(image).unsqueeze(0)
14 | 
15 | text = tokenizer(["a diagram", "a dog", "a cat", "a beignet"], context_length=model.context_length)
16 | 
17 | with torch.no_grad(), torch.cuda.amp.autocast():
18 |     image_features = model.encode_image(image)
19 |     text_features = model.encode_text(text)
20 |     image_features = F.normalize(image_features, dim=-1)
21 |     text_features = F.normalize(text_features, dim=-1)
22 | 
23 |     text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
24 | 
25 | print("Label probs:", text_probs)  # prints: [[0., 0., 0., 1.0]]
26 | 
--------------------------------------------------------------------------------
/open_clip/__init__.py:
--------------------------------------------------------------------------------
 1 | from .version import __version__
 2 | 
 3 | from .coca_model import CoCa
 4 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
 5 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
 6 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint
 7 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss
 8 | from .model import CLIPS, CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
 9 |     convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \
10 |     get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg
11 | from .openai import load_openai_model, list_openai_models
12 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
13 |     get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
14 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
15 | from .tokenizer import SimpleTokenizer, tokenize, decode
16 | from .transform import image_transform, AugmentationCfg
17 | from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy
18 | from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES
19 | 
--------------------------------------------------------------------------------
/open_clip/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/coca_model.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/coca_model.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/constants.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/constants.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/convert.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/convert.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/factory.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/factory.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/hf_configs.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/hf_configs.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/hf_model.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/hf_model.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/loss.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/loss.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/model.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/model.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/modified_resnet.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/modified_resnet.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/openai.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/openai.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/pos_embed.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/pos_embed.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/pretrained.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/pretrained.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/push_to_hf_hub.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/push_to_hf_hub.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/timm_model.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/timm_model.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/tokenizer.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/tokenizer.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/transform.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/transform.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/transformer.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/transformer.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/utils.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/version.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/version.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/zero_shot_classifier.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/zero_shot_classifier.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/zero_shot_metadata.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/__pycache__/zero_shot_metadata.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/open_clip/constants.py:
--------------------------------------------------------------------------------
 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
 3 | IMAGENET_MEAN = (0.485, 0.456, 0.406)
 4 | IMAGENET_STD = (0.229, 0.224, 0.225)
 5 | INCEPTION_MEAN = (0.5, 0.5, 0.5)
 6 | INCEPTION_STD = (0.5, 0.5, 0.5)
 7 | 
 8 | # Default name for a weights file hosted on the Huggingface Hub.
 9 | HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin"  # default pytorch pkl
10 | HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors"  # safetensors version
11 | HF_CONFIG_NAME = 'open_clip_config.json'
12 | 
--------------------------------------------------------------------------------
/open_clip/hf_configs.py:
--------------------------------------------------------------------------------
 1 | # HF architecture dict:
 2 | arch_dict = {
 3 |     # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
 4 |     "roberta": {
 5 |         "config_names": {
 6 |             "context_length": "max_position_embeddings",
 7 |             "vocab_size": "vocab_size",
 8 |             "width": "hidden_size",
 9 |             "heads": "num_attention_heads",
10 |             "layers": "num_hidden_layers",
11 |             "layer_attr": "layer",
12 |             "token_embeddings_attr": "embeddings"
13 |         },
14 |         "pooler": "mean_pooler",
15 |     },
16 |     # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17 |     "xlm-roberta": {
18 |         "config_names": {
19 |             "context_length": "max_position_embeddings",
20 |             "vocab_size": "vocab_size",
21 |             "width": "hidden_size",
22 |             "heads": "num_attention_heads",
23 |             "layers": "num_hidden_layers",
24 |             "layer_attr": "layer",
25 |             "token_embeddings_attr": "embeddings"
26 |         },
27 |         "pooler": "mean_pooler",
28 |     },
29 |     # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30 |     "mt5": {
31 |         "config_names": {
32 |             # unlimited seqlen
33 |             # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34 |             # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35 |             "context_length": "",
36 |             "vocab_size": "vocab_size",
37 |             "width": "d_model",
38 |             "heads": "num_heads",
39 |             "layers": "num_layers",
40 |             "layer_attr": "block",
41 |             "token_embeddings_attr": "embed_tokens"
42 |         },
43 |         "pooler": "mean_pooler",
44 |     },
45 |     # https://huggingface.co/docs/transformers/model_doc/bert
46 |     "bert": {
47 |         "config_names": {
48 |             "context_length": "max_position_embeddings",
49 |             "vocab_size": "vocab_size",
50 |             "width": "hidden_size",
51 |             "heads": "num_attention_heads",
52 |             "layers": "num_hidden_layers",
53 |         },
54 |         "pooler": "cls_pooler",
55 |     },
56 |     # https://huggingface.co/docs/transformers/model_doc/m2m_100
57 |     "m2m_100": {
58 |         "config_names": {
59 |             "context_length": "max_position_embeddings",
60 |             "vocab_size": "vocab_size",
61 |             "width": "d_model",
62 |             "heads": "encoder_attention_heads",
63 |             "layers": "encoder_layers",
64 |         },
65 |         "pooler": "cls_pooler",
66 |     },
67 | }
68 | 
--------------------------------------------------------------------------------
/open_clip/hf_model.py:
--------------------------------------------------------------------------------
  1 | """ huggingface model adapter
  2 | 
  3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
  4 | """
  5 | import re
  6 | 
  7 | import torch
  8 | import torch.nn as nn
  9 | from torch import TensorType
 10 | 
 11 | try:
 12 |     import transformers
 13 |     from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
 14 |     from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
 15 |         BaseModelOutputWithPoolingAndCrossAttentions
 16 | except ImportError as e:
 17 |     transformers = None
 18 | 
 19 | 
 20 |     class BaseModelOutput:
 21 |         pass
 22 | 
 23 | 
 24 |     class PretrainedConfig:
 25 |         pass
 26 | 
 27 | from .hf_configs import arch_dict
 28 | 
 29 | 
 30 | # utils
 31 | def _camel2snake(s):
 32 |     return re.sub(r'(? 1
 18 |         self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
 19 |         self.bn1 = nn.BatchNorm2d(planes)
 20 |         self.act1 = nn.ReLU(inplace=True)
 21 | 
 22 |         self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
 23 |         self.bn2 = nn.BatchNorm2d(planes)
 24 |         self.act2 = nn.ReLU(inplace=True)
 25 | 
 26 |         self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
 27 | 
 28 |         self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
 29 |         self.bn3 = nn.BatchNorm2d(planes * self.expansion)
 30 |         self.act3 = nn.ReLU(inplace=True)
 31 | 
 32 |         self.downsample = None
 33 |         self.stride = stride
 34 | 
 35 |         if stride > 1 or inplanes != planes * Bottleneck.expansion:
 36 |             # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
 37 |             self.downsample = nn.Sequential(OrderedDict([
 38 |                 ("-1", nn.AvgPool2d(stride)),
 39 |                 ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
 40 |                 ("1", nn.BatchNorm2d(planes * self.expansion))
 41 |             ]))
 42 | 
 43 |     def forward(self, x: torch.Tensor):
 44 |         identity = x
 45 | 
 46 |         out = self.act1(self.bn1(self.conv1(x)))
 47 |         out = self.act2(self.bn2(self.conv2(out)))
 48 |         out = self.avgpool(out)
 49 |         out = self.bn3(self.conv3(out))
 50 | 
 51 |         if self.downsample is not None:
 52 |             identity = self.downsample(x)
 53 | 
 54 |         out += identity
 55 |         out = self.act3(out)
 56 |         return out
 57 | 
 58 | 
 59 | class AttentionPool2d(nn.Module):
 60 |     def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
 61 |         super().__init__()
 62 |         self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
 63 |         self.k_proj = nn.Linear(embed_dim, embed_dim)
 64 |         self.q_proj = nn.Linear(embed_dim, embed_dim)
 65 |         self.v_proj = nn.Linear(embed_dim, embed_dim)
 66 |         self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
 67 |         self.num_heads = num_heads
 68 | 
 69 |     def forward(self, x):
 70 |         x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)  # NCHW -> (HW)NC
 71 |         x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
 72 |         x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
 73 |         x, _ = F.multi_head_attention_forward(
 74 |             query=x, key=x, value=x,
 75 |             embed_dim_to_check=x.shape[-1],
 76 |             num_heads=self.num_heads,
 77 |             q_proj_weight=self.q_proj.weight,
 78 |             k_proj_weight=self.k_proj.weight,
 79 |             v_proj_weight=self.v_proj.weight,
 80 |             in_proj_weight=None,
 81 |             in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
 82 |             bias_k=None,
 83 |             bias_v=None,
 84 |             add_zero_attn=False,
 85 |             dropout_p=0.,
 86 |             out_proj_weight=self.c_proj.weight,
 87 |             out_proj_bias=self.c_proj.bias,
 88 |             use_separate_proj_weight=True,
 89 |             training=self.training,
 90 |             need_weights=False
 91 |         )
 92 | 
 93 |         return x[0]
 94 | 
 95 | 
 96 | class ModifiedResNet(nn.Module):
 97 |     """
 98 |     A ResNet class that is similar to torchvision's but contains the following changes:
 99 |     - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
100 |     - Performs antialiasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
101 |     - The final pooling layer is a QKV attention instead of an average pool
102 |     """
103 | 
104 |     def __init__(
105 |             self,
106 |             layers: List[int],
107 |             output_dim: int,
108 |             heads: int,
109 |             image_size: int = 224,
110 |             width: int = 64,
111 |     ):
112 |         super().__init__()
113 |         self.output_dim = output_dim
114 |         self.image_size = image_size
115 | 
116 |         # the 3-layer stem
117 |         self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
118 |         self.bn1 = nn.BatchNorm2d(width // 2)
119 |         self.act1 = nn.ReLU(inplace=True)
120 |         self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
121 |         self.bn2 = nn.BatchNorm2d(width // 2)
122 |         self.act2 = nn.ReLU(inplace=True)
123 |         self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
124 |         self.bn3 = nn.BatchNorm2d(width)
125 |         self.act3 = nn.ReLU(inplace=True)
126 |         self.avgpool = nn.AvgPool2d(2)
127 | 
128 |         # residual layers
129 |         self._inplanes = width  # this is a *mutable* variable used during construction
130 |         self.layer1 = self._make_layer(width, layers[0])
131 |         self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
132 |         self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
133 |         self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
134 | 
135 |         embed_dim = width * 32  # the ResNet feature dimension
136 |         self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
137 | 
138 |         self.init_parameters()
139 | 
140 |     def _make_layer(self, planes, blocks, stride=1):
141 |         layers = [Bottleneck(self._inplanes, planes, stride)]
142 | 
143 |         self._inplanes = planes * Bottleneck.expansion
144 |         for _ in range(1, blocks):
145 |             layers.append(Bottleneck(self._inplanes, planes))
146 | 
147 |         return nn.Sequential(*layers)
148 | 
149 |     def init_parameters(self):
150 |         if self.attnpool is not None:
151 |             std = self.attnpool.c_proj.in_features ** -0.5
152 |             nn.init.normal_(self.attnpool.q_proj.weight, std=std)
153 |             nn.init.normal_(self.attnpool.k_proj.weight, std=std)
154 |             nn.init.normal_(self.attnpool.v_proj.weight, std=std)
155 |             nn.init.normal_(self.attnpool.c_proj.weight, std=std)
156 | 
157 |         for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
158 |             for name, param in resnet_block.named_parameters():
159 |                 if name.endswith("bn3.weight"):
160 |                     nn.init.zeros_(param)
161 | 
162 |     def lock(self, unlocked_groups=0, freeze_bn_stats=False):
163 |         assert unlocked_groups == 0, 'partial locking not currently supported for this model'
164 |         for param in self.parameters():
165 |             param.requires_grad = False
166 |         if freeze_bn_stats:
167 |             freeze_batch_norm_2d(self)
168 | 
169 |     @torch.jit.ignore
170 |     def set_grad_checkpointing(self, enable=True):
171 |         # FIXME support for non-transformer
172 |         pass
173 | 
174 |     def stem(self, x):
175 |         x = self.act1(self.bn1(self.conv1(x)))
176 |         x = self.act2(self.bn2(self.conv2(x)))
177 |         x = self.act3(self.bn3(self.conv3(x)))
178 |         x = self.avgpool(x)
179 |         return x
180 | 
181 |     def forward_intermediates(
182 |             self,
183 |             x: torch.Tensor,
184 |             indices: Optional[Union[int, List[int]]] = None,
185 |             stop_early: bool = False,
186 |             normalize_intermediates: bool = False,
187 |             intermediates_only: bool = False,
188 |             output_fmt: str = 'NCHW',
189 |             output_extra_tokens: bool = False,
190 |     ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
191 |         """ Forward features that returns intermediates.
192 | 
193 |         Args:
194 |             x: Input image tensor
195 |             indices: Take last n blocks if int, all if None, select matching indices if sequence
196 |             stop_early: Stop iterating over blocks when last desired intermediate hit
197 |             normalize_intermediates: Apply final norm layer to all intermediates
198 |             intermediates_only: Only return intermediate features
199 |             output_fmt: Shape of intermediate feature outputs
200 |             output_extra_tokens: Return both extra class, eot tokens
201 |         Returns:
202 | 
203 |         """
204 |         assert output_fmt in ('NCHW',), 'Output format must be == NCHW.'
205 |         # NOTE normalize_intermediates and return_extra_tokens don't apply
206 |         take_indices, max_index = feature_take_indices(5, indices)
207 | 
208 |         output = {}
209 |         intermediates = []
210 |         blocks = [self.stem, self.layer1, self.layer2, self.layer3, self.layer4]
211 |         if torch.jit.is_scripting() or not stop_early:  # can't slice blocks in torchscript
212 |             blocks = blocks[:max_index + 1]
213 |         for i, blk in enumerate(blocks):
214 |             x = blk(x)
215 |             if i in take_indices:
216 |                 intermediates.append(x)
217 | 
218 |         output['image_intermediates'] = intermediates
219 | 
220 |         if intermediates_only:
221 |             return output
222 | 
223 |         x = self.attnpool(x)
224 |         output['image_features'] = x
225 | 
226 |         return output
227 | 
228 |     def forward(self, x):
229 |         x = self.stem(x)
230 |         x = self.layer1(x)
231 |         x = self.layer2(x)
232 |         x = self.layer3(x)
233 |         x = self.layer4(x)
234 |         x = self.attnpool(x)
235 | 
236 |         return x
237 | 
--------------------------------------------------------------------------------
/open_clip/openai.py:
--------------------------------------------------------------------------------
 1 | """ OpenAI pretrained model functions
 2 | 
 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
 4 | """
 5 | 
 6 | import os
 7 | import warnings
 8 | from typing import List, Optional, Union
 9 | 
10 | import torch
11 | 
12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
14 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
15 | 
16 | __all__ = ["list_openai_models", "load_openai_model"]
17 | 
18 | 
19 | def list_openai_models() -> List[str]:
20 |     """Returns the names of available CLIP models"""
21 |     return list_pretrained_models_by_tag('openai')
22 | 
23 | 
24 | def load_openai_model(
25 |         name: str,
26 |         precision: Optional[str] = None,
27 |         device: Optional[Union[str, torch.device]] = None,
28 |         cache_dir: Optional[str] = None,
29 | ):
30 |     """Load a CLIP model
31 | 
32 |     Parameters
33 |     ----------
34 |     name : str
35 |         A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36 |     precision: str
37 |         Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38 |     device : Union[str, torch.device]
39 |         The device to put the loaded model
40 |     cache_dir : Optional[str]
41 |         The directory to cache the downloaded model weights
42 | 
43 |     Returns
44 |     -------
45 |     model : torch.nn.Module
46 |         The CLIP model
47 |     preprocess : Callable[[PIL.Image], torch.Tensor]
48 |         A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
49 |     """
50 |     if device is None:
51 |         device = "cuda" if torch.cuda.is_available() else "cpu"
52 |     if precision is None:
53 |         precision = 'fp32' if device == 'cpu' else 'fp16'
54 | 
55 |     if get_pretrained_url(name, 'openai'):
56 |         model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
57 |     elif os.path.isfile(name):
58 |         model_path = name
59 |     else:
60 |         raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
61 | 
62 |     try:
63 |         # loading JIT archive
64 |         model = torch.jit.load(model_path, map_location="cpu").eval()
65 |         state_dict = None
66 |     except RuntimeError:
67 |         # loading saved state dict
68 |         state_dict = torch.load(model_path, map_location="cpu")
69 | 
70 |     # Build a non-jit model from the OpenAI jitted model state dict
71 |     cast_dtype = get_cast_dtype(precision)
72 |     try:
73 |         model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
74 |     except KeyError:
75 |         sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
76 |         model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
77 | 
78 |     # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
79 |     model = model.to(device)
80 |     # FIXME support pure fp16/bf16 precision modes
81 |     if precision != 'fp16':
82 |         model.float()
83 |         if precision == 'bf16':
84 |             # for bf16, convert back to low-precision
85 |             convert_weights_to_lp(model, dtype=torch.bfloat16)
86 | 
87 |     # add mean / std attributes for consistency with OpenCLIP models
88 |     model.visual.image_mean = OPENAI_DATASET_MEAN
89 |     model.visual.image_std = OPENAI_DATASET_STD
90 |     return model
91 | 
--------------------------------------------------------------------------------
/open_clip/pos_embed.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | # All rights reserved.
 3 | 
 4 | # This source code is licensed under the license found in the
 5 | # LICENSE file in the root directory of this source tree.
 6 | # --------------------------------------------------------
 7 | # Position embedding utils
 8 | # --------------------------------------------------------
 9 | 
10 | import numpy as np
11 | 
12 | import torch
13 | 
14 | # --------------------------------------------------------
15 | # 2D sine-cosine position embedding
16 | # References:
17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18 | # MoCo v3: https://github.com/facebookresearch/moco-v3
19 | # --------------------------------------------------------
20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21 |     """
22 |     grid_size: int of the grid height and width
23 |     return:
24 |     pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25 |     """
26 |     grid_h = np.arange(grid_size, dtype=np.float32)
27 |     grid_w = np.arange(grid_size, dtype=np.float32)
28 |     grid = np.meshgrid(grid_w, grid_h)  # here w goes first
29 |     grid = np.stack(grid, axis=0)
30 | 
31 |     grid = grid.reshape([2, 1, grid_size, grid_size])
32 |     pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33 |     if cls_token:
34 |         pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35 |     return pos_embed
36 | 
37 | 
38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39 |     assert embed_dim % 2 == 0
40 | 
41 |     # use half of dimensions to encode grid_h
42 |     emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
43 |     emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)
44 | 
45 |     emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46 |     return emb
47 | 
48 | 
49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50 |     """
51 |     embed_dim: output dimension for each position
52 |     pos: a list of positions to be encoded: size (M,)
53 |     out: (M, D)
54 |     """
55 |     assert embed_dim % 2 == 0
56 |     omega = np.arange(embed_dim // 2, dtype=float)
57 |     omega /= embed_dim / 2.
58 |     omega = 1. / 10000**omega  # (D/2,)
59 | 
60 |     pos = pos.reshape(-1)  # (M,)
61 |     out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product
62 | 
63 |     emb_sin = np.sin(out) # (M, D/2)
64 |     emb_cos = np.cos(out) # (M, D/2)
65 | 
66 |     emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
67 |     return emb
68 | 
69 | 
70 | # --------------------------------------------------------
71 | # Interpolate position embeddings for high-resolution
72 | # References:
73 | # DeiT: https://github.com/facebookresearch/deit
74 | # --------------------------------------------------------
75 | def interpolate_pos_embed(model, checkpoint_model):
76 |     if 'pos_embed' in checkpoint_model:
77 |         pos_embed_checkpoint = checkpoint_model['pos_embed']
78 |         embedding_size = pos_embed_checkpoint.shape[-1]
79 |         num_patches = model.patch_embed.num_patches
80 |         num_extra_tokens = model.pos_embed.shape[-2] - num_patches
81 |         # height (== width) for the checkpoint position embedding
82 |         orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
83 |         # height (== width) for the new position embedding
84 |         new_size = int(num_patches ** 0.5)
85 |         # class_token and dist_token are kept unchanged
86 |         if orig_size != new_size:
87 |             print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
88 |             extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
89 |             # only the position tokens are interpolated
90 |             pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
91 |             pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
92 |             pos_tokens = torch.nn.functional.interpolate(
93 |                 pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
94 |             pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
95 |             new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
96 |             checkpoint_model['pos_embed'] = new_pos_embed
97 | 
--------------------------------------------------------------------------------
/open_clip/timm_model.py:
--------------------------------------------------------------------------------
  1 | """ timm model adapter
  2 | 
  3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
  4 | """
  5 | import logging
  6 | from collections import OrderedDict
  7 | from typing import Dict, List, Optional, Tuple, Union
  8 | 
  9 | import torch
 10 | import torch.nn as nn
 11 | 
 12 | try:
 13 |     import timm
 14 |     from timm.layers import RotAttentionPool2d
 15 |     from timm.layers import AttentionPool2d as AbsAttentionPool2d
 16 |     from timm.layers import Mlp, to_2tuple
 17 | except ImportError:
 18 |     timm = None
 19 | 
 20 | from .utils import freeze_batch_norm_2d
 21 | 
 22 | 
 23 | class TimmModel(nn.Module):
 24 |     """ timm model adapter
 25 |     """
 26 | 
 27 |     def __init__(
 28 |             self,
 29 |             model_name: str,
 30 |             embed_dim: int,
 31 |             image_size: Union[int, Tuple[int, int]] = 224,
 32 |             pool: str = 'avg',
 33 |             proj: str = 'linear',
 34 |             proj_bias: bool = False,
 35 |             drop: float = 0.,
 36 |             drop_path: Optional[float] = None,
 37 |             patch_drop: Optional[float] = None,
 38 |             pretrained: bool = False,
 39 |     ):
 40 |         super().__init__()
 41 |         if timm is None:
 42 |             raise RuntimeError("Please install the latest timm (`pip install timm`) to use timm based models.")
 43 |         self.image_size = to_2tuple(image_size)
 44 | 
 45 |         # setup kwargs that may not be common across all models
 46 |         timm_kwargs = {}
 47 |         if drop_path is not None:
 48 |             timm_kwargs['drop_path_rate'] = drop_path
 49 |         if patch_drop is not None:
 50 |             timm_kwargs['patch_drop_rate'] = patch_drop
 51 | 
 52 |         custom_pool = pool in ('abs_attn', 'rot_attn')
 53 |         if proj:
 54 |             assert proj in ("linear", "mlp", "none")
 55 |         extra_proj = proj in ("linear", "mlp")
 56 |         if not extra_proj and not custom_pool:
 57 |             # use network classifier head as projection if no proj specified and no custom pooling used
 58 |             # if projection is explicitly set to "none" will be pass through from network trunk
 59 |             proj_dim = 0 if proj == 'none' else embed_dim
 60 |             self.trunk = timm.create_model(
 61 |                 model_name,
 62 |                 num_classes=proj_dim,
 63 |                 global_pool=pool,
 64 |                 pretrained=pretrained,
 65 |                 **timm_kwargs,
 66 |             )
 67 |             prev_chs = embed_dim
 68 |         else:
 69 |             self.trunk = timm.create_model(
 70 |                 model_name,
 71 |                 pretrained=pretrained,
 72 |                 **timm_kwargs,
 73 |             )
 74 |             feat_size = self.trunk.default_cfg.get('pool_size', None)
 75 |             feature_ndim = 1 if not feat_size else 2
 76 |             if custom_pool:
 77 |                 assert feature_ndim == 2
 78 |                 # if attn pooling used, remove both classifier and default pool
 79 |                 self.trunk.reset_classifier(0, global_pool='')
 80 |             else:
 81 |                 # reset global pool if pool config set, otherwise leave as network default
 82 |                 reset_kwargs = dict(global_pool=pool) if pool else {}
 83 |                 self.trunk.reset_classifier(0, **reset_kwargs)
 84 |             prev_chs = self.trunk.num_features
 85 | 
 86 |         head_layers = OrderedDict()
 87 | 
 88 |         # Add custom pooling to head
 89 |         if pool == 'abs_attn':
 90 |             head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
 91 |             prev_chs = embed_dim
 92 |         elif pool == 'rot_attn':
 93 |             head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
 94 |             prev_chs = embed_dim
 95 | 
 96 |         # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
 97 |         if proj == 'linear':
 98 |             head_layers['drop'] = nn.Dropout(drop)
 99 |             head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
100 |         elif proj == 'mlp':
101 |             head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
102 | 
103 |         self.head = nn.Sequential(head_layers)
104 | 
105 |     def lock(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False):
106 |         """ lock modules
107 |         Args:
108 |             unlocked_groups (int): leave last n layer groups unlocked (default: 0)
109 |         """
110 |         if not unlocked_groups:
111 |             # lock full model
112 |             for param in self.trunk.parameters():
113 |                 param.requires_grad = False
114 |             if freeze_bn_stats:
115 |                 freeze_batch_norm_2d(self.trunk)
116 |         else:
117 |             # NOTE: partial freeze requires latest timm (master) branch and is subject to change
118 |             try:
119 |                 # FIXME import here until API stable and in an official release
120 |                 from timm.models.helpers import group_parameters, group_modules
121 |             except ImportError:
122 |                 raise RuntimeError(
123 |                     'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
124 |             matcher = self.trunk.group_matcher()
125 |             gparams = group_parameters(self.trunk, matcher)
126 |             max_layer_id = max(gparams.keys())
127 |             max_layer_id = max_layer_id - unlocked_groups
128 |             for group_idx in range(max_layer_id + 1):
129 |                 group = gparams[group_idx]
130 |                 for param in group:
131 |                     self.trunk.get_parameter(param).requires_grad = False
132 |             if freeze_bn_stats:
133 |                 gmodules = group_modules(self.trunk, matcher, reverse=True)
134 |                 gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
135 |                 freeze_batch_norm_2d(self.trunk, gmodules)
136 | 
137 |     @torch.jit.ignore
138 |     def set_grad_checkpointing(self, enable: bool = True):
139 |         try:
140 |             self.trunk.set_grad_checkpointing(enable)
141 |         except Exception as e:
142 |             logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
143 | 
144 |     def forward_intermediates(
145 |             self,
146 |             x: torch.Tensor,
147 |             indices: Optional[Union[int, List[int]]] = None,
148 |             stop_early: bool = False,
149 |             normalize_intermediates: bool = False,
150 |             intermediates_only: bool = False,
151 |             output_fmt: str = 'NCHW',
152 |             output_extra_tokens: bool = False,
153 |     ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
154 |         """ Forward features that returns intermediates.
155 | 
156 |         Args:
157 |             x: Input image tensor
158 |             indices: Take last n blocks if int, all if None, select matching indices if sequence
159 |             stop_early: Stop iterating over blocks when last desired intermediate hit
160 |             normalize_intermediates: Apply norm layer to all intermediates
161 |             intermediates_only: Only return intermediate features
162 |             output_fmt: Shape of intermediate feature outputs
163 |             output_extra_tokens: Return both prefix and spatial intermediate tokens
164 |         Returns:
165 |         """
166 |         extra_args = {}
167 |         if output_extra_tokens:
168 |             extra_args['return_prefix_tokens'] = True
169 |         trunk_output = self.trunk.forward_intermediates(
170 |                 x,
171 |                 indices=indices,
172 |                 intermediates_only=intermediates_only,
173 |                 norm=normalize_intermediates,
174 |                 stop_early=stop_early,
175 |                 output_fmt=output_fmt,
176 |                 **extra_args,
177 |             )
178 | 
179 |         return_dict = {}
180 |         intermediates = trunk_output if intermediates_only else trunk_output[1]
181 |         if output_extra_tokens and intermediates and isinstance(intermediates[0], tuple):
182 |             intermediates_prefix = [xi[1] for xi in intermediates]
183 |             intermediates = [xi[0] for xi in intermediates]
184 |             return_dict['image_intermediates_prefix'] = intermediates_prefix
185 | 
186 |         return_dict['image_intermediates'] = intermediates
187 |         if intermediates_only:
188 |             return return_dict
189 | 
190 |         image_features = self.trunk.forward_head(trunk_output[0])  # run through timm pooling / projection
191 |         image_features = self.head(image_features) # run through adapter pooling / projection
192 |         return_dict['image_features'] = image_features
193 |         return return_dict
194 | 
195 |     def forward(self, x):
196 |         x = self.trunk(x)
197 |         x = self.head(x)
198 |         return x
199 | 
--------------------------------------------------------------------------------
/open_clip/utils.py:
--------------------------------------------------------------------------------
  1 | import collections.abc
  2 | from itertools import repeat
  3 | from typing import List, Optional, Tuple, Union
  4 | 
  5 | import torch
  6 | from torch import nn as nn
  7 | from torch import _assert
  8 | from torchvision.ops.misc import FrozenBatchNorm2d
  9 | 
 10 | 
 11 | def freeze_batch_norm_2d(module, module_match={}, name=''):
 12 |     """
 13 |     Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
 14 |     itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
 15 |     returned. Otherwise, the module is walked recursively and submodules are converted in place.
 16 | 
 17 |     Args:
 18 |         module (torch.nn.Module): Any PyTorch module.
 19 |         module_match (dict): Dictionary of full module names to freeze (all if empty)
 20 |         name (str): Full module name (prefix)
 21 | 
 22 |     Returns:
 23 |         torch.nn.Module: Resulting module
 24 | 
 25 |     Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
 26 |     """
 27 |     res = module
 28 |     is_match = True
 29 |     if module_match:
 30 |         is_match = name in module_match
 31 |     if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
 32 |         res = FrozenBatchNorm2d(module.num_features)
 33 |         res.num_features = module.num_features
 34 |         res.affine = module.affine
 35 |         if module.affine:
 36 |             res.weight.data = module.weight.data.clone().detach()
 37 |             res.bias.data = module.bias.data.clone().detach()
 38 |         res.running_mean.data = module.running_mean.data
 39 |         res.running_var.data = module.running_var.data
 40 |         res.eps = module.eps
 41 |     else:
 42 |         for child_name, child in module.named_children():
 43 |             full_child_name = '.'.join([name, child_name]) if name else child_name
 44 |             new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
 45 |             if new_child is not child:
 46 |                 res.add_module(child_name, new_child)
 47 |     return res
 48 | 
 49 | 
 50 | # From PyTorch internals
 51 | def _ntuple(n):
 52 |     def parse(x):
 53 |         if isinstance(x, collections.abc.Iterable):
 54 |             return x
 55 |         return tuple(repeat(x, n))
 56 |     return parse
 57 | 
 58 | 
 59 | to_1tuple = _ntuple(1)
 60 | to_2tuple = _ntuple(2)
 61 | to_3tuple = _ntuple(3)
 62 | to_4tuple = _ntuple(4)
 63 | to_ntuple = lambda n, x: _ntuple(n)(x)
 64 | 
 65 | # Replaces all linear layers with linear_replacement
 66 | # TODO: add int8 support for other linear layers including attn and convnets
 67 | def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True):
 68 |     for name, module in model.named_children():
 69 |         if len(list(module.children())) > 0:
 70 |             replace_linear(module, linear_replacement, include_modules, copy_weights)
 71 | 
 72 |         if isinstance(module, torch.nn.Linear) and name in include_modules:
 73 |             old_module = model._modules[name]
 74 |             model._modules[name] = linear_replacement(
 75 |                 module.in_features,
 76 |                 module.out_features,
 77 |                 module.bias is not None,
 78 |             )
 79 |             if copy_weights:
 80 |                 model._modules[name].weight.data.copy_(old_module.weight.data)
 81 |                 if model._modules[name].bias is not None:
 82 |                     model._modules[name].bias.data.copy_(old_module.bias)
 83 | 
 84 |     return model
 85 | 
 86 | def convert_int8_model_to_inference_mode(model):
 87 |     for m in model.modules():
 88 |         if hasattr(m, 'prepare_for_eval'):
 89 |             int8_original_dtype = m.weight.dtype
 90 |             m.prepare_for_eval()
 91 |             m.int8_original_dtype = int8_original_dtype
 92 | 
 93 | 
 94 | def feature_take_indices(
 95 |         num_features: int,
 96 |         indices: Optional[Union[int, List[int]]] = None,
 97 |         as_set: bool = False,
 98 | ) -> Tuple[List[int], int]:
 99 |     """ Determine the absolute feature indices to 'take' from.
100 | 
101 |     Note: This function can be called in forward() so must be torchscript compatible,
102 |     which requires some incomplete typing and workaround hacks.
103 | 
104 |     Args:
105 |         num_features: total number of features to select from
106 |         indices: indices to select,
107 |           None -> select all
108 |           int -> select last n
109 |           list/tuple of int -> return specified (-ve indices specify from end)
110 |         as_set: return as a set
111 | 
112 |     Returns:
113 |         List (or set) of absolute (from beginning) indices, Maximum index
114 |     """
115 |     if indices is None:
116 |         indices = num_features  # all features if None
117 | 
118 |     if isinstance(indices, int):
119 |         # convert int -> last n indices
120 |         _assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})')
121 |         take_indices = [num_features - indices + i for i in range(indices)]
122 |     else:
123 |         take_indices: List[int] = []
124 |         for i in indices:
125 |             idx = num_features + i if i < 0 else i
126 |             _assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})')
127 |             take_indices.append(idx)
128 | 
129 |     if not torch.jit.is_scripting() and as_set:
130 |         return set(take_indices), max(take_indices)
131 | 
132 |     return take_indices, max(take_indices)
133 | 
134 | 
135 | def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:
136 |     if isinstance(x, int):
137 |         # if indices is an int, take last N features
138 |         return tuple(range(-x, 0))
139 |     return tuple(x)
140 | 
--------------------------------------------------------------------------------
/open_clip/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '2.31.0'
2 | 
--------------------------------------------------------------------------------
/open_clip/zero_shot_classifier.py:
--------------------------------------------------------------------------------
  1 | from functools import partial
  2 | from itertools import islice
  3 | from typing import Callable, List, Optional, Sequence, Union
  4 | 
  5 | import torch
  6 | import torch.nn.functional as F
  7 | 
  8 | 
  9 | def batched(iterable, n):
 10 |     """Batch data into lists of length *n*. The last batch may be shorter.
 11 |     NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl
 12 |     """
 13 |     it = iter(iterable)
 14 |     while True:
 15 |         batch = list(islice(it, n))
 16 |         if not batch:
 17 |             break
 18 |         yield batch
 19 | 
 20 | 
 21 | def build_zero_shot_classifier(
 22 |         model,
 23 |         tokenizer,
 24 |         classnames: Sequence[str],
 25 |         templates: Sequence[Union[Callable, str]],
 26 |         num_classes_per_batch: Optional[int] = 10,
 27 |         device: Union[str, torch.device] = 'cpu',
 28 |         use_tqdm: bool = False,
 29 | ):
 30 |     """ Build zero-shot classifier weights by iterating over class names in batches
 31 |     Args:
 32 |         model: CLIP model instance
 33 |         tokenizer: CLIP tokenizer instance
 34 |         classnames: A sequence of class (label) names
 35 |         templates: A sequence of callables or format() friendly strings to produce templates per class name
 36 |         num_classes_per_batch: The number of classes to batch together in each forward, all if None
 37 |         device: Device to use.
 38 |         use_tqdm: Enable TQDM progress bar.
 39 |     """
 40 |     assert isinstance(templates, Sequence) and len(templates) > 0
 41 |     assert isinstance(classnames, Sequence) and len(classnames) > 0
 42 |     use_format = isinstance(templates[0], str)
 43 |     num_templates = len(templates)
 44 |     num_classes = len(classnames)
 45 |     if use_tqdm:
 46 |         import tqdm
 47 |         num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1)
 48 |         iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch)
 49 |     else:
 50 |         iter_wrap = iter
 51 | 
 52 |     def _process_batch(batch_classnames):
 53 |         num_batch_classes = len(batch_classnames)
 54 |         texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates]
 55 |         texts = tokenizer(texts).to(device)
 56 |         class_embeddings = model.encode_text(texts, normalize=True)
 57 |         class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1)
 58 |         class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True)
 59 |         class_embeddings = class_embeddings.T
 60 |         return class_embeddings
 61 | 
 62 |     with torch.no_grad():
 63 |         if num_classes_per_batch:
 64 |             batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))]
 65 |             zeroshot_weights = torch.cat(batched_embeds, dim=1)
 66 |         else:
 67 |             zeroshot_weights = _process_batch(classnames)
 68 |     return zeroshot_weights
 69 | 
 70 | 
 71 | def build_zero_shot_classifier_legacy(
 72 |         model,
 73 |         tokenizer,
 74 |         classnames: Sequence[str],
 75 |         templates: Sequence[Union[Callable, str]],
 76 |         device: Union[str, torch.device] = 'cpu',
 77 |         use_tqdm: bool = False,
 78 | ):
 79 |     """ Build zero-shot classifier weights by iterating over class names 1 by 1
 80 |     Args:
 81 |         model: CLIP model instance
 82 |         tokenizer: CLIP tokenizer instance
 83 |         classnames: A sequence of class (label) names
 84 |         templates: A sequence of callables or format() friendly strings to produce templates per class name
 85 |         device: Device to use.
 86 |         use_tqdm: Enable TQDM progress bar.
 87 |     """
 88 |     assert isinstance(templates, Sequence) and len(templates) > 0
 89 |     assert isinstance(classnames, Sequence) and len(classnames) > 0
 90 |     if use_tqdm:
 91 |         import tqdm
 92 |         iter_wrap = tqdm.tqdm
 93 |     else:
 94 |         iter_wrap = iter
 95 | 
 96 |     use_format = isinstance(templates[0], str)
 97 | 
 98 |     with torch.no_grad():
 99 |         zeroshot_weights = []
100 |         for classname in iter_wrap(classnames):
101 |             texts = [template.format(classname) if use_format else template(classname) for template in templates]
102 |             texts = tokenizer(texts).to(device)  # tokenize
103 |             class_embeddings = model.encode_text(texts)
104 |             class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
105 |             class_embedding /= class_embedding.norm()
106 |             zeroshot_weights.append(class_embedding)
107 |         zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
108 | 
109 |     return zeroshot_weights
110 | 
111 | 
--------------------------------------------------------------------------------
/open_clip_train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__init__.py
--------------------------------------------------------------------------------
/open_clip_train/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip_train/__pycache__/data.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/data.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip_train/__pycache__/distributed.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/distributed.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip_train/__pycache__/file_utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/file_utils.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip_train/__pycache__/logger.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/logger.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip_train/__pycache__/params.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/params.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip_train/__pycache__/precision.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/precision.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip_train/__pycache__/scheduler.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/scheduler.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip_train/__pycache__/train.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/train.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip_train/__pycache__/zero_shot.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCSC-VLAA/CLIPS/6b7d32d6994d0fb8ccdbc8263700363651ed5b64/open_clip_train/__pycache__/zero_shot.cpython-311.pyc
--------------------------------------------------------------------------------
/open_clip_train/distributed.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import warnings
  3 | from typing import Optional
  4 | 
  5 | import torch
  6 | import torch.distributed as dist
  7 | 
  8 | try:
  9 |     import horovod.torch as hvd
 10 | except ImportError:
 11 |     hvd = None
 12 | 
 13 | 
 14 | def is_global_master(args):
 15 |     return args.rank == 0
 16 | 
 17 | 
 18 | def is_local_master(args):
 19 |     return args.local_rank == 0
 20 | 
 21 | 
 22 | def is_master(args, local=False):
 23 |     return is_local_master(args) if local else is_global_master(args)
 24 | 
 25 | 
 26 | def is_device_available(device):
 27 |     device_type = torch.device(device).type
 28 |     is_avail = False
 29 |     is_known = False
 30 |     if device_type == 'cuda':
 31 |         is_avail = torch.cuda.is_available()
 32 |         is_known = True
 33 |     elif device_type == 'npu':
 34 |         # NOTE autoload device extension needed for this not to error out on this check
 35 |         is_avail = torch.npu.is_available()
 36 |         is_known = True
 37 |     elif device_type == 'mps':
 38 |         is_avail = torch.backends.mps.is_available()
 39 |         is_known = True
 40 |     elif device_type == 'cpu':
 41 |         is_avail = True
 42 |         is_known = True
 43 | 
 44 |     return is_avail, is_known
 45 | 
 46 | 
 47 | def set_device(device):
 48 |     if device.startswith('cuda:'):
 49 |         torch.cuda.set_device(device)
 50 |     elif device.startswith('npu:'):
 51 |         torch.npu.set_device(device)
 52 | 
 53 | 
 54 | def is_using_horovod():
 55 |     # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
 56 |     # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
 57 |     ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
 58 |     pmi_vars = ["PMI_RANK", "PMI_SIZE"]
 59 |     if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]):
 60 |         return True
 61 |     else:
 62 |         return False
 63 | 
 64 | 
 65 | def is_using_distributed():
 66 |     if 'WORLD_SIZE' in os.environ:
 67 |         return int(os.environ['WORLD_SIZE']) > 1
 68 |     if 'SLURM_NTASKS' in os.environ:
 69 |         return int(os.environ['SLURM_NTASKS']) > 1
 70 |     return False
 71 | 
 72 | 
 73 | def world_info_from_env():
 74 |     local_rank = 0
 75 |     for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
 76 |         if v in os.environ:
 77 |             local_rank = int(os.environ[v])
 78 |             break
 79 |     global_rank = 0
 80 |     for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
 81 |         if v in os.environ:
 82 |             global_rank = int(os.environ[v])
 83 |             break
 84 |     world_size = 1
 85 |     for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
 86 |         if v in os.environ:
 87 |             world_size = int(os.environ[v])
 88 |             break
 89 | 
 90 |     return local_rank, global_rank, world_size
 91 | 
 92 | 
 93 | def init_distributed_device(args):
 94 |     # Distributed training = training on more than one GPU.
 95 |     # Works in both single and multi-node scenarios.
 96 |     args.distributed = False
 97 |     args.world_size = 1
 98 |     args.rank = 0  # global rank
 99 |     args.local_rank = 0
100 |     result = init_distributed_device_so(
101 |         device=getattr(args, 'device', 'cuda'),
102 |         dist_backend=getattr(args, 'dist_backend', None),
103 |         dist_url=getattr(args, 'dist_url', None),
104 |         horovod=getattr(args, 'horovod', False),
105 |         no_set_device_rank=getattr(args, 'no_set_device_rank', False),
106 |     )
107 |     args.device = result['device']
108 |     args.world_size = result['world_size']
109 |     args.rank = result['global_rank']
110 |     args.local_rank = result['local_rank']
111 |     args.distributed = result['distributed']
112 |     device = torch.device(args.device)
113 |     return device
114 | 
115 | 
116 | def init_distributed_device_so(
117 |         device: str = 'cuda',
118 |         dist_backend: Optional[str] = None,
119 |         dist_url: Optional[str] = None,
120 |         horovod: bool = False,
121 |         no_set_device_rank: bool = False,
122 | ):
123 |     # Distributed training = training on more than one GPU.
124 |     # Works in both single and multi-node scenarios.
125 |     distributed = False
126 |     world_size = 1
127 |     global_rank = 0
128 |     local_rank = 0
129 |     device_type, *device_idx = device.split(':', maxsplit=1)
130 |     is_avail, is_known = is_device_available(device_type)
131 |     if not is_known:
132 |         warnings.warn(f"Device {device} was not known and checked for availability, trying anyways.")
133 |     elif not is_avail:
134 |         warnings.warn(f"Device {device} was not available, falling back to CPU.")
135 |         device_type = device = 'cpu'
136 | 
137 |     if horovod:
138 |         import horovod.torch as hvd
139 |         assert hvd is not None, "Horovod is not installed"
140 |         hvd.init()
141 |         local_rank = int(hvd.local_rank())
142 |         global_rank = hvd.rank()
143 |         world_size = hvd.size()
144 |         distributed = True
145 |     elif is_using_distributed():
146 |         if dist_backend is None:
147 |             dist_backends = {
148 |                 "cuda": "nccl",
149 |                 "hpu": "hccl",
150 |                 "npu": "hccl",
151 |                 "xpu": "ccl",
152 |             }
153 |             dist_backend = dist_backends.get(device_type, 'gloo')
154 | 
155 |         dist_url = dist_url or 'env://'
156 | 
157 |         if 'SLURM_PROCID' in os.environ:
158 |             # DDP via SLURM
159 |             local_rank, global_rank, world_size = world_info_from_env()
160 |             # SLURM var -> torch.distributed vars in case needed
161 |             os.environ['LOCAL_RANK'] = str(local_rank)
162 |             os.environ['RANK'] = str(global_rank)
163 |             os.environ['WORLD_SIZE'] = str(world_size)
164 |             torch.distributed.init_process_group(
165 |                 backend=dist_backend,
166 |                 init_method=dist_url,
167 |                 world_size=world_size,
168 |                 rank=global_rank,
169 |             )
170 |         else:
171 |             # DDP via torchrun, torch.distributed.launch
172 |             local_rank, _, _ = world_info_from_env()
173 |             torch.distributed.init_process_group(
174 |                 backend=dist_backend,
175 |                 init_method=dist_url,
176 |             )
177 |             world_size = torch.distributed.get_world_size()
178 |             global_rank = torch.distributed.get_rank()
179 |         distributed = True
180 | 
181 |     if distributed and not no_set_device_rank and device_type not in ('cpu', 'mps'):
182 |         # Ignore manually specified device index in distributed mode and
183 |         # override with resolved local rank, fewer headaches in most setups.
184 |         if device_idx:
185 |             warnings.warn(f'device index {device_idx[0]} removed from specified ({device}).')
186 |         device = f'{device_type}:{local_rank}'
187 |         set_device(device)
188 | 
189 |     return dict(
190 |         device=device,
191 |         global_rank=global_rank,
192 |         local_rank=local_rank,
193 |         world_size=world_size,
194 |         distributed=distributed,
195 |     )
196 | 
197 | 
198 | def broadcast_object(args, obj, src=0):
199 |     # broadcast a pickle-able python object from rank-0 to all ranks
200 |     if args.horovod:
201 |         return hvd.broadcast_object(obj, root_rank=src)
202 |     else:
203 |         if args.rank == src:
204 |             objects = [obj]
205 |         else:
206 |             objects = [None]
207 |         dist.broadcast_object_list(objects, src=src)
208 |         return objects[0]
209 | 
210 | 
211 | def all_gather_object(args, obj, dst=0):
212 |     # gather a pickle-able python object across all ranks
213 |     if args.horovod:
214 |         return hvd.allgather_object(obj)
215 |     else:
216 |         objects = [None for _ in range(args.world_size)]
217 |         dist.all_gather_object(objects, obj)
218 |         return objects
219 | 
--------------------------------------------------------------------------------
/open_clip_train/file_utils.py:
--------------------------------------------------------------------------------
 1 | import logging
 2 | import os
 3 | import multiprocessing
 4 | import subprocess
 5 | import time
 6 | import fsspec
 7 | import torch
 8 | from tqdm import tqdm
 9 | 
10 | def remote_sync_s3(local_dir, remote_dir):
11 |     # skip epoch_latest which can change during sync.
12 |     result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
13 |     if result.returncode != 0:
14 |         logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}")
15 |         return False
16 |         
17 |     logging.info(f"Successfully synced with S3 bucket")
18 |     return True
19 | 
20 | def remote_sync_fsspec(local_dir, remote_dir):
21 |     # FIXME currently this is slow and not recommended. Look into speeding up.
22 |     a = fsspec.get_mapper(local_dir)
23 |     b = fsspec.get_mapper(remote_dir)
24 | 
25 |     for k in a:
26 |         # skip epoch_latest which can change during sync.
27 |         if 'epoch_latest.pt' in k:
28 |             continue
29 | 
30 |         logging.info(f'Attempting to sync {k}')
31 |         if k in b and len(a[k]) == len(b[k]):
32 |             logging.debug(f'Skipping remote sync for {k}.')
33 |             continue
34 | 
35 |         try:
36 |             logging.info(f'Successful sync for {k}.')
37 |             b[k] = a[k]
38 |         except Exception as e:
39 |             logging.info(f'Error during remote sync for {k}: {e}')
40 |             return False
41 | 
42 |     return True
43 | 
44 | def remote_sync(local_dir, remote_dir, protocol):
45 |     logging.info('Starting remote sync.')
46 |     if protocol == 's3':
47 |         return remote_sync_s3(local_dir, remote_dir)
48 |     elif protocol == 'fsspec':
49 |         return remote_sync_fsspec(local_dir, remote_dir)
50 |     else:
51 |         logging.error('Remote protocol not known')
52 |         return False
53 | 
54 | def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol):
55 |     while True:
56 |         time.sleep(sync_every)
57 |         remote_sync(local_dir, remote_dir, protocol)
58 | 
59 | def start_sync_process(sync_every, local_dir, remote_dir, protocol):
60 |     p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol))
61 |     return p
62 | 
63 | # Note: we are not currently using this save function.
64 | def pt_save(pt_obj, file_path):
65 |     of = fsspec.open(file_path, "wb")
66 |     with of as f:
67 |         torch.save(pt_obj, file_path)
68 | 
69 | def pt_load(file_path, map_location=None):
70 |     if file_path.startswith('s3'):
71 |         logging.info('Loading remote checkpoint, which may take a bit.')
72 |     of = fsspec.open(file_path, "rb")
73 |     with of as f:
74 |         out = torch.load(f, map_location=map_location)
75 |     return out
76 | 
77 | def check_exists(file_path):
78 |     try:
79 |         with fsspec.open(file_path):
80 |             pass
81 |     except FileNotFoundError:
82 |         return False
83 |     return True
84 | 
--------------------------------------------------------------------------------
/open_clip_train/logger.py:
--------------------------------------------------------------------------------
 1 | import logging
 2 | 
 3 | 
 4 | def setup_logging(log_file, level, include_host=False):
 5 |     if include_host:
 6 |         import socket
 7 |         hostname = socket.gethostname()
 8 |         formatter = logging.Formatter(
 9 |             f'%(asctime)s |  {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
10 |     else:
11 |         formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
12 | 
13 |     logging.root.setLevel(level)
14 |     loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
15 |     for logger in loggers:
16 |         logger.setLevel(level)
17 | 
18 |     stream_handler = logging.StreamHandler()
19 |     stream_handler.setFormatter(formatter)
20 |     logging.root.addHandler(stream_handler)
21 | 
22 |     if log_file:
23 |         file_handler = logging.FileHandler(filename=log_file)
24 |         file_handler.setFormatter(formatter)
25 |         logging.root.addHandler(file_handler)
26 | 
27 | 
--------------------------------------------------------------------------------
/open_clip_train/precision.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | from contextlib import suppress
 3 | from functools import partial
 4 | 
 5 | 
 6 | def get_autocast(precision, device_type='cuda'):
 7 |     if precision =='amp':
 8 |         amp_dtype = torch.float16
 9 |     elif precision == 'amp_bfloat16' or precision == 'amp_bf16':
10 |         amp_dtype = torch.bfloat16
11 |     else:
12 |         return suppress
13 | 
14 |     return partial(torch.amp.autocast, device_type=device_type, dtype=amp_dtype)
--------------------------------------------------------------------------------
/open_clip_train/profiler.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | 
  3 | import torch
  4 | import open_clip
  5 | import pandas as pd
  6 | from torch.utils.flop_counter import FlopCounterMode
  7 | try:
  8 |     import fvcore
  9 | except:
 10 |     fvcore = None
 11 | 
 12 | parser = argparse.ArgumentParser(description='OpenCLIP Profiler')
 13 | 
 14 | # benchmark specific args
 15 | parser.add_argument('--model', metavar='NAME', default='',
 16 |                     help='model(s) to profile')
 17 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
 18 |                     help='Output csv file for results')
 19 | parser.add_argument('--profiler', default='torch', type=str, choices=['torch', 'fvcore'])
 20 | parser.add_argument('--batch-size', default=1, type=int, help='Batch size for profiling')
 21 | 
 22 | 
 23 | def profile_fvcore(
 24 |         model,
 25 |         image_input_size=(3, 224, 224),
 26 |         text_input_size=(77,),
 27 |         batch_size=1,
 28 |         detailed=False,
 29 |         force_cpu=False
 30 | ):
 31 |     if force_cpu:
 32 |         model = model.to('cpu')
 33 |     device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
 34 |     example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
 35 |     example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
 36 |     fca = fvcore.nn.FlopCountAnalysis(model, (example_image_input, example_text_input))
 37 |     aca = fvcore.nn.ActivationCountAnalysis(model, (example_image_input, example_text_input))
 38 |     if detailed:
 39 |         fcs = fvcore.nn.flop_count_str(fca)
 40 |         print(fcs)
 41 |     return fca.total() / batch_size, aca.total() / batch_size
 42 | 
 43 | 
 44 | def profile_fvcore_text(
 45 |         model,
 46 |         text_input_size=(77,),
 47 |         batch_size=1,
 48 |         detailed=False,
 49 |         force_cpu=False
 50 | ):
 51 |     if force_cpu:
 52 |         model = model.to('cpu')
 53 |     device = next(model.parameters()).device
 54 |     example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
 55 |     fca = fvcore.nn.FlopCountAnalysis(model, example_input)
 56 |     aca = fvcore.nn.ActivationCountAnalysis(model, example_input)
 57 |     if detailed:
 58 |         fcs = fvcore.nn.flop_count_str(fca)
 59 |         print(fcs)
 60 |     return fca.total() / batch_size, aca.total() / batch_size
 61 | 
 62 | 
 63 | def profile_fvcore_image(
 64 |         model,
 65 |         image_input_size=(3, 224, 224),
 66 |         batch_size=1,
 67 |         detailed=False,
 68 |         force_cpu=False
 69 | ):
 70 |     if force_cpu:
 71 |         model = model.to('cpu')
 72 |     device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
 73 |     example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
 74 |     fca = fvcore.nn.FlopCountAnalysis(model, example_input)
 75 |     aca = fvcore.nn.ActivationCountAnalysis(model, example_input)
 76 |     if detailed:
 77 |         fcs = fvcore.nn.flop_count_str(fca)
 78 |         print(fcs)
 79 |     return fca.total() / batch_size, aca.total() / batch_size
 80 | 
 81 | 
 82 | def profile_torch_image(model, image_input_size, batch_size=1, force_cpu=False):
 83 |     """Profile the image encoder using torch.utils.flop_counter"""
 84 |     if force_cpu:
 85 |         model = model.to('cpu')
 86 |     device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
 87 |     example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
 88 | 
 89 |     flop_counter = FlopCounterMode()
 90 |     with flop_counter:
 91 |         model(example_input)
 92 |     total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
 93 |     return total_flops / batch_size
 94 | 
 95 | 
 96 | def profile_torch_text(model, text_input_size, batch_size=1, force_cpu=False):
 97 |     """Profile the text encoder using torch.utils.flop_counter"""
 98 |     if force_cpu:
 99 |         model = model.to('cpu')
100 |     device = next(model.parameters()).device
101 |     example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
102 | 
103 |     flop_counter = FlopCounterMode()
104 |     with flop_counter:
105 |         model(example_input)
106 |     total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
107 |     return total_flops / batch_size
108 | 
109 | 
110 | def profile_torch(model, text_input_size, image_input_size, batch_size=1, force_cpu=False):
111 |     """Profile the full model using torch.utils.flop_counter"""
112 |     if force_cpu:
113 |         model = model.to('cpu')
114 |     device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
115 |     image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
116 |     text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
117 | 
118 |     flop_counter = FlopCounterMode()
119 |     with flop_counter:
120 |         model(image_input, text_input)
121 |     total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
122 |     return total_flops / batch_size
123 | 
124 | 
125 | def count_params(model):
126 |     return sum(m.numel() for m in model.parameters())
127 | 
128 | def profile_model(model_name, batch_size=1, profiler='torch', device="cuda"):
129 |     assert profiler in ['torch', 'fvcore'], 'Only torch and fvcore profilers are supported'
130 |     if profiler == 'fvcore':
131 |         assert fvcore is not None, 'Please install fvcore.'
132 |     model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False)
133 |     model.eval()
134 | 
135 |     if torch.cuda.is_available():
136 |         model = model.cuda()
137 |     elif device == "npu" and torch.npu.is_available():
138 |         model = model.npu()
139 | 
140 |     if isinstance(model.visual.image_size, (tuple, list)):
141 |         image_input_size = (3,) + tuple(model.visual.image_size[-2:])
142 |     else:
143 |         image_input_size = (3, model.visual.image_size, model.visual.image_size)
144 | 
145 |     text_input_size = (77,)
146 |     if hasattr(model, 'context_length') and model.context_length:
147 |         text_input_size = (model.context_length,)
148 | 
149 |     results = {}
150 |     results['model'] = model_name
151 |     results['image_size'] = image_input_size[1]
152 | 
153 |     model_cfg = open_clip.get_model_config(model_name)
154 |     if model_cfg:
155 |         vision_cfg = open_clip.CLIPVisionCfg(**model_cfg['vision_cfg'])
156 |         text_cfg = open_clip.CLIPTextCfg(**model_cfg['text_cfg'])
157 |         results['image_width'] = int(vision_cfg.width)
158 |         results['text_width'] = int(text_cfg.width)
159 |         results['embed_dim'] = int(model_cfg['embed_dim'])
160 |     else:
161 |         results['image_width'] = 0
162 |         results['text_width'] = 0
163 |         results['embed_dim'] = 0
164 | 
165 |     retries = 2
166 |     while retries:
167 |         retries -= 1
168 |         try:
169 |             results['mparams'] = round(count_params(model) / 1e6, 2)
170 |             results['image_mparams'] = round(count_params(model.visual) / 1e6, 2)
171 |             results['text_mparams'] = round(count_params(model.text) / 1e6, 2)
172 | 
173 |             if profiler == 'fvcore':
174 |                 macs, acts = profile_fvcore(
175 |                     model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
176 | 
177 |                 image_macs, image_acts = profile_fvcore_image(
178 |                     model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size)
179 | 
180 |                 text_macs, text_acts = profile_fvcore_text(
181 |                     model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
182 | 
183 |                 results['gmacs'] = round(macs / 1e9, 2)
184 |                 results['macts'] = round(acts / 1e6, 2)
185 |                 
186 |                 results['image_gmacs'] = round(image_macs / 1e9, 2)
187 |                 results['image_macts'] = round(image_acts / 1e6, 2)
188 |                 
189 |                 results['text_gmacs'] = round(text_macs / 1e9, 2)
190 |                 results['text_macts'] = round(text_acts / 1e6, 2)
191 |             elif profiler == 'torch':
192 |                 image_flops = profile_torch_image(
193 |                     model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size)
194 |                 text_flops = profile_torch_text(
195 |                     model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
196 |                 total_flops = profile_torch(
197 |                     model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
198 | 
199 |                 results['gflops'] = round(total_flops / 1e9, 2)
200 |                 results['image_gflops'] = round(image_flops / 1e9, 2)
201 |                 results['text_gflops'] = round(text_flops / 1e9, 2)
202 | 
203 |         except RuntimeError as e:
204 |             pass
205 |     return results
206 | 
207 | 
208 | def main():
209 |     args = parser.parse_args()
210 | 
211 |     # FIXME accept a text file name to allow lists of models in txt/csv
212 |     if args.model == 'all':
213 |         parsed_model = open_clip.list_models()
214 |     else:
215 |         parsed_model = args.model.split(',')
216 | 
217 |     results = []
218 |     models_with_errors = []
219 |     for m in parsed_model:
220 |         print('='*100)
221 |         print(f'Profiling {m}')
222 |         try:
223 |             row = profile_model(m, batch_size=args.batch_size, profiler=args.profiler, device=args.device)
224 |             results.append(row)
225 |         except Exception as e:
226 |             print(f'Error profiling {m}: {e}')
227 |             import traceback
228 |             traceback.print_exc()
229 |             models_with_errors.append(m)
230 | 
231 |     df = pd.DataFrame(results, columns=results[0].keys())
232 | 
233 |     if 'gmacs' in df.columns:
234 |         df = df.sort_values(by=['gmacs', 'mparams', 'model'])
235 |     else:
236 |         df = df.sort_values(by=['gflops', 'mparams', 'model'])
237 | 
238 |     print('='*100)
239 |     print('Done.')
240 |     print(df)
241 |     if args.results_file:
242 |         df.to_csv(args.results_file, index=False)
243 | 
244 |     if models_with_errors:
245 |         print('Models with errors:', models_with_errors)
246 | 
247 | 
248 | if __name__ == '__main__':
249 |     main()
250 | 
--------------------------------------------------------------------------------
/open_clip_train/scheduler.py:
--------------------------------------------------------------------------------
 1 | import math
 2 | 
 3 | 
 4 | def assign_learning_rate(optimizer, new_lr):
 5 |     for param_group in optimizer.param_groups:
 6 |         param_group["lr"] = new_lr
 7 | 
 8 | 
 9 | def _warmup_lr(base_lr, warmup_length, step):
10 |     return base_lr * (step + 1) / warmup_length
11 | 
12 | 
13 | def const_lr(optimizer, base_lr, warmup_length, steps):
14 |     def _lr_adjuster(step):
15 |         if step < warmup_length:
16 |             lr = _warmup_lr(base_lr, warmup_length, step)
17 |         else:
18 |             lr = base_lr
19 |         assign_learning_rate(optimizer, lr)
20 |         return lr
21 | 
22 |     return _lr_adjuster
23 | 
24 | 
25 | def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.):
26 |     def _lr_adjuster(step):
27 |         start_cooldown_step = steps - cooldown_steps
28 |         if step < warmup_length:
29 |             lr = _warmup_lr(base_lr, warmup_length, step)
30 |         else:
31 |             if step < start_cooldown_step:
32 |                 lr = base_lr
33 |             else:
34 |                 e = step - start_cooldown_step
35 |                 es = steps - start_cooldown_step
36 |                 # linear decay if power == 1; polynomial decay otherwise;
37 |                 decay = (1 - (e / es)) ** cooldown_power
38 |                 lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr
39 |         assign_learning_rate(optimizer, lr)
40 |         return lr
41 | 
42 |     return _lr_adjuster
43 | 
44 | 
45 | def cosine_lr(optimizer, base_lr, warmup_length, steps):
46 |     def _lr_adjuster(step):
47 |         if step < warmup_length:
48 |             lr = _warmup_lr(base_lr, warmup_length, step)
49 |         else:
50 |             e = step - warmup_length
51 |             es = steps - warmup_length
52 |             lr = 0.5 * (1 + math.cos(math.pi * e / es)) * base_lr
53 |         assign_learning_rate(optimizer, lr)
54 |         return lr
55 | 
56 |     return _lr_adjuster
57 | 
58 | 
--------------------------------------------------------------------------------
/open_clip_train/zero_shot.py:
--------------------------------------------------------------------------------
 1 | import logging
 2 | 
 3 | import torch
 4 | from tqdm import tqdm
 5 | 
 6 | from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \
 7 |     IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES
 8 | from open_clip_train.precision import get_autocast
 9 | 
10 | 
11 | def accuracy(output, target, topk=(1,)):
12 |     pred = output.topk(max(topk), 1, True, True)[1].t()
13 |     correct = pred.eq(target.view(1, -1).expand_as(pred))
14 |     return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
15 | 
16 | 
17 | def run(model, classifier, dataloader, args):
18 |     device = torch.device(args.device)
19 |     autocast = get_autocast(args.precision, device_type=device.type)
20 |     input_dtype = get_input_dtype(args.precision)
21 | 
22 |     with torch.inference_mode():
23 |         top1, top5, n = 0., 0., 0.
24 |         for images, target in tqdm(dataloader, unit_scale=args.batch_size):
25 |             images = images.to(device=device, dtype=input_dtype)
26 |             target = target.to(device)
27 | 
28 |             with autocast():
29 |                 # predict
30 |                 output = model(image=images)
31 |                 image_features = output['image_features'] if isinstance(output, dict) else output[0]
32 |                 logits = 100. * image_features @ classifier
33 | 
34 |             # measure accuracy
35 |             acc1, acc5 = accuracy(logits, target, topk=(1, 5))
36 |             top1 += acc1
37 |             top5 += acc5
38 |             n += images.size(0)
39 | 
40 |     top1 = (top1 / n)
41 |     top5 = (top5 / n)
42 |     return top1, top5
43 | 
44 | 
45 | def zero_shot_eval(model, data, epoch, args, tokenizer=None):
46 |     if 'imagenet-val' not in data and 'imagenet-v2' not in data:
47 |         return {}
48 |     if args.zeroshot_frequency == 0:
49 |         return {}
50 |     if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
51 |         return {}
52 |     if args.distributed and not args.horovod:
53 |         model = model.module
54 | 
55 |     logging.info('Starting zero-shot imagenet.')
56 |     if tokenizer is None:
57 |         tokenizer = get_tokenizer(args.model)
58 | 
59 |     logging.info('Building zero-shot classifier')
60 |     device = torch.device(args.device)
61 |     autocast = get_autocast(args.precision, device_type=device.type)
62 |     with autocast():
63 |         classifier = build_zero_shot_classifier(
64 |             model,
65 |             tokenizer=tokenizer,
66 |             classnames=IMAGENET_CLASSNAMES,
67 |             templates=OPENAI_IMAGENET_TEMPLATES,
68 |             num_classes_per_batch=10,
69 |             device=device,
70 |             use_tqdm=True,
71 |         )
72 | 
73 |     logging.info('Using classifier')
74 |     results = {}
75 |     if 'imagenet-val' in data:
76 |         top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args)
77 |         results['imagenet-zeroshot-val-top1'] = top1
78 |         results['imagenet-zeroshot-val-top5'] = top5
79 |     if 'imagenet-v2' in data:
80 |         top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args)
81 |         results['imagenetv2-zeroshot-val-top1'] = top1
82 |         results['imagenetv2-zeroshot-val-top5'] = top5
83 | 
84 |     logging.info('Finished zero-shot imagenet.')
85 | 
86 |     return results
87 | 
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.9.0
2 | torchvision
3 | regex
4 | ftfy
5 | tqdm
6 | huggingface_hub
7 | safetensors
8 | timm
9 | transformers
--------------------------------------------------------------------------------