├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── data ├── README.md ├── coco │ └── prepare_data.py └── imgnet │ ├── imgnet_real_query.txt │ └── imgnet_targets.txt ├── env.sh ├── model ├── clip.py └── model.py ├── requirements.txt ├── setenv.sh ├── src ├── data.py ├── demo.py ├── eval_retrieval.py ├── eval_utils.py ├── logger.py ├── main.py ├── params.py ├── trainer.py └── utils.py ├── third_party └── open_clip │ ├── LICENSE │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ ├── environment.yml │ ├── model.py │ ├── model_configs │ ├── RN101.json │ ├── RN50.json │ ├── RN50_a2.json │ ├── RN50_a2s.json │ ├── RN50x16.json │ ├── RN50x4.json │ ├── ViT-B-16.json │ └── ViT-B-32.json │ ├── scheduler.py │ └── simple_tokenizer.py └── valprep.sh /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pic2Word (CVPR2023) 2 | 3 | This is an open source implementation of [Pic2Word](https://arxiv.org/pdf/2302.03084.pdf). This is not an 4 | officially supported Google product. 5 | 6 | 7 | ## Data 8 | 9 | ### Training Data 10 | We utilize [Conceptual Captions URLs](https://ai.google.com/research/ConceptualCaptions/download) to train a model. 11 | See [open_clip](https://github.com/mlfoundations/open_clip) to see the process of getting the dataset. 12 | 13 | The training data directory has to be in the root of this repo, and should be structured like below. 14 | ```bash 15 | cc_data 16 | ├── train ## training image diretories. 17 | └── val ## validation image directories. 18 | cc 19 | ├── Train_GCC-training_output.csv ## training data list 20 | └── Validation_GCC-1.1.0-Validation_output.csv ## validation data list 21 | ``` 22 | 23 | ### Test Data 24 | See [README](data/README.md) to prepare test dataset. 25 | 26 | ## Training 27 | 28 | ### Install dependencies 29 | See [open_clip](https://github.com/mlfoundations/open_clip) for the details of installation. 30 | The same environment should be usable in this repo. 31 | setenv.sh is the script we used to set-up the environment in virtualenv. 32 | 33 | Also run below to add directory to pythonpath: 34 | ```bash 35 | . env3/bin/activate 36 | export PYTHONPATH="$PYTHONPATH:$PWD/src" 37 | export PYTHONWARNINGS='ignore:semaphore_tracker:UserWarning' 38 | ``` 39 | ### Pre-trained model 40 | The model is available in [GoogleDrive](https://drive.google.com/file/d/1IxRi2Cj81RxMu0ViT4q4nkfyjbSHm1dF/view?usp=sharing). 41 | 42 | ### Sample running code for training: 43 | 44 | ```bash 45 | python -u src/main.py \ 46 | --save-frequency 1 \ 47 | --train-data="cc/Train_GCC-training_output.csv" \ 48 | --warmup 10000 \ 49 | --batch-size=128 \ 50 | --lr=1e-4 \ 51 | --wd=0.1 \ 52 | --epochs=30 \ 53 | --workers=8 \ 54 | --openai-pretrained \ 55 | --model ViT-L/14 56 | ``` 57 | 58 | ### Sample evaluation only: 59 | 60 | Evaluation on COCO, ImageNet, or CIRR. 61 | ```bash 62 | python src/eval_retrieval.py \ 63 | --openai-pretrained \ 64 | --resume /path/to/checkpoints \ 65 | --eval-mode $data_name \ ## replace with coco, imgnet, or cirr 66 | --gpu $gpu_id 67 | --model ViT-L/14 68 | ``` 69 | 70 | Evaluation on fashion-iq (shirt or dress or toptee) 71 | ```bash 72 | python src/eval_retrieval.py \ 73 | --openai-pretrained \ 74 | --resume /path/to/checkpoints \ 75 | --eval-mode fashion \ 76 | --source $cloth_type \ ## replace with shirt or dress or toptee 77 | --gpu $gpu_id 78 | --model ViT-L/14 79 | ``` 80 | 81 | ### Demo: 82 | 83 | Evaluation on COCO, ImageNet, or CIRR. 84 | 85 | ```bash 86 | python src/demo.py \ 87 | --openai-pretrained \ 88 | --resume /path/to/checkpoints \ 89 | --retrieval-data $data_name \ ## Choose from coco, imgnet, cirr, dress, shirt, toptee. 90 | --query_file "path_img1,path_img2,path_img3..." \ ## query images 91 | --prompts "prompt1,prompt2,..." \ #prompts. Use * to indicate the token to be replaced with an image token. e.g., "a sketch of *" 92 | --demo-out $path_demo \ # directory to generate html file and image directory. 93 | --gpu $gpu_id 94 | --model ViT-L/14 95 | ``` 96 | This demo will generate a directory which includes html file and an image directory. Download the directory and open html to see results. 97 | 98 | ## Citing 99 | 100 | If you found this repository useful, please consider citing: 101 | 102 | ```bibtex 103 | @article{saito2023pic2word, 104 | title={Pic2Word: Mapping Pictures to Words for Zero-shot Composed Image Retrieval}, 105 | author={Saito, Kuniaki and Sohn, Kihyuk and Zhang, Xiang and Li, Chun-Liang and Lee, Chen-Yu and Saenko, Kate and Pfister, Tomas}, 106 | journal={CVPR}, 107 | year={2023} 108 | } 109 | 110 | ``` 111 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## Data 2 | 3 | Overall structure of this directory should be as follows. 4 | ```bash 5 | data 6 |    ├── coco 7 |    ├── imgnet 8 | ├── CIRR 9 |    └── fashion-iq 10 | ``` 11 | 12 | ### ImageNet 13 | ```bash 14 | imgnet 15 | ├── imagenet-r ## unzipped imagenet-r directories containing images. This folder should contain subfolders. 16 | └──n01443537 17 | . 18 | . 19 | 20 | ├── imgnet_real_query.txt 21 | ├── imgnet_targets.txt 22 | └── real ## imagenet validation directories containing images. This folder should contain subfolders. 23 | └──n01440764 24 | . 25 | . 26 | ``` 27 | See [ImageNet-R](https://github.com/hendrycks/imagenet-r) to download the dataset. 28 | 29 | ### COCO 30 | ```bash 31 | coco 32 |    ├── annotations/instances_val2017.json ## annotations for COCO validation images. 33 |    ├── prepare_data.py ## code to generate query data. 34 | ├── coco_eval.csv ## this will be generated by running prepare_data.py 35 |    ├── val2017 ## directory containing COCO validation images. 36 |    └── val2017_masked ## running prepare_data.py will produce the directory. 37 | ``` 38 | Download both instances_val2017.json and val2017. 39 | Run the command to below to produce directory of val2017_masked. 40 | ```bash 41 | python prepare_data.py 42 | ``` 43 | 44 | ### CIRR 45 | 46 | ``` 47 | cirr 48 |   ├── captions 49 | └──cap.rc2.val.json 50 |   ├── dev 51 |   └── image_splits 52 | └──split.rc2.val.json 53 | ``` 54 | Download the images following instruction on [CIRR](https://github.com/Cuberick-Orion/CIRR). 55 | 56 | ### Fashion-IQ 57 | 58 | ``` 59 | fashion-iq 60 | ├── json 61 | ├── cap.dress.val.json 62 | ├── cap.shirt.val.json 63 | └── cap.toptee.val.json 64 | ├── image_splits 65 | ├── split.dress.val.json 66 | ├── split.shirt.val.json 67 | └── split.toptee.val.json 68 | └── images ## images under this directory. 69 | ``` 70 | Json files are available in https://github.com/XiaoxiaoGuo/fashion-iq. 71 | Images are downloaded from https://github.com/postBG/CosMo.pytorch. 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /data/coco/prepare_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pycocotools.coco import COCO 16 | from collections import defaultdict 17 | import random 18 | import pandas as pd 19 | from PIL import Image 20 | import json 21 | import numpy as np 22 | import os 23 | 24 | coco = COCO(annotation_file='annotations/instances_val2017.json') 25 | cat_ids = coco.getCatIds() 26 | def convert_coco_json_to_csv(filename='./annotations/instances_val2017.json', root='./val2017'): 27 | s = json.load(open(filename, 'r')) 28 | out_file = 'coco_eval.csv' 29 | mask_dir = root+"_masked" 30 | if not os.path.exists(mask_dir): 31 | os.makedirs(mask_dir) 32 | out = open(out_file, 'w') 33 | out.write('id,query_regions,query_class,classes\n') 34 | all_ids = [] 35 | dict_id2cat = {item['id']:item['name'] for item in s['categories']} 36 | for im in s['images']: 37 | all_ids.append(im['id']) 38 | all_ids_ann = [] 39 | id2anns = defaultdict(list) 40 | for ann in s['annotations']: 41 | image_id = ann['image_id'] 42 | all_ids_ann.append(image_id) 43 | x1 = ann['bbox'][0] 44 | x2 = ann['bbox'][0] + ann['bbox'][2] 45 | y1 = ann['bbox'][1] 46 | y2 = ann['bbox'][1] + ann['bbox'][3] 47 | label = dict_id2cat[ann['category_id']] 48 | tmp = [x1, y1, x2, y2, label, ann] 49 | id2anns[image_id].append(tmp) 50 | # Give query regions + classes not included in the query as a hint to retrieve images. 51 | class_count = 0 52 | for id_img in id2anns.keys(): 53 | anns = id2anns[id_img] 54 | label_set = {} 55 | for ann in anns: 56 | label_set[ann[-2]] = label_set.get(ann[-2], 0) + 1 57 | label_set = list(label_set.keys()) 58 | class_count += len(label_set) 59 | output = "%012d.jpg," %id_img 60 | image = Image.open(os.path.join(root, "%012d.jpg" %id_img)) 61 | image = np.array(image) 62 | width, height = image.shape[0], image.shape[1] 63 | area_img = width * height 64 | cand_query = [] 65 | for cand in anns: 66 | x1, y1, x2, y2 = map(lambda x: float(x), cand[:-2]) 67 | area = (x2-x1) * (y2-y1) 68 | if 0.05 < area < 0.5 * area_img: 69 | cand_query.append(cand) 70 | if len(cand_query) >= 1: 71 | query_regions = random.sample(cand_query, k=1) 72 | for region in query_regions: 73 | query_label = region[-2] 74 | ann_region = region[-1] 75 | 76 | id_img = ann_region['image_id'] 77 | filename = coco.imgs[id_img]['file_name'] 78 | image = Image.open(os.path.join(root, filename)) 79 | image = np.array(image) 80 | mask = coco.annToMask(ann_region) 81 | width, height = mask.shape 82 | mask = mask.reshape(width, height,1) 83 | if len(image.shape) == 2: 84 | image = image.reshape(width, height, 1) 85 | image_masked = image * mask + (1-mask) * 255 86 | try: 87 | im = Image.fromarray(image_masked) 88 | except: 89 | image_masked = np.squeeze(image_masked, axis=2) 90 | im = Image.fromarray(image_masked) 91 | im.save(os.path.join(mask_dir, filename)) 92 | 93 | label_set.remove(query_label) 94 | output += ";".join(map(lambda x: str(x), region[:-2])) 95 | output += " " 96 | output += "," 97 | output += query_label 98 | output += "," 99 | output += ";".join(label_set) 100 | output += "\n" 101 | out.write(output) 102 | out.close() 103 | # Sort file by image id 104 | s1 = pd.read_csv(out_file) 105 | s1.sort_values('id', inplace=True) 106 | s1.to_csv(out_file, index=False) 107 | 108 | convert_coco_json_to_csv() 109 | -------------------------------------------------------------------------------- /env.sh: -------------------------------------------------------------------------------- 1 | . env3/bin/activate 2 | export PYTHONPATH="$PYTHONPATH:$PWD/src" 3 | export PYTHONWARNINGS='ignore:semaphore_tracker:UserWarning' 4 | 5 | -------------------------------------------------------------------------------- /model/clip.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Most code is from https://github.com/openai/CLIP 16 | import hashlib 17 | import os 18 | import urllib 19 | import warnings 20 | from typing import Union, List 21 | import torch 22 | from PIL import Image 23 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomResizedCrop 24 | from tqdm import tqdm 25 | from model.model import build_model 26 | from third_party.open_clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 27 | 28 | from functools import * 29 | try: 30 | from huggingface_hub import hf_hub_download 31 | __version__ = '2.0.2' 32 | hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) 33 | _has_hf_hub = True 34 | except ImportError: 35 | hf_hub_download = None 36 | _has_hf_hub = False 37 | 38 | __all__ = ["available_models", "load", "tokenize"] 39 | _tokenizer = _Tokenizer() 40 | 41 | _MODELS = { 42 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 43 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 44 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 45 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 46 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 47 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 48 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 49 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 50 | } 51 | _OPENAI = { 52 | "ViT-H-14": 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K/' 53 | , 54 | } 55 | 56 | def has_hf_hub(necessary=False): 57 | if not _has_hf_hub and necessary: 58 | # if no HF Hub module installed, and it is necessary to continue, raise error 59 | raise RuntimeError( 60 | 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') 61 | return _has_hf_hub 62 | 63 | def download_pretrained_from_hf( 64 | model_id: str, 65 | filename: str = 'open_clip_pytorch_model.bin', 66 | revision=None, 67 | cache_dir: Union[str, None] = None, 68 | ): 69 | has_hf_hub(True) 70 | cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) 71 | return cached_file 72 | 73 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 74 | os.makedirs(root, exist_ok=True) 75 | filename = os.path.basename(url) 76 | 77 | expected_sha256 = url.split("/")[-2] 78 | download_target = os.path.join(root, filename) 79 | 80 | if os.path.exists(download_target) and not os.path.isfile(download_target): 81 | raise RuntimeError(f"{download_target} exists and is not a regular file") 82 | 83 | if os.path.isfile(download_target): 84 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 85 | return download_target 86 | else: 87 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 88 | 89 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 90 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 91 | while True: 92 | buffer = source.read(8192) 93 | if not buffer: 94 | break 95 | 96 | output.write(buffer) 97 | loop.update(len(buffer)) 98 | 99 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 100 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 101 | 102 | return download_target 103 | 104 | def _convert_to_rgb(image): 105 | return image.convert('RGB') 106 | 107 | def _transform(n_px: int, is_train: bool): 108 | normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 109 | if is_train: 110 | return Compose([ 111 | RandomResizedCrop(n_px, scale=(0.9, 1.0), interpolation=Image.BICUBIC), 112 | _convert_to_rgb, 113 | ToTensor(), 114 | normalize, 115 | ]) 116 | else: 117 | return Compose([ 118 | Resize(n_px, interpolation=Image.BICUBIC), 119 | CenterCrop(n_px), 120 | _convert_to_rgb, 121 | ToTensor(), 122 | normalize, 123 | ]) 124 | 125 | 126 | 127 | def available_models() -> List[str]: 128 | """Returns the names of available CLIP models""" 129 | return list(_MODELS.keys()) 130 | 131 | 132 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True, is_train=False, pretrained=True): 133 | """Load a CLIP model 134 | Parameters 135 | ---------- 136 | name : str 137 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 138 | device : Union[str, torch.device] 139 | The device to put the loaded model 140 | jit : bool 141 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 142 | Returns 143 | ------- 144 | model : torch.nn.Module 145 | The CLIP model 146 | preprocess : Callable[[PIL.Image], torch.Tensor] 147 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 148 | """ 149 | if name in _MODELS: 150 | model_path = _download(_MODELS[name]) 151 | elif os.path.isfile(name): 152 | model_path = name 153 | elif name in _OPENAI: 154 | has_hf_hub(True) 155 | # we assume the hf_hub entries in pretrained config combine model_id + filename in 156 | # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and 157 | # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. 158 | model_id, filename = os.path.split(_OPENAI[name]) 159 | if filename: 160 | model_path = download_pretrained_from_hf(model_id, filename=filename) 161 | else: 162 | model_path = download_pretrained_from_hf(model_id) 163 | else: 164 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 165 | 166 | try: 167 | # loading JIT archive 168 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 169 | state_dict = None 170 | except RuntimeError: 171 | # loading saved state dict 172 | if jit: 173 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 174 | jit = False 175 | state_dict = torch.load(model_path, map_location="cpu") 176 | 177 | if not jit: 178 | try: 179 | model = build_model(state_dict or model.state_dict()).to(device) 180 | except KeyError: 181 | sd = {k[7:]: v for k,v in state_dict["state_dict"].items()} 182 | model = build_model(sd).to(device) 183 | 184 | if str(device) == "cpu": 185 | model.float() 186 | return model, \ 187 | _transform(model.visual.input_resolution, is_train=True), \ 188 | _transform(model.visual.input_resolution, is_train=False) 189 | 190 | # patch the device names 191 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 192 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 193 | 194 | def patch_device(module): 195 | graphs = [module.graph] if hasattr(module, "graph") else [] 196 | if hasattr(module, "forward1"): 197 | graphs.append(module.forward1.graph) 198 | 199 | for graph in graphs: 200 | for node in graph.findAllNodes("prim::Constant"): 201 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 202 | node.copyAttributes(device_node) 203 | 204 | model.apply(patch_device) 205 | patch_device(model.encode_image) 206 | patch_device(model.encode_text) 207 | 208 | # patch dtype to float32 on CPU 209 | if str(device) == "cpu": 210 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 211 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 212 | float_node = float_input.node() 213 | 214 | def patch_float(module): 215 | graphs = [module.graph] if hasattr(module, "graph") else [] 216 | if hasattr(module, "forward1"): 217 | graphs.append(module.forward1.graph) 218 | 219 | for graph in graphs: 220 | for node in graph.findAllNodes("aten::to"): 221 | inputs = list(node.inputs()) 222 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 223 | if inputs[i].node()["value"] == 5: 224 | inputs[i].node().copyAttributes(float_node) 225 | 226 | model.apply(patch_float) 227 | patch_float(model.encode_image) 228 | patch_float(model.encode_text) 229 | 230 | model.float() 231 | 232 | return model, \ 233 | _transform(model.input_resolution.item(), is_train=True), \ 234 | _transform(model.input_resolution.item(), is_train=False) 235 | 236 | 237 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 238 | """ 239 | Returns the tokenized representation of given input string(s) 240 | Parameters 241 | ---------- 242 | texts : Union[str, List[str]] 243 | An input string or a list of input strings to tokenize 244 | context_length : int 245 | The context length to use; all CLIP models use 77 as the context length 246 | Returns 247 | ------- 248 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 249 | """ 250 | if isinstance(texts, str): 251 | texts = [texts] 252 | 253 | sot_token = _tokenizer.encoder[""] 254 | eot_token = _tokenizer.encoder[""] 255 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 256 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 257 | 258 | for i, tokens in enumerate(all_tokens): 259 | if len(tokens) > context_length: # Truncate 260 | tokens = tokens[:context_length-1] 261 | tokens = tokens + [eot_token] 262 | result[i, :len(tokens)] = torch.tensor(tokens) 263 | 264 | return result 265 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections import OrderedDict 16 | from typing import Tuple, Union 17 | 18 | import os 19 | import json 20 | from copy import deepcopy 21 | import numpy as np 22 | import torch 23 | import torch.nn.functional as F 24 | from torch import nn 25 | import torch.distributed as dist 26 | 27 | class IM2TEXT(nn.Module): 28 | def __init__(self, embed_dim=512, middle_dim=512, output_dim=512, n_layer=2, dropout=0.1): 29 | super().__init__() 30 | self.fc_out = nn.Linear(middle_dim, output_dim) 31 | layers = [] 32 | dim = embed_dim 33 | for _ in range(n_layer): 34 | block = [] 35 | block.append(nn.Linear(dim, middle_dim)) 36 | block.append(nn.Dropout(dropout)) 37 | block.append(nn.ReLU()) 38 | dim = middle_dim 39 | layers.append(nn.Sequential(*block)) 40 | self.layers = nn.Sequential(*layers) 41 | 42 | def forward(self, x: torch.Tensor): 43 | for layer in self.layers: 44 | x = layer(x) 45 | return self.fc_out(x) 46 | 47 | class Bottleneck(nn.Module): 48 | expansion = 4 49 | 50 | def __init__(self, inplanes, planes, stride=1): 51 | super().__init__() 52 | 53 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 54 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 55 | self.bn1 = nn.BatchNorm2d(planes) 56 | 57 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 58 | self.bn2 = nn.BatchNorm2d(planes) 59 | 60 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 61 | 62 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 63 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 64 | 65 | self.relu = nn.ReLU(inplace=True) 66 | self.downsample = None 67 | self.stride = stride 68 | 69 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 70 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 71 | self.downsample = nn.Sequential(OrderedDict([ 72 | ("-1", nn.AvgPool2d(stride)), 73 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 74 | ("1", nn.BatchNorm2d(planes * self.expansion)) 75 | ])) 76 | 77 | def forward(self, x: torch.Tensor): 78 | identity = x 79 | 80 | out = self.relu(self.bn1(self.conv1(x))) 81 | out = self.relu(self.bn2(self.conv2(out))) 82 | out = self.avgpool(out) 83 | out = self.bn3(self.conv3(out)) 84 | 85 | if self.downsample is not None: 86 | identity = self.downsample(x) 87 | 88 | out += identity 89 | out = self.relu(out) 90 | return out 91 | 92 | 93 | class AttentionPool2d(nn.Module): 94 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 95 | super().__init__() 96 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 97 | self.k_proj = nn.Linear(embed_dim, embed_dim) 98 | self.q_proj = nn.Linear(embed_dim, embed_dim) 99 | self.v_proj = nn.Linear(embed_dim, embed_dim) 100 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 101 | self.num_heads = num_heads 102 | 103 | def forward(self, x): 104 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 105 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 106 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 107 | x, _ = F.multi_head_attention_forward( 108 | query=x, key=x, value=x, 109 | embed_dim_to_check=x.shape[-1], 110 | num_heads=self.num_heads, 111 | q_proj_weight=self.q_proj.weight, 112 | k_proj_weight=self.k_proj.weight, 113 | v_proj_weight=self.v_proj.weight, 114 | in_proj_weight=None, 115 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 116 | bias_k=None, 117 | bias_v=None, 118 | add_zero_attn=False, 119 | dropout_p=0, 120 | out_proj_weight=self.c_proj.weight, 121 | out_proj_bias=self.c_proj.bias, 122 | use_separate_proj_weight=True, 123 | training=self.training, 124 | need_weights=False 125 | ) 126 | 127 | return x[0] 128 | 129 | 130 | class ModifiedResNet(nn.Module): 131 | """ 132 | A ResNet class that is similar to torchvision's but contains the following changes: 133 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 134 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 135 | - The final pooling layer is a QKV attention instead of an average pool 136 | """ 137 | 138 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 139 | super().__init__() 140 | self.output_dim = output_dim 141 | self.input_resolution = input_resolution 142 | 143 | # the 3-layer stem 144 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 145 | self.bn1 = nn.BatchNorm2d(width // 2) 146 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 147 | self.bn2 = nn.BatchNorm2d(width // 2) 148 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 149 | self.bn3 = nn.BatchNorm2d(width) 150 | self.avgpool = nn.AvgPool2d(2) 151 | self.relu = nn.ReLU(inplace=True) 152 | 153 | # residual layers 154 | self._inplanes = width # this is a *mutable* variable used during construction 155 | self.layer1 = self._make_layer(width, layers[0]) 156 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 157 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 158 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 159 | 160 | embed_dim = width * 32 # the ResNet feature dimension 161 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 162 | 163 | def _make_layer(self, planes, blocks, stride=1): 164 | layers = [Bottleneck(self._inplanes, planes, stride)] 165 | 166 | self._inplanes = planes * Bottleneck.expansion 167 | for _ in range(1, blocks): 168 | layers.append(Bottleneck(self._inplanes, planes)) 169 | 170 | return nn.Sequential(*layers) 171 | 172 | def forward(self, x): 173 | def stem(x): 174 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 175 | x = self.relu(bn(conv(x))) 176 | x = self.avgpool(x) 177 | return x 178 | 179 | x = x.type(self.conv1.weight.dtype) 180 | x = stem(x) 181 | x = self.layer1(x) 182 | x = self.layer2(x) 183 | x = self.layer3(x) 184 | x = self.layer4(x) 185 | x = self.attnpool(x) 186 | 187 | return x 188 | 189 | 190 | class LayerNorm(nn.LayerNorm): 191 | """Subclass torch's LayerNorm to handle fp16.""" 192 | 193 | def forward(self, x: torch.Tensor): 194 | orig_type = x.dtype 195 | ret = super().forward(x.type(torch.float32)) 196 | return ret.type(orig_type) 197 | 198 | 199 | class QuickGELU(nn.Module): 200 | def forward(self, x: torch.Tensor): 201 | return x * torch.sigmoid(1.702 * x) 202 | 203 | 204 | class ResidualAttentionBlock(nn.Module): 205 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 206 | super().__init__() 207 | 208 | self.attn = nn.MultiheadAttention(d_model, n_head) 209 | self.ln_1 = LayerNorm(d_model) 210 | self.mlp = nn.Sequential(OrderedDict([ 211 | ("c_fc", nn.Linear(d_model, d_model * 4)), 212 | ("gelu", QuickGELU()), 213 | ("c_proj", nn.Linear(d_model * 4, d_model)) 214 | ])) 215 | self.ln_2 = LayerNorm(d_model) 216 | self.attn_mask = attn_mask 217 | 218 | def attention(self, x: torch.Tensor): 219 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 220 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 221 | 222 | def forward(self, x: torch.Tensor): 223 | x = x + self.attention(self.ln_1(x)) 224 | x = x + self.mlp(self.ln_2(x)) 225 | return x 226 | 227 | 228 | class Transformer(nn.Module): 229 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 230 | super().__init__() 231 | self.width = width 232 | self.layers = layers 233 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 234 | 235 | def forward(self, x: torch.Tensor): 236 | return self.resblocks(x) 237 | 238 | 239 | class VisualTransformer(nn.Module): 240 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 241 | super().__init__() 242 | self.input_resolution = input_resolution 243 | self.output_dim = output_dim 244 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 245 | 246 | scale = width ** -0.5 247 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 248 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 249 | self.ln_pre = LayerNorm(width) 250 | 251 | self.transformer = Transformer(width, layers, heads) 252 | 253 | self.ln_post = LayerNorm(width) 254 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 255 | 256 | def forward(self, x: torch.Tensor): 257 | x = self.conv1(x) # shape = [*, width, grid, grid] 258 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 259 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 260 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 261 | x = x + self.positional_embedding.to(x.dtype) 262 | x = self.ln_pre(x) 263 | 264 | x = x.permute(1, 0, 2) # NLD -> LND 265 | x = self.transformer(x) 266 | x = x.permute(1, 0, 2) # LND -> NLD 267 | 268 | x = self.ln_post(x[:, 0, :]) 269 | 270 | if self.proj is not None: 271 | x = x @ self.proj 272 | 273 | return x 274 | 275 | def get_tokens(self, x: torch.Tensor): 276 | x = self.conv1(x) # shape = [*, width, grid, grid] 277 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 278 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 279 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 280 | x = x + self.positional_embedding.to(x.dtype) 281 | x = self.ln_pre(x) 282 | x = x.permute(1, 0, 2) # NLD -> LND 283 | x = self.transformer(x) 284 | x = x.permute(1, 0, 2) # LND -> NLD 285 | return x 286 | 287 | 288 | class CLIP(nn.Module): 289 | def __init__(self, 290 | embed_dim: int, 291 | # vision 292 | image_resolution: int, 293 | vision_layers: Union[Tuple[int, int, int, int], int], 294 | vision_width: int, 295 | vision_patch_size: int, 296 | # text 297 | context_length: int, 298 | vocab_size: int, 299 | transformer_width: int, 300 | transformer_heads: int, 301 | transformer_layers: int, 302 | extra_transformer_layers: int = 0, 303 | share_projection_layer: bool = True, 304 | ): 305 | super().__init__() 306 | self.embed_dim = embed_dim 307 | self.context_length = context_length 308 | self.share_projection_layer = share_projection_layer 309 | self.has_extra = True if extra_transformer_layers > 0 else False 310 | 311 | if isinstance(vision_layers, (tuple, list)): 312 | vision_heads = vision_width * 32 // 64 313 | self.visual = ModifiedResNet( 314 | layers=vision_layers, 315 | output_dim=embed_dim, 316 | heads=vision_heads, 317 | input_resolution=image_resolution, 318 | width=vision_width 319 | ) 320 | else: 321 | vision_heads = vision_width // 64 322 | self.visual = VisualTransformer( 323 | input_resolution=image_resolution, 324 | patch_size=vision_patch_size, 325 | width=vision_width, 326 | layers=vision_layers, 327 | heads=vision_heads, 328 | output_dim=embed_dim 329 | ) 330 | self.transformer_width = transformer_width 331 | self.transformer = Transformer( 332 | width=transformer_width, 333 | layers=transformer_layers, 334 | heads=transformer_heads, 335 | attn_mask=self.build_attention_mask() 336 | ) 337 | if extra_transformer_layers > 0: 338 | self.extra_transformer = Transformer( 339 | width=transformer_width, 340 | layers=extra_transformer_layers, 341 | heads=transformer_heads, 342 | attn_mask=self.build_attention_mask() 343 | ) 344 | self.extra_ln_final = LayerNorm(transformer_width) 345 | 346 | self.vocab_size = vocab_size 347 | self.end_id = self.vocab_size -1 348 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 349 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 350 | self.ln_final = LayerNorm(transformer_width) 351 | 352 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 353 | if not share_projection_layer: 354 | self.extra_text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 355 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 356 | 357 | self.initialize_parameters() 358 | 359 | def initialize_parameters(self): 360 | nn.init.normal_(self.token_embedding.weight, std=0.02) 361 | nn.init.normal_(self.positional_embedding, std=0.01) 362 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 363 | 364 | if isinstance(self.visual, ModifiedResNet): 365 | if self.visual.attnpool is not None: 366 | std = self.visual.attnpool.c_proj.in_features ** -0.5 367 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 368 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 369 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 370 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 371 | 372 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 373 | for name, param in resnet_block.named_parameters(): 374 | if name.endswith("bn3.weight"): 375 | nn.init.zeros_(param) 376 | 377 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 378 | attn_std = self.transformer.width ** -0.5 379 | fc_std = (2 * self.transformer.width) ** -0.5 380 | for block in self.transformer.resblocks: 381 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 382 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 383 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 384 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 385 | 386 | if self.text_projection is not None: 387 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 388 | if hasattr(self, 'extra_text_projection'): 389 | nn.init.normal_(self.extra_text_projection, std=self.transformer.width ** -0.5) 390 | 391 | def build_attention_mask(self): 392 | # lazily create causal attention mask, with full attention between the vision tokens 393 | # pytorch uses additive attention mask; fill with -inf 394 | mask = torch.empty(self.context_length, self.context_length) 395 | mask.fill_(float("-inf")) 396 | mask.triu_(1) # zero out the lower diagonal 397 | return mask 398 | 399 | @property 400 | def dtype(self): 401 | return self.visual.conv1.weight.dtype 402 | 403 | def encode_image(self, image): 404 | return self.visual(image.type(self.dtype)) 405 | 406 | def encode_text(self, text): 407 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 408 | 409 | x = x + self.positional_embedding.type(self.dtype) 410 | x = x.permute(1, 0, 2) # NLD -> LND 411 | x = self.transformer(x) 412 | x = x.permute(1, 0, 2) # LND -> NLD 413 | x = self.ln_final(x).type(self.dtype) 414 | # x.shape = [batch_size, n_ctx, transformer.width] 415 | # take features from the eot embedding (eot_token is the highest number in each sequence) 416 | collect_ind = text == self.end_id 417 | collect_ind = collect_ind.nonzero()[:, 1] 418 | x = x[torch.arange(x.size(0)), collect_ind] @ self.text_projection 419 | return x 420 | 421 | 422 | def encode_text_img(self, text, img_tokens): 423 | b_size = img_tokens.size(0) 424 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 425 | collect_ind = text == self.end_id 426 | collect_ind = collect_ind.nonzero()[:, 1] 427 | img_tokens = img_tokens.view(b_size, 1, -1) 428 | x = torch.cat([x[:, :collect_ind[0]], img_tokens, x[:, collect_ind[0]:-1]], dim=1) 429 | x = x + self.positional_embedding.type(self.dtype) 430 | x = x.permute(1, 0, 2) # NLD -> LND 431 | x = self.transformer(x) 432 | x = x.permute(1, 0, 2) # LND -> NLD 433 | x = self.ln_final(x).type(self.dtype) 434 | # x.shape = [batch_size, n_ctx, transformer.width] 435 | # take features from the eot embedding (eot_token is the highest number in each sequence) 436 | x = x[torch.arange(x.size(0)), collect_ind+1] @ self.text_projection 437 | return x 438 | 439 | def encode_text_img_vis(self, text, img_tokens, split_ind=4): 440 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 441 | collect_ind = text == self.end_id 442 | collect_ind = collect_ind.nonzero()[:, 1] 443 | new_x = [] 444 | for i, sample in enumerate(x): 445 | ind_insert = text[i] == split_ind 446 | sample = sample.view(1, x.size(1), -1) 447 | if isinstance(img_tokens, tuple): 448 | indexes = ind_insert.nonzero() 449 | for i, index in enumerate(indexes): 450 | img = img_tokens[i].view(1, 1, -1) 451 | sample = torch.cat([sample[:, :index], img, sample[:, index+1:]], dim=1) 452 | else: 453 | img_tokens = img_tokens.view(1, 1, -1) 454 | ind_insert = ind_insert.nonzero()[0] 455 | sample = torch.cat([sample[:, :ind_insert], img_tokens, sample[:, ind_insert+1:]], dim=1) 456 | new_x.append(sample) 457 | x = torch.cat(new_x, dim=0) 458 | x = x + self.positional_embedding.type(self.dtype) 459 | x = x.permute(1, 0, 2) # NLD -> LND 460 | x = self.transformer(x) 461 | x = x.permute(1, 0, 2) # LND -> NLD 462 | x = self.ln_final(x).type(self.dtype) 463 | # x.shape = [batch_size, n_ctx, transformer.width] 464 | # take features from the eot embedding (eot_token is the highest number in each sequence) 465 | x = x[torch.arange(x.size(0)), collect_ind] @ self.text_projection 466 | return x 467 | 468 | def encode_text_img_retrieval(self, text, img_tokens, split_ind=4, repeat=True): 469 | # text.shape = [1, n_ctx] 470 | # img_tokens.shape = [batch_size, d_model] 471 | if isinstance(img_tokens, tuple): 472 | b_size = img_tokens[0].shape[0] 473 | else: 474 | b_size = img_tokens.shape[0] 475 | if repeat: 476 | text = text.repeat(b_size, 1) 477 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 478 | collect_ind = text == self.end_id 479 | collect_ind = collect_ind.nonzero()[:, 1] 480 | ind_insert = text[0] == split_ind 481 | if isinstance(img_tokens, tuple): 482 | indexes = ind_insert.nonzero() 483 | for i, index in enumerate(indexes): 484 | img = img_tokens[i].view(b_size, 1, -1) 485 | x = torch.cat([x[:, :index], img, x[:, index+1:]], dim=1) 486 | else: 487 | img_tokens = img_tokens.view(b_size, 1, -1) 488 | ind_insert = ind_insert.nonzero()[0] 489 | x = torch.cat([x[:, :ind_insert], img_tokens, x[:, ind_insert+1:]], dim=1) 490 | #x = torch.cat([x, torch.zeros_like(x).cuda()[:, :1, :]], dim=1) 491 | x = x + self.positional_embedding.type(self.dtype) 492 | x = x.permute(1, 0, 2) # NLD -> LND 493 | x = self.transformer(x) 494 | x = x.permute(1, 0, 2) # LND -> NLD 495 | x = self.ln_final(x).type(self.dtype) 496 | # x.shape = [batch_size, n_ctx, transformer.width] 497 | # take features from the eot embedding (eot_token is the highest number in each sequence) 498 | x = x[torch.arange(x.size(0)), collect_ind] @ self.text_projection 499 | return x 500 | 501 | def forward(self, image, text, extra=False): 502 | if image is None: 503 | if extra: 504 | return self.encode_text_extra(text) 505 | else: 506 | return self.encode_text(text) 507 | elif text is None: 508 | return self.encode_image(image) 509 | image_features = self.encode_image(image) 510 | if extra: 511 | text_features = self.encode_text_extra(text) 512 | else: 513 | text_features = self.encode_text(text) 514 | 515 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 516 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 517 | 518 | return image_features, text_features, self.logit_scale.exp() 519 | 520 | 521 | @torch.no_grad() 522 | def concat_all_gather(tensor): 523 | """ 524 | Performs all_gather operation on the provided tensors. 525 | *** Warning ***: torch.distributed.all_gather has no gradient. 526 | """ 527 | tensors_gather = [torch.ones_like(tensor) 528 | for _ in range(torch.distributed.get_world_size())] 529 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 530 | 531 | output = torch.cat(tensors_gather, dim=0) 532 | return output 533 | 534 | def convert_weights(model: nn.Module): 535 | """Convert applicable model parameters to fp16""" 536 | 537 | def _convert_weights_to_fp16(l): 538 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 539 | l.weight.data = l.weight.data.half() 540 | if l.bias is not None: 541 | l.bias.data = l.bias.data.half() 542 | 543 | if isinstance(l, nn.MultiheadAttention): 544 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 545 | tensor = getattr(l, attr) 546 | if tensor is not None: 547 | tensor.data = tensor.data.half() 548 | 549 | for name in ["text_projection", "proj"]: 550 | if hasattr(l, name): 551 | attr = getattr(l, name) 552 | if attr is not None: 553 | attr.data = attr.data.half() 554 | 555 | model.apply(_convert_weights_to_fp16) 556 | 557 | 558 | def build_model(state_dict: dict): 559 | vit = "visual.proj" in state_dict 560 | 561 | if vit: 562 | vision_width = state_dict["visual.conv1.weight"].shape[0] 563 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 564 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 565 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 566 | image_resolution = vision_patch_size * grid_size 567 | else: 568 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 569 | vision_layers = tuple(counts) 570 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 571 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 572 | vision_patch_size = None 573 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 574 | image_resolution = output_width * 32 575 | 576 | embed_dim = state_dict["text_projection"].shape[1] 577 | context_length = state_dict["positional_embedding"].shape[0] 578 | vocab_size = state_dict["token_embedding.weight"].shape[0] 579 | transformer_width = state_dict["ln_final.weight"].shape[0] 580 | transformer_heads = transformer_width // 64 581 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 582 | 583 | model = CLIP( 584 | embed_dim, 585 | image_resolution, vision_layers, vision_width, vision_patch_size, 586 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 587 | ) 588 | 589 | for key in ["input_resolution", "context_length", "vocab_size"]: 590 | if key in state_dict: 591 | del state_dict[key] 592 | 593 | convert_weights(model) 594 | model.load_state_dict(state_dict) 595 | return model.eval() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scikit-image 3 | scikit-learn 4 | torch 5 | torchvision 6 | tensorboard 7 | ase==3.21.1 8 | braceexpand==0.1.7 9 | cached-property==1.5.2 10 | configparser==5.0.2 11 | cycler==0.10.0 12 | decorator==4.4.2 13 | docker-pycreds==0.4.0 14 | gitdb==4.0.7 15 | gitpython==3.1.30 16 | googledrivedownloader==0.4 17 | h5py==3.1.0 18 | isodate==0.6.0 19 | jinja2==3.0.1 20 | kiwisolver==1.3.1 21 | littleutils==0.2.2 22 | llvmlite==0.36.0 23 | markupsafe==2.0.1 24 | matplotlib==3.3.4 25 | networkx==2.5.1 26 | numba==0.53.1 27 | ogb==1.3.1 28 | outdated==0.2.1 29 | pathtools==0.1.2 30 | promise==2.3 31 | psutil==5.8.0 32 | pyarrow==4.0.0 33 | pyparsing==2.4.7 34 | python-louvain==0.15 35 | pyyaml==5.4.1 36 | rdflib==5.0.0 37 | sentry-sdk==1.14.0 38 | shortuuid==1.0.1 39 | sklearn==0.0 40 | smmap==4.0.0 41 | subprocess32==3.5.4 42 | torch-geometric==1.7.0 43 | wandb==0.10.30 44 | wilds==1.1.0 45 | ftfy 46 | regex 47 | webdataset 48 | requests 49 | hydra-core 50 | omegaconf 51 | fairseq==0.10.0 52 | bitarray -------------------------------------------------------------------------------- /setenv.sh: -------------------------------------------------------------------------------- 1 | sudo apt install python3-dev python3-virtualenv python3-tk imagemagick 2 | virtualenv -p python3 --system-site-packages env3 3 | . env3/bin/activate 4 | pip install -r requirements.txt 5 | deactivate 6 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import sys 17 | import math 18 | import logging 19 | import functools 20 | import braceexpand 21 | import random 22 | import pdb 23 | import json 24 | 25 | import pandas as pd 26 | import numpy as np 27 | import pyarrow as pa 28 | from PIL import Image 29 | Image.MAX_IMAGE_PIXELS = 1000000000 30 | 31 | from typing import Union 32 | from dataclasses import dataclass 33 | import torch 34 | import torch.distributed as dist 35 | from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler 36 | from torch.utils.data.distributed import DistributedSampler 37 | import torchvision.datasets as datasets 38 | from torchvision.datasets.folder import DatasetFolder 39 | import torchvision.datasets as datasets 40 | import torchvision.transforms as T 41 | from third_party.open_clip.clip import tokenize 42 | 43 | 44 | ## Structure of dataset directory 45 | ## CIRR: under ./data/CIRR 46 | ## validation images ./dev/ 47 | ## caption split ./captions/cap.rc2.val.json 48 | ## image split ./image_splits/split.rc2.val.json 49 | class CIRR(Dataset): 50 | def __init__(self, transforms, mode='caps', 51 | vis_mode=False, test=False, root='./data'): 52 | self.mode = mode 53 | self.transforms = transforms 54 | self.vis_mode = vis_mode 55 | ## mode to use test split of CIRR 56 | self.test = test 57 | self.root = os.path.join(root, 'CIRR') 58 | self.root_img = os.path.join(self.root, 'dev') 59 | if self.test: 60 | self.root_img = os.path.join(self.root, 'test1') 61 | if self.mode == 'caps': 62 | self.json = os.path.join(self.root , 'captions/cap.rc2.test1.json') 63 | else: 64 | self.json = os.path.join(self.root, 'image_splits/split.rc2.test1.json') 65 | else: 66 | if self.mode == 'caps': 67 | self.json = os.path.join(self.root, 'captions/cap.rc2.val.json') 68 | else: 69 | self.json = os.path.join(self.root, 'image_splits/split.rc2.val.json') 70 | logging.debug(f'Loading json data from {self.json}.') 71 | data = json.load(open(self.json, "r")) 72 | self.ref_imgs = [] 73 | self.target_imgs = [] 74 | self.target_caps = [] 75 | if self.test: 76 | self.init_test(data) 77 | elif self.mode == 'caps': 78 | self.init_val(data) 79 | else: 80 | self.target_imgs = [key + ".png" for key in data.keys()] 81 | if self.vis_mode: 82 | self.target_imgs = list(set(self.target_imgs)) 83 | logging.info("Use {} imgs".format(len(self.target_imgs))) 84 | 85 | def init_test(self, data): 86 | self.pairids = [] 87 | if self.mode == 'caps': 88 | for d in data: 89 | ref_path = d['reference']+ ".png" 90 | self.ref_imgs.append(ref_path) 91 | self.target_caps.append(d['caption']) 92 | self.pairids.append(d['pairid']) 93 | self.target_imgs.append('dummy') 94 | else: 95 | self.target_imgs = [key + ".png" for key in data.keys()] 96 | 97 | def init_val(self, data): 98 | for d in data: 99 | ref_path = d['reference']+ ".png" 100 | tar_path = d['target_hard']+ ".png" 101 | self.ref_imgs.append(ref_path) 102 | self.target_imgs.append(tar_path) 103 | self.target_caps.append(d['caption']) 104 | 105 | def return_testdata(self, idx): 106 | if self.mode == 'caps': 107 | ref_path = str(self.ref_imgs[idx]) 108 | img_path = os.path.join(self.root_img, ref_path) 109 | ref_images = self.transforms(Image.open(img_path)) 110 | target_cap = self.target_caps[idx] 111 | text_with_blank_raw = 'a photo of * , {}'.format(target_cap) 112 | caption_only = tokenize(target_cap)[0] 113 | text_with_blank = tokenize(text_with_blank_raw)[0] 114 | return ref_images, text_with_blank, \ 115 | caption_only, str(self.ref_imgs[idx]), \ 116 | self.pairids[idx], text_with_blank_raw 117 | else: 118 | tar_path = str(self.target_imgs[idx]) 119 | img_path = Image.open(os.path.join(self.root_img, tar_path)) 120 | target_images = self.transforms(img_path) 121 | return target_images, tar_path 122 | 123 | def return_valdata(self, idx): 124 | if self.mode == 'caps' and not self.vis_mode: 125 | ref_path = str(self.ref_imgs[idx]) 126 | img_path = os.path.join(self.root_img, ref_path) 127 | ref_images = self.transforms(Image.open(img_path)) 128 | target_cap = self.target_caps[idx] 129 | text_with_blank = 'a photo of * , {}'.format(target_cap) 130 | caption_only = tokenize(target_cap)[0] 131 | ref_text_tokens = tokenize(text_with_blank)[0] 132 | return ref_images, ref_text_tokens, caption_only, \ 133 | str(self.ref_imgs[idx]), str(self.target_imgs[idx]), \ 134 | target_cap 135 | else: 136 | tar_path = str(self.target_imgs[idx]) 137 | img_path = os.path.join(self.root_img, tar_path) 138 | target_images = self.transforms(Image.open(img_path)) 139 | return target_images, img_path 140 | 141 | def __getitem__(self, idx): 142 | if self.test: 143 | return self.return_testdata(idx) 144 | else: 145 | return self.return_valdata(idx) 146 | 147 | def __len__(self): 148 | return len(self.target_imgs) 149 | 150 | ## Fashion-IQ: under ./data/fashion-iq 151 | ## validation images ./images 152 | ## caption split ./json/cap.{cloth_type}.val.json, cloth_type in [toptee, shirt, dress] 153 | ## image split ./image_splits/split.{cloth_type}.val.json, cloth_type in [toptee, shirt, dress] 154 | class FashionIQ(Dataset): 155 | def __init__(self, cloth, transforms, is_train=False, vis_mode=False, \ 156 | mode='caps', is_return_target_path=False, root='./data'): 157 | root_iq = os.path.join(root, 'fashion-iq') 158 | self.root_img = os.path.join(root_iq, 'images') 159 | self.vis_mode = vis_mode 160 | self.mode = mode 161 | self.is_return_target_path = is_return_target_path 162 | self.transforms = transforms 163 | if mode == 'imgs': 164 | self.json_file = os.path.join(root_iq, 'image_splits', \ 165 | 'split.{}.val.json'.format(cloth)) 166 | else: 167 | self.json_file = os.path.join(root_iq, 'json', \ 168 | 'cap.{}.val.json'.format(cloth)) 169 | logging.debug(f'Loading json data from {self.json_file}.') 170 | 171 | self.ref_imgs = [] 172 | self.target_imgs = [] 173 | self.ref_caps = [] 174 | self.target_caps = [] 175 | if mode == 'imgs': 176 | self.init_imgs() 177 | logging.info("Use {} imgs".format(len(self.target_imgs))) 178 | else: 179 | self.init_data() 180 | logging.info("Use {} imgs".format(len(self.target_imgs))) 181 | 182 | def init_imgs(self): 183 | data = json.load(open(self.json_file, "r")) 184 | self.target_imgs = [key + ".png" for key in data] 185 | 186 | def init_data(self): 187 | def load_data(data): 188 | for d in data: 189 | ref_path = os.path.join(self.root_img, d['candidate']+ ".png") 190 | tar_path = os.path.join(self.root_img, d['target']+ ".png") 191 | try: 192 | Image.open(ref_path) 193 | Image.open(tar_path) 194 | self.ref_imgs.append(ref_path) 195 | self.target_imgs.append(tar_path) 196 | self.ref_caps.append((d['captions'][0], d['captions'][1])) 197 | #self.target_caps.append(d['captions'][1]) 198 | except: 199 | print('cannot load {}'.format(d['candidate'])) 200 | if isinstance(self.json_file, str): 201 | data = json.load(open(self.json_file, "r")) 202 | load_data(data) 203 | elif isinstance(self.json_file, list): 204 | for filename in self.json_file: 205 | data = json.load(open(filename, "r")) 206 | load_data(data) 207 | 208 | def __len__(self): 209 | if self.mode == 'caps': 210 | return len(self.ref_imgs) 211 | else: 212 | return len(self.target_imgs) 213 | 214 | def return_imgs(self, idx): 215 | tar_path = str(self.target_imgs[idx]) 216 | img_path = os.path.join(self.root_img, tar_path) 217 | target_images = self.transforms(Image.open(img_path)) 218 | return target_images, os.path.join(self.root_img, tar_path) 219 | 220 | def return_all(self, idx): 221 | if self.vis_mode: 222 | tar_path = str(self.target_imgs[idx]) 223 | target_images = self.transforms(Image.open(tar_path)) 224 | return target_images, tar_path 225 | ref_images = self.transforms(Image.open(str(self.ref_imgs[idx]))) 226 | target_images = self.transforms(Image.open(str(self.target_imgs[idx]))) 227 | cap1, cap2 = self.ref_caps[idx] 228 | text_with_blank = 'a photo of * , {} and {}'.format(cap2, cap1) 229 | token_texts = tokenize(text_with_blank)[0] 230 | if self.is_return_target_path: 231 | return ref_images, target_images, token_texts, token_texts, \ 232 | str(self.target_imgs[idx]), str(self.ref_imgs[idx]), \ 233 | cap1 234 | else: 235 | return ref_images, target_images, text_with_blank 236 | 237 | 238 | def __getitem__(self, idx): 239 | if self.mode == 'imgs': 240 | return self.return_imgs(idx) 241 | else: 242 | return self.return_all(idx) 243 | 244 | ## COCO: under ./data/coco 245 | ## validation images ./val2017 246 | ## validation masked images ./val2017_masked 247 | ## validation csv file ./coco_eval.csv 248 | class CsvCOCO(Dataset): 249 | def __init__(self, transforms, transforms_region, sep=",", 250 | return_data_identifier=False, return_filename=False, 251 | root='./data'): 252 | self.transforms = transforms 253 | self.transforms_region = transforms_region 254 | self.root = os.path.join(root, 'coco') 255 | self.root_img = os.path.join(self.root, 'val2017') 256 | self.csv_file = os.path.join(self.root, 'coco_eval.csv') 257 | logging.debug(f'Loading csv data from {self.csv_file}.') 258 | df = pd.read_csv(self.csv_file, sep=sep) 259 | self.images = df['id'].tolist() 260 | ## query_region contains the box of query regions. 261 | regions = df['query_regions'].tolist() 262 | self.regions = [] 263 | for region in regions: 264 | x1, y1, x2, y2 = map(lambda x: int(float(x)), region.split(";")) 265 | self.regions.append([x1, y1, x2, y2]) 266 | 267 | ## query_classes contains the class of query region in the target. 268 | self.query_classes = df['query_class'].tolist() 269 | self.classes = [] 270 | ## classes contains the list of classes in the target. 271 | for list_class in df['classes'].tolist(): 272 | if isinstance(list_class, str): 273 | list_class = list_class.split(";") 274 | self.classes.append(list_class) 275 | else: 276 | self.classes.append([""]) 277 | self.return_data_identifier = return_data_identifier 278 | logging.debug('Done loading data.') 279 | self.return_filename = return_filename 280 | 281 | def __len__(self): 282 | return len(self.images) 283 | 284 | def __getitem__(self, idx): 285 | img_path = os.path.join(self.root_img, str(self.images[idx])) 286 | image = Image.open(img_path) 287 | masked_path = os.path.join(self.root_img.replace('val2017', 'val2017_masked'), \ 288 | str(self.images[idx])) 289 | image_masked = Image.open(masked_path) 290 | 291 | ## extract query region. 292 | x1, y1, x2, y2 = self.regions[idx] 293 | region_image = image_masked.crop((x1, y1, x2, y2)) 294 | 295 | image = self.transforms(image) 296 | ## no cropping is applied to query region. 297 | region_image = self.transforms_region(region_image) 298 | query_class = self.query_classes[idx] 299 | other_classes = self.classes[idx] 300 | text_with_blank = 'a photo of * and {}'.format(" and ".join(other_classes)) 301 | text_with_queryclass = 'a photo of * and {} and {}'.format(query_class, \ 302 | " and ".join(other_classes)) 303 | raw_text = text_with_queryclass 304 | text_full = 'a photo of {} and {}'.format(query_class, " and ".join(other_classes)) 305 | text_with_blank = tokenize(text_with_blank)[0] 306 | text_with_queryclass = tokenize(text_with_queryclass)[0] 307 | text_full = tokenize(text_full)[0] 308 | return image, region_image, text_full, text_with_blank, \ 309 | text_with_queryclass, str(self.images[idx]), raw_text 310 | 311 | 312 | class ImageList(Dataset): 313 | def __init__(self, input_filename, transforms, root=None, 314 | return_filename=False, is_labels=False): 315 | logging.debug(f'Loading txt data from {input_filename}.') 316 | with open(input_filename, 'r') as f: 317 | lines = f.readlines() 318 | if not is_labels: 319 | self.images = [line.strip() for line in lines] 320 | else: 321 | filenames = [line.strip() for line in lines] 322 | self.images = [name.split(" ")[0] for name in filenames] 323 | self.labels = [int(name.split(" ")[1]) for name in filenames] 324 | self.is_labels = is_labels 325 | self.transforms = transforms 326 | self.root = root 327 | logging.debug('Done loading data.') 328 | self.return_filename = return_filename 329 | 330 | def __len__(self): 331 | return len(self.images) 332 | 333 | def __getitem__(self, idx): 334 | if self.root is not None: 335 | img_path = os.path.join(self.root, str(self.images[idx])) 336 | else: 337 | img_path = str(self.images[idx]) 338 | images = self.transforms(Image.open(img_path)) 339 | if self.return_filename: 340 | return images, img_path 341 | elif self.is_labels: 342 | target = self.labels[idx] 343 | return images, target 344 | else: 345 | return images 346 | 347 | 348 | class CustomFolder(Dataset): 349 | def __init__(self, folder, transform): 350 | image_lists = os.listdir(folder) 351 | self.samples = [os.path.join(folder, name) for name in image_lists] 352 | self.transform = transform 353 | 354 | def __len__(self): 355 | return len(self.samples) 356 | 357 | def __getitem__(self, index: int): 358 | """ 359 | Args: 360 | index (int): Index 361 | 362 | Returns: 363 | tuple: (sample, target) where target is class_index of the target class. 364 | """ 365 | path = self.samples[index] 366 | sample = Image.open(str(path)) 367 | if self.transform is not None: 368 | sample = self.transform(sample) 369 | return sample, path 370 | 371 | 372 | class CsvDataset(Dataset): 373 | def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t", 374 | return_data_identifier=False, return_filename=False): 375 | logging.debug(f'Loading csv data from {input_filename}.') 376 | df = pd.read_csv(input_filename, sep=sep) 377 | self.images = df[img_key].tolist() 378 | self.captions = df[caption_key].tolist() 379 | self.transforms = transforms 380 | self.return_data_identifier = return_data_identifier 381 | logging.debug('Done loading data of {} samples'.format(len(self.images))) 382 | self.return_filename = return_filename 383 | 384 | def __len__(self): 385 | return len(self.captions) 386 | 387 | def __getitem__(self, idx): 388 | images = self.transforms(Image.open(str(self.images[idx]))) 389 | if self.return_filename: 390 | return images, str(self.images[idx]) 391 | texts = tokenize([str(self.captions[idx])])[0] 392 | 393 | if self.return_data_identifier: 394 | return images, texts, 0 395 | return images, texts 396 | 397 | @dataclass 398 | class DataInfo: 399 | dataloader: DataLoader 400 | sampler: DistributedSampler 401 | 402 | def preprocess_txt(text): 403 | return tokenize([str(text)])[0] 404 | 405 | def get_dataset_size(shards): 406 | shards_list = list(braceexpand.braceexpand(shards)) 407 | dir_path = os.path.dirname(shards) 408 | sizes_filename = os.path.join(dir_path, 'sizes.json') 409 | sizes = json.load(open(sizes_filename, 'r')) 410 | total_size = sum( 411 | [int(sizes[os.path.basename(shard)]) for shard in shards_list]) 412 | num_shards = len(shards_list) 413 | return total_size, num_shards 414 | 415 | def get_imagenet(args, preprocess_fns, split): 416 | assert split in ["train", "val", "v2"] 417 | is_train = split == "train" 418 | preprocess_train, preprocess_val = preprocess_fns 419 | 420 | if split == "v2": 421 | from imagenetv2_pytorch import ImageNetV2Dataset 422 | dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) 423 | else: 424 | if is_train: 425 | data_path = args.imagenet_train 426 | preprocess_fn = preprocess_train 427 | else: 428 | data_path = args.imagenet_val 429 | preprocess_fn = preprocess_val 430 | assert data_path 431 | 432 | dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) 433 | 434 | if is_train: 435 | idxs = np.zeros(len(dataset.targets)) 436 | target_array = np.array(dataset.targets) 437 | k = 50 438 | for c in range(1000): 439 | m = target_array == c 440 | n = len(idxs[m]) 441 | arr = np.zeros(n) 442 | arr[:k] = 1 443 | np.random.shuffle(arr) 444 | idxs[m] = arr 445 | 446 | idxs = idxs.astype('int') 447 | sampler = SubsetRandomSampler(np.where(idxs)[0]) 448 | else: 449 | sampler = None 450 | 451 | dataloader = torch.utils.data.DataLoader( 452 | dataset, 453 | batch_size=args.batch_size, 454 | num_workers=args.workers, 455 | sampler=sampler, 456 | ) 457 | return DataInfo(dataloader, sampler) 458 | 459 | def count_samples(dataloader): 460 | os.environ["WDS_EPOCH"] = "0" 461 | n_elements, n_batches = 0, 0 462 | for images, texts in dataloader: 463 | n_batches += 1 464 | n_elements += len(images) 465 | assert len(images) == len(texts) 466 | return n_elements, n_batches 467 | 468 | def get_csv_dataset(args, preprocess_fn, is_train, input_filename=None): 469 | if input_filename is None: 470 | input_filename = args.train_data if is_train else args.val_data 471 | assert input_filename 472 | dataset = CsvDataset( 473 | input_filename, 474 | preprocess_fn, 475 | img_key=args.csv_img_key, 476 | caption_key=args.csv_caption_key, 477 | sep=args.csv_separator) 478 | 479 | num_samples = len(dataset) 480 | sampler = DistributedSampler(dataset) if args.distributed and is_train else None 481 | shuffle = is_train and sampler is None 482 | 483 | dataloader = DataLoader( 484 | dataset, 485 | batch_size=args.batch_size, 486 | shuffle=shuffle, 487 | num_workers=args.workers, 488 | pin_memory=True, 489 | sampler=sampler, 490 | drop_last=is_train, 491 | ) 492 | dataloader.num_samples = num_samples 493 | dataloader.num_batches = len(dataloader) 494 | 495 | return DataInfo(dataloader, sampler) 496 | 497 | 498 | # 499 | def get_imgnet_r(args, preprocess_fn, is_train, input_filename=None): 500 | if input_filename is None: 501 | input_filename = args.train_data if is_train else args.val_data 502 | assert input_filename 503 | path_data = os.path.join(args.root_data, 'imgnet/imagenet-r') 504 | dataset = CustomFolder(path_data, transform=preprocess_fn) 505 | num_samples = len(dataset) 506 | sampler = DistributedSampler(dataset) if args.distributed and is_train else None 507 | shuffle = is_train and sampler is None 508 | dataloader = DataLoader( 509 | dataset, 510 | batch_size=args.batch_size, 511 | shuffle=shuffle, 512 | num_workers=args.workers, 513 | pin_memory=True, 514 | sampler=sampler, 515 | drop_last=is_train, 516 | ) 517 | dataloader.num_samples = num_samples 518 | dataloader.num_batches = len(dataloader) 519 | return DataInfo(dataloader, sampler) 520 | 521 | 522 | def get_directory_dataset(args, preprocess_fn, is_train, input_filename=None): 523 | if input_filename is None: 524 | input_filename = args.train_data if is_train else args.val_data 525 | assert input_filename 526 | dataset = CustomFolder( 527 | input_filename, 528 | transform=preprocess_fn) 529 | num_samples = len(dataset) 530 | sampler = DistributedSampler(dataset) if args.distributed and is_train else None 531 | shuffle = is_train and sampler is None 532 | 533 | dataloader = DataLoader( 534 | dataset, 535 | batch_size=args.batch_size, 536 | shuffle=shuffle, 537 | num_workers=args.workers, 538 | pin_memory=True, 539 | sampler=sampler, 540 | drop_last=is_train, 541 | ) 542 | dataloader.num_samples = num_samples 543 | dataloader.num_batches = len(dataloader) 544 | 545 | return DataInfo(dataloader, sampler) 546 | 547 | 548 | def get_dataset_fn(data_path, dataset_type): 549 | if dataset_type == 'imgnet_r': 550 | return get_imgnet_r 551 | elif dataset_type == 'fashion-iq': 552 | return get_fashion_iq 553 | elif dataset_type == 'cirr': 554 | return get_cirr 555 | elif dataset_type == 'directory': 556 | return get_directory_dataset 557 | elif dataset_type == "csv": 558 | return get_csv_dataset 559 | elif dataset_type == "auto": 560 | ext = data_path.split('.')[-1] 561 | if ext in ['csv', 'tsv']: 562 | return get_csv_dataset 563 | else: 564 | raise ValueError( 565 | f"Tried to figure out dataset type, but failed for extention {ext}.") 566 | else: 567 | raise ValueError(f"Unsupported dataset type: {dataset_type}") 568 | 569 | 570 | def get_data(args, preprocess_fns): 571 | preprocess_train, preprocess_val = preprocess_fns 572 | data = {} 573 | dataset_type_val = getattr(args, 'dataset_type_val', args.dataset_type) 574 | if args.train_data: 575 | data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( 576 | args, preprocess_train, is_train=True) 577 | if args.val_data: 578 | data["val"] = get_dataset_fn(args.val_data, dataset_type_val)( 579 | args, preprocess_val, is_train=False) 580 | if args.imagenet_val is not None: 581 | data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val") 582 | if args.imagenet_v2 is not None: 583 | data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2") 584 | return data 585 | -------------------------------------------------------------------------------- /src/demo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import time 16 | import logging 17 | from time import gmtime, strftime 18 | from pathlib import Path 19 | import json 20 | import torch 21 | import torch.distributed as dist 22 | import torch.multiprocessing as mp 23 | import torch.backends.cudnn as cudnn 24 | from torch.utils.tensorboard import SummaryWriter 25 | from torch.utils.data import DataLoader 26 | from model.clip import _transform, load 27 | from model.model import convert_weights, CLIP, IM2TEXT 28 | from eval_utils import visualize_results 29 | from data import get_data, CsvDataset, CustomFolder, CIRR, FashionIQ, ImageList 30 | from params import parse_args, get_project_root 31 | from logger import setup_primary_logging, setup_worker_logging 32 | from utils import is_master, convert_models_to_fp32, TargetPad 33 | 34 | def main_worker(gpu, ngpus_per_node, log_queue, args): 35 | args.gpu = gpu 36 | args.rank = gpu 37 | setup_worker_logging(args.rank, log_queue, args.log_level) 38 | 39 | # Log and save params. 40 | if is_master(args): 41 | logging.info("Params:") 42 | params_file = os.path.join(args.logs, args.name, "params.txt") 43 | with open(params_file, "w") as f: 44 | for name in sorted(vars(args)): 45 | val = getattr(args, name) 46 | logging.info(f"{name}: {val}") 47 | f.write(f"{name}: {val}\n") 48 | 49 | if args.distributed: 50 | dist.init_process_group( 51 | backend=args.dist_backend, 52 | init_method=args.dist_url, 53 | world_size=args.world_size, 54 | rank=args.rank, 55 | ) 56 | 57 | if args.dp: 58 | args.batch_size *= args.world_size 59 | 60 | if args.gpu is not None: 61 | logging.info(f"Use GPU: {args.gpu} for training") 62 | torch.cuda.set_device(args.gpu) 63 | 64 | # Do not use skip_reset unless you want to use on of the CLIP model 65 | if args.openai_pretrained: 66 | model, preprocess_train, preprocess_val = load( 67 | args.model, 68 | jit=False) 69 | else: 70 | model_config_file = Path(__file__).parent / f"model_configs/{args.model.replace('/', '-')}.json" 71 | print('Loading model from', model_config_file) 72 | assert os.path.exists(model_config_file) 73 | with open(model_config_file, 'r') as f: 74 | model_info = json.load(f) 75 | if args.use_prefix: 76 | model_info['vocab_size'] += 1 77 | model_info['use_prefix'] = True 78 | model = CLIP(**model_info) 79 | convert_weights(model) 80 | preprocess_train = _transform(model.visual.input_resolution, is_train=True) 81 | preprocess_val = _transform(model.visual.input_resolution, is_train=False) 82 | img2text = IM2TEXT(embed_dim=model.embed_dim, output_dim=model.token_embedding.weight.shape[1]) 83 | 84 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 85 | if args.precision == "amp" or args.precision == "fp32" or args.gpu is None: 86 | convert_models_to_fp32(model) 87 | 88 | if not torch.cuda.is_available(): 89 | model.float() 90 | img2text.float() 91 | logging.warning("using CPU, this will be slow") 92 | else: 93 | model.cuda(args.gpu) 94 | img2text.cuda(args.gpu) 95 | if args.precision == "fp16": 96 | convert_weights(model) 97 | convert_weights(img2text) 98 | # Previously batch size and workers were global and not per GPU. 99 | # args.batch_size = args.batch_size / ngpus_per_node) 100 | # args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 101 | 102 | if args.distributed and args.use_bn_sync: 103 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 104 | if args.distributed: 105 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=model.has_extra) 106 | img2text = torch.nn.parallel.DistributedDataParallel(img2text, device_ids=[args.gpu], find_unused_parameters=False) 107 | if args.dp: 108 | model = torch.nn.DataParallel(model, device_ids=args.multigpu) 109 | img2text = torch.nn.DataParallel(img2text, device_ids=args.multigpu) 110 | 111 | if args.precision == "fp16": 112 | convert_weights(model) 113 | convert_weights(img2text) 114 | 115 | data = get_data(args, (preprocess_train, preprocess_val)) 116 | if args.resume == 'auto': 117 | checkpoint_list = os.listdir(args.checkpoint_path) 118 | checkpoint_list = [ckpt for ckpt in checkpoint_list if ckpt.startswith('epoch')] 119 | if checkpoint_list: 120 | latest_epoch = max([int(ckpt.split('_')[1].split('.')[0]) for ckpt in checkpoint_list]) 121 | args.resume = os.path.join(args.checkpoint_path, f'epoch_{latest_epoch}.pt') 122 | else: 123 | args.resume = None 124 | 125 | if args.resume is not None: 126 | if os.path.isfile(args.resume): 127 | if args.gpu is None: 128 | checkpoint = torch.load(args.resume) 129 | else: 130 | # Map model to be loaded to specified single gpu. 131 | loc = "cuda:{}".format(args.gpu) 132 | checkpoint = torch.load(args.resume, map_location=loc) 133 | sd = checkpoint["state_dict"] 134 | sd_img2text = checkpoint["state_dict_img2text"] 135 | if not args.distributed and next(iter(sd.items()))[0].startswith('module'): 136 | sd = {k[len('module.'):]: v for k, v in sd.items()} 137 | if not args.distributed and next(iter(sd_img2text.items()))[0].startswith('module'): 138 | sd_img2text = {k[len('module.'):]: v for k, v in sd_img2text.items()} 139 | model.load_state_dict(sd) 140 | img2text.load_state_dict(sd_img2text) 141 | logging.info( 142 | f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})" 143 | ) 144 | else: 145 | logging.info("=> no checkpoint found at '{}'".format(args.resume)) 146 | cudnn.benchmark = True 147 | cudnn.deterministic = False 148 | prompt = args.prompts.split(",") 149 | root_project = os.path.join(get_project_root(), 'data') 150 | logging.info("root dir '{}'".format(root_project)) 151 | logging.info("prompt list '{}'".format(prompt)) 152 | 153 | if "csv" in args.retrieval_data: 154 | dataset = CsvDataset( 155 | args.retrieval_data, 156 | preprocess_val, 157 | img_key=args.csv_img_key, 158 | caption_key=args.csv_caption_key, 159 | sep=args.csv_separator, 160 | return_filename=True) 161 | elif args.retrieval_data == 'imgnet': 162 | target_path = os.path.join(root_project, "imgnet", "imgnet_targets.txt") 163 | dataset = ImageList(target_path, root=root_project, transforms=preprocess_val, 164 | is_labels=True, return_filename=True) 165 | elif args.retrieval_data == 'cirr': 166 | dataset = CIRR( 167 | transforms=preprocess_val, 168 | root=root_project, 169 | mode='caps', 170 | vis_mode=True, 171 | ) 172 | elif args.retrieval_data in ['dress', 'shirt', 'toptee']: 173 | dataset = FashionIQ(cloth=args.retrieval_data, 174 | transforms=preprocess_val, 175 | root=root_project, 176 | mode='caps', 177 | vis_mode=True) 178 | elif args.retrieval_data == 'coco': 179 | dataset = CustomFolder(os.path.join(root_project, "coco/val2017"), transform=preprocess_val) 180 | else: 181 | raise ValueError 182 | dataloader = DataLoader( 183 | dataset, 184 | batch_size=args.batch_size, 185 | shuffle=False, 186 | num_workers=args.workers, 187 | pin_memory=True, 188 | drop_last=False, 189 | ) 190 | visualize_results(model, img2text, args, prompt, dataloader, ) 191 | 192 | 193 | def main(): 194 | args = parse_args() 195 | 196 | # get the name of the experiments 197 | if args.name is None: 198 | args.name = (f"lr={args.lr}_" 199 | "wd={args.wd}_" 200 | "agg={args.aggregate}_" 201 | "model={args.model}_" 202 | "batchsize={args.batch_size}_workers={args.workers}") 203 | if args.time_suffix: 204 | args.name += "_date=%Y-%m-%d-%H-%M-%S" 205 | args.name = strftime(args.name, gmtime()) 206 | 207 | if args.copy_codebase: 208 | import sys, subprocess 209 | from shutil import copytree, ignore_patterns 210 | new_code_path = os.path.join(args.logs, args.name, "code") 211 | if os.path.exists(new_code_path): 212 | print( 213 | f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." 214 | ) 215 | return -1 216 | print(f"Copying codebase to {new_code_path}") 217 | current_code_path = os.path.realpath(__file__) 218 | for _ in range(3): 219 | current_code_path = os.path.dirname(current_code_path) 220 | copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb')) 221 | print("Done copying code.") 222 | os.environ["PYTHONPATH"] = f"{os.environ['PYTHONPATH']}:{os.path.join(new_code_path, 'src')}" 223 | main_file = os.path.join(new_code_path, "src", "training", "main.py") 224 | argv = sys.argv 225 | argv.remove('--copy-codebase') 226 | argv.extend(['--name', args.name]) 227 | command = [sys.executable] + argv 228 | print("Executing command:", " ".join(command)) 229 | subprocess.check_call(command) 230 | return 1 231 | 232 | args.log_path = os.path.join(args.logs, args.name, "out.log") 233 | if os.path.exists(args.log_path) and args.resume is None: 234 | print( 235 | "Error. Experiment already exists. Use --name {} to specify a new experiment." 236 | ) 237 | return -1 238 | 239 | assert args.precision in ['amp', 'fp16', 'fp32'] 240 | #assert args.model in ['RN50', 'RN101', 'RN50x4', 'ViT-B/32'] or os.path.exists(args.model) 241 | 242 | args.ngpus_per_node = torch.cuda.device_count() 243 | 244 | args.wandb = 'wandb' in args.report_to or 'all' in args.report_to 245 | args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to 246 | 247 | args.tensorboard_path = os.path.join(args.logs, args.name, "tensorboard") if args.tensorboard else '' 248 | args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") 249 | for dirname in [args.tensorboard_path, args.checkpoint_path]: 250 | if dirname: 251 | os.makedirs(dirname, exist_ok=True) 252 | 253 | 254 | # Set multiprocessing type to spawn. 255 | # This is important for logging to work with multiprocessing. 256 | torch.multiprocessing.set_start_method("spawn") 257 | 258 | # Set logger 259 | args.log_level = logging.DEBUG if args.debug else logging.INFO 260 | log_queue = setup_primary_logging(args.log_path, args.log_level) 261 | 262 | # Distributed training = training on more than one GPU. 263 | # Also easily possible to extend to multiple nodes & multiple GPUs. 264 | args.distributed = (args.gpu is None) and torch.cuda.is_available() and (not args.dp) 265 | if args.distributed: 266 | ngpus_per_node = torch.cuda.device_count() 267 | args.world_size = ngpus_per_node 268 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, log_queue, args)) 269 | else: 270 | if args.dp: 271 | args.gpu = args.multigpu[0] 272 | args.world_size = len(args.multigpu) 273 | else: 274 | args.world_size = 1 275 | main_worker(args.gpu, None, log_queue, args) 276 | 277 | 278 | if __name__ == "__main__": 279 | main() 280 | -------------------------------------------------------------------------------- /src/eval_retrieval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import time 17 | import logging 18 | from time import gmtime, strftime 19 | from pathlib import Path 20 | import json 21 | from functools import partial 22 | import wandb 23 | import torch 24 | from torch import optim 25 | import torch.distributed as dist 26 | import torch.multiprocessing as mp 27 | import torch.backends.cudnn as cudnn 28 | from torch.utils.tensorboard import SummaryWriter 29 | from torch.cuda.amp import GradScaler 30 | from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler 31 | import torchvision.datasets as datasets 32 | import torchvision.transforms as T 33 | from PIL import Image 34 | 35 | from model.clip import _transform, load 36 | from model.model import convert_weights, CLIP, IM2TEXT 37 | from eval_utils import evaluate_imgnet_retrieval, evaluate_coco, evaluate_fashion, evaluate_cirr, evaluate_cirr_test 38 | from data import CsvDataset, CustomFolder, ImageList, CsvCOCO, FashionIQ, CIRR 39 | from params import parse_args, get_project_root 40 | from logger import setup_primary_logging, setup_worker_logging 41 | from utils import is_master, convert_models_to_fp32, TargetPad 42 | 43 | def load_model(args): 44 | model, _, preprocess_val = load( 45 | args.model, 46 | jit=False) 47 | img2text = IM2TEXT(embed_dim=model.embed_dim, 48 | middle_dim=args.middle_dim, 49 | output_dim=model.token_embedding.weight.shape[1], 50 | n_layer=args.n_layer) 51 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 52 | if args.precision == "amp" or args.precision == "fp32" or args.gpu is None: 53 | convert_models_to_fp32(model) 54 | 55 | if not torch.cuda.is_available(): 56 | model.float() 57 | img2text.float() 58 | logging.warning("using CPU, this will be slow") 59 | else: 60 | model.cuda(args.gpu) 61 | img2text.cuda(args.gpu) 62 | if args.precision == "fp16": 63 | convert_weights(model) 64 | convert_weights(img2text) 65 | # Previously batch size and workers were global and not per GPU. 66 | # args.batch_size = args.batch_size / ngpus_per_node) 67 | # args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 68 | if args.distributed and args.use_bn_sync: 69 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 70 | if args.distributed: 71 | model = torch.nn.parallel.DistributedDataParallel(model, 72 | device_ids=[args.gpu], 73 | find_unused_parameters=model.has_extra) 74 | img2text = torch.nn.parallel.DistributedDataParallel(img2text, 75 | device_ids=[args.gpu], find_unused_parameters=False) 76 | if args.dp: 77 | model = torch.nn.DataParallel(model, device_ids=args.multigpu) 78 | img2text = torch.nn.DataParallel(img2text, device_ids=args.multigpu) 79 | 80 | if args.precision == "fp16": 81 | convert_weights(model) 82 | convert_weights(img2text) 83 | if args.resume == 'auto': 84 | checkpoint_list = os.listdir(args.checkpoint_path) 85 | checkpoint_list = [ckpt for ckpt in checkpoint_list if ckpt.startswith('epoch')] 86 | if checkpoint_list: 87 | latest_epoch = max([int(ckpt.split('_')[1].split('.')[0]) for ckpt in checkpoint_list]) 88 | args.resume = os.path.join(args.checkpoint_path, f'epoch_{latest_epoch}.pt') 89 | else: 90 | args.resume = None 91 | 92 | assert args.resume is not None 93 | if os.path.isfile(args.resume): 94 | if args.gpu is None: 95 | checkpoint = torch.load(args.resume) 96 | else: 97 | # Map model to be loaded to specified single gpu. 98 | loc = "cuda:{}".format(args.gpu) 99 | checkpoint = torch.load(args.resume, map_location=loc) 100 | sd = checkpoint["state_dict"] 101 | sd_img2text = checkpoint["state_dict_img2text"] 102 | if not args.distributed and next(iter(sd.items()))[0].startswith('module'): 103 | sd = {k[len('module.'):]: v for k, v in sd.items()} 104 | if not args.distributed and next(iter(sd_img2text.items()))[0].startswith('module'): 105 | sd_img2text = {k[len('module.'):]: v for k, v in sd_img2text.items()} 106 | model.load_state_dict(sd) 107 | img2text.load_state_dict(sd_img2text) 108 | logging.info( 109 | f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})" 110 | ) 111 | else: 112 | logging.info("=> no checkpoint found at '{}'".format(args.resume)) 113 | return model, img2text, preprocess_val 114 | 115 | def setup_log_save(args): 116 | if is_master(args): 117 | logging.info("Params:") 118 | params_file = os.path.join(args.logs, args.name, "params.txt") 119 | with open(params_file, "w") as f: 120 | for name in sorted(vars(args)): 121 | val = getattr(args, name) 122 | logging.info(f"{name}: {val}") 123 | f.write(f"{name}: {val}\n") 124 | 125 | if args.distributed: 126 | dist.init_process_group( 127 | backend=args.dist_backend, 128 | init_method=args.dist_url, 129 | world_size=args.world_size, 130 | rank=args.rank, 131 | ) 132 | if args.dp: 133 | args.batch_size *= args.world_size 134 | if args.gpu is not None: 135 | logging.info(f"Use GPU: {args.gpu} for training") 136 | torch.cuda.set_device(args.gpu) 137 | 138 | 139 | def main_worker(gpu, ngpus_per_node, log_queue, args): 140 | args.gpu = gpu 141 | args.rank = gpu 142 | setup_worker_logging(args.rank, log_queue, args.log_level) 143 | # Log and save params. 144 | setup_log_save(args) 145 | # Load trained model 146 | model, img2text, preprocess_val = load_model(args) 147 | cudnn.benchmark = True 148 | cudnn.deterministic = False 149 | root_project = os.path.join(get_project_root(), 'data') 150 | ## Padding option 151 | if args.target_pad: 152 | trans_tmp = preprocess_val.transforms 153 | trans_tmp = [TargetPad(1.25)] + trans_tmp 154 | preprocess_train = T.Compose(trans_tmp) 155 | preprocess_val = preprocess_train 156 | 157 | ## Load data for each evaluation dataset and perform evaluation. 158 | if args.eval_mode == 'coco': 159 | trans_val = preprocess_val.transforms 160 | n_px = trans_val[1].size 161 | trans_val = [T.Resize(n_px, interpolation=Image.BICUBIC)] + trans_val[2:] 162 | preprocess_val_region = T.Compose(trans_val) 163 | source_dataset = CsvCOCO(transforms=preprocess_val, 164 | transforms_region=preprocess_val_region, 165 | root=root_project) 166 | source_dataloader = DataLoader( 167 | source_dataset, 168 | batch_size=args.batch_size, 169 | shuffle=False, 170 | num_workers=args.workers, 171 | pin_memory=True, 172 | drop_last=False) 173 | evaluate_coco(model, img2text, args, source_dataloader) 174 | 175 | elif args.eval_mode == 'cirr': 176 | source_dataset = CIRR(transforms=preprocess_val, 177 | root=root_project) 178 | target_dataset = CIRR(transforms=preprocess_val, 179 | root=root_project, 180 | mode='imgs') 181 | source_dataloader = DataLoader( 182 | source_dataset, 183 | batch_size=args.batch_size, 184 | shuffle=False, 185 | num_workers=args.workers, 186 | pin_memory=True, 187 | drop_last=False) 188 | target_dataloader = DataLoader( 189 | target_dataset, 190 | batch_size=args.batch_size, 191 | shuffle=False, 192 | num_workers=args.workers, 193 | pin_memory=True, 194 | drop_last=False) 195 | evaluate_cirr(model, 196 | img2text, 197 | args, 198 | source_dataloader, 199 | target_dataloader) 200 | 201 | elif args.eval_mode == 'cirr_test': 202 | source_dataset = CIRR(transforms=preprocess_val, 203 | root=root_project, test=True) 204 | target_dataset = CIRR(transforms=preprocess_val, 205 | root=root_project, 206 | mode='imgs', 207 | test=True) 208 | source_dataloader = DataLoader( 209 | source_dataset, 210 | batch_size=args.batch_size, 211 | shuffle=False, 212 | num_workers=args.workers, 213 | pin_memory=True, 214 | drop_last=False) 215 | target_dataloader = DataLoader( 216 | target_dataset, 217 | batch_size=args.batch_size, 218 | shuffle=False, 219 | num_workers=args.workers, 220 | pin_memory=True, 221 | drop_last=False) 222 | results = evaluate_cirr_test(model, 223 | img2text, 224 | args, 225 | source_dataloader, 226 | target_dataloader) 227 | for key, value in results.items(): 228 | with open('res_cirr/' + key + '.json', 'w') as f: 229 | json.dump(value, f) 230 | 231 | elif args.eval_mode == 'fashion': 232 | assert args.source_data in ['dress', 'shirt', 'toptee'] 233 | source_dataset = FashionIQ(cloth=args.source_data, 234 | transforms=preprocess_val, 235 | root=root_project, 236 | is_return_target_path=True) 237 | target_dataset = FashionIQ(cloth=args.source_data, 238 | transforms=preprocess_val, 239 | root=root_project, 240 | mode='imgs') 241 | source_dataloader = DataLoader( 242 | source_dataset, 243 | batch_size=args.batch_size, 244 | shuffle=False, 245 | num_workers=args.workers, 246 | pin_memory=True, 247 | drop_last=False) 248 | target_dataloader = DataLoader( 249 | target_dataset, 250 | batch_size=args.batch_size, 251 | shuffle=False, 252 | num_workers=args.workers, 253 | pin_memory=True, 254 | drop_last=False) 255 | evaluate_fashion(model, img2text, args, source_dataloader, target_dataloader) 256 | elif args.eval_mode == 'imgnet': 257 | domains = ['cartoon', 'origami', 'toy', 'sculpture'] 258 | prompt = ["a {} of *".format(domain) for domain in domains] 259 | source_path = os.path.join(root_project, "imgnet", "imgnet_real_query.txt") 260 | target_path = os.path.join(root_project, "imgnet", "imgnet_targets.txt") 261 | source_dataset = ImageList(source_path, root=root_project, transforms=preprocess_val, is_labels=True) 262 | target_dataset = ImageList(target_path, root=root_project, transforms=preprocess_val, is_labels=True) 263 | eval_func = evaluate_imgnet_retrieval 264 | source_dataloader = DataLoader( 265 | source_dataset, 266 | batch_size=args.batch_size, 267 | shuffle=False, 268 | num_workers=args.workers, 269 | pin_memory=True, 270 | drop_last=False) 271 | target_dataloader = DataLoader( 272 | target_dataset, 273 | batch_size=args.batch_size, 274 | shuffle=False, 275 | num_workers=args.workers, 276 | pin_memory=True, 277 | drop_last=False) 278 | eval_func(model, img2text, args, prompt, source_dataloader, target_dataloader) 279 | 280 | def main(): 281 | args = parse_args() 282 | 283 | # get the name of the experiments 284 | if args.name is None: 285 | args.name = (f"lr={args.lr}_" 286 | "wd={args.wd}_" 287 | "agg={args.aggregate}_" 288 | "model={args.model}_" 289 | "batchsize={args.batch_size}_workers={args.workers}") 290 | if args.time_suffix: 291 | args.name += "_date=%Y-%m-%d-%H-%M-%S" 292 | args.name = strftime(args.name, gmtime()) 293 | 294 | if args.copy_codebase: 295 | import sys, subprocess 296 | from shutil import copytree, ignore_patterns 297 | new_code_path = os.path.join(args.logs, args.name, "code") 298 | if os.path.exists(new_code_path): 299 | print( 300 | f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." 301 | ) 302 | return -1 303 | print(f"Copying codebase to {new_code_path}") 304 | current_code_path = os.path.realpath(__file__) 305 | for _ in range(3): 306 | current_code_path = os.path.dirname(current_code_path) 307 | copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb')) 308 | print("Done copying code.") 309 | os.environ["PYTHONPATH"] = f"{os.environ['PYTHONPATH']}:{os.path.join(new_code_path, 'src')}" 310 | main_file = os.path.join(new_code_path, "src", "training", "main.py") 311 | argv = sys.argv 312 | argv.remove('--copy-codebase') 313 | argv.extend(['--name', args.name]) 314 | command = [sys.executable] + argv 315 | print("Executing command:", " ".join(command)) 316 | subprocess.check_call(command) 317 | return 1 318 | 319 | args.log_path = os.path.join(args.logs, args.name, "out.log") 320 | if os.path.exists(args.log_path) and args.resume is None: 321 | print( 322 | "Error. Experiment already exists. Use --name {} to specify a new experiment." 323 | ) 324 | return -1 325 | 326 | assert args.precision in ['amp', 'fp16', 'fp32'] 327 | #assert args.model in ['RN50', 'RN101', 'RN50x4', 'ViT-B/32'] or os.path.exists(args.model) 328 | 329 | args.ngpus_per_node = torch.cuda.device_count() 330 | 331 | args.wandb = 'wandb' in args.report_to or 'all' in args.report_to 332 | args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to 333 | 334 | args.tensorboard_path = os.path.join(args.logs, args.name, "tensorboard") if args.tensorboard else '' 335 | args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") 336 | for dirname in [args.tensorboard_path, args.checkpoint_path]: 337 | if dirname: 338 | os.makedirs(dirname, exist_ok=True) 339 | 340 | 341 | # Set multiprocessing type to spawn. 342 | # This is important for logging to work with multiprocessing. 343 | torch.multiprocessing.set_start_method("spawn") 344 | 345 | # Set logger 346 | args.log_level = logging.DEBUG if args.debug else logging.INFO 347 | log_queue = setup_primary_logging(args.log_path, args.log_level) 348 | args.world_size = 1 349 | try: 350 | main_worker(args.gpu, None, log_queue, args) 351 | except: 352 | print('evaluation done') 353 | 354 | 355 | if __name__ == "__main__": 356 | main() 357 | -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import logging 17 | from logging import Filter 18 | from logging.handlers import QueueHandler, QueueListener 19 | 20 | import torch 21 | import torch.distributed as dist 22 | import torch.multiprocessing as mp 23 | from torch.multiprocessing import Queue 24 | 25 | 26 | def setup_primary_logging(log_file, level): 27 | log_queue = Queue(-1) 28 | 29 | file_handler = logging.FileHandler(filename=log_file) 30 | stream_handler = logging.StreamHandler() 31 | 32 | formatter = logging.Formatter( 33 | '%(asctime)s | %(levelname)s | %(message)s', 34 | datefmt='%Y-%m-%d,%H:%M:%S') 35 | 36 | file_handler.setFormatter(formatter) 37 | stream_handler.setFormatter(formatter) 38 | 39 | file_handler.setLevel(level) 40 | stream_handler.setLevel(level) 41 | 42 | listener = QueueListener(log_queue, file_handler, stream_handler) 43 | 44 | listener.start() 45 | 46 | return log_queue 47 | 48 | 49 | class WorkerLogFilter(Filter): 50 | def __init__(self, rank=-1): 51 | super().__init__() 52 | self._rank = rank 53 | 54 | def filter(self, record): 55 | if self._rank != -1: 56 | record.msg = f"Rank {self._rank} | {record.msg}" 57 | return True 58 | 59 | 60 | def setup_worker_logging(rank, log_queue, level): 61 | queue_handler = QueueHandler(log_queue) 62 | 63 | worker_filter = WorkerLogFilter(rank) 64 | queue_handler.addFilter(worker_filter) 65 | 66 | queue_handler.setLevel(level) 67 | 68 | root_logger = logging.getLogger() 69 | root_logger.addHandler(queue_handler) 70 | 71 | root_logger.setLevel(level) 72 | 73 | 74 | def fake_worker(rank: int, world_size: int, log_queue: Queue): 75 | setup_worker_logging(rank, log_queue, logging.DEBUG) 76 | logging.info("Test worker log") 77 | logging.error("Test worker error log") 78 | torch.cuda.set_device(rank) 79 | dist.init_process_group( 80 | backend='nccl', 81 | init_method='tcp://127.0.0.1:6100', 82 | world_size=world_size, 83 | rank=rank, 84 | ) 85 | 86 | if __name__ == "__main__": 87 | # Set multiprocessing type to spawn 88 | torch.multiprocessing.set_start_method("spawn") 89 | 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument("-g", "--gpu-list", type=int, help="List of GPU IDs", nargs="+", required=True) 92 | 93 | args = parser.parse_args() 94 | 95 | world_size = len(args.gpu_list) 96 | 97 | # Initialize the primary logging handlers. Use the returned `log_queue` 98 | # to which the worker processes would use to push their messages 99 | log_queue = setup_primary_logging("/usr/lusers/gamaga/out.log", logging.DEBUG) 100 | 101 | if world_size == 1: 102 | worker(0, world_size, log_queue) 103 | else: 104 | mp.spawn(fake_worker, args=(world_size, log_queue), nprocs=world_size) -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import time 16 | import logging 17 | from time import gmtime, strftime 18 | from pathlib import Path 19 | import json 20 | import wandb 21 | import torch 22 | from torch import optim 23 | import torch.distributed as dist 24 | import torch.multiprocessing as mp 25 | import torch.backends.cudnn as cudnn 26 | from torch.utils.tensorboard import SummaryWriter 27 | from torch.cuda.amp import GradScaler 28 | from third_party.open_clip.scheduler import cosine_lr 29 | from model.clip import _transform, load 30 | from model.model import convert_weights, CLIP, IM2TEXT 31 | from trainer import train 32 | from data import get_data 33 | from params import parse_args 34 | from logger import setup_primary_logging, setup_worker_logging 35 | from utils import is_master, convert_models_to_fp32 36 | import torchvision.transforms as T 37 | 38 | def main_worker(gpu, ngpus_per_node, log_queue, args): 39 | args.gpu = gpu 40 | args.rank = gpu 41 | setup_worker_logging(args.rank, log_queue, args.log_level) 42 | 43 | # Log and save params. 44 | if is_master(args): 45 | logging.info("Params:") 46 | params_file = os.path.join(args.logs, args.name, "params.txt") 47 | with open(params_file, "w") as f: 48 | for name in sorted(vars(args)): 49 | val = getattr(args, name) 50 | logging.info(f"{name}: {val}") 51 | f.write(f"{name}: {val}\n") 52 | 53 | if args.distributed: 54 | dist.init_process_group( 55 | backend=args.dist_backend, 56 | init_method=args.dist_url, 57 | world_size=args.world_size, 58 | rank=args.rank, 59 | ) 60 | 61 | if args.dp: 62 | args.batch_size *= args.world_size 63 | 64 | if args.gpu is not None: 65 | logging.info(f"Use GPU: {args.gpu} for training") 66 | torch.cuda.set_device(args.gpu) 67 | 68 | # Do not use skip_reset unless you want to use on of the CLIP model 69 | if args.openai_pretrained: 70 | model, preprocess_train, preprocess_val = load( 71 | args.model, 72 | jit=False) 73 | else: 74 | model_config_file = Path(__file__).parent / f"model_configs/{args.model.replace('/', '-')}.json" 75 | print('Loading model from', model_config_file) 76 | assert os.path.exists(model_config_file) 77 | with open(model_config_file, 'r') as f: 78 | model_info = json.load(f) 79 | if args.use_prefix: 80 | model_info['vocab_size'] += 1 81 | model_info['use_prefix'] = True 82 | model = CLIP(**model_info) 83 | convert_weights(model) 84 | preprocess_train = _transform(model.visual.input_resolution, is_train=True) 85 | preprocess_val = _transform(model.visual.input_resolution, is_train=False) 86 | try: 87 | img2text = IM2TEXT(embed_dim=model.embed_dim, 88 | middle_dim=args.middle_dim, 89 | output_dim=model.token_embedding.weight.shape[1], 90 | n_layer=args.n_layer) 91 | except: 92 | img2text = IM2TEXT(embed_dim=1024, output_dim=1024, 93 | is_normalize=args.normalize_output, is_mlp=args.use_mlp, n_layer=args.n_layer) 94 | 95 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 96 | if args.precision == "amp" or args.precision == "fp32" or args.gpu is None: 97 | convert_models_to_fp32(model) 98 | 99 | if not torch.cuda.is_available(): 100 | model.float() 101 | img2text.float() 102 | logging.warning("using CPU, this will be slow") 103 | else: 104 | model.cuda(args.gpu) 105 | img2text.cuda(args.gpu) 106 | if args.precision == "fp16": 107 | convert_weights(model) 108 | convert_weights(img2text) 109 | # Previously batch size and workers were global and not per GPU. 110 | # args.batch_size = args.batch_size / ngpus_per_node) 111 | # args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 112 | 113 | if args.distributed and args.use_bn_sync: 114 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 115 | if args.distributed: 116 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 117 | img2text = torch.nn.parallel.DistributedDataParallel(img2text, device_ids=[args.gpu], find_unused_parameters=False) 118 | if args.dp: 119 | model = torch.nn.DataParallel(model, device_ids=args.multigpu) 120 | img2text = torch.nn.DataParallel(img2text, device_ids=args.multigpu) 121 | 122 | if args.precision == "fp16": 123 | convert_weights(model) 124 | convert_weights(img2text) 125 | 126 | data = get_data(args, (preprocess_train, preprocess_val)) 127 | exclude = lambda n : "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n 128 | include = lambda n : not exclude(n) 129 | named_parameters = list(img2text.named_parameters()) 130 | gain_or_bias_params = [p for n, p in named_parameters if exclude(n) and p.requires_grad] 131 | rest_params = [p for n, p in named_parameters if include(n) and p.requires_grad] 132 | 133 | if args.train_data is None: 134 | optimizer = None 135 | scheduler = None 136 | else: 137 | optimizer = optim.AdamW( 138 | [ 139 | {"params": gain_or_bias_params, "weight_decay": 0.}, 140 | {"params": rest_params, "weight_decay": args.wd}, 141 | ], 142 | lr=args.lr, 143 | betas=(args.beta1, args.beta2), 144 | eps=args.eps, 145 | ) 146 | total_steps = data["train"].dataloader.num_batches * args.epochs 147 | scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) 148 | 149 | scaler = GradScaler() if args.precision == "amp" else None 150 | 151 | # optionally resume from a checkpoint 152 | start_epoch = 0 153 | if args.resume == 'auto': 154 | checkpoint_list = os.listdir(args.checkpoint_path) 155 | checkpoint_list = [ckpt for ckpt in checkpoint_list if ckpt.startswith('epoch')] 156 | if checkpoint_list: 157 | latest_epoch = max([int(ckpt.split('_')[1].split('.')[0]) for ckpt in checkpoint_list]) 158 | args.resume = os.path.join(args.checkpoint_path, f'epoch_{latest_epoch}.pt') 159 | else: 160 | args.resume = None 161 | 162 | if args.resume is not None: 163 | if os.path.isfile(args.resume): 164 | if args.gpu is None: 165 | checkpoint = torch.load(args.resume) 166 | else: 167 | # Map model to be loaded to specified single gpu. 168 | loc = "cuda:{}".format(args.gpu) 169 | checkpoint = torch.load(args.resume, map_location=loc) 170 | start_epoch = checkpoint["epoch"] 171 | sd = checkpoint["state_dict"] 172 | sd_img2text = checkpoint["state_dict_img2text"] 173 | if not args.distributed and next(iter(sd.items()))[0].startswith('module'): 174 | sd = {k[len('module.'):]: v for k, v in sd.items()} 175 | if not args.distributed and next(iter(sd_img2text.items()))[0].startswith('module'): 176 | sd_img2text = {k[len('module.'):]: v for k, v in sd_img2text.items()} 177 | model.load_state_dict(sd) 178 | img2text.load_state_dict(sd_img2text) 179 | if optimizer is not None: 180 | optimizer.load_state_dict(checkpoint["optimizer"]) 181 | logging.info( 182 | f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})" 183 | ) 184 | else: 185 | logging.info("=> no checkpoint found at '{}'".format(args.resume)) 186 | 187 | cudnn.benchmark = True 188 | cudnn.deterministic = False 189 | # determine if this worker should save logs and checkpoints. 190 | # only do so if it is the 0th worker. 191 | args.save_logs = (args.logs is not None and args.logs != '' and args.logs.lower() != 'none') and ( 192 | (not args.distributed) or args.gpu == 0 193 | ) 194 | writer = None 195 | if args.save_logs and args.tensorboard: 196 | writer = SummaryWriter(args.tensorboard_path) 197 | 198 | if args.wandb and is_master(args): 199 | logging.debug('Starting wandb.') 200 | args.train_sz = data["train"].dataloader.num_samples 201 | if args.val_data is not None: 202 | args.val_sz = data["val"].dataloader.num_samples 203 | # you will have to configure this for your project! 204 | wandb.init( 205 | project="open-clip", 206 | notes=args.wandb_notes, 207 | tags=[], 208 | config=vars(args), 209 | ) 210 | if args.debug: 211 | wandb.watch(model, log='all') 212 | wandb.save(params_file) 213 | logging.debug('Finished loading wandb.') 214 | 215 | for epoch in range(start_epoch, args.epochs): 216 | if args.gpu == 0: 217 | logging.info(f'Start epoch {epoch}') 218 | train(model, img2text, data, epoch, optimizer, scaler, scheduler, args, writer) 219 | steps = data["train"].dataloader.num_batches * (epoch + 1) 220 | # Saving checkpoints. 221 | if args.save_logs and (args.gpu == 0 or (not args.distributed)): 222 | if (epoch + 1) == args.epochs or ( 223 | args.save_frequency > 0 and ((epoch + 1) % args.save_frequency) == 0 224 | ): 225 | torch.save( 226 | { 227 | "epoch": epoch + 1, 228 | "name": args.name, 229 | "state_dict": model.state_dict(), 230 | "state_dict_img2text": img2text.state_dict(), 231 | "optimizer": optimizer.state_dict(), 232 | }, 233 | os.path.join(args.checkpoint_path, f"epoch_{epoch + 1}.pt"), 234 | ) 235 | if args.save_most_recent: 236 | torch.save( 237 | { 238 | "epoch": epoch + 1, 239 | "name": args.name, 240 | "state_dict": model.state_dict(), 241 | "state_dict_img2text": img2text.state_dict(), 242 | "optimizer": optimizer.state_dict(), 243 | }, 244 | os.path.join(args.checkpoint_path, "epoch_latest.pt"), 245 | ) 246 | 247 | if args.wandb and (args.gpu == 0 or (not args.distributed)): 248 | wandb.finish() 249 | 250 | 251 | def main(): 252 | args = parse_args() 253 | 254 | # get the name of the experiments 255 | if args.name is None: 256 | args.name = (f"lr={args.lr}_" 257 | "wd={args.wd}_" 258 | "agg={args.aggregate}_" 259 | "model={args.model}_" 260 | "batchsize={args.batch_size}_workers={args.workers}") 261 | import pdb 262 | pdb.set_trace 263 | if args.time_suffix: 264 | args.name += "_date=%Y-%m-%d-%H-%M-%S" 265 | args.name = strftime(args.name, gmtime()) 266 | 267 | if args.copy_codebase: 268 | import sys, subprocess 269 | from shutil import copytree, ignore_patterns 270 | new_code_path = os.path.join(args.logs, args.name, "code") 271 | if os.path.exists(new_code_path): 272 | print( 273 | f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." 274 | ) 275 | return -1 276 | print(f"Copying codebase to {new_code_path}") 277 | current_code_path = os.path.realpath(__file__) 278 | for _ in range(3): 279 | current_code_path = os.path.dirname(current_code_path) 280 | copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb')) 281 | print("Done copying code.") 282 | os.environ["PYTHONPATH"] = f"{os.environ['PYTHONPATH']}:{os.path.join(new_code_path, 'src')}" 283 | main_file = os.path.join(new_code_path, "src", "training", "main.py") 284 | argv = sys.argv 285 | argv.remove('--copy-codebase') 286 | argv.extend(['--name', args.name]) 287 | command = [sys.executable] + argv 288 | print("Executing command:", " ".join(command)) 289 | subprocess.check_call(command) 290 | return 1 291 | 292 | args.log_path = os.path.join(args.logs, args.name, "out.log") 293 | if os.path.exists(args.log_path) and args.resume is None: 294 | print( 295 | "Error. Experiment already exists. Use --name {} to specify a new experiment." 296 | ) 297 | return -1 298 | 299 | assert args.precision in ['amp', 'fp16', 'fp32'] 300 | #assert args.model in ['RN50', 'RN101', 'RN50x4', 'ViT-B/32'] or os.path.exists(args.model) 301 | 302 | args.ngpus_per_node = torch.cuda.device_count() 303 | 304 | args.wandb = 'wandb' in args.report_to or 'all' in args.report_to 305 | args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to 306 | 307 | args.tensorboard_path = os.path.join(args.logs, args.name, "tensorboard") if args.tensorboard else '' 308 | args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") 309 | for dirname in [args.tensorboard_path, args.checkpoint_path]: 310 | if dirname: 311 | os.makedirs(dirname, exist_ok=True) 312 | 313 | 314 | # Set multiprocessing type to spawn. 315 | # This is important for logging to work with multiprocessing. 316 | torch.multiprocessing.set_start_method("spawn") 317 | 318 | # Set logger 319 | args.log_level = logging.DEBUG if args.debug else logging.INFO 320 | log_queue = setup_primary_logging(args.log_path, args.log_level) 321 | 322 | # Distributed training = training on more than one GPU. 323 | # Also easily possible to extend to multiple nodes & multiple GPUs. 324 | args.distributed = (args.gpu is None) and torch.cuda.is_available() and (not args.dp) 325 | if args.distributed: 326 | ngpus_per_node = torch.cuda.device_count() 327 | args.world_size = ngpus_per_node 328 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, log_queue, args)) 329 | else: 330 | if args.dp: 331 | args.gpu = args.multigpu[0] 332 | args.world_size = len(args.multigpu) 333 | else: 334 | args.world_size = 1 335 | main_worker(args.gpu, None, log_queue, args) 336 | 337 | 338 | if __name__ == "__main__": 339 | main() 340 | -------------------------------------------------------------------------------- /src/params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import argparse 15 | from pathlib import Path 16 | 17 | def get_project_root(): 18 | return Path(__file__).parent.parent 19 | 20 | def get_default_params(model_name): 21 | # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) 22 | if model_name in ["RN50", "RN101", "RN50x4", "RN50x64", "RN50x16", "RN50_flat", "RN50_t1", "RN50_t2", "RN50_t3", "RN50_t4", "RN50_t5", "RN50_t6", 23 | "RN50_flat_ft", "RN50_t1_pos_ft", "RN50_t2_pos_ft", "RN50_t1_pos", "RN50_t2_pos", 24 | "RN50_flat_large", "RN50_t1_large", "RN50_t2_large", 25 | "RN50_a2", "RN50_a2s", "ViT-H-14"]: 26 | return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} 27 | elif model_name in ["ViT-B/32", "ViT-L/14", "ViT-B/16"]: 28 | return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} 29 | else: 30 | return {} 31 | 32 | 33 | def parse_args(): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("--no-time-suffix", 36 | default=True, 37 | action="store_false", 38 | help="Whether to append current time in the suffix.", 39 | dest="time_suffix") 40 | parser.add_argument( 41 | "--train-data", 42 | type=str, 43 | default=None, 44 | help="Path to csv filewith training data", 45 | ) 46 | parser.add_argument( 47 | "--val-data", 48 | type=str, 49 | default=None, 50 | help="Path to csv file with validation data", 51 | ) 52 | parser.add_argument( 53 | "--prompts", 54 | type=str, 55 | default=None, 56 | help="list of prompts split with ,", 57 | ) 58 | parser.add_argument( 59 | "--retrieval-data", 60 | type=str, 61 | default=None, 62 | help="Path to csv file or folder of retrieval data", 63 | ) 64 | parser.add_argument( 65 | "--demo-out", 66 | type=str, 67 | default="demo", 68 | help="Path to the output directory for visualization", 69 | ) 70 | parser.add_argument( 71 | "--source-data", 72 | type=str, 73 | default=None, 74 | help="Path to txt file of retrieval data", 75 | ) 76 | parser.add_argument( 77 | "--target-data", 78 | type=str, 79 | default=None, 80 | help="Path to txt file of retrieval data", 81 | ) 82 | parser.add_argument( 83 | "--target-pad", 84 | action="store_true", 85 | default=False, 86 | help="Padding augmentation proposed by combiner.", 87 | ) 88 | parser.add_argument( 89 | "--query_file", 90 | type=str, 91 | default=None, 92 | help="Path to query image file for retrieval visualization", 93 | ) 94 | parser.add_argument("--eval-mode", 95 | type=str, 96 | choices=["coco", "cirr", "cirr_test", "fashion", "imgnet"], 97 | default="coco", 98 | help="Evaluate Pacs") 99 | parser.add_argument("--middle_dim", 100 | default=512, 101 | type=int, 102 | help="Number of hidden units in mapping network.") 103 | parser.add_argument("--droprate", 104 | default=0.1, 105 | type=float, 106 | help="Dropout rate.") 107 | parser.add_argument( 108 | "--n-layer", type=int, default=2, help="Number of layers in im2text" 109 | ) 110 | parser.add_argument( 111 | "--dataset-type", 112 | choices=["webdataset", "csv", "inet", "auto", "inet,csv", "csv,inet", "directory", "fashion-iq", "cirr", "imgnet_r"], 113 | default="auto", 114 | help="Which type of dataset to process." 115 | ) 116 | parser.add_argument( 117 | "--dataset-type-val", 118 | choices=["webdataset", "csv", "inet", "auto"], 119 | default="auto", 120 | help="Which type of dataset to process." 121 | ) 122 | parser.add_argument( 123 | "--csv-separator", 124 | type=str, 125 | default="\t", 126 | help="For csv-like datasets, which separator to use." 127 | ) 128 | parser.add_argument( 129 | "--csv-img-key", 130 | type=str, 131 | default="filepath", 132 | help="For csv-like datasets, the name of the key for the image paths." 133 | ) 134 | parser.add_argument( 135 | "--csv-caption-key", 136 | type=str, 137 | default="title", 138 | help="For csv-like datasets, the name of the key for the captions." 139 | ) 140 | parser.add_argument( 141 | "--imagenet-val", 142 | type=str, 143 | default=None, 144 | help="Path to imagenet val set for conducting zero shot evaluation.", 145 | ) 146 | parser.add_argument( 147 | "--imagenet-v2", 148 | type=str, 149 | default=None, 150 | help="Path to imagenet v2 for conducting zero shot evaluation.", 151 | ) 152 | parser.add_argument( 153 | "--logs", 154 | type=str, 155 | default="./logs/", 156 | help="Where to store tensorboard logs. Use None to avoid storing logs.", 157 | ) 158 | parser.add_argument( 159 | "--name", 160 | type=str, 161 | default=None, 162 | help="Optional identifier for the experiment when storing logs. Otherwise use current time.", 163 | ) 164 | parser.add_argument( 165 | "--workers", type=int, default=1, help="Number of workers per GPU." 166 | ) 167 | parser.add_argument( 168 | "--batch-size", type=int, default=64, help="Batch size per GPU." 169 | ) 170 | parser.add_argument( 171 | "--epochs", type=int, default=32, help="Number of epochs to train for." 172 | ) 173 | parser.add_argument("--lr", type=float, default=None, help="Learning rate.") 174 | parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") 175 | parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") 176 | parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") 177 | parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") 178 | parser.add_argument( 179 | "--warmup", type=int, default=10000, help="Number of steps to warmup for." 180 | ) 181 | parser.add_argument("--use-bn-sync", 182 | default=False, 183 | action="store_true", 184 | help="Whether to use batch norm sync.") 185 | parser.add_argument("--use-debiased-sampler", 186 | default=False, 187 | action="store_true", 188 | help="Whether to use batch norm sync.") 189 | parser.add_argument("--use-prefix", 190 | default=False, 191 | action="store_true", 192 | help="Whether to use prefix conditioning in using image classification dataset.") 193 | parser.add_argument( 194 | "--gpu", 195 | type=int, 196 | default=None, 197 | help="Specify a single GPU to run the code on for debugging." 198 | "Leave at None to use all available GPUs.", 199 | ) 200 | parser.add_argument( 201 | "--skip-scheduler", 202 | action="store_true", 203 | default=False, 204 | help="Use this flag to skip the learning rate decay.", 205 | ) 206 | parser.add_argument( 207 | "--save-frequency", type=int, default=1, help="How often to save checkpoints." 208 | ) 209 | parser.add_argument( 210 | "--save-most-recent", 211 | action="store_true", 212 | default=False, 213 | help="Always save the most recent model trained to epoch_latest.pt.", 214 | ) 215 | parser.add_argument( 216 | "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot." 217 | ) 218 | parser.add_argument( 219 | "--regression-frequency", type=int, default=2, help="How often to run zero shot." 220 | ) 221 | parser.add_argument( 222 | "--resume", 223 | default=None, 224 | type=str, 225 | help="path to latest checkpoint (default: none)", 226 | ) 227 | parser.add_argument( 228 | "--precision", 229 | choices=["amp", "fp16", "fp32"], 230 | default="amp", 231 | help="Floating point precition." 232 | ) 233 | parser.add_argument( 234 | "--model", 235 | choices=["RN50", "RN101", "RN50x4", "RN50x64", "RN50x16", "ViT-B/16", "ViT-B/32", "ViT-L/14", "ViT-H-14", 236 | "RN50_flat", "RN50_t1", "RN50_t2", "RN50_t3", "RN50_t4", "RN50_t5", "RN50_t6", 237 | "RN50_flat_ft", "RN50_t1_pos_ft", "RN50_t2_pos_ft", "RN50_t1_pos", "RN50_t2_pos", 238 | "RN50_flat_large", "RN50_t1_large", "RN50_t2_large", 239 | "RN50_a2", "RN50_a2s"], 240 | default="RN50", 241 | help="Name of the vision backbone to use.", 242 | ) 243 | parser.add_argument( 244 | "--openai-pretrained", 245 | default=False, 246 | action='store_true', 247 | help="Use the openai pretrained models.", 248 | ) 249 | # arguments for distributed training 250 | parser.add_argument( 251 | "--dist-url", 252 | default="tcp://127.0.0.1:6100", 253 | type=str, 254 | help="url used to set up distributed training", 255 | ) 256 | parser.add_argument( 257 | "--dist-backend", default="nccl", type=str, help="distributed backend" 258 | ) 259 | parser.add_argument( 260 | "--skip-aggregate", 261 | default=False, 262 | action="store_true", 263 | help="whether to aggregate features across gpus before computing the loss" 264 | ) 265 | parser.add_argument( 266 | "--report-to", 267 | default='', 268 | type=str, 269 | help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']" 270 | ) 271 | parser.add_argument( 272 | "--wandb-notes", 273 | default='', 274 | type=str, 275 | help="Notes if logging with wandb" 276 | ) 277 | parser.add_argument( 278 | "--C", type=float, default=3.16, help="inverse regularizer for logistic reg." 279 | ) 280 | parser.add_argument( 281 | "--debug", 282 | default=False, 283 | action="store_true", 284 | help="If true, more information is logged." 285 | ) 286 | parser.add_argument( 287 | "--copy-codebase", 288 | default=False, 289 | action="store_true", 290 | help="If true, we copy the entire base on the log diretory, and execute from there." 291 | ) 292 | parser.add_argument( 293 | "--dp", 294 | default=False, 295 | action="store_true", 296 | help="Use DP instead of DDP." 297 | ) 298 | parser.add_argument( 299 | "--multigpu", 300 | default=None, 301 | type=lambda x: [int(a) for a in x.split(",")], 302 | help="In DP, which GPUs to use for multigpu training", 303 | ) 304 | args = parser.parse_args() 305 | args.aggregate = not args.skip_aggregate 306 | 307 | # If some params are not passed, we use the default values based on model name. 308 | default_params = get_default_params(args.model) 309 | for name, val in default_params.items(): 310 | if getattr(args, name) is None: 311 | setattr(args, name, val) 312 | 313 | return args 314 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import time 17 | import json 18 | import numpy as np 19 | import torch 20 | import torch.nn as nn 21 | from PIL import Image 22 | 23 | from torch.cuda.amp import autocast 24 | import torch.distributed as dist 25 | from tqdm import tqdm 26 | from torchvision.utils import save_image 27 | import sys 28 | import pdb 29 | import wandb 30 | import logging 31 | import torch.nn.functional as F 32 | from third_party.open_clip.clip import tokenize, _transform 33 | from third_party.open_clip.simple_tokenizer import SimpleTokenizer 34 | from utils import is_master 35 | 36 | 37 | def get_loss(model, images, texts, loss_img, loss_txt, args, data_identifier=-1): 38 | if data_identifier == 1: 39 | # ImageNet dataset 40 | image_features, text_features, logit_scale = model(images, texts, extra=True) 41 | else: 42 | image_features, text_features, logit_scale = model(images, texts) 43 | logit_scale = logit_scale.mean() 44 | if args.distributed and args.aggregate: 45 | world_size = dist.get_world_size() 46 | rank = dist.get_rank() 47 | 48 | # We gather tensors from all gpus to get more negatives to contrast with. 49 | gathered_image_features = [ 50 | torch.zeros_like(image_features) for _ in range(world_size) 51 | ] 52 | gathered_text_features = [ 53 | torch.zeros_like(text_features) for _ in range(world_size) 54 | ] 55 | dist.all_gather(gathered_image_features, image_features) 56 | dist.all_gather(gathered_text_features, text_features) 57 | 58 | all_image_features = torch.cat( 59 | [image_features] 60 | + gathered_image_features[:rank] 61 | + gathered_image_features[rank + 1 :] 62 | ) 63 | all_text_features = torch.cat( 64 | [text_features] 65 | + gathered_text_features[:rank] 66 | + gathered_text_features[rank + 1 :] 67 | ) 68 | 69 | ground_truth = torch.arange(len(all_image_features)).long() 70 | if args.gpu is not None: 71 | ground_truth = ground_truth.cuda(args.gpu, non_blocking=True) 72 | 73 | # this is needed to send gradients back everywhere. 74 | # Image loss. 75 | logits_per_image = logit_scale * all_image_features @ all_text_features.t() 76 | loss_img_val = loss_img(logits_per_image, ground_truth) 77 | logits_per_text = logits_per_image.t() 78 | loss_txt_val = loss_txt(logits_per_text, ground_truth) 79 | else: 80 | ground_truth = torch.arange(len(image_features)).long() 81 | if args.gpu is not None: 82 | ground_truth = ground_truth.cuda(args.gpu, non_blocking=True) 83 | 84 | # Image loss. 85 | logits_per_image = logit_scale * image_features @ text_features.t() 86 | loss_img_val = loss_img(logits_per_image, ground_truth) 87 | logits_per_text = logit_scale * text_features @ image_features.t() 88 | loss_txt_val = loss_txt(logits_per_text, ground_truth) 89 | 90 | total_loss = (loss_img_val + loss_txt_val) / 2 91 | return total_loss 92 | 93 | 94 | def get_text_features(model, token_features, args): 95 | text = tokenize("a photo of") 96 | text = text.cuda(args.gpu, non_blocking=True) 97 | text = text.view(1, -1) 98 | text = text.repeat(token_features.size(0), 1) 99 | text_features = model.encode_text_img(text, token_features) 100 | return text_features 101 | 102 | def get_loss_img2text(model, img2text, images, loss_img, loss_txt, args, memory=None): 103 | with torch.no_grad(): 104 | image_features = model.encode_image(images) 105 | token_features = img2text(image_features) 106 | text_features = get_text_features(model, token_features, args) 107 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 108 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 109 | logit_scale = model.logit_scale.exp() 110 | logit_scale = logit_scale.mean() 111 | if args.distributed and args.aggregate: 112 | world_size = dist.get_world_size() 113 | rank = dist.get_rank() 114 | 115 | # We gather tensors from all gpus to get more negatives to contrast with. 116 | gathered_image_features = [ 117 | torch.zeros_like(image_features) for _ in range(world_size) 118 | ] 119 | gathered_text_features = [ 120 | torch.zeros_like(text_features) for _ in range(world_size) 121 | ] 122 | dist.all_gather(gathered_image_features, image_features) 123 | dist.all_gather(gathered_text_features, text_features) 124 | 125 | all_image_features = torch.cat( 126 | [image_features] 127 | + gathered_image_features[:rank] 128 | + gathered_image_features[rank + 1 :] 129 | ) 130 | all_text_features = torch.cat( 131 | [text_features] 132 | + gathered_text_features[:rank] 133 | + gathered_text_features[rank + 1 :] 134 | ) 135 | 136 | ground_truth = torch.arange(len(all_image_features)).long() 137 | if args.gpu is not None: 138 | ground_truth = ground_truth.cuda(args.gpu, non_blocking=True) 139 | 140 | # this is needed to send gradients back everywhere. 141 | # Image loss. 142 | logits_per_image = logit_scale * all_image_features @ all_text_features.t() 143 | loss_img_val = loss_img(logits_per_image, ground_truth) 144 | logits_per_text = logits_per_image.t() 145 | loss_txt_val = loss_txt(logits_per_text, ground_truth) 146 | else: 147 | ground_truth = torch.arange(len(image_features)).long() 148 | if args.gpu is not None: 149 | ground_truth = ground_truth.cuda(args.gpu, non_blocking=True) 150 | # Image loss. 151 | logits_per_image = logit_scale * image_features @ text_features.t() 152 | loss_img_val = loss_img(logits_per_image, ground_truth) 153 | logits_per_text = logit_scale * text_features @ image_features.t() 154 | loss_txt_val = loss_txt(logits_per_text, ground_truth) 155 | total_loss = (loss_img_val + loss_txt_val) / 2 156 | return total_loss 157 | 158 | 159 | def train(model, img2text, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None): 160 | os.environ["WDS_EPOCH"] = str(epoch) 161 | model.eval() 162 | dataloader, sampler = data['train'].dataloader, data['train'].sampler 163 | loss_img = nn.CrossEntropyLoss() 164 | loss_txt = nn.CrossEntropyLoss() 165 | if args.gpu is not None: 166 | loss_img = loss_img.cuda(args.gpu) 167 | loss_txt = loss_txt.cuda(args.gpu) 168 | 169 | if args.distributed and sampler is not None: 170 | sampler.set_epoch(epoch) 171 | 172 | num_batches_per_epoch = dataloader.num_batches 173 | 174 | end = time.time() 175 | for i, batch in enumerate(dataloader): 176 | step = num_batches_per_epoch * epoch + i 177 | scheduler(step) 178 | 179 | optimizer.zero_grad() 180 | 181 | images, texts = batch[0], batch[1] 182 | if len(batch) == 3 and args.use_debiased_sampler: 183 | data_identifier = torch.unique(batch[2])[0].numpy() 184 | else: 185 | data_identifier = -1 186 | if args.gpu is not None: 187 | images = images.cuda(args.gpu, non_blocking=True) 188 | 189 | data_time = time.time() - end 190 | 191 | m = model.module if args.distributed or args.dp else model 192 | 193 | # with automatic mixed precision. 194 | if args.precision == "amp": 195 | with autocast(): 196 | total_loss = get_loss_img2text(m, img2text, images, loss_img, loss_txt, args, data_identifier) 197 | scaler.scale(total_loss).backward() 198 | scaler.step(optimizer) 199 | scaler.update() 200 | 201 | else: 202 | total_loss = get_loss_img2text(m, img2text, images, loss_img, loss_txt, args, data_identifier) 203 | total_loss.backward() 204 | optimizer.step() 205 | 206 | # Note: we clamp to 4.6052 = ln(100), as in the original paper. 207 | #m.logit_scale.data = torch.clamp(m.logit_scale.data, 0, 4.6052) 208 | 209 | batch_time = time.time() - end 210 | end = time.time() 211 | 212 | if is_master(args) and (i % 100) == 0: 213 | num_samples = i * len(images) * args.world_size 214 | samples_per_epoch = dataloader.num_samples 215 | percent_complete = 100.0 * i / num_batches_per_epoch 216 | logging.info( 217 | f"Train Epoch: {epoch} [{num_samples}/{samples_per_epoch} ({percent_complete:.0f}%)]\t" 218 | f"Loss: {total_loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}" 219 | f"\tLR: {optimizer.param_groups[0]['lr']:5f}\tlogit_scale {m.logit_scale.data:.3f}" 220 | ) 221 | # save train loss / etc. 222 | 223 | timestep = epoch * num_batches_per_epoch + i 224 | log_data = { 225 | "loss": total_loss.item(), 226 | "data_time": data_time, 227 | "batch_time": batch_time, 228 | "scale": m.logit_scale.data.item(), 229 | "lr": optimizer.param_groups[0]["lr"] 230 | } 231 | 232 | for name, val in log_data.items(): 233 | name = "train/" + name 234 | if tb_writer is not None: 235 | tb_writer.add_scalar(name, val, timestep) 236 | if args.wandb: 237 | wandb.log({name: val, 'step': timestep}) -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import numpy as np 17 | import torch 18 | import torchvision.transforms.functional as F 19 | 20 | class TargetPad: 21 | """ 22 | Pad the image if its aspect ratio is above a target ratio. 23 | Pad the image to match such target ratio 24 | """ 25 | 26 | def __init__(self, target_ratio=1.25): 27 | """ 28 | :param target_ratio: target ratio 29 | :param size: preprocessing output dimension 30 | """ 31 | self.target_ratio = target_ratio 32 | 33 | def __call__(self, image): 34 | w, h = image.size 35 | actual_ratio = max(w, h) / min(w, h) 36 | if actual_ratio < self.target_ratio: # check if the ratio is above or below the target ratio 37 | return image 38 | scaled_max_wh = max(w, h) / self.target_ratio # rescale the pad to match the target ratio 39 | hp = max(int((scaled_max_wh - w) / 2), 0) 40 | vp = max(int((scaled_max_wh - h) / 2), 0) 41 | padding = [hp, vp, hp, vp] 42 | return F.pad(image, padding, 0, 'constant') 43 | 44 | def convert_models_to_fp32(model): 45 | for p in model.parameters(): 46 | p.data = p.data.float() 47 | if p.grad: 48 | p.grad.data = p.grad.data.float() 49 | 50 | def is_master(args): 51 | return (not args.distributed) or args.gpu == 0 or args.dp -------------------------------------------------------------------------------- /third_party/open_clip/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman, 2 | Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar, 3 | John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi, 4 | Ludwig Schmidt 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining 7 | a copy of this software and associated documentation files (the 8 | "Software"), to deal in the Software without restriction, including 9 | without limitation the rights to use, copy, modify, merge, publish, 10 | distribute, sublicense, and/or sell copies of the Software, and to 11 | permit persons to whom the Software is furnished to do so, subject to 12 | the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be 15 | included in all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 20 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 21 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 22 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 23 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /third_party/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/composed_image_retrieval/8c053297c2fae9cd17ddcded48445a4f47208dbd/third_party/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /third_party/open_clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _convert_image_to_rgb(image): 76 | return image.convert("RGB") 77 | 78 | 79 | def _transform(n_px): 80 | return Compose([ 81 | Resize(n_px, interpolation=BICUBIC), 82 | CenterCrop(n_px), 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | 88 | 89 | def available_models() -> List[str]: 90 | """Returns the names of available CLIP models""" 91 | return list(_MODELS.keys()) 92 | 93 | 94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 95 | """Load a CLIP model 96 | Parameters 97 | ---------- 98 | name : str 99 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 100 | device : Union[str, torch.device] 101 | The device to put the loaded model 102 | jit : bool 103 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 104 | download_root: str 105 | path to download the model files; by default, it uses "~/.cache/clip" 106 | Returns 107 | ------- 108 | model : torch.nn.Module 109 | The CLIP model 110 | preprocess : Callable[[PIL.Image], torch.Tensor] 111 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 112 | """ 113 | if name in _MODELS: 114 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 115 | elif os.path.isfile(name): 116 | model_path = name 117 | else: 118 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 119 | 120 | with open(model_path, 'rb') as opened_file: 121 | try: 122 | # loading JIT archive 123 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 124 | state_dict = None 125 | except RuntimeError: 126 | # loading saved state dict 127 | if jit: 128 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 129 | jit = False 130 | state_dict = torch.load(opened_file, map_location="cpu") 131 | 132 | if not jit: 133 | model = build_model(state_dict or model.state_dict()).to(device) 134 | if str(device) == "cpu": 135 | model.float() 136 | return model, _transform(model.visual.input_resolution) 137 | 138 | # patch the device names 139 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 140 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 141 | 142 | def patch_device(module): 143 | try: 144 | graphs = [module.graph] if hasattr(module, "graph") else [] 145 | except RuntimeError: 146 | graphs = [] 147 | 148 | if hasattr(module, "forward1"): 149 | graphs.append(module.forward1.graph) 150 | 151 | for graph in graphs: 152 | for node in graph.findAllNodes("prim::Constant"): 153 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 154 | node.copyAttributes(device_node) 155 | 156 | model.apply(patch_device) 157 | patch_device(model.encode_image) 158 | patch_device(model.encode_text) 159 | 160 | # patch dtype to float32 on CPU 161 | if str(device) == "cpu": 162 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 163 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 164 | float_node = float_input.node() 165 | 166 | def patch_float(module): 167 | try: 168 | graphs = [module.graph] if hasattr(module, "graph") else [] 169 | except RuntimeError: 170 | graphs = [] 171 | 172 | if hasattr(module, "forward1"): 173 | graphs.append(module.forward1.graph) 174 | 175 | for graph in graphs: 176 | for node in graph.findAllNodes("aten::to"): 177 | inputs = list(node.inputs()) 178 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 179 | if inputs[i].node()["value"] == 5: 180 | inputs[i].node().copyAttributes(float_node) 181 | 182 | model.apply(patch_float) 183 | patch_float(model.encode_image) 184 | patch_float(model.encode_text) 185 | 186 | model.float() 187 | 188 | return model, _transform(model.input_resolution.item()) 189 | 190 | 191 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 192 | """ 193 | Returns the tokenized representation of given input string(s) 194 | Parameters 195 | ---------- 196 | texts : Union[str, List[str]] 197 | An input string or a list of input strings to tokenize 198 | context_length : int 199 | The context length to use; all CLIP models use 77 as the context length 200 | truncate: bool 201 | Whether to truncate the text in case its encoding is longer than the context length 202 | Returns 203 | ------- 204 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 205 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 206 | """ 207 | if isinstance(texts, str): 208 | texts = [texts] 209 | 210 | sot_token = _tokenizer.encoder["<|startoftext|>"] 211 | eot_token = _tokenizer.encoder["<|endoftext|>"] 212 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 213 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 214 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 215 | else: 216 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 217 | 218 | for i, tokens in enumerate(all_tokens): 219 | if len(tokens) > context_length: 220 | if truncate: 221 | tokens = tokens[:context_length] 222 | tokens[-1] = eot_token 223 | else: 224 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 225 | result[i, :len(tokens)] = torch.tensor(tokens) 226 | 227 | return result -------------------------------------------------------------------------------- /third_party/open_clip/environment.yml: -------------------------------------------------------------------------------- 1 | name: open_clip 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - absl-py=0.12.0=py36h06a4308_0 9 | - aiohttp=3.6.3=py36h7b6447c_0 10 | - async-timeout=3.0.1=py36h06a4308_0 11 | - attrs=20.3.0=pyhd3eb1b0_0 12 | - blas=1.0=mkl 13 | - blinker=1.4=py36h06a4308_0 14 | - brotlipy=0.7.0=py36h27cfd23_1003 15 | - c-ares=1.17.1=h27cfd23_0 16 | - ca-certificates=2020.12.5=ha878542_0 17 | - cachetools=4.2.1=pyhd3eb1b0_0 18 | - certifi=2020.12.5=py36h5fab9bb_1 19 | - cffi=1.14.5=py36h261ae71_0 20 | - chardet=3.0.4=py36h06a4308_1003 21 | - click=7.1.2=pyhd3eb1b0_0 22 | - coverage=5.5=py36h27cfd23_2 23 | - cryptography=3.4.7=py36hd23ed53_0 24 | - cudatoolkit=11.0.221=h6bb024c_0 25 | - cython=0.29.23=py36h2531618_0 26 | - dataclasses=0.8=pyh4f3eec9_6 27 | - faiss-gpu=1.4.0=py36_cuda8.0.61_1 28 | - freetype=2.10.4=h5ab3b9f_0 29 | - ftfy=5.8=py_0 30 | - google-auth=1.29.0=pyhd3eb1b0_0 31 | - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0 32 | - grpcio=1.36.1=py36h2157cd5_1 33 | - idna=2.10=pyhd3eb1b0_0 34 | - idna_ssl=1.1.0=py36h06a4308_0 35 | - importlib-metadata=3.10.0=py36h06a4308_0 36 | - intel-openmp=2021.2.0=h06a4308_610 37 | - joblib=1.0.1=pyhd8ed1ab_0 38 | - jpeg=9b=h024ee3a_2 39 | - lcms2=2.12=h3be6417_0 40 | - ld_impl_linux-64=2.33.1=h53a641e_7 41 | - libblas=3.9.0=1_h6e990d7_netlib 42 | - libcblas=3.9.0=3_h893e4fe_netlib 43 | - libffi=3.3=he6710b0_2 44 | - libgcc=7.2.0=h69d50b8_2 45 | - libgcc-ng=9.1.0=hdf63c60_0 46 | - libgfortran-ng=7.5.0=h14aa051_19 47 | - libgfortran4=7.5.0=h14aa051_19 48 | - liblapack=3.9.0=3_h893e4fe_netlib 49 | - libpng=1.6.37=hbc83047_0 50 | - libprotobuf=3.14.0=h8c45485_0 51 | - libstdcxx-ng=9.1.0=hdf63c60_0 52 | - libtiff=4.1.0=h2733197_1 53 | - libuv=1.40.0=h7b6447c_0 54 | - lz4-c=1.9.3=h2531618_0 55 | - markdown=3.3.4=py36h06a4308_0 56 | - mkl=2020.2=256 57 | - mkl-service=2.3.0=py36he8ac12f_0 58 | - mkl_fft=1.3.0=py36h54f3939_0 59 | - mkl_random=1.1.1=py36h0573a6f_0 60 | - multidict=4.7.6=py36h7b6447c_1 61 | - ncurses=6.2=he6710b0_1 62 | - ninja=1.10.2=hff7bd54_1 63 | - numpy=1.19.2=py36h54aff64_0 64 | - numpy-base=1.19.2=py36hfa32c7d_0 65 | - oauthlib=3.1.0=py_0 66 | - olefile=0.46=py36_0 67 | - openssl=1.1.1k=h27cfd23_0 68 | - pandas=1.1.3=py36he6710b0_0 69 | - pillow=8.2.0=py36he98fc37_0 70 | - pip=21.0.1=py36h06a4308_0 71 | - protobuf=3.14.0=py36h2531618_1 72 | - pyasn1=0.4.8=py_0 73 | - pyasn1-modules=0.2.8=py_0 74 | - pycparser=2.20=py_2 75 | - pyjwt=1.7.1=py36_0 76 | - pyopenssl=20.0.1=pyhd3eb1b0_1 77 | - pysocks=1.7.1=py36h06a4308_0 78 | - python=3.6.13=hdb3f193_0 79 | - python-dateutil=2.8.1=pyhd3eb1b0_0 80 | - python_abi=3.6=1_cp36m 81 | - pytorch=1.7.1=py3.6_cuda11.0.221_cudnn8.0.5_0 82 | - pytz=2021.1=pyhd3eb1b0_0 83 | - readline=8.1=h27cfd23_0 84 | - regex=2021.4.4=py36h27cfd23_0 85 | - requests=2.25.1=pyhd3eb1b0_0 86 | - requests-oauthlib=1.3.0=py_0 87 | - rsa=4.7.2=pyhd3eb1b0_1 88 | - scikit-learn=0.23.2=py36hb6e6923_3 89 | - scipy=1.5.3=py36h976291a_0 90 | - setuptools=52.0.0=py36h06a4308_0 91 | - six=1.15.0=py36h06a4308_0 92 | - sqlite=3.35.4=hdfb4753_0 93 | - tensorboard=2.4.0=pyhc547734_0 94 | - tensorboard-plugin-wit=1.6.0=py_0 95 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 96 | - tk=8.6.10=hbc83047_0 97 | - torchaudio=0.7.2=py36 98 | - torchvision=0.8.2=py36_cu110 99 | - tqdm=4.59.0=pyhd3eb1b0_1 100 | - typing_extensions=3.7.4.3=pyha847dfd_0 101 | - urllib3=1.26.4=pyhd3eb1b0_0 102 | - wcwidth=0.2.5=py_0 103 | - werkzeug=1.0.1=pyhd3eb1b0_0 104 | - wheel=0.36.2=pyhd3eb1b0_0 105 | - xz=5.2.5=h7b6447c_0 106 | - yarl=1.6.3=py36h27cfd23_0 107 | - zipp=3.4.1=pyhd3eb1b0_0 108 | - zlib=1.2.11=h7b6447c_3 109 | - zstd=1.4.9=haebb681_0 110 | - pip: 111 | - ase==3.21.1 112 | - braceexpand==0.1.7 113 | - cached-property==1.5.2 114 | - configparser==5.0.2 115 | - cycler==0.10.0 116 | - decorator==4.4.2 117 | - docker-pycreds==0.4.0 118 | - gitdb==4.0.7 119 | - gitpython==3.1.14 120 | - googledrivedownloader==0.4 121 | - h5py==3.1.0 122 | - isodate==0.6.0 123 | - jinja2==3.0.1 124 | - kiwisolver==1.3.1 125 | - littleutils==0.2.2 126 | - llvmlite==0.36.0 127 | - markupsafe==2.0.1 128 | - matplotlib==3.3.4 129 | - networkx==2.5.1 130 | - numba==0.53.1 131 | - ogb==1.3.1 132 | - outdated==0.2.1 133 | - pathtools==0.1.2 134 | - promise==2.3 135 | - psutil==5.8.0 136 | - pyarrow==4.0.0 137 | - pyparsing==2.4.7 138 | - python-louvain==0.15 139 | - pyyaml==5.4.1 140 | - rdflib==5.0.0 141 | - sentry-sdk==1.1.0 142 | - shortuuid==1.0.1 143 | - sklearn==0.0 144 | - smmap==4.0.0 145 | - subprocess32==3.5.4 146 | - torch-geometric==1.7.0 147 | - wandb==0.10.30 148 | - wilds==1.1.0 149 | - "--editable=git+https://github.com/tmbdev/webdataset.git@a4f3ec08551b42f20b20cdc1ba32d12536eabc15#egg=webdataset" 150 | - git+https://github.com/modestyachts/ImageNetV2_pytorch 151 | - https://pytorch-geometric.com/whl/torch-1.7.0+cu110/torch_scatter-2.0.6-cp36-cp36m-linux_x86_64.whl 152 | prefix: /home/gamaga/anaconda3/envs/open_clip 153 | -------------------------------------------------------------------------------- /third_party/open_clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.relu3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.relu1(self.bn1(self.conv1(x))) 46 | out = self.relu2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x[:1], key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0, 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | return x.squeeze(0) 92 | 93 | 94 | class ModifiedResNet(nn.Module): 95 | """ 96 | A ResNet class that is similar to torchvision's but contains the following changes: 97 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 98 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 99 | - The final pooling layer is a QKV attention instead of an average pool 100 | """ 101 | 102 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 103 | super().__init__() 104 | self.output_dim = output_dim 105 | self.input_resolution = input_resolution 106 | 107 | # the 3-layer stem 108 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 109 | self.bn1 = nn.BatchNorm2d(width // 2) 110 | self.relu1 = nn.ReLU(inplace=True) 111 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 112 | self.bn2 = nn.BatchNorm2d(width // 2) 113 | self.relu2 = nn.ReLU(inplace=True) 114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 115 | self.bn3 = nn.BatchNorm2d(width) 116 | self.relu3 = nn.ReLU(inplace=True) 117 | self.avgpool = nn.AvgPool2d(2) 118 | 119 | # residual layers 120 | self._inplanes = width # this is a *mutable* variable used during construction 121 | self.layer1 = self._make_layer(width, layers[0]) 122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 125 | 126 | embed_dim = width * 32 # the ResNet feature dimension 127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 128 | 129 | def _make_layer(self, planes, blocks, stride=1): 130 | layers = [Bottleneck(self._inplanes, planes, stride)] 131 | 132 | self._inplanes = planes * Bottleneck.expansion 133 | for _ in range(1, blocks): 134 | layers.append(Bottleneck(self._inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | def stem(x): 140 | x = self.relu1(self.bn1(self.conv1(x))) 141 | x = self.relu2(self.bn2(self.conv2(x))) 142 | x = self.relu3(self.bn3(self.conv3(x))) 143 | x = self.avgpool(x) 144 | return x 145 | 146 | x = x.type(self.conv1.weight.dtype) 147 | x = stem(x) 148 | x = self.layer1(x) 149 | x = self.layer2(x) 150 | x = self.layer3(x) 151 | x = self.layer4(x) 152 | x = self.attnpool(x) 153 | 154 | return x 155 | 156 | 157 | class LayerNorm(nn.LayerNorm): 158 | """Subclass torch's LayerNorm to handle fp16.""" 159 | 160 | def forward(self, x: torch.Tensor): 161 | orig_type = x.dtype 162 | ret = super().forward(x.type(torch.float32)) 163 | return ret.type(orig_type) 164 | 165 | 166 | class QuickGELU(nn.Module): 167 | def forward(self, x: torch.Tensor): 168 | return x * torch.sigmoid(1.702 * x) 169 | 170 | 171 | class ResidualAttentionBlock(nn.Module): 172 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 173 | super().__init__() 174 | 175 | self.attn = nn.MultiheadAttention(d_model, n_head) 176 | self.ln_1 = LayerNorm(d_model) 177 | self.mlp = nn.Sequential(OrderedDict([ 178 | ("c_fc", nn.Linear(d_model, d_model * 4)), 179 | ("gelu", QuickGELU()), 180 | ("c_proj", nn.Linear(d_model * 4, d_model)) 181 | ])) 182 | self.ln_2 = LayerNorm(d_model) 183 | self.attn_mask = attn_mask 184 | 185 | def attention(self, x: torch.Tensor): 186 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 187 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 188 | 189 | def forward(self, x: torch.Tensor): 190 | x = x + self.attention(self.ln_1(x)) 191 | x = x + self.mlp(self.ln_2(x)) 192 | return x 193 | 194 | 195 | class Transformer(nn.Module): 196 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 197 | super().__init__() 198 | self.width = width 199 | self.layers = layers 200 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 201 | 202 | def forward(self, x: torch.Tensor): 203 | return self.resblocks(x) 204 | 205 | 206 | class VisionTransformer(nn.Module): 207 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 208 | super().__init__() 209 | self.input_resolution = input_resolution 210 | self.output_dim = output_dim 211 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 212 | 213 | scale = width ** -0.5 214 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 215 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 216 | self.ln_pre = LayerNorm(width) 217 | 218 | self.transformer = Transformer(width, layers, heads) 219 | 220 | self.ln_post = LayerNorm(width) 221 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 222 | 223 | def forward(self, x: torch.Tensor): 224 | x = self.conv1(x) # shape = [*, width, grid, grid] 225 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 226 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 227 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 228 | x = x + self.positional_embedding.to(x.dtype) 229 | x = self.ln_pre(x) 230 | 231 | x = x.permute(1, 0, 2) # NLD -> LND 232 | x = self.transformer(x) 233 | x = x.permute(1, 0, 2) # LND -> NLD 234 | 235 | x = self.ln_post(x[:, 0, :]) 236 | 237 | if self.proj is not None: 238 | x = x @ self.proj 239 | 240 | return x 241 | 242 | 243 | class CLIP(nn.Module): 244 | def __init__(self, 245 | embed_dim: int, 246 | # vision 247 | image_resolution: int, 248 | vision_layers: Union[Tuple[int, int, int, int], int], 249 | vision_width: int, 250 | vision_patch_size: int, 251 | # text 252 | context_length: int, 253 | vocab_size: int, 254 | transformer_width: int, 255 | transformer_heads: int, 256 | transformer_layers: int 257 | ): 258 | super().__init__() 259 | 260 | self.context_length = context_length 261 | 262 | if isinstance(vision_layers, (tuple, list)): 263 | vision_heads = vision_width * 32 // 64 264 | self.visual = ModifiedResNet( 265 | layers=vision_layers, 266 | output_dim=embed_dim, 267 | heads=vision_heads, 268 | input_resolution=image_resolution, 269 | width=vision_width 270 | ) 271 | else: 272 | vision_heads = vision_width // 64 273 | self.visual = VisionTransformer( 274 | input_resolution=image_resolution, 275 | patch_size=vision_patch_size, 276 | width=vision_width, 277 | layers=vision_layers, 278 | heads=vision_heads, 279 | output_dim=embed_dim 280 | ) 281 | 282 | self.transformer = Transformer( 283 | width=transformer_width, 284 | layers=transformer_layers, 285 | heads=transformer_heads, 286 | attn_mask=self.build_attention_mask() 287 | ) 288 | 289 | self.vocab_size = vocab_size 290 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 291 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 292 | self.ln_final = LayerNorm(transformer_width) 293 | 294 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 295 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 296 | 297 | self.initialize_parameters() 298 | 299 | def initialize_parameters(self): 300 | nn.init.normal_(self.token_embedding.weight, std=0.02) 301 | nn.init.normal_(self.positional_embedding, std=0.01) 302 | 303 | if isinstance(self.visual, ModifiedResNet): 304 | if self.visual.attnpool is not None: 305 | std = self.visual.attnpool.c_proj.in_features ** -0.5 306 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 307 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 308 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 309 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 310 | 311 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 312 | for name, param in resnet_block.named_parameters(): 313 | if name.endswith("bn3.weight"): 314 | nn.init.zeros_(param) 315 | 316 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 317 | attn_std = self.transformer.width ** -0.5 318 | fc_std = (2 * self.transformer.width) ** -0.5 319 | for block in self.transformer.resblocks: 320 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 321 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 322 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 323 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 324 | 325 | if self.text_projection is not None: 326 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 327 | 328 | def build_attention_mask(self): 329 | # lazily create causal attention mask, with full attention between the vision tokens 330 | # pytorch uses additive attention mask; fill with -inf 331 | mask = torch.empty(self.context_length, self.context_length) 332 | mask.fill_(float("-inf")) 333 | mask.triu_(1) # zero out the lower diagonal 334 | return mask 335 | 336 | @property 337 | def dtype(self): 338 | return self.visual.conv1.weight.dtype 339 | 340 | def encode_image(self, image): 341 | return self.visual(image.type(self.dtype)) 342 | 343 | def encode_text(self, text): 344 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 345 | 346 | x = x + self.positional_embedding.type(self.dtype) 347 | x = x.permute(1, 0, 2) # NLD -> LND 348 | x = self.transformer(x) 349 | x = x.permute(1, 0, 2) # LND -> NLD 350 | x = self.ln_final(x).type(self.dtype) 351 | 352 | # x.shape = [batch_size, n_ctx, transformer.width] 353 | # take features from the eot embedding (eot_token is the highest number in each sequence) 354 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 355 | 356 | return x 357 | 358 | def forward(self, image, text): 359 | image_features = self.encode_image(image) 360 | text_features = self.encode_text(text) 361 | 362 | # normalized features 363 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 364 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 365 | 366 | # cosine similarity as logits 367 | logit_scale = self.logit_scale.exp() 368 | logits_per_image = logit_scale * image_features @ text_features.t() 369 | logits_per_text = logits_per_image.t() 370 | 371 | # shape = [global_batch_size, global_batch_size] 372 | return logits_per_image, logits_per_text 373 | 374 | 375 | def convert_weights(model: nn.Module): 376 | """Convert applicable model parameters to fp16""" 377 | 378 | def _convert_weights_to_fp16(l): 379 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 380 | l.weight.data = l.weight.data.half() 381 | if l.bias is not None: 382 | l.bias.data = l.bias.data.half() 383 | 384 | if isinstance(l, nn.MultiheadAttention): 385 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 386 | tensor = getattr(l, attr) 387 | if tensor is not None: 388 | tensor.data = tensor.data.half() 389 | 390 | for name in ["text_projection", "proj"]: 391 | if hasattr(l, name): 392 | attr = getattr(l, name) 393 | if attr is not None: 394 | attr.data = attr.data.half() 395 | 396 | model.apply(_convert_weights_to_fp16) 397 | 398 | 399 | def build_model(state_dict: dict): 400 | vit = "visual.proj" in state_dict 401 | 402 | if vit: 403 | vision_width = state_dict["visual.conv1.weight"].shape[0] 404 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 405 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 406 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 407 | image_resolution = vision_patch_size * grid_size 408 | else: 409 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 410 | vision_layers = tuple(counts) 411 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 412 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 413 | vision_patch_size = None 414 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 415 | image_resolution = output_width * 32 416 | 417 | embed_dim = state_dict["text_projection"].shape[1] 418 | context_length = state_dict["positional_embedding"].shape[0] 419 | vocab_size = state_dict["token_embedding.weight"].shape[0] 420 | transformer_width = state_dict["ln_final.weight"].shape[0] 421 | transformer_heads = transformer_width // 64 422 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 423 | 424 | model = CLIP( 425 | embed_dim, 426 | image_resolution, vision_layers, vision_width, vision_patch_size, 427 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 428 | ) 429 | 430 | for key in ["input_resolution", "context_length", "vocab_size"]: 431 | if key in state_dict: 432 | del state_dict[key] 433 | 434 | convert_weights(model) 435 | model.load_state_dict(state_dict) 436 | return model.eval() -------------------------------------------------------------------------------- /third_party/open_clip/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "image_resolution": 224, 4 | "vision_layers": [ 5 | 3, 6 | 4, 7 | 23, 8 | 3 9 | ], 10 | "vision_width": 64, 11 | "vision_patch_size": null, 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "transformer_width": 512, 15 | "transformer_heads": 8, 16 | "transformer_layers": 12 17 | } -------------------------------------------------------------------------------- /third_party/open_clip/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "image_resolution": 224, 4 | "vision_layers": [ 5 | 3, 6 | 4, 7 | 6, 8 | 3 9 | ], 10 | "vision_width": 64, 11 | "vision_patch_size": null, 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "transformer_width": 512, 15 | "transformer_heads": 8, 16 | "transformer_layers": 12 17 | } -------------------------------------------------------------------------------- /third_party/open_clip/model_configs/RN50_a2.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "image_resolution": 224, 4 | "vision_layers": [ 5 | 3, 6 | 4, 7 | 6, 8 | 3 9 | ], 10 | "vision_width": 64, 11 | "vision_patch_size": null, 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "transformer_width": 512, 15 | "transformer_heads": 8, 16 | "transformer_layers": 12, 17 | "extra_transformer_layers": 2, 18 | "share_projection_layer": false 19 | } -------------------------------------------------------------------------------- /third_party/open_clip/model_configs/RN50_a2s.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "image_resolution": 224, 4 | "vision_layers": [ 5 | 3, 6 | 4, 7 | 6, 8 | 3 9 | ], 10 | "vision_width": 64, 11 | "vision_patch_size": null, 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "transformer_width": 512, 15 | "transformer_heads": 8, 16 | "transformer_layers": 12, 17 | "extra_transformer_layers": 2, 18 | "share_projection_layer": true 19 | } -------------------------------------------------------------------------------- /third_party/open_clip/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "image_resolution": 384, 4 | "vision_layers": [ 5 | 6, 6 | 8, 7 | 18, 8 | 8 9 | ], 10 | "vision_width": 96, 11 | "vision_patch_size": null, 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "transformer_width": 768, 15 | "transformer_heads": 12, 16 | "transformer_layers": 12 17 | } -------------------------------------------------------------------------------- /third_party/open_clip/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "image_resolution": 288, 4 | "vision_layers": [ 5 | 4, 6 | 6, 7 | 10, 8 | 6 9 | ], 10 | "vision_width": 80, 11 | "vision_patch_size": null, 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "transformer_width": 640, 15 | "transformer_heads": 10, 16 | "transformer_layers": 12 17 | } -------------------------------------------------------------------------------- /third_party/open_clip/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "image_resolution": 224, 4 | "vision_layers": 12, 5 | "vision_width": 768, 6 | "vision_patch_size": 16, 7 | "context_length": 77, 8 | "vocab_size": 49408, 9 | "transformer_width": 512, 10 | "transformer_heads": 8, 11 | "transformer_layers": 12 12 | } -------------------------------------------------------------------------------- /third_party/open_clip/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "image_resolution": 224, 4 | "vision_layers": 12, 5 | "vision_width": 768, 6 | "vision_patch_size": 32, 7 | "context_length": 77, 8 | "vocab_size": 49408, 9 | "transformer_width": 512, 10 | "transformer_heads": 8, 11 | "transformer_layers": 12 12 | } -------------------------------------------------------------------------------- /third_party/open_clip/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def assign_learning_rate(optimizer, new_lr): 4 | for param_group in optimizer.param_groups: 5 | param_group["lr"] = new_lr 6 | 7 | def _warmup_lr(base_lr, warmup_length, step): 8 | return base_lr * (step + 1) / warmup_length 9 | 10 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 11 | def _lr_adjuster(step): 12 | if step < warmup_length: 13 | lr = _warmup_lr(base_lr, warmup_length, step) 14 | else: 15 | e = step - warmup_length 16 | es = steps - warmup_length 17 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 18 | assign_learning_rate(optimizer, lr) 19 | return lr 20 | return _lr_adjuster -------------------------------------------------------------------------------- /third_party/open_clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text --------------------------------------------------------------------------------