├── README.md
├── datas
├── miniImagenet
│ ├── create_miniImagenet.py
│ ├── proc_images.py
│ ├── test.csv
│ ├── train.csv
│ ├── trainval.csv
│ └── val.csv
└── omniglot_28x28.zip
├── docs
└── sosn.png
├── miniimagenet
├── miniimagenet_test_few_shot_SoSN.py
├── miniimagenet_train_few_shot_SoSN.py
├── models.py
├── task_generator.py
└── task_generator_test.py
├── omniglot
├── models.py
├── omniglot_test_few_shot_SoSN.py
├── omniglot_train_few_shot_SoSN.py
└── task_generator.py
└── openmic
└── p1-p2
├── models.py
├── openmic_test_few_shot_SoSN.py
├── openmic_train_few_shot_SoSN.py
├── task_generator.py
└── task_generator.py~
/README.md:
--------------------------------------------------------------------------------
1 | # Power Normalizing Second-order Similarity Network for Few-shot Learning
2 |
3 | Pytorch Implementation of IEEE WACV2019 "[Power Normalizing Second-order Similarity Network for Few-shot Learning](https://arxiv.org/pdf/1811.04167.pdf)".
4 | This is based on the code of Relation Net.
5 |
6 | Download formatted miniImagenet:
7 | https://drive.google.com/file/d/1QhUs2uwEbVqCVig6B6cQmI0WTdYy2C9t/view?usp=sharing
8 |
9 | Download pre-processed OpenMIC dataset via following links (Request form is required):
10 | http://users.cecs.anu.edu.au/~koniusz/openmic-dataset/
11 |
12 | Decompress the downloaded datasets into '/datas'.
13 | If you have any problem with the code, please contact hongguang.zhang@anu.edu.au.
14 |
15 | 
16 |
17 | __Requires.__
18 | ```
19 | pytorch-0.4.1
20 | numpy
21 | scipy
22 | ```
23 |
24 | __For miniImagenet training and testing, run following commands.__
25 |
26 | ```
27 | cd ./miniimagenet
28 | python miniimagenet_train_few_shot_SoSN.py -w 5 -s 1 -sigma 100
29 | python miniimagenet_test_few_shot_SoSN.py -w 5 -s 1 -sigma 100
30 | ```
31 |
32 | ## Citation
33 | ```
34 | @inproceedings{zhang2019power,
35 | title={Power normalizing second-order similarity network for few-shot learning},
36 | author={Zhang, Hongguang and Koniusz, Piotr},
37 | booktitle={2019 IEEE Winter Conference on Applications of Computer Vision (WACV)},
38 | pages={1185--1193},
39 | year={2019},
40 | organization={IEEE}
41 | }
42 | ```
43 |
--------------------------------------------------------------------------------
/datas/miniImagenet/create_miniImagenet.py:
--------------------------------------------------------------------------------
1 | '''
2 | This code creates the MiniImagenet dataset. Following the partitions given
3 | by Sachin Ravi and Hugo Larochelle in
4 | https://github.com/twitter/meta-learning-lstm/tree/master/data/miniImagenet
5 | '''
6 |
7 | import numpy as np
8 | import csv
9 | import glob, os
10 | from shutil import copyfile
11 | #import cv2
12 | from tqdm import tqdm
13 |
14 | pathImageNet = '/flush1/zha230/ILSVRC/Data/CLS-LOC/train'
15 | pathminiImageNet = '/data/zha230/LearningToCompare_FSL-master/datas/miniImagenet/'
16 | pathImages = os.path.join(pathminiImageNet,'images/')
17 | filesCSVSachinRavi = [os.path.join(pathminiImageNet,'train.csv'),
18 | os.path.join(pathminiImageNet,'val.csv'),
19 | os.path.join(pathminiImageNet,'test.csv')]
20 |
21 | # Check if the folder of images exist. If not create it.
22 | if not os.path.exists(pathImages):
23 | os.makedirs(pathImages)
24 |
25 | for filename in filesCSVSachinRavi:
26 | with open(filename) as csvfile:
27 | csv_reader = csv.reader(csvfile, delimiter=',')
28 | next(csv_reader, None)
29 | images = {}
30 | print('Reading IDs....')
31 | for row in tqdm(csv_reader):
32 | if row[1] in images.keys():
33 | images[row[1]].append(row[0])
34 | else:
35 | images[row[1]] = [row[0]]
36 | print(images.keys()[0])
37 | print((os.path.join(pathImageNet, images.keys()[0])))
38 |
39 | print('Writing photos....')
40 | for c in tqdm(images.keys()): # Iterate over all the classes
41 | lst_files = []
42 | for file in glob.glob(pathImageNet + "/*"+c+"*"):
43 | lst_files.append(file)
44 | # TODO: Sort by name of by index number of the image???
45 | # I sort by the number of the image
46 | lst_index = [int(i[i.rfind('_')+1:i.rfind('.')]) for i in lst_files]
47 | index_sorted = sorted(range(len(lst_index)), key=lst_index.__getitem__)
48 |
49 | # Now iterate
50 | index_selected = [int(i[i.index('.') - 4:i.index('.')]) for i in images[c]]
51 | selected_images = np.array(index_sorted)[np.array(index_selected) - 1]
52 | for i in np.arange(len(selected_images)):
53 | # read file and resize to 84x84x3
54 | #im = cv2.imread(os.path.join(pathImageNet,lst_files[selected_images[i]]))
55 | #im_resized = cv2.resize(im, (84, 84), interpolation=cv2.INTER_AREA)
56 | #cv2.imwrite(os.path.join(pathImages, images[c][i]),im_resized)
57 | copyfile(os.path.join(pathImageNet,lst_files[selected_images[i]]),os.path.join(pathImages, images[c][i]))
58 |
--------------------------------------------------------------------------------
/datas/miniImagenet/proc_images.py:
--------------------------------------------------------------------------------
1 | """
2 | code copied from https://github.com/cbfinn/maml/blob/master/data/miniImagenet/proc_images.py
3 | Script for converting from csv file datafiles to a directory for each image (which is how it is loaded by MAML code)
4 |
5 | Acquire miniImagenet from Ravi & Larochelle '17, along with the train, val, and test csv files. Put the
6 | csv files in the miniImagenet directory and put the images in the directory 'miniImagenet/images/'.
7 | Then run this script from the miniImagenet directory:
8 | cd data/miniImagenet/
9 | python proc_images.py
10 | """
11 | from __future__ import print_function
12 | import csv
13 | import glob
14 | import os
15 |
16 | from PIL import Image
17 |
18 | path_to_images = 'images/'
19 |
20 | all_images = glob.glob(path_to_images + '*')
21 |
22 | # Resize images
23 |
24 | for i, image_file in enumerate(all_images):
25 | im = Image.open(image_file)
26 | im = im.resize((84, 84), resample=Image.LANCZOS)
27 | im.save(image_file)
28 | if i % 500 == 0:
29 | print(i)
30 |
31 | # Put in correct directory
32 | for datatype in ['train', 'val', 'test']:
33 | os.system('mkdir ' + datatype)
34 |
35 | with open(datatype + '.csv', 'r') as f:
36 | reader = csv.reader(f, delimiter=',')
37 | last_label = ''
38 | for i, row in enumerate(reader):
39 | if i == 0: # skip the headers
40 | continue
41 | label = row[1]
42 | image_name = row[0]
43 | if label != last_label:
44 | cur_dir = datatype + '/' + label + '/'
45 | os.system('mkdir ' + cur_dir)
46 | last_label = label
47 | os.system('cp images/' + image_name + ' ' + cur_dir)
48 |
--------------------------------------------------------------------------------
/datas/omniglot_28x28.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HongguangZhang/SoSN-wacv19-master/f108b96007981388ca9f7af86ee7f3226adad874/datas/omniglot_28x28.zip
--------------------------------------------------------------------------------
/docs/sosn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HongguangZhang/SoSN-wacv19-master/f108b96007981388ca9f7af86ee7f3226adad874/docs/sosn.png
--------------------------------------------------------------------------------
/miniimagenet/miniimagenet_test_few_shot_SoSN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from torch.optim.lr_scheduler import StepLR
6 | import numpy as np
7 | import task_generator_test as tg
8 | import os
9 | import math
10 | import argparse
11 | import scipy as sp
12 | import scipy.stats
13 | import time
14 | import models
15 |
16 | parser = argparse.ArgumentParser(description="One Shot Visual Recognition")
17 | parser.add_argument("-f","--feature_dim",type = int, default = 64)
18 | parser.add_argument("-r","--relation_dim",type = int, default = 8)
19 | parser.add_argument("-w","--class_num",type = int, default = 5)
20 | parser.add_argument("-s","--support_num_per_class",type = int, default = 5)
21 | parser.add_argument("-b","--query_num_per_class",type = int, default = 15)
22 | parser.add_argument("-e","--episode",type = int, default= 100)
23 | parser.add_argument("-t","--test_episode", type = int, default = 600)
24 | parser.add_argument("-l","--learning_rate", type = float, default = 0.001)
25 | parser.add_argument("-g","--gpu",type=int, default=0)
26 | parser.add_argument("-u","--hidden_unit",type=int,default=10)
27 | parser.add_argument("-sigma","--sigma",type=float,default=100)
28 | args = parser.parse_args()
29 |
30 |
31 | # Hyper Parameters
32 | METHOD = "SoSN_Logit" + str(args.sigma) + "_Models"
33 | FEATURE_DIM = args.feature_dim
34 | RELATION_DIM = args.relation_dim
35 | CLASS_NUM = args.class_num
36 | SUPPORT_NUM_PER_CLASS = args.support_num_per_class
37 | QUERY_NUM_PER_CLASS = args.query_num_per_class
38 | EPISODE = args.episode
39 | TEST_EPISODE = args.test_episode
40 | LEARNING_RATE = args.learning_rate
41 | GPU = args.gpu
42 | HIDDEN_UNIT = args.hidden_unit
43 | SIGMA = args.sigma
44 |
45 | def power_norm(x, SIGMA):
46 | out = 2/(1 + torch.exp(-SIGMA*x)) - 1
47 | return out
48 |
49 | def mean_confidence_interval(data, confidence=0.95):
50 | a = 1.0*np.array(data)
51 | n = len(a)
52 | m = np.mean(a)
53 | s = scipy.stats.sem(a)
54 | h = s * sp.stats.t._ppf((1+confidence)/2., n-1)
55 | return m,h
56 |
57 | def weights_init(m):
58 | classname = m.__class__.__name__
59 | if classname.find('Conv') != -1:
60 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
61 | m.weight.data.normal_(0, math.sqrt(2. / n))
62 | if m.bias is not None:
63 | m.bias.data.zero_()
64 | elif classname.find('queryNorm') != -1:
65 | m.weight.data.fill_(1)
66 | m.bias.data.zero_()
67 | elif classname.find('Linear') != -1:
68 | n = m.weight.size(1)
69 | m.weight.data.normal_(0, 0.01)
70 | m.bias.data = torch.ones(m.bias.data.size())
71 |
72 | def main():
73 | metatrain_folders,metatest_folders = tg.mini_imagenet_folders()
74 |
75 | print("init neural networks")
76 | feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
77 | relation_network = models.SimilarityNetwork(FEATURE_DIM,RELATION_DIM).apply(weights_init).cuda(GPU)
78 |
79 | feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE)
80 | feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=50000,gamma=0.5)
81 | relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE)
82 | relation_network_scheduler = StepLR(relation_network_optim,step_size=50000,gamma=0.5)
83 |
84 | if os.path.exists(str(METHOD + "/miniImagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")):
85 | feature_encoder.load_state_dict(torch.load(str(METHOD + "/miniImagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")))
86 | print("load feature encoder success")
87 | if os.path.exists(str(METHOD + "/miniImagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")):
88 | relation_network.load_state_dict(torch.load(str(METHOD + "/miniImagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")))
89 | print("load relation network success")
90 | if os.path.exists(METHOD) == False:
91 | os.system('mkdir ' + METHOD)
92 |
93 | # Step 3: build graph
94 | print("Training...")
95 |
96 | best_accuracy = 0.0
97 | best_h = 0.0
98 | start = time.time()
99 |
100 | for episode in range(EPISODE):
101 | with torch.no_grad():
102 | # test
103 | print("Testing...")
104 | accuracies = []
105 | for i in range(TEST_EPISODE):
106 | total_rewards = 0
107 | counter = 0
108 | task = tg.MiniImagenetTask(metatest_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,15)
109 | support_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False)
110 | num_per_class = 5
111 | query_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=num_per_class,split="test",shuffle=False)
112 |
113 | support_images,support_labels = support_dataloader.__iter__().next()
114 | for query_images,query_labels in query_dataloader:
115 | query_size = query_labels.shape[0]
116 |
117 | support_features = feature_encoder(Variable(support_images).cuda(GPU))
118 | support_features = support_features.view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,FEATURE_DIM,19*19).sum(1)
119 | query_features = feature_encoder(Variable(query_images).cuda(GPU)).view(num_per_class*CLASS_NUM,64,19*19)
120 | H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU)
121 | H_query_features = Variable(torch.Tensor(num_per_class*CLASS_NUM, 1, 64, 64)).cuda(GPU)
122 |
123 | for d in range(support_features.size()[0]):
124 | s = support_features[d,:,:].squeeze(0)
125 | s = (1.0 / support_features.size()[2]) * s.mm(s.transpose(0,1))
126 | H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
127 | for d in range(query_features.size()[0]):
128 | s = query_features[d,:,:].squeeze(0)
129 | s = (1.0 / query_features.size()[2]) * s.mm(s.transpose(0,1))
130 | H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
131 |
132 |
133 | support_features_ext = H_support_features.unsqueeze(0).repeat(query_size,1,1,1,1)
134 |
135 | query_features_ext = H_query_features.unsqueeze(0).repeat(1*CLASS_NUM,1,1,1,1)
136 | query_features_ext = torch.transpose(query_features_ext,0,1)
137 | relation_pairs = torch.cat((support_features_ext, query_features_ext),2).view(-1,2,64,64)
138 | relations = relation_network(relation_pairs).view(-1,CLASS_NUM)
139 |
140 | _,predict_labels = torch.max(relations.data,1)
141 |
142 | rewards = [1 if predict_labels[j]==query_labels[j].cuda(GPU) else 0 for j in range(query_size)]
143 |
144 | total_rewards += np.sum(rewards)
145 | counter += query_size
146 |
147 | accuracy = total_rewards/1.0/counter
148 | accuracies.append(accuracy)
149 |
150 |
151 | test_accuracy,h = mean_confidence_interval(accuracies)
152 |
153 | print("test accuracy:",test_accuracy,"h:",h)
154 | print("best accuracy:",best_accuracy,"h:",best_h)
155 |
156 | if test_accuracy > best_accuracy:
157 | # save networks
158 | torch.save(feature_encoder.state_dict(),str(METHOD + "/miniImagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))
159 | torch.save(relation_network.state_dict(),str(METHOD + "/miniImagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))
160 | print("save networks for episode:",episode)
161 |
162 | best_accuracy = test_accuracy
163 | best_h = h
164 |
165 |
166 |
167 | if __name__ == '__main__':
168 | main()
169 |
--------------------------------------------------------------------------------
/miniimagenet/miniimagenet_train_few_shot_SoSN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from torch.optim.lr_scheduler import StepLR
6 | import numpy as np
7 | import task_generator as tg
8 | import os
9 | import math
10 | import argparse
11 | import scipy as sp
12 | import scipy.stats
13 | import time
14 | import models
15 |
16 | parser = argparse.ArgumentParser(description="One Shot Visual Recognition")
17 | parser.add_argument("-f","--feature_dim",type = int, default = 64)
18 | parser.add_argument("-r","--relation_dim",type = int, default = 8)
19 | parser.add_argument("-w","--class_num",type = int, default = 5)
20 | parser.add_argument("-s","--support_num_per_class",type = int, default = 5)
21 | parser.add_argument("-b","--query_num_per_class",type = int, default = 15)
22 | parser.add_argument("-e","--episode",type = int, default= 500000)
23 | parser.add_argument("-t","--test_episode", type = int, default = 600)
24 | parser.add_argument("-l","--learning_rate", type = float, default = 0.001)
25 | parser.add_argument("-g","--gpu",type=int, default=0)
26 | parser.add_argument("-u","--hidden_unit",type=int,default=10)
27 | parser.add_argument("-sigma","--sigma",type=float,default=100)
28 | args = parser.parse_args()
29 |
30 |
31 | # Hyper Parameters
32 | METHOD = "SoSN_Logit" + str(args.sigma) + "_Models"
33 | FEATURE_DIM = args.feature_dim
34 | RELATION_DIM = args.relation_dim
35 | CLASS_NUM = args.class_num
36 | SUPPORT_NUM_PER_CLASS = args.support_num_per_class
37 | QUERY_NUM_PER_CLASS = args.query_num_per_class
38 | EPISODE = args.episode
39 | TEST_EPISODE = args.test_episode
40 | LEARNING_RATE = args.learning_rate
41 | GPU = args.gpu
42 | HIDDEN_UNIT = args.hidden_unit
43 | SIGMA = args.sigma
44 |
45 | def power_norm(x, SIGMA):
46 | out = 2/(1 + torch.exp(-SIGMA*x)) - 1
47 | return out
48 |
49 | def mean_confidence_interval(data, confidence=0.95):
50 | a = 1.0*np.array(data)
51 | n = len(a)
52 | m = np.mean(a)
53 | s = scipy.stats.sem(a)
54 | h = s * sp.stats.t._ppf((1+confidence)/2., n-1)
55 | return m,h
56 |
57 | def weights_init(m):
58 | classname = m.__class__.__name__
59 | if classname.find('Conv') != -1:
60 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
61 | m.weight.data.normal_(0, math.sqrt(2. / n))
62 | if m.bias is not None:
63 | m.bias.data.zero_()
64 | elif classname.find('queryNorm') != -1:
65 | m.weight.data.fill_(1)
66 | m.bias.data.zero_()
67 | elif classname.find('Linear') != -1:
68 | n = m.weight.size(1)
69 | m.weight.data.normal_(0, 0.01)
70 | m.bias.data = torch.ones(m.bias.data.size())
71 |
72 | def main():
73 | metatrain_folders,metatest_folders = tg.mini_imagenet_folders()
74 |
75 | print("init neural networks")
76 | feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
77 | relation_network = models.SimilarityNetwork(FEATURE_DIM,RELATION_DIM).apply(weights_init).cuda(GPU)
78 |
79 | feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE)
80 | feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=50000,gamma=0.5)
81 | relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE)
82 | relation_network_scheduler = StepLR(relation_network_optim,step_size=50000,gamma=0.5)
83 |
84 | if os.path.exists(str(METHOD + "/miniImagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")):
85 | feature_encoder.load_state_dict(torch.load(str(METHOD + "/miniImagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")))
86 | print("load feature encoder success")
87 | if os.path.exists(str(METHOD + "/miniImagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")):
88 | relation_network.load_state_dict(torch.load(str(METHOD + "/miniImagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")))
89 | print("load relation network success")
90 | if os.path.exists(METHOD) == False:
91 | os.system('mkdir ' + METHOD)
92 |
93 | # Step 3: build graph
94 | print("Training...")
95 |
96 | best_accuracy = 0.0
97 | best_h = 0.0
98 | start = time.time()
99 |
100 | for episode in range(EPISODE):
101 | feature_encoder_scheduler.step(episode)
102 | relation_network_scheduler.step(episode)
103 |
104 | # init dataset
105 | task = tg.MiniImagenetTask(metatrain_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,QUERY_NUM_PER_CLASS)
106 | support_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False)
107 | query_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=QUERY_NUM_PER_CLASS,split="test",shuffle=True)
108 |
109 | # generate support data and query_data
110 | supports,support_labels = support_dataloader.__iter__().next() #25*3*84*84
111 | queries,query_labels = query_dataloader.__iter__().next()
112 |
113 | # generate features
114 | support_features = feature_encoder(Variable(supports).cuda(GPU)) # 25*64*19*19
115 | support_features = support_features.view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,FEATURE_DIM,19*19).sum(1) # size: CLASS_NUMx64x19x19
116 | query_features = feature_encoder(Variable(queries).cuda(GPU)).view(QUERY_NUM_PER_CLASS*CLASS_NUM,64,19*19) # size: QUERY_NUM_PER_CLASSx64x19x19
117 |
118 | # init second-order representations
119 | H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU)
120 | H_query_features = Variable(torch.Tensor(QUERY_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU)
121 | # second-order pooling
122 | for d in range(support_features.size()[0]):
123 | s = support_features[d,:,:].squeeze(0)
124 | s = (1.0 / support_features.size()[2]) * s.mm(s.transpose(0,1))
125 | H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
126 | for d in range(query_features.size()[0]):
127 | s = query_features[d,:,:].squeeze(0)
128 | s = (1.0 / query_features.size()[2]) * s.mm(s.transpose(0,1))
129 | H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
130 |
131 | # form the QURY_NUM_PER_CLASSxCLASS_NUM relation pairs
132 | support_features_ext = H_support_features.unsqueeze(0).repeat(QUERY_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)
133 | query_features_ext = H_query_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1)
134 | query_features_ext = torch.transpose(query_features_ext,0,1)
135 | relation_pairs = torch.cat((support_features_ext, query_features_ext),2).view(-1,2,64,64)
136 | # calculate relation scores
137 | relations = relation_network(relation_pairs).view(-1,CLASS_NUM)
138 |
139 | # define the loss function
140 | mse = nn.MSELoss().cuda(GPU)
141 | one_hot_labels = Variable(torch.zeros(QUERY_NUM_PER_CLASS*CLASS_NUM, CLASS_NUM).scatter_(1, query_labels.view(-1,1), 1).cuda(GPU))
142 | loss = mse(relations,one_hot_labels)
143 |
144 |
145 | # training
146 | feature_encoder.zero_grad()
147 | relation_network.zero_grad()
148 |
149 | loss.backward()
150 |
151 | feature_encoder_optim.step()
152 | relation_network_optim.step()
153 |
154 | if (episode+1)%100 == 0:
155 | print("episode:",episode+1,"loss",loss.data[0])
156 |
157 | if (episode+1)%2500 == 0:
158 | # test
159 | print("Testing...")
160 | accuracies = []
161 | for i in range(TEST_EPISODE):
162 | total_rewards = 0
163 | counter = 0
164 | task = tg.MiniImagenetTask(metatest_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,15)
165 | support_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False)
166 | num_per_class = 5
167 | query_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=num_per_class,split="test",shuffle=False)
168 |
169 | support_images,support_labels = support_dataloader.__iter__().next()
170 | for query_images,query_labels in query_dataloader:
171 | query_size = query_labels.shape[0]
172 |
173 | support_features = feature_encoder(Variable(support_images).cuda(GPU))
174 | support_features = support_features.view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,FEATURE_DIM,19*19).sum(1)
175 | query_features = feature_encoder(Variable(query_images).cuda(GPU)).view(num_per_class*CLASS_NUM,64,19*19)
176 | H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU)
177 | H_query_features = Variable(torch.Tensor(num_per_class*CLASS_NUM, 1, 64, 64)).cuda(GPU)
178 |
179 | for d in range(support_features.size()[0]):
180 | s = support_features[d,:,:].squeeze(0)
181 | s = (1.0 / support_features.size()[2]) * s.mm(s.transpose(0,1))
182 | H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
183 | for d in range(query_features.size()[0]):
184 | s = query_features[d,:,:].squeeze(0)
185 | s = (1.0 / query_features.size()[2]) * s.mm(s.transpose(0,1))
186 | H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
187 |
188 |
189 | support_features_ext = H_support_features.unsqueeze(0).repeat(query_size,1,1,1,1)
190 |
191 | query_features_ext = H_query_features.unsqueeze(0).repeat(1*CLASS_NUM,1,1,1,1)
192 | query_features_ext = torch.transpose(query_features_ext,0,1)
193 | relation_pairs = torch.cat((support_features_ext, query_features_ext),2).view(-1,2,64,64)
194 | relations = relation_network(relation_pairs).view(-1,CLASS_NUM)
195 |
196 | _,predict_labels = torch.max(relations.data,1)
197 |
198 | rewards = [1 if predict_labels[j]==query_labels[j].cuda(GPU) else 0 for j in range(query_size)]
199 |
200 | total_rewards += np.sum(rewards)
201 | counter += query_size
202 |
203 | accuracy = total_rewards/1.0/counter
204 | accuracies.append(accuracy)
205 |
206 |
207 | test_accuracy,h = mean_confidence_interval(accuracies)
208 |
209 | print("test accuracy:",test_accuracy,"h:",h)
210 | print("best accuracy:",best_accuracy,"h:",best_h)
211 |
212 | if test_accuracy > best_accuracy:
213 | # save networks
214 | torch.save(feature_encoder.state_dict(),str(METHOD + "/miniImagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))
215 | torch.save(relation_network.state_dict(),str(METHOD + "/miniImagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))
216 | print("save networks for episode:",episode)
217 |
218 | best_accuracy = test_accuracy
219 | best_h = h
220 |
221 |
222 |
223 | if __name__ == '__main__':
224 | main()
225 |
--------------------------------------------------------------------------------
/miniimagenet/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class FeatureEncoder(nn.Module):
6 | """docstring for ClassName"""
7 | def __init__(self):
8 | super(FeatureEncoder, self).__init__()
9 | self.layer1 = nn.Sequential(
10 | nn.Conv2d(3,64,kernel_size=3,padding=0),
11 | nn.BatchNorm2d(64, momentum=1, affine=True),
12 | nn.ReLU(),
13 | nn.MaxPool2d(2))
14 | self.layer2 = nn.Sequential(
15 | nn.Conv2d(64,64,kernel_size=3,padding=0),
16 | nn.BatchNorm2d(64, momentum=1, affine=True),
17 | nn.ReLU(),
18 | nn.MaxPool2d(2))
19 | self.layer3 = nn.Sequential(
20 | nn.Conv2d(64,64,kernel_size=3,padding=1),
21 | nn.BatchNorm2d(64, momentum=1, affine=True),
22 | nn.ReLU())
23 | self.layer4 = nn.Sequential(
24 | nn.Conv2d(64,64,kernel_size=3,padding=1),
25 | nn.BatchNorm2d(64, momentum=1, affine=True),
26 | nn.ReLU())
27 |
28 | def forward(self,x):
29 | out = self.layer1(x)
30 | out = self.layer2(out)
31 | out = self.layer3(out)
32 | out = self.layer4(out)
33 | return out
34 |
35 | class SimilarityNetwork(nn.Module):
36 | """docstring for RelationNetwork"""
37 | def __init__(self,input_size,hidden_size):
38 | super(SimilarityNetwork, self).__init__()
39 | self.layer1 = nn.Sequential(
40 | nn.Conv2d(2,64,kernel_size=3,padding=0),
41 | nn.BatchNorm2d(64, momentum=1, affine=True),
42 | nn.ReLU(),
43 | nn.MaxPool2d(2)) #Nx64x31x31
44 | self.layer2 = nn.Sequential(
45 | nn.Conv2d(64,64,kernel_size=3,padding=0),
46 | nn.BatchNorm2d(64, momentum=1, affine=True),
47 | nn.ReLU(),
48 | nn.MaxPool2d(2)) #Nx64x14x14
49 | self.layer3 = nn.Sequential(
50 | nn.Conv2d(64,64,kernel_size=3,padding=0),
51 | nn.BatchNorm2d(64, momentum=1, affine=True),
52 | nn.ReLU(),
53 | nn.MaxPool2d(2)) #Nx64x6x6
54 | self.layer4 = nn.Sequential(
55 | nn.Conv2d(64,64,kernel_size=3,padding=0),
56 | nn.BatchNorm2d(64, momentum=1, affine=True),
57 | nn.ReLU(),
58 | nn.MaxPool2d(2)) #Nx64x2x2
59 | self.fc1 = nn.Linear(input_size*4,hidden_size)
60 | self.fc2 = nn.Linear(hidden_size,1)
61 |
62 | def forward(self,x):
63 | out = self.layer1(x)
64 | out = self.layer2(out)
65 | out = self.layer3(out)
66 | out = self.layer4(out)
67 | out = out.view(out.size(0),-1)
68 | out = F.relu(self.fc1(out))
69 | out = F.sigmoid(self.fc2(out))
70 | return out
71 |
--------------------------------------------------------------------------------
/miniimagenet/task_generator.py:
--------------------------------------------------------------------------------
1 | # code is based on https://github.com/katerakelly/pytorch-maml
2 | import torchvision
3 | import torchvision.datasets as dset
4 | import torchvision.transforms as transforms
5 | import torch
6 | from torch.utils.data import DataLoader,Dataset
7 | import random
8 | import os
9 | from PIL import Image
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | from torch.utils.data.sampler import Sampler
13 |
14 | def imshow(img):
15 | npimg = img.numpy()
16 | plt.axis("off")
17 | plt.imshow(np.transpose(npimg,(1,2,0)))
18 | plt.show()
19 |
20 | class Rotate(object):
21 | def __init__(self, angle):
22 | self.angle = angle
23 | def __call__(self, x, mode="reflect"):
24 | x = x.rotate(self.angle)
25 | return x
26 |
27 | def mini_imagenet_folders():
28 | train_folder = '../datas/miniImagenet/train'
29 | test_folder = '../datas/miniImagenet/val'
30 |
31 | metatrain_folders = [os.path.join(train_folder, label) \
32 | for label in os.listdir(train_folder) \
33 | if os.path.isdir(os.path.join(train_folder, label)) \
34 | ]
35 | metatest_folders = [os.path.join(test_folder, label) \
36 | for label in os.listdir(test_folder) \
37 | if os.path.isdir(os.path.join(test_folder, label)) \
38 | ]
39 |
40 | random.seed(1)
41 | random.shuffle(metatrain_folders)
42 | random.shuffle(metatest_folders)
43 |
44 | return metatrain_folders,metatest_folders
45 |
46 | class MiniImagenetTask(object):
47 |
48 | def __init__(self, character_folders, num_classes, train_num,test_num):
49 |
50 | self.character_folders = character_folders
51 | self.num_classes = num_classes
52 | self.train_num = train_num
53 | self.test_num = test_num
54 |
55 | class_folders = random.sample(self.character_folders,self.num_classes)
56 | labels = np.array(range(len(class_folders)))
57 | labels = dict(zip(class_folders, labels))
58 | samples = dict()
59 |
60 | self.train_roots = []
61 | self.test_roots = []
62 | for c in class_folders:
63 |
64 | temp = [os.path.join(c, x) for x in os.listdir(c)]
65 | samples[c] = random.sample(temp, len(temp))
66 | random.shuffle(samples[c])
67 |
68 | self.train_roots += samples[c][:train_num]
69 | self.test_roots += samples[c][train_num:train_num+test_num]
70 |
71 | self.train_labels = [labels[self.get_class(x)] for x in self.train_roots]
72 | self.test_labels = [labels[self.get_class(x)] for x in self.test_roots]
73 |
74 | def get_class(self, sample):
75 | return os.path.join(*sample.split('/')[:-1])
76 |
77 |
78 | class FewShotDataset(Dataset):
79 |
80 | def __init__(self, task, split='train', transform=None, target_transform=None):
81 | self.transform = transform # Torch operations on the input image
82 | self.target_transform = target_transform
83 | self.task = task
84 | self.split = split
85 | self.image_roots = self.task.train_roots if self.split == 'train' else self.task.test_roots
86 | self.labels = self.task.train_labels if self.split == 'train' else self.task.test_labels
87 |
88 | def __len__(self):
89 | return len(self.image_roots)
90 |
91 | def __getitem__(self, idx):
92 | raise NotImplementedError("This is an abstract class. Subclass this class for your particular dataset.")
93 |
94 | class MiniImagenet(FewShotDataset):
95 |
96 | def __init__(self, *args, **kwargs):
97 | super(MiniImagenet, self).__init__(*args, **kwargs)
98 |
99 | def __getitem__(self, idx):
100 | image_root = self.image_roots[idx]
101 | image = Image.open(image_root)
102 | image = image.convert('RGB')
103 | if self.transform is not None:
104 | image = self.transform(image)
105 | label = self.labels[idx]
106 | if self.target_transform is not None:
107 | label = self.target_transform(label)
108 | return image, label
109 |
110 |
111 | class ClassBalancedSampler(Sampler):
112 | ''' Samples 'num_inst' examples each from 'num_cl' pools
113 | of examples of size 'num_per_class' '''
114 |
115 | def __init__(self, num_per_class, num_cl, num_inst,shuffle=True):
116 | self.num_per_class = num_per_class
117 | self.num_cl = num_cl
118 | self.num_inst = num_inst
119 | self.shuffle = shuffle
120 |
121 | def __iter__(self):
122 | # return a single list of indices, assuming that items will be grouped by class
123 | if self.shuffle:
124 | batch = [[i+j*self.num_inst for i in torch.randperm(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)]
125 | else:
126 | batch = [[i+j*self.num_inst for i in range(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)]
127 | batch = [item for sublist in batch for item in sublist]
128 |
129 | if self.shuffle:
130 | random.shuffle(batch)
131 | return iter(batch)
132 |
133 | def __len__(self):
134 | return 1
135 |
136 |
137 | def get_mini_imagenet_data_loader(task, num_per_class=1, split='train',shuffle = False):
138 | normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426])
139 |
140 | dataset = MiniImagenet(task,split=split,transform=transforms.Compose([transforms.ToTensor(),normalize]))
141 |
142 | if split == 'train':
143 | sampler = ClassBalancedSampler(num_per_class, task.num_classes, task.train_num,shuffle=shuffle)
144 | else:
145 | sampler = ClassBalancedSampler(num_per_class, task.num_classes, task.test_num,shuffle=shuffle)
146 |
147 | loader = DataLoader(dataset, batch_size=num_per_class*task.num_classes, sampler=sampler)
148 |
149 | return loader
150 |
151 |
--------------------------------------------------------------------------------
/miniimagenet/task_generator_test.py:
--------------------------------------------------------------------------------
1 | # code is based on https://github.com/katerakelly/pytorch-maml
2 | import torchvision
3 | import torchvision.datasets as dset
4 | import torchvision.transforms as transforms
5 | import torch
6 | from torch.utils.data import DataLoader,Dataset
7 | import random
8 | import os
9 | from PIL import Image
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | from torch.utils.data.sampler import Sampler
13 |
14 | def imshow(img):
15 | npimg = img.numpy()
16 | plt.axis("off")
17 | plt.imshow(np.transpose(npimg,(1,2,0)))
18 | plt.show()
19 |
20 | class Rotate(object):
21 | def __init__(self, angle):
22 | self.angle = angle
23 | def __call__(self, x, mode="reflect"):
24 | x = x.rotate(self.angle)
25 | return x
26 |
27 | def mini_imagenet_folders():
28 | train_folder = '../datas/miniImagenet/train'
29 | test_folder = '../datas/miniImagenet/test'
30 |
31 | metatrain_folders = [os.path.join(train_folder, label) \
32 | for label in os.listdir(train_folder) \
33 | if os.path.isdir(os.path.join(train_folder, label)) \
34 | ]
35 | metatest_folders = [os.path.join(test_folder, label) \
36 | for label in os.listdir(test_folder) \
37 | if os.path.isdir(os.path.join(test_folder, label)) \
38 | ]
39 |
40 | random.seed(1)
41 | random.shuffle(metatrain_folders)
42 | random.shuffle(metatest_folders)
43 |
44 | return metatrain_folders,metatest_folders
45 |
46 | class MiniImagenetTask(object):
47 |
48 | def __init__(self, character_folders, num_classes, train_num,test_num):
49 |
50 | self.character_folders = character_folders
51 | self.num_classes = num_classes
52 | self.train_num = train_num
53 | self.test_num = test_num
54 |
55 | class_folders = random.sample(self.character_folders,self.num_classes)
56 | labels = np.array(range(len(class_folders)))
57 | labels = dict(zip(class_folders, labels))
58 | samples = dict()
59 |
60 | self.train_roots = []
61 | self.test_roots = []
62 | for c in class_folders:
63 |
64 | temp = [os.path.join(c, x) for x in os.listdir(c)]
65 | samples[c] = random.sample(temp, len(temp))
66 | random.shuffle(samples[c])
67 |
68 | self.train_roots += samples[c][:train_num]
69 | self.test_roots += samples[c][train_num:train_num+test_num]
70 |
71 | self.train_labels = [labels[self.get_class(x)] for x in self.train_roots]
72 | self.test_labels = [labels[self.get_class(x)] for x in self.test_roots]
73 |
74 | def get_class(self, sample):
75 | return os.path.join(*sample.split('/')[:-1])
76 |
77 | class FewShotDataset(Dataset):
78 |
79 | def __init__(self, task, split='train', transform=None, target_transform=None):
80 | self.transform = transform # Torch operations on the input image
81 | self.target_transform = target_transform
82 | self.task = task
83 | self.split = split
84 | self.image_roots = self.task.train_roots if self.split == 'train' else self.task.test_roots
85 | self.labels = self.task.train_labels if self.split == 'train' else self.task.test_labels
86 |
87 | def __len__(self):
88 | return len(self.image_roots)
89 |
90 | def __getitem__(self, idx):
91 | raise NotImplementedError("This is an abstract class. Subclass this class for your particular dataset.")
92 |
93 | class MiniImagenet(FewShotDataset):
94 |
95 | def __init__(self, *args, **kwargs):
96 | super(MiniImagenet, self).__init__(*args, **kwargs)
97 |
98 | def __getitem__(self, idx):
99 | image_root = self.image_roots[idx]
100 | image = Image.open(image_root)
101 | image = image.convert('RGB')
102 | if self.transform is not None:
103 | image = self.transform(image)
104 | label = self.labels[idx]
105 | if self.target_transform is not None:
106 | label = self.target_transform(label)
107 | return image, label
108 |
109 |
110 | class ClassBalancedSampler(Sampler):
111 | ''' Samples 'num_inst' examples each from 'num_cl' pools
112 | of examples of size 'num_per_class' '''
113 |
114 | def __init__(self, num_cl, num_inst,shuffle=True):
115 |
116 | self.num_cl = num_cl
117 | self.num_inst = num_inst
118 | self.shuffle = shuffle
119 |
120 | def __iter__(self):
121 | # return a single list of indices, assuming that items will be grouped by class
122 | if self.shuffle:
123 | batches = [[i+j*self.num_inst for i in torch.randperm(self.num_inst)] for j in range(self.num_cl)]
124 | else:
125 | batches = [[i+j*self.num_inst for i in range(self.num_inst)] for j in range(self.num_cl)]
126 | batches = [[batches[j][i] for j in range(self.num_cl)] for i in range(self.num_inst)]
127 |
128 | if self.shuffle:
129 | random.shuffle(batches)
130 | for sublist in batches:
131 | random.shuffle(sublist)
132 | batches = [item for sublist in batches for item in sublist]
133 | return iter(batches)
134 |
135 | def __len__(self):
136 | return 1
137 |
138 | class ClassBalancedSamplerOld(Sampler):
139 | ''' Samples 'num_inst' examples each from 'num_cl' pools
140 | of examples of size 'num_per_class' '''
141 |
142 | def __init__(self, num_per_class, num_cl, num_inst,shuffle=True):
143 | self.num_per_class = num_per_class
144 | self.num_cl = num_cl
145 | self.num_inst = num_inst
146 | self.shuffle = shuffle
147 |
148 | def __iter__(self):
149 | # return a single list of indices, assuming that items will be grouped by class
150 | if self.shuffle:
151 | batch = [[i+j*self.num_inst for i in torch.randperm(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)]
152 | else:
153 | batch = [[i+j*self.num_inst for i in range(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)]
154 | batch = [item for sublist in batch for item in sublist]
155 |
156 | if self.shuffle:
157 | random.shuffle(batch)
158 | return iter(batch)
159 |
160 | def __len__(self):
161 | return 1
162 |
163 |
164 | def get_mini_imagenet_data_loader(task, num_per_class=1, split='train',shuffle = False):
165 | normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426])
166 |
167 | dataset = MiniImagenet(task,split=split,transform=transforms.Compose([transforms.ToTensor(),normalize]))
168 | if split == 'train':
169 | sampler = ClassBalancedSamplerOld(num_per_class,task.num_classes, task.train_num,shuffle=shuffle)
170 |
171 | else:
172 | sampler = ClassBalancedSampler(task.num_classes, task.test_num,shuffle=shuffle)
173 |
174 | loader = DataLoader(dataset, batch_size=num_per_class*task.num_classes, sampler=sampler)
175 | return loader
176 |
--------------------------------------------------------------------------------
/omniglot/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class FeatureEncoder(nn.Module):
6 | """docstring for ClassName"""
7 | def __init__(self):
8 | super(FeatureEncoder, self).__init__()
9 | self.layer1 = nn.Sequential(
10 | nn.Conv2d(1,64,kernel_size=3,padding=0),
11 | nn.BatchNorm2d(64, momentum=1, affine=True),
12 | nn.ReLU(),
13 | nn.MaxPool2d(2))
14 | self.layer2 = nn.Sequential(
15 | nn.Conv2d(64,64,kernel_size=3,padding=0),
16 | nn.BatchNorm2d(64, momentum=1, affine=True),
17 | nn.ReLU(),
18 | nn.MaxPool2d(2))
19 | self.layer3 = nn.Sequential(
20 | nn.Conv2d(64,64,kernel_size=3,padding=1),
21 | nn.BatchNorm2d(64, momentum=1, affine=True),
22 | nn.ReLU())
23 | self.layer4 = nn.Sequential(
24 | nn.Conv2d(64,64,kernel_size=3,padding=1),
25 | nn.BatchNorm2d(64, momentum=1, affine=True),
26 | nn.ReLU())
27 | def forward(self,x):
28 | out = self.layer1(x)
29 | out = self.layer2(out)
30 | out = self.layer3(out)
31 | out = self.layer4(out)
32 | return out
33 |
34 | class SimilarityNetwork(nn.Module):
35 | """docstring for RelationNetwork"""
36 | def __init__(self,input_size,hidden_size):
37 | super(SimilarityNetwork, self).__init__()
38 | self.layer1 = nn.Sequential(
39 | nn.Conv2d(2,64,kernel_size=3,padding=0),
40 | nn.BatchNorm2d(64, momentum=1, affine=True),
41 | nn.ReLU(),
42 | nn.MaxPool2d(2)) #Nx64x31x31
43 | self.layer2 = nn.Sequential(
44 | nn.Conv2d(64,64,kernel_size=3,padding=0),
45 | nn.BatchNorm2d(64, momentum=1, affine=True),
46 | nn.ReLU(),
47 | nn.MaxPool2d(2)) #Nx64x14x14
48 | self.layer3 = nn.Sequential(
49 | nn.Conv2d(64,64,kernel_size=3,padding=0),
50 | nn.BatchNorm2d(64, momentum=1, affine=True),
51 | nn.ReLU(),
52 | nn.MaxPool2d(2)) #Nx64x6x6
53 | self.layer4 = nn.Sequential(
54 | nn.Conv2d(64,64,kernel_size=3,padding=0),
55 | nn.BatchNorm2d(64, momentum=1, affine=True),
56 | nn.ReLU(),
57 | nn.MaxPool2d(2)) #Nx64x2x2
58 | self.fc1 = nn.Linear(input_size*4,hidden_size)
59 | self.fc2 = nn.Linear(hidden_size,1)
60 |
61 | def forward(self,x):
62 | out = self.layer1(x)
63 | out = self.layer2(out)
64 | out = self.layer3(out)
65 | out = self.layer4(out)
66 | out = out.view(out.size(0),-1)
67 | out = F.relu(self.fc1(out))
68 | out = F.sigmoid(self.fc2(out))
69 | return out
70 |
--------------------------------------------------------------------------------
/omniglot/omniglot_test_few_shot_SoSN.py:
--------------------------------------------------------------------------------
1 | #-------------------------------------
2 | # Project: Learning to Compare: Relation Network for Few-Shot Learning
3 | # Date: 2017.9.21
4 | # Author: Flood Sung
5 | # All Rights Reserved
6 | #-------------------------------------
7 |
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.autograd import Variable
13 | from torch.optim.lr_scheduler import StepLR
14 | import numpy as np
15 | import task_generator as tg
16 | import os
17 | import math
18 | import argparse
19 | import random
20 | import models
21 |
22 | parser = argparse.ArgumentParser(description="One Shot Visual Recognition")
23 | parser.add_argument("-f","--feature_dim",type = int, default = 64)
24 | parser.add_argument("-r","--relation_dim",type = int, default = 8)
25 | parser.add_argument("-w","--class_num",type = int, default = 5)
26 | parser.add_argument("-s","--support_num_per_class",type = int, default = 5)
27 | parser.add_argument("-b","--query_num_per_class",type = int, default = 2)
28 | parser.add_argument("-e","--episode",type = int, default= 100)
29 | parser.add_argument("-t","--query_episode", type = int, default = 1000)
30 | parser.add_argument("-l","--learning_rate", type = float, default = 0.001)
31 | parser.add_argument("-g","--gpu",type=int, default=0)
32 | parser.add_argument("-u","--hidden_unit",type=int,default=10)
33 | parser.add_argument("-sigma","--sigma", type = float, default = 1)
34 | parser.add_argument("-ts","--test_num_per_class",type = int, default = 5)
35 | args = parser.parse_args()
36 |
37 |
38 | # Hyper Parameters
39 | METHOD = "SoSN_LOGIT" + str(args.sigma) + "_Models"
40 | FEATURE_DIM = args.feature_dim
41 | RELATION_DIM = args.relation_dim
42 | CLASS_NUM = args.class_num
43 | SUPPORT_NUM_PER_CLASS = args.support_num_per_class
44 | QUERY_NUM_PER_CLASS = args.query_num_per_class
45 | TEST_NUM_PER_CLASS = args.test_num_per_class
46 | EPISODE = args.episode
47 | TEST_EPISODE = args.query_episode
48 | LEARNING_RATE = args.learning_rate
49 | GPU = args.gpu
50 | HIDDEN_UNIT = args.hidden_unit
51 | SIGMA = args.sigma
52 |
53 | def power_norm(x, SIGMA):
54 | out = 2/(1 + torch.exp(-SIGMA*x)) - 1
55 | return out
56 |
57 | def weights_init(m):
58 | classname = m.__class__.__name__
59 | if classname.find('Conv') != -1:
60 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
61 | m.weight.data.normal_(0, math.sqrt(2. / n))
62 | if m.bias is not None:
63 | m.bias.data.zero_()
64 | elif classname.find('BatchNorm') != -1:
65 | m.weight.data.fill_(1)
66 | m.bias.data.zero_()
67 | elif classname.find('Linear') != -1:
68 | n = m.weight.size(1)
69 | m.weight.data.normal_(0, 0.01)
70 | m.bias.data = torch.ones(m.bias.data.size())
71 |
72 | def main():
73 | # Step 1: init data folders
74 | print("init data folders")
75 | # init character folders for dataset construction
76 | metatrain_character_folders,metaquery_character_folders = tg.omniglot_character_folders()
77 |
78 | # Step 2: init neural networks
79 | print("init neural networks")
80 |
81 | feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
82 | relation_network = models.SimilarityNetwork(FEATURE_DIM,RELATION_DIM).apply(weights_init).cuda(GPU)
83 |
84 | feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE)
85 | feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=50000,gamma=0.1)
86 | relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE)
87 | relation_network_scheduler = StepLR(relation_network_optim,step_size=50000,gamma=0.1)
88 |
89 | if os.path.exists(str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")):
90 | feature_encoder.load_state_dict(torch.load(str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")))
91 | print("load feature encoder success")
92 | if os.path.exists(str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")):
93 | relation_network.load_state_dict(torch.load(str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")))
94 | print("load similarity network success")
95 | if os.path.exists(METHOD) == False:
96 | os.system('mkdir ' + METHOD)
97 |
98 | # Step 3: build graph
99 | print("Training...")
100 |
101 | best_accuracy = 0.0
102 | best_h = 0.0
103 |
104 | for episode in range(EPISODE):
105 | with torch.no_grad():
106 | # query
107 | print("Testing...")
108 | total_rewards = 0
109 |
110 | for i in range(TEST_EPISODE):
111 | degrees = random.choice([0,90,180,270])
112 | task = tg.OmniglotTask(metaquery_character_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,TEST_NUM_PER_CLASS,)
113 | support_dataloader = tg.get_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False,rotation=degrees)
114 | query_dataloader = tg.get_data_loader(task,num_per_class=TEST_NUM_PER_CLASS,split="query",shuffle=True,rotation=degrees)
115 |
116 | support_images,support_labels = support_dataloader.__iter__().next()
117 | query_images,query_labels = query_dataloader.__iter__().next()
118 |
119 | # calculate features
120 | support_features = feature_encoder(Variable(support_images).cuda(GPU)) # 5x64
121 | support_features = support_features.view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,FEATURE_DIM,25).sum(1)
122 | query_features = feature_encoder(Variable(query_images).cuda(GPU)).view(TEST_NUM_PER_CLASS*CLASS_NUM,64,25)
123 |
124 | H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU)
125 | H_query_features = Variable(torch.Tensor(TEST_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU)
126 | # HOP features
127 | for d in range(support_features.size(0)):
128 | s = support_features[d,:,:].squeeze(0)
129 | s = (1.0 / support_features.size(2)) * s.mm(s.t())
130 | H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
131 | for d in range(query_features.size(0)):
132 | s = query_features[d,:,:].squeeze(0)
133 | s = (1.0 / query_features.size(2)) * s.mm(s.t())
134 | H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
135 |
136 | # calculate relations
137 | # each query support link to every supports to calculate relations
138 | # to form a 100x128 matrix for relation network
139 | support_features_ext = H_support_features.unsqueeze(0).repeat(TEST_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)
140 | query_features_ext = H_query_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1)
141 | query_features_ext = torch.transpose(query_features_ext,0,1)
142 |
143 | relation_pairs = torch.cat((support_features_ext,query_features_ext),2).view(-1,2,64,64)
144 | relations = relation_network(relation_pairs).view(-1,CLASS_NUM)
145 |
146 | _,predict_labels = torch.max(relations.data,1)
147 |
148 | rewards = [1 if predict_labels[j]==query_labels[j].cuda(GPU) else 0 for j in range(CLASS_NUM*TEST_NUM_PER_CLASS)]
149 |
150 | total_rewards += np.sum(rewards)
151 |
152 | test_accuracy = total_rewards/1.0/CLASS_NUM/TEST_NUM_PER_CLASS/TEST_EPISODE
153 |
154 | print("query accuracy:",test_accuracy)
155 | print("best accuracy:",best_accuracy)
156 |
157 | if test_accuracy > best_accuracy:
158 | best_accuracy = test_accuracy
159 |
160 |
161 |
162 |
163 |
164 | if __name__ == '__main__':
165 | main()
166 |
--------------------------------------------------------------------------------
/omniglot/omniglot_train_few_shot_SoSN.py:
--------------------------------------------------------------------------------
1 | #-------------------------------------
2 | # Project: Learning to Compare: Relation Network for Few-Shot Learning
3 | # Date: 2017.9.21
4 | # Author: Flood Sung
5 | # All Rights Reserved
6 | #-------------------------------------
7 |
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.autograd import Variable
13 | from torch.optim.lr_scheduler import StepLR
14 | import numpy as np
15 | import task_generator as tg
16 | import os
17 | import math
18 | import argparse
19 | import random
20 | import models
21 |
22 | parser = argparse.ArgumentParser(description="One Shot Visual Recognition")
23 | parser.add_argument("-f","--feature_dim",type = int, default = 64)
24 | parser.add_argument("-r","--relation_dim",type = int, default = 8)
25 | parser.add_argument("-w","--class_num",type = int, default = 5)
26 | parser.add_argument("-s","--support_num_per_class",type = int, default = 5)
27 | parser.add_argument("-b","--query_num_per_class",type = int, default = 2)
28 | parser.add_argument("-e","--episode",type = int, default= 1000000)
29 | parser.add_argument("-t","--query_episode", type = int, default = 1000)
30 | parser.add_argument("-l","--learning_rate", type = float, default = 0.001)
31 | parser.add_argument("-g","--gpu",type=int, default=0)
32 | parser.add_argument("-u","--hidden_unit",type=int,default=10)
33 | parser.add_argument("-sigma","--sigma", type = float, default = 1)
34 | parser.add_argument("-ts","--test_num_per_class",type = int, default = 5)
35 | args = parser.parse_args()
36 |
37 |
38 | # Hyper Parameters
39 | METHOD = "SoSN_LOGIT" + str(args.sigma) + "_Models"
40 | FEATURE_DIM = args.feature_dim
41 | RELATION_DIM = args.relation_dim
42 | CLASS_NUM = args.class_num
43 | SUPPORT_NUM_PER_CLASS = args.support_num_per_class
44 | QUERY_NUM_PER_CLASS = args.query_num_per_class
45 | TEST_NUM_PER_CLASS = args.test_num_per_class
46 | EPISODE = args.episode
47 | TEST_EPISODE = args.query_episode
48 | LEARNING_RATE = args.learning_rate
49 | GPU = args.gpu
50 | HIDDEN_UNIT = args.hidden_unit
51 | SIGMA = args.sigma
52 |
53 | def power_norm(x, SIGMA):
54 | out = 2/(1 + torch.exp(-SIGMA*x)) - 1
55 | return out
56 |
57 | def weights_init(m):
58 | classname = m.__class__.__name__
59 | if classname.find('Conv') != -1:
60 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
61 | m.weight.data.normal_(0, math.sqrt(2. / n))
62 | if m.bias is not None:
63 | m.bias.data.zero_()
64 | elif classname.find('BatchNorm') != -1:
65 | m.weight.data.fill_(1)
66 | m.bias.data.zero_()
67 | elif classname.find('Linear') != -1:
68 | n = m.weight.size(1)
69 | m.weight.data.normal_(0, 0.01)
70 | m.bias.data = torch.ones(m.bias.data.size())
71 |
72 | def main():
73 | # Step 1: init data folders
74 | print("init data folders")
75 | # init character folders for dataset construction
76 | metatrain_character_folders,metaquery_character_folders = tg.omniglot_character_folders()
77 |
78 | # Step 2: init neural networks
79 | print("init neural networks")
80 |
81 | feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
82 | relation_network = models.SimilarityNetwork(FEATURE_DIM,RELATION_DIM).apply(weights_init).cuda(GPU)
83 |
84 | feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE)
85 | feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=50000,gamma=0.1)
86 | relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE)
87 | relation_network_scheduler = StepLR(relation_network_optim,step_size=50000,gamma=0.1)
88 |
89 | if os.path.exists(str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")):
90 | feature_encoder.load_state_dict(torch.load(str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")))
91 | print("load feature encoder success")
92 | if os.path.exists(str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")):
93 | relation_network.load_state_dict(torch.load(str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")))
94 | print("load similarity network success")
95 | if os.path.exists(METHOD) == False:
96 | os.system('mkdir ' + METHOD)
97 |
98 | # Step 3: build graph
99 | print("Training...")
100 |
101 | best_accuracy = 0.0
102 | best_h = 0.0
103 |
104 | for episode in range(EPISODE):
105 |
106 | feature_encoder_scheduler.step(episode)
107 | relation_network_scheduler.step(episode)
108 |
109 | # init dataset
110 | # support_dataloader is to obtain previous supports for compare
111 | # query_dataloader is to query supports for training
112 | degrees = random.choice([0,90,180,270])
113 | task = tg.OmniglotTask(metatrain_character_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,QUERY_NUM_PER_CLASS)
114 | support_dataloader = tg.get_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False,rotation=degrees)
115 | query_dataloader = tg.get_data_loader(task,num_per_class=QUERY_NUM_PER_CLASS,split="test",shuffle=True,rotation=degrees)
116 |
117 |
118 | # support datas
119 | supports,support_labels = support_dataloader.__iter__().next()
120 | queries,query_labels = query_dataloader.__iter__().next()
121 |
122 | # calculate features
123 | support_features = feature_encoder(Variable(supports).cuda(GPU)) # 5x64*5*5
124 | support_features = support_features.view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,FEATURE_DIM,25).sum(1) #size: CLASS_NUMx64x5x5
125 | query_features = feature_encoder(Variable(queries).cuda(GPU)).view(QUERY_NUM_PER_CLASS*CLASS_NUM,64,25) #size: QUERY_NUM_PER_CLASSx64x5x5
126 |
127 | H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU)
128 | H_query_features = Variable(torch.Tensor(QUERY_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU)
129 | # HOP features
130 | for d in range(support_features.size(0)):
131 | s = support_features[d,:,:].squeeze(0)
132 | s = (1.0 / support_features.size(2)) * s.mm(s.t())
133 | H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
134 | for d in range(query_features.size(0)):
135 | s = query_features[d,:,:].squeeze(0)
136 | s = (1.0 / query_features.size(2)) * s.mm(s.t())
137 | H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
138 |
139 | # form the QURY_NUM_PER_CLASSxCLASS_NUM relation pairs
140 | support_features_ext = H_support_features.unsqueeze(0).repeat(QUERY_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)
141 | query_features_ext = H_query_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1)
142 | query_features_ext = torch.transpose(query_features_ext,0,1)
143 | relation_pairs = torch.cat((support_features_ext, query_features_ext),2).view(-1,2,64,64)
144 | # calculate relation scores
145 | relations = relation_network(relation_pairs).view(-1,CLASS_NUM)
146 |
147 | # define the loss function
148 | mse = nn.MSELoss().cuda(GPU)
149 | one_hot_labels = Variable(torch.zeros(QUERY_NUM_PER_CLASS*CLASS_NUM, CLASS_NUM).scatter_(1, query_labels.view(-1,1), 1)).cuda(GPU)
150 | loss = mse(relations,one_hot_labels)
151 |
152 | # training
153 | feature_encoder.zero_grad()
154 | relation_network.zero_grad()
155 |
156 | loss.backward()
157 |
158 | feature_encoder_optim.step()
159 | relation_network_optim.step()
160 |
161 | if (episode+1)%100 == 0:
162 | print("episode:",episode+1,"loss",loss.data[0])
163 |
164 | if (episode)%2500 == 0:
165 | with torch.no_grad():
166 | # query
167 | print("Testing...")
168 | total_rewards = 0
169 |
170 | for i in range(TEST_EPISODE):
171 | degrees = random.choice([0,90,180,270])
172 | task = tg.OmniglotTask(metaquery_character_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,TEST_NUM_PER_CLASS,)
173 | support_dataloader = tg.get_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False,rotation=degrees)
174 | query_dataloader = tg.get_data_loader(task,num_per_class=TEST_NUM_PER_CLASS,split="test",shuffle=True,rotation=degrees)
175 |
176 | support_images,support_labels = support_dataloader.__iter__().next()
177 | query_images,query_labels = query_dataloader.__iter__().next()
178 |
179 | # calculate features
180 | support_features = feature_encoder(Variable(support_images).cuda(GPU)) # 5x64
181 | support_features = support_features.view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,FEATURE_DIM,25).sum(1)
182 | query_features = feature_encoder(Variable(query_images).cuda(GPU)).view(TEST_NUM_PER_CLASS*CLASS_NUM,64,25)
183 |
184 | H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU)
185 | H_query_features = Variable(torch.Tensor(TEST_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU)
186 | # HOP features
187 | for d in range(support_features.size(0)):
188 | s = support_features[d,:,:].squeeze(0)
189 | s = (1.0 / support_features.size(2)) * s.mm(s.t())
190 | H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
191 | for d in range(query_features.size(0)):
192 | s = query_features[d,:,:].squeeze(0)
193 | s = (1.0 / query_features.size(2)) * s.mm(s.t())
194 | H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
195 |
196 | # calculate relations
197 | # each query support link to every supports to calculate relations
198 | # to form a 100x128 matrix for relation network
199 | support_features_ext = H_support_features.unsqueeze(0).repeat(TEST_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)
200 | query_features_ext = H_query_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1)
201 | query_features_ext = torch.transpose(query_features_ext,0,1)
202 |
203 | relation_pairs = torch.cat((support_features_ext,query_features_ext),2).view(-1,2,64,64)
204 | relations = relation_network(relation_pairs).view(-1,CLASS_NUM)
205 |
206 | _,predict_labels = torch.max(relations.data,1)
207 |
208 | rewards = [1 if predict_labels[j]==query_labels[j].cuda(GPU) else 0 for j in range(CLASS_NUM*TEST_NUM_PER_CLASS)]
209 |
210 | total_rewards += np.sum(rewards)
211 |
212 | test_accuracy = total_rewards/1.0/CLASS_NUM/TEST_NUM_PER_CLASS/TEST_EPISODE
213 |
214 | print("query accuracy:",test_accuracy)
215 | print("best accuracy:",best_accuracy)
216 |
217 | if test_accuracy > best_accuracy:
218 | # save networks
219 | torch.save(feature_encoder.state_dict(),str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))
220 | torch.save(relation_network.state_dict(),str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))
221 | print("save networks for episode:",episode)
222 | best_accuracy = test_accuracy
223 |
224 |
225 |
226 |
227 |
228 | if __name__ == '__main__':
229 | main()
230 |
--------------------------------------------------------------------------------
/omniglot/task_generator.py:
--------------------------------------------------------------------------------
1 | # code is based on https://github.com/katerakelly/pytorch-maml
2 | import torchvision
3 | import torchvision.datasets as dset
4 | import torchvision.transforms as transforms
5 | import torch
6 | from torch.utils.data import DataLoader,Dataset
7 | import random
8 | import os
9 | from PIL import Image
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | from torch.utils.data.sampler import Sampler
13 |
14 | def imshow(img):
15 | npimg = img.numpy()
16 | plt.axis("off")
17 | plt.imshow(np.transpose(npimg,(1,2,0)))
18 | plt.show()
19 |
20 | class Rotate(object):
21 | def __init__(self, angle):
22 | self.angle = angle
23 | def __call__(self, x, mode="reflect"):
24 | x = x.rotate(self.angle)
25 | return x
26 |
27 | def omniglot_character_folders():
28 | data_folder = '../datas/omniglot_resized/'
29 |
30 | character_folders = [os.path.join(data_folder, family, character) \
31 | for family in os.listdir(data_folder) \
32 | if os.path.isdir(os.path.join(data_folder, family)) \
33 | for character in os.listdir(os.path.join(data_folder, family))]
34 | random.seed(1)
35 | random.shuffle(character_folders)
36 |
37 | num_train = 1200
38 | metatrain_character_folders = character_folders[:num_train]
39 | metaval_character_folders = character_folders[num_train:]
40 |
41 | return metatrain_character_folders,metaval_character_folders
42 |
43 | class OmniglotTask(object):
44 | # This class is for task generation for both meta training and meta testing.
45 | # For meta training, we use all 20 samples without valid set (empty here).
46 | # For meta testing, we use 1 or 5 shot samples for training, while using the same number of samples for validation.
47 | # If set num_samples = 20 and chracter_folders = metatrain_character_folders, we generate tasks for meta training
48 | # If set num_samples = 1 or 5 and chracter_folders = metatest_chracter_folders, we generate tasks for meta testing
49 | def __init__(self, character_folders, num_classes, train_num,test_num):
50 |
51 | self.character_folders = character_folders
52 | self.num_classes = num_classes
53 | self.train_num = train_num
54 | self.test_num = test_num
55 |
56 | class_folders = random.sample(self.character_folders,self.num_classes)
57 | labels = np.array(range(len(class_folders)))
58 | labels = dict(zip(class_folders, labels))
59 | samples = dict()
60 |
61 | self.train_roots = []
62 | self.test_roots = []
63 | for c in class_folders:
64 |
65 | temp = [os.path.join(c, x) for x in os.listdir(c)]
66 | samples[c] = random.sample(temp, len(temp))
67 |
68 | self.train_roots += samples[c][:train_num]
69 | self.test_roots += samples[c][train_num:train_num+test_num]
70 |
71 | self.train_labels = [labels[self.get_class(x)] for x in self.train_roots]
72 | self.test_labels = [labels[self.get_class(x)] for x in self.test_roots]
73 |
74 | def get_class(self, sample):
75 | return os.path.join(*sample.split('/')[:-1])
76 |
77 |
78 | class FewShotDataset(Dataset):
79 |
80 | def __init__(self, task, split='train', transform=None, target_transform=None):
81 | self.transform = transform # Torch operations on the input image
82 | self.target_transform = target_transform
83 | self.task = task
84 | self.split = split
85 | self.image_roots = self.task.train_roots if self.split == 'train' else self.task.test_roots
86 | self.labels = self.task.train_labels if self.split == 'train' else self.task.test_labels
87 |
88 | def __len__(self):
89 | return len(self.image_roots)
90 |
91 | def __getitem__(self, idx):
92 | raise NotImplementedError("This is an abstract class. Subclass this class for your particular dataset.")
93 |
94 |
95 | class Omniglot(FewShotDataset):
96 |
97 | def __init__(self, *args, **kwargs):
98 | super(Omniglot, self).__init__(*args, **kwargs)
99 |
100 | def __getitem__(self, idx):
101 | image_root = self.image_roots[idx]
102 | image = Image.open(image_root)
103 | image = image.convert('L')
104 | image = image.resize((28,28), resample=Image.LANCZOS) # per Chelsea's implementation
105 | #image = np.array(image, dtype=np.float32)
106 | if self.transform is not None:
107 | image = self.transform(image)
108 | label = self.labels[idx]
109 | if self.target_transform is not None:
110 | label = self.target_transform(label)
111 | return image, label
112 |
113 | class ClassBalancedSampler(Sampler):
114 | ''' Samples 'num_inst' examples each from 'num_cl' pools
115 | of examples of size 'num_per_class' '''
116 |
117 | def __init__(self, num_per_class, num_cl, num_inst,shuffle=True):
118 | self.num_per_class = num_per_class
119 | self.num_cl = num_cl
120 | self.num_inst = num_inst
121 | self.shuffle = shuffle
122 |
123 | def __iter__(self):
124 | # return a single list of indices, assuming that items will be grouped by class
125 | if self.shuffle:
126 | batch = [[i+j*self.num_inst for i in torch.randperm(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)]
127 | else:
128 | batch = [[i+j*self.num_inst for i in range(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)]
129 | batch = [item for sublist in batch for item in sublist]
130 |
131 | if self.shuffle:
132 | random.shuffle(batch)
133 | return iter(batch)
134 |
135 | def __len__(self):
136 | return 1
137 |
138 |
139 | def get_data_loader(task, num_per_class=1, split='train',shuffle=True,rotation=0):
140 | # NOTE: batch size here is # instances PER CLASS
141 | normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426])
142 |
143 | dataset = Omniglot(task,split=split,transform=transforms.Compose([Rotate(rotation),transforms.ToTensor(),normalize]))
144 |
145 | if split == 'train':
146 | sampler = ClassBalancedSampler(num_per_class, task.num_classes, task.train_num,shuffle=shuffle)
147 | else:
148 | sampler = ClassBalancedSampler(num_per_class, task.num_classes, task.test_num,shuffle=shuffle)
149 | loader = DataLoader(dataset, batch_size=num_per_class*task.num_classes, sampler=sampler)
150 |
151 | return loader
152 |
153 |
--------------------------------------------------------------------------------
/openmic/p1-p2/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class FeatureEncoder(nn.Module):
6 | """docstring for ClassName"""
7 | def __init__(self):
8 | super(FeatureEncoder, self).__init__()
9 | self.layer1 = nn.Sequential(
10 | nn.Conv2d(3,64,kernel_size=3,padding=0),
11 | nn.BatchNorm2d(64, momentum=1, affine=True),
12 | nn.ReLU(),
13 | nn.MaxPool2d(2))
14 | self.layer2 = nn.Sequential(
15 | nn.Conv2d(64,64,kernel_size=3,padding=0),
16 | nn.BatchNorm2d(64, momentum=1, affine=True),
17 | nn.ReLU(),
18 | nn.MaxPool2d(2))
19 | self.layer3 = nn.Sequential(
20 | nn.Conv2d(64,64,kernel_size=3,padding=1),
21 | nn.BatchNorm2d(64, momentum=1, affine=True),
22 | nn.ReLU())
23 | self.layer4 = nn.Sequential(
24 | nn.Conv2d(64,64,kernel_size=3,padding=1),
25 | nn.BatchNorm2d(64, momentum=1, affine=True),
26 | nn.ReLU())
27 |
28 | def forward(self,x):
29 | out = self.layer1(x)
30 | out = self.layer2(out)
31 | out = self.layer3(out)
32 | out = self.layer4(out)
33 | #out = out.view(out.size(0),-1)
34 | return out # 64
35 |
36 | class SimilarityNetwork(nn.Module):
37 | """docstring for RelationNetwork"""
38 | def __init__(self,input_size,hidden_size):
39 | super(SimilarityNetwork, self).__init__()
40 | self.layer1 = nn.Sequential(
41 | nn.Conv2d(2,64,kernel_size=3,padding=0),
42 | nn.BatchNorm2d(64, momentum=1, affine=True),
43 | nn.ReLU(),
44 | nn.MaxPool2d(2)) #Nx64x31x31
45 | self.layer2 = nn.Sequential(
46 | nn.Conv2d(64,64,kernel_size=3,padding=0),
47 | nn.BatchNorm2d(64, momentum=1, affine=True),
48 | nn.ReLU(),
49 | nn.MaxPool2d(2)) #Nx64x14x14
50 | self.layer3 = nn.Sequential(
51 | nn.Conv2d(64,64,kernel_size=3,padding=0),
52 | nn.BatchNorm2d(64, momentum=1, affine=True),
53 | nn.ReLU(),
54 | nn.MaxPool2d(2)) #Nx64x6x6
55 | self.layer4 = nn.Sequential(
56 | nn.Conv2d(64,64,kernel_size=3,padding=0),
57 | nn.BatchNorm2d(64, momentum=1, affine=True),
58 | nn.ReLU(),
59 | nn.MaxPool2d(2)) #Nx64x2x2
60 | self.fc1 = nn.Linear(input_size*4,hidden_size)
61 | self.fc2 = nn.Linear(hidden_size,1)
62 |
63 | def forward(self,x):
64 | out = self.layer1(x)
65 | out = self.layer2(out)
66 | out = self.layer3(out)
67 | out = self.layer4(out)
68 | out = out.view(out.size(0),-1)
69 | out = F.relu(self.fc1(out))
70 | out = F.sigmoid(self.fc2(out))
71 | return out
72 |
--------------------------------------------------------------------------------
/openmic/p1-p2/openmic_test_few_shot_SoSN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from torch.optim.lr_scheduler import StepLR
6 | import numpy as np
7 | import task_generator as tg
8 | import os
9 | import math
10 | import argparse
11 | import scipy as sp
12 | import scipy.stats
13 | import models
14 |
15 | parser = argparse.ArgumentParser(description="One Shot Visual Recognition")
16 | parser.add_argument("-f","--feature_dim",type = int, default = 64)
17 | parser.add_argument("-r","--relation_dim",type = int, default = 8)
18 | parser.add_argument("-w","--class_num",type = int, default = 5)
19 | parser.add_argument("-s","--support_num_per_class",type = int, default = 1)
20 | parser.add_argument("-b","--query_num_per_class",type = int, default = 2)
21 | parser.add_argument("-e","--episode",type = int, default = 100)
22 | parser.add_argument("-t","--query_episode", type = int, default = 1000)
23 | parser.add_argument("-l","--learning_rate", type = float, default = 0.0001)
24 | parser.add_argument("-g","--gpu",type=int, default=0)
25 | parser.add_argument("-u","--hidden_unit",type=int,default=10)
26 | parser.add_argument("-sigma","--sigma", type = float, default = 1)
27 | parser.add_argument("-lam","--lamb", type = float, default = 0)
28 | args = parser.parse_args()
29 |
30 | # Hyper Parameters
31 | METHOD = "SoSN_Logit" + str(args.sigma) + "_Models"
32 | FEATURE_DIM = args.feature_dim
33 | RELATION_DIM = args.relation_dim
34 | CLASS_NUM = args.class_num
35 | SUPPORT_NUM_PER_CLASS = 1
36 | QUERY_NUM_PER_CLASS = args.query_num_per_class
37 | EPISODE = args.episode
38 | TEST_EPISODE = args.query_episode
39 | LEARNING_RATE = args.learning_rate
40 | GPU = args.gpu
41 | HIDDEN_UNIT = args.hidden_unit
42 | SIGMA = args.sigma
43 | LAMBDA = args.lamb
44 |
45 | def power_norm(x, SIGMA):
46 | out = 2/(1 + torch.exp(-SIGMA*x)) - 1
47 | return out
48 |
49 | def mean_confidence_interval(data, confidence=0.95):
50 | a = 1.0*np.array(data)
51 | n = len(a)
52 | m, se = np.mean(a), scipy.stats.sem(a)
53 | h = se * sp.stats.t._ppf((1+confidence)/2., n-1)
54 | return m,h
55 |
56 | def weights_init(m):
57 | classname = m.__class__.__name__
58 | if classname.find('Conv') != -1:
59 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
60 | m.weight.data.normal_(0, math.sqrt(2. / n))
61 | if m.bias is not None:
62 | m.bias.data.zero_()
63 | elif classname.find('BatchNorm') != -1:
64 | m.weight.data.fill_(1)
65 | m.bias.data.zero_()
66 | elif classname.find('Linear') != -1:
67 | n = m.weight.size(1)
68 | m.weight.data.normal_(0, 0.01)
69 | m.bias.data = torch.ones(m.bias.data.size())
70 |
71 | def main():
72 | metatrain_folders,metaquery_folders = tg.mini_imagenet_folders()
73 |
74 | print("init neural networks")
75 |
76 | feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
77 | relation_network = models.SimilarityNetwork(FEATURE_DIM,RELATION_DIM).apply(weights_init).cuda(GPU)
78 |
79 | feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE)
80 | feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=50000,gamma=0.5)
81 | relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE)
82 | relation_network_scheduler = StepLR(relation_network_optim,step_size=50000,gamma=0.5)
83 |
84 | if os.path.exists(str(METHOD + "/feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")):
85 | feature_encoder.load_state_dict(torch.load(str(METHOD + "/feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")))
86 | print("load feature encoder success")
87 | if os.path.exists(str(METHOD + "/relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")):
88 | relation_network.load_state_dict(torch.load(str(METHOD + "/relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")))
89 | print("load relation network success")
90 | if os.path.exists(METHOD) == False:
91 | os.system('mkdir ' + METHOD)
92 |
93 | # Step 3: build graph
94 | print("Training...")
95 |
96 | best_accuracy = 0.0
97 | best_h = 0.0
98 |
99 | for episode in range(EPISODE):
100 | with torch.no_grad():
101 | print("Testing...")
102 | accuracies = []
103 | for i in range(TEST_EPISODE):
104 | total_rewards = 0
105 | counter = 0
106 | task = tg.MiniImagenetTask(metaquery_folders,CLASS_NUM,1,2)
107 | support_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=1,split="train",shuffle=False)
108 | num_per_class = 2
109 | query_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=num_per_class,split="query",shuffle=True)
110 | support_images,support_labels = support_dataloader.__iter__().next()
111 | for query_images,query_labels in query_dataloader:
112 | query_size = query_labels.shape[0]
113 | # calculate features
114 | support_features = feature_encoder(Variable(support_images).cuda(GPU)).view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,64,19**2).sum(1)
115 | query_features = feature_encoder(Variable(query_images).cuda(GPU)).view(num_per_class*CLASS_NUM,64,19**2)
116 |
117 | H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU)
118 | H_query_features = Variable(torch.Tensor(num_per_class*CLASS_NUM, 1, 64, 64)).cuda(GPU)
119 | # HOP features
120 | for d in range(support_features.size()[0]):
121 | s = support_features[d,:,:].squeeze(0)
122 | s = s - LAMBDA * s.mean(1).repeat(1,s.size()[1]).view(s.size())
123 | s = (1.0 / support_features.size()[2]) * s.mm(s.transpose(0,1))
124 | H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
125 | for d in range(query_features.size()[0]):
126 | s = query_features[d,:,:].squeeze(0)
127 | s = s - LAMBDA * s.mean(1).repeat(1,s.size()[1]).view(s.size())
128 | s = (1.0 / query_features.size()[2]) * s.mm(s.transpose(0,1))
129 | H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
130 |
131 | # form relation pairs
132 | support_features_ext = H_support_features.unsqueeze(0).repeat(query_size,1,1,1,1)
133 | query_features_ext = H_query_features.unsqueeze(0).repeat(1*CLASS_NUM,1,1,1,1)
134 | query_features_ext = torch.transpose(query_features_ext,0,1)
135 | relation_pairs = torch.cat((support_features_ext,query_features_ext),2).view(-1,2,64,64)
136 | # calculate relation scores
137 | relations = relation_network(relation_pairs).view(-1,CLASS_NUM)
138 |
139 | _,predict_labels = torch.max(relations.data,1)
140 |
141 | rewards = [1 if predict_labels[j]==query_labels[j].cuda(GPU) else 0 for j in range(query_size)]
142 |
143 | total_rewards += np.sum(rewards)
144 | counter += query_size
145 | accuracy = total_rewards/1.0/counter
146 | accuracies.append(accuracy)
147 |
148 | test_accuracy,h = mean_confidence_interval(accuracies)
149 |
150 | print("Test accuracy:", test_accuracy,"h:", h)
151 | print("Best accuracy: ", best_accuracy, "h:", best_h)
152 |
153 | if test_accuracy > best_accuracy:
154 | best_accuracy = test_accuracy
155 | best_h = h
156 |
157 |
158 |
159 |
160 |
161 | if __name__ == '__main__':
162 | main()
163 |
--------------------------------------------------------------------------------
/openmic/p1-p2/openmic_train_few_shot_SoSN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from torch.optim.lr_scheduler import StepLR
6 | import numpy as np
7 | import task_generator as tg
8 | import os
9 | import math
10 | import argparse
11 | import scipy as sp
12 | import scipy.stats
13 | import models
14 |
15 | parser = argparse.ArgumentParser(description="One Shot Visual Recognition")
16 | parser.add_argument("-f","--feature_dim",type = int, default = 64)
17 | parser.add_argument("-r","--relation_dim",type = int, default = 8)
18 | parser.add_argument("-w","--class_num",type = int, default = 5)
19 | parser.add_argument("-s","--support_num_per_class",type = int, default = 1)
20 | parser.add_argument("-b","--query_num_per_class",type = int, default = 2)
21 | parser.add_argument("-e","--episode",type = int, default= 50000)
22 | parser.add_argument("-t","--query_episode", type = int, default = 1000)
23 | parser.add_argument("-l","--learning_rate", type = float, default = 0.0001)
24 | parser.add_argument("-g","--gpu",type=int, default=0)
25 | parser.add_argument("-u","--hidden_unit",type=int,default=10)
26 | parser.add_argument("-sigma","--sigma", type = float, default = 1)
27 | parser.add_argument("-lam","--lamb", type = float, default = 0)
28 | args = parser.parse_args()
29 |
30 | # Hyper Parameters
31 | METHOD = "SoSN_Logit" + str(args.sigma) + "_Models"
32 | FEATURE_DIM = args.feature_dim
33 | RELATION_DIM = args.relation_dim
34 | CLASS_NUM = args.class_num
35 | SUPPORT_NUM_PER_CLASS = 1
36 | QUERY_NUM_PER_CLASS = args.query_num_per_class
37 | EPISODE = args.episode
38 | TEST_EPISODE = args.query_episode
39 | LEARNING_RATE = args.learning_rate
40 | GPU = args.gpu
41 | HIDDEN_UNIT = args.hidden_unit
42 | SIGMA = args.sigma
43 | LAMBDA = args.lamb
44 |
45 | def power_norm(x, SIGMA):
46 | out = 2/(1 + torch.exp(-SIGMA*x)) - 1
47 | return out
48 |
49 | def mean_confidence_interval(data, confidence=0.95):
50 | a = 1.0*np.array(data)
51 | n = len(a)
52 | m, se = np.mean(a), scipy.stats.sem(a)
53 | h = se * sp.stats.t._ppf((1+confidence)/2., n-1)
54 | return m,h
55 |
56 | def weights_init(m):
57 | classname = m.__class__.__name__
58 | if classname.find('Conv') != -1:
59 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
60 | m.weight.data.normal_(0, math.sqrt(2. / n))
61 | if m.bias is not None:
62 | m.bias.data.zero_()
63 | elif classname.find('BatchNorm') != -1:
64 | m.weight.data.fill_(1)
65 | m.bias.data.zero_()
66 | elif classname.find('Linear') != -1:
67 | n = m.weight.size(1)
68 | m.weight.data.normal_(0, 0.01)
69 | m.bias.data = torch.ones(m.bias.data.size())
70 |
71 | def main():
72 | # Step 1: init data folders
73 | print("init data folders")
74 | # init character folders for dataset construction
75 | metatrain_folders,metaquery_folders = tg.mini_imagenet_folders()
76 |
77 | # Step 2: init neural networks
78 | print("init neural networks")
79 |
80 | feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
81 | relation_network = models.SimilarityNetwork(FEATURE_DIM,RELATION_DIM).apply(weights_init).cuda(GPU)
82 |
83 | feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE)
84 | feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=50000,gamma=0.5)
85 | relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE)
86 | relation_network_scheduler = StepLR(relation_network_optim,step_size=50000,gamma=0.5)
87 |
88 | if os.path.exists(str(METHOD + "/feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")):
89 | feature_encoder.load_state_dict(torch.load(str(METHOD + "/feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")))
90 | print("load feature encoder success")
91 | if os.path.exists(str(METHOD + "/relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")):
92 | relation_network.load_state_dict(torch.load(str(METHOD + "/relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")))
93 | print("load relation network success")
94 | if os.path.exists(METHOD) == False:
95 | os.system('mkdir ' + METHOD)
96 |
97 | # Step 3: build graph
98 | print("Training...")
99 |
100 | best_accuracy = 0.0
101 | best_h = 0.0
102 |
103 | for episode in range(EPISODE):
104 | feature_encoder_scheduler.step(episode)
105 | relation_network_scheduler.step(episode)
106 |
107 | # init dataset
108 | # support_dataloader is to obtain previous supports for compare
109 | # query_dataloader is to query supports for training
110 | task = tg.MiniImagenetTask(metatrain_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,QUERY_NUM_PER_CLASS)
111 | support_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False)
112 | query_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=QUERY_NUM_PER_CLASS,split="test",shuffle=True)
113 |
114 | # support datas
115 | supports,support_labels = support_dataloader.__iter__().next()
116 | queries,query_labels = query_dataloader.__iter__().next()
117 |
118 | # calculate features
119 | support_features = feature_encoder(Variable(supports).cuda(GPU)).view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,64,19**2).sum(1) # 5x64*19*19
120 | query_features = feature_encoder(Variable(queries).cuda(GPU)).view(QUERY_NUM_PER_CLASS*CLASS_NUM,64,19**2) # 20x64*19*19
121 | H_support_features = Variable(torch.Tensor(SUPPORT_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU)
122 | H_query_features = Variable(torch.Tensor(QUERY_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU)
123 | # HOP features
124 | for d in range(support_features.size()[0]):
125 | s = support_features[d,:,:].squeeze(0)
126 | s = s - LAMBDA * s.mean(1).repeat(1,s.size()[1]).view(s.size())
127 | s = (1.0 / support_features.size()[2]) * s.mm(s.transpose(0,1))
128 | H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
129 | for d in range(query_features.size()[0]):
130 | s = query_features[d,:,:].squeeze(0)
131 | s = s - LAMBDA * s.mean(1).repeat(1,s.size()[1]).view(s.size())
132 | s = (1.0 / query_features.size()[2]) * s.mm(s.transpose(0,1))
133 | H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
134 |
135 | # form relation pairs
136 | support_features_ext = H_support_features.unsqueeze(0).repeat(QUERY_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)
137 | query_features_ext = H_query_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1)
138 | query_features_ext = torch.transpose(query_features_ext,0,1)
139 | relation_pairs = torch.cat((support_features_ext,query_features_ext),2).view(-1,2,64,64)
140 | # calculate relation scores
141 | relations = relation_network(relation_pairs).view(-1,CLASS_NUM*SUPPORT_NUM_PER_CLASS)
142 |
143 | # define loss function
144 | mse = nn.MSELoss().cuda(GPU)
145 | one_hot_labels = Variable(torch.zeros(QUERY_NUM_PER_CLASS*CLASS_NUM, CLASS_NUM).scatter_(1, query_labels.view(-1,1), 1)).cuda(GPU)
146 | loss = mse(relations,one_hot_labels)
147 |
148 | # updating network parameters with their gradients
149 | feature_encoder.zero_grad()
150 | relation_network.zero_grad()
151 |
152 | loss.backward()
153 |
154 | feature_encoder_optim.step()
155 | relation_network_optim.step()
156 |
157 | if (episode+1)%100 == 0:
158 | print("episode:",episode+1,"loss",loss.data[0])
159 |
160 | if episode%500 == 0:
161 | # query
162 | print("Testing...")
163 |
164 | accuracies = []
165 | for i in range(TEST_EPISODE):
166 | with torch.no_grad():
167 | total_rewards = 0
168 | counter = 0
169 | task = tg.MiniImagenetTask(metaquery_folders,CLASS_NUM,1,2)
170 | support_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=1,split="train",shuffle=False)
171 | num_per_class = 2
172 | query_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=num_per_class,split="query",shuffle=True)
173 | support_images,support_labels = support_dataloader.__iter__().next()
174 | for query_images,query_labels in query_dataloader:
175 | query_size = query_labels.shape[0]
176 | # calculate features
177 | support_features = feature_encoder(Variable(support_images).cuda(GPU)).view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,64,19**2).sum(1)
178 | query_features = feature_encoder(Variable(query_images).cuda(GPU)).view(num_per_class*CLASS_NUM,64,19**2)
179 |
180 | H_support_features = Variable(torch.Tensor(SUPPORT_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU)
181 | H_query_features = Variable(torch.Tensor(num_per_class*CLASS_NUM, 1, 64, 64)).cuda(GPU)
182 | # HOP features
183 | for d in range(support_features.size()[0]):
184 | s = support_features[d,:,:].squeeze(0)
185 | s = s - LAMBDA * s.mean(1).repeat(1,s.size()[1]).view(s.size())
186 | s = (1.0 / support_features.size()[2]) * s.mm(s.transpose(0,1))
187 | H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
188 | for d in range(query_features.size()[0]):
189 | s = query_features[d,:,:].squeeze(0)
190 | s = s - LAMBDA * s.mean(1).repeat(1,s.size()[1]).view(s.size())
191 | s = (1.0 / query_features.size()[2]) * s.mm(s.transpose(0,1))
192 | H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
193 |
194 | # form relation pairs
195 | support_features_ext = H_support_features.unsqueeze(0).repeat(query_size,1,1,1,1)
196 | query_features_ext = H_query_features.unsqueeze(0).repeat(1*CLASS_NUM,1,1,1,1)
197 | query_features_ext = torch.transpose(query_features_ext,0,1)
198 | relation_pairs = torch.cat((support_features_ext,query_features_ext),2).view(-1,2,64,64)
199 | # calculate relation scores
200 | relations = relation_network(relation_pairs).view(-1,CLASS_NUM)
201 |
202 | _,predict_labels = torch.max(relations.data,1)
203 |
204 | rewards = [1 if predict_labels[j]==query_labels[j].cuda(GPU) else 0 for j in range(query_size)]
205 |
206 | total_rewards += np.sum(rewards)
207 | counter += query_size
208 | accuracy = total_rewards/1.0/counter
209 | accuracies.append(accuracy)
210 |
211 | test_accuracy,h = mean_confidence_interval(accuracies)
212 |
213 | print("Test accuracy:", test_accuracy,"h:",h)
214 | print("Best accuracy: ", best_accuracy, "h:", best_h)
215 |
216 | if test_accuracy > best_accuracy:
217 | # save networks
218 | torch.save(feature_encoder.state_dict(),str(METHOD + "/feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))
219 | torch.save(relation_network.state_dict(),str(METHOD + "/relation_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))
220 | print("save networks for episode:",episode)
221 |
222 | best_accuracy = test_accuracy
223 | best_h = h
224 |
225 |
226 |
227 |
228 |
229 | if __name__ == '__main__':
230 | main()
231 |
--------------------------------------------------------------------------------
/openmic/p1-p2/task_generator.py:
--------------------------------------------------------------------------------
1 | # code is based on https://github.com/katerakelly/pytorch-maml
2 | import torchvision
3 | import torchvision.datasets as dset
4 | import torchvision.transforms as transforms
5 | import torch
6 | from torch.utils.data import DataLoader,Dataset
7 | import random
8 | import os
9 | from PIL import Image
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | from torch.utils.data.sampler import Sampler
13 |
14 | def imshow(img):
15 | npimg = img.numpy()
16 | plt.axis("off")
17 | plt.imshow(np.transpose(npimg,(1,2,0)))
18 | plt.show()
19 |
20 | class Rotate(object):
21 | def __init__(self, angle):
22 | self.angle = angle
23 | def __call__(self, x, mode="reflect"):
24 | x = x.rotate(self.angle)
25 | return x
26 |
27 | def mini_imagenet_folders():
28 | train_folder = '../../datas/openmic/part1'
29 | test_folder = '../../datas/openmic/part2'
30 |
31 | metatrain_folders = [os.path.join(train_folder, label) \
32 | for label in os.listdir(train_folder) \
33 | if os.path.isdir(os.path.join(train_folder, label)) \
34 | ]
35 | metatest_folders = [os.path.join(test_folder, label) \
36 | for label in os.listdir(test_folder) \
37 | if os.path.isdir(os.path.join(test_folder, label)) \
38 | ]
39 |
40 | random.seed(1)
41 | random.shuffle(metatrain_folders)
42 | random.shuffle(metatest_folders)
43 |
44 | return metatrain_folders,metatest_folders
45 |
46 | class MiniImagenetTask(object):
47 |
48 | def __init__(self, character_folders, num_classes, train_num,test_num):
49 |
50 | self.character_folders = character_folders
51 | self.num_classes = num_classes
52 | self.train_num = train_num
53 | self.test_num = test_num
54 |
55 | class_folders = random.sample(self.character_folders,self.num_classes)
56 | labels = np.array(range(len(class_folders)))
57 | labels = dict(zip(class_folders, labels))
58 | samples = dict()
59 |
60 | self.train_roots = []
61 | self.test_roots = []
62 | for c in class_folders:
63 |
64 | temp = [os.path.join(c, x) for x in os.listdir(c)]
65 | samples[c] = random.sample(temp, len(temp))
66 | random.shuffle(samples[c])
67 |
68 | self.train_roots += samples[c][:train_num]
69 | self.test_roots += samples[c][train_num:train_num+test_num]
70 |
71 | self.train_labels = [labels[self.get_class(x)] for x in self.train_roots]
72 | self.test_labels = [labels[self.get_class(x)] for x in self.test_roots]
73 |
74 | def get_class(self, sample):
75 | return os.path.join(*sample.split('/')[:-1])
76 |
77 |
78 | class FewShotDataset(Dataset):
79 |
80 | def __init__(self, task, split='train', transform=None, target_transform=None):
81 | self.transform = transform # Torch operations on the input image
82 | self.target_transform = target_transform
83 | self.task = task
84 | self.split = split
85 | self.image_roots = self.task.train_roots if self.split == 'train' else self.task.test_roots
86 | self.labels = self.task.train_labels if self.split == 'train' else self.task.test_labels
87 |
88 | def __len__(self):
89 | return len(self.image_roots)
90 |
91 | def __getitem__(self, idx):
92 | raise NotImplementedError("This is an abstract class. Subclass this class for your particular dataset.")
93 |
94 | class MiniImagenet(FewShotDataset):
95 |
96 | def __init__(self, *args, **kwargs):
97 | super(MiniImagenet, self).__init__(*args, **kwargs)
98 |
99 | def __getitem__(self, idx):
100 | image_root = self.image_roots[idx]
101 | image = Image.open(image_root)
102 | image = image.convert('RGB')
103 | if self.transform is not None:
104 | image = self.transform(image)
105 | label = self.labels[idx]
106 | if self.target_transform is not None:
107 | label = self.target_transform(label)
108 | return image, label
109 |
110 |
111 | class ClassBalancedSampler(Sampler):
112 | ''' Samples 'num_inst' examples each from 'num_cl' pools
113 | of examples of size 'num_per_class' '''
114 |
115 | def __init__(self, num_per_class, num_cl, num_inst,shuffle=True):
116 | self.num_per_class = num_per_class
117 | self.num_cl = num_cl
118 | self.num_inst = num_inst
119 | self.shuffle = shuffle
120 |
121 | def __iter__(self):
122 | # return a single list of indices, assuming that items will be grouped by class
123 | if self.shuffle:
124 | batch = [[i+j*self.num_inst for i in torch.randperm(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)]
125 | else:
126 | batch = [[i+j*self.num_inst for i in range(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)]
127 | batch = [item for sublist in batch for item in sublist]
128 |
129 | if self.shuffle:
130 | random.shuffle(batch)
131 | return iter(batch)
132 |
133 | def __len__(self):
134 | return 1
135 |
136 |
137 | def get_mini_imagenet_data_loader(task, num_per_class=1, split='train',shuffle = False):
138 | normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426])
139 |
140 | dataset = MiniImagenet(task,split=split,transform=transforms.Compose([transforms.ToTensor()]))
141 |
142 | if split == 'train':
143 | sampler = ClassBalancedSampler(num_per_class, task.num_classes, task.train_num,shuffle=shuffle)
144 | else:
145 | sampler = ClassBalancedSampler(num_per_class, task.num_classes, task.test_num,shuffle=shuffle)
146 |
147 | loader = DataLoader(dataset, batch_size=num_per_class*task.num_classes, sampler=sampler)
148 |
149 | return loader
150 |
151 |
--------------------------------------------------------------------------------
/openmic/p1-p2/task_generator.py~:
--------------------------------------------------------------------------------
1 | # code is based on https://github.com/katerakelly/pytorch-maml
2 | import torchvision
3 | import torchvision.datasets as dset
4 | import torchvision.transforms as transforms
5 | import torch
6 | from torch.utils.data import DataLoader,Dataset
7 | import random
8 | import os
9 | from PIL import Image
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | from torch.utils.data.sampler import Sampler
13 |
14 | def imshow(img):
15 | npimg = img.numpy()
16 | plt.axis("off")
17 | plt.imshow(np.transpose(npimg,(1,2,0)))
18 | plt.show()
19 |
20 | class Rotate(object):
21 | def __init__(self, angle):
22 | self.angle = angle
23 | def __call__(self, x, mode="reflect"):
24 | x = x.rotate(self.angle)
25 | return x
26 |
27 | def mini_imagenet_folders():
28 | train_folder = '../../datas/museum/gls/train'
29 | test_folder = '../../datas/museum/gls/test'
30 |
31 | metatrain_folders = [os.path.join(train_folder, label) \
32 | for label in os.listdir(train_folder) \
33 | if os.path.isdir(os.path.join(train_folder, label)) \
34 | ]
35 | metatest_folders = [os.path.join(test_folder, label) \
36 | for label in os.listdir(test_folder) \
37 | if os.path.isdir(os.path.join(test_folder, label)) \
38 | ]
39 |
40 | random.seed(1)
41 | random.shuffle(metatrain_folders)
42 | random.shuffle(metatest_folders)
43 |
44 | return metatrain_folders,metatest_folders
45 |
46 | class MiniImagenetTask(object):
47 |
48 | def __init__(self, character_folders, num_classes, train_num,test_num):
49 |
50 | self.character_folders = character_folders
51 | self.num_classes = num_classes
52 | self.train_num = train_num
53 | self.test_num = test_num
54 |
55 | class_folders = random.sample(self.character_folders,self.num_classes)
56 | labels = np.array(range(len(class_folders)))
57 | labels = dict(zip(class_folders, labels))
58 | samples = dict()
59 |
60 | self.train_roots = []
61 | self.test_roots = []
62 | for c in class_folders:
63 |
64 | temp = [os.path.join(c, x) for x in os.listdir(c)]
65 | samples[c] = random.sample(temp, len(temp))
66 | random.shuffle(samples[c])
67 |
68 | self.train_roots += samples[c][:train_num]
69 | self.test_roots += samples[c][train_num:train_num+test_num]
70 |
71 | self.train_labels = [labels[self.get_class(x)] for x in self.train_roots]
72 | self.test_labels = [labels[self.get_class(x)] for x in self.test_roots]
73 |
74 | def get_class(self, sample):
75 | return os.path.join(*sample.split('/')[:-1])
76 |
77 |
78 | class FewShotDataset(Dataset):
79 |
80 | def __init__(self, task, split='train', transform=None, target_transform=None):
81 | self.transform = transform # Torch operations on the input image
82 | self.target_transform = target_transform
83 | self.task = task
84 | self.split = split
85 | self.image_roots = self.task.train_roots if self.split == 'train' else self.task.test_roots
86 | self.labels = self.task.train_labels if self.split == 'train' else self.task.test_labels
87 |
88 | def __len__(self):
89 | return len(self.image_roots)
90 |
91 | def __getitem__(self, idx):
92 | raise NotImplementedError("This is an abstract class. Subclass this class for your particular dataset.")
93 |
94 | class MiniImagenet(FewShotDataset):
95 |
96 | def __init__(self, *args, **kwargs):
97 | super(MiniImagenet, self).__init__(*args, **kwargs)
98 |
99 | def __getitem__(self, idx):
100 | image_root = self.image_roots[idx]
101 | image = Image.open(image_root)
102 | image = image.convert('RGB')
103 | if self.transform is not None:
104 | image = self.transform(image)
105 | label = self.labels[idx]
106 | if self.target_transform is not None:
107 | label = self.target_transform(label)
108 | return image, label
109 |
110 |
111 | class ClassBalancedSampler(Sampler):
112 | ''' Samples 'num_inst' examples each from 'num_cl' pools
113 | of examples of size 'num_per_class' '''
114 |
115 | def __init__(self, num_per_class, num_cl, num_inst,shuffle=True):
116 | self.num_per_class = num_per_class
117 | self.num_cl = num_cl
118 | self.num_inst = num_inst
119 | self.shuffle = shuffle
120 |
121 | def __iter__(self):
122 | # return a single list of indices, assuming that items will be grouped by class
123 | if self.shuffle:
124 | batch = [[i+j*self.num_inst for i in torch.randperm(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)]
125 | else:
126 | batch = [[i+j*self.num_inst for i in range(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)]
127 | batch = [item for sublist in batch for item in sublist]
128 |
129 | if self.shuffle:
130 | random.shuffle(batch)
131 | return iter(batch)
132 |
133 | def __len__(self):
134 | return 1
135 |
136 |
137 | def get_mini_imagenet_data_loader(task, num_per_class=1, split='train',shuffle = False):
138 | normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426])
139 |
140 | dataset = MiniImagenet(task,split=split,transform=transforms.Compose([transforms.ToTensor()]))
141 |
142 | if split == 'train':
143 | sampler = ClassBalancedSampler(num_per_class, task.num_classes, task.train_num,shuffle=shuffle)
144 | else:
145 | sampler = ClassBalancedSampler(num_per_class, task.num_classes, task.test_num,shuffle=shuffle)
146 |
147 | loader = DataLoader(dataset, batch_size=num_per_class*task.num_classes, sampler=sampler)
148 |
149 | return loader
150 |
151 |
--------------------------------------------------------------------------------