├── .gitignore ├── DATASETS.md ├── DMN ├── DMN_clip_wrapper.py ├── DMN_core.py ├── DMN_utils.py └── __init__.py ├── LICENSE ├── OGA ├── OGA_core.py └── __init__.py ├── README.md ├── TDA ├── TDA_core.py ├── TDA_utils.py └── __init__.py ├── __init__.py ├── compute_features.py ├── datasets ├── __init__.py ├── caltech101.py ├── dtd.py ├── eurosat.py ├── fgvc.py ├── food101.py ├── imagenet.py ├── imagenet_a.py ├── imagenet_r.py ├── imagenet_sketch.py ├── imagenet_v2.py ├── oxford_flowers.py ├── oxford_pets.py ├── prepare_tta_datasets.py ├── sampler.py ├── stanford_cars.py ├── sun397.py ├── tiny_imagenet.py ├── tiny_imagenet_c.py ├── ucf101.py ├── utils.py └── visda.py ├── images └── abstract_barplot_github_version.png ├── main.py ├── runner.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /DATASETS.md: -------------------------------------------------------------------------------- 1 | # How to install datasets 2 | 3 | All datasets should be under the same folder (say `$DATA`) and organized as follow to avoid modifying the source code. The file structure looks like 4 | 5 | ``` 6 | $DATA/ 7 | |–– caltech-101/ 8 | |–– oxford_pets/ 9 | |–– stanford_cars/ 10 | ``` 11 | 12 | If you have some datasets already installed somewhere else, you can create symbolic links in `$DATA/dataset_name` that point to the original data to avoid duplicate download. 13 | 14 | Datasets list: 15 | - [Caltech101](#caltech101) 16 | - [OxfordPets](#oxfordpets) 17 | - [StanfordCars](#stanfordcars) 18 | - [Flowers102](#flowers102) 19 | - [Food101](#food101) 20 | - [FGVCAircraft](#fgvcaircraft) 21 | - [SUN397](#sun397) 22 | - [DTD](#dtd) 23 | - [EuroSAT](#eurosat) 24 | - [UCF101](#ucf101) 25 | - [ImageNet](#imagenet) 26 | 27 | 28 | ### Caltech101 29 | - Create a folder named `caltech-101/` under `$DATA`. 30 | - Download `101_ObjectCategories.tar.gz` from http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz and extract the file under `$DATA/caltech-101`. 31 | - Download `split_zhou_Caltech101.json` from this [link](https://drive.google.com/file/d/1hyarUivQE36mY6jSomru6Fjd-JzwcCzN/view?usp=sharing) and put it under `$DATA/caltech-101`. 32 | 33 | The directory structure should look like 34 | ``` 35 | caltech-101/ 36 | |–– 101_ObjectCategories/ 37 | |–– split_zhou_Caltech101.json 38 | ``` 39 | 40 | ### OxfordPets 41 | - Create a folder named `oxford_pets/` under `$DATA`. 42 | - Download the images from https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz. 43 | - Download the annotations from https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz. 44 | - Download `split_zhou_OxfordPets.json` from this [link](https://drive.google.com/file/d/1501r8Ber4nNKvmlFVQZ8SeUHTcdTTEqs/view?usp=sharing). 45 | 46 | The directory structure should look like 47 | ``` 48 | oxford_pets/ 49 | |–– images/ 50 | |–– annotations/ 51 | |–– split_zhou_OxfordPets.json 52 | ``` 53 | 54 | ### StanfordCars 55 | - Create a folder named `stanford_cars/` under `$DATA`. 56 | - Download the train images http://ai.stanford.edu/~jkrause/car196/cars_train.tgz. 57 | - Download the test images http://ai.stanford.edu/~jkrause/car196/cars_test.tgz. 58 | - Download the train labels https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz. 59 | - Download the test labels http://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat. 60 | - Download `split_zhou_StanfordCars.json` from this [link](https://drive.google.com/file/d/1ObCFbaAgVu0I-k_Au-gIUcefirdAuizT/view?usp=sharing). 61 | 62 | The directory structure should look like 63 | ``` 64 | stanford_cars/ 65 | |–– cars_test\ 66 | |–– cars_test_annos_withlabels.mat 67 | |–– cars_train\ 68 | |–– devkit\ 69 | |–– split_zhou_StanfordCars.json 70 | ``` 71 | 72 | ### Flowers102 73 | - Create a folder named `oxford_flowers/` under `$DATA`. 74 | - Download the images and labels from https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz and https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat respectively. 75 | - Download `cat_to_name.json` from [here](https://drive.google.com/file/d/1AkcxCXeK_RCGCEC_GvmWxjcjaNhu-at0/view?usp=sharing). 76 | - Download `split_zhou_OxfordFlowers.json` from [here](https://drive.google.com/file/d/1Pp0sRXzZFZq15zVOzKjKBu4A9i01nozT/view?usp=sharing). 77 | 78 | The directory structure should look like 79 | ``` 80 | oxford_flowers/ 81 | |–– cat_to_name.json 82 | |–– imagelabels.mat 83 | |–– jpg/ 84 | |–– split_zhou_OxfordFlowers.json 85 | ``` 86 | 87 | ### Food101 88 | - Download the dataset from https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/ and extract the file `food-101.tar.gz` under `$DATA`, resulting in a folder named `$DATA/food-101/`. 89 | - Download `split_zhou_Food101.json` from [here](https://drive.google.com/file/d/1QK0tGi096I0Ba6kggatX1ee6dJFIcEJl/view?usp=sharing). 90 | 91 | The directory structure should look like 92 | ``` 93 | food-101/ 94 | |–– images/ 95 | |–– license_agreement.txt 96 | |–– meta/ 97 | |–– README.txt 98 | |–– split_zhou_Food101.json 99 | ``` 100 | 101 | ### FGVCAircraft 102 | - Download the data from https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz. 103 | - Extract `fgvc-aircraft-2013b.tar.gz` and keep only `data/`. 104 | - Move `data/` to `$DATA` and rename the folder to `fgvc_aircraft/`. 105 | 106 | The directory structure should look like 107 | ``` 108 | fgvc_aircraft/ 109 | |–– images/ 110 | |–– ... # a bunch of .txt files 111 | ``` 112 | 113 | ### SUN397 114 | - Create a folder named `sun397/` under `$DATA`. 115 | - Download the images http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz. 116 | - Download the partitions https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip. 117 | - Extract these files under `$DATA/sun397/`. 118 | - Download `split_zhou_SUN397.json` from this [link](https://drive.google.com/file/d/1y2RD81BYuiyvebdN-JymPfyWYcd8_MUq/view?usp=sharing). 119 | 120 | The directory structure should look like 121 | ``` 122 | sun397/ 123 | |–– SUN397/ 124 | |–– split_zhou_SUN397.json 125 | |–– ... # a bunch of .txt files 126 | ``` 127 | 128 | ### DTD 129 | - Download the dataset from https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz and extract it to `$DATA`. This should lead to `$DATA/dtd/`. 130 | - Download `split_zhou_DescribableTextures.json` from this [link](https://drive.google.com/file/d/1u3_QfB467jqHgNXC00UIzbLZRQCg2S7x/view?usp=sharing). 131 | 132 | The directory structure should look like 133 | ``` 134 | dtd/ 135 | |–– images/ 136 | |–– imdb/ 137 | |–– labels/ 138 | |–– split_zhou_DescribableTextures.json 139 | ``` 140 | 141 | ### EuroSAT 142 | - Create a folder named `eurosat/` under `$DATA`. 143 | - Download the dataset from http://madm.dfki.de/files/sentinel/EuroSAT.zip and extract it to `$DATA/eurosat/`. 144 | - Download `split_zhou_EuroSAT.json` from [here](https://drive.google.com/file/d/1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o/view?usp=sharing). 145 | 146 | The directory structure should look like 147 | ``` 148 | eurosat/ 149 | |–– 2750/ 150 | |–– split_zhou_EuroSAT.json 151 | ``` 152 | 153 | ### UCF101 154 | - Create a folder named `ucf101/` under `$DATA`. 155 | - Download the zip file `UCF-101-midframes.zip` from [here](https://drive.google.com/file/d/10Jqome3vtUA2keJkNanAiFpgbyC9Hc2O/view?usp=sharing) and extract it to `$DATA/ucf101/`. This zip file contains the extracted middle video frames. 156 | - Download `split_zhou_UCF101.json` from this [link](https://drive.google.com/file/d/1I0S0q91hJfsV9Gf4xDIjgDq4AqBNJb1y/view?usp=sharing). 157 | 158 | The directory structure should look like 159 | ``` 160 | ucf101/ 161 | |–– UCF-101-midframes/ 162 | |–– split_zhou_UCF101.json 163 | ``` 164 | 165 | ### ImageNet 166 | - Create a folder named `imagenet/` under `$DATA`. 167 | - Create `images/` under `imagenet/`. 168 | - Download the dataset from the [official website](https://image-net.org/index.php) and extract the training and validation sets to `$DATA/imagenet/images`. The directory structure should look like 169 | ``` 170 | imagenet/ 171 | |–– images/ 172 | | |–– train/ # contains 1,000 folders like n01440764, n01443537, etc. 173 | | |–– val/ 174 | ``` 175 | - If you had downloaded the ImageNet dataset before, you can create symbolic links to map the training and validation sets to `$DATA/imagenet/images`. 176 | - Download the `classnames.txt` to `$DATA/imagenet/` from this [link](https://drive.google.com/file/d/1-61f_ol79pViBFDG_IDlUQSwoLcn2XXF/view?usp=sharing). The class names are copied from [CLIP](https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb). 177 | -------------------------------------------------------------------------------- /DMN/DMN_clip_wrapper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch.nn as nn 3 | from clip import load, tokenize 4 | import torch 5 | class DMNClipWrapper(nn.Module): 6 | def __init__(self, clip_model, transform, device, classnames, batch_size, criterion='cosine', arch="ViT-L/14", 7 | learned_cls=False, memory_size=10, text_prompt_type='custom'): 8 | super(DMNClipWrapper, self).__init__() 9 | self.clip = clip_model 10 | self.classnames = [name.replace("_", " ") for name in classnames] 11 | self.first_flag = True 12 | self.memory_size = memory_size 13 | self.return_local_feat = False 14 | if text_prompt_type != 'custom': 15 | raise RuntimeError('Only custom prompts are supported.') 16 | self.text_prompt_type = text_prompt_type 17 | 18 | self.logit_scale = self.clip.logit_scale.data 19 | self.text_feat = None 20 | self.few_shot_mem = False 21 | # self.n_cls = len(classnames) ## 200 22 | # self.image_encoder = clip.visual 23 | # # ipdb.set_trace() 24 | # self.text_encoder = TextEncoder(clip) 25 | # # prompt tuning 26 | # self.prompt_learner = PromptLearner(clip, classnames, batch_size, n_ctx, ctx_init, ctx_position, learned_cls) 27 | # self.criterion = criterion 28 | 29 | 30 | # @property 31 | # def dtype(self): 32 | # return self.image_encoder.conv1.weight.dtype 33 | 34 | # # restore the initial state of the prompt_learner (tunable prompt) 35 | # def reset(self): 36 | # self.prompt_learner.reset() 37 | # 38 | def reset_classnames(self, dataset): 39 | self.n_cls = len(dataset.classnames) ## 200 40 | self.classnames = [name.replace("_", " ") for name in dataset.classnames] 41 | self.text_prompt = dataset.template 42 | # ipdb.set_trace() 43 | # name_lens = [len(_tokenizer.encode(name)) for name in classnames] 44 | # prompts = [self.prompt_prefix + " " + name + "." for name in classnames] ## 200 45 | # tokenized_prompts = torch.cat([tokenize(p) for p in prompts]).to(self.device) ## torch.Size([200, 77]) 46 | # 47 | # clip, _, _ = load(arch, device=self.device, download_root=DOWNLOAD_ROOT) 48 | # 49 | # with torch.no_grad(): 50 | # embedding = clip.token_embedding(tokenized_prompts).type(self.dtype) ## torch.Size([200, 77, 512]) 51 | # 52 | # self.token_prefix = embedding[:, :1, :] ## 200*1*512 前缀 53 | # self.token_suffix = embedding[:, 1 + self.n_ctx :, :] # CLS, EOS ## torch.Size([200, 72, 512]) 后缀 54 | # 55 | # self.name_lens = name_lens 56 | # self.tokenized_prompts = tokenized_prompts ## torch.Size([200, 77]) 57 | # self.classnames = classnames 58 | self.first_flag = True 59 | 60 | def get_text_features(self): 61 | ## get the text feature only once, multiple class & multiple prompt 62 | text_feat = [] 63 | text_label = [] 64 | count = 0 65 | for name in self.classnames: 66 | text_prompts = [template.format(name) for template in self.text_prompt] # format with class 67 | if self.text_prompt_type =='tip_cupl': 68 | text_prompts += self.cupl_prompts[name] 69 | texts = tokenize(text_prompts).cuda() # tokenize 70 | class_embeddings = self.clip.encode_text(texts) # embed with text encoder 71 | class_embeddings_full = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) 72 | class_embedding_mean = class_embeddings_full.mean(dim=0) 73 | class_embedding_mean /= class_embedding_mean.norm() 74 | text_feat.append(class_embedding_mean) ### 1024 75 | one_hot_target = torch.zeros(self.n_cls).to(class_embedding_mean.device) 76 | one_hot_target[count] = 1 77 | text_label.append(one_hot_target) ## 1 * d, turn it to one hot labels. 78 | count = count + 1 79 | self.text_feat = torch.stack(text_feat, dim=0).cuda() ## N*1024 80 | self.text_label = torch.stack(text_label, dim=0).cuda() ## N*N 81 | 82 | self.text_feat_full = self.text_feat ## not used. 83 | ######## 直接从这里找出 important text feat following APE. TO DO 84 | self.fixed_global_feat = self.text_feat.clone().unsqueeze(1) ## N*1*C 85 | self.fixed_local_feat = self.text_feat.clone().unsqueeze(1) ## N*1*C 86 | self.fixed_global_feat_vanilla = self.text_feat.clone().unsqueeze(1) ## N*1*C 87 | self.fixed_local_feat_vanilla = self.text_feat.clone().unsqueeze(1) ## N*1*C 88 | 89 | self.fixed_global_label = self.text_label.clone().unsqueeze(1) 90 | self.fixed_local_label = self.text_label.clone().unsqueeze(1) 91 | self.fixed_global_label_vanilla = self.text_label.clone().unsqueeze(1) 92 | self.fixed_local_label_vanilla = self.text_label.clone().unsqueeze(1) 93 | 94 | if self.first_flag: ## initlize 95 | self.image_feature_memory = torch.zeros(self.n_cls, self.memory_size, self.text_feat.shape[1]).to(self.text_feat.device) ## 如果满了,把entropy 最高的扔出去 96 | self.image_prediction_mem = torch.zeros(self.n_cls, self.memory_size, self.n_cls).to(self.text_feat.device) ## category prediction. 97 | self.image_entropy_mem = torch.zeros(self.n_cls, self.memory_size).to(self.text_feat.device) ## category prediction. 98 | self.image_feature_count = torch.zeros(self.n_cls, 1).long().to(self.text_feat.device) 99 | 100 | self.local_feature_memory = torch.zeros(self.n_cls, self.memory_size, self.text_feat.shape[1]).to(self.text_feat.device) 101 | self.local_prediction_mem = torch.zeros(self.n_cls, self.memory_size, self.n_cls).to(self.text_feat.device) ## category prediction. 102 | self.local_entropy_mem = torch.zeros(self.n_cls, self.memory_size).to(self.text_feat.device) ## category prediction. 103 | self.local_feature_count = torch.zeros(self.n_cls, 1).long().to(self.text_feat.device) 104 | self.first_flag = False 105 | 106 | return self.text_feat, self.text_feat_full 107 | 108 | # text_features = [] 109 | # prompts = self.prompt_learner(with_std=True) ## torch.Size([1000, 77, 512]) 110 | # tokenized_prompts = self.prompt_learner.tokenized_prompts 111 | # t_features = self.text_encoder(prompts, tokenized_prompts) ## torch.Size([1000, 1024]) 112 | # text_features.append(t_features / t_features.norm(dim=-1, keepdim=True)) 113 | # self.num_class = t_features.size(0) 114 | # text_features = torch.stack(text_features, dim=0) 115 | # # return text_features 116 | # 117 | # return torch.mean(text_features, dim=0) 118 | 119 | def DMN_encode_image(self, x): 120 | x = self.clip.visual.conv1(x) # shape = [*, width, grid, grid] 121 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 122 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 123 | x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 124 | x = x + self.clip.visual.positional_embedding.to(x.dtype) 125 | x = self.clip.visual.ln_pre(x) 126 | 127 | x = x.permute(1, 0, 2) # NLD -> LND 128 | x = self.clip.visual.transformer(x) 129 | x = x.permute(1, 0, 2) # LND -> NLD ## torch.Size([128, 197, 768]) 130 | 131 | # ipdb.set_trace() 132 | # x = self.clip.visual.ln_post(x[:, 0, :]) ## 128*768 133 | x = self.clip.visual.ln_post(x) ## 128*197*768 134 | 135 | if self.clip.visual.proj is not None: 136 | x = x @ self.clip.visual.proj 137 | 138 | return x 139 | def get_image_features(self, image): 140 | # image_features_vanilla = self.image_encoder(image.type(self.dtype)) 141 | ## for Res50 128*1024 or 128*50*1024 [global feat; 7*7 local feature] 142 | ## for VIT, 128*512 or 128*197*512 [global feat; 14*14 local features] 143 | image_features = self.DMN_encode_image(image) 144 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 145 | image_features_local = image_features[:,1:,:] ## B*L*C 146 | image_features_global = image_features[:, 0, :] ## B*C 147 | 148 | self.image_features_local = None #image_features_local 149 | self.image_features_global = image_features_global 150 | 151 | return self.image_features_global, self.image_features_local 152 | 153 | # logit_scale = self.logit_scale.exp() 154 | # logits = logit_scale * image_features @ text_features.t() 155 | # return logits 156 | 157 | def forward(self, input): 158 | pass 159 | # if isinstance(input, Tuple): 160 | # view_0, view_1, view_2 = input 161 | # return self.contrast_prompt_tuning(view_0, view_1, view_2) 162 | # elif len(input.size()) == 2: 163 | # return self.directional_prompt_tuning(input) 164 | # else: 165 | # return self.inference(input) 166 | -------------------------------------------------------------------------------- /DMN/DMN_core.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from PIL import Image 4 | 5 | try: 6 | from torchvision.transforms import InterpolationMode 7 | BICUBIC = InterpolationMode.BICUBIC 8 | except ImportError: 9 | BICUBIC = Image.BICUBIC 10 | ## the main component. 11 | class DualMem(nn.Module): 12 | def __init__(self, args=None, beta=5.5, feat_dim=1024, class_num=1000, mapping='bias'): 13 | super(DualMem, self).__init__() 14 | self.args = args 15 | self.indice = args.indice ## indice of important channels. 16 | self.beta = beta 17 | self.rank = 4 18 | self.init_pred = 0 19 | if args.shared_param: 20 | self.global_affine = nn.Parameter(torch.zeros((feat_dim, feat_dim))) 21 | self.global_bias = nn.Parameter(torch.zeros((class_num, feat_dim))) ## unknown use the category mean. 22 | self.global_bias_key = self.global_bias 23 | self.global_bias_value = self.global_bias 24 | 25 | self.global_ffn_affine = nn.Parameter(torch.zeros((feat_dim, feat_dim))) 26 | self.global_ffn_bias = nn.Parameter(torch.zeros((class_num, feat_dim))) ## unknown use the category mean. 27 | self.text_affine = self.global_ffn_affine 28 | self.text_bias = self.global_ffn_bias 29 | else: 30 | self.global_affine = nn.Parameter(torch.zeros((feat_dim, feat_dim))) 31 | self.global_bias = nn.Parameter(torch.zeros((class_num, feat_dim))) ## unknown use the category mean. 32 | self.global_bias_key = nn.Parameter(torch.zeros((class_num, feat_dim))) ## unknown use the category mean. 33 | self.global_bias_value = nn.Parameter(torch.zeros((class_num, feat_dim))) ## unknown use the category mean. 34 | 35 | self.global_ffn_affine = nn.Parameter(torch.zeros((feat_dim, feat_dim))) 36 | self.global_ffn_bias = nn.Parameter(torch.zeros((class_num, feat_dim))) ## unknown use the category mean. 37 | self.text_affine = nn.Parameter(torch.zeros((feat_dim, feat_dim))) 38 | self.text_bias = nn.Parameter(torch.zeros((class_num, feat_dim))) 39 | self.learnable_mapping = args.mapping ### bias | affine | all 40 | 41 | 42 | def update_memory_bank(self, model): 43 | # updating 44 | mean_prob = self.init_pred[0] 45 | value, indice = mean_prob.max(0) 46 | pseudo_label = indice.item() 47 | text_features = model.text_feat[pseudo_label] ## 512 48 | selected_image_features_global = model.image_features_global[:1] 49 | current_instance_entropy = -(mean_prob * (torch.log(mean_prob + 1e-8))).sum() 50 | if model.image_feature_count[pseudo_label] == model.memory_size: 51 | ###### if the new one is low entropy, find the sample with the max entropy, and replace it with the new one 52 | if (current_instance_entropy < model.image_entropy_mem[pseudo_label]).sum() == 0: 53 | pass ## the entropy of current test image is very large. 54 | else: 55 | _, indice = torch.sort(model.image_entropy_mem[pseudo_label]) 56 | to_replace_indice = indice[-1] ## with max entropy, ascending. 57 | model.image_feature_memory[pseudo_label][to_replace_indice] = selected_image_features_global 58 | model.image_prediction_mem[pseudo_label][to_replace_indice] = mean_prob[0] 59 | model.image_entropy_mem[pseudo_label][to_replace_indice] = current_instance_entropy 60 | else: 61 | model.image_feature_memory[pseudo_label][model.image_feature_count[pseudo_label, 0].item()] = selected_image_features_global 62 | model.image_prediction_mem[pseudo_label][model.image_feature_count[pseudo_label, 0].item()] = mean_prob[0] 63 | model.image_entropy_mem[pseudo_label][model.image_feature_count[pseudo_label, 0].item()] = current_instance_entropy 64 | model.image_feature_count[pseudo_label] += 1 65 | 66 | 67 | def fast_get_image_pred(self, 68 | img_features, 69 | model, 70 | clip_prototypes, 71 | return_full=False, 72 | return_logit=True): 73 | # vectorized version of below function 74 | # only works when not using augmentations + training free settings 75 | assert return_logit 76 | assert self.args.position == 'qkv' or self.args.position == 'all' 77 | assert torch.all(torch.linalg.norm(self.global_bias, dim = -1)<1e-6) #global bias is initialized in dmn base code but should always zero in trainig free scenarios 78 | assert torch.all(torch.linalg.norm(self.global_bias_key, dim = -1)<1e-6) 79 | assert torch.all(torch.linalg.norm(self.global_bias_value, dim = -1)<1e-6) 80 | memorized_image_feat = torch.cat((model.image_feature_memory, model.fixed_global_feat_vanilla), dim=1) 81 | #memorized_image_feat = model.fixed_global_feat_vanilla 82 | batch_sim_mat = torch.sum(img_features[:,None,None,:] * memorized_image_feat[None,...], 83 | dim = -1) 84 | batch_sim_mat = torch.exp(-self.beta * (-batch_sim_mat + 1)) 85 | batch_adapt_image_feat = torch.sum(memorized_image_feat[None,...] * batch_sim_mat[...,None], 86 | dim = -2) 87 | batch_adapt_image_feat = batch_adapt_image_feat/torch.linalg.norm(batch_adapt_image_feat, 88 | dim = -1, 89 | keepdims = True) 90 | batch_logits = 100 * torch.sum(img_features[:,None,:] * batch_adapt_image_feat, dim = -1 ) 91 | return batch_logits 92 | 93 | 94 | def get_image_pred(self, model, return_full=False, return_logit=False): 95 | ## prediction with dynamic memory. 96 | img_feat = model.image_features_global[:1] # 1*1024 97 | count_image_feat = model.image_feature_count.clone() 98 | num_class = model.image_feature_memory.shape[0] 99 | image_classifier = 'similarity_weighted' ## category_center | entropy_weighted | similarity_weighted 100 | ### similarity_weighted achieves the best results. 101 | memorized_image_feat = torch.cat((model.image_feature_memory, model.fixed_global_feat_vanilla), dim=1) ## 200*11*1024 102 | # model.fixed_global_feat_vanilla is actually the text embeddings 103 | if image_classifier == 'similarity_weighted': ## this is an instance adaptative method. 104 | ## calculate the cos similarity betweeen image feature and memory feature, and then weighted the memorized features according to similarity. 105 | ###################### 有一些memory 是空的,现在却往里面塞了一个self.global_bias, 这不合理,还要把它继续置空。 106 | img_feat_mappling = img_feat 107 | memorized_image_feat_K = memorized_image_feat 108 | memorized_image_feat_V = memorized_image_feat 109 | with torch.no_grad(): 110 | if self.args.position == 'query': 111 | img_feat_mappling = img_feat + self.global_bias.mean(0, keepdim=True) ## N*1024 112 | elif self.args.position == 'key': 113 | memorized_image_feat_K = memorized_image_feat + self.global_bias_key.unsqueeze(1) ## class*shot*1024 114 | elif self.args.position == 'value': 115 | memorized_image_feat_V = memorized_image_feat + self.global_bias_value.unsqueeze(1) ## class*shot*1024 116 | elif self.args.position == 'qkv' or self.args.position == 'all': 117 | img_feat_mappling = img_feat + self.global_bias.mean(0, keepdim=True) ## N*1024 118 | memorized_image_feat_K = memorized_image_feat + self.global_bias_key.unsqueeze(1) ## class*shot*1024 119 | memorized_image_feat_V = memorized_image_feat + self.global_bias_value.unsqueeze(1) ## class*shot*1024 120 | else: 121 | pass 122 | memorized_image_feat_K = memorized_image_feat_K / memorized_image_feat_K.norm(dim=-1, keepdim=True) 123 | ## some memorized_image_feat slots are empty before mapping, reseting them to empty. 124 | memorized_image_feat_K[memorized_image_feat.sum(-1) == 0] = 0 125 | memorized_image_feat_V = memorized_image_feat_V / memorized_image_feat_V.norm(dim=-1, keepdim=True) 126 | memorized_image_feat_V[memorized_image_feat.sum(-1) == 0] = 0 127 | img_feat_mappling = img_feat_mappling / img_feat_mappling.norm(dim=-1, keepdim=True) 128 | 129 | similarity_matrix = (img_feat_mappling * memorized_image_feat_K).sum(-1) ## 200*11 idealy [-1,1], practically [0.1, 0.2] 130 | similarity_matrix = torch.exp(-self.beta * (-similarity_matrix + 1)) 131 | ### weighting memoried features with similarity weights. 132 | adaptive_image_feat = (memorized_image_feat_V * similarity_matrix.unsqueeze(-1)).sum(1) 133 | ## torch.Size([1, class, dim]) 134 | adaptive_image_feat = adaptive_image_feat / adaptive_image_feat.norm(dim=-1, keepdim=True) 135 | if self.args.position == 'output' or self.args.position == 'all': 136 | adaptive_image_feat = adaptive_image_feat + self.global_ffn_bias.unsqueeze(0) ## class*shot*1024 137 | 138 | adaptive_image_feat = adaptive_image_feat / adaptive_image_feat.norm(dim=-1, keepdim=True) 139 | logit_scale = model.logit_scale.exp() 140 | # adaptive_image_feat: torch.Size([1, 102, 1024]) 141 | # img_feat: torch.Size([1, 1024]) 142 | logits = logit_scale * adaptive_image_feat @ img_feat.unsqueeze(-1) ## used feat is not update. 143 | logits = logits[:,:,0] 144 | if return_logit: 145 | return logits 146 | else: 147 | return logits.softmax(dim=1) 148 | else: 149 | raise NotImplementedError 150 | 151 | def get_image_pred_fewshot_global(self, model, return_full=False, return_logit=False): 152 | ## prediction with static memory. 153 | if return_full: 154 | img_feat = model.image_features_global # 1*1024 155 | else: 156 | img_feat = model.image_features_global[:1, :] # 1*1024 157 | num_class = model.image_feature_memory.shape[0] 158 | memorized_image_feat = model.fixed_global_feat ## 200*11*1024, few shot samples and text features. 159 | img_feat_mappling = img_feat 160 | memorized_image_feat_K = memorized_image_feat 161 | memorized_image_feat_V = memorized_image_feat 162 | 163 | if self.args.position == 'query': 164 | img_feat_mappling = img_feat + self.global_bias.mean(0, keepdim=True) ## N*1024 165 | elif self.args.position == 'key': 166 | memorized_image_feat_K = memorized_image_feat + self.global_bias_key.unsqueeze(1) ## class*shot*1024 167 | elif self.args.position == 'value': 168 | memorized_image_feat_V = memorized_image_feat + self.global_bias_value.unsqueeze(1) ## class*shot*1024 169 | elif self.args.position == 'qkv' or self.args.position == 'all': 170 | img_feat_mappling = img_feat + self.global_bias.mean(0, keepdim=True) ## N*1024 171 | memorized_image_feat_K = memorized_image_feat + self.global_bias_key.unsqueeze(1) ## class*shot*1024 172 | memorized_image_feat_V = memorized_image_feat + self.global_bias_value.unsqueeze(1) ## class*shot*1024 173 | 174 | memorized_image_feat_K = memorized_image_feat_K / memorized_image_feat_K.norm(dim=-1, keepdim=True) 175 | memorized_image_feat_V = memorized_image_feat_V / memorized_image_feat_V.norm(dim=-1, keepdim=True) 176 | img_feat_mappling = img_feat_mappling / img_feat_mappling.norm(dim=-1, keepdim=True) 177 | ## calculate the cos similarity betweeen image feature and memory feature, and then weighted the memorized probability. 178 | ## 200*11*200; 179 | similarity_matrix = memorized_image_feat_K @ img_feat_mappling.T ## class*shot*Batch 180 | similarity_matrix = torch.exp(-self.beta * (-similarity_matrix + 1)) 181 | adaptive_image_feat = memorized_image_feat_V.transpose(1,2) @ similarity_matrix ## class * D * batch, 102*1024*204 182 | adaptive_image_feat = adaptive_image_feat / adaptive_image_feat.norm(dim=1, keepdim=True) 183 | logit_scale = model.logit_scale.exp() 184 | adaptive_image_feat = adaptive_image_feat.transpose(0,2).transpose(1,2) ## 204*102*1024 185 | if self.args.position == 'output' or self.args.position == 'all': 186 | adaptive_image_feat = adaptive_image_feat + self.global_ffn_bias.unsqueeze(0) ## class*shot*1024 187 | 188 | adaptive_image_feat = adaptive_image_feat / adaptive_image_feat.norm(dim=-1, keepdim=True) 189 | # ipdb.set_trace() 190 | # adaptive_image_feat: 1*102*1024 191 | # img_feat: 1*1024 192 | logits = logit_scale * adaptive_image_feat[..., self.args.indice] @ img_feat[..., self.args.indice].unsqueeze(-1) ## memoried features are not updated. 193 | if return_logit: 194 | return logits[:,:,0] 195 | else: 196 | return logits[:,:,0].softmax(dim=1) 197 | 198 | def get_text_prediction(self, model, return_full=True, return_logit=False): 199 | logit_scale = model.logit_scale.exp() 200 | if self.args.position == 'output' or self.args.position == 'all': 201 | text_feat = model.text_feat + self.text_bias 202 | else: 203 | text_feat = model.text_feat 204 | text_feat = text_feat / text_feat.norm(dim=1, keepdim=True) ## already filtered with indice. 205 | img_text_logit = logit_scale * model.image_features_global @ text_feat.t() ## 128*200 206 | if return_full: 207 | pass 208 | else: 209 | img_text_logit = img_text_logit[:1] 210 | if return_logit: 211 | return img_text_logit 212 | else: 213 | return img_text_logit.softmax(-1) 214 | -------------------------------------------------------------------------------- /DMN/DMN_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import time 4 | import random 5 | 6 | import numpy as np 7 | 8 | import shutil 9 | from enum import Enum 10 | 11 | import torch 12 | import torchvision.transforms as transforms 13 | import torch.nn.functional as F 14 | import torch.nn as nn 15 | 16 | def set_random_seed(seed): 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | 22 | class Summary(Enum): 23 | NONE = 0 24 | AVERAGE = 1 25 | SUM = 2 26 | COUNT = 3 27 | 28 | class AverageMeter(object): 29 | """Computes and stores the average and current value""" 30 | def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): 31 | self.name = name 32 | self.fmt = fmt 33 | self.summary_type = summary_type 34 | self.reset() 35 | 36 | def reset(self): 37 | self.val = 0 38 | self.avg = 0 39 | self.sum = 0 40 | self.count = 0 41 | 42 | def update(self, val, n=1): 43 | self.val = val 44 | self.sum += val * n 45 | self.count += n 46 | self.avg = self.sum / self.count 47 | 48 | def __str__(self): 49 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 50 | return fmtstr.format(**self.__dict__) 51 | 52 | def summary(self): 53 | fmtstr = '' 54 | if self.summary_type is Summary.NONE: 55 | fmtstr = '' 56 | elif self.summary_type is Summary.AVERAGE: 57 | fmtstr = '{name} {avg:.3f}' 58 | elif self.summary_type is Summary.SUM: 59 | fmtstr = '{name} {sum:.3f}' 60 | elif self.summary_type is Summary.COUNT: 61 | fmtstr = '{name} {count:.3f}' 62 | else: 63 | raise ValueError('invalid summary type %r' % self.summary_type) 64 | 65 | return fmtstr.format(**self.__dict__) 66 | 67 | 68 | class ProgressMeter(object): 69 | def __init__(self, num_batches, meters, prefix=""): 70 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 71 | self.meters = meters 72 | self.prefix = prefix 73 | 74 | def display(self, batch): 75 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 76 | entries += [str(meter) for meter in self.meters] 77 | print('\t'.join(entries)) 78 | 79 | def display_summary(self): 80 | entries = [" *"] 81 | entries += [meter.summary() for meter in self.meters] 82 | print(' '.join(entries)) 83 | 84 | def _get_batch_fmtstr(self, num_batches): 85 | num_digits = len(str(num_batches // 1)) 86 | fmt = '{:' + str(num_digits) + 'd}' 87 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 88 | 89 | def accuracy(output, target, topk=(1,)): 90 | """Computes the accuracy over the k top predictions for the specified values of k""" 91 | with torch.no_grad(): 92 | maxk = max(topk) 93 | batch_size = target.size(0) 94 | _, pred = output.topk(maxk, 1, True, True) 95 | pred = pred.t() 96 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 97 | 98 | res = [] 99 | for k in topk: 100 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 101 | res.append(correct_k.mul_(100.0 / batch_size)) 102 | return res 103 | 104 | 105 | def load_model_weight(load_path, model, device, args): 106 | if os.path.isfile(load_path): 107 | print("=> loading checkpoint '{}'".format(load_path)) 108 | checkpoint = torch.load(load_path, map_location=device) 109 | state_dict = checkpoint['state_dict'] 110 | # Ignore fixed token vectors 111 | if "token_prefix" in state_dict: 112 | del state_dict["token_prefix"] 113 | 114 | if "token_suffix" in state_dict: 115 | del state_dict["token_suffix"] 116 | 117 | args.start_epoch = checkpoint['epoch'] 118 | try: 119 | best_acc1 = checkpoint['best_acc1'] 120 | except: 121 | best_acc1 = torch.tensor(0) 122 | if device != 'cpu': 123 | # best_acc1 may be from a checkpoint from a different GPU 124 | best_acc1 = best_acc1.to(device) 125 | try: 126 | model.load_state_dict(state_dict) 127 | except: 128 | # TODO: implement this method for the generator class 129 | model.prompt_generator.load_state_dict(state_dict, strict=False) 130 | print("=> loaded checkpoint '{}' (epoch {})" 131 | .format(load_path, checkpoint['epoch'])) 132 | del checkpoint 133 | torch.cuda.empty_cache() 134 | else: 135 | print("=> no checkpoint found at '{}'".format(load_path)) 136 | 137 | 138 | def validate(val_loader, model, criterion, args, output_mask=None): 139 | batch_time = AverageMeter('Time', ':6.3f', Summary.NONE) 140 | losses = AverageMeter('Loss', ':.4e', Summary.NONE) 141 | top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE) 142 | top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE) 143 | progress = ProgressMeter( 144 | len(val_loader), 145 | [batch_time, losses, top1, top5], 146 | prefix='Test: ') 147 | 148 | # switch to evaluate mode 149 | model.eval() 150 | 151 | with torch.no_grad(): 152 | end = time.time() 153 | for i, (images, target) in enumerate(val_loader): 154 | if args.gpu is not None: 155 | images = images.cuda(args.gpu, non_blocking=True) 156 | if torch.cuda.is_available(): 157 | target = target.cuda(args.gpu, non_blocking=True) 158 | 159 | # compute output 160 | with torch.cuda.amp.autocast(): 161 | output = model(images) 162 | if output_mask: 163 | output = output[:, output_mask] 164 | loss = criterion(output, target) 165 | 166 | # measure accuracy and record loss 167 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 168 | losses.update(loss.item(), images.size(0)) 169 | top1.update(acc1[0], images.size(0)) 170 | top5.update(acc5[0], images.size(0)) 171 | 172 | # measure elapsed time 173 | batch_time.update(time.time() - end) 174 | end = time.time() 175 | 176 | if i % args.print_freq == 0: 177 | progress.display(i) 178 | progress.display_summary() 179 | 180 | return top1.avg 181 | 182 | class SmoothCrossEntropy(nn.Module): 183 | def __init__(self, alpha=0.0): 184 | super(SmoothCrossEntropy, self).__init__() 185 | self.alpha = alpha 186 | 187 | def forward(self, logits, labels): 188 | num_classes = logits.shape[-1] 189 | alpha_div_k = self.alpha / num_classes 190 | target_probs = F.one_hot(labels, num_classes=num_classes).float() * \ 191 | (1. - self.alpha) + alpha_div_k 192 | loss = -(target_probs * torch.log_softmax(logits, dim=-1)).sum(dim=-1) 193 | return loss.mean() 194 | 195 | def select_confident_samples(prob, top): 196 | # ipdb.set_trace() 197 | batch_entropy = -(prob * torch.log(prob + 1e-6)).sum(1) 198 | idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * top)] ## pick the min entropy 199 | idx_confused = torch.argsort(batch_entropy, descending=False)[int(batch_entropy.size()[0] * top):] ## pick the max entropy 200 | return prob[idx], idx, prob[idx_confused], idx_confused 201 | 202 | 203 | def cls_acc(output, target, topk=1): 204 | pred = output.topk(topk, 1, True, True)[1].t() 205 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 206 | acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) 207 | acc = 100 * acc / target.shape[0] 208 | return acc 209 | 210 | def init_image_memory(train_loader, model, args): 211 | model.eval() 212 | if model.first_flag: 213 | with torch.no_grad(): 214 | text_feat, text_feat_full = model.get_text_features() 215 | else: 216 | print('the text feat has already initilized, pass it here.') 217 | memorized_image_global_feat = [] ## N*[shot*aug]*C 218 | memorized_image_local_feat = [] ## N*[shot*aug]*C 219 | memorized_image_global_feat_vanilla = [] ## N*shot*C 220 | memorized_image_local_feat_vanilla = [] ## N*shot*C 221 | memorized_labels = [] 222 | 223 | for i in range(model.n_cls): 224 | memorized_image_global_feat.append([]) 225 | memorized_image_local_feat.append([]) 226 | memorized_image_global_feat_vanilla.append([]) 227 | memorized_image_local_feat_vanilla.append([]) 228 | memorized_labels.append([]) 229 | 230 | for i, (images, target) in enumerate(train_loader): 231 | assert args.gpu is not None 232 | if isinstance(images, list): ### augmix return, list 233 | images = torch.cat(images, dim=0) 234 | images = images.cuda(args.gpu, non_blocking=True) 235 | else: ## standard return, Tensor 236 | if len(images.size()) > 4: 237 | # when using ImageNet Sampler as the dataset 238 | assert images.size()[0] == 1 239 | images = images.squeeze(0) 240 | images = images.cuda(args.gpu, non_blocking=True) 241 | target = target.cuda(args.gpu, non_blocking=True) 242 | with torch.no_grad(): 243 | image_features_global, image_features_local = model.get_image_features(images) ## 4*1024; 4*49*1024. 244 | text_features = model.text_feat[target] ## 512 245 | ## only use the original ?? we should use all; however, only use the vanilla one in the dynamic memory. 246 | selected_image_features_local = model.image_features_local 247 | cos_sim = (selected_image_features_local * text_features).sum(-1) ## between 0.2-0.3, very close. 248 | weight_prob = (cos_sim * 100).softmax(-1) ## 1*197, following clip temperature. 249 | ######## 250 | attented_feat = (weight_prob.unsqueeze(-1) * selected_image_features_local).sum(1) ## 1*512 251 | attented_feat = attented_feat / attented_feat.norm(dim=-1, keepdim=True) ## 1*512 252 | memorized_image_global_feat[target].append(image_features_global) ## aug*C 253 | memorized_image_local_feat[target].append(attented_feat) # aug * C 254 | memorized_image_global_feat_vanilla[target].append(image_features_global[:1]) ## aug*C 255 | memorized_image_local_feat_vanilla[target].append(attented_feat[:1]) # aug * C 256 | one_hot_target = torch.zeros(1, model.n_cls).to(target.device) 257 | one_hot_target[0, target] = 1 258 | memorized_labels[target].append(one_hot_target) ## 1 * C, turn it to one hot labels. 259 | 260 | for i in range(model.n_cls): 261 | memorized_image_global_feat[i] = torch.cat(memorized_image_global_feat[i], dim=0).unsqueeze(0) ## 1*augshot*C 262 | memorized_image_local_feat[i] = torch.cat(memorized_image_local_feat[i], dim=0).unsqueeze(0) 263 | memorized_image_global_feat_vanilla[i] = torch.cat(memorized_image_global_feat_vanilla[i], dim=0).unsqueeze(0) ## 1*shot*C 264 | memorized_image_local_feat_vanilla[i] = torch.cat(memorized_image_local_feat_vanilla[i], dim=0).unsqueeze(0) 265 | memorized_labels[i] = torch.cat(memorized_labels[i], dim=0).unsqueeze(0) 266 | 267 | memorized_image_global_feat = torch.cat(memorized_image_global_feat, dim=0) ## n*shot*c 268 | memorized_image_local_feat = torch.cat(memorized_image_local_feat, dim=0) 269 | memorized_image_global_feat_vanilla = torch.cat(memorized_image_global_feat_vanilla, dim=0) ## n*shot*c 270 | memorized_image_local_feat_vanilla = torch.cat(memorized_image_local_feat_vanilla, dim=0) 271 | memorized_labels = torch.cat(memorized_labels, dim=0) 272 | 273 | ######## memorized few shot features and labels. 274 | model.fewshot_image_global_feat = memorized_image_global_feat ## class*augshot*c 275 | model.fewshot_image_local_feat = memorized_image_local_feat 276 | model.fewshot_image_global_feat_vanilla = memorized_image_global_feat_vanilla ## class*shot*c 277 | model.fewshot_image_local_feat_vanilla = memorized_image_local_feat_vanilla 278 | model.fewshot_label = memorized_labels ## class*shot*c, one hot labels 279 | 280 | ############# add features of labeled data to the dynamic memory. This is important when there are more labeled data. 281 | model.fixed_global_feat_vanilla = torch.cat((model.fixed_global_feat, memorized_image_global_feat_vanilla), dim=1) ## N*1*C 282 | model.fixed_local_feat_vanilla = torch.cat((model.fixed_local_feat, memorized_image_local_feat_vanilla), dim=1) ## N*1*C 283 | 284 | ###################### for static memory, with text feature and augmented image feat 285 | model.fixed_global_feat = torch.cat((model.fixed_global_feat, memorized_image_global_feat), dim=1) ## N*1*C 286 | model.fixed_local_feat = torch.cat((model.fixed_local_feat, memorized_image_local_feat), dim=1) ## N*1*C 287 | 288 | print('appending the few shot image feature to fixed image memories.') 289 | 290 | #%% data utils 291 | from PIL import Image 292 | try: 293 | from torchvision.transforms import InterpolationMode 294 | BICUBIC = InterpolationMode.BICUBIC 295 | except ImportError: 296 | BICUBIC = Image.BICUBIC 297 | 298 | from . import DMN_augmix_ops as augmentations 299 | from . import DMN_randaugment as RandAugmentMC 300 | 301 | ID_to_DIRNAME={ 302 | 'I': 'ImageNet', 303 | 'A': 'imagenet-a', 304 | 'K': 'ImageNet-Sketch', 305 | 'R': 'imagenet-r', 306 | 'V': 'imagenetv2-matched-frequency-format-val', 307 | 'flower102': 'oxford_flowers', 308 | 'dtd': 'dtd', 309 | 'pets': 'oxford_pets/images', 310 | 'cars': 'stanford_cars', 311 | 'ucf101': 'ucf101/UCF-101-midframes', 312 | 'caltech101': 'caltech-101/caltech-101/101_ObjectCategories', 313 | 'food101': 'food-101', 314 | 'sun397': 'sun397/SUN397', 315 | 'aircraft': 'fgvc_aircraft', 316 | 'eurosat': 'eurosat/2750' 317 | } 318 | 319 | 320 | # AugMix Transforms 321 | def get_preaugment(): 322 | return transforms.Compose([ 323 | transforms.RandomResizedCrop(224, scale=(0.5, 1)), 324 | # transforms.Resize(256, interpolation=BICUBIC), 325 | # transforms.RandomCrop(224), 326 | transforms.RandomHorizontalFlip(), 327 | ]) 328 | 329 | def augmix(image, preprocess, aug_list, severity=1): 330 | preaugment = get_preaugment() 331 | x_orig = preaugment(image) 332 | x_processed = preprocess(x_orig) 333 | if len(aug_list) == 0: 334 | return x_processed 335 | w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0])) 336 | m = np.float32(np.random.beta(1.0, 1.0)) 337 | 338 | mix = torch.zeros_like(x_processed) 339 | for i in range(3): 340 | x_aug = x_orig.copy() 341 | for _ in range(np.random.randint(1, 4)): 342 | x_aug = np.random.choice(aug_list)(x_aug, severity) 343 | mix += w[i] * preprocess(x_aug) 344 | mix = m * x_processed + (1 - m) * mix 345 | return mix 346 | 347 | 348 | class AugMixAugmenter(object): 349 | def __init__(self, base_transform, preprocess, n_views=2, augmix=False, 350 | severity=1): 351 | self.base_transform = base_transform 352 | self.preprocess = preprocess 353 | self.n_views = n_views 354 | if augmix: 355 | self.aug_list = augmentations.augmentations 356 | else: 357 | self.aug_list = [] 358 | self.severity = severity 359 | 360 | def __call__(self, x): 361 | image = self.preprocess(self.base_transform(x)) 362 | views = [augmix(x, self.preprocess, self.aug_list, self.severity) for _ in range(self.n_views)] 363 | return [image] + views 364 | 365 | ####################################### weak augmentation for memorized images. 366 | def get_preaugment_augmem(): 367 | return transforms.Compose([ 368 | # transforms.Resize(230, interpolation=BICUBIC), 369 | # transforms.RandomCrop(224), 370 | transforms.RandomResizedCrop(224, scale=(0.5, 1)), 371 | transforms.RandomHorizontalFlip(), 372 | ]) 373 | 374 | 375 | def augmem(image, preprocess, aug_list, severity=1): 376 | preaugment = get_preaugment_augmem() 377 | x_orig = preaugment(image) 378 | x_processed = preprocess(x_orig) 379 | return x_processed 380 | 381 | class AugMemAugmenter(object): 382 | def __init__(self, base_transform, preprocess, n_views=2, augmix=False, 383 | severity=1): 384 | self.base_transform = base_transform 385 | self.preprocess = preprocess 386 | self.n_views = n_views 387 | if augmix: 388 | self.aug_list = augmentations.augmentations 389 | else: 390 | self.aug_list = [] 391 | self.severity = severity 392 | 393 | def __call__(self, x): 394 | image = self.preprocess(self.base_transform(x)) 395 | views = [augmem(x, self.preprocess, self.aug_list, self.severity) for _ in range(self.n_views)] 396 | return [image] + views 397 | 398 | def randaug(image, preprocess, strong_aug): 399 | preaugment = get_preaugment() 400 | x_orig = preaugment(image) 401 | x_orig = strong_aug(x_orig) 402 | x_processed = preprocess(x_orig) 403 | return x_processed 404 | 405 | class StrongAugmenter(object): 406 | def __init__(self, base_transform, preprocess, n_views=2, augmix=False, 407 | severity=1): 408 | self.base_transform = base_transform 409 | self.preprocess = preprocess 410 | self.n_views = n_views 411 | if augmix: 412 | self.aug_list = augmentations.augmentations 413 | else: 414 | self.aug_list = [] 415 | self.severity = severity 416 | self.strong_aug = RandAugmentMC(n=2, m=10) 417 | 418 | def __call__(self, x): 419 | preaugment = get_preaugment() 420 | x_orig = preaugment(x) 421 | image = self.preprocess(x_orig) 422 | 423 | # image = augmix(x, self.preprocess, self.aug_list, self.severity) 424 | # image = randaug(x, self.preprocess, self.strong_aug) 425 | 426 | return image 427 | 428 | class StrongAugmenterRand(object): 429 | def __init__(self, base_transform, preprocess, n_views=2, augmix=False, 430 | severity=1): 431 | self.base_transform = base_transform 432 | self.preprocess = preprocess 433 | self.n_views = n_views 434 | if augmix: 435 | self.aug_list = augmentations.augmentations 436 | else: 437 | self.aug_list = [] 438 | self.severity = severity 439 | self.strong_aug = RandAugmentMC(n=2, m=10) 440 | 441 | def __call__(self, x): 442 | rand_num = random.random() 443 | if rand_num < 0.5: 444 | preaugment = get_preaugment() 445 | x_orig = preaugment(x) 446 | image = self.preprocess(x_orig) 447 | else: 448 | image = augmix(x, self.preprocess, self.aug_list, self.severity) 449 | # else: 450 | # image = randaug(x, self.preprocess, self.strong_aug) 451 | 452 | return image 453 | -------------------------------------------------------------------------------- /DMN/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cfuchs2023/OGA/b91d449a28d0958849c43dc7f698405b9fcfe4b4/DMN/__init__.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Clement Fuchs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /OGA/OGA_core.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | class GaussAdapt(torch.nn.Module): 4 | def __init__(self, clip_prototypes, shot_capacity = 8, sig_type = 'RidgeMoorePenrose'): 5 | '''shot_capacity: maximum number of stored samples per class. 6 | sig_type: type of estimator for teh covariance. One of 'Ridge', 'MoorePenrose' or the recommended 'RidgeMoorePenrose'. 7 | The latter transitions from empirical Bayes Ridge (see https://doi.org/10.1016/j.jmva.2008.01.016) to inverse when more than 4d sampels are available. ''' 8 | super(GaussAdapt, self).__init__() 9 | assert sig_type in ['RidgeMoorePenrose', 'Ridge', 'MoorePenrose'] 10 | K,d = clip_prototypes.shape 11 | self.shot_capacity = shot_capacity 12 | self.K = K 13 | self.clip_prototypes = clip_prototypes #should be (K,d) 14 | self.mus = clip_prototypes.clone().type(torch.float32) 15 | self.temp = 100 16 | self.d = clip_prototypes.shape[-1] 17 | self.count = torch.nn.Parameter(torch.zeros(self.K), requires_grad = False) 18 | self.sig_type = sig_type 19 | self.Sig = torch.nn.Parameter(1/d * torch.eye(d, dtype = torch.float32), requires_grad = False) 20 | self.inv_Sig = torch.nn.Parameter(d * torch.eye(d, dtype = torch.float32), requires_grad = False) 21 | self.memory_state = torch.nn.Parameter(torch.zeros((K,shot_capacity), dtype = torch.bool), requires_grad = False) 22 | 23 | self.memory = torch.nn.Parameter(torch.zeros((K, shot_capacity, d), dtype = torch.float16), 24 | requires_grad = False) 25 | self.memory_soft_labels = torch.nn.Parameter(torch.zeros((K, shot_capacity), dtype = torch.float16), 26 | requires_grad = False) 27 | self.__init_entropy(prop_max = 1) 28 | 29 | return None 30 | 31 | def __init_entropy(self, prop_max = 1): 32 | max_entropy = -torch.log(torch.tensor(1/self.K)) 33 | init_val = prop_max * max_entropy 34 | self.memory_entropy = torch.nn.Parameter(init_val * torch.ones((self.K,self.shot_capacity), dtype = torch.float16, device = self.memory.device), 35 | requires_grad = False) 36 | return init_val 37 | 38 | 39 | def get_entropy(self, probs): 40 | sh_entropy = - torch.sum(torch.log(probs+1e-6)*probs, dim = -1) 41 | return sh_entropy 42 | 43 | def __update_memory_entropy(self, x, text_prob, entropy, pseudo_label, gauss_prob = None): 44 | updated = False 45 | if torch.any(entropy=2 # was >2 88 | self.mus[mask,:] = means[mask,:].type(torch.float32) 89 | if normalize_mu: 90 | self.mus[mask,:] = self.mus[mask,:] / torch.linalg.norm(self.mus[mask,:], dim = -1, keepdims = True) 91 | return None 92 | 93 | 94 | def __update_sigma(self, use_soft_labels = False): 95 | if 'Ridge' == self.sig_type: 96 | d = self.mus.shape[-1] 97 | x = self.memory.view((self.K*self.shot_capacity, d)) 98 | x_mem_state = self.memory_state.view((self.K*self.shot_capacity)) 99 | if torch.any(torch.sum(self.memory_state, dim = -1)>2): 100 | x_labels = torch.tensor([k for k in range(self.K) for _ in range(self.shot_capacity)], device = x.device) 101 | center_vecs = torch.cat([x[torch.logical_and(x_mem_state, x_labels == k)] - self.mus[k:k+1,:] for k in range(self.K)]) 102 | M = center_vecs.T.cov() 103 | trace = torch.sum(M[range(d), range(d)]) 104 | # shape 1 = d / shape 0 = n 105 | n,d = center_vecs.shape 106 | cov_inv = d * torch.linalg.pinv((n - 1) * M + trace * torch.eye(d, device = center_vecs.device)) 107 | self.Sig[...] = M 108 | self.inv_Sig[...] = cov_inv 109 | elif 'RidgeMoorePenrose' == self.sig_type: 110 | d = self.mus.shape[-1] 111 | n = torch.sum(self.memory_state) 112 | 113 | if torch.any(torch.sum(self.memory_state, dim = -1)>2): 114 | x = self.memory.view((self.K*self.shot_capacity, d)) 115 | x_labels = torch.tensor([k for k in range(self.K) for _ in range(self.shot_capacity)], device = x.device) 116 | x_mem_state = self.memory_state.view((self.K*self.shot_capacity)) 117 | 118 | class_probs = self.memory_soft_labels[self.memory_state] 119 | center_vecs = torch.cat([x[torch.logical_and(x_mem_state, x_labels == k)] - self.mus[k:k+1,:] for k in range(self.K)]) 120 | center_vec_mean = center_vecs.mean(dim=0) 121 | if use_soft_labels: 122 | #M = center_vecs.T.cov(correction=1) 123 | c_center_vecs = (center_vecs - center_vec_mean[None,:]) * class_probs[:,None] 124 | M = c_center_vecs.T @ c_center_vecs / torch.sum(class_probs) 125 | else: 126 | c_center_vecs = (center_vecs - center_vec_mean[None,:]) 127 | M = c_center_vecs.T @ c_center_vecs / (n-1) 128 | 129 | if n<=4*d: 130 | # use shrinkage 131 | 132 | trace = torch.sum(M[range(d), range(d)]) 133 | # shape 1 = d / shape 0 = n 134 | cov_inv = d * torch.linalg.pinv((n - 1) * M + trace * torch.eye(d, device = center_vecs.device)) 135 | self.Sig[...] = M 136 | self.inv_Sig[...] = cov_inv 137 | else: 138 | # Use pinv 139 | self.Sig[...] = M 140 | self.inv_Sig[...] = torch.linalg.pinv(M.type(torch.float32)) 141 | elif 'MoorePenrose' == self.sig_type: 142 | d = self.mus.shape[-1] 143 | x = self.memory.view((self.K*self.shot_capacity, d)) 144 | x_mem_state = self.memory_state.view((self.K*self.shot_capacity)) 145 | if torch.any(torch.sum(self.memory_state, dim = -1)>2): 146 | x_labels = torch.tensor([k for k in range(self.K) for _ in range(self.shot_capacity)], device = x.device) 147 | center_vecs = torch.cat([x[torch.logical_and(x_mem_state, x_labels == k)] - self.mus[k:k+1,:] for k in range(self.K)]) 148 | M = center_vecs.T.cov() 149 | self.Sig[...] = M 150 | self.inv_Sig[...] = torch.linalg.pinv(M.type(torch.float32)) 151 | 152 | return None 153 | 154 | def get_log_probs(self,x): 155 | W = torch.einsum('nd, dc -> cn', self.mus, self.inv_Sig) 156 | b = - torch.einsum('nd, dc, nc -> n', self.mus, self.inv_Sig, self.mus) / 2 157 | Q = - torch.einsum('nd, dc, nc -> n', x.float(), self.inv_Sig, x.float()) / 2 158 | log_probs = (x.float() @ W + b) 159 | log_probs += Q[:,None] 160 | return log_probs 161 | 162 | 163 | def get_MAP(self, y_hat, memory_logits, tau = 0.01, simplex_p = False): 164 | '''y_hat: zero shot soft labels. memory_logits: log probabilities obtained from the cached samples. ''' 165 | lambd = 1.0 166 | assert type(tau) is float or type(lambd) is float 167 | # Compute gaussian probs 168 | if type(tau) is float: 169 | if not simplex_p: 170 | p_ = torch.exp(tau * memory_logits) 171 | else: 172 | p_ = (tau*memory_logits).softmax(-1) 173 | else: 174 | if not simplex_p: 175 | p_ = torch.exp(tau[None,None,:] * memory_logits[...,None]) 176 | else: 177 | p_ = (tau[None,None,:] * memory_logits[...,None]).softmax(-1) 178 | 179 | # Compute MAP (only if y_hat is not None) 180 | if y_hat is None: 181 | z = None 182 | else: 183 | if type(lambd) is float: 184 | if len(p_.shape) == 2: 185 | z = (y_hat**lambd) * p_ 186 | z = z/torch.sum(z, dim = 1, keepdims = True) 187 | elif len(p_.shape) == 3: 188 | z = (y_hat**lambd)[...,None] * p_ 189 | z = z/torch.sum(z, dim = 1, keepdims = True) 190 | else: 191 | raise RuntimeError(f'Incompatible p_ shape {p_.shape}') 192 | 193 | else: 194 | if len(p_.shape) == 2: 195 | z = (y_hat[:,:,None]**lambd[None,None,:]) * p_[:,:,None] 196 | elif len(p_.shape) == 3: 197 | z = (y_hat[:,:,None]**lambd[None,None,:])[...,None] * p_[...,None] 198 | else: 199 | raise RuntimeError(f'Incompatible p_ shape {p_.shape}') 200 | return z, p_ -------------------------------------------------------------------------------- /OGA/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cfuchs2023/OGA/b91d449a28d0958849c43dc7f698405b9fcfe4b4/OGA/__init__.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Online Gaussian Adaptation of Vision-Language Models (OGA) 2 | 3 | 4 | 5 | 6 | The official repository of the paper [*Online Gaussian Adaptation of Vision-Language Model*](https://arxiv.org/abs/2501.04352). 7 | 8 | Authors: 9 | [Clément Fuchs*](https://scholar.google.com/citations?user=ZXWUJ4QAAAAJ&hl=fr&oi=ao), 10 | [Maxime Zanella*](https://scholar.google.com/citations?user=FIoE9YIAAAAJ&hl=fr&oi=ao), 11 | [Christophe De Vleeschouwer](https://scholar.google.com/citations?user=xb3Zc3cAAAAJ&hl=fr&oi=ao). 12 | 13 | *Denotes equal contribution 14 | 15 | ## Overview 16 | 17 | **OGA** is an online adaptation method which builds a cache of samples with low zero-shot entropy along a data stream. This cache is then used to build a multivariate Gaussian model of the class conditional likelihoods of the observed features, finally computing updated predictions using a pseudo-bayesian Maximum A Posteriori (MAP) estimator. Main results averaged over 11 datasets are summarized in the 2 Tables below. 18 | 19 | | W/ standard prompts | **ViT-B/16** | **ViT-B/32** | **ViT-L/14** | **ResNet50** | **ResNet101** | 20 | |----------------------|--------------|--------------|--------------|--------------|---------------| 21 | | Zero-Shot | 65.3 | 61.9 | 72.6 | 58.7 | 59.5 | 22 | | TDA | 67.7 `↑2.4` | 62.3 `↑0.4` | 73.5 `↑0.9` | 59.3 `↑0.6` | 60.6 `↑1.1` | 23 | | DMN | 67.5 `↑2.2` | 61.8 `↓0.1` | 73.7 `↑1.1` | 58.6 `↓0.1` | 61.0 `↑1.5` | 24 | | **OGA (ours)** | **68.5 `↑3.2`** | **62.9 `↑1.0`** | **74.3 `↑1.7`** | **59.8 `↑1.1`** | **61.6 `↑2.1`** | 25 | 26 | 27 | | W/ custom prompts | **ViT-B/16** | **ViT-B/32** | **ViT-L/14** | **ResNet50** | **ResNet101** | 28 | |----------------------|--------------|--------------|--------------|--------------|---------------| 29 | | Zero-Shot | 65.6 | 61.4 | 72.2 | 57.4 | 59.0 | 30 | | TDA | 66.9 `↑1.3` | 62.3 `↑0.9` | 73.9 `↑1.7` | 58.1 `↑0.7` | 59.4 `↑0.4` | 31 | | DMN | 66.4 `↑0.8` | 61.6 `↑0.2` | 74.4 `↑2.2` | 57.2 `↓0.2` | 60.3 `↑1.3` | 32 | | **OGA (ours)** | **67.3 `↑1.7`** | **62.8 `↑1.4`** | **74.7 `↑2.5`** | **58.4 `↑1.0`** | **60.6 `↑1.6`** | 33 | 34 | Additionally, we advocate for more rigorous evaluation practices, including increasing the number of runs and considering additional quantitative metrics, such as our proposed Expected Tail Accuracy (ETA), calculated as the average accuracy in the worst 10% of runs. See illustration below. 35 | 36 |

