├── DataFiltering ├── Download.py ├── FilterData.py └── TransFilter.py ├── LICENSE ├── Priming ├── Templates │ ├── FGVCAircraft.py │ ├── FGVC_Aircraft.txt │ ├── Flowers102.py │ ├── Flowers102.txt │ ├── Food101.py │ ├── Food101.txt │ ├── ImageNet.py │ ├── OxfordIIITPet.py │ ├── OxfordPets.txt │ ├── SUN397.py │ ├── StanfordCars.py │ ├── StanfordCars.txt │ └── Sun397.txt ├── args.py ├── data.py ├── imagenet.json ├── prime.py └── util.py ├── README.md ├── assets ├── Aircraft_FSL.jpeg ├── Cars_FSL.jpeg ├── Flowers_FSL.jpeg └── teaser.jpg ├── environment.yml └── requirements.txt /DataFiltering/Download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import pandas as pd 4 | from requests.packages.urllib3.util.retry import Retry 5 | from requests.adapters import HTTPAdapter 6 | from requests.exceptions import SSLError 7 | import concurrent.futures 8 | import socket 9 | import argparse 10 | import time 11 | from tqdm import tqdm 12 | 13 | def download_dataset(root_folder_read, root_folder_write, max_imgs): 14 | # Download images given a pytorch folder of urls given in jsons 15 | master_df = pd.DataFrame(columns=['URL', 'Path', 'TEXT']) 16 | for class_name in os.listdir(root_folder_read): 17 | class_folder = os.path.join(root_folder_read, class_name.replace('/', 'or')) 18 | write_folder = os.path.join(root_folder_write, class_name.replace('/', 'or')) 19 | if not os.path.exists(write_folder): 20 | os.makedirs(write_folder) 21 | for chunks in os.listdir(class_folder): 22 | chunk_path = os.path.join(class_folder, chunks) 23 | df = pd.read_json(chunk_path) 24 | folder_path = os.path.join(root_folder_write, class_name) 25 | df_out = download_url_list(root_folder_write, df, class_name, max_imgs) 26 | master_df = pd.concat([master_df, df_out]) 27 | master_df.to_csv(root_folder_write + '.csv') 28 | 29 | 30 | def download_image(url, filename, df, urls, paths, captions, sims, j): 31 | #texts = caption 32 | retry_strategy = Retry( 33 | total=1, 34 | backoff_factor=1, 35 | ) 36 | adapter = HTTPAdapter(max_retries=retry_strategy) 37 | 38 | # Create a requests session with the retry mechanism 39 | session = requests.Session() 40 | session.mount('http://', adapter) 41 | session.mount('https://', adapter) # (connect timeout, read timeout) 42 | #print(url) 43 | # Set the socket timeout for the session 44 | socket.setdefaulttimeout(1) 45 | # Make a request using the session 46 | try: 47 | if not(os.path.exists(filename)): 48 | response = session.get(url, timeout=1) 49 | if response.status_code == 200 and response.history == [] and response.url == url: 50 | with open(filename, 'wb') as f: 51 | f.write(requests.get(url).content) 52 | print(f'Downloaded {url}') 53 | urls.append(url) 54 | paths.append(filename) 55 | captions.append(df.TEXT[j]) 56 | sims.append(df.similarity[j]) 57 | else: 58 | print(f'Failed to download {url} (status code: {response.status_code})') 59 | else: 60 | print('Already downloaded') 61 | except requests.exceptions.Timeout: 62 | print(f"Timeout occurred while downloading {url}") 63 | except SSLError: 64 | print(f"SSL error occurred while downloading {url}") 65 | except Exception as e: 66 | print(f"Error occurred while downloading {url}: {str(e)}") 67 | 68 | def download_url_list(root_folder, df, class_name, max_imgs): 69 | start_time = time.time() 70 | image_urls = df.URL 71 | texts = df.TEXT 72 | # Create a directory to save the downloaded images 73 | if not os.path.exists(root_folder): 74 | os.makedirs(root_folder) 75 | folder_path = os.path.join(root_folder, class_name) 76 | print(folder_path) 77 | if not os.path.exists(folder_path): 78 | os.makedirs(folder_path) 79 | 80 | # Create a thread pool to download the images 81 | with concurrent.futures.ThreadPoolExecutor(max_workers=256) as executor: 82 | # Loop through the image URLs and download each safe image 83 | urls = [] 84 | paths = [] 85 | captions = [] 86 | sims = [] 87 | futures = [] 88 | for j, url in tqdm(enumerate(image_urls[:max_imgs])): 89 | if 'png' or 'jpg' in url: 90 | filename = os.path.join(folder_path, os.path.basename(url)) 91 | future = executor.submit(download_image, url, filename, df, urls, paths, captions, sims, j) 92 | futures 93 | print(f'Downloaded {len(urls)} out of {len(image_urls[:max_imgs])} images.') 94 | print(print("Downloaded in --- %s seconds ---" % (time.time() - start_time))) 95 | print('Images per second: {}'.format(len(urls)/(time.time() - start_time))) 96 | 97 | if __name__ == "__main__": 98 | # Create argument parser 99 | parser = argparse.ArgumentParser( 100 | description="" 101 | ) 102 | 103 | # Add arguments 104 | parser.add_argument("--r", type=str, help="Path to SQLite database file.") 105 | parser.add_argument("--w", type=str, help="") 106 | parser.add_argument("--n", type=int, default=1000) 107 | 108 | # Parse arguments 109 | args = parser.parse_args() 110 | 111 | # Call main function with arguments 112 | download_dataset(args.r, args.w, args.n) -------------------------------------------------------------------------------- /DataFiltering/FilterData.py: -------------------------------------------------------------------------------- 1 | from concurrent import futures 2 | import time 3 | import sqlite3 4 | from pathlib import Path 5 | import argparse 6 | import subprocess 7 | import os 8 | import importlib 9 | import sys 10 | 11 | 12 | import pyarrow as pa 13 | import pyarrow.parquet as pq 14 | import itertools 15 | import pandas as pd 16 | import tqdm 17 | 18 | 19 | current_directory = os.path.dirname(os.path.abspath(os.getcwd())) 20 | sys.path.append(current_directory) 21 | 22 | def get_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("-d", "--dbs", nargs="+", default=None, help="Path to sqlite dbs") 25 | parser.add_argument( 26 | "-q", 27 | "--query", 28 | type=str, 29 | default=None, 30 | help="Pipe separated list of queries or newline separated query document", 31 | ) 32 | 33 | parser.add_argument( 34 | "--template", 35 | action='store_true', 36 | help="", 37 | ) 38 | parser.add_argument( 39 | "-n", 40 | "--quantity", 41 | type=int, 42 | default=None, 43 | help="Number of desired outputs (currently only functions with workers=1)", 44 | ) 45 | parser.add_argument( 46 | "-o", 47 | "--output", 48 | default="./", 49 | help="Full output path, will make parent directories if they don't exist", 50 | ) 51 | parser.add_argument( 52 | "-w", 53 | "--workers", 54 | type=int, 55 | default=1, 56 | ) 57 | parser.add_argument( 58 | "--field", default="TEXT", type=str, help="Field to search database" 59 | ) 60 | 61 | args = parser.parse_args() 62 | 63 | return args 64 | 65 | 66 | def main(): 67 | args = get_args() 68 | if args.template: 69 | print('Priming.Templates.' + args.query) 70 | dataset_obj = importlib.import_module('Priming.Templates.' + args.query) 71 | words = dataset_obj.classes 72 | else: 73 | if not os.path.exists(args.query): 74 | words = args.query.split("|") 75 | else: 76 | words = [l for l in Path(args.query).read_text().split("\n") if l] 77 | 78 | print( 79 | f"Searching {len(args.dbs)} dbs for {len(words)} needles:" 80 | ) 81 | out = search_sharded_database( 82 | args.dbs, 83 | words, 84 | workers=args.workers, 85 | max_results=args.quantity, 86 | field=args.field, 87 | ) 88 | 89 | fields = [ 90 | "SAMPLE_ID", 91 | "URL", 92 | "TEXT", 93 | "HEIGHT", 94 | "WIDTH", 95 | "LICENSE", 96 | "NSFW", 97 | "similarity", 98 | "QUERY", 99 | ] 100 | 101 | field_types = [ 102 | pa.int64(), 103 | pa.binary(), 104 | pa.binary(), 105 | pa.int32(), 106 | pa.int32(), 107 | pa.binary(), 108 | pa.binary(), 109 | pa.float64(), 110 | pa.binary(), 111 | ] 112 | 113 | schema = pa.schema( 114 | [pa.field(name, dtype) for name, dtype in zip(fields, field_types)] 115 | ) 116 | 117 | folder = Path(args.output) 118 | folder.mkdir(parents=True, exist_ok=True) 119 | 120 | for i, chunk in enumerate( 121 | chunk_iterator(row_iterator(out, fn=process_fields), chunk_size=500000) 122 | ): 123 | df = pd.DataFrame(chunk, columns=fields) 124 | df.to_json(folder / f"chunk_{i}.json", orient="records") 125 | # table = pa.Table.from_pandas(df, schema=schema) 126 | # pq.write_table(table, folder / f"chunk_{i}.parquet") 127 | 128 | 129 | def process_fields(key, row): 130 | sample_id, url, text, height, width, licence_, nsfw, similarity = row 131 | 132 | return ( 133 | int(float(sample_id)) if sample_id else None, 134 | bytes(url, "utf-8") if url else None, 135 | bytes(text, "utf-8") if text else None, 136 | int(float(height)) if height else None, 137 | int(float(width)) if width else None, 138 | bytes(licence_, "utf-8") if licence_ else None, 139 | bytes(nsfw, "utf-8") if nsfw else None, 140 | float(similarity) if similarity else None, 141 | bytes(key, "utf-8"), 142 | ) 143 | 144 | 145 | def chunk_iterator(iterator, chunk_size): 146 | """ 147 | Given an iterator, returns an iterator of iterators where each 148 | inner iterator has length `chunk_size` or less. 149 | """ 150 | while True: 151 | chunk = list(itertools.islice(iterator, chunk_size)) 152 | if not chunk: 153 | break 154 | yield chunk 155 | 156 | 157 | def row_iterator(in_dict, fn=lambda x: x): 158 | for key, values in in_dict.items(): 159 | for row in values: 160 | yield fn(key, row) 161 | 162 | 163 | def safe_dict_collate(dict_a, dict_b): 164 | set_keys = set(dict_a.keys()).union(set(dict_b.keys())) 165 | 166 | out = {} 167 | for k in set_keys: 168 | a_vals = dict_a.get(k, []) 169 | b_vals = dict_b.get(k, []) 170 | 171 | out[k] = a_vals + b_vals 172 | 173 | return out 174 | 175 | 176 | def search_sharded_database( 177 | dbs, words, max_results=None, workers=1, field="TEXT" 178 | ): 179 | items = [(i, db, words, field) for i, db in enumerate(dbs)] 180 | 181 | with futures.ThreadPoolExecutor(max_workers=workers) as executor: 182 | futures_to_results = { 183 | executor.submit(search_database, item): item for item in items 184 | } 185 | all_results = {} 186 | for future in futures.as_completed(futures_to_results): 187 | result = future.result() 188 | 189 | all_results = safe_dict_collate(all_results, result) 190 | 191 | if max_results is not None and all( 192 | [len(v) > max_results for v in all_results.values()] 193 | ): 194 | for future in futures_to_results: 195 | future.cancel() 196 | break 197 | 198 | return all_results 199 | 200 | 201 | def search_database(args): 202 | shard_idx, db, words, field = args 203 | word_to_results = {} 204 | start_time = time.time() 205 | total_results = 0 206 | if os.path.exists(db): 207 | conn = sqlite3.connect(db) 208 | c = conn.cursor() 209 | for i, word in tqdm.tqdm(enumerate(words), desc=f"Shard {shard_idx} ", total=len(words)): 210 | query = f"SELECT * FROM samples WHERE {field} MATCH '\"{word}\"'" 211 | c.execute(query) 212 | 213 | # Fetch results 214 | word_to_results[word] = list(c.fetchall()) 215 | total_results += len(word_to_results[word]) 216 | 217 | end_time = time.time() 218 | print( 219 | f"Search of shard {shard_idx} took {end_time - start_time:.4f} seconds for {len(words)} words," 220 | f" {total_results} results" 221 | ) 222 | 223 | conn.close() 224 | else: 225 | print("Skipping shard:{}".format(db)) 226 | return word_to_results 227 | 228 | 229 | if __name__ == "__main__": 230 | main() 231 | -------------------------------------------------------------------------------- /DataFiltering/TransFilter.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from numpy.lib.format import open_memmap 3 | from pathlib import Path 4 | from PIL import Image, ImageDraw 5 | from torchvision.datasets.folder import default_loader 6 | from torch.utils.data import DataLoader, Dataset 7 | from utils.dataset import SafeImageFolder 8 | 9 | import torchvision 10 | import clip 11 | import shutil 12 | import pickle 13 | 14 | import argparse 15 | import numpy as np 16 | import torch.nn as nn 17 | import torch 18 | import torch.optim as optim 19 | import torch.nn.functional as F 20 | import tqdm 21 | import math 22 | import os 23 | import json 24 | import random as r 25 | import open_clip 26 | 27 | 28 | def get_args(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument( 31 | "--retrieval-path", 32 | type=Path, 33 | default=None, 34 | required=True, 35 | help="Path to retrieval reservoir (clean subset of LAION)", 36 | ) 37 | parser.add_argument( 38 | "--transductive-path", 39 | type=Path, 40 | default=None, 41 | required=True, 42 | help="Path to transfer evaluation dataset (to be used in a transductive fashion)", 43 | ) 44 | parser.add_argument( 45 | "--k-shot", 46 | type=int, 47 | default=None, 48 | help="Number of shots per class, only used in train-time data augmentation", 49 | ) 50 | parser.add_argument( 51 | "--cache-path", 52 | type=Path, 53 | default=Path("/usr/aux/gen-datasets/cache"), 54 | help="Path to cache", 55 | ) 56 | parser.add_argument( 57 | "--out-dir", 58 | type=Path, 59 | default=None, 60 | required=True, 61 | help="Path to output directory (dataset in ImageFolder format)", 62 | ) 63 | parser.add_argument( 64 | "--retrievals-per-image", 65 | default=10, 66 | type=int, 67 | help="Number of retrievals per image", 68 | ) 69 | parser.add_argument( 70 | "--clip-filter", 71 | action="store_true", 72 | help="Filter using CLIP before transductive retrieval", 73 | ) 74 | parser.add_argument( 75 | "--prompt-file", 76 | type=Path, 77 | default=None, 78 | help="Path to prompt file, format classname to list of prompts", 79 | ) 80 | parser.add_argument( 81 | "--dataset-type", type=str, default="ImageFolder", help="Type of dataset" 82 | ) 83 | parser.add_argument( 84 | "--clip-score-filter", 85 | type=float, 86 | default=None, 87 | help="Filter using CLIP score, after clip classification filtering", 88 | ) 89 | parser.add_argument( 90 | "--split", 91 | type=str, 92 | default="val", 93 | help="Split to use, only applies to non-ImageFolder datasets", 94 | ) 95 | parser.add_argument( 96 | "--model", 97 | type=str, 98 | default="ViT-B-16", 99 | help="Model arch from open_clip to use for filtering" 100 | ) 101 | parser.add_argument( 102 | "--pretrained", 103 | type=str, 104 | default="laion2b_s34b_b88k", 105 | help="Pre-trained weights from open_clip to use for filtering. See open_clip repo for choices" 106 | ) 107 | args = parser.parse_args() 108 | 109 | return args 110 | 111 | 112 | def main(): 113 | # Assume data retrieval set is the same structure transfer dataset 114 | 115 | # 1. Get data retrieval set, both metadata CSV + paths, and transfer dataset 116 | # 1a. Filter transfer dataset by classes which exist in the retrieval dataset 117 | # 2. Precompute all features, store in mmap arrays 118 | # 3. Precompute similarity matrix between retrieval set + data retrieval set 119 | # 4. For each element of the val set, get _k_ relevant images of the retrieval set. 120 | # Hard label these retrieved images and "finetune" the zero-shot head. 121 | # 4a. Various fine-tuning options: 122 | # i. NCM with zeroshot head 123 | # ii. NCM + wise-ft zeroshot head + double-counting retrievals 124 | # iii. Linear probe (maybe mask the irrelevant images) 125 | # iv. Get _all_ images _k_ per class, then fine-tune 126 | # 4b. Test scale wrt _k_ 127 | 128 | # Dummy args for testing 129 | args = get_args() 130 | 131 | args.cache_path.mkdir(exist_ok=True, parents=True) 132 | 133 | # Getting model 134 | print("=> Acquiring model") 135 | model, _, preprocess = open_clip.create_model_and_transforms( 136 | args.model, args.pretrained, device="cuda" 137 | ) 138 | 139 | # Getting retrieval dataset/loader 140 | print("=> Getting retrieval set") 141 | retrieval_set = SafeImageFolder( 142 | args.retrieval_path, 143 | transform=preprocess, 144 | is_valid_file=lambda x: Path( 145 | x 146 | ).is_file(), # choose all subfiles, these get filtered later 147 | ) 148 | retrieval_loader = DataLoader( 149 | retrieval_set, batch_size=256, shuffle=False, num_workers=16 150 | ) 151 | print(f"---- Found {len(retrieval_set)} examples ----") 152 | 153 | tset_kwargs = dict(transform=preprocess) 154 | 155 | if args.dataset_type != "ImageFolder": 156 | tset_kwargs.update(**dict(split=args.split, download=True)) 157 | 158 | print("=> Getting transductive set") 159 | transductive_set = getattr(torchvision.datasets, args.dataset_type)( 160 | args.transductive_path, **tset_kwargs 161 | ) 162 | transductive_loader = DataLoader( 163 | transductive_set, batch_size=256, shuffle=False, num_workers=16 164 | ) 165 | print(f"---- Found {len(transductive_set)} examples ----") 166 | 167 | print("=> Renormalizing retrieval set labels") 168 | if args.dataset_type != "ImageFolder": 169 | import templates 170 | 171 | # Fixing the retrieval dataset 172 | class_list = [ 173 | c.replace("/", "or") for c in getattr(templates, args.dataset_type).classes 174 | ] 175 | class_to_idx = {cls: i for i, cls in enumerate(class_list)} 176 | else: 177 | class_to_idx = transductive_set.class_to_idx 178 | class_list = transductive_set.classes 179 | 180 | imgs = [] 181 | for path, label in retrieval_set.imgs: 182 | if retrieval_set.classes[label] not in class_to_idx: 183 | continue 184 | 185 | imgs.append((path, class_to_idx[retrieval_set.classes[label]])) 186 | 187 | retrieval_set.imgs = retrieval_set.samples = imgs 188 | retrieval_set.classes = class_list 189 | retrieval_set.class_to_idx = class_to_idx 190 | 191 | # Correcting datasets 192 | print("=> Filtering bad images from retrieval set") 193 | retrieval_set = filter_bad_images( 194 | retrieval_set, 195 | cache=args.cache_path / f"bad_{args.retrieval_path.stem}_images.npy", 196 | ) 197 | print(f"---- Now {len(retrieval_set)} examples ----") 198 | 199 | # Feature extraction 200 | print("=> Doing feature extraction") 201 | transductive_features = extract_features( 202 | transductive_loader, 203 | model=model, 204 | memmap_file=args.cache_path 205 | / f"cache_{args.transductive_path.stem}_features.npy", 206 | ) 207 | 208 | retrieval_features = extract_features( 209 | retrieval_loader, 210 | model=model, 211 | memmap_file=args.cache_path / f"cache_{args.retrieval_path.stem}_features.npy", 212 | ) 213 | 214 | # Applying k-shot filtering and clip filtering 215 | logit_max_probs = None 216 | if args.clip_filter: 217 | print("=> Performing clip filtering") 218 | retrieval_features, retrieval_set, logit_max_probs = clip_filter( 219 | model, 220 | retrieval_features=retrieval_features, 221 | retrieval_set=retrieval_set, 222 | class_prompt_dict=json.load(args.prompt_file.open("r")), 223 | clip_score_filter=args.clip_score_filter, 224 | ) 225 | print(f"=> Done clip filtering, {len(retrieval_set)} examples left") 226 | 227 | if args.k_shot is not None: 228 | print(f"=> Doing {args.k_shot}-shot filtering") 229 | if args.dataset_type != "ImageFolder": 230 | transductive_set, shot_indices = k_shot_generic( 231 | transductive_set, k=args.k_shot 232 | ) 233 | 234 | else: 235 | transductive_set, shot_indices = k_shot_imagefolder( 236 | transductive_set, k=args.k_shot 237 | ) 238 | 239 | transductive_features = transductive_features[shot_indices.astype(bool), :] 240 | 241 | # Computing sim matrix 242 | print("=> Computing batched inner products") 243 | sim_matrix = batched_inner_products( 244 | transductive_features, 245 | retrieval_features, 246 | batch_size=256, 247 | out=args.cache_path 248 | / f"cache_{args.retrieval_path.stem}_sim_shots={args.k_shot}_filter={args.clip_filter}_score={args.clip_score_filter}.npy", 249 | ) 250 | 251 | print("=> Getting closest retrievals") 252 | paths, paths_by_image = get_closest_retrievals( 253 | sim_matrix=sim_matrix, 254 | dataset=retrieval_set, 255 | transductive_set=transductive_set, 256 | k=args.retrievals_per_image, 257 | allow_image_labels=args.k_shot is not None, 258 | logit_max_probs=logit_max_probs, 259 | ) 260 | 261 | with open(args.cache_path / f"{args.retrieval_path.stem}_paths_by_image.pkl", "wb") as f: 262 | pickle.dump(paths_by_image, f) 263 | 264 | print(f"=> Copying images to output directory {args.out_dir}") 265 | 266 | for cname in class_list: 267 | (args.out_dir / cname).mkdir(exist_ok=True, parents=True) 268 | 269 | for path, label in tqdm.tqdm(paths): 270 | path = Path(path) 271 | 272 | class_dir = args.out_dir / class_list[label] 273 | out_path = class_dir / path.name 274 | shutil.copy(path, out_path) 275 | 276 | 277 | 278 | globals().update(locals()) 279 | 280 | 281 | def add_empty_folders(folder): 282 | folder = Path(folder) 283 | 284 | for subfolder in folder.iterdir(): 285 | subfolder.mkdir(exist_ok=True) 286 | 287 | raise NotImplementedError() 288 | 289 | 290 | @torch.no_grad() 291 | def clip_filter( 292 | model, retrieval_features, retrieval_set, class_prompt_dict, clip_score_filter=0.0 293 | ): 294 | zs_head = compute_zero_shot_head( 295 | model, class_prompt_dict, classnames=retrieval_set.classes, device=0 296 | ) 297 | 298 | imgs = [] 299 | acc = 0 300 | total = 0 301 | indices = [] 302 | logit_max_probs = [] 303 | for image_features, batch_imgs in tqdm.tqdm( 304 | zip( 305 | batchify(retrieval_features, batch_size=256), 306 | batchify(retrieval_set.imgs, batch_size=256), 307 | ), 308 | total=len(retrieval_set) // 256 + 1, 309 | desc="CLIP filtering", 310 | ): 311 | img_paths, labels = zip(*batch_imgs) 312 | image_feature = torch.tensor(image_features).to(0) 313 | logits = image_feature @ zs_head.T 314 | 315 | for logit, img_path, label in zip(logits, img_paths, labels): 316 | score = (100 * logit).squeeze().softmax(dim=-1)[label].item() 317 | if label == logit.squeeze().argmax().item() and score >= clip_score_filter: 318 | imgs.append((img_path, label)) 319 | acc += 1 320 | 321 | logit_max_probs.append(score) 322 | 323 | indices.append(total) 324 | 325 | total += 1 326 | 327 | retrieval_set.imgs = retrieval_set.samples = imgs 328 | 329 | print(f"Clip filter accuracy on retrieval set: {100*acc / total:0.2f})") 330 | 331 | return retrieval_features[np.array(indices)], retrieval_set, logit_max_probs 332 | 333 | 334 | def get_closest_retrievals( 335 | sim_matrix, 336 | dataset: SafeImageFolder, 337 | transductive_set: SafeImageFolder, 338 | k, 339 | allow_image_labels=False, 340 | logit_max_probs=None, 341 | ): 342 | if allow_image_labels: 343 | labels_to_indices = defaultdict(lambda: np.zeros(len(dataset), dtype=np.uint8)) 344 | 345 | for i, (_, label) in enumerate(dataset.imgs): 346 | labels_to_indices[label][i] = 1 347 | else: 348 | labels_to_indices = defaultdict(lambda: np.ones(len(dataset), dtype=np.uint8)) 349 | 350 | # sim matrix is transductive example size x retrieval set size 351 | outs = [] 352 | outs_by_image_path = {} 353 | 354 | print(sim_matrix.shape, len(transductive_set)) 355 | for i in tqdm.tqdm(range(sim_matrix.shape[0])): 356 | label = transductive_set[i][1] 357 | 358 | # since sims are between -1 and 1, we add 10 to make sure that the retrieval indices 359 | # are only from the relevant class 360 | retrieval_indices = np.argpartition( 361 | (sim_matrix[i] + 10) * labels_to_indices[label], -k 362 | )[-k:] 363 | 364 | paths = [dataset.imgs[j] for j in retrieval_indices] 365 | if logit_max_probs is not None: 366 | logit_prob = [logit_max_probs[j] for j in retrieval_indices] 367 | else: 368 | logit_prob = [-1.0] * len(paths) 369 | 370 | sim_scores = sim_matrix[i][retrieval_indices] 371 | 372 | if hasattr(transductive_set, "imgs"): 373 | outs_by_image_path[transductive_set.imgs[i]] = list( 374 | zip(*zip(*paths), logit_prob, sim_scores) 375 | ) 376 | 377 | outs.extend(paths) 378 | 379 | return set(list(outs)), outs_by_image_path 380 | 381 | 382 | def k_shot_imagefolder(dataset: SafeImageFolder, k=10): 383 | """Returns an ImageFolder dataset that only contains k images per class.""" 384 | old_imgs = dataset.imgs 385 | 386 | r.seed(0) 387 | classes_to_images = defaultdict(list) 388 | for i, (image, label) in enumerate(dataset.imgs): 389 | classes_to_images[label].append((image, label, i)) 390 | 391 | new_dataset = [] 392 | indices = np.zeros(len(old_imgs), dtype=np.uint8) 393 | for label, images in classes_to_images.items(): 394 | for image, _, i in r.sample(images, k): 395 | new_dataset.append((image, label)) 396 | indices[i] = 1 397 | 398 | dataset.imgs = dataset.samples = new_dataset 399 | 400 | return dataset, old_imgs, indices 401 | 402 | 403 | def k_shot_generic(dataset, k=10): 404 | # wayyyyy slower, fix later with dataloader 405 | r.seed(0) 406 | classes_to_images = defaultdict(list) 407 | 408 | for i, (image, label) in enumerate(dataset): 409 | classes_to_images[label].append(i) 410 | 411 | indices = np.zeros(len(dataset), dtype=np.uint8) 412 | for label, images in classes_to_images.items(): 413 | for i in r.sample(images, k): 414 | indices[i] = 1 415 | 416 | return SubsetDataset(dataset, np.where(indices == 1)), np.array(indices) 417 | 418 | 419 | def maybe_cache_output(fn, cache_path, **kwargs): 420 | cache_path = Path(cache_path) 421 | if cache_path.exists(): 422 | return torch.load(cache_path, map_location="cpu") 423 | else: 424 | os.makedirs(cache_path.parent, exist_ok=True) 425 | output = fn(**kwargs) 426 | torch.save(output, cache_path) 427 | 428 | return output 429 | 430 | 431 | def get_random_grid(image_list, transform, k=10): 432 | dataset = ImageListDataset(image_list, transform=transform) 433 | 434 | imgs = [] 435 | for i in range(len(dataset)): 436 | imgs.append(dataset[i][0]) 437 | 438 | imgs = torch.stack(imgs) 439 | grid = torchvision.utils.make_grid(imgs, nrow=int(math.sqrt(k)), normalize=True) 440 | 441 | return grid 442 | 443 | 444 | def get_grids_from_loader(loader, transductive_set, synsets): 445 | loader_iter = iter(loader) 446 | count = 0 447 | while True: 448 | batch, labels = next(loader_iter) 449 | texts = [synsets[transductive_set.classes[i]] for i in labels] 450 | 451 | k = int(math.sqrt(len(batch))) 452 | grid = torchvision.utils.make_grid(batch, nrow=k, normalize=True) 453 | pil_grid = add_text_to_grid(grid, texts, grid_width=k) 454 | pil_grid.save(f"test_grids/{count}.jpg") 455 | count += 1 456 | 457 | 458 | @torch.no_grad() 459 | def compute_zero_shot_head(model, class_prompts_dict, classnames, device=0): 460 | zero_shot_head = [] 461 | 462 | for i, classname in tqdm.tqdm( 463 | enumerate(classnames), total=len(classnames), desc="Computing zeroshot head" 464 | ): 465 | prompts = class_prompts_dict[classname] 466 | tokens = clip.tokenize(prompts).to(device) 467 | text_embeddings = model.encode_text(tokens).float() 468 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) 469 | zero_shot_head.append(text_embeddings.mean(axis=0)) 470 | 471 | zero_shot_head = torch.stack(zero_shot_head, dim=0) 472 | zero_shot_head = zero_shot_head / zero_shot_head.norm(dim=-1, keepdim=True) 473 | 474 | return zero_shot_head 475 | 476 | 477 | def fine_tune_zero_shot_head( 478 | model, 479 | dataloader, 480 | val_loader, 481 | zero_shot_head, 482 | epochs, 483 | learning_rate, 484 | device="cuda:0", 485 | alpha=0.0, 486 | ): 487 | model.train() 488 | 489 | head = nn.Linear( 490 | in_features=zero_shot_head.shape[1], out_features=zero_shot_head.shape[0] 491 | ).to(device) 492 | 493 | zero_shot_head_copy = zero_shot_head.clone().cpu() 494 | head.weight.data = zero_shot_head.to(device).float() 495 | head.bias.data = torch.zeros_like(head.bias.data).to(device) 496 | 497 | criterion = torch.nn.CrossEntropyLoss() 498 | optimizer = optim.SGD( 499 | head.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4 500 | ) 501 | 502 | # Add cosine annealing scheduler 503 | scheduler = optim.lr_scheduler.CosineAnnealingLR( 504 | optimizer, T_max=len(dataloader) * epochs, eta_min=0 505 | ) 506 | 507 | for epoch in range(epochs): 508 | for i, (images, labels) in enumerate(dataloader): 509 | images, labels = images.to(device), labels.to(device) 510 | 511 | optimizer.zero_grad() 512 | 513 | with torch.no_grad(): 514 | image_features = F.normalize( 515 | model.encode_image(images).float(), dim=1, p=2 516 | ) 517 | 518 | output = head(image_features) 519 | loss = criterion(output, labels) 520 | 521 | loss.backward() 522 | optimizer.step() 523 | 524 | # Update learning rate using the scheduler 525 | scheduler.step() 526 | 527 | print( 528 | f"""Epoch: {epoch+1}/{epochs}, 529 | Batch: {i+1}/{len(dataloader)}, 530 | Loss: {loss.item():0.4f}, 531 | LR: {scheduler.get_last_lr()[0]:.7f}""" 532 | ) 533 | 534 | # todo: ensemble head 535 | with torch.no_grad(): 536 | old_head_weight = head.weight.data.clone() 537 | old_head_bias = head.bias.data.clone() 538 | 539 | head.weight.data = ( 540 | alpha * zero_shot_head_copy.to(device).float() 541 | + (1 - alpha) * head.weight.data 542 | ) 543 | head.bias.data = head.bias.data * (1 - alpha) 544 | 545 | accuracy = compute_accuracy(model, head, val_loader, device) 546 | head.weight.data = old_head_weight 547 | head.bias.data = old_head_bias 548 | 549 | print(f"Epoch: {0}/{epochs}, Accuracy: {accuracy:.2f}%") 550 | 551 | return model 552 | 553 | 554 | @torch.no_grad() 555 | def compute_accuracy(model, head, dataloader, device): 556 | model.eval() 557 | correct = 0 558 | total = 0 559 | with torch.no_grad(): 560 | for images, labels in dataloader: 561 | images, labels = images.to(device), labels.to(device) 562 | image_features = F.normalize(model.encode_image(images).float(), dim=1, p=2) 563 | output = head(image_features) 564 | _, predicted = torch.max(output.data, 1) 565 | total += labels.size(0) 566 | correct += (predicted == labels).sum().item() 567 | model.train() 568 | return 100 * correct / total 569 | 570 | 571 | def image_grids_to_folders( 572 | paths_to_images, preprocess, transductive_set, retrieval_set, imagenet_to_classname 573 | ): 574 | save_path = Path("test_grids") 575 | save_path.mkdir(exist_ok=True) 576 | count = 0 577 | 578 | path_iter = iter(paths_to_images.items()) 579 | 580 | while True: 581 | count += 1 582 | (img, img_label), batch = next(path_iter) 583 | 584 | batch = [(img, img_label, 0.0, 0.0)] + batch 585 | 586 | grid = get_random_grid([(a, b) for a, b, _, _ in batch], preprocess) 587 | classnames = [ 588 | f"{imagenet_to_classname[transductive_set.classes[label]]} ({logit_prob:0.4f}) ({sim_score:0.4f}))" 589 | for _, label, logit_prob, sim_score in batch 590 | ] 591 | classnames[0] = f"GT: {imagenet_to_classname[retrieval_set.classes[img_label]]}" 592 | grid = add_text_to_grid(grid, classnames) 593 | 594 | (save_path / transductive_set.classes[img_label]).mkdir(exist_ok=True) 595 | 596 | grid.save(save_path / transductive_set.classes[img_label] / f"{count}.jpg") 597 | 598 | if count > 2000: 599 | break 600 | 601 | 602 | def add_text_to_grid(grid, text_list, grid_width=3): 603 | # convert the tensor grid to a PIL image 604 | pil_grid = torchvision.transforms.ToPILImage()(grid) 605 | 606 | # loop through the list of texts and draw each text on the corresponding image 607 | for i, text in enumerate(text_list): 608 | x = i % grid_width # calculate x coordinate based on index i 609 | y = i // grid_width # calculate y coordinate based on index i 610 | draw = ImageDraw.Draw(pil_grid) 611 | 612 | draw.text( 613 | (x * 224 + 5 * x, y * 224 + 5 * y), text 614 | ) # adjust the position of text to your liking 615 | 616 | return pil_grid 617 | 618 | 619 | def relabel_mismatch_classes( 620 | trans_set: SafeImageFolder, retrieval_set: SafeImageFolder 621 | ): 622 | # Assuming trans_set is the groundtruth, retrieval_set labels are recomputed wrt trans_set 623 | imgs = [] 624 | for path, label in trans_set.imgs: 625 | class_name = trans_set.classes[label] 626 | if class_name in retrieval_set.class_to_idx: 627 | imgs.append((path, retrieval_set.class_to_idx[class_name])) 628 | 629 | trans_set.classes = retrieval_set.classes 630 | trans_set.class_to_idx = retrieval_set.class_to_idx 631 | trans_set.imgs = trans_set.samples = imgs 632 | 633 | return trans_set 634 | 635 | 636 | def filter_bad_images(dataset: SafeImageFolder, cache=None): 637 | cache = Path(cache) 638 | 639 | if cache.exists(): 640 | preload_cache = np.load(cache) 641 | else: 642 | preload_cache = np.ones(len(dataset)) 643 | 644 | loader = DataLoader( 645 | BadImageCheckerDataset(dataset), 646 | batch_size=256, 647 | shuffle=False, 648 | num_workers=16, 649 | ) 650 | pointer = 0 651 | 652 | for check in tqdm.tqdm(loader, desc="Filtering bad images"): 653 | preload_cache[pointer : pointer + len(check)] = check.float().numpy() 654 | pointer += len(check) 655 | 656 | dataset.imgs = [dataset.imgs[i] for i in np.where(preload_cache == 1)[0]] 657 | dataset.samples = dataset.imgs 658 | 659 | if not cache.exists(): 660 | np.save(cache, preload_cache) 661 | 662 | return dataset 663 | 664 | 665 | def batched_inner_products(m1, m2, out, batch_size=256): 666 | if os.path.exists(out): 667 | return np.load(out, mmap_mode="r") 668 | 669 | feature_vectors = open_memmap( 670 | out, 671 | dtype="float32", 672 | mode="w+", 673 | shape=(m1.shape[0], m2.shape[0]), 674 | ) 675 | 676 | with torch.no_grad(): 677 | count_m1 = 0 678 | 679 | for batch_m1 in tqdm.tqdm( 680 | batchify(m1, batch_size=batch_size), total=m1.shape[0] // batch_size + 1 681 | ): 682 | count_m2 = 0 683 | 684 | for batch_m2 in batchify(m2, batch_size=batch_size): 685 | inner_product = torch.einsum( 686 | "ij,kj->ik", 687 | torch.tensor(batch_m1).cuda(), 688 | torch.tensor(batch_m2).cuda(), 689 | ) 690 | 691 | feature_vectors[ 692 | count_m1 : count_m1 + batch_m1.shape[0], 693 | count_m2 : count_m2 + batch_m2.shape[0], 694 | ] = inner_product.cpu().numpy() 695 | count_m2 += batch_m2.shape[0] 696 | 697 | count_m1 += batch_m1.shape[0] 698 | 699 | return feature_vectors 700 | 701 | 702 | def extract_features(loader, model, memmap_file): 703 | if os.path.exists(memmap_file): 704 | return np.load(memmap_file, mmap_mode="r") 705 | 706 | # Create a numpy memmap to store the feature vectors 707 | feature_size = model.visual.output_dim 708 | feature_vectors = open_memmap( 709 | memmap_file, 710 | dtype="float32", 711 | mode="w+", 712 | shape=(len(loader.dataset), feature_size), 713 | ) 714 | 715 | # Set the model to evaluation mode 716 | model.eval() 717 | 718 | # Iterate through the images and extract the feature vectors 719 | count = 0 720 | with torch.no_grad(): 721 | for i, batch in tqdm.tqdm( 722 | enumerate(loader), total=len(loader), ascii=True, desc="feature extraction" 723 | ): 724 | # Preprocess the image 725 | images = batch[0] 726 | images = images.to(0) 727 | 728 | # Pass the image through the model to get the feature vector 729 | feature_vector = ( 730 | F.normalize(model.encode_image(images), p=2, dim=1).cpu().numpy() 731 | ) 732 | 733 | # Store the feature vector in the memmap 734 | feature_vectors[count : count + len(images)] = feature_vector 735 | count += len(images) 736 | 737 | return feature_vectors 738 | 739 | 740 | def batchify(iterable, batch_size=256): 741 | num_batches = math.ceil(len(iterable) / batch_size) 742 | 743 | for i in range(num_batches): 744 | yield iterable[i * batch_size : i * batch_size + batch_size] 745 | 746 | 747 | class ImageListDataset(Dataset): 748 | def __init__(self, imgs, transform=None) -> None: 749 | self.imgs = imgs 750 | self.transform = transform 751 | 752 | def __len__(self): 753 | return len(self.imgs) 754 | 755 | def __getitem__(self, index): 756 | img_path, label = self.imgs[index] 757 | img = default_loader(img_path) 758 | 759 | if self.transform is not None: 760 | img = self.transform(img) 761 | 762 | return img, label 763 | 764 | 765 | class SubsetDataset(Dataset): 766 | def __init__(self, dataset, indices): 767 | super().__init__() 768 | self.indices = indices 769 | self.dataset = dataset 770 | 771 | def __len__(self): 772 | return len(self.indices) 773 | 774 | def __getitem__(self, index): 775 | return self.dataset[self.indices[index]] 776 | 777 | 778 | class BadImageCheckerDataset(Dataset): 779 | def __init__(self, dataset): 780 | super().__init__() 781 | self.dataset = dataset 782 | 783 | def __getitem__(self, index): 784 | try: 785 | entry = self.dataset[index] 786 | return True 787 | except: 788 | return False 789 | 790 | def __len__(self): 791 | return len(self.dataset) 792 | 793 | 794 | if __name__ == "__main__": 795 | main() 796 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 RAIVN Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Priming/Templates/FGVCAircraft.py: -------------------------------------------------------------------------------- 1 | classes = [ 2 | '707-320', 3 | '727-200', 4 | '737-200', 5 | '737-300', 6 | '737-400', 7 | '737-500', 8 | '737-600', 9 | '737-700', 10 | '737-800', 11 | '737-900', 12 | '747-100', 13 | '747-200', 14 | '747-300', 15 | '747-400', 16 | '757-200', 17 | '757-300', 18 | '767-200', 19 | '767-300', 20 | '767-400', 21 | '777-200', 22 | '777-300', 23 | 'A300B4', 24 | 'A310', 25 | 'A318', 26 | 'A319', 27 | 'A320', 28 | 'A321', 29 | 'A330-200', 30 | 'A330-300', 31 | 'A340-200', 32 | 'A340-300', 33 | 'A340-500', 34 | 'A340-600', 35 | 'A380', 36 | 'ATR-42', 37 | 'ATR-72', 38 | 'An-12', 39 | 'BAE 146-200', 40 | 'BAE 146-300', 41 | 'BAE-125', 42 | 'Beechcraft 1900', 43 | 'Boeing 717', 44 | 'C-130', 45 | 'C-47', 46 | 'CRJ-200', 47 | 'CRJ-700', 48 | 'CRJ-900', 49 | 'Cessna 172', 50 | 'Cessna 208', 51 | 'Cessna 525', 52 | 'Cessna 560', 53 | 'Challenger 600', 54 | 'DC-10', 55 | 'DC-3', 56 | 'DC-6', 57 | 'DC-8', 58 | 'DC-9-30', 59 | 'DH-82', 60 | 'DHC-1', 61 | 'DHC-6', 62 | 'DHC-8-100', 63 | 'DHC-8-300', 64 | 'DR-400', 65 | 'Dornier 328', 66 | 'E-170', 67 | 'E-190', 68 | 'E-195', 69 | 'EMB-120', 70 | 'ERJ 135', 71 | 'ERJ 145', 72 | 'Embraer Legacy 600', 73 | 'Eurofighter Typhoon', 74 | 'F-16A_B', 75 | 'F_A-18', 76 | 'Falcon 2000', 77 | 'Falcon 900', 78 | 'Fokker 100', 79 | 'Fokker 50', 80 | 'Fokker 70', 81 | 'Global Express', 82 | 'Gulfstream IV', 83 | 'Gulfstream V', 84 | 'Hawk T1', 85 | 'Il-76', 86 | 'L-1011', 87 | 'MD-11', 88 | 'MD-80', 89 | 'MD-87', 90 | 'MD-90', 91 | 'Metroliner', 92 | 'Model B200', 93 | 'PA-28', 94 | 'SR-20', 95 | 'Saab 2000', 96 | 'Saab 340', 97 | 'Spitfire', 98 | 'Tornado', 99 | 'Tu-134', 100 | 'Tu-154', 101 | 'Yak-42', 102 | ] 103 | 104 | templates = [ 105 | 'a photo of a {}, a type of aircraft.', 106 | 'a photo of the {}, a type of aircraft.', 107 | ] -------------------------------------------------------------------------------- /Priming/Templates/FGVC_Aircraft.txt: -------------------------------------------------------------------------------- 1 | 707-320 2 | 727-200 3 | 737-200 4 | 737-300 5 | 737-400 6 | 737-500 7 | 737-600 8 | 737-700 9 | 737-800 10 | 737-900 11 | 747-100 12 | 747-200 13 | 747-300 14 | 747-400 15 | 757-200 16 | 757-300 17 | 767-200 18 | 767-300 19 | 767-400 20 | 777-200 21 | 777-300 22 | A300B4 23 | A310 24 | A318 25 | A319 26 | A320 27 | A321 28 | A330-200 29 | A330-300 30 | A340-200 31 | A340-300 32 | A340-500 33 | A340-600 34 | A380 35 | ATR-42 36 | ATR-72 37 | An-12 38 | BAE 146-200 39 | BAE 146-300 40 | BAE-125 41 | Beechcraft 1900 42 | Boeing 717 43 | C-130 44 | C-47 45 | CRJ-200 46 | CRJ-700 47 | CRJ-900 48 | Cessna 172 49 | Cessna 208 50 | Cessna 525 51 | Cessna 560 52 | Challenger 600 53 | DC-10 54 | DC-3 55 | DC-6 56 | DC-8 57 | DC-9-30 58 | DH-82 59 | DHC-1 60 | DHC-6 61 | DHC-8-100 62 | DHC-8-300 63 | DR-400 64 | Dornier 328 65 | E-170 66 | E-190 67 | E-195 68 | EMB-120 69 | ERJ 135 70 | ERJ 145 71 | Embraer Legacy 600 72 | Eurofighter Typhoon 73 | F-16A/B 74 | F/A-18 75 | Falcon 2000 76 | Falcon 900 77 | Fokker 100 78 | Fokker 50 79 | Fokker 70 80 | Global Express 81 | Gulfstream IV 82 | Gulfstream V 83 | Hawk T1 84 | Il-76 85 | L-1011 86 | MD-11 87 | MD-80 88 | MD-87 89 | MD-90 90 | Metroliner 91 | Model B200 92 | PA-28 93 | SR-20 94 | Saab 2000 95 | Saab 340 96 | Spitfire 97 | Tornado 98 | Tu-134 99 | Tu-154 100 | Yak-42 101 | -------------------------------------------------------------------------------- /Priming/Templates/Flowers102.py: -------------------------------------------------------------------------------- 1 | classes = [ 2 | 'pink primrose', 3 | 'hard-leaved pocket orchid', 4 | 'canterbury bells', 5 | 'sweet pea', 6 | 'english marigold', 7 | 'tiger lily', 8 | 'moon orchid', 9 | 'bird of paradise', 10 | 'monkshood', 11 | 'globe thistle', 12 | 'snapdragon', 13 | "colt's foot", 14 | 'king protea', 15 | 'spear thistle', 16 | 'yellow iris', 17 | 'globe flower', 18 | 'purple coneflower', 19 | 'peruvian lily', 20 | 'balloon flower', 21 | 'giant white arum lily', 22 | 'fire lily', 23 | 'pincushion flower', 24 | 'fritillary', 25 | 'red ginger', 26 | 'grape hyacinth', 27 | 'corn poppy', 28 | 'prince of wales feathers', 29 | 'stemless gentian', 30 | 'artichoke', 31 | 'sweet william', 32 | 'carnation', 33 | 'garden phlox', 34 | 'love in the mist', 35 | 'mexican aster', 36 | 'alpine sea holly', 37 | 'ruby-lipped cattleya', 38 | 'cape flower', 39 | 'great masterwort', 40 | 'siam tulip', 41 | 'lenten rose', 42 | 'barbeton daisy', 43 | 'daffodil', 44 | 'sword lily', 45 | 'poinsettia', 46 | 'bolero deep blue', 47 | 'wallflower', 48 | 'marigold', 49 | 'buttercup', 50 | 'oxeye daisy', 51 | 'common dandelion', 52 | 'petunia', 53 | 'wild pansy', 54 | 'primula', 55 | 'sunflower', 56 | 'pelargonium', 57 | 'bishop of llandaff', 58 | 'gaura', 59 | 'geranium', 60 | 'orange dahlia', 61 | 'pink and yellow dahlia', 62 | 'cautleya spicata', 63 | 'japanese anemone', 64 | 'black-eyed susan', 65 | 'silverbush', 66 | 'californian poppy', 67 | 'osteospermum', 68 | 'spring crocus', 69 | 'bearded iris', 70 | 'windflower', 71 | 'tree poppy', 72 | 'gazania', 73 | 'azalea', 74 | 'water lily', 75 | 'rose', 76 | 'thorn apple', 77 | 'morning glory', 78 | 'passion flower', 79 | 'lotus', 80 | 'toad lily', 81 | 'anthurium', 82 | 'frangipani', 83 | 'clematis', 84 | 'hibiscus', 85 | 'columbine', 86 | 'desert-rose', 87 | 'tree mallow', 88 | 'magnolia', 89 | 'cyclamen', 90 | 'watercress', 91 | 'canna lily', 92 | 'hippeastrum', 93 | 'bee balm', 94 | 'air plant', 95 | 'foxglove', 96 | 'bougainvillea', 97 | 'camellia', 98 | 'mallow', 99 | 'mexican petunia', 100 | 'bromelia', 101 | 'blanket flower', 102 | 'trumpet creeper', 103 | 'blackberry lily', 104 | ] 105 | 106 | templates = [ 107 | 'a photo of a {}, a type of flower.', 108 | ] 109 | 110 | translate = dict({'wild pansy': 'pansy flower', 'sunflower': 'sun flower', 'poinsettia': 'pointsettia', 'black-eyed susan': 'black eyed susan', 111 | 'californian poppy': 'california poppy', 'camellia':'camelia', 'desert-rose': 'desert rose', 'cape flower': 'japanese spider lily'}) 112 | classes = [translate[x] if x in translate.keys() else x for x in classes] -------------------------------------------------------------------------------- /Priming/Templates/Flowers102.txt: -------------------------------------------------------------------------------- 1 | pink primrose 2 | hard-leaved pocket orchid 3 | canterbury bells 4 | sweet pea 5 | english marigold 6 | tiger lily 7 | moon orchid 8 | bird of paradise 9 | monkshood 10 | globe thistle 11 | snapdragon 12 | colt's foot 13 | king protea 14 | spear thistle 15 | yellow iris 16 | globe flower 17 | purple coneflower 18 | peruvian lily 19 | balloon flower 20 | giant white arum lily 21 | fire lily 22 | pincushion flower 23 | fritillary 24 | red ginger 25 | grape hyacinth 26 | corn poppy 27 | prince of wales feathers 28 | stemless gentian 29 | artichoke 30 | sweet william 31 | carnation 32 | garden phlox 33 | love in the mist 34 | mexican aster 35 | alpine sea holly 36 | ruby-lipped cattleya 37 | red spider lily 38 | great masterwort 39 | siam tulip 40 | lenten rose 41 | barbeton daisy 42 | daffodil 43 | sword lily 44 | pointsettia 45 | bolero deep blue 46 | wallflower 47 | marigold 48 | buttercup 49 | oxeye daisy 50 | common dandelion 51 | petunia 52 | pansy flower 53 | primula 54 | sun flower 55 | pelargonium 56 | bishop of llandaff 57 | gaura 58 | geranium 59 | orange dahlia 60 | pink and yellow dahlia 61 | cautleya spicata 62 | japanese anemone 63 | black eyed susan 64 | silverbush 65 | california poppy 66 | osteospermum 67 | spring crocus 68 | bearded iris 69 | windflower 70 | tree poppy 71 | gazania 72 | azalea 73 | water lily 74 | rose 75 | thorn apple 76 | morning glory 77 | passion flower 78 | lotus 79 | toad lily 80 | anthurium 81 | frangipani 82 | clematis 83 | hibiscus 84 | columbine 85 | desert rose 86 | tree mallow 87 | magnolia 88 | cyclamen 89 | watercress 90 | canna lily 91 | hippeastrum 92 | bee balm 93 | air plant 94 | foxglove 95 | bougainvillea 96 | camelia 97 | mallow 98 | mexican petunia 99 | bromelia 100 | blanket flower 101 | trumpet creeper 102 | blackberry lily 103 | -------------------------------------------------------------------------------- /Priming/Templates/Food101.py: -------------------------------------------------------------------------------- 1 | classes = [ 2 | 'apple pie', 3 | 'baby back ribs', 4 | 'baklava', 5 | 'beef carpaccio', 6 | 'beef tartare', 7 | 'beet salad', 8 | 'beignets', 9 | 'bibimbap', 10 | 'bread pudding', 11 | 'breakfast burrito', 12 | 'bruschetta', 13 | 'caesar salad', 14 | 'cannoli', 15 | 'caprese salad', 16 | 'carrot cake', 17 | 'ceviche', 18 | 'cheese plate', 19 | 'cheesecake', 20 | 'chicken curry', 21 | 'chicken quesadilla', 22 | 'chicken wings', 23 | 'chocolate cake', 24 | 'chocolate mousse', 25 | 'churros', 26 | 'clam chowder', 27 | 'club sandwich', 28 | 'crab cakes', 29 | 'creme brulee', 30 | 'croque madame', 31 | 'cup cakes', 32 | 'deviled eggs', 33 | 'donuts', 34 | 'dumplings', 35 | 'edamame', 36 | 'eggs benedict', 37 | 'escargots', 38 | 'falafel', 39 | 'filet mignon', 40 | 'fish and chips', 41 | 'foie gras', 42 | 'french fries', 43 | 'french onion soup', 44 | 'french toast', 45 | 'fried calamari', 46 | 'fried rice', 47 | 'frozen yogurt', 48 | 'garlic bread', 49 | 'gnocchi', 50 | 'greek salad', 51 | 'grilled cheese sandwich', 52 | 'grilled salmon', 53 | 'guacamole', 54 | 'gyoza', 55 | 'hamburger', 56 | 'hot and sour soup', 57 | 'hot dog', 58 | 'huevos rancheros', 59 | 'hummus', 60 | 'ice cream', 61 | 'lasagna', 62 | 'lobster bisque', 63 | 'lobster roll sandwich', 64 | 'macaroni and cheese', 65 | 'macarons', 66 | 'miso soup', 67 | 'mussels', 68 | 'nachos', 69 | 'omelette', 70 | 'onion rings', 71 | 'oysters', 72 | 'pad thai', 73 | 'paella', 74 | 'pancakes', 75 | 'panna cotta', 76 | 'peking duck', 77 | 'pho', 78 | 'pizza', 79 | 'pork chop', 80 | 'poutine', 81 | 'prime rib', 82 | 'pulled pork sandwich', 83 | 'ramen', 84 | 'ravioli', 85 | 'red velvet cake', 86 | 'risotto', 87 | 'samosa', 88 | 'sashimi', 89 | 'scallops', 90 | 'seaweed salad', 91 | 'shrimp and grits', 92 | 'spaghetti bolognese', 93 | 'spaghetti carbonara', 94 | 'spring rolls', 95 | 'steak', 96 | 'strawberry shortcake', 97 | 'sushi', 98 | 'tacos', 99 | 'takoyaki', 100 | 'tiramisu', 101 | 'tuna tartare', 102 | 'waffles', 103 | ] 104 | 105 | templates = [ 106 | 'a photo of {}, a type of food.', 107 | ] -------------------------------------------------------------------------------- /Priming/Templates/Food101.txt: -------------------------------------------------------------------------------- 1 | apple pie 2 | baby back ribs 3 | baklava 4 | beef carpaccio 5 | beef tartare 6 | beet salad 7 | beignets 8 | bibimbap 9 | bread pudding 10 | breakfast burrito 11 | bruschetta 12 | caesar salad 13 | cannoli 14 | caprese salad 15 | carrot cake 16 | ceviche 17 | cheese plate 18 | cheesecake 19 | chicken curry 20 | chicken quesadilla 21 | chicken wings 22 | chocolate cake 23 | chocolate mousse 24 | churros 25 | clam chowder 26 | club sandwich 27 | crab cakes 28 | creme brulee 29 | croque madame 30 | cup cakes 31 | deviled eggs 32 | donuts 33 | dumplings 34 | edamame 35 | eggs benedict 36 | escargots 37 | falafel 38 | filet mignon 39 | fish and chips 40 | foie gras 41 | french fries 42 | french onion soup 43 | french toast 44 | fried calamari 45 | fried rice 46 | frozen yogurt 47 | garlic bread 48 | gnocchi 49 | greek salad 50 | grilled cheese sandwich 51 | grilled salmon 52 | guacamole 53 | gyoza 54 | hamburger 55 | hot and sour soup 56 | hot dog 57 | huevos rancheros 58 | hummus 59 | ice cream 60 | lasagna 61 | lobster bisque 62 | lobster roll sandwich 63 | macaroni and cheese 64 | macarons 65 | miso soup 66 | mussels 67 | nachos 68 | omelette 69 | onion rings 70 | oysters 71 | pad thai 72 | paella 73 | pancakes 74 | panna cotta 75 | peking duck 76 | pho 77 | pizza 78 | pork chop 79 | poutine 80 | prime rib 81 | pulled pork sandwich 82 | ramen 83 | ravioli 84 | red velvet cake 85 | risotto 86 | samosa 87 | sashimi 88 | scallops 89 | seaweed salad 90 | shrimp and grits 91 | spaghetti bolognese 92 | spaghetti carbonara 93 | spring rolls 94 | steak 95 | strawberry shortcake 96 | sushi 97 | tacos 98 | takoyaki 99 | tiramisu 100 | tuna tartare 101 | waffles 102 | -------------------------------------------------------------------------------- /Priming/Templates/ImageNet.py: -------------------------------------------------------------------------------- 1 | #Subset of the 80 Open AI templates that approximates the full set. 2 | templates = [ 3 | 'itap of a {}.', 4 | 'a bad photo of the {}.', 5 | 'a origami {}.', 6 | 'a photo of the large {}.', 7 | 'a {} in a video game.', 8 | 'art of the {}.', 9 | 'a photo of the small {}'] 10 | 11 | templates_full = [ 12 | 'a bad photo of a {}.', 13 | 'a photo of many {}.', 14 | 'a sculpture of a {}.', 15 | 'a photo of the hard to see {}.', 16 | 'a low resolution photo of the {}.', 17 | 'a rendering of a {}.', 18 | 'graffiti of a {}.', 19 | 'a bad photo of the {}.', 20 | 'a cropped photo of the {}.', 21 | 'a tattoo of a {}.', 22 | 'the embroidered {}.', 23 | 'a photo of a hard to see {}.', 24 | 'a bright photo of a {}.', 25 | 'a photo of a clean {}.', 26 | 'a photo of a dirty {}.', 27 | 'a dark photo of the {}.', 28 | 'a drawing of a {}.', 29 | 'a photo of my {}.', 30 | 'the plastic {}.', 31 | 'a photo of the cool {}.', 32 | 'a close-up photo of a {}.', 33 | 'a black and white photo of the {}.', 34 | 'a painting of the {}.', 35 | 'a painting of a {}.', 36 | 'a pixelated photo of the {}.', 37 | 'a sculpture of the {}.', 38 | 'a bright photo of the {}.', 39 | 'a cropped photo of a {}.', 40 | 'a plastic {}.', 41 | 'a photo of the dirty {}.', 42 | 'a jpeg corrupted photo of a {}.', 43 | 'a blurry photo of the {}.', 44 | 'a photo of the {}.', 45 | 'a good photo of the {}.', 46 | 'a rendering of the {}.', 47 | 'a {} in a video game.', 48 | 'a photo of one {}.', 49 | 'a doodle of a {}.', 50 | 'a close-up photo of the {}.', 51 | 'a photo of a {}.', 52 | 'the origami {}.', 53 | 'the {} in a video game.', 54 | 'a sketch of a {}.', 55 | 'a doodle of the {}.', 56 | 'a origami {}.', 57 | 'a low resolution photo of a {}.', 58 | 'the toy {}.', 59 | 'a rendition of the {}.', 60 | 'a photo of the clean {}.', 61 | 'a photo of a large {}.', 62 | 'a rendition of a {}.', 63 | 'a photo of a nice {}.', 64 | 'a photo of a weird {}.', 65 | 'a blurry photo of a {}.', 66 | 'a cartoon {}.', 67 | 'art of a {}.', 68 | 'a sketch of the {}.', 69 | 'a embroidered {}.', 70 | 'a pixelated photo of a {}.', 71 | 'itap of the {}.', 72 | 'a jpeg corrupted photo of the {}.', 73 | 'a good photo of a {}.', 74 | 'a plushie {}.', 75 | 'a photo of the nice {}.', 76 | 'a photo of the small {}.', 77 | 'a photo of the weird {}.', 78 | 'the cartoon {}.', 79 | 'art of the {}.', 80 | 'a drawing of the {}.', 81 | 'a photo of the large {}.', 82 | 'a black and white photo of a {}.', 83 | 'the plushie {}.', 84 | 'a dark photo of a {}.', 85 | 'itap of a {}.', 86 | 'graffiti of the {}.', 87 | 'a toy {}.', 88 | 'itap of my {}.', 89 | 'a photo of a cool {}.', 90 | 'a photo of a small {}.', 91 | 'a tattoo of the {}.', 92 | ] 93 | -------------------------------------------------------------------------------- /Priming/Templates/OxfordIIITPet.py: -------------------------------------------------------------------------------- 1 | classes = ['Abyssinian', 2 | 'American Bulldog', 3 | 'American Pit Bull Terrier', 4 | 'Basset Hound', 5 | 'Beagle', 6 | 'Bengal', 7 | 'Birman', 8 | 'Bombay', 9 | 'Boxer', 10 | 'British Shorthair', 11 | 'Chihuahua', 12 | 'Egyptian Mau', 13 | 'English Cocker Spaniel', 14 | 'English Setter', 15 | 'German Shorthaired', 16 | 'Great Pyrenees', 17 | 'Havanese', 18 | 'Japanese Chin', 19 | 'Keeshond', 20 | 'Leonberger', 21 | 'Maine Coon', 22 | 'Miniature Pinscher', 'Newfoundland', 'Persian', 'Pomeranian', 'Pug', 'Ragdoll', 'Russian Blue', 'Saint Bernard', 'Samoyed', 'Scottish Terrier', 'Shiba Inu', 'Siamese', 'Sphynx', 'Staffordshire Bull Terrier', 'Wheaten Terrier', 'Yorkshire Terrier'] 23 | 24 | 25 | templates = [ 26 | 'a photo of a {}, a type of pet.', 27 | ] 28 | -------------------------------------------------------------------------------- /Priming/Templates/OxfordPets.txt: -------------------------------------------------------------------------------- 1 | Abyssinian 2 | Bengal 3 | Birman 4 | Bombay 5 | British Shorthair 6 | Egyptian Mau 7 | Maine Coon 8 | Persian 9 | Ragdoll 10 | Russian Blue 11 | Siamese 12 | Sphynx 13 | american bulldog 14 | american pit bull terrier 15 | basset hound 16 | beagle 17 | boxer 18 | chihuahua 19 | english cocker spaniel 20 | english setter 21 | german shorthaired 22 | great pyrenees 23 | havanese 24 | japanese chin 25 | keeshond 26 | leonberger 27 | miniature pinscher 28 | newfoundland 29 | pomeranian 30 | pug 31 | saint bernard 32 | samoyed 33 | scottish terrier 34 | shiba inu 35 | staffordshire bull terrier 36 | wheaten terrier 37 | yorkshire terrier 38 | -------------------------------------------------------------------------------- /Priming/Templates/SUN397.py: -------------------------------------------------------------------------------- 1 | classes = [ 2 | 'abbey', 3 | 'airplane cabin', 4 | 'airport terminal', 5 | 'alley', 6 | 'amphitheater', 7 | 'amusement arcade', 8 | 'amusement park', 9 | 'anechoic chamber', 10 | 'apartment building outdoor', 11 | 'apse indoor', 12 | 'aquarium', 13 | 'aqueduct', 14 | 'arch', 15 | 'archive', 16 | 'arrival gate outdoor', 17 | 'art gallery', 18 | 'art school', 19 | 'art studio', 20 | 'assembly line', 21 | 'athletic field outdoor', 22 | 'atrium public', 23 | 'attic', 24 | 'auditorium', 25 | 'auto factory', 26 | 'badlands', 27 | 'badminton court indoor', 28 | 'baggage claim', 29 | 'bakery shop', 30 | 'balcony exterior', 31 | 'balcony interior', 32 | 'ball pit', 33 | 'ballroom', 34 | 'bamboo forest', 35 | 'banquet hall', 36 | 'bar', 37 | 'barn', 38 | 'barndoor', 39 | 'baseball field', 40 | 'basement', 41 | 'basilica', 42 | 'basketball court outdoor', 43 | 'bathroom', 44 | 'batters box', 45 | 'bayou', 46 | 'bazaar indoor', 47 | 'bazaar outdoor', 48 | 'beach', 49 | 'beauty salon', 50 | 'bedroom', 51 | 'berth', 52 | 'biology laboratory', 53 | 'bistro indoor', 54 | 'boardwalk', 55 | 'boat deck', 56 | 'boathouse', 57 | 'bookstore', 58 | 'booth indoor', 59 | 'botanical garden', 60 | 'bow window indoor', 61 | 'bow window outdoor', 62 | 'bowling alley', 63 | 'boxing ring', 64 | 'brewery indoor', 65 | 'bridge', 66 | 'building facade', 67 | 'bullring', 68 | 'burial chamber', 69 | 'bus interior', 70 | 'butchers shop', 71 | 'butte', 72 | 'cabin outdoor', 73 | 'cafeteria', 74 | 'campsite', 75 | 'campus', 76 | 'canal natural', 77 | 'canal urban', 78 | 'candy store', 79 | 'canyon', 80 | 'car interior backseat', 81 | 'car interior frontseat', 82 | 'carrousel', 83 | 'casino indoor', 84 | 'castle', 85 | 'catacomb', 86 | 'cathedral indoor', 87 | 'cathedral outdoor', 88 | 'cavern indoor', 89 | 'cemetery', 90 | 'chalet', 91 | 'cheese factory', 92 | 'chemistry lab', 93 | 'chicken coop indoor', 94 | 'chicken coop outdoor', 95 | 'childs room', 96 | 'church indoor', 97 | 'church outdoor', 98 | 'classroom', 99 | 'clean room', 100 | 'cliff', 101 | 'cloister indoor', 102 | 'closet', 103 | 'clothing store', 104 | 'coast', 105 | 'cockpit', 106 | 'coffee shop', 107 | 'computer room', 108 | 'conference center', 109 | 'conference room', 110 | 'construction site', 111 | 'control room', 112 | 'control tower outdoor', 113 | 'corn field', 114 | 'corral', 115 | 'corridor', 116 | 'cottage garden', 117 | 'courthouse', 118 | 'courtroom', 119 | 'courtyard', 120 | 'covered bridge exterior', 121 | 'creek', 122 | 'crevasse', 123 | 'crosswalk', 124 | 'cubicle office', 125 | 'dam', 126 | 'delicatessen', 127 | 'dentists office', 128 | 'desert sand', 129 | 'desert vegetation', 130 | 'diner indoor', 131 | 'diner outdoor', 132 | 'dinette home', 133 | 'dinette vehicle', 134 | 'dining car', 135 | 'dining room', 136 | 'discotheque', 137 | 'dock', 138 | 'doorway outdoor', 139 | 'dorm room', 140 | 'driveway', 141 | 'driving range outdoor', 142 | 'drugstore', 143 | 'electrical substation', 144 | 'elevator door', 145 | 'elevator interior', 146 | 'elevator shaft', 147 | 'engine room', 148 | 'escalator indoor', 149 | 'excavation', 150 | 'factory indoor', 151 | 'fairway', 152 | 'fastfood restaurant', 153 | 'field cultivated', 154 | 'field wild', 155 | 'fire escape', 156 | 'fire station', 157 | 'firing range indoor', 158 | 'fishpond', 159 | 'florist shop indoor', 160 | 'food court', 161 | 'forest broadleaf', 162 | 'forest needleleaf', 163 | 'forest path', 164 | 'forest road', 165 | 'formal garden', 166 | 'fountain', 167 | 'galley', 168 | 'game room', 169 | 'garage indoor', 170 | 'garbage dump', 171 | 'gas station', 172 | 'gazebo exterior', 173 | 'general store indoor', 174 | 'general store outdoor', 175 | 'gift shop', 176 | 'golf course', 177 | 'greenhouse indoor', 178 | 'greenhouse outdoor', 179 | 'gymnasium indoor', 180 | 'hangar indoor', 181 | 'hangar outdoor', 182 | 'harbor', 183 | 'hayfield', 184 | 'heliport', 185 | 'herb garden', 186 | 'highway', 187 | 'hill', 188 | 'home office', 189 | 'hospital', 190 | 'hospital room', 191 | 'hot spring', 192 | 'hot tub outdoor', 193 | 'hotel outdoor', 194 | 'hotel room', 195 | 'house', 196 | 'hunting lodge outdoor', 197 | 'ice cream parlor', 198 | 'ice floe', 199 | 'ice shelf', 200 | 'ice skating rink indoor', 201 | 'ice skating rink outdoor', 202 | 'iceberg', 203 | 'igloo', 204 | 'industrial area', 205 | 'inn outdoor', 206 | 'islet', 207 | 'jacuzzi indoor', 208 | 'jail cell', 209 | 'jail indoor', 210 | 'jewelry shop', 211 | 'kasbah', 212 | 'kennel indoor', 213 | 'kennel outdoor', 214 | 'kindergarden classroom', 215 | 'kitchen', 216 | 'kitchenette', 217 | 'labyrinth outdoor', 218 | 'lake natural', 219 | 'landfill', 220 | 'landing deck', 221 | 'laundromat', 222 | 'lecture room', 223 | 'library indoor', 224 | 'library outdoor', 225 | 'lido deck outdoor', 226 | 'lift bridge', 227 | 'lighthouse', 228 | 'limousine interior', 229 | 'living room', 230 | 'lobby', 231 | 'lock chamber', 232 | 'locker room', 233 | 'mansion', 234 | 'manufactured home', 235 | 'market indoor', 236 | 'market outdoor', 237 | 'marsh', 238 | 'martial arts gym', 239 | 'mausoleum', 240 | 'medina', 241 | 'moat water', 242 | 'monastery outdoor', 243 | 'mosque indoor', 244 | 'mosque outdoor', 245 | 'motel', 246 | 'mountain', 247 | 'mountain snowy', 248 | 'movie theater indoor', 249 | 'museum indoor', 250 | 'music store', 251 | 'music studio', 252 | 'nuclear power plant outdoor', 253 | 'nursery', 254 | 'oast house', 255 | 'observatory outdoor', 256 | 'ocean', 257 | 'office', 258 | 'office building', 259 | 'oil refinery outdoor', 260 | 'oilrig', 261 | 'operating room', 262 | 'orchard', 263 | 'outhouse outdoor', 264 | 'pagoda', 265 | 'palace', 266 | 'pantry', 267 | 'park', 268 | 'parking garage indoor', 269 | 'parking garage outdoor', 270 | 'parking lot', 271 | 'parlor', 272 | 'pasture', 273 | 'patio', 274 | 'pavilion', 275 | 'pharmacy', 276 | 'phone booth', 277 | 'physics laboratory', 278 | 'picnic area', 279 | 'pilothouse indoor', 280 | 'planetarium outdoor', 281 | 'playground', 282 | 'playroom', 283 | 'plaza', 284 | 'podium indoor', 285 | 'podium outdoor', 286 | 'pond', 287 | 'poolroom establishment', 288 | 'poolroom home', 289 | 'power plant outdoor', 290 | 'promenade deck', 291 | 'pub indoor', 292 | 'pulpit', 293 | 'putting green', 294 | 'racecourse', 295 | 'raceway', 296 | 'raft', 297 | 'railroad track', 298 | 'rainforest', 299 | 'reception', 300 | 'recreation room', 301 | 'residential neighborhood', 302 | 'restaurant', 303 | 'restaurant kitchen', 304 | 'restaurant patio', 305 | 'rice paddy', 306 | 'riding arena', 307 | 'river', 308 | 'rock arch', 309 | 'rope bridge', 310 | 'ruin', 311 | 'runway', 312 | 'sandbar', 313 | 'sandbox', 314 | 'sauna', 315 | 'schoolhouse', 316 | 'sea cliff', 317 | 'server room', 318 | 'shed', 319 | 'shoe shop', 320 | 'shopfront', 321 | 'shopping mall indoor', 322 | 'shower', 323 | 'skatepark', 324 | 'ski lodge', 325 | 'ski resort', 326 | 'ski slope', 327 | 'sky', 328 | 'skyscraper', 329 | 'slum', 330 | 'snowfield', 331 | 'squash court', 332 | 'stable', 333 | 'stadium baseball', 334 | 'stadium football', 335 | 'stage indoor', 336 | 'staircase', 337 | 'street', 338 | 'subway interior', 339 | 'subway station platform', 340 | 'supermarket', 341 | 'sushi bar', 342 | 'swamp', 343 | 'swimming pool indoor', 344 | 'swimming pool outdoor', 345 | 'synagogue indoor', 346 | 'synagogue outdoor', 347 | 'television studio', 348 | 'temple east asia', 349 | 'temple south asia', 350 | 'tennis court indoor', 351 | 'tennis court outdoor', 352 | 'tent outdoor', 353 | 'theater indoor procenium', 354 | 'theater indoor seats', 355 | 'thriftshop', 356 | 'throne room', 357 | 'ticket booth', 358 | 'toll plaza', 359 | 'topiary garden', 360 | 'tower', 361 | 'toyshop', 362 | 'track outdoor', 363 | 'train railway', 364 | 'train station platform', 365 | 'tree farm', 366 | 'tree house', 367 | 'trench', 368 | 'underwater coral reef', 369 | 'utility room', 370 | 'valley', 371 | 'van interior', 372 | 'vegetable garden', 373 | 'veranda', 374 | 'veterinarians office', 375 | 'viaduct', 376 | 'videostore', 377 | 'village', 378 | 'vineyard', 379 | 'volcano', 380 | 'volleyball court indoor', 381 | 'volleyball court outdoor', 382 | 'waiting room', 383 | 'warehouse indoor', 384 | 'water tower', 385 | 'waterfall block', 386 | 'waterfall fan', 387 | 'waterfall plunge', 388 | 'watering hole', 389 | 'wave', 390 | 'wet bar', 391 | 'wheat field', 392 | 'wind farm', 393 | 'windmill', 394 | 'wine cellar barrel storage', 395 | 'wine cellar bottle storage', 396 | 'wrestling ring indoor', 397 | 'yard', 398 | 'youth hostel', 399 | ] 400 | 401 | templates = [ 402 | 'a photo of a {}.', 403 | 'a photo of the {}.', 404 | ] -------------------------------------------------------------------------------- /Priming/Templates/StanfordCars.py: -------------------------------------------------------------------------------- 1 | classes = [ 2 | 'AM General Hummer SUV 2000', 3 | 'Acura RL Sedan 2012', 4 | 'Acura TL Sedan 2012', 5 | 'Acura TL Type-S 2008', 6 | 'Acura TSX Sedan 2012', 7 | 'Acura Integra Type R 2001', 8 | 'Acura ZDX Hatchback 2012', 9 | 'Aston Martin V8 Vantage Convertible 2012', 10 | 'Aston Martin V8 Vantage Coupe 2012', 11 | 'Aston Martin Virage Convertible 2012', 12 | 'Aston Martin Virage Coupe 2012', 13 | 'Audi RS 4 Convertible 2008', 14 | 'Audi A5 Coupe 2012', 15 | 'Audi TTS Coupe 2012', 16 | 'Audi R8 Coupe 2012', 17 | 'Audi V8 Sedan 1994', 18 | 'Audi 100 Sedan 1994', 19 | 'Audi 100 Wagon 1994', 20 | 'Audi TT Hatchback 2011', 21 | 'Audi S6 Sedan 2011', 22 | 'Audi S5 Convertible 2012', 23 | 'Audi S5 Coupe 2012', 24 | 'Audi S4 Sedan 2012', 25 | 'Audi S4 Sedan 2007', 26 | 'Audi TT RS Coupe 2012', 27 | 'BMW ActiveHybrid 5 Sedan 2012', 28 | 'BMW 1 Series Convertible 2012', 29 | 'BMW 1 Series Coupe 2012', 30 | 'BMW 3 Series Sedan 2012', 31 | 'BMW 3 Series Wagon 2012', 32 | 'BMW 6 Series Convertible 2007', 33 | 'BMW X5 SUV 2007', 34 | 'BMW X6 SUV 2012', 35 | 'BMW M3 Coupe 2012', 36 | 'BMW M5 Sedan 2010', 37 | 'BMW M6 Convertible 2010', 38 | 'BMW X3 SUV 2012', 39 | 'BMW Z4 Convertible 2012', 40 | 'Bentley Continental Supersports Conv. Convertible 2012', 41 | 'Bentley Arnage Sedan 2009', 42 | 'Bentley Mulsanne Sedan 2011', 43 | 'Bentley Continental GT Coupe 2012', 44 | 'Bentley Continental GT Coupe 2007', 45 | 'Bentley Continental Flying Spur Sedan 2007', 46 | 'Bugatti Veyron 16.4 Convertible 2009', 47 | 'Bugatti Veyron 16.4 Coupe 2009', 48 | 'Buick Regal GS 2012', 49 | 'Buick Rainier SUV 2007', 50 | 'Buick Verano Sedan 2012', 51 | 'Buick Enclave SUV 2012', 52 | 'Cadillac CTS-V Sedan 2012', 53 | 'Cadillac SRX SUV 2012', 54 | 'Cadillac Escalade EXT Crew Cab 2007', 55 | 'Chevrolet Silverado 1500 Hybrid Crew Cab 2012', 56 | 'Chevrolet Corvette Convertible 2012', 57 | 'Chevrolet Corvette ZR1 2012', 58 | 'Chevrolet Corvette Ron Fellows Edition Z06 2007', 59 | 'Chevrolet Traverse SUV 2012', 60 | 'Chevrolet Camaro Convertible 2012', 61 | 'Chevrolet HHR SS 2010', 62 | 'Chevrolet Impala Sedan 2007', 63 | 'Chevrolet Tahoe Hybrid SUV 2012', 64 | 'Chevrolet Sonic Sedan 2012', 65 | 'Chevrolet Express Cargo Van 2007', 66 | 'Chevrolet Avalanche Crew Cab 2012', 67 | 'Chevrolet Cobalt SS 2010', 68 | 'Chevrolet Malibu Hybrid Sedan 2010', 69 | 'Chevrolet TrailBlazer SS 2009', 70 | 'Chevrolet Silverado 2500HD Regular Cab 2012', 71 | 'Chevrolet Silverado 1500 Classic Extended Cab 2007', 72 | 'Chevrolet Express Van 2007', 73 | 'Chevrolet Monte Carlo Coupe 2007', 74 | 'Chevrolet Malibu Sedan 2007', 75 | 'Chevrolet Silverado 1500 Extended Cab 2012', 76 | 'Chevrolet Silverado 1500 Regular Cab 2012', 77 | 'Chrysler Aspen SUV 2009', 78 | 'Chrysler Sebring Convertible 2010', 79 | 'Chrysler Town and Country Minivan 2012', 80 | 'Chrysler 300 SRT-8 2010', 81 | 'Chrysler Crossfire Convertible 2008', 82 | 'Chrysler PT Cruiser Convertible 2008', 83 | 'Daewoo Nubira Wagon 2002', 84 | 'Dodge Caliber Wagon 2012', 85 | 'Dodge Caliber Wagon 2007', 86 | 'Dodge Caravan Minivan 1997', 87 | 'Dodge Ram Pickup 3500 Crew Cab 2010', 88 | 'Dodge Ram Pickup 3500 Quad Cab 2009', 89 | 'Dodge Sprinter Cargo Van 2009', 90 | 'Dodge Journey SUV 2012', 91 | 'Dodge Dakota Crew Cab 2010', 92 | 'Dodge Dakota Club Cab 2007', 93 | 'Dodge Magnum Wagon 2008', 94 | 'Dodge Challenger SRT8 2011', 95 | 'Dodge Durango SUV 2012', 96 | 'Dodge Durango SUV 2007', 97 | 'Dodge Charger Sedan 2012', 98 | 'Dodge Charger SRT-8 2009', 99 | 'Eagle Talon Hatchback 1998', 100 | 'FIAT 500 Abarth 2012', 101 | 'FIAT 500 Convertible 2012', 102 | 'Ferrari FF Coupe 2012', 103 | 'Ferrari California Convertible 2012', 104 | 'Ferrari 458 Italia Convertible 2012', 105 | 'Ferrari 458 Italia Coupe 2012', 106 | 'Fisker Karma Sedan 2012', 107 | 'Ford F-450 Super Duty Crew Cab 2012', 108 | 'Ford Mustang Convertible 2007', 109 | 'Ford Freestar Minivan 2007', 110 | 'Ford Expedition EL SUV 2009', 111 | 'Ford Edge SUV 2012', 112 | 'Ford Ranger SuperCab 2011', 113 | 'Ford GT Coupe 2006', 114 | 'Ford F-150 Regular Cab 2012', 115 | 'Ford F-150 Regular Cab 2007', 116 | 'Ford Focus Sedan 2007', 117 | 'Ford E-Series Wagon Van 2012', 118 | 'Ford Fiesta Sedan 2012', 119 | 'GMC Terrain SUV 2012', 120 | 'GMC Savana Van 2012', 121 | 'GMC Yukon Hybrid SUV 2012', 122 | 'GMC Acadia SUV 2012', 123 | 'GMC Canyon Extended Cab 2012', 124 | 'Geo Metro Convertible 1993', 125 | 'HUMMER H3T Crew Cab 2010', 126 | 'HUMMER H2 SUT Crew Cab 2009', 127 | 'Honda Odyssey Minivan 2012', 128 | 'Honda Odyssey Minivan 2007', 129 | 'Honda Accord Coupe 2012', 130 | 'Honda Accord Sedan 2012', 131 | 'Hyundai Veloster Hatchback 2012', 132 | 'Hyundai Santa Fe SUV 2012', 133 | 'Hyundai Tucson SUV 2012', 134 | 'Hyundai Veracruz SUV 2012', 135 | 'Hyundai Sonata Hybrid Sedan 2012', 136 | 'Hyundai Elantra Sedan 2007', 137 | 'Hyundai Accent Sedan 2012', 138 | 'Hyundai Genesis Sedan 2012', 139 | 'Hyundai Sonata Sedan 2012', 140 | 'Hyundai Elantra Touring Hatchback 2012', 141 | 'Hyundai Azera Sedan 2012', 142 | 'Infiniti G Coupe IPL 2012', 143 | 'Infiniti QX56 SUV 2011', 144 | 'Isuzu Ascender SUV 2008', 145 | 'Jaguar XK XKR 2012', 146 | 'Jeep Patriot SUV 2012', 147 | 'Jeep Wrangler SUV 2012', 148 | 'Jeep Liberty SUV 2012', 149 | 'Jeep Grand Cherokee SUV 2012', 150 | 'Jeep Compass SUV 2012', 151 | 'Lamborghini Reventon Coupe 2008', 152 | 'Lamborghini Aventador Coupe 2012', 153 | 'Lamborghini Gallardo LP 570-4 Superleggera 2012', 154 | 'Lamborghini Diablo Coupe 2001', 155 | 'Land Rover Range Rover SUV 2012', 156 | 'Land Rover LR2 SUV 2012', 157 | 'Lincoln Town Car Sedan 2011', 158 | 'MINI Cooper Roadster Convertible 2012', 159 | 'Maybach Landaulet Convertible 2012', 160 | 'Mazda Tribute SUV 2011', 161 | 'McLaren MP4-12C Coupe 2012', 162 | 'Mercedes-Benz 300-Class Convertible 1993', 163 | 'Mercedes-Benz C-Class Sedan 2012', 164 | 'Mercedes-Benz SL-Class Coupe 2009', 165 | 'Mercedes-Benz E-Class Sedan 2012', 166 | 'Mercedes-Benz S-Class Sedan 2012', 167 | 'Mercedes-Benz Sprinter Van 2012', 168 | 'Mitsubishi Lancer Sedan 2012', 169 | 'Nissan Leaf Hatchback 2012', 170 | 'Nissan NV Passenger Van 2012', 171 | 'Nissan Juke Hatchback 2012', 172 | 'Nissan 240SX Coupe 1998', 173 | 'Plymouth Neon Coupe 1999', 174 | 'Porsche Panamera Sedan 2012', 175 | 'Ram CV Cargo Van Minivan 2012', 176 | 'Rolls-Royce Phantom Drophead Coupe Convertible 2012', 177 | 'Rolls-Royce Ghost Sedan 2012', 178 | 'Rolls-Royce Phantom Sedan 2012', 179 | 'Scion xD Hatchback 2012', 180 | 'Spyker C8 Convertible 2009', 181 | 'Spyker C8 Coupe 2009', 182 | 'Suzuki Aerio Sedan 2007', 183 | 'Suzuki Kizashi Sedan 2012', 184 | 'Suzuki SX4 Hatchback 2012', 185 | 'Suzuki SX4 Sedan 2012', 186 | 'Tesla Model S Sedan 2012', 187 | 'Toyota Sequoia SUV 2012', 188 | 'Toyota Camry Sedan 2012', 189 | 'Toyota Corolla Sedan 2012', 190 | 'Toyota 4Runner SUV 2012', 191 | 'Volkswagen Golf Hatchback 2012', 192 | 'Volkswagen Golf Hatchback 1991', 193 | 'Volkswagen Beetle Hatchback 2012', 194 | 'Volvo C30 Hatchback 2012', 195 | 'Volvo 240 Sedan 1993', 196 | 'Volvo XC90 SUV 2007', 197 | 'smart fortwo Convertible 2012', 198 | ] 199 | 200 | templates = [ 201 | 'a photo of a {}.', 202 | 'a photo of the {}.', 203 | 'a photo of my {}.', 204 | 'i love my {}!', 205 | 'a photo of my dirty {}.', 206 | 'a photo of my clean {}.', 207 | 'a photo of my new {}.', 208 | 'a photo of my old {}.', 209 | ] -------------------------------------------------------------------------------- /Priming/Templates/StanfordCars.txt: -------------------------------------------------------------------------------- 1 | AM General Hummer SUV 2000 2 | Acura RL Sedan 2012 3 | Acura TL Sedan 2012 4 | Acura TL Type-S 2008 5 | Acura TSX Sedan 2012 6 | Acura Integra Type R 2001 7 | Acura ZDX Hatchback 2012 8 | Aston Martin V8 Vantage Convertible 2012 9 | Aston Martin V8 Vantage Coupe 2012 10 | Aston Martin Virage Convertible 2012 11 | Aston Martin Virage Coupe 2012 12 | Audi RS 4 Convertible 2008 13 | Audi A5 Coupe 2012 14 | Audi TTS Coupe 2012 15 | Audi R8 Coupe 2012 16 | Audi V8 Sedan 1994 17 | Audi 100 Sedan 1994 18 | Audi 100 Wagon 1994 19 | Audi TT Hatchback 2011 20 | Audi S6 Sedan 2011 21 | Audi S5 Convertible 2012 22 | Audi S5 Coupe 2012 23 | Audi S4 Sedan 2012 24 | Audi S4 Sedan 2007 25 | Audi TT RS Coupe 2012 26 | BMW ActiveHybrid 5 Sedan 2012 27 | BMW 1 Series Convertible 2012 28 | BMW 1 Series Coupe 2012 29 | BMW 3 Series Sedan 2012 30 | BMW 3 Series Wagon 2012 31 | BMW 6 Series Convertible 2007 32 | BMW X5 SUV 2007 33 | BMW X6 SUV 2012 34 | BMW M3 Coupe 2012 35 | BMW M5 Sedan 2010 36 | BMW M6 Convertible 2010 37 | BMW X3 SUV 2012 38 | BMW Z4 Convertible 2012 39 | Bentley Continental Supersports Conv. Convertible 2012 40 | Bentley Arnage Sedan 2009 41 | Bentley Mulsanne Sedan 2011 42 | Bentley Continental GT Coupe 2012 43 | Bentley Continental GT Coupe 2007 44 | Bentley Continental Flying Spur Sedan 2007 45 | Bugatti Veyron 16.4 Convertible 2009 46 | Bugatti Veyron 16.4 Coupe 2009 47 | Buick Regal GS 2012 48 | Buick Rainier SUV 2007 49 | Buick Verano Sedan 2012 50 | Buick Enclave SUV 2012 51 | Cadillac CTS-V Sedan 2012 52 | Cadillac SRX SUV 2012 53 | Cadillac Escalade EXT Crew Cab 2007 54 | Chevrolet Silverado 1500 Hybrid Crew Cab 2012 55 | Chevrolet Corvette Convertible 2012 56 | Chevrolet Corvette ZR1 2012 57 | Chevrolet Corvette Ron Fellows Edition Z06 2007 58 | Chevrolet Traverse SUV 2012 59 | Chevrolet Camaro Convertible 2012 60 | Chevrolet HHR SS 2010 61 | Chevrolet Impala Sedan 2007 62 | Chevrolet Tahoe Hybrid SUV 2012 63 | Chevrolet Sonic Sedan 2012 64 | Chevrolet Express Cargo Van 2007 65 | Chevrolet Avalanche Crew Cab 2012 66 | Chevrolet Cobalt SS 2010 67 | Chevrolet Malibu Hybrid Sedan 2010 68 | Chevrolet TrailBlazer SS 2009 69 | Chevrolet Silverado 2500HD Regular Cab 2012 70 | Chevrolet Silverado 1500 Classic Extended Cab 2007 71 | Chevrolet Express Van 2007 72 | Chevrolet Monte Carlo Coupe 2007 73 | Chevrolet Malibu Sedan 2007 74 | Chevrolet Silverado 1500 Extended Cab 2012 75 | Chevrolet Silverado 1500 Regular Cab 2012 76 | Chrysler Aspen SUV 2009 77 | Chrysler Sebring Convertible 2010 78 | Chrysler Town and Country Minivan 2012 79 | Chrysler 300 SRT-8 2010 80 | Chrysler Crossfire Convertible 2008 81 | Chrysler PT Cruiser Convertible 2008 82 | Daewoo Nubira Wagon 2002 83 | Dodge Caliber Wagon 2012 84 | Dodge Caliber Wagon 2007 85 | Dodge Caravan Minivan 1997 86 | Dodge Ram Pickup 3500 Crew Cab 2010 87 | Dodge Ram Pickup 3500 Quad Cab 2009 88 | Dodge Sprinter Cargo Van 2009 89 | Dodge Journey SUV 2012 90 | Dodge Dakota Crew Cab 2010 91 | Dodge Dakota Club Cab 2007 92 | Dodge Magnum Wagon 2008 93 | Dodge Challenger SRT8 2011 94 | Dodge Durango SUV 2012 95 | Dodge Durango SUV 2007 96 | Dodge Charger Sedan 2012 97 | Dodge Charger SRT-8 2009 98 | Eagle Talon Hatchback 1998 99 | FIAT 500 Abarth 2012 100 | FIAT 500 Convertible 2012 101 | Ferrari FF Coupe 2012 102 | Ferrari California Convertible 2012 103 | Ferrari 458 Italia Convertible 2012 104 | Ferrari 458 Italia Coupe 2012 105 | Fisker Karma Sedan 2012 106 | Ford F-450 Super Duty Crew Cab 2012 107 | Ford Mustang Convertible 2007 108 | Ford Freestar Minivan 2007 109 | Ford Expedition EL SUV 2009 110 | Ford Edge SUV 2012 111 | Ford Ranger SuperCab 2011 112 | Ford GT Coupe 2006 113 | Ford F-150 Regular Cab 2012 114 | Ford F-150 Regular Cab 2007 115 | Ford Focus Sedan 2007 116 | Ford E-Series Wagon Van 2012 117 | Ford Fiesta Sedan 2012 118 | GMC Terrain SUV 2012 119 | GMC Savana Van 2012 120 | GMC Yukon Hybrid SUV 2012 121 | GMC Acadia SUV 2012 122 | GMC Canyon Extended Cab 2012 123 | Geo Metro Convertible 1993 124 | HUMMER H3T Crew Cab 2010 125 | HUMMER H2 SUT Crew Cab 2009 126 | Honda Odyssey Minivan 2012 127 | Honda Odyssey Minivan 2007 128 | Honda Accord Coupe 2012 129 | Honda Accord Sedan 2012 130 | Hyundai Veloster Hatchback 2012 131 | Hyundai Santa Fe SUV 2012 132 | Hyundai Tucson SUV 2012 133 | Hyundai Veracruz SUV 2012 134 | Hyundai Sonata Hybrid Sedan 2012 135 | Hyundai Elantra Sedan 2007 136 | Hyundai Accent Sedan 2012 137 | Hyundai Genesis Sedan 2012 138 | Hyundai Sonata Sedan 2012 139 | Hyundai Elantra Touring Hatchback 2012 140 | Hyundai Azera Sedan 2012 141 | Infiniti G Coupe IPL 2012 142 | Infiniti QX56 SUV 2011 143 | Isuzu Ascender SUV 2008 144 | Jaguar XK XKR 2012 145 | Jeep Patriot SUV 2012 146 | Jeep Wrangler SUV 2012 147 | Jeep Liberty SUV 2012 148 | Jeep Grand Cherokee SUV 2012 149 | Jeep Compass SUV 2012 150 | Lamborghini Reventon Coupe 2008 151 | Lamborghini Aventador Coupe 2012 152 | Lamborghini Gallardo LP 570-4 Superleggera 2012 153 | Lamborghini Diablo Coupe 2001 154 | Land Rover Range Rover SUV 2012 155 | Land Rover LR2 SUV 2012 156 | Lincoln Town Car Sedan 2011 157 | MINI Cooper Roadster Convertible 2012 158 | Maybach Landaulet Convertible 2012 159 | Mazda Tribute SUV 2011 160 | McLaren MP4-12C Coupe 2012 161 | Mercedes-Benz 300-Class Convertible 1993 162 | Mercedes-Benz C-Class Sedan 2012 163 | Mercedes-Benz SL-Class Coupe 2009 164 | Mercedes-Benz E-Class Sedan 2012 165 | Mercedes-Benz S-Class Sedan 2012 166 | Mercedes-Benz Sprinter Van 2012 167 | Mitsubishi Lancer Sedan 2012 168 | Nissan Leaf Hatchback 2012 169 | Nissan NV Passenger Van 2012 170 | Nissan Juke Hatchback 2012 171 | Nissan 240SX Coupe 1998 172 | Plymouth Neon Coupe 1999 173 | Porsche Panamera Sedan 2012 174 | Ram C/V Cargo Van Minivan 2012 175 | Rolls-Royce Phantom Drophead Coupe Convertible 2012 176 | Rolls-Royce Ghost Sedan 2012 177 | Rolls-Royce Phantom Sedan 2012 178 | Scion xD Hatchback 2012 179 | Spyker C8 Convertible 2009 180 | Spyker C8 Coupe 2009 181 | Suzuki Aerio Sedan 2007 182 | Suzuki Kizashi Sedan 2012 183 | Suzuki SX4 Hatchback 2012 184 | Suzuki SX4 Sedan 2012 185 | Tesla Model S Sedan 2012 186 | Toyota Sequoia SUV 2012 187 | Toyota Camry Sedan 2012 188 | Toyota Corolla Sedan 2012 189 | Toyota 4Runner SUV 2012 190 | Volkswagen Golf Hatchback 2012 191 | Volkswagen Golf Hatchback 1991 192 | Volkswagen Beetle Hatchback 2012 193 | Volvo C30 Hatchback 2012 194 | Volvo 240 Sedan 1993 195 | Volvo XC90 SUV 2007 196 | smart fortwo Convertible 2012 197 | -------------------------------------------------------------------------------- /Priming/Templates/Sun397.txt: -------------------------------------------------------------------------------- 1 | abbey 2 | airplane cabin 3 | airport terminal 4 | alley 5 | amphitheater 6 | amusement arcade 7 | amusement park 8 | anechoic chamber 9 | apartment building outdoor 10 | apse indoor 11 | aquarium 12 | aqueduct 13 | arch 14 | archive 15 | arrival gate outdoor 16 | art gallery 17 | art school 18 | art studio 19 | assembly line 20 | athletic field outdoor 21 | atrium public 22 | attic 23 | auditorium 24 | auto factory 25 | badlands 26 | badminton court indoor 27 | baggage claim 28 | bakery shop 29 | balcony exterior 30 | balcony interior 31 | ball pit 32 | ballroom 33 | bamboo forest 34 | banquet hall 35 | bar 36 | barn 37 | barndoor 38 | baseball field 39 | basement 40 | basilica 41 | basketball court outdoor 42 | bathroom 43 | batters box 44 | bayou 45 | bazaar indoor 46 | bazaar outdoor 47 | beach 48 | beauty salon 49 | bedroom 50 | berth 51 | biology laboratory 52 | bistro indoor 53 | boardwalk 54 | boat deck 55 | boathouse 56 | bookstore 57 | booth indoor 58 | botanical garden 59 | bow window indoor 60 | bow window outdoor 61 | bowling alley 62 | boxing ring 63 | brewery indoor 64 | bridge 65 | building facade 66 | bullring 67 | burial chamber 68 | bus interior 69 | butchers shop 70 | butte 71 | cabin outdoor 72 | cafeteria 73 | campsite 74 | campus 75 | canal natural 76 | canal urban 77 | candy store 78 | canyon 79 | car interior backseat 80 | car interior frontseat 81 | carrousel 82 | casino indoor 83 | castle 84 | catacomb 85 | cathedral indoor 86 | cathedral outdoor 87 | cavern indoor 88 | cemetery 89 | chalet 90 | cheese factory 91 | chemistry lab 92 | chicken coop indoor 93 | chicken coop outdoor 94 | childs room 95 | church indoor 96 | church outdoor 97 | classroom 98 | clean room 99 | cliff 100 | cloister indoor 101 | closet 102 | clothing store 103 | coast 104 | cockpit 105 | coffee shop 106 | computer room 107 | conference center 108 | conference room 109 | construction site 110 | control room 111 | control tower outdoor 112 | corn field 113 | corral 114 | corridor 115 | cottage garden 116 | courthouse 117 | courtroom 118 | courtyard 119 | covered bridge exterior 120 | creek 121 | crevasse 122 | crosswalk 123 | cubicle office 124 | dam 125 | delicatessen 126 | dentists office 127 | desert sand 128 | desert vegetation 129 | diner indoor 130 | diner outdoor 131 | dinette home 132 | dinette vehicle 133 | dining car 134 | dining room 135 | discotheque 136 | dock 137 | doorway outdoor 138 | dorm room 139 | driveway 140 | driving range outdoor 141 | drugstore 142 | electrical substation 143 | elevator door 144 | elevator interior 145 | elevator shaft 146 | engine room 147 | escalator indoor 148 | excavation 149 | factory indoor 150 | fairway 151 | fastfood restaurant 152 | field cultivated 153 | field wild 154 | fire escape 155 | fire station 156 | firing range indoor 157 | fishpond 158 | florist shop indoor 159 | food court 160 | forest broadleaf 161 | forest needleleaf 162 | forest path 163 | forest road 164 | formal garden 165 | fountain 166 | galley 167 | game room 168 | garage indoor 169 | garbage dump 170 | gas station 171 | gazebo exterior 172 | general store indoor 173 | general store outdoor 174 | gift shop 175 | golf course 176 | greenhouse indoor 177 | greenhouse outdoor 178 | gymnasium indoor 179 | hangar indoor 180 | hangar outdoor 181 | harbor 182 | hayfield 183 | heliport 184 | herb garden 185 | highway 186 | hill 187 | home office 188 | hospital 189 | hospital room 190 | hot spring 191 | hot tub outdoor 192 | hotel outdoor 193 | hotel room 194 | house 195 | hunting lodge outdoor 196 | ice cream parlor 197 | ice floe 198 | ice shelf 199 | ice skating rink indoor 200 | ice skating rink outdoor 201 | iceberg 202 | igloo 203 | industrial area 204 | inn outdoor 205 | islet 206 | jacuzzi indoor 207 | jail cell 208 | jail indoor 209 | jewelry shop 210 | kasbah 211 | kennel indoor 212 | kennel outdoor 213 | kindergarden classroom 214 | kitchen 215 | kitchenette 216 | labyrinth outdoor 217 | lake natural 218 | landfill 219 | landing deck 220 | laundromat 221 | lecture room 222 | library indoor 223 | library outdoor 224 | lido deck outdoor 225 | lift bridge 226 | lighthouse 227 | limousine interior 228 | living room 229 | lobby 230 | lock chamber 231 | locker room 232 | mansion 233 | manufactured home 234 | market indoor 235 | market outdoor 236 | marsh 237 | martial arts gym 238 | mausoleum 239 | medina 240 | moat water 241 | monastery outdoor 242 | mosque indoor 243 | mosque outdoor 244 | motel 245 | mountain 246 | mountain snowy 247 | movie theater indoor 248 | museum indoor 249 | music store 250 | music studio 251 | nuclear power plant outdoor 252 | nursery 253 | oast house 254 | observatory outdoor 255 | ocean 256 | office 257 | office building 258 | oil refinery outdoor 259 | oilrig 260 | operating room 261 | orchard 262 | outhouse outdoor 263 | pagoda 264 | palace 265 | pantry 266 | park 267 | parking garage indoor 268 | parking garage outdoor 269 | parking lot 270 | parlor 271 | pasture 272 | patio 273 | pavilion 274 | pharmacy 275 | phone booth 276 | physics laboratory 277 | picnic area 278 | pilothouse indoor 279 | planetarium outdoor 280 | playground 281 | playroom 282 | plaza 283 | podium indoor 284 | podium outdoor 285 | pond 286 | poolroom establishment 287 | poolroom home 288 | power plant outdoor 289 | promenade deck 290 | pub indoor 291 | pulpit 292 | putting green 293 | racecourse 294 | raceway 295 | raft 296 | railroad track 297 | rainforest 298 | reception 299 | recreation room 300 | residential neighborhood 301 | restaurant 302 | restaurant kitchen 303 | restaurant patio 304 | rice paddy 305 | riding arena 306 | river 307 | rock arch 308 | rope bridge 309 | ruin 310 | runway 311 | sandbar 312 | sandbox 313 | sauna 314 | schoolhouse 315 | sea cliff 316 | server room 317 | shed 318 | shoe shop 319 | shopfront 320 | shopping mall indoor 321 | shower 322 | skatepark 323 | ski lodge 324 | ski resort 325 | ski slope 326 | sky 327 | skyscraper 328 | slum 329 | snowfield 330 | squash court 331 | stable 332 | stadium baseball 333 | stadium football 334 | stage indoor 335 | staircase 336 | street 337 | subway interior 338 | subway station platform 339 | supermarket 340 | sushi bar 341 | swamp 342 | swimming pool indoor 343 | swimming pool outdoor 344 | synagogue indoor 345 | synagogue outdoor 346 | television studio 347 | temple east asia 348 | temple south asia 349 | tennis court indoor 350 | tennis court outdoor 351 | tent outdoor 352 | theater indoor procenium 353 | theater indoor seats 354 | thriftshop 355 | throne room 356 | ticket booth 357 | toll plaza 358 | topiary garden 359 | tower 360 | toyshop 361 | track outdoor 362 | train railway 363 | train station platform 364 | tree farm 365 | tree house 366 | trench 367 | underwater coral reef 368 | utility room 369 | valley 370 | van interior 371 | vegetable garden 372 | veranda 373 | veterinarians office 374 | viaduct 375 | videostore 376 | village 377 | vineyard 378 | volcano 379 | volleyball court indoor 380 | volleyball court outdoor 381 | waiting room 382 | warehouse indoor 383 | water tower 384 | waterfall block 385 | waterfall fan 386 | waterfall plunge 387 | watering hole 388 | wave 389 | wet bar 390 | wheat field 391 | wind farm 392 | windmill 393 | wine cellar barrel storage 394 | wine cellar bottle storage 395 | wrestling ring indoor 396 | yard 397 | youth hostel 398 | -------------------------------------------------------------------------------- /Priming/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | 5 | def parse_arguments(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--dataset', type=str, default='StanfordCars',help = 'Name of target dataset, i.e. Flowers102, StanfordCars. Ensure that it matches the name in the Templates Folder.') 8 | parser.add_argument('--subset_path', type=str, default=None,help = 'Path to priming set.') 9 | parser.add_argument('--base_dataset', type=str, default='laion2b_s34b_b88k',help = 'Pretraining Dataset') 10 | parser.add_argument('--test_path', type=str, default='',help = 'Path to test dataset') 11 | parser.add_argument('--model', type=str, default='ViT-B-16',help = 'Model') 12 | parser.add_argument('--shots', type=int, default=1,help = 'Number of examples') 13 | parser.add_argument('--test_batches', type=int, default=np.inf,help = 'Number of test batches') 14 | parser.add_argument('--alpha', type=float, default=.9,help = 'Number of examples') 15 | parser.add_argument('--prime', action='store_true', help='Use the priming data') 16 | parser.add_argument('--retrain', action='store_true', help='Recache the image features') 17 | parser.add_argument('--cupl', action='store_true', help='Use the CuPL prompts to initialize the text classifier') 18 | parser.add_argument('--root', type=str, default='./',help = 'Base directory.') 19 | parser.add_argument('--val_path', type=str, default=None,help = 'Directory for custom evaluation data sets. For torchvision datasets leave as none and they will be automatically loaded.') 20 | parser.add_argument('--train_path', type=str, default=None,help = 'Directory for train data. Needed for non-torch vision datasets.') 21 | parser.add_argument('--batch_size', type=int, default=16, help = 'Adjust to use memory appropriately') 22 | parser.add_argument('--cache', type=bool, default=True, help = 'Cache image features for faster iteration.') 23 | parser.add_argument('--cuda', type=bool, default=True, help = 'Cache image features for faster iteration.') 24 | parser.add_argument('--cache_path', type=str, default='./cache/',help = 'Base directory.') 25 | parser.add_argument('--results_path', type=str, default='./results/',help = 'Base directory.') 26 | parser.add_argument('--custom_data', action='store_true', help = 'Whether to use a custom loader. Use for ImageNetv2, and other ImageNet variants as well as SUN.') 27 | parser.add_argument('--num_workers', type=int, default=8, help = 'Number of workers for each dataloader.') 28 | parsed_args = parser.parse_args() 29 | return parsed_args 30 | -------------------------------------------------------------------------------- /Priming/data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import torchvision.datasets as datasets 3 | import numpy as np 4 | import os 5 | from PIL import Image 6 | 7 | class CustomDataset(Dataset): 8 | def __init__(self, root, transform=None): 9 | self.root = root 10 | self.transform = transform 11 | self.imgs = [] 12 | self.classes = [] 13 | self.class_to_idx = {} 14 | self.idx_to_class = {} 15 | for target in np.sort(os.listdir(self.root)): 16 | target_path = os.path.join(self.root, target) 17 | if os.path.isdir(target_path): 18 | class_idx = len(self.class_to_idx) 19 | self.class_to_idx[target] = class_idx 20 | self.idx_to_class[class_idx] = target 21 | for root, _, fnames in sorted(os.walk(target_path)): 22 | if len(fnames) > 0: 23 | for fname in fnames: 24 | path = os.path.join(root, fname) 25 | self.imgs.append((path, class_idx)) 26 | 27 | 28 | def __getitem__(self, index): 29 | path, target = self.imgs[index] 30 | try: 31 | img = Image.open(path).convert('RGB') 32 | except: 33 | print('Error loading image:', path) 34 | img = Image.new('RGB', (224, 224)) 35 | if self.transform is not None: 36 | img = self.transform(img) 37 | return img, target 38 | 39 | def __len__(self): 40 | return len(self.imgs) 41 | 42 | def get_class_name(self, class_idx): 43 | return self.idx_to_class[class_idx] 44 | 45 | def get_class_idx(self, class_name): 46 | return self.class_to_idx[class_name] 47 | 48 | class ImageNetV2(datasets.ImageFolder): 49 | def __getitem__(self, index): 50 | path, target = self.samples[index] 51 | # get the folder name from the path 52 | folder_name = os.path.basename(os.path.dirname(path)) 53 | # convert the folder name to an integer 54 | target = int(folder_name) 55 | # return the image and the correct label 56 | return super().__getitem__(index)[0], target 57 | -------------------------------------------------------------------------------- /Priming/imagenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "n01440764": "tench", 3 | "n01443537": "goldfish", 4 | "n01484850": "great white shark", 5 | "n01491361": "tiger shark", 6 | "n01494475": "hammerhead shark", 7 | "n01496331": "electric ray", 8 | "n01498041": "stingray", 9 | "n01514668": "rooster", 10 | "n01514859": "hen", 11 | "n01518878": "ostrich", 12 | "n01530575": "brambling", 13 | "n01531178": "goldfinch", 14 | "n01532829": "house finch", 15 | "n01534433": "junco", 16 | "n01537544": "indigo bunting", 17 | "n01558993": "American robin", 18 | "n01560419": "bulbul", 19 | "n01580077": "jay", 20 | "n01582220": "magpie", 21 | "n01592084": "chickadee", 22 | "n01601694": "American dipper", 23 | "n01608432": "kite (bird of prey)", 24 | "n01614925": "bald eagle", 25 | "n01616318": "vulture", 26 | "n01622779": "great grey owl", 27 | "n01629819": "fire salamander", 28 | "n01630670": "smooth newt", 29 | "n01631663": "newt", 30 | "n01632458": "spotted salamander", 31 | "n01632777": "axolotl", 32 | "n01641577": "American bullfrog", 33 | "n01644373": "tree frog", 34 | "n01644900": "tailed frog", 35 | "n01664065": "loggerhead sea turtle", 36 | "n01665541": "leatherback sea turtle", 37 | "n01667114": "mud turtle", 38 | "n01667778": "terrapin", 39 | "n01669191": "box turtle", 40 | "n01675722": "banded gecko", 41 | "n01677366": "green iguana", 42 | "n01682714": "Carolina anole", 43 | "n01685808": "desert grassland whiptail lizard", 44 | "n01687978": "agama", 45 | "n01688243": "frilled-necked lizard", 46 | "n01689811": "alligator lizard", 47 | "n01692333": "Gila monster", 48 | "n01693334": "European green lizard", 49 | "n01694178": "chameleon", 50 | "n01695060": "Komodo dragon", 51 | "n01697457": "Nile crocodile", 52 | "n01698640": "American alligator", 53 | "n01704323": "triceratops", 54 | "n01728572": "worm snake", 55 | "n01728920": "ring-necked snake", 56 | "n01729322": "eastern hog-nosed snake", 57 | "n01729977": "smooth green snake", 58 | "n01734418": "kingsnake", 59 | "n01735189": "garter snake", 60 | "n01737021": "water snake", 61 | "n01739381": "vine snake", 62 | "n01740131": "night snake", 63 | "n01742172": "boa constrictor", 64 | "n01744401": "African rock python", 65 | "n01748264": "Indian cobra", 66 | "n01749939": "green mamba", 67 | "n01751748": "sea snake", 68 | "n01753488": "Saharan horned viper", 69 | "n01755581": "eastern diamondback rattlesnake", 70 | "n01756291": "sidewinder rattlesnake", 71 | "n01768244": "trilobite", 72 | "n01770081": "harvestman", 73 | "n01770393": "scorpion", 74 | "n01773157": "yellow garden spider", 75 | "n01773549": "barn spider", 76 | "n01773797": "European garden spider", 77 | "n01774384": "southern black widow", 78 | "n01774750": "tarantula", 79 | "n01775062": "wolf spider", 80 | "n01776313": "tick", 81 | "n01784675": "centipede", 82 | "n01795545": "black grouse", 83 | "n01796340": "ptarmigan", 84 | "n01797886": "ruffed grouse", 85 | "n01798484": "prairie grouse", 86 | "n01806143": "peafowl", 87 | "n01806567": "quail", 88 | "n01807496": "partridge", 89 | "n01817953": "african grey parrot", 90 | "n01818515": "macaw", 91 | "n01819313": "sulphur-crested cockatoo", 92 | "n01820546": "lorikeet", 93 | "n01824575": "coucal", 94 | "n01828970": "bee eater", 95 | "n01829413": "hornbill", 96 | "n01833805": "hummingbird", 97 | "n01843065": "jacamar", 98 | "n01843383": "toucan", 99 | "n01847000": "duck", 100 | "n01855032": "red-breasted merganser", 101 | "n01855672": "goose", 102 | "n01860187": "black swan", 103 | "n01871265": "tusker", 104 | "n01872401": "echidna", 105 | "n01873310": "platypus", 106 | "n01877812": "wallaby", 107 | "n01882714": "koala", 108 | "n01883070": "wombat", 109 | "n01910747": "jellyfish", 110 | "n01914609": "sea anemone", 111 | "n01917289": "brain coral", 112 | "n01924916": "flatworm", 113 | "n01930112": "nematode", 114 | "n01943899": "conch", 115 | "n01944390": "snail", 116 | "n01945685": "slug", 117 | "n01950731": "sea slug", 118 | "n01955084": "chiton", 119 | "n01968897": "chambered nautilus", 120 | "n01978287": "Dungeness crab", 121 | "n01978455": "rock crab", 122 | "n01980166": "fiddler crab", 123 | "n01981276": "red king crab", 124 | "n01983481": "American lobster", 125 | "n01984695": "spiny lobster", 126 | "n01985128": "crayfish", 127 | "n01986214": "hermit crab", 128 | "n01990800": "isopod", 129 | "n02002556": "white stork", 130 | "n02002724": "black stork", 131 | "n02006656": "spoonbill", 132 | "n02007558": "flamingo", 133 | "n02009229": "little blue heron", 134 | "n02009912": "great egret", 135 | "n02011460": "bittern bird", 136 | "n02012849": "crane bird", 137 | "n02013706": "limpkin", 138 | "n02017213": "common gallinule", 139 | "n02018207": "American coot", 140 | "n02018795": "bustard", 141 | "n02025239": "ruddy turnstone", 142 | "n02027492": "dunlin", 143 | "n02028035": "common redshank", 144 | "n02033041": "dowitcher", 145 | "n02037110": "oystercatcher", 146 | "n02051845": "pelican", 147 | "n02056570": "king penguin", 148 | "n02058221": "albatross", 149 | "n02066245": "grey whale", 150 | "n02071294": "killer whale", 151 | "n02074367": "dugong", 152 | "n02077923": "sea lion", 153 | "n02085620": "Chihuahua", 154 | "n02085782": "Japanese Chin", 155 | "n02085936": "Maltese", 156 | "n02086079": "Pekingese", 157 | "n02086240": "Shih Tzu", 158 | "n02086646": "King Charles Spaniel", 159 | "n02086910": "Papillon", 160 | "n02087046": "toy terrier", 161 | "n02087394": "Rhodesian Ridgeback", 162 | "n02088094": "Afghan Hound", 163 | "n02088238": "Basset Hound", 164 | "n02088364": "Beagle", 165 | "n02088466": "Bloodhound", 166 | "n02088632": "Bluetick Coonhound", 167 | "n02089078": "Black and Tan Coonhound", 168 | "n02089867": "Treeing Walker Coonhound", 169 | "n02089973": "English foxhound", 170 | "n02090379": "Redbone Coonhound", 171 | "n02090622": "borzoi", 172 | "n02090721": "Irish Wolfhound", 173 | "n02091032": "Italian Greyhound", 174 | "n02091134": "Whippet", 175 | "n02091244": "Ibizan Hound", 176 | "n02091467": "Norwegian Elkhound", 177 | "n02091635": "Otterhound", 178 | "n02091831": "Saluki", 179 | "n02092002": "Scottish Deerhound", 180 | "n02092339": "Weimaraner", 181 | "n02093256": "Staffordshire Bull Terrier", 182 | "n02093428": "American Staffordshire Terrier", 183 | "n02093647": "Bedlington Terrier", 184 | "n02093754": "Border Terrier", 185 | "n02093859": "Kerry Blue Terrier", 186 | "n02093991": "Irish Terrier", 187 | "n02094114": "Norfolk Terrier", 188 | "n02094258": "Norwich Terrier", 189 | "n02094433": "Yorkshire Terrier", 190 | "n02095314": "Wire Fox Terrier", 191 | "n02095570": "Lakeland Terrier", 192 | "n02095889": "Sealyham Terrier", 193 | "n02096051": "Airedale Terrier", 194 | "n02096177": "Cairn Terrier", 195 | "n02096294": "Australian Terrier", 196 | "n02096437": "Dandie Dinmont Terrier", 197 | "n02096585": "Boston Terrier", 198 | "n02097047": "Miniature Schnauzer", 199 | "n02097130": "Giant Schnauzer", 200 | "n02097209": "Standard Schnauzer", 201 | "n02097298": "Scottish Terrier", 202 | "n02097474": "Tibetan Terrier", 203 | "n02097658": "Australian Silky Terrier", 204 | "n02098105": "Soft-coated Wheaten Terrier", 205 | "n02098286": "West Highland White Terrier", 206 | "n02098413": "Lhasa Apso", 207 | "n02099267": "Flat-Coated Retriever", 208 | "n02099429": "Curly-coated Retriever", 209 | "n02099601": "Golden Retriever", 210 | "n02099712": "Labrador Retriever", 211 | "n02099849": "Chesapeake Bay Retriever", 212 | "n02100236": "German Shorthaired Pointer", 213 | "n02100583": "Vizsla", 214 | "n02100735": "English Setter", 215 | "n02100877": "Irish Setter", 216 | "n02101006": "Gordon Setter", 217 | "n02101388": "Brittany dog", 218 | "n02101556": "Clumber Spaniel", 219 | "n02102040": "English Springer Spaniel", 220 | "n02102177": "Welsh Springer Spaniel", 221 | "n02102318": "Cocker Spaniel", 222 | "n02102480": "Sussex Spaniel", 223 | "n02102973": "Irish Water Spaniel", 224 | "n02104029": "Kuvasz", 225 | "n02104365": "Schipperke", 226 | "n02105056": "Groenendael dog", 227 | "n02105162": "Malinois", 228 | "n02105251": "Briard", 229 | "n02105412": "Australian Kelpie", 230 | "n02105505": "Komondor", 231 | "n02105641": "Old English Sheepdog", 232 | "n02105855": "Shetland Sheepdog", 233 | "n02106030": "collie", 234 | "n02106166": "Border Collie", 235 | "n02106382": "Bouvier des Flandres dog", 236 | "n02106550": "Rottweiler", 237 | "n02106662": "German Shepherd Dog", 238 | "n02107142": "Dobermann", 239 | "n02107312": "Miniature Pinscher", 240 | "n02107574": "Greater Swiss Mountain Dog", 241 | "n02107683": "Bernese Mountain Dog", 242 | "n02107908": "Appenzeller Sennenhund", 243 | "n02108000": "Entlebucher Sennenhund", 244 | "n02108089": "Boxer", 245 | "n02108422": "Bullmastiff", 246 | "n02108551": "Tibetan Mastiff", 247 | "n02108915": "French Bulldog", 248 | "n02109047": "Great Dane", 249 | "n02109525": "St. Bernard", 250 | "n02109961": "husky", 251 | "n02110063": "Alaskan Malamute", 252 | "n02110185": "Siberian Husky", 253 | "n02110341": "Dalmatian", 254 | "n02110627": "Affenpinscher", 255 | "n02110806": "Basenji", 256 | "n02110958": "pug", 257 | "n02111129": "Leonberger", 258 | "n02111277": "Newfoundland dog", 259 | "n02111500": "Great Pyrenees dog", 260 | "n02111889": "Samoyed", 261 | "n02112018": "Pomeranian", 262 | "n02112137": "Chow Chow", 263 | "n02112350": "Keeshond", 264 | "n02112706": "brussels griffon", 265 | "n02113023": "Pembroke Welsh Corgi", 266 | "n02113186": "Cardigan Welsh Corgi", 267 | "n02113624": "Toy Poodle", 268 | "n02113712": "Miniature Poodle", 269 | "n02113799": "Standard Poodle", 270 | "n02113978": "Mexican hairless dog (xoloitzcuintli)", 271 | "n02114367": "grey wolf", 272 | "n02114548": "Alaskan tundra wolf", 273 | "n02114712": "red wolf or maned wolf", 274 | "n02114855": "coyote", 275 | "n02115641": "dingo", 276 | "n02115913": "dhole", 277 | "n02116738": "African wild dog", 278 | "n02117135": "hyena", 279 | "n02119022": "red fox", 280 | "n02119789": "kit fox", 281 | "n02120079": "Arctic fox", 282 | "n02120505": "grey fox", 283 | "n02123045": "tabby cat", 284 | "n02123159": "tiger cat", 285 | "n02123394": "Persian cat", 286 | "n02123597": "Siamese cat", 287 | "n02124075": "Egyptian Mau", 288 | "n02125311": "cougar", 289 | "n02127052": "lynx", 290 | "n02128385": "leopard", 291 | "n02128757": "snow leopard", 292 | "n02128925": "jaguar", 293 | "n02129165": "lion", 294 | "n02129604": "tiger", 295 | "n02130308": "cheetah", 296 | "n02132136": "brown bear", 297 | "n02133161": "American black bear", 298 | "n02134084": "polar bear", 299 | "n02134418": "sloth bear", 300 | "n02137549": "mongoose", 301 | "n02138441": "meerkat", 302 | "n02165105": "tiger beetle", 303 | "n02165456": "ladybug", 304 | "n02167151": "ground beetle", 305 | "n02168699": "longhorn beetle", 306 | "n02169497": "leaf beetle", 307 | "n02172182": "dung beetle", 308 | "n02174001": "rhinoceros beetle", 309 | "n02177972": "weevil", 310 | "n02190166": "fly", 311 | "n02206856": "bee", 312 | "n02219486": "ant", 313 | "n02226429": "grasshopper", 314 | "n02229544": "cricket insect", 315 | "n02231487": "stick insect", 316 | "n02233338": "cockroach", 317 | "n02236044": "praying mantis", 318 | "n02256656": "cicada", 319 | "n02259212": "leafhopper", 320 | "n02264363": "lacewing", 321 | "n02268443": "dragonfly", 322 | "n02268853": "damselfly", 323 | "n02276258": "red admiral butterfly", 324 | "n02277742": "ringlet butterfly", 325 | "n02279972": "monarch butterfly", 326 | "n02280649": "small white butterfly", 327 | "n02281406": "sulphur butterfly", 328 | "n02281787": "gossamer-winged butterfly", 329 | "n02317335": "starfish", 330 | "n02319095": "sea urchin", 331 | "n02321529": "sea cucumber", 332 | "n02325366": "cottontail rabbit", 333 | "n02326432": "hare", 334 | "n02328150": "Angora rabbit", 335 | "n02342885": "hamster", 336 | "n02346627": "porcupine", 337 | "n02356798": "fox squirrel", 338 | "n02361337": "marmot", 339 | "n02363005": "beaver", 340 | "n02364673": "guinea pig", 341 | "n02389026": "common sorrel horse", 342 | "n02391049": "zebra", 343 | "n02395406": "pig", 344 | "n02396427": "wild boar", 345 | "n02397096": "warthog", 346 | "n02398521": "hippopotamus", 347 | "n02403003": "ox", 348 | "n02408429": "water buffalo", 349 | "n02410509": "bison", 350 | "n02412080": "ram (adult male sheep)", 351 | "n02415577": "bighorn sheep", 352 | "n02417914": "Alpine ibex", 353 | "n02422106": "hartebeest", 354 | "n02422699": "impala (antelope)", 355 | "n02423022": "gazelle", 356 | "n02437312": "arabian camel", 357 | "n02437616": "llama", 358 | "n02441942": "weasel", 359 | "n02442845": "mink", 360 | "n02443114": "European polecat", 361 | "n02443484": "black-footed ferret", 362 | "n02444819": "otter", 363 | "n02445715": "skunk", 364 | "n02447366": "badger", 365 | "n02454379": "armadillo", 366 | "n02457408": "three-toed sloth", 367 | "n02480495": "orangutan", 368 | "n02480855": "gorilla", 369 | "n02481823": "chimpanzee", 370 | "n02483362": "gibbon", 371 | "n02483708": "siamang", 372 | "n02484975": "guenon", 373 | "n02486261": "patas monkey", 374 | "n02486410": "baboon", 375 | "n02487347": "macaque", 376 | "n02488291": "langur", 377 | "n02488702": "black-and-white colobus", 378 | "n02489166": "proboscis monkey", 379 | "n02490219": "marmoset", 380 | "n02492035": "white-headed capuchin", 381 | "n02492660": "howler monkey", 382 | "n02493509": "titi monkey", 383 | "n02493793": "Geoffroy's spider monkey", 384 | "n02494079": "common squirrel monkey", 385 | "n02497673": "ring-tailed lemur", 386 | "n02500267": "indri", 387 | "n02504013": "Asian elephant", 388 | "n02504458": "African bush elephant", 389 | "n02509815": "red panda", 390 | "n02510455": "giant panda", 391 | "n02514041": "snoek fish", 392 | "n02526121": "eel", 393 | "n02536864": "silver salmon", 394 | "n02606052": "rock beauty fish", 395 | "n02607072": "clownfish", 396 | "n02640242": "sturgeon", 397 | "n02641379": "gar fish", 398 | "n02643566": "lionfish", 399 | "n02655020": "pufferfish", 400 | "n02666196": "abacus", 401 | "n02667093": "abaya", 402 | "n02669723": "academic gown", 403 | "n02672831": "accordion", 404 | "n02676566": "acoustic guitar", 405 | "n02687172": "aircraft carrier", 406 | "n02690373": "airliner", 407 | "n02692877": "airship", 408 | "n02699494": "altar", 409 | "n02701002": "ambulance", 410 | "n02704792": "amphibious vehicle", 411 | "n02708093": "analog clock", 412 | "n02727426": "apiary", 413 | "n02730930": "apron", 414 | "n02747177": "trash can", 415 | "n02749479": "assault rifle", 416 | "n02769748": "backpack", 417 | "n02776631": "bakery", 418 | "n02777292": "balance beam", 419 | "n02782093": "balloon", 420 | "n02783161": "ballpoint pen", 421 | "n02786058": "Band-Aid", 422 | "n02787622": "banjo", 423 | "n02788148": "baluster / handrail", 424 | "n02790996": "barbell", 425 | "n02791124": "barber chair", 426 | "n02791270": "barbershop", 427 | "n02793495": "barn", 428 | "n02794156": "barometer", 429 | "n02795169": "barrel", 430 | "n02797295": "wheelbarrow", 431 | "n02799071": "baseball", 432 | "n02802426": "basketball", 433 | "n02804414": "bassinet", 434 | "n02804610": "bassoon", 435 | "n02807133": "swimming cap", 436 | "n02808304": "bath towel", 437 | "n02808440": "bathtub", 438 | "n02814533": "station wagon", 439 | "n02814860": "lighthouse", 440 | "n02815834": "beaker", 441 | "n02817516": "military hat (bearskin or shako)", 442 | "n02823428": "beer bottle", 443 | "n02823750": "beer glass", 444 | "n02825657": "bell tower", 445 | "n02834397": "baby bib", 446 | "n02835271": "tandem bicycle", 447 | "n02837789": "bikini", 448 | "n02840245": "ring binder", 449 | "n02841315": "binoculars", 450 | "n02843684": "birdhouse", 451 | "n02859443": "boathouse", 452 | "n02860847": "bobsleigh", 453 | "n02865351": "bolo tie", 454 | "n02869837": "poke bonnet", 455 | "n02870880": "bookcase", 456 | "n02871525": "bookstore", 457 | "n02877765": "bottle cap", 458 | "n02879718": "hunting bow", 459 | "n02883205": "bow tie", 460 | "n02892201": "brass memorial plaque", 461 | "n02892767": "bra", 462 | "n02894605": "breakwater", 463 | "n02895154": "breastplate", 464 | "n02906734": "broom", 465 | "n02909870": "bucket", 466 | "n02910353": "buckle", 467 | "n02916936": "bulletproof vest", 468 | "n02917067": "high-speed train", 469 | "n02927161": "butcher shop", 470 | "n02930766": "taxicab", 471 | "n02939185": "cauldron", 472 | "n02948072": "candle", 473 | "n02950826": "cannon", 474 | "n02951358": "canoe", 475 | "n02951585": "can opener", 476 | "n02963159": "cardigan", 477 | "n02965783": "car mirror", 478 | "n02966193": "carousel", 479 | "n02966687": "tool kit", 480 | "n02971356": "cardboard box / carton", 481 | "n02974003": "car wheel", 482 | "n02977058": "automated teller machine", 483 | "n02978881": "cassette", 484 | "n02979186": "cassette player", 485 | "n02980441": "castle", 486 | "n02981792": "catamaran", 487 | "n02988304": "CD player", 488 | "n02992211": "cello", 489 | "n02992529": "mobile phone", 490 | "n02999410": "chain", 491 | "n03000134": "chain-link fence", 492 | "n03000247": "chain mail", 493 | "n03000684": "chainsaw", 494 | "n03014705": "storage chest", 495 | "n03016953": "chiffonier", 496 | "n03017168": "bell or wind chime", 497 | "n03018349": "china cabinet", 498 | "n03026506": "Christmas stocking", 499 | "n03028079": "church", 500 | "n03032252": "movie theater", 501 | "n03041632": "cleaver", 502 | "n03042490": "cliff dwelling", 503 | "n03045698": "cloak", 504 | "n03047690": "clogs", 505 | "n03062245": "cocktail shaker", 506 | "n03063599": "coffee mug", 507 | "n03063689": "coffeemaker", 508 | "n03065424": "spiral or coil", 509 | "n03075370": "combination lock", 510 | "n03085013": "computer keyboard", 511 | "n03089624": "candy store", 512 | "n03095699": "container ship", 513 | "n03100240": "convertible", 514 | "n03109150": "corkscrew", 515 | "n03110669": "cornet", 516 | "n03124043": "cowboy boot", 517 | "n03124170": "cowboy hat", 518 | "n03125729": "cradle", 519 | "n03126707": "construction crane", 520 | "n03127747": "crash helmet", 521 | "n03127925": "crate", 522 | "n03131574": "infant bed", 523 | "n03133878": "Crock Pot", 524 | "n03134739": "croquet ball", 525 | "n03141823": "crutch", 526 | "n03146219": "cuirass", 527 | "n03160309": "dam", 528 | "n03179701": "desk", 529 | "n03180011": "desktop computer", 530 | "n03187595": "rotary dial telephone", 531 | "n03188531": "diaper", 532 | "n03196217": "digital clock", 533 | "n03197337": "digital watch", 534 | "n03201208": "dining table", 535 | "n03207743": "dishcloth", 536 | "n03207941": "dishwasher", 537 | "n03208938": "disc brake", 538 | "n03216828": "dock", 539 | "n03218198": "dog sled", 540 | "n03220513": "dome", 541 | "n03223299": "doormat", 542 | "n03240683": "drilling rig", 543 | "n03249569": "drum", 544 | "n03250847": "drumstick", 545 | "n03255030": "dumbbell", 546 | "n03259280": "Dutch oven", 547 | "n03271574": "electric fan", 548 | "n03272010": "electric guitar", 549 | "n03272562": "electric locomotive", 550 | "n03290653": "entertainment center", 551 | "n03291819": "envelope", 552 | "n03297495": "espresso machine", 553 | "n03314780": "face powder", 554 | "n03325584": "feather boa", 555 | "n03337140": "filing cabinet", 556 | "n03344393": "fireboat", 557 | "n03345487": "fire truck", 558 | "n03347037": "fire screen", 559 | "n03355925": "flagpole", 560 | "n03372029": "flute", 561 | "n03376595": "folding chair", 562 | "n03379051": "football helmet", 563 | "n03384352": "forklift", 564 | "n03388043": "fountain", 565 | "n03388183": "fountain pen", 566 | "n03388549": "four-poster bed", 567 | "n03393912": "freight car", 568 | "n03394916": "French horn", 569 | "n03400231": "frying pan", 570 | "n03404251": "fur coat", 571 | "n03417042": "garbage truck", 572 | "n03424325": "gas mask or respirator", 573 | "n03425413": "gas pump", 574 | "n03443371": "goblet", 575 | "n03444034": "go-kart", 576 | "n03445777": "golf ball", 577 | "n03445924": "golf cart", 578 | "n03447447": "gondola", 579 | "n03447721": "gong", 580 | "n03450230": "gown", 581 | "n03452741": "grand piano", 582 | "n03457902": "greenhouse", 583 | "n03459775": "radiator grille", 584 | "n03461385": "grocery store", 585 | "n03467068": "guillotine", 586 | "n03476684": "hair clip", 587 | "n03476991": "hair spray", 588 | "n03478589": "half-track", 589 | "n03481172": "hammer", 590 | "n03482405": "hamper", 591 | "n03483316": "hair dryer", 592 | "n03485407": "hand-held computer", 593 | "n03485794": "handkerchief", 594 | "n03492542": "hard disk drive", 595 | "n03494278": "harmonica", 596 | "n03495258": "harp", 597 | "n03496892": "combine harvester", 598 | "n03498962": "hatchet", 599 | "n03527444": "holster", 600 | "n03529860": "home theater", 601 | "n03530642": "honeycomb", 602 | "n03532672": "hook", 603 | "n03534580": "hoop skirt", 604 | "n03535780": "gymnastic horizontal bar", 605 | "n03538406": "horse-drawn vehicle", 606 | "n03544143": "hourglass", 607 | "n03584254": "iPod", 608 | "n03584829": "clothes iron", 609 | "n03590841": "carved pumpkin", 610 | "n03594734": "jeans", 611 | "n03594945": "jeep", 612 | "n03595614": "T-shirt", 613 | "n03598930": "jigsaw puzzle", 614 | "n03599486": "rickshaw", 615 | "n03602883": "joystick", 616 | "n03617480": "kimono", 617 | "n03623198": "knee pad", 618 | "n03627232": "knot", 619 | "n03630383": "lab coat", 620 | "n03633091": "ladle", 621 | "n03637318": "lampshade", 622 | "n03642806": "laptop computer", 623 | "n03649909": "lawn mower", 624 | "n03657121": "lens cap", 625 | "n03658185": "letter opener", 626 | "n03661043": "library", 627 | "n03662601": "lifeboat", 628 | "n03666591": "lighter", 629 | "n03670208": "limousine", 630 | "n03673027": "ocean liner", 631 | "n03676483": "lipstick", 632 | "n03680355": "slip-on shoe", 633 | "n03690938": "lotion", 634 | "n03691459": "music speaker", 635 | "n03692522": "loupe magnifying glass", 636 | "n03697007": "sawmill", 637 | "n03706229": "magnetic compass", 638 | "n03709823": "messenger bag", 639 | "n03710193": "mailbox", 640 | "n03710637": "tights", 641 | "n03710721": "one-piece bathing suit", 642 | "n03717622": "manhole cover", 643 | "n03720891": "maraca", 644 | "n03721384": "marimba", 645 | "n03724870": "mask", 646 | "n03729826": "matchstick", 647 | "n03733131": "maypole", 648 | "n03733281": "maze", 649 | "n03733805": "measuring cup", 650 | "n03742115": "medicine cabinet", 651 | "n03743016": "megalith", 652 | "n03759954": "microphone", 653 | "n03761084": "microwave oven", 654 | "n03763968": "military uniform", 655 | "n03764736": "milk can", 656 | "n03769881": "minibus", 657 | "n03770439": "miniskirt", 658 | "n03770679": "minivan", 659 | "n03773504": "missile", 660 | "n03775071": "mitten", 661 | "n03775546": "mixing bowl", 662 | "n03776460": "mobile home", 663 | "n03777568": "ford model t", 664 | "n03777754": "modem", 665 | "n03781244": "monastery", 666 | "n03782006": "monitor", 667 | "n03785016": "moped", 668 | "n03786901": "mortar and pestle", 669 | "n03787032": "graduation cap", 670 | "n03788195": "mosque", 671 | "n03788365": "mosquito net", 672 | "n03791053": "vespa", 673 | "n03792782": "mountain bike", 674 | "n03792972": "tent", 675 | "n03793489": "computer mouse", 676 | "n03794056": "mousetrap", 677 | "n03796401": "moving van", 678 | "n03803284": "muzzle", 679 | "n03804744": "metal nail", 680 | "n03814639": "neck brace", 681 | "n03814906": "necklace", 682 | "n03825788": "baby pacifier", 683 | "n03832673": "notebook computer", 684 | "n03837869": "obelisk", 685 | "n03838899": "oboe", 686 | "n03840681": "ocarina", 687 | "n03841143": "odometer", 688 | "n03843555": "oil filter", 689 | "n03854065": "pipe organ", 690 | "n03857828": "oscilloscope", 691 | "n03866082": "overskirt", 692 | "n03868242": "bullock cart", 693 | "n03868863": "oxygen mask", 694 | "n03871628": "product packet / packaging", 695 | "n03873416": "paddle", 696 | "n03874293": "paddle wheel", 697 | "n03874599": "padlock", 698 | "n03876231": "paintbrush", 699 | "n03877472": "pajamas", 700 | "n03877845": "palace", 701 | "n03884397": "pan flute", 702 | "n03887697": "paper towel", 703 | "n03888257": "parachute", 704 | "n03888605": "parallel bars", 705 | "n03891251": "park bench", 706 | "n03891332": "parking meter", 707 | "n03895866": "railroad car", 708 | "n03899768": "patio", 709 | "n03902125": "payphone", 710 | "n03903868": "pedestal", 711 | "n03908618": "pencil case", 712 | "n03908714": "pencil sharpener", 713 | "n03916031": "perfume", 714 | "n03920288": "Petri dish", 715 | "n03924679": "photocopier", 716 | "n03929660": "plectrum", 717 | "n03929855": "Pickelhaube", 718 | "n03930313": "picket fence", 719 | "n03930630": "pickup truck", 720 | "n03933933": "pier", 721 | "n03935335": "piggy bank", 722 | "n03937543": "pill bottle", 723 | "n03938244": "pillow", 724 | "n03942813": "ping-pong ball", 725 | "n03944341": "pinwheel", 726 | "n03947888": "pirate ship", 727 | "n03950228": "drink pitcher", 728 | "n03954731": "block plane", 729 | "n03956157": "planetarium", 730 | "n03958227": "plastic bag", 731 | "n03961711": "plate rack", 732 | "n03967562": "farm plow", 733 | "n03970156": "plunger", 734 | "n03976467": "Polaroid camera", 735 | "n03976657": "pole", 736 | "n03977966": "police van", 737 | "n03980874": "poncho", 738 | "n03982430": "pool table", 739 | "n03983396": "soda bottle", 740 | "n03991062": "plant pot", 741 | "n03992509": "potter's wheel", 742 | "n03995372": "power drill", 743 | "n03998194": "prayer rug", 744 | "n04004767": "printer", 745 | "n04005630": "prison", 746 | "n04008634": "missile", 747 | "n04009552": "projector", 748 | "n04019541": "hockey puck", 749 | "n04023962": "punching bag", 750 | "n04026417": "purse", 751 | "n04033901": "quill", 752 | "n04033995": "quilt", 753 | "n04037443": "race car", 754 | "n04039381": "racket", 755 | "n04040759": "radiator", 756 | "n04041544": "radio", 757 | "n04044716": "radio telescope", 758 | "n04049303": "rain barrel", 759 | "n04065272": "recreational vehicle", 760 | "n04067472": "fishing casting reel", 761 | "n04069434": "reflex camera", 762 | "n04070727": "refrigerator", 763 | "n04074963": "remote control", 764 | "n04081281": "restaurant", 765 | "n04086273": "revolver", 766 | "n04090263": "rifle", 767 | "n04099969": "rocking chair", 768 | "n04111531": "rotisserie", 769 | "n04116512": "eraser", 770 | "n04118538": "rugby ball", 771 | "n04118776": "ruler measuring stick", 772 | "n04120489": "sneaker", 773 | "n04125021": "safe", 774 | "n04127249": "safety pin", 775 | "n04131690": "salt shaker", 776 | "n04133789": "sandal", 777 | "n04136333": "sarong", 778 | "n04141076": "saxophone", 779 | "n04141327": "scabbard", 780 | "n04141975": "weighing scale", 781 | "n04146614": "school bus", 782 | "n04147183": "schooner", 783 | "n04149813": "scoreboard", 784 | "n04152593": "CRT monitor", 785 | "n04153751": "screw", 786 | "n04154565": "screwdriver", 787 | "n04162706": "seat belt", 788 | "n04179913": "sewing machine", 789 | "n04192698": "shield", 790 | "n04200800": "shoe store", 791 | "n04201297": "shoji screen / room divider", 792 | "n04204238": "shopping basket", 793 | "n04204347": "shopping cart", 794 | "n04208210": "shovel", 795 | "n04209133": "shower cap", 796 | "n04209239": "shower curtain", 797 | "n04228054": "ski", 798 | "n04229816": "balaclava ski mask", 799 | "n04235860": "sleeping bag", 800 | "n04238763": "slide rule", 801 | "n04239074": "sliding door", 802 | "n04243546": "slot machine", 803 | "n04251144": "snorkel", 804 | "n04252077": "snowmobile", 805 | "n04252225": "snowplow", 806 | "n04254120": "soap dispenser", 807 | "n04254680": "soccer ball", 808 | "n04254777": "sock", 809 | "n04258138": "solar thermal collector", 810 | "n04259630": "sombrero", 811 | "n04263257": "soup bowl", 812 | "n04264628": "keyboard space bar", 813 | "n04265275": "space heater", 814 | "n04266014": "space shuttle", 815 | "n04270147": "spatula", 816 | "n04273569": "motorboat", 817 | "n04275548": "spider web", 818 | "n04277352": "spindle", 819 | "n04285008": "sports car", 820 | "n04286575": "spotlight", 821 | "n04296562": "stage", 822 | "n04310018": "steam locomotive", 823 | "n04311004": "through arch bridge", 824 | "n04311174": "steel drum", 825 | "n04317175": "stethoscope", 826 | "n04325704": "scarf", 827 | "n04326547": "stone wall", 828 | "n04328186": "stopwatch", 829 | "n04330267": "stove", 830 | "n04332243": "strainer", 831 | "n04335435": "tram", 832 | "n04336792": "stretcher", 833 | "n04344873": "couch", 834 | "n04346328": "stupa", 835 | "n04347754": "submarine", 836 | "n04350905": "suit", 837 | "n04355338": "sundial", 838 | "n04355933": "sunglasses", 839 | "n04356056": "sunglasses", 840 | "n04357314": "sunscreen", 841 | "n04366367": "suspension bridge", 842 | "n04367480": "mop", 843 | "n04370456": "sweatshirt", 844 | "n04371430": "swim trunks / shorts", 845 | "n04371774": "swing", 846 | "n04372370": "electrical switch", 847 | "n04376876": "syringe", 848 | "n04380533": "table lamp", 849 | "n04389033": "tank", 850 | "n04392985": "tape player", 851 | "n04398044": "teapot", 852 | "n04399382": "teddy bear", 853 | "n04404412": "television", 854 | "n04409515": "tennis ball", 855 | "n04417672": "thatched roof", 856 | "n04418357": "front curtain", 857 | "n04423845": "thimble", 858 | "n04428191": "threshing machine", 859 | "n04429376": "throne", 860 | "n04435653": "tile roof", 861 | "n04442312": "toaster", 862 | "n04443257": "tobacco shop", 863 | "n04447861": "toilet seat", 864 | "n04456115": "torch", 865 | "n04458633": "totem pole", 866 | "n04461696": "tow truck", 867 | "n04462240": "toy store", 868 | "n04465501": "tractor", 869 | "n04467665": "semi-trailer truck", 870 | "n04476259": "tray", 871 | "n04479046": "trench coat", 872 | "n04482393": "tricycle", 873 | "n04483307": "trimaran", 874 | "n04485082": "tripod", 875 | "n04486054": "triumphal arch", 876 | "n04487081": "trolleybus", 877 | "n04487394": "trombone", 878 | "n04493381": "hot tub", 879 | "n04501370": "turnstile", 880 | "n04505470": "typewriter keyboard", 881 | "n04507155": "umbrella", 882 | "n04509417": "unicycle", 883 | "n04515003": "upright piano", 884 | "n04517823": "vacuum cleaner", 885 | "n04522168": "vase", 886 | "n04523525": "vaulted or arched ceiling", 887 | "n04525038": "velvet fabric", 888 | "n04525305": "vending machine", 889 | "n04532106": "vestment", 890 | "n04532670": "viaduct", 891 | "n04536866": "violin", 892 | "n04540053": "volleyball", 893 | "n04542943": "waffle iron", 894 | "n04548280": "wall clock", 895 | "n04548362": "wallet", 896 | "n04550184": "wardrobe", 897 | "n04552348": "military aircraft", 898 | "n04553703": "sink", 899 | "n04554684": "washing machine", 900 | "n04557648": "water bottle", 901 | "n04560804": "water jug", 902 | "n04562935": "water tower", 903 | "n04579145": "whiskey jug", 904 | "n04579432": "whistle", 905 | "n04584207": "hair wig", 906 | "n04589890": "window screen", 907 | "n04590129": "window shade", 908 | "n04591157": "Windsor tie", 909 | "n04591713": "wine bottle", 910 | "n04592741": "airplane wing", 911 | "n04596742": "wok", 912 | "n04597913": "wooden spoon", 913 | "n04599235": "wool", 914 | "n04604644": "split-rail fence", 915 | "n04606251": "shipwreck", 916 | "n04612504": "sailboat", 917 | "n04613696": "yurt", 918 | "n06359193": "website", 919 | "n06596364": "comic book", 920 | "n06785654": "crossword", 921 | "n06794110": "traffic or street sign", 922 | "n06874185": "traffic light", 923 | "n07248320": "dust jacket", 924 | "n07565083": "menu", 925 | "n07579787": "plate", 926 | "n07583066": "guacamole", 927 | "n07584110": "consomme", 928 | "n07590611": "hot pot", 929 | "n07613480": "trifle", 930 | "n07614500": "ice cream", 931 | "n07615774": "popsicle", 932 | "n07684084": "baguette", 933 | "n07693725": "bagel", 934 | "n07695742": "pretzel", 935 | "n07697313": "cheeseburger", 936 | "n07697537": "hot dog", 937 | "n07711569": "mashed potatoes", 938 | "n07714571": "cabbage", 939 | "n07714990": "broccoli", 940 | "n07715103": "cauliflower", 941 | "n07716358": "zucchini", 942 | "n07716906": "spaghetti squash", 943 | "n07717410": "acorn squash", 944 | "n07717556": "butternut squash", 945 | "n07718472": "cucumber", 946 | "n07718747": "artichoke", 947 | "n07720875": "bell pepper", 948 | "n07730033": "cardoon", 949 | "n07734744": "mushroom", 950 | "n07742313": "Granny Smith apple", 951 | "n07745940": "strawberry", 952 | "n07747607": "orange", 953 | "n07749582": "lemon", 954 | "n07753113": "fig", 955 | "n07753275": "pineapple", 956 | "n07753592": "banana", 957 | "n07754684": "jackfruit", 958 | "n07760859": "cherimoya (custard apple)", 959 | "n07768694": "pomegranate", 960 | "n07802026": "hay", 961 | "n07831146": "carbonara", 962 | "n07836838": "chocolate syrup", 963 | "n07860988": "dough", 964 | "n07871810": "meatloaf", 965 | "n07873807": "pizza", 966 | "n07875152": "pot pie", 967 | "n07880968": "burrito", 968 | "n07892512": "red wine", 969 | "n07920052": "espresso", 970 | "n07930864": "tea cup", 971 | "n07932039": "eggnog", 972 | "n09193705": "mountain", 973 | "n09229709": "bubble", 974 | "n09246464": "cliff", 975 | "n09256479": "coral reef", 976 | "n09288635": "geyser", 977 | "n09332890": "lakeshore", 978 | "n09399592": "promontory", 979 | "n09421951": "sandbar", 980 | "n09428293": "beach", 981 | "n09468604": "valley", 982 | "n09472597": "volcano", 983 | "n09835506": "baseball player", 984 | "n10148035": "bridegroom", 985 | "n10565667": "scuba diver", 986 | "n11879895": "rapeseed", 987 | "n11939491": "daisy", 988 | "n12057211": "yellow lady's slipper", 989 | "n12144580": "corn", 990 | "n12267677": "acorn", 991 | "n12620546": "rose hip", 992 | "n12768682": "horse chestnut seed", 993 | "n12985857": "coral fungus", 994 | "n12998815": "agaric", 995 | "n13037406": "gyromitra", 996 | "n13040303": "stinkhorn mushroom", 997 | "n13044778": "earth star fungus", 998 | "n13052670": "hen of the woods mushroom", 999 | "n13054560": "bolete", 1000 | "n13133613": "corn cob", 1001 | "n15075141": "toilet paper" 1002 | } -------------------------------------------------------------------------------- /Priming/prime.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import open_clip 4 | import torchvision 5 | import numpy as np 6 | import importlib 7 | import os 8 | import json 9 | import torchvision.datasets as datasets 10 | from torch.utils.data import Dataset, DataLoader 11 | from data import ImageNetV2, CustomDataset 12 | from util import zeroshot_classifier_gpt, zeroshot_classifier, centroid, create_exp_ID 13 | from tqdm import tqdm 14 | from args import parse_arguments 15 | 16 | Shift_Datasets = ['ImageNet-V2', 'sketch','ImageNet-r', 'ImageNet-a'] 17 | 18 | #See args.py for list of arguments 19 | args = parse_arguments() 20 | template_folder = 'Templates' 21 | cache_path = args.cache_path 22 | dataset_name = args.dataset 23 | results_path = args.results_path 24 | experiment_ID = create_exp_ID(args) 25 | 26 | if not os.path.exists(results_path): 27 | os.makedirs(results_path) 28 | if not os.path.exists(cache_path): 29 | os.makedirs(cache_path) 30 | 31 | if dataset_name == 'ImageNet': 32 | f = open('imagenet.json',) 33 | data = json.load(f) 34 | labels = list(data.values()) 35 | class_names = list(data.values()) 36 | dataset_obj = importlib.import_module(template_folder + '.' + dataset_name) 37 | templates = dataset_obj.templates 38 | 39 | elif dataset_name in ['ImageNet-V2', 'sketch', 'ImageNet-r', 'ImageNet-a']: 40 | f = open('imagenet.json',) 41 | data = json.load(f) 42 | dataset_obj = importlib.import_module(template_folder + '.' + 'ImageNet') 43 | labels = list(data.keys()) 44 | class_names = list(data.values()) 45 | templates = dataset_obj.templates 46 | 47 | else: 48 | dataset_obj = importlib.import_module(template_folder + '.' + dataset_name) 49 | labels = dataset_obj.classes 50 | class_names = labels 51 | templates = dataset_obj.templates 52 | 53 | print('Loading model') 54 | dataset_type = args.base_dataset 55 | model_type = args.model 56 | model, _, preprocess = open_clip.create_model_and_transforms(model_type, pretrained=args.base_dataset) 57 | 58 | 59 | #Use torchvision dataloaders when possible for train/test splits. Otherwise use custom loaders. 60 | if dataset_name in Shift_Datasets + ['ImageNet']: 61 | split = 'val' 62 | test_set = CustomDataset(args.val_path, preprocess) 63 | train_set = CustomDataset(args.train_path, preprocess) 64 | 65 | elif dataset_name == 'ImageNet-V2': 66 | split = 'val' 67 | test_set = ImageNetV2(root, preprocess) 68 | train_set = ImageNetV2(root, preprocess) 69 | 70 | elif dataset_name == 'sketch': 71 | split = 'val' 72 | test_set = torchvision.datasets.ImageFolder(root, preprocess) 73 | train_set = torchvision.datasets.ImageFolder(root, preprocess) 74 | 75 | elif dataset_name == 'ImageNet-r': 76 | split = 'val' 77 | test_set = CustomDataset(root, preprocess) 78 | train_set = CustomDataset(root, preprocess) 79 | 80 | elif dataset_name == 'ImageNet-a': 81 | split = 'val' 82 | test_set = CustomDataset(root, preprocess) 83 | train_set = CustomDataset(root, preprocess) 84 | 85 | elif dataset_name == 'SUN397': 86 | dataset = torchvision.datasets.__getattribute__(dataset_name) 87 | test_set = dataset(args.val_path, transform = preprocess, download = True) 88 | train_set = test_set 89 | else: 90 | root = './' 91 | test_split = 'test' 92 | if args.dataset == 'OxfordIIITPet': 93 | train_split = 'trainval' 94 | else: 95 | train_split = 'train' 96 | dataset = torchvision.datasets.__getattribute__(dataset_name) 97 | test_set = dataset(root, split = test_split, transform = preprocess, download = True) 98 | train_set = dataset(root, split = train_split, transform = preprocess, download = True) 99 | 100 | if args.prime: 101 | #Use custom_data for SUN, ImageNet variants, and transductive datasets. 102 | if args.custom_data: 103 | subset = CustomDataset(args.subset_path, transform = preprocess) 104 | else: 105 | subset = torchvision.datasets.ImageFolder(args.subset_path, transform = preprocess) 106 | 107 | if args.prime: 108 | cti = subset.class_to_idx 109 | keys = cti.keys() 110 | vals = cti.values() 111 | if args.dataset =='OxfordIIITPet': 112 | keys = [x.lower() for x in keys] 113 | cti = dict(zip(keys,vals)) 114 | idx_map = dict() 115 | for i,j in enumerate(labels): 116 | if '/' in j: 117 | j = j.replace('/', 'or') 118 | try: 119 | index = cti[j] 120 | idx_map[index] = i 121 | except Exception as e: 122 | pass 123 | 124 | train_set = DataLoader(train_set, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers) 125 | test_set = DataLoader(test_set, batch_size = args.batch_size, shuffle = False, num_workers = args.num_workers) 126 | if args.prime: 127 | subset = DataLoader(subset, batch_size = args.batch_size, shuffle = False, num_workers = args.num_workers) 128 | 129 | 130 | 131 | 132 | print('creating OpenAI text classifier') 133 | model.eval() 134 | if args.cuda: 135 | model = model.cuda() 136 | tokenizer = open_clip.get_tokenizer(args.model) 137 | text = tokenizer(class_names) 138 | text_features = zeroshot_classifier(class_names, templates, model, tokenizer).cpu().numpy().T 139 | 140 | if args.cupl: 141 | print('creating CuPL classifier') 142 | tokenizer = open_clip.get_tokenizer(args.model) 143 | text_features_cupl = zeroshot_classifier_gpt(class_names, model, tokenizer, dataset_name).cpu().numpy().T 144 | #average cupl with OpenAI prompts 145 | text_features = (text_features + text_features_cupl)/2 146 | 147 | 148 | 149 | if args.retrain and dataset_name not in Shift_Datasets and args.shots > 0: 150 | print('Processing train image features') 151 | train_set_cpu = None 152 | train_labels = None 153 | for x,y in tqdm(train_set): 154 | if args.cuda: 155 | x = x.cuda() 156 | feats = model.encode_image(x).detach().cpu() 157 | feats /= feats.norm(dim=-1, keepdim=True) 158 | feats = feats.squeeze(0).numpy() 159 | 160 | if train_set_cpu is None: 161 | train_set_cpu = feats 162 | else: 163 | train_set_cpu = np.concatenate((train_set_cpu, feats), axis = 0) 164 | if train_labels is None: 165 | train_labels = y 166 | else: 167 | train_labels = np.concatenate((train_labels, y), axis = 0) 168 | if args.cache: 169 | np.save(cache_path + 'train_feats_' + experiment_ID, train_set_cpu) 170 | np.save(cache_path + 'train_labels_'+ experiment_ID, train_labels) 171 | 172 | if args.shots > 0 and args.cache: 173 | train_set_cpu = np.load(cache_path + 'train_feats_'+ experiment_ID+'.npy') 174 | train_labels = np.load(cache_path + 'train_labels_'+ experiment_ID+'.npy') 175 | else: 176 | train_set_cpu_sampled = [] 177 | train_labels_sampled = [] 178 | 179 | if args.retrain: 180 | print('Processing test image features') 181 | test_set_cpu = None 182 | test_labels = None 183 | i=0 184 | for x,y in tqdm(test_set): 185 | i+=1 186 | if i > args.test_batches: 187 | break 188 | if args.cuda: 189 | x = x.cuda() 190 | feats = model.encode_image(x).detach().cpu() 191 | feats /= feats.norm(dim=-1, keepdim=True) 192 | feats = feats.numpy() 193 | if test_set_cpu is None: 194 | test_set_cpu = feats 195 | else: 196 | test_set_cpu = np.concatenate((test_set_cpu, feats), axis = 0) 197 | if test_labels is None: 198 | test_labels = y 199 | else: 200 | test_labels = np.concatenate((test_labels, y), axis = 0) 201 | if args.cache: 202 | np.save(cache_path + 'test_feats_'+experiment_ID, test_set_cpu) 203 | np.save(cache_path + 'test_labels_'+experiment_ID, test_labels) 204 | if args.cache: 205 | test_set_cpu = np.load(cache_path + 'test_feats_'+experiment_ID+'.npy') 206 | test_labels = np.load(cache_path + 'test_labels_'+experiment_ID+'.npy') 207 | 208 | subset_cpu = [] 209 | subset_labels=[] 210 | print('Creating priming image features') 211 | if args.retrain and args.prime: 212 | subset_cpu = None 213 | subset_labels=None 214 | print('Parsing Subset, length {}'.format(len(subset))) 215 | 216 | for x,y in tqdm(subset): 217 | if args.cuda: 218 | x = x.cuda() 219 | feats = model.encode_image(x).detach().cpu() 220 | feats /= feats.norm(dim=-1, keepdim=True) 221 | feats = feats.numpy() 222 | 223 | if subset_cpu is None: 224 | subset_cpu = feats 225 | else: 226 | subset_cpu = np.concatenate((subset_cpu, feats), axis = 0) 227 | if subset_labels is None: 228 | subset_labels = y 229 | else: 230 | subset_labels = np.concatenate((subset_labels, y.detach().cpu()), axis = 0) 231 | if args.cache: 232 | np.save(cache_path + 'subset_feats_' + experiment_ID + os.path.split(args.subset_path)[-1], subset_cpu) 233 | np.save(cache_path + 'subset_labels_' + experiment_ID+os.path.split(args.subset_path)[-1], subset_labels) 234 | 235 | if args.prime and args.cache: 236 | subset_cpu = np.load(cache_path + 'subset_feats_'+experiment_ID+os.path.split(args.subset_path)[-1]+'.npy') 237 | subset_labels = np.load(cache_path + 'subset_labels_'+experiment_ID+os.path.split(args.subset_path)[-1]+'.npy') 238 | subset_labels = [idx_map[x] for x in subset_labels] 239 | 240 | indices = [] 241 | shot = args.shots 242 | 243 | means = [] 244 | if args.shots > 0: 245 | indices = [] 246 | for i in range(0,len(labels)): 247 | idx = np.random.choice(np.where(train_labels==i)[0], size = shot, replace = False) 248 | indices += list(idx) 249 | train_set_cpu_sampled = list(train_set_cpu[indices]) 250 | train_labels_sampled = list(train_labels[indices]) 251 | 252 | if args.prime: 253 | train_set_cpu_sampled += list(subset_cpu) 254 | train_labels_sampled += list(subset_labels) 255 | 256 | if args.shots > 0 or args.prime: 257 | cents = centroid(train_set_cpu_sampled, train_labels_sampled).numpy() 258 | cents = np.nan_to_num(cents) 259 | alpha = args.alpha 260 | text_features_new = (alpha)*text_features + (1.0-alpha)*cents 261 | else: 262 | text_features_new = text_features 263 | 264 | text_probs = (test_set_cpu @ text_features_new.T).argmax(axis=-1) 265 | acc = (test_labels == text_probs).mean() 266 | print("text zero-shot: {}".format(acc)) 267 | np.save(results_path + 'accuracy' + '_' + experiment_ID + '_' + str(args.prime), acc) 268 | 269 | 270 | 271 | 272 | -------------------------------------------------------------------------------- /Priming/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from PIL import Image 5 | import json 6 | from tqdm import tqdm 7 | 8 | def zeroshot_classifier(classnames, templates, model, tokenizer): 9 | with torch.no_grad(): 10 | zeroshot_weights = [] 11 | for classname in tqdm(classnames): 12 | texts = [template.format(classname) for template in templates] #format with class 13 | texts = tokenizer(texts) 14 | if next(model.parameters()).is_cuda: 15 | texts = texts.cuda() 16 | class_embeddings = model.encode_text(texts) #embed with text encoder 17 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 18 | class_embedding = class_embeddings.mean(dim=0) 19 | class_embedding /= class_embedding.norm() 20 | zeroshot_weights.append(class_embedding) 21 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1) 22 | if next(model.parameters()).is_cuda: 23 | zeroshot_weights = zeroshot_weights.cuda() 24 | return zeroshot_weights 25 | 26 | def zeroshot_classifier_gpt(classnames, model, tokenizer, dataset = None, templates=None, use_both=False): 27 | 28 | # keys = [x.replace('(', '').replace(')', '') for x in gpt3_prompts.keys()] 29 | #keys = [x.lower() for x in gpt3_prompts.keys()] 30 | # values = gpt3_prompts.values() 31 | # gpt3_prompts = dict(zip(keys, values)) 32 | 33 | if dataset in ['ImageNet', 'ImageNet-a','ImageNet-r', 'sketch', 'ImageNet-V2']: 34 | with open('CuPL_image_prompts.json') as f: 35 | gpt3_prompts = json.load(f) 36 | 37 | elif dataset == 'oxford-iiit-pet': 38 | with open('pets_prompts_full.json') as f: 39 | gpt3_prompts = json.load(f) 40 | 41 | elif dataset == 'SUN397': 42 | with open('sun_prompts_full.json') as f: 43 | gpt3_prompts = json.load(f) 44 | keys = [x.replace('(', '').replace(')', '') for x in gpt3_prompts.keys()] 45 | keys = [x.lower() for x in gpt3_prompts.keys()] 46 | values = gpt3_prompts.values() 47 | gpt3_prompts = dict(zip(keys, values)) 48 | 49 | elif dataset == 'Flowers102': 50 | with open('flower_prompts_full.json') as f: 51 | gpt3_prompts = json.load(f) 52 | elif dataset == 'StanfordCars': 53 | with open('cars_prompts_full.json') as f: 54 | gpt3_prompts = json.load(f) 55 | 56 | elif dataset == 'Food101': 57 | with open('descriptors_food101.json') as f: 58 | gpt3_prompts = json.load(f) 59 | 60 | classnames = [x.replace(' or ', ' / ') for x in classnames] 61 | with torch.no_grad(): 62 | zeroshot_weights = [] 63 | for i in tqdm(range(len(classnames))): 64 | if use_both: 65 | texts = [template.format(classnames[i]) for template in templates] 66 | else: 67 | texts = [] 68 | 69 | for t in gpt3_prompts[classnames[i]]: 70 | texts.append(t) 71 | texts = tokenizer(texts) #tokenize 72 | if next(model.parameters()).is_cuda: 73 | texts = texts.cuda() 74 | class_embeddings = model.encode_text(texts) #embed with text encoder 75 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 76 | class_embedding = class_embeddings.mean(dim=0) 77 | class_embedding /= class_embedding.norm().cpu().detach() 78 | zeroshot_weights.append(class_embedding) 79 | 80 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1) 81 | return zeroshot_weights 82 | 83 | 84 | def centroid(embeddings, labels): 85 | labels = torch.tensor(labels) 86 | embeddings = torch.tensor(np.array(embeddings)) 87 | onehot = torch.zeros(labels.size(0), labels.max()+1) 88 | filled_onehot = onehot.scatter_(1, labels.unsqueeze(dim=1), 1) 89 | new_prototypes = torch.mm(filled_onehot.permute((1, 0)), embeddings) 90 | new_prototypes /= new_prototypes.norm(dim=-1, keepdim=True) 91 | return new_prototypes 92 | 93 | def create_exp_ID(args): 94 | return args.model+ '_' + args.base_dataset 95 | + '_' + os.path.split(args.subset_path)[-1] 96 | 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Priming 2 | 3 | Pytorch implementation of [Neural Priming for Sample-Efficient Adaptation](https://arxiv.org/pdf/2306.10191.pdf). Neural Priming improves transfer learning accuracy and robustness by recalling relevant data from the pretraining dataset. 4 | 5 | 6 | 7 | 8 | 9 | ## Getting Started 10 | 11 | If you'd like a demo of the code see the [collab notebook](https://colab.research.google.com/drive/1FkwnkfHCwBjsxsfy_WdxF6W_m6US28ft?usp=sharing). 12 | 13 | ### Installation 14 | 15 | 1. Clone this repository to your local machine using the following command: 16 | ```bash 17 | git clone https://github.com/your-username/neural-priming.git 18 | ``` 19 | 3. Navigate to the project directory: 20 | ```bash 21 | cd neural-priming 22 | ``` 23 | 5. (Optional) - Create a conda environment: 24 | ```bash 25 | conda create -n neural-priming python=3.8 26 | conda activate neural-priming 27 | ``` 28 | 6. Install the required dependencies from the requirements.txt file: 29 | ```bash 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | ### Downloading The Data 34 | 35 | #### Priming Data (LAION-2B) 36 | To get started quickly we provide the priming subsets of LAION-2B for each target task. The link to download the data from Google Drive can be found [here](https://drive.google.com/drive/folders/1yQfr6IYrG8_ViuQW7hOtHHr6yalrQ8d0?usp=sharing). If downloading to a headless server we recommend using [gdown](https://github.com/wkentaro/gdown). Once downloaded, unzip and place in the `/data` folder in the root directory. 37 | 38 | Alternatively, we provide code in the **Text Filtering and Downloading Data** section for creating your own priming subset. 39 | 40 | #### Evaluation Data 41 | 42 | - To download ImageNet-1k: Download from this Kaggle [link](https://www.kaggle.com/competitions/imagenet-object-localization-challenge/data?select=ILSVRC). The validation and train set should be in ImageFolder format which looks like the following. 43 | ``` 44 | ├── ImageNet_Train 45 | ├── n01440764 46 | ├── img_1.jpeg 47 | ... 48 | ├── img_N.jpeg 49 | 50 | ... 51 | 52 | ├── n01632777 53 | ├── img_1.jpeg 54 | ... 55 | ├── img_N.jpeg 56 | ``` 57 | 58 | 59 | - Torchvision Datasets: The 6 transfer learning datasets (StanfordCars, FGVC Aircraft, Flowers102, OxfordPets, and SUN397) will automatically be downloaded upon running the training code. 60 | 61 | - The other datasets (ImageNetV2, ImageNet-a, r, and sketch) can be found on their respective webpages. For ImageNetV2 we use the *ImageNetV2-matched-frequency* version. 62 | 63 | 64 | ## Train and Evaluate Models 65 | 66 | 67 | ### Zero-shot Priming 68 | Example commands for priming and evaluating the model: 69 | - ```python Priming/prime.py --dataset Flowers102 --shots 0 --alpha .7 --prime --subset_path ./data/Flowers102 --retrain``` 70 | - ```python Priming/prime.py --dataset StanfordCars --shots 0 --prime --subset_path ./data/StanfordCars --retrain``` 71 | 72 | Note: At this current time, StanfordCars is not available through torchvision. The dataset can be downloaded from [kaggle](https://www.kaggle.com/datasets/jessicali9530/stanford-cars-dataset). 73 | - ``` 74 | python Priming/prime.py --dataset ImageNet --shots 0 --prime --cupl --subset_path /data/ImageNet_subset --train_path ./data/ImageNet/train --val_path /data/ImageNet/val --retrain 75 | ``` 76 | To run the equivalent baselines, omit the `--prime` flag. For example: 77 | 78 | ```python priming/prime.py --dataset Flowers102 --shots 0 --subset_path ./data/Flowers102 --retrain``` 79 | 80 | Zero-shot Results: 81 | | | ImageNet | Stanford Cars | FGVC Aircraft | Flowers102 | Food101 | Oxford Pets | SUN397 | 82 | |-------------------------|----------|---------------|---------------|------------|---------|-------------|--------| 83 | | CLIP Baseline | 68.30 | 87.40 | 25.86 | 71.65 | 86.58 | 90.21 | 67.35 | 84 | | CuPL | 70.25 | 88.63 | 29.64 | 72.32 | 86.20 | 91.16 | 70.80 | 85 | | Priming (Ours) | 70.75 | 89.30 | 33.03 | 79.81 | 86.66 | 91.87 | 71.21 | 86 | | Priming + CuPL (Ours) | 71.38 | 90.23 | 36.00 | 80.04 | 86.86 | 91.85 | 72.35 | 87 | 88 | 89 | ### Few-shot Priming 90 | Example commands for reproducing few-shot results. Note that alpha depends on the number of shots used. 91 | 92 | ```bash 93 | python priming/prime.py --dataset Flowers102 --shots 2 --alpha .58 --prime --subset_path ./data/Flowers102 --retrain 94 | ``` 95 | 96 | ```bash 97 | python priming/prime.py --dataset FGVCAircraft --shots 3 --alpha .55 --prime --subset_path ./data/FGVCAircraft --retrain 98 | ``` 99 | 100 |
101 | FGVC_FSL 102 | Flowers_FSL 103 |
104 | 105 | 106 | ### Transductive Priming 107 | Example commands for reproducing transductive results on the distribution shift datasets: 108 | 109 | 110 | ```bash 111 | python prime.py --dataset ImageNet-V2 --shots 0 --prime --cupl --subset_path ./data/ImageNetv2 --val_path ./data/ImageNetV2-matched-frequency --custom_data --retrain 112 | ``` 113 | 114 | | | ImageNet-V2 | ImageNet-R | ImageNet Sketch | ImageNet-A | 115 | |--------------------------------------|-------------|------------|-----------------|------------| 116 | | CLIP | 62.61 | 64.57 | 57.05 | 35.95 | 117 | | Transduct. Priming (Ours) |64.23 | 79.37 | 59.97 | 38.20 | 118 | 119 | 120 | Command line options: 121 | 122 | - `--prime` Use the priming subset to condition the model. 123 | - `--retrain` Reprocess the image features from the train/val/subset datasets. If already cached, omit to avoid reprocessing. 124 | - `--text` Initialize the classifier with the text prompts from OpenAI for ensembling with image features. 125 | - `--cupl` Initialize the classifier with text prompts from CuPL and OpenAI. 126 | - `--cache` Whether to cache the image features of the train/test/priming subset. Set to true by default. Set to false if low on disk space. 127 | - `--alpha` Ensembling coefficient between text and image features. Depends on the size of the training/priming set. 128 | - `--shots` Number of examples to be used from the target training set (not to be confused with the priming subset). 129 | - `--model` Change the base model. The priming subsets provided above are from the B-16 model. 130 | - `--subset_path` Path to the priming subset. 131 | - `--val_path` Path to the evaluation dataset. Only required for ImageNet and the distribution shift datasets. 132 | 133 | 134 | For full list of command line options see `args.py`. 135 | 136 | 137 | ## Creating Subsets from LAION-2B 138 | 139 | ### Text filtering and Downloading Images 140 | To quickly filter through the LAION-2B dataset using text, we use SQLite in python. The data base parquets can be downloaded [here](https://drive.google.com/drive/folders/1yQfr6IYrG8_ViuQW7hOtHHr6yalrQ8d0?usp=sharing). We recommend [gdown](https://github.com/wkentaro/gdown) if downloading to a headless server. Once downloaded, place them in a `/parquets` folder. Each parquet is around 8 GB and all parquets are about 1 TB. If disk space is limited, you can download fewer parquets and filter on a subset of LAION-2B. Also note that placing the parquets on SSD will significantly improve search speed. 141 | You'll need the sqlite package which can be installed with 142 | ```pip install pysqlite3```. 143 | Given the class names for a dataset, the code will filter for LAION-2B entries which have captions containing the class name and write the meta data to a json. Example to command to filter for ImageNet classes: 144 | 145 | ```bash 146 | python ./DataFiltering/FilterData.py -o ./ImageNet_Filtered -q ImageNet \ 147 | -d /parquets/part-00{000..123}-5114fd87-297e-42b0-9d11-50f1df323dfa-c000.snappy.db --template 148 | ``` 149 | Adjust the line text `{000..123}` if using fewer parquets. To filter using with respect to your own custom dataset, places the class names in a .py file in templates and set it as the `--q` argument in the above command. 150 | 151 | Once the data is stored in the json, you can download the data from URLS using the following command: 152 | 153 | ```python DataFiltering/download_urls.py --r ./DataFiltering/ImageNet_Filtered/ --w ./data/ImageNet_filtered``` 154 | 155 | Note that links break over time. A signficantly smaller number of images may be actually downloaded. 156 | 157 | 158 | 159 | 160 | 161 | ### Transductive Filtering 162 | Once the **Text Filtering and Downloading Images** has been completed, transductive filtering can be performed on the text-filtered subset. An example command would be the following: 163 | 164 | Example command. Given the priming pool `ImageNet_filtered` and a path to a ground-truth dataset (`ImageNet`), this takes 10 dataset train images per class (`--k-shot=10`) and retrieves the 10 closest images for each of these from the priming pool (`--retrievals-per-image=10`): 165 | 166 | ```bash 167 | python ./DataFilering/TransFiltering.py \ 168 | --dataset-type ImageNet \ 169 | --retrieval-path "./data/ImageNet_filtered" \ 170 | --transductive-path="/usr/data" \ 171 | --cache-path="./data/kshot_cache_ImageNet" \ 172 | --out-dir="/ImageNet" \ 173 | --retrievals-per-image=10 \ 174 | --prompt-file=./data/ImageNet.json \ 175 | --k-shot=10 \ 176 | --split="train" 177 | ``` 178 | 179 | This returns a dataset in the `ImageFolder` format, where each retrieved image is labeled by its given class in the priming pool. If you use `--split="test"`, this becomes the transductive setting discussed in the paper. You can add the `--clip-filter` command to apply a CLIP classifier to this new pool to further refine the dataset. 180 | 181 | See the `DataFiltering/TransFiltering.py` file for more details. Below is a list of arguments and descriptions: 182 | 183 | ```bash 184 | usage: TransFiltering.py [-h] --retrieval-path RETRIEVAL_PATH --transductive-path TRANSDUCTIVE_PATH [--k-shot K_SHOT] 185 | [--cache-path CACHE_PATH] --out-dir OUT_DIR [--retrievals-per-image RETRIEVALS_PER_IMAGE] [--clip-filter] 186 | [--prompt-file PROMPT_FILE] [--dataset-type DATASET_TYPE] [--clip-score-filter CLIP_SCORE_FILTER] 187 | [--split SPLIT] [--model MODEL] [--pretrained PRETRAINED] 188 | 189 | optional arguments: 190 | -h, --help show this help message and exit 191 | --retrieval-path RETRIEVAL_PATH 192 | Path to retrieval reservoir (clean subset of LAION) 193 | --transductive-path TRANSDUCTIVE_PATH 194 | Path to transfer evaluation dataset (to be used in a transductive fashion) 195 | --k-shot K_SHOT Number of shots per class, only used in train-time data augmentation 196 | --cache-path CACHE_PATH 197 | Path to cache 198 | --out-dir OUT_DIR Path to output directory (dataset in ImageFolder format) 199 | --retrievals-per-image RETRIEVALS_PER_IMAGE 200 | Number of retrievals per image 201 | --clip-filter Filter using CLIP before transductive retrieval 202 | --prompt-file PROMPT_FILE 203 | Path to prompt file, format classname to list of prompts 204 | --dataset-type DATASET_TYPE 205 | Type of dataset 206 | --clip-score-filter CLIP_SCORE_FILTER 207 | Filter using CLIP score, after clip classification filtering 208 | --split SPLIT Split to use, only applies to non-ImageFolder datasets 209 | --model MODEL Model arch from open_clip to use for filtering 210 | --pretrained PRETRAINED 211 | Pre-trained weights from open_clip to use for filtering. See open_clip repo for choices 212 | ``` 213 | -------------------------------------------------------------------------------- /assets/Aircraft_FSL.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RAIVNLab/neural-priming/e9146cbb07663ea4aacf4d3c6e99c47bf12107f6/assets/Aircraft_FSL.jpeg -------------------------------------------------------------------------------- /assets/Cars_FSL.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RAIVNLab/neural-priming/e9146cbb07663ea4aacf4d3c6e99c47bf12107f6/assets/Cars_FSL.jpeg -------------------------------------------------------------------------------- /assets/Flowers_FSL.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RAIVNLab/neural-priming/e9146cbb07663ea4aacf4d3c6e99c47bf12107f6/assets/Flowers_FSL.jpeg -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RAIVNLab/neural-priming/e9146cbb07663ea4aacf4d3c6e99c47bf12107f6/assets/teaser.jpg -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: neural-priming 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.08.22=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.4.4=h6a678d5_0 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=3.0.11=h7f8727e_2 15 | - pip=23.2.1=py38h06a4308_0 16 | - python=3.8.18=h955ad1f_0 17 | - readline=8.2=h5eee18b_0 18 | - setuptools=68.0.0=py38h06a4308_0 19 | - sqlite=3.41.2=h5eee18b_0 20 | - tk=8.6.12=h1ccaba5_0 21 | - wheel=0.41.2=py38h06a4308_0 22 | - xz=5.4.2=h5eee18b_0 23 | - zlib=1.2.13=h5eee18b_0 24 | - pip: 25 | - certifi==2023.7.22 26 | - charset-normalizer==3.3.0 27 | - filelock==3.12.4 28 | - fsspec==2023.9.2 29 | - ftfy==6.1.1 30 | - huggingface-hub==0.18.0 31 | - idna==3.4 32 | - jinja2==3.1.2 33 | - markupsafe==2.1.3 34 | - mpmath==1.3.0 35 | - networkx==3.1 36 | - numpy==1.24.4 37 | - nvidia-cublas-cu12==12.1.3.1 38 | - nvidia-cuda-cupti-cu12==12.1.105 39 | - nvidia-cuda-nvrtc-cu12==12.1.105 40 | - nvidia-cuda-runtime-cu12==12.1.105 41 | - nvidia-cudnn-cu12==8.9.2.26 42 | - nvidia-cufft-cu12==11.0.2.54 43 | - nvidia-curand-cu12==10.3.2.106 44 | - nvidia-cusolver-cu12==11.4.5.107 45 | - nvidia-cusparse-cu12==12.1.0.106 46 | - nvidia-nccl-cu12==2.18.1 47 | - nvidia-nvjitlink-cu12==12.2.140 48 | - nvidia-nvtx-cu12==12.1.105 49 | - open-clip-torch==2.22.0 50 | - packaging==23.2 51 | - pillow==10.0.1 52 | - protobuf==3.20.3 53 | - pyyaml==6.0.1 54 | - regex==2023.10.3 55 | - requests==2.31.0 56 | - safetensors==0.4.0 57 | - scipy==1.10.1 58 | - sentencepiece==0.1.99 59 | - sympy==1.12 60 | - timm==0.9.7 61 | - torch==2.1.0 62 | - torchvision==0.16.0 63 | - tqdm==4.66.1 64 | - triton==2.1.0 65 | - typing-extensions==4.8.0 66 | - urllib3==2.0.6 67 | - wcwidth==0.2.8 68 | prefix: /home/matt/anaconda3/envs/neural-priming 69 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | open-clip-torch==2.22.0 2 | scipy==1.10.1 3 | --------------------------------------------------------------------------------