├── LICENSE
├── README.md
├── assets
└── teaser.png
├── opencls
├── configs
│ ├── .DS_Store
│ └── cls_schedule
│ │ ├── .DS_Store
│ │ ├── cls_vit_b16_s1.28B_bs16k.yaml
│ │ ├── cls_vit_b16_s512m_bs16k.yaml
│ │ ├── cls_vit_l14_224_s12.8B_bs90k.yaml
│ │ ├── cls_vit_l14_s1.28B_bs16k.yaml
│ │ ├── cls_vit_l16_s512m_bs16k.yaml
│ │ ├── lit_vit_b16_s1.28B_bs16k.yaml
│ │ ├── lit_vit_b16_s512m_bs16k.yaml
│ │ ├── lit_vit_l14_224_s12.8B_bs90k.yaml
│ │ ├── lit_vit_l14_s1.28B_bs16k.yaml
│ │ └── lit_vit_l16_s512m_bs16k.yaml
├── open_clip
│ ├── .DS_Store
│ ├── __init__.py
│ ├── big_vision.py
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── cls_model.py
│ ├── coca_model.py
│ ├── constants.py
│ ├── factory.py
│ ├── hf_configs.py
│ ├── hf_model.py
│ ├── loss.py
│ ├── model.py
│ ├── model_configs
│ │ ├── .DS_Store
│ │ ├── CLS-ViT-B-16.json
│ │ ├── CLS-ViT-L-14.json
│ │ ├── CLS-ViT-L-16.json
│ │ ├── EVA01-g-14-plus.json
│ │ ├── EVA01-g-14.json
│ │ ├── EVA02-B-16.json
│ │ ├── EVA02-E-14-plus.json
│ │ ├── EVA02-E-14.json
│ │ ├── EVA02-L-14-336.json
│ │ ├── EVA02-L-14.json
│ │ ├── RN101-quickgelu.json
│ │ ├── RN101.json
│ │ ├── RN50-quickgelu.json
│ │ ├── RN50.json
│ │ ├── RN50x16.json
│ │ ├── RN50x4.json
│ │ ├── RN50x64.json
│ │ ├── ViT-B-16-SigLIP-256.json
│ │ ├── ViT-B-16-SigLIP-384.json
│ │ ├── ViT-B-16-SigLIP-512.json
│ │ ├── ViT-B-16-SigLIP-i18n-256.json
│ │ ├── ViT-B-16-SigLIP.json
│ │ ├── ViT-B-16-avg.json
│ │ ├── ViT-B-16-plus-240.json
│ │ ├── ViT-B-16-plus.json
│ │ ├── ViT-B-16-quickgelu.json
│ │ ├── ViT-B-16.json
│ │ ├── ViT-B-32-256.json
│ │ ├── ViT-B-32-plus-256.json
│ │ ├── ViT-B-32-quickgelu.json
│ │ ├── ViT-B-32.json
│ │ ├── ViT-H-14-378-quickgelu.json
│ │ ├── ViT-H-14-CLIPA-336.json
│ │ ├── ViT-H-14-CLIPA.json
│ │ ├── ViT-H-14-quickgelu.json
│ │ ├── ViT-H-14.json
│ │ ├── ViT-H-16.json
│ │ ├── ViT-L-14-280.json
│ │ ├── ViT-L-14-336.json
│ │ ├── ViT-L-14-CLIPA-336.json
│ │ ├── ViT-L-14-CLIPA.json
│ │ ├── ViT-L-14-avg.json
│ │ ├── ViT-L-14-quickgelu.json
│ │ ├── ViT-L-14.json
│ │ ├── ViT-L-16-320.json
│ │ ├── ViT-L-16-SigLIP-256.json
│ │ ├── ViT-L-16-SigLIP-384.json
│ │ ├── ViT-L-16-avg.json
│ │ ├── ViT-L-16.json
│ │ ├── ViT-M-16-alt.json
│ │ ├── ViT-M-16.json
│ │ ├── ViT-M-32-alt.json
│ │ ├── ViT-M-32.json
│ │ ├── ViT-S-16-alt.json
│ │ ├── ViT-S-16.json
│ │ ├── ViT-S-32-alt.json
│ │ ├── ViT-S-32.json
│ │ ├── ViT-SO400M-14-SigLIP-384.json
│ │ ├── ViT-SO400M-14-SigLIP.json
│ │ ├── ViT-bigG-14-CLIPA-336.json
│ │ ├── ViT-bigG-14-CLIPA.json
│ │ ├── ViT-bigG-14.json
│ │ ├── ViT-e-14.json
│ │ ├── ViT-g-14.json
│ │ ├── coca_ViT-B-32.json
│ │ ├── coca_ViT-L-14.json
│ │ ├── coca_base.json
│ │ ├── coca_roberta-ViT-B-32.json
│ │ ├── convnext_base.json
│ │ ├── convnext_base_w.json
│ │ ├── convnext_base_w_320.json
│ │ ├── convnext_large.json
│ │ ├── convnext_large_d.json
│ │ ├── convnext_large_d_320.json
│ │ ├── convnext_small.json
│ │ ├── convnext_tiny.json
│ │ ├── convnext_xlarge.json
│ │ ├── convnext_xxlarge.json
│ │ ├── convnext_xxlarge_320.json
│ │ ├── mt5-base-ViT-B-32.json
│ │ ├── mt5-xl-ViT-H-14.json
│ │ ├── nllb-clip-base-siglip.json
│ │ ├── nllb-clip-base.json
│ │ ├── nllb-clip-large-siglip.json
│ │ ├── nllb-clip-large.json
│ │ ├── roberta-ViT-B-32.json
│ │ ├── swin_base_patch4_window7_224.json
│ │ ├── vit_medium_patch16_gap_256.json
│ │ ├── vit_relpos_medium_patch16_cls_224.json
│ │ ├── xlm-roberta-base-ViT-B-32.json
│ │ └── xlm-roberta-large-ViT-H-14.json
│ ├── modified_resnet.py
│ ├── openai.py
│ ├── pos_embed.py
│ ├── pretrained.py
│ ├── push_to_hf_hub.py
│ ├── timm_model.py
│ ├── tokenizer.py
│ ├── transform.py
│ ├── transformer.py
│ ├── utils.py
│ ├── version.py
│ ├── zero_shot_classifier.py
│ └── zero_shot_metadata.py
└── training
│ ├── .DS_Store
│ ├── .gitignore
│ ├── __init__.py
│ ├── data.py
│ ├── distributed.py
│ ├── file_utils.py
│ ├── logger.py
│ ├── main.py
│ ├── params.py
│ ├── precision.py
│ ├── profiler0.py
│ ├── scheduler.py
│ ├── train.py
│ └── zero_shot.py
├── requirements.txt
├── train.sh
└── train_combo.sh
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
SuperClass: Classification Done Right for Vision-Language Pre-Training
3 |
4 | [**Zilong Huang**](http://speedinghzl.github.io/) · [**Qinghao Ye**](https://scholar.google.com/citations?user=ZYOhaGwAAAAJ&hl=zh-CN) · [**Bingyi Kang**](https://bingykang.github.io/) · [**Jiashi Feng**](https://sites.google.com/site/jshfeng/) · [**Haoqi Fan**](https://scholar.google.com/citations?user=76B8lrgAAAAJ&hl=en)
5 |
6 | Bytedance Seed
7 |
8 |

9 |
10 |
11 |
12 |
13 | This work presents SuperClass, a super simple classification method that performs vision-language pre-training. Our method does **not require a text encoder** to be pre-trained on image-text data. Instead, it utilizes **tokenized raw text** as **supervised classification labels**, without the need for additional text filtering or selection.
14 |
15 |
16 |

17 |
18 |
19 |
20 | ## News
21 |
22 | - **2024-11-06:** Paper & code are all released.
23 | - **2024-10-02:** SuperClass is accepted by NeurIPS 2024.
24 |
25 |
26 | ## Usage
27 |
28 | ### Prepraration
29 |
30 | ```bash
31 | git clone https://github.com/x-cls/superclass
32 | cd superclass
33 | pip install -r requirements.txt
34 | ```
35 |
36 | Download the datasets [Datacomp-1B](https://github.com/mlfoundations/datacomp) and [ImageNet-1K](https://www.image-net.org/download.php). You can also use [other image-text pair datasets](https://github.com/rom1504/img2dataset/tree/main?tab=readme-ov-file#examples) for training.
37 |
38 | Modify the **DATA_PATH** and **VAL_DATA_PATH** in training script **train.sh** and **train_combo.sh** to your local paths to Datacomp-1B and ImageNet-1K.
39 |
40 |
41 | ### CLIP Training & Superclass Training
42 |
43 | To start CLIP training and superclass training, use the following command:
44 |
45 | ```bash
46 | bash train.sh opencls
47 | ```
48 |
49 | This script will navigate to the opencls directory and execute the training.
50 |
51 | If you want to include the LiT training phase, use the following command:
52 |
53 | ```bash
54 | bash train_combo.sh opencls
55 | ```
56 |
57 | CLS training config are here `opencls/configs/cls_schedule`
58 |
59 |
60 | For example:
61 | ```bash
62 | bash train.sh configs/cls_schedule/cls_vit_b16_s1.28B_bs16k.yaml opencls
63 | ```
64 |
65 | Please note that the default **precision** during training is set to **amp_bfloat16**. If your GPU (e.g., V100) does not support bf16, please change it to **fp16** or **amp**.
66 |
67 |
68 |
69 |
70 |
71 | ## Acknowledgement
72 | Our codebase is built up on [OpenCLIP](https://github.com/mlfoundations/open_clip) and the [ViTamin](https://github.com/Beckschen/ViTamin).
73 |
74 | We thank the [OpenCLIP](https://github.com/mlfoundations/open_clip) and the [ViTamin](https://github.com/Beckschen/ViTamin) for contributing such impressive codes and models to our community.
75 |
76 |
77 | ## LICENSE
78 |
79 | The models & code of SuperClass are released under the Apache-2.0 license.
80 |
81 |
82 | ## Citation
83 |
84 | If you find this project useful, please consider citing:
85 |
86 | ```bibtex
87 | @inproceedings{superclass_huang,
88 | title={Classification Done Right for Vision-Language Pre-Training},
89 | author={Huang, Zilong and Ye, Qinghao and Kang, Bingyi and Feng, Jiashi and Fan, Haoqi},
90 | booktitle={NeurIPS},
91 | year={2024}
92 | }
93 | ```
94 |
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/x-cls/superclass/90b761f41177400caa178d3845b39eba771e5bf1/assets/teaser.png
--------------------------------------------------------------------------------
/opencls/configs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/x-cls/superclass/90b761f41177400caa178d3845b39eba771e5bf1/opencls/configs/.DS_Store
--------------------------------------------------------------------------------
/opencls/configs/cls_schedule/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/x-cls/superclass/90b761f41177400caa178d3845b39eba771e5bf1/opencls/configs/cls_schedule/.DS_Store
--------------------------------------------------------------------------------
/opencls/configs/cls_schedule/cls_vit_b16_s1.28B_bs16k.yaml:
--------------------------------------------------------------------------------
1 | save_frequency: 1
2 | name: "cls_vit_b16_s1.28B_bs16k"
3 | train_data: '/datasets/datacomp_1b/data/{00000..128000}.tar' # please modify to your own path
4 | train_num_samples: 1_280_000_000
5 | dataset_type: webdataset
6 | precision: 'amp_bfloat16'
7 | warmup: 500
8 | global_batch_size: 16384
9 | batch_size: 0
10 | epochs: 1
11 | lr: 5e-4
12 | beta1: 0.9
13 | beta2: 0.98
14 | eps: 1.0e-6
15 | workers: 6
16 | model: "CLS-ViT-B-16"
17 | seed: 0
18 | ddp_static_graph: true
19 | local_loss: true
20 | gather_with_grad: true
21 | force_image_size: 224
22 | grad_checkpointing: true
23 |
24 | logs: './logs'
25 | imagenet_val: './imagenet1k/val' # please modify to your own path
26 |
27 | report_to: "tensorboard"
28 | log_every_n_steps: 128
29 | zeroshot_steps: 0
30 | val_steps: 0
31 | zeroshot_frequency: 0
32 | val_frequency: 0
33 | save_every_n_steps: 6104
34 | delete_prev_step_ckpt: true
35 |
36 | resume: latest
--------------------------------------------------------------------------------
/opencls/configs/cls_schedule/cls_vit_b16_s512m_bs16k.yaml:
--------------------------------------------------------------------------------
1 | save_frequency: 1
2 | name: "cls_vit_b16_s512m_bs16k"
3 | train_data: '/datasets/datacomp_1b/data/{00000..128000}.tar' # please modify to your own path
4 | train_num_samples: 512_000_000
5 | dataset_type: webdataset
6 | precision: 'amp_bfloat16'
7 | warmup: 500
8 | global_batch_size: 16384
9 | batch_size: 0
10 | epochs: 1
11 | lr: 5e-4
12 | beta1: 0.9
13 | beta2: 0.98
14 | eps: 1.0e-6
15 | workers: 6
16 | model: "CLS-ViT-B-16"
17 | seed: 0
18 | ddp_static_graph: true
19 | local_loss: true
20 | gather_with_grad: true
21 | force_image_size: 224
22 | grad_checkpointing: true
23 |
24 | logs: './logs'
25 | imagenet_val: './imagenet1k/val' # please modify to your own path
26 |
27 | report_to: "tensorboard"
28 | log_every_n_steps: 128
29 | zeroshot_steps: 0
30 | val_steps: 0
31 | zeroshot_frequency: 0
32 | val_frequency: 0
33 | save_every_n_steps: 6104
34 | delete_prev_step_ckpt: true
35 |
36 | resume: latest
--------------------------------------------------------------------------------
/opencls/configs/cls_schedule/cls_vit_l14_224_s12.8B_bs90k.yaml:
--------------------------------------------------------------------------------
1 | save_frequency: 1
2 | name: "cls_vit_l14_224_s12.8B_bs90k"
3 | train_data: '/datasets/datacomp_1b/data/{00000..140089}.tar' # please modify to your own path
4 | train_num_samples: 1_280_000_000
5 | dataset_type: webdataset
6 | precision: 'amp_bfloat16'
7 | warmup: 10000
8 | global_batch_size: 90112
9 | batch_size: 0
10 | epochs: 10
11 | lr: 1.0e-3
12 | beta1: 0.9
13 | beta2: 0.95
14 | eps: 1.0e-6
15 | workers: 6
16 | model: "CLS-ViT-L-14"
17 | seed: 0
18 | ddp_static_graph: true
19 | local_loss: true
20 | gather_with_grad: true
21 | force_image_size: 224
22 | grad_checkpointing: true
23 |
24 | logs: './logs'
25 | imagenet_val: './imagenet1k/val' # please modify to your own path
26 |
27 | report_to: "tensorboard"
28 | log_every_n_steps: 32
29 | zeroshot_steps: 0
30 | val_steps: 0
31 | zeroshot_frequency: 0
32 | val_frequency: 0
33 | save_every_n_steps: 3052
34 | delete_prev_step_ckpt: true
35 | aug_cfg: {'scale': [0.4, 1.0], 'color_jitter': [0.32, 0.32, 0.32, 0.08], 'color_jitter_prob': 0.8, 'gray_scale_prob': 0.2}
36 | resume: latest
--------------------------------------------------------------------------------
/opencls/configs/cls_schedule/cls_vit_l14_s1.28B_bs16k.yaml:
--------------------------------------------------------------------------------
1 | save_frequency: 1
2 | name: "cls_vit_l14_s1.28B_bs16k"
3 | train_data: '/datasets/datacomp_1b/data/{00000..128000}.tar' # please modify to your own path
4 | train_num_samples: 1_280_000_000
5 | dataset_type: webdataset
6 | precision: 'amp_bfloat16'
7 | warmup: 500
8 | global_batch_size: 16384
9 | batch_size: 0
10 | epochs: 1
11 | lr: 5e-4
12 | beta1: 0.9
13 | beta2: 0.98
14 | eps: 1.0e-6
15 | workers: 10
16 | model: "CLS-ViT-L-14"
17 | seed: 0
18 | ddp_static_graph: true
19 | local_loss: true
20 | gather_with_grad: true
21 | force_image_size: 224
22 | grad_checkpointing: true
23 |
24 | logs: './logs'
25 | imagenet_val: './imagenet1k/val' # please modify to your own path
26 |
27 | report_to: "tensorboard"
28 | log_every_n_steps: 128
29 | zeroshot_steps: 0
30 | val_steps: 0
31 | zeroshot_frequency: 0
32 | val_frequency: 0
33 | save_every_n_steps: 6104
34 | delete_prev_step_ckpt: true
35 |
36 | resume: latest
--------------------------------------------------------------------------------
/opencls/configs/cls_schedule/cls_vit_l16_s512m_bs16k.yaml:
--------------------------------------------------------------------------------
1 | save_frequency: 1
2 | name: "cls_vit_l16_s512m_bs16k"
3 | train_data: '/datasets/datacomp_1b/data/{00000..128000}.tar' # please modify to your own path
4 | train_num_samples: 512_000_000
5 | dataset_type: webdataset
6 | precision: 'amp_bfloat16'
7 | warmup: 500
8 | global_batch_size: 16384
9 | batch_size: 0
10 | epochs: 1
11 | lr: 5e-4
12 | beta1: 0.9
13 | beta2: 0.98
14 | eps: 1.0e-6
15 | workers: 6
16 | model: "CLS-ViT-L-16"
17 | seed: 0
18 | ddp_static_graph: true
19 | local_loss: true
20 | gather_with_grad: true
21 | force_image_size: 224
22 | grad_checkpointing: true
23 |
24 | logs: './logs'
25 | imagenet_val: './imagenet1k/val' # please modify to your own path
26 |
27 | report_to: "tensorboard"
28 | log_every_n_steps: 128
29 | zeroshot_steps: 0
30 | val_steps: 0
31 | zeroshot_frequency: 0
32 | val_frequency: 0
33 | save_every_n_steps: 6104
34 | delete_prev_step_ckpt: true
35 |
36 | resume: latest
--------------------------------------------------------------------------------
/opencls/configs/cls_schedule/lit_vit_b16_s1.28B_bs16k.yaml:
--------------------------------------------------------------------------------
1 | save_frequency: 1
2 | name: "lit_cls_vit_b16_s1.28B_bs16k"
3 | train_data: '/datasets/datacomp_1b/data/{00000..51200}.tar'
4 | train_num_samples: 1_280_000_000
5 | dataset_type: webdataset
6 | precision: 'amp_bfloat16'
7 | warmup: 500
8 | global_batch_size: 16384
9 | batch_size: 0
10 | epochs: 1
11 | lr: 5e-4
12 | beta1: 0.9
13 | beta2: 0.98
14 | eps: 1.0e-6
15 | workers: 6
16 | model: "ViT-B-16"
17 | seed: 0
18 | ddp_static_graph: true
19 | local_loss: true
20 | gather_with_grad: true
21 | force_image_size: 224
22 | grad_checkpointing: true
23 |
24 | lock_image: true
25 | lock_image_unlocked_groups: 1
26 | pretrained_image: './logs/cls_vit_b16_s1.28B_bs16k/checkpoints/epoch_1.pt'
27 |
28 | logs: './logs'
29 | imagenet_val: './imagenet1k/val' # please modify to your own path
30 |
31 | report_to: "tensorboard"
32 | log_every_n_steps: 128
33 | zeroshot_steps: 6104
34 | val_steps: 6104
35 | save_every_n_steps: 6104
36 | delete_prev_step_ckpt: true
37 |
38 | resume: latest
--------------------------------------------------------------------------------
/opencls/configs/cls_schedule/lit_vit_b16_s512m_bs16k.yaml:
--------------------------------------------------------------------------------
1 | save_frequency: 1
2 | name: "lit_cls_vit_b16_s512m_bs16k"
3 | train_data: '/datasets/datacomp_1b/data/{00000..51200}.tar'
4 | train_num_samples: 512_000_000
5 | dataset_type: webdataset
6 | precision: 'amp_bfloat16'
7 | warmup: 500
8 | global_batch_size: 16384
9 | batch_size: 0
10 | epochs: 1
11 | lr: 5e-4
12 | beta1: 0.9
13 | beta2: 0.98
14 | eps: 1.0e-6
15 | workers: 6
16 | model: "ViT-B-16-avg"
17 | seed: 0
18 | ddp_static_graph: true
19 | local_loss: true
20 | gather_with_grad: true
21 | force_image_size: 224
22 | grad_checkpointing: true
23 |
24 | lock_image: true
25 | lock_image_unlocked_groups: 1
26 | pretrained_image: './logs/cls_vit_b16_s512m_bs16k/checkpoints/epoch_1.pt'
27 |
28 | logs: './logs'
29 | imagenet_val: './imagenet1k/val' # please modify to your own path
30 |
31 | report_to: "tensorboard"
32 | log_every_n_steps: 128
33 | zeroshot_steps: 6104
34 | val_steps: 6104
35 | save_every_n_steps: 6104
36 | delete_prev_step_ckpt: true
37 |
38 | resume: latest
--------------------------------------------------------------------------------
/opencls/configs/cls_schedule/lit_vit_l14_224_s12.8B_bs90k.yaml:
--------------------------------------------------------------------------------
1 | save_frequency: 1
2 | name: "lit_cls_vit_l14_224_s12.8B_bs90k"
3 | train_data: '/datasets/datacomp_1b/data/{00000..140089}.tar' # please modify to your own path
4 | train_num_samples: 1_280_000_000
5 | dataset_type: webdataset
6 | precision: 'amp_bfloat16'
7 | warmup: 10000
8 | global_batch_size: 90112
9 | batch_size: 0
10 | epochs: 10
11 | lr: 1.0e-3
12 | beta1: 0.9
13 | beta2: 0.95
14 | eps: 1.0e-6
15 | workers: 6
16 | model: "ViT-L-14-avg"
17 | seed: 0
18 | ddp_static_graph: true
19 | local_loss: true
20 | gather_with_grad: true
21 | force_image_size: 224
22 | grad_checkpointing: true
23 |
24 | lock_image: true
25 | lock_image_unlocked_groups: 1
26 | pretrained_image: './logs/cls_vit_l14_224_s12.8B_bs90k/checkpoints/epoch_10.pt'
27 |
28 | logs: './logs'
29 | imagenet_val: './imagenet1k/val' # please modify to your own path
30 |
31 | report_to: "tensorboard"
32 | log_every_n_steps: 32
33 | zeroshot_steps: 3052
34 | val_steps: 3052
35 | zeroshot_frequency: 1
36 | save_every_n_steps: 3052
37 | delete_prev_step_ckpt: true
38 | aug_cfg: {'scale': [0.4, 1.0], 'color_jitter': [0.32, 0.32, 0.32, 0.08], 'color_jitter_prob': 0.8, 'gray_scale_prob': 0.2}
39 | resume: latest
--------------------------------------------------------------------------------
/opencls/configs/cls_schedule/lit_vit_l14_s1.28B_bs16k.yaml:
--------------------------------------------------------------------------------
1 | save_frequency: 1
2 | name: "lit_cls_vit_l14_s1.28B_bs16k"
3 | train_data: '/datasets/datacomp_1b/data/{00000..128000}.tar' # please modify to your own path
4 | train_num_samples: 1_280_000_000
5 | dataset_type: webdataset
6 | precision: 'amp_bfloat16'
7 | warmup: 500
8 | global_batch_size: 16384
9 | batch_size: 0
10 | epochs: 1
11 | lr: 5e-4
12 | beta1: 0.9
13 | beta2: 0.98
14 | eps: 1.0e-6
15 | workers: 6
16 | model: "ViT-L-14"
17 | seed: 0
18 | ddp_static_graph: true
19 | local_loss: true
20 | gather_with_grad: true
21 | force_image_size: 224
22 | grad_checkpointing: true
23 |
24 | lock_image: true
25 | lock_image_unlocked_groups: 1
26 | pretrained_image: './logs/cls_vit_l14_s1.28B_bs16k/checkpoints/epoch_1.pt'
27 |
28 | logs: './logs'
29 | imagenet_val: './imagenet1k/val' # please modify to your own path
30 |
31 | report_to: "tensorboard"
32 | log_every_n_steps: 128
33 | zeroshot_steps: 6104
34 | val_steps: 6104
35 | save_every_n_steps: 6104
36 | delete_prev_step_ckpt: true
37 |
38 | resume: latest
--------------------------------------------------------------------------------
/opencls/configs/cls_schedule/lit_vit_l16_s512m_bs16k.yaml:
--------------------------------------------------------------------------------
1 | save_frequency: 1
2 | name: "lit_clip_vit_l16_s512m_bs16k"
3 | train_data: '/datasets/datacomp_1b/data/{00000..51200}.tar'
4 | train_num_samples: 512_000_000
5 | dataset_type: webdataset
6 | precision: 'amp_bfloat16'
7 | warmup: 500
8 | global_batch_size: 16384
9 | batch_size: 0
10 | epochs: 1
11 | lr: 5e-4
12 | beta1: 0.9
13 | beta2: 0.98
14 | eps: 1.0e-6
15 | workers: 6
16 | model: "ViT-L-16-avg"
17 | seed: 0
18 | ddp_static_graph: true
19 | local_loss: true
20 | gather_with_grad: true
21 | force_image_size: 224
22 | grad_checkpointing: true
23 |
24 | lock_image: true
25 | lock_image_unlocked_groups: 1
26 | pretrained_image: './logs/cls_vit_l16_s512m_bs16k/checkpoints/epoch_1.pt'
27 |
28 | logs: './logs'
29 | imagenet_val: './imagenet1k/val' # please modify to your own path
30 |
31 | report_to: "tensorboard"
32 | log_every_n_steps: 128
33 | zeroshot_steps: 6104
34 | val_steps: 6104
35 | save_every_n_steps: 6104
36 | delete_prev_step_ckpt: true
37 |
38 | resume: latest
--------------------------------------------------------------------------------
/opencls/open_clip/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/x-cls/superclass/90b761f41177400caa178d3845b39eba771e5bf1/opencls/open_clip/.DS_Store
--------------------------------------------------------------------------------
/opencls/open_clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .coca_model import CoCa
2 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
3 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
4 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint
5 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss
6 | from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
7 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \
8 | get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg
9 | from .openai import load_openai_model, list_openai_models
10 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
11 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
12 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
13 | from .tokenizer import SimpleTokenizer, tokenize, decode
14 | from .transform import image_transform, AugmentationCfg
15 | from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy
16 | from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES
17 |
--------------------------------------------------------------------------------
/opencls/open_clip/big_vision.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from .model import CustomTextCLIP
5 | from .transformer import TextTransformer, Transformer
6 |
7 |
8 | @torch.no_grad()
9 | def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
10 | """ Load weights from .npz checkpoints for official Google big_vision image-text models
11 |
12 | Currently the SigLIP source models are supported and a CustomTextCLIP destination model
13 | w/ timm image encoder.
14 | """
15 | from timm.layers import resample_patch_embed, resample_abs_pos_embed
16 |
17 | def _n2p(w, t=True):
18 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
19 | w = w.flatten()
20 | if t:
21 | if w.ndim == 4:
22 | w = w.transpose([3, 2, 0, 1])
23 | elif w.ndim == 3:
24 | w = w.transpose([2, 0, 1])
25 | elif w.ndim == 2:
26 | w = w.transpose([1, 0])
27 | return torch.from_numpy(w)
28 |
29 | w = np.load(checkpoint_path)
30 | interpolation = 'bilinear'
31 | antialias = False
32 |
33 | def _convert_timm_img(module, prefix):
34 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
35 | if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]:
36 | embed_conv_w = resample_patch_embed(
37 | embed_conv_w,
38 | module.patch_embed.proj.weight.shape[-2:],
39 | interpolation=interpolation,
40 | antialias=antialias,
41 | verbose=True,
42 | )
43 | module.patch_embed.proj.weight.copy_(embed_conv_w)
44 | module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
45 |
46 | if module.cls_token is not None:
47 | module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
48 |
49 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
50 | if pos_embed_w.shape != module.pos_embed.shape:
51 | assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}'
52 | num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1)
53 | pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
54 | pos_embed_w,
55 | new_size=module.patch_embed.grid_size,
56 | num_prefix_tokens=num_prefix_tokens,
57 | interpolation=interpolation,
58 | antialias=antialias,
59 | verbose=True,
60 | )
61 | module.pos_embed.copy_(pos_embed_w)
62 |
63 | mha_sub, b_sub, ln1_sub = (0, 0, 1)
64 | for i, block in enumerate(module.blocks.children()):
65 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
66 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
67 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
68 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
69 | block.attn.qkv.weight.copy_(torch.cat([
70 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
71 | block.attn.qkv.bias.copy_(torch.cat([
72 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
73 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
74 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
75 | for r in range(2):
76 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
77 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
78 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
79 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
80 |
81 | module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
82 | module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
83 |
84 | if module.attn_pool is not None:
85 | block_prefix = f'{prefix}MAPHead_0/'
86 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
87 | module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
88 | module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
89 | module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
90 | module.attn_pool.kv.weight.copy_(torch.cat([
91 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
92 | module.attn_pool.kv.bias.copy_(torch.cat([
93 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
94 | module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
95 | module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
96 | module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
97 | module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
98 | for r in range(2):
99 | getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
100 | getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
101 |
102 | def _convert_openclip_transformer(module: Transformer, prefix):
103 | for i, block in enumerate(module.resblocks.children()):
104 | block_prefix = f'{prefix}encoderblock_{i}/'
105 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
106 | block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
107 | block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
108 | block.attn.in_proj_weight.copy_(torch.cat([
109 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
110 | block.attn.in_proj_bias.copy_(torch.cat([
111 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
112 | block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
113 | block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
114 | block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale']))
115 | block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias']))
116 | block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel']))
117 | block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias']))
118 | block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel']))
119 | block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias']))
120 |
121 | def _convert_openclip_txt(module: TextTransformer, prefix):
122 | module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False))
123 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0)
124 | module.positional_embedding.copy_(pos_embed_w)
125 | _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
126 | module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
127 | module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
128 | module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
129 | module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))
130 |
131 | _convert_timm_img(model.visual.trunk, 'params/img/')
132 | _convert_openclip_txt(model.text, 'params/txt/')
133 | model.logit_bias.copy_(_n2p(w['params/b'])[0])
134 | model.logit_scale.copy_(_n2p(w['params/t'])[0])
135 |
136 |
137 |
--------------------------------------------------------------------------------
/opencls/open_clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/x-cls/superclass/90b761f41177400caa178d3845b39eba771e5bf1/opencls/open_clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/opencls/open_clip/cls_model.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Copyright (2024) Bytedance Ltd. and/or its affiliates
3 | # Licensed under the Apache License, Version 2.0 (the "License")
4 | # SuperClass Project
5 | # Written by Zilong Huang
6 | # --------------------------------------------------------
7 | from typing import Optional
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as F
12 | import numpy as np
13 | from dataclasses import dataclass
14 |
15 | from .transformer import (
16 | LayerNormFp32,
17 | LayerNorm,
18 | QuickGELU,
19 | MultimodalTransformer,
20 | MixClsHead,
21 | )
22 | from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
23 |
24 |
25 | @dataclass
26 | class ClassHeadCfg(CLIPTextCfg):
27 | mlp_ratio: int = 4
28 | layers: int = 1
29 |
30 |
31 | def _build_cls_head(
32 | width,
33 | clshead_cfg,
34 | quick_gelu: bool = False,
35 | cast_dtype: Optional[torch.dtype] = None,
36 | ):
37 | clshead_cfg = ClassHeadCfg(**clshead_cfg) if isinstance(clshead_cfg, dict) else clshead_cfg
38 | act_layer = QuickGELU if quick_gelu else nn.GELU
39 | norm_layer = (
40 | LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
41 | )
42 |
43 | head = MixClsHead(
44 | width=width,
45 | layers=clshead_cfg.layers,
46 | mlp_ratio=clshead_cfg.mlp_ratio,
47 | act_layer=act_layer,
48 | norm_layer=norm_layer,
49 | output_dim=clshead_cfg.vocab_size,
50 | )
51 |
52 | return head
53 |
54 |
55 | class Classifier(nn.Module):
56 | def __init__(
57 | self,
58 | embed_dim,
59 | text_cfg: CLIPTextCfg,
60 | vision_cfg: CLIPVisionCfg,
61 | quick_gelu: bool = False,
62 | cast_dtype: Optional[torch.dtype] = None,
63 | ):
64 | super().__init__()
65 | clshead_cfg = ClassHeadCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
66 | vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
67 |
68 | vocab_size = clshead_cfg.vocab_size
69 |
70 | self.visual = _build_vision_tower(
71 | embed_dim=embed_dim,
72 | vision_cfg=vision_cfg,
73 | quick_gelu=quick_gelu,
74 | cast_dtype=cast_dtype,
75 | )
76 |
77 | self.text_decoder = _build_cls_head(
78 | embed_dim if embed_dim else vision_cfg.width,
79 | clshead_cfg=clshead_cfg,
80 | quick_gelu=quick_gelu,
81 | cast_dtype=cast_dtype,
82 | )
83 |
84 | self.register_buffer("cap_fq", torch.zeros([1, vocab_size], dtype=torch.float64))
85 | self.register_buffer("num_samples", torch.zeros([1, 1], dtype=torch.float64))
86 |
87 | @torch.jit.ignore
88 | def set_grad_checkpointing(self, enable=True):
89 | self.visual.set_grad_checkpointing(enable)
90 | # self.text.set_grad_checkpointing(enable)
91 | # self.text_decoder.set_grad_checkpointing(enable)
92 |
93 | def forward(self, image, text, image_embs=None):
94 | if image_embs is None:
95 | image_embs = self.visual(image)
96 |
97 | logits = self.text_decoder(image_embs)
98 | labels = text.clone()
99 |
100 | return {
101 | "cap_fq": self.cap_fq,
102 | "num_samples": self.num_samples,
103 | "logits": logits,
104 | "labels": labels,
105 | "logit_scale": torch.ones([1]),
106 | }
--------------------------------------------------------------------------------
/opencls/open_clip/constants.py:
--------------------------------------------------------------------------------
1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
3 | IMAGENET_MEAN = (0.485, 0.456, 0.406)
4 | IMAGENET_STD = (0.229, 0.224, 0.225)
5 | INCEPTION_MEAN = (0.5, 0.5, 0.5)
6 | INCEPTION_STD = (0.5, 0.5, 0.5)
7 |
--------------------------------------------------------------------------------
/opencls/open_clip/hf_configs.py:
--------------------------------------------------------------------------------
1 | # HF architecture dict:
2 | arch_dict = {
3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4 | "roberta": {
5 | "config_names": {
6 | "context_length": "max_position_embeddings",
7 | "vocab_size": "vocab_size",
8 | "width": "hidden_size",
9 | "heads": "num_attention_heads",
10 | "layers": "num_hidden_layers",
11 | "layer_attr": "layer",
12 | "token_embeddings_attr": "embeddings"
13 | },
14 | "pooler": "mean_pooler",
15 | },
16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17 | "xlm-roberta": {
18 | "config_names": {
19 | "context_length": "max_position_embeddings",
20 | "vocab_size": "vocab_size",
21 | "width": "hidden_size",
22 | "heads": "num_attention_heads",
23 | "layers": "num_hidden_layers",
24 | "layer_attr": "layer",
25 | "token_embeddings_attr": "embeddings"
26 | },
27 | "pooler": "mean_pooler",
28 | },
29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30 | "mt5": {
31 | "config_names": {
32 | # unlimited seqlen
33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35 | "context_length": "",
36 | "vocab_size": "vocab_size",
37 | "width": "d_model",
38 | "heads": "num_heads",
39 | "layers": "num_layers",
40 | "layer_attr": "block",
41 | "token_embeddings_attr": "embed_tokens"
42 | },
43 | "pooler": "mean_pooler",
44 | },
45 | # https://huggingface.co/docs/transformers/model_doc/bert
46 | "bert": {
47 | "config_names": {
48 | "context_length": "max_position_embeddings",
49 | "vocab_size": "vocab_size",
50 | "width": "hidden_size",
51 | "heads": "num_attention_heads",
52 | "layers": "num_hidden_layers",
53 | },
54 | "pooler": "cls_pooler",
55 | },
56 | # https://huggingface.co/docs/transformers/model_doc/m2m_100
57 | "m2m_100": {
58 | "config_names": {
59 | "context_length": "max_position_embeddings",
60 | "vocab_size": "vocab_size",
61 | "width": "d_model",
62 | "heads": "encoder_attention_heads",
63 | "layers": "encoder_layers",
64 | },
65 | "pooler": "cls_pooler",
66 | },
67 | }
68 |
--------------------------------------------------------------------------------
/opencls/open_clip/hf_model.py:
--------------------------------------------------------------------------------
1 | """ huggingface model adapter
2 |
3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4 | """
5 | import re
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch import TensorType
10 |
11 | try:
12 | import transformers
13 | from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
14 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
15 | BaseModelOutputWithPoolingAndCrossAttentions
16 | except ImportError as e:
17 | transformers = None
18 |
19 |
20 | class BaseModelOutput:
21 | pass
22 |
23 |
24 | class PretrainedConfig:
25 | pass
26 |
27 | from .hf_configs import arch_dict
28 |
29 |
30 | # utils
31 | def _camel2snake(s):
32 | return re.sub(r'(? torch.Tensor:
96 | # calculated ground-truth and cache if enabled
97 | if self.prev_num_logits != num_logits or device not in self.labels:
98 | labels = torch.arange(num_logits, device=device, dtype=torch.long)
99 | if self.world_size > 1 and self.local_loss:
100 | labels = labels + num_logits * self.rank
101 | if self.cache_labels:
102 | self.labels[device] = labels
103 | self.prev_num_logits = num_logits
104 | else:
105 | labels = self.labels[device]
106 | return labels
107 |
108 | def get_logits(self, image_features, text_features, logit_scale):
109 | if self.world_size > 1:
110 | all_image_features, all_text_features = gather_features(
111 | image_features, text_features,
112 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
113 |
114 | if self.local_loss:
115 | logits_per_image = logit_scale * image_features @ all_text_features.T
116 | logits_per_text = logit_scale * text_features @ all_image_features.T
117 | else:
118 | logits_per_image = logit_scale * all_image_features @ all_text_features.T
119 | logits_per_text = logits_per_image.T
120 | else:
121 | logits_per_image = logit_scale * image_features @ text_features.T
122 | logits_per_text = logit_scale * text_features @ image_features.T
123 |
124 | return logits_per_image, logits_per_text
125 |
126 | def forward(self, image_features, text_features, logit_scale, output_dict=False):
127 | device = image_features.device
128 | logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
129 |
130 | labels = self.get_ground_truth(device, logits_per_image.shape[0])
131 |
132 | total_loss = (
133 | F.cross_entropy(logits_per_image, labels) +
134 | F.cross_entropy(logits_per_text, labels)
135 | ) / 2
136 |
137 | return {"contrastive_loss": total_loss} if output_dict else total_loss
138 |
139 |
140 | class CoCaLoss(ClipLoss):
141 | def __init__(
142 | self,
143 | caption_loss_weight,
144 | clip_loss_weight,
145 | pad_id=0, # pad_token for open_clip custom tokenizer
146 | local_loss=False,
147 | gather_with_grad=False,
148 | cache_labels=False,
149 | rank=0,
150 | world_size=1,
151 | use_horovod=False,
152 | ):
153 | super().__init__(
154 | local_loss=local_loss,
155 | gather_with_grad=gather_with_grad,
156 | cache_labels=cache_labels,
157 | rank=rank,
158 | world_size=world_size,
159 | use_horovod=use_horovod
160 | )
161 |
162 | self.clip_loss_weight = clip_loss_weight
163 | self.caption_loss_weight = caption_loss_weight
164 | self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
165 |
166 | def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
167 |
168 | clip_loss = torch.tensor(0)
169 |
170 | if self.clip_loss_weight:
171 | clip_loss = super().forward(image_features, text_features, logit_scale)
172 | clip_loss = self.clip_loss_weight * clip_loss
173 |
174 | caption_loss = self.caption_loss(
175 | logits.permute(0, 2, 1),
176 | labels,
177 | )
178 | caption_loss = caption_loss * self.caption_loss_weight
179 |
180 | if output_dict:
181 | return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
182 |
183 | return clip_loss, caption_loss
184 |
185 |
186 | class DistillClipLoss(ClipLoss):
187 |
188 | def dist_loss(self, teacher_logits, student_logits):
189 | return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
190 |
191 | def forward(
192 | self,
193 | image_features,
194 | text_features,
195 | logit_scale,
196 | dist_image_features,
197 | dist_text_features,
198 | dist_logit_scale,
199 | output_dict=False,
200 | ):
201 | logits_per_image, logits_per_text = \
202 | self.get_logits(image_features, text_features, logit_scale)
203 |
204 | dist_logits_per_image, dist_logits_per_text = \
205 | self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
206 |
207 | labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
208 |
209 | contrastive_loss = (
210 | F.cross_entropy(logits_per_image, labels) +
211 | F.cross_entropy(logits_per_text, labels)
212 | ) / 2
213 |
214 | distill_loss = (
215 | self.dist_loss(dist_logits_per_image, logits_per_image) +
216 | self.dist_loss(dist_logits_per_text, logits_per_text)
217 | ) / 2
218 |
219 | if output_dict:
220 | return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
221 |
222 | return contrastive_loss, distill_loss
223 |
224 |
225 | def neighbour_exchange(from_rank, to_rank, tensor, group=None):
226 | tensor_recv = torch.zeros_like(tensor)
227 | send_op = torch.distributed.P2POp(
228 | torch.distributed.isend,
229 | tensor,
230 | to_rank,
231 | group=group,
232 | )
233 | recv_op = torch.distributed.P2POp(
234 | torch.distributed.irecv,
235 | tensor_recv,
236 | from_rank,
237 | group=group,
238 | )
239 | reqs = torch.distributed.batch_isend_irecv([send_op, recv_op])
240 | for req in reqs:
241 | req.wait()
242 | return tensor_recv
243 |
244 |
245 | def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
246 | tensor_from_left = torch.zeros_like(tensor_to_right)
247 | tensor_from_right = torch.zeros_like(tensor_to_left)
248 | send_op_left = torch.distributed.P2POp(
249 | torch.distributed.isend,
250 | tensor_to_left,
251 | left_rank,
252 | group=group,
253 | )
254 | send_op_right = torch.distributed.P2POp(
255 | torch.distributed.isend,
256 | tensor_to_right,
257 | right_rank,
258 | group=group,
259 | )
260 | recv_op_left = torch.distributed.P2POp(
261 | torch.distributed.irecv,
262 | tensor_from_left,
263 | left_rank,
264 | group=group,
265 | )
266 | recv_op_right = torch.distributed.P2POp(
267 | torch.distributed.irecv,
268 | tensor_from_right,
269 | right_rank,
270 | group=group,
271 | )
272 | reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left])
273 | for req in reqs:
274 | req.wait()
275 | return tensor_from_right, tensor_from_left
276 |
277 |
278 | class NeighbourExchange(torch.autograd.Function):
279 | @staticmethod
280 | def forward(ctx, from_rank, to_rank, group, tensor):
281 | ctx.group = group
282 | ctx.from_rank = from_rank
283 | ctx.to_rank = to_rank
284 | return neighbour_exchange(from_rank, to_rank, tensor, group=group)
285 |
286 | @staticmethod
287 | def backward(ctx, grad_output):
288 | return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),)
289 |
290 |
291 | def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None):
292 | return NeighbourExchange.apply(from_rank, to_rank, group, tensor)
293 |
294 |
295 | class NeighbourExchangeBidir(torch.autograd.Function):
296 | @staticmethod
297 | def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right):
298 | ctx.group = group
299 | ctx.left_rank = left_rank
300 | ctx.right_rank = right_rank
301 | return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group)
302 |
303 | @staticmethod
304 | def backward(ctx, *grad_outputs):
305 | return (None, None, None) + \
306 | NeighbourExchangeBidir.apply(ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs)
307 |
308 |
309 | def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
310 | return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right)
311 |
312 |
313 | class SigLipLoss(nn.Module):
314 | """ Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343
315 |
316 | @article{zhai2023sigmoid,
317 | title={Sigmoid loss for language image pre-training},
318 | author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas},
319 | journal={arXiv preprint arXiv:2303.15343},
320 | year={2023}
321 | }
322 | """
323 | def __init__(
324 | self,
325 | cache_labels=False,
326 | rank=0,
327 | world_size=1,
328 | bidir=True,
329 | use_horovod=False,
330 | ):
331 | super().__init__()
332 | self.cache_labels = cache_labels
333 | self.rank = rank
334 | self.world_size = world_size
335 | assert not use_horovod # FIXME need to look at hvd ops for ring transfers
336 | self.use_horovod = use_horovod
337 | self.bidir = bidir
338 |
339 | # cache state FIXME cache not currently used, worthwhile?
340 | self.prev_num_logits = 0
341 | self.labels = {}
342 |
343 | def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor:
344 | labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype)
345 | if not negative_only:
346 | labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels
347 | return labels
348 |
349 | def get_logits(self, image_features, text_features, logit_scale, logit_bias=None):
350 | logits = logit_scale * image_features @ text_features.T
351 | if logit_bias is not None:
352 | logits += logit_bias
353 | return logits
354 |
355 | def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False):
356 | logits = self.get_logits(image_features, text_features, logit_scale, logit_bias)
357 | labels = self.get_ground_truth(
358 | image_features.device,
359 | image_features.dtype,
360 | image_features.shape[0],
361 | negative_only=negative_only,
362 | )
363 | loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0]
364 | return loss
365 |
366 | def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False):
367 | loss = self._loss(image_features, text_features, logit_scale, logit_bias)
368 |
369 | if self.world_size > 1:
370 | # exchange text features w/ neighbour world_size - 1 times
371 | right_rank = (self.rank + 1) % self.world_size
372 | left_rank = (self.rank - 1 + self.world_size) % self.world_size
373 | if self.bidir:
374 | text_features_to_right = text_features_to_left = text_features
375 | num_bidir, remainder = divmod(self.world_size - 1, 2)
376 | for i in range(num_bidir):
377 | text_features_recv = neighbour_exchange_bidir_with_grad(
378 | left_rank,
379 | right_rank,
380 | text_features_to_left,
381 | text_features_to_right,
382 | )
383 |
384 | for f in text_features_recv:
385 | loss += self._loss(
386 | image_features,
387 | f,
388 | logit_scale,
389 | logit_bias,
390 | negative_only=True,
391 | )
392 | text_features_to_left, text_features_to_right = text_features_recv
393 |
394 | if remainder:
395 | text_features_recv = neighbour_exchange_with_grad(
396 | left_rank, right_rank, text_features_to_right)
397 |
398 | loss += self._loss(
399 | image_features,
400 | text_features_recv,
401 | logit_scale,
402 | logit_bias,
403 | negative_only=True,
404 | )
405 | else:
406 | text_features_to_right = text_features
407 | for i in range(self.world_size - 1):
408 | text_features_from_left = neighbour_exchange_with_grad(
409 | left_rank, right_rank, text_features_to_right)
410 |
411 | loss += self._loss(
412 | image_features,
413 | text_features_from_left,
414 | logit_scale,
415 | logit_bias,
416 | negative_only=True,
417 | )
418 | text_features_to_right = text_features_from_left
419 |
420 | return {"contrastive_loss": loss} if output_dict else loss
421 |
422 |
423 |
424 | class ClsLoss(nn.Module):
425 | def __init__(
426 | self,
427 | world_size,
428 | pad_id=0, # pad_token for open_clip custom tokenizer
429 | ):
430 | super().__init__()
431 |
432 | self.pad_id = pad_id
433 | self.world_size = world_size
434 | print('loss ignore id ', pad_id)
435 |
436 | def loss(self, logits, targets):
437 | norm_item = F.normalize(targets, p=1, dim=1)
438 | loss = -(F.log_softmax(logits, dim=1) * norm_item).sum(dim=1).mean()
439 | return loss
440 |
441 | def reweight_targets(self, cap_fq, num_samples, targets):
442 | cap_fq += targets.sum(dim=0, keepdim=True) / targets.shape[0]
443 | num_samples += 1
444 | dist.all_reduce(cap_fq, op=dist.ReduceOp.AVG)
445 | dist.all_reduce(num_samples, op=dist.ReduceOp.AVG)
446 | all_batch_size = self.world_size * targets.shape[0]
447 | targets = targets * torch.log((num_samples+1.0/all_batch_size) / (cap_fq+1.0/all_batch_size)).to(dtype=targets.dtype)
448 | return targets
449 |
450 | def forward(self, cap_fq, num_samples, logits, labels, logit_scale, output_dict=False):
451 | B, C = logits.shape
452 |
453 | targets = torch.zeros(B, C, dtype=torch.float32).to(labels.device)
454 | # scatter labels to one-hot
455 | targets.scatter_(dim=1, index=labels, value=1.0)
456 |
457 | targets = self.reweight_targets(cap_fq, num_samples, targets)
458 | class_loss = self.loss(logits, targets)
459 |
460 | if output_dict:
461 | return {"class_loss": class_loss}
462 |
463 | return class_loss
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/x-cls/superclass/90b761f41177400caa178d3845b39eba771e5bf1/opencls/open_clip/model_configs/.DS_Store
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/CLS-ViT-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 0,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 16,
8 | "pool_type": "avg"
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "layers": 0,
14 | "mlp_ratio": 4
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/CLS-ViT-L-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 0,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14,
8 | "pool_type": "avg"
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "layers": 0,
14 | "mlp_ratio": 4
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/CLS-ViT-L-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 0,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 16,
8 | "pool_type": "avg"
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "layers": 0,
14 | "mlp_ratio": 4
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/EVA01-g-14-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "timm_model_name": "eva_giant_patch14_224",
6 | "timm_model_pretrained": false,
7 | "timm_pool": "token",
8 | "timm_proj": null
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 1024,
14 | "heads": 16,
15 | "layers": 24
16 | },
17 | "custom_text": true
18 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/EVA01-g-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "timm_model_name": "eva_giant_patch14_224",
6 | "timm_model_pretrained": false,
7 | "timm_pool": "token",
8 | "timm_proj": null
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 768,
14 | "heads": 12,
15 | "layers": 12
16 | },
17 | "custom_text": true
18 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/EVA02-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "timm_model_name": "eva02_base_patch16_clip_224",
6 | "timm_model_pretrained": false,
7 | "timm_pool": "token",
8 | "timm_proj": null
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | },
17 | "custom_text": true
18 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/EVA02-E-14-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "timm_model_name": "eva02_enormous_patch14_clip_224",
6 | "timm_model_pretrained": false,
7 | "timm_pool": "token",
8 | "timm_proj": null
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 1280,
14 | "heads": 20,
15 | "layers": 32
16 | },
17 | "custom_text": true
18 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/EVA02-E-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "timm_model_name": "eva02_enormous_patch14_clip_224",
6 | "timm_model_pretrained": false,
7 | "timm_pool": "token",
8 | "timm_proj": null
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 1024,
14 | "heads": 16,
15 | "layers": 24
16 | },
17 | "custom_text": true
18 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/EVA02-L-14-336.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 336,
5 | "timm_model_name": "eva02_large_patch14_clip_336",
6 | "timm_model_pretrained": false,
7 | "timm_pool": "token",
8 | "timm_proj": null
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 768,
14 | "heads": 12,
15 | "layers": 12
16 | },
17 | "custom_text": true
18 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/EVA02-L-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "timm_model_name": "eva02_large_patch14_clip_224",
6 | "timm_model_pretrained": false,
7 | "timm_pool": "token",
8 | "timm_proj": null
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 768,
14 | "heads": 12,
15 | "layers": 12
16 | },
17 | "custom_text": true
18 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/RN101-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": [
7 | 3,
8 | 4,
9 | 23,
10 | 3
11 | ],
12 | "width": 64,
13 | "patch_size": null
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 512,
19 | "heads": 8,
20 | "layers": 12
21 | }
22 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/RN101.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": [
6 | 3,
7 | 4,
8 | 23,
9 | 3
10 | ],
11 | "width": 64,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 512,
18 | "heads": 8,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/RN50-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": [
7 | 3,
8 | 4,
9 | 6,
10 | 3
11 | ],
12 | "width": 64,
13 | "patch_size": null
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 512,
19 | "heads": 8,
20 | "layers": 12
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/RN50.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": [
6 | 3,
7 | 4,
8 | 6,
9 | 3
10 | ],
11 | "width": 64,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 512,
18 | "heads": 8,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/RN50x16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 384,
5 | "layers": [
6 | 6,
7 | 8,
8 | 18,
9 | 8
10 | ],
11 | "width": 96,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 768,
18 | "heads": 12,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/RN50x4.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 288,
5 | "layers": [
6 | 4,
7 | 6,
8 | 10,
9 | 6
10 | ],
11 | "width": 80,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 640,
18 | "heads": 10,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/RN50x64.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 448,
5 | "layers": [
6 | 3,
7 | 15,
8 | 36,
9 | 10
10 | ],
11 | "width": 128,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 1024,
18 | "heads": 16,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-B-16-SigLIP-256.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "init_logit_bias": -10,
4 | "custom_text": true,
5 | "vision_cfg": {
6 | "image_size": 256,
7 | "timm_model_name": "vit_base_patch16_siglip_256",
8 | "timm_model_pretrained": false,
9 | "timm_pool": "map",
10 | "timm_proj": "none"
11 | },
12 | "text_cfg": {
13 | "context_length": 64,
14 | "vocab_size": 32000,
15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
16 | "tokenizer_kwargs": {
17 | "clean": "canonicalize"
18 | },
19 | "width": 768,
20 | "heads": 12,
21 | "layers": 12,
22 | "no_causal_mask": true,
23 | "proj_bias": true,
24 | "pool_type": "last",
25 | "norm_kwargs":{
26 | "eps": 1e-6
27 | }
28 | }
29 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-B-16-SigLIP-384.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "init_logit_bias": -10,
4 | "custom_text": true,
5 | "vision_cfg": {
6 | "image_size": 384,
7 | "timm_model_name": "vit_base_patch16_siglip_384",
8 | "timm_model_pretrained": false,
9 | "timm_pool": "map",
10 | "timm_proj": "none"
11 | },
12 | "text_cfg": {
13 | "context_length": 64,
14 | "vocab_size": 32000,
15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
16 | "tokenizer_kwargs": {
17 | "clean": "canonicalize"
18 | },
19 | "width": 768,
20 | "heads": 12,
21 | "layers": 12,
22 | "no_causal_mask": true,
23 | "proj_bias": true,
24 | "pool_type": "last",
25 | "norm_kwargs":{
26 | "eps": 1e-6
27 | }
28 | }
29 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-B-16-SigLIP-512.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "init_logit_bias": -10,
4 | "custom_text": true,
5 | "vision_cfg": {
6 | "image_size": 512,
7 | "timm_model_name": "vit_base_patch16_siglip_512",
8 | "timm_model_pretrained": false,
9 | "timm_pool": "map",
10 | "timm_proj": "none"
11 | },
12 | "text_cfg": {
13 | "context_length": 64,
14 | "vocab_size": 32000,
15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
16 | "tokenizer_kwargs": {
17 | "clean": "canonicalize"
18 | },
19 | "width": 768,
20 | "heads": 12,
21 | "layers": 12,
22 | "no_causal_mask": true,
23 | "proj_bias": true,
24 | "pool_type": "last",
25 | "norm_kwargs":{
26 | "eps": 1e-6
27 | }
28 | }
29 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-B-16-SigLIP-i18n-256.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "init_logit_bias": -10,
4 | "custom_text": true,
5 | "vision_cfg": {
6 | "image_size": 256,
7 | "timm_model_name": "vit_base_patch16_siglip_256",
8 | "timm_model_pretrained": false,
9 | "timm_pool": "map",
10 | "timm_proj": "none"
11 | },
12 | "text_cfg": {
13 | "context_length": 64,
14 | "vocab_size": 250000,
15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP-i18n-256",
16 | "tokenizer_kwargs": {
17 | "clean": "canonicalize"
18 | },
19 | "width": 768,
20 | "heads": 12,
21 | "layers": 12,
22 | "no_causal_mask": true,
23 | "proj_bias": true,
24 | "pool_type": "last",
25 | "norm_kwargs":{
26 | "eps": 1e-6
27 | }
28 | }
29 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-B-16-SigLIP.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "init_logit_bias": -10,
4 | "custom_text": true,
5 | "vision_cfg": {
6 | "image_size": 224,
7 | "timm_model_name": "vit_base_patch16_siglip_224",
8 | "timm_model_pretrained": false,
9 | "timm_pool": "map",
10 | "timm_proj": "none"
11 | },
12 | "text_cfg": {
13 | "context_length": 64,
14 | "vocab_size": 32000,
15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
16 | "tokenizer_kwargs": {
17 | "clean": "canonicalize"
18 | },
19 | "width": 768,
20 | "heads": 12,
21 | "layers": 12,
22 | "no_causal_mask": true,
23 | "proj_bias": true,
24 | "pool_type": "last",
25 | "norm_kwargs":{
26 | "eps": 1e-6
27 | }
28 | }
29 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-B-16-avg.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 16,
8 | "pool_type": "avg"
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-B-16-plus-240.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 240,
5 | "layers": 12,
6 | "width": 896,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 640,
13 | "heads": 10,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-B-16-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 896,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 640,
13 | "heads": 10,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-B-16-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": 12,
7 | "width": 768,
8 | "patch_size": 16
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-B-32-256.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 256,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-B-32-plus-256.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 256,
5 | "layers": 12,
6 | "width": 896,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 640,
13 | "heads": 10,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-B-32-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": 12,
7 | "width": 768,
8 | "patch_size": 32
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-B-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-H-14-378-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 378,
6 | "layers": 32,
7 | "width": 1280,
8 | "head_width": 80,
9 | "patch_size": 14
10 | },
11 | "text_cfg": {
12 | "context_length": 77,
13 | "vocab_size": 49408,
14 | "width": 1024,
15 | "heads": 16,
16 | "layers": 24
17 | }
18 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-H-14-CLIPA-336.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 336,
5 | "layers": 32,
6 | "width": 1280,
7 | "head_width": 80,
8 | "patch_size": 14,
9 | "no_ln_pre": true,
10 | "pool_type": "avg",
11 | "final_ln_after_pool": true
12 | },
13 | "text_cfg": {
14 | "context_length": 32,
15 | "vocab_size": 32000,
16 | "hf_tokenizer_name": "bert-base-uncased",
17 | "tokenizer_kwargs": {
18 | "strip_sep_token": true
19 | },
20 | "width": 1024,
21 | "heads": 16,
22 | "layers": 24,
23 | "pool_type": "last",
24 | "no_causal_mask": true
25 | }
26 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-H-14-CLIPA.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 32,
6 | "width": 1280,
7 | "head_width": 80,
8 | "patch_size": 14,
9 | "no_ln_pre": true,
10 | "pool_type": "avg",
11 | "final_ln_after_pool": true
12 | },
13 | "text_cfg": {
14 | "context_length": 32,
15 | "vocab_size": 32000,
16 | "hf_tokenizer_name": "bert-base-uncased",
17 | "tokenizer_kwargs": {
18 | "strip_sep_token": true
19 | },
20 | "width": 1024,
21 | "heads": 16,
22 | "layers": 24,
23 | "pool_type": "last",
24 | "no_causal_mask": true
25 | }
26 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-H-14-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": 32,
7 | "width": 1280,
8 | "head_width": 80,
9 | "patch_size": 14
10 | },
11 | "text_cfg": {
12 | "context_length": 77,
13 | "vocab_size": 49408,
14 | "width": 1024,
15 | "heads": 16,
16 | "layers": 24
17 | }
18 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-H-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 32,
6 | "width": 1280,
7 | "head_width": 80,
8 | "patch_size": 14
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 1024,
14 | "heads": 16,
15 | "layers": 24
16 | }
17 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-H-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 32,
6 | "width": 1280,
7 | "head_width": 80,
8 | "patch_size": 16
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 1024,
14 | "heads": 16,
15 | "layers": 24
16 | }
17 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-L-14-280.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 280,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-L-14-336.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 336,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-L-14-CLIPA-336.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 336,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14,
8 | "no_ln_pre": true,
9 | "pool_type": "avg",
10 | "final_ln_after_pool": true
11 | },
12 | "text_cfg": {
13 | "context_length": 32,
14 | "vocab_size": 32000,
15 | "hf_tokenizer_name": "bert-base-uncased",
16 | "tokenizer_kwargs": {
17 | "strip_sep_token": true
18 | },
19 | "width": 768,
20 | "heads": 12,
21 | "layers": 12,
22 | "pool_type": "last",
23 | "no_causal_mask": true
24 | }
25 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-L-14-CLIPA.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14,
8 | "no_ln_pre": true,
9 | "pool_type": "avg",
10 | "final_ln_after_pool": true
11 | },
12 | "text_cfg": {
13 | "context_length": 32,
14 | "vocab_size": 32000,
15 | "hf_tokenizer_name": "bert-base-uncased",
16 | "tokenizer_kwargs": {
17 | "strip_sep_token": true
18 | },
19 | "width": 768,
20 | "heads": 12,
21 | "layers": 12,
22 | "pool_type": "last",
23 | "no_causal_mask": true
24 | }
25 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-L-14-avg.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14,
8 | "pool_type": "avg"
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 768,
14 | "heads": 12,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-L-14-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": 24,
7 | "width": 1024,
8 | "patch_size": 14
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 768,
14 | "heads": 12,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-L-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-L-16-320.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 320,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-L-16-SigLIP-256.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "init_logit_bias": -10,
4 | "custom_text": true,
5 | "vision_cfg": {
6 | "image_size": 256,
7 | "timm_model_name": "vit_large_patch16_siglip_256",
8 | "timm_model_pretrained": false,
9 | "timm_pool": "map",
10 | "timm_proj": "none"
11 | },
12 | "text_cfg": {
13 | "context_length": 64,
14 | "vocab_size": 32000,
15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
16 | "tokenizer_kwargs": {
17 | "clean": "canonicalize"
18 | },
19 | "width": 1024,
20 | "heads": 16,
21 | "layers": 24,
22 | "no_causal_mask": true,
23 | "proj_bias": true,
24 | "pool_type": "last",
25 | "norm_kwargs":{
26 | "eps": 1e-6
27 | }
28 | }
29 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-L-16-SigLIP-384.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "init_logit_bias": -10,
4 | "custom_text": true,
5 | "vision_cfg": {
6 | "image_size": 384,
7 | "timm_model_name": "vit_large_patch16_siglip_384",
8 | "timm_model_pretrained": false,
9 | "timm_pool": "map",
10 | "timm_proj": "none"
11 | },
12 | "text_cfg": {
13 | "context_length": 64,
14 | "vocab_size": 32000,
15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
16 | "tokenizer_kwargs": {
17 | "clean": "canonicalize"
18 | },
19 | "width": 1024,
20 | "heads": 16,
21 | "layers": 24,
22 | "no_causal_mask": true,
23 | "proj_bias": true,
24 | "pool_type": "last",
25 | "norm_kwargs":{
26 | "eps": 1e-6
27 | }
28 | }
29 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-L-16-avg.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 16,
8 | "pool_type": "avg"
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 768,
14 | "heads": 12,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-L-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-M-16-alt.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 384,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 512,
7 | "patch_size": 16,
8 | "ls_init_value": 1e-4
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 384,
14 | "heads": 6,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-M-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 512,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-M-32-alt.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 384,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 512,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 384,
13 | "heads": 6,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-M-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 512,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-S-16-alt.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 256,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 384,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 256,
13 | "heads": 4,
14 | "layers": 10
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-S-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 384,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 384,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 384,
13 | "heads": 6,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-S-32-alt.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 256,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 384,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 256,
13 | "heads": 4,
14 | "layers": 10
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-S-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 384,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 384,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 384,
13 | "heads": 6,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-SO400M-14-SigLIP-384.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1152,
3 | "init_logit_bias": -10,
4 | "custom_text": true,
5 | "vision_cfg": {
6 | "image_size": 384,
7 | "timm_model_name": "vit_so400m_patch14_siglip_384",
8 | "timm_model_pretrained": false,
9 | "timm_pool": "map",
10 | "timm_proj": "none"
11 | },
12 | "text_cfg": {
13 | "context_length": 64,
14 | "vocab_size": 32000,
15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
16 | "tokenizer_kwargs": {
17 | "clean": "canonicalize"
18 | },
19 | "width": 1152,
20 | "heads": 16,
21 | "layers": 27,
22 | "mlp_ratio": 3.7362,
23 | "no_causal_mask": true,
24 | "proj_bias": true,
25 | "pool_type": "last",
26 | "norm_kwargs":{
27 | "eps": 1e-6
28 | }
29 | }
30 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-SO400M-14-SigLIP.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1152,
3 | "init_logit_bias": -10,
4 | "custom_text": true,
5 | "vision_cfg": {
6 | "image_size": 224,
7 | "timm_model_name": "vit_so400m_patch14_siglip_224",
8 | "timm_model_pretrained": false,
9 | "timm_pool": "map",
10 | "timm_proj": "none"
11 | },
12 | "text_cfg": {
13 | "context_length": 16,
14 | "vocab_size": 32000,
15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
16 | "tokenizer_kwargs": {
17 | "clean": "canonicalize"
18 | },
19 | "width": 1152,
20 | "heads": 16,
21 | "layers": 27,
22 | "mlp_ratio": 3.7362,
23 | "no_causal_mask": true,
24 | "proj_bias": true,
25 | "pool_type": "last",
26 | "norm_kwargs":{
27 | "eps": 1e-6
28 | }
29 | }
30 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-bigG-14-CLIPA-336.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1280,
3 | "vision_cfg": {
4 | "image_size": 336,
5 | "layers": 48,
6 | "width": 1664,
7 | "head_width": 104,
8 | "mlp_ratio": 4.9231,
9 | "patch_size": 14,
10 | "no_ln_pre": true,
11 | "pool_type": "avg",
12 | "final_ln_after_pool": true
13 | },
14 | "text_cfg": {
15 | "context_length": 32,
16 | "vocab_size": 32000,
17 | "hf_tokenizer_name": "bert-base-uncased",
18 | "tokenizer_kwargs": {
19 | "strip_sep_token": true
20 | },
21 | "width": 1280,
22 | "heads": 20,
23 | "layers": 32,
24 | "pool_type": "last",
25 | "no_causal_mask": true
26 | }
27 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-bigG-14-CLIPA.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1280,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 48,
6 | "width": 1664,
7 | "head_width": 104,
8 | "mlp_ratio": 4.9231,
9 | "patch_size": 14,
10 | "no_ln_pre": true,
11 | "pool_type": "avg",
12 | "final_ln_after_pool": true
13 | },
14 | "text_cfg": {
15 | "context_length": 32,
16 | "vocab_size": 32000,
17 | "hf_tokenizer_name": "bert-base-uncased",
18 | "tokenizer_kwargs": {
19 | "strip_sep_token": true
20 | },
21 | "width": 1280,
22 | "heads": 20,
23 | "layers": 32,
24 | "pool_type": "last",
25 | "no_causal_mask": true
26 | }
27 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-bigG-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1280,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 48,
6 | "width": 1664,
7 | "head_width": 104,
8 | "mlp_ratio": 4.9231,
9 | "patch_size": 14
10 | },
11 | "text_cfg": {
12 | "context_length": 77,
13 | "vocab_size": 49408,
14 | "width": 1280,
15 | "heads": 20,
16 | "layers": 32
17 | }
18 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-e-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1280,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 56,
6 | "width": 1792,
7 | "head_width": 112,
8 | "mlp_ratio": 8.5715,
9 | "patch_size": 14
10 | },
11 | "text_cfg": {
12 | "context_length": 77,
13 | "vocab_size": 49408,
14 | "width": 1280,
15 | "heads": 20,
16 | "layers": 36
17 | }
18 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/ViT-g-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 40,
6 | "width": 1408,
7 | "head_width": 88,
8 | "mlp_ratio": 4.3637,
9 | "patch_size": 14
10 | },
11 | "text_cfg": {
12 | "context_length": 77,
13 | "vocab_size": 49408,
14 | "width": 1024,
15 | "heads": 16,
16 | "layers": 24
17 | }
18 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/coca_ViT-B-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 32,
8 | "attentional_pool": true,
9 | "attn_pooler_heads": 8,
10 | "output_tokens": true
11 | },
12 | "text_cfg": {
13 | "context_length": 76,
14 | "vocab_size": 49408,
15 | "width": 512,
16 | "heads": 8,
17 | "layers": 12,
18 | "embed_cls": true,
19 | "output_tokens": true
20 | },
21 | "multimodal_cfg": {
22 | "context_length": 76,
23 | "vocab_size": 49408,
24 | "width": 512,
25 | "heads": 8,
26 | "layers": 12,
27 | "attn_pooler_heads": 8
28 | },
29 | "custom_text": true
30 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/coca_ViT-L-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14,
8 | "attentional_pool": true,
9 | "attn_pooler_heads": 8,
10 | "output_tokens": true
11 | },
12 | "text_cfg": {
13 | "context_length": 76,
14 | "vocab_size": 49408,
15 | "width": 768,
16 | "heads": 12,
17 | "layers": 12,
18 | "embed_cls": true,
19 | "output_tokens": true
20 | },
21 | "multimodal_cfg": {
22 | "context_length": 76,
23 | "vocab_size": 49408,
24 | "width": 768,
25 | "heads": 12,
26 | "layers": 12,
27 | "attn_pooler_heads": 12
28 | },
29 | "custom_text": true
30 | }
31 |
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/coca_base.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "multimodal_cfg": {
4 | "width": 768,
5 | "context_length": 76,
6 | "vocab_size": 64000,
7 | "mlp_ratio": 4,
8 | "layers": 12,
9 | "dim_head": 64,
10 | "heads": 12,
11 | "n_queries": 256,
12 | "attn_pooler_heads": 8
13 | },
14 | "vision_cfg": {
15 | "image_size": 288,
16 | "layers": 12,
17 | "width": 768,
18 | "patch_size": 18,
19 | "output_tokens": true
20 | },
21 | "text_cfg": {
22 | "context_length": 76,
23 | "vocab_size": 64000,
24 | "layers": 12,
25 | "heads": 12,
26 | "width": 768,
27 | "embed_cls": true,
28 | "output_tokens": true
29 | },
30 | "custom_text": true
31 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/coca_roberta-ViT-B-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 32,
8 | "output_tokens": true
9 | },
10 | "text_cfg": {
11 | "hf_model_name": "roberta-base",
12 | "hf_tokenizer_name": "roberta-base",
13 | "hf_proj_type": "linear",
14 | "width": 768,
15 | "output_tokens": true
16 | },
17 | "multimodal_cfg": {
18 | "context_length": 76,
19 | "width": 768,
20 | "heads": 8,
21 | "layers": 12
22 | },
23 | "custom_text": true
24 | }
25 |
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/convnext_base.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_base",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 224
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 512,
16 | "heads": 8,
17 | "layers": 12
18 | }
19 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/convnext_base_w.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_base",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 256
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 640,
16 | "heads": 10,
17 | "layers": 12
18 | }
19 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/convnext_base_w_320.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_base",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 320
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 640,
16 | "heads": 10,
17 | "layers": 12
18 | }
19 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/convnext_large.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_large",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 224
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 768,
16 | "heads": 12,
17 | "layers": 12
18 | }
19 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/convnext_large_d.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_large",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "mlp",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 256
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 768,
16 | "heads": 12,
17 | "layers": 16
18 | }
19 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/convnext_large_d_320.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_large",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "mlp",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 320
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 768,
16 | "heads": 12,
17 | "layers": 16
18 | }
19 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/convnext_small.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_small",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 224
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 512,
16 | "heads": 8,
17 | "layers": 12
18 | }
19 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/convnext_tiny.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_tiny",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 224
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 512,
16 | "heads": 8,
17 | "layers": 12
18 | }
19 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/convnext_xlarge.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_xlarge",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 256
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 1024,
16 | "heads": 16,
17 | "layers": 20
18 | }
19 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/convnext_xxlarge.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_xxlarge",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 256
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 1024,
16 | "heads": 16,
17 | "layers": 24
18 | }
19 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/convnext_xxlarge_320.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_xxlarge",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 320
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 1024,
16 | "heads": 16,
17 | "layers": 24
18 | }
19 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/mt5-base-ViT-B-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "hf_model_name": "google/mt5-base",
11 | "hf_tokenizer_name": "google/mt5-base",
12 | "hf_pooler_type": "mean_pooler"
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/mt5-xl-ViT-H-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 32,
6 | "width": 1280,
7 | "head_width": 80,
8 | "patch_size": 14
9 | },
10 | "text_cfg": {
11 | "hf_model_name": "google/mt5-xl",
12 | "hf_tokenizer_name": "google/mt5-xl",
13 | "hf_pooler_type": "mean_pooler"
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/nllb-clip-base-siglip.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "custom_text": true,
4 | "init_logit_bias": -10,
5 | "vision_cfg": {
6 | "image_size": 384,
7 | "timm_model_name": "vit_base_patch16_siglip_384",
8 | "timm_model_pretrained": false,
9 | "timm_pool": "map",
10 | "timm_proj": "none"
11 | },
12 | "text_cfg": {
13 | "hf_model_name": "facebook/nllb-200-distilled-600M",
14 | "hf_tokenizer_name": "facebook/nllb-200-distilled-600M",
15 | "hf_proj_type": "linear",
16 | "hf_pooler_type": "cls_pooler"
17 | }
18 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/nllb-clip-base.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "hf_model_name": "facebook/nllb-200-distilled-600M",
11 | "hf_tokenizer_name": "facebook/nllb-200-distilled-600M",
12 | "hf_proj_type": "linear",
13 | "hf_pooler_type": "cls_pooler"
14 | }
15 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/nllb-clip-large-siglip.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1152,
3 | "custom_text": true,
4 | "init_logit_bias": -10,
5 | "vision_cfg": {
6 | "image_size": 384,
7 | "timm_model_name": "vit_so400m_patch14_siglip_384",
8 | "timm_model_pretrained": false,
9 | "timm_pool": "map",
10 | "timm_proj": "none"
11 | },
12 | "text_cfg": {
13 | "hf_model_name": "facebook/nllb-200-distilled-1.3B",
14 | "hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B",
15 | "hf_proj_type": "linear",
16 | "hf_pooler_type": "cls_pooler"
17 | }
18 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/nllb-clip-large.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 32,
6 | "width": 1280,
7 | "head_width": 80,
8 | "patch_size": 14
9 | },
10 | "text_cfg": {
11 | "hf_model_name": "facebook/nllb-200-distilled-1.3B",
12 | "hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B",
13 | "hf_proj_type": "linear",
14 | "hf_pooler_type": "cls_pooler"
15 | }
16 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/roberta-ViT-B-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": 12,
7 | "width": 768,
8 | "patch_size": 32
9 | },
10 | "text_cfg": {
11 | "hf_model_name": "roberta-base",
12 | "hf_tokenizer_name": "roberta-base",
13 | "hf_pooler_type": "mean_pooler"
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/swin_base_patch4_window7_224.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "timm_model_name": "swin_base_patch4_window7_224",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 640,
14 | "heads": 10,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/vit_medium_patch16_gap_256.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "vit_medium_patch16_gap_256",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "image_size": 256
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "vit_relpos_medium_patch16_cls_224",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "hf_model_name": "xlm-roberta-base",
11 | "hf_tokenizer_name": "xlm-roberta-base",
12 | "hf_pooler_type": "mean_pooler"
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/opencls/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 32,
6 | "width": 1280,
7 | "head_width": 80,
8 | "patch_size": 14
9 | },
10 | "text_cfg": {
11 | "hf_model_name": "xlm-roberta-large",
12 | "hf_tokenizer_name": "xlm-roberta-large",
13 | "hf_pooler_type": "mean_pooler"
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/opencls/open_clip/modified_resnet.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | from open_clip.utils import freeze_batch_norm_2d
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.act1 = nn.ReLU(inplace=True)
20 |
21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.act2 = nn.ReLU(inplace=True)
24 |
25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26 |
27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29 | self.act3 = nn.ReLU(inplace=True)
30 |
31 | self.downsample = None
32 | self.stride = stride
33 |
34 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36 | self.downsample = nn.Sequential(OrderedDict([
37 | ("-1", nn.AvgPool2d(stride)),
38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39 | ("1", nn.BatchNorm2d(planes * self.expansion))
40 | ]))
41 |
42 | def forward(self, x: torch.Tensor):
43 | identity = x
44 |
45 | out = self.act1(self.bn1(self.conv1(x)))
46 | out = self.act2(self.bn2(self.conv2(out)))
47 | out = self.avgpool(out)
48 | out = self.bn3(self.conv3(out))
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.act3(out)
55 | return out
56 |
57 |
58 | class AttentionPool2d(nn.Module):
59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60 | super().__init__()
61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62 | self.k_proj = nn.Linear(embed_dim, embed_dim)
63 | self.q_proj = nn.Linear(embed_dim, embed_dim)
64 | self.v_proj = nn.Linear(embed_dim, embed_dim)
65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66 | self.num_heads = num_heads
67 |
68 | def forward(self, x):
69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72 | x, _ = F.multi_head_attention_forward(
73 | query=x, key=x, value=x,
74 | embed_dim_to_check=x.shape[-1],
75 | num_heads=self.num_heads,
76 | q_proj_weight=self.q_proj.weight,
77 | k_proj_weight=self.k_proj.weight,
78 | v_proj_weight=self.v_proj.weight,
79 | in_proj_weight=None,
80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81 | bias_k=None,
82 | bias_v=None,
83 | add_zero_attn=False,
84 | dropout_p=0.,
85 | out_proj_weight=self.c_proj.weight,
86 | out_proj_bias=self.c_proj.bias,
87 | use_separate_proj_weight=True,
88 | training=self.training,
89 | need_weights=False
90 | )
91 |
92 | return x[0]
93 |
94 |
95 | class ModifiedResNet(nn.Module):
96 | """
97 | A ResNet class that is similar to torchvision's but contains the following changes:
98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
100 | - The final pooling layer is a QKV attention instead of an average pool
101 | """
102 |
103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64):
104 | super().__init__()
105 | self.output_dim = output_dim
106 | self.image_size = image_size
107 |
108 | # the 3-layer stem
109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
110 | self.bn1 = nn.BatchNorm2d(width // 2)
111 | self.act1 = nn.ReLU(inplace=True)
112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
113 | self.bn2 = nn.BatchNorm2d(width // 2)
114 | self.act2 = nn.ReLU(inplace=True)
115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
116 | self.bn3 = nn.BatchNorm2d(width)
117 | self.act3 = nn.ReLU(inplace=True)
118 | self.avgpool = nn.AvgPool2d(2)
119 |
120 | # residual layers
121 | self._inplanes = width # this is a *mutable* variable used during construction
122 | self.layer1 = self._make_layer(width, layers[0])
123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
126 |
127 | embed_dim = width * 32 # the ResNet feature dimension
128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
129 |
130 | self.init_parameters()
131 |
132 | def _make_layer(self, planes, blocks, stride=1):
133 | layers = [Bottleneck(self._inplanes, planes, stride)]
134 |
135 | self._inplanes = planes * Bottleneck.expansion
136 | for _ in range(1, blocks):
137 | layers.append(Bottleneck(self._inplanes, planes))
138 |
139 | return nn.Sequential(*layers)
140 |
141 | def init_parameters(self):
142 | if self.attnpool is not None:
143 | std = self.attnpool.c_proj.in_features ** -0.5
144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std)
145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std)
146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std)
147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std)
148 |
149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
150 | for name, param in resnet_block.named_parameters():
151 | if name.endswith("bn3.weight"):
152 | nn.init.zeros_(param)
153 |
154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model'
156 | for param in self.parameters():
157 | param.requires_grad = False
158 | if freeze_bn_stats:
159 | freeze_batch_norm_2d(self)
160 |
161 | @torch.jit.ignore
162 | def set_grad_checkpointing(self, enable=True):
163 | # FIXME support for non-transformer
164 | pass
165 |
166 | def stem(self, x):
167 | x = self.act1(self.bn1(self.conv1(x)))
168 | x = self.act2(self.bn2(self.conv2(x)))
169 | x = self.act3(self.bn3(self.conv3(x)))
170 | x = self.avgpool(x)
171 | return x
172 |
173 | def forward(self, x):
174 | x = self.stem(x)
175 | x = self.layer1(x)
176 | x = self.layer2(x)
177 | x = self.layer3(x)
178 | x = self.layer4(x)
179 | x = self.attnpool(x)
180 |
181 | return x
182 |
--------------------------------------------------------------------------------
/opencls/open_clip/openai.py:
--------------------------------------------------------------------------------
1 | """ OpenAI pretrained model functions
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 |
6 | import os
7 | import warnings
8 | from typing import List, Optional, Union
9 |
10 | import torch
11 |
12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
14 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
15 |
16 | __all__ = ["list_openai_models", "load_openai_model"]
17 |
18 |
19 | def list_openai_models() -> List[str]:
20 | """Returns the names of available CLIP models"""
21 | return list_pretrained_models_by_tag('openai')
22 |
23 |
24 | def load_openai_model(
25 | name: str,
26 | precision: Optional[str] = None,
27 | device: Optional[Union[str, torch.device]] = None,
28 | cache_dir: Optional[str] = None,
29 | ):
30 | """Load a CLIP model
31 |
32 | Parameters
33 | ----------
34 | name : str
35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36 | precision: str
37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38 | device : Union[str, torch.device]
39 | The device to put the loaded model
40 | cache_dir : Optional[str]
41 | The directory to cache the downloaded model weights
42 |
43 | Returns
44 | -------
45 | model : torch.nn.Module
46 | The CLIP model
47 | preprocess : Callable[[PIL.Image], torch.Tensor]
48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
49 | """
50 | if device is None:
51 | device = "cuda" if torch.cuda.is_available() else "cpu"
52 | if precision is None:
53 | precision = 'fp32' if device == 'cpu' else 'fp16'
54 |
55 | if get_pretrained_url(name, 'openai'):
56 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
57 | elif os.path.isfile(name):
58 | model_path = name
59 | else:
60 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
61 |
62 | try:
63 | # loading JIT archive
64 | model = torch.jit.load(model_path, map_location="cpu").eval()
65 | state_dict = None
66 | except RuntimeError:
67 | # loading saved state dict
68 | state_dict = torch.load(model_path, map_location="cpu")
69 |
70 | # Build a non-jit model from the OpenAI jitted model state dict
71 | cast_dtype = get_cast_dtype(precision)
72 | try:
73 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
74 | except KeyError:
75 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
76 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
77 |
78 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
79 | model = model.to(device)
80 | # FIXME support pure fp16/bf16 precision modes
81 | if precision != 'fp16':
82 | model.float()
83 | if precision == 'bf16':
84 | # for bf16, convert back to low-precision
85 | convert_weights_to_lp(model, dtype=torch.bfloat16)
86 |
87 | # add mean / std attributes for consistency with OpenCLIP models
88 | model.visual.image_mean = OPENAI_DATASET_MEAN
89 | model.visual.image_std = OPENAI_DATASET_STD
90 | return model
91 |
--------------------------------------------------------------------------------
/opencls/open_clip/pos_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # Position embedding utils
8 | # --------------------------------------------------------
9 |
10 | import numpy as np
11 |
12 | import torch
13 |
14 | # --------------------------------------------------------
15 | # 2D sine-cosine position embedding
16 | # References:
17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18 | # MoCo v3: https://github.com/facebookresearch/moco-v3
19 | # --------------------------------------------------------
20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21 | """
22 | grid_size: int of the grid height and width
23 | return:
24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25 | """
26 | grid_h = np.arange(grid_size, dtype=np.float32)
27 | grid_w = np.arange(grid_size, dtype=np.float32)
28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
29 | grid = np.stack(grid, axis=0)
30 |
31 | grid = grid.reshape([2, 1, grid_size, grid_size])
32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33 | if cls_token:
34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35 | return pos_embed
36 |
37 |
38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39 | assert embed_dim % 2 == 0
40 |
41 | # use half of dimensions to encode grid_h
42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
44 |
45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46 | return emb
47 |
48 |
49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50 | """
51 | embed_dim: output dimension for each position
52 | pos: a list of positions to be encoded: size (M,)
53 | out: (M, D)
54 | """
55 | assert embed_dim % 2 == 0
56 | omega = np.arange(embed_dim // 2, dtype=float)
57 | omega /= embed_dim / 2.
58 | omega = 1. / 10000**omega # (D/2,)
59 |
60 | pos = pos.reshape(-1) # (M,)
61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
62 |
63 | emb_sin = np.sin(out) # (M, D/2)
64 | emb_cos = np.cos(out) # (M, D/2)
65 |
66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
67 | return emb
68 |
69 |
70 | # --------------------------------------------------------
71 | # Interpolate position embeddings for high-resolution
72 | # References:
73 | # DeiT: https://github.com/facebookresearch/deit
74 | # --------------------------------------------------------
75 | def interpolate_pos_embed(model, checkpoint_model):
76 | if 'pos_embed' in checkpoint_model:
77 | pos_embed_checkpoint = checkpoint_model['pos_embed']
78 | embedding_size = pos_embed_checkpoint.shape[-1]
79 | num_patches = model.patch_embed.num_patches
80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
81 | # height (== width) for the checkpoint position embedding
82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
83 | # height (== width) for the new position embedding
84 | new_size = int(num_patches ** 0.5)
85 | # class_token and dist_token are kept unchanged
86 | if orig_size != new_size:
87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
89 | # only the position tokens are interpolated
90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
92 | pos_tokens = torch.nn.functional.interpolate(
93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
96 | checkpoint_model['pos_embed'] = new_pos_embed
97 |
--------------------------------------------------------------------------------
/opencls/open_clip/push_to_hf_hub.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | from pathlib import Path
5 | from tempfile import TemporaryDirectory
6 | from typing import Optional, Tuple, Union
7 |
8 | import torch
9 |
10 | try:
11 | from huggingface_hub import (
12 | create_repo,
13 | get_hf_file_metadata,
14 | hf_hub_download,
15 | hf_hub_url,
16 | repo_type_and_id_from_hf_id,
17 | upload_folder,
18 | list_repo_files,
19 | )
20 | from huggingface_hub.utils import EntryNotFoundError
21 | _has_hf_hub = True
22 | except ImportError:
23 | _has_hf_hub = False
24 |
25 | try:
26 | import safetensors.torch
27 | _has_safetensors = True
28 | except ImportError:
29 | _has_safetensors = False
30 |
31 | from .factory import create_model_from_pretrained, get_model_config, get_tokenizer
32 | from .tokenizer import HFTokenizer
33 |
34 | # Default name for a weights file hosted on the Huggingface Hub.
35 | HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
36 | HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
37 | HF_CONFIG_NAME = 'open_clip_config.json'
38 |
39 |
40 | def save_config_for_hf(
41 | model,
42 | config_path: str,
43 | model_config: Optional[dict]
44 | ):
45 | preprocess_cfg = {
46 | 'mean': model.visual.image_mean,
47 | 'std': model.visual.image_std,
48 | }
49 | other_pp = getattr(model.visual, 'preprocess_cfg', {})
50 | if 'interpolation' in other_pp:
51 | preprocess_cfg['interpolation'] = other_pp['interpolation']
52 | if 'resize_mode' in other_pp:
53 | preprocess_cfg['resize_mode'] = other_pp['resize_mode']
54 | hf_config = {
55 | 'model_cfg': model_config,
56 | 'preprocess_cfg': preprocess_cfg,
57 | }
58 |
59 | with config_path.open('w') as f:
60 | json.dump(hf_config, f, indent=2)
61 |
62 |
63 | def save_for_hf(
64 | model,
65 | tokenizer: HFTokenizer,
66 | model_config: dict,
67 | save_directory: str,
68 | safe_serialization: Union[bool, str] = 'both',
69 | skip_weights : bool = False,
70 | ):
71 | config_filename = HF_CONFIG_NAME
72 |
73 | save_directory = Path(save_directory)
74 | save_directory.mkdir(exist_ok=True, parents=True)
75 |
76 | if not skip_weights:
77 | tensors = model.state_dict()
78 | if safe_serialization is True or safe_serialization == "both":
79 | assert _has_safetensors, "`pip install safetensors` to use .safetensors"
80 | safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
81 | if safe_serialization is False or safe_serialization == "both":
82 | torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
83 |
84 | tokenizer.save_pretrained(save_directory)
85 |
86 | config_path = save_directory / config_filename
87 | save_config_for_hf(model, config_path, model_config=model_config)
88 |
89 |
90 | def push_to_hf_hub(
91 | model,
92 | tokenizer,
93 | model_config: Optional[dict],
94 | repo_id: str,
95 | commit_message: str = 'Add model',
96 | token: Optional[str] = None,
97 | revision: Optional[str] = None,
98 | private: bool = False,
99 | create_pr: bool = False,
100 | model_card: Optional[dict] = None,
101 | safe_serialization: Union[bool, str] = False,
102 | ):
103 | if not isinstance(tokenizer, HFTokenizer):
104 | # FIXME this makes it awkward to push models with new tokenizers, come up with better soln.
105 | # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14
106 | tokenizer = HFTokenizer('openai/clip-vit-large-patch14')
107 |
108 | # Create repo if it doesn't exist yet
109 | repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
110 |
111 | # Infer complete repo_id from repo_url
112 | # Can be different from the input `repo_id` if repo_owner was implicit
113 | _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
114 | repo_id = f"{repo_owner}/{repo_name}"
115 |
116 | # Check if repo already exists and determine what needs updating
117 | repo_exists = False
118 | repo_files = {}
119 | try:
120 | repo_files = set(list_repo_files(repo_id))
121 | repo_exists = True
122 | except Exception as e:
123 | print('Repo does not exist', e)
124 |
125 | try:
126 | get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
127 | has_readme = True
128 | except EntryNotFoundError:
129 | has_readme = False
130 |
131 | # Dump model and push to Hub
132 | with TemporaryDirectory() as tmpdir:
133 | # Save model weights and config.
134 | save_for_hf(
135 | model,
136 | tokenizer=tokenizer,
137 | model_config=model_config,
138 | save_directory=tmpdir,
139 | safe_serialization=safe_serialization,
140 | )
141 |
142 | # Add readme if it does not exist
143 | if not has_readme:
144 | model_card = model_card or {}
145 | model_name = repo_id.split('/')[-1]
146 | readme_path = Path(tmpdir) / "README.md"
147 | readme_text = generate_readme(model_card, model_name)
148 | readme_path.write_text(readme_text)
149 |
150 | # Upload model and return
151 | return upload_folder(
152 | repo_id=repo_id,
153 | folder_path=tmpdir,
154 | revision=revision,
155 | create_pr=create_pr,
156 | commit_message=commit_message,
157 | )
158 |
159 |
160 | def push_pretrained_to_hf_hub(
161 | model_name,
162 | pretrained: str,
163 | repo_id: str,
164 | precision: str = 'fp32',
165 | image_mean: Optional[Tuple[float, ...]] = None,
166 | image_std: Optional[Tuple[float, ...]] = None,
167 | image_interpolation: Optional[str] = None,
168 | image_resize_mode: Optional[str] = None, # only effective for inference
169 | commit_message: str = 'Add model',
170 | token: Optional[str] = None,
171 | revision: Optional[str] = None,
172 | private: bool = False,
173 | create_pr: bool = False,
174 | model_card: Optional[dict] = None,
175 | hf_tokenizer_self: bool = False,
176 | ):
177 | model, preprocess_eval = create_model_from_pretrained(
178 | model_name,
179 | pretrained=pretrained,
180 | precision=precision,
181 | image_mean=image_mean,
182 | image_std=image_std,
183 | image_interpolation=image_interpolation,
184 | image_resize_mode=image_resize_mode,
185 | )
186 | model_config = get_model_config(model_name)
187 | assert model_config
188 |
189 | tokenizer = get_tokenizer(model_name)
190 | if hf_tokenizer_self:
191 | # make hf tokenizer config in the uploaded model point to self instead of original location
192 | model_config['text']['hf_tokenizer_name'] = repo_id
193 |
194 | push_to_hf_hub(
195 | model=model,
196 | tokenizer=tokenizer,
197 | model_config=model_config,
198 | repo_id=repo_id,
199 | commit_message=commit_message,
200 | token=token,
201 | revision=revision,
202 | private=private,
203 | create_pr=create_pr,
204 | model_card=model_card,
205 | safe_serialization='both',
206 | )
207 |
208 |
209 | def generate_readme(model_card: dict, model_name: str):
210 | tags = model_card.pop('tags', ('clip',))
211 | pipeline_tag = model_card.pop('pipeline_tag', 'zero-shot-image-classification')
212 | readme_text = "---\n"
213 | if tags:
214 | readme_text += "tags:\n"
215 | for t in tags:
216 | readme_text += f"- {t}\n"
217 | readme_text += "library_name: open_clip\n"
218 | readme_text += f"pipeline_tag: {pipeline_tag}\n"
219 | readme_text += f"license: {model_card.get('license', 'mit')}\n"
220 | if 'details' in model_card and 'Dataset' in model_card['details']:
221 | readme_text += 'datasets:\n'
222 | readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
223 | readme_text += "---\n"
224 | readme_text += f"# Model card for {model_name}\n"
225 | if 'description' in model_card:
226 | readme_text += f"\n{model_card['description']}\n"
227 | if 'details' in model_card:
228 | readme_text += f"\n## Model Details\n"
229 | for k, v in model_card['details'].items():
230 | if isinstance(v, (list, tuple)):
231 | readme_text += f"- **{k}:**\n"
232 | for vi in v:
233 | readme_text += f" - {vi}\n"
234 | elif isinstance(v, dict):
235 | readme_text += f"- **{k}:**\n"
236 | for ki, vi in v.items():
237 | readme_text += f" - {ki}: {vi}\n"
238 | else:
239 | readme_text += f"- **{k}:** {v}\n"
240 | if 'usage' in model_card:
241 | readme_text += f"\n## Model Usage\n"
242 | readme_text += model_card['usage']
243 | readme_text += '\n'
244 |
245 | if 'comparison' in model_card:
246 | readme_text += f"\n## Model Comparison\n"
247 | readme_text += model_card['comparison']
248 | readme_text += '\n'
249 |
250 | if 'citation' in model_card:
251 | readme_text += f"\n## Citation\n"
252 | if not isinstance(model_card['citation'], (list, tuple)):
253 | citations = [model_card['citation']]
254 | else:
255 | citations = model_card['citation']
256 | for c in citations:
257 | readme_text += f"```bibtex\n{c}\n```\n"
258 |
259 | return readme_text
260 |
261 |
262 | if __name__ == "__main__":
263 | parser = argparse.ArgumentParser(description="Push to Hugging Face Hub")
264 | parser.add_argument(
265 | "--model", type=str, help="Name of the model to use.",
266 | )
267 | parser.add_argument(
268 | "--pretrained", type=str,
269 | help="Use a pretrained CLIP model weights with the specified tag or file path.",
270 | )
271 | parser.add_argument(
272 | "--repo-id", type=str,
273 | help="Destination HF Hub repo-id ie 'organization/model_id'.",
274 | )
275 | parser.add_argument(
276 | "--precision", type=str, default='fp32',
277 | )
278 | parser.add_argument(
279 | '--image-mean', type=float, nargs='+', default=None, metavar='MEAN',
280 | help='Override default image mean value of dataset')
281 | parser.add_argument(
282 | '--image-std', type=float, nargs='+', default=None, metavar='STD',
283 | help='Override default image std deviation of of dataset')
284 | parser.add_argument(
285 | '--image-interpolation',
286 | default=None, type=str, choices=['bicubic', 'bilinear', 'random'],
287 | help="image resize interpolation"
288 | )
289 | parser.add_argument(
290 | '--image-resize-mode',
291 | default=None, type=str, choices=['shortest', 'longest', 'squash'],
292 | help="image resize mode during inference"
293 | )
294 | parser.add_argument(
295 | "--hf-tokenizer-self",
296 | default=False,
297 | action="store_true",
298 | help="make hf_tokenizer_name point in uploaded config point to itself"
299 | )
300 | args = parser.parse_args()
301 |
302 | print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}')
303 |
304 | # FIXME add support to pass model_card json / template from file via cmd line
305 |
306 | push_pretrained_to_hf_hub(
307 | args.model,
308 | args.pretrained,
309 | args.repo_id,
310 | precision=args.precision,
311 | image_mean=args.image_mean, # override image mean/std if trained w/ non defaults
312 | image_std=args.image_std,
313 | image_interpolation=args.image_interpolation,
314 | image_resize_mode=args.image_resize_mode,
315 | )
316 |
317 | print(f'{args.model} saved.')
318 |
--------------------------------------------------------------------------------
/opencls/open_clip/timm_model.py:
--------------------------------------------------------------------------------
1 | """ timm model adapter
2 |
3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4 | """
5 | import logging
6 | from collections import OrderedDict
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | try:
12 | import timm
13 | from timm.models.layers import Mlp, to_2tuple
14 | try:
15 | # old timm imports < 0.8.1
16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d
17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
18 | except ImportError:
19 | # new timm imports >= 0.8.1
20 | from timm.layers import RotAttentionPool2d
21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d
22 | except ImportError:
23 | timm = None
24 |
25 | from .utils import freeze_batch_norm_2d
26 |
27 |
28 | class TimmModel(nn.Module):
29 | """ timm model adapter
30 | """
31 |
32 | def __init__(
33 | self,
34 | model_name,
35 | embed_dim,
36 | image_size=224,
37 | pool='avg',
38 | proj='linear',
39 | proj_bias=False,
40 | drop=0.,
41 | drop_path=None,
42 | patch_drop=None,
43 | pretrained=False,
44 | ):
45 | super().__init__()
46 | if timm is None:
47 | raise RuntimeError("Please `pip install timm` to use timm models.")
48 | self.image_size = to_2tuple(image_size)
49 |
50 | # setup kwargs that may not be common across all models
51 | timm_kwargs = {}
52 | if drop_path is not None:
53 | timm_kwargs['drop_path_rate'] = drop_path
54 | if patch_drop is not None:
55 | timm_kwargs['patch_drop_rate'] = patch_drop
56 |
57 | custom_pool = pool in ('abs_attn', 'rot_attn')
58 | if proj:
59 | assert proj in ("linear", "mlp", "none")
60 | extra_proj = proj in ("linear", "mlp")
61 | if not extra_proj and not custom_pool:
62 | # use network classifier head as projection if no proj specified and no custom pooling used
63 | # if projection is explicitly set to "none" will be pass through from network trunk
64 | proj_dim = 0 if proj == 'none' else embed_dim
65 | self.trunk = timm.create_model(
66 | model_name,
67 | num_classes=proj_dim,
68 | global_pool=pool,
69 | pretrained=pretrained,
70 | **timm_kwargs,
71 | )
72 | prev_chs = embed_dim
73 | else:
74 | self.trunk = timm.create_model(
75 | model_name,
76 | pretrained=pretrained,
77 | **timm_kwargs,
78 | )
79 | feat_size = self.trunk.default_cfg.get('pool_size', None)
80 | feature_ndim = 1 if not feat_size else 2
81 | if custom_pool:
82 | assert feature_ndim == 2
83 | # if attn pooling used, remove both classifier and default pool
84 | self.trunk.reset_classifier(0, global_pool='')
85 | else:
86 | # reset global pool if pool config set, otherwise leave as network default
87 | reset_kwargs = dict(global_pool=pool) if pool else {}
88 | self.trunk.reset_classifier(0, **reset_kwargs)
89 | prev_chs = self.trunk.num_features
90 |
91 | head_layers = OrderedDict()
92 |
93 | # Add custom pooling to head
94 | if pool == 'abs_attn':
95 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
96 | prev_chs = embed_dim
97 | elif pool == 'rot_attn':
98 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
99 | prev_chs = embed_dim
100 |
101 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
102 | if proj == 'linear':
103 | head_layers['drop'] = nn.Dropout(drop)
104 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
105 | elif proj == 'mlp':
106 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
107 |
108 | self.head = nn.Sequential(head_layers)
109 |
110 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
111 | """ lock modules
112 | Args:
113 | unlocked_groups (int): leave last n layer groups unlocked (default: 0)
114 | """
115 | if not unlocked_groups:
116 | # lock full model
117 | for param in self.trunk.parameters():
118 | param.requires_grad = False
119 | if freeze_bn_stats:
120 | freeze_batch_norm_2d(self.trunk)
121 | else:
122 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change
123 | try:
124 | # FIXME import here until API stable and in an official release
125 | from timm.models.helpers import group_parameters, group_modules
126 | except ImportError:
127 | raise RuntimeError(
128 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
129 | matcher = self.trunk.group_matcher()
130 | gparams = group_parameters(self.trunk, matcher)
131 | max_layer_id = max(gparams.keys())
132 | max_layer_id = max_layer_id - unlocked_groups
133 | for group_idx in range(max_layer_id + 1):
134 | group = gparams[group_idx]
135 | for param in group:
136 | self.trunk.get_parameter(param).requires_grad = False
137 | if freeze_bn_stats:
138 | gmodules = group_modules(self.trunk, matcher, reverse=True)
139 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
140 | freeze_batch_norm_2d(self.trunk, gmodules)
141 |
142 | @torch.jit.ignore
143 | def set_grad_checkpointing(self, enable=True):
144 | try:
145 | self.trunk.set_grad_checkpointing(enable)
146 | except Exception as e:
147 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
148 |
149 | def forward(self, x):
150 | x = self.trunk(x)
151 | x = self.head(x)
152 | return x
153 |
--------------------------------------------------------------------------------
/opencls/open_clip/transform.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import random
3 | import warnings
4 | from dataclasses import dataclass, asdict
5 | from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
6 |
7 | import torch
8 | import torchvision.transforms.functional as F
9 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
10 | CenterCrop, ColorJitter, Grayscale
11 |
12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13 | from .utils import to_2tuple
14 |
15 |
16 | @dataclass
17 | class PreprocessCfg:
18 | size: Union[int, Tuple[int, int]] = 224
19 | mode: str = 'RGB'
20 | mean: Tuple[float, ...] = OPENAI_DATASET_MEAN
21 | std: Tuple[float, ...] = OPENAI_DATASET_STD
22 | interpolation: str = 'bicubic'
23 | resize_mode: str = 'shortest'
24 | fill_color: int = 0
25 |
26 | def __post_init__(self):
27 | assert self.mode in ('RGB',)
28 |
29 | @property
30 | def num_channels(self):
31 | return 3
32 |
33 | @property
34 | def input_size(self):
35 | return (self.num_channels,) + to_2tuple(self.size)
36 |
37 | _PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys())
38 |
39 |
40 | def merge_preprocess_dict(
41 | base: Union[PreprocessCfg, Dict],
42 | overlay: Dict,
43 | ):
44 | """ Merge overlay key-value pairs on top of base preprocess cfg or dict.
45 | Input dicts are filtered based on PreprocessCfg fields.
46 | """
47 | if isinstance(base, PreprocessCfg):
48 | base_clean = asdict(base)
49 | else:
50 | base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS}
51 | if overlay:
52 | overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None}
53 | base_clean.update(overlay_clean)
54 | return base_clean
55 |
56 |
57 | def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs):
58 | return merge_preprocess_dict(base, kwargs)
59 |
60 |
61 | @dataclass
62 | class AugmentationCfg:
63 | scale: Tuple[float, float] = (0.9, 1.0)
64 | ratio: Optional[Tuple[float, float]] = None
65 | color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None
66 | re_prob: Optional[float] = None
67 | re_count: Optional[int] = None
68 | use_timm: bool = False
69 |
70 | # params for simclr_jitter_gray
71 | color_jitter_prob: float = None
72 | gray_scale_prob: float = None
73 |
74 |
75 | def _setup_size(size, error_msg):
76 | if isinstance(size, numbers.Number):
77 | return int(size), int(size)
78 |
79 | if isinstance(size, Sequence) and len(size) == 1:
80 | return size[0], size[0]
81 |
82 | if len(size) != 2:
83 | raise ValueError(error_msg)
84 |
85 | return size
86 |
87 |
88 | class ResizeKeepRatio:
89 | """ Resize and Keep Ratio
90 |
91 | Copy & paste from `timm`
92 | """
93 |
94 | def __init__(
95 | self,
96 | size,
97 | longest=0.,
98 | interpolation=InterpolationMode.BICUBIC,
99 | random_scale_prob=0.,
100 | random_scale_range=(0.85, 1.05),
101 | random_aspect_prob=0.,
102 | random_aspect_range=(0.9, 1.11)
103 | ):
104 | if isinstance(size, (list, tuple)):
105 | self.size = tuple(size)
106 | else:
107 | self.size = (size, size)
108 | self.interpolation = interpolation
109 | self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest
110 | self.random_scale_prob = random_scale_prob
111 | self.random_scale_range = random_scale_range
112 | self.random_aspect_prob = random_aspect_prob
113 | self.random_aspect_range = random_aspect_range
114 |
115 | @staticmethod
116 | def get_params(
117 | img,
118 | target_size,
119 | longest,
120 | random_scale_prob=0.,
121 | random_scale_range=(0.85, 1.05),
122 | random_aspect_prob=0.,
123 | random_aspect_range=(0.9, 1.11)
124 | ):
125 | """Get parameters
126 | """
127 | source_size = img.size[::-1] # h, w
128 | h, w = source_size
129 | target_h, target_w = target_size
130 | ratio_h = h / target_h
131 | ratio_w = w / target_w
132 | ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
133 | if random_scale_prob > 0 and random.random() < random_scale_prob:
134 | ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
135 | ratio_factor = (ratio_factor, ratio_factor)
136 | else:
137 | ratio_factor = (1., 1.)
138 | if random_aspect_prob > 0 and random.random() < random_aspect_prob:
139 | aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1])
140 | ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
141 | size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
142 | return size
143 |
144 | def __call__(self, img):
145 | """
146 | Args:
147 | img (PIL Image): Image to be cropped and resized.
148 |
149 | Returns:
150 | PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
151 | """
152 | size = self.get_params(
153 | img, self.size, self.longest,
154 | self.random_scale_prob, self.random_scale_range,
155 | self.random_aspect_prob, self.random_aspect_range
156 | )
157 | img = F.resize(img, size, self.interpolation)
158 | return img
159 |
160 | def __repr__(self):
161 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
162 | format_string += f', interpolation={self.interpolation})'
163 | format_string += f', longest={self.longest:.3f})'
164 | return format_string
165 |
166 |
167 | def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
168 | """Center crops and/or pads the given image.
169 | If the image is torch Tensor, it is expected
170 | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
171 | If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
172 |
173 | Args:
174 | img (PIL Image or Tensor): Image to be cropped.
175 | output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
176 | it is used for both directions.
177 | fill (int, Tuple[int]): Padding color
178 |
179 | Returns:
180 | PIL Image or Tensor: Cropped image.
181 | """
182 | if isinstance(output_size, numbers.Number):
183 | output_size = (int(output_size), int(output_size))
184 | elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
185 | output_size = (output_size[0], output_size[0])
186 |
187 | _, image_height, image_width = F.get_dimensions(img)
188 | crop_height, crop_width = output_size
189 |
190 | if crop_width > image_width or crop_height > image_height:
191 | padding_ltrb = [
192 | (crop_width - image_width) // 2 if crop_width > image_width else 0,
193 | (crop_height - image_height) // 2 if crop_height > image_height else 0,
194 | (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
195 | (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
196 | ]
197 | img = F.pad(img, padding_ltrb, fill=fill)
198 | _, image_height, image_width = F.get_dimensions(img)
199 | if crop_width == image_width and crop_height == image_height:
200 | return img
201 |
202 | crop_top = int(round((image_height - crop_height) / 2.0))
203 | crop_left = int(round((image_width - crop_width) / 2.0))
204 | return F.crop(img, crop_top, crop_left, crop_height, crop_width)
205 |
206 |
207 | class CenterCropOrPad(torch.nn.Module):
208 | """Crops the given image at the center.
209 | If the image is torch Tensor, it is expected
210 | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
211 | If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
212 |
213 | Args:
214 | size (sequence or int): Desired output size of the crop. If size is an
215 | int instead of sequence like (h, w), a square crop (size, size) is
216 | made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
217 | """
218 |
219 | def __init__(self, size, fill=0):
220 | super().__init__()
221 | self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
222 | self.fill = fill
223 |
224 | def forward(self, img):
225 | """
226 | Args:
227 | img (PIL Image or Tensor): Image to be cropped.
228 |
229 | Returns:
230 | PIL Image or Tensor: Cropped image.
231 | """
232 | return center_crop_or_pad(img, self.size, fill=self.fill)
233 |
234 | def __repr__(self) -> str:
235 | return f"{self.__class__.__name__}(size={self.size})"
236 |
237 |
238 | def _convert_to_rgb(image):
239 | return image.convert('RGB')
240 |
241 |
242 | class color_jitter(object):
243 | """
244 | Apply Color Jitter to the PIL image with a specified probability.
245 | """
246 | def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., p=0.8):
247 | assert 0. <= p <= 1.
248 | self.p = p
249 | self.transf = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
250 |
251 | def __call__(self, img):
252 | if random.random() < self.p:
253 | return self.transf(img)
254 | else:
255 | return img
256 |
257 |
258 | class gray_scale(object):
259 | """
260 | Apply Gray Scale to the PIL image with a specified probability.
261 | """
262 | def __init__(self, p=0.2):
263 | assert 0. <= p <= 1.
264 | self.p = p
265 | self.transf = Grayscale(num_output_channels=3)
266 |
267 | def __call__(self, img):
268 | if random.random() < self.p:
269 | return self.transf(img)
270 | else:
271 | return img
272 |
273 |
274 | def image_transform(
275 | image_size: Union[int, Tuple[int, int]],
276 | is_train: bool,
277 | mean: Optional[Tuple[float, ...]] = None,
278 | std: Optional[Tuple[float, ...]] = None,
279 | resize_mode: Optional[str] = None,
280 | interpolation: Optional[str] = None,
281 | fill_color: int = 0,
282 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
283 | ):
284 | mean = mean or OPENAI_DATASET_MEAN
285 | if not isinstance(mean, (list, tuple)):
286 | mean = (mean,) * 3
287 |
288 | std = std or OPENAI_DATASET_STD
289 | if not isinstance(std, (list, tuple)):
290 | std = (std,) * 3
291 |
292 | interpolation = interpolation or 'bicubic'
293 | assert interpolation in ['bicubic', 'bilinear', 'random']
294 | # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set
295 | interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC
296 |
297 | resize_mode = resize_mode or 'shortest'
298 | assert resize_mode in ('shortest', 'longest', 'squash')
299 |
300 | if isinstance(aug_cfg, dict):
301 | aug_cfg = AugmentationCfg(**aug_cfg)
302 | else:
303 | aug_cfg = aug_cfg or AugmentationCfg()
304 |
305 | normalize = Normalize(mean=mean, std=std)
306 |
307 | if is_train:
308 | aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
309 | use_timm = aug_cfg_dict.pop('use_timm', False)
310 | if use_timm:
311 | from timm.data import create_transform # timm can still be optional
312 | if isinstance(image_size, (tuple, list)):
313 | assert len(image_size) >= 2
314 | input_size = (3,) + image_size[-2:]
315 | else:
316 | input_size = (3, image_size, image_size)
317 |
318 | aug_cfg_dict.setdefault('color_jitter', None) # disable by default
319 | # drop extra non-timm items
320 | aug_cfg_dict.pop('color_jitter_prob', None)
321 | aug_cfg_dict.pop('gray_scale_prob', None)
322 |
323 | train_transform = create_transform(
324 | input_size=input_size,
325 | is_training=True,
326 | hflip=0.,
327 | mean=mean,
328 | std=std,
329 | re_mode='pixel',
330 | interpolation=interpolation,
331 | **aug_cfg_dict,
332 | )
333 | else:
334 | train_transform = [
335 | RandomResizedCrop(
336 | image_size,
337 | scale=aug_cfg_dict.pop('scale'),
338 | interpolation=InterpolationMode.BICUBIC,
339 | ),
340 | _convert_to_rgb,
341 | ]
342 | if aug_cfg.color_jitter_prob:
343 | assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4
344 | train_transform.extend([
345 | color_jitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob)
346 | ])
347 | if aug_cfg.gray_scale_prob:
348 | train_transform.extend([
349 | gray_scale(aug_cfg.gray_scale_prob)
350 | ])
351 | train_transform.extend([
352 | ToTensor(),
353 | normalize,
354 | ])
355 | train_transform = Compose(train_transform)
356 | if aug_cfg_dict:
357 | warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
358 | return train_transform
359 | else:
360 | if resize_mode == 'longest':
361 | transforms = [
362 | ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1),
363 | CenterCropOrPad(image_size, fill=fill_color)
364 | ]
365 | elif resize_mode == 'squash':
366 | if isinstance(image_size, int):
367 | image_size = (image_size, image_size)
368 | transforms = [
369 | Resize(image_size, interpolation=interpolation_mode),
370 | ]
371 | else:
372 | assert resize_mode == 'shortest'
373 | if not isinstance(image_size, (tuple, list)):
374 | image_size = (image_size, image_size)
375 | if image_size[0] == image_size[1]:
376 | # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
377 | transforms = [
378 | Resize(image_size[0], interpolation=interpolation_mode)
379 | ]
380 | else:
381 | # resize shortest edge to matching target dim for non-square target
382 | transforms = [ResizeKeepRatio(image_size)]
383 | transforms += [CenterCrop(image_size)]
384 |
385 | transforms.extend([
386 | _convert_to_rgb,
387 | ToTensor(),
388 | normalize,
389 | ])
390 | return Compose(transforms)
391 |
392 |
393 | def image_transform_v2(
394 | cfg: PreprocessCfg,
395 | is_train: bool,
396 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
397 | ):
398 | return image_transform(
399 | image_size=cfg.size,
400 | is_train=is_train,
401 | mean=cfg.mean,
402 | std=cfg.std,
403 | interpolation=cfg.interpolation,
404 | resize_mode=cfg.resize_mode,
405 | fill_color=cfg.fill_color,
406 | aug_cfg=aug_cfg,
407 | )
408 |
--------------------------------------------------------------------------------
/opencls/open_clip/utils.py:
--------------------------------------------------------------------------------
1 | from itertools import repeat
2 | import collections.abc
3 |
4 | import torch
5 | from torch import nn as nn
6 | from torchvision.ops.misc import FrozenBatchNorm2d
7 |
8 |
9 | def freeze_batch_norm_2d(module, module_match={}, name=''):
10 | """
11 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
12 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
13 | returned. Otherwise, the module is walked recursively and submodules are converted in place.
14 |
15 | Args:
16 | module (torch.nn.Module): Any PyTorch module.
17 | module_match (dict): Dictionary of full module names to freeze (all if empty)
18 | name (str): Full module name (prefix)
19 |
20 | Returns:
21 | torch.nn.Module: Resulting module
22 |
23 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
24 | """
25 | res = module
26 | is_match = True
27 | if module_match:
28 | is_match = name in module_match
29 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
30 | res = FrozenBatchNorm2d(module.num_features)
31 | res.num_features = module.num_features
32 | res.affine = module.affine
33 | if module.affine:
34 | res.weight.data = module.weight.data.clone().detach()
35 | res.bias.data = module.bias.data.clone().detach()
36 | res.running_mean.data = module.running_mean.data
37 | res.running_var.data = module.running_var.data
38 | res.eps = module.eps
39 | else:
40 | for child_name, child in module.named_children():
41 | full_child_name = '.'.join([name, child_name]) if name else child_name
42 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
43 | if new_child is not child:
44 | res.add_module(child_name, new_child)
45 | return res
46 |
47 |
48 | # From PyTorch internals
49 | def _ntuple(n):
50 | def parse(x):
51 | if isinstance(x, collections.abc.Iterable):
52 | return x
53 | return tuple(repeat(x, n))
54 | return parse
55 |
56 |
57 | to_1tuple = _ntuple(1)
58 | to_2tuple = _ntuple(2)
59 | to_3tuple = _ntuple(3)
60 | to_4tuple = _ntuple(4)
61 | to_ntuple = lambda n, x: _ntuple(n)(x)
62 |
63 | # Replaces all linear layers with linear_replacement
64 | # TODO: add int8 support for other linear layers including attn and convnets
65 | def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True):
66 | for name, module in model.named_children():
67 | if len(list(module.children())) > 0:
68 | replace_linear(module, linear_replacement, include_modules, copy_weights)
69 |
70 | if isinstance(module, torch.nn.Linear) and name in include_modules:
71 | old_module = model._modules[name]
72 | model._modules[name] = linear_replacement(
73 | module.in_features,
74 | module.out_features,
75 | module.bias is not None,
76 | )
77 | if copy_weights:
78 | model._modules[name].weight.data.copy_(old_module.weight.data)
79 | if model._modules[name].bias is not None:
80 | model._modules[name].bias.data.copy_(old_module.bias)
81 |
82 | return model
83 |
84 | def convert_int8_model_to_inference_mode(model):
85 | for m in model.modules():
86 | if hasattr(m, 'prepare_for_eval'):
87 | int8_original_dtype = m.weight.dtype
88 | m.prepare_for_eval()
89 | m.int8_original_dtype = int8_original_dtype
--------------------------------------------------------------------------------
/opencls/open_clip/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '2.24.0'
2 |
--------------------------------------------------------------------------------
/opencls/open_clip/zero_shot_classifier.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from itertools import islice
3 | from typing import Callable, List, Optional, Sequence, Union
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 |
9 | def batched(iterable, n):
10 | """Batch data into lists of length *n*. The last batch may be shorter.
11 | NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl
12 | """
13 | it = iter(iterable)
14 | while True:
15 | batch = list(islice(it, n))
16 | if not batch:
17 | break
18 | yield batch
19 |
20 |
21 | def build_zero_shot_classifier(
22 | model,
23 | tokenizer,
24 | classnames: Sequence[str],
25 | templates: Sequence[Union[Callable, str]],
26 | num_classes_per_batch: Optional[int] = 10,
27 | device: Union[str, torch.device] = 'cpu',
28 | use_tqdm: bool = False,
29 | ):
30 | """ Build zero-shot classifier weights by iterating over class names in batches
31 | Args:
32 | model: CLIP model instance
33 | tokenizer: CLIP tokenizer instance
34 | classnames: A sequence of class (label) names
35 | templates: A sequence of callables or format() friendly strings to produce templates per class name
36 | num_classes_per_batch: The number of classes to batch together in each forward, all if None
37 | device: Device to use.
38 | use_tqdm: Enable TQDM progress bar.
39 | """
40 | assert isinstance(templates, Sequence) and len(templates) > 0
41 | assert isinstance(classnames, Sequence) and len(classnames) > 0
42 | use_format = isinstance(templates[0], str)
43 | num_templates = len(templates)
44 | num_classes = len(classnames)
45 | if use_tqdm:
46 | import tqdm
47 | num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1)
48 | iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch)
49 | else:
50 | iter_wrap = iter
51 |
52 | def _process_batch(batch_classnames):
53 | num_batch_classes = len(batch_classnames)
54 | texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates]
55 | texts = tokenizer(texts).to(device)
56 | class_embeddings = model.encode_text(texts, normalize=True)
57 | class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1)
58 | class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True)
59 | class_embeddings = class_embeddings.T
60 | return class_embeddings
61 |
62 | with torch.no_grad():
63 | if num_classes_per_batch:
64 | batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))]
65 | zeroshot_weights = torch.cat(batched_embeds, dim=1)
66 | else:
67 | zeroshot_weights = _process_batch(classnames)
68 | return zeroshot_weights
69 |
70 |
71 | def build_zero_shot_classifier_legacy(
72 | model,
73 | tokenizer,
74 | classnames: Sequence[str],
75 | templates: Sequence[Union[Callable, str]],
76 | device: Union[str, torch.device] = 'cpu',
77 | use_tqdm: bool = False,
78 | ):
79 | """ Build zero-shot classifier weights by iterating over class names 1 by 1
80 | Args:
81 | model: CLIP model instance
82 | tokenizer: CLIP tokenizer instance
83 | classnames: A sequence of class (label) names
84 | templates: A sequence of callables or format() friendly strings to produce templates per class name
85 | device: Device to use.
86 | use_tqdm: Enable TQDM progress bar.
87 | """
88 | assert isinstance(templates, Sequence) and len(templates) > 0
89 | assert isinstance(classnames, Sequence) and len(classnames) > 0
90 | if use_tqdm:
91 | import tqdm
92 | iter_wrap = tqdm.tqdm
93 | else:
94 | iter_wrap = iter
95 |
96 | use_format = isinstance(templates[0], str)
97 |
98 | with torch.no_grad():
99 | zeroshot_weights = []
100 | for classname in iter_wrap(classnames):
101 | texts = [template.format(classname) if use_format else template(classname) for template in templates]
102 | texts = tokenizer(texts).to(device) # tokenize
103 | class_embeddings = model.encode_text(texts)
104 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
105 | class_embedding /= class_embedding.norm()
106 | zeroshot_weights.append(class_embedding)
107 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
108 |
109 | return zeroshot_weights
110 |
111 |
--------------------------------------------------------------------------------
/opencls/training/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/x-cls/superclass/90b761f41177400caa178d3845b39eba771e5bf1/opencls/training/.DS_Store
--------------------------------------------------------------------------------
/opencls/training/.gitignore:
--------------------------------------------------------------------------------
1 | logs/
2 |
--------------------------------------------------------------------------------
/opencls/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/x-cls/superclass/90b761f41177400caa178d3845b39eba771e5bf1/opencls/training/__init__.py
--------------------------------------------------------------------------------
/opencls/training/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.distributed as dist
5 |
6 | try:
7 | import horovod.torch as hvd
8 | except ImportError:
9 | hvd = None
10 |
11 |
12 | def is_global_master(args):
13 | return args.rank == 0
14 |
15 |
16 | def is_local_master(args):
17 | return args.local_rank == 0
18 |
19 |
20 | def is_master(args, local=False):
21 | return is_local_master(args) if local else is_global_master(args)
22 |
23 |
24 | def is_using_horovod():
25 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
26 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
27 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
28 | pmi_vars = ["PMI_RANK", "PMI_SIZE"]
29 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]):
30 | return True
31 | else:
32 | return False
33 |
34 |
35 | def is_using_distributed():
36 | if 'WORLD_SIZE' in os.environ:
37 | return int(os.environ['WORLD_SIZE']) > 1
38 | if 'SLURM_NTASKS' in os.environ:
39 | return int(os.environ['SLURM_NTASKS']) > 1
40 | return False
41 |
42 |
43 | def world_info_from_env():
44 | local_rank = 0
45 | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
46 | if v in os.environ:
47 | local_rank = int(os.environ[v])
48 | break
49 | global_rank = 0
50 | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
51 | if v in os.environ:
52 | global_rank = int(os.environ[v])
53 | break
54 | world_size = 1
55 | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
56 | if v in os.environ:
57 | world_size = int(os.environ[v])
58 | break
59 |
60 | return local_rank, global_rank, world_size
61 |
62 |
63 | def init_distributed_device(args):
64 | # Distributed training = training on more than one GPU.
65 | # Works in both single and multi-node scenarios.
66 | args.distributed = False
67 | args.world_size = 1
68 | args.rank = 0 # global rank
69 | args.local_rank = 0
70 | if args.horovod:
71 | assert hvd is not None, "Horovod is not installed"
72 | hvd.init()
73 | args.local_rank = int(hvd.local_rank())
74 | args.rank = hvd.rank()
75 | args.world_size = hvd.size()
76 | args.distributed = True
77 | os.environ['LOCAL_RANK'] = str(args.local_rank)
78 | os.environ['RANK'] = str(args.rank)
79 | os.environ['WORLD_SIZE'] = str(args.world_size)
80 | elif is_using_distributed():
81 | if 'SLURM_PROCID' in os.environ:
82 | # DDP via SLURM
83 | args.local_rank, args.rank, args.world_size = world_info_from_env()
84 | # SLURM var -> torch.distributed vars in case needed
85 | os.environ['LOCAL_RANK'] = str(args.local_rank)
86 | os.environ['RANK'] = str(args.rank)
87 | os.environ['WORLD_SIZE'] = str(args.world_size)
88 | torch.distributed.init_process_group(
89 | backend=args.dist_backend,
90 | init_method=args.dist_url,
91 | world_size=args.world_size,
92 | rank=args.rank,
93 | )
94 | else:
95 | # DDP via torchrun, torch.distributed.launch
96 | args.local_rank, _, _ = world_info_from_env()
97 | torch.distributed.init_process_group(
98 | backend=args.dist_backend,
99 | init_method=args.dist_url)
100 | args.world_size = torch.distributed.get_world_size()
101 | args.rank = torch.distributed.get_rank()
102 | args.distributed = True
103 |
104 | if torch.cuda.is_available():
105 | if args.distributed and not args.no_set_device_rank:
106 | device = 'cuda:%d' % args.local_rank
107 | else:
108 | device = 'cuda:0'
109 | torch.cuda.set_device(device)
110 | else:
111 | device = 'cpu'
112 | args.device = device
113 | device = torch.device(device)
114 | return device
115 |
116 |
117 | def broadcast_object(args, obj, src=0):
118 | # broadcast a pickle-able python object from rank-0 to all ranks
119 | if args.horovod:
120 | return hvd.broadcast_object(obj, root_rank=src)
121 | else:
122 | if args.rank == src:
123 | objects = [obj]
124 | else:
125 | objects = [None]
126 | dist.broadcast_object_list(objects, src=src)
127 | return objects[0]
128 |
129 |
130 | def all_gather_object(args, obj, dst=0):
131 | # gather a pickle-able python object across all ranks
132 | if args.horovod:
133 | return hvd.allgather_object(obj)
134 | else:
135 | objects = [None for _ in range(args.world_size)]
136 | dist.all_gather_object(objects, obj)
137 | return objects
138 |
--------------------------------------------------------------------------------
/opencls/training/file_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import multiprocessing
4 | import subprocess
5 | import time
6 | import fsspec
7 | import torch
8 | from tqdm import tqdm
9 |
10 | def remote_sync_s3(local_dir, remote_dir):
11 | # skip epoch_latest which can change during sync.
12 | result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
13 | if result.returncode != 0:
14 | logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}")
15 | return False
16 |
17 | logging.info(f"Successfully synced with S3 bucket")
18 | return True
19 |
20 | def remote_sync_fsspec(local_dir, remote_dir):
21 | # FIXME currently this is slow and not recommended. Look into speeding up.
22 | a = fsspec.get_mapper(local_dir)
23 | b = fsspec.get_mapper(remote_dir)
24 |
25 | for k in a:
26 | # skip epoch_latest which can change during sync.
27 | if 'epoch_latest.pt' in k:
28 | continue
29 |
30 | logging.info(f'Attempting to sync {k}')
31 | if k in b and len(a[k]) == len(b[k]):
32 | logging.debug(f'Skipping remote sync for {k}.')
33 | continue
34 |
35 | try:
36 | logging.info(f'Successful sync for {k}.')
37 | b[k] = a[k]
38 | except Exception as e:
39 | logging.info(f'Error during remote sync for {k}: {e}')
40 | return False
41 |
42 | return True
43 |
44 | def remote_sync(local_dir, remote_dir, protocol):
45 | logging.info('Starting remote sync.')
46 | if protocol == 's3':
47 | return remote_sync_s3(local_dir, remote_dir)
48 | elif protocol == 'fsspec':
49 | return remote_sync_fsspec(local_dir, remote_dir)
50 | else:
51 | logging.error('Remote protocol not known')
52 | return False
53 |
54 | def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol):
55 | while True:
56 | time.sleep(sync_every)
57 | remote_sync(local_dir, remote_dir, protocol)
58 |
59 | def start_sync_process(sync_every, local_dir, remote_dir, protocol):
60 | p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol))
61 | return p
62 |
63 | # Note: we are not currently using this save function.
64 | def pt_save(pt_obj, file_path):
65 | of = fsspec.open(file_path, "wb")
66 | with of as f:
67 | torch.save(pt_obj, file_path)
68 |
69 | def pt_load(file_path, map_location=None):
70 | if file_path.startswith('s3'):
71 | logging.info('Loading remote checkpoint, which may take a bit.')
72 | of = fsspec.open(file_path, "rb")
73 | with of as f:
74 | out = torch.load(f, map_location=map_location)
75 | return out
76 |
77 | def check_exists(file_path):
78 | try:
79 | with fsspec.open(file_path):
80 | pass
81 | except FileNotFoundError:
82 | return False
83 | return True
84 |
--------------------------------------------------------------------------------
/opencls/training/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | def setup_logging(log_file, level, include_host=False):
5 | if include_host:
6 | import socket
7 | hostname = socket.gethostname()
8 | formatter = logging.Formatter(
9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
10 | else:
11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
12 |
13 | logging.root.setLevel(level)
14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
15 | for logger in loggers:
16 | logger.setLevel(level)
17 |
18 | stream_handler = logging.StreamHandler()
19 | stream_handler.setFormatter(formatter)
20 | logging.root.addHandler(stream_handler)
21 |
22 | if log_file:
23 | file_handler = logging.FileHandler(filename=log_file)
24 | file_handler.setFormatter(formatter)
25 | logging.root.addHandler(file_handler)
26 |
27 |
--------------------------------------------------------------------------------
/opencls/training/precision.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from contextlib import suppress
3 |
4 |
5 | def get_autocast(precision):
6 | if precision == 'amp':
7 | return torch.cuda.amp.autocast
8 | elif precision == 'amp_bfloat16' or precision == 'amp_bf16':
9 | # amp_bfloat16 is more stable than amp float16 for clip training
10 | return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
11 | else:
12 | return suppress
13 |
--------------------------------------------------------------------------------
/opencls/training/profiler0.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import open_clip
5 | import pandas as pd
6 | from torch.utils.flop_counter import FlopCounterMode
7 | try:
8 | import fvcore
9 | except:
10 | fvcore = None
11 |
12 | parser = argparse.ArgumentParser(description='OpenCLIP Profiler')
13 |
14 | # benchmark specific args
15 | parser.add_argument('--model', metavar='NAME', default='',
16 | help='model(s) to profile')
17 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
18 | help='Output csv file for results')
19 | parser.add_argument('--profiler', default='torch', type=str, choices=['torch', 'fvcore'])
20 | parser.add_argument('--batch-size', default=1, type=int, help='Batch size for profiling')
21 |
22 |
23 | def profile_fvcore(
24 | model,
25 | image_input_size=(3, 224, 224),
26 | text_input_size=(77,),
27 | batch_size=1,
28 | detailed=False,
29 | force_cpu=False
30 | ):
31 | if force_cpu:
32 | model = model.to('cpu')
33 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
34 | example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
35 | example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
36 | fca = fvcore.nn.FlopCountAnalysis(model, (example_image_input, example_text_input))
37 | aca = fvcore.nn.ActivationCountAnalysis(model, (example_image_input, example_text_input))
38 | if detailed:
39 | fcs = fvcore.nn.flop_count_str(fca)
40 | print(fcs)
41 | return fca.total() / batch_size, aca.total() / batch_size
42 |
43 |
44 | def profile_fvcore_text(
45 | model,
46 | text_input_size=(77,),
47 | batch_size=1,
48 | detailed=False,
49 | force_cpu=False
50 | ):
51 | if force_cpu:
52 | model = model.to('cpu')
53 | device = next(model.parameters()).device
54 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
55 | fca = fvcore.nn.FlopCountAnalysis(model, example_input)
56 | aca = fvcore.nn.ActivationCountAnalysis(model, example_input)
57 | if detailed:
58 | fcs = fvcore.nn.flop_count_str(fca)
59 | print(fcs)
60 | return fca.total() / batch_size, aca.total() / batch_size
61 |
62 |
63 | def profile_fvcore_image(
64 | model,
65 | image_input_size=(3, 224, 224),
66 | batch_size=1,
67 | detailed=False,
68 | force_cpu=False
69 | ):
70 | if force_cpu:
71 | model = model.to('cpu')
72 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
73 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
74 | fca = fvcore.nn.FlopCountAnalysis(model, example_input)
75 | aca = fvcore.nn.ActivationCountAnalysis(model, example_input)
76 | if detailed:
77 | fcs = fvcore.nn.flop_count_str(fca)
78 | print(fcs)
79 | return fca.total() / batch_size, aca.total() / batch_size
80 |
81 |
82 | def profile_torch_image(model, image_input_size, batch_size=1, force_cpu=False):
83 | """Profile the image encoder using torch.utils.flop_counter"""
84 | if force_cpu:
85 | model = model.to('cpu')
86 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
87 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
88 |
89 | flop_counter = FlopCounterMode()
90 | with flop_counter:
91 | model(example_input)
92 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
93 | return total_flops / batch_size
94 |
95 |
96 | def profile_torch_text(model, text_input_size, batch_size=1, force_cpu=False):
97 | """Profile the text encoder using torch.utils.flop_counter"""
98 | if force_cpu:
99 | model = model.to('cpu')
100 | device = next(model.parameters()).device
101 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
102 |
103 | flop_counter = FlopCounterMode()
104 | with flop_counter:
105 | model(example_input)
106 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
107 | return total_flops / batch_size
108 |
109 |
110 | def profile_torch(model, text_input_size, image_input_size, batch_size=1, force_cpu=False):
111 | """Profile the full model using torch.utils.flop_counter"""
112 | if force_cpu:
113 | model = model.to('cpu')
114 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
115 | image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
116 | text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
117 |
118 | flop_counter = FlopCounterMode()
119 | with flop_counter:
120 | model(image_input, text_input)
121 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
122 | return total_flops / batch_size
123 |
124 |
125 | def count_params(model):
126 | return sum(m.numel() for m in model.parameters())
127 |
128 | def profile_model(model_name, batch_size=1, profiler='torch'):
129 | assert profiler in ['torch', 'fvcore'], 'Only torch and fvcore profilers are supported'
130 | if profiler == 'fvcore':
131 | assert fvcore is not None, 'Please install fvcore.'
132 | model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False)
133 | model.eval()
134 | if torch.cuda.is_available():
135 | model = model.cuda()
136 |
137 | if isinstance(model.visual.image_size, (tuple, list)):
138 | image_input_size = (3,) + tuple(model.visual.image_size[-2:])
139 | else:
140 | image_input_size = (3, model.visual.image_size, model.visual.image_size)
141 |
142 | text_input_size = (77,)
143 | if hasattr(model, 'context_length') and model.context_length:
144 | text_input_size = (model.context_length,)
145 |
146 | results = {}
147 | results['model'] = model_name
148 | results['image_size'] = image_input_size[1]
149 |
150 | model_cfg = open_clip.get_model_config(model_name)
151 | if model_cfg:
152 | vision_cfg = open_clip.CLIPVisionCfg(**model_cfg['vision_cfg'])
153 | text_cfg = open_clip.CLIPTextCfg(**model_cfg['text_cfg'])
154 | results['image_width'] = int(vision_cfg.width)
155 | results['text_width'] = int(text_cfg.width)
156 | results['embed_dim'] = int(model_cfg['embed_dim'])
157 | else:
158 | results['image_width'] = 0
159 | results['text_width'] = 0
160 | results['embed_dim'] = 0
161 |
162 | retries = 2
163 | while retries:
164 | retries -= 1
165 | try:
166 | results['mparams'] = round(count_params(model) / 1e6, 2)
167 | results['image_mparams'] = round(count_params(model.visual) / 1e6, 2)
168 | results['text_mparams'] = round(count_params(model.text) / 1e6, 2)
169 |
170 | if profiler == 'fvcore':
171 | macs, acts = profile_fvcore(
172 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
173 |
174 | image_macs, image_acts = profile_fvcore_image(
175 | model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size)
176 |
177 | text_macs, text_acts = profile_fvcore_text(
178 | model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
179 |
180 | results['gmacs'] = round(macs / 1e9, 2)
181 | results['macts'] = round(acts / 1e6, 2)
182 |
183 | results['image_gmacs'] = round(image_macs / 1e9, 2)
184 | results['image_macts'] = round(image_acts / 1e6, 2)
185 |
186 | results['text_gmacs'] = round(text_macs / 1e9, 2)
187 | results['text_macts'] = round(text_acts / 1e6, 2)
188 | elif profiler == 'torch':
189 | image_flops = profile_torch_image(
190 | model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size)
191 | text_flops = profile_torch_text(
192 | model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
193 | total_flops = profile_torch(
194 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
195 |
196 | results['gflops'] = round(total_flops / 1e9, 2)
197 | results['image_gflops'] = round(image_flops / 1e9, 2)
198 | results['text_gflops'] = round(text_flops / 1e9, 2)
199 |
200 | except RuntimeError as e:
201 | pass
202 | return results
203 |
204 |
205 | def main():
206 | args = parser.parse_args()
207 |
208 | # FIXME accept a text file name to allow lists of models in txt/csv
209 | if args.model == 'all':
210 | parsed_model = open_clip.list_models()
211 | else:
212 | parsed_model = args.model.split(',')
213 |
214 | results = []
215 | models_with_errors = []
216 | for m in parsed_model:
217 | print('='*100)
218 | print(f'Profiling {m}')
219 | try:
220 | row = profile_model(m, batch_size=args.batch_size, profiler=args.profiler)
221 | results.append(row)
222 | except Exception as e:
223 | print(f'Error profiling {m}: {e}')
224 | import traceback
225 | traceback.print_exc()
226 | models_with_errors.append(m)
227 |
228 | df = pd.DataFrame(results, columns=results[0].keys())
229 |
230 | if 'gmacs' in df.columns:
231 | df = df.sort_values(by=['gmacs', 'mparams', 'model'])
232 | else:
233 | df = df.sort_values(by=['gflops', 'mparams', 'model'])
234 |
235 | print('='*100)
236 | print('Done.')
237 | print(df)
238 | if args.results_file:
239 | df.to_csv(args.results_file, index=False)
240 |
241 | if models_with_errors:
242 | print('Models with errors:', models_with_errors)
243 |
244 |
245 | if __name__ == '__main__':
246 | main()
247 |
--------------------------------------------------------------------------------
/opencls/training/scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def assign_learning_rate(optimizer, new_lr):
5 | for param_group in optimizer.param_groups:
6 | param_group["lr"] = new_lr
7 |
8 |
9 | def _warmup_lr(base_lr, warmup_length, step):
10 | return base_lr * (step + 1) / warmup_length
11 |
12 |
13 | def const_lr(optimizer, base_lr, warmup_length, steps):
14 | def _lr_adjuster(step):
15 | if step < warmup_length:
16 | lr = _warmup_lr(base_lr, warmup_length, step)
17 | else:
18 | lr = base_lr
19 | assign_learning_rate(optimizer, lr)
20 | return lr
21 | return _lr_adjuster
22 |
23 |
24 | def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.):
25 | def _lr_adjuster(step):
26 | start_cooldown_step = steps - cooldown_steps
27 | if step < warmup_length:
28 | lr = _warmup_lr(base_lr, warmup_length, step)
29 | else:
30 | if step < start_cooldown_step:
31 | lr = base_lr
32 | else:
33 | e = step - start_cooldown_step
34 | es = steps - start_cooldown_step
35 | # linear decay if power == 1; polynomial decay otherwise;
36 | decay = (1 - (e/es)) ** cooldown_power
37 | lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr
38 | assign_learning_rate(optimizer, lr)
39 | return lr
40 | return _lr_adjuster
41 |
42 |
43 | def cosine_lr(optimizer, base_lr, warmup_length, steps):
44 | def _lr_adjuster(step):
45 | if step < warmup_length:
46 | lr = _warmup_lr(base_lr, warmup_length, step)
47 | else:
48 | e = step - warmup_length
49 | es = steps - warmup_length
50 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
51 | assign_learning_rate(optimizer, lr)
52 | return lr
53 | return _lr_adjuster
54 |
--------------------------------------------------------------------------------
/opencls/training/zero_shot.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | from tqdm import tqdm
5 |
6 | from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \
7 | IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES
8 | from .precision import get_autocast
9 |
10 |
11 | def accuracy(output, target, topk=(1,)):
12 | pred = output.topk(max(topk), 1, True, True)[1].t()
13 | correct = pred.eq(target.view(1, -1).expand_as(pred))
14 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
15 |
16 |
17 | def run(model, classifier, dataloader, args):
18 | autocast = get_autocast(args.precision)
19 | input_dtype = get_input_dtype(args.precision)
20 |
21 | with torch.no_grad():
22 | top1, top5, n = 0., 0., 0.
23 | for images, target in tqdm(dataloader, unit_scale=args.batch_size):
24 | images = images.to(device=args.device, dtype=input_dtype)
25 | target = target.to(args.device)
26 |
27 | with autocast():
28 | # predict
29 | output = model(image=images)
30 | image_features = output['image_features'] if isinstance(output, dict) else output[0]
31 | logits = 100. * image_features @ classifier
32 |
33 | # measure accuracy
34 | acc1, acc5 = accuracy(logits, target, topk=(1, 5))
35 | top1 += acc1
36 | top5 += acc5
37 | n += images.size(0)
38 |
39 | top1 = (top1 / n)
40 | top5 = (top5 / n)
41 | return top1, top5
42 |
43 |
44 | def zero_shot_eval(model, data, epoch, args, tokenizer=None, should_zero_eval=False):
45 | if 'imagenet-val' not in data and 'imagenet-v2' not in data:
46 | return {}
47 | # if args.zeroshot_frequency == 0:
48 | # return {}
49 | # if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
50 | # return {}
51 | if not should_zero_eval:
52 | return {}
53 | if args.distributed and not args.horovod:
54 | model = model.module
55 |
56 | logging.info('Starting zero-shot imagenet.')
57 | if tokenizer is None:
58 | tokenizer = get_tokenizer(args.model)
59 |
60 | logging.info('Building zero-shot classifier')
61 | autocast = get_autocast(args.precision)
62 | with autocast():
63 | classifier = build_zero_shot_classifier(
64 | model,
65 | tokenizer=tokenizer,
66 | classnames=IMAGENET_CLASSNAMES,
67 | templates=OPENAI_IMAGENET_TEMPLATES,
68 | num_classes_per_batch=10,
69 | device=args.device,
70 | use_tqdm=True,
71 | )
72 |
73 | logging.info('Using classifier')
74 | results = {}
75 | if 'imagenet-val' in data:
76 | top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args)
77 | results['imagenet-zeroshot-val-top1'] = top1
78 | results['imagenet-zeroshot-val-top5'] = top5
79 | if 'imagenet-v2' in data:
80 | top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args)
81 | results['imagenetv2-zeroshot-val-top1'] = top1
82 | results['imagenetv2-zeroshot-val-top5'] = top5
83 |
84 | logging.info('Finished zero-shot imagenet.')
85 |
86 | return results
87 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.9.0
2 | torchvision
3 | regex
4 | ftfy
5 | tqdm
6 | huggingface_hub
7 | safetensors
8 | timm
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | CONFIG=$1
2 | DIR=$2
3 | NV_GPUS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
4 | GPUS=${GPUS:-${NV_GPUS}}
5 | NNODES=${NNODES:-1}
6 | NODE_RANK=${NODE_RANK:-0}
7 | PORT=${PORT:-55565}
8 | PORT=9007
9 | DIR=${DIR:-"opencls"}
10 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
11 | HOSTE_NODE_ADDR=${MASTER_ADDR}:${PORT}
12 | TIMESTAMP=$(date +%Y_%m_%d-%H_%M_%S)
13 | export TOKENIZERS_PARALLELISM=true
14 |
15 | DATA_PATH="{your_path_to_datacomp-1b-webdataset}/{000000..140146}.tar"
16 | VAL_DATA_PATH="{your_path_to_imagenet1k}/ILSVRC/Data/CLS-LOC/val"
17 |
18 | echo "$DIR"
19 | cd $DIR
20 |
21 | torchrun --nproc_per_node=$GPUS \
22 | --rdzv_endpoint=$HOSTE_NODE_ADDR \
23 | --nnodes=$NNODES --node_rank=$NODE_RANK \
24 | -m training.main \
25 | --config="${CONFIG}" \
26 | --train-data $DATA_PATH \
27 | --imagenet-val $VAL_DATA_PATH
--------------------------------------------------------------------------------
/train_combo.sh:
--------------------------------------------------------------------------------
1 | CLS_CONFIG=$1
2 | LIT_CONFIG=$2
3 | DIR=$3
4 | NV_GPUS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
5 | GPUS=${GPUS:-${NV_GPUS}}
6 | NNODES=${NNODES:-1}
7 | NODE_RANK=${NODE_RANK:-0}
8 | PORT=${PORT:-55565}
9 | DIR=${DIR:-"opencls"}
10 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
11 | HOSTE_NODE_ADDR=${MASTER_ADDR}:${PORT}
12 | TIMESTAMP=$(date +%Y_%m_%d-%H_%M_%S)
13 | export TOKENIZERS_PARALLELISM=true
14 |
15 | DATA_PATH="{your_path_to_datacomp-1b-webdataset}/{000000..140146}.tar"
16 | VAL_DATA_PATH="{your_path_to_imagenet1k}/ILSVRC/Data/CLS-LOC/val"
17 |
18 | echo "$DIR"
19 | cd $DIR
20 |
21 | torchrun --nproc_per_node=$GPUS \
22 | --rdzv_endpoint=$HOSTE_NODE_ADDR \
23 | --nnodes=$NNODES --node_rank=$NODE_RANK \
24 | -m training.main \
25 | --config="${CLS_CONFIG}" \
26 | --train-data $DATA_PATH \
27 | --imagenet-val $VAL_DATA_PATH
28 |
29 |
30 | torchrun --nproc_per_node=$GPUS \
31 | --rdzv_endpoint=$HOSTE_NODE_ADDR \
32 | --nnodes=$NNODES --node_rank=$NODE_RANK \
33 | -m training.main \
34 | --config="${LIT_CONFIG}" \
35 | --train-data $DATA_PATH \
36 | --imagenet-val $VAL_DATA_PATH
--------------------------------------------------------------------------------