37 | Bar plot 38 |
39 | Figure 1. The presented results are averaged over 100 runs. We propose the Expected Tail Accuracy (ETA), i.e., the average over the 10% worst runs, in solid red line. Our method named OGA not only significantly outperforms competitors on average but also has an ETA exceeding their average accuracy on several datasets (e.g., ImageNet and Pets). See our paper https://arxiv.org/abs/2501.04352 40 |

41 | 42 | The repository also includes a lightweight implementation of [TDA](https://openaccess.thecvf.com/content/CVPR2024/html/Karmanov_Efficient_Test-Time_Adaptation_of_Vision-Language_Models_CVPR_2024_paper.html) and [DMN](https://openaccess.thecvf.com/content/CVPR2024/html/Zhang_Dual_Memory_Networks_A_Versatile_Adaptation_Approach_for_Vision-Language_Models_CVPR_2024_paper.html) for training free / zero-shot adaptation without test-time augmentations. 43 | 44 | 45 | ## Dependencies 46 | The repository is dependent on [PyTorch](https://pytorch.org/) and [openai-clip](https://pypi.org/project/openai-clip/). 47 | ## Datasets 48 | Please follow [DATASETS.md](DATASETS.md) to install the datasets. 49 | You will get a structure with the following dataset names: 50 | ``` 51 | $DATA/ 52 | |–– caltech-101/ 53 | |–– oxford_pets/ 54 | |–– stanford_cars/ 55 | |–– oxford_flowers/ 56 | |–– food-101/ 57 | |–– fgvc_aircraft/ 58 | |–– sun397/ 59 | |–– dtd/ 60 | |–– eurosat/ 61 | |–– ucf101/ 62 | |–– imagenet/ 63 | ``` 64 | ## Running benchmarks 65 | ### Computing and storing features 66 | The benchmarks are run using pre-computed features, as none of the available methods update the vision encoder. 67 | First, use compute_features.py to compute and store features and labels. 68 | Example : 69 | ```bash 70 | python compute_features.py --data_root_path "E:/DATA" --backbone "vit_b16" --datasets 'sun397' 'imagenet' 'fgvc_aircraft' 'eurosat' 'food101' 'caltech101' 'oxford_pets' 'oxford_flowers' 'stanford_cars' 'dtd' 'ucf101' 71 | ``` 72 | /!\ Warning: The above command line overwrites previous features for the current architecture. 73 | The features and targets are stored in a "cache" subfolder within each dataset folder. It should look like 74 | ``` 75 | $DATA/ 76 | |–– caltech-101/ 77 | |--cache/ 78 | |–– oxford_pets/ 79 | |--cache/ 80 | |–– stanford_cars/ 81 | |--cache/ 82 | ... 83 | ``` 84 | ### Using trained prompts 85 | In our paper, we present results obtained atop [TaskRes](https://openaccess.thecvf.com/content/CVPR2023/html/Yu_Task_Residual_for_Tuning_Vision-Language_Models_CVPR_2023_paper.html) and [CoOp](https://link.springer.com/article/10.1007/s11263-022-01653-1). To reproduce the relevant results, you need to download the pre-computed prototypes from [TransCLIP](https://github.com/MaxZanella/transduction-for-vlms). Go to "Pre-computed prototypes" and download the 'Few_shot' folder from the provided drive. Place it in $DATA/clip_tuned_prompts/. 86 | It should look like 87 | ``` 88 | $DATA/ 89 | |–– clip_tuned_prompts/ 90 | |--Few_shot/ 91 | ... 92 | ``` 93 | 94 | ### Benchmarks 95 | Results presented in our paper can be reproduced using main.py. Results are stored in a .json (for quantities such as average batch accuracy per dataset) and a .pickle (for detailed results such as accuracy per batch), at $DATA/results/. 96 | The randomness is controlled by the parameters --master_seed and --n_runs. For a same tuple of (master_seed, n_runs), the runs generated are always the same. Note that you may still observe slight variations in results depending on your CUDA and PyTorch versions or hardware specifications. 97 | Example : 98 | ```bash 99 | python main.py --data_root_path "E:/DATA" --adapt_method_name "TDA" --datasets 'sun397' 'imagenet' 'fgvc_aircraft' 'eurosat' 'food101' 'caltech101' 'oxford_pets' 'oxford_flowers' 'stanford_cars' 'dtd' 'ucf101' 100 | ``` 101 | 102 | ## Citation 103 | 104 | If you find this repository useful, please consider citing our paper: 105 | ``` 106 | @article{fuchs2025online, 107 | title={Online Gaussian Test-Time Adaptation of Vision-Language Models}, 108 | author={Fuchs, Cl{\'e}ment and Zanella, Maxime and De Vleeschouwer, Christophe} 109 | journal={arXiv preprint arXiv:2501.04352}, 110 | year={2025} 111 | } 112 | ``` 113 | 114 | ## Contact 115 | 116 | For any inquiries, please contact us at [clement.fuchs@uclouvain.be](mailto:clement.fuchs@uclouvain.be) and [maxime.zanella@uclouvain.be](mailto:maxime.zanella@uclouvain.be) or feel free to [create an issue](https://github.com/cfuchs2023/OGA/issues). 117 | 118 | ## Acknowledgment 119 | This repository is mainly based on [CLIP](https://github.com/openai/CLIP) and [TransCLIP](https://github.com/MaxZanella/transduction-for-vlms). 120 | 121 | -------------------------------------------------------------------------------- /TDA/TDA_core.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn.functional as F 4 | import operator 5 | from .TDA_utils import * 6 | 7 | #%% 8 | class TDA_PosCache(torch.nn.Module): 9 | def __init__(self, K, d, shot_capacity = 8): 10 | super(TDA_PosCache, self).__init__() 11 | self.shot_capacity = shot_capacity 12 | self.memory = torch.nn.Parameter(torch.zeros((K,shot_capacity,d), dtype = torch.float16), 13 | requires_grad = False) 14 | self.memory_entropy = torch.nn.Parameter(1e3 * torch.ones((K,shot_capacity), dtype = torch.float16), 15 | requires_grad = False) 16 | self.memory_state = torch.nn.Parameter(torch.zeros((K,shot_capacity), dtype = torch.bool), requires_grad = False) 17 | self.K = K 18 | self.d = d 19 | self.shot_capacity = shot_capacity 20 | # self.memory_indexes = torch.nn.Parameter(torch.zeros((K, shot_capacity), dtype = torch.int32), 21 | # requires_grad = False) # for sanity checks 22 | self.init_entropy(prop_max = 1) 23 | return None 24 | 25 | 26 | def init_entropy(self, prop_max = 1): 27 | max_entropy = -torch.log(torch.tensor(1/self.K).to(self.memory_entropy.device)) 28 | init_val = prop_max * max_entropy 29 | self.memory_entropy = torch.nn.Parameter(init_val * torch.ones((self.K,self.shot_capacity), dtype = torch.float16, device = self.memory.device), 30 | requires_grad = False) 31 | return init_val 32 | def __update_memory(self, x, text_logit): 33 | text_prob = text_logit.softmax(-1) 34 | entropy = self.get_entropy(text_prob).item() 35 | text_label = torch.argmax(text_logit, dim = -1) 36 | updated = False 37 | if torch.any(entropy self.lower_probability_bound: 99 | entropy = self.get_entropy(text_prob).item() 100 | if entropy > self.lower_entropy_bound and entropy < self.upper_entropy_bound: 101 | if torch.any(entropy neg_mask_thresholds[0]) & (cache_values < neg_mask_thresholds[1])).type(torch.int8)).cuda().half() 161 | else: 162 | cache_values = (F.one_hot(torch.Tensor(cache_values).to(torch.int64), num_classes=clip_weights.size(1))).cuda().half() 163 | 164 | affinity = image_features @ cache_keys 165 | cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values 166 | return alpha * cache_logits 167 | 168 | def compute_tda_logits(pos_cfg, neg_cfg, query_features, clip_weights, 169 | pos_cache = {}, neg_cache = {}, do_update_cache = False): 170 | pos_enabled, neg_enabled = pos_cfg['enabled'], neg_cfg['enabled'] 171 | if pos_enabled: 172 | pos_params = {k: pos_cfg[k] for k in ['shot_capacity', 'alpha', 'beta']} 173 | if neg_enabled: 174 | neg_params = {k: neg_cfg[k] for k in ['shot_capacity', 'alpha', 'beta', 'entropy_threshold', 'mask_threshold']} 175 | with torch.no_grad(): 176 | clip_logits, loss, prob_map, pred = get_clip_logits(query_features, clip_weights) 177 | prop_entropy = get_entropy(loss, clip_weights) 178 | if do_update_cache: 179 | if pos_enabled: 180 | update_cache(pos_cache, pred, [query_features, loss], pos_params['shot_capacity']) 181 | 182 | if neg_enabled and neg_params['entropy_threshold']['lower'] < prop_entropy < neg_params['entropy_threshold']['upper']: 183 | update_cache(neg_cache, pred, [query_features, loss, prob_map], neg_params['shot_capacity'], True) 184 | 185 | final_logits = clip_logits.clone() 186 | if pos_enabled and pos_cache: 187 | final_logits += compute_cache_logits(query_features, pos_cache, pos_params['alpha'], pos_params['beta'], clip_weights) 188 | if neg_enabled and neg_cache: 189 | final_logits -= compute_cache_logits(query_features, neg_cache, neg_params['alpha'], neg_params['beta'], clip_weights, (neg_params['mask_threshold']['lower'], neg_params['mask_threshold']['upper'])) 190 | return final_logits, pos_cache, neg_cache 191 | 192 | 193 | def run_test_tda(pos_cfg, neg_cfg, query_features, query_labels, clip_weights, 194 | pos_cache = {}, neg_cache = {}): 195 | 196 | indices = [[i] for i in range(query_features.shape[0])] 197 | with torch.no_grad(): 198 | accuracies = [] 199 | 200 | #Test-time adaptation 201 | for i in range(query_features.shape[0]): 202 | indexes = indices[i] 203 | images_features = query_features[indexes] 204 | targets = query_labels[indexes] 205 | final_logits, pos_cache, neg_cache = compute_tda_logits(pos_cfg, neg_cfg, 206 | images_features, clip_weights, 207 | pos_cache, neg_cache, do_update_cache = True) 208 | acc = cls_acc(final_logits, targets) 209 | accuracies.append(acc) 210 | 211 | #print("---- TDA's test accuracy: {:.2f}. ----\n".format(sum(accuracies)/len(accuracies))) 212 | avg_accuracy = sum(accuracies)/len(accuracies) 213 | return avg_accuracy, pos_cache, neg_cache 214 | -------------------------------------------------------------------------------- /TDA/TDA_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import yaml 4 | import torch 5 | import math 6 | import numpy as np 7 | import clip 8 | from PIL import Image 9 | 10 | try: 11 | from torchvision.transforms import InterpolationMode 12 | BICUBIC = InterpolationMode.BICUBIC 13 | except ImportError: 14 | BICUBIC = Image.BICUBIC 15 | 16 | def get_entropy(loss, clip_weights): 17 | max_entropy = math.log2(clip_weights.size(1)) 18 | return float(loss / max_entropy) 19 | 20 | 21 | def softmax_entropy(x): 22 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 23 | 24 | 25 | def avg_entropy(outputs): 26 | logits = outputs - outputs.logsumexp(dim=-1, keepdim=True) 27 | avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0]) 28 | min_real = torch.finfo(avg_logits.dtype).min 29 | avg_logits = torch.clamp(avg_logits, min=min_real) 30 | return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1) 31 | 32 | 33 | def cls_acc(output, target, topk=1): 34 | pred = output.topk(topk, 1, True, True)[1].t() 35 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 36 | acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) 37 | acc = 100 * acc / target.shape[0] 38 | return acc 39 | 40 | 41 | def clip_classifier(classnames, template, clip_model): 42 | with torch.no_grad(): 43 | clip_weights = [] 44 | 45 | for classname in classnames: 46 | # Tokenize the prompts 47 | classname = classname.replace('_', ' ') 48 | texts = [t.format(classname) for t in template] 49 | texts = clip.tokenize(texts).cuda() 50 | # prompt ensemble for ImageNet 51 | class_embeddings = clip_model.encode_text(texts) 52 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 53 | class_embedding = class_embeddings.mean(dim=0) 54 | class_embedding /= class_embedding.norm() 55 | clip_weights.append(class_embedding) 56 | 57 | clip_weights = torch.stack(clip_weights, dim=1).cuda() 58 | return clip_weights 59 | 60 | 61 | def get_clip_logits(image_features, clip_weights): 62 | with torch.no_grad(): 63 | clip_logits = 100. * image_features @ clip_weights 64 | 65 | if image_features.size(0) > 1: 66 | batch_entropy = softmax_entropy(clip_logits) 67 | selected_idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * 0.1)] 68 | output = clip_logits[selected_idx] 69 | image_features = image_features[selected_idx].mean(0).unsqueeze(0) 70 | clip_logits = output.mean(0).unsqueeze(0) 71 | 72 | loss = avg_entropy(output) 73 | prob_map = output.softmax(1).mean(0).unsqueeze(0) 74 | pred = int(output.mean(0).unsqueeze(0).topk(1, 1, True, True)[1].t()) 75 | else: 76 | loss = softmax_entropy(clip_logits) 77 | prob_map = clip_logits.softmax(1) 78 | pred = int(clip_logits.topk(1, 1, True, True)[1].t()[0]) 79 | 80 | return clip_logits, loss, prob_map, pred 81 | 82 | 83 | # def get_ood_preprocess(): 84 | # normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 85 | # std=[0.26862954, 0.26130258, 0.27577711]) 86 | # base_transform = transforms.Compose([ 87 | # transforms.Resize(224, interpolation=BICUBIC), 88 | # transforms.CenterCrop(224)]) 89 | # preprocess = transforms.Compose([ 90 | # transforms.ToTensor(), 91 | # normalize]) 92 | # aug_preprocess = AugMixAugmenter(base_transform, preprocess, n_views=63, augmix=True) 93 | 94 | # return aug_preprocess 95 | 96 | 97 | def get_config_file(config_path, dataset_name): 98 | if dataset_name == "I": 99 | config_name = "imagenet.yaml" 100 | elif dataset_name in ["A", "V", "R", "S"]: 101 | config_name = f"imagenet_{dataset_name.lower()}.yaml" 102 | else: 103 | config_name = f"{dataset_name}.yaml" 104 | 105 | config_file = os.path.join(config_path, config_name) 106 | 107 | with open(config_file, 'r') as file: 108 | cfg = yaml.load(file, Loader=yaml.SafeLoader) 109 | 110 | if not os.path.exists(config_file): 111 | raise FileNotFoundError(f"The configuration file {config_file} was not found.") 112 | 113 | return cfg 114 | 115 | 116 | # def build_test_data_loader(dataset_name, root_path, preprocess): 117 | # if dataset_name == 'I': 118 | # dataset = ImageNet(root_path, preprocess) 119 | # test_loader = torch.utils.data.DataLoader(dataset.test, batch_size=1, num_workers=8, shuffle=True) 120 | 121 | # elif dataset_name in ['A','V','R','S']: 122 | # preprocess = get_ood_preprocess() 123 | # dataset = build_dataset(f"imagenet-{dataset_name.lower()}", root_path) 124 | # test_loader = build_data_loader(data_source=dataset.test, batch_size=1, is_train=False, tfm=preprocess, shuffle=True) 125 | 126 | # elif dataset_name in ['caltech101','dtd','eurosat','fgvc','food101','oxford_flowers','oxford_pets','stanford_cars','sun397','ucf101']: 127 | # dataset = build_dataset(dataset_name, root_path) 128 | # test_loader = build_data_loader(data_source=dataset.test, batch_size=1, is_train=False, tfm=preprocess, shuffle=True) 129 | 130 | # else: 131 | # raise "Dataset is not from the chosen list" 132 | 133 | # return test_loader, dataset.classnames, dataset.template 134 | -------------------------------------------------------------------------------- /TDA/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cfuchs2023/OGA/b91d449a28d0958849c43dc7f698405b9fcfe4b4/TDA/__init__.py -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cfuchs2023/OGA/b91d449a28d0958849c43dc7f698405b9fcfe4b4/__init__.py -------------------------------------------------------------------------------- /compute_features.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | import clip 4 | import os 5 | import utils as uti 6 | import datasets as dts 7 | def get_arguments(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--data_root_path', type=str) 10 | parser.add_argument('-d', '--datasets', default = ['dtd'], nargs="*", help = 'List of datasets for which to compute features.') 11 | parser.add_argument('--backbone', default='vit_b16', type=str, help = 'Name of the backbone to use. Examples : vit_b16 or rn101.') 12 | parser.add_argument('--root_cache_path', default = None, type = str, help = 'Path where the cached features and targets will be stored. Defaults to data_root_path/{dataset}/cache internally.') 13 | args = parser.parse_args() 14 | return args 15 | 16 | def main(): 17 | args = get_arguments() 18 | assert args.data_root_path is not None 19 | cfg = {} 20 | cfg['backbone'] = uti.backbones[args.backbone] 21 | print('========== Loading Clip Model') 22 | clip_model, preprocess = clip.load(cfg['backbone']) 23 | if args.root_cache_path is not None: 24 | base_cache_dir = args.root_cache_path 25 | else: 26 | base_cache_dir = args.data_root_path 27 | for dataset_name in args.datasets: 28 | print('\n******* dataset : ', dataset_name) 29 | if dataset_name == 'imagenet': 30 | cfg['load_cache'] = True 31 | cfg['dataset'] = uti.datasets[dataset_name] 32 | cfg['root_path'] = args.data_root_path 33 | cfg['shots'] = 0 34 | cfg['load_pre_feat'] = False 35 | cache_dir = os.path.join(base_cache_dir, uti.datasets[dataset_name], 'cache') 36 | os.makedirs(cache_dir, exist_ok=True) 37 | cfg['cache_dir'] = cache_dir 38 | print(cfg['cache_dir']) 39 | print('Computing Features...') 40 | train_loader, val_loader, test_loader, dataset = dts.get_all_dataloaders(cfg, preprocess, dirichlet=None) 41 | _ = uti.get_all_features( 42 | cfg, train_loader, val_loader, test_loader, dataset, clip_model) 43 | return None 44 | 45 | if __name__ == '__main__': 46 | main() 47 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | from .oxford_pets import OxfordPets 4 | from .eurosat import EuroSAT 5 | from .ucf101 import UCF101 6 | from .sun397 import SUN397 7 | from .caltech101 import Caltech101 8 | from .dtd import DescribableTextures 9 | from .fgvc import FGVCAircraft 10 | from .food101 import Food101 11 | from .oxford_flowers import OxfordFlowers 12 | from .stanford_cars import StanfordCars 13 | from .imagenet import ImageNet 14 | from .imagenet_a import ImageNetA 15 | from .imagenet_v2 import ImageNetV2 16 | from .imagenet_r import ImageNetR 17 | from .imagenet_sketch import ImageNetSketch 18 | from .utils import * 19 | from .sampler import LabelCorrelatedSampler 20 | from .prepare_tta_datasets import prepare_cifar 21 | 22 | dataset_list = { 23 | "oxford_pets": OxfordPets, 24 | "eurosat": EuroSAT, 25 | "ucf101": UCF101, 26 | "sun397": SUN397, 27 | "caltech101": Caltech101, 28 | "dtd": DescribableTextures, 29 | "fgvc": FGVCAircraft, 30 | "fgvc_aircraft": FGVCAircraft, 31 | "food101": Food101, 32 | "oxford_flowers": OxfordFlowers, 33 | "stanford_cars": StanfordCars, 34 | "imagenet": ImageNet, 35 | "imagenet_a": ImageNetA, 36 | "imagenet_v2": ImageNetV2, 37 | "imagenet_r": ImageNetR, 38 | "imagenet_sketch": ImageNetSketch, 39 | } 40 | 41 | 42 | 43 | def get_all_dataloaders(cfg, preprocess, dirichlet=None, batch_size = 64, num_workers = 8): 44 | dataset_name = cfg['dataset'] 45 | 46 | if dataset_name.startswith('imagenet'): 47 | if dirichlet == None: 48 | sampler = None 49 | else: 50 | sampler = LabelCorrelatedSampler(dataset.test, dirichlet, batch_size=batch_size) 51 | 52 | dataset = dataset_list[dataset_name](cfg['root_path'], cfg['shots'], preprocess=preprocess, train_preprocess=None, test_preprocess=None, load_cache=cfg['load_cache'], load_pre_feat=cfg['load_pre_feat']) 53 | test_loader = torch.utils.data.DataLoader(dataset.test, batch_size=batch_size, num_workers=num_workers, shuffle=False, sampler=sampler) 54 | train_loader = None 55 | val_loader = None 56 | if cfg['shots'] > 0: 57 | train_loader = torch.utils.data.DataLoader(dataset.train, batch_size=batch_size, num_workers=num_workers, shuffle=False) 58 | val_loader = torch.utils.data.DataLoader(dataset.val, batch_size=batch_size, num_workers=num_workers, shuffle=False) 59 | 60 | elif dataset_name.startswith('cifar'): 61 | sampler = None 62 | dataset = prepare_cifar(cfg['root_path'], dataset_name, preprocess=preprocess) 63 | train_loader, val_loader = None, None 64 | test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, sampler=sampler) 65 | else: 66 | dataset = dataset_list[dataset_name](cfg['root_path'], cfg['shots']) 67 | val_loader = build_data_loader(data_source=dataset.val, batch_size=batch_size, is_train=False, tfm=preprocess, 68 | shuffle=False) 69 | if dirichlet == None: 70 | sampler = None 71 | else: 72 | sampler = LabelCorrelatedSampler(dataset.test, dirichlet, batch_size=batch_size) 73 | test_loader = build_data_loader(data_source=dataset.test, batch_size=batch_size, is_train=False, tfm=preprocess, 74 | shuffle=False, sampler=sampler, num_workers = num_workers) 75 | train_loader = None 76 | if cfg['shots'] > 0: 77 | 78 | train_loader = build_data_loader(data_source=dataset.train_x, batch_size=batch_size, tfm=preprocess, 79 | is_train=False, shuffle=False, num_workers = num_workers) 80 | 81 | return train_loader, val_loader, test_loader, dataset 82 | 83 | 84 | -------------------------------------------------------------------------------- /datasets/caltech101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase 4 | from .oxford_pets import OxfordPets 5 | 6 | 7 | template = ['a photo of a {}.'] 8 | 9 | 10 | class Caltech101(DatasetBase): 11 | 12 | dataset_dir = 'Caltech101' 13 | 14 | def __init__(self, root, num_shots): 15 | self.dataset_dir = os.path.join(root, self.dataset_dir) 16 | self.image_dir = os.path.join(self.dataset_dir, '101_ObjectCategories') 17 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Caltech101.json') 18 | 19 | self.template = template 20 | 21 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 22 | n_shots_val = min(num_shots, 4) 23 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 24 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 25 | 26 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from .utils import Datum, DatasetBase, listdir_nohidden 5 | from .oxford_pets import OxfordPets 6 | 7 | 8 | template = ['{} texture'] 9 | #template = ['a photo of a {}.'] 10 | 11 | class DescribableTextures(DatasetBase): 12 | 13 | dataset_dir = 'DTD' 14 | 15 | def __init__(self, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.image_dir = os.path.join(self.dataset_dir, 'images') 18 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_DescribableTextures.json') 19 | 20 | self.template = template 21 | 22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 23 | n_shots_val = min(num_shots, 4) 24 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 25 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 26 | 27 | super().__init__(train_x=train, val=val, test=test) 28 | 29 | @staticmethod 30 | def read_and_split_data( 31 | image_dir, 32 | p_trn=0.5, 33 | p_val=0.2, 34 | ignored=[], 35 | new_cnames=None 36 | ): 37 | # The data are supposed to be organized into the following structure 38 | # ============= 39 | # images/ 40 | # dog/ 41 | # cat/ 42 | # horse/ 43 | # ============= 44 | categories = listdir_nohidden(image_dir) 45 | categories = [c for c in categories if c not in ignored] 46 | categories.sort() 47 | 48 | p_tst = 1 - p_trn - p_val 49 | print(f'Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test') 50 | 51 | def _collate(ims, y, c): 52 | items = [] 53 | for im in ims: 54 | item = Datum( 55 | impath=im, 56 | label=y, # is already 0-based 57 | classname=c 58 | ) 59 | items.append(item) 60 | return items 61 | 62 | train, val, test = [], [], [] 63 | for label, category in enumerate(categories): 64 | category_dir = os.path.join(image_dir, category) 65 | images = listdir_nohidden(category_dir) 66 | images = [os.path.join(category_dir, im) for im in images] 67 | random.shuffle(images) 68 | n_total = len(images) 69 | n_train = round(n_total * p_trn) 70 | n_val = round(n_total * p_val) 71 | n_test = n_total - n_train - n_val 72 | assert n_train > 0 and n_val > 0 and n_test > 0 73 | 74 | if new_cnames is not None and category in new_cnames: 75 | category = new_cnames[category] 76 | 77 | train.extend(_collate(images[:n_train], label, category)) 78 | val.extend(_collate(images[n_train:n_train+n_val], label, category)) 79 | test.extend(_collate(images[n_train+n_val:], label, category)) 80 | 81 | return train, val, test 82 | -------------------------------------------------------------------------------- /datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | from .oxford_pets import OxfordPets 5 | 6 | 7 | template = ['a centered satellite photo of {}.'] 8 | 9 | NEW_CNAMES = { 10 | 'AnnualCrop': 'Annual Crop Land', 11 | 'Forest': 'Forest', 12 | 'HerbaceousVegetation': 'Herbaceous Vegetation Land', 13 | 'Highway': 'Highway or Road', 14 | 'Industrial': 'Industrial Buildings', 15 | 'Pasture': 'Pasture Land', 16 | 'PermanentCrop': 'Permanent Crop Land', 17 | 'Residential': 'Residential Buildings', 18 | 'River': 'River', 19 | 'SeaLake': 'Sea or Lake' 20 | } 21 | 22 | 23 | class EuroSAT(DatasetBase): 24 | 25 | dataset_dir = 'eurosat' 26 | 27 | def __init__(self, root, num_shots): 28 | self.dataset_dir = os.path.join(root, self.dataset_dir) 29 | self.image_dir = os.path.join(self.dataset_dir, '2750') 30 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_EuroSAT.json') 31 | 32 | self.template = template 33 | 34 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 35 | n_shots_val = min(num_shots, 4) 36 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 37 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 38 | 39 | super().__init__(train_x=train, val=val, test=test) 40 | 41 | def update_classname(self, dataset_old): 42 | dataset_new = [] 43 | for item_old in dataset_old: 44 | cname_old = item_old.classname 45 | cname_new = NEW_CLASSNAMES[cname_old] 46 | item_new = Datum( 47 | impath=item_old.impath, 48 | label=item_old.label, 49 | classname=cname_new 50 | ) 51 | dataset_new.append(item_new) 52 | return dataset_new 53 | -------------------------------------------------------------------------------- /datasets/fgvc.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | 5 | 6 | template = ['a photo of a {}, a type of aircraft.'] 7 | 8 | 9 | class FGVCAircraft(DatasetBase): 10 | 11 | dataset_dir = 'fgvc_aircraft' 12 | 13 | def __init__(self, root, num_shots): 14 | 15 | self.dataset_dir = os.path.join(root, self.dataset_dir) 16 | self.image_dir = os.path.join(self.dataset_dir, 'images') 17 | 18 | self.template = template 19 | 20 | classnames = [] 21 | with open(os.path.join(self.dataset_dir, 'variants.txt'), 'r') as f: 22 | lines = f.readlines() 23 | for line in lines: 24 | classnames.append(line.strip()) 25 | cname2lab = {c: i for i, c in enumerate(classnames)} 26 | 27 | train = self.read_data(cname2lab, 'images_variant_train.txt') 28 | val = self.read_data(cname2lab, 'images_variant_val.txt') 29 | test = self.read_data(cname2lab, 'images_variant_test.txt') 30 | 31 | n_shots_val = min(num_shots, 4) 32 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 33 | 34 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 35 | 36 | super().__init__(train_x=train, val=val, test=test) 37 | 38 | def read_data(self, cname2lab, split_file): 39 | filepath = os.path.join(self.dataset_dir, split_file) 40 | items = [] 41 | 42 | with open(filepath, 'r') as f: 43 | lines = f.readlines() 44 | for line in lines: 45 | line = line.strip().split(' ') 46 | imname = line[0] + '.jpg' 47 | classname = ' '.join(line[1:]) 48 | impath = os.path.join(self.image_dir, imname) 49 | label = cname2lab[classname] 50 | item = Datum( 51 | impath=impath, 52 | label=label, 53 | classname=classname 54 | ) 55 | items.append(item) 56 | 57 | return items -------------------------------------------------------------------------------- /datasets/food101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | from .oxford_pets import OxfordPets 5 | 6 | 7 | template = ['a photo of {}, a type of food.'] 8 | 9 | 10 | class Food101(DatasetBase): 11 | 12 | dataset_dir = 'Food101' 13 | 14 | def __init__(self, root, num_shots): 15 | self.dataset_dir = os.path.join(root, self.dataset_dir) 16 | self.image_dir = os.path.join(self.dataset_dir, 'images') 17 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Food101.json') 18 | 19 | self.template = template 20 | 21 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 22 | n_shots_val = min(num_shots, 4) 23 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 24 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 25 | 26 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | from collections import defaultdict 5 | 6 | import torch 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | 10 | import torchvision.datasets as datasets 11 | 12 | 13 | imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", 14 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", 15 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", 16 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", 17 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", 18 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", 19 | "box turtle", "banded gecko", "green iguana", "Carolina anole", 20 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", 21 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", 22 | "American alligator", "triceratops", "worm snake", "ring-necked snake", 23 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", 24 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", 25 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", 26 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", 27 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", 28 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", 29 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", 30 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", 31 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", 32 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", 33 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", 34 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", 35 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", 36 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", 37 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", 38 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", 39 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", 40 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", 41 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", 42 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", 43 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", 44 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", 45 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", 46 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", 47 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", 48 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", 49 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", 50 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", 51 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", 52 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", 53 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", 54 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", 55 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", 56 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", 57 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", 58 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", 59 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", 60 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", 61 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", 62 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", 63 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", 64 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", 65 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", 66 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", 67 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", 68 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", 69 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", 70 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", 71 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", 72 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", 73 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", 74 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", 75 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", 76 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", 77 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", 78 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", 79 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", 80 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", 81 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", 82 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", 83 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", 84 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", 85 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", 86 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", 87 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", 88 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", 89 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", 90 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", 91 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", 92 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", 93 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", 94 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", 95 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", 96 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", 97 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", 98 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", 99 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", 100 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", 101 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", 102 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", 103 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", 104 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", 105 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", 106 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", 107 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", 108 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", 109 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", 110 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", 111 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", 112 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck", 113 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", 114 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", 115 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", 116 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", 117 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", 118 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", 119 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", 120 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", 121 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", 122 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", 123 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", 124 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", 125 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", 126 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", 127 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", 128 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", 129 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", 130 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", 131 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", 132 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", 133 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", 134 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", 135 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", 136 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", 137 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", 138 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", 139 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", 140 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", 141 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", 142 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", 143 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", 144 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", 145 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", 146 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", 147 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", 148 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", 149 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", 150 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", 151 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", 152 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", 153 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", 154 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", 155 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", 156 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", 157 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", 158 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", 159 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", 160 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", 161 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", 162 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", 163 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", 164 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", 165 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", 166 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", 167 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", 168 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", 169 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", 170 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", 171 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", 172 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", 173 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", 174 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", 175 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", 176 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", 177 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] 178 | 179 | custom_templates = ["itap of a {}.", 180 | "a bad photo of the {}.", 181 | "a origami {}.", 182 | "a photo of the large {}.", 183 | "a {} in a video game.", 184 | "art of the {}.", 185 | "a photo of the small {}."] 186 | 187 | 188 | imagenet_templates = ["a photo of a {}."] 189 | 190 | class ImageNet(): 191 | 192 | dataset_dir = 'imagenet' 193 | 194 | def __init__(self, root, num_shots, preprocess, train_preprocess=None, test_preprocess=None, load_cache=False, load_pre_feat=False): 195 | 196 | self.dataset_dir = os.path.join(root, self.dataset_dir) 197 | self.image_dir = os.path.join(self.dataset_dir, 'images') 198 | 199 | if train_preprocess is None: 200 | train_preprocess = transforms.Compose([ 201 | transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC), 202 | transforms.RandomHorizontalFlip(p=0.5), 203 | transforms.ToTensor(), 204 | transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 205 | ]) 206 | 207 | 208 | if test_preprocess is None: 209 | test_preprocess = preprocess 210 | 211 | self.train, self.val, self.test = None, None, None 212 | 213 | if not load_cache and num_shots > 0: 214 | self.train = datasets.ImageFolder(os.path.join(os.path.join(self.dataset_dir, 'train')), transform=train_preprocess) 215 | 216 | if not load_pre_feat and num_shots > 0: 217 | self.val = datasets.ImageFolder(os.path.join(os.path.join(self.dataset_dir, 'train')), transform=preprocess) 218 | 219 | if not load_pre_feat: 220 | self.test = datasets.ImageFolder(os.path.join(os.path.join(self.dataset_dir, 'val')), transform=test_preprocess) 221 | 222 | num_shots_val = min(4, num_shots) 223 | 224 | self.template = imagenet_templates 225 | self.custom_templates = custom_templates 226 | self.classnames = imagenet_classes 227 | 228 | if not load_pre_feat and num_shots > 0: 229 | split_by_label_dict = defaultdict(list) 230 | for i in range(len(self.train.imgs)): 231 | split_by_label_dict[self.train.targets[i]].append(self.train.imgs[i]) 232 | 233 | imgs = [] 234 | targets = [] 235 | imgs_val = [] 236 | targets_val = [] 237 | 238 | for label, items in split_by_label_dict.items(): 239 | samples = random.sample(items, num_shots + num_shots_val) 240 | imgs = imgs + samples[0:num_shots] 241 | imgs_val = imgs_val + samples[num_shots:num_shots+num_shots_val] 242 | targets = targets + [label for i in range(num_shots)] 243 | targets_val = targets_val + [label for i in range(num_shots_val)] 244 | 245 | self.train.imgs = imgs 246 | self.train.targets = targets 247 | self.train.samples = imgs 248 | 249 | self.val.imgs = imgs_val 250 | self.val.targets = targets_val 251 | self.val.samples = imgs_val 252 | 253 | -------------------------------------------------------------------------------- /datasets/oxford_flowers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from scipy.io import loadmat 4 | from collections import defaultdict 5 | 6 | from .oxford_pets import OxfordPets 7 | from .utils import Datum, DatasetBase, read_json 8 | 9 | 10 | template = ['a photo of a {}, a type of flower.'] 11 | 12 | 13 | class OxfordFlowers(DatasetBase): 14 | 15 | dataset_dir = 'Flower102' 16 | 17 | def __init__(self, root, num_shots): 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, 'jpg') 20 | self.label_file = os.path.join(self.dataset_dir, 'imagelabels.mat') 21 | self.lab2cname_file = os.path.join(self.dataset_dir, 'cat_to_name.json') 22 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordFlowers.json') 23 | 24 | self.template = template 25 | 26 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 27 | n_shots_val = min(num_shots, 4) 28 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 29 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 30 | 31 | super().__init__(train_x=train, val=val, test=test) 32 | 33 | def read_data(self): 34 | tracker = defaultdict(list) 35 | label_file = loadmat(self.label_file)['labels'][0] 36 | for i, label in enumerate(label_file): 37 | imname = f'image_{str(i + 1).zfill(5)}.jpg' 38 | impath = os.path.join(self.image_dir, imname) 39 | label = int(label) 40 | tracker[label].append(impath) 41 | 42 | print('Splitting data into 50% train, 20% val, and 30% test') 43 | 44 | def _collate(ims, y, c): 45 | items = [] 46 | for im in ims: 47 | item = Datum( 48 | impath=im, 49 | label=y-1, # convert to 0-based label 50 | classname=c 51 | ) 52 | items.append(item) 53 | return items 54 | 55 | lab2cname = read_json(self.lab2cname_file) 56 | train, val, test = [], [], [] 57 | for label, impaths in tracker.items(): 58 | random.shuffle(impaths) 59 | n_total = len(impaths) 60 | n_train = round(n_total * 0.5) 61 | n_val = round(n_total * 0.2) 62 | n_test = n_total - n_train - n_val 63 | assert n_train > 0 and n_val > 0 and n_test > 0 64 | cname = lab2cname[str(label)] 65 | train.extend(_collate(impaths[:n_train], label, cname)) 66 | val.extend(_collate(impaths[n_train:n_train+n_val], label, cname)) 67 | test.extend(_collate(impaths[n_train+n_val:], label, cname)) 68 | 69 | return train, val, test -------------------------------------------------------------------------------- /datasets/oxford_pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | from collections import defaultdict 5 | 6 | import torchvision.transforms as transforms 7 | 8 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 9 | 10 | 11 | template = ['a photo of a {}, a type of pet.'] 12 | 13 | 14 | class OxfordPets(DatasetBase): 15 | 16 | dataset_dir = 'OxfordPets' 17 | 18 | def __init__(self, root, num_shots): 19 | self.dataset_dir = os.path.join(root, self.dataset_dir) 20 | self.image_dir = os.path.join(self.dataset_dir, 'images') 21 | self.anno_dir = os.path.join(self.dataset_dir, 'annotations') 22 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordPets.json') 23 | 24 | self.template = template 25 | 26 | train, val, test = self.read_split(self.split_path, self.image_dir) 27 | n_shots_val = min(num_shots, 4) 28 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 29 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 30 | 31 | super().__init__(train_x=train, val=val, test=test) 32 | 33 | def read_data(self, split_file): 34 | filepath = os.path.join(self.anno_dir, split_file) 35 | items = [] 36 | 37 | with open(filepath, 'r') as f: 38 | lines = f.readlines() 39 | for line in lines: 40 | line = line.strip() 41 | imname, label, species, _ = line.split(' ') 42 | breed = imname.split('_')[:-1] 43 | breed = '_'.join(breed) 44 | breed = breed.lower() 45 | imname += '.jpg' 46 | impath = os.path.join(self.image_dir, imname) 47 | label = int(label) - 1 # convert to 0-based index 48 | item = Datum( 49 | impath=impath, 50 | label=label, 51 | classname=breed 52 | ) 53 | items.append(item) 54 | 55 | return items 56 | 57 | @staticmethod 58 | def split_trainval(trainval, p_val=0.2): 59 | p_trn = 1 - p_val 60 | print(f'Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val') 61 | tracker = defaultdict(list) 62 | for idx, item in enumerate(trainval): 63 | label = item.label 64 | tracker[label].append(idx) 65 | 66 | train, val = [], [] 67 | for label, idxs in tracker.items(): 68 | n_val = round(len(idxs) * p_val) 69 | assert n_val > 0 70 | random.shuffle(idxs) 71 | for n, idx in enumerate(idxs): 72 | item = trainval[idx] 73 | if n < n_val: 74 | val.append(item) 75 | else: 76 | train.append(item) 77 | 78 | return train, val 79 | 80 | @staticmethod 81 | def save_split(train, val, test, filepath, path_prefix): 82 | def _extract(items): 83 | out = [] 84 | for item in items: 85 | impath = item.impath 86 | label = item.label 87 | classname = item.classname 88 | impath = impath.replace(path_prefix, '') 89 | if impath.startswith('/'): 90 | impath = impath[1:] 91 | out.append((impath, label, classname)) 92 | return out 93 | 94 | train = _extract(train) 95 | val = _extract(val) 96 | test = _extract(test) 97 | 98 | split = { 99 | 'train': train, 100 | 'val': val, 101 | 'test': test 102 | } 103 | 104 | write_json(split, filepath) 105 | print(f'Saved split to {filepath}') 106 | 107 | @staticmethod 108 | def read_split(filepath, path_prefix): 109 | def _convert(items): 110 | out = [] 111 | for impath, label, classname in items: 112 | impath = os.path.join(path_prefix, impath) 113 | item = Datum( 114 | impath=impath, 115 | label=int(label), 116 | classname=classname 117 | ) 118 | out.append(item) 119 | return out 120 | 121 | split = read_json(filepath) 122 | train = _convert(split['train']) 123 | val = _convert(split['val']) 124 | test = _convert(split['test']) 125 | 126 | return train, val, test -------------------------------------------------------------------------------- /datasets/prepare_tta_datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from torch.utils.data import random_split 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | from torchvision.datasets import ImageFolder 7 | import numpy as np 8 | from PIL import Image 9 | 10 | 11 | 12 | common_corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur', 13 | 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog', 14 | 'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression'] 15 | 16 | 17 | def prepare_cifar(root, dataset, preprocess, level=5): 18 | 19 | if dataset.startswith('cifar100'): 20 | root += '/CIFAR/' 21 | size = 10000 22 | 23 | l = dataset.split('-') 24 | if len(l)==1: 25 | corruption = 'original' 26 | else: 27 | corruption = l[1] 28 | 29 | if corruption == 'original': 30 | testset = torchvision.datasets.CIFAR100(root=root, 31 | train=False, download=True, transform=preprocess) 32 | elif corruption in common_corruptions: 33 | testset_raw = np.load(root + '/CIFAR-100-C/%s.npy' % (corruption)) 34 | testset_raw = testset_raw[(level - 1) * size: level * size] 35 | testset = torchvision.datasets.CIFAR100(root=root, 36 | train=False, download=False, transform=preprocess) 37 | 38 | testset.data = testset_raw 39 | elif dataset.startswith('cifar10'): 40 | root += '/CIFAR/' 41 | size = 10000 42 | 43 | l = dataset.split('-') 44 | if len(l)==1: 45 | corruption = 'original' 46 | else: 47 | corruption = l[1] 48 | if corruption == 'original': 49 | testset = torchvision.datasets.CIFAR10(root=root, 50 | train=False, download=False, transform=preprocess) 51 | elif corruption in common_corruptions: 52 | testset_raw = np.load(root + '/CIFAR-10-C/%s.npy' % (corruption)) 53 | testset_raw = testset_raw[(level - 1) * size: level * size] 54 | testset = torchvision.datasets.CIFAR10(root=root, 55 | train=False, download=False, transform=preprocess) 56 | testset.data = testset_raw 57 | 58 | elif args.corruption == 'cifar_new': 59 | from utils.cifar_new import CIFAR_New 60 | teset = CIFAR_New(root=args.dataroot + '/CIFAR-10.1/', transform=te_transforms) 61 | permute = False 62 | else: 63 | raise Exception('Corruption not found!') 64 | 65 | 66 | 67 | testset.classnames = testset.classes 68 | testset.template = ['a photo of a {}.'] 69 | return testset 70 | 71 | 72 | """ 73 | elif args.dataset == 'visda': 74 | teset = VisdaTest(args.dataroot, transforms=visda_val) 75 | 76 | elif args.dataset == 'tiny-imagenet': 77 | if not hasattr(args, 'corruption') or args.corruption == 'original': 78 | teset = TinyImageNetDataset(args.dataroot + '/tiny-imagenet-200/', mode='val', transform=te_transforms) 79 | elif args.corruption in common_corruptions: 80 | teset = TinyImageNetCDataset(args.dataroot + '/Tiny-ImageNet-C/', corruption = args.corruption, level = args.level, 81 | transform=te_transforms) 82 | else: 83 | raise Exception('Dataset not found!') 84 | """ 85 | 86 | 87 | 88 | def prepare_val_data(args, transform=None): 89 | if args.dataset == 'visda': 90 | vset = ImageFolder(root=args.dataroot + 'validation/', transform=transform if transform is not None else visda_val) 91 | else: 92 | raise Exception('Dataset not found!') 93 | 94 | if args.distributed: 95 | v_sampler = torch.utils.data.distributed.DistributedSampler(vset) 96 | else: 97 | v_sampler = None 98 | if not hasattr(args, 'workers'): 99 | args.workers = 1 100 | vloader = torch.utils.data.DataLoader(vset, batch_size=args.batch_size, 101 | shuffle=(v_sampler is None), num_workers=args.workers, pin_memory=True, sampler=v_sampler, drop_last=True) 102 | return vloader, v_sampler 103 | 104 | def prepare_train_data(args, transform=None): 105 | if args.clip : 106 | tr_transforms = clip_transforms 107 | if args.dataset == 'cifar10': 108 | trset = torchvision.datasets.CIFAR10(root=args.dataroot, 109 | train=True, download=False, transform=tr_transforms) 110 | elif args.dataset == 'cifar100': 111 | trset = torchvision.datasets.CIFAR100(root=args.dataroot, train=True, download=False, transform=tr_transforms) 112 | elif args.dataset == 'visda': 113 | dataset = ImageFolder(root=args.dataroot + 'train/', transform=visda_train if transform is None else transform) 114 | trset, _ = random_split(dataset, [106678, 45719], generator=torch.Generator().manual_seed(args.seed)) 115 | elif args.dataset == 'tiny-imagenet': 116 | trset = TinyImageNetDataset(args.dataroot + '/tiny-imagenet-200/', transform=tinyimagenet_transforms) 117 | else: 118 | raise Exception('Dataset not found!') 119 | 120 | if args.distributed: 121 | tr_sampler = torch.utils.data.distributed.DistributedSampler(trset) 122 | else: 123 | tr_sampler = None 124 | 125 | if not hasattr(args, 'workers'): 126 | args.workers = 1 127 | trloader = torch.utils.data.DataLoader(trset, batch_size=args.batch_size, 128 | shuffle=(tr_sampler is None), num_workers=args.workers, pin_memory=True, sampler=tr_sampler) 129 | 130 | 131 | return trloader, tr_sampler, trset -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data.sampler import Sampler 3 | from typing import List 4 | from collections import defaultdict 5 | from numpy.random import dirichlet 6 | 7 | 8 | class LabelCorrelatedSampler(Sampler): 9 | def __init__(self, data_source, gamma, batch_size, slots=None): 10 | self.label_dict = defaultdict(list) 11 | self.classes = set() 12 | for i, item in enumerate(data_source): 13 | self.label_dict[item.label].append(i) 14 | self.classes.add(item.label) 15 | self.labels = list(self.label_dict.keys()) 16 | self.labels.sort() 17 | 18 | self.data_source = data_source 19 | self.gamma = gamma 20 | self.batch_size = batch_size 21 | self.num_class = len(self.classes) 22 | 23 | 24 | if slots is not None: 25 | self.num_slots = slots 26 | else: 27 | self.num_slots = self.num_class if self.num_class <= 100 else 100 28 | 29 | 30 | 31 | 32 | def __len__(self): 33 | return len(self.data_source) 34 | 35 | def __iter__(self): 36 | final_indices = [] 37 | label_distribution = dirichlet([self.gamma] * self.num_slots, self.num_class) 38 | 39 | for label in self.labels: 40 | indices = np.array(self.label_dict[label]) 41 | slot_indices = [[] for _ in range(self.num_slots)] 42 | 43 | partition = label_distribution[self.labels.index(label)] 44 | print('partition', partition) 45 | for s, ids in enumerate(np.split(indices, (np.cumsum(partition)[:-1] * len(indices)).astype(int))): 46 | print(s, ids) 47 | slot_indices[s].extend(ids) 48 | 49 | for s_ids in slot_indices: 50 | permutation = np.random.permutation(range(len(s_ids))) 51 | ids = [] 52 | for i in permutation: 53 | ids.extend(s_ids[i] if isinstance(s_ids[i], list) else [s_ids[i]]) 54 | final_indices.extend(ids) 55 | print(final_indices) 56 | return iter(final_indices) 57 | 58 | -------------------------------------------------------------------------------- /datasets/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.io import loadmat 3 | 4 | from .oxford_pets import OxfordPets 5 | from .utils import Datum, DatasetBase 6 | 7 | 8 | template = ['a photo of a {}.'] 9 | 10 | 11 | class StanfordCars(DatasetBase): 12 | 13 | dataset_dir = 'StanfordCars' 14 | 15 | def __init__(self, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_StanfordCars.json') 18 | 19 | self.template = template 20 | 21 | train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir) 22 | n_shots_val = min(num_shots, 4) 23 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 24 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 25 | 26 | super().__init__(train_x=train, val=val, test=test) 27 | 28 | def read_data(self, image_dir, anno_file, meta_file): 29 | anno_file = loadmat(anno_file)['annotations'][0] 30 | meta_file = loadmat(meta_file)['class_names'][0] 31 | items = [] 32 | 33 | for i in range(len(anno_file)): 34 | imname = anno_file[i]['fname'][0] 35 | impath = os.path.join(self.dataset_dir, image_dir, imname) 36 | label = anno_file[i]['class'][0, 0] 37 | label = int(label) - 1 # convert to 0-based index 38 | classname = meta_file[label][0] 39 | names = classname.split(' ') 40 | year = names.pop(-1) 41 | names.insert(0, year) 42 | classname = ' '.join(names) 43 | item = Datum( 44 | impath=impath, 45 | label=label, 46 | classname=classname 47 | ) 48 | items.append(item) 49 | 50 | return items -------------------------------------------------------------------------------- /datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | 5 | from .oxford_pets import OxfordPets 6 | 7 | 8 | template = ['a photo of a {}.'] 9 | 10 | 11 | class SUN397(DatasetBase): 12 | 13 | dataset_dir = 'SUN397' 14 | 15 | def __init__(self, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.image_dir = os.path.join(self.dataset_dir, 'SUN397') 18 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_SUN397.json') 19 | 20 | self.template = template 21 | 22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 23 | n_shots_val = min(num_shots, 4) 24 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 25 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 26 | 27 | super().__init__(train_x=train, val=val, test=test) 28 | 29 | def read_data(self, cname2lab, text_file): 30 | text_file = os.path.join(self.dataset_dir, text_file) 31 | items = [] 32 | 33 | with open(text_file, 'r') as f: 34 | lines = f.readlines() 35 | for line in lines: 36 | imname = line.strip()[1:] # remove / 37 | classname = os.path.dirname(imname) 38 | label = cname2lab[classname] 39 | impath = os.path.join(self.image_dir, imname) 40 | 41 | names = classname.split('/')[1:] # remove 1st letter 42 | names = names[::-1] # put words like indoor/outdoor at first 43 | classname = ' '.join(names) 44 | 45 | item = Datum( 46 | impath=impath, 47 | label=label, 48 | classname=classname 49 | ) 50 | items.append(item) 51 | 52 | return items 53 | -------------------------------------------------------------------------------- /datasets/tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import os 4 | 5 | from collections import defaultdict 6 | from torch.utils.data import Dataset 7 | 8 | from tqdm.autonotebook import tqdm 9 | 10 | 11 | def _add_channels(img,): 12 | while len(img.shape) < 3: # third axis is the channels 13 | img = np.expand_dims(img, axis=-1) 14 | while(img.shape[-1]) < 3: 15 | img = np.concatenate([img, img[:, :, -1:]], axis=-1) 16 | return img 17 | 18 | 19 | class TinyImageNetPaths: 20 | def __init__(self, root_dir): 21 | train_path = os.path.join(root_dir, 'train') 22 | val_path = os.path.join(root_dir, 'val') 23 | test_path = os.path.join(root_dir, 'test') 24 | 25 | wnids_path = os.path.join(root_dir, 'wnids.txt') 26 | words_path = os.path.join(root_dir, 'words.txt') 27 | 28 | self._make_paths(train_path, val_path, test_path, 29 | wnids_path, words_path) 30 | 31 | def _make_paths(self, train_path, val_path, test_path, 32 | wnids_path, words_path): 33 | self.ids = [] 34 | with open(wnids_path, 'r') as idf: 35 | for nid in idf: 36 | nid = nid.strip() 37 | self.ids.append(nid) 38 | self.nid_to_words = defaultdict(list) 39 | with open(words_path, 'r') as wf: 40 | for line in wf: 41 | nid, labels = line.split('\t') 42 | labels = list(map(lambda x: x.strip(), labels.split(','))) 43 | self.nid_to_words[nid].extend(labels) 44 | 45 | self.paths = { 46 | 'train': [], # [img_path, id, nid, box] 47 | 'val': [], # [img_path, id, nid, box] 48 | 'test': [] # img_path 49 | } 50 | 51 | # Get the test paths 52 | self.paths['test'] = list(map(lambda x: os.path.join(test_path, x), 53 | os.listdir(test_path))) 54 | # Get the validation paths and labels 55 | with open(os.path.join(val_path, 'val_annotations.txt')) as valf: 56 | for line in valf: 57 | fname, nid, x0, y0, x1, y1 = line.split() 58 | fname = os.path.join(val_path, 'images', fname) 59 | bbox = int(x0), int(y0), int(x1), int(y1) 60 | label_id = self.ids.index(nid) 61 | self.paths['val'].append((fname, label_id, nid, bbox)) 62 | 63 | # Get the training paths 64 | train_nids = os.listdir(train_path) 65 | for nid in train_nids: 66 | anno_path = os.path.join(train_path, nid, nid+'_boxes.txt') 67 | imgs_path = os.path.join(train_path, nid, 'images') 68 | label_id = self.ids.index(nid) 69 | with open(anno_path, 'r') as annof: 70 | for line in annof: 71 | fname, x0, y0, x1, y1 = line.split() 72 | fname = os.path.join(imgs_path, fname) 73 | bbox = int(x0), int(y0), int(x1), int(y1) 74 | self.paths['train'].append((fname, label_id, nid, bbox)) 75 | 76 | 77 | class TinyImageNetDataset(Dataset): 78 | def __init__(self, root_dir, mode='train', transform=None, max_samples=None): 79 | tinp = TinyImageNetPaths(root_dir) 80 | self.mode = mode 81 | self.label_idx = 1 # from [image, id, nid, box] 82 | self.transform = transform 83 | self.transform_results = dict() 84 | 85 | self.IMAGE_SHAPE = (64, 64, 3) 86 | 87 | self.img_data = [] 88 | self.label_data = [] 89 | 90 | self.max_samples = max_samples 91 | self.samples = tinp.paths[mode] 92 | self.samples_num = len(self.samples) 93 | 94 | if self.max_samples is not None: 95 | self.samples_num = min(self.max_samples, self.samples_num) 96 | self.samples = np.random.permutation(self.samples)[:self.samples_num] 97 | 98 | def __len__(self): 99 | return self.samples_num 100 | 101 | def __getitem__(self, idx): 102 | s = self.samples[idx] 103 | img = Image.open(s[0]) 104 | img_array = np.array(img) 105 | if img_array.shape[-1] < 3 or len(img_array.shape) < 3: 106 | img_array = _add_channels(img_array) 107 | img = Image.fromarray(img_array) 108 | lbl = None if self.mode == 'test' else s[self.label_idx] 109 | 110 | if self.transform: 111 | sample = self.transform(img) 112 | return sample, lbl 113 | 114 | 115 | # dataroot = "/home/davidoso/Documents/Data/" 116 | # a = TinyImageNetDataset(dataroot + 'tiny-imagenet-200/', preload=False) -------------------------------------------------------------------------------- /datasets/tiny_imagenet_c.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import os 4 | 5 | from collections import defaultdict 6 | from torch.utils.data import Dataset 7 | 8 | from tqdm.autonotebook import tqdm 9 | 10 | 11 | def _add_channels(img,): 12 | while len(img.shape) < 3: # third axis is the channels 13 | img = np.expand_dims(img, axis=-1) 14 | while(img.shape[-1]) < 3: 15 | img = np.concatenate([img, img[:, :, -1:]], axis=-1) 16 | return img 17 | 18 | 19 | class TinyImageNetPaths: 20 | def __init__(self, root_dir, corruption = 'original', level = 5): 21 | val_path = os.path.join(root_dir, corruption, str(level)) 22 | 23 | root_dir_original = root_dir.replace('Tiny-ImageNet-C', 'tiny-imagenet-200') 24 | wnids_path = os.path.join(root_dir_original, 'wnids.txt') 25 | words_path = os.path.join(root_dir_original, 'words.txt') 26 | 27 | self._make_paths(val_path, wnids_path, words_path) 28 | 29 | def _make_paths(self, corrupt_path, wnids_path, words_path): 30 | self.ids = [] 31 | with open(wnids_path, 'r') as idf: 32 | for nid in idf: 33 | nid = nid.strip() 34 | self.ids.append(nid) 35 | self.nid_to_words = defaultdict(list) 36 | with open(words_path, 'r') as wf: 37 | for line in wf: 38 | nid, labels = line.split('\t') 39 | labels = list(map(lambda x: x.strip(), labels.split(','))) 40 | self.nid_to_words[nid].extend(labels) 41 | 42 | self.paths = { 43 | 'corrupt': [], # [img_path, id, nid, box] 44 | } 45 | 46 | # Get the corruption paths 47 | corrupt_nids = os.listdir(corrupt_path) 48 | for nid in corrupt_nids: 49 | label_id = self.ids.index(nid) 50 | path = os.path.join(corrupt_path, nid) 51 | corrupt_name = os.listdir(path) 52 | for imgname in corrupt_name: 53 | fname = os.path.join(path, imgname) 54 | self.paths['corrupt'].append((fname, label_id, nid)) 55 | 56 | 57 | class TinyImageNetCDataset(Dataset): 58 | def __init__(self, root_dir, mode='corrupt', transform=None, max_samples=None, corruption = 'snow', level = 5): 59 | tinp = TinyImageNetPaths(root_dir, corruption=corruption, level=level) 60 | self.mode = mode 61 | self.label_idx = 1 # from [image, id, nid, box] 62 | self.transform = transform 63 | self.transform_results = dict() 64 | 65 | self.IMAGE_SHAPE = (64, 64, 3) 66 | 67 | self.img_data = [] 68 | self.label_data = [] 69 | 70 | self.max_samples = max_samples 71 | self.samples = tinp.paths[mode] 72 | self.samples_num = len(self.samples) 73 | 74 | if self.max_samples is not None: 75 | self.samples_num = min(self.max_samples, self.samples_num) 76 | self.samples = np.random.permutation(self.samples)[:self.samples_num] 77 | 78 | def __len__(self): 79 | return self.samples_num 80 | 81 | def __getitem__(self, idx): 82 | s = self.samples[idx] 83 | img = Image.open(s[0]) 84 | img_array = np.array(img) 85 | if img_array.shape[-1] < 3 or len(img_array.shape) < 3: 86 | img_array = _add_channels(img_array) 87 | img = Image.fromarray(img_array) 88 | lbl = None if self.mode == 'test' else s[self.label_idx] 89 | 90 | if self.transform: 91 | sample = self.transform(img) 92 | return sample, lbl 93 | 94 | 95 | # dataroot = "/home/davidoso/Documents/Data/" 96 | # a = TinyImageNetDataset(dataroot + 'tiny-imagenet-200/', preload=False) -------------------------------------------------------------------------------- /datasets/ucf101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | 5 | from .oxford_pets import OxfordPets 6 | 7 | 8 | template = ['a photo of a person doing {}.'] 9 | 10 | 11 | class UCF101(DatasetBase): 12 | 13 | dataset_dir = 'UCF101' 14 | 15 | def __init__(self, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.image_dir = os.path.join(self.dataset_dir, 'UCF-101-midframes') 18 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_UCF101.json') 19 | 20 | self.template = template 21 | 22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 23 | n_shots_val = min(num_shots, 4) 24 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 25 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 26 | 27 | super().__init__(train_x=train, val=val, test=test) 28 | 29 | def read_data(self, cname2lab, text_file): 30 | text_file = os.path.join(self.dataset_dir, text_file) 31 | items = [] 32 | 33 | with open(text_file, 'r') as f: 34 | lines = f.readlines() 35 | for line in lines: 36 | line = line.strip().split(' ')[0] # trainlist: filename, label 37 | action, filename = line.split('/') 38 | label = cname2lab[action] 39 | 40 | elements = re.findall('[A-Z][^A-Z]*', action) 41 | renamed_action = '_'.join(elements) 42 | 43 | filename = filename.replace('.avi', '.jpg') 44 | impath = os.path.join(self.image_dir, renamed_action, filename) 45 | 46 | item = Datum( 47 | impath=impath, 48 | label=label, 49 | classname=renamed_action 50 | ) 51 | items.append(item) 52 | 53 | return items 54 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import os.path as osp 4 | import tarfile 5 | import zipfile 6 | from collections import defaultdict 7 | import gdown 8 | import json 9 | import torch 10 | from torch.utils.data import Dataset as TorchDataset 11 | import torchvision.transforms as T 12 | from PIL import Image 13 | 14 | 15 | def read_json(fpath): 16 | """Read json file from a path.""" 17 | with open(fpath, 'r') as f: 18 | obj = json.load(f) 19 | return obj 20 | 21 | 22 | def write_json(obj, fpath): 23 | """Writes to a json file.""" 24 | if not osp.exists(osp.dirname(fpath)): 25 | os.makedirs(osp.dirname(fpath)) 26 | with open(fpath, 'w') as f: 27 | json.dump(obj, f, indent=4, separators=(',', ': ')) 28 | 29 | 30 | def read_image(path): 31 | """Read image from path using ``PIL.Image``. 32 | 33 | Args: 34 | path (str): path to an image. 35 | 36 | Returns: 37 | PIL image 38 | """ 39 | if not osp.exists(path): 40 | raise IOError('No file exists at {}'.format(path)) 41 | 42 | while True: 43 | try: 44 | img = Image.open(path).convert('RGB') 45 | return img 46 | except IOError: 47 | print( 48 | 'Cannot read image from {}, ' 49 | 'probably due to heavy IO. Will re-try'.format(path) 50 | ) 51 | 52 | 53 | def listdir_nohidden(path, sort=False): 54 | """List non-hidden items in a directory. 55 | 56 | Args: 57 | path (str): directory path. 58 | sort (bool): sort the items. 59 | """ 60 | items = [f for f in os.listdir(path) if not f.startswith('.') and 'sh' not in f] 61 | if sort: 62 | items.sort() 63 | return items 64 | 65 | 66 | class Datum: 67 | """Data instance which defines the basic attributes. 68 | 69 | Args: 70 | impath (str): image path. 71 | label (int): class label. 72 | domain (int): domain label. 73 | classname (str): class name. 74 | """ 75 | 76 | def __init__(self, impath='', label=0, domain=-1, classname=''): 77 | assert isinstance(impath, str) 78 | assert isinstance(label, int) 79 | assert isinstance(domain, int) 80 | assert isinstance(classname, str) 81 | 82 | self._impath = impath 83 | self._label = label 84 | self._domain = domain 85 | self._classname = classname 86 | 87 | @property 88 | def impath(self): 89 | return self._impath 90 | 91 | @property 92 | def label(self): 93 | return self._label 94 | 95 | @property 96 | def domain(self): 97 | return self._domain 98 | 99 | @property 100 | def classname(self): 101 | return self._classname 102 | 103 | 104 | class DatasetBase: 105 | """A unified dataset class for 106 | 1) domain adaptation 107 | 2) domain generalization 108 | 3) semi-supervised learning 109 | """ 110 | dataset_dir = '' # the directory where the dataset is stored 111 | domains = [] # string names of all domains 112 | 113 | def __init__(self, train_x=None, train_u=None, val=None, test=None): 114 | self._train_x = train_x # labeled training data 115 | self._train_u = train_u # unlabeled training data (optional) 116 | self._val = val # validation data (optional) 117 | self._test = test # test data 118 | 119 | self._num_classes = self.get_num_classes(train_x) 120 | self._lab2cname, self._classnames = self.get_lab2cname(train_x) 121 | 122 | @property 123 | def train_x(self): 124 | return self._train_x 125 | 126 | @property 127 | def train_u(self): 128 | return self._train_u 129 | 130 | @property 131 | def val(self): 132 | return self._val 133 | 134 | @property 135 | def test(self): 136 | return self._test 137 | 138 | @property 139 | def lab2cname(self): 140 | return self._lab2cname 141 | 142 | @property 143 | def classnames(self): 144 | return self._classnames 145 | 146 | @property 147 | def num_classes(self): 148 | return self._num_classes 149 | 150 | def get_num_classes(self, data_source): 151 | """Count number of classes. 152 | 153 | Args: 154 | data_source (list): a list of Datum objects. 155 | """ 156 | label_set = set() 157 | for item in data_source: 158 | label_set.add(item.label) 159 | return max(label_set) + 1 160 | 161 | def get_lab2cname(self, data_source): 162 | """Get a label-to-classname mapping (dict). 163 | 164 | Args: 165 | data_source (list): a list of Datum objects. 166 | """ 167 | container = set() 168 | for item in data_source: 169 | container.add((item.label, item.classname)) 170 | mapping = {label: classname for label, classname in container} 171 | labels = list(mapping.keys()) 172 | labels.sort() 173 | classnames = [mapping[label] for label in labels] 174 | return mapping, classnames 175 | 176 | def check_input_domains(self, source_domains, target_domains): 177 | self.is_input_domain_valid(source_domains) 178 | self.is_input_domain_valid(target_domains) 179 | 180 | def is_input_domain_valid(self, input_domains): 181 | for domain in input_domains: 182 | if domain not in self.domains: 183 | raise ValueError( 184 | 'Input domain must belong to {}, ' 185 | 'but got [{}]'.format(self.domains, domain) 186 | ) 187 | 188 | def download_data(self, url, dst, from_gdrive=True): 189 | if not osp.exists(osp.dirname(dst)): 190 | os.makedirs(osp.dirname(dst)) 191 | 192 | if from_gdrive: 193 | gdown.download(url, dst, quiet=False) 194 | else: 195 | raise NotImplementedError 196 | 197 | print('Extracting file ...') 198 | 199 | try: 200 | tar = tarfile.open(dst) 201 | tar.extractall(path=osp.dirname(dst)) 202 | tar.close() 203 | except: 204 | zip_ref = zipfile.ZipFile(dst, 'r') 205 | zip_ref.extractall(osp.dirname(dst)) 206 | zip_ref.close() 207 | 208 | print('File extracted to {}'.format(osp.dirname(dst))) 209 | 210 | def generate_fewshot_dataset( 211 | self, *data_sources, num_shots=-1, repeat=True 212 | ): 213 | """Generate a few-shot dataset (typically for the training set). 214 | 215 | This function is useful when one wants to evaluate a model 216 | in a few-shot learning setting where each class only contains 217 | a few number of images. 218 | 219 | Args: 220 | data_sources: each individual is a list containing Datum objects. 221 | num_shots (int): number of instances per class to sample. 222 | repeat (bool): repeat images if needed. 223 | """ 224 | if num_shots < 1: 225 | if len(data_sources) == 1: 226 | return data_sources[0] 227 | return data_sources 228 | 229 | output = [] 230 | 231 | for data_source in data_sources: 232 | tracker = self.split_dataset_by_label(data_source) 233 | dataset = [] 234 | 235 | for label, items in tracker.items(): 236 | if len(items) >= num_shots: 237 | sampled_items = random.sample(items, num_shots) 238 | else: 239 | if repeat: 240 | sampled_items = random.choices(items, k=num_shots) 241 | else: 242 | sampled_items = items 243 | dataset.extend(sampled_items) 244 | 245 | output.append(dataset) 246 | 247 | if len(output) == 1: 248 | return output[0] 249 | 250 | return output 251 | 252 | def split_dataset_by_label(self, data_source): 253 | """Split a dataset, i.e. a list of Datum objects, 254 | into class-specific groups stored in a dictionary. 255 | 256 | Args: 257 | data_source (list): a list of Datum objects. 258 | """ 259 | output = defaultdict(list) 260 | 261 | for item in data_source: 262 | output[item.label].append(item) 263 | 264 | return output 265 | 266 | def split_dataset_by_domain(self, data_source): 267 | """Split a dataset, i.e. a list of Datum objects, 268 | into domain-specific groups stored in a dictionary. 269 | 270 | Args: 271 | data_source (list): a list of Datum objects. 272 | """ 273 | output = defaultdict(list) 274 | 275 | for item in data_source: 276 | output[item.domain].append(item) 277 | 278 | return output 279 | 280 | 281 | class DatasetWrapper(TorchDataset): 282 | def __init__(self, data_source, input_size, transform=None, is_train=False, 283 | return_img0=False, k_tfm=1): 284 | self.data_source = data_source 285 | self.transform = transform # accept list (tuple) as input 286 | self.is_train = is_train 287 | # Augmenting an image K>1 times is only allowed during training 288 | self.k_tfm = k_tfm if is_train else 1 289 | self.return_img0 = return_img0 290 | 291 | if self.k_tfm > 1 and transform is None: 292 | raise ValueError( 293 | 'Cannot augment the image {} times ' 294 | 'because transform is None'.format(self.k_tfm) 295 | ) 296 | 297 | # Build transform that doesn't apply any data augmentation 298 | interp_mode = T.InterpolationMode.BICUBIC 299 | to_tensor = [] 300 | to_tensor += [T.Resize(input_size, interpolation=interp_mode)] 301 | to_tensor += [T.ToTensor()] 302 | normalize = T.Normalize( 303 | mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) 304 | ) 305 | to_tensor += [normalize] 306 | self.to_tensor = T.Compose(to_tensor) 307 | 308 | def __len__(self): 309 | return len(self.data_source) 310 | 311 | def __getitem__(self, idx): 312 | item = self.data_source[idx] 313 | 314 | output = { 315 | 'label': item.label, 316 | 'domain': item.domain, 317 | 'impath': item.impath 318 | } 319 | 320 | img0 = read_image(item.impath) 321 | 322 | if self.transform is not None: 323 | if isinstance(self.transform, (list, tuple)): 324 | for i, tfm in enumerate(self.transform): 325 | img = self._transform_image(tfm, img0) 326 | keyname = 'img' 327 | if (i + 1) > 1: 328 | keyname += str(i + 1) 329 | output[keyname] = img 330 | else: 331 | img = self._transform_image(self.transform, img0) 332 | output['img'] = img 333 | 334 | if self.return_img0: 335 | output['img0'] = self.to_tensor(img0) 336 | 337 | return output['img'], output['label'] 338 | 339 | def _transform_image(self, tfm, img0): 340 | img_list = [] 341 | 342 | for k in range(self.k_tfm): 343 | img_list.append(tfm(img0)) 344 | 345 | img = img_list 346 | if len(img) == 1: 347 | img = img[0] 348 | 349 | return img 350 | 351 | 352 | def build_data_loader( 353 | data_source=None, 354 | batch_size=64, 355 | input_size=224, 356 | tfm=None, 357 | is_train=True, 358 | shuffle=False, 359 | dataset_wrapper=None, 360 | sampler=None, 361 | num_workers = 8 362 | ): 363 | 364 | if dataset_wrapper is None: 365 | dataset_wrapper = DatasetWrapper 366 | 367 | # Build data loader 368 | data_loader = torch.utils.data.DataLoader( 369 | dataset_wrapper(data_source, input_size=input_size, transform=tfm, is_train=is_train), 370 | batch_size=batch_size, 371 | num_workers=num_workers, 372 | shuffle=shuffle, 373 | drop_last=False, 374 | pin_memory=(torch.cuda.is_available()), 375 | sampler=sampler, 376 | ) 377 | assert len(data_loader) > 0 378 | 379 | return data_loader 380 | -------------------------------------------------------------------------------- /datasets/visda.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | from typing import Callable, Optional 5 | 6 | 7 | class VisdaTest(Dataset): 8 | def __init__(self, root: str, transforms: Optional[Callable] = None): 9 | self.root = root 10 | self.transforms = transforms 11 | self.img_list = np.loadtxt(root + 'image_list.txt', dtype=str) 12 | 13 | def __len__(self): 14 | return self.img_list.shape[0] 15 | 16 | def __getitem__(self, idx): 17 | name = self.img_list[idx][0] 18 | label = int(self.img_list[idx][1]) 19 | 20 | img = Image.open(self.root + 'test/' + name) 21 | if self.transforms is not None: 22 | img = self.transforms(img) 23 | 24 | return img, label -------------------------------------------------------------------------------- /images/abstract_barplot_github_version.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cfuchs2023/OGA/b91d449a28d0958849c43dc7f698405b9fcfe4b4/images/abstract_barplot_github_version.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | import runner 4 | 5 | def get_arguments(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--data_root_path', type=str) 8 | parser.add_argument('-d', '--datasets', default = ['dtd'], nargs="*", help = 'List of datasets to adapt on.') 9 | parser.add_argument('--adapt_method_name', default='OGA', type=str, 10 | help='The name of the online adaption method to use. One of TDA, DMN or OGA.') 11 | parser.add_argument('--backbone', default='vit_b16', type=str, help = 'Name of the backbone to use. Examples : vit_b16 or rn101.') 12 | parser.add_argument('--root_cache_path', default = None, type = str, help = 'Root path of cached features and targets. Features and target should be located at root_cache_path/{dataset}/cache. Defaults to data_root_path internally.') 13 | parser.add_argument('--root_prompts_path', default = None, type = str, help = 'Path where learned prompts should be located when using coop or taskres.') #TODO: clarify 14 | parser.add_argument('--root_save_path', default = None, type = str, help = 'Path where results are stored. Defaults to data_root_path/results internally.') 15 | parser.add_argument('--run_name', default = None, type = str, help = 'Name of the results files. Has an internal default value and can be left to None. The code outputs two results files, {run_name}.json and {run_name}.pickle.') 16 | parser.add_argument('--n_runs', default = 100, type = int, help = 'Number of runs for each dataset.') 17 | parser.add_argument('--batch_size', default = 32, type = int) 18 | parser.add_argument('--shot_capacity', default = 8, type = int, help = 'Maximum size of the memory of the online adaptation method, expressed in number of shots per class.') 19 | parser.add_argument('--check_on_fullset' , default = False, type = int, help = 'Wether to evaluate the accuracy of the adapted model on the complete test set at regular intervals. Results are stored in {run_name}.pickle.') 20 | parser.add_argument('--master_seed', default = 42, type = int, help = 'Master seed for generating identical runs. For a fixed tuple (n_runs, master_seed), the runs are always the same regardless of the method.') 21 | parser.add_argument('--prompts_types', default = 'standard', type = str, help = 'Type of prompts. One of \' standard \', ') #TODO 22 | parser.add_argument('--prompts_n_shots', default = None, type = int) 23 | parser.add_argument('--prompts_seed', default = None, type = int) 24 | 25 | # Parse initial arguments 26 | args, remaining_args = parser.parse_known_args() 27 | 28 | # Conditionally add named arguments for OGA 29 | if args.adapt_method_name == 'OGA': 30 | oga_parser = argparse.ArgumentParser() 31 | oga_parser.add_argument('--tau', default = 0.05, type=float, help="tau parameter of the MAP estimator.") 32 | oga_parser.add_argument('--sig_type', default = 'RidgeMoorePenrose',type=str, help="Type of the estimator of inverse matrix. One of Ridge (always use Bayes-Ridge estimator) , MoorePenrose (always use (pseudo-)inverse) or RidgeMoorePenrose (switch estimator based on the number of samples available for estimation). ") 33 | oga_parser.add_argument('--normalize_mu', default = False, type=bool, help="Wether to normalize mu after estimation.") 34 | oga_args = oga_parser.parse_args(remaining_args) 35 | 36 | # Add parsed OGA-specific parameters to the main args 37 | # for key, value in vars(oga_args).items(): 38 | # setattr(args, key, value) 39 | args.OGA_params = oga_args 40 | print(args) 41 | elif args.adapt_method_name == 'TDA': 42 | tda_parser = argparse.ArgumentParser() 43 | tda_parser.add_argument('--pos_cache_logits_scale', default = 2.0, type = float, help = "Scaling factor for logits obtained with the positive cache.") 44 | tda_parser.add_argument('--neg_cache_logits_scale', default = 0.117, type = float, help = 'Scaling factor for logits obtained with the negative cache.') 45 | tda_parser.add_argument('--neg_cache_capacity', default = 3, type = int, help = 'Maximum size of the negative cache in number of shots per class. The size of the positive cache is computed to reach the target shot_capacity.') 46 | 47 | tda_args = tda_parser.parse_args(remaining_args) 48 | args.TDA_params = tda_args 49 | 50 | elif args.adapt_method_name == 'DMN': 51 | dmn_parser = argparse.ArgumentParser() 52 | dmn_parser.add_argument('--DMN_prob_factor', default = 1.0, type = float) 53 | dmn_args = dmn_parser.parse_args(remaining_args) 54 | args.DMN_params = dmn_args 55 | 56 | return args 57 | 58 | 59 | def main(): 60 | args = get_arguments() 61 | assert args.data_root_path is not None 62 | runner.run(args) 63 | return None 64 | 65 | 66 | if __name__ == '__main__': 67 | main() -------------------------------------------------------------------------------- /runner.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def run(args): 4 | print('Beginning import') 5 | import torch 6 | import os 7 | from tqdm import tqdm 8 | import clip 9 | import utils as uti 10 | # from utils import datasets #Dict with all possible datasets 11 | # from utils import backbones #Dict with the possible backbones 12 | import pickle 13 | import json 14 | from OGA import OGA_core 15 | from TDA import TDA_core as tdac 16 | import numpy as np 17 | import DMN.DMN_clip_wrapper as DMN_clip_wrapper 18 | import DMN.DMN_core as DMN_core 19 | print('Import done') 20 | 21 | 22 | # ==================== Get arguments 23 | root_path = args.data_root_path 24 | 25 | if args.root_save_path is not None: 26 | base_save_path = args.root_save_path 27 | else: 28 | base_save_path = os.path.join(root_path, 'results') 29 | os.makedirs(base_save_path, exist_ok = True) 30 | 31 | if args.root_cache_path is not None: 32 | base_cache_dir = args.root_cache_path 33 | else: 34 | base_cache_dir = root_path 35 | 36 | if args.root_prompts_path is not None: 37 | root_prompts_path = args.root_prompts_path 38 | else: 39 | root_prompts_path = root_path 40 | 41 | model_name = args.backbone 42 | n_runs = args.n_runs #100 #only for random tasks 43 | 44 | prompts_types = args.prompts_types 45 | prompts_n_shots = args.prompts_n_shots 46 | prompts_seed = args.prompts_seed 47 | batch_size = args.batch_size 48 | shot_capacity = args.shot_capacity # size of the memory in shots per class 49 | 50 | check_on_fullset = args.check_on_fullset #Evaluate accuracy of the adapted model on the complete test set every X batches 51 | adapt_method_name = args.adapt_method_name 52 | 53 | # ==================== Generate tasks seeds from master seed 54 | rng = np.random.default_rng(args.master_seed) 55 | shuffles_seed = rng.choice(range(10000*n_runs), size = n_runs, replace = False) 56 | 57 | 58 | # ==================== Load clip model 59 | clip_model, preprocess = clip.load(uti.backbones[model_name]) 60 | 61 | # =================== Get run name for saving results 62 | if args.run_name is not None: 63 | run_name = args.run_name 64 | else: 65 | if prompts_types in ['coop-fewshot', 'taskres-fewshot']: 66 | pname = prompts_types.split('-')[0] 67 | prompts_suffix = f'prompts_{pname}_seed_{prompts_seed}_shots_{prompts_n_shots}' 68 | else: 69 | pname = prompts_types 70 | prompts_suffix = f'prompts_{pname}' 71 | 72 | if adapt_method_name == 'OGA': 73 | method_suffix = f'OGA_sig_{args.OGA_params.sig_type}_shot_capacity_{shot_capacity}' 74 | elif adapt_method_name == 'TDA': 75 | method_suffix = f'TDA_shot_capacity_{shot_capacity}_neg_capacity_{args.TDA_params.neg_cache_capacity}' 76 | elif adapt_method_name == 'DMN': 77 | method_suffix = f'DMN_shot_capacity_{shot_capacity}' 78 | 79 | 80 | run_name = f'{method_suffix}_{prompts_suffix}' 81 | 82 | # ==================== Prepare results dict with some arguments 83 | resu = {} #Dict storing all results 84 | if adapt_method_name == 'OGA': 85 | tau = args.OGA_params.tau 86 | if type(tau) is float: 87 | resu['tau'] = tau #.cpu().numpy().tolist() 88 | else: 89 | resu['tau'] = tau.cpu().numpy().tolist() 90 | 91 | # ==================== Loop over requested datasets 92 | for dataset_name in args.datasets: 93 | resu[dataset_name] = {} 94 | resu[dataset_name][adapt_method_name] = {} 95 | if check_on_fullset: 96 | if dataset_name == 'food101': 97 | check_on_fullsets_interval = 15 98 | elif dataset_name == 'stanford_cars': 99 | check_on_fullsets_interval = 10 100 | else: 101 | check_on_fullsets_interval = 5 102 | 103 | # ******* Load features and labels 104 | if base_cache_dir is None: 105 | cache_dir = os.path.join(root_path, uti.datasets[dataset_name], 'cache') 106 | else: 107 | cache_dir = os.path.join(base_cache_dir, uti.datasets[dataset_name], 'cache') 108 | 109 | train_loader, val_loader, test_loader, dataset,\ 110 | features_and_labels\ 111 | = uti.load_features(dataset_name, 112 | root_path, 113 | cache_dir, 114 | preprocess, 115 | clip_model, 116 | uti.backbones[model_name], 117 | splits = ['test']) 118 | test_features, test_labels = features_and_labels 119 | K = torch.max(test_labels)+1 120 | d = test_features.shape[-1] 121 | 122 | # ******* Load encoded textual prompts 123 | clip_prototypes = uti.load_clip_classifier(dataset_name, 124 | root_prompts_path, 125 | prompts_types = prompts_types, 126 | dataset = dataset, 127 | prompts_n_shots = prompts_n_shots, 128 | seed = prompts_seed, 129 | backbone = model_name, 130 | num_classes = K, 131 | model_dim = d, 132 | clip_model = clip_model, ) 133 | 134 | 135 | # ******* Check zero-shot performance 136 | temp = 100 137 | zs_logits = torch.zeros((test_features.shape[0], K), dtype = torch.float64, device = test_features.device) 138 | zs_logits[...] = temp * test_features@clip_prototypes.squeeze() 139 | zs_pred = torch.argmax(zs_logits, dim = -1).cuda() 140 | zs_acc = torch.sum(zs_pred.cpu() == test_labels.cpu())/zs_pred.shape[0] 141 | print(f'Zero-shot accuracy : {zs_acc}') 142 | resu[dataset_name]['zero_shot'] = zs_acc.item() 143 | 144 | # ******* Cache zero-shot soft labels and zero shot entropies 145 | zs_probs = zs_logits.softmax(-1).cuda() 146 | zs_entropy = - torch.sum(torch.log(zs_probs+1e-9)*zs_probs, dim = -1) 147 | 148 | 149 | # ******* Prepare list of results 150 | all_acc_batchs_all_tasks_ = [] 151 | if check_on_fullset: # If evaluating on full test set at regular intervals is requested 152 | all_full_datasets_accuracies_ = [] 153 | 154 | # ******* Loop over tasks 155 | for n_ in tqdm(range(shuffles_seed.shape[0])): 156 | if check_on_fullset: 157 | all_full_datasets_accuracies_.append([]) 158 | rng = np.random.default_rng(seed = shuffles_seed[n_]) 159 | shuffle = torch.tensor(rng.choice(range(test_features.shape[0]), 160 | test_features.shape[0], 161 | replace = False)) 162 | 163 | # ******* Initialize online adaptation model 164 | if adapt_method_name == 'OGA': 165 | oga_model = OGA_core.GaussAdapt(clip_prototypes.squeeze().T, 166 | shot_capacity = shot_capacity, 167 | sig_type = args.OGA_params.sig_type).cuda() 168 | elif adapt_method_name == 'TDA': 169 | TDA_neg_cache = tdac.TDA_NegCache(K, d, shot_capacity=args.TDA_params.neg_cache_capacity).cuda() 170 | TDA_pos_cache = tdac.TDA_PosCache(K, d, shot_capacity = shot_capacity - args.TDA_params.neg_cache_capacity).cuda() 171 | TDA_pos_cache_logits_scale = args.TDA_params.pos_cache_logits_scale 172 | TDA_neg_cache_logits_scale = args.TDA_params.neg_cache_logits_scale 173 | 174 | elif adapt_method_name == 'DMN': 175 | dmn_args = uti.get_default_dmn_args() 176 | dmn_args.memory_size = shot_capacity 177 | DMN_clip = DMN_clip_wrapper.DMNClipWrapper(clip_model, 178 | preprocess, 179 | 'cuda', 180 | dataset.classnames, 181 | batch_size, 182 | arch = uti.backbones[model_name], 183 | memory_size = shot_capacity) 184 | DMN_clip.reset_classnames(dataset) 185 | dmn = DMN_core.DualMem(dmn_args, feat_dim = test_features.shape[-1], class_num = K) 186 | dmn = dmn.cuda() 187 | DMN_clip.eval() 188 | dmn.eval() 189 | with torch.autocast("cuda", dtype = torch.float16), torch.no_grad(): 190 | text_feat, text_feat_full = DMN_clip.get_text_features() 191 | # ******* Initialize task 192 | start = 0 # Start of current batch slice 193 | end = 0 # End of current batch slice 194 | num_batch = 1 # Current batch number 195 | all_acc_batches_single_task_ = [] # Per batch accuracy of current task 196 | last_batch = False 197 | 198 | # ******* Loop over batches 199 | while not(last_batch): 200 | 201 | # ******* Get indexes of current batch 202 | start = end 203 | end = min(start+batch_size, test_features.shape[0]) 204 | indices = shuffle[start:end] 205 | if end == test_features.shape[0]: 206 | last_batch = True 207 | 208 | # ******* Evaluate accuracy on full test set if need be 209 | if check_on_fullset and (not(num_batch % check_on_fullsets_interval) or last_batch): 210 | if adapt_method_name == 'OGA': 211 | log_probs = oga_model.get_log_probs(test_features) 212 | z, p = oga_model.get_MAP(zs_probs, log_probs, tau = tau, simplex_p = True) 213 | pred = torch.argmax(z, dim = 1) 214 | elif adapt_method_name == 'TDA': 215 | pos_cache_logits = TDA_pos_cache.get_logits(test_features, beta = 5, alpha = 1) 216 | neg_cache_logits = TDA_neg_cache.get_logits(test_features, beta = 1) 217 | query_logits = zs_logits + TDA_pos_cache_logits_scale * pos_cache_logits + TDA_neg_cache_logits_scale * neg_cache_logits 218 | pred = torch.argmax(query_logits, dim = -1) 219 | 220 | acc = torch.sum(pred == test_labels, dim = 0)/test_labels.shape[0] 221 | all_full_datasets_accuracies_[-1].append([num_batch-1, acc.cpu().item()]) 222 | 223 | # ******* Get batch data 224 | batch_features = test_features[indices, :] 225 | batch_labels = test_labels[indices] 226 | batch_zs_pseudo_labels = zs_pred[indices] 227 | batch_zs_probs = zs_probs[indices] 228 | batch_zs_logits = zs_logits[indices, :] 229 | batch_zs_entropy = zs_entropy[indices] 230 | 231 | # ******* Update memory on batch 232 | if adapt_method_name == 'OGA': 233 | _ = oga_model.update_memory(batch_features, 234 | batch_zs_logits, 235 | batch_zs_probs, 236 | batch_zs_entropy, 237 | batch_zs_pseudo_labels, 238 | tau = args.OGA_params.tau, 239 | normalize_mu = args.OGA_params.normalize_mu) 240 | elif adapt_method_name == 'TDA': 241 | _ = TDA_pos_cache.update_memory(batch_features, batch_zs_logits) 242 | _ = TDA_neg_cache.update_memory(batch_features, batch_zs_logits) 243 | elif adapt_method_name == 'DMN': 244 | with torch.autocast("cuda"), torch.no_grad(): 245 | for ju,u in enumerate(indices): 246 | DMN_clip.image_features_global = batch_features[ju:ju+1,...] 247 | # We never use augmentations 248 | # confidence_prediction, selected_idx, confused_weak_output, confused_idx = select_confident_samples(img_text.softmax(1), 249 | # dmn_args.selection_p) 250 | dmn.init_pred = batch_zs_probs[ju:ju+1,:] 251 | dmn.update_memory_bank(DMN_clip) 252 | 253 | # ******* Predict on batch using updated memory 254 | if adapt_method_name == 'OGA': 255 | log_probs = oga_model.get_log_probs(batch_features) 256 | z, p = oga_model.get_MAP(batch_zs_probs, log_probs, tau = args.OGA_params.tau, simplex_p = True) 257 | pred = torch.argmax(z, dim = 1) 258 | elif adapt_method_name == 'TDA': 259 | pos_cache_logits = TDA_pos_cache.get_logits(batch_features, beta = 5, alpha = 1) 260 | neg_cache_logits = TDA_neg_cache.get_logits(batch_features, beta = 1) 261 | query_logits = batch_zs_logits + TDA_pos_cache_logits_scale * pos_cache_logits + TDA_neg_cache_logits_scale * neg_cache_logits 262 | pred = torch.argmax(query_logits, dim = -1) 263 | elif adapt_method_name == 'DMN': 264 | with torch.autocast("cuda"), torch.no_grad(): # 265 | all_img_logits = dmn.fast_get_image_pred(batch_features, DMN_clip, clip_prototypes) 266 | all_img_probs = all_img_logits.softmax(-1) 267 | final_probs = batch_zs_probs + args.DMN_params.DMN_prob_factor * all_img_probs 268 | pred = torch.argmax(final_probs, dim = -1) 269 | acc = torch.sum(pred == batch_labels, dim = 0)/batch_labels.shape[0] 270 | all_acc_batches_single_task_.append(acc.cpu()) 271 | acc = torch.sum(pred == batch_labels, dim = 0)/batch_labels.shape[0] 272 | 273 | # ******* Store accuracy on current batch 274 | all_acc_batches_single_task_.append(acc.cpu()) 275 | num_batch += 1 276 | 277 | 278 | all_acc_batches_single_task = torch.stack(all_acc_batches_single_task_) 279 | all_acc_batchs_all_tasks_.append(all_acc_batches_single_task) 280 | print(f'{adapt_method_name} average accuracy on task number {n_} : ', all_acc_batches_single_task.mean(dim=0).cpu()) 281 | 282 | # ******* Prepare results for disk writing 283 | all_acc_batchs_all_tasks = torch.stack(all_acc_batchs_all_tasks_) 284 | avg_per_task_accs = torch.mean(all_acc_batchs_all_tasks, dim = 1) 285 | avg_accs = torch.mean(avg_per_task_accs, dim = 0) 286 | resu[dataset_name][adapt_method_name]['avg_acc'] = avg_accs.cpu().numpy().tolist() 287 | if check_on_fullset: 288 | resu[dataset_name][adapt_method_name]['checks_on_full_dataset'] = torch.tensor(all_full_datasets_accuracies_) 289 | 290 | 291 | ignore_keys = ['all_acc_batchs_all_tasks', 292 | 'checks_on_full_dataset'] 293 | partial_resu = {} # Resu with only avg acc per dataset 294 | for dname in resu.keys(): 295 | if type(resu[dname]) is dict: 296 | partial_resu[dname] = {} 297 | partial_resu[dname]['zero_shot'] = resu[dname]['zero_shot'] 298 | for method_key in resu[dname].keys(): 299 | if type(resu[dname][method_key]) is dict: 300 | partial_resu[dname][method_key] = {} 301 | for key in resu[dname][method_key].keys(): 302 | if key not in ignore_keys: 303 | partial_resu[dname][method_key][key] = resu[dname][method_key][key] 304 | else: 305 | partial_resu[dname][method_key] = resu[dname][method_key] 306 | else: 307 | partial_resu[dname] = resu[dname] 308 | 309 | # ******* Store average acc / dataset in a json 310 | with open(os.path.join(base_save_path, run_name+'.json'), 'w') as f: 311 | json.dump(partial_resu,f) 312 | 313 | # ******* Store complete results in a pickle 314 | with open(os.path.join(base_save_path, run_name+'.pickle'), 'wb') as f: 315 | pickle.dump(resu,f) 316 | 317 | return None -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tqdm import tqdm 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import os 7 | import clip 8 | import json 9 | 10 | def clip_classifier(classnames, template, clip_model, reduce='mean', gpt=False, wordnet_dict=None): 11 | with torch.no_grad(): 12 | clip_weights = [] 13 | if wordnet_dict is not None: 14 | indices = [] 15 | i = 0 16 | for classname in classnames: 17 | allnames = [classname] + wordnet_dict[classname] 18 | for name in allnames: 19 | 20 | # Tokenize the prompts 21 | name = name.replace('_', ' ') 22 | 23 | texts = [t.format(name) for t in template] 24 | texts = clip.tokenize(texts).cuda() 25 | 26 | class_embeddings = clip_model.encode_text(texts) 27 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 28 | if reduce=='mean': 29 | class_embedding = class_embeddings.mean(dim=0) 30 | class_embedding /= class_embedding.norm() 31 | clip_weights.append(class_embedding) 32 | if reduce is None: 33 | class_embeddings /= class_embeddings.norm(dim=1, keepdim=True) 34 | clip_weights.append(class_embeddings) 35 | i+=1 36 | indices.append(i) 37 | 38 | return clip_weights, indices 39 | else: 40 | 41 | for classname in classnames: 42 | 43 | # Tokenize the prompts 44 | classname = classname.replace('_', ' ') 45 | 46 | if gpt: 47 | texts = template[classname] 48 | else: 49 | texts = [t.format(classname) for t in template] 50 | texts = clip.tokenize(texts).cuda() 51 | 52 | class_embeddings = clip_model.encode_text(texts) 53 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 54 | if reduce=='mean': 55 | class_embedding = class_embeddings.mean(dim=0) 56 | class_embedding /= class_embedding.norm() 57 | clip_weights.append(class_embedding) 58 | if reduce is None: 59 | class_embeddings /= class_embeddings.norm(dim=1, keepdim=True) 60 | clip_weights.append(class_embeddings) 61 | 62 | clip_weights = torch.stack(clip_weights, dim=-1).cuda() 63 | return clip_weights 64 | 65 | 66 | def get_samples_feature_and_labels(cache_dir, splits = ['test'], backbone_name = 'ViT-B/16', dataset_name = ''): 67 | model_cache_name = backbone_name.replace('/', '_') 68 | model_cache_name = model_cache_name.replace('-', '_') 69 | out = [] 70 | 71 | for spl in splits: 72 | features_path = os.path.join(cache_dir, f'{model_cache_name}_{spl}_features.pt') 73 | try: 74 | _features = torch.load(features_path).cuda() 75 | except FileNotFoundError(): 76 | raise FileNotFoundError(f'Could not find cached features at {features_path}. Run compute_features.py or check the --root_cache_path argument. ') 77 | _labels = torch.load(os.path.join(cache_dir, f'{spl}_target.pt')).cuda() 78 | out.append(_features) 79 | out.append(_labels) 80 | 81 | return out 82 | 83 | 84 | 85 | 86 | def get_all_features(cfg, train_loader, val_loader, test_loader, dataset, clip_model): 87 | clip_prototypes = clip_classifier(dataset.classnames, dataset.template, clip_model, reduce=None) 88 | test_features, test_labels = pre_load_features(cfg, "test", clip_model, test_loader) 89 | 90 | if val_loader is not None: 91 | val_features, val_labels = pre_load_features(cfg, "val", clip_model, val_loader) 92 | else: 93 | val_features, val_labels = None, None 94 | 95 | shot_features = None 96 | shot_labels = None 97 | 98 | if cfg['shots'] > 0: 99 | shot_features, shot_labels = build_cache_model(cfg, clip_model, train_loader, n_views=0, 100 | reduce=None) 101 | #val_features, val_labels = pre_load_features(cfg, "val", clip_model, val_loader) 102 | return shot_features, shot_labels, val_features, val_labels, test_features, test_labels, clip_prototypes 103 | 104 | 105 | def build_cache_model(cfg, clip_model, train_loader_cache, n_views=0, reduce=None): 106 | print('... for shot samples from train split:') 107 | 108 | if cfg['load_cache'] == False: 109 | cache_keys = [] 110 | cache_values = [] 111 | if n_views == 0: 112 | n_epochs =1 113 | else: 114 | n_epochs = n_views 115 | with torch.no_grad(): 116 | # Data augmentation for the cache model 117 | for augment_idx in range(n_epochs): 118 | train_features = [] 119 | train_labels = [] 120 | for i, (images, target) in enumerate(tqdm(train_loader_cache)): 121 | images = images.cuda() 122 | image_features = clip_model.encode_image(images) 123 | train_features.append(image_features) 124 | 125 | if augment_idx == 0: 126 | target = target.cuda() 127 | cache_values.append(target) 128 | 129 | cache_keys.append(torch.cat(train_features, dim=0).unsqueeze(0)) 130 | 131 | 132 | 133 | if n_views == 1: 134 | cache_keys = torch.cat(cache_keys, dim=0).mean(dim=0) 135 | cache_keys /= cache_keys.norm(dim=-1, keepdim=True) 136 | #cache_keys = cache_keys.permute(1, 0) 137 | else: 138 | cache_keys = torch.cat(cache_keys, dim=0) # [n_views, n_classes, n_features] 139 | if reduce == 'mean': 140 | cache_keys = cache_keys.mean(0, keepdim=True) 141 | 142 | cache_keys /= cache_keys.norm(dim=-1, keepdim=True) 143 | cache_keys.permute(0, 2, 1) 144 | 145 | cache_values = F.one_hot(torch.cat(cache_values, dim=0)).half() 146 | 147 | torch.save(cache_keys, cfg['cache_dir'] + '/keys_' + str(cfg['shots']) + "shots.pt") 148 | torch.save(cache_values, cfg['cache_dir'] + '/values_' + str(cfg['shots']) + "shots.pt") 149 | 150 | else: 151 | cache_keys = torch.load(cfg['cache_dir'] + '/keys_' + str(cfg['shots']) + "shots.pt") 152 | cache_values = torch.load(cfg['cache_dir'] + '/values_' + str(cfg['shots']) + "shots.pt") 153 | 154 | return cache_keys, cache_values 155 | 156 | 157 | 158 | 159 | 160 | 161 | def pre_load_features(cfg, split, clip_model, loader, n_views=1, mode = 'new'): 162 | 163 | #print('... from {} split:'.format(split)) 164 | if cfg['load_pre_feat'] == False: 165 | if mode == 'new': 166 | batch_size = loader.batch_size 167 | num_samples = len(loader.dataset) 168 | features, labels = [], [] 169 | 170 | with torch.no_grad(): 171 | 172 | for view in range(n_views): 173 | length = 0 174 | i = 0 175 | #for i, (images, target) in enumerate(tqdm(loader)): 176 | pbar = tqdm(total = num_samples) 177 | while i < num_samples: 178 | j = min(i+batch_size, num_samples) 179 | images = torch.vstack([loader.dataset[k][0][None,...] for k in range(i,j)]) 180 | target = torch.cat([torch.tensor(loader.dataset[k][1], dtype = torch.int64)[None] for k in range(i,j)]) 181 | if len(images.shape)==3: 182 | images = images[None,...] 183 | if n_views == 1: 184 | 185 | images, target = images.cuda(), target.cuda() 186 | 187 | 188 | image_features = clip_model.encode_image(images) 189 | 190 | image_features /= image_features.norm(dim=-1, keepdim=True) 191 | 192 | 193 | features.append(image_features.cpu()) 194 | labels.append(target.cpu()) 195 | else: 196 | images, target = images.cuda(), target.cuda() 197 | image_features = clip_model.encode_image(images) 198 | image_features /= image_features.norm(dim=-1, keepdim=True) 199 | if view == 0: 200 | labels.append(target.cpu()) 201 | if i ==0: 202 | mean_features = image_features 203 | else: 204 | mean_features = torch.cat((mean_features, image_features)) 205 | else: 206 | mean_features[length:length+image_features.size(0)] += image_features 207 | length += image_features.size(0) 208 | pbar.update(batch_size) 209 | i = j 210 | else: 211 | features, labels = [], [] 212 | with torch.no_grad(): 213 | for view in range(n_views): 214 | length = 0 215 | for i, (images, target) in enumerate(tqdm(loader)): 216 | if n_views == 1: 217 | 218 | images, target = images.cuda(), target.cuda() 219 | 220 | 221 | image_features = clip_model.encode_image(images) 222 | 223 | image_features /= image_features.norm(dim=-1, keepdim=True) 224 | 225 | 226 | features.append(image_features.cpu()) 227 | labels.append(target.cpu()) 228 | else: 229 | images, target = images.cuda(), target.cuda() 230 | image_features = clip_model.encode_image(images) 231 | image_features /= image_features.norm(dim=-1, keepdim=True) 232 | if view == 0: 233 | labels.append(target.cpu()) 234 | if i ==0: 235 | mean_features = image_features 236 | else: 237 | mean_features = torch.cat((mean_features, image_features)) 238 | else: 239 | mean_features[length:length+image_features.size(0)] += image_features 240 | length += image_features.size(0) 241 | 242 | if n_views > 1: 243 | mean_features = mean_features / n_views 244 | features = mean_features / mean_features.norm(dim=-1, keepdim=True) 245 | labels = torch.cat(labels) 246 | 247 | elif n_views==1: 248 | features = torch.cat(features) 249 | labels = torch.cat(labels) 250 | 251 | 252 | backbone_name = cfg['backbone'].replace('/', '_') 253 | backbone_name = backbone_name.replace('\\', '_') 254 | backbone_name = backbone_name.replace('-', '_') 255 | print() 256 | torch.save(features, cfg['cache_dir'] + "/" + f'{backbone_name}_{split}_features.pt') 257 | if not(os.path.exists(cfg['cache_dir'] + "/" + f'{split}_target.pt')): 258 | torch.save(labels, cfg['cache_dir'] + "/" + f'{split}_target.pt') 259 | 260 | else: 261 | print('LOADING FEATURES') 262 | try: 263 | features = torch.load(cfg['cache_dir'] + "/" + split + "_f.pt") 264 | labels = torch.load(cfg['cache_dir'] + "/" + split + "_l.pt") 265 | except FileNotFoundError: 266 | backbone_name = cfg['backbone'].replace('/', '_') 267 | backbone_name = backbone_name.replace('\\', '_') 268 | backbone_name = backbone_name.replace('-', '_') 269 | features = torch.load(os.path.join(cfg['cache_dir'], f'{backbone_name}_{split}_features.pt')) 270 | labels = torch.load(os.path.join(cfg['cache_dir'], f'{split}_target.pt')) 271 | 272 | return features, labels 273 | 274 | 275 | #%% 276 | 277 | datasets = { 278 | 'sun397':'SUN397', 279 | 'imagenet':'imagenet', 280 | 'fgvc_aircraft':'fgvc_aircraft', 281 | 'eurosat':'eurosat', 282 | 'food101':'Food101', 283 | 'caltech101':'Caltech101', 284 | 'oxford_pets':'OxfordPets', 285 | 'oxford_flowers':'Flower102', 286 | 'stanford_cars':'StanfordCars', 287 | 'dtd':'DTD', 288 | 'ucf101':'UCF101', 289 | } 290 | 291 | backbones = { 292 | 'vit_b16': 'ViT-B/16', 293 | 'rn50': 'RN50', 294 | 'vit_b32':'ViT-B/32', 295 | 'rn101': 'RN101', 296 | } 297 | 298 | #%% 299 | def get_ensemble_prompt(cname): 300 | 301 | texts = [ 302 | f'itap of a {cname}.', 303 | f'a bad photo of the {cname}.', 304 | f'a origami {cname}.', 305 | f'a photo of the large {cname}.', 306 | f'a {cname} in a video game.', 307 | f'art of the {cname}', 308 | f'a photo of the small {cname}' 309 | ] 310 | return texts 311 | 312 | #%% 313 | 314 | def load_clip_classifier(dataset_name, root_prompts_path, 315 | prompts_types = 'custom', 316 | dataset = None, prompts_n_shots = 4, seed = 1, backbone = 'vit_b16', 317 | clip_model = None, num_classes = None, model_dim = 512,): 318 | if prompts_types == 'standard': 319 | clip_prototypes = clip_classifier(dataset.classnames, dataset.template, clip_model) 320 | elif prompts_types == 'coop-fewshot': 321 | path = os.path.join(root_prompts_path, 'Few_shot', 'coop', backbone, f'{prompts_n_shots}shots', 322 | dataset_name, f'seed{seed}') 323 | clip_prototypes = torch.load(os.path.join(path, 'text_features.pt')).T 324 | elif prompts_types == 'taskres-fewshot': 325 | path = os.path.join(root_prompts_path, 'Few_shot', 'taskres', backbone, f'{prompts_n_shots}shots', 326 | dataset_name, f'seed{seed}') 327 | clip_prototypes = torch.load(os.path.join(path, 'text_features.pt')).T 328 | elif prompts_types == 'coop-16shotsimagenet': 329 | path = os.path.join(root_prompts_path, 'Cross-dataset', 'coop', backbone, f'{prompts_n_shots}shots', 330 | dataset_name, f'seed{seed}') 331 | clip_prototypes = torch.load(os.path.join(path, 'text_features.pt')).T 332 | elif prompts_types == 'cupl': 333 | # Load CuPL prompts 334 | prompt_datasetname = dataset_name.replace('_', '') 335 | if 'flowers' in dataset_name: 336 | prompt_datasetname = 'flowers102' 337 | prompt_path = os.path.join(root_prompts_path, 'cupl', f"CuPL_prompts_{prompt_datasetname}.json") 338 | with open(prompt_path, 'r') as f: 339 | cupl_prompts = json.load(f) 340 | 341 | # Encode cupl prompts 342 | tokenized_prompts = {} 343 | K = num_classes 344 | d = model_dim 345 | num_prompts = torch.zeros(K) 346 | avg_prompts = torch.zeros((K,d), dtype = torch.float16) 347 | 348 | for j,cname_ in enumerate(dataset.classnames): 349 | cname = cname_.replace('_', ' ') 350 | #print(f'class : {cname}, num_prompts : {len(cupl_prompts[cname])}') 351 | tokenized_prompts[cname] = clip.tokenize(cupl_prompts[cname]).cuda() 352 | num_prompts[j] = len(cupl_prompts[cname]) 353 | with torch.autocast("cuda"), torch.no_grad(): 354 | encoded_p_cname = clip_model.encode_text(tokenized_prompts[cname].cuda()) 355 | encoded_p_cname = encoded_p_cname/torch.linalg.norm(encoded_p_cname, dim = -1, keepdims = True) 356 | avg_prompts[j,...] = encoded_p_cname.mean(0) 357 | clip_prototypes = (avg_prompts/torch.linalg.norm(avg_prompts, dim = -1, keepdims = True)).T.cuda() 358 | elif prompts_types == 'custom_ensemble': 359 | # Encode prompts 360 | tokenized_prompts = {} 361 | K = num_classes 362 | d = model_dim 363 | num_prompts = torch.zeros(K) 364 | avg_prompts = torch.zeros((K,d), dtype = torch.float16) 365 | 366 | for j,cname_ in enumerate(dataset.classnames): 367 | cname = cname_.replace('_', ' ') 368 | #print(f'class : {cname}, num_prompts : {len(cupl_prompts[cname])}') 369 | txts = get_ensemble_prompt(cname) 370 | tokenized_prompts[cname] = clip.tokenize(txts).cuda() 371 | num_prompts[j] = len(txts) 372 | with torch.autocast("cuda"), torch.no_grad(): 373 | encoded_p_cname = clip_model.encode_text(tokenized_prompts[cname].cuda()) 374 | encoded_p_cname = encoded_p_cname/torch.linalg.norm(encoded_p_cname, dim = -1, keepdims = True) 375 | avg_prompts[j,...] = encoded_p_cname.mean(0) 376 | clip_prototypes = (avg_prompts/torch.linalg.norm(avg_prompts, dim = -1, keepdims = True)).T.cuda() 377 | return clip_prototypes.type(torch.float16) 378 | 379 | #%% 380 | import datasets as dts 381 | def load_features(dataset_name, 382 | root_path, 383 | cache_dir, 384 | preprocess, 385 | clip_model, 386 | backbone_name, 387 | splits = ['train', 'test'], 388 | load_loaders = True): 389 | cfg = {} 390 | print(f'============ DATASET : {dataset_name}') 391 | 392 | cfg['dataset'] = datasets[dataset_name] 393 | 394 | cfg['root_path'] = root_path 395 | cfg['shots'] = 0 396 | cfg['load_pre_feat'] = True 397 | cfg['cache_dir'] = cache_dir 398 | if dataset_name == 'imagenet': 399 | cfg['load_cache'] = False 400 | if load_loaders: 401 | train_loader, val_loader, test_loader, dataset = dts.get_all_dataloaders(cfg, preprocess, dirichlet=None) 402 | else: 403 | dataset = dts.dataset_list[dataset_name](cfg['root_path'], cfg['shots']) 404 | train_loader, val_loader, test_loader = None,None,None 405 | features_and_labels = get_samples_feature_and_labels(cache_dir, 406 | splits = splits, 407 | backbone_name = backbone_name, 408 | dataset_name = dataset_name) 409 | 410 | 411 | return train_loader, val_loader, test_loader, dataset, features_and_labels 412 | 413 | #%% 414 | def get_default_dmn_args(): 415 | class Args: 416 | def __init__(self): 417 | return None 418 | dmn_args = Args() 419 | dmn_args.indice = 0 420 | dmn_args.shared_param = None 421 | dmn_args.mapping = 'bias' 422 | dmn_args.position = 'all' 423 | dmn_args.n_shot = 0 #zero shot 424 | dmn_args.n_augments = 0 425 | dmn_args.selection_p = 0.1 426 | return dmn_args --------------------------------------------------------------------------------