├── configs.py ├── filelists ├── CUB │ ├── download_CUB.sh │ └── write_CUB_filelist.py ├── NAB │ └── write_NAB_filelist.py └── DOG │ └── write_DOG_filelist.py ├── data ├── additional_transforms.py ├── dataset.py ├── feature_loader.py └── datamgr.py ├── README.md ├── methods ├── baselinefinetune.py ├── meta_template.py ├── baselinetrain.py └── baselinevae.py ├── utils.py ├── io_utils.py ├── train_vae.py ├── finetune_sample.py └── backbone.py /configs.py: -------------------------------------------------------------------------------- 1 | data_dir = {} 2 | data_dir['CUB'] = './filelists/CUB/' 3 | data_dir['NAB'] = './filelists/NAB/' 4 | data_dir['DOG'] = './filelists/DOG/' 5 | 6 | -------------------------------------------------------------------------------- /filelists/CUB/download_CUB.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz 3 | tar -zxvf CUB_200_2011.tgz 4 | #python write_CUB_filelist.py 5 | -------------------------------------------------------------------------------- /data/additional_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import torch 9 | from PIL import ImageEnhance 10 | 11 | transformtypedict=dict(Brightness=ImageEnhance.Brightness, Contrast=ImageEnhance.Contrast, Sharpness=ImageEnhance.Sharpness, Color=ImageEnhance.Color) 12 | 13 | 14 | 15 | class ImageJitter(object): 16 | def __init__(self, transformdict): 17 | self.transforms = [(transformtypedict[k], transformdict[k]) for k in transformdict] 18 | 19 | 20 | def __call__(self, img): 21 | out = img 22 | randtensor = torch.rand(len(self.transforms)) 23 | 24 | for i, (transformer, alpha) in enumerate(self.transforms): 25 | r = alpha*(randtensor[i]*2.0 -1.0) + 1 26 | out = transformer(out).enhance(r).convert('RGB') 27 | 28 | return out 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational_Transfer_Fewshot 2 | 3 | Implementation for Variational Transfer Learning for Fine-grained Few-shot Visual Recognition 4 | 5 | ## Environment 6 | - Python3 7 | - Pytorch 1.4.0 8 | 9 | ## Datasets 10 | 11 | ### CUB 12 | - Change directoy to filelists/CUB 13 | - Download the dataset from http://www.vision.caltech.edu/visipedia/CUB-200-2011.html 14 | - Put it as `./CUB_200_2011` 15 | - Run `python ./write_CUB_filelist.py` 16 | 17 | ### NAB 18 | - Change directoy to filelists/NAB 19 | - Download the dataset from http://dl.allaboutbirds.org/nabirds 20 | - Put it as `./nabirds` 21 | - Run `python ./write_NAB_filelist.py` 22 | 23 | ### Stanford Dogs 24 | - Change directoy to filelists/DOG 25 | - Download the dataset from http://vision.stanford.edu/aditya86/ImageNetDogs/ 26 | - Put it as `./stanforddogs` 27 | - Run `python ./write_DOG_filelist.py` 28 | 29 | 30 | ## Running Experiments 31 | 32 | ### Run training phase: 33 | ```bash 34 | python train_vae.py 35 | ``` 36 | 37 | ### Run testing phase: 38 | ```bash 39 | python finetune_sample.py 40 | ``` 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /filelists/NAB/write_NAB_filelist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os import listdir 3 | from os.path import isfile, isdir, join 4 | import os 5 | import json 6 | import random 7 | 8 | cwd = os.getcwd() 9 | data_path = join(cwd,'nabirds/images') 10 | savedir = './' 11 | dataset_list = ['base'] 12 | 13 | #if not os.path.exists(savedir): 14 | # os.makedirs(savedir) 15 | 16 | folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))] 17 | folder_list.sort() 18 | label_dict = dict(zip(folder_list,range(0,len(folder_list)))) 19 | 20 | classfile_list_all = [] 21 | 22 | for i, folder in enumerate(folder_list): 23 | folder_path = join(data_path, folder) 24 | classfile_list_all.append( [ join(folder_path, cf) for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')]) 25 | random.shuffle(classfile_list_all[i]) 26 | 27 | 28 | for dataset in dataset_list: 29 | file_list = [] 30 | label_list = [] 31 | for i, classfile_list in enumerate(classfile_list_all): 32 | if 'base' in dataset: 33 | if (i%2 == 0): 34 | file_list = file_list + classfile_list 35 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 36 | if 'val' in dataset: 37 | if (i%4 == 1): 38 | file_list = file_list + classfile_list 39 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 40 | if 'novel' in dataset: 41 | if (i%4 == 3): 42 | file_list = file_list + classfile_list 43 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 44 | 45 | fo = open(savedir + dataset + "_100_10.json", "w") 46 | fo.write('{"label_names": [') 47 | fo.writelines(['"%s",' % item for item in folder_list]) 48 | fo.seek(0, os.SEEK_END) 49 | fo.seek(fo.tell()-1, os.SEEK_SET) 50 | fo.write('],') 51 | 52 | fo.write('"image_names": [') 53 | fo.writelines(['"%s",' % item for item in file_list]) 54 | fo.seek(0, os.SEEK_END) 55 | fo.seek(fo.tell()-1, os.SEEK_SET) 56 | fo.write('],') 57 | 58 | fo.write('"image_labels": [') 59 | fo.writelines(['%d,' % item for item in label_list]) 60 | fo.seek(0, os.SEEK_END) 61 | fo.seek(fo.tell()-1, os.SEEK_SET) 62 | fo.write(']}') 63 | 64 | fo.close() 65 | print("%s -OK" %dataset) 66 | -------------------------------------------------------------------------------- /filelists/CUB/write_CUB_filelist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os import listdir 3 | from os.path import isfile, isdir, join 4 | import os 5 | import json 6 | import random 7 | 8 | cwd = os.getcwd() 9 | data_path = join(cwd,'CUB_200_2011/images') 10 | savedir = './' 11 | dataset_list = ['novel', 'base', 'val'] 12 | 13 | #if not os.path.exists(savedir): 14 | # os.makedirs(savedir) 15 | 16 | folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))] 17 | folder_list.sort() 18 | label_dict = dict(zip(folder_list,range(0,len(folder_list)))) 19 | 20 | classfile_list_all = [] 21 | 22 | for i, folder in enumerate(folder_list): 23 | folder_path = join(data_path, folder) 24 | classfile_list_all.append( [ join(folder_path, cf) for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')]) 25 | random.shuffle(classfile_list_all[i]) 26 | 27 | 28 | for dataset in dataset_list: 29 | file_list = [] 30 | label_list = [] 31 | for i, classfile_list in enumerate(classfile_list_all): 32 | if 'all' in dataset: 33 | file_list = file_list + classfile_list 34 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 35 | 36 | if 'base' in dataset: 37 | if i % 2 ==0: 38 | file_list = file_list + classfile_list 39 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 40 | if 'val' in dataset: 41 | if i % 4 == 1 and i > 120: 42 | file_list = file_list + classfile_list 43 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 44 | if 'novel' in dataset: 45 | if (i > 60 and i%4 == 3) or (i > 60 and i%4 ==1 and i < 118): 46 | file_list = file_list + classfile_list 47 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 48 | 49 | fo = open(savedir + dataset + ".json", "w") 50 | fo.write('{"label_names": [') 51 | fo.writelines(['"%s",' % item for item in folder_list]) 52 | fo.seek(0, os.SEEK_END) 53 | fo.seek(fo.tell()-1, os.SEEK_SET) 54 | fo.write('],') 55 | 56 | fo.write('"image_names": [') 57 | fo.writelines(['"%s",' % item for item in file_list]) 58 | fo.seek(0, os.SEEK_END) 59 | fo.seek(fo.tell()-1, os.SEEK_SET) 60 | fo.write('],') 61 | 62 | fo.write('"image_labels": [') 63 | fo.writelines(['%d,' % item for item in label_list]) 64 | fo.seek(0, os.SEEK_END) 65 | fo.seek(fo.tell()-1, os.SEEK_SET) 66 | fo.write(']}') 67 | 68 | fo.close() 69 | print("%s -OK" %dataset) 70 | -------------------------------------------------------------------------------- /filelists/DOG/write_DOG_filelist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os import listdir 3 | from os.path import isfile, isdir, join 4 | import os 5 | import json 6 | import random 7 | 8 | cwd = os.getcwd() 9 | data_path = join(cwd,'stanforddogs/Images') 10 | savedir = './' 11 | dataset_list = ['base', 'val', 'novel'] 12 | 13 | #if not os.path.exists(savedir): 14 | # os.makedirs(savedir) 15 | 16 | folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))] 17 | folder_list.sort() 18 | label_dict = dict(zip(folder_list,range(0,len(folder_list)))) 19 | 20 | classfile_list_all = [] 21 | 22 | for i, folder in enumerate(folder_list): 23 | folder_path = join(data_path, folder) 24 | classfile_list_all.append( [ join(folder_path, cf) for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')]) 25 | random.shuffle(classfile_list_all[i]) 26 | 27 | longtail_classnum = 0 28 | 29 | for dataset in dataset_list: 30 | file_list = [] 31 | label_list = [] 32 | for i, classfile_list in enumerate(classfile_list_all): 33 | 34 | if 'base' in dataset: 35 | if i % 2 == 0 or i < 20: 36 | #if i < 70 == 0: 37 | #if longtail_classnum < 60: 38 | # classfile_list = classfile_list[:10] 39 | #longtail_classnum += 1 40 | file_list = file_list + classfile_list 41 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 42 | if 'val' in dataset: 43 | if (i % 4 == 1 and i > 20 and i < 98): 44 | #if (i >70 and i < 90): 45 | file_list = file_list + classfile_list 46 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 47 | if 'novel' in dataset: 48 | if (i % 4 == 3 and i > 20 and i < 98) or (i > 98 and i % 2 == 1): 49 | #if (i > 90): 50 | file_list = file_list + classfile_list 51 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 52 | 53 | fo = open(savedir + dataset + ".json", "w") 54 | fo.write('{"label_names": [') 55 | fo.writelines(['"%s",' % item for item in folder_list]) 56 | fo.seek(0, os.SEEK_END) 57 | fo.seek(fo.tell()-1, os.SEEK_SET) 58 | fo.write('],') 59 | 60 | fo.write('"image_names": [') 61 | fo.writelines(['"%s",' % item for item in file_list]) 62 | fo.seek(0, os.SEEK_END) 63 | fo.seek(fo.tell()-1, os.SEEK_SET) 64 | fo.write('],') 65 | 66 | fo.write('"image_labels": [') 67 | fo.writelines(['%d,' % item for item in label_list]) 68 | fo.seek(0, os.SEEK_END) 69 | fo.seek(fo.tell()-1, os.SEEK_SET) 70 | fo.write(']}') 71 | 72 | fo.close() 73 | print("%s -OK" %dataset) 74 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import json 6 | import numpy as np 7 | import torchvision.transforms as transforms 8 | import os 9 | import pdb 10 | identity = lambda x:x 11 | class SimpleDataset: 12 | def __init__(self, data_file, transform, target_transform=identity): 13 | with open(data_file, 'r') as f: 14 | self.meta = json.load(f) 15 | self.transform = transform 16 | self.target_transform = target_transform 17 | 18 | 19 | def __getitem__(self,i): 20 | image_path = os.path.join(self.meta['image_names'][i]) 21 | img = Image.open(image_path).convert('RGB') 22 | img = self.transform(img) 23 | target = self.target_transform(self.meta['image_labels'][i]) 24 | return img, target 25 | 26 | def __len__(self): 27 | return len(self.meta['image_names']) 28 | 29 | 30 | class SetDataset: 31 | def __init__(self, data_file, batch_size, transform): 32 | with open(data_file, 'r') as f: 33 | self.meta = json.load(f) 34 | 35 | self.cl_list = np.unique(self.meta['image_labels']).tolist() 36 | 37 | self.sub_meta = {} 38 | for cl in self.cl_list: 39 | self.sub_meta[cl] = [] 40 | 41 | for x,y in zip(self.meta['image_names'],self.meta['image_labels']): 42 | self.sub_meta[y].append(x) 43 | 44 | for cl in self.cl_list: 45 | num = len(self.sub_meta[cl]) 46 | if len(self.sub_meta[cl]) < batch_size: 47 | for i in range(batch_size - num): 48 | self.sub_meta[cl].append(self.sub_meta[cl][i]) 49 | 50 | self.sub_dataloader = [] 51 | sub_data_loader_params = dict(batch_size = batch_size, 52 | shuffle = True, 53 | num_workers = 0, #use main thread only or may receive multiple batches 54 | pin_memory = False) 55 | for cl in self.cl_list: 56 | sub_dataset = SubDataset(self.sub_meta[cl], cl, transform = transform ) 57 | self.sub_dataloader.append( torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params) ) 58 | 59 | def __getitem__(self,i): 60 | return next(iter(self.sub_dataloader[i])) 61 | 62 | def __len__(self): 63 | return len(self.cl_list) 64 | 65 | class SubDataset: 66 | def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), target_transform=identity): 67 | self.sub_meta = sub_meta 68 | self.cl = cl 69 | self.transform = transform 70 | self.target_transform = target_transform 71 | 72 | def __getitem__(self,i): 73 | #print( '%d -%d' %(self.cl,i)) 74 | image_path = os.path.join( self.sub_meta[i]) 75 | img = Image.open(image_path).convert('RGB') 76 | img = self.transform(img) 77 | target = self.target_transform(self.cl) 78 | return img, target 79 | 80 | def __len__(self): 81 | return len(self.sub_meta) 82 | 83 | class EpisodicBatchSampler(object): 84 | def __init__(self, n_classes, n_way, n_episodes): 85 | self.n_classes = n_classes 86 | self.n_way = n_way 87 | self.n_episodes = n_episodes 88 | 89 | def __len__(self): 90 | return self.n_episodes 91 | 92 | def __iter__(self): 93 | for i in range(self.n_episodes): 94 | yield torch.randperm(self.n_classes)[:self.n_way] 95 | -------------------------------------------------------------------------------- /data/feature_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import h5py 4 | import pdb 5 | class SimpleHDF5Dataset: 6 | def __init__(self, file_handle = None): 7 | if file_handle == None: 8 | self.f = '' 9 | self.all_feats_dset = [] 10 | self.all_labels = [] 11 | self.total = 0 12 | else: 13 | self.f = file_handle 14 | self.all_feats_dset = self.f['all_feats'][...] 15 | self.all_labels = self.f['all_labels'][...] 16 | self.total = self.f['count'][0] 17 | # print('here') 18 | def __getitem__(self, i): 19 | return torch.Tensor(self.all_feats_dset[i,:]), int(self.all_labels[i]) 20 | 21 | def __len__(self): 22 | return self.total 23 | 24 | def init_img_loader(data_loader): 25 | 26 | img_data_file = {} 27 | 28 | for i, (x, y) in enumerate(data_loader): 29 | bs = x.size(0) 30 | for idx in range(bs): 31 | y_idx = y[idx].item() 32 | x_idx = np.array(x[idx]) 33 | if y_idx not in img_data_file.keys(): 34 | img_data_file[y_idx] = [] 35 | img_data_file[y_idx].append(x_idx) 36 | 37 | return img_data_file 38 | 39 | def init_loader(filename): 40 | with h5py.File(filename, 'r') as f: 41 | fileset = SimpleHDF5Dataset(f) 42 | 43 | #labels = [ l for l in fileset.all_labels if l != 0] 44 | feats = fileset.all_feats_dset 45 | labels = fileset.all_labels 46 | while np.sum(feats[-1]) == 0: 47 | feats = np.delete(feats,-1,axis = 0) 48 | labels = np.delete(labels,-1,axis = 0) 49 | 50 | class_list = np.unique(np.array(labels)).tolist() 51 | inds = range(len(labels)) 52 | 53 | cl_data_file = {} 54 | for cl in class_list: 55 | cl_data_file[cl] = [] 56 | for ind in inds: 57 | cl_data_file[labels[ind]].append( feats[ind]) 58 | 59 | return cl_data_file 60 | 61 | 62 | def get_classmap(dset): 63 | ''' 64 | Creates a mapping between serial number of a class 65 | in provided dataset and the indices used for classification. 66 | Returns: 67 | 2 dicts, 1 each for train and test classes 68 | ''' 69 | class_names_file = '/home/jingyi/feature_generating_pytorch/datasets/%s/classes.txt' % dset 70 | 71 | with open(class_names_file) as fp: 72 | all_classes = fp.readlines() 73 | with open('/home/jingyi/feature_generating_pytorch/datasets/%s/testclasses.txt' % dset) as fp: 74 | test_class_names = [i.strip() for i in fp.readlines() if i != ''] 75 | 76 | test_count = 0 77 | train_count = 0 78 | 79 | train_classmap = dict() 80 | test_classmap = dict() 81 | for line in all_classes: 82 | idx, name = [i.strip() for i in line.split(' ')] 83 | if name in test_class_names: 84 | test_classmap[int(idx)] = test_count 85 | test_count += 1 86 | else: 87 | train_classmap[int(idx)] = train_count 88 | train_count += 1 89 | return train_classmap, test_classmap 90 | 91 | def load_feat(features, labels, train_classmap, test_classmap): 92 | cl_data_file = {} 93 | for feat, label in zip(features, labels): 94 | if label in test_classmap.keys(): 95 | if label not in cl_data_file.keys(): 96 | cl_data_file[label] = feat.reshape((1, -1)) 97 | else: 98 | cl_data_file[label] = np.concatenate([cl_data_file[label], feat.reshape((1, -1))], 0) 99 | return cl_data_file 100 | -------------------------------------------------------------------------------- /methods/baselinefinetune.py: -------------------------------------------------------------------------------- 1 | import backbone 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | from sklearn.manifold import TSNE 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from methods.meta_template import MetaTemplate 9 | import pdb 10 | 11 | class BaselineFinetune(MetaTemplate): 12 | def __init__(self, model_func, n_way, n_support, loss_type = "softmax"): 13 | super(BaselineFinetune, self).__init__( model_func, n_way, n_support) 14 | self.loss_type = loss_type 15 | 16 | def set_forward(self,x,is_feature = True, aug_per_sample=0): 17 | return self.set_forward_adaptation(x,is_feature, aug_per_sample=aug_per_sample); #Baseline always do adaptation 18 | 19 | def set_forward_adaptation(self,x,is_feature = True, base_cl_data_file=None, aug_per_sample=0): 20 | assert is_feature == True, 'Baseline only support testing with feature' 21 | z_support, z_query = self.parse_feature(x,is_feature) 22 | if aug_per_sample > 0: 23 | z_support = self.aug_features(z_support, aug_per_sample=aug_per_sample) 24 | z_support_all = z_support.contiguous().view(self.n_way* (self.n_support+aug_per_sample), -1 ) 25 | z_query_all = z_query.contiguous().view(self.n_way* self.n_query, -1 ) 26 | y_support_all = torch.from_numpy(np.repeat(range( self.n_way ), (self.n_support+aug_per_sample) )) 27 | y_support_all = Variable(y_support_all.cuda()) 28 | y_query_all = np.repeat(range(self.n_way), self.n_query) 29 | batch_size = 4 30 | 31 | if self.loss_type == 'softmax': 32 | linear_clf = nn.Linear(self.feat_dim, self.n_way) 33 | elif self.loss_type == 'dist': 34 | linear_clf = backbone.distLinear(self.feat_dim, self.n_way) 35 | linear_clf = linear_clf.cuda() 36 | 37 | set_optimizer = torch.optim.SGD(linear_clf.parameters(), lr = 0.01, momentum=0.9, dampening=0.9, weight_decay=0.001) 38 | 39 | loss_function = nn.CrossEntropyLoss() 40 | loss_function = loss_function.cuda() 41 | 42 | #support_size = self.n_way* self.n_support 43 | support_size = z_support_all.shape[0] 44 | for epoch in range(100): 45 | rand_id = np.random.permutation(support_size) 46 | for i in range(0, support_size , batch_size): 47 | set_optimizer.zero_grad() 48 | selected_id = torch.from_numpy( rand_id[i: min(i+batch_size, support_size) ]).cuda() 49 | z_batch = z_support_all[selected_id] 50 | y_batch = y_support_all[selected_id] 51 | scores = linear_clf(z_batch) 52 | loss = loss_function(scores,y_batch) 53 | loss.backward() 54 | set_optimizer.step() 55 | #scores = linear_clf(z_support_all) 56 | #pred = torch.argmax(scores, dim=1) 57 | #acc = np.mean(np.array((pred == y_support_all).cpu().data)) 58 | #print('Epoch: %d, Acc: %f'%(epoch, acc)) 59 | scores = linear_clf(z_query_all) 60 | #pdb.set_trace() 61 | #weight_embedded = TSNE(n_components=2).fit_transform(linear_clf.L.weight.cpu().data) 62 | #plot_tsne(z_embedded, np.array(y_support_all.cpu().data), weight=weight_embedded) 63 | return scores 64 | 65 | def aug_features(self, ori_features, aug_per_sample, n_way=5, n_shot=1, feature_dim=640): 66 | aug_features = torch.zeros((n_way, (aug_per_sample + n_shot), feature_dim)) 67 | for cls in range(n_way): 68 | cls_feature = ori_features[cls, :, :] 69 | aug_cls_feature = cls_feature + torch.randn(aug_per_sample, feature_dim).cuda() * 0.2 70 | aug_cls_feature = torch.cat((aug_cls_feature, cls_feature)) 71 | aug_features[cls, :, :] = aug_cls_feature 72 | return aug_features.cuda() 73 | 74 | def set_forward_loss(self,x): 75 | raise ValueError('Baseline predict on pretrained feature and do not support finetune backbone') 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /data/datamgr.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import pdb 7 | import torchvision.transforms as transforms 8 | import data.additional_transforms as add_transforms 9 | from data.dataset import SimpleDataset, SetDataset, EpisodicBatchSampler 10 | from abc import abstractmethod 11 | 12 | class TransformLoader: 13 | def __init__(self, image_size, 14 | normalize_param, 15 | jitter_param = dict(Brightness=0.4, Contrast=0.4, Color=0.4)): 16 | self.image_size = image_size 17 | self.normalize_param = normalize_param 18 | self.jitter_param = jitter_param 19 | 20 | def parse_transform(self, transform_type): 21 | if transform_type=='ImageJitter': 22 | method = add_transforms.ImageJitter( self.jitter_param ) 23 | return method 24 | method = getattr(transforms, transform_type) 25 | if transform_type=='RandomSizedCrop': 26 | return method(self.image_size) 27 | elif transform_type=='CenterCrop': 28 | return method(self.image_size) 29 | elif transform_type=='Scale': 30 | return method([int(self.image_size*1.15), int(self.image_size*1.15)]) 31 | elif transform_type=='Normalize': 32 | return method(**self.normalize_param ) 33 | else: 34 | return method() 35 | 36 | def get_composed_transform(self, aug = False, finetune_aug = False): 37 | if aug: 38 | transform_list = ['RandomSizedCrop', 'ImageJitter', 'RandomHorizontalFlip', 'ToTensor', 'Normalize'] 39 | else: 40 | if finetune_aug: 41 | transform_list = ['RandomSizedCrop', 'ImageJitter', 'ToTensor'] 42 | else: 43 | transform_list = ['Scale','CenterCrop', 'ToTensor', 'Normalize'] 44 | 45 | transform_funcs = [ self.parse_transform(x) for x in transform_list] 46 | transform = transforms.Compose(transform_funcs) 47 | return transform 48 | 49 | class DataManager: 50 | @abstractmethod 51 | def get_data_loader(self, data_file, aug): 52 | pass 53 | 54 | 55 | class SimpleDataManager(DataManager): 56 | def __init__(self, image_size, batch_size, normalize_param = dict(mean= [0.485, 0.456, 0.406] , std=[0.229, 0.224, 0.225])): 57 | super(SimpleDataManager, self).__init__() 58 | self.batch_size = batch_size 59 | self.trans_loader = TransformLoader(image_size, normalize_param=normalize_param) 60 | 61 | def get_data_loader(self, data_file, aug, finetune_aug=False, shuffle=True): #parameters that would change on train/val set 62 | transform = self.trans_loader.get_composed_transform(aug, finetune_aug) 63 | dataset = SimpleDataset(data_file, transform) 64 | #dataset = McDataset(data_file, transform, distributed=self.distributed) 65 | data_loader_params = dict(batch_size = self.batch_size, num_workers = 12, pin_memory = True) 66 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params, shuffle=shuffle) 67 | 68 | return data_loader 69 | 70 | class SetDataManager(DataManager): 71 | def __init__(self, image_size, n_way, n_support, n_query, n_eposide =100): 72 | super(SetDataManager, self).__init__() 73 | self.image_size = image_size 74 | self.n_way = n_way 75 | self.batch_size = n_support + n_query 76 | self.n_eposide = n_eposide 77 | 78 | self.trans_loader = TransformLoader(image_size) 79 | 80 | def get_data_loader(self, data_file, aug): #parameters that would change on train/val set 81 | transform = self.trans_loader.get_composed_transform(aug) 82 | dataset = SetDataset( data_file , self.batch_size, transform ) 83 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide ) 84 | data_loader_params = dict(batch_sampler = sampler, num_workers = 12, pin_memory = True) 85 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 86 | return data_loader 87 | 88 | 89 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | import torch.distributed as dist 6 | import pdb 7 | def matrix_log_density_gaussian(x, mu, logvar): 8 | """Calculates log density of a Gaussian for all combination of bacth pairs of 9 | `x` and `mu`. I.e. return tensor of shape `(batch_size, batch_size, dim)` 10 | instead of (batch_size, dim) in the usual log density. 11 | Parameters 12 | ---------- 13 | x: torch.Tensor 14 | Value at which to compute the density. Shape: (batch_size, dim). 15 | mu: torch.Tensor 16 | Mean. Shape: (batch_size, dim). 17 | logvar: torch.Tensor 18 | Log variance. Shape: (batch_size, dim). 19 | batch_size: int 20 | number of training images in the batch 21 | """ 22 | batch_size, dim = x.shape 23 | x = x.view(batch_size, 1, dim) 24 | mu = mu.view(1, batch_size, dim) 25 | logvar = logvar.view(1, batch_size, dim) 26 | return log_density_gaussian(x, mu, logvar) 27 | 28 | 29 | def log_density_gaussian(x, mu, logvar): 30 | """Calculates log density of a Gaussian. 31 | Parameters 32 | ---------- 33 | x: torch.Tensor or np.ndarray or float 34 | Value at which to compute the density. 35 | mu: torch.Tensor or np.ndarray or float 36 | Mean. 37 | logvar: torch.Tensor or np.ndarray or float 38 | Log variance. 39 | """ 40 | normalization = - 0.5 * (math.log(2 * math.pi) + logvar) 41 | inv_var = torch.exp(-logvar) 42 | log_density = normalization - 0.5 * ((x - mu)**2 * inv_var) 43 | return log_density 44 | 45 | 46 | def log_importance_weight_matrix(batch_size, dataset_size): 47 | """ 48 | Calculates a log importance weight matrix 49 | Parameters 50 | ---------- 51 | batch_size: int 52 | number of training images in the batch 53 | dataset_size: int 54 | number of training images in the dataset 55 | """ 56 | N = dataset_size 57 | M = batch_size - 1 58 | strat_weight = (N - M) / (N * M) 59 | W = torch.Tensor(batch_size, batch_size).fill_(1 / M) 60 | W.view(-1)[::M + 1] = 1 / N 61 | W.view(-1)[1::M + 1] = strat_weight 62 | W[M - 1, 0] = strat_weight 63 | return W.log() 64 | 65 | def _get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist, n_data, is_mss=True): 66 | batch_size, hidden_dim = latent_sample.shape 67 | 68 | # calculate log q(z|x) 69 | log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1) 70 | 71 | # calculate log p(z) 72 | # mean and log var is 0 73 | zeros = torch.zeros_like(latent_sample) 74 | log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1) 75 | 76 | mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist) 77 | 78 | if is_mss: 79 | # use stratification 80 | log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample.device) 81 | mat_log_qz = mat_log_qz + log_iw_mat.view(batch_size, batch_size, 1) 82 | 83 | log_qz = torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False) 84 | log_prod_qzi = torch.logsumexp(mat_log_qz, dim=1, keepdim=False).sum(1) 85 | return log_pz, log_qz, log_prod_qzi, log_q_zCx 86 | 87 | 88 | def l2_norm(input,axis=1): 89 | norm = torch.norm(input,2,axis,True) 90 | output = torch.div(input, norm+0.00001) 91 | return norm, output 92 | 93 | def one_hot(y, num_class): 94 | return torch.zeros((len(y), num_class)).scatter_(1, y.unsqueeze(1), 1) 95 | 96 | def DBindex(cl_data_file): 97 | class_list = cl_data_file.keys() 98 | cl_num= len(class_list) 99 | cl_means = [] 100 | stds = [] 101 | DBs = [] 102 | for cl in class_list: 103 | cl_means.append( np.mean(cl_data_file[cl], axis = 0) ) 104 | stds.append( np.sqrt(np.mean( np.sum(np.square( cl_data_file[cl] - cl_means[-1]), axis = 1)))) 105 | 106 | mu_i = np.tile( np.expand_dims( np.array(cl_means), axis = 0), (len(class_list),1,1) ) 107 | mu_j = np.transpose(mu_i,(1,0,2)) 108 | mdists = np.sqrt(np.sum(np.square(mu_i - mu_j), axis = 2)) 109 | 110 | for i in range(cl_num): 111 | DBs.append( np.max([ (stds[i]+ stds[j])/mdists[i,j] for j in range(cl_num) if j != i ]) ) 112 | return np.mean(DBs) 113 | 114 | def sparsity(cl_data_file): 115 | class_list = cl_data_file.keys() 116 | cl_sparsity = [] 117 | for cl in class_list: 118 | cl_sparsity.append(np.mean([np.sum(x!=0) for x in cl_data_file[cl] ]) ) 119 | 120 | return np.mean(cl_sparsity) 121 | -------------------------------------------------------------------------------- /methods/meta_template.py: -------------------------------------------------------------------------------- 1 | import backbone 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import utils 8 | from abc import abstractmethod 9 | import pdb 10 | class MetaTemplate(nn.Module): 11 | def __init__(self, model_func, n_way, n_support, change_way = True): 12 | super(MetaTemplate, self).__init__() 13 | self.n_way = n_way 14 | self.n_support = n_support 15 | self.n_query = -1 #(change depends on input) 16 | self.feature = model_func(avg_pool=True) 17 | self.feat_dim = self.feature.final_feat_dim 18 | self.change_way = change_way #some methods allow different_way classification during training and test 19 | 20 | @abstractmethod 21 | def set_forward(self,x,is_feature): 22 | pass 23 | 24 | @abstractmethod 25 | def set_forward_loss(self, x): 26 | pass 27 | 28 | def forward(self,x): 29 | out = self.feature.forward(x) 30 | return out 31 | 32 | def parse_feature(self,x,is_feature): 33 | x = Variable(x.cuda()) 34 | if is_feature: 35 | z_all = x 36 | else: 37 | x = x.contiguous().view( self.n_way * (self.n_support + self.n_query), *x.size()[2:]) 38 | z_all = self.feature.forward(x) 39 | z_all = z_all.view( self.n_way, self.n_support + self.n_query, -1) 40 | z_support = z_all[:, :self.n_support] 41 | z_query = z_all[:, self.n_support:] 42 | 43 | return z_support, z_query 44 | 45 | def correct(self, x): 46 | scores = self.set_forward(x) 47 | y_query = np.repeat(range( self.n_way ), self.n_query ) 48 | 49 | topk_scores, topk_labels = scores.data.topk(1, 1, True, True) 50 | topk_ind = topk_labels.cpu().numpy() 51 | top1_correct = np.sum(topk_ind[:,0] == y_query) 52 | return float(top1_correct), len(y_query) 53 | 54 | def train_loop(self, epoch, train_loader, optimizer, tb_logger=None): 55 | print_freq = 10 56 | 57 | avg_loss=0 58 | for i, (x,_ ) in enumerate(train_loader): 59 | self.n_query = x.size(1) - self.n_support 60 | if self.change_way: 61 | self.n_way = x.size(0) 62 | optimizer.zero_grad() 63 | loss = self.set_forward_loss( x ) 64 | loss.backward() 65 | optimizer.step() 66 | avg_loss = avg_loss+loss.item() 67 | 68 | if i % print_freq==0: 69 | #print(optimizer.state_dict()['param_groups'][0]['lr']) 70 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), avg_loss/float(i+1))) 71 | 72 | def test_loop(self, test_loader, record = None): 73 | correct =0 74 | count = 0 75 | acc_all = [] 76 | 77 | iter_num = len(test_loader) 78 | for i, (x,_) in enumerate(test_loader): 79 | self.n_query = x.size(1) - self.n_support 80 | if self.change_way: 81 | self.n_way = x.size(0) 82 | correct_this, count_this = self.correct(x) 83 | acc_all.append(correct_this/ count_this*100 ) 84 | 85 | acc_all = np.asarray(acc_all) 86 | acc_mean = np.mean(acc_all) 87 | acc_std = np.std(acc_all) 88 | print('%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num, acc_mean, 1.96* acc_std/np.sqrt(iter_num))) 89 | return acc_mean 90 | 91 | def set_forward_adaptation(self, x, is_feature = True): #further adaptation, default is fixing feature and train a new softmax clasifier 92 | assert is_feature == True, 'Feature is fixed in further adaptation' 93 | z_support, z_query = self.parse_feature(x,is_feature) 94 | 95 | z_support = z_support.contiguous().view(self.n_way* self.n_support, -1 ) 96 | z_query = z_query.contiguous().view(self.n_way* self.n_query, -1 ) 97 | 98 | y_support = torch.from_numpy(np.repeat(range( self.n_way ), self.n_support )) 99 | y_support = Variable(y_support.cuda()) 100 | 101 | linear_clf = nn.Linear(self.feat_dim, self.n_way) 102 | linear_clf = linear_clf.cuda() 103 | 104 | set_optimizer = torch.optim.SGD(linear_clf.parameters(), lr = 0.01, momentum=0.9, dampening=0.9, weight_decay=0.001) 105 | 106 | loss_function = nn.CrossEntropyLoss() 107 | loss_function = loss_function.cuda() 108 | 109 | batch_size = 4 110 | support_size = self.n_way* self.n_support 111 | for epoch in range(100): 112 | rand_id = np.random.permutation(support_size) 113 | for i in range(0, support_size , batch_size): 114 | set_optimizer.zero_grad() 115 | selected_id = torch.from_numpy( rand_id[i: min(i+batch_size, support_size) ]).cuda() 116 | z_batch = z_support[selected_id] 117 | y_batch = y_support[selected_id] 118 | scores = linear_clf(z_batch) 119 | loss = loss_function(scores,y_batch) 120 | loss.backward() 121 | set_optimizer.step() 122 | 123 | scores = linear_clf(z_query) 124 | return scores 125 | -------------------------------------------------------------------------------- /methods/baselinetrain.py: -------------------------------------------------------------------------------- 1 | import backbone 2 | import utils 3 | import pdb 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import numpy as np 9 | import torch.nn.functional as F 10 | 11 | class BaselineTrain(nn.Module): 12 | def __init__(self, model_func, num_class, res_model, loss_type = 'softmax', feature_aug=False, ratio=0, radius=10): 13 | super(BaselineTrain, self).__init__() 14 | self.feature = model_func(avg_pool=True) 15 | self.feature_aug = feature_aug 16 | self.ratio = ratio 17 | if loss_type == 'softmax': 18 | self.classifier = nn.Linear(self.feature.final_feat_dim, num_class) 19 | self.classifier.bias.data.fill_(0) 20 | elif loss_type == 'dist': #Baseline ++ 21 | #self.classifier = backbone.distLinear(self.feature.final_feat_dim, num_class) 22 | self.classifier = backbone.distLinear(640, num_class) 23 | elif loss_type == 'norm': 24 | self.classifier = backbone.NormLinear(self.feature.final_feat_dim, num_class, radius=radius) 25 | self.loss_type = loss_type #'softmax' #'dist' 26 | self.num_class = num_class 27 | self.loss_fn = nn.CrossEntropyLoss() 28 | self.rank = rank 29 | self.world_size = world_size 30 | self.DBval = False; #only set True for CUB dataset, see issue #31 31 | 32 | def forward(self,x): 33 | x = Variable(x.cuda()) 34 | out = self.feature.forward(x) 35 | return out 36 | 37 | 38 | 39 | def forward_loss(self, feature, y): 40 | scores = self.classifier.forward(feature) 41 | y = Variable(y.cuda()) 42 | return self.loss_fn(scores, y ) 43 | 44 | def train_loop(self, epoch, train_loader, optimizer, tb_logger): 45 | print_freq = 10 46 | ori_avg_loss=0 47 | aug_avg_loss=0 48 | 49 | ratio = self.ratio 50 | 51 | for i, (x, y) in enumerate(train_loader): 52 | ori_feature = self.forward(x) 53 | ori_loss = self.forward_loss(ori_feature, y) 54 | if self.feature_aug: 55 | aug_feature = ori_feature + torch.randn_like(ori_feature) * 0.5 56 | aug_loss = self.forward_loss(aug_feature, y) 57 | loss = ori_loss + aug_loss 58 | else: 59 | aug_loss = ori_loss 60 | loss = ori_loss 61 | 62 | optimizer.zero_grad() 63 | loss.backward() 64 | optimizer.step() 65 | 66 | ori_avg_loss = ori_avg_loss+ori_loss.item() 67 | aug_avg_loss = aug_avg_loss+aug_loss.item() 68 | 69 | bs = x.size(0) 70 | 71 | if i % print_freq==0: 72 | #print(optimizer.state_dict()['param_groups'][0]['lr']) 73 | print('Epoch {:d} | Batch {:d}/{:d} | Ori Loss {:f} | Aug Loss {:f}'.format(epoch, i, len(train_loader), ori_avg_loss/float(i+1), aug_avg_loss/float(i+1))) 74 | curr_step = epoch*len(train_loader) + i 75 | tb_logger.add_scalar('Ori Loss', ori_avg_loss/float(i+1), curr_step) 76 | 77 | 78 | 79 | def analysis_loop(self, val_loader, record = None): 80 | cls_class_file = {} 81 | #classifier = self.classifier.weight.data 82 | for i, (x,y) in enumerate(val_loader): 83 | x = x.cuda() 84 | x_var = Variable(x) 85 | feats = self.feature.forward(x_var) 86 | 87 | cls_feats = feats.data.cpu().numpy() 88 | labels = y.cpu().numpy() 89 | for f, l in zip(cls_feats, labels): 90 | if l not in cls_class_file.keys(): 91 | cls_class_file[l] = [] 92 | cls_class_file[l].append(f) 93 | for cl in cls_class_file: 94 | cls_class_file[cl] = np.array(cls_class_file[cl]) 95 | 96 | DB, intra_dist, inter_dist = DBindex(cls_class_file) 97 | #sum_dist = get_dist(classifier) 98 | print('DB index (cls) = %4.2f, intra_dist (cls) = %4.2f, inter_dist (cls) = %4.2f' %(DB, intra_dist, inter_dist)) 99 | return 1/DB #DB index: the lower the better 100 | 101 | 102 | 103 | 104 | 105 | 106 | def DBindex(cl_data_file): 107 | #For the definition Davis Bouldin index (DBindex), see https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index 108 | #DB index present the intra-class variation of the data 109 | #As baseline/baseline++ do not train few-shot classifier in training, this is an alternative metric to evaluate the validation set 110 | #Emperically, this only works for CUB dataset but not for miniImagenet dataset 111 | 112 | class_list = cl_data_file.keys() 113 | cl_num= len(class_list) 114 | cl_means = [] 115 | stds = [] 116 | DBs = [] 117 | intra_dist = [] 118 | inter_dist = [] 119 | for cl in class_list: 120 | cl_means.append( np.mean(cl_data_file[cl], axis = 0) ) 121 | stds.append( np.sqrt(np.mean( np.sum(np.square( cl_data_file[cl] - cl_means[-1]), axis = 1)))) 122 | 123 | mu_i = np.tile( np.expand_dims( np.array(cl_means), axis = 0), (len(class_list),1,1) ) 124 | mu_j = np.transpose(mu_i,(1,0,2)) 125 | mdists = np.sqrt(np.sum(np.square(mu_i - mu_j), axis = 2)) 126 | 127 | for i in range(cl_num): 128 | DBs.append( np.max([ (stds[i]+ stds[j])/mdists[i,j] for j in range(cl_num) if j != i ]) ) 129 | intra_dist.append(stds[i]) 130 | inter_dist.append(np.mean([mdists[i,j] for j in range(cl_num) if j != i])) 131 | 132 | return np.mean(DBs), np.mean(intra_dist), np.mean(mdists) 133 | -------------------------------------------------------------------------------- /io_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | import argparse 5 | import multiprocessing as mp 6 | import torch 7 | import torch.distributed as dist 8 | import backbone 9 | model_dict = dict( 10 | Conv4 = backbone.Conv4, 11 | ResNet12 = backbone.resnet12) 12 | 13 | def parse_args(script): 14 | parser = argparse.ArgumentParser(description= 'few-shot script %s' %(script)) 15 | parser.add_argument('--dataset' , default='CUB', help='CUB/NAB') 16 | parser.add_argument('--model' , default='Conv4', help='model: Conv4 / ResNet12') # 50 and 101 are not used in the paper 17 | parser.add_argument('--method' , default='baseline', help='baseline/baseline++') #relationnet_softmax replace L2 norm with softmax to expedite training, maml_approx use first-order approximation in the gradient for efficiency 18 | parser.add_argument('--train_n_way' , default=5, type=int, help='class num to classify for training') #baseline and baseline++ would ignore this parameter 19 | parser.add_argument('--save_dir' , help='directory to save model') #still required for save_features.py and test.py to find the model path correctly 20 | parser.add_argument('--save_iter', default=-1, type=int,help ='save feature from the model trained in x epoch, use the best model if x is -1') 21 | parser.add_argument('--assign_name' , help='directory to save model') #still required for save_features.py and test.py to find the model path correctly 22 | parser.add_argument('--test_n_way' , default=5, type=int, help='class num to classify for testing (validation) ') #baseline and baseline++ only use this parameter in finetuning 23 | parser.add_argument('--n_shot' , default=5, type=int, help='number of labeled data in each class, same as n_support') #baseline and baseline++ only use this parameter in finetuning 24 | parser.add_argument('--train_aug' , action='store_true', help='perform data augmentation or not during training ') #still required for save_features.py and test.py to find the model path correctly 25 | parser.add_argument('--feature_aug' , action='store_true', help='perform data augmentation or not during training ') #still required for save_features.py and test.py to find the model path correctly 26 | parser.add_argument('--loss_type' , default='norm', help='softmax/dist/arcface') #default novel, but you can also test base/val class accuracy if you want 27 | parser.add_argument('--num_classes' , default=200, type=int, help='total number of classes in softmax, only used in baseline') #make it larger than the maximum label value in base class 28 | parser.add_argument('--bs' , default=16, type=int, help ='batch size') 29 | parser.add_argument('--kl_weight' , default=1., type=float, help ='margin for arcface') 30 | parser.add_argument('--split' , default='base', help='base/val/novel') #default novel, but you can also test base/val class accuracy if you want 31 | 32 | if script == 'train': 33 | parser.add_argument('--save_freq' , default=50, type=int, help='Save frequency') 34 | parser.add_argument('--start_epoch' , default=0, type=int,help ='Starting epoch') 35 | parser.add_argument('--stop_epoch' , default=100, type=int, help ='Stopping epoch') #for meta-learning methods, each epoch contains 100 episodes. The default epoch number is dataset dependent. See train.py 36 | parser.add_argument('--lr_steps' , nargs='+', type=int, help ='learning rate decay steps') 37 | parser.add_argument('--lr' , default=0.001, type=float, help ='learning rate') 38 | parser.add_argument('--aug_weight' , default=0., type=float, help ='margin for arcface') 39 | parser.add_argument('--evaluate' , default=None) 40 | parser.add_argument('--resume' , action='store_true', help='continue from previous trained model with largest epoch') 41 | parser.add_argument('--resume_iter' , default=0, type=int, help='number of labeled data in each class, same as n_support') #baseline and baseline++ only use this parameter in finetuning 42 | parser.add_argument('--warmup' , action='store_true', help='continue from baseline, neglected if resume is true') #never used in the paper 43 | parser.add_argument('--warmup_file',help ='save feature from the model trained in x epoch, use the best model if x is -1') 44 | elif script == 'finetune_sample': 45 | parser.add_argument('--aug_per_sample' , default=2, type=int, help ='number of augmented features') 46 | else: 47 | raise ValueError('Unknown script') 48 | 49 | 50 | return parser.parse_args() 51 | 52 | 53 | def get_assigned_file(checkpoint_dir,assign_name): 54 | assign_file = os.path.join(checkpoint_dir, '{:s}.tar'.format(assign_name)) 55 | return assign_file 56 | 57 | def get_resume_file(checkpoint_dir, resume_iter=0): 58 | filelist = glob.glob(os.path.join(checkpoint_dir, '*.tar')) 59 | 60 | if resume_iter > 0: 61 | resume_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(resume_iter)) 62 | return resume_file 63 | 64 | if len(filelist) == 0: 65 | return None 66 | 67 | filelist = [ x for x in filelist if os.path.basename(x) != 'best_model.tar' ] 68 | epochs = np.array([int(os.path.splitext(os.path.basename(x))[0]) for x in filelist]) 69 | max_epoch = np.max(epochs) 70 | resume_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(max_epoch)) 71 | return resume_file 72 | 73 | def get_best_file(checkpoint_dir): 74 | best_file = os.path.join(checkpoint_dir, 'best_model.tar') 75 | if os.path.isfile(best_file): 76 | return best_file 77 | else: 78 | return get_resume_file(checkpoint_dir) 79 | 80 | 81 | -------------------------------------------------------------------------------- /train_vae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.optim 6 | import torch.optim.lr_scheduler as lr_scheduler 7 | from tensorboardX import SummaryWriter 8 | import time 9 | import os 10 | import glob 11 | import pdb 12 | 13 | import configs 14 | from data.datamgr import SimpleDataManager, SetDataManager 15 | from methods.baselinetrain import BaselineTrain 16 | from methods.baselinefinetune import BaselineFinetune 17 | from methods.baselinevae import DisentangleNet 18 | 19 | from io_utils import model_dict, parse_args, get_resume_file, get_assigned_file 20 | 21 | def train(base_loader, val_loader, model, optimization, start_epoch, stop_epoch, params, tb_logger): 22 | 23 | cls_params_ids = [] 24 | params_ids1 = [id(p) for p in model.backbone.parameters()] 25 | cls_params_ids.extend(params_ids1) 26 | params_ids2 = [id(p) for p in model.classifier.parameters()] 27 | 28 | cls_params_ids.extend(params_ids2) 29 | 30 | 31 | cls_params = [p for p in model.parameters() if id(p) in cls_params_ids and p.requires_grad] 32 | g_params = [p for p in model.parameters() if id(p) not in cls_params_ids and p.requires_grad] 33 | optimizer = torch.optim.Adam([ 34 | {'params': cls_params, 'lr': 0.001}, 35 | {'params': g_params, 'lr': 0.0001} 36 | ]) 37 | max_acc = 0 38 | for epoch in range(start_epoch,stop_epoch): 39 | model.train() 40 | if params.lr_steps is not None and epoch in params.lr_steps: 41 | for param_group in optimizer.param_groups: 42 | init_lr = param_group['lr'] 43 | param_group['lr'] = init_lr * 0.1 44 | 45 | model.train_all(epoch, base_loader, optimizer, tb_logger, len(base_loader)*params.bs) 46 | 47 | model.eval() 48 | 49 | if not os.path.isdir(params.checkpoint_dir): 50 | os.makedirs(params.checkpoint_dir) 51 | 52 | acc = model.analysis_loop(val_loader) 53 | 54 | if acc > max_acc : #for baseline and baseline++, we don't use validation in default and we let acc = -1, but we allow options to validate with DB index 55 | print("best model! save...") 56 | max_acc = acc 57 | outfile = os.path.join(params.checkpoint_dir, 'best_model.tar') 58 | torch.save({'epoch':epoch, 'state':model.state_dict()}, outfile) 59 | 60 | if ((epoch % params.save_freq==0) or (epoch==stop_epoch-1)): 61 | outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch)) 62 | torch.save({'epoch':epoch, 'state':model.state_dict()}, outfile) 63 | 64 | return model 65 | 66 | if __name__=='__main__': 67 | 68 | np.random.seed(10) 69 | params = parse_args('train') 70 | 71 | base_file = configs.data_dir[params.dataset] + params.split + '.json' 72 | val_file = configs.data_dir[params.dataset] + 'val.json' 73 | 74 | if 'Conv' in params.model or 'ResNet12' in params.model: 75 | image_size = 84 76 | else: 77 | image_size = 224 78 | 79 | 80 | optimization = 'Adam' 81 | 82 | 83 | if params.method in ['baseline', 'baseline++'] : 84 | base_datamgr = SimpleDataManager(image_size, batch_size = params.bs) 85 | base_loader = base_datamgr.get_data_loader( base_file , aug = params.train_aug ) 86 | val_datamgr = SimpleDataManager(image_size, batch_size = params.bs) 87 | val_loader = val_datamgr.get_data_loader( val_file, aug = False) 88 | 89 | 90 | if params.method == 'baseline': 91 | model = BaselineTrain( model_dict[params.model], params.num_classes) 92 | else: 93 | model = DisentangleNet( model_dict[params.model], params.num_classes, kl_weight=params.kl_weight, aug_weight=params.aug_weight, loss_type = params.loss_type) 94 | 95 | 96 | 97 | model = model.cuda() 98 | 99 | params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(params.save_dir, params.dataset, params.model, params.method) 100 | if params.train_aug: 101 | params.checkpoint_dir += '_aug' 102 | if not params.method in ['baseline', 'baseline++']: 103 | params.checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot) 104 | 105 | params.checkpoint_dir += '_' + params.split 106 | params.checkpoint_dir += '_%.2f'%(params.kl_weight) 107 | 108 | if not os.path.isdir(params.checkpoint_dir): 109 | os.makedirs(params.checkpoint_dir) 110 | 111 | tb_logger = SummaryWriter('%s/events/%s' %(params.save_dir, time.time())) 112 | 113 | start_epoch = params.start_epoch 114 | stop_epoch = params.stop_epoch 115 | if params.method == 'maml' or params.method == 'maml_approx' : 116 | stop_epoch = params.stop_epoch * model.n_task #maml use multiple tasks in one update 117 | 118 | if params.resume: 119 | resume_file = get_resume_file(params.checkpoint_dir, params.resume_iter) 120 | if resume_file is not None: 121 | tmp = torch.load(resume_file) 122 | start_epoch = tmp['epoch']+1 123 | state = tmp['state'] 124 | state_keys = list(state.keys()) 125 | model.load_state_dict(state, strict=True) 126 | keys1 = set([k for k,_ in model.named_parameters()]) 127 | keys2 = set(tmp['state'].keys()) 128 | not_loaded = keys2 - keys1 129 | for k in not_loaded: 130 | print('caution: {} not loaded'.format(k)) 131 | 132 | elif params.warmup: #We also support warmup from pretrained baseline feature, but we never used in our paper 133 | warmup_resume_file = get_assigned_file(params.checkpoint_dir, str(params.warmup_file)) 134 | tmp = torch.load(warmup_resume_file) 135 | if tmp is not None: 136 | state = tmp['state'] 137 | state_keys = list(state.keys()) 138 | for i, key in enumerate(state_keys): 139 | if "feature." in key: 140 | newkey = key.replace("feature.","") # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx' 141 | state[newkey] = state.pop(key) 142 | else: 143 | state.pop(key) 144 | model.backbone.load_state_dict(state) 145 | else: 146 | raise ValueError('No warm_up file') 147 | 148 | if params.evaluate: 149 | eval_file = os.path.join(params.checkpoint_dir, params.evaluate) 150 | if eval_file is not None: 151 | tmp = torch.load(eval_file) 152 | model.load_state_dict(tmp['state']) 153 | model.eval() 154 | acc = model.get_cov(val_loader) 155 | else: 156 | print(params) 157 | model = train(base_loader, val_loader, model, optimization, start_epoch, stop_epoch, params, tb_logger) 158 | -------------------------------------------------------------------------------- /finetune_sample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | import os 5 | import glob 6 | import h5py 7 | import pdb 8 | import random 9 | import time 10 | import configs 11 | import backbone 12 | from data.datamgr import SimpleDataManager 13 | from data import feature_loader 14 | from methods.baselinetrain import BaselineTrain 15 | from methods.baselinefinetune import BaselineFinetune 16 | from methods.baselinevae import DisentangleNet 17 | from io_utils import model_dict, parse_args, get_resume_file, get_best_file, get_assigned_file 18 | from utils import l2_norm 19 | from sklearn.manifold import TSNE 20 | from sklearn.cluster import KMeans 21 | from sklearn.metrics.pairwise import euclidean_distances 22 | from scipy.stats import entropy 23 | from scipy.special import softmax 24 | 25 | def aug_features(all_cls_feature, mu, logvar, feature_map, model, aug_per_sample=2, n_way=5, n_shot=1, feat_dim=512): 26 | all_cls_feature = all_cls_feature.view(n_way, n_shot, feat_dim) 27 | mu = mu.view(n_way, n_shot, feat_dim) 28 | logvar = logvar.view(n_way, n_shot, feat_dim) 29 | aug_features = torch.zeros((n_way, aug_per_sample*n_shot, feat_dim)) 30 | aug_y= torch.from_numpy(np.repeat(range( n_way ), aug_per_sample*n_shot)) 31 | aug_y = Variable(aug_y).cuda() 32 | for cls in range(n_way): 33 | cls_feature = all_cls_feature[cls,:,:] 34 | cls_feature = cls_feature.repeat(aug_per_sample, 1) 35 | cls_mu = mu[cls,:,:] 36 | #cls_mu = cls_mu.mean(0, True) 37 | cls_mu = cls_mu.repeat(aug_per_sample, 1) 38 | cls_logvar = logvar[cls,:,:] 39 | #cls_logvar = cls_logvar.mean(0, True) 40 | cls_logvar = cls_logvar.repeat(aug_per_sample, 1) 41 | cls_invar_feature = torch.randn_like(cls_feature) 42 | #cls_invar_feature = model.reparameterize(cls_mu, cls_logvar) 43 | #cls_feature_map = feature_map[cls,:,:,:] 44 | #cls_feature_map = cls_feature_map.unsqueeze(0) 45 | #cls_invar_feature = torch.randn_like(cls_feature) * 0.2 46 | aggr_feature = cls_feature + cls_invar_feature 47 | #recon_feature = model.forward_d(aggr_feature) 48 | #recon_loss = model.superloss(recon_feature, cls_feature_map) 49 | #cls_aug_feature = model.cls_fc(recon_feature) 50 | cls_aug_feature = aggr_feature 51 | aug_features[cls, :, :] = cls_aug_feature 52 | aug_features = aug_features.view(n_way*aug_per_sample*n_shot, -1) 53 | return aug_features.detach().cuda(), aug_y 54 | 55 | 56 | 57 | 58 | def finetune_backbone(model, img_data_file, n_way, n_support, feat_dim, aug=False, n_query=15): 59 | 60 | class_list = img_data_file.keys() 61 | select_class = random.sample(class_list, n_way) 62 | img_all = [] 63 | 64 | for cl in select_class: 65 | img_data = img_data_file[cl] 66 | img_data = np.array(img_data) 67 | perm_ids = np.random.permutation(len(img_data)).tolist() 68 | perm_ids = perm_ids + perm_ids 69 | img_all.append( [ np.squeeze( img_data[perm_ids[i]]) for i in range(n_support+n_query) ] ) # stack each batch 70 | 71 | # samples images for support set and query set 72 | img_all = torch.from_numpy(np.array(img_all)) 73 | img_all = Variable(img_all).cuda() 74 | [c, h, w] = img_all[0][0].shape 75 | 76 | x_support = img_all[:, :n_support,:,:,:] 77 | x_query = img_all[:,n_support:,:,:,:] 78 | 79 | x_support = x_support.contiguous().view(n_way* n_support, c, h, w) 80 | x_query = x_query.contiguous().view(n_way* n_query, c, h, w) 81 | 82 | y_support = torch.from_numpy(np.repeat(range( n_way ), n_support )) 83 | y_support = Variable(y_support).cuda() 84 | 85 | z_support, mu, logvar, feature_map = model(x_support) 86 | z_query, _, _, _ = model(x_query) 87 | #z_support = model(x_support).detach().cuda() 88 | z_support = z_support.view(n_way*n_support, -1).detach().cuda() 89 | z_query = z_query.view(n_way*n_query, -1).detach().cuda() 90 | #z_mean = torch.mean(z_support, dim=1) 91 | #_, z_mean = l2_norm(z_mean) 92 | 93 | #z_query = model(x_query).detach().cuda() 94 | #z_query = z_query.view(n_way, n_query, -1) 95 | 96 | #aug_z = recon_z.view(n_way*n_support, -1).detach().cuda() 97 | #aug_y = y_support.clone() 98 | #z_support_all = z_support.view(n_way*n_support, -1) 99 | feat_dim = 640 100 | y_query = np.repeat(range( n_way ), n_query ) 101 | aug_z, aug_y = aug_features(z_support, mu, logvar, feature_map, model, n_shot=n_support, aug_per_sample=params.aug_per_sample, feat_dim=feat_dim) 102 | #aug_z, aug_y = trans_features(aug_z, aug_y, z_query, z_support) 103 | #cls_invar = model.reparameterize(mu, logvar) 104 | #aug_z, aug_y = aug_beta_features(z_support, cls_invar, feature_map, model, n_shot=n_support) 105 | if aug_z is not None: 106 | z_support_all = torch.cat((z_support, aug_z)) 107 | y_support_all = torch.cat((y_support, aug_y)) 108 | else: 109 | z_support_all = z_support 110 | y_support_all = y_support 111 | #z_support_all = z_support 112 | #y_support_all = y_support 113 | # train classifier with augmened features 114 | linear_clf = backbone.distLinear(feat_dim, n_way).cuda() 115 | ## initialize weights for linear_clf 116 | set_optimizer = torch.optim.SGD(linear_clf.parameters(), lr = 0.01, momentum=0.9, dampening=0.9, weight_decay=0.001) 117 | 118 | loss_function = torch.nn.CrossEntropyLoss() 119 | loss_function = loss_function.cuda() 120 | # 121 | batch_size = 4 122 | support_size = z_support_all.shape[0] 123 | for epoch in range(100): 124 | rand_id = np.random.permutation(support_size) 125 | for i in range(0, support_size , batch_size): 126 | set_optimizer.zero_grad() 127 | selected_id = torch.from_numpy( rand_id[i: min(i+batch_size, support_size) ]).cuda() 128 | z_batch = z_support_all[selected_id] 129 | y_batch = y_support_all[selected_id] 130 | scores = linear_clf(z_batch) 131 | loss = loss_function(scores,y_batch) 132 | loss.backward() 133 | set_optimizer.step() 134 | # 135 | model.eval() 136 | scores = linear_clf(z_query) 137 | pred = torch.argmax(scores, 1) 138 | #acc = np.mean(np.array(pred.cpu().data) == y_query) 139 | scores = linear_clf(z_query) 140 | pred = torch.argmax(scores, 1) 141 | #pdb.set_trace() 142 | acc = np.mean(np.array(pred.cpu().data) == y_query)*100 143 | return acc 144 | 145 | 146 | if __name__ == '__main__': 147 | params = parse_args('finetune_sample') 148 | 149 | image_size = 84 150 | 151 | 152 | split = "novel" 153 | loadfile = configs.data_dir[params.dataset] + split + '.json' 154 | 155 | checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(params.save_dir, params.dataset, params.model, params.method) 156 | if params.train_aug: 157 | checkpoint_dir += '_aug' 158 | checkpoint_dir += "_" + params.split 159 | checkpoint_dir += '_%.2f'%(params.kl_weight) 160 | if params.assign_name is not None: 161 | modelfile = get_assigned_file(checkpoint_dir,params.assign_name) 162 | # elif params.method in ['baseline', 'baseline++'] : 163 | # modelfile = get_resume_file(checkpoint_dir) #comment in 2019/08/03 updates as the validation of baseline/baseline++ is added 164 | else: 165 | modelfile = get_best_file(checkpoint_dir) 166 | if params.save_iter != -1: 167 | novel_file = os.path.join( checkpoint_dir.replace("checkpoints","features"), split + "_" + str(params.save_iter)+ ".hdf5") 168 | else: 169 | novel_file = os.path.join( checkpoint_dir.replace("checkpoints","features"), split + ".hdf5") 170 | 171 | datamgr = SimpleDataManager(image_size, batch_size = 64) 172 | data_loader = datamgr.get_data_loader(loadfile, aug = False, shuffle=False) 173 | img_data_file = feature_loader.init_img_loader(data_loader) 174 | 175 | #base_file = os.path.join( checkpoint_dir.replace("checkpoints","features"), "base.hdf5") #defaut split = novel, but you can also test base or val classes 176 | #base_data_file = feature_loader.init_loader(base_file) 177 | 178 | model = DisentangleNet( model_dict[params.model], params.num_classes, kl_weight=params.kl_weight, loss_type = params.loss_type) 179 | 180 | tmp = torch.load(modelfile) 181 | state = tmp['state'] 182 | state_keys = list(state.keys()) 183 | for i, key in enumerate(state_keys): 184 | if "module." in key: 185 | newkey = key.replace("module.","") # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx' 186 | state[newkey] = state.pop(key) 187 | # else: 188 | # state.pop(key) 189 | 190 | model = model.cuda() 191 | 192 | dirname = os.path.dirname(novel_file) 193 | if not os.path.isdir(dirname): 194 | os.makedirs(dirname) 195 | 196 | iter_num = 600 197 | acc_all = [] 198 | model.load_state_dict(state, strict=False) 199 | #visualize_intra_cls_var(model) 200 | for i in range(iter_num): 201 | acc = finetune_backbone(model, img_data_file, params.train_n_way, params.n_shot, 512) 202 | acc_all.append(acc) 203 | if i%10 == 0: 204 | print('Iter: %d, Acc : %f, Avg Acc: %f'% (i, acc, np.mean(np.array(acc_all)))) 205 | acc_all = np.asarray(acc_all) 206 | acc_mean = np.mean(acc_all) 207 | acc_std = np.std(acc_all) 208 | 209 | print('%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num, acc_mean, 1.96* acc_std/np.sqrt(iter_num))) 210 | 211 | 212 | with open('./record/results.txt' , 'a') as f: 213 | timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime()) 214 | aug_str = '-aug' if params.train_aug else '' 215 | if False : 216 | exp_setting = '%s-%s-%s-%s %sshot %sway_test' %(params.dataset, params.model, params.method, aug_str, params.n_shot, params.test_n_way ) 217 | else: 218 | exp_setting = '%s-%s-%s%s %sshot %sway_train %sway_test aug%s' %(params.dataset, params.model, params.method, aug_str , params.n_shot , params.train_n_way, params.test_n_way, str(params.aug_per_sample) ) 219 | acc_str = '%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num, acc_mean, 1.96* acc_std/np.sqrt(iter_num)) 220 | f.write( 'Time: %s, Setting: %s, Acc: %s \n' %(timestamp,exp_setting,acc_str) ) 221 | -------------------------------------------------------------------------------- /methods/baselinevae.py: -------------------------------------------------------------------------------- 1 | import backbone 2 | import utils 3 | import pdb 4 | import os 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | from torchvision.utils import save_image 10 | import numpy as np 11 | import torch.nn.functional 12 | from utils import l2_norm, _get_log_pz_qz_prodzi_qzCx 13 | import torch.distributed as dist 14 | 15 | class Conv_block(nn.Module): 16 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 17 | super(Conv_block, self).__init__() 18 | self.conv = nn.Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False) 19 | self.bn = nn.BatchNorm2d(out_c) 20 | self.prelu = nn.PReLU(out_c) 21 | def forward(self, x): 22 | x = self.conv(x) 23 | x = self.bn(x) 24 | x = self.prelu(x) 25 | return x 26 | 27 | 28 | 29 | 30 | class DisentangleNet(nn.Module): 31 | def __init__(self, model_func, num_class, kl_weight=1, loss_type = 'softmax', aug_weight=0, use_conv=False, rank=0, world_size=0, avg=True): 32 | super(DisentangleNet, self).__init__() 33 | self.backbone = model_func(flatten=False) 34 | self.aug_weight = aug_weight 35 | self.DBval = True 36 | self.rank = rank 37 | self.world_size = world_size 38 | self.use_conv = use_conv 39 | if not use_conv: 40 | channel = 640 41 | pool_size = 5 42 | feature_dim = 640 43 | self.cls_fc = nn.Sequential( 44 | nn.AvgPool2d(pool_size), 45 | backbone.Flatten() 46 | ) 47 | else: 48 | channel = 64 49 | pool_size = 5 50 | feature_dim = 1600 51 | self.cls_fc = nn.Sequential( 52 | backbone.Flatten() 53 | ) 54 | 55 | cls_feature_dim = feature_dim 56 | self.encoder = nn.Sequential( 57 | Conv_block(channel, channel, kernel=(3,3), padding=(1,1)), 58 | Conv_block(channel, channel, kernel=(3,3), padding=(1,1)), 59 | Conv_block(channel, channel, kernel=(3,3), padding=(1,1)), 60 | backbone.Flatten(), 61 | ) 62 | self.vae_mean = nn.Linear(channel*pool_size*pool_size, feature_dim) 63 | self.vae_var = nn.Linear(channel*pool_size*pool_size, feature_dim) 64 | self.decoder_fc = nn.Linear(feature_dim, channel*pool_size*pool_size) 65 | self.decoder = nn.Sequential( 66 | Conv_block(channel, channel, kernel=(3,3), padding=(1,1)), 67 | Conv_block(channel, channel, kernel=(3,3), padding=(1,1)), 68 | Conv_block(channel, channel, kernel=(3,3), padding=(1,1)), 69 | ) 70 | self.superloss = torch.nn.MSELoss() 71 | if loss_type == 'softmax': 72 | self.classifier = nn.Linear(cls_feature_dim, num_class) 73 | self.classifier.bias.data.fill_(0) 74 | elif loss_type == 'dist': #Baseline ++ 75 | self.classifier = backbone.distLinear(cls_feature_dim, num_class) 76 | else: 77 | self.classifier = backbone.NormLinear(cls_feature_dim, num_class, radius=10) 78 | self.smloss = torch.nn.CrossEntropyLoss() 79 | self.kl_weight = kl_weight 80 | 81 | 82 | def reparameterize(self, mu, logvar): 83 | std = torch.exp(0.5*logvar) 84 | eps = torch.randn_like(std) 85 | # remove abnormal points 86 | return mu + eps*std 87 | 88 | def forward_d(self, aggr_feature): 89 | if not self.use_conv: 90 | channel = 640 91 | pool_size = 5 92 | else: 93 | channel = 64 94 | pool_size = 5 95 | recon_feature = self.decoder_fc(aggr_feature).view(-1, channel, pool_size, pool_size) 96 | recon_feature = self.decoder(recon_feature) 97 | return recon_feature 98 | 99 | def forward(self, x): 100 | feature_map = self.backbone(x) 101 | cls_feature = self.cls_fc(feature_map) 102 | #if self.use_wide: 103 | # feature_map = self.downsample(feature_map) 104 | bs = x.size(0) 105 | if not self.use_conv: 106 | channel = 640 107 | else: 108 | channel = 64 109 | encoder_map = feature_map + cls_feature.view(bs, channel, 1, 1).detach() 110 | #encoder_map = feature_map 111 | encoder_feature = self.encoder(encoder_map) 112 | mu = self.vae_mean(encoder_feature) 113 | logvar = self.vae_var(encoder_feature) 114 | return cls_feature, mu, logvar, feature_map 115 | #return cls_feature 116 | 117 | 118 | def train_all(self, epoch, train_loader, optimizer, tb_logger, n_data=None): 119 | print_freq = 10 120 | cls_avg_loss = 0 121 | recon_avg_loss = 0 122 | kl_avg_loss = 0 123 | aug_avg_loss = 0 124 | beta = self.kl_weight 125 | 126 | for i, (x, y) in enumerate(train_loader): 127 | x = Variable(x.cuda()) 128 | bs = x.size(0) 129 | cls_feature, mu, logvar, feature_map = self.forward(x) 130 | scores = self.classifier.forward(cls_feature) 131 | y = Variable(y.cuda()) 132 | 133 | # classification loss 134 | cls_loss = self.smloss(scores, y) 135 | cls_invar_feature = self.reparameterize(mu, logvar) 136 | aggr_feature = cls_feature + cls_invar_feature.detach() 137 | 138 | # reconstruction loss 139 | recon_feature = self.forward_d(aggr_feature) 140 | recon_loss = self.superloss(recon_feature, feature_map.detach()) 141 | #recon_loss = self.superloss(recon_feature, x) 142 | # feature aug loss 143 | if self.aug_weight > 0: 144 | aug_cls_feature = cls_feature + cls_invar_feature 145 | aug_scores = self.classifier.forward(aug_cls_feature) 146 | aug_cls_loss = self.smloss(aug_scores, y) 147 | else: 148 | aug_cls_loss = torch.zeros(1).cuda() 149 | 150 | # kl_loss 151 | #kl_loss = -0.5*torch.sum(1+logvar-logvar.exp()-mu.pow(2)) / bs 152 | log_pz, log_qz, log_prod_qzi, log_q_zCx = _get_log_pz_qz_prodzi_qzCx(cls_invar_feature, (mu, logvar), n_data, is_mss=False) 153 | #I[z;x] = KL[q(z,x)||q(x)q(z)] = E_x[KL[q(z|x)||q(z)]] 154 | mi_loss = (log_q_zCx - log_qz).mean() 155 | # TC[z] = KL[q(z)||\prod_i z_i] 156 | tc_loss = (log_qz - log_prod_qzi).mean() 157 | # dw_kl_loss is KL[q(z)||p(z)] instead of usual KL[q(z|x)||p(z))] 158 | dw_kl_loss = (log_prod_qzi - log_pz).mean() 159 | #loss = cls_loss + recon_loss + kl_loss * self.kl_weight + aug_cls_loss * self.aug_weight 160 | loss = cls_loss + recon_loss + mi_loss + dw_kl_loss + tc_loss * beta + aug_cls_loss * self.aug_weight 161 | #loss = cls_loss + recon_loss + kl_loss * self.kl_weight 162 | kl_loss = mi_loss + dw_kl_loss + tc_loss 163 | 164 | optimizer.zero_grad() 165 | loss.backward() 166 | optimizer.step() 167 | 168 | pred = torch.argmax(scores, dim=1) 169 | acc = torch.mean((pred == y).float()) 170 | 171 | cls_avg_loss = cls_avg_loss+cls_loss.item() 172 | recon_avg_loss = recon_avg_loss+recon_loss.item() 173 | kl_avg_loss = kl_avg_loss+kl_loss.item() 174 | aug_avg_loss = aug_avg_loss+aug_cls_loss.item() 175 | 176 | if i%print_freq==0 and self.rank == 0: 177 | print('Epoch {:d} | Batch {:d}/{:d} | Cl Loss {:f} | Recon Loss {:f} | Kl Loss {:f} | Aug Loss {:f}'.format(epoch, i, len(train_loader), cls_avg_loss/float(i+1), recon_avg_loss/float(i+1), kl_avg_loss/float(i+1), aug_avg_loss/float(i+1))) 178 | curr_step = epoch*len(train_loader) + i 179 | tb_logger.add_scalar('Cl Loss', cls_avg_loss/float(i+1), curr_step) 180 | tb_logger.add_scalar('Recon Loss', recon_avg_loss/float(i+1), curr_step) 181 | tb_logger.add_scalar('KL Loss', kl_avg_loss/float(i+1), curr_step) 182 | tb_logger.add_scalar('Aug Loss', aug_avg_loss/float(i+1), curr_step) 183 | 184 | 185 | 186 | 187 | def analysis_loop(self, val_loader, record = None): 188 | cls_class_file = {} 189 | #classifier = self.classifier.weight.data 190 | for i, (x,y) in enumerate(val_loader): 191 | x = x.cuda() 192 | x_var = Variable(x) 193 | feats = self.backbone.forward(x_var) 194 | feats = self.cls_fc(feats) 195 | 196 | cls_feats = feats.data.cpu().numpy() 197 | labels = y.cpu().numpy() 198 | for f, l in zip(cls_feats, labels): 199 | if l not in cls_class_file.keys(): 200 | cls_class_file[l] = [] 201 | cls_class_file[l].append(f) 202 | for cl in cls_class_file: 203 | cls_class_file[cl] = np.array(cls_class_file[cl]) 204 | 205 | DB, intra_dist, inter_dist = DBindex(cls_class_file) 206 | #sum_dist = get_dist(classifier) 207 | print('DB index (cls) = %4.2f, intra_dist (cls) = %4.2f, inter_dist (cls) = %4.2f' %(DB, intra_dist, inter_dist)) 208 | return 1/DB #DB index: the lower the better 209 | 210 | 211 | 212 | 213 | def DBindex(cls_data_file): 214 | #For the definition Davis Bouldin index (DBindex), see https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index 215 | #DB index present the intra-class variation of the data 216 | #As baseline/baseline++ do not train few-shot classifier in training, this is an alternative metric to evaluate the validation set 217 | #Emperically, this only works for CUB dataset but not for miniImagenet dataset 218 | 219 | class_list = cls_data_file.keys() 220 | cls_num= len(class_list) 221 | cls_means = [] 222 | stds = [] 223 | DBs = [] 224 | intra_dist = [] 225 | inter_dist = [] 226 | for cl in class_list: 227 | cls_means.append( np.mean(cls_data_file[cl], axis = 0) ) 228 | stds.append( np.sqrt(np.mean( np.sum(np.square( cls_data_file[cl] - cls_means[-1]), axis = 1)))) 229 | 230 | mu_i = np.tile( np.expand_dims( np.array(cls_means), axis = 0), (len(class_list),1,1) ) 231 | mu_j = np.transpose(mu_i,(1,0,2)) 232 | mdists = np.sqrt(np.sum(np.square(mu_i - mu_j), axis = 2)) 233 | 234 | for i in range(cls_num): 235 | DBs.append( np.max([ (stds[i]+ stds[j])/mdists[i,j] for j in range(cls_num) if j != i ]) ) 236 | intra_dist.append(stds[i]) 237 | inter_dist.append(np.mean([mdists[i,j] for j in range(cls_num) if j != i])) 238 | 239 | return np.mean(DBs), np.mean(intra_dist), np.mean(mdists) 240 | -------------------------------------------------------------------------------- /backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import pdb 5 | import math 6 | from utils import l2_norm 7 | from torch.nn.utils.weight_norm import WeightNorm 8 | from torch.distributions import Bernoulli 9 | # This ResNet network was designed following the practice of the following papers: 10 | # TADAM: Task dependent adaptive metric for improved few-shot learning (Oreshkin et al., in NIPS 2018) and 11 | # A Simple Neural Attentive Meta-Learner (Mishra et al., in ICLR 2018). 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | class Flatten(nn.Module): 19 | def __init__(self): 20 | super(Flatten, self).__init__() 21 | 22 | def forward(self, x): 23 | return x.view(x.size(0), -1) 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False, block_size=1): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.LeakyReLU(0.1) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.conv3 = conv3x3(planes, planes) 36 | self.bn3 = nn.BatchNorm2d(planes) 37 | self.maxpool = nn.MaxPool2d(stride) 38 | self.downsample = downsample 39 | self.stride = stride 40 | self.drop_rate = drop_rate 41 | self.num_batches_tracked = 0 42 | self.drop_block = drop_block 43 | self.block_size = block_size 44 | self.DropBlock = DropBlock(block_size=self.block_size) 45 | 46 | def forward(self, x): 47 | self.num_batches_tracked += 1 48 | 49 | residual = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | out = self.relu(out) 58 | 59 | out = self.conv3(out) 60 | out = self.bn3(out) 61 | 62 | if self.downsample is not None: 63 | residual = self.downsample(x) 64 | out += residual 65 | out = self.relu(out) 66 | out = self.maxpool(out) 67 | 68 | if self.drop_rate > 0: 69 | if self.drop_block == True: 70 | feat_size = out.size()[2] 71 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate) 72 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2 73 | out = self.DropBlock(out, gamma=gamma) 74 | else: 75 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True) 76 | 77 | return out 78 | 79 | class DropBlock(nn.Module): 80 | def __init__(self, block_size): 81 | super(DropBlock, self).__init__() 82 | 83 | self.block_size = block_size 84 | #self.gamma = gamma 85 | #self.bernouli = Bernoulli(gamma) 86 | 87 | def forward(self, x, gamma): 88 | # shape: (bsize, channels, height, width) 89 | 90 | if self.training: 91 | batch_size, channels, height, width = x.shape 92 | 93 | bernoulli = Bernoulli(gamma) 94 | mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda() 95 | #print((x.sample[-2], x.sample[-1])) 96 | block_mask = self._compute_block_mask(mask) 97 | #print (block_mask.size()) 98 | #print (x.size()) 99 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3] 100 | count_ones = block_mask.sum() 101 | 102 | return block_mask * x * (countM / count_ones) 103 | else: 104 | return x 105 | 106 | def _compute_block_mask(self, mask): 107 | left_padding = int((self.block_size-1) / 2) 108 | right_padding = int(self.block_size / 2) 109 | 110 | batch_size, channels, height, width = mask.shape 111 | #print ("mask", mask[0][0]) 112 | non_zero_idxs = mask.nonzero() 113 | nr_blocks = non_zero_idxs.shape[0] 114 | 115 | offsets = torch.stack( 116 | [ 117 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding, 118 | torch.arange(self.block_size).repeat(self.block_size), #- left_padding 119 | ] 120 | ).t().cuda() 121 | offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()), 1) 122 | 123 | if nr_blocks > 0: 124 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1) 125 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4) 126 | offsets = offsets.long() 127 | 128 | block_idxs = non_zero_idxs + offsets 129 | #block_idxs += left_padding 130 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 131 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1. 132 | else: 133 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 134 | 135 | block_mask = 1 - padded_mask#[:height, :width] 136 | return block_mask 137 | 138 | 139 | class ResNet(nn.Module): 140 | 141 | def __init__(self, block, keep_prob=1.0, avg_pool=False, drop_rate=0.0, dropblock_size=5, flatten=False): 142 | self.inplanes = 3 143 | super(ResNet, self).__init__() 144 | self.final_feat_dim = 640 145 | self.layer1 = self._make_layer(block, 64, stride=2, drop_rate=drop_rate) 146 | self.layer2 = self._make_layer(block, 160, stride=2, drop_rate=drop_rate) 147 | self.layer3 = self._make_layer(block, 320, stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 148 | self.layer4 = self._make_layer(block, 640, stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 149 | if avg_pool: 150 | self.avgpool = nn.Sequential( 151 | nn.AvgPool2d(5, stride=1), 152 | Flatten() 153 | ) 154 | self.keep_prob = keep_prob 155 | self.flatten = flatten 156 | self.keep_avg_pool = avg_pool 157 | self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False) 158 | self.drop_rate = drop_rate 159 | 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 163 | elif isinstance(m, nn.BatchNorm2d): 164 | nn.init.constant_(m.weight, 1) 165 | nn.init.constant_(m.bias, 0) 166 | 167 | def _make_layer(self, block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1): 168 | downsample = None 169 | if stride != 1 or self.inplanes != planes * block.expansion: 170 | downsample = nn.Sequential( 171 | nn.Conv2d(self.inplanes, planes * block.expansion, 172 | kernel_size=1, stride=1, bias=False), 173 | nn.BatchNorm2d(planes * block.expansion), 174 | ) 175 | 176 | layers = [] 177 | layers.append(block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size)) 178 | self.inplanes = planes * block.expansion 179 | 180 | return nn.Sequential(*layers) 181 | 182 | def forward(self, x): 183 | x = self.layer1(x) 184 | x = self.layer2(x) 185 | x = self.layer3(x) 186 | x = self.layer4(x) 187 | if self.keep_avg_pool: 188 | x = self.avgpool(x) 189 | return x 190 | 191 | class ConvNet(nn.Module): 192 | def __init__(self, depth, flatten = True): 193 | super(ConvNet,self).__init__() 194 | trunk = [] 195 | for i in range(depth): 196 | indim = 3 if i == 0 else 64 197 | outdim = 64 198 | B = ConvBlock(indim, outdim, pool = ( i <4 ) ) #only pooling for fist 4 layers 199 | trunk.append(B) 200 | 201 | if flatten: 202 | trunk.append(Flatten()) 203 | 204 | self.trunk = nn.Sequential(*trunk) 205 | self.final_feat_dim = 640 206 | 207 | def forward(self,x): 208 | out = self.trunk(x) 209 | return out 210 | 211 | 212 | def resnet12(keep_prob=1.0, avg_pool=False, flatten=True, **kwargs): 213 | """Constructs a ResNet-12 model. 214 | """ 215 | model = ResNet(BasicBlock, keep_prob=keep_prob, drop_rate=0.1, dropblock_size=2, avg_pool=avg_pool, flatten=flatten) 216 | return model 217 | 218 | 219 | def Conv4(avg_pool=True, flatten=True): 220 | return ConvNet(4, flatten=flatten) 221 | 222 | class distLinear(nn.Module): 223 | def __init__(self, indim, outdim): 224 | super(distLinear, self).__init__() 225 | self.L = nn.Linear( indim, outdim, bias = False) 226 | self.class_wise_learnable_norm = True #See the issue#4&8 in the github 227 | if self.class_wise_learnable_norm: 228 | WeightNorm.apply(self.L, 'weight', dim=0) #split the weight update component to direction and norm 229 | 230 | if outdim <=200: 231 | self.scale_factor = 2; #a fixed scale factor to scale the output of cos value into a reasonably large input for softmax, for to reproduce the result of CUB with ResNet10, use 4. see the issue#31 in the github 232 | else: 233 | self.scale_factor = 10; #in omniglot, a larger scale factor is required to handle >1000 output classes. 234 | 235 | def forward(self, x, y=None): 236 | x_norm = torch.norm(x, p=2, dim =1).unsqueeze(1).expand_as(x) 237 | x_normalized = x.div(x_norm+ 0.00001) 238 | if not self.class_wise_learnable_norm: 239 | L_norm = torch.norm(self.L.weight.data, p=2, dim =1).unsqueeze(1).expand_as(self.L.weight.data) 240 | self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001) 241 | cos_dist = self.L(x_normalized) #matrix product by forward function, but when using WeightNorm, this also multiply the cosine distance by a class-wise learnable norm, see the issue#4&8 in the github 242 | scores = self.scale_factor* (cos_dist) 243 | 244 | return scores 245 | 246 | 247 | class NormLinear(nn.Module): 248 | # implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599 249 | def __init__(self, embedding_size=512, classnum=51332, radius=10, pretrained=None): 250 | super(NormLinear, self).__init__() 251 | self.classnum = classnum 252 | self.s = radius 253 | # initial kernel 254 | self.weight = nn.Parameter(torch.Tensor(embedding_size, classnum)) 255 | self.reset_parameters() 256 | 257 | def reset_parameters(self): 258 | stdv = 1./math.sqrt(self.weight.size(0)) 259 | self.weight.data.uniform_(-stdv, stdv) 260 | 261 | def forward(self, embeddings, y=None): 262 | # weights norm 263 | nB = len(embeddings) 264 | x_len, embeddings = l2_norm(embeddings) 265 | kernel_len, kernel_norm = l2_norm(self.weight,axis=0) 266 | # cos(theta+m) 267 | cos_theta = torch.mm(embeddings,kernel_norm) 268 | theta = embeddings.mm(kernel_norm.detach()) 269 | output = cos_theta * 1.0 # a little bit hacky way to prevent in_place operation on cos_theta 270 | 271 | if self.s: 272 | output*=self.s # scale up in order to make softmax work, first introduced in normface 273 | else: 274 | output*=x_len 275 | 276 | 277 | return output 278 | 279 | --------------------------------------------------------------------------------