├── data ├── __init__.py ├── datautils.py ├── hoi_dataset.py ├── fewshot_datasets.py ├── augmix_ops.py ├── imagnet_prompts.py ├── cls_to_names.py └── imagenet_variants.py ├── utils ├── __init__.py └── tools.py ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── simple_tokenizer.py ├── clip.py ├── cocoop.py ├── custom_clip.py └── model.py ├── figures └── CTPT_ICLR2024_poster.png ├── scripts ├── test_tpt_fg.sh ├── test_baseline.sh ├── test_tpt_ds.sh ├── test_tpt_ctpt_fg.sh └── test_tpt_ctpt_ds.sh ├── LICENSE ├── README.md ├── environment.yml └── tpt_classification.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .tools import * 2 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | from .custom_clip import * -------------------------------------------------------------------------------- /figures/CTPT_ICLR2024_poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hee-suk-yoon/C-TPT/HEAD/figures/CTPT_ICLR2024_poster.png -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hee-suk-yoon/C-TPT/HEAD/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /scripts/test_tpt_fg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data_root='put your dataset path here' 4 | testsets=$1 5 | #arch=RN50 6 | arch=ViT-B/16 7 | bs=64 8 | ctx_init=a_photo_of_a 9 | run_type=tpt 10 | 11 | python ./tpt_classification.py ${data_root} --test_sets ${testsets} \ 12 | -a ${arch} -b ${bs} --gpu 0 \ 13 | --tpt --ctx_init ${ctx_init} --run_type ${run_type} \ 14 | -------------------------------------------------------------------------------- /scripts/test_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data_root='put your dataset path here' 4 | testsets=$1 5 | #arch=RN50 6 | arch=ViT-B/16 7 | bs=64 8 | 9 | ctx_init=a_photo_of_a 10 | run_type=baseline 11 | 12 | python ./tpt_classification.py ${data_root} --test_sets ${testsets} \ 13 | -a ${arch} -b ${bs} --gpu 0 \ 14 | --tpt --ctx_init ${ctx_init} --run_type ${run_type} \ -------------------------------------------------------------------------------- /scripts/test_tpt_ds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data_root='put your dataset path here' 4 | testsets=$1 5 | #arch=RN50 6 | arch=ViT-B/16 7 | bs=64 8 | ctx_init=a_photo_of_a 9 | run_type=tpt 10 | 11 | python ./tpt_classification.py ${data_root} --test_sets ${testsets} \ 12 | -a ${arch} -b ${bs} --gpu 0 \ 13 | --tpt --ctx_init ${ctx_init} --run_type ${run_type} --I_augmix \ 14 | -------------------------------------------------------------------------------- /scripts/test_tpt_ctpt_fg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data_root='put your dataset path here' 4 | testsets=$1 5 | #arch=RN50 6 | arch=ViT-B/16 7 | bs=64 8 | ctx_init=a_photo_of_a 9 | run_type=tpt_ctpt 10 | lambda_term=50 11 | 12 | python ./tpt_classification.py ${data_root} --test_sets ${testsets} \ 13 | -a ${arch} -b ${bs} --gpu 0 \ 14 | --tpt --ctx_init ${ctx_init} --run_type ${run_type} --lambda_term ${lambda_term} \ 15 | -------------------------------------------------------------------------------- /scripts/test_tpt_ctpt_ds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data_root='put your dataset path here' 4 | testsets=$1 5 | #arch=RN50 6 | arch=ViT-B/16 7 | bs=64 8 | ctx_init=a_photo_of_a 9 | run_type=tpt_ctpt 10 | lambda_term=20 11 | 12 | python ./tpt_classification.py ${data_root} --test_sets ${testsets} \ 13 | -a ${arch} -b ${bs} --gpu 0 \ 14 | --tpt --ctx_init ${ctx_init} --run_type ${run_type} --I_augmix --lambda_term ${lambda_term} \ 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Hee Suk Yoon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # C-TPT: Calibrated Test-Time Prompt Tuning for Vision-Language Models via Text Feature Dispersion (ICLR 2024) 2 | 3 | This repository provides the official implementation of our ICLR 2024 paper: 4 | > C-TPT: Calibrated Test-Time Prompt Tuning for Vision-Language Models via Text Feature Dispersion 5 | > Authors: Hee Suk Yoon*, Eunseop Yoon*, Joshua Tian Jin Tee, Mark Hasegawa-Johnson, Yingzhen Li, Chang D. Yoo 6 | 7 | The implementation is built upon [TPT](https://github.com/azshue/TPT). 8 | 9 | [[Paper Link](https://openreview.net/forum?id=jzzEHTBFOT)] 10 | 11 | ![](figures/CTPT_ICLR2024_poster.png) 12 | 13 | ## Installation 14 | ```bash 15 | # Clone this repo 16 | git clone https://github.com/hee-suk-yoon/C-TPT 17 | cd C-TPT 18 | 19 | # Create a conda enviroment 20 | 1. conda env create -f environment.yml 21 | 2. conda activate ctpt 22 | ``` 23 | 24 | ## Datasets 25 | Our evaluation focuses on 26 | 27 | 1) fine-grained classification: ImageNet, Flower102, OxfordPets, SUN397, DTD, Food101, StanfordCars, Aircraft, UCF101, EuroSAT, Caltech101 28 | 29 | 2) natural distribution shift: ImageNet-V2, ImageNet-A, ImageNet-R, ImageNet-Sketch 30 | 31 | Prepare the datasets based on the following link https://github.com/azshue/TPT. 32 | 33 | ## Running Experiments 34 | 35 | In each of the .sh files, change the {data_root} accordingly. Additionally, you can change the CLIP architecture by modifying the {arch} parameter to either ‘RN50’ or ‘ViT-B/16’. 36 | 37 | 1. Baseline (standard CLIP) 38 | ```bash 39 | bash scripts/test_baseline.sh {dataset} 40 | ``` 41 | 42 | 2. Test-Time Prompt Tuning (TPT) 43 | ```bash 44 | #for Fine-grained classification 45 | bash scripts/test_tpt_fg.sh {dataset} 46 | 47 | #for natural distribution shift 48 | bash scripts/test_tpt_ds.sh {dataset} 49 | 50 | #for temperature scaling experiments, change the run_type to tpt_ts in the .sh file. 51 | ``` 52 | 53 | 3. Calibrated Test-Time Prompt Tuning (C-TPT) 54 | ```bash 55 | #for Fine-grained classification 56 | bash scripts/test_tpt_ctpt_fg.sh {dataset} 57 | 58 | #for natural distribution shift 59 | bash scripts/test_tpt_ctpt_ds.sh {dataset} 60 | ``` 61 | The command line argument {dataset} can be specified as follows: ‘I’, ‘DTD’, ‘Flower102’, ‘Food101’, ‘Cars’, ‘SUN397’, ‘Aircraft’, ‘Pets’, ‘Caltech101’, ‘UCF101’, or ‘eurosat’ for fine-grained classification datasets, and ‘V2’, ‘A’, ‘R’, or ‘K’ for datasets with natural distribution shifts. 62 | 63 | ## Acknowledgement 64 | This work was supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government(MSIT) (No.2022-0-00184, Development and Study of AI Technologies to Inexpensively Conform to Evolving Policy on Ethics), and Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government(MSIT) (No. 2022-0-00951, Development of Uncertainty-Aware Agents Learning by Asking Questions). 65 | 66 | Also, we thank the authors of the [CoOp/CoCoOp](https://github.com/KaiyangZhou/CoOp) and [TPT](https://github.com/azshue/TPT) for their open-source contributions and their assistance with the data preparation. 67 | 68 | ## Citation 69 | If you find our work useful in your research, please cite: 70 | ``` 71 | @inproceedings{ 72 | yoon2024ctpt, 73 | title={C-{TPT}: Calibrated Test-Time Prompt Tuning for Vision-Language Models via Text Feature Dispersion}, 74 | author={Hee Suk Yoon and Eunseop Yoon and Joshua Tian Jin Tee and Mark A. Hasegawa-Johnson and Yingzhen Li and Chang D. Yoo}, 75 | booktitle={The Twelfth International Conference on Learning Representations}, 76 | year={2024}, 77 | url={https://openreview.net/forum?id=jzzEHTBFOT} 78 | } 79 | ``` 80 | 81 | ## Contact 82 | If you have any questions, please feel free to email hskyoon@kaist.ac.kr 83 | -------------------------------------------------------------------------------- /data/datautils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Tuple 3 | from PIL import Image 4 | import numpy as np 5 | 6 | import torch 7 | import torchvision.transforms as transforms 8 | import torchvision.datasets as datasets 9 | 10 | from data.hoi_dataset import BongardDataset 11 | try: 12 | from torchvision.transforms import InterpolationMode 13 | BICUBIC = InterpolationMode.BICUBIC 14 | except ImportError: 15 | BICUBIC = Image.BICUBIC 16 | 17 | from data.fewshot_datasets import * 18 | import data.augmix_ops as augmentations 19 | 20 | import ipdb 21 | 22 | ID_to_DIRNAME={ 23 | 'I': 'ImageNet', 24 | 'A': 'imagenet-a', 25 | 'K': 'ImageNet-Sketch', 26 | 'R': 'imagenet-r', 27 | 'V': 'imagenetv2-matched-frequency-format-val', 28 | 'flower102': 'Flower102', 29 | 'dtd': 'DTD', 30 | 'pets': 'OxfordPets', 31 | 'cars': 'StanfordCars', 32 | 'ucf101': 'UCF101', 33 | 'caltech101': 'Caltech101', 34 | 'food101': 'Food101', 35 | 'sun397': 'SUN397', 36 | 'aircraft': 'fgvc_aircraft', 37 | 'eurosat': 'eurosat' 38 | } 39 | 40 | distortions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 41 | 'defocus_blur', 'glass_blur', 42 | 'zoom_blur', 'frost', 43 | 'brightness', 'contrast', 'elastic_transform', 44 | 'pixelate','fog','speckle_noise','saturate', 'spatter', 'gaussian_blur'] 45 | 46 | 47 | def build_dataset(set_id, transform, data_root, mode='test', n_shot=None, split="all", bongard_anno=False): 48 | if set_id == 'I': 49 | # ImageNet validation set 50 | testdir = os.path.join(os.path.join(data_root, ID_to_DIRNAME[set_id]), 'val') 51 | testset = datasets.ImageFolder(testdir, transform=transform) 52 | elif set_id in ['A', 'K', 'R', 'V']: 53 | testdir = os.path.join(data_root, ID_to_DIRNAME[set_id]) 54 | testset = datasets.ImageFolder(testdir, transform=transform) 55 | elif set_id in fewshot_datasets: 56 | if mode == 'train' and n_shot: 57 | testset = build_fewshot_dataset(set_id, os.path.join(data_root, ID_to_DIRNAME[set_id.lower()]), transform, mode=mode, n_shot=n_shot) 58 | else: 59 | testset = build_fewshot_dataset(set_id, os.path.join(data_root, ID_to_DIRNAME[set_id.lower()]), transform, mode=mode) 60 | 61 | elif set_id == 'bongard': 62 | assert isinstance(transform, Tuple) 63 | base_transform, query_transform = transform 64 | testset = BongardDataset(data_root, split, mode, base_transform, query_transform, bongard_anno) 65 | else: 66 | raise NotImplementedError 67 | 68 | return testset 69 | 70 | 71 | # AugMix Transforms 72 | def get_preaugment(): 73 | return transforms.Compose([ 74 | transforms.RandomResizedCrop(224), 75 | transforms.RandomHorizontalFlip(), 76 | ]) 77 | 78 | def augmix(image, preprocess, aug_list, severity=1): 79 | preaugment = get_preaugment() 80 | x_orig = preaugment(image) 81 | x_processed = preprocess(x_orig) 82 | if len(aug_list) == 0: 83 | return x_processed 84 | w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0])) 85 | m = np.float32(np.random.beta(1.0, 1.0)) 86 | 87 | mix = torch.zeros_like(x_processed) 88 | for i in range(3): 89 | x_aug = x_orig.copy() 90 | for _ in range(np.random.randint(1, 4)): 91 | x_aug = np.random.choice(aug_list)(x_aug, severity) 92 | mix += w[i] * preprocess(x_aug) 93 | mix = m * x_processed + (1 - m) * mix 94 | return mix 95 | 96 | 97 | class AugMixAugmenter(object): 98 | def __init__(self, base_transform, preprocess, n_views=2, augmix=False, 99 | severity=1): 100 | self.base_transform = base_transform 101 | self.preprocess = preprocess 102 | self.n_views = n_views 103 | if augmix: 104 | self.aug_list = augmentations.augmentations 105 | else: 106 | self.aug_list = [] 107 | self.severity = severity 108 | 109 | def __call__(self, x): 110 | image = self.preprocess(self.base_transform(x)) 111 | views = [augmix(x, self.preprocess, self.aug_list, self.severity) for _ in range(self.n_views)] 112 | return [image] + views 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /data/hoi_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import csv 4 | import random 5 | import numpy as np 6 | import scipy.io as sio 7 | 8 | from PIL import Image 9 | from PIL import ImageFile 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | 12 | import torch 13 | from torch.utils.data import Dataset 14 | 15 | # debug dataset 16 | import torchvision.transforms as transforms 17 | 18 | try: 19 | from torchvision.transforms import InterpolationMode 20 | BICUBIC = InterpolationMode.BICUBIC 21 | except ImportError: 22 | BICUBIC = Image.BICUBIC 23 | 24 | 25 | 26 | class BongardDataset(Dataset): 27 | def __init__(self, data_root, data_split='unseen_obj_unseen_act', mode='test', 28 | base_transform=None, query_transform=None, with_annotation=False): 29 | self.base_transform = base_transform 30 | if query_transform is None: 31 | self.query_transform = base_transform 32 | else: 33 | self.query_transform = query_transform 34 | self.data_root = data_root 35 | self.mode = mode 36 | self.with_annotation = with_annotation 37 | 38 | assert mode in ['val', 'test'] 39 | data_file = os.path.join("data/bongard_splits", "bongard_hoi_{}_{}.json".format(self.mode, data_split)) 40 | self.task_list = [] 41 | with open(data_file, "r") as fp: 42 | task_items = json.load(fp) 43 | for task in task_items: 44 | task_data = {} 45 | pos_samples = [] 46 | neg_samples = [] 47 | for sample in task[0]: 48 | neg_samples.append(sample['im_path']) 49 | for sample in task[1]: 50 | pos_samples.append(sample['im_path']) 51 | 52 | # random split samples into support and query images (6 vs. 1 for both pos and neg samples) 53 | task_data['pos_samples'] = pos_samples 54 | task_data['neg_samples'] = neg_samples 55 | task_data['annotation'] = task[-1].replace("++", " ") 56 | self.task_list.append(task_data) 57 | 58 | def __len__(self): 59 | return len(self.task_list) 60 | 61 | def load_image(self, path, transform_type="base_transform"): 62 | im_path = os.path.join(self.data_root, path.replace("./", "")) 63 | if not os.path.isfile(im_path): 64 | print("file not exist: {}".format(im_path)) 65 | if '/pic/image/val' in im_path: 66 | im_path = im_path.replace('val', 'train') 67 | elif '/pic/image/train' in im_path: 68 | im_path = im_path.replace('train', 'val') 69 | try: 70 | image = Image.open(im_path).convert('RGB') 71 | except: 72 | print("File error: ", im_path) 73 | image = Image.open(im_path).convert('RGB') 74 | trans = getattr(self, transform_type) 75 | if trans is not None: 76 | image = trans(image) 77 | return image 78 | 79 | def __getitem__(self, idx): 80 | task = self.task_list[idx] 81 | pos_samples = task['pos_samples'] 82 | neg_samples = task['neg_samples'] 83 | 84 | random.seed(0) 85 | random.shuffle(pos_samples) 86 | random.shuffle(neg_samples) 87 | 88 | f_pos_support = pos_samples[:-1] 89 | f_neg_support = neg_samples[:-1] 90 | pos_images = [self.load_image(f, "base_transform") for f in f_pos_support] 91 | neg_images = [self.load_image(f, "base_transform") for f in f_neg_support] 92 | pos_support = torch.stack(pos_images, dim=0) 93 | neg_support = torch.stack(neg_images, dim=0) 94 | 95 | try: 96 | pos_query = torch.stack(self.load_image(pos_samples[-1], "query_transform"), dim=0) 97 | neg_query = torch.stack(self.load_image(neg_samples[-1], "query_transform"), dim=0) 98 | except: 99 | pos_query = torch.stack([self.load_image(pos_samples[-1], "query_transform")], dim=0) 100 | neg_query = torch.stack([self.load_image(neg_samples[-1], "query_transform")], dim=0) 101 | 102 | support_images = torch.cat((pos_support, neg_support), dim=0) 103 | support_labels = torch.Tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]).long() 104 | query_images = torch.stack([neg_query, pos_query], dim=0) 105 | query_labels = torch.Tensor([1, 0]).long() 106 | 107 | if self.with_annotation: 108 | annotation = task['annotation'] 109 | return support_images, query_images, support_labels, query_labels, annotation 110 | else: 111 | return support_images, query_images, support_labels, query_labels 112 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /data/fewshot_datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import json 5 | import random 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | import PIL 10 | from PIL import Image 11 | 12 | 13 | class BaseJsonDataset(Dataset): 14 | def __init__(self, image_path, json_path, mode='train', n_shot=None, transform=None): 15 | self.transform = transform 16 | self.image_path = image_path 17 | self.split_json = json_path 18 | self.mode = mode 19 | self.image_list = [] 20 | self.label_list = [] 21 | with open(self.split_json) as fp: 22 | splits = json.load(fp) 23 | samples = splits[self.mode] 24 | for s in samples: 25 | self.image_list.append(s[0]) 26 | self.label_list.append(s[1]) 27 | 28 | if n_shot is not None: 29 | few_shot_samples = [] 30 | c_range = max(self.label_list) + 1 31 | for c in range(c_range): 32 | c_idx = [idx for idx, lable in enumerate(self.label_list) if lable == c] 33 | random.seed(0) 34 | few_shot_samples.extend(random.sample(c_idx, n_shot)) 35 | self.image_list = [self.image_list[i] for i in few_shot_samples] 36 | self.label_list = [self.label_list[i] for i in few_shot_samples] 37 | 38 | def __len__(self): 39 | return len(self.image_list) 40 | 41 | def __getitem__(self, idx): 42 | image_path = os.path.join(self.image_path, self.image_list[idx]) 43 | image = Image.open(image_path).convert('RGB') 44 | label = self.label_list[idx] 45 | if self.transform: 46 | image = self.transform(image) 47 | 48 | return image, torch.tensor(label).long() 49 | 50 | fewshot_datasets = ['DTD', 'Flower102', 'Food101', 'Cars', 'SUN397', 51 | 'Aircraft', 'Pets', 'Caltech101', 'UCF101', 'eurosat'] 52 | 53 | path_dict = { 54 | # dataset_name: ["image_dir", "json_split_file"] 55 | "flower102": ["jpg", "data/data_splits/split_zhou_OxfordFlowers.json"], 56 | "food101": ["images", "data/data_splits/split_zhou_Food101.json"], 57 | "dtd": ["images", "data/data_splits/split_zhou_DescribableTextures.json"], 58 | "pets": ["", "data/data_splits/split_zhou_OxfordPets.json"], 59 | "sun397": ["", "data/data_splits/split_zhou_SUN397.json"], 60 | "caltech101": ["", "data/data_splits/split_zhou_Caltech101.json"], 61 | "ucf101": ["", "data/data_splits/split_zhou_UCF101.json"], 62 | "cars": ["", "data/data_splits/split_zhou_StanfordCars.json"], 63 | "eurosat": ["", "data/data_splits/split_zhou_EuroSAT.json"] 64 | } 65 | 66 | def build_fewshot_dataset(set_id, root, transform, mode='train', n_shot=None): 67 | if set_id.lower() == 'aircraft': 68 | return Aircraft(root, mode, n_shot, transform) 69 | path_suffix, json_path = path_dict[set_id.lower()] 70 | image_path = os.path.join(root, path_suffix) 71 | return BaseJsonDataset(image_path, json_path, mode, n_shot, transform) 72 | 73 | 74 | class Aircraft(Dataset): 75 | """ FGVC Aircraft dataset """ 76 | def __init__(self, root, mode='train', n_shot=None, transform=None): 77 | self.transform = transform 78 | self.path = root 79 | self.mode = mode 80 | 81 | self.cname = [] 82 | with open(os.path.join(self.path, "variants.txt"), 'r') as fp: 83 | self.cname = [l.replace("\n", "") for l in fp.readlines()] 84 | 85 | self.image_list = [] 86 | self.label_list = [] 87 | with open(os.path.join(self.path, 'images_variant_{:s}.txt'.format(self.mode)), 'r') as fp: 88 | lines = [s.replace("\n", "") for s in fp.readlines()] 89 | for l in lines: 90 | ls = l.split(" ") 91 | img = ls[0] 92 | label = " ".join(ls[1:]) 93 | self.image_list.append("{}.jpg".format(img)) 94 | self.label_list.append(self.cname.index(label)) 95 | 96 | if n_shot is not None: 97 | few_shot_samples = [] 98 | c_range = max(self.label_list) + 1 99 | for c in range(c_range): 100 | c_idx = [idx for idx, lable in enumerate(self.label_list) if lable == c] 101 | random.seed(0) 102 | few_shot_samples.extend(random.sample(c_idx, n_shot)) 103 | self.image_list = [self.image_list[i] for i in few_shot_samples] 104 | self.label_list = [self.label_list[i] for i in few_shot_samples] 105 | 106 | def __len__(self): 107 | return len(self.image_list) 108 | 109 | def __getitem__(self, idx): 110 | image_path = os.path.join(self.path, 'images', self.image_list[idx]) 111 | image = Image.open(image_path).convert('RGB') 112 | label = self.label_list[idx] 113 | if self.transform: 114 | image = self.transform(image) 115 | 116 | return image, torch.tensor(label).long() 117 | 118 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ctpt 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1 8 | - _openmp_mutex=5.1 9 | - blas=1.0 10 | - brotlipy=0.7.0 11 | - bzip2=1.0.8 12 | - ca-certificates=2023.01.10 13 | - certifi=2022.12.7 14 | - cffi=1.15.1 15 | - cryptography=39.0.1 16 | - cuda-cudart=11.8.89 17 | - cuda-cupti=11.8.87 18 | - cuda-libraries=11.8.0 19 | - cuda-nvrtc=11.8.89 20 | - cuda-nvtx=11.8.86 21 | - cuda-runtime=11.8.0 22 | - ffmpeg=4.3 23 | - filelock=3.9.0 24 | - flit-core=3.8.0 25 | - freetype=2.12.1 26 | - giflib=5.2.1 27 | - gmp=6.2.1 28 | - gmpy2=2.1.2 29 | - gnutls=3.6.15 30 | - idna=3.4 31 | - intel-openmp=2021.4.0 32 | - jinja2=3.1.2 33 | - jpeg=9e 34 | - lame=3.100 35 | - lcms2=2.12 36 | - ld_impl_linux-64=2.38 37 | - lerc=3.0 38 | - libcublas=11.11.3.6 39 | - libcufft=10.9.0.58 40 | - libcufile=1.6.0.25 41 | - libcurand=10.3.2.56 42 | - libcusolver=11.4.1.48 43 | - libcusparse=11.7.5.86 44 | - libdeflate=1.17 45 | - libffi=3.4.2 46 | - libgcc-ng=11.2.0 47 | - libgomp=11.2.0 48 | - libiconv=1.16 49 | - libidn2=2.3.2 50 | - libnpp=11.8.0.86 51 | - libnvjpeg=11.9.0.86 52 | - libpng=1.6.39 53 | - libstdcxx-ng=11.2.0 54 | - libtasn1=4.16.0 55 | - libtiff=4.5.0 56 | - libunistring=0.9.10 57 | - libuuid=1.41.5 58 | - libwebp=1.2.4 59 | - libwebp-base=1.2.4 60 | - lz4-c=1.9.4 61 | - markupsafe=2.1.1 62 | - mkl=2021.4.0 63 | - mkl-service=2.4.0 64 | - mkl_fft=1.3.1 65 | - mkl_random=1.2.2 66 | - mpc=1.1.0 67 | - mpfr=4.0.2 68 | - ncurses=6.4 69 | - nettle=3.7.3 70 | - networkx=2.8.4 71 | - numpy=1.23.5 72 | - numpy-base=1.23.5 73 | - openh264=2.1.1 74 | - openssl=1.1.1t 75 | - pillow=9.4.0 76 | - pip=23.0.1 77 | - pyopenssl=23.0.0 78 | - pysocks=1.7.1 79 | - python=3.10.10 80 | - pytorch=2.0.0 81 | - pytorch-cuda=11.8 82 | - pytorch-mutex=1.0 83 | - readline=8.2 84 | - requests=2.28.1 85 | - setuptools=65.6.3 86 | - sqlite=3.41.1 87 | - sympy=1.11.1 88 | - tk=8.6.12 89 | - torchaudio=2.0.0 90 | - torchtriton=2.0.0 91 | - torchvision=0.15.0 92 | - typing_extensions=4.4.0 93 | - tzdata=2022g 94 | - urllib3=1.26.15 95 | - wheel=0.38.4 96 | - xz=5.2.10 97 | - zlib=1.2.13 98 | - zstd=1.5.4 99 | - pip: 100 | - absl-py==1.4.0 101 | - accelerate==0.24.1 102 | - appdirs==1.4.4 103 | - asttokens==2.2.1 104 | - backcall==0.2.0 105 | - beautifulsoup4==4.12.0 106 | - blessed==1.20.0 107 | - cachetools==5.3.0 108 | - charset-normalizer==2.0.4 109 | - click==8.1.3 110 | - contourpy==1.0.7 111 | - cycler==0.11.0 112 | - decorator==5.1.1 113 | - diffusers==0.23.0 114 | - docker-pycreds==0.4.0 115 | - entrypoints==0.3 116 | - executing==1.2.0 117 | - flake8==3.7.9 118 | - fonttools==4.39.3 119 | - fsspec==2023.10.0 120 | - ftfy==6.1.1 121 | - future==0.18.3 122 | - gdown==4.7.1 123 | - gitdb==4.0.10 124 | - gitpython==3.1.31 125 | - google-auth==2.17.1 126 | - google-auth-oauthlib==1.0.0 127 | - gpustat==1.0.0 128 | - grpcio==1.53.0 129 | - huggingface-hub==0.17.3 130 | - importlib-metadata==6.8.0 131 | - ipdb==0.13.13 132 | - ipython==8.12.0 133 | - isort==4.3.21 134 | - jedi==0.18.2 135 | - joblib==1.2.0 136 | - kiwisolver==1.4.4 137 | - littleutils==0.2.2 138 | - markdown==3.4.3 139 | - matplotlib==3.7.1 140 | - matplotlib-inline==0.1.6 141 | - mccabe==0.6.1 142 | - mpmath==1.2.1 143 | - nvidia-ml-py==11.495.46 144 | - oauthlib==3.2.2 145 | - ogb==1.3.5 146 | - outdated==0.2.2 147 | - packaging==23.1 148 | - pandas==1.5.3 149 | - parso==0.8.3 150 | - pathtools==0.1.2 151 | - pexpect==4.8.0 152 | - pickleshare==0.7.5 153 | - prompt-toolkit==3.0.38 154 | - protobuf==4.22.1 155 | - psutil==5.9.4 156 | - ptyprocess==0.7.0 157 | - pure-eval==0.2.2 158 | - pyasn1==0.4.8 159 | - pyasn1-modules==0.2.8 160 | - pycodestyle==2.5.0 161 | - pycparser==2.21 162 | - pyflakes==2.1.1 163 | - pygments==2.14.0 164 | - pyparsing==3.0.9 165 | - python-dateutil==2.8.2 166 | - pytz==2023.3 167 | - pyyaml==6.0 168 | - regex==2023.3.23 169 | - requests-oauthlib==1.3.1 170 | - rsa==4.9 171 | - safetensors==0.4.0 172 | - scikit-learn==1.2.2 173 | - scipy==1.10.1 174 | - sentry-sdk==1.19.0 175 | - setproctitle==1.3.2 176 | - six==1.16.0 177 | - smmap==5.0.0 178 | - soupsieve==2.4 179 | - stack-data==0.6.2 180 | - tabulate==0.9.0 181 | - tb-nightly==2.13.0a20230401 182 | - tensorboard-data-server==0.7.0 183 | - tensorboard-plugin-wit==1.8.1 184 | - threadpoolctl==3.1.0 185 | - tokenizers==0.14.1 186 | - tomli==2.0.1 187 | - tqdm==4.65.0 188 | - traitlets==5.9.0 189 | - transformers==4.35.0 190 | - wandb==0.14.0 191 | - wcwidth==0.2.6 192 | - werkzeug==2.2.3 193 | - wilds==1.2.2 194 | - yacs==0.1.8 195 | - yapf==0.29.0 196 | - zipp==3.17.0 197 | 198 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /data/augmix_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Base augmentations operators.""" 16 | 17 | import numpy as np 18 | from PIL import Image, ImageOps, ImageEnhance 19 | 20 | # ImageNet code should change this value 21 | IMAGE_SIZE = 224 22 | 23 | 24 | def int_parameter(level, maxval): 25 | """Helper function to scale `val` between 0 and maxval . 26 | 27 | Args: 28 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 29 | maxval: Maximum value that the operation can have. This will be scaled to 30 | level/PARAMETER_MAX. 31 | 32 | Returns: 33 | An int that results from scaling `maxval` according to `level`. 34 | """ 35 | return int(level * maxval / 10) 36 | 37 | 38 | def float_parameter(level, maxval): 39 | """Helper function to scale `val` between 0 and maxval. 40 | 41 | Args: 42 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 43 | maxval: Maximum value that the operation can have. This will be scaled to 44 | level/PARAMETER_MAX. 45 | 46 | Returns: 47 | A float that results from scaling `maxval` according to `level`. 48 | """ 49 | return float(level) * maxval / 10. 50 | 51 | 52 | def sample_level(n): 53 | return np.random.uniform(low=0.1, high=n) 54 | 55 | 56 | def autocontrast(pil_img, _): 57 | return ImageOps.autocontrast(pil_img) 58 | 59 | 60 | def equalize(pil_img, _): 61 | return ImageOps.equalize(pil_img) 62 | 63 | 64 | def posterize(pil_img, level): 65 | level = int_parameter(sample_level(level), 4) 66 | return ImageOps.posterize(pil_img, 4 - level) 67 | 68 | 69 | def rotate(pil_img, level): 70 | degrees = int_parameter(sample_level(level), 30) 71 | if np.random.uniform() > 0.5: 72 | degrees = -degrees 73 | return pil_img.rotate(degrees, resample=Image.BILINEAR) 74 | 75 | 76 | def solarize(pil_img, level): 77 | level = int_parameter(sample_level(level), 256) 78 | return ImageOps.solarize(pil_img, 256 - level) 79 | 80 | 81 | def shear_x(pil_img, level): 82 | level = float_parameter(sample_level(level), 0.3) 83 | if np.random.uniform() > 0.5: 84 | level = -level 85 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 86 | Image.AFFINE, (1, level, 0, 0, 1, 0), 87 | resample=Image.BILINEAR) 88 | 89 | 90 | def shear_y(pil_img, level): 91 | level = float_parameter(sample_level(level), 0.3) 92 | if np.random.uniform() > 0.5: 93 | level = -level 94 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 95 | Image.AFFINE, (1, 0, 0, level, 1, 0), 96 | resample=Image.BILINEAR) 97 | 98 | 99 | def translate_x(pil_img, level): 100 | level = int_parameter(sample_level(level), IMAGE_SIZE / 3) 101 | if np.random.random() > 0.5: 102 | level = -level 103 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 104 | Image.AFFINE, (1, 0, level, 0, 1, 0), 105 | resample=Image.BILINEAR) 106 | 107 | 108 | def translate_y(pil_img, level): 109 | level = int_parameter(sample_level(level), IMAGE_SIZE / 3) 110 | if np.random.random() > 0.5: 111 | level = -level 112 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 113 | Image.AFFINE, (1, 0, 0, 0, 1, level), 114 | resample=Image.BILINEAR) 115 | 116 | 117 | # operation that overlaps with ImageNet-C's test set 118 | def color(pil_img, level): 119 | level = float_parameter(sample_level(level), 1.8) + 0.1 120 | return ImageEnhance.Color(pil_img).enhance(level) 121 | 122 | 123 | # operation that overlaps with ImageNet-C's test set 124 | def contrast(pil_img, level): 125 | level = float_parameter(sample_level(level), 1.8) + 0.1 126 | return ImageEnhance.Contrast(pil_img).enhance(level) 127 | 128 | 129 | # operation that overlaps with ImageNet-C's test set 130 | def brightness(pil_img, level): 131 | level = float_parameter(sample_level(level), 1.8) + 0.1 132 | return ImageEnhance.Brightness(pil_img).enhance(level) 133 | 134 | 135 | # operation that overlaps with ImageNet-C's test set 136 | def sharpness(pil_img, level): 137 | level = float_parameter(sample_level(level), 1.8) + 0.1 138 | return ImageEnhance.Sharpness(pil_img).enhance(level) 139 | 140 | 141 | augmentations = [ 142 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, 143 | translate_x, translate_y 144 | ] 145 | 146 | augmentations_all = [ 147 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, 148 | translate_x, translate_y, color, contrast, brightness, sharpness 149 | ] -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | 5 | import numpy as np 6 | 7 | import shutil 8 | from enum import Enum 9 | 10 | import torch 11 | import torchvision.transforms as transforms 12 | 13 | 14 | def set_random_seed(seed): 15 | random.seed(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | 20 | class Summary(Enum): 21 | NONE = 0 22 | AVERAGE = 1 23 | SUM = 2 24 | COUNT = 3 25 | 26 | class AverageMeter(object): 27 | """Computes and stores the average and current value""" 28 | def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): 29 | self.name = name 30 | self.fmt = fmt 31 | self.summary_type = summary_type 32 | self.reset() 33 | 34 | def reset(self): 35 | self.val = 0 36 | self.avg = 0 37 | self.sum = 0 38 | self.count = 0 39 | 40 | def update(self, val, n=1): 41 | self.val = val 42 | self.sum += val * n 43 | self.count += n 44 | self.avg = self.sum / self.count 45 | 46 | def __str__(self): 47 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 48 | return fmtstr.format(**self.__dict__) 49 | 50 | def summary(self): 51 | fmtstr = '' 52 | if self.summary_type is Summary.NONE: 53 | fmtstr = '' 54 | elif self.summary_type is Summary.AVERAGE: 55 | fmtstr = '{name} {avg:.3f}' 56 | elif self.summary_type is Summary.SUM: 57 | fmtstr = '{name} {sum:.3f}' 58 | elif self.summary_type is Summary.COUNT: 59 | fmtstr = '{name} {count:.3f}' 60 | else: 61 | raise ValueError('invalid summary type %r' % self.summary_type) 62 | 63 | return fmtstr.format(**self.__dict__) 64 | 65 | 66 | class ProgressMeter(object): 67 | def __init__(self, num_batches, meters, prefix=""): 68 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 69 | self.meters = meters 70 | self.prefix = prefix 71 | 72 | def display(self, batch): 73 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 74 | entries += [str(meter) for meter in self.meters] 75 | print('\t'.join(entries)) 76 | 77 | def display_summary(self): 78 | entries = [" *"] 79 | entries += [meter.summary() for meter in self.meters] 80 | print(' '.join(entries)) 81 | 82 | def _get_batch_fmtstr(self, num_batches): 83 | num_digits = len(str(num_batches // 1)) 84 | fmt = '{:' + str(num_digits) + 'd}' 85 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 86 | 87 | 88 | def accuracy(output, target, topk=(1,)): 89 | """Computes the accuracy over the k top predictions for the specified values of k""" 90 | with torch.no_grad(): 91 | maxk = max(topk) 92 | batch_size = target.size(0) 93 | 94 | _, pred = output.topk(maxk, 1, True, True) 95 | pred = pred.t() 96 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 97 | 98 | res = [] 99 | for k in topk: 100 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 101 | res.append(correct_k.mul_(100.0 / batch_size)) 102 | return res 103 | 104 | 105 | def load_model_weight(load_path, model, device, args): 106 | if os.path.isfile(load_path): 107 | print("=> loading checkpoint '{}'".format(load_path)) 108 | checkpoint = torch.load(load_path, map_location=device) 109 | state_dict = checkpoint['state_dict'] 110 | # Ignore fixed token vectors 111 | if "token_prefix" in state_dict: 112 | del state_dict["token_prefix"] 113 | 114 | if "token_suffix" in state_dict: 115 | del state_dict["token_suffix"] 116 | 117 | args.start_epoch = checkpoint['epoch'] 118 | try: 119 | best_acc1 = checkpoint['best_acc1'] 120 | except: 121 | best_acc1 = torch.tensor(0) 122 | if device is not 'cpu': 123 | # best_acc1 may be from a checkpoint from a different GPU 124 | best_acc1 = best_acc1.to(device) 125 | try: 126 | model.load_state_dict(state_dict) 127 | except: 128 | # TODO: implement this method for the generator class 129 | model.prompt_generator.load_state_dict(state_dict, strict=False) 130 | print("=> loaded checkpoint '{}' (epoch {})" 131 | .format(load_path, checkpoint['epoch'])) 132 | del checkpoint 133 | torch.cuda.empty_cache() 134 | else: 135 | print("=> no checkpoint found at '{}'".format(load_path)) 136 | 137 | 138 | def validate(val_loader, model, criterion, args, output_mask=None): 139 | batch_time = AverageMeter('Time', ':6.3f', Summary.NONE) 140 | losses = AverageMeter('Loss', ':.4e', Summary.NONE) 141 | top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE) 142 | top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE) 143 | progress = ProgressMeter( 144 | len(val_loader), 145 | [batch_time, losses, top1, top5], 146 | prefix='Test: ') 147 | 148 | # switch to evaluate mode 149 | model.eval() 150 | 151 | with torch.no_grad(): 152 | end = time.time() 153 | for i, (images, target) in enumerate(val_loader): 154 | if args.gpu is not None: 155 | images = images.cuda(args.gpu, non_blocking=True) 156 | if torch.cuda.is_available(): 157 | target = target.cuda(args.gpu, non_blocking=True) 158 | 159 | # compute output 160 | with torch.cuda.amp.autocast(): 161 | output = model(images) 162 | if output_mask: 163 | output = output[:, output_mask] 164 | loss = criterion(output, target) 165 | 166 | # measure accuracy and record loss 167 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 168 | losses.update(loss.item(), images.size(0)) 169 | top1.update(acc1[0], images.size(0)) 170 | top5.update(acc5[0], images.size(0)) 171 | 172 | # measure elapsed time 173 | batch_time.update(time.time() - end) 174 | end = time.time() 175 | 176 | if i % args.print_freq == 0: 177 | progress.display(i) 178 | progress.display_summary() 179 | 180 | return top1.avg 181 | -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | } 40 | 41 | 42 | def _download(url: str, root: str): 43 | os.makedirs(root, exist_ok=True) 44 | filename = os.path.basename(url) 45 | 46 | expected_sha256 = url.split("/")[-2] 47 | download_target = os.path.join(root, filename) 48 | 49 | if os.path.exists(download_target) and not os.path.isfile(download_target): 50 | raise RuntimeError(f"{download_target} exists and is not a regular file") 51 | 52 | if os.path.isfile(download_target): 53 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 54 | return download_target 55 | else: 56 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 57 | 58 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 59 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 60 | while True: 61 | buffer = source.read(8192) 62 | if not buffer: 63 | break 64 | 65 | output.write(buffer) 66 | loop.update(len(buffer)) 67 | 68 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 69 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 70 | 71 | return download_target 72 | 73 | 74 | def _convert_image_to_rgb(image): 75 | return image.convert("RGB") 76 | 77 | 78 | def _transform(n_px): 79 | return Compose([ 80 | Resize(n_px, interpolation=BICUBIC), 81 | CenterCrop(n_px), 82 | _convert_image_to_rgb, 83 | ToTensor(), 84 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 85 | ]) 86 | 87 | 88 | def available_models() -> List[str]: 89 | """Returns the names of available CLIP models""" 90 | return list(_MODELS.keys()) 91 | 92 | 93 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 94 | """Load a CLIP model 95 | 96 | Parameters 97 | ---------- 98 | name : str 99 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 100 | 101 | device : Union[str, torch.device] 102 | The device to put the loaded model 103 | 104 | jit : bool 105 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 106 | 107 | download_root: str 108 | path to download the model files; by default, it uses "~/.cache/clip" 109 | 110 | Returns 111 | ------- 112 | model : torch.nn.Module 113 | The CLIP model 114 | 115 | preprocess : Callable[[PIL.Image], torch.Tensor] 116 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 117 | """ 118 | if name in _MODELS: 119 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 120 | elif os.path.isfile(name): 121 | model_path = name 122 | else: 123 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 124 | 125 | try: 126 | # loading JIT archive 127 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 128 | state_dict = None 129 | except RuntimeError: 130 | # loading saved state dict 131 | if jit: 132 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 133 | jit = False 134 | state_dict = torch.load(model_path, map_location="cpu") 135 | 136 | embed_dim = model.state_dict()["text_projection"].shape[1] 137 | if not jit: 138 | model = build_model(state_dict or model.state_dict()).to(device) 139 | if str(device) == "cpu": 140 | model.float() 141 | return model, embed_dim, _transform(model.visual.input_resolution) 142 | 143 | # patch the device names 144 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 145 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 146 | 147 | def patch_device(module): 148 | try: 149 | graphs = [module.graph] if hasattr(module, "graph") else [] 150 | except RuntimeError: 151 | graphs = [] 152 | 153 | if hasattr(module, "forward1"): 154 | graphs.append(module.forward1.graph) 155 | 156 | for graph in graphs: 157 | for node in graph.findAllNodes("prim::Constant"): 158 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 159 | node.copyAttributes(device_node) 160 | 161 | model.apply(patch_device) 162 | patch_device(model.encode_image) 163 | patch_device(model.encode_text) 164 | 165 | # patch dtype to float32 on CPU 166 | if str(device) == "cpu": 167 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 168 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 169 | float_node = float_input.node() 170 | 171 | def patch_float(module): 172 | try: 173 | graphs = [module.graph] if hasattr(module, "graph") else [] 174 | except RuntimeError: 175 | graphs = [] 176 | 177 | if hasattr(module, "forward1"): 178 | graphs.append(module.forward1.graph) 179 | 180 | for graph in graphs: 181 | for node in graph.findAllNodes("aten::to"): 182 | inputs = list(node.inputs()) 183 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 184 | if inputs[i].node()["value"] == 5: 185 | inputs[i].node().copyAttributes(float_node) 186 | 187 | model.apply(patch_float) 188 | patch_float(model.encode_image) 189 | patch_float(model.encode_text) 190 | 191 | model.float() 192 | 193 | return model, embed_dim, _transform(model.input_resolution.item()) 194 | 195 | 196 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 197 | """ 198 | Returns the tokenized representation of given input string(s) 199 | 200 | Parameters 201 | ---------- 202 | texts : Union[str, List[str]] 203 | An input string or a list of input strings to tokenize 204 | 205 | context_length : int 206 | The context length to use; all CLIP models use 77 as the context length 207 | 208 | truncate: bool 209 | Whether to truncate the text in case its encoding is longer than the context length 210 | 211 | Returns 212 | ------- 213 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 214 | """ 215 | if isinstance(texts, str): 216 | texts = [texts] 217 | 218 | sot_token = _tokenizer.encoder["<|startoftext|>"] 219 | eot_token = _tokenizer.encoder["<|endoftext|>"] 220 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 221 | 222 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 223 | 224 | for i, tokens in enumerate(all_tokens): 225 | if len(tokens) > context_length: 226 | if truncate: 227 | tokens = tokens[:context_length] 228 | tokens[-1] = eot_token 229 | else: 230 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 231 | result[i, :len(tokens)] = torch.tensor(tokens) 232 | 233 | return result 234 | -------------------------------------------------------------------------------- /clip/cocoop.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from clip import load, tokenize 9 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 10 | from .custom_clip import TextEncoder 11 | from data.imagnet_prompts import imagenet_classes 12 | from data.cls_to_names import * 13 | from data.fewshot_datasets import fewshot_datasets 14 | import ipdb 15 | 16 | 17 | _tokenizer = _Tokenizer() 18 | 19 | DOWNLOAD_ROOT='~/.cache/clip' 20 | 21 | class CoCoOpPromptLearner(nn.Module): 22 | def __init__(self, clip_model, classnames, n_ctx=4, ctx_init="a_photo_of_a", ctx_position='end'): 23 | super().__init__() 24 | n_cls = len(classnames) 25 | dtype = clip_model.dtype 26 | self.dtype = dtype 27 | self.device = clip_model.visual.conv1.weight.device 28 | ctx_dim = clip_model.ln_final.weight.shape[0] 29 | embed_dim = clip_model.text_projection.shape[1] 30 | self.ctx_dim = ctx_dim 31 | 32 | if ctx_init: 33 | # use given words to initialize context vectors 34 | print("Initializing the contect with given words: [{}]".format(ctx_init)) 35 | ctx_init = ctx_init.replace("_", " ") 36 | n_ctx = len(ctx_init.split(" ")) 37 | prompt = tokenize(ctx_init).to(self.device) 38 | with torch.no_grad(): 39 | embedding = clip_model.token_embedding(prompt).type(dtype) 40 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :] 41 | prompt_prefix = ctx_init 42 | 43 | else: 44 | print("Random initialization: initializing a generic context") 45 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 46 | nn.init.normal_(ctx_vectors, std=0.02) 47 | prompt_prefix = " ".join(["X"] * n_ctx) 48 | 49 | print(f'Initial context: "{prompt_prefix}"') 50 | print(f"Number of context words (tokens): {n_ctx}") 51 | self.prompt_prefix = prompt_prefix 52 | 53 | self.ctx = nn.Parameter(ctx_vectors) # to be optimized 54 | self.meta_net = nn.Sequential(OrderedDict([ 55 | ("linear1", nn.Linear(embed_dim, embed_dim // 16)), 56 | ("relu", nn.ReLU(inplace=True)), 57 | ("linear2", nn.Linear(embed_dim // 16, ctx_dim)) 58 | ])) 59 | 60 | classnames = [name.replace("_", " ") for name in classnames] 61 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 62 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 63 | 64 | tokenized_prompts = torch.cat([tokenize(p) for p in prompts]).to(self.device) 65 | with torch.no_grad(): 66 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) 67 | 68 | # These token vectors will be saved when in save_model(), 69 | # but they should be ignored in load_model() as we want to use 70 | # those computed using the current class names 71 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 72 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS 73 | 74 | self.ctx_init = ctx_init 75 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 76 | self.name_lens = name_lens 77 | self.class_token_position = ctx_position 78 | self.n_cls = n_cls 79 | self.n_ctx = n_ctx 80 | 81 | def construct_prompts(self, ctx, prefix, suffix, label=None): 82 | # dim0 is either batch_size (during training) or n_cls (during testing) 83 | # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim) 84 | # prefix: the sos token, with shape of (n_cls, 1, ctx_dim) 85 | # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim) 86 | 87 | if label is not None: 88 | prefix = prefix[label] 89 | suffix = suffix[label] 90 | 91 | prompts = torch.cat( 92 | [ 93 | prefix, # (dim0, 1, dim) 94 | ctx, # (dim0, n_ctx, dim) 95 | suffix, # (dim0, *, dim) 96 | ], 97 | dim=1, 98 | ) 99 | 100 | return prompts 101 | 102 | def reset_classnames(self, classnames, arch): 103 | self.n_cls = len(classnames) 104 | classnames = [name.replace("_", " ") for name in classnames] 105 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 106 | prompts = [self.prompt_prefix + " " + name + "." for name in classnames] 107 | tokenized_prompts = torch.cat([tokenize(p) for p in prompts]).to(self.device) 108 | 109 | clip, _, _ = load(arch, device=self.device, download_root=DOWNLOAD_ROOT) 110 | 111 | with torch.no_grad(): 112 | embedding = clip.token_embedding(tokenized_prompts).type(self.dtype) 113 | 114 | self.token_prefix = embedding[:, :1, :] 115 | self.token_suffix = embedding[:, 1 + self.n_ctx :, :] # CLS, EOS 116 | 117 | self.name_lens = name_lens 118 | self.tokenized_prompts = tokenized_prompts 119 | 120 | def forward(self, im_features, ctx_only=False): 121 | prefix = self.token_prefix 122 | suffix = self.token_suffix 123 | ctx = self.ctx # (n_ctx, ctx_dim) 124 | bias = self.meta_net(im_features) # (batch, ctx_dim) 125 | bias = bias.unsqueeze(1) # (batch, 1, ctx_dim) 126 | ctx = ctx.unsqueeze(0) # (1, n_ctx, ctx_dim) 127 | ctx_shifted = ctx + bias # (batch, n_ctx, ctx_dim) 128 | if ctx_only: 129 | return ctx_shifted # don't expand to n_cls, optimize one ctx for all classes 130 | 131 | # Use instance-conditioned context tokens for all classes 132 | prompts = [] 133 | for ctx_shifted_i in ctx_shifted: 134 | ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1) 135 | pts_i = self.construct_prompts(ctx_i, prefix, suffix) # (n_cls, n_tkn, ctx_dim) 136 | prompts.append(pts_i) 137 | prompts = torch.stack(prompts) 138 | 139 | return prompts 140 | 141 | class CoCoOpCLIP(nn.Module): 142 | def __init__(self, device, classnames, criterion='cosine', arch="ViT-L/14", 143 | n_ctx=16, ctx_init="a_photo_of_a", ctx_position='end'): 144 | super().__init__() 145 | clip, _, _ = load(arch, device=device, download_root=DOWNLOAD_ROOT) 146 | self.image_encoder = clip.visual 147 | self.text_encoder = TextEncoder(clip) 148 | self.logit_scale = clip.logit_scale.data 149 | # prompt tuning 150 | self.prompt_generator = CoCoOpPromptLearner(clip, classnames, n_ctx, ctx_init, ctx_position) 151 | self.tokenized_prompts = self.prompt_generator.tokenized_prompts 152 | self.criterion = criterion 153 | self.dtype = clip.dtype 154 | 155 | def inference(self, image, label=None): 156 | tokenized_prompts = self.prompt_generator.tokenized_prompts 157 | logit_scale = self.logit_scale.exp() 158 | 159 | image_features = self.image_encoder(image.type(self.dtype)) 160 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 161 | 162 | prompts = self.prompt_generator(image_features) 163 | 164 | logits = [] 165 | for pts_i, imf_i in zip(prompts, image_features): 166 | text_features = self.text_encoder(pts_i, tokenized_prompts) 167 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 168 | l_i = logit_scale * imf_i @ text_features.t() 169 | logits.append(l_i) 170 | 171 | logits = torch.stack(logits) 172 | return logits 173 | 174 | def gen_ctx(self, image, aug=False): 175 | with torch.no_grad(): 176 | with torch.cuda.amp.autocast(): 177 | image_features = self.image_encoder(image.type(self.dtype)) 178 | if aug: 179 | image_feature_avg = image_features[0].unsqueeze(0) 180 | else: 181 | image_feature_avg = image_features.mean(dim=0, keepdim=True) 182 | ctx = self.prompt_generator(image_feature_avg, ctx_only=True) 183 | 184 | return image_features, ctx.detach().clone() 185 | 186 | def forward_ctx(self, image_features, ctx): 187 | N = 1 188 | 189 | prefix = self.prompt_generator.token_prefix.expand(N, -1, -1, -1) # [N, n_cls, 1, dim] 190 | suffix = self.prompt_generator.token_suffix.expand(N, -1, -1, -1) 191 | # expand `ctx` n_cls times 192 | ctx = ctx.expand(self.prompt_generator.n_cls, -1, -1, -1) 193 | ctx = ctx.permute(1, 0, 2, 3) 194 | # ctx = ctx.reshape(N, self.prompt_generator.n_cls, -1, self.prompt_generator.ctx_dim) 195 | 196 | prompts = torch.cat([ 197 | prefix, 198 | ctx, 199 | suffix 200 | ], dim=-2) 201 | 202 | # full_n_ctx = prompts.size()[-2] 203 | 204 | prompts = prompts.reshape(N * self.prompt_generator.n_cls, -1, self.prompt_generator.ctx_dim) 205 | tokenized_prompts = self.prompt_generator.tokenized_prompts 206 | tokenized_prompts = tokenized_prompts.repeat(N, 1) 207 | text_features = self.text_encoder(prompts, tokenized_prompts) 208 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 209 | 210 | #[c-tpt] -------------------------------------------- 211 | if self.l2_norm_cal: 212 | prompt_mean = text_features.mean(0) 213 | feature_distance = text_features - prompt_mean 214 | l2_norm = torch.linalg.norm(feature_distance, dim=-1) 215 | l2_norm_mean = l2_norm.mean() 216 | 217 | #for saving to csv file 218 | self.l2_norm_mean = l2_norm_mean.item() 219 | 220 | #for training 221 | self.l2_norm_mean_training = l2_norm_mean 222 | #----------------------------------------------------- 223 | 224 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 225 | 226 | logit_scale = self.logit_scale.exp() 227 | 228 | text_features = text_features.squeeze(0) 229 | logits = logit_scale * image_features @ text_features.t() 230 | 231 | return logits 232 | 233 | def forward(self, input): 234 | if isinstance(input, Tuple): 235 | image_features, ctx = input 236 | return self.forward_ctx(image_features, ctx) 237 | else: 238 | return self.inference(input) 239 | 240 | def get_cocoop(clip_arch, test_set, device, n_ctx): 241 | if test_set in fewshot_datasets: 242 | classnames = eval("{}_classes".format(test_set.lower())) 243 | else: 244 | classnames = imagenet_classes 245 | 246 | model = CoCoOpCLIP(device, classnames, arch=clip_arch, n_ctx=n_ctx) 247 | 248 | return model -------------------------------------------------------------------------------- /clip/custom_clip.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | from typing import List, Tuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from clip import load, tokenize 10 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 11 | from data.imagnet_prompts import imagenet_classes 12 | from data.fewshot_datasets import fewshot_datasets 13 | from data.cls_to_names import * 14 | 15 | import ipdb 16 | 17 | _tokenizer = _Tokenizer() 18 | 19 | DOWNLOAD_ROOT='~/.cache/clip' 20 | 21 | class ClipImageEncoder(nn.Module): 22 | def __init__(self, device, arch="ViT-L/14", image_resolution=224, n_class=1000): 23 | super(ClipImageEncoder, self).__init__() 24 | clip, embed_dim, _ = load(arch, device=device, download_root=DOWNLOAD_ROOT) 25 | self.encoder = clip.visual 26 | del clip.transformer 27 | torch.cuda.empty_cache() 28 | 29 | self.cls_head = nn.Linear(embed_dim, n_class) 30 | 31 | @property 32 | def dtype(self): 33 | return self.encoder.conv1.weight.dtype 34 | 35 | def forward(self, image): 36 | x = self.encoder(image.type(self.dtype)) 37 | output = self.cls_head(x) 38 | return output 39 | 40 | 41 | class TextEncoder(nn.Module): 42 | def __init__(self, clip_model): 43 | super().__init__() 44 | self.transformer = clip_model.transformer 45 | self.positional_embedding = clip_model.positional_embedding 46 | self.ln_final = clip_model.ln_final 47 | self.text_projection = clip_model.text_projection 48 | self.dtype = clip_model.dtype 49 | 50 | def forward(self, prompts, tokenized_prompts): 51 | x = prompts + self.positional_embedding.type(self.dtype) 52 | x = x.permute(1, 0, 2) # NLD -> LND 53 | x = self.transformer(x) 54 | x = x.permute(1, 0, 2) # LND -> NLD 55 | x = self.ln_final(x).type(self.dtype) 56 | 57 | # x.shape = [batch_size, n_ctx, transformer.width] 58 | # take features from the eot embedding (eot_token is the highest number in each sequence) 59 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection 60 | 61 | return x 62 | 63 | 64 | class PromptLearner(nn.Module): 65 | def __init__(self, clip_model, classnames, batch_size=None, n_ctx=16, ctx_init=None, ctx_position='end', learned_cls=False): 66 | super().__init__() 67 | n_cls = len(classnames) 68 | self.learned_cls = learned_cls 69 | dtype = clip_model.dtype 70 | self.dtype = dtype 71 | self.device = clip_model.visual.conv1.weight.device 72 | ctx_dim = clip_model.ln_final.weight.shape[0] 73 | self.ctx_dim = ctx_dim 74 | self.batch_size = batch_size 75 | 76 | # self.ctx, prompt_prefix = self.reset_prompt(ctx_dim, ctx_init, clip_model) 77 | 78 | if ctx_init: 79 | # use given words to initialize context vectors 80 | print("Initializing the contect with given words: [{}]".format(ctx_init)) 81 | ctx_init = ctx_init.replace("_", " ") 82 | if '[CLS]' in ctx_init: 83 | ctx_list = ctx_init.split(" ") 84 | split_idx = ctx_list.index("[CLS]") 85 | ctx_init = ctx_init.replace("[CLS] ", "") 86 | ctx_position = "middle" 87 | else: 88 | split_idx = None 89 | self.split_idx = split_idx 90 | n_ctx = len(ctx_init.split(" ")) 91 | prompt = tokenize(ctx_init).to(self.device) 92 | with torch.no_grad(): 93 | embedding = clip_model.token_embedding(prompt).type(dtype) 94 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :] 95 | prompt_prefix = ctx_init 96 | else: 97 | print("Random initialization: initializing a generic context") 98 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 99 | nn.init.normal_(ctx_vectors, std=0.02) 100 | prompt_prefix = " ".join(["X"] * n_ctx) 101 | 102 | self.prompt_prefix = prompt_prefix 103 | 104 | print(f'Initial context: "{prompt_prefix}"') 105 | print(f"Number of context words (tokens): {n_ctx}") 106 | 107 | # batch-wise prompt tuning for test-time adaptation 108 | if self.batch_size is not None: 109 | ctx_vectors = ctx_vectors.repeat(batch_size, 1, 1) #(N, L, D) 110 | self.ctx_init_state = ctx_vectors.detach().clone() 111 | self.ctx = nn.Parameter(ctx_vectors) # to be optimized 112 | 113 | if not self.learned_cls: 114 | classnames = [name.replace("_", " ") for name in classnames] 115 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 116 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 117 | else: 118 | print("Random initialization: initializing a learnable class token") 119 | cls_vectors = torch.empty(n_cls, 1, ctx_dim, dtype=dtype) # assume each learnable cls_token is only 1 word 120 | nn.init.normal_(cls_vectors, std=0.02) 121 | cls_token = "X" 122 | name_lens = [1 for _ in classnames] 123 | prompts = [prompt_prefix + " " + cls_token + "." for _ in classnames] 124 | 125 | self.cls_init_state = cls_vectors.detach().clone() 126 | self.cls = nn.Parameter(cls_vectors) # to be optimized 127 | 128 | tokenized_prompts = torch.cat([tokenize(p) for p in prompts]).to(self.device) 129 | with torch.no_grad(): 130 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) 131 | 132 | # These token vectors will be saved when in save_model(), 133 | # but they should be ignored in load_model() as we want to use 134 | # those computed using the current class names 135 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 136 | if self.learned_cls: 137 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx + 1:, :]) # ..., EOS 138 | else: 139 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS 140 | 141 | self.ctx_init = ctx_init 142 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 143 | self.name_lens = name_lens 144 | self.class_token_position = ctx_position 145 | self.n_cls = n_cls 146 | self.n_ctx = n_ctx 147 | self.classnames = classnames 148 | 149 | def reset(self): 150 | ctx_vectors = self.ctx_init_state 151 | self.ctx.copy_(ctx_vectors) # to be optimized 152 | if self.learned_cls: 153 | cls_vectors = self.cls_init_state 154 | self.cls.copy_(cls_vectors) 155 | 156 | def reset_classnames(self, classnames, arch): 157 | self.n_cls = len(classnames) 158 | if not self.learned_cls: 159 | classnames = [name.replace("_", " ") for name in classnames] 160 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 161 | prompts = [self.prompt_prefix + " " + name + "." for name in classnames] 162 | else: 163 | cls_vectors = torch.empty(self.n_cls, 1, self.ctx_dim, dtype=self.dtype) # assume each learnable cls_token is only 1 word 164 | nn.init.normal_(cls_vectors, std=0.02) 165 | cls_token = "X" 166 | name_lens = [1 for _ in classnames] 167 | prompts = [self.prompt_prefix + " " + cls_token + "." for _ in classnames] 168 | # TODO: re-init the cls parameters 169 | # self.cls = nn.Parameter(cls_vectors) # to be optimized 170 | self.cls_init_state = cls_vectors.detach().clone() 171 | tokenized_prompts = torch.cat([tokenize(p) for p in prompts]).to(self.device) 172 | 173 | clip, _, _ = load(arch, device=self.device, download_root=DOWNLOAD_ROOT) 174 | 175 | with torch.no_grad(): 176 | embedding = clip.token_embedding(tokenized_prompts).type(self.dtype) 177 | 178 | self.token_prefix = embedding[:, :1, :] 179 | self.token_suffix = embedding[:, 1 + self.n_ctx :, :] # CLS, EOS 180 | 181 | self.name_lens = name_lens 182 | self.tokenized_prompts = tokenized_prompts 183 | self.classnames = classnames 184 | 185 | def forward(self, init=None): 186 | # the init will be used when computing CLIP directional loss 187 | if init is not None: 188 | ctx = init 189 | else: 190 | ctx = self.ctx 191 | if ctx.dim() == 2: 192 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) 193 | elif not ctx.size()[0] == self.n_cls: 194 | ctx = ctx.unsqueeze(1).expand(-1, self.n_cls, -1, -1) 195 | 196 | prefix = self.token_prefix 197 | suffix = self.token_suffix 198 | if self.batch_size is not None: 199 | # This way only works for single-gpu setting (could pass batch size as an argument for forward()) 200 | prefix = prefix.repeat(self.batch_size, 1, 1, 1) 201 | suffix = suffix.repeat(self.batch_size, 1, 1, 1) 202 | 203 | if self.learned_cls: 204 | assert self.class_token_position == "end" 205 | if self.class_token_position == "end": 206 | if self.learned_cls: 207 | cls = self.cls 208 | prompts = torch.cat( 209 | [ 210 | prefix, # (n_cls, 1, dim) 211 | ctx, # (n_cls, n_ctx, dim) 212 | cls, # (n_cls, 1, dim) 213 | suffix, # (n_cls, *, dim) 214 | ], 215 | dim=-2, 216 | ) 217 | else: 218 | prompts = torch.cat( 219 | [ 220 | prefix, # (n_cls, 1, dim) 221 | ctx, # (n_cls, n_ctx, dim) 222 | suffix, # (n_cls, *, dim) 223 | ], 224 | dim=-2, 225 | ) 226 | elif self.class_token_position == "middle": 227 | # TODO: to work with a batch of prompts 228 | if self.split_idx is not None: 229 | half_n_ctx = self.split_idx # split the ctx at the position of [CLS] in `ctx_init` 230 | else: 231 | half_n_ctx = self.n_ctx // 2 232 | prompts = [] 233 | for i in range(self.n_cls): 234 | name_len = self.name_lens[i] 235 | prefix_i = prefix[i : i + 1, :, :] 236 | class_i = suffix[i : i + 1, :name_len, :] 237 | suffix_i = suffix[i : i + 1, name_len:, :] 238 | ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :] 239 | ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :] 240 | prompt = torch.cat( 241 | [ 242 | prefix_i, # (1, 1, dim) 243 | ctx_i_half1, # (1, n_ctx//2, dim) 244 | class_i, # (1, name_len, dim) 245 | ctx_i_half2, # (1, n_ctx//2, dim) 246 | suffix_i, # (1, *, dim) 247 | ], 248 | dim=1, 249 | ) 250 | prompts.append(prompt) 251 | prompts = torch.cat(prompts, dim=0) 252 | 253 | elif self.class_token_position == "front": 254 | prompts = [] 255 | for i in range(self.n_cls): 256 | name_len = self.name_lens[i] 257 | prefix_i = prefix[i : i + 1, :, :] 258 | class_i = suffix[i : i + 1, :name_len, :] 259 | suffix_i = suffix[i : i + 1, name_len:, :] 260 | ctx_i = ctx[i : i + 1, :, :] 261 | prompt = torch.cat( 262 | [ 263 | prefix_i, # (1, 1, dim) 264 | class_i, # (1, name_len, dim) 265 | ctx_i, # (1, n_ctx, dim) 266 | suffix_i, # (1, *, dim) 267 | ], 268 | dim=1, 269 | ) 270 | prompts.append(prompt) 271 | prompts = torch.cat(prompts, dim=0) 272 | 273 | else: 274 | raise ValueError 275 | 276 | return prompts 277 | 278 | 279 | class ClipTestTimeTuning(nn.Module): 280 | def __init__(self, device, classnames, batch_size, criterion='cosine', arch="ViT-L/14", 281 | n_ctx=16, ctx_init=None, ctx_position='end', learned_cls=False): 282 | super(ClipTestTimeTuning, self).__init__() 283 | clip, _, _ = load(arch, device=device, download_root=DOWNLOAD_ROOT) 284 | self.image_encoder = clip.visual 285 | self.text_encoder = TextEncoder(clip) 286 | self.logit_scale = clip.logit_scale.data 287 | # prompt tuning 288 | self.prompt_learner = PromptLearner(clip, classnames, batch_size, n_ctx, ctx_init, ctx_position, learned_cls) 289 | self.criterion = criterion 290 | 291 | @property 292 | def dtype(self): 293 | return self.image_encoder.conv1.weight.dtype 294 | 295 | # restore the initial state of the prompt_learner (tunable prompt) 296 | def reset(self): 297 | self.prompt_learner.reset() 298 | 299 | def reset_classnames(self, classnames, arch): 300 | self.prompt_learner.reset_classnames(classnames, arch) 301 | 302 | def get_text_features(self): 303 | text_features = [] 304 | prompts = self.prompt_learner() 305 | tokenized_prompts = self.prompt_learner.tokenized_prompts 306 | t_features = self.text_encoder(prompts, tokenized_prompts) 307 | text_features.append(t_features / t_features.norm(dim=-1, keepdim=True)) 308 | text_features = torch.stack(text_features, dim=0) 309 | 310 | return torch.mean(text_features, dim=0) 311 | 312 | def inference(self, image): 313 | with torch.no_grad(): 314 | image_features = self.image_encoder(image.type(self.dtype)) 315 | 316 | text_features = self.get_text_features() 317 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 318 | 319 | #[c-tpt] -------------------------------------------- 320 | if self.l2_norm_cal: 321 | prompt_mean = text_features.mean(0) 322 | feature_distance = text_features - prompt_mean 323 | l2_norm = torch.linalg.norm(feature_distance, dim=-1) 324 | l2_norm_mean = l2_norm.mean() 325 | 326 | #for saving to csv file 327 | self.l2_norm_mean = l2_norm_mean.item() 328 | 329 | #for training 330 | self.l2_norm_mean_training = l2_norm_mean 331 | 332 | #----------------------------------------------------- 333 | 334 | logit_scale = self.logit_scale.exp() 335 | logits = logit_scale * image_features @ text_features.t() 336 | 337 | return logits 338 | 339 | def forward(self, input): 340 | if isinstance(input, Tuple): 341 | view_0, view_1, view_2 = input 342 | return self.contrast_prompt_tuning(view_0, view_1, view_2) 343 | elif len(input.size()) == 2: 344 | return self.directional_prompt_tuning(input) 345 | else: 346 | return self.inference(input) 347 | 348 | 349 | def get_coop(clip_arch, test_set, device, n_ctx, ctx_init, learned_cls=False): 350 | if test_set in fewshot_datasets: 351 | classnames = eval("{}_classes".format(test_set.lower())) 352 | elif test_set == 'bongard': 353 | if learned_cls: 354 | classnames = ['X', 'X'] 355 | else: 356 | classnames = ['True', 'False'] 357 | else: 358 | classnames = imagenet_classes 359 | 360 | model = ClipTestTimeTuning(device, classnames, None, arch=clip_arch, 361 | n_ctx=n_ctx, ctx_init=ctx_init, learned_cls=learned_cls) 362 | 363 | return model 364 | 365 | -------------------------------------------------------------------------------- /data/imagnet_prompts.py: -------------------------------------------------------------------------------- 1 | imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] 2 | 3 | imagenet_templates = [ 4 | 'a bad photo of a {}.', 5 | 'a photo of many {}.', 6 | 'a sculpture of a {}.', 7 | 'a photo of the hard to see {}.', 8 | 'a low resolution photo of the {}.', 9 | 'a rendering of a {}.', 10 | 'graffiti of a {}.', 11 | 'a bad photo of the {}.', 12 | 'a cropped photo of the {}.', 13 | 'a tattoo of a {}.', 14 | 'the embroidered {}.', 15 | 'a photo of a hard to see {}.', 16 | 'a bright photo of a {}.', 17 | 'a photo of a clean {}.', 18 | 'a photo of a dirty {}.', 19 | 'a dark photo of the {}.', 20 | 'a drawing of a {}.', 21 | 'a photo of my {}.', 22 | 'the plastic {}.', 23 | 'a photo of the cool {}.', 24 | 'a close-up photo of a {}.', 25 | 'a black and white photo of the {}.', 26 | 'a painting of the {}.', 27 | 'a painting of a {}.', 28 | 'a pixelated photo of the {}.', 29 | 'a sculpture of the {}.', 30 | 'a bright photo of the {}.', 31 | 'a cropped photo of a {}.', 32 | 'a plastic {}.', 33 | 'a photo of the dirty {}.', 34 | 'a jpeg corrupted photo of a {}.', 35 | 'a blurry photo of the {}.', 36 | 'a photo of the {}.', 37 | 'a good photo of the {}.', 38 | 'a rendering of the {}.', 39 | 'a {} in a video game.', 40 | 'a photo of one {}.', 41 | 'a doodle of a {}.', 42 | 'a close-up photo of the {}.', 43 | 'a photo of a {}.', 44 | 'the origami {}.', 45 | 'the {} in a video game.', 46 | 'a sketch of a {}.', 47 | 'a doodle of the {}.', 48 | 'a origami {}.', 49 | 'a low resolution photo of a {}.', 50 | 'the toy {}.', 51 | 'a rendition of the {}.', 52 | 'a photo of the clean {}.', 53 | 'a photo of a large {}.', 54 | 'a rendition of a {}.', 55 | 'a photo of a nice {}.', 56 | 'a photo of a weird {}.', 57 | 'a blurry photo of a {}.', 58 | 'a cartoon {}.', 59 | 'art of a {}.', 60 | 'a sketch of the {}.', 61 | 'a embroidered {}.', 62 | 'a pixelated photo of a {}.', 63 | 'itap of the {}.', 64 | 'a jpeg corrupted photo of the {}.', 65 | 'a good photo of a {}.', 66 | 'a plushie {}.', 67 | 'a photo of the nice {}.', 68 | 'a photo of the small {}.', 69 | 'a photo of the weird {}.', 70 | 'the cartoon {}.', 71 | 'art of the {}.', 72 | 'a drawing of the {}.', 73 | 'a photo of the large {}.', 74 | 'a black and white photo of a {}.', 75 | 'the plushie {}.', 76 | 'a dark photo of a {}.', 77 | 'itap of a {}.', 78 | 'graffiti of the {}.', 79 | 'a toy {}.', 80 | 'itap of my {}.', 81 | 'a photo of a cool {}.', 82 | 'a photo of a small {}.', 83 | 'a tattoo of the {}.', 84 | ] 85 | 86 | tip_imagenet_templates = [ 87 | 'a bad photo of the {}.', 88 | 'a {} in a video game.', 89 | 'a origami {}.', 90 | 'a photo of the small {}.', 91 | 'art of the {}.', 92 | 'a photo of the large {}.', 93 | 'itap of a {}.', 94 | ] 95 | 96 | tip_imagenet_templates_v0 = [ 97 | 'a bad photo of a {}.', 98 | 'a {} in a video game.', 99 | 'a origami of a {}.', 100 | 'a photo of the small {}.', 101 | 'art of the {}.', 102 | 'a photo of the large {}.', 103 | 'itap of a {}.', 104 | ] -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.relu3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.relu1(self.bn1(self.conv1(x))) 46 | out = self.relu2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x[:1], key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0, 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | return x.squeeze(0) 92 | 93 | 94 | class ModifiedResNet(nn.Module): 95 | """ 96 | A ResNet class that is similar to torchvision's but contains the following changes: 97 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 98 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 99 | - The final pooling layer is a QKV attention instead of an average pool 100 | """ 101 | 102 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 103 | super().__init__() 104 | self.output_dim = output_dim 105 | self.input_resolution = input_resolution 106 | 107 | # the 3-layer stem 108 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 109 | self.bn1 = nn.BatchNorm2d(width // 2) 110 | self.relu1 = nn.ReLU(inplace=True) 111 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 112 | self.bn2 = nn.BatchNorm2d(width // 2) 113 | self.relu2 = nn.ReLU(inplace=True) 114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 115 | self.bn3 = nn.BatchNorm2d(width) 116 | self.relu3 = nn.ReLU(inplace=True) 117 | self.avgpool = nn.AvgPool2d(2) 118 | 119 | # residual layers 120 | self._inplanes = width # this is a *mutable* variable used during construction 121 | self.layer1 = self._make_layer(width, layers[0]) 122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 125 | 126 | embed_dim = width * 32 # the ResNet feature dimension 127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 128 | 129 | def _make_layer(self, planes, blocks, stride=1): 130 | layers = [Bottleneck(self._inplanes, planes, stride)] 131 | 132 | self._inplanes = planes * Bottleneck.expansion 133 | for _ in range(1, blocks): 134 | layers.append(Bottleneck(self._inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | def stem(x): 140 | x = self.relu1(self.bn1(self.conv1(x))) 141 | x = self.relu2(self.bn2(self.conv2(x))) 142 | x = self.relu3(self.bn3(self.conv3(x))) 143 | x = self.avgpool(x) 144 | return x 145 | 146 | x = x.type(self.conv1.weight.dtype) 147 | x = stem(x) 148 | x = self.layer1(x) 149 | x = self.layer2(x) 150 | x = self.layer3(x) 151 | x = self.layer4(x) 152 | x = self.attnpool(x) 153 | 154 | return x 155 | 156 | 157 | class LayerNorm(nn.LayerNorm): 158 | """Subclass torch's LayerNorm to handle fp16.""" 159 | 160 | def forward(self, x: torch.Tensor): 161 | orig_type = x.dtype 162 | ret = super().forward(x.type(torch.float32)) 163 | return ret.type(orig_type) 164 | 165 | 166 | class QuickGELU(nn.Module): 167 | def forward(self, x: torch.Tensor): 168 | return x * torch.sigmoid(1.702 * x) 169 | 170 | 171 | class ResidualAttentionBlock(nn.Module): 172 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 173 | super().__init__() 174 | 175 | self.attn = nn.MultiheadAttention(d_model, n_head) 176 | self.ln_1 = LayerNorm(d_model) 177 | self.mlp = nn.Sequential(OrderedDict([ 178 | ("c_fc", nn.Linear(d_model, d_model * 4)), 179 | ("gelu", QuickGELU()), 180 | ("c_proj", nn.Linear(d_model * 4, d_model)) 181 | ])) 182 | self.ln_2 = LayerNorm(d_model) 183 | self.attn_mask = attn_mask 184 | 185 | def attention(self, x: torch.Tensor): 186 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 187 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 188 | 189 | def forward(self, x: torch.Tensor): 190 | x = x + self.attention(self.ln_1(x)) 191 | x = x + self.mlp(self.ln_2(x)) 192 | return x 193 | 194 | 195 | class Transformer(nn.Module): 196 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 197 | super().__init__() 198 | self.width = width 199 | self.layers = layers 200 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 201 | 202 | def forward(self, x: torch.Tensor): 203 | return self.resblocks(x) 204 | 205 | 206 | class VisionTransformer(nn.Module): 207 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 208 | super().__init__() 209 | self.input_resolution = input_resolution 210 | self.output_dim = output_dim 211 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 212 | 213 | scale = width ** -0.5 214 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 215 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 216 | self.ln_pre = LayerNorm(width) 217 | 218 | self.transformer = Transformer(width, layers, heads) 219 | 220 | self.ln_post = LayerNorm(width) 221 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 222 | 223 | def forward(self, x: torch.Tensor): 224 | x = self.conv1(x) # shape = [*, width, grid, grid] 225 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 226 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 227 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 228 | x = x + self.positional_embedding.to(x.dtype) 229 | x = self.ln_pre(x) 230 | 231 | x = x.permute(1, 0, 2) # NLD -> LND 232 | x = self.transformer(x) 233 | x = x.permute(1, 0, 2) # LND -> NLD 234 | 235 | x = self.ln_post(x[:, 0, :]) 236 | 237 | if self.proj is not None: 238 | x = x @ self.proj 239 | 240 | return x 241 | 242 | 243 | class CLIP(nn.Module): 244 | def __init__(self, 245 | embed_dim: int, 246 | # vision 247 | image_resolution: int, 248 | vision_layers: Union[Tuple[int, int, int, int], int], 249 | vision_width: int, 250 | vision_patch_size: int, 251 | # text 252 | context_length: int, 253 | vocab_size: int, 254 | transformer_width: int, 255 | transformer_heads: int, 256 | transformer_layers: int 257 | ): 258 | super().__init__() 259 | 260 | self.context_length = context_length 261 | 262 | if isinstance(vision_layers, (tuple, list)): 263 | vision_heads = vision_width * 32 // 64 264 | self.visual = ModifiedResNet( 265 | layers=vision_layers, 266 | output_dim=embed_dim, 267 | heads=vision_heads, 268 | input_resolution=image_resolution, 269 | width=vision_width 270 | ) 271 | else: 272 | vision_heads = vision_width // 64 273 | self.visual = VisionTransformer( 274 | input_resolution=image_resolution, 275 | patch_size=vision_patch_size, 276 | width=vision_width, 277 | layers=vision_layers, 278 | heads=vision_heads, 279 | output_dim=embed_dim 280 | ) 281 | 282 | self.transformer = Transformer( 283 | width=transformer_width, 284 | layers=transformer_layers, 285 | heads=transformer_heads, 286 | attn_mask=self.build_attention_mask() 287 | ) 288 | 289 | self.vocab_size = vocab_size 290 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 291 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 292 | self.ln_final = LayerNorm(transformer_width) 293 | 294 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 295 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 296 | 297 | self.initialize_parameters() 298 | 299 | def initialize_parameters(self): 300 | nn.init.normal_(self.token_embedding.weight, std=0.02) 301 | nn.init.normal_(self.positional_embedding, std=0.01) 302 | 303 | if isinstance(self.visual, ModifiedResNet): 304 | if self.visual.attnpool is not None: 305 | std = self.visual.attnpool.c_proj.in_features ** -0.5 306 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 307 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 308 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 309 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 310 | 311 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 312 | for name, param in resnet_block.named_parameters(): 313 | if name.endswith("bn3.weight"): 314 | nn.init.zeros_(param) 315 | 316 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 317 | attn_std = self.transformer.width ** -0.5 318 | fc_std = (2 * self.transformer.width) ** -0.5 319 | for block in self.transformer.resblocks: 320 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 321 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 322 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 323 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 324 | 325 | if self.text_projection is not None: 326 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 327 | 328 | def build_attention_mask(self): 329 | # lazily create causal attention mask, with full attention between the vision tokens 330 | # pytorch uses additive attention mask; fill with -inf 331 | mask = torch.empty(self.context_length, self.context_length) 332 | mask.fill_(float("-inf")) 333 | mask.triu_(1) # zero out the lower diagonal 334 | return mask 335 | 336 | @property 337 | def dtype(self): 338 | return self.visual.conv1.weight.dtype 339 | 340 | def encode_image(self, image): 341 | return self.visual(image.type(self.dtype)) 342 | 343 | def encode_text(self, text): 344 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 345 | 346 | x = x + self.positional_embedding.type(self.dtype) 347 | x = x.permute(1, 0, 2) # NLD -> LND 348 | x = self.transformer(x) 349 | x = x.permute(1, 0, 2) # LND -> NLD 350 | x = self.ln_final(x).type(self.dtype) 351 | 352 | # x.shape = [batch_size, n_ctx, transformer.width] 353 | # take features from the eot embedding (eot_token is the highest number in each sequence) 354 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 355 | 356 | return x 357 | 358 | def forward(self, image, text): 359 | image_features = self.encode_image(image) 360 | text_features = self.encode_text(text) 361 | 362 | # normalized features 363 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 364 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 365 | 366 | # cosine similarity as logits 367 | logit_scale = self.logit_scale.exp() 368 | logits_per_image = logit_scale * image_features @ text_features.t() 369 | logits_per_text = logits_per_image.t() 370 | 371 | # shape = [global_batch_size, global_batch_size] 372 | return logits_per_image, logits_per_text 373 | 374 | 375 | def convert_weights(model: nn.Module): 376 | """Convert applicable model parameters to fp16""" 377 | 378 | def _convert_weights_to_fp16(l): 379 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 380 | l.weight.data = l.weight.data.half() 381 | if l.bias is not None: 382 | l.bias.data = l.bias.data.half() 383 | 384 | if isinstance(l, nn.MultiheadAttention): 385 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 386 | tensor = getattr(l, attr) 387 | if tensor is not None: 388 | tensor.data = tensor.data.half() 389 | 390 | for name in ["text_projection", "proj"]: 391 | if hasattr(l, name): 392 | attr = getattr(l, name) 393 | if attr is not None: 394 | attr.data = attr.data.half() 395 | 396 | model.apply(_convert_weights_to_fp16) 397 | 398 | 399 | def build_model(state_dict: dict): 400 | vit = "visual.proj" in state_dict 401 | 402 | if vit: 403 | vision_width = state_dict["visual.conv1.weight"].shape[0] 404 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 405 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 406 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 407 | image_resolution = vision_patch_size * grid_size 408 | else: 409 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 410 | vision_layers = tuple(counts) 411 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 412 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 413 | vision_patch_size = None 414 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 415 | image_resolution = output_width * 32 416 | 417 | embed_dim = state_dict["text_projection"].shape[1] 418 | context_length = state_dict["positional_embedding"].shape[0] 419 | vocab_size = state_dict["token_embedding.weight"].shape[0] 420 | transformer_width = state_dict["ln_final.weight"].shape[0] 421 | transformer_heads = transformer_width // 64 422 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 423 | 424 | model = CLIP( 425 | embed_dim, 426 | image_resolution, vision_layers, vision_width, vision_patch_size, 427 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 428 | ) 429 | 430 | for key in ["input_resolution", "context_length", "vocab_size"]: 431 | if key in state_dict: 432 | del state_dict[key] 433 | 434 | # convert_weights(model) 435 | model.load_state_dict(state_dict) 436 | del state_dict 437 | torch.cuda.empty_cache() 438 | return model.eval() -------------------------------------------------------------------------------- /data/cls_to_names.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | flower102_maps = {"21": "fire lily", "3": "canterbury bells", "45": "bolero deep blue", "1": "pink primrose", "34": "mexican aster", "27": "prince of wales feathers", "7": "moon orchid", "16": "globe-flower", "25": "grape hyacinth", "26": "corn poppy", "79": "toad lily", "39": "siam tulip", "24": "red ginger", "67": "spring crocus", "35": "alpine sea holly", "32": "garden phlox", "10": "globe thistle", "6": "tiger lily", "93": "ball moss", "33": "love in the mist", "9": "monkshood", "102": "blackberry lily", "14": "spear thistle", "19": "balloon flower", "100": "blanket flower", "13": "king protea", "49": "oxeye daisy", "15": "yellow iris", "61": "cautleya spicata", "31": "carnation", "64": "silverbush", "68": "bearded iris", "63": "black-eyed susan", "69": "windflower", "62": "japanese anemone", "20": "giant white arum lily", "38": "great masterwort", "4": "sweet pea", "86": "tree mallow", "101": "trumpet creeper", "42": "daffodil", "22": "pincushion flower", "2": "hard-leaved pocket orchid", "54": "sunflower", "66": "osteospermum", "70": "tree poppy", "85": "desert-rose", "99": "bromelia", "87": "magnolia", "5": "english marigold", "92": "bee balm", "28": "stemless gentian", "97": "mallow", "57": "gaura", "40": "lenten rose", "47": "marigold", "59": "orange dahlia", "48": "buttercup", "55": "pelargonium", "36": "ruby-lipped cattleya", "91": "hippeastrum", "29": "artichoke", "71": "gazania", "90": "canna lily", "18": "peruvian lily", "98": "mexican petunia", "8": "bird of paradise", "30": "sweet william", "17": "purple coneflower", "52": "wild pansy", "84": "columbine", "12": "colt's foot", "11": "snapdragon", "96": "camellia", "23": "fritillary", "50": "common dandelion", "44": "poinsettia", "53": "primula", "72": "azalea", "65": "californian poppy", "80": "anthurium", "76": "morning glory", "37": "cape flower", "56": "bishop of llandaff", "60": "pink-yellow dahlia", "82": "clematis", "58": "geranium", "75": "thorn apple", "41": "barbeton daisy", "95": "bougainvillea", "43": "sword lily", "83": "hibiscus", "78": "lotus", "88": "cyclamen", "94": "foxglove", "81": "frangipani", "74": "rose", "89": "watercress", "73": "water lily", "46": "wallflower", "77": "passion flower", "51": "petunia"} 5 | flower102_maps = {int(k)-1: v for k, v in flower102_maps.items()} 6 | flower102_maps = dict(sorted(flower102_maps.items())) 7 | flower102_classes = list(flower102_maps.values()) 8 | 9 | food101_classes = ['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles'] 10 | 11 | dtd_classes = ['banded', 'blotchy', 'braided', 'bubbly', 'bumpy', 'chequered', 'cobwebbed', 'cracked', 'crosshatched', 'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled', 'frilly', 'gauzy', 'grid', 'grooved', 'honeycombed', 'interlaced', 'knitted', 'lacelike', 'lined', 'marbled', 'matted', 'meshed', 'paisley', 'perforated', 'pitted', 'pleated', 'polka-dotted', 'porous', 'potholed', 'scaly', 'smeared', 'spiralled', 'sprinkled', 'stained', 'stratified', 'striped', 'studded', 'swirly', 'veined', 'waffled', 'woven', 'wrinkled', 'zigzagged'] 12 | 13 | pets_classes = ['abyssinian', 'american_bulldog', 'american_pit_bull_terrier', 'basset_hound', 'beagle', 'bengal', 'birman', 'bombay', 'boxer', 'british_shorthair', 'chihuahua', 'egyptian_mau', 'english_cocker_spaniel', 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', 'japanese_chin', 'keeshond', 'leonberger', 'maine_coon', 'miniature_pinscher', 'newfoundland', 'persian', 'pomeranian', 'pug', 'ragdoll', 'russian_blue', 'saint_bernard', 'samoyed', 'scottish_terrier', 'shiba_inu', 'siamese', 'sphynx', 'staffordshire_bull_terrier', 'wheaten_terrier', 'yorkshire_terrier'] 14 | 15 | sun397_classes = ['abbey', 'airplane_cabin', 'airport_terminal', 'alley', 'amphitheater', 'amusement_arcade', 'amusement_park', 'anechoic_chamber', 'outdoor apartment_building', 'indoor apse', 'aquarium', 'aqueduct', 'arch', 'archive', 'outdoor arrival_gate', 'art_gallery', 'art_school', 'art_studio', 'assembly_line', 'outdoor athletic_field', 'public atrium', 'attic', 'auditorium', 'auto_factory', 'badlands', 'indoor badminton_court', 'baggage_claim', 'shop bakery', 'exterior balcony', 'interior balcony', 'ball_pit', 'ballroom', 'bamboo_forest', 'banquet_hall', 'bar', 'barn', 'barndoor', 'baseball_field', 'basement', 'basilica', 'outdoor basketball_court', 'bathroom', 'batters_box', 'bayou', 'indoor bazaar', 'outdoor bazaar', 'beach', 'beauty_salon', 'bedroom', 'berth', 'biology_laboratory', 'indoor bistro', 'boardwalk', 'boat_deck', 'boathouse', 'bookstore', 'indoor booth', 'botanical_garden', 'indoor bow_window', 'outdoor bow_window', 'bowling_alley', 'boxing_ring', 'indoor brewery', 'bridge', 'building_facade', 'bullring', 'burial_chamber', 'bus_interior', 'butchers_shop', 'butte', 'outdoor cabin', 'cafeteria', 'campsite', 'campus', 'natural canal', 'urban canal', 'candy_store', 'canyon', 'backseat car_interior', 'frontseat car_interior', 'carrousel', 'indoor casino', 'castle', 'catacomb', 'indoor cathedral', 'outdoor cathedral', 'indoor cavern', 'cemetery', 'chalet', 'cheese_factory', 'chemistry_lab', 'indoor chicken_coop', 'outdoor chicken_coop', 'childs_room', 'indoor church', 'outdoor church', 'classroom', 'clean_room', 'cliff', 'indoor cloister', 'closet', 'clothing_store', 'coast', 'cockpit', 'coffee_shop', 'computer_room', 'conference_center', 'conference_room', 'construction_site', 'control_room', 'outdoor control_tower', 'corn_field', 'corral', 'corridor', 'cottage_garden', 'courthouse', 'courtroom', 'courtyard', 'exterior covered_bridge', 'creek', 'crevasse', 'crosswalk', 'office cubicle', 'dam', 'delicatessen', 'dentists_office', 'sand desert', 'vegetation desert', 'indoor diner', 'outdoor diner', 'home dinette', 'vehicle dinette', 'dining_car', 'dining_room', 'discotheque', 'dock', 'outdoor doorway', 'dorm_room', 'driveway', 'outdoor driving_range', 'drugstore', 'electrical_substation', 'door elevator', 'interior elevator', 'elevator_shaft', 'engine_room', 'indoor escalator', 'excavation', 'indoor factory', 'fairway', 'fastfood_restaurant', 'cultivated field', 'wild field', 'fire_escape', 'fire_station', 'indoor firing_range', 'fishpond', 'indoor florist_shop', 'food_court', 'broadleaf forest', 'needleleaf forest', 'forest_path', 'forest_road', 'formal_garden', 'fountain', 'galley', 'game_room', 'indoor garage', 'garbage_dump', 'gas_station', 'exterior gazebo', 'indoor general_store', 'outdoor general_store', 'gift_shop', 'golf_course', 'indoor greenhouse', 'outdoor greenhouse', 'indoor gymnasium', 'indoor hangar', 'outdoor hangar', 'harbor', 'hayfield', 'heliport', 'herb_garden', 'highway', 'hill', 'home_office', 'hospital', 'hospital_room', 'hot_spring', 'outdoor hot_tub', 'outdoor hotel', 'hotel_room', 'house', 'outdoor hunting_lodge', 'ice_cream_parlor', 'ice_floe', 'ice_shelf', 'indoor ice_skating_rink', 'outdoor ice_skating_rink', 'iceberg', 'igloo', 'industrial_area', 'outdoor inn', 'islet', 'indoor jacuzzi', 'indoor jail', 'jail_cell', 'jewelry_shop', 'kasbah', 'indoor kennel', 'outdoor kennel', 'kindergarden_classroom', 'kitchen', 'kitchenette', 'outdoor labyrinth', 'natural lake', 'landfill', 'landing_deck', 'laundromat', 'lecture_room', 'indoor library', 'outdoor library', 'outdoor lido_deck', 'lift_bridge', 'lighthouse', 'limousine_interior', 'living_room', 'lobby', 'lock_chamber', 'locker_room', 'mansion', 'manufactured_home', 'indoor market', 'outdoor market', 'marsh', 'martial_arts_gym', 'mausoleum', 'medina', 'water moat', 'outdoor monastery', 'indoor mosque', 'outdoor mosque', 'motel', 'mountain', 'mountain_snowy', 'indoor movie_theater', 'indoor museum', 'music_store', 'music_studio', 'outdoor nuclear_power_plant', 'nursery', 'oast_house', 'outdoor observatory', 'ocean', 'office', 'office_building', 'outdoor oil_refinery', 'oilrig', 'operating_room', 'orchard', 'outdoor outhouse', 'pagoda', 'palace', 'pantry', 'park', 'indoor parking_garage', 'outdoor parking_garage', 'parking_lot', 'parlor', 'pasture', 'patio', 'pavilion', 'pharmacy', 'phone_booth', 'physics_laboratory', 'picnic_area', 'indoor pilothouse', 'outdoor planetarium', 'playground', 'playroom', 'plaza', 'indoor podium', 'outdoor podium', 'pond', 'establishment poolroom', 'home poolroom', 'outdoor power_plant', 'promenade_deck', 'indoor pub', 'pulpit', 'putting_green', 'racecourse', 'raceway', 'raft', 'railroad_track', 'rainforest', 'reception', 'recreation_room', 'residential_neighborhood', 'restaurant', 'restaurant_kitchen', 'restaurant_patio', 'rice_paddy', 'riding_arena', 'river', 'rock_arch', 'rope_bridge', 'ruin', 'runway', 'sandbar', 'sandbox', 'sauna', 'schoolhouse', 'sea_cliff', 'server_room', 'shed', 'shoe_shop', 'shopfront', 'indoor shopping_mall', 'shower', 'skatepark', 'ski_lodge', 'ski_resort', 'ski_slope', 'sky', 'skyscraper', 'slum', 'snowfield', 'squash_court', 'stable', 'baseball stadium', 'football stadium', 'indoor stage', 'staircase', 'street', 'subway_interior', 'platform subway_station', 'supermarket', 'sushi_bar', 'swamp', 'indoor swimming_pool', 'outdoor swimming_pool', 'indoor synagogue', 'outdoor synagogue', 'television_studio', 'east_asia temple', 'south_asia temple', 'indoor tennis_court', 'outdoor tennis_court', 'outdoor tent', 'indoor_procenium theater', 'indoor_seats theater', 'thriftshop', 'throne_room', 'ticket_booth', 'toll_plaza', 'topiary_garden', 'tower', 'toyshop', 'outdoor track', 'train_railway', 'platform train_station', 'tree_farm', 'tree_house', 'trench', 'coral_reef underwater', 'utility_room', 'valley', 'van_interior', 'vegetable_garden', 'veranda', 'veterinarians_office', 'viaduct', 'videostore', 'village', 'vineyard', 'volcano', 'indoor volleyball_court', 'outdoor volleyball_court', 'waiting_room', 'indoor warehouse', 'water_tower', 'block waterfall', 'fan waterfall', 'plunge waterfall', 'watering_hole', 'wave', 'wet_bar', 'wheat_field', 'wind_farm', 'windmill', 'barrel_storage wine_cellar', 'bottle_storage wine_cellar', 'indoor wrestling_ring', 'yard', 'youth_hostel'] 16 | 17 | caltech101_classes = ['face', 'leopard', 'motorbike', 'accordion', 'airplane', 'anchor', 'ant', 'barrel', 'bass', 'beaver', 'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly', 'camera', 'cannon', 'car_side', 'ceiling_fan', 'cellphone', 'chair', 'chandelier', 'cougar_body', 'cougar_face', 'crab', 'crayfish', 'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill', 'dolphin', 'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium', 'ewer', 'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk', 'gramophone', 'grand_piano', 'hawksbill', 'headphone', 'hedgehog', 'helicopter', 'ibis', 'inline_skate', 'joshua_tree', 'kangaroo', 'ketch', 'lamp', 'laptop', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly', 'menorah', 'metronome', 'minaret', 'nautilus', 'octopus', 'okapi', 'pagoda', 'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', 'rhino', 'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion', 'sea_horse', 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus', 'stop_sign', 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella', 'watch', 'water_lilly', 'wheelchair', 'wild_cat', 'windsor_chair', 'wrench', 'yin_yang'] 18 | 19 | cars_classes = ['2000 AM General Hummer SUV', '2012 Acura RL Sedan', '2012 Acura TL Sedan', '2008 Acura TL Type-S', '2012 Acura TSX Sedan', '2001 Acura Integra Type R', '2012 Acura ZDX Hatchback', '2012 Aston Martin V8 Vantage Convertible', '2012 Aston Martin V8 Vantage Coupe', '2012 Aston Martin Virage Convertible', '2012 Aston Martin Virage Coupe', '2008 Audi RS 4 Convertible', '2012 Audi A5 Coupe', '2012 Audi TTS Coupe', '2012 Audi R8 Coupe', '1994 Audi V8 Sedan', '1994 Audi 100 Sedan', '1994 Audi 100 Wagon', '2011 Audi TT Hatchback', '2011 Audi S6 Sedan', '2012 Audi S5 Convertible', '2012 Audi S5 Coupe', '2012 Audi S4 Sedan', '2007 Audi S4 Sedan', '2012 Audi TT RS Coupe', '2012 BMW ActiveHybrid 5 Sedan', '2012 BMW 1 Series Convertible', '2012 BMW 1 Series Coupe', '2012 BMW 3 Series Sedan', '2012 BMW 3 Series Wagon', '2007 BMW 6 Series Convertible', '2007 BMW X5 SUV', '2012 BMW X6 SUV', '2012 BMW M3 Coupe', '2010 BMW M5 Sedan', '2010 BMW M6 Convertible', '2012 BMW X3 SUV', '2012 BMW Z4 Convertible', '2012 Bentley Continental Supersports Conv. Convertible', '2009 Bentley Arnage Sedan', '2011 Bentley Mulsanne Sedan', '2012 Bentley Continental GT Coupe', '2007 Bentley Continental GT Coupe', '2007 Bentley Continental Flying Spur Sedan', '2009 Bugatti Veyron 16.4 Convertible', '2009 Bugatti Veyron 16.4 Coupe', '2012 Buick Regal GS', '2007 Buick Rainier SUV', '2012 Buick Verano Sedan', '2012 Buick Enclave SUV', '2012 Cadillac CTS-V Sedan', '2012 Cadillac SRX SUV', '2007 Cadillac Escalade EXT Crew Cab', '2012 Chevrolet Silverado 1500 Hybrid Crew Cab', '2012 Chevrolet Corvette Convertible', '2012 Chevrolet Corvette ZR1', '2007 Chevrolet Corvette Ron Fellows Edition Z06', '2012 Chevrolet Traverse SUV', '2012 Chevrolet Camaro Convertible', '2010 Chevrolet HHR SS', '2007 Chevrolet Impala Sedan', '2012 Chevrolet Tahoe Hybrid SUV', '2012 Chevrolet Sonic Sedan', '2007 Chevrolet Express Cargo Van', '2012 Chevrolet Avalanche Crew Cab', '2010 Chevrolet Cobalt SS', '2010 Chevrolet Malibu Hybrid Sedan', '2009 Chevrolet TrailBlazer SS', '2012 Chevrolet Silverado 2500HD Regular Cab', '2007 Chevrolet Silverado 1500 Classic Extended Cab', '2007 Chevrolet Express Van', '2007 Chevrolet Monte Carlo Coupe', '2007 Chevrolet Malibu Sedan', '2012 Chevrolet Silverado 1500 Extended Cab', '2012 Chevrolet Silverado 1500 Regular Cab', '2009 Chrysler Aspen SUV', '2010 Chrysler Sebring Convertible', '2012 Chrysler Town and Country Minivan', '2010 Chrysler 300 SRT-8', '2008 Chrysler Crossfire Convertible', '2008 Chrysler PT Cruiser Convertible', '2002 Daewoo Nubira Wagon', '2012 Dodge Caliber Wagon', '2007 Dodge Caliber Wagon', '1997 Dodge Caravan Minivan', '2010 Dodge Ram Pickup 3500 Crew Cab', '2009 Dodge Ram Pickup 3500 Quad Cab', '2009 Dodge Sprinter Cargo Van', '2012 Dodge Journey SUV', '2010 Dodge Dakota Crew Cab', '2007 Dodge Dakota Club Cab', '2008 Dodge Magnum Wagon', '2011 Dodge Challenger SRT8', '2012 Dodge Durango SUV', '2007 Dodge Durango SUV', '2012 Dodge Charger Sedan', '2009 Dodge Charger SRT-8', '1998 Eagle Talon Hatchback', '2012 FIAT 500 Abarth', '2012 FIAT 500 Convertible', '2012 Ferrari FF Coupe', '2012 Ferrari California Convertible', '2012 Ferrari 458 Italia Convertible', '2012 Ferrari 458 Italia Coupe', '2012 Fisker Karma Sedan', '2012 Ford F-450 Super Duty Crew Cab', '2007 Ford Mustang Convertible', '2007 Ford Freestar Minivan', '2009 Ford Expedition EL SUV', '2012 Ford Edge SUV', '2011 Ford Ranger SuperCab', '2006 Ford GT Coupe', '2012 Ford F-150 Regular Cab', '2007 Ford F-150 Regular Cab', '2007 Ford Focus Sedan', '2012 Ford E-Series Wagon Van', '2012 Ford Fiesta Sedan', '2012 GMC Terrain SUV', '2012 GMC Savana Van', '2012 GMC Yukon Hybrid SUV', '2012 GMC Acadia SUV', '2012 GMC Canyon Extended Cab', '1993 Geo Metro Convertible', '2010 HUMMER H3T Crew Cab', '2009 HUMMER H2 SUT Crew Cab', '2012 Honda Odyssey Minivan', '2007 Honda Odyssey Minivan', '2012 Honda Accord Coupe', '2012 Honda Accord Sedan', '2012 Hyundai Veloster Hatchback', '2012 Hyundai Santa Fe SUV', '2012 Hyundai Tucson SUV', '2012 Hyundai Veracruz SUV', '2012 Hyundai Sonata Hybrid Sedan', '2007 Hyundai Elantra Sedan', '2012 Hyundai Accent Sedan', '2012 Hyundai Genesis Sedan', '2012 Hyundai Sonata Sedan', '2012 Hyundai Elantra Touring Hatchback', '2012 Hyundai Azera Sedan', '2012 Infiniti G Coupe IPL', '2011 Infiniti QX56 SUV', '2008 Isuzu Ascender SUV', '2012 Jaguar XK XKR', '2012 Jeep Patriot SUV', '2012 Jeep Wrangler SUV', '2012 Jeep Liberty SUV', '2012 Jeep Grand Cherokee SUV', '2012 Jeep Compass SUV', '2008 Lamborghini Reventon Coupe', '2012 Lamborghini Aventador Coupe', '2012 Lamborghini Gallardo LP 570-4 Superleggera', '2001 Lamborghini Diablo Coupe', '2012 Land Rover Range Rover SUV', '2012 Land Rover LR2 SUV', '2011 Lincoln Town Car Sedan', '2012 MINI Cooper Roadster Convertible', '2012 Maybach Landaulet Convertible', '2011 Mazda Tribute SUV', '2012 McLaren MP4-12C Coupe', '1993 Mercedes-Benz 300-Class Convertible', '2012 Mercedes-Benz C-Class Sedan', '2009 Mercedes-Benz SL-Class Coupe', '2012 Mercedes-Benz E-Class Sedan', '2012 Mercedes-Benz S-Class Sedan', '2012 Mercedes-Benz Sprinter Van', '2012 Mitsubishi Lancer Sedan', '2012 Nissan Leaf Hatchback', '2012 Nissan NV Passenger Van', '2012 Nissan Juke Hatchback', '1998 Nissan 240SX Coupe', '1999 Plymouth Neon Coupe', '2012 Porsche Panamera Sedan', '2012 Ram C/V Cargo Van Minivan', '2012 Rolls-Royce Phantom Drophead Coupe Convertible', '2012 Rolls-Royce Ghost Sedan', '2012 Rolls-Royce Phantom Sedan', '2012 Scion xD Hatchback', '2009 Spyker C8 Convertible', '2009 Spyker C8 Coupe', '2007 Suzuki Aerio Sedan', '2012 Suzuki Kizashi Sedan', '2012 Suzuki SX4 Hatchback', '2012 Suzuki SX4 Sedan', '2012 Tesla Model S Sedan', '2012 Toyota Sequoia SUV', '2012 Toyota Camry Sedan', '2012 Toyota Corolla Sedan', '2012 Toyota 4Runner SUV', '2012 Volkswagen Golf Hatchback', '1991 Volkswagen Golf Hatchback', '2012 Volkswagen Beetle Hatchback', '2012 Volvo C30 Hatchback', '1993 Volvo 240 Sedan', '2007 Volvo XC90 SUV', '2012 smart fortwo Convertible'] 20 | 21 | ucf101_classes = ['Apply_Eye_Makeup', 'Apply_Lipstick', 'Archery', 'Baby_Crawling', 'Balance_Beam', 'Band_Marching', 'Baseball_Pitch', 'Basketball', 'Basketball_Dunk', 'Bench_Press', 'Biking', 'Billiards', 'Blow_Dry_Hair', 'Blowing_Candles', 'Body_Weight_Squats', 'Bowling', 'Boxing_Punching_Bag', 'Boxing_Speed_Bag', 'Breast_Stroke', 'Brushing_Teeth', 'Clean_And_Jerk', 'Cliff_Diving', 'Cricket_Bowling', 'Cricket_Shot', 'Cutting_In_Kitchen', 'Diving', 'Drumming', 'Fencing', 'Field_Hockey_Penalty', 'Floor_Gymnastics', 'Frisbee_Catch', 'Front_Crawl', 'Golf_Swing', 'Haircut', 'Hammering', 'Hammer_Throw', 'Handstand_Pushups', 'Handstand_Walking', 'Head_Massage', 'High_Jump', 'Horse_Race', 'Horse_Riding', 'Hula_Hoop', 'Ice_Dancing', 'Javelin_Throw', 'Juggling_Balls', 'Jumping_Jack', 'Jump_Rope', 'Kayaking', 'Knitting', 'Long_Jump', 'Lunges', 'Military_Parade', 'Mixing', 'Mopping_Floor', 'Nunchucks', 'Parallel_Bars', 'Pizza_Tossing', 'Playing_Cello', 'Playing_Daf', 'Playing_Dhol', 'Playing_Flute', 'Playing_Guitar', 'Playing_Piano', 'Playing_Sitar', 'Playing_Tabla', 'Playing_Violin', 'Pole_Vault', 'Pommel_Horse', 'Pull_Ups', 'Punch', 'Push_Ups', 'Rafting', 'Rock_Climbing_Indoor', 'Rope_Climbing', 'Rowing', 'Salsa_Spin', 'Shaving_Beard', 'Shotput', 'Skate_Boarding', 'Skiing', 'Skijet', 'Sky_Diving', 'Soccer_Juggling', 'Soccer_Penalty', 'Still_Rings', 'Sumo_Wrestling', 'Surfing', 'Swing', 'Table_Tennis_Shot', 'Tai_Chi', 'Tennis_Swing', 'Throw_Discus', 'Trampoline_Jumping', 'Typing', 'Uneven_Bars', 'Volleyball_Spiking', 'Walking_With_Dog', 'Wall_Pushups', 'Writing_On_Board', 'Yo_Yo'] 22 | 23 | aircraft_classes = ['707-320', '727-200', '737-200', '737-300', '737-400', '737-500', '737-600', '737-700', '737-800', '737-900', '747-100', '747-200', '747-300', '747-400', '757-200', '757-300', '767-200', '767-300', '767-400', '777-200', '777-300', 'A300B4', 'A310', 'A318', 'A319', 'A320', 'A321', 'A330-200', 'A330-300', 'A340-200', 'A340-300', 'A340-500', 'A340-600', 'A380', 'ATR-42', 'ATR-72', 'An-12', 'BAE 146-200', 'BAE 146-300', 'BAE-125', 'Beechcraft 1900', 'Boeing 717', 'C-130', 'C-47', 'CRJ-200', 'CRJ-700', 'CRJ-900', 'Cessna 172', 'Cessna 208', 'Cessna 525', 'Cessna 560', 'Challenger 600', 'DC-10', 'DC-3', 'DC-6', 'DC-8', 'DC-9-30', 'DH-82', 'DHC-1', 'DHC-6', 'DHC-8-100', 'DHC-8-300', 'DR-400', 'Dornier 328', 'E-170', 'E-190', 'E-195', 'EMB-120', 'ERJ 135', 'ERJ 145', 'Embraer Legacy 600', 'Eurofighter Typhoon', 'F-16A/B', 'F/A-18', 'Falcon 2000', 'Falcon 900', 'Fokker 100', 'Fokker 50', 'Fokker 70', 'Global Express', 'Gulfstream IV', 'Gulfstream V', 'Hawk T1', 'Il-76', 'L-1011', 'MD-11', 'MD-80', 'MD-87', 'MD-90', 'Metroliner', 'Model B200', 'PA-28', 'SR-20', 'Saab 2000', 'Saab 340', 'Spitfire', 'Tornado', 'Tu-134', 'Tu-154', 'Yak-42'] 24 | 25 | eurosat_classes = ['Annual Crop Land', 'Forest', 'Herbaceous Vegetation Land', 'Highway or Road', 'Industrial Buildings', 'Pasture Land', 'Permanent Crop Land', 'Residential Buildings', 'River', 'Sea or Lake'] 26 | -------------------------------------------------------------------------------- /tpt_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import time 4 | 5 | from copy import deepcopy 6 | 7 | from PIL import Image 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim 14 | import torch.utils.data 15 | import torch.utils.data.distributed 16 | import torchvision.transforms as transforms 17 | 18 | 19 | try: 20 | from torchvision.transforms import InterpolationMode 21 | BICUBIC = InterpolationMode.BICUBIC 22 | except ImportError: 23 | BICUBIC = Image.BICUBIC 24 | import torchvision.models as models 25 | 26 | from clip.custom_clip import get_coop 27 | from clip.cocoop import get_cocoop 28 | from data.imagnet_prompts import imagenet_classes 29 | from data.datautils import AugMixAugmenter, build_dataset 30 | from utils.tools import Summary, AverageMeter, ProgressMeter, accuracy, load_model_weight, set_random_seed 31 | from data.cls_to_names import * 32 | from data.fewshot_datasets import fewshot_datasets 33 | from data.imagenet_variants import thousand_k_to_200, imagenet_a_mask, imagenet_r_mask, imagenet_v_mask 34 | 35 | import ipdb 36 | import math 37 | import pickle 38 | 39 | model_names = sorted(name for name in models.__dict__ 40 | if name.islower() and not name.startswith("__") 41 | and callable(models.__dict__[name])) 42 | 43 | def ECE_Loss(num_bins, predictions, confidences, correct): 44 | #ipdb.set_trace() 45 | bin_boundaries = torch.linspace(0, 1, num_bins + 1) 46 | bin_lowers = bin_boundaries[:-1] 47 | bin_uppers = bin_boundaries[1:] 48 | bin_accuracy = [0]*num_bins 49 | bin_confidence = [0]*num_bins 50 | bin_num_sample = [0]*num_bins 51 | 52 | for idx in range(len(predictions)): 53 | #prediction = predictions[idx] 54 | confidence = confidences[idx] 55 | bin_idx = -1 56 | for bin_lower, bin_upper in zip(bin_lowers, bin_uppers): 57 | bin_idx += 1 58 | bin_lower = bin_lower.item() 59 | bin_upper = bin_upper.item() 60 | #if bin_lower <= confidence and confidence < bin_upper: 61 | if bin_lower < confidence and confidence <= bin_upper: 62 | bin_num_sample[bin_idx] += 1 63 | bin_accuracy[bin_idx] += correct[idx] 64 | bin_confidence[bin_idx] += confidences[idx] 65 | 66 | for idx in range(num_bins): 67 | if bin_num_sample[idx] != 0: 68 | bin_accuracy[idx] = bin_accuracy[idx]/bin_num_sample[idx] 69 | bin_confidence[idx] = bin_confidence[idx]/bin_num_sample[idx] 70 | 71 | ece_loss = 0.0 72 | for idx in range(num_bins): 73 | temp_abs = abs(bin_accuracy[idx]-bin_confidence[idx]) 74 | ece_loss += (temp_abs*bin_num_sample[idx])/len(predictions) 75 | 76 | return ece_loss, bin_accuracy, bin_confidence, bin_num_sample 77 | 78 | def Calculator(result_dict): 79 | 80 | list_max_confidence = result_dict['max_confidence'] 81 | list_prediction = result_dict['prediction'] 82 | list_label = result_dict['label'] 83 | 84 | torch_list_prediction = torch.tensor(list_prediction).int() 85 | torch_list_label = torch.tensor(list_label).int() 86 | 87 | torch_correct = (torch_list_prediction == torch_list_label) 88 | list_correct = torch_correct.tolist() 89 | 90 | 91 | ece_data = ECE_Loss(20, list_prediction, list_max_confidence, list_correct) 92 | acc = sum(list_correct)/len(list_correct) 93 | 94 | print('acc: ', acc*100) 95 | print('ece: ', ece_data[0]*100) 96 | 97 | return 98 | 99 | 100 | def select_confident_samples(logits, top): 101 | batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1) 102 | idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * top)] 103 | return logits[idx], idx 104 | 105 | def avg_entropy(outputs): 106 | logits = outputs - outputs.logsumexp(dim=-1, keepdim=True) # logits = outputs.log_softmax(dim=1) [N, 1000] 107 | avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0]) # avg_logits = logits.mean(0) [1, 1000] 108 | min_real = torch.finfo(avg_logits.dtype).min 109 | avg_logits = torch.clamp(avg_logits, min=min_real) 110 | return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1) 111 | 112 | 113 | def test_time_tuning(model, inputs, optimizer, scaler, args): 114 | output = None 115 | output2 = None 116 | single_output = None 117 | if args.cocoop: 118 | image_feature, pgen_ctx = inputs 119 | pgen_ctx.requires_grad = True 120 | optimizer = torch.optim.AdamW([pgen_ctx], args.lr) 121 | 122 | selected_idx = None 123 | for j in range(args.tta_steps): 124 | if 'tpt' in args.run_type: 125 | with torch.cuda.amp.autocast(): 126 | if args.cocoop: 127 | output = model((image_feature, pgen_ctx)) 128 | else: 129 | output = model(inputs) 130 | 131 | if selected_idx is not None: 132 | output = output[selected_idx] 133 | else: 134 | output, selected_idx = select_confident_samples(output, args.selection_p) 135 | 136 | loss = avg_entropy(output) 137 | else: 138 | loss = 0 139 | 140 | if args.two_step and 'tpt' in args.run_type: 141 | optimizer.zero_grad() 142 | # compute gradient and do SGD step 143 | scaler.scale(loss).backward(retain_graph=True) 144 | # Unscales the gradients of optimizer's assigned params in-place 145 | scaler.step(optimizer) 146 | scaler.update() 147 | loss = 0 148 | 149 | with torch.cuda.amp.autocast(): 150 | if args.cocoop: 151 | output2 = model((image_feature, pgen_ctx)) 152 | else: 153 | output2 = model(inputs) 154 | 155 | if 'ctpt' in args.run_type: 156 | if output == None and output2 == None: 157 | single_output = model(args.image) 158 | 159 | lambda_ = args.lambda_term 160 | loss += (-lambda_* model.l2_norm_mean_training) 161 | 162 | if args.run_type not in ['baseline', 'baseline_cocoop', 'baseline_coop', 'baseline_ts']: 163 | optimizer.zero_grad() 164 | # compute gradient and do SGD step 165 | scaler.scale(loss).backward() 166 | # Unscales the gradients of optimizer's assigned params in-place 167 | scaler.step(optimizer) 168 | scaler.update() 169 | 170 | if args.cocoop: 171 | return pgen_ctx 172 | 173 | return 174 | 175 | 176 | def main(args, result_dict): 177 | 178 | set_random_seed(args.seed) 179 | 180 | # This codebase has only been tested under the single GPU setting 181 | assert args.gpu is not None 182 | main_worker(args.gpu, args, result_dict) 183 | 184 | 185 | def main_worker(gpu, args, result_dict): 186 | args.gpu = gpu 187 | set_random_seed(args.seed) 188 | print("Use GPU: {} for training".format(args.gpu)) 189 | 190 | # create model (zero-shot clip model (ViT-L/14@px336) with promptruning) 191 | if args.test_sets in fewshot_datasets: 192 | classnames = eval("{}_classes".format(args.test_sets.lower())) 193 | else: 194 | classnames = imagenet_classes 195 | if args.cocoop: 196 | model = get_cocoop(args.arch, args.test_sets, 'cpu', args.n_ctx) 197 | assert args.load is not None 198 | load_model_weight(args.load, model, 'cpu', args) # to load to cuda: device="cuda:{}".format(args.gpu) 199 | model_state = deepcopy(model.state_dict()) 200 | else: 201 | model = get_coop(args.arch, args.test_sets, args.gpu, args.n_ctx, args.ctx_init) 202 | if args.load is not None: 203 | print("Use pre-trained soft prompt (CoOp) as initialization") 204 | pretrained_ctx = torch.load(args.load)['state_dict']['ctx'] 205 | assert pretrained_ctx.size()[0] == args.n_ctx 206 | with torch.no_grad(): 207 | #model.prompt_learner[0].ctx.copy_(pretrained_ctx) 208 | #model.prompt_learner[0].ctx_init_state = pretrained_ctx 209 | 210 | model.prompt_learner.ctx.copy_(pretrained_ctx) 211 | model.prompt_learner.ctx_init_state = pretrained_ctx 212 | 213 | model_state = None 214 | 215 | for name, param in model.named_parameters(): 216 | if not args.cocoop: 217 | if "prompt_learner" not in name: 218 | param.requires_grad_(False) 219 | else: 220 | if "text_encoder" not in name: 221 | param.requires_grad_(False) 222 | 223 | print("=> Model created: visual backbone {}".format(args.arch)) 224 | 225 | if not torch.cuda.is_available(): 226 | print('using CPU, this will be slow') 227 | else: 228 | assert args.gpu is not None 229 | torch.cuda.set_device(args.gpu) 230 | model = model.cuda(args.gpu) 231 | 232 | # define optimizer 233 | if args.cocoop: 234 | optimizer = None 235 | optim_state = None 236 | else: 237 | trainable_param = model.prompt_learner.parameters() 238 | optimizer = torch.optim.AdamW(trainable_param, args.lr) 239 | optim_state = deepcopy(optimizer.state_dict()) 240 | 241 | # setup automatic mixed-precision (Amp) loss scaling 242 | scaler = torch.cuda.amp.GradScaler(init_scale=1000) 243 | 244 | print('=> Using native Torch AMP. Training in mixed precision.') 245 | 246 | cudnn.benchmark = True 247 | 248 | # norm stats from clip.load() 249 | normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 250 | std=[0.26862954, 0.26130258, 0.27577711]) 251 | 252 | 253 | # iterating through eval datasets 254 | datasets = args.test_sets.split("/") 255 | assert len(datasets) == 1 256 | results = {} 257 | for set_id in datasets: 258 | if args.tpt: 259 | base_transform = transforms.Compose([ 260 | transforms.Resize(args.resolution, interpolation=BICUBIC), 261 | transforms.CenterCrop(args.resolution)]) 262 | preprocess = transforms.Compose([ 263 | transforms.ToTensor(), 264 | normalize]) 265 | 266 | if args.I_augmix: 267 | data_transform = AugMixAugmenter(base_transform, preprocess, n_views=args.batch_size-1, 268 | augmix=len(set_id)>=1) 269 | else: 270 | data_transform = AugMixAugmenter(base_transform, preprocess, n_views=args.batch_size-1, 271 | augmix=len(set_id)>1) 272 | batchsize = 1 273 | else: 274 | data_transform = transforms.Compose([ 275 | transforms.Resize(args.resolution, interpolation=BICUBIC), 276 | transforms.CenterCrop(args.resolution), 277 | transforms.ToTensor(), 278 | normalize, 279 | ]) 280 | batchsize = args.batch_size 281 | 282 | print("evaluating: {}".format(set_id)) 283 | # reset the model 284 | # Reset classnames of custom CLIP model 285 | if len(set_id) > 1: 286 | # fine-grained classification datasets 287 | classnames = eval("{}_classes".format(set_id.lower())) 288 | else: 289 | assert set_id in ['A', 'R', 'K', 'V', 'I'] 290 | classnames_all = imagenet_classes 291 | classnames = [] 292 | if set_id in ['A', 'R', 'V']: 293 | label_mask = eval("imagenet_{}_mask".format(set_id.lower())) 294 | if set_id == 'R': 295 | for i, m in enumerate(label_mask): 296 | if m: 297 | classnames.append(classnames_all[i]) 298 | else: 299 | classnames = [classnames_all[i] for i in label_mask] 300 | 301 | else: 302 | classnames = classnames_all 303 | if args.cocoop: 304 | model.prompt_generator.reset_classnames(classnames, args.arch) 305 | model = model.cpu() 306 | model_state = model.state_dict() 307 | model = model.cuda(args.gpu) 308 | else: 309 | model.reset_classnames(classnames, args.arch) 310 | 311 | val_dataset = build_dataset(set_id, data_transform, args.data, mode=args.dataset_mode) 312 | print("number of test samples: {}".format(len(val_dataset))) 313 | val_loader = torch.utils.data.DataLoader( 314 | val_dataset, 315 | batch_size=batchsize, shuffle=True, 316 | num_workers=args.workers, pin_memory=True) 317 | 318 | results[set_id] = test_time_adapt_eval(val_loader, model, model_state, optimizer, optim_state, scaler, args, result_dict) 319 | del val_dataset, val_loader 320 | try: 321 | print("=> Acc. on testset [{}]: @1 {}/ @5 {}".format(set_id, results[set_id][0], results[set_id][1])) 322 | except: 323 | print("=> Acc. on testset [{}]: {}".format(set_id, results[set_id])) 324 | 325 | print("======== Result Summary ========") 326 | print("params: nstep lr bs") 327 | print("params: {} {} {}".format(args.tta_steps, args.lr, args.batch_size)) 328 | print("\t\t [set_id] \t\t Top-1 acc. \t\t Top-5 acc.") 329 | for id in results.keys(): 330 | print("{}".format(id), end=" ") 331 | print("\n") 332 | for id in results.keys(): 333 | print("{:.2f}".format(results[id][0]), end=" ") 334 | print("\n") 335 | 336 | 337 | def test_time_adapt_eval(val_loader, model, model_state, optimizer, optim_state, scaler, args, result_dict): 338 | batch_time = AverageMeter('Time', ':6.3f', Summary.NONE) 339 | top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE) 340 | top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE) 341 | 342 | progress = ProgressMeter( 343 | len(val_loader), 344 | [batch_time, top1, top5], 345 | prefix='Test: ') 346 | 347 | # reset model and switch to evaluate mode 348 | model.eval() 349 | if not args.cocoop: # no need to reset cocoop because it's fixed 350 | with torch.no_grad(): 351 | model.reset() 352 | end = time.time() 353 | 354 | #define a softmax layer 355 | softmax = torch.nn.Softmax(dim=1) 356 | 357 | if 'ctpt' in args.run_type: 358 | model.l2_norm_cal = True 359 | else: 360 | model.l2_norm_cal = False 361 | 362 | for i, (images, target) in enumerate(val_loader): 363 | assert args.gpu is not None 364 | if isinstance(images, list): 365 | for k in range(len(images)): 366 | images[k] = images[k].cuda(args.gpu, non_blocking=True) 367 | image = images[0] 368 | else: 369 | if len(images.size()) > 4: 370 | # when using ImageNet Sampler as the dataset 371 | assert images.size()[0] == 1 372 | images = images.squeeze(0) 373 | images = images.cuda(args.gpu, non_blocking=True) 374 | image = images 375 | target = target.cuda(args.gpu, non_blocking=True) 376 | if args.tpt: 377 | images = torch.cat(images, dim=0) 378 | 379 | if 'ctpt' in args.run_type: 380 | args.image = image 381 | 382 | # reset the tunable prompt to its initial state 383 | if not args.cocoop: # no need to reset cocoop because it's fixed 384 | if args.tta_steps > 0: 385 | with torch.no_grad(): 386 | model.reset() 387 | optimizer.load_state_dict(optim_state) 388 | 389 | test_time_tuning(model, images, optimizer, scaler, args) 390 | else: 391 | with torch.no_grad(): 392 | with torch.cuda.amp.autocast(): 393 | image_feature, pgen_ctx = model.gen_ctx(images, args.tpt) 394 | optimizer = None 395 | 396 | pgen_ctx = test_time_tuning(model, (image_feature, pgen_ctx), optimizer, scaler, args) 397 | 398 | # The actual inference goes here 399 | if args.tpt: 400 | if args.cocoop: 401 | image_feature = image_feature[0].unsqueeze(0) 402 | 403 | with torch.no_grad(): 404 | with torch.cuda.amp.autocast(): 405 | if args.cocoop: 406 | output = model((image_feature, pgen_ctx)) 407 | else: 408 | output = model(image) 409 | 410 | 411 | if 'ts' not in args.run_type: 412 | softmax_output = softmax(output) 413 | elif 'ts' in args.run_type: 414 | if 'ViT' in args.arch: 415 | softmax_output = softmax(output/temperature_value['ViT']) 416 | elif 'RN' in args.arch: 417 | softmax_output = softmax(output/temperature_value['RN']) 418 | else: 419 | ipdb.set_trace() 420 | 421 | #maximum confidence of the softmax_output and its index 422 | max_confidence, max_index = torch.max(softmax_output, 1) 423 | 424 | #save the max confidence, prediction, and label to the result_dict 425 | result_dict['max_confidence'].append(max_confidence.item()) 426 | result_dict['prediction'].append(max_index.item()) 427 | result_dict['label'].append(target.item()) 428 | 429 | # measure accuracy and record loss 430 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 431 | 432 | top1.update(acc1[0], image.size(0)) 433 | top5.update(acc5[0], image.size(0)) 434 | 435 | # measure elapsed time 436 | batch_time.update(time.time() - end) 437 | end = time.time() 438 | 439 | if (i+1) % args.print_freq == 0: 440 | progress.display(i) 441 | 442 | progress.display_summary() 443 | 444 | return [top1.avg, top5.avg] 445 | 446 | 447 | temperature_value = {'ViT': 1.16, 'RN': 1.15} #for temperature scaling experiments 448 | 449 | if __name__ == '__main__': 450 | parser = argparse.ArgumentParser(description='Test-time Prompt Tuning') 451 | parser.add_argument('data', metavar='DIR', help='path to dataset root') 452 | parser.add_argument('--test_sets', type=str, default='A/R/V/K/I', help='test dataset (multiple datasets split by slash)') 453 | parser.add_argument('--dataset_mode', type=str, default='test', help='which split to use: train/val/test') 454 | parser.add_argument('-a', '--arch', metavar='ARCH', default='RN50') 455 | parser.add_argument('--resolution', default=224, type=int, help='CLIP image resolution') 456 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 457 | help='number of data loading workers (default: 4)') 458 | parser.add_argument('-b', '--batch-size', default=64, type=int, metavar='N') 459 | parser.add_argument('--lr', '--learning-rate', default=5e-3, type=float, 460 | metavar='LR', help='initial learning rate', dest='lr') 461 | parser.add_argument('-p', '--print-freq', default=200, type=int, 462 | metavar='N', help='print frequency (default: 10)') 463 | parser.add_argument('--gpu', default=0, type=int, 464 | help='GPU id to use.') 465 | parser.add_argument('--tpt', action='store_true', default=False, help='run test-time prompt tuning') 466 | parser.add_argument('--selection_p', default=0.1, type=float, help='confidence selection percentile') 467 | parser.add_argument('--tta_steps', default=1, type=int, help='test-time-adapt steps') 468 | parser.add_argument('--n_ctx', default=4, type=int, help='number of tunable tokens') 469 | parser.add_argument('--ctx_init', default=None, type=str, help='init tunable prompts') 470 | parser.add_argument('--cocoop', action='store_true', default=False, help="use cocoop's output as prompt initialization") 471 | parser.add_argument('--load', default=None, type=str, help='path to a pre-trained coop/cocoop') 472 | parser.add_argument('--seed', type=int, default=0) 473 | 474 | # added args for c-tpt -------------------------------- 475 | parser.add_argument('--lambda_term' , type=float, default=0.0, help='lambda for c-tpt') 476 | parser.add_argument('--run_type' , type=str, default='baseline_tpt', choices=['baseline', 'tpt', 'tpt_ctpt', 'tpt_ts']) 477 | parser.add_argument('--two_step', action='store_true', default=False, help='two step training') 478 | parser.add_argument('--I_augmix', action='store_true', default=False, help='augmix for I') 479 | # ------------------------------------------------ 480 | 481 | args = parser.parse_args() 482 | 483 | if 'ctpt' not in args.run_type: 484 | args.lambda_term = 0.0 485 | 486 | result_dict = {'max_confidence': [], 'prediction': [], 'label': []} 487 | main(args, result_dict) 488 | acc, ece = Calculator(result_dict) 489 | 490 | 491 | 492 | 493 | 494 | 495 | -------------------------------------------------------------------------------- /data/imagenet_variants.py: -------------------------------------------------------------------------------- 1 | thousand_k_to_200 = {0: -1, 1: -1, 2: -1, 3: -1, 4: -1, 5: -1, 6: 0, 7: -1, 8: -1, 9: -1, 10: -1, 11: 1, 12: -1, 13: 2, 14: -1, 15: 3, 16: -1, 17: 4, 18: -1, 19: -1, 20: -1, 21: -1, 22: 5, 23: 6, 24: -1, 25: -1, 26: -1, 27: 7, 28: -1, 29: -1, 30: 8, 31: -1, 32: -1, 33: -1, 34: -1, 35: -1, 36: -1, 37: 9, 38: -1, 39: 10, 40: -1, 41: -1, 42: 11, 43: -1, 44: -1, 45: -1, 46: -1, 47: 12, 48: -1, 49: -1, 50: 13, 51: -1, 52: -1, 53: -1, 54: -1, 55: -1, 56: -1, 57: 14, 58: -1, 59: -1, 60: -1, 61: -1, 62: -1, 63: -1, 64: -1, 65: -1, 66: -1, 67: -1, 68: -1, 69: -1, 70: 15, 71: 16, 72: -1, 73: -1, 74: -1, 75: -1, 76: 17, 77: -1, 78: -1, 79: 18, 80: -1, 81: -1, 82: -1, 83: -1, 84: -1, 85: -1, 86: -1, 87: -1, 88: -1, 89: 19, 90: 20, 91: -1, 92: -1, 93: -1, 94: 21, 95: -1, 96: 22, 97: 23, 98: -1, 99: 24, 100: -1, 101: -1, 102: -1, 103: -1, 104: -1, 105: 25, 106: -1, 107: 26, 108: 27, 109: -1, 110: 28, 111: -1, 112: -1, 113: 29, 114: -1, 115: -1, 116: -1, 117: -1, 118: -1, 119: -1, 120: -1, 121: -1, 122: -1, 123: -1, 124: 30, 125: 31, 126: -1, 127: -1, 128: -1, 129: -1, 130: 32, 131: -1, 132: 33, 133: -1, 134: -1, 135: -1, 136: -1, 137: -1, 138: -1, 139: -1, 140: -1, 141: -1, 142: -1, 143: 34, 144: 35, 145: -1, 146: -1, 147: -1, 148: -1, 149: -1, 150: 36, 151: 37, 152: -1, 153: -1, 154: -1, 155: -1, 156: -1, 157: -1, 158: -1, 159: -1, 160: -1, 161: -1, 162: -1, 163: -1, 164: -1, 165: -1, 166: -1, 167: -1, 168: -1, 169: -1, 170: -1, 171: -1, 172: -1, 173: -1, 174: -1, 175: -1, 176: -1, 177: -1, 178: -1, 179: -1, 180: -1, 181: -1, 182: -1, 183: -1, 184: -1, 185: -1, 186: -1, 187: -1, 188: -1, 189: -1, 190: -1, 191: -1, 192: -1, 193: -1, 194: -1, 195: -1, 196: -1, 197: -1, 198: -1, 199: -1, 200: -1, 201: -1, 202: -1, 203: -1, 204: -1, 205: -1, 206: -1, 207: 38, 208: -1, 209: -1, 210: -1, 211: -1, 212: -1, 213: -1, 214: -1, 215: -1, 216: -1, 217: -1, 218: -1, 219: -1, 220: -1, 221: -1, 222: -1, 223: -1, 224: -1, 225: -1, 226: -1, 227: -1, 228: -1, 229: -1, 230: -1, 231: -1, 232: -1, 233: -1, 234: 39, 235: 40, 236: -1, 237: -1, 238: -1, 239: -1, 240: -1, 241: -1, 242: -1, 243: -1, 244: -1, 245: -1, 246: -1, 247: -1, 248: -1, 249: -1, 250: -1, 251: -1, 252: -1, 253: -1, 254: 41, 255: -1, 256: -1, 257: -1, 258: -1, 259: -1, 260: -1, 261: -1, 262: -1, 263: -1, 264: -1, 265: -1, 266: -1, 267: -1, 268: -1, 269: -1, 270: -1, 271: -1, 272: -1, 273: -1, 274: -1, 275: -1, 276: -1, 277: 42, 278: -1, 279: -1, 280: -1, 281: -1, 282: -1, 283: 43, 284: -1, 285: -1, 286: -1, 287: 44, 288: -1, 289: -1, 290: -1, 291: 45, 292: -1, 293: -1, 294: -1, 295: 46, 296: -1, 297: -1, 298: 47, 299: -1, 300: -1, 301: 48, 302: -1, 303: -1, 304: -1, 305: -1, 306: 49, 307: 50, 308: 51, 309: 52, 310: 53, 311: 54, 312: -1, 313: 55, 314: 56, 315: 57, 316: -1, 317: 58, 318: -1, 319: 59, 320: -1, 321: -1, 322: -1, 323: 60, 324: 61, 325: -1, 326: 62, 327: 63, 328: -1, 329: -1, 330: 64, 331: -1, 332: -1, 333: -1, 334: 65, 335: 66, 336: 67, 337: -1, 338: -1, 339: -1, 340: -1, 341: -1, 342: -1, 343: -1, 344: -1, 345: -1, 346: -1, 347: 68, 348: -1, 349: -1, 350: -1, 351: -1, 352: -1, 353: -1, 354: -1, 355: -1, 356: -1, 357: -1, 358: -1, 359: -1, 360: -1, 361: 69, 362: -1, 363: 70, 364: -1, 365: -1, 366: -1, 367: -1, 368: -1, 369: -1, 370: -1, 371: -1, 372: 71, 373: -1, 374: -1, 375: -1, 376: -1, 377: -1, 378: 72, 379: -1, 380: -1, 381: -1, 382: -1, 383: -1, 384: -1, 385: -1, 386: 73, 387: -1, 388: -1, 389: -1, 390: -1, 391: -1, 392: -1, 393: -1, 394: -1, 395: -1, 396: -1, 397: 74, 398: -1, 399: -1, 400: 75, 401: 76, 402: 77, 403: -1, 404: 78, 405: -1, 406: -1, 407: 79, 408: -1, 409: -1, 410: -1, 411: 80, 412: -1, 413: -1, 414: -1, 415: -1, 416: 81, 417: 82, 418: -1, 419: -1, 420: 83, 421: -1, 422: -1, 423: -1, 424: -1, 425: 84, 426: -1, 427: -1, 428: 85, 429: -1, 430: 86, 431: -1, 432: -1, 433: -1, 434: -1, 435: -1, 436: -1, 437: 87, 438: 88, 439: -1, 440: -1, 441: -1, 442: -1, 443: -1, 444: -1, 445: 89, 446: -1, 447: -1, 448: -1, 449: -1, 450: -1, 451: -1, 452: -1, 453: -1, 454: -1, 455: -1, 456: 90, 457: 91, 458: -1, 459: -1, 460: -1, 461: 92, 462: 93, 463: -1, 464: -1, 465: -1, 466: -1, 467: -1, 468: -1, 469: -1, 470: 94, 471: -1, 472: 95, 473: -1, 474: -1, 475: -1, 476: -1, 477: -1, 478: -1, 479: -1, 480: -1, 481: -1, 482: -1, 483: 96, 484: -1, 485: -1, 486: 97, 487: -1, 488: 98, 489: -1, 490: -1, 491: -1, 492: 99, 493: -1, 494: -1, 495: -1, 496: 100, 497: -1, 498: -1, 499: -1, 500: -1, 501: -1, 502: -1, 503: -1, 504: -1, 505: -1, 506: -1, 507: -1, 508: -1, 509: -1, 510: -1, 511: -1, 512: -1, 513: -1, 514: 101, 515: -1, 516: 102, 517: -1, 518: -1, 519: -1, 520: -1, 521: -1, 522: -1, 523: -1, 524: -1, 525: -1, 526: -1, 527: -1, 528: 103, 529: -1, 530: 104, 531: -1, 532: -1, 533: -1, 534: -1, 535: -1, 536: -1, 537: -1, 538: -1, 539: 105, 540: -1, 541: -1, 542: 106, 543: 107, 544: -1, 545: -1, 546: -1, 547: -1, 548: -1, 549: 108, 550: -1, 551: -1, 552: 109, 553: -1, 554: -1, 555: -1, 556: -1, 557: 110, 558: -1, 559: -1, 560: -1, 561: 111, 562: 112, 563: -1, 564: -1, 565: -1, 566: -1, 567: -1, 568: -1, 569: 113, 570: -1, 571: -1, 572: 114, 573: 115, 574: -1, 575: 116, 576: -1, 577: -1, 578: -1, 579: 117, 580: -1, 581: -1, 582: -1, 583: -1, 584: -1, 585: -1, 586: -1, 587: -1, 588: -1, 589: 118, 590: -1, 591: -1, 592: -1, 593: -1, 594: -1, 595: -1, 596: -1, 597: -1, 598: -1, 599: -1, 600: -1, 601: -1, 602: -1, 603: -1, 604: -1, 605: -1, 606: 119, 607: 120, 608: -1, 609: 121, 610: -1, 611: -1, 612: -1, 613: -1, 614: 122, 615: -1, 616: -1, 617: -1, 618: -1, 619: -1, 620: -1, 621: -1, 622: -1, 623: -1, 624: -1, 625: -1, 626: 123, 627: 124, 628: -1, 629: -1, 630: -1, 631: -1, 632: -1, 633: -1, 634: -1, 635: -1, 636: -1, 637: -1, 638: -1, 639: -1, 640: 125, 641: 126, 642: 127, 643: 128, 644: -1, 645: -1, 646: -1, 647: -1, 648: -1, 649: -1, 650: -1, 651: -1, 652: -1, 653: -1, 654: -1, 655: -1, 656: -1, 657: -1, 658: 129, 659: -1, 660: -1, 661: -1, 662: -1, 663: -1, 664: -1, 665: -1, 666: -1, 667: -1, 668: 130, 669: -1, 670: -1, 671: -1, 672: -1, 673: -1, 674: -1, 675: -1, 676: -1, 677: 131, 678: -1, 679: -1, 680: -1, 681: -1, 682: 132, 683: -1, 684: 133, 685: -1, 686: -1, 687: 134, 688: -1, 689: -1, 690: -1, 691: -1, 692: -1, 693: -1, 694: -1, 695: -1, 696: -1, 697: -1, 698: -1, 699: -1, 700: -1, 701: 135, 702: -1, 703: -1, 704: 136, 705: -1, 706: -1, 707: -1, 708: -1, 709: -1, 710: -1, 711: -1, 712: -1, 713: -1, 714: -1, 715: -1, 716: -1, 717: -1, 718: -1, 719: 137, 720: -1, 721: -1, 722: -1, 723: -1, 724: -1, 725: -1, 726: -1, 727: -1, 728: -1, 729: -1, 730: -1, 731: -1, 732: -1, 733: -1, 734: -1, 735: -1, 736: 138, 737: -1, 738: -1, 739: -1, 740: -1, 741: -1, 742: -1, 743: -1, 744: -1, 745: -1, 746: 139, 747: -1, 748: -1, 749: 140, 750: -1, 751: -1, 752: 141, 753: -1, 754: -1, 755: -1, 756: -1, 757: -1, 758: 142, 759: -1, 760: -1, 761: -1, 762: -1, 763: 143, 764: -1, 765: 144, 766: -1, 767: -1, 768: 145, 769: -1, 770: -1, 771: -1, 772: -1, 773: 146, 774: 147, 775: -1, 776: 148, 777: -1, 778: -1, 779: 149, 780: 150, 781: -1, 782: -1, 783: -1, 784: -1, 785: -1, 786: 151, 787: -1, 788: -1, 789: -1, 790: -1, 791: -1, 792: 152, 793: -1, 794: -1, 795: -1, 796: -1, 797: 153, 798: -1, 799: -1, 800: -1, 801: -1, 802: 154, 803: 155, 804: 156, 805: -1, 806: -1, 807: -1, 808: -1, 809: -1, 810: -1, 811: -1, 812: -1, 813: 157, 814: -1, 815: 158, 816: -1, 817: -1, 818: -1, 819: -1, 820: 159, 821: -1, 822: -1, 823: 160, 824: -1, 825: -1, 826: -1, 827: -1, 828: -1, 829: -1, 830: -1, 831: 161, 832: -1, 833: 162, 834: -1, 835: 163, 836: -1, 837: -1, 838: -1, 839: 164, 840: -1, 841: -1, 842: -1, 843: -1, 844: -1, 845: 165, 846: -1, 847: 166, 848: -1, 849: -1, 850: 167, 851: -1, 852: -1, 853: -1, 854: -1, 855: -1, 856: -1, 857: -1, 858: -1, 859: 168, 860: -1, 861: -1, 862: 169, 863: -1, 864: -1, 865: -1, 866: -1, 867: -1, 868: -1, 869: -1, 870: 170, 871: -1, 872: -1, 873: -1, 874: -1, 875: -1, 876: -1, 877: -1, 878: -1, 879: 171, 880: 172, 881: -1, 882: -1, 883: -1, 884: -1, 885: -1, 886: -1, 887: -1, 888: 173, 889: -1, 890: 174, 891: -1, 892: -1, 893: -1, 894: -1, 895: -1, 896: -1, 897: 175, 898: -1, 899: -1, 900: 176, 901: -1, 902: -1, 903: -1, 904: -1, 905: -1, 906: -1, 907: 177, 908: -1, 909: -1, 910: -1, 911: -1, 912: -1, 913: 178, 914: -1, 915: -1, 916: -1, 917: -1, 918: -1, 919: -1, 920: -1, 921: -1, 922: -1, 923: -1, 924: 179, 925: -1, 926: -1, 927: -1, 928: -1, 929: -1, 930: -1, 931: -1, 932: 180, 933: 181, 934: 182, 935: -1, 936: -1, 937: 183, 938: -1, 939: -1, 940: -1, 941: -1, 942: -1, 943: 184, 944: -1, 945: 185, 946: -1, 947: 186, 948: -1, 949: -1, 950: -1, 951: 187, 952: -1, 953: -1, 954: 188, 955: -1, 956: 189, 957: 190, 958: -1, 959: 191, 960: -1, 961: -1, 962: -1, 963: -1, 964: -1, 965: -1, 966: -1, 967: -1, 968: -1, 969: -1, 970: -1, 971: 192, 972: 193, 973: -1, 974: -1, 975: -1, 976: -1, 977: -1, 978: -1, 979: -1, 980: 194, 981: 195, 982: -1, 983: -1, 984: 196, 985: -1, 986: 197, 987: 198, 988: 199, 989: -1, 990: -1, 991: -1, 992: -1, 993: -1, 994: -1, 995: -1, 996: -1, 997: -1, 998: -1, 999: -1} 2 | # For ImageNet-A 200 categories 3 | imagenet_a_mask = [k for k in thousand_k_to_200 if thousand_k_to_200[k] != -1] 4 | 5 | 6 | # For ImageNet-R 200 categories 7 | all_wnids = ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', 'n01496331', 'n01498041', 'n01514668', 'n01514859', 'n01518878', 'n01530575', 'n01531178', 'n01532829', 'n01534433', 'n01537544', 'n01558993', 'n01560419', 'n01580077', 'n01582220', 'n01592084', 'n01601694', 'n01608432', 'n01614925', 'n01616318', 'n01622779', 'n01629819', 'n01630670', 'n01631663', 'n01632458', 'n01632777', 'n01641577', 'n01644373', 'n01644900', 'n01664065', 'n01665541', 'n01667114', 'n01667778', 'n01669191', 'n01675722', 'n01677366', 'n01682714', 'n01685808', 'n01687978', 'n01688243', 'n01689811', 'n01692333', 'n01693334', 'n01694178', 'n01695060', 'n01697457', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01729322', 'n01729977', 'n01734418', 'n01735189', 'n01737021', 'n01739381', 'n01740131', 'n01742172', 'n01744401', 'n01748264', 'n01749939', 'n01751748', 'n01753488', 'n01755581', 'n01756291', 'n01768244', 'n01770081', 'n01770393', 'n01773157', 'n01773549', 'n01773797', 'n01774384', 'n01774750', 'n01775062', 'n01776313', 'n01784675', 'n01795545', 'n01796340', 'n01797886', 'n01798484', 'n01806143', 'n01806567', 'n01807496', 'n01817953', 'n01818515', 'n01819313', 'n01820546', 'n01824575', 'n01828970', 'n01829413', 'n01833805', 'n01843065', 'n01843383', 'n01847000', 'n01855032', 'n01855672', 'n01860187', 'n01871265', 'n01872401', 'n01873310', 'n01877812', 'n01882714', 'n01883070', 'n01910747', 'n01914609', 'n01917289', 'n01924916', 'n01930112', 'n01943899', 'n01944390', 'n01945685', 'n01950731', 'n01955084', 'n01968897', 'n01978287', 'n01978455', 'n01980166', 'n01981276', 'n01983481', 'n01984695', 'n01985128', 'n01986214', 'n01990800', 'n02002556', 'n02002724', 'n02006656', 'n02007558', 'n02009229', 'n02009912', 'n02011460', 'n02012849', 'n02013706', 'n02017213', 'n02018207', 'n02018795', 'n02025239', 'n02027492', 'n02028035', 'n02033041', 'n02037110', 'n02051845', 'n02056570', 'n02058221', 'n02066245', 'n02071294', 'n02074367', 'n02077923', 'n02085620', 'n02085782', 'n02085936', 'n02086079', 'n02086240', 'n02086646', 'n02086910', 'n02087046', 'n02087394', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02088632', 'n02089078', 'n02089867', 'n02089973', 'n02090379', 'n02090622', 'n02090721', 'n02091032', 'n02091134', 'n02091244', 'n02091467', 'n02091635', 'n02091831', 'n02092002', 'n02092339', 'n02093256', 'n02093428', 'n02093647', 'n02093754', 'n02093859', 'n02093991', 'n02094114', 'n02094258', 'n02094433', 'n02095314', 'n02095570', 'n02095889', 'n02096051', 'n02096177', 'n02096294', 'n02096437', 'n02096585', 'n02097047', 'n02097130', 'n02097209', 'n02097298', 'n02097474', 'n02097658', 'n02098105', 'n02098286', 'n02098413', 'n02099267', 'n02099429', 'n02099601', 'n02099712', 'n02099849', 'n02100236', 'n02100583', 'n02100735', 'n02100877', 'n02101006', 'n02101388', 'n02101556', 'n02102040', 'n02102177', 'n02102318', 'n02102480', 'n02102973', 'n02104029', 'n02104365', 'n02105056', 'n02105162', 'n02105251', 'n02105412', 'n02105505', 'n02105641', 'n02105855', 'n02106030', 'n02106166', 'n02106382', 'n02106550', 'n02106662', 'n02107142', 'n02107312', 'n02107574', 'n02107683', 'n02107908', 'n02108000', 'n02108089', 'n02108422', 'n02108551', 'n02108915', 'n02109047', 'n02109525', 'n02109961', 'n02110063', 'n02110185', 'n02110341', 'n02110627', 'n02110806', 'n02110958', 'n02111129', 'n02111277', 'n02111500', 'n02111889', 'n02112018', 'n02112137', 'n02112350', 'n02112706', 'n02113023', 'n02113186', 'n02113624', 'n02113712', 'n02113799', 'n02113978', 'n02114367', 'n02114548', 'n02114712', 'n02114855', 'n02115641', 'n02115913', 'n02116738', 'n02117135', 'n02119022', 'n02119789', 'n02120079', 'n02120505', 'n02123045', 'n02123159', 'n02123394', 'n02123597', 'n02124075', 'n02125311', 'n02127052', 'n02128385', 'n02128757', 'n02128925', 'n02129165', 'n02129604', 'n02130308', 'n02132136', 'n02133161', 'n02134084', 'n02134418', 'n02137549', 'n02138441', 'n02165105', 'n02165456', 'n02167151', 'n02168699', 'n02169497', 'n02172182', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02229544', 'n02231487', 'n02233338', 'n02236044', 'n02256656', 'n02259212', 'n02264363', 'n02268443', 'n02268853', 'n02276258', 'n02277742', 'n02279972', 'n02280649', 'n02281406', 'n02281787', 'n02317335', 'n02319095', 'n02321529', 'n02325366', 'n02326432', 'n02328150', 'n02342885', 'n02346627', 'n02356798', 'n02361337', 'n02363005', 'n02364673', 'n02389026', 'n02391049', 'n02395406', 'n02396427', 'n02397096', 'n02398521', 'n02403003', 'n02408429', 'n02410509', 'n02412080', 'n02415577', 'n02417914', 'n02422106', 'n02422699', 'n02423022', 'n02437312', 'n02437616', 'n02441942', 'n02442845', 'n02443114', 'n02443484', 'n02444819', 'n02445715', 'n02447366', 'n02454379', 'n02457408', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02483708', 'n02484975', 'n02486261', 'n02486410', 'n02487347', 'n02488291', 'n02488702', 'n02489166', 'n02490219', 'n02492035', 'n02492660', 'n02493509', 'n02493793', 'n02494079', 'n02497673', 'n02500267', 'n02504013', 'n02504458', 'n02509815', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02606052', 'n02607072', 'n02640242', 'n02641379', 'n02643566', 'n02655020', 'n02666196', 'n02667093', 'n02669723', 'n02672831', 'n02676566', 'n02687172', 'n02690373', 'n02692877', 'n02699494', 'n02701002', 'n02704792', 'n02708093', 'n02727426', 'n02730930', 'n02747177', 'n02749479', 'n02769748', 'n02776631', 'n02777292', 'n02782093', 'n02783161', 'n02786058', 'n02787622', 'n02788148', 'n02790996', 'n02791124', 'n02791270', 'n02793495', 'n02794156', 'n02795169', 'n02797295', 'n02799071', 'n02802426', 'n02804414', 'n02804610', 'n02807133', 'n02808304', 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02817516', 'n02823428', 'n02823750', 'n02825657', 'n02834397', 'n02835271', 'n02837789', 'n02840245', 'n02841315', 'n02843684', 'n02859443', 'n02860847', 'n02865351', 'n02869837', 'n02870880', 'n02871525', 'n02877765', 'n02879718', 'n02883205', 'n02892201', 'n02892767', 'n02894605', 'n02895154', 'n02906734', 'n02909870', 'n02910353', 'n02916936', 'n02917067', 'n02927161', 'n02930766', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02951585', 'n02963159', 'n02965783', 'n02966193', 'n02966687', 'n02971356', 'n02974003', 'n02977058', 'n02978881', 'n02979186', 'n02980441', 'n02981792', 'n02988304', 'n02992211', 'n02992529', 'n02999410', 'n03000134', 'n03000247', 'n03000684', 'n03014705', 'n03016953', 'n03017168', 'n03018349', 'n03026506', 'n03028079', 'n03032252', 'n03041632', 'n03042490', 'n03045698', 'n03047690', 'n03062245', 'n03063599', 'n03063689', 'n03065424', 'n03075370', 'n03085013', 'n03089624', 'n03095699', 'n03100240', 'n03109150', 'n03110669', 'n03124043', 'n03124170', 'n03125729', 'n03126707', 'n03127747', 'n03127925', 'n03131574', 'n03133878', 'n03134739', 'n03141823', 'n03146219', 'n03160309', 'n03179701', 'n03180011', 'n03187595', 'n03188531', 'n03196217', 'n03197337', 'n03201208', 'n03207743', 'n03207941', 'n03208938', 'n03216828', 'n03218198', 'n03220513', 'n03223299', 'n03240683', 'n03249569', 'n03250847', 'n03255030', 'n03259280', 'n03271574', 'n03272010', 'n03272562', 'n03290653', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03337140', 'n03344393', 'n03345487', 'n03347037', 'n03355925', 'n03372029', 'n03376595', 'n03379051', 'n03384352', 'n03388043', 'n03388183', 'n03388549', 'n03393912', 'n03394916', 'n03400231', 'n03404251', 'n03417042', 'n03424325', 'n03425413', 'n03443371', 'n03444034', 'n03445777', 'n03445924', 'n03447447', 'n03447721', 'n03450230', 'n03452741', 'n03457902', 'n03459775', 'n03461385', 'n03467068', 'n03476684', 'n03476991', 'n03478589', 'n03481172', 'n03482405', 'n03483316', 'n03485407', 'n03485794', 'n03492542', 'n03494278', 'n03495258', 'n03496892', 'n03498962', 'n03527444', 'n03529860', 'n03530642', 'n03532672', 'n03534580', 'n03535780', 'n03538406', 'n03544143', 'n03584254', 'n03584829', 'n03590841', 'n03594734', 'n03594945', 'n03595614', 'n03598930', 'n03599486', 'n03602883', 'n03617480', 'n03623198', 'n03627232', 'n03630383', 'n03633091', 'n03637318', 'n03642806', 'n03649909', 'n03657121', 'n03658185', 'n03661043', 'n03662601', 'n03666591', 'n03670208', 'n03673027', 'n03676483', 'n03680355', 'n03690938', 'n03691459', 'n03692522', 'n03697007', 'n03706229', 'n03709823', 'n03710193', 'n03710637', 'n03710721', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733281', 'n03733805', 'n03742115', 'n03743016', 'n03759954', 'n03761084', 'n03763968', 'n03764736', 'n03769881', 'n03770439', 'n03770679', 'n03773504', 'n03775071', 'n03775546', 'n03776460', 'n03777568', 'n03777754', 'n03781244', 'n03782006', 'n03785016', 'n03786901', 'n03787032', 'n03788195', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03793489', 'n03794056', 'n03796401', 'n03803284', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03832673', 'n03837869', 'n03838899', 'n03840681', 'n03841143', 'n03843555', 'n03854065', 'n03857828', 'n03866082', 'n03868242', 'n03868863', 'n03871628', 'n03873416', 'n03874293', 'n03874599', 'n03876231', 'n03877472', 'n03877845', 'n03884397', 'n03887697', 'n03888257', 'n03888605', 'n03891251', 'n03891332', 'n03895866', 'n03899768', 'n03902125', 'n03903868', 'n03908618', 'n03908714', 'n03916031', 'n03920288', 'n03924679', 'n03929660', 'n03929855', 'n03930313', 'n03930630', 'n03933933', 'n03935335', 'n03937543', 'n03938244', 'n03942813', 'n03944341', 'n03947888', 'n03950228', 'n03954731', 'n03956157', 'n03958227', 'n03961711', 'n03967562', 'n03970156', 'n03976467', 'n03976657', 'n03977966', 'n03980874', 'n03982430', 'n03983396', 'n03991062', 'n03992509', 'n03995372', 'n03998194', 'n04004767', 'n04005630', 'n04008634', 'n04009552', 'n04019541', 'n04023962', 'n04026417', 'n04033901', 'n04033995', 'n04037443', 'n04039381', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04065272', 'n04067472', 'n04069434', 'n04070727', 'n04074963', 'n04081281', 'n04086273', 'n04090263', 'n04099969', 'n04111531', 'n04116512', 'n04118538', 'n04118776', 'n04120489', 'n04125021', 'n04127249', 'n04131690', 'n04133789', 'n04136333', 'n04141076', 'n04141327', 'n04141975', 'n04146614', 'n04147183', 'n04149813', 'n04152593', 'n04153751', 'n04154565', 'n04162706', 'n04179913', 'n04192698', 'n04200800', 'n04201297', 'n04204238', 'n04204347', 'n04208210', 'n04209133', 'n04209239', 'n04228054', 'n04229816', 'n04235860', 'n04238763', 'n04239074', 'n04243546', 'n04251144', 'n04252077', 'n04252225', 'n04254120', 'n04254680', 'n04254777', 'n04258138', 'n04259630', 'n04263257', 'n04264628', 'n04265275', 'n04266014', 'n04270147', 'n04273569', 'n04275548', 'n04277352', 'n04285008', 'n04286575', 'n04296562', 'n04310018', 'n04311004', 'n04311174', 'n04317175', 'n04325704', 'n04326547', 'n04328186', 'n04330267', 'n04332243', 'n04335435', 'n04336792', 'n04344873', 'n04346328', 'n04347754', 'n04350905', 'n04355338', 'n04355933', 'n04356056', 'n04357314', 'n04366367', 'n04367480', 'n04370456', 'n04371430', 'n04371774', 'n04372370', 'n04376876', 'n04380533', 'n04389033', 'n04392985', 'n04398044', 'n04399382', 'n04404412', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04428191', 'n04429376', 'n04435653', 'n04442312', 'n04443257', 'n04447861', 'n04456115', 'n04458633', 'n04461696', 'n04462240', 'n04465501', 'n04467665', 'n04476259', 'n04479046', 'n04482393', 'n04483307', 'n04485082', 'n04486054', 'n04487081', 'n04487394', 'n04493381', 'n04501370', 'n04505470', 'n04507155', 'n04509417', 'n04515003', 'n04517823', 'n04522168', 'n04523525', 'n04525038', 'n04525305', 'n04532106', 'n04532670', 'n04536866', 'n04540053', 'n04542943', 'n04548280', 'n04548362', 'n04550184', 'n04552348', 'n04553703', 'n04554684', 'n04557648', 'n04560804', 'n04562935', 'n04579145', 'n04579432', 'n04584207', 'n04589890', 'n04590129', 'n04591157', 'n04591713', 'n04592741', 'n04596742', 'n04597913', 'n04599235', 'n04604644', 'n04606251', 'n04612504', 'n04613696', 'n06359193', 'n06596364', 'n06785654', 'n06794110', 'n06874185', 'n07248320', 'n07565083', 'n07579787', 'n07583066', 'n07584110', 'n07590611', 'n07613480', 'n07614500', 'n07615774', 'n07684084', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07711569', 'n07714571', 'n07714990', 'n07715103', 'n07716358', 'n07716906', 'n07717410', 'n07717556', 'n07718472', 'n07718747', 'n07720875', 'n07730033', 'n07734744', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', 'n07754684', 'n07760859', 'n07768694', 'n07802026', 'n07831146', 'n07836838', 'n07860988', 'n07871810', 'n07873807', 'n07875152', 'n07880968', 'n07892512', 'n07920052', 'n07930864', 'n07932039', 'n09193705', 'n09229709', 'n09246464', 'n09256479', 'n09288635', 'n09332890', 'n09399592', 'n09421951', 'n09428293', 'n09468604', 'n09472597', 'n09835506', 'n10148035', 'n10565667', 'n11879895', 'n11939491', 'n12057211', 'n12144580', 'n12267677', 'n12620546', 'n12768682', 'n12985857', 'n12998815', 'n13037406', 'n13040303', 'n13044778', 'n13052670', 'n13054560', 'n13133613', 'n15075141'] 8 | 9 | imagenet_r_wnids = {'n01443537', 'n01484850', 'n01494475', 'n01498041', 'n01514859', 'n01518878', 'n01531178', 'n01534433', 'n01614925', 'n01616318', 'n01630670', 'n01632777', 'n01644373', 'n01677366', 'n01694178', 'n01748264', 'n01770393', 'n01774750', 'n01784675', 'n01806143', 'n01820546', 'n01833805', 'n01843383', 'n01847000', 'n01855672', 'n01860187', 'n01882714', 'n01910747', 'n01944390', 'n01983481', 'n01986214', 'n02007558', 'n02009912', 'n02051845', 'n02056570', 'n02066245', 'n02071294', 'n02077923', 'n02085620', 'n02086240', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02091032', 'n02091134', 'n02092339', 'n02094433', 'n02096585', 'n02097298', 'n02098286', 'n02099601', 'n02099712', 'n02102318', 'n02106030', 'n02106166', 'n02106550', 'n02106662', 'n02108089', 'n02108915', 'n02109525', 'n02110185', 'n02110341', 'n02110958', 'n02112018', 'n02112137', 'n02113023', 'n02113624', 'n02113799', 'n02114367', 'n02117135', 'n02119022', 'n02123045', 'n02128385', 'n02128757', 'n02129165', 'n02129604', 'n02130308', 'n02134084', 'n02138441', 'n02165456', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02233338', 'n02236044', 'n02268443', 'n02279972', 'n02317335', 'n02325366', 'n02346627', 'n02356798', 'n02363005', 'n02364673', 'n02391049', 'n02395406', 'n02398521', 'n02410509', 'n02423022', 'n02437616', 'n02445715', 'n02447366', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02486410', 'n02510455', 'n02526121', 'n02607072', 'n02655020', 'n02672831', 'n02701002', 'n02749479', 'n02769748', 'n02793495', 'n02797295', 'n02802426', 'n02808440', 'n02814860', 'n02823750', 'n02841315', 'n02843684', 'n02883205', 'n02906734', 'n02909870', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02966193', 'n02980441', 'n02992529', 'n03124170', 'n03272010', 'n03345487', 'n03372029', 'n03424325', 'n03452741', 'n03467068', 'n03481172', 'n03494278', 'n03495258', 'n03498962', 'n03594945', 'n03602883', 'n03630383', 'n03649909', 'n03676483', 'n03710193', 'n03773504', 'n03775071', 'n03888257', 'n03930630', 'n03947888', 'n04086273', 'n04118538', 'n04133789', 'n04141076', 'n04146614', 'n04147183', 'n04192698', 'n04254680', 'n04266014', 'n04275548', 'n04310018', 'n04325704', 'n04347754', 'n04389033', 'n04409515', 'n04465501', 'n04487394', 'n04522168', 'n04536866', 'n04552348', 'n04591713', 'n07614500', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07714571', 'n07714990', 'n07718472', 'n07720875', 'n07734744', 'n07742313', 'n07745940', 'n07749582', 'n07753275', 'n07753592', 'n07768694', 'n07873807', 'n07880968', 'n07920052', 'n09472597', 'n09835506', 'n10565667', 'n12267677'} 10 | 11 | imagenet_r_mask = [wnid in imagenet_r_wnids for wnid in all_wnids] 12 | 13 | imagenet_v_mask = [0, 1, 10, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 11,110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 12, 120, 121, 122,123, 124, 125, 126, 127, 128, 129, 13, 130, 131, 132, 133, 134, 135,136, 137, 138, 139, 14, 140, 141, 142, 143, 144, 145, 146, 147, 148,149, 15, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 16, 160,161, 162, 163, 164, 165, 166, 167, 168, 169, 17, 170, 171, 172, 173,174, 175, 176, 177, 178, 179, 18, 180, 181, 182, 183, 184, 185, 186,187, 188, 189, 19, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 2, 20, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 21, 210,211, 212, 213, 214, 215, 216, 217, 218, 219, 22, 220, 221, 222, 223,224, 225, 226, 227, 228, 229, 23, 230, 231, 232, 233, 234, 235, 236,237, 238, 239, 24, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 25, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 26, 260, 261,262, 263, 264, 265, 266, 267, 268, 269, 27, 270, 271, 272, 273, 274,275, 276, 277, 278, 279, 28, 280, 281, 282, 283, 284, 285, 286, 287,288, 289, 29, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 3, 30, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 31, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 32, 320, 321, 322, 323, 324,325, 326, 327, 328, 329, 33, 330, 331, 332, 333, 334, 335, 336, 337,338, 339, 34, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 35,350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 36, 360, 361, 362,363, 364, 365, 366, 367, 368, 369, 37, 370, 371, 372, 373, 374, 375,376, 377, 378, 379, 38, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 39, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 4, 40,400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 41, 410, 411, 412,413, 414, 415, 416, 417, 418, 419, 42, 420, 421, 422, 423, 424, 425,426, 427, 428, 429, 43, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 44, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 45, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 46, 460, 461, 462, 463,464, 465, 466, 467, 468, 469, 47, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 48, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 49, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 5, 50, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 51, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 52, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 53, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 54, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 55, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 56, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 57, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 58, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 59, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 6, 60, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 61, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 62, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 63, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 64, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 65, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 66, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 67, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 68, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 69, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 7, 70, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 71, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 72, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 73, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 74, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 75, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 76, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 77, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 78, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 79, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, 8, 80, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 81, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 82, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829, 83, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 84, 840, 841, 842, 843, 844, 845, 846, 847, 848, 849, 85, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 86, 860, 861, 862, 863, 864, 865, 866, 867, 868, 869, 87, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 88, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 89, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 9, 90, 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 91, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 92, 920, 921, 922, 923, 924, 925, 926, 927, 928, 929, 93, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 94, 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, 95, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 96, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 97, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 98, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 99, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999] --------------------------------------------------------------------------------