├── CovaMNet_Test_5way1shot.py
├── CovaMNet_Test_5way5shot.py
├── CovaMNet_Train_5way1shot.py
├── CovaMNet_Train_5way5shot.py
├── LICENSE
├── README.md
├── dataset
├── CubBird
│ └── CubBird_prepare_csv.py
├── StanfordCar
│ └── StanforCar_prepare_csv.py
├── StanfordDog
│ └── StanfordDog_prepare_csv.py
└── datasets_csv.py
├── imgs
├── CovaMNet.bmp
├── result_finegrained.bmp
└── results_miniImageNet.bmp
├── models
└── network.py
└── results
└── CovaMNet_miniImageNet_Conv64_5_Way_1_Shot
├── model_best.pth.tar
└── opt_resutls.txt
/CovaMNet_Test_5way1shot.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | Author: Wenbin Li (liwenbin.nju@gmail.com)
6 | Date: Jan. 14, 2019
7 | Version: V0
8 |
9 | Citation:
10 | @inproceedings{li2019CovaMNet,
11 | title={Distribution Consistency based Covariance Metric Networks for Few-shot Learning},
12 | author={Li, Wenbin and Xu, Jinglin and Huo, Jing and Wang, Lei and Gao Yang and Luo, Jiebo},
13 | booktitle={AAAI},
14 | year={2019}
15 | }
16 | """
17 |
18 |
19 | from __future__ import print_function
20 | import argparse
21 | import os
22 | import random
23 | import shutil
24 | import numpy as np
25 | import torch
26 | import torch.nn as nn
27 | import torch.nn.parallel
28 | import torch.backends.cudnn as cudnn
29 | import torch.optim as optim
30 | import torch.utils.data
31 | import torchvision.datasets as dset
32 | import torchvision.transforms as transforms
33 | import torchvision.utils as vutils
34 | from torch.autograd import grad
35 | import time
36 | from torch import autograd
37 | from PIL import ImageFile
38 | import scipy as sp
39 | import scipy.stats
40 | import sys
41 | sys.dont_write_bytecode = True
42 |
43 |
44 |
45 | # ============================ Data & Networks =====================================
46 | from dataset.datasets_csv import Imagefolder_csv
47 | import models.network as CovaNet
48 | # ==================================================================================
49 |
50 |
51 | ImageFile.LOAD_TRUNCATED_IMAGES = True
52 | os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
53 | os.environ['CUDA_VISIBLE_DEVICES']='0'
54 |
55 |
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument('--dataset_dir', default='/Datasets/miniImageNet--ravi', help='the path of the data')
58 | parser.add_argument('--data_name', default='miniImageNet', help='miniImageNet|StanfordDog|StanfordCar|CubBird')
59 | parser.add_argument('--mode', default='test', help='train|val|test')
60 | parser.add_argument('--outf', default='./results/CovaMNet')
61 | parser.add_argument('--resume', default='', type=str, help='path to the lastest checkpoint (default: none)')
62 | parser.add_argument('--basemodel', default='Conv64', help='Conv64')
63 | parser.add_argument('--workers', type=int, default=8)
64 | # Few-shot parameters #
65 | parser.add_argument('--imageSize', type=int, default=84)
66 | parser.add_argument('--episodeSize', type=int, default=1, help='the mini-batch size of training')
67 | parser.add_argument('--testepisodeSize', type=int, default=1, help='one episode is taken as a mini-batch')
68 | parser.add_argument('--epochs', type=int, default=30, help='the total number of training epoch')
69 | parser.add_argument('--episode_train_num', type=int, default=10000, help='the total number of training episodes')
70 | parser.add_argument('--episode_val_num', type=int, default=1000, help='the total number of evaluation episodes')
71 | parser.add_argument('--episode_test_num', type=int, default=600, help='the total number of testing episodes')
72 | parser.add_argument('--way_num', type=int, default=5, help='the number of way/class')
73 | parser.add_argument('--shot_num', type=int, default=1, help='the number of shot')
74 | parser.add_argument('--query_num', type=int, default=15, help='the number of queries')
75 | parser.add_argument('--lr', type=float, default=0.005, help='learning rate, default=0.005')
76 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
77 | parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda')
78 | parser.add_argument('--ngpu', type=int, default=1, help='the number of gpus')
79 | parser.add_argument('--nc', type=int, default=3, help='input image channels')
80 | parser.add_argument('--clamp_lower', type=float, default=-0.01)
81 | parser.add_argument('--clamp_upper', type=float, default=0.01)
82 | parser.add_argument('--print_freq', '-p', default=100, type=int, metavar='N', help='print frequency (default: 100)')
83 | opt = parser.parse_args()
84 | opt.cuda = True
85 | cudnn.benchmark = True
86 |
87 |
88 |
89 |
90 | # ======================================= Define functions =============================================
91 | def validate(val_loader, model, criterion, epoch_index, F_txt):
92 | batch_time = AverageMeter()
93 | losses = AverageMeter()
94 | top1 = AverageMeter()
95 |
96 |
97 | # switch to evaluate mode
98 | model.eval()
99 | accuracies = []
100 |
101 |
102 | end = time.time()
103 | for episode_index, (query_images, query_targets, support_images, support_targets) in enumerate(val_loader):
104 |
105 | # Convert query and support images
106 | query_images = torch.cat(query_images, 0)
107 | input_var1 = query_images.cuda()
108 |
109 |
110 | input_var2 = []
111 | for i in range(len(support_images)):
112 | temp_support = support_images[i]
113 | temp_support = torch.cat(temp_support, 0)
114 | temp_support = temp_support.cuda()
115 | input_var2.append(temp_support)
116 |
117 |
118 | # Deal with the targets
119 | target = torch.cat(query_targets, 0)
120 | target = target.cuda()
121 |
122 | # Calculate the output
123 | output = model(input_var1, input_var2)
124 | loss = criterion(output, target)
125 |
126 |
127 | # measure accuracy and record loss
128 | prec1, _ = accuracy(output, target, topk=(1, 3))
129 | losses.update(loss.item(), query_images.size(0))
130 | top1.update(prec1[0], query_images.size(0))
131 | accuracies.append(prec1)
132 |
133 |
134 | # measure elapsed time
135 | batch_time.update(time.time() - end)
136 | end = time.time()
137 |
138 |
139 | #============== print the intermediate results ==============#
140 | if episode_index % opt.print_freq == 0 and episode_index != 0:
141 |
142 | print('Test-({0}): [{1}/{2}]\t'
143 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
144 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
145 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
146 | epoch_index, episode_index, len(val_loader), batch_time=batch_time, loss=losses, top1=top1))
147 |
148 | print('Test-({0}): [{1}/{2}]\t'
149 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
150 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
151 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
152 | epoch_index, episode_index, len(val_loader), batch_time=batch_time, loss=losses, top1=top1), file=F_txt)
153 |
154 |
155 | print(' * Prec@1 {top1.avg:.3f} Best_prec1 {best_prec1:.3f}'.format(top1=top1, best_prec1=best_prec1))
156 | print(' * Prec@1 {top1.avg:.3f} Best_prec1 {best_prec1:.3f}'.format(top1=top1, best_prec1=best_prec1), file=F_txt)
157 |
158 | return top1.avg, accuracies
159 |
160 |
161 | class AverageMeter(object):
162 | """Computes and stores the average and current value"""
163 | def __init__(self):
164 | self.reset()
165 |
166 | def reset(self):
167 | self.val = 0
168 | self.avg = 0
169 | self.sum = 0
170 | self.count = 0
171 |
172 | def update(self, val, n=1):
173 | self.val = val
174 | self.sum += val * n
175 | self.count += n
176 | self.avg = self.sum / self.count
177 |
178 |
179 |
180 | def accuracy(output, target, topk=(1,)):
181 | """Computes the precision@k for the specified values of k"""
182 | with torch.no_grad():
183 | maxk = max(topk)
184 | batch_size = target.size(0)
185 |
186 | _, pred = output.topk(maxk, 1, True, True)
187 | pred = pred.t()
188 | correct = pred.eq(target.view(1, -1).expand_as(pred))
189 |
190 | res = []
191 | for k in topk:
192 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
193 | res.append(correct_k.mul_(100.0 / batch_size))
194 | return res
195 |
196 |
197 | def mean_confidence_interval(data, confidence=0.95):
198 | a = [1.0*np.array(data[i].cpu()) for i in range(len(data))]
199 | n = len(a)
200 | m, se = np.mean(a), scipy.stats.sem(a)
201 | h = se * sp.stats.t._ppf((1+confidence)/2., n-1)
202 | return m,h
203 |
204 |
205 | # ======================================== Settings of path ============================================
206 | # save path
207 | opt.outf = opt.outf+'_'+opt.data_name+'_'+str(opt.basemodel)+'_'+str(opt.way_num)+'Way_'+str(opt.shot_num)+'Shot'
208 |
209 | if not os.path.exists(opt.outf):
210 | os.makedirs(opt.outf)
211 |
212 | if torch.cuda.is_available() and not opt.cuda:
213 | print("WARNING: You have a CUDA device, so you should probably run with --cuda")
214 |
215 | # save the opt and results to txt file
216 | txt_save_path = os.path.join(opt.outf, 'Test_resutls.txt')
217 | F_txt = open(txt_save_path, 'a+')
218 | print(opt)
219 | print(opt, file=F_txt)
220 |
221 |
222 |
223 | # ========================================== Model config ===============================================
224 | ngpu = int(opt.ngpu)
225 | global best_prec1, epoch_index
226 | best_prec1 = 0
227 | epoch_index = 0
228 | model = CovaNet.define_CovarianceNet(which_model=opt.basemodel, num_classes=opt.way_num, norm='batch',
229 | init_type='normal', use_gpu=opt.cuda)
230 |
231 | # define loss function (criterion) and optimizer
232 | criterion = nn.CrossEntropyLoss().cuda()
233 | optimizer = optim.Adam(model.parameters(), lr=opt.lr, betas=(opt.beta1, 0.9))
234 |
235 |
236 |
237 | # optionally resume from a checkpoint
238 | if opt.resume:
239 | if os.path.isfile(opt.resume):
240 | print("=> loading checkpoint '{}'".format(opt.resume))
241 | checkpoint = torch.load(opt.resume)
242 | epoch_index = checkpoint['epoch_index']
243 | best_prec1 = checkpoint['best_prec1']
244 | model.load_state_dict(checkpoint['state_dict'])
245 | optimizer.load_state_dict(checkpoint['optimizer'])
246 | print("=> loaded checkpoint '{}' (epoch {})".format(opt.resume, checkpoint['epoch_index']))
247 | else:
248 | print("=> no checkpoint found at '{}'".format(opt.resume))
249 |
250 | if opt.ngpu > 1:
251 | model = nn.DataParallel(model, range(opt.ngpu))
252 |
253 | print(model)
254 | print(model, file=F_txt) # print the architecture of the network
255 |
256 |
257 |
258 |
259 | # ============================================ Testing phase ========================================
260 | print('\n............Start testing............')
261 | start_time = time.time()
262 | repeat_num = 5 # repeat running the testing code several times
263 |
264 |
265 | total_accuracy = 0.0
266 | total_h = np.zeros(repeat_num)
267 | total_accuracy_vector = []
268 | for r in range(repeat_num):
269 | print('===================================== Round %d =====================================' %r)
270 | print('===================================== Round %d =====================================' %r, file=F_txt)
271 |
272 | # ======================================= Folder of Datasets =======================================
273 |
274 | # image transform & normalization
275 | ImgTransform = transforms.Compose([
276 | transforms.Resize((opt.imageSize, opt.imageSize)),
277 | transforms.ToTensor(),
278 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
279 | ])
280 |
281 | testset = Imagefolder_csv(
282 | data_dir=opt.dataset_dir, mode=opt.mode, image_size=opt.imageSize, transform=ImgTransform,
283 | episode_num=opt.episode_test_num, way_num=opt.way_num, shot_num=opt.shot_num, query_num=opt.query_num
284 | )
285 | print('Testset: %d-------------%d' %(len(testset), r), file=F_txt)
286 |
287 |
288 |
289 | # ========================================== Load Datasets =========================================
290 | test_loader = torch.utils.data.DataLoader(
291 | testset, batch_size=opt.testepisodeSize, shuffle=True,
292 | num_workers=int(opt.workers), drop_last=True, pin_memory=True
293 | )
294 |
295 |
296 | # =========================================== Evaluation ==========================================
297 | prec1, accuracies = validate(test_loader, model, criterion, epoch_index, F_txt)
298 |
299 |
300 | test_accuracy, h = mean_confidence_interval(accuracies)
301 | print("Test accuracy", test_accuracy, "h", h[0])
302 | print("Test accuracy", test_accuracy, "h", h[0], file=F_txt)
303 | total_accuracy += test_accuracy
304 | total_accuracy_vector.extend(accuracies)
305 | total_h[r] = h
306 |
307 |
308 | aver_accuracy, _ = mean_confidence_interval(total_accuracy_vector)
309 | print("Aver_accuracy:", aver_accuracy, "Aver_h", total_h.mean())
310 | print("Aver_accuracy:", aver_accuracy, "Aver_h", total_h.mean(), file=F_txt)
311 | F_txt.close()
312 |
313 |
314 | # ============================================ Testing End ========================================
315 |
--------------------------------------------------------------------------------
/CovaMNet_Test_5way5shot.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | Author: Wenbin Li (liwenbin.nju@gmail.com)
6 | Date: Jan. 14, 2019
7 | Version: V0
8 |
9 | Citation:
10 | @inproceedings{li2019CovaMNet,
11 | title={Distribution Consistency based Covariance Metric Networks for Few-shot Learning},
12 | author={Li, Wenbin and Xu, Jinglin and Huo, Jing and Wang, Lei and Gao Yang and Luo, Jiebo},
13 | booktitle={AAAI},
14 | year={2019}
15 | }
16 | """
17 |
18 |
19 | from __future__ import print_function
20 | import argparse
21 | import os
22 | import random
23 | import shutil
24 | import numpy as np
25 | import torch
26 | import torch.nn as nn
27 | import torch.nn.parallel
28 | import torch.backends.cudnn as cudnn
29 | import torch.optim as optim
30 | import torch.utils.data
31 | import torchvision.datasets as dset
32 | import torchvision.transforms as transforms
33 | import torchvision.utils as vutils
34 | from torch.autograd import grad
35 | import time
36 | from torch import autograd
37 | from PIL import ImageFile
38 | import scipy as sp
39 | import scipy.stats
40 | import sys
41 | sys.dont_write_bytecode = True
42 |
43 |
44 |
45 | # ============================ Data & Networks =====================================
46 | from dataset.datasets_csv import Imagefolder_csv
47 | import models.network as CovaNet
48 | # ==================================================================================
49 |
50 |
51 | ImageFile.LOAD_TRUNCATED_IMAGES = True
52 | os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
53 | os.environ['CUDA_VISIBLE_DEVICES']='0'
54 |
55 |
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument('--dataset_dir', default=' ', help='the path of the data')
58 | parser.add_argument('--data_name', default='miniImageNet', help='miniImageNet|StanfordDog|StanfordCar|CubBird')
59 | parser.add_argument('--mode', default='test', help='train|val|test')
60 | parser.add_argument('--outf', default='./results/CovaMNet')
61 | parser.add_argument('--resume', default=' ', type=str, help='path to the lastest checkpoint (default: none)')
62 | parser.add_argument('--basemodel', default='Conv64', help='Conv64')
63 | parser.add_argument('--workers', type=int, default=8)
64 | # Few-shot parameters #
65 | parser.add_argument('--imageSize', type=int, default=84)
66 | parser.add_argument('--episodeSize', type=int, default=1, help='the mini-batch size of training')
67 | parser.add_argument('--testepisodeSize', type=int, default=1, help='one episode is taken as a mini-batch')
68 | parser.add_argument('--epochs', type=int, default=30, help='the total number of training epoch')
69 | parser.add_argument('--episode_train_num', type=int, default=10000, help='the total number of training episodes')
70 | parser.add_argument('--episode_val_num', type=int, default=1000, help='the total number of evaluation episodes')
71 | parser.add_argument('--episode_test_num', type=int, default=600, help='the total number of testing episodes')
72 | parser.add_argument('--way_num', type=int, default=5, help='the number of way/class')
73 | parser.add_argument('--shot_num', type=int, default=5, help='the number of shot')
74 | parser.add_argument('--query_num', type=int, default=15, help='the number of queries')
75 | parser.add_argument('--lr', type=float, default=0.005, help='learning rate, default=0.005')
76 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
77 | parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda')
78 | parser.add_argument('--ngpu', type=int, default=1, help='the number of gpus')
79 | parser.add_argument('--nc', type=int, default=3, help='input image channels')
80 | parser.add_argument('--clamp_lower', type=float, default=-0.01)
81 | parser.add_argument('--clamp_upper', type=float, default=0.01)
82 | parser.add_argument('--print_freq', '-p', default=100, type=int, metavar='N', help='print frequency (default: 100)')
83 | opt = parser.parse_args()
84 | opt.cuda = True
85 | cudnn.benchmark = True
86 |
87 |
88 |
89 |
90 | # ======================================= Define functions =============================================
91 | def validate(val_loader, model, criterion, epoch_index, F_txt):
92 | batch_time = AverageMeter()
93 | losses = AverageMeter()
94 | top1 = AverageMeter()
95 |
96 |
97 | # switch to evaluate mode
98 | model.eval()
99 | accuracies = []
100 |
101 |
102 | end = time.time()
103 | for episode_index, (query_images, query_targets, support_images, support_targets) in enumerate(val_loader):
104 |
105 | # Convert query and support images
106 | query_images = torch.cat(query_images, 0)
107 | input_var1 = query_images.cuda()
108 |
109 |
110 | input_var2 = []
111 | for i in range(len(support_images)):
112 | temp_support = support_images[i]
113 | temp_support = torch.cat(temp_support, 0)
114 | temp_support = temp_support.cuda()
115 | input_var2.append(temp_support)
116 |
117 |
118 | # Deal with the targets
119 | target = torch.cat(query_targets, 0)
120 | target = target.cuda()
121 |
122 | # Calculate the output
123 | output = model(input_var1, input_var2)
124 | loss = criterion(output, target)
125 |
126 |
127 | # measure accuracy and record loss
128 | prec1, _ = accuracy(output, target, topk=(1, 3))
129 | losses.update(loss.item(), query_images.size(0))
130 | top1.update(prec1[0], query_images.size(0))
131 | accuracies.append(prec1)
132 |
133 |
134 | # measure elapsed time
135 | batch_time.update(time.time() - end)
136 | end = time.time()
137 |
138 |
139 | #============== print the intermediate results ==============#
140 | if episode_index % opt.print_freq == 0 and episode_index != 0:
141 |
142 | print('Test-({0}): [{1}/{2}]\t'
143 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
144 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
145 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
146 | epoch_index, episode_index, len(val_loader), batch_time=batch_time, loss=losses, top1=top1))
147 |
148 | print('Test-({0}): [{1}/{2}]\t'
149 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
150 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
151 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
152 | epoch_index, episode_index, len(val_loader), batch_time=batch_time, loss=losses, top1=top1), file=F_txt)
153 |
154 |
155 | print(' * Prec@1 {top1.avg:.3f} Best_prec1 {best_prec1:.3f}'.format(top1=top1, best_prec1=best_prec1))
156 | print(' * Prec@1 {top1.avg:.3f} Best_prec1 {best_prec1:.3f}'.format(top1=top1, best_prec1=best_prec1), file=F_txt)
157 |
158 | return top1.avg, accuracies
159 |
160 |
161 | class AverageMeter(object):
162 | """Computes and stores the average and current value"""
163 | def __init__(self):
164 | self.reset()
165 |
166 | def reset(self):
167 | self.val = 0
168 | self.avg = 0
169 | self.sum = 0
170 | self.count = 0
171 |
172 | def update(self, val, n=1):
173 | self.val = val
174 | self.sum += val * n
175 | self.count += n
176 | self.avg = self.sum / self.count
177 |
178 |
179 |
180 | def accuracy(output, target, topk=(1,)):
181 | """Computes the precision@k for the specified values of k"""
182 | with torch.no_grad():
183 | maxk = max(topk)
184 | batch_size = target.size(0)
185 |
186 | _, pred = output.topk(maxk, 1, True, True)
187 | pred = pred.t()
188 | correct = pred.eq(target.view(1, -1).expand_as(pred))
189 |
190 | res = []
191 | for k in topk:
192 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
193 | res.append(correct_k.mul_(100.0 / batch_size))
194 | return res
195 |
196 |
197 | def mean_confidence_interval(data, confidence=0.95):
198 | a = [1.0*np.array(data[i].cpu()) for i in range(len(data))]
199 | n = len(a)
200 | m, se = np.mean(a), scipy.stats.sem(a)
201 | h = se * sp.stats.t._ppf((1+confidence)/2., n-1)
202 | return m,h
203 |
204 |
205 | # ======================================== Settings of path ============================================
206 | # save path
207 | opt.outf = opt.outf+'_'+opt.data_name+'_'+str(opt.basemodel)+'_'+str(opt.way_num)+'Way_'+str(opt.shot_num)+'Shot'
208 |
209 | if not os.path.exists(opt.outf):
210 | os.makedirs(opt.outf)
211 |
212 | if torch.cuda.is_available() and not opt.cuda:
213 | print("WARNING: You have a CUDA device, so you should probably run with --cuda")
214 |
215 | # save the opt and results to txt file
216 | txt_save_path = os.path.join(opt.outf, 'Test_resutls.txt')
217 | F_txt = open(txt_save_path, 'a+')
218 | print(opt)
219 | print(opt, file=F_txt)
220 |
221 |
222 |
223 | # ========================================== Model config ===============================================
224 | ngpu = int(opt.ngpu)
225 | global best_prec1, epoch_index
226 | best_prec1 = 0
227 | epoch_index = 0
228 | model = CovaNet.define_CovarianceNet(which_model=opt.basemodel, num_classes=opt.way_num, norm='batch',
229 | init_type='normal', use_gpu=opt.cuda)
230 |
231 | # define loss function (criterion) and optimizer
232 | criterion = nn.CrossEntropyLoss().cuda()
233 | optimizer = optim.Adam(model.parameters(), lr=opt.lr, betas=(opt.beta1, 0.9))
234 |
235 |
236 |
237 | # optionally resume from a checkpoint
238 | if opt.resume:
239 | if os.path.isfile(opt.resume):
240 | print("=> loading checkpoint '{}'".format(opt.resume))
241 | checkpoint = torch.load(opt.resume)
242 | epoch_index = checkpoint['epoch_index']
243 | best_prec1 = checkpoint['best_prec1']
244 | model.load_state_dict(checkpoint['state_dict'])
245 | optimizer.load_state_dict(checkpoint['optimizer'])
246 | print("=> loaded checkpoint '{}' (epoch {})".format(opt.resume, checkpoint['epoch_index']))
247 | else:
248 | print("=> no checkpoint found at '{}'".format(opt.resume))
249 |
250 | if opt.ngpu > 1:
251 | model = nn.DataParallel(model, range(opt.ngpu))
252 |
253 | print(model)
254 | print(model, file=F_txt) # print the architecture of the network
255 |
256 |
257 |
258 |
259 | # ============================================ Testing phase ========================================
260 | print('\n............Start testing............')
261 | start_time = time.time()
262 | repeat_num = 5 # repeat running the testing code several times
263 |
264 |
265 | total_accuracy = 0.0
266 | total_h = np.zeros(repeat_num)
267 | total_accuracy_vector = []
268 | for r in range(repeat_num):
269 | print('===================================== Round %d =====================================' %r)
270 | print('===================================== Round %d =====================================' %r, file=F_txt)
271 |
272 | # ======================================= Folder of Datasets =======================================
273 |
274 | # image transform & normalization
275 | ImgTransform = transforms.Compose([
276 | transforms.Resize((opt.imageSize, opt.imageSize)),
277 | transforms.ToTensor(),
278 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
279 | ])
280 |
281 | testset = Imagefolder_csv(
282 | data_dir=opt.dataset_dir, mode=opt.mode, image_size=opt.imageSize, transform=ImgTransform,
283 | episode_num=opt.episode_test_num, way_num=opt.way_num, shot_num=opt.shot_num, query_num=opt.query_num
284 | )
285 | print('Testset: %d-------------%d' %(len(testset), r), file=F_txt)
286 |
287 |
288 |
289 | # ========================================== Load Datasets =========================================
290 | test_loader = torch.utils.data.DataLoader(
291 | testset, batch_size=opt.testepisodeSize, shuffle=True,
292 | num_workers=int(opt.workers), drop_last=True, pin_memory=True
293 | )
294 |
295 |
296 | # =========================================== Evaluation ==========================================
297 | prec1, accuracies = validate(test_loader, model, criterion, epoch_index, F_txt)
298 |
299 |
300 | test_accuracy, h = mean_confidence_interval(accuracies)
301 | print("Test accuracy", test_accuracy, "h", h[0])
302 | print("Test accuracy", test_accuracy, "h", h[0], file=F_txt)
303 | total_accuracy += test_accuracy
304 | total_accuracy_vector.extend(accuracies)
305 | total_h[r] = h
306 |
307 |
308 | aver_accuracy, _ = mean_confidence_interval(total_accuracy_vector)
309 | print("Aver_accuracy:", aver_accuracy, "Aver_h", total_h.mean())
310 | print("Aver_accuracy:", aver_accuracy, "Aver_h", total_h.mean(), file=F_txt)
311 | F_txt.close()
312 |
313 |
314 | # ============================================ Testing End ========================================
315 |
--------------------------------------------------------------------------------
/CovaMNet_Train_5way1shot.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | Author: Wenbin Li (liwenbin.nju@gmail.com)
6 | Date: Jan. 14, 2019
7 | Version: V0
8 |
9 | Citation:
10 | @inproceedings{li2019CovaMNet,
11 | title={Distribution Consistency based Covariance Metric Networks for Few-shot Learning},
12 | author={Li, Wenbin and Xu, Jinglin and Huo, Jing and Wang, Lei and Gao Yang and Luo, Jiebo},
13 | booktitle={AAAI},
14 | year={2019}
15 | }
16 | """
17 |
18 |
19 |
20 | from __future__ import print_function
21 | import argparse
22 | import os
23 | import random
24 | import shutil
25 | import numpy as np
26 | import torch
27 | import torch.nn as nn
28 | import torch.nn.parallel
29 | import torch.backends.cudnn as cudnn
30 | import torch.optim as optim
31 | import torch.utils.data
32 | import torchvision.datasets as dset
33 | import torchvision.transforms as transforms
34 | import torchvision.utils as vutils
35 | from torch.autograd import grad
36 | import time
37 | from torch import autograd
38 | from PIL import ImageFile
39 | import pdb
40 | import sys
41 | sys.dont_write_bytecode = True
42 |
43 |
44 | # ============================ Data & Networks =====================================
45 | from dataset.datasets_csv import Imagefolder_csv
46 | import models.network as CovaNet
47 | # ==================================================================================
48 |
49 |
50 | ImageFile.LOAD_TRUNCATED_IMAGES = True
51 | os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
52 | os.environ['CUDA_VISIBLE_DEVICES']='0'
53 |
54 |
55 | parser = argparse.ArgumentParser()
56 | parser.add_argument('--dataset_dir', default='', help='/Datasets/miniImageNet/')
57 | parser.add_argument('--data_name', default='miniImageNet', help='miniImageNet|StanfordDog|StanfordCar|CubBird')
58 | parser.add_argument('--mode', default='train', help='train|val|test')
59 | parser.add_argument('--outf', default='./results/CovaMNet')
60 | parser.add_argument('--resume', default='', type=str, help='path to the lastest checkpoint (default: none)')
61 | parser.add_argument('--basemodel', default='Conv64', help='Conv64')
62 | parser.add_argument('--workers', type=int, default=8)
63 | # Few-shot parameters #
64 | parser.add_argument('--imageSize', type=int, default=84)
65 | parser.add_argument('--episodeSize', type=int, default=1, help='the mini-batch size of training')
66 | parser.add_argument('--testepisodeSize', type=int, default=1, help='one episode is taken as a mini-batch')
67 | parser.add_argument('--epochs', type=int, default=40, help='the total number of training epoch')
68 | parser.add_argument('--episode_train_num', type=int, default=10000, help='the total number of training episodes')
69 | parser.add_argument('--episode_val_num', type=int, default=10000, help='the total number of evaluation episodes')
70 | parser.add_argument('--episode_test_num', type=int, default=1000, help='the total number of testing episodes')
71 | parser.add_argument('--way_num', type=int, default=5, help='the number of way/class')
72 | parser.add_argument('--shot_num', type=int, default=1, help='the number of shot')
73 | parser.add_argument('--query_num', type=int, default=15, help='the number of queries')
74 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate, default=0.005')
75 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
76 | parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda')
77 | parser.add_argument('--ngpu', type=int, default=1, help='the number of gpus')
78 | parser.add_argument('--nc', type=int, default=3, help='input image channels')
79 | parser.add_argument('--clamp_lower', type=float, default=-0.01)
80 | parser.add_argument('--clamp_upper', type=float, default=0.01)
81 | parser.add_argument('--print_freq', '-p', default=100, type=int, metavar='N', help='print frequency (default: 100)')
82 | opt = parser.parse_args()
83 | opt.cuda = True
84 | cudnn.benchmark = True
85 |
86 |
87 |
88 | # ======================================= Define functions =============================================
89 | def adjust_learning_rate(optimizer, epoch_num):
90 | """Sets the learning rate to the initial LR decayed by 0.05 every 10 epochs"""
91 | lr = opt.lr * (0.05 ** (epoch_num // 10))
92 | for param_group in optimizer.param_groups:
93 | param_group['lr'] = lr
94 |
95 |
96 | def train(train_loader, model, criterion, optimizer, epoch_index, F_txt):
97 | batch_time = AverageMeter()
98 | data_time = AverageMeter()
99 | losses = AverageMeter()
100 | top1 = AverageMeter()
101 |
102 |
103 | end = time.time()
104 | for episode_index, (query_images, query_targets, support_images, support_targets) in enumerate(train_loader):
105 |
106 | # Measure data loading time
107 | data_time.update(time.time() - end)
108 |
109 | # Convert query and support images
110 | query_images = torch.cat(query_images, 0)
111 | input_var1 = query_images.cuda()
112 |
113 | input_var2 = []
114 | for i in range(len(support_images)):
115 | temp_support = support_images[i]
116 | temp_support = torch.cat(temp_support, 0)
117 | temp_support = temp_support.cuda()
118 | input_var2.append(temp_support)
119 |
120 | # Deal with the targets
121 | target = torch.cat(query_targets, 0)
122 | target = target.cuda()
123 |
124 | # Calculate the output
125 | output = model(input_var1, input_var2)
126 | loss = criterion(output, target)
127 |
128 | # Compute gradients and do SGD step
129 | optimizer.zero_grad()
130 | loss.backward()
131 | optimizer.step()
132 |
133 |
134 | # Measure accuracy and record loss
135 | prec1, _ = accuracy(output, target, topk=(1,3))
136 | losses.update(loss.item(), query_images.size(0))
137 | top1.update(prec1[0], query_images.size(0))
138 |
139 |
140 | # Measure elapsed time
141 | batch_time.update(time.time() - end)
142 | end = time.time()
143 |
144 |
145 | #============== print the intermediate results ==============#
146 | if episode_index % opt.print_freq == 0 and episode_index != 0:
147 |
148 | print('Eposide-({0}): [{1}/{2}]\t'
149 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
150 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
151 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
152 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
153 | epoch_index, episode_index, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1))
154 |
155 | print('Eposide-({0}): [{1}/{2}]\t'
156 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
157 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
158 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
159 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
160 | epoch_index, episode_index, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1), file=F_txt)
161 |
162 |
163 |
164 | def validate(val_loader, model, criterion, epoch_index, best_prec1, F_txt):
165 | batch_time = AverageMeter()
166 | losses = AverageMeter()
167 | top1 = AverageMeter()
168 |
169 |
170 | # switch to evaluate mode
171 | model.eval()
172 | accuracies = []
173 |
174 |
175 | end = time.time()
176 | for episode_index, (query_images, query_targets, support_images, support_targets) in enumerate(val_loader):
177 |
178 | # Convert query and support images
179 | query_images = torch.cat(query_images, 0)
180 | input_var1 = query_images.cuda()
181 |
182 |
183 | input_var2 = []
184 | for i in range(len(support_images)):
185 | temp_support = support_images[i]
186 | temp_support = torch.cat(temp_support, 0)
187 | temp_support = temp_support.cuda()
188 | input_var2.append(temp_support)
189 |
190 |
191 | # Deal with the targets
192 | target = torch.cat(query_targets, 0)
193 | target = target.cuda()
194 |
195 | # Calculate the output
196 | output = model(input_var1, input_var2)
197 | loss = criterion(output, target)
198 |
199 |
200 | # measure accuracy and record loss
201 | prec1, _ = accuracy(output, target, topk=(1, 3))
202 | losses.update(loss.item(), query_images.size(0))
203 | top1.update(prec1[0], query_images.size(0))
204 | accuracies.append(prec1)
205 |
206 |
207 | # measure elapsed time
208 | batch_time.update(time.time() - end)
209 | end = time.time()
210 |
211 |
212 | #============== print the intermediate results ==============#
213 | if episode_index % opt.print_freq == 0 and episode_index != 0:
214 |
215 | print('Test-({0}): [{1}/{2}]\t'
216 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
217 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
218 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
219 | epoch_index, episode_index, len(val_loader), batch_time=batch_time, loss=losses, top1=top1))
220 |
221 | print('Test-({0}): [{1}/{2}]\t'
222 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
223 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
224 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
225 | epoch_index, episode_index, len(val_loader), batch_time=batch_time, loss=losses, top1=top1), file=F_txt)
226 |
227 |
228 | print(' * Prec@1 {top1.avg:.3f} Best_prec1 {best_prec1:.3f}'.format(top1=top1, best_prec1=best_prec1))
229 | print(' * Prec@1 {top1.avg:.3f} Best_prec1 {best_prec1:.3f}'.format(top1=top1, best_prec1=best_prec1), file=F_txt)
230 |
231 | return top1.avg, accuracies
232 |
233 |
234 | def save_checkpoint(state, filename='checkpoint.pth.tar'):
235 | torch.save(state, filename)
236 |
237 |
238 |
239 | class AverageMeter(object):
240 | """Computes and stores the average and current value"""
241 | def __init__(self):
242 | self.reset()
243 |
244 | def reset(self):
245 | self.val = 0
246 | self.avg = 0
247 | self.sum = 0
248 | self.count = 0
249 |
250 | def update(self, val, n=1):
251 | self.val = val
252 | self.sum += val * n
253 | self.count += n
254 | self.avg = self.sum / self.count
255 |
256 |
257 | def accuracy(output, target, topk=(1,)):
258 | """Computes the precision@k for the specified values of k"""
259 | with torch.no_grad():
260 | maxk = max(topk)
261 | batch_size = target.size(0)
262 |
263 | _, pred = output.topk(maxk, 1, True, True)
264 | pred = pred.t()
265 | correct = pred.eq(target.view(1, -1).expand_as(pred))
266 |
267 | res = []
268 | for k in topk:
269 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
270 | res.append(correct_k.mul_(100.0 / batch_size))
271 | return res
272 |
273 |
274 | # ======================================== Settings of path ============================================
275 | # saving path
276 | opt.outf = opt.outf+'_'+opt.data_name+'_'+str(opt.basemodel)+'_'+str(opt.way_num)+'Way_'+str(opt.shot_num)+'Shot'
277 |
278 | if not os.path.exists(opt.outf):
279 | os.makedirs(opt.outf)
280 |
281 | if torch.cuda.is_available() and not opt.cuda:
282 | print("WARNING: You have a CUDA device, so you should probably run with --cuda")
283 |
284 | # save the opt and results to a txt file
285 | txt_save_path = os.path.join(opt.outf, 'opt_resutls.txt')
286 | F_txt = open(txt_save_path, 'a+')
287 | print(opt)
288 | print(opt, file=F_txt)
289 |
290 |
291 |
292 | # ========================================== Model Config ===============================================
293 | ngpu = int(opt.ngpu)
294 | global best_prec1, epoch_index
295 | best_prec1 = 0
296 | epoch_index = 0
297 | model = CovaNet.define_CovarianceNet(which_model=opt.basemodel, num_classes=opt.way_num, norm='batch',
298 | init_type='normal', use_gpu=opt.cuda)
299 |
300 |
301 | # define loss function (criterion) and optimizer
302 | criterion = nn.CrossEntropyLoss().cuda()
303 | optimizer = optim.Adam(model.parameters(), lr=opt.lr, betas=(opt.beta1, 0.9))
304 |
305 |
306 | # optionally resume from a checkpoint
307 | if opt.resume:
308 | if os.path.isfile(opt.resume):
309 | print("=> loading checkpoint '{}'".format(opt.resume))
310 | checkpoint = torch.load(opt.resume)
311 | epoch_index = checkpoint['epoch_index']
312 | best_prec1 = checkpoint['best_prec1']
313 | model.load_state_dict(checkpoint['state_dict'])
314 | optimizer.load_state_dict(checkpoint['optimizer'])
315 | print("=> loaded checkpoint '{}' (epoch {})".format(opt.resume, checkpoint['epoch_index']))
316 | else:
317 | print("=> no checkpoint found at '{}'".format(opt.resume))
318 |
319 | if opt.ngpu > 1:
320 | model = nn.DataParallel(model, range(opt.ngpu))
321 |
322 | print(model)
323 | print(model, file=F_txt) # print the architecture of the network
324 |
325 |
326 |
327 |
328 | # ============================================ Training phase ========================================
329 | print('\n............Start training............\n')
330 | start_time = time.time()
331 |
332 |
333 | for epoch_item in range(opt.epochs):
334 | print('===================================== Epoch %d =====================================' %epoch_item)
335 | print('===================================== Epoch %d =====================================' %epoch_item, file=F_txt)
336 | adjust_learning_rate(optimizer, epoch_item)
337 |
338 |
339 | # ======================================= Folder of Datasets =======================================
340 | # image transform & normalization
341 | ImgTransform = transforms.Compose([
342 | transforms.Resize((opt.imageSize, opt.imageSize)),
343 | transforms.ToTensor(),
344 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
345 | ])
346 |
347 | trainset = Imagefolder_csv(
348 | data_dir=opt.dataset_dir, mode=opt.mode, image_size=opt.imageSize, transform=ImgTransform,
349 | episode_num=opt.episode_train_num, way_num=opt.way_num, shot_num=opt.shot_num, query_num=opt.query_num
350 | )
351 | valset = Imagefolder_csv(
352 | data_dir=opt.dataset_dir, mode='val', image_size=opt.imageSize, transform=ImgTransform,
353 | episode_num=opt.episode_val_num, way_num=opt.way_num, shot_num=opt.shot_num, query_num=opt.query_num
354 | )
355 | testset = Imagefolder_csv(
356 | data_dir=opt.dataset_dir, mode='test', image_size=opt.imageSize, transform=ImgTransform,
357 | episode_num=opt.episode_test_num, way_num=opt.way_num, shot_num=opt.shot_num, query_num=opt.query_num
358 | )
359 |
360 | print('Trainset: %d' %len(trainset))
361 | print('Valset: %d' %len(valset))
362 | print('Testset: %d' %len(testset))
363 | print('Trainset: %d' %len(trainset), file=F_txt)
364 | print('Valset: %d' %len(valset), file=F_txt)
365 | print('Testset: %d' %len(testset), file=F_txt)
366 |
367 |
368 |
369 | # ========================================== Load Datasets =========================================
370 | train_loader = torch.utils.data.DataLoader(
371 | trainset, batch_size=opt.episodeSize, shuffle=True,
372 | num_workers=int(opt.workers), drop_last=True, pin_memory=True
373 | )
374 | val_loader = torch.utils.data.DataLoader(
375 | valset, batch_size=opt.testepisodeSize, shuffle=True,
376 | num_workers=int(opt.workers), drop_last=True, pin_memory=True
377 | )
378 | test_loader = torch.utils.data.DataLoader(
379 | testset, batch_size=opt.testepisodeSize, shuffle=True,
380 | num_workers=int(opt.workers), drop_last=True, pin_memory=True
381 | )
382 |
383 |
384 | # ============================================ Training ===========================================
385 | # Fix the parameters of Batch Normalization after 10000 episodes (1 epoch)
386 | if epoch_item < 1:
387 | model.train()
388 | else:
389 | model.eval()
390 |
391 | # Train for 10000 episodes in each epoch
392 | train(train_loader, model, criterion, optimizer, epoch_item, F_txt)
393 |
394 |
395 | # =========================================== Evaluation ==========================================
396 | print('============ Validation on the val set ============')
397 | print('============ validation on the val set ============', file=F_txt)
398 | prec1, _ = validate(val_loader, model, criterion, epoch_item, best_prec1, F_txt)
399 |
400 |
401 | # record the best prec@1 and save checkpoint
402 | is_best = prec1 > best_prec1
403 | best_prec1 = max(prec1, best_prec1)
404 |
405 | # save the checkpoint
406 | if is_best:
407 | save_checkpoint(
408 | {
409 | 'epoch_index': epoch_item,
410 | 'arch': opt.basemodel,
411 | 'state_dict': model.state_dict(),
412 | 'best_prec1': best_prec1,
413 | 'optimizer' : optimizer.state_dict(),
414 | }, os.path.join(opt.outf, 'model_best.pth.tar'))
415 |
416 |
417 | if epoch_item % 10 == 0:
418 | filename = os.path.join(opt.outf, 'epoch_%d.pth.tar' %epoch_item)
419 | save_checkpoint(
420 | {
421 | 'epoch_index': epoch_item,
422 | 'arch': opt.basemodel,
423 | 'state_dict': model.state_dict(),
424 | 'best_prec1': best_prec1,
425 | 'optimizer' : optimizer.state_dict(),
426 | }, filename)
427 |
428 |
429 | # Testing Prase
430 | print('============ Testing on the test set ============')
431 | print('============ Testing on the test set ============', file=F_txt)
432 | prec1, _ = validate(test_loader, model, criterion, epoch_item, best_prec1, F_txt)
433 |
434 |
435 | F_txt.close()
436 | print('............Training is end............')
437 |
438 | # ============================================ Training End ==========================================
439 |
440 |
441 |
442 |
443 |
--------------------------------------------------------------------------------
/CovaMNet_Train_5way5shot.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | Author: Wenbin Li (liwenbin.nju@gmail.com)
6 | Date: Jan. 14, 2019
7 | Version: V0
8 |
9 | Citation:
10 | @inproceedings{li2019CovaMNet,
11 | title={Distribution Consistency based Covariance Metric Networks for Few-shot Learning},
12 | author={Li, Wenbin and Xu, Jinglin and Huo, Jing and Wang, Lei and Gao Yang and Luo, Jiebo},
13 | booktitle={AAAI},
14 | year={2019}
15 | }
16 | """
17 |
18 |
19 |
20 | from __future__ import print_function
21 | import argparse
22 | import os
23 | import random
24 | import shutil
25 | import numpy as np
26 | import torch
27 | import torch.nn as nn
28 | import torch.nn.parallel
29 | import torch.backends.cudnn as cudnn
30 | import torch.optim as optim
31 | import torch.utils.data
32 | import torchvision.datasets as dset
33 | import torchvision.transforms as transforms
34 | import torchvision.utils as vutils
35 | from torch.autograd import grad
36 | import time
37 | from torch import autograd
38 | from PIL import ImageFile
39 | import pdb
40 | import sys
41 | sys.dont_write_bytecode = True
42 |
43 |
44 |
45 | # ============================ Data & Networks =====================================
46 | from dataset.datasets_csv import Imagefolder_csv
47 | import models.network as CovaNet
48 | # ==================================================================================
49 |
50 |
51 | ImageFile.LOAD_TRUNCATED_IMAGES = True
52 | os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
53 | os.environ['CUDA_VISIBLE_DEVICES']='0'
54 |
55 |
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument('--dataset_dir', default='', help='/Datasets/miniImageNet/')
58 | parser.add_argument('--data_name', default='miniImageNet', help='miniImageNet|StanfordDog|StanfordCar|CubBird')
59 | parser.add_argument('--mode', default='train', help='train|val|test')
60 | parser.add_argument('--outf', default='./results/CovaMNet')
61 | parser.add_argument('--resume', default='', type=str, help='path to the lastest checkpoint (default: none)')
62 | parser.add_argument('--basemodel', default='Conv64', help='Conv64')
63 | parser.add_argument('--workers', type=int, default=8)
64 | # Few-shot parameters #
65 | parser.add_argument('--imageSize', type=int, default=84)
66 | parser.add_argument('--episodeSize', type=int, default=1, help='the mini-batch size of training')
67 | parser.add_argument('--testepisodeSize', type=int, default=1, help='one episode is taken as a mini-batch')
68 | parser.add_argument('--epochs', type=int, default=40, help='the total number of training epoch')
69 | parser.add_argument('--episode_train_num', type=int, default=10000, help='the total number of training episodes')
70 | parser.add_argument('--episode_val_num', type=int, default=10000, help='the total number of evaluation episodes')
71 | parser.add_argument('--episode_test_num', type=int, default=1000, help='the total number of testing episodes')
72 | parser.add_argument('--way_num', type=int, default=5, help='the number of way/class')
73 | parser.add_argument('--shot_num', type=int, default=5, help='the number of shot')
74 | parser.add_argument('--query_num', type=int, default=10, help='the number of queries')
75 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate, default=0.005')
76 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
77 | parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda')
78 | parser.add_argument('--ngpu', type=int, default=1, help='the number of gpus')
79 | parser.add_argument('--nc', type=int, default=3, help='input image channels')
80 | parser.add_argument('--clamp_lower', type=float, default=-0.01)
81 | parser.add_argument('--clamp_upper', type=float, default=0.01)
82 | parser.add_argument('--print_freq', '-p', default=100, type=int, metavar='N', help='print frequency (default: 100)')
83 | opt = parser.parse_args()
84 | opt.cuda = True
85 | cudnn.benchmark = True
86 |
87 |
88 |
89 | # ======================================= Define functions =============================================
90 | def adjust_learning_rate(optimizer, epoch_num):
91 | """Sets the learning rate to the initial LR decayed by 0.05 every 10 epochs"""
92 | lr = opt.lr * (0.05 ** (epoch_num // 10))
93 | for param_group in optimizer.param_groups:
94 | param_group['lr'] = lr
95 |
96 |
97 | def train(train_loader, model, criterion, optimizer, epoch_index, F_txt):
98 | batch_time = AverageMeter()
99 | data_time = AverageMeter()
100 | losses = AverageMeter()
101 | top1 = AverageMeter()
102 |
103 |
104 | end = time.time()
105 | for episode_index, (query_images, query_targets, support_images, support_targets) in enumerate(train_loader):
106 |
107 | # Measure data loading time
108 | data_time.update(time.time() - end)
109 |
110 | # Convert query and support images
111 | query_images = torch.cat(query_images, 0)
112 | input_var1 = query_images.cuda()
113 |
114 | input_var2 = []
115 | for i in range(len(support_images)):
116 | temp_support = support_images[i]
117 | temp_support = torch.cat(temp_support, 0)
118 | temp_support = temp_support.cuda()
119 | input_var2.append(temp_support)
120 |
121 | # Deal with the targets
122 | target = torch.cat(query_targets, 0)
123 | target = target.cuda()
124 |
125 | # Calculate the output
126 | output = model(input_var1, input_var2)
127 | loss = criterion(output, target)
128 |
129 | # Compute gradients and do SGD step
130 | optimizer.zero_grad()
131 | loss.backward()
132 | optimizer.step()
133 |
134 |
135 | # Measure accuracy and record loss
136 | prec1, _ = accuracy(output, target, topk=(1,3))
137 | losses.update(loss.item(), query_images.size(0))
138 | top1.update(prec1[0], query_images.size(0))
139 |
140 |
141 | # Measure elapsed time
142 | batch_time.update(time.time() - end)
143 | end = time.time()
144 |
145 |
146 | #============== print the intermediate results ==============#
147 | if episode_index % opt.print_freq == 0 and episode_index != 0:
148 |
149 | print('Eposide-({0}): [{1}/{2}]\t'
150 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
151 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
152 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
153 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
154 | epoch_index, episode_index, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1))
155 |
156 | print('Eposide-({0}): [{1}/{2}]\t'
157 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
158 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
159 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
160 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
161 | epoch_index, episode_index, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1), file=F_txt)
162 |
163 |
164 |
165 | def validate(val_loader, model, criterion, epoch_index, best_prec1, F_txt):
166 | batch_time = AverageMeter()
167 | losses = AverageMeter()
168 | top1 = AverageMeter()
169 |
170 |
171 | # switch to evaluate mode
172 | model.eval()
173 | accuracies = []
174 |
175 |
176 | end = time.time()
177 | for episode_index, (query_images, query_targets, support_images, support_targets) in enumerate(val_loader):
178 |
179 | # Convert query and support images
180 | query_images = torch.cat(query_images, 0)
181 | input_var1 = query_images.cuda()
182 |
183 |
184 | input_var2 = []
185 | for i in range(len(support_images)):
186 | temp_support = support_images[i]
187 | temp_support = torch.cat(temp_support, 0)
188 | temp_support = temp_support.cuda()
189 | input_var2.append(temp_support)
190 |
191 |
192 | # Deal with the targets
193 | target = torch.cat(query_targets, 0)
194 | target = target.cuda()
195 |
196 | # Calculate the output
197 | output = model(input_var1, input_var2)
198 | loss = criterion(output, target)
199 |
200 |
201 | # measure accuracy and record loss
202 | prec1, _ = accuracy(output, target, topk=(1, 3))
203 | losses.update(loss.item(), query_images.size(0))
204 | top1.update(prec1[0], query_images.size(0))
205 | accuracies.append(prec1)
206 |
207 |
208 | # measure elapsed time
209 | batch_time.update(time.time() - end)
210 | end = time.time()
211 |
212 |
213 | #============== print the intermediate results ==============#
214 | if episode_index % opt.print_freq == 0 and episode_index != 0:
215 |
216 | print('Test-({0}): [{1}/{2}]\t'
217 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
218 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
219 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
220 | epoch_index, episode_index, len(val_loader), batch_time=batch_time, loss=losses, top1=top1))
221 |
222 | print('Test-({0}): [{1}/{2}]\t'
223 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
224 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
225 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
226 | epoch_index, episode_index, len(val_loader), batch_time=batch_time, loss=losses, top1=top1), file=F_txt)
227 |
228 |
229 | print(' * Prec@1 {top1.avg:.3f} Best_prec1 {best_prec1:.3f}'.format(top1=top1, best_prec1=best_prec1))
230 | print(' * Prec@1 {top1.avg:.3f} Best_prec1 {best_prec1:.3f}'.format(top1=top1, best_prec1=best_prec1), file=F_txt)
231 |
232 | return top1.avg, accuracies
233 |
234 |
235 | def save_checkpoint(state, filename='checkpoint.pth.tar'):
236 | torch.save(state, filename)
237 |
238 |
239 |
240 | class AverageMeter(object):
241 | """Computes and stores the average and current value"""
242 | def __init__(self):
243 | self.reset()
244 |
245 | def reset(self):
246 | self.val = 0
247 | self.avg = 0
248 | self.sum = 0
249 | self.count = 0
250 |
251 | def update(self, val, n=1):
252 | self.val = val
253 | self.sum += val * n
254 | self.count += n
255 | self.avg = self.sum / self.count
256 |
257 |
258 | def accuracy(output, target, topk=(1,)):
259 | """Computes the precision@k for the specified values of k"""
260 | with torch.no_grad():
261 | maxk = max(topk)
262 | batch_size = target.size(0)
263 |
264 | _, pred = output.topk(maxk, 1, True, True)
265 | pred = pred.t()
266 | correct = pred.eq(target.view(1, -1).expand_as(pred))
267 |
268 | res = []
269 | for k in topk:
270 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
271 | res.append(correct_k.mul_(100.0 / batch_size))
272 | return res
273 |
274 |
275 | # ======================================== Settings of path ============================================
276 | # saving path
277 | opt.outf = opt.outf+'_'+opt.data_name+'_'+str(opt.basemodel)+'_'+str(opt.way_num)+'Way_'+str(opt.shot_num)+'Shot'
278 |
279 | if not os.path.exists(opt.outf):
280 | os.makedirs(opt.outf)
281 |
282 | if torch.cuda.is_available() and not opt.cuda:
283 | print("WARNING: You have a CUDA device, so you should probably run with --cuda")
284 |
285 | # save the opt and results to a txt file
286 | txt_save_path = os.path.join(opt.outf, 'opt_resutls.txt')
287 | F_txt = open(txt_save_path, 'a+')
288 | print(opt)
289 | print(opt, file=F_txt)
290 |
291 |
292 |
293 | # ========================================== Model Config ===============================================
294 | ngpu = int(opt.ngpu)
295 | global best_prec1, epoch_index
296 | best_prec1 = 0
297 | epoch_index = 0
298 | model = CovaNet.define_CovarianceNet(which_model=opt.basemodel, num_classes=opt.way_num, norm='batch',
299 | init_type='normal', use_gpu=opt.cuda)
300 |
301 |
302 | # define loss function (criterion) and optimizer
303 | criterion = nn.CrossEntropyLoss().cuda()
304 | optimizer = optim.Adam(model.parameters(), lr=opt.lr, betas=(opt.beta1, 0.9))
305 |
306 |
307 | # optionally resume from a checkpoint
308 | if opt.resume:
309 | if os.path.isfile(opt.resume):
310 | print("=> loading checkpoint '{}'".format(opt.resume))
311 | checkpoint = torch.load(opt.resume)
312 | epoch_index = checkpoint['epoch_index']
313 | best_prec1 = checkpoint['best_prec1']
314 | model.load_state_dict(checkpoint['state_dict'])
315 | optimizer.load_state_dict(checkpoint['optimizer'])
316 | print("=> loaded checkpoint '{}' (epoch {})".format(opt.resume, checkpoint['epoch_index']))
317 | else:
318 | print("=> no checkpoint found at '{}'".format(opt.resume))
319 |
320 | if opt.ngpu > 1:
321 | model = nn.DataParallel(model, range(opt.ngpu))
322 |
323 | print(model)
324 | print(model, file=F_txt) # print the architecture of the network
325 |
326 |
327 |
328 |
329 | # ============================================ Training phase ========================================
330 | print('\n............Start training............\n')
331 | start_time = time.time()
332 |
333 |
334 | for epoch_item in range(opt.epochs):
335 | print('===================================== Epoch %d =====================================' %epoch_item)
336 | print('===================================== Epoch %d =====================================' %epoch_item, file=F_txt)
337 | adjust_learning_rate(optimizer, epoch_item)
338 |
339 |
340 | # ======================================= Folder of Datasets =======================================
341 | # image transform & normalization
342 | ImgTransform = transforms.Compose([
343 | transforms.Resize((opt.imageSize, opt.imageSize)),
344 | transforms.ToTensor(),
345 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
346 | ])
347 |
348 | trainset = Imagefolder_csv(
349 | data_dir=opt.dataset_dir, mode=opt.mode, image_size=opt.imageSize, transform=ImgTransform,
350 | episode_num=opt.episode_train_num, way_num=opt.way_num, shot_num=opt.shot_num, query_num=opt.query_num
351 | )
352 | valset = Imagefolder_csv(
353 | data_dir=opt.dataset_dir, mode='val', image_size=opt.imageSize, transform=ImgTransform,
354 | episode_num=opt.episode_val_num, way_num=opt.way_num, shot_num=opt.shot_num, query_num=opt.query_num
355 | )
356 | testset = Imagefolder_csv(
357 | data_dir=opt.dataset_dir, mode='test', image_size=opt.imageSize, transform=ImgTransform,
358 | episode_num=opt.episode_test_num, way_num=opt.way_num, shot_num=opt.shot_num, query_num=opt.query_num
359 | )
360 |
361 | print('Trainset: %d' %len(trainset))
362 | print('Valset: %d' %len(valset))
363 | print('Testset: %d' %len(testset))
364 | print('Trainset: %d' %len(trainset), file=F_txt)
365 | print('Valset: %d' %len(valset), file=F_txt)
366 | print('Testset: %d' %len(testset), file=F_txt)
367 |
368 |
369 |
370 | # ========================================== Load Datasets =========================================
371 | train_loader = torch.utils.data.DataLoader(
372 | trainset, batch_size=opt.episodeSize, shuffle=True,
373 | num_workers=int(opt.workers), drop_last=True, pin_memory=True
374 | )
375 | val_loader = torch.utils.data.DataLoader(
376 | valset, batch_size=opt.testepisodeSize, shuffle=True,
377 | num_workers=int(opt.workers), drop_last=True, pin_memory=True
378 | )
379 | test_loader = torch.utils.data.DataLoader(
380 | testset, batch_size=opt.testepisodeSize, shuffle=True,
381 | num_workers=int(opt.workers), drop_last=True, pin_memory=True
382 | )
383 |
384 |
385 | # ============================================ Training ===========================================
386 | # Fix the parameters of Batch Normalization after 10000 episodes (1 epoch)
387 | if epoch_item < 1:
388 | model.train()
389 | else:
390 | model.eval()
391 |
392 | # Train for 10000 episodes in each epoch
393 | train(train_loader, model, criterion, optimizer, epoch_item, F_txt)
394 |
395 |
396 | # =========================================== Evaluation ==========================================
397 | print('============ Validation on the val set ============')
398 | print('============ validation on the val set ============', file=F_txt)
399 | prec1, _ = validate(val_loader, model, criterion, epoch_item, best_prec1, F_txt)
400 |
401 |
402 | # record the best prec@1 and save checkpoint
403 | is_best = prec1 > best_prec1
404 | best_prec1 = max(prec1, best_prec1)
405 |
406 | # save the checkpoint
407 | if is_best:
408 | save_checkpoint(
409 | {
410 | 'epoch_index': epoch_item,
411 | 'arch': opt.basemodel,
412 | 'state_dict': model.state_dict(),
413 | 'best_prec1': best_prec1,
414 | 'optimizer' : optimizer.state_dict(),
415 | }, os.path.join(opt.outf, 'model_best.pth.tar'))
416 |
417 |
418 | if epoch_item % 10 == 0:
419 | filename = os.path.join(opt.outf, 'epoch_%d.pth.tar' %epoch_item)
420 | save_checkpoint(
421 | {
422 | 'epoch_index': epoch_item,
423 | 'arch': opt.basemodel,
424 | 'state_dict': model.state_dict(),
425 | 'best_prec1': best_prec1,
426 | 'optimizer' : optimizer.state_dict(),
427 | }, filename)
428 |
429 |
430 | # Testing Prase
431 | print('============ Testing on the test set ============')
432 | print('============ Testing on the test set ============', file=F_txt)
433 | prec1, _ = validate(test_loader, model, criterion, epoch_item, best_prec1, F_txt)
434 |
435 |
436 | F_txt.close()
437 | print('............Training is end............')
438 |
439 | # ============================================ Training End ==========================================
440 |
441 |
442 |
443 |
444 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2019, Wenbin Li
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
25 |
26 | --------------------------- LICENSE FOR CovaMNet --------------------------------
27 | BSD License
28 |
29 | For CovaMNet software
30 | Copyright (c) 2019, Wenbin Li
31 | All rights reserved.
32 |
33 | Redistribution and use in source and binary forms, with or without
34 | modification, are permitted provided that the following conditions are met:
35 |
36 | * Redistributions of source code must retain the above copyright notice, this
37 | list of conditions and the following disclaimer.
38 |
39 | * Redistributions in binary form must reproduce the above copyright notice,
40 | this list of conditions and the following disclaimer in the documentation
41 | and/or other materials provided with the distribution.
42 |
43 |
44 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CovaMNet in PyTorch
2 |
3 | We provide a PyTorch implementation of CovaMNet for few-shot learning. The code was written by [Wenbin Li](https://github.com/WenbinLee) [Homepage].
4 |
5 | If you use this code for your research, please cite:
6 |
7 | [Distribution Consistency based Covariance Metric Networks for Few-shot Learning](https://cs.nju.edu.cn/rl/people/liwb/AAAI19.pdf).
8 | [Wenbin Li](https://cs.nju.edu.cn/liwenbin/), Jinglin Xu, Jing Huo, Lei Wang, Yang Gao and Jiebo Luo. In AAAI 2019.
9 |
10 |
11 |
12 | ## Prerequisites
13 | - Linux
14 | - Python 3
15 | - Pytorch 0.4
16 | - GPU + CUDA CuDNN
17 |
18 | ## Getting Started
19 | ### Installation
20 |
21 | - Clone this repo:
22 | ```bash
23 | git clone https://github.com/WenbinLee/CovaMNet
24 | cd CovaMNet
25 | ```
26 |
27 | - Install [PyTorch](http://pytorch.org) 0.4 and other dependencies (e.g., torchvision).
28 |
29 | ### Datasets
30 | - [miniImageNet](https://drive.google.com/file/d/1fUBrpv8iutYwdL4xE1rX_R9ef6tyncX9/view).
31 | - [StanfordDog](http://vision.stanford.edu/aditya86/ImageNetDogs/).
32 | - [StanfordCar](https://ai.stanford.edu/~jkrause/cars/car_dataset.html).
33 | - [CUB-200](http://www.vision.caltech.edu/visipedia/CUB-200.html).
34 | Thanks [Victor Garcia](https://github.com/vgsatorras/few-shot-gnn) for providing the miniImageNet dataset. In our paper, we just used the CUB-200 dataset. In fact, there is a newer revision of this dataset with more images, see [Caltech-UCSD Birds-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html). Note, if you use these datasets, please cite the corresponding papers.
35 |
36 |
37 | ### miniImageNet Few-shot Classification
38 | - Train a 5-way 1-shot model:
39 | ```bash
40 | python CovaMNet_Train_5way1shot.py --dataset_dir ./datasets/miniImageNet --data_name miniImageNet
41 | ```
42 | - Test the model (specify the dataset_dir and data_name first):
43 | ```bash
44 | python CovaMNet_Test_5way1shot.py --resume ./results/CovaMNet_miniImageNet_Conv64_5_Way_1_Shot/model_best.pth.tar
45 | ```
46 | - The results on the miniImageNet dataset:
47 |
48 |
49 |
50 | ### Fine-grained Few-shot Classification
51 | - Data prepocessing (e.g., StanfordDog)
52 | - Specify the path of the dataset and the saving path.
53 | - Run the preprocessing script.
54 | ```bash
55 | #!./dataset/StanfordDog/StanfordDog_prepare_csv.py
56 | python ./dataset/StanfordDog/StanfordDog_prepare_csv.py
57 | ```
58 | - Train a 5-way 1-shot model:
59 | ```bash
60 | python CovaMNet_Train_5way1shot.py --dataset_dir ./datasets/StanfordDog --data_name StanfordDog
61 | ```
62 | - Test the model (specify the dataset_dir and data_name first):
63 | ```bash
64 | python CovaMNet_Test_5way1shot.py --resume ./results/CovaMNet_StanfordDog_Conv64_5_Way_1_Shot/model_best.pth.tar
65 | ```
66 | - The results on the fine-grained datasets:
67 |
68 |
69 |
70 |
71 | ## Citation
72 | If you use this code for your research, please cite our paper.
73 | ```
74 | @inproceedings{li2019CovaMNet,
75 | title={Distribution Consistency based Covariance Metric Networks for Few-shot Learning},
76 | author={Li, Wenbin and Xu, Jinglin and Huo, Jing and Wang, Lei and Gao Yang and Luo, Jiebo},
77 | booktitle={AAAI},
78 | year={2019}
79 | }
80 |
81 | ```
82 |
83 |
--------------------------------------------------------------------------------
/dataset/CubBird/CubBird_prepare_csv.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Wenbin Li
3 | ## Date: Dec. 16 2018
4 | ##
5 | ## Divide data into train/val/test in a csv version
6 | ## Output: train.csv, val.csv, test.csv
7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
8 |
9 | import os
10 | import csv
11 | import numpy as np
12 | import random
13 | from PIL import Image
14 | import pdb
15 |
16 |
17 | data_dir = '/FewShot/Datasets/CUB_birds' # the path of the download dataset
18 | save_dir = '/FewShot/Datasets/CUB_birds/For_FewShot' # the saving path of the divided dataset
19 |
20 |
21 | if not os.path.exists(os.path.join(save_dir, 'images')):
22 | os.makedirs(os.path.join(save_dir, 'images'))
23 |
24 | images_dir = os.path.join(data_dir, 'images')
25 | train_class_num = 130
26 | val_class_num = 20
27 | test_class_num = 50
28 |
29 |
30 |
31 | # get all the dog classes
32 | classes_list = [class_name for class_name in os.listdir(images_dir) if os.path.isdir(os.path.join(images_dir, class_name))]
33 |
34 |
35 | # divide the train/val/test set
36 | random.seed(200)
37 | train_list = random.sample(classes_list, train_class_num)
38 | remain_list = [rem for rem in classes_list if rem not in train_list]
39 | val_list = random.sample(remain_list, val_class_num)
40 | test_list = [rem for rem in remain_list if rem not in val_list]
41 |
42 |
43 | # save data into csv file----- Train
44 | train_data = []
45 | for class_name in train_list:
46 | images = [[i, class_name] for i in os.listdir(os.path.join(images_dir, class_name))]
47 | train_data.extend(images)
48 | print('Train----%s' %class_name)
49 |
50 | # read images and store these images
51 | img_paths = [os.path.join(images_dir, class_name, i) for i in os.listdir(os.path.join(images_dir, class_name))]
52 | for index, img_path in enumerate(img_paths):
53 | img = Image.open(img_path)
54 | img = img.convert('RGB')
55 | img.save(os.path.join(save_dir, 'images', images[index][0]), quality=100)
56 |
57 |
58 | with open(os.path.join(save_dir, 'train.csv'), 'w') as csvfile:
59 | writer = csv.writer(csvfile)
60 |
61 | writer.writerow(['filename', 'label'])
62 | writer.writerows(train_data)
63 |
64 |
65 |
66 |
67 | # save data into csv file----- Val
68 | val_data = []
69 | for class_name in val_list:
70 | images = [[i, class_name] for i in os.listdir(os.path.join(images_dir, class_name))]
71 | val_data.extend(images)
72 | print('Val----%s' %class_name)
73 |
74 | # read images and store these images
75 | img_paths = [os.path.join(images_dir, class_name, i) for i in os.listdir(os.path.join(images_dir, class_name))]
76 | for index, img_path in enumerate(img_paths):
77 | img = Image.open(img_path)
78 | img = img.convert('RGB')
79 | img.save(os.path.join(save_dir, 'images', images[index][0]), quality=100)
80 |
81 | with open(os.path.join(save_dir, 'val.csv'), 'w') as csvfile:
82 | writer = csv.writer(csvfile)
83 |
84 | writer.writerow(['filename', 'label'])
85 | writer.writerows(val_data)
86 |
87 |
88 |
89 |
90 | # save data into csv file----- Test
91 | test_data = []
92 | for class_name in test_list:
93 | images = [[i, class_name] for i in os.listdir(os.path.join(images_dir, class_name))]
94 | test_data.extend(images)
95 | print('Test----%s' %class_name)
96 |
97 | # read images and store these images
98 | img_paths = [os.path.join(images_dir, class_name, i) for i in os.listdir(os.path.join(images_dir, class_name))]
99 | for index, img_path in enumerate(img_paths):
100 | img = Image.open(img_path)
101 | img = img.convert('RGB')
102 | img.save(os.path.join(save_dir, 'images', images[index][0]), quality=100)
103 |
104 |
105 | with open(os.path.join(save_dir, 'test.csv'), 'w') as csvfile:
106 | writer = csv.writer(csvfile)
107 |
108 | writer.writerow(['filename', 'label'])
109 | writer.writerows(test_data)
110 |
--------------------------------------------------------------------------------
/dataset/StanfordCar/StanforCar_prepare_csv.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Wenbin Li
3 | ## Date: Dec. 16 2018
4 | ##
5 | ## Divide data into train/val/test in a csv version
6 | ## Output: train.csv, val.csv, test.csv
7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
8 |
9 | import os
10 | import csv
11 | import numpy as np
12 | import random
13 | from PIL import Image
14 | import pdb
15 |
16 |
17 | data_dir = '/FewShot/Datasets/Stanford_cars' # the path of the download dataset
18 | save_dir = '/FewShot/Datasets/Stanford_cars/For_FewShot' # the saving path of the divided dataset
19 |
20 |
21 | if not os.path.exists(os.path.join(save_dir, 'images')):
22 | os.makedirs(os.path.join(save_dir, 'images'))
23 |
24 | images_dir = os.path.join(data_dir, 'images')
25 | train_class_num = 130
26 | val_class_num = 17
27 | test_class_num = 49
28 |
29 |
30 |
31 | # get all the dog classes
32 | classes_list = [class_name for class_name in os.listdir(images_dir) if os.path.isdir(os.path.join(images_dir, class_name))]
33 |
34 |
35 | # divide the train/val/test set
36 | random.seed(196)
37 | train_list = random.sample(classes_list, train_class_num)
38 | remain_list = [rem for rem in classes_list if rem not in train_list]
39 | val_list = random.sample(remain_list, val_class_num)
40 | test_list = [rem for rem in remain_list if rem not in val_list]
41 |
42 |
43 | # save data into csv file----- Train
44 | train_data = []
45 | for class_name in train_list:
46 | images = [[i, class_name] for i in os.listdir(os.path.join(images_dir, class_name))]
47 | train_data.extend(images)
48 | print('Train----%s' %class_name)
49 |
50 | # read images and store these images
51 | img_paths = [os.path.join(images_dir, class_name, i) for i in os.listdir(os.path.join(images_dir, class_name))]
52 | for index, img_path in enumerate(img_paths):
53 | img = Image.open(img_path)
54 | img = img.convert('RGB')
55 | img.save(os.path.join(save_dir, 'images', images[index][0]), quality=100)
56 |
57 |
58 | with open(os.path.join(save_dir, 'train.csv'), 'w') as csvfile:
59 | writer = csv.writer(csvfile)
60 |
61 | writer.writerow(['filename', 'label'])
62 | writer.writerows(train_data)
63 |
64 |
65 |
66 |
67 | # save data into csv file----- Val
68 | val_data = []
69 | for class_name in val_list:
70 | images = [[i, class_name] for i in os.listdir(os.path.join(images_dir, class_name))]
71 | val_data.extend(images)
72 | print('Val----%s' %class_name)
73 |
74 | # read images and store these images
75 | img_paths = [os.path.join(images_dir, class_name, i) for i in os.listdir(os.path.join(images_dir, class_name))]
76 | for index, img_path in enumerate(img_paths):
77 | img = Image.open(img_path)
78 | img = img.convert('RGB')
79 | img.save(os.path.join(save_dir, 'images', images[index][0]), quality=100)
80 |
81 | with open(os.path.join(save_dir, 'val.csv'), 'w') as csvfile:
82 | writer = csv.writer(csvfile)
83 |
84 | writer.writerow(['filename', 'label'])
85 | writer.writerows(val_data)
86 |
87 |
88 |
89 |
90 | # save data into csv file----- Test
91 | test_data = []
92 | for class_name in test_list:
93 | images = [[i, class_name] for i in os.listdir(os.path.join(images_dir, class_name))]
94 | test_data.extend(images)
95 | print('Test----%s' %class_name)
96 |
97 | # read images and store these images
98 | img_paths = [os.path.join(images_dir, class_name, i) for i in os.listdir(os.path.join(images_dir, class_name))]
99 | for index, img_path in enumerate(img_paths):
100 | img = Image.open(img_path)
101 | img = img.convert('RGB')
102 | img.save(os.path.join(save_dir, 'images', images[index][0]), quality=100)
103 |
104 |
105 | with open(os.path.join(save_dir, 'test.csv'), 'w') as csvfile:
106 | writer = csv.writer(csvfile)
107 |
108 | writer.writerow(['filename', 'label'])
109 | writer.writerows(test_data)
110 |
--------------------------------------------------------------------------------
/dataset/StanfordDog/StanfordDog_prepare_csv.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Wenbin Li
3 | ## Date: Dec. 16 2018
4 | ##
5 | ## Divide data into train/val/test in a csv version
6 | ## Output: train.csv, val.csv, test.csv
7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
8 |
9 | import os
10 | import csv
11 | import numpy as np
12 | import random
13 | from PIL import Image
14 | import pdb
15 |
16 |
17 | data_dir = '/FewShot/Datasets/Stanford_dogs' # the path of the download dataset
18 | save_dir = '/FewShot/Datasets/Stanford_dogs/For_FewShot' # the saving path of the divided dataset
19 |
20 |
21 | if not os.path.exists(os.path.join(save_dir, 'images')):
22 | os.makedirs(os.path.join(save_dir, 'images'))
23 |
24 | images_dir = os.path.join(data_dir, 'Images')
25 | train_class_num = 70
26 | val_class_num = 20
27 | test_class_num = 30
28 |
29 |
30 |
31 | # get all the dog classes
32 | classes_list = [class_name for class_name in os.listdir(images_dir) if os.path.isdir(os.path.join(images_dir, class_name))]
33 |
34 |
35 | # divide the train/val/test set
36 | random.seed(120)
37 | train_list = random.sample(classes_list, train_class_num)
38 | remain_list = [rem for rem in classes_list if rem not in train_list]
39 | val_list = random.sample(remain_list, val_class_num)
40 | test_list = [rem for rem in remain_list if rem not in val_list]
41 |
42 |
43 | # save data into csv file----- Train
44 | train_data = []
45 | for class_name in train_list:
46 | images = [[i, class_name] for i in os.listdir(os.path.join(images_dir, class_name))]
47 | train_data.extend(images)
48 | print('Train----%s' %class_name)
49 |
50 | # read images and store these images
51 | img_paths = [os.path.join(images_dir, class_name, i) for i in os.listdir(os.path.join(images_dir, class_name))]
52 | for index, img_path in enumerate(img_paths):
53 | img = Image.open(img_path)
54 | img = img.convert('RGB')
55 | img.save(os.path.join(save_dir, 'images', images[index][0]), quality=100)
56 |
57 |
58 | with open(os.path.join(save_dir, 'train.csv'), 'w') as csvfile:
59 | writer = csv.writer(csvfile)
60 |
61 | writer.writerow(['filename', 'label'])
62 | writer.writerows(train_data)
63 |
64 |
65 |
66 |
67 | # save data into csv file----- Val
68 | val_data = []
69 | for class_name in val_list:
70 | images = [[i, class_name] for i in os.listdir(os.path.join(images_dir, class_name))]
71 | val_data.extend(images)
72 | print('Val----%s' %class_name)
73 |
74 | # read images and store these images
75 | img_paths = [os.path.join(images_dir, class_name, i) for i in os.listdir(os.path.join(images_dir, class_name))]
76 | for index, img_path in enumerate(img_paths):
77 | img = Image.open(img_path)
78 | img = img.convert('RGB')
79 | img.save(os.path.join(save_dir, 'images', images[index][0]), quality=100)
80 |
81 | with open(os.path.join(save_dir, 'val.csv'), 'w') as csvfile:
82 | writer = csv.writer(csvfile)
83 |
84 | writer.writerow(['filename', 'label'])
85 | writer.writerows(val_data)
86 |
87 |
88 |
89 |
90 | # save data into csv file----- Test
91 | test_data = []
92 | for class_name in test_list:
93 | images = [[i, class_name] for i in os.listdir(os.path.join(images_dir, class_name))]
94 | test_data.extend(images)
95 | print('Test----%s' %class_name)
96 |
97 | # read images and store these images
98 | img_paths = [os.path.join(images_dir, class_name, i) for i in os.listdir(os.path.join(images_dir, class_name))]
99 | for index, img_path in enumerate(img_paths):
100 | img = Image.open(img_path)
101 | img = img.convert('RGB')
102 | img.save(os.path.join(save_dir, 'images', images[index][0]), quality=100)
103 |
104 |
105 | with open(os.path.join(save_dir, 'test.csv'), 'w') as csvfile:
106 | writer = csv.writer(csvfile)
107 |
108 | writer.writerow(['filename', 'label'])
109 | writer.writerows(test_data)
110 |
--------------------------------------------------------------------------------
/dataset/datasets_csv.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as path
3 | import json
4 | import torch
5 | import torch.utils.data as data
6 | import numpy as np
7 | import random
8 | from PIL import Image
9 | import pdb
10 | import csv
11 | import sys
12 | sys.dont_write_bytecode = True
13 |
14 |
15 |
16 | def pil_loader(path):
17 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
18 | with open(path, 'rb') as f:
19 | with Image.open(f) as img:
20 | return img.convert('RGB')
21 |
22 |
23 | def accimage_loader(path):
24 | import accimage
25 | try:
26 | return accimage.Image(path)
27 | except IOError:
28 | # Potentially a decoding problem, fall back to PIL.Image
29 | return pil_loader(path)
30 |
31 |
32 | def gray_loader(path):
33 | with open(path, 'rb') as f:
34 | with Image.open(f) as img:
35 | return img.convert('P')
36 |
37 |
38 | def default_loader(path):
39 | from torchvision import get_image_backend
40 | if get_image_backend() == 'accimage':
41 | return accimage_loader(path)
42 | else:
43 | return pil_loader(path)
44 |
45 |
46 | def find_classes(dir):
47 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
48 | classes.sort()
49 | class_to_idx = {classes[i]: i for i in range(len(classes))}
50 |
51 | return classes, class_to_idx
52 |
53 |
54 | class Imagefolder_csv(object):
55 | """
56 | Imagefolder for miniImageNet--ravi, StanfordDog, StanfordCar and CubBird datasets.
57 | Images are stored in the folder of "images";
58 | Indexes are stored in the CSV files.
59 | """
60 |
61 | def __init__(self, data_dir="", mode="train", image_size=84, data_name="miniImageNet",
62 | transform=None, loader=default_loader, gray_loader=gray_loader,
63 | episode_num=1000, way_num=5, shot_num=5, query_num=5):
64 |
65 | super(Imagefolder_csv, self).__init__()
66 |
67 |
68 | # set the paths of the csv files
69 | train_csv = os.path.join(data_dir, 'train.csv')
70 | val_csv = os.path.join(data_dir, 'val.csv')
71 | test_csv = os.path.join(data_dir, 'test.csv')
72 |
73 |
74 | data_list = []
75 | e = 0
76 | if mode == "train":
77 |
78 | # store all the classes and images into a dict
79 | class_img_dict = {}
80 | with open(train_csv) as f_csv:
81 | f_train = csv.reader(f_csv, delimiter=',')
82 | for row in f_train:
83 | if f_train.line_num == 1:
84 | continue
85 | img_name, img_class = row
86 |
87 | if img_class in class_img_dict:
88 | class_img_dict[img_class].append(img_name)
89 | else:
90 | class_img_dict[img_class]=[]
91 | class_img_dict[img_class].append(img_name)
92 | f_csv.close()
93 | class_list = class_img_dict.keys()
94 |
95 |
96 | while e < episode_num:
97 |
98 | # construct each episode
99 | episode = []
100 | e += 1
101 | temp_list = random.sample(class_list, way_num)
102 | label_num = -1
103 |
104 | for item in temp_list:
105 | label_num += 1
106 | imgs_set = class_img_dict[item]
107 | support_imgs = random.sample(imgs_set, shot_num)
108 | query_imgs = [val for val in imgs_set if val not in support_imgs]
109 |
110 | if query_num < len(query_imgs):
111 | query_imgs = random.sample(query_imgs, query_num)
112 |
113 |
114 | # the dir of support set
115 | query_dir = [path.join(data_dir, 'images', i) for i in query_imgs]
116 | support_dir = [path.join(data_dir, 'images', i) for i in support_imgs]
117 |
118 |
119 | data_files = {
120 | "query_img": query_dir,
121 | "support_set": support_dir,
122 | "target": label_num
123 | }
124 | episode.append(data_files)
125 | data_list.append(episode)
126 |
127 |
128 | elif mode == "val":
129 |
130 | # store all the classes and images into a dict
131 | class_img_dict = {}
132 | with open(val_csv) as f_csv:
133 | f_val = csv.reader(f_csv, delimiter=',')
134 | for row in f_val:
135 | if f_val.line_num == 1:
136 | continue
137 | img_name, img_class = row
138 |
139 | if img_class in class_img_dict:
140 | class_img_dict[img_class].append(img_name)
141 | else:
142 | class_img_dict[img_class]=[]
143 | class_img_dict[img_class].append(img_name)
144 | f_csv.close()
145 | class_list = class_img_dict.keys()
146 |
147 |
148 |
149 | while e < episode_num: # setting the episode number to 600
150 |
151 | # construct each episode
152 | episode = []
153 | e += 1
154 | temp_list = random.sample(class_list, way_num)
155 | label_num = -1
156 |
157 | for item in temp_list:
158 | label_num += 1
159 | imgs_set = class_img_dict[item]
160 | support_imgs = random.sample(imgs_set, shot_num)
161 | query_imgs = [val for val in imgs_set if val not in support_imgs]
162 |
163 | if query_num Covariance metric layer --> Classification layer
134 | # Dataset: 84 x 84 x 3, for miniImageNet, StanfordDog, StanfordCar, CubBird
135 | # Filters: 64->64->64->64
136 | # Mapping Sizes: 84->42->21->21->21
137 |
138 |
139 | class CovarianceNet_64(nn.Module):
140 | def __init__(self, norm_layer=nn.BatchNorm2d, num_classes=5):
141 | super(CovarianceNet_64, self).__init__()
142 |
143 | if type(norm_layer) == functools.partial:
144 | use_bias = norm_layer.func == nn.InstanceNorm2d
145 | else:
146 | use_bias = norm_layer == nn.InstanceNorm2d
147 |
148 | self.features = nn.Sequential( # 3*84*84
149 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=use_bias),
150 | norm_layer(64),
151 | nn.LeakyReLU(0.2, True),
152 | nn.MaxPool2d(kernel_size=2, stride=2), # 64*42*42
153 |
154 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias),
155 | norm_layer(64),
156 | nn.LeakyReLU(0.2, True),
157 | nn.MaxPool2d(kernel_size=2, stride=2), # 64*21*21
158 |
159 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias),
160 | norm_layer(64),
161 | nn.LeakyReLU(0.2, True), # 64*21*21
162 |
163 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias),
164 | norm_layer(64),
165 | nn.LeakyReLU(0.2, True), # 64*21*21
166 | )
167 |
168 | self.covariance = CovaBlock() # 1*(441*num_classes)
169 |
170 | self.classifier = nn.Sequential(
171 | nn.LeakyReLU(0.2, True),
172 | nn.Dropout(),
173 | nn.Conv1d(1, 1, kernel_size=441, stride=441, bias=use_bias),
174 | )
175 |
176 |
177 | def forward(self, input1, input2):
178 |
179 | # extract features of input1--query image
180 | q = self.features(input1)
181 |
182 | # extract features of input2--support set
183 | S = []
184 | for i in range(len(input2)):
185 | S.append(self.features(input2[i]))
186 |
187 | x = self.covariance(q, S) # get Batch*1*(h*w*num_classes)
188 | x = self.classifier(x) # get Batch*1*num_classes
189 | x = x.squeeze(1) # get Batch*num_classes
190 |
191 | return x
192 |
193 |
194 |
195 | #========================== Define a Covariance Metric layer ==========================#
196 | # Calculate the local covariance matrix of each category in the support set
197 | # Calculate the Covariance Metric between a query sample and a category
198 |
199 |
200 | class CovaBlock(nn.Module):
201 | def __init__(self):
202 | super(CovaBlock, self).__init__()
203 |
204 |
205 | # calculate the covariance matrix
206 | def cal_covariance(self, input):
207 |
208 | CovaMatrix_list = []
209 | for i in range(len(input)):
210 | support_set_sam = input[i]
211 | B, C, h, w = support_set_sam.size()
212 |
213 | support_set_sam = support_set_sam.permute(1, 0, 2, 3)
214 | support_set_sam = support_set_sam.contiguous().view(C, -1)
215 | mean_support = torch.mean(support_set_sam, 1, True)
216 | support_set_sam = support_set_sam-mean_support
217 |
218 | covariance_matrix = support_set_sam@torch.transpose(support_set_sam, 0, 1)
219 | covariance_matrix = torch.div(covariance_matrix, h*w*B-1)
220 | CovaMatrix_list.append(covariance_matrix)
221 |
222 | return CovaMatrix_list
223 |
224 |
225 | # calculate the similarity
226 | def cal_similarity(self, input, CovaMatrix_list):
227 |
228 | B, C, h, w = input.size()
229 | Cova_Sim = []
230 |
231 | for i in range(B):
232 | query_sam = input[i]
233 | query_sam = query_sam.view(C, -1)
234 | query_sam_norm = torch.norm(query_sam, 2, 1, True)
235 | query_sam = query_sam/query_sam_norm
236 |
237 | if torch.cuda.is_available():
238 | mea_sim = torch.zeros(1, len(CovaMatrix_list)*h*w).cuda()
239 |
240 | for j in range(len(CovaMatrix_list)):
241 | temp_dis = torch.transpose(query_sam, 0, 1)@CovaMatrix_list[j]@query_sam
242 | mea_sim[0, j*h*w:(j+1)*h*w] = temp_dis.diag()
243 |
244 | Cova_Sim.append(mea_sim.unsqueeze(0))
245 |
246 | Cova_Sim = torch.cat(Cova_Sim, 0) # get Batch*1*(h*w*num_classes)
247 | return Cova_Sim
248 |
249 |
250 | def forward(self, x1, x2):
251 |
252 | CovaMatrix_list = self.cal_covariance(x2)
253 | Cova_Sim = self.cal_similarity(x1, CovaMatrix_list)
254 |
255 | return Cova_Sim
256 |
--------------------------------------------------------------------------------
/results/CovaMNet_miniImageNet_Conv64_5_Way_1_Shot/model_best.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WenbinLee/CovaMNet/d65d0bcc0f26bc8d742d75fe3387f89603c89185/results/CovaMNet_miniImageNet_Conv64_5_Way_1_Shot/model_best.pth.tar
--------------------------------------------------------------------------------