├── LICENSE ├── README.md ├── assets └── teaser.png ├── opencls ├── configs │ ├── .DS_Store │ └── cls_schedule │ │ ├── .DS_Store │ │ ├── cls_vit_b16_s1.28B_bs16k.yaml │ │ ├── cls_vit_b16_s512m_bs16k.yaml │ │ ├── cls_vit_l14_224_s12.8B_bs90k.yaml │ │ ├── cls_vit_l14_s1.28B_bs16k.yaml │ │ ├── cls_vit_l16_s512m_bs16k.yaml │ │ ├── lit_vit_b16_s1.28B_bs16k.yaml │ │ ├── lit_vit_b16_s512m_bs16k.yaml │ │ ├── lit_vit_l14_224_s12.8B_bs90k.yaml │ │ ├── lit_vit_l14_s1.28B_bs16k.yaml │ │ └── lit_vit_l16_s512m_bs16k.yaml ├── open_clip │ ├── .DS_Store │ ├── __init__.py │ ├── big_vision.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── cls_model.py │ ├── coca_model.py │ ├── constants.py │ ├── factory.py │ ├── hf_configs.py │ ├── hf_model.py │ ├── loss.py │ ├── model.py │ ├── model_configs │ │ ├── .DS_Store │ │ ├── CLS-ViT-B-16.json │ │ ├── CLS-ViT-L-14.json │ │ ├── CLS-ViT-L-16.json │ │ ├── EVA01-g-14-plus.json │ │ ├── EVA01-g-14.json │ │ ├── EVA02-B-16.json │ │ ├── EVA02-E-14-plus.json │ │ ├── EVA02-E-14.json │ │ ├── EVA02-L-14-336.json │ │ ├── EVA02-L-14.json │ │ ├── RN101-quickgelu.json │ │ ├── RN101.json │ │ ├── RN50-quickgelu.json │ │ ├── RN50.json │ │ ├── RN50x16.json │ │ ├── RN50x4.json │ │ ├── RN50x64.json │ │ ├── ViT-B-16-SigLIP-256.json │ │ ├── ViT-B-16-SigLIP-384.json │ │ ├── ViT-B-16-SigLIP-512.json │ │ ├── ViT-B-16-SigLIP-i18n-256.json │ │ ├── ViT-B-16-SigLIP.json │ │ ├── ViT-B-16-avg.json │ │ ├── ViT-B-16-plus-240.json │ │ ├── ViT-B-16-plus.json │ │ ├── ViT-B-16-quickgelu.json │ │ ├── ViT-B-16.json │ │ ├── ViT-B-32-256.json │ │ ├── ViT-B-32-plus-256.json │ │ ├── ViT-B-32-quickgelu.json │ │ ├── ViT-B-32.json │ │ ├── ViT-H-14-378-quickgelu.json │ │ ├── ViT-H-14-CLIPA-336.json │ │ ├── ViT-H-14-CLIPA.json │ │ ├── ViT-H-14-quickgelu.json │ │ ├── ViT-H-14.json │ │ ├── ViT-H-16.json │ │ ├── ViT-L-14-280.json │ │ ├── ViT-L-14-336.json │ │ ├── ViT-L-14-CLIPA-336.json │ │ ├── ViT-L-14-CLIPA.json │ │ ├── ViT-L-14-avg.json │ │ ├── ViT-L-14-quickgelu.json │ │ ├── ViT-L-14.json │ │ ├── ViT-L-16-320.json │ │ ├── ViT-L-16-SigLIP-256.json │ │ ├── ViT-L-16-SigLIP-384.json │ │ ├── ViT-L-16-avg.json │ │ ├── ViT-L-16.json │ │ ├── ViT-M-16-alt.json │ │ ├── ViT-M-16.json │ │ ├── ViT-M-32-alt.json │ │ ├── ViT-M-32.json │ │ ├── ViT-S-16-alt.json │ │ ├── ViT-S-16.json │ │ ├── ViT-S-32-alt.json │ │ ├── ViT-S-32.json │ │ ├── ViT-SO400M-14-SigLIP-384.json │ │ ├── ViT-SO400M-14-SigLIP.json │ │ ├── ViT-bigG-14-CLIPA-336.json │ │ ├── ViT-bigG-14-CLIPA.json │ │ ├── ViT-bigG-14.json │ │ ├── ViT-e-14.json │ │ ├── ViT-g-14.json │ │ ├── coca_ViT-B-32.json │ │ ├── coca_ViT-L-14.json │ │ ├── coca_base.json │ │ ├── coca_roberta-ViT-B-32.json │ │ ├── convnext_base.json │ │ ├── convnext_base_w.json │ │ ├── convnext_base_w_320.json │ │ ├── convnext_large.json │ │ ├── convnext_large_d.json │ │ ├── convnext_large_d_320.json │ │ ├── convnext_small.json │ │ ├── convnext_tiny.json │ │ ├── convnext_xlarge.json │ │ ├── convnext_xxlarge.json │ │ ├── convnext_xxlarge_320.json │ │ ├── mt5-base-ViT-B-32.json │ │ ├── mt5-xl-ViT-H-14.json │ │ ├── nllb-clip-base-siglip.json │ │ ├── nllb-clip-base.json │ │ ├── nllb-clip-large-siglip.json │ │ ├── nllb-clip-large.json │ │ ├── roberta-ViT-B-32.json │ │ ├── swin_base_patch4_window7_224.json │ │ ├── vit_medium_patch16_gap_256.json │ │ ├── vit_relpos_medium_patch16_cls_224.json │ │ ├── xlm-roberta-base-ViT-B-32.json │ │ └── xlm-roberta-large-ViT-H-14.json │ ├── modified_resnet.py │ ├── openai.py │ ├── pos_embed.py │ ├── pretrained.py │ ├── push_to_hf_hub.py │ ├── timm_model.py │ ├── tokenizer.py │ ├── transform.py │ ├── transformer.py │ ├── utils.py │ ├── version.py │ ├── zero_shot_classifier.py │ └── zero_shot_metadata.py └── training │ ├── .DS_Store │ ├── .gitignore │ ├── __init__.py │ ├── data.py │ ├── distributed.py │ ├── file_utils.py │ ├── logger.py │ ├── main.py │ ├── params.py │ ├── precision.py │ ├── profiler0.py │ ├── scheduler.py │ ├── train.py │ └── zero_shot.py ├── requirements.txt ├── train.sh └── train_combo.sh /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

SuperClass: Classification Done Right for Vision-Language Pre-Training

