├── .github └── workflows │ ├── ci.yml │ └── python-publish.yml ├── .gitignore ├── CITATION.cff ├── HISTORY.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── docs ├── CLIP.png ├── Interacting_with_open_clip.ipynb ├── clip_conceptual_captions.md ├── clip_loss.png ├── clip_recall.png ├── clip_val_loss.png ├── clip_zeroshot.png ├── effective_robustness.png ├── laion2b_clip_zeroshot_b32.png ├── laion_clip_zeroshot.png ├── laion_clip_zeroshot_b16.png ├── laion_clip_zeroshot_b16_plus_240.png ├── laion_clip_zeroshot_l14.png ├── laion_openai_compare_b32.jpg └── scaling.png ├── requirements-test.txt ├── requirements-training.txt ├── requirements.txt ├── setup.py ├── src ├── data │ └── gather_cc.py ├── open_clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── factory.py │ ├── loss.py │ ├── model.py │ ├── model_configs │ │ ├── RN101-quickgelu.json │ │ ├── RN101.json │ │ ├── RN50-quickgelu.json │ │ ├── RN50.json │ │ ├── RN50x16.json │ │ ├── RN50x4.json │ │ ├── ViT-B-16-plus-240.json │ │ ├── ViT-B-16-plus.json │ │ ├── ViT-B-16.json │ │ ├── ViT-B-32-plus-256.json │ │ ├── ViT-B-32-quickgelu.json │ │ ├── ViT-B-32.json │ │ ├── ViT-H-14.json │ │ ├── ViT-H-16.json │ │ ├── ViT-L-14-280.json │ │ ├── ViT-L-14-336.json │ │ ├── ViT-L-14.json │ │ ├── ViT-L-16-320.json │ │ ├── ViT-L-16.json │ │ ├── ViT-g-14.json │ │ ├── timm-efficientnetv2_rw_s.json │ │ ├── timm-resnet50d.json │ │ ├── timm-resnetaa50d.json │ │ ├── timm-resnetblur50.json │ │ ├── timm-swin_base_patch4_window7_224.json │ │ ├── timm-vit_base_patch16_224.json │ │ ├── timm-vit_base_patch32_224.json │ │ └── timm-vit_small_patch16_224.json │ ├── openai.py │ ├── pretrained.py │ ├── timm_model.py │ ├── tokenizer.py │ ├── transform.py │ ├── utils.py │ └── version.py └── training │ ├── .gitignore │ ├── __init__.py │ ├── data.py │ ├── distributed.py │ ├── imagenet_zeroshot_data.py │ ├── logger.py │ ├── main.py │ ├── params.py │ ├── scheduler.py │ ├── train.py │ └── zero_shot.py └── tests └── test_simple.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Continuous integration 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | tests: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: [3.8] 17 | 18 | steps: 19 | - uses: actions/checkout@v2 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Install 25 | run: | 26 | python3 -m venv .env 27 | source .env/bin/activate 28 | make install 29 | make install-dev 30 | - name: Unit tests 31 | run: | 32 | source .env/bin/activate 33 | make test 34 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: actions-ecosystem/action-regex-match@v2 13 | id: regex-match 14 | with: 15 | text: ${{ github.event.head_commit.message }} 16 | regex: '^Release ([^ ]+)' 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.8' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Release 26 | if: ${{ steps.regex-match.outputs.match != '' }} 27 | uses: softprops/action-gh-release@v1 28 | with: 29 | tag_name: v${{ steps.regex-match.outputs.group1 }} 30 | - name: Build and publish 31 | if: ${{ steps.regex-match.outputs.match != '' }} 32 | env: 33 | TWINE_USERNAME: __token__ 34 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 35 | run: | 36 | python setup.py sdist bdist_wheel 37 | twine upload dist/* 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | wandb/ 3 | models/ 4 | features/ 5 | results/ 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | sync.sh 137 | gpu1sync.sh 138 | .idea 139 | *.pdf 140 | **/._* 141 | **/*DS_* 142 | **.jsonl 143 | src/sbatch 144 | src/misc 145 | .vscode 146 | src/debug 147 | core.* 148 | 149 | # Allow 150 | !src/evaluation/misc/results_dbs/* -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.1.0 2 | message: If you use this software, please cite it as below. 3 | authors: 4 | - family-names: Ilharco 5 | given-names: Gabriel 6 | - family-names: Wortsman 7 | given-names: Mitchell 8 | - family-names: Wightman 9 | given-names: Ross 10 | - family-names: Gordon 11 | given-names: Cade 12 | - family-names: Carlini 13 | given-names: Nicholas 14 | - family-names: Taori 15 | given-names: Rohan 16 | - family-names: Dave 17 | given-names: Achal 18 | - family-names: Shankar 19 | given-names: Vaishaal 20 | - family-names: Namkoong 21 | given-names: Hongseok 22 | - family-names: Miller 23 | given-names: John 24 | - family-names: Hajishirzi 25 | given-names: Hannaneh 26 | - family-names: Farhadi 27 | given-names: Ali 28 | - family-names: Schmidt 29 | given-names: Ludwig 30 | title: OpenCLIP 31 | version: v0.1 32 | doi: 10.5281/zenodo.5143773 33 | date-released: 2021-07-28 34 | -------------------------------------------------------------------------------- /HISTORY.md: -------------------------------------------------------------------------------- 1 | ## 1.2.0 2 | 3 | * ViT-B/32 trained on Laion2B-en 4 | * add missing openai RN50x64 model 5 | 6 | ## 1.1.1 7 | 8 | * ViT-B/16+ 9 | * Add grad checkpointing support 10 | * more robust data loader 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman, 2 | Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar, 3 | John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi, 4 | Ludwig Schmidt 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining 7 | a copy of this software and associated documentation files (the 8 | "Software"), to deal in the Software without restriction, including 9 | without limitation the rights to use, copy, modify, merge, publish, 10 | distribute, sublicense, and/or sell copies of the Software, and to 11 | permit persons to whom the Software is furnished to do so, subject to 12 | the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be 15 | included in all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 20 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 21 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 22 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 23 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include src/open_clip/bpe_simple_vocab_16e6.txt.gz 2 | include src/open_clip/model_configs/*.json 3 | 4 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | install: ## [Local development] Upgrade pip, install requirements, install package. 2 | python -m pip install -U pip 3 | python -m pip install -e . 4 | 5 | install-dev: ## [Local development] Install test requirements 6 | python -m pip install -r requirements-test.txt 7 | 8 | test: ## [Local development] Run unit tests 9 | python -m pytest -x -s -v tests 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # When and why vision-language models behave like bags-of-words, and what to do about it? (ICLR 2023 Oral) 2 | 3 | **Note** This code will not work with the distributed/multi-gpu setting as it is currently implemented. 4 | 5 | ## NegCLIP Implementation 6 | 7 | NegCLIP introduces a few simple edits to the original OpenCLIP base. To ease the code-reading phase, 8 | I'll point out the main edits here; if you are familiar with how OpenCLIP works this should be easy to read/edit and 9 | modify. 10 | 11 | **Dataset** 12 | 13 | The dataset now requires loading hard captions (provided as a list) and hard image negatives. Hard captions and hard images 14 | are chosen at random at each epoch. 15 | 16 | ```python 17 | df = pd.read_csv(input_filename, sep=sep, converters={"neg_caption":ast.literal_eval, "neg_image":ast.literal_eval}) 18 | 19 | self.images = df[img_key].tolist() 20 | self.captions = df[caption_key].tolist() 21 | self.hard_captions = df[hard_captions_key].tolist() 22 | self.hard_images = df["neg_image"].tolist() 23 | self.transforms = transforms 24 | 25 | [...] 26 | 27 | # example of random selection of an hard caption 28 | chosen_caption = random.choice(self.hard_captions[idx]) 29 | hard_captions = tokenize([str(chosen_caption)])[0] 30 | ``` 31 | 32 | **Forward Pass** 33 | 34 | To reduce the number of edits we need to apply to the contrastive loss, we concatenate negative images and negative 35 | captions together. Once this is done we will let the model run the forward pass on this data. 36 | 37 | ```python 38 | images = torch.cat([images, hard_images]) # we concatenate images and hard images 39 | 40 | texts = torch.cat([texts, texts_hard_images]) # we concatenate texts with the text of the hard images 41 | texts = torch.cat([texts, hard_captions]) # we concatenate text with the hard captions 42 | texts = torch.cat([texts, hard_captions_of_hard_images]) # we concatenate texts with the hard caption of the hard images 43 | 44 | # Note. This operation is going to leave us with different in sizes. We will have 2x texts than images (because of the hard negatives). 45 | # This will require us to fix how we compute the loss (see next section). 46 | 47 | with autocast(): 48 | image_features, text_features, logit_scale = model(images, texts) 49 | ``` 50 | 51 | **Loss** 52 | 53 | Finally, we have the loss. In the **Forward Pass** section we have built texts and images that have different `lenghts`. 54 | Basically starting from a batchsize of 256, you get to a contrastive matrix of 512x1024 (for the image part we have 55 | 256 images + 256 hard images, for the text part we have 256 captions + 256 captions from the hard images + 256 hard captions + 256 hard captions from the hard images). 56 | 57 | So we need to change the loss a bit to ignore computing the loss on the wrong items (see the paper). 58 | 59 | ```python 60 | total_loss = ( 61 | F.cross_entropy(logits_per_image, labels) + 62 | F.cross_entropy(logits_per_text[:len(logits_per_image)], labels) 63 | ) / 2 64 | ``` 65 | 66 | 67 | # Citation 68 | If you use this code or data, please consider citing our paper: 69 | 70 | ``` 71 | @inproceedings{ 72 | yuksekgonul2023when, 73 | title={When and why Vision-Language Models behave like Bags-of-Words, and what to do about it?}, 74 | author={Mert Yuksekgonul and Federico Bianchi and Pratyusha Kalluri and Dan Jurafsky and James Zou}, 75 | booktitle={International Conference on Learning Representations}, 76 | year={2023}, 77 | url={https://openreview.net/forum?id=KRLUvxh8uaX} 78 | } 79 | ``` 80 | 81 | What follows from here is the original OpenCLIP readme. 82 | 83 | # Original OpenCLIP 84 | 85 | [[Paper]](https://arxiv.org/abs/2109.01903) [[Colab]](https://colab.research.google.com/github/mlfoundations/open_clip/blob/master/docs/Interacting_with_open_clip.ipynb) 86 | 87 | Welcome to an open source implementation of OpenAI's [CLIP](https://arxiv.org/abs/2103.00020) (Contrastive Language-Image Pre-training). 88 | 89 | The goal of this repository is to enable training models with contrastive image-text supervision, and to investigate their properties such as robustness to distribution shift. Our starting point is an implementation of CLIP that matches the accuracy of the original CLIP models when trained on the same dataset. 90 | Specifically, a ResNet-50 model trained with our codebase on OpenAI's [15 million image subset of YFCC](https://github.com/openai/CLIP/blob/main/data/yfcc100m.md) achieves **32.7%** top-1 accuracy on ImageNet. OpenAI's CLIP model reaches **31.3%** when trained on the same subset of YFCC. For ease of experimentation, we also provide code for training on the 3 million images in the [Conceptual Captions](https://ai.google.com/research/ConceptualCaptions/download) dataset, where a ResNet-50x4 trained with our codebase reaches 22.2% top-1 ImageNet accuracy. 91 | 92 | We further this with a replication study on a dataset of comparable size to OpenAI's. Using [LAION-400M](https://arxiv.org/abs/2111.02114), we train CLIP with a 93 | * ViT-B/32 and achieve an accuracy of **62.9%**, comparable to OpenAI's **63.2%**, zero-shot top-1 on ImageNet1k 94 | * ViT-B/16 and achieve an accuracy of **67.1%**, comparable to OpenAI's **68.3%** (as measured here, 68.6% in paper) 95 | * ViT-B/16+ 240x240 (~50% more FLOPS than B/16 224x224) and achieve an accuracy of **69.2%** 96 | * ViT-L/14 and achieve an accuracy of **72.77%**, vs OpenAI's **75.5%** (as measured here, 75.3% in paper) 97 | 98 | As we describe in more detail [below](#why-are-low-accuracy-clip-models-interesting), CLIP models in a medium accuracy regime already allow us to draw conclusions about the robustness of larger CLIP models since the models follow [reliable scaling laws](https://arxiv.org/abs/2107.04649). 99 | 100 | This codebase is work in progress, and we invite all to contribute in making it more acessible and useful. In the future, we plan to add support for TPU training and release larger models. We hope this codebase facilitates and promotes further research in contrastive image-text learning. Please submit an issue or send an email if you have any other requests or suggestions. 101 | 102 | Note that portions of `src/open_clip/` modelling and tokenizer code are adaptations of OpenAI's official [repository](https://github.com/openai/CLIP). 103 | 104 | ## Approach 105 | 106 | | ![CLIP](https://raw.githubusercontent.com/mlfoundations/open_clip/main/docs/CLIP.png) | 107 | |:--:| 108 | | Image Credit: https://github.com/openai/CLIP | 109 | 110 | ## Usage 111 | 112 | ``` 113 | pip install open_clip_torch 114 | ``` 115 | 116 | ```python 117 | import torch 118 | from PIL import Image 119 | import open_clip 120 | 121 | model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32') 122 | 123 | image = preprocess(Image.open("CLIP.png")).unsqueeze(0) 124 | text = open_clip.tokenize(["a diagram", "a dog", "a cat"]) 125 | 126 | with torch.no_grad(): 127 | image_features = model.encode_image(image) 128 | text_features = model.encode_text(text) 129 | image_features /= image_features.norm(dim=-1, keepdim=True) 130 | text_features /= text_features.norm(dim=-1, keepdim=True) 131 | 132 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) 133 | 134 | print("Label probs:", text_probs) # prints: [[1., 0., 0.]] 135 | ``` 136 | 137 | To compute billions of embeddings efficiently, you can use [clip-retrieval](https://github.com/rom1504/clip-retrieval) which has openclip support. 138 | 139 | ## Fine-tuning on classification tasks 140 | 141 | This repository is focused on training CLIP models. To fine-tune a *trained* zero-shot model on a downstream classification task such as ImageNet, please see [our other repository: WiSE-FT](https://github.com/mlfoundations/wise-ft). The [WiSE-FT repository](https://github.com/mlfoundations/wise-ft) contains code for our paper on [Robust Fine-tuning of Zero-shot Models](https://arxiv.org/abs/2109.01903), in which we introduce a technique for fine-tuning zero-shot models while preserving robustness under distribution shift. 142 | 143 | ## Data 144 | 145 | 146 | ### Conceptual Captions 147 | 148 | OpenCLIP reads a CSV file with two columns: a path to an image, and a text caption. The names of the columns are passed as an argument to `main.py`. 149 | 150 | The script `src/data/gather_cc.py` will collect the Conceptual Captions images. First, download the [Conceptual Captions URLs](https://ai.google.com/research/ConceptualCaptions/download) and then run the script from our repository: 151 | 152 | ```bash 153 | python3 src/data/gather_cc.py path/to/Train_GCC-training.tsv path/to/Validation_GCC-1.1.0-Validation.tsv 154 | ``` 155 | 156 | Our training set contains 2.89M images, and our validation set contains 13K images. 157 | 158 | 159 | ### YFCC and other datasets 160 | 161 | In addition to specifying the training data via CSV files as mentioned above, our codebase also supports [webdataset](https://github.com/webdataset/webdataset), which is recommended for larger scale datasets. The expected format is a series of `.tar` files. Each of these `.tar` files should contain two files for each training example, one for the image and one for the corresponding text. Both files should have the same name but different extensions. For instance, `shard_001.tar` could contain files such as `abc.jpg` and `abc.txt`. You can learn more about `webdataset` at [https://github.com/webdataset/webdataset](https://github.com/webdataset/webdataset). We use `.tar` files with 1,000 data points each, which we create using [tarp](https://github.com/webdataset/tarp). 162 | 163 | You can download the YFCC dataset from [Multimedia Commons](http://mmcommons.org/). 164 | Similar to OpenAI, we used a subset of YFCC to reach the aforementioned accuracy numbers. 165 | The indices of images in this subset are in [OpenAI's CLIP repository](https://github.com/openai/CLIP/blob/main/data/yfcc100m.md). 166 | 167 | 168 | ## Training CLIP 169 | 170 | ### Setup Environment and Install dependencies 171 | 172 | #### Conda 173 | 174 | ```bash 175 | # Create a conda environment (heavily recommended) 176 | conda create -n open_clip python=3.10 177 | conda activate open_clip 178 | ``` 179 | 180 | Install conda PyTorch as per https://pytorch.org/get-started/locally/ 181 | 182 | #### Virtualenv 183 | 184 | openclip also can be used with virtualenv with these lines: 185 | ``` 186 | python3 -m venv .env 187 | source .env/bin/activate 188 | pip install -U pip 189 | make install 190 | ``` 191 | 192 | Install pip PyTorch as per https://pytorch.org/get-started/locally/ 193 | 194 | Test can be run with `make install-dev` then `make test` 195 | 196 | #### Other dependencies 197 | 198 | Install open_clip pacakge and remaining dependencies: 199 | 200 | ```bash 201 | cd open_clip 202 | python setup.py install 203 | ``` 204 | 205 | If you want to train models, you will also need to install the packages 206 | from `requirements-training.txt`. 207 | 208 | ### Sample single-process running code: 209 | 210 | ```bash 211 | python -m training.main \ 212 | --save-frequency 1 \ 213 | --zeroshot-frequency 1 \ 214 | --report-to tensorboard \ 215 | --train-data="/path/to/train_data.csv" \ 216 | --val-data="/path/to/validation_data.csv" \ 217 | --csv-img-key filepath \ 218 | --csv-caption-key title \ 219 | --imagenet-val=/path/to/imagenet/root/val/ \ 220 | --warmup 10000 \ 221 | --batch-size=128 \ 222 | --lr=1e-3 \ 223 | --wd=0.1 \ 224 | --epochs=30 \ 225 | --workers=8 \ 226 | --model RN50 227 | ``` 228 | 229 | Note: `imagenet-val` is the path to the *validation* set of ImageNet for zero-shot evaluation, not the training set! 230 | You can remove this argument if you do not want to perform zero-shot evaluation on ImageNet throughout training. Note that the `val` folder should contain subfolders. If it doest not, please use [this script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh). 231 | 232 | ### Multi-GPU and Beyond 233 | 234 | This code has been battle tested up to 1024 A100s and offers a variety of solutions 235 | for distributed training. We include native support for SLURM clusters. 236 | 237 | As the number of devices used to train increases, so does the space complexity of 238 | the the logit matrix. Using a naïve all-gather scheme, space complexity will be 239 | `O(n^2)`. Instead, complexity may become effectively linear if the flags 240 | `--gather-with-grad` and `--local-loss` are used. This alteration results in one-to-one 241 | numerical results as the naïve method. 242 | 243 | #### Single-Node 244 | 245 | We make use of `torchrun` to launch distributed jobs. The following launches a 246 | a job on a node of 4 GPUs: 247 | 248 | ```bash 249 | cd open_clip/src 250 | torchrun --nproc_per_node 4 -m training.main \ 251 | --train-data '/data/cc12m/cc12m-train-{0000..2175}.tar' \ 252 | --train-num-samples 10968539 \ 253 | --dataset-type webdataset \ 254 | --batch-size 320 \ 255 | --precision amp \ 256 | --workers 4 \ 257 | --imagenet-val /data/imagenet/validation/ 258 | ``` 259 | 260 | #### Multi-Node 261 | 262 | The same script above works, so long as users include information about the number 263 | of nodes and host node. 264 | 265 | ```bash 266 | cd open_clip/src 267 | torchrun --nproc_per_node=4 \ 268 | --rdzv_endpoint=$HOSTE_NODE_ADDR \ 269 | -m training.main \ 270 | --train-data '/data/cc12m/cc12m-train-{0000..2175}.tar' \ 271 | --train-num-samples 10968539 \ 272 | --dataset-type webdataset \ 273 | --batch-size 320 \ 274 | --precision amp \ 275 | --workers 4 \ 276 | --imagenet-val /data/imagenet/validation/ 277 | ``` 278 | 279 | #### SLURM 280 | 281 | This is likely the easiest solution to utilize. The following script was used to 282 | train our largest models: 283 | 284 | ```bash 285 | #!/bin/bash -x 286 | #SBATCH --nodes=32 287 | #SBATCH --gres=gpu:4 288 | #SBATCH --ntasks-per-node=4 289 | #SBATCH --cpus-per-task=6 290 | #SBATCH --wait-all-nodes=1 291 | #SBATCH --job-name=open_clip 292 | #SBATCH --account=ACCOUNT_NAME 293 | #SBATCH --partition PARTITION_NAME 294 | 295 | eval "$(/path/to/conda/bin/conda shell.bash hook)" # init conda 296 | conda activate open_clip 297 | export CUDA_VISIBLE_DEVICES=0,1,2,3 298 | export MASTER_PORT=12802 299 | 300 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 301 | export MASTER_ADDR=$master_addr 302 | 303 | cd /shared/open_clip 304 | export PYTHONPATH="$PYTHONPATH:$PWD/src" 305 | srun --cpu_bind=v --accel-bind=gn python -u src/training/main.py \ 306 | --save-frequency 1 \ 307 | --report-to tensorboard \ 308 | --train-data="/data/LAION-400M/{00000..41455}.tar" \ 309 | --warmup 2000 \ 310 | --batch-size=256 \ 311 | --epochs=32 \ 312 | --workers=8 \ 313 | --model ViT-B-32 \ 314 | --name "ViT-B-32-Vanilla" \ 315 | --seed 0 \ 316 | --local-loss \ 317 | --gather-with-grad 318 | ``` 319 | 320 | ### Resuming from a checkpoint: 321 | 322 | ```bash 323 | python -m training.main \ 324 | --train-data="/path/to/train_data.csv" \ 325 | --val-data="/path/to/validation_data.csv" \ 326 | --resume /path/to/checkpoints/epoch_K.pt 327 | ``` 328 | 329 | ### Loss Curves 330 | 331 | When run on a machine with 8 GPUs the command should produce the following training curve for Conceptual Captions: 332 | 333 | ![CLIP zero shot training curve](https://raw.githubusercontent.com/mlfoundations/open_clip/main/docs/clip_zeroshot.png) 334 | 335 | More detailed curves for Conceptual Captions are given at [/docs/clip_conceptual_captions.md](/docs/clip_conceptual_captions.md). 336 | 337 | When training a RN50 on YFCC the same hyperparameters as above are used, with the exception of `lr=5e-4` and `epochs=32`. 338 | 339 | Note that to use another model, like `ViT-B/32` or `RN50x4` or `RN50x16` or `ViT-B/16`, specify with `--model RN50x4`. 340 | 341 | ### Launch tensorboard: 342 | ```bash 343 | tensorboard --logdir=logs/tensorboard/ --port=7777 344 | ``` 345 | 346 | ## Evaluation / Zero-Shot 347 | 348 | ### Evaluating local checkpoint: 349 | 350 | ```bash 351 | python -m training.main \ 352 | --val-data="/path/to/validation_data.csv" \ 353 | --model RN101 \ 354 | --pretrained /path/to/checkpoints/epoch_K.pt 355 | ``` 356 | 357 | ### Evaluating hosted pretrained checkpoint on ImageNet zero-shot prediction: 358 | 359 | ```bash 360 | python -m training.main \ 361 | --imagenet-val /path/to/imagenet/validation \ 362 | --model ViT-B-32-quickgelu \ 363 | --pretrained laion400m_e32 364 | ``` 365 | 366 | ## Pretrained model details 367 | 368 | ### LAION-400M - https://laion.ai/laion-400-open-dataset 369 | 370 | We are working on reproducing OpenAI's ViT results with the comparably sized (and open) LAION-400M dataset. Trained 371 | weights may be found in release [v0.2](https://github.com/mlfoundations/open_clip/releases/tag/v0.2-weights). 372 | 373 | The LAION400M weights have been trained on the JUWELS supercomputer (see acknowledgements section below). 374 | 375 | #### ViT-B/32 224x224 376 | 377 | We replicate OpenAI's results on ViT-B/32, reaching a top-1 ImageNet-1k zero-shot accuracy of 62.96%. 378 | 379 | 380 | 381 | __Zero-shot comparison (courtesy of Andreas Fürst)__ 382 | 383 | 384 | ViT-B/32 was trained with 128 A100 (40 GB) GPUs for ~36 hours, 4600 GPU-hours. The per-GPU batch size was 256 for a global batch size of 32768. 256 is much lower than it could have been (~320-384) due to being sized initially before moving to 'local' contrastive loss. 385 | 386 | #### ViT-B/16 224x224 387 | 388 | The B/16 LAION400M training reached a top-1 ImageNet-1k zero-shot validation score of 67.07. 389 | 390 | 391 | 392 | This was the first major train session using the updated webdataset 0.2.x code. A bug was found that prevented shards from being shuffled properly between nodes/workers each epoch. This was fixed part way through training (epoch 26) but likely had an impact. 393 | 394 | ViT-B/16 was trained with 176 A100 (40 GB) GPUS for ~61 hours, 10700 GPU-hours. Batch size per GPU was 192 for a global batch size of 33792. 395 | 396 | #### ViT-B/16+ 240x240 397 | 398 | The B/16+ 240x240 LAION400M training reached a top-1 ImageNet-1k zero-shot validation score of 69.21. 399 | 400 | This model is the same depth as the B/16, but increases the 401 | * vision width from 768 -> 896 402 | * text width from 512 -> 640 403 | * the resolution 224x224 -> 240x240 (196 -> 225 tokens) 404 | 405 | 406 | 407 | Unlike the B/16 run above, this model was a clean run with no dataset shuffling issues. 408 | 409 | ViT-B/16+ was trained with 224 A100 (40 GB) GPUS for ~61 hours, 13620 GPU-hours. Batch size per GPU was 160 for a global batch size of 35840. 410 | 411 | #### ViT-L/14 224x224 412 | 413 | The L/14 LAION-400M training reached a top-1 ImageNet-1k zero-shot validation score of 72.77. 414 | 415 | 416 | 417 | ViT-L/14 was trained with 400 A100 (40 GB) GPUS for ~127 hours, 50800 GPU-hours. Batch size per GPU was 96 for a global batch size of 38400. Grad checkpointing was enabled. 418 | 419 | ### LAION-2B (en) - https://laion.ai/laion-5b-a-new-era-of-open-large-scale-multi-modal-datasets/ 420 | 421 | A ~2B sample subset of LAION-5B with english captions (https://huggingface.co/datasets/laion/laion2B-en) 422 | 423 | #### ViT-B/32 224x224 424 | A ViT-B/32 trained on LAION-2B, reaching a top-1 ImageNet-1k zero-shot accuracy of 65.62%. 425 | 426 | 427 | 428 | ViT-B/32 was trained with 112 A100 (40 GB) GPUs. The per-GPU batch size was 416 for a global batch size of 46592. Compute generously provided by [stability.ai](https://stability.ai/). 429 | 430 | #### YFCC-15M 431 | 432 | Below are checkpoints of models trained on YFCC-15M, along with their zero-shot top-1 accuracies on ImageNet and ImageNetV2. These models were trained using 8 GPUs and the same hyperparameters described in the "Sample running code" section, with the exception of `lr=5e-4` and `epochs=32`. 433 | 434 | * [ResNet-50](https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt) (32.7% / 27.9%) 435 | * [ResNet-101](https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt) (34.8% / 30.0%) 436 | 437 | #### CC12M - https://github.com/google-research-datasets/conceptual-12m 438 | 439 | * [ResNet-50](https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt) (36.45%) 440 | 441 | ### Pretrained Model Interface 442 | 443 | We offer a simple model interface to instantiate both pre-trained and untrained models. 444 | 445 | NOTE: Many existing checkpoints use the QuickGELU activation from the original OpenAI models. This activation is actually less efficient that native torch.nn.GELU in recent versions of PyTorch. The model defaults are now nn.GELU, so one should use model definitions with `-quickgelu` postfix for the OpenCLIP pretrained weights. All OpenAI pretrained weights will always default to QuickGELU. One can also use the non `-quickgelu` model definitions with pretrained weights using QuickGELU but there will be an accuracy drop, for fine-tune that will likely vanish for longer runs. 446 | 447 | Future trained models will use nn.GELU. 448 | 449 | ```python 450 | >>> import open_clip 451 | >>> open_clip.list_pretrained() 452 | [('RN50', 'openai'), 453 | ('RN50', 'yfcc15m'), 454 | ('RN50', 'cc12m'), 455 | ('RN50-quickgelu', 'openai'), 456 | ('RN50-quickgelu', 'yfcc15m'), 457 | ('RN50-quickgelu', 'cc12m'), 458 | ('RN101', 'openai'), 459 | ('RN101', 'yfcc15m'), 460 | ('RN101-quickgelu', 'openai'), 461 | ('RN101-quickgelu', 'yfcc15m'), 462 | ('RN50x4', 'openai'), 463 | ('RN50x16', 'openai'), 464 | ('RN50x64', 'openai'), 465 | ('ViT-B-32', 'openai'), 466 | ('ViT-B-32', 'laion2b_e16'), 467 | ('ViT-B-32', 'laion400m_e31'), 468 | ('ViT-B-32', 'laion400m_e32'), 469 | ('ViT-B-32-quickgelu', 'openai'), 470 | ('ViT-B-32-quickgelu', 'laion400m_e31'), 471 | ('ViT-B-32-quickgelu', 'laion400m_e32'), 472 | ('ViT-B-16', 'openai'), 473 | ('ViT-B-16', 'laion400m_e31'), 474 | ('ViT-B-16', 'laion400m_e32'), 475 | ('ViT-B-16-plus-240', 'laion400m_e31'), 476 | ('ViT-B-16-plus-240', 'laion400m_e32'), 477 | ('ViT-L-14', 'openai'), 478 | ('ViT-L-14-336', 'openai')] 479 | 480 | >>> model, train_transform, eval_transform = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_e16') 481 | ``` 482 | 483 | ## Scaling trends 484 | 485 | The plot below shows how zero-shot performance of CLIP models varies as we scale the number of samples used for training. Zero-shot performance increases steadily for both ImageNet and [ImageNetV2](https://arxiv.org/abs/1902.10811), and is far from saturated at ~15M samples. 486 | 487 | 488 | 489 | ## Why are low-accuracy CLIP models interesting? 490 | 491 | **TL;DR:** CLIP models have high effective robustness, even at small scales. 492 | 493 | CLIP models are particularly intriguing because they are more robust to natural distribution shifts (see Section 3.3 in the [CLIP paper](https://arxiv.org/abs/2103.00020)). 494 | This phenomena is illustrated by the figure below, with ImageNet accuracy on the x-axis 495 | and [ImageNetV2](https://arxiv.org/abs/1902.10811) (a reproduction of the ImageNet validation set with distribution shift) accuracy on the y-axis. 496 | Standard training denotes training on the ImageNet train set and the CLIP zero-shot models 497 | are shown as stars. 498 | 499 | ![CLIP scatter plot](https://raw.githubusercontent.com/mlfoundations/open_clip/main/docs/effective_robustness.png) 500 | 501 | As observed by [Taori et al., 2020](https://arxiv.org/abs/2007.00644) and [Miller et al., 2021](https://arxiv.org/abs/2107.04649), the in-distribution 502 | and out-of-distribution accuracies of models trained on ImageNet follow a predictable linear trend (the red line in the above plot). *Effective robustness* 503 | quantifies robustness as accuracy beyond this baseline, i.e., how far a model lies above the red line. Ideally a model would not suffer from distribution shift and fall on the y = x line ([trained human labelers are within a percentage point of the y = x line](http://proceedings.mlr.press/v119/shankar20c.html)). 504 | 505 | Even though the CLIP models trained with 506 | this codebase achieve much lower accuracy than those trained by OpenAI, our models still lie on the same 507 | trend of improved effective robustness (the purple line). Therefore, we can study what makes 508 | CLIP robust without requiring industrial-scale compute. 509 | 510 | For more information on effective robustness, please see: 511 | 512 | - [Recht et al., 2019](https://arxiv.org/abs/1902.10811). 513 | - [Taori et al., 2020](https://arxiv.org/abs/2007.00644). 514 | - [Miller et al., 2021](https://arxiv.org/abs/2107.04649). 515 | 516 | To know more about the factors that contribute to CLIP's robustness refer to [Fang et al., 2022](https://arxiv.org/abs/2205.01397). 517 | 518 | ## Acknowledgments 519 | 520 | We gratefully acknowledge the Gauss Centre for Supercomputing e.V. (www.gauss-centre.eu) for funding this part of work by providing computing time through the John von Neumann Institute for Computing (NIC) on the GCS Supercomputer JUWELS Booster at Jülich Supercomputing Centre (JSC). 521 | 522 | ## The Team 523 | 524 | Current development of this repository is led by [Ross Wightman](https://rwightman.com/), [Cade Gordon](http://cadegordon.io/), and [Vaishaal Shankar](http://vaishaal.com/). 525 | 526 | The original version of this repository is from a group of researchers at UW, Google, Stanford, Amazon, Columbia, and Berkeley. 527 | 528 | [Gabriel Ilharco*](http://gabrielilharco.com/), [Mitchell Wortsman*](https://mitchellnw.github.io/), [Nicholas Carlini](https://nicholas.carlini.com/), [Rohan Taori](https://www.rohantaori.com/), [Achal Dave](http://www.achaldave.com/), [Vaishaal Shankar](http://vaishaal.com/), [John Miller](https://people.eecs.berkeley.edu/~miller_john/), [Hongseok Namkoong](https://hsnamkoong.github.io/), [Hannaneh Hajishirzi](https://homes.cs.washington.edu/~hannaneh/), [Ali Farhadi](https://homes.cs.washington.edu/~ali/), [Ludwig Schmidt](https://people.csail.mit.edu/ludwigs/) 529 | 530 | Special thanks to [Jong Wook Kim](https://jongwook.kim/) and [Alec Radford](https://github.com/Newmu) for help with reproducing CLIP! 531 | 532 | ## Citing 533 | 534 | If you found this repository useful, please consider citing: 535 | ```bibtex 536 | @software{ilharco_gabriel_2021_5143773, 537 | author = {Ilharco, Gabriel and 538 | Wortsman, Mitchell and 539 | Wightman, Ross and 540 | Gordon, Cade and 541 | Carlini, Nicholas and 542 | Taori, Rohan and 543 | Dave, Achal and 544 | Shankar, Vaishaal and 545 | Namkoong, Hongseok and 546 | Miller, John and 547 | Hajishirzi, Hannaneh and 548 | Farhadi, Ali and 549 | Schmidt, Ludwig}, 550 | title = {OpenCLIP}, 551 | month = jul, 552 | year = 2021, 553 | note = {If you use this software, please cite it as below.}, 554 | publisher = {Zenodo}, 555 | version = {0.1}, 556 | doi = {10.5281/zenodo.5143773}, 557 | url = {https://doi.org/10.5281/zenodo.5143773} 558 | } 559 | ``` 560 | 561 | ```bibtex 562 | @inproceedings{Radford2021LearningTV, 563 | title={Learning Transferable Visual Models From Natural Language Supervision}, 564 | author={Alec Radford and Jong Wook Kim and Chris Hallacy and A. Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever}, 565 | booktitle={ICML}, 566 | year={2021} 567 | } 568 | ``` 569 | 570 | [![DOI](https://zenodo.org/badge/390536799.svg)](https://zenodo.org/badge/latestdoi/390536799) 571 | -------------------------------------------------------------------------------- /docs/CLIP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/CLIP.png -------------------------------------------------------------------------------- /docs/clip_conceptual_captions.md: -------------------------------------------------------------------------------- 1 | ## Additional training curves for CLIP on Conceptual Captions 2 | 3 | # Zero shot accuracy 4 | ![](/docs/clip_zeroshot.png) 5 | 6 | # Training loss curve 7 | ![](/docs/clip_loss.png) 8 | 9 | # Validation loss curve 10 | ![](/docs/clip_val_loss.png) 11 | 12 | # Validation recall 13 | ![](/docs/clip_recall.png) -------------------------------------------------------------------------------- /docs/clip_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/clip_loss.png -------------------------------------------------------------------------------- /docs/clip_recall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/clip_recall.png -------------------------------------------------------------------------------- /docs/clip_val_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/clip_val_loss.png -------------------------------------------------------------------------------- /docs/clip_zeroshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/clip_zeroshot.png -------------------------------------------------------------------------------- /docs/effective_robustness.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/effective_robustness.png -------------------------------------------------------------------------------- /docs/laion2b_clip_zeroshot_b32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/laion2b_clip_zeroshot_b32.png -------------------------------------------------------------------------------- /docs/laion_clip_zeroshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/laion_clip_zeroshot.png -------------------------------------------------------------------------------- /docs/laion_clip_zeroshot_b16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/laion_clip_zeroshot_b16.png -------------------------------------------------------------------------------- /docs/laion_clip_zeroshot_b16_plus_240.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/laion_clip_zeroshot_b16_plus_240.png -------------------------------------------------------------------------------- /docs/laion_clip_zeroshot_l14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/laion_clip_zeroshot_l14.png -------------------------------------------------------------------------------- /docs/laion_openai_compare_b32.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/laion_openai_compare_b32.jpg -------------------------------------------------------------------------------- /docs/scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/docs/scaling.png -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest-xdist==2.5.0 2 | pytest==7.0.1 -------------------------------------------------------------------------------- /requirements-training.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | torchvision 3 | webdataset>=0.2.5 4 | regex 5 | ftfy 6 | tqdm 7 | pandas 8 | braceexpand 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | torchvision 3 | regex 4 | ftfy 5 | tqdm 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ Setup 2 | """ 3 | from setuptools import setup, find_packages 4 | from codecs import open 5 | from os import path 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | 9 | # Get the long description from the README file 10 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 11 | long_description = f.read() 12 | 13 | exec(open('src/open_clip/version.py').read()) 14 | setup( 15 | name='open_clip_torch', 16 | version=__version__, 17 | description='OpenCLIP', 18 | long_description=long_description, 19 | long_description_content_type='text/markdown', 20 | url='https://github.com/mlfoundations/open_clip', 21 | author='', 22 | author_email='', 23 | classifiers=[ 24 | # How mature is this project? Common values are 25 | # 3 - Alpha 26 | # 4 - Beta 27 | # 5 - Production/Stable 28 | 'Development Status :: 3 - Alpha', 29 | 'Intended Audience :: Education', 30 | 'Intended Audience :: Science/Research', 31 | 'License :: OSI Approved :: Apache Software License', 32 | 'Programming Language :: Python :: 3.7', 33 | 'Programming Language :: Python :: 3.8', 34 | 'Programming Language :: Python :: 3.9', 35 | 'Programming Language :: Python :: 3.10', 36 | 'Topic :: Scientific/Engineering', 37 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 38 | 'Topic :: Software Development', 39 | 'Topic :: Software Development :: Libraries', 40 | 'Topic :: Software Development :: Libraries :: Python Modules', 41 | ], 42 | 43 | # Note that this is a string of words separated by whitespace, not a list. 44 | keywords='CLIP pretrained', 45 | package_dir={'': 'src'}, 46 | packages=find_packages(where='src', exclude=['training']), 47 | include_package_data=True, 48 | install_requires=[ 49 | 'torch >= 1.9', 50 | 'torchvision', 51 | 'ftfy', 52 | 'regex', 53 | 'tqdm', 54 | ], 55 | python_requires='>=3.7', 56 | ) 57 | -------------------------------------------------------------------------------- /src/data/gather_cc.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import os 3 | import multiprocessing as mp 4 | from io import BytesIO 5 | import numpy as np 6 | import PIL 7 | from PIL import Image 8 | import pickle 9 | import sys 10 | 11 | 12 | def grab(line): 13 | """ 14 | Download a single image from the TSV. 15 | """ 16 | uid, split, line = line 17 | try: 18 | caption, url = line.split("\t")[:2] 19 | except: 20 | print("Parse error") 21 | return 22 | 23 | if os.path.exists(ROOT+"/%s/%d/%d.jpg"%(split,uid%1000,uid)): 24 | print("Finished", uid) 25 | return uid, caption, url 26 | 27 | # Let's not crash if anythign weird happens 28 | try: 29 | dat = requests.get(url, timeout=20) 30 | if dat.status_code != 200: 31 | print("404 file", url) 32 | return 33 | 34 | # Try to parse this as an Image file, we'll fail out if not 35 | im = Image.open(BytesIO(dat.content)) 36 | im.thumbnail((512, 512), PIL.Image.BICUBIC) 37 | if min(*im.size) < max(*im.size)/3: 38 | print("Too small", url) 39 | return 40 | 41 | im.save(ROOT+"/%s/%d/%d.jpg"%(split,uid%1000,uid)) 42 | 43 | # Another try/catch just because sometimes saving and re-loading 44 | # the image is different than loading it once. 45 | try: 46 | o = Image.open(ROOT+"/%s/%d/%d.jpg"%(split,uid%1000,uid)) 47 | o = np.array(o) 48 | 49 | print("Success", o.shape, uid, url) 50 | return uid, caption, url 51 | except: 52 | print("Failed", uid, url) 53 | 54 | except Exception as e: 55 | print("Unknown error", e) 56 | pass 57 | 58 | if __name__ == "__main__": 59 | ROOT = "cc_data" 60 | 61 | if not os.path.exists(ROOT): 62 | os.mkdir(ROOT) 63 | os.mkdir(os.path.join(ROOT,"train")) 64 | os.mkdir(os.path.join(ROOT,"val")) 65 | for i in range(1000): 66 | os.mkdir(os.path.join(ROOT,"train", str(i))) 67 | os.mkdir(os.path.join(ROOT,"val", str(i))) 68 | 69 | 70 | p = mp.Pool(300) 71 | 72 | for tsv in sys.argv[1:]: 73 | print("Processing file", tsv) 74 | assert 'val' in tsv.lower() or 'train' in tsv.lower() 75 | split = 'val' if 'val' in tsv.lower() else 'train' 76 | results = p.map(grab, 77 | [(i,split,x) for i,x in enumerate(open(tsv).read().split("\n"))]) 78 | 79 | out = open(tsv.replace(".tsv","_output.csv"),"w") 80 | out.write("title\tfilepath\n") 81 | 82 | for row in results: 83 | if row is None: continue 84 | id, caption, url = row 85 | fp = os.path.join(ROOT, split, str(id % 1000), str(id) + ".jpg") 86 | if os.path.exists(fp): 87 | out.write("%s\t%s\n"%(caption,fp)) 88 | else: 89 | print("Drop", id) 90 | out.close() 91 | 92 | p.close() 93 | 94 | -------------------------------------------------------------------------------- /src/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import list_models, create_model, create_model_and_transforms, add_model_config 2 | from .loss import ClipLoss 3 | from .model import CLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_fp16, trace_model 4 | from .openai import load_openai_model, list_openai_models 5 | from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\ 6 | get_pretrained_url, download_pretrained 7 | from .tokenizer import SimpleTokenizer, tokenize 8 | from .transform import image_transform 9 | -------------------------------------------------------------------------------- /src/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/src/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /src/open_clip/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | from copy import deepcopy 7 | from pathlib import Path 8 | from typing import Optional, Tuple 9 | 10 | import torch 11 | 12 | from .model import CLIP, convert_weights_to_fp16, resize_pos_embed 13 | from .openai import load_openai_model 14 | from .pretrained import get_pretrained_url, download_pretrained 15 | from .transform import image_transform 16 | 17 | 18 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 19 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 20 | 21 | 22 | def _natural_key(string_): 23 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 24 | 25 | 26 | def _rescan_model_configs(): 27 | global _MODEL_CONFIGS 28 | 29 | config_ext = ('.json',) 30 | config_files = [] 31 | for config_path in _MODEL_CONFIG_PATHS: 32 | if config_path.is_file() and config_path.suffix in config_ext: 33 | config_files.append(config_path) 34 | elif config_path.is_dir(): 35 | for ext in config_ext: 36 | config_files.extend(config_path.glob(f'*{ext}')) 37 | 38 | for cf in config_files: 39 | with open(cf, 'r') as f: 40 | model_cfg = json.load(f) 41 | if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): 42 | _MODEL_CONFIGS[cf.stem] = model_cfg 43 | 44 | _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} 45 | 46 | 47 | _rescan_model_configs() # initial populate of model config registry 48 | 49 | 50 | def load_state_dict(checkpoint_path: str, map_location='cpu'): 51 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 52 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 53 | state_dict = checkpoint['state_dict'] 54 | else: 55 | state_dict = checkpoint 56 | if next(iter(state_dict.items()))[0].startswith('module'): 57 | state_dict = {k[7:]: v for k, v in state_dict.items()} 58 | return state_dict 59 | 60 | 61 | def load_checkpoint(model, checkpoint_path, strict=True): 62 | state_dict = load_state_dict(checkpoint_path) 63 | resize_pos_embed(state_dict, model) 64 | incompatible_keys = model.load_state_dict(state_dict, strict=strict) 65 | return incompatible_keys 66 | 67 | 68 | def create_model( 69 | model_name: str, 70 | pretrained: str = '', 71 | precision: str = 'fp32', 72 | device: torch.device = torch.device('cpu'), 73 | jit: bool = False, 74 | force_quick_gelu: bool = False, 75 | pretrained_image: bool = False, 76 | ): 77 | model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names 78 | 79 | if pretrained.lower() == 'openai': 80 | logging.info(f'Loading pretrained {model_name} from OpenAI.') 81 | model = load_openai_model(model_name, device=device, jit=jit) 82 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 83 | if precision == "amp" or precision == "fp32": 84 | model = model.float() 85 | else: 86 | if model_name in _MODEL_CONFIGS: 87 | logging.info(f'Loading {model_name} model config.') 88 | model_cfg = deepcopy(_MODEL_CONFIGS[model_name]) 89 | else: 90 | logging.error(f'Model config for {model_name} not found; available models {list_models()}.') 91 | raise RuntimeError(f'Model config for {model_name} not found.') 92 | 93 | if force_quick_gelu: 94 | # override for use of QuickGELU on non-OpenAI transformer models 95 | model_cfg["quick_gelu"] = True 96 | 97 | if pretrained_image: 98 | if 'timm_model_name' in model_cfg.get('vision_cfg', {}): 99 | # pretrained weight loading for timm models set via vision_cfg 100 | model_cfg['vision_cfg']['timm_model_pretrained'] = True 101 | else: 102 | assert False, 'pretrained image towers currently only supported for timm models' 103 | 104 | model = CLIP(**model_cfg) 105 | 106 | if pretrained: 107 | checkpoint_path = '' 108 | url = get_pretrained_url(model_name, pretrained) 109 | if url: 110 | checkpoint_path = download_pretrained(url) 111 | elif os.path.exists(pretrained): 112 | checkpoint_path = pretrained 113 | 114 | if checkpoint_path: 115 | logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') 116 | load_checkpoint(model, checkpoint_path) 117 | else: 118 | logging.warning(f'Pretrained weights ({pretrained}) not found for model {model_name}.') 119 | raise RuntimeError(f'Pretrained weights ({pretrained}) not found for model {model_name}.') 120 | 121 | model.to(device=device) 122 | if precision == "fp16": 123 | assert device.type != 'cpu' 124 | convert_weights_to_fp16(model) 125 | 126 | if jit: 127 | model = torch.jit.script(model) 128 | 129 | return model 130 | 131 | 132 | def create_model_and_transforms( 133 | model_name: str, 134 | pretrained: str = '', 135 | precision: str = 'fp32', 136 | device: torch.device = torch.device('cpu'), 137 | jit: bool = False, 138 | force_quick_gelu: bool = False, 139 | pretrained_image: bool = False, 140 | mean: Optional[Tuple[float, ...]] = None, 141 | std: Optional[Tuple[float, ...]] = None, 142 | ): 143 | model = create_model( 144 | model_name, pretrained, precision, device, jit, 145 | force_quick_gelu=force_quick_gelu, 146 | pretrained_image=pretrained_image) 147 | preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=mean, std=std) 148 | preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=mean, std=std) 149 | return model, preprocess_train, preprocess_val 150 | 151 | 152 | def list_models(): 153 | """ enumerate available model architectures based on config files """ 154 | return list(_MODEL_CONFIGS.keys()) 155 | 156 | 157 | def add_model_config(path): 158 | """ add model config path or file and update registry """ 159 | if not isinstance(path, Path): 160 | path = Path(path) 161 | _MODEL_CONFIG_PATHS.append(path) 162 | _rescan_model_configs() 163 | -------------------------------------------------------------------------------- /src/open_clip/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | try: 6 | import torch.distributed.nn 7 | from torch import distributed as dist 8 | has_distributed = True 9 | except ImportError: 10 | has_distributed = False 11 | 12 | try: 13 | import horovod.torch as hvd 14 | except ImportError: 15 | hvd = None 16 | 17 | 18 | def gather_features( 19 | image_features, 20 | text_features, 21 | local_loss=False, 22 | gather_with_grad=False, 23 | rank=0, 24 | world_size=1, 25 | use_horovod=False 26 | ): 27 | assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' 28 | if use_horovod: 29 | assert hvd is not None, 'Please install horovod' 30 | if gather_with_grad: 31 | all_image_features = hvd.allgather(image_features) 32 | all_text_features = hvd.allgather(text_features) 33 | else: 34 | with torch.no_grad(): 35 | all_image_features = hvd.allgather(image_features) 36 | all_text_features = hvd.allgather(text_features) 37 | if not local_loss: 38 | # ensure grads for local rank when all_* features don't have a gradient 39 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 40 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 41 | gathered_image_features[rank] = image_features 42 | gathered_text_features[rank] = text_features 43 | all_image_features = torch.cat(gathered_image_features, dim=0) 44 | all_text_features = torch.cat(gathered_text_features, dim=0) 45 | else: 46 | # We gather tensors from all gpus 47 | if gather_with_grad: 48 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 49 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 50 | else: 51 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 52 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 53 | dist.all_gather(gathered_image_features, image_features) 54 | dist.all_gather(gathered_text_features, text_features) 55 | if not local_loss: 56 | # ensure grads for local rank when all_* features don't have a gradient 57 | gathered_image_features[rank] = image_features 58 | gathered_text_features[rank] = text_features 59 | all_image_features = torch.cat(gathered_image_features, dim=0) 60 | all_text_features = torch.cat(gathered_text_features, dim=0) 61 | 62 | return all_image_features, all_text_features 63 | 64 | 65 | class ClipLoss(nn.Module): 66 | 67 | def __init__( 68 | self, 69 | local_loss=False, 70 | gather_with_grad=False, 71 | cache_labels=False, 72 | rank=0, 73 | world_size=1, 74 | use_horovod=False, 75 | ): 76 | super().__init__() 77 | self.local_loss = local_loss 78 | self.gather_with_grad = gather_with_grad 79 | self.cache_labels = cache_labels 80 | self.rank = rank 81 | self.world_size = world_size 82 | self.use_horovod = use_horovod 83 | 84 | # cache state 85 | self.prev_num_logits = 0 86 | self.labels = {} 87 | 88 | def forward(self, image_features, text_features, logit_scale): 89 | device = image_features.device 90 | if self.world_size > 1: 91 | all_image_features, all_text_features = gather_features( 92 | image_features, text_features, 93 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 94 | 95 | if self.local_loss: 96 | logits_per_image = logit_scale * image_features @ all_text_features.T 97 | logits_per_text = logit_scale * text_features @ all_image_features.T 98 | else: 99 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 100 | logits_per_text = logits_per_image.T 101 | else: 102 | logits_per_image = logit_scale * image_features @ text_features.T 103 | logits_per_text = logit_scale * text_features @ image_features.T 104 | 105 | # calculated ground-truth and cache if enabled 106 | num_logits = logits_per_image.shape[0] 107 | if self.prev_num_logits != num_logits or device not in self.labels: 108 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 109 | if self.world_size > 1 and self.local_loss: 110 | labels = labels + num_logits * self.rank 111 | if self.cache_labels: 112 | self.labels[device] = labels 113 | self.prev_num_logits = num_logits 114 | else: 115 | labels = self.labels[device] 116 | 117 | total_loss = ( 118 | F.cross_entropy(logits_per_image, labels) + 119 | F.cross_entropy(logits_per_text[:len(logits_per_image)], labels) 120 | ) / 2 121 | return total_loss 122 | -------------------------------------------------------------------------------- /src/open_clip/model.py: -------------------------------------------------------------------------------- 1 | """ CLIP Model 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | from collections import OrderedDict 6 | from dataclasses import dataclass 7 | import logging 8 | import math 9 | from typing import Tuple, Union, Callable, Optional 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | from torch import nn 15 | from torch.utils.checkpoint import checkpoint 16 | 17 | from .timm_model import TimmModel 18 | from .utils import freeze_batch_norm_2d, to_2tuple 19 | 20 | 21 | class Bottleneck(nn.Module): 22 | expansion = 4 23 | 24 | def __init__(self, inplanes, planes, stride=1): 25 | super().__init__() 26 | 27 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 28 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 29 | self.bn1 = nn.BatchNorm2d(planes) 30 | self.relu1 = nn.ReLU(inplace=True) 31 | 32 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.relu2 = nn.ReLU(inplace=True) 35 | 36 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 37 | 38 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 39 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 40 | self.relu3 = nn.ReLU(inplace=True) 41 | 42 | self.downsample = None 43 | self.stride = stride 44 | 45 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 46 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 47 | self.downsample = nn.Sequential(OrderedDict([ 48 | ("-1", nn.AvgPool2d(stride)), 49 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 50 | ("1", nn.BatchNorm2d(planes * self.expansion)) 51 | ])) 52 | 53 | def forward(self, x: torch.Tensor): 54 | identity = x 55 | 56 | out = self.relu1(self.bn1(self.conv1(x))) 57 | out = self.relu2(self.bn2(self.conv2(out))) 58 | out = self.avgpool(out) 59 | out = self.bn3(self.conv3(out)) 60 | 61 | if self.downsample is not None: 62 | identity = self.downsample(x) 63 | 64 | out += identity 65 | out = self.relu3(out) 66 | return out 67 | 68 | 69 | class AttentionPool2d(nn.Module): 70 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 71 | super().__init__() 72 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 73 | self.k_proj = nn.Linear(embed_dim, embed_dim) 74 | self.q_proj = nn.Linear(embed_dim, embed_dim) 75 | self.v_proj = nn.Linear(embed_dim, embed_dim) 76 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 77 | self.num_heads = num_heads 78 | 79 | def forward(self, x): 80 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 81 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 82 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 83 | x, _ = F.multi_head_attention_forward( 84 | query=x, key=x, value=x, 85 | embed_dim_to_check=x.shape[-1], 86 | num_heads=self.num_heads, 87 | q_proj_weight=self.q_proj.weight, 88 | k_proj_weight=self.k_proj.weight, 89 | v_proj_weight=self.v_proj.weight, 90 | in_proj_weight=None, 91 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 92 | bias_k=None, 93 | bias_v=None, 94 | add_zero_attn=False, 95 | dropout_p=0, 96 | out_proj_weight=self.c_proj.weight, 97 | out_proj_bias=self.c_proj.bias, 98 | use_separate_proj_weight=True, 99 | training=self.training, 100 | need_weights=False 101 | ) 102 | 103 | return x[0] 104 | 105 | 106 | class ModifiedResNet(nn.Module): 107 | """ 108 | A ResNet class that is similar to torchvision's but contains the following changes: 109 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 110 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 111 | - The final pooling layer is a QKV attention instead of an average pool 112 | """ 113 | 114 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 115 | super().__init__() 116 | self.output_dim = output_dim 117 | self.image_size = image_size 118 | 119 | # the 3-layer stem 120 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 121 | self.bn1 = nn.BatchNorm2d(width // 2) 122 | self.relu1 = nn.ReLU(inplace=True) 123 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 124 | self.bn2 = nn.BatchNorm2d(width // 2) 125 | self.relu2 = nn.ReLU(inplace=True) 126 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 127 | self.bn3 = nn.BatchNorm2d(width) 128 | self.relu3 = nn.ReLU(inplace=True) 129 | self.avgpool = nn.AvgPool2d(2) 130 | 131 | # residual layers 132 | self._inplanes = width # this is a *mutable* variable used during construction 133 | self.layer1 = self._make_layer(width, layers[0]) 134 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 135 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 136 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 137 | 138 | embed_dim = width * 32 # the ResNet feature dimension 139 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 140 | 141 | self.init_parameters() 142 | 143 | def _make_layer(self, planes, blocks, stride=1): 144 | layers = [Bottleneck(self._inplanes, planes, stride)] 145 | 146 | self._inplanes = planes * Bottleneck.expansion 147 | for _ in range(1, blocks): 148 | layers.append(Bottleneck(self._inplanes, planes)) 149 | 150 | return nn.Sequential(*layers) 151 | 152 | def init_parameters(self): 153 | if self.attnpool is not None: 154 | std = self.attnpool.c_proj.in_features ** -0.5 155 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 156 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 157 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 158 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 159 | 160 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 161 | for name, param in resnet_block.named_parameters(): 162 | if name.endswith("bn3.weight"): 163 | nn.init.zeros_(param) 164 | 165 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 166 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 167 | for param in self.parameters(): 168 | param.requires_grad = False 169 | if freeze_bn_stats: 170 | freeze_batch_norm_2d(self) 171 | 172 | @torch.jit.ignore 173 | def set_grad_checkpointing(self, enable=True): 174 | # FIXME support for non-transformer 175 | pass 176 | 177 | def stem(self, x): 178 | x = self.relu1(self.bn1(self.conv1(x))) 179 | x = self.relu2(self.bn2(self.conv2(x))) 180 | x = self.relu3(self.bn3(self.conv3(x))) 181 | x = self.avgpool(x) 182 | return x 183 | 184 | def forward(self, x): 185 | x = self.stem(x) 186 | x = self.layer1(x) 187 | x = self.layer2(x) 188 | x = self.layer3(x) 189 | x = self.layer4(x) 190 | x = self.attnpool(x) 191 | 192 | return x 193 | 194 | 195 | class LayerNorm(nn.LayerNorm): 196 | """Subclass torch's LayerNorm to handle fp16.""" 197 | 198 | def forward(self, x: torch.Tensor): 199 | orig_type = x.dtype 200 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 201 | return x.to(orig_type) 202 | 203 | 204 | class QuickGELU(nn.Module): 205 | # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory 206 | def forward(self, x: torch.Tensor): 207 | return x * torch.sigmoid(1.702 * x) 208 | 209 | 210 | class ResidualAttentionBlock(nn.Module): 211 | def __init__(self, d_model: int, n_head: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU): 212 | super().__init__() 213 | 214 | self.attn = nn.MultiheadAttention(d_model, n_head) 215 | self.ln_1 = LayerNorm(d_model) 216 | mlp_width = int(d_model * mlp_ratio) 217 | self.mlp = nn.Sequential(OrderedDict([ 218 | ("c_fc", nn.Linear(d_model, mlp_width)), 219 | ("gelu", act_layer()), 220 | ("c_proj", nn.Linear(mlp_width, d_model)) 221 | ])) 222 | self.ln_2 = LayerNorm(d_model) 223 | 224 | def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 225 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] 226 | 227 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 228 | x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) 229 | x = x + self.mlp(self.ln_2(x)) 230 | return x 231 | 232 | 233 | class Transformer(nn.Module): 234 | def __init__(self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU): 235 | super().__init__() 236 | self.width = width 237 | self.layers = layers 238 | self.grad_checkpointing = False 239 | 240 | self.resblocks = nn.ModuleList([ 241 | ResidualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer) 242 | for _ in range(layers) 243 | ]) 244 | 245 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 246 | for r in self.resblocks: 247 | if self.grad_checkpointing and not torch.jit.is_scripting(): 248 | x = checkpoint(r, x, attn_mask) 249 | else: 250 | x = r(x, attn_mask=attn_mask) 251 | return x 252 | 253 | 254 | class VisualTransformer(nn.Module): 255 | def __init__( 256 | self, image_size: int, patch_size: int, width: int, layers: int, heads: int, mlp_ratio: float, 257 | output_dim: int, act_layer: Callable = nn.GELU): 258 | super().__init__() 259 | self.image_size = to_2tuple(image_size) 260 | self.patch_size = to_2tuple(patch_size) 261 | self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1]) 262 | self.output_dim = output_dim 263 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 264 | 265 | scale = width ** -0.5 266 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 267 | self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) 268 | self.ln_pre = LayerNorm(width) 269 | 270 | self.transformer = Transformer(width, layers, heads, mlp_ratio, act_layer=act_layer) 271 | 272 | self.ln_post = LayerNorm(width) 273 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 274 | 275 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 276 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 277 | for param in self.parameters(): 278 | param.requires_grad = False 279 | 280 | @torch.jit.ignore 281 | def set_grad_checkpointing(self, enable=True): 282 | self.transformer.grad_checkpointing = enable 283 | 284 | def forward(self, x: torch.Tensor): 285 | x = self.conv1(x) # shape = [*, width, grid, grid] 286 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 287 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 288 | x = torch.cat( 289 | [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 290 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 291 | x = x + self.positional_embedding.to(x.dtype) 292 | x = self.ln_pre(x) 293 | 294 | x = x.permute(1, 0, 2) # NLD -> LND 295 | x = self.transformer(x) 296 | x = x.permute(1, 0, 2) # LND -> NLD 297 | 298 | x = self.ln_post(x[:, 0, :]) 299 | 300 | if self.proj is not None: 301 | x = x @ self.proj 302 | 303 | return x 304 | 305 | 306 | @dataclass 307 | class CLIPVisionCfg: 308 | layers: Union[Tuple[int, int, int, int], int] = 12 309 | width: int = 768 310 | head_width: int = 64 311 | mlp_ratio: float = 4.0 312 | patch_size: int = 16 313 | image_size: Union[Tuple[int, int], int] = 224 314 | timm_model_name: str = None # a valid model name overrides layers, width, patch_size 315 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model 316 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') 317 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') 318 | 319 | 320 | @dataclass 321 | class CLIPTextCfg: 322 | context_length: int = 77 323 | vocab_size: int = 49408 324 | width: int = 512 325 | heads: int = 8 326 | layers: int = 12 327 | 328 | 329 | class CLIP(nn.Module): 330 | def __init__( 331 | self, 332 | embed_dim: int, 333 | vision_cfg: CLIPVisionCfg, 334 | text_cfg: CLIPTextCfg, 335 | quick_gelu: bool = False, 336 | ): 337 | super().__init__() 338 | if isinstance(vision_cfg, dict): 339 | vision_cfg = CLIPVisionCfg(**vision_cfg) 340 | if isinstance(text_cfg, dict): 341 | text_cfg = CLIPTextCfg(**text_cfg) 342 | 343 | self.context_length = text_cfg.context_length 344 | 345 | # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more 346 | # memory efficient in recent PyTorch releases (>= 1.10). 347 | # NOTE: timm models always use native GELU regardless of quick_gelu flag. 348 | act_layer = QuickGELU if quick_gelu else nn.GELU 349 | 350 | if vision_cfg.timm_model_name: 351 | self.visual = TimmModel( 352 | vision_cfg.timm_model_name, 353 | pretrained=vision_cfg.timm_model_pretrained, 354 | pool=vision_cfg.timm_pool, 355 | proj=vision_cfg.timm_proj, 356 | embed_dim=embed_dim, 357 | image_size=vision_cfg.image_size 358 | ) 359 | act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models 360 | elif isinstance(vision_cfg.layers, (tuple, list)): 361 | vision_heads = vision_cfg.width * 32 // vision_cfg.head_width 362 | self.visual = ModifiedResNet( 363 | layers=vision_cfg.layers, 364 | output_dim=embed_dim, 365 | heads=vision_heads, 366 | image_size=vision_cfg.image_size, 367 | width=vision_cfg.width 368 | ) 369 | else: 370 | vision_heads = vision_cfg.width // vision_cfg.head_width 371 | self.visual = VisualTransformer( 372 | image_size=vision_cfg.image_size, 373 | patch_size=vision_cfg.patch_size, 374 | width=vision_cfg.width, 375 | layers=vision_cfg.layers, 376 | heads=vision_heads, 377 | mlp_ratio=vision_cfg.mlp_ratio, 378 | output_dim=embed_dim, 379 | act_layer=act_layer, 380 | ) 381 | 382 | self.transformer = Transformer( 383 | width=text_cfg.width, 384 | layers=text_cfg.layers, 385 | heads=text_cfg.heads, 386 | act_layer=act_layer, 387 | ) 388 | 389 | self.vocab_size = text_cfg.vocab_size 390 | self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width) 391 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, text_cfg.width)) 392 | self.ln_final = LayerNorm(text_cfg.width) 393 | 394 | self.text_projection = nn.Parameter(torch.empty(text_cfg.width, embed_dim)) 395 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 396 | self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) 397 | 398 | self.init_parameters() 399 | 400 | def init_parameters(self): 401 | nn.init.normal_(self.token_embedding.weight, std=0.02) 402 | nn.init.normal_(self.positional_embedding, std=0.01) 403 | nn.init.constant_(self.logit_scale, np.log(1 / 0.07)) 404 | 405 | if hasattr(self.visual, 'init_parameters'): 406 | self.visual.init_parameters() 407 | 408 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 409 | attn_std = self.transformer.width ** -0.5 410 | fc_std = (2 * self.transformer.width) ** -0.5 411 | for block in self.transformer.resblocks: 412 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 413 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 414 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 415 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 416 | 417 | if self.text_projection is not None: 418 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 419 | 420 | def build_attention_mask(self): 421 | # lazily create causal attention mask, with full attention between the vision tokens 422 | # pytorch uses additive attention mask; fill with -inf 423 | mask = torch.empty(self.context_length, self.context_length) 424 | mask.fill_(float("-inf")) 425 | mask.triu_(1) # zero out the lower diagonal 426 | return mask 427 | 428 | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): 429 | # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 430 | self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) 431 | 432 | @torch.jit.ignore 433 | def set_grad_checkpointing(self, enable=True): 434 | self.visual.set_grad_checkpointing(enable) 435 | self.transformer.grad_checkpointing = enable 436 | 437 | def encode_image(self, image): 438 | return self.visual(image) 439 | 440 | def encode_text(self, text): 441 | x = self.token_embedding(text) # [batch_size, n_ctx, d_model] 442 | 443 | x = x + self.positional_embedding 444 | x = x.permute(1, 0, 2) # NLD -> LND 445 | x = self.transformer(x, attn_mask=self.attn_mask) 446 | x = x.permute(1, 0, 2) # LND -> NLD 447 | x = self.ln_final(x) 448 | 449 | # x.shape = [batch_size, n_ctx, transformer.width] 450 | # take features from the eot embedding (eot_token is the highest number in each sequence) 451 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 452 | 453 | return x 454 | 455 | def forward(self, image, text): 456 | if image is None: 457 | return self.encode_text(text) 458 | elif text is None: 459 | return self.encode_image(image) 460 | image_features = self.encode_image(image) 461 | image_features = F.normalize(image_features, dim=-1) 462 | 463 | text_features = self.encode_text(text) 464 | text_features = F.normalize(text_features, dim=-1) 465 | 466 | return image_features, text_features, self.logit_scale.exp() 467 | 468 | 469 | def convert_weights_to_fp16(model: nn.Module): 470 | """Convert applicable model parameters to fp16""" 471 | 472 | def _convert_weights_to_fp16(l): 473 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 474 | l.weight.data = l.weight.data.half() 475 | if l.bias is not None: 476 | l.bias.data = l.bias.data.half() 477 | 478 | if isinstance(l, nn.MultiheadAttention): 479 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 480 | tensor = getattr(l, attr) 481 | if tensor is not None: 482 | tensor.data = tensor.data.half() 483 | 484 | for name in ["text_projection", "proj"]: 485 | if hasattr(l, name): 486 | attr = getattr(l, name) 487 | if attr is not None: 488 | attr.data = attr.data.half() 489 | 490 | model.apply(_convert_weights_to_fp16) 491 | 492 | 493 | def build_model_from_openai_state_dict(state_dict: dict): 494 | vit = "visual.proj" in state_dict 495 | 496 | if vit: 497 | vision_width = state_dict["visual.conv1.weight"].shape[0] 498 | vision_layers = len( 499 | [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 500 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 501 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 502 | image_size = vision_patch_size * grid_size 503 | else: 504 | counts: list = [ 505 | len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 506 | vision_layers = tuple(counts) 507 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 508 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 509 | vision_patch_size = None 510 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 511 | image_size = output_width * 32 512 | 513 | embed_dim = state_dict["text_projection"].shape[1] 514 | context_length = state_dict["positional_embedding"].shape[0] 515 | vocab_size = state_dict["token_embedding.weight"].shape[0] 516 | transformer_width = state_dict["ln_final.weight"].shape[0] 517 | transformer_heads = transformer_width // 64 518 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 519 | 520 | vision_cfg = CLIPVisionCfg( 521 | layers=vision_layers, 522 | width=vision_width, 523 | patch_size=vision_patch_size, 524 | image_size=image_size, 525 | ) 526 | text_cfg = CLIPTextCfg( 527 | context_length=context_length, 528 | vocab_size=vocab_size, 529 | width=transformer_width, 530 | heads=transformer_heads, 531 | layers=transformer_layers 532 | ) 533 | model = CLIP( 534 | embed_dim, 535 | vision_cfg=vision_cfg, 536 | text_cfg=text_cfg, 537 | quick_gelu=True, # OpenAI models were trained with QuickGELU 538 | ) 539 | 540 | for key in ["input_resolution", "context_length", "vocab_size"]: 541 | state_dict.pop(key, None) 542 | 543 | convert_weights_to_fp16(model) 544 | model.load_state_dict(state_dict) 545 | return model.eval() 546 | 547 | 548 | def trace_model(model, batch_size=256, device=torch.device('cpu')): 549 | model.eval() 550 | image_size = model.visual.image_size 551 | example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) 552 | example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) 553 | model = torch.jit.trace_module( 554 | model, 555 | inputs=dict( 556 | forward=(example_images, example_text), 557 | encode_text=(example_text,), 558 | encode_image=(example_images,) 559 | )) 560 | model.visual.image_size = image_size 561 | return model 562 | 563 | 564 | def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): 565 | # Rescale the grid of position embeddings when loading from state_dict 566 | old_pos_embed = state_dict.get('visual.positional_embedding', None) 567 | if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): 568 | return 569 | grid_size = to_2tuple(model.visual.grid_size) 570 | extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) 571 | new_seq_len = grid_size[0] * grid_size[1] + extra_tokens 572 | if new_seq_len == old_pos_embed.shape[0]: 573 | return 574 | 575 | if extra_tokens: 576 | pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] 577 | else: 578 | pos_emb_tok, pos_emb_img = None, old_pos_embed 579 | old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) 580 | 581 | logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) 582 | pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) 583 | pos_emb_img = F.interpolate( 584 | pos_emb_img, 585 | size=grid_size, 586 | mode=interpolation, 587 | align_corners=True, 588 | ) 589 | pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] 590 | if pos_emb_tok is not None: 591 | new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) 592 | else: 593 | new_pos_embed = pos_emb_img 594 | state_dict['visual.positional_embedding'] = new_pos_embed 595 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-16-plus-240.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 240, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-16-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-32-plus-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-H-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-14-280.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 280, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-16-320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 320, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/timm-efficientnetv2_rw_s.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "efficientnetv2_rw_s", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 288 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/timm-resnet50d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "resnet50d", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/timm-resnetaa50d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "resnetaa50d", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/timm-resnetblur50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "resnetblur50", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/timm-swin_base_patch4_window7_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "swin_base_patch4_window7_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/timm-vit_base_patch16_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_base_patch16_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/timm-vit_base_patch32_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_base_patch32_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/timm-vit_small_patch16_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_small_patch16_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import Union, List 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict 13 | from .pretrained import get_pretrained_url, list_pretrained_tag_models, download_pretrained 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_tag_models('openai') 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 26 | jit=True, 27 | ): 28 | """Load a CLIP model 29 | 30 | Parameters 31 | ---------- 32 | name : str 33 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 34 | device : Union[str, torch.device] 35 | The device to put the loaded model 36 | jit : bool 37 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 38 | 39 | Returns 40 | ------- 41 | model : torch.nn.Module 42 | The CLIP model 43 | preprocess : Callable[[PIL.Image], torch.Tensor] 44 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 45 | """ 46 | if get_pretrained_url(name, 'openai'): 47 | model_path = download_pretrained(get_pretrained_url(name, 'openai')) 48 | elif os.path.isfile(name): 49 | model_path = name 50 | else: 51 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 52 | 53 | try: 54 | # loading JIT archive 55 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 56 | state_dict = None 57 | except RuntimeError: 58 | # loading saved state dict 59 | if jit: 60 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 61 | jit = False 62 | state_dict = torch.load(model_path, map_location="cpu") 63 | 64 | if not jit: 65 | try: 66 | model = build_model_from_openai_state_dict(state_dict or model.state_dict()).to(device) 67 | except KeyError: 68 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 69 | model = build_model_from_openai_state_dict(sd).to(device) 70 | 71 | if str(device) == "cpu": 72 | model.float() 73 | return model 74 | 75 | # patch the device names 76 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 77 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 78 | 79 | def patch_device(module): 80 | try: 81 | graphs = [module.graph] if hasattr(module, "graph") else [] 82 | except RuntimeError: 83 | graphs = [] 84 | 85 | if hasattr(module, "forward1"): 86 | graphs.append(module.forward1.graph) 87 | 88 | for graph in graphs: 89 | for node in graph.findAllNodes("prim::Constant"): 90 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 91 | node.copyAttributes(device_node) 92 | 93 | model.apply(patch_device) 94 | patch_device(model.encode_image) 95 | patch_device(model.encode_text) 96 | 97 | # patch dtype to float32 on CPU 98 | if str(device) == "cpu": 99 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 100 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 101 | float_node = float_input.node() 102 | 103 | def patch_float(module): 104 | try: 105 | graphs = [module.graph] if hasattr(module, "graph") else [] 106 | except RuntimeError: 107 | graphs = [] 108 | 109 | if hasattr(module, "forward1"): 110 | graphs.append(module.forward1.graph) 111 | 112 | for graph in graphs: 113 | for node in graph.findAllNodes("aten::to"): 114 | inputs = list(node.inputs()) 115 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 116 | if inputs[i].node()["value"] == 5: 117 | inputs[i].node().copyAttributes(float_node) 118 | 119 | model.apply(patch_float) 120 | patch_float(model.encode_image) 121 | patch_float(model.encode_text) 122 | model.float() 123 | 124 | # ensure image_size attr available at consistent location for both jit and non-jit 125 | model.visual.image_size = model.input_resolution.item() 126 | return model 127 | -------------------------------------------------------------------------------- /src/open_clip/pretrained.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | 6 | from tqdm import tqdm 7 | 8 | _RN50 = dict( 9 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 10 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 11 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt" 12 | ) 13 | 14 | _RN50_quickgelu = dict( 15 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 16 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 17 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt" 18 | ) 19 | 20 | _RN101 = dict( 21 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 22 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt" 23 | ) 24 | 25 | _RN101_quickgelu = dict( 26 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 27 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt" 28 | ) 29 | 30 | _RN50x4 = dict( 31 | openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 32 | ) 33 | 34 | _RN50x16 = dict( 35 | openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 36 | ) 37 | 38 | _RN50x64 = dict( 39 | openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 40 | ) 41 | 42 | _VITB32 = dict( 43 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 44 | laion2b_e16="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth", 45 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 46 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 47 | ) 48 | 49 | _VITB32_quickgelu = dict( 50 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 51 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 52 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 53 | ) 54 | 55 | _VITB16 = dict( 56 | openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 57 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt", 58 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt", 59 | ) 60 | 61 | _VITB16_PLUS_240 = dict( 62 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt", 63 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt", 64 | ) 65 | 66 | _VITL14 = dict( 67 | openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 68 | laion400m_e31='https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt', 69 | laion400m_e32='https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt', 70 | ) 71 | 72 | _VITL14_336 = dict( 73 | openai="https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt" 74 | ) 75 | 76 | _PRETRAINED = { 77 | "RN50": _RN50, 78 | "RN50-quickgelu": _RN50_quickgelu, 79 | "RN101": _RN101, 80 | "RN101-quickgelu": _RN101_quickgelu, 81 | "RN50x4": _RN50x4, 82 | "RN50x16": _RN50x16, 83 | "RN50x64": _RN50x64, 84 | "ViT-B-32": _VITB32, 85 | "ViT-B-32-quickgelu": _VITB32_quickgelu, 86 | "ViT-B-16": _VITB16, 87 | "ViT-B-16-plus-240": _VITB16_PLUS_240, 88 | "ViT-L-14": _VITL14, 89 | "ViT-L-14-336": _VITL14_336, 90 | } 91 | 92 | 93 | def list_pretrained(as_str: bool = False): 94 | """ returns list of pretrained models 95 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True 96 | """ 97 | return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] 98 | 99 | 100 | def list_pretrained_tag_models(tag: str): 101 | """ return all models having the specified pretrain tag """ 102 | models = [] 103 | for k in _PRETRAINED.keys(): 104 | if tag in _PRETRAINED[k]: 105 | models.append(k) 106 | return models 107 | 108 | 109 | def list_pretrained_model_tags(model: str): 110 | """ return all pretrain tags for the specified model architecture """ 111 | tags = [] 112 | if model in _PRETRAINED: 113 | tags.extend(_PRETRAINED[model].keys()) 114 | return tags 115 | 116 | 117 | def get_pretrained_url(model: str, tag: str): 118 | if model not in _PRETRAINED: 119 | return '' 120 | model_pretrained = _PRETRAINED[model] 121 | tag = tag.lower() 122 | if tag not in model_pretrained: 123 | return '' 124 | return model_pretrained[tag] 125 | 126 | 127 | def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")): 128 | os.makedirs(root, exist_ok=True) 129 | filename = os.path.basename(url) 130 | 131 | if 'openaipublic' in url: 132 | expected_sha256 = url.split("/")[-2] 133 | else: 134 | expected_sha256 = '' 135 | 136 | download_target = os.path.join(root, filename) 137 | 138 | if os.path.exists(download_target) and not os.path.isfile(download_target): 139 | raise RuntimeError(f"{download_target} exists and is not a regular file") 140 | 141 | if os.path.isfile(download_target): 142 | if expected_sha256: 143 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 144 | return download_target 145 | else: 146 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 147 | else: 148 | return download_target 149 | 150 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 151 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 152 | while True: 153 | buffer = source.read(8192) 154 | if not buffer: 155 | break 156 | 157 | output.write(buffer) 158 | loop.update(len(buffer)) 159 | 160 | if expected_sha256 and hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 161 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 162 | 163 | return download_target 164 | -------------------------------------------------------------------------------- /src/open_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch.nn as nn 8 | 9 | try: 10 | import timm 11 | from timm.models.layers import Mlp, to_2tuple 12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 13 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 14 | except ImportError as e: 15 | timm = None 16 | 17 | from .utils import freeze_batch_norm_2d 18 | 19 | 20 | class TimmModel(nn.Module): 21 | """ timm model adapter 22 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model_name, 28 | embed_dim, 29 | image_size=224, 30 | pool='avg', 31 | proj='linear', 32 | drop=0., 33 | pretrained=False): 34 | super().__init__() 35 | if timm is None: 36 | raise RuntimeError("Please `pip install timm` to use timm models.") 37 | 38 | self.image_size = to_2tuple(image_size) 39 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 40 | feat_size = self.trunk.default_cfg.get('pool_size', None) 41 | feature_ndim = 1 if not feat_size else 2 42 | if pool in ('abs_attn', 'rot_attn'): 43 | assert feature_ndim == 2 44 | # if attn pooling used, remove both classifier and default pool 45 | self.trunk.reset_classifier(0, global_pool='') 46 | else: 47 | # reset global pool if pool config set, otherwise leave as network default 48 | reset_kwargs = dict(global_pool=pool) if pool else {} 49 | self.trunk.reset_classifier(0, **reset_kwargs) 50 | prev_chs = self.trunk.num_features 51 | 52 | head_layers = OrderedDict() 53 | if pool == 'abs_attn': 54 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 55 | prev_chs = embed_dim 56 | elif pool == 'rot_attn': 57 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 58 | prev_chs = embed_dim 59 | else: 60 | assert proj, 'projection layer needed if non-attention pooling is used.' 61 | 62 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 63 | if proj == 'linear': 64 | head_layers['drop'] = nn.Dropout(drop) 65 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim) 66 | elif proj == 'mlp': 67 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) 68 | 69 | self.head = nn.Sequential(head_layers) 70 | 71 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 72 | """ lock modules 73 | Args: 74 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 75 | """ 76 | if not unlocked_groups: 77 | # lock full model 78 | for param in self.trunk.parameters(): 79 | param.requires_grad = False 80 | if freeze_bn_stats: 81 | freeze_batch_norm_2d(self.trunk) 82 | else: 83 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 84 | try: 85 | # FIXME import here until API stable and in an official release 86 | from timm.models.helpers import group_parameters, group_modules 87 | except ImportError: 88 | raise RuntimeError( 89 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 90 | matcher = self.trunk.group_matcher() 91 | gparams = group_parameters(self.trunk, matcher) 92 | max_layer_id = max(gparams.keys()) 93 | max_layer_id = max_layer_id - unlocked_groups 94 | for group_idx in range(max_layer_id + 1): 95 | group = gparams[group_idx] 96 | for param in group: 97 | self.trunk.get_parameter(param).requires_grad = False 98 | if freeze_bn_stats: 99 | gmodules = group_modules(self.trunk, matcher, reverse=True) 100 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 101 | freeze_batch_norm_2d(self.trunk, gmodules) 102 | 103 | def forward(self, x): 104 | x = self.trunk(x) 105 | x = self.head(x) 106 | return x 107 | -------------------------------------------------------------------------------- /src/open_clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | 16 | @lru_cache() 17 | def default_bpe(): 18 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 19 | 20 | 21 | @lru_cache() 22 | def bytes_to_unicode(): 23 | """ 24 | Returns list of utf-8 byte and a corresponding list of unicode strings. 25 | The reversible bpe codes work on unicode strings. 26 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 27 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 28 | This is a signficant percentage of your normal, say, 32K bpe vocab. 29 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 30 | And avoids mapping to whitespace/control characters the bpe code barfs on. 31 | """ 32 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 33 | cs = bs[:] 34 | n = 0 35 | for b in range(2**8): 36 | if b not in bs: 37 | bs.append(b) 38 | cs.append(2**8+n) 39 | n += 1 40 | cs = [chr(n) for n in cs] 41 | return dict(zip(bs, cs)) 42 | 43 | 44 | def get_pairs(word): 45 | """Return set of symbol pairs in a word. 46 | Word is represented as tuple of symbols (symbols being variable-length strings). 47 | """ 48 | pairs = set() 49 | prev_char = word[0] 50 | for char in word[1:]: 51 | pairs.add((prev_char, char)) 52 | prev_char = char 53 | return pairs 54 | 55 | 56 | def basic_clean(text): 57 | text = ftfy.fix_text(text) 58 | text = html.unescape(html.unescape(text)) 59 | return text.strip() 60 | 61 | 62 | def whitespace_clean(text): 63 | text = re.sub(r'\s+', ' ', text) 64 | text = text.strip() 65 | return text 66 | 67 | 68 | class SimpleTokenizer(object): 69 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 70 | self.byte_encoder = bytes_to_unicode() 71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 72 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 73 | merges = merges[1:49152-256-2+1] 74 | merges = [tuple(merge.split()) for merge in merges] 75 | vocab = list(bytes_to_unicode().values()) 76 | vocab = vocab + [v+'' for v in vocab] 77 | for merge in merges: 78 | vocab.append(''.join(merge)) 79 | if not special_tokens: 80 | special_tokens = ['', ''] 81 | else: 82 | special_tokens = ['', ''] + special_tokens 83 | vocab.extend(special_tokens) 84 | self.encoder = dict(zip(vocab, range(len(vocab)))) 85 | self.decoder = {v: k for k, v in self.encoder.items()} 86 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 87 | self.cache = {t:t for t in special_tokens} 88 | special = "|".join(special_tokens) 89 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 90 | 91 | self.vocab_size = len(self.encoder) 92 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 93 | 94 | def bpe(self, token): 95 | if token in self.cache: 96 | return self.cache[token] 97 | word = tuple(token[:-1]) + ( token[-1] + '',) 98 | pairs = get_pairs(word) 99 | 100 | if not pairs: 101 | return token+'' 102 | 103 | while True: 104 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 105 | if bigram not in self.bpe_ranks: 106 | break 107 | first, second = bigram 108 | new_word = [] 109 | i = 0 110 | while i < len(word): 111 | try: 112 | j = word.index(first, i) 113 | new_word.extend(word[i:j]) 114 | i = j 115 | except: 116 | new_word.extend(word[i:]) 117 | break 118 | 119 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 120 | new_word.append(first+second) 121 | i += 2 122 | else: 123 | new_word.append(word[i]) 124 | i += 1 125 | new_word = tuple(new_word) 126 | word = new_word 127 | if len(word) == 1: 128 | break 129 | else: 130 | pairs = get_pairs(word) 131 | word = ' '.join(word) 132 | self.cache[token] = word 133 | return word 134 | 135 | def encode(self, text): 136 | bpe_tokens = [] 137 | text = whitespace_clean(basic_clean(text)).lower() 138 | for token in re.findall(self.pat, text): 139 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 140 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 141 | return bpe_tokens 142 | 143 | def decode(self, tokens): 144 | text = ''.join([self.decoder[token] for token in tokens]) 145 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 146 | return text 147 | 148 | 149 | _tokenizer = SimpleTokenizer() 150 | 151 | 152 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 153 | """ 154 | Returns the tokenized representation of given input string(s) 155 | 156 | Parameters 157 | ---------- 158 | texts : Union[str, List[str]] 159 | An input string or a list of input strings to tokenize 160 | context_length : int 161 | The context length to use; all CLIP models use 77 as the context length 162 | 163 | Returns 164 | ------- 165 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 166 | """ 167 | if isinstance(texts, str): 168 | texts = [texts] 169 | 170 | sot_token = _tokenizer.encoder[""] 171 | eot_token = _tokenizer.encoder[""] 172 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 173 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 174 | 175 | for i, tokens in enumerate(all_tokens): 176 | if len(tokens) > context_length: 177 | tokens = tokens[:context_length] # Truncate 178 | tokens[-1] = eot_token 179 | result[i, :len(tokens)] = torch.tensor(tokens) 180 | 181 | return result 182 | -------------------------------------------------------------------------------- /src/open_clip/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms.functional as F 6 | 7 | 8 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 9 | CenterCrop 10 | 11 | 12 | class ResizeMaxSize(nn.Module): 13 | 14 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): 15 | super().__init__() 16 | if not isinstance(max_size, int): 17 | raise TypeError(f"Size should be int. Got {type(max_size)}") 18 | self.max_size = max_size 19 | self.interpolation = interpolation 20 | self.fn = min if fn == 'min' else min 21 | self.fill = fill 22 | 23 | def forward(self, img): 24 | if isinstance(img, torch.Tensor): 25 | height, width = img.shape[:2] 26 | else: 27 | width, height = img.size 28 | scale = self.max_size / float(max(height, width)) 29 | if scale != 1.0: 30 | new_size = tuple(round(dim * scale) for dim in (height, width)) 31 | img = F.resize(img, new_size, self.interpolation) 32 | pad_h = self.max_size - new_size[0] 33 | pad_w = self.max_size - new_size[1] 34 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) 35 | return img 36 | 37 | 38 | def _convert_to_rgb(image): 39 | return image.convert('RGB') 40 | 41 | 42 | def image_transform( 43 | image_size: int, 44 | is_train: bool, 45 | mean: Optional[Tuple[float, ...]] = None, 46 | std: Optional[Tuple[float, ...]] = None, 47 | resize_longest_max: bool = False, 48 | fill_color: int = 0, 49 | ): 50 | mean = mean or (0.48145466, 0.4578275, 0.40821073) # OpenAI dataset mean 51 | std = std or (0.26862954, 0.26130258, 0.27577711) # OpenAI dataset std 52 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 53 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 54 | image_size = image_size[0] 55 | 56 | normalize = Normalize(mean=mean, std=std) 57 | if is_train: 58 | return Compose([ 59 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 60 | _convert_to_rgb, 61 | ToTensor(), 62 | normalize, 63 | ]) 64 | else: 65 | if resize_longest_max: 66 | transforms = [ 67 | ResizeMaxSize(image_size, fill=fill_color) 68 | ] 69 | else: 70 | transforms = [ 71 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 72 | CenterCrop(image_size), 73 | ] 74 | transforms.extend([ 75 | _convert_to_rgb, 76 | ToTensor(), 77 | normalize, 78 | ]) 79 | return Compose(transforms) 80 | -------------------------------------------------------------------------------- /src/open_clip/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | from torch import nn as nn 5 | from torchvision.ops.misc import FrozenBatchNorm2d 6 | 7 | 8 | def freeze_batch_norm_2d(module, module_match={}, name=''): 9 | """ 10 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 11 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 12 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 13 | 14 | Args: 15 | module (torch.nn.Module): Any PyTorch module. 16 | module_match (dict): Dictionary of full module names to freeze (all if empty) 17 | name (str): Full module name (prefix) 18 | 19 | Returns: 20 | torch.nn.Module: Resulting module 21 | 22 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 23 | """ 24 | res = module 25 | is_match = True 26 | if module_match: 27 | is_match = name in module_match 28 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 29 | res = FrozenBatchNorm2d(module.num_features) 30 | res.num_features = module.num_features 31 | res.affine = module.affine 32 | if module.affine: 33 | res.weight.data = module.weight.data.clone().detach() 34 | res.bias.data = module.bias.data.clone().detach() 35 | res.running_mean.data = module.running_mean.data 36 | res.running_var.data = module.running_var.data 37 | res.eps = module.eps 38 | else: 39 | for child_name, child in module.named_children(): 40 | full_child_name = '.'.join([name, child_name]) if name else child_name 41 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 42 | if new_child is not child: 43 | res.add_module(child_name, new_child) 44 | return res 45 | 46 | 47 | # From PyTorch internals 48 | def _ntuple(n): 49 | def parse(x): 50 | if isinstance(x, collections.abc.Iterable): 51 | return x 52 | return tuple(repeat(x, n)) 53 | return parse 54 | 55 | 56 | to_1tuple = _ntuple(1) 57 | to_2tuple = _ntuple(2) 58 | to_3tuple = _ntuple(3) 59 | to_4tuple = _ntuple(4) 60 | to_ntuple = lambda n, x: _ntuple(n)(x) 61 | -------------------------------------------------------------------------------- /src/open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.3.0' 2 | -------------------------------------------------------------------------------- /src/training/.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | -------------------------------------------------------------------------------- /src/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinid/neg_clip/e680f8e44ab9f081f62094232d8ee3b77f8574f7/src/training/__init__.py -------------------------------------------------------------------------------- /src/training/data.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | import logging 4 | import math 5 | import os 6 | import random 7 | import sys 8 | import time 9 | from dataclasses import dataclass 10 | from multiprocessing import Value 11 | import ast 12 | import random 13 | import braceexpand 14 | import numpy as np 15 | import pandas as pd 16 | import torch 17 | import torchvision.datasets as datasets 18 | import webdataset as wds 19 | from PIL import Image 20 | from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info 21 | from torch.utils.data.distributed import DistributedSampler 22 | from webdataset.filters import _shuffle 23 | from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample 24 | 25 | try: 26 | import horovod.torch as hvd 27 | except ImportError: 28 | hvd = None 29 | 30 | from open_clip import tokenize 31 | 32 | 33 | class CsvDataset(Dataset): 34 | def __init__(self, input_filename, transforms, img_key, caption_key, hard_captions_key, sep="\t"): 35 | logging.debug(f'Loading csv data from {input_filename}.') 36 | df = pd.read_csv(input_filename, sep=sep, converters={"neg_caption":ast.literal_eval, "neg_image":ast.literal_eval}) 37 | 38 | self.images = df[img_key].tolist() 39 | self.captions = df[caption_key].tolist() 40 | self.hard_captions = df[hard_captions_key].tolist() 41 | self.hard_images = df["neg_image"].tolist() 42 | self.transforms = transforms 43 | logging.debug('Done loading data.') 44 | 45 | def __len__(self): 46 | return len(self.captions) 47 | 48 | def __getitem__(self, idx): 49 | images = self.transforms(Image.open(str(self.images[idx]))) 50 | texts = tokenize([str(self.captions[idx])])[0] 51 | 52 | chosen_caption = random.choice(self.hard_captions[idx]) 53 | hard_captions = tokenize([str(chosen_caption)])[0] 54 | 55 | chose_image_index = random.choice(self.hard_images[idx]) 56 | 57 | new_images = self.transforms(Image.open(str(self.images[chose_image_index]))) 58 | new_texts = tokenize([str(self.captions[chose_image_index])])[0] 59 | 60 | chosen_caption = random.choice(self.hard_captions[chose_image_index]) 61 | new_hard = tokenize([str(chosen_caption)])[0] 62 | 63 | return images, new_images, texts, new_texts, hard_captions, new_hard 64 | 65 | 66 | 67 | class SharedEpoch: 68 | def __init__(self, epoch: int = 0): 69 | self.shared_epoch = Value('i', epoch) 70 | 71 | def set_value(self, epoch): 72 | self.shared_epoch.value = epoch 73 | 74 | def get_value(self): 75 | return self.shared_epoch.value 76 | 77 | 78 | @dataclass 79 | class DataInfo: 80 | dataloader: DataLoader 81 | sampler: DistributedSampler = None 82 | shared_epoch: SharedEpoch = None 83 | 84 | def set_epoch(self, epoch): 85 | if self.shared_epoch is not None: 86 | self.shared_epoch.set_value(epoch) 87 | if self.sampler is not None and isinstance(self.sampler, DistributedSampler): 88 | self.sampler.set_epoch(epoch) 89 | 90 | 91 | def preprocess_txt(text): 92 | return tokenize([str(text)])[0] 93 | 94 | 95 | def get_dataset_size(shards): 96 | shards_list = list(braceexpand.braceexpand(shards)) 97 | dir_path = os.path.dirname(shards) 98 | sizes_filename = os.path.join(dir_path, 'sizes.json') 99 | len_filename = os.path.join(dir_path, '__len__') 100 | if os.path.exists(sizes_filename): 101 | sizes = json.load(open(sizes_filename, 'r')) 102 | total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list]) 103 | elif os.path.exists(len_filename): 104 | # FIXME this used to be eval(open(...)) but that seemed rather unsafe 105 | total_size = ast.literal_eval(open(len_filename, 'r').read()) 106 | else: 107 | total_size = None # num samples undefined 108 | # some common dataset sizes (at time of authors last download) 109 | # CC3M (train): 2905954 110 | # CC12M: 10968539 111 | # LAION-400M: 407332084 112 | # LAION-2B (english): 2170337258 113 | num_shards = len(shards_list) 114 | return total_size, num_shards 115 | 116 | 117 | def get_imagenet(args, preprocess_fns, split): 118 | assert split in ["train", "val", "v2"] 119 | is_train = split == "train" 120 | preprocess_train, preprocess_val = preprocess_fns 121 | 122 | if split == "v2": 123 | from imagenetv2_pytorch import ImageNetV2Dataset 124 | dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) 125 | else: 126 | if is_train: 127 | data_path = args.imagenet_train 128 | preprocess_fn = preprocess_train 129 | else: 130 | data_path = args.imagenet_val 131 | preprocess_fn = preprocess_val 132 | assert data_path 133 | 134 | dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) 135 | 136 | if is_train: 137 | idxs = np.zeros(len(dataset.targets)) 138 | target_array = np.array(dataset.targets) 139 | k = 50 140 | for c in range(1000): 141 | m = target_array == c 142 | n = len(idxs[m]) 143 | arr = np.zeros(n) 144 | arr[:k] = 1 145 | np.random.shuffle(arr) 146 | idxs[m] = arr 147 | 148 | idxs = idxs.astype('int') 149 | sampler = SubsetRandomSampler(np.where(idxs)[0]) 150 | else: 151 | sampler = None 152 | 153 | dataloader = torch.utils.data.DataLoader( 154 | dataset, 155 | batch_size=args.batch_size, 156 | num_workers=args.workers, 157 | sampler=sampler, 158 | ) 159 | 160 | return DataInfo(dataloader=dataloader, sampler=sampler) 161 | 162 | 163 | def count_samples(dataloader): 164 | os.environ["WDS_EPOCH"] = "0" 165 | n_elements, n_batches = 0, 0 166 | for images, texts in dataloader: 167 | n_batches += 1 168 | n_elements += len(images) 169 | assert len(images) == len(texts) 170 | return n_elements, n_batches 171 | 172 | 173 | def filter_no_caption(sample): 174 | return 'txt' in sample 175 | 176 | 177 | def log_and_continue(exn): 178 | """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" 179 | logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') 180 | return True 181 | 182 | 183 | def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): 184 | """Return function over iterator that groups key, value pairs into samples. 185 | 186 | :param keys: function that splits the key into key and extension (base_plus_ext) 187 | :param lcase: convert suffixes to lower case (Default value = True) 188 | """ 189 | current_sample = None 190 | for filesample in data: 191 | assert isinstance(filesample, dict) 192 | fname, value = filesample["fname"], filesample["data"] 193 | prefix, suffix = keys(fname) 194 | if prefix is None: 195 | continue 196 | if lcase: 197 | suffix = suffix.lower() 198 | # FIXME webdataset version throws if suffix in current_sample, but we have a potential for 199 | # this happening in the current LAION400m dataset if a tar ends with same prefix as the next 200 | # begins, rare, but can happen since prefix aren't unique across tar files in that dataset 201 | if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: 202 | if valid_sample(current_sample): 203 | yield current_sample 204 | current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) 205 | if suffixes is None or suffix in suffixes: 206 | current_sample[suffix] = value 207 | if valid_sample(current_sample): 208 | yield current_sample 209 | 210 | 211 | def tarfile_to_samples_nothrow(src, handler=log_and_continue): 212 | # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw 213 | streams = url_opener(src, handler=handler) 214 | files = tar_file_expander(streams, handler=handler) 215 | samples = group_by_keys_nothrow(files, handler=handler) 216 | return samples 217 | 218 | 219 | def pytorch_worker_seed(): 220 | """get dataloader worker seed from pytorch""" 221 | worker_info = get_worker_info() 222 | if worker_info is not None: 223 | # favour the seed already created for pytorch dataloader workers if it exists 224 | return worker_info.seed 225 | # fallback to wds rank based seed 226 | return wds.utils.pytorch_worker_seed() 227 | 228 | 229 | _SHARD_SHUFFLE_SIZE = 2000 230 | _SHARD_SHUFFLE_INITIAL = 500 231 | _SAMPLE_SHUFFLE_SIZE = 5000 232 | _SAMPLE_SHUFFLE_INITIAL = 1000 233 | 234 | 235 | class detshuffle2(wds.PipelineStage): 236 | def __init__( 237 | self, 238 | bufsize=1000, 239 | initial=100, 240 | seed=0, 241 | epoch=-1, 242 | ): 243 | self.bufsize = bufsize 244 | self.initial = initial 245 | self.seed = seed 246 | self.epoch = epoch 247 | 248 | def run(self, src): 249 | if isinstance(self.epoch, SharedEpoch): 250 | epoch = self.epoch.get_value() 251 | else: 252 | # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) 253 | # situation as different workers may wrap at different times (or not at all). 254 | self.epoch += 1 255 | epoch = self.epoch 256 | rng = random.Random() 257 | if self.seed < 0: 258 | seed = pytorch_worker_seed() + epoch 259 | else: 260 | seed = self.seed + epoch 261 | rng.seed(seed) 262 | return _shuffle(src, self.bufsize, self.initial, rng) 263 | 264 | 265 | class ResampledShards2(IterableDataset): 266 | """An iterable dataset yielding a list of urls.""" 267 | 268 | def __init__( 269 | self, 270 | urls, 271 | nshards=sys.maxsize, 272 | worker_seed=None, 273 | deterministic=False, 274 | epoch=-1, 275 | ): 276 | """Sample shards from the shard list with replacement. 277 | 278 | :param urls: a list of URLs as a Python list or brace notation string 279 | """ 280 | super().__init__() 281 | urls = wds.shardlists.expand_urls(urls) 282 | self.urls = urls 283 | assert isinstance(self.urls[0], str) 284 | self.nshards = nshards 285 | self.rng = random.Random() 286 | self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed 287 | self.deterministic = deterministic 288 | self.epoch = epoch 289 | 290 | def __iter__(self): 291 | """Return an iterator over the shards.""" 292 | if isinstance(self.epoch, SharedEpoch): 293 | epoch = self.epoch.get_value() 294 | else: 295 | # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) 296 | # situation as different workers may wrap at different times (or not at all). 297 | self.epoch += 1 298 | epoch = self.epoch 299 | if self.deterministic: 300 | # reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed 301 | self.rng.seed(self.worker_seed() + epoch) 302 | for _ in range(self.nshards): 303 | yield dict(url=self.rng.choice(self.urls)) 304 | 305 | 306 | def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False): 307 | input_shards = args.train_data if is_train else args.val_data 308 | assert input_shards is not None 309 | resampled = getattr(args, 'dataset_resampled', False) and is_train 310 | 311 | num_samples, num_shards = get_dataset_size(input_shards) 312 | if not num_samples: 313 | if is_train: 314 | num_samples = args.train_num_samples 315 | if not num_samples: 316 | raise RuntimeError( 317 | 'Currently, number of dataset samples must be specified for training dataset. ' 318 | 'Please specify via `--train-num-samples` if no dataset length info present.') 319 | else: 320 | num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified 321 | 322 | shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc 323 | if resampled: 324 | pipeline = [ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch)] 325 | else: 326 | pipeline = [wds.SimpleShardList(input_shards)] 327 | 328 | # at this point we have an iterator over all the shards 329 | if is_train: 330 | if not resampled: 331 | pipeline.extend([ 332 | detshuffle2( 333 | bufsize=_SHARD_SHUFFLE_SIZE, 334 | initial=_SHARD_SHUFFLE_INITIAL, 335 | seed=args.seed, 336 | epoch=shared_epoch, 337 | ), 338 | wds.split_by_node, 339 | wds.split_by_worker, 340 | ]) 341 | pipeline.extend([ 342 | # at this point, we have an iterator over the shards assigned to each worker at each node 343 | tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), 344 | wds.shuffle( 345 | bufsize=_SAMPLE_SHUFFLE_SIZE, 346 | initial=_SAMPLE_SHUFFLE_INITIAL, 347 | ), 348 | ]) 349 | else: 350 | pipeline.extend([ 351 | wds.split_by_worker, 352 | # at this point, we have an iterator over the shards assigned to each worker 353 | wds.tarfile_to_samples(handler=log_and_continue), 354 | ]) 355 | pipeline.extend([ 356 | wds.select(filter_no_caption), 357 | wds.decode("pilrgb", handler=log_and_continue), 358 | wds.rename(image="jpg;png", text="txt"), 359 | wds.map_dict(image=preprocess_img, text=preprocess_txt), 360 | wds.to_tuple("image", "text"), 361 | wds.batched(args.batch_size, partial=not is_train), 362 | ]) 363 | 364 | dataset = wds.DataPipeline(*pipeline) 365 | if is_train: 366 | if not resampled: 367 | assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' 368 | # roll over and repeat a few samples to get same number of full batches on each node 369 | round_fn = math.floor if floor else math.ceil 370 | global_batch_size = args.batch_size * args.world_size 371 | num_batches = round_fn(num_samples / global_batch_size) 372 | num_workers = max(1, args.workers) 373 | num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker 374 | num_batches = num_worker_batches * num_workers 375 | num_samples = num_batches * global_batch_size 376 | dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this 377 | else: 378 | # last batches are partial, eval is done on single (master) node 379 | num_batches = math.ceil(num_samples / args.batch_size) 380 | 381 | dataloader = wds.WebLoader( 382 | dataset, 383 | batch_size=None, 384 | shuffle=False, 385 | num_workers=args.workers, 386 | persistent_workers=True, 387 | ) 388 | 389 | # FIXME not clear which approach is better, with_epoch before vs after dataloader? 390 | # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 391 | # if is_train: 392 | # # roll over and repeat a few samples to get same number of full batches on each node 393 | # global_batch_size = args.batch_size * args.world_size 394 | # num_batches = math.ceil(num_samples / global_batch_size) 395 | # num_workers = max(1, args.workers) 396 | # num_batches = math.ceil(num_batches / num_workers) * num_workers 397 | # num_samples = num_batches * global_batch_size 398 | # dataloader = dataloader.with_epoch(num_batches) 399 | # else: 400 | # # last batches are partial, eval is done on single (master) node 401 | # num_batches = math.ceil(num_samples / args.batch_size) 402 | 403 | # add meta-data to dataloader instance for convenience 404 | dataloader.num_batches = num_batches 405 | dataloader.num_samples = num_samples 406 | 407 | return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) 408 | 409 | 410 | def get_csv_dataset(args, preprocess_fn, is_train, epoch=0): 411 | input_filename = args.train_data if is_train else args.val_data 412 | assert input_filename 413 | dataset = CsvDataset( 414 | input_filename, 415 | preprocess_fn, 416 | img_key=args.csv_img_key, 417 | caption_key=args.csv_caption_key, 418 | hard_captions_key=args.csv_hard_captions_key, 419 | sep=args.csv_separator) 420 | num_samples = len(dataset) 421 | sampler = DistributedSampler(dataset) if args.distributed and is_train else None 422 | shuffle = is_train and sampler is None 423 | 424 | dataloader = DataLoader( 425 | dataset, 426 | batch_size=args.batch_size, 427 | shuffle=shuffle, 428 | num_workers=args.workers, 429 | pin_memory=True, 430 | sampler=sampler, 431 | drop_last=is_train, 432 | ) 433 | dataloader.num_samples = num_samples 434 | dataloader.num_batches = len(dataloader) 435 | 436 | return DataInfo(dataloader, sampler) 437 | 438 | 439 | def get_dataset_fn(data_path, dataset_type): 440 | if dataset_type == "webdataset": 441 | return get_wds_dataset 442 | elif dataset_type == "csv": 443 | return get_csv_dataset 444 | elif dataset_type == "auto": 445 | ext = data_path.split('.')[-1] 446 | if ext in ['csv', 'tsv']: 447 | return get_csv_dataset 448 | elif ext in ['tar']: 449 | return get_wds_dataset 450 | else: 451 | raise ValueError( 452 | f"Tried to figure out dataset type, but failed for extention {ext}.") 453 | else: 454 | raise ValueError(f"Unsupported dataset type: {dataset_type}") 455 | 456 | 457 | def get_data(args, preprocess_fns, epoch=0): 458 | preprocess_train, preprocess_val = preprocess_fns 459 | data = {} 460 | 461 | if args.train_data: 462 | data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( 463 | args, preprocess_train, is_train=True, epoch=epoch) 464 | 465 | if args.val_data: 466 | data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( 467 | args, preprocess_val, is_train=False) 468 | 469 | if args.imagenet_val is not None: 470 | data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val") 471 | 472 | if args.imagenet_v2 is not None: 473 | data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2") 474 | 475 | return data 476 | -------------------------------------------------------------------------------- /src/training/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | try: 6 | import horovod.torch as hvd 7 | except ImportError: 8 | hvd = None 9 | 10 | 11 | def is_global_master(args): 12 | return args.rank == 0 13 | 14 | 15 | def is_local_master(args): 16 | return args.local_rank == 0 17 | 18 | 19 | def is_master(args, local=False): 20 | return is_local_master(args) if local else is_global_master(args) 21 | 22 | 23 | def is_using_horovod(): 24 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 25 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 26 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 27 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 28 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): 29 | return True 30 | else: 31 | return False 32 | 33 | 34 | def is_using_distributed(): 35 | if 'WORLD_SIZE' in os.environ: 36 | return int(os.environ['WORLD_SIZE']) > 1 37 | if 'SLURM_NTASKS' in os.environ: 38 | return int(os.environ['SLURM_NTASKS']) > 1 39 | return False 40 | 41 | 42 | def world_info_from_env(): 43 | local_rank = 0 44 | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): 45 | if v in os.environ: 46 | local_rank = int(os.environ[v]) 47 | break 48 | global_rank = 0 49 | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): 50 | if v in os.environ: 51 | global_rank = int(os.environ[v]) 52 | break 53 | world_size = 1 54 | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): 55 | if v in os.environ: 56 | world_size = int(os.environ[v]) 57 | break 58 | 59 | return local_rank, global_rank, world_size 60 | 61 | 62 | def init_distributed_device(args): 63 | # Distributed training = training on more than one GPU. 64 | # Works in both single and multi-node scenarios. 65 | args.distributed = False 66 | args.world_size = 1 67 | args.rank = 0 # global rank 68 | args.local_rank = 0 69 | if args.horovod: 70 | assert hvd is not None, "Horovod is not installed" 71 | hvd.init() 72 | args.local_rank = int(hvd.local_rank()) 73 | args.rank = hvd.rank() 74 | args.world_size = hvd.size() 75 | args.distributed = True 76 | os.environ['LOCAL_RANK'] = str(args.local_rank) 77 | os.environ['RANK'] = str(args.rank) 78 | os.environ['WORLD_SIZE'] = str(args.world_size) 79 | elif is_using_distributed(): 80 | if 'SLURM_PROCID' in os.environ: 81 | # DDP via SLURM 82 | args.local_rank, args.rank, args.world_size = world_info_from_env() 83 | # SLURM var -> torch.distributed vars in case needed 84 | os.environ['LOCAL_RANK'] = str(args.local_rank) 85 | os.environ['RANK'] = str(args.rank) 86 | os.environ['WORLD_SIZE'] = str(args.world_size) 87 | torch.distributed.init_process_group( 88 | backend=args.dist_backend, 89 | init_method=args.dist_url, 90 | world_size=args.world_size, 91 | rank=args.rank, 92 | ) 93 | else: 94 | # DDP via torchrun, torch.distributed.launch 95 | args.local_rank, _, _ = world_info_from_env() 96 | torch.distributed.init_process_group( 97 | backend=args.dist_backend, 98 | init_method=args.dist_url) 99 | args.world_size = torch.distributed.get_world_size() 100 | args.rank = torch.distributed.get_rank() 101 | args.distributed = True 102 | 103 | if torch.cuda.is_available(): 104 | if args.distributed and not args.no_set_device_rank: 105 | device = 'cuda:%d' % args.local_rank 106 | else: 107 | device = 'cuda:0' 108 | torch.cuda.set_device(device) 109 | else: 110 | device = 'cpu' 111 | args.device = device 112 | device = torch.device(device) 113 | return device 114 | -------------------------------------------------------------------------------- /src/training/imagenet_zeroshot_data.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | imagenet_classnames = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", 4 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", 5 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", 6 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", 7 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", 8 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", 9 | "box turtle", "banded gecko", "green iguana", "Carolina anole", 10 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", 11 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", 12 | "American alligator", "triceratops", "worm snake", "ring-necked snake", 13 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", 14 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", 15 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", 16 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", 17 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", 18 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", 19 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", 20 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", 21 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", 22 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", 23 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", 24 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", 25 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", 26 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", 27 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", 28 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", 29 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", 30 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", 31 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", 32 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", 33 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", 34 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", 35 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", 36 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", 37 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", 38 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", 39 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", 40 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", 41 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", 42 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", 43 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", 44 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", 45 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", 46 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", 47 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", 48 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", 49 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", 50 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", 51 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", 52 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", 53 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", 54 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", 55 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", 56 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", 57 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", 58 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", 59 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", 60 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", 61 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", 62 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", 63 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", 64 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", 65 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", 66 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", 67 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", 68 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", 69 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", 70 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", 71 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", 72 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", 73 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", 74 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", 75 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", 76 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", 77 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", 78 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", 79 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", 80 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", 81 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", 82 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", 83 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", 84 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", 85 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", 86 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", 87 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", 88 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", 89 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", 90 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", 91 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", 92 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", 93 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", 94 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", 95 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", 96 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", 97 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", 98 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", 99 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", 100 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", 101 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", 102 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck", 103 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", 104 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", 105 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", 106 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", 107 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", 108 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", 109 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", 110 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", 111 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", 112 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", 113 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", 114 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", 115 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", 116 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", 117 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", 118 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", 119 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", 120 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", 121 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", 122 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", 123 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", 124 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", 125 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", 126 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", 127 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", 128 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", 129 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", 130 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", 131 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", 132 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", 133 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", 134 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", 135 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", 136 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", 137 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", 138 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", 139 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", 140 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", 141 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", 142 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", 143 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", 144 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", 145 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", 146 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", 147 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", 148 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", 149 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", 150 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", 151 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", 152 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", 153 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", 154 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", 155 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", 156 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", 157 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", 158 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", 159 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", 160 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", 161 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", 162 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", 163 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", 164 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", 165 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", 166 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", 167 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] 168 | 169 | 170 | 171 | 172 | 173 | openai_imagenet_template = [ 174 | lambda c: f'a bad photo of a {c}.', 175 | lambda c: f'a photo of many {c}.', 176 | lambda c: f'a sculpture of a {c}.', 177 | lambda c: f'a photo of the hard to see {c}.', 178 | lambda c: f'a low resolution photo of the {c}.', 179 | lambda c: f'a rendering of a {c}.', 180 | lambda c: f'graffiti of a {c}.', 181 | lambda c: f'a bad photo of the {c}.', 182 | lambda c: f'a cropped photo of the {c}.', 183 | lambda c: f'a tattoo of a {c}.', 184 | lambda c: f'the embroidered {c}.', 185 | lambda c: f'a photo of a hard to see {c}.', 186 | lambda c: f'a bright photo of a {c}.', 187 | lambda c: f'a photo of a clean {c}.', 188 | lambda c: f'a photo of a dirty {c}.', 189 | lambda c: f'a dark photo of the {c}.', 190 | lambda c: f'a drawing of a {c}.', 191 | lambda c: f'a photo of my {c}.', 192 | lambda c: f'the plastic {c}.', 193 | lambda c: f'a photo of the cool {c}.', 194 | lambda c: f'a close-up photo of a {c}.', 195 | lambda c: f'a black and white photo of the {c}.', 196 | lambda c: f'a painting of the {c}.', 197 | lambda c: f'a painting of a {c}.', 198 | lambda c: f'a pixelated photo of the {c}.', 199 | lambda c: f'a sculpture of the {c}.', 200 | lambda c: f'a bright photo of the {c}.', 201 | lambda c: f'a cropped photo of a {c}.', 202 | lambda c: f'a plastic {c}.', 203 | lambda c: f'a photo of the dirty {c}.', 204 | lambda c: f'a jpeg corrupted photo of a {c}.', 205 | lambda c: f'a blurry photo of the {c}.', 206 | lambda c: f'a photo of the {c}.', 207 | lambda c: f'a good photo of the {c}.', 208 | lambda c: f'a rendering of the {c}.', 209 | lambda c: f'a {c} in a video game.', 210 | lambda c: f'a photo of one {c}.', 211 | lambda c: f'a doodle of a {c}.', 212 | lambda c: f'a close-up photo of the {c}.', 213 | lambda c: f'a photo of a {c}.', 214 | lambda c: f'the origami {c}.', 215 | lambda c: f'the {c} in a video game.', 216 | lambda c: f'a sketch of a {c}.', 217 | lambda c: f'a doodle of the {c}.', 218 | lambda c: f'a origami {c}.', 219 | lambda c: f'a low resolution photo of a {c}.', 220 | lambda c: f'the toy {c}.', 221 | lambda c: f'a rendition of the {c}.', 222 | lambda c: f'a photo of the clean {c}.', 223 | lambda c: f'a photo of a large {c}.', 224 | lambda c: f'a rendition of a {c}.', 225 | lambda c: f'a photo of a nice {c}.', 226 | lambda c: f'a photo of a weird {c}.', 227 | lambda c: f'a blurry photo of a {c}.', 228 | lambda c: f'a cartoon {c}.', 229 | lambda c: f'art of a {c}.', 230 | lambda c: f'a sketch of the {c}.', 231 | lambda c: f'a embroidered {c}.', 232 | lambda c: f'a pixelated photo of a {c}.', 233 | lambda c: f'itap of the {c}.', 234 | lambda c: f'a jpeg corrupted photo of the {c}.', 235 | lambda c: f'a good photo of a {c}.', 236 | lambda c: f'a plushie {c}.', 237 | lambda c: f'a photo of the nice {c}.', 238 | lambda c: f'a photo of the small {c}.', 239 | lambda c: f'a photo of the weird {c}.', 240 | lambda c: f'the cartoon {c}.', 241 | lambda c: f'art of the {c}.', 242 | lambda c: f'a drawing of the {c}.', 243 | lambda c: f'a photo of the large {c}.', 244 | lambda c: f'a black and white photo of a {c}.', 245 | lambda c: f'the plushie {c}.', 246 | lambda c: f'a dark photo of a {c}.', 247 | lambda c: f'itap of a {c}.', 248 | lambda c: f'graffiti of the {c}.', 249 | lambda c: f'a toy {c}.', 250 | lambda c: f'itap of my {c}.', 251 | lambda c: f'a photo of a cool {c}.', 252 | lambda c: f'a photo of a small {c}.', 253 | lambda c: f'a tattoo of the {c}.', 254 | ] 255 | -------------------------------------------------------------------------------- /src/training/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logging(log_file, level, include_host=False): 5 | if include_host: 6 | import socket 7 | hostname = socket.gethostname() 8 | formatter = logging.Formatter( 9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 10 | else: 11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 12 | 13 | logging.root.setLevel(level) 14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 15 | for logger in loggers: 16 | logger.setLevel(level) 17 | 18 | stream_handler = logging.StreamHandler() 19 | stream_handler.setFormatter(formatter) 20 | logging.root.addHandler(stream_handler) 21 | 22 | if log_file: 23 | file_handler = logging.FileHandler(filename=log_file) 24 | file_handler.setFormatter(formatter) 25 | logging.root.addHandler(file_handler) 26 | 27 | -------------------------------------------------------------------------------- /src/training/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | from datetime import datetime 5 | 6 | import numpy as np 7 | import torch 8 | from torch import optim 9 | from torch.cuda.amp import GradScaler 10 | 11 | try: 12 | import wandb 13 | except ImportError: 14 | wandb = None 15 | 16 | try: 17 | import torch.utils.tensorboard as tensorboard 18 | except ImportError: 19 | tensorboard = None 20 | 21 | try: 22 | import horovod.torch as hvd 23 | except ImportError: 24 | hvd = None 25 | 26 | from open_clip import create_model_and_transforms, trace_model 27 | from training.data import get_data 28 | from training.distributed import is_master, init_distributed_device, world_info_from_env 29 | from training.logger import setup_logging 30 | from training.params import parse_args 31 | from training.scheduler import cosine_lr 32 | from training.train import train_one_epoch, evaluate 33 | 34 | 35 | def random_seed(seed=42, rank=0): 36 | torch.manual_seed(seed + rank) 37 | np.random.seed(seed + rank) 38 | random.seed(seed + rank) 39 | 40 | 41 | def main(): 42 | args = parse_args() 43 | 44 | # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? 45 | args.model = args.model.replace('/', '-') 46 | 47 | # get the name of the experiments 48 | if args.name is None: 49 | args.name = '-'.join([ 50 | datetime.now().strftime("%Y_%m_%d-%H_%M_%S"), 51 | f"model_{args.model}", 52 | f"lr_{args.lr}", 53 | f"b_{args.batch_size}", 54 | f"j_{args.workers}", 55 | f"p_{args.precision}", 56 | ]) 57 | 58 | # discover initial world args early so we can log properly 59 | args.distributed = False 60 | args.local_rank, args.rank, args.world_size = world_info_from_env() 61 | 62 | args.log_path = None 63 | if is_master(args, local=args.log_local): 64 | log_base_path = os.path.join(args.logs, args.name) 65 | os.makedirs(log_base_path, exist_ok=True) 66 | log_filename = f'out-{args.rank}' if args.log_local else 'out.log' 67 | args.log_path = os.path.join(log_base_path, log_filename) 68 | if os.path.exists(args.log_path): 69 | print( 70 | "Error. Experiment already exists. Use --name {} to specify a new experiment." 71 | ) 72 | return -1 73 | 74 | # Set logger 75 | args.log_level = logging.DEBUG if args.debug else logging.INFO 76 | setup_logging(args.log_path, args.log_level) 77 | 78 | # fully initialize distributed device environment 79 | torch.backends.cudnn.benchmark = True 80 | torch.backends.cudnn.deterministic = False 81 | device = init_distributed_device(args) 82 | 83 | args.wandb = 'wandb' in args.report_to or 'all' in args.report_to 84 | args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to 85 | if is_master(args): 86 | args.tensorboard_path = os.path.join(args.logs, args.name, "tensorboard") if args.tensorboard else '' 87 | args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") 88 | for dirname in [args.tensorboard_path, args.checkpoint_path]: 89 | if dirname: 90 | os.makedirs(dirname, exist_ok=True) 91 | else: 92 | args.tensorboard_path = '' 93 | args.checkpoint_path = '' 94 | 95 | if args.copy_codebase: 96 | copy_codebase(args) 97 | 98 | assert args.precision in ['amp', 'fp16', 'fp32'] 99 | if args.precision == 'fp16': 100 | logging.warning( 101 | 'It is recommended to use AMP mixed-precision instead of FP16. ' 102 | 'FP16 support needs further verification and tuning, especially for train.') 103 | 104 | if args.horovod: 105 | logging.info( 106 | f'Running in horovod mode with multiple processes / nodes. Device: {args.device}.' 107 | f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') 108 | elif args.distributed: 109 | logging.info( 110 | f'Running in distributed mode with multiple processes. Device: {args.device}.' 111 | f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') 112 | else: 113 | logging.info(f'Running with a single process. Device {args.device}.') 114 | 115 | random_seed(args.seed, 0) 116 | model, preprocess_train, preprocess_val = create_model_and_transforms( 117 | args.model, 118 | args.pretrained, 119 | precision=args.precision, 120 | device=device, 121 | jit=args.torchscript, 122 | force_quick_gelu=args.force_quick_gelu, 123 | pretrained_image=args.pretrained_image, 124 | ) 125 | random_seed(args.seed, args.rank) 126 | 127 | if args.trace: 128 | model = trace_model(model, batch_size=args.batch_size, device=device) 129 | 130 | if args.lock_image: 131 | # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 132 | model.lock_image_tower( 133 | unlocked_groups=args.lock_image_unlocked_groups, 134 | freeze_bn_stats=args.lock_image_freeze_bn_stats) 135 | 136 | if args.grad_checkpointing: 137 | model.set_grad_checkpointing() 138 | 139 | if is_master(args): 140 | logging.info("Model:") 141 | logging.info(f"{str(model)}") 142 | logging.info("Params:") 143 | params_file = os.path.join(args.logs, args.name, "params.txt") 144 | with open(params_file, "w") as f: 145 | for name in sorted(vars(args)): 146 | val = getattr(args, name) 147 | logging.info(f" {name}: {val}") 148 | f.write(f"{name}: {val}\n") 149 | 150 | if args.distributed and not args.horovod: 151 | if args.use_bn_sync: 152 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 153 | ddp_args = {} 154 | if args.ddp_static_graph: 155 | # this doesn't exist in older PyTorch, arg only added if enabled 156 | ddp_args['static_graph'] = True 157 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) 158 | 159 | # create optimizer and scaler 160 | optimizer = None 161 | scaler = None 162 | if args.train_data: 163 | assert not args.trace, 'Cannot train with traced model' 164 | 165 | exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n 166 | include = lambda n, p: not exclude(n, p) 167 | 168 | named_parameters = list(model.named_parameters()) 169 | gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] 170 | rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] 171 | 172 | optimizer = optim.AdamW( 173 | [ 174 | {"params": gain_or_bias_params, "weight_decay": 0.}, 175 | {"params": rest_params, "weight_decay": args.wd}, 176 | ], 177 | lr=args.lr, 178 | betas=(args.beta1, args.beta2), 179 | eps=args.eps, 180 | ) 181 | if args.horovod: 182 | optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) 183 | hvd.broadcast_parameters(model.state_dict(), root_rank=0) 184 | hvd.broadcast_optimizer_state(optimizer, root_rank=0) 185 | 186 | scaler = GradScaler() if args.precision == "amp" else None 187 | 188 | # optionally resume from a checkpoint 189 | start_epoch = 0 190 | if args.resume is not None: 191 | if os.path.isfile(args.resume): 192 | checkpoint = torch.load(args.resume, map_location=device) 193 | if 'epoch' in checkpoint: 194 | # resuming a train checkpoint w/ epoch and optimizer state 195 | start_epoch = checkpoint["epoch"] 196 | sd = checkpoint["state_dict"] 197 | if not args.distributed and next(iter(sd.items()))[0].startswith('module'): 198 | sd = {k[len('module.'):]: v for k, v in sd.items()} 199 | model.load_state_dict(sd) 200 | if optimizer is not None: 201 | optimizer.load_state_dict(checkpoint["optimizer"]) 202 | if scaler is not None and 'scaler' in checkpoint: 203 | scaler.load_state_dict(checkpoint['scaler']) 204 | logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})") 205 | else: 206 | # loading a bare (model only) checkpoint for fine-tune or evaluation 207 | model.load_state_dict(checkpoint) 208 | logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") 209 | else: 210 | logging.info("=> no checkpoint found at '{}'".format(args.resume)) 211 | 212 | # initialize datasets 213 | data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch) 214 | assert len(data), 'At least one train or eval dataset must be specified.' 215 | 216 | # create scheduler if train 217 | scheduler = None 218 | if 'train' in data and optimizer is not None: 219 | total_steps = data["train"].dataloader.num_batches * args.epochs 220 | scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) 221 | 222 | # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 223 | args.save_logs = args.logs and args.logs.lower() != 'none' and is_master(args) 224 | writer = None 225 | if args.save_logs and args.tensorboard: 226 | assert tensorboard is not None, "Please install tensorboard." 227 | writer = tensorboard.SummaryWriter(args.tensorboard_path) 228 | 229 | if args.wandb and is_master(args): 230 | assert wandb is not None, 'Please install wandb.' 231 | logging.debug('Starting wandb.') 232 | args.train_sz = data["train"].dataloader.num_samples 233 | if args.val_data is not None: 234 | args.val_sz = data["val"].dataloader.num_samples 235 | # you will have to configure this for your project! 236 | wandb.init( 237 | project="open-clip", 238 | notes=args.wandb_notes, 239 | tags=[], 240 | config=vars(args), 241 | ) 242 | if args.debug: 243 | wandb.watch(model, log='all') 244 | wandb.save(params_file) 245 | logging.debug('Finished loading wandb.') 246 | 247 | if 'train' not in data: 248 | evaluate(model, data, start_epoch, args, writer) 249 | return 250 | 251 | for epoch in range(start_epoch, args.epochs): 252 | if is_master(args): 253 | logging.info(f'Start epoch {epoch}') 254 | 255 | train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) 256 | completed_epoch = epoch + 1 257 | 258 | if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')): 259 | evaluate(model, data, completed_epoch, args, writer) 260 | 261 | # Saving checkpoints. 262 | if args.save_logs: 263 | checkpoint_dict = { 264 | "epoch": completed_epoch, 265 | "name": args.name, 266 | "state_dict": model.state_dict(), 267 | "optimizer": optimizer.state_dict(), 268 | } 269 | if scaler is not None: 270 | checkpoint_dict["scaler"] = scaler.state_dict() 271 | 272 | if completed_epoch == args.epochs or ( 273 | args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 274 | ): 275 | torch.save( 276 | checkpoint_dict, 277 | os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), 278 | ) 279 | if args.save_most_recent: 280 | torch.save( 281 | checkpoint_dict, 282 | os.path.join(args.checkpoint_path, f"epoch_latest.pt"), 283 | ) 284 | 285 | if args.wandb and is_master(args): 286 | wandb.finish() 287 | 288 | 289 | def copy_codebase(args): 290 | from shutil import copytree, ignore_patterns 291 | new_code_path = os.path.join(args.logs, args.name, "code") 292 | if os.path.exists(new_code_path): 293 | print( 294 | f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." 295 | ) 296 | return -1 297 | print(f"Copying codebase to {new_code_path}") 298 | current_code_path = os.path.realpath(__file__) 299 | for _ in range(3): 300 | current_code_path = os.path.dirname(current_code_path) 301 | copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb')) 302 | print("Done copying code.") 303 | return 1 304 | 305 | 306 | if __name__ == "__main__": 307 | main() 308 | -------------------------------------------------------------------------------- /src/training/params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_default_params(model_name): 5 | # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) 6 | model_name = model_name.lower() 7 | if "vit" in model_name: 8 | return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} 9 | else: 10 | return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--train-data", 17 | type=str, 18 | default=None, 19 | help="Path to csv filewith training data", 20 | ) 21 | parser.add_argument( 22 | "--val-data", 23 | type=str, 24 | default=None, 25 | help="Path to csv file with validation data", 26 | ) 27 | parser.add_argument( 28 | "--train-num-samples", 29 | type=int, 30 | default=None, 31 | help="Number of samples in dataset. Required for webdataset if not available in info file.", 32 | ) 33 | parser.add_argument( 34 | "--val-num-samples", 35 | type=int, 36 | default=None, 37 | help="Number of samples in dataset. Useful for webdataset if not available in info file.", 38 | ) 39 | parser.add_argument( 40 | "--dataset-type", 41 | choices=["webdataset", "csv", "auto"], 42 | default="auto", 43 | help="Which type of dataset to process." 44 | ) 45 | parser.add_argument( 46 | "--dataset-resampled", 47 | default=False, 48 | action="store_true", 49 | help="Whether to use sampling with replacement for webdataset shard selection." 50 | ) 51 | parser.add_argument( 52 | "--csv-separator", 53 | type=str, 54 | default="\t", 55 | help="For csv-like datasets, which separator to use." 56 | ) 57 | parser.add_argument( 58 | "--csv-img-key", 59 | type=str, 60 | default="filepath", 61 | help="For csv-like datasets, the name of the key for the image paths." 62 | ) 63 | parser.add_argument( 64 | "--csv-hard-captions-key", 65 | type=str, 66 | default="neg_caption", 67 | help="For csv-like datasets, the name of the key for the hard captions." 68 | ) 69 | parser.add_argument( 70 | "--csv-caption-key", 71 | type=str, 72 | default="title", 73 | help="For csv-like datasets, the name of the key for the captions." 74 | ) 75 | parser.add_argument( 76 | "--imagenet-val", 77 | type=str, 78 | default=None, 79 | help="Path to imagenet val set for conducting zero shot evaluation.", 80 | ) 81 | parser.add_argument( 82 | "--imagenet-v2", 83 | type=str, 84 | default=None, 85 | help="Path to imagenet v2 for conducting zero shot evaluation.", 86 | ) 87 | parser.add_argument( 88 | "--logs", 89 | type=str, 90 | default="./logs/", 91 | help="Where to store tensorboard logs. Use None to avoid storing logs.", 92 | ) 93 | parser.add_argument( 94 | "--log-local", 95 | action="store_true", 96 | default=False, 97 | help="log files on local master, otherwise global master only.", 98 | ) 99 | parser.add_argument( 100 | "--name", 101 | type=str, 102 | default=None, 103 | help="Optional identifier for the experiment when storing logs. Otherwise use current time.", 104 | ) 105 | parser.add_argument( 106 | "--workers", type=int, default=1, help="Number of dataloader workers per GPU." 107 | ) 108 | parser.add_argument( 109 | "--batch-size", type=int, default=64, help="Batch size per GPU." 110 | ) 111 | parser.add_argument( 112 | "--epochs", type=int, default=32, help="Number of epochs to train for." 113 | ) 114 | parser.add_argument("--lr", type=float, default=None, help="Learning rate.") 115 | parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") 116 | parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") 117 | parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") 118 | parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") 119 | parser.add_argument( 120 | "--warmup", type=int, default=10000, help="Number of steps to warmup for." 121 | ) 122 | parser.add_argument( 123 | "--use-bn-sync", 124 | default=False, 125 | action="store_true", 126 | help="Whether to use batch norm sync.") 127 | parser.add_argument( 128 | "--skip-scheduler", 129 | action="store_true", 130 | default=False, 131 | help="Use this flag to skip the learning rate decay.", 132 | ) 133 | parser.add_argument( 134 | "--save-frequency", type=int, default=1, help="How often to save checkpoints." 135 | ) 136 | parser.add_argument( 137 | "--save-most-recent", 138 | action="store_true", 139 | default=False, 140 | help="Always save the most recent model trained to epoch_latest.pt.", 141 | ) 142 | parser.add_argument( 143 | "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot." 144 | ) 145 | parser.add_argument( 146 | "--val-frequency", type=int, default=1, help="How often to run evaluation with val data." 147 | ) 148 | parser.add_argument( 149 | "--resume", 150 | default=None, 151 | type=str, 152 | help="path to latest checkpoint (default: none)", 153 | ) 154 | parser.add_argument( 155 | "--precision", 156 | choices=["amp", "fp16", "fp32"], 157 | default="amp", 158 | help="Floating point precision." 159 | ) 160 | parser.add_argument( 161 | "--model", 162 | type=str, 163 | default="RN50", 164 | help="Name of the vision backbone to use.", 165 | ) 166 | parser.add_argument( 167 | "--pretrained", 168 | default='', 169 | type=str, 170 | help="Use a pretrained CLIP model weights with the specified tag or file path.", 171 | ) 172 | parser.add_argument( 173 | "--pretrained-image", 174 | default=False, 175 | action='store_true', 176 | help="Load imagenet pretrained weights for image tower backbone if available.", 177 | ) 178 | parser.add_argument( 179 | "--lock-image", 180 | default=False, 181 | action='store_true', 182 | help="Lock full image tower by disabling gradients.", 183 | ) 184 | parser.add_argument( 185 | "--lock-image-unlocked-groups", 186 | type=int, 187 | default=0, 188 | help="Leave last n image tower layer groups unlocked.", 189 | ) 190 | parser.add_argument( 191 | "--lock-image-freeze-bn-stats", 192 | default=False, 193 | action='store_true', 194 | help="Freeze BatchNorm running stats in image tower for any locked layers.", 195 | ) 196 | parser.add_argument( 197 | "--grad-checkpointing", 198 | default=False, 199 | action='store_true', 200 | help="Enable gradient checkpointing.", 201 | ) 202 | parser.add_argument( 203 | "--local-loss", 204 | default=False, 205 | action="store_true", 206 | help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)" 207 | ) 208 | parser.add_argument( 209 | "--gather-with-grad", 210 | default=False, 211 | action="store_true", 212 | help="enable full distributed gradient for feature gather" 213 | ) 214 | parser.add_argument( 215 | "--force-quick-gelu", 216 | default=False, 217 | action='store_true', 218 | help="Force use of QuickGELU activation for non-OpenAI transformer models.", 219 | ) 220 | parser.add_argument( 221 | "--torchscript", 222 | default=False, 223 | action='store_true', 224 | help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", 225 | ) 226 | parser.add_argument( 227 | "--trace", 228 | default=False, 229 | action='store_true', 230 | help="torch.jit.trace the model for inference / eval only", 231 | ) 232 | # arguments for distributed training 233 | parser.add_argument( 234 | "--dist-url", 235 | default="env://", 236 | type=str, 237 | help="url used to set up distributed training", 238 | ) 239 | parser.add_argument( 240 | "--dist-backend", default="nccl", type=str, help="distributed backend" 241 | ) 242 | parser.add_argument( 243 | "--report-to", 244 | default='', 245 | type=str, 246 | help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']" 247 | ) 248 | parser.add_argument( 249 | "--wandb-notes", 250 | default='', 251 | type=str, 252 | help="Notes if logging with wandb" 253 | ) 254 | parser.add_argument( 255 | "--debug", 256 | default=False, 257 | action="store_true", 258 | help="If true, more information is logged." 259 | ) 260 | parser.add_argument( 261 | "--copy-codebase", 262 | default=False, 263 | action="store_true", 264 | help="If true, we copy the entire base on the log diretory, and execute from there." 265 | ) 266 | parser.add_argument( 267 | "--horovod", 268 | default=False, 269 | action="store_true", 270 | help="Use horovod for distributed training." 271 | ) 272 | parser.add_argument( 273 | "--ddp-static-graph", 274 | default=False, 275 | action='store_true', 276 | help="Enable static graph optimization for DDP in PyTorch >= 1.11.", 277 | ) 278 | parser.add_argument( 279 | "--no-set-device-rank", 280 | default=False, 281 | action="store_true", 282 | help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc)." 283 | ) 284 | parser.add_argument( 285 | "--seed", type=int, default=0, help="Default random seed." 286 | ) 287 | parser.add_argument( 288 | "--norm_gradient_clip", type=float, default=None, help="Gradient clip." 289 | ) 290 | args = parser.parse_args() 291 | 292 | # If some params are not passed, we use the default values based on model name. 293 | default_params = get_default_params(args.model) 294 | for name, val in default_params.items(): 295 | if getattr(args, name) is None: 296 | setattr(args, name, val) 297 | 298 | return args 299 | -------------------------------------------------------------------------------- /src/training/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | e = step - warmup_length 19 | es = steps - warmup_length 20 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 21 | assign_learning_rate(optimizer, lr) 22 | return lr 23 | return _lr_adjuster -------------------------------------------------------------------------------- /src/training/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import math 4 | import os 5 | import time 6 | from contextlib import suppress 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | try: 13 | import wandb 14 | except ImportError: 15 | wandb = None 16 | 17 | from open_clip import ClipLoss 18 | from .distributed import is_master 19 | from .zero_shot import zero_shot_eval 20 | 21 | 22 | class AverageMeter(object): 23 | """Computes and stores the average and current value""" 24 | def __init__(self): 25 | self.reset() 26 | 27 | def reset(self): 28 | self.val = 0 29 | self.avg = 0 30 | self.sum = 0 31 | self.count = 0 32 | 33 | def update(self, val, n=1): 34 | self.val = val 35 | self.sum += val * n 36 | self.count += n 37 | self.avg = self.sum / self.count 38 | 39 | 40 | def unwrap_model(model): 41 | if hasattr(model, 'module'): 42 | return model.module 43 | else: 44 | return model 45 | 46 | 47 | def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None): 48 | device = torch.device(args.device) 49 | autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress 50 | 51 | model.train() 52 | loss = ClipLoss( 53 | local_loss=args.local_loss, 54 | gather_with_grad=args.gather_with_grad, 55 | cache_labels=True, 56 | rank=args.rank, 57 | world_size=args.world_size, 58 | use_horovod=args.horovod) 59 | 60 | data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch 61 | dataloader = data['train'].dataloader 62 | num_batches_per_epoch = dataloader.num_batches 63 | sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) 64 | 65 | loss_m = AverageMeter() 66 | batch_time_m = AverageMeter() 67 | data_time_m = AverageMeter() 68 | end = time.time() 69 | for i, batch in enumerate(dataloader): 70 | step = num_batches_per_epoch * epoch + i 71 | scheduler(step) 72 | 73 | images, new_images, texts, new_texts, hard_captions, new_hard = batch 74 | 75 | images = images.to(device=device, non_blocking=True) 76 | new_images = new_images.to(device=device, non_blocking=True) 77 | 78 | texts = texts.to(device=device, non_blocking=True) 79 | new_texts = new_texts.to(device=device, non_blocking=True) 80 | 81 | hard_captions = hard_captions.to(device=device, non_blocking=True) 82 | new_hard = new_hard.to(device=device, non_blocking=True) 83 | 84 | images = torch.cat([images, new_images]) 85 | 86 | texts = torch.cat([texts, new_texts]) 87 | texts = torch.cat([texts, hard_captions]) 88 | texts = torch.cat([texts, new_hard]) 89 | 90 | data_time_m.update(time.time() - end) 91 | optimizer.zero_grad() 92 | 93 | with autocast(): 94 | image_features, text_features, logit_scale = model(images, texts) 95 | total_loss = loss(image_features, text_features, logit_scale) 96 | 97 | if scaler is not None: 98 | scaler.scale(total_loss).backward() 99 | if args.horovod: 100 | optimizer.synchronize() 101 | scaler.unscale_(optimizer) 102 | if args.norm_gradient_clip is not None: 103 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) 104 | with optimizer.skip_synchronize(): 105 | scaler.step(optimizer) 106 | else: 107 | if args.norm_gradient_clip is not None: 108 | scaler.unscale_(optimizer) 109 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) 110 | scaler.step(optimizer) 111 | scaler.update() 112 | else: 113 | total_loss.backward() 114 | if args.norm_gradient_clip is not None: 115 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) 116 | optimizer.step() 117 | 118 | # Note: we clamp to 4.6052 = ln(100), as in the original paper. 119 | with torch.no_grad(): 120 | unwrap_model(model).logit_scale.clamp_(0, math.log(100)) 121 | 122 | batch_time_m.update(time.time() - end) 123 | end = time.time() 124 | batch_count = i + 1 125 | if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): 126 | batch_size = len(images) 127 | num_samples = batch_count * batch_size * args.world_size 128 | samples_per_epoch = dataloader.num_samples 129 | percent_complete = 100.0 * batch_count / num_batches_per_epoch 130 | 131 | # NOTE loss is coarsely sampled, just master node and per log update 132 | loss_m.update(total_loss.item(), batch_size) 133 | logit_scale_scalar = logit_scale.item() 134 | logging.info( 135 | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " 136 | f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " 137 | f"Data (t): {data_time_m.avg:.3f} " 138 | f"Batch (t): {batch_time_m.avg:.3f}, {args.batch_size*args.world_size / batch_time_m.val:#g}/s " 139 | f"LR: {optimizer.param_groups[0]['lr']:5f} " 140 | f"Logit Scale: {logit_scale_scalar:.3f} - V4" 141 | ) 142 | 143 | # Save train loss / etc. Using non avg meter values as loggers have their own smoothing 144 | log_data = { 145 | "loss": loss_m.val, 146 | "data_time": data_time_m.val, 147 | "batch_time": batch_time_m.val, 148 | "samples_per_scond": args.batch_size*args.world_size / batch_time_m.val, 149 | "scale": logit_scale_scalar, 150 | "lr": optimizer.param_groups[0]["lr"] 151 | } 152 | for name, val in log_data.items(): 153 | name = "train/" + name 154 | if tb_writer is not None: 155 | tb_writer.add_scalar(name, val, step) 156 | if args.wandb: 157 | assert wandb is not None, 'Please install wandb.' 158 | wandb.log({name: val, 'step': step}) 159 | 160 | # resetting batch / data time meters per log window 161 | batch_time_m.reset() 162 | data_time_m.reset() 163 | # end for 164 | 165 | 166 | def evaluate(model, data, epoch, args, tb_writer=None): 167 | metrics = {} 168 | if not is_master(args): 169 | return metrics 170 | device = torch.device(args.device) 171 | model.eval() 172 | 173 | zero_shot_metrics = zero_shot_eval(model, data, epoch, args) 174 | metrics.update(zero_shot_metrics) 175 | 176 | autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress 177 | if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)): 178 | dataloader = data['val'].dataloader 179 | num_samples = 0 180 | samples_per_val = dataloader.num_samples 181 | 182 | # FIXME this does not scale past small eval datasets 183 | # all_image_features @ all_text_features will blow up memory and compute very quickly 184 | cumulative_loss = 0.0 185 | all_image_features, all_text_features = [], [] 186 | with torch.no_grad(): 187 | for i, batch in enumerate(dataloader): 188 | images, hard_images, texts, texts_hard_images, hard_captions, hard_captions_of_hard_images = batch 189 | 190 | images = images.to(device=device, non_blocking=True) 191 | hard_images = hard_images.to(device=device, non_blocking=True) 192 | 193 | texts = texts.to(device=device, non_blocking=True) 194 | texts_hard_images = texts_hard_images.to(device=device, non_blocking=True) 195 | 196 | hard_captions = hard_captions.to(device=device, non_blocking=True) 197 | hard_captions_of_hard_images = hard_captions_of_hard_images.to(device=device, non_blocking=True) 198 | 199 | images = torch.cat([images, hard_images]) 200 | 201 | texts = torch.cat([texts, texts_hard_images]) 202 | texts = torch.cat([texts, hard_captions]) 203 | texts = torch.cat([texts, hard_captions_of_hard_images]) 204 | 205 | with autocast(): 206 | image_features, text_features, logit_scale = model(images, texts) 207 | # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly 208 | # however, system RAM is easily exceeded and compute time becomes problematic 209 | 210 | logit_scale = logit_scale.mean() 211 | logits_per_image = logit_scale * image_features @ text_features.t() 212 | logits_per_text = logits_per_image.t() 213 | 214 | all_image_features.append(image_features.cpu()) 215 | all_text_features.append(text_features[:len(logits_per_image)].cpu()) 216 | 217 | batch_size = images.shape[0] 218 | labels = torch.arange(batch_size, device=device).long() 219 | total_loss = ( 220 | F.cross_entropy(logits_per_image, labels) + 221 | F.cross_entropy(logits_per_text[:len(logits_per_image)], labels) 222 | ) / 2 223 | 224 | cumulative_loss += total_loss * batch_size 225 | num_samples += batch_size 226 | if is_master(args) and (i % 100) == 0: 227 | logging.info( 228 | f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" 229 | f"Loss: {cumulative_loss / num_samples:.6f}\t") 230 | 231 | val_metrics = get_metrics( 232 | image_features=torch.cat(all_image_features), 233 | text_features=torch.cat(all_text_features), 234 | logit_scale=logit_scale.cpu(), 235 | ) 236 | loss = cumulative_loss / num_samples 237 | metrics.update( 238 | {**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} 239 | ) 240 | 241 | if not metrics: 242 | return metrics 243 | 244 | logging.info( 245 | f"Eval Epoch: {epoch} " 246 | + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) 247 | ) 248 | 249 | if args.save_logs: 250 | for name, val in metrics.items(): 251 | if tb_writer is not None: 252 | tb_writer.add_scalar(f"val/{name}", val, epoch) 253 | 254 | with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: 255 | f.write(json.dumps(metrics)) 256 | f.write("\n") 257 | 258 | if args.wandb: 259 | assert wandb is not None, 'Please install wandb.' 260 | for name, val in metrics.items(): 261 | wandb.log({f"val/{name}": val, 'epoch': epoch}) 262 | 263 | return metrics 264 | 265 | 266 | def get_metrics(image_features, text_features, logit_scale): 267 | metrics = {} 268 | logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() 269 | logits_per_text = logits_per_image.t().detach().cpu() 270 | 271 | logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text[:len(logits_per_image)]} 272 | ground_truth = torch.arange(len(text_features)).view(-1, 1) 273 | 274 | for name, logit in logits.items(): 275 | ranking = torch.argsort(logit, descending=True) 276 | preds = torch.where(ranking == ground_truth)[1] 277 | preds = preds.detach().cpu().numpy() 278 | metrics[f"{name}_mean_rank"] = preds.mean() + 1 279 | metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 280 | for k in [1, 5, 10]: 281 | metrics[f"{name}_R@{k}"] = np.mean(preds < k) 282 | 283 | return metrics 284 | -------------------------------------------------------------------------------- /src/training/zero_shot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from contextlib import suppress 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from tqdm import tqdm 7 | 8 | from open_clip import tokenize 9 | from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template 10 | 11 | 12 | def zero_shot_classifier(model, classnames, templates, args): 13 | with torch.no_grad(): 14 | zeroshot_weights = [] 15 | for classname in tqdm(classnames): 16 | texts = [template(classname) for template in templates] # format with class 17 | texts = tokenize(texts).to(args.device) # tokenize 18 | if args.distributed and not args.horovod: 19 | class_embeddings = model.module.encode_text(texts) 20 | else: 21 | class_embeddings = model.encode_text(texts) 22 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) 23 | class_embedding /= class_embedding.norm() 24 | zeroshot_weights.append(class_embedding) 25 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) 26 | return zeroshot_weights 27 | 28 | 29 | def accuracy(output, target, topk=(1,)): 30 | pred = output.topk(max(topk), 1, True, True)[1].t() 31 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 32 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 33 | 34 | 35 | def run(model, classifier, dataloader, args): 36 | autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress 37 | with torch.no_grad(): 38 | top1, top5, n = 0., 0., 0. 39 | for images, target in tqdm(dataloader, unit_scale=args.batch_size): 40 | images = images.to(args.device) 41 | target = target.to(args.device) 42 | 43 | with autocast(): 44 | # predict 45 | if args.distributed and not args.horovod: 46 | image_features = model.module.encode_image(images) 47 | else: 48 | image_features = model.encode_image(images) 49 | image_features = F.normalize(image_features, dim=-1) 50 | logits = 100. * image_features @ classifier 51 | 52 | # measure accuracy 53 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 54 | top1 += acc1 55 | top5 += acc5 56 | n += images.size(0) 57 | 58 | top1 = (top1 / n) 59 | top5 = (top5 / n) 60 | return top1, top5 61 | 62 | 63 | def zero_shot_eval(model, data, epoch, args): 64 | if 'imagenet-val' not in data and 'imagenet-v2' not in data: 65 | return {} 66 | if args.zeroshot_frequency == 0: 67 | return {} 68 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: 69 | return {} 70 | 71 | logging.info('Starting zero-shot imagenet.') 72 | 73 | logging.info('Building zero-shot classifier') 74 | classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args) 75 | 76 | logging.info('Using classifier') 77 | results = {} 78 | if 'imagenet-val' in data: 79 | top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) 80 | results['imagenet-zeroshot-val-top1'] = top1 81 | results['imagenet-zeroshot-val-top5'] = top5 82 | if 'imagenet-v2' in data: 83 | top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) 84 | results['imagenetv2-zeroshot-val-top1'] = top1 85 | results['imagenetv2-zeroshot-val-top5'] = top5 86 | 87 | logging.info('Finished zero-shot imagenet.') 88 | 89 | return results 90 | -------------------------------------------------------------------------------- /tests/test_simple.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from PIL import Image 4 | from open_clip import tokenizer 5 | import open_clip 6 | import os 7 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 8 | 9 | def test_inference(): 10 | model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32') 11 | 12 | current_dir = os.path.dirname(os.path.realpath(__file__)) 13 | 14 | image = preprocess(Image.open(current_dir + "/../docs/CLIP.png")).unsqueeze(0) 15 | text = tokenizer.tokenize(["a diagram", "a dog", "a cat"]) 16 | 17 | with torch.no_grad(): 18 | image_features = model.encode_image(image) 19 | text_features = model.encode_text(text) 20 | 21 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) 22 | 23 | assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0] --------------------------------------------------------------------------------