├── README.md ├── configs.py ├── conft.sh ├── data ├── __init__.py ├── additional_transforms.py ├── cssf_datamgr_custom_collate.py └── datamgr.py ├── filelists ├── cars │ └── write_cars_filelist.py ├── cub │ └── write_cub_filelist.py ├── miniImagenet │ └── write_miniImagenet_filelist.py ├── places │ └── write_places_filelist.py ├── plantae │ └── write_plantae_filelist.py └── process.py ├── finetune.py ├── methods ├── __init__.py ├── backbone.py └── weight_imprint_based │ ├── ConCE.py │ └── __init__.py ├── mt_conft.sh ├── options.py ├── output └── checkpoints │ ├── download_encoder.py │ ├── download_encoders.py │ └── download_models.py ├── read_parallel_results.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Contrastive-Finetuning 2 | This repo is the official implementation of the following paper: 3 | **"On the Importance of Distractors for Few-Shot Classification"** [Paper](https://arxiv.org/abs/2109.09883) 4 | 5 | If you find this repo useful for your research, please consider citing this paper 6 | ``` 7 | @misc{das2021importance, 8 | title={On the Importance of Distractors for Few-Shot Classification}, 9 | author={Rajshekhar Das and Yu-Xiong Wang and JoséM. F. Moura}, 10 | year={2021}, 11 | eprint={2109.09883}, 12 | archivePrefix={arXiv}, 13 | primaryClass={cs.CV} 14 | } 15 | ``` 16 | # Dataset Download 17 | 18 | To set up the dataset, follow the exact steps outlined in [here](https://github.com/hytseng0509/CrossDomainFewShot#datasets). 19 | 20 | # Pretrained Model 21 | 22 | To download the pretrained backbone model, follow the exact steps outlined in [here](https://github.com/hytseng0509/CrossDomainFewShot#feature-encoder-pre-training) 23 | 24 | # Running 25 | 26 | * To run contrastive finetuning on `cub` data (default target domain) with the downloaded pretrained model, simply run ```bash conft.sh``` 27 | * To run the multi-task variant on the same target domain, run ```bash mt_conft.sh``` 28 | * To change the target domain or other hyperparameters, refer to `conft.sh` and `mt_conft.sh` 29 | 30 | # Acknowlegements 31 | 32 | Part of the codebase, namely, the dataloaders have been adapted from [Cross-Domain Few-Shot Classification via Learned Feature-Wise Transformation](https://github.com/hytseng0509/CrossDomainFewShot#datasets). 33 | 34 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | save_dir = './' 2 | data_dir = {} 3 | 4 | # LFT 5 | data_dir['cub'] = './filelists/cub/' 6 | data_dir['miniImagenet'] = './filelists/miniImagenet/' 7 | data_dir['cars'] = './filelists/cars/' 8 | data_dir['plantae'] = './filelists/plantae/' 9 | data_dir['places'] = './filelists/places/' 10 | -------------------------------------------------------------------------------- /conft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Experimental Settings and Hyperparameters 4 | MODE='ewn_lpan' 5 | TgtSet=cub #options: [cub/cars/places/plantae] 6 | DtracSet='./filelists/miniImagenet/base.json' 7 | DtracBsz=128 8 | FTepoch=100 9 | TAU=0.05 10 | # -------- Run command --------- 11 | CUDA_VISIBLE_DEVICES=0 python finetune.py \ 12 | --ft_mode $MODE \ 13 | --targetset $TgtSet --is_tgt_aug \ 14 | --distractor_set $DtracSet \ 15 | --distractor_bsz $DtracBsz \ 16 | --stop_epoch $FTepoch --tau $TAU \ 17 | --name Mode-${MODE}/TgtSet-${TgtSet}_DSET-${DSET}/DtracBsz-${DtracBsz}_FTepoch-${FTepoch}_TAU-${TAU} \ 18 | --load-modelpath 'output/checkpoints/baseline/399.tar' \ 19 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import os 3 | 4 | def get_datafiles(params, configs): 5 | target_file = params.distractor_set 6 | target_val_file = os.path.join(configs.data_dir[params.targetset], 'val.json') 7 | target_novel_file = os.path.join(configs.data_dir[params.targetset], 'novel.json') 8 | return target_file, target_val_file, target_novel_file -------------------------------------------------------------------------------- /data/additional_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # This is the implementation from https://github.com/facebookresearch/low-shot-shrink-hallucinate. 8 | 9 | import torch 10 | from PIL import ImageEnhance 11 | import ipdb 12 | 13 | transformtypedict=dict(Brightness=ImageEnhance.Brightness, Contrast=ImageEnhance.Contrast, Sharpness=ImageEnhance.Sharpness, Color=ImageEnhance.Color) 14 | 15 | class ImageJitter(object): 16 | def __init__(self, transformdict): 17 | self.transforms = [(transformtypedict[k], transformdict[k]) for k in transformdict] 18 | self.keys = [transformdict[k] for k in transformdict] 19 | 20 | def __call__(self, img): 21 | out = img 22 | randtensor = torch.rand(len(self.transforms)) 23 | 24 | for i, (transformer, alpha) in enumerate(self.transforms): 25 | r = alpha*(randtensor[i]*2.0 -1.0) + 1 26 | out = transformer(out).enhance(r).convert('RGB') 27 | 28 | return out 29 | 30 | def __repr__(self): 31 | return self.__class__.__name__ + str(self.keys) 32 | 33 | class ZeroPixel(object): 34 | def __init__(self, pcnt): 35 | self.pcnt = pcnt 36 | 37 | def __call__(self, img): 38 | rows, cols = torch.where(img[0]==1) 39 | numones = rows.shape[0] 40 | idx = torch.randperm(numones)[:int(self.pcnt*numones)] 41 | img[:,rows[idx],cols[idx]] = 0 42 | return img 43 | 44 | def __repr__(self): 45 | return self.__class__.__name__ + '(pcnt={0})'.format(str(self.pcnt)) 46 | -------------------------------------------------------------------------------- /data/cssf_datamgr_custom_collate.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | import json 3 | import numpy as np 4 | import os 5 | from PIL import Image 6 | import ipdb 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | import torchvision.transforms as transforms 12 | import data.additional_transforms as add_transforms 13 | 14 | from abc import abstractmethod 15 | from data.datamgr import TransformLoader 16 | 17 | 18 | # novel 19 | nWorker_setDM = 0 20 | nWorker_labCB = 0 21 | # source 22 | nWorker_simpleDM = 0 23 | nWorker_unlabCB = 0 24 | identity = lambda x:x 25 | 26 | class DataManager: 27 | @abstractmethod 28 | def get_data_loader(self, data_file, aug): 29 | pass 30 | 31 | def list_collate(batch): 32 | xlist_batch = [] 33 | y_batch = torch.Tensor([]) 34 | for sample in batch: 35 | xlist_batch.append([sample[0]]) 36 | y = sample[1] 37 | if isinstance(y, int): 38 | y = torch.Tensor([y]) 39 | else: 40 | y = y.unsqueeze(0) 41 | y_batch = torch.cat([y_batch, y], dim=0) 42 | return xlist_batch, y_batch 43 | 44 | def list_set_collate(batch): 45 | x_batch = torch.Tensor([]) 46 | xlist_batch = [] 47 | y_batch = torch.Tensor([]) 48 | for sample in batch: 49 | x_batch = torch.cat([x_batch, sample[0].unsqueeze(0)], dim=0) 50 | xlist_batch.append(sample[1]) 51 | y = sample[2] 52 | if isinstance(y, int): 53 | y = torch.Tensor([y]) 54 | else: 55 | y = y.unsqueeze(0) 56 | y_batch = torch.cat([y_batch, y], dim=0) 57 | return x_batch, xlist_batch, y_batch 58 | 59 | ############################################################################## 60 | class SimpleDataManager(DataManager): 61 | def __init__(self, batch_size): 62 | super(SimpleDataManager, self).__init__() 63 | self.batch_size = batch_size 64 | 65 | def get_data_loader(self, data_file, shuffle=False): 66 | dataset = SimpleDataset(data_file) 67 | data_loader_params = dict(batch_size = self.batch_size, shuffle = shuffle, 68 | num_workers = nWorker_simpleDM, pin_memory = True, 69 | drop_last=True, collate_fn=list_collate) 70 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 71 | return data_loader 72 | 73 | class SimpleDataset: 74 | def __init__(self, data_file): 75 | with open(data_file, 'r') as f: 76 | self.meta = json.load(f) 77 | 78 | def __getitem__(self,i): 79 | image_path = os.path.join(self.meta['image_names'][i]) 80 | img = Image.open(image_path).convert('RGB') 81 | target = self.meta['image_labels'][i] 82 | return img, target 83 | 84 | def __len__(self): 85 | return len(self.meta['image_names']) 86 | ############################################################################## 87 | 88 | class SetDataManager(DataManager): 89 | def __init__(self, image_size, num_aug, n_way, n_support, n_query, n_episode=100, no_color=False): 90 | super(SetDataManager, self).__init__() 91 | self.image_size = image_size 92 | self.n_way = n_way 93 | self.n_support = n_support 94 | self.batch_size = n_support + n_query 95 | self.n_episode = n_episode 96 | self.num_aug = num_aug 97 | self.trans_loader = TransformLoader(image_size) 98 | 99 | def get_data_loader(self, data_file): #parameters that would change on train/val set 100 | transform_test = self.trans_loader.get_composed_transform( aug=False) 101 | dataset = SetDataset( data_file , self.batch_size, self.num_aug, transform_test) 102 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_episode ) 103 | data_loader_params = dict(batch_sampler = sampler, num_workers = nWorker_setDM, collate_fn=list_set_collate) 104 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 105 | return data_loader 106 | 107 | class SetDataset: 108 | def __init__(self, data_file, batch_size, num_aug, transform_test): 109 | self.data_file = data_file 110 | with open(data_file, 'r') as f: 111 | self.meta = json.load(f) 112 | 113 | self.cl_list = np.unique(self.meta['image_labels']).tolist() 114 | 115 | self.sub_meta = {} 116 | for cl in self.cl_list: 117 | self.sub_meta[cl] = [] 118 | 119 | for x,y in zip(self.meta['image_names'],self.meta['image_labels']): 120 | self.sub_meta[y].append(x) 121 | 122 | self.sub_dataloader = [] 123 | sub_data_loader_params = dict(batch_size = batch_size, 124 | shuffle = True, 125 | num_workers = 0, #use main thread only or may receive multiple batches 126 | pin_memory = False, collate_fn=list_set_collate) 127 | for cl in self.cl_list: 128 | sub_dataset = SubDataset(self.sub_meta[cl], cl, num_aug=num_aug, 129 | transform_test =transform_test ) 130 | self.sub_dataloader.append( torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params) ) 131 | 132 | def __getitem__(self,i): 133 | clswise_list = next(iter(self.sub_dataloader[i])) 134 | return clswise_list 135 | 136 | def __len__(self): 137 | return len(self.cl_list) 138 | 139 | class SubDataset: 140 | def __init__(self, sub_meta, cl, num_aug=100, transform_test=transforms.ToTensor(), 141 | target_transform=identity, min_size=50): 142 | self.sub_meta = sub_meta 143 | self.cl = cl 144 | # self.transform = ContrastiveWrapper(transform, num_aug) 145 | self.transform_test = transform_test 146 | self.target_transform = target_transform 147 | if len(self.sub_meta) < min_size: 148 | idxs = [i % len(self.sub_meta) for i in range(min_size)] 149 | self.sub_meta = np.array(self.sub_meta)[idxs].tolist() 150 | 151 | def __getitem__(self,i): 152 | image_path = self.sub_meta[i] 153 | img_raw = Image.open(image_path).convert('RGB') 154 | img = self.transform_test(img_raw) 155 | target = self.target_transform(self.cl) 156 | return img, img_raw, target 157 | 158 | 159 | def __len__(self): 160 | return len(self.sub_meta) 161 | 162 | class EpisodicBatchSampler(object): 163 | def __init__(self, n_classes, n_way, n_episodes): 164 | self.n_classes = n_classes 165 | self.n_way = n_way 166 | self.n_episodes = n_episodes 167 | 168 | def __len__(self): 169 | return self.n_episodes 170 | 171 | def __iter__(self): 172 | for i in range(self.n_episodes): 173 | yield torch.randperm(self.n_classes)[:self.n_way] 174 | 175 | ############################################################################## 176 | class Augmentator: 177 | def __init__(self, x, transform, y=None): 178 | self.x = [] 179 | for xi in x: 180 | self.x.extend(xi) 181 | self.transform = transform 182 | if y is not None: 183 | self.y = y.contiguous().view(-1) 184 | else: 185 | self.y = ['None']*len(self.x) 186 | 187 | def __len__(self): 188 | return len(self.x) 189 | 190 | def __getitem__(self, i): 191 | # img_raw = Image.open(self.x[i]).convert('RGB') 192 | # img = self.transform(img_raw) 193 | img = self.transform(self.x[i]) 194 | target = self.y[i] 195 | return img, target 196 | 197 | class UnsupCon_Augmentator (Augmentator): 198 | # for 1 shot experiments 199 | def __init__(self, x, transform): 200 | super(UnsupCon_Augmentator, self).__init__(x, transform) 201 | 202 | def __getitem__(self, i): 203 | img1 = self.transform(self.x[i]) 204 | img2 = self.transform(self.x[i]) 205 | img = torch.cat([img1.unsqueeze(0), img2.unsqueeze(0)], dim=0) 206 | target = self.y[i] 207 | return img, target 208 | 209 | class ContrastiveBatchifier: 210 | def __init__(self, n_way, n_support, image_size, augstrength='0'): 211 | self.n_way=n_way 212 | self.n_support=n_support 213 | jitter_param = dict(Brightness=0.4, Contrast=0.4, Color=0.4) 214 | self.transform = TransformLoader(image_size, jitter_param=jitter_param).get_composed_transform(aug=True) 215 | 216 | def get_loader(self, x): 217 | if self.n_support>1: 218 | dataset = Augmentator(x, self.transform) 219 | else: 220 | dataset = UnsupCon_Augmentator(x, self.transform) 221 | data_loader_params = dict( 222 | dataset=dataset, 223 | batch_size = len(dataset), 224 | shuffle = False,# batchify fn below expects this to be False 225 | num_workers = nWorker_labCB, 226 | pin_memory = True) 227 | data_loader = torch.utils.data.DataLoader(**data_loader_params) 228 | return data_loader 229 | 230 | def _batchify(self, x, n_way, n_support): 231 | # converts into contrastive batch 232 | x = x.contiguous().view(n_way, n_support,*x.shape[1:]) 233 | permuted_idx = torch.cat([torch.randperm(n_support).unsqueeze(0) for _ in range(n_way)], dim=0) 234 | shots_per_way = n_support if n_support % 2 == 0 else n_support - 1 235 | permuted_idx = permuted_idx[:, :shots_per_way] 236 | permuted_idx = permuted_idx.view(n_way, shots_per_way, 1, 1, 1).expand( 237 | n_way, shots_per_way, *x.shape[2:]).to(x.device) 238 | x = torch.gather(x, 1, permuted_idx) 239 | x = x.view(-1, *x.shape[2:]) 240 | bch = torch.split(x, 2) 241 | bch = torch.cat([b.unsqueeze(0) for b in bch], dim=0) 242 | return bch, shots_per_way 243 | 244 | def batchify(self, x): 245 | if self.n_support>1: 246 | return self._batchify(x, self.n_way, self.n_support) 247 | else: 248 | return x, 2 # gget_loader takes care of the form of the input 249 | 250 | def hpm_batchify(self, x): 251 | #hard positive mining 252 | featdim = x.shape[1] 253 | x = x.view(self.n_way, self.n_support, featdim) 254 | # x: [5,5,512] 255 | leftx = x.unsqueeze(2).expand(self.n_way, self.n_support, self.n_support, featdim) 256 | rightx = x.unsqueeze(1).expand(self.n_way, self.n_support, self.n_support, featdim) 257 | alignment = F.cosine_similarity(leftx, rightx, dim=3) #[5ways,5,5] 258 | farpos_idx = alignment.argmin(dim=2) #[5way,5] 259 | farpos_idx = farpos_idx.unsqueeze(2).unsqueeze(3).expand(self.n_way,self.n_support,1,featdim) 260 | farpos_sample = torch.gather(rightx, 2, farpos_idx) 261 | bch = torch.cat([x.view(-1, featdim).unsqueeze(1),farpos_sample.view(-1, featdim).unsqueeze(1)], dim=1) 262 | shots_per_way = self.n_support*2 263 | return bch, shots_per_way 264 | 265 | if __name__=="__main__": 266 | x = torch.rand(5,5,3,224,224) 267 | Contrastive_batchify(x) 268 | -------------------------------------------------------------------------------- /data/datamgr.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import json 4 | import numpy as np 5 | import os 6 | import random 7 | from PIL import Image 8 | import ipdb 9 | 10 | import torch 11 | import torchvision.transforms as transforms 12 | import data.additional_transforms as add_transforms 13 | from abc import abstractmethod 14 | 15 | NUM_WORKERS=2 16 | class TransformLoader: 17 | def __init__(self, image_size, 18 | normalize_param = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 19 | jitter_param = dict(Brightness=0.4, Contrast=0.4, Color=0.4), rot_deg = 30): 20 | self.image_size = image_size 21 | self.normalize_param = normalize_param 22 | self.jitter_param = jitter_param 23 | self.rot_deg = rot_deg 24 | 25 | def parse_transform(self, transform_type): 26 | if transform_type=='ImageJitter': 27 | method = add_transforms.ImageJitter( self.jitter_param ) 28 | return method 29 | method = getattr(transforms, transform_type) 30 | 31 | if transform_type=='Grayscale': 32 | return method(3) 33 | elif transform_type=='RandomResizedCrop': 34 | # return method(self.image_size) 35 | return method(self.image_size, scale=(0.2, 1.0)) 36 | elif transform_type=='CenterCrop': 37 | return method(self.image_size) 38 | elif transform_type=='Resize': 39 | return method([int(self.image_size*1.15), int(self.image_size*1.15)]) 40 | elif transform_type=='RandomRotation': 41 | return method(self.rot_deg) 42 | elif transform_type=='Normalize': 43 | return method(**self.normalize_param ) 44 | else: 45 | return method() 46 | 47 | def get_composed_transform(self, aug = False): 48 | if aug: 49 | transform_list = ['RandomResizedCrop', 'ImageJitter', 'RandomHorizontalFlip', 'ToTensor', 'Normalize'] 50 | else: 51 | transform_list = ['Resize','CenterCrop', 'ToTensor', 'Normalize'] 52 | 53 | transform_funcs = [] 54 | for x in transform_list: 55 | transform_funcs.append(self.parse_transform(x)) 56 | transform = transforms.Compose(transform_funcs) 57 | return transform 58 | 59 | #class DataManager: 60 | class DataManager(object): 61 | @abstractmethod 62 | def get_data_loader(self, data_file, aug): 63 | pass 64 | 65 | class SimpleDataManager(DataManager): 66 | def __init__(self, image_size, batch_size, drop_last=False, is_shuffle=True): 67 | super(SimpleDataManager, self).__init__() 68 | self.batch_size = batch_size 69 | self.trans_loader = TransformLoader(image_size) 70 | self.drop_last = drop_last 71 | self.is_shuffle = is_shuffle 72 | 73 | def get_data_loader(self, data_file, aug): 74 | transform = self.trans_loader.get_composed_transform(aug) 75 | dataset = SimpleDataset(data_file, transform) 76 | data_loader_params = dict(batch_size = self.batch_size, shuffle = self.is_shuffle, 77 | num_workers = NUM_WORKERS, pin_memory =True, drop_last=self.drop_last) 78 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 79 | return data_loader 80 | 81 | ############# 82 | 83 | identity = lambda x:x 84 | 85 | support_label = 1 86 | query_label = 0 87 | class SimpleDataset: 88 | def __init__(self, data_file, transform, target_transform=identity): 89 | with open(data_file, 'r') as f: 90 | self.meta = json.load(f) 91 | 92 | classnames_filename = os.path.join(os.path.dirname(data_file), 'classnames.txt') 93 | if os.path.exists(classnames_filename): 94 | self.clsid2name = {} 95 | with open(classnames_filename) as f: 96 | lines = f.readlines() 97 | for line in lines: 98 | line = line.split('\n')[0] 99 | if '#' not in line and line!='': 100 | try: 101 | clsid, clsname = line.split(' ') 102 | except: 103 | ipdb.set_trace() 104 | self.clsid2name[clsid] = clsname 105 | 106 | self.transform = transform 107 | self.target_transform = target_transform 108 | 109 | def __getitem__(self,i): 110 | image_path = os.path.join(self.meta['image_names'][i]) 111 | img = Image.open(image_path).convert('RGB') 112 | img = self.transform(img) 113 | target = self.target_transform(self.meta['image_labels'][i]) 114 | return img, target 115 | 116 | def __len__(self): 117 | return len(self.meta['image_names']) 118 | 119 | def get_classname(self, filename): 120 | clsid = os.path.dirname(filename).split('/')[-1] 121 | return self.clsid2name[clsid] -------------------------------------------------------------------------------- /filelists/cars/write_cars_filelist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | #from os import listdir 3 | from os.path import join#isfile, isdir, join 4 | import os 5 | #import json 6 | import random 7 | from scipy.io import loadmat 8 | 9 | cwd = os.getcwd() 10 | data_path = join(cwd,'source/cars_train') 11 | savedir = './' 12 | dataset_list = ['base','val','novel'] 13 | 14 | data_list = np.array(loadmat('source/devkit/cars_train_annos.mat')['annotations'][0]) 15 | class_list = np.array(loadmat('source/devkit/cars_meta.mat')['class_names'][0]) 16 | classfile_list_all = [[] for i in range(len(class_list))] 17 | 18 | for i in range(len(data_list)): 19 | folder_path = join(data_path, data_list[i][-1][0]) 20 | classfile_list_all[data_list[i][-2][0][0] - 1].append(folder_path) 21 | 22 | for i in range(len(classfile_list_all)): 23 | random.shuffle(classfile_list_all[i]) 24 | 25 | '''folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))] 26 | folder_list.sort() 27 | label_dict = dict(zip(folder_list,range(0,len(folder_list)))) 28 | 29 | classfile_list_all = [] 30 | 31 | for i, folder in enumerate(folder_list): 32 | folder_path = join(data_path, folder) 33 | classfile_list_all.append( [ join(folder_path, cf) for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')]) 34 | random.shuffle(classfile_list_all[i])''' 35 | 36 | for dataset in dataset_list: 37 | file_list = [] 38 | label_list = [] 39 | for i, classfile_list in enumerate(classfile_list_all): 40 | if 'base' in dataset: 41 | if (i%2 == 0): 42 | file_list = file_list + classfile_list 43 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 44 | if 'val' in dataset: 45 | if (i%4 == 1): 46 | file_list = file_list + classfile_list 47 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 48 | if 'novel' in dataset: 49 | if (i%4 == 3): 50 | file_list = file_list + classfile_list 51 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 52 | 53 | fo = open(savedir + dataset + ".json", "w") 54 | fo.write('{"label_names": [') 55 | fo.writelines(['"%s",' % item[0] for item in class_list]) 56 | fo.seek(0, os.SEEK_END) 57 | fo.seek(fo.tell()-1, os.SEEK_SET) 58 | fo.write('],') 59 | 60 | fo.write('"image_names": [') 61 | fo.writelines(['"%s",' % item for item in file_list]) 62 | fo.seek(0, os.SEEK_END) 63 | fo.seek(fo.tell()-1, os.SEEK_SET) 64 | fo.write('],') 65 | 66 | fo.write('"image_labels": [') 67 | fo.writelines(['%d,' % item for item in label_list]) 68 | fo.seek(0, os.SEEK_END) 69 | fo.seek(fo.tell()-1, os.SEEK_SET) 70 | fo.write(']}') 71 | 72 | fo.close() 73 | print("%s -OK" %dataset) 74 | -------------------------------------------------------------------------------- /filelists/cub/write_cub_filelist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os import listdir 3 | from os.path import isfile, isdir, join 4 | import os 5 | import random 6 | 7 | cwd = os.getcwd() 8 | data_path = join(cwd,'source/CUB_200_2011/images') 9 | savedir = './' 10 | dataset_list = ['base','val','novel'] 11 | 12 | #if not os.path.exists(savedir): 13 | # os.makedirs(savedir) 14 | 15 | folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))] 16 | folder_list.sort() 17 | label_dict = dict(zip(folder_list,range(0,len(folder_list)))) 18 | 19 | classfile_list_all = [] 20 | 21 | for i, folder in enumerate(folder_list): 22 | folder_path = join(data_path, folder) 23 | classfile_list_all.append( [ join(folder_path, cf) for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')]) 24 | random.shuffle(classfile_list_all[i]) 25 | 26 | 27 | for dataset in dataset_list: 28 | file_list = [] 29 | label_list = [] 30 | for i, classfile_list in enumerate(classfile_list_all): 31 | if 'base' in dataset: 32 | if (i%2 == 0): 33 | file_list = file_list + classfile_list 34 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 35 | if 'val' in dataset: 36 | if (i%4 == 1): 37 | file_list = file_list + classfile_list 38 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 39 | if 'novel' in dataset: 40 | if (i%4 == 3): 41 | file_list = file_list + classfile_list 42 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 43 | 44 | fo = open(savedir + dataset + ".json", "w") 45 | fo.write('{"label_names": [') 46 | fo.writelines(['"%s",' % item for item in folder_list]) 47 | fo.seek(0, os.SEEK_END) 48 | fo.seek(fo.tell()-1, os.SEEK_SET) 49 | fo.write('],') 50 | 51 | fo.write('"image_names": [') 52 | fo.writelines(['"%s",' % item for item in file_list]) 53 | fo.seek(0, os.SEEK_END) 54 | fo.seek(fo.tell()-1, os.SEEK_SET) 55 | fo.write('],') 56 | 57 | fo.write('"image_labels": [') 58 | fo.writelines(['%d,' % item for item in label_list]) 59 | fo.seek(0, os.SEEK_END) 60 | fo.seek(fo.tell()-1, os.SEEK_SET) 61 | fo.write(']}') 62 | 63 | fo.close() 64 | print("%s -OK" %dataset) 65 | -------------------------------------------------------------------------------- /filelists/miniImagenet/write_miniImagenet_filelist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os import listdir 3 | from os.path import isfile, isdir, join 4 | import os 5 | import random 6 | 7 | cwd = os.getcwd() 8 | root_path = join(cwd,'source/mini_imagenet_full_size') 9 | savedir = './' 10 | dataset_list = ['base','val','novel'] 11 | 12 | for dataset in dataset_list: 13 | if dataset == 'base': 14 | data_path = join(root_path , 'train') 15 | elif dataset == 'val': 16 | data_path = join(root_path , 'val') 17 | elif dataset == 'novel': 18 | data_path = join(root_path , 'test') 19 | else: 20 | raise Exception('no such dataset') 21 | 22 | folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))] 23 | print('{} dataset contains {} categories'.format(dataset, len(folder_list))) 24 | folder_list.sort() 25 | label_dict = dict(zip(folder_list,range(0,len(folder_list)))) 26 | 27 | classfile_list_all = [] 28 | 29 | for i, folder in enumerate(folder_list): 30 | folder_path = join(data_path, folder) 31 | classfile_list_all.append( [ join(folder_path, cf) for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')]) 32 | random.shuffle(classfile_list_all[i]) 33 | assert(len(classfile_list_all[i]) == 600) 34 | 35 | file_list = [] 36 | label_list = [] 37 | for i, classfile_list in enumerate(classfile_list_all): 38 | file_list = file_list + classfile_list 39 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 40 | 41 | fo = open(savedir + dataset + ".json", "w") 42 | fo.write('{"label_names": [') 43 | fo.writelines(['"%s",' % item for item in folder_list]) 44 | fo.seek(0, os.SEEK_END) 45 | fo.seek(fo.tell()-1, os.SEEK_SET) 46 | fo.write('],') 47 | 48 | fo.write('"image_names": [') 49 | fo.writelines(['"%s",' % item for item in file_list]) 50 | fo.seek(0, os.SEEK_END) 51 | fo.seek(fo.tell()-1, os.SEEK_SET) 52 | fo.write('],') 53 | 54 | fo.write('"image_labels": [') 55 | fo.writelines(['%d,' % item for item in label_list]) 56 | fo.seek(0, os.SEEK_END) 57 | fo.seek(fo.tell()-1, os.SEEK_SET) 58 | fo.write(']}') 59 | 60 | fo.close() 61 | print("%s -OK" %dataset) 62 | -------------------------------------------------------------------------------- /filelists/places/write_places_filelist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os import listdir 3 | from os.path import isfile, isdir, join 4 | import os 5 | import random 6 | 7 | cwd = os.getcwd() 8 | data_path = join(cwd,'source/places365_standard/train') 9 | savedir = './' 10 | dataset_list = ['base','val','novel'] 11 | 12 | 13 | folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))] 14 | folder_list.sort() 15 | label_dict = dict(zip(folder_list,range(0,len(folder_list)))) 16 | 17 | classfile_list_all = [] 18 | 19 | for i, folder in enumerate(folder_list): 20 | folder_path = join(data_path, folder) 21 | cfs = [cf for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')] 22 | cfs.sort() 23 | cfs = cfs[:200] 24 | classfile_list_all.append([ join(folder_path, cf) for cf in cfs]) 25 | random.shuffle(classfile_list_all[i]) 26 | print(len(classfile_list_all)) 27 | 28 | for dataset in dataset_list: 29 | file_list = [] 30 | label_list = [] 31 | for i, classfile_list in enumerate(classfile_list_all): 32 | if 'base' in dataset: 33 | if (i%2 == 0): 34 | file_list = file_list + classfile_list 35 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 36 | if 'val' in dataset: 37 | if (i%4 == 1): 38 | file_list = file_list + classfile_list 39 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 40 | if 'novel' in dataset: 41 | if (i%4 == 3): 42 | file_list = file_list + classfile_list 43 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 44 | 45 | fo = open(savedir + dataset + ".json", "w") 46 | fo.write('{"label_names": [') 47 | fo.writelines(['"%s",' % item for item in folder_list]) 48 | fo.seek(0, os.SEEK_END) 49 | fo.seek(fo.tell()-1, os.SEEK_SET) 50 | fo.write('],') 51 | 52 | fo.write('"image_names": [') 53 | fo.writelines(['"%s",' % item for item in file_list]) 54 | fo.seek(0, os.SEEK_END) 55 | fo.seek(fo.tell()-1, os.SEEK_SET) 56 | fo.write('],') 57 | 58 | fo.write('"image_labels": [') 59 | fo.writelines(['%d,' % item for item in label_list]) 60 | fo.seek(0, os.SEEK_END) 61 | fo.seek(fo.tell()-1, os.SEEK_SET) 62 | fo.write(']}') 63 | 64 | fo.close() 65 | print("%s -OK" %dataset) 66 | -------------------------------------------------------------------------------- /filelists/plantae/write_plantae_filelist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os import listdir 3 | from os.path import isfile, isdir, join 4 | import os 5 | #import json 6 | import random 7 | from subprocess import call 8 | 9 | cwd = os.getcwd() 10 | source_path = join(cwd,'source/Plantae') 11 | data_path = join(cwd,'images') 12 | if not os.path.exists(data_path): 13 | os.makedirs(data_path) 14 | savedir = './' 15 | dataset_list = ['base','val','novel'] 16 | 17 | 18 | folder_list = [f for f in listdir(source_path) if isdir(join(source_path, f))] 19 | #folder_list.sort() 20 | folder_list_count = np.array([len(listdir(join(source_path, f))) for f in folder_list]) 21 | folder_list_idx = np.argsort(folder_list_count) 22 | folder_list = np.array(folder_list)[folder_list_idx[-200:]].tolist() 23 | label_dict = dict(zip(folder_list,range(0,len(folder_list)))) 24 | 25 | classfile_list_all = [] 26 | 27 | for i, folder in enumerate(folder_list): 28 | source_folder_path = join(source_path, folder) 29 | folder_path = join(data_path, folder) 30 | classfile_list_all.append( [ cf for cf in listdir(source_folder_path) if (isfile(join(source_folder_path,cf)) and cf[0] != '.')]) 31 | random.shuffle(classfile_list_all[i]) 32 | classfile_list_all[i] = classfile_list_all[i][:min(len(classfile_list_all[i]), 600)] 33 | 34 | call('mkdir ' + folder_path, shell=True) 35 | for cf in classfile_list_all[i]: 36 | call('cp ' + join(source_folder_path, cf) + ' ' + join(folder_path, cf), shell=True) 37 | classfile_list_all[i] = [join(folder_path, cf) for cf in classfile_list_all[i]] 38 | 39 | for dataset in dataset_list: 40 | file_list = [] 41 | label_list = [] 42 | for i, classfile_list in enumerate(classfile_list_all): 43 | if 'base' in dataset: 44 | if (i%2 == 0): 45 | file_list = file_list + classfile_list 46 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 47 | if 'val' in dataset: 48 | if (i%4 == 1): 49 | file_list = file_list + classfile_list 50 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 51 | if 'novel' in dataset: 52 | if (i%4 == 3): 53 | file_list = file_list + classfile_list 54 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 55 | 56 | fo = open(savedir + dataset + ".json", "w") 57 | fo.write('{"label_names": [') 58 | fo.writelines(['"%s",' % item for item in folder_list]) 59 | fo.seek(0, os.SEEK_END) 60 | fo.seek(fo.tell()-1, os.SEEK_SET) 61 | fo.write('],') 62 | 63 | fo.write('"image_names": [') 64 | fo.writelines(['"%s",' % item for item in file_list]) 65 | fo.seek(0, os.SEEK_END) 66 | fo.seek(fo.tell()-1, os.SEEK_SET) 67 | fo.write('],') 68 | 69 | fo.write('"image_labels": [') 70 | fo.writelines(['%d,' % item for item in label_list]) 71 | fo.seek(0, os.SEEK_END) 72 | fo.seek(fo.tell()-1, os.SEEK_SET) 73 | fo.write(']}') 74 | 75 | fo.close() 76 | print("%s -OK" %dataset) 77 | -------------------------------------------------------------------------------- /filelists/process.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from subprocess import call 4 | 5 | if len(sys.argv) != 2: 6 | raise Exception('Incorrect command! e.g., python3 process.py DATASET [cars, cub, places, miniImagenet, plantae]') 7 | dataset = sys.argv[1] 8 | 9 | print('--- process ' + dataset + ' dataset ---') 10 | if not os.path.exists(os.path.join(dataset, 'source')): 11 | os.makedirs(os.path.join(dataset, 'source')) 12 | os.chdir(os.path.join(dataset, 'source')) 13 | 14 | # download files 15 | if dataset == 'cars': 16 | call('wget http://imagenet.stanford.edu/internal/car196/cars_train.tgz', shell=True) 17 | call('tar -zxf cars_train.tgz', shell=True) 18 | call('wget https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz', shell=True) 19 | call('tar -zxf car_devkit.tgz', shell=True) 20 | elif dataset == 'cub': 21 | call('wget http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz', shell=True) 22 | call('tar -zxf CUB_200_2011.tgz', shell=True) 23 | elif dataset == 'places': 24 | call('wget http://data.csail.mit.edu/places/places365/places365standard_easyformat.tar', shell=True) 25 | call('tar -xf places365standard_easyformat.tar', shell=True) 26 | elif dataset == 'miniImagenet': 27 | # this file is from MAML++: https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch 28 | call('wget http://vllab.ucmerced.edu/ym41608/projects/CrossDomainFewShot/filelists/mini_imagenet_full_size.tar.bz2', shell=True) 29 | call('tar -xjf mini_imagenet_full_size.tar.bz2', shell=True) 30 | elif dataset == 'plantae': 31 | call('wget http://vllab.ucmerced.edu/ym41608/projects/CrossDomainFewShot/filelists/plantae.tar.gz', shell=True) 32 | call('tar -xzf plantae.tar.gz', shell=True) 33 | else: 34 | raise Exception('No such dataset!') 35 | 36 | # process file 37 | os.chdir('..') 38 | call('python3 write_' + dataset + '_filelist.py', shell=True) 39 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | ''' 4 | import numpy as np 5 | import time 6 | import os 7 | import glob 8 | from itertools import combinations 9 | from tqdm import tqdm 10 | import matplotlib.pyplot as plt 11 | from shutil import copy 12 | import pickle 13 | import ipdb 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.optim 18 | import torch.nn.functional as F 19 | import torch.optim.lr_scheduler as lr_scheduler 20 | from torch.autograd import Variable 21 | 22 | from methods.backbone import model_dict 23 | from methods.weight_imprint_based import LinearEvaluator 24 | from methods.weight_imprint_based.ConCE import ConCeModel 25 | from data import get_datafiles 26 | from data.datamgr import SimpleDataManager 27 | from data.cssf_datamgr_custom_collate import SetDataManager, ContrastiveBatchifier 28 | 29 | import configs 30 | from options import parse_args, get_best_file, get_assigned_file 31 | from utils import * 32 | from read_parallel_results import consolidate_results 33 | 34 | EPS = 0 35 | eps = 0.0000001 36 | 37 | def record_each(q_type, quantity, results_path): 38 | for key in quantity.keys(): 39 | # for each epoch 40 | if len(quantity[key]) != 0: 41 | quantity_this = quantity[key] 42 | quantity_this = np.asarray(quantity_this) 43 | with open(results_path.replace('results.txt', '%s%s.pkl' % (q_type, key)), 'wb') as f: 44 | pickle.dump(quantity_this, f) 45 | 46 | 47 | def record(acc_all_le, acc_all, accDiff_all, cluster_support, cluster_query, results_path): 48 | acc_all_le = np.asarray(acc_all_le) 49 | 50 | with open(results_path.replace('results.txt', 'linev.pkl'), 'wb') as f: 51 | pickle.dump(acc_all_le, f) 52 | 53 | # accuracy 54 | record_each('wi_final', acc_all, results_path) 55 | record_each('wi_delta', accDiff_all, results_path) 56 | # support clusters 57 | record_each('wi_support_cspread', cluster_support['cspread'], results_path) 58 | record_each('wi_support_cspread_pcnt', cluster_support['cspread_pcnt'], results_path) 59 | record_each('wi_support_csep', cluster_support['csep'], results_path) 60 | record_each('wi_support_csep_pcnt', cluster_support['csep_pcnt'], results_path) 61 | # query clusters 62 | record_each('wi_query_cspread', cluster_query['cspread'], results_path) 63 | record_each('wi_query_cspread_pcnt', cluster_query['cspread_pcnt'], results_path) 64 | record_each('wi_query_csep', cluster_query['csep'], results_path) 65 | record_each('wi_query_csep_pcnt', cluster_query['csep_pcnt'], results_path) 66 | 67 | 68 | def epoch_wise_collection(model, n_way, n_support, n_query, x, epoch, 69 | acc_before, 70 | cspread_support_before, csep_support_before, cspread_query_before, csep_query_before, 71 | acc_all, accDiff_all, cluster_support, cluster_query): 72 | # accuracies 73 | acc_after, _, z_all_analysis = model.validate(n_way, n_support, n_query, x, epoch) 74 | acc_all[str(epoch + 1)].append(acc_after) 75 | improvement = acc_after - acc_before 76 | accDiff_all[str(epoch + 1)].append(improvement) 77 | 78 | # cluster distances 79 | cspread_support_after, csep_support_after, cspread_query_after, csep_query_after = \ 80 | model.get_episode_distances(n_way, n_support, n_query, z_all_analysis) 81 | # support spread 82 | delta = cspread_support_after - cspread_support_before 83 | cluster_support['cspread'][str(epoch + 1)].append(delta) 84 | cluster_support['cspread_pcnt'][str(epoch + 1)].append(delta / (cspread_support_before + eps)) 85 | # support separation 86 | delta = csep_support_after - csep_support_before 87 | cluster_support['csep'][str(epoch + 1)].append(delta) 88 | cluster_support['csep_pcnt'][str(epoch + 1)].append(delta / (csep_support_before + eps)) 89 | # query spread 90 | delta = cspread_query_after - cspread_query_before 91 | cluster_query['cspread'][str(epoch + 1)].append(delta) 92 | cluster_query['cspread_pcnt'][str(epoch + 1)].append(delta / (cspread_query_before + eps)) 93 | # query separation 94 | delta = csep_query_after - csep_query_before 95 | cluster_query['csep'][str(epoch + 1)].append(delta) 96 | cluster_query['csep_pcnt'][str(epoch + 1)].append(delta / (csep_query_before + eps)) 97 | # return acc_after, improvement, z_all_analysis, cspread_support_before, csep_support_before, cspread_query_before, \ 98 | # csep_query_before, acc_all, accDiff_all, cluster_support, cluster_query 99 | return acc_after, improvement, z_all_analysis, cspread_support_after, csep_support_after, cspread_query_after, \ 100 | csep_query_after, acc_all, accDiff_all, cluster_support, cluster_query 101 | 102 | 103 | def get_init_var(): 104 | acc_all_le = [] 105 | acc_all = {'0': [], '10': [], '25': [], '50': [], '75': [], '100': [], '200': [], '300': [], '400': [], '500': [], 106 | '600': [], '700': [], '800': [], '900': [], '1000': []} 107 | accDiff_all = {'0': [], '10': [], '25': [], '50': [], '75': [], '100': [], '200': [], '300': [], '400': [], '500': [], 108 | '600': [], '700': [], '800': [], '900': [], '1000': []} 109 | cluster_support = {} 110 | cluster_support['cspread'] = {'0': [], '10': [], '25': [], '50': [], '75': [], '100': [], '200': [], '300': [], 111 | '400': [], '500': [], '600': [], '700': [], '800': [], '900': [], '1000': []} 112 | cluster_support['cspread_pcnt'] = {'0': [], '10': [], '25': [], '50': [], '75': [], '100': [], '200': [], '300': [], 113 | '400': [], '500': [], '600': [], '700': [], '800': [], '900': [], '1000': []} 114 | cluster_support['csep'] = {'0': [], '10': [], '25': [], '50': [], '75': [], '100': [], '200': [], '300': [], 115 | '400': [], '500': [], '600': [], '700': [], '800': [], '900': [], '1000': []} 116 | cluster_support['csep_pcnt'] = {'0': [], '10': [], '25': [], '50': [], '75': [], '100': [], '200': [], '300': [], 117 | '400': [], '500': [], '600': [], '700': [], '800': [], '900': [], '1000': []} 118 | cluster_query = {} 119 | cluster_query['cspread'] = {'0': [], '10': [], '25': [], '50': [], '75': [], '100': [], '200': [], '300': [], 120 | '400': [], '500': [], '600': [], '700': [], '800': [], '900': [], '1000': []} 121 | cluster_query['cspread_pcnt'] = {'0': [], '10': [], '25': [], '50': [], '75': [], '100': [], '200': [], '300': [], 122 | '400': [], '500': [], '600': [], '700': [], '800': [], '900': [], '1000': []} 123 | cluster_query['csep'] = {'0': [], '10': [], '25': [], '50': [], '75': [], '100': [], '200': [], '300': [], '400': [], 124 | '500': [], '600': [], '700': [], '800': [], '900': [], '1000': []} 125 | cluster_query['csep_pcnt'] = {'0': [], '10': [], '25': [], '50': [], '75': [], '100': [], '200': [], '300': [], 126 | '400': [], '500': [], '600': [], '700': [], '800': [], '900': [], '1000': []} 127 | return acc_all_le, acc_all, accDiff_all, cluster_support, cluster_query 128 | 129 | 130 | def finetune(source_loader, novel_loader, total_epoch, model_params, dataloader_params, params): 131 | results_dir = params.results_dir 132 | n_way, n_support = dataloader_params['n_way'], dataloader_params['n_support'] 133 | image_size = dataloader_params['image_size'] 134 | acc_all_le, acc_all, accDiff_all, cluster_support, cluster_query = get_init_var() 135 | 136 | # model 137 | model = ConCeModel(n_way=n_way, projhead=params.projection_head, 138 | ft_mode=params.ft_mode, **model_params) 139 | progress = tqdm(novel_loader) 140 | for task_id, (x, x_ft, y) in enumerate(progress): 141 | if params.ufsl_dataset: 142 | supcon_datamgr = ContrastiveBatchifier(n_way=n_way, n_support=n_support, image_size=image_size, 143 | dataset=params.targetset) 144 | else: 145 | supcon_datamgr = ContrastiveBatchifier(n_way=n_way, n_support=n_support, image_size=image_size, 146 | augstrength=params.augstrength) 147 | 148 | if params.is_tgt_aug: 149 | supcon_dataloader = supcon_datamgr.get_loader([sample[:n_support] for sample in x_ft]) 150 | else: 151 | supcon_dataloader = None 152 | ############################################################################################### 153 | # load pretrained feature extractor 154 | model.refresh_from_chkpt() 155 | model.cuda() 156 | optimizer = torch.optim.Adam(model.parameters(), lr=params.lr) 157 | n_query = x.size(1) - n_support 158 | x = x.cuda() 159 | 160 | acc_before, _, z_all_analysis, cspread_support_before, csep_support_before, cspread_query_before, \ 161 | csep_query_before, acc_all, accDiff_all, cluster_support, cluster_query = epoch_wise_collection( 162 | model, n_way, n_support, n_query, x, -1, 163 | EPS, EPS, EPS, EPS, EPS, acc_all, accDiff_all, cluster_support, cluster_query) 164 | 165 | for epoch in range(total_epoch): 166 | model.train() 167 | 168 | # labelled contrastive batch 169 | if supcon_dataloader is not None: 170 | x_l_fewshot, _ = next(iter(supcon_dataloader)) 171 | x_l_fewshot = x_l_fewshot.cuda() 172 | else: 173 | # without any augmentation 174 | x_l_fewshot = x[:, :n_support, :, :, :] 175 | x_l_fewshot = x_l_fewshot.contiguous().view(x_l_fewshot.size(0) * x_l_fewshot.size(1), 176 | *x_l_fewshot.shape[2:]).cuda() 177 | 178 | # unlabelled batch 179 | if epoch % len(source_loader) == 0: 180 | src_iter = iter(source_loader) 181 | x_src, y_src = next(src_iter) 182 | x_src, y_src = x_src.cuda(), y_src.cuda() 183 | unlab_bsz = x_src.size(0) 184 | z_src = model.get_feature(x_src) 185 | 186 | # Batchify 187 | x_l, shots_per_way = supcon_datamgr.batchify(x_l_fewshot) 188 | x_l = x_l.view(x_l.size(0) * x_l.size(1), *x_l.size()[2:]) 189 | z_l = model.forward_this(x_l) 190 | # random paired positives 191 | z_u = model.forward_projection(z_src) 192 | z_batch = torch.cat([z_l, z_u], dim=0) 193 | loss_primary = model.cssf_loss(z_batch, shots_per_way, n_way, unlab_bsz, 194 | mode=params.ft_mode, alpha=params.alpha) 195 | 196 | if 'mtce' in params.ft_mode: 197 | loss_mt = model.CE_loss_source(z_src, y_src) 198 | loss = loss_primary + loss_mt 199 | else: 200 | # no Multitask supervised loss 201 | loss = loss_primary 202 | 203 | optimizer.zero_grad() 204 | loss.backward() 205 | optimizer.step() 206 | model.tf_writer.add_scalar('train/loss', loss.item(), epoch) 207 | if str(epoch + 1) in acc_all.keys(): 208 | acc_after, improvement, _, _, _, _, _, acc_all, accDiff_all, cluster_support, cluster_query = epoch_wise_collection( 209 | model, n_way, n_support, n_query, x, epoch, 210 | acc_before, 211 | cspread_support_before, csep_support_before, cspread_query_before, 212 | csep_query_before, 213 | acc_all, accDiff_all, cluster_support, cluster_query) 214 | 215 | #progress.set_description('improvement%d = %0.3f, finalacc%d = %0.3f' % (epoch + 1, improvement, epoch + 1, acc_after)) 216 | progress.set_description('epoch %d' % (epoch + 1)) 217 | 218 | acc_all_le.append(0) 219 | if (task_id + 1) % 10 == 0 or task_id == len(novel_loader) - 1: 220 | record(acc_all_le, acc_all, accDiff_all, cluster_support, cluster_query, 221 | results_path=os.path.join(results_dir, 'results.txt')) 222 | 223 | if __name__ == '__main__': 224 | np.random.seed(10) 225 | params = parse_args('train/ufsl/cssf/parallel') 226 | print(params) 227 | resultdir = '%s/%s' % (params.save_dir, params.name) 228 | params.name = os.path.join(params.name, str(params.run_id)) 229 | 230 | # output and tensorboard dir 231 | params.tf_dir = '%s/%s/log' % (params.save_dir, params.name) 232 | params.checkpoint_dir = params.tf_dir.replace('/log', '/chkpt') 233 | if not os.path.isdir(params.checkpoint_dir): 234 | os.makedirs(params.checkpoint_dir) 235 | params.results_dir = params.tf_dir.replace('/log', '/results') 236 | if not os.path.isdir(params.results_dir): 237 | os.makedirs(params.results_dir) 238 | ################################################################## 239 | if 'Conv' in params.model: 240 | image_size = 84 241 | else: 242 | image_size = 224 243 | num_tasks = params.num_tasks 244 | total_epoch = params.stop_epoch 245 | 246 | # dataloaders 247 | target_file, target_val_file, target_novel_file = get_datafiles(params, configs) 248 | if params.hyperparam_select: 249 | inference_file = target_val_file 250 | else: 251 | inference_file = target_novel_file 252 | 253 | dataloader_params = dict( 254 | image_size=image_size, 255 | num_aug=total_epoch, 256 | n_way=params.test_n_way, 257 | n_support=params.n_shot, 258 | n_episode=num_tasks, 259 | n_query=params.n_query) 260 | 261 | if target_file is not None: 262 | novel_loader = SetDataManager(**dataloader_params).get_data_loader(inference_file) 263 | source_loader = SimpleDataManager(image_size, params.distractor_bsz).get_data_loader( 264 | target_file, aug=params.is_src_aug) 265 | else: 266 | raise ValueError('must define "--distractor_set" flag') 267 | 268 | print('------------ experimental details --------------') 269 | print('finetuning mode : %s' % params.ft_mode) 270 | print('model : %s' % params.model) 271 | print('source file: %s' % (target_file)) 272 | print('novel file: %s' % (inference_file)) 273 | print('n_way: %d, n_shot: %d, n_query: %d' % ( 274 | dataloader_params['n_way'], dataloader_params['n_support'], dataloader_params['n_query'])) 275 | print('distractor sz: %s' % (params.distractor_bsz)) 276 | print('tau : %s' % (params.tau)) 277 | print('lr : %s' % (params.lr)) 278 | print('Source Aug : ', params.is_src_aug) 279 | print('Target Aug : ', params.is_tgt_aug) 280 | 281 | # model 282 | model_params = dict( 283 | model_func=model_dict[params.model], 284 | tau=params.tau, 285 | tf_path=params.tf_dir, 286 | loadpath=params.load_modelpath, 287 | is_distribute=torch.cuda.device_count() > 1, 288 | src_classes=64, 289 | cos_fac=params.cosine_fac, 290 | ) 291 | 292 | # contrastive finetune 293 | finetune(source_loader, novel_loader, total_epoch, model_params, dataloader_params, params) 294 | 295 | # save results 296 | consolidate_results(root=resultdir) 297 | -------------------------------------------------------------------------------- /methods/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantacode/Contrastive-Finetuning/321e60f26644f4cd6ed7e362cbbddce2f753e7ad/methods/__init__.py -------------------------------------------------------------------------------- /methods/backbone.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | import torch.nn.functional as F 7 | from torch.nn.utils import weight_norm 8 | from torch.distributions import Bernoulli 9 | import ipdb 10 | 11 | TRACKBN = True 12 | # --- gaussian initialize --- 13 | def init_layer(L): 14 | # Initialization using fan-in 15 | if isinstance(L, nn.Conv2d): 16 | n = L.kernel_size[0]*L.kernel_size[1]*L.out_channels 17 | L.weight.data.normal_(0,math.sqrt(2.0/float(n))) 18 | elif isinstance(L, nn.BatchNorm2d): 19 | L.weight.data.fill_(1) 20 | L.bias.data.fill_(0) 21 | 22 | class distLinear(nn.Module): 23 | def __init__(self, indim, outdim): 24 | super(distLinear, self).__init__() 25 | self.L = weight_norm(nn.Linear(indim, outdim, bias=False), name='weight', dim=0) 26 | self.relu = nn.ReLU() 27 | 28 | def forward(self, x): 29 | x_norm = torch.norm(x, p=2, dim =1).unsqueeze(1).expand_as(x) 30 | x_normalized = x.div(x_norm + 0.00001) 31 | L_norm = torch.norm(self.L.weight.data, p=2, dim =1).unsqueeze(1).expand_as(self.L.weight.data) 32 | self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001) 33 | cos_dist = self.L(x_normalized) 34 | scores = 10 * cos_dist 35 | return scores 36 | 37 | # --- flatten tensor --- 38 | class Flatten(nn.Module): 39 | def __init__(self): 40 | super(Flatten, self).__init__() 41 | 42 | def forward(self, x): 43 | return x.view(x.size(0), -1) 44 | 45 | # --- conv3x3 tensor --- 46 | def conv3x3(in_planes, out_planes, stride=1): 47 | """3x3 convolution with padding""" 48 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 49 | padding=1, bias=False) 50 | 51 | class SELayer(nn.Module): 52 | def __init__(self, channel, reduction=16): 53 | super(SELayer, self).__init__() 54 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 55 | self.fc = nn.Sequential( 56 | nn.Linear(channel, channel // reduction), 57 | nn.ReLU(inplace=True), 58 | nn.Linear(channel // reduction, channel), 59 | nn.Sigmoid() 60 | ) 61 | 62 | def forward(self, x): 63 | b, c, _, _ = x.size() 64 | y = self.avg_pool(x).view(b, c) 65 | y = self.fc(y).view(b, c, 1, 1) 66 | return x * y 67 | 68 | class DropBlock(nn.Module): 69 | def __init__(self, block_size): 70 | super(DropBlock, self).__init__() 71 | 72 | self.block_size = block_size 73 | # self.gamma = gamma 74 | # self.bernouli = Bernoulli(gamma) 75 | 76 | def forward(self, x, gamma): 77 | # shape: (bsize, channels, height, width) 78 | 79 | if self.training: 80 | batch_size, channels, height, width = x.shape 81 | 82 | bernoulli = Bernoulli(gamma) 83 | mask = bernoulli.sample( 84 | (batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda() 85 | block_mask = self._compute_block_mask(mask) 86 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3] 87 | count_ones = block_mask.sum() 88 | 89 | return block_mask * x * (countM / count_ones) 90 | else: 91 | return x 92 | 93 | def _compute_block_mask(self, mask): 94 | left_padding = int((self.block_size - 1) / 2) 95 | right_padding = int(self.block_size / 2) 96 | 97 | batch_size, channels, height, width = mask.shape 98 | # print ("mask", mask[0][0]) 99 | non_zero_idxs = mask.nonzero() 100 | nr_blocks = non_zero_idxs.shape[0] 101 | 102 | offsets = torch.stack( 103 | [ 104 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), 105 | # - left_padding, 106 | torch.arange(self.block_size).repeat(self.block_size), # - left_padding 107 | ] 108 | ).t().cuda() 109 | offsets = torch.cat((torch.zeros(self.block_size ** 2, 2).cuda().long(), offsets.long()), 1) 110 | 111 | if nr_blocks > 0: 112 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1) 113 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4) 114 | offsets = offsets.long() 115 | 116 | block_idxs = non_zero_idxs + offsets 117 | # block_idxs += left_padding 118 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 119 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1. 120 | else: 121 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 122 | 123 | block_mask = 1 - padded_mask # [:height, :width] 124 | return block_mask 125 | 126 | 127 | # --- LSTMCell module for matchingnet --- 128 | class LSTMCell(nn.Module): 129 | maml = False 130 | def __init__(self, input_size, hidden_size, bias=True): 131 | super(LSTMCell, self).__init__() 132 | self.input_size = input_size 133 | self.hidden_size = hidden_size 134 | self.bias = bias 135 | if self.maml: 136 | self.x2h = Linear_fw(input_size, 4 * hidden_size, bias=bias) 137 | self.h2h = Linear_fw(hidden_size, 4 * hidden_size, bias=bias) 138 | else: 139 | self.x2h = nn.Linear(input_size, 4 * hidden_size, bias=bias) 140 | self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias) 141 | self.reset_parameters() 142 | 143 | def reset_parameters(self): 144 | std = 1.0 / math.sqrt(self.hidden_size) 145 | for w in self.parameters(): 146 | w.data.uniform_(-std, std) 147 | 148 | def forward(self, x, hidden=None): 149 | if hidden is None: 150 | hx = torch.zeors_like(x) 151 | cx = torch.zeros_like(x) 152 | else: 153 | hx, cx = hidden 154 | 155 | gates = self.x2h(x) + self.h2h(hx) 156 | ingate, forgetgate, cellgate, outgate = torch.split(gates, self.hidden_size, dim=1) 157 | 158 | ingate = torch.sigmoid(ingate) 159 | forgetgate = torch.sigmoid(forgetgate) 160 | cellgate = torch.tanh(cellgate) 161 | outgate = torch.sigmoid(outgate) 162 | 163 | cy = torch.mul(cx, forgetgate) + torch.mul(ingate, cellgate) 164 | hy = torch.mul(outgate, torch.tanh(cy)) 165 | return (hy, cy) 166 | 167 | # --- LSTM module for matchingnet --- 168 | class LSTM(nn.Module): 169 | def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, bidirectional=False): 170 | super(LSTM, self).__init__() 171 | 172 | self.input_size = input_size 173 | self.hidden_size = hidden_size 174 | self.num_layers = num_layers 175 | self.bias = bias 176 | self.batch_first = batch_first 177 | self.num_directions = 2 if bidirectional else 1 178 | assert(self.num_layers == 1) 179 | 180 | self.lstm = LSTMCell(input_size, hidden_size, self.bias) 181 | 182 | def forward(self, x, hidden=None): 183 | # swap axis if batch first 184 | if self.batch_first: 185 | x = x.permute(1, 0 ,2) 186 | 187 | # hidden state 188 | if hidden is None: 189 | h0 = torch.zeros(self.num_directions, x.size(1), self.hidden_size, dtype=x.dtype, device=x.device) 190 | c0 = torch.zeros(self.num_directions, x.size(1), self.hidden_size, dtype=x.dtype, device=x.device) 191 | else: 192 | h0, c0 = hidden 193 | 194 | # forward 195 | outs = [] 196 | hn = h0[0] 197 | cn = c0[0] 198 | for seq in range(x.size(0)): 199 | hn, cn = self.lstm(x[seq], (hn, cn)) 200 | outs.append(hn.unsqueeze(0)) 201 | outs = torch.cat(outs, dim=0) 202 | 203 | # reverse foward 204 | if self.num_directions == 2: 205 | outs_reverse = [] 206 | hn = h0[1] 207 | cn = c0[1] 208 | for seq in range(x.size(0)): 209 | seq = x.size(1) - 1 - seq 210 | hn, cn = self.lstm(x[seq], (hn, cn)) 211 | outs_reverse.append(hn.unsqueeze(0)) 212 | outs_reverse = torch.cat(outs_reverse, dim=0) 213 | outs = torch.cat([outs, outs_reverse], dim=2) 214 | 215 | # swap axis if batch first 216 | if self.batch_first: 217 | outs = outs.permute(1, 0, 2) 218 | return outs 219 | 220 | # --- Linear module --- 221 | class Linear_fw(nn.Linear): #used in MAML to forward input with fast weight 222 | def __init__(self, in_features, out_features, bias=True): 223 | super(Linear_fw, self).__init__(in_features, out_features, bias=bias) 224 | self.weight.fast = None #Lazy hack to add fast weight link 225 | self.bias.fast = None 226 | 227 | def forward(self, x): 228 | if self.weight.fast is not None and self.bias.fast is not None: 229 | out = F.linear(x, self.weight.fast, self.bias.fast) 230 | else: 231 | out = super(Linear_fw, self).forward(x) 232 | return out 233 | 234 | # --- Conv2d module --- 235 | class Conv2d_fw(nn.Conv2d): #used in MAML to forward input with fast weight 236 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,padding=0, bias = True): 237 | super(Conv2d_fw, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias) 238 | self.weight.fast = None 239 | if not self.bias is None: 240 | self.bias.fast = None 241 | 242 | def forward(self, x): 243 | if self.bias is None: 244 | if self.weight.fast is not None: 245 | out = F.conv2d(x, self.weight.fast, None, stride= self.stride, padding=self.padding) 246 | else: 247 | out = super(Conv2d_fw, self).forward(x) 248 | else: 249 | if self.weight.fast is not None and self.bias.fast is not None: 250 | out = F.conv2d(x, self.weight.fast, self.bias.fast, stride= self.stride, padding=self.padding) 251 | else: 252 | out = super(Conv2d_fw, self).forward(x) 253 | return out 254 | 255 | # --- softplus module --- 256 | def softplus(x): 257 | return torch.nn.functional.softplus(x, beta=100) 258 | 259 | # --- feature-wise transformation layer --- 260 | class FeatureWiseTransformation2d_fw(nn.BatchNorm2d): 261 | feature_augment = False 262 | def __init__(self, num_features, momentum=0.1, track_running_stats=True): 263 | super(FeatureWiseTransformation2d_fw, self).__init__(num_features, momentum=momentum, track_running_stats=track_running_stats) 264 | self.weight.fast = None 265 | self.bias.fast = None 266 | if self.track_running_stats: 267 | self.register_buffer('running_mean', torch.zeros(num_features)) 268 | self.register_buffer('running_var', torch.zeros(num_features)) 269 | if self.feature_augment: # initialize {gamma, beta} with {0.3, 0.5} 270 | self.gamma = torch.nn.Parameter(torch.ones(1, num_features, 1, 1)*0.3) 271 | self.beta = torch.nn.Parameter(torch.ones(1, num_features, 1, 1)*0.5) 272 | self.reset_parameters() 273 | 274 | def reset_running_stats(self): 275 | if self.track_running_stats: 276 | self.running_mean.zero_() 277 | self.running_var.fill_(1) 278 | 279 | def forward(self, x, step=0): 280 | if self.weight.fast is not None and self.bias.fast is not None: 281 | weight = self.weight.fast 282 | bias = self.bias.fast 283 | else: 284 | weight = self.weight 285 | bias = self.bias 286 | if self.track_running_stats: 287 | out = F.batch_norm(x, self.running_mean, self.running_var, weight, bias, training=self.training, momentum=self.momentum) 288 | else: 289 | out = F.batch_norm(x, torch.zeros_like(x), torch.ones_like(x), weight, bias, training=True, momentum=1) 290 | 291 | # apply feature-wise transformation 292 | if self.feature_augment and self.training: 293 | gamma = (1 + torch.randn(1, self.num_features, 1, 1, dtype=self.gamma.dtype, device=self.gamma.device)*softplus(self.gamma)).expand_as(out) 294 | beta = (torch.randn(1, self.num_features, 1, 1, dtype=self.beta.dtype, device=self.beta.device)*softplus(self.beta)).expand_as(out) 295 | out = gamma*out + beta 296 | return out 297 | 298 | # --- BatchNorm2d --- 299 | class BatchNorm2d_fw(nn.BatchNorm2d): 300 | def __init__(self, num_features, momentum=0.1, track_running_stats=True): 301 | super(BatchNorm2d_fw, self).__init__(num_features, momentum=momentum, track_running_stats=track_running_stats) 302 | self.weight.fast = None 303 | self.bias.fast = None 304 | if self.track_running_stats: 305 | self.register_buffer('running_mean', torch.zeros(num_features)) 306 | self.register_buffer('running_var', torch.zeros(num_features)) 307 | self.reset_parameters() 308 | 309 | def reset_running_stats(self): 310 | if self.track_running_stats: 311 | self.running_mean.zero_() 312 | self.running_var.fill_(1) 313 | 314 | def forward(self, x, step=0): 315 | if self.weight.fast is not None and self.bias.fast is not None: 316 | weight = self.weight.fast 317 | bias = self.bias.fast 318 | else: 319 | weight = self.weight 320 | bias = self.bias 321 | if self.track_running_stats: 322 | out = F.batch_norm(x, self.running_mean, self.running_var, weight, bias, training=self.training, momentum=self.momentum) 323 | else: 324 | out = F.batch_norm(x, torch.zeros(x.size(1), dtype=x.dtype, device=x.device), torch.ones(x.size(1), dtype=x.dtype, device=x.device), weight, bias, training=True, momentum=1) 325 | return out 326 | 327 | # --- BatchNorm1d --- 328 | class BatchNorm1d_fw(nn.BatchNorm1d): 329 | def __init__(self, num_features, momentum=0.1, track_running_stats=True): 330 | super(BatchNorm1d_fw, self).__init__(num_features, momentum=momentum, track_running_stats=track_running_stats) 331 | self.weight.fast = None 332 | self.bias.fast = None 333 | if self.track_running_stats: 334 | self.register_buffer('running_mean', torch.zeros(num_features)) 335 | self.register_buffer('running_var', torch.zeros(num_features)) 336 | self.reset_parameters() 337 | 338 | def reset_running_stats(self): 339 | if self.track_running_stats: 340 | self.running_mean.zero_() 341 | self.running_var.fill_(1) 342 | 343 | def forward(self, x, step=0): 344 | if self.weight.fast is not None and self.bias.fast is not None: 345 | weight = self.weight.fast 346 | bias = self.bias.fast 347 | else: 348 | weight = self.weight 349 | bias = self.bias 350 | if self.track_running_stats: 351 | out = F.batch_norm(x, self.running_mean, self.running_var, weight, bias, training=self.training, momentum=self.momentum) 352 | else: 353 | out = F.batch_norm(x, torch.zeros(x.size(1), dtype=x.dtype, device=x.device), torch.ones(x.size(1), dtype=x.dtype, device=x.device), weight, bias, training=True, momentum=1) 354 | return out 355 | 356 | # --- Simple Conv Block --- 357 | class ConvBlock(nn.Module): 358 | maml = False 359 | def __init__(self, indim, outdim, pool = True, padding = 1): 360 | super(ConvBlock, self).__init__() 361 | self.indim = indim 362 | self.outdim = outdim 363 | if self.maml: 364 | self.C = Conv2d_fw(indim, outdim, 3, padding = padding) 365 | self.BN = FeatureWiseTransformation2d_fw(outdim) 366 | else: 367 | self.C = nn.Conv2d(indim, outdim, 3, padding= padding) 368 | self.BN = nn.BatchNorm2d(outdim) 369 | self.relu = nn.ReLU(inplace=True) 370 | 371 | self.parametrized_layers = [self.C, self.BN, self.relu] 372 | if pool: 373 | self.pool = nn.MaxPool2d(2) 374 | self.parametrized_layers.append(self.pool) 375 | 376 | for layer in self.parametrized_layers: 377 | init_layer(layer) 378 | self.trunk = nn.Sequential(*self.parametrized_layers) 379 | 380 | def forward(self,x): 381 | out = self.trunk(x) 382 | return out 383 | 384 | # --- Basic ResNet Block (resnet12) --- 385 | class BasicBlock(nn.Module): 386 | expansion = 1 387 | def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False, 388 | block_size=1, use_se=False): 389 | super(BasicBlock, self).__init__() 390 | self.conv1 = conv3x3(inplanes, planes) 391 | self.bn1 = nn.BatchNorm2d(planes) 392 | self.relu = nn.LeakyReLU(0.1) 393 | self.conv2 = conv3x3(planes, planes) 394 | self.bn2 = nn.BatchNorm2d(planes) 395 | self.conv3 = conv3x3(planes, planes) 396 | self.bn3 = nn.BatchNorm2d(planes) 397 | self.maxpool = nn.MaxPool2d(stride) 398 | self.downsample = downsample 399 | self.stride = stride 400 | self.drop_rate = drop_rate 401 | self.num_batches_tracked = 0 402 | self.drop_block = drop_block 403 | self.block_size = block_size 404 | self.DropBlock = DropBlock(block_size=self.block_size) 405 | self.use_se = use_se 406 | if self.use_se: 407 | self.se = SELayer(planes, 4) 408 | 409 | def forward(self, x): 410 | self.num_batches_tracked += 1 411 | 412 | residual = x 413 | 414 | out = self.conv1(x) 415 | out = self.bn1(out) 416 | out = self.relu(out) 417 | 418 | out = self.conv2(out) 419 | out = self.bn2(out) 420 | out = self.relu(out) 421 | 422 | out = self.conv3(out) 423 | out = self.bn3(out) 424 | if self.use_se: 425 | out = self.se(out) 426 | 427 | if self.downsample is not None: 428 | residual = self.downsample(x) 429 | out += residual 430 | out = self.relu(out) 431 | out = self.maxpool(out) 432 | 433 | if self.drop_rate > 0: 434 | if self.drop_block == True: 435 | feat_size = out.size()[2] 436 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate) 437 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2 438 | out = self.DropBlock(out, gamma=gamma) 439 | else: 440 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True) 441 | 442 | return out 443 | 444 | # --- Simple ResNet Block --- 445 | class SimpleBlock(nn.Module): 446 | maml = False 447 | def __init__(self, indim, outdim, half_res, leaky=False): 448 | super(SimpleBlock, self).__init__() 449 | self.indim = indim 450 | self.outdim = outdim 451 | if self.maml: 452 | self.C1 = Conv2d_fw(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False) 453 | self.BN1 = BatchNorm2d_fw(outdim) 454 | self.C2 = Conv2d_fw(outdim, outdim,kernel_size=3, padding=1,bias=False) 455 | self.BN2 = FeatureWiseTransformation2d_fw(outdim) # feature-wise transformation at the end of each residual block 456 | else: 457 | self.C1 = nn.Conv2d(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False) 458 | self.BN1 = nn.BatchNorm2d(outdim, track_running_stats=TRACKBN) 459 | self.C2 = nn.Conv2d(outdim, outdim,kernel_size=3, padding=1,bias=False) 460 | self.BN2 = nn.BatchNorm2d(outdim, track_running_stats=TRACKBN) 461 | self.relu1 = nn.ReLU(inplace=True) if not leaky else nn.LeakyReLU(0.2, inplace=True) 462 | self.relu2 = nn.ReLU(inplace=True) if not leaky else nn.LeakyReLU(0.2, inplace=True) 463 | 464 | self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2] 465 | 466 | self.half_res = half_res 467 | 468 | # if the input number of channels is not equal to the output, then need a 1x1 convolution 469 | if indim!=outdim: 470 | if self.maml: 471 | self.shortcut = Conv2d_fw(indim, outdim, 1, 2 if half_res else 1, bias=False) 472 | self.BNshortcut = FeatureWiseTransformation2d_fw(outdim) 473 | else: 474 | self.shortcut = nn.Conv2d(indim, outdim, 1, 2 if half_res else 1, bias=False) 475 | self.BNshortcut = nn.BatchNorm2d(outdim, track_running_stats=TRACKBN) 476 | 477 | self.parametrized_layers.append(self.shortcut) 478 | self.parametrized_layers.append(self.BNshortcut) 479 | self.shortcut_type = '1x1' 480 | else: 481 | self.shortcut_type = 'identity' 482 | 483 | for layer in self.parametrized_layers: 484 | init_layer(layer) 485 | 486 | def forward(self, x): 487 | out = self.C1(x) 488 | out = self.BN1(out) 489 | out = self.relu1(out) 490 | out = self.C2(out) 491 | out = self.BN2(out) 492 | short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x)) 493 | out = out + short_out 494 | out = self.relu2(out) 495 | return out 496 | 497 | # --- ConvNet module --- 498 | class ConvNet(nn.Module): 499 | def __init__(self, depth, flatten = True): 500 | super(ConvNet,self).__init__() 501 | self.grads = [] 502 | self.fmaps = [] 503 | trunk = [] 504 | for i in range(depth): 505 | indim = 3 if i == 0 else 64 506 | outdim = 64 507 | B = ConvBlock(indim, outdim, pool = ( i <4 ) ) #only pooling for fist 4 layers 508 | trunk.append(B) 509 | 510 | if flatten: 511 | trunk.append(Flatten()) 512 | 513 | self.trunk = nn.Sequential(*trunk) 514 | self.final_feat_dim = 1600 515 | 516 | def forward(self,x): 517 | out = self.trunk(x) 518 | return out 519 | 520 | class ConvNet_medium(nn.Module): 521 | def __init__(self, depth, flatten = True): 522 | super(ConvNet_medium,self).__init__() 523 | self.grads = [] 524 | self.fmaps = [] 525 | trunk = [] 526 | for i in range(depth): 527 | indim = 3 if i == 0 else 64 528 | if i==depth-1: 529 | outdim = 20 530 | self.final_feat_dim = outdim*25 531 | else: 532 | outdim = 64 533 | B = ConvBlock(indim, outdim, pool = ( i <4 ) ) #only pooling for fist 4 layers 534 | trunk.append(B) 535 | 536 | if flatten: 537 | trunk.append(Flatten()) 538 | 539 | self.trunk = nn.Sequential(*trunk) 540 | 541 | def forward(self,x): 542 | out = self.trunk(x) 543 | return out 544 | 545 | 546 | class ConvNet_small(nn.Module): 547 | def __init__(self, depth, flatten = True): 548 | super(ConvNet_small,self).__init__() 549 | self.grads = [] 550 | self.fmaps = [] 551 | trunk = [] 552 | for i in range(depth): 553 | indim = 3 if i == 0 else 64 554 | if i==depth-1: 555 | outdim = 6 556 | self.final_feat_dim = outdim*25 557 | else: 558 | outdim = 64 559 | B = ConvBlock(indim, outdim, pool = ( i <4 ) ) #only pooling for fist 4 layers 560 | trunk.append(B) 561 | 562 | if flatten: 563 | trunk.append(Flatten()) 564 | 565 | self.trunk = nn.Sequential(*trunk) 566 | 567 | def forward(self,x): 568 | out = self.trunk(x) 569 | return out 570 | 571 | # --- ConvNetNopool module --- 572 | class ConvNetNopool(nn.Module): #Relation net use a 4 layer conv with pooling in only first two layers, else no pooling 573 | def __init__(self, depth): 574 | super(ConvNetNopool,self).__init__() 575 | self.grads = [] 576 | self.fmaps = [] 577 | trunk = [] 578 | for i in range(depth): 579 | indim = 3 if i == 0 else 64 580 | outdim = 64 581 | B = ConvBlock(indim, outdim, pool = ( i in [0,1] ), padding = 0 if i in[0,1] else 1 ) #only first two layer has pooling and no padding 582 | trunk.append(B) 583 | 584 | self.trunk = nn.Sequential(*trunk) 585 | self.final_feat_dim = [64,19,19] 586 | 587 | def forward(self,x): 588 | out = self.trunk(x) 589 | return out 590 | 591 | # --- ResNet module --- 592 | class ResNet(nn.Module): 593 | maml = False 594 | def __init__(self,block,list_of_num_layers, list_of_out_dims, flatten=True, leakyrelu=False): 595 | # list_of_num_layers specifies number of layers in each stage 596 | # list_of_out_dims specifies number of output channel for each stage 597 | super(ResNet,self).__init__() 598 | self.grads = [] 599 | self.fmaps = [] 600 | assert len(list_of_num_layers)==4, 'Can have only four stages' 601 | if self.maml: 602 | conv1 = Conv2d_fw(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 603 | bn1 = BatchNorm2d_fw(64) 604 | else: 605 | conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 606 | bn1 = nn.BatchNorm2d(64, track_running_stats=TRACKBN) 607 | 608 | relu = nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True) 609 | pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 610 | 611 | init_layer(conv1) 612 | init_layer(bn1) 613 | 614 | trunk = [conv1, bn1, relu, pool1] 615 | 616 | indim = 64 617 | for i in range(4): 618 | for j in range(list_of_num_layers[i]): 619 | half_res = (i>=1) and (j==0) 620 | B = block(indim, list_of_out_dims[i], half_res, leaky=leakyrelu) 621 | trunk.append(B) 622 | indim = list_of_out_dims[i] 623 | self.last_block = B 624 | 625 | if flatten: 626 | avgpool = nn.AvgPool2d(7) 627 | trunk.append(avgpool) 628 | trunk.append(Flatten()) 629 | self.final_feat_dim = indim 630 | else: 631 | # self.final_feat_dim = [ indim, 7, 7] 632 | self.final_feat_dim = indim 633 | 634 | self.trunk = nn.Sequential(*trunk) 635 | 636 | def forward(self,x): 637 | out = self.trunk(x) 638 | return out 639 | 640 | # --- Conv networks --- 641 | def Conv4_small(flatten, leakyrelu): 642 | return ConvNet_small(4) 643 | def Conv4_medium(flatten, leakyrelu): 644 | return ConvNet_medium(4) 645 | def Conv4(flatten, leakyrelu): 646 | return ConvNet(4, flatten=flatten) 647 | def Conv6(): 648 | return ConvNet(6) 649 | def Conv4NP(): 650 | return ConvNetNopool(4) 651 | def Conv6NP(): 652 | return ConvNetNopool(6) 653 | 654 | # --- ResNet networks --- 655 | def ResNet10(flatten=True, leakyrelu=False): 656 | return ResNet(SimpleBlock, [1,1,1,1],[64,128,256,512], flatten, leakyrelu) 657 | def ResNet18(flatten=True, leakyrelu=False): 658 | return ResNet(SimpleBlock, [2,2,2,2],[64,128,256,512], flatten, leakyrelu) 659 | def ResNet34(flatten=True, leakyrelu=False): 660 | return ResNet(SimpleBlock, [3,4,6,3],[64,128,256,512], flatten, leakyrelu) 661 | 662 | model_dict = dict(Conv4 = Conv4, 663 | Conv4_small = Conv4_small, 664 | Conv4_medium = Conv4_medium, 665 | Conv6 = Conv6, 666 | ResNet10 = ResNet10, 667 | ResNet18 = ResNet18, 668 | ResNet34 = ResNet34) 669 | 670 | -------------------------------------------------------------------------------- /methods/weight_imprint_based/ConCE.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import ipdb 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | from torch.nn.utils import weight_norm 9 | 10 | from methods.weight_imprint_based import WeightImprint 11 | from methods import backbone 12 | from tensorboardX import SummaryWriter 13 | 14 | from utils import * 15 | 16 | EPS=0.00001 17 | def weight_reset(m): 18 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.BatchNorm2d): 19 | m.reset_parameters() 20 | 21 | class ConCeModel(WeightImprint): 22 | maml = False 23 | 24 | def __init__(self, model_func, tau, n_way, tf_path=None, loadpath=None, projhead=False, 25 | is_distribute=False, src_classes=0, ft_mode=False, cos_fac=1.0): 26 | self.tau = tau 27 | self.n_way = n_way 28 | self.src_classes = src_classes 29 | self.cos_fac = cos_fac 30 | self.model_func = model_func 31 | self.tf_path = tf_path 32 | self.projhead = projhead 33 | 34 | # primary head 35 | self.ft_mode = ft_mode 36 | self.loadpath = loadpath 37 | self.is_distribute = is_distribute 38 | self.init_parameters() 39 | self.init_model() 40 | 41 | def init_parameters(self): 42 | self.init_checkpoint=None 43 | if self.projhead: 44 | self.projected_feature_dim = 64 45 | print('--- with projection %d ---'%(self.projected_feature_dim)) 46 | else: 47 | print('--- No projection ---') 48 | 49 | if 'mtce' in self.ft_mode: 50 | print('--- Secondary classifier indim ---') 51 | 52 | def init_model(self): 53 | super(ConCeModel, self).__init__(model_func=self.model_func, tf_path=self.tf_path) 54 | if self.ft_mode == 'ce_mtce' or self.ft_mode == 'ce': 55 | self.L = weight_norm(nn.Linear(self.feat_dim, self.n_way, bias=False), name='weight', dim=0) 56 | self.projection_head = None 57 | else: 58 | if self.projhead: 59 | self.projection_head = nn.Sequential( 60 | nn.Linear(self.feat_dim, self.projected_feature_dim, bias=True) 61 | ) 62 | else: 63 | self.projection_head = None 64 | 65 | # secondary head 66 | if 'mtce' in self.ft_mode: 67 | self.source_L = weight_norm(nn.Linear(self.feat_dim, self.src_classes, bias=False), name='weight', dim=0) 68 | 69 | self.loss_fn = nn.CrossEntropyLoss() 70 | if self.loadpath != None: 71 | self.load_model() 72 | if self.is_distribute: 73 | self.distribute_model() 74 | 75 | def load_model(self): 76 | if self.init_checkpoint is not None: 77 | self.load_state_dict(self.init_checkpoint, strict=False) 78 | else: 79 | state = torch.load(self.loadpath) 80 | loadstate = {} 81 | if 'state' in state.keys(): 82 | state = state['state'] 83 | for key in state.keys(): 84 | if 'feature.module' in key: 85 | loadstate[key.replace('feature.module', 'feature')] = state[key] 86 | else: 87 | loadstate[key] = state[key] 88 | elif 'state_dict' in state.keys(): 89 | state = state['state_dict'] 90 | for key in state.keys(): 91 | if 'module.encoder_k' in key: 92 | loadstate[key.replace('module.encoder_k', 'feature')] = state[key] 93 | else: 94 | loadstate[key] = state[key] 95 | self.init_checkpoint = loadstate 96 | self.load_state_dict(loadstate, strict=False) 97 | return self 98 | 99 | 100 | def refresh_from_chkpt(self): 101 | self.init_model() 102 | 103 | def distribute_model(self): 104 | self.feature = nn.DataParallel(self.feature) 105 | return self 106 | 107 | def forward_projection(self, z): 108 | if self.projection_head is not None: 109 | return self.projection_head(z) 110 | else: 111 | return z 112 | 113 | def forward_this(self, x): 114 | return self.forward_projection(self.get_feature(x)) 115 | 116 | def ewn_contrastive_loss(self, z, mask_pos, mask_neg, mask_distract, n_s, alpha): 117 | # equally weighted task and distractor negative contrastive loss 118 | bsz, featdim = z.size() 119 | z_square = z.view(bsz, 1, featdim).repeat(1, bsz, 1) 120 | sim = nn.CosineSimilarity(dim=2)(z_square, z_square.transpose(1, 0)) 121 | Sv = torch.exp(sim / self.tau) 122 | neg = (Sv * mask_neg) 123 | neg = alpha*(1-mask_distract)*neg + (1-alpha)*mask_distract*neg 124 | neg = 2*neg 125 | neg = neg.sum(dim=1).unsqueeze(1).repeat(1, bsz) 126 | li = mask_pos * torch.log(Sv / (Sv + neg) + EPS) 127 | li = li - li.diag().diag() 128 | li = (1 / (n_s - 1)) * li.sum(dim=1) 129 | loss = -li[mask_pos.sum(dim=1) > 0].mean() 130 | return loss 131 | 132 | ################# MAIN ############## 133 | def cssf_loss(self, z, shots_per_way, n_way, n_ul, mode='lpan', alpha=None): 134 | # labelled positives and all negatives 135 | n_pos = 2 136 | n_l = n_way * shots_per_way 137 | # positive mask 138 | T1 = np.eye(int(n_l/n_pos)) 139 | T2 = np.ones((n_pos, n_pos)) 140 | mask_pos_lab = torch.FloatTensor(np.kron(T1, T2)) 141 | T3 = torch.cat([mask_pos_lab, torch.zeros(n_l, n_ul)], dim=1) 142 | T4 = torch.zeros(n_ul, n_l+n_ul) 143 | mask_pos = torch.cat([T3,T4], dim=0).to(z.device) 144 | # negative mask 145 | T1 = 1-np.eye(n_way) 146 | T2 = np.ones((shots_per_way, shots_per_way)) 147 | mask_neg_lab = torch.FloatTensor(np.kron(T1, T2)) 148 | T3 = torch.cat([mask_neg_lab, torch.ones(n_l, n_ul)], dim=1) 149 | T4 = torch.ones(n_ul, n_l + n_ul) # dummy 150 | mask_neg = torch.cat([T3,T4], dim=0).to(z.device) 151 | T3 = torch.cat([torch.zeros(n_l, n_l), torch.ones(n_l, n_ul)], dim=1) 152 | mask_distract = torch.cat([T3, T4], dim=0).to(z.device) 153 | alpha = n_ul/(n_ul + n_l - shots_per_way) 154 | return self.ewn_contrastive_loss(z, mask_pos, mask_neg, mask_distract, n_pos, alpha) 155 | 156 | ################# MAIN ############## 157 | 158 | def get_classification_scores(self, z, classifier): 159 | z_norm = torch.norm(z, p=2, dim=1).unsqueeze(1).expand_as(z) 160 | z_normalized = z.div(z_norm + EPS) 161 | L_norm = torch.norm(classifier.weight.data, p=2, dim=1).unsqueeze(1).expand_as(classifier.weight.data) 162 | classifier.weight.data = classifier.weight.data.div(L_norm + EPS) 163 | cos_dist = classifier(z_normalized) 164 | scores = self.cos_fac * cos_dist 165 | return scores 166 | 167 | def CE_loss(self, x, y): 168 | z = self.get_feature(x) 169 | scores = self.get_classification_scores(z, self.L) 170 | loss = self.loss_fn(scores, y) 171 | return loss 172 | 173 | def CE_loss_source(self, z, y): 174 | scores = self.get_classification_scores(z, self.source_L) 175 | loss = self.loss_fn(scores, y) 176 | return loss 177 | 178 | def get_linear_classification_scores(self, z, classifier): 179 | scores = classifier(z) 180 | return scores 181 | 182 | def LCE_loss(self, x, y): 183 | z = self.get_feature(x) 184 | scores = self.get_linear_classification_scores(z, self.L) 185 | loss = self.loss_fn(scores, y) 186 | return loss 187 | -------------------------------------------------------------------------------- /methods/weight_imprint_based/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import ipdb 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn.utils import weight_norm 8 | from tensorboardX import SummaryWriter 9 | 10 | class WeightImprint(nn.Module): 11 | def __init__(self, model_func, tf_path=None, loadpath=None, is_distribute=False, flatten=True, leakyrelu=False): 12 | super(WeightImprint, self).__init__() 13 | self.method = 'WeightImprint' 14 | self.model_func=model_func 15 | self.feature = model_func(flatten=flatten, leakyrelu=leakyrelu) 16 | self.feat_dim = self.feature.final_feat_dim 17 | self.tf_path = tf_path 18 | self.tf_writer = SummaryWriter(log_dir=self.tf_path) 19 | self.loss_fn = nn.CrossEntropyLoss() 20 | if loadpath != None: 21 | self.load_model(loadpath) 22 | if is_distribute: 23 | self.distribute_model() 24 | 25 | def load_model(self, loadpath): 26 | state = torch.load(loadpath) 27 | if 'state' in state.keys(): 28 | state = state['state'] 29 | loadstate = {} 30 | for key in state.keys(): 31 | if 'feature.module' in key: 32 | loadstate[key.replace('feature.module', 'feature')] = state[key] 33 | else: 34 | loadstate[key] = state[key] 35 | self.load_state_dict(loadstate, strict=False) 36 | return self 37 | 38 | def distribute_model(self): 39 | self.feature = nn.DataParallel(self.feature) 40 | return self 41 | 42 | def get_feature(self, x): 43 | return self.feature(x) 44 | 45 | def fewshot_task_loss(self, x, n_way, n_support, n_query): 46 | y_query = torch.from_numpy(np.repeat(range(n_way), n_query)) 47 | y_query = y_query.cuda() 48 | x = x.contiguous().view(n_way * (n_support + n_query), *x.size()[2:]) 49 | z_all_linearized = self.get_feature(x) 50 | z_all = z_all_linearized.view(n_way, n_support + n_query, -1) 51 | z_support = z_all[:, :n_support] 52 | z_query = z_all[:, n_support:] 53 | z_support = z_support.contiguous() 54 | z_proto = z_support.view(n_way, n_support, -1).mean(1) # the shape of z is [n_data, n_dim] 55 | z_query = z_query.contiguous().view(n_way * n_query, -1) 56 | 57 | # normalize 58 | z_proto = F.normalize(z_proto, dim=1) 59 | z_query = F.normalize(z_query, dim=1) 60 | 61 | scores = cosine_dist(z_query, z_proto) 62 | loss = self.loss_fn(scores, y_query) 63 | return scores, loss, z_all_linearized 64 | 65 | def validate(self, n_way, n_support, n_query, x, epoch): 66 | self.eval() 67 | scores, loss, z_all = self.fewshot_task_loss(x, n_way, n_support, n_query) 68 | y_query = np.repeat(range(n_way), n_query) 69 | topk_scores, topk_labels = scores.data.topk(1, 1, True, True) 70 | topk_ind = topk_labels.cpu().numpy() 71 | top1_correct = np.sum(topk_ind[:, 0] == y_query) 72 | correct_this, count_this = float(top1_correct), len(y_query) 73 | acc_after = correct_this / count_this * 100 74 | self.tf_writer.add_scalar('validation/acc_before_training', acc_after, epoch + 1) 75 | return acc_after, loss, z_all 76 | 77 | def _get_epi_dist(self, n_way, n_samples, z_all): 78 | # cluster spread and separation 79 | all_affinities = cosine_dist(z_all, z_all) 80 | # cluster spread 81 | T1 = np.eye(n_way) 82 | T2 = np.ones((n_samples, n_samples)) 83 | mask_pos = torch.FloatTensor(np.kron(T1, T2)).to(all_affinities.device) 84 | cluster_spread = (1-all_affinities)*mask_pos 85 | cluster_spread = cluster_spread.sum(dim=1).contiguous().view(n_way, n_samples) 86 | cluster_spread = cluster_spread.sum(dim=1)/2 87 | cluster_spread = cluster_spread.mean(dim=0) 88 | 89 | #cluster_sep 90 | mask_neg = 1-mask_pos 91 | cluster_sep = (1 - all_affinities) * mask_neg 92 | cluster_sep = cluster_sep.sum(dim=1).contiguous().view(n_way, n_samples) 93 | cluster_sep = cluster_sep.sum(dim=1).mean(dim=0) 94 | return cluster_spread.item(), cluster_sep.item() 95 | 96 | def get_episode_distances(self, n_way, n_support, n_query, z_all): 97 | z_all = z_all.detach().cpu() 98 | n_samples = n_support+n_query 99 | 100 | z_all = F.normalize(z_all, dim=1) 101 | z_all_reshaped = z_all.contiguous().view(n_way, n_samples, z_all.shape[-1]) 102 | z_support = z_all_reshaped[:,:n_support].contiguous().view(-1, z_all.shape[-1]) 103 | z_query = z_all_reshaped[:,n_support:].contiguous().view(-1, z_all.shape[-1]) 104 | cspread_support, csep_support = self._get_epi_dist(n_way, n_support, z_support) 105 | cspread_query, csep_query = self._get_epi_dist(n_way, n_query, z_query) 106 | if n_support==1: 107 | cspread_support = np.random.rand(1)[0]*0.00001 108 | 109 | return cspread_support, csep_support, cspread_query, csep_query 110 | 111 | 112 | ######################################################### 113 | class LinearEvaluator(nn.Module): 114 | def __init__(self, feature, outdim, train_size): 115 | super(LinearEvaluator, self).__init__() 116 | self.feature = feature 117 | self.L = weight_norm(nn.Linear(feature.feat_dim, outdim, bias=False), name='weight', dim=0) 118 | self.loss_fn = nn.CrossEntropyLoss() 119 | self.batch_size = 64 120 | self.train_size = train_size 121 | self.train_iters = 50 122 | 123 | def get_scores(self, x): 124 | x_norm = torch.norm(x, p=2, dim=1).unsqueeze(1).expand_as(x) 125 | x_normalized = x.div(x_norm + 0.00001) 126 | L_norm = torch.norm(self.L.weight.data, p=2, dim=1).unsqueeze(1).expand_as(self.L.weight.data) 127 | self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001) 128 | cos_dist = self.L(x_normalized) 129 | scores = 10 * cos_dist 130 | return scores 131 | 132 | def train_classifier(self, input, target): 133 | classifier_opt = torch.optim.SGD(self.L.parameters(), 134 | lr=0.01, momentum=0.9, dampening=0.9, weight_decay=0.001) 135 | self.cuda() 136 | self.feature.eval() 137 | self.L.train() 138 | for epoch in range(self.train_iters): 139 | rand_id = np.random.permutation(self.train_size) 140 | for j in range(0, self.train_size, self.batch_size): 141 | classifier_opt.zero_grad() 142 | selected_id = torch.from_numpy(rand_id[j: min(j + self.batch_size, self.train_size)]).cuda() 143 | x_batch = input[selected_id] 144 | y_batch = target[selected_id] 145 | output = self.get_scores(self.feature.get_feature(x_batch)) 146 | loss = self.loss_fn(output, y_batch) 147 | loss.backward() 148 | classifier_opt.step() 149 | 150 | def test_classifier(self, input, target): 151 | self.feature.eval() 152 | self.L.eval() 153 | scores = self.get_scores(self.feature.get_feature(input)) 154 | topk_scores, topk_labels = scores.data.topk(1, 1, True, True) 155 | topk_ind = topk_labels.cpu().numpy() 156 | top1_correct = np.sum(topk_ind[:, 0] == target) 157 | correct_this, count_this = float(top1_correct), len(target) 158 | acc = correct_this / count_this * 100 159 | return acc 160 | 161 | def evaluate(self, xtrain, ytrain, xtest, ytest): 162 | self.train_classifier(xtrain, ytrain) 163 | task_acc = self.test_classifier(xtest, ytest) 164 | return task_acc 165 | 166 | 167 | 168 | ######################################################### 169 | def cosine_dist(x, y): 170 | # x: N x D 171 | # y: M x D 172 | n = x.size(0) 173 | m = y.size(0) 174 | d = x.size(1) 175 | assert d == y.size(1) 176 | 177 | x = x.unsqueeze(1).expand(n, m, d) 178 | y = y.unsqueeze(0).expand(n, m, d) 179 | alignment = nn.functional.cosine_similarity(x, y, dim=2) 180 | return alignment 181 | -------------------------------------------------------------------------------- /mt_conft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Experimental Settings and Hyperparameters 4 | MODE='ewn_lpan_mtce' 5 | TgtSet=cub #options: [cub/cars/places/plantae] 6 | DtracSet='./filelists/miniImagenet/base.json' 7 | DtracBsz=64 8 | FTepoch=200 9 | TAU=0.1 10 | # -------- Run command --------- 11 | CUDA_VISIBLE_DEVICES=0 python finetune.py \ 12 | --ft_mode $MODE \ 13 | --targetset $TgtSet --is_tgt_aug \ 14 | --distractor_set $DtracSet \ 15 | --distractor_bsz $DtracBsz \ 16 | --stop_epoch $FTepoch --tau $TAU \ 17 | --name Mode-$MODE/TgtSet-$TgtSet_DSET-$DSET/DtracBsz-$DtracBsz_FTepoch-$FTepoch_TAU-$TAU \ 18 | --load-modelpath 'output/checkpoints/baseline/399.tar' \ 19 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | import torch 5 | import argparse 6 | import ipdb 7 | 8 | def parse_args(script): 9 | parser = argparse.ArgumentParser(description= 'few-shot script %s' %(script)) 10 | parser.add_argument('--dataset', default='multi', help='miniImagenet/cub/cars/places/plantae, specify multi for training with multiple domains') 11 | parser.add_argument('--testset', default='cub', help='cub/cars/places/plantae, valid only when dataset=multi') 12 | parser.add_argument('--valset', default='cub', help='cub/cars/places/plantae, valid only when dataset=multi') 13 | parser.add_argument('--model', default='ResNet10', help='model: Conv{4|6} / ResNet{10|18|34}') 14 | parser.add_argument('--method-type', default='baseline',help='baseline/baseline++/protonet/matchingnet/relationnet{_softmax}/gnnnet') 15 | parser.add_argument('--train_n_way' , default=5, type=int, help='class num to classify for training') 16 | parser.add_argument('--test_n_way' , default=5, type=int, help='class num to classify for testing (validation) ') 17 | parser.add_argument('--n_shot' , default=5, type=int, help='number of labeled data in each class, same as n_support') 18 | parser.add_argument('--n_query' , default=15, type=int, help='num queries') 19 | parser.add_argument('--train_aug' , action='store_true', help='perform data augmentation or not during training ') 20 | parser.add_argument('--debug' , action='store_true', help='debug') 21 | parser.add_argument('--name' , default='tmp', type=str, help='') 22 | parser.add_argument('--save_dir' , default='output', type=str, help='directory for logs and checkpoints') 23 | parser.add_argument('--data-dir' , default='./filelists', type=str, help='') 24 | parser.add_argument('--image-size', default=84, type=int, help='tUn vs Semi supervised') 25 | parser.add_argument('--load-modelpath', default=None, type=str, help='') 26 | parser.add_argument('--target-datapath', default=None, help='if specific name for target path') 27 | parser.add_argument('--augstrength', default='0', type=str, help='level of augmentation') 28 | parser.add_argument('--num_classes', default=200, type=int, 29 | help='total number of classes in softmax, only used in baseline') 30 | parser.add_argument('--freeze_backbone', action='store_true', 31 | help='perform data augmentation or not during training ') 32 | if 'train' in script: 33 | parser.add_argument('--save_freq' , default=25, type=int, help='Save frequency') 34 | parser.add_argument('--start_epoch' , default=0, type=int,help ='Starting epoch') 35 | parser.add_argument('--stop_epoch' , default=100000000, type=int, help ='Stopping epoch') 36 | parser.add_argument('--lr', default=0.001, type=float, help='learning rate') 37 | parser.add_argument('--resume' , default='', type=str, help='continue from previous trained model with largest epoch') 38 | parser.add_argument('--resume_epoch', default=-1, type=int, help='') 39 | parser.add_argument('--warmup' , default='gg3b0', type=str, help='continue from baseline, neglected if resume is true') 40 | parser.add_argument('--Nmb', default=1, type=int, help='num episodes per batch') 41 | parser.add_argument('--hyperparam_select', action='store_true', help='hyperparameter selection using val split') 42 | parser.add_argument("--subpix_pcnt", type=float, default=0.2, help="top score pcnt for dense protonet.") 43 | if 'contrastive' in script: 44 | parser.add_argument('--targetset', default='cub', help='for adaptation') 45 | parser.add_argument('--temperature', default=0.5, type=float, help='contrastive loss temperature') 46 | parser.add_argument('--contrastive-wt', default=1.0, type=float, help='contrastive loss temperature') 47 | parser.add_argument('--contrastive-batch-size', default=-1, type=int, help='feature batchsz') 48 | elif 'clusterReg' in script: 49 | parser.add_argument('--targetset', default='cub', help='for adaptation') 50 | parser.add_argument('--cr-wt', default=1.0, type=float, help='contrastive loss temperature') 51 | parser.add_argument('--cr-batch-size', default=-1, type=int, help='feature batchsz') 52 | parser.add_argument('--num-clusters', default=200, type=int, help=' num clusters') 53 | parser.add_argument('--sim-bias', default=0.5, type=float, help=' similarity bias') 54 | elif 'ufsl' in script: 55 | parser.add_argument('--targetset', default='cub', help='for adaptation') 56 | parser.add_argument('--ufsl-wt', default=1.0, type=float, help='contrastive loss temperature') 57 | parser.add_argument('--dominant-id', default=0, type=int, help='redefine sampling based on gt classes, only for ' 58 | 'debugging') 59 | parser.add_argument('--dominant-p', default=0.9, type=float, help='only for debugging') 60 | parser.add_argument('--projection-head', action='store_true', help='proj head on top of features') 61 | if 'l1' in script: 62 | parser.add_argument('--l1reg', action='store_true', help='to include l1 reg') 63 | parser.add_argument('--tgt-batch-size', default=-1, type=int, help='target batchsz') 64 | parser.add_argument('--reg-wt', default=1.0, type=float, help='regularization wt') 65 | parser.add_argument('--tau-l1', default=1.0, type=float, help='similarity function param') 66 | elif ('mvcon' in script) or ('npcon' in script): 67 | parser.add_argument('--tau', default=0.1, type=float, help='temperature') 68 | parser.add_argument('--latentcls_prob', type=float, help='latent class prob') 69 | parser.add_argument('--modification', default='none', type=str, help='none/debiased/debiased_hpm') 70 | parser.add_argument('--pos-wt-type', default='uniform', type=str, help='weighting schemes for positives') 71 | parser.add_argument('--mvcon-type', default='vanila', type=str, help='vanila: supcon, plus: remove positves ' 72 | 'from denominator') 73 | parser.add_argument('--num_pos', default=-1, type=int, help='') 74 | parser.add_argument('--load-modelpath-aux', default=None, type=str, help='') 75 | parser.add_argument('--hpm_type', default=None, type=str, help='type of hardpositive mining') 76 | elif 'hpm' in script: 77 | parser.add_argument('--beta', default=0.5, type=float, help='temperature') 78 | parser.add_argument('--tau', default=0.5, type=float, help='temperature') 79 | elif 'cssf' in script: 80 | parser.add_argument('--tau', default=0.5, type=float, help='temperature') 81 | parser.add_argument('--cosine_fac', default=1.0, type=float, help='temperature for cosine classifier') 82 | parser.add_argument('--alpha', default=0.5, type=float, help='hnm+align convex parameter') 83 | parser.add_argument('--bd_alpha', default=0.5, type=float, help='hnm+align convex parameter') 84 | parser.add_argument('--beta', default=1.0, type=float, help='hpm parameter') 85 | parser.add_argument('--clstau', default=1.0, type=float, help='dataset dependent parameter') 86 | parser.add_argument('--distractor_bsz', default=64, type=int, help='distractor batch size') 87 | parser.add_argument('--src_subset_sz', default=64, type=int, help='source batch size') 88 | parser.add_argument('--src_classes', default=64, type=int, help='source classifier size') 89 | parser.add_argument('--num_tasks', default=600, type=int, help='num_tasks') 90 | parser.add_argument('--num_ft_layers', default=0, type=int, help='# finetuning layers') 91 | parser.add_argument('--is_same_head', action='store_true', help='same projection head for OL and LPUN') 92 | parser.add_argument('--is_src_aug', action='store_true', help='augmentation to source samples') 93 | parser.add_argument('--is_tgt_aug', action='store_true', help='augmentation to target (novel) samples') 94 | parser.add_argument('--ufsl_dataset', action='store_true', help='if ufsl experiments') 95 | parser.add_argument('--ceft', action='store_true', help='replacing contr. with ce classifier in conce') 96 | parser.add_argument('--ft_mode', default='preFT', type=str, help='cssf finetuning type') 97 | parser.add_argument('--distractor_set', default='./filelists/miniImagenet/base.json', type=str, help='distractor dataset') 98 | if 'parallel' in script: 99 | parser.add_argument('--run_id', default=0, type=int, help='parallel process id') 100 | 101 | 102 | elif 'transfer_mvcon' in script: 103 | parser.add_argument('--src_n_way', default=5, type=int, help='class num to classify for training') 104 | parser.add_argument('--src_n_query', default=5, type=int, help='class num to classify for training') 105 | parser.add_argument('--targetset', default='cub', help='for adaptation') 106 | parser.add_argument('--featype', default='projection', help='feature ex for source fsl') 107 | parser.add_argument('--projtype', default='same', help='projection for supcon') 108 | parser.add_argument('--ufsl-wt', default=1.0, type=float, help='contrastive loss temperature') 109 | parser.add_argument('--tau', default=0.5, type=float, help='temperature') 110 | parser.add_argument('--pos-wt-type', default='min', type=str, help='weighting schemes for positives') 111 | parser.add_argument('--batch_size', default=-1, type=int, help='target batchsz') 112 | parser.add_argument('--mvcon-type', default='vanila', type=str, help='vanila: supcon, plus: remove positves ' 113 | 'from denominator') 114 | elif ('transfer_Contrastive' in script) or ('transfer_simclr' in script): 115 | parser.add_argument('--targetset', default='cub', help='for adaptation') 116 | parser.add_argument('--ufsl-wt', default=1.0, type=float, help='contrastive loss temperature') 117 | parser.add_argument('--tau', default=0.5, type=float, help='temperature') 118 | parser.add_argument('--batch-size', default=-1, type=int, help='target batchsz') 119 | parser.add_argument('--num_pos', default=-1, type=int, help='') 120 | parser.add_argument('--hpm_type', default=None, type=str, help='type of hardpositive mining') 121 | 122 | elif 'interp' in script: 123 | parser.add_argument('--sampling', default='random', help='support and query sampling') 124 | parser.add_argument('--targetset', default='cub', help='for adaptation') 125 | parser.add_argument('--ufsl-wt', default=1.0, type=float, help='contrastive loss temperature') 126 | parser.add_argument('--num-aug', default=4, type=int, help='') 127 | elif script == 'test': 128 | parser.add_argument('--split' , default='novel', help='base/val/novel') 129 | parser.add_argument('--save_epoch', default=-1, type=int,help ='save feature from the model trained in x epoch, ' 130 | 'use the best model if x is -1') 131 | parser.add_argument('--batch-size', default=-1, type=int, help='feature batchsz') 132 | elif script == 'cluster': 133 | parser.add_argument('--ss-thresh', default=0.99, type=float, help='self similarity threshold') 134 | parser.add_argument('--min-shots', default=0, type=int, help='num shots per class') 135 | parser.add_argument('--split' , default='novel', help='base/val/novel') 136 | parser.add_argument('--num-classes' , default=200, type=int, help=' num clusters') 137 | parser.add_argument('--relabel', action='store_true', help='tUn vs Semi supervised') 138 | parser.add_argument('--batch-size', default=-1, type=int, help='feature batchsz') 139 | parser.add_argument('--tau', default=0.1, type=float, help='feature batchsz') 140 | parser.add_argument('--data-dir-save', default='./filelists', type=str, help='') 141 | parser.add_argument('--save_epoch', default=-1, type=int,help ='save feature from the model trained in x epoch, ' 142 | 'use the best model if x is -1') 143 | else: 144 | raise ValueError('Unknown script') 145 | 146 | return parser.parse_args() 147 | 148 | def get_assigned_file(checkpoint_dir,num): 149 | assign_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(num)) 150 | return assign_file 151 | 152 | def get_resume_file(checkpoint_dir, resume_epoch=-1): 153 | filelist = glob.glob(os.path.join(checkpoint_dir, '*.tar')) 154 | if len(filelist) == 0: 155 | return None 156 | 157 | filelist = [ x for x in filelist if os.path.basename(x) != 'best_model.tar' ] 158 | epochs = np.array([int(os.path.splitext(os.path.basename(x))[0]) for x in filelist]) 159 | max_epoch = np.max(epochs) 160 | epoch = max_epoch if resume_epoch == -1 else resume_epoch 161 | resume_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(epoch)) 162 | return resume_file 163 | 164 | def get_best_file(checkpoint_dir): 165 | best_file = os.path.join(checkpoint_dir, 'best_model.tar') 166 | if os.path.isfile(best_file): 167 | return best_file 168 | else: 169 | return get_resume_file(checkpoint_dir) 170 | 171 | def load_warmup_state(filename, method): 172 | print(' load pre-trained model file: {}'.format(filename)) 173 | # warmup_resume_file = get_resume_file(filename) 174 | warmup_resume_file = get_best_file(filename) 175 | tmp = torch.load(warmup_resume_file) 176 | if tmp is not None: 177 | state = tmp['state'] 178 | state_keys = list(state.keys()) 179 | for i, key in enumerate(state_keys): 180 | if 'relationnet' in method and "feature." in key: 181 | newkey = key.replace("feature.","") 182 | state[newkey] = state.pop(key) 183 | elif method == 'gnnnet': 184 | if 'feature.module.' in key: 185 | if (torch.cuda.device_count() > 1): 186 | newkey = key.replace("feature.module.", "module.") 187 | else: 188 | newkey = key.replace("feature.module.", "") 189 | state[newkey] = state.pop(key) 190 | elif 'feature.' in key: 191 | newkey = key.replace("feature.", "") 192 | state[newkey] = state.pop(key) 193 | elif method == 'matchingnet' and 'feature.' in key and '.7.' not in key: 194 | newkey = key.replace("feature.","") 195 | state[newkey] = state.pop(key) 196 | else: 197 | state.pop(key) 198 | else: 199 | raise ValueError(' No pre-trained encoder file found!') 200 | return state 201 | 202 | -------------------------------------------------------------------------------- /output/checkpoints/download_encoder.py: -------------------------------------------------------------------------------- 1 | from subprocess import call 2 | 3 | call('wget http://vllab.ucmerced.edu/ym41608/projects/CrossDomainFewShot/checkpoints/baseline.tar.gz', shell=True) 4 | call('wget http://vllab.ucmerced.edu/ym41608/projects/CrossDomainFewShot/checkpoints/baseline++.tar.gz', shell=True) 5 | call('tar -zxf baseline.tar.gz', shell=True) 6 | call('tar -zxf baseline++.tar.gz', shell=True) 7 | call('rm baseline.tar.gz', shell=True) 8 | call('rm baseline++.tar.gz', shell=True) 9 | -------------------------------------------------------------------------------- /output/checkpoints/download_encoders.py: -------------------------------------------------------------------------------- 1 | from subprocess import call 2 | 3 | call('wget http://vllab.ucmerced.edu/ym41608/projects/CrossDomainFewShot/checkpoints/baseline.tar.gz', shell=True) 4 | call('wget http://vllab.ucmerced.edu/ym41608/projects/CrossDomainFewShot/checkpoints/baseline++.tar.gz', shell=True) 5 | call('tar -zxf baseline.tar.gz', shell=True) 6 | call('tar -zxf baseline++.tar.gz', shell=True) 7 | call('rm baseline.tar.gz', shell=True) 8 | call('rm baseline++.tar.gz', shell=True) 9 | -------------------------------------------------------------------------------- /output/checkpoints/download_models.py: -------------------------------------------------------------------------------- 1 | from subprocess import call 2 | import sys 3 | 4 | # current available models: 5 | 6 | if len(sys.argv) != 6: 7 | raise Exception('Incorrect command! e.g., python3 download_models.py [cub/cars/places/plantae] [1/5] [matchingnet/relationnet/gnnnet]') 8 | 9 | testset = sys.argv[1] 10 | shot = sys.argv[2] 11 | model = sys.argv[3] 12 | 13 | filename = 'multi_' + testset + '_' + shot + '_lft_' + model + '.tar.gz' 14 | 15 | call('wget http://vllab.ucmerced.edu/ym41608/projects/CrossDomainFewShot/checkpoints/' + filename, shell=True) 16 | call('tar -zxf ' + filename, shell=True) 17 | call('rm ' + filename, shell=True) 18 | -------------------------------------------------------------------------------- /read_parallel_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import ipdb 5 | 6 | 7 | def append_data(name, quantity): 8 | with open(name, 'rb') as f: 9 | data = pickle.load(f) 10 | quantity.append(data) 11 | return quantity 12 | 13 | 14 | def consolidate_results(root, max_parallel_id=0): 15 | NE_list = range(0, 600, 10) 16 | with open(os.path.join(root, 'results.txt'), 'w') as f: 17 | for ine, num_epochs in enumerate(NE_list): 18 | acc_all_le = [] 19 | acc_all = [] 20 | accDiff_all = [] 21 | chklist = [] 22 | for run in range(0, max_parallel_id + 1): 23 | linev_name = os.path.join(root, str(run), 'results/linev.pkl') 24 | final_name = os.path.join(root, str(run), 'results/wi_final%d.pkl' % (num_epochs)) 25 | if not os.path.exists(final_name): 26 | chklist.append(run) 27 | continue 28 | delta_name = os.path.join(root, str(run), 'results/wi_delta%d.pkl' % (num_epochs)) 29 | 30 | acc_all_le = append_data(linev_name, acc_all_le) 31 | acc_all = append_data(final_name, acc_all) 32 | accDiff_all = append_data(delta_name, accDiff_all) 33 | 34 | if len(acc_all_le) == 0: 35 | continue 36 | acc_all_le = np.hstack(acc_all_le) 37 | acc_all = np.hstack(acc_all) 38 | 39 | nTasks = acc_all_le.shape[0] 40 | acc_mean = np.mean(acc_all) 41 | acc_std = 1.96 * np.std(acc_all) / np.sqrt(nTasks) 42 | 43 | strn = "accuracy = %4.2f +- %4.2f" % (acc_mean, acc_std) 44 | print(strn) 45 | 46 | 47 | if __name__ == "__main__": 48 | consolidate_results(root) 49 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import json 5 | import ipdb 6 | import math 7 | import pickle 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | def one_hot(y, num_class): 12 | return torch.zeros((len(y), num_class)).scatter_(1, y.unsqueeze(1), 1) 13 | 14 | 15 | def DBindex(cl_data_file): 16 | class_list = cl_data_file.keys() 17 | cl_num = len(class_list) 18 | cl_means = [] 19 | stds = [] 20 | DBs = [] 21 | for cl in class_list: 22 | cl_means.append(np.mean(cl_data_file[cl], axis=0)) 23 | stds.append(np.sqrt(np.mean(np.sum(np.square(cl_data_file[cl] - cl_means[-1]), axis=1)))) 24 | 25 | mu_i = np.tile(np.expand_dims(np.array(cl_means), axis=0), (len(class_list), 1, 1)) 26 | mu_j = np.transpose(mu_i, (1, 0, 2)) 27 | mdists = np.sqrt(np.sum(np.square(mu_i - mu_j), axis=2)) 28 | 29 | for i in range(cl_num): 30 | DBs.append(np.max([(stds[i] + stds[j]) / mdists[i, j] for j in range(cl_num) if j != i])) 31 | return np.mean(DBs) 32 | 33 | 34 | def sparsity(cl_data_file): 35 | class_list = cl_data_file.keys() 36 | cl_sparsity = [] 37 | for cl in class_list: 38 | cl_sparsity.append(np.mean([np.sum(x != 0) for x in cl_data_file[cl]])) 39 | 40 | return np.mean(cl_sparsity) 41 | 42 | 43 | def createdir(savedir): 44 | if not os.path.isdir(savedir): os.makedirs(savedir) 45 | return savedir 46 | 47 | 48 | def get_miniImagenet_labelnames(labelnames): 49 | with open('/home/rajshekd/projects/FSG/FSG_raj/cdfsl/filelists/miniImagenet/classnames.txt') as f: 50 | lines = f.readlines() 51 | classdict = {} 52 | for line in lines: 53 | line = line.split('\n')[0] 54 | if line == '#### Val ####': 55 | break 56 | key, name = line.split(' ') 57 | classdict[key] = name 58 | aliases = [] 59 | for lname in labelnames: 60 | aliases.append(classdict[lname]) 61 | return aliases 62 | 63 | 64 | def json_dump(obj, filename): 65 | with open(filename, 'w') as f: 66 | json.dump(obj, f) 67 | 68 | 69 | def pickle_dump(filename, obj): 70 | with open(filename, 'wb') as f: 71 | pickle.dump(obj, f) 72 | 73 | 74 | def json_load(filename): 75 | with open(filename) as f: 76 | obj = json.load(f) 77 | return obj 78 | 79 | 80 | def pickle_load(filename): 81 | with open(filename, 'rb') as f: 82 | obj = pickle.load(f) 83 | return obj 84 | 85 | 86 | def classwise_affinity_graph(affinity_graph, labels, c1, c2): 87 | uL = np.unique(labels) 88 | label1 = uL[c1] 89 | label2 = uL[c2] 90 | c1_indices = np.asarray([i for i, lab in enumerate(labels) if lab == label1]) 91 | c2_indices = np.asarray([i for i, lab in enumerate(labels) if lab == label2]) 92 | return affinity_graph[c1_indices[:, None], c2_indices] 93 | 94 | 95 | def aggregate_accuracy(test_logits_sample, test_labels): 96 | """ 97 | Compute classification accuracy. 98 | """ 99 | averaged_predictions = torch.logsumexp(test_logits_sample, dim=0) 100 | return torch.mean(torch.eq(test_labels, torch.argmax(averaged_predictions, dim=-1)).float()) 101 | 102 | 103 | def mat_sigmoid(mat): 104 | return 1 / (1 + np.exp(-mat)) 105 | 106 | 107 | def mat_normalize(mat): 108 | return (mat - mat.min()) / (mat.max() - mat.min()) 109 | 110 | 111 | def preprocess_image(x): 112 | img = x.permute(1, 2, 0).cpu().numpy() 113 | img = (img - img.min()) / (img.max() - img.min()) 114 | return img 115 | 116 | 117 | def save_image(x, name): 118 | plt.imsave(name, preprocess_image(x)) 119 | 120 | 121 | def chkpt_vis(model, n1, n2): 122 | # for i, (name, val) in enumerate(list(model.named_parameters())[n1:n2]): 123 | # print(name, ': ', val.mean().item()) 124 | # print('-------------------') 125 | for i, (name, val) in enumerate(list(model.state_dict().items())[n1:n2]): 126 | print(name, ': ', val.float().mean().item()) 127 | print('-------------------') 128 | --------------------------------------------------------------------------------