3 | 4 | [**Zilong Huang**](http://speedinghzl.github.io/) · [**Qinghao Ye**](https://scholar.google.com/citations?user=ZYOhaGwAAAAJ&hl=zh-CN) · [**Bingyi Kang**](https://bingykang.github.io/) · [**Jiashi Feng**](https://sites.google.com/site/jshfeng/) · [**Haoqi Fan**](https://scholar.google.com/citations?user=76B8lrgAAAAJ&hl=en) 5 | 6 | Bytedance Seed 7 | 8 | Paper PDF 9 | 10 | 11 |
12 | 13 | This work presents SuperClass, a super simple classification method that performs vision-language pre-training. Our method does **not require a text encoder** to be pre-trained on image-text data. Instead, it utilizes **tokenized raw text** as **supervised classification labels**, without the need for additional text filtering or selection. 14 | 15 |
16 | teaser 17 |
18 | 19 | 20 | ## News 21 | 22 | - **2024-11-06:** Paper & code are all released. 23 | - **2024-10-02:** SuperClass is accepted by NeurIPS 2024. 24 | 25 | 26 | ## Usage 27 | 28 | ### Prepraration 29 | 30 | ```bash 31 | git clone https://github.com/x-cls/superclass 32 | cd superclass 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | Download the datasets [Datacomp-1B](https://github.com/mlfoundations/datacomp) and [ImageNet-1K](https://www.image-net.org/download.php). You can also use [other image-text pair datasets](https://github.com/rom1504/img2dataset/tree/main?tab=readme-ov-file#examples) for training. 37 | 38 | Modify the **DATA_PATH** and **VAL_DATA_PATH** in training script **train.sh** and **train_combo.sh** to your local paths to Datacomp-1B and ImageNet-1K. 39 | 40 | 41 | ### CLIP Training & Superclass Training 42 | 43 | To start CLIP training and superclass training, use the following command: 44 | 45 | ```bash 46 | bash train.sh opencls 47 | ``` 48 | 49 | This script will navigate to the opencls directory and execute the training. 50 | 51 | If you want to include the LiT training phase, use the following command: 52 | 53 | ```bash 54 | bash train_combo.sh opencls 55 | ``` 56 | 57 | CLS training config are here `opencls/configs/cls_schedule` 58 | 59 | 60 | For example: 61 | ```bash 62 | bash train.sh configs/cls_schedule/cls_vit_b16_s1.28B_bs16k.yaml opencls 63 | ``` 64 | 65 | Please note that the default **precision** during training is set to **amp_bfloat16**. If your GPU (e.g., V100) does not support bf16, please change it to **fp16** or **amp**. 66 | 67 | 68 | 69 | 70 | 71 | ## Acknowledgement 72 | Our codebase is built up on [OpenCLIP](https://github.com/mlfoundations/open_clip) and the [ViTamin](https://github.com/Beckschen/ViTamin). 73 | 74 | We thank the [OpenCLIP](https://github.com/mlfoundations/open_clip) and the [ViTamin](https://github.com/Beckschen/ViTamin) for contributing such impressive codes and models to our community. 75 | 76 | 77 | ## LICENSE 78 | 79 | The models & code of SuperClass are released under the Apache-2.0 license. 80 | 81 | 82 | ## Citation 83 | 84 | If you find this project useful, please consider citing: 85 | 86 | ```bibtex 87 | @inproceedings{superclass_huang, 88 | title={Classification Done Right for Vision-Language Pre-Training}, 89 | author={Huang, Zilong and Ye, Qinghao and Kang, Bingyi and Feng, Jiashi and Fan, Haoqi}, 90 | booktitle={NeurIPS}, 91 | year={2024} 92 | } 93 | ``` 94 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/x-cls/superclass/90b761f41177400caa178d3845b39eba771e5bf1/assets/teaser.png -------------------------------------------------------------------------------- /opencls/configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/x-cls/superclass/90b761f41177400caa178d3845b39eba771e5bf1/opencls/configs/.DS_Store -------------------------------------------------------------------------------- /opencls/configs/cls_schedule/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/x-cls/superclass/90b761f41177400caa178d3845b39eba771e5bf1/opencls/configs/cls_schedule/.DS_Store -------------------------------------------------------------------------------- /opencls/configs/cls_schedule/cls_vit_b16_s1.28B_bs16k.yaml: -------------------------------------------------------------------------------- 1 | save_frequency: 1 2 | name: "cls_vit_b16_s1.28B_bs16k" 3 | train_data: '/datasets/datacomp_1b/data/{00000..128000}.tar' # please modify to your own path 4 | train_num_samples: 1_280_000_000 5 | dataset_type: webdataset 6 | precision: 'amp_bfloat16' 7 | warmup: 500 8 | global_batch_size: 16384 9 | batch_size: 0 10 | epochs: 1 11 | lr: 5e-4 12 | beta1: 0.9 13 | beta2: 0.98 14 | eps: 1.0e-6 15 | workers: 6 16 | model: "CLS-ViT-B-16" 17 | seed: 0 18 | ddp_static_graph: true 19 | local_loss: true 20 | gather_with_grad: true 21 | force_image_size: 224 22 | grad_checkpointing: true 23 | 24 | logs: './logs' 25 | imagenet_val: './imagenet1k/val' # please modify to your own path 26 | 27 | report_to: "tensorboard" 28 | log_every_n_steps: 128 29 | zeroshot_steps: 0 30 | val_steps: 0 31 | zeroshot_frequency: 0 32 | val_frequency: 0 33 | save_every_n_steps: 6104 34 | delete_prev_step_ckpt: true 35 | 36 | resume: latest -------------------------------------------------------------------------------- /opencls/configs/cls_schedule/cls_vit_b16_s512m_bs16k.yaml: -------------------------------------------------------------------------------- 1 | save_frequency: 1 2 | name: "cls_vit_b16_s512m_bs16k" 3 | train_data: '/datasets/datacomp_1b/data/{00000..128000}.tar' # please modify to your own path 4 | train_num_samples: 512_000_000 5 | dataset_type: webdataset 6 | precision: 'amp_bfloat16' 7 | warmup: 500 8 | global_batch_size: 16384 9 | batch_size: 0 10 | epochs: 1 11 | lr: 5e-4 12 | beta1: 0.9 13 | beta2: 0.98 14 | eps: 1.0e-6 15 | workers: 6 16 | model: "CLS-ViT-B-16" 17 | seed: 0 18 | ddp_static_graph: true 19 | local_loss: true 20 | gather_with_grad: true 21 | force_image_size: 224 22 | grad_checkpointing: true 23 | 24 | logs: './logs' 25 | imagenet_val: './imagenet1k/val' # please modify to your own path 26 | 27 | report_to: "tensorboard" 28 | log_every_n_steps: 128 29 | zeroshot_steps: 0 30 | val_steps: 0 31 | zeroshot_frequency: 0 32 | val_frequency: 0 33 | save_every_n_steps: 6104 34 | delete_prev_step_ckpt: true 35 | 36 | resume: latest -------------------------------------------------------------------------------- /opencls/configs/cls_schedule/cls_vit_l14_224_s12.8B_bs90k.yaml: -------------------------------------------------------------------------------- 1 | save_frequency: 1 2 | name: "cls_vit_l14_224_s12.8B_bs90k" 3 | train_data: '/datasets/datacomp_1b/data/{00000..140089}.tar' # please modify to your own path 4 | train_num_samples: 1_280_000_000 5 | dataset_type: webdataset 6 | precision: 'amp_bfloat16' 7 | warmup: 10000 8 | global_batch_size: 90112 9 | batch_size: 0 10 | epochs: 10 11 | lr: 1.0e-3 12 | beta1: 0.9 13 | beta2: 0.95 14 | eps: 1.0e-6 15 | workers: 6 16 | model: "CLS-ViT-L-14" 17 | seed: 0 18 | ddp_static_graph: true 19 | local_loss: true 20 | gather_with_grad: true 21 | force_image_size: 224 22 | grad_checkpointing: true 23 | 24 | logs: './logs' 25 | imagenet_val: './imagenet1k/val' # please modify to your own path 26 | 27 | report_to: "tensorboard" 28 | log_every_n_steps: 32 29 | zeroshot_steps: 0 30 | val_steps: 0 31 | zeroshot_frequency: 0 32 | val_frequency: 0 33 | save_every_n_steps: 3052 34 | delete_prev_step_ckpt: true 35 | aug_cfg: {'scale': [0.4, 1.0], 'color_jitter': [0.32, 0.32, 0.32, 0.08], 'color_jitter_prob': 0.8, 'gray_scale_prob': 0.2} 36 | resume: latest -------------------------------------------------------------------------------- /opencls/configs/cls_schedule/cls_vit_l14_s1.28B_bs16k.yaml: -------------------------------------------------------------------------------- 1 | save_frequency: 1 2 | name: "cls_vit_l14_s1.28B_bs16k" 3 | train_data: '/datasets/datacomp_1b/data/{00000..128000}.tar' # please modify to your own path 4 | train_num_samples: 1_280_000_000 5 | dataset_type: webdataset 6 | precision: 'amp_bfloat16' 7 | warmup: 500 8 | global_batch_size: 16384 9 | batch_size: 0 10 | epochs: 1 11 | lr: 5e-4 12 | beta1: 0.9 13 | beta2: 0.98 14 | eps: 1.0e-6 15 | workers: 10 16 | model: "CLS-ViT-L-14" 17 | seed: 0 18 | ddp_static_graph: true 19 | local_loss: true 20 | gather_with_grad: true 21 | force_image_size: 224 22 | grad_checkpointing: true 23 | 24 | logs: './logs' 25 | imagenet_val: './imagenet1k/val' # please modify to your own path 26 | 27 | report_to: "tensorboard" 28 | log_every_n_steps: 128 29 | zeroshot_steps: 0 30 | val_steps: 0 31 | zeroshot_frequency: 0 32 | val_frequency: 0 33 | save_every_n_steps: 6104 34 | delete_prev_step_ckpt: true 35 | 36 | resume: latest -------------------------------------------------------------------------------- /opencls/configs/cls_schedule/cls_vit_l16_s512m_bs16k.yaml: -------------------------------------------------------------------------------- 1 | save_frequency: 1 2 | name: "cls_vit_l16_s512m_bs16k" 3 | train_data: '/datasets/datacomp_1b/data/{00000..128000}.tar' # please modify to your own path 4 | train_num_samples: 512_000_000 5 | dataset_type: webdataset 6 | precision: 'amp_bfloat16' 7 | warmup: 500 8 | global_batch_size: 16384 9 | batch_size: 0 10 | epochs: 1 11 | lr: 5e-4 12 | beta1: 0.9 13 | beta2: 0.98 14 | eps: 1.0e-6 15 | workers: 6 16 | model: "CLS-ViT-L-16" 17 | seed: 0 18 | ddp_static_graph: true 19 | local_loss: true 20 | gather_with_grad: true 21 | force_image_size: 224 22 | grad_checkpointing: true 23 | 24 | logs: './logs' 25 | imagenet_val: './imagenet1k/val' # please modify to your own path 26 | 27 | report_to: "tensorboard" 28 | log_every_n_steps: 128 29 | zeroshot_steps: 0 30 | val_steps: 0 31 | zeroshot_frequency: 0 32 | val_frequency: 0 33 | save_every_n_steps: 6104 34 | delete_prev_step_ckpt: true 35 | 36 | resume: latest -------------------------------------------------------------------------------- /opencls/configs/cls_schedule/lit_vit_b16_s1.28B_bs16k.yaml: -------------------------------------------------------------------------------- 1 | save_frequency: 1 2 | name: "lit_cls_vit_b16_s1.28B_bs16k" 3 | train_data: '/datasets/datacomp_1b/data/{00000..51200}.tar' 4 | train_num_samples: 1_280_000_000 5 | dataset_type: webdataset 6 | precision: 'amp_bfloat16' 7 | warmup: 500 8 | global_batch_size: 16384 9 | batch_size: 0 10 | epochs: 1 11 | lr: 5e-4 12 | beta1: 0.9 13 | beta2: 0.98 14 | eps: 1.0e-6 15 | workers: 6 16 | model: "ViT-B-16" 17 | seed: 0 18 | ddp_static_graph: true 19 | local_loss: true 20 | gather_with_grad: true 21 | force_image_size: 224 22 | grad_checkpointing: true 23 | 24 | lock_image: true 25 | lock_image_unlocked_groups: 1 26 | pretrained_image: './logs/cls_vit_b16_s1.28B_bs16k/checkpoints/epoch_1.pt' 27 | 28 | logs: './logs' 29 | imagenet_val: './imagenet1k/val' # please modify to your own path 30 | 31 | report_to: "tensorboard" 32 | log_every_n_steps: 128 33 | zeroshot_steps: 6104 34 | val_steps: 6104 35 | save_every_n_steps: 6104 36 | delete_prev_step_ckpt: true 37 | 38 | resume: latest -------------------------------------------------------------------------------- /opencls/configs/cls_schedule/lit_vit_b16_s512m_bs16k.yaml: -------------------------------------------------------------------------------- 1 | save_frequency: 1 2 | name: "lit_cls_vit_b16_s512m_bs16k" 3 | train_data: '/datasets/datacomp_1b/data/{00000..51200}.tar' 4 | train_num_samples: 512_000_000 5 | dataset_type: webdataset 6 | precision: 'amp_bfloat16' 7 | warmup: 500 8 | global_batch_size: 16384 9 | batch_size: 0 10 | epochs: 1 11 | lr: 5e-4 12 | beta1: 0.9 13 | beta2: 0.98 14 | eps: 1.0e-6 15 | workers: 6 16 | model: "ViT-B-16-avg" 17 | seed: 0 18 | ddp_static_graph: true 19 | local_loss: true 20 | gather_with_grad: true 21 | force_image_size: 224 22 | grad_checkpointing: true 23 | 24 | lock_image: true 25 | lock_image_unlocked_groups: 1 26 | pretrained_image: './logs/cls_vit_b16_s512m_bs16k/checkpoints/epoch_1.pt' 27 | 28 | logs: './logs' 29 | imagenet_val: './imagenet1k/val' # please modify to your own path 30 | 31 | report_to: "tensorboard" 32 | log_every_n_steps: 128 33 | zeroshot_steps: 6104 34 | val_steps: 6104 35 | save_every_n_steps: 6104 36 | delete_prev_step_ckpt: true 37 | 38 | resume: latest -------------------------------------------------------------------------------- /opencls/configs/cls_schedule/lit_vit_l14_224_s12.8B_bs90k.yaml: -------------------------------------------------------------------------------- 1 | save_frequency: 1 2 | name: "lit_cls_vit_l14_224_s12.8B_bs90k" 3 | train_data: '/datasets/datacomp_1b/data/{00000..140089}.tar' # please modify to your own path 4 | train_num_samples: 1_280_000_000 5 | dataset_type: webdataset 6 | precision: 'amp_bfloat16' 7 | warmup: 10000 8 | global_batch_size: 90112 9 | batch_size: 0 10 | epochs: 10 11 | lr: 1.0e-3 12 | beta1: 0.9 13 | beta2: 0.95 14 | eps: 1.0e-6 15 | workers: 6 16 | model: "ViT-L-14-avg" 17 | seed: 0 18 | ddp_static_graph: true 19 | local_loss: true 20 | gather_with_grad: true 21 | force_image_size: 224 22 | grad_checkpointing: true 23 | 24 | lock_image: true 25 | lock_image_unlocked_groups: 1 26 | pretrained_image: './logs/cls_vit_l14_224_s12.8B_bs90k/checkpoints/epoch_10.pt' 27 | 28 | logs: './logs' 29 | imagenet_val: './imagenet1k/val' # please modify to your own path 30 | 31 | report_to: "tensorboard" 32 | log_every_n_steps: 32 33 | zeroshot_steps: 3052 34 | val_steps: 3052 35 | zeroshot_frequency: 1 36 | save_every_n_steps: 3052 37 | delete_prev_step_ckpt: true 38 | aug_cfg: {'scale': [0.4, 1.0], 'color_jitter': [0.32, 0.32, 0.32, 0.08], 'color_jitter_prob': 0.8, 'gray_scale_prob': 0.2} 39 | resume: latest -------------------------------------------------------------------------------- /opencls/configs/cls_schedule/lit_vit_l14_s1.28B_bs16k.yaml: -------------------------------------------------------------------------------- 1 | save_frequency: 1 2 | name: "lit_cls_vit_l14_s1.28B_bs16k" 3 | train_data: '/datasets/datacomp_1b/data/{00000..128000}.tar' # please modify to your own path 4 | train_num_samples: 1_280_000_000 5 | dataset_type: webdataset 6 | precision: 'amp_bfloat16' 7 | warmup: 500 8 | global_batch_size: 16384 9 | batch_size: 0 10 | epochs: 1 11 | lr: 5e-4 12 | beta1: 0.9 13 | beta2: 0.98 14 | eps: 1.0e-6 15 | workers: 6 16 | model: "ViT-L-14" 17 | seed: 0 18 | ddp_static_graph: true 19 | local_loss: true 20 | gather_with_grad: true 21 | force_image_size: 224 22 | grad_checkpointing: true 23 | 24 | lock_image: true 25 | lock_image_unlocked_groups: 1 26 | pretrained_image: './logs/cls_vit_l14_s1.28B_bs16k/checkpoints/epoch_1.pt' 27 | 28 | logs: './logs' 29 | imagenet_val: './imagenet1k/val' # please modify to your own path 30 | 31 | report_to: "tensorboard" 32 | log_every_n_steps: 128 33 | zeroshot_steps: 6104 34 | val_steps: 6104 35 | save_every_n_steps: 6104 36 | delete_prev_step_ckpt: true 37 | 38 | resume: latest -------------------------------------------------------------------------------- /opencls/configs/cls_schedule/lit_vit_l16_s512m_bs16k.yaml: -------------------------------------------------------------------------------- 1 | save_frequency: 1 2 | name: "lit_clip_vit_l16_s512m_bs16k" 3 | train_data: '/datasets/datacomp_1b/data/{00000..51200}.tar' 4 | train_num_samples: 512_000_000 5 | dataset_type: webdataset 6 | precision: 'amp_bfloat16' 7 | warmup: 500 8 | global_batch_size: 16384 9 | batch_size: 0 10 | epochs: 1 11 | lr: 5e-4 12 | beta1: 0.9 13 | beta2: 0.98 14 | eps: 1.0e-6 15 | workers: 6 16 | model: "ViT-L-16-avg" 17 | seed: 0 18 | ddp_static_graph: true 19 | local_loss: true 20 | gather_with_grad: true 21 | force_image_size: 224 22 | grad_checkpointing: true 23 | 24 | lock_image: true 25 | lock_image_unlocked_groups: 1 26 | pretrained_image: './logs/cls_vit_l16_s512m_bs16k/checkpoints/epoch_1.pt' 27 | 28 | logs: './logs' 29 | imagenet_val: './imagenet1k/val' # please modify to your own path 30 | 31 | report_to: "tensorboard" 32 | log_every_n_steps: 128 33 | zeroshot_steps: 6104 34 | val_steps: 6104 35 | save_every_n_steps: 6104 36 | delete_prev_step_ckpt: true 37 | 38 | resume: latest -------------------------------------------------------------------------------- /opencls/open_clip/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/x-cls/superclass/90b761f41177400caa178d3845b39eba771e5bf1/opencls/open_clip/.DS_Store -------------------------------------------------------------------------------- /opencls/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .coca_model import CoCa 2 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 3 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss 4 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 5 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss 6 | from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ 7 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \ 8 | get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg 9 | from .openai import load_openai_model, list_openai_models 10 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ 11 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 12 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub 13 | from .tokenizer import SimpleTokenizer, tokenize, decode 14 | from .transform import image_transform, AugmentationCfg 15 | from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy 16 | from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES 17 | -------------------------------------------------------------------------------- /opencls/open_clip/big_vision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .model import CustomTextCLIP 5 | from .transformer import TextTransformer, Transformer 6 | 7 | 8 | @torch.no_grad() 9 | def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): 10 | """ Load weights from .npz checkpoints for official Google big_vision image-text models 11 | 12 | Currently the SigLIP source models are supported and a CustomTextCLIP destination model 13 | w/ timm image encoder. 14 | """ 15 | from timm.layers import resample_patch_embed, resample_abs_pos_embed 16 | 17 | def _n2p(w, t=True): 18 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 19 | w = w.flatten() 20 | if t: 21 | if w.ndim == 4: 22 | w = w.transpose([3, 2, 0, 1]) 23 | elif w.ndim == 3: 24 | w = w.transpose([2, 0, 1]) 25 | elif w.ndim == 2: 26 | w = w.transpose([1, 0]) 27 | return torch.from_numpy(w) 28 | 29 | w = np.load(checkpoint_path) 30 | interpolation = 'bilinear' 31 | antialias = False 32 | 33 | def _convert_timm_img(module, prefix): 34 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 35 | if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]: 36 | embed_conv_w = resample_patch_embed( 37 | embed_conv_w, 38 | module.patch_embed.proj.weight.shape[-2:], 39 | interpolation=interpolation, 40 | antialias=antialias, 41 | verbose=True, 42 | ) 43 | module.patch_embed.proj.weight.copy_(embed_conv_w) 44 | module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 45 | 46 | if module.cls_token is not None: 47 | module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 48 | 49 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) 50 | if pos_embed_w.shape != module.pos_embed.shape: 51 | assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}' 52 | num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1) 53 | pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights 54 | pos_embed_w, 55 | new_size=module.patch_embed.grid_size, 56 | num_prefix_tokens=num_prefix_tokens, 57 | interpolation=interpolation, 58 | antialias=antialias, 59 | verbose=True, 60 | ) 61 | module.pos_embed.copy_(pos_embed_w) 62 | 63 | mha_sub, b_sub, ln1_sub = (0, 0, 1) 64 | for i, block in enumerate(module.blocks.children()): 65 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 66 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' 67 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 68 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 69 | block.attn.qkv.weight.copy_(torch.cat([ 70 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 71 | block.attn.qkv.bias.copy_(torch.cat([ 72 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 73 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 74 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 75 | for r in range(2): 76 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) 77 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) 78 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) 79 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) 80 | 81 | module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 82 | module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 83 | 84 | if module.attn_pool is not None: 85 | block_prefix = f'{prefix}MAPHead_0/' 86 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' 87 | module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) 88 | module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) 89 | module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) 90 | module.attn_pool.kv.weight.copy_(torch.cat([ 91 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) 92 | module.attn_pool.kv.bias.copy_(torch.cat([ 93 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) 94 | module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 95 | module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 96 | module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 97 | module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 98 | for r in range(2): 99 | getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) 100 | getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) 101 | 102 | def _convert_openclip_transformer(module: Transformer, prefix): 103 | for i, block in enumerate(module.resblocks.children()): 104 | block_prefix = f'{prefix}encoderblock_{i}/' 105 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' 106 | block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 107 | block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 108 | block.attn.in_proj_weight.copy_(torch.cat([ 109 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 110 | block.attn.in_proj_bias.copy_(torch.cat([ 111 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 112 | block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 113 | block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 114 | block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'])) 115 | block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'])) 116 | block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'])) 117 | block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'])) 118 | block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'])) 119 | block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'])) 120 | 121 | def _convert_openclip_txt(module: TextTransformer, prefix): 122 | module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False)) 123 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0) 124 | module.positional_embedding.copy_(pos_embed_w) 125 | _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') 126 | module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) 127 | module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) 128 | module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 129 | module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 130 | 131 | _convert_timm_img(model.visual.trunk, 'params/img/') 132 | _convert_openclip_txt(model.text, 'params/txt/') 133 | model.logit_bias.copy_(_n2p(w['params/b'])[0]) 134 | model.logit_scale.copy_(_n2p(w['params/t'])[0]) 135 | 136 | 137 | -------------------------------------------------------------------------------- /opencls/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/x-cls/superclass/90b761f41177400caa178d3845b39eba771e5bf1/opencls/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /opencls/open_clip/cls_model.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (2024) Bytedance Ltd. and/or its affiliates 3 | # Licensed under the Apache License, Version 2.0 (the "License") 4 | # SuperClass Project 5 | # Written by Zilong Huang 6 | # -------------------------------------------------------- 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | import numpy as np 13 | from dataclasses import dataclass 14 | 15 | from .transformer import ( 16 | LayerNormFp32, 17 | LayerNorm, 18 | QuickGELU, 19 | MultimodalTransformer, 20 | MixClsHead, 21 | ) 22 | from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower 23 | 24 | 25 | @dataclass 26 | class ClassHeadCfg(CLIPTextCfg): 27 | mlp_ratio: int = 4 28 | layers: int = 1 29 | 30 | 31 | def _build_cls_head( 32 | width, 33 | clshead_cfg, 34 | quick_gelu: bool = False, 35 | cast_dtype: Optional[torch.dtype] = None, 36 | ): 37 | clshead_cfg = ClassHeadCfg(**clshead_cfg) if isinstance(clshead_cfg, dict) else clshead_cfg 38 | act_layer = QuickGELU if quick_gelu else nn.GELU 39 | norm_layer = ( 40 | LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm 41 | ) 42 | 43 | head = MixClsHead( 44 | width=width, 45 | layers=clshead_cfg.layers, 46 | mlp_ratio=clshead_cfg.mlp_ratio, 47 | act_layer=act_layer, 48 | norm_layer=norm_layer, 49 | output_dim=clshead_cfg.vocab_size, 50 | ) 51 | 52 | return head 53 | 54 | 55 | class Classifier(nn.Module): 56 | def __init__( 57 | self, 58 | embed_dim, 59 | text_cfg: CLIPTextCfg, 60 | vision_cfg: CLIPVisionCfg, 61 | quick_gelu: bool = False, 62 | cast_dtype: Optional[torch.dtype] = None, 63 | ): 64 | super().__init__() 65 | clshead_cfg = ClassHeadCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg 66 | vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg 67 | 68 | vocab_size = clshead_cfg.vocab_size 69 | 70 | self.visual = _build_vision_tower( 71 | embed_dim=embed_dim, 72 | vision_cfg=vision_cfg, 73 | quick_gelu=quick_gelu, 74 | cast_dtype=cast_dtype, 75 | ) 76 | 77 | self.text_decoder = _build_cls_head( 78 | embed_dim if embed_dim else vision_cfg.width, 79 | clshead_cfg=clshead_cfg, 80 | quick_gelu=quick_gelu, 81 | cast_dtype=cast_dtype, 82 | ) 83 | 84 | self.register_buffer("cap_fq", torch.zeros([1, vocab_size], dtype=torch.float64)) 85 | self.register_buffer("num_samples", torch.zeros([1, 1], dtype=torch.float64)) 86 | 87 | @torch.jit.ignore 88 | def set_grad_checkpointing(self, enable=True): 89 | self.visual.set_grad_checkpointing(enable) 90 | # self.text.set_grad_checkpointing(enable) 91 | # self.text_decoder.set_grad_checkpointing(enable) 92 | 93 | def forward(self, image, text, image_embs=None): 94 | if image_embs is None: 95 | image_embs = self.visual(image) 96 | 97 | logits = self.text_decoder(image_embs) 98 | labels = text.clone() 99 | 100 | return { 101 | "cap_fq": self.cap_fq, 102 | "num_samples": self.num_samples, 103 | "logits": logits, 104 | "labels": labels, 105 | "logit_scale": torch.ones([1]), 106 | } -------------------------------------------------------------------------------- /opencls/open_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 4 | IMAGENET_STD = (0.229, 0.224, 0.225) 5 | INCEPTION_MEAN = (0.5, 0.5, 0.5) 6 | INCEPTION_STD = (0.5, 0.5, 0.5) 7 | -------------------------------------------------------------------------------- /opencls/open_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | # https://huggingface.co/docs/transformers/model_doc/bert 46 | "bert": { 47 | "config_names": { 48 | "context_length": "max_position_embeddings", 49 | "vocab_size": "vocab_size", 50 | "width": "hidden_size", 51 | "heads": "num_attention_heads", 52 | "layers": "num_hidden_layers", 53 | }, 54 | "pooler": "cls_pooler", 55 | }, 56 | # https://huggingface.co/docs/transformers/model_doc/m2m_100 57 | "m2m_100": { 58 | "config_names": { 59 | "context_length": "max_position_embeddings", 60 | "vocab_size": "vocab_size", 61 | "width": "d_model", 62 | "heads": "encoder_attention_heads", 63 | "layers": "encoder_layers", 64 | }, 65 | "pooler": "cls_pooler", 66 | }, 67 | } 68 | -------------------------------------------------------------------------------- /opencls/open_clip/hf_model.py: -------------------------------------------------------------------------------- 1 | """ huggingface model adapter 2 | 3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. 4 | """ 5 | import re 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch import TensorType 10 | 11 | try: 12 | import transformers 13 | from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig 14 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ 15 | BaseModelOutputWithPoolingAndCrossAttentions 16 | except ImportError as e: 17 | transformers = None 18 | 19 | 20 | class BaseModelOutput: 21 | pass 22 | 23 | 24 | class PretrainedConfig: 25 | pass 26 | 27 | from .hf_configs import arch_dict 28 | 29 | 30 | # utils 31 | def _camel2snake(s): 32 | return re.sub(r'(? torch.Tensor: 96 | # calculated ground-truth and cache if enabled 97 | if self.prev_num_logits != num_logits or device not in self.labels: 98 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 99 | if self.world_size > 1 and self.local_loss: 100 | labels = labels + num_logits * self.rank 101 | if self.cache_labels: 102 | self.labels[device] = labels 103 | self.prev_num_logits = num_logits 104 | else: 105 | labels = self.labels[device] 106 | return labels 107 | 108 | def get_logits(self, image_features, text_features, logit_scale): 109 | if self.world_size > 1: 110 | all_image_features, all_text_features = gather_features( 111 | image_features, text_features, 112 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 113 | 114 | if self.local_loss: 115 | logits_per_image = logit_scale * image_features @ all_text_features.T 116 | logits_per_text = logit_scale * text_features @ all_image_features.T 117 | else: 118 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 119 | logits_per_text = logits_per_image.T 120 | else: 121 | logits_per_image = logit_scale * image_features @ text_features.T 122 | logits_per_text = logit_scale * text_features @ image_features.T 123 | 124 | return logits_per_image, logits_per_text 125 | 126 | def forward(self, image_features, text_features, logit_scale, output_dict=False): 127 | device = image_features.device 128 | logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) 129 | 130 | labels = self.get_ground_truth(device, logits_per_image.shape[0]) 131 | 132 | total_loss = ( 133 | F.cross_entropy(logits_per_image, labels) + 134 | F.cross_entropy(logits_per_text, labels) 135 | ) / 2 136 | 137 | return {"contrastive_loss": total_loss} if output_dict else total_loss 138 | 139 | 140 | class CoCaLoss(ClipLoss): 141 | def __init__( 142 | self, 143 | caption_loss_weight, 144 | clip_loss_weight, 145 | pad_id=0, # pad_token for open_clip custom tokenizer 146 | local_loss=False, 147 | gather_with_grad=False, 148 | cache_labels=False, 149 | rank=0, 150 | world_size=1, 151 | use_horovod=False, 152 | ): 153 | super().__init__( 154 | local_loss=local_loss, 155 | gather_with_grad=gather_with_grad, 156 | cache_labels=cache_labels, 157 | rank=rank, 158 | world_size=world_size, 159 | use_horovod=use_horovod 160 | ) 161 | 162 | self.clip_loss_weight = clip_loss_weight 163 | self.caption_loss_weight = caption_loss_weight 164 | self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) 165 | 166 | def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): 167 | 168 | clip_loss = torch.tensor(0) 169 | 170 | if self.clip_loss_weight: 171 | clip_loss = super().forward(image_features, text_features, logit_scale) 172 | clip_loss = self.clip_loss_weight * clip_loss 173 | 174 | caption_loss = self.caption_loss( 175 | logits.permute(0, 2, 1), 176 | labels, 177 | ) 178 | caption_loss = caption_loss * self.caption_loss_weight 179 | 180 | if output_dict: 181 | return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} 182 | 183 | return clip_loss, caption_loss 184 | 185 | 186 | class DistillClipLoss(ClipLoss): 187 | 188 | def dist_loss(self, teacher_logits, student_logits): 189 | return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) 190 | 191 | def forward( 192 | self, 193 | image_features, 194 | text_features, 195 | logit_scale, 196 | dist_image_features, 197 | dist_text_features, 198 | dist_logit_scale, 199 | output_dict=False, 200 | ): 201 | logits_per_image, logits_per_text = \ 202 | self.get_logits(image_features, text_features, logit_scale) 203 | 204 | dist_logits_per_image, dist_logits_per_text = \ 205 | self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) 206 | 207 | labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0]) 208 | 209 | contrastive_loss = ( 210 | F.cross_entropy(logits_per_image, labels) + 211 | F.cross_entropy(logits_per_text, labels) 212 | ) / 2 213 | 214 | distill_loss = ( 215 | self.dist_loss(dist_logits_per_image, logits_per_image) + 216 | self.dist_loss(dist_logits_per_text, logits_per_text) 217 | ) / 2 218 | 219 | if output_dict: 220 | return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} 221 | 222 | return contrastive_loss, distill_loss 223 | 224 | 225 | def neighbour_exchange(from_rank, to_rank, tensor, group=None): 226 | tensor_recv = torch.zeros_like(tensor) 227 | send_op = torch.distributed.P2POp( 228 | torch.distributed.isend, 229 | tensor, 230 | to_rank, 231 | group=group, 232 | ) 233 | recv_op = torch.distributed.P2POp( 234 | torch.distributed.irecv, 235 | tensor_recv, 236 | from_rank, 237 | group=group, 238 | ) 239 | reqs = torch.distributed.batch_isend_irecv([send_op, recv_op]) 240 | for req in reqs: 241 | req.wait() 242 | return tensor_recv 243 | 244 | 245 | def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None): 246 | tensor_from_left = torch.zeros_like(tensor_to_right) 247 | tensor_from_right = torch.zeros_like(tensor_to_left) 248 | send_op_left = torch.distributed.P2POp( 249 | torch.distributed.isend, 250 | tensor_to_left, 251 | left_rank, 252 | group=group, 253 | ) 254 | send_op_right = torch.distributed.P2POp( 255 | torch.distributed.isend, 256 | tensor_to_right, 257 | right_rank, 258 | group=group, 259 | ) 260 | recv_op_left = torch.distributed.P2POp( 261 | torch.distributed.irecv, 262 | tensor_from_left, 263 | left_rank, 264 | group=group, 265 | ) 266 | recv_op_right = torch.distributed.P2POp( 267 | torch.distributed.irecv, 268 | tensor_from_right, 269 | right_rank, 270 | group=group, 271 | ) 272 | reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left]) 273 | for req in reqs: 274 | req.wait() 275 | return tensor_from_right, tensor_from_left 276 | 277 | 278 | class NeighbourExchange(torch.autograd.Function): 279 | @staticmethod 280 | def forward(ctx, from_rank, to_rank, group, tensor): 281 | ctx.group = group 282 | ctx.from_rank = from_rank 283 | ctx.to_rank = to_rank 284 | return neighbour_exchange(from_rank, to_rank, tensor, group=group) 285 | 286 | @staticmethod 287 | def backward(ctx, grad_output): 288 | return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),) 289 | 290 | 291 | def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None): 292 | return NeighbourExchange.apply(from_rank, to_rank, group, tensor) 293 | 294 | 295 | class NeighbourExchangeBidir(torch.autograd.Function): 296 | @staticmethod 297 | def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right): 298 | ctx.group = group 299 | ctx.left_rank = left_rank 300 | ctx.right_rank = right_rank 301 | return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group) 302 | 303 | @staticmethod 304 | def backward(ctx, *grad_outputs): 305 | return (None, None, None) + \ 306 | NeighbourExchangeBidir.apply(ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs) 307 | 308 | 309 | def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None): 310 | return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right) 311 | 312 | 313 | class SigLipLoss(nn.Module): 314 | """ Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343 315 | 316 | @article{zhai2023sigmoid, 317 | title={Sigmoid loss for language image pre-training}, 318 | author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas}, 319 | journal={arXiv preprint arXiv:2303.15343}, 320 | year={2023} 321 | } 322 | """ 323 | def __init__( 324 | self, 325 | cache_labels=False, 326 | rank=0, 327 | world_size=1, 328 | bidir=True, 329 | use_horovod=False, 330 | ): 331 | super().__init__() 332 | self.cache_labels = cache_labels 333 | self.rank = rank 334 | self.world_size = world_size 335 | assert not use_horovod # FIXME need to look at hvd ops for ring transfers 336 | self.use_horovod = use_horovod 337 | self.bidir = bidir 338 | 339 | # cache state FIXME cache not currently used, worthwhile? 340 | self.prev_num_logits = 0 341 | self.labels = {} 342 | 343 | def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor: 344 | labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype) 345 | if not negative_only: 346 | labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels 347 | return labels 348 | 349 | def get_logits(self, image_features, text_features, logit_scale, logit_bias=None): 350 | logits = logit_scale * image_features @ text_features.T 351 | if logit_bias is not None: 352 | logits += logit_bias 353 | return logits 354 | 355 | def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False): 356 | logits = self.get_logits(image_features, text_features, logit_scale, logit_bias) 357 | labels = self.get_ground_truth( 358 | image_features.device, 359 | image_features.dtype, 360 | image_features.shape[0], 361 | negative_only=negative_only, 362 | ) 363 | loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0] 364 | return loss 365 | 366 | def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False): 367 | loss = self._loss(image_features, text_features, logit_scale, logit_bias) 368 | 369 | if self.world_size > 1: 370 | # exchange text features w/ neighbour world_size - 1 times 371 | right_rank = (self.rank + 1) % self.world_size 372 | left_rank = (self.rank - 1 + self.world_size) % self.world_size 373 | if self.bidir: 374 | text_features_to_right = text_features_to_left = text_features 375 | num_bidir, remainder = divmod(self.world_size - 1, 2) 376 | for i in range(num_bidir): 377 | text_features_recv = neighbour_exchange_bidir_with_grad( 378 | left_rank, 379 | right_rank, 380 | text_features_to_left, 381 | text_features_to_right, 382 | ) 383 | 384 | for f in text_features_recv: 385 | loss += self._loss( 386 | image_features, 387 | f, 388 | logit_scale, 389 | logit_bias, 390 | negative_only=True, 391 | ) 392 | text_features_to_left, text_features_to_right = text_features_recv 393 | 394 | if remainder: 395 | text_features_recv = neighbour_exchange_with_grad( 396 | left_rank, right_rank, text_features_to_right) 397 | 398 | loss += self._loss( 399 | image_features, 400 | text_features_recv, 401 | logit_scale, 402 | logit_bias, 403 | negative_only=True, 404 | ) 405 | else: 406 | text_features_to_right = text_features 407 | for i in range(self.world_size - 1): 408 | text_features_from_left = neighbour_exchange_with_grad( 409 | left_rank, right_rank, text_features_to_right) 410 | 411 | loss += self._loss( 412 | image_features, 413 | text_features_from_left, 414 | logit_scale, 415 | logit_bias, 416 | negative_only=True, 417 | ) 418 | text_features_to_right = text_features_from_left 419 | 420 | return {"contrastive_loss": loss} if output_dict else loss 421 | 422 | 423 | 424 | class ClsLoss(nn.Module): 425 | def __init__( 426 | self, 427 | world_size, 428 | pad_id=0, # pad_token for open_clip custom tokenizer 429 | ): 430 | super().__init__() 431 | 432 | self.pad_id = pad_id 433 | self.world_size = world_size 434 | print('loss ignore id ', pad_id) 435 | 436 | def loss(self, logits, targets): 437 | norm_item = F.normalize(targets, p=1, dim=1) 438 | loss = -(F.log_softmax(logits, dim=1) * norm_item).sum(dim=1).mean() 439 | return loss 440 | 441 | def reweight_targets(self, cap_fq, num_samples, targets): 442 | cap_fq += targets.sum(dim=0, keepdim=True) / targets.shape[0] 443 | num_samples += 1 444 | dist.all_reduce(cap_fq, op=dist.ReduceOp.AVG) 445 | dist.all_reduce(num_samples, op=dist.ReduceOp.AVG) 446 | all_batch_size = self.world_size * targets.shape[0] 447 | targets = targets * torch.log((num_samples+1.0/all_batch_size) / (cap_fq+1.0/all_batch_size)).to(dtype=targets.dtype) 448 | return targets 449 | 450 | def forward(self, cap_fq, num_samples, logits, labels, logit_scale, output_dict=False): 451 | B, C = logits.shape 452 | 453 | targets = torch.zeros(B, C, dtype=torch.float32).to(labels.device) 454 | # scatter labels to one-hot 455 | targets.scatter_(dim=1, index=labels, value=1.0) 456 | 457 | targets = self.reweight_targets(cap_fq, num_samples, targets) 458 | class_loss = self.loss(logits, targets) 459 | 460 | if output_dict: 461 | return {"class_loss": class_loss} 462 | 463 | return class_loss -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/x-cls/superclass/90b761f41177400caa178d3845b39eba771e5bf1/opencls/open_clip/model_configs/.DS_Store -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/CLS-ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 0, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "pool_type": "avg" 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "layers": 0, 14 | "mlp_ratio": 4 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/CLS-ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 0, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "pool_type": "avg" 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "layers": 0, 14 | "mlp_ratio": 4 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/CLS-ViT-L-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 0, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16, 8 | "pool_type": "avg" 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "layers": 0, 14 | "mlp_ratio": 4 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/EVA01-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva_giant_patch14_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/EVA01-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva_giant_patch14_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/EVA02-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_base_patch16_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/EVA02-E-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_enormous_patch14_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1280, 14 | "heads": 20, 15 | "layers": 32 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/EVA02-E-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_enormous_patch14_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/EVA02-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "timm_model_name": "eva02_large_patch14_clip_336", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/EVA02-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_large_patch14_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/RN50x64.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": [ 6 | 3, 7 | 15, 8 | 36, 9 | 10 10 | ], 11 | "width": 128, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 1024, 18 | "heads": 16, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-B-16-SigLIP-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 256, 7 | "timm_model_name": "vit_base_patch16_siglip_256", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-B-16-SigLIP-384.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_base_patch16_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-B-16-SigLIP-512.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 512, 7 | "timm_model_name": "vit_base_patch16_siglip_512", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-B-16-SigLIP-i18n-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 256, 7 | "timm_model_name": "vit_base_patch16_siglip_256", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 250000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP-i18n-256", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-B-16-SigLIP.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 224, 7 | "timm_model_name": "vit_base_patch16_siglip_224", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-B-16-avg.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "pool_type": "avg" 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-B-16-plus-240.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 240, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-B-16-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-B-16-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-B-32-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-B-32-plus-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-H-14-378-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 378, 6 | "layers": 32, 7 | "width": 1280, 8 | "head_width": 80, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-H-14-CLIPA-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14, 9 | "no_ln_pre": true, 10 | "pool_type": "avg", 11 | "final_ln_after_pool": true 12 | }, 13 | "text_cfg": { 14 | "context_length": 32, 15 | "vocab_size": 32000, 16 | "hf_tokenizer_name": "bert-base-uncased", 17 | "tokenizer_kwargs": { 18 | "strip_sep_token": true 19 | }, 20 | "width": 1024, 21 | "heads": 16, 22 | "layers": 24, 23 | "pool_type": "last", 24 | "no_causal_mask": true 25 | } 26 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-H-14-CLIPA.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14, 9 | "no_ln_pre": true, 10 | "pool_type": "avg", 11 | "final_ln_after_pool": true 12 | }, 13 | "text_cfg": { 14 | "context_length": 32, 15 | "vocab_size": 32000, 16 | "hf_tokenizer_name": "bert-base-uncased", 17 | "tokenizer_kwargs": { 18 | "strip_sep_token": true 19 | }, 20 | "width": 1024, 21 | "heads": 16, 22 | "layers": 24, 23 | "pool_type": "last", 24 | "no_causal_mask": true 25 | } 26 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-H-14-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 32, 7 | "width": 1280, 8 | "head_width": 80, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-H-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-L-14-280.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 280, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-L-14-CLIPA-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "no_ln_pre": true, 9 | "pool_type": "avg", 10 | "final_ln_after_pool": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 32, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "bert-base-uncased", 16 | "tokenizer_kwargs": { 17 | "strip_sep_token": true 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "pool_type": "last", 23 | "no_causal_mask": true 24 | } 25 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-L-14-CLIPA.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "no_ln_pre": true, 9 | "pool_type": "avg", 10 | "final_ln_after_pool": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 32, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "bert-base-uncased", 16 | "tokenizer_kwargs": { 17 | "strip_sep_token": true 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "pool_type": "last", 23 | "no_causal_mask": true 24 | } 25 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-L-14-avg.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "pool_type": "avg" 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-L-14-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 24, 7 | "width": 1024, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-L-16-320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 320, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-L-16-SigLIP-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 256, 7 | "timm_model_name": "vit_large_patch16_siglip_256", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-L-16-SigLIP-384.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_large_patch16_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-L-16-avg.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16, 8 | "pool_type": "avg" 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-L-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-M-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16, 8 | "ls_init_value": 1e-4 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 384, 14 | "heads": 6, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-M-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-M-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-M-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-S-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-S-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-S-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-S-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-SO400M-14-SigLIP-384.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1152, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_so400m_patch14_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 1152, 20 | "heads": 16, 21 | "layers": 27, 22 | "mlp_ratio": 3.7362, 23 | "no_causal_mask": true, 24 | "proj_bias": true, 25 | "pool_type": "last", 26 | "norm_kwargs":{ 27 | "eps": 1e-6 28 | } 29 | } 30 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-SO400M-14-SigLIP.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1152, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 224, 7 | "timm_model_name": "vit_so400m_patch14_siglip_224", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 16, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 1152, 20 | "heads": 16, 21 | "layers": 27, 22 | "mlp_ratio": 3.7362, 23 | "no_causal_mask": true, 24 | "proj_bias": true, 25 | "pool_type": "last", 26 | "norm_kwargs":{ 27 | "eps": 1e-6 28 | } 29 | } 30 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-bigG-14-CLIPA-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14, 10 | "no_ln_pre": true, 11 | "pool_type": "avg", 12 | "final_ln_after_pool": true 13 | }, 14 | "text_cfg": { 15 | "context_length": 32, 16 | "vocab_size": 32000, 17 | "hf_tokenizer_name": "bert-base-uncased", 18 | "tokenizer_kwargs": { 19 | "strip_sep_token": true 20 | }, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "pool_type": "last", 25 | "no_causal_mask": true 26 | } 27 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-bigG-14-CLIPA.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14, 10 | "no_ln_pre": true, 11 | "pool_type": "avg", 12 | "final_ln_after_pool": true 13 | }, 14 | "text_cfg": { 15 | "context_length": 32, 16 | "vocab_size": 32000, 17 | "hf_tokenizer_name": "bert-base-uncased", 18 | "tokenizer_kwargs": { 19 | "strip_sep_token": true 20 | }, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "pool_type": "last", 25 | "no_causal_mask": true 26 | } 27 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-bigG-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 32 17 | } 18 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-e-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 56, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.5715, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 36 17 | } 18 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/ViT-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/coca_ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32, 8 | "attentional_pool": true, 9 | "attn_pooler_heads": 8, 10 | "output_tokens": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 76, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12, 18 | "embed_cls": true, 19 | "output_tokens": true 20 | }, 21 | "multimodal_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 49408, 24 | "width": 512, 25 | "heads": 8, 26 | "layers": 12, 27 | "attn_pooler_heads": 8 28 | }, 29 | "custom_text": true 30 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/coca_ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "attentional_pool": true, 9 | "attn_pooler_heads": 8, 10 | "output_tokens": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 76, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 12, 18 | "embed_cls": true, 19 | "output_tokens": true 20 | }, 21 | "multimodal_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 49408, 24 | "width": 768, 25 | "heads": 12, 26 | "layers": 12, 27 | "attn_pooler_heads": 12 28 | }, 29 | "custom_text": true 30 | } 31 | -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/coca_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "multimodal_cfg": { 4 | "width": 768, 5 | "context_length": 76, 6 | "vocab_size": 64000, 7 | "mlp_ratio": 4, 8 | "layers": 12, 9 | "dim_head": 64, 10 | "heads": 12, 11 | "n_queries": 256, 12 | "attn_pooler_heads": 8 13 | }, 14 | "vision_cfg": { 15 | "image_size": 288, 16 | "layers": 12, 17 | "width": 768, 18 | "patch_size": 18, 19 | "output_tokens": true 20 | }, 21 | "text_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 64000, 24 | "layers": 12, 25 | "heads": 12, 26 | "width": 768, 27 | "embed_cls": true, 28 | "output_tokens": true 29 | }, 30 | "custom_text": true 31 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/coca_roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32, 8 | "output_tokens": true 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "hf_proj_type": "linear", 14 | "width": 768, 15 | "output_tokens": true 16 | }, 17 | "multimodal_cfg": { 18 | "context_length": 76, 19 | "width": 768, 20 | "heads": 8, 21 | "layers": 12 22 | }, 23 | "custom_text": true 24 | } 25 | -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/convnext_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/convnext_base_w.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 640, 16 | "heads": 10, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/convnext_base_w_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 640, 16 | "heads": 10, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/convnext_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/convnext_large_d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "mlp", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 16 18 | } 19 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/convnext_large_d_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "mlp", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 16 18 | } 19 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/convnext_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_small", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/convnext_tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_tiny", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/convnext_xlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 20 18 | } 19 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/convnext_xxlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xxlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 24 18 | } 19 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/convnext_xxlarge_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xxlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 24 18 | } 19 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/mt5-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "google/mt5-base", 11 | "hf_tokenizer_name": "google/mt5-base", 12 | "hf_pooler_type": "mean_pooler" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/mt5-xl-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "google/mt5-xl", 12 | "hf_tokenizer_name": "google/mt5-xl", 13 | "hf_pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/nllb-clip-base-siglip.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "custom_text": true, 4 | "init_logit_bias": -10, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_base_patch16_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "hf_model_name": "facebook/nllb-200-distilled-600M", 14 | "hf_tokenizer_name": "facebook/nllb-200-distilled-600M", 15 | "hf_proj_type": "linear", 16 | "hf_pooler_type": "cls_pooler" 17 | } 18 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/nllb-clip-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "facebook/nllb-200-distilled-600M", 11 | "hf_tokenizer_name": "facebook/nllb-200-distilled-600M", 12 | "hf_proj_type": "linear", 13 | "hf_pooler_type": "cls_pooler" 14 | } 15 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/nllb-clip-large-siglip.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1152, 3 | "custom_text": true, 4 | "init_logit_bias": -10, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_so400m_patch14_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "hf_model_name": "facebook/nllb-200-distilled-1.3B", 14 | "hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B", 15 | "hf_proj_type": "linear", 16 | "hf_pooler_type": "cls_pooler" 17 | } 18 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/nllb-clip-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "facebook/nllb-200-distilled-1.3B", 12 | "hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B", 13 | "hf_proj_type": "linear", 14 | "hf_pooler_type": "cls_pooler" 15 | } 16 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "hf_pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/swin_base_patch4_window7_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "swin_base_patch4_window7_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 640, 14 | "heads": 10, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/vit_medium_patch16_gap_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_medium_patch16_gap_256", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 256 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_relpos_medium_patch16_cls_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "xlm-roberta-base", 11 | "hf_tokenizer_name": "xlm-roberta-base", 12 | "hf_pooler_type": "mean_pooler" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /opencls/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "xlm-roberta-large", 12 | "hf_tokenizer_name": "xlm-roberta-large", 13 | "hf_pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /opencls/open_clip/modified_resnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from open_clip.utils import freeze_batch_norm_2d 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.act1(self.bn1(self.conv1(x))) 46 | out = self.act2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.act3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x, key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0., 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | 92 | return x[0] 93 | 94 | 95 | class ModifiedResNet(nn.Module): 96 | """ 97 | A ResNet class that is similar to torchvision's but contains the following changes: 98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 100 | - The final pooling layer is a QKV attention instead of an average pool 101 | """ 102 | 103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 104 | super().__init__() 105 | self.output_dim = output_dim 106 | self.image_size = image_size 107 | 108 | # the 3-layer stem 109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 110 | self.bn1 = nn.BatchNorm2d(width // 2) 111 | self.act1 = nn.ReLU(inplace=True) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.act2 = nn.ReLU(inplace=True) 115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 116 | self.bn3 = nn.BatchNorm2d(width) 117 | self.act3 = nn.ReLU(inplace=True) 118 | self.avgpool = nn.AvgPool2d(2) 119 | 120 | # residual layers 121 | self._inplanes = width # this is a *mutable* variable used during construction 122 | self.layer1 = self._make_layer(width, layers[0]) 123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 126 | 127 | embed_dim = width * 32 # the ResNet feature dimension 128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 129 | 130 | self.init_parameters() 131 | 132 | def _make_layer(self, planes, blocks, stride=1): 133 | layers = [Bottleneck(self._inplanes, planes, stride)] 134 | 135 | self._inplanes = planes * Bottleneck.expansion 136 | for _ in range(1, blocks): 137 | layers.append(Bottleneck(self._inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def init_parameters(self): 142 | if self.attnpool is not None: 143 | std = self.attnpool.c_proj.in_features ** -0.5 144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 148 | 149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 150 | for name, param in resnet_block.named_parameters(): 151 | if name.endswith("bn3.weight"): 152 | nn.init.zeros_(param) 153 | 154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | if freeze_bn_stats: 159 | freeze_batch_norm_2d(self) 160 | 161 | @torch.jit.ignore 162 | def set_grad_checkpointing(self, enable=True): 163 | # FIXME support for non-transformer 164 | pass 165 | 166 | def stem(self, x): 167 | x = self.act1(self.bn1(self.conv1(x))) 168 | x = self.act2(self.bn2(self.conv2(x))) 169 | x = self.act3(self.bn3(self.conv3(x))) 170 | x = self.avgpool(x) 171 | return x 172 | 173 | def forward(self, x): 174 | x = self.stem(x) 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | x = self.attnpool(x) 180 | 181 | return x 182 | -------------------------------------------------------------------------------- /opencls/open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 13 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 14 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 15 | 16 | __all__ = ["list_openai_models", "load_openai_model"] 17 | 18 | 19 | def list_openai_models() -> List[str]: 20 | """Returns the names of available CLIP models""" 21 | return list_pretrained_models_by_tag('openai') 22 | 23 | 24 | def load_openai_model( 25 | name: str, 26 | precision: Optional[str] = None, 27 | device: Optional[Union[str, torch.device]] = None, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | cache_dir : Optional[str] 41 | The directory to cache the downloaded model weights 42 | 43 | Returns 44 | ------- 45 | model : torch.nn.Module 46 | The CLIP model 47 | preprocess : Callable[[PIL.Image], torch.Tensor] 48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 49 | """ 50 | if device is None: 51 | device = "cuda" if torch.cuda.is_available() else "cpu" 52 | if precision is None: 53 | precision = 'fp32' if device == 'cpu' else 'fp16' 54 | 55 | if get_pretrained_url(name, 'openai'): 56 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 57 | elif os.path.isfile(name): 58 | model_path = name 59 | else: 60 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 61 | 62 | try: 63 | # loading JIT archive 64 | model = torch.jit.load(model_path, map_location="cpu").eval() 65 | state_dict = None 66 | except RuntimeError: 67 | # loading saved state dict 68 | state_dict = torch.load(model_path, map_location="cpu") 69 | 70 | # Build a non-jit model from the OpenAI jitted model state dict 71 | cast_dtype = get_cast_dtype(precision) 72 | try: 73 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 74 | except KeyError: 75 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 76 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 77 | 78 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 79 | model = model.to(device) 80 | # FIXME support pure fp16/bf16 precision modes 81 | if precision != 'fp16': 82 | model.float() 83 | if precision == 'bf16': 84 | # for bf16, convert back to low-precision 85 | convert_weights_to_lp(model, dtype=torch.bfloat16) 86 | 87 | # add mean / std attributes for consistency with OpenCLIP models 88 | model.visual.image_mean = OPENAI_DATASET_MEAN 89 | model.visual.image_std = OPENAI_DATASET_STD 90 | return model 91 | -------------------------------------------------------------------------------- /opencls/open_clip/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /opencls/open_clip/push_to_hf_hub.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from pathlib import Path 5 | from tempfile import TemporaryDirectory 6 | from typing import Optional, Tuple, Union 7 | 8 | import torch 9 | 10 | try: 11 | from huggingface_hub import ( 12 | create_repo, 13 | get_hf_file_metadata, 14 | hf_hub_download, 15 | hf_hub_url, 16 | repo_type_and_id_from_hf_id, 17 | upload_folder, 18 | list_repo_files, 19 | ) 20 | from huggingface_hub.utils import EntryNotFoundError 21 | _has_hf_hub = True 22 | except ImportError: 23 | _has_hf_hub = False 24 | 25 | try: 26 | import safetensors.torch 27 | _has_safetensors = True 28 | except ImportError: 29 | _has_safetensors = False 30 | 31 | from .factory import create_model_from_pretrained, get_model_config, get_tokenizer 32 | from .tokenizer import HFTokenizer 33 | 34 | # Default name for a weights file hosted on the Huggingface Hub. 35 | HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl 36 | HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version 37 | HF_CONFIG_NAME = 'open_clip_config.json' 38 | 39 | 40 | def save_config_for_hf( 41 | model, 42 | config_path: str, 43 | model_config: Optional[dict] 44 | ): 45 | preprocess_cfg = { 46 | 'mean': model.visual.image_mean, 47 | 'std': model.visual.image_std, 48 | } 49 | other_pp = getattr(model.visual, 'preprocess_cfg', {}) 50 | if 'interpolation' in other_pp: 51 | preprocess_cfg['interpolation'] = other_pp['interpolation'] 52 | if 'resize_mode' in other_pp: 53 | preprocess_cfg['resize_mode'] = other_pp['resize_mode'] 54 | hf_config = { 55 | 'model_cfg': model_config, 56 | 'preprocess_cfg': preprocess_cfg, 57 | } 58 | 59 | with config_path.open('w') as f: 60 | json.dump(hf_config, f, indent=2) 61 | 62 | 63 | def save_for_hf( 64 | model, 65 | tokenizer: HFTokenizer, 66 | model_config: dict, 67 | save_directory: str, 68 | safe_serialization: Union[bool, str] = 'both', 69 | skip_weights : bool = False, 70 | ): 71 | config_filename = HF_CONFIG_NAME 72 | 73 | save_directory = Path(save_directory) 74 | save_directory.mkdir(exist_ok=True, parents=True) 75 | 76 | if not skip_weights: 77 | tensors = model.state_dict() 78 | if safe_serialization is True or safe_serialization == "both": 79 | assert _has_safetensors, "`pip install safetensors` to use .safetensors" 80 | safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME) 81 | if safe_serialization is False or safe_serialization == "both": 82 | torch.save(tensors, save_directory / HF_WEIGHTS_NAME) 83 | 84 | tokenizer.save_pretrained(save_directory) 85 | 86 | config_path = save_directory / config_filename 87 | save_config_for_hf(model, config_path, model_config=model_config) 88 | 89 | 90 | def push_to_hf_hub( 91 | model, 92 | tokenizer, 93 | model_config: Optional[dict], 94 | repo_id: str, 95 | commit_message: str = 'Add model', 96 | token: Optional[str] = None, 97 | revision: Optional[str] = None, 98 | private: bool = False, 99 | create_pr: bool = False, 100 | model_card: Optional[dict] = None, 101 | safe_serialization: Union[bool, str] = False, 102 | ): 103 | if not isinstance(tokenizer, HFTokenizer): 104 | # FIXME this makes it awkward to push models with new tokenizers, come up with better soln. 105 | # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 106 | tokenizer = HFTokenizer('openai/clip-vit-large-patch14') 107 | 108 | # Create repo if it doesn't exist yet 109 | repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) 110 | 111 | # Infer complete repo_id from repo_url 112 | # Can be different from the input `repo_id` if repo_owner was implicit 113 | _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) 114 | repo_id = f"{repo_owner}/{repo_name}" 115 | 116 | # Check if repo already exists and determine what needs updating 117 | repo_exists = False 118 | repo_files = {} 119 | try: 120 | repo_files = set(list_repo_files(repo_id)) 121 | repo_exists = True 122 | except Exception as e: 123 | print('Repo does not exist', e) 124 | 125 | try: 126 | get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) 127 | has_readme = True 128 | except EntryNotFoundError: 129 | has_readme = False 130 | 131 | # Dump model and push to Hub 132 | with TemporaryDirectory() as tmpdir: 133 | # Save model weights and config. 134 | save_for_hf( 135 | model, 136 | tokenizer=tokenizer, 137 | model_config=model_config, 138 | save_directory=tmpdir, 139 | safe_serialization=safe_serialization, 140 | ) 141 | 142 | # Add readme if it does not exist 143 | if not has_readme: 144 | model_card = model_card or {} 145 | model_name = repo_id.split('/')[-1] 146 | readme_path = Path(tmpdir) / "README.md" 147 | readme_text = generate_readme(model_card, model_name) 148 | readme_path.write_text(readme_text) 149 | 150 | # Upload model and return 151 | return upload_folder( 152 | repo_id=repo_id, 153 | folder_path=tmpdir, 154 | revision=revision, 155 | create_pr=create_pr, 156 | commit_message=commit_message, 157 | ) 158 | 159 | 160 | def push_pretrained_to_hf_hub( 161 | model_name, 162 | pretrained: str, 163 | repo_id: str, 164 | precision: str = 'fp32', 165 | image_mean: Optional[Tuple[float, ...]] = None, 166 | image_std: Optional[Tuple[float, ...]] = None, 167 | image_interpolation: Optional[str] = None, 168 | image_resize_mode: Optional[str] = None, # only effective for inference 169 | commit_message: str = 'Add model', 170 | token: Optional[str] = None, 171 | revision: Optional[str] = None, 172 | private: bool = False, 173 | create_pr: bool = False, 174 | model_card: Optional[dict] = None, 175 | hf_tokenizer_self: bool = False, 176 | ): 177 | model, preprocess_eval = create_model_from_pretrained( 178 | model_name, 179 | pretrained=pretrained, 180 | precision=precision, 181 | image_mean=image_mean, 182 | image_std=image_std, 183 | image_interpolation=image_interpolation, 184 | image_resize_mode=image_resize_mode, 185 | ) 186 | model_config = get_model_config(model_name) 187 | assert model_config 188 | 189 | tokenizer = get_tokenizer(model_name) 190 | if hf_tokenizer_self: 191 | # make hf tokenizer config in the uploaded model point to self instead of original location 192 | model_config['text']['hf_tokenizer_name'] = repo_id 193 | 194 | push_to_hf_hub( 195 | model=model, 196 | tokenizer=tokenizer, 197 | model_config=model_config, 198 | repo_id=repo_id, 199 | commit_message=commit_message, 200 | token=token, 201 | revision=revision, 202 | private=private, 203 | create_pr=create_pr, 204 | model_card=model_card, 205 | safe_serialization='both', 206 | ) 207 | 208 | 209 | def generate_readme(model_card: dict, model_name: str): 210 | tags = model_card.pop('tags', ('clip',)) 211 | pipeline_tag = model_card.pop('pipeline_tag', 'zero-shot-image-classification') 212 | readme_text = "---\n" 213 | if tags: 214 | readme_text += "tags:\n" 215 | for t in tags: 216 | readme_text += f"- {t}\n" 217 | readme_text += "library_name: open_clip\n" 218 | readme_text += f"pipeline_tag: {pipeline_tag}\n" 219 | readme_text += f"license: {model_card.get('license', 'mit')}\n" 220 | if 'details' in model_card and 'Dataset' in model_card['details']: 221 | readme_text += 'datasets:\n' 222 | readme_text += f"- {model_card['details']['Dataset'].lower()}\n" 223 | readme_text += "---\n" 224 | readme_text += f"# Model card for {model_name}\n" 225 | if 'description' in model_card: 226 | readme_text += f"\n{model_card['description']}\n" 227 | if 'details' in model_card: 228 | readme_text += f"\n## Model Details\n" 229 | for k, v in model_card['details'].items(): 230 | if isinstance(v, (list, tuple)): 231 | readme_text += f"- **{k}:**\n" 232 | for vi in v: 233 | readme_text += f" - {vi}\n" 234 | elif isinstance(v, dict): 235 | readme_text += f"- **{k}:**\n" 236 | for ki, vi in v.items(): 237 | readme_text += f" - {ki}: {vi}\n" 238 | else: 239 | readme_text += f"- **{k}:** {v}\n" 240 | if 'usage' in model_card: 241 | readme_text += f"\n## Model Usage\n" 242 | readme_text += model_card['usage'] 243 | readme_text += '\n' 244 | 245 | if 'comparison' in model_card: 246 | readme_text += f"\n## Model Comparison\n" 247 | readme_text += model_card['comparison'] 248 | readme_text += '\n' 249 | 250 | if 'citation' in model_card: 251 | readme_text += f"\n## Citation\n" 252 | if not isinstance(model_card['citation'], (list, tuple)): 253 | citations = [model_card['citation']] 254 | else: 255 | citations = model_card['citation'] 256 | for c in citations: 257 | readme_text += f"```bibtex\n{c}\n```\n" 258 | 259 | return readme_text 260 | 261 | 262 | if __name__ == "__main__": 263 | parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") 264 | parser.add_argument( 265 | "--model", type=str, help="Name of the model to use.", 266 | ) 267 | parser.add_argument( 268 | "--pretrained", type=str, 269 | help="Use a pretrained CLIP model weights with the specified tag or file path.", 270 | ) 271 | parser.add_argument( 272 | "--repo-id", type=str, 273 | help="Destination HF Hub repo-id ie 'organization/model_id'.", 274 | ) 275 | parser.add_argument( 276 | "--precision", type=str, default='fp32', 277 | ) 278 | parser.add_argument( 279 | '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', 280 | help='Override default image mean value of dataset') 281 | parser.add_argument( 282 | '--image-std', type=float, nargs='+', default=None, metavar='STD', 283 | help='Override default image std deviation of of dataset') 284 | parser.add_argument( 285 | '--image-interpolation', 286 | default=None, type=str, choices=['bicubic', 'bilinear', 'random'], 287 | help="image resize interpolation" 288 | ) 289 | parser.add_argument( 290 | '--image-resize-mode', 291 | default=None, type=str, choices=['shortest', 'longest', 'squash'], 292 | help="image resize mode during inference" 293 | ) 294 | parser.add_argument( 295 | "--hf-tokenizer-self", 296 | default=False, 297 | action="store_true", 298 | help="make hf_tokenizer_name point in uploaded config point to itself" 299 | ) 300 | args = parser.parse_args() 301 | 302 | print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') 303 | 304 | # FIXME add support to pass model_card json / template from file via cmd line 305 | 306 | push_pretrained_to_hf_hub( 307 | args.model, 308 | args.pretrained, 309 | args.repo_id, 310 | precision=args.precision, 311 | image_mean=args.image_mean, # override image mean/std if trained w/ non defaults 312 | image_std=args.image_std, 313 | image_interpolation=args.image_interpolation, 314 | image_resize_mode=args.image_resize_mode, 315 | ) 316 | 317 | print(f'{args.model} saved.') 318 | -------------------------------------------------------------------------------- /opencls/open_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | import logging 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | try: 12 | import timm 13 | from timm.models.layers import Mlp, to_2tuple 14 | try: 15 | # old timm imports < 0.8.1 16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 18 | except ImportError: 19 | # new timm imports >= 0.8.1 20 | from timm.layers import RotAttentionPool2d 21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 22 | except ImportError: 23 | timm = None 24 | 25 | from .utils import freeze_batch_norm_2d 26 | 27 | 28 | class TimmModel(nn.Module): 29 | """ timm model adapter 30 | """ 31 | 32 | def __init__( 33 | self, 34 | model_name, 35 | embed_dim, 36 | image_size=224, 37 | pool='avg', 38 | proj='linear', 39 | proj_bias=False, 40 | drop=0., 41 | drop_path=None, 42 | patch_drop=None, 43 | pretrained=False, 44 | ): 45 | super().__init__() 46 | if timm is None: 47 | raise RuntimeError("Please `pip install timm` to use timm models.") 48 | self.image_size = to_2tuple(image_size) 49 | 50 | # setup kwargs that may not be common across all models 51 | timm_kwargs = {} 52 | if drop_path is not None: 53 | timm_kwargs['drop_path_rate'] = drop_path 54 | if patch_drop is not None: 55 | timm_kwargs['patch_drop_rate'] = patch_drop 56 | 57 | custom_pool = pool in ('abs_attn', 'rot_attn') 58 | if proj: 59 | assert proj in ("linear", "mlp", "none") 60 | extra_proj = proj in ("linear", "mlp") 61 | if not extra_proj and not custom_pool: 62 | # use network classifier head as projection if no proj specified and no custom pooling used 63 | # if projection is explicitly set to "none" will be pass through from network trunk 64 | proj_dim = 0 if proj == 'none' else embed_dim 65 | self.trunk = timm.create_model( 66 | model_name, 67 | num_classes=proj_dim, 68 | global_pool=pool, 69 | pretrained=pretrained, 70 | **timm_kwargs, 71 | ) 72 | prev_chs = embed_dim 73 | else: 74 | self.trunk = timm.create_model( 75 | model_name, 76 | pretrained=pretrained, 77 | **timm_kwargs, 78 | ) 79 | feat_size = self.trunk.default_cfg.get('pool_size', None) 80 | feature_ndim = 1 if not feat_size else 2 81 | if custom_pool: 82 | assert feature_ndim == 2 83 | # if attn pooling used, remove both classifier and default pool 84 | self.trunk.reset_classifier(0, global_pool='') 85 | else: 86 | # reset global pool if pool config set, otherwise leave as network default 87 | reset_kwargs = dict(global_pool=pool) if pool else {} 88 | self.trunk.reset_classifier(0, **reset_kwargs) 89 | prev_chs = self.trunk.num_features 90 | 91 | head_layers = OrderedDict() 92 | 93 | # Add custom pooling to head 94 | if pool == 'abs_attn': 95 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 96 | prev_chs = embed_dim 97 | elif pool == 'rot_attn': 98 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 99 | prev_chs = embed_dim 100 | 101 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 102 | if proj == 'linear': 103 | head_layers['drop'] = nn.Dropout(drop) 104 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 105 | elif proj == 'mlp': 106 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) 107 | 108 | self.head = nn.Sequential(head_layers) 109 | 110 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 111 | """ lock modules 112 | Args: 113 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 114 | """ 115 | if not unlocked_groups: 116 | # lock full model 117 | for param in self.trunk.parameters(): 118 | param.requires_grad = False 119 | if freeze_bn_stats: 120 | freeze_batch_norm_2d(self.trunk) 121 | else: 122 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 123 | try: 124 | # FIXME import here until API stable and in an official release 125 | from timm.models.helpers import group_parameters, group_modules 126 | except ImportError: 127 | raise RuntimeError( 128 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 129 | matcher = self.trunk.group_matcher() 130 | gparams = group_parameters(self.trunk, matcher) 131 | max_layer_id = max(gparams.keys()) 132 | max_layer_id = max_layer_id - unlocked_groups 133 | for group_idx in range(max_layer_id + 1): 134 | group = gparams[group_idx] 135 | for param in group: 136 | self.trunk.get_parameter(param).requires_grad = False 137 | if freeze_bn_stats: 138 | gmodules = group_modules(self.trunk, matcher, reverse=True) 139 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 140 | freeze_batch_norm_2d(self.trunk, gmodules) 141 | 142 | @torch.jit.ignore 143 | def set_grad_checkpointing(self, enable=True): 144 | try: 145 | self.trunk.set_grad_checkpointing(enable) 146 | except Exception as e: 147 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') 148 | 149 | def forward(self, x): 150 | x = self.trunk(x) 151 | x = self.head(x) 152 | return x 153 | -------------------------------------------------------------------------------- /opencls/open_clip/transform.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import random 3 | import warnings 4 | from dataclasses import dataclass, asdict 5 | from typing import Any, Dict, List, Optional, Sequence, Tuple, Union 6 | 7 | import torch 8 | import torchvision.transforms.functional as F 9 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 10 | CenterCrop, ColorJitter, Grayscale 11 | 12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 13 | from .utils import to_2tuple 14 | 15 | 16 | @dataclass 17 | class PreprocessCfg: 18 | size: Union[int, Tuple[int, int]] = 224 19 | mode: str = 'RGB' 20 | mean: Tuple[float, ...] = OPENAI_DATASET_MEAN 21 | std: Tuple[float, ...] = OPENAI_DATASET_STD 22 | interpolation: str = 'bicubic' 23 | resize_mode: str = 'shortest' 24 | fill_color: int = 0 25 | 26 | def __post_init__(self): 27 | assert self.mode in ('RGB',) 28 | 29 | @property 30 | def num_channels(self): 31 | return 3 32 | 33 | @property 34 | def input_size(self): 35 | return (self.num_channels,) + to_2tuple(self.size) 36 | 37 | _PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys()) 38 | 39 | 40 | def merge_preprocess_dict( 41 | base: Union[PreprocessCfg, Dict], 42 | overlay: Dict, 43 | ): 44 | """ Merge overlay key-value pairs on top of base preprocess cfg or dict. 45 | Input dicts are filtered based on PreprocessCfg fields. 46 | """ 47 | if isinstance(base, PreprocessCfg): 48 | base_clean = asdict(base) 49 | else: 50 | base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS} 51 | if overlay: 52 | overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None} 53 | base_clean.update(overlay_clean) 54 | return base_clean 55 | 56 | 57 | def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs): 58 | return merge_preprocess_dict(base, kwargs) 59 | 60 | 61 | @dataclass 62 | class AugmentationCfg: 63 | scale: Tuple[float, float] = (0.9, 1.0) 64 | ratio: Optional[Tuple[float, float]] = None 65 | color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None 66 | re_prob: Optional[float] = None 67 | re_count: Optional[int] = None 68 | use_timm: bool = False 69 | 70 | # params for simclr_jitter_gray 71 | color_jitter_prob: float = None 72 | gray_scale_prob: float = None 73 | 74 | 75 | def _setup_size(size, error_msg): 76 | if isinstance(size, numbers.Number): 77 | return int(size), int(size) 78 | 79 | if isinstance(size, Sequence) and len(size) == 1: 80 | return size[0], size[0] 81 | 82 | if len(size) != 2: 83 | raise ValueError(error_msg) 84 | 85 | return size 86 | 87 | 88 | class ResizeKeepRatio: 89 | """ Resize and Keep Ratio 90 | 91 | Copy & paste from `timm` 92 | """ 93 | 94 | def __init__( 95 | self, 96 | size, 97 | longest=0., 98 | interpolation=InterpolationMode.BICUBIC, 99 | random_scale_prob=0., 100 | random_scale_range=(0.85, 1.05), 101 | random_aspect_prob=0., 102 | random_aspect_range=(0.9, 1.11) 103 | ): 104 | if isinstance(size, (list, tuple)): 105 | self.size = tuple(size) 106 | else: 107 | self.size = (size, size) 108 | self.interpolation = interpolation 109 | self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest 110 | self.random_scale_prob = random_scale_prob 111 | self.random_scale_range = random_scale_range 112 | self.random_aspect_prob = random_aspect_prob 113 | self.random_aspect_range = random_aspect_range 114 | 115 | @staticmethod 116 | def get_params( 117 | img, 118 | target_size, 119 | longest, 120 | random_scale_prob=0., 121 | random_scale_range=(0.85, 1.05), 122 | random_aspect_prob=0., 123 | random_aspect_range=(0.9, 1.11) 124 | ): 125 | """Get parameters 126 | """ 127 | source_size = img.size[::-1] # h, w 128 | h, w = source_size 129 | target_h, target_w = target_size 130 | ratio_h = h / target_h 131 | ratio_w = w / target_w 132 | ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest) 133 | if random_scale_prob > 0 and random.random() < random_scale_prob: 134 | ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1]) 135 | ratio_factor = (ratio_factor, ratio_factor) 136 | else: 137 | ratio_factor = (1., 1.) 138 | if random_aspect_prob > 0 and random.random() < random_aspect_prob: 139 | aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1]) 140 | ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor) 141 | size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)] 142 | return size 143 | 144 | def __call__(self, img): 145 | """ 146 | Args: 147 | img (PIL Image): Image to be cropped and resized. 148 | 149 | Returns: 150 | PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size 151 | """ 152 | size = self.get_params( 153 | img, self.size, self.longest, 154 | self.random_scale_prob, self.random_scale_range, 155 | self.random_aspect_prob, self.random_aspect_range 156 | ) 157 | img = F.resize(img, size, self.interpolation) 158 | return img 159 | 160 | def __repr__(self): 161 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 162 | format_string += f', interpolation={self.interpolation})' 163 | format_string += f', longest={self.longest:.3f})' 164 | return format_string 165 | 166 | 167 | def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor: 168 | """Center crops and/or pads the given image. 169 | If the image is torch Tensor, it is expected 170 | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. 171 | If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. 172 | 173 | Args: 174 | img (PIL Image or Tensor): Image to be cropped. 175 | output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int, 176 | it is used for both directions. 177 | fill (int, Tuple[int]): Padding color 178 | 179 | Returns: 180 | PIL Image or Tensor: Cropped image. 181 | """ 182 | if isinstance(output_size, numbers.Number): 183 | output_size = (int(output_size), int(output_size)) 184 | elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: 185 | output_size = (output_size[0], output_size[0]) 186 | 187 | _, image_height, image_width = F.get_dimensions(img) 188 | crop_height, crop_width = output_size 189 | 190 | if crop_width > image_width or crop_height > image_height: 191 | padding_ltrb = [ 192 | (crop_width - image_width) // 2 if crop_width > image_width else 0, 193 | (crop_height - image_height) // 2 if crop_height > image_height else 0, 194 | (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, 195 | (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, 196 | ] 197 | img = F.pad(img, padding_ltrb, fill=fill) 198 | _, image_height, image_width = F.get_dimensions(img) 199 | if crop_width == image_width and crop_height == image_height: 200 | return img 201 | 202 | crop_top = int(round((image_height - crop_height) / 2.0)) 203 | crop_left = int(round((image_width - crop_width) / 2.0)) 204 | return F.crop(img, crop_top, crop_left, crop_height, crop_width) 205 | 206 | 207 | class CenterCropOrPad(torch.nn.Module): 208 | """Crops the given image at the center. 209 | If the image is torch Tensor, it is expected 210 | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. 211 | If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. 212 | 213 | Args: 214 | size (sequence or int): Desired output size of the crop. If size is an 215 | int instead of sequence like (h, w), a square crop (size, size) is 216 | made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). 217 | """ 218 | 219 | def __init__(self, size, fill=0): 220 | super().__init__() 221 | self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") 222 | self.fill = fill 223 | 224 | def forward(self, img): 225 | """ 226 | Args: 227 | img (PIL Image or Tensor): Image to be cropped. 228 | 229 | Returns: 230 | PIL Image or Tensor: Cropped image. 231 | """ 232 | return center_crop_or_pad(img, self.size, fill=self.fill) 233 | 234 | def __repr__(self) -> str: 235 | return f"{self.__class__.__name__}(size={self.size})" 236 | 237 | 238 | def _convert_to_rgb(image): 239 | return image.convert('RGB') 240 | 241 | 242 | class color_jitter(object): 243 | """ 244 | Apply Color Jitter to the PIL image with a specified probability. 245 | """ 246 | def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., p=0.8): 247 | assert 0. <= p <= 1. 248 | self.p = p 249 | self.transf = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) 250 | 251 | def __call__(self, img): 252 | if random.random() < self.p: 253 | return self.transf(img) 254 | else: 255 | return img 256 | 257 | 258 | class gray_scale(object): 259 | """ 260 | Apply Gray Scale to the PIL image with a specified probability. 261 | """ 262 | def __init__(self, p=0.2): 263 | assert 0. <= p <= 1. 264 | self.p = p 265 | self.transf = Grayscale(num_output_channels=3) 266 | 267 | def __call__(self, img): 268 | if random.random() < self.p: 269 | return self.transf(img) 270 | else: 271 | return img 272 | 273 | 274 | def image_transform( 275 | image_size: Union[int, Tuple[int, int]], 276 | is_train: bool, 277 | mean: Optional[Tuple[float, ...]] = None, 278 | std: Optional[Tuple[float, ...]] = None, 279 | resize_mode: Optional[str] = None, 280 | interpolation: Optional[str] = None, 281 | fill_color: int = 0, 282 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, 283 | ): 284 | mean = mean or OPENAI_DATASET_MEAN 285 | if not isinstance(mean, (list, tuple)): 286 | mean = (mean,) * 3 287 | 288 | std = std or OPENAI_DATASET_STD 289 | if not isinstance(std, (list, tuple)): 290 | std = (std,) * 3 291 | 292 | interpolation = interpolation or 'bicubic' 293 | assert interpolation in ['bicubic', 'bilinear', 'random'] 294 | # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set 295 | interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC 296 | 297 | resize_mode = resize_mode or 'shortest' 298 | assert resize_mode in ('shortest', 'longest', 'squash') 299 | 300 | if isinstance(aug_cfg, dict): 301 | aug_cfg = AugmentationCfg(**aug_cfg) 302 | else: 303 | aug_cfg = aug_cfg or AugmentationCfg() 304 | 305 | normalize = Normalize(mean=mean, std=std) 306 | 307 | if is_train: 308 | aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} 309 | use_timm = aug_cfg_dict.pop('use_timm', False) 310 | if use_timm: 311 | from timm.data import create_transform # timm can still be optional 312 | if isinstance(image_size, (tuple, list)): 313 | assert len(image_size) >= 2 314 | input_size = (3,) + image_size[-2:] 315 | else: 316 | input_size = (3, image_size, image_size) 317 | 318 | aug_cfg_dict.setdefault('color_jitter', None) # disable by default 319 | # drop extra non-timm items 320 | aug_cfg_dict.pop('color_jitter_prob', None) 321 | aug_cfg_dict.pop('gray_scale_prob', None) 322 | 323 | train_transform = create_transform( 324 | input_size=input_size, 325 | is_training=True, 326 | hflip=0., 327 | mean=mean, 328 | std=std, 329 | re_mode='pixel', 330 | interpolation=interpolation, 331 | **aug_cfg_dict, 332 | ) 333 | else: 334 | train_transform = [ 335 | RandomResizedCrop( 336 | image_size, 337 | scale=aug_cfg_dict.pop('scale'), 338 | interpolation=InterpolationMode.BICUBIC, 339 | ), 340 | _convert_to_rgb, 341 | ] 342 | if aug_cfg.color_jitter_prob: 343 | assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4 344 | train_transform.extend([ 345 | color_jitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob) 346 | ]) 347 | if aug_cfg.gray_scale_prob: 348 | train_transform.extend([ 349 | gray_scale(aug_cfg.gray_scale_prob) 350 | ]) 351 | train_transform.extend([ 352 | ToTensor(), 353 | normalize, 354 | ]) 355 | train_transform = Compose(train_transform) 356 | if aug_cfg_dict: 357 | warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') 358 | return train_transform 359 | else: 360 | if resize_mode == 'longest': 361 | transforms = [ 362 | ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1), 363 | CenterCropOrPad(image_size, fill=fill_color) 364 | ] 365 | elif resize_mode == 'squash': 366 | if isinstance(image_size, int): 367 | image_size = (image_size, image_size) 368 | transforms = [ 369 | Resize(image_size, interpolation=interpolation_mode), 370 | ] 371 | else: 372 | assert resize_mode == 'shortest' 373 | if not isinstance(image_size, (tuple, list)): 374 | image_size = (image_size, image_size) 375 | if image_size[0] == image_size[1]: 376 | # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg) 377 | transforms = [ 378 | Resize(image_size[0], interpolation=interpolation_mode) 379 | ] 380 | else: 381 | # resize shortest edge to matching target dim for non-square target 382 | transforms = [ResizeKeepRatio(image_size)] 383 | transforms += [CenterCrop(image_size)] 384 | 385 | transforms.extend([ 386 | _convert_to_rgb, 387 | ToTensor(), 388 | normalize, 389 | ]) 390 | return Compose(transforms) 391 | 392 | 393 | def image_transform_v2( 394 | cfg: PreprocessCfg, 395 | is_train: bool, 396 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, 397 | ): 398 | return image_transform( 399 | image_size=cfg.size, 400 | is_train=is_train, 401 | mean=cfg.mean, 402 | std=cfg.std, 403 | interpolation=cfg.interpolation, 404 | resize_mode=cfg.resize_mode, 405 | fill_color=cfg.fill_color, 406 | aug_cfg=aug_cfg, 407 | ) 408 | -------------------------------------------------------------------------------- /opencls/open_clip/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | import torch 5 | from torch import nn as nn 6 | from torchvision.ops.misc import FrozenBatchNorm2d 7 | 8 | 9 | def freeze_batch_norm_2d(module, module_match={}, name=''): 10 | """ 11 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 12 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 13 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 14 | 15 | Args: 16 | module (torch.nn.Module): Any PyTorch module. 17 | module_match (dict): Dictionary of full module names to freeze (all if empty) 18 | name (str): Full module name (prefix) 19 | 20 | Returns: 21 | torch.nn.Module: Resulting module 22 | 23 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 24 | """ 25 | res = module 26 | is_match = True 27 | if module_match: 28 | is_match = name in module_match 29 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 30 | res = FrozenBatchNorm2d(module.num_features) 31 | res.num_features = module.num_features 32 | res.affine = module.affine 33 | if module.affine: 34 | res.weight.data = module.weight.data.clone().detach() 35 | res.bias.data = module.bias.data.clone().detach() 36 | res.running_mean.data = module.running_mean.data 37 | res.running_var.data = module.running_var.data 38 | res.eps = module.eps 39 | else: 40 | for child_name, child in module.named_children(): 41 | full_child_name = '.'.join([name, child_name]) if name else child_name 42 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 43 | if new_child is not child: 44 | res.add_module(child_name, new_child) 45 | return res 46 | 47 | 48 | # From PyTorch internals 49 | def _ntuple(n): 50 | def parse(x): 51 | if isinstance(x, collections.abc.Iterable): 52 | return x 53 | return tuple(repeat(x, n)) 54 | return parse 55 | 56 | 57 | to_1tuple = _ntuple(1) 58 | to_2tuple = _ntuple(2) 59 | to_3tuple = _ntuple(3) 60 | to_4tuple = _ntuple(4) 61 | to_ntuple = lambda n, x: _ntuple(n)(x) 62 | 63 | # Replaces all linear layers with linear_replacement 64 | # TODO: add int8 support for other linear layers including attn and convnets 65 | def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): 66 | for name, module in model.named_children(): 67 | if len(list(module.children())) > 0: 68 | replace_linear(module, linear_replacement, include_modules, copy_weights) 69 | 70 | if isinstance(module, torch.nn.Linear) and name in include_modules: 71 | old_module = model._modules[name] 72 | model._modules[name] = linear_replacement( 73 | module.in_features, 74 | module.out_features, 75 | module.bias is not None, 76 | ) 77 | if copy_weights: 78 | model._modules[name].weight.data.copy_(old_module.weight.data) 79 | if model._modules[name].bias is not None: 80 | model._modules[name].bias.data.copy_(old_module.bias) 81 | 82 | return model 83 | 84 | def convert_int8_model_to_inference_mode(model): 85 | for m in model.modules(): 86 | if hasattr(m, 'prepare_for_eval'): 87 | int8_original_dtype = m.weight.dtype 88 | m.prepare_for_eval() 89 | m.int8_original_dtype = int8_original_dtype -------------------------------------------------------------------------------- /opencls/open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.24.0' 2 | -------------------------------------------------------------------------------- /opencls/open_clip/zero_shot_classifier.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from itertools import islice 3 | from typing import Callable, List, Optional, Sequence, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def batched(iterable, n): 10 | """Batch data into lists of length *n*. The last batch may be shorter. 11 | NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl 12 | """ 13 | it = iter(iterable) 14 | while True: 15 | batch = list(islice(it, n)) 16 | if not batch: 17 | break 18 | yield batch 19 | 20 | 21 | def build_zero_shot_classifier( 22 | model, 23 | tokenizer, 24 | classnames: Sequence[str], 25 | templates: Sequence[Union[Callable, str]], 26 | num_classes_per_batch: Optional[int] = 10, 27 | device: Union[str, torch.device] = 'cpu', 28 | use_tqdm: bool = False, 29 | ): 30 | """ Build zero-shot classifier weights by iterating over class names in batches 31 | Args: 32 | model: CLIP model instance 33 | tokenizer: CLIP tokenizer instance 34 | classnames: A sequence of class (label) names 35 | templates: A sequence of callables or format() friendly strings to produce templates per class name 36 | num_classes_per_batch: The number of classes to batch together in each forward, all if None 37 | device: Device to use. 38 | use_tqdm: Enable TQDM progress bar. 39 | """ 40 | assert isinstance(templates, Sequence) and len(templates) > 0 41 | assert isinstance(classnames, Sequence) and len(classnames) > 0 42 | use_format = isinstance(templates[0], str) 43 | num_templates = len(templates) 44 | num_classes = len(classnames) 45 | if use_tqdm: 46 | import tqdm 47 | num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) 48 | iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) 49 | else: 50 | iter_wrap = iter 51 | 52 | def _process_batch(batch_classnames): 53 | num_batch_classes = len(batch_classnames) 54 | texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] 55 | texts = tokenizer(texts).to(device) 56 | class_embeddings = model.encode_text(texts, normalize=True) 57 | class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) 58 | class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) 59 | class_embeddings = class_embeddings.T 60 | return class_embeddings 61 | 62 | with torch.no_grad(): 63 | if num_classes_per_batch: 64 | batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] 65 | zeroshot_weights = torch.cat(batched_embeds, dim=1) 66 | else: 67 | zeroshot_weights = _process_batch(classnames) 68 | return zeroshot_weights 69 | 70 | 71 | def build_zero_shot_classifier_legacy( 72 | model, 73 | tokenizer, 74 | classnames: Sequence[str], 75 | templates: Sequence[Union[Callable, str]], 76 | device: Union[str, torch.device] = 'cpu', 77 | use_tqdm: bool = False, 78 | ): 79 | """ Build zero-shot classifier weights by iterating over class names 1 by 1 80 | Args: 81 | model: CLIP model instance 82 | tokenizer: CLIP tokenizer instance 83 | classnames: A sequence of class (label) names 84 | templates: A sequence of callables or format() friendly strings to produce templates per class name 85 | device: Device to use. 86 | use_tqdm: Enable TQDM progress bar. 87 | """ 88 | assert isinstance(templates, Sequence) and len(templates) > 0 89 | assert isinstance(classnames, Sequence) and len(classnames) > 0 90 | if use_tqdm: 91 | import tqdm 92 | iter_wrap = tqdm.tqdm 93 | else: 94 | iter_wrap = iter 95 | 96 | use_format = isinstance(templates[0], str) 97 | 98 | with torch.no_grad(): 99 | zeroshot_weights = [] 100 | for classname in iter_wrap(classnames): 101 | texts = [template.format(classname) if use_format else template(classname) for template in templates] 102 | texts = tokenizer(texts).to(device) # tokenize 103 | class_embeddings = model.encode_text(texts) 104 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) 105 | class_embedding /= class_embedding.norm() 106 | zeroshot_weights.append(class_embedding) 107 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) 108 | 109 | return zeroshot_weights 110 | 111 | -------------------------------------------------------------------------------- /opencls/training/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/x-cls/superclass/90b761f41177400caa178d3845b39eba771e5bf1/opencls/training/.DS_Store -------------------------------------------------------------------------------- /opencls/training/.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | -------------------------------------------------------------------------------- /opencls/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/x-cls/superclass/90b761f41177400caa178d3845b39eba771e5bf1/opencls/training/__init__.py -------------------------------------------------------------------------------- /opencls/training/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | try: 7 | import horovod.torch as hvd 8 | except ImportError: 9 | hvd = None 10 | 11 | 12 | def is_global_master(args): 13 | return args.rank == 0 14 | 15 | 16 | def is_local_master(args): 17 | return args.local_rank == 0 18 | 19 | 20 | def is_master(args, local=False): 21 | return is_local_master(args) if local else is_global_master(args) 22 | 23 | 24 | def is_using_horovod(): 25 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 26 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 27 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 28 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 29 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): 30 | return True 31 | else: 32 | return False 33 | 34 | 35 | def is_using_distributed(): 36 | if 'WORLD_SIZE' in os.environ: 37 | return int(os.environ['WORLD_SIZE']) > 1 38 | if 'SLURM_NTASKS' in os.environ: 39 | return int(os.environ['SLURM_NTASKS']) > 1 40 | return False 41 | 42 | 43 | def world_info_from_env(): 44 | local_rank = 0 45 | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): 46 | if v in os.environ: 47 | local_rank = int(os.environ[v]) 48 | break 49 | global_rank = 0 50 | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): 51 | if v in os.environ: 52 | global_rank = int(os.environ[v]) 53 | break 54 | world_size = 1 55 | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): 56 | if v in os.environ: 57 | world_size = int(os.environ[v]) 58 | break 59 | 60 | return local_rank, global_rank, world_size 61 | 62 | 63 | def init_distributed_device(args): 64 | # Distributed training = training on more than one GPU. 65 | # Works in both single and multi-node scenarios. 66 | args.distributed = False 67 | args.world_size = 1 68 | args.rank = 0 # global rank 69 | args.local_rank = 0 70 | if args.horovod: 71 | assert hvd is not None, "Horovod is not installed" 72 | hvd.init() 73 | args.local_rank = int(hvd.local_rank()) 74 | args.rank = hvd.rank() 75 | args.world_size = hvd.size() 76 | args.distributed = True 77 | os.environ['LOCAL_RANK'] = str(args.local_rank) 78 | os.environ['RANK'] = str(args.rank) 79 | os.environ['WORLD_SIZE'] = str(args.world_size) 80 | elif is_using_distributed(): 81 | if 'SLURM_PROCID' in os.environ: 82 | # DDP via SLURM 83 | args.local_rank, args.rank, args.world_size = world_info_from_env() 84 | # SLURM var -> torch.distributed vars in case needed 85 | os.environ['LOCAL_RANK'] = str(args.local_rank) 86 | os.environ['RANK'] = str(args.rank) 87 | os.environ['WORLD_SIZE'] = str(args.world_size) 88 | torch.distributed.init_process_group( 89 | backend=args.dist_backend, 90 | init_method=args.dist_url, 91 | world_size=args.world_size, 92 | rank=args.rank, 93 | ) 94 | else: 95 | # DDP via torchrun, torch.distributed.launch 96 | args.local_rank, _, _ = world_info_from_env() 97 | torch.distributed.init_process_group( 98 | backend=args.dist_backend, 99 | init_method=args.dist_url) 100 | args.world_size = torch.distributed.get_world_size() 101 | args.rank = torch.distributed.get_rank() 102 | args.distributed = True 103 | 104 | if torch.cuda.is_available(): 105 | if args.distributed and not args.no_set_device_rank: 106 | device = 'cuda:%d' % args.local_rank 107 | else: 108 | device = 'cuda:0' 109 | torch.cuda.set_device(device) 110 | else: 111 | device = 'cpu' 112 | args.device = device 113 | device = torch.device(device) 114 | return device 115 | 116 | 117 | def broadcast_object(args, obj, src=0): 118 | # broadcast a pickle-able python object from rank-0 to all ranks 119 | if args.horovod: 120 | return hvd.broadcast_object(obj, root_rank=src) 121 | else: 122 | if args.rank == src: 123 | objects = [obj] 124 | else: 125 | objects = [None] 126 | dist.broadcast_object_list(objects, src=src) 127 | return objects[0] 128 | 129 | 130 | def all_gather_object(args, obj, dst=0): 131 | # gather a pickle-able python object across all ranks 132 | if args.horovod: 133 | return hvd.allgather_object(obj) 134 | else: 135 | objects = [None for _ in range(args.world_size)] 136 | dist.all_gather_object(objects, obj) 137 | return objects 138 | -------------------------------------------------------------------------------- /opencls/training/file_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import multiprocessing 4 | import subprocess 5 | import time 6 | import fsspec 7 | import torch 8 | from tqdm import tqdm 9 | 10 | def remote_sync_s3(local_dir, remote_dir): 11 | # skip epoch_latest which can change during sync. 12 | result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 13 | if result.returncode != 0: 14 | logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}") 15 | return False 16 | 17 | logging.info(f"Successfully synced with S3 bucket") 18 | return True 19 | 20 | def remote_sync_fsspec(local_dir, remote_dir): 21 | # FIXME currently this is slow and not recommended. Look into speeding up. 22 | a = fsspec.get_mapper(local_dir) 23 | b = fsspec.get_mapper(remote_dir) 24 | 25 | for k in a: 26 | # skip epoch_latest which can change during sync. 27 | if 'epoch_latest.pt' in k: 28 | continue 29 | 30 | logging.info(f'Attempting to sync {k}') 31 | if k in b and len(a[k]) == len(b[k]): 32 | logging.debug(f'Skipping remote sync for {k}.') 33 | continue 34 | 35 | try: 36 | logging.info(f'Successful sync for {k}.') 37 | b[k] = a[k] 38 | except Exception as e: 39 | logging.info(f'Error during remote sync for {k}: {e}') 40 | return False 41 | 42 | return True 43 | 44 | def remote_sync(local_dir, remote_dir, protocol): 45 | logging.info('Starting remote sync.') 46 | if protocol == 's3': 47 | return remote_sync_s3(local_dir, remote_dir) 48 | elif protocol == 'fsspec': 49 | return remote_sync_fsspec(local_dir, remote_dir) 50 | else: 51 | logging.error('Remote protocol not known') 52 | return False 53 | 54 | def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol): 55 | while True: 56 | time.sleep(sync_every) 57 | remote_sync(local_dir, remote_dir, protocol) 58 | 59 | def start_sync_process(sync_every, local_dir, remote_dir, protocol): 60 | p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol)) 61 | return p 62 | 63 | # Note: we are not currently using this save function. 64 | def pt_save(pt_obj, file_path): 65 | of = fsspec.open(file_path, "wb") 66 | with of as f: 67 | torch.save(pt_obj, file_path) 68 | 69 | def pt_load(file_path, map_location=None): 70 | if file_path.startswith('s3'): 71 | logging.info('Loading remote checkpoint, which may take a bit.') 72 | of = fsspec.open(file_path, "rb") 73 | with of as f: 74 | out = torch.load(f, map_location=map_location) 75 | return out 76 | 77 | def check_exists(file_path): 78 | try: 79 | with fsspec.open(file_path): 80 | pass 81 | except FileNotFoundError: 82 | return False 83 | return True 84 | -------------------------------------------------------------------------------- /opencls/training/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logging(log_file, level, include_host=False): 5 | if include_host: 6 | import socket 7 | hostname = socket.gethostname() 8 | formatter = logging.Formatter( 9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 10 | else: 11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 12 | 13 | logging.root.setLevel(level) 14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 15 | for logger in loggers: 16 | logger.setLevel(level) 17 | 18 | stream_handler = logging.StreamHandler() 19 | stream_handler.setFormatter(formatter) 20 | logging.root.addHandler(stream_handler) 21 | 22 | if log_file: 23 | file_handler = logging.FileHandler(filename=log_file) 24 | file_handler.setFormatter(formatter) 25 | logging.root.addHandler(file_handler) 26 | 27 | -------------------------------------------------------------------------------- /opencls/training/precision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from contextlib import suppress 3 | 4 | 5 | def get_autocast(precision): 6 | if precision == 'amp': 7 | return torch.cuda.amp.autocast 8 | elif precision == 'amp_bfloat16' or precision == 'amp_bf16': 9 | # amp_bfloat16 is more stable than amp float16 for clip training 10 | return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) 11 | else: 12 | return suppress 13 | -------------------------------------------------------------------------------- /opencls/training/profiler0.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import open_clip 5 | import pandas as pd 6 | from torch.utils.flop_counter import FlopCounterMode 7 | try: 8 | import fvcore 9 | except: 10 | fvcore = None 11 | 12 | parser = argparse.ArgumentParser(description='OpenCLIP Profiler') 13 | 14 | # benchmark specific args 15 | parser.add_argument('--model', metavar='NAME', default='', 16 | help='model(s) to profile') 17 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', 18 | help='Output csv file for results') 19 | parser.add_argument('--profiler', default='torch', type=str, choices=['torch', 'fvcore']) 20 | parser.add_argument('--batch-size', default=1, type=int, help='Batch size for profiling') 21 | 22 | 23 | def profile_fvcore( 24 | model, 25 | image_input_size=(3, 224, 224), 26 | text_input_size=(77,), 27 | batch_size=1, 28 | detailed=False, 29 | force_cpu=False 30 | ): 31 | if force_cpu: 32 | model = model.to('cpu') 33 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 34 | example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 35 | example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 36 | fca = fvcore.nn.FlopCountAnalysis(model, (example_image_input, example_text_input)) 37 | aca = fvcore.nn.ActivationCountAnalysis(model, (example_image_input, example_text_input)) 38 | if detailed: 39 | fcs = fvcore.nn.flop_count_str(fca) 40 | print(fcs) 41 | return fca.total() / batch_size, aca.total() / batch_size 42 | 43 | 44 | def profile_fvcore_text( 45 | model, 46 | text_input_size=(77,), 47 | batch_size=1, 48 | detailed=False, 49 | force_cpu=False 50 | ): 51 | if force_cpu: 52 | model = model.to('cpu') 53 | device = next(model.parameters()).device 54 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 55 | fca = fvcore.nn.FlopCountAnalysis(model, example_input) 56 | aca = fvcore.nn.ActivationCountAnalysis(model, example_input) 57 | if detailed: 58 | fcs = fvcore.nn.flop_count_str(fca) 59 | print(fcs) 60 | return fca.total() / batch_size, aca.total() / batch_size 61 | 62 | 63 | def profile_fvcore_image( 64 | model, 65 | image_input_size=(3, 224, 224), 66 | batch_size=1, 67 | detailed=False, 68 | force_cpu=False 69 | ): 70 | if force_cpu: 71 | model = model.to('cpu') 72 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 73 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 74 | fca = fvcore.nn.FlopCountAnalysis(model, example_input) 75 | aca = fvcore.nn.ActivationCountAnalysis(model, example_input) 76 | if detailed: 77 | fcs = fvcore.nn.flop_count_str(fca) 78 | print(fcs) 79 | return fca.total() / batch_size, aca.total() / batch_size 80 | 81 | 82 | def profile_torch_image(model, image_input_size, batch_size=1, force_cpu=False): 83 | """Profile the image encoder using torch.utils.flop_counter""" 84 | if force_cpu: 85 | model = model.to('cpu') 86 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 87 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 88 | 89 | flop_counter = FlopCounterMode() 90 | with flop_counter: 91 | model(example_input) 92 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values()) 93 | return total_flops / batch_size 94 | 95 | 96 | def profile_torch_text(model, text_input_size, batch_size=1, force_cpu=False): 97 | """Profile the text encoder using torch.utils.flop_counter""" 98 | if force_cpu: 99 | model = model.to('cpu') 100 | device = next(model.parameters()).device 101 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 102 | 103 | flop_counter = FlopCounterMode() 104 | with flop_counter: 105 | model(example_input) 106 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values()) 107 | return total_flops / batch_size 108 | 109 | 110 | def profile_torch(model, text_input_size, image_input_size, batch_size=1, force_cpu=False): 111 | """Profile the full model using torch.utils.flop_counter""" 112 | if force_cpu: 113 | model = model.to('cpu') 114 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 115 | image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 116 | text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 117 | 118 | flop_counter = FlopCounterMode() 119 | with flop_counter: 120 | model(image_input, text_input) 121 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values()) 122 | return total_flops / batch_size 123 | 124 | 125 | def count_params(model): 126 | return sum(m.numel() for m in model.parameters()) 127 | 128 | def profile_model(model_name, batch_size=1, profiler='torch'): 129 | assert profiler in ['torch', 'fvcore'], 'Only torch and fvcore profilers are supported' 130 | if profiler == 'fvcore': 131 | assert fvcore is not None, 'Please install fvcore.' 132 | model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False) 133 | model.eval() 134 | if torch.cuda.is_available(): 135 | model = model.cuda() 136 | 137 | if isinstance(model.visual.image_size, (tuple, list)): 138 | image_input_size = (3,) + tuple(model.visual.image_size[-2:]) 139 | else: 140 | image_input_size = (3, model.visual.image_size, model.visual.image_size) 141 | 142 | text_input_size = (77,) 143 | if hasattr(model, 'context_length') and model.context_length: 144 | text_input_size = (model.context_length,) 145 | 146 | results = {} 147 | results['model'] = model_name 148 | results['image_size'] = image_input_size[1] 149 | 150 | model_cfg = open_clip.get_model_config(model_name) 151 | if model_cfg: 152 | vision_cfg = open_clip.CLIPVisionCfg(**model_cfg['vision_cfg']) 153 | text_cfg = open_clip.CLIPTextCfg(**model_cfg['text_cfg']) 154 | results['image_width'] = int(vision_cfg.width) 155 | results['text_width'] = int(text_cfg.width) 156 | results['embed_dim'] = int(model_cfg['embed_dim']) 157 | else: 158 | results['image_width'] = 0 159 | results['text_width'] = 0 160 | results['embed_dim'] = 0 161 | 162 | retries = 2 163 | while retries: 164 | retries -= 1 165 | try: 166 | results['mparams'] = round(count_params(model) / 1e6, 2) 167 | results['image_mparams'] = round(count_params(model.visual) / 1e6, 2) 168 | results['text_mparams'] = round(count_params(model.text) / 1e6, 2) 169 | 170 | if profiler == 'fvcore': 171 | macs, acts = profile_fvcore( 172 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 173 | 174 | image_macs, image_acts = profile_fvcore_image( 175 | model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size) 176 | 177 | text_macs, text_acts = profile_fvcore_text( 178 | model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 179 | 180 | results['gmacs'] = round(macs / 1e9, 2) 181 | results['macts'] = round(acts / 1e6, 2) 182 | 183 | results['image_gmacs'] = round(image_macs / 1e9, 2) 184 | results['image_macts'] = round(image_acts / 1e6, 2) 185 | 186 | results['text_gmacs'] = round(text_macs / 1e9, 2) 187 | results['text_macts'] = round(text_acts / 1e6, 2) 188 | elif profiler == 'torch': 189 | image_flops = profile_torch_image( 190 | model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size) 191 | text_flops = profile_torch_text( 192 | model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 193 | total_flops = profile_torch( 194 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 195 | 196 | results['gflops'] = round(total_flops / 1e9, 2) 197 | results['image_gflops'] = round(image_flops / 1e9, 2) 198 | results['text_gflops'] = round(text_flops / 1e9, 2) 199 | 200 | except RuntimeError as e: 201 | pass 202 | return results 203 | 204 | 205 | def main(): 206 | args = parser.parse_args() 207 | 208 | # FIXME accept a text file name to allow lists of models in txt/csv 209 | if args.model == 'all': 210 | parsed_model = open_clip.list_models() 211 | else: 212 | parsed_model = args.model.split(',') 213 | 214 | results = [] 215 | models_with_errors = [] 216 | for m in parsed_model: 217 | print('='*100) 218 | print(f'Profiling {m}') 219 | try: 220 | row = profile_model(m, batch_size=args.batch_size, profiler=args.profiler) 221 | results.append(row) 222 | except Exception as e: 223 | print(f'Error profiling {m}: {e}') 224 | import traceback 225 | traceback.print_exc() 226 | models_with_errors.append(m) 227 | 228 | df = pd.DataFrame(results, columns=results[0].keys()) 229 | 230 | if 'gmacs' in df.columns: 231 | df = df.sort_values(by=['gmacs', 'mparams', 'model']) 232 | else: 233 | df = df.sort_values(by=['gflops', 'mparams', 'model']) 234 | 235 | print('='*100) 236 | print('Done.') 237 | print(df) 238 | if args.results_file: 239 | df.to_csv(args.results_file, index=False) 240 | 241 | if models_with_errors: 242 | print('Models with errors:', models_with_errors) 243 | 244 | 245 | if __name__ == '__main__': 246 | main() 247 | -------------------------------------------------------------------------------- /opencls/training/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def const_lr(optimizer, base_lr, warmup_length, steps): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | lr = base_lr 19 | assign_learning_rate(optimizer, lr) 20 | return lr 21 | return _lr_adjuster 22 | 23 | 24 | def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.): 25 | def _lr_adjuster(step): 26 | start_cooldown_step = steps - cooldown_steps 27 | if step < warmup_length: 28 | lr = _warmup_lr(base_lr, warmup_length, step) 29 | else: 30 | if step < start_cooldown_step: 31 | lr = base_lr 32 | else: 33 | e = step - start_cooldown_step 34 | es = steps - start_cooldown_step 35 | # linear decay if power == 1; polynomial decay otherwise; 36 | decay = (1 - (e/es)) ** cooldown_power 37 | lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr 38 | assign_learning_rate(optimizer, lr) 39 | return lr 40 | return _lr_adjuster 41 | 42 | 43 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 44 | def _lr_adjuster(step): 45 | if step < warmup_length: 46 | lr = _warmup_lr(base_lr, warmup_length, step) 47 | else: 48 | e = step - warmup_length 49 | es = steps - warmup_length 50 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 51 | assign_learning_rate(optimizer, lr) 52 | return lr 53 | return _lr_adjuster 54 | -------------------------------------------------------------------------------- /opencls/training/zero_shot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \ 7 | IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES 8 | from .precision import get_autocast 9 | 10 | 11 | def accuracy(output, target, topk=(1,)): 12 | pred = output.topk(max(topk), 1, True, True)[1].t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 15 | 16 | 17 | def run(model, classifier, dataloader, args): 18 | autocast = get_autocast(args.precision) 19 | input_dtype = get_input_dtype(args.precision) 20 | 21 | with torch.no_grad(): 22 | top1, top5, n = 0., 0., 0. 23 | for images, target in tqdm(dataloader, unit_scale=args.batch_size): 24 | images = images.to(device=args.device, dtype=input_dtype) 25 | target = target.to(args.device) 26 | 27 | with autocast(): 28 | # predict 29 | output = model(image=images) 30 | image_features = output['image_features'] if isinstance(output, dict) else output[0] 31 | logits = 100. * image_features @ classifier 32 | 33 | # measure accuracy 34 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 35 | top1 += acc1 36 | top5 += acc5 37 | n += images.size(0) 38 | 39 | top1 = (top1 / n) 40 | top5 = (top5 / n) 41 | return top1, top5 42 | 43 | 44 | def zero_shot_eval(model, data, epoch, args, tokenizer=None, should_zero_eval=False): 45 | if 'imagenet-val' not in data and 'imagenet-v2' not in data: 46 | return {} 47 | # if args.zeroshot_frequency == 0: 48 | # return {} 49 | # if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: 50 | # return {} 51 | if not should_zero_eval: 52 | return {} 53 | if args.distributed and not args.horovod: 54 | model = model.module 55 | 56 | logging.info('Starting zero-shot imagenet.') 57 | if tokenizer is None: 58 | tokenizer = get_tokenizer(args.model) 59 | 60 | logging.info('Building zero-shot classifier') 61 | autocast = get_autocast(args.precision) 62 | with autocast(): 63 | classifier = build_zero_shot_classifier( 64 | model, 65 | tokenizer=tokenizer, 66 | classnames=IMAGENET_CLASSNAMES, 67 | templates=OPENAI_IMAGENET_TEMPLATES, 68 | num_classes_per_batch=10, 69 | device=args.device, 70 | use_tqdm=True, 71 | ) 72 | 73 | logging.info('Using classifier') 74 | results = {} 75 | if 'imagenet-val' in data: 76 | top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) 77 | results['imagenet-zeroshot-val-top1'] = top1 78 | results['imagenet-zeroshot-val-top5'] = top5 79 | if 'imagenet-v2' in data: 80 | top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) 81 | results['imagenetv2-zeroshot-val-top1'] = top1 82 | results['imagenetv2-zeroshot-val-top5'] = top5 83 | 84 | logging.info('Finished zero-shot imagenet.') 85 | 86 | return results 87 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | torchvision 3 | regex 4 | ftfy 5 | tqdm 6 | huggingface_hub 7 | safetensors 8 | timm -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | DIR=$2 3 | NV_GPUS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) 4 | GPUS=${GPUS:-${NV_GPUS}} 5 | NNODES=${NNODES:-1} 6 | NODE_RANK=${NODE_RANK:-0} 7 | PORT=${PORT:-55565} 8 | PORT=9007 9 | DIR=${DIR:-"opencls"} 10 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 11 | HOSTE_NODE_ADDR=${MASTER_ADDR}:${PORT} 12 | TIMESTAMP=$(date +%Y_%m_%d-%H_%M_%S) 13 | export TOKENIZERS_PARALLELISM=true 14 | 15 | DATA_PATH="{your_path_to_datacomp-1b-webdataset}/{000000..140146}.tar" 16 | VAL_DATA_PATH="{your_path_to_imagenet1k}/ILSVRC/Data/CLS-LOC/val" 17 | 18 | echo "$DIR" 19 | cd $DIR 20 | 21 | torchrun --nproc_per_node=$GPUS \ 22 | --rdzv_endpoint=$HOSTE_NODE_ADDR \ 23 | --nnodes=$NNODES --node_rank=$NODE_RANK \ 24 | -m training.main \ 25 | --config="${CONFIG}" \ 26 | --train-data $DATA_PATH \ 27 | --imagenet-val $VAL_DATA_PATH -------------------------------------------------------------------------------- /train_combo.sh: -------------------------------------------------------------------------------- 1 | CLS_CONFIG=$1 2 | LIT_CONFIG=$2 3 | DIR=$3 4 | NV_GPUS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) 5 | GPUS=${GPUS:-${NV_GPUS}} 6 | NNODES=${NNODES:-1} 7 | NODE_RANK=${NODE_RANK:-0} 8 | PORT=${PORT:-55565} 9 | DIR=${DIR:-"opencls"} 10 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 11 | HOSTE_NODE_ADDR=${MASTER_ADDR}:${PORT} 12 | TIMESTAMP=$(date +%Y_%m_%d-%H_%M_%S) 13 | export TOKENIZERS_PARALLELISM=true 14 | 15 | DATA_PATH="{your_path_to_datacomp-1b-webdataset}/{000000..140146}.tar" 16 | VAL_DATA_PATH="{your_path_to_imagenet1k}/ILSVRC/Data/CLS-LOC/val" 17 | 18 | echo "$DIR" 19 | cd $DIR 20 | 21 | torchrun --nproc_per_node=$GPUS \ 22 | --rdzv_endpoint=$HOSTE_NODE_ADDR \ 23 | --nnodes=$NNODES --node_rank=$NODE_RANK \ 24 | -m training.main \ 25 | --config="${CLS_CONFIG}" \ 26 | --train-data $DATA_PATH \ 27 | --imagenet-val $VAL_DATA_PATH 28 | 29 | 30 | torchrun --nproc_per_node=$GPUS \ 31 | --rdzv_endpoint=$HOSTE_NODE_ADDR \ 32 | --nnodes=$NNODES --node_rank=$NODE_RANK \ 33 | -m training.main \ 34 | --config="${LIT_CONFIG}" \ 35 | --train-data $DATA_PATH \ 36 | --imagenet-val $VAL_DATA_PATH --------------------------------------------------------------------------------