├── README.md ├── datas ├── miniImagenet │ ├── create_miniImagenet.py │ ├── proc_images.py │ ├── test.csv │ ├── train.csv │ ├── trainval.csv │ └── val.csv └── omniglot_28x28.zip ├── docs └── sosn.png ├── miniimagenet ├── miniimagenet_test_few_shot_SoSN.py ├── miniimagenet_train_few_shot_SoSN.py ├── models.py ├── task_generator.py └── task_generator_test.py ├── omniglot ├── models.py ├── omniglot_test_few_shot_SoSN.py ├── omniglot_train_few_shot_SoSN.py └── task_generator.py └── openmic └── p1-p2 ├── models.py ├── openmic_test_few_shot_SoSN.py ├── openmic_train_few_shot_SoSN.py ├── task_generator.py └── task_generator.py~ /README.md: -------------------------------------------------------------------------------- 1 | # Power Normalizing Second-order Similarity Network for Few-shot Learning 2 | 3 | Pytorch Implementation of IEEE WACV2019 "[Power Normalizing Second-order Similarity Network for Few-shot Learning](https://arxiv.org/pdf/1811.04167.pdf)". 4 | This is based on the code of Relation Net. 5 | 6 | Download formatted miniImagenet:
7 | https://drive.google.com/file/d/1QhUs2uwEbVqCVig6B6cQmI0WTdYy2C9t/view?usp=sharing
8 | 9 | Download pre-processed OpenMIC dataset via following links (Request form is required):
10 | http://users.cecs.anu.edu.au/~koniusz/openmic-dataset/ 11 | 12 | Decompress the downloaded datasets into '/datas'.
13 | If you have any problem with the code, please contact hongguang.zhang@anu.edu.au.
14 | 15 | ![Pipline of SoSN](docs/sosn.png) 16 | 17 | __Requires.__ 18 | ``` 19 | pytorch-0.4.1 20 | numpy 21 | scipy 22 | ``` 23 | 24 | __For miniImagenet training and testing, run following commands.__ 25 | 26 | ``` 27 | cd ./miniimagenet 28 | python miniimagenet_train_few_shot_SoSN.py -w 5 -s 1 -sigma 100 29 | python miniimagenet_test_few_shot_SoSN.py -w 5 -s 1 -sigma 100 30 | ``` 31 | 32 | ## Citation 33 | ``` 34 | @inproceedings{zhang2019power, 35 | title={Power normalizing second-order similarity network for few-shot learning}, 36 | author={Zhang, Hongguang and Koniusz, Piotr}, 37 | booktitle={2019 IEEE Winter Conference on Applications of Computer Vision (WACV)}, 38 | pages={1185--1193}, 39 | year={2019}, 40 | organization={IEEE} 41 | } 42 | ``` 43 | -------------------------------------------------------------------------------- /datas/miniImagenet/create_miniImagenet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code creates the MiniImagenet dataset. Following the partitions given 3 | by Sachin Ravi and Hugo Larochelle in 4 | https://github.com/twitter/meta-learning-lstm/tree/master/data/miniImagenet 5 | ''' 6 | 7 | import numpy as np 8 | import csv 9 | import glob, os 10 | from shutil import copyfile 11 | #import cv2 12 | from tqdm import tqdm 13 | 14 | pathImageNet = '/flush1/zha230/ILSVRC/Data/CLS-LOC/train' 15 | pathminiImageNet = '/data/zha230/LearningToCompare_FSL-master/datas/miniImagenet/' 16 | pathImages = os.path.join(pathminiImageNet,'images/') 17 | filesCSVSachinRavi = [os.path.join(pathminiImageNet,'train.csv'), 18 | os.path.join(pathminiImageNet,'val.csv'), 19 | os.path.join(pathminiImageNet,'test.csv')] 20 | 21 | # Check if the folder of images exist. If not create it. 22 | if not os.path.exists(pathImages): 23 | os.makedirs(pathImages) 24 | 25 | for filename in filesCSVSachinRavi: 26 | with open(filename) as csvfile: 27 | csv_reader = csv.reader(csvfile, delimiter=',') 28 | next(csv_reader, None) 29 | images = {} 30 | print('Reading IDs....') 31 | for row in tqdm(csv_reader): 32 | if row[1] in images.keys(): 33 | images[row[1]].append(row[0]) 34 | else: 35 | images[row[1]] = [row[0]] 36 | print(images.keys()[0]) 37 | print((os.path.join(pathImageNet, images.keys()[0]))) 38 | 39 | print('Writing photos....') 40 | for c in tqdm(images.keys()): # Iterate over all the classes 41 | lst_files = [] 42 | for file in glob.glob(pathImageNet + "/*"+c+"*"): 43 | lst_files.append(file) 44 | # TODO: Sort by name of by index number of the image??? 45 | # I sort by the number of the image 46 | lst_index = [int(i[i.rfind('_')+1:i.rfind('.')]) for i in lst_files] 47 | index_sorted = sorted(range(len(lst_index)), key=lst_index.__getitem__) 48 | 49 | # Now iterate 50 | index_selected = [int(i[i.index('.') - 4:i.index('.')]) for i in images[c]] 51 | selected_images = np.array(index_sorted)[np.array(index_selected) - 1] 52 | for i in np.arange(len(selected_images)): 53 | # read file and resize to 84x84x3 54 | #im = cv2.imread(os.path.join(pathImageNet,lst_files[selected_images[i]])) 55 | #im_resized = cv2.resize(im, (84, 84), interpolation=cv2.INTER_AREA) 56 | #cv2.imwrite(os.path.join(pathImages, images[c][i]),im_resized) 57 | copyfile(os.path.join(pathImageNet,lst_files[selected_images[i]]),os.path.join(pathImages, images[c][i])) 58 | -------------------------------------------------------------------------------- /datas/miniImagenet/proc_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | code copied from https://github.com/cbfinn/maml/blob/master/data/miniImagenet/proc_images.py 3 | Script for converting from csv file datafiles to a directory for each image (which is how it is loaded by MAML code) 4 | 5 | Acquire miniImagenet from Ravi & Larochelle '17, along with the train, val, and test csv files. Put the 6 | csv files in the miniImagenet directory and put the images in the directory 'miniImagenet/images/'. 7 | Then run this script from the miniImagenet directory: 8 | cd data/miniImagenet/ 9 | python proc_images.py 10 | """ 11 | from __future__ import print_function 12 | import csv 13 | import glob 14 | import os 15 | 16 | from PIL import Image 17 | 18 | path_to_images = 'images/' 19 | 20 | all_images = glob.glob(path_to_images + '*') 21 | 22 | # Resize images 23 | 24 | for i, image_file in enumerate(all_images): 25 | im = Image.open(image_file) 26 | im = im.resize((84, 84), resample=Image.LANCZOS) 27 | im.save(image_file) 28 | if i % 500 == 0: 29 | print(i) 30 | 31 | # Put in correct directory 32 | for datatype in ['train', 'val', 'test']: 33 | os.system('mkdir ' + datatype) 34 | 35 | with open(datatype + '.csv', 'r') as f: 36 | reader = csv.reader(f, delimiter=',') 37 | last_label = '' 38 | for i, row in enumerate(reader): 39 | if i == 0: # skip the headers 40 | continue 41 | label = row[1] 42 | image_name = row[0] 43 | if label != last_label: 44 | cur_dir = datatype + '/' + label + '/' 45 | os.system('mkdir ' + cur_dir) 46 | last_label = label 47 | os.system('cp images/' + image_name + ' ' + cur_dir) 48 | -------------------------------------------------------------------------------- /datas/omniglot_28x28.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HongguangZhang/SoSN-wacv19-master/f108b96007981388ca9f7af86ee7f3226adad874/datas/omniglot_28x28.zip -------------------------------------------------------------------------------- /docs/sosn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HongguangZhang/SoSN-wacv19-master/f108b96007981388ca9f7af86ee7f3226adad874/docs/sosn.png -------------------------------------------------------------------------------- /miniimagenet/miniimagenet_test_few_shot_SoSN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.optim.lr_scheduler import StepLR 6 | import numpy as np 7 | import task_generator_test as tg 8 | import os 9 | import math 10 | import argparse 11 | import scipy as sp 12 | import scipy.stats 13 | import time 14 | import models 15 | 16 | parser = argparse.ArgumentParser(description="One Shot Visual Recognition") 17 | parser.add_argument("-f","--feature_dim",type = int, default = 64) 18 | parser.add_argument("-r","--relation_dim",type = int, default = 8) 19 | parser.add_argument("-w","--class_num",type = int, default = 5) 20 | parser.add_argument("-s","--support_num_per_class",type = int, default = 5) 21 | parser.add_argument("-b","--query_num_per_class",type = int, default = 15) 22 | parser.add_argument("-e","--episode",type = int, default= 100) 23 | parser.add_argument("-t","--test_episode", type = int, default = 600) 24 | parser.add_argument("-l","--learning_rate", type = float, default = 0.001) 25 | parser.add_argument("-g","--gpu",type=int, default=0) 26 | parser.add_argument("-u","--hidden_unit",type=int,default=10) 27 | parser.add_argument("-sigma","--sigma",type=float,default=100) 28 | args = parser.parse_args() 29 | 30 | 31 | # Hyper Parameters 32 | METHOD = "SoSN_Logit" + str(args.sigma) + "_Models" 33 | FEATURE_DIM = args.feature_dim 34 | RELATION_DIM = args.relation_dim 35 | CLASS_NUM = args.class_num 36 | SUPPORT_NUM_PER_CLASS = args.support_num_per_class 37 | QUERY_NUM_PER_CLASS = args.query_num_per_class 38 | EPISODE = args.episode 39 | TEST_EPISODE = args.test_episode 40 | LEARNING_RATE = args.learning_rate 41 | GPU = args.gpu 42 | HIDDEN_UNIT = args.hidden_unit 43 | SIGMA = args.sigma 44 | 45 | def power_norm(x, SIGMA): 46 | out = 2/(1 + torch.exp(-SIGMA*x)) - 1 47 | return out 48 | 49 | def mean_confidence_interval(data, confidence=0.95): 50 | a = 1.0*np.array(data) 51 | n = len(a) 52 | m = np.mean(a) 53 | s = scipy.stats.sem(a) 54 | h = s * sp.stats.t._ppf((1+confidence)/2., n-1) 55 | return m,h 56 | 57 | def weights_init(m): 58 | classname = m.__class__.__name__ 59 | if classname.find('Conv') != -1: 60 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 61 | m.weight.data.normal_(0, math.sqrt(2. / n)) 62 | if m.bias is not None: 63 | m.bias.data.zero_() 64 | elif classname.find('queryNorm') != -1: 65 | m.weight.data.fill_(1) 66 | m.bias.data.zero_() 67 | elif classname.find('Linear') != -1: 68 | n = m.weight.size(1) 69 | m.weight.data.normal_(0, 0.01) 70 | m.bias.data = torch.ones(m.bias.data.size()) 71 | 72 | def main(): 73 | metatrain_folders,metatest_folders = tg.mini_imagenet_folders() 74 | 75 | print("init neural networks") 76 | feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU) 77 | relation_network = models.SimilarityNetwork(FEATURE_DIM,RELATION_DIM).apply(weights_init).cuda(GPU) 78 | 79 | feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE) 80 | feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=50000,gamma=0.5) 81 | relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE) 82 | relation_network_scheduler = StepLR(relation_network_optim,step_size=50000,gamma=0.5) 83 | 84 | if os.path.exists(str(METHOD + "/miniImagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")): 85 | feature_encoder.load_state_dict(torch.load(str(METHOD + "/miniImagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))) 86 | print("load feature encoder success") 87 | if os.path.exists(str(METHOD + "/miniImagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")): 88 | relation_network.load_state_dict(torch.load(str(METHOD + "/miniImagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))) 89 | print("load relation network success") 90 | if os.path.exists(METHOD) == False: 91 | os.system('mkdir ' + METHOD) 92 | 93 | # Step 3: build graph 94 | print("Training...") 95 | 96 | best_accuracy = 0.0 97 | best_h = 0.0 98 | start = time.time() 99 | 100 | for episode in range(EPISODE): 101 | with torch.no_grad(): 102 | # test 103 | print("Testing...") 104 | accuracies = [] 105 | for i in range(TEST_EPISODE): 106 | total_rewards = 0 107 | counter = 0 108 | task = tg.MiniImagenetTask(metatest_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,15) 109 | support_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False) 110 | num_per_class = 5 111 | query_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=num_per_class,split="test",shuffle=False) 112 | 113 | support_images,support_labels = support_dataloader.__iter__().next() 114 | for query_images,query_labels in query_dataloader: 115 | query_size = query_labels.shape[0] 116 | 117 | support_features = feature_encoder(Variable(support_images).cuda(GPU)) 118 | support_features = support_features.view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,FEATURE_DIM,19*19).sum(1) 119 | query_features = feature_encoder(Variable(query_images).cuda(GPU)).view(num_per_class*CLASS_NUM,64,19*19) 120 | H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU) 121 | H_query_features = Variable(torch.Tensor(num_per_class*CLASS_NUM, 1, 64, 64)).cuda(GPU) 122 | 123 | for d in range(support_features.size()[0]): 124 | s = support_features[d,:,:].squeeze(0) 125 | s = (1.0 / support_features.size()[2]) * s.mm(s.transpose(0,1)) 126 | H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) 127 | for d in range(query_features.size()[0]): 128 | s = query_features[d,:,:].squeeze(0) 129 | s = (1.0 / query_features.size()[2]) * s.mm(s.transpose(0,1)) 130 | H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) 131 | 132 | 133 | support_features_ext = H_support_features.unsqueeze(0).repeat(query_size,1,1,1,1) 134 | 135 | query_features_ext = H_query_features.unsqueeze(0).repeat(1*CLASS_NUM,1,1,1,1) 136 | query_features_ext = torch.transpose(query_features_ext,0,1) 137 | relation_pairs = torch.cat((support_features_ext, query_features_ext),2).view(-1,2,64,64) 138 | relations = relation_network(relation_pairs).view(-1,CLASS_NUM) 139 | 140 | _,predict_labels = torch.max(relations.data,1) 141 | 142 | rewards = [1 if predict_labels[j]==query_labels[j].cuda(GPU) else 0 for j in range(query_size)] 143 | 144 | total_rewards += np.sum(rewards) 145 | counter += query_size 146 | 147 | accuracy = total_rewards/1.0/counter 148 | accuracies.append(accuracy) 149 | 150 | 151 | test_accuracy,h = mean_confidence_interval(accuracies) 152 | 153 | print("test accuracy:",test_accuracy,"h:",h) 154 | print("best accuracy:",best_accuracy,"h:",best_h) 155 | 156 | if test_accuracy > best_accuracy: 157 | # save networks 158 | torch.save(feature_encoder.state_dict(),str(METHOD + "/miniImagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")) 159 | torch.save(relation_network.state_dict(),str(METHOD + "/miniImagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")) 160 | print("save networks for episode:",episode) 161 | 162 | best_accuracy = test_accuracy 163 | best_h = h 164 | 165 | 166 | 167 | if __name__ == '__main__': 168 | main() 169 | -------------------------------------------------------------------------------- /miniimagenet/miniimagenet_train_few_shot_SoSN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.optim.lr_scheduler import StepLR 6 | import numpy as np 7 | import task_generator as tg 8 | import os 9 | import math 10 | import argparse 11 | import scipy as sp 12 | import scipy.stats 13 | import time 14 | import models 15 | 16 | parser = argparse.ArgumentParser(description="One Shot Visual Recognition") 17 | parser.add_argument("-f","--feature_dim",type = int, default = 64) 18 | parser.add_argument("-r","--relation_dim",type = int, default = 8) 19 | parser.add_argument("-w","--class_num",type = int, default = 5) 20 | parser.add_argument("-s","--support_num_per_class",type = int, default = 5) 21 | parser.add_argument("-b","--query_num_per_class",type = int, default = 15) 22 | parser.add_argument("-e","--episode",type = int, default= 500000) 23 | parser.add_argument("-t","--test_episode", type = int, default = 600) 24 | parser.add_argument("-l","--learning_rate", type = float, default = 0.001) 25 | parser.add_argument("-g","--gpu",type=int, default=0) 26 | parser.add_argument("-u","--hidden_unit",type=int,default=10) 27 | parser.add_argument("-sigma","--sigma",type=float,default=100) 28 | args = parser.parse_args() 29 | 30 | 31 | # Hyper Parameters 32 | METHOD = "SoSN_Logit" + str(args.sigma) + "_Models" 33 | FEATURE_DIM = args.feature_dim 34 | RELATION_DIM = args.relation_dim 35 | CLASS_NUM = args.class_num 36 | SUPPORT_NUM_PER_CLASS = args.support_num_per_class 37 | QUERY_NUM_PER_CLASS = args.query_num_per_class 38 | EPISODE = args.episode 39 | TEST_EPISODE = args.test_episode 40 | LEARNING_RATE = args.learning_rate 41 | GPU = args.gpu 42 | HIDDEN_UNIT = args.hidden_unit 43 | SIGMA = args.sigma 44 | 45 | def power_norm(x, SIGMA): 46 | out = 2/(1 + torch.exp(-SIGMA*x)) - 1 47 | return out 48 | 49 | def mean_confidence_interval(data, confidence=0.95): 50 | a = 1.0*np.array(data) 51 | n = len(a) 52 | m = np.mean(a) 53 | s = scipy.stats.sem(a) 54 | h = s * sp.stats.t._ppf((1+confidence)/2., n-1) 55 | return m,h 56 | 57 | def weights_init(m): 58 | classname = m.__class__.__name__ 59 | if classname.find('Conv') != -1: 60 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 61 | m.weight.data.normal_(0, math.sqrt(2. / n)) 62 | if m.bias is not None: 63 | m.bias.data.zero_() 64 | elif classname.find('queryNorm') != -1: 65 | m.weight.data.fill_(1) 66 | m.bias.data.zero_() 67 | elif classname.find('Linear') != -1: 68 | n = m.weight.size(1) 69 | m.weight.data.normal_(0, 0.01) 70 | m.bias.data = torch.ones(m.bias.data.size()) 71 | 72 | def main(): 73 | metatrain_folders,metatest_folders = tg.mini_imagenet_folders() 74 | 75 | print("init neural networks") 76 | feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU) 77 | relation_network = models.SimilarityNetwork(FEATURE_DIM,RELATION_DIM).apply(weights_init).cuda(GPU) 78 | 79 | feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE) 80 | feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=50000,gamma=0.5) 81 | relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE) 82 | relation_network_scheduler = StepLR(relation_network_optim,step_size=50000,gamma=0.5) 83 | 84 | if os.path.exists(str(METHOD + "/miniImagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")): 85 | feature_encoder.load_state_dict(torch.load(str(METHOD + "/miniImagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))) 86 | print("load feature encoder success") 87 | if os.path.exists(str(METHOD + "/miniImagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")): 88 | relation_network.load_state_dict(torch.load(str(METHOD + "/miniImagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))) 89 | print("load relation network success") 90 | if os.path.exists(METHOD) == False: 91 | os.system('mkdir ' + METHOD) 92 | 93 | # Step 3: build graph 94 | print("Training...") 95 | 96 | best_accuracy = 0.0 97 | best_h = 0.0 98 | start = time.time() 99 | 100 | for episode in range(EPISODE): 101 | feature_encoder_scheduler.step(episode) 102 | relation_network_scheduler.step(episode) 103 | 104 | # init dataset 105 | task = tg.MiniImagenetTask(metatrain_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,QUERY_NUM_PER_CLASS) 106 | support_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False) 107 | query_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=QUERY_NUM_PER_CLASS,split="test",shuffle=True) 108 | 109 | # generate support data and query_data 110 | supports,support_labels = support_dataloader.__iter__().next() #25*3*84*84 111 | queries,query_labels = query_dataloader.__iter__().next() 112 | 113 | # generate features 114 | support_features = feature_encoder(Variable(supports).cuda(GPU)) # 25*64*19*19 115 | support_features = support_features.view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,FEATURE_DIM,19*19).sum(1) # size: CLASS_NUMx64x19x19 116 | query_features = feature_encoder(Variable(queries).cuda(GPU)).view(QUERY_NUM_PER_CLASS*CLASS_NUM,64,19*19) # size: QUERY_NUM_PER_CLASSx64x19x19 117 | 118 | # init second-order representations 119 | H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU) 120 | H_query_features = Variable(torch.Tensor(QUERY_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU) 121 | # second-order pooling 122 | for d in range(support_features.size()[0]): 123 | s = support_features[d,:,:].squeeze(0) 124 | s = (1.0 / support_features.size()[2]) * s.mm(s.transpose(0,1)) 125 | H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) 126 | for d in range(query_features.size()[0]): 127 | s = query_features[d,:,:].squeeze(0) 128 | s = (1.0 / query_features.size()[2]) * s.mm(s.transpose(0,1)) 129 | H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) 130 | 131 | # form the QURY_NUM_PER_CLASSxCLASS_NUM relation pairs 132 | support_features_ext = H_support_features.unsqueeze(0).repeat(QUERY_NUM_PER_CLASS*CLASS_NUM,1,1,1,1) 133 | query_features_ext = H_query_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1) 134 | query_features_ext = torch.transpose(query_features_ext,0,1) 135 | relation_pairs = torch.cat((support_features_ext, query_features_ext),2).view(-1,2,64,64) 136 | # calculate relation scores 137 | relations = relation_network(relation_pairs).view(-1,CLASS_NUM) 138 | 139 | # define the loss function 140 | mse = nn.MSELoss().cuda(GPU) 141 | one_hot_labels = Variable(torch.zeros(QUERY_NUM_PER_CLASS*CLASS_NUM, CLASS_NUM).scatter_(1, query_labels.view(-1,1), 1).cuda(GPU)) 142 | loss = mse(relations,one_hot_labels) 143 | 144 | 145 | # training 146 | feature_encoder.zero_grad() 147 | relation_network.zero_grad() 148 | 149 | loss.backward() 150 | 151 | feature_encoder_optim.step() 152 | relation_network_optim.step() 153 | 154 | if (episode+1)%100 == 0: 155 | print("episode:",episode+1,"loss",loss.data[0]) 156 | 157 | if (episode+1)%2500 == 0: 158 | # test 159 | print("Testing...") 160 | accuracies = [] 161 | for i in range(TEST_EPISODE): 162 | total_rewards = 0 163 | counter = 0 164 | task = tg.MiniImagenetTask(metatest_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,15) 165 | support_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False) 166 | num_per_class = 5 167 | query_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=num_per_class,split="test",shuffle=False) 168 | 169 | support_images,support_labels = support_dataloader.__iter__().next() 170 | for query_images,query_labels in query_dataloader: 171 | query_size = query_labels.shape[0] 172 | 173 | support_features = feature_encoder(Variable(support_images).cuda(GPU)) 174 | support_features = support_features.view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,FEATURE_DIM,19*19).sum(1) 175 | query_features = feature_encoder(Variable(query_images).cuda(GPU)).view(num_per_class*CLASS_NUM,64,19*19) 176 | H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU) 177 | H_query_features = Variable(torch.Tensor(num_per_class*CLASS_NUM, 1, 64, 64)).cuda(GPU) 178 | 179 | for d in range(support_features.size()[0]): 180 | s = support_features[d,:,:].squeeze(0) 181 | s = (1.0 / support_features.size()[2]) * s.mm(s.transpose(0,1)) 182 | H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) 183 | for d in range(query_features.size()[0]): 184 | s = query_features[d,:,:].squeeze(0) 185 | s = (1.0 / query_features.size()[2]) * s.mm(s.transpose(0,1)) 186 | H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) 187 | 188 | 189 | support_features_ext = H_support_features.unsqueeze(0).repeat(query_size,1,1,1,1) 190 | 191 | query_features_ext = H_query_features.unsqueeze(0).repeat(1*CLASS_NUM,1,1,1,1) 192 | query_features_ext = torch.transpose(query_features_ext,0,1) 193 | relation_pairs = torch.cat((support_features_ext, query_features_ext),2).view(-1,2,64,64) 194 | relations = relation_network(relation_pairs).view(-1,CLASS_NUM) 195 | 196 | _,predict_labels = torch.max(relations.data,1) 197 | 198 | rewards = [1 if predict_labels[j]==query_labels[j].cuda(GPU) else 0 for j in range(query_size)] 199 | 200 | total_rewards += np.sum(rewards) 201 | counter += query_size 202 | 203 | accuracy = total_rewards/1.0/counter 204 | accuracies.append(accuracy) 205 | 206 | 207 | test_accuracy,h = mean_confidence_interval(accuracies) 208 | 209 | print("test accuracy:",test_accuracy,"h:",h) 210 | print("best accuracy:",best_accuracy,"h:",best_h) 211 | 212 | if test_accuracy > best_accuracy: 213 | # save networks 214 | torch.save(feature_encoder.state_dict(),str(METHOD + "/miniImagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")) 215 | torch.save(relation_network.state_dict(),str(METHOD + "/miniImagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")) 216 | print("save networks for episode:",episode) 217 | 218 | best_accuracy = test_accuracy 219 | best_h = h 220 | 221 | 222 | 223 | if __name__ == '__main__': 224 | main() 225 | -------------------------------------------------------------------------------- /miniimagenet/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FeatureEncoder(nn.Module): 6 | """docstring for ClassName""" 7 | def __init__(self): 8 | super(FeatureEncoder, self).__init__() 9 | self.layer1 = nn.Sequential( 10 | nn.Conv2d(3,64,kernel_size=3,padding=0), 11 | nn.BatchNorm2d(64, momentum=1, affine=True), 12 | nn.ReLU(), 13 | nn.MaxPool2d(2)) 14 | self.layer2 = nn.Sequential( 15 | nn.Conv2d(64,64,kernel_size=3,padding=0), 16 | nn.BatchNorm2d(64, momentum=1, affine=True), 17 | nn.ReLU(), 18 | nn.MaxPool2d(2)) 19 | self.layer3 = nn.Sequential( 20 | nn.Conv2d(64,64,kernel_size=3,padding=1), 21 | nn.BatchNorm2d(64, momentum=1, affine=True), 22 | nn.ReLU()) 23 | self.layer4 = nn.Sequential( 24 | nn.Conv2d(64,64,kernel_size=3,padding=1), 25 | nn.BatchNorm2d(64, momentum=1, affine=True), 26 | nn.ReLU()) 27 | 28 | def forward(self,x): 29 | out = self.layer1(x) 30 | out = self.layer2(out) 31 | out = self.layer3(out) 32 | out = self.layer4(out) 33 | return out 34 | 35 | class SimilarityNetwork(nn.Module): 36 | """docstring for RelationNetwork""" 37 | def __init__(self,input_size,hidden_size): 38 | super(SimilarityNetwork, self).__init__() 39 | self.layer1 = nn.Sequential( 40 | nn.Conv2d(2,64,kernel_size=3,padding=0), 41 | nn.BatchNorm2d(64, momentum=1, affine=True), 42 | nn.ReLU(), 43 | nn.MaxPool2d(2)) #Nx64x31x31 44 | self.layer2 = nn.Sequential( 45 | nn.Conv2d(64,64,kernel_size=3,padding=0), 46 | nn.BatchNorm2d(64, momentum=1, affine=True), 47 | nn.ReLU(), 48 | nn.MaxPool2d(2)) #Nx64x14x14 49 | self.layer3 = nn.Sequential( 50 | nn.Conv2d(64,64,kernel_size=3,padding=0), 51 | nn.BatchNorm2d(64, momentum=1, affine=True), 52 | nn.ReLU(), 53 | nn.MaxPool2d(2)) #Nx64x6x6 54 | self.layer4 = nn.Sequential( 55 | nn.Conv2d(64,64,kernel_size=3,padding=0), 56 | nn.BatchNorm2d(64, momentum=1, affine=True), 57 | nn.ReLU(), 58 | nn.MaxPool2d(2)) #Nx64x2x2 59 | self.fc1 = nn.Linear(input_size*4,hidden_size) 60 | self.fc2 = nn.Linear(hidden_size,1) 61 | 62 | def forward(self,x): 63 | out = self.layer1(x) 64 | out = self.layer2(out) 65 | out = self.layer3(out) 66 | out = self.layer4(out) 67 | out = out.view(out.size(0),-1) 68 | out = F.relu(self.fc1(out)) 69 | out = F.sigmoid(self.fc2(out)) 70 | return out 71 | -------------------------------------------------------------------------------- /miniimagenet/task_generator.py: -------------------------------------------------------------------------------- 1 | # code is based on https://github.com/katerakelly/pytorch-maml 2 | import torchvision 3 | import torchvision.datasets as dset 4 | import torchvision.transforms as transforms 5 | import torch 6 | from torch.utils.data import DataLoader,Dataset 7 | import random 8 | import os 9 | from PIL import Image 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from torch.utils.data.sampler import Sampler 13 | 14 | def imshow(img): 15 | npimg = img.numpy() 16 | plt.axis("off") 17 | plt.imshow(np.transpose(npimg,(1,2,0))) 18 | plt.show() 19 | 20 | class Rotate(object): 21 | def __init__(self, angle): 22 | self.angle = angle 23 | def __call__(self, x, mode="reflect"): 24 | x = x.rotate(self.angle) 25 | return x 26 | 27 | def mini_imagenet_folders(): 28 | train_folder = '../datas/miniImagenet/train' 29 | test_folder = '../datas/miniImagenet/val' 30 | 31 | metatrain_folders = [os.path.join(train_folder, label) \ 32 | for label in os.listdir(train_folder) \ 33 | if os.path.isdir(os.path.join(train_folder, label)) \ 34 | ] 35 | metatest_folders = [os.path.join(test_folder, label) \ 36 | for label in os.listdir(test_folder) \ 37 | if os.path.isdir(os.path.join(test_folder, label)) \ 38 | ] 39 | 40 | random.seed(1) 41 | random.shuffle(metatrain_folders) 42 | random.shuffle(metatest_folders) 43 | 44 | return metatrain_folders,metatest_folders 45 | 46 | class MiniImagenetTask(object): 47 | 48 | def __init__(self, character_folders, num_classes, train_num,test_num): 49 | 50 | self.character_folders = character_folders 51 | self.num_classes = num_classes 52 | self.train_num = train_num 53 | self.test_num = test_num 54 | 55 | class_folders = random.sample(self.character_folders,self.num_classes) 56 | labels = np.array(range(len(class_folders))) 57 | labels = dict(zip(class_folders, labels)) 58 | samples = dict() 59 | 60 | self.train_roots = [] 61 | self.test_roots = [] 62 | for c in class_folders: 63 | 64 | temp = [os.path.join(c, x) for x in os.listdir(c)] 65 | samples[c] = random.sample(temp, len(temp)) 66 | random.shuffle(samples[c]) 67 | 68 | self.train_roots += samples[c][:train_num] 69 | self.test_roots += samples[c][train_num:train_num+test_num] 70 | 71 | self.train_labels = [labels[self.get_class(x)] for x in self.train_roots] 72 | self.test_labels = [labels[self.get_class(x)] for x in self.test_roots] 73 | 74 | def get_class(self, sample): 75 | return os.path.join(*sample.split('/')[:-1]) 76 | 77 | 78 | class FewShotDataset(Dataset): 79 | 80 | def __init__(self, task, split='train', transform=None, target_transform=None): 81 | self.transform = transform # Torch operations on the input image 82 | self.target_transform = target_transform 83 | self.task = task 84 | self.split = split 85 | self.image_roots = self.task.train_roots if self.split == 'train' else self.task.test_roots 86 | self.labels = self.task.train_labels if self.split == 'train' else self.task.test_labels 87 | 88 | def __len__(self): 89 | return len(self.image_roots) 90 | 91 | def __getitem__(self, idx): 92 | raise NotImplementedError("This is an abstract class. Subclass this class for your particular dataset.") 93 | 94 | class MiniImagenet(FewShotDataset): 95 | 96 | def __init__(self, *args, **kwargs): 97 | super(MiniImagenet, self).__init__(*args, **kwargs) 98 | 99 | def __getitem__(self, idx): 100 | image_root = self.image_roots[idx] 101 | image = Image.open(image_root) 102 | image = image.convert('RGB') 103 | if self.transform is not None: 104 | image = self.transform(image) 105 | label = self.labels[idx] 106 | if self.target_transform is not None: 107 | label = self.target_transform(label) 108 | return image, label 109 | 110 | 111 | class ClassBalancedSampler(Sampler): 112 | ''' Samples 'num_inst' examples each from 'num_cl' pools 113 | of examples of size 'num_per_class' ''' 114 | 115 | def __init__(self, num_per_class, num_cl, num_inst,shuffle=True): 116 | self.num_per_class = num_per_class 117 | self.num_cl = num_cl 118 | self.num_inst = num_inst 119 | self.shuffle = shuffle 120 | 121 | def __iter__(self): 122 | # return a single list of indices, assuming that items will be grouped by class 123 | if self.shuffle: 124 | batch = [[i+j*self.num_inst for i in torch.randperm(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)] 125 | else: 126 | batch = [[i+j*self.num_inst for i in range(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)] 127 | batch = [item for sublist in batch for item in sublist] 128 | 129 | if self.shuffle: 130 | random.shuffle(batch) 131 | return iter(batch) 132 | 133 | def __len__(self): 134 | return 1 135 | 136 | 137 | def get_mini_imagenet_data_loader(task, num_per_class=1, split='train',shuffle = False): 138 | normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426]) 139 | 140 | dataset = MiniImagenet(task,split=split,transform=transforms.Compose([transforms.ToTensor(),normalize])) 141 | 142 | if split == 'train': 143 | sampler = ClassBalancedSampler(num_per_class, task.num_classes, task.train_num,shuffle=shuffle) 144 | else: 145 | sampler = ClassBalancedSampler(num_per_class, task.num_classes, task.test_num,shuffle=shuffle) 146 | 147 | loader = DataLoader(dataset, batch_size=num_per_class*task.num_classes, sampler=sampler) 148 | 149 | return loader 150 | 151 | -------------------------------------------------------------------------------- /miniimagenet/task_generator_test.py: -------------------------------------------------------------------------------- 1 | # code is based on https://github.com/katerakelly/pytorch-maml 2 | import torchvision 3 | import torchvision.datasets as dset 4 | import torchvision.transforms as transforms 5 | import torch 6 | from torch.utils.data import DataLoader,Dataset 7 | import random 8 | import os 9 | from PIL import Image 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from torch.utils.data.sampler import Sampler 13 | 14 | def imshow(img): 15 | npimg = img.numpy() 16 | plt.axis("off") 17 | plt.imshow(np.transpose(npimg,(1,2,0))) 18 | plt.show() 19 | 20 | class Rotate(object): 21 | def __init__(self, angle): 22 | self.angle = angle 23 | def __call__(self, x, mode="reflect"): 24 | x = x.rotate(self.angle) 25 | return x 26 | 27 | def mini_imagenet_folders(): 28 | train_folder = '../datas/miniImagenet/train' 29 | test_folder = '../datas/miniImagenet/test' 30 | 31 | metatrain_folders = [os.path.join(train_folder, label) \ 32 | for label in os.listdir(train_folder) \ 33 | if os.path.isdir(os.path.join(train_folder, label)) \ 34 | ] 35 | metatest_folders = [os.path.join(test_folder, label) \ 36 | for label in os.listdir(test_folder) \ 37 | if os.path.isdir(os.path.join(test_folder, label)) \ 38 | ] 39 | 40 | random.seed(1) 41 | random.shuffle(metatrain_folders) 42 | random.shuffle(metatest_folders) 43 | 44 | return metatrain_folders,metatest_folders 45 | 46 | class MiniImagenetTask(object): 47 | 48 | def __init__(self, character_folders, num_classes, train_num,test_num): 49 | 50 | self.character_folders = character_folders 51 | self.num_classes = num_classes 52 | self.train_num = train_num 53 | self.test_num = test_num 54 | 55 | class_folders = random.sample(self.character_folders,self.num_classes) 56 | labels = np.array(range(len(class_folders))) 57 | labels = dict(zip(class_folders, labels)) 58 | samples = dict() 59 | 60 | self.train_roots = [] 61 | self.test_roots = [] 62 | for c in class_folders: 63 | 64 | temp = [os.path.join(c, x) for x in os.listdir(c)] 65 | samples[c] = random.sample(temp, len(temp)) 66 | random.shuffle(samples[c]) 67 | 68 | self.train_roots += samples[c][:train_num] 69 | self.test_roots += samples[c][train_num:train_num+test_num] 70 | 71 | self.train_labels = [labels[self.get_class(x)] for x in self.train_roots] 72 | self.test_labels = [labels[self.get_class(x)] for x in self.test_roots] 73 | 74 | def get_class(self, sample): 75 | return os.path.join(*sample.split('/')[:-1]) 76 | 77 | class FewShotDataset(Dataset): 78 | 79 | def __init__(self, task, split='train', transform=None, target_transform=None): 80 | self.transform = transform # Torch operations on the input image 81 | self.target_transform = target_transform 82 | self.task = task 83 | self.split = split 84 | self.image_roots = self.task.train_roots if self.split == 'train' else self.task.test_roots 85 | self.labels = self.task.train_labels if self.split == 'train' else self.task.test_labels 86 | 87 | def __len__(self): 88 | return len(self.image_roots) 89 | 90 | def __getitem__(self, idx): 91 | raise NotImplementedError("This is an abstract class. Subclass this class for your particular dataset.") 92 | 93 | class MiniImagenet(FewShotDataset): 94 | 95 | def __init__(self, *args, **kwargs): 96 | super(MiniImagenet, self).__init__(*args, **kwargs) 97 | 98 | def __getitem__(self, idx): 99 | image_root = self.image_roots[idx] 100 | image = Image.open(image_root) 101 | image = image.convert('RGB') 102 | if self.transform is not None: 103 | image = self.transform(image) 104 | label = self.labels[idx] 105 | if self.target_transform is not None: 106 | label = self.target_transform(label) 107 | return image, label 108 | 109 | 110 | class ClassBalancedSampler(Sampler): 111 | ''' Samples 'num_inst' examples each from 'num_cl' pools 112 | of examples of size 'num_per_class' ''' 113 | 114 | def __init__(self, num_cl, num_inst,shuffle=True): 115 | 116 | self.num_cl = num_cl 117 | self.num_inst = num_inst 118 | self.shuffle = shuffle 119 | 120 | def __iter__(self): 121 | # return a single list of indices, assuming that items will be grouped by class 122 | if self.shuffle: 123 | batches = [[i+j*self.num_inst for i in torch.randperm(self.num_inst)] for j in range(self.num_cl)] 124 | else: 125 | batches = [[i+j*self.num_inst for i in range(self.num_inst)] for j in range(self.num_cl)] 126 | batches = [[batches[j][i] for j in range(self.num_cl)] for i in range(self.num_inst)] 127 | 128 | if self.shuffle: 129 | random.shuffle(batches) 130 | for sublist in batches: 131 | random.shuffle(sublist) 132 | batches = [item for sublist in batches for item in sublist] 133 | return iter(batches) 134 | 135 | def __len__(self): 136 | return 1 137 | 138 | class ClassBalancedSamplerOld(Sampler): 139 | ''' Samples 'num_inst' examples each from 'num_cl' pools 140 | of examples of size 'num_per_class' ''' 141 | 142 | def __init__(self, num_per_class, num_cl, num_inst,shuffle=True): 143 | self.num_per_class = num_per_class 144 | self.num_cl = num_cl 145 | self.num_inst = num_inst 146 | self.shuffle = shuffle 147 | 148 | def __iter__(self): 149 | # return a single list of indices, assuming that items will be grouped by class 150 | if self.shuffle: 151 | batch = [[i+j*self.num_inst for i in torch.randperm(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)] 152 | else: 153 | batch = [[i+j*self.num_inst for i in range(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)] 154 | batch = [item for sublist in batch for item in sublist] 155 | 156 | if self.shuffle: 157 | random.shuffle(batch) 158 | return iter(batch) 159 | 160 | def __len__(self): 161 | return 1 162 | 163 | 164 | def get_mini_imagenet_data_loader(task, num_per_class=1, split='train',shuffle = False): 165 | normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426]) 166 | 167 | dataset = MiniImagenet(task,split=split,transform=transforms.Compose([transforms.ToTensor(),normalize])) 168 | if split == 'train': 169 | sampler = ClassBalancedSamplerOld(num_per_class,task.num_classes, task.train_num,shuffle=shuffle) 170 | 171 | else: 172 | sampler = ClassBalancedSampler(task.num_classes, task.test_num,shuffle=shuffle) 173 | 174 | loader = DataLoader(dataset, batch_size=num_per_class*task.num_classes, sampler=sampler) 175 | return loader 176 | -------------------------------------------------------------------------------- /omniglot/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FeatureEncoder(nn.Module): 6 | """docstring for ClassName""" 7 | def __init__(self): 8 | super(FeatureEncoder, self).__init__() 9 | self.layer1 = nn.Sequential( 10 | nn.Conv2d(1,64,kernel_size=3,padding=0), 11 | nn.BatchNorm2d(64, momentum=1, affine=True), 12 | nn.ReLU(), 13 | nn.MaxPool2d(2)) 14 | self.layer2 = nn.Sequential( 15 | nn.Conv2d(64,64,kernel_size=3,padding=0), 16 | nn.BatchNorm2d(64, momentum=1, affine=True), 17 | nn.ReLU(), 18 | nn.MaxPool2d(2)) 19 | self.layer3 = nn.Sequential( 20 | nn.Conv2d(64,64,kernel_size=3,padding=1), 21 | nn.BatchNorm2d(64, momentum=1, affine=True), 22 | nn.ReLU()) 23 | self.layer4 = nn.Sequential( 24 | nn.Conv2d(64,64,kernel_size=3,padding=1), 25 | nn.BatchNorm2d(64, momentum=1, affine=True), 26 | nn.ReLU()) 27 | def forward(self,x): 28 | out = self.layer1(x) 29 | out = self.layer2(out) 30 | out = self.layer3(out) 31 | out = self.layer4(out) 32 | return out 33 | 34 | class SimilarityNetwork(nn.Module): 35 | """docstring for RelationNetwork""" 36 | def __init__(self,input_size,hidden_size): 37 | super(SimilarityNetwork, self).__init__() 38 | self.layer1 = nn.Sequential( 39 | nn.Conv2d(2,64,kernel_size=3,padding=0), 40 | nn.BatchNorm2d(64, momentum=1, affine=True), 41 | nn.ReLU(), 42 | nn.MaxPool2d(2)) #Nx64x31x31 43 | self.layer2 = nn.Sequential( 44 | nn.Conv2d(64,64,kernel_size=3,padding=0), 45 | nn.BatchNorm2d(64, momentum=1, affine=True), 46 | nn.ReLU(), 47 | nn.MaxPool2d(2)) #Nx64x14x14 48 | self.layer3 = nn.Sequential( 49 | nn.Conv2d(64,64,kernel_size=3,padding=0), 50 | nn.BatchNorm2d(64, momentum=1, affine=True), 51 | nn.ReLU(), 52 | nn.MaxPool2d(2)) #Nx64x6x6 53 | self.layer4 = nn.Sequential( 54 | nn.Conv2d(64,64,kernel_size=3,padding=0), 55 | nn.BatchNorm2d(64, momentum=1, affine=True), 56 | nn.ReLU(), 57 | nn.MaxPool2d(2)) #Nx64x2x2 58 | self.fc1 = nn.Linear(input_size*4,hidden_size) 59 | self.fc2 = nn.Linear(hidden_size,1) 60 | 61 | def forward(self,x): 62 | out = self.layer1(x) 63 | out = self.layer2(out) 64 | out = self.layer3(out) 65 | out = self.layer4(out) 66 | out = out.view(out.size(0),-1) 67 | out = F.relu(self.fc1(out)) 68 | out = F.sigmoid(self.fc2(out)) 69 | return out 70 | -------------------------------------------------------------------------------- /omniglot/omniglot_test_few_shot_SoSN.py: -------------------------------------------------------------------------------- 1 | #------------------------------------- 2 | # Project: Learning to Compare: Relation Network for Few-Shot Learning 3 | # Date: 2017.9.21 4 | # Author: Flood Sung 5 | # All Rights Reserved 6 | #------------------------------------- 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | from torch.optim.lr_scheduler import StepLR 14 | import numpy as np 15 | import task_generator as tg 16 | import os 17 | import math 18 | import argparse 19 | import random 20 | import models 21 | 22 | parser = argparse.ArgumentParser(description="One Shot Visual Recognition") 23 | parser.add_argument("-f","--feature_dim",type = int, default = 64) 24 | parser.add_argument("-r","--relation_dim",type = int, default = 8) 25 | parser.add_argument("-w","--class_num",type = int, default = 5) 26 | parser.add_argument("-s","--support_num_per_class",type = int, default = 5) 27 | parser.add_argument("-b","--query_num_per_class",type = int, default = 2) 28 | parser.add_argument("-e","--episode",type = int, default= 100) 29 | parser.add_argument("-t","--query_episode", type = int, default = 1000) 30 | parser.add_argument("-l","--learning_rate", type = float, default = 0.001) 31 | parser.add_argument("-g","--gpu",type=int, default=0) 32 | parser.add_argument("-u","--hidden_unit",type=int,default=10) 33 | parser.add_argument("-sigma","--sigma", type = float, default = 1) 34 | parser.add_argument("-ts","--test_num_per_class",type = int, default = 5) 35 | args = parser.parse_args() 36 | 37 | 38 | # Hyper Parameters 39 | METHOD = "SoSN_LOGIT" + str(args.sigma) + "_Models" 40 | FEATURE_DIM = args.feature_dim 41 | RELATION_DIM = args.relation_dim 42 | CLASS_NUM = args.class_num 43 | SUPPORT_NUM_PER_CLASS = args.support_num_per_class 44 | QUERY_NUM_PER_CLASS = args.query_num_per_class 45 | TEST_NUM_PER_CLASS = args.test_num_per_class 46 | EPISODE = args.episode 47 | TEST_EPISODE = args.query_episode 48 | LEARNING_RATE = args.learning_rate 49 | GPU = args.gpu 50 | HIDDEN_UNIT = args.hidden_unit 51 | SIGMA = args.sigma 52 | 53 | def power_norm(x, SIGMA): 54 | out = 2/(1 + torch.exp(-SIGMA*x)) - 1 55 | return out 56 | 57 | def weights_init(m): 58 | classname = m.__class__.__name__ 59 | if classname.find('Conv') != -1: 60 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 61 | m.weight.data.normal_(0, math.sqrt(2. / n)) 62 | if m.bias is not None: 63 | m.bias.data.zero_() 64 | elif classname.find('BatchNorm') != -1: 65 | m.weight.data.fill_(1) 66 | m.bias.data.zero_() 67 | elif classname.find('Linear') != -1: 68 | n = m.weight.size(1) 69 | m.weight.data.normal_(0, 0.01) 70 | m.bias.data = torch.ones(m.bias.data.size()) 71 | 72 | def main(): 73 | # Step 1: init data folders 74 | print("init data folders") 75 | # init character folders for dataset construction 76 | metatrain_character_folders,metaquery_character_folders = tg.omniglot_character_folders() 77 | 78 | # Step 2: init neural networks 79 | print("init neural networks") 80 | 81 | feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU) 82 | relation_network = models.SimilarityNetwork(FEATURE_DIM,RELATION_DIM).apply(weights_init).cuda(GPU) 83 | 84 | feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE) 85 | feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=50000,gamma=0.1) 86 | relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE) 87 | relation_network_scheduler = StepLR(relation_network_optim,step_size=50000,gamma=0.1) 88 | 89 | if os.path.exists(str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")): 90 | feature_encoder.load_state_dict(torch.load(str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))) 91 | print("load feature encoder success") 92 | if os.path.exists(str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")): 93 | relation_network.load_state_dict(torch.load(str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))) 94 | print("load similarity network success") 95 | if os.path.exists(METHOD) == False: 96 | os.system('mkdir ' + METHOD) 97 | 98 | # Step 3: build graph 99 | print("Training...") 100 | 101 | best_accuracy = 0.0 102 | best_h = 0.0 103 | 104 | for episode in range(EPISODE): 105 | with torch.no_grad(): 106 | # query 107 | print("Testing...") 108 | total_rewards = 0 109 | 110 | for i in range(TEST_EPISODE): 111 | degrees = random.choice([0,90,180,270]) 112 | task = tg.OmniglotTask(metaquery_character_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,TEST_NUM_PER_CLASS,) 113 | support_dataloader = tg.get_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False,rotation=degrees) 114 | query_dataloader = tg.get_data_loader(task,num_per_class=TEST_NUM_PER_CLASS,split="query",shuffle=True,rotation=degrees) 115 | 116 | support_images,support_labels = support_dataloader.__iter__().next() 117 | query_images,query_labels = query_dataloader.__iter__().next() 118 | 119 | # calculate features 120 | support_features = feature_encoder(Variable(support_images).cuda(GPU)) # 5x64 121 | support_features = support_features.view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,FEATURE_DIM,25).sum(1) 122 | query_features = feature_encoder(Variable(query_images).cuda(GPU)).view(TEST_NUM_PER_CLASS*CLASS_NUM,64,25) 123 | 124 | H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU) 125 | H_query_features = Variable(torch.Tensor(TEST_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU) 126 | # HOP features 127 | for d in range(support_features.size(0)): 128 | s = support_features[d,:,:].squeeze(0) 129 | s = (1.0 / support_features.size(2)) * s.mm(s.t()) 130 | H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) 131 | for d in range(query_features.size(0)): 132 | s = query_features[d,:,:].squeeze(0) 133 | s = (1.0 / query_features.size(2)) * s.mm(s.t()) 134 | H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) 135 | 136 | # calculate relations 137 | # each query support link to every supports to calculate relations 138 | # to form a 100x128 matrix for relation network 139 | support_features_ext = H_support_features.unsqueeze(0).repeat(TEST_NUM_PER_CLASS*CLASS_NUM,1,1,1,1) 140 | query_features_ext = H_query_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1) 141 | query_features_ext = torch.transpose(query_features_ext,0,1) 142 | 143 | relation_pairs = torch.cat((support_features_ext,query_features_ext),2).view(-1,2,64,64) 144 | relations = relation_network(relation_pairs).view(-1,CLASS_NUM) 145 | 146 | _,predict_labels = torch.max(relations.data,1) 147 | 148 | rewards = [1 if predict_labels[j]==query_labels[j].cuda(GPU) else 0 for j in range(CLASS_NUM*TEST_NUM_PER_CLASS)] 149 | 150 | total_rewards += np.sum(rewards) 151 | 152 | test_accuracy = total_rewards/1.0/CLASS_NUM/TEST_NUM_PER_CLASS/TEST_EPISODE 153 | 154 | print("query accuracy:",test_accuracy) 155 | print("best accuracy:",best_accuracy) 156 | 157 | if test_accuracy > best_accuracy: 158 | best_accuracy = test_accuracy 159 | 160 | 161 | 162 | 163 | 164 | if __name__ == '__main__': 165 | main() 166 | -------------------------------------------------------------------------------- /omniglot/omniglot_train_few_shot_SoSN.py: -------------------------------------------------------------------------------- 1 | #------------------------------------- 2 | # Project: Learning to Compare: Relation Network for Few-Shot Learning 3 | # Date: 2017.9.21 4 | # Author: Flood Sung 5 | # All Rights Reserved 6 | #------------------------------------- 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | from torch.optim.lr_scheduler import StepLR 14 | import numpy as np 15 | import task_generator as tg 16 | import os 17 | import math 18 | import argparse 19 | import random 20 | import models 21 | 22 | parser = argparse.ArgumentParser(description="One Shot Visual Recognition") 23 | parser.add_argument("-f","--feature_dim",type = int, default = 64) 24 | parser.add_argument("-r","--relation_dim",type = int, default = 8) 25 | parser.add_argument("-w","--class_num",type = int, default = 5) 26 | parser.add_argument("-s","--support_num_per_class",type = int, default = 5) 27 | parser.add_argument("-b","--query_num_per_class",type = int, default = 2) 28 | parser.add_argument("-e","--episode",type = int, default= 1000000) 29 | parser.add_argument("-t","--query_episode", type = int, default = 1000) 30 | parser.add_argument("-l","--learning_rate", type = float, default = 0.001) 31 | parser.add_argument("-g","--gpu",type=int, default=0) 32 | parser.add_argument("-u","--hidden_unit",type=int,default=10) 33 | parser.add_argument("-sigma","--sigma", type = float, default = 1) 34 | parser.add_argument("-ts","--test_num_per_class",type = int, default = 5) 35 | args = parser.parse_args() 36 | 37 | 38 | # Hyper Parameters 39 | METHOD = "SoSN_LOGIT" + str(args.sigma) + "_Models" 40 | FEATURE_DIM = args.feature_dim 41 | RELATION_DIM = args.relation_dim 42 | CLASS_NUM = args.class_num 43 | SUPPORT_NUM_PER_CLASS = args.support_num_per_class 44 | QUERY_NUM_PER_CLASS = args.query_num_per_class 45 | TEST_NUM_PER_CLASS = args.test_num_per_class 46 | EPISODE = args.episode 47 | TEST_EPISODE = args.query_episode 48 | LEARNING_RATE = args.learning_rate 49 | GPU = args.gpu 50 | HIDDEN_UNIT = args.hidden_unit 51 | SIGMA = args.sigma 52 | 53 | def power_norm(x, SIGMA): 54 | out = 2/(1 + torch.exp(-SIGMA*x)) - 1 55 | return out 56 | 57 | def weights_init(m): 58 | classname = m.__class__.__name__ 59 | if classname.find('Conv') != -1: 60 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 61 | m.weight.data.normal_(0, math.sqrt(2. / n)) 62 | if m.bias is not None: 63 | m.bias.data.zero_() 64 | elif classname.find('BatchNorm') != -1: 65 | m.weight.data.fill_(1) 66 | m.bias.data.zero_() 67 | elif classname.find('Linear') != -1: 68 | n = m.weight.size(1) 69 | m.weight.data.normal_(0, 0.01) 70 | m.bias.data = torch.ones(m.bias.data.size()) 71 | 72 | def main(): 73 | # Step 1: init data folders 74 | print("init data folders") 75 | # init character folders for dataset construction 76 | metatrain_character_folders,metaquery_character_folders = tg.omniglot_character_folders() 77 | 78 | # Step 2: init neural networks 79 | print("init neural networks") 80 | 81 | feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU) 82 | relation_network = models.SimilarityNetwork(FEATURE_DIM,RELATION_DIM).apply(weights_init).cuda(GPU) 83 | 84 | feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE) 85 | feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=50000,gamma=0.1) 86 | relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE) 87 | relation_network_scheduler = StepLR(relation_network_optim,step_size=50000,gamma=0.1) 88 | 89 | if os.path.exists(str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")): 90 | feature_encoder.load_state_dict(torch.load(str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))) 91 | print("load feature encoder success") 92 | if os.path.exists(str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")): 93 | relation_network.load_state_dict(torch.load(str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))) 94 | print("load similarity network success") 95 | if os.path.exists(METHOD) == False: 96 | os.system('mkdir ' + METHOD) 97 | 98 | # Step 3: build graph 99 | print("Training...") 100 | 101 | best_accuracy = 0.0 102 | best_h = 0.0 103 | 104 | for episode in range(EPISODE): 105 | 106 | feature_encoder_scheduler.step(episode) 107 | relation_network_scheduler.step(episode) 108 | 109 | # init dataset 110 | # support_dataloader is to obtain previous supports for compare 111 | # query_dataloader is to query supports for training 112 | degrees = random.choice([0,90,180,270]) 113 | task = tg.OmniglotTask(metatrain_character_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,QUERY_NUM_PER_CLASS) 114 | support_dataloader = tg.get_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False,rotation=degrees) 115 | query_dataloader = tg.get_data_loader(task,num_per_class=QUERY_NUM_PER_CLASS,split="test",shuffle=True,rotation=degrees) 116 | 117 | 118 | # support datas 119 | supports,support_labels = support_dataloader.__iter__().next() 120 | queries,query_labels = query_dataloader.__iter__().next() 121 | 122 | # calculate features 123 | support_features = feature_encoder(Variable(supports).cuda(GPU)) # 5x64*5*5 124 | support_features = support_features.view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,FEATURE_DIM,25).sum(1) #size: CLASS_NUMx64x5x5 125 | query_features = feature_encoder(Variable(queries).cuda(GPU)).view(QUERY_NUM_PER_CLASS*CLASS_NUM,64,25) #size: QUERY_NUM_PER_CLASSx64x5x5 126 | 127 | H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU) 128 | H_query_features = Variable(torch.Tensor(QUERY_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU) 129 | # HOP features 130 | for d in range(support_features.size(0)): 131 | s = support_features[d,:,:].squeeze(0) 132 | s = (1.0 / support_features.size(2)) * s.mm(s.t()) 133 | H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) 134 | for d in range(query_features.size(0)): 135 | s = query_features[d,:,:].squeeze(0) 136 | s = (1.0 / query_features.size(2)) * s.mm(s.t()) 137 | H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) 138 | 139 | # form the QURY_NUM_PER_CLASSxCLASS_NUM relation pairs 140 | support_features_ext = H_support_features.unsqueeze(0).repeat(QUERY_NUM_PER_CLASS*CLASS_NUM,1,1,1,1) 141 | query_features_ext = H_query_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1) 142 | query_features_ext = torch.transpose(query_features_ext,0,1) 143 | relation_pairs = torch.cat((support_features_ext, query_features_ext),2).view(-1,2,64,64) 144 | # calculate relation scores 145 | relations = relation_network(relation_pairs).view(-1,CLASS_NUM) 146 | 147 | # define the loss function 148 | mse = nn.MSELoss().cuda(GPU) 149 | one_hot_labels = Variable(torch.zeros(QUERY_NUM_PER_CLASS*CLASS_NUM, CLASS_NUM).scatter_(1, query_labels.view(-1,1), 1)).cuda(GPU) 150 | loss = mse(relations,one_hot_labels) 151 | 152 | # training 153 | feature_encoder.zero_grad() 154 | relation_network.zero_grad() 155 | 156 | loss.backward() 157 | 158 | feature_encoder_optim.step() 159 | relation_network_optim.step() 160 | 161 | if (episode+1)%100 == 0: 162 | print("episode:",episode+1,"loss",loss.data[0]) 163 | 164 | if (episode)%2500 == 0: 165 | with torch.no_grad(): 166 | # query 167 | print("Testing...") 168 | total_rewards = 0 169 | 170 | for i in range(TEST_EPISODE): 171 | degrees = random.choice([0,90,180,270]) 172 | task = tg.OmniglotTask(metaquery_character_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,TEST_NUM_PER_CLASS,) 173 | support_dataloader = tg.get_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False,rotation=degrees) 174 | query_dataloader = tg.get_data_loader(task,num_per_class=TEST_NUM_PER_CLASS,split="test",shuffle=True,rotation=degrees) 175 | 176 | support_images,support_labels = support_dataloader.__iter__().next() 177 | query_images,query_labels = query_dataloader.__iter__().next() 178 | 179 | # calculate features 180 | support_features = feature_encoder(Variable(support_images).cuda(GPU)) # 5x64 181 | support_features = support_features.view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,FEATURE_DIM,25).sum(1) 182 | query_features = feature_encoder(Variable(query_images).cuda(GPU)).view(TEST_NUM_PER_CLASS*CLASS_NUM,64,25) 183 | 184 | H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU) 185 | H_query_features = Variable(torch.Tensor(TEST_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU) 186 | # HOP features 187 | for d in range(support_features.size(0)): 188 | s = support_features[d,:,:].squeeze(0) 189 | s = (1.0 / support_features.size(2)) * s.mm(s.t()) 190 | H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) 191 | for d in range(query_features.size(0)): 192 | s = query_features[d,:,:].squeeze(0) 193 | s = (1.0 / query_features.size(2)) * s.mm(s.t()) 194 | H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) 195 | 196 | # calculate relations 197 | # each query support link to every supports to calculate relations 198 | # to form a 100x128 matrix for relation network 199 | support_features_ext = H_support_features.unsqueeze(0).repeat(TEST_NUM_PER_CLASS*CLASS_NUM,1,1,1,1) 200 | query_features_ext = H_query_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1) 201 | query_features_ext = torch.transpose(query_features_ext,0,1) 202 | 203 | relation_pairs = torch.cat((support_features_ext,query_features_ext),2).view(-1,2,64,64) 204 | relations = relation_network(relation_pairs).view(-1,CLASS_NUM) 205 | 206 | _,predict_labels = torch.max(relations.data,1) 207 | 208 | rewards = [1 if predict_labels[j]==query_labels[j].cuda(GPU) else 0 for j in range(CLASS_NUM*TEST_NUM_PER_CLASS)] 209 | 210 | total_rewards += np.sum(rewards) 211 | 212 | test_accuracy = total_rewards/1.0/CLASS_NUM/TEST_NUM_PER_CLASS/TEST_EPISODE 213 | 214 | print("query accuracy:",test_accuracy) 215 | print("best accuracy:",best_accuracy) 216 | 217 | if test_accuracy > best_accuracy: 218 | # save networks 219 | torch.save(feature_encoder.state_dict(),str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")) 220 | torch.save(relation_network.state_dict(),str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")) 221 | print("save networks for episode:",episode) 222 | best_accuracy = test_accuracy 223 | 224 | 225 | 226 | 227 | 228 | if __name__ == '__main__': 229 | main() 230 | -------------------------------------------------------------------------------- /omniglot/task_generator.py: -------------------------------------------------------------------------------- 1 | # code is based on https://github.com/katerakelly/pytorch-maml 2 | import torchvision 3 | import torchvision.datasets as dset 4 | import torchvision.transforms as transforms 5 | import torch 6 | from torch.utils.data import DataLoader,Dataset 7 | import random 8 | import os 9 | from PIL import Image 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from torch.utils.data.sampler import Sampler 13 | 14 | def imshow(img): 15 | npimg = img.numpy() 16 | plt.axis("off") 17 | plt.imshow(np.transpose(npimg,(1,2,0))) 18 | plt.show() 19 | 20 | class Rotate(object): 21 | def __init__(self, angle): 22 | self.angle = angle 23 | def __call__(self, x, mode="reflect"): 24 | x = x.rotate(self.angle) 25 | return x 26 | 27 | def omniglot_character_folders(): 28 | data_folder = '../datas/omniglot_resized/' 29 | 30 | character_folders = [os.path.join(data_folder, family, character) \ 31 | for family in os.listdir(data_folder) \ 32 | if os.path.isdir(os.path.join(data_folder, family)) \ 33 | for character in os.listdir(os.path.join(data_folder, family))] 34 | random.seed(1) 35 | random.shuffle(character_folders) 36 | 37 | num_train = 1200 38 | metatrain_character_folders = character_folders[:num_train] 39 | metaval_character_folders = character_folders[num_train:] 40 | 41 | return metatrain_character_folders,metaval_character_folders 42 | 43 | class OmniglotTask(object): 44 | # This class is for task generation for both meta training and meta testing. 45 | # For meta training, we use all 20 samples without valid set (empty here). 46 | # For meta testing, we use 1 or 5 shot samples for training, while using the same number of samples for validation. 47 | # If set num_samples = 20 and chracter_folders = metatrain_character_folders, we generate tasks for meta training 48 | # If set num_samples = 1 or 5 and chracter_folders = metatest_chracter_folders, we generate tasks for meta testing 49 | def __init__(self, character_folders, num_classes, train_num,test_num): 50 | 51 | self.character_folders = character_folders 52 | self.num_classes = num_classes 53 | self.train_num = train_num 54 | self.test_num = test_num 55 | 56 | class_folders = random.sample(self.character_folders,self.num_classes) 57 | labels = np.array(range(len(class_folders))) 58 | labels = dict(zip(class_folders, labels)) 59 | samples = dict() 60 | 61 | self.train_roots = [] 62 | self.test_roots = [] 63 | for c in class_folders: 64 | 65 | temp = [os.path.join(c, x) for x in os.listdir(c)] 66 | samples[c] = random.sample(temp, len(temp)) 67 | 68 | self.train_roots += samples[c][:train_num] 69 | self.test_roots += samples[c][train_num:train_num+test_num] 70 | 71 | self.train_labels = [labels[self.get_class(x)] for x in self.train_roots] 72 | self.test_labels = [labels[self.get_class(x)] for x in self.test_roots] 73 | 74 | def get_class(self, sample): 75 | return os.path.join(*sample.split('/')[:-1]) 76 | 77 | 78 | class FewShotDataset(Dataset): 79 | 80 | def __init__(self, task, split='train', transform=None, target_transform=None): 81 | self.transform = transform # Torch operations on the input image 82 | self.target_transform = target_transform 83 | self.task = task 84 | self.split = split 85 | self.image_roots = self.task.train_roots if self.split == 'train' else self.task.test_roots 86 | self.labels = self.task.train_labels if self.split == 'train' else self.task.test_labels 87 | 88 | def __len__(self): 89 | return len(self.image_roots) 90 | 91 | def __getitem__(self, idx): 92 | raise NotImplementedError("This is an abstract class. Subclass this class for your particular dataset.") 93 | 94 | 95 | class Omniglot(FewShotDataset): 96 | 97 | def __init__(self, *args, **kwargs): 98 | super(Omniglot, self).__init__(*args, **kwargs) 99 | 100 | def __getitem__(self, idx): 101 | image_root = self.image_roots[idx] 102 | image = Image.open(image_root) 103 | image = image.convert('L') 104 | image = image.resize((28,28), resample=Image.LANCZOS) # per Chelsea's implementation 105 | #image = np.array(image, dtype=np.float32) 106 | if self.transform is not None: 107 | image = self.transform(image) 108 | label = self.labels[idx] 109 | if self.target_transform is not None: 110 | label = self.target_transform(label) 111 | return image, label 112 | 113 | class ClassBalancedSampler(Sampler): 114 | ''' Samples 'num_inst' examples each from 'num_cl' pools 115 | of examples of size 'num_per_class' ''' 116 | 117 | def __init__(self, num_per_class, num_cl, num_inst,shuffle=True): 118 | self.num_per_class = num_per_class 119 | self.num_cl = num_cl 120 | self.num_inst = num_inst 121 | self.shuffle = shuffle 122 | 123 | def __iter__(self): 124 | # return a single list of indices, assuming that items will be grouped by class 125 | if self.shuffle: 126 | batch = [[i+j*self.num_inst for i in torch.randperm(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)] 127 | else: 128 | batch = [[i+j*self.num_inst for i in range(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)] 129 | batch = [item for sublist in batch for item in sublist] 130 | 131 | if self.shuffle: 132 | random.shuffle(batch) 133 | return iter(batch) 134 | 135 | def __len__(self): 136 | return 1 137 | 138 | 139 | def get_data_loader(task, num_per_class=1, split='train',shuffle=True,rotation=0): 140 | # NOTE: batch size here is # instances PER CLASS 141 | normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426]) 142 | 143 | dataset = Omniglot(task,split=split,transform=transforms.Compose([Rotate(rotation),transforms.ToTensor(),normalize])) 144 | 145 | if split == 'train': 146 | sampler = ClassBalancedSampler(num_per_class, task.num_classes, task.train_num,shuffle=shuffle) 147 | else: 148 | sampler = ClassBalancedSampler(num_per_class, task.num_classes, task.test_num,shuffle=shuffle) 149 | loader = DataLoader(dataset, batch_size=num_per_class*task.num_classes, sampler=sampler) 150 | 151 | return loader 152 | 153 | -------------------------------------------------------------------------------- /openmic/p1-p2/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FeatureEncoder(nn.Module): 6 | """docstring for ClassName""" 7 | def __init__(self): 8 | super(FeatureEncoder, self).__init__() 9 | self.layer1 = nn.Sequential( 10 | nn.Conv2d(3,64,kernel_size=3,padding=0), 11 | nn.BatchNorm2d(64, momentum=1, affine=True), 12 | nn.ReLU(), 13 | nn.MaxPool2d(2)) 14 | self.layer2 = nn.Sequential( 15 | nn.Conv2d(64,64,kernel_size=3,padding=0), 16 | nn.BatchNorm2d(64, momentum=1, affine=True), 17 | nn.ReLU(), 18 | nn.MaxPool2d(2)) 19 | self.layer3 = nn.Sequential( 20 | nn.Conv2d(64,64,kernel_size=3,padding=1), 21 | nn.BatchNorm2d(64, momentum=1, affine=True), 22 | nn.ReLU()) 23 | self.layer4 = nn.Sequential( 24 | nn.Conv2d(64,64,kernel_size=3,padding=1), 25 | nn.BatchNorm2d(64, momentum=1, affine=True), 26 | nn.ReLU()) 27 | 28 | def forward(self,x): 29 | out = self.layer1(x) 30 | out = self.layer2(out) 31 | out = self.layer3(out) 32 | out = self.layer4(out) 33 | #out = out.view(out.size(0),-1) 34 | return out # 64 35 | 36 | class SimilarityNetwork(nn.Module): 37 | """docstring for RelationNetwork""" 38 | def __init__(self,input_size,hidden_size): 39 | super(SimilarityNetwork, self).__init__() 40 | self.layer1 = nn.Sequential( 41 | nn.Conv2d(2,64,kernel_size=3,padding=0), 42 | nn.BatchNorm2d(64, momentum=1, affine=True), 43 | nn.ReLU(), 44 | nn.MaxPool2d(2)) #Nx64x31x31 45 | self.layer2 = nn.Sequential( 46 | nn.Conv2d(64,64,kernel_size=3,padding=0), 47 | nn.BatchNorm2d(64, momentum=1, affine=True), 48 | nn.ReLU(), 49 | nn.MaxPool2d(2)) #Nx64x14x14 50 | self.layer3 = nn.Sequential( 51 | nn.Conv2d(64,64,kernel_size=3,padding=0), 52 | nn.BatchNorm2d(64, momentum=1, affine=True), 53 | nn.ReLU(), 54 | nn.MaxPool2d(2)) #Nx64x6x6 55 | self.layer4 = nn.Sequential( 56 | nn.Conv2d(64,64,kernel_size=3,padding=0), 57 | nn.BatchNorm2d(64, momentum=1, affine=True), 58 | nn.ReLU(), 59 | nn.MaxPool2d(2)) #Nx64x2x2 60 | self.fc1 = nn.Linear(input_size*4,hidden_size) 61 | self.fc2 = nn.Linear(hidden_size,1) 62 | 63 | def forward(self,x): 64 | out = self.layer1(x) 65 | out = self.layer2(out) 66 | out = self.layer3(out) 67 | out = self.layer4(out) 68 | out = out.view(out.size(0),-1) 69 | out = F.relu(self.fc1(out)) 70 | out = F.sigmoid(self.fc2(out)) 71 | return out 72 | -------------------------------------------------------------------------------- /openmic/p1-p2/openmic_test_few_shot_SoSN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.optim.lr_scheduler import StepLR 6 | import numpy as np 7 | import task_generator as tg 8 | import os 9 | import math 10 | import argparse 11 | import scipy as sp 12 | import scipy.stats 13 | import models 14 | 15 | parser = argparse.ArgumentParser(description="One Shot Visual Recognition") 16 | parser.add_argument("-f","--feature_dim",type = int, default = 64) 17 | parser.add_argument("-r","--relation_dim",type = int, default = 8) 18 | parser.add_argument("-w","--class_num",type = int, default = 5) 19 | parser.add_argument("-s","--support_num_per_class",type = int, default = 1) 20 | parser.add_argument("-b","--query_num_per_class",type = int, default = 2) 21 | parser.add_argument("-e","--episode",type = int, default = 100) 22 | parser.add_argument("-t","--query_episode", type = int, default = 1000) 23 | parser.add_argument("-l","--learning_rate", type = float, default = 0.0001) 24 | parser.add_argument("-g","--gpu",type=int, default=0) 25 | parser.add_argument("-u","--hidden_unit",type=int,default=10) 26 | parser.add_argument("-sigma","--sigma", type = float, default = 1) 27 | parser.add_argument("-lam","--lamb", type = float, default = 0) 28 | args = parser.parse_args() 29 | 30 | # Hyper Parameters 31 | METHOD = "SoSN_Logit" + str(args.sigma) + "_Models" 32 | FEATURE_DIM = args.feature_dim 33 | RELATION_DIM = args.relation_dim 34 | CLASS_NUM = args.class_num 35 | SUPPORT_NUM_PER_CLASS = 1 36 | QUERY_NUM_PER_CLASS = args.query_num_per_class 37 | EPISODE = args.episode 38 | TEST_EPISODE = args.query_episode 39 | LEARNING_RATE = args.learning_rate 40 | GPU = args.gpu 41 | HIDDEN_UNIT = args.hidden_unit 42 | SIGMA = args.sigma 43 | LAMBDA = args.lamb 44 | 45 | def power_norm(x, SIGMA): 46 | out = 2/(1 + torch.exp(-SIGMA*x)) - 1 47 | return out 48 | 49 | def mean_confidence_interval(data, confidence=0.95): 50 | a = 1.0*np.array(data) 51 | n = len(a) 52 | m, se = np.mean(a), scipy.stats.sem(a) 53 | h = se * sp.stats.t._ppf((1+confidence)/2., n-1) 54 | return m,h 55 | 56 | def weights_init(m): 57 | classname = m.__class__.__name__ 58 | if classname.find('Conv') != -1: 59 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 60 | m.weight.data.normal_(0, math.sqrt(2. / n)) 61 | if m.bias is not None: 62 | m.bias.data.zero_() 63 | elif classname.find('BatchNorm') != -1: 64 | m.weight.data.fill_(1) 65 | m.bias.data.zero_() 66 | elif classname.find('Linear') != -1: 67 | n = m.weight.size(1) 68 | m.weight.data.normal_(0, 0.01) 69 | m.bias.data = torch.ones(m.bias.data.size()) 70 | 71 | def main(): 72 | metatrain_folders,metaquery_folders = tg.mini_imagenet_folders() 73 | 74 | print("init neural networks") 75 | 76 | feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU) 77 | relation_network = models.SimilarityNetwork(FEATURE_DIM,RELATION_DIM).apply(weights_init).cuda(GPU) 78 | 79 | feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE) 80 | feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=50000,gamma=0.5) 81 | relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE) 82 | relation_network_scheduler = StepLR(relation_network_optim,step_size=50000,gamma=0.5) 83 | 84 | if os.path.exists(str(METHOD + "/feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")): 85 | feature_encoder.load_state_dict(torch.load(str(METHOD + "/feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))) 86 | print("load feature encoder success") 87 | if os.path.exists(str(METHOD + "/relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")): 88 | relation_network.load_state_dict(torch.load(str(METHOD + "/relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))) 89 | print("load relation network success") 90 | if os.path.exists(METHOD) == False: 91 | os.system('mkdir ' + METHOD) 92 | 93 | # Step 3: build graph 94 | print("Training...") 95 | 96 | best_accuracy = 0.0 97 | best_h = 0.0 98 | 99 | for episode in range(EPISODE): 100 | with torch.no_grad(): 101 | print("Testing...") 102 | accuracies = [] 103 | for i in range(TEST_EPISODE): 104 | total_rewards = 0 105 | counter = 0 106 | task = tg.MiniImagenetTask(metaquery_folders,CLASS_NUM,1,2) 107 | support_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=1,split="train",shuffle=False) 108 | num_per_class = 2 109 | query_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=num_per_class,split="query",shuffle=True) 110 | support_images,support_labels = support_dataloader.__iter__().next() 111 | for query_images,query_labels in query_dataloader: 112 | query_size = query_labels.shape[0] 113 | # calculate features 114 | support_features = feature_encoder(Variable(support_images).cuda(GPU)).view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,64,19**2).sum(1) 115 | query_features = feature_encoder(Variable(query_images).cuda(GPU)).view(num_per_class*CLASS_NUM,64,19**2) 116 | 117 | H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU) 118 | H_query_features = Variable(torch.Tensor(num_per_class*CLASS_NUM, 1, 64, 64)).cuda(GPU) 119 | # HOP features 120 | for d in range(support_features.size()[0]): 121 | s = support_features[d,:,:].squeeze(0) 122 | s = s - LAMBDA * s.mean(1).repeat(1,s.size()[1]).view(s.size()) 123 | s = (1.0 / support_features.size()[2]) * s.mm(s.transpose(0,1)) 124 | H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) 125 | for d in range(query_features.size()[0]): 126 | s = query_features[d,:,:].squeeze(0) 127 | s = s - LAMBDA * s.mean(1).repeat(1,s.size()[1]).view(s.size()) 128 | s = (1.0 / query_features.size()[2]) * s.mm(s.transpose(0,1)) 129 | H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) 130 | 131 | # form relation pairs 132 | support_features_ext = H_support_features.unsqueeze(0).repeat(query_size,1,1,1,1) 133 | query_features_ext = H_query_features.unsqueeze(0).repeat(1*CLASS_NUM,1,1,1,1) 134 | query_features_ext = torch.transpose(query_features_ext,0,1) 135 | relation_pairs = torch.cat((support_features_ext,query_features_ext),2).view(-1,2,64,64) 136 | # calculate relation scores 137 | relations = relation_network(relation_pairs).view(-1,CLASS_NUM) 138 | 139 | _,predict_labels = torch.max(relations.data,1) 140 | 141 | rewards = [1 if predict_labels[j]==query_labels[j].cuda(GPU) else 0 for j in range(query_size)] 142 | 143 | total_rewards += np.sum(rewards) 144 | counter += query_size 145 | accuracy = total_rewards/1.0/counter 146 | accuracies.append(accuracy) 147 | 148 | test_accuracy,h = mean_confidence_interval(accuracies) 149 | 150 | print("Test accuracy:", test_accuracy,"h:", h) 151 | print("Best accuracy: ", best_accuracy, "h:", best_h) 152 | 153 | if test_accuracy > best_accuracy: 154 | best_accuracy = test_accuracy 155 | best_h = h 156 | 157 | 158 | 159 | 160 | 161 | if __name__ == '__main__': 162 | main() 163 | -------------------------------------------------------------------------------- /openmic/p1-p2/openmic_train_few_shot_SoSN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.optim.lr_scheduler import StepLR 6 | import numpy as np 7 | import task_generator as tg 8 | import os 9 | import math 10 | import argparse 11 | import scipy as sp 12 | import scipy.stats 13 | import models 14 | 15 | parser = argparse.ArgumentParser(description="One Shot Visual Recognition") 16 | parser.add_argument("-f","--feature_dim",type = int, default = 64) 17 | parser.add_argument("-r","--relation_dim",type = int, default = 8) 18 | parser.add_argument("-w","--class_num",type = int, default = 5) 19 | parser.add_argument("-s","--support_num_per_class",type = int, default = 1) 20 | parser.add_argument("-b","--query_num_per_class",type = int, default = 2) 21 | parser.add_argument("-e","--episode",type = int, default= 50000) 22 | parser.add_argument("-t","--query_episode", type = int, default = 1000) 23 | parser.add_argument("-l","--learning_rate", type = float, default = 0.0001) 24 | parser.add_argument("-g","--gpu",type=int, default=0) 25 | parser.add_argument("-u","--hidden_unit",type=int,default=10) 26 | parser.add_argument("-sigma","--sigma", type = float, default = 1) 27 | parser.add_argument("-lam","--lamb", type = float, default = 0) 28 | args = parser.parse_args() 29 | 30 | # Hyper Parameters 31 | METHOD = "SoSN_Logit" + str(args.sigma) + "_Models" 32 | FEATURE_DIM = args.feature_dim 33 | RELATION_DIM = args.relation_dim 34 | CLASS_NUM = args.class_num 35 | SUPPORT_NUM_PER_CLASS = 1 36 | QUERY_NUM_PER_CLASS = args.query_num_per_class 37 | EPISODE = args.episode 38 | TEST_EPISODE = args.query_episode 39 | LEARNING_RATE = args.learning_rate 40 | GPU = args.gpu 41 | HIDDEN_UNIT = args.hidden_unit 42 | SIGMA = args.sigma 43 | LAMBDA = args.lamb 44 | 45 | def power_norm(x, SIGMA): 46 | out = 2/(1 + torch.exp(-SIGMA*x)) - 1 47 | return out 48 | 49 | def mean_confidence_interval(data, confidence=0.95): 50 | a = 1.0*np.array(data) 51 | n = len(a) 52 | m, se = np.mean(a), scipy.stats.sem(a) 53 | h = se * sp.stats.t._ppf((1+confidence)/2., n-1) 54 | return m,h 55 | 56 | def weights_init(m): 57 | classname = m.__class__.__name__ 58 | if classname.find('Conv') != -1: 59 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 60 | m.weight.data.normal_(0, math.sqrt(2. / n)) 61 | if m.bias is not None: 62 | m.bias.data.zero_() 63 | elif classname.find('BatchNorm') != -1: 64 | m.weight.data.fill_(1) 65 | m.bias.data.zero_() 66 | elif classname.find('Linear') != -1: 67 | n = m.weight.size(1) 68 | m.weight.data.normal_(0, 0.01) 69 | m.bias.data = torch.ones(m.bias.data.size()) 70 | 71 | def main(): 72 | # Step 1: init data folders 73 | print("init data folders") 74 | # init character folders for dataset construction 75 | metatrain_folders,metaquery_folders = tg.mini_imagenet_folders() 76 | 77 | # Step 2: init neural networks 78 | print("init neural networks") 79 | 80 | feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU) 81 | relation_network = models.SimilarityNetwork(FEATURE_DIM,RELATION_DIM).apply(weights_init).cuda(GPU) 82 | 83 | feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE) 84 | feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=50000,gamma=0.5) 85 | relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE) 86 | relation_network_scheduler = StepLR(relation_network_optim,step_size=50000,gamma=0.5) 87 | 88 | if os.path.exists(str(METHOD + "/feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")): 89 | feature_encoder.load_state_dict(torch.load(str(METHOD + "/feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))) 90 | print("load feature encoder success") 91 | if os.path.exists(str(METHOD + "/relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")): 92 | relation_network.load_state_dict(torch.load(str(METHOD + "/relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))) 93 | print("load relation network success") 94 | if os.path.exists(METHOD) == False: 95 | os.system('mkdir ' + METHOD) 96 | 97 | # Step 3: build graph 98 | print("Training...") 99 | 100 | best_accuracy = 0.0 101 | best_h = 0.0 102 | 103 | for episode in range(EPISODE): 104 | feature_encoder_scheduler.step(episode) 105 | relation_network_scheduler.step(episode) 106 | 107 | # init dataset 108 | # support_dataloader is to obtain previous supports for compare 109 | # query_dataloader is to query supports for training 110 | task = tg.MiniImagenetTask(metatrain_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,QUERY_NUM_PER_CLASS) 111 | support_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False) 112 | query_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=QUERY_NUM_PER_CLASS,split="test",shuffle=True) 113 | 114 | # support datas 115 | supports,support_labels = support_dataloader.__iter__().next() 116 | queries,query_labels = query_dataloader.__iter__().next() 117 | 118 | # calculate features 119 | support_features = feature_encoder(Variable(supports).cuda(GPU)).view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,64,19**2).sum(1) # 5x64*19*19 120 | query_features = feature_encoder(Variable(queries).cuda(GPU)).view(QUERY_NUM_PER_CLASS*CLASS_NUM,64,19**2) # 20x64*19*19 121 | H_support_features = Variable(torch.Tensor(SUPPORT_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU) 122 | H_query_features = Variable(torch.Tensor(QUERY_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU) 123 | # HOP features 124 | for d in range(support_features.size()[0]): 125 | s = support_features[d,:,:].squeeze(0) 126 | s = s - LAMBDA * s.mean(1).repeat(1,s.size()[1]).view(s.size()) 127 | s = (1.0 / support_features.size()[2]) * s.mm(s.transpose(0,1)) 128 | H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) 129 | for d in range(query_features.size()[0]): 130 | s = query_features[d,:,:].squeeze(0) 131 | s = s - LAMBDA * s.mean(1).repeat(1,s.size()[1]).view(s.size()) 132 | s = (1.0 / query_features.size()[2]) * s.mm(s.transpose(0,1)) 133 | H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) 134 | 135 | # form relation pairs 136 | support_features_ext = H_support_features.unsqueeze(0).repeat(QUERY_NUM_PER_CLASS*CLASS_NUM,1,1,1,1) 137 | query_features_ext = H_query_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1) 138 | query_features_ext = torch.transpose(query_features_ext,0,1) 139 | relation_pairs = torch.cat((support_features_ext,query_features_ext),2).view(-1,2,64,64) 140 | # calculate relation scores 141 | relations = relation_network(relation_pairs).view(-1,CLASS_NUM*SUPPORT_NUM_PER_CLASS) 142 | 143 | # define loss function 144 | mse = nn.MSELoss().cuda(GPU) 145 | one_hot_labels = Variable(torch.zeros(QUERY_NUM_PER_CLASS*CLASS_NUM, CLASS_NUM).scatter_(1, query_labels.view(-1,1), 1)).cuda(GPU) 146 | loss = mse(relations,one_hot_labels) 147 | 148 | # updating network parameters with their gradients 149 | feature_encoder.zero_grad() 150 | relation_network.zero_grad() 151 | 152 | loss.backward() 153 | 154 | feature_encoder_optim.step() 155 | relation_network_optim.step() 156 | 157 | if (episode+1)%100 == 0: 158 | print("episode:",episode+1,"loss",loss.data[0]) 159 | 160 | if episode%500 == 0: 161 | # query 162 | print("Testing...") 163 | 164 | accuracies = [] 165 | for i in range(TEST_EPISODE): 166 | with torch.no_grad(): 167 | total_rewards = 0 168 | counter = 0 169 | task = tg.MiniImagenetTask(metaquery_folders,CLASS_NUM,1,2) 170 | support_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=1,split="train",shuffle=False) 171 | num_per_class = 2 172 | query_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=num_per_class,split="query",shuffle=True) 173 | support_images,support_labels = support_dataloader.__iter__().next() 174 | for query_images,query_labels in query_dataloader: 175 | query_size = query_labels.shape[0] 176 | # calculate features 177 | support_features = feature_encoder(Variable(support_images).cuda(GPU)).view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,64,19**2).sum(1) 178 | query_features = feature_encoder(Variable(query_images).cuda(GPU)).view(num_per_class*CLASS_NUM,64,19**2) 179 | 180 | H_support_features = Variable(torch.Tensor(SUPPORT_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU) 181 | H_query_features = Variable(torch.Tensor(num_per_class*CLASS_NUM, 1, 64, 64)).cuda(GPU) 182 | # HOP features 183 | for d in range(support_features.size()[0]): 184 | s = support_features[d,:,:].squeeze(0) 185 | s = s - LAMBDA * s.mean(1).repeat(1,s.size()[1]).view(s.size()) 186 | s = (1.0 / support_features.size()[2]) * s.mm(s.transpose(0,1)) 187 | H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) 188 | for d in range(query_features.size()[0]): 189 | s = query_features[d,:,:].squeeze(0) 190 | s = s - LAMBDA * s.mean(1).repeat(1,s.size()[1]).view(s.size()) 191 | s = (1.0 / query_features.size()[2]) * s.mm(s.transpose(0,1)) 192 | H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) 193 | 194 | # form relation pairs 195 | support_features_ext = H_support_features.unsqueeze(0).repeat(query_size,1,1,1,1) 196 | query_features_ext = H_query_features.unsqueeze(0).repeat(1*CLASS_NUM,1,1,1,1) 197 | query_features_ext = torch.transpose(query_features_ext,0,1) 198 | relation_pairs = torch.cat((support_features_ext,query_features_ext),2).view(-1,2,64,64) 199 | # calculate relation scores 200 | relations = relation_network(relation_pairs).view(-1,CLASS_NUM) 201 | 202 | _,predict_labels = torch.max(relations.data,1) 203 | 204 | rewards = [1 if predict_labels[j]==query_labels[j].cuda(GPU) else 0 for j in range(query_size)] 205 | 206 | total_rewards += np.sum(rewards) 207 | counter += query_size 208 | accuracy = total_rewards/1.0/counter 209 | accuracies.append(accuracy) 210 | 211 | test_accuracy,h = mean_confidence_interval(accuracies) 212 | 213 | print("Test accuracy:", test_accuracy,"h:",h) 214 | print("Best accuracy: ", best_accuracy, "h:", best_h) 215 | 216 | if test_accuracy > best_accuracy: 217 | # save networks 218 | torch.save(feature_encoder.state_dict(),str(METHOD + "/feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")) 219 | torch.save(relation_network.state_dict(),str(METHOD + "/relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")) 220 | print("save networks for episode:",episode) 221 | 222 | best_accuracy = test_accuracy 223 | best_h = h 224 | 225 | 226 | 227 | 228 | 229 | if __name__ == '__main__': 230 | main() 231 | -------------------------------------------------------------------------------- /openmic/p1-p2/task_generator.py: -------------------------------------------------------------------------------- 1 | # code is based on https://github.com/katerakelly/pytorch-maml 2 | import torchvision 3 | import torchvision.datasets as dset 4 | import torchvision.transforms as transforms 5 | import torch 6 | from torch.utils.data import DataLoader,Dataset 7 | import random 8 | import os 9 | from PIL import Image 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from torch.utils.data.sampler import Sampler 13 | 14 | def imshow(img): 15 | npimg = img.numpy() 16 | plt.axis("off") 17 | plt.imshow(np.transpose(npimg,(1,2,0))) 18 | plt.show() 19 | 20 | class Rotate(object): 21 | def __init__(self, angle): 22 | self.angle = angle 23 | def __call__(self, x, mode="reflect"): 24 | x = x.rotate(self.angle) 25 | return x 26 | 27 | def mini_imagenet_folders(): 28 | train_folder = '../../datas/openmic/part1' 29 | test_folder = '../../datas/openmic/part2' 30 | 31 | metatrain_folders = [os.path.join(train_folder, label) \ 32 | for label in os.listdir(train_folder) \ 33 | if os.path.isdir(os.path.join(train_folder, label)) \ 34 | ] 35 | metatest_folders = [os.path.join(test_folder, label) \ 36 | for label in os.listdir(test_folder) \ 37 | if os.path.isdir(os.path.join(test_folder, label)) \ 38 | ] 39 | 40 | random.seed(1) 41 | random.shuffle(metatrain_folders) 42 | random.shuffle(metatest_folders) 43 | 44 | return metatrain_folders,metatest_folders 45 | 46 | class MiniImagenetTask(object): 47 | 48 | def __init__(self, character_folders, num_classes, train_num,test_num): 49 | 50 | self.character_folders = character_folders 51 | self.num_classes = num_classes 52 | self.train_num = train_num 53 | self.test_num = test_num 54 | 55 | class_folders = random.sample(self.character_folders,self.num_classes) 56 | labels = np.array(range(len(class_folders))) 57 | labels = dict(zip(class_folders, labels)) 58 | samples = dict() 59 | 60 | self.train_roots = [] 61 | self.test_roots = [] 62 | for c in class_folders: 63 | 64 | temp = [os.path.join(c, x) for x in os.listdir(c)] 65 | samples[c] = random.sample(temp, len(temp)) 66 | random.shuffle(samples[c]) 67 | 68 | self.train_roots += samples[c][:train_num] 69 | self.test_roots += samples[c][train_num:train_num+test_num] 70 | 71 | self.train_labels = [labels[self.get_class(x)] for x in self.train_roots] 72 | self.test_labels = [labels[self.get_class(x)] for x in self.test_roots] 73 | 74 | def get_class(self, sample): 75 | return os.path.join(*sample.split('/')[:-1]) 76 | 77 | 78 | class FewShotDataset(Dataset): 79 | 80 | def __init__(self, task, split='train', transform=None, target_transform=None): 81 | self.transform = transform # Torch operations on the input image 82 | self.target_transform = target_transform 83 | self.task = task 84 | self.split = split 85 | self.image_roots = self.task.train_roots if self.split == 'train' else self.task.test_roots 86 | self.labels = self.task.train_labels if self.split == 'train' else self.task.test_labels 87 | 88 | def __len__(self): 89 | return len(self.image_roots) 90 | 91 | def __getitem__(self, idx): 92 | raise NotImplementedError("This is an abstract class. Subclass this class for your particular dataset.") 93 | 94 | class MiniImagenet(FewShotDataset): 95 | 96 | def __init__(self, *args, **kwargs): 97 | super(MiniImagenet, self).__init__(*args, **kwargs) 98 | 99 | def __getitem__(self, idx): 100 | image_root = self.image_roots[idx] 101 | image = Image.open(image_root) 102 | image = image.convert('RGB') 103 | if self.transform is not None: 104 | image = self.transform(image) 105 | label = self.labels[idx] 106 | if self.target_transform is not None: 107 | label = self.target_transform(label) 108 | return image, label 109 | 110 | 111 | class ClassBalancedSampler(Sampler): 112 | ''' Samples 'num_inst' examples each from 'num_cl' pools 113 | of examples of size 'num_per_class' ''' 114 | 115 | def __init__(self, num_per_class, num_cl, num_inst,shuffle=True): 116 | self.num_per_class = num_per_class 117 | self.num_cl = num_cl 118 | self.num_inst = num_inst 119 | self.shuffle = shuffle 120 | 121 | def __iter__(self): 122 | # return a single list of indices, assuming that items will be grouped by class 123 | if self.shuffle: 124 | batch = [[i+j*self.num_inst for i in torch.randperm(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)] 125 | else: 126 | batch = [[i+j*self.num_inst for i in range(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)] 127 | batch = [item for sublist in batch for item in sublist] 128 | 129 | if self.shuffle: 130 | random.shuffle(batch) 131 | return iter(batch) 132 | 133 | def __len__(self): 134 | return 1 135 | 136 | 137 | def get_mini_imagenet_data_loader(task, num_per_class=1, split='train',shuffle = False): 138 | normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426]) 139 | 140 | dataset = MiniImagenet(task,split=split,transform=transforms.Compose([transforms.ToTensor()])) 141 | 142 | if split == 'train': 143 | sampler = ClassBalancedSampler(num_per_class, task.num_classes, task.train_num,shuffle=shuffle) 144 | else: 145 | sampler = ClassBalancedSampler(num_per_class, task.num_classes, task.test_num,shuffle=shuffle) 146 | 147 | loader = DataLoader(dataset, batch_size=num_per_class*task.num_classes, sampler=sampler) 148 | 149 | return loader 150 | 151 | -------------------------------------------------------------------------------- /openmic/p1-p2/task_generator.py~: -------------------------------------------------------------------------------- 1 | # code is based on https://github.com/katerakelly/pytorch-maml 2 | import torchvision 3 | import torchvision.datasets as dset 4 | import torchvision.transforms as transforms 5 | import torch 6 | from torch.utils.data import DataLoader,Dataset 7 | import random 8 | import os 9 | from PIL import Image 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from torch.utils.data.sampler import Sampler 13 | 14 | def imshow(img): 15 | npimg = img.numpy() 16 | plt.axis("off") 17 | plt.imshow(np.transpose(npimg,(1,2,0))) 18 | plt.show() 19 | 20 | class Rotate(object): 21 | def __init__(self, angle): 22 | self.angle = angle 23 | def __call__(self, x, mode="reflect"): 24 | x = x.rotate(self.angle) 25 | return x 26 | 27 | def mini_imagenet_folders(): 28 | train_folder = '../../datas/museum/gls/train' 29 | test_folder = '../../datas/museum/gls/test' 30 | 31 | metatrain_folders = [os.path.join(train_folder, label) \ 32 | for label in os.listdir(train_folder) \ 33 | if os.path.isdir(os.path.join(train_folder, label)) \ 34 | ] 35 | metatest_folders = [os.path.join(test_folder, label) \ 36 | for label in os.listdir(test_folder) \ 37 | if os.path.isdir(os.path.join(test_folder, label)) \ 38 | ] 39 | 40 | random.seed(1) 41 | random.shuffle(metatrain_folders) 42 | random.shuffle(metatest_folders) 43 | 44 | return metatrain_folders,metatest_folders 45 | 46 | class MiniImagenetTask(object): 47 | 48 | def __init__(self, character_folders, num_classes, train_num,test_num): 49 | 50 | self.character_folders = character_folders 51 | self.num_classes = num_classes 52 | self.train_num = train_num 53 | self.test_num = test_num 54 | 55 | class_folders = random.sample(self.character_folders,self.num_classes) 56 | labels = np.array(range(len(class_folders))) 57 | labels = dict(zip(class_folders, labels)) 58 | samples = dict() 59 | 60 | self.train_roots = [] 61 | self.test_roots = [] 62 | for c in class_folders: 63 | 64 | temp = [os.path.join(c, x) for x in os.listdir(c)] 65 | samples[c] = random.sample(temp, len(temp)) 66 | random.shuffle(samples[c]) 67 | 68 | self.train_roots += samples[c][:train_num] 69 | self.test_roots += samples[c][train_num:train_num+test_num] 70 | 71 | self.train_labels = [labels[self.get_class(x)] for x in self.train_roots] 72 | self.test_labels = [labels[self.get_class(x)] for x in self.test_roots] 73 | 74 | def get_class(self, sample): 75 | return os.path.join(*sample.split('/')[:-1]) 76 | 77 | 78 | class FewShotDataset(Dataset): 79 | 80 | def __init__(self, task, split='train', transform=None, target_transform=None): 81 | self.transform = transform # Torch operations on the input image 82 | self.target_transform = target_transform 83 | self.task = task 84 | self.split = split 85 | self.image_roots = self.task.train_roots if self.split == 'train' else self.task.test_roots 86 | self.labels = self.task.train_labels if self.split == 'train' else self.task.test_labels 87 | 88 | def __len__(self): 89 | return len(self.image_roots) 90 | 91 | def __getitem__(self, idx): 92 | raise NotImplementedError("This is an abstract class. Subclass this class for your particular dataset.") 93 | 94 | class MiniImagenet(FewShotDataset): 95 | 96 | def __init__(self, *args, **kwargs): 97 | super(MiniImagenet, self).__init__(*args, **kwargs) 98 | 99 | def __getitem__(self, idx): 100 | image_root = self.image_roots[idx] 101 | image = Image.open(image_root) 102 | image = image.convert('RGB') 103 | if self.transform is not None: 104 | image = self.transform(image) 105 | label = self.labels[idx] 106 | if self.target_transform is not None: 107 | label = self.target_transform(label) 108 | return image, label 109 | 110 | 111 | class ClassBalancedSampler(Sampler): 112 | ''' Samples 'num_inst' examples each from 'num_cl' pools 113 | of examples of size 'num_per_class' ''' 114 | 115 | def __init__(self, num_per_class, num_cl, num_inst,shuffle=True): 116 | self.num_per_class = num_per_class 117 | self.num_cl = num_cl 118 | self.num_inst = num_inst 119 | self.shuffle = shuffle 120 | 121 | def __iter__(self): 122 | # return a single list of indices, assuming that items will be grouped by class 123 | if self.shuffle: 124 | batch = [[i+j*self.num_inst for i in torch.randperm(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)] 125 | else: 126 | batch = [[i+j*self.num_inst for i in range(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)] 127 | batch = [item for sublist in batch for item in sublist] 128 | 129 | if self.shuffle: 130 | random.shuffle(batch) 131 | return iter(batch) 132 | 133 | def __len__(self): 134 | return 1 135 | 136 | 137 | def get_mini_imagenet_data_loader(task, num_per_class=1, split='train',shuffle = False): 138 | normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426]) 139 | 140 | dataset = MiniImagenet(task,split=split,transform=transforms.Compose([transforms.ToTensor()])) 141 | 142 | if split == 'train': 143 | sampler = ClassBalancedSampler(num_per_class, task.num_classes, task.train_num,shuffle=shuffle) 144 | else: 145 | sampler = ClassBalancedSampler(num_per_class, task.num_classes, task.test_num,shuffle=shuffle) 146 | 147 | loader = DataLoader(dataset, batch_size=num_per_class*task.num_classes, sampler=sampler) 148 | 149 | return loader 150 | 151 | --------------------------------------------------------------------------------