├── architectures ├── __init__.py ├── GAttnClassifier.py ├── conv2d_mtl.py ├── NetworkPre.py ├── GNetworkPre.py ├── ResNetFeat.py └── AttnClassifier.py ├── pretrain ├── architectures │ ├── __init__.py │ ├── Network.py │ ├── LossFeat.py │ └── ResNetFeat.py ├── util.py ├── dataloader │ ├── dataloader.py │ ├── mini_imagenet.py │ ├── cifar.py │ └── tiered_imagenet.py ├── batch_process.py └── trainer │ ├── PreTrainer.py │ ├── MetaEval.py │ └── BaseTrainer.py ├── LICENSE ├── dataloader ├── dataloader.py ├── tiered_imagenet.py ├── cifar.py └── mini_imagenet.py ├── README.md ├── util.py ├── train.py ├── trainer ├── FSEval.py ├── GFSEval.py ├── MetaTrainer.py └── GMetaTrainer.py └── test.py /architectures/__init__.py: -------------------------------------------------------------------------------- 1 | # from .convnet import convnet4 2 | # from .resnet import resnet12 3 | # from .resnet_ssl import resnet12_ssl 4 | # from .resnet_sd import resnet12_sd 5 | # from .resnet_selfdist import multi_resnet12_kd 6 | # from .resnet import seresnet12 7 | # from .wresnet import wrn_28_10 8 | 9 | # from .resnet_new import resnet50 10 | 11 | # model_pool = [ 12 | # 'convnet4', 13 | # 'resnet12', 14 | # 'resnet12_ssl', 15 | # 'resnet12_kd', 16 | # 'resnet12_sd', 17 | # 'seresnet12', 18 | # 'wrn_28_10', 19 | # ] 20 | 21 | # model_dict = { 22 | # 'wrn_28_10': wrn_28_10, 23 | # 'convnet4': convnet4, 24 | # 'resnet12': resnet12, 25 | # 'resnet12_ssl': resnet12_ssl, 26 | # 'resnet12_kd': multi_resnet12_kd, 27 | # 'resnet12_sd': resnet12_sd, 28 | # 'seresnet12': seresnet12, 29 | # 'resnet50': resnet50, 30 | # } 31 | -------------------------------------------------------------------------------- /pretrain/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | # from .convnet import convnet4 2 | # from .resnet import resnet12 3 | # from .resnet_ssl import resnet12_ssl 4 | # from .resnet_sd import resnet12_sd 5 | # from .resnet_selfdist import multi_resnet12_kd 6 | # from .resnet import seresnet12 7 | # from .wresnet import wrn_28_10 8 | 9 | # from .resnet_new import resnet50 10 | 11 | # model_pool = [ 12 | # 'convnet4', 13 | # 'resnet12', 14 | # 'resnet12_ssl', 15 | # 'resnet12_kd', 16 | # 'resnet12_sd', 17 | # 'seresnet12', 18 | # 'wrn_28_10', 19 | # ] 20 | 21 | # model_dict = { 22 | # 'wrn_28_10': wrn_28_10, 23 | # 'convnet4': convnet4, 24 | # 'resnet12': resnet12, 25 | # 'resnet12_ssl': resnet12_ssl, 26 | # 'resnet12_kd': multi_resnet12_kd, 27 | # 'resnet12_sd': resnet12_sd, 28 | # 'seresnet12': seresnet12, 29 | # 'resnet50': resnet50, 30 | # } 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 shiyuanh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pretrain/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | import numpy as np 4 | import os 5 | import sys 6 | 7 | class AverageMeter(object): 8 | """Computes and stores the average and current value""" 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | 25 | def adjust_learning_rate(epoch, opt, optimizer): 26 | """Sets the learning rate to the initial LR decayed by decay rate every steep step""" 27 | steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs)) 28 | if steps > 0: 29 | new_lr = opt.learning_rate * (opt.lr_decay_rate ** steps) 30 | for param_group in optimizer.param_groups: 31 | param_group['lr'] = new_lr 32 | 33 | 34 | def accuracy(output, target, topk=(1,)): 35 | """Computes the accuracy over the k top predictions for the specified values of k""" 36 | with torch.no_grad(): 37 | maxk = max(topk) 38 | batch_size = target.size(0) 39 | 40 | _, pred = output.topk(maxk, 1, True, True) 41 | pred = pred.t() 42 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 43 | 44 | res = [] 45 | for k in topk: 46 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 47 | res.append(correct_k.mul_(100.0 / batch_size)) 48 | return res 49 | 50 | def rot_aug(x): 51 | bs = x.size(0) 52 | x_90 = x.transpose(2,3).flip(2) 53 | x_180 = x.flip(2).flip(3) 54 | x_270 = x.flip(2).transpose(2,3) 55 | rot_data = torch.cat((x, x_90, x_180, x_270),0) 56 | rot_label = torch.cat((torch.zeros(bs),torch.ones(bs),2*torch.ones(bs),3*torch.ones(bs))) 57 | return rot_data, rot_label -------------------------------------------------------------------------------- /pretrain/architectures/Network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.distributions import Bernoulli 6 | from torch.nn.utils.weight_norm import WeightNorm 7 | from scipy.special import gamma 8 | 9 | import numpy as np 10 | import math 11 | import pdb 12 | from architectures.ResNetFeat import create_feature_extractor 13 | from architectures.LossFeat import * 14 | 15 | DIM_CL = 128 16 | 17 | class ClassifierCombo(nn.Module): 18 | def __init__(self, in_dim, n_classes, c_type, temp=10.0): 19 | super().__init__() 20 | if c_type == 'cosine': 21 | self.classifier = nn.Linear(in_dim, n_classes, bias = False) 22 | WeightNorm.apply(self.classifier, 'weight', dim=0) #split the weight update component to direction and norm 23 | elif c_type == 'linear': 24 | self.classifier = nn.Linear(in_dim, n_classes, bias = True) 25 | elif c_type == 'mlp': 26 | self.classifier = [nn.Linear(in_dim, 1024),nn.Tanh(),nn.Linear(1024, n_classes)] 27 | self.classifier = nn.Sequential(*self.classifier) 28 | # https://github.com/wyharveychen/CloserLookFewShot/blob/e03aca8a2d01c9b5861a5a816cd5d3fdfc47cd45/backbone.py#L22 29 | # https://github.com/arjish/PreTrainedFullLibrary_FewShot/blob/main/classifier_full_library.py#L44 30 | 31 | self.c_type = c_type 32 | self.temp = nn.Parameter(torch.tensor(temp),requires_grad=False) 33 | 34 | def forward(self, feat): 35 | if self.c_type in ['linear','mlp']: 36 | return self.classifier(feat) 37 | else: 38 | return self.temp * self.classifier(F.normalize(feat,dim=-1)) 39 | 40 | class Backbone(nn.Module): 41 | def __init__(self,args,restype,n_class): 42 | super(Backbone,self).__init__() 43 | self.args = args 44 | self.restype= restype 45 | self.n_class = n_class 46 | self.featype = args.featype 47 | 48 | self.feature = create_feature_extractor(restype=restype,dataset=args.dataset) 49 | self.cls_classifier = ClassifierCombo(self.feature.out_dim, self.n_class, 'linear') 50 | 51 | def forward(self, left, right=None, need_cont=False): 52 | the_img = left if right is None else torch.cat([left,right],dim=0) 53 | resfeat = self.feature(the_img) 54 | cls_logit = self.cls_classifier(resfeat) 55 | return resfeat, cls_logit -------------------------------------------------------------------------------- /pretrain/dataloader/dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os,sys 4 | import numpy as np 5 | from tqdm import tqdm 6 | from torch.utils.data import DataLoader 7 | 8 | from .cifar import PreCIFAR, MetaCIFAR 9 | from .mini_imagenet import PreMini, MetaMini 10 | from .tiered_imagenet import PreTiered, MetaTiered 11 | 12 | def get_dataloaders(opt,mode='contrast'): 13 | # dataloader 14 | opt.data_root = '{}/{}'.format(opt.data_root, opt.dataset) 15 | 16 | if opt.dataset == 'miniImageNet': 17 | n_cls = 64 18 | meta_1shot_loader = DataLoader(MetaMini(opt,1,'test',False), batch_size=1, shuffle=False, drop_last=False,num_workers=opt.num_workers) 19 | meta_5shot_loader = DataLoader(MetaMini(opt,5,'test',False), batch_size=1, shuffle=False, drop_last=False,num_workers=opt.num_workers) 20 | meta_test_loader = (meta_1shot_loader,meta_5shot_loader) 21 | pre_train_loader = DataLoader(PreMini(opt,'train',True), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) 22 | return pre_train_loader, meta_test_loader, n_cls 23 | 24 | elif opt.dataset in ['CIFAR-FS','FC100']: 25 | 26 | n_cls = 64 if opt.dataset == 'CIFAR-FS' else 60 27 | meta_1shot_loader = DataLoader(MetaCIFAR(opt,1,'test',False), batch_size=1, shuffle=False, drop_last=False,num_workers=opt.num_workers) 28 | meta_5shot_loader = DataLoader(MetaCIFAR(opt,5,'test',False), batch_size=1, shuffle=False, drop_last=False,num_workers=opt.num_workers) 29 | meta_test_loader = (meta_1shot_loader,meta_5shot_loader) 30 | pre_train_loader = DataLoader(PreCIFAR(opt,'train',True), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) 31 | return pre_train_loader, meta_test_loader, n_cls 32 | 33 | elif opt.dataset == 'tieredImageNet': 34 | n_cls = 351 35 | meta_1shot_loader = DataLoader(MetaTiered(opt,1,'test',False), batch_size=1, shuffle=False, drop_last=False,num_workers=opt.num_workers) 36 | meta_5shot_loader = DataLoader(MetaTiered(opt,5,'test',False), batch_size=1, shuffle=False, drop_last=False,num_workers=opt.num_workers) 37 | meta_test_loader = (meta_1shot_loader,meta_5shot_loader) 38 | pre_train_loader = DataLoader(PreTiered(opt,'train',True), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) 39 | return pre_train_loader, meta_test_loader, n_cls 40 | 41 | else: 42 | raise ValueError('Dataset Not in Record, Pls check the CONFIGS') 43 | -------------------------------------------------------------------------------- /architectures/GAttnClassifier.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | import math 7 | import pdb 8 | from architectures.AttnClassifier import * 9 | 10 | class GClassifier(Classifier): 11 | def __init__(self, args, feat_dim, param_seman, train_weight_base=False): 12 | super(GClassifier, self).__init__(args, feat_dim, param_seman, train_weight_base) 13 | 14 | # Weight & Bias for Base 15 | self.train_weight_base = train_weight_base 16 | self.init_representation(param_seman) 17 | if train_weight_base: 18 | print('Enable training base class weights') 19 | 20 | self.calibrator = SupportCalibrator(nway=args.n_ways, feat_dim=feat_dim, n_head=1, base_seman_calib=args.base_seman_calib, neg_gen_type=args.neg_gen_type) 21 | self.open_generator = OpenSetGenerater(args.n_ways, feat_dim, n_head=1, neg_gen_type=args.neg_gen_type, agg=args.agg) 22 | self.metric = Metric_Cosine() 23 | 24 | def forward(self, features, cls_ids, test=False): 25 | ## bs: features[0].size(0) 26 | ## support_feat: bs*nway*nshot*D 27 | ## query_feat: bs*(nway*nquery)*D 28 | ## base_ids: bs*54 29 | (support_feat, query_feat, openset_feat, baseset_feat) = features 30 | 31 | (nb,nc,ns,ndim),nq = support_feat.size(),query_feat.size(1) 32 | (supp_ids, base_ids) = cls_ids 33 | 34 | base_weights,base_wgtmem,base_seman,support_seman = self.get_representation(supp_ids,base_ids) 35 | support_feat = torch.mean(support_feat, dim=2) 36 | 37 | supp_protos,support_attn = self.calibrator(support_feat, base_weights, support_seman, base_seman) 38 | 39 | fakeclass_protos, recip_unit = self.open_generator(supp_protos, base_weights, support_seman, base_seman) 40 | cls_protos = torch.cat([base_weights, supp_protos, fakeclass_protos], dim=1) 41 | 42 | 43 | query_funit_distance = 1.0- self.metric(recip_unit, query_feat) 44 | qopen_funit_distance = 1.0- self.metric(recip_unit, openset_feat) 45 | funit_distance = torch.cat([query_funit_distance,qopen_funit_distance],dim=1) 46 | 47 | query_cls_scores = self.metric(cls_protos, query_feat) 48 | openset_cls_scores = self.metric(cls_protos, openset_feat) 49 | baseset_cls_scores = self.metric(cls_protos, baseset_feat) 50 | 51 | test_cosine_scores = (baseset_cls_scores,query_cls_scores,openset_cls_scores) 52 | return test_cosine_scores, supp_protos, fakeclass_protos, (base_weights,base_wgtmem), funit_distance 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /dataloader/dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import argparse 5 | import socket 6 | import time 7 | import sys 8 | from tqdm import tqdm 9 | 10 | import torch 11 | import torch.optim as optim 12 | import torch.nn as nn 13 | import torch.backends.cudnn as cudnn 14 | from torch.utils.data import DataLoader 15 | import torch.nn.functional as F 16 | 17 | from .mini_imagenet import OpenMini, GenMini 18 | from .cifar import OpenCIFAR 19 | from .tiered_imagenet import OpenTiered 20 | 21 | import numpy as np 22 | 23 | 24 | def get_dataloaders(opt,mode='open'): 25 | # dataloader 26 | opt.data_root = '{}/{}'.format(opt.data_root, opt.dataset) 27 | 28 | if mode == 'gopenmeta': 29 | assert opt.dataset == 'miniImageNet' 30 | n_cls = 64 31 | open_train_loader = DataLoader(GenMini(opt,'train','episode', True), batch_size=opt.n_train_para, shuffle=False, num_workers=opt.num_workers) 32 | meta_test_loader = DataLoader(GenMini(opt,'test','episode', False), batch_size=1, shuffle=False, drop_last=False,num_workers=opt.num_workers) 33 | return open_train_loader, meta_test_loader, n_cls 34 | 35 | assert mode == 'openmeta' 36 | 37 | if opt.dataset == 'miniImageNet': 38 | n_cls = 64 39 | open_train_loader = DataLoader(OpenMini(opt,'train','episode', True), batch_size=opt.n_train_para, shuffle=False, num_workers=opt.num_workers) 40 | meta_test_loader = DataLoader(OpenMini(opt,'test','episode', False), batch_size=1, shuffle=False, drop_last=False,num_workers=opt.num_workers) 41 | return open_train_loader, meta_test_loader, n_cls 42 | elif opt.dataset in ['FC100','CIFAR-FS']: 43 | n_cls = 60 if opt.dataset=='FC100' else 64 44 | open_train_loader = DataLoader(OpenCIFAR(opt,'train','episode', True), batch_size=opt.n_train_para, shuffle=False, num_workers=opt.num_workers) 45 | meta_test_loader = DataLoader(OpenCIFAR(opt,'test','episode', False), batch_size=1, shuffle=False, drop_last=False,num_workers=opt.num_workers) 46 | return open_train_loader, meta_test_loader, n_cls 47 | elif opt.dataset in ['tieredImageNet', 'tieredImageNetWord']: 48 | n_cls = 351 49 | open_train_loader = DataLoader(OpenTiered(opt,'train','episode', True), batch_size=opt.n_train_para, shuffle=False, num_workers=opt.num_workers) 50 | meta_test_loader = DataLoader(OpenTiered(opt,'test','episode', False), batch_size=1, shuffle=False, drop_last=False,num_workers=opt.num_workers) 51 | return open_train_loader, meta_test_loader, n_cls 52 | else: 53 | raise NotImplementedError(opt.dataset) 54 | -------------------------------------------------------------------------------- /pretrain/batch_process.py: -------------------------------------------------------------------------------- 1 | import os,pdb 2 | import torch 3 | import argparse 4 | import numpy as np 5 | 6 | from trainer.PreTrainer import PreTrainer 7 | from dataloader.dataloader import get_dataloaders 8 | 9 | model_pool = ['ResNet12'] 10 | task_pool = ['Entropy','EntropyRot'] 11 | parser = argparse.ArgumentParser('argument for training') 12 | 13 | # General 14 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size') 15 | parser.add_argument('--num_workers', type=int, default=2, help='num of workers to use') 16 | parser.add_argument('--epochs', type=int, default=90, help='number of training epochs') 17 | parser.add_argument('--eval', action='store_true', help='using cosine annealing') 18 | parser.add_argument('-t', '--trial', type=str, default='1', help='the experiment id') 19 | parser.add_argument('--featype', type=str, default='EntropyRot', choices=task_pool, help='number of training epochs') 20 | 21 | # dataset 22 | parser.add_argument('--restype', type=str, default='ResNet12', choices=model_pool) 23 | parser.add_argument('--dataset', type=str, default='miniImageNet', choices=['miniImageNet', 'tieredImageNet','CIFAR-FS', 'FC100']) 24 | parser.add_argument('--model_path', type=str, default='./logs/', help='path to save model') 25 | parser.add_argument('--data_root', type=str, default='/home/jiawei/DATA/', help='path to data root') 26 | 27 | # few-shot setting 28 | parser.add_argument('--n_episodes', type=int, default=1000, metavar='N', help='Number of test runs') 29 | parser.add_argument('--n_ways', type=int, default=5, metavar='N', help='Number of classes for doing each classification run') 30 | parser.add_argument('--n_shots', type=int, default=1, metavar='N', help='Number of shots in test') 31 | parser.add_argument('--n_queries', type=int, default=15, metavar='N', help='Number of query in test') 32 | parser.add_argument('--n_aug_support_samples', default=5, type=int, help='The number of augmented samples for each meta test sample') 33 | 34 | # optimization 35 | parser.add_argument('--learning_rate', type=float, default=0.05, help='learning rate') 36 | parser.add_argument('--lr_decay_epochs', type=str, default='60', help='where to decay lr, can be a list') 37 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate') 38 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay') 39 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 40 | 41 | #hyper parameters 42 | parser.add_argument('--rotangle', type=int, default=10,help='rotation angle of the weak augmentation') 43 | parser.add_argument('--temp', type=float, default=0.5, help='temperature of the contrastive loss') 44 | parser.add_argument('--use_bce', action='store_true') 45 | 46 | 47 | args = parser.parse_args() 48 | args.logroot = os.path.join(os.path.abspath('.'),'logs') 49 | if not os.path.isdir(args.logroot): 50 | os.makedirs(args.logroot) 51 | args.n_gpu = torch.cuda.device_count() 52 | 53 | if __name__ == "__main__": 54 | 55 | pre_train_loader, meta_test_loader, n_cls = get_dataloaders(args,'entropy') 56 | dataloader_trainer = (pre_train_loader, None, n_cls) 57 | trainer = PreTrainer(args,dataloader_trainer) 58 | trainer.train(meta_test_loader) 59 | -------------------------------------------------------------------------------- /pretrain/trainer/PreTrainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import pdb 5 | import numpy as np 6 | import time 7 | from tqdm import tqdm 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from .BaseTrainer import BaseTrainer 12 | from util import AverageMeter 13 | 14 | def rot_aug(x): 15 | bs = x.size(0) 16 | x_90 = x.transpose(2,3).flip(2) 17 | x_180 = x.flip(2).flip(3) 18 | x_270 = x.flip(2).transpose(2,3) 19 | rot_data = torch.cat((x, x_90, x_180, x_270),0) 20 | rot_label = torch.cat((torch.zeros(bs),torch.ones(bs),2*torch.ones(bs),3*torch.ones(bs))) 21 | return rot_data, rot_label 22 | 23 | class PreTrainer(BaseTrainer): 24 | def __init__(self, args, dataset_trainer): 25 | super(PreTrainer,self).__init__(args, dataset_trainer) 26 | 27 | def train_epoch(self, epoch, train_loader, model, criterion, optimizer, args): 28 | if args.featype == 'EntropyRot': 29 | return self.ce_rot_epoch(epoch, train_loader, model, criterion, optimizer) 30 | elif args.featype == 'Entropy': 31 | return self.ce_epoch(epoch, train_loader, model, criterion, optimizer) 32 | 33 | def ce_epoch(self, epoch, train_loader, model, criterion, optimizer): 34 | """One epoch training""" 35 | model.train() 36 | losses = AverageMeter() 37 | 38 | with tqdm(train_loader, total=len(train_loader), leave=False) as pbar: 39 | for idx, (image,target,_) in enumerate(pbar): 40 | 41 | batch_size = target.size()[0] 42 | # Forward 43 | _,cls_logits = model(image.cuda()) 44 | if self.args.use_bce: 45 | loss = criterion['logit'](cls_logits, F.one_hot(target,self.n_cls).float().cuda()) 46 | else: 47 | loss = criterion['logit'](cls_logits, target.cuda()) 48 | losses.update(loss.item(), batch_size) 49 | 50 | # ===================backward===================== 51 | optimizer.zero_grad() 52 | loss.backward() 53 | optimizer.step() 54 | pbar.set_postfix({"Epoch {} Loss".format(epoch) :'{0:.2f}'.format(losses.avg)}) 55 | 56 | message = 'Epoch {} Train_Loss {:.3f}'.format(epoch, losses.avg) 57 | return losses.avg, message 58 | 59 | def ce_rot_epoch(self, epoch, train_loader, model, criterion, optimizer): 60 | """One epoch training""" 61 | model.train() 62 | losses = AverageMeter() 63 | 64 | with tqdm(train_loader, total=len(train_loader), leave=False) as pbar: 65 | for idx, (image,target,_) in enumerate(pbar): 66 | 67 | batch_size = target.size()[0] 68 | image,_ = rot_aug(image) 69 | # Forward 70 | _,cls_logits = model(image.cuda()) 71 | loss = criterion['logit'](cls_logits, target.repeat(4).cuda()) 72 | losses.update(loss.item(), batch_size) 73 | 74 | # ===================backward===================== 75 | optimizer.zero_grad() 76 | loss.backward() 77 | optimizer.step() 78 | pbar.set_postfix({"Epoch {} Loss".format(epoch) :'{0:.2f}'.format(losses.avg)}) 79 | 80 | message = 'Epoch {} Train_Loss {:.3f}'.format(epoch, losses.avg) 81 | return losses.avg, message -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Task-Adaptive Negative Envision for Few-Shot Open-Set Recognition 2 | This is the code repository for ["Task-Adaptive Negative Envision for Few-Shot Open-Set Recognition"](https://openaccess.thecvf.com/content/CVPR2022/html/Huang_Task-Adaptive_Negative_Envision_for_Few-Shot_Open-Set_Recognition_CVPR_2022_paper.html) (accepted by CVPR 2022). 3 | 4 | 5 | ## Installation 6 | This repo is tested with Python 3.6, Pytorch 1.8, CUDA 10.1. More recent versions of Python and Pytorch with compatible CUDA versions should also support the code. 7 | 8 | 9 | ## Data Preparation 10 | MiniImageNet image data are provided by [RFS](https://github.com/WangYueFt/rfs), available at [DropBox](https://www.dropbox.com/sh/6yd1ygtyc3yd981/AABVeEqzC08YQv4UZk7lNHvya?dl=0). We also provide the word embeddings for the class names [here](https://drive.google.com/file/d/1CpF3M_qySCBhIWOSURIT_LpA1B61tsFb/view?usp=sharing). For TieredImageNet, we use the image data and word embeddings provided by [AW3](https://github.com/ServiceNow/am3), available at [GoogleDrive](https://drive.google.com/file/d/1Letu5U_kAjQfqJjNPWS_rdjJ7Fd46LbX/view). Download and put them under your <*data_dir*>. 11 | 12 | 13 | ## Pre-trained models 14 | We provide the pre-trained models for TieredImageNet and MiniImageNet, which can be downloaded [here](https://drive.google.com/drive/folders/1mj8j5ZChRFLcYMBWEsBBhst8uQTOz_WJ?usp=sharing). Save the pre-trained model to <*pretrained_model_path*>. 15 | 16 | ## Training 17 | An example of training command for 5-way 1-shot FSOR: 18 | ``` 19 | python train.py --dataset --logroot --data_root \ 20 | --n_ways 5 --n_shots 1 \ 21 | --pretrained_model_path \ 22 | --featype OpenMeta \ 23 | --learning_rate 0.03 \ 24 | --tunefeat 0.0001 \ 25 | --tune_part 4 \ 26 | --cosine \ 27 | --base_seman_calib 1 \ 28 | --train_weight_base 1 \ 29 | --neg_gen_type semang 30 | ``` 31 | 32 | ## Testing 33 | An example of testing command for 5-way 1-shot FSOR: 34 | ``` 35 | python test.py --dataset --data_root \ 36 | --n_ways 5 --n_shots 1 \ 37 | --pretrained_model_path \ 38 | --featype OpenMeta \ 39 | --test_model_path \ 40 | --n_test_runs 1000 \ 41 | --seed 42 | ``` 43 | 44 | ## Pre-training 45 | We also provide the code for the pre-training stage under `pretrain` folder. An example of running command for pre-training on miniImageNet: 46 | ``` 47 | python batch_process.py --featype EntropyRot --learning_rate 0.05 48 | ``` 49 | 50 | ## Citation 51 | If you find this repo useful for your research, please consider citing the paper: 52 | ``` 53 | @InProceedings{Huang_2022_CVPR, 54 | author = {Huang, Shiyuan and Ma, Jiawei and Han, Guangxing and Chang, Shih-Fu}, 55 | title = {Task-Adaptive Negative Envision for Few-Shot Open-Set Recognition}, 56 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 57 | month = {June}, 58 | year = {2022}, 59 | pages = {7171-7180} 60 | } 61 | ``` 62 | 63 | 64 | ## Acknowledgement 65 | Our code and data are based upon [RFS](https://github.com/WangYueFt/rfs) and [AW3](https://github.com/ServiceNow/am3). -------------------------------------------------------------------------------- /pretrain/trainer/MetaEval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys, os, pdb 4 | import numpy as np 5 | import scipy 6 | from scipy.stats import t 7 | from tqdm import tqdm 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from sklearn import metrics 13 | from sklearn.linear_model import LogisticRegression 14 | 15 | def meta_evaluation(net, metaloader, type_aug='crop', type_classifier='LR'): 16 | if type_aug == 'crop': 17 | return meta_test(net, metaloader, type_classifier) 18 | 19 | def meta_test(net, metaloader, classifier='LR'): 20 | net = net.eval() 21 | acc = [] 22 | with torch.no_grad(): 23 | with tqdm(metaloader, total=len(metaloader), leave=False) as pbar: 24 | for idx, data in enumerate(pbar): 25 | # Data Preparation 26 | support_data, support_label, query_data, query_label = data 27 | support_data = support_data.cuda() 28 | query_data = query_data.cuda() 29 | # Data Reorganization 30 | _, _, height, width, channel = support_data.size() 31 | support_data = support_data.view(-1, height, width, channel) 32 | query_data = query_data.view(-1, height, width, channel) 33 | support_label = support_label.view(-1).numpy() 34 | query_label = query_label.view(-1).numpy() 35 | 36 | # Feature Extracdtion 37 | support_features = net(support_data)[0].view(support_data.size(0), -1) 38 | query_features = net(query_data)[0].view(query_data.size(0), -1) 39 | support_features = F.normalize(support_features,p=2,dim=-1).detach().cpu().numpy() 40 | query_features = F.normalize(query_features,p=2,dim=-1).detach().cpu().numpy() 41 | if classifier.lower() in ['lr','linearregression']: 42 | clf = LogisticRegression(random_state=0, solver='lbfgs', max_iter=1000, penalty='l2', multi_class='multinomial') 43 | clf.fit(support_features, support_label) 44 | query_pred = clf.predict(query_features) 45 | elif 'proto' in classifier.lower(): 46 | query_pred = Proto(support_features, support_label, query_features, query_label) 47 | else: 48 | raise NotImplementedError('classifier not supported: {}'.format(classifier)) 49 | 50 | acc.append(metrics.accuracy_score(query_label, query_pred)) 51 | pbar.set_postfix({"Few-Shot MetaEval Acc":'{0:.2f}'.format(acc[-1])}) 52 | return mean_confidence_interval(acc) 53 | 54 | def Proto(support, support_ys, query, query_label): 55 | proto_ys = sorted(np.unique(support_ys).tolist()) 56 | proto = [] 57 | for cls_id in proto_ys: 58 | the_feat = support[support_ys==cls_id].mean(axis=0) 59 | proto.append(the_feat) 60 | proto = np.stack(proto) 61 | 62 | proto_norm = np.linalg.norm(proto, axis=1, keepdims=True) 63 | proto = proto / proto_norm 64 | cosine_distance = query @ proto.transpose() 65 | 66 | max_idx = np.argmax(cosine_distance, axis=1) 67 | pred = [proto_ys[idx] for idx in max_idx] 68 | return pred 69 | 70 | def mean_confidence_interval(data, confidence=0.95): 71 | a = 100.0 * np.array(data) 72 | n = len(a) 73 | m, se = np.mean(a), scipy.stats.sem(a) 74 | h = se * t._ppf((1+confidence)/2., n-1) 75 | return m, h 76 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import os 7 | import sys 8 | from tqdm import tqdm 9 | 10 | 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value""" 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | def adjust_learning_rate(epoch, opt, optimizer, threshold=1e-6): 30 | """Sets the learning rate to the initial LR decayed by decay rate every steep step""" 31 | steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs)) 32 | if steps > 0 and opt.learning_rate > threshold: 33 | new_lr = opt.learning_rate * (opt.lr_decay_rate ** steps) 34 | for param_group in optimizer.param_groups: 35 | param_group['lr'] = new_lr 36 | 37 | 38 | def accuracy(output, target, topk=(1,)): 39 | """Computes the accuracy over the k top predictions for the specified values of k""" 40 | with torch.no_grad(): 41 | maxk = max(topk) 42 | batch_size = target.size(0) 43 | 44 | _, pred = output.topk(maxk, 1, True, True) 45 | pred = pred.t() 46 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 47 | 48 | res = [] 49 | for k in topk: 50 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 51 | res.append(correct_k.mul_(100.0 / batch_size)) 52 | return res 53 | 54 | 55 | class Logger(object): 56 | '''Save training process to log file with simple plot function.''' 57 | def __init__(self, fpath, title=None, resume=False): 58 | self.file = None 59 | self.resume = resume 60 | self.title = '' if title == None else title 61 | if fpath is not None: 62 | if resume: 63 | self.file = open(fpath, 'r') 64 | name = self.file.readline() 65 | self.names = name.rstrip().split('\t') 66 | self.numbers = {} 67 | for _, name in enumerate(self.names): 68 | self.numbers[name] = [] 69 | 70 | for numbers in self.file: 71 | numbers = numbers.rstrip().split('\t') 72 | for i in range(0, len(numbers)): 73 | self.numbers[self.names[i]].append(numbers[i]) 74 | self.file.close() 75 | self.file = open(fpath, 'a') 76 | else: 77 | self.file = open(fpath, 'w') 78 | 79 | def set_names(self, names): 80 | if self.resume: 81 | pass 82 | # initialize numbers as empty list 83 | self.numbers = {} 84 | self.names = names 85 | for _, name in enumerate(self.names): 86 | self.file.write(name) 87 | self.file.write('\t') 88 | self.numbers[name] = [] 89 | self.file.write('\n') 90 | self.file.flush() 91 | 92 | 93 | def append(self, numbers): 94 | assert len(self.names) == len(numbers), 'Numbers do not match names' 95 | for index, num in enumerate(numbers): 96 | self.file.write("{0:.6f}".format(num)) 97 | self.file.write('\t') 98 | self.numbers[self.names[index]].append(num) 99 | self.file.write('\n') 100 | self.file.flush() 101 | 102 | def plot(self, names=None): 103 | names = self.names if names == None else names 104 | numbers = self.numbers 105 | for _, name in enumerate(names): 106 | x = np.arange(len(numbers[name])) 107 | plt.plot(x, np.asarray(numbers[name])) 108 | plt.legend([self.title + '(' + name + ')' for name in names]) 109 | plt.grid(True) 110 | 111 | 112 | def close(self): 113 | if self.file is not None: 114 | self.file.close() 115 | 116 | -------------------------------------------------------------------------------- /architectures/conv2d_mtl.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Modified from: https://github.com/pytorch/pytorch 4 | ## Tianjin University 5 | ## liuyaoyao@tju.edu.cn 6 | ## Copyright (c) 2019 7 | ## 8 | ## This source code is licensed under the MIT-style license found in the 9 | ## LICENSE file in the root directory of this source tree 10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 11 | """ MTL CONV layers. """ 12 | import math 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.nn.parameter import Parameter 16 | from torch.nn.modules.module import Module 17 | from torch.nn.modules.utils import _pair 18 | 19 | class _ConvNdMtl(Module): 20 | """The class for meta-transfer convolution""" 21 | def __init__(self, in_channels, out_channels, kernel_size, stride, 22 | padding, dilation, transposed, output_padding, groups, bias): 23 | super(_ConvNdMtl, self).__init__() 24 | if in_channels % groups != 0: 25 | raise ValueError('in_channels must be divisible by groups') 26 | if out_channels % groups != 0: 27 | raise ValueError('out_channels must be divisible by groups') 28 | self.in_channels = in_channels 29 | self.out_channels = out_channels 30 | self.kernel_size = kernel_size 31 | self.stride = stride 32 | self.padding = padding 33 | self.dilation = dilation 34 | self.transposed = transposed 35 | self.output_padding = output_padding 36 | self.groups = groups 37 | if transposed: 38 | self.weight = Parameter(torch.Tensor( 39 | in_channels, out_channels // groups, *kernel_size)) 40 | self.mtl_weight = Parameter(torch.ones(in_channels, out_channels // groups, 1, 1)) 41 | else: 42 | self.weight = Parameter(torch.Tensor( 43 | out_channels, in_channels // groups, *kernel_size)) 44 | self.mtl_weight = Parameter(torch.ones(out_channels, in_channels // groups, 1, 1)) 45 | self.weight.requires_grad=False 46 | if bias: 47 | self.bias = Parameter(torch.Tensor(out_channels)) 48 | self.bias.requires_grad=False 49 | self.mtl_bias = Parameter(torch.zeros(out_channels)) 50 | else: 51 | self.register_parameter('bias', None) 52 | self.register_parameter('mtl_bias', None) 53 | self.reset_parameters() 54 | 55 | def reset_parameters(self): 56 | n = self.in_channels 57 | for k in self.kernel_size: 58 | n *= k 59 | stdv = 1. / math.sqrt(n) 60 | self.weight.data.uniform_(-stdv, stdv) 61 | self.mtl_weight.data.uniform_(1, 1) 62 | if self.bias is not None: 63 | self.bias.data.uniform_(-stdv, stdv) 64 | self.mtl_bias.data.uniform_(0, 0) 65 | 66 | def extra_repr(self): 67 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 68 | ', stride={stride}') 69 | if self.padding != (0,) * len(self.padding): 70 | s += ', padding={padding}' 71 | if self.dilation != (1,) * len(self.dilation): 72 | s += ', dilation={dilation}' 73 | if self.output_padding != (0,) * len(self.output_padding): 74 | s += ', output_padding={output_padding}' 75 | if self.groups != 1: 76 | s += ', groups={groups}' 77 | if self.bias is None: 78 | s += ', bias=False' 79 | return s.format(**self.__dict__) 80 | 81 | class Conv2dMtl(_ConvNdMtl): 82 | """The class for meta-transfer convolution""" 83 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 84 | padding=0, dilation=1, groups=1, bias=True): 85 | kernel_size = _pair(kernel_size) 86 | stride = _pair(stride) 87 | padding = _pair(padding) 88 | dilation = _pair(dilation) 89 | super(Conv2dMtl, self).__init__( 90 | in_channels, out_channels, kernel_size, stride, padding, dilation, 91 | False, _pair(0), groups, bias) 92 | 93 | def forward(self, inp): 94 | new_mtl_weight = self.mtl_weight.expand(self.weight.shape) 95 | new_weight = self.weight.mul(new_mtl_weight) 96 | if self.bias is not None: 97 | new_bias = self.bias + self.mtl_bias 98 | else: 99 | new_bias = None 100 | return F.conv2d(inp, new_weight, new_bias, self.stride, 101 | self.padding, self.dilation, self.groups) 102 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import numpy as np 5 | # from models.base_model import Base_Model 6 | 7 | from trainer.MetaTrainer import MetaTrainer 8 | from trainer.GMetaTrainer import GMetaTrainer 9 | from dataloader.dataloader import get_dataloaders 10 | 11 | 12 | model_pool = ['ResNet18','ResNet12','WRN28'] 13 | parser = argparse.ArgumentParser('argument for training') 14 | 15 | # General Setting 16 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size') 17 | parser.add_argument('--num_workers', type=int, default=1, help='num of workers to use') 18 | parser.add_argument('--epochs', type=int, default=65, help='number of training epochs') 19 | parser.add_argument('--featype', type=str, default='OpenMeta', choices=['OpenMeta', 'GOpenMeta'], help='type of task: OpenMeta -- FSOR, GOpenMeta --- GFSOR') 20 | parser.add_argument('--restype', type=str, default='ResNet12', choices=model_pool, help='Network Structure') 21 | parser.add_argument('--dataset', type=str, default='miniImageNet', choices=['miniImageNet', 'tieredImageNet']) 22 | parser.add_argument('--gpus', type=str, default='0') 23 | parser.add_argument('-t', '--trial', type=str, default='1', help='the experiment id') 24 | 25 | # Optimization 26 | parser.add_argument('--cosine', action='store_true', help='using cosine annealing') 27 | parser.add_argument('--learning_rate', type=float, default=0.05, help='learning rate') 28 | parser.add_argument('--lr_decay_epochs', type=str, default='30', help='where to decay lr, can be a list') 29 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate') 30 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay') 31 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 32 | parser.add_argument('--adam', action='store_true', help='use adam optimizer') 33 | parser.add_argument('--tunefeat', type=float, default=0.0, help='update feature parameter') 34 | 35 | # Specify folder 36 | parser.add_argument('--logroot', type=str, default='./logs/', help='path to save model') 37 | parser.add_argument('--data_root', type=str, default='data/', help='path to data root') 38 | parser.add_argument('--pretrained_model_path', type=str, default='miniImageNet_pre.pth', help='path to pretrained model') 39 | 40 | # Meta Setting 41 | parser.add_argument('--n_ways', type=int, default=5, metavar='N', help='Number of classes for doing each classification run') 42 | parser.add_argument('--n_shots', type=int, default=1, metavar='N', help='Number of shots in test') 43 | parser.add_argument('--n_queries', type=int, default=15, metavar='N', help='Number of query in test') 44 | parser.add_argument('--n_aug_support_samples', default=5, type=int, help='The number of augmented samples for each meta test sample') 45 | parser.add_argument('--n_train_para', type=int, default=2, metavar='train_batch_size', help='Size of training batch)') 46 | parser.add_argument('--n_train_runs', type=int, default=300, help='Number of training episodes') 47 | parser.add_argument('--n_test_runs', type=int, default=600, metavar='N', help='Number of test runs') 48 | 49 | # Meta Control 50 | parser.add_argument('--gamma', type=float, default=1.0, help='loss cofficient for mse loss') 51 | parser.add_argument('--train_weight_base', type=int, default=0, help='enable training base class weights') 52 | parser.add_argument('--neg_gen_type', type=str, default='semang', choices=['semang', 'attg', 'att', 'mlp']) 53 | parser.add_argument('--base_seman_calib',type=int, default=0, help='base semantics calibration') 54 | parser.add_argument('--agg', type=str, default='avg', choices=['avg', 'mlp']) 55 | 56 | parser.add_argument('--tune_part', type=int, default=2, choices=[1,2, 3, 4]) 57 | parser.add_argument('--base_size', default=-1, type=int) 58 | parser.add_argument('--n_open_ways', type=int, default=5, metavar='N', help='Number of classes for doing each classification run') 59 | parser.add_argument('--funit', type=float, default=1.0) 60 | 61 | 62 | 63 | if __name__ == "__main__": 64 | torch.manual_seed(0) 65 | 66 | args = parser.parse_args() 67 | 68 | args.n_train_runs = args.n_train_runs * args.n_train_para 69 | args.n_gpu = len(args.gpus.split(',')) 70 | args.train_weight_base = args.train_weight_base==1 71 | args.base_seman_calib = args.base_seman_calib==1 72 | 73 | 74 | if args.featype == 'OpenMeta': 75 | open_train_val_loader, meta_test_loader, n_cls = get_dataloaders(args,'openmeta') 76 | dataloader_trainer = (open_train_val_loader, meta_test_loader, n_cls) 77 | args.base_size = n_cls if args.base_size == -1 else args.base_size 78 | trainer = MetaTrainer(args,dataloader_trainer,meta_test_loader) 79 | trainer.train(meta_test_loader) 80 | 81 | elif args.featype == 'GOpenMeta': 82 | open_train_val_loader, meta_test_loader, n_cls = get_dataloaders(args,'gopenmeta') 83 | dataloader_trainer = (open_train_val_loader, None, n_cls) 84 | args.base_size = n_cls if args.base_size == -1 else args.base_size 85 | trainer = GMetaTrainer(args,dataloader_trainer,meta_test_loader) 86 | trainer.train(meta_test_loader) 87 | 88 | -------------------------------------------------------------------------------- /trainer/FSEval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys, os, pdb 4 | import numpy as np 5 | import scipy 6 | from scipy.stats import t 7 | from tqdm import tqdm 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from sklearn import metrics 14 | from sklearn.metrics import f1_score 15 | 16 | 17 | def run_test_fsl(net, openloader,config,encoder=None, generator=None, n_ways=5,n_shots=1,scale=4): 18 | net = net.eval() 19 | auroc_type = config['auroc_type'] 20 | 21 | with tqdm(openloader, total=len(openloader), leave=False) as pbar: 22 | acc_trace = [] 23 | auroc_trace = {k:[] for k in auroc_type} 24 | for idx, data in enumerate(pbar): 25 | feats, labels, probs = compute_feats(net, data) 26 | acc, auroc = eval_fsl_nplus1(feats, labels, probs, auroc_type) 27 | 28 | acc_trace.append(acc) 29 | for t in auroc_type: 30 | if auroc[t] is None: 31 | continue 32 | auroc_trace[t].append(auroc[t]) 33 | 34 | pbar.set_postfix({ 35 | "OpenSet MetaEval Acc":'{0:.2f}'.format(acc), 36 | "AUROC-%s MetaEval:" % auroc_type[0]:'{0:.2f}'.format(auroc[auroc_type[0]]) 37 | }) 38 | 39 | config['data'] = {'acc': mean_confidence_interval(acc_trace)} 40 | for t in auroc_type: 41 | config['data']['auroc_%s'%t] = mean_confidence_interval(auroc_trace[t]) 42 | return config 43 | 44 | 45 | 46 | def eval_fsl_nplus1(feats, labels, probs, auroc_type=['prob',]): 47 | cls_protos,query_feats,open_feats = feats 48 | supp_label, query_label, open_label = labels 49 | num_query = query_label.shape[0] 50 | supp_label = supp_label.view() 51 | all_probs = np.concatenate(probs, axis=0) 52 | 53 | auroc = dict() 54 | 55 | if 'prob' in auroc_type: 56 | auroc_score = all_probs[:,-1] 57 | 58 | auroc_result = metrics.roc_auc_score(1-open_label,auroc_score) 59 | auroc['prob'] = auroc_result 60 | 61 | if 'fscore' in auroc_type: 62 | num_open = len(open_label) - len(query_label) 63 | num_way = 5 64 | all_labels = np.concatenate([query_label, num_way * np.ones(num_open)], -1).astype(np.int) 65 | ypred = np.argmax(all_probs, axis=-1) 66 | auroc['fscore'] = f1_score(all_labels, ypred, average='macro', labels=np.unique(ypred)) 67 | 68 | assert all_probs.shape[-1] == 6 69 | num_query = query_label.shape[0] 70 | query_pred = np.argmax(all_probs[:num_query,:-1], axis=-1) 71 | acc = metrics.accuracy_score(query_label, query_pred) 72 | 73 | return acc, auroc 74 | 75 | 76 | 77 | def compute_feats(net, data): 78 | with torch.no_grad(): 79 | # Data Preparation 80 | support_data, support_label, query_data, query_label, suppopen_data, suppopen_label, openset_data, openset_label, supp_idx, open_idx = data 81 | 82 | # Data Conversion & Packaging 83 | support_data,support_label = support_data.float().cuda(),support_label.cuda().long() 84 | query_data,query_label = query_data.float().cuda(),query_label.cuda().long() 85 | suppopen_data,suppopen_label = suppopen_data.float().cuda(),suppopen_label.cuda().long() 86 | openset_data,openset_label = openset_data.float().cuda(),openset_label.cuda().long() 87 | supp_idx, open_idx= supp_idx.long(), open_idx.long() 88 | openset_label = net.n_ways * torch.ones_like(openset_label) 89 | the_img = (support_data, query_data, suppopen_data, openset_data) 90 | the_label = (support_label,query_label,suppopen_label,openset_label) 91 | the_conj = (supp_idx, open_idx) 92 | 93 | # Tensor Input Preparation 94 | features, cls_protos, cosine_probs= net(the_img,the_label,the_conj,test=True) 95 | (supp_feat, query_feat, openset_feat) = features 96 | 97 | # Numpy Input Preparation 98 | cls_protos_numpy = F.normalize(cls_protos.view(-1,net.feat_dim),p=2,dim=-1).cpu().numpy() 99 | supplabel_numpy = support_label.view(supp_feat.shape[1:-1]).cpu().numpy() 100 | querylabel_numpy = query_label.view(-1).cpu().numpy() 101 | supp_feat_numpy = F.normalize(supp_feat[0].view(-1,net.feat_dim),p=2,dim=-1).cpu().numpy() 102 | queryfeat_numpy = F.normalize(query_feat[0],p=2,dim=-1).cpu().numpy() 103 | openfeat_numpy = F.normalize(openset_feat[0],p=2,dim=-1).cpu().numpy() 104 | open_label = np.concatenate((np.ones(query_label.size(1)),np.zeros(openset_label.size(1)))) 105 | 106 | # Numpy Probs Preparation 107 | query_cls_probs, openset_cls_probs = cosine_probs 108 | query_cls_probs = query_cls_probs[0].cpu().numpy() 109 | openset_cls_probs = openset_cls_probs[0].cpu().numpy() 110 | cosine_probs = (query_cls_probs, openset_cls_probs) 111 | 112 | return (cls_protos_numpy,queryfeat_numpy,openfeat_numpy), (supplabel_numpy, querylabel_numpy, open_label), cosine_probs 113 | 114 | 115 | 116 | def mean_confidence_interval(data, confidence=0.95): 117 | a = 100.0 * np.array(data) 118 | n = len(a) 119 | m, se = np.mean(a), scipy.stats.sem(a) 120 | h = se * t._ppf((1+confidence)/2., n-1) 121 | m = np.round(m, 3) 122 | h = np.round(h, 3) 123 | return m, h 124 | 125 | -------------------------------------------------------------------------------- /trainer/GFSEval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys, os, pdb 4 | import numpy as np 5 | import scipy 6 | from scipy.stats import t 7 | from tqdm import tqdm 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from sklearn import metrics 14 | from sklearn.metrics import f1_score 15 | 16 | 17 | def run_test_gfsl(net, genopenloader): 18 | net = net.eval() 19 | acc_base,acc_novel,acc_ball,acc_nall,acc_gen = [],[],[],[],[] 20 | acc_sepa,acc_mean = [],[] 21 | acc_delta = [] 22 | auroc_gen_prob,auroc_gen_diff = [],[] 23 | auroc_f1score = [] 24 | with torch.no_grad(): 25 | with tqdm(genopenloader, total=len(genopenloader), leave=False) as pbar: 26 | for idx, data in enumerate(pbar): 27 | # Data Preparation 28 | support_data, support_label, query_data, query_label, suppopen_data, suppopen_label, openset_data, openset_label, baseset_data, baseset_label, supp_idx, open_idx = data 29 | 30 | num_query,num_open,num_base = query_label.size(1),openset_label.size(1),baseset_label.size(1) 31 | assert support_data.size(0) == 1 32 | num_base_cls, num_novel_cls = baseset_label.max().item()+1, net.n_ways 33 | 34 | support_data,support_label = support_data.float().cuda(),support_label.cuda().long() 35 | query_data,query_label = query_data.float().cuda(),query_label.cuda().long() 36 | suppopen_data,suppopen_label = suppopen_data.float().cuda(),suppopen_label.cuda().long() 37 | openset_data,openset_label = openset_data.float().cuda(),openset_label.cuda().long() 38 | baseset_data,baseset_label = baseset_data.float().cuda(),baseset_label.cuda().long() 39 | supp_idx,open_idx = supp_idx.long().cuda(),open_idx.long().cuda() 40 | openset_label = num_novel_cls * torch.ones_like(openset_label) 41 | 42 | the_img = (support_data, query_data, suppopen_data, openset_data, baseset_data) 43 | the_label = (support_label,query_label,suppopen_label,openset_label,baseset_label) 44 | num_baseclass = baseset_label.max()+1 45 | the_conj = (supp_idx-num_baseclass, open_idx-num_baseclass) 46 | 47 | test_feats, cls_protos, test_cls_scores = net(the_img,the_label,the_conj,None,True) 48 | (baseset_cls_scores,query_cls_scores,openset_cls_scores) = test_cls_scores 49 | (support_feat, query_feat, openset_feat, baseset_feat) = test_feats 50 | 51 | # scores_gen = torch.mm(features_eval, centers_all.transpose(0,1)) 52 | scores_gen = torch.cat([baseset_cls_scores[0],query_cls_scores[0],openset_cls_scores[0]],dim=0) 53 | probs_gen_plus = F.softmax(scores_gen,dim=-1).cpu().numpy() 54 | probs_gen_max = F.softmax(scores_gen[:,:num_base_cls+num_novel_cls],dim=-1).cpu().numpy() 55 | 56 | novel_label = query_label.view(-1).cpu().numpy() 57 | base_label = baseset_label.view(-1).cpu().numpy() 58 | open_label_binary = np.concatenate((np.ones(num_base+num_query),np.zeros(num_open))) 59 | general_label = np.concatenate([base_label,novel_label+base_label.max()+1],axis=0) 60 | 61 | acc_ball.append(metrics.accuracy_score(general_label[:num_base], np.argmax(probs_gen_max[:num_base],-1))) 62 | acc_nall.append(metrics.accuracy_score(general_label[num_base:], np.argmax(probs_gen_max[num_base:num_base+num_query],-1))) 63 | acc_gen.append(2*acc_ball[-1]*acc_nall[-1]/(acc_ball[-1]+acc_nall[-1])) ## harmonic mean 64 | acc_mean.append((acc_ball[-1]+acc_nall[-1])/2) ## arithmetic mean 65 | acc_base.append(metrics.accuracy_score(base_label, np.argmax(probs_gen_max[:num_base,:-net.n_ways],-1))) 66 | acc_novel.append(metrics.accuracy_score(novel_label, np.argmax(probs_gen_max[num_base:num_base+num_query,-net.n_ways:],-1))) 67 | acc_sepa.append((acc_base[-1]+acc_novel[-1])/2) 68 | acc_delta.append(0.5*(acc_base[-1]+acc_novel[-1]-acc_ball[-1]-acc_nall[-1])) 69 | 70 | auroc_gen_prob.append(metrics.roc_auc_score(1-open_label_binary,probs_gen_plus[:,-1])) 71 | auroc_gen_diff.append(metrics.roc_auc_score(1-open_label_binary,probs_gen_plus[:,-1]-probs_gen_plus[:,:-1].max(axis=-1))) 72 | all_labels = np.concatenate([general_label, (general_label.max()+1) * np.ones(num_open)], -1).astype(np.int) 73 | ypred = np.argmax(probs_gen_plus, axis=-1) 74 | auroc_f1score.append(f1_score(all_labels, ypred, average='macro')) 75 | 76 | pbar.set_postfix({ 77 | "OpenSet MetaEval Acc":'{0:.2f}'.format(acc_gen[-1]), 78 | "ROC":'{0:.2f}'.format(auroc_gen_diff[-1]), 79 | "Gen Acc":'{0:.2f}'.format(acc_gen[-1]) 80 | }) 81 | 82 | acc = {'base':mean_confidence_interval(acc_ball),'novel':mean_confidence_interval(acc_nall),'gen':mean_confidence_interval(acc_gen)} 83 | acc_aux = {'bb':mean_confidence_interval(acc_base),'nn':mean_confidence_interval(acc_novel),'sepa_mean':mean_confidence_interval(acc_sepa),'delta':mean_confidence_interval(acc_delta),'all_mean':mean_confidence_interval(acc_mean)} 84 | auroc_nplus = {'prob':mean_confidence_interval(auroc_gen_prob),'diff':mean_confidence_interval(auroc_gen_diff), 'f1':mean_confidence_interval(auroc_f1score)} 85 | 86 | return acc_aux['all_mean'], acc['gen'], acc_aux['delta'], auroc_nplus['prob'], auroc_nplus['f1'] 87 | #return acc['novel']+acc['base']+acc['gen'], auroc_nplus['prob']+auroc_nplus['diff']+auroc_nplus['f1'], acc_aux['bb']+acc_aux['nn']+acc_aux['sepa_mean']+acc_aux['delta']+acc_aux['all_mean'] 88 | 89 | def mean_confidence_interval(data, confidence=0.95): 90 | a = 100.0 * np.array(data) 91 | n = len(a) 92 | m, se = np.mean(a), scipy.stats.sem(a) 93 | h = se * t._ppf((1+confidence)/2., n-1) 94 | m = np.round(m, 3) 95 | h = np.round(h, 3) 96 | return m, h 97 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import numpy as np 5 | import torch.backends.cudnn as cudnn 6 | 7 | from dataloader.dataloader import get_dataloaders 8 | from architectures.NetworkPre import FeatureNet 9 | from architectures.GNetworkPre import GFeatureNet 10 | from trainer.FSEval import run_test_fsl 11 | from trainer.GFSEval import run_test_gfsl 12 | import pdb 13 | import logging 14 | 15 | 16 | 17 | model_pool = ['ResNet18','ResNet12','WRN28'] 18 | parser = argparse.ArgumentParser('argument for training') 19 | 20 | # General Setting 21 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size') 22 | parser.add_argument('--num_workers', type=int, default=3, help='num of workers to use') 23 | parser.add_argument('--featype', type=str, default='OpenMeta', choices=['OpenMeta', 'GOpenMeta'], help='type of task: OpenMeta -- FSOR, GOpenMeta --- GFSOR') 24 | parser.add_argument('--restype', type=str, default='ResNet12', choices=model_pool, help='Network Structure') 25 | parser.add_argument('--dataset', type=str, default='miniImageNet', choices=['miniImageNet', 'tieredImageNet']) 26 | parser.add_argument('--gpus', type=str, default='0') 27 | 28 | # Specify folder 29 | parser.add_argument('--logroot', type=str, default='./logs/', help='path to save model') 30 | parser.add_argument('--data_root', type=str, default='data/', help='path to data root') 31 | parser.add_argument('--test_model_path', type=str, default='max_acc.pth') 32 | parser.add_argument('--pretrained_model_path', type=str, default='miniImageNet_pre.pth') 33 | 34 | # Meta Setting 35 | parser.add_argument('--n_ways', type=int, default=5, metavar='N', help='Number of classes for doing each classification run') 36 | parser.add_argument('--n_open_ways', type=int, default=5, metavar='N', help='Number of classes for doing each classification run') 37 | parser.add_argument('--n_shots', type=int, default=1, metavar='N', help='Number of shots in test') 38 | parser.add_argument('--n_queries', type=int, default=15, metavar='N', help='Number of query in test') 39 | parser.add_argument('--n_aug_support_samples', default=5, type=int, help='The number of augmented samples for each meta test sample') 40 | parser.add_argument('--n_train_para', type=int, default=2, metavar='test_batch_size', help='Size of test batch)') 41 | parser.add_argument('--n_train_runs', type=int, default=300, help='Number of training episodes') 42 | parser.add_argument('--n_test_runs', type=int, default=600, metavar='N', help='Number of test runs') 43 | 44 | # Network Flow Path 45 | parser.add_argument('--gamma', type=float, default=2.0, help='loss cofficient for open-mse loss') 46 | parser.add_argument('--tunefeat', type=float, default=0.0, help='update feature parameter') 47 | parser.add_argument('--train_weight_base', action='store_true', help='enable training base class weights') 48 | # Disgarded temporarily 49 | parser.add_argument('--dist_metric', type=str, default='cosine', help='type of negative generator') 50 | parser.add_argument('--comment', default='', type=str) 51 | 52 | parser.add_argument('--neg_gen_type', type=str, default='semang', choices=['semang', 'attg', 'att', 'mlp']) 53 | parser.add_argument('--base_seman_calib',type=int, default=0, help='base semantics calibration') 54 | parser.add_argument('--tune_part', type=int, default=2, choices=[1,2]) 55 | parser.add_argument('--agg', type=str, default='avg', choices=['avg', 'mlp']) 56 | 57 | parser.add_argument('--held_out', action='store_true') 58 | parser.add_argument('--seed', default=0, type=int) 59 | 60 | args = parser.parse_args() 61 | 62 | 63 | 64 | def eval(args, model, meta_test_loader, config): 65 | params = torch.load(args.test_model_path)['cls_params'] 66 | model.load_state_dict(params, strict=True) 67 | 68 | model.eval() 69 | logging.info('Loaded Model Weight from %s' % args.test_model_path) 70 | 71 | if args.featype == 'OpenMeta': 72 | config = run_test_fsl(model, meta_test_loader,config) 73 | logging.info('Result for %d-shot:' % (args.n_shots)) 74 | for k, v in config.items(): 75 | if k == 'data': 76 | for k1,v1 in v.items(): 77 | logging.info('\t\t{}: {}'.format(k1, v1)) 78 | else: 79 | logging.info('\t{}: {}'.format(k, v)) 80 | 81 | else: 82 | result = run_test_gfsl(model, meta_test_loader) 83 | logging.info('Result for %d-shot:' % (args.n_shots)) 84 | logging.info('\t Arithmetic Mean: {}'.format(result[0])) 85 | logging.info('\t Harmonic Mean: {}'.format(result[1])) 86 | logging.info('\t Delta: {}'.format(result[2])) 87 | logging.info('\t AUROC: {}'.format(result[3])) 88 | logging.info('\t F1: {}'.format(result[4])) 89 | 90 | 91 | if __name__ == "__main__": 92 | torch.manual_seed(args.seed) 93 | os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpus) 94 | 95 | log_file = 'test_%s_%s.log' % (args.comment, args.dataset) 96 | handlers = [logging.FileHandler(log_file), logging.StreamHandler()] 97 | logging.basicConfig(level=logging.INFO, 98 | format='%(asctime)s: %(message)s', 99 | datefmt='%Y-%m-%d %H:%M:%S', 100 | handlers = handlers) 101 | 102 | model_dir = args.test_model_path 103 | mode = 'openmeta' if args.featype == 'OpenMeta' else 'gopenmeta' 104 | _, meta_test_loader, n_cls = get_dataloaders(args, mode) 105 | 106 | params = torch.load(args.pretrained_model_path)['params'] 107 | cls_params = {k: v for k, v in params.items() if 'cls_classifier' in k} 108 | 109 | if args.featype == 'OpenMeta': 110 | model = FeatureNet(args, args.restype, n_cls, (cls_params, meta_test_loader.dataset.vector_array)) 111 | else: 112 | model = GFeatureNet(args, args.restype, n_cls, (cls_params, meta_test_loader.dataset.vector_array)) 113 | 114 | 115 | if torch.cuda.is_available(): 116 | model = model.cuda() 117 | cudnn.benchmark = True 118 | 119 | 120 | ########## Testing Meta-trained Model ########## 121 | print(args.test_model_path) 122 | config = {'auroc_type':['prob', 'fscore']} 123 | 124 | eval(args, model, meta_test_loader, config) 125 | logging.info('-----------SEED: %d-----------------' % args.seed) 126 | logging.info('--------------------------------') 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /architectures/NetworkPre.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | import math 7 | import pdb 8 | from architectures.ResNetFeat import create_feature_extractor 9 | from architectures.AttnClassifier import Classifier 10 | 11 | 12 | class FeatureNet(nn.Module): 13 | def __init__(self,args,restype,n_class,param_seman): 14 | super(FeatureNet,self).__init__() 15 | self.args = args 16 | self.restype = restype 17 | self.n_class = n_class 18 | self.featype = args.featype 19 | self.n_ways = args.n_ways 20 | self.tunefeat = args.tunefeat 21 | self.distance_label = torch.Tensor([i for i in range(self.n_ways)]).cuda().long() 22 | self.metric = Metric_Cosine() 23 | 24 | self.feature = create_feature_extractor(restype,args.dataset) 25 | self.feat_dim = self.feature.out_dim 26 | 27 | 28 | self.cls_classifier = Classifier(args, self.feat_dim, param_seman, args.train_weight_base) if 'OpenMeta' in self.featype else nn.Linear(self.feat_dim, n_class) 29 | 30 | assert 'OpenMeta' in self.featype 31 | if self.tunefeat == 0.0: 32 | for _,p in self.feature.named_parameters(): 33 | p.requires_grad=False 34 | else: 35 | if args.tune_part <= 3: 36 | for _,p in self.feature.layer1.named_parameters(): 37 | p.requires_grad=False 38 | if args.tune_part <= 2: 39 | for _,p in self.feature.layer2.named_parameters(): 40 | p.requires_grad=False 41 | if args.tune_part <= 1: 42 | for _,p in self.feature.layer3.named_parameters(): 43 | p.requires_grad=False 44 | 45 | 46 | def forward(self, the_img, labels=None, conj_ids=None, base_ids=None, test=False): 47 | if labels is None: 48 | assert the_img.dim() == 4 49 | return (self.feature(the_img),None) 50 | else: 51 | return self.open_forward(the_img, labels, conj_ids, base_ids, test) 52 | 53 | 54 | def open_forward(self, the_input, labels, conj_ids, base_ids, test): 55 | # Hyper-parameter Preparation 56 | the_sizes = [_.size(1) for _ in the_input] 57 | (ne,_,nc,nh,nw) = the_input[0].size() 58 | 59 | # Data Preparation 60 | combined_data = torch.cat(the_input,dim=1).view(-1,nc,nh,nw) 61 | if not self.tunefeat: 62 | with torch.no_grad(): 63 | combined_feat = self.feature(combined_data).detach() 64 | else: 65 | combined_feat = self.feature(combined_data) 66 | support_feat,query_feat,supopen_feat,openset_feat = torch.split(combined_feat.view(ne,-1,self.feat_dim),the_sizes,dim=1) 67 | (support_label,query_label,supopen_label,openset_label) = labels 68 | (supp_idx, open_idx) = conj_ids 69 | cls_label = torch.cat([query_label, openset_label], dim=1) 70 | test_feats = (support_feat, query_feat, openset_feat) 71 | 72 | 73 | ### First Task 74 | support_feat = support_feat.view(ne, self.n_ways, -1, self.feat_dim) 75 | test_cosine_scores, supp_protos, fakeclass_protos, loss_cls, loss_funit = self.task_proto((support_feat,query_feat,openset_feat), (supp_idx,base_ids), cls_label, test) 76 | cls_protos = torch.cat([supp_protos, fakeclass_protos], dim=1) 77 | test_cls_probs = self.task_pred(test_cosine_scores[0], test_cosine_scores[1]) 78 | 79 | if test: 80 | test_feats = (support_feat, query_feat, openset_feat) 81 | return test_feats, cls_protos, test_cls_probs 82 | 83 | ## Second task 84 | supopen_feat = supopen_feat.view(ne, self.n_ways, -1, self.feat_dim) 85 | _, supp_protos_aug, fakeclass_protos_aug, loss_cls_aug, loss_funit_aug = self.task_proto((supopen_feat,openset_feat,query_feat), (open_idx,base_ids), cls_label, test) 86 | 87 | supp_protos = F.normalize(supp_protos, dim=-1) 88 | fakeclass_protos = F.normalize(fakeclass_protos, dim=-1) 89 | supp_protos_aug = F.normalize(supp_protos_aug, dim=-1) 90 | fakeclass_protos_aug = F.normalize(fakeclass_protos_aug, dim=-1) 91 | 92 | loss_open_hinge = 0.0 93 | # loss_open_hinge_1 = F.mse_loss(fakeclass_protos.repeat(1,self.n_ways, 1), supp_protos) 94 | # loss_open_hinge_2 = F.mse_loss(fakeclass_protos_aug.repeat(1,self.n_ways, 1), supp_protos_aug) 95 | # loss_open_hinge = loss_open_hinge_1 + loss_open_hinge_2 96 | 97 | 98 | 99 | loss = (loss_cls+loss_cls_aug, loss_open_hinge, loss_funit+loss_funit_aug) 100 | return test_feats, cls_protos, test_cls_probs, loss 101 | 102 | 103 | def task_proto(self, features, cls_ids, cls_label,test=False): 104 | test_cosine_scores, supp_protos, fakeclass_protos, _, funit_distance = self.cls_classifier(features, cls_ids, test) 105 | (query_cls_scores,openset_cls_scores) = test_cosine_scores 106 | cls_scores = torch.cat([query_cls_scores,openset_cls_scores], dim=1) 107 | fakeunit_loss = fakeunit_compare(funit_distance,self.n_ways,cls_label) 108 | cls_scores,close_label,cls_label = cls_scores.view(-1, self.n_ways+1),cls_label[:,:query_cls_scores.size(1)].reshape(-1),cls_label.view(-1) 109 | loss_cls = F.cross_entropy(cls_scores, cls_label) 110 | return test_cosine_scores, supp_protos, fakeclass_protos, loss_cls, fakeunit_loss 111 | 112 | 113 | def task_pred(self, query_cls_scores, openset_cls_scores, many_cls_scores=None): 114 | query_cls_probs = F.softmax(query_cls_scores.detach(), dim=-1) 115 | openset_cls_probs = F.softmax(openset_cls_scores.detach(), dim=-1) 116 | if many_cls_scores is None: 117 | return (query_cls_probs, openset_cls_probs) 118 | else: 119 | many_cls_probs = F.softmax(many_cls_scores.detach(), dim=-1) 120 | return (query_cls_probs, openset_cls_probs, many_cls_probs, query_cls_scores, openset_cls_scores) 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | class Metric_Cosine(nn.Module): 129 | def __init__(self, temperature=10): 130 | super(Metric_Cosine, self).__init__() 131 | self.temp = nn.Parameter(torch.tensor(float(temperature))) 132 | 133 | def forward(self, supp_center, query_feature): 134 | ## supp_center: bs*nway*D 135 | ## query_feature: bs*(nway*nquery)*D 136 | supp_center = F.normalize(supp_center, dim=-1) # eps=1e-6 default 1e-12 137 | query_feature = F.normalize(query_feature, dim=-1) 138 | logits = torch.bmm(query_feature, supp_center.transpose(1,2)) 139 | return logits * self.temp 140 | 141 | 142 | 143 | def fakeunit_compare(funit_distance,n_ways,cls_label): 144 | cls_label_binary = F.one_hot(cls_label)[:,:,:-1].float() 145 | loss = torch.sum(F.binary_cross_entropy_with_logits(input=funit_distance, target=cls_label_binary)) 146 | return loss -------------------------------------------------------------------------------- /pretrain/dataloader/mini_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | import pdb 7 | from torch.utils.data import Dataset 8 | import torchvision.transforms as transforms 9 | 10 | class PreMini(Dataset): 11 | def __init__(self, args, partition='train', is_training=True, is_contrast=False): 12 | super(PreMini, self).__init__() 13 | mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0] 14 | std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0] 15 | normalize = transforms.Normalize(mean=mean, std=std) 16 | self.is_contrast = is_training and is_contrast 17 | 18 | if is_training: 19 | if is_contrast: 20 | self.transform_left = transforms.Compose([ 21 | transforms.RandomResizedCrop(size=84, scale=(0.2, 1.)), 22 | transforms.RandomHorizontalFlip(), 23 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 24 | transforms.RandomGrayscale(p=0.2), 25 | transforms.ToTensor(), 26 | normalize 27 | ]) 28 | self.transform_right = transforms.Compose([ 29 | transforms.RandomRotation(args.rotangle), 30 | transforms.RandomCrop(84, padding=8), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ToTensor(), 33 | normalize 34 | ]) 35 | else: 36 | self.transform = transforms.Compose([ 37 | transforms.RandomCrop(84, padding=8), 38 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 39 | transforms.RandomHorizontalFlip(), 40 | transforms.ToTensor(), 41 | normalize 42 | ]) 43 | else: 44 | self.transform = transforms.Compose([transforms.ToTensor(),normalize]) 45 | 46 | filename = 'miniImageNet_category_split_train_phase_{}.pickle'.format(partition) 47 | self.data = {} 48 | with open(os.path.join(args.data_root, filename), 'rb') as f: 49 | pack = pickle.load(f, encoding='latin1') 50 | imgs = pack['data'].astype('uint8') 51 | labels = pack['labels'] 52 | self.imgs = [Image.fromarray(x) for x in imgs] 53 | min_label = min(labels) 54 | self.labels = [x - min_label for x in labels] 55 | print('Load {} Data of {} for miniImagenet in Pretraining Stage'.format(len(self.imgs), partition)) 56 | 57 | def __getitem__(self, item): 58 | img = self.transform(self.imgs[item]) 59 | target = self.labels[item] 60 | return img, target, item 61 | 62 | def __len__(self): 63 | return len(self.labels) 64 | 65 | def random_idx(self): 66 | self.rand_idx = {k:np.random.permutation(len(v)) for k,v in self.data.items()} 67 | 68 | 69 | class MetaMini(Dataset): 70 | def __init__(self, args, n_shots, partition='test', is_training=False, fix_seed=True): 71 | super(MetaMini, self).__init__() 72 | self.fix_seed = fix_seed 73 | self.n_ways = args.n_ways 74 | self.n_shots = n_shots 75 | self.n_queries = args.n_queries 76 | self.n_episodes = args.n_episodes 77 | self.n_aug_support_samples = args.n_aug_support_samples 78 | 79 | mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0] 80 | std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0] 81 | normalize = transforms.Normalize(mean=mean, std=std) 82 | 83 | if is_training: 84 | self.train_transform = transforms.Compose([ 85 | transforms.RandomCrop(84, padding=8), 86 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 87 | transforms.RandomHorizontalFlip(), 88 | transforms.ToTensor(), 89 | normalize 90 | ]) 91 | else: 92 | self.train_transform = transforms.Compose([ 93 | transforms.RandomCrop(84, padding=8), 94 | transforms.RandomHorizontalFlip(), 95 | transforms.ToTensor(), 96 | normalize 97 | ]) 98 | 99 | self.test_transform = transforms.Compose([transforms.ToTensor(),normalize]) 100 | 101 | suffix = partition if partition in ['val','test'] else 'train_phase_train' 102 | filename = 'miniImageNet_category_split_{}.pickle'.format(suffix) 103 | self.data = {} 104 | with open(os.path.join(args.data_root, filename), 'rb') as f: 105 | pack = pickle.load(f, encoding='latin1') 106 | imgs = pack['data'].astype('uint8') 107 | labels = pack['labels'] 108 | self.imgs = [Image.fromarray(x) for x in imgs] 109 | min_label = min(labels) 110 | self.labels = [x - min_label for x in labels] 111 | print('Load {} Data of {} for miniImagenet in Meta-Learning Stage'.format(len(self.imgs), partition)) 112 | 113 | self.data = {} 114 | for idx in range(len(self.imgs)): 115 | if self.labels[idx] not in self.data: 116 | self.data[self.labels[idx]] = [] 117 | self.data[self.labels[idx]].append(self.imgs[idx]) 118 | self.classes = list(self.data.keys()) 119 | 120 | def __getitem__(self, item): 121 | if self.fix_seed: 122 | np.random.seed(item) 123 | cls_sampled = np.random.choice(self.classes, self.n_ways, False) 124 | support_xs = [] 125 | support_ys = [] 126 | query_xs = [] 127 | query_ys = [] 128 | for idx, the_cls in enumerate(cls_sampled): 129 | imgs = self.data[the_cls] 130 | support_xs_ids_sampled = np.random.choice(range(len(imgs)), self.n_shots, False) 131 | support_xs.extend([imgs[the_id] for the_id in support_xs_ids_sampled]) 132 | support_ys.extend([idx] * self.n_shots) 133 | query_xs_ids = np.setxor1d(np.arange(len(imgs)), support_xs_ids_sampled) 134 | query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False) 135 | query_xs.extend([imgs[the_id] for the_id in query_xs_ids]) 136 | query_ys.extend([idx] * query_xs_ids.shape[0]) 137 | 138 | if self.n_aug_support_samples > 1: 139 | support_xs = support_xs * self.n_aug_support_samples 140 | support_ys = support_ys * self.n_aug_support_samples 141 | 142 | support_xs = torch.stack(list(map(lambda x: self.train_transform(x), support_xs))) 143 | query_xs = torch.stack(list(map(lambda x: self.test_transform(x), query_xs))) 144 | support_ys,query_ys = np.array(support_ys),np.array(query_ys) 145 | 146 | return support_xs, support_ys, query_xs, query_ys 147 | 148 | def __len__(self): 149 | return self.n_episodes -------------------------------------------------------------------------------- /architectures/GNetworkPre.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | import math 7 | import pdb 8 | from architectures.ResNetFeat import create_feature_extractor 9 | from architectures.GAttnClassifier import GClassifier 10 | 11 | 12 | 13 | class GFeatureNet(nn.Module): 14 | def __init__(self,args,restype,n_class,param_seman): 15 | super(GFeatureNet,self).__init__() 16 | self.args = args 17 | self.restype = restype 18 | self.n_class = n_class 19 | self.featype = args.featype 20 | self.n_ways = args.n_ways 21 | self.tunefeat = args.tunefeat 22 | self.distance_label = torch.Tensor([i for i in range(self.n_ways)]).cuda().long() 23 | self.metric = Metric_Cosine() 24 | 25 | self.feature = create_feature_extractor(restype,args.dataset) 26 | self.feat_dim = self.feature.out_dim 27 | 28 | 29 | self.cls_classifier = GClassifier(args, self.feat_dim, param_seman, args.train_weight_base) if 'GOpenMeta' in self.featype else nn.Linear(self.feat_dim, n_class) 30 | 31 | assert 'GOpenMeta' in self.featype 32 | if self.tunefeat == 0.0: 33 | for _,p in self.feature.named_parameters(): 34 | p.requires_grad=False 35 | else: 36 | if args.tune_part <= 3: 37 | for _,p in self.feature.layer1.named_parameters(): 38 | p.requires_grad=False 39 | if args.tune_part <= 2: 40 | for _,p in self.feature.layer2.named_parameters(): 41 | p.requires_grad=False 42 | if args.tune_part <= 1: 43 | for _,p in self.feature.layer3.named_parameters(): 44 | p.requires_grad=False 45 | 46 | def forward(self, the_img, labels=None, conj_ids=None, base_ids=None, test=False): 47 | if labels is None: 48 | assert the_img.dim() == 4 49 | return (self.feature(the_img),None) 50 | else: 51 | return self.gen_open_forward(the_img, labels, conj_ids, base_ids, test) 52 | 53 | def gen_open_forward(self, the_input, labels, conj_ids, base_ids, test): 54 | # Hyper-parameter Preparation 55 | the_sizes = [_.size(1) for _ in the_input] 56 | (ne,_,nc,nh,nw) = the_input[0].size() 57 | 58 | # Data Preparation 59 | combined_data = torch.cat(the_input,dim=1).view(-1,nc,nh,nw) 60 | if not self.tunefeat: 61 | with torch.no_grad(): 62 | combined_feat = self.feature(combined_data).detach() 63 | else: 64 | combined_feat = self.feature(combined_data) 65 | support_feat,query_feat,supopen_feat,openset_feat,baseset_feat = torch.split(combined_feat.view(ne,-1,self.feat_dim),the_sizes,dim=1) 66 | (support_label,query_label,suppopen_label,openset_label,baseset_label) = labels 67 | (supp_idx, open_idx) = conj_ids 68 | num_baseclass = baseset_label.max()+1 69 | cls_label = torch.cat([baseset_label,query_label+num_baseclass,openset_label+num_baseclass], dim=1) 70 | test_feats = (support_feat, query_feat, openset_feat, baseset_feat) 71 | 72 | ### First Task 73 | support_feat = support_feat.view(ne, self.n_ways, -1, self.feat_dim) 74 | test_cosine_scores, supp_protos, fakeclass_protos, base_centers, loss_cls, loss_funit = self.gen_task_proto((support_feat,query_feat,openset_feat,baseset_feat), (supp_idx,base_ids), cls_label, num_baseclass, test) 75 | cls_protos = torch.cat([supp_protos, fakeclass_protos], dim=1) 76 | test_cls_probs = self.task_pred(test_cosine_scores[1], test_cosine_scores[2], test_cosine_scores[0]) 77 | 78 | if test: 79 | return test_feats, cls_protos, test_cosine_scores 80 | 81 | ## Second task 82 | supopen_feat = supopen_feat.view(ne, self.n_ways, -1, self.feat_dim) 83 | _, supp_protos_aug, fakeclass_protos_aug, _, loss_cls_aug, loss_funit_aug = self.gen_task_proto((supopen_feat,openset_feat,query_feat, baseset_feat), (open_idx,base_ids), cls_label, num_baseclass, test) 84 | 85 | supp_protos = F.normalize(supp_protos, dim=-1) 86 | fakeclass_protos = F.normalize(fakeclass_protos, dim=-1) 87 | supp_protos_aug = F.normalize(supp_protos_aug, dim=-1) 88 | fakeclass_protos_aug = F.normalize(fakeclass_protos_aug, dim=-1) 89 | 90 | loss_open_hinge_1 = F.mse_loss(fakeclass_protos.repeat(1,self.n_ways, 1), supp_protos) 91 | loss_open_hinge_2 = F.mse_loss(fakeclass_protos_aug.repeat(1,self.n_ways, 1), supp_protos_aug) 92 | loss_open_hinge = loss_open_hinge_1 + loss_open_hinge_2 93 | 94 | loss = (loss_cls+loss_cls_aug, loss_open_hinge, loss_funit+loss_funit_aug) 95 | return test_feats, cls_protos, test_cls_probs, loss 96 | 97 | def gen_task_proto(self, features, cls_ids, cls_label,num_baseclass, test=False): 98 | test_cosine_scores, supp_protos, fakeclass_protos, base_weights, funit_distance = self.cls_classifier(features, cls_ids, test) 99 | if fakeclass_protos is None: 100 | return test_cosine_scores, supp_protos, None, None 101 | (base_centers,weight_mem) = base_weights 102 | 103 | cls_scores = torch.cat(test_cosine_scores, dim=1).view(-1,num_baseclass+self.n_ways+1) 104 | fakeunit_loss = fakeunit_compare(funit_distance,self.n_ways,cls_label[:,test_cosine_scores[0].size(1):]-num_baseclass) 105 | 106 | loss_cls = F.cross_entropy(cls_scores, cls_label.view(-1)) 107 | return test_cosine_scores, supp_protos, fakeclass_protos, base_centers, loss_cls, fakeunit_loss 108 | 109 | def task_pred(self, query_cls_scores, openset_cls_scores, many_cls_scores=None): 110 | query_cls_probs = F.softmax(query_cls_scores.detach(), dim=-1) 111 | openset_cls_probs = F.softmax(openset_cls_scores.detach(), dim=-1) 112 | if many_cls_scores is None: 113 | return (query_cls_probs, openset_cls_probs) 114 | else: 115 | many_cls_probs = F.softmax(many_cls_scores.detach(), dim=-1) 116 | return (query_cls_probs, openset_cls_probs, many_cls_probs, query_cls_scores, openset_cls_scores) 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | class Metric_Cosine(nn.Module): 125 | def __init__(self, temperature=10): 126 | super(Metric_Cosine, self).__init__() 127 | self.temp = nn.Parameter(torch.tensor(float(temperature))) 128 | 129 | def forward(self, supp_center, query_feature): 130 | ## supp_center: bs*nway*D 131 | ## query_feature: bs*(nway*nquery)*D 132 | supp_center = F.normalize(supp_center, dim=-1) # eps=1e-6 default 1e-12 133 | query_feature = F.normalize(query_feature, dim=-1) 134 | logits = torch.bmm(query_feature, supp_center.transpose(1,2)) 135 | return logits * self.temp 136 | 137 | 138 | 139 | def fakeunit_compare(funit_distance,n_ways,cls_label): 140 | cls_label_binary = F.one_hot(cls_label)[:,:,:-1].float() 141 | loss = torch.sum(F.binary_cross_entropy_with_logits(input=funit_distance, target=cls_label_binary)) 142 | return loss -------------------------------------------------------------------------------- /pretrain/architectures/LossFeat.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.distributions import Bernoulli 5 | from torch.autograd import Variable 6 | from scipy.special import gamma 7 | 8 | import numpy as np 9 | import math 10 | import pdb 11 | 12 | # class NNClassifier(nn.Module): 13 | # def __init__(self, in_dim, n_classes,metric='cos'): 14 | # super().__init__() 15 | # self.proto = nn.Parameter(torch.empty(n_classes, in_dim)) 16 | # nn.init.kaiming_uniform_(self.proto, a=math.sqrt(n_classes)) 17 | # self.metric = metric 18 | # self.temp = nn.Parameter(torch.tensor(10.),requires_grad=False) 19 | # def forward(self, feat): 20 | # logits = torch.mm(F.normalize(feat, dim=-1),F.normalize(self.proto, dim=-1).t()) 21 | # return logits * self.temp 22 | 23 | class LossLogit(nn.Module): 24 | def __init__(self, detach=False): 25 | super(LossLogit,self).__init__() 26 | self.detach = detach 27 | self.classification = nn.CrossEntropyLoss() 28 | self.divergance = nn.KLDivLoss(reduction='batchmean') 29 | 30 | def forward(self,logits,targets,ce_partner=False): 31 | p_logit,m_logit = logits 32 | m_ce,p_ce = self.classification(m_logit,targets),self.classification(p_logit,targets) 33 | 34 | if self.detach: 35 | p_logit = p_logit.detach() 36 | loss = self.divergance(F.log_softmax(m_logit, dim=1), F.softmax(Variable(p_logit), dim=1)) 37 | hloss = F.softmax(p_logit, dim=1) * F.log_softmax(p_logit, dim=1) 38 | hloss = -1.0*hloss.sum(dim=-1).mean() 39 | loss = loss+hloss 40 | return loss+m_ce+p_ce if ce_partner else loss+m_ce 41 | 42 | class SupConLoss(nn.Module): 43 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 44 | It also supports the unsupervised contrastive loss in SimCLR""" 45 | def __init__(self, temperature=0.07, contrast_mode='all', 46 | base_temperature=0.07): 47 | super(SupConLoss, self).__init__() 48 | self.temperature = temperature 49 | self.contrast_mode = contrast_mode 50 | self.base_temperature = base_temperature 51 | 52 | def forward(self, features, labels=None, mask=None): 53 | """Compute loss for model. If both `labels` and `mask` are None, 54 | it degenerates to SimCLR unsupervised loss: 55 | https://arxiv.org/pdf/2002.05709.pdf 56 | 57 | Args: 58 | features: hidden vector of shape [bsz, n_views, ...]. 59 | labels: ground truth of shape [bsz]. 60 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 61 | has the same class as sample i. Can be asymmetric. 62 | Returns: 63 | A loss scalar. 64 | """ 65 | device = (torch.device('cuda') 66 | if features.is_cuda 67 | else torch.device('cpu')) 68 | 69 | if len(features.shape) < 3: 70 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 71 | 'at least 3 dimensions are required') 72 | if len(features.shape) > 3: 73 | features = features.view(features.shape[0], features.shape[1], -1) 74 | 75 | batch_size = features.shape[0] 76 | if labels is not None and mask is not None: 77 | raise ValueError('Cannot define both `labels` and `mask`') 78 | elif labels is None and mask is None: 79 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 80 | elif labels is not None: 81 | labels = labels.contiguous().view(-1, 1) 82 | if labels.shape[0] != batch_size: 83 | raise ValueError('Num of labels does not match num of features') 84 | mask = torch.eq(labels, labels.T).float().to(device) 85 | else: 86 | mask = mask.float().to(device) 87 | 88 | contrast_count = features.shape[1] 89 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 90 | if self.contrast_mode == 'one': 91 | anchor_feature = features[:, 0] 92 | anchor_count = 1 93 | elif self.contrast_mode == 'all': 94 | anchor_feature = contrast_feature 95 | anchor_count = contrast_count 96 | else: 97 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 98 | 99 | # compute logits 100 | anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T),self.temperature) 101 | # for numerical stability 102 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 103 | logits = anchor_dot_contrast - logits_max.detach() 104 | 105 | # tile mask 106 | mask = mask.repeat(anchor_count, contrast_count) 107 | # mask-out self-contrast cases 108 | logits_mask = torch.scatter( 109 | torch.ones_like(mask), 110 | 1, 111 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 112 | 0 113 | ) 114 | mask = mask * logits_mask 115 | 116 | # compute log_prob 117 | exp_logits = torch.exp(logits) * logits_mask 118 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 119 | 120 | # compute mean of log-likelihood over positive 121 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 122 | 123 | # loss 124 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 125 | loss = loss.view(anchor_count, batch_size).mean() 126 | 127 | return loss 128 | 129 | # class Loss_Partner(nn.Module): 130 | # def __init__(self,mode='KL-Mutual'): 131 | # super(Loss_Partner,self).__init__() 132 | # self.mode = mode 133 | # self.classification = nn.CrossEntropyLoss() 134 | # self.divergance = nn.KLDivLoss(reduction='batchmean') 135 | 136 | # def forward(self,logits,targets,ce_partner=False): 137 | # p_logit,m_logit = logits 138 | # m_ce,p_ce = self.classification(m_logit,targets),self.classification(p_logit,targets) 139 | # if self.mode == 'KL-Mutual': 140 | # kl_p2m = self.divergance(F.log_softmax(m_logit, dim=1), F.softmax(Variable(p_logit), dim=1)) 141 | # kl_m2p = self.divergance(F.log_softmax(p_logit, dim=1), F.softmax(Variable(m_logit), dim=1)) 142 | # loss = kl_p2m + kl_m2p 143 | # elif self.mode == 'KL-Partner': 144 | # loss = self.divergance(F.log_softmax(m_logit, dim=1), F.softmax(Variable(p_logit), dim=1)) 145 | # elif self.mode == 'KL-Mainer': 146 | # loss = self.divergance(F.log_softmax(p_logit, dim=1), F.softmax(Variable(m_logit), dim=1)) 147 | # elif self.mode == 'Cross-Mainer': 148 | # loss = self.divergance(F.log_softmax(p_logit, dim=1), F.softmax(Variable(m_logit), dim=1)) 149 | # hloss = F.softmax(m_logit, dim=1) * F.log_softmax(m_logit, dim=1) 150 | # hloss = -1.0*hloss.sum(dim=-1).mean() 151 | # loss = loss+hloss 152 | # elif self.mode == 'Cross-Partner': 153 | # loss = self.divergance(F.log_softmax(m_logit, dim=1), F.softmax(Variable(p_logit), dim=1)) 154 | # hloss = F.softmax(p_logit, dim=1) * F.log_softmax(p_logit, dim=1) 155 | # hloss = -1.0*hloss.sum(dim=-1).mean() 156 | # loss = loss+hloss 157 | # return loss+m_ce+p_ce if ce_partner else loss+m_ce 158 | 159 | -------------------------------------------------------------------------------- /pretrain/dataloader/cifar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | import pdb 7 | from torch.utils.data import Dataset 8 | import torchvision.transforms as transforms 9 | 10 | class PreCIFAR(Dataset): 11 | def __init__(self, args, partition='train', is_training=True, is_contrast=False): 12 | super(PreCIFAR, self).__init__() 13 | mean = [0.5071, 0.4867, 0.4408] 14 | std = [0.2675, 0.2565, 0.2761] 15 | normalize = transforms.Normalize(mean=mean, std=std) 16 | self.is_contrast = is_training and is_contrast 17 | 18 | if is_training: 19 | if is_contrast: 20 | self.transform_left = transforms.Compose([ 21 | transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)), 22 | transforms.RandomHorizontalFlip(), 23 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 24 | transforms.RandomGrayscale(p=0.2), 25 | transforms.ToTensor(), 26 | normalize 27 | ]) 28 | self.transform_right = transforms.Compose([ 29 | transforms.RandomRotation(5), 30 | transforms.RandomCrop(32, padding=4), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ToTensor(), 33 | normalize 34 | ]) 35 | else: 36 | self.transform = transforms.Compose([ 37 | transforms.RandomCrop(32, padding=4), 38 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 39 | transforms.RandomHorizontalFlip(), 40 | transforms.ToTensor(), 41 | normalize 42 | ]) 43 | else: 44 | self.transform = transforms.Compose([transforms.ToTensor(),normalize]) 45 | 46 | filename = '{}.pickle'.format(partition) 47 | self.data = {} 48 | with open(os.path.join(args.data_root, filename), 'rb') as f: 49 | pack = pickle.load(f, encoding='latin1') 50 | self.imgs = pack['data'] 51 | labels = pack['labels'] 52 | 53 | cur_class = 0 54 | label2label = {} 55 | for _, label in enumerate(labels): 56 | if label not in label2label: 57 | label2label[label] = cur_class 58 | cur_class += 1 59 | new_labels = [] 60 | for idx, label in enumerate(labels): 61 | new_labels.append(label2label[label]) 62 | self.labels = new_labels 63 | 64 | self.imgs = [Image.fromarray(x) for x in self.imgs] 65 | print('Load {} Data of {} for {} in Pretraining Stage'.format(len(self.imgs), partition, args.dataset)) 66 | 67 | def __getitem__(self, item): 68 | if self.is_contrast: 69 | left,right = self.transform_left(self.imgs[item]),self.transform_right(self.imgs[item]) 70 | target = self.labels[item] 71 | return left, right, target, item 72 | else: 73 | img = self.transform(self.imgs[item]) 74 | target = self.labels[item] 75 | return img, target, item 76 | 77 | def __len__(self): 78 | return len(self.labels) 79 | 80 | class MetaCIFAR(Dataset): 81 | def __init__(self, args, n_shots, partition='test', is_training=False, fix_seed=True): 82 | super(MetaCIFAR, self).__init__() 83 | self.fix_seed = fix_seed 84 | self.n_ways = args.n_ways 85 | self.n_shots = n_shots 86 | self.n_queries = args.n_queries 87 | self.n_episodes = args.n_episodes 88 | self.n_aug_support_samples = args.n_aug_support_samples 89 | 90 | mean = [0.5071, 0.4867, 0.4408] 91 | std = [0.2675, 0.2565, 0.2761] 92 | normalize = transforms.Normalize(mean=mean, std=std) 93 | 94 | if is_training: 95 | self.train_transform = transforms.Compose([ 96 | transforms.RandomCrop(32, padding=4), 97 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 98 | transforms.RandomHorizontalFlip(), 99 | transforms.ToTensor(), 100 | normalize 101 | ]) 102 | else: 103 | self.train_transform = transforms.Compose([ 104 | transforms.RandomCrop(32, padding=4), 105 | transforms.RandomHorizontalFlip(), 106 | transforms.ToTensor(), 107 | normalize 108 | ]) 109 | 110 | self.test_transform = transforms.Compose([transforms.ToTensor(),normalize]) 111 | 112 | filename = '{}.pickle'.format(partition) 113 | self.data = {} 114 | with open(os.path.join(args.data_root, filename), 'rb') as f: 115 | pack = pickle.load(f, encoding='latin1') 116 | self.imgs = pack['data'] 117 | labels = pack['labels'] 118 | 119 | cur_class = 0 120 | label2label = {} 121 | for _, label in enumerate(labels): 122 | if label not in label2label: 123 | label2label[label] = cur_class 124 | cur_class += 1 125 | new_labels = [] 126 | for idx, label in enumerate(labels): 127 | new_labels.append(label2label[label]) 128 | self.labels = new_labels 129 | 130 | self.imgs = [Image.fromarray(x) for x in self.imgs] 131 | print('Load {} Data of {} for {} in Meta-Learning Stage'.format(len(self.imgs), partition, args.dataset)) 132 | 133 | self.data = {} 134 | for idx in range(len(self.imgs)): 135 | if self.labels[idx] not in self.data: 136 | self.data[self.labels[idx]] = [] 137 | self.data[self.labels[idx]].append(self.imgs[idx]) 138 | self.classes = list(self.data.keys()) 139 | 140 | def __getitem__(self, item): 141 | if self.fix_seed: 142 | np.random.seed(item) 143 | cls_sampled = np.random.choice(self.classes, self.n_ways, False) 144 | support_xs = [] 145 | support_ys = [] 146 | query_xs = [] 147 | query_ys = [] 148 | for idx, the_cls in enumerate(cls_sampled): 149 | imgs = self.data[the_cls] 150 | support_xs_ids_sampled = np.random.choice(range(len(imgs)), self.n_shots, False) 151 | support_xs.extend([imgs[the_id] for the_id in support_xs_ids_sampled]) 152 | support_ys.extend([idx] * self.n_shots) 153 | query_xs_ids = np.setxor1d(np.arange(len(imgs)), support_xs_ids_sampled) 154 | query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False) 155 | query_xs.extend([imgs[the_id] for the_id in query_xs_ids]) 156 | query_ys.extend([idx] * query_xs_ids.shape[0]) 157 | 158 | if self.n_aug_support_samples > 1: 159 | support_xs = support_xs * self.n_aug_support_samples 160 | support_ys = support_ys * self.n_aug_support_samples 161 | 162 | support_xs = torch.stack(list(map(lambda x: self.train_transform(x), support_xs))) 163 | query_xs = torch.stack(list(map(lambda x: self.test_transform(x), query_xs))) 164 | support_ys,query_ys = np.array(support_ys),np.array(query_ys) 165 | 166 | return support_xs, support_ys, query_xs, query_ys 167 | 168 | def __len__(self): 169 | return self.n_episodes 170 | -------------------------------------------------------------------------------- /pretrain/dataloader/tiered_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | import pdb 9 | 10 | def load_labels(file): 11 | try: 12 | with open(file, 'rb') as fo: 13 | data = pickle.load(fo) 14 | return data 15 | except: 16 | with open(file, 'rb') as f: 17 | u = pickle._Unpickler(f) 18 | u.encoding = 'latin1' 19 | data = u.load() 20 | return data 21 | 22 | mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0] 23 | std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0] 24 | normalize = transforms.Normalize(mean=mean, std=std) 25 | 26 | class PreTiered(Dataset): 27 | def __init__(self, args, partition='train', is_training=True, is_contrast=False): 28 | super(PreTiered, self).__init__() 29 | self.is_contrast = is_training and is_contrast 30 | 31 | if is_training: 32 | if is_contrast: 33 | self.transform_left = transforms.Compose([ 34 | transforms.RandomResizedCrop(size=84, scale=(0.2, 1.)), 35 | transforms.RandomHorizontalFlip(), 36 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 37 | transforms.RandomGrayscale(p=0.2), 38 | transforms.ToTensor(), 39 | normalize 40 | ]) 41 | self.transform_right = transforms.Compose([ 42 | transforms.RandomRotation(10), 43 | transforms.RandomCrop(84, padding=8), 44 | transforms.RandomHorizontalFlip(), 45 | transforms.ToTensor(), 46 | normalize 47 | ]) 48 | else: 49 | self.transform = transforms.Compose([ 50 | transforms.RandomCrop(84, padding=8), 51 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 52 | transforms.RandomHorizontalFlip(), 53 | transforms.ToTensor(), 54 | normalize 55 | ]) 56 | else: 57 | self.transform = transforms.Compose([transforms.ToTensor(),normalize]) 58 | 59 | image_file = '{}_images.npz'.format(partition) 60 | label_file = '{}_labels.pkl'.format(partition) 61 | 62 | # modified code to load tieredImageNet 63 | image_file = os.path.join(args.data_root, image_file) 64 | self.imgs = np.load(image_file)['images'] 65 | label_file = os.path.join(args.data_root, label_file) 66 | labels = load_labels(label_file)['labels'] 67 | 68 | self.imgs = [Image.fromarray(x) for x in self.imgs] 69 | min_label = min(labels) 70 | self.labels = [x - min_label for x in labels] 71 | print('Load {} Data of {} for tieredImageNet in Pretraining Stage'.format(len(self.imgs), partition)) 72 | 73 | def __getitem__(self, item): 74 | if self.is_contrast: 75 | left,right = self.transform_left(self.imgs[item]),self.transform_right(self.imgs[item]) 76 | target = self.labels[item] 77 | return left, right, target, item 78 | else: 79 | img = self.transform(self.imgs[item]) 80 | target = self.labels[item] 81 | return img, target, item 82 | 83 | def __len__(self): 84 | return len(self.labels) 85 | 86 | class MetaTiered(Dataset): 87 | def __init__(self, args, n_shots, partition='test', is_training=False, fix_seed=True): 88 | super(MetaTiered, self).__init__() 89 | self.fix_seed = fix_seed 90 | self.n_ways = args.n_ways 91 | self.n_shots = n_shots 92 | self.n_queries = args.n_queries 93 | self.n_episodes = args.n_episodes 94 | self.n_aug_support_samples = args.n_aug_support_samples 95 | 96 | mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0] 97 | std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0] 98 | normalize = transforms.Normalize(mean=mean, std=std) 99 | 100 | if is_training: 101 | self.train_transform = transforms.Compose([ 102 | transforms.RandomCrop(84, padding=8), 103 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 104 | transforms.RandomHorizontalFlip(), 105 | transforms.ToTensor(), 106 | normalize 107 | ]) 108 | else: 109 | self.train_transform = transforms.Compose([ 110 | transforms.RandomCrop(84, padding=8), 111 | transforms.RandomHorizontalFlip(), 112 | transforms.ToTensor(), 113 | normalize 114 | ]) 115 | 116 | self.test_transform = transforms.Compose([transforms.ToTensor(),normalize]) 117 | image_file = '{}_images.npz'.format(partition) 118 | label_file = '{}_labels.pkl'.format(partition) 119 | 120 | # modified code to load tieredImageNet 121 | image_file = os.path.join(args.data_root, image_file) 122 | self.imgs = np.load(image_file)['images'] 123 | label_file = os.path.join(args.data_root, label_file) 124 | labels = load_labels(label_file)['labels'] 125 | 126 | self.imgs = [Image.fromarray(x) for x in self.imgs] 127 | min_label = min(labels) 128 | self.labels = [x - min_label for x in labels] 129 | print('Load {} Data of {} for tieredImageNet in Meta-Learning Stage'.format(len(self.imgs), partition)) 130 | 131 | self.data = {} 132 | for idx in range(len(self.imgs)): 133 | if self.labels[idx] not in self.data: 134 | self.data[self.labels[idx]] = [] 135 | self.data[self.labels[idx]].append(self.imgs[idx]) 136 | self.classes = list(self.data.keys()) 137 | 138 | def __getitem__(self, item): 139 | if self.fix_seed: 140 | np.random.seed(item) 141 | cls_sampled = np.random.choice(self.classes, self.n_ways, False) 142 | support_xs = [] 143 | support_ys = [] 144 | query_xs = [] 145 | query_ys = [] 146 | for idx, the_cls in enumerate(cls_sampled): 147 | imgs = self.data[the_cls] 148 | support_xs_ids_sampled = np.random.choice(range(len(imgs)), self.n_shots, False) 149 | support_xs.extend([imgs[the_id] for the_id in support_xs_ids_sampled]) 150 | support_ys.extend([idx] * self.n_shots) 151 | query_xs_ids = np.setxor1d(np.arange(len(imgs)), support_xs_ids_sampled) 152 | query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False) 153 | query_xs.extend([imgs[the_id] for the_id in query_xs_ids]) 154 | query_ys.extend([idx] * query_xs_ids.shape[0]) 155 | 156 | if self.n_aug_support_samples > 1: 157 | support_xs = support_xs * self.n_aug_support_samples 158 | support_ys = support_ys * self.n_aug_support_samples 159 | 160 | support_xs = torch.stack(list(map(lambda x: self.train_transform(x), support_xs))) 161 | query_xs = torch.stack(list(map(lambda x: self.test_transform(x), query_xs))) 162 | support_ys,query_ys = np.array(support_ys),np.array(query_ys) 163 | return support_xs, support_ys, query_xs, query_ys 164 | 165 | def __len__(self): 166 | return self.n_episodes -------------------------------------------------------------------------------- /dataloader/tiered_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | import pdb 9 | 10 | def load_labels(file): 11 | try: 12 | with open(file, 'rb') as fo: 13 | data = pickle.load(fo) 14 | return data 15 | except: 16 | with open(file, 'rb') as f: 17 | u = pickle._Unpickler(f) 18 | u.encoding = 'latin1' 19 | data = u.load() 20 | return data 21 | 22 | 23 | 24 | class OpenTiered(Dataset): 25 | def __init__(self, args, partition='test', mode='episode', is_training=False, fix_seed=True): 26 | super(OpenTiered, self).__init__() 27 | self.mode = mode 28 | self.fix_seed = fix_seed 29 | self.n_ways = args.n_ways 30 | self.n_open_ways = args.n_open_ways 31 | self.n_shots = args.n_shots 32 | self.n_queries = args.n_queries 33 | self.n_episodes = args.n_test_runs if partition == 'test' else args.n_train_runs 34 | self.n_aug_support_samples = args.n_aug_support_samples 35 | self.partition = partition 36 | 37 | mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0] 38 | std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0] 39 | normalize = transforms.Normalize(mean=mean, std=std) 40 | 41 | if is_training: 42 | self.train_transform = transforms.Compose([ 43 | transforms.RandomCrop(84, padding=8), 44 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 45 | transforms.RandomHorizontalFlip(), 46 | transforms.ToTensor(), 47 | normalize 48 | ]) 49 | else: 50 | self.train_transform = transforms.Compose([ 51 | transforms.RandomCrop(84, padding=8), 52 | transforms.RandomHorizontalFlip(), 53 | transforms.ToTensor(), 54 | normalize 55 | ]) 56 | 57 | self.test_transform = transforms.Compose([transforms.ToTensor(),normalize]) 58 | 59 | self.vector_array = {} 60 | key_map = {'train':'base','test':'novel_test','val':'novel_val'} 61 | root_path = args.data_root 62 | for the_file in ['test','train', 'val']: 63 | file = 'few-shot-wordemb-{}.npz'.format(the_file) 64 | self.vector_array[key_map[the_file]] = np.load(os.path.join(root_path,file))['features'] 65 | 66 | full_file = 'few-shot-{}.npz'.format(partition) 67 | self.imgs = np.load(os.path.join(root_path,full_file))['features'] 68 | labels = np.load(os.path.join(root_path,full_file))['targets'] 69 | 70 | 71 | self.imgs = [Image.fromarray(x) for x in self.imgs] 72 | min_label = min(labels) 73 | self.labels = [x - min_label for x in labels] 74 | print('Load {} Data of {} for tieredImageNet in Meta-Learning Stage'.format(len(self.imgs), partition)) 75 | 76 | self.data = {} 77 | for idx in range(len(self.imgs)): 78 | if self.labels[idx] not in self.data: 79 | self.data[self.labels[idx]] = [] 80 | self.data[self.labels[idx]].append(self.imgs[idx]) 81 | self.classes = list(self.data.keys()) 82 | 83 | 84 | def __getitem__(self, item): 85 | return self.get_episode(item) 86 | 87 | def get_episode(self, item): 88 | 89 | if self.fix_seed: 90 | np.random.seed(item) 91 | cls_sampled = np.random.choice(self.classes, self.n_ways, False) 92 | support_xs = [] 93 | support_ys = [] 94 | suppopen_xs = [] 95 | suppopen_ys = [] 96 | query_xs = [] 97 | query_ys = [] 98 | openset_xs = [] 99 | openset_ys = [] 100 | 101 | for idx, the_cls in enumerate(cls_sampled): 102 | imgs = self.data[the_cls] 103 | support_xs_ids_sampled = np.random.choice(range(len(imgs)), self.n_shots, False) 104 | support_xs.extend([imgs[the_id] for the_id in support_xs_ids_sampled]) 105 | support_ys.extend([idx] * self.n_shots) 106 | query_xs_ids = np.setxor1d(np.arange(len(imgs)), support_xs_ids_sampled) 107 | query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False) 108 | query_xs.extend([imgs[the_id] for the_id in query_xs_ids]) 109 | query_ys.extend([idx] * self.n_queries) 110 | 111 | cls_open_ids = np.setxor1d(np.arange(len(self.classes)), cls_sampled) 112 | cls_open_ids = np.random.choice(cls_open_ids, self.n_open_ways, False) 113 | for idx, the_cls in enumerate(cls_open_ids): 114 | imgs = self.data[the_cls] 115 | suppopen_xs_ids_sampled = np.random.choice(range(len(imgs)), self.n_shots, False) 116 | suppopen_xs.extend([imgs[the_id] for the_id in suppopen_xs_ids_sampled]) 117 | suppopen_ys.extend([idx] * self.n_shots) 118 | openset_xs_ids = np.setxor1d(np.arange(len(imgs)), suppopen_xs_ids_sampled) 119 | openset_xs_ids_sampled = np.random.choice(range(len(imgs)), self.n_queries, False) 120 | openset_xs.extend([imgs[the_id] for the_id in openset_xs_ids_sampled]) 121 | openset_ys.extend([the_cls] * self.n_queries) 122 | 123 | 124 | if self.partition == 'train': 125 | base_ids = np.setxor1d(np.arange(len(self.classes)), np.concatenate([cls_sampled,cls_open_ids])) 126 | assert len(set(base_ids).union(set(cls_open_ids)).union(set(cls_sampled))) == len(self.classes) 127 | base_ids = np.array(sorted(base_ids)) 128 | 129 | if self.n_aug_support_samples > 1: 130 | support_xs_aug = [support_xs[i:i+self.n_shots]*self.n_aug_support_samples for i in range(0,len(support_xs),self.n_shots)] 131 | support_ys_aug = [support_ys[i:i+self.n_shots]*self.n_aug_support_samples for i in range(0,len(support_ys),self.n_shots)] 132 | support_xs,support_ys = support_xs_aug[0],support_ys_aug[0] 133 | for next_xs,next_ys in zip(support_xs_aug[1:],support_ys_aug[1:]): 134 | support_xs.extend(next_xs) 135 | support_ys.extend(next_ys) 136 | 137 | suppopen_xs_aug = [suppopen_xs[i:i+self.n_shots]*self.n_aug_support_samples for i in range(0,len(support_xs),self.n_shots)] 138 | suppopen_ys_aug = [suppopen_ys[i:i+self.n_shots]*self.n_aug_support_samples for i in range(0,len(support_ys),self.n_shots)] 139 | suppopen_xs,suppopen_ys = suppopen_xs_aug[0],suppopen_ys_aug[0] 140 | for next_xs,next_ys in zip(suppopen_xs_aug[1:],suppopen_ys_aug[1:]): 141 | suppopen_xs.extend(next_xs) 142 | suppopen_ys.extend(next_ys) 143 | 144 | support_xs = torch.stack(list(map(lambda x: self.train_transform(x), support_xs))) 145 | suppopen_xs = torch.stack(list(map(lambda x: self.train_transform(x), suppopen_xs))) 146 | query_xs = torch.stack(list(map(lambda x: self.test_transform(x), query_xs))) 147 | openset_xs = torch.stack(list(map(lambda x: self.test_transform(x), openset_xs))) 148 | support_ys,query_ys,openset_ys = np.array(support_ys),np.array(query_ys),np.array(openset_ys) 149 | suppopen_ys = np.array(suppopen_ys) 150 | cls_sampled, cls_open_ids = np.array(cls_sampled), np.array(cls_open_ids) 151 | 152 | 153 | if self.partition == 'train': 154 | return support_xs, support_ys, query_xs, query_ys, suppopen_xs, suppopen_ys, openset_xs, openset_ys, cls_sampled, cls_open_ids, base_ids, 155 | else: 156 | return support_xs, support_ys, query_xs, query_ys, suppopen_xs, suppopen_ys, openset_xs, openset_ys, cls_sampled, cls_open_ids 157 | 158 | def __len__(self): 159 | return self.n_episodes -------------------------------------------------------------------------------- /pretrain/trainer/BaseTrainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import pdb 5 | import numpy as np 6 | import argparse 7 | import socket 8 | import time 9 | import sys 10 | from tqdm import tqdm 11 | import mkl 12 | import math 13 | import h5py 14 | 15 | import torch 16 | import torch.optim as optim 17 | import torch.nn as nn 18 | import torch.backends.cudnn as cudnn 19 | from torch.utils.data import DataLoader 20 | import torch.nn.functional as F 21 | from torch.autograd import Variable 22 | from tensorboardX import SummaryWriter 23 | 24 | from architectures.Network import Backbone 25 | from architectures.LossFeat import SupConLoss 26 | from trainer.MetaEval import meta_evaluation 27 | from util import adjust_learning_rate, accuracy, AverageMeter, rot_aug 28 | 29 | def get_freer_gpu(): 30 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') 31 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()] 32 | return np.argmax(memory_available) 33 | 34 | def model_name(args): 35 | model_name = '{}_{}_batch_{}_lr_{}'.format(args.dataset, args.restype, args.batch_size, args.learning_rate) 36 | if args.restype in ['ViT','Swin']: 37 | model_name = '{}_Trans{}'.format(model_name,args.vit_dim) 38 | if args.featype == 'Contrast': 39 | model_name = '{}_temp_{}_even_{}'.format(model_name,args.temp,args.even) 40 | if args.featype == 'Entropy': 41 | model_name = model_name + '_bce' if args.use_bce else model_name 42 | return model_name 43 | 44 | class BaseTrainer(object): 45 | def __init__(self, args, dataset_trainer): 46 | args.logroot = os.path.join(args.logroot, args.featype) 47 | if not os.path.isdir(args.logroot): 48 | os.makedirs(args.logroot) 49 | 50 | # set the path according to the environment 51 | iterations = args.lr_decay_epochs.split(',') 52 | args.lr_decay_epochs = list([]) 53 | for it in iterations: 54 | args.lr_decay_epochs.append(int(it)) 55 | 56 | args.model_name = model_name(args) 57 | self.save_path = os.path.join(args.logroot, args.model_name) 58 | if not os.path.isdir(self.save_path): 59 | os.mkdir(self.save_path) 60 | 61 | self.args = args 62 | self.train_loader, self.val_loader, self.n_cls = dataset_trainer 63 | 64 | # model & optimizer 65 | self.model = Backbone(args, args.restype, self.n_cls) 66 | if self.args.restype in ['ViT','Swin']: 67 | self.optimizer = optim.AdamW(self.model.parameters(), eps=1e-8, betas=(0.9, 0.999), lr=5e-4, weight_decay=0.05) 68 | # self.optimizer = optim.SGD(self.model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False) 69 | else: 70 | self.optimizer = optim.SGD(self.model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False) 71 | self.criterion = {'feat':SupConLoss(temperature=args.temp),'logit':nn.BCEWithLogitsLoss() if self.args.use_bce else nn.CrossEntropyLoss()} # 72 | # print(self.model) 73 | # print(self.optimizer) 74 | 75 | if torch.cuda.is_available(): 76 | if args.n_gpu > 1: 77 | self.model = nn.DataParallel(self.model) 78 | self.model = self.model.cuda() 79 | self.criterion = {name:loss.cuda() for name,loss in self.criterion.items()} 80 | cudnn.benchmark = True 81 | 82 | def train(self, eval_loader=None): 83 | 84 | trlog = {'args':vars(self.args), 'max_1shot_meta':0.0, 'max_5shot_meta':0.0, 'max_1shot_epoch':0, 'max_5shot_epoch':0} 85 | writer = SummaryWriter(self.save_path) 86 | 87 | # routine: supervised pre-training 88 | for epoch in range(1, self.args.epochs + 1): 89 | 90 | adjust_learning_rate(epoch, self.args, self.optimizer) 91 | train_loss, train_msg = self.train_epoch(epoch, self.train_loader, self.model, self.criterion, self.optimizer, self.args) 92 | 93 | writer.add_scalar('train/loss', float(train_loss), epoch) 94 | writer.add_scalar('lr', self.optimizer.param_groups[0]['lr'], epoch) 95 | 96 | #evaluate 97 | if eval_loader is not None and (epoch % 10 == 0 or epoch > 55): 98 | start = time.time() 99 | eval_1shot_loader,eval_5shot_loader = eval_loader 100 | meta_1shot_acc, meta_1shot_std = meta_evaluation(self.model, eval_1shot_loader) 101 | meta_5shot_acc, meta_5shot_std = meta_evaluation(self.model, eval_5shot_loader) 102 | test_time = time.time() - start 103 | writer.add_scalar('MetaAcc/1shot', float(meta_1shot_acc), epoch) 104 | writer.add_scalar('MetaStd/1shot', float(meta_1shot_std), epoch) 105 | writer.add_scalar('MetaAcc/5shot', float(meta_5shot_acc), epoch) 106 | writer.add_scalar('MetaStd/5shot', float(meta_5shot_std), epoch) 107 | meta_msg = 'Meta Test Acc: 1-shot {:.4f} 5-shot {:.4f}, Meta Test std: {:.4f} {:.4f}, Time: {:.1f}'.format(meta_1shot_acc, meta_5shot_acc, meta_1shot_std, meta_5shot_std, test_time) 108 | train_msg = train_msg + ' | ' + meta_msg 109 | if trlog['max_1shot_meta'] < meta_1shot_acc: 110 | trlog['max_1shot_meta'] = meta_1shot_acc 111 | trlog['max_1shot_epoch'] = epoch 112 | self.save_model(epoch,'max_meta') 113 | if trlog['max_5shot_meta'] < meta_5shot_acc: 114 | trlog['max_5shot_meta'] = meta_5shot_acc 115 | trlog['max_5shot_epoch'] = epoch 116 | self.save_model(epoch,'max_meta_5shot') # will not use 117 | 118 | print(train_msg) 119 | if epoch % 10 == 0 or epoch==self.args.epochs: 120 | self.save_model(epoch,'last') 121 | print('The Best Meta 1(5)-shot Acc {:.4f}({:.4f}) in Epoch {}({})'.format(trlog['max_1shot_meta'],trlog['max_5shot_meta'],trlog['max_1shot_epoch'],trlog['max_5shot_epoch'])) 122 | torch.save(trlog, os.path.join(self.save_path, 'trlog')) 123 | 124 | def train_epoch(self, epoch, train_loader, model, criterion, optimizer, args): 125 | """One epoch training""" 126 | return 0,'to be updated' 127 | 128 | def save_model(self, epoch, name=None): 129 | state = { 130 | 'epoch': epoch, 131 | 'params': self.model.state_dict() 132 | } 133 | file_name = '{}.pth'.format('epoch_'+str(epoch) if name is None else name) 134 | print('==> Saving', file_name) 135 | torch.save(state, os.path.join(self.save_path, file_name)) 136 | 137 | def eval_report(self,eval_loader,path): 138 | print('Loading data from', path) 139 | params = torch.load(path)['params'] 140 | if 'tiered' in self.args.dataset: 141 | params = {'.'.join(k.split('.')[1:]):v for k,v in params.items()} 142 | model_dict = self.model.state_dict() 143 | model_dict.update(params) 144 | self.model.load_state_dict(model_dict) 145 | self.model.eval() 146 | 147 | eval_1shot_loader,eval_5shot_loader = eval_loader 148 | meta_1shot_acc, meta_1shot_std = meta_evaluation(self.model, eval_1shot_loader) 149 | meta_5shot_acc, meta_5shot_std = meta_evaluation(self.model, eval_5shot_loader) 150 | print('Linear Regression: 1(5)-shot Accuracy {:.4f}({:.4f}) Std {:.4f}({:.4f})'.format(meta_1shot_acc,meta_5shot_acc,meta_1shot_std,meta_5shot_std)) 151 | meta_1shot_acc, meta_1shot_std = meta_evaluation(self.model, eval_1shot_loader, type_classifier='proto') 152 | meta_5shot_acc, meta_5shot_std = meta_evaluation(self.model, eval_5shot_loader, type_classifier='proto') 153 | print('Proto Classification 1(5)-shot Accuracy {:.4f}({:.4f}) Std {:.4f}({:.4f})'.format(meta_1shot_acc,meta_5shot_acc,meta_1shot_std,meta_5shot_std)) 154 | -------------------------------------------------------------------------------- /pretrain/architectures/ResNetFeat.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.distributions import Bernoulli 5 | import pdb 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 11 | 12 | class DropBlock(nn.Module): 13 | def __init__(self, block_size): 14 | super(DropBlock, self).__init__() 15 | 16 | self.block_size = block_size 17 | #self.gamma = gamma 18 | #self.bernouli = Bernoulli(gamma) 19 | 20 | def forward(self, x, gamma): 21 | # shape: (bsize, channels, height, width) 22 | 23 | if self.training: 24 | batch_size, channels, height, width = x.shape 25 | 26 | bernoulli = Bernoulli(gamma) 27 | mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda() 28 | block_mask = self._compute_block_mask(mask) 29 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3] 30 | count_ones = block_mask.sum() 31 | 32 | return block_mask * x * (countM / count_ones) 33 | else: 34 | return x 35 | 36 | def _compute_block_mask(self, mask): 37 | left_padding = int((self.block_size-1) / 2) 38 | right_padding = int(self.block_size / 2) 39 | 40 | batch_size, channels, height, width = mask.shape 41 | #print ("mask", mask[0][0]) 42 | non_zero_idxs = mask.nonzero() 43 | nr_blocks = non_zero_idxs.shape[0] 44 | 45 | offsets = torch.stack( 46 | [ 47 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding, 48 | torch.arange(self.block_size).repeat(self.block_size), #- left_padding 49 | ] 50 | ).t().cuda() 51 | offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()), 1) 52 | 53 | if nr_blocks > 0: 54 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1) 55 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4) 56 | offsets = offsets.long() 57 | 58 | block_idxs = non_zero_idxs + offsets 59 | #block_idxs += left_padding 60 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 61 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1. 62 | else: 63 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 64 | 65 | block_mask = 1 - padded_mask#[:height, :width] 66 | return block_mask 67 | 68 | 69 | class BasicBlock(nn.Module): 70 | expansion = 1 71 | 72 | def __init__(self, inplanes, planes, stride=2, downsample=None, drop_rate=0.0, drop_block=False,block_size=1): 73 | super(BasicBlock, self).__init__() 74 | self.conv1 = conv3x3(inplanes, planes) 75 | self.bn1 = nn.BatchNorm2d(planes) 76 | self.relu = nn.LeakyReLU(0.1) 77 | self.conv2 = conv3x3(planes, planes) 78 | self.bn2 = nn.BatchNorm2d(planes) 79 | self.conv3 = conv3x3(planes, planes) 80 | self.bn3 = nn.BatchNorm2d(planes) 81 | self.maxpool = nn.MaxPool2d(stride) 82 | self.downsample = downsample 83 | self.stride = stride 84 | self.drop_rate = drop_rate 85 | self.num_batches_tracked = 0 86 | self.drop_block = drop_block 87 | self.block_size = block_size 88 | self.DropBlock = DropBlock(block_size=self.block_size) 89 | 90 | def forward(self, x): 91 | self.num_batches_tracked += 1 92 | 93 | residual = x 94 | 95 | out = self.conv1(x) 96 | out = self.bn1(out) 97 | out = self.relu(out) 98 | 99 | out = self.conv2(out) 100 | out = self.bn2(out) 101 | out = self.relu(out) 102 | 103 | out = self.conv3(out) 104 | out = self.bn3(out) 105 | 106 | if self.downsample is not None: 107 | residual = self.downsample(x) 108 | out += residual 109 | out = self.relu(out) 110 | out = self.maxpool(out) 111 | 112 | if self.drop_rate > 0: 113 | if self.drop_block == True: 114 | feat_size = out.size()[2] 115 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate) 116 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2 117 | out = self.DropBlock(out, gamma=gamma) 118 | else: 119 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True) 120 | 121 | return out 122 | 123 | 124 | class ResNet(nn.Module): 125 | 126 | def __init__(self, block, n_blocks, keep_prob=1.0, drop_rate=0.0, dropblock_size=5, num_classes=-1): 127 | super(ResNet, self).__init__() 128 | channels = [64,160,320,640] 129 | 130 | self.inplanes = 3 131 | self.layer1 = self._make_layer(block, n_blocks[0], channels[0],drop_rate=drop_rate) 132 | self.layer2 = self._make_layer(block, n_blocks[1], channels[1],drop_rate=drop_rate) 133 | self.layer3 = self._make_layer(block, n_blocks[2], channels[2],drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 134 | self.layer4 = self._make_layer(block, n_blocks[3], channels[3],drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 135 | self.avgpool = nn.AdaptiveAvgPool2d(1) 136 | self.keep_prob = keep_prob 137 | self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False) 138 | self.drop_rate = drop_rate 139 | self.out_dim = channels[-1] 140 | 141 | for m in self.modules(): 142 | if isinstance(m, nn.Conv2d): 143 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 144 | elif isinstance(m, nn.BatchNorm2d): 145 | nn.init.constant_(m.weight, 1) 146 | nn.init.constant_(m.bias, 0) 147 | 148 | 149 | def _make_layer(self, block, n_block, planes, stride=2, drop_rate=0.0, drop_block=False, block_size=1): 150 | downsample = None 151 | if stride != 1 or self.inplanes != planes * block.expansion: 152 | downsample = nn.Sequential( 153 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, bias=False), 154 | nn.BatchNorm2d(planes * block.expansion), 155 | ) 156 | the_blk = block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size) 157 | self.inplanes = planes * block.expansion 158 | 159 | return the_blk 160 | 161 | def forward(self, x, is_feat=False): 162 | 163 | x = self.layer1(x) 164 | x = self.layer2(x) 165 | x = self.layer3(x) 166 | x = self.layer4(x) 167 | 168 | x = self.avgpool(x) 169 | resfeat = x.view(x.size(0), -1) 170 | 171 | return resfeat 172 | 173 | def create_feature_extractor(restype, dataset, **kwargs): 174 | assert restype == 'ResNet12' 175 | keep_prob = 1.0 176 | drop_rate = 0.1 177 | dropblock_size = 5 if 'ImageNet' in dataset else 2 178 | network = ResNet(BasicBlock, [1, 1, 1, 1], keep_prob=keep_prob, drop_rate=drop_rate, dropblock_size=dropblock_size, **kwargs) 179 | return network 180 | 181 | 182 | def resnet12_ssl(keep_prob=1.0, avg_pool=False, **kwargs): 183 | """Constructs a ResNet-12 model. 184 | """ 185 | model = ResNet(BasicBlock, [1, 1, 1, 1], keep_prob=keep_prob **kwargs) 186 | return model 187 | 188 | 189 | def resnet18(keep_prob=1.0, avg_pool=False, **kwargs): 190 | """Constructs a ResNet-18 model. 191 | """ 192 | model = ResNet(BasicBlock, [1, 1, 2, 2], keep_prob=keep_prob **kwargs) 193 | return model 194 | 195 | -------------------------------------------------------------------------------- /architectures/ResNetFeat.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.distributions import Bernoulli 5 | from .conv2d_mtl import Conv2dMtl 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 11 | 12 | def conv3x3mtl(in_planes, out_planes, stride=1): 13 | return Conv2dMtl(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | class DropBlock(nn.Module): 16 | def __init__(self, block_size): 17 | super(DropBlock, self).__init__() 18 | 19 | self.block_size = block_size 20 | #self.gamma = gamma 21 | #self.bernouli = Bernoulli(gamma) 22 | 23 | def forward(self, x, gamma): 24 | # shape: (bsize, channels, height, width) 25 | 26 | if self.training: 27 | batch_size, channels, height, width = x.shape 28 | 29 | bernoulli = Bernoulli(gamma) 30 | mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda() 31 | block_mask = self._compute_block_mask(mask) 32 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3] 33 | count_ones = block_mask.sum() 34 | 35 | return block_mask * x * (countM / count_ones) 36 | else: 37 | return x 38 | 39 | def _compute_block_mask(self, mask): 40 | left_padding = int((self.block_size-1) / 2) 41 | right_padding = int(self.block_size / 2) 42 | 43 | batch_size, channels, height, width = mask.shape 44 | #print ("mask", mask[0][0]) 45 | non_zero_idxs = mask.nonzero() 46 | nr_blocks = non_zero_idxs.shape[0] 47 | 48 | offsets = torch.stack( 49 | [ 50 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding, 51 | torch.arange(self.block_size).repeat(self.block_size), #- left_padding 52 | ] 53 | ).t().cuda() 54 | offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()), 1) 55 | 56 | if nr_blocks > 0: 57 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1) 58 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4) 59 | offsets = offsets.long() 60 | 61 | block_idxs = non_zero_idxs + offsets 62 | #block_idxs += left_padding 63 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 64 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1. 65 | else: 66 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 67 | 68 | block_mask = 1 - padded_mask#[:height, :width] 69 | return block_mask 70 | 71 | 72 | class BasicBlock(nn.Module): 73 | expansion = 1 74 | 75 | def __init__(self, inplanes, planes, stride=2, downsample=None, drop_rate=0.0, drop_block=False,block_size=1): 76 | super(BasicBlock, self).__init__() 77 | self.conv1 = conv3x3(inplanes, planes) 78 | self.bn1 = nn.BatchNorm2d(planes) 79 | self.relu = nn.LeakyReLU(0.1) 80 | self.conv2 = conv3x3(planes, planes) 81 | self.bn2 = nn.BatchNorm2d(planes) 82 | self.conv3 = conv3x3(planes, planes) 83 | self.bn3 = nn.BatchNorm2d(planes) 84 | self.maxpool = nn.MaxPool2d(stride) 85 | self.downsample = downsample 86 | self.stride = stride 87 | self.drop_rate = drop_rate 88 | self.num_batches_tracked = 0 89 | self.drop_block = drop_block 90 | self.block_size = block_size 91 | self.DropBlock = DropBlock(block_size=self.block_size) 92 | 93 | def forward(self, x): 94 | self.num_batches_tracked += 1 95 | 96 | residual = x 97 | 98 | out = self.conv1(x) 99 | out = self.bn1(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv2(out) 103 | out = self.bn2(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv3(out) 107 | out = self.bn3(out) 108 | 109 | if self.downsample is not None: 110 | residual = self.downsample(x) 111 | out += residual 112 | out = self.relu(out) 113 | out = self.maxpool(out) 114 | 115 | if self.drop_rate > 0: 116 | if self.drop_block == True: 117 | feat_size = out.size()[2] 118 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate) 119 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2 120 | out = self.DropBlock(out, gamma=gamma) 121 | else: 122 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True) 123 | 124 | return out 125 | 126 | class BasicBlockMeta(nn.Module): 127 | expansion = 1 128 | 129 | def __init__(self, inplanes, planes, stride=2, downsample=None, drop_rate=0.0, drop_block=False,block_size=1): 130 | super(BasicBlockMeta, self).__init__() 131 | self.conv1 = conv3x3mtl(inplanes, planes) 132 | self.bn1 = nn.BatchNorm2d(planes) 133 | self.relu = nn.LeakyReLU(0.1) 134 | self.conv2 = conv3x3mtl(planes, planes) 135 | self.bn2 = nn.BatchNorm2d(planes) 136 | self.conv3 = conv3x3mtl(planes, planes) 137 | self.bn3 = nn.BatchNorm2d(planes) 138 | self.maxpool = nn.MaxPool2d(stride) 139 | self.downsample = downsample 140 | self.stride = stride 141 | self.drop_rate = drop_rate 142 | self.num_batches_tracked = 0 143 | self.drop_block = drop_block 144 | self.block_size = block_size 145 | self.DropBlock = DropBlock(block_size=self.block_size) 146 | 147 | def forward(self, x): 148 | self.num_batches_tracked += 1 149 | 150 | residual = x 151 | 152 | out = self.conv1(x) 153 | out = self.bn1(out) 154 | out = self.relu(out) 155 | 156 | out = self.conv2(out) 157 | out = self.bn2(out) 158 | out = self.relu(out) 159 | 160 | out = self.conv3(out) 161 | out = self.bn3(out) 162 | 163 | if self.downsample is not None: 164 | residual = self.downsample(x) 165 | out += residual 166 | out = self.relu(out) 167 | out = self.maxpool(out) 168 | 169 | if self.drop_rate > 0: 170 | if self.drop_block == True: 171 | feat_size = out.size()[2] 172 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate) 173 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2 174 | out = self.DropBlock(out, gamma=gamma) 175 | else: 176 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True) 177 | 178 | return out 179 | 180 | 181 | class ResNet(nn.Module): 182 | 183 | def __init__(self, block, n_blocks, keep_prob=1.0, drop_rate=0.0, dropblock_size=5, num_classes=-1): 184 | super(ResNet, self).__init__() 185 | channels = [64,160,320,640] 186 | 187 | self.inplanes = 3 188 | self.layer1 = self._make_layer(block, n_blocks[0], channels[0],drop_rate=drop_rate) 189 | self.layer2 = self._make_layer(block, n_blocks[1], channels[1],drop_rate=drop_rate) 190 | self.layer3 = self._make_layer(block, n_blocks[2], channels[2],drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 191 | self.layer4 = self._make_layer(block, n_blocks[3], channels[3],drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 192 | self.avgpool = nn.AdaptiveAvgPool2d(1) 193 | self.keep_prob = keep_prob 194 | self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False) 195 | self.drop_rate = drop_rate 196 | self.out_dim = channels[-1] 197 | 198 | for m in self.modules(): 199 | if isinstance(m, nn.Conv2d): 200 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 201 | elif isinstance(m, nn.BatchNorm2d): 202 | nn.init.constant_(m.weight, 1) 203 | nn.init.constant_(m.bias, 0) 204 | 205 | 206 | def _make_layer(self, block, n_block, planes, stride=2, drop_rate=0.0, drop_block=False, block_size=1): 207 | downsample = None 208 | if stride != 1 or self.inplanes != planes * block.expansion: 209 | downsample = nn.Sequential( 210 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, bias=False), 211 | nn.BatchNorm2d(planes * block.expansion), 212 | ) 213 | the_blk = block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size) 214 | self.inplanes = planes * block.expansion 215 | 216 | return the_blk 217 | 218 | def forward(self, x, rot=False,is_feat=False): 219 | x = self.layer1(x) 220 | x = self.layer2(x) 221 | 222 | x = self.layer3(x) 223 | x = self.layer4(x) 224 | 225 | x = self.avgpool(x) 226 | resfeat = x.view(x.size(0), -1) 227 | 228 | return resfeat 229 | 230 | def create_feature_extractor(restype, dataset, **kwargs): 231 | # mode 0:pre-train, 1:finetune 2:bias_shift 232 | keep_prob = 1.0 233 | drop_rate = 0.1 234 | dropblock_size = 5 if 'ImageNet' in dataset else 2 235 | if restype == 'ResNet12': 236 | network = ResNet(BasicBlock, [1,1,1,1], keep_prob=keep_prob, drop_rate=drop_rate, dropblock_size=dropblock_size, **kwargs) 237 | elif restype == 'ResNet18': 238 | network = ResNet(BasicBlock, [1,1,2,2], keep_prob=keep_prob, drop_rate=drop_rate, dropblock_size=dropblock_size, **kwargs) 239 | else: 240 | raise ValueError("Not Implemented Yet") 241 | return network 242 | -------------------------------------------------------------------------------- /dataloader/cifar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | import pdb 7 | from torch.utils.data import Dataset 8 | import torchvision.transforms as transforms 9 | 10 | INCLUDE_BASE=False 11 | 12 | class OpenCIFAR(Dataset): 13 | def __init__(self, args, partition='test', mode='episode', is_training=False, fix_seed=True, held_out=False): 14 | super(OpenCIFAR, self).__init__() 15 | self.fix_seed = fix_seed 16 | self.n_ways = args.n_ways 17 | self.n_shots = args.n_shots 18 | self.n_open_ways = args.n_open_ways 19 | self.n_queries = args.n_queries 20 | self.n_episodes = args.n_train_runs if partition=='train' else args.n_test_runs 21 | self.n_aug_support_samples = 1 if partition == 'train' else args.n_aug_support_samples 22 | self.partition = partition 23 | self.held_out = held_out 24 | 25 | mean = [0.5071, 0.4867, 0.4408] 26 | std = [0.2675, 0.2565, 0.2761] 27 | normalize = transforms.Normalize(mean=mean, std=std) 28 | 29 | if is_training: 30 | self.train_transform = transforms.Compose([ 31 | transforms.RandomCrop(32, padding=4), 32 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 33 | transforms.RandomHorizontalFlip(), 34 | transforms.ToTensor(), 35 | normalize 36 | ]) 37 | else: 38 | self.train_transform = transforms.Compose([ 39 | transforms.RandomCrop(32, padding=4), 40 | transforms.RandomHorizontalFlip(), 41 | transforms.ToTensor(), 42 | normalize 43 | ]) 44 | 45 | with open(os.path.join(args.data_root,'category_vector.pickle'), 'rb') as f: 46 | pack = pickle.load(f, encoding='latin1') 47 | vectors = pack['vector'] 48 | 49 | self.vector_array = {'base':vectors['train'],'nove_val':vectors['val'],'novel_test':vectors['test']} 50 | 51 | self.test_transform = transforms.Compose([transforms.ToTensor(),normalize]) 52 | self.init_episode(args.data_root,partition) 53 | 54 | def __getitem__(self, item): 55 | return self.get_episode(item) 56 | 57 | def init_episode(self, data_root, partition): 58 | 59 | filename = '{}.pickle'.format(partition) 60 | self.data = {} 61 | with open(os.path.join(data_root, filename), 'rb') as f: 62 | pack = pickle.load(f, encoding='latin1') 63 | self.imgs = pack['data'] 64 | labels = pack['labels'] 65 | 66 | label2label = {} 67 | unique_labels = sorted(list(set(labels))) 68 | for cur_class, label in enumerate(unique_labels): 69 | label2label[label] = cur_class 70 | new_labels = [] 71 | for idx, label in enumerate(labels): 72 | new_labels.append(label2label[label]) 73 | self.labels = new_labels 74 | 75 | self.imgs = [Image.fromarray(x) for x in self.imgs] 76 | print('Load {} Data of {} in Meta-Learning Stage'.format(len(self.imgs), partition)) 77 | 78 | self.data = {} 79 | for idx in range(len(self.imgs)): 80 | if self.labels[idx] not in self.data: 81 | self.data[self.labels[idx]] = [] 82 | self.data[self.labels[idx]].append(self.imgs[idx]) 83 | self.classes = list(self.data.keys()) 84 | 85 | if self.held_out: 86 | for key in self.data: 87 | self.data[key] = self.data[key][:-100] 88 | 89 | if self.partition == 'test': 90 | if INCLUDE_BASE: 91 | filename = '{}.pickle'.format('train') 92 | with open(os.path.join(data_root, filename), 'rb') as f: 93 | pack = pickle.load(f, encoding='latin1') 94 | self.base_imgs = pack['data'].astype('uint8') 95 | labels = pack['labels'] 96 | self.base_imgs = [Image.fromarray(x) for x in self.base_imgs] 97 | min_label = min(labels) 98 | self.base_labels = [x - min_label for x in labels] 99 | self.base_data = {} 100 | for idx in range(len(self.base_imgs)): 101 | if self.base_labels[idx] not in self.base_data: 102 | self.base_data[self.base_labels[idx]] = [] 103 | self.base_data[self.base_labels[idx]].append(self.base_imgs[idx]) 104 | for key in self.base_data: 105 | self.base_data[key] = self.base_data[key][-100:] 106 | self.base_classes = list(self.base_data.keys()) 107 | 108 | print('Load {} Base Data of {} for miniImagenet in Meta-Learning Stage'.format(len(self.base_imgs), partition)) 109 | 110 | def get_episode(self, item): 111 | 112 | if self.fix_seed: 113 | np.random.seed(item) 114 | cls_sampled = np.random.choice(self.classes, self.n_ways, False) 115 | support_xs = [] 116 | support_ys = [] 117 | suppopen_xs = [] 118 | suppopen_ys = [] 119 | query_xs = [] 120 | query_ys = [] 121 | openset_xs = [] 122 | openset_ys = [] 123 | manyshot_xs = [] 124 | manyshot_ys = [] 125 | 126 | # Close set preparation 127 | for idx, the_cls in enumerate(cls_sampled): 128 | imgs = self.data[the_cls] 129 | support_xs_ids_sampled = np.random.choice(range(len(imgs)), self.n_shots, False) 130 | support_xs.extend([imgs[the_id] for the_id in support_xs_ids_sampled]) 131 | support_ys.extend([idx] * self.n_shots) 132 | query_xs_ids = np.setxor1d(np.arange(len(imgs)), support_xs_ids_sampled) 133 | query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False) 134 | query_xs.extend([imgs[the_id] for the_id in query_xs_ids]) 135 | query_ys.extend([idx] * self.n_queries) 136 | 137 | # Open set preparation 138 | cls_open_ids = np.setxor1d(np.arange(len(self.classes)), cls_sampled) 139 | cls_open_ids = np.random.choice(cls_open_ids, self.n_open_ways, False) 140 | for idx, the_cls in enumerate(cls_open_ids): 141 | imgs = self.data[the_cls] 142 | suppopen_xs_ids_sampled = np.random.choice(range(len(imgs)), self.n_shots, False) 143 | suppopen_xs.extend([imgs[the_id] for the_id in suppopen_xs_ids_sampled]) 144 | suppopen_ys.extend([idx] * self.n_shots) 145 | openset_xs_ids = np.setxor1d(np.arange(len(imgs)), suppopen_xs_ids_sampled) 146 | openset_xs_ids_sampled = np.random.choice(range(len(imgs)), self.n_queries, False) 147 | openset_xs.extend([imgs[the_id] for the_id in openset_xs_ids_sampled]) 148 | openset_ys.extend([the_cls] * self.n_queries) 149 | 150 | if self.partition == 'train': 151 | base_ids = np.setxor1d(np.arange(len(self.classes)), np.concatenate([cls_sampled,cls_open_ids])) 152 | assert len(set(base_ids).union(set(cls_open_ids)).union(set(cls_sampled))) == len(self.classes) 153 | base_ids = np.array(sorted(base_ids)) 154 | else: 155 | if INCLUDE_BASE: 156 | base_ids = sorted(self.base_classes) 157 | assert len(base_ids) > self.n_ways 158 | base_cls_sampled = list(np.random.choice(base_ids, self.n_ways, False)) 159 | base_cls_sampled.sort() 160 | for idx, the_cls in enumerate(base_cls_sampled): 161 | imgs = self.base_data[the_cls] 162 | manyshot_xs_ids_sampled = np.random.choice(range(len(imgs)), self.n_queries, False) 163 | manyshot_xs.extend([imgs[the_id] for the_id in manyshot_xs_ids_sampled]) 164 | manyshot_ys.extend([idx] * self.n_queries) 165 | 166 | 167 | if self.n_aug_support_samples > 1: 168 | support_xs_aug = [support_xs[i:i+self.n_shots]*self.n_aug_support_samples for i in range(0,len(support_xs),self.n_shots)] 169 | support_ys_aug = [support_ys[i:i+self.n_shots]*self.n_aug_support_samples for i in range(0,len(support_ys),self.n_shots)] 170 | support_xs,support_ys = support_xs_aug[0],support_ys_aug[0] 171 | for next_xs,next_ys in zip(support_xs_aug[1:],support_ys_aug[1:]): 172 | support_xs.extend(next_xs) 173 | support_ys.extend(next_ys) 174 | 175 | suppopen_xs_aug = [suppopen_xs[i:i+self.n_shots]*self.n_aug_support_samples for i in range(0,len(support_xs),self.n_shots)] 176 | suppopen_ys_aug = [suppopen_ys[i:i+self.n_shots]*self.n_aug_support_samples for i in range(0,len(support_ys),self.n_shots)] 177 | suppopen_xs,suppopen_ys = suppopen_xs_aug[0],suppopen_ys_aug[0] 178 | for next_xs,next_ys in zip(suppopen_xs_aug[1:],suppopen_ys_aug[1:]): 179 | suppopen_xs.extend(next_xs) 180 | suppopen_ys.extend(next_ys) 181 | 182 | support_xs = torch.stack(list(map(lambda x: self.train_transform(x), support_xs))) 183 | suppopen_xs = torch.stack(list(map(lambda x: self.train_transform(x), suppopen_xs))) 184 | query_xs = torch.stack(list(map(lambda x: self.test_transform(x), query_xs))) 185 | openset_xs = torch.stack(list(map(lambda x: self.test_transform(x), openset_xs))) 186 | support_ys,query_ys,openset_ys = np.array(support_ys),np.array(query_ys),np.array(openset_ys) 187 | suppopen_ys = np.array(suppopen_ys) 188 | cls_sampled, cls_open_ids = np.array(cls_sampled), np.array(cls_open_ids) 189 | 190 | if self.partition == 'train': 191 | return support_xs, support_ys, query_xs, query_ys, suppopen_xs, suppopen_ys, openset_xs, openset_ys, cls_sampled, cls_open_ids, base_ids 192 | else: 193 | if INCLUDE_BASE: 194 | manyshot_xs = torch.stack(list(map(lambda x: self.test_transform(x), manyshot_xs))) 195 | openset_xs = torch.cat([openset_xs, manyshot_xs]) 196 | openset_ys = torch.ones(len(openset_xs)) 197 | return support_xs, support_ys, query_xs, query_ys, suppopen_xs, suppopen_ys, openset_xs, openset_ys, cls_sampled, cls_open_ids 198 | 199 | def __len__(self): 200 | return self.n_episodes -------------------------------------------------------------------------------- /trainer/MetaTrainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import numpy as np 5 | import argparse 6 | import socket 7 | import time 8 | import sys 9 | from tqdm import tqdm 10 | import pdb 11 | 12 | import torch 13 | import torch.optim as optim 14 | import torch.nn as nn 15 | import torch.backends.cudnn as cudnn 16 | from torch.utils.data import DataLoader 17 | import torch.nn.functional as F 18 | from torch.autograd import Variable 19 | from tensorboardX import SummaryWriter 20 | 21 | from architectures.NetworkPre import FeatureNet 22 | from trainer.FSEval import run_test_fsl 23 | from util import adjust_learning_rate, accuracy, AverageMeter 24 | from sklearn import metrics 25 | 26 | 27 | class MetaTrainer(object): 28 | def __init__(self, args, dataset_trainer, eval_loader=None, hard_path=None): 29 | os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpus) 30 | args.logroot = os.path.join(args.logroot, args.featype + '_' + args.dataset) 31 | if not os.path.isdir(args.logroot): 32 | os.makedirs(args.logroot) 33 | 34 | try: 35 | iterations = args.lr_decay_epochs.split(',') 36 | args.lr_decay_epochs = list([]) 37 | for it in iterations: 38 | args.lr_decay_epochs.append(int(it)) 39 | except: 40 | pass 41 | 42 | args.model_name = '{}_{}_shot_{}'.format(args.dataset, args.n_train_runs, args.n_shots) 43 | 44 | 45 | self.save_path = os.path.join(args.logroot, args.model_name) 46 | if not os.path.isdir(self.save_path): 47 | os.mkdir(self.save_path) 48 | 49 | assert args.pretrained_model_path is not None, 'Missing Pretrained Model' 50 | params = torch.load(args.pretrained_model_path)['params'] 51 | feat_params = {k: v for k, v in params.items() if 'feature' in k} 52 | cls_params = {k: v for k, v in params.items() if 'cls_classifier' in k} 53 | 54 | self.args = args 55 | self.train_loader, self.val_loader, n_cls = dataset_trainer 56 | self.model = FeatureNet(args, args.restype, n_cls, (cls_params,self.train_loader.dataset.vector_array)) 57 | 58 | 59 | ##### Load Pretrained Weights for Feature Extractor 60 | model_dict = self.model.state_dict() 61 | model_dict.update(feat_params) 62 | self.model.load_state_dict(model_dict) 63 | 64 | self.model.train() 65 | print('Loaded Pretrained Weight from %s' % args.pretrained_model_path) 66 | 67 | # optimizer 68 | if self.args.tunefeat == 0.0: 69 | optim_param = [{'params': self.model.cls_classifier.parameters()}] 70 | else: 71 | optim_param = [{'params': self.model.cls_classifier.parameters()},{'params': filter(lambda p: p.requires_grad, self.model.feature.parameters()),'lr': self.args.tunefeat}] 72 | 73 | self.optimizer = optim.SGD(optim_param, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) 74 | 75 | if torch.cuda.is_available(): 76 | if args.n_gpu > 1: 77 | self.model = nn.DataParallel(self.model) 78 | self.model = self.model.cuda() 79 | cudnn.benchmark = True 80 | 81 | # set cosine annealing scheduler 82 | if args.cosine: 83 | print("==> training with plateau scheduler ...") 84 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'max') 85 | else: 86 | print("==> training with MultiStep scheduler ... gamma {} step {}".format(args.lr_decay_rate, args.lr_decay_epochs)) 87 | 88 | def train(self, eval_loader=None): 89 | 90 | trlog = {} 91 | trlog['args'] = vars(self.args) 92 | trlog['maxmeta_acc'] = 0.0 93 | trlog['maxmeta_acc_epoch'] = 0 94 | trlog['maxmeta_auroc'] = 0.0 95 | trlog['maxmeta_auroc_epoch'] = 0 96 | 97 | writer = SummaryWriter(self.save_path) 98 | 99 | criterion = nn.CrossEntropyLoss() 100 | criterion = criterion.cuda() 101 | 102 | 103 | for epoch in range(1, self.args.epochs + 1): 104 | if self.args.cosine: 105 | self.scheduler.step(trlog['maxmeta_acc']) 106 | else: 107 | adjust_learning_rate(epoch, self.args, self.optimizer, 0.0001) 108 | 109 | train_acc, train_auroc, train_loss, train_msg = self.train_episode(epoch, self.train_loader, self.model, criterion, self.optimizer, self.args) 110 | 111 | writer.add_scalar('train/acc', float(train_acc), epoch) 112 | writer.add_scalar('train/auroc', float(train_auroc), epoch) 113 | writer.add_scalar('train/loss_cls', float(train_loss[0]), epoch) 114 | writer.add_scalar('train/loss_funit', float(train_loss[1]), epoch) 115 | writer.add_scalar('lr', self.optimizer.param_groups[0]['lr'], epoch) 116 | 117 | self.model.eval() 118 | 119 | #evaluate 120 | if eval_loader is not None: 121 | start = time.time() 122 | assert self.args.featype == 'OpenMeta' 123 | config = {'auroc_type':['prob']} 124 | result = run_test_fsl(self.model, eval_loader, config) 125 | meta_test_acc = result['data']['acc'] 126 | open_score_auroc = result['data']['auroc_prob'] 127 | 128 | test_time = time.time() - start 129 | writer.add_scalar('meta/close_acc', float(meta_test_acc[0]), epoch) 130 | writer.add_scalar('meta/close_std', float(meta_test_acc[1]), epoch) 131 | writer.add_scalar('meta/open_auroc', float(open_score_auroc[0]), epoch) 132 | writer.add_scalar('meta/open_std', float(open_score_auroc[1]), epoch) 133 | 134 | meta_msg = 'Meta Test Acc: {:.4f}, Test std: {:.4f}, AUROC: {:.4f}, Time: {:.1f}'.format(meta_test_acc[0], meta_test_acc[1], open_score_auroc[0], test_time) 135 | train_msg = train_msg + ' | ' + meta_msg 136 | 137 | if trlog['maxmeta_acc'] < meta_test_acc[0]: 138 | trlog['maxmeta_acc'] = meta_test_acc[0] 139 | trlog['maxmeta_acc_epoch'] = epoch 140 | acc_auroc = (meta_test_acc[0], open_score_auroc[0]) 141 | self.save_model(epoch, 'max_acc', acc_auroc) 142 | if trlog['maxmeta_auroc'] < open_score_auroc[0]: 143 | trlog['maxmeta_auroc'] = open_score_auroc[0] 144 | trlog['maxmeta_auroc_epoch'] = epoch 145 | acc_auroc = (meta_test_acc[0], open_score_auroc[0]) 146 | self.save_model(epoch, 'max_auroc', acc_auroc) 147 | 148 | print(train_msg) 149 | 150 | # regular saving 151 | if epoch % 5 == 0: 152 | self.save_model(epoch,'last') 153 | print('The Best Meta Acc {:.4f} in Epoch {}, Best Meta AUROC {:.4f} in Epoch {}'.format(trlog['maxmeta_acc'],trlog['maxmeta_acc_epoch'],trlog['maxmeta_auroc'],trlog['maxmeta_auroc_epoch'])) 154 | 155 | 156 | def train_episode(self, epoch, train_loader, model, criterion, optimizer, args): 157 | """One epoch training""" 158 | model.train() 159 | if self.args.tunefeat==0: 160 | model.feature.eval() 161 | 162 | 163 | batch_time = AverageMeter() 164 | losses_cls = AverageMeter() 165 | losses_funit = AverageMeter() 166 | acc = AverageMeter() 167 | auroc = AverageMeter() 168 | end = time.time() 169 | 170 | with tqdm(train_loader, total=len(train_loader), leave=False) as pbar: 171 | for idx, data in enumerate(pbar): 172 | support_data, support_label, query_data, query_label, suppopen_data, suppopen_label, openset_data, openset_label, supp_idx, open_idx, base_ids = data 173 | 174 | # Data Conversion & Packaging 175 | support_data,support_label = support_data.float().cuda(),support_label.cuda().long() 176 | query_data,query_label = query_data.float().cuda(),query_label.cuda().long() 177 | suppopen_data,suppopen_label = suppopen_data.float().cuda(),suppopen_label.cuda().long() 178 | openset_data,openset_label = openset_data.float().cuda(),openset_label.cuda().long() 179 | supp_idx, open_idx,base_ids = supp_idx.long(), open_idx.long(),base_ids.long() 180 | openset_label = self.args.n_ways * torch.ones_like(openset_label) 181 | the_img = (support_data, query_data, suppopen_data, openset_data) 182 | the_label = (support_label,query_label,suppopen_label,openset_label) 183 | the_conj = (supp_idx, open_idx) 184 | 185 | _, _, probs, loss = model(the_img,the_label,the_conj,base_ids) 186 | query_cls_probs, openset_cls_probs = probs 187 | (loss_cls, loss_open_hinge, loss_funit) = loss 188 | loss_open = args.gamma * loss_open_hinge + args.funit * loss_funit 189 | 190 | loss = loss_open + loss_cls 191 | 192 | ### Closed Set Accuracy 193 | close_pred = np.argmax(probs[0][:,:,:self.args.n_ways].view(-1,self.args.n_ways).cpu().numpy(),-1) 194 | close_label = query_label.view(-1).cpu().numpy() 195 | acc.update(metrics.accuracy_score(close_label, close_pred),1) 196 | 197 | ### Open Set AUROC 198 | open_label_binary = np.concatenate((np.ones(close_pred.shape),np.zeros(close_pred.shape))) 199 | query_cls_probs = query_cls_probs.view(-1, self.args.n_ways+1) 200 | openset_cls_probs = openset_cls_probs.view(-1, self.args.n_ways+1) 201 | open_scores = torch.cat([query_cls_probs,openset_cls_probs], dim=0).cpu().numpy()[:,-1] 202 | auroc.update(metrics.roc_auc_score(1-open_label_binary,open_scores),1) 203 | 204 | 205 | losses_cls.update(loss_cls.item(), 1) 206 | losses_funit.update(loss_funit.item(), 1) 207 | 208 | # ===================backward===================== 209 | optimizer.zero_grad() 210 | loss.backward() 211 | optimizer.step() 212 | 213 | # ===================meters===================== 214 | batch_time.update(time.time() - end) 215 | end = time.time() 216 | 217 | 218 | pbar.set_postfix({"Acc":'{0:.2f}'.format(acc.avg), 219 | "Auroc":'{0:.2f}'.format(auroc.avg), 220 | "cls_ce" :'{0:.2f}'.format(losses_cls.avg), 221 | "funit" :'{0:.4f}'.format(losses_funit.avg), 222 | }) 223 | 224 | message = 'Epoch {} Train_Acc {acc.avg:.3f} Train_Auroc {auroc.avg:.3f}'.format(epoch, acc=acc, auroc=auroc) 225 | 226 | return acc.avg, auroc.avg, (losses_cls.avg, losses_funit.avg), message 227 | 228 | 229 | def save_model(self, epoch, name=None, acc_auroc=None): 230 | state = { 231 | 'epoch': epoch, 232 | 'cls_params': self.model.state_dict() if self.args.n_gpu==1 else self.model.module.state_dict(), 233 | 'acc_auroc': acc_auroc 234 | } 235 | # 'optimizer': self.optimizer.state_dict()['param_groups'], 236 | 237 | file_name = 'epoch_'+str(epoch)+'.pth' if name is None else name + '.pth' 238 | print('==> Saving', file_name) 239 | torch.save(state, os.path.join(self.save_path, file_name)) 240 | 241 | 242 | -------------------------------------------------------------------------------- /trainer/GMetaTrainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import numpy as np 5 | import argparse 6 | import socket 7 | import time 8 | import sys 9 | from tqdm import tqdm 10 | import pdb 11 | 12 | import torch 13 | import torch.optim as optim 14 | import torch.nn as nn 15 | import torch.backends.cudnn as cudnn 16 | from torch.utils.data import DataLoader 17 | import torch.nn.functional as F 18 | from torch.autograd import Variable 19 | from tensorboardX import SummaryWriter 20 | 21 | from architectures.GNetworkPre import GFeatureNet 22 | from trainer.GFSEval import run_test_gfsl 23 | from util import adjust_learning_rate, accuracy, AverageMeter 24 | from sklearn import metrics 25 | 26 | 27 | class GMetaTrainer(object): 28 | def __init__(self, args, dataset_trainer, eval_loader=None, hard_path=None): 29 | os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpus) 30 | args.logroot = os.path.join(args.logroot, args.featype + '_' + args.dataset) 31 | if not os.path.isdir(args.logroot): 32 | os.makedirs(args.logroot) 33 | 34 | try: 35 | iterations = args.lr_decay_epochs.split(',') 36 | args.lr_decay_epochs = list([]) 37 | for it in iterations: 38 | args.lr_decay_epochs.append(int(it)) 39 | except: 40 | pass 41 | 42 | args.model_name = '{}_{}_shot_{}'.format(args.dataset, args.n_train_runs, args.n_shots) 43 | 44 | 45 | self.save_path = os.path.join(args.logroot, args.model_name) 46 | if not os.path.isdir(self.save_path): 47 | os.mkdir(self.save_path) 48 | 49 | assert args.pretrained_model_path is not None, 'Missing Pretrained Model' 50 | params = torch.load(args.pretrained_model_path)['params'] 51 | feat_params = {k: v for k, v in params.items() if 'feature' in k} 52 | cls_params = {k: v for k, v in params.items() if 'cls_classifier' in k} 53 | 54 | self.args = args 55 | self.train_loader, self.val_loader, n_cls = dataset_trainer 56 | self.model = GFeatureNet(args, args.restype, n_cls, (cls_params,self.train_loader.dataset.vector_array)) 57 | 58 | 59 | ##### Load Pretrained Weights for Feature Extractor 60 | model_dict = self.model.state_dict() 61 | model_dict.update(feat_params) 62 | self.model.load_state_dict(model_dict) 63 | 64 | self.model.train() 65 | print('Loaded Pretrained Weight from %s' % args.pretrained_model_path) 66 | 67 | # optimizer 68 | if self.args.tunefeat == 0.0: 69 | optim_param = [{'params': self.model.cls_classifier.parameters()}] 70 | else: 71 | optim_param = [{'params': self.model.cls_classifier.parameters()},{'params': filter(lambda p: p.requires_grad, self.model.feature.parameters()),'lr': self.args.tunefeat}] 72 | 73 | self.optimizer = optim.SGD(optim_param, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) 74 | 75 | if torch.cuda.is_available(): 76 | if args.n_gpu > 1: 77 | self.model = nn.DataParallel(self.model) 78 | self.model = self.model.cuda() 79 | cudnn.benchmark = True 80 | 81 | # set cosine annealing scheduler 82 | if args.cosine: 83 | print("==> training with plateau scheduler ...") 84 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'max') 85 | else: 86 | print("==> training with MultiStep scheduler ... gamma {} step {}".format(args.lr_decay_rate, args.lr_decay_epochs)) 87 | 88 | def train(self, eval_loader=None): 89 | 90 | trlog = {} 91 | trlog['args'] = vars(self.args) 92 | trlog['maxmeta_acc'] = 0.0 93 | trlog['maxmeta_acc_epoch'] = 0 94 | trlog['maxmeta_auroc'] = 0.0 95 | trlog['maxmeta_auroc_epoch'] = 0 96 | trlog['maxmeta_all'] = 0.0 97 | trlog['maxmeta_all_epoch'] = 0 98 | 99 | writer = SummaryWriter(self.save_path) 100 | 101 | criterion = nn.CrossEntropyLoss() 102 | criterion = criterion.cuda() 103 | 104 | for epoch in range(1, self.args.epochs + 1): 105 | if self.args.tunefeat>0: 106 | del self.optimizer.param_groups[-1] 107 | if self.args.cosine: 108 | self.scheduler.step() 109 | else: 110 | adjust_learning_rate(epoch, self.args, self.optimizer, 0.0001) 111 | 112 | train_acc, train_auroc, train_loss, train_msg = self.train_episode_gen(epoch, self.train_loader, self.model, criterion, self.optimizer, self.args) 113 | writer.add_scalar('train/acc', float(train_acc), epoch) 114 | writer.add_scalar('train/auroc', float(train_auroc), epoch) 115 | writer.add_scalar('train/loss_cls', float(train_loss[0]), epoch) 116 | writer.add_scalar('train/loss_funit', float(train_loss[1]), epoch) 117 | writer.add_scalar('lr', self.optimizer.param_groups[0]['lr'], epoch) 118 | 119 | #evaluate 120 | if eval_loader is not None: 121 | start = time.time() 122 | result = run_test_gfsl(self.model, eval_loader) # epoch%5==0 123 | 124 | (arith_mean, harmonic_mean, delta, auroc, f1) = result 125 | 126 | test_time = time.time() - start 127 | writer.add_scalar('meta/mean_acc', float(arith_mean[0]), epoch) 128 | writer.add_scalar('meta/mean_std', float(arith_mean[1]), epoch) 129 | 130 | writer.add_scalar('meta/hmean_acc', float(harmonic_mean[0]), epoch) 131 | writer.add_scalar('meta_std/hmean_acc', float(harmonic_mean[1]), epoch) 132 | 133 | 134 | writer.add_scalar('meta/auroc', float(auroc[0]), epoch) 135 | writer.add_scalar('meta/auroc', float(auroc[1]), epoch) 136 | 137 | 138 | meta_msg = 'Meta Test Acc: {:.4f} H-Acc {:.4f} AUROC: {:.4f}, Time: {:.1f}'.format(arith_mean[0], harmonic_mean[0], auroc[0], test_time) 139 | train_msg = train_msg + ' | ' + meta_msg 140 | if trlog['maxmeta_acc'] < arith_mean[0]: 141 | trlog['maxmeta_acc'] = arith_mean[0] 142 | trlog['maxmeta_acc_epoch'] = epoch 143 | acc_auroc = (arith_mean[0], auroc[0]) 144 | self.save_model(epoch, 'max_acc', acc_auroc) 145 | if trlog['maxmeta_auroc'] < auroc[0]: 146 | trlog['maxmeta_auroc'] = auroc[0] 147 | trlog['maxmeta_auroc_epoch'] = epoch 148 | acc_auroc = (arith_mean[0], auroc[0]) 149 | self.save_model(epoch, 'max_auroc', acc_auroc) 150 | if trlog['maxmeta_all'] < (arith_mean[0]+auroc[0])/2: 151 | trlog['maxmeta_all'] = (arith_mean[0]+auroc[0])/2 152 | trlog['maxmeta_all_epoch'] = epoch 153 | acc_auroc = (arith_mean[0], auroc[0]) 154 | self.save_model(epoch, 'max_all', acc_auroc) 155 | print(train_msg) 156 | 157 | # regular saving 158 | if epoch % 5 == 0: 159 | self.save_model(epoch,'last') 160 | print('The Best Meta Acc {:.4f} in Epoch {}, Best Meta AUROC {:.4f} in Epoch {}, Meta ALL {:.4f} in Epoch {}'.format(trlog['maxmeta_acc'],trlog['maxmeta_acc_epoch'],trlog['maxmeta_auroc'],trlog['maxmeta_auroc_epoch'],trlog['maxmeta_all'],trlog['maxmeta_all_epoch'])) 161 | 162 | 163 | def train_episode_gen(self, epoch, train_loader, model, criterion, optimizer, args): 164 | """One epoch training""" 165 | model.train() 166 | if self.args.tunefeat == 0: 167 | model.feature.eval() 168 | 169 | batch_time = AverageMeter() 170 | losses_cls = AverageMeter() 171 | losses_funit = AverageMeter() 172 | acc = AverageMeter() 173 | auroc = AverageMeter() 174 | end = time.time() 175 | 176 | with tqdm(train_loader, total=len(train_loader), leave=False) as pbar: 177 | for idx, data in enumerate(pbar): 178 | support_data, support_label, query_data, query_label, suppopen_data, suppopen_label, openset_data, openset_label, baseset_data, baseset_label, supp_idx, open_idx, base_ids = data 179 | 180 | # Data Conversion & Packaging 181 | support_data,support_label = support_data.float().cuda(),support_label.cuda().long() 182 | query_data,query_label = query_data.float().cuda(),query_label.cuda().long() 183 | suppopen_data,suppopen_label = suppopen_data.float().cuda(),suppopen_label.cuda().long() 184 | openset_data,openset_label = openset_data.float().cuda(),openset_label.cuda().long() 185 | baseset_data,baseset_label = baseset_data.float().cuda(),baseset_label.cuda().long() 186 | openset_label = self.args.n_ways * torch.ones_like(openset_label) 187 | 188 | the_img = (support_data, query_data, suppopen_data, openset_data, baseset_data) 189 | the_label = (support_label,query_label,suppopen_label,openset_label,baseset_label) 190 | the_conj = (supp_idx, open_idx) 191 | num_baseclass = baseset_label.max()+1 192 | 193 | _, _, probs, loss = model(the_img,the_label,the_conj,base_ids) 194 | (query_cls_probs, openset_cls_probs, many_cls_probs, query_cls_scores, openset_cls_scores) = probs 195 | (loss_cls, loss_open_hinge, loss_funit) = loss 196 | loss_open = args.gamma * loss_open_hinge + args.funit*loss_funit 197 | loss = loss_open + loss_cls 198 | 199 | close_pred = np.argmax(query_cls_scores[:,:,num_baseclass:-1].contiguous().view(-1,self.args.n_ways).detach().cpu().numpy(),-1) 200 | close_label = query_label.view(-1).cpu().numpy() 201 | open_label_binary = np.concatenate((np.ones(close_pred.shape),np.zeros(close_pred.shape))) 202 | 203 | # The eval metric is based on FSOR 204 | query_cls_probs = F.softmax(query_cls_scores[:,:,num_baseclass:], dim=-1) 205 | openset_cls_probs = F.softmax(openset_cls_scores[:,:,num_baseclass:], dim=-1) 206 | query_cls_probs = query_cls_probs.view(-1, self.args.n_ways+1) 207 | openset_cls_probs = openset_cls_probs.view(-1, self.args.n_ways+1) 208 | open_scores = torch.cat([query_cls_probs,openset_cls_probs], dim=0).detach().cpu().numpy()[:,-1] 209 | acc.update(metrics.accuracy_score(close_label,close_pred),1) 210 | auroc.update(metrics.roc_auc_score(1-open_label_binary,open_scores),1) 211 | 212 | losses_cls.update(loss_cls.item(), 1) 213 | losses_funit.update(loss_funit.item(), 1) 214 | 215 | # ===================backward===================== 216 | optimizer.zero_grad() 217 | loss.backward() 218 | optimizer.step() 219 | 220 | # ===================meters===================== 221 | batch_time.update(time.time() - end) 222 | end = time.time() 223 | 224 | pbar.set_postfix({"Acc":'{0:.2f}'.format(acc.avg), 225 | "Auroc":'{0:.2f}'.format(auroc.avg), 226 | "cls_ce" :'{0:.2f}'.format(losses_cls.avg), 227 | "funit" :'{0:.4f}'.format(losses_funit.avg), 228 | }) 229 | 230 | message = 'Epoch {} Train_Acc {acc.avg:.3f} Train_Auroc {auroc.avg:.3f}'.format(epoch, acc=acc, auroc=auroc) 231 | return acc.avg, auroc.avg, (losses_cls.avg, losses_funit.avg), message 232 | 233 | def save_model(self, epoch, name=None, acc_auroc=None): 234 | state = { 235 | 'epoch': epoch, 236 | 'cls_params': self.model.state_dict() if self.args.n_gpu==1 else self.model.module.state_dict(), 237 | 'acc_auroc': acc_auroc 238 | } 239 | # 'optimizer': self.optimizer.state_dict()['param_groups'], 240 | 241 | file_name = 'epoch_'+str(epoch)+'.pth' if name is None else name + '.pth' 242 | print('==> Saving', file_name) 243 | torch.save(state, os.path.join(self.save_path, file_name)) 244 | 245 | 246 | 247 | 248 | -------------------------------------------------------------------------------- /architectures/AttnClassifier.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | import math 7 | import pdb 8 | 9 | class Classifier(nn.Module): 10 | def __init__(self, args, feat_dim, param_seman, train_weight_base=False): 11 | super(Classifier, self).__init__() 12 | 13 | # Weight & Bias for Base 14 | self.train_weight_base = train_weight_base 15 | self.init_representation(param_seman) 16 | if train_weight_base: 17 | print('Enable training base class weights') 18 | 19 | self.calibrator = SupportCalibrator(nway=args.n_ways, feat_dim=feat_dim, n_head=1, base_seman_calib=args.base_seman_calib, neg_gen_type=args.neg_gen_type) 20 | self.open_generator = OpenSetGenerater(args.n_ways, feat_dim, n_head=1, neg_gen_type=args.neg_gen_type, agg=args.agg) 21 | self.metric = Metric_Cosine() 22 | 23 | def forward(self, features, cls_ids, test=False): 24 | ## bs: features[0].size(0) 25 | ## support_feat: bs*nway*nshot*D 26 | ## query_feat: bs*(nway*nquery)*D 27 | ## base_ids: bs*54 28 | (support_feat, query_feat, openset_feat) = features 29 | 30 | (nb,nc,ns,ndim),nq = support_feat.size(),query_feat.size(1) 31 | (supp_ids, base_ids) = cls_ids 32 | base_weights,base_wgtmem,base_seman,support_seman = self.get_representation(supp_ids,base_ids) 33 | support_feat = torch.mean(support_feat, dim=2) 34 | 35 | supp_protos,support_attn = self.calibrator(support_feat, base_weights, support_seman, base_seman) 36 | 37 | fakeclass_protos, recip_unit = self.open_generator(supp_protos, base_weights, support_seman, base_seman) 38 | cls_protos = torch.cat([supp_protos, fakeclass_protos], dim=1) 39 | 40 | query_cls_scores = self.metric(cls_protos, query_feat) 41 | openset_cls_scores = self.metric(cls_protos, openset_feat) 42 | 43 | test_cosine_scores = (query_cls_scores,openset_cls_scores) 44 | 45 | query_funit_distance = 1.0- self.metric(recip_unit, query_feat) 46 | qopen_funit_distance = 1.0- self.metric(recip_unit, openset_feat) 47 | funit_distance = torch.cat([query_funit_distance,qopen_funit_distance],dim=1) 48 | 49 | return test_cosine_scores, supp_protos, fakeclass_protos, (base_weights,base_wgtmem), funit_distance 50 | 51 | def init_representation(self, param_seman): 52 | (params,seman_dict) = param_seman 53 | self.weight_base = nn.Parameter(params['cls_classifier.weight'], requires_grad=self.train_weight_base) 54 | self.bias_base = nn.Parameter(params['cls_classifier.bias'], requires_grad=self.train_weight_base) 55 | self.weight_mem = nn.Parameter(params['cls_classifier.weight'].clone(), requires_grad=False) 56 | 57 | self.seman = {k:nn.Parameter(torch.from_numpy(v),requires_grad=False).float().cuda() for k,v in seman_dict.items()} 58 | 59 | def get_representation(self, cls_ids, base_ids, randpick=False): 60 | if base_ids is not None: 61 | base_weights = self.weight_base[base_ids,:] ## bs*54*D 62 | base_wgtmem = self.weight_mem[base_ids,:] 63 | base_seman = self.seman['base'][base_ids,:] 64 | supp_seman = self.seman['base'][cls_ids,:] 65 | else: 66 | bs = cls_ids.size(0) 67 | base_weights = self.weight_base.repeat(bs,1,1) 68 | base_wgtmem = self.weight_mem.repeat(bs,1,1) 69 | base_seman = self.seman['base'].repeat(bs,1,1) 70 | supp_seman = self.seman['novel_test'][cls_ids,:] 71 | if randpick: 72 | num_base = base_weights.shape[1] 73 | base_size = self.base_size 74 | idx = np.random.choice(list(range(num_base)), size=base_size, replace=False) 75 | base_weights = base_weights[:, idx, :] 76 | base_seman = base_seman[:, idx, :] 77 | return base_weights,base_wgtmem,base_seman,supp_seman 78 | 79 | 80 | class SupportCalibrator(nn.Module): 81 | def __init__(self, nway, feat_dim, n_head=1,base_seman_calib=True, neg_gen_type='semang'): 82 | super(SupportCalibrator, self).__init__() 83 | self.nway = nway 84 | self.feat_dim = feat_dim 85 | self.base_seman_calib = base_seman_calib 86 | 87 | self.map_sem = nn.Sequential(nn.Linear(300,300),nn.LeakyReLU(0.1),nn.Dropout(0.1),nn.Linear(300,300)) 88 | 89 | self.calibrator = MultiHeadAttention(feat_dim//n_head, feat_dim//n_head, (feat_dim,feat_dim)) 90 | 91 | self.neg_gen_type = neg_gen_type 92 | if neg_gen_type == 'semang': 93 | self.task_visfuse = nn.Linear(feat_dim+300,feat_dim) 94 | self.task_semfuse = nn.Linear(feat_dim+300,300) 95 | 96 | def _seman_calib(self, seman): 97 | seman = self.map_sem(seman) 98 | return seman 99 | 100 | 101 | def forward(self, support_feat, base_weights, support_seman, base_seman): 102 | ## support_feat: bs*nway*640, base_weights: bs*num_base*640, support_seman: bs*nway*300, base_seman:bs*num_base*300 103 | n_bs, n_base_cls = base_weights.size()[:2] 104 | 105 | base_weights = base_weights.unsqueeze(dim=1).repeat(1,self.nway,1,1).view(-1, n_base_cls, self.feat_dim) 106 | 107 | support_feat = support_feat.view(-1,1,self.feat_dim) 108 | 109 | 110 | if self.neg_gen_type == 'semang': 111 | support_seman = self._seman_calib(support_seman) 112 | if self.base_seman_calib: 113 | base_seman = self._seman_calib(base_seman) 114 | 115 | base_seman = base_seman.unsqueeze(dim=1).repeat(1,self.nway,1,1).view(-1, n_base_cls, 300) 116 | support_seman = support_seman.view(-1, 1, 300) 117 | 118 | base_mem_vis = base_weights 119 | task_mem_vis = base_weights 120 | 121 | base_mem_seman = base_seman 122 | task_mem_seman = base_seman 123 | avg_task_mem = torch.mean(torch.cat([task_mem_vis,task_mem_seman],-1), 1, keepdim=True) 124 | 125 | gate_vis = torch.sigmoid(self.task_visfuse(avg_task_mem)) + 1.0 126 | gate_sem = torch.sigmoid(self.task_semfuse(avg_task_mem)) + 1.0 127 | 128 | base_weights = base_mem_vis * gate_vis 129 | base_seman = base_mem_seman * gate_sem 130 | 131 | elif self.neg_gen_type == 'attg': 132 | base_mem_vis = base_weights 133 | base_seman = None 134 | support_seman = None 135 | 136 | elif self.neg_gen_type == 'att': 137 | base_weights = support_feat 138 | base_mem_vis = support_feat 139 | support_seman = None 140 | base_seman = None 141 | 142 | else: 143 | return support_feat.view(n_bs,self.nway,-1), None 144 | 145 | support_center, _, support_attn, _ = self.calibrator(support_feat, base_weights, base_mem_vis, support_seman, base_seman) 146 | 147 | support_center = support_center.view(n_bs,self.nway,-1) 148 | support_attn = support_attn.view(n_bs,self.nway,-1) 149 | return support_center, support_attn 150 | 151 | 152 | class OpenSetGenerater(nn.Module): 153 | def __init__(self, nway, featdim, n_head=1, neg_gen_type='semang', agg='avg'): 154 | super(OpenSetGenerater, self).__init__() 155 | self.nway = nway 156 | self.att = MultiHeadAttention(featdim//n_head, featdim//n_head, (featdim,featdim)) 157 | self.featdim = featdim 158 | 159 | self.neg_gen_type = neg_gen_type 160 | if neg_gen_type == 'semang': 161 | self.task_visfuse = nn.Linear(featdim+300,featdim) 162 | self.task_semfuse = nn.Linear(featdim+300,300) 163 | 164 | 165 | self.agg = agg 166 | if agg == 'mlp': 167 | self.agg_func = nn.Sequential(nn.Linear(featdim,featdim),nn.LeakyReLU(0.5),nn.Dropout(0.5),nn.Linear(featdim,featdim)) 168 | 169 | self.map_sem = nn.Sequential(nn.Linear(300,300),nn.LeakyReLU(0.1),nn.Dropout(0.1),nn.Linear(300,300)) 170 | 171 | def _seman_calib(self, seman): 172 | ### feat: bs*d*feat_dim, seman: bs*d*300 173 | seman = self.map_sem(seman) 174 | return seman 175 | 176 | def forward(self, support_center, base_weights, support_seman=None, base_seman=None): 177 | ## support_center: bs*nway*D 178 | ## weight_base: bs*nbase*D 179 | bs = support_center.shape[0] 180 | n_bs, n_base_cls = base_weights.size()[:2] 181 | 182 | base_weights = base_weights.unsqueeze(dim=1).repeat(1,self.nway,1,1).view(-1, n_base_cls, self.featdim) 183 | support_center = support_center.view(-1, 1, self.featdim) 184 | 185 | if self.neg_gen_type=='semang': 186 | support_seman = self._seman_calib(support_seman) 187 | base_seman = base_seman.unsqueeze(dim=1).repeat(1,self.nway,1,1).view(-1, n_base_cls, 300) 188 | support_seman = support_seman.view(-1, 1, 300) 189 | 190 | base_mem_vis = base_weights 191 | task_mem_vis = base_weights 192 | 193 | base_mem_seman = base_seman 194 | task_mem_seman = base_seman 195 | avg_task_mem = torch.mean(torch.cat([task_mem_vis,task_mem_seman],-1), 1, keepdim=True) 196 | 197 | gate_vis = torch.sigmoid(self.task_visfuse(avg_task_mem)) + 1.0 198 | gate_sem = torch.sigmoid(self.task_semfuse(avg_task_mem)) + 1.0 199 | 200 | base_weights = base_mem_vis * gate_vis 201 | base_seman = base_mem_seman * gate_sem 202 | 203 | 204 | elif self.neg_gen_type == 'attg': 205 | base_mem_vis = base_weights 206 | support_seman = None 207 | base_seman = None 208 | 209 | elif self.neg_gen_type == 'att': 210 | base_weights = support_center 211 | base_mem_vis = support_center 212 | support_seman = None 213 | base_seman = None 214 | 215 | else: 216 | fakeclass_center = support_center.mean(dim=0, keepdim=True) 217 | if self.agg == 'mlp': 218 | fakeclass_center = self.agg_func(fakeclass_center) 219 | return fakeclass_center, support_center.view(bs, -1, self.featdim) 220 | 221 | 222 | output, attcoef, attn_score, value = self.att(support_center, base_weights, base_mem_vis, support_seman, base_seman) ## bs*nway*nbase 223 | 224 | output = output.view(bs, -1, self.featdim) 225 | fakeclass_center = output.mean(dim=1,keepdim=True) 226 | 227 | if self.agg == 'mlp': 228 | fakeclass_center = self.agg_func(fakeclass_center) 229 | 230 | return fakeclass_center, output 231 | 232 | 233 | class MultiHeadAttention(nn.Module): 234 | ''' Multi-Head Attention module ''' 235 | 236 | def __init__(self, d_k, d_v, d_model, n_head=1, dropout=0.1): 237 | super().__init__() 238 | self.n_head = n_head 239 | self.d_k = d_k 240 | self.d_v = d_v 241 | 242 | #### Visual feature projection head 243 | self.w_qs = nn.Linear(d_model[0], n_head * d_k, bias=False) 244 | self.w_ks = nn.Linear(d_model[1], n_head * d_k, bias=False) 245 | self.w_vs = nn.Linear(d_model[-1], n_head * d_v, bias=False) 246 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model[0] + d_k))) 247 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model[1] + d_k))) 248 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model[-1] + d_v))) 249 | 250 | #### Semantic projection head ####### 251 | self.w_qs_sem = nn.Linear(300, n_head * d_k, bias=False) 252 | self.w_ks_sem = nn.Linear(300, n_head * d_k, bias=False) 253 | self.w_vs_sem = nn.Linear(300, n_head * d_k, bias=False) 254 | 255 | nn.init.normal_(self.w_qs_sem.weight, mean=0, std=np.sqrt(2.0 / 600)) 256 | nn.init.normal_(self.w_ks_sem.weight, mean=0, std=np.sqrt(2.0 / 600)) 257 | nn.init.normal_(self.w_vs_sem.weight, mean=0, std=np.sqrt(2.0 / 600)) 258 | 259 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 260 | 261 | self.fc = nn.Linear(n_head * d_v, d_model[0], bias=False) 262 | nn.init.xavier_normal_(self.fc.weight) 263 | self.dropout = nn.Dropout(dropout) 264 | 265 | 266 | def forward(self, q, k, v, q_sem=None, k_sem=None, mark_res=True): 267 | ### q: bs*nway*D 268 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 269 | sz_b, len_q, _ = q.size() 270 | sz_b, len_k, _ = k.size() 271 | sz_b, len_v, _ = v.size() 272 | 273 | residual = q 274 | 275 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 276 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 277 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 278 | 279 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 280 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 281 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 282 | 283 | 284 | if q_sem is not None: 285 | sz_b, len_q, _ = q_sem.size() 286 | sz_b, len_k, _ = k_sem.size() 287 | q_sem = self.w_qs_sem(q_sem).view(sz_b, len_q, n_head, d_k) 288 | k_sem = self.w_ks_sem(k_sem).view(sz_b, len_k, n_head, d_k) 289 | q_sem = q_sem.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) 290 | k_sem = k_sem.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) 291 | 292 | output, attn, attn_score = self.attention(q, k, v, q_sem, k_sem) 293 | 294 | output = output.view(n_head, sz_b, len_q, d_v) 295 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 296 | 297 | output = self.dropout(self.fc(output)) 298 | if mark_res: 299 | output = output + residual 300 | 301 | return output, attn, attn_score, v 302 | 303 | 304 | class ScaledDotProductAttention(nn.Module): 305 | ''' Scaled Dot-Product Attention ''' 306 | 307 | def __init__(self, temperature, attn_dropout=0.1): 308 | super().__init__() 309 | self.temperature = temperature 310 | self.dropout = nn.Dropout(attn_dropout) 311 | self.softmax = nn.Softmax(dim=2) 312 | 313 | def forward(self, q, k, v, q_sem = None, k_sem = None): 314 | 315 | attn_score = torch.bmm(q, k.transpose(1, 2)) 316 | 317 | if q_sem is not None: 318 | attn_sem = torch.bmm(q_sem, k_sem.transpose(1, 2)) 319 | q = q + q_sem 320 | k = k + k_sem 321 | attn_score = torch.bmm(q, k.transpose(1, 2)) 322 | 323 | attn_score /= self.temperature 324 | attn = self.softmax(attn_score) 325 | attn = self.dropout(attn) 326 | 327 | output = torch.bmm(attn, v) 328 | return output, attn, attn_score 329 | 330 | 331 | class Metric_Cosine(nn.Module): 332 | def __init__(self, temperature=10): 333 | super(Metric_Cosine, self).__init__() 334 | self.temp = nn.Parameter(torch.tensor(float(temperature))) 335 | 336 | def forward(self, supp_center, query_feature): 337 | ## supp_center: bs*nway*D 338 | ## query_feature: bs*(nway*nquery)*D 339 | supp_center = F.normalize(supp_center, dim=-1) # eps=1e-6 default 1e-12 340 | query_feature = F.normalize(query_feature, dim=-1) 341 | logits = torch.bmm(query_feature, supp_center.transpose(1,2)) 342 | return logits * self.temp 343 | 344 | 345 | -------------------------------------------------------------------------------- /dataloader/mini_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | import pdb 7 | from torch.utils.data import Dataset 8 | import torchvision.transforms as transforms 9 | 10 | 11 | class OpenMini(Dataset): 12 | def __init__(self, args, partition='test', mode='episode', is_training=False, fix_seed=True): 13 | super(OpenMini, self).__init__() 14 | self.mode = mode 15 | self.fix_seed = fix_seed 16 | self.n_ways = args.n_ways 17 | self.n_open_ways = args.n_open_ways 18 | self.n_shots = args.n_shots 19 | self.n_queries = args.n_queries 20 | self.n_episodes = args.n_test_runs if partition == 'test' else args.n_train_runs 21 | self.n_aug_support_samples = 1 if partition == 'train' else args.n_aug_support_samples 22 | self.partition = partition 23 | 24 | mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0] 25 | std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0] 26 | normalize = transforms.Normalize(mean=mean, std=std) 27 | 28 | if is_training: 29 | self.train_transform = transforms.Compose([ 30 | transforms.RandomCrop(84, padding=8), 31 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 32 | transforms.RandomHorizontalFlip(), 33 | transforms.ToTensor(), 34 | normalize 35 | ]) 36 | else: 37 | self.train_transform = transforms.Compose([ 38 | transforms.RandomCrop(84, padding=8), 39 | transforms.RandomHorizontalFlip(), 40 | transforms.ToTensor(), 41 | normalize 42 | ]) 43 | 44 | with open(os.path.join(args.data_root,'miniImageNet_category_vector.pickle'), 'rb') as f: 45 | pack = pickle.load(f, encoding='latin1') 46 | vector_array = [] 47 | for i in range(100): 48 | vector_array.append(pack[i][1]) 49 | vector_array = np.array(vector_array) # Train 0~63, Val 64~79, Test 80~99 50 | self.vector_array = {'base':vector_array[:64],'nove_val':vector_array[64:80],'novel_test':vector_array[80:]} 51 | 52 | self.test_transform = transforms.Compose([transforms.ToTensor(),normalize]) 53 | self.init_episode(args.data_root,partition) 54 | # self.get_episode(10) 55 | 56 | def __getitem__(self, item): 57 | return self.get_episode(item) 58 | 59 | def init_episode(self, data_root, partition): 60 | suffix = partition if partition in ['val','test'] else 'train_phase_train' 61 | filename = 'miniImageNet_category_split_{}.pickle'.format(suffix) 62 | self.data = {} 63 | 64 | with open(os.path.join(data_root, filename), 'rb') as f: 65 | pack = pickle.load(f, encoding='latin1') 66 | imgs = pack['data'].astype('uint8') 67 | labels = pack['labels'] 68 | self.imgs = [Image.fromarray(x) for x in imgs] 69 | min_label = min(labels) 70 | self.labels = [x - min_label for x in labels] 71 | print('Load {} Data of {} for miniImagenet in Meta-Learning Stage'.format(len(self.imgs), partition)) 72 | self.data = {} 73 | for idx in range(len(self.imgs)): 74 | if self.labels[idx] not in self.data: 75 | self.data[self.labels[idx]] = [] 76 | self.data[self.labels[idx]].append(self.imgs[idx]) 77 | self.classes = list(self.data.keys()) 78 | 79 | 80 | def get_episode(self, item): 81 | 82 | if self.fix_seed: 83 | np.random.seed(item) 84 | cls_sampled = np.random.choice(self.classes, self.n_ways, False) 85 | support_xs = [] 86 | support_ys = [] 87 | suppopen_xs = [] 88 | suppopen_ys = [] 89 | query_xs = [] 90 | query_ys = [] 91 | openset_xs = [] 92 | openset_ys = [] 93 | manyshot_xs = [] 94 | manyshot_ys = [] 95 | 96 | # Close set preparation 97 | for idx, the_cls in enumerate(cls_sampled): 98 | imgs = self.data[the_cls] 99 | support_xs_ids_sampled = np.random.choice(range(len(imgs)), self.n_shots, False) 100 | support_xs.extend([imgs[the_id] for the_id in support_xs_ids_sampled]) 101 | support_ys.extend([idx] * self.n_shots) 102 | query_xs_ids = np.setxor1d(np.arange(len(imgs)), support_xs_ids_sampled) 103 | query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False) 104 | query_xs.extend([imgs[the_id] for the_id in query_xs_ids]) 105 | query_ys.extend([idx] * self.n_queries) 106 | 107 | # Open set preparation 108 | cls_open_ids = np.setxor1d(np.arange(len(self.classes)), cls_sampled) 109 | cls_open_ids = np.random.choice(cls_open_ids, self.n_open_ways, False) 110 | for idx, the_cls in enumerate(cls_open_ids): 111 | imgs = self.data[the_cls] 112 | suppopen_xs_ids_sampled = np.random.choice(range(len(imgs)), self.n_shots, False) 113 | suppopen_xs.extend([imgs[the_id] for the_id in suppopen_xs_ids_sampled]) 114 | suppopen_ys.extend([idx] * self.n_shots) 115 | openset_xs_ids = np.setxor1d(np.arange(len(imgs)), suppopen_xs_ids_sampled) 116 | openset_xs_ids_sampled = np.random.choice(range(len(imgs)), self.n_queries, False) 117 | openset_xs.extend([imgs[the_id] for the_id in openset_xs_ids_sampled]) 118 | openset_ys.extend([the_cls] * self.n_queries) 119 | 120 | if self.partition == 'train': 121 | base_ids = np.setxor1d(np.arange(len(self.classes)), np.concatenate([cls_sampled,cls_open_ids])) 122 | assert len(set(base_ids).union(set(cls_open_ids)).union(set(cls_sampled))) == 64 123 | base_ids = np.array(sorted(base_ids)) 124 | 125 | 126 | if self.n_aug_support_samples > 1: 127 | support_xs_aug = [support_xs[i:i+self.n_shots]*self.n_aug_support_samples for i in range(0,len(support_xs),self.n_shots)] 128 | support_ys_aug = [support_ys[i:i+self.n_shots]*self.n_aug_support_samples for i in range(0,len(support_ys),self.n_shots)] 129 | support_xs,support_ys = support_xs_aug[0],support_ys_aug[0] 130 | for next_xs,next_ys in zip(support_xs_aug[1:],support_ys_aug[1:]): 131 | support_xs.extend(next_xs) 132 | support_ys.extend(next_ys) 133 | 134 | suppopen_xs_aug = [suppopen_xs[i:i+self.n_shots]*self.n_aug_support_samples for i in range(0,len(support_xs),self.n_shots)] 135 | suppopen_ys_aug = [suppopen_ys[i:i+self.n_shots]*self.n_aug_support_samples for i in range(0,len(support_ys),self.n_shots)] 136 | suppopen_xs,suppopen_ys = suppopen_xs_aug[0],suppopen_ys_aug[0] 137 | for next_xs,next_ys in zip(suppopen_xs_aug[1:],suppopen_ys_aug[1:]): 138 | suppopen_xs.extend(next_xs) 139 | suppopen_ys.extend(next_ys) 140 | 141 | support_xs = torch.stack(list(map(lambda x: self.train_transform(x), support_xs))) 142 | suppopen_xs = torch.stack(list(map(lambda x: self.train_transform(x), suppopen_xs))) 143 | query_xs = torch.stack(list(map(lambda x: self.test_transform(x), query_xs))) 144 | openset_xs = torch.stack(list(map(lambda x: self.test_transform(x), openset_xs))) 145 | support_ys,query_ys,openset_ys = np.array(support_ys),np.array(query_ys),np.array(openset_ys) 146 | suppopen_ys = np.array(suppopen_ys) 147 | cls_sampled, cls_open_ids = np.array(cls_sampled), np.array(cls_open_ids) 148 | 149 | 150 | if self.partition == 'train': 151 | return support_xs, support_ys, query_xs, query_ys, suppopen_xs, suppopen_ys, openset_xs, openset_ys, cls_sampled, cls_open_ids, base_ids, 152 | else: 153 | return support_xs, support_ys, query_xs, query_ys, suppopen_xs, suppopen_ys, openset_xs, openset_ys, cls_sampled, cls_open_ids 154 | 155 | 156 | def __len__(self): 157 | return self.n_episodes 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | class GenMini(OpenMini): 166 | def __init__(self, args, partition='test', mode='episode', is_training=False, fix_seed=True): 167 | super(GenMini, self).__init__(args, partition, mode, is_training, fix_seed) 168 | 169 | def __getitem__(self, item): 170 | return self.get_episode(item) 171 | 172 | def init_episode(self, data_root, partition): 173 | 174 | if partition == 'train': 175 | 176 | filename = 'miniImageNet_category_split_train_phase_train.pickle' 177 | with open(os.path.join(data_root, filename), 'rb') as f: 178 | pack = pickle.load(f, encoding='latin1') 179 | self.base_imgs = pack['data'].astype('uint8') 180 | labels = pack['labels'] 181 | self.base_imgs = [Image.fromarray(x) for x in self.base_imgs] 182 | min_label = min(labels) 183 | self.base_labels = [x - min_label for x in labels] 184 | self.base_data = {} 185 | for idx in range(len(self.base_imgs)): 186 | if self.base_labels[idx] not in self.base_data: 187 | self.base_data[self.base_labels[idx]] = [] 188 | self.base_data[self.base_labels[idx]].append(self.base_imgs[idx]) 189 | self.base_classes = list(self.base_data.keys()) 190 | 191 | self.novel_imgs = self.base_imgs 192 | self.novel_labels = self.base_labels 193 | self.novel_data = self.base_data 194 | self.novel_classes = self.base_classes 195 | 196 | elif partition == 'test': 197 | 198 | filename = 'miniImageNet_category_split_train_phase_test.pickle' 199 | with open(os.path.join(data_root, filename), 'rb') as f: 200 | pack = pickle.load(f, encoding='latin1') 201 | self.base_imgs = pack['data'].astype('uint8') 202 | labels = pack['labels'] 203 | self.base_imgs = [Image.fromarray(x) for x in self.base_imgs] 204 | min_label = min(labels) 205 | self.base_labels = [x - min_label for x in labels] 206 | self.base_data = {} 207 | for idx in range(len(self.base_imgs)): 208 | if self.base_labels[idx] not in self.base_data: 209 | self.base_data[self.base_labels[idx]] = [] 210 | self.base_data[self.base_labels[idx]].append(self.base_imgs[idx]) 211 | self.base_classes = list(self.base_data.keys()) 212 | 213 | filename = 'miniImageNet_category_split_test.pickle' 214 | with open(os.path.join(data_root, filename), 'rb') as f: 215 | pack = pickle.load(f, encoding='latin1') 216 | self.novel_imgs = pack['data'].astype('uint8') 217 | labels = pack['labels'] 218 | self.novel_imgs = [Image.fromarray(x) for x in self.novel_imgs] 219 | min_label = min(labels) 220 | self.novel_labels = [x - min_label + len(self.base_classes) for x in labels] 221 | self.novel_data = {} 222 | for idx in range(len(self.novel_imgs)): 223 | if self.novel_labels[idx] not in self.novel_data: 224 | self.novel_data[self.novel_labels[idx]] = [] 225 | self.novel_data[self.novel_labels[idx]].append(self.novel_imgs[idx]) 226 | self.novel_classes = list(self.novel_data.keys()) 227 | 228 | print('Load {} Data of {} for miniImagenet in Meta-Learning Stage'.format(len(self.base_imgs), partition)) 229 | print('Load {} Data of {} for miniImagenet in Meta-Learning Stage'.format(len(self.novel_imgs), partition)) 230 | 231 | def get_episode(self, item): 232 | 233 | if self.fix_seed: 234 | np.random.seed(item) 235 | cls_sampled = np.random.choice(self.novel_classes, self.n_ways, False) 236 | support_xs = [] 237 | support_ys = [] 238 | query_xs = [] 239 | query_ys = [] 240 | suppopen_xs = [] 241 | suppopen_ys = [] 242 | openset_xs = [] 243 | openset_ys = [] 244 | manyshot_xs = [] 245 | manyshot_ys = [] 246 | 247 | for idx, the_cls in enumerate(cls_sampled): 248 | imgs = self.novel_data[the_cls] 249 | support_xs_ids_sampled = np.random.choice(range(len(imgs)), self.n_shots, False) 250 | support_xs.extend([imgs[the_id] for the_id in support_xs_ids_sampled]) 251 | support_ys.extend([idx] * self.n_shots) 252 | query_xs_ids = np.setxor1d(np.arange(len(imgs)), support_xs_ids_sampled) 253 | query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False) 254 | query_xs.extend([imgs[the_id] for the_id in query_xs_ids]) 255 | query_ys.extend([idx] * self.n_queries) 256 | 257 | cls_open_ids = np.setxor1d(self.novel_classes, cls_sampled) 258 | cls_open_ids = np.random.choice(cls_open_ids, self.n_ways, False) 259 | for idx, the_cls in enumerate(cls_open_ids): 260 | imgs = self.novel_data[the_cls] 261 | suppopen_xs_ids_sampled = np.random.choice(range(len(imgs)), self.n_shots, False) 262 | suppopen_xs.extend([imgs[the_id] for the_id in suppopen_xs_ids_sampled]) 263 | suppopen_ys.extend([idx] * self.n_shots) 264 | openset_xs_ids = np.setxor1d(np.arange(len(imgs)), suppopen_xs_ids_sampled) 265 | openset_xs_ids_sampled = np.random.choice(range(len(imgs)), self.n_queries, False) 266 | openset_xs.extend([imgs[the_id] for the_id in openset_xs_ids_sampled]) 267 | openset_ys.extend([the_cls] * self.n_queries) 268 | 269 | if self.partition == 'train': 270 | base_ids = np.setxor1d(self.base_classes, np.concatenate([cls_sampled,cls_open_ids])) 271 | assert len(set(base_ids).union(set(cls_open_ids)).union(set(cls_sampled))) == 64 272 | base_ids = sorted(base_ids) 273 | else: 274 | base_ids = sorted(self.base_classes) 275 | 276 | num_query = self.n_ways * self.n_queries 277 | assert num_query > len(base_ids) 278 | num_atleast = num_query//len(base_ids) 279 | num_extra = list(np.random.choice(base_ids, num_query-len(base_ids)*num_atleast, False)) 280 | num_extra.sort() 281 | num_samples = {} 282 | for the_cls in base_ids: 283 | num_samples[the_cls] = num_atleast + 1 if the_cls in num_extra else num_atleast 284 | 285 | for idx, the_cls in enumerate(base_ids): 286 | imgs = self.base_data[the_cls] 287 | manyshot_xs_ids_sampled = np.random.choice(range(len(imgs)), num_samples[the_cls], False) 288 | manyshot_xs.extend([imgs[the_id] for the_id in manyshot_xs_ids_sampled]) 289 | manyshot_ys.extend([idx] * num_samples[the_cls]) 290 | 291 | if self.n_aug_support_samples > 1: 292 | support_xs_aug = [support_xs[i:i+self.n_shots]*self.n_aug_support_samples for i in range(0,len(support_xs),self.n_shots)] 293 | support_ys_aug = [support_ys[i:i+self.n_shots]*self.n_aug_support_samples for i in range(0,len(support_ys),self.n_shots)] 294 | support_xs,support_ys = support_xs_aug[0],support_ys_aug[0] 295 | for next_xs,next_ys in zip(support_xs_aug[1:],support_ys_aug[1:]): 296 | support_xs.extend(next_xs) 297 | support_ys.extend(next_ys) 298 | 299 | suppopen_xs_aug = [suppopen_xs[i:i+self.n_shots]*self.n_aug_support_samples for i in range(0,len(support_xs),self.n_shots)] 300 | suppopen_ys_aug = [suppopen_ys[i:i+self.n_shots]*self.n_aug_support_samples for i in range(0,len(support_ys),self.n_shots)] 301 | suppopen_xs,suppopen_ys = suppopen_xs_aug[0],suppopen_ys_aug[0] 302 | for next_xs,next_ys in zip(suppopen_xs_aug[1:],suppopen_ys_aug[1:]): 303 | suppopen_xs.extend(next_xs) 304 | suppopen_ys.extend(next_ys) 305 | 306 | support_xs = torch.stack(list(map(lambda x: self.train_transform(x), support_xs))) 307 | suppopen_xs = torch.stack(list(map(lambda x: self.train_transform(x), suppopen_xs))) 308 | query_xs = torch.stack(list(map(lambda x: self.test_transform(x), query_xs))) 309 | openset_xs = torch.stack(list(map(lambda x: self.test_transform(x), openset_xs))) 310 | manyshot_xs = torch.stack(list(map(lambda x: self.test_transform(x), manyshot_xs))) 311 | support_ys,query_ys,openset_ys = np.array(support_ys),np.array(query_ys),np.array(openset_ys) 312 | suppopen_ys,manyshot_ys = np.array(suppopen_ys),np.array(manyshot_ys) 313 | cls_sampled, cls_open_ids = np.array(cls_sampled), np.array(cls_open_ids) 314 | 315 | if self.partition == 'train': 316 | return support_xs, support_ys, query_xs, query_ys, suppopen_xs, suppopen_ys, openset_xs, openset_ys, manyshot_xs, manyshot_ys, cls_sampled, cls_open_ids, np.array(base_ids) 317 | else: 318 | return support_xs, support_ys, query_xs, query_ys, suppopen_xs, suppopen_ys, openset_xs, openset_ys, manyshot_xs, manyshot_ys, cls_sampled, cls_open_ids 319 | 320 | def __len__(self): 321 | return self.n_episodes 322 | --------------------------------------------------------------------------------