├── data ├── __init__.py ├── datamgr.py └── dataset.py ├── network ├── __init__.py └── resnet.py ├── .idea ├── misc.xml ├── vcs.xml ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml ├── FeatWalk.iml └── deployment.xml ├── methods ├── __init__.py ├── stl_deepbdc.py ├── bdc_module.py ├── template.py └── FeatWalk.py ├── run.sh ├── README.md ├── utils ├── utils.py └── loss.py └── eval.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datamgr 2 | from . import dataset 3 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | from . import resnet 2 | # from . import convnet 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /methods/__init__.py: -------------------------------------------------------------------------------- 1 | from . import template 2 | # from . import protonet 3 | # from . import good_embed 4 | # from . import meta_deepbdc 5 | # from . import stl_deepbdc 6 | from . import bdc_module 7 | 8 | 9 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/FeatWalk.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | gpuid=0 2 | 3 | python eval.py --gpu ${gpuid} --n_episodes 2000 --n_aug_support_samples 17 --n_shot 1 --distill_model mini/ResNet12_stl_deepbdc_distill/last_model.tar --test_times 5 --lr 0.5 --fix_seed --sfc_bs 3 --sim_temperature 32 4 | python eval.py --gpu ${gpuid} --n_episodes 2000 --n_aug_support_samples 17 --n_shot 5 --distill_model mini/ResNet12_stl_deepbdc_distill/last_model.tar --test_times 5 --lr 0.01 --fix_seed --sfc_bs 3 --sim_temperature 32 -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 30 | -------------------------------------------------------------------------------- /methods/stl_deepbdc.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from .template import MetaTemplate 7 | from sklearn.linear_model import LogisticRegression 8 | from .bdc_module import BDC 9 | 10 | 11 | class STLDeepBDC(MetaTemplate): 12 | def __init__(self, params, model_func, n_way, n_support): 13 | super(STLDeepBDC, self).__init__(params, model_func, n_way, n_support) 14 | self.loss_fn = nn.CrossEntropyLoss() 15 | 16 | reduce_dim = params.reduce_dim 17 | self.feat_dim = int(reduce_dim * (reduce_dim+1) / 2) 18 | self.dcov = BDC(is_vec=True, input_dim=self.feature.feat_dim, dimension_reduction=reduce_dim) 19 | 20 | self.C = params.penalty_C 21 | self.params = params 22 | 23 | def feature_forward(self, x): 24 | out = self.dcov(x) 25 | return out 26 | 27 | def set_forward(self, x, is_feature=True): 28 | # print(x.shape) 29 | with torch.no_grad(): 30 | z_support, z_query = self.parse_feature(x, is_feature) 31 | # print(z_support.shape) 32 | z_support = z_support.detach() 33 | z_query = z_query.detach() 34 | 35 | z_support = z_support.contiguous().view(self.n_way * self.n_support, -1) 36 | z_query = z_query.contiguous().view(self.n_way * self.n_query, -1) 37 | 38 | qry_norm = torch.norm(z_query, p=2, dim=1).unsqueeze(1).expand_as(z_query) 39 | spt_norm = torch.norm(z_support, p=2, dim=1).unsqueeze(1).expand_as(z_support) 40 | qry_normalized = z_query.div(qry_norm + 1e-6) 41 | spt_normalized = z_support.div(spt_norm + 1e-6) 42 | 43 | z_query = qry_normalized.detach().cpu().numpy() 44 | z_support = spt_normalized.detach().cpu().numpy() 45 | y_support = np.repeat(range(self.n_way), self.n_support) 46 | 47 | clf = LogisticRegression(penalty='l2', 48 | random_state=0, 49 | C=self.C, 50 | solver='lbfgs', 51 | max_iter=1000, 52 | multi_class='multinomial') 53 | clf.fit(z_support, y_support) 54 | scores = clf.predict(z_query) 55 | 56 | return scores 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FeatWalk 2 | 3 | FeatWalk is a method tailored for few-shot learning settings, focusing on effectively mining local views to mitigate the interference caused by discriminative features in global view pre-training. By analyzing the correlation of local views with different class prototypes, FeatWalk constructs a more comprehensive class-related representation. This method has been accepted by AAAI 2024, and this repository serves as the official implementation for reference. 4 | 5 | ## Comparison with Baseline Methods 6 | 7 | The following table demonstrates the performance of FeatWalk compared to the baseline method DeepBDC in various few-shot learning (FSL) scenarios on MiniImageNet and TieredImageNet. The results indicate that FeatWalk significantly outperforms DeepBDC in different FSL scenarios. 8 | 9 | | Method | Embedding | Mini
5-way 1-shot | Mini
5-way 5-shot | Tiered
5-way 1-shot | Tiered
5-way 5-shot | 10 | |----------|-----------|------------------------|------------------------|--------------------------|--------------------------| 11 | | DeepBDC | BDC | 67.83 ± 0.43 | 85.45 ± 0.29 | 73.82 ± 0.47 | 89.00 ± 0.30 | 12 | | FeatWalk | BDC | 70.21 ± 0.44 | 87.38 ± 0.27 | 75.25 ± 0.48 | 89.92 ± 0.29 | 13 | 14 | 15 | ## Preparation Before Running 16 | 17 | Before starting with FeatWalk, please ensure the following preparations are made: 18 | 19 | 1. Place the pre-trained models in the `checkpoint` directory. The pre-trained models can be obtained through the corresponding baseline methods or accessed from the official [DeepBDC](https://github.com/Fei-Long121/DeepBDC) implementation. 20 | 2. Ensure that datasets (such as [MiniImageNet](https://drive.google.com/file/d/1aBxfcU5cn-htIlqriiOQCOXp_t9TOm9g/view?usp=sharing)) are located in the `filelist` directory. 21 | 22 | #### Dataset Structure: 23 | ``` 24 | --FeatWalk 25 | |--filelist 26 | |--miniImageNet 27 | |--train 28 | |--val 29 | |--test 30 | ``` 31 | ## Running Commands 32 | 33 | To run FeatWalk, use the following command: 34 | 35 | ```bash 36 | # 5-Way 1-shot/5-shot on MiniImageNet 37 | sh run.sh 38 | ``` 39 | 40 | ## Acknowledgments 41 | We would like to express our heartfelt gratitude to the open-source methods [GoodEmbed](https://github.com/WangYueFt/rfs/) and [DeepBDC](https://github.com/Fei-Long121/DeepBDC). Our code for this paper was inspired and informed by these sources, and their contributions have been invaluable in supporting our work. 42 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import shutil 4 | import time 5 | import pprint 6 | import torch 7 | import numpy as np 8 | import os.path as osp 9 | import random 10 | import torch.nn.functional as F 11 | 12 | def set_seed(seed): 13 | if seed == 0: 14 | print(' random seed') 15 | torch.backends.cudnn.benchmark = True 16 | else: 17 | print('manual seed:', seed) 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | torch.backends.cudnn.deterministic = True 23 | torch.backends.cudnn.benchmark = False 24 | 25 | def load_model(model, dir): 26 | model_dict = model.state_dict() 27 | file_dict = torch.load(dir)['state'] 28 | for k, v in file_dict.items(): 29 | if k not in model_dict: 30 | print(k) 31 | file_dict = {k: v for k, v in file_dict.items() if k in model_dict} 32 | model_dict.update(file_dict) 33 | model.load_state_dict(model_dict) 34 | return model 35 | 36 | def compute_weight_local(feat_g,feat_ql,feat_sl,temperature=2.0): 37 | # feat_g : nk * dim 38 | # feat_l : nk * m * dim 39 | [_,k,m,dim] = feat_sl.shape 40 | [n,q,m,dim] = feat_ql.shape 41 | 42 | feat_g_expand = feat_g.unsqueeze(2).expand_as(feat_ql) 43 | sim_gl = torch.cosine_similarity(feat_g_expand,feat_ql,dim=-1) 44 | I_opp_m = (1 - torch.eye(m)).unsqueeze(0).to(sim_gl.device) 45 | sim_gl = -(torch.matmul(sim_gl, I_opp_m).unsqueeze(-2))/(m-1) 46 | 47 | 48 | return sim_gl 49 | 50 | # proto_walk 51 | def compute_weight_local(feat_g,feat_ql,feat_sl,measure = "cosine"): 52 | # feat_g : nk * dim 53 | # feat_l : nk * m * dim 54 | [_,k,m,dim] = feat_sl.shape 55 | [n,q,m,dim] = feat_ql.shape 56 | # print(feat_ql.shape) 57 | 58 | feat_g_expand = torch.mean(feat_g,dim=1).unsqueeze(0).unsqueeze(1).unsqueeze(3) 59 | if measure == "cosine": 60 | sim_gl = torch.cosine_similarity(feat_g_expand,feat_ql.unsqueeze(2),dim=-1) 61 | else: 62 | sim_gl = -1 * 0.002 * torch.sum((feat_g_expand - feat_ql.unsqueeze(2)) ** 2, dim=-1) 63 | 64 | I_m = torch.eye(m).unsqueeze(0).unsqueeze(1).to(sim_gl.device) 65 | sim_gl = torch.matmul(sim_gl, I_m) 66 | 67 | return sim_gl 68 | 69 | 70 | if __name__ == '__main__': 71 | feat_g = torch.randn((5,15,64)) 72 | # feat_g = torch.ones((5,3,64)) 73 | feat_sl = torch.randn((5,3,6,64)) 74 | feat_ql = torch.randn((5,15,6,64)) 75 | # feat_l = torch.ones((5,3,6,64)) 76 | compute_weight_local(feat_g,feat_ql,feat_sl) 77 | # print(compute_weight_local(feat_g,feat_ql,feat_sl)[0,0]) -------------------------------------------------------------------------------- /methods/bdc_module.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @file: bdc_modele.py 3 | @author: Fei Long 4 | @author: Jiaming Lv 5 | Please cite the paper below if you use the code: 6 | 7 | Jiangtao Xie, Fei Long, Jiaming Lv, Qilong Wang and Peihua Li. Joint Distribution Matters: Deep Brownian Distance Covariance for Few-Shot Classification. IEEE Int. Conf. on Computer Vision and Pattern Recognition (CVPR), 2022. 8 | 9 | Copyright (C) 2022 Fei Long and Jiaming Lv 10 | 11 | All rights reserved. 12 | ''' 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | class BDC(nn.Module): 18 | def __init__(self, is_vec=True, input_dim=640, dimension_reduction=None, activate='relu'): 19 | super(BDC, self).__init__() 20 | self.is_vec = is_vec 21 | self.dr = dimension_reduction 22 | self.activate = activate 23 | self.input_dim = input_dim[0] 24 | # self.input_dim = input_dim 25 | if self.dr is not None and self.dr != self.input_dim: 26 | if activate == 'relu': 27 | self.act = nn.ReLU(inplace=True) 28 | elif activate == 'leaky_relu': 29 | self.act = nn.LeakyReLU(0.1) 30 | else: 31 | self.act = nn.ReLU(inplace=True) 32 | 33 | self.conv_dr_block = nn.Sequential( 34 | nn.Conv2d(self.input_dim, self.dr, kernel_size=1, stride=1, bias=False), 35 | nn.BatchNorm2d(self.dr), 36 | self.act 37 | ) 38 | output_dim = self.dr if self.dr else self.input_dim 39 | if self.is_vec: 40 | self.output_dim = int(output_dim*(output_dim+1)/2) 41 | else: 42 | self.output_dim = int(output_dim*output_dim) 43 | 44 | self.temperature = nn.Parameter(torch.log((1. / (2 * input_dim[1]*input_dim[2])) * torch.ones(1,1)), requires_grad=True) 45 | 46 | self._init_weight() 47 | 48 | def _init_weight(self): 49 | for m in self.modules(): 50 | if isinstance(m, nn.Conv2d): 51 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out', nonlinearity='leaky_relu') 52 | elif isinstance(m, nn.BatchNorm2d): 53 | nn.init.constant_(m.weight, 1) 54 | nn.init.constant_(m.bias, 0) 55 | 56 | def forward(self, x): 57 | if self.dr is not None and self.dr != self.input_dim: 58 | x = self.conv_dr_block(x) 59 | x = BDCovpool(x, self.temperature) 60 | if self.is_vec: 61 | x = Triuvec(x) 62 | else: 63 | x = x.reshape(x.shape[0], -1) 64 | return x 65 | 66 | def BDCovpool(x, t): 67 | batchSize, dim, h, w = x.data.shape 68 | M = h * w 69 | x = x.reshape(batchSize, dim, M) 70 | 71 | I = torch.eye(dim, dim, device=x.device).view(1, dim, dim).repeat(batchSize, 1, 1).type(x.dtype) 72 | I_M = torch.ones(batchSize, dim, dim, device=x.device).type(x.dtype) 73 | x_pow2 = x.bmm(x.transpose(1, 2)) 74 | dcov = I_M.bmm(x_pow2 * I) + (x_pow2 * I).bmm(I_M) - 2 * x_pow2 75 | 76 | dcov = torch.clamp(dcov, min=0.0) 77 | dcov = torch.exp(t)* dcov 78 | dcov = torch.sqrt(dcov + 1e-5) 79 | t = dcov - 1. / dim * dcov.bmm(I_M) - 1. / dim * I_M.bmm(dcov) + 1. / (dim * dim) * I_M.bmm(dcov).bmm(I_M) 80 | 81 | return t 82 | 83 | 84 | def Triuvec(x): 85 | batchSize, dim, dim = x.shape 86 | r = x.reshape(batchSize, dim * dim) 87 | I = torch.ones(dim, dim).triu().reshape(dim * dim) 88 | index = I.nonzero(as_tuple = False) 89 | y = torch.zeros(batchSize, int(dim * (dim + 1) / 2), device=x.device).type(x.dtype) 90 | y = r[:, index].squeeze() 91 | return y 92 | 93 | if __name__ == '__main__': 94 | x = torch.rand((3, 4, 5, 5)) 95 | # bdc = BDC(input_dim=x.shape,dimension_reduction=4) 96 | 97 | t = torch.log((1. / (2 * 25)) * torch.ones(1,1)) 98 | print(BDCovpool(x,t)[0,:,:]) -------------------------------------------------------------------------------- /data/datamgr.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | from data.dataset import SetDataset_JSON, SimpleDataset, SetDataset, EpisodicBatchSampler, SimpleDataset_JSON 8 | from abc import abstractmethod 9 | 10 | 11 | class TransformLoader: 12 | def __init__(self, image_size): 13 | self.normalize_param = dict(mean=[0.472, 0.453, 0.410], std=[0.277, 0.268, 0.285]) 14 | 15 | self.image_size = image_size 16 | if image_size == 84: 17 | self.resize_size = 92 18 | elif image_size == 128: 19 | self.resize_size = 140 20 | elif image_size == 224: 21 | self.resize_size = 256 22 | 23 | def get_composed_transform(self, aug=False): 24 | if aug: 25 | transform = transforms.Compose([ 26 | transforms.RandomResizedCrop(self.image_size), 27 | transforms.RandomHorizontalFlip(), 28 | transforms.ColorJitter(0.4, 0.4, 0.4), 29 | transforms.ToTensor(), 30 | transforms.Normalize(**self.normalize_param) 31 | ]) 32 | else: 33 | transform = transforms.Compose([ 34 | transforms.Resize(self.resize_size), 35 | transforms.CenterCrop(self.image_size), 36 | transforms.ToTensor(), 37 | transforms.Normalize(**self.normalize_param) 38 | ]) 39 | return transform 40 | 41 | 42 | class DataManager: 43 | @abstractmethod 44 | def get_data_loader(self, data_file, aug): 45 | pass 46 | 47 | 48 | class SimpleDataManager(DataManager): 49 | def __init__(self, data_path, image_size, batch_size, json_read=False): 50 | super(SimpleDataManager, self).__init__() 51 | self.batch_size = batch_size 52 | self.data_path = data_path 53 | self.trans_loader = TransformLoader(image_size) 54 | self.json_read = json_read 55 | 56 | def get_data_loader(self, data_file, aug): # parameters that would change on train/val set 57 | transform = self.trans_loader.get_composed_transform(aug) 58 | if self.json_read: 59 | dataset = SimpleDataset_JSON(self.data_path, data_file, transform) 60 | else: 61 | dataset = SimpleDataset(self.data_path, data_file, transform) 62 | data_loader_params = dict(batch_size=self.batch_size, shuffle=True, num_workers=12, pin_memory=True) 63 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 64 | 65 | return data_loader 66 | 67 | 68 | class SetDataManager(DataManager): 69 | def __init__(self, data_path, image_size, n_way, n_support, n_query, n_episode, json_read=False,aug_num = 0,args=None): 70 | super(SetDataManager, self).__init__() 71 | self.image_size = image_size 72 | self.n_way = n_way 73 | self.batch_size = n_support + n_query 74 | self.n_episode = n_episode 75 | self.data_path = data_path 76 | self.json_read = json_read 77 | self.aug_num = aug_num 78 | self.args = args 79 | 80 | self.trans_loader = TransformLoader(image_size) 81 | 82 | def get_data_loader(self, data_file, aug): # parameters that would change on train/val set 83 | transform = self.trans_loader.get_composed_transform(aug) 84 | if self.json_read: 85 | # print(self.aug_num) 86 | dataset = SetDataset_JSON(self.data_path, data_file, self.batch_size, transform,aug_num=self.aug_num, args=self.args) 87 | else: 88 | dataset = SetDataset(self.data_path, data_file, self.batch_size, transform,aug_num=self.aug_num, args=self.args) 89 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_episode) 90 | data_loader_params = dict(batch_sampler=sampler, pin_memory=True) 91 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 92 | return data_loader 93 | 94 | 95 | 96 | data_loader 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import pprint 4 | import os 5 | import time 6 | from data.datamgr import SetDataManager 7 | from methods.FeatWalk import FeatWalk_Net 8 | from utils.utils import set_seed,load_model 9 | 10 | DATA_DIR = 'data' 11 | 12 | torch.set_num_threads(4) 13 | _utils_pp = pprint.PrettyPrinter() 14 | def pprint(x): 15 | _utils_pp.pprint(x) 16 | 17 | def parse_option(): 18 | parser = argparse.ArgumentParser('arguments for model pre-train') 19 | # about dataset and network 20 | parser.add_argument('--dataset', type=str, default='miniimagenet', 21 | choices=['miniimagenet', 'cub', 'tieredimagenet', 'fc100']) 22 | parser.add_argument('--data_root', type=str, default=DATA_DIR) 23 | parser.add_argument('--model', default='resnet12',choices=['resnet12', 'resnet18', 'resnet34', 'conv64']) 24 | parser.add_argument('--img_size', default=84, type=int, choices=[84,224]) 25 | 26 | # about model : 27 | parser.add_argument('--drop_gama', default=0.5, type= float) 28 | parser.add_argument("--beta", default=0.01, type=float) 29 | parser.add_argument('--drop_rate', default=0.5, type=float) 30 | parser.add_argument('--reduce_dim', default=128, type=int) 31 | 32 | # about meta test 33 | parser.add_argument('--val_freq',default=5,type=int) 34 | parser.add_argument('--set', type=str, default='test', choices=['val', 'test'], help='the set for validation') 35 | parser.add_argument('--n_way', type=int, default=5) 36 | parser.add_argument('--n_shot', type=int, default=1) 37 | parser.add_argument('--n_aug_support_samples',type=int, default=1) 38 | parser.add_argument('--n_queries', type=int, default=15) 39 | parser.add_argument('--n_episodes', type=int, default=1000) 40 | parser.add_argument('--num_workers', default=0, type=int) 41 | parser.add_argument('--test_batch_size',default=1) 42 | parser.add_argument('--grid',default=None) 43 | 44 | # setting 45 | parser.add_argument('--gpu', default=0, type=int) 46 | parser.add_argument('--save_dir', default='checkpoint') 47 | parser.add_argument('--test_LR', default=False, action='store_true') 48 | parser.add_argument('--model_type',default='best',choices=['best','last']) 49 | parser.add_argument('--seed', default=1, type=int) 50 | parser.add_argument('--no_save_model', default=False, action='store_true') 51 | parser.add_argument('--method',default='local_proto',choices=['local_proto','good_metric','stl_deepbdc','confusion','WinSA']) 52 | parser.add_argument('--distill_model', default=None,type=str,help='about distillation model path') 53 | parser.add_argument('--penalty_c', default=1.0, type=float) 54 | parser.add_argument('--test_times', default=1, type=int) 55 | 56 | # confusion representation: 57 | parser.add_argument('--n_symmetry_aug', default=1, type=int) 58 | parser.add_argument('--embeding_way', default='BDC', choices=['BDC','GE','protonet','baseline++']) 59 | parser.add_argument('--wd_test', type=float, default=0.01) 60 | parser.add_argument('--LR', default=False,action='store_true') 61 | parser.add_argument('--lr', default=0.01, type=float) 62 | parser.add_argument('--optim', default='Adam',choices=['Adam', 'SGD']) 63 | parser.add_argument('--drop_few',default=0.5,type=float) 64 | parser.add_argument('--fix_seed', default=False, action='store_true') 65 | parser.add_argument('--local_scale', default=0.2 , type=float) 66 | parser.add_argument('--distill', default=False, action='store_true') 67 | parser.add_argument('--sfc_bs', default=16, type=int) 68 | parser.add_argument('--alpha', default=0.5 , type=float) 69 | parser.add_argument('--sim_temperature', default=64 , type=float) 70 | parser.add_argument('--measure', default='cosine', choices=['cosine','eudist']) 71 | 72 | args = parser.parse_args() 73 | args.n_symmetry_aug = args.n_aug_support_samples 74 | 75 | return args 76 | 77 | 78 | def model_load(args,model): 79 | # method = 'deep_emd' if args.deep_emd else 'local_match' 80 | method = args.method 81 | save_path = os.path.join(args.save_dir, args.dataset + "_" + method + "_resnet12_"+args.model_type 82 | + ("_"+str(args.model_id) if args.model_id else "") + ".pth") 83 | if args.distill_model is not None: 84 | save_path = os.path.join(args.save_dir, args.distill_model) 85 | else: 86 | assert "model load failed! " 87 | print('teacher model path: ' + save_path) 88 | state_dict = torch.load(save_path)['model'] 89 | model.load_state_dict(state_dict) 90 | return model 91 | 92 | 93 | def main(): 94 | args = parse_option() 95 | if args.img_size == 224 and args.transform == 'B': 96 | args.transform = 'B224' 97 | 98 | if args.grid: 99 | args.n_aug_support_samples = 1 100 | for i in args.grid: 101 | args.n_aug_support_samples += i ** 2 102 | args.n_symmetry_aug = args.n_aug_support_samples 103 | 104 | pprint(args) 105 | if args.gpu: 106 | gpu_device = str(args.gpu) 107 | else: 108 | gpu_device = "0" 109 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_device 110 | if args.fix_seed: 111 | set_seed(args.seed) 112 | 113 | json_file_read = False 114 | if args.dataset == 'cub': 115 | novel_file = 'novel.json' 116 | json_file_read = True 117 | else: 118 | novel_file = 'test' 119 | if args.dataset == 'miniimagenet': 120 | novel_few_shot_params = dict(n_way=args.n_way, n_support=args.n_shot) 121 | novel_datamgr = SetDataManager('filelist/miniImageNet', args.img_size, n_query=args.n_queries, 122 | n_episode=args.n_episodes, json_read=json_file_read,aug_num=args.n_aug_support_samples,args=args, 123 | **novel_few_shot_params) 124 | novel_loader = novel_datamgr.get_data_loader(novel_file, aug=False) 125 | num_classes = 64 126 | elif args.dataset == 'cub': 127 | novel_few_shot_params = dict(n_way=args.n_way, n_support=args.n_shot) 128 | novel_datamgr = SetDataManager('filelist/CUB',args.img_size, n_query=args.n_queries, 129 | n_episode=args.n_episodes, json_read=json_file_read,aug_num=args.n_aug_support_samples,args=args, 130 | **novel_few_shot_params) 131 | novel_loader = novel_datamgr.get_data_loader(novel_file, aug=False) 132 | num_classes = 100 133 | 134 | model = FeatWalk_Net(args,num_classes=num_classes).cuda() 135 | model.eval() 136 | model = load_model(model,os.path.join(args.save_dir,args.distill_model)) 137 | 138 | print("-"*20+" start meta test... "+"-"*20) 139 | acc_sum = 0 140 | confidence_sum = 0 141 | for t in range(args.test_times): 142 | with torch.no_grad(): 143 | tic = time.time() 144 | mean, confidence = model.meta_test_loop(novel_loader) 145 | acc_sum += mean 146 | confidence_sum += confidence 147 | print() 148 | print("Time {} :meta_val acc: {:.2f} +- {:.2f} elapse: {:.2f} min".format(t,mean * 100, confidence * 100, 149 | (time.time() - tic) / 60)) 150 | 151 | print("{} times \t acc: {:.2f} +- {:.2f}".format(args.test_times, acc_sum/args.test_times * 100, confidence_sum/args.test_times * 100, )) 152 | 153 | if __name__ == '__main__': 154 | main() -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | 6 | class DistillKL(nn.Module): 7 | """KL divergence for distillation""" 8 | def __init__(self, T): 9 | super(DistillKL, self).__init__() 10 | self.T = T 11 | 12 | def forward(self, y_s, y_t): 13 | p_s = F.log_softmax(y_s/self.T, dim=1) 14 | p_t = F.softmax(y_t/self.T, dim=1) 15 | loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0] 16 | return loss 17 | 18 | def mask_loss(out,gama=0.5): 19 | # print(out.shape) 20 | crition = torch.nn.BCELoss() 21 | out = out.contiguous().view(out.shape[0],-1) 22 | avg_imp = torch.mean(out,dim=1).unsqueeze(1) 23 | rate_Sa = torch.mean(torch.where(out >= avg_imp, 1, 0).float(), dim=-1) 24 | imp_gama = 1 - rate_Sa * gama 25 | 26 | value, ind = torch.sort(out, dim=1, descending=True) 27 | drop_ind = torch.ceil((1 - imp_gama) * out.shape[-1]) 28 | threshold = value[range(out.shape[0]), drop_ind.long()] 29 | threshold = threshold.unsqueeze(1).expand_as(out) 30 | fore_mask = torch.where(out >= threshold, 1, 0).float() 31 | loss_mask = crition(out,fore_mask) 32 | # print(loss_mask) 33 | return loss_mask 34 | 35 | def uniformity_loss(feat, const_feat,label=None,temp=0.5): 36 | sim_aa = torch.cosine_similarity(feat, const_feat, dim=-1) 37 | feat_expand = feat.unsqueeze(0).repeat(feat.shape[0],1,1) 38 | const_feat_expand = const_feat.unsqueeze(1).expand_as(feat_expand) 39 | sim_ab = torch.cosine_similarity(feat_expand, const_feat_expand,dim=-1) 40 | sim_a = torch.exp(sim_aa/temp) 41 | sim_b = torch.exp(sim_ab/temp) 42 | sim_tot = torch.sum(sim_b + 1e-6,dim=-1) 43 | if label is not None: 44 | 45 | sim_idx = torch.cat([torch.sum(sim_b[i,torch.where(label.squeeze(0) == label.squeeze(0)[i])[0]],dim=-1).unsqueeze(0) 46 | for i in range(sim_b.shape[0])],dim=0) 47 | 48 | p = sim_idx/sim_tot 49 | 50 | else: 51 | p = sim_a / sim_tot 52 | loss = torch.mean(-torch.log(p+1e-8)) 53 | return loss 54 | 55 | def Distance_Correlation(latent, control): 56 | latent = F.normalize(latent) 57 | control = F.normalize(control) 58 | 59 | matrix_a = torch.sqrt(torch.sum(torch.square(latent.unsqueeze(0) - latent.unsqueeze(1)), dim=-1) + 1e-12) 60 | matrix_b = torch.sqrt(torch.sum(torch.square(control.unsqueeze(0) - control.unsqueeze(1)), dim=-1) + 1e-12) 61 | 62 | matrix_A = matrix_a - torch.mean(matrix_a, dim=0, keepdims=True) - torch.mean(matrix_a, dim=1, 63 | keepdims=True) + torch.mean(matrix_a) 64 | matrix_B = matrix_b - torch.mean(matrix_b, dim=0, keepdims=True) - torch.mean(matrix_b, dim=1, 65 | keepdims=True) + torch.mean(matrix_b) 66 | 67 | Gamma_XY = torch.sum(matrix_A * matrix_B) / (matrix_A.shape[0] * matrix_A.shape[1]) 68 | Gamma_XX = torch.sum(matrix_A * matrix_A) / (matrix_A.shape[0] * matrix_A.shape[1]) 69 | Gamma_YY = torch.sum(matrix_B * matrix_B) / (matrix_A.shape[0] * matrix_A.shape[1]) 70 | 71 | correlation_r = Gamma_XY / torch.sqrt(Gamma_XX * Gamma_YY + 1e-9) 72 | return correlation_r 73 | 74 | def area_loss(out,gama=0.5): 75 | # print(out.shape) 76 | out = out.contiguous().view(out.shape[0],-1) 77 | y = torch.mean(out,-1) 78 | avg_imp = torch.mean(out,dim=-1).unsqueeze(1) 79 | rate_Sa = torch.mean(torch.where(out >= avg_imp, 1, 0).float(), dim=-1) 80 | imp_gama = rate_Sa * gama 81 | imp_gama = torch.cat([imp_gama.unsqueeze(1),1-imp_gama.unsqueeze(1)],dim=-1) 82 | y = torch.cat([y.unsqueeze(1), 1 - y.unsqueeze(1)], dim=-1) 83 | loss_area = F.kl_div(y.log(),imp_gama, reduction='batchmean') 84 | return loss_area 85 | 86 | def cosine_sim(out,lab): 87 | 88 | if len(lab.size()) == 1: 89 | label = torch.zeros((out.size(0), 90 | out.size(1))).long().cuda() 91 | label_range = torch.arange(0, out.size(0)).long() 92 | label[label_range, lab] = 1 93 | lab = label 94 | 95 | return torch.mean(torch.abs(out) * lab) 96 | 97 | def ce_loss(out, lab,temperature=1,is_softmax = True): 98 | 99 | if is_softmax: 100 | out = F.softmax(out*temperature, 1) 101 | if len(lab.size()) == 1: 102 | label = torch.zeros((out.size(0), 103 | out.size(1))).long().cuda() 104 | label_range = torch.arange(0, out.size(0)).long() 105 | label[label_range, lab] = 1 106 | lab = label 107 | loss = torch.mean(torch.sum(-lab*torch.log(out+1e-8),1)) 108 | 109 | return loss 110 | 111 | # 计算信息熵的大小 112 | def entropy_loss(out): 113 | # crition = torch.nn.BCELoss() 114 | out = F.softmax(out, 1) 115 | # print(out) 116 | # pred = torch.ones_like(out)/out.shape[1] 117 | # loss = crition(pred,out) 118 | loss = -torch.mean(torch.sum(out*torch.log(out + 1e-8), 1)) 119 | return loss 120 | 121 | def Few_loss(out,lab): 122 | # 目的似乎是实现poly loss,但实践过程中有误 123 | # 这个损失意义不大 124 | out = F.softmax(out, 1) 125 | eps = 2 126 | n = 1 127 | poly_head = torch.zeros(out.size(0),out.size(1)).cuda() 128 | for i in range(n): 129 | poly_head += eps*1/(i+1)*torch.pow(1-out,(i+1)) 130 | ce_loss = torch.sum(-lab * torch.log(out + 1e-8) - poly_head,1) 131 | loss = torch.mean(ce_loss) 132 | return loss 133 | 134 | def loc_loss(out_loc,lab): 135 | 136 | out_loc = F.sigmoid(out_loc) 137 | # print(out_loc) 138 | log_loc = (-lab) * torch.log(out_loc + 1e-8)-(1-lab)* torch.log(out_loc + 1e-8) 139 | # loss = torch.mean(torch.sum(log_loc, 1)) 140 | loss = torch.mean(torch.mean(log_loc, 1)) 141 | 142 | # out_loc = out_loc.view(out_loc.size(0),out_loc.size(1),-1,2) 143 | # out_loc = F.softmax(out_loc,dim=3) 144 | 145 | return loss 146 | 147 | def euclidean_dist(x, y): 148 | ''' 149 | Compute euclidean distance between two tensors 150 | ''' 151 | # x: N x D 152 | # y: M x D 153 | n = x.size(0) 154 | m = y.size(0) 155 | d = x.size(1) 156 | if d != y.size(1): 157 | raise Exception 158 | 159 | # unsqueeze 在dim维度进行扩展 160 | x = x.unsqueeze(1).expand(n, m, d) 161 | y = y.unsqueeze(0).expand(n, m, d) 162 | 163 | return torch.pow(x - y, 2).sum(2) 164 | 165 | def prototypical_Loss(feat_out,lab,prototypes,epoch,center=False,temperature = 1): 166 | temperature = 256 167 | def supp_idxs(c): 168 | # FIXME when torch will support where as np 169 | return label_cpu.eq(c).nonzero()[:].squeeze(1) 170 | 171 | feat_cpu = feat_out.cpu() 172 | label_cpu = lab.cpu() 173 | prototypes = prototypes.cpu() 174 | 175 | n_classes = prototypes.size(0) 176 | if len(label_cpu.size()) == 1: 177 | classes = np.unique(label_cpu) 178 | # map :调用函数supp_idsx classes作为参数列表 179 | support_idxs = list(map(supp_idxs,classes)) 180 | prototypes_update = torch.stack([feat_cpu[idx_list].mean(0) for idx_list in support_idxs]) 181 | else: 182 | classes = range(n_classes) 183 | count = sum(label_cpu, 0) 184 | # feat_cpu dim : 64 * 640 185 | # label dim : 64 * 5 186 | prototypes_update = torch.matmul(feat_cpu.T,label_cpu.float())/torch.tensor(count).float() 187 | prototypes_update = prototypes_update.T 188 | # if epoch == 0 : 189 | # beta = 0.9 190 | prototypes[classes, :] = prototypes_update.detach() 191 | 192 | if len(lab.size()) == 1: 193 | label = torch.zeros((feat_cpu.size(0), 194 | n_classes)).long().cuda() 195 | label_range = torch.arange(0, feat_cpu.size(0)).long() 196 | label[label_range, lab] = 1 197 | lab = label 198 | dists = euclidean_dist(feat_cpu,prototypes)/temperature 199 | # print(dists.shape) 200 | log_p_y = F.log_softmax(-dists, dim=1) 201 | y = F.softmax(-dists,1) 202 | 203 | loss = torch.mean(torch.sum(-lab.cpu() * torch.log(y+1e-8),1)) 204 | # print(loss) 205 | return loss,prototypes 206 | 207 | if __name__ == '__main__': 208 | # exp = torch.rand((3,5,5)) 209 | # print(area_loss(torch.sigmoid(exp))) 210 | feat = torch.rand((5, 640,100)) 211 | feat_cons = torch.rand((5,640,100)) 212 | print(Distance_Correlation(feat,feat_cons)) 213 | # print(uniformity_loss(feat,feat_cons)) -------------------------------------------------------------------------------- /methods/template.py: -------------------------------------------------------------------------------- 1 | import math 2 | from sqlite3 import paramstyle 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from abc import abstractmethod 9 | from .bdc_module import * 10 | 11 | 12 | class BaselineTrain(nn.Module): 13 | def __init__(self, params, model_func, num_class): 14 | super(BaselineTrain, self).__init__() 15 | self.params = params 16 | self.feature = model_func() 17 | if params.method in ['stl_deepbdc', 'meta_deepbdc']: 18 | reduce_dim = params.reduce_dim 19 | self.feat_dim = int(reduce_dim * (reduce_dim+1) / 2) 20 | self.dcov = BDC(is_vec=True, input_dim=self.feature.feat_dim, dimension_reduction=reduce_dim) 21 | self.dropout = nn.Dropout(params.dropout_rate) 22 | 23 | elif params.method in ['protonet', 'good_embed']: 24 | self.feat_dim = self.feature.feat_dim[0] 25 | self.avgpool = nn.AdaptiveAvgPool2d(1) 26 | 27 | if params.method in ['stl_deepbdc', 'meta_deepbdc', 'protonet', 'good_embed']: 28 | self.classifier = nn.Linear(self.feat_dim, num_class) 29 | self.classifier.bias.data.fill_(0) 30 | 31 | self.num_class = num_class 32 | self.loss_fn = nn.CrossEntropyLoss() 33 | 34 | def feature_forward(self, x): 35 | out = self.feature.forward(x) 36 | if self.params.method in ['stl_deepbdc', 'meta_deepbdc']: 37 | out = self.dcov(out) 38 | out = self.dropout(out) 39 | elif self.params.method in ['protonet', 'good_embed']: 40 | out = self.avgpool(out).view(out.size(0), -1) 41 | return out 42 | 43 | def forward(self, x): 44 | x = Variable(x.cuda()) 45 | out = self.feature_forward(x) 46 | scores = self.classifier.forward(out) 47 | return scores 48 | 49 | def forward_meta_val(self, x): 50 | x = Variable(x.cuda()) 51 | x = x.contiguous().view(self.params.val_n_way * (self.params.n_shot + self.params.n_query), *x.size()[2:]) 52 | 53 | out = self.feature_forward(x) 54 | 55 | z_all = out.view(self.params.val_n_way, self.params.n_shot + self.params.n_query, -1) 56 | z_support = z_all[:, :self.params.n_shot] 57 | z_query = z_all[:, self.params.n_shot:] 58 | z_proto = z_support.contiguous().view(self.params.val_n_way, self.params.n_shot, -1).mean(1) 59 | z_query = z_query.contiguous().view(self.params.val_n_way * self.params.n_query, -1) 60 | 61 | if self.params.method in ['meta_deepbdc']: 62 | scores = self.metric(z_query, z_proto) 63 | elif self.params.method in ['protonet']: 64 | scores = self.euclidean_dist(z_query, z_proto) 65 | return scores 66 | 67 | def forward_loss(self, x, y): 68 | scores = self.forward(x) 69 | y = Variable(y.cuda()) 70 | return self.loss_fn(scores, y), scores 71 | 72 | def forward_meta_val_loss(self, x): 73 | y_query = torch.from_numpy(np.repeat(range(self.params.val_n_way), self.params.n_query)) 74 | y_query = Variable(y_query.cuda()) 75 | y_label = np.repeat(range(self.params.val_n_way), self.params.n_query) 76 | scores = self.forward_meta_val(x) 77 | topk_scores, topk_labels = scores.data.topk(1, 1, True, True) 78 | topk_ind = topk_labels.cpu().numpy() 79 | top1_correct = np.sum(topk_ind[:, 0] == y_label) 80 | return float(top1_correct), len(y_label), self.loss_fn(scores, y_query), scores 81 | 82 | def train_loop(self, epoch, train_loader, optimizer): 83 | print_freq = 200 84 | avg_loss = 0 85 | total_correct = 0 86 | 87 | iter_num = len(train_loader) 88 | total = len(train_loader) * self.params.batch_size 89 | 90 | for i, (x, y) in enumerate(train_loader): 91 | y = Variable(y.cuda()) 92 | optimizer.zero_grad() 93 | loss, output = self.forward_loss(x, y) 94 | pred = output.data.max(1)[1] 95 | total_correct += pred.eq(y.data.view_as(pred)).sum() 96 | loss.backward() 97 | optimizer.step() 98 | 99 | avg_loss = avg_loss + loss.item() 100 | 101 | if i % print_freq == 0: 102 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), avg_loss / float(i + 1))) 103 | return avg_loss / iter_num, float(total_correct) / total * 100 104 | 105 | def test_loop(self, val_loader): 106 | total_correct = 0 107 | avg_loss = 0.0 108 | total = len(val_loader) * self.params.batch_size 109 | with torch.no_grad(): 110 | for i, (x, y) in enumerate(val_loader): 111 | y = Variable(y.cuda()) 112 | loss, output = self.forward_loss(x, y) 113 | avg_loss = avg_loss + loss.item() 114 | pred = output.data.max(1)[1] 115 | total_correct += pred.eq(y.data.view_as(pred)).sum() 116 | avg_loss /= len(val_loader) 117 | acc = float(total_correct) / total 118 | # print('Test Acc = %4.2f%%, loss is %.2f' % (acc * 100, avg_loss)) 119 | return avg_loss, acc * 100 120 | 121 | def meta_test_loop(self, test_loader): 122 | acc_all = [] 123 | avg_loss = 0 124 | iter_num = len(test_loader) 125 | with torch.no_grad(): 126 | for i, (x, _) in enumerate(test_loader): 127 | correct_this, count_this, loss, _ = self.forward_meta_val_loss(x) 128 | acc_all.append(correct_this / count_this * 100) 129 | avg_loss = avg_loss + loss.item() 130 | acc_all = np.asarray(acc_all) 131 | acc_mean = np.mean(acc_all) 132 | acc_std = np.std(acc_all) 133 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) 134 | 135 | return avg_loss / iter_num, acc_mean 136 | 137 | def metric(self, x, y): 138 | # x: N x D 139 | # y: M x D 140 | n = x.size(0) 141 | m = y.size(0) 142 | d = x.size(1) 143 | assert d == y.size(1) 144 | 145 | x = x.unsqueeze(1).expand(n, m, d) 146 | y = y.unsqueeze(0).expand(n, m, d) 147 | 148 | if self.params.n_shot > 1: 149 | dist = torch.pow(x - y, 2).sum(2) 150 | score = -dist 151 | else: 152 | score = (x * y).sum(2) 153 | return score 154 | 155 | def euclidean_dist(self, x, y): 156 | # x: N x D 157 | # y: M x D 158 | n = x.size(0) 159 | m = y.size(0) 160 | d = x.size(1) 161 | assert d == y.size(1) 162 | 163 | x = x.unsqueeze(1).expand(n, m, d) 164 | y = y.unsqueeze(0).expand(n, m, d) 165 | 166 | score = -torch.pow(x - y, 2).sum(2) 167 | return score 168 | 169 | 170 | class MetaTemplate(nn.Module): 171 | def __init__(self, params, model_func, n_way, n_support, change_way=True): 172 | super(MetaTemplate, self).__init__() 173 | self.n_way = n_way 174 | self.n_support = n_support 175 | self.n_query = params.n_query # (change depends on input) 176 | self.feature = model_func() 177 | self.change_way = change_way # some methods allow different_way classification during training and test 178 | self.params = params 179 | 180 | @abstractmethod 181 | def set_forward(self, x, is_feature): 182 | pass 183 | 184 | @abstractmethod 185 | def set_forward_loss(self, x): 186 | pass 187 | 188 | @abstractmethod 189 | def feature_forward(self, x): 190 | pass 191 | 192 | def forward(self, x): 193 | out = self.feature.forward(x) 194 | return out 195 | 196 | def parse_feature(self, x, is_feature): 197 | x = Variable(x.cuda()) 198 | if is_feature: 199 | z_all = x 200 | else: 201 | x = x.contiguous().view(self.n_way * (self.n_support + self.n_query), *x.size()[2:]) 202 | x = self.feature.forward(x) 203 | z_all = self.feature_forward(x) 204 | z_all = z_all.view(self.n_way, self.n_support + self.n_query, -1) 205 | z_support = z_all[:, :self.n_support] 206 | 207 | z_query = z_all[:, self.n_support:] 208 | # print(z_query.shape) 209 | 210 | return z_support, z_query 211 | 212 | def correct(self, x): 213 | scores = self.set_forward(x) 214 | y_query = np.repeat(range(self.n_way), self.n_query) 215 | 216 | topk_scores, topk_labels = scores.data.topk(1, 1, True, True) 217 | topk_ind = topk_labels.cpu().numpy() 218 | top1_correct = np.sum(topk_ind[:, 0] == y_query) 219 | return float(top1_correct), len(y_query) 220 | 221 | def train_loop(self, epoch, train_loader, optimizer): 222 | print_freq = 200 223 | avg_loss = 0 224 | acc_all = [] 225 | iter_num = len(train_loader) 226 | for i, (x, _) in enumerate(train_loader): 227 | self.n_query = x.size(1) - self.n_support 228 | if self.change_way: 229 | self.n_way = x.size(0) 230 | optimizer.zero_grad() 231 | correct_this, count_this, loss, _ = self.set_forward_loss(x) 232 | acc_all.append(correct_this / count_this * 100) 233 | loss.backward() 234 | optimizer.step() 235 | avg_loss = avg_loss + loss.item() 236 | 237 | if i % print_freq == 0: 238 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), 239 | avg_loss / float(i + 1))) 240 | acc_all = np.asarray(acc_all) 241 | acc_mean = np.mean(acc_all) 242 | return avg_loss / iter_num, acc_mean 243 | 244 | def test_loop(self, test_loader, record=None): 245 | acc_all = [] 246 | avg_loss = 0 247 | iter_num = len(test_loader) 248 | with torch.no_grad(): 249 | for i, (x, _) in enumerate(test_loader): 250 | self.n_query = x.size(1) - self.n_support 251 | if self.change_way: 252 | self.n_way = x.size(0) 253 | correct_this, count_this, loss, _ = self.set_forward_loss(x) 254 | acc_all.append(correct_this / count_this * 100) 255 | avg_loss = avg_loss + loss.item() 256 | acc_all = np.asarray(acc_all) 257 | acc_mean = np.mean(acc_all) 258 | acc_std = np.std(acc_all) 259 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) 260 | 261 | return avg_loss / iter_num, acc_mean -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import json 6 | import numpy as np 7 | import torchvision.transforms as transforms 8 | import os 9 | 10 | identity = lambda x: x 11 | 12 | 13 | def get_grid_location(size, ratio, num_grid): 14 | ''' 15 | 16 | :param size: size of the height/width 17 | :param ratio: generate grid size/ even divided grid size 18 | :param num_grid: number of grid 19 | :return: a list containing the coordinate of the grid 20 | ''' 21 | raw_grid_size = int(size / num_grid) 22 | enlarged_grid_size = int(size / num_grid * ratio) 23 | 24 | center_location = raw_grid_size // 2 25 | 26 | location_list = [] 27 | for i in range(num_grid): 28 | location_list.append((max(0, center_location - enlarged_grid_size // 2), 29 | min(size, center_location + enlarged_grid_size // 2))) 30 | center_location = center_location + raw_grid_size 31 | 32 | return location_list 33 | 34 | 35 | class SimpleDataset: 36 | def __init__(self, data_path, data_file_list, transform, target_transform=identity): 37 | label = [] 38 | data = [] 39 | k = 0 40 | data_dir_list = data_file_list.replace(" ","").split(',') 41 | for data_file in data_dir_list: 42 | img_dir = data_path + '/' + data_file 43 | for i in os.listdir(img_dir): 44 | file_dir = os.path.join(img_dir, i) 45 | for j in os.listdir(file_dir): 46 | data.append(file_dir + '/' + j) 47 | label.append(k) 48 | k += 1 49 | self.data = data 50 | self.label = label 51 | self.transform = transform 52 | self.target_transform = target_transform 53 | 54 | def __getitem__(self, i): 55 | image_path = os.path.join(self.data[i]) 56 | img = Image.open(image_path).convert('RGB') 57 | img = self.transform(img) 58 | target = self.target_transform(self.label[i] - min(self.label)) 59 | return img, target 60 | 61 | def __len__(self): 62 | return len(self.label) 63 | 64 | 65 | class SetDataset: 66 | def __init__(self, data_path, data_file_list, batch_size, transform,aug_num=0,args=None): 67 | label = [] 68 | data = [] 69 | k = 0 70 | data_dir_list = data_file_list.replace(" ","").split(',') 71 | for data_file in data_dir_list: 72 | img_dir = data_path + '/' + data_file 73 | for i in os.listdir(img_dir): 74 | file_dir = os.path.join(img_dir, i) 75 | for j in os.listdir(file_dir): 76 | data.append(file_dir + '/' + j) 77 | label.append(k) 78 | k += 1 79 | self.data = data 80 | self.label = label 81 | self.transform = transform 82 | self.cl_list = np.unique(self.label).tolist() 83 | self.args = args 84 | 85 | self.sub_meta = {} 86 | for cl in self.cl_list: 87 | self.sub_meta[cl] = [] 88 | 89 | for x, y in zip(self.data, self.label): 90 | self.sub_meta[y].append(x) 91 | 92 | self.sub_dataloader = [] 93 | sub_data_loader_params = dict(batch_size=batch_size, 94 | shuffle=True, 95 | num_workers=0, # use main thread only or may receive multiple batches 96 | pin_memory=False) 97 | self.cl_num = 0 98 | for cl in self.cl_list: 99 | if len(self.sub_meta[cl])>=25: 100 | self.cl_num += 1 101 | sub_dataset = SubDataset(self.sub_meta[cl], cl, transform=transform,aug_num=aug_num,args=self.args) 102 | self.sub_dataloader.append(torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params)) 103 | 104 | def __getitem__(self, i): 105 | return next(iter(self.sub_dataloader[i])) 106 | 107 | def __len__(self): 108 | return self.cl_num 109 | 110 | 111 | class SubDataset: 112 | def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), target_transform=identity,aug_num=0,args=None): 113 | self.sub_meta = sub_meta 114 | self.cl = cl 115 | self.transform = transform 116 | self.target_transform = target_transform 117 | self.aug_num = aug_num 118 | self.grid = args.grid 119 | self.transform_grid = transforms.Compose([ 120 | transforms.Resize([args.img_size,args.img_size]), 121 | transforms.ToTensor(), 122 | transforms.Normalize(mean=[0.472, 0.453, 0.410], std=[0.277, 0.268, 0.285]) 123 | ]) 124 | 125 | self.transform_s = transforms.Compose([ 126 | transforms.RandomResizedCrop(args.img_size, scale=(args.local_scale, args.local_scale)), 127 | transforms.RandomHorizontalFlip(), 128 | transforms.ToTensor(), 129 | transforms.Normalize(mean=[0.472, 0.453, 0.410], std=[0.277, 0.268, 0.285]) 130 | ]) 131 | 132 | def __getitem__(self, i): 133 | image_path = os.path.join(self.sub_meta[i]) 134 | img = Image.open(image_path).convert('RGB') 135 | img_set = [] 136 | img_w = self.transform(img) 137 | img_set.append(img_w.unsqueeze(0)) 138 | if self.grid: 139 | for num_patch in self.grid: 140 | patches = self.get_pyramid(img, num_patch) 141 | # print(patches.shape) 142 | img_set.append(patches) 143 | else: 144 | for _ in range(self.aug_num - 1): 145 | img_s = self.transform_s(img) 146 | img_set.append(img_s.unsqueeze(0)) 147 | # for item in img_set: 148 | # print(item.shape) 149 | img = torch.cat(img_set, dim=0) 150 | target = self.target_transform(self.cl) 151 | return img, target 152 | 153 | def get_pyramid(self, img, num_patch): 154 | num_grid = num_patch 155 | grid_ratio = 1 156 | w, h = img.size 157 | grid_locations_w = get_grid_location(w, grid_ratio, num_grid) 158 | grid_locations_h = get_grid_location(h, grid_ratio, num_grid) 159 | 160 | patches_list = [] 161 | for i in range(num_grid): 162 | for j in range(num_grid): 163 | patch_location_w = grid_locations_w[j] 164 | patch_location_h = grid_locations_h[i] 165 | left_up_corner_w = patch_location_w[0] 166 | left_up_corner_h = patch_location_h[0] 167 | right_down_cornet_w = patch_location_w[1] 168 | right_down_cornet_h = patch_location_h[1] 169 | patch = img.crop((left_up_corner_w, left_up_corner_h, right_down_cornet_w, right_down_cornet_h)) 170 | patch = self.transform_grid(patch) 171 | patches_list.append(patch.unsqueeze(0)) 172 | return torch.cat(patches_list,dim=0) 173 | 174 | def __len__(self): 175 | return len(self.sub_meta) 176 | 177 | 178 | class SimpleDataset_JSON: 179 | def __init__(self, data_path, data_file, transform, target_transform=identity): 180 | data = data_path + '/' + data_file 181 | with open(data, 'r') as f: 182 | self.meta = json.load(f) 183 | self.transform = transform 184 | self.target_transform = target_transform 185 | 186 | def __getitem__(self, i): 187 | image_path = os.path.join(self.meta['image_names'][i]) 188 | img = Image.open(image_path).convert('RGB') 189 | img = self.transform(img) 190 | target = self.target_transform(self.meta['image_labels'][i]) 191 | return img, target 192 | 193 | def __len__(self): 194 | return len(self.meta['image_names']) 195 | 196 | 197 | class SetDataset_JSON: 198 | def __init__(self, data_path, data_file, batch_size, transform,aug_num=0,args=None): 199 | data = data_path + '/' + data_file 200 | 201 | print(transform.__dict__) 202 | with open(data, 'r') as f: 203 | self.meta = json.load(f) 204 | 205 | self.cl_list = np.unique(self.meta['image_labels']).tolist() 206 | self.args = args 207 | 208 | self.sub_meta = {} 209 | for cl in self.cl_list: 210 | self.sub_meta[cl] = [] 211 | 212 | for x, y in zip(self.meta['image_names'], self.meta['image_labels']): 213 | self.sub_meta[y].append(x) 214 | 215 | self.sub_dataloader = [] 216 | # print(len(self.cl_list)) 217 | sub_data_loader_params = dict(batch_size=batch_size, 218 | shuffle=True, 219 | num_workers=0, # use main thread only or may receive multiple batches 220 | pin_memory=False) 221 | for cl in self.cl_list: 222 | sub_dataset = SubDataset_JSON(self.sub_meta[cl], cl, transform=transform,aug_num=aug_num,args=self.args) 223 | self.sub_dataloader.append(torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params)) 224 | 225 | def __getitem__(self, i): 226 | return next(iter(self.sub_dataloader[i])) 227 | 228 | def __len__(self): 229 | return len(self.cl_list) 230 | 231 | 232 | class SubDataset_JSON: 233 | def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), target_transform=identity,aug_num=0,args=None): 234 | self.sub_meta = sub_meta 235 | self.cl = cl 236 | self.transform = transform 237 | self.target_transform = target_transform 238 | self.grid = args.grid 239 | self.transform_grid = transforms.Compose([ 240 | transforms.Resize([args.img_size, args.img_size]), 241 | transforms.ToTensor(), 242 | transforms.Normalize(mean=[0.472, 0.453, 0.410], std=[0.277, 0.268, 0.285]) 243 | ]) 244 | 245 | self.transform_s = transforms.Compose([ 246 | # transforms.RandomResizedCrop(224, scale=(0.3, 0.7)), 247 | transforms.RandomResizedCrop(args.img_size, scale=(args.local_scale, args.local_scale)), 248 | # transforms.RandomResizedCrop(args.img_size), 249 | transforms.RandomHorizontalFlip(), 250 | # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 251 | transforms.ToTensor(), 252 | transforms.Normalize(mean=[0.472, 0.453, 0.410], std=[0.277, 0.268, 0.285]) 253 | ]) 254 | # print(aug_num) 255 | self.aug_num =aug_num 256 | 257 | def __getitem__(self, i): 258 | # print( '%d -%d' %(self.cl,i)) 259 | image_path = os.path.join(self.sub_meta[i]) 260 | img = Image.open(image_path).convert('RGB') 261 | img_set = [] 262 | img_w = self.transform(img) 263 | img_set.append(img_w.unsqueeze(0)) 264 | if self.grid: 265 | for num_patch in self.grid: 266 | patches = self.get_pyramid(img, num_patch) 267 | img_set.append(patches) 268 | else: 269 | for _ in range(self.aug_num - 1): 270 | img_s = self.transform_s(img) 271 | img_set.append(img_s.unsqueeze(0)) 272 | img = torch.cat(img_set,dim=0) 273 | target = self.target_transform(self.cl) 274 | return img, target 275 | 276 | def get_pyramid(self, img, num_patch): 277 | num_grid = num_patch 278 | grid_ratio = 1 279 | w, h = img.size 280 | grid_locations_w = get_grid_location(w, grid_ratio, num_grid) 281 | grid_locations_h = get_grid_location(h, grid_ratio, num_grid) 282 | 283 | patches_list = [] 284 | for i in range(num_grid): 285 | for j in range(num_grid): 286 | patch_location_w = grid_locations_w[j] 287 | patch_location_h = grid_locations_h[i] 288 | left_up_corner_w = patch_location_w[0] 289 | left_up_corner_h = patch_location_h[0] 290 | right_down_cornet_w = patch_location_w[1] 291 | right_down_cornet_h = patch_location_h[1] 292 | patch = img.crop((left_up_corner_w, left_up_corner_h, right_down_cornet_w, right_down_cornet_h)) 293 | patch = self.transform_grid(patch) 294 | patches_list.append(patch.unsqueeze(0)) 295 | return torch.cat(patches_list, dim=0) 296 | 297 | 298 | def __len__(self): 299 | return len(self.sub_meta) 300 | 301 | 302 | class EpisodicBatchSampler(object): 303 | def __init__(self, n_classes, n_way, n_episodes): 304 | self.n_classes = n_classes 305 | self.n_way = n_way 306 | self.n_episodes = n_episodes 307 | 308 | def __len__(self): 309 | return self.n_episodes 310 | 311 | def __iter__(self): 312 | for i in range(self.n_episodes): 313 | yield torch.randperm(self.n_classes)[:self.n_way] 314 | 315 | 316 | -------------------------------------------------------------------------------- /network/resnet.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from torch.autograd import Variable 5 | import torch.nn as nn 6 | import math 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from torch.nn.utils.weight_norm import WeightNorm 10 | from torch.distributions import Bernoulli 11 | 12 | ############################################## 13 | # Basic ResNet model # 14 | ############################################## 15 | 16 | def init_layer(L): 17 | # Initialization using fan-in 18 | if isinstance(L, nn.Conv2d): 19 | n = L.kernel_size[0] * L.kernel_size[1] * L.out_channels 20 | L.weight.data.normal_(0, math.sqrt(2.0 / float(n))) 21 | elif isinstance(L, nn.BatchNorm2d): 22 | L.weight.data.fill_(1) 23 | L.bias.data.fill_(0) 24 | 25 | class Flatten(nn.Module): 26 | def __init__(self): 27 | super(Flatten, self).__init__() 28 | 29 | def forward(self, x): 30 | return x.view(x.size(0), -1) 31 | 32 | # Simple ResNet Block 33 | class SimpleBlock(nn.Module): 34 | maml = False # Default 35 | 36 | def __init__(self, indim, outdim, half_res): 37 | super(SimpleBlock, self).__init__() 38 | self.indim = indim 39 | self.outdim = outdim 40 | self.C1 = nn.Conv2d(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False) 41 | self.BN1 = nn.BatchNorm2d(outdim) 42 | self.C2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1, bias=False) 43 | self.BN2 = nn.BatchNorm2d(outdim) 44 | 45 | self.relu1 = nn.ReLU(inplace=True) 46 | self.relu2 = nn.ReLU(inplace=True) 47 | 48 | self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2] 49 | 50 | self.half_res = half_res 51 | 52 | # if the input number of channels is not equal to the output, then need a 1x1 convolution 53 | if indim != outdim: 54 | 55 | self.shortcut = nn.Conv2d(indim, outdim, 1, 2 if half_res else 1, bias=False) 56 | self.BNshortcut = nn.BatchNorm2d(outdim) 57 | 58 | self.parametrized_layers.append(self.shortcut) 59 | self.parametrized_layers.append(self.BNshortcut) 60 | self.shortcut_type = '1x1' 61 | else: 62 | self.shortcut_type = 'identity' 63 | 64 | for layer in self.parametrized_layers: 65 | init_layer(layer) 66 | 67 | def forward(self, x): 68 | out = self.C1(x) 69 | out = self.BN1(out) 70 | out = self.relu1(out) 71 | out = self.C2(out) 72 | out = self.BN2(out) 73 | short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x)) 74 | out = out + short_out 75 | out = self.relu2(out) 76 | return out 77 | 78 | 79 | # Bottleneck block 80 | class BottleneckBlock(nn.Module): 81 | maml = False # Default 82 | 83 | def __init__(self, indim, outdim, half_res): 84 | super(BottleneckBlock, self).__init__() 85 | bottleneckdim = int(outdim / 4) 86 | self.indim = indim 87 | self.outdim = outdim 88 | self.C1 = nn.Conv2d(indim, bottleneckdim, kernel_size=1, bias=False) 89 | self.BN1 = nn.BatchNorm2d(bottleneckdim) 90 | self.C2 = nn.Conv2d(bottleneckdim, bottleneckdim, kernel_size=3, stride=2 if half_res else 1, padding=1) 91 | self.BN2 = nn.BatchNorm2d(bottleneckdim) 92 | self.C3 = nn.Conv2d(bottleneckdim, outdim, kernel_size=1, bias=False) 93 | self.BN3 = nn.BatchNorm2d(outdim) 94 | 95 | self.relu = nn.ReLU() 96 | self.parametrized_layers = [self.C1, self.BN1, self.C2, self.BN2, self.C3, self.BN3] 97 | self.half_res = half_res 98 | 99 | # if the input number of channels is not equal to the output, then need a 1x1 convolution 100 | if indim != outdim: 101 | self.shortcut = nn.Conv2d(indim, outdim, 1, stride=2 if half_res else 1, bias=False) 102 | 103 | self.parametrized_layers.append(self.shortcut) 104 | self.shortcut_type = '1x1' 105 | else: 106 | self.shortcut_type = 'identity' 107 | 108 | for layer in self.parametrized_layers: 109 | init_layer(layer) 110 | 111 | def forward(self, x): 112 | 113 | short_out = x if self.shortcut_type == 'identity' else self.shortcut(x) 114 | out = self.C1(x) 115 | out = self.BN1(out) 116 | out = self.relu(out) 117 | out = self.C2(out) 118 | out = self.BN2(out) 119 | out = self.relu(out) 120 | out = self.C3(out) 121 | out = self.BN3(out) 122 | out = out + short_out 123 | 124 | out = self.relu(out) 125 | return out 126 | 127 | 128 | 129 | class ResNet(nn.Module): 130 | maml = False # Default 131 | 132 | def __init__(self, block, list_of_num_layers, list_of_out_dims, flatten=False): 133 | # list_of_num_layers specifies number of layers in each stage 134 | # list_of_out_dims specifies number of output channel for each stage 135 | super(ResNet, self).__init__() 136 | assert len(list_of_num_layers) == 4, 'Can have only four stages' 137 | 138 | conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 139 | bias=False) 140 | bn1 = nn.BatchNorm2d(64) 141 | 142 | relu = nn.ReLU() 143 | pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 144 | 145 | init_layer(conv1) 146 | init_layer(bn1) 147 | trunk = [conv1, bn1, relu, pool1] 148 | 149 | indim = 64 150 | for i in range(4): 151 | for j in range(list_of_num_layers[i]): 152 | half_res = (i >= 1) and (j == 0) and i != 3 153 | B = block(indim, list_of_out_dims[i], half_res) 154 | trunk.append(B) 155 | indim = list_of_out_dims[i] 156 | 157 | if flatten: 158 | avgpool = nn.AvgPool2d(7) 159 | trunk.append(avgpool) 160 | trunk.append(Flatten()) 161 | # self.final_feat_dim = indim 162 | 163 | self.feat_dim = [512, 14, 14] 164 | self.trunk = nn.Sequential(*trunk) 165 | 166 | def forward(self, x): 167 | out = self.trunk(x) 168 | # out = out.view(out.size(0), -1) 169 | return out 170 | 171 | 172 | def ResNet10(flatten=True): 173 | return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], flatten) 174 | 175 | 176 | def ResNet18(flatten=False): 177 | return ResNet(SimpleBlock, [2, 2, 2, 2], [64, 128, 256, 512], flatten) 178 | 179 | 180 | def ResNet34(flatten=True): 181 | return ResNet(SimpleBlock, [3, 4, 6, 3], [64, 128, 256, 512], flatten) 182 | 183 | 184 | def ResNet50(flatten=True): 185 | return ResNet(BottleneckBlock, [3, 4, 6, 3], [256, 512, 1024, 2048], flatten) 186 | 187 | 188 | def ResNet101(flatten=True): 189 | return ResNet(BottleneckBlock, [3, 4, 23, 3], [256, 512, 1024, 2048], flatten) 190 | 191 | 192 | ############################################## 193 | # a variant of ResNet model # 194 | ############################################## 195 | 196 | def conv3x3(in_planes, out_planes, stride=1): 197 | """3x3 convolution with padding""" 198 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 199 | padding=1, bias=False) 200 | 201 | 202 | class SELayer(nn.Module): 203 | def __init__(self, channel, reduction=16): 204 | super(SELayer, self).__init__() 205 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 206 | self.fc = nn.Sequential( 207 | nn.Linear(channel, channel // reduction), 208 | nn.ReLU(inplace=True), 209 | nn.Linear(channel // reduction, channel), 210 | nn.Sigmoid() 211 | ) 212 | 213 | def forward(self, x): 214 | b, c, _, _ = x.size() 215 | y = self.avg_pool(x).view(b, c) 216 | y = self.fc(y).view(b, c, 1, 1) 217 | return x * y 218 | 219 | 220 | class DropBlock(nn.Module): 221 | def __init__(self, block_size): 222 | super(DropBlock, self).__init__() 223 | 224 | self.block_size = block_size 225 | #self.gamma = gamma 226 | #self.bernouli = Bernoulli(gamma) 227 | 228 | def forward(self, x, gamma): 229 | # shape: (bsize, channels, height, width) 230 | 231 | if self.training: 232 | batch_size, channels, height, width = x.shape 233 | 234 | bernoulli = Bernoulli(gamma) 235 | mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda() 236 | block_mask = self._compute_block_mask(mask) 237 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3] 238 | count_ones = block_mask.sum() 239 | 240 | return block_mask * x * (countM / count_ones) 241 | else: 242 | return x 243 | 244 | def _compute_block_mask(self, mask): 245 | left_padding = int((self.block_size-1) / 2) 246 | right_padding = int(self.block_size / 2) 247 | 248 | batch_size, channels, height, width = mask.shape 249 | #print ("mask", mask[0][0]) 250 | non_zero_idxs = mask.nonzero() 251 | nr_blocks = non_zero_idxs.shape[0] 252 | 253 | offsets = torch.stack( 254 | [ 255 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding, 256 | torch.arange(self.block_size).repeat(self.block_size), #- left_padding 257 | ] 258 | ).t().cuda() 259 | offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()), 1) 260 | 261 | if nr_blocks > 0: 262 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1) 263 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4) 264 | offsets = offsets.long() 265 | 266 | block_idxs = non_zero_idxs + offsets 267 | #block_idxs += left_padding 268 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 269 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1. 270 | else: 271 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 272 | 273 | block_mask = 1 - padded_mask#[:height, :width] 274 | return block_mask 275 | 276 | 277 | class BasicBlockVariant(nn.Module): 278 | expansion = 1 279 | 280 | def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False, 281 | block_size=1, use_se=False): 282 | super(BasicBlockVariant, self).__init__() 283 | self.conv1 = conv3x3(inplanes, planes) 284 | self.bn1 = nn.BatchNorm2d(planes) 285 | self.relu = nn.LeakyReLU(0.1) 286 | self.conv2 = conv3x3(planes, planes) 287 | self.bn2 = nn.BatchNorm2d(planes) 288 | self.conv3 = conv3x3(planes, planes) 289 | self.bn3 = nn.BatchNorm2d(planes) 290 | self.maxpool = nn.MaxPool2d(stride) 291 | self.downsample = downsample 292 | self.stride = stride 293 | self.drop_rate = drop_rate 294 | self.num_batches_tracked = 0 295 | self.drop_block = drop_block 296 | self.block_size = block_size 297 | self.DropBlock = DropBlock(block_size=self.block_size) 298 | self.use_se = use_se 299 | if self.use_se: 300 | self.se = SELayer(planes, 4) 301 | 302 | def forward(self, x): 303 | self.num_batches_tracked += 1 304 | 305 | residual = x 306 | 307 | out = self.conv1(x) 308 | out = self.bn1(out) 309 | out = self.relu(out) 310 | 311 | out = self.conv2(out) 312 | out = self.bn2(out) 313 | out = self.relu(out) 314 | 315 | out = self.conv3(out) 316 | out = self.bn3(out) 317 | if self.use_se: 318 | out = self.se(out) 319 | 320 | if self.downsample is not None: 321 | residual = self.downsample(x) 322 | out += residual 323 | out = self.relu(out) 324 | out = self.maxpool(out) 325 | 326 | if self.drop_rate > 0: 327 | if self.drop_block == True: 328 | feat_size = out.size()[2] 329 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate) 330 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2 331 | out = self.DropBlock(out, gamma=gamma) 332 | else: 333 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True) 334 | 335 | return out 336 | 337 | 338 | class resnet(nn.Module): 339 | 340 | def __init__(self, block, n_blocks, keep_prob=1.0, avg_pool=False, drop_rate=0.0, 341 | dropblock_size=5, num_classes=-1, use_se=False): 342 | super(resnet, self).__init__() 343 | 344 | self.inplanes = 3 345 | self.use_se = use_se 346 | self.layer1 = self._make_layer(block, n_blocks[0], 64, 347 | stride=2, drop_rate=drop_rate) 348 | self.layer2 = self._make_layer(block, n_blocks[1], 160, 349 | stride=2, drop_rate=drop_rate) 350 | self.layer3 = self._make_layer(block, n_blocks[2], 320, 351 | stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 352 | self.layer4 = self._make_layer(block, n_blocks[3], 640, 353 | stride=1, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 354 | self.keep_prob = keep_prob 355 | self.keep_avg_pool = avg_pool 356 | self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False) 357 | self.drop_rate = drop_rate 358 | self.feat_dim = [640, 10, 10] 359 | 360 | for m in self.modules(): 361 | if isinstance(m, nn.Conv2d): 362 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 363 | elif isinstance(m, nn.BatchNorm2d): 364 | nn.init.constant_(m.weight, 1) 365 | nn.init.constant_(m.bias, 0) 366 | 367 | self.num_classes = num_classes 368 | if self.num_classes > 0: 369 | self.classifier = nn.Linear(640, self.num_classes) 370 | 371 | def _make_layer(self, block, n_block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1): 372 | downsample = None 373 | if stride != 1 or self.inplanes != planes * block.expansion: 374 | downsample = nn.Sequential( 375 | nn.Conv2d(self.inplanes, planes * block.expansion, 376 | kernel_size=1, stride=1, bias=False), 377 | nn.BatchNorm2d(planes * block.expansion), 378 | ) 379 | 380 | layers = [] 381 | if n_block == 1: 382 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size, self.use_se) 383 | else: 384 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, self.use_se) 385 | layers.append(layer) 386 | self.inplanes = planes * block.expansion 387 | 388 | for i in range(1, n_block): 389 | if i == n_block - 1: 390 | layer = block(self.inplanes, planes, drop_rate=drop_rate, drop_block=drop_block, 391 | block_size=block_size, use_se=self.use_se) 392 | else: 393 | layer = block(self.inplanes, planes, drop_rate=drop_rate, use_se=self.use_se) 394 | layers.append(layer) 395 | 396 | return nn.Sequential(*layers) 397 | 398 | def forward(self, x, ): 399 | x = self.layer1(x) 400 | x = self.layer2(x) 401 | x = self.layer3(x) 402 | x = self.layer4(x) 403 | return x 404 | 405 | def ResNet12(keep_prob=1.0, avg_pool=True, **kwargs): 406 | """Constructs a ResNet-12 model. 407 | """ 408 | model = resnet(BasicBlockVariant, [1, 1, 1, 1], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs) 409 | return model 410 | 411 | def ResNet34s(keep_prob=1.0, avg_pool=False, **kwargs): 412 | """Constructs a ResNet-24 model. 413 | """ 414 | model = resnet(BasicBlockVariant, [2, 3, 4, 2], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs) 415 | return model 416 | 417 | 418 | if __name__ == '__main__': 419 | import argparse 420 | 421 | parser = argparse.ArgumentParser('argument for training') 422 | parser.add_argument('--model', type=str, default='resnet12',choices=['resnet12', 'resnet18', 'resnet24', 'resnet50', 'resnet101', 423 | 'seresnet12', 'seresnet18', 'seresnet24', 'seresnet50', 424 | 'seresnet101']) 425 | args = parser.parse_args() 426 | 427 | model_dict = { 428 | 'resnet12': ResNet12, 429 | } 430 | 431 | model = model_dict[args.model](avg_pool=True, drop_rate=0.1, dropblock_size=5, num_classes=64) 432 | data = torch.randn(2, 3, 84, 84) 433 | model = model.cuda() 434 | data = data.cuda() 435 | feat, logit = model(data, is_feat=True) 436 | print(feat[-1].shape) 437 | print(logit.shape)(logit.shape) -------------------------------------------------------------------------------- /methods/FeatWalk.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | import tqdm 6 | from sklearn.pipeline import make_pipeline 7 | from sklearn.preprocessing import StandardScaler 8 | from sklearn.svm import SVC, LinearSVC 9 | from methods.bdc_module import BDC 10 | import torch.nn.functional as F 11 | 12 | sys.path.append("..") 13 | import scipy 14 | from scipy.stats import t 15 | import network.resnet as resnet 16 | from utils.loss import * 17 | from sklearn.linear_model import LogisticRegression as LR 18 | from utils.loss import DistillKL 19 | from utils.utils import * 20 | import math 21 | from torch.nn.utils.weight_norm import WeightNorm 22 | 23 | import warnings 24 | warnings.filterwarnings("ignore") 25 | 26 | 27 | def mean_confidence_interval(data, confidence=0.95,multi = 1): 28 | a = 1.0 * np.array(data) 29 | n = len(a) 30 | m, se = np.mean(a), scipy.stats.sem(a) 31 | h = se * t._ppf((1+confidence)/2., n-1) 32 | return m * multi, h * multi 33 | 34 | def normalize(x): 35 | norm = x.pow(2).sum(1, keepdim=True).pow(1. / 2) 36 | out = x.div(norm) 37 | return out 38 | 39 | def random_sample(linspace, max_idx, num_sample=5): 40 | sample_idx = np.random.choice(range(linspace), num_sample) 41 | sample_idx += np.sort(random.sample(list(range(0, max_idx, linspace)),num_sample)) 42 | return sample_idx 43 | 44 | def Triuvec(x,no_diag = False): 45 | batchSize, dim, dim = x.shape 46 | r = x.reshape(batchSize, dim * dim) 47 | I = torch.ones(dim, dim).triu() 48 | if no_diag: 49 | I -= torch.eye(dim,dim) 50 | I = I.reshape(dim * dim) 51 | index = I.nonzero(as_tuple = False) 52 | # y = torch.zeros(batchSize, int(dim * (dim + 1) / 2), device=x.device).type(x.dtype) 53 | y = r[:, index].squeeze() 54 | return y 55 | 56 | def Triumap(x,no_diag = False): 57 | 58 | batchSize, dim, dim, h, w = x.shape 59 | r = x.reshape(batchSize, dim * dim, h, w) 60 | I = torch.ones(dim, dim).triu() 61 | if no_diag: 62 | I -= torch.eye(dim,dim) 63 | I = I.reshape(dim * dim) 64 | index = I.nonzero(as_tuple = False) 65 | # y = torch.zeros(batchSize, int(dim * (dim + 1) / 2), device=x.device).type(x.dtype) 66 | y = r[:, index, :, :].squeeze() 67 | return y 68 | 69 | def Diagvec(x): 70 | batchSize, dim, dim = x.shape 71 | r = x.reshape(batchSize, dim * dim) 72 | I = torch.eye(dim, dim).triu().reshape(dim * dim) 73 | index = I.nonzero(as_tuple = False) 74 | y = r[:, index].squeeze() 75 | return y 76 | 77 | class FeatWalk_Net(nn.Module): 78 | def __init__(self,params,num_classes = 5,): 79 | super(FeatWalk_Net, self).__init__() 80 | 81 | self.params = params 82 | 83 | if params.model == 'resnet12': 84 | self.feature = resnet.ResNet12(avg_pool=True,num_classes=64) 85 | resnet_layer_dim = [64, 160, 320, 640] 86 | elif params.model == 'resnet18': 87 | self.feature = resnet.ResNet18() 88 | resnet_layer_dim = [64, 128, 256, 512] 89 | 90 | self.resnet_layer_dim = resnet_layer_dim 91 | self.reduce_dim = params.reduce_dim 92 | self.feat_dim = self.feature.feat_dim 93 | self.dim = int(self.reduce_dim * (self.reduce_dim+1)/2) 94 | if resnet_layer_dim[-1] != self.reduce_dim: 95 | 96 | self.Conv = nn.Sequential( 97 | nn.Conv2d(resnet_layer_dim[-1], self.reduce_dim, kernel_size=1, stride=1, bias=False), 98 | nn.BatchNorm2d(self.reduce_dim), 99 | nn.ReLU(inplace=True) 100 | ) 101 | self._init_weight(self.Conv.modules()) 102 | 103 | drop_rate = params.drop_rate 104 | if self.params.embeding_way in ['BDC']: 105 | self.SFC = nn.Linear(self.dim, num_classes) 106 | self.SFC.bias.data.fill_(0) 107 | elif self.params.embeding_way in ['baseline++']: 108 | self.SFC = nn.Linear(self.reduce_dim, num_classes, bias=False) 109 | WeightNorm.apply(self.SFC, 'weight', dim=0) 110 | else: 111 | self.SFC = nn.Linear(self.reduce_dim, num_classes) 112 | 113 | self.drop = nn.Dropout(drop_rate) 114 | 115 | self.temperature = nn.Parameter(torch.log((1. /(2 * self.feat_dim[1] * self.feat_dim[2])* torch.ones(1, 1))), 116 | requires_grad=True) 117 | 118 | self.dcov = BDC(is_vec=True, input_dim=[self.reduce_dim,self.feature.feat_dim[1],self.feature.feat_dim[2]], dimension_reduction=self.reduce_dim) 119 | 120 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 121 | 122 | if resnet_layer_dim[-1] != self.reduce_dim: 123 | self.dcov.conv_dr_block = self.Conv 124 | 125 | self.n_shot = params.n_shot 126 | self.n_way = params.n_way 127 | self.transform_aug = params.n_aug_support_samples 128 | 129 | def _init_weight(self,modules): 130 | for m in modules: 131 | if isinstance(m, nn.Conv2d): 132 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out', nonlinearity='leaky_relu') 133 | elif isinstance(m, nn.BatchNorm2d): 134 | nn.init.constant_(m.weight, 1) 135 | nn.init.constant_(m.bias, 0) 136 | 137 | def normalize(self,x): 138 | x = (x - torch.mean(x, dim=1).unsqueeze(1)) 139 | return x 140 | 141 | def forward_feature(self, x): 142 | feat_map = self.feature(x, ) 143 | if self.resnet_layer_dim[-1] != self.reduce_dim: 144 | feat_map = self.Conv(feat_map) 145 | out = feat_map 146 | return out 147 | 148 | def normalize_feature(self, x): 149 | if self.params.norm == 'center': 150 | x = x - x.mean(2).unsqueeze(2) 151 | return x 152 | else: 153 | return x 154 | 155 | def forward_pretrain(self, x): 156 | x = self.forward_feature(x) 157 | x = self.drop(x) 158 | return self.SFC(x) 159 | 160 | def train_loop(self,epoch,train_loader,optimizer): 161 | print_step = 100 162 | avg_loss = 0 163 | total_correct = 0 164 | iter_num = len(train_loader) 165 | total = 0 166 | loss_ce_fn = nn.CrossEntropyLoss() 167 | for i ,data in enumerate(train_loader): 168 | image , label = data 169 | image = image.cuda() 170 | label = label.cuda() 171 | out = self.forward_pretrain(image) 172 | loss = loss_ce_fn(out, label) 173 | avg_loss = avg_loss + loss.item() 174 | optimizer.zero_grad() 175 | loss.backward() 176 | optimizer.step() 177 | _, pred = torch.max(out, 1) 178 | correct = (pred == label).sum().item() 179 | total_correct += correct 180 | total += label.size(0) 181 | if i % print_step == 0: 182 | print('\rEpoch {:d} | Batch: {:d}/{:d} | Loss: {:.4f} | Acc_train: {:.2f}'.format(epoch, i, len(train_loader), 183 | avg_loss / float(i + 1),correct/label.shape[0]*100), end=' ') 184 | print() 185 | 186 | return avg_loss / iter_num, float(total_correct) / total * 100 187 | 188 | def meta_val_loop(self,val_loader): 189 | acc = [] 190 | for i, data in enumerate(val_loader): 191 | 192 | support_xs, support_ys, query_xs, query_ys = data 193 | support_xs = support_xs.cuda() 194 | query_xs = query_xs.cuda() 195 | split_size = 128 196 | if support_xs.squeeze(0).shape[0] >= split_size: 197 | feat_sup_ = [] 198 | for j in range(math.ceil(support_xs.squeeze(0).shape[0] / split_size)): 199 | fest_sup_item = self.forward_feature( 200 | support_xs.squeeze(0)[j * split_size:min((j + 1) * split_size, support_xs.shape[1]), :, :, :],) 201 | feat_sup_.append(fest_sup_item if len(fest_sup_item.shape) >= 1 else fest_sup_item.unsqueeze(0)) 202 | feat_sup = torch.cat(feat_sup_, dim=0) 203 | else: 204 | feat_sup = self.forward_feature(support_xs.squeeze(0),) 205 | if query_xs.squeeze(0).shape[0] > split_size: 206 | feat_qry_ = [] 207 | for j in range(math.ceil(query_xs.squeeze(0).shape[0] / split_size)): 208 | feat_qry_item = self.forward_feature( 209 | query_xs.squeeze(0)[j * split_size:min((j + 1) * split_size, query_xs.shape[1]), :, :, :], 210 | ) 211 | feat_qry_.append(feat_qry_item if len(feat_qry_item.shape) > 1 else feat_qry_item.unsqueeze(0)) 212 | 213 | feat_qry = torch.cat(feat_qry_, dim=0) 214 | else: 215 | feat_qry = self.forward_feature(query_xs.squeeze(0),) 216 | if self.params.LR: 217 | pred = self.LR(feat_sup, support_ys, feat_qry, query_ys) 218 | else: 219 | with torch.enable_grad(): 220 | pred = self.softmax(feat_sup, support_ys, feat_qry, ) 221 | _, pred = torch.max(pred, dim=-1) 222 | if self.params.n_symmetry_aug > 1: 223 | query_ys = query_ys.view(-1, self.params.n_symmetry_aug) 224 | query_ys = torch.mode(query_ys, dim=-1)[0] 225 | acc_epo = np.mean(pred.cpu().numpy() == query_ys.numpy()) 226 | acc.append(acc_epo) 227 | return mean_confidence_interval(acc) 228 | 229 | def meta_test_loop(self,test_loader): 230 | acc = [] 231 | for i, (x, _) in enumerate(test_loader): 232 | self.params.n_aug_support_samples = self.transform_aug 233 | tic = time.time() 234 | x = x.contiguous().view(self.n_way, (self.n_shot + self.params.n_queries), *x.size()[2:]) 235 | support_xs = x[:, :self.n_shot].contiguous().view( 236 | self.n_way * self.n_shot * self.params.n_aug_support_samples, *x.size()[3:]).cuda() 237 | query_xs = x[:, self.n_shot:, 0:self.params.n_symmetry_aug].contiguous().view( 238 | self.n_way * self.params.n_queries * self.params.n_symmetry_aug, *x.size()[3:]).cuda() 239 | 240 | support_y = torch.from_numpy(np.repeat(range(self.params.n_way),self.n_shot*self.params.n_aug_support_samples)).unsqueeze(0) 241 | split_size = 128 242 | if support_xs.shape[0] >= split_size: 243 | feat_sup_ = [] 244 | for j in range(math.ceil(support_xs.shape[0]/split_size)): 245 | fest_sup_item =self.forward_feature(support_xs[j*split_size:min((j+1)*split_size,support_xs.shape[0]),],) 246 | feat_sup_.append(fest_sup_item if len(fest_sup_item.shape)>=1 else fest_sup_item.unsqueeze(0)) 247 | feat_sup = torch.cat(feat_sup_,dim=0) 248 | else: 249 | feat_sup = self.forward_feature(support_xs) 250 | if query_xs.shape[0] >= split_size: 251 | feat_qry_ = [] 252 | for j in range(math.ceil(query_xs.shape[0]/split_size)): 253 | feat_qry_item = self.forward_feature( 254 | query_xs[j * split_size:min((j + 1) * split_size, query_xs.shape[0]), ],) 255 | feat_qry_.append(feat_qry_item if len(feat_qry_item.shape) > 1 else feat_qry_item.unsqueeze(0)) 256 | 257 | feat_qry = torch.cat(feat_qry_,dim=0) 258 | else: 259 | feat_qry = self.forward_feature(query_xs,) 260 | 261 | if self.params.LR: 262 | pred = self.predict_wo_fc(feat_sup, support_y, feat_qry,) 263 | 264 | else: 265 | with torch.enable_grad(): 266 | pred = self.softmax(feat_sup, support_y, feat_qry,) 267 | _,pred = torch.max(pred,dim=-1) 268 | 269 | query_ys = np.repeat(range(self.n_way), self.params.n_queries) 270 | pred = pred.view(-1) 271 | acc_epo = np.mean(pred.cpu().numpy() == query_ys) 272 | acc.append(acc_epo) 273 | print("\repisode {} acc: {:.2f} | avg_acc: {:.2f} +- {:.2f}, elapse : {:.2f}".format(i, acc_epo * 100, 274 | *mean_confidence_interval( 275 | acc, multi=100), ( 276 | time.time() - tic) / 60), 277 | end='') 278 | 279 | return mean_confidence_interval(acc) 280 | 281 | def distillation(self,epoch,train_loader,optimizer,model_t): 282 | print_step = 100 283 | avg_loss = 0 284 | total_correct = 0 285 | iter_num = len(train_loader) 286 | total = 0 287 | loss_div_fn = DistillKL(4) 288 | loss_ce_fn = nn.CrossEntropyLoss() 289 | for i, data in enumerate(train_loader): 290 | image, label = data 291 | image = image.cuda() 292 | label = label.cuda() 293 | with torch.no_grad(): 294 | out_t = model_t.forward_pretrain(image) 295 | 296 | out= self.forward_pretrain(image) 297 | loss_ce = loss_ce_fn(out, label) 298 | loss_div = loss_div_fn(out, out_t) 299 | 300 | loss = loss_ce * 0.5 + loss_div * 0.5 301 | avg_loss = avg_loss + loss.item() 302 | optimizer.zero_grad() 303 | loss.backward() 304 | optimizer.step() 305 | 306 | _, pred = torch.max(out, 1) 307 | correct = (pred == label).sum().item() 308 | total_correct += correct 309 | total += label.size(0) 310 | if i % print_step == 0: 311 | print('\rEpoch {:d} | Batch: {:d}/{:d} | Loss: {:.4f} | Acc_train: {:.2f}'.format(epoch, i, 312 | len(train_loader), 313 | avg_loss / float( 314 | i + 1), 315 | correct / label.shape[ 316 | 0] * 100), 317 | end=' ') 318 | print() 319 | return avg_loss / iter_num, float(total_correct) / total * 100 320 | 321 | # new selective local fusion : 322 | def softmax(self,support_z,support_ys,query_z,): 323 | loss_ce_fn = nn.CrossEntropyLoss() 324 | batch_size = self.params.sfc_bs 325 | walk_times = 24 326 | alpha = self.params.alpha 327 | tempe = self.params.sim_temperature 328 | support_ys = support_ys.cuda() 329 | 330 | if self.params.embeding_way in ['BDC']: 331 | SFC = nn.Linear(self.dim, self.params.n_way).cuda() 332 | iter_num = 100 333 | optimizer = torch.optim.AdamW([{'params': SFC.parameters()}], lr=0.001, 334 | weight_decay=self.params.wd_test,eps=1e-4) 335 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, iter_num * math.ceil(self.n_way*self.n_shot/batch_size),eta_min=1e-3) 336 | 337 | 338 | else: 339 | tempe =16 340 | 341 | if self.params.embeding_way in ['baseline++']: 342 | SFC = nn.Linear(self.reduce_dim, self.params.n_way, bias=False).cuda() 343 | WeightNorm.apply(SFC, 'weight', dim=0) 344 | else: 345 | SFC = nn.Linear(self.reduce_dim, self.params.n_way).cuda() 346 | 347 | if self.params.optim in ['Adam']: 348 | # lr = 5e-3 349 | optimizer = torch.optim.AdamW([{'params': SFC.parameters()}], lr=0.005, 350 | weight_decay=self.params.wd_test, eps=5e-3) 351 | 352 | iter_num = 100 353 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, iter_num * math.ceil( 354 | self.n_way * self.n_shot / batch_size), eta_min=5e-3) 355 | 356 | 357 | else: 358 | optimizer = torch.optim.SGD([{'params': SFC.parameters()}], 359 | lr=self.params.lr, momentum=0.9, nesterov=True, 360 | weight_decay=self.params.wd_test) 361 | 362 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 150], gamma=0.1) 363 | iter_num = 180 364 | 365 | 366 | SFC.train() 367 | 368 | if self.params.embeding_way in ['BDC']: 369 | support_z = self.dcov(support_z) 370 | query_z = self.dcov(query_z) 371 | 372 | else: 373 | support_z = self.avg_pool(support_z).view(support_z.shape[0], -1) 374 | query_z = self.avg_pool(query_z) 375 | 376 | support_ys = support_ys.view(self.n_way * self.n_shot, self.params.n_aug_support_samples, -1) 377 | global_ys = support_ys[:, 0, :] 378 | 379 | support_z = support_z.reshape(self.n_way,self.n_shot,self.params.n_aug_support_samples,-1) 380 | query_z = query_z.reshape(self.n_way,self.params.n_queries,self.params.n_aug_support_samples,-1) 381 | 382 | 383 | feat_q = query_z[:,:,0] 384 | feat_ql = query_z[:,:,1:] 385 | feat_g = support_z[:,:,0] 386 | feat_sl = support_z[:,:,1:] 387 | # w_local: n * k * n * m 388 | num_sample = self.n_way*self.n_shot 389 | global_ys = global_ys.view(self.n_way,self.n_shot,-1) 390 | 391 | feat_g = feat_g.detach() 392 | feat_sl = feat_sl.detach() 393 | 394 | # feat_sl: n * k * n * dim 395 | I = torch.eye(self.n_way,self.n_way,device=feat_g.device).unsqueeze(0).unsqueeze(1) 396 | proto_moving = torch.mean(feat_g, dim=1) 397 | 398 | 399 | 400 | for i in range(iter_num): 401 | weight = compute_weight_local(proto_moving.unsqueeze(1), feat_sl, feat_sl, self.params.measure) 402 | idx_walk = torch.randperm(self.params.n_aug_support_samples-1,)[:walk_times] 403 | w_local = F.softmax(weight[:,:,:,idx_walk] * tempe, dim=-1) 404 | feat_s = torch.sum((feat_sl[:,:,idx_walk,:].unsqueeze(-3)) * (w_local.detach().unsqueeze(-1)), dim=-2) 405 | support_x = alpha * feat_g.unsqueeze(-2) + (1- alpha) * feat_s 406 | proto_update = torch.sum(torch.matmul(torch.mean(support_x,dim=1).transpose(1,2),torch.eye(self.n_way,device=proto_moving.device).unsqueeze(0)),dim=-1) 407 | proto_moving = 0.9 * proto_moving + 0.1 * proto_update 408 | spt_norm = torch.norm(support_x, p=2, dim=-1).unsqueeze(-1).expand_as(support_x) 409 | support_x = support_x.div(spt_norm + 1e-6) 410 | 411 | 412 | SFC.train() 413 | sample_idxs = torch.randperm(num_sample) 414 | for j in range(math.ceil(num_sample/batch_size)): 415 | idxs = sample_idxs[j*batch_size:min((j+1)*batch_size,num_sample)] 416 | x = support_x[idxs//self.n_shot,idxs%self.n_shot] 417 | y = global_ys[idxs//self.n_shot,idxs%self.n_shot] 418 | x = self.drop(x) 419 | # out = torch.sum(SFC(x)*I,dim=-1).view(-1,self.n_way) 420 | out = torch.sum(x.mul(SFC.weight),dim=-1) + SFC.bias 421 | loss_ce = loss_ce_fn(out,y.long().view(-1)) 422 | loss = loss_ce 423 | optimizer.zero_grad() 424 | loss.backward() 425 | optimizer.step() 426 | if lr_scheduler is not None: 427 | lr_scheduler.step() 428 | 429 | SFC.eval() 430 | 431 | w_local = compute_weight_local(proto_moving.unsqueeze(1), feat_ql, feat_sl,self.params.measure) 432 | w_local = F.softmax(w_local * tempe, dim=-1) 433 | 434 | # feat_sl: n * k * n * dim 435 | feat_lq = torch.sum(feat_ql.unsqueeze(-3) * w_local.unsqueeze(-1), dim=-2) 436 | query_x = alpha * feat_q.unsqueeze(-2) + (1- alpha) * feat_lq 437 | 438 | spt_norm = torch.norm(query_x, p=2, dim=-1).unsqueeze(-1).expand_as(query_x) 439 | query_x = query_x.div(spt_norm + 1e-6) 440 | 441 | with torch.no_grad(): 442 | # out = torch.sum(SFC(query_x)*I,dim=-1).view(-1,self.n_way) 443 | out = torch.sum(query_x.mul(SFC.weight), dim=-1) + SFC.bias 444 | 445 | return out 446 | 447 | def LR(self,support_z,support_ys,query_z,query_ys): 448 | 449 | clf = LR(penalty='l2', 450 | random_state=0, 451 | C=self.params.penalty_c, 452 | solver='lbfgs', 453 | max_iter=1000, 454 | multi_class='multinomial') 455 | 456 | spt_norm = torch.norm(support_z, p=2, dim=1).unsqueeze(1).expand_as(support_z) 457 | spt_normalized = support_z.div(spt_norm + 1e-6) 458 | 459 | qry_norm = torch.norm(query_z, p=2, dim=1).unsqueeze(1).expand_as(query_z) 460 | qry_normalized = query_z.div(qry_norm + 1e-6) 461 | 462 | z_support = spt_normalized.detach().cpu().numpy() 463 | z_query = qry_normalized.detach().cpu().numpy() 464 | 465 | y_support = np.repeat(range(self.params.n_way), self.n_shot) 466 | 467 | clf.fit(z_support, y_support) 468 | 469 | return torch.from_numpy(clf.predict(z_query)) 470 | --------------------------------------------------------------------------------