├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── data
├── README.md
├── coco
│ └── prepare_data.py
└── imgnet
│ ├── imgnet_real_query.txt
│ └── imgnet_targets.txt
├── env.sh
├── model
├── clip.py
└── model.py
├── requirements.txt
├── setenv.sh
├── src
├── data.py
├── demo.py
├── eval_retrieval.py
├── eval_utils.py
├── logger.py
├── main.py
├── params.py
├── trainer.py
└── utils.py
├── third_party
└── open_clip
│ ├── LICENSE
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── clip.py
│ ├── environment.yml
│ ├── model.py
│ ├── model_configs
│ ├── RN101.json
│ ├── RN50.json
│ ├── RN50_a2.json
│ ├── RN50_a2s.json
│ ├── RN50x16.json
│ ├── RN50x4.json
│ ├── ViT-B-16.json
│ └── ViT-B-32.json
│ ├── scheduler.py
│ └── simple_tokenizer.py
└── valprep.sh
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows [Google's Open Source Community
28 | Guidelines](https://opensource.google/conduct/).
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pic2Word (CVPR2023)
2 |
3 | This is an open source implementation of [Pic2Word](https://arxiv.org/pdf/2302.03084.pdf). This is not an
4 | officially supported Google product.
5 |
6 |
7 | ## Data
8 |
9 | ### Training Data
10 | We utilize [Conceptual Captions URLs](https://ai.google.com/research/ConceptualCaptions/download) to train a model.
11 | See [open_clip](https://github.com/mlfoundations/open_clip) to see the process of getting the dataset.
12 |
13 | The training data directory has to be in the root of this repo, and should be structured like below.
14 | ```bash
15 | cc_data
16 | ├── train ## training image diretories.
17 | └── val ## validation image directories.
18 | cc
19 | ├── Train_GCC-training_output.csv ## training data list
20 | └── Validation_GCC-1.1.0-Validation_output.csv ## validation data list
21 | ```
22 |
23 | ### Test Data
24 | See [README](data/README.md) to prepare test dataset.
25 |
26 | ## Training
27 |
28 | ### Install dependencies
29 | See [open_clip](https://github.com/mlfoundations/open_clip) for the details of installation.
30 | The same environment should be usable in this repo.
31 | setenv.sh is the script we used to set-up the environment in virtualenv.
32 |
33 | Also run below to add directory to pythonpath:
34 | ```bash
35 | . env3/bin/activate
36 | export PYTHONPATH="$PYTHONPATH:$PWD/src"
37 | export PYTHONWARNINGS='ignore:semaphore_tracker:UserWarning'
38 | ```
39 | ### Pre-trained model
40 | The model is available in [GoogleDrive](https://drive.google.com/file/d/1IxRi2Cj81RxMu0ViT4q4nkfyjbSHm1dF/view?usp=sharing).
41 |
42 | ### Sample running code for training:
43 |
44 | ```bash
45 | python -u src/main.py \
46 | --save-frequency 1 \
47 | --train-data="cc/Train_GCC-training_output.csv" \
48 | --warmup 10000 \
49 | --batch-size=128 \
50 | --lr=1e-4 \
51 | --wd=0.1 \
52 | --epochs=30 \
53 | --workers=8 \
54 | --openai-pretrained \
55 | --model ViT-L/14
56 | ```
57 |
58 | ### Sample evaluation only:
59 |
60 | Evaluation on COCO, ImageNet, or CIRR.
61 | ```bash
62 | python src/eval_retrieval.py \
63 | --openai-pretrained \
64 | --resume /path/to/checkpoints \
65 | --eval-mode $data_name \ ## replace with coco, imgnet, or cirr
66 | --gpu $gpu_id
67 | --model ViT-L/14
68 | ```
69 |
70 | Evaluation on fashion-iq (shirt or dress or toptee)
71 | ```bash
72 | python src/eval_retrieval.py \
73 | --openai-pretrained \
74 | --resume /path/to/checkpoints \
75 | --eval-mode fashion \
76 | --source $cloth_type \ ## replace with shirt or dress or toptee
77 | --gpu $gpu_id
78 | --model ViT-L/14
79 | ```
80 |
81 | ### Demo:
82 |
83 | Evaluation on COCO, ImageNet, or CIRR.
84 |
85 | ```bash
86 | python src/demo.py \
87 | --openai-pretrained \
88 | --resume /path/to/checkpoints \
89 | --retrieval-data $data_name \ ## Choose from coco, imgnet, cirr, dress, shirt, toptee.
90 | --query_file "path_img1,path_img2,path_img3..." \ ## query images
91 | --prompts "prompt1,prompt2,..." \ #prompts. Use * to indicate the token to be replaced with an image token. e.g., "a sketch of *"
92 | --demo-out $path_demo \ # directory to generate html file and image directory.
93 | --gpu $gpu_id
94 | --model ViT-L/14
95 | ```
96 | This demo will generate a directory which includes html file and an image directory. Download the directory and open html to see results.
97 |
98 | ## Citing
99 |
100 | If you found this repository useful, please consider citing:
101 |
102 | ```bibtex
103 | @article{saito2023pic2word,
104 | title={Pic2Word: Mapping Pictures to Words for Zero-shot Composed Image Retrieval},
105 | author={Saito, Kuniaki and Sohn, Kihyuk and Zhang, Xiang and Li, Chun-Liang and Lee, Chen-Yu and Saenko, Kate and Pfister, Tomas},
106 | journal={CVPR},
107 | year={2023}
108 | }
109 |
110 | ```
111 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | ## Data
2 |
3 | Overall structure of this directory should be as follows.
4 | ```bash
5 | data
6 | ├── coco
7 | ├── imgnet
8 | ├── CIRR
9 | └── fashion-iq
10 | ```
11 |
12 | ### ImageNet
13 | ```bash
14 | imgnet
15 | ├── imagenet-r ## unzipped imagenet-r directories containing images. This folder should contain subfolders.
16 | └──n01443537
17 | .
18 | .
19 |
20 | ├── imgnet_real_query.txt
21 | ├── imgnet_targets.txt
22 | └── real ## imagenet validation directories containing images. This folder should contain subfolders.
23 | └──n01440764
24 | .
25 | .
26 | ```
27 | See [ImageNet-R](https://github.com/hendrycks/imagenet-r) to download the dataset.
28 |
29 | ### COCO
30 | ```bash
31 | coco
32 | ├── annotations/instances_val2017.json ## annotations for COCO validation images.
33 | ├── prepare_data.py ## code to generate query data.
34 | ├── coco_eval.csv ## this will be generated by running prepare_data.py
35 | ├── val2017 ## directory containing COCO validation images.
36 | └── val2017_masked ## running prepare_data.py will produce the directory.
37 | ```
38 | Download both instances_val2017.json and val2017.
39 | Run the command to below to produce directory of val2017_masked.
40 | ```bash
41 | python prepare_data.py
42 | ```
43 |
44 | ### CIRR
45 |
46 | ```
47 | cirr
48 | ├── captions
49 | └──cap.rc2.val.json
50 | ├── dev
51 | └── image_splits
52 | └──split.rc2.val.json
53 | ```
54 | Download the images following instruction on [CIRR](https://github.com/Cuberick-Orion/CIRR).
55 |
56 | ### Fashion-IQ
57 |
58 | ```
59 | fashion-iq
60 | ├── json
61 | ├── cap.dress.val.json
62 | ├── cap.shirt.val.json
63 | └── cap.toptee.val.json
64 | ├── image_splits
65 | ├── split.dress.val.json
66 | ├── split.shirt.val.json
67 | └── split.toptee.val.json
68 | └── images ## images under this directory.
69 | ```
70 | Json files are available in https://github.com/XiaoxiaoGuo/fashion-iq.
71 | Images are downloaded from https://github.com/postBG/CosMo.pytorch.
72 |
73 |
74 |
75 |
--------------------------------------------------------------------------------
/data/coco/prepare_data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from pycocotools.coco import COCO
16 | from collections import defaultdict
17 | import random
18 | import pandas as pd
19 | from PIL import Image
20 | import json
21 | import numpy as np
22 | import os
23 |
24 | coco = COCO(annotation_file='annotations/instances_val2017.json')
25 | cat_ids = coco.getCatIds()
26 | def convert_coco_json_to_csv(filename='./annotations/instances_val2017.json', root='./val2017'):
27 | s = json.load(open(filename, 'r'))
28 | out_file = 'coco_eval.csv'
29 | mask_dir = root+"_masked"
30 | if not os.path.exists(mask_dir):
31 | os.makedirs(mask_dir)
32 | out = open(out_file, 'w')
33 | out.write('id,query_regions,query_class,classes\n')
34 | all_ids = []
35 | dict_id2cat = {item['id']:item['name'] for item in s['categories']}
36 | for im in s['images']:
37 | all_ids.append(im['id'])
38 | all_ids_ann = []
39 | id2anns = defaultdict(list)
40 | for ann in s['annotations']:
41 | image_id = ann['image_id']
42 | all_ids_ann.append(image_id)
43 | x1 = ann['bbox'][0]
44 | x2 = ann['bbox'][0] + ann['bbox'][2]
45 | y1 = ann['bbox'][1]
46 | y2 = ann['bbox'][1] + ann['bbox'][3]
47 | label = dict_id2cat[ann['category_id']]
48 | tmp = [x1, y1, x2, y2, label, ann]
49 | id2anns[image_id].append(tmp)
50 | # Give query regions + classes not included in the query as a hint to retrieve images.
51 | class_count = 0
52 | for id_img in id2anns.keys():
53 | anns = id2anns[id_img]
54 | label_set = {}
55 | for ann in anns:
56 | label_set[ann[-2]] = label_set.get(ann[-2], 0) + 1
57 | label_set = list(label_set.keys())
58 | class_count += len(label_set)
59 | output = "%012d.jpg," %id_img
60 | image = Image.open(os.path.join(root, "%012d.jpg" %id_img))
61 | image = np.array(image)
62 | width, height = image.shape[0], image.shape[1]
63 | area_img = width * height
64 | cand_query = []
65 | for cand in anns:
66 | x1, y1, x2, y2 = map(lambda x: float(x), cand[:-2])
67 | area = (x2-x1) * (y2-y1)
68 | if 0.05 < area < 0.5 * area_img:
69 | cand_query.append(cand)
70 | if len(cand_query) >= 1:
71 | query_regions = random.sample(cand_query, k=1)
72 | for region in query_regions:
73 | query_label = region[-2]
74 | ann_region = region[-1]
75 |
76 | id_img = ann_region['image_id']
77 | filename = coco.imgs[id_img]['file_name']
78 | image = Image.open(os.path.join(root, filename))
79 | image = np.array(image)
80 | mask = coco.annToMask(ann_region)
81 | width, height = mask.shape
82 | mask = mask.reshape(width, height,1)
83 | if len(image.shape) == 2:
84 | image = image.reshape(width, height, 1)
85 | image_masked = image * mask + (1-mask) * 255
86 | try:
87 | im = Image.fromarray(image_masked)
88 | except:
89 | image_masked = np.squeeze(image_masked, axis=2)
90 | im = Image.fromarray(image_masked)
91 | im.save(os.path.join(mask_dir, filename))
92 |
93 | label_set.remove(query_label)
94 | output += ";".join(map(lambda x: str(x), region[:-2]))
95 | output += " "
96 | output += ","
97 | output += query_label
98 | output += ","
99 | output += ";".join(label_set)
100 | output += "\n"
101 | out.write(output)
102 | out.close()
103 | # Sort file by image id
104 | s1 = pd.read_csv(out_file)
105 | s1.sort_values('id', inplace=True)
106 | s1.to_csv(out_file, index=False)
107 |
108 | convert_coco_json_to_csv()
109 |
--------------------------------------------------------------------------------
/env.sh:
--------------------------------------------------------------------------------
1 | . env3/bin/activate
2 | export PYTHONPATH="$PYTHONPATH:$PWD/src"
3 | export PYTHONWARNINGS='ignore:semaphore_tracker:UserWarning'
4 |
5 |
--------------------------------------------------------------------------------
/model/clip.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Most code is from https://github.com/openai/CLIP
16 | import hashlib
17 | import os
18 | import urllib
19 | import warnings
20 | from typing import Union, List
21 | import torch
22 | from PIL import Image
23 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomResizedCrop
24 | from tqdm import tqdm
25 | from model.model import build_model
26 | from third_party.open_clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
27 |
28 | from functools import *
29 | try:
30 | from huggingface_hub import hf_hub_download
31 | __version__ = '2.0.2'
32 | hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__)
33 | _has_hf_hub = True
34 | except ImportError:
35 | hf_hub_download = None
36 | _has_hf_hub = False
37 |
38 | __all__ = ["available_models", "load", "tokenize"]
39 | _tokenizer = _Tokenizer()
40 |
41 | _MODELS = {
42 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
43 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
44 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
45 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
46 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
47 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
48 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
49 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
50 | }
51 | _OPENAI = {
52 | "ViT-H-14": 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'
53 | ,
54 | }
55 |
56 | def has_hf_hub(necessary=False):
57 | if not _has_hf_hub and necessary:
58 | # if no HF Hub module installed, and it is necessary to continue, raise error
59 | raise RuntimeError(
60 | 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
61 | return _has_hf_hub
62 |
63 | def download_pretrained_from_hf(
64 | model_id: str,
65 | filename: str = 'open_clip_pytorch_model.bin',
66 | revision=None,
67 | cache_dir: Union[str, None] = None,
68 | ):
69 | has_hf_hub(True)
70 | cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
71 | return cached_file
72 |
73 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
74 | os.makedirs(root, exist_ok=True)
75 | filename = os.path.basename(url)
76 |
77 | expected_sha256 = url.split("/")[-2]
78 | download_target = os.path.join(root, filename)
79 |
80 | if os.path.exists(download_target) and not os.path.isfile(download_target):
81 | raise RuntimeError(f"{download_target} exists and is not a regular file")
82 |
83 | if os.path.isfile(download_target):
84 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
85 | return download_target
86 | else:
87 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
88 |
89 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
90 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
91 | while True:
92 | buffer = source.read(8192)
93 | if not buffer:
94 | break
95 |
96 | output.write(buffer)
97 | loop.update(len(buffer))
98 |
99 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
100 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
101 |
102 | return download_target
103 |
104 | def _convert_to_rgb(image):
105 | return image.convert('RGB')
106 |
107 | def _transform(n_px: int, is_train: bool):
108 | normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
109 | if is_train:
110 | return Compose([
111 | RandomResizedCrop(n_px, scale=(0.9, 1.0), interpolation=Image.BICUBIC),
112 | _convert_to_rgb,
113 | ToTensor(),
114 | normalize,
115 | ])
116 | else:
117 | return Compose([
118 | Resize(n_px, interpolation=Image.BICUBIC),
119 | CenterCrop(n_px),
120 | _convert_to_rgb,
121 | ToTensor(),
122 | normalize,
123 | ])
124 |
125 |
126 |
127 | def available_models() -> List[str]:
128 | """Returns the names of available CLIP models"""
129 | return list(_MODELS.keys())
130 |
131 |
132 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True, is_train=False, pretrained=True):
133 | """Load a CLIP model
134 | Parameters
135 | ----------
136 | name : str
137 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
138 | device : Union[str, torch.device]
139 | The device to put the loaded model
140 | jit : bool
141 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
142 | Returns
143 | -------
144 | model : torch.nn.Module
145 | The CLIP model
146 | preprocess : Callable[[PIL.Image], torch.Tensor]
147 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
148 | """
149 | if name in _MODELS:
150 | model_path = _download(_MODELS[name])
151 | elif os.path.isfile(name):
152 | model_path = name
153 | elif name in _OPENAI:
154 | has_hf_hub(True)
155 | # we assume the hf_hub entries in pretrained config combine model_id + filename in
156 | # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
157 | # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
158 | model_id, filename = os.path.split(_OPENAI[name])
159 | if filename:
160 | model_path = download_pretrained_from_hf(model_id, filename=filename)
161 | else:
162 | model_path = download_pretrained_from_hf(model_id)
163 | else:
164 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
165 |
166 | try:
167 | # loading JIT archive
168 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
169 | state_dict = None
170 | except RuntimeError:
171 | # loading saved state dict
172 | if jit:
173 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
174 | jit = False
175 | state_dict = torch.load(model_path, map_location="cpu")
176 |
177 | if not jit:
178 | try:
179 | model = build_model(state_dict or model.state_dict()).to(device)
180 | except KeyError:
181 | sd = {k[7:]: v for k,v in state_dict["state_dict"].items()}
182 | model = build_model(sd).to(device)
183 |
184 | if str(device) == "cpu":
185 | model.float()
186 | return model, \
187 | _transform(model.visual.input_resolution, is_train=True), \
188 | _transform(model.visual.input_resolution, is_train=False)
189 |
190 | # patch the device names
191 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
192 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
193 |
194 | def patch_device(module):
195 | graphs = [module.graph] if hasattr(module, "graph") else []
196 | if hasattr(module, "forward1"):
197 | graphs.append(module.forward1.graph)
198 |
199 | for graph in graphs:
200 | for node in graph.findAllNodes("prim::Constant"):
201 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
202 | node.copyAttributes(device_node)
203 |
204 | model.apply(patch_device)
205 | patch_device(model.encode_image)
206 | patch_device(model.encode_text)
207 |
208 | # patch dtype to float32 on CPU
209 | if str(device) == "cpu":
210 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
211 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
212 | float_node = float_input.node()
213 |
214 | def patch_float(module):
215 | graphs = [module.graph] if hasattr(module, "graph") else []
216 | if hasattr(module, "forward1"):
217 | graphs.append(module.forward1.graph)
218 |
219 | for graph in graphs:
220 | for node in graph.findAllNodes("aten::to"):
221 | inputs = list(node.inputs())
222 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
223 | if inputs[i].node()["value"] == 5:
224 | inputs[i].node().copyAttributes(float_node)
225 |
226 | model.apply(patch_float)
227 | patch_float(model.encode_image)
228 | patch_float(model.encode_text)
229 |
230 | model.float()
231 |
232 | return model, \
233 | _transform(model.input_resolution.item(), is_train=True), \
234 | _transform(model.input_resolution.item(), is_train=False)
235 |
236 |
237 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
238 | """
239 | Returns the tokenized representation of given input string(s)
240 | Parameters
241 | ----------
242 | texts : Union[str, List[str]]
243 | An input string or a list of input strings to tokenize
244 | context_length : int
245 | The context length to use; all CLIP models use 77 as the context length
246 | Returns
247 | -------
248 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
249 | """
250 | if isinstance(texts, str):
251 | texts = [texts]
252 |
253 | sot_token = _tokenizer.encoder[""]
254 | eot_token = _tokenizer.encoder[""]
255 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
256 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
257 |
258 | for i, tokens in enumerate(all_tokens):
259 | if len(tokens) > context_length: # Truncate
260 | tokens = tokens[:context_length-1]
261 | tokens = tokens + [eot_token]
262 | result[i, :len(tokens)] = torch.tensor(tokens)
263 |
264 | return result
265 |
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from collections import OrderedDict
16 | from typing import Tuple, Union
17 |
18 | import os
19 | import json
20 | from copy import deepcopy
21 | import numpy as np
22 | import torch
23 | import torch.nn.functional as F
24 | from torch import nn
25 | import torch.distributed as dist
26 |
27 | class IM2TEXT(nn.Module):
28 | def __init__(self, embed_dim=512, middle_dim=512, output_dim=512, n_layer=2, dropout=0.1):
29 | super().__init__()
30 | self.fc_out = nn.Linear(middle_dim, output_dim)
31 | layers = []
32 | dim = embed_dim
33 | for _ in range(n_layer):
34 | block = []
35 | block.append(nn.Linear(dim, middle_dim))
36 | block.append(nn.Dropout(dropout))
37 | block.append(nn.ReLU())
38 | dim = middle_dim
39 | layers.append(nn.Sequential(*block))
40 | self.layers = nn.Sequential(*layers)
41 |
42 | def forward(self, x: torch.Tensor):
43 | for layer in self.layers:
44 | x = layer(x)
45 | return self.fc_out(x)
46 |
47 | class Bottleneck(nn.Module):
48 | expansion = 4
49 |
50 | def __init__(self, inplanes, planes, stride=1):
51 | super().__init__()
52 |
53 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
54 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
55 | self.bn1 = nn.BatchNorm2d(planes)
56 |
57 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
58 | self.bn2 = nn.BatchNorm2d(planes)
59 |
60 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
61 |
62 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
63 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
64 |
65 | self.relu = nn.ReLU(inplace=True)
66 | self.downsample = None
67 | self.stride = stride
68 |
69 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
70 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
71 | self.downsample = nn.Sequential(OrderedDict([
72 | ("-1", nn.AvgPool2d(stride)),
73 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
74 | ("1", nn.BatchNorm2d(planes * self.expansion))
75 | ]))
76 |
77 | def forward(self, x: torch.Tensor):
78 | identity = x
79 |
80 | out = self.relu(self.bn1(self.conv1(x)))
81 | out = self.relu(self.bn2(self.conv2(out)))
82 | out = self.avgpool(out)
83 | out = self.bn3(self.conv3(out))
84 |
85 | if self.downsample is not None:
86 | identity = self.downsample(x)
87 |
88 | out += identity
89 | out = self.relu(out)
90 | return out
91 |
92 |
93 | class AttentionPool2d(nn.Module):
94 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
95 | super().__init__()
96 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
97 | self.k_proj = nn.Linear(embed_dim, embed_dim)
98 | self.q_proj = nn.Linear(embed_dim, embed_dim)
99 | self.v_proj = nn.Linear(embed_dim, embed_dim)
100 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
101 | self.num_heads = num_heads
102 |
103 | def forward(self, x):
104 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
105 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
106 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
107 | x, _ = F.multi_head_attention_forward(
108 | query=x, key=x, value=x,
109 | embed_dim_to_check=x.shape[-1],
110 | num_heads=self.num_heads,
111 | q_proj_weight=self.q_proj.weight,
112 | k_proj_weight=self.k_proj.weight,
113 | v_proj_weight=self.v_proj.weight,
114 | in_proj_weight=None,
115 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
116 | bias_k=None,
117 | bias_v=None,
118 | add_zero_attn=False,
119 | dropout_p=0,
120 | out_proj_weight=self.c_proj.weight,
121 | out_proj_bias=self.c_proj.bias,
122 | use_separate_proj_weight=True,
123 | training=self.training,
124 | need_weights=False
125 | )
126 |
127 | return x[0]
128 |
129 |
130 | class ModifiedResNet(nn.Module):
131 | """
132 | A ResNet class that is similar to torchvision's but contains the following changes:
133 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
134 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
135 | - The final pooling layer is a QKV attention instead of an average pool
136 | """
137 |
138 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
139 | super().__init__()
140 | self.output_dim = output_dim
141 | self.input_resolution = input_resolution
142 |
143 | # the 3-layer stem
144 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
145 | self.bn1 = nn.BatchNorm2d(width // 2)
146 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
147 | self.bn2 = nn.BatchNorm2d(width // 2)
148 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
149 | self.bn3 = nn.BatchNorm2d(width)
150 | self.avgpool = nn.AvgPool2d(2)
151 | self.relu = nn.ReLU(inplace=True)
152 |
153 | # residual layers
154 | self._inplanes = width # this is a *mutable* variable used during construction
155 | self.layer1 = self._make_layer(width, layers[0])
156 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
157 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
158 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
159 |
160 | embed_dim = width * 32 # the ResNet feature dimension
161 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
162 |
163 | def _make_layer(self, planes, blocks, stride=1):
164 | layers = [Bottleneck(self._inplanes, planes, stride)]
165 |
166 | self._inplanes = planes * Bottleneck.expansion
167 | for _ in range(1, blocks):
168 | layers.append(Bottleneck(self._inplanes, planes))
169 |
170 | return nn.Sequential(*layers)
171 |
172 | def forward(self, x):
173 | def stem(x):
174 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
175 | x = self.relu(bn(conv(x)))
176 | x = self.avgpool(x)
177 | return x
178 |
179 | x = x.type(self.conv1.weight.dtype)
180 | x = stem(x)
181 | x = self.layer1(x)
182 | x = self.layer2(x)
183 | x = self.layer3(x)
184 | x = self.layer4(x)
185 | x = self.attnpool(x)
186 |
187 | return x
188 |
189 |
190 | class LayerNorm(nn.LayerNorm):
191 | """Subclass torch's LayerNorm to handle fp16."""
192 |
193 | def forward(self, x: torch.Tensor):
194 | orig_type = x.dtype
195 | ret = super().forward(x.type(torch.float32))
196 | return ret.type(orig_type)
197 |
198 |
199 | class QuickGELU(nn.Module):
200 | def forward(self, x: torch.Tensor):
201 | return x * torch.sigmoid(1.702 * x)
202 |
203 |
204 | class ResidualAttentionBlock(nn.Module):
205 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
206 | super().__init__()
207 |
208 | self.attn = nn.MultiheadAttention(d_model, n_head)
209 | self.ln_1 = LayerNorm(d_model)
210 | self.mlp = nn.Sequential(OrderedDict([
211 | ("c_fc", nn.Linear(d_model, d_model * 4)),
212 | ("gelu", QuickGELU()),
213 | ("c_proj", nn.Linear(d_model * 4, d_model))
214 | ]))
215 | self.ln_2 = LayerNorm(d_model)
216 | self.attn_mask = attn_mask
217 |
218 | def attention(self, x: torch.Tensor):
219 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
220 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
221 |
222 | def forward(self, x: torch.Tensor):
223 | x = x + self.attention(self.ln_1(x))
224 | x = x + self.mlp(self.ln_2(x))
225 | return x
226 |
227 |
228 | class Transformer(nn.Module):
229 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
230 | super().__init__()
231 | self.width = width
232 | self.layers = layers
233 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
234 |
235 | def forward(self, x: torch.Tensor):
236 | return self.resblocks(x)
237 |
238 |
239 | class VisualTransformer(nn.Module):
240 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
241 | super().__init__()
242 | self.input_resolution = input_resolution
243 | self.output_dim = output_dim
244 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
245 |
246 | scale = width ** -0.5
247 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
248 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
249 | self.ln_pre = LayerNorm(width)
250 |
251 | self.transformer = Transformer(width, layers, heads)
252 |
253 | self.ln_post = LayerNorm(width)
254 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
255 |
256 | def forward(self, x: torch.Tensor):
257 | x = self.conv1(x) # shape = [*, width, grid, grid]
258 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
259 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
260 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
261 | x = x + self.positional_embedding.to(x.dtype)
262 | x = self.ln_pre(x)
263 |
264 | x = x.permute(1, 0, 2) # NLD -> LND
265 | x = self.transformer(x)
266 | x = x.permute(1, 0, 2) # LND -> NLD
267 |
268 | x = self.ln_post(x[:, 0, :])
269 |
270 | if self.proj is not None:
271 | x = x @ self.proj
272 |
273 | return x
274 |
275 | def get_tokens(self, x: torch.Tensor):
276 | x = self.conv1(x) # shape = [*, width, grid, grid]
277 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
278 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
279 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
280 | x = x + self.positional_embedding.to(x.dtype)
281 | x = self.ln_pre(x)
282 | x = x.permute(1, 0, 2) # NLD -> LND
283 | x = self.transformer(x)
284 | x = x.permute(1, 0, 2) # LND -> NLD
285 | return x
286 |
287 |
288 | class CLIP(nn.Module):
289 | def __init__(self,
290 | embed_dim: int,
291 | # vision
292 | image_resolution: int,
293 | vision_layers: Union[Tuple[int, int, int, int], int],
294 | vision_width: int,
295 | vision_patch_size: int,
296 | # text
297 | context_length: int,
298 | vocab_size: int,
299 | transformer_width: int,
300 | transformer_heads: int,
301 | transformer_layers: int,
302 | extra_transformer_layers: int = 0,
303 | share_projection_layer: bool = True,
304 | ):
305 | super().__init__()
306 | self.embed_dim = embed_dim
307 | self.context_length = context_length
308 | self.share_projection_layer = share_projection_layer
309 | self.has_extra = True if extra_transformer_layers > 0 else False
310 |
311 | if isinstance(vision_layers, (tuple, list)):
312 | vision_heads = vision_width * 32 // 64
313 | self.visual = ModifiedResNet(
314 | layers=vision_layers,
315 | output_dim=embed_dim,
316 | heads=vision_heads,
317 | input_resolution=image_resolution,
318 | width=vision_width
319 | )
320 | else:
321 | vision_heads = vision_width // 64
322 | self.visual = VisualTransformer(
323 | input_resolution=image_resolution,
324 | patch_size=vision_patch_size,
325 | width=vision_width,
326 | layers=vision_layers,
327 | heads=vision_heads,
328 | output_dim=embed_dim
329 | )
330 | self.transformer_width = transformer_width
331 | self.transformer = Transformer(
332 | width=transformer_width,
333 | layers=transformer_layers,
334 | heads=transformer_heads,
335 | attn_mask=self.build_attention_mask()
336 | )
337 | if extra_transformer_layers > 0:
338 | self.extra_transformer = Transformer(
339 | width=transformer_width,
340 | layers=extra_transformer_layers,
341 | heads=transformer_heads,
342 | attn_mask=self.build_attention_mask()
343 | )
344 | self.extra_ln_final = LayerNorm(transformer_width)
345 |
346 | self.vocab_size = vocab_size
347 | self.end_id = self.vocab_size -1
348 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
349 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
350 | self.ln_final = LayerNorm(transformer_width)
351 |
352 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
353 | if not share_projection_layer:
354 | self.extra_text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
355 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
356 |
357 | self.initialize_parameters()
358 |
359 | def initialize_parameters(self):
360 | nn.init.normal_(self.token_embedding.weight, std=0.02)
361 | nn.init.normal_(self.positional_embedding, std=0.01)
362 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
363 |
364 | if isinstance(self.visual, ModifiedResNet):
365 | if self.visual.attnpool is not None:
366 | std = self.visual.attnpool.c_proj.in_features ** -0.5
367 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
368 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
369 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
370 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
371 |
372 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
373 | for name, param in resnet_block.named_parameters():
374 | if name.endswith("bn3.weight"):
375 | nn.init.zeros_(param)
376 |
377 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
378 | attn_std = self.transformer.width ** -0.5
379 | fc_std = (2 * self.transformer.width) ** -0.5
380 | for block in self.transformer.resblocks:
381 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
382 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
383 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
384 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
385 |
386 | if self.text_projection is not None:
387 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
388 | if hasattr(self, 'extra_text_projection'):
389 | nn.init.normal_(self.extra_text_projection, std=self.transformer.width ** -0.5)
390 |
391 | def build_attention_mask(self):
392 | # lazily create causal attention mask, with full attention between the vision tokens
393 | # pytorch uses additive attention mask; fill with -inf
394 | mask = torch.empty(self.context_length, self.context_length)
395 | mask.fill_(float("-inf"))
396 | mask.triu_(1) # zero out the lower diagonal
397 | return mask
398 |
399 | @property
400 | def dtype(self):
401 | return self.visual.conv1.weight.dtype
402 |
403 | def encode_image(self, image):
404 | return self.visual(image.type(self.dtype))
405 |
406 | def encode_text(self, text):
407 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
408 |
409 | x = x + self.positional_embedding.type(self.dtype)
410 | x = x.permute(1, 0, 2) # NLD -> LND
411 | x = self.transformer(x)
412 | x = x.permute(1, 0, 2) # LND -> NLD
413 | x = self.ln_final(x).type(self.dtype)
414 | # x.shape = [batch_size, n_ctx, transformer.width]
415 | # take features from the eot embedding (eot_token is the highest number in each sequence)
416 | collect_ind = text == self.end_id
417 | collect_ind = collect_ind.nonzero()[:, 1]
418 | x = x[torch.arange(x.size(0)), collect_ind] @ self.text_projection
419 | return x
420 |
421 |
422 | def encode_text_img(self, text, img_tokens):
423 | b_size = img_tokens.size(0)
424 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
425 | collect_ind = text == self.end_id
426 | collect_ind = collect_ind.nonzero()[:, 1]
427 | img_tokens = img_tokens.view(b_size, 1, -1)
428 | x = torch.cat([x[:, :collect_ind[0]], img_tokens, x[:, collect_ind[0]:-1]], dim=1)
429 | x = x + self.positional_embedding.type(self.dtype)
430 | x = x.permute(1, 0, 2) # NLD -> LND
431 | x = self.transformer(x)
432 | x = x.permute(1, 0, 2) # LND -> NLD
433 | x = self.ln_final(x).type(self.dtype)
434 | # x.shape = [batch_size, n_ctx, transformer.width]
435 | # take features from the eot embedding (eot_token is the highest number in each sequence)
436 | x = x[torch.arange(x.size(0)), collect_ind+1] @ self.text_projection
437 | return x
438 |
439 | def encode_text_img_vis(self, text, img_tokens, split_ind=4):
440 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
441 | collect_ind = text == self.end_id
442 | collect_ind = collect_ind.nonzero()[:, 1]
443 | new_x = []
444 | for i, sample in enumerate(x):
445 | ind_insert = text[i] == split_ind
446 | sample = sample.view(1, x.size(1), -1)
447 | if isinstance(img_tokens, tuple):
448 | indexes = ind_insert.nonzero()
449 | for i, index in enumerate(indexes):
450 | img = img_tokens[i].view(1, 1, -1)
451 | sample = torch.cat([sample[:, :index], img, sample[:, index+1:]], dim=1)
452 | else:
453 | img_tokens = img_tokens.view(1, 1, -1)
454 | ind_insert = ind_insert.nonzero()[0]
455 | sample = torch.cat([sample[:, :ind_insert], img_tokens, sample[:, ind_insert+1:]], dim=1)
456 | new_x.append(sample)
457 | x = torch.cat(new_x, dim=0)
458 | x = x + self.positional_embedding.type(self.dtype)
459 | x = x.permute(1, 0, 2) # NLD -> LND
460 | x = self.transformer(x)
461 | x = x.permute(1, 0, 2) # LND -> NLD
462 | x = self.ln_final(x).type(self.dtype)
463 | # x.shape = [batch_size, n_ctx, transformer.width]
464 | # take features from the eot embedding (eot_token is the highest number in each sequence)
465 | x = x[torch.arange(x.size(0)), collect_ind] @ self.text_projection
466 | return x
467 |
468 | def encode_text_img_retrieval(self, text, img_tokens, split_ind=4, repeat=True):
469 | # text.shape = [1, n_ctx]
470 | # img_tokens.shape = [batch_size, d_model]
471 | if isinstance(img_tokens, tuple):
472 | b_size = img_tokens[0].shape[0]
473 | else:
474 | b_size = img_tokens.shape[0]
475 | if repeat:
476 | text = text.repeat(b_size, 1)
477 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
478 | collect_ind = text == self.end_id
479 | collect_ind = collect_ind.nonzero()[:, 1]
480 | ind_insert = text[0] == split_ind
481 | if isinstance(img_tokens, tuple):
482 | indexes = ind_insert.nonzero()
483 | for i, index in enumerate(indexes):
484 | img = img_tokens[i].view(b_size, 1, -1)
485 | x = torch.cat([x[:, :index], img, x[:, index+1:]], dim=1)
486 | else:
487 | img_tokens = img_tokens.view(b_size, 1, -1)
488 | ind_insert = ind_insert.nonzero()[0]
489 | x = torch.cat([x[:, :ind_insert], img_tokens, x[:, ind_insert+1:]], dim=1)
490 | #x = torch.cat([x, torch.zeros_like(x).cuda()[:, :1, :]], dim=1)
491 | x = x + self.positional_embedding.type(self.dtype)
492 | x = x.permute(1, 0, 2) # NLD -> LND
493 | x = self.transformer(x)
494 | x = x.permute(1, 0, 2) # LND -> NLD
495 | x = self.ln_final(x).type(self.dtype)
496 | # x.shape = [batch_size, n_ctx, transformer.width]
497 | # take features from the eot embedding (eot_token is the highest number in each sequence)
498 | x = x[torch.arange(x.size(0)), collect_ind] @ self.text_projection
499 | return x
500 |
501 | def forward(self, image, text, extra=False):
502 | if image is None:
503 | if extra:
504 | return self.encode_text_extra(text)
505 | else:
506 | return self.encode_text(text)
507 | elif text is None:
508 | return self.encode_image(image)
509 | image_features = self.encode_image(image)
510 | if extra:
511 | text_features = self.encode_text_extra(text)
512 | else:
513 | text_features = self.encode_text(text)
514 |
515 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
516 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
517 |
518 | return image_features, text_features, self.logit_scale.exp()
519 |
520 |
521 | @torch.no_grad()
522 | def concat_all_gather(tensor):
523 | """
524 | Performs all_gather operation on the provided tensors.
525 | *** Warning ***: torch.distributed.all_gather has no gradient.
526 | """
527 | tensors_gather = [torch.ones_like(tensor)
528 | for _ in range(torch.distributed.get_world_size())]
529 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
530 |
531 | output = torch.cat(tensors_gather, dim=0)
532 | return output
533 |
534 | def convert_weights(model: nn.Module):
535 | """Convert applicable model parameters to fp16"""
536 |
537 | def _convert_weights_to_fp16(l):
538 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
539 | l.weight.data = l.weight.data.half()
540 | if l.bias is not None:
541 | l.bias.data = l.bias.data.half()
542 |
543 | if isinstance(l, nn.MultiheadAttention):
544 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
545 | tensor = getattr(l, attr)
546 | if tensor is not None:
547 | tensor.data = tensor.data.half()
548 |
549 | for name in ["text_projection", "proj"]:
550 | if hasattr(l, name):
551 | attr = getattr(l, name)
552 | if attr is not None:
553 | attr.data = attr.data.half()
554 |
555 | model.apply(_convert_weights_to_fp16)
556 |
557 |
558 | def build_model(state_dict: dict):
559 | vit = "visual.proj" in state_dict
560 |
561 | if vit:
562 | vision_width = state_dict["visual.conv1.weight"].shape[0]
563 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
564 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
565 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
566 | image_resolution = vision_patch_size * grid_size
567 | else:
568 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
569 | vision_layers = tuple(counts)
570 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
571 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
572 | vision_patch_size = None
573 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
574 | image_resolution = output_width * 32
575 |
576 | embed_dim = state_dict["text_projection"].shape[1]
577 | context_length = state_dict["positional_embedding"].shape[0]
578 | vocab_size = state_dict["token_embedding.weight"].shape[0]
579 | transformer_width = state_dict["ln_final.weight"].shape[0]
580 | transformer_heads = transformer_width // 64
581 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
582 |
583 | model = CLIP(
584 | embed_dim,
585 | image_resolution, vision_layers, vision_width, vision_patch_size,
586 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
587 | )
588 |
589 | for key in ["input_resolution", "context_length", "vocab_size"]:
590 | if key in state_dict:
591 | del state_dict[key]
592 |
593 | convert_weights(model)
594 | model.load_state_dict(state_dict)
595 | return model.eval()
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | scikit-image
3 | scikit-learn
4 | torch
5 | torchvision
6 | tensorboard
7 | ase==3.21.1
8 | braceexpand==0.1.7
9 | cached-property==1.5.2
10 | configparser==5.0.2
11 | cycler==0.10.0
12 | decorator==4.4.2
13 | docker-pycreds==0.4.0
14 | gitdb==4.0.7
15 | gitpython==3.1.30
16 | googledrivedownloader==0.4
17 | h5py==3.1.0
18 | isodate==0.6.0
19 | jinja2==3.0.1
20 | kiwisolver==1.3.1
21 | littleutils==0.2.2
22 | llvmlite==0.36.0
23 | markupsafe==2.0.1
24 | matplotlib==3.3.4
25 | networkx==2.5.1
26 | numba==0.53.1
27 | ogb==1.3.1
28 | outdated==0.2.1
29 | pathtools==0.1.2
30 | promise==2.3
31 | psutil==5.8.0
32 | pyarrow==4.0.0
33 | pyparsing==2.4.7
34 | python-louvain==0.15
35 | pyyaml==5.4.1
36 | rdflib==5.0.0
37 | sentry-sdk==1.14.0
38 | shortuuid==1.0.1
39 | sklearn==0.0
40 | smmap==4.0.0
41 | subprocess32==3.5.4
42 | torch-geometric==1.7.0
43 | wandb==0.10.30
44 | wilds==1.1.0
45 | ftfy
46 | regex
47 | webdataset
48 | requests
49 | hydra-core
50 | omegaconf
51 | fairseq==0.10.0
52 | bitarray
--------------------------------------------------------------------------------
/setenv.sh:
--------------------------------------------------------------------------------
1 | sudo apt install python3-dev python3-virtualenv python3-tk imagemagick
2 | virtualenv -p python3 --system-site-packages env3
3 | . env3/bin/activate
4 | pip install -r requirements.txt
5 | deactivate
6 |
--------------------------------------------------------------------------------
/src/data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import sys
17 | import math
18 | import logging
19 | import functools
20 | import braceexpand
21 | import random
22 | import pdb
23 | import json
24 |
25 | import pandas as pd
26 | import numpy as np
27 | import pyarrow as pa
28 | from PIL import Image
29 | Image.MAX_IMAGE_PIXELS = 1000000000
30 |
31 | from typing import Union
32 | from dataclasses import dataclass
33 | import torch
34 | import torch.distributed as dist
35 | from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
36 | from torch.utils.data.distributed import DistributedSampler
37 | import torchvision.datasets as datasets
38 | from torchvision.datasets.folder import DatasetFolder
39 | import torchvision.datasets as datasets
40 | import torchvision.transforms as T
41 | from third_party.open_clip.clip import tokenize
42 |
43 |
44 | ## Structure of dataset directory
45 | ## CIRR: under ./data/CIRR
46 | ## validation images ./dev/
47 | ## caption split ./captions/cap.rc2.val.json
48 | ## image split ./image_splits/split.rc2.val.json
49 | class CIRR(Dataset):
50 | def __init__(self, transforms, mode='caps',
51 | vis_mode=False, test=False, root='./data'):
52 | self.mode = mode
53 | self.transforms = transforms
54 | self.vis_mode = vis_mode
55 | ## mode to use test split of CIRR
56 | self.test = test
57 | self.root = os.path.join(root, 'CIRR')
58 | self.root_img = os.path.join(self.root, 'dev')
59 | if self.test:
60 | self.root_img = os.path.join(self.root, 'test1')
61 | if self.mode == 'caps':
62 | self.json = os.path.join(self.root , 'captions/cap.rc2.test1.json')
63 | else:
64 | self.json = os.path.join(self.root, 'image_splits/split.rc2.test1.json')
65 | else:
66 | if self.mode == 'caps':
67 | self.json = os.path.join(self.root, 'captions/cap.rc2.val.json')
68 | else:
69 | self.json = os.path.join(self.root, 'image_splits/split.rc2.val.json')
70 | logging.debug(f'Loading json data from {self.json}.')
71 | data = json.load(open(self.json, "r"))
72 | self.ref_imgs = []
73 | self.target_imgs = []
74 | self.target_caps = []
75 | if self.test:
76 | self.init_test(data)
77 | elif self.mode == 'caps':
78 | self.init_val(data)
79 | else:
80 | self.target_imgs = [key + ".png" for key in data.keys()]
81 | if self.vis_mode:
82 | self.target_imgs = list(set(self.target_imgs))
83 | logging.info("Use {} imgs".format(len(self.target_imgs)))
84 |
85 | def init_test(self, data):
86 | self.pairids = []
87 | if self.mode == 'caps':
88 | for d in data:
89 | ref_path = d['reference']+ ".png"
90 | self.ref_imgs.append(ref_path)
91 | self.target_caps.append(d['caption'])
92 | self.pairids.append(d['pairid'])
93 | self.target_imgs.append('dummy')
94 | else:
95 | self.target_imgs = [key + ".png" for key in data.keys()]
96 |
97 | def init_val(self, data):
98 | for d in data:
99 | ref_path = d['reference']+ ".png"
100 | tar_path = d['target_hard']+ ".png"
101 | self.ref_imgs.append(ref_path)
102 | self.target_imgs.append(tar_path)
103 | self.target_caps.append(d['caption'])
104 |
105 | def return_testdata(self, idx):
106 | if self.mode == 'caps':
107 | ref_path = str(self.ref_imgs[idx])
108 | img_path = os.path.join(self.root_img, ref_path)
109 | ref_images = self.transforms(Image.open(img_path))
110 | target_cap = self.target_caps[idx]
111 | text_with_blank_raw = 'a photo of * , {}'.format(target_cap)
112 | caption_only = tokenize(target_cap)[0]
113 | text_with_blank = tokenize(text_with_blank_raw)[0]
114 | return ref_images, text_with_blank, \
115 | caption_only, str(self.ref_imgs[idx]), \
116 | self.pairids[idx], text_with_blank_raw
117 | else:
118 | tar_path = str(self.target_imgs[idx])
119 | img_path = Image.open(os.path.join(self.root_img, tar_path))
120 | target_images = self.transforms(img_path)
121 | return target_images, tar_path
122 |
123 | def return_valdata(self, idx):
124 | if self.mode == 'caps' and not self.vis_mode:
125 | ref_path = str(self.ref_imgs[idx])
126 | img_path = os.path.join(self.root_img, ref_path)
127 | ref_images = self.transforms(Image.open(img_path))
128 | target_cap = self.target_caps[idx]
129 | text_with_blank = 'a photo of * , {}'.format(target_cap)
130 | caption_only = tokenize(target_cap)[0]
131 | ref_text_tokens = tokenize(text_with_blank)[0]
132 | return ref_images, ref_text_tokens, caption_only, \
133 | str(self.ref_imgs[idx]), str(self.target_imgs[idx]), \
134 | target_cap
135 | else:
136 | tar_path = str(self.target_imgs[idx])
137 | img_path = os.path.join(self.root_img, tar_path)
138 | target_images = self.transforms(Image.open(img_path))
139 | return target_images, img_path
140 |
141 | def __getitem__(self, idx):
142 | if self.test:
143 | return self.return_testdata(idx)
144 | else:
145 | return self.return_valdata(idx)
146 |
147 | def __len__(self):
148 | return len(self.target_imgs)
149 |
150 | ## Fashion-IQ: under ./data/fashion-iq
151 | ## validation images ./images
152 | ## caption split ./json/cap.{cloth_type}.val.json, cloth_type in [toptee, shirt, dress]
153 | ## image split ./image_splits/split.{cloth_type}.val.json, cloth_type in [toptee, shirt, dress]
154 | class FashionIQ(Dataset):
155 | def __init__(self, cloth, transforms, is_train=False, vis_mode=False, \
156 | mode='caps', is_return_target_path=False, root='./data'):
157 | root_iq = os.path.join(root, 'fashion-iq')
158 | self.root_img = os.path.join(root_iq, 'images')
159 | self.vis_mode = vis_mode
160 | self.mode = mode
161 | self.is_return_target_path = is_return_target_path
162 | self.transforms = transforms
163 | if mode == 'imgs':
164 | self.json_file = os.path.join(root_iq, 'image_splits', \
165 | 'split.{}.val.json'.format(cloth))
166 | else:
167 | self.json_file = os.path.join(root_iq, 'json', \
168 | 'cap.{}.val.json'.format(cloth))
169 | logging.debug(f'Loading json data from {self.json_file}.')
170 |
171 | self.ref_imgs = []
172 | self.target_imgs = []
173 | self.ref_caps = []
174 | self.target_caps = []
175 | if mode == 'imgs':
176 | self.init_imgs()
177 | logging.info("Use {} imgs".format(len(self.target_imgs)))
178 | else:
179 | self.init_data()
180 | logging.info("Use {} imgs".format(len(self.target_imgs)))
181 |
182 | def init_imgs(self):
183 | data = json.load(open(self.json_file, "r"))
184 | self.target_imgs = [key + ".png" for key in data]
185 |
186 | def init_data(self):
187 | def load_data(data):
188 | for d in data:
189 | ref_path = os.path.join(self.root_img, d['candidate']+ ".png")
190 | tar_path = os.path.join(self.root_img, d['target']+ ".png")
191 | try:
192 | Image.open(ref_path)
193 | Image.open(tar_path)
194 | self.ref_imgs.append(ref_path)
195 | self.target_imgs.append(tar_path)
196 | self.ref_caps.append((d['captions'][0], d['captions'][1]))
197 | #self.target_caps.append(d['captions'][1])
198 | except:
199 | print('cannot load {}'.format(d['candidate']))
200 | if isinstance(self.json_file, str):
201 | data = json.load(open(self.json_file, "r"))
202 | load_data(data)
203 | elif isinstance(self.json_file, list):
204 | for filename in self.json_file:
205 | data = json.load(open(filename, "r"))
206 | load_data(data)
207 |
208 | def __len__(self):
209 | if self.mode == 'caps':
210 | return len(self.ref_imgs)
211 | else:
212 | return len(self.target_imgs)
213 |
214 | def return_imgs(self, idx):
215 | tar_path = str(self.target_imgs[idx])
216 | img_path = os.path.join(self.root_img, tar_path)
217 | target_images = self.transforms(Image.open(img_path))
218 | return target_images, os.path.join(self.root_img, tar_path)
219 |
220 | def return_all(self, idx):
221 | if self.vis_mode:
222 | tar_path = str(self.target_imgs[idx])
223 | target_images = self.transforms(Image.open(tar_path))
224 | return target_images, tar_path
225 | ref_images = self.transforms(Image.open(str(self.ref_imgs[idx])))
226 | target_images = self.transforms(Image.open(str(self.target_imgs[idx])))
227 | cap1, cap2 = self.ref_caps[idx]
228 | text_with_blank = 'a photo of * , {} and {}'.format(cap2, cap1)
229 | token_texts = tokenize(text_with_blank)[0]
230 | if self.is_return_target_path:
231 | return ref_images, target_images, token_texts, token_texts, \
232 | str(self.target_imgs[idx]), str(self.ref_imgs[idx]), \
233 | cap1
234 | else:
235 | return ref_images, target_images, text_with_blank
236 |
237 |
238 | def __getitem__(self, idx):
239 | if self.mode == 'imgs':
240 | return self.return_imgs(idx)
241 | else:
242 | return self.return_all(idx)
243 |
244 | ## COCO: under ./data/coco
245 | ## validation images ./val2017
246 | ## validation masked images ./val2017_masked
247 | ## validation csv file ./coco_eval.csv
248 | class CsvCOCO(Dataset):
249 | def __init__(self, transforms, transforms_region, sep=",",
250 | return_data_identifier=False, return_filename=False,
251 | root='./data'):
252 | self.transforms = transforms
253 | self.transforms_region = transforms_region
254 | self.root = os.path.join(root, 'coco')
255 | self.root_img = os.path.join(self.root, 'val2017')
256 | self.csv_file = os.path.join(self.root, 'coco_eval.csv')
257 | logging.debug(f'Loading csv data from {self.csv_file}.')
258 | df = pd.read_csv(self.csv_file, sep=sep)
259 | self.images = df['id'].tolist()
260 | ## query_region contains the box of query regions.
261 | regions = df['query_regions'].tolist()
262 | self.regions = []
263 | for region in regions:
264 | x1, y1, x2, y2 = map(lambda x: int(float(x)), region.split(";"))
265 | self.regions.append([x1, y1, x2, y2])
266 |
267 | ## query_classes contains the class of query region in the target.
268 | self.query_classes = df['query_class'].tolist()
269 | self.classes = []
270 | ## classes contains the list of classes in the target.
271 | for list_class in df['classes'].tolist():
272 | if isinstance(list_class, str):
273 | list_class = list_class.split(";")
274 | self.classes.append(list_class)
275 | else:
276 | self.classes.append([""])
277 | self.return_data_identifier = return_data_identifier
278 | logging.debug('Done loading data.')
279 | self.return_filename = return_filename
280 |
281 | def __len__(self):
282 | return len(self.images)
283 |
284 | def __getitem__(self, idx):
285 | img_path = os.path.join(self.root_img, str(self.images[idx]))
286 | image = Image.open(img_path)
287 | masked_path = os.path.join(self.root_img.replace('val2017', 'val2017_masked'), \
288 | str(self.images[idx]))
289 | image_masked = Image.open(masked_path)
290 |
291 | ## extract query region.
292 | x1, y1, x2, y2 = self.regions[idx]
293 | region_image = image_masked.crop((x1, y1, x2, y2))
294 |
295 | image = self.transforms(image)
296 | ## no cropping is applied to query region.
297 | region_image = self.transforms_region(region_image)
298 | query_class = self.query_classes[idx]
299 | other_classes = self.classes[idx]
300 | text_with_blank = 'a photo of * and {}'.format(" and ".join(other_classes))
301 | text_with_queryclass = 'a photo of * and {} and {}'.format(query_class, \
302 | " and ".join(other_classes))
303 | raw_text = text_with_queryclass
304 | text_full = 'a photo of {} and {}'.format(query_class, " and ".join(other_classes))
305 | text_with_blank = tokenize(text_with_blank)[0]
306 | text_with_queryclass = tokenize(text_with_queryclass)[0]
307 | text_full = tokenize(text_full)[0]
308 | return image, region_image, text_full, text_with_blank, \
309 | text_with_queryclass, str(self.images[idx]), raw_text
310 |
311 |
312 | class ImageList(Dataset):
313 | def __init__(self, input_filename, transforms, root=None,
314 | return_filename=False, is_labels=False):
315 | logging.debug(f'Loading txt data from {input_filename}.')
316 | with open(input_filename, 'r') as f:
317 | lines = f.readlines()
318 | if not is_labels:
319 | self.images = [line.strip() for line in lines]
320 | else:
321 | filenames = [line.strip() for line in lines]
322 | self.images = [name.split(" ")[0] for name in filenames]
323 | self.labels = [int(name.split(" ")[1]) for name in filenames]
324 | self.is_labels = is_labels
325 | self.transforms = transforms
326 | self.root = root
327 | logging.debug('Done loading data.')
328 | self.return_filename = return_filename
329 |
330 | def __len__(self):
331 | return len(self.images)
332 |
333 | def __getitem__(self, idx):
334 | if self.root is not None:
335 | img_path = os.path.join(self.root, str(self.images[idx]))
336 | else:
337 | img_path = str(self.images[idx])
338 | images = self.transforms(Image.open(img_path))
339 | if self.return_filename:
340 | return images, img_path
341 | elif self.is_labels:
342 | target = self.labels[idx]
343 | return images, target
344 | else:
345 | return images
346 |
347 |
348 | class CustomFolder(Dataset):
349 | def __init__(self, folder, transform):
350 | image_lists = os.listdir(folder)
351 | self.samples = [os.path.join(folder, name) for name in image_lists]
352 | self.transform = transform
353 |
354 | def __len__(self):
355 | return len(self.samples)
356 |
357 | def __getitem__(self, index: int):
358 | """
359 | Args:
360 | index (int): Index
361 |
362 | Returns:
363 | tuple: (sample, target) where target is class_index of the target class.
364 | """
365 | path = self.samples[index]
366 | sample = Image.open(str(path))
367 | if self.transform is not None:
368 | sample = self.transform(sample)
369 | return sample, path
370 |
371 |
372 | class CsvDataset(Dataset):
373 | def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t",
374 | return_data_identifier=False, return_filename=False):
375 | logging.debug(f'Loading csv data from {input_filename}.')
376 | df = pd.read_csv(input_filename, sep=sep)
377 | self.images = df[img_key].tolist()
378 | self.captions = df[caption_key].tolist()
379 | self.transforms = transforms
380 | self.return_data_identifier = return_data_identifier
381 | logging.debug('Done loading data of {} samples'.format(len(self.images)))
382 | self.return_filename = return_filename
383 |
384 | def __len__(self):
385 | return len(self.captions)
386 |
387 | def __getitem__(self, idx):
388 | images = self.transforms(Image.open(str(self.images[idx])))
389 | if self.return_filename:
390 | return images, str(self.images[idx])
391 | texts = tokenize([str(self.captions[idx])])[0]
392 |
393 | if self.return_data_identifier:
394 | return images, texts, 0
395 | return images, texts
396 |
397 | @dataclass
398 | class DataInfo:
399 | dataloader: DataLoader
400 | sampler: DistributedSampler
401 |
402 | def preprocess_txt(text):
403 | return tokenize([str(text)])[0]
404 |
405 | def get_dataset_size(shards):
406 | shards_list = list(braceexpand.braceexpand(shards))
407 | dir_path = os.path.dirname(shards)
408 | sizes_filename = os.path.join(dir_path, 'sizes.json')
409 | sizes = json.load(open(sizes_filename, 'r'))
410 | total_size = sum(
411 | [int(sizes[os.path.basename(shard)]) for shard in shards_list])
412 | num_shards = len(shards_list)
413 | return total_size, num_shards
414 |
415 | def get_imagenet(args, preprocess_fns, split):
416 | assert split in ["train", "val", "v2"]
417 | is_train = split == "train"
418 | preprocess_train, preprocess_val = preprocess_fns
419 |
420 | if split == "v2":
421 | from imagenetv2_pytorch import ImageNetV2Dataset
422 | dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
423 | else:
424 | if is_train:
425 | data_path = args.imagenet_train
426 | preprocess_fn = preprocess_train
427 | else:
428 | data_path = args.imagenet_val
429 | preprocess_fn = preprocess_val
430 | assert data_path
431 |
432 | dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
433 |
434 | if is_train:
435 | idxs = np.zeros(len(dataset.targets))
436 | target_array = np.array(dataset.targets)
437 | k = 50
438 | for c in range(1000):
439 | m = target_array == c
440 | n = len(idxs[m])
441 | arr = np.zeros(n)
442 | arr[:k] = 1
443 | np.random.shuffle(arr)
444 | idxs[m] = arr
445 |
446 | idxs = idxs.astype('int')
447 | sampler = SubsetRandomSampler(np.where(idxs)[0])
448 | else:
449 | sampler = None
450 |
451 | dataloader = torch.utils.data.DataLoader(
452 | dataset,
453 | batch_size=args.batch_size,
454 | num_workers=args.workers,
455 | sampler=sampler,
456 | )
457 | return DataInfo(dataloader, sampler)
458 |
459 | def count_samples(dataloader):
460 | os.environ["WDS_EPOCH"] = "0"
461 | n_elements, n_batches = 0, 0
462 | for images, texts in dataloader:
463 | n_batches += 1
464 | n_elements += len(images)
465 | assert len(images) == len(texts)
466 | return n_elements, n_batches
467 |
468 | def get_csv_dataset(args, preprocess_fn, is_train, input_filename=None):
469 | if input_filename is None:
470 | input_filename = args.train_data if is_train else args.val_data
471 | assert input_filename
472 | dataset = CsvDataset(
473 | input_filename,
474 | preprocess_fn,
475 | img_key=args.csv_img_key,
476 | caption_key=args.csv_caption_key,
477 | sep=args.csv_separator)
478 |
479 | num_samples = len(dataset)
480 | sampler = DistributedSampler(dataset) if args.distributed and is_train else None
481 | shuffle = is_train and sampler is None
482 |
483 | dataloader = DataLoader(
484 | dataset,
485 | batch_size=args.batch_size,
486 | shuffle=shuffle,
487 | num_workers=args.workers,
488 | pin_memory=True,
489 | sampler=sampler,
490 | drop_last=is_train,
491 | )
492 | dataloader.num_samples = num_samples
493 | dataloader.num_batches = len(dataloader)
494 |
495 | return DataInfo(dataloader, sampler)
496 |
497 |
498 | #
499 | def get_imgnet_r(args, preprocess_fn, is_train, input_filename=None):
500 | if input_filename is None:
501 | input_filename = args.train_data if is_train else args.val_data
502 | assert input_filename
503 | path_data = os.path.join(args.root_data, 'imgnet/imagenet-r')
504 | dataset = CustomFolder(path_data, transform=preprocess_fn)
505 | num_samples = len(dataset)
506 | sampler = DistributedSampler(dataset) if args.distributed and is_train else None
507 | shuffle = is_train and sampler is None
508 | dataloader = DataLoader(
509 | dataset,
510 | batch_size=args.batch_size,
511 | shuffle=shuffle,
512 | num_workers=args.workers,
513 | pin_memory=True,
514 | sampler=sampler,
515 | drop_last=is_train,
516 | )
517 | dataloader.num_samples = num_samples
518 | dataloader.num_batches = len(dataloader)
519 | return DataInfo(dataloader, sampler)
520 |
521 |
522 | def get_directory_dataset(args, preprocess_fn, is_train, input_filename=None):
523 | if input_filename is None:
524 | input_filename = args.train_data if is_train else args.val_data
525 | assert input_filename
526 | dataset = CustomFolder(
527 | input_filename,
528 | transform=preprocess_fn)
529 | num_samples = len(dataset)
530 | sampler = DistributedSampler(dataset) if args.distributed and is_train else None
531 | shuffle = is_train and sampler is None
532 |
533 | dataloader = DataLoader(
534 | dataset,
535 | batch_size=args.batch_size,
536 | shuffle=shuffle,
537 | num_workers=args.workers,
538 | pin_memory=True,
539 | sampler=sampler,
540 | drop_last=is_train,
541 | )
542 | dataloader.num_samples = num_samples
543 | dataloader.num_batches = len(dataloader)
544 |
545 | return DataInfo(dataloader, sampler)
546 |
547 |
548 | def get_dataset_fn(data_path, dataset_type):
549 | if dataset_type == 'imgnet_r':
550 | return get_imgnet_r
551 | elif dataset_type == 'fashion-iq':
552 | return get_fashion_iq
553 | elif dataset_type == 'cirr':
554 | return get_cirr
555 | elif dataset_type == 'directory':
556 | return get_directory_dataset
557 | elif dataset_type == "csv":
558 | return get_csv_dataset
559 | elif dataset_type == "auto":
560 | ext = data_path.split('.')[-1]
561 | if ext in ['csv', 'tsv']:
562 | return get_csv_dataset
563 | else:
564 | raise ValueError(
565 | f"Tried to figure out dataset type, but failed for extention {ext}.")
566 | else:
567 | raise ValueError(f"Unsupported dataset type: {dataset_type}")
568 |
569 |
570 | def get_data(args, preprocess_fns):
571 | preprocess_train, preprocess_val = preprocess_fns
572 | data = {}
573 | dataset_type_val = getattr(args, 'dataset_type_val', args.dataset_type)
574 | if args.train_data:
575 | data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
576 | args, preprocess_train, is_train=True)
577 | if args.val_data:
578 | data["val"] = get_dataset_fn(args.val_data, dataset_type_val)(
579 | args, preprocess_val, is_train=False)
580 | if args.imagenet_val is not None:
581 | data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val")
582 | if args.imagenet_v2 is not None:
583 | data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2")
584 | return data
585 |
--------------------------------------------------------------------------------
/src/demo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | import time
16 | import logging
17 | from time import gmtime, strftime
18 | from pathlib import Path
19 | import json
20 | import torch
21 | import torch.distributed as dist
22 | import torch.multiprocessing as mp
23 | import torch.backends.cudnn as cudnn
24 | from torch.utils.tensorboard import SummaryWriter
25 | from torch.utils.data import DataLoader
26 | from model.clip import _transform, load
27 | from model.model import convert_weights, CLIP, IM2TEXT
28 | from eval_utils import visualize_results
29 | from data import get_data, CsvDataset, CustomFolder, CIRR, FashionIQ, ImageList
30 | from params import parse_args, get_project_root
31 | from logger import setup_primary_logging, setup_worker_logging
32 | from utils import is_master, convert_models_to_fp32, TargetPad
33 |
34 | def main_worker(gpu, ngpus_per_node, log_queue, args):
35 | args.gpu = gpu
36 | args.rank = gpu
37 | setup_worker_logging(args.rank, log_queue, args.log_level)
38 |
39 | # Log and save params.
40 | if is_master(args):
41 | logging.info("Params:")
42 | params_file = os.path.join(args.logs, args.name, "params.txt")
43 | with open(params_file, "w") as f:
44 | for name in sorted(vars(args)):
45 | val = getattr(args, name)
46 | logging.info(f"{name}: {val}")
47 | f.write(f"{name}: {val}\n")
48 |
49 | if args.distributed:
50 | dist.init_process_group(
51 | backend=args.dist_backend,
52 | init_method=args.dist_url,
53 | world_size=args.world_size,
54 | rank=args.rank,
55 | )
56 |
57 | if args.dp:
58 | args.batch_size *= args.world_size
59 |
60 | if args.gpu is not None:
61 | logging.info(f"Use GPU: {args.gpu} for training")
62 | torch.cuda.set_device(args.gpu)
63 |
64 | # Do not use skip_reset unless you want to use on of the CLIP model
65 | if args.openai_pretrained:
66 | model, preprocess_train, preprocess_val = load(
67 | args.model,
68 | jit=False)
69 | else:
70 | model_config_file = Path(__file__).parent / f"model_configs/{args.model.replace('/', '-')}.json"
71 | print('Loading model from', model_config_file)
72 | assert os.path.exists(model_config_file)
73 | with open(model_config_file, 'r') as f:
74 | model_info = json.load(f)
75 | if args.use_prefix:
76 | model_info['vocab_size'] += 1
77 | model_info['use_prefix'] = True
78 | model = CLIP(**model_info)
79 | convert_weights(model)
80 | preprocess_train = _transform(model.visual.input_resolution, is_train=True)
81 | preprocess_val = _transform(model.visual.input_resolution, is_train=False)
82 | img2text = IM2TEXT(embed_dim=model.embed_dim, output_dim=model.token_embedding.weight.shape[1])
83 |
84 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
85 | if args.precision == "amp" or args.precision == "fp32" or args.gpu is None:
86 | convert_models_to_fp32(model)
87 |
88 | if not torch.cuda.is_available():
89 | model.float()
90 | img2text.float()
91 | logging.warning("using CPU, this will be slow")
92 | else:
93 | model.cuda(args.gpu)
94 | img2text.cuda(args.gpu)
95 | if args.precision == "fp16":
96 | convert_weights(model)
97 | convert_weights(img2text)
98 | # Previously batch size and workers were global and not per GPU.
99 | # args.batch_size = args.batch_size / ngpus_per_node)
100 | # args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
101 |
102 | if args.distributed and args.use_bn_sync:
103 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
104 | if args.distributed:
105 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=model.has_extra)
106 | img2text = torch.nn.parallel.DistributedDataParallel(img2text, device_ids=[args.gpu], find_unused_parameters=False)
107 | if args.dp:
108 | model = torch.nn.DataParallel(model, device_ids=args.multigpu)
109 | img2text = torch.nn.DataParallel(img2text, device_ids=args.multigpu)
110 |
111 | if args.precision == "fp16":
112 | convert_weights(model)
113 | convert_weights(img2text)
114 |
115 | data = get_data(args, (preprocess_train, preprocess_val))
116 | if args.resume == 'auto':
117 | checkpoint_list = os.listdir(args.checkpoint_path)
118 | checkpoint_list = [ckpt for ckpt in checkpoint_list if ckpt.startswith('epoch')]
119 | if checkpoint_list:
120 | latest_epoch = max([int(ckpt.split('_')[1].split('.')[0]) for ckpt in checkpoint_list])
121 | args.resume = os.path.join(args.checkpoint_path, f'epoch_{latest_epoch}.pt')
122 | else:
123 | args.resume = None
124 |
125 | if args.resume is not None:
126 | if os.path.isfile(args.resume):
127 | if args.gpu is None:
128 | checkpoint = torch.load(args.resume)
129 | else:
130 | # Map model to be loaded to specified single gpu.
131 | loc = "cuda:{}".format(args.gpu)
132 | checkpoint = torch.load(args.resume, map_location=loc)
133 | sd = checkpoint["state_dict"]
134 | sd_img2text = checkpoint["state_dict_img2text"]
135 | if not args.distributed and next(iter(sd.items()))[0].startswith('module'):
136 | sd = {k[len('module.'):]: v for k, v in sd.items()}
137 | if not args.distributed and next(iter(sd_img2text.items()))[0].startswith('module'):
138 | sd_img2text = {k[len('module.'):]: v for k, v in sd_img2text.items()}
139 | model.load_state_dict(sd)
140 | img2text.load_state_dict(sd_img2text)
141 | logging.info(
142 | f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})"
143 | )
144 | else:
145 | logging.info("=> no checkpoint found at '{}'".format(args.resume))
146 | cudnn.benchmark = True
147 | cudnn.deterministic = False
148 | prompt = args.prompts.split(",")
149 | root_project = os.path.join(get_project_root(), 'data')
150 | logging.info("root dir '{}'".format(root_project))
151 | logging.info("prompt list '{}'".format(prompt))
152 |
153 | if "csv" in args.retrieval_data:
154 | dataset = CsvDataset(
155 | args.retrieval_data,
156 | preprocess_val,
157 | img_key=args.csv_img_key,
158 | caption_key=args.csv_caption_key,
159 | sep=args.csv_separator,
160 | return_filename=True)
161 | elif args.retrieval_data == 'imgnet':
162 | target_path = os.path.join(root_project, "imgnet", "imgnet_targets.txt")
163 | dataset = ImageList(target_path, root=root_project, transforms=preprocess_val,
164 | is_labels=True, return_filename=True)
165 | elif args.retrieval_data == 'cirr':
166 | dataset = CIRR(
167 | transforms=preprocess_val,
168 | root=root_project,
169 | mode='caps',
170 | vis_mode=True,
171 | )
172 | elif args.retrieval_data in ['dress', 'shirt', 'toptee']:
173 | dataset = FashionIQ(cloth=args.retrieval_data,
174 | transforms=preprocess_val,
175 | root=root_project,
176 | mode='caps',
177 | vis_mode=True)
178 | elif args.retrieval_data == 'coco':
179 | dataset = CustomFolder(os.path.join(root_project, "coco/val2017"), transform=preprocess_val)
180 | else:
181 | raise ValueError
182 | dataloader = DataLoader(
183 | dataset,
184 | batch_size=args.batch_size,
185 | shuffle=False,
186 | num_workers=args.workers,
187 | pin_memory=True,
188 | drop_last=False,
189 | )
190 | visualize_results(model, img2text, args, prompt, dataloader, )
191 |
192 |
193 | def main():
194 | args = parse_args()
195 |
196 | # get the name of the experiments
197 | if args.name is None:
198 | args.name = (f"lr={args.lr}_"
199 | "wd={args.wd}_"
200 | "agg={args.aggregate}_"
201 | "model={args.model}_"
202 | "batchsize={args.batch_size}_workers={args.workers}")
203 | if args.time_suffix:
204 | args.name += "_date=%Y-%m-%d-%H-%M-%S"
205 | args.name = strftime(args.name, gmtime())
206 |
207 | if args.copy_codebase:
208 | import sys, subprocess
209 | from shutil import copytree, ignore_patterns
210 | new_code_path = os.path.join(args.logs, args.name, "code")
211 | if os.path.exists(new_code_path):
212 | print(
213 | f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment."
214 | )
215 | return -1
216 | print(f"Copying codebase to {new_code_path}")
217 | current_code_path = os.path.realpath(__file__)
218 | for _ in range(3):
219 | current_code_path = os.path.dirname(current_code_path)
220 | copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb'))
221 | print("Done copying code.")
222 | os.environ["PYTHONPATH"] = f"{os.environ['PYTHONPATH']}:{os.path.join(new_code_path, 'src')}"
223 | main_file = os.path.join(new_code_path, "src", "training", "main.py")
224 | argv = sys.argv
225 | argv.remove('--copy-codebase')
226 | argv.extend(['--name', args.name])
227 | command = [sys.executable] + argv
228 | print("Executing command:", " ".join(command))
229 | subprocess.check_call(command)
230 | return 1
231 |
232 | args.log_path = os.path.join(args.logs, args.name, "out.log")
233 | if os.path.exists(args.log_path) and args.resume is None:
234 | print(
235 | "Error. Experiment already exists. Use --name {} to specify a new experiment."
236 | )
237 | return -1
238 |
239 | assert args.precision in ['amp', 'fp16', 'fp32']
240 | #assert args.model in ['RN50', 'RN101', 'RN50x4', 'ViT-B/32'] or os.path.exists(args.model)
241 |
242 | args.ngpus_per_node = torch.cuda.device_count()
243 |
244 | args.wandb = 'wandb' in args.report_to or 'all' in args.report_to
245 | args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to
246 |
247 | args.tensorboard_path = os.path.join(args.logs, args.name, "tensorboard") if args.tensorboard else ''
248 | args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints")
249 | for dirname in [args.tensorboard_path, args.checkpoint_path]:
250 | if dirname:
251 | os.makedirs(dirname, exist_ok=True)
252 |
253 |
254 | # Set multiprocessing type to spawn.
255 | # This is important for logging to work with multiprocessing.
256 | torch.multiprocessing.set_start_method("spawn")
257 |
258 | # Set logger
259 | args.log_level = logging.DEBUG if args.debug else logging.INFO
260 | log_queue = setup_primary_logging(args.log_path, args.log_level)
261 |
262 | # Distributed training = training on more than one GPU.
263 | # Also easily possible to extend to multiple nodes & multiple GPUs.
264 | args.distributed = (args.gpu is None) and torch.cuda.is_available() and (not args.dp)
265 | if args.distributed:
266 | ngpus_per_node = torch.cuda.device_count()
267 | args.world_size = ngpus_per_node
268 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, log_queue, args))
269 | else:
270 | if args.dp:
271 | args.gpu = args.multigpu[0]
272 | args.world_size = len(args.multigpu)
273 | else:
274 | args.world_size = 1
275 | main_worker(args.gpu, None, log_queue, args)
276 |
277 |
278 | if __name__ == "__main__":
279 | main()
280 |
--------------------------------------------------------------------------------
/src/eval_retrieval.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import time
17 | import logging
18 | from time import gmtime, strftime
19 | from pathlib import Path
20 | import json
21 | from functools import partial
22 | import wandb
23 | import torch
24 | from torch import optim
25 | import torch.distributed as dist
26 | import torch.multiprocessing as mp
27 | import torch.backends.cudnn as cudnn
28 | from torch.utils.tensorboard import SummaryWriter
29 | from torch.cuda.amp import GradScaler
30 | from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
31 | import torchvision.datasets as datasets
32 | import torchvision.transforms as T
33 | from PIL import Image
34 |
35 | from model.clip import _transform, load
36 | from model.model import convert_weights, CLIP, IM2TEXT
37 | from eval_utils import evaluate_imgnet_retrieval, evaluate_coco, evaluate_fashion, evaluate_cirr, evaluate_cirr_test
38 | from data import CsvDataset, CustomFolder, ImageList, CsvCOCO, FashionIQ, CIRR
39 | from params import parse_args, get_project_root
40 | from logger import setup_primary_logging, setup_worker_logging
41 | from utils import is_master, convert_models_to_fp32, TargetPad
42 |
43 | def load_model(args):
44 | model, _, preprocess_val = load(
45 | args.model,
46 | jit=False)
47 | img2text = IM2TEXT(embed_dim=model.embed_dim,
48 | middle_dim=args.middle_dim,
49 | output_dim=model.token_embedding.weight.shape[1],
50 | n_layer=args.n_layer)
51 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
52 | if args.precision == "amp" or args.precision == "fp32" or args.gpu is None:
53 | convert_models_to_fp32(model)
54 |
55 | if not torch.cuda.is_available():
56 | model.float()
57 | img2text.float()
58 | logging.warning("using CPU, this will be slow")
59 | else:
60 | model.cuda(args.gpu)
61 | img2text.cuda(args.gpu)
62 | if args.precision == "fp16":
63 | convert_weights(model)
64 | convert_weights(img2text)
65 | # Previously batch size and workers were global and not per GPU.
66 | # args.batch_size = args.batch_size / ngpus_per_node)
67 | # args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
68 | if args.distributed and args.use_bn_sync:
69 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
70 | if args.distributed:
71 | model = torch.nn.parallel.DistributedDataParallel(model,
72 | device_ids=[args.gpu],
73 | find_unused_parameters=model.has_extra)
74 | img2text = torch.nn.parallel.DistributedDataParallel(img2text,
75 | device_ids=[args.gpu], find_unused_parameters=False)
76 | if args.dp:
77 | model = torch.nn.DataParallel(model, device_ids=args.multigpu)
78 | img2text = torch.nn.DataParallel(img2text, device_ids=args.multigpu)
79 |
80 | if args.precision == "fp16":
81 | convert_weights(model)
82 | convert_weights(img2text)
83 | if args.resume == 'auto':
84 | checkpoint_list = os.listdir(args.checkpoint_path)
85 | checkpoint_list = [ckpt for ckpt in checkpoint_list if ckpt.startswith('epoch')]
86 | if checkpoint_list:
87 | latest_epoch = max([int(ckpt.split('_')[1].split('.')[0]) for ckpt in checkpoint_list])
88 | args.resume = os.path.join(args.checkpoint_path, f'epoch_{latest_epoch}.pt')
89 | else:
90 | args.resume = None
91 |
92 | assert args.resume is not None
93 | if os.path.isfile(args.resume):
94 | if args.gpu is None:
95 | checkpoint = torch.load(args.resume)
96 | else:
97 | # Map model to be loaded to specified single gpu.
98 | loc = "cuda:{}".format(args.gpu)
99 | checkpoint = torch.load(args.resume, map_location=loc)
100 | sd = checkpoint["state_dict"]
101 | sd_img2text = checkpoint["state_dict_img2text"]
102 | if not args.distributed and next(iter(sd.items()))[0].startswith('module'):
103 | sd = {k[len('module.'):]: v for k, v in sd.items()}
104 | if not args.distributed and next(iter(sd_img2text.items()))[0].startswith('module'):
105 | sd_img2text = {k[len('module.'):]: v for k, v in sd_img2text.items()}
106 | model.load_state_dict(sd)
107 | img2text.load_state_dict(sd_img2text)
108 | logging.info(
109 | f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})"
110 | )
111 | else:
112 | logging.info("=> no checkpoint found at '{}'".format(args.resume))
113 | return model, img2text, preprocess_val
114 |
115 | def setup_log_save(args):
116 | if is_master(args):
117 | logging.info("Params:")
118 | params_file = os.path.join(args.logs, args.name, "params.txt")
119 | with open(params_file, "w") as f:
120 | for name in sorted(vars(args)):
121 | val = getattr(args, name)
122 | logging.info(f"{name}: {val}")
123 | f.write(f"{name}: {val}\n")
124 |
125 | if args.distributed:
126 | dist.init_process_group(
127 | backend=args.dist_backend,
128 | init_method=args.dist_url,
129 | world_size=args.world_size,
130 | rank=args.rank,
131 | )
132 | if args.dp:
133 | args.batch_size *= args.world_size
134 | if args.gpu is not None:
135 | logging.info(f"Use GPU: {args.gpu} for training")
136 | torch.cuda.set_device(args.gpu)
137 |
138 |
139 | def main_worker(gpu, ngpus_per_node, log_queue, args):
140 | args.gpu = gpu
141 | args.rank = gpu
142 | setup_worker_logging(args.rank, log_queue, args.log_level)
143 | # Log and save params.
144 | setup_log_save(args)
145 | # Load trained model
146 | model, img2text, preprocess_val = load_model(args)
147 | cudnn.benchmark = True
148 | cudnn.deterministic = False
149 | root_project = os.path.join(get_project_root(), 'data')
150 | ## Padding option
151 | if args.target_pad:
152 | trans_tmp = preprocess_val.transforms
153 | trans_tmp = [TargetPad(1.25)] + trans_tmp
154 | preprocess_train = T.Compose(trans_tmp)
155 | preprocess_val = preprocess_train
156 |
157 | ## Load data for each evaluation dataset and perform evaluation.
158 | if args.eval_mode == 'coco':
159 | trans_val = preprocess_val.transforms
160 | n_px = trans_val[1].size
161 | trans_val = [T.Resize(n_px, interpolation=Image.BICUBIC)] + trans_val[2:]
162 | preprocess_val_region = T.Compose(trans_val)
163 | source_dataset = CsvCOCO(transforms=preprocess_val,
164 | transforms_region=preprocess_val_region,
165 | root=root_project)
166 | source_dataloader = DataLoader(
167 | source_dataset,
168 | batch_size=args.batch_size,
169 | shuffle=False,
170 | num_workers=args.workers,
171 | pin_memory=True,
172 | drop_last=False)
173 | evaluate_coco(model, img2text, args, source_dataloader)
174 |
175 | elif args.eval_mode == 'cirr':
176 | source_dataset = CIRR(transforms=preprocess_val,
177 | root=root_project)
178 | target_dataset = CIRR(transforms=preprocess_val,
179 | root=root_project,
180 | mode='imgs')
181 | source_dataloader = DataLoader(
182 | source_dataset,
183 | batch_size=args.batch_size,
184 | shuffle=False,
185 | num_workers=args.workers,
186 | pin_memory=True,
187 | drop_last=False)
188 | target_dataloader = DataLoader(
189 | target_dataset,
190 | batch_size=args.batch_size,
191 | shuffle=False,
192 | num_workers=args.workers,
193 | pin_memory=True,
194 | drop_last=False)
195 | evaluate_cirr(model,
196 | img2text,
197 | args,
198 | source_dataloader,
199 | target_dataloader)
200 |
201 | elif args.eval_mode == 'cirr_test':
202 | source_dataset = CIRR(transforms=preprocess_val,
203 | root=root_project, test=True)
204 | target_dataset = CIRR(transforms=preprocess_val,
205 | root=root_project,
206 | mode='imgs',
207 | test=True)
208 | source_dataloader = DataLoader(
209 | source_dataset,
210 | batch_size=args.batch_size,
211 | shuffle=False,
212 | num_workers=args.workers,
213 | pin_memory=True,
214 | drop_last=False)
215 | target_dataloader = DataLoader(
216 | target_dataset,
217 | batch_size=args.batch_size,
218 | shuffle=False,
219 | num_workers=args.workers,
220 | pin_memory=True,
221 | drop_last=False)
222 | results = evaluate_cirr_test(model,
223 | img2text,
224 | args,
225 | source_dataloader,
226 | target_dataloader)
227 | for key, value in results.items():
228 | with open('res_cirr/' + key + '.json', 'w') as f:
229 | json.dump(value, f)
230 |
231 | elif args.eval_mode == 'fashion':
232 | assert args.source_data in ['dress', 'shirt', 'toptee']
233 | source_dataset = FashionIQ(cloth=args.source_data,
234 | transforms=preprocess_val,
235 | root=root_project,
236 | is_return_target_path=True)
237 | target_dataset = FashionIQ(cloth=args.source_data,
238 | transforms=preprocess_val,
239 | root=root_project,
240 | mode='imgs')
241 | source_dataloader = DataLoader(
242 | source_dataset,
243 | batch_size=args.batch_size,
244 | shuffle=False,
245 | num_workers=args.workers,
246 | pin_memory=True,
247 | drop_last=False)
248 | target_dataloader = DataLoader(
249 | target_dataset,
250 | batch_size=args.batch_size,
251 | shuffle=False,
252 | num_workers=args.workers,
253 | pin_memory=True,
254 | drop_last=False)
255 | evaluate_fashion(model, img2text, args, source_dataloader, target_dataloader)
256 | elif args.eval_mode == 'imgnet':
257 | domains = ['cartoon', 'origami', 'toy', 'sculpture']
258 | prompt = ["a {} of *".format(domain) for domain in domains]
259 | source_path = os.path.join(root_project, "imgnet", "imgnet_real_query.txt")
260 | target_path = os.path.join(root_project, "imgnet", "imgnet_targets.txt")
261 | source_dataset = ImageList(source_path, root=root_project, transforms=preprocess_val, is_labels=True)
262 | target_dataset = ImageList(target_path, root=root_project, transforms=preprocess_val, is_labels=True)
263 | eval_func = evaluate_imgnet_retrieval
264 | source_dataloader = DataLoader(
265 | source_dataset,
266 | batch_size=args.batch_size,
267 | shuffle=False,
268 | num_workers=args.workers,
269 | pin_memory=True,
270 | drop_last=False)
271 | target_dataloader = DataLoader(
272 | target_dataset,
273 | batch_size=args.batch_size,
274 | shuffle=False,
275 | num_workers=args.workers,
276 | pin_memory=True,
277 | drop_last=False)
278 | eval_func(model, img2text, args, prompt, source_dataloader, target_dataloader)
279 |
280 | def main():
281 | args = parse_args()
282 |
283 | # get the name of the experiments
284 | if args.name is None:
285 | args.name = (f"lr={args.lr}_"
286 | "wd={args.wd}_"
287 | "agg={args.aggregate}_"
288 | "model={args.model}_"
289 | "batchsize={args.batch_size}_workers={args.workers}")
290 | if args.time_suffix:
291 | args.name += "_date=%Y-%m-%d-%H-%M-%S"
292 | args.name = strftime(args.name, gmtime())
293 |
294 | if args.copy_codebase:
295 | import sys, subprocess
296 | from shutil import copytree, ignore_patterns
297 | new_code_path = os.path.join(args.logs, args.name, "code")
298 | if os.path.exists(new_code_path):
299 | print(
300 | f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment."
301 | )
302 | return -1
303 | print(f"Copying codebase to {new_code_path}")
304 | current_code_path = os.path.realpath(__file__)
305 | for _ in range(3):
306 | current_code_path = os.path.dirname(current_code_path)
307 | copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb'))
308 | print("Done copying code.")
309 | os.environ["PYTHONPATH"] = f"{os.environ['PYTHONPATH']}:{os.path.join(new_code_path, 'src')}"
310 | main_file = os.path.join(new_code_path, "src", "training", "main.py")
311 | argv = sys.argv
312 | argv.remove('--copy-codebase')
313 | argv.extend(['--name', args.name])
314 | command = [sys.executable] + argv
315 | print("Executing command:", " ".join(command))
316 | subprocess.check_call(command)
317 | return 1
318 |
319 | args.log_path = os.path.join(args.logs, args.name, "out.log")
320 | if os.path.exists(args.log_path) and args.resume is None:
321 | print(
322 | "Error. Experiment already exists. Use --name {} to specify a new experiment."
323 | )
324 | return -1
325 |
326 | assert args.precision in ['amp', 'fp16', 'fp32']
327 | #assert args.model in ['RN50', 'RN101', 'RN50x4', 'ViT-B/32'] or os.path.exists(args.model)
328 |
329 | args.ngpus_per_node = torch.cuda.device_count()
330 |
331 | args.wandb = 'wandb' in args.report_to or 'all' in args.report_to
332 | args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to
333 |
334 | args.tensorboard_path = os.path.join(args.logs, args.name, "tensorboard") if args.tensorboard else ''
335 | args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints")
336 | for dirname in [args.tensorboard_path, args.checkpoint_path]:
337 | if dirname:
338 | os.makedirs(dirname, exist_ok=True)
339 |
340 |
341 | # Set multiprocessing type to spawn.
342 | # This is important for logging to work with multiprocessing.
343 | torch.multiprocessing.set_start_method("spawn")
344 |
345 | # Set logger
346 | args.log_level = logging.DEBUG if args.debug else logging.INFO
347 | log_queue = setup_primary_logging(args.log_path, args.log_level)
348 | args.world_size = 1
349 | try:
350 | main_worker(args.gpu, None, log_queue, args)
351 | except:
352 | print('evaluation done')
353 |
354 |
355 | if __name__ == "__main__":
356 | main()
357 |
--------------------------------------------------------------------------------
/src/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import logging
17 | from logging import Filter
18 | from logging.handlers import QueueHandler, QueueListener
19 |
20 | import torch
21 | import torch.distributed as dist
22 | import torch.multiprocessing as mp
23 | from torch.multiprocessing import Queue
24 |
25 |
26 | def setup_primary_logging(log_file, level):
27 | log_queue = Queue(-1)
28 |
29 | file_handler = logging.FileHandler(filename=log_file)
30 | stream_handler = logging.StreamHandler()
31 |
32 | formatter = logging.Formatter(
33 | '%(asctime)s | %(levelname)s | %(message)s',
34 | datefmt='%Y-%m-%d,%H:%M:%S')
35 |
36 | file_handler.setFormatter(formatter)
37 | stream_handler.setFormatter(formatter)
38 |
39 | file_handler.setLevel(level)
40 | stream_handler.setLevel(level)
41 |
42 | listener = QueueListener(log_queue, file_handler, stream_handler)
43 |
44 | listener.start()
45 |
46 | return log_queue
47 |
48 |
49 | class WorkerLogFilter(Filter):
50 | def __init__(self, rank=-1):
51 | super().__init__()
52 | self._rank = rank
53 |
54 | def filter(self, record):
55 | if self._rank != -1:
56 | record.msg = f"Rank {self._rank} | {record.msg}"
57 | return True
58 |
59 |
60 | def setup_worker_logging(rank, log_queue, level):
61 | queue_handler = QueueHandler(log_queue)
62 |
63 | worker_filter = WorkerLogFilter(rank)
64 | queue_handler.addFilter(worker_filter)
65 |
66 | queue_handler.setLevel(level)
67 |
68 | root_logger = logging.getLogger()
69 | root_logger.addHandler(queue_handler)
70 |
71 | root_logger.setLevel(level)
72 |
73 |
74 | def fake_worker(rank: int, world_size: int, log_queue: Queue):
75 | setup_worker_logging(rank, log_queue, logging.DEBUG)
76 | logging.info("Test worker log")
77 | logging.error("Test worker error log")
78 | torch.cuda.set_device(rank)
79 | dist.init_process_group(
80 | backend='nccl',
81 | init_method='tcp://127.0.0.1:6100',
82 | world_size=world_size,
83 | rank=rank,
84 | )
85 |
86 | if __name__ == "__main__":
87 | # Set multiprocessing type to spawn
88 | torch.multiprocessing.set_start_method("spawn")
89 |
90 | parser = argparse.ArgumentParser()
91 | parser.add_argument("-g", "--gpu-list", type=int, help="List of GPU IDs", nargs="+", required=True)
92 |
93 | args = parser.parse_args()
94 |
95 | world_size = len(args.gpu_list)
96 |
97 | # Initialize the primary logging handlers. Use the returned `log_queue`
98 | # to which the worker processes would use to push their messages
99 | log_queue = setup_primary_logging("/usr/lusers/gamaga/out.log", logging.DEBUG)
100 |
101 | if world_size == 1:
102 | worker(0, world_size, log_queue)
103 | else:
104 | mp.spawn(fake_worker, args=(world_size, log_queue), nprocs=world_size)
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | import time
16 | import logging
17 | from time import gmtime, strftime
18 | from pathlib import Path
19 | import json
20 | import wandb
21 | import torch
22 | from torch import optim
23 | import torch.distributed as dist
24 | import torch.multiprocessing as mp
25 | import torch.backends.cudnn as cudnn
26 | from torch.utils.tensorboard import SummaryWriter
27 | from torch.cuda.amp import GradScaler
28 | from third_party.open_clip.scheduler import cosine_lr
29 | from model.clip import _transform, load
30 | from model.model import convert_weights, CLIP, IM2TEXT
31 | from trainer import train
32 | from data import get_data
33 | from params import parse_args
34 | from logger import setup_primary_logging, setup_worker_logging
35 | from utils import is_master, convert_models_to_fp32
36 | import torchvision.transforms as T
37 |
38 | def main_worker(gpu, ngpus_per_node, log_queue, args):
39 | args.gpu = gpu
40 | args.rank = gpu
41 | setup_worker_logging(args.rank, log_queue, args.log_level)
42 |
43 | # Log and save params.
44 | if is_master(args):
45 | logging.info("Params:")
46 | params_file = os.path.join(args.logs, args.name, "params.txt")
47 | with open(params_file, "w") as f:
48 | for name in sorted(vars(args)):
49 | val = getattr(args, name)
50 | logging.info(f"{name}: {val}")
51 | f.write(f"{name}: {val}\n")
52 |
53 | if args.distributed:
54 | dist.init_process_group(
55 | backend=args.dist_backend,
56 | init_method=args.dist_url,
57 | world_size=args.world_size,
58 | rank=args.rank,
59 | )
60 |
61 | if args.dp:
62 | args.batch_size *= args.world_size
63 |
64 | if args.gpu is not None:
65 | logging.info(f"Use GPU: {args.gpu} for training")
66 | torch.cuda.set_device(args.gpu)
67 |
68 | # Do not use skip_reset unless you want to use on of the CLIP model
69 | if args.openai_pretrained:
70 | model, preprocess_train, preprocess_val = load(
71 | args.model,
72 | jit=False)
73 | else:
74 | model_config_file = Path(__file__).parent / f"model_configs/{args.model.replace('/', '-')}.json"
75 | print('Loading model from', model_config_file)
76 | assert os.path.exists(model_config_file)
77 | with open(model_config_file, 'r') as f:
78 | model_info = json.load(f)
79 | if args.use_prefix:
80 | model_info['vocab_size'] += 1
81 | model_info['use_prefix'] = True
82 | model = CLIP(**model_info)
83 | convert_weights(model)
84 | preprocess_train = _transform(model.visual.input_resolution, is_train=True)
85 | preprocess_val = _transform(model.visual.input_resolution, is_train=False)
86 | try:
87 | img2text = IM2TEXT(embed_dim=model.embed_dim,
88 | middle_dim=args.middle_dim,
89 | output_dim=model.token_embedding.weight.shape[1],
90 | n_layer=args.n_layer)
91 | except:
92 | img2text = IM2TEXT(embed_dim=1024, output_dim=1024,
93 | is_normalize=args.normalize_output, is_mlp=args.use_mlp, n_layer=args.n_layer)
94 |
95 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
96 | if args.precision == "amp" or args.precision == "fp32" or args.gpu is None:
97 | convert_models_to_fp32(model)
98 |
99 | if not torch.cuda.is_available():
100 | model.float()
101 | img2text.float()
102 | logging.warning("using CPU, this will be slow")
103 | else:
104 | model.cuda(args.gpu)
105 | img2text.cuda(args.gpu)
106 | if args.precision == "fp16":
107 | convert_weights(model)
108 | convert_weights(img2text)
109 | # Previously batch size and workers were global and not per GPU.
110 | # args.batch_size = args.batch_size / ngpus_per_node)
111 | # args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
112 |
113 | if args.distributed and args.use_bn_sync:
114 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
115 | if args.distributed:
116 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
117 | img2text = torch.nn.parallel.DistributedDataParallel(img2text, device_ids=[args.gpu], find_unused_parameters=False)
118 | if args.dp:
119 | model = torch.nn.DataParallel(model, device_ids=args.multigpu)
120 | img2text = torch.nn.DataParallel(img2text, device_ids=args.multigpu)
121 |
122 | if args.precision == "fp16":
123 | convert_weights(model)
124 | convert_weights(img2text)
125 |
126 | data = get_data(args, (preprocess_train, preprocess_val))
127 | exclude = lambda n : "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
128 | include = lambda n : not exclude(n)
129 | named_parameters = list(img2text.named_parameters())
130 | gain_or_bias_params = [p for n, p in named_parameters if exclude(n) and p.requires_grad]
131 | rest_params = [p for n, p in named_parameters if include(n) and p.requires_grad]
132 |
133 | if args.train_data is None:
134 | optimizer = None
135 | scheduler = None
136 | else:
137 | optimizer = optim.AdamW(
138 | [
139 | {"params": gain_or_bias_params, "weight_decay": 0.},
140 | {"params": rest_params, "weight_decay": args.wd},
141 | ],
142 | lr=args.lr,
143 | betas=(args.beta1, args.beta2),
144 | eps=args.eps,
145 | )
146 | total_steps = data["train"].dataloader.num_batches * args.epochs
147 | scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps)
148 |
149 | scaler = GradScaler() if args.precision == "amp" else None
150 |
151 | # optionally resume from a checkpoint
152 | start_epoch = 0
153 | if args.resume == 'auto':
154 | checkpoint_list = os.listdir(args.checkpoint_path)
155 | checkpoint_list = [ckpt for ckpt in checkpoint_list if ckpt.startswith('epoch')]
156 | if checkpoint_list:
157 | latest_epoch = max([int(ckpt.split('_')[1].split('.')[0]) for ckpt in checkpoint_list])
158 | args.resume = os.path.join(args.checkpoint_path, f'epoch_{latest_epoch}.pt')
159 | else:
160 | args.resume = None
161 |
162 | if args.resume is not None:
163 | if os.path.isfile(args.resume):
164 | if args.gpu is None:
165 | checkpoint = torch.load(args.resume)
166 | else:
167 | # Map model to be loaded to specified single gpu.
168 | loc = "cuda:{}".format(args.gpu)
169 | checkpoint = torch.load(args.resume, map_location=loc)
170 | start_epoch = checkpoint["epoch"]
171 | sd = checkpoint["state_dict"]
172 | sd_img2text = checkpoint["state_dict_img2text"]
173 | if not args.distributed and next(iter(sd.items()))[0].startswith('module'):
174 | sd = {k[len('module.'):]: v for k, v in sd.items()}
175 | if not args.distributed and next(iter(sd_img2text.items()))[0].startswith('module'):
176 | sd_img2text = {k[len('module.'):]: v for k, v in sd_img2text.items()}
177 | model.load_state_dict(sd)
178 | img2text.load_state_dict(sd_img2text)
179 | if optimizer is not None:
180 | optimizer.load_state_dict(checkpoint["optimizer"])
181 | logging.info(
182 | f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})"
183 | )
184 | else:
185 | logging.info("=> no checkpoint found at '{}'".format(args.resume))
186 |
187 | cudnn.benchmark = True
188 | cudnn.deterministic = False
189 | # determine if this worker should save logs and checkpoints.
190 | # only do so if it is the 0th worker.
191 | args.save_logs = (args.logs is not None and args.logs != '' and args.logs.lower() != 'none') and (
192 | (not args.distributed) or args.gpu == 0
193 | )
194 | writer = None
195 | if args.save_logs and args.tensorboard:
196 | writer = SummaryWriter(args.tensorboard_path)
197 |
198 | if args.wandb and is_master(args):
199 | logging.debug('Starting wandb.')
200 | args.train_sz = data["train"].dataloader.num_samples
201 | if args.val_data is not None:
202 | args.val_sz = data["val"].dataloader.num_samples
203 | # you will have to configure this for your project!
204 | wandb.init(
205 | project="open-clip",
206 | notes=args.wandb_notes,
207 | tags=[],
208 | config=vars(args),
209 | )
210 | if args.debug:
211 | wandb.watch(model, log='all')
212 | wandb.save(params_file)
213 | logging.debug('Finished loading wandb.')
214 |
215 | for epoch in range(start_epoch, args.epochs):
216 | if args.gpu == 0:
217 | logging.info(f'Start epoch {epoch}')
218 | train(model, img2text, data, epoch, optimizer, scaler, scheduler, args, writer)
219 | steps = data["train"].dataloader.num_batches * (epoch + 1)
220 | # Saving checkpoints.
221 | if args.save_logs and (args.gpu == 0 or (not args.distributed)):
222 | if (epoch + 1) == args.epochs or (
223 | args.save_frequency > 0 and ((epoch + 1) % args.save_frequency) == 0
224 | ):
225 | torch.save(
226 | {
227 | "epoch": epoch + 1,
228 | "name": args.name,
229 | "state_dict": model.state_dict(),
230 | "state_dict_img2text": img2text.state_dict(),
231 | "optimizer": optimizer.state_dict(),
232 | },
233 | os.path.join(args.checkpoint_path, f"epoch_{epoch + 1}.pt"),
234 | )
235 | if args.save_most_recent:
236 | torch.save(
237 | {
238 | "epoch": epoch + 1,
239 | "name": args.name,
240 | "state_dict": model.state_dict(),
241 | "state_dict_img2text": img2text.state_dict(),
242 | "optimizer": optimizer.state_dict(),
243 | },
244 | os.path.join(args.checkpoint_path, "epoch_latest.pt"),
245 | )
246 |
247 | if args.wandb and (args.gpu == 0 or (not args.distributed)):
248 | wandb.finish()
249 |
250 |
251 | def main():
252 | args = parse_args()
253 |
254 | # get the name of the experiments
255 | if args.name is None:
256 | args.name = (f"lr={args.lr}_"
257 | "wd={args.wd}_"
258 | "agg={args.aggregate}_"
259 | "model={args.model}_"
260 | "batchsize={args.batch_size}_workers={args.workers}")
261 | import pdb
262 | pdb.set_trace
263 | if args.time_suffix:
264 | args.name += "_date=%Y-%m-%d-%H-%M-%S"
265 | args.name = strftime(args.name, gmtime())
266 |
267 | if args.copy_codebase:
268 | import sys, subprocess
269 | from shutil import copytree, ignore_patterns
270 | new_code_path = os.path.join(args.logs, args.name, "code")
271 | if os.path.exists(new_code_path):
272 | print(
273 | f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment."
274 | )
275 | return -1
276 | print(f"Copying codebase to {new_code_path}")
277 | current_code_path = os.path.realpath(__file__)
278 | for _ in range(3):
279 | current_code_path = os.path.dirname(current_code_path)
280 | copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb'))
281 | print("Done copying code.")
282 | os.environ["PYTHONPATH"] = f"{os.environ['PYTHONPATH']}:{os.path.join(new_code_path, 'src')}"
283 | main_file = os.path.join(new_code_path, "src", "training", "main.py")
284 | argv = sys.argv
285 | argv.remove('--copy-codebase')
286 | argv.extend(['--name', args.name])
287 | command = [sys.executable] + argv
288 | print("Executing command:", " ".join(command))
289 | subprocess.check_call(command)
290 | return 1
291 |
292 | args.log_path = os.path.join(args.logs, args.name, "out.log")
293 | if os.path.exists(args.log_path) and args.resume is None:
294 | print(
295 | "Error. Experiment already exists. Use --name {} to specify a new experiment."
296 | )
297 | return -1
298 |
299 | assert args.precision in ['amp', 'fp16', 'fp32']
300 | #assert args.model in ['RN50', 'RN101', 'RN50x4', 'ViT-B/32'] or os.path.exists(args.model)
301 |
302 | args.ngpus_per_node = torch.cuda.device_count()
303 |
304 | args.wandb = 'wandb' in args.report_to or 'all' in args.report_to
305 | args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to
306 |
307 | args.tensorboard_path = os.path.join(args.logs, args.name, "tensorboard") if args.tensorboard else ''
308 | args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints")
309 | for dirname in [args.tensorboard_path, args.checkpoint_path]:
310 | if dirname:
311 | os.makedirs(dirname, exist_ok=True)
312 |
313 |
314 | # Set multiprocessing type to spawn.
315 | # This is important for logging to work with multiprocessing.
316 | torch.multiprocessing.set_start_method("spawn")
317 |
318 | # Set logger
319 | args.log_level = logging.DEBUG if args.debug else logging.INFO
320 | log_queue = setup_primary_logging(args.log_path, args.log_level)
321 |
322 | # Distributed training = training on more than one GPU.
323 | # Also easily possible to extend to multiple nodes & multiple GPUs.
324 | args.distributed = (args.gpu is None) and torch.cuda.is_available() and (not args.dp)
325 | if args.distributed:
326 | ngpus_per_node = torch.cuda.device_count()
327 | args.world_size = ngpus_per_node
328 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, log_queue, args))
329 | else:
330 | if args.dp:
331 | args.gpu = args.multigpu[0]
332 | args.world_size = len(args.multigpu)
333 | else:
334 | args.world_size = 1
335 | main_worker(args.gpu, None, log_queue, args)
336 |
337 |
338 | if __name__ == "__main__":
339 | main()
340 |
--------------------------------------------------------------------------------
/src/params.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import argparse
15 | from pathlib import Path
16 |
17 | def get_project_root():
18 | return Path(__file__).parent.parent
19 |
20 | def get_default_params(model_name):
21 | # Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
22 | if model_name in ["RN50", "RN101", "RN50x4", "RN50x64", "RN50x16", "RN50_flat", "RN50_t1", "RN50_t2", "RN50_t3", "RN50_t4", "RN50_t5", "RN50_t6",
23 | "RN50_flat_ft", "RN50_t1_pos_ft", "RN50_t2_pos_ft", "RN50_t1_pos", "RN50_t2_pos",
24 | "RN50_flat_large", "RN50_t1_large", "RN50_t2_large",
25 | "RN50_a2", "RN50_a2s", "ViT-H-14"]:
26 | return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}
27 | elif model_name in ["ViT-B/32", "ViT-L/14", "ViT-B/16"]:
28 | return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
29 | else:
30 | return {}
31 |
32 |
33 | def parse_args():
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument("--no-time-suffix",
36 | default=True,
37 | action="store_false",
38 | help="Whether to append current time in the suffix.",
39 | dest="time_suffix")
40 | parser.add_argument(
41 | "--train-data",
42 | type=str,
43 | default=None,
44 | help="Path to csv filewith training data",
45 | )
46 | parser.add_argument(
47 | "--val-data",
48 | type=str,
49 | default=None,
50 | help="Path to csv file with validation data",
51 | )
52 | parser.add_argument(
53 | "--prompts",
54 | type=str,
55 | default=None,
56 | help="list of prompts split with ,",
57 | )
58 | parser.add_argument(
59 | "--retrieval-data",
60 | type=str,
61 | default=None,
62 | help="Path to csv file or folder of retrieval data",
63 | )
64 | parser.add_argument(
65 | "--demo-out",
66 | type=str,
67 | default="demo",
68 | help="Path to the output directory for visualization",
69 | )
70 | parser.add_argument(
71 | "--source-data",
72 | type=str,
73 | default=None,
74 | help="Path to txt file of retrieval data",
75 | )
76 | parser.add_argument(
77 | "--target-data",
78 | type=str,
79 | default=None,
80 | help="Path to txt file of retrieval data",
81 | )
82 | parser.add_argument(
83 | "--target-pad",
84 | action="store_true",
85 | default=False,
86 | help="Padding augmentation proposed by combiner.",
87 | )
88 | parser.add_argument(
89 | "--query_file",
90 | type=str,
91 | default=None,
92 | help="Path to query image file for retrieval visualization",
93 | )
94 | parser.add_argument("--eval-mode",
95 | type=str,
96 | choices=["coco", "cirr", "cirr_test", "fashion", "imgnet"],
97 | default="coco",
98 | help="Evaluate Pacs")
99 | parser.add_argument("--middle_dim",
100 | default=512,
101 | type=int,
102 | help="Number of hidden units in mapping network.")
103 | parser.add_argument("--droprate",
104 | default=0.1,
105 | type=float,
106 | help="Dropout rate.")
107 | parser.add_argument(
108 | "--n-layer", type=int, default=2, help="Number of layers in im2text"
109 | )
110 | parser.add_argument(
111 | "--dataset-type",
112 | choices=["webdataset", "csv", "inet", "auto", "inet,csv", "csv,inet", "directory", "fashion-iq", "cirr", "imgnet_r"],
113 | default="auto",
114 | help="Which type of dataset to process."
115 | )
116 | parser.add_argument(
117 | "--dataset-type-val",
118 | choices=["webdataset", "csv", "inet", "auto"],
119 | default="auto",
120 | help="Which type of dataset to process."
121 | )
122 | parser.add_argument(
123 | "--csv-separator",
124 | type=str,
125 | default="\t",
126 | help="For csv-like datasets, which separator to use."
127 | )
128 | parser.add_argument(
129 | "--csv-img-key",
130 | type=str,
131 | default="filepath",
132 | help="For csv-like datasets, the name of the key for the image paths."
133 | )
134 | parser.add_argument(
135 | "--csv-caption-key",
136 | type=str,
137 | default="title",
138 | help="For csv-like datasets, the name of the key for the captions."
139 | )
140 | parser.add_argument(
141 | "--imagenet-val",
142 | type=str,
143 | default=None,
144 | help="Path to imagenet val set for conducting zero shot evaluation.",
145 | )
146 | parser.add_argument(
147 | "--imagenet-v2",
148 | type=str,
149 | default=None,
150 | help="Path to imagenet v2 for conducting zero shot evaluation.",
151 | )
152 | parser.add_argument(
153 | "--logs",
154 | type=str,
155 | default="./logs/",
156 | help="Where to store tensorboard logs. Use None to avoid storing logs.",
157 | )
158 | parser.add_argument(
159 | "--name",
160 | type=str,
161 | default=None,
162 | help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
163 | )
164 | parser.add_argument(
165 | "--workers", type=int, default=1, help="Number of workers per GPU."
166 | )
167 | parser.add_argument(
168 | "--batch-size", type=int, default=64, help="Batch size per GPU."
169 | )
170 | parser.add_argument(
171 | "--epochs", type=int, default=32, help="Number of epochs to train for."
172 | )
173 | parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
174 | parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
175 | parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
176 | parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
177 | parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
178 | parser.add_argument(
179 | "--warmup", type=int, default=10000, help="Number of steps to warmup for."
180 | )
181 | parser.add_argument("--use-bn-sync",
182 | default=False,
183 | action="store_true",
184 | help="Whether to use batch norm sync.")
185 | parser.add_argument("--use-debiased-sampler",
186 | default=False,
187 | action="store_true",
188 | help="Whether to use batch norm sync.")
189 | parser.add_argument("--use-prefix",
190 | default=False,
191 | action="store_true",
192 | help="Whether to use prefix conditioning in using image classification dataset.")
193 | parser.add_argument(
194 | "--gpu",
195 | type=int,
196 | default=None,
197 | help="Specify a single GPU to run the code on for debugging."
198 | "Leave at None to use all available GPUs.",
199 | )
200 | parser.add_argument(
201 | "--skip-scheduler",
202 | action="store_true",
203 | default=False,
204 | help="Use this flag to skip the learning rate decay.",
205 | )
206 | parser.add_argument(
207 | "--save-frequency", type=int, default=1, help="How often to save checkpoints."
208 | )
209 | parser.add_argument(
210 | "--save-most-recent",
211 | action="store_true",
212 | default=False,
213 | help="Always save the most recent model trained to epoch_latest.pt.",
214 | )
215 | parser.add_argument(
216 | "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot."
217 | )
218 | parser.add_argument(
219 | "--regression-frequency", type=int, default=2, help="How often to run zero shot."
220 | )
221 | parser.add_argument(
222 | "--resume",
223 | default=None,
224 | type=str,
225 | help="path to latest checkpoint (default: none)",
226 | )
227 | parser.add_argument(
228 | "--precision",
229 | choices=["amp", "fp16", "fp32"],
230 | default="amp",
231 | help="Floating point precition."
232 | )
233 | parser.add_argument(
234 | "--model",
235 | choices=["RN50", "RN101", "RN50x4", "RN50x64", "RN50x16", "ViT-B/16", "ViT-B/32", "ViT-L/14", "ViT-H-14",
236 | "RN50_flat", "RN50_t1", "RN50_t2", "RN50_t3", "RN50_t4", "RN50_t5", "RN50_t6",
237 | "RN50_flat_ft", "RN50_t1_pos_ft", "RN50_t2_pos_ft", "RN50_t1_pos", "RN50_t2_pos",
238 | "RN50_flat_large", "RN50_t1_large", "RN50_t2_large",
239 | "RN50_a2", "RN50_a2s"],
240 | default="RN50",
241 | help="Name of the vision backbone to use.",
242 | )
243 | parser.add_argument(
244 | "--openai-pretrained",
245 | default=False,
246 | action='store_true',
247 | help="Use the openai pretrained models.",
248 | )
249 | # arguments for distributed training
250 | parser.add_argument(
251 | "--dist-url",
252 | default="tcp://127.0.0.1:6100",
253 | type=str,
254 | help="url used to set up distributed training",
255 | )
256 | parser.add_argument(
257 | "--dist-backend", default="nccl", type=str, help="distributed backend"
258 | )
259 | parser.add_argument(
260 | "--skip-aggregate",
261 | default=False,
262 | action="store_true",
263 | help="whether to aggregate features across gpus before computing the loss"
264 | )
265 | parser.add_argument(
266 | "--report-to",
267 | default='',
268 | type=str,
269 | help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']"
270 | )
271 | parser.add_argument(
272 | "--wandb-notes",
273 | default='',
274 | type=str,
275 | help="Notes if logging with wandb"
276 | )
277 | parser.add_argument(
278 | "--C", type=float, default=3.16, help="inverse regularizer for logistic reg."
279 | )
280 | parser.add_argument(
281 | "--debug",
282 | default=False,
283 | action="store_true",
284 | help="If true, more information is logged."
285 | )
286 | parser.add_argument(
287 | "--copy-codebase",
288 | default=False,
289 | action="store_true",
290 | help="If true, we copy the entire base on the log diretory, and execute from there."
291 | )
292 | parser.add_argument(
293 | "--dp",
294 | default=False,
295 | action="store_true",
296 | help="Use DP instead of DDP."
297 | )
298 | parser.add_argument(
299 | "--multigpu",
300 | default=None,
301 | type=lambda x: [int(a) for a in x.split(",")],
302 | help="In DP, which GPUs to use for multigpu training",
303 | )
304 | args = parser.parse_args()
305 | args.aggregate = not args.skip_aggregate
306 |
307 | # If some params are not passed, we use the default values based on model name.
308 | default_params = get_default_params(args.model)
309 | for name, val in default_params.items():
310 | if getattr(args, name) is None:
311 | setattr(args, name, val)
312 |
313 | return args
314 |
--------------------------------------------------------------------------------
/src/trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import time
17 | import json
18 | import numpy as np
19 | import torch
20 | import torch.nn as nn
21 | from PIL import Image
22 |
23 | from torch.cuda.amp import autocast
24 | import torch.distributed as dist
25 | from tqdm import tqdm
26 | from torchvision.utils import save_image
27 | import sys
28 | import pdb
29 | import wandb
30 | import logging
31 | import torch.nn.functional as F
32 | from third_party.open_clip.clip import tokenize, _transform
33 | from third_party.open_clip.simple_tokenizer import SimpleTokenizer
34 | from utils import is_master
35 |
36 |
37 | def get_loss(model, images, texts, loss_img, loss_txt, args, data_identifier=-1):
38 | if data_identifier == 1:
39 | # ImageNet dataset
40 | image_features, text_features, logit_scale = model(images, texts, extra=True)
41 | else:
42 | image_features, text_features, logit_scale = model(images, texts)
43 | logit_scale = logit_scale.mean()
44 | if args.distributed and args.aggregate:
45 | world_size = dist.get_world_size()
46 | rank = dist.get_rank()
47 |
48 | # We gather tensors from all gpus to get more negatives to contrast with.
49 | gathered_image_features = [
50 | torch.zeros_like(image_features) for _ in range(world_size)
51 | ]
52 | gathered_text_features = [
53 | torch.zeros_like(text_features) for _ in range(world_size)
54 | ]
55 | dist.all_gather(gathered_image_features, image_features)
56 | dist.all_gather(gathered_text_features, text_features)
57 |
58 | all_image_features = torch.cat(
59 | [image_features]
60 | + gathered_image_features[:rank]
61 | + gathered_image_features[rank + 1 :]
62 | )
63 | all_text_features = torch.cat(
64 | [text_features]
65 | + gathered_text_features[:rank]
66 | + gathered_text_features[rank + 1 :]
67 | )
68 |
69 | ground_truth = torch.arange(len(all_image_features)).long()
70 | if args.gpu is not None:
71 | ground_truth = ground_truth.cuda(args.gpu, non_blocking=True)
72 |
73 | # this is needed to send gradients back everywhere.
74 | # Image loss.
75 | logits_per_image = logit_scale * all_image_features @ all_text_features.t()
76 | loss_img_val = loss_img(logits_per_image, ground_truth)
77 | logits_per_text = logits_per_image.t()
78 | loss_txt_val = loss_txt(logits_per_text, ground_truth)
79 | else:
80 | ground_truth = torch.arange(len(image_features)).long()
81 | if args.gpu is not None:
82 | ground_truth = ground_truth.cuda(args.gpu, non_blocking=True)
83 |
84 | # Image loss.
85 | logits_per_image = logit_scale * image_features @ text_features.t()
86 | loss_img_val = loss_img(logits_per_image, ground_truth)
87 | logits_per_text = logit_scale * text_features @ image_features.t()
88 | loss_txt_val = loss_txt(logits_per_text, ground_truth)
89 |
90 | total_loss = (loss_img_val + loss_txt_val) / 2
91 | return total_loss
92 |
93 |
94 | def get_text_features(model, token_features, args):
95 | text = tokenize("a photo of")
96 | text = text.cuda(args.gpu, non_blocking=True)
97 | text = text.view(1, -1)
98 | text = text.repeat(token_features.size(0), 1)
99 | text_features = model.encode_text_img(text, token_features)
100 | return text_features
101 |
102 | def get_loss_img2text(model, img2text, images, loss_img, loss_txt, args, memory=None):
103 | with torch.no_grad():
104 | image_features = model.encode_image(images)
105 | token_features = img2text(image_features)
106 | text_features = get_text_features(model, token_features, args)
107 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
108 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
109 | logit_scale = model.logit_scale.exp()
110 | logit_scale = logit_scale.mean()
111 | if args.distributed and args.aggregate:
112 | world_size = dist.get_world_size()
113 | rank = dist.get_rank()
114 |
115 | # We gather tensors from all gpus to get more negatives to contrast with.
116 | gathered_image_features = [
117 | torch.zeros_like(image_features) for _ in range(world_size)
118 | ]
119 | gathered_text_features = [
120 | torch.zeros_like(text_features) for _ in range(world_size)
121 | ]
122 | dist.all_gather(gathered_image_features, image_features)
123 | dist.all_gather(gathered_text_features, text_features)
124 |
125 | all_image_features = torch.cat(
126 | [image_features]
127 | + gathered_image_features[:rank]
128 | + gathered_image_features[rank + 1 :]
129 | )
130 | all_text_features = torch.cat(
131 | [text_features]
132 | + gathered_text_features[:rank]
133 | + gathered_text_features[rank + 1 :]
134 | )
135 |
136 | ground_truth = torch.arange(len(all_image_features)).long()
137 | if args.gpu is not None:
138 | ground_truth = ground_truth.cuda(args.gpu, non_blocking=True)
139 |
140 | # this is needed to send gradients back everywhere.
141 | # Image loss.
142 | logits_per_image = logit_scale * all_image_features @ all_text_features.t()
143 | loss_img_val = loss_img(logits_per_image, ground_truth)
144 | logits_per_text = logits_per_image.t()
145 | loss_txt_val = loss_txt(logits_per_text, ground_truth)
146 | else:
147 | ground_truth = torch.arange(len(image_features)).long()
148 | if args.gpu is not None:
149 | ground_truth = ground_truth.cuda(args.gpu, non_blocking=True)
150 | # Image loss.
151 | logits_per_image = logit_scale * image_features @ text_features.t()
152 | loss_img_val = loss_img(logits_per_image, ground_truth)
153 | logits_per_text = logit_scale * text_features @ image_features.t()
154 | loss_txt_val = loss_txt(logits_per_text, ground_truth)
155 | total_loss = (loss_img_val + loss_txt_val) / 2
156 | return total_loss
157 |
158 |
159 | def train(model, img2text, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None):
160 | os.environ["WDS_EPOCH"] = str(epoch)
161 | model.eval()
162 | dataloader, sampler = data['train'].dataloader, data['train'].sampler
163 | loss_img = nn.CrossEntropyLoss()
164 | loss_txt = nn.CrossEntropyLoss()
165 | if args.gpu is not None:
166 | loss_img = loss_img.cuda(args.gpu)
167 | loss_txt = loss_txt.cuda(args.gpu)
168 |
169 | if args.distributed and sampler is not None:
170 | sampler.set_epoch(epoch)
171 |
172 | num_batches_per_epoch = dataloader.num_batches
173 |
174 | end = time.time()
175 | for i, batch in enumerate(dataloader):
176 | step = num_batches_per_epoch * epoch + i
177 | scheduler(step)
178 |
179 | optimizer.zero_grad()
180 |
181 | images, texts = batch[0], batch[1]
182 | if len(batch) == 3 and args.use_debiased_sampler:
183 | data_identifier = torch.unique(batch[2])[0].numpy()
184 | else:
185 | data_identifier = -1
186 | if args.gpu is not None:
187 | images = images.cuda(args.gpu, non_blocking=True)
188 |
189 | data_time = time.time() - end
190 |
191 | m = model.module if args.distributed or args.dp else model
192 |
193 | # with automatic mixed precision.
194 | if args.precision == "amp":
195 | with autocast():
196 | total_loss = get_loss_img2text(m, img2text, images, loss_img, loss_txt, args, data_identifier)
197 | scaler.scale(total_loss).backward()
198 | scaler.step(optimizer)
199 | scaler.update()
200 |
201 | else:
202 | total_loss = get_loss_img2text(m, img2text, images, loss_img, loss_txt, args, data_identifier)
203 | total_loss.backward()
204 | optimizer.step()
205 |
206 | # Note: we clamp to 4.6052 = ln(100), as in the original paper.
207 | #m.logit_scale.data = torch.clamp(m.logit_scale.data, 0, 4.6052)
208 |
209 | batch_time = time.time() - end
210 | end = time.time()
211 |
212 | if is_master(args) and (i % 100) == 0:
213 | num_samples = i * len(images) * args.world_size
214 | samples_per_epoch = dataloader.num_samples
215 | percent_complete = 100.0 * i / num_batches_per_epoch
216 | logging.info(
217 | f"Train Epoch: {epoch} [{num_samples}/{samples_per_epoch} ({percent_complete:.0f}%)]\t"
218 | f"Loss: {total_loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}"
219 | f"\tLR: {optimizer.param_groups[0]['lr']:5f}\tlogit_scale {m.logit_scale.data:.3f}"
220 | )
221 | # save train loss / etc.
222 |
223 | timestep = epoch * num_batches_per_epoch + i
224 | log_data = {
225 | "loss": total_loss.item(),
226 | "data_time": data_time,
227 | "batch_time": batch_time,
228 | "scale": m.logit_scale.data.item(),
229 | "lr": optimizer.param_groups[0]["lr"]
230 | }
231 |
232 | for name, val in log_data.items():
233 | name = "train/" + name
234 | if tb_writer is not None:
235 | tb_writer.add_scalar(name, val, timestep)
236 | if args.wandb:
237 | wandb.log({name: val, 'step': timestep})
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | import numpy as np
17 | import torch
18 | import torchvision.transforms.functional as F
19 |
20 | class TargetPad:
21 | """
22 | Pad the image if its aspect ratio is above a target ratio.
23 | Pad the image to match such target ratio
24 | """
25 |
26 | def __init__(self, target_ratio=1.25):
27 | """
28 | :param target_ratio: target ratio
29 | :param size: preprocessing output dimension
30 | """
31 | self.target_ratio = target_ratio
32 |
33 | def __call__(self, image):
34 | w, h = image.size
35 | actual_ratio = max(w, h) / min(w, h)
36 | if actual_ratio < self.target_ratio: # check if the ratio is above or below the target ratio
37 | return image
38 | scaled_max_wh = max(w, h) / self.target_ratio # rescale the pad to match the target ratio
39 | hp = max(int((scaled_max_wh - w) / 2), 0)
40 | vp = max(int((scaled_max_wh - h) / 2), 0)
41 | padding = [hp, vp, hp, vp]
42 | return F.pad(image, padding, 0, 'constant')
43 |
44 | def convert_models_to_fp32(model):
45 | for p in model.parameters():
46 | p.data = p.data.float()
47 | if p.grad:
48 | p.grad.data = p.grad.data.float()
49 |
50 | def is_master(args):
51 | return (not args.distributed) or args.gpu == 0 or args.dp
--------------------------------------------------------------------------------
/third_party/open_clip/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman,
2 | Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar,
3 | John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi,
4 | Ludwig Schmidt
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining
7 | a copy of this software and associated documentation files (the
8 | "Software"), to deal in the Software without restriction, including
9 | without limitation the rights to use, copy, modify, merge, publish,
10 | distribute, sublicense, and/or sell copies of the Software, and to
11 | permit persons to whom the Software is furnished to do so, subject to
12 | the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be
15 | included in all copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
21 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
22 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
23 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
24 |
--------------------------------------------------------------------------------
/third_party/open_clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/composed_image_retrieval/8c053297c2fae9cd17ddcded48445a4f47208dbd/third_party/open_clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/third_party/open_clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Any, Union, List
6 | from pkg_resources import packaging
7 |
8 | import torch
9 | from PIL import Image
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11 | from tqdm import tqdm
12 |
13 | from .model import build_model
14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15 |
16 | try:
17 | from torchvision.transforms import InterpolationMode
18 | BICUBIC = InterpolationMode.BICUBIC
19 | except ImportError:
20 | BICUBIC = Image.BICUBIC
21 |
22 |
23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25 |
26 |
27 | __all__ = ["available_models", "load", "tokenize"]
28 | _tokenizer = _Tokenizer()
29 |
30 | _MODELS = {
31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40 | }
41 |
42 |
43 | def _download(url: str, root: str):
44 | os.makedirs(root, exist_ok=True)
45 | filename = os.path.basename(url)
46 |
47 | expected_sha256 = url.split("/")[-2]
48 | download_target = os.path.join(root, filename)
49 |
50 | if os.path.exists(download_target) and not os.path.isfile(download_target):
51 | raise RuntimeError(f"{download_target} exists and is not a regular file")
52 |
53 | if os.path.isfile(download_target):
54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55 | return download_target
56 | else:
57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58 |
59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61 | while True:
62 | buffer = source.read(8192)
63 | if not buffer:
64 | break
65 |
66 | output.write(buffer)
67 | loop.update(len(buffer))
68 |
69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
71 |
72 | return download_target
73 |
74 |
75 | def _convert_image_to_rgb(image):
76 | return image.convert("RGB")
77 |
78 |
79 | def _transform(n_px):
80 | return Compose([
81 | Resize(n_px, interpolation=BICUBIC),
82 | CenterCrop(n_px),
83 | _convert_image_to_rgb,
84 | ToTensor(),
85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
86 | ])
87 |
88 |
89 | def available_models() -> List[str]:
90 | """Returns the names of available CLIP models"""
91 | return list(_MODELS.keys())
92 |
93 |
94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
95 | """Load a CLIP model
96 | Parameters
97 | ----------
98 | name : str
99 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
100 | device : Union[str, torch.device]
101 | The device to put the loaded model
102 | jit : bool
103 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
104 | download_root: str
105 | path to download the model files; by default, it uses "~/.cache/clip"
106 | Returns
107 | -------
108 | model : torch.nn.Module
109 | The CLIP model
110 | preprocess : Callable[[PIL.Image], torch.Tensor]
111 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
112 | """
113 | if name in _MODELS:
114 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
115 | elif os.path.isfile(name):
116 | model_path = name
117 | else:
118 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
119 |
120 | with open(model_path, 'rb') as opened_file:
121 | try:
122 | # loading JIT archive
123 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
124 | state_dict = None
125 | except RuntimeError:
126 | # loading saved state dict
127 | if jit:
128 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
129 | jit = False
130 | state_dict = torch.load(opened_file, map_location="cpu")
131 |
132 | if not jit:
133 | model = build_model(state_dict or model.state_dict()).to(device)
134 | if str(device) == "cpu":
135 | model.float()
136 | return model, _transform(model.visual.input_resolution)
137 |
138 | # patch the device names
139 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
140 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
141 |
142 | def patch_device(module):
143 | try:
144 | graphs = [module.graph] if hasattr(module, "graph") else []
145 | except RuntimeError:
146 | graphs = []
147 |
148 | if hasattr(module, "forward1"):
149 | graphs.append(module.forward1.graph)
150 |
151 | for graph in graphs:
152 | for node in graph.findAllNodes("prim::Constant"):
153 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
154 | node.copyAttributes(device_node)
155 |
156 | model.apply(patch_device)
157 | patch_device(model.encode_image)
158 | patch_device(model.encode_text)
159 |
160 | # patch dtype to float32 on CPU
161 | if str(device) == "cpu":
162 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
163 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
164 | float_node = float_input.node()
165 |
166 | def patch_float(module):
167 | try:
168 | graphs = [module.graph] if hasattr(module, "graph") else []
169 | except RuntimeError:
170 | graphs = []
171 |
172 | if hasattr(module, "forward1"):
173 | graphs.append(module.forward1.graph)
174 |
175 | for graph in graphs:
176 | for node in graph.findAllNodes("aten::to"):
177 | inputs = list(node.inputs())
178 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
179 | if inputs[i].node()["value"] == 5:
180 | inputs[i].node().copyAttributes(float_node)
181 |
182 | model.apply(patch_float)
183 | patch_float(model.encode_image)
184 | patch_float(model.encode_text)
185 |
186 | model.float()
187 |
188 | return model, _transform(model.input_resolution.item())
189 |
190 |
191 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
192 | """
193 | Returns the tokenized representation of given input string(s)
194 | Parameters
195 | ----------
196 | texts : Union[str, List[str]]
197 | An input string or a list of input strings to tokenize
198 | context_length : int
199 | The context length to use; all CLIP models use 77 as the context length
200 | truncate: bool
201 | Whether to truncate the text in case its encoding is longer than the context length
202 | Returns
203 | -------
204 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
205 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
206 | """
207 | if isinstance(texts, str):
208 | texts = [texts]
209 |
210 | sot_token = _tokenizer.encoder["<|startoftext|>"]
211 | eot_token = _tokenizer.encoder["<|endoftext|>"]
212 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
213 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
214 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
215 | else:
216 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
217 |
218 | for i, tokens in enumerate(all_tokens):
219 | if len(tokens) > context_length:
220 | if truncate:
221 | tokens = tokens[:context_length]
222 | tokens[-1] = eot_token
223 | else:
224 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
225 | result[i, :len(tokens)] = torch.tensor(tokens)
226 |
227 | return result
--------------------------------------------------------------------------------
/third_party/open_clip/environment.yml:
--------------------------------------------------------------------------------
1 | name: open_clip
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - absl-py=0.12.0=py36h06a4308_0
9 | - aiohttp=3.6.3=py36h7b6447c_0
10 | - async-timeout=3.0.1=py36h06a4308_0
11 | - attrs=20.3.0=pyhd3eb1b0_0
12 | - blas=1.0=mkl
13 | - blinker=1.4=py36h06a4308_0
14 | - brotlipy=0.7.0=py36h27cfd23_1003
15 | - c-ares=1.17.1=h27cfd23_0
16 | - ca-certificates=2020.12.5=ha878542_0
17 | - cachetools=4.2.1=pyhd3eb1b0_0
18 | - certifi=2020.12.5=py36h5fab9bb_1
19 | - cffi=1.14.5=py36h261ae71_0
20 | - chardet=3.0.4=py36h06a4308_1003
21 | - click=7.1.2=pyhd3eb1b0_0
22 | - coverage=5.5=py36h27cfd23_2
23 | - cryptography=3.4.7=py36hd23ed53_0
24 | - cudatoolkit=11.0.221=h6bb024c_0
25 | - cython=0.29.23=py36h2531618_0
26 | - dataclasses=0.8=pyh4f3eec9_6
27 | - faiss-gpu=1.4.0=py36_cuda8.0.61_1
28 | - freetype=2.10.4=h5ab3b9f_0
29 | - ftfy=5.8=py_0
30 | - google-auth=1.29.0=pyhd3eb1b0_0
31 | - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
32 | - grpcio=1.36.1=py36h2157cd5_1
33 | - idna=2.10=pyhd3eb1b0_0
34 | - idna_ssl=1.1.0=py36h06a4308_0
35 | - importlib-metadata=3.10.0=py36h06a4308_0
36 | - intel-openmp=2021.2.0=h06a4308_610
37 | - joblib=1.0.1=pyhd8ed1ab_0
38 | - jpeg=9b=h024ee3a_2
39 | - lcms2=2.12=h3be6417_0
40 | - ld_impl_linux-64=2.33.1=h53a641e_7
41 | - libblas=3.9.0=1_h6e990d7_netlib
42 | - libcblas=3.9.0=3_h893e4fe_netlib
43 | - libffi=3.3=he6710b0_2
44 | - libgcc=7.2.0=h69d50b8_2
45 | - libgcc-ng=9.1.0=hdf63c60_0
46 | - libgfortran-ng=7.5.0=h14aa051_19
47 | - libgfortran4=7.5.0=h14aa051_19
48 | - liblapack=3.9.0=3_h893e4fe_netlib
49 | - libpng=1.6.37=hbc83047_0
50 | - libprotobuf=3.14.0=h8c45485_0
51 | - libstdcxx-ng=9.1.0=hdf63c60_0
52 | - libtiff=4.1.0=h2733197_1
53 | - libuv=1.40.0=h7b6447c_0
54 | - lz4-c=1.9.3=h2531618_0
55 | - markdown=3.3.4=py36h06a4308_0
56 | - mkl=2020.2=256
57 | - mkl-service=2.3.0=py36he8ac12f_0
58 | - mkl_fft=1.3.0=py36h54f3939_0
59 | - mkl_random=1.1.1=py36h0573a6f_0
60 | - multidict=4.7.6=py36h7b6447c_1
61 | - ncurses=6.2=he6710b0_1
62 | - ninja=1.10.2=hff7bd54_1
63 | - numpy=1.19.2=py36h54aff64_0
64 | - numpy-base=1.19.2=py36hfa32c7d_0
65 | - oauthlib=3.1.0=py_0
66 | - olefile=0.46=py36_0
67 | - openssl=1.1.1k=h27cfd23_0
68 | - pandas=1.1.3=py36he6710b0_0
69 | - pillow=8.2.0=py36he98fc37_0
70 | - pip=21.0.1=py36h06a4308_0
71 | - protobuf=3.14.0=py36h2531618_1
72 | - pyasn1=0.4.8=py_0
73 | - pyasn1-modules=0.2.8=py_0
74 | - pycparser=2.20=py_2
75 | - pyjwt=1.7.1=py36_0
76 | - pyopenssl=20.0.1=pyhd3eb1b0_1
77 | - pysocks=1.7.1=py36h06a4308_0
78 | - python=3.6.13=hdb3f193_0
79 | - python-dateutil=2.8.1=pyhd3eb1b0_0
80 | - python_abi=3.6=1_cp36m
81 | - pytorch=1.7.1=py3.6_cuda11.0.221_cudnn8.0.5_0
82 | - pytz=2021.1=pyhd3eb1b0_0
83 | - readline=8.1=h27cfd23_0
84 | - regex=2021.4.4=py36h27cfd23_0
85 | - requests=2.25.1=pyhd3eb1b0_0
86 | - requests-oauthlib=1.3.0=py_0
87 | - rsa=4.7.2=pyhd3eb1b0_1
88 | - scikit-learn=0.23.2=py36hb6e6923_3
89 | - scipy=1.5.3=py36h976291a_0
90 | - setuptools=52.0.0=py36h06a4308_0
91 | - six=1.15.0=py36h06a4308_0
92 | - sqlite=3.35.4=hdfb4753_0
93 | - tensorboard=2.4.0=pyhc547734_0
94 | - tensorboard-plugin-wit=1.6.0=py_0
95 | - threadpoolctl=2.1.0=pyh5ca1d4c_0
96 | - tk=8.6.10=hbc83047_0
97 | - torchaudio=0.7.2=py36
98 | - torchvision=0.8.2=py36_cu110
99 | - tqdm=4.59.0=pyhd3eb1b0_1
100 | - typing_extensions=3.7.4.3=pyha847dfd_0
101 | - urllib3=1.26.4=pyhd3eb1b0_0
102 | - wcwidth=0.2.5=py_0
103 | - werkzeug=1.0.1=pyhd3eb1b0_0
104 | - wheel=0.36.2=pyhd3eb1b0_0
105 | - xz=5.2.5=h7b6447c_0
106 | - yarl=1.6.3=py36h27cfd23_0
107 | - zipp=3.4.1=pyhd3eb1b0_0
108 | - zlib=1.2.11=h7b6447c_3
109 | - zstd=1.4.9=haebb681_0
110 | - pip:
111 | - ase==3.21.1
112 | - braceexpand==0.1.7
113 | - cached-property==1.5.2
114 | - configparser==5.0.2
115 | - cycler==0.10.0
116 | - decorator==4.4.2
117 | - docker-pycreds==0.4.0
118 | - gitdb==4.0.7
119 | - gitpython==3.1.14
120 | - googledrivedownloader==0.4
121 | - h5py==3.1.0
122 | - isodate==0.6.0
123 | - jinja2==3.0.1
124 | - kiwisolver==1.3.1
125 | - littleutils==0.2.2
126 | - llvmlite==0.36.0
127 | - markupsafe==2.0.1
128 | - matplotlib==3.3.4
129 | - networkx==2.5.1
130 | - numba==0.53.1
131 | - ogb==1.3.1
132 | - outdated==0.2.1
133 | - pathtools==0.1.2
134 | - promise==2.3
135 | - psutil==5.8.0
136 | - pyarrow==4.0.0
137 | - pyparsing==2.4.7
138 | - python-louvain==0.15
139 | - pyyaml==5.4.1
140 | - rdflib==5.0.0
141 | - sentry-sdk==1.1.0
142 | - shortuuid==1.0.1
143 | - sklearn==0.0
144 | - smmap==4.0.0
145 | - subprocess32==3.5.4
146 | - torch-geometric==1.7.0
147 | - wandb==0.10.30
148 | - wilds==1.1.0
149 | - "--editable=git+https://github.com/tmbdev/webdataset.git@a4f3ec08551b42f20b20cdc1ba32d12536eabc15#egg=webdataset"
150 | - git+https://github.com/modestyachts/ImageNetV2_pytorch
151 | - https://pytorch-geometric.com/whl/torch-1.7.0+cu110/torch_scatter-2.0.6-cp36-cp36m-linux_x86_64.whl
152 | prefix: /home/gamaga/anaconda3/envs/open_clip
153 |
--------------------------------------------------------------------------------
/third_party/open_clip/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.relu1 = nn.ReLU(inplace=True)
20 |
21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.relu2 = nn.ReLU(inplace=True)
24 |
25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26 |
27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29 | self.relu3 = nn.ReLU(inplace=True)
30 |
31 | self.downsample = None
32 | self.stride = stride
33 |
34 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36 | self.downsample = nn.Sequential(OrderedDict([
37 | ("-1", nn.AvgPool2d(stride)),
38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39 | ("1", nn.BatchNorm2d(planes * self.expansion))
40 | ]))
41 |
42 | def forward(self, x: torch.Tensor):
43 | identity = x
44 |
45 | out = self.relu1(self.bn1(self.conv1(x)))
46 | out = self.relu2(self.bn2(self.conv2(out)))
47 | out = self.avgpool(out)
48 | out = self.bn3(self.conv3(out))
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.relu3(out)
55 | return out
56 |
57 |
58 | class AttentionPool2d(nn.Module):
59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60 | super().__init__()
61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62 | self.k_proj = nn.Linear(embed_dim, embed_dim)
63 | self.q_proj = nn.Linear(embed_dim, embed_dim)
64 | self.v_proj = nn.Linear(embed_dim, embed_dim)
65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66 | self.num_heads = num_heads
67 |
68 | def forward(self, x):
69 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72 | x, _ = F.multi_head_attention_forward(
73 | query=x[:1], key=x, value=x,
74 | embed_dim_to_check=x.shape[-1],
75 | num_heads=self.num_heads,
76 | q_proj_weight=self.q_proj.weight,
77 | k_proj_weight=self.k_proj.weight,
78 | v_proj_weight=self.v_proj.weight,
79 | in_proj_weight=None,
80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81 | bias_k=None,
82 | bias_v=None,
83 | add_zero_attn=False,
84 | dropout_p=0,
85 | out_proj_weight=self.c_proj.weight,
86 | out_proj_bias=self.c_proj.bias,
87 | use_separate_proj_weight=True,
88 | training=self.training,
89 | need_weights=False
90 | )
91 | return x.squeeze(0)
92 |
93 |
94 | class ModifiedResNet(nn.Module):
95 | """
96 | A ResNet class that is similar to torchvision's but contains the following changes:
97 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
98 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
99 | - The final pooling layer is a QKV attention instead of an average pool
100 | """
101 |
102 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
103 | super().__init__()
104 | self.output_dim = output_dim
105 | self.input_resolution = input_resolution
106 |
107 | # the 3-layer stem
108 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
109 | self.bn1 = nn.BatchNorm2d(width // 2)
110 | self.relu1 = nn.ReLU(inplace=True)
111 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112 | self.bn2 = nn.BatchNorm2d(width // 2)
113 | self.relu2 = nn.ReLU(inplace=True)
114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
115 | self.bn3 = nn.BatchNorm2d(width)
116 | self.relu3 = nn.ReLU(inplace=True)
117 | self.avgpool = nn.AvgPool2d(2)
118 |
119 | # residual layers
120 | self._inplanes = width # this is a *mutable* variable used during construction
121 | self.layer1 = self._make_layer(width, layers[0])
122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
125 |
126 | embed_dim = width * 32 # the ResNet feature dimension
127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
128 |
129 | def _make_layer(self, planes, blocks, stride=1):
130 | layers = [Bottleneck(self._inplanes, planes, stride)]
131 |
132 | self._inplanes = planes * Bottleneck.expansion
133 | for _ in range(1, blocks):
134 | layers.append(Bottleneck(self._inplanes, planes))
135 |
136 | return nn.Sequential(*layers)
137 |
138 | def forward(self, x):
139 | def stem(x):
140 | x = self.relu1(self.bn1(self.conv1(x)))
141 | x = self.relu2(self.bn2(self.conv2(x)))
142 | x = self.relu3(self.bn3(self.conv3(x)))
143 | x = self.avgpool(x)
144 | return x
145 |
146 | x = x.type(self.conv1.weight.dtype)
147 | x = stem(x)
148 | x = self.layer1(x)
149 | x = self.layer2(x)
150 | x = self.layer3(x)
151 | x = self.layer4(x)
152 | x = self.attnpool(x)
153 |
154 | return x
155 |
156 |
157 | class LayerNorm(nn.LayerNorm):
158 | """Subclass torch's LayerNorm to handle fp16."""
159 |
160 | def forward(self, x: torch.Tensor):
161 | orig_type = x.dtype
162 | ret = super().forward(x.type(torch.float32))
163 | return ret.type(orig_type)
164 |
165 |
166 | class QuickGELU(nn.Module):
167 | def forward(self, x: torch.Tensor):
168 | return x * torch.sigmoid(1.702 * x)
169 |
170 |
171 | class ResidualAttentionBlock(nn.Module):
172 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
173 | super().__init__()
174 |
175 | self.attn = nn.MultiheadAttention(d_model, n_head)
176 | self.ln_1 = LayerNorm(d_model)
177 | self.mlp = nn.Sequential(OrderedDict([
178 | ("c_fc", nn.Linear(d_model, d_model * 4)),
179 | ("gelu", QuickGELU()),
180 | ("c_proj", nn.Linear(d_model * 4, d_model))
181 | ]))
182 | self.ln_2 = LayerNorm(d_model)
183 | self.attn_mask = attn_mask
184 |
185 | def attention(self, x: torch.Tensor):
186 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
187 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
188 |
189 | def forward(self, x: torch.Tensor):
190 | x = x + self.attention(self.ln_1(x))
191 | x = x + self.mlp(self.ln_2(x))
192 | return x
193 |
194 |
195 | class Transformer(nn.Module):
196 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
197 | super().__init__()
198 | self.width = width
199 | self.layers = layers
200 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
201 |
202 | def forward(self, x: torch.Tensor):
203 | return self.resblocks(x)
204 |
205 |
206 | class VisionTransformer(nn.Module):
207 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
208 | super().__init__()
209 | self.input_resolution = input_resolution
210 | self.output_dim = output_dim
211 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
212 |
213 | scale = width ** -0.5
214 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
215 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
216 | self.ln_pre = LayerNorm(width)
217 |
218 | self.transformer = Transformer(width, layers, heads)
219 |
220 | self.ln_post = LayerNorm(width)
221 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
222 |
223 | def forward(self, x: torch.Tensor):
224 | x = self.conv1(x) # shape = [*, width, grid, grid]
225 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
226 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
227 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
228 | x = x + self.positional_embedding.to(x.dtype)
229 | x = self.ln_pre(x)
230 |
231 | x = x.permute(1, 0, 2) # NLD -> LND
232 | x = self.transformer(x)
233 | x = x.permute(1, 0, 2) # LND -> NLD
234 |
235 | x = self.ln_post(x[:, 0, :])
236 |
237 | if self.proj is not None:
238 | x = x @ self.proj
239 |
240 | return x
241 |
242 |
243 | class CLIP(nn.Module):
244 | def __init__(self,
245 | embed_dim: int,
246 | # vision
247 | image_resolution: int,
248 | vision_layers: Union[Tuple[int, int, int, int], int],
249 | vision_width: int,
250 | vision_patch_size: int,
251 | # text
252 | context_length: int,
253 | vocab_size: int,
254 | transformer_width: int,
255 | transformer_heads: int,
256 | transformer_layers: int
257 | ):
258 | super().__init__()
259 |
260 | self.context_length = context_length
261 |
262 | if isinstance(vision_layers, (tuple, list)):
263 | vision_heads = vision_width * 32 // 64
264 | self.visual = ModifiedResNet(
265 | layers=vision_layers,
266 | output_dim=embed_dim,
267 | heads=vision_heads,
268 | input_resolution=image_resolution,
269 | width=vision_width
270 | )
271 | else:
272 | vision_heads = vision_width // 64
273 | self.visual = VisionTransformer(
274 | input_resolution=image_resolution,
275 | patch_size=vision_patch_size,
276 | width=vision_width,
277 | layers=vision_layers,
278 | heads=vision_heads,
279 | output_dim=embed_dim
280 | )
281 |
282 | self.transformer = Transformer(
283 | width=transformer_width,
284 | layers=transformer_layers,
285 | heads=transformer_heads,
286 | attn_mask=self.build_attention_mask()
287 | )
288 |
289 | self.vocab_size = vocab_size
290 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
291 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
292 | self.ln_final = LayerNorm(transformer_width)
293 |
294 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
295 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
296 |
297 | self.initialize_parameters()
298 |
299 | def initialize_parameters(self):
300 | nn.init.normal_(self.token_embedding.weight, std=0.02)
301 | nn.init.normal_(self.positional_embedding, std=0.01)
302 |
303 | if isinstance(self.visual, ModifiedResNet):
304 | if self.visual.attnpool is not None:
305 | std = self.visual.attnpool.c_proj.in_features ** -0.5
306 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
307 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
308 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
309 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
310 |
311 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
312 | for name, param in resnet_block.named_parameters():
313 | if name.endswith("bn3.weight"):
314 | nn.init.zeros_(param)
315 |
316 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
317 | attn_std = self.transformer.width ** -0.5
318 | fc_std = (2 * self.transformer.width) ** -0.5
319 | for block in self.transformer.resblocks:
320 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
321 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
322 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
323 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
324 |
325 | if self.text_projection is not None:
326 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
327 |
328 | def build_attention_mask(self):
329 | # lazily create causal attention mask, with full attention between the vision tokens
330 | # pytorch uses additive attention mask; fill with -inf
331 | mask = torch.empty(self.context_length, self.context_length)
332 | mask.fill_(float("-inf"))
333 | mask.triu_(1) # zero out the lower diagonal
334 | return mask
335 |
336 | @property
337 | def dtype(self):
338 | return self.visual.conv1.weight.dtype
339 |
340 | def encode_image(self, image):
341 | return self.visual(image.type(self.dtype))
342 |
343 | def encode_text(self, text):
344 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
345 |
346 | x = x + self.positional_embedding.type(self.dtype)
347 | x = x.permute(1, 0, 2) # NLD -> LND
348 | x = self.transformer(x)
349 | x = x.permute(1, 0, 2) # LND -> NLD
350 | x = self.ln_final(x).type(self.dtype)
351 |
352 | # x.shape = [batch_size, n_ctx, transformer.width]
353 | # take features from the eot embedding (eot_token is the highest number in each sequence)
354 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
355 |
356 | return x
357 |
358 | def forward(self, image, text):
359 | image_features = self.encode_image(image)
360 | text_features = self.encode_text(text)
361 |
362 | # normalized features
363 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
364 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
365 |
366 | # cosine similarity as logits
367 | logit_scale = self.logit_scale.exp()
368 | logits_per_image = logit_scale * image_features @ text_features.t()
369 | logits_per_text = logits_per_image.t()
370 |
371 | # shape = [global_batch_size, global_batch_size]
372 | return logits_per_image, logits_per_text
373 |
374 |
375 | def convert_weights(model: nn.Module):
376 | """Convert applicable model parameters to fp16"""
377 |
378 | def _convert_weights_to_fp16(l):
379 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
380 | l.weight.data = l.weight.data.half()
381 | if l.bias is not None:
382 | l.bias.data = l.bias.data.half()
383 |
384 | if isinstance(l, nn.MultiheadAttention):
385 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
386 | tensor = getattr(l, attr)
387 | if tensor is not None:
388 | tensor.data = tensor.data.half()
389 |
390 | for name in ["text_projection", "proj"]:
391 | if hasattr(l, name):
392 | attr = getattr(l, name)
393 | if attr is not None:
394 | attr.data = attr.data.half()
395 |
396 | model.apply(_convert_weights_to_fp16)
397 |
398 |
399 | def build_model(state_dict: dict):
400 | vit = "visual.proj" in state_dict
401 |
402 | if vit:
403 | vision_width = state_dict["visual.conv1.weight"].shape[0]
404 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
405 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
406 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
407 | image_resolution = vision_patch_size * grid_size
408 | else:
409 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
410 | vision_layers = tuple(counts)
411 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
412 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
413 | vision_patch_size = None
414 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
415 | image_resolution = output_width * 32
416 |
417 | embed_dim = state_dict["text_projection"].shape[1]
418 | context_length = state_dict["positional_embedding"].shape[0]
419 | vocab_size = state_dict["token_embedding.weight"].shape[0]
420 | transformer_width = state_dict["ln_final.weight"].shape[0]
421 | transformer_heads = transformer_width // 64
422 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
423 |
424 | model = CLIP(
425 | embed_dim,
426 | image_resolution, vision_layers, vision_width, vision_patch_size,
427 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
428 | )
429 |
430 | for key in ["input_resolution", "context_length", "vocab_size"]:
431 | if key in state_dict:
432 | del state_dict[key]
433 |
434 | convert_weights(model)
435 | model.load_state_dict(state_dict)
436 | return model.eval()
--------------------------------------------------------------------------------
/third_party/open_clip/model_configs/RN101.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "image_resolution": 224,
4 | "vision_layers": [
5 | 3,
6 | 4,
7 | 23,
8 | 3
9 | ],
10 | "vision_width": 64,
11 | "vision_patch_size": null,
12 | "context_length": 77,
13 | "vocab_size": 49408,
14 | "transformer_width": 512,
15 | "transformer_heads": 8,
16 | "transformer_layers": 12
17 | }
--------------------------------------------------------------------------------
/third_party/open_clip/model_configs/RN50.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "image_resolution": 224,
4 | "vision_layers": [
5 | 3,
6 | 4,
7 | 6,
8 | 3
9 | ],
10 | "vision_width": 64,
11 | "vision_patch_size": null,
12 | "context_length": 77,
13 | "vocab_size": 49408,
14 | "transformer_width": 512,
15 | "transformer_heads": 8,
16 | "transformer_layers": 12
17 | }
--------------------------------------------------------------------------------
/third_party/open_clip/model_configs/RN50_a2.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "image_resolution": 224,
4 | "vision_layers": [
5 | 3,
6 | 4,
7 | 6,
8 | 3
9 | ],
10 | "vision_width": 64,
11 | "vision_patch_size": null,
12 | "context_length": 77,
13 | "vocab_size": 49408,
14 | "transformer_width": 512,
15 | "transformer_heads": 8,
16 | "transformer_layers": 12,
17 | "extra_transformer_layers": 2,
18 | "share_projection_layer": false
19 | }
--------------------------------------------------------------------------------
/third_party/open_clip/model_configs/RN50_a2s.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "image_resolution": 224,
4 | "vision_layers": [
5 | 3,
6 | 4,
7 | 6,
8 | 3
9 | ],
10 | "vision_width": 64,
11 | "vision_patch_size": null,
12 | "context_length": 77,
13 | "vocab_size": 49408,
14 | "transformer_width": 512,
15 | "transformer_heads": 8,
16 | "transformer_layers": 12,
17 | "extra_transformer_layers": 2,
18 | "share_projection_layer": true
19 | }
--------------------------------------------------------------------------------
/third_party/open_clip/model_configs/RN50x16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "image_resolution": 384,
4 | "vision_layers": [
5 | 6,
6 | 8,
7 | 18,
8 | 8
9 | ],
10 | "vision_width": 96,
11 | "vision_patch_size": null,
12 | "context_length": 77,
13 | "vocab_size": 49408,
14 | "transformer_width": 768,
15 | "transformer_heads": 12,
16 | "transformer_layers": 12
17 | }
--------------------------------------------------------------------------------
/third_party/open_clip/model_configs/RN50x4.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "image_resolution": 288,
4 | "vision_layers": [
5 | 4,
6 | 6,
7 | 10,
8 | 6
9 | ],
10 | "vision_width": 80,
11 | "vision_patch_size": null,
12 | "context_length": 77,
13 | "vocab_size": 49408,
14 | "transformer_width": 640,
15 | "transformer_heads": 10,
16 | "transformer_layers": 12
17 | }
--------------------------------------------------------------------------------
/third_party/open_clip/model_configs/ViT-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "image_resolution": 224,
4 | "vision_layers": 12,
5 | "vision_width": 768,
6 | "vision_patch_size": 16,
7 | "context_length": 77,
8 | "vocab_size": 49408,
9 | "transformer_width": 512,
10 | "transformer_heads": 8,
11 | "transformer_layers": 12
12 | }
--------------------------------------------------------------------------------
/third_party/open_clip/model_configs/ViT-B-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "image_resolution": 224,
4 | "vision_layers": 12,
5 | "vision_width": 768,
6 | "vision_patch_size": 32,
7 | "context_length": 77,
8 | "vocab_size": 49408,
9 | "transformer_width": 512,
10 | "transformer_heads": 8,
11 | "transformer_layers": 12
12 | }
--------------------------------------------------------------------------------
/third_party/open_clip/scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def assign_learning_rate(optimizer, new_lr):
4 | for param_group in optimizer.param_groups:
5 | param_group["lr"] = new_lr
6 |
7 | def _warmup_lr(base_lr, warmup_length, step):
8 | return base_lr * (step + 1) / warmup_length
9 |
10 | def cosine_lr(optimizer, base_lr, warmup_length, steps):
11 | def _lr_adjuster(step):
12 | if step < warmup_length:
13 | lr = _warmup_lr(base_lr, warmup_length, step)
14 | else:
15 | e = step - warmup_length
16 | es = steps - warmup_length
17 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
18 | assign_learning_rate(optimizer, lr)
19 | return lr
20 | return _lr_adjuster
--------------------------------------------------------------------------------
/third_party/open_clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
--------------------------------------------------------------------------------