├── .github
└── workflows
│ ├── ci.yml
│ └── python-publish.yml
├── .gitignore
├── CITATION.cff
├── HISTORY.md
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── docs
├── CLIP.png
├── Interacting_with_open_clip.ipynb
├── clip_conceptual_captions.md
├── clip_loss.png
├── clip_recall.png
├── clip_val_loss.png
├── clip_zeroshot.png
├── effective_robustness.png
├── laion2b_clip_zeroshot_b32.png
├── laion_clip_zeroshot.png
├── laion_clip_zeroshot_b16.png
├── laion_clip_zeroshot_b16_plus_240.png
├── laion_clip_zeroshot_l14.png
├── laion_openai_compare_b32.jpg
└── scaling.png
├── requirements-test.txt
├── requirements-training.txt
├── requirements.txt
├── setup.py
├── src
├── data
│ └── gather_cc.py
├── open_clip
│ ├── __init__.py
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── factory.py
│ ├── loss.py
│ ├── model.py
│ ├── model_configs
│ │ ├── RN101-quickgelu.json
│ │ ├── RN101.json
│ │ ├── RN50-quickgelu.json
│ │ ├── RN50.json
│ │ ├── RN50x16.json
│ │ ├── RN50x4.json
│ │ ├── ViT-B-16-plus-240.json
│ │ ├── ViT-B-16-plus.json
│ │ ├── ViT-B-16.json
│ │ ├── ViT-B-32-plus-256.json
│ │ ├── ViT-B-32-quickgelu.json
│ │ ├── ViT-B-32.json
│ │ ├── ViT-H-14.json
│ │ ├── ViT-H-16.json
│ │ ├── ViT-L-14-280.json
│ │ ├── ViT-L-14-336.json
│ │ ├── ViT-L-14.json
│ │ ├── ViT-L-16-320.json
│ │ ├── ViT-L-16.json
│ │ ├── ViT-g-14.json
│ │ ├── timm-efficientnetv2_rw_s.json
│ │ ├── timm-resnet50d.json
│ │ ├── timm-resnetaa50d.json
│ │ ├── timm-resnetblur50.json
│ │ ├── timm-swin_base_patch4_window7_224.json
│ │ ├── timm-vit_base_patch16_224.json
│ │ ├── timm-vit_base_patch32_224.json
│ │ └── timm-vit_small_patch16_224.json
│ ├── openai.py
│ ├── pretrained.py
│ ├── timm_model.py
│ ├── tokenizer.py
│ ├── transform.py
│ ├── utils.py
│ └── version.py
└── training
│ ├── .gitignore
│ ├── __init__.py
│ ├── data.py
│ ├── distributed.py
│ ├── imagenet_zeroshot_data.py
│ ├── logger.py
│ ├── main.py
│ ├── params.py
│ ├── scheduler.py
│ ├── train.py
│ └── zero_shot.py
└── tests
└── test_simple.py
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: Continuous integration
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 | branches:
9 | - main
10 |
11 | jobs:
12 | tests:
13 | runs-on: ubuntu-latest
14 | strategy:
15 | matrix:
16 | python-version: [3.8]
17 |
18 | steps:
19 | - uses: actions/checkout@v2
20 | - name: Set up Python ${{ matrix.python-version }}
21 | uses: actions/setup-python@v2
22 | with:
23 | python-version: ${{ matrix.python-version }}
24 | - name: Install
25 | run: |
26 | python3 -m venv .env
27 | source .env/bin/activate
28 | make install
29 | make install-dev
30 | - name: Unit tests
31 | run: |
32 | source .env/bin/activate
33 | make test
34 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v2
12 | - uses: actions-ecosystem/action-regex-match@v2
13 | id: regex-match
14 | with:
15 | text: ${{ github.event.head_commit.message }}
16 | regex: '^Release ([^ ]+)'
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.8'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Release
26 | if: ${{ steps.regex-match.outputs.match != '' }}
27 | uses: softprops/action-gh-release@v1
28 | with:
29 | tag_name: v${{ steps.regex-match.outputs.group1 }}
30 | - name: Build and publish
31 | if: ${{ steps.regex-match.outputs.match != '' }}
32 | env:
33 | TWINE_USERNAME: __token__
34 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
35 | run: |
36 | python setup.py sdist bdist_wheel
37 | twine upload dist/*
38 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | logs/
2 | wandb/
3 | models/
4 | features/
5 | results/
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | pip-wheel-metadata/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101 | __pypackages__/
102 |
103 | # Celery stuff
104 | celerybeat-schedule
105 | celerybeat.pid
106 |
107 | # SageMath parsed files
108 | *.sage.py
109 |
110 | # Environments
111 | .env
112 | .venv
113 | env/
114 | venv/
115 | ENV/
116 | env.bak/
117 | venv.bak/
118 |
119 | # Spyder project settings
120 | .spyderproject
121 | .spyproject
122 |
123 | # Rope project settings
124 | .ropeproject
125 |
126 | # mkdocs documentation
127 | /site
128 |
129 | # mypy
130 | .mypy_cache/
131 | .dmypy.json
132 | dmypy.json
133 |
134 | # Pyre type checker
135 | .pyre/
136 | sync.sh
137 | gpu1sync.sh
138 | .idea
139 | *.pdf
140 | **/._*
141 | **/*DS_*
142 | **.jsonl
143 | src/sbatch
144 | src/misc
145 | .vscode
146 | src/debug
147 | core.*
148 |
149 | # Allow
150 | !src/evaluation/misc/results_dbs/*
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.1.0
2 | message: If you use this software, please cite it as below.
3 | authors:
4 | - family-names: Ilharco
5 | given-names: Gabriel
6 | - family-names: Wortsman
7 | given-names: Mitchell
8 | - family-names: Wightman
9 | given-names: Ross
10 | - family-names: Gordon
11 | given-names: Cade
12 | - family-names: Carlini
13 | given-names: Nicholas
14 | - family-names: Taori
15 | given-names: Rohan
16 | - family-names: Dave
17 | given-names: Achal
18 | - family-names: Shankar
19 | given-names: Vaishaal
20 | - family-names: Namkoong
21 | given-names: Hongseok
22 | - family-names: Miller
23 | given-names: John
24 | - family-names: Hajishirzi
25 | given-names: Hannaneh
26 | - family-names: Farhadi
27 | given-names: Ali
28 | - family-names: Schmidt
29 | given-names: Ludwig
30 | title: OpenCLIP
31 | version: v0.1
32 | doi: 10.5281/zenodo.5143773
33 | date-released: 2021-07-28
34 |
--------------------------------------------------------------------------------
/HISTORY.md:
--------------------------------------------------------------------------------
1 | ## 1.2.0
2 |
3 | * ViT-B/32 trained on Laion2B-en
4 | * add missing openai RN50x64 model
5 |
6 | ## 1.1.1
7 |
8 | * ViT-B/16+
9 | * Add grad checkpointing support
10 | * more robust data loader
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman,
2 | Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar,
3 | John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi,
4 | Ludwig Schmidt
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining
7 | a copy of this software and associated documentation files (the
8 | "Software"), to deal in the Software without restriction, including
9 | without limitation the rights to use, copy, modify, merge, publish,
10 | distribute, sublicense, and/or sell copies of the Software, and to
11 | permit persons to whom the Software is furnished to do so, subject to
12 | the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be
15 | included in all copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
21 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
22 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
23 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
24 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include src/open_clip/bpe_simple_vocab_16e6.txt.gz
2 | include src/open_clip/model_configs/*.json
3 |
4 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | install: ## [Local development] Upgrade pip, install requirements, install package.
2 | python -m pip install -U pip
3 | python -m pip install -e .
4 |
5 | install-dev: ## [Local development] Install test requirements
6 | python -m pip install -r requirements-test.txt
7 |
8 | test: ## [Local development] Run unit tests
9 | python -m pytest -x -s -v tests
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # When and why vision-language models behave like bags-of-words, and what to do about it? (ICLR 2023 Oral)
2 |
3 | **Note** This code will not work with the distributed/multi-gpu setting as it is currently implemented.
4 |
5 | ## NegCLIP Implementation
6 |
7 | NegCLIP introduces a few simple edits to the original OpenCLIP base. To ease the code-reading phase,
8 | I'll point out the main edits here; if you are familiar with how OpenCLIP works this should be easy to read/edit and
9 | modify.
10 |
11 | **Dataset**
12 |
13 | The dataset now requires loading hard captions (provided as a list) and hard image negatives. Hard captions and hard images
14 | are chosen at random at each epoch.
15 |
16 | ```python
17 | df = pd.read_csv(input_filename, sep=sep, converters={"neg_caption":ast.literal_eval, "neg_image":ast.literal_eval})
18 |
19 | self.images = df[img_key].tolist()
20 | self.captions = df[caption_key].tolist()
21 | self.hard_captions = df[hard_captions_key].tolist()
22 | self.hard_images = df["neg_image"].tolist()
23 | self.transforms = transforms
24 |
25 | [...]
26 |
27 | # example of random selection of an hard caption
28 | chosen_caption = random.choice(self.hard_captions[idx])
29 | hard_captions = tokenize([str(chosen_caption)])[0]
30 | ```
31 |
32 | **Forward Pass**
33 |
34 | To reduce the number of edits we need to apply to the contrastive loss, we concatenate negative images and negative
35 | captions together. Once this is done we will let the model run the forward pass on this data.
36 |
37 | ```python
38 | images = torch.cat([images, hard_images]) # we concatenate images and hard images
39 |
40 | texts = torch.cat([texts, texts_hard_images]) # we concatenate texts with the text of the hard images
41 | texts = torch.cat([texts, hard_captions]) # we concatenate text with the hard captions
42 | texts = torch.cat([texts, hard_captions_of_hard_images]) # we concatenate texts with the hard caption of the hard images
43 |
44 | # Note. This operation is going to leave us with different in sizes. We will have 2x texts than images (because of the hard negatives).
45 | # This will require us to fix how we compute the loss (see next section).
46 |
47 | with autocast():
48 | image_features, text_features, logit_scale = model(images, texts)
49 | ```
50 |
51 | **Loss**
52 |
53 | Finally, we have the loss. In the **Forward Pass** section we have built texts and images that have different `lenghts`.
54 | Basically starting from a batchsize of 256, you get to a contrastive matrix of 512x1024 (for the image part we have
55 | 256 images + 256 hard images, for the text part we have 256 captions + 256 captions from the hard images + 256 hard captions + 256 hard captions from the hard images).
56 |
57 | So we need to change the loss a bit to ignore computing the loss on the wrong items (see the paper).
58 |
59 | ```python
60 | total_loss = (
61 | F.cross_entropy(logits_per_image, labels) +
62 | F.cross_entropy(logits_per_text[:len(logits_per_image)], labels)
63 | ) / 2
64 | ```
65 |
66 |
67 | # Citation
68 | If you use this code or data, please consider citing our paper:
69 |
70 | ```
71 | @inproceedings{
72 | yuksekgonul2023when,
73 | title={When and why Vision-Language Models behave like Bags-of-Words, and what to do about it?},
74 | author={Mert Yuksekgonul and Federico Bianchi and Pratyusha Kalluri and Dan Jurafsky and James Zou},
75 | booktitle={International Conference on Learning Representations},
76 | year={2023},
77 | url={https://openreview.net/forum?id=KRLUvxh8uaX}
78 | }
79 | ```
80 |
81 | What follows from here is the original OpenCLIP readme.
82 |
83 | # Original OpenCLIP
84 |
85 | [[Paper]](https://arxiv.org/abs/2109.01903) [[Colab]](https://colab.research.google.com/github/mlfoundations/open_clip/blob/master/docs/Interacting_with_open_clip.ipynb)
86 |
87 | Welcome to an open source implementation of OpenAI's [CLIP](https://arxiv.org/abs/2103.00020) (Contrastive Language-Image Pre-training).
88 |
89 | The goal of this repository is to enable training models with contrastive image-text supervision, and to investigate their properties such as robustness to distribution shift. Our starting point is an implementation of CLIP that matches the accuracy of the original CLIP models when trained on the same dataset.
90 | Specifically, a ResNet-50 model trained with our codebase on OpenAI's [15 million image subset of YFCC](https://github.com/openai/CLIP/blob/main/data/yfcc100m.md) achieves **32.7%** top-1 accuracy on ImageNet. OpenAI's CLIP model reaches **31.3%** when trained on the same subset of YFCC. For ease of experimentation, we also provide code for training on the 3 million images in the [Conceptual Captions](https://ai.google.com/research/ConceptualCaptions/download) dataset, where a ResNet-50x4 trained with our codebase reaches 22.2% top-1 ImageNet accuracy.
91 |
92 | We further this with a replication study on a dataset of comparable size to OpenAI's. Using [LAION-400M](https://arxiv.org/abs/2111.02114), we train CLIP with a
93 | * ViT-B/32 and achieve an accuracy of **62.9%**, comparable to OpenAI's **63.2%**, zero-shot top-1 on ImageNet1k
94 | * ViT-B/16 and achieve an accuracy of **67.1%**, comparable to OpenAI's **68.3%** (as measured here, 68.6% in paper)
95 | * ViT-B/16+ 240x240 (~50% more FLOPS than B/16 224x224) and achieve an accuracy of **69.2%**
96 | * ViT-L/14 and achieve an accuracy of **72.77%**, vs OpenAI's **75.5%** (as measured here, 75.3% in paper)
97 |
98 | As we describe in more detail [below](#why-are-low-accuracy-clip-models-interesting), CLIP models in a medium accuracy regime already allow us to draw conclusions about the robustness of larger CLIP models since the models follow [reliable scaling laws](https://arxiv.org/abs/2107.04649).
99 |
100 | This codebase is work in progress, and we invite all to contribute in making it more acessible and useful. In the future, we plan to add support for TPU training and release larger models. We hope this codebase facilitates and promotes further research in contrastive image-text learning. Please submit an issue or send an email if you have any other requests or suggestions.
101 |
102 | Note that portions of `src/open_clip/` modelling and tokenizer code are adaptations of OpenAI's official [repository](https://github.com/openai/CLIP).
103 |
104 | ## Approach
105 |
106 | |  |
107 | |:--:|
108 | | Image Credit: https://github.com/openai/CLIP |
109 |
110 | ## Usage
111 |
112 | ```
113 | pip install open_clip_torch
114 | ```
115 |
116 | ```python
117 | import torch
118 | from PIL import Image
119 | import open_clip
120 |
121 | model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32')
122 |
123 | image = preprocess(Image.open("CLIP.png")).unsqueeze(0)
124 | text = open_clip.tokenize(["a diagram", "a dog", "a cat"])
125 |
126 | with torch.no_grad():
127 | image_features = model.encode_image(image)
128 | text_features = model.encode_text(text)
129 | image_features /= image_features.norm(dim=-1, keepdim=True)
130 | text_features /= text_features.norm(dim=-1, keepdim=True)
131 |
132 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
133 |
134 | print("Label probs:", text_probs) # prints: [[1., 0., 0.]]
135 | ```
136 |
137 | To compute billions of embeddings efficiently, you can use [clip-retrieval](https://github.com/rom1504/clip-retrieval) which has openclip support.
138 |
139 | ## Fine-tuning on classification tasks
140 |
141 | This repository is focused on training CLIP models. To fine-tune a *trained* zero-shot model on a downstream classification task such as ImageNet, please see [our other repository: WiSE-FT](https://github.com/mlfoundations/wise-ft). The [WiSE-FT repository](https://github.com/mlfoundations/wise-ft) contains code for our paper on [Robust Fine-tuning of Zero-shot Models](https://arxiv.org/abs/2109.01903), in which we introduce a technique for fine-tuning zero-shot models while preserving robustness under distribution shift.
142 |
143 | ## Data
144 |
145 |
146 | ### Conceptual Captions
147 |
148 | OpenCLIP reads a CSV file with two columns: a path to an image, and a text caption. The names of the columns are passed as an argument to `main.py`.
149 |
150 | The script `src/data/gather_cc.py` will collect the Conceptual Captions images. First, download the [Conceptual Captions URLs](https://ai.google.com/research/ConceptualCaptions/download) and then run the script from our repository:
151 |
152 | ```bash
153 | python3 src/data/gather_cc.py path/to/Train_GCC-training.tsv path/to/Validation_GCC-1.1.0-Validation.tsv
154 | ```
155 |
156 | Our training set contains 2.89M images, and our validation set contains 13K images.
157 |
158 |
159 | ### YFCC and other datasets
160 |
161 | In addition to specifying the training data via CSV files as mentioned above, our codebase also supports [webdataset](https://github.com/webdataset/webdataset), which is recommended for larger scale datasets. The expected format is a series of `.tar` files. Each of these `.tar` files should contain two files for each training example, one for the image and one for the corresponding text. Both files should have the same name but different extensions. For instance, `shard_001.tar` could contain files such as `abc.jpg` and `abc.txt`. You can learn more about `webdataset` at [https://github.com/webdataset/webdataset](https://github.com/webdataset/webdataset). We use `.tar` files with 1,000 data points each, which we create using [tarp](https://github.com/webdataset/tarp).
162 |
163 | You can download the YFCC dataset from [Multimedia Commons](http://mmcommons.org/).
164 | Similar to OpenAI, we used a subset of YFCC to reach the aforementioned accuracy numbers.
165 | The indices of images in this subset are in [OpenAI's CLIP repository](https://github.com/openai/CLIP/blob/main/data/yfcc100m.md).
166 |
167 |
168 | ## Training CLIP
169 |
170 | ### Setup Environment and Install dependencies
171 |
172 | #### Conda
173 |
174 | ```bash
175 | # Create a conda environment (heavily recommended)
176 | conda create -n open_clip python=3.10
177 | conda activate open_clip
178 | ```
179 |
180 | Install conda PyTorch as per https://pytorch.org/get-started/locally/
181 |
182 | #### Virtualenv
183 |
184 | openclip also can be used with virtualenv with these lines:
185 | ```
186 | python3 -m venv .env
187 | source .env/bin/activate
188 | pip install -U pip
189 | make install
190 | ```
191 |
192 | Install pip PyTorch as per https://pytorch.org/get-started/locally/
193 |
194 | Test can be run with `make install-dev` then `make test`
195 |
196 | #### Other dependencies
197 |
198 | Install open_clip pacakge and remaining dependencies:
199 |
200 | ```bash
201 | cd open_clip
202 | python setup.py install
203 | ```
204 |
205 | If you want to train models, you will also need to install the packages
206 | from `requirements-training.txt`.
207 |
208 | ### Sample single-process running code:
209 |
210 | ```bash
211 | python -m training.main \
212 | --save-frequency 1 \
213 | --zeroshot-frequency 1 \
214 | --report-to tensorboard \
215 | --train-data="/path/to/train_data.csv" \
216 | --val-data="/path/to/validation_data.csv" \
217 | --csv-img-key filepath \
218 | --csv-caption-key title \
219 | --imagenet-val=/path/to/imagenet/root/val/ \
220 | --warmup 10000 \
221 | --batch-size=128 \
222 | --lr=1e-3 \
223 | --wd=0.1 \
224 | --epochs=30 \
225 | --workers=8 \
226 | --model RN50
227 | ```
228 |
229 | Note: `imagenet-val` is the path to the *validation* set of ImageNet for zero-shot evaluation, not the training set!
230 | You can remove this argument if you do not want to perform zero-shot evaluation on ImageNet throughout training. Note that the `val` folder should contain subfolders. If it doest not, please use [this script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh).
231 |
232 | ### Multi-GPU and Beyond
233 |
234 | This code has been battle tested up to 1024 A100s and offers a variety of solutions
235 | for distributed training. We include native support for SLURM clusters.
236 |
237 | As the number of devices used to train increases, so does the space complexity of
238 | the the logit matrix. Using a naïve all-gather scheme, space complexity will be
239 | `O(n^2)`. Instead, complexity may become effectively linear if the flags
240 | `--gather-with-grad` and `--local-loss` are used. This alteration results in one-to-one
241 | numerical results as the naïve method.
242 |
243 | #### Single-Node
244 |
245 | We make use of `torchrun` to launch distributed jobs. The following launches a
246 | a job on a node of 4 GPUs:
247 |
248 | ```bash
249 | cd open_clip/src
250 | torchrun --nproc_per_node 4 -m training.main \
251 | --train-data '/data/cc12m/cc12m-train-{0000..2175}.tar' \
252 | --train-num-samples 10968539 \
253 | --dataset-type webdataset \
254 | --batch-size 320 \
255 | --precision amp \
256 | --workers 4 \
257 | --imagenet-val /data/imagenet/validation/
258 | ```
259 |
260 | #### Multi-Node
261 |
262 | The same script above works, so long as users include information about the number
263 | of nodes and host node.
264 |
265 | ```bash
266 | cd open_clip/src
267 | torchrun --nproc_per_node=4 \
268 | --rdzv_endpoint=$HOSTE_NODE_ADDR \
269 | -m training.main \
270 | --train-data '/data/cc12m/cc12m-train-{0000..2175}.tar' \
271 | --train-num-samples 10968539 \
272 | --dataset-type webdataset \
273 | --batch-size 320 \
274 | --precision amp \
275 | --workers 4 \
276 | --imagenet-val /data/imagenet/validation/
277 | ```
278 |
279 | #### SLURM
280 |
281 | This is likely the easiest solution to utilize. The following script was used to
282 | train our largest models:
283 |
284 | ```bash
285 | #!/bin/bash -x
286 | #SBATCH --nodes=32
287 | #SBATCH --gres=gpu:4
288 | #SBATCH --ntasks-per-node=4
289 | #SBATCH --cpus-per-task=6
290 | #SBATCH --wait-all-nodes=1
291 | #SBATCH --job-name=open_clip
292 | #SBATCH --account=ACCOUNT_NAME
293 | #SBATCH --partition PARTITION_NAME
294 |
295 | eval "$(/path/to/conda/bin/conda shell.bash hook)" # init conda
296 | conda activate open_clip
297 | export CUDA_VISIBLE_DEVICES=0,1,2,3
298 | export MASTER_PORT=12802
299 |
300 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
301 | export MASTER_ADDR=$master_addr
302 |
303 | cd /shared/open_clip
304 | export PYTHONPATH="$PYTHONPATH:$PWD/src"
305 | srun --cpu_bind=v --accel-bind=gn python -u src/training/main.py \
306 | --save-frequency 1 \
307 | --report-to tensorboard \
308 | --train-data="/data/LAION-400M/{00000..41455}.tar" \
309 | --warmup 2000 \
310 | --batch-size=256 \
311 | --epochs=32 \
312 | --workers=8 \
313 | --model ViT-B-32 \
314 | --name "ViT-B-32-Vanilla" \
315 | --seed 0 \
316 | --local-loss \
317 | --gather-with-grad
318 | ```
319 |
320 | ### Resuming from a checkpoint:
321 |
322 | ```bash
323 | python -m training.main \
324 | --train-data="/path/to/train_data.csv" \
325 | --val-data="/path/to/validation_data.csv" \
326 | --resume /path/to/checkpoints/epoch_K.pt
327 | ```
328 |
329 | ### Loss Curves
330 |
331 | When run on a machine with 8 GPUs the command should produce the following training curve for Conceptual Captions:
332 |
333 | 
334 |
335 | More detailed curves for Conceptual Captions are given at [/docs/clip_conceptual_captions.md](/docs/clip_conceptual_captions.md).
336 |
337 | When training a RN50 on YFCC the same hyperparameters as above are used, with the exception of `lr=5e-4` and `epochs=32`.
338 |
339 | Note that to use another model, like `ViT-B/32` or `RN50x4` or `RN50x16` or `ViT-B/16`, specify with `--model RN50x4`.
340 |
341 | ### Launch tensorboard:
342 | ```bash
343 | tensorboard --logdir=logs/tensorboard/ --port=7777
344 | ```
345 |
346 | ## Evaluation / Zero-Shot
347 |
348 | ### Evaluating local checkpoint:
349 |
350 | ```bash
351 | python -m training.main \
352 | --val-data="/path/to/validation_data.csv" \
353 | --model RN101 \
354 | --pretrained /path/to/checkpoints/epoch_K.pt
355 | ```
356 |
357 | ### Evaluating hosted pretrained checkpoint on ImageNet zero-shot prediction:
358 |
359 | ```bash
360 | python -m training.main \
361 | --imagenet-val /path/to/imagenet/validation \
362 | --model ViT-B-32-quickgelu \
363 | --pretrained laion400m_e32
364 | ```
365 |
366 | ## Pretrained model details
367 |
368 | ### LAION-400M - https://laion.ai/laion-400-open-dataset
369 |
370 | We are working on reproducing OpenAI's ViT results with the comparably sized (and open) LAION-400M dataset. Trained
371 | weights may be found in release [v0.2](https://github.com/mlfoundations/open_clip/releases/tag/v0.2-weights).
372 |
373 | The LAION400M weights have been trained on the JUWELS supercomputer (see acknowledgements section below).
374 |
375 | #### ViT-B/32 224x224
376 |
377 | We replicate OpenAI's results on ViT-B/32, reaching a top-1 ImageNet-1k zero-shot accuracy of 62.96%.
378 |
379 |
380 |
381 | __Zero-shot comparison (courtesy of Andreas Fürst)__
382 |
383 |
384 | ViT-B/32 was trained with 128 A100 (40 GB) GPUs for ~36 hours, 4600 GPU-hours. The per-GPU batch size was 256 for a global batch size of 32768. 256 is much lower than it could have been (~320-384) due to being sized initially before moving to 'local' contrastive loss.
385 |
386 | #### ViT-B/16 224x224
387 |
388 | The B/16 LAION400M training reached a top-1 ImageNet-1k zero-shot validation score of 67.07.
389 |
390 |
391 |
392 | This was the first major train session using the updated webdataset 0.2.x code. A bug was found that prevented shards from being shuffled properly between nodes/workers each epoch. This was fixed part way through training (epoch 26) but likely had an impact.
393 |
394 | ViT-B/16 was trained with 176 A100 (40 GB) GPUS for ~61 hours, 10700 GPU-hours. Batch size per GPU was 192 for a global batch size of 33792.
395 |
396 | #### ViT-B/16+ 240x240
397 |
398 | The B/16+ 240x240 LAION400M training reached a top-1 ImageNet-1k zero-shot validation score of 69.21.
399 |
400 | This model is the same depth as the B/16, but increases the
401 | * vision width from 768 -> 896
402 | * text width from 512 -> 640
403 | * the resolution 224x224 -> 240x240 (196 -> 225 tokens)
404 |
405 |
406 |
407 | Unlike the B/16 run above, this model was a clean run with no dataset shuffling issues.
408 |
409 | ViT-B/16+ was trained with 224 A100 (40 GB) GPUS for ~61 hours, 13620 GPU-hours. Batch size per GPU was 160 for a global batch size of 35840.
410 |
411 | #### ViT-L/14 224x224
412 |
413 | The L/14 LAION-400M training reached a top-1 ImageNet-1k zero-shot validation score of 72.77.
414 |
415 |
416 |
417 | ViT-L/14 was trained with 400 A100 (40 GB) GPUS for ~127 hours, 50800 GPU-hours. Batch size per GPU was 96 for a global batch size of 38400. Grad checkpointing was enabled.
418 |
419 | ### LAION-2B (en) - https://laion.ai/laion-5b-a-new-era-of-open-large-scale-multi-modal-datasets/
420 |
421 | A ~2B sample subset of LAION-5B with english captions (https://huggingface.co/datasets/laion/laion2B-en)
422 |
423 | #### ViT-B/32 224x224
424 | A ViT-B/32 trained on LAION-2B, reaching a top-1 ImageNet-1k zero-shot accuracy of 65.62%.
425 |
426 |
427 |
428 | ViT-B/32 was trained with 112 A100 (40 GB) GPUs. The per-GPU batch size was 416 for a global batch size of 46592. Compute generously provided by [stability.ai](https://stability.ai/).
429 |
430 | #### YFCC-15M
431 |
432 | Below are checkpoints of models trained on YFCC-15M, along with their zero-shot top-1 accuracies on ImageNet and ImageNetV2. These models were trained using 8 GPUs and the same hyperparameters described in the "Sample running code" section, with the exception of `lr=5e-4` and `epochs=32`.
433 |
434 | * [ResNet-50](https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt) (32.7% / 27.9%)
435 | * [ResNet-101](https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt) (34.8% / 30.0%)
436 |
437 | #### CC12M - https://github.com/google-research-datasets/conceptual-12m
438 |
439 | * [ResNet-50](https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt) (36.45%)
440 |
441 | ### Pretrained Model Interface
442 |
443 | We offer a simple model interface to instantiate both pre-trained and untrained models.
444 |
445 | NOTE: Many existing checkpoints use the QuickGELU activation from the original OpenAI models. This activation is actually less efficient that native torch.nn.GELU in recent versions of PyTorch. The model defaults are now nn.GELU, so one should use model definitions with `-quickgelu` postfix for the OpenCLIP pretrained weights. All OpenAI pretrained weights will always default to QuickGELU. One can also use the non `-quickgelu` model definitions with pretrained weights using QuickGELU but there will be an accuracy drop, for fine-tune that will likely vanish for longer runs.
446 |
447 | Future trained models will use nn.GELU.
448 |
449 | ```python
450 | >>> import open_clip
451 | >>> open_clip.list_pretrained()
452 | [('RN50', 'openai'),
453 | ('RN50', 'yfcc15m'),
454 | ('RN50', 'cc12m'),
455 | ('RN50-quickgelu', 'openai'),
456 | ('RN50-quickgelu', 'yfcc15m'),
457 | ('RN50-quickgelu', 'cc12m'),
458 | ('RN101', 'openai'),
459 | ('RN101', 'yfcc15m'),
460 | ('RN101-quickgelu', 'openai'),
461 | ('RN101-quickgelu', 'yfcc15m'),
462 | ('RN50x4', 'openai'),
463 | ('RN50x16', 'openai'),
464 | ('RN50x64', 'openai'),
465 | ('ViT-B-32', 'openai'),
466 | ('ViT-B-32', 'laion2b_e16'),
467 | ('ViT-B-32', 'laion400m_e31'),
468 | ('ViT-B-32', 'laion400m_e32'),
469 | ('ViT-B-32-quickgelu', 'openai'),
470 | ('ViT-B-32-quickgelu', 'laion400m_e31'),
471 | ('ViT-B-32-quickgelu', 'laion400m_e32'),
472 | ('ViT-B-16', 'openai'),
473 | ('ViT-B-16', 'laion400m_e31'),
474 | ('ViT-B-16', 'laion400m_e32'),
475 | ('ViT-B-16-plus-240', 'laion400m_e31'),
476 | ('ViT-B-16-plus-240', 'laion400m_e32'),
477 | ('ViT-L-14', 'openai'),
478 | ('ViT-L-14-336', 'openai')]
479 |
480 | >>> model, train_transform, eval_transform = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_e16')
481 | ```
482 |
483 | ## Scaling trends
484 |
485 | The plot below shows how zero-shot performance of CLIP models varies as we scale the number of samples used for training. Zero-shot performance increases steadily for both ImageNet and [ImageNetV2](https://arxiv.org/abs/1902.10811), and is far from saturated at ~15M samples.
486 |
487 |
488 |
489 | ## Why are low-accuracy CLIP models interesting?
490 |
491 | **TL;DR:** CLIP models have high effective robustness, even at small scales.
492 |
493 | CLIP models are particularly intriguing because they are more robust to natural distribution shifts (see Section 3.3 in the [CLIP paper](https://arxiv.org/abs/2103.00020)).
494 | This phenomena is illustrated by the figure below, with ImageNet accuracy on the x-axis
495 | and [ImageNetV2](https://arxiv.org/abs/1902.10811) (a reproduction of the ImageNet validation set with distribution shift) accuracy on the y-axis.
496 | Standard training denotes training on the ImageNet train set and the CLIP zero-shot models
497 | are shown as stars.
498 |
499 | 
500 |
501 | As observed by [Taori et al., 2020](https://arxiv.org/abs/2007.00644) and [Miller et al., 2021](https://arxiv.org/abs/2107.04649), the in-distribution
502 | and out-of-distribution accuracies of models trained on ImageNet follow a predictable linear trend (the red line in the above plot). *Effective robustness*
503 | quantifies robustness as accuracy beyond this baseline, i.e., how far a model lies above the red line. Ideally a model would not suffer from distribution shift and fall on the y = x line ([trained human labelers are within a percentage point of the y = x line](http://proceedings.mlr.press/v119/shankar20c.html)).
504 |
505 | Even though the CLIP models trained with
506 | this codebase achieve much lower accuracy than those trained by OpenAI, our models still lie on the same
507 | trend of improved effective robustness (the purple line). Therefore, we can study what makes
508 | CLIP robust without requiring industrial-scale compute.
509 |
510 | For more information on effective robustness, please see:
511 |
512 | - [Recht et al., 2019](https://arxiv.org/abs/1902.10811).
513 | - [Taori et al., 2020](https://arxiv.org/abs/2007.00644).
514 | - [Miller et al., 2021](https://arxiv.org/abs/2107.04649).
515 |
516 | To know more about the factors that contribute to CLIP's robustness refer to [Fang et al., 2022](https://arxiv.org/abs/2205.01397).
517 |
518 | ## Acknowledgments
519 |
520 | We gratefully acknowledge the Gauss Centre for Supercomputing e.V. (www.gauss-centre.eu) for funding this part of work by providing computing time through the John von Neumann Institute for Computing (NIC) on the GCS Supercomputer JUWELS Booster at Jülich Supercomputing Centre (JSC).
521 |
522 | ## The Team
523 |
524 | Current development of this repository is led by [Ross Wightman](https://rwightman.com/), [Cade Gordon](http://cadegordon.io/), and [Vaishaal Shankar](http://vaishaal.com/).
525 |
526 | The original version of this repository is from a group of researchers at UW, Google, Stanford, Amazon, Columbia, and Berkeley.
527 |
528 | [Gabriel Ilharco*](http://gabrielilharco.com/), [Mitchell Wortsman*](https://mitchellnw.github.io/), [Nicholas Carlini](https://nicholas.carlini.com/), [Rohan Taori](https://www.rohantaori.com/), [Achal Dave](http://www.achaldave.com/), [Vaishaal Shankar](http://vaishaal.com/), [John Miller](https://people.eecs.berkeley.edu/~miller_john/), [Hongseok Namkoong](https://hsnamkoong.github.io/), [Hannaneh Hajishirzi](https://homes.cs.washington.edu/~hannaneh/), [Ali Farhadi](https://homes.cs.washington.edu/~ali/), [Ludwig Schmidt](https://people.csail.mit.edu/ludwigs/)
529 |
530 | Special thanks to [Jong Wook Kim](https://jongwook.kim/) and [Alec Radford](https://github.com/Newmu) for help with reproducing CLIP!
531 |
532 | ## Citing
533 |
534 | If you found this repository useful, please consider citing:
535 | ```bibtex
536 | @software{ilharco_gabriel_2021_5143773,
537 | author = {Ilharco, Gabriel and
538 | Wortsman, Mitchell and
539 | Wightman, Ross and
540 | Gordon, Cade and
541 | Carlini, Nicholas and
542 | Taori, Rohan and
543 | Dave, Achal and
544 | Shankar, Vaishaal and
545 | Namkoong, Hongseok and
546 | Miller, John and
547 | Hajishirzi, Hannaneh and
548 | Farhadi, Ali and
549 | Schmidt, Ludwig},
550 | title = {OpenCLIP},
551 | month = jul,
552 | year = 2021,
553 | note = {If you use this software, please cite it as below.},
554 | publisher = {Zenodo},
555 | version = {0.1},
556 | doi = {10.5281/zenodo.5143773},
557 | url = {https://doi.org/10.5281/zenodo.5143773}
558 | }
559 | ```
560 |
561 | ```bibtex
562 | @inproceedings{Radford2021LearningTV,
563 | title={Learning Transferable Visual Models From Natural Language Supervision},
564 | author={Alec Radford and Jong Wook Kim and Chris Hallacy and A. Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
565 | booktitle={ICML},
566 | year={2021}
567 | }
568 | ```
569 |
570 | [](https://zenodo.org/badge/latestdoi/390536799)
571 |
--------------------------------------------------------------------------------
/docs/CLIP.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/CLIP.png
--------------------------------------------------------------------------------
/docs/clip_conceptual_captions.md:
--------------------------------------------------------------------------------
1 | ## Additional training curves for CLIP on Conceptual Captions
2 |
3 | # Zero shot accuracy
4 | 
5 |
6 | # Training loss curve
7 | 
8 |
9 | # Validation loss curve
10 | 
11 |
12 | # Validation recall
13 | 
--------------------------------------------------------------------------------
/docs/clip_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/clip_loss.png
--------------------------------------------------------------------------------
/docs/clip_recall.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/clip_recall.png
--------------------------------------------------------------------------------
/docs/clip_val_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/clip_val_loss.png
--------------------------------------------------------------------------------
/docs/clip_zeroshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/clip_zeroshot.png
--------------------------------------------------------------------------------
/docs/effective_robustness.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/effective_robustness.png
--------------------------------------------------------------------------------
/docs/laion2b_clip_zeroshot_b32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/laion2b_clip_zeroshot_b32.png
--------------------------------------------------------------------------------
/docs/laion_clip_zeroshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/laion_clip_zeroshot.png
--------------------------------------------------------------------------------
/docs/laion_clip_zeroshot_b16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/laion_clip_zeroshot_b16.png
--------------------------------------------------------------------------------
/docs/laion_clip_zeroshot_b16_plus_240.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/laion_clip_zeroshot_b16_plus_240.png
--------------------------------------------------------------------------------
/docs/laion_clip_zeroshot_l14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/laion_clip_zeroshot_l14.png
--------------------------------------------------------------------------------
/docs/laion_openai_compare_b32.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/laion_openai_compare_b32.jpg
--------------------------------------------------------------------------------
/docs/scaling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/scaling.png
--------------------------------------------------------------------------------
/requirements-test.txt:
--------------------------------------------------------------------------------
1 | pytest-xdist==2.5.0
2 | pytest==7.0.1
--------------------------------------------------------------------------------
/requirements-training.txt:
--------------------------------------------------------------------------------
1 | torch>=1.9.0
2 | torchvision
3 | webdataset>=0.2.5
4 | regex
5 | ftfy
6 | tqdm
7 | pandas
8 | braceexpand
9 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.9.0
2 | torchvision
3 | regex
4 | ftfy
5 | tqdm
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """ Setup
2 | """
3 | from setuptools import setup, find_packages
4 | from codecs import open
5 | from os import path
6 |
7 | here = path.abspath(path.dirname(__file__))
8 |
9 | # Get the long description from the README file
10 | with open(path.join(here, 'README.md'), encoding='utf-8') as f:
11 | long_description = f.read()
12 |
13 | exec(open('src/open_clip/version.py').read())
14 | setup(
15 | name='open_clip_torch',
16 | version=__version__,
17 | description='OpenCLIP',
18 | long_description=long_description,
19 | long_description_content_type='text/markdown',
20 | url='https://github.com/mlfoundations/open_clip',
21 | author='',
22 | author_email='',
23 | classifiers=[
24 | # How mature is this project? Common values are
25 | # 3 - Alpha
26 | # 4 - Beta
27 | # 5 - Production/Stable
28 | 'Development Status :: 3 - Alpha',
29 | 'Intended Audience :: Education',
30 | 'Intended Audience :: Science/Research',
31 | 'License :: OSI Approved :: Apache Software License',
32 | 'Programming Language :: Python :: 3.7',
33 | 'Programming Language :: Python :: 3.8',
34 | 'Programming Language :: Python :: 3.9',
35 | 'Programming Language :: Python :: 3.10',
36 | 'Topic :: Scientific/Engineering',
37 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
38 | 'Topic :: Software Development',
39 | 'Topic :: Software Development :: Libraries',
40 | 'Topic :: Software Development :: Libraries :: Python Modules',
41 | ],
42 |
43 | # Note that this is a string of words separated by whitespace, not a list.
44 | keywords='CLIP pretrained',
45 | package_dir={'': 'src'},
46 | packages=find_packages(where='src', exclude=['training']),
47 | include_package_data=True,
48 | install_requires=[
49 | 'torch >= 1.9',
50 | 'torchvision',
51 | 'ftfy',
52 | 'regex',
53 | 'tqdm',
54 | ],
55 | python_requires='>=3.7',
56 | )
57 |
--------------------------------------------------------------------------------
/src/data/gather_cc.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import os
3 | import multiprocessing as mp
4 | from io import BytesIO
5 | import numpy as np
6 | import PIL
7 | from PIL import Image
8 | import pickle
9 | import sys
10 |
11 |
12 | def grab(line):
13 | """
14 | Download a single image from the TSV.
15 | """
16 | uid, split, line = line
17 | try:
18 | caption, url = line.split("\t")[:2]
19 | except:
20 | print("Parse error")
21 | return
22 |
23 | if os.path.exists(ROOT+"/%s/%d/%d.jpg"%(split,uid%1000,uid)):
24 | print("Finished", uid)
25 | return uid, caption, url
26 |
27 | # Let's not crash if anythign weird happens
28 | try:
29 | dat = requests.get(url, timeout=20)
30 | if dat.status_code != 200:
31 | print("404 file", url)
32 | return
33 |
34 | # Try to parse this as an Image file, we'll fail out if not
35 | im = Image.open(BytesIO(dat.content))
36 | im.thumbnail((512, 512), PIL.Image.BICUBIC)
37 | if min(*im.size) < max(*im.size)/3:
38 | print("Too small", url)
39 | return
40 |
41 | im.save(ROOT+"/%s/%d/%d.jpg"%(split,uid%1000,uid))
42 |
43 | # Another try/catch just because sometimes saving and re-loading
44 | # the image is different than loading it once.
45 | try:
46 | o = Image.open(ROOT+"/%s/%d/%d.jpg"%(split,uid%1000,uid))
47 | o = np.array(o)
48 |
49 | print("Success", o.shape, uid, url)
50 | return uid, caption, url
51 | except:
52 | print("Failed", uid, url)
53 |
54 | except Exception as e:
55 | print("Unknown error", e)
56 | pass
57 |
58 | if __name__ == "__main__":
59 | ROOT = "cc_data"
60 |
61 | if not os.path.exists(ROOT):
62 | os.mkdir(ROOT)
63 | os.mkdir(os.path.join(ROOT,"train"))
64 | os.mkdir(os.path.join(ROOT,"val"))
65 | for i in range(1000):
66 | os.mkdir(os.path.join(ROOT,"train", str(i)))
67 | os.mkdir(os.path.join(ROOT,"val", str(i)))
68 |
69 |
70 | p = mp.Pool(300)
71 |
72 | for tsv in sys.argv[1:]:
73 | print("Processing file", tsv)
74 | assert 'val' in tsv.lower() or 'train' in tsv.lower()
75 | split = 'val' if 'val' in tsv.lower() else 'train'
76 | results = p.map(grab,
77 | [(i,split,x) for i,x in enumerate(open(tsv).read().split("\n"))])
78 |
79 | out = open(tsv.replace(".tsv","_output.csv"),"w")
80 | out.write("title\tfilepath\n")
81 |
82 | for row in results:
83 | if row is None: continue
84 | id, caption, url = row
85 | fp = os.path.join(ROOT, split, str(id % 1000), str(id) + ".jpg")
86 | if os.path.exists(fp):
87 | out.write("%s\t%s\n"%(caption,fp))
88 | else:
89 | print("Drop", id)
90 | out.close()
91 |
92 | p.close()
93 |
94 |
--------------------------------------------------------------------------------
/src/open_clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .factory import list_models, create_model, create_model_and_transforms, add_model_config
2 | from .loss import ClipLoss
3 | from .model import CLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_fp16, trace_model
4 | from .openai import load_openai_model, list_openai_models
5 | from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\
6 | get_pretrained_url, download_pretrained
7 | from .tokenizer import SimpleTokenizer, tokenize
8 | from .transform import image_transform
9 |
--------------------------------------------------------------------------------
/src/open_clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/src/open_clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/src/open_clip/factory.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import pathlib
5 | import re
6 | from copy import deepcopy
7 | from pathlib import Path
8 | from typing import Optional, Tuple
9 |
10 | import torch
11 |
12 | from .model import CLIP, convert_weights_to_fp16, resize_pos_embed
13 | from .openai import load_openai_model
14 | from .pretrained import get_pretrained_url, download_pretrained
15 | from .transform import image_transform
16 |
17 |
18 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
19 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
20 |
21 |
22 | def _natural_key(string_):
23 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
24 |
25 |
26 | def _rescan_model_configs():
27 | global _MODEL_CONFIGS
28 |
29 | config_ext = ('.json',)
30 | config_files = []
31 | for config_path in _MODEL_CONFIG_PATHS:
32 | if config_path.is_file() and config_path.suffix in config_ext:
33 | config_files.append(config_path)
34 | elif config_path.is_dir():
35 | for ext in config_ext:
36 | config_files.extend(config_path.glob(f'*{ext}'))
37 |
38 | for cf in config_files:
39 | with open(cf, 'r') as f:
40 | model_cfg = json.load(f)
41 | if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
42 | _MODEL_CONFIGS[cf.stem] = model_cfg
43 |
44 | _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
45 |
46 |
47 | _rescan_model_configs() # initial populate of model config registry
48 |
49 |
50 | def load_state_dict(checkpoint_path: str, map_location='cpu'):
51 | checkpoint = torch.load(checkpoint_path, map_location=map_location)
52 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
53 | state_dict = checkpoint['state_dict']
54 | else:
55 | state_dict = checkpoint
56 | if next(iter(state_dict.items()))[0].startswith('module'):
57 | state_dict = {k[7:]: v for k, v in state_dict.items()}
58 | return state_dict
59 |
60 |
61 | def load_checkpoint(model, checkpoint_path, strict=True):
62 | state_dict = load_state_dict(checkpoint_path)
63 | resize_pos_embed(state_dict, model)
64 | incompatible_keys = model.load_state_dict(state_dict, strict=strict)
65 | return incompatible_keys
66 |
67 |
68 | def create_model(
69 | model_name: str,
70 | pretrained: str = '',
71 | precision: str = 'fp32',
72 | device: torch.device = torch.device('cpu'),
73 | jit: bool = False,
74 | force_quick_gelu: bool = False,
75 | pretrained_image: bool = False,
76 | ):
77 | model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
78 |
79 | if pretrained.lower() == 'openai':
80 | logging.info(f'Loading pretrained {model_name} from OpenAI.')
81 | model = load_openai_model(model_name, device=device, jit=jit)
82 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
83 | if precision == "amp" or precision == "fp32":
84 | model = model.float()
85 | else:
86 | if model_name in _MODEL_CONFIGS:
87 | logging.info(f'Loading {model_name} model config.')
88 | model_cfg = deepcopy(_MODEL_CONFIGS[model_name])
89 | else:
90 | logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
91 | raise RuntimeError(f'Model config for {model_name} not found.')
92 |
93 | if force_quick_gelu:
94 | # override for use of QuickGELU on non-OpenAI transformer models
95 | model_cfg["quick_gelu"] = True
96 |
97 | if pretrained_image:
98 | if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
99 | # pretrained weight loading for timm models set via vision_cfg
100 | model_cfg['vision_cfg']['timm_model_pretrained'] = True
101 | else:
102 | assert False, 'pretrained image towers currently only supported for timm models'
103 |
104 | model = CLIP(**model_cfg)
105 |
106 | if pretrained:
107 | checkpoint_path = ''
108 | url = get_pretrained_url(model_name, pretrained)
109 | if url:
110 | checkpoint_path = download_pretrained(url)
111 | elif os.path.exists(pretrained):
112 | checkpoint_path = pretrained
113 |
114 | if checkpoint_path:
115 | logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
116 | load_checkpoint(model, checkpoint_path)
117 | else:
118 | logging.warning(f'Pretrained weights ({pretrained}) not found for model {model_name}.')
119 | raise RuntimeError(f'Pretrained weights ({pretrained}) not found for model {model_name}.')
120 |
121 | model.to(device=device)
122 | if precision == "fp16":
123 | assert device.type != 'cpu'
124 | convert_weights_to_fp16(model)
125 |
126 | if jit:
127 | model = torch.jit.script(model)
128 |
129 | return model
130 |
131 |
132 | def create_model_and_transforms(
133 | model_name: str,
134 | pretrained: str = '',
135 | precision: str = 'fp32',
136 | device: torch.device = torch.device('cpu'),
137 | jit: bool = False,
138 | force_quick_gelu: bool = False,
139 | pretrained_image: bool = False,
140 | mean: Optional[Tuple[float, ...]] = None,
141 | std: Optional[Tuple[float, ...]] = None,
142 | ):
143 | model = create_model(
144 | model_name, pretrained, precision, device, jit,
145 | force_quick_gelu=force_quick_gelu,
146 | pretrained_image=pretrained_image)
147 | preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=mean, std=std)
148 | preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=mean, std=std)
149 | return model, preprocess_train, preprocess_val
150 |
151 |
152 | def list_models():
153 | """ enumerate available model architectures based on config files """
154 | return list(_MODEL_CONFIGS.keys())
155 |
156 |
157 | def add_model_config(path):
158 | """ add model config path or file and update registry """
159 | if not isinstance(path, Path):
160 | path = Path(path)
161 | _MODEL_CONFIG_PATHS.append(path)
162 | _rescan_model_configs()
163 |
--------------------------------------------------------------------------------
/src/open_clip/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 | try:
6 | import torch.distributed.nn
7 | from torch import distributed as dist
8 | has_distributed = True
9 | except ImportError:
10 | has_distributed = False
11 |
12 | try:
13 | import horovod.torch as hvd
14 | except ImportError:
15 | hvd = None
16 |
17 |
18 | def gather_features(
19 | image_features,
20 | text_features,
21 | local_loss=False,
22 | gather_with_grad=False,
23 | rank=0,
24 | world_size=1,
25 | use_horovod=False
26 | ):
27 | assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
28 | if use_horovod:
29 | assert hvd is not None, 'Please install horovod'
30 | if gather_with_grad:
31 | all_image_features = hvd.allgather(image_features)
32 | all_text_features = hvd.allgather(text_features)
33 | else:
34 | with torch.no_grad():
35 | all_image_features = hvd.allgather(image_features)
36 | all_text_features = hvd.allgather(text_features)
37 | if not local_loss:
38 | # ensure grads for local rank when all_* features don't have a gradient
39 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
40 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
41 | gathered_image_features[rank] = image_features
42 | gathered_text_features[rank] = text_features
43 | all_image_features = torch.cat(gathered_image_features, dim=0)
44 | all_text_features = torch.cat(gathered_text_features, dim=0)
45 | else:
46 | # We gather tensors from all gpus
47 | if gather_with_grad:
48 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
49 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
50 | else:
51 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
52 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
53 | dist.all_gather(gathered_image_features, image_features)
54 | dist.all_gather(gathered_text_features, text_features)
55 | if not local_loss:
56 | # ensure grads for local rank when all_* features don't have a gradient
57 | gathered_image_features[rank] = image_features
58 | gathered_text_features[rank] = text_features
59 | all_image_features = torch.cat(gathered_image_features, dim=0)
60 | all_text_features = torch.cat(gathered_text_features, dim=0)
61 |
62 | return all_image_features, all_text_features
63 |
64 |
65 | class ClipLoss(nn.Module):
66 |
67 | def __init__(
68 | self,
69 | local_loss=False,
70 | gather_with_grad=False,
71 | cache_labels=False,
72 | rank=0,
73 | world_size=1,
74 | use_horovod=False,
75 | ):
76 | super().__init__()
77 | self.local_loss = local_loss
78 | self.gather_with_grad = gather_with_grad
79 | self.cache_labels = cache_labels
80 | self.rank = rank
81 | self.world_size = world_size
82 | self.use_horovod = use_horovod
83 |
84 | # cache state
85 | self.prev_num_logits = 0
86 | self.labels = {}
87 |
88 | def forward(self, image_features, text_features, logit_scale):
89 | device = image_features.device
90 | if self.world_size > 1:
91 | all_image_features, all_text_features = gather_features(
92 | image_features, text_features,
93 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
94 |
95 | if self.local_loss:
96 | logits_per_image = logit_scale * image_features @ all_text_features.T
97 | logits_per_text = logit_scale * text_features @ all_image_features.T
98 | else:
99 | logits_per_image = logit_scale * all_image_features @ all_text_features.T
100 | logits_per_text = logits_per_image.T
101 | else:
102 | logits_per_image = logit_scale * image_features @ text_features.T
103 | logits_per_text = logit_scale * text_features @ image_features.T
104 |
105 | # calculated ground-truth and cache if enabled
106 | num_logits = logits_per_image.shape[0]
107 | if self.prev_num_logits != num_logits or device not in self.labels:
108 | labels = torch.arange(num_logits, device=device, dtype=torch.long)
109 | if self.world_size > 1 and self.local_loss:
110 | labels = labels + num_logits * self.rank
111 | if self.cache_labels:
112 | self.labels[device] = labels
113 | self.prev_num_logits = num_logits
114 | else:
115 | labels = self.labels[device]
116 |
117 | total_loss = (
118 | F.cross_entropy(logits_per_image, labels) +
119 | F.cross_entropy(logits_per_text[:len(logits_per_image)], labels)
120 | ) / 2
121 | return total_loss
122 |
--------------------------------------------------------------------------------
/src/open_clip/model.py:
--------------------------------------------------------------------------------
1 | """ CLIP Model
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 | from collections import OrderedDict
6 | from dataclasses import dataclass
7 | import logging
8 | import math
9 | from typing import Tuple, Union, Callable, Optional
10 |
11 | import numpy as np
12 | import torch
13 | import torch.nn.functional as F
14 | from torch import nn
15 | from torch.utils.checkpoint import checkpoint
16 |
17 | from .timm_model import TimmModel
18 | from .utils import freeze_batch_norm_2d, to_2tuple
19 |
20 |
21 | class Bottleneck(nn.Module):
22 | expansion = 4
23 |
24 | def __init__(self, inplanes, planes, stride=1):
25 | super().__init__()
26 |
27 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
28 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
29 | self.bn1 = nn.BatchNorm2d(planes)
30 | self.relu1 = nn.ReLU(inplace=True)
31 |
32 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
33 | self.bn2 = nn.BatchNorm2d(planes)
34 | self.relu2 = nn.ReLU(inplace=True)
35 |
36 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
37 |
38 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
39 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
40 | self.relu3 = nn.ReLU(inplace=True)
41 |
42 | self.downsample = None
43 | self.stride = stride
44 |
45 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
46 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
47 | self.downsample = nn.Sequential(OrderedDict([
48 | ("-1", nn.AvgPool2d(stride)),
49 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
50 | ("1", nn.BatchNorm2d(planes * self.expansion))
51 | ]))
52 |
53 | def forward(self, x: torch.Tensor):
54 | identity = x
55 |
56 | out = self.relu1(self.bn1(self.conv1(x)))
57 | out = self.relu2(self.bn2(self.conv2(out)))
58 | out = self.avgpool(out)
59 | out = self.bn3(self.conv3(out))
60 |
61 | if self.downsample is not None:
62 | identity = self.downsample(x)
63 |
64 | out += identity
65 | out = self.relu3(out)
66 | return out
67 |
68 |
69 | class AttentionPool2d(nn.Module):
70 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
71 | super().__init__()
72 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
73 | self.k_proj = nn.Linear(embed_dim, embed_dim)
74 | self.q_proj = nn.Linear(embed_dim, embed_dim)
75 | self.v_proj = nn.Linear(embed_dim, embed_dim)
76 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
77 | self.num_heads = num_heads
78 |
79 | def forward(self, x):
80 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
81 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
82 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
83 | x, _ = F.multi_head_attention_forward(
84 | query=x, key=x, value=x,
85 | embed_dim_to_check=x.shape[-1],
86 | num_heads=self.num_heads,
87 | q_proj_weight=self.q_proj.weight,
88 | k_proj_weight=self.k_proj.weight,
89 | v_proj_weight=self.v_proj.weight,
90 | in_proj_weight=None,
91 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
92 | bias_k=None,
93 | bias_v=None,
94 | add_zero_attn=False,
95 | dropout_p=0,
96 | out_proj_weight=self.c_proj.weight,
97 | out_proj_bias=self.c_proj.bias,
98 | use_separate_proj_weight=True,
99 | training=self.training,
100 | need_weights=False
101 | )
102 |
103 | return x[0]
104 |
105 |
106 | class ModifiedResNet(nn.Module):
107 | """
108 | A ResNet class that is similar to torchvision's but contains the following changes:
109 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
110 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
111 | - The final pooling layer is a QKV attention instead of an average pool
112 | """
113 |
114 | def __init__(self, layers, output_dim, heads, image_size=224, width=64):
115 | super().__init__()
116 | self.output_dim = output_dim
117 | self.image_size = image_size
118 |
119 | # the 3-layer stem
120 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
121 | self.bn1 = nn.BatchNorm2d(width // 2)
122 | self.relu1 = nn.ReLU(inplace=True)
123 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
124 | self.bn2 = nn.BatchNorm2d(width // 2)
125 | self.relu2 = nn.ReLU(inplace=True)
126 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
127 | self.bn3 = nn.BatchNorm2d(width)
128 | self.relu3 = nn.ReLU(inplace=True)
129 | self.avgpool = nn.AvgPool2d(2)
130 |
131 | # residual layers
132 | self._inplanes = width # this is a *mutable* variable used during construction
133 | self.layer1 = self._make_layer(width, layers[0])
134 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
135 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
136 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
137 |
138 | embed_dim = width * 32 # the ResNet feature dimension
139 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
140 |
141 | self.init_parameters()
142 |
143 | def _make_layer(self, planes, blocks, stride=1):
144 | layers = [Bottleneck(self._inplanes, planes, stride)]
145 |
146 | self._inplanes = planes * Bottleneck.expansion
147 | for _ in range(1, blocks):
148 | layers.append(Bottleneck(self._inplanes, planes))
149 |
150 | return nn.Sequential(*layers)
151 |
152 | def init_parameters(self):
153 | if self.attnpool is not None:
154 | std = self.attnpool.c_proj.in_features ** -0.5
155 | nn.init.normal_(self.attnpool.q_proj.weight, std=std)
156 | nn.init.normal_(self.attnpool.k_proj.weight, std=std)
157 | nn.init.normal_(self.attnpool.v_proj.weight, std=std)
158 | nn.init.normal_(self.attnpool.c_proj.weight, std=std)
159 |
160 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
161 | for name, param in resnet_block.named_parameters():
162 | if name.endswith("bn3.weight"):
163 | nn.init.zeros_(param)
164 |
165 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
166 | assert unlocked_groups == 0, 'partial locking not currently supported for this model'
167 | for param in self.parameters():
168 | param.requires_grad = False
169 | if freeze_bn_stats:
170 | freeze_batch_norm_2d(self)
171 |
172 | @torch.jit.ignore
173 | def set_grad_checkpointing(self, enable=True):
174 | # FIXME support for non-transformer
175 | pass
176 |
177 | def stem(self, x):
178 | x = self.relu1(self.bn1(self.conv1(x)))
179 | x = self.relu2(self.bn2(self.conv2(x)))
180 | x = self.relu3(self.bn3(self.conv3(x)))
181 | x = self.avgpool(x)
182 | return x
183 |
184 | def forward(self, x):
185 | x = self.stem(x)
186 | x = self.layer1(x)
187 | x = self.layer2(x)
188 | x = self.layer3(x)
189 | x = self.layer4(x)
190 | x = self.attnpool(x)
191 |
192 | return x
193 |
194 |
195 | class LayerNorm(nn.LayerNorm):
196 | """Subclass torch's LayerNorm to handle fp16."""
197 |
198 | def forward(self, x: torch.Tensor):
199 | orig_type = x.dtype
200 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
201 | return x.to(orig_type)
202 |
203 |
204 | class QuickGELU(nn.Module):
205 | # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
206 | def forward(self, x: torch.Tensor):
207 | return x * torch.sigmoid(1.702 * x)
208 |
209 |
210 | class ResidualAttentionBlock(nn.Module):
211 | def __init__(self, d_model: int, n_head: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU):
212 | super().__init__()
213 |
214 | self.attn = nn.MultiheadAttention(d_model, n_head)
215 | self.ln_1 = LayerNorm(d_model)
216 | mlp_width = int(d_model * mlp_ratio)
217 | self.mlp = nn.Sequential(OrderedDict([
218 | ("c_fc", nn.Linear(d_model, mlp_width)),
219 | ("gelu", act_layer()),
220 | ("c_proj", nn.Linear(mlp_width, d_model))
221 | ]))
222 | self.ln_2 = LayerNorm(d_model)
223 |
224 | def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
225 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
226 |
227 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
228 | x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
229 | x = x + self.mlp(self.ln_2(x))
230 | return x
231 |
232 |
233 | class Transformer(nn.Module):
234 | def __init__(self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU):
235 | super().__init__()
236 | self.width = width
237 | self.layers = layers
238 | self.grad_checkpointing = False
239 |
240 | self.resblocks = nn.ModuleList([
241 | ResidualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer)
242 | for _ in range(layers)
243 | ])
244 |
245 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
246 | for r in self.resblocks:
247 | if self.grad_checkpointing and not torch.jit.is_scripting():
248 | x = checkpoint(r, x, attn_mask)
249 | else:
250 | x = r(x, attn_mask=attn_mask)
251 | return x
252 |
253 |
254 | class VisualTransformer(nn.Module):
255 | def __init__(
256 | self, image_size: int, patch_size: int, width: int, layers: int, heads: int, mlp_ratio: float,
257 | output_dim: int, act_layer: Callable = nn.GELU):
258 | super().__init__()
259 | self.image_size = to_2tuple(image_size)
260 | self.patch_size = to_2tuple(patch_size)
261 | self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
262 | self.output_dim = output_dim
263 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
264 |
265 | scale = width ** -0.5
266 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
267 | self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
268 | self.ln_pre = LayerNorm(width)
269 |
270 | self.transformer = Transformer(width, layers, heads, mlp_ratio, act_layer=act_layer)
271 |
272 | self.ln_post = LayerNorm(width)
273 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
274 |
275 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
276 | assert unlocked_groups == 0, 'partial locking not currently supported for this model'
277 | for param in self.parameters():
278 | param.requires_grad = False
279 |
280 | @torch.jit.ignore
281 | def set_grad_checkpointing(self, enable=True):
282 | self.transformer.grad_checkpointing = enable
283 |
284 | def forward(self, x: torch.Tensor):
285 | x = self.conv1(x) # shape = [*, width, grid, grid]
286 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
287 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
288 | x = torch.cat(
289 | [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
290 | x], dim=1) # shape = [*, grid ** 2 + 1, width]
291 | x = x + self.positional_embedding.to(x.dtype)
292 | x = self.ln_pre(x)
293 |
294 | x = x.permute(1, 0, 2) # NLD -> LND
295 | x = self.transformer(x)
296 | x = x.permute(1, 0, 2) # LND -> NLD
297 |
298 | x = self.ln_post(x[:, 0, :])
299 |
300 | if self.proj is not None:
301 | x = x @ self.proj
302 |
303 | return x
304 |
305 |
306 | @dataclass
307 | class CLIPVisionCfg:
308 | layers: Union[Tuple[int, int, int, int], int] = 12
309 | width: int = 768
310 | head_width: int = 64
311 | mlp_ratio: float = 4.0
312 | patch_size: int = 16
313 | image_size: Union[Tuple[int, int], int] = 224
314 | timm_model_name: str = None # a valid model name overrides layers, width, patch_size
315 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
316 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
317 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
318 |
319 |
320 | @dataclass
321 | class CLIPTextCfg:
322 | context_length: int = 77
323 | vocab_size: int = 49408
324 | width: int = 512
325 | heads: int = 8
326 | layers: int = 12
327 |
328 |
329 | class CLIP(nn.Module):
330 | def __init__(
331 | self,
332 | embed_dim: int,
333 | vision_cfg: CLIPVisionCfg,
334 | text_cfg: CLIPTextCfg,
335 | quick_gelu: bool = False,
336 | ):
337 | super().__init__()
338 | if isinstance(vision_cfg, dict):
339 | vision_cfg = CLIPVisionCfg(**vision_cfg)
340 | if isinstance(text_cfg, dict):
341 | text_cfg = CLIPTextCfg(**text_cfg)
342 |
343 | self.context_length = text_cfg.context_length
344 |
345 | # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
346 | # memory efficient in recent PyTorch releases (>= 1.10).
347 | # NOTE: timm models always use native GELU regardless of quick_gelu flag.
348 | act_layer = QuickGELU if quick_gelu else nn.GELU
349 |
350 | if vision_cfg.timm_model_name:
351 | self.visual = TimmModel(
352 | vision_cfg.timm_model_name,
353 | pretrained=vision_cfg.timm_model_pretrained,
354 | pool=vision_cfg.timm_pool,
355 | proj=vision_cfg.timm_proj,
356 | embed_dim=embed_dim,
357 | image_size=vision_cfg.image_size
358 | )
359 | act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
360 | elif isinstance(vision_cfg.layers, (tuple, list)):
361 | vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
362 | self.visual = ModifiedResNet(
363 | layers=vision_cfg.layers,
364 | output_dim=embed_dim,
365 | heads=vision_heads,
366 | image_size=vision_cfg.image_size,
367 | width=vision_cfg.width
368 | )
369 | else:
370 | vision_heads = vision_cfg.width // vision_cfg.head_width
371 | self.visual = VisualTransformer(
372 | image_size=vision_cfg.image_size,
373 | patch_size=vision_cfg.patch_size,
374 | width=vision_cfg.width,
375 | layers=vision_cfg.layers,
376 | heads=vision_heads,
377 | mlp_ratio=vision_cfg.mlp_ratio,
378 | output_dim=embed_dim,
379 | act_layer=act_layer,
380 | )
381 |
382 | self.transformer = Transformer(
383 | width=text_cfg.width,
384 | layers=text_cfg.layers,
385 | heads=text_cfg.heads,
386 | act_layer=act_layer,
387 | )
388 |
389 | self.vocab_size = text_cfg.vocab_size
390 | self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
391 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, text_cfg.width))
392 | self.ln_final = LayerNorm(text_cfg.width)
393 |
394 | self.text_projection = nn.Parameter(torch.empty(text_cfg.width, embed_dim))
395 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
396 | self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
397 |
398 | self.init_parameters()
399 |
400 | def init_parameters(self):
401 | nn.init.normal_(self.token_embedding.weight, std=0.02)
402 | nn.init.normal_(self.positional_embedding, std=0.01)
403 | nn.init.constant_(self.logit_scale, np.log(1 / 0.07))
404 |
405 | if hasattr(self.visual, 'init_parameters'):
406 | self.visual.init_parameters()
407 |
408 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
409 | attn_std = self.transformer.width ** -0.5
410 | fc_std = (2 * self.transformer.width) ** -0.5
411 | for block in self.transformer.resblocks:
412 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
413 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
414 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
415 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
416 |
417 | if self.text_projection is not None:
418 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
419 |
420 | def build_attention_mask(self):
421 | # lazily create causal attention mask, with full attention between the vision tokens
422 | # pytorch uses additive attention mask; fill with -inf
423 | mask = torch.empty(self.context_length, self.context_length)
424 | mask.fill_(float("-inf"))
425 | mask.triu_(1) # zero out the lower diagonal
426 | return mask
427 |
428 | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
429 | # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
430 | self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
431 |
432 | @torch.jit.ignore
433 | def set_grad_checkpointing(self, enable=True):
434 | self.visual.set_grad_checkpointing(enable)
435 | self.transformer.grad_checkpointing = enable
436 |
437 | def encode_image(self, image):
438 | return self.visual(image)
439 |
440 | def encode_text(self, text):
441 | x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
442 |
443 | x = x + self.positional_embedding
444 | x = x.permute(1, 0, 2) # NLD -> LND
445 | x = self.transformer(x, attn_mask=self.attn_mask)
446 | x = x.permute(1, 0, 2) # LND -> NLD
447 | x = self.ln_final(x)
448 |
449 | # x.shape = [batch_size, n_ctx, transformer.width]
450 | # take features from the eot embedding (eot_token is the highest number in each sequence)
451 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
452 |
453 | return x
454 |
455 | def forward(self, image, text):
456 | if image is None:
457 | return self.encode_text(text)
458 | elif text is None:
459 | return self.encode_image(image)
460 | image_features = self.encode_image(image)
461 | image_features = F.normalize(image_features, dim=-1)
462 |
463 | text_features = self.encode_text(text)
464 | text_features = F.normalize(text_features, dim=-1)
465 |
466 | return image_features, text_features, self.logit_scale.exp()
467 |
468 |
469 | def convert_weights_to_fp16(model: nn.Module):
470 | """Convert applicable model parameters to fp16"""
471 |
472 | def _convert_weights_to_fp16(l):
473 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
474 | l.weight.data = l.weight.data.half()
475 | if l.bias is not None:
476 | l.bias.data = l.bias.data.half()
477 |
478 | if isinstance(l, nn.MultiheadAttention):
479 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
480 | tensor = getattr(l, attr)
481 | if tensor is not None:
482 | tensor.data = tensor.data.half()
483 |
484 | for name in ["text_projection", "proj"]:
485 | if hasattr(l, name):
486 | attr = getattr(l, name)
487 | if attr is not None:
488 | attr.data = attr.data.half()
489 |
490 | model.apply(_convert_weights_to_fp16)
491 |
492 |
493 | def build_model_from_openai_state_dict(state_dict: dict):
494 | vit = "visual.proj" in state_dict
495 |
496 | if vit:
497 | vision_width = state_dict["visual.conv1.weight"].shape[0]
498 | vision_layers = len(
499 | [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
500 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
501 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
502 | image_size = vision_patch_size * grid_size
503 | else:
504 | counts: list = [
505 | len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
506 | vision_layers = tuple(counts)
507 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
508 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
509 | vision_patch_size = None
510 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
511 | image_size = output_width * 32
512 |
513 | embed_dim = state_dict["text_projection"].shape[1]
514 | context_length = state_dict["positional_embedding"].shape[0]
515 | vocab_size = state_dict["token_embedding.weight"].shape[0]
516 | transformer_width = state_dict["ln_final.weight"].shape[0]
517 | transformer_heads = transformer_width // 64
518 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
519 |
520 | vision_cfg = CLIPVisionCfg(
521 | layers=vision_layers,
522 | width=vision_width,
523 | patch_size=vision_patch_size,
524 | image_size=image_size,
525 | )
526 | text_cfg = CLIPTextCfg(
527 | context_length=context_length,
528 | vocab_size=vocab_size,
529 | width=transformer_width,
530 | heads=transformer_heads,
531 | layers=transformer_layers
532 | )
533 | model = CLIP(
534 | embed_dim,
535 | vision_cfg=vision_cfg,
536 | text_cfg=text_cfg,
537 | quick_gelu=True, # OpenAI models were trained with QuickGELU
538 | )
539 |
540 | for key in ["input_resolution", "context_length", "vocab_size"]:
541 | state_dict.pop(key, None)
542 |
543 | convert_weights_to_fp16(model)
544 | model.load_state_dict(state_dict)
545 | return model.eval()
546 |
547 |
548 | def trace_model(model, batch_size=256, device=torch.device('cpu')):
549 | model.eval()
550 | image_size = model.visual.image_size
551 | example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
552 | example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
553 | model = torch.jit.trace_module(
554 | model,
555 | inputs=dict(
556 | forward=(example_images, example_text),
557 | encode_text=(example_text,),
558 | encode_image=(example_images,)
559 | ))
560 | model.visual.image_size = image_size
561 | return model
562 |
563 |
564 | def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
565 | # Rescale the grid of position embeddings when loading from state_dict
566 | old_pos_embed = state_dict.get('visual.positional_embedding', None)
567 | if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
568 | return
569 | grid_size = to_2tuple(model.visual.grid_size)
570 | extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
571 | new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
572 | if new_seq_len == old_pos_embed.shape[0]:
573 | return
574 |
575 | if extra_tokens:
576 | pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
577 | else:
578 | pos_emb_tok, pos_emb_img = None, old_pos_embed
579 | old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
580 |
581 | logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
582 | pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
583 | pos_emb_img = F.interpolate(
584 | pos_emb_img,
585 | size=grid_size,
586 | mode=interpolation,
587 | align_corners=True,
588 | )
589 | pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
590 | if pos_emb_tok is not None:
591 | new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
592 | else:
593 | new_pos_embed = pos_emb_img
594 | state_dict['visual.positional_embedding'] = new_pos_embed
595 |
--------------------------------------------------------------------------------
/src/open_clip/model_configs/RN101-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": [
7 | 3,
8 | 4,
9 | 23,
10 | 3
11 | ],
12 | "width": 64,
13 | "patch_size": null
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 512,
19 | "heads": 8,
20 | "layers": 12
21 | }
22 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/RN101.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": [
6 | 3,
7 | 4,
8 | 23,
9 | 3
10 | ],
11 | "width": 64,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 512,
18 | "heads": 8,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/RN50-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": [
7 | 3,
8 | 4,
9 | 6,
10 | 3
11 | ],
12 | "width": 64,
13 | "patch_size": null
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 512,
19 | "heads": 8,
20 | "layers": 12
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/src/open_clip/model_configs/RN50.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": [
6 | 3,
7 | 4,
8 | 6,
9 | 3
10 | ],
11 | "width": 64,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 512,
18 | "heads": 8,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/RN50x16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 384,
5 | "layers": [
6 | 6,
7 | 8,
8 | 18,
9 | 8
10 | ],
11 | "width": 96,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 768,
18 | "heads": 12,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/RN50x4.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 288,
5 | "layers": [
6 | 4,
7 | 6,
8 | 10,
9 | 6
10 | ],
11 | "width": 80,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 640,
18 | "heads": 10,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/ViT-B-16-plus-240.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 240,
5 | "layers": 12,
6 | "width": 896,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 640,
13 | "heads": 10,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/ViT-B-16-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 896,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 640,
13 | "heads": 10,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/ViT-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/ViT-B-32-plus-256.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 256,
5 | "layers": 12,
6 | "width": 896,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 640,
13 | "heads": 10,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/ViT-B-32-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": 12,
7 | "width": 768,
8 | "patch_size": 32
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/ViT-B-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/ViT-H-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 32,
6 | "width": 1280,
7 | "head_width": 80,
8 | "patch_size": 14
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 1024,
14 | "heads": 16,
15 | "layers": 24
16 | }
17 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/ViT-H-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 32,
6 | "width": 1280,
7 | "head_width": 80,
8 | "patch_size": 16
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 1024,
14 | "heads": 16,
15 | "layers": 24
16 | }
17 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/ViT-L-14-280.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 280,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/ViT-L-14-336.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 336,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/ViT-L-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/ViT-L-16-320.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 320,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/ViT-L-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/ViT-g-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 40,
6 | "width": 1408,
7 | "head_width": 88,
8 | "mlp_ratio": 4.3637,
9 | "patch_size": 14
10 | },
11 | "text_cfg": {
12 | "context_length": 77,
13 | "vocab_size": 49408,
14 | "width": 1024,
15 | "heads": 16,
16 | "layers": 24
17 | }
18 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/timm-efficientnetv2_rw_s.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "timm_model_name": "efficientnetv2_rw_s",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "abs_attn",
7 | "timm_proj": "",
8 | "image_size": 288
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 768,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/timm-resnet50d.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "timm_model_name": "resnet50d",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "abs_attn",
7 | "timm_proj": "",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/src/open_clip/model_configs/timm-resnetaa50d.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "timm_model_name": "resnetaa50d",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "abs_attn",
7 | "timm_proj": "",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/src/open_clip/model_configs/timm-resnetblur50.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "timm_model_name": "resnetblur50",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "abs_attn",
7 | "timm_proj": "",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/src/open_clip/model_configs/timm-swin_base_patch4_window7_224.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "swin_base_patch4_window7_224",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/timm-vit_base_patch16_224.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "vit_base_patch16_224",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/timm-vit_base_patch32_224.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "vit_base_patch32_224",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/src/open_clip/model_configs/timm-vit_small_patch16_224.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "vit_small_patch16_224",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/src/open_clip/openai.py:
--------------------------------------------------------------------------------
1 | """ OpenAI pretrained model functions
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 |
6 | import os
7 | import warnings
8 | from typing import Union, List
9 |
10 | import torch
11 |
12 | from .model import build_model_from_openai_state_dict
13 | from .pretrained import get_pretrained_url, list_pretrained_tag_models, download_pretrained
14 |
15 | __all__ = ["list_openai_models", "load_openai_model"]
16 |
17 |
18 | def list_openai_models() -> List[str]:
19 | """Returns the names of available CLIP models"""
20 | return list_pretrained_tag_models('openai')
21 |
22 |
23 | def load_openai_model(
24 | name: str,
25 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
26 | jit=True,
27 | ):
28 | """Load a CLIP model
29 |
30 | Parameters
31 | ----------
32 | name : str
33 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
34 | device : Union[str, torch.device]
35 | The device to put the loaded model
36 | jit : bool
37 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
38 |
39 | Returns
40 | -------
41 | model : torch.nn.Module
42 | The CLIP model
43 | preprocess : Callable[[PIL.Image], torch.Tensor]
44 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
45 | """
46 | if get_pretrained_url(name, 'openai'):
47 | model_path = download_pretrained(get_pretrained_url(name, 'openai'))
48 | elif os.path.isfile(name):
49 | model_path = name
50 | else:
51 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
52 |
53 | try:
54 | # loading JIT archive
55 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
56 | state_dict = None
57 | except RuntimeError:
58 | # loading saved state dict
59 | if jit:
60 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
61 | jit = False
62 | state_dict = torch.load(model_path, map_location="cpu")
63 |
64 | if not jit:
65 | try:
66 | model = build_model_from_openai_state_dict(state_dict or model.state_dict()).to(device)
67 | except KeyError:
68 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
69 | model = build_model_from_openai_state_dict(sd).to(device)
70 |
71 | if str(device) == "cpu":
72 | model.float()
73 | return model
74 |
75 | # patch the device names
76 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
77 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
78 |
79 | def patch_device(module):
80 | try:
81 | graphs = [module.graph] if hasattr(module, "graph") else []
82 | except RuntimeError:
83 | graphs = []
84 |
85 | if hasattr(module, "forward1"):
86 | graphs.append(module.forward1.graph)
87 |
88 | for graph in graphs:
89 | for node in graph.findAllNodes("prim::Constant"):
90 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
91 | node.copyAttributes(device_node)
92 |
93 | model.apply(patch_device)
94 | patch_device(model.encode_image)
95 | patch_device(model.encode_text)
96 |
97 | # patch dtype to float32 on CPU
98 | if str(device) == "cpu":
99 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
100 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
101 | float_node = float_input.node()
102 |
103 | def patch_float(module):
104 | try:
105 | graphs = [module.graph] if hasattr(module, "graph") else []
106 | except RuntimeError:
107 | graphs = []
108 |
109 | if hasattr(module, "forward1"):
110 | graphs.append(module.forward1.graph)
111 |
112 | for graph in graphs:
113 | for node in graph.findAllNodes("aten::to"):
114 | inputs = list(node.inputs())
115 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
116 | if inputs[i].node()["value"] == 5:
117 | inputs[i].node().copyAttributes(float_node)
118 |
119 | model.apply(patch_float)
120 | patch_float(model.encode_image)
121 | patch_float(model.encode_text)
122 | model.float()
123 |
124 | # ensure image_size attr available at consistent location for both jit and non-jit
125 | model.visual.image_size = model.input_resolution.item()
126 | return model
127 |
--------------------------------------------------------------------------------
/src/open_clip/pretrained.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 |
6 | from tqdm import tqdm
7 |
8 | _RN50 = dict(
9 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
10 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
11 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"
12 | )
13 |
14 | _RN50_quickgelu = dict(
15 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
16 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
17 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"
18 | )
19 |
20 | _RN101 = dict(
21 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
22 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"
23 | )
24 |
25 | _RN101_quickgelu = dict(
26 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
27 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"
28 | )
29 |
30 | _RN50x4 = dict(
31 | openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
32 | )
33 |
34 | _RN50x16 = dict(
35 | openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
36 | )
37 |
38 | _RN50x64 = dict(
39 | openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
40 | )
41 |
42 | _VITB32 = dict(
43 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
44 | laion2b_e16="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth",
45 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
46 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
47 | )
48 |
49 | _VITB32_quickgelu = dict(
50 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
51 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
52 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
53 | )
54 |
55 | _VITB16 = dict(
56 | openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
57 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt",
58 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt",
59 | )
60 |
61 | _VITB16_PLUS_240 = dict(
62 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt",
63 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt",
64 | )
65 |
66 | _VITL14 = dict(
67 | openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
68 | laion400m_e31='https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt',
69 | laion400m_e32='https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt',
70 | )
71 |
72 | _VITL14_336 = dict(
73 | openai="https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"
74 | )
75 |
76 | _PRETRAINED = {
77 | "RN50": _RN50,
78 | "RN50-quickgelu": _RN50_quickgelu,
79 | "RN101": _RN101,
80 | "RN101-quickgelu": _RN101_quickgelu,
81 | "RN50x4": _RN50x4,
82 | "RN50x16": _RN50x16,
83 | "RN50x64": _RN50x64,
84 | "ViT-B-32": _VITB32,
85 | "ViT-B-32-quickgelu": _VITB32_quickgelu,
86 | "ViT-B-16": _VITB16,
87 | "ViT-B-16-plus-240": _VITB16_PLUS_240,
88 | "ViT-L-14": _VITL14,
89 | "ViT-L-14-336": _VITL14_336,
90 | }
91 |
92 |
93 | def list_pretrained(as_str: bool = False):
94 | """ returns list of pretrained models
95 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
96 | """
97 | return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
98 |
99 |
100 | def list_pretrained_tag_models(tag: str):
101 | """ return all models having the specified pretrain tag """
102 | models = []
103 | for k in _PRETRAINED.keys():
104 | if tag in _PRETRAINED[k]:
105 | models.append(k)
106 | return models
107 |
108 |
109 | def list_pretrained_model_tags(model: str):
110 | """ return all pretrain tags for the specified model architecture """
111 | tags = []
112 | if model in _PRETRAINED:
113 | tags.extend(_PRETRAINED[model].keys())
114 | return tags
115 |
116 |
117 | def get_pretrained_url(model: str, tag: str):
118 | if model not in _PRETRAINED:
119 | return ''
120 | model_pretrained = _PRETRAINED[model]
121 | tag = tag.lower()
122 | if tag not in model_pretrained:
123 | return ''
124 | return model_pretrained[tag]
125 |
126 |
127 | def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")):
128 | os.makedirs(root, exist_ok=True)
129 | filename = os.path.basename(url)
130 |
131 | if 'openaipublic' in url:
132 | expected_sha256 = url.split("/")[-2]
133 | else:
134 | expected_sha256 = ''
135 |
136 | download_target = os.path.join(root, filename)
137 |
138 | if os.path.exists(download_target) and not os.path.isfile(download_target):
139 | raise RuntimeError(f"{download_target} exists and is not a regular file")
140 |
141 | if os.path.isfile(download_target):
142 | if expected_sha256:
143 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
144 | return download_target
145 | else:
146 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
147 | else:
148 | return download_target
149 |
150 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
151 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
152 | while True:
153 | buffer = source.read(8192)
154 | if not buffer:
155 | break
156 |
157 | output.write(buffer)
158 | loop.update(len(buffer))
159 |
160 | if expected_sha256 and hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
161 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
162 |
163 | return download_target
164 |
--------------------------------------------------------------------------------
/src/open_clip/timm_model.py:
--------------------------------------------------------------------------------
1 | """ timm model adapter
2 |
3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4 | """
5 | from collections import OrderedDict
6 |
7 | import torch.nn as nn
8 |
9 | try:
10 | import timm
11 | from timm.models.layers import Mlp, to_2tuple
12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d
13 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
14 | except ImportError as e:
15 | timm = None
16 |
17 | from .utils import freeze_batch_norm_2d
18 |
19 |
20 | class TimmModel(nn.Module):
21 | """ timm model adapter
22 | # FIXME this adapter is a work in progress, may change in ways that break weight compat
23 | """
24 |
25 | def __init__(
26 | self,
27 | model_name,
28 | embed_dim,
29 | image_size=224,
30 | pool='avg',
31 | proj='linear',
32 | drop=0.,
33 | pretrained=False):
34 | super().__init__()
35 | if timm is None:
36 | raise RuntimeError("Please `pip install timm` to use timm models.")
37 |
38 | self.image_size = to_2tuple(image_size)
39 | self.trunk = timm.create_model(model_name, pretrained=pretrained)
40 | feat_size = self.trunk.default_cfg.get('pool_size', None)
41 | feature_ndim = 1 if not feat_size else 2
42 | if pool in ('abs_attn', 'rot_attn'):
43 | assert feature_ndim == 2
44 | # if attn pooling used, remove both classifier and default pool
45 | self.trunk.reset_classifier(0, global_pool='')
46 | else:
47 | # reset global pool if pool config set, otherwise leave as network default
48 | reset_kwargs = dict(global_pool=pool) if pool else {}
49 | self.trunk.reset_classifier(0, **reset_kwargs)
50 | prev_chs = self.trunk.num_features
51 |
52 | head_layers = OrderedDict()
53 | if pool == 'abs_attn':
54 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
55 | prev_chs = embed_dim
56 | elif pool == 'rot_attn':
57 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
58 | prev_chs = embed_dim
59 | else:
60 | assert proj, 'projection layer needed if non-attention pooling is used.'
61 |
62 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
63 | if proj == 'linear':
64 | head_layers['drop'] = nn.Dropout(drop)
65 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim)
66 | elif proj == 'mlp':
67 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
68 |
69 | self.head = nn.Sequential(head_layers)
70 |
71 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
72 | """ lock modules
73 | Args:
74 | unlocked_groups (int): leave last n layer groups unlocked (default: 0)
75 | """
76 | if not unlocked_groups:
77 | # lock full model
78 | for param in self.trunk.parameters():
79 | param.requires_grad = False
80 | if freeze_bn_stats:
81 | freeze_batch_norm_2d(self.trunk)
82 | else:
83 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change
84 | try:
85 | # FIXME import here until API stable and in an official release
86 | from timm.models.helpers import group_parameters, group_modules
87 | except ImportError:
88 | raise RuntimeError(
89 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
90 | matcher = self.trunk.group_matcher()
91 | gparams = group_parameters(self.trunk, matcher)
92 | max_layer_id = max(gparams.keys())
93 | max_layer_id = max_layer_id - unlocked_groups
94 | for group_idx in range(max_layer_id + 1):
95 | group = gparams[group_idx]
96 | for param in group:
97 | self.trunk.get_parameter(param).requires_grad = False
98 | if freeze_bn_stats:
99 | gmodules = group_modules(self.trunk, matcher, reverse=True)
100 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
101 | freeze_batch_norm_2d(self.trunk, gmodules)
102 |
103 | def forward(self, x):
104 | x = self.trunk(x)
105 | x = self.head(x)
106 | return x
107 |
--------------------------------------------------------------------------------
/src/open_clip/tokenizer.py:
--------------------------------------------------------------------------------
1 | """ CLIP tokenizer
2 |
3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 | import gzip
6 | import html
7 | import os
8 | from functools import lru_cache
9 | from typing import Union, List
10 |
11 | import ftfy
12 | import regex as re
13 | import torch
14 |
15 |
16 | @lru_cache()
17 | def default_bpe():
18 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
19 |
20 |
21 | @lru_cache()
22 | def bytes_to_unicode():
23 | """
24 | Returns list of utf-8 byte and a corresponding list of unicode strings.
25 | The reversible bpe codes work on unicode strings.
26 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
27 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
28 | This is a signficant percentage of your normal, say, 32K bpe vocab.
29 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
30 | And avoids mapping to whitespace/control characters the bpe code barfs on.
31 | """
32 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
33 | cs = bs[:]
34 | n = 0
35 | for b in range(2**8):
36 | if b not in bs:
37 | bs.append(b)
38 | cs.append(2**8+n)
39 | n += 1
40 | cs = [chr(n) for n in cs]
41 | return dict(zip(bs, cs))
42 |
43 |
44 | def get_pairs(word):
45 | """Return set of symbol pairs in a word.
46 | Word is represented as tuple of symbols (symbols being variable-length strings).
47 | """
48 | pairs = set()
49 | prev_char = word[0]
50 | for char in word[1:]:
51 | pairs.add((prev_char, char))
52 | prev_char = char
53 | return pairs
54 |
55 |
56 | def basic_clean(text):
57 | text = ftfy.fix_text(text)
58 | text = html.unescape(html.unescape(text))
59 | return text.strip()
60 |
61 |
62 | def whitespace_clean(text):
63 | text = re.sub(r'\s+', ' ', text)
64 | text = text.strip()
65 | return text
66 |
67 |
68 | class SimpleTokenizer(object):
69 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
70 | self.byte_encoder = bytes_to_unicode()
71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
72 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
73 | merges = merges[1:49152-256-2+1]
74 | merges = [tuple(merge.split()) for merge in merges]
75 | vocab = list(bytes_to_unicode().values())
76 | vocab = vocab + [v+'' for v in vocab]
77 | for merge in merges:
78 | vocab.append(''.join(merge))
79 | if not special_tokens:
80 | special_tokens = ['', '']
81 | else:
82 | special_tokens = ['', ''] + special_tokens
83 | vocab.extend(special_tokens)
84 | self.encoder = dict(zip(vocab, range(len(vocab))))
85 | self.decoder = {v: k for k, v in self.encoder.items()}
86 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
87 | self.cache = {t:t for t in special_tokens}
88 | special = "|".join(special_tokens)
89 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
90 |
91 | self.vocab_size = len(self.encoder)
92 | self.all_special_ids = [self.encoder[t] for t in special_tokens]
93 |
94 | def bpe(self, token):
95 | if token in self.cache:
96 | return self.cache[token]
97 | word = tuple(token[:-1]) + ( token[-1] + '',)
98 | pairs = get_pairs(word)
99 |
100 | if not pairs:
101 | return token+''
102 |
103 | while True:
104 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
105 | if bigram not in self.bpe_ranks:
106 | break
107 | first, second = bigram
108 | new_word = []
109 | i = 0
110 | while i < len(word):
111 | try:
112 | j = word.index(first, i)
113 | new_word.extend(word[i:j])
114 | i = j
115 | except:
116 | new_word.extend(word[i:])
117 | break
118 |
119 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
120 | new_word.append(first+second)
121 | i += 2
122 | else:
123 | new_word.append(word[i])
124 | i += 1
125 | new_word = tuple(new_word)
126 | word = new_word
127 | if len(word) == 1:
128 | break
129 | else:
130 | pairs = get_pairs(word)
131 | word = ' '.join(word)
132 | self.cache[token] = word
133 | return word
134 |
135 | def encode(self, text):
136 | bpe_tokens = []
137 | text = whitespace_clean(basic_clean(text)).lower()
138 | for token in re.findall(self.pat, text):
139 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
140 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
141 | return bpe_tokens
142 |
143 | def decode(self, tokens):
144 | text = ''.join([self.decoder[token] for token in tokens])
145 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
146 | return text
147 |
148 |
149 | _tokenizer = SimpleTokenizer()
150 |
151 |
152 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
153 | """
154 | Returns the tokenized representation of given input string(s)
155 |
156 | Parameters
157 | ----------
158 | texts : Union[str, List[str]]
159 | An input string or a list of input strings to tokenize
160 | context_length : int
161 | The context length to use; all CLIP models use 77 as the context length
162 |
163 | Returns
164 | -------
165 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
166 | """
167 | if isinstance(texts, str):
168 | texts = [texts]
169 |
170 | sot_token = _tokenizer.encoder[""]
171 | eot_token = _tokenizer.encoder[""]
172 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
173 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
174 |
175 | for i, tokens in enumerate(all_tokens):
176 | if len(tokens) > context_length:
177 | tokens = tokens[:context_length] # Truncate
178 | tokens[-1] = eot_token
179 | result[i, :len(tokens)] = torch.tensor(tokens)
180 |
181 | return result
182 |
--------------------------------------------------------------------------------
/src/open_clip/transform.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Sequence, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torchvision.transforms.functional as F
6 |
7 |
8 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
9 | CenterCrop
10 |
11 |
12 | class ResizeMaxSize(nn.Module):
13 |
14 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
15 | super().__init__()
16 | if not isinstance(max_size, int):
17 | raise TypeError(f"Size should be int. Got {type(max_size)}")
18 | self.max_size = max_size
19 | self.interpolation = interpolation
20 | self.fn = min if fn == 'min' else min
21 | self.fill = fill
22 |
23 | def forward(self, img):
24 | if isinstance(img, torch.Tensor):
25 | height, width = img.shape[:2]
26 | else:
27 | width, height = img.size
28 | scale = self.max_size / float(max(height, width))
29 | if scale != 1.0:
30 | new_size = tuple(round(dim * scale) for dim in (height, width))
31 | img = F.resize(img, new_size, self.interpolation)
32 | pad_h = self.max_size - new_size[0]
33 | pad_w = self.max_size - new_size[1]
34 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
35 | return img
36 |
37 |
38 | def _convert_to_rgb(image):
39 | return image.convert('RGB')
40 |
41 |
42 | def image_transform(
43 | image_size: int,
44 | is_train: bool,
45 | mean: Optional[Tuple[float, ...]] = None,
46 | std: Optional[Tuple[float, ...]] = None,
47 | resize_longest_max: bool = False,
48 | fill_color: int = 0,
49 | ):
50 | mean = mean or (0.48145466, 0.4578275, 0.40821073) # OpenAI dataset mean
51 | std = std or (0.26862954, 0.26130258, 0.27577711) # OpenAI dataset std
52 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
53 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
54 | image_size = image_size[0]
55 |
56 | normalize = Normalize(mean=mean, std=std)
57 | if is_train:
58 | return Compose([
59 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
60 | _convert_to_rgb,
61 | ToTensor(),
62 | normalize,
63 | ])
64 | else:
65 | if resize_longest_max:
66 | transforms = [
67 | ResizeMaxSize(image_size, fill=fill_color)
68 | ]
69 | else:
70 | transforms = [
71 | Resize(image_size, interpolation=InterpolationMode.BICUBIC),
72 | CenterCrop(image_size),
73 | ]
74 | transforms.extend([
75 | _convert_to_rgb,
76 | ToTensor(),
77 | normalize,
78 | ])
79 | return Compose(transforms)
80 |
--------------------------------------------------------------------------------
/src/open_clip/utils.py:
--------------------------------------------------------------------------------
1 | from itertools import repeat
2 | import collections.abc
3 |
4 | from torch import nn as nn
5 | from torchvision.ops.misc import FrozenBatchNorm2d
6 |
7 |
8 | def freeze_batch_norm_2d(module, module_match={}, name=''):
9 | """
10 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
11 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
12 | returned. Otherwise, the module is walked recursively and submodules are converted in place.
13 |
14 | Args:
15 | module (torch.nn.Module): Any PyTorch module.
16 | module_match (dict): Dictionary of full module names to freeze (all if empty)
17 | name (str): Full module name (prefix)
18 |
19 | Returns:
20 | torch.nn.Module: Resulting module
21 |
22 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
23 | """
24 | res = module
25 | is_match = True
26 | if module_match:
27 | is_match = name in module_match
28 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
29 | res = FrozenBatchNorm2d(module.num_features)
30 | res.num_features = module.num_features
31 | res.affine = module.affine
32 | if module.affine:
33 | res.weight.data = module.weight.data.clone().detach()
34 | res.bias.data = module.bias.data.clone().detach()
35 | res.running_mean.data = module.running_mean.data
36 | res.running_var.data = module.running_var.data
37 | res.eps = module.eps
38 | else:
39 | for child_name, child in module.named_children():
40 | full_child_name = '.'.join([name, child_name]) if name else child_name
41 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
42 | if new_child is not child:
43 | res.add_module(child_name, new_child)
44 | return res
45 |
46 |
47 | # From PyTorch internals
48 | def _ntuple(n):
49 | def parse(x):
50 | if isinstance(x, collections.abc.Iterable):
51 | return x
52 | return tuple(repeat(x, n))
53 | return parse
54 |
55 |
56 | to_1tuple = _ntuple(1)
57 | to_2tuple = _ntuple(2)
58 | to_3tuple = _ntuple(3)
59 | to_4tuple = _ntuple(4)
60 | to_ntuple = lambda n, x: _ntuple(n)(x)
61 |
--------------------------------------------------------------------------------
/src/open_clip/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '1.3.0'
2 |
--------------------------------------------------------------------------------
/src/training/.gitignore:
--------------------------------------------------------------------------------
1 | logs/
2 |
--------------------------------------------------------------------------------
/src/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/src/training/__init__.py
--------------------------------------------------------------------------------
/src/training/data.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import json
3 | import logging
4 | import math
5 | import os
6 | import random
7 | import sys
8 | import time
9 | from dataclasses import dataclass
10 | from multiprocessing import Value
11 | import ast
12 | import random
13 | import braceexpand
14 | import numpy as np
15 | import pandas as pd
16 | import torch
17 | import torchvision.datasets as datasets
18 | import webdataset as wds
19 | from PIL import Image
20 | from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info
21 | from torch.utils.data.distributed import DistributedSampler
22 | from webdataset.filters import _shuffle
23 | from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
24 |
25 | try:
26 | import horovod.torch as hvd
27 | except ImportError:
28 | hvd = None
29 |
30 | from open_clip import tokenize
31 |
32 |
33 | class CsvDataset(Dataset):
34 | def __init__(self, input_filename, transforms, img_key, caption_key, hard_captions_key, sep="\t"):
35 | logging.debug(f'Loading csv data from {input_filename}.')
36 | df = pd.read_csv(input_filename, sep=sep, converters={"neg_caption":ast.literal_eval, "neg_image":ast.literal_eval})
37 |
38 | self.images = df[img_key].tolist()
39 | self.captions = df[caption_key].tolist()
40 | self.hard_captions = df[hard_captions_key].tolist()
41 | self.hard_images = df["neg_image"].tolist()
42 | self.transforms = transforms
43 | logging.debug('Done loading data.')
44 |
45 | def __len__(self):
46 | return len(self.captions)
47 |
48 | def __getitem__(self, idx):
49 | images = self.transforms(Image.open(str(self.images[idx])))
50 | texts = tokenize([str(self.captions[idx])])[0]
51 |
52 | chosen_caption = random.choice(self.hard_captions[idx])
53 | hard_captions = tokenize([str(chosen_caption)])[0]
54 |
55 | chose_image_index = random.choice(self.hard_images[idx])
56 |
57 | new_images = self.transforms(Image.open(str(self.images[chose_image_index])))
58 | new_texts = tokenize([str(self.captions[chose_image_index])])[0]
59 |
60 | chosen_caption = random.choice(self.hard_captions[chose_image_index])
61 | new_hard = tokenize([str(chosen_caption)])[0]
62 |
63 | return images, new_images, texts, new_texts, hard_captions, new_hard
64 |
65 |
66 |
67 | class SharedEpoch:
68 | def __init__(self, epoch: int = 0):
69 | self.shared_epoch = Value('i', epoch)
70 |
71 | def set_value(self, epoch):
72 | self.shared_epoch.value = epoch
73 |
74 | def get_value(self):
75 | return self.shared_epoch.value
76 |
77 |
78 | @dataclass
79 | class DataInfo:
80 | dataloader: DataLoader
81 | sampler: DistributedSampler = None
82 | shared_epoch: SharedEpoch = None
83 |
84 | def set_epoch(self, epoch):
85 | if self.shared_epoch is not None:
86 | self.shared_epoch.set_value(epoch)
87 | if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
88 | self.sampler.set_epoch(epoch)
89 |
90 |
91 | def preprocess_txt(text):
92 | return tokenize([str(text)])[0]
93 |
94 |
95 | def get_dataset_size(shards):
96 | shards_list = list(braceexpand.braceexpand(shards))
97 | dir_path = os.path.dirname(shards)
98 | sizes_filename = os.path.join(dir_path, 'sizes.json')
99 | len_filename = os.path.join(dir_path, '__len__')
100 | if os.path.exists(sizes_filename):
101 | sizes = json.load(open(sizes_filename, 'r'))
102 | total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list])
103 | elif os.path.exists(len_filename):
104 | # FIXME this used to be eval(open(...)) but that seemed rather unsafe
105 | total_size = ast.literal_eval(open(len_filename, 'r').read())
106 | else:
107 | total_size = None # num samples undefined
108 | # some common dataset sizes (at time of authors last download)
109 | # CC3M (train): 2905954
110 | # CC12M: 10968539
111 | # LAION-400M: 407332084
112 | # LAION-2B (english): 2170337258
113 | num_shards = len(shards_list)
114 | return total_size, num_shards
115 |
116 |
117 | def get_imagenet(args, preprocess_fns, split):
118 | assert split in ["train", "val", "v2"]
119 | is_train = split == "train"
120 | preprocess_train, preprocess_val = preprocess_fns
121 |
122 | if split == "v2":
123 | from imagenetv2_pytorch import ImageNetV2Dataset
124 | dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
125 | else:
126 | if is_train:
127 | data_path = args.imagenet_train
128 | preprocess_fn = preprocess_train
129 | else:
130 | data_path = args.imagenet_val
131 | preprocess_fn = preprocess_val
132 | assert data_path
133 |
134 | dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
135 |
136 | if is_train:
137 | idxs = np.zeros(len(dataset.targets))
138 | target_array = np.array(dataset.targets)
139 | k = 50
140 | for c in range(1000):
141 | m = target_array == c
142 | n = len(idxs[m])
143 | arr = np.zeros(n)
144 | arr[:k] = 1
145 | np.random.shuffle(arr)
146 | idxs[m] = arr
147 |
148 | idxs = idxs.astype('int')
149 | sampler = SubsetRandomSampler(np.where(idxs)[0])
150 | else:
151 | sampler = None
152 |
153 | dataloader = torch.utils.data.DataLoader(
154 | dataset,
155 | batch_size=args.batch_size,
156 | num_workers=args.workers,
157 | sampler=sampler,
158 | )
159 |
160 | return DataInfo(dataloader=dataloader, sampler=sampler)
161 |
162 |
163 | def count_samples(dataloader):
164 | os.environ["WDS_EPOCH"] = "0"
165 | n_elements, n_batches = 0, 0
166 | for images, texts in dataloader:
167 | n_batches += 1
168 | n_elements += len(images)
169 | assert len(images) == len(texts)
170 | return n_elements, n_batches
171 |
172 |
173 | def filter_no_caption(sample):
174 | return 'txt' in sample
175 |
176 |
177 | def log_and_continue(exn):
178 | """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
179 | logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
180 | return True
181 |
182 |
183 | def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
184 | """Return function over iterator that groups key, value pairs into samples.
185 |
186 | :param keys: function that splits the key into key and extension (base_plus_ext)
187 | :param lcase: convert suffixes to lower case (Default value = True)
188 | """
189 | current_sample = None
190 | for filesample in data:
191 | assert isinstance(filesample, dict)
192 | fname, value = filesample["fname"], filesample["data"]
193 | prefix, suffix = keys(fname)
194 | if prefix is None:
195 | continue
196 | if lcase:
197 | suffix = suffix.lower()
198 | # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
199 | # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
200 | # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
201 | if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
202 | if valid_sample(current_sample):
203 | yield current_sample
204 | current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
205 | if suffixes is None or suffix in suffixes:
206 | current_sample[suffix] = value
207 | if valid_sample(current_sample):
208 | yield current_sample
209 |
210 |
211 | def tarfile_to_samples_nothrow(src, handler=log_and_continue):
212 | # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
213 | streams = url_opener(src, handler=handler)
214 | files = tar_file_expander(streams, handler=handler)
215 | samples = group_by_keys_nothrow(files, handler=handler)
216 | return samples
217 |
218 |
219 | def pytorch_worker_seed():
220 | """get dataloader worker seed from pytorch"""
221 | worker_info = get_worker_info()
222 | if worker_info is not None:
223 | # favour the seed already created for pytorch dataloader workers if it exists
224 | return worker_info.seed
225 | # fallback to wds rank based seed
226 | return wds.utils.pytorch_worker_seed()
227 |
228 |
229 | _SHARD_SHUFFLE_SIZE = 2000
230 | _SHARD_SHUFFLE_INITIAL = 500
231 | _SAMPLE_SHUFFLE_SIZE = 5000
232 | _SAMPLE_SHUFFLE_INITIAL = 1000
233 |
234 |
235 | class detshuffle2(wds.PipelineStage):
236 | def __init__(
237 | self,
238 | bufsize=1000,
239 | initial=100,
240 | seed=0,
241 | epoch=-1,
242 | ):
243 | self.bufsize = bufsize
244 | self.initial = initial
245 | self.seed = seed
246 | self.epoch = epoch
247 |
248 | def run(self, src):
249 | if isinstance(self.epoch, SharedEpoch):
250 | epoch = self.epoch.get_value()
251 | else:
252 | # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
253 | # situation as different workers may wrap at different times (or not at all).
254 | self.epoch += 1
255 | epoch = self.epoch
256 | rng = random.Random()
257 | if self.seed < 0:
258 | seed = pytorch_worker_seed() + epoch
259 | else:
260 | seed = self.seed + epoch
261 | rng.seed(seed)
262 | return _shuffle(src, self.bufsize, self.initial, rng)
263 |
264 |
265 | class ResampledShards2(IterableDataset):
266 | """An iterable dataset yielding a list of urls."""
267 |
268 | def __init__(
269 | self,
270 | urls,
271 | nshards=sys.maxsize,
272 | worker_seed=None,
273 | deterministic=False,
274 | epoch=-1,
275 | ):
276 | """Sample shards from the shard list with replacement.
277 |
278 | :param urls: a list of URLs as a Python list or brace notation string
279 | """
280 | super().__init__()
281 | urls = wds.shardlists.expand_urls(urls)
282 | self.urls = urls
283 | assert isinstance(self.urls[0], str)
284 | self.nshards = nshards
285 | self.rng = random.Random()
286 | self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed
287 | self.deterministic = deterministic
288 | self.epoch = epoch
289 |
290 | def __iter__(self):
291 | """Return an iterator over the shards."""
292 | if isinstance(self.epoch, SharedEpoch):
293 | epoch = self.epoch.get_value()
294 | else:
295 | # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
296 | # situation as different workers may wrap at different times (or not at all).
297 | self.epoch += 1
298 | epoch = self.epoch
299 | if self.deterministic:
300 | # reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed
301 | self.rng.seed(self.worker_seed() + epoch)
302 | for _ in range(self.nshards):
303 | yield dict(url=self.rng.choice(self.urls))
304 |
305 |
306 | def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False):
307 | input_shards = args.train_data if is_train else args.val_data
308 | assert input_shards is not None
309 | resampled = getattr(args, 'dataset_resampled', False) and is_train
310 |
311 | num_samples, num_shards = get_dataset_size(input_shards)
312 | if not num_samples:
313 | if is_train:
314 | num_samples = args.train_num_samples
315 | if not num_samples:
316 | raise RuntimeError(
317 | 'Currently, number of dataset samples must be specified for training dataset. '
318 | 'Please specify via `--train-num-samples` if no dataset length info present.')
319 | else:
320 | num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified
321 |
322 | shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc
323 | if resampled:
324 | pipeline = [ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch)]
325 | else:
326 | pipeline = [wds.SimpleShardList(input_shards)]
327 |
328 | # at this point we have an iterator over all the shards
329 | if is_train:
330 | if not resampled:
331 | pipeline.extend([
332 | detshuffle2(
333 | bufsize=_SHARD_SHUFFLE_SIZE,
334 | initial=_SHARD_SHUFFLE_INITIAL,
335 | seed=args.seed,
336 | epoch=shared_epoch,
337 | ),
338 | wds.split_by_node,
339 | wds.split_by_worker,
340 | ])
341 | pipeline.extend([
342 | # at this point, we have an iterator over the shards assigned to each worker at each node
343 | tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue),
344 | wds.shuffle(
345 | bufsize=_SAMPLE_SHUFFLE_SIZE,
346 | initial=_SAMPLE_SHUFFLE_INITIAL,
347 | ),
348 | ])
349 | else:
350 | pipeline.extend([
351 | wds.split_by_worker,
352 | # at this point, we have an iterator over the shards assigned to each worker
353 | wds.tarfile_to_samples(handler=log_and_continue),
354 | ])
355 | pipeline.extend([
356 | wds.select(filter_no_caption),
357 | wds.decode("pilrgb", handler=log_and_continue),
358 | wds.rename(image="jpg;png", text="txt"),
359 | wds.map_dict(image=preprocess_img, text=preprocess_txt),
360 | wds.to_tuple("image", "text"),
361 | wds.batched(args.batch_size, partial=not is_train),
362 | ])
363 |
364 | dataset = wds.DataPipeline(*pipeline)
365 | if is_train:
366 | if not resampled:
367 | assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers'
368 | # roll over and repeat a few samples to get same number of full batches on each node
369 | round_fn = math.floor if floor else math.ceil
370 | global_batch_size = args.batch_size * args.world_size
371 | num_batches = round_fn(num_samples / global_batch_size)
372 | num_workers = max(1, args.workers)
373 | num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
374 | num_batches = num_worker_batches * num_workers
375 | num_samples = num_batches * global_batch_size
376 | dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this
377 | else:
378 | # last batches are partial, eval is done on single (master) node
379 | num_batches = math.ceil(num_samples / args.batch_size)
380 |
381 | dataloader = wds.WebLoader(
382 | dataset,
383 | batch_size=None,
384 | shuffle=False,
385 | num_workers=args.workers,
386 | persistent_workers=True,
387 | )
388 |
389 | # FIXME not clear which approach is better, with_epoch before vs after dataloader?
390 | # hoping to resolve via https://github.com/webdataset/webdataset/issues/169
391 | # if is_train:
392 | # # roll over and repeat a few samples to get same number of full batches on each node
393 | # global_batch_size = args.batch_size * args.world_size
394 | # num_batches = math.ceil(num_samples / global_batch_size)
395 | # num_workers = max(1, args.workers)
396 | # num_batches = math.ceil(num_batches / num_workers) * num_workers
397 | # num_samples = num_batches * global_batch_size
398 | # dataloader = dataloader.with_epoch(num_batches)
399 | # else:
400 | # # last batches are partial, eval is done on single (master) node
401 | # num_batches = math.ceil(num_samples / args.batch_size)
402 |
403 | # add meta-data to dataloader instance for convenience
404 | dataloader.num_batches = num_batches
405 | dataloader.num_samples = num_samples
406 |
407 | return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
408 |
409 |
410 | def get_csv_dataset(args, preprocess_fn, is_train, epoch=0):
411 | input_filename = args.train_data if is_train else args.val_data
412 | assert input_filename
413 | dataset = CsvDataset(
414 | input_filename,
415 | preprocess_fn,
416 | img_key=args.csv_img_key,
417 | caption_key=args.csv_caption_key,
418 | hard_captions_key=args.csv_hard_captions_key,
419 | sep=args.csv_separator)
420 | num_samples = len(dataset)
421 | sampler = DistributedSampler(dataset) if args.distributed and is_train else None
422 | shuffle = is_train and sampler is None
423 |
424 | dataloader = DataLoader(
425 | dataset,
426 | batch_size=args.batch_size,
427 | shuffle=shuffle,
428 | num_workers=args.workers,
429 | pin_memory=True,
430 | sampler=sampler,
431 | drop_last=is_train,
432 | )
433 | dataloader.num_samples = num_samples
434 | dataloader.num_batches = len(dataloader)
435 |
436 | return DataInfo(dataloader, sampler)
437 |
438 |
439 | def get_dataset_fn(data_path, dataset_type):
440 | if dataset_type == "webdataset":
441 | return get_wds_dataset
442 | elif dataset_type == "csv":
443 | return get_csv_dataset
444 | elif dataset_type == "auto":
445 | ext = data_path.split('.')[-1]
446 | if ext in ['csv', 'tsv']:
447 | return get_csv_dataset
448 | elif ext in ['tar']:
449 | return get_wds_dataset
450 | else:
451 | raise ValueError(
452 | f"Tried to figure out dataset type, but failed for extention {ext}.")
453 | else:
454 | raise ValueError(f"Unsupported dataset type: {dataset_type}")
455 |
456 |
457 | def get_data(args, preprocess_fns, epoch=0):
458 | preprocess_train, preprocess_val = preprocess_fns
459 | data = {}
460 |
461 | if args.train_data:
462 | data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
463 | args, preprocess_train, is_train=True, epoch=epoch)
464 |
465 | if args.val_data:
466 | data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
467 | args, preprocess_val, is_train=False)
468 |
469 | if args.imagenet_val is not None:
470 | data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val")
471 |
472 | if args.imagenet_v2 is not None:
473 | data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2")
474 |
475 | return data
476 |
--------------------------------------------------------------------------------
/src/training/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 | try:
6 | import horovod.torch as hvd
7 | except ImportError:
8 | hvd = None
9 |
10 |
11 | def is_global_master(args):
12 | return args.rank == 0
13 |
14 |
15 | def is_local_master(args):
16 | return args.local_rank == 0
17 |
18 |
19 | def is_master(args, local=False):
20 | return is_local_master(args) if local else is_global_master(args)
21 |
22 |
23 | def is_using_horovod():
24 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
25 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
26 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
27 | pmi_vars = ["PMI_RANK", "PMI_SIZE"]
28 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]):
29 | return True
30 | else:
31 | return False
32 |
33 |
34 | def is_using_distributed():
35 | if 'WORLD_SIZE' in os.environ:
36 | return int(os.environ['WORLD_SIZE']) > 1
37 | if 'SLURM_NTASKS' in os.environ:
38 | return int(os.environ['SLURM_NTASKS']) > 1
39 | return False
40 |
41 |
42 | def world_info_from_env():
43 | local_rank = 0
44 | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
45 | if v in os.environ:
46 | local_rank = int(os.environ[v])
47 | break
48 | global_rank = 0
49 | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
50 | if v in os.environ:
51 | global_rank = int(os.environ[v])
52 | break
53 | world_size = 1
54 | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
55 | if v in os.environ:
56 | world_size = int(os.environ[v])
57 | break
58 |
59 | return local_rank, global_rank, world_size
60 |
61 |
62 | def init_distributed_device(args):
63 | # Distributed training = training on more than one GPU.
64 | # Works in both single and multi-node scenarios.
65 | args.distributed = False
66 | args.world_size = 1
67 | args.rank = 0 # global rank
68 | args.local_rank = 0
69 | if args.horovod:
70 | assert hvd is not None, "Horovod is not installed"
71 | hvd.init()
72 | args.local_rank = int(hvd.local_rank())
73 | args.rank = hvd.rank()
74 | args.world_size = hvd.size()
75 | args.distributed = True
76 | os.environ['LOCAL_RANK'] = str(args.local_rank)
77 | os.environ['RANK'] = str(args.rank)
78 | os.environ['WORLD_SIZE'] = str(args.world_size)
79 | elif is_using_distributed():
80 | if 'SLURM_PROCID' in os.environ:
81 | # DDP via SLURM
82 | args.local_rank, args.rank, args.world_size = world_info_from_env()
83 | # SLURM var -> torch.distributed vars in case needed
84 | os.environ['LOCAL_RANK'] = str(args.local_rank)
85 | os.environ['RANK'] = str(args.rank)
86 | os.environ['WORLD_SIZE'] = str(args.world_size)
87 | torch.distributed.init_process_group(
88 | backend=args.dist_backend,
89 | init_method=args.dist_url,
90 | world_size=args.world_size,
91 | rank=args.rank,
92 | )
93 | else:
94 | # DDP via torchrun, torch.distributed.launch
95 | args.local_rank, _, _ = world_info_from_env()
96 | torch.distributed.init_process_group(
97 | backend=args.dist_backend,
98 | init_method=args.dist_url)
99 | args.world_size = torch.distributed.get_world_size()
100 | args.rank = torch.distributed.get_rank()
101 | args.distributed = True
102 |
103 | if torch.cuda.is_available():
104 | if args.distributed and not args.no_set_device_rank:
105 | device = 'cuda:%d' % args.local_rank
106 | else:
107 | device = 'cuda:0'
108 | torch.cuda.set_device(device)
109 | else:
110 | device = 'cpu'
111 | args.device = device
112 | device = torch.device(device)
113 | return device
114 |
--------------------------------------------------------------------------------
/src/training/imagenet_zeroshot_data.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | imagenet_classnames = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
4 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
5 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
6 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
7 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
8 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
9 | "box turtle", "banded gecko", "green iguana", "Carolina anole",
10 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
11 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
12 | "American alligator", "triceratops", "worm snake", "ring-necked snake",
13 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
14 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
15 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
16 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
17 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
18 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
19 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
20 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
21 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
22 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
23 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
24 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
25 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
26 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
27 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
28 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
29 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
30 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
31 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
32 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
33 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
34 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
35 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
36 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
37 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
38 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
39 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
40 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
41 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
42 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
43 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
44 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
45 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
46 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
47 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
48 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
49 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
50 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
51 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
52 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
53 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
54 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
55 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
56 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
57 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
58 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
59 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
60 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
61 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
62 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
63 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
64 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
65 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
66 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
67 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
68 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
69 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
70 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
71 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
72 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
73 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
74 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
75 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
76 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
77 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
78 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
79 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
80 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
81 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
82 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
83 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
84 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
85 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
86 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
87 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
88 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
89 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
90 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
91 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
92 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
93 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
94 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
95 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
96 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
97 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
98 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
99 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
100 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
101 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
102 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck",
103 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
104 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
105 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
106 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
107 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
108 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
109 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
110 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
111 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
112 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
113 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
114 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
115 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
116 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
117 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
118 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
119 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
120 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
121 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
122 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
123 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
124 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
125 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
126 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
127 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
128 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
129 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
130 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
131 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
132 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
133 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
134 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
135 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
136 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
137 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
138 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
139 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
140 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
141 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
142 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
143 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
144 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
145 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
146 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
147 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
148 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
149 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
150 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
151 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
152 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
153 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
154 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
155 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
156 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
157 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
158 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
159 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
160 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
161 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
162 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
163 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
164 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
165 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
166 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
167 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
168 |
169 |
170 |
171 |
172 |
173 | openai_imagenet_template = [
174 | lambda c: f'a bad photo of a {c}.',
175 | lambda c: f'a photo of many {c}.',
176 | lambda c: f'a sculpture of a {c}.',
177 | lambda c: f'a photo of the hard to see {c}.',
178 | lambda c: f'a low resolution photo of the {c}.',
179 | lambda c: f'a rendering of a {c}.',
180 | lambda c: f'graffiti of a {c}.',
181 | lambda c: f'a bad photo of the {c}.',
182 | lambda c: f'a cropped photo of the {c}.',
183 | lambda c: f'a tattoo of a {c}.',
184 | lambda c: f'the embroidered {c}.',
185 | lambda c: f'a photo of a hard to see {c}.',
186 | lambda c: f'a bright photo of a {c}.',
187 | lambda c: f'a photo of a clean {c}.',
188 | lambda c: f'a photo of a dirty {c}.',
189 | lambda c: f'a dark photo of the {c}.',
190 | lambda c: f'a drawing of a {c}.',
191 | lambda c: f'a photo of my {c}.',
192 | lambda c: f'the plastic {c}.',
193 | lambda c: f'a photo of the cool {c}.',
194 | lambda c: f'a close-up photo of a {c}.',
195 | lambda c: f'a black and white photo of the {c}.',
196 | lambda c: f'a painting of the {c}.',
197 | lambda c: f'a painting of a {c}.',
198 | lambda c: f'a pixelated photo of the {c}.',
199 | lambda c: f'a sculpture of the {c}.',
200 | lambda c: f'a bright photo of the {c}.',
201 | lambda c: f'a cropped photo of a {c}.',
202 | lambda c: f'a plastic {c}.',
203 | lambda c: f'a photo of the dirty {c}.',
204 | lambda c: f'a jpeg corrupted photo of a {c}.',
205 | lambda c: f'a blurry photo of the {c}.',
206 | lambda c: f'a photo of the {c}.',
207 | lambda c: f'a good photo of the {c}.',
208 | lambda c: f'a rendering of the {c}.',
209 | lambda c: f'a {c} in a video game.',
210 | lambda c: f'a photo of one {c}.',
211 | lambda c: f'a doodle of a {c}.',
212 | lambda c: f'a close-up photo of the {c}.',
213 | lambda c: f'a photo of a {c}.',
214 | lambda c: f'the origami {c}.',
215 | lambda c: f'the {c} in a video game.',
216 | lambda c: f'a sketch of a {c}.',
217 | lambda c: f'a doodle of the {c}.',
218 | lambda c: f'a origami {c}.',
219 | lambda c: f'a low resolution photo of a {c}.',
220 | lambda c: f'the toy {c}.',
221 | lambda c: f'a rendition of the {c}.',
222 | lambda c: f'a photo of the clean {c}.',
223 | lambda c: f'a photo of a large {c}.',
224 | lambda c: f'a rendition of a {c}.',
225 | lambda c: f'a photo of a nice {c}.',
226 | lambda c: f'a photo of a weird {c}.',
227 | lambda c: f'a blurry photo of a {c}.',
228 | lambda c: f'a cartoon {c}.',
229 | lambda c: f'art of a {c}.',
230 | lambda c: f'a sketch of the {c}.',
231 | lambda c: f'a embroidered {c}.',
232 | lambda c: f'a pixelated photo of a {c}.',
233 | lambda c: f'itap of the {c}.',
234 | lambda c: f'a jpeg corrupted photo of the {c}.',
235 | lambda c: f'a good photo of a {c}.',
236 | lambda c: f'a plushie {c}.',
237 | lambda c: f'a photo of the nice {c}.',
238 | lambda c: f'a photo of the small {c}.',
239 | lambda c: f'a photo of the weird {c}.',
240 | lambda c: f'the cartoon {c}.',
241 | lambda c: f'art of the {c}.',
242 | lambda c: f'a drawing of the {c}.',
243 | lambda c: f'a photo of the large {c}.',
244 | lambda c: f'a black and white photo of a {c}.',
245 | lambda c: f'the plushie {c}.',
246 | lambda c: f'a dark photo of a {c}.',
247 | lambda c: f'itap of a {c}.',
248 | lambda c: f'graffiti of the {c}.',
249 | lambda c: f'a toy {c}.',
250 | lambda c: f'itap of my {c}.',
251 | lambda c: f'a photo of a cool {c}.',
252 | lambda c: f'a photo of a small {c}.',
253 | lambda c: f'a tattoo of the {c}.',
254 | ]
255 |
--------------------------------------------------------------------------------
/src/training/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | def setup_logging(log_file, level, include_host=False):
5 | if include_host:
6 | import socket
7 | hostname = socket.gethostname()
8 | formatter = logging.Formatter(
9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
10 | else:
11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
12 |
13 | logging.root.setLevel(level)
14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
15 | for logger in loggers:
16 | logger.setLevel(level)
17 |
18 | stream_handler = logging.StreamHandler()
19 | stream_handler.setFormatter(formatter)
20 | logging.root.addHandler(stream_handler)
21 |
22 | if log_file:
23 | file_handler = logging.FileHandler(filename=log_file)
24 | file_handler.setFormatter(formatter)
25 | logging.root.addHandler(file_handler)
26 |
27 |
--------------------------------------------------------------------------------
/src/training/main.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import random
4 | from datetime import datetime
5 |
6 | import numpy as np
7 | import torch
8 | from torch import optim
9 | from torch.cuda.amp import GradScaler
10 |
11 | try:
12 | import wandb
13 | except ImportError:
14 | wandb = None
15 |
16 | try:
17 | import torch.utils.tensorboard as tensorboard
18 | except ImportError:
19 | tensorboard = None
20 |
21 | try:
22 | import horovod.torch as hvd
23 | except ImportError:
24 | hvd = None
25 |
26 | from open_clip import create_model_and_transforms, trace_model
27 | from training.data import get_data
28 | from training.distributed import is_master, init_distributed_device, world_info_from_env
29 | from training.logger import setup_logging
30 | from training.params import parse_args
31 | from training.scheduler import cosine_lr
32 | from training.train import train_one_epoch, evaluate
33 |
34 |
35 | def random_seed(seed=42, rank=0):
36 | torch.manual_seed(seed + rank)
37 | np.random.seed(seed + rank)
38 | random.seed(seed + rank)
39 |
40 |
41 | def main():
42 | args = parse_args()
43 |
44 | # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule?
45 | args.model = args.model.replace('/', '-')
46 |
47 | # get the name of the experiments
48 | if args.name is None:
49 | args.name = '-'.join([
50 | datetime.now().strftime("%Y_%m_%d-%H_%M_%S"),
51 | f"model_{args.model}",
52 | f"lr_{args.lr}",
53 | f"b_{args.batch_size}",
54 | f"j_{args.workers}",
55 | f"p_{args.precision}",
56 | ])
57 |
58 | # discover initial world args early so we can log properly
59 | args.distributed = False
60 | args.local_rank, args.rank, args.world_size = world_info_from_env()
61 |
62 | args.log_path = None
63 | if is_master(args, local=args.log_local):
64 | log_base_path = os.path.join(args.logs, args.name)
65 | os.makedirs(log_base_path, exist_ok=True)
66 | log_filename = f'out-{args.rank}' if args.log_local else 'out.log'
67 | args.log_path = os.path.join(log_base_path, log_filename)
68 | if os.path.exists(args.log_path):
69 | print(
70 | "Error. Experiment already exists. Use --name {} to specify a new experiment."
71 | )
72 | return -1
73 |
74 | # Set logger
75 | args.log_level = logging.DEBUG if args.debug else logging.INFO
76 | setup_logging(args.log_path, args.log_level)
77 |
78 | # fully initialize distributed device environment
79 | torch.backends.cudnn.benchmark = True
80 | torch.backends.cudnn.deterministic = False
81 | device = init_distributed_device(args)
82 |
83 | args.wandb = 'wandb' in args.report_to or 'all' in args.report_to
84 | args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to
85 | if is_master(args):
86 | args.tensorboard_path = os.path.join(args.logs, args.name, "tensorboard") if args.tensorboard else ''
87 | args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints")
88 | for dirname in [args.tensorboard_path, args.checkpoint_path]:
89 | if dirname:
90 | os.makedirs(dirname, exist_ok=True)
91 | else:
92 | args.tensorboard_path = ''
93 | args.checkpoint_path = ''
94 |
95 | if args.copy_codebase:
96 | copy_codebase(args)
97 |
98 | assert args.precision in ['amp', 'fp16', 'fp32']
99 | if args.precision == 'fp16':
100 | logging.warning(
101 | 'It is recommended to use AMP mixed-precision instead of FP16. '
102 | 'FP16 support needs further verification and tuning, especially for train.')
103 |
104 | if args.horovod:
105 | logging.info(
106 | f'Running in horovod mode with multiple processes / nodes. Device: {args.device}.'
107 | f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.')
108 | elif args.distributed:
109 | logging.info(
110 | f'Running in distributed mode with multiple processes. Device: {args.device}.'
111 | f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.')
112 | else:
113 | logging.info(f'Running with a single process. Device {args.device}.')
114 |
115 | random_seed(args.seed, 0)
116 | model, preprocess_train, preprocess_val = create_model_and_transforms(
117 | args.model,
118 | args.pretrained,
119 | precision=args.precision,
120 | device=device,
121 | jit=args.torchscript,
122 | force_quick_gelu=args.force_quick_gelu,
123 | pretrained_image=args.pretrained_image,
124 | )
125 | random_seed(args.seed, args.rank)
126 |
127 | if args.trace:
128 | model = trace_model(model, batch_size=args.batch_size, device=device)
129 |
130 | if args.lock_image:
131 | # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
132 | model.lock_image_tower(
133 | unlocked_groups=args.lock_image_unlocked_groups,
134 | freeze_bn_stats=args.lock_image_freeze_bn_stats)
135 |
136 | if args.grad_checkpointing:
137 | model.set_grad_checkpointing()
138 |
139 | if is_master(args):
140 | logging.info("Model:")
141 | logging.info(f"{str(model)}")
142 | logging.info("Params:")
143 | params_file = os.path.join(args.logs, args.name, "params.txt")
144 | with open(params_file, "w") as f:
145 | for name in sorted(vars(args)):
146 | val = getattr(args, name)
147 | logging.info(f" {name}: {val}")
148 | f.write(f"{name}: {val}\n")
149 |
150 | if args.distributed and not args.horovod:
151 | if args.use_bn_sync:
152 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
153 | ddp_args = {}
154 | if args.ddp_static_graph:
155 | # this doesn't exist in older PyTorch, arg only added if enabled
156 | ddp_args['static_graph'] = True
157 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args)
158 |
159 | # create optimizer and scaler
160 | optimizer = None
161 | scaler = None
162 | if args.train_data:
163 | assert not args.trace, 'Cannot train with traced model'
164 |
165 | exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
166 | include = lambda n, p: not exclude(n, p)
167 |
168 | named_parameters = list(model.named_parameters())
169 | gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]
170 | rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]
171 |
172 | optimizer = optim.AdamW(
173 | [
174 | {"params": gain_or_bias_params, "weight_decay": 0.},
175 | {"params": rest_params, "weight_decay": args.wd},
176 | ],
177 | lr=args.lr,
178 | betas=(args.beta1, args.beta2),
179 | eps=args.eps,
180 | )
181 | if args.horovod:
182 | optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
183 | hvd.broadcast_parameters(model.state_dict(), root_rank=0)
184 | hvd.broadcast_optimizer_state(optimizer, root_rank=0)
185 |
186 | scaler = GradScaler() if args.precision == "amp" else None
187 |
188 | # optionally resume from a checkpoint
189 | start_epoch = 0
190 | if args.resume is not None:
191 | if os.path.isfile(args.resume):
192 | checkpoint = torch.load(args.resume, map_location=device)
193 | if 'epoch' in checkpoint:
194 | # resuming a train checkpoint w/ epoch and optimizer state
195 | start_epoch = checkpoint["epoch"]
196 | sd = checkpoint["state_dict"]
197 | if not args.distributed and next(iter(sd.items()))[0].startswith('module'):
198 | sd = {k[len('module.'):]: v for k, v in sd.items()}
199 | model.load_state_dict(sd)
200 | if optimizer is not None:
201 | optimizer.load_state_dict(checkpoint["optimizer"])
202 | if scaler is not None and 'scaler' in checkpoint:
203 | scaler.load_state_dict(checkpoint['scaler'])
204 | logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})")
205 | else:
206 | # loading a bare (model only) checkpoint for fine-tune or evaluation
207 | model.load_state_dict(checkpoint)
208 | logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})")
209 | else:
210 | logging.info("=> no checkpoint found at '{}'".format(args.resume))
211 |
212 | # initialize datasets
213 | data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch)
214 | assert len(data), 'At least one train or eval dataset must be specified.'
215 |
216 | # create scheduler if train
217 | scheduler = None
218 | if 'train' in data and optimizer is not None:
219 | total_steps = data["train"].dataloader.num_batches * args.epochs
220 | scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps)
221 |
222 | # determine if this worker should save logs and checkpoints. only do so if it is rank == 0
223 | args.save_logs = args.logs and args.logs.lower() != 'none' and is_master(args)
224 | writer = None
225 | if args.save_logs and args.tensorboard:
226 | assert tensorboard is not None, "Please install tensorboard."
227 | writer = tensorboard.SummaryWriter(args.tensorboard_path)
228 |
229 | if args.wandb and is_master(args):
230 | assert wandb is not None, 'Please install wandb.'
231 | logging.debug('Starting wandb.')
232 | args.train_sz = data["train"].dataloader.num_samples
233 | if args.val_data is not None:
234 | args.val_sz = data["val"].dataloader.num_samples
235 | # you will have to configure this for your project!
236 | wandb.init(
237 | project="open-clip",
238 | notes=args.wandb_notes,
239 | tags=[],
240 | config=vars(args),
241 | )
242 | if args.debug:
243 | wandb.watch(model, log='all')
244 | wandb.save(params_file)
245 | logging.debug('Finished loading wandb.')
246 |
247 | if 'train' not in data:
248 | evaluate(model, data, start_epoch, args, writer)
249 | return
250 |
251 | for epoch in range(start_epoch, args.epochs):
252 | if is_master(args):
253 | logging.info(f'Start epoch {epoch}')
254 |
255 | train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer)
256 | completed_epoch = epoch + 1
257 |
258 | if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')):
259 | evaluate(model, data, completed_epoch, args, writer)
260 |
261 | # Saving checkpoints.
262 | if args.save_logs:
263 | checkpoint_dict = {
264 | "epoch": completed_epoch,
265 | "name": args.name,
266 | "state_dict": model.state_dict(),
267 | "optimizer": optimizer.state_dict(),
268 | }
269 | if scaler is not None:
270 | checkpoint_dict["scaler"] = scaler.state_dict()
271 |
272 | if completed_epoch == args.epochs or (
273 | args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0
274 | ):
275 | torch.save(
276 | checkpoint_dict,
277 | os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"),
278 | )
279 | if args.save_most_recent:
280 | torch.save(
281 | checkpoint_dict,
282 | os.path.join(args.checkpoint_path, f"epoch_latest.pt"),
283 | )
284 |
285 | if args.wandb and is_master(args):
286 | wandb.finish()
287 |
288 |
289 | def copy_codebase(args):
290 | from shutil import copytree, ignore_patterns
291 | new_code_path = os.path.join(args.logs, args.name, "code")
292 | if os.path.exists(new_code_path):
293 | print(
294 | f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment."
295 | )
296 | return -1
297 | print(f"Copying codebase to {new_code_path}")
298 | current_code_path = os.path.realpath(__file__)
299 | for _ in range(3):
300 | current_code_path = os.path.dirname(current_code_path)
301 | copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb'))
302 | print("Done copying code.")
303 | return 1
304 |
305 |
306 | if __name__ == "__main__":
307 | main()
308 |
--------------------------------------------------------------------------------
/src/training/params.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_default_params(model_name):
5 | # Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
6 | model_name = model_name.lower()
7 | if "vit" in model_name:
8 | return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
9 | else:
10 | return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}
11 |
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument(
16 | "--train-data",
17 | type=str,
18 | default=None,
19 | help="Path to csv filewith training data",
20 | )
21 | parser.add_argument(
22 | "--val-data",
23 | type=str,
24 | default=None,
25 | help="Path to csv file with validation data",
26 | )
27 | parser.add_argument(
28 | "--train-num-samples",
29 | type=int,
30 | default=None,
31 | help="Number of samples in dataset. Required for webdataset if not available in info file.",
32 | )
33 | parser.add_argument(
34 | "--val-num-samples",
35 | type=int,
36 | default=None,
37 | help="Number of samples in dataset. Useful for webdataset if not available in info file.",
38 | )
39 | parser.add_argument(
40 | "--dataset-type",
41 | choices=["webdataset", "csv", "auto"],
42 | default="auto",
43 | help="Which type of dataset to process."
44 | )
45 | parser.add_argument(
46 | "--dataset-resampled",
47 | default=False,
48 | action="store_true",
49 | help="Whether to use sampling with replacement for webdataset shard selection."
50 | )
51 | parser.add_argument(
52 | "--csv-separator",
53 | type=str,
54 | default="\t",
55 | help="For csv-like datasets, which separator to use."
56 | )
57 | parser.add_argument(
58 | "--csv-img-key",
59 | type=str,
60 | default="filepath",
61 | help="For csv-like datasets, the name of the key for the image paths."
62 | )
63 | parser.add_argument(
64 | "--csv-hard-captions-key",
65 | type=str,
66 | default="neg_caption",
67 | help="For csv-like datasets, the name of the key for the hard captions."
68 | )
69 | parser.add_argument(
70 | "--csv-caption-key",
71 | type=str,
72 | default="title",
73 | help="For csv-like datasets, the name of the key for the captions."
74 | )
75 | parser.add_argument(
76 | "--imagenet-val",
77 | type=str,
78 | default=None,
79 | help="Path to imagenet val set for conducting zero shot evaluation.",
80 | )
81 | parser.add_argument(
82 | "--imagenet-v2",
83 | type=str,
84 | default=None,
85 | help="Path to imagenet v2 for conducting zero shot evaluation.",
86 | )
87 | parser.add_argument(
88 | "--logs",
89 | type=str,
90 | default="./logs/",
91 | help="Where to store tensorboard logs. Use None to avoid storing logs.",
92 | )
93 | parser.add_argument(
94 | "--log-local",
95 | action="store_true",
96 | default=False,
97 | help="log files on local master, otherwise global master only.",
98 | )
99 | parser.add_argument(
100 | "--name",
101 | type=str,
102 | default=None,
103 | help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
104 | )
105 | parser.add_argument(
106 | "--workers", type=int, default=1, help="Number of dataloader workers per GPU."
107 | )
108 | parser.add_argument(
109 | "--batch-size", type=int, default=64, help="Batch size per GPU."
110 | )
111 | parser.add_argument(
112 | "--epochs", type=int, default=32, help="Number of epochs to train for."
113 | )
114 | parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
115 | parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
116 | parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
117 | parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
118 | parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
119 | parser.add_argument(
120 | "--warmup", type=int, default=10000, help="Number of steps to warmup for."
121 | )
122 | parser.add_argument(
123 | "--use-bn-sync",
124 | default=False,
125 | action="store_true",
126 | help="Whether to use batch norm sync.")
127 | parser.add_argument(
128 | "--skip-scheduler",
129 | action="store_true",
130 | default=False,
131 | help="Use this flag to skip the learning rate decay.",
132 | )
133 | parser.add_argument(
134 | "--save-frequency", type=int, default=1, help="How often to save checkpoints."
135 | )
136 | parser.add_argument(
137 | "--save-most-recent",
138 | action="store_true",
139 | default=False,
140 | help="Always save the most recent model trained to epoch_latest.pt.",
141 | )
142 | parser.add_argument(
143 | "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot."
144 | )
145 | parser.add_argument(
146 | "--val-frequency", type=int, default=1, help="How often to run evaluation with val data."
147 | )
148 | parser.add_argument(
149 | "--resume",
150 | default=None,
151 | type=str,
152 | help="path to latest checkpoint (default: none)",
153 | )
154 | parser.add_argument(
155 | "--precision",
156 | choices=["amp", "fp16", "fp32"],
157 | default="amp",
158 | help="Floating point precision."
159 | )
160 | parser.add_argument(
161 | "--model",
162 | type=str,
163 | default="RN50",
164 | help="Name of the vision backbone to use.",
165 | )
166 | parser.add_argument(
167 | "--pretrained",
168 | default='',
169 | type=str,
170 | help="Use a pretrained CLIP model weights with the specified tag or file path.",
171 | )
172 | parser.add_argument(
173 | "--pretrained-image",
174 | default=False,
175 | action='store_true',
176 | help="Load imagenet pretrained weights for image tower backbone if available.",
177 | )
178 | parser.add_argument(
179 | "--lock-image",
180 | default=False,
181 | action='store_true',
182 | help="Lock full image tower by disabling gradients.",
183 | )
184 | parser.add_argument(
185 | "--lock-image-unlocked-groups",
186 | type=int,
187 | default=0,
188 | help="Leave last n image tower layer groups unlocked.",
189 | )
190 | parser.add_argument(
191 | "--lock-image-freeze-bn-stats",
192 | default=False,
193 | action='store_true',
194 | help="Freeze BatchNorm running stats in image tower for any locked layers.",
195 | )
196 | parser.add_argument(
197 | "--grad-checkpointing",
198 | default=False,
199 | action='store_true',
200 | help="Enable gradient checkpointing.",
201 | )
202 | parser.add_argument(
203 | "--local-loss",
204 | default=False,
205 | action="store_true",
206 | help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)"
207 | )
208 | parser.add_argument(
209 | "--gather-with-grad",
210 | default=False,
211 | action="store_true",
212 | help="enable full distributed gradient for feature gather"
213 | )
214 | parser.add_argument(
215 | "--force-quick-gelu",
216 | default=False,
217 | action='store_true',
218 | help="Force use of QuickGELU activation for non-OpenAI transformer models.",
219 | )
220 | parser.add_argument(
221 | "--torchscript",
222 | default=False,
223 | action='store_true',
224 | help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'",
225 | )
226 | parser.add_argument(
227 | "--trace",
228 | default=False,
229 | action='store_true',
230 | help="torch.jit.trace the model for inference / eval only",
231 | )
232 | # arguments for distributed training
233 | parser.add_argument(
234 | "--dist-url",
235 | default="env://",
236 | type=str,
237 | help="url used to set up distributed training",
238 | )
239 | parser.add_argument(
240 | "--dist-backend", default="nccl", type=str, help="distributed backend"
241 | )
242 | parser.add_argument(
243 | "--report-to",
244 | default='',
245 | type=str,
246 | help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']"
247 | )
248 | parser.add_argument(
249 | "--wandb-notes",
250 | default='',
251 | type=str,
252 | help="Notes if logging with wandb"
253 | )
254 | parser.add_argument(
255 | "--debug",
256 | default=False,
257 | action="store_true",
258 | help="If true, more information is logged."
259 | )
260 | parser.add_argument(
261 | "--copy-codebase",
262 | default=False,
263 | action="store_true",
264 | help="If true, we copy the entire base on the log diretory, and execute from there."
265 | )
266 | parser.add_argument(
267 | "--horovod",
268 | default=False,
269 | action="store_true",
270 | help="Use horovod for distributed training."
271 | )
272 | parser.add_argument(
273 | "--ddp-static-graph",
274 | default=False,
275 | action='store_true',
276 | help="Enable static graph optimization for DDP in PyTorch >= 1.11.",
277 | )
278 | parser.add_argument(
279 | "--no-set-device-rank",
280 | default=False,
281 | action="store_true",
282 | help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc)."
283 | )
284 | parser.add_argument(
285 | "--seed", type=int, default=0, help="Default random seed."
286 | )
287 | parser.add_argument(
288 | "--norm_gradient_clip", type=float, default=None, help="Gradient clip."
289 | )
290 | args = parser.parse_args()
291 |
292 | # If some params are not passed, we use the default values based on model name.
293 | default_params = get_default_params(args.model)
294 | for name, val in default_params.items():
295 | if getattr(args, name) is None:
296 | setattr(args, name, val)
297 |
298 | return args
299 |
--------------------------------------------------------------------------------
/src/training/scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def assign_learning_rate(optimizer, new_lr):
5 | for param_group in optimizer.param_groups:
6 | param_group["lr"] = new_lr
7 |
8 |
9 | def _warmup_lr(base_lr, warmup_length, step):
10 | return base_lr * (step + 1) / warmup_length
11 |
12 |
13 | def cosine_lr(optimizer, base_lr, warmup_length, steps):
14 | def _lr_adjuster(step):
15 | if step < warmup_length:
16 | lr = _warmup_lr(base_lr, warmup_length, step)
17 | else:
18 | e = step - warmup_length
19 | es = steps - warmup_length
20 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
21 | assign_learning_rate(optimizer, lr)
22 | return lr
23 | return _lr_adjuster
--------------------------------------------------------------------------------
/src/training/train.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import math
4 | import os
5 | import time
6 | from contextlib import suppress
7 |
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 |
12 | try:
13 | import wandb
14 | except ImportError:
15 | wandb = None
16 |
17 | from open_clip import ClipLoss
18 | from .distributed import is_master
19 | from .zero_shot import zero_shot_eval
20 |
21 |
22 | class AverageMeter(object):
23 | """Computes and stores the average and current value"""
24 | def __init__(self):
25 | self.reset()
26 |
27 | def reset(self):
28 | self.val = 0
29 | self.avg = 0
30 | self.sum = 0
31 | self.count = 0
32 |
33 | def update(self, val, n=1):
34 | self.val = val
35 | self.sum += val * n
36 | self.count += n
37 | self.avg = self.sum / self.count
38 |
39 |
40 | def unwrap_model(model):
41 | if hasattr(model, 'module'):
42 | return model.module
43 | else:
44 | return model
45 |
46 |
47 | def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None):
48 | device = torch.device(args.device)
49 | autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress
50 |
51 | model.train()
52 | loss = ClipLoss(
53 | local_loss=args.local_loss,
54 | gather_with_grad=args.gather_with_grad,
55 | cache_labels=True,
56 | rank=args.rank,
57 | world_size=args.world_size,
58 | use_horovod=args.horovod)
59 |
60 | data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch
61 | dataloader = data['train'].dataloader
62 | num_batches_per_epoch = dataloader.num_batches
63 | sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
64 |
65 | loss_m = AverageMeter()
66 | batch_time_m = AverageMeter()
67 | data_time_m = AverageMeter()
68 | end = time.time()
69 | for i, batch in enumerate(dataloader):
70 | step = num_batches_per_epoch * epoch + i
71 | scheduler(step)
72 |
73 | images, new_images, texts, new_texts, hard_captions, new_hard = batch
74 |
75 | images = images.to(device=device, non_blocking=True)
76 | new_images = new_images.to(device=device, non_blocking=True)
77 |
78 | texts = texts.to(device=device, non_blocking=True)
79 | new_texts = new_texts.to(device=device, non_blocking=True)
80 |
81 | hard_captions = hard_captions.to(device=device, non_blocking=True)
82 | new_hard = new_hard.to(device=device, non_blocking=True)
83 |
84 | images = torch.cat([images, new_images])
85 |
86 | texts = torch.cat([texts, new_texts])
87 | texts = torch.cat([texts, hard_captions])
88 | texts = torch.cat([texts, new_hard])
89 |
90 | data_time_m.update(time.time() - end)
91 | optimizer.zero_grad()
92 |
93 | with autocast():
94 | image_features, text_features, logit_scale = model(images, texts)
95 | total_loss = loss(image_features, text_features, logit_scale)
96 |
97 | if scaler is not None:
98 | scaler.scale(total_loss).backward()
99 | if args.horovod:
100 | optimizer.synchronize()
101 | scaler.unscale_(optimizer)
102 | if args.norm_gradient_clip is not None:
103 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0)
104 | with optimizer.skip_synchronize():
105 | scaler.step(optimizer)
106 | else:
107 | if args.norm_gradient_clip is not None:
108 | scaler.unscale_(optimizer)
109 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0)
110 | scaler.step(optimizer)
111 | scaler.update()
112 | else:
113 | total_loss.backward()
114 | if args.norm_gradient_clip is not None:
115 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0)
116 | optimizer.step()
117 |
118 | # Note: we clamp to 4.6052 = ln(100), as in the original paper.
119 | with torch.no_grad():
120 | unwrap_model(model).logit_scale.clamp_(0, math.log(100))
121 |
122 | batch_time_m.update(time.time() - end)
123 | end = time.time()
124 | batch_count = i + 1
125 | if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch):
126 | batch_size = len(images)
127 | num_samples = batch_count * batch_size * args.world_size
128 | samples_per_epoch = dataloader.num_samples
129 | percent_complete = 100.0 * batch_count / num_batches_per_epoch
130 |
131 | # NOTE loss is coarsely sampled, just master node and per log update
132 | loss_m.update(total_loss.item(), batch_size)
133 | logit_scale_scalar = logit_scale.item()
134 | logging.info(
135 | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
136 | f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
137 | f"Data (t): {data_time_m.avg:.3f} "
138 | f"Batch (t): {batch_time_m.avg:.3f}, {args.batch_size*args.world_size / batch_time_m.val:#g}/s "
139 | f"LR: {optimizer.param_groups[0]['lr']:5f} "
140 | f"Logit Scale: {logit_scale_scalar:.3f} - V4"
141 | )
142 |
143 | # Save train loss / etc. Using non avg meter values as loggers have their own smoothing
144 | log_data = {
145 | "loss": loss_m.val,
146 | "data_time": data_time_m.val,
147 | "batch_time": batch_time_m.val,
148 | "samples_per_scond": args.batch_size*args.world_size / batch_time_m.val,
149 | "scale": logit_scale_scalar,
150 | "lr": optimizer.param_groups[0]["lr"]
151 | }
152 | for name, val in log_data.items():
153 | name = "train/" + name
154 | if tb_writer is not None:
155 | tb_writer.add_scalar(name, val, step)
156 | if args.wandb:
157 | assert wandb is not None, 'Please install wandb.'
158 | wandb.log({name: val, 'step': step})
159 |
160 | # resetting batch / data time meters per log window
161 | batch_time_m.reset()
162 | data_time_m.reset()
163 | # end for
164 |
165 |
166 | def evaluate(model, data, epoch, args, tb_writer=None):
167 | metrics = {}
168 | if not is_master(args):
169 | return metrics
170 | device = torch.device(args.device)
171 | model.eval()
172 |
173 | zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
174 | metrics.update(zero_shot_metrics)
175 |
176 | autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress
177 | if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)):
178 | dataloader = data['val'].dataloader
179 | num_samples = 0
180 | samples_per_val = dataloader.num_samples
181 |
182 | # FIXME this does not scale past small eval datasets
183 | # all_image_features @ all_text_features will blow up memory and compute very quickly
184 | cumulative_loss = 0.0
185 | all_image_features, all_text_features = [], []
186 | with torch.no_grad():
187 | for i, batch in enumerate(dataloader):
188 | images, hard_images, texts, texts_hard_images, hard_captions, hard_captions_of_hard_images = batch
189 |
190 | images = images.to(device=device, non_blocking=True)
191 | hard_images = hard_images.to(device=device, non_blocking=True)
192 |
193 | texts = texts.to(device=device, non_blocking=True)
194 | texts_hard_images = texts_hard_images.to(device=device, non_blocking=True)
195 |
196 | hard_captions = hard_captions.to(device=device, non_blocking=True)
197 | hard_captions_of_hard_images = hard_captions_of_hard_images.to(device=device, non_blocking=True)
198 |
199 | images = torch.cat([images, hard_images])
200 |
201 | texts = torch.cat([texts, texts_hard_images])
202 | texts = torch.cat([texts, hard_captions])
203 | texts = torch.cat([texts, hard_captions_of_hard_images])
204 |
205 | with autocast():
206 | image_features, text_features, logit_scale = model(images, texts)
207 | # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly
208 | # however, system RAM is easily exceeded and compute time becomes problematic
209 |
210 | logit_scale = logit_scale.mean()
211 | logits_per_image = logit_scale * image_features @ text_features.t()
212 | logits_per_text = logits_per_image.t()
213 |
214 | all_image_features.append(image_features.cpu())
215 | all_text_features.append(text_features[:len(logits_per_image)].cpu())
216 |
217 | batch_size = images.shape[0]
218 | labels = torch.arange(batch_size, device=device).long()
219 | total_loss = (
220 | F.cross_entropy(logits_per_image, labels) +
221 | F.cross_entropy(logits_per_text[:len(logits_per_image)], labels)
222 | ) / 2
223 |
224 | cumulative_loss += total_loss * batch_size
225 | num_samples += batch_size
226 | if is_master(args) and (i % 100) == 0:
227 | logging.info(
228 | f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t"
229 | f"Loss: {cumulative_loss / num_samples:.6f}\t")
230 |
231 | val_metrics = get_metrics(
232 | image_features=torch.cat(all_image_features),
233 | text_features=torch.cat(all_text_features),
234 | logit_scale=logit_scale.cpu(),
235 | )
236 | loss = cumulative_loss / num_samples
237 | metrics.update(
238 | {**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples}
239 | )
240 |
241 | if not metrics:
242 | return metrics
243 |
244 | logging.info(
245 | f"Eval Epoch: {epoch} "
246 | + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
247 | )
248 |
249 | if args.save_logs:
250 | for name, val in metrics.items():
251 | if tb_writer is not None:
252 | tb_writer.add_scalar(f"val/{name}", val, epoch)
253 |
254 | with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
255 | f.write(json.dumps(metrics))
256 | f.write("\n")
257 |
258 | if args.wandb:
259 | assert wandb is not None, 'Please install wandb.'
260 | for name, val in metrics.items():
261 | wandb.log({f"val/{name}": val, 'epoch': epoch})
262 |
263 | return metrics
264 |
265 |
266 | def get_metrics(image_features, text_features, logit_scale):
267 | metrics = {}
268 | logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()
269 | logits_per_text = logits_per_image.t().detach().cpu()
270 |
271 | logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text[:len(logits_per_image)]}
272 | ground_truth = torch.arange(len(text_features)).view(-1, 1)
273 |
274 | for name, logit in logits.items():
275 | ranking = torch.argsort(logit, descending=True)
276 | preds = torch.where(ranking == ground_truth)[1]
277 | preds = preds.detach().cpu().numpy()
278 | metrics[f"{name}_mean_rank"] = preds.mean() + 1
279 | metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
280 | for k in [1, 5, 10]:
281 | metrics[f"{name}_R@{k}"] = np.mean(preds < k)
282 |
283 | return metrics
284 |
--------------------------------------------------------------------------------
/src/training/zero_shot.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from contextlib import suppress
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from tqdm import tqdm
7 |
8 | from open_clip import tokenize
9 | from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template
10 |
11 |
12 | def zero_shot_classifier(model, classnames, templates, args):
13 | with torch.no_grad():
14 | zeroshot_weights = []
15 | for classname in tqdm(classnames):
16 | texts = [template(classname) for template in templates] # format with class
17 | texts = tokenize(texts).to(args.device) # tokenize
18 | if args.distributed and not args.horovod:
19 | class_embeddings = model.module.encode_text(texts)
20 | else:
21 | class_embeddings = model.encode_text(texts)
22 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
23 | class_embedding /= class_embedding.norm()
24 | zeroshot_weights.append(class_embedding)
25 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device)
26 | return zeroshot_weights
27 |
28 |
29 | def accuracy(output, target, topk=(1,)):
30 | pred = output.topk(max(topk), 1, True, True)[1].t()
31 | correct = pred.eq(target.view(1, -1).expand_as(pred))
32 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
33 |
34 |
35 | def run(model, classifier, dataloader, args):
36 | autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress
37 | with torch.no_grad():
38 | top1, top5, n = 0., 0., 0.
39 | for images, target in tqdm(dataloader, unit_scale=args.batch_size):
40 | images = images.to(args.device)
41 | target = target.to(args.device)
42 |
43 | with autocast():
44 | # predict
45 | if args.distributed and not args.horovod:
46 | image_features = model.module.encode_image(images)
47 | else:
48 | image_features = model.encode_image(images)
49 | image_features = F.normalize(image_features, dim=-1)
50 | logits = 100. * image_features @ classifier
51 |
52 | # measure accuracy
53 | acc1, acc5 = accuracy(logits, target, topk=(1, 5))
54 | top1 += acc1
55 | top5 += acc5
56 | n += images.size(0)
57 |
58 | top1 = (top1 / n)
59 | top5 = (top5 / n)
60 | return top1, top5
61 |
62 |
63 | def zero_shot_eval(model, data, epoch, args):
64 | if 'imagenet-val' not in data and 'imagenet-v2' not in data:
65 | return {}
66 | if args.zeroshot_frequency == 0:
67 | return {}
68 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
69 | return {}
70 |
71 | logging.info('Starting zero-shot imagenet.')
72 |
73 | logging.info('Building zero-shot classifier')
74 | classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args)
75 |
76 | logging.info('Using classifier')
77 | results = {}
78 | if 'imagenet-val' in data:
79 | top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args)
80 | results['imagenet-zeroshot-val-top1'] = top1
81 | results['imagenet-zeroshot-val-top5'] = top5
82 | if 'imagenet-v2' in data:
83 | top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args)
84 | results['imagenetv2-zeroshot-val-top1'] = top1
85 | results['imagenetv2-zeroshot-val-top5'] = top5
86 |
87 | logging.info('Finished zero-shot imagenet.')
88 |
89 | return results
90 |
--------------------------------------------------------------------------------
/tests/test_simple.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from PIL import Image
4 | from open_clip import tokenizer
5 | import open_clip
6 | import os
7 | os.environ["CUDA_VISIBLE_DEVICES"] = ""
8 |
9 | def test_inference():
10 | model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32')
11 |
12 | current_dir = os.path.dirname(os.path.realpath(__file__))
13 |
14 | image = preprocess(Image.open(current_dir + "/../docs/CLIP.png")).unsqueeze(0)
15 | text = tokenizer.tokenize(["a diagram", "a dog", "a cat"])
16 |
17 | with torch.no_grad():
18 | image_features = model.encode_image(image)
19 | text_features = model.encode_text(text)
20 |
21 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
22 |
23 | assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0]
--------------------------------------------------------------------------------