├── CovaMNet_Test_5way1shot.py ├── CovaMNet_Test_5way5shot.py ├── CovaMNet_Train_5way1shot.py ├── CovaMNet_Train_5way5shot.py ├── LICENSE ├── README.md ├── dataset ├── CubBird │ └── CubBird_prepare_csv.py ├── StanfordCar │ └── StanforCar_prepare_csv.py ├── StanfordDog │ └── StanfordDog_prepare_csv.py └── datasets_csv.py ├── imgs ├── CovaMNet.bmp ├── result_finegrained.bmp └── results_miniImageNet.bmp ├── models └── network.py └── results └── CovaMNet_miniImageNet_Conv64_5_Way_1_Shot ├── model_best.pth.tar └── opt_resutls.txt /CovaMNet_Test_5way1shot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Author: Wenbin Li (liwenbin.nju@gmail.com) 6 | Date: Jan. 14, 2019 7 | Version: V0 8 | 9 | Citation: 10 | @inproceedings{li2019CovaMNet, 11 | title={Distribution Consistency based Covariance Metric Networks for Few-shot Learning}, 12 | author={Li, Wenbin and Xu, Jinglin and Huo, Jing and Wang, Lei and Gao Yang and Luo, Jiebo}, 13 | booktitle={AAAI}, 14 | year={2019} 15 | } 16 | """ 17 | 18 | 19 | from __future__ import print_function 20 | import argparse 21 | import os 22 | import random 23 | import shutil 24 | import numpy as np 25 | import torch 26 | import torch.nn as nn 27 | import torch.nn.parallel 28 | import torch.backends.cudnn as cudnn 29 | import torch.optim as optim 30 | import torch.utils.data 31 | import torchvision.datasets as dset 32 | import torchvision.transforms as transforms 33 | import torchvision.utils as vutils 34 | from torch.autograd import grad 35 | import time 36 | from torch import autograd 37 | from PIL import ImageFile 38 | import scipy as sp 39 | import scipy.stats 40 | import sys 41 | sys.dont_write_bytecode = True 42 | 43 | 44 | 45 | # ============================ Data & Networks ===================================== 46 | from dataset.datasets_csv import Imagefolder_csv 47 | import models.network as CovaNet 48 | # ================================================================================== 49 | 50 | 51 | ImageFile.LOAD_TRUNCATED_IMAGES = True 52 | os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID' 53 | os.environ['CUDA_VISIBLE_DEVICES']='0' 54 | 55 | 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('--dataset_dir', default='/Datasets/miniImageNet--ravi', help='the path of the data') 58 | parser.add_argument('--data_name', default='miniImageNet', help='miniImageNet|StanfordDog|StanfordCar|CubBird') 59 | parser.add_argument('--mode', default='test', help='train|val|test') 60 | parser.add_argument('--outf', default='./results/CovaMNet') 61 | parser.add_argument('--resume', default='', type=str, help='path to the lastest checkpoint (default: none)') 62 | parser.add_argument('--basemodel', default='Conv64', help='Conv64') 63 | parser.add_argument('--workers', type=int, default=8) 64 | # Few-shot parameters # 65 | parser.add_argument('--imageSize', type=int, default=84) 66 | parser.add_argument('--episodeSize', type=int, default=1, help='the mini-batch size of training') 67 | parser.add_argument('--testepisodeSize', type=int, default=1, help='one episode is taken as a mini-batch') 68 | parser.add_argument('--epochs', type=int, default=30, help='the total number of training epoch') 69 | parser.add_argument('--episode_train_num', type=int, default=10000, help='the total number of training episodes') 70 | parser.add_argument('--episode_val_num', type=int, default=1000, help='the total number of evaluation episodes') 71 | parser.add_argument('--episode_test_num', type=int, default=600, help='the total number of testing episodes') 72 | parser.add_argument('--way_num', type=int, default=5, help='the number of way/class') 73 | parser.add_argument('--shot_num', type=int, default=1, help='the number of shot') 74 | parser.add_argument('--query_num', type=int, default=15, help='the number of queries') 75 | parser.add_argument('--lr', type=float, default=0.005, help='learning rate, default=0.005') 76 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 77 | parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda') 78 | parser.add_argument('--ngpu', type=int, default=1, help='the number of gpus') 79 | parser.add_argument('--nc', type=int, default=3, help='input image channels') 80 | parser.add_argument('--clamp_lower', type=float, default=-0.01) 81 | parser.add_argument('--clamp_upper', type=float, default=0.01) 82 | parser.add_argument('--print_freq', '-p', default=100, type=int, metavar='N', help='print frequency (default: 100)') 83 | opt = parser.parse_args() 84 | opt.cuda = True 85 | cudnn.benchmark = True 86 | 87 | 88 | 89 | 90 | # ======================================= Define functions ============================================= 91 | def validate(val_loader, model, criterion, epoch_index, F_txt): 92 | batch_time = AverageMeter() 93 | losses = AverageMeter() 94 | top1 = AverageMeter() 95 | 96 | 97 | # switch to evaluate mode 98 | model.eval() 99 | accuracies = [] 100 | 101 | 102 | end = time.time() 103 | for episode_index, (query_images, query_targets, support_images, support_targets) in enumerate(val_loader): 104 | 105 | # Convert query and support images 106 | query_images = torch.cat(query_images, 0) 107 | input_var1 = query_images.cuda() 108 | 109 | 110 | input_var2 = [] 111 | for i in range(len(support_images)): 112 | temp_support = support_images[i] 113 | temp_support = torch.cat(temp_support, 0) 114 | temp_support = temp_support.cuda() 115 | input_var2.append(temp_support) 116 | 117 | 118 | # Deal with the targets 119 | target = torch.cat(query_targets, 0) 120 | target = target.cuda() 121 | 122 | # Calculate the output 123 | output = model(input_var1, input_var2) 124 | loss = criterion(output, target) 125 | 126 | 127 | # measure accuracy and record loss 128 | prec1, _ = accuracy(output, target, topk=(1, 3)) 129 | losses.update(loss.item(), query_images.size(0)) 130 | top1.update(prec1[0], query_images.size(0)) 131 | accuracies.append(prec1) 132 | 133 | 134 | # measure elapsed time 135 | batch_time.update(time.time() - end) 136 | end = time.time() 137 | 138 | 139 | #============== print the intermediate results ==============# 140 | if episode_index % opt.print_freq == 0 and episode_index != 0: 141 | 142 | print('Test-({0}): [{1}/{2}]\t' 143 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 144 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 145 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 146 | epoch_index, episode_index, len(val_loader), batch_time=batch_time, loss=losses, top1=top1)) 147 | 148 | print('Test-({0}): [{1}/{2}]\t' 149 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 150 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 151 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 152 | epoch_index, episode_index, len(val_loader), batch_time=batch_time, loss=losses, top1=top1), file=F_txt) 153 | 154 | 155 | print(' * Prec@1 {top1.avg:.3f} Best_prec1 {best_prec1:.3f}'.format(top1=top1, best_prec1=best_prec1)) 156 | print(' * Prec@1 {top1.avg:.3f} Best_prec1 {best_prec1:.3f}'.format(top1=top1, best_prec1=best_prec1), file=F_txt) 157 | 158 | return top1.avg, accuracies 159 | 160 | 161 | class AverageMeter(object): 162 | """Computes and stores the average and current value""" 163 | def __init__(self): 164 | self.reset() 165 | 166 | def reset(self): 167 | self.val = 0 168 | self.avg = 0 169 | self.sum = 0 170 | self.count = 0 171 | 172 | def update(self, val, n=1): 173 | self.val = val 174 | self.sum += val * n 175 | self.count += n 176 | self.avg = self.sum / self.count 177 | 178 | 179 | 180 | def accuracy(output, target, topk=(1,)): 181 | """Computes the precision@k for the specified values of k""" 182 | with torch.no_grad(): 183 | maxk = max(topk) 184 | batch_size = target.size(0) 185 | 186 | _, pred = output.topk(maxk, 1, True, True) 187 | pred = pred.t() 188 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 189 | 190 | res = [] 191 | for k in topk: 192 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 193 | res.append(correct_k.mul_(100.0 / batch_size)) 194 | return res 195 | 196 | 197 | def mean_confidence_interval(data, confidence=0.95): 198 | a = [1.0*np.array(data[i].cpu()) for i in range(len(data))] 199 | n = len(a) 200 | m, se = np.mean(a), scipy.stats.sem(a) 201 | h = se * sp.stats.t._ppf((1+confidence)/2., n-1) 202 | return m,h 203 | 204 | 205 | # ======================================== Settings of path ============================================ 206 | # save path 207 | opt.outf = opt.outf+'_'+opt.data_name+'_'+str(opt.basemodel)+'_'+str(opt.way_num)+'Way_'+str(opt.shot_num)+'Shot' 208 | 209 | if not os.path.exists(opt.outf): 210 | os.makedirs(opt.outf) 211 | 212 | if torch.cuda.is_available() and not opt.cuda: 213 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 214 | 215 | # save the opt and results to txt file 216 | txt_save_path = os.path.join(opt.outf, 'Test_resutls.txt') 217 | F_txt = open(txt_save_path, 'a+') 218 | print(opt) 219 | print(opt, file=F_txt) 220 | 221 | 222 | 223 | # ========================================== Model config =============================================== 224 | ngpu = int(opt.ngpu) 225 | global best_prec1, epoch_index 226 | best_prec1 = 0 227 | epoch_index = 0 228 | model = CovaNet.define_CovarianceNet(which_model=opt.basemodel, num_classes=opt.way_num, norm='batch', 229 | init_type='normal', use_gpu=opt.cuda) 230 | 231 | # define loss function (criterion) and optimizer 232 | criterion = nn.CrossEntropyLoss().cuda() 233 | optimizer = optim.Adam(model.parameters(), lr=opt.lr, betas=(opt.beta1, 0.9)) 234 | 235 | 236 | 237 | # optionally resume from a checkpoint 238 | if opt.resume: 239 | if os.path.isfile(opt.resume): 240 | print("=> loading checkpoint '{}'".format(opt.resume)) 241 | checkpoint = torch.load(opt.resume) 242 | epoch_index = checkpoint['epoch_index'] 243 | best_prec1 = checkpoint['best_prec1'] 244 | model.load_state_dict(checkpoint['state_dict']) 245 | optimizer.load_state_dict(checkpoint['optimizer']) 246 | print("=> loaded checkpoint '{}' (epoch {})".format(opt.resume, checkpoint['epoch_index'])) 247 | else: 248 | print("=> no checkpoint found at '{}'".format(opt.resume)) 249 | 250 | if opt.ngpu > 1: 251 | model = nn.DataParallel(model, range(opt.ngpu)) 252 | 253 | print(model) 254 | print(model, file=F_txt) # print the architecture of the network 255 | 256 | 257 | 258 | 259 | # ============================================ Testing phase ======================================== 260 | print('\n............Start testing............') 261 | start_time = time.time() 262 | repeat_num = 5 # repeat running the testing code several times 263 | 264 | 265 | total_accuracy = 0.0 266 | total_h = np.zeros(repeat_num) 267 | total_accuracy_vector = [] 268 | for r in range(repeat_num): 269 | print('===================================== Round %d =====================================' %r) 270 | print('===================================== Round %d =====================================' %r, file=F_txt) 271 | 272 | # ======================================= Folder of Datasets ======================================= 273 | 274 | # image transform & normalization 275 | ImgTransform = transforms.Compose([ 276 | transforms.Resize((opt.imageSize, opt.imageSize)), 277 | transforms.ToTensor(), 278 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 279 | ]) 280 | 281 | testset = Imagefolder_csv( 282 | data_dir=opt.dataset_dir, mode=opt.mode, image_size=opt.imageSize, transform=ImgTransform, 283 | episode_num=opt.episode_test_num, way_num=opt.way_num, shot_num=opt.shot_num, query_num=opt.query_num 284 | ) 285 | print('Testset: %d-------------%d' %(len(testset), r), file=F_txt) 286 | 287 | 288 | 289 | # ========================================== Load Datasets ========================================= 290 | test_loader = torch.utils.data.DataLoader( 291 | testset, batch_size=opt.testepisodeSize, shuffle=True, 292 | num_workers=int(opt.workers), drop_last=True, pin_memory=True 293 | ) 294 | 295 | 296 | # =========================================== Evaluation ========================================== 297 | prec1, accuracies = validate(test_loader, model, criterion, epoch_index, F_txt) 298 | 299 | 300 | test_accuracy, h = mean_confidence_interval(accuracies) 301 | print("Test accuracy", test_accuracy, "h", h[0]) 302 | print("Test accuracy", test_accuracy, "h", h[0], file=F_txt) 303 | total_accuracy += test_accuracy 304 | total_accuracy_vector.extend(accuracies) 305 | total_h[r] = h 306 | 307 | 308 | aver_accuracy, _ = mean_confidence_interval(total_accuracy_vector) 309 | print("Aver_accuracy:", aver_accuracy, "Aver_h", total_h.mean()) 310 | print("Aver_accuracy:", aver_accuracy, "Aver_h", total_h.mean(), file=F_txt) 311 | F_txt.close() 312 | 313 | 314 | # ============================================ Testing End ======================================== 315 | -------------------------------------------------------------------------------- /CovaMNet_Test_5way5shot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Author: Wenbin Li (liwenbin.nju@gmail.com) 6 | Date: Jan. 14, 2019 7 | Version: V0 8 | 9 | Citation: 10 | @inproceedings{li2019CovaMNet, 11 | title={Distribution Consistency based Covariance Metric Networks for Few-shot Learning}, 12 | author={Li, Wenbin and Xu, Jinglin and Huo, Jing and Wang, Lei and Gao Yang and Luo, Jiebo}, 13 | booktitle={AAAI}, 14 | year={2019} 15 | } 16 | """ 17 | 18 | 19 | from __future__ import print_function 20 | import argparse 21 | import os 22 | import random 23 | import shutil 24 | import numpy as np 25 | import torch 26 | import torch.nn as nn 27 | import torch.nn.parallel 28 | import torch.backends.cudnn as cudnn 29 | import torch.optim as optim 30 | import torch.utils.data 31 | import torchvision.datasets as dset 32 | import torchvision.transforms as transforms 33 | import torchvision.utils as vutils 34 | from torch.autograd import grad 35 | import time 36 | from torch import autograd 37 | from PIL import ImageFile 38 | import scipy as sp 39 | import scipy.stats 40 | import sys 41 | sys.dont_write_bytecode = True 42 | 43 | 44 | 45 | # ============================ Data & Networks ===================================== 46 | from dataset.datasets_csv import Imagefolder_csv 47 | import models.network as CovaNet 48 | # ================================================================================== 49 | 50 | 51 | ImageFile.LOAD_TRUNCATED_IMAGES = True 52 | os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID' 53 | os.environ['CUDA_VISIBLE_DEVICES']='0' 54 | 55 | 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('--dataset_dir', default=' ', help='the path of the data') 58 | parser.add_argument('--data_name', default='miniImageNet', help='miniImageNet|StanfordDog|StanfordCar|CubBird') 59 | parser.add_argument('--mode', default='test', help='train|val|test') 60 | parser.add_argument('--outf', default='./results/CovaMNet') 61 | parser.add_argument('--resume', default=' ', type=str, help='path to the lastest checkpoint (default: none)') 62 | parser.add_argument('--basemodel', default='Conv64', help='Conv64') 63 | parser.add_argument('--workers', type=int, default=8) 64 | # Few-shot parameters # 65 | parser.add_argument('--imageSize', type=int, default=84) 66 | parser.add_argument('--episodeSize', type=int, default=1, help='the mini-batch size of training') 67 | parser.add_argument('--testepisodeSize', type=int, default=1, help='one episode is taken as a mini-batch') 68 | parser.add_argument('--epochs', type=int, default=30, help='the total number of training epoch') 69 | parser.add_argument('--episode_train_num', type=int, default=10000, help='the total number of training episodes') 70 | parser.add_argument('--episode_val_num', type=int, default=1000, help='the total number of evaluation episodes') 71 | parser.add_argument('--episode_test_num', type=int, default=600, help='the total number of testing episodes') 72 | parser.add_argument('--way_num', type=int, default=5, help='the number of way/class') 73 | parser.add_argument('--shot_num', type=int, default=5, help='the number of shot') 74 | parser.add_argument('--query_num', type=int, default=15, help='the number of queries') 75 | parser.add_argument('--lr', type=float, default=0.005, help='learning rate, default=0.005') 76 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 77 | parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda') 78 | parser.add_argument('--ngpu', type=int, default=1, help='the number of gpus') 79 | parser.add_argument('--nc', type=int, default=3, help='input image channels') 80 | parser.add_argument('--clamp_lower', type=float, default=-0.01) 81 | parser.add_argument('--clamp_upper', type=float, default=0.01) 82 | parser.add_argument('--print_freq', '-p', default=100, type=int, metavar='N', help='print frequency (default: 100)') 83 | opt = parser.parse_args() 84 | opt.cuda = True 85 | cudnn.benchmark = True 86 | 87 | 88 | 89 | 90 | # ======================================= Define functions ============================================= 91 | def validate(val_loader, model, criterion, epoch_index, F_txt): 92 | batch_time = AverageMeter() 93 | losses = AverageMeter() 94 | top1 = AverageMeter() 95 | 96 | 97 | # switch to evaluate mode 98 | model.eval() 99 | accuracies = [] 100 | 101 | 102 | end = time.time() 103 | for episode_index, (query_images, query_targets, support_images, support_targets) in enumerate(val_loader): 104 | 105 | # Convert query and support images 106 | query_images = torch.cat(query_images, 0) 107 | input_var1 = query_images.cuda() 108 | 109 | 110 | input_var2 = [] 111 | for i in range(len(support_images)): 112 | temp_support = support_images[i] 113 | temp_support = torch.cat(temp_support, 0) 114 | temp_support = temp_support.cuda() 115 | input_var2.append(temp_support) 116 | 117 | 118 | # Deal with the targets 119 | target = torch.cat(query_targets, 0) 120 | target = target.cuda() 121 | 122 | # Calculate the output 123 | output = model(input_var1, input_var2) 124 | loss = criterion(output, target) 125 | 126 | 127 | # measure accuracy and record loss 128 | prec1, _ = accuracy(output, target, topk=(1, 3)) 129 | losses.update(loss.item(), query_images.size(0)) 130 | top1.update(prec1[0], query_images.size(0)) 131 | accuracies.append(prec1) 132 | 133 | 134 | # measure elapsed time 135 | batch_time.update(time.time() - end) 136 | end = time.time() 137 | 138 | 139 | #============== print the intermediate results ==============# 140 | if episode_index % opt.print_freq == 0 and episode_index != 0: 141 | 142 | print('Test-({0}): [{1}/{2}]\t' 143 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 144 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 145 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 146 | epoch_index, episode_index, len(val_loader), batch_time=batch_time, loss=losses, top1=top1)) 147 | 148 | print('Test-({0}): [{1}/{2}]\t' 149 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 150 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 151 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 152 | epoch_index, episode_index, len(val_loader), batch_time=batch_time, loss=losses, top1=top1), file=F_txt) 153 | 154 | 155 | print(' * Prec@1 {top1.avg:.3f} Best_prec1 {best_prec1:.3f}'.format(top1=top1, best_prec1=best_prec1)) 156 | print(' * Prec@1 {top1.avg:.3f} Best_prec1 {best_prec1:.3f}'.format(top1=top1, best_prec1=best_prec1), file=F_txt) 157 | 158 | return top1.avg, accuracies 159 | 160 | 161 | class AverageMeter(object): 162 | """Computes and stores the average and current value""" 163 | def __init__(self): 164 | self.reset() 165 | 166 | def reset(self): 167 | self.val = 0 168 | self.avg = 0 169 | self.sum = 0 170 | self.count = 0 171 | 172 | def update(self, val, n=1): 173 | self.val = val 174 | self.sum += val * n 175 | self.count += n 176 | self.avg = self.sum / self.count 177 | 178 | 179 | 180 | def accuracy(output, target, topk=(1,)): 181 | """Computes the precision@k for the specified values of k""" 182 | with torch.no_grad(): 183 | maxk = max(topk) 184 | batch_size = target.size(0) 185 | 186 | _, pred = output.topk(maxk, 1, True, True) 187 | pred = pred.t() 188 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 189 | 190 | res = [] 191 | for k in topk: 192 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 193 | res.append(correct_k.mul_(100.0 / batch_size)) 194 | return res 195 | 196 | 197 | def mean_confidence_interval(data, confidence=0.95): 198 | a = [1.0*np.array(data[i].cpu()) for i in range(len(data))] 199 | n = len(a) 200 | m, se = np.mean(a), scipy.stats.sem(a) 201 | h = se * sp.stats.t._ppf((1+confidence)/2., n-1) 202 | return m,h 203 | 204 | 205 | # ======================================== Settings of path ============================================ 206 | # save path 207 | opt.outf = opt.outf+'_'+opt.data_name+'_'+str(opt.basemodel)+'_'+str(opt.way_num)+'Way_'+str(opt.shot_num)+'Shot' 208 | 209 | if not os.path.exists(opt.outf): 210 | os.makedirs(opt.outf) 211 | 212 | if torch.cuda.is_available() and not opt.cuda: 213 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 214 | 215 | # save the opt and results to txt file 216 | txt_save_path = os.path.join(opt.outf, 'Test_resutls.txt') 217 | F_txt = open(txt_save_path, 'a+') 218 | print(opt) 219 | print(opt, file=F_txt) 220 | 221 | 222 | 223 | # ========================================== Model config =============================================== 224 | ngpu = int(opt.ngpu) 225 | global best_prec1, epoch_index 226 | best_prec1 = 0 227 | epoch_index = 0 228 | model = CovaNet.define_CovarianceNet(which_model=opt.basemodel, num_classes=opt.way_num, norm='batch', 229 | init_type='normal', use_gpu=opt.cuda) 230 | 231 | # define loss function (criterion) and optimizer 232 | criterion = nn.CrossEntropyLoss().cuda() 233 | optimizer = optim.Adam(model.parameters(), lr=opt.lr, betas=(opt.beta1, 0.9)) 234 | 235 | 236 | 237 | # optionally resume from a checkpoint 238 | if opt.resume: 239 | if os.path.isfile(opt.resume): 240 | print("=> loading checkpoint '{}'".format(opt.resume)) 241 | checkpoint = torch.load(opt.resume) 242 | epoch_index = checkpoint['epoch_index'] 243 | best_prec1 = checkpoint['best_prec1'] 244 | model.load_state_dict(checkpoint['state_dict']) 245 | optimizer.load_state_dict(checkpoint['optimizer']) 246 | print("=> loaded checkpoint '{}' (epoch {})".format(opt.resume, checkpoint['epoch_index'])) 247 | else: 248 | print("=> no checkpoint found at '{}'".format(opt.resume)) 249 | 250 | if opt.ngpu > 1: 251 | model = nn.DataParallel(model, range(opt.ngpu)) 252 | 253 | print(model) 254 | print(model, file=F_txt) # print the architecture of the network 255 | 256 | 257 | 258 | 259 | # ============================================ Testing phase ======================================== 260 | print('\n............Start testing............') 261 | start_time = time.time() 262 | repeat_num = 5 # repeat running the testing code several times 263 | 264 | 265 | total_accuracy = 0.0 266 | total_h = np.zeros(repeat_num) 267 | total_accuracy_vector = [] 268 | for r in range(repeat_num): 269 | print('===================================== Round %d =====================================' %r) 270 | print('===================================== Round %d =====================================' %r, file=F_txt) 271 | 272 | # ======================================= Folder of Datasets ======================================= 273 | 274 | # image transform & normalization 275 | ImgTransform = transforms.Compose([ 276 | transforms.Resize((opt.imageSize, opt.imageSize)), 277 | transforms.ToTensor(), 278 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 279 | ]) 280 | 281 | testset = Imagefolder_csv( 282 | data_dir=opt.dataset_dir, mode=opt.mode, image_size=opt.imageSize, transform=ImgTransform, 283 | episode_num=opt.episode_test_num, way_num=opt.way_num, shot_num=opt.shot_num, query_num=opt.query_num 284 | ) 285 | print('Testset: %d-------------%d' %(len(testset), r), file=F_txt) 286 | 287 | 288 | 289 | # ========================================== Load Datasets ========================================= 290 | test_loader = torch.utils.data.DataLoader( 291 | testset, batch_size=opt.testepisodeSize, shuffle=True, 292 | num_workers=int(opt.workers), drop_last=True, pin_memory=True 293 | ) 294 | 295 | 296 | # =========================================== Evaluation ========================================== 297 | prec1, accuracies = validate(test_loader, model, criterion, epoch_index, F_txt) 298 | 299 | 300 | test_accuracy, h = mean_confidence_interval(accuracies) 301 | print("Test accuracy", test_accuracy, "h", h[0]) 302 | print("Test accuracy", test_accuracy, "h", h[0], file=F_txt) 303 | total_accuracy += test_accuracy 304 | total_accuracy_vector.extend(accuracies) 305 | total_h[r] = h 306 | 307 | 308 | aver_accuracy, _ = mean_confidence_interval(total_accuracy_vector) 309 | print("Aver_accuracy:", aver_accuracy, "Aver_h", total_h.mean()) 310 | print("Aver_accuracy:", aver_accuracy, "Aver_h", total_h.mean(), file=F_txt) 311 | F_txt.close() 312 | 313 | 314 | # ============================================ Testing End ======================================== 315 | -------------------------------------------------------------------------------- /CovaMNet_Train_5way1shot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Author: Wenbin Li (liwenbin.nju@gmail.com) 6 | Date: Jan. 14, 2019 7 | Version: V0 8 | 9 | Citation: 10 | @inproceedings{li2019CovaMNet, 11 | title={Distribution Consistency based Covariance Metric Networks for Few-shot Learning}, 12 | author={Li, Wenbin and Xu, Jinglin and Huo, Jing and Wang, Lei and Gao Yang and Luo, Jiebo}, 13 | booktitle={AAAI}, 14 | year={2019} 15 | } 16 | """ 17 | 18 | 19 | 20 | from __future__ import print_function 21 | import argparse 22 | import os 23 | import random 24 | import shutil 25 | import numpy as np 26 | import torch 27 | import torch.nn as nn 28 | import torch.nn.parallel 29 | import torch.backends.cudnn as cudnn 30 | import torch.optim as optim 31 | import torch.utils.data 32 | import torchvision.datasets as dset 33 | import torchvision.transforms as transforms 34 | import torchvision.utils as vutils 35 | from torch.autograd import grad 36 | import time 37 | from torch import autograd 38 | from PIL import ImageFile 39 | import pdb 40 | import sys 41 | sys.dont_write_bytecode = True 42 | 43 | 44 | # ============================ Data & Networks ===================================== 45 | from dataset.datasets_csv import Imagefolder_csv 46 | import models.network as CovaNet 47 | # ================================================================================== 48 | 49 | 50 | ImageFile.LOAD_TRUNCATED_IMAGES = True 51 | os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID' 52 | os.environ['CUDA_VISIBLE_DEVICES']='0' 53 | 54 | 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument('--dataset_dir', default='', help='/Datasets/miniImageNet/') 57 | parser.add_argument('--data_name', default='miniImageNet', help='miniImageNet|StanfordDog|StanfordCar|CubBird') 58 | parser.add_argument('--mode', default='train', help='train|val|test') 59 | parser.add_argument('--outf', default='./results/CovaMNet') 60 | parser.add_argument('--resume', default='', type=str, help='path to the lastest checkpoint (default: none)') 61 | parser.add_argument('--basemodel', default='Conv64', help='Conv64') 62 | parser.add_argument('--workers', type=int, default=8) 63 | # Few-shot parameters # 64 | parser.add_argument('--imageSize', type=int, default=84) 65 | parser.add_argument('--episodeSize', type=int, default=1, help='the mini-batch size of training') 66 | parser.add_argument('--testepisodeSize', type=int, default=1, help='one episode is taken as a mini-batch') 67 | parser.add_argument('--epochs', type=int, default=40, help='the total number of training epoch') 68 | parser.add_argument('--episode_train_num', type=int, default=10000, help='the total number of training episodes') 69 | parser.add_argument('--episode_val_num', type=int, default=10000, help='the total number of evaluation episodes') 70 | parser.add_argument('--episode_test_num', type=int, default=1000, help='the total number of testing episodes') 71 | parser.add_argument('--way_num', type=int, default=5, help='the number of way/class') 72 | parser.add_argument('--shot_num', type=int, default=1, help='the number of shot') 73 | parser.add_argument('--query_num', type=int, default=15, help='the number of queries') 74 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate, default=0.005') 75 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 76 | parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda') 77 | parser.add_argument('--ngpu', type=int, default=1, help='the number of gpus') 78 | parser.add_argument('--nc', type=int, default=3, help='input image channels') 79 | parser.add_argument('--clamp_lower', type=float, default=-0.01) 80 | parser.add_argument('--clamp_upper', type=float, default=0.01) 81 | parser.add_argument('--print_freq', '-p', default=100, type=int, metavar='N', help='print frequency (default: 100)') 82 | opt = parser.parse_args() 83 | opt.cuda = True 84 | cudnn.benchmark = True 85 | 86 | 87 | 88 | # ======================================= Define functions ============================================= 89 | def adjust_learning_rate(optimizer, epoch_num): 90 | """Sets the learning rate to the initial LR decayed by 0.05 every 10 epochs""" 91 | lr = opt.lr * (0.05 ** (epoch_num // 10)) 92 | for param_group in optimizer.param_groups: 93 | param_group['lr'] = lr 94 | 95 | 96 | def train(train_loader, model, criterion, optimizer, epoch_index, F_txt): 97 | batch_time = AverageMeter() 98 | data_time = AverageMeter() 99 | losses = AverageMeter() 100 | top1 = AverageMeter() 101 | 102 | 103 | end = time.time() 104 | for episode_index, (query_images, query_targets, support_images, support_targets) in enumerate(train_loader): 105 | 106 | # Measure data loading time 107 | data_time.update(time.time() - end) 108 | 109 | # Convert query and support images 110 | query_images = torch.cat(query_images, 0) 111 | input_var1 = query_images.cuda() 112 | 113 | input_var2 = [] 114 | for i in range(len(support_images)): 115 | temp_support = support_images[i] 116 | temp_support = torch.cat(temp_support, 0) 117 | temp_support = temp_support.cuda() 118 | input_var2.append(temp_support) 119 | 120 | # Deal with the targets 121 | target = torch.cat(query_targets, 0) 122 | target = target.cuda() 123 | 124 | # Calculate the output 125 | output = model(input_var1, input_var2) 126 | loss = criterion(output, target) 127 | 128 | # Compute gradients and do SGD step 129 | optimizer.zero_grad() 130 | loss.backward() 131 | optimizer.step() 132 | 133 | 134 | # Measure accuracy and record loss 135 | prec1, _ = accuracy(output, target, topk=(1,3)) 136 | losses.update(loss.item(), query_images.size(0)) 137 | top1.update(prec1[0], query_images.size(0)) 138 | 139 | 140 | # Measure elapsed time 141 | batch_time.update(time.time() - end) 142 | end = time.time() 143 | 144 | 145 | #============== print the intermediate results ==============# 146 | if episode_index % opt.print_freq == 0 and episode_index != 0: 147 | 148 | print('Eposide-({0}): [{1}/{2}]\t' 149 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 150 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 151 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 152 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 153 | epoch_index, episode_index, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1)) 154 | 155 | print('Eposide-({0}): [{1}/{2}]\t' 156 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 157 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 158 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 159 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 160 | epoch_index, episode_index, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1), file=F_txt) 161 | 162 | 163 | 164 | def validate(val_loader, model, criterion, epoch_index, best_prec1, F_txt): 165 | batch_time = AverageMeter() 166 | losses = AverageMeter() 167 | top1 = AverageMeter() 168 | 169 | 170 | # switch to evaluate mode 171 | model.eval() 172 | accuracies = [] 173 | 174 | 175 | end = time.time() 176 | for episode_index, (query_images, query_targets, support_images, support_targets) in enumerate(val_loader): 177 | 178 | # Convert query and support images 179 | query_images = torch.cat(query_images, 0) 180 | input_var1 = query_images.cuda() 181 | 182 | 183 | input_var2 = [] 184 | for i in range(len(support_images)): 185 | temp_support = support_images[i] 186 | temp_support = torch.cat(temp_support, 0) 187 | temp_support = temp_support.cuda() 188 | input_var2.append(temp_support) 189 | 190 | 191 | # Deal with the targets 192 | target = torch.cat(query_targets, 0) 193 | target = target.cuda() 194 | 195 | # Calculate the output 196 | output = model(input_var1, input_var2) 197 | loss = criterion(output, target) 198 | 199 | 200 | # measure accuracy and record loss 201 | prec1, _ = accuracy(output, target, topk=(1, 3)) 202 | losses.update(loss.item(), query_images.size(0)) 203 | top1.update(prec1[0], query_images.size(0)) 204 | accuracies.append(prec1) 205 | 206 | 207 | # measure elapsed time 208 | batch_time.update(time.time() - end) 209 | end = time.time() 210 | 211 | 212 | #============== print the intermediate results ==============# 213 | if episode_index % opt.print_freq == 0 and episode_index != 0: 214 | 215 | print('Test-({0}): [{1}/{2}]\t' 216 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 217 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 218 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 219 | epoch_index, episode_index, len(val_loader), batch_time=batch_time, loss=losses, top1=top1)) 220 | 221 | print('Test-({0}): [{1}/{2}]\t' 222 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 223 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 224 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 225 | epoch_index, episode_index, len(val_loader), batch_time=batch_time, loss=losses, top1=top1), file=F_txt) 226 | 227 | 228 | print(' * Prec@1 {top1.avg:.3f} Best_prec1 {best_prec1:.3f}'.format(top1=top1, best_prec1=best_prec1)) 229 | print(' * Prec@1 {top1.avg:.3f} Best_prec1 {best_prec1:.3f}'.format(top1=top1, best_prec1=best_prec1), file=F_txt) 230 | 231 | return top1.avg, accuracies 232 | 233 | 234 | def save_checkpoint(state, filename='checkpoint.pth.tar'): 235 | torch.save(state, filename) 236 | 237 | 238 | 239 | class AverageMeter(object): 240 | """Computes and stores the average and current value""" 241 | def __init__(self): 242 | self.reset() 243 | 244 | def reset(self): 245 | self.val = 0 246 | self.avg = 0 247 | self.sum = 0 248 | self.count = 0 249 | 250 | def update(self, val, n=1): 251 | self.val = val 252 | self.sum += val * n 253 | self.count += n 254 | self.avg = self.sum / self.count 255 | 256 | 257 | def accuracy(output, target, topk=(1,)): 258 | """Computes the precision@k for the specified values of k""" 259 | with torch.no_grad(): 260 | maxk = max(topk) 261 | batch_size = target.size(0) 262 | 263 | _, pred = output.topk(maxk, 1, True, True) 264 | pred = pred.t() 265 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 266 | 267 | res = [] 268 | for k in topk: 269 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 270 | res.append(correct_k.mul_(100.0 / batch_size)) 271 | return res 272 | 273 | 274 | # ======================================== Settings of path ============================================ 275 | # saving path 276 | opt.outf = opt.outf+'_'+opt.data_name+'_'+str(opt.basemodel)+'_'+str(opt.way_num)+'Way_'+str(opt.shot_num)+'Shot' 277 | 278 | if not os.path.exists(opt.outf): 279 | os.makedirs(opt.outf) 280 | 281 | if torch.cuda.is_available() and not opt.cuda: 282 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 283 | 284 | # save the opt and results to a txt file 285 | txt_save_path = os.path.join(opt.outf, 'opt_resutls.txt') 286 | F_txt = open(txt_save_path, 'a+') 287 | print(opt) 288 | print(opt, file=F_txt) 289 | 290 | 291 | 292 | # ========================================== Model Config =============================================== 293 | ngpu = int(opt.ngpu) 294 | global best_prec1, epoch_index 295 | best_prec1 = 0 296 | epoch_index = 0 297 | model = CovaNet.define_CovarianceNet(which_model=opt.basemodel, num_classes=opt.way_num, norm='batch', 298 | init_type='normal', use_gpu=opt.cuda) 299 | 300 | 301 | # define loss function (criterion) and optimizer 302 | criterion = nn.CrossEntropyLoss().cuda() 303 | optimizer = optim.Adam(model.parameters(), lr=opt.lr, betas=(opt.beta1, 0.9)) 304 | 305 | 306 | # optionally resume from a checkpoint 307 | if opt.resume: 308 | if os.path.isfile(opt.resume): 309 | print("=> loading checkpoint '{}'".format(opt.resume)) 310 | checkpoint = torch.load(opt.resume) 311 | epoch_index = checkpoint['epoch_index'] 312 | best_prec1 = checkpoint['best_prec1'] 313 | model.load_state_dict(checkpoint['state_dict']) 314 | optimizer.load_state_dict(checkpoint['optimizer']) 315 | print("=> loaded checkpoint '{}' (epoch {})".format(opt.resume, checkpoint['epoch_index'])) 316 | else: 317 | print("=> no checkpoint found at '{}'".format(opt.resume)) 318 | 319 | if opt.ngpu > 1: 320 | model = nn.DataParallel(model, range(opt.ngpu)) 321 | 322 | print(model) 323 | print(model, file=F_txt) # print the architecture of the network 324 | 325 | 326 | 327 | 328 | # ============================================ Training phase ======================================== 329 | print('\n............Start training............\n') 330 | start_time = time.time() 331 | 332 | 333 | for epoch_item in range(opt.epochs): 334 | print('===================================== Epoch %d =====================================' %epoch_item) 335 | print('===================================== Epoch %d =====================================' %epoch_item, file=F_txt) 336 | adjust_learning_rate(optimizer, epoch_item) 337 | 338 | 339 | # ======================================= Folder of Datasets ======================================= 340 | # image transform & normalization 341 | ImgTransform = transforms.Compose([ 342 | transforms.Resize((opt.imageSize, opt.imageSize)), 343 | transforms.ToTensor(), 344 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 345 | ]) 346 | 347 | trainset = Imagefolder_csv( 348 | data_dir=opt.dataset_dir, mode=opt.mode, image_size=opt.imageSize, transform=ImgTransform, 349 | episode_num=opt.episode_train_num, way_num=opt.way_num, shot_num=opt.shot_num, query_num=opt.query_num 350 | ) 351 | valset = Imagefolder_csv( 352 | data_dir=opt.dataset_dir, mode='val', image_size=opt.imageSize, transform=ImgTransform, 353 | episode_num=opt.episode_val_num, way_num=opt.way_num, shot_num=opt.shot_num, query_num=opt.query_num 354 | ) 355 | testset = Imagefolder_csv( 356 | data_dir=opt.dataset_dir, mode='test', image_size=opt.imageSize, transform=ImgTransform, 357 | episode_num=opt.episode_test_num, way_num=opt.way_num, shot_num=opt.shot_num, query_num=opt.query_num 358 | ) 359 | 360 | print('Trainset: %d' %len(trainset)) 361 | print('Valset: %d' %len(valset)) 362 | print('Testset: %d' %len(testset)) 363 | print('Trainset: %d' %len(trainset), file=F_txt) 364 | print('Valset: %d' %len(valset), file=F_txt) 365 | print('Testset: %d' %len(testset), file=F_txt) 366 | 367 | 368 | 369 | # ========================================== Load Datasets ========================================= 370 | train_loader = torch.utils.data.DataLoader( 371 | trainset, batch_size=opt.episodeSize, shuffle=True, 372 | num_workers=int(opt.workers), drop_last=True, pin_memory=True 373 | ) 374 | val_loader = torch.utils.data.DataLoader( 375 | valset, batch_size=opt.testepisodeSize, shuffle=True, 376 | num_workers=int(opt.workers), drop_last=True, pin_memory=True 377 | ) 378 | test_loader = torch.utils.data.DataLoader( 379 | testset, batch_size=opt.testepisodeSize, shuffle=True, 380 | num_workers=int(opt.workers), drop_last=True, pin_memory=True 381 | ) 382 | 383 | 384 | # ============================================ Training =========================================== 385 | # Fix the parameters of Batch Normalization after 10000 episodes (1 epoch) 386 | if epoch_item < 1: 387 | model.train() 388 | else: 389 | model.eval() 390 | 391 | # Train for 10000 episodes in each epoch 392 | train(train_loader, model, criterion, optimizer, epoch_item, F_txt) 393 | 394 | 395 | # =========================================== Evaluation ========================================== 396 | print('============ Validation on the val set ============') 397 | print('============ validation on the val set ============', file=F_txt) 398 | prec1, _ = validate(val_loader, model, criterion, epoch_item, best_prec1, F_txt) 399 | 400 | 401 | # record the best prec@1 and save checkpoint 402 | is_best = prec1 > best_prec1 403 | best_prec1 = max(prec1, best_prec1) 404 | 405 | # save the checkpoint 406 | if is_best: 407 | save_checkpoint( 408 | { 409 | 'epoch_index': epoch_item, 410 | 'arch': opt.basemodel, 411 | 'state_dict': model.state_dict(), 412 | 'best_prec1': best_prec1, 413 | 'optimizer' : optimizer.state_dict(), 414 | }, os.path.join(opt.outf, 'model_best.pth.tar')) 415 | 416 | 417 | if epoch_item % 10 == 0: 418 | filename = os.path.join(opt.outf, 'epoch_%d.pth.tar' %epoch_item) 419 | save_checkpoint( 420 | { 421 | 'epoch_index': epoch_item, 422 | 'arch': opt.basemodel, 423 | 'state_dict': model.state_dict(), 424 | 'best_prec1': best_prec1, 425 | 'optimizer' : optimizer.state_dict(), 426 | }, filename) 427 | 428 | 429 | # Testing Prase 430 | print('============ Testing on the test set ============') 431 | print('============ Testing on the test set ============', file=F_txt) 432 | prec1, _ = validate(test_loader, model, criterion, epoch_item, best_prec1, F_txt) 433 | 434 | 435 | F_txt.close() 436 | print('............Training is end............') 437 | 438 | # ============================================ Training End ========================================== 439 | 440 | 441 | 442 | 443 | -------------------------------------------------------------------------------- /CovaMNet_Train_5way5shot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Author: Wenbin Li (liwenbin.nju@gmail.com) 6 | Date: Jan. 14, 2019 7 | Version: V0 8 | 9 | Citation: 10 | @inproceedings{li2019CovaMNet, 11 | title={Distribution Consistency based Covariance Metric Networks for Few-shot Learning}, 12 | author={Li, Wenbin and Xu, Jinglin and Huo, Jing and Wang, Lei and Gao Yang and Luo, Jiebo}, 13 | booktitle={AAAI}, 14 | year={2019} 15 | } 16 | """ 17 | 18 | 19 | 20 | from __future__ import print_function 21 | import argparse 22 | import os 23 | import random 24 | import shutil 25 | import numpy as np 26 | import torch 27 | import torch.nn as nn 28 | import torch.nn.parallel 29 | import torch.backends.cudnn as cudnn 30 | import torch.optim as optim 31 | import torch.utils.data 32 | import torchvision.datasets as dset 33 | import torchvision.transforms as transforms 34 | import torchvision.utils as vutils 35 | from torch.autograd import grad 36 | import time 37 | from torch import autograd 38 | from PIL import ImageFile 39 | import pdb 40 | import sys 41 | sys.dont_write_bytecode = True 42 | 43 | 44 | 45 | # ============================ Data & Networks ===================================== 46 | from dataset.datasets_csv import Imagefolder_csv 47 | import models.network as CovaNet 48 | # ================================================================================== 49 | 50 | 51 | ImageFile.LOAD_TRUNCATED_IMAGES = True 52 | os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID' 53 | os.environ['CUDA_VISIBLE_DEVICES']='0' 54 | 55 | 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('--dataset_dir', default='', help='/Datasets/miniImageNet/') 58 | parser.add_argument('--data_name', default='miniImageNet', help='miniImageNet|StanfordDog|StanfordCar|CubBird') 59 | parser.add_argument('--mode', default='train', help='train|val|test') 60 | parser.add_argument('--outf', default='./results/CovaMNet') 61 | parser.add_argument('--resume', default='', type=str, help='path to the lastest checkpoint (default: none)') 62 | parser.add_argument('--basemodel', default='Conv64', help='Conv64') 63 | parser.add_argument('--workers', type=int, default=8) 64 | # Few-shot parameters # 65 | parser.add_argument('--imageSize', type=int, default=84) 66 | parser.add_argument('--episodeSize', type=int, default=1, help='the mini-batch size of training') 67 | parser.add_argument('--testepisodeSize', type=int, default=1, help='one episode is taken as a mini-batch') 68 | parser.add_argument('--epochs', type=int, default=40, help='the total number of training epoch') 69 | parser.add_argument('--episode_train_num', type=int, default=10000, help='the total number of training episodes') 70 | parser.add_argument('--episode_val_num', type=int, default=10000, help='the total number of evaluation episodes') 71 | parser.add_argument('--episode_test_num', type=int, default=1000, help='the total number of testing episodes') 72 | parser.add_argument('--way_num', type=int, default=5, help='the number of way/class') 73 | parser.add_argument('--shot_num', type=int, default=5, help='the number of shot') 74 | parser.add_argument('--query_num', type=int, default=10, help='the number of queries') 75 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate, default=0.005') 76 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 77 | parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda') 78 | parser.add_argument('--ngpu', type=int, default=1, help='the number of gpus') 79 | parser.add_argument('--nc', type=int, default=3, help='input image channels') 80 | parser.add_argument('--clamp_lower', type=float, default=-0.01) 81 | parser.add_argument('--clamp_upper', type=float, default=0.01) 82 | parser.add_argument('--print_freq', '-p', default=100, type=int, metavar='N', help='print frequency (default: 100)') 83 | opt = parser.parse_args() 84 | opt.cuda = True 85 | cudnn.benchmark = True 86 | 87 | 88 | 89 | # ======================================= Define functions ============================================= 90 | def adjust_learning_rate(optimizer, epoch_num): 91 | """Sets the learning rate to the initial LR decayed by 0.05 every 10 epochs""" 92 | lr = opt.lr * (0.05 ** (epoch_num // 10)) 93 | for param_group in optimizer.param_groups: 94 | param_group['lr'] = lr 95 | 96 | 97 | def train(train_loader, model, criterion, optimizer, epoch_index, F_txt): 98 | batch_time = AverageMeter() 99 | data_time = AverageMeter() 100 | losses = AverageMeter() 101 | top1 = AverageMeter() 102 | 103 | 104 | end = time.time() 105 | for episode_index, (query_images, query_targets, support_images, support_targets) in enumerate(train_loader): 106 | 107 | # Measure data loading time 108 | data_time.update(time.time() - end) 109 | 110 | # Convert query and support images 111 | query_images = torch.cat(query_images, 0) 112 | input_var1 = query_images.cuda() 113 | 114 | input_var2 = [] 115 | for i in range(len(support_images)): 116 | temp_support = support_images[i] 117 | temp_support = torch.cat(temp_support, 0) 118 | temp_support = temp_support.cuda() 119 | input_var2.append(temp_support) 120 | 121 | # Deal with the targets 122 | target = torch.cat(query_targets, 0) 123 | target = target.cuda() 124 | 125 | # Calculate the output 126 | output = model(input_var1, input_var2) 127 | loss = criterion(output, target) 128 | 129 | # Compute gradients and do SGD step 130 | optimizer.zero_grad() 131 | loss.backward() 132 | optimizer.step() 133 | 134 | 135 | # Measure accuracy and record loss 136 | prec1, _ = accuracy(output, target, topk=(1,3)) 137 | losses.update(loss.item(), query_images.size(0)) 138 | top1.update(prec1[0], query_images.size(0)) 139 | 140 | 141 | # Measure elapsed time 142 | batch_time.update(time.time() - end) 143 | end = time.time() 144 | 145 | 146 | #============== print the intermediate results ==============# 147 | if episode_index % opt.print_freq == 0 and episode_index != 0: 148 | 149 | print('Eposide-({0}): [{1}/{2}]\t' 150 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 151 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 152 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 153 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 154 | epoch_index, episode_index, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1)) 155 | 156 | print('Eposide-({0}): [{1}/{2}]\t' 157 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 158 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 159 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 160 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 161 | epoch_index, episode_index, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1), file=F_txt) 162 | 163 | 164 | 165 | def validate(val_loader, model, criterion, epoch_index, best_prec1, F_txt): 166 | batch_time = AverageMeter() 167 | losses = AverageMeter() 168 | top1 = AverageMeter() 169 | 170 | 171 | # switch to evaluate mode 172 | model.eval() 173 | accuracies = [] 174 | 175 | 176 | end = time.time() 177 | for episode_index, (query_images, query_targets, support_images, support_targets) in enumerate(val_loader): 178 | 179 | # Convert query and support images 180 | query_images = torch.cat(query_images, 0) 181 | input_var1 = query_images.cuda() 182 | 183 | 184 | input_var2 = [] 185 | for i in range(len(support_images)): 186 | temp_support = support_images[i] 187 | temp_support = torch.cat(temp_support, 0) 188 | temp_support = temp_support.cuda() 189 | input_var2.append(temp_support) 190 | 191 | 192 | # Deal with the targets 193 | target = torch.cat(query_targets, 0) 194 | target = target.cuda() 195 | 196 | # Calculate the output 197 | output = model(input_var1, input_var2) 198 | loss = criterion(output, target) 199 | 200 | 201 | # measure accuracy and record loss 202 | prec1, _ = accuracy(output, target, topk=(1, 3)) 203 | losses.update(loss.item(), query_images.size(0)) 204 | top1.update(prec1[0], query_images.size(0)) 205 | accuracies.append(prec1) 206 | 207 | 208 | # measure elapsed time 209 | batch_time.update(time.time() - end) 210 | end = time.time() 211 | 212 | 213 | #============== print the intermediate results ==============# 214 | if episode_index % opt.print_freq == 0 and episode_index != 0: 215 | 216 | print('Test-({0}): [{1}/{2}]\t' 217 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 218 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 219 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 220 | epoch_index, episode_index, len(val_loader), batch_time=batch_time, loss=losses, top1=top1)) 221 | 222 | print('Test-({0}): [{1}/{2}]\t' 223 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 224 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 225 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 226 | epoch_index, episode_index, len(val_loader), batch_time=batch_time, loss=losses, top1=top1), file=F_txt) 227 | 228 | 229 | print(' * Prec@1 {top1.avg:.3f} Best_prec1 {best_prec1:.3f}'.format(top1=top1, best_prec1=best_prec1)) 230 | print(' * Prec@1 {top1.avg:.3f} Best_prec1 {best_prec1:.3f}'.format(top1=top1, best_prec1=best_prec1), file=F_txt) 231 | 232 | return top1.avg, accuracies 233 | 234 | 235 | def save_checkpoint(state, filename='checkpoint.pth.tar'): 236 | torch.save(state, filename) 237 | 238 | 239 | 240 | class AverageMeter(object): 241 | """Computes and stores the average and current value""" 242 | def __init__(self): 243 | self.reset() 244 | 245 | def reset(self): 246 | self.val = 0 247 | self.avg = 0 248 | self.sum = 0 249 | self.count = 0 250 | 251 | def update(self, val, n=1): 252 | self.val = val 253 | self.sum += val * n 254 | self.count += n 255 | self.avg = self.sum / self.count 256 | 257 | 258 | def accuracy(output, target, topk=(1,)): 259 | """Computes the precision@k for the specified values of k""" 260 | with torch.no_grad(): 261 | maxk = max(topk) 262 | batch_size = target.size(0) 263 | 264 | _, pred = output.topk(maxk, 1, True, True) 265 | pred = pred.t() 266 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 267 | 268 | res = [] 269 | for k in topk: 270 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 271 | res.append(correct_k.mul_(100.0 / batch_size)) 272 | return res 273 | 274 | 275 | # ======================================== Settings of path ============================================ 276 | # saving path 277 | opt.outf = opt.outf+'_'+opt.data_name+'_'+str(opt.basemodel)+'_'+str(opt.way_num)+'Way_'+str(opt.shot_num)+'Shot' 278 | 279 | if not os.path.exists(opt.outf): 280 | os.makedirs(opt.outf) 281 | 282 | if torch.cuda.is_available() and not opt.cuda: 283 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 284 | 285 | # save the opt and results to a txt file 286 | txt_save_path = os.path.join(opt.outf, 'opt_resutls.txt') 287 | F_txt = open(txt_save_path, 'a+') 288 | print(opt) 289 | print(opt, file=F_txt) 290 | 291 | 292 | 293 | # ========================================== Model Config =============================================== 294 | ngpu = int(opt.ngpu) 295 | global best_prec1, epoch_index 296 | best_prec1 = 0 297 | epoch_index = 0 298 | model = CovaNet.define_CovarianceNet(which_model=opt.basemodel, num_classes=opt.way_num, norm='batch', 299 | init_type='normal', use_gpu=opt.cuda) 300 | 301 | 302 | # define loss function (criterion) and optimizer 303 | criterion = nn.CrossEntropyLoss().cuda() 304 | optimizer = optim.Adam(model.parameters(), lr=opt.lr, betas=(opt.beta1, 0.9)) 305 | 306 | 307 | # optionally resume from a checkpoint 308 | if opt.resume: 309 | if os.path.isfile(opt.resume): 310 | print("=> loading checkpoint '{}'".format(opt.resume)) 311 | checkpoint = torch.load(opt.resume) 312 | epoch_index = checkpoint['epoch_index'] 313 | best_prec1 = checkpoint['best_prec1'] 314 | model.load_state_dict(checkpoint['state_dict']) 315 | optimizer.load_state_dict(checkpoint['optimizer']) 316 | print("=> loaded checkpoint '{}' (epoch {})".format(opt.resume, checkpoint['epoch_index'])) 317 | else: 318 | print("=> no checkpoint found at '{}'".format(opt.resume)) 319 | 320 | if opt.ngpu > 1: 321 | model = nn.DataParallel(model, range(opt.ngpu)) 322 | 323 | print(model) 324 | print(model, file=F_txt) # print the architecture of the network 325 | 326 | 327 | 328 | 329 | # ============================================ Training phase ======================================== 330 | print('\n............Start training............\n') 331 | start_time = time.time() 332 | 333 | 334 | for epoch_item in range(opt.epochs): 335 | print('===================================== Epoch %d =====================================' %epoch_item) 336 | print('===================================== Epoch %d =====================================' %epoch_item, file=F_txt) 337 | adjust_learning_rate(optimizer, epoch_item) 338 | 339 | 340 | # ======================================= Folder of Datasets ======================================= 341 | # image transform & normalization 342 | ImgTransform = transforms.Compose([ 343 | transforms.Resize((opt.imageSize, opt.imageSize)), 344 | transforms.ToTensor(), 345 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 346 | ]) 347 | 348 | trainset = Imagefolder_csv( 349 | data_dir=opt.dataset_dir, mode=opt.mode, image_size=opt.imageSize, transform=ImgTransform, 350 | episode_num=opt.episode_train_num, way_num=opt.way_num, shot_num=opt.shot_num, query_num=opt.query_num 351 | ) 352 | valset = Imagefolder_csv( 353 | data_dir=opt.dataset_dir, mode='val', image_size=opt.imageSize, transform=ImgTransform, 354 | episode_num=opt.episode_val_num, way_num=opt.way_num, shot_num=opt.shot_num, query_num=opt.query_num 355 | ) 356 | testset = Imagefolder_csv( 357 | data_dir=opt.dataset_dir, mode='test', image_size=opt.imageSize, transform=ImgTransform, 358 | episode_num=opt.episode_test_num, way_num=opt.way_num, shot_num=opt.shot_num, query_num=opt.query_num 359 | ) 360 | 361 | print('Trainset: %d' %len(trainset)) 362 | print('Valset: %d' %len(valset)) 363 | print('Testset: %d' %len(testset)) 364 | print('Trainset: %d' %len(trainset), file=F_txt) 365 | print('Valset: %d' %len(valset), file=F_txt) 366 | print('Testset: %d' %len(testset), file=F_txt) 367 | 368 | 369 | 370 | # ========================================== Load Datasets ========================================= 371 | train_loader = torch.utils.data.DataLoader( 372 | trainset, batch_size=opt.episodeSize, shuffle=True, 373 | num_workers=int(opt.workers), drop_last=True, pin_memory=True 374 | ) 375 | val_loader = torch.utils.data.DataLoader( 376 | valset, batch_size=opt.testepisodeSize, shuffle=True, 377 | num_workers=int(opt.workers), drop_last=True, pin_memory=True 378 | ) 379 | test_loader = torch.utils.data.DataLoader( 380 | testset, batch_size=opt.testepisodeSize, shuffle=True, 381 | num_workers=int(opt.workers), drop_last=True, pin_memory=True 382 | ) 383 | 384 | 385 | # ============================================ Training =========================================== 386 | # Fix the parameters of Batch Normalization after 10000 episodes (1 epoch) 387 | if epoch_item < 1: 388 | model.train() 389 | else: 390 | model.eval() 391 | 392 | # Train for 10000 episodes in each epoch 393 | train(train_loader, model, criterion, optimizer, epoch_item, F_txt) 394 | 395 | 396 | # =========================================== Evaluation ========================================== 397 | print('============ Validation on the val set ============') 398 | print('============ validation on the val set ============', file=F_txt) 399 | prec1, _ = validate(val_loader, model, criterion, epoch_item, best_prec1, F_txt) 400 | 401 | 402 | # record the best prec@1 and save checkpoint 403 | is_best = prec1 > best_prec1 404 | best_prec1 = max(prec1, best_prec1) 405 | 406 | # save the checkpoint 407 | if is_best: 408 | save_checkpoint( 409 | { 410 | 'epoch_index': epoch_item, 411 | 'arch': opt.basemodel, 412 | 'state_dict': model.state_dict(), 413 | 'best_prec1': best_prec1, 414 | 'optimizer' : optimizer.state_dict(), 415 | }, os.path.join(opt.outf, 'model_best.pth.tar')) 416 | 417 | 418 | if epoch_item % 10 == 0: 419 | filename = os.path.join(opt.outf, 'epoch_%d.pth.tar' %epoch_item) 420 | save_checkpoint( 421 | { 422 | 'epoch_index': epoch_item, 423 | 'arch': opt.basemodel, 424 | 'state_dict': model.state_dict(), 425 | 'best_prec1': best_prec1, 426 | 'optimizer' : optimizer.state_dict(), 427 | }, filename) 428 | 429 | 430 | # Testing Prase 431 | print('============ Testing on the test set ============') 432 | print('============ Testing on the test set ============', file=F_txt) 433 | prec1, _ = validate(test_loader, model, criterion, epoch_item, best_prec1, F_txt) 434 | 435 | 436 | F_txt.close() 437 | print('............Training is end............') 438 | 439 | # ============================================ Training End ========================================== 440 | 441 | 442 | 443 | 444 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, Wenbin Li 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | 26 | --------------------------- LICENSE FOR CovaMNet -------------------------------- 27 | BSD License 28 | 29 | For CovaMNet software 30 | Copyright (c) 2019, Wenbin Li 31 | All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without 34 | modification, are permitted provided that the following conditions are met: 35 | 36 | * Redistributions of source code must retain the above copyright notice, this 37 | list of conditions and the following disclaimer. 38 | 39 | * Redistributions in binary form must reproduce the above copyright notice, 40 | this list of conditions and the following disclaimer in the documentation 41 | and/or other materials provided with the distribution. 42 | 43 | 44 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CovaMNet in PyTorch 2 | 3 | We provide a PyTorch implementation of CovaMNet for few-shot learning. The code was written by [Wenbin Li](https://github.com/WenbinLee) [Homepage].
4 | 5 | If you use this code for your research, please cite: 6 | 7 | [Distribution Consistency based Covariance Metric Networks for Few-shot Learning](https://cs.nju.edu.cn/rl/people/liwb/AAAI19.pdf).
8 | [Wenbin Li](https://cs.nju.edu.cn/liwenbin/), Jinglin Xu, Jing Huo, Lei Wang, Yang Gao and Jiebo Luo. In AAAI 2019.
9 | 10 | 11 | 12 | ## Prerequisites 13 | - Linux 14 | - Python 3 15 | - Pytorch 0.4 16 | - GPU + CUDA CuDNN 17 | 18 | ## Getting Started 19 | ### Installation 20 | 21 | - Clone this repo: 22 | ```bash 23 | git clone https://github.com/WenbinLee/CovaMNet 24 | cd CovaMNet 25 | ``` 26 | 27 | - Install [PyTorch](http://pytorch.org) 0.4 and other dependencies (e.g., torchvision). 28 | 29 | ### Datasets 30 | - [miniImageNet](https://drive.google.com/file/d/1fUBrpv8iutYwdL4xE1rX_R9ef6tyncX9/view). 31 | - [StanfordDog](http://vision.stanford.edu/aditya86/ImageNetDogs/). 32 | - [StanfordCar](https://ai.stanford.edu/~jkrause/cars/car_dataset.html). 33 | - [CUB-200](http://www.vision.caltech.edu/visipedia/CUB-200.html).
34 | Thanks [Victor Garcia](https://github.com/vgsatorras/few-shot-gnn) for providing the miniImageNet dataset. In our paper, we just used the CUB-200 dataset. In fact, there is a newer revision of this dataset with more images, see [Caltech-UCSD Birds-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html). Note, if you use these datasets, please cite the corresponding papers. 35 | 36 | 37 | ### miniImageNet Few-shot Classification 38 | - Train a 5-way 1-shot model: 39 | ```bash 40 | python CovaMNet_Train_5way1shot.py --dataset_dir ./datasets/miniImageNet --data_name miniImageNet 41 | ``` 42 | - Test the model (specify the dataset_dir and data_name first): 43 | ```bash 44 | python CovaMNet_Test_5way1shot.py --resume ./results/CovaMNet_miniImageNet_Conv64_5_Way_1_Shot/model_best.pth.tar 45 | ``` 46 | - The results on the miniImageNet dataset: 47 | 48 | 49 | 50 | ### Fine-grained Few-shot Classification 51 | - Data prepocessing (e.g., StanfordDog) 52 | - Specify the path of the dataset and the saving path. 53 | - Run the preprocessing script. 54 | ```bash 55 | #!./dataset/StanfordDog/StanfordDog_prepare_csv.py 56 | python ./dataset/StanfordDog/StanfordDog_prepare_csv.py 57 | ``` 58 | - Train a 5-way 1-shot model: 59 | ```bash 60 | python CovaMNet_Train_5way1shot.py --dataset_dir ./datasets/StanfordDog --data_name StanfordDog 61 | ``` 62 | - Test the model (specify the dataset_dir and data_name first): 63 | ```bash 64 | python CovaMNet_Test_5way1shot.py --resume ./results/CovaMNet_StanfordDog_Conv64_5_Way_1_Shot/model_best.pth.tar 65 | ``` 66 | - The results on the fine-grained datasets: 67 | 68 | 69 | 70 | 71 | ## Citation 72 | If you use this code for your research, please cite our paper. 73 | ``` 74 | @inproceedings{li2019CovaMNet, 75 | title={Distribution Consistency based Covariance Metric Networks for Few-shot Learning}, 76 | author={Li, Wenbin and Xu, Jinglin and Huo, Jing and Wang, Lei and Gao Yang and Luo, Jiebo}, 77 | booktitle={AAAI}, 78 | year={2019} 79 | } 80 | 81 | ``` 82 | 83 | -------------------------------------------------------------------------------- /dataset/CubBird/CubBird_prepare_csv.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Wenbin Li 3 | ## Date: Dec. 16 2018 4 | ## 5 | ## Divide data into train/val/test in a csv version 6 | ## Output: train.csv, val.csv, test.csv 7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | 9 | import os 10 | import csv 11 | import numpy as np 12 | import random 13 | from PIL import Image 14 | import pdb 15 | 16 | 17 | data_dir = '/FewShot/Datasets/CUB_birds' # the path of the download dataset 18 | save_dir = '/FewShot/Datasets/CUB_birds/For_FewShot' # the saving path of the divided dataset 19 | 20 | 21 | if not os.path.exists(os.path.join(save_dir, 'images')): 22 | os.makedirs(os.path.join(save_dir, 'images')) 23 | 24 | images_dir = os.path.join(data_dir, 'images') 25 | train_class_num = 130 26 | val_class_num = 20 27 | test_class_num = 50 28 | 29 | 30 | 31 | # get all the dog classes 32 | classes_list = [class_name for class_name in os.listdir(images_dir) if os.path.isdir(os.path.join(images_dir, class_name))] 33 | 34 | 35 | # divide the train/val/test set 36 | random.seed(200) 37 | train_list = random.sample(classes_list, train_class_num) 38 | remain_list = [rem for rem in classes_list if rem not in train_list] 39 | val_list = random.sample(remain_list, val_class_num) 40 | test_list = [rem for rem in remain_list if rem not in val_list] 41 | 42 | 43 | # save data into csv file----- Train 44 | train_data = [] 45 | for class_name in train_list: 46 | images = [[i, class_name] for i in os.listdir(os.path.join(images_dir, class_name))] 47 | train_data.extend(images) 48 | print('Train----%s' %class_name) 49 | 50 | # read images and store these images 51 | img_paths = [os.path.join(images_dir, class_name, i) for i in os.listdir(os.path.join(images_dir, class_name))] 52 | for index, img_path in enumerate(img_paths): 53 | img = Image.open(img_path) 54 | img = img.convert('RGB') 55 | img.save(os.path.join(save_dir, 'images', images[index][0]), quality=100) 56 | 57 | 58 | with open(os.path.join(save_dir, 'train.csv'), 'w') as csvfile: 59 | writer = csv.writer(csvfile) 60 | 61 | writer.writerow(['filename', 'label']) 62 | writer.writerows(train_data) 63 | 64 | 65 | 66 | 67 | # save data into csv file----- Val 68 | val_data = [] 69 | for class_name in val_list: 70 | images = [[i, class_name] for i in os.listdir(os.path.join(images_dir, class_name))] 71 | val_data.extend(images) 72 | print('Val----%s' %class_name) 73 | 74 | # read images and store these images 75 | img_paths = [os.path.join(images_dir, class_name, i) for i in os.listdir(os.path.join(images_dir, class_name))] 76 | for index, img_path in enumerate(img_paths): 77 | img = Image.open(img_path) 78 | img = img.convert('RGB') 79 | img.save(os.path.join(save_dir, 'images', images[index][0]), quality=100) 80 | 81 | with open(os.path.join(save_dir, 'val.csv'), 'w') as csvfile: 82 | writer = csv.writer(csvfile) 83 | 84 | writer.writerow(['filename', 'label']) 85 | writer.writerows(val_data) 86 | 87 | 88 | 89 | 90 | # save data into csv file----- Test 91 | test_data = [] 92 | for class_name in test_list: 93 | images = [[i, class_name] for i in os.listdir(os.path.join(images_dir, class_name))] 94 | test_data.extend(images) 95 | print('Test----%s' %class_name) 96 | 97 | # read images and store these images 98 | img_paths = [os.path.join(images_dir, class_name, i) for i in os.listdir(os.path.join(images_dir, class_name))] 99 | for index, img_path in enumerate(img_paths): 100 | img = Image.open(img_path) 101 | img = img.convert('RGB') 102 | img.save(os.path.join(save_dir, 'images', images[index][0]), quality=100) 103 | 104 | 105 | with open(os.path.join(save_dir, 'test.csv'), 'w') as csvfile: 106 | writer = csv.writer(csvfile) 107 | 108 | writer.writerow(['filename', 'label']) 109 | writer.writerows(test_data) 110 | -------------------------------------------------------------------------------- /dataset/StanfordCar/StanforCar_prepare_csv.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Wenbin Li 3 | ## Date: Dec. 16 2018 4 | ## 5 | ## Divide data into train/val/test in a csv version 6 | ## Output: train.csv, val.csv, test.csv 7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | 9 | import os 10 | import csv 11 | import numpy as np 12 | import random 13 | from PIL import Image 14 | import pdb 15 | 16 | 17 | data_dir = '/FewShot/Datasets/Stanford_cars' # the path of the download dataset 18 | save_dir = '/FewShot/Datasets/Stanford_cars/For_FewShot' # the saving path of the divided dataset 19 | 20 | 21 | if not os.path.exists(os.path.join(save_dir, 'images')): 22 | os.makedirs(os.path.join(save_dir, 'images')) 23 | 24 | images_dir = os.path.join(data_dir, 'images') 25 | train_class_num = 130 26 | val_class_num = 17 27 | test_class_num = 49 28 | 29 | 30 | 31 | # get all the dog classes 32 | classes_list = [class_name for class_name in os.listdir(images_dir) if os.path.isdir(os.path.join(images_dir, class_name))] 33 | 34 | 35 | # divide the train/val/test set 36 | random.seed(196) 37 | train_list = random.sample(classes_list, train_class_num) 38 | remain_list = [rem for rem in classes_list if rem not in train_list] 39 | val_list = random.sample(remain_list, val_class_num) 40 | test_list = [rem for rem in remain_list if rem not in val_list] 41 | 42 | 43 | # save data into csv file----- Train 44 | train_data = [] 45 | for class_name in train_list: 46 | images = [[i, class_name] for i in os.listdir(os.path.join(images_dir, class_name))] 47 | train_data.extend(images) 48 | print('Train----%s' %class_name) 49 | 50 | # read images and store these images 51 | img_paths = [os.path.join(images_dir, class_name, i) for i in os.listdir(os.path.join(images_dir, class_name))] 52 | for index, img_path in enumerate(img_paths): 53 | img = Image.open(img_path) 54 | img = img.convert('RGB') 55 | img.save(os.path.join(save_dir, 'images', images[index][0]), quality=100) 56 | 57 | 58 | with open(os.path.join(save_dir, 'train.csv'), 'w') as csvfile: 59 | writer = csv.writer(csvfile) 60 | 61 | writer.writerow(['filename', 'label']) 62 | writer.writerows(train_data) 63 | 64 | 65 | 66 | 67 | # save data into csv file----- Val 68 | val_data = [] 69 | for class_name in val_list: 70 | images = [[i, class_name] for i in os.listdir(os.path.join(images_dir, class_name))] 71 | val_data.extend(images) 72 | print('Val----%s' %class_name) 73 | 74 | # read images and store these images 75 | img_paths = [os.path.join(images_dir, class_name, i) for i in os.listdir(os.path.join(images_dir, class_name))] 76 | for index, img_path in enumerate(img_paths): 77 | img = Image.open(img_path) 78 | img = img.convert('RGB') 79 | img.save(os.path.join(save_dir, 'images', images[index][0]), quality=100) 80 | 81 | with open(os.path.join(save_dir, 'val.csv'), 'w') as csvfile: 82 | writer = csv.writer(csvfile) 83 | 84 | writer.writerow(['filename', 'label']) 85 | writer.writerows(val_data) 86 | 87 | 88 | 89 | 90 | # save data into csv file----- Test 91 | test_data = [] 92 | for class_name in test_list: 93 | images = [[i, class_name] for i in os.listdir(os.path.join(images_dir, class_name))] 94 | test_data.extend(images) 95 | print('Test----%s' %class_name) 96 | 97 | # read images and store these images 98 | img_paths = [os.path.join(images_dir, class_name, i) for i in os.listdir(os.path.join(images_dir, class_name))] 99 | for index, img_path in enumerate(img_paths): 100 | img = Image.open(img_path) 101 | img = img.convert('RGB') 102 | img.save(os.path.join(save_dir, 'images', images[index][0]), quality=100) 103 | 104 | 105 | with open(os.path.join(save_dir, 'test.csv'), 'w') as csvfile: 106 | writer = csv.writer(csvfile) 107 | 108 | writer.writerow(['filename', 'label']) 109 | writer.writerows(test_data) 110 | -------------------------------------------------------------------------------- /dataset/StanfordDog/StanfordDog_prepare_csv.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Wenbin Li 3 | ## Date: Dec. 16 2018 4 | ## 5 | ## Divide data into train/val/test in a csv version 6 | ## Output: train.csv, val.csv, test.csv 7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | 9 | import os 10 | import csv 11 | import numpy as np 12 | import random 13 | from PIL import Image 14 | import pdb 15 | 16 | 17 | data_dir = '/FewShot/Datasets/Stanford_dogs' # the path of the download dataset 18 | save_dir = '/FewShot/Datasets/Stanford_dogs/For_FewShot' # the saving path of the divided dataset 19 | 20 | 21 | if not os.path.exists(os.path.join(save_dir, 'images')): 22 | os.makedirs(os.path.join(save_dir, 'images')) 23 | 24 | images_dir = os.path.join(data_dir, 'Images') 25 | train_class_num = 70 26 | val_class_num = 20 27 | test_class_num = 30 28 | 29 | 30 | 31 | # get all the dog classes 32 | classes_list = [class_name for class_name in os.listdir(images_dir) if os.path.isdir(os.path.join(images_dir, class_name))] 33 | 34 | 35 | # divide the train/val/test set 36 | random.seed(120) 37 | train_list = random.sample(classes_list, train_class_num) 38 | remain_list = [rem for rem in classes_list if rem not in train_list] 39 | val_list = random.sample(remain_list, val_class_num) 40 | test_list = [rem for rem in remain_list if rem not in val_list] 41 | 42 | 43 | # save data into csv file----- Train 44 | train_data = [] 45 | for class_name in train_list: 46 | images = [[i, class_name] for i in os.listdir(os.path.join(images_dir, class_name))] 47 | train_data.extend(images) 48 | print('Train----%s' %class_name) 49 | 50 | # read images and store these images 51 | img_paths = [os.path.join(images_dir, class_name, i) for i in os.listdir(os.path.join(images_dir, class_name))] 52 | for index, img_path in enumerate(img_paths): 53 | img = Image.open(img_path) 54 | img = img.convert('RGB') 55 | img.save(os.path.join(save_dir, 'images', images[index][0]), quality=100) 56 | 57 | 58 | with open(os.path.join(save_dir, 'train.csv'), 'w') as csvfile: 59 | writer = csv.writer(csvfile) 60 | 61 | writer.writerow(['filename', 'label']) 62 | writer.writerows(train_data) 63 | 64 | 65 | 66 | 67 | # save data into csv file----- Val 68 | val_data = [] 69 | for class_name in val_list: 70 | images = [[i, class_name] for i in os.listdir(os.path.join(images_dir, class_name))] 71 | val_data.extend(images) 72 | print('Val----%s' %class_name) 73 | 74 | # read images and store these images 75 | img_paths = [os.path.join(images_dir, class_name, i) for i in os.listdir(os.path.join(images_dir, class_name))] 76 | for index, img_path in enumerate(img_paths): 77 | img = Image.open(img_path) 78 | img = img.convert('RGB') 79 | img.save(os.path.join(save_dir, 'images', images[index][0]), quality=100) 80 | 81 | with open(os.path.join(save_dir, 'val.csv'), 'w') as csvfile: 82 | writer = csv.writer(csvfile) 83 | 84 | writer.writerow(['filename', 'label']) 85 | writer.writerows(val_data) 86 | 87 | 88 | 89 | 90 | # save data into csv file----- Test 91 | test_data = [] 92 | for class_name in test_list: 93 | images = [[i, class_name] for i in os.listdir(os.path.join(images_dir, class_name))] 94 | test_data.extend(images) 95 | print('Test----%s' %class_name) 96 | 97 | # read images and store these images 98 | img_paths = [os.path.join(images_dir, class_name, i) for i in os.listdir(os.path.join(images_dir, class_name))] 99 | for index, img_path in enumerate(img_paths): 100 | img = Image.open(img_path) 101 | img = img.convert('RGB') 102 | img.save(os.path.join(save_dir, 'images', images[index][0]), quality=100) 103 | 104 | 105 | with open(os.path.join(save_dir, 'test.csv'), 'w') as csvfile: 106 | writer = csv.writer(csvfile) 107 | 108 | writer.writerow(['filename', 'label']) 109 | writer.writerows(test_data) 110 | -------------------------------------------------------------------------------- /dataset/datasets_csv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as path 3 | import json 4 | import torch 5 | import torch.utils.data as data 6 | import numpy as np 7 | import random 8 | from PIL import Image 9 | import pdb 10 | import csv 11 | import sys 12 | sys.dont_write_bytecode = True 13 | 14 | 15 | 16 | def pil_loader(path): 17 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 18 | with open(path, 'rb') as f: 19 | with Image.open(f) as img: 20 | return img.convert('RGB') 21 | 22 | 23 | def accimage_loader(path): 24 | import accimage 25 | try: 26 | return accimage.Image(path) 27 | except IOError: 28 | # Potentially a decoding problem, fall back to PIL.Image 29 | return pil_loader(path) 30 | 31 | 32 | def gray_loader(path): 33 | with open(path, 'rb') as f: 34 | with Image.open(f) as img: 35 | return img.convert('P') 36 | 37 | 38 | def default_loader(path): 39 | from torchvision import get_image_backend 40 | if get_image_backend() == 'accimage': 41 | return accimage_loader(path) 42 | else: 43 | return pil_loader(path) 44 | 45 | 46 | def find_classes(dir): 47 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 48 | classes.sort() 49 | class_to_idx = {classes[i]: i for i in range(len(classes))} 50 | 51 | return classes, class_to_idx 52 | 53 | 54 | class Imagefolder_csv(object): 55 | """ 56 | Imagefolder for miniImageNet--ravi, StanfordDog, StanfordCar and CubBird datasets. 57 | Images are stored in the folder of "images"; 58 | Indexes are stored in the CSV files. 59 | """ 60 | 61 | def __init__(self, data_dir="", mode="train", image_size=84, data_name="miniImageNet", 62 | transform=None, loader=default_loader, gray_loader=gray_loader, 63 | episode_num=1000, way_num=5, shot_num=5, query_num=5): 64 | 65 | super(Imagefolder_csv, self).__init__() 66 | 67 | 68 | # set the paths of the csv files 69 | train_csv = os.path.join(data_dir, 'train.csv') 70 | val_csv = os.path.join(data_dir, 'val.csv') 71 | test_csv = os.path.join(data_dir, 'test.csv') 72 | 73 | 74 | data_list = [] 75 | e = 0 76 | if mode == "train": 77 | 78 | # store all the classes and images into a dict 79 | class_img_dict = {} 80 | with open(train_csv) as f_csv: 81 | f_train = csv.reader(f_csv, delimiter=',') 82 | for row in f_train: 83 | if f_train.line_num == 1: 84 | continue 85 | img_name, img_class = row 86 | 87 | if img_class in class_img_dict: 88 | class_img_dict[img_class].append(img_name) 89 | else: 90 | class_img_dict[img_class]=[] 91 | class_img_dict[img_class].append(img_name) 92 | f_csv.close() 93 | class_list = class_img_dict.keys() 94 | 95 | 96 | while e < episode_num: 97 | 98 | # construct each episode 99 | episode = [] 100 | e += 1 101 | temp_list = random.sample(class_list, way_num) 102 | label_num = -1 103 | 104 | for item in temp_list: 105 | label_num += 1 106 | imgs_set = class_img_dict[item] 107 | support_imgs = random.sample(imgs_set, shot_num) 108 | query_imgs = [val for val in imgs_set if val not in support_imgs] 109 | 110 | if query_num < len(query_imgs): 111 | query_imgs = random.sample(query_imgs, query_num) 112 | 113 | 114 | # the dir of support set 115 | query_dir = [path.join(data_dir, 'images', i) for i in query_imgs] 116 | support_dir = [path.join(data_dir, 'images', i) for i in support_imgs] 117 | 118 | 119 | data_files = { 120 | "query_img": query_dir, 121 | "support_set": support_dir, 122 | "target": label_num 123 | } 124 | episode.append(data_files) 125 | data_list.append(episode) 126 | 127 | 128 | elif mode == "val": 129 | 130 | # store all the classes and images into a dict 131 | class_img_dict = {} 132 | with open(val_csv) as f_csv: 133 | f_val = csv.reader(f_csv, delimiter=',') 134 | for row in f_val: 135 | if f_val.line_num == 1: 136 | continue 137 | img_name, img_class = row 138 | 139 | if img_class in class_img_dict: 140 | class_img_dict[img_class].append(img_name) 141 | else: 142 | class_img_dict[img_class]=[] 143 | class_img_dict[img_class].append(img_name) 144 | f_csv.close() 145 | class_list = class_img_dict.keys() 146 | 147 | 148 | 149 | while e < episode_num: # setting the episode number to 600 150 | 151 | # construct each episode 152 | episode = [] 153 | e += 1 154 | temp_list = random.sample(class_list, way_num) 155 | label_num = -1 156 | 157 | for item in temp_list: 158 | label_num += 1 159 | imgs_set = class_img_dict[item] 160 | support_imgs = random.sample(imgs_set, shot_num) 161 | query_imgs = [val for val in imgs_set if val not in support_imgs] 162 | 163 | if query_num Covariance metric layer --> Classification layer 134 | # Dataset: 84 x 84 x 3, for miniImageNet, StanfordDog, StanfordCar, CubBird 135 | # Filters: 64->64->64->64 136 | # Mapping Sizes: 84->42->21->21->21 137 | 138 | 139 | class CovarianceNet_64(nn.Module): 140 | def __init__(self, norm_layer=nn.BatchNorm2d, num_classes=5): 141 | super(CovarianceNet_64, self).__init__() 142 | 143 | if type(norm_layer) == functools.partial: 144 | use_bias = norm_layer.func == nn.InstanceNorm2d 145 | else: 146 | use_bias = norm_layer == nn.InstanceNorm2d 147 | 148 | self.features = nn.Sequential( # 3*84*84 149 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), 150 | norm_layer(64), 151 | nn.LeakyReLU(0.2, True), 152 | nn.MaxPool2d(kernel_size=2, stride=2), # 64*42*42 153 | 154 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), 155 | norm_layer(64), 156 | nn.LeakyReLU(0.2, True), 157 | nn.MaxPool2d(kernel_size=2, stride=2), # 64*21*21 158 | 159 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), 160 | norm_layer(64), 161 | nn.LeakyReLU(0.2, True), # 64*21*21 162 | 163 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), 164 | norm_layer(64), 165 | nn.LeakyReLU(0.2, True), # 64*21*21 166 | ) 167 | 168 | self.covariance = CovaBlock() # 1*(441*num_classes) 169 | 170 | self.classifier = nn.Sequential( 171 | nn.LeakyReLU(0.2, True), 172 | nn.Dropout(), 173 | nn.Conv1d(1, 1, kernel_size=441, stride=441, bias=use_bias), 174 | ) 175 | 176 | 177 | def forward(self, input1, input2): 178 | 179 | # extract features of input1--query image 180 | q = self.features(input1) 181 | 182 | # extract features of input2--support set 183 | S = [] 184 | for i in range(len(input2)): 185 | S.append(self.features(input2[i])) 186 | 187 | x = self.covariance(q, S) # get Batch*1*(h*w*num_classes) 188 | x = self.classifier(x) # get Batch*1*num_classes 189 | x = x.squeeze(1) # get Batch*num_classes 190 | 191 | return x 192 | 193 | 194 | 195 | #========================== Define a Covariance Metric layer ==========================# 196 | # Calculate the local covariance matrix of each category in the support set 197 | # Calculate the Covariance Metric between a query sample and a category 198 | 199 | 200 | class CovaBlock(nn.Module): 201 | def __init__(self): 202 | super(CovaBlock, self).__init__() 203 | 204 | 205 | # calculate the covariance matrix 206 | def cal_covariance(self, input): 207 | 208 | CovaMatrix_list = [] 209 | for i in range(len(input)): 210 | support_set_sam = input[i] 211 | B, C, h, w = support_set_sam.size() 212 | 213 | support_set_sam = support_set_sam.permute(1, 0, 2, 3) 214 | support_set_sam = support_set_sam.contiguous().view(C, -1) 215 | mean_support = torch.mean(support_set_sam, 1, True) 216 | support_set_sam = support_set_sam-mean_support 217 | 218 | covariance_matrix = support_set_sam@torch.transpose(support_set_sam, 0, 1) 219 | covariance_matrix = torch.div(covariance_matrix, h*w*B-1) 220 | CovaMatrix_list.append(covariance_matrix) 221 | 222 | return CovaMatrix_list 223 | 224 | 225 | # calculate the similarity 226 | def cal_similarity(self, input, CovaMatrix_list): 227 | 228 | B, C, h, w = input.size() 229 | Cova_Sim = [] 230 | 231 | for i in range(B): 232 | query_sam = input[i] 233 | query_sam = query_sam.view(C, -1) 234 | query_sam_norm = torch.norm(query_sam, 2, 1, True) 235 | query_sam = query_sam/query_sam_norm 236 | 237 | if torch.cuda.is_available(): 238 | mea_sim = torch.zeros(1, len(CovaMatrix_list)*h*w).cuda() 239 | 240 | for j in range(len(CovaMatrix_list)): 241 | temp_dis = torch.transpose(query_sam, 0, 1)@CovaMatrix_list[j]@query_sam 242 | mea_sim[0, j*h*w:(j+1)*h*w] = temp_dis.diag() 243 | 244 | Cova_Sim.append(mea_sim.unsqueeze(0)) 245 | 246 | Cova_Sim = torch.cat(Cova_Sim, 0) # get Batch*1*(h*w*num_classes) 247 | return Cova_Sim 248 | 249 | 250 | def forward(self, x1, x2): 251 | 252 | CovaMatrix_list = self.cal_covariance(x2) 253 | Cova_Sim = self.cal_similarity(x1, CovaMatrix_list) 254 | 255 | return Cova_Sim 256 | -------------------------------------------------------------------------------- /results/CovaMNet_miniImageNet_Conv64_5_Way_1_Shot/model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenbinLee/CovaMNet/d65d0bcc0f26bc8d742d75fe3387f89603c89185/results/CovaMNet_miniImageNet_Conv64_5_Way_1_Shot/model_best.pth.tar --------------------------------------------------------------------------------