├── utils.py ├── README.md ├── reclip.py ├── LICENSE ├── datasets.py ├── models.py └── prompts └── clip_prompts /utils.py: -------------------------------------------------------------------------------- 1 | 2 | def accuracy(output, target, topk=(1,)): 3 | pred = output.topk(max(topk), 1, True, True)[1].t() 4 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 5 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReCLIP: Refine Contrastive Language Image Pre-Training with Source Free Domain Adaptation 2 | 3 | ## Overview 4 | 5 | This repository provides the official PyTorch implementation of our WACV 2024 (Oral) Paper [ReCLIP: Refine Contrastive Language Image Pre-Training with Source Free Domain Adaptation](https://arxiv.org/abs/2308.03793) 6 | 7 | ### Hardware 8 | We have evaluated our code on NVIDIA A100 GPU with 40GB GPU Memory with batch size of 64. Please use --parallel and smaller batch size for smaller memory GPU. 9 | 10 | ### Environment 11 | We tested our code with PyTorch 1.12.0. 12 | 13 | ### Model Weight 14 | We use [CLIP](https://github.com/openai/CLIP) ViT-L/14 as our main base model for adaptation. It is also possible to use other architecture by configing the --architecture option. Our code will automatically download the CLIP checkpoint from [link](https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt) and put it under the ./ckpt folder. 15 | 16 | ### License 17 | ReCLIP is released under the Apache 2.0 license. Please see the [LICENSE](LICENSE) file for more information. 18 | 19 | 20 | 21 | ## Citations 22 | 23 | @article{xuefeng2023reclip, 24 | title={ReCLIP: Refine Contrastive Language Image Pre-Training with Source Free Domain Adaptation}, 25 | author={Xuefeng, Hu and Ke, Zhang and Lu, Xia and Albert, Chen and Jiajia, Luo and Yuyin, Sun and Ken, Wang and Nan, Qiao and Xiao, Zeng and Min, Sun and others}, 26 | journal={2024 IEEE winter conference on applications of computer vision (WACV)}, 27 | year={2024}, 28 | organization={IEEE} 29 | } 30 | 31 | ## Acknowledgements 32 | This work is completed during Xuefeng's internship at Amazon. 33 | -------------------------------------------------------------------------------- /reclip.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import time 5 | import random 6 | import argparse 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torch.utils.data import DataLoader 13 | from torch.nn import DataParallel 14 | 15 | from datasets import get_dataset, accuracy 16 | from models import CLIP_LN_T, CLIP_LN_V, LabelPropagationCluster 17 | 18 | 19 | parser = argparse.ArgumentParser(description='CLIP Evaluation') 20 | parser.add_argument('--dataset', default='resisc45', help='dataset name') 21 | parser.add_argument('--architecture', default='ViT-L/14', help='architecture name', choices=['RN50','RN101','RN50x4','RN50x16','RN50x64','ViT-L/14','ViT-L/14@336px','ViT-B/32','ViT-B/16']) 22 | 23 | # training hpyter parameters 24 | parser.add_argument('--bs', '--batch-size', default=64, type=int, metavar='N', dest='bs') 25 | parser.add_argument('--epoch_max', '--epoch_max', default=100, type=int, metavar='N', dest='epoch_max') 26 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') 27 | parser.add_argument('--lrt', default=1e-3, type=float, metavar='LR', help='initial learning rate for text encoder', dest='lrt') 28 | parser.add_argument('--lrv', default=1e-3, type=float, metavar='LR', help='initial learning rate for visual encoder', dest='lrv') 29 | 30 | # parser.add_argument('--update_basis', action='store_true') 31 | # parser.add_argument('--mean_per_class', action='store_true') 32 | 33 | # label propagation parameters 34 | parser.add_argument('--neighbor_size', default=20, type=int) 35 | parser.add_argument('--alpha', default=0.99, type=float) 36 | parser.add_argument('--cut_dim', default=768, type=int) 37 | 38 | # training control 39 | parser.add_argument('--seed', default=0, type=int) 40 | parser.add_argument('--monitor', action='store_true') 41 | parser.add_argument('--parallel', action='store_true') 42 | parser.add_argument('--log', default='./logs/', type=str, help='log dir, should be something like ./logs/') 43 | 44 | 45 | # main function 46 | def main(): 47 | device = "cuda" if torch.cuda.is_available() else "cpu" 48 | args = parser.parse_args() 49 | 50 | torch.manual_seed(args.seed) 51 | random.seed(10) 52 | 53 | # load class names and prompts (provided by CLIP) 54 | with open("./prompts/clip_prompts", 'r') as filename: 55 | names_prompts = json.load(filename) 56 | class_names = names_prompts[args.dataset]["classes"] 57 | templates = names_prompts[args.dataset]["templates"] 58 | 59 | # load the ReCLIP-V model 60 | v_model = CLIP_LN_V(class_names=class_names, templates=templates, architecture=args.architecture, learnable_classifier=False) 61 | if torch.cuda.is_available(): 62 | if args.parallel: 63 | v_model = DataParallel(v_model) # data parallel 64 | v_model.to(device) 65 | 66 | # optimizer for ReCLIP-V visual-encoder layer-norm parameters 67 | if args.parallel: 68 | v_optimizer = torch.optim.SGD(v_model.module.learnable_params, args.lrv, weight_decay=args.weight_decay, momentum=0.9) 69 | else: 70 | v_optimizer = torch.optim.SGD(v_model.learnable_params, args.lrv, weight_decay=args.weight_decay, momentum=0.9) 71 | 72 | # load the ReCLIP-T model 73 | t_model = CLIP_LN_T(architecture=args.architecture, templates=templates) 74 | if torch.cuda.is_available(): 75 | if args.parallel: 76 | t_model = DataParallel(t_model) 77 | t_model.to(device) 78 | 79 | # optimizer for ReCLIP-T text-encoder layer-norm parameters 80 | if args.parallel: 81 | t_optimizer = torch.optim.SGD(t_model.module.learnable_params, args.lrt, weight_decay=args.weight_decay, momentum=0.9) 82 | else: 83 | t_optimizer = torch.optim.SGD(t_model.learnable_params, args.lrt, weight_decay=args.weight_decay, momentum=0.9) 84 | 85 | # obtain datasets with preprocess function provided by CLIP 86 | if args.parallel: 87 | test_dataset = get_dataset(dataset_name=args.dataset, preprocess=t_model.module.preprocess) 88 | else: 89 | test_dataset = get_dataset(dataset_name=args.dataset, preprocess=t_model.preprocess) 90 | 91 | # set max epoch 92 | args.epoch_max = min(args.epoch_max, int(len(test_dataset)/5000+2)) 93 | 94 | # label propagation module for ReCLIP-V and ReCLIP-T, initialize with classification weights (text embeddings from class names) from CLIP models 95 | if args.parallel: 96 | v_label_propagation = LabelPropagationCluster(v_model.module.classification_weight, len(test_dataset), k=args.neighbor_size, alpha=args.alpha, cut_dim=args.cut_dim) 97 | t_label_propagation = LabelPropagationCluster(v_model.module.classification_weight, len(test_dataset), k=args.neighbor_size, alpha=args.alpha, cut_dim=args.cut_dim) 98 | else: 99 | v_label_propagation = LabelPropagationCluster(v_model.classification_weight, len(test_dataset), k=args.neighbor_size, alpha=args.alpha, cut_dim=args.cut_dim) 100 | t_label_propagation = LabelPropagationCluster(v_model.classification_weight, len(test_dataset), k=args.neighbor_size, alpha=args.alpha, cut_dim=args.cut_dim) 101 | 102 | # loss function 103 | criterion = torch.nn.CrossEntropyLoss(reduction = 'none') 104 | 105 | # dataloaders 106 | dataset_size = len(test_dataset) 107 | test_loader = DataLoader(test_dataset, batch_size=args.bs, num_workers=32, drop_last=False, shuffle=False) 108 | train_loader = DataLoader(test_dataset, batch_size=args.bs, num_workers=32, drop_last=False, shuffle=True) 109 | 110 | # logs 111 | best_acc = 0 112 | timestr = time.strftime("%Y%m%d-%H%M%S") 113 | if(not os.path.isdir(args.log)): 114 | os.mkdir(args.log) # create folder if not exist 115 | output_file = open(args.log + args.dataset + timestr, 'a') 116 | json.dump(args.__dict__, output_file, indent=2) # store args for reference 117 | output_file.flush() 118 | args.file = output_file 119 | 120 | for epoch in range(args.epoch_max): 121 | # test phase: run forward passs over all test data for evalution for both ReCLIP-T and ReCLIP-V; prepare pseudo labels at the same time 122 | with torch.no_grad(): 123 | top1t, top1v, top1c, n = 0., 0., 0., 0 124 | 125 | # monitor setup 126 | if(args.monitor): 127 | pbar = tqdm(test_loader) 128 | else: 129 | pbar = test_loader 130 | 131 | # evaluation starts 132 | for images, idx, label in pbar: 133 | # inputs. abusolute idx of current example is also provided in order to record pseudo labels 134 | image_input = images.to(device) 135 | label = label.to(device).view(-1) 136 | 137 | # forward pass of ReCLIP-T 138 | t_logits, t_feature = t_model(image_input, class_names) 139 | t_acc = accuracy(t_logits, label, topk=(1,))[0] 140 | 141 | # forward pass of ReCLIP-V 142 | v_logits, v_feature = v_model(image_input) 143 | v_acc = accuracy(v_logits, label, topk=(1,))[0] 144 | 145 | # update the label propagation mododules with visual features collected from ReCLIP-V and ReCLIP-T 146 | t_label_propagation(t_feature, idx, label) 147 | v_label_propagation(v_feature, idx, label) 148 | 149 | # combined logits for prediction 150 | c_logits = 0.5 * (t_logits + v_logits) 151 | c_acc = accuracy(c_logits, label, topk=(1,))[0] 152 | 153 | # summary 154 | top1t += t_acc 155 | top1v += v_acc 156 | top1c += c_acc 157 | n += len(label) 158 | 159 | # update progress bar 160 | if(args.monitor): 161 | pbar.set_description(f"Epoch = {(epoch):d} Test Accuracy (C/T/V) = {(100*top1c/n):.2f}%, {(100*top1t/n):.2f}%, {(100*top1v/n):.2f}%") 162 | 163 | # end of evaluation, collected features from all samples, perform label propagation 164 | # label propagation function returns the accuracy of pseudo labels generated by ReCLIP-T 165 | pt_acc = t_label_propagation.perform_label_propagation(clear_cache=True) 166 | # label propagation function returns the accuracy of pseudo labels generated by ReCLIP-V, as well as the clustering centriods 167 | pv_acc, centriods = v_label_propagation.perform_label_propagation(clear_cache=True, cluster_centriod=True) 168 | 169 | # updates the ReCLIP-V classification weights with clustering centriods (ReCLIP-V uses cluster centriods for classification) 170 | if args.parallel: 171 | v_model.module.classification_weight = centriods.t() 172 | else: 173 | v_model.classification_weight = centriods.t() 174 | 175 | # logging: update best acc 176 | if(100 * top1c / n > best_acc): 177 | best_acc = 100 * top1c / n 178 | 179 | # logging: end of epoch summary 180 | if(args.monitor): 181 | print(f"Epoch = {(epoch):d} Best Accuracy = {best_acc:.2f}%, Pseudo Label Accuracy (T/V) = {100 * pt_acc:.2f}%, {100 * pv_acc:.2f}%") 182 | 183 | # logging 184 | args.file.write(f"Epoch = {(epoch):d} Test Accuracy (C/T/V) = {(100*top1c/n):.2f}%, {(100*top1t/n):.2f}%, {(100*top1v/n):.2f}%\n") 185 | args.file.write(f"Epoch = {(epoch):d} Best Accuracy = {best_acc:.2f}%, Pseudo Label Accuracy (T/V) = {100 * pt_acc:.2f}%, {100 * pv_acc:.2f}%\n") 186 | args.file.flush() 187 | 188 | 189 | # training phase: updates ReCLIP-T and ReCLIP-V parameters with pseudo labels 190 | top1t, top1v, top1c, n = 0., 0., 0., 0 191 | 192 | # monitor setup 193 | if(args.monitor): 194 | pbar = tqdm(train_loader) 195 | else: 196 | pbar = train_loader 197 | 198 | # training starts 199 | for images, idx, label in pbar: 200 | # inputs. abusolute idx of current example is also provided in order to lookup pseudo labels 201 | image_input = images.to(device) 202 | label = label.to(device).view(-1) 203 | 204 | # forward pass of ReCLIP-T 205 | t_logits, _ = t_model(image_input, class_names) 206 | 207 | # forward pass of ReCLIP-V 208 | v_logits, _ = v_model(image_input) 209 | 210 | # get pseudo labels for ReCLIP-T, based on current example idx 211 | t_pseudo_labels, _ = t_label_propagation.get_pseudo_label(idx) 212 | t_pseudo_labels = torch.LongTensor(t_pseudo_labels).to(device) 213 | 214 | # get pseudo labels for ReCLIP-V, based on current example idx 215 | v_pseudo_labels, _ = v_label_propagation.get_pseudo_label(idx) 216 | v_pseudo_labels = torch.LongTensor(v_pseudo_labels).to(device) 217 | 218 | # use commonly agreed pseudo labels for training 219 | confidence_map = (v_pseudo_labels == t_pseudo_labels) 220 | 221 | # if there is any commonly agreed labels, otherwise (unlikely) skip the current training 222 | if(torch.sum(confidence_map) > 0): 223 | 224 | # back propagation for ReCLIP-T, only updates the entry where both ReCLIP-T and ReCLIP-V agrees 225 | t_optimizer.zero_grad() 226 | t_loss = torch.sum(criterion(t_logits, t_pseudo_labels) * confidence_map) / torch.sum(confidence_map) 227 | t_loss.backward() 228 | t_optimizer.step() 229 | 230 | # back propagation for ReCLIP-V, only updates the entry where both ReCLIP-T and ReCLIP-V agrees 231 | v_optimizer.zero_grad() 232 | v_loss = torch.mean(criterion(v_logits, v_pseudo_labels) * confidence_map) / torch.sum(confidence_map) 233 | v_loss.backward() 234 | v_optimizer.step() 235 | 236 | # accuracy record 237 | c_logits = 0.5 * (t_logits + v_logits) 238 | v_acc = accuracy(v_logits, label, topk=(1,))[0] 239 | t_acc = accuracy(t_logits, label, topk=(1,))[0] 240 | c_acc = accuracy(c_logits, label, topk=(1,))[0] 241 | 242 | # summary 243 | top1t += t_acc 244 | top1v += v_acc 245 | top1c += c_acc 246 | n += len(label) 247 | 248 | # update progress bar 249 | if(args.monitor): 250 | pbar.set_description(f"Epoch = {(epoch):d} Training Accuracy (C/T/V) = {(100*top1c/n):.2f}%, {(100*top1t/n):.2f}%, {(100*top1v/n):.2f}%") 251 | 252 | # updates the projection matrix and the classification weight in ReCLIP-T 253 | # for ReCLIP-V, it uses clustering centriods for classification, therefore it does not require this update 254 | with torch.no_grad(): 255 | if args.parallel: 256 | classification_weight_t = t_model.module.encode_text(class_names, full_templates=True) 257 | else: 258 | classification_weight_t = t_model.encode_text(class_names, full_templates=True) 259 | 260 | t_label_propagation.update_projection(classification_weight_t) 261 | t_label_propagation.update_centriods(classification_weight_t.t()) 262 | 263 | if __name__ == "__main__": 264 | main() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 IDEA 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | 204 | 205 | Conditional DETR(https://github.com/Atten4Vis/ConditionalDETR) 206 | 207 | Copyright 2021 Microsoft. 208 | 209 | Licensed under the Apache License, Version 2.0 (the "License"); 210 | you may not use this file except in compliance with the License. 211 | You may obtain a copy of the License at 212 | 213 | http://www.apache.org/licenses/LICENSE-2.0 214 | 215 | Unless required by applicable law or agreed to in writing, software 216 | distributed under the License is distributed on an "AS IS" BASIS, 217 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 218 | See the License for the specific language governing permissions and 219 | limitations under the License. 220 | 221 | 222 | Deformable DETR(https://github.com/fundamentalvision/Deformable-DETR) 223 | 224 | Copyright 2020 SenseTime 225 | 226 | Licensed under the Apache License, Version 2.0 (the "License"); 227 | you may not use this file except in compliance with the License. 228 | You may obtain a copy of the License at 229 | 230 | http://www.apache.org/licenses/LICENSE-2.0 231 | 232 | Unless required by applicable law or agreed to in writing, software 233 | distributed under the License is distributed on an "AS IS" BASIS, 234 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 235 | See the License for the specific language governing permissions and 236 | limitations under the License. 237 | 238 | 239 | DETR(https://github.com/facebookresearch/detr) 240 | 241 | Copyright 2020 - present, Facebook, Inc 242 | 243 | Licensed under the Apache License, Version 2.0 (the "License"); 244 | you may not use this file except in compliance with the License. 245 | You may obtain a copy of the License at 246 | 247 | http://www.apache.org/licenses/LICENSE-2.0 248 | 249 | Unless required by applicable law or agreed to in writing, software 250 | distributed under the License is distributed on an "AS IS" BASIS, 251 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 252 | See the License for the specific language governing permissions and 253 | limitations under the License. 254 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import socket 4 | import torchvision 5 | import pandas 6 | import json 7 | import scipy.io 8 | import torch 9 | from os import path 10 | from PIL import Image 11 | from tqdm import tqdm 12 | import random 13 | import numpy as np 14 | from torchvision.datasets import Food101, CIFAR10, CIFAR100, StanfordCars 15 | from torchvision.datasets import FGVCAircraft, VOCDetection, DTD, OxfordIIITPet, Caltech101 16 | from torchvision.datasets import EuroSAT, GTSRB, Kitti, Country211, PCAM, Kinetics, RenderedSST2 17 | from torchvision.datasets import UCF101, FER2013, ImageNet, Flowers102, MNIST, STL10 18 | from torchvision.transforms import transforms 19 | from torchvision import transforms, datasets 20 | from torch.utils.data import Dataset 21 | from urllib.error import HTTPError 22 | 23 | 24 | # SETUP DATASET FOLDER HERE!!!! 25 | if(socket.gethostname()[-7:] == 'usc.edu'): 26 | data_dir = "/project/nevatia_174/xuefeng_files/clip_datasets" 27 | else: 28 | data_dir = "path/to/your/data/folder" 29 | 30 | 31 | # default transformation for debugging the code 32 | ImageNetTransform = transforms.Compose([ 33 | transforms.Resize((256, 256)), 34 | transforms.CenterCrop(224), 35 | transforms.ToTensor(), 36 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 37 | std=[0.229, 0.224, 0.225]) 38 | ]) 39 | 40 | pytorch_implemented_sets = [ 41 | 'food101', 42 | 'cifar10', 43 | 'cifar100', 44 | 'standford_cars', 45 | 'fgvc', 46 | 'dtd', 47 | 'oxford_pets', 48 | 'flowers102', 49 | 'mnist', 50 | 'stl10', 51 | 'eurosat', 52 | 'gtsrb', 53 | 'country211', 54 | 'pcam', 55 | 'renderedsst2', 56 | 'caltech101', 57 | # 'ucf101', 58 | # 'kitti', 59 | # 'k700', 60 | # 'voc2007', 61 | # 'cifar10_train', 62 | # 'cifar100_train', 63 | ] 64 | 65 | self_implemented_sets = [ 66 | 'fer2013', 67 | 'imagenet', 68 | 'birdsnap', 69 | 'resisc45', 70 | 'aid', 71 | 'sun397', 72 | 'office_d', 73 | 'office_a', 74 | 'office_w', 75 | 'office_ar', 76 | 'office_cl', 77 | 'office_pr', 78 | 'office_rw', 79 | # 'aid_train', 80 | # 'sun397_train', 81 | ] 82 | 83 | 84 | def accuracy(output, target, topk=(1,)): 85 | pred = output.topk(max(topk), 1, True, True)[1].t() 86 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 87 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 88 | 89 | class pytorch_dataset_wrapper(Dataset): 90 | def __init__(self, name, preprocess=ImageNetTransform): 91 | assert(name in pytorch_implemented_sets) 92 | if(name == 'food101'): 93 | self.dataset = Food101(root=data_dir, split='test', transform=preprocess, download=True) 94 | elif(name == 'cifar10'): 95 | self.dataset = CIFAR10(root=data_dir, train=False, transform=preprocess, download=True) 96 | elif(name == 'cifar100'): 97 | self.dataset = CIFAR100(root=data_dir, train=False, transform=preprocess, download=True) 98 | elif(name == 'cifar10_train'): 99 | self.dataset = CIFAR10(root=data_dir, train=True, transform=preprocess, download=True) 100 | elif(name == 'cifar100_train'): 101 | self.dataset = CIFAR100(root=data_dir, train=True, transform=preprocess, download=True) 102 | elif(name == 'standford_cars'): 103 | self.dataset = StanfordCars(root=data_dir, split='test', transform=preprocess, download=True) 104 | elif(name == 'fgvc'): 105 | self.dataset = FGVCAircraft(root=data_dir, split='test', transform=preprocess, download=True) 106 | elif(name == 'dtd'): 107 | self.dataset = DTD(root=data_dir, split='test', transform=preprocess, download=True) 108 | elif(name == 'oxford_pets'): 109 | self.dataset = OxfordIIITPet(root=data_dir, split='test', transform=preprocess, download=True) 110 | elif(name == 'caltech101'): 111 | self.dataset = Caltech101(root=data_dir, transform=preprocess, download=True) # split 112 | elif(name == 'flowers102'): 113 | self.dataset = Flowers102(root=data_dir, split='test', transform=preprocess, download=True) 114 | elif(name == 'mnist'): 115 | self.dataset = MNIST(root=data_dir, train=False, transform=preprocess, download=True) 116 | elif(name == 'stl10'): 117 | self.dataset = STL10(root=data_dir, split='test', transform=preprocess, download=True) 118 | elif(name == 'eurosat'): 119 | self.dataset = EuroSAT(root=data_dir, transform=preprocess, download=True) # split 120 | elif(name == 'gtsrb'): 121 | self.dataset = GTSRB(root=data_dir, split='test', transform=preprocess, download=True) 122 | elif(name == 'country211'): 123 | self.dataset = Country211(root=data_dir, split='test', transform=preprocess, download=True) 124 | elif(name == 'pcam'): 125 | self.dataset = PCAM(root=data_dir, split='val', transform=preprocess, download=True) 126 | elif(name == 'renderedsst2'): 127 | self.dataset = RenderedSST2(root=data_dir, split='test', transform=preprocess, download=True) 128 | elif(name == 'imagenet'): 129 | self.dataset = ImageNet(root=data_dir, split='val', transform=preprocess) 130 | 131 | def __len__(self): 132 | return len(self.dataset) 133 | 134 | def __getitem__(self, idx): 135 | img, label = self.dataset[idx] 136 | return img, idx, label 137 | 138 | 139 | # self-implemented datasets 140 | class AID(Dataset): 141 | def __init__(self, transform=ImageNetTransform, root_dir=data_dir, train=False): 142 | self.transform = transform 143 | if(train): 144 | self.root_dir = root_dir + '/AID/AID_full/train/' 145 | self.label_file = pandas.read_feather(root_dir + '/AID/labels_full/labels_train.feather') 146 | else: 147 | self.root_dir = root_dir + '/AID/AID_full/test/' 148 | self.label_file = pandas.read_feather(root_dir + '/AID/labels_full/labels_test.feather') 149 | self.classes = sorted(list(set(self.label_file['class']))) 150 | self.files = self.label_file['id'] 151 | self.labels = self.label_file['class'] 152 | self.class2label = {self.classes[i]:i for i in range(len(self.classes))} 153 | 154 | def __len__(self): 155 | return(len(self.label_file['class'])) 156 | 157 | def __getitem__(self, idx): 158 | img = Image.open(self.root_dir + self.files[idx]).convert('RGB') 159 | label = self.class2label[self.labels[idx]] 160 | return self.transform(img), idx, label 161 | 162 | class resisc45(Dataset): 163 | def __init__(self, transform=ImageNetTransform, root_dir=data_dir): 164 | self.root_dir = root_dir + '/resisc45/RESISC45_full/test/' 165 | self.transform = transform 166 | self.label_file = pandas.read_feather(root_dir + '/resisc45/labels_full/labels_test.feather') 167 | self.classes = sorted(list(set(self.label_file['class']))) 168 | self.files = self.label_file['id'] 169 | self.labels = self.label_file['class'] 170 | self.class2label = {self.classes[i]:i for i in range(len(self.classes))} 171 | # print(self.classes) 172 | 173 | def __len__(self): 174 | return(len(self.label_file['class'])) 175 | 176 | def __getitem__(self, idx): 177 | img = Image.open(self.root_dir + self.files[idx]).convert('RGB') 178 | label = self.class2label[self.labels[idx]] 179 | return self.transform(img), idx, label 180 | 181 | class imagenet(Dataset): 182 | def __init__(self, transform=ImageNetTransform, root_dir=data_dir): 183 | self.root_dir = root_dir + '/imagenet/' 184 | self.transform = transform 185 | 186 | # imagenet class names 187 | name_table = json.load(open(self.root_dir + 'imagenet_categories.json','rb')) 188 | self.classes = [name_table[str(i)][1] for i in range(1000)] 189 | self.id2label = {name_table[str(i)][0]:i for i in range(1000)} 190 | 191 | # validation set labels 192 | mat = scipy.io.loadmat(os.path.join(self.root_dir, 'ILSVRC2012_devkit_t12/data/meta.mat'))['synsets'] 193 | label_file = open(os.path.join(self.root_dir, 'ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt')) 194 | self.labels = [self.id2label[mat[int(x)-1][0][1][0]] for x in label_file.readlines()] 195 | 196 | self.files = os.listdir(os.path.join(self.root_dir, 'validation')) 197 | self.files.sort() 198 | 199 | def __len__(self): 200 | return(len(self.labels)) 201 | 202 | def __getitem__(self, idx): 203 | fn = os.path.join(self.root_dir, 'validation', self.files[idx]) 204 | img = Image.open(fn).convert('RGB') 205 | label = self.labels[idx] 206 | return self.transform(img), idx, label 207 | 208 | class SUN397(Dataset): 209 | def __init__(self, transform=transforms.ToTensor(), root_dir=data_dir, train=False): 210 | self.root_dir = root_dir + '/SUN397' 211 | self.transform = transform 212 | 213 | with open(self.root_dir+'/ClassName.txt','r') as file: 214 | classnames = file.read().splitlines() 215 | self.class2id = {classnames[i]:i for i in range(len(classnames))} 216 | 217 | if(train): 218 | with open(self.root_dir+'/Training.txt','r') as file: 219 | test_files = file.read().splitlines() 220 | self.files = test_files 221 | else: 222 | with open(self.root_dir+'/Testing_01.txt','r') as file: 223 | test_files = file.read().splitlines() 224 | self.files = test_files 225 | 226 | self.labels = [] 227 | for i in range(len(self.files)): 228 | if(train): 229 | curr_class_name = '/'+'/'.join(self.files[i].split('/')[6:-1]) 230 | else: 231 | curr_class_name = '/'.join(self.files[i].split('/')[:-1]) 232 | curr_label = self.class2id[curr_class_name] 233 | self.labels.append(curr_label) 234 | 235 | def __len__(self): 236 | return(len(self.labels)) 237 | 238 | def __getitem__(self, idx): 239 | img = Image.open(self.root_dir+self.files[idx]).convert('RGB') 240 | label = self.labels[idx] 241 | return self.transform(img), idx, label 242 | 243 | class fer2013(Dataset): 244 | def __init__(self, transform=ImageNetTransform, root_dir=data_dir): 245 | self.root_dir = root_dir + '/fer2013' 246 | self.transform = transform 247 | 248 | self.images = [] 249 | self.labels = [] 250 | with open(self.root_dir + '/public_test.csv', 'r') as file: 251 | reader = csv.reader(file) 252 | for row in reader: 253 | label = int(row[0]) 254 | data = row[1].split(' ') 255 | data = np.array([int(x) for x in data]) 256 | data = np.reshape(data, (48,48)) 257 | data = Image.fromarray(np.uint8(data)).convert('RGB') 258 | self.images.append(data) 259 | self.labels.append(label) 260 | 261 | def __len__(self): 262 | return(len(self.labels)) 263 | 264 | def __getitem__(self, idx): 265 | img = self.images[idx] 266 | label = self.labels[idx] 267 | return self.transform(img), idx, label 268 | 269 | class birdsnap(Dataset): 270 | def __init__(self, transform=transforms.ToTensor(), root_dir=data_dir): 271 | self.root_dir = root_dir + '/birdsnap' 272 | self.transform = transform 273 | 274 | # available test images 275 | with open(root_dir + '/birdsnap/test_images.txt','r') as file: 276 | test_images = file.read().splitlines()[1:] 277 | self.available_files = [] 278 | for filenames in test_images: 279 | full_file_name = root_dir + '/birdsnap/download/images/' + filenames 280 | if(path.exists(full_file_name)): 281 | self.available_files.append(full_file_name) 282 | 283 | # get folders 284 | with open(root_dir + '/birdsnap/species.txt','r') as file: 285 | species = file.read().splitlines()[1:] 286 | 287 | names = [] 288 | folders = [] 289 | for line in species: 290 | line_split = line.split('\t') 291 | names.append(line_split[1]) 292 | folders.append(line_split[3]) 293 | 294 | new_order = np.argsort(names) 295 | self.names = [names[i] for i in new_order] 296 | self.folders = [folders[i] for i in new_order] 297 | 298 | self.folder2id = {self.folders[i]:i for i in range(len(folders))} 299 | 300 | # get labels 301 | self.labels = [] 302 | for file in self.available_files: 303 | folder_name = file.split('/')[-2] 304 | curr_id = self.folder2id[folder_name.lower()] 305 | self.labels.append(curr_id) 306 | 307 | def __len__(self): 308 | return(len(self.labels)) 309 | 310 | def __getitem__(self, idx): 311 | img = Image.open(self.available_files[idx]) 312 | label = self.labels[idx] 313 | return self.transform(img), idx, label 314 | 315 | 316 | class caltech101(Dataset): 317 | def __init__(self, transform=transforms.ToTensor(), root_dir=data_dir): 318 | self.root_dir = root_dir + '/caltech101/101_ObjectCategories' 319 | self.transform = transform 320 | 321 | self.categories = os.listdir(self.root_dir) 322 | self.categories.sort() 323 | 324 | self.images = [] 325 | self.labels = [] 326 | 327 | with open(root_dir+'/caltech101/caltech101_test','r') as file: 328 | self.test_dict = json.load(file) 329 | 330 | for i in range(len(self.categories)): 331 | files = self.test_dict[self.categories[i].lower()] 332 | for filename in files: 333 | full_file_name = os.path.join(self.root_dir, self.categories[i], filename) 334 | if os.path.exists(full_file_name): 335 | self.images.append(full_file_name) 336 | self.labels.append(i) 337 | else: 338 | print('missing file!') 339 | 340 | 341 | def __len__(self): 342 | return(len(self.labels)) 343 | 344 | def __getitem__(self, idx): 345 | img = Image.open(self.images[idx]).convert('RGB') 346 | label = self.labels[idx] 347 | return self.transform(img), idx, label 348 | 349 | class office(Dataset): 350 | def __init__(self, transform=transforms.ToTensor(), root_dir=data_dir, mode='dslr'): 351 | assert(mode in ['dslr','amazon','webcam']) 352 | self.root_dir = root_dir + '/office31/'+mode+'/images/' 353 | self.transform = transform 354 | 355 | self.categories = os.listdir(self.root_dir) 356 | self.categories.sort() 357 | 358 | self.images = [] 359 | self.labels = [] 360 | 361 | for i in range(len(self.categories)): 362 | files = os.listdir(path.join(self.root_dir,self.categories[i])) 363 | for file in files: 364 | self.images.append(path.join(self.root_dir,self.categories[i],file)) 365 | self.labels.append(i) 366 | 367 | def __len__(self): 368 | return(len(self.labels)) 369 | 370 | def __getitem__(self, idx): 371 | img = Image.open(self.images[idx]).convert('RGB') 372 | label = self.labels[idx] 373 | return self.transform(img), idx, label 374 | 375 | class officehome(Dataset): 376 | def __init__(self, transform=transforms.ToTensor(), root_dir=data_dir, mode='dslr'): 377 | assert(mode in ['Art','Clipart','Product', 'Real_World']) 378 | self.root_dir = root_dir + '/officehome/'+mode 379 | self.transform = transform 380 | 381 | self.categories = ['Drill', 'Exit_Sign', 'Bottle', 'Glasses', 'Computer', 'File_Cabinet', 'Shelf', 'Toys', 'Sink', 382 | 'Laptop', 'Kettle', 'Folder', 'Keyboard', 'Flipflops', 'Pencil', 'Bed', 'Hammer', 'ToothBrush', 'Couch', 383 | 'Bike', 'Postit_Notes', 'Mug', 'Webcam', 'Desk_Lamp', 'Telephone', 'Helmet', 'Mouse', 'Pen', 'Monitor', 384 | 'Mop', 'Sneakers', 'Notebook', 'Backpack', 'Alarm_Clock', 'Push_Pin', 'Paper_Clip', 'Batteries', 'Radio', 385 | 'Fan', 'Ruler', 'Pan', 'Screwdriver', 'Trash_Can', 'Printer', 'Speaker', 'Eraser', 'Bucket', 'Chair', 386 | 'Calendar', 'Calculator', 'Flowers', 'Lamp_Shade', 'Spoon', 'Candles', 'Clipboards', 'Scissors', 'TV', 387 | 'Curtains', 'Fork', 'Soda', 'Table', 'Knives', 'Oven', 'Refrigerator', 'Marker'] 388 | 389 | self.images = [] 390 | self.labels = [] 391 | 392 | for i in range(len(self.categories)): 393 | files = os.listdir(path.join(self.root_dir,self.categories[i])) 394 | for file in files: 395 | self.images.append(path.join(self.root_dir,self.categories[i],file)) 396 | self.labels.append(i) 397 | 398 | def __len__(self): 399 | return(len(self.labels)) 400 | 401 | def __getitem__(self, idx): 402 | img = Image.open(self.images[idx]).convert('RGB') 403 | label = self.labels[idx] 404 | return self.transform(img), idx, label 405 | 406 | 407 | def get_dataset(dataset_name, preprocess): 408 | print(dataset_name) 409 | # pytorch implemented dataset 410 | if(dataset_name in pytorch_implemented_sets): 411 | return pytorch_dataset_wrapper(dataset_name, preprocess) 412 | # other 413 | assert(dataset_name in self_implemented_sets) 414 | if(dataset_name == 'aid'): 415 | return AID(transform=preprocess) 416 | elif(dataset_name == 'aid_train'): 417 | return AID(transform=preprocess, train=True) 418 | elif(dataset_name == 'imagenet'): 419 | return imagenet(transform=preprocess) 420 | elif(dataset_name == 'fer2013'): 421 | return fer2013(transform=preprocess) 422 | elif(dataset_name == 'birdsnap'): 423 | return birdsnap(transform=preprocess) 424 | elif(dataset_name == 'resisc45'): 425 | return resisc45(transform=preprocess) 426 | elif(dataset_name == 'sun397'): 427 | return SUN397(transform=preprocess) 428 | elif(dataset_name == 'sun397_train'): 429 | return SUN397(transform=preprocess, train=True) 430 | elif(dataset_name == 'caltech101'): 431 | return caltech101(transform=preprocess) 432 | elif(dataset_name == 'office_d'): 433 | return office(transform=preprocess, mode='dslr') 434 | elif(dataset_name == 'office_a'): 435 | return office(transform=preprocess, mode='amazon') 436 | elif(dataset_name == 'office_w'): 437 | return office(transform=preprocess, mode='webcam') 438 | elif(dataset_name == 'office_ar'): 439 | return officehome(transform=preprocess, mode='Art') 440 | elif(dataset_name == 'office_cl'): 441 | return officehome(transform=preprocess, mode='Clipart') 442 | elif(dataset_name == 'office_pr'): 443 | return officehome(transform=preprocess, mode='Product') 444 | elif(dataset_name == 'office_rw'): 445 | return officehome(transform=preprocess, mode='Real_World') 446 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import clip 3 | import socket 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from scipy.sparse import eye 7 | from scipy.sparse import linalg as s_linalg 8 | import numpy as np 9 | 10 | # SETUP CHEKPOINT FOLDER HERE!!!! 11 | if(socket.gethostname()[-7:] == 'usc.edu'): 12 | ckpt_dir = '/home1/xuefengh/models_ckpt/' 13 | else: 14 | ckpt_dir = "path/to/your/checkpoint/folder" 15 | 16 | # default templates provided by CLIP for ImageNet 17 | clip_full_templates = [ 18 | 'a bad photo of a {}.', 19 | 'a photo of many {}.', 20 | 'a sculpture of a {}.', 21 | 'a photo of the hard to see {}.', 22 | 'a low resolution photo of the {}.', 23 | 'a rendering of a {}.', 24 | 'graffiti of a {}.', 25 | 'a bad photo of the {}.', 26 | 'a cropped photo of the {}.', 27 | 'a tattoo of a {}.', 28 | 'the embroidered {}.', 29 | 'a photo of a hard to see {}.', 30 | 'a bright photo of a {}.', 31 | 'a photo of a clean {}.', 32 | 'a photo of a dirty {}.', 33 | 'a dark photo of the {}.', 34 | 'a drawing of a {}.', 35 | 'a photo of my {}.', 36 | 'the plastic {}.', 37 | 'a photo of the cool {}.', 38 | 'a close-up photo of a {}.', 39 | 'a black and white photo of the {}.', 40 | 'a painting of the {}.', 41 | 'a painting of a {}.', 42 | 'a pixelated photo of the {}.', 43 | 'a sculpture of the {}.', 44 | 'a bright photo of the {}.', 45 | 'a cropped photo of a {}.', 46 | 'a plastic {}.', 47 | 'a photo of the dirty {}.', 48 | 'a jpeg corrupted photo of a {}.', 49 | 'a blurry photo of the {}.', 50 | 'a photo of the {}.', 51 | 'a good photo of the {}.', 52 | 'a rendering of the {}.', 53 | 'a {} in a video game.', 54 | 'a photo of one {}.', 55 | 'a doodle of a {}.', 56 | 'a close-up photo of the {}.', 57 | 'a photo of a {}.', 58 | 'the origami {}.', 59 | 'the {} in a video game.', 60 | 'a sketch of a {}.', 61 | 'a doodle of the {}.', 62 | 'a origami {}.', 63 | 'a low resolution photo of a {}.', 64 | 'the toy {}.', 65 | 'a rendition of the {}.', 66 | 'a photo of the clean {}.', 67 | 'a photo of a large {}.', 68 | 'a rendition of a {}.', 69 | 'a photo of a nice {}.', 70 | 'a photo of a weird {}.', 71 | 'a blurry photo of a {}.', 72 | 'a cartoon {}.', 73 | 'art of a {}.', 74 | 'a sketch of the {}.', 75 | 'a embroidered {}.', 76 | 'a pixelated photo of a {}.', 77 | 'itap of the {}.', 78 | 'a jpeg corrupted photo of the {}.', 79 | 'a good photo of a {}.', 80 | 'a plushie {}.', 81 | 'a photo of the nice {}.', 82 | 'a photo of the small {}.', 83 | 'a photo of the weird {}.', 84 | 'the cartoon {}.', 85 | 'art of the {}.', 86 | 'a drawing of the {}.', 87 | 'a photo of the large {}.', 88 | 'a black and white photo of a {}.', 89 | 'the plushie {}.', 90 | 'a dark photo of a {}.', 91 | 'itap of a {}.', 92 | 'graffiti of the {}.', 93 | 'a toy {}.', 94 | 'itap of my {}.', 95 | 'a photo of a cool {}.', 96 | 'a photo of a small {}.', 97 | 'a tattoo of the {}.', 98 | ] 99 | 100 | # default template provided by CLIP 101 | clip_small_templates = [ 102 | 'a photo of a {}.', 103 | ] 104 | 105 | # setup device 106 | if(torch.cuda.is_available()): 107 | device = torch.device('cuda') 108 | else: 109 | device = torch.device('cpu') 110 | 111 | # ReCLIP-V, only updates visual encoder layer norm 112 | class CLIP_LN_V(nn.Module): 113 | def __init__(self, class_names, architecture='ViT-L/14', templates=clip_full_templates, learnable_classifier=False): 114 | super().__init__() 115 | 116 | # load CLIP checkpoint 117 | self.base_model, self.preprocess = clip.load(download_root=ckpt_dir, name=architecture, device=device) 118 | 119 | # load the templates provided by CLIP 120 | self.templates = templates 121 | 122 | # produce text embeddings based on class names 123 | with torch.no_grad(): 124 | class_embeddings = self.encode_text(class_names).detach() 125 | 126 | # classification are set to not learnable by default, learnable_params is a dict of parameters for optimizer to know which parameter to update 127 | if(learnable_classifier): 128 | self.classification_weight = nn.Parameter(class_embeddings, requires_grad=True) 129 | self.learnable_params = [self.classification_weight] 130 | else: 131 | self.classification_weight = class_embeddings 132 | self.learnable_params = [] 133 | 134 | # setup parameteres for training 135 | self.setup_parapmeters() 136 | 137 | # make everything forzen except visual layer norm 138 | def setup_parapmeters(self): 139 | self.base_model.eval() 140 | self.base_model.requires_grad_(False) 141 | # visual related layer norms 142 | for m in self.base_model.visual.modules(): 143 | if isinstance(m, torch.nn.LayerNorm) or isinstance(m, torch.nn.BatchNorm2d): 144 | m.requires_grad_(True) 145 | self.learnable_params.append(m.weight) 146 | self.learnable_params.append(m.bias) 147 | 148 | def encode_text(self, classnames): 149 | # collect all prompts 150 | zeroshot_weights = [] # expected size: [class_num * template_num] 151 | num_class = len(classnames) 152 | for classname in classnames: 153 | if(isinstance(classname, list)): 154 | # prepare token then embedding 155 | all_prompts = [template.format(classna) for template in self.templates for classna in classname] 156 | else: 157 | # prepare token then embedding 158 | all_prompts = [template.format(classname) for template in self.templates] 159 | 160 | all_tokens = clip.tokenize(all_prompts).to(device) # [template_num, 77] 161 | class_embeddings = self.base_model.encode_text(all_tokens) # [num_prompts, 768] 162 | # normalize, average, normalize again 163 | class_embeddings = F.normalize(class_embeddings, p=2, dim=-1) 164 | class_embeddings = class_embeddings.mean(dim=0) # [num_class, 768] 165 | class_embeddings = F.normalize(class_embeddings, p=2, dim=-1) 166 | zeroshot_weights.append(class_embeddings) 167 | 168 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) 169 | return zeroshot_weights 170 | 171 | def encode_image(self, image): 172 | image_features = self.base_model.encode_image(image) 173 | image_features = F.normalize(image_features, p=2, dim=1) 174 | return image_features 175 | 176 | def forward(self, image): 177 | image_features = self.encode_image(image) 178 | self.classification_weight = F.normalize(self.classification_weight, p=2, dim=0) 179 | logits = 100. * image_features @ self.classification_weight # use precomputed classification weights (cluster centriods after first evaluation) 180 | return logits, image_features 181 | 182 | # ReCLIP-T, only updates text encoder layer norm 183 | class CLIP_LN_T(nn.Module): 184 | def __init__(self, architecture='ViT-L/14', templates=clip_small_templates): 185 | super().__init__() 186 | 187 | # load CLIP checkpoint 188 | self.base_model, self.preprocess = clip.load(download_root=ckpt_dir, name=architecture, device=device) 189 | 190 | # load the templates provided by CLIP 191 | self.templates = templates 192 | self.short_templates = [self.templates[0]] # ReCLIP-T uses short template for training and use more templates for providing projection matarix and clustering centriods 193 | self.num_template = len(self.templates) 194 | 195 | # setup parameteres 196 | self.learnable_params = [] 197 | self.setup_parapmeters() 198 | 199 | # make everything forzen except text encoder layer norm 200 | def setup_parapmeters(self): 201 | self.base_model.eval() 202 | self.base_model.requires_grad_(False) 203 | self.learnable_params = [] 204 | # visual related layer norms 205 | for m in self.base_model.transformer.modules(): 206 | if isinstance(m, torch.nn.LayerNorm): 207 | m.requires_grad_(True) 208 | self.learnable_params.append(m.weight) 209 | self.learnable_params.append(m.bias) 210 | self.base_model.ln_final.requires_grad_(True) 211 | self.learnable_params.append(self.base_model.ln_final.weight) 212 | self.learnable_params.append(self.base_model.ln_final.bias) 213 | 214 | # by default, use short template. Unless full_templates=True (used for producing projection matarix and clustering centriods) 215 | def encode_text(self, classnames, full_templates=False): 216 | # select template 217 | if(full_templates): 218 | curr_template = self.templates 219 | # collect all prompts 220 | zeroshot_weights = [] # expected size: [class_num * template_num] 221 | num_class = len(classnames) 222 | for classname in classnames: 223 | if(isinstance(classname, list)): 224 | # prepare token then embedding 225 | all_prompts = [template.format(classna) for template in self.templates for classna in classname] 226 | else: 227 | # prepare token then embedding 228 | all_prompts = [template.format(classname) for template in self.templates] 229 | 230 | all_tokens = clip.tokenize(all_prompts).to(device) # [template_num, 77] 231 | class_embeddings = self.base_model.encode_text(all_tokens) # [num_prompts, 768] 232 | # normalize, average, normalize again 233 | class_embeddings = F.normalize(class_embeddings, p=2, dim=-1) 234 | class_embeddings = class_embeddings.mean(dim=0) # [num_class, 768] 235 | class_embeddings = F.normalize(class_embeddings, p=2, dim=-1) 236 | zeroshot_weights.append(class_embeddings) 237 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) 238 | return zeroshot_weights 239 | else: 240 | curr_template = self.short_templates 241 | curr_num_template = len(curr_template) 242 | # collect all prompts 243 | all_prompts = [] # expected size: [class_num * template_num] 244 | num_class = len(classnames) 245 | for classname in classnames: 246 | all_prompts.extend([template.format(classname) for template in curr_template]) 247 | all_tokens = clip.tokenize(all_prompts).to(device) # [class_num * template_num, 77] 248 | 249 | # class embeddings 250 | class_embeddings = self.base_model.encode_text(all_tokens) # [num_prompts, 768] 251 | class_embeddings = class_embeddings.view(num_class, curr_num_template, -1) 252 | 253 | # normalize, average, normalize again 254 | class_embeddings = F.normalize(class_embeddings, p=2, dim=-1) 255 | class_embeddings = class_embeddings.mean(dim=1) # [num_class, 768] 256 | class_embeddings = F.normalize(class_embeddings, p=2, dim=-1) 257 | class_embeddings = class_embeddings.transpose(0,1) 258 | return class_embeddings 259 | 260 | def encode_image(self, image): 261 | image_features = self.base_model.encode_image(image) 262 | image_features = F.normalize(image_features, p=2, dim=1) 263 | return image_features 264 | 265 | def forward(self, image, class_names): 266 | image_features = self.encode_image(image) 267 | class_embedding = self.encode_text(class_names) # use generated text embeddings for classification 268 | logits = 100. * image_features @ class_embedding 269 | return logits, image_features 270 | 271 | # label propagation module producing pseudo labels 272 | class LabelPropagationCluster(nn.Module): 273 | def __init__(self, classification_weight, dataset_size, k=10, alpha=0.99, cut_dim=768): 274 | super().__init__() 275 | 276 | # text generated classification weight 277 | self.classification_weight = classification_weight 278 | 279 | # parameters 280 | self.feat_dim = classification_weight.size(0) 281 | self.num_class = classification_weight.size(1) 282 | self.num_neighbor = k 283 | self.dataset_size = dataset_size 284 | self.image_per_class = self.dataset_size // self.num_class 285 | self.alpha = alpha 286 | self.cut_dim = cut_dim # for datasets with number of class > 500, we use cut_dim=250 287 | 288 | # container for pseudo labels, features, etc 289 | self.all_feat = [] 290 | self.idx_map = [] 291 | self.all_labels = {} 292 | self.pseudo_labels = {i:0 for i in range(self.dataset_size)} 293 | self.confidence = {i:0 for i in range(self.dataset_size)} 294 | 295 | # build projection 296 | self.update_projection(classification_weight) # update projection matrix with current classification weights 297 | self.update_centriods(classification_weight.t()) # update clustering centriods with current classification weights 298 | 299 | def forward(self, x, idx, label): 300 | # update features into memory 301 | idx = list(idx.cpu().numpy()) 302 | label = list(label.cpu().numpy()) 303 | self.all_feat.append(x.detach()) 304 | self.idx_map.extend(idx) 305 | bs = len(label) 306 | for i in range(bs): 307 | self.all_labels[idx[i]] = label[i] 308 | 309 | # use svd to compute projection matrix 310 | def update_projection(self, classification_weight=None): 311 | # classification_weight [768, class_num] 312 | if(classification_weight is not None): 313 | classification_weight = classification_weight 314 | else: 315 | classification_weight = self.centriods.t() 316 | 317 | U, S, V = torch.svd(classification_weight.to(torch.float32)) # U [768, class] 318 | self.projection_matrix = nn.Parameter((U[:,1:self.cut_dim] @ U[:,1:self.cut_dim].t()).to(torch.float16), requires_grad=False) # [768, 768] 319 | 320 | # update clustering centriods 321 | def update_centriods(self, centriods): 322 | self.centriods = centriods 323 | 324 | # return pseudo labels for examples of given indices 325 | def get_pseudo_label(self, idx): 326 | idx = list(idx.cpu().numpy()) 327 | pseudo_labels = [self.pseudo_labels[i] for i in idx] 328 | pseudo_confdence = [self.confidence[i] for i in idx] 329 | return pseudo_labels, pseudo_confdence 330 | 331 | # calculate closed form solution for label propagation 332 | def cg_diffusion(self, qsims, Wn, alpha = 0.99, maxiter = 20, tol = 1e-6): 333 | Wnn = eye(Wn.shape[0]) - alpha * Wn 334 | out_sims = [] 335 | for i in range(qsims.shape[0]): 336 | f,inf = s_linalg.cg(Wnn, qsims[i,:], tol=tol, maxiter=maxiter) 337 | out_sims.append(f.reshape(-1,1)) 338 | out_sims = np.concatenate(out_sims, axis = 1) 339 | ranks = np.argsort(-out_sims, axis = 0) 340 | return ranks, out_sims 341 | 342 | # main function for label propagation 343 | def perform_label_propagation(self, clear_cache=True, cluster_centriod=False): 344 | # stack all feat 345 | self.all_feat_stack = torch.cat(self.all_feat, dim=0) # [all_feat] 346 | 347 | # assertion 348 | num_record = self.all_feat_stack.size(0) 349 | assert(len(self.idx_map) == num_record) 350 | assert(len(self.all_labels) == num_record) 351 | 352 | # prepare features 353 | all_points = torch.cat([self.centriods, self.all_feat_stack], dim=0).detach() 354 | all_points_original = all_points.cpu().to(torch.float32) 355 | all_points_project = all_points @ self.projection_matrix 356 | all_points_project = (F.normalize(all_points_project, p=2, dim=-1)).cpu().to(torch.float32) 357 | num_example = all_points_project.size(0) 358 | 359 | # affinty matrix 360 | A = (all_points_project @ all_points_project.t() + 1) / 2 361 | # remove diagonal 362 | A = A * (1 - torch.eye(num_example)) 363 | # only keep topk nearest neighbors 364 | topk_val, topk_idx = torch.topk(A, self.num_neighbor, dim=1) 365 | topA = torch.zeros_like(A).scatter_(1, topk_idx, topk_val) 366 | # symmentric 367 | W = (topA + topA.t()) / 2 368 | # normalize 369 | D = torch.diag(torch.diag((W @ torch.ones(num_example, num_example)) ** (-0.5))) 370 | nW = D @ (W @ D) 371 | # need to use scipy to solve linear system pW = Y 372 | nw = nW.numpy() 373 | 374 | # produce labels 375 | Y = np.zeros((self.num_class, num_example)) 376 | for i in range(self.num_class): 377 | Y[i,i] = 1 378 | 379 | # perform cg optimization 380 | _, out_sims = self.cg_diffusion(Y, nw, self.alpha) 381 | 382 | # prediction 383 | prediction = np.argmax(out_sims, axis=1) # [sample.] 384 | 385 | # entropy 386 | out_sims_normalized = (out_sims.T / out_sims.sum(axis=1)).T # row normaliz ranks 387 | entropy = - out_sims_normalized * np.log(out_sims_normalized) 388 | ent_confidence = 1 - entropy.sum(axis=1) / np.log(self.num_class) 389 | 390 | # calculate clustering centriods to update ReCLIP-V classification weights 391 | if(cluster_centriod): 392 | new_centriods = self.centriods.clone() 393 | for class_id in range(self.num_class): 394 | current_class = (prediction == class_id) 395 | num_current_class = np.sum(current_class) 396 | current_class_conf = ent_confidence[current_class] # [num_current_class] 397 | current_class_feat = all_points_original[current_class] # [num_current_class, 768] 398 | sample_order = np.argsort(current_class_conf)[::-1] 399 | sample_order = sample_order[:int(num_current_class)].copy() 400 | current_centriods = current_class_feat[sample_order] 401 | current_centriods = torch.mean(current_centriods, dim=0) 402 | current_centriods = F.normalize(current_centriods.to(torch.float32), p=2, dim=0) 403 | new_centriods[class_id,:] = current_centriods 404 | 405 | prediction = prediction[self.num_class:] # first num_class entries correspond to text embeddings of each class 406 | ent_confidence = ent_confidence[self.num_class:] # first num_class entries correspond to text embeddings of each class 407 | 408 | # save predictions & confidence 409 | for i in range(len(prediction)): 410 | self.pseudo_labels[self.idx_map[i]] = prediction[i] 411 | self.confidence[self.idx_map[i]] = ent_confidence[i] 412 | 413 | # pseudo label result 414 | pseudo_label_acc = np.mean([self.pseudo_labels[i] == self.all_labels[i] for i in self.idx_map]) 415 | 416 | # clean up the bucket for next round 417 | if(clear_cache): 418 | self.all_feat = [] 419 | self.idx_map = [] 420 | 421 | if(cluster_centriod): 422 | return pseudo_label_acc, new_centriods 423 | else: 424 | return pseudo_label_acc -------------------------------------------------------------------------------- /prompts/clip_prompts: -------------------------------------------------------------------------------- 1 | {"birdsnap": {"classes": ["Acadian Flycatcher", "Acorn Woodpecker", "Alder Flycatcher", "Allens Hummingbird", "Altamira Oriole", "American Avocet", "American Bittern", "American Black Duck", "American Coot", "American Crow", "American Dipper", "American Golden Plover", "American Goldfinch", "American Kestrel", "American Oystercatcher", "American Pipit", "American Redstart", "American Robin", "American Three toed Woodpecker", "American Tree Sparrow", "American White Pelican", "American Wigeon", "American Woodcock", "Anhinga", "Annas Hummingbird", "Arctic Tern", "Ash throated Flycatcher", "Audubons Oriole", "Bairds Sandpiper", "Bald Eagle", "Baltimore Oriole", "Band tailed Pigeon", "Barn Swallow", "Barred Owl", "Barrows Goldeneye", "Bay breasted Warbler", "Bells Vireo", "Belted Kingfisher", "Bewicks Wren", "Black Guillemot", "Black Oystercatcher", "Black Phoebe", "Black Rosy Finch", "Black Scoter", "Black Skimmer", "Black Tern", "Black Turnstone", "Black Vulture", "Black and white Warbler", "Black backed Woodpecker", "Black bellied Plover", "Black billed Cuckoo", "Black billed Magpie", "Black capped Chickadee", "Black chinned Hummingbird", "Black chinned Sparrow", "Black crested Titmouse", "Black crowned Night Heron", "Black headed Grosbeak", "Black legged Kittiwake", "Black necked Stilt", "Black throated Blue Warbler", "Black throated Gray Warbler", "Black throated Green Warbler", "Black throated Sparrow", "Blackburnian Warbler", "Blackpoll Warbler", "Blue Grosbeak", "Blue Jay", "Blue gray Gnatcatcher", "Blue headed Vireo", "Blue winged Teal", "Blue winged Warbler", "Boat tailed Grackle", "Bobolink", "Bohemian Waxwing", "Bonapartes Gull", "Boreal Chickadee", "Brandts Cormorant", "Brant", "Brewers Blackbird", "Brewers Sparrow", "Bridled Titmouse", "Broad billed Hummingbird", "Broad tailed Hummingbird", "Broad winged Hawk", "Bronzed Cowbird", "Brown Creeper", "Brown Pelican", "Brown Thrasher", "Brown capped Rosy Finch", "Brown crested Flycatcher", "Brown headed Cowbird", "Brown headed Nuthatch", "Bufflehead", "Bullocks Oriole", "Burrowing Owl", "Bushtit", "Cackling Goose", "Cactus Wren", "California Gull", "California Quail", "California Thrasher", "California Towhee", "Calliope Hummingbird", "Canada Goose", "Canada Warbler", "Canvasback", "Canyon Towhee", "Canyon Wren", "Cape May Warbler", "Carolina Chickadee", "Carolina Wren", "Caspian Tern", "Cassins Finch", "Cassins Kingbird", "Cassins Sparrow", "Cassins Vireo", "Cattle Egret", "Cave Swallow", "Cedar Waxwing", "Cerulean Warbler", "Chestnut backed Chickadee", "Chestnut collared Longspur", "Chestnut sided Warbler", "Chihuahuan Raven", "Chimney Swift", "Chipping Sparrow", "Cinnamon Teal", "Clapper Rail", "Clarks Grebe", "Clarks Nutcracker", "Clay colored Sparrow", "Cliff Swallow", "Common Black Hawk", "Common Eider", "Common Gallinule", "Common Goldeneye", "Common Grackle", "Common Ground Dove", "Common Loon", "Common Merganser", "Common Murre", "Common Nighthawk", "Common Raven", "Common Redpoll", "Common Tern", "Common Yellowthroat", "Connecticut Warbler", "Coopers Hawk", "Cordilleran Flycatcher", "Costas Hummingbird", "Couchs Kingbird", "Crested Caracara", "Curve billed Thrasher", "Dark eyed Junco", "Dickcissel", "Double crested Cormorant", "Downy Woodpecker", "Dunlin", "Dusky Flycatcher", "Dusky Grouse", "Eared Grebe", "Eastern Bluebird", "Eastern Kingbird", "Eastern Meadowlark", "Eastern Phoebe", "Eastern Screech Owl", "Eastern Towhee", "Eastern Wood Pewee", "Elegant Trogon", "Elf Owl", "Eurasian Collared Dove", "Eurasian Wigeon", "European Starling", "Evening Grosbeak", "Ferruginous Hawk", "Ferruginous Pygmy Owl", "Field Sparrow", "Fish Crow", "Florida Scrub Jay", "Forsters Tern", "Fox Sparrow", "Franklins Gull", "Fulvous Whistling Duck", "Gadwall", "Gambels Quail", "Gila Woodpecker", "Glaucous Gull", "Glaucous winged Gull", "Glossy Ibis", "Golden Eagle", "Golden crowned Kinglet", "Golden crowned Sparrow", "Golden fronted Woodpecker", "Golden winged Warbler", "Grasshopper Sparrow", "Gray Catbird", "Gray Flycatcher", "Gray Jay", "Gray Kingbird", "Gray cheeked Thrush", "Gray crowned Rosy Finch", "Great Black backed Gull", "Great Blue Heron", "Great Cormorant", "Great Crested Flycatcher", "Great Egret", "Great Gray Owl", "Great Horned Owl", "Great Kiskadee", "Great tailed Grackle", "Greater Prairie Chicken", "Greater Roadrunner", "Greater Sage Grouse", "Greater Scaup", "Greater White fronted Goose", "Greater Yellowlegs", "Green Jay", "Green tailed Towhee", "Green winged Teal", "Groove billed Ani", "Gull billed Tern", "Hairy Woodpecker", "Hammonds Flycatcher", "Harlequin Duck", "Harriss Hawk", "Harriss Sparrow", "Heermanns Gull", "Henslows Sparrow", "Hepatic Tanager", "Hermit Thrush", "Herring Gull", "Hoary Redpoll", "Hooded Merganser", "Hooded Oriole", "Hooded Warbler", "Horned Grebe", "Horned Lark", "House Finch", "House Sparrow", "House Wren", "Huttons Vireo", "Iceland Gull", "Inca Dove", "Indigo Bunting", "Killdeer", "King Rail", "Ladder backed Woodpecker", "Lapland Longspur", "Lark Bunting", "Lark Sparrow", "Laughing Gull", "Lazuli Bunting", "Le Contes Sparrow", "Least Bittern", "Least Flycatcher", "Least Grebe", "Least Sandpiper", "Least Tern", "Lesser Goldfinch", "Lesser Nighthawk", "Lesser Scaup", "Lesser Yellowlegs", "Lewiss Woodpecker", "Limpkin", "Lincolns Sparrow", "Little Blue Heron", "Loggerhead Shrike", "Long billed Curlew", "Long billed Dowitcher", "Long billed Thrasher", "Long eared Owl", "Long tailed Duck", "Louisiana Waterthrush", "Magnificent Frigatebird", "Magnolia Warbler", "Mallard", "Marbled Godwit", "Marsh Wren", "Merlin", "Mew Gull", "Mexican Jay", "Mississippi Kite", "Monk Parakeet", "Mottled Duck", "Mountain Bluebird", "Mountain Chickadee", "Mountain Plover", "Mourning Dove", "Mourning Warbler", "Muscovy Duck", "Mute Swan", "Nashville Warbler", "Nelsons Sparrow", "Neotropic Cormorant", "Northern Bobwhite", "Northern Cardinal", "Northern Flicker", "Northern Gannet", "Northern Goshawk", "Northern Harrier", "Northern Hawk Owl", "Northern Mockingbird", "Northern Parula", "Northern Pintail", "Northern Rough winged Swallow", "Northern Saw whet Owl", "Northern Shrike", "Northern Waterthrush", "Nuttalls Woodpecker", "Oak Titmouse", "Olive Sparrow", "Olive sided Flycatcher", "Orange crowned Warbler", "Orchard Oriole", "Osprey", "Ovenbird", "Pacific Golden Plover", "Pacific Loon", "Pacific Wren", "Pacific slope Flycatcher", "Painted Bunting", "Painted Redstart", "Palm Warbler", "Pectoral Sandpiper", "Peregrine Falcon", "Phainopepla", "Philadelphia Vireo", "Pied billed Grebe", "Pigeon Guillemot", "Pileated Woodpecker", "Pine Grosbeak", "Pine Siskin", "Pine Warbler", "Piping Plover", "Plumbeous Vireo", "Prairie Falcon", "Prairie Warbler", "Prothonotary Warbler", "Purple Finch", "Purple Gallinule", "Purple Martin", "Purple Sandpiper", "Pygmy Nuthatch", "Pyrrhuloxia", "Red Crossbill", "Red Knot", "Red Phalarope", "Red bellied Woodpecker", "Red breasted Merganser", "Red breasted Nuthatch", "Red breasted Sapsucker", "Red cockaded Woodpecker", "Red eyed Vireo", "Red headed Woodpecker", "Red naped Sapsucker", "Red necked Grebe", "Red necked Phalarope", "Red shouldered Hawk", "Red tailed Hawk", "Red throated Loon", "Red winged Blackbird", "Reddish Egret", "Redhead", "Ring billed Gull", "Ring necked Duck", "Ring necked Pheasant", "Rock Pigeon", "Rock Ptarmigan", "Rock Sandpiper", "Rock Wren", "Rose breasted Grosbeak", "Roseate Tern", "Rosss Goose", "Rough legged Hawk", "Royal Tern", "Ruby crowned Kinglet", "Ruby throated Hummingbird", "Ruddy Duck", "Ruddy Turnstone", "Ruffed Grouse", "Rufous Hummingbird", "Rufous crowned Sparrow", "Rusty Blackbird", "Sage Thrasher", "Saltmarsh Sparrow", "Sanderling", "Sandhill Crane", "Sandwich Tern", "Says Phoebe", "Scaled Quail", "Scarlet Tanager", "Scissor tailed Flycatcher", "Scotts Oriole", "Seaside Sparrow", "Sedge Wren", "Semipalmated Plover", "Semipalmated Sandpiper", "Sharp shinned Hawk", "Sharp tailed Grouse", "Short billed Dowitcher", "Short eared Owl", "Snail Kite", "Snow Bunting", "Snow Goose", "Snowy Egret", "Snowy Owl", "Snowy Plover", "Solitary Sandpiper", "Song Sparrow", "Sooty Grouse", "Sora", "Spotted Owl", "Spotted Sandpiper", "Spotted Towhee", "Spruce Grouse", "Stellers Jay", "Stilt Sandpiper", "Summer Tanager", "Surf Scoter", "Surfbird", "Swainsons Hawk", "Swainsons Thrush", "Swallow tailed Kite", "Swamp Sparrow", "Tennessee Warbler", "Thayers Gull", "Townsends Solitaire", "Townsends Warbler", "Tree Swallow", "Tricolored Heron", "Tropical Kingbird", "Trumpeter Swan", "Tufted Titmouse", "Tundra Swan", "Turkey Vulture", "Upland Sandpiper", "Varied Thrush", "Veery", "Verdin", "Vermilion Flycatcher", "Vesper Sparrow", "Violet green Swallow", "Virginia Rail", "Wandering Tattler", "Warbling Vireo", "Western Bluebird", "Western Grebe", "Western Gull", "Western Kingbird", "Western Meadowlark", "Western Sandpiper", "Western Screech Owl", "Western Scrub Jay", "Western Tanager", "Western Wood Pewee", "Whimbrel", "White Ibis", "White breasted Nuthatch", "White crowned Sparrow", "White eyed Vireo", "White faced Ibis", "White headed Woodpecker", "White rumped Sandpiper", "White tailed Hawk", "White tailed Kite", "White tailed Ptarmigan", "White throated Sparrow", "White throated Swift", "White winged Crossbill", "White winged Dove", "White winged Scoter", "Wild Turkey", "Willet", "Williamsons Sapsucker", "Willow Flycatcher", "Willow Ptarmigan", "Wilsons Phalarope", "Wilsons Plover", "Wilsons Snipe", "Wilsons Warbler", "Winter Wren", "Wood Stork", "Wood Thrush", "Worm eating Warbler", "Wrentit", "Yellow Warbler", "Yellow bellied Flycatcher", "Yellow bellied Sapsucker", "Yellow billed Cuckoo", "Yellow billed Magpie", "Yellow breasted Chat", "Yellow crowned Night Heron", "Yellow eyed Junco", "Yellow headed Blackbird", "Yellow rumped Warbler", "Yellow throated Vireo", "Yellow throated Warbler", "Zone tailed Hawk"], "templates": ["a photo of a {}, a type of bird."]}, "cifar10": {"classes": ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"], "templates": ["a photo of a {}.", "a blurry photo of a {}.", "a black and white photo of a {}.", "a low contrast photo of a {}.", "a high contrast photo of a {}.", "a bad photo of a {}.", "a good photo of a {}.", "a photo of a small {}.", "a photo of a big {}.", "a photo of the {}.", "a blurry photo of the {}.", "a black and white photo of the {}.", "a low contrast photo of the {}.", "a high contrast photo of the {}.", "a bad photo of the {}.", "a good photo of the {}.", "a photo of the small {}.", "a photo of the big {}."]}, "cifar100": {"classes": ["apple", "aquarium fish", "baby", "bear", "beaver", "bed", "bee", "beetle", "bicycle", "bottle", "bowl", "boy", "bridge", "bus", "butterfly", "camel", "can", "castle", "caterpillar", "cattle", "chair", "chimpanzee", "clock", "cloud", "cockroach", "couch", "crab", "crocodile", "cup", "dinosaur", "dolphin", "elephant", "flatfish", "forest", "fox", "girl", "hamster", "house", "kangaroo", "keyboard", "lamp", "lawn mower", "leopard", "lion", "lizard", "lobster", "man", "maple tree", "motorcycle", "mountain", "mouse", "mushroom", "oak tree", "orange", "orchid", "otter", "palm tree", "pear", "pickup truck", "pine tree", "plain", "plate", "poppy", "porcupine", "possum", "rabbit", "raccoon", "ray", "road", "rocket", "rose", "sea", "seal", "shark", "shrew", "skunk", "skyscraper", "snail", "snake", "spider", "squirrel", "streetcar", "sunflower", "sweet pepper", "table", "tank", "telephone", "television", "tiger", "tractor", "train", "trout", "tulip", "turtle", "wardrobe", "whale", "willow tree", "wolf", "woman", "worm"], "templates": ["a photo of a {}.", "a blurry photo of a {}.", "a black and white photo of a {}.", "a low contrast photo of a {}.", "a high contrast photo of a {}.", "a bad photo of a {}.", "a good photo of a {}.", "a photo of a small {}.", "a photo of a big {}.", "a photo of the {}.", "a blurry photo of the {}.", "a black and white photo of the {}.", "a low contrast photo of the {}.", "a high contrast photo of the {}.", "a bad photo of the {}.", "a good photo of the {}.", "a photo of the small {}.", "a photo of the big {}."]}, "clever": {"classes": ["10", "3", "4", "5", "6", "7", "8", "9"], "templates": ["a photo of {} objects."]}, "caltech101": {"classes": ["human face", "ID photo", "leopard", "motorbike", "accordion", "airplane", "anchor", "ant", "barrel", "bass", "beaver", "binocular", "bonsai", "brain", "brontosaurus", "buddha", "butterfly", "camera", "cannon", "side of a car", "ceiling fan", "cellphone", "chair", "chandelier", "body of a cougar cat", "face of a cougar cat", "crab", "crayfish", "crocodile", "head of a crocodile", "cup", "dalmatian", "dollar bill", "dolphin", "dragonfly", "electric guitar", "elephant", "emu", "euphonium", "ewer", "ferry", "flamingo", "head of a flamingo", "garfield", "gerenuk", "gramophone", "grand piano", "hawksbill", "headphone", "hedgehog", "helicopter", "ibis", "inline skate", "joshua tree", "kangaroo", "ketch", "lamp", "laptop", "llama", "lobster", "lotus", "mandolin", "mayfly", "menorah", "metronome", "minaret", "nautilus", "octopus", "okapi", "pagoda", "panda", "pigeon", "pizza", "platypus", "pyramid", "revolver", "rhino", "rooster", "saxophone", "schooner", "scissors", "scorpion", "sea horse", "snoopy (cartoon beagle)", "soccer ball", "stapler", "starfish", "stegosaurus", "stop sign", "strawberry", "sunflower", "tick", "trilobite", "umbrella", "watch", "water lilly", "wheelchair", "wild cat", "windsor chair", "wrench", "yin and yang symbol"], "templates": ["a photo of a {}.", "a painting of a {}.", "a plastic {}.", "a sculpture of a {}.", "a sketch of a {}.", "a tattoo of a {}.", "a toy {}.", "a rendition of a {}.", "a embroidered {}.", "a cartoon {}.", "a {} in a video game.", "a plushie {}.", "a origami {}.", "art of a {}.", "graffiti of a {}.", "a drawing of a {}.", "a doodle of a {}.", "a photo of the {}.", "a painting of the {}.", "the plastic {}.", "a sculpture of the {}.", "a sketch of the {}.", "a tattoo of the {}.", "the toy {}.", "a rendition of the {}.", "the embroidered {}.", "the cartoon {}.", "the {} in a video game.", "the plushie {}.", "the origami {}.", "art of the {}.", "graffiti of the {}.", "a drawing of the {}.", "a doodle of the {}."]}, "country211": {"classes": ["Andorra", "United Arab Emirates", "Afghanistan", "Antigua and Barbuda", "Anguilla", "Albania", "Armenia", "Angola", "Antarctica", "Argentina", "Austria", "Australia", "Aruba", "Aland Islands", "Azerbaijan", "Bosnia and Herzegovina", "Barbados", "Bangladesh", "Belgium", "Burkina Faso", "Bulgaria", "Bahrain", "Benin", "Bermuda", "Brunei Darussalam", "Bolivia", "Bonaire, Saint Eustatius and Saba", "Brazil", "Bahamas", "Bhutan", "Botswana", "Belarus", "Belize", "Canada", "DR Congo", "Central African Republic", "Switzerland", "Cote d'Ivoire", "Cook Islands", "Chile", "Cameroon", "China", "Colombia", "Costa Rica", "Cuba", "Cabo Verde", "Curacao", "Cyprus", "Czech Republic", "Germany", "Denmark", "Dominica", "Dominican Republic", "Algeria", "Ecuador", "Estonia", "Egypt", "Spain", "Ethiopia", "Finland", "Fiji", "Falkland Islands", "Faeroe Islands", "France", "Gabon", "United Kingdom", "Grenada", "Georgia", "French Guiana", "Guernsey", "Ghana", "Gibraltar", "Greenland", "Gambia", "Guadeloupe", "Greece", "South Georgia and South Sandwich Is.", "Guatemala", "Guam", "Guyana", "Hong Kong", "Honduras", "Croatia", "Haiti", "Hungary", "Indonesia", "Ireland", "Israel", "Isle of Man", "India", "Iraq", "Iran", "Iceland", "Italy", "Jersey", "Jamaica", "Jordan", "Japan", "Kenya", "Kyrgyz Republic", "Cambodia", "St. Kitts and Nevis", "North Korea", "South Korea", "Kuwait", "Cayman Islands", "Kazakhstan", "Laos", "Lebanon", "St. Lucia", "Liechtenstein", "Sri Lanka", "Liberia", "Lithuania", "Luxembourg", "Latvia", "Libya", "Morocco", "Monaco", "Moldova", "Montenegro", "Saint-Martin", "Madagascar", "Macedonia", "Mali", "Myanmar", "Mongolia", "Macau", "Martinique", "Mauritania", "Malta", "Mauritius", "Maldives", "Malawi", "Mexico", "Malaysia", "Mozambique", "Namibia", "New Caledonia", "Nigeria", "Nicaragua", "Netherlands", "Norway", "Nepal", "New Zealand", "Oman", "Panama", "Peru", "French Polynesia", "Papua New Guinea", "Philippines", "Pakistan", "Poland", "Puerto Rico", "Palestine", "Portugal", "Palau", "Paraguay", "Qatar", "Reunion", "Romania", "Serbia", "Russia", "Rwanda", "Saudi Arabia", "Solomon Islands", "Seychelles", "Sudan", "Sweden", "Singapore", "St. Helena", "Slovenia", "Svalbard and Jan Mayen Islands", "Slovakia", "Sierra Leone", "San Marino", "Senegal", "Somalia", "South Sudan", "El Salvador", "Sint Maarten", "Syria", "Eswatini", "Togo", "Thailand", "Tajikistan", "Timor-Leste", "Turkmenistan", "Tunisia", "Tonga", "Turkey", "Trinidad and Tobago", "Taiwan", "Tanzania", "Ukraine", "Uganda", "United States", "Uruguay", "Uzbekistan", "Vatican", "Venezuela", "British Virgin Islands", "United States Virgin Islands", "Vietnam", "Vanuatu", "Samoa", "Kosovo", "Yemen", "South Africa", "Zambia", "Zimbabwe"], "templates": ["a photo i took in {}.", "a photo i took while visiting {}.", "a photo from my home country of {}.", "a photo from my visit to {}.", "a photo showing the country of {}."]}, "dtd": {"classes": ["banded", "blotchy", "braided", "bubbly", "bumpy", "chequered", "cobwebbed", "cracked", "crosshatched", "crystalline", "dotted", "fibrous", "flecked", "freckled", "frilly", "gauzy", "grid", "grooved", "honeycombed", "interlaced", "knitted", "lacelike", "lined", "marbled", "matted", "meshed", "paisley", "perforated", "pitted", "pleated", "polka-dotted", "porous", "potholed", "scaly", "smeared", "spiralled", "sprinkled", "stained", "stratified", "striped", "studded", "swirly", "veined", "waffled", "woven", "wrinkled", "zigzagged"], "templates": ["a photo of a {} texture.", "a photo of a {} pattern.", "a photo of a {} thing.", "a photo of a {} object.", "a photo of the {} texture.", "a photo of the {} pattern.", "a photo of the {} thing.", "a photo of the {} object."]}, "eurosat": {"classes": ["annual crop land", "forest", "brushland or shrubland", "highway or road", "industrial buildings or commercial buildings", "pasture land", "permanent crop land", "residential buildings or homes or apartments", "river", "lake or sea"], "templates": ["a centered satellite photo of {}.", "a centered satellite photo of a {}.", "a centered satellite photo of the {}."]}, "fgvc": {"classes": ["707-320", "727-200", "737-200", "737-300", "737-400", "737-500", "737-600", "737-700", "737-800", "737-900", "747-100", "747-200", "747-300", "747-400", "757-200", "757-300", "767-200", "767-300", "767-400", "777-200", "777-300", "A300B4", "A310", "A318", "A319", "A320", "A321", "A330-200", "A330-300", "A340-200", "A340-300", "A340-500", "A340-600", "A380", "ATR-42", "ATR-72", "An-12", "BAE 146-200", "BAE 146-300", "BAE-125", "Beechcraft 1900", "Boeing 717", "C-130", "C-47", "CRJ-200", "CRJ-700", "CRJ-900", "Cessna 172", "Cessna 208", "Cessna 525", "Cessna 560", "Challenger 600", "DC-10", "DC-3", "DC-6", "DC-8", "DC-9-30", "DH-82", "DHC-1", "DHC-6", "DHC-8-100", "DHC-8-300", "DR-400", "Dornier 328", "E-170", "E-190", "E-195", "EMB-120", "ERJ 135", "ERJ 145", "Embraer Legacy 600", "Eurofighter Typhoon", "F-16A/B", "F/A-18", "Falcon 2000", "Falcon 900", "Fokker 100", "Fokker 50", "Fokker 70", "Global Express", "Gulfstream IV", "Gulfstream V", "Hawk T1", "Il-76", "L-1011", "MD-11", "MD-80", "MD-87", "MD-90", "Metroliner", "Model B200", "PA-28", "SR-20", "Saab 2000", "Saab 340", "Spitfire", "Tornado", "Tu-134", "Tu-154", "Yak-42"], "templates": ["a photo of a {}, a type of aircraft.", "a photo of the {}, a type of aircraft."]}, "fer2013": {"classes": [["angry"], ["disgusted"], ["fearful"], ["happy", "smiling"], ["sad", "depressed"], ["surprised", "shocked", "spooked"], ["neutral", "bored"]], "templates": ["a photo of a {} looking face.", "a photo of a face showing the emotion: {}.", "a photo of a face looking {}.", "a face that looks {}.", "they look {}.", "look at how {} they are."]}, "flowers102": {"classes": ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink and yellow dahlia", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "air plant", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"], "templates": ["a photo of a {}, a type of flower."]}, "food101": {"classes": ["apple pie", "baby back ribs", "baklava", "beef carpaccio", "beef tartare", "beet salad", "beignets", "bibimbap", "bread pudding", "breakfast burrito", "bruschetta", "caesar salad", "cannoli", "caprese salad", "carrot cake", "ceviche", "cheese plate", "cheesecake", "chicken curry", "chicken quesadilla", "chicken wings", "chocolate cake", "chocolate mousse", "churros", "clam chowder", "club sandwich", "crab cakes", "creme brulee", "croque madame", "cup cakes", "deviled eggs", "donuts", "dumplings", "edamame", "eggs benedict", "escargots", "falafel", "filet mignon", "fish and chips", "foie gras", "french fries", "french onion soup", "french toast", "fried calamari", "fried rice", "frozen yogurt", "garlic bread", "gnocchi", "greek salad", "grilled cheese sandwich", "grilled salmon", "guacamole", "gyoza", "hamburger", "hot and sour soup", "hot dog", "huevos rancheros", "hummus", "ice cream", "lasagna", "lobster bisque", "lobster roll sandwich", "macaroni and cheese", "macarons", "miso soup", "mussels", "nachos", "omelette", "onion rings", "oysters", "pad thai", "paella", "pancakes", "panna cotta", "peking duck", "pho", "pizza", "pork chop", "poutine", "prime rib", "pulled pork sandwich", "ramen", "ravioli", "red velvet cake", "risotto", "samosa", "sashimi", "scallops", "seaweed salad", "shrimp and grits", "spaghetti bolognese", "spaghetti carbonara", "spring rolls", "steak", "strawberry shortcake", "sushi", "tacos", "takoyaki", "tiramisu", "tuna tartare", "waffles"], "templates": ["a photo of {}, a type of food."]}, "gtsrb": {"classes": ["red and white circle 20 kph speed limit", "red and white circle 30 kph speed limit", "red and white circle 50 kph speed limit", "red and white circle 60 kph speed limit", "red and white circle 70 kph speed limit", "red and white circle 80 kph speed limit", "end / de-restriction of 80 kph speed limit", "red and white circle 100 kph speed limit", "red and white circle 120 kph speed limit", "red and white circle red car and black car no passing", "red and white circle red truck and black car no passing", "red and white triangle road intersection warning", "white and yellow diamond priority road", "red and white upside down triangle yield right-of-way", "stop", "empty red and white circle", "red and white circle no truck entry", "red circle with white horizonal stripe no entry", "red and white triangle with exclamation mark warning", "red and white triangle with black left curve approaching warning", "red and white triangle with black right curve approaching warning", "red and white triangle with black double curve approaching warning", "red and white triangle rough / bumpy road warning", "red and white triangle car skidding / slipping warning", "red and white triangle with merging / narrow lanes warning", "red and white triangle with person digging / construction / road work warning", "red and white triangle with traffic light approaching warning", "red and white triangle with person walking warning", "red and white triangle with child and person walking warning", "red and white triangle with bicyle warning", "red and white triangle with snowflake / ice warning", "red and white triangle with deer warning", "white circle with gray strike bar no speed limit", "blue circle with white right turn arrow mandatory", "blue circle with white left turn arrow mandatory", "blue circle with white forward arrow mandatory", "blue circle with white forward or right turn arrow mandatory", "blue circle with white forward or left turn arrow mandatory", "blue circle with white keep right arrow mandatory", "blue circle with white keep left arrow mandatory", "blue circle with white arrows indicating a traffic circle", "white circle with gray strike bar indicating no passing for cars has ended", "white circle with gray strike bar indicating no passing for trucks has ended"], "templates": ["a zoomed in photo of a \"{}\" traffic sign.", "a centered photo of a \"{}\" traffic sign.", "a close up photo of a \"{}\" traffic sign."]}, "hatefulmemes": {"classes": ["meme", "hatespeech meme"], "templates": ["a {}."]}, "kitti": {"classes": ["a photo i took of a car on my left or right side.", "a photo i took with a car nearby.", "a photo i took with a car in the distance.", "a photo i took with no car."], "templates": ["{}"]}, "k700": {"classes": ["abseiling", "acting in play", "adjusting glasses", "air drumming", "alligator wrestling", "answering questions", "applauding", "applying cream", "archaeological excavation", "archery", "arguing", "arm wrestling", "arranging flowers", "arresting", "assembling bicycle", "assembling computer", "attending conference", "auctioning", "baby waking up", "backflip (human)", "baking cookies", "bandaging", "barbequing", "bartending", "base jumping", "bathing dog", "battle rope training", "beatboxing", "bee keeping", "being excited", "being in zero gravity", "belly dancing", "bench pressing", "bending back", "bending metal", "biking through snow", "blasting sand", "blending fruit", "blowdrying hair", "blowing bubble gum", "blowing glass", "blowing leaves", "blowing nose", "blowing out candles", "bobsledding", "bodysurfing", "bookbinding", "bottling", "bouncing ball (not juggling)", "bouncing on bouncy castle", "bouncing on trampoline", "bowling", "braiding hair", "breading or breadcrumbing", "breakdancing", "breaking boards", "breaking glass", "breathing fire", "brush painting", "brushing floor", "brushing hair", "brushing teeth", "building cabinet", "building lego", "building sandcastle", "building shed", "bulldozing", "bungee jumping", "burping", "busking", "calculating", "calligraphy", "canoeing or kayaking", "capoeira", "capsizing", "card stacking", "card throwing", "carrying baby", "carrying weight", "cartwheeling", "carving ice", "carving marble", "carving pumpkin", "carving wood with a knife", "casting fishing line", "catching fish", "catching or throwing baseball", "catching or throwing frisbee", "catching or throwing softball", "celebrating", "changing gear in car", "changing oil", "changing wheel (not on bike)", "chasing", "checking tires", "checking watch", "cheerleading", "chewing gum", "chiseling stone", "chiseling wood", "chopping meat", "chopping wood", "clam digging", "clapping", "clay pottery making", "clean and jerk", "cleaning gutters", "cleaning pool", "cleaning shoes", "cleaning toilet", "cleaning windows", "climbing a rope", "climbing ladder", "climbing tree", "closing door", "coloring in", "combing hair", "contact juggling", "contorting", "cooking chicken", "cooking egg", "cooking on campfire", "cooking sausages (not on barbeque)", "cooking scallops", "cosplaying", "coughing", "counting money", "country line dancing", "cracking back", "cracking knuckles", "cracking neck", "crawling baby", "crocheting", "crossing eyes", "crossing river", "crying", "cumbia", "curling (sport)", "curling eyelashes", "curling hair", "cutting apple", "cutting cake", "cutting nails", "cutting orange", "cutting pineapple", "cutting watermelon", "dancing ballet", "dancing charleston", "dancing gangnam style", "dancing macarena", "deadlifting", "dealing cards", "decorating the christmas tree", "decoupage", "delivering mail", "digging", "dining", "directing traffic", "disc golfing", "diving cliff", "docking boat", "dodgeball", "doing aerobics", "doing jigsaw puzzle", "doing laundry", "doing nails", "doing sudoku", "drawing", "dribbling basketball", "drinking shots", "driving car", "driving tractor", "drooling", "drop kicking", "drumming fingers", "dumpster diving", "dunking basketball", "dyeing eyebrows", "dyeing hair", "eating burger", "eating cake", "eating carrots", "eating chips", "eating doughnuts", "eating hotdog", "eating ice cream", "eating nachos", "eating spaghetti", "eating watermelon", "egg hunting", "embroidering", "entering church", "exercising arm", "exercising with an exercise ball", "extinguishing fire", "faceplanting", "falling off bike", "falling off chair", "feeding birds", "feeding fish", "feeding goats", "fencing (sport)", "fidgeting", "filling cake", "filling eyebrows", "finger snapping", "fixing bicycle", "fixing hair", "flint knapping", "flipping bottle", "flipping pancake", "fly tying", "flying kite", "folding clothes", "folding napkins", "folding paper", "front raises", "frying vegetables", "gargling", "geocaching", "getting a haircut", "getting a piercing", "getting a tattoo", "giving or receiving award", "gold panning", "golf chipping", "golf driving", "golf putting", "gospel singing in church", "grinding meat", "grooming cat", "grooming dog", "grooming horse", "gymnastics tumbling", "hammer throw", "hand washing clothes", "head stand", "headbanging", "headbutting", "helmet diving", "herding cattle", "high fiving", "high jump", "high kick", "historical reenactment", "hitting baseball", "hockey stop", "holding snake", "home roasting coffee", "hopscotch", "hoverboarding", "huddling", "hugging (not baby)", "hugging baby", "hula hooping", "hurdling", "hurling (sport)", "ice climbing", "ice fishing", "ice skating", "ice swimming", "inflating balloons", "installing carpet", "ironing", "ironing hair", "javelin throw", "jaywalking", "jetskiing", "jogging", "juggling balls", "juggling fire", "juggling soccer ball", "jumping bicycle", "jumping into pool", "jumping jacks", "jumping sofa", "jumpstyle dancing", "karaoke", "kicking field goal", "kicking soccer ball", "kissing", "kitesurfing", "knitting", "krumping", "land sailing", "laughing", "lawn mower racing", "laying bricks", "laying concrete", "laying decking", "laying stone", "laying tiles", "leatherworking", "letting go of balloon", "licking", "lifting hat", "lighting candle", "lighting fire", "listening with headphones", "lock picking", "long jump", "longboarding", "looking at phone", "looking in mirror", "luge", "lunge", "making a cake", "making a sandwich", "making balloon shapes", "making bubbles", "making cheese", "making horseshoes", "making jewelry", "making latte art", "making paper aeroplanes", "making pizza", "making slime", "making snowman", "making sushi", "making tea", "making the bed", "marching", "marriage proposal", "massaging back", "massaging feet", "massaging legs", "massaging neck", "massaging person's head", "metal detecting", "milking cow", "milking goat", "mixing colours", "moon walking", "mopping floor", "mosh pit dancing", "motorcycling", "mountain climber (exercise)", "moving baby", "moving child", "moving furniture", "mowing lawn", "mushroom foraging", "needle felting", "news anchoring", "opening bottle (not wine)", "opening coconuts", "opening door", "opening present", "opening refrigerator", "opening wine bottle", "packing", "paragliding", "parasailing", "parkour", "passing American football (in game)", "passing American football (not in game)", "passing soccer ball", "peeling apples", "peeling banana", "peeling potatoes", "person collecting garbage", "petting animal (not cat)", "petting cat", "petting horse", "photobombing", "photocopying", "picking apples", "picking blueberries", "pillow fight", "pinching", "pirouetting", "planing wood", "planting trees", "plastering", "playing accordion", "playing american football", "playing badminton", "playing bagpipes", "playing basketball", "playing bass guitar", "playing beer pong", "playing billiards", "playing blackjack", "playing cards", "playing cello", "playing checkers", "playing chess", "playing clarinet", "playing controller", "playing cricket", "playing cymbals", "playing darts", "playing didgeridoo", "playing dominoes", "playing drums", "playing field hockey", "playing flute", "playing gong", "playing guitar", "playing hand clapping games", "playing harmonica", "playing harp", "playing ice hockey", "playing keyboard", "playing kickball", "playing laser tag", "playing lute", "playing mahjong", "playing maracas", "playing marbles", "playing monopoly", "playing netball", "playing nose flute", "playing oboe", "playing ocarina", "playing organ", "playing paintball", "playing pan pipes", "playing piano", "playing piccolo", "playing pinball", "playing ping pong", "playing poker", "playing polo", "playing recorder", "playing road hockey", "playing rounders", "playing rubiks cube", "playing saxophone", "playing scrabble", "playing shuffleboard", "playing slot machine", "playing squash or racquetball", "playing tennis", "playing trombone", "playing trumpet", "playing ukulele", "playing violin", "playing volleyball", "playing with trains", "playing xylophone", "poaching eggs", "poking bellybutton", "pole vault", "polishing furniture", "polishing metal", "popping balloons", "pouring beer", "pouring milk", "pouring wine", "preparing salad", "presenting weather forecast", "pretending to be a statue", "pull ups", "pulling espresso shot", "pulling rope (game)", "pumping fist", "pumping gas", "punching bag", "punching person (boxing)", "push up", "pushing car", "pushing cart", "pushing wheelbarrow", "pushing wheelchair", "putting in contact lenses", "putting on eyeliner", "putting on foundation", "putting on lipstick", "putting on mascara", "putting on sari", "putting on shoes", "putting wallpaper on wall", "raising eyebrows", "reading book", "reading newspaper", "recording music", "repairing puncture", "riding a bike", "riding camel", "riding elephant", "riding mechanical bull", "riding mule", "riding or walking with horse", "riding scooter", "riding snow blower", "riding unicycle", "ripping paper", "roasting marshmallows", "roasting pig", "robot dancing", "rock climbing", "rock scissors paper", "roller skating", "rolling eyes", "rolling pastry", "rope pushdown", "running on treadmill", "sailing", "salsa dancing", "saluting", "sanding floor", "sanding wood", "sausage making", "sawing wood", "scrambling eggs", "scrapbooking", "scrubbing face", "scuba diving", "seasoning food", "separating eggs", "setting table", "sewing", "shaking hands", "shaking head", "shaping bread dough", "sharpening knives", "sharpening pencil", "shaving head", "shaving legs", "shearing sheep", "shining flashlight", "shining shoes", "shoot dance", "shooting basketball", "shooting goal (soccer)", "shooting off fireworks", "shopping", "shot put", "shouting", "shoveling snow", "shredding paper", "shucking oysters", "shuffling cards", "shuffling feet", "side kick", "sieving", "sign language interpreting", "silent disco", "singing", "sipping cup", "situp", "skateboarding", "ski ballet", "ski jumping", "skiing crosscountry", "skiing mono", "skiing slalom", "skipping rope", "skipping stone", "skydiving", "slacklining", "slapping", "sled dog racing", "sleeping", "slicing onion", "smashing", "smelling feet", "smoking", "smoking hookah", "smoking pipe", "snatch weight lifting", "sneezing", "snorkeling", "snowboarding", "snowkiting", "snowmobiling", "somersaulting", "spelunking", "spinning plates", "spinning poi", "splashing water", "spray painting", "spraying", "springboard diving", "square dancing", "squat", "squeezing orange", "stacking cups", "stacking dice", "standing on hands", "staring", "steer roping", "steering car", "sticking tongue out", "stomping grapes", "stretching arm", "stretching leg", "sucking lolly", "surfing crowd", "surfing water", "surveying", "sweeping floor", "swimming backstroke", "swimming breast stroke", "swimming butterfly stroke", "swimming front crawl", "swimming with dolphins", "swimming with sharks", "swing dancing", "swinging baseball bat", "swinging on something", "sword fighting", "sword swallowing", "tackling", "tagging graffiti", "tai chi", "taking photo", "talking on cell phone", "tango dancing", "tap dancing", "tapping guitar", "tapping pen", "tasting beer", "tasting food", "tasting wine", "testifying", "texting", "threading needle", "throwing axe", "throwing ball (not baseball or American football)", "throwing discus", "throwing knife", "throwing snowballs", "throwing tantrum", "throwing water balloon", "tickling", "tie dying", "tightrope walking", "tiptoeing", "tobogganing", "tossing coin", "tossing salad", "training dog", "trapezing", "treating wood", "trimming or shaving beard", "trimming shrubs", "trimming trees", "triple jump", "twiddling fingers", "tying bow tie", "tying knot (not on a tie)", "tying necktie", "tying shoe laces", "unboxing", "uncorking champagne", "unloading truck", "using a microscope", "using a paint roller", "using a power drill", "using a sledge hammer", "using a wrench", "using atm", "using bagging machine", "using circular saw", "using inhaler", "using megaphone", "using puppets", "using remote controller (not gaming)", "using segway", "vacuuming car", "vacuuming floor", "visiting the zoo", "wading through mud", "wading through water", "waiting in line", "waking up", "walking on stilts", "walking the dog", "walking through snow", "walking with crutches", "washing dishes", "washing feet", "washing hair", "washing hands", "watching tv", "water skiing", "water sliding", "watering plants", "waving hand", "waxing armpits", "waxing back", "waxing chest", "waxing eyebrows", "waxing legs", "weaving basket", "weaving fabric", "welding", "whistling", "windsurfing", "winking", "wood burning (art)", "wrapping present", "wrestling", "writing", "yarn spinning", "yawning", "yoga", "zumba"], "templates": ["a photo of {}.", "a photo of a person {}.", "a photo of a person using {}.", "a photo of a person doing {}.", "a photo of a person during {}.", "a photo of a person performing {}.", "a photo of a person practicing {}.", "a video of {}.", "a video of a person {}.", "a video of a person using {}.", "a video of a person doing {}.", "a video of a person during {}.", "a video of a person performing {}.", "a video of a person practicing {}.", "a example of {}.", "a example of a person {}.", "a example of a person using {}.", "a example of a person doing {}.", "a example of a person during {}.", "a example of a person performing {}.", "a example of a person practicing {}.", "a demonstration of {}.", "a demonstration of a person {}.", "a demonstration of a person using {}.", "a demonstration of a person doing {}.", "a demonstration of a person during {}.", "a demonstration of a person performing {}.", "a demonstration of a person practicing {}."]}, "mnist": {"classes": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], "templates": ["a photo of the number: \"{}\"."]}, "oxford_pets": {"classes": ["Abyssinian", "American Bulldog", "American Pit Bull Terrier", "Basset Hound", "Beagle", "Bengal", "Birman", "Bombay", "Boxer", "British Shorthair", "Chihuahua", "Egyptian Mau", "English Cocker Spaniel", "English Setter", "German Shorthaired", "Great Pyrenees", "Havanese", "Japanese Chin", "Keeshond", "Leonberger", "Maine Coon", "Miniature Pinscher", "Newfoundland", "Persian", "Pomeranian", "Pug", "Ragdoll", "Russian Blue", "Saint Bernard", "Samoyed", "Scottish Terrier", "Shiba Inu", "Siamese", "Sphynx", "Staffordshire Bull Terrier", "Wheaten Terrier", "Yorkshire Terrier"], "templates": ["a photo of a {}, a type of pet."]}, "voc2007": {"classes": ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "dog", "horse", "motorbike", "person", "sheep", "sofa", "diningtable", "pottedplant", "train", "tvmonitor"], "templates": ["a photo of a {}."]}, "pcam": {"classes": ["lymph node", "lymph node containing metastatic tumor tissue"], "templates": ["this is a medical image of {}"]}, "resisc45": {"classes": ["airplane", "airport", "baseball diamond", "basketball court", "beach", "bridge", "chaparral", "church", "circular farmland", "cloud", "commercial area", "dense residential", "desert", "forest", "freeway", "golf course", "ground track field", "harbor", "industrial area", "intersection", "island", "lake", "meadow", "medium residential", "mobile home park", "mountain", "overpass", "palace", "parking lot", "railway", "railway station", "rectangular farmland", "river", "roundabout", "runway", "sea ice", "ship", "snowberg", "sparse residential", "stadium", "storage tank", "tennis court", "terrace", "thermal power station", "wetland"], "templates": ["satellite imagery of {}.", "aerial imagery of {}.", "satellite photo of {}.", "aerial photo of {}.", "satellite view of {}.", "aerial view of {}.", "satellite imagery of a {}.", "aerial imagery of a {}.", "satellite photo of a {}.", "aerial photo of a {}.", "satellite view of a {}.", "aerial view of a {}.", "satellite imagery of the {}.", "aerial imagery of the {}.", "satellite photo of the {}.", "aerial photo of the {}.", "satellite view of the {}.", "aerial view of the {}."]}, "renderedsst2": {"classes": ["negative", "positive"], "templates": ["a {} review of a movie."]}, "stl10": {"classes": ["airplane", "bird", "car", "cat", "deer", "dog", "horse", "monkey", "ship", "truck"], "templates": ["a photo of a {}.", "a photo of the {}."]}, "sun397": {"classes": ["abbey", "airplane cabin", "airport terminal", "alley", "amphitheater", "amusement arcade", "amusement park", "anechoic chamber", "apartment building outdoor", "apse indoor", "aquarium", "aqueduct", "arch", "archive", "arrival gate outdoor", "art gallery", "art school", "art studio", "assembly line", "athletic field outdoor", "atrium public", "attic", "auditorium", "auto factory", "badlands", "badminton court indoor", "baggage claim", "bakery shop", "balcony exterior", "balcony interior", "ball pit", "ballroom", "bamboo forest", "banquet hall", "bar", "barn", "barndoor", "baseball field", "basement", "basilica", "basketball court outdoor", "bathroom", "batters box", "bayou", "bazaar indoor", "bazaar outdoor", "beach", "beauty salon", "bedroom", "berth", "biology laboratory", "bistro indoor", "boardwalk", "boat deck", "boathouse", "bookstore", "booth indoor", "botanical garden", "bow window indoor", "bow window outdoor", "bowling alley", "boxing ring", "brewery indoor", "bridge", "building facade", "bullring", "burial chamber", "bus interior", "butchers shop", "butte", "cabin outdoor", "cafeteria", "campsite", "campus", "canal natural", "canal urban", "candy store", "canyon", "car interior backseat", "car interior frontseat", "carrousel", "casino indoor", "castle", "catacomb", "cathedral indoor", "cathedral outdoor", "cavern indoor", "cemetery", "chalet", "cheese factory", "chemistry lab", "chicken coop indoor", "chicken coop outdoor", "childs room", "church indoor", "church outdoor", "classroom", "clean room", "cliff", "cloister indoor", "closet", "clothing store", "coast", "cockpit", "coffee shop", "computer room", "conference center", "conference room", "construction site", "control room", "control tower outdoor", "corn field", "corral", "corridor", "cottage garden", "courthouse", "courtroom", "courtyard", "covered bridge exterior", "creek", "crevasse", "crosswalk", "cubicle office", "dam", "delicatessen", "dentists office", "desert sand", "desert vegetation", "diner indoor", "diner outdoor", "dinette home", "dinette vehicle", "dining car", "dining room", "discotheque", "dock", "doorway outdoor", "dorm room", "driveway", "driving range outdoor", "drugstore", "electrical substation", "elevator door", "elevator interior", "elevator shaft", "engine room", "escalator indoor", "excavation", "factory indoor", "fairway", "fastfood restaurant", "field cultivated", "field wild", "fire escape", "fire station", "firing range indoor", "fishpond", "florist shop indoor", "food court", "forest broadleaf", "forest needleleaf", "forest path", "forest road", "formal garden", "fountain", "galley", "game room", "garage indoor", "garbage dump", "gas station", "gazebo exterior", "general store indoor", "general store outdoor", "gift shop", "golf course", "greenhouse indoor", "greenhouse outdoor", "gymnasium indoor", "hangar indoor", "hangar outdoor", "harbor", "hayfield", "heliport", "herb garden", "highway", "hill", "home office", "hospital", "hospital room", "hot spring", "hot tub outdoor", "hotel outdoor", "hotel room", "house", "hunting lodge outdoor", "ice cream parlor", "ice floe", "ice shelf", "ice skating rink indoor", "ice skating rink outdoor", "iceberg", "igloo", "industrial area", "inn outdoor", "islet", "jacuzzi indoor", "jail cell", "jail indoor", "jewelry shop", "kasbah", "kennel indoor", "kennel outdoor", "kindergarden classroom", "kitchen", "kitchenette", "labyrinth outdoor", "lake natural", "landfill", "landing deck", "laundromat", "lecture room", "library indoor", "library outdoor", "lido deck outdoor", "lift bridge", "lighthouse", "limousine interior", "living room", "lobby", "lock chamber", "locker room", "mansion", "manufactured home", "market indoor", "market outdoor", "marsh", "martial arts gym", "mausoleum", "medina", "moat water", "monastery outdoor", "mosque indoor", "mosque outdoor", "motel", "mountain", "mountain snowy", "movie theater indoor", "museum indoor", "music store", "music studio", "nuclear power plant outdoor", "nursery", "oast house", "observatory outdoor", "ocean", "office", "office building", "oil refinery outdoor", "oilrig", "operating room", "orchard", "outhouse outdoor", "pagoda", "palace", "pantry", "park", "parking garage indoor", "parking garage outdoor", "parking lot", "parlor", "pasture", "patio", "pavilion", "pharmacy", "phone booth", "physics laboratory", "picnic area", "pilothouse indoor", "planetarium outdoor", "playground", "playroom", "plaza", "podium indoor", "podium outdoor", "pond", "poolroom establishment", "poolroom home", "power plant outdoor", "promenade deck", "pub indoor", "pulpit", "putting green", "racecourse", "raceway", "raft", "railroad track", "rainforest", "reception", "recreation room", "residential neighborhood", "restaurant", "restaurant kitchen", "restaurant patio", "rice paddy", "riding arena", "river", "rock arch", "rope bridge", "ruin", "runway", "sandbar", "sandbox", "sauna", "schoolhouse", "sea cliff", "server room", "shed", "shoe shop", "shopfront", "shopping mall indoor", "shower", "skatepark", "ski lodge", "ski resort", "ski slope", "sky", "skyscraper", "slum", "snowfield", "squash court", "stable", "stadium baseball", "stadium football", "stage indoor", "staircase", "street", "subway interior", "subway station platform", "supermarket", "sushi bar", "swamp", "swimming pool indoor", "swimming pool outdoor", "synagogue indoor", "synagogue outdoor", "television studio", "temple east asia", "temple south asia", "tennis court indoor", "tennis court outdoor", "tent outdoor", "theater indoor procenium", "theater indoor seats", "thriftshop", "throne room", "ticket booth", "toll plaza", "topiary garden", "tower", "toyshop", "track outdoor", "train railway", "train station platform", "tree farm", "tree house", "trench", "underwater coral reef", "utility room", "valley", "van interior", "vegetable garden", "veranda", "veterinarians office", "viaduct", "videostore", "village", "vineyard", "volcano", "volleyball court indoor", "volleyball court outdoor", "waiting room", "warehouse indoor", "water tower", "waterfall block", "waterfall fan", "waterfall plunge", "watering hole", "wave", "wet bar", "wheat field", "wind farm", "windmill", "wine cellar barrel storage", "wine cellar bottle storage", "wrestling ring indoor", "yard", "youth hostel"], "templates": ["a photo of a {}.", "a photo of the {}."]}, "standford_cars": {"classes": ["AM General Hummer SUV 2000", "Acura RL Sedan 2012", "Acura TL Sedan 2012", "Acura TL Type-S 2008", "Acura TSX Sedan 2012", "Acura Integra Type R 2001", "Acura ZDX Hatchback 2012", "Aston Martin V8 Vantage Convertible 2012", "Aston Martin V8 Vantage Coupe 2012", "Aston Martin Virage Convertible 2012", "Aston Martin Virage Coupe 2012", "Audi RS 4 Convertible 2008", "Audi A5 Coupe 2012", "Audi TTS Coupe 2012", "Audi R8 Coupe 2012", "Audi V8 Sedan 1994", "Audi 100 Sedan 1994", "Audi 100 Wagon 1994", "Audi TT Hatchback 2011", "Audi S6 Sedan 2011", "Audi S5 Convertible 2012", "Audi S5 Coupe 2012", "Audi S4 Sedan 2012", "Audi S4 Sedan 2007", "Audi TT RS Coupe 2012", "BMW ActiveHybrid 5 Sedan 2012", "BMW 1 Series Convertible 2012", "BMW 1 Series Coupe 2012", "BMW 3 Series Sedan 2012", "BMW 3 Series Wagon 2012", "BMW 6 Series Convertible 2007", "BMW X5 SUV 2007", "BMW X6 SUV 2012", "BMW M3 Coupe 2012", "BMW M5 Sedan 2010", "BMW M6 Convertible 2010", "BMW X3 SUV 2012", "BMW Z4 Convertible 2012", "Bentley Continental Supersports Conv. Convertible 2012", "Bentley Arnage Sedan 2009", "Bentley Mulsanne Sedan 2011", "Bentley Continental GT Coupe 2012", "Bentley Continental GT Coupe 2007", "Bentley Continental Flying Spur Sedan 2007", "Bugatti Veyron 16.4 Convertible 2009", "Bugatti Veyron 16.4 Coupe 2009", "Buick Regal GS 2012", "Buick Rainier SUV 2007", "Buick Verano Sedan 2012", "Buick Enclave SUV 2012", "Cadillac CTS-V Sedan 2012", "Cadillac SRX SUV 2012", "Cadillac Escalade EXT Crew Cab 2007", "Chevrolet Silverado 1500 Hybrid Crew Cab 2012", "Chevrolet Corvette Convertible 2012", "Chevrolet Corvette ZR1 2012", "Chevrolet Corvette Ron Fellows Edition Z06 2007", "Chevrolet Traverse SUV 2012", "Chevrolet Camaro Convertible 2012", "Chevrolet HHR SS 2010", "Chevrolet Impala Sedan 2007", "Chevrolet Tahoe Hybrid SUV 2012", "Chevrolet Sonic Sedan 2012", "Chevrolet Express Cargo Van 2007", "Chevrolet Avalanche Crew Cab 2012", "Chevrolet Cobalt SS 2010", "Chevrolet Malibu Hybrid Sedan 2010", "Chevrolet TrailBlazer SS 2009", "Chevrolet Silverado 2500HD Regular Cab 2012", "Chevrolet Silverado 1500 Classic Extended Cab 2007", "Chevrolet Express Van 2007", "Chevrolet Monte Carlo Coupe 2007", "Chevrolet Malibu Sedan 2007", "Chevrolet Silverado 1500 Extended Cab 2012", "Chevrolet Silverado 1500 Regular Cab 2012", "Chrysler Aspen SUV 2009", "Chrysler Sebring Convertible 2010", "Chrysler Town and Country Minivan 2012", "Chrysler 300 SRT-8 2010", "Chrysler Crossfire Convertible 2008", "Chrysler PT Cruiser Convertible 2008", "Daewoo Nubira Wagon 2002", "Dodge Caliber Wagon 2012", "Dodge Caliber Wagon 2007", "Dodge Caravan Minivan 1997", "Dodge Ram Pickup 3500 Crew Cab 2010", "Dodge Ram Pickup 3500 Quad Cab 2009", "Dodge Sprinter Cargo Van 2009", "Dodge Journey SUV 2012", "Dodge Dakota Crew Cab 2010", "Dodge Dakota Club Cab 2007", "Dodge Magnum Wagon 2008", "Dodge Challenger SRT8 2011", "Dodge Durango SUV 2012", "Dodge Durango SUV 2007", "Dodge Charger Sedan 2012", "Dodge Charger SRT-8 2009", "Eagle Talon Hatchback 1998", "FIAT 500 Abarth 2012", "FIAT 500 Convertible 2012", "Ferrari FF Coupe 2012", "Ferrari California Convertible 2012", "Ferrari 458 Italia Convertible 2012", "Ferrari 458 Italia Coupe 2012", "Fisker Karma Sedan 2012", "Ford F-450 Super Duty Crew Cab 2012", "Ford Mustang Convertible 2007", "Ford Freestar Minivan 2007", "Ford Expedition EL SUV 2009", "Ford Edge SUV 2012", "Ford Ranger SuperCab 2011", "Ford GT Coupe 2006", "Ford F-150 Regular Cab 2012", "Ford F-150 Regular Cab 2007", "Ford Focus Sedan 2007", "Ford E-Series Wagon Van 2012", "Ford Fiesta Sedan 2012", "GMC Terrain SUV 2012", "GMC Savana Van 2012", "GMC Yukon Hybrid SUV 2012", "GMC Acadia SUV 2012", "GMC Canyon Extended Cab 2012", "Geo Metro Convertible 1993", "HUMMER H3T Crew Cab 2010", "HUMMER H2 SUT Crew Cab 2009", "Honda Odyssey Minivan 2012", "Honda Odyssey Minivan 2007", "Honda Accord Coupe 2012", "Honda Accord Sedan 2012", "Hyundai Veloster Hatchback 2012", "Hyundai Santa Fe SUV 2012", "Hyundai Tucson SUV 2012", "Hyundai Veracruz SUV 2012", "Hyundai Sonata Hybrid Sedan 2012", "Hyundai Elantra Sedan 2007", "Hyundai Accent Sedan 2012", "Hyundai Genesis Sedan 2012", "Hyundai Sonata Sedan 2012", "Hyundai Elantra Touring Hatchback 2012", "Hyundai Azera Sedan 2012", "Infiniti G Coupe IPL 2012", "Infiniti QX56 SUV 2011", "Isuzu Ascender SUV 2008", "Jaguar XK XKR 2012", "Jeep Patriot SUV 2012", "Jeep Wrangler SUV 2012", "Jeep Liberty SUV 2012", "Jeep Grand Cherokee SUV 2012", "Jeep Compass SUV 2012", "Lamborghini Reventon Coupe 2008", "Lamborghini Aventador Coupe 2012", "Lamborghini Gallardo LP 570-4 Superleggera 2012", "Lamborghini Diablo Coupe 2001", "Land Rover Range Rover SUV 2012", "Land Rover LR2 SUV 2012", "Lincoln Town Car Sedan 2011", "MINI Cooper Roadster Convertible 2012", "Maybach Landaulet Convertible 2012", "Mazda Tribute SUV 2011", "McLaren MP4-12C Coupe 2012", "Mercedes-Benz 300-Class Convertible 1993", "Mercedes-Benz C-Class Sedan 2012", "Mercedes-Benz SL-Class Coupe 2009", "Mercedes-Benz E-Class Sedan 2012", "Mercedes-Benz S-Class Sedan 2012", "Mercedes-Benz Sprinter Van 2012", "Mitsubishi Lancer Sedan 2012", "Nissan Leaf Hatchback 2012", "Nissan NV Passenger Van 2012", "Nissan Juke Hatchback 2012", "Nissan 240SX Coupe 1998", "Plymouth Neon Coupe 1999", "Porsche Panamera Sedan 2012", "Ram C/V Cargo Van Minivan 2012", "Rolls-Royce Phantom Drophead Coupe Convertible 2012", "Rolls-Royce Ghost Sedan 2012", "Rolls-Royce Phantom Sedan 2012", "Scion xD Hatchback 2012", "Spyker C8 Convertible 2009", "Spyker C8 Coupe 2009", "Suzuki Aerio Sedan 2007", "Suzuki Kizashi Sedan 2012", "Suzuki SX4 Hatchback 2012", "Suzuki SX4 Sedan 2012", "Tesla Model S Sedan 2012", "Toyota Sequoia SUV 2012", "Toyota Camry Sedan 2012", "Toyota Corolla Sedan 2012", "Toyota 4Runner SUV 2012", "Volkswagen Golf Hatchback 2012", "Volkswagen Golf Hatchback 1991", "Volkswagen Beetle Hatchback 2012", "Volvo C30 Hatchback 2012", "Volvo 240 Sedan 1993", "Volvo XC90 SUV 2007", "smart fortwo Convertible 2012"], "templates": ["a photo of a {}.", "a photo of the {}.", "a photo of my {}.", "i love my {}!", "a photo of my dirty {}.", "a photo of my clean {}.", "a photo of my new {}.", "a photo of my old {}."]}, "ucf101": {"classes": ["Apply Eye Makeup", "Apply Lipstick", "Archery", "Baby Crawling", "Balance Beam", "Band Marching", "Baseball Pitch", "Basketball", "Basketball Dunk", "Bench Press", "Biking", "Billiards", "Blow Dry Hair", "Blowing Candles", "Body Weight Squats", "Bowling", "Boxing Punching Bag", "Boxing Speed Bag", "Breast Stroke", "Brushing Teeth", "Clean And Jerk", "Cliff Diving", "Cricket Bowling", "Cricket Shot", "Cutting In Kitchen", "Diving", "Drumming", "Fencing", "Field Hockey Penalty", "Floor Gymnastics", "Frisbee Catch", "Front Crawl", "Golf Swing", "Haircut", "Hammer Throw", "Hammering", "Hand Stand Pushups", "Handstand Walking", "Head Massage", "High Jump", "Horse Race", "Horse Riding", "Hula Hoop", "Ice Dancing", "Javelin Throw", "Juggling Balls", "Jump Rope", "Jumping Jack", "Kayaking", "Knitting", "Long Jump", "Lunges", "Military Parade", "Mixing", "Mopping Floor", "Nunchucks", "Parallel Bars", "Pizza Tossing", "Playing Cello", "Playing Daf", "Playing Dhol", "Playing Flute", "Playing Guitar", "Playing Piano", "Playing Sitar", "Playing Tabla", "Playing Violin", "Pole Vault", "Pommel Horse", "Pull Ups", "Punch", "Push Ups", "Rafting", "Rock Climbing Indoor", "Rope Climbing", "Rowing", "Salsa Spin", "Shaving Beard", "Shotput", "Skate Boarding", "Skiing", "Skijet", "Sky Diving", "Soccer Juggling", "Soccer Penalty", "Still Rings", "Sumo Wrestling", "Surfing", "Swing", "Table Tennis Shot", "Tai Chi", "Tennis Swing", "Throw Discus", "Trampoline Jumping", "Typing", "Uneven Bars", "Volleyball Spiking", "Walking With Dog", "Wall Pushups", "Writing On Board", "Yo Yo"], "templates": ["a photo of a person {}.", "a video of a person {}.", "a example of a person {}.", "a demonstration of a person {}.", "a photo of the person {}.", "a video of the person {}.", "a example of the person {}.", "a demonstration of the person {}.", "a photo of a person using {}.", "a video of a person using {}.", "a example of a person using {}.", "a demonstration of a person using {}.", "a photo of the person using {}.", "a video of the person using {}.", "a example of the person using {}.", "a demonstration of the person using {}.", "a photo of a person doing {}.", "a video of a person doing {}.", "a example of a person doing {}.", "a demonstration of a person doing {}.", "a photo of the person doing {}.", "a video of the person doing {}.", "a example of the person doing {}.", "a demonstration of the person doing {}.", "a photo of a person during {}.", "a video of a person during {}.", "a example of a person during {}.", "a demonstration of a person during {}.", "a photo of the person during {}.", "a video of the person during {}.", "a example of the person during {}.", "a demonstration of the person during {}.", "a photo of a person performing {}.", "a video of a person performing {}.", "a example of a person performing {}.", "a demonstration of a person performing {}.", "a photo of the person performing {}.", "a video of the person performing {}.", "a example of the person performing {}.", "a demonstration of the person performing {}.", "a photo of a person practicing {}.", "a video of a person practicing {}.", "a example of a person practicing {}.", "a demonstration of a person practicing {}.", "a photo of the person practicing {}.", "a video of the person practicing {}.", "a example of the person practicing {}.", "a demonstration of the person practicing {}."]}, "aid": {"classes": ["Airport", "Bare Land", "Baseball Field", "Beach", "Bridge", "Center", "Church", "Commercial", "Dense Residential", "Desert", "Farmland", "Forest", "Industrial", "Meadow", "Medium Residential", "Mountain", "Park", "Parking", "Playground", "Pond", "Port", "Railway Station", "Resort", "River", "School", "SparseResidential", "Square", "Stadium", "StorageTanks", "Viaduct"], "templates": ["satellite imagery of {}.", "aerial imagery of {}.", "satellite photo of {}.", "aerial photo of {}.", "satellite view of {}.", "aerial view of {}.", "satellite imagery of a {}.", "aerial imagery of a {}.", "satellite photo of a {}.", "aerial photo of a {}.", "satellite view of a {}.", "aerial view of a {}.", "satellite imagery of the {}.", "aerial imagery of the {}.", "satellite photo of the {}.", "aerial photo of the {}.", "satellite view of the {}.", "aerial view of the {}."]}, "imagenet": {"classes": ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"], "templates": ["a photo of a {}.", "a bad photo of a {}.", "a photo of many {}.", "a sculpture of a {}.", "a photo of the hard to see {}.", "a low resolution photo of the {}.", "a rendering of a {}.", "graffiti of a {}.", "a bad photo of the {}.", "a cropped photo of the {}.", "a tattoo of a {}.", "the embroidered {}.", "a photo of a hard to see {}.", "a bright photo of a {}.", "a photo of a clean {}.", "a photo of a dirty {}.", "a dark photo of the {}.", "a drawing of a {}.", "a photo of my {}.", "the plastic {}.", "a photo of the cool {}.", "a close-up photo of a {}.", "a black and white photo of the {}.", "a painting of the {}.", "a painting of a {}.", "a pixelated photo of the {}.", "a sculpture of the {}.", "a bright photo of the {}.", "a cropped photo of a {}.", "a plastic {}.", "a photo of the dirty {}.", "a jpeg corrupted photo of a {}.", "a blurry photo of the {}.", "a photo of the {}.", "a good photo of the {}.", "a rendering of the {}.", "a {} in a video game.", "a photo of one {}.", "a doodle of a {}.", "a close-up photo of the {}.", "a photo of a {}.", "the origami {}.", "the {} in a video game.", "a sketch of a {}.", "a doodle of the {}.", "a origami {}.", "a low resolution photo of a {}.", "the toy {}.", "a rendition of the {}.", "a photo of the clean {}.", "a photo of a large {}.", "a rendition of a {}.", "a photo of a nice {}.", "a photo of a weird {}.", "a blurry photo of a {}.", "a cartoon {}.", "art of a {}.", "a sketch of the {}.", "a embroidered {}.", "a pixelated photo of a {}.", "itap of the {}.", "a jpeg corrupted photo of the {}.", "a good photo of a {}.", "a plushie {}.", "a photo of the nice {}.", "a photo of the small {}.", "a photo of the weird {}.", "the cartoon {}.", "art of the {}.", "a drawing of the {}.", "a photo of the large {}.", "a black and white photo of a {}.", "the plushie {}.", "a dark photo of a {}.", "itap of a {}.", "graffiti of the {}.", "a toy {}.", "itap of my {}.", "a photo of a cool {}.", "a photo of a small {}.", "a tattoo of the {}."]}, "office_d": {"classes": ["back_pack", "bike", "bike_helmet", "bookcase", "bottle", "calculator", "desk_chair", "desk_lamp", "desktop_computer", "file_cabinet", "headphones", "keyboard", "laptop_computer", "letter_tray", "mobile_phone", "monitor", "mouse", "mug", "paper_notebook", "pen", "phone", "printer", "projector", "punchers", "ring_binder", "ruler", "scissors", "speaker", "stapler", "tape_dispenser", "trash_can"], "templates": ["a photo of a {}."]}, "office_w": {"classes": ["back_pack", "bike", "bike_helmet", "bookcase", "bottle", "calculator", "desk_chair", "desk_lamp", "desktop_computer", "file_cabinet", "headphones", "keyboard", "laptop_computer", "letter_tray", "mobile_phone", "monitor", "mouse", "mug", "paper_notebook", "pen", "phone", "printer", "projector", "punchers", "ring_binder", "ruler", "scissors", "speaker", "stapler", "tape_dispenser", "trash_can"], "templates": ["a photo of a {}."]}, "office_a": {"classes": ["back_pack", "bike", "bike_helmet", "bookcase", "bottle", "calculator", "desk_chair", "desk_lamp", "desktop_computer", "file_cabinet", "headphones", "keyboard", "laptop_computer", "letter_tray", "mobile_phone", "monitor", "mouse", "mug", "paper_notebook", "pen", "phone", "printer", "projector", "punchers", "ring_binder", "ruler", "scissors", "speaker", "stapler", "tape_dispenser", "trash_can"], "templates": ["A photo of a {}."]}, "office_ar": {"classes": ["Drill", "Exit Sign", "Bottle", "Glasses", "Computer", "File Cabinet", "Shelf", "Toys", "Sink", "Laptop", "Kettle", "Folder", "Keyboard", "Flipflops", "Pencil", "Bed", "Hammer", "ToothBrush", "Couch", "Bike", "Postit Notes", "Mug", "Webcam", "Desk Lamp", "Telephone", "Helmet", "Mouse", "Pen", "Monitor", "Mop", "Sneakers", "Notebook", "Backpack", "Alarm Clock", "Push Pin", "Paper Clip", "Batteries", "Radio", "Fan", "Ruler", "Pan", "Screwdriver", "Trash Can", "Printer", "Speaker", "Eraser", "Bucket", "Chair", "Calendar", "Calculator", "Flowers", "Lamp Shade", "Spoon", "Candles", "Clipboards", "Scissors", "TV", "Curtains", "Fork", "Soda", "Table", "Knives", "Oven", "Refrigerator", "Marker"], "templates": ["A photo of a {}."]}, "office_cl": {"classes": ["Drill", "Exit Sign", "Bottle", "Glasses", "Computer", "File Cabinet", "Shelf", "Toys", "Sink", "Laptop", "Kettle", "Folder", "Keyboard", "Flipflops", "Pencil", "Bed", "Hammer", "ToothBrush", "Couch", "Bike", "Postit Notes", "Mug", "Webcam", "Desk Lamp", "Telephone", "Helmet", "Mouse", "Pen", "Monitor", "Mop", "Sneakers", "Notebook", "Backpack", "Alarm Clock", "Push Pin", "Paper Clip", "Batteries", "Radio", "Fan", "Ruler", "Pan", "Screwdriver", "Trash Can", "Printer", "Speaker", "Eraser", "Bucket", "Chair", "Calendar", "Calculator", "Flowers", "Lamp Shade", "Spoon", "Candles", "Clipboards", "Scissors", "TV", "Curtains", "Fork", "Soda", "Table", "Knives", "Oven", "Refrigerator", "Marker"], "templates": ["A photo of a {}."]}, "office_pr": {"classes": ["Drill", "Exit Sign", "Bottle", "Glasses", "Computer", "File Cabinet", "Shelf", "Toys", "Sink", "Laptop", "Kettle", "Folder", "Keyboard", "Flipflops", "Pencil", "Bed", "Hammer", "ToothBrush", "Couch", "Bike", "Postit Notes", "Mug", "Webcam", "Desk Lamp", "Telephone", "Helmet", "Mouse", "Pen", "Monitor", "Mop", "Sneakers", "Notebook", "Backpack", "Alarm Clock", "Push Pin", "Paper Clip", "Batteries", "Radio", "Fan", "Ruler", "Pan", "Screwdriver", "Trash Can", "Printer", "Speaker", "Eraser", "Bucket", "Chair", "Calendar", "Calculator", "Flowers", "Lamp Shade", "Spoon", "Candles", "Clipboards", "Scissors", "TV", "Curtains", "Fork", "Soda", "Table", "Knives", "Oven", "Refrigerator", "Marker"], "templates": ["A photo of a {}."]}, "office_rw": {"classes": ["Drill", "Exit Sign", "Bottle", "Glasses", "Computer", "File Cabinet", "Shelf", "Toys", "Sink", "Laptop", "Kettle", "Folder", "Keyboard", "Flipflops", "Pencil", "Bed", "Hammer", "ToothBrush", "Couch", "Bike", "Postit Notes", "Mug", "Webcam", "Desk Lamp", "Telephone", "Helmet", "Mouse", "Pen", "Monitor", "Mop", "Sneakers", "Notebook", "Backpack", "Alarm Clock", "Push Pin", "Paper Clip", "Batteries", "Radio", "Fan", "Ruler", "Pan", "Screwdriver", "Trash Can", "Printer", "Speaker", "Eraser", "Bucket", "Chair", "Calendar", "Calculator", "Flowers", "Lamp Shade", "Spoon", "Candles", "Clipboards", "Scissors", "TV", "Curtains", "Fork", "Soda", "Table", "Knives", "Oven", "Refrigerator", "Marker"], "templates": ["A photo of a {}."]}} --------------------------------------------------------------------------------