├── data
├── __init__.py
├── datamgr.py
└── dataset.py
├── network
├── __init__.py
└── resnet.py
├── .idea
├── misc.xml
├── vcs.xml
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
├── FeatWalk.iml
└── deployment.xml
├── methods
├── __init__.py
├── stl_deepbdc.py
├── bdc_module.py
├── template.py
└── FeatWalk.py
├── run.sh
├── README.md
├── utils
├── utils.py
└── loss.py
└── eval.py
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from . import datamgr
2 | from . import dataset
3 |
--------------------------------------------------------------------------------
/network/__init__.py:
--------------------------------------------------------------------------------
1 | from . import resnet
2 | # from . import convnet
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/methods/__init__.py:
--------------------------------------------------------------------------------
1 | from . import template
2 | # from . import protonet
3 | # from . import good_embed
4 | # from . import meta_deepbdc
5 | # from . import stl_deepbdc
6 | from . import bdc_module
7 |
8 |
9 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/FeatWalk.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | gpuid=0
2 |
3 | python eval.py --gpu ${gpuid} --n_episodes 2000 --n_aug_support_samples 17 --n_shot 1 --distill_model mini/ResNet12_stl_deepbdc_distill/last_model.tar --test_times 5 --lr 0.5 --fix_seed --sfc_bs 3 --sim_temperature 32
4 | python eval.py --gpu ${gpuid} --n_episodes 2000 --n_aug_support_samples 17 --n_shot 5 --distill_model mini/ResNet12_stl_deepbdc_distill/last_model.tar --test_times 5 --lr 0.01 --fix_seed --sfc_bs 3 --sim_temperature 32
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/methods/stl_deepbdc.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | import torch.nn.functional as F
6 | from .template import MetaTemplate
7 | from sklearn.linear_model import LogisticRegression
8 | from .bdc_module import BDC
9 |
10 |
11 | class STLDeepBDC(MetaTemplate):
12 | def __init__(self, params, model_func, n_way, n_support):
13 | super(STLDeepBDC, self).__init__(params, model_func, n_way, n_support)
14 | self.loss_fn = nn.CrossEntropyLoss()
15 |
16 | reduce_dim = params.reduce_dim
17 | self.feat_dim = int(reduce_dim * (reduce_dim+1) / 2)
18 | self.dcov = BDC(is_vec=True, input_dim=self.feature.feat_dim, dimension_reduction=reduce_dim)
19 |
20 | self.C = params.penalty_C
21 | self.params = params
22 |
23 | def feature_forward(self, x):
24 | out = self.dcov(x)
25 | return out
26 |
27 | def set_forward(self, x, is_feature=True):
28 | # print(x.shape)
29 | with torch.no_grad():
30 | z_support, z_query = self.parse_feature(x, is_feature)
31 | # print(z_support.shape)
32 | z_support = z_support.detach()
33 | z_query = z_query.detach()
34 |
35 | z_support = z_support.contiguous().view(self.n_way * self.n_support, -1)
36 | z_query = z_query.contiguous().view(self.n_way * self.n_query, -1)
37 |
38 | qry_norm = torch.norm(z_query, p=2, dim=1).unsqueeze(1).expand_as(z_query)
39 | spt_norm = torch.norm(z_support, p=2, dim=1).unsqueeze(1).expand_as(z_support)
40 | qry_normalized = z_query.div(qry_norm + 1e-6)
41 | spt_normalized = z_support.div(spt_norm + 1e-6)
42 |
43 | z_query = qry_normalized.detach().cpu().numpy()
44 | z_support = spt_normalized.detach().cpu().numpy()
45 | y_support = np.repeat(range(self.n_way), self.n_support)
46 |
47 | clf = LogisticRegression(penalty='l2',
48 | random_state=0,
49 | C=self.C,
50 | solver='lbfgs',
51 | max_iter=1000,
52 | multi_class='multinomial')
53 | clf.fit(z_support, y_support)
54 | scores = clf.predict(z_query)
55 |
56 | return scores
57 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FeatWalk
2 |
3 | FeatWalk is a method tailored for few-shot learning settings, focusing on effectively mining local views to mitigate the interference caused by discriminative features in global view pre-training. By analyzing the correlation of local views with different class prototypes, FeatWalk constructs a more comprehensive class-related representation. This method has been accepted by AAAI 2024, and this repository serves as the official implementation for reference.
4 |
5 | ## Comparison with Baseline Methods
6 |
7 | The following table demonstrates the performance of FeatWalk compared to the baseline method DeepBDC in various few-shot learning (FSL) scenarios on MiniImageNet and TieredImageNet. The results indicate that FeatWalk significantly outperforms DeepBDC in different FSL scenarios.
8 |
9 | | Method | Embedding | Mini
5-way 1-shot | Mini
5-way 5-shot | Tiered
5-way 1-shot | Tiered
5-way 5-shot |
10 | |----------|-----------|------------------------|------------------------|--------------------------|--------------------------|
11 | | DeepBDC | BDC | 67.83 ± 0.43 | 85.45 ± 0.29 | 73.82 ± 0.47 | 89.00 ± 0.30 |
12 | | FeatWalk | BDC | 70.21 ± 0.44 | 87.38 ± 0.27 | 75.25 ± 0.48 | 89.92 ± 0.29 |
13 |
14 |
15 | ## Preparation Before Running
16 |
17 | Before starting with FeatWalk, please ensure the following preparations are made:
18 |
19 | 1. Place the pre-trained models in the `checkpoint` directory. The pre-trained models can be obtained through the corresponding baseline methods or accessed from the official [DeepBDC](https://github.com/Fei-Long121/DeepBDC) implementation.
20 | 2. Ensure that datasets (such as [MiniImageNet](https://drive.google.com/file/d/1aBxfcU5cn-htIlqriiOQCOXp_t9TOm9g/view?usp=sharing)) are located in the `filelist` directory.
21 |
22 | #### Dataset Structure:
23 | ```
24 | --FeatWalk
25 | |--filelist
26 | |--miniImageNet
27 | |--train
28 | |--val
29 | |--test
30 | ```
31 | ## Running Commands
32 |
33 | To run FeatWalk, use the following command:
34 |
35 | ```bash
36 | # 5-Way 1-shot/5-shot on MiniImageNet
37 | sh run.sh
38 | ```
39 |
40 | ## Acknowledgments
41 | We would like to express our heartfelt gratitude to the open-source methods [GoodEmbed](https://github.com/WangYueFt/rfs/) and [DeepBDC](https://github.com/Fei-Long121/DeepBDC). Our code for this paper was inspired and informed by these sources, and their contributions have been invaluable in supporting our work.
42 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import shutil
4 | import time
5 | import pprint
6 | import torch
7 | import numpy as np
8 | import os.path as osp
9 | import random
10 | import torch.nn.functional as F
11 |
12 | def set_seed(seed):
13 | if seed == 0:
14 | print(' random seed')
15 | torch.backends.cudnn.benchmark = True
16 | else:
17 | print('manual seed:', seed)
18 | random.seed(seed)
19 | np.random.seed(seed)
20 | torch.manual_seed(seed)
21 | torch.cuda.manual_seed_all(seed)
22 | torch.backends.cudnn.deterministic = True
23 | torch.backends.cudnn.benchmark = False
24 |
25 | def load_model(model, dir):
26 | model_dict = model.state_dict()
27 | file_dict = torch.load(dir)['state']
28 | for k, v in file_dict.items():
29 | if k not in model_dict:
30 | print(k)
31 | file_dict = {k: v for k, v in file_dict.items() if k in model_dict}
32 | model_dict.update(file_dict)
33 | model.load_state_dict(model_dict)
34 | return model
35 |
36 | def compute_weight_local(feat_g,feat_ql,feat_sl,temperature=2.0):
37 | # feat_g : nk * dim
38 | # feat_l : nk * m * dim
39 | [_,k,m,dim] = feat_sl.shape
40 | [n,q,m,dim] = feat_ql.shape
41 |
42 | feat_g_expand = feat_g.unsqueeze(2).expand_as(feat_ql)
43 | sim_gl = torch.cosine_similarity(feat_g_expand,feat_ql,dim=-1)
44 | I_opp_m = (1 - torch.eye(m)).unsqueeze(0).to(sim_gl.device)
45 | sim_gl = -(torch.matmul(sim_gl, I_opp_m).unsqueeze(-2))/(m-1)
46 |
47 |
48 | return sim_gl
49 |
50 | # proto_walk
51 | def compute_weight_local(feat_g,feat_ql,feat_sl,measure = "cosine"):
52 | # feat_g : nk * dim
53 | # feat_l : nk * m * dim
54 | [_,k,m,dim] = feat_sl.shape
55 | [n,q,m,dim] = feat_ql.shape
56 | # print(feat_ql.shape)
57 |
58 | feat_g_expand = torch.mean(feat_g,dim=1).unsqueeze(0).unsqueeze(1).unsqueeze(3)
59 | if measure == "cosine":
60 | sim_gl = torch.cosine_similarity(feat_g_expand,feat_ql.unsqueeze(2),dim=-1)
61 | else:
62 | sim_gl = -1 * 0.002 * torch.sum((feat_g_expand - feat_ql.unsqueeze(2)) ** 2, dim=-1)
63 |
64 | I_m = torch.eye(m).unsqueeze(0).unsqueeze(1).to(sim_gl.device)
65 | sim_gl = torch.matmul(sim_gl, I_m)
66 |
67 | return sim_gl
68 |
69 |
70 | if __name__ == '__main__':
71 | feat_g = torch.randn((5,15,64))
72 | # feat_g = torch.ones((5,3,64))
73 | feat_sl = torch.randn((5,3,6,64))
74 | feat_ql = torch.randn((5,15,6,64))
75 | # feat_l = torch.ones((5,3,6,64))
76 | compute_weight_local(feat_g,feat_ql,feat_sl)
77 | # print(compute_weight_local(feat_g,feat_ql,feat_sl)[0,0])
--------------------------------------------------------------------------------
/methods/bdc_module.py:
--------------------------------------------------------------------------------
1 | '''
2 | @file: bdc_modele.py
3 | @author: Fei Long
4 | @author: Jiaming Lv
5 | Please cite the paper below if you use the code:
6 |
7 | Jiangtao Xie, Fei Long, Jiaming Lv, Qilong Wang and Peihua Li. Joint Distribution Matters: Deep Brownian Distance Covariance for Few-Shot Classification. IEEE Int. Conf. on Computer Vision and Pattern Recognition (CVPR), 2022.
8 |
9 | Copyright (C) 2022 Fei Long and Jiaming Lv
10 |
11 | All rights reserved.
12 | '''
13 |
14 | import torch
15 | import torch.nn as nn
16 |
17 | class BDC(nn.Module):
18 | def __init__(self, is_vec=True, input_dim=640, dimension_reduction=None, activate='relu'):
19 | super(BDC, self).__init__()
20 | self.is_vec = is_vec
21 | self.dr = dimension_reduction
22 | self.activate = activate
23 | self.input_dim = input_dim[0]
24 | # self.input_dim = input_dim
25 | if self.dr is not None and self.dr != self.input_dim:
26 | if activate == 'relu':
27 | self.act = nn.ReLU(inplace=True)
28 | elif activate == 'leaky_relu':
29 | self.act = nn.LeakyReLU(0.1)
30 | else:
31 | self.act = nn.ReLU(inplace=True)
32 |
33 | self.conv_dr_block = nn.Sequential(
34 | nn.Conv2d(self.input_dim, self.dr, kernel_size=1, stride=1, bias=False),
35 | nn.BatchNorm2d(self.dr),
36 | self.act
37 | )
38 | output_dim = self.dr if self.dr else self.input_dim
39 | if self.is_vec:
40 | self.output_dim = int(output_dim*(output_dim+1)/2)
41 | else:
42 | self.output_dim = int(output_dim*output_dim)
43 |
44 | self.temperature = nn.Parameter(torch.log((1. / (2 * input_dim[1]*input_dim[2])) * torch.ones(1,1)), requires_grad=True)
45 |
46 | self._init_weight()
47 |
48 | def _init_weight(self):
49 | for m in self.modules():
50 | if isinstance(m, nn.Conv2d):
51 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out', nonlinearity='leaky_relu')
52 | elif isinstance(m, nn.BatchNorm2d):
53 | nn.init.constant_(m.weight, 1)
54 | nn.init.constant_(m.bias, 0)
55 |
56 | def forward(self, x):
57 | if self.dr is not None and self.dr != self.input_dim:
58 | x = self.conv_dr_block(x)
59 | x = BDCovpool(x, self.temperature)
60 | if self.is_vec:
61 | x = Triuvec(x)
62 | else:
63 | x = x.reshape(x.shape[0], -1)
64 | return x
65 |
66 | def BDCovpool(x, t):
67 | batchSize, dim, h, w = x.data.shape
68 | M = h * w
69 | x = x.reshape(batchSize, dim, M)
70 |
71 | I = torch.eye(dim, dim, device=x.device).view(1, dim, dim).repeat(batchSize, 1, 1).type(x.dtype)
72 | I_M = torch.ones(batchSize, dim, dim, device=x.device).type(x.dtype)
73 | x_pow2 = x.bmm(x.transpose(1, 2))
74 | dcov = I_M.bmm(x_pow2 * I) + (x_pow2 * I).bmm(I_M) - 2 * x_pow2
75 |
76 | dcov = torch.clamp(dcov, min=0.0)
77 | dcov = torch.exp(t)* dcov
78 | dcov = torch.sqrt(dcov + 1e-5)
79 | t = dcov - 1. / dim * dcov.bmm(I_M) - 1. / dim * I_M.bmm(dcov) + 1. / (dim * dim) * I_M.bmm(dcov).bmm(I_M)
80 |
81 | return t
82 |
83 |
84 | def Triuvec(x):
85 | batchSize, dim, dim = x.shape
86 | r = x.reshape(batchSize, dim * dim)
87 | I = torch.ones(dim, dim).triu().reshape(dim * dim)
88 | index = I.nonzero(as_tuple = False)
89 | y = torch.zeros(batchSize, int(dim * (dim + 1) / 2), device=x.device).type(x.dtype)
90 | y = r[:, index].squeeze()
91 | return y
92 |
93 | if __name__ == '__main__':
94 | x = torch.rand((3, 4, 5, 5))
95 | # bdc = BDC(input_dim=x.shape,dimension_reduction=4)
96 |
97 | t = torch.log((1. / (2 * 25)) * torch.ones(1,1))
98 | print(BDCovpool(x,t)[0,:,:])
--------------------------------------------------------------------------------
/data/datamgr.py:
--------------------------------------------------------------------------------
1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate
2 |
3 | import torch
4 | from PIL import Image
5 | import numpy as np
6 | import torchvision.transforms as transforms
7 | from data.dataset import SetDataset_JSON, SimpleDataset, SetDataset, EpisodicBatchSampler, SimpleDataset_JSON
8 | from abc import abstractmethod
9 |
10 |
11 | class TransformLoader:
12 | def __init__(self, image_size):
13 | self.normalize_param = dict(mean=[0.472, 0.453, 0.410], std=[0.277, 0.268, 0.285])
14 |
15 | self.image_size = image_size
16 | if image_size == 84:
17 | self.resize_size = 92
18 | elif image_size == 128:
19 | self.resize_size = 140
20 | elif image_size == 224:
21 | self.resize_size = 256
22 |
23 | def get_composed_transform(self, aug=False):
24 | if aug:
25 | transform = transforms.Compose([
26 | transforms.RandomResizedCrop(self.image_size),
27 | transforms.RandomHorizontalFlip(),
28 | transforms.ColorJitter(0.4, 0.4, 0.4),
29 | transforms.ToTensor(),
30 | transforms.Normalize(**self.normalize_param)
31 | ])
32 | else:
33 | transform = transforms.Compose([
34 | transforms.Resize(self.resize_size),
35 | transforms.CenterCrop(self.image_size),
36 | transforms.ToTensor(),
37 | transforms.Normalize(**self.normalize_param)
38 | ])
39 | return transform
40 |
41 |
42 | class DataManager:
43 | @abstractmethod
44 | def get_data_loader(self, data_file, aug):
45 | pass
46 |
47 |
48 | class SimpleDataManager(DataManager):
49 | def __init__(self, data_path, image_size, batch_size, json_read=False):
50 | super(SimpleDataManager, self).__init__()
51 | self.batch_size = batch_size
52 | self.data_path = data_path
53 | self.trans_loader = TransformLoader(image_size)
54 | self.json_read = json_read
55 |
56 | def get_data_loader(self, data_file, aug): # parameters that would change on train/val set
57 | transform = self.trans_loader.get_composed_transform(aug)
58 | if self.json_read:
59 | dataset = SimpleDataset_JSON(self.data_path, data_file, transform)
60 | else:
61 | dataset = SimpleDataset(self.data_path, data_file, transform)
62 | data_loader_params = dict(batch_size=self.batch_size, shuffle=True, num_workers=12, pin_memory=True)
63 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
64 |
65 | return data_loader
66 |
67 |
68 | class SetDataManager(DataManager):
69 | def __init__(self, data_path, image_size, n_way, n_support, n_query, n_episode, json_read=False,aug_num = 0,args=None):
70 | super(SetDataManager, self).__init__()
71 | self.image_size = image_size
72 | self.n_way = n_way
73 | self.batch_size = n_support + n_query
74 | self.n_episode = n_episode
75 | self.data_path = data_path
76 | self.json_read = json_read
77 | self.aug_num = aug_num
78 | self.args = args
79 |
80 | self.trans_loader = TransformLoader(image_size)
81 |
82 | def get_data_loader(self, data_file, aug): # parameters that would change on train/val set
83 | transform = self.trans_loader.get_composed_transform(aug)
84 | if self.json_read:
85 | # print(self.aug_num)
86 | dataset = SetDataset_JSON(self.data_path, data_file, self.batch_size, transform,aug_num=self.aug_num, args=self.args)
87 | else:
88 | dataset = SetDataset(self.data_path, data_file, self.batch_size, transform,aug_num=self.aug_num, args=self.args)
89 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_episode)
90 | data_loader_params = dict(batch_sampler=sampler, pin_memory=True)
91 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
92 | return data_loader
93 |
94 |
95 |
96 | data_loader
97 |
98 |
99 |
100 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import pprint
4 | import os
5 | import time
6 | from data.datamgr import SetDataManager
7 | from methods.FeatWalk import FeatWalk_Net
8 | from utils.utils import set_seed,load_model
9 |
10 | DATA_DIR = 'data'
11 |
12 | torch.set_num_threads(4)
13 | _utils_pp = pprint.PrettyPrinter()
14 | def pprint(x):
15 | _utils_pp.pprint(x)
16 |
17 | def parse_option():
18 | parser = argparse.ArgumentParser('arguments for model pre-train')
19 | # about dataset and network
20 | parser.add_argument('--dataset', type=str, default='miniimagenet',
21 | choices=['miniimagenet', 'cub', 'tieredimagenet', 'fc100'])
22 | parser.add_argument('--data_root', type=str, default=DATA_DIR)
23 | parser.add_argument('--model', default='resnet12',choices=['resnet12', 'resnet18', 'resnet34', 'conv64'])
24 | parser.add_argument('--img_size', default=84, type=int, choices=[84,224])
25 |
26 | # about model :
27 | parser.add_argument('--drop_gama', default=0.5, type= float)
28 | parser.add_argument("--beta", default=0.01, type=float)
29 | parser.add_argument('--drop_rate', default=0.5, type=float)
30 | parser.add_argument('--reduce_dim', default=128, type=int)
31 |
32 | # about meta test
33 | parser.add_argument('--val_freq',default=5,type=int)
34 | parser.add_argument('--set', type=str, default='test', choices=['val', 'test'], help='the set for validation')
35 | parser.add_argument('--n_way', type=int, default=5)
36 | parser.add_argument('--n_shot', type=int, default=1)
37 | parser.add_argument('--n_aug_support_samples',type=int, default=1)
38 | parser.add_argument('--n_queries', type=int, default=15)
39 | parser.add_argument('--n_episodes', type=int, default=1000)
40 | parser.add_argument('--num_workers', default=0, type=int)
41 | parser.add_argument('--test_batch_size',default=1)
42 | parser.add_argument('--grid',default=None)
43 |
44 | # setting
45 | parser.add_argument('--gpu', default=0, type=int)
46 | parser.add_argument('--save_dir', default='checkpoint')
47 | parser.add_argument('--test_LR', default=False, action='store_true')
48 | parser.add_argument('--model_type',default='best',choices=['best','last'])
49 | parser.add_argument('--seed', default=1, type=int)
50 | parser.add_argument('--no_save_model', default=False, action='store_true')
51 | parser.add_argument('--method',default='local_proto',choices=['local_proto','good_metric','stl_deepbdc','confusion','WinSA'])
52 | parser.add_argument('--distill_model', default=None,type=str,help='about distillation model path')
53 | parser.add_argument('--penalty_c', default=1.0, type=float)
54 | parser.add_argument('--test_times', default=1, type=int)
55 |
56 | # confusion representation:
57 | parser.add_argument('--n_symmetry_aug', default=1, type=int)
58 | parser.add_argument('--embeding_way', default='BDC', choices=['BDC','GE','protonet','baseline++'])
59 | parser.add_argument('--wd_test', type=float, default=0.01)
60 | parser.add_argument('--LR', default=False,action='store_true')
61 | parser.add_argument('--lr', default=0.01, type=float)
62 | parser.add_argument('--optim', default='Adam',choices=['Adam', 'SGD'])
63 | parser.add_argument('--drop_few',default=0.5,type=float)
64 | parser.add_argument('--fix_seed', default=False, action='store_true')
65 | parser.add_argument('--local_scale', default=0.2 , type=float)
66 | parser.add_argument('--distill', default=False, action='store_true')
67 | parser.add_argument('--sfc_bs', default=16, type=int)
68 | parser.add_argument('--alpha', default=0.5 , type=float)
69 | parser.add_argument('--sim_temperature', default=64 , type=float)
70 | parser.add_argument('--measure', default='cosine', choices=['cosine','eudist'])
71 |
72 | args = parser.parse_args()
73 | args.n_symmetry_aug = args.n_aug_support_samples
74 |
75 | return args
76 |
77 |
78 | def model_load(args,model):
79 | # method = 'deep_emd' if args.deep_emd else 'local_match'
80 | method = args.method
81 | save_path = os.path.join(args.save_dir, args.dataset + "_" + method + "_resnet12_"+args.model_type
82 | + ("_"+str(args.model_id) if args.model_id else "") + ".pth")
83 | if args.distill_model is not None:
84 | save_path = os.path.join(args.save_dir, args.distill_model)
85 | else:
86 | assert "model load failed! "
87 | print('teacher model path: ' + save_path)
88 | state_dict = torch.load(save_path)['model']
89 | model.load_state_dict(state_dict)
90 | return model
91 |
92 |
93 | def main():
94 | args = parse_option()
95 | if args.img_size == 224 and args.transform == 'B':
96 | args.transform = 'B224'
97 |
98 | if args.grid:
99 | args.n_aug_support_samples = 1
100 | for i in args.grid:
101 | args.n_aug_support_samples += i ** 2
102 | args.n_symmetry_aug = args.n_aug_support_samples
103 |
104 | pprint(args)
105 | if args.gpu:
106 | gpu_device = str(args.gpu)
107 | else:
108 | gpu_device = "0"
109 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_device
110 | if args.fix_seed:
111 | set_seed(args.seed)
112 |
113 | json_file_read = False
114 | if args.dataset == 'cub':
115 | novel_file = 'novel.json'
116 | json_file_read = True
117 | else:
118 | novel_file = 'test'
119 | if args.dataset == 'miniimagenet':
120 | novel_few_shot_params = dict(n_way=args.n_way, n_support=args.n_shot)
121 | novel_datamgr = SetDataManager('filelist/miniImageNet', args.img_size, n_query=args.n_queries,
122 | n_episode=args.n_episodes, json_read=json_file_read,aug_num=args.n_aug_support_samples,args=args,
123 | **novel_few_shot_params)
124 | novel_loader = novel_datamgr.get_data_loader(novel_file, aug=False)
125 | num_classes = 64
126 | elif args.dataset == 'cub':
127 | novel_few_shot_params = dict(n_way=args.n_way, n_support=args.n_shot)
128 | novel_datamgr = SetDataManager('filelist/CUB',args.img_size, n_query=args.n_queries,
129 | n_episode=args.n_episodes, json_read=json_file_read,aug_num=args.n_aug_support_samples,args=args,
130 | **novel_few_shot_params)
131 | novel_loader = novel_datamgr.get_data_loader(novel_file, aug=False)
132 | num_classes = 100
133 |
134 | model = FeatWalk_Net(args,num_classes=num_classes).cuda()
135 | model.eval()
136 | model = load_model(model,os.path.join(args.save_dir,args.distill_model))
137 |
138 | print("-"*20+" start meta test... "+"-"*20)
139 | acc_sum = 0
140 | confidence_sum = 0
141 | for t in range(args.test_times):
142 | with torch.no_grad():
143 | tic = time.time()
144 | mean, confidence = model.meta_test_loop(novel_loader)
145 | acc_sum += mean
146 | confidence_sum += confidence
147 | print()
148 | print("Time {} :meta_val acc: {:.2f} +- {:.2f} elapse: {:.2f} min".format(t,mean * 100, confidence * 100,
149 | (time.time() - tic) / 60))
150 |
151 | print("{} times \t acc: {:.2f} +- {:.2f}".format(args.test_times, acc_sum/args.test_times * 100, confidence_sum/args.test_times * 100, ))
152 |
153 | if __name__ == '__main__':
154 | main()
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import torch.nn as nn
5 |
6 | class DistillKL(nn.Module):
7 | """KL divergence for distillation"""
8 | def __init__(self, T):
9 | super(DistillKL, self).__init__()
10 | self.T = T
11 |
12 | def forward(self, y_s, y_t):
13 | p_s = F.log_softmax(y_s/self.T, dim=1)
14 | p_t = F.softmax(y_t/self.T, dim=1)
15 | loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
16 | return loss
17 |
18 | def mask_loss(out,gama=0.5):
19 | # print(out.shape)
20 | crition = torch.nn.BCELoss()
21 | out = out.contiguous().view(out.shape[0],-1)
22 | avg_imp = torch.mean(out,dim=1).unsqueeze(1)
23 | rate_Sa = torch.mean(torch.where(out >= avg_imp, 1, 0).float(), dim=-1)
24 | imp_gama = 1 - rate_Sa * gama
25 |
26 | value, ind = torch.sort(out, dim=1, descending=True)
27 | drop_ind = torch.ceil((1 - imp_gama) * out.shape[-1])
28 | threshold = value[range(out.shape[0]), drop_ind.long()]
29 | threshold = threshold.unsqueeze(1).expand_as(out)
30 | fore_mask = torch.where(out >= threshold, 1, 0).float()
31 | loss_mask = crition(out,fore_mask)
32 | # print(loss_mask)
33 | return loss_mask
34 |
35 | def uniformity_loss(feat, const_feat,label=None,temp=0.5):
36 | sim_aa = torch.cosine_similarity(feat, const_feat, dim=-1)
37 | feat_expand = feat.unsqueeze(0).repeat(feat.shape[0],1,1)
38 | const_feat_expand = const_feat.unsqueeze(1).expand_as(feat_expand)
39 | sim_ab = torch.cosine_similarity(feat_expand, const_feat_expand,dim=-1)
40 | sim_a = torch.exp(sim_aa/temp)
41 | sim_b = torch.exp(sim_ab/temp)
42 | sim_tot = torch.sum(sim_b + 1e-6,dim=-1)
43 | if label is not None:
44 |
45 | sim_idx = torch.cat([torch.sum(sim_b[i,torch.where(label.squeeze(0) == label.squeeze(0)[i])[0]],dim=-1).unsqueeze(0)
46 | for i in range(sim_b.shape[0])],dim=0)
47 |
48 | p = sim_idx/sim_tot
49 |
50 | else:
51 | p = sim_a / sim_tot
52 | loss = torch.mean(-torch.log(p+1e-8))
53 | return loss
54 |
55 | def Distance_Correlation(latent, control):
56 | latent = F.normalize(latent)
57 | control = F.normalize(control)
58 |
59 | matrix_a = torch.sqrt(torch.sum(torch.square(latent.unsqueeze(0) - latent.unsqueeze(1)), dim=-1) + 1e-12)
60 | matrix_b = torch.sqrt(torch.sum(torch.square(control.unsqueeze(0) - control.unsqueeze(1)), dim=-1) + 1e-12)
61 |
62 | matrix_A = matrix_a - torch.mean(matrix_a, dim=0, keepdims=True) - torch.mean(matrix_a, dim=1,
63 | keepdims=True) + torch.mean(matrix_a)
64 | matrix_B = matrix_b - torch.mean(matrix_b, dim=0, keepdims=True) - torch.mean(matrix_b, dim=1,
65 | keepdims=True) + torch.mean(matrix_b)
66 |
67 | Gamma_XY = torch.sum(matrix_A * matrix_B) / (matrix_A.shape[0] * matrix_A.shape[1])
68 | Gamma_XX = torch.sum(matrix_A * matrix_A) / (matrix_A.shape[0] * matrix_A.shape[1])
69 | Gamma_YY = torch.sum(matrix_B * matrix_B) / (matrix_A.shape[0] * matrix_A.shape[1])
70 |
71 | correlation_r = Gamma_XY / torch.sqrt(Gamma_XX * Gamma_YY + 1e-9)
72 | return correlation_r
73 |
74 | def area_loss(out,gama=0.5):
75 | # print(out.shape)
76 | out = out.contiguous().view(out.shape[0],-1)
77 | y = torch.mean(out,-1)
78 | avg_imp = torch.mean(out,dim=-1).unsqueeze(1)
79 | rate_Sa = torch.mean(torch.where(out >= avg_imp, 1, 0).float(), dim=-1)
80 | imp_gama = rate_Sa * gama
81 | imp_gama = torch.cat([imp_gama.unsqueeze(1),1-imp_gama.unsqueeze(1)],dim=-1)
82 | y = torch.cat([y.unsqueeze(1), 1 - y.unsqueeze(1)], dim=-1)
83 | loss_area = F.kl_div(y.log(),imp_gama, reduction='batchmean')
84 | return loss_area
85 |
86 | def cosine_sim(out,lab):
87 |
88 | if len(lab.size()) == 1:
89 | label = torch.zeros((out.size(0),
90 | out.size(1))).long().cuda()
91 | label_range = torch.arange(0, out.size(0)).long()
92 | label[label_range, lab] = 1
93 | lab = label
94 |
95 | return torch.mean(torch.abs(out) * lab)
96 |
97 | def ce_loss(out, lab,temperature=1,is_softmax = True):
98 |
99 | if is_softmax:
100 | out = F.softmax(out*temperature, 1)
101 | if len(lab.size()) == 1:
102 | label = torch.zeros((out.size(0),
103 | out.size(1))).long().cuda()
104 | label_range = torch.arange(0, out.size(0)).long()
105 | label[label_range, lab] = 1
106 | lab = label
107 | loss = torch.mean(torch.sum(-lab*torch.log(out+1e-8),1))
108 |
109 | return loss
110 |
111 | # 计算信息熵的大小
112 | def entropy_loss(out):
113 | # crition = torch.nn.BCELoss()
114 | out = F.softmax(out, 1)
115 | # print(out)
116 | # pred = torch.ones_like(out)/out.shape[1]
117 | # loss = crition(pred,out)
118 | loss = -torch.mean(torch.sum(out*torch.log(out + 1e-8), 1))
119 | return loss
120 |
121 | def Few_loss(out,lab):
122 | # 目的似乎是实现poly loss,但实践过程中有误
123 | # 这个损失意义不大
124 | out = F.softmax(out, 1)
125 | eps = 2
126 | n = 1
127 | poly_head = torch.zeros(out.size(0),out.size(1)).cuda()
128 | for i in range(n):
129 | poly_head += eps*1/(i+1)*torch.pow(1-out,(i+1))
130 | ce_loss = torch.sum(-lab * torch.log(out + 1e-8) - poly_head,1)
131 | loss = torch.mean(ce_loss)
132 | return loss
133 |
134 | def loc_loss(out_loc,lab):
135 |
136 | out_loc = F.sigmoid(out_loc)
137 | # print(out_loc)
138 | log_loc = (-lab) * torch.log(out_loc + 1e-8)-(1-lab)* torch.log(out_loc + 1e-8)
139 | # loss = torch.mean(torch.sum(log_loc, 1))
140 | loss = torch.mean(torch.mean(log_loc, 1))
141 |
142 | # out_loc = out_loc.view(out_loc.size(0),out_loc.size(1),-1,2)
143 | # out_loc = F.softmax(out_loc,dim=3)
144 |
145 | return loss
146 |
147 | def euclidean_dist(x, y):
148 | '''
149 | Compute euclidean distance between two tensors
150 | '''
151 | # x: N x D
152 | # y: M x D
153 | n = x.size(0)
154 | m = y.size(0)
155 | d = x.size(1)
156 | if d != y.size(1):
157 | raise Exception
158 |
159 | # unsqueeze 在dim维度进行扩展
160 | x = x.unsqueeze(1).expand(n, m, d)
161 | y = y.unsqueeze(0).expand(n, m, d)
162 |
163 | return torch.pow(x - y, 2).sum(2)
164 |
165 | def prototypical_Loss(feat_out,lab,prototypes,epoch,center=False,temperature = 1):
166 | temperature = 256
167 | def supp_idxs(c):
168 | # FIXME when torch will support where as np
169 | return label_cpu.eq(c).nonzero()[:].squeeze(1)
170 |
171 | feat_cpu = feat_out.cpu()
172 | label_cpu = lab.cpu()
173 | prototypes = prototypes.cpu()
174 |
175 | n_classes = prototypes.size(0)
176 | if len(label_cpu.size()) == 1:
177 | classes = np.unique(label_cpu)
178 | # map :调用函数supp_idsx classes作为参数列表
179 | support_idxs = list(map(supp_idxs,classes))
180 | prototypes_update = torch.stack([feat_cpu[idx_list].mean(0) for idx_list in support_idxs])
181 | else:
182 | classes = range(n_classes)
183 | count = sum(label_cpu, 0)
184 | # feat_cpu dim : 64 * 640
185 | # label dim : 64 * 5
186 | prototypes_update = torch.matmul(feat_cpu.T,label_cpu.float())/torch.tensor(count).float()
187 | prototypes_update = prototypes_update.T
188 | # if epoch == 0 :
189 | # beta = 0.9
190 | prototypes[classes, :] = prototypes_update.detach()
191 |
192 | if len(lab.size()) == 1:
193 | label = torch.zeros((feat_cpu.size(0),
194 | n_classes)).long().cuda()
195 | label_range = torch.arange(0, feat_cpu.size(0)).long()
196 | label[label_range, lab] = 1
197 | lab = label
198 | dists = euclidean_dist(feat_cpu,prototypes)/temperature
199 | # print(dists.shape)
200 | log_p_y = F.log_softmax(-dists, dim=1)
201 | y = F.softmax(-dists,1)
202 |
203 | loss = torch.mean(torch.sum(-lab.cpu() * torch.log(y+1e-8),1))
204 | # print(loss)
205 | return loss,prototypes
206 |
207 | if __name__ == '__main__':
208 | # exp = torch.rand((3,5,5))
209 | # print(area_loss(torch.sigmoid(exp)))
210 | feat = torch.rand((5, 640,100))
211 | feat_cons = torch.rand((5,640,100))
212 | print(Distance_Correlation(feat,feat_cons))
213 | # print(uniformity_loss(feat,feat_cons))
--------------------------------------------------------------------------------
/methods/template.py:
--------------------------------------------------------------------------------
1 | import math
2 | from sqlite3 import paramstyle
3 | import torch
4 | import torch.nn as nn
5 | from torch.autograd import Variable
6 | import numpy as np
7 | import torch.nn.functional as F
8 | from abc import abstractmethod
9 | from .bdc_module import *
10 |
11 |
12 | class BaselineTrain(nn.Module):
13 | def __init__(self, params, model_func, num_class):
14 | super(BaselineTrain, self).__init__()
15 | self.params = params
16 | self.feature = model_func()
17 | if params.method in ['stl_deepbdc', 'meta_deepbdc']:
18 | reduce_dim = params.reduce_dim
19 | self.feat_dim = int(reduce_dim * (reduce_dim+1) / 2)
20 | self.dcov = BDC(is_vec=True, input_dim=self.feature.feat_dim, dimension_reduction=reduce_dim)
21 | self.dropout = nn.Dropout(params.dropout_rate)
22 |
23 | elif params.method in ['protonet', 'good_embed']:
24 | self.feat_dim = self.feature.feat_dim[0]
25 | self.avgpool = nn.AdaptiveAvgPool2d(1)
26 |
27 | if params.method in ['stl_deepbdc', 'meta_deepbdc', 'protonet', 'good_embed']:
28 | self.classifier = nn.Linear(self.feat_dim, num_class)
29 | self.classifier.bias.data.fill_(0)
30 |
31 | self.num_class = num_class
32 | self.loss_fn = nn.CrossEntropyLoss()
33 |
34 | def feature_forward(self, x):
35 | out = self.feature.forward(x)
36 | if self.params.method in ['stl_deepbdc', 'meta_deepbdc']:
37 | out = self.dcov(out)
38 | out = self.dropout(out)
39 | elif self.params.method in ['protonet', 'good_embed']:
40 | out = self.avgpool(out).view(out.size(0), -1)
41 | return out
42 |
43 | def forward(self, x):
44 | x = Variable(x.cuda())
45 | out = self.feature_forward(x)
46 | scores = self.classifier.forward(out)
47 | return scores
48 |
49 | def forward_meta_val(self, x):
50 | x = Variable(x.cuda())
51 | x = x.contiguous().view(self.params.val_n_way * (self.params.n_shot + self.params.n_query), *x.size()[2:])
52 |
53 | out = self.feature_forward(x)
54 |
55 | z_all = out.view(self.params.val_n_way, self.params.n_shot + self.params.n_query, -1)
56 | z_support = z_all[:, :self.params.n_shot]
57 | z_query = z_all[:, self.params.n_shot:]
58 | z_proto = z_support.contiguous().view(self.params.val_n_way, self.params.n_shot, -1).mean(1)
59 | z_query = z_query.contiguous().view(self.params.val_n_way * self.params.n_query, -1)
60 |
61 | if self.params.method in ['meta_deepbdc']:
62 | scores = self.metric(z_query, z_proto)
63 | elif self.params.method in ['protonet']:
64 | scores = self.euclidean_dist(z_query, z_proto)
65 | return scores
66 |
67 | def forward_loss(self, x, y):
68 | scores = self.forward(x)
69 | y = Variable(y.cuda())
70 | return self.loss_fn(scores, y), scores
71 |
72 | def forward_meta_val_loss(self, x):
73 | y_query = torch.from_numpy(np.repeat(range(self.params.val_n_way), self.params.n_query))
74 | y_query = Variable(y_query.cuda())
75 | y_label = np.repeat(range(self.params.val_n_way), self.params.n_query)
76 | scores = self.forward_meta_val(x)
77 | topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
78 | topk_ind = topk_labels.cpu().numpy()
79 | top1_correct = np.sum(topk_ind[:, 0] == y_label)
80 | return float(top1_correct), len(y_label), self.loss_fn(scores, y_query), scores
81 |
82 | def train_loop(self, epoch, train_loader, optimizer):
83 | print_freq = 200
84 | avg_loss = 0
85 | total_correct = 0
86 |
87 | iter_num = len(train_loader)
88 | total = len(train_loader) * self.params.batch_size
89 |
90 | for i, (x, y) in enumerate(train_loader):
91 | y = Variable(y.cuda())
92 | optimizer.zero_grad()
93 | loss, output = self.forward_loss(x, y)
94 | pred = output.data.max(1)[1]
95 | total_correct += pred.eq(y.data.view_as(pred)).sum()
96 | loss.backward()
97 | optimizer.step()
98 |
99 | avg_loss = avg_loss + loss.item()
100 |
101 | if i % print_freq == 0:
102 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), avg_loss / float(i + 1)))
103 | return avg_loss / iter_num, float(total_correct) / total * 100
104 |
105 | def test_loop(self, val_loader):
106 | total_correct = 0
107 | avg_loss = 0.0
108 | total = len(val_loader) * self.params.batch_size
109 | with torch.no_grad():
110 | for i, (x, y) in enumerate(val_loader):
111 | y = Variable(y.cuda())
112 | loss, output = self.forward_loss(x, y)
113 | avg_loss = avg_loss + loss.item()
114 | pred = output.data.max(1)[1]
115 | total_correct += pred.eq(y.data.view_as(pred)).sum()
116 | avg_loss /= len(val_loader)
117 | acc = float(total_correct) / total
118 | # print('Test Acc = %4.2f%%, loss is %.2f' % (acc * 100, avg_loss))
119 | return avg_loss, acc * 100
120 |
121 | def meta_test_loop(self, test_loader):
122 | acc_all = []
123 | avg_loss = 0
124 | iter_num = len(test_loader)
125 | with torch.no_grad():
126 | for i, (x, _) in enumerate(test_loader):
127 | correct_this, count_this, loss, _ = self.forward_meta_val_loss(x)
128 | acc_all.append(correct_this / count_this * 100)
129 | avg_loss = avg_loss + loss.item()
130 | acc_all = np.asarray(acc_all)
131 | acc_mean = np.mean(acc_all)
132 | acc_std = np.std(acc_all)
133 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
134 |
135 | return avg_loss / iter_num, acc_mean
136 |
137 | def metric(self, x, y):
138 | # x: N x D
139 | # y: M x D
140 | n = x.size(0)
141 | m = y.size(0)
142 | d = x.size(1)
143 | assert d == y.size(1)
144 |
145 | x = x.unsqueeze(1).expand(n, m, d)
146 | y = y.unsqueeze(0).expand(n, m, d)
147 |
148 | if self.params.n_shot > 1:
149 | dist = torch.pow(x - y, 2).sum(2)
150 | score = -dist
151 | else:
152 | score = (x * y).sum(2)
153 | return score
154 |
155 | def euclidean_dist(self, x, y):
156 | # x: N x D
157 | # y: M x D
158 | n = x.size(0)
159 | m = y.size(0)
160 | d = x.size(1)
161 | assert d == y.size(1)
162 |
163 | x = x.unsqueeze(1).expand(n, m, d)
164 | y = y.unsqueeze(0).expand(n, m, d)
165 |
166 | score = -torch.pow(x - y, 2).sum(2)
167 | return score
168 |
169 |
170 | class MetaTemplate(nn.Module):
171 | def __init__(self, params, model_func, n_way, n_support, change_way=True):
172 | super(MetaTemplate, self).__init__()
173 | self.n_way = n_way
174 | self.n_support = n_support
175 | self.n_query = params.n_query # (change depends on input)
176 | self.feature = model_func()
177 | self.change_way = change_way # some methods allow different_way classification during training and test
178 | self.params = params
179 |
180 | @abstractmethod
181 | def set_forward(self, x, is_feature):
182 | pass
183 |
184 | @abstractmethod
185 | def set_forward_loss(self, x):
186 | pass
187 |
188 | @abstractmethod
189 | def feature_forward(self, x):
190 | pass
191 |
192 | def forward(self, x):
193 | out = self.feature.forward(x)
194 | return out
195 |
196 | def parse_feature(self, x, is_feature):
197 | x = Variable(x.cuda())
198 | if is_feature:
199 | z_all = x
200 | else:
201 | x = x.contiguous().view(self.n_way * (self.n_support + self.n_query), *x.size()[2:])
202 | x = self.feature.forward(x)
203 | z_all = self.feature_forward(x)
204 | z_all = z_all.view(self.n_way, self.n_support + self.n_query, -1)
205 | z_support = z_all[:, :self.n_support]
206 |
207 | z_query = z_all[:, self.n_support:]
208 | # print(z_query.shape)
209 |
210 | return z_support, z_query
211 |
212 | def correct(self, x):
213 | scores = self.set_forward(x)
214 | y_query = np.repeat(range(self.n_way), self.n_query)
215 |
216 | topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
217 | topk_ind = topk_labels.cpu().numpy()
218 | top1_correct = np.sum(topk_ind[:, 0] == y_query)
219 | return float(top1_correct), len(y_query)
220 |
221 | def train_loop(self, epoch, train_loader, optimizer):
222 | print_freq = 200
223 | avg_loss = 0
224 | acc_all = []
225 | iter_num = len(train_loader)
226 | for i, (x, _) in enumerate(train_loader):
227 | self.n_query = x.size(1) - self.n_support
228 | if self.change_way:
229 | self.n_way = x.size(0)
230 | optimizer.zero_grad()
231 | correct_this, count_this, loss, _ = self.set_forward_loss(x)
232 | acc_all.append(correct_this / count_this * 100)
233 | loss.backward()
234 | optimizer.step()
235 | avg_loss = avg_loss + loss.item()
236 |
237 | if i % print_freq == 0:
238 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader),
239 | avg_loss / float(i + 1)))
240 | acc_all = np.asarray(acc_all)
241 | acc_mean = np.mean(acc_all)
242 | return avg_loss / iter_num, acc_mean
243 |
244 | def test_loop(self, test_loader, record=None):
245 | acc_all = []
246 | avg_loss = 0
247 | iter_num = len(test_loader)
248 | with torch.no_grad():
249 | for i, (x, _) in enumerate(test_loader):
250 | self.n_query = x.size(1) - self.n_support
251 | if self.change_way:
252 | self.n_way = x.size(0)
253 | correct_this, count_this, loss, _ = self.set_forward_loss(x)
254 | acc_all.append(correct_this / count_this * 100)
255 | avg_loss = avg_loss + loss.item()
256 | acc_all = np.asarray(acc_all)
257 | acc_mean = np.mean(acc_all)
258 | acc_std = np.std(acc_all)
259 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
260 |
261 | return avg_loss / iter_num, acc_mean
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate
2 |
3 | import torch
4 | from PIL import Image
5 | import json
6 | import numpy as np
7 | import torchvision.transforms as transforms
8 | import os
9 |
10 | identity = lambda x: x
11 |
12 |
13 | def get_grid_location(size, ratio, num_grid):
14 | '''
15 |
16 | :param size: size of the height/width
17 | :param ratio: generate grid size/ even divided grid size
18 | :param num_grid: number of grid
19 | :return: a list containing the coordinate of the grid
20 | '''
21 | raw_grid_size = int(size / num_grid)
22 | enlarged_grid_size = int(size / num_grid * ratio)
23 |
24 | center_location = raw_grid_size // 2
25 |
26 | location_list = []
27 | for i in range(num_grid):
28 | location_list.append((max(0, center_location - enlarged_grid_size // 2),
29 | min(size, center_location + enlarged_grid_size // 2)))
30 | center_location = center_location + raw_grid_size
31 |
32 | return location_list
33 |
34 |
35 | class SimpleDataset:
36 | def __init__(self, data_path, data_file_list, transform, target_transform=identity):
37 | label = []
38 | data = []
39 | k = 0
40 | data_dir_list = data_file_list.replace(" ","").split(',')
41 | for data_file in data_dir_list:
42 | img_dir = data_path + '/' + data_file
43 | for i in os.listdir(img_dir):
44 | file_dir = os.path.join(img_dir, i)
45 | for j in os.listdir(file_dir):
46 | data.append(file_dir + '/' + j)
47 | label.append(k)
48 | k += 1
49 | self.data = data
50 | self.label = label
51 | self.transform = transform
52 | self.target_transform = target_transform
53 |
54 | def __getitem__(self, i):
55 | image_path = os.path.join(self.data[i])
56 | img = Image.open(image_path).convert('RGB')
57 | img = self.transform(img)
58 | target = self.target_transform(self.label[i] - min(self.label))
59 | return img, target
60 |
61 | def __len__(self):
62 | return len(self.label)
63 |
64 |
65 | class SetDataset:
66 | def __init__(self, data_path, data_file_list, batch_size, transform,aug_num=0,args=None):
67 | label = []
68 | data = []
69 | k = 0
70 | data_dir_list = data_file_list.replace(" ","").split(',')
71 | for data_file in data_dir_list:
72 | img_dir = data_path + '/' + data_file
73 | for i in os.listdir(img_dir):
74 | file_dir = os.path.join(img_dir, i)
75 | for j in os.listdir(file_dir):
76 | data.append(file_dir + '/' + j)
77 | label.append(k)
78 | k += 1
79 | self.data = data
80 | self.label = label
81 | self.transform = transform
82 | self.cl_list = np.unique(self.label).tolist()
83 | self.args = args
84 |
85 | self.sub_meta = {}
86 | for cl in self.cl_list:
87 | self.sub_meta[cl] = []
88 |
89 | for x, y in zip(self.data, self.label):
90 | self.sub_meta[y].append(x)
91 |
92 | self.sub_dataloader = []
93 | sub_data_loader_params = dict(batch_size=batch_size,
94 | shuffle=True,
95 | num_workers=0, # use main thread only or may receive multiple batches
96 | pin_memory=False)
97 | self.cl_num = 0
98 | for cl in self.cl_list:
99 | if len(self.sub_meta[cl])>=25:
100 | self.cl_num += 1
101 | sub_dataset = SubDataset(self.sub_meta[cl], cl, transform=transform,aug_num=aug_num,args=self.args)
102 | self.sub_dataloader.append(torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params))
103 |
104 | def __getitem__(self, i):
105 | return next(iter(self.sub_dataloader[i]))
106 |
107 | def __len__(self):
108 | return self.cl_num
109 |
110 |
111 | class SubDataset:
112 | def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), target_transform=identity,aug_num=0,args=None):
113 | self.sub_meta = sub_meta
114 | self.cl = cl
115 | self.transform = transform
116 | self.target_transform = target_transform
117 | self.aug_num = aug_num
118 | self.grid = args.grid
119 | self.transform_grid = transforms.Compose([
120 | transforms.Resize([args.img_size,args.img_size]),
121 | transforms.ToTensor(),
122 | transforms.Normalize(mean=[0.472, 0.453, 0.410], std=[0.277, 0.268, 0.285])
123 | ])
124 |
125 | self.transform_s = transforms.Compose([
126 | transforms.RandomResizedCrop(args.img_size, scale=(args.local_scale, args.local_scale)),
127 | transforms.RandomHorizontalFlip(),
128 | transforms.ToTensor(),
129 | transforms.Normalize(mean=[0.472, 0.453, 0.410], std=[0.277, 0.268, 0.285])
130 | ])
131 |
132 | def __getitem__(self, i):
133 | image_path = os.path.join(self.sub_meta[i])
134 | img = Image.open(image_path).convert('RGB')
135 | img_set = []
136 | img_w = self.transform(img)
137 | img_set.append(img_w.unsqueeze(0))
138 | if self.grid:
139 | for num_patch in self.grid:
140 | patches = self.get_pyramid(img, num_patch)
141 | # print(patches.shape)
142 | img_set.append(patches)
143 | else:
144 | for _ in range(self.aug_num - 1):
145 | img_s = self.transform_s(img)
146 | img_set.append(img_s.unsqueeze(0))
147 | # for item in img_set:
148 | # print(item.shape)
149 | img = torch.cat(img_set, dim=0)
150 | target = self.target_transform(self.cl)
151 | return img, target
152 |
153 | def get_pyramid(self, img, num_patch):
154 | num_grid = num_patch
155 | grid_ratio = 1
156 | w, h = img.size
157 | grid_locations_w = get_grid_location(w, grid_ratio, num_grid)
158 | grid_locations_h = get_grid_location(h, grid_ratio, num_grid)
159 |
160 | patches_list = []
161 | for i in range(num_grid):
162 | for j in range(num_grid):
163 | patch_location_w = grid_locations_w[j]
164 | patch_location_h = grid_locations_h[i]
165 | left_up_corner_w = patch_location_w[0]
166 | left_up_corner_h = patch_location_h[0]
167 | right_down_cornet_w = patch_location_w[1]
168 | right_down_cornet_h = patch_location_h[1]
169 | patch = img.crop((left_up_corner_w, left_up_corner_h, right_down_cornet_w, right_down_cornet_h))
170 | patch = self.transform_grid(patch)
171 | patches_list.append(patch.unsqueeze(0))
172 | return torch.cat(patches_list,dim=0)
173 |
174 | def __len__(self):
175 | return len(self.sub_meta)
176 |
177 |
178 | class SimpleDataset_JSON:
179 | def __init__(self, data_path, data_file, transform, target_transform=identity):
180 | data = data_path + '/' + data_file
181 | with open(data, 'r') as f:
182 | self.meta = json.load(f)
183 | self.transform = transform
184 | self.target_transform = target_transform
185 |
186 | def __getitem__(self, i):
187 | image_path = os.path.join(self.meta['image_names'][i])
188 | img = Image.open(image_path).convert('RGB')
189 | img = self.transform(img)
190 | target = self.target_transform(self.meta['image_labels'][i])
191 | return img, target
192 |
193 | def __len__(self):
194 | return len(self.meta['image_names'])
195 |
196 |
197 | class SetDataset_JSON:
198 | def __init__(self, data_path, data_file, batch_size, transform,aug_num=0,args=None):
199 | data = data_path + '/' + data_file
200 |
201 | print(transform.__dict__)
202 | with open(data, 'r') as f:
203 | self.meta = json.load(f)
204 |
205 | self.cl_list = np.unique(self.meta['image_labels']).tolist()
206 | self.args = args
207 |
208 | self.sub_meta = {}
209 | for cl in self.cl_list:
210 | self.sub_meta[cl] = []
211 |
212 | for x, y in zip(self.meta['image_names'], self.meta['image_labels']):
213 | self.sub_meta[y].append(x)
214 |
215 | self.sub_dataloader = []
216 | # print(len(self.cl_list))
217 | sub_data_loader_params = dict(batch_size=batch_size,
218 | shuffle=True,
219 | num_workers=0, # use main thread only or may receive multiple batches
220 | pin_memory=False)
221 | for cl in self.cl_list:
222 | sub_dataset = SubDataset_JSON(self.sub_meta[cl], cl, transform=transform,aug_num=aug_num,args=self.args)
223 | self.sub_dataloader.append(torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params))
224 |
225 | def __getitem__(self, i):
226 | return next(iter(self.sub_dataloader[i]))
227 |
228 | def __len__(self):
229 | return len(self.cl_list)
230 |
231 |
232 | class SubDataset_JSON:
233 | def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), target_transform=identity,aug_num=0,args=None):
234 | self.sub_meta = sub_meta
235 | self.cl = cl
236 | self.transform = transform
237 | self.target_transform = target_transform
238 | self.grid = args.grid
239 | self.transform_grid = transforms.Compose([
240 | transforms.Resize([args.img_size, args.img_size]),
241 | transforms.ToTensor(),
242 | transforms.Normalize(mean=[0.472, 0.453, 0.410], std=[0.277, 0.268, 0.285])
243 | ])
244 |
245 | self.transform_s = transforms.Compose([
246 | # transforms.RandomResizedCrop(224, scale=(0.3, 0.7)),
247 | transforms.RandomResizedCrop(args.img_size, scale=(args.local_scale, args.local_scale)),
248 | # transforms.RandomResizedCrop(args.img_size),
249 | transforms.RandomHorizontalFlip(),
250 | # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
251 | transforms.ToTensor(),
252 | transforms.Normalize(mean=[0.472, 0.453, 0.410], std=[0.277, 0.268, 0.285])
253 | ])
254 | # print(aug_num)
255 | self.aug_num =aug_num
256 |
257 | def __getitem__(self, i):
258 | # print( '%d -%d' %(self.cl,i))
259 | image_path = os.path.join(self.sub_meta[i])
260 | img = Image.open(image_path).convert('RGB')
261 | img_set = []
262 | img_w = self.transform(img)
263 | img_set.append(img_w.unsqueeze(0))
264 | if self.grid:
265 | for num_patch in self.grid:
266 | patches = self.get_pyramid(img, num_patch)
267 | img_set.append(patches)
268 | else:
269 | for _ in range(self.aug_num - 1):
270 | img_s = self.transform_s(img)
271 | img_set.append(img_s.unsqueeze(0))
272 | img = torch.cat(img_set,dim=0)
273 | target = self.target_transform(self.cl)
274 | return img, target
275 |
276 | def get_pyramid(self, img, num_patch):
277 | num_grid = num_patch
278 | grid_ratio = 1
279 | w, h = img.size
280 | grid_locations_w = get_grid_location(w, grid_ratio, num_grid)
281 | grid_locations_h = get_grid_location(h, grid_ratio, num_grid)
282 |
283 | patches_list = []
284 | for i in range(num_grid):
285 | for j in range(num_grid):
286 | patch_location_w = grid_locations_w[j]
287 | patch_location_h = grid_locations_h[i]
288 | left_up_corner_w = patch_location_w[0]
289 | left_up_corner_h = patch_location_h[0]
290 | right_down_cornet_w = patch_location_w[1]
291 | right_down_cornet_h = patch_location_h[1]
292 | patch = img.crop((left_up_corner_w, left_up_corner_h, right_down_cornet_w, right_down_cornet_h))
293 | patch = self.transform_grid(patch)
294 | patches_list.append(patch.unsqueeze(0))
295 | return torch.cat(patches_list, dim=0)
296 |
297 |
298 | def __len__(self):
299 | return len(self.sub_meta)
300 |
301 |
302 | class EpisodicBatchSampler(object):
303 | def __init__(self, n_classes, n_way, n_episodes):
304 | self.n_classes = n_classes
305 | self.n_way = n_way
306 | self.n_episodes = n_episodes
307 |
308 | def __len__(self):
309 | return self.n_episodes
310 |
311 | def __iter__(self):
312 | for i in range(self.n_episodes):
313 | yield torch.randperm(self.n_classes)[:self.n_way]
314 |
315 |
316 |
--------------------------------------------------------------------------------
/network/resnet.py:
--------------------------------------------------------------------------------
1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate
2 |
3 | import torch
4 | from torch.autograd import Variable
5 | import torch.nn as nn
6 | import math
7 | import numpy as np
8 | import torch.nn.functional as F
9 | from torch.nn.utils.weight_norm import WeightNorm
10 | from torch.distributions import Bernoulli
11 |
12 | ##############################################
13 | # Basic ResNet model #
14 | ##############################################
15 |
16 | def init_layer(L):
17 | # Initialization using fan-in
18 | if isinstance(L, nn.Conv2d):
19 | n = L.kernel_size[0] * L.kernel_size[1] * L.out_channels
20 | L.weight.data.normal_(0, math.sqrt(2.0 / float(n)))
21 | elif isinstance(L, nn.BatchNorm2d):
22 | L.weight.data.fill_(1)
23 | L.bias.data.fill_(0)
24 |
25 | class Flatten(nn.Module):
26 | def __init__(self):
27 | super(Flatten, self).__init__()
28 |
29 | def forward(self, x):
30 | return x.view(x.size(0), -1)
31 |
32 | # Simple ResNet Block
33 | class SimpleBlock(nn.Module):
34 | maml = False # Default
35 |
36 | def __init__(self, indim, outdim, half_res):
37 | super(SimpleBlock, self).__init__()
38 | self.indim = indim
39 | self.outdim = outdim
40 | self.C1 = nn.Conv2d(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False)
41 | self.BN1 = nn.BatchNorm2d(outdim)
42 | self.C2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1, bias=False)
43 | self.BN2 = nn.BatchNorm2d(outdim)
44 |
45 | self.relu1 = nn.ReLU(inplace=True)
46 | self.relu2 = nn.ReLU(inplace=True)
47 |
48 | self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2]
49 |
50 | self.half_res = half_res
51 |
52 | # if the input number of channels is not equal to the output, then need a 1x1 convolution
53 | if indim != outdim:
54 |
55 | self.shortcut = nn.Conv2d(indim, outdim, 1, 2 if half_res else 1, bias=False)
56 | self.BNshortcut = nn.BatchNorm2d(outdim)
57 |
58 | self.parametrized_layers.append(self.shortcut)
59 | self.parametrized_layers.append(self.BNshortcut)
60 | self.shortcut_type = '1x1'
61 | else:
62 | self.shortcut_type = 'identity'
63 |
64 | for layer in self.parametrized_layers:
65 | init_layer(layer)
66 |
67 | def forward(self, x):
68 | out = self.C1(x)
69 | out = self.BN1(out)
70 | out = self.relu1(out)
71 | out = self.C2(out)
72 | out = self.BN2(out)
73 | short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x))
74 | out = out + short_out
75 | out = self.relu2(out)
76 | return out
77 |
78 |
79 | # Bottleneck block
80 | class BottleneckBlock(nn.Module):
81 | maml = False # Default
82 |
83 | def __init__(self, indim, outdim, half_res):
84 | super(BottleneckBlock, self).__init__()
85 | bottleneckdim = int(outdim / 4)
86 | self.indim = indim
87 | self.outdim = outdim
88 | self.C1 = nn.Conv2d(indim, bottleneckdim, kernel_size=1, bias=False)
89 | self.BN1 = nn.BatchNorm2d(bottleneckdim)
90 | self.C2 = nn.Conv2d(bottleneckdim, bottleneckdim, kernel_size=3, stride=2 if half_res else 1, padding=1)
91 | self.BN2 = nn.BatchNorm2d(bottleneckdim)
92 | self.C3 = nn.Conv2d(bottleneckdim, outdim, kernel_size=1, bias=False)
93 | self.BN3 = nn.BatchNorm2d(outdim)
94 |
95 | self.relu = nn.ReLU()
96 | self.parametrized_layers = [self.C1, self.BN1, self.C2, self.BN2, self.C3, self.BN3]
97 | self.half_res = half_res
98 |
99 | # if the input number of channels is not equal to the output, then need a 1x1 convolution
100 | if indim != outdim:
101 | self.shortcut = nn.Conv2d(indim, outdim, 1, stride=2 if half_res else 1, bias=False)
102 |
103 | self.parametrized_layers.append(self.shortcut)
104 | self.shortcut_type = '1x1'
105 | else:
106 | self.shortcut_type = 'identity'
107 |
108 | for layer in self.parametrized_layers:
109 | init_layer(layer)
110 |
111 | def forward(self, x):
112 |
113 | short_out = x if self.shortcut_type == 'identity' else self.shortcut(x)
114 | out = self.C1(x)
115 | out = self.BN1(out)
116 | out = self.relu(out)
117 | out = self.C2(out)
118 | out = self.BN2(out)
119 | out = self.relu(out)
120 | out = self.C3(out)
121 | out = self.BN3(out)
122 | out = out + short_out
123 |
124 | out = self.relu(out)
125 | return out
126 |
127 |
128 |
129 | class ResNet(nn.Module):
130 | maml = False # Default
131 |
132 | def __init__(self, block, list_of_num_layers, list_of_out_dims, flatten=False):
133 | # list_of_num_layers specifies number of layers in each stage
134 | # list_of_out_dims specifies number of output channel for each stage
135 | super(ResNet, self).__init__()
136 | assert len(list_of_num_layers) == 4, 'Can have only four stages'
137 |
138 | conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
139 | bias=False)
140 | bn1 = nn.BatchNorm2d(64)
141 |
142 | relu = nn.ReLU()
143 | pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
144 |
145 | init_layer(conv1)
146 | init_layer(bn1)
147 | trunk = [conv1, bn1, relu, pool1]
148 |
149 | indim = 64
150 | for i in range(4):
151 | for j in range(list_of_num_layers[i]):
152 | half_res = (i >= 1) and (j == 0) and i != 3
153 | B = block(indim, list_of_out_dims[i], half_res)
154 | trunk.append(B)
155 | indim = list_of_out_dims[i]
156 |
157 | if flatten:
158 | avgpool = nn.AvgPool2d(7)
159 | trunk.append(avgpool)
160 | trunk.append(Flatten())
161 | # self.final_feat_dim = indim
162 |
163 | self.feat_dim = [512, 14, 14]
164 | self.trunk = nn.Sequential(*trunk)
165 |
166 | def forward(self, x):
167 | out = self.trunk(x)
168 | # out = out.view(out.size(0), -1)
169 | return out
170 |
171 |
172 | def ResNet10(flatten=True):
173 | return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], flatten)
174 |
175 |
176 | def ResNet18(flatten=False):
177 | return ResNet(SimpleBlock, [2, 2, 2, 2], [64, 128, 256, 512], flatten)
178 |
179 |
180 | def ResNet34(flatten=True):
181 | return ResNet(SimpleBlock, [3, 4, 6, 3], [64, 128, 256, 512], flatten)
182 |
183 |
184 | def ResNet50(flatten=True):
185 | return ResNet(BottleneckBlock, [3, 4, 6, 3], [256, 512, 1024, 2048], flatten)
186 |
187 |
188 | def ResNet101(flatten=True):
189 | return ResNet(BottleneckBlock, [3, 4, 23, 3], [256, 512, 1024, 2048], flatten)
190 |
191 |
192 | ##############################################
193 | # a variant of ResNet model #
194 | ##############################################
195 |
196 | def conv3x3(in_planes, out_planes, stride=1):
197 | """3x3 convolution with padding"""
198 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
199 | padding=1, bias=False)
200 |
201 |
202 | class SELayer(nn.Module):
203 | def __init__(self, channel, reduction=16):
204 | super(SELayer, self).__init__()
205 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
206 | self.fc = nn.Sequential(
207 | nn.Linear(channel, channel // reduction),
208 | nn.ReLU(inplace=True),
209 | nn.Linear(channel // reduction, channel),
210 | nn.Sigmoid()
211 | )
212 |
213 | def forward(self, x):
214 | b, c, _, _ = x.size()
215 | y = self.avg_pool(x).view(b, c)
216 | y = self.fc(y).view(b, c, 1, 1)
217 | return x * y
218 |
219 |
220 | class DropBlock(nn.Module):
221 | def __init__(self, block_size):
222 | super(DropBlock, self).__init__()
223 |
224 | self.block_size = block_size
225 | #self.gamma = gamma
226 | #self.bernouli = Bernoulli(gamma)
227 |
228 | def forward(self, x, gamma):
229 | # shape: (bsize, channels, height, width)
230 |
231 | if self.training:
232 | batch_size, channels, height, width = x.shape
233 |
234 | bernoulli = Bernoulli(gamma)
235 | mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda()
236 | block_mask = self._compute_block_mask(mask)
237 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3]
238 | count_ones = block_mask.sum()
239 |
240 | return block_mask * x * (countM / count_ones)
241 | else:
242 | return x
243 |
244 | def _compute_block_mask(self, mask):
245 | left_padding = int((self.block_size-1) / 2)
246 | right_padding = int(self.block_size / 2)
247 |
248 | batch_size, channels, height, width = mask.shape
249 | #print ("mask", mask[0][0])
250 | non_zero_idxs = mask.nonzero()
251 | nr_blocks = non_zero_idxs.shape[0]
252 |
253 | offsets = torch.stack(
254 | [
255 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding,
256 | torch.arange(self.block_size).repeat(self.block_size), #- left_padding
257 | ]
258 | ).t().cuda()
259 | offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()), 1)
260 |
261 | if nr_blocks > 0:
262 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1)
263 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4)
264 | offsets = offsets.long()
265 |
266 | block_idxs = non_zero_idxs + offsets
267 | #block_idxs += left_padding
268 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
269 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1.
270 | else:
271 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
272 |
273 | block_mask = 1 - padded_mask#[:height, :width]
274 | return block_mask
275 |
276 |
277 | class BasicBlockVariant(nn.Module):
278 | expansion = 1
279 |
280 | def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False,
281 | block_size=1, use_se=False):
282 | super(BasicBlockVariant, self).__init__()
283 | self.conv1 = conv3x3(inplanes, planes)
284 | self.bn1 = nn.BatchNorm2d(planes)
285 | self.relu = nn.LeakyReLU(0.1)
286 | self.conv2 = conv3x3(planes, planes)
287 | self.bn2 = nn.BatchNorm2d(planes)
288 | self.conv3 = conv3x3(planes, planes)
289 | self.bn3 = nn.BatchNorm2d(planes)
290 | self.maxpool = nn.MaxPool2d(stride)
291 | self.downsample = downsample
292 | self.stride = stride
293 | self.drop_rate = drop_rate
294 | self.num_batches_tracked = 0
295 | self.drop_block = drop_block
296 | self.block_size = block_size
297 | self.DropBlock = DropBlock(block_size=self.block_size)
298 | self.use_se = use_se
299 | if self.use_se:
300 | self.se = SELayer(planes, 4)
301 |
302 | def forward(self, x):
303 | self.num_batches_tracked += 1
304 |
305 | residual = x
306 |
307 | out = self.conv1(x)
308 | out = self.bn1(out)
309 | out = self.relu(out)
310 |
311 | out = self.conv2(out)
312 | out = self.bn2(out)
313 | out = self.relu(out)
314 |
315 | out = self.conv3(out)
316 | out = self.bn3(out)
317 | if self.use_se:
318 | out = self.se(out)
319 |
320 | if self.downsample is not None:
321 | residual = self.downsample(x)
322 | out += residual
323 | out = self.relu(out)
324 | out = self.maxpool(out)
325 |
326 | if self.drop_rate > 0:
327 | if self.drop_block == True:
328 | feat_size = out.size()[2]
329 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate)
330 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2
331 | out = self.DropBlock(out, gamma=gamma)
332 | else:
333 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True)
334 |
335 | return out
336 |
337 |
338 | class resnet(nn.Module):
339 |
340 | def __init__(self, block, n_blocks, keep_prob=1.0, avg_pool=False, drop_rate=0.0,
341 | dropblock_size=5, num_classes=-1, use_se=False):
342 | super(resnet, self).__init__()
343 |
344 | self.inplanes = 3
345 | self.use_se = use_se
346 | self.layer1 = self._make_layer(block, n_blocks[0], 64,
347 | stride=2, drop_rate=drop_rate)
348 | self.layer2 = self._make_layer(block, n_blocks[1], 160,
349 | stride=2, drop_rate=drop_rate)
350 | self.layer3 = self._make_layer(block, n_blocks[2], 320,
351 | stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
352 | self.layer4 = self._make_layer(block, n_blocks[3], 640,
353 | stride=1, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
354 | self.keep_prob = keep_prob
355 | self.keep_avg_pool = avg_pool
356 | self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False)
357 | self.drop_rate = drop_rate
358 | self.feat_dim = [640, 10, 10]
359 |
360 | for m in self.modules():
361 | if isinstance(m, nn.Conv2d):
362 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
363 | elif isinstance(m, nn.BatchNorm2d):
364 | nn.init.constant_(m.weight, 1)
365 | nn.init.constant_(m.bias, 0)
366 |
367 | self.num_classes = num_classes
368 | if self.num_classes > 0:
369 | self.classifier = nn.Linear(640, self.num_classes)
370 |
371 | def _make_layer(self, block, n_block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1):
372 | downsample = None
373 | if stride != 1 or self.inplanes != planes * block.expansion:
374 | downsample = nn.Sequential(
375 | nn.Conv2d(self.inplanes, planes * block.expansion,
376 | kernel_size=1, stride=1, bias=False),
377 | nn.BatchNorm2d(planes * block.expansion),
378 | )
379 |
380 | layers = []
381 | if n_block == 1:
382 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size, self.use_se)
383 | else:
384 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, self.use_se)
385 | layers.append(layer)
386 | self.inplanes = planes * block.expansion
387 |
388 | for i in range(1, n_block):
389 | if i == n_block - 1:
390 | layer = block(self.inplanes, planes, drop_rate=drop_rate, drop_block=drop_block,
391 | block_size=block_size, use_se=self.use_se)
392 | else:
393 | layer = block(self.inplanes, planes, drop_rate=drop_rate, use_se=self.use_se)
394 | layers.append(layer)
395 |
396 | return nn.Sequential(*layers)
397 |
398 | def forward(self, x, ):
399 | x = self.layer1(x)
400 | x = self.layer2(x)
401 | x = self.layer3(x)
402 | x = self.layer4(x)
403 | return x
404 |
405 | def ResNet12(keep_prob=1.0, avg_pool=True, **kwargs):
406 | """Constructs a ResNet-12 model.
407 | """
408 | model = resnet(BasicBlockVariant, [1, 1, 1, 1], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs)
409 | return model
410 |
411 | def ResNet34s(keep_prob=1.0, avg_pool=False, **kwargs):
412 | """Constructs a ResNet-24 model.
413 | """
414 | model = resnet(BasicBlockVariant, [2, 3, 4, 2], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs)
415 | return model
416 |
417 |
418 | if __name__ == '__main__':
419 | import argparse
420 |
421 | parser = argparse.ArgumentParser('argument for training')
422 | parser.add_argument('--model', type=str, default='resnet12',choices=['resnet12', 'resnet18', 'resnet24', 'resnet50', 'resnet101',
423 | 'seresnet12', 'seresnet18', 'seresnet24', 'seresnet50',
424 | 'seresnet101'])
425 | args = parser.parse_args()
426 |
427 | model_dict = {
428 | 'resnet12': ResNet12,
429 | }
430 |
431 | model = model_dict[args.model](avg_pool=True, drop_rate=0.1, dropblock_size=5, num_classes=64)
432 | data = torch.randn(2, 3, 84, 84)
433 | model = model.cuda()
434 | data = data.cuda()
435 | feat, logit = model(data, is_feat=True)
436 | print(feat[-1].shape)
437 | print(logit.shape)(logit.shape)
--------------------------------------------------------------------------------
/methods/FeatWalk.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 |
5 | import tqdm
6 | from sklearn.pipeline import make_pipeline
7 | from sklearn.preprocessing import StandardScaler
8 | from sklearn.svm import SVC, LinearSVC
9 | from methods.bdc_module import BDC
10 | import torch.nn.functional as F
11 |
12 | sys.path.append("..")
13 | import scipy
14 | from scipy.stats import t
15 | import network.resnet as resnet
16 | from utils.loss import *
17 | from sklearn.linear_model import LogisticRegression as LR
18 | from utils.loss import DistillKL
19 | from utils.utils import *
20 | import math
21 | from torch.nn.utils.weight_norm import WeightNorm
22 |
23 | import warnings
24 | warnings.filterwarnings("ignore")
25 |
26 |
27 | def mean_confidence_interval(data, confidence=0.95,multi = 1):
28 | a = 1.0 * np.array(data)
29 | n = len(a)
30 | m, se = np.mean(a), scipy.stats.sem(a)
31 | h = se * t._ppf((1+confidence)/2., n-1)
32 | return m * multi, h * multi
33 |
34 | def normalize(x):
35 | norm = x.pow(2).sum(1, keepdim=True).pow(1. / 2)
36 | out = x.div(norm)
37 | return out
38 |
39 | def random_sample(linspace, max_idx, num_sample=5):
40 | sample_idx = np.random.choice(range(linspace), num_sample)
41 | sample_idx += np.sort(random.sample(list(range(0, max_idx, linspace)),num_sample))
42 | return sample_idx
43 |
44 | def Triuvec(x,no_diag = False):
45 | batchSize, dim, dim = x.shape
46 | r = x.reshape(batchSize, dim * dim)
47 | I = torch.ones(dim, dim).triu()
48 | if no_diag:
49 | I -= torch.eye(dim,dim)
50 | I = I.reshape(dim * dim)
51 | index = I.nonzero(as_tuple = False)
52 | # y = torch.zeros(batchSize, int(dim * (dim + 1) / 2), device=x.device).type(x.dtype)
53 | y = r[:, index].squeeze()
54 | return y
55 |
56 | def Triumap(x,no_diag = False):
57 |
58 | batchSize, dim, dim, h, w = x.shape
59 | r = x.reshape(batchSize, dim * dim, h, w)
60 | I = torch.ones(dim, dim).triu()
61 | if no_diag:
62 | I -= torch.eye(dim,dim)
63 | I = I.reshape(dim * dim)
64 | index = I.nonzero(as_tuple = False)
65 | # y = torch.zeros(batchSize, int(dim * (dim + 1) / 2), device=x.device).type(x.dtype)
66 | y = r[:, index, :, :].squeeze()
67 | return y
68 |
69 | def Diagvec(x):
70 | batchSize, dim, dim = x.shape
71 | r = x.reshape(batchSize, dim * dim)
72 | I = torch.eye(dim, dim).triu().reshape(dim * dim)
73 | index = I.nonzero(as_tuple = False)
74 | y = r[:, index].squeeze()
75 | return y
76 |
77 | class FeatWalk_Net(nn.Module):
78 | def __init__(self,params,num_classes = 5,):
79 | super(FeatWalk_Net, self).__init__()
80 |
81 | self.params = params
82 |
83 | if params.model == 'resnet12':
84 | self.feature = resnet.ResNet12(avg_pool=True,num_classes=64)
85 | resnet_layer_dim = [64, 160, 320, 640]
86 | elif params.model == 'resnet18':
87 | self.feature = resnet.ResNet18()
88 | resnet_layer_dim = [64, 128, 256, 512]
89 |
90 | self.resnet_layer_dim = resnet_layer_dim
91 | self.reduce_dim = params.reduce_dim
92 | self.feat_dim = self.feature.feat_dim
93 | self.dim = int(self.reduce_dim * (self.reduce_dim+1)/2)
94 | if resnet_layer_dim[-1] != self.reduce_dim:
95 |
96 | self.Conv = nn.Sequential(
97 | nn.Conv2d(resnet_layer_dim[-1], self.reduce_dim, kernel_size=1, stride=1, bias=False),
98 | nn.BatchNorm2d(self.reduce_dim),
99 | nn.ReLU(inplace=True)
100 | )
101 | self._init_weight(self.Conv.modules())
102 |
103 | drop_rate = params.drop_rate
104 | if self.params.embeding_way in ['BDC']:
105 | self.SFC = nn.Linear(self.dim, num_classes)
106 | self.SFC.bias.data.fill_(0)
107 | elif self.params.embeding_way in ['baseline++']:
108 | self.SFC = nn.Linear(self.reduce_dim, num_classes, bias=False)
109 | WeightNorm.apply(self.SFC, 'weight', dim=0)
110 | else:
111 | self.SFC = nn.Linear(self.reduce_dim, num_classes)
112 |
113 | self.drop = nn.Dropout(drop_rate)
114 |
115 | self.temperature = nn.Parameter(torch.log((1. /(2 * self.feat_dim[1] * self.feat_dim[2])* torch.ones(1, 1))),
116 | requires_grad=True)
117 |
118 | self.dcov = BDC(is_vec=True, input_dim=[self.reduce_dim,self.feature.feat_dim[1],self.feature.feat_dim[2]], dimension_reduction=self.reduce_dim)
119 |
120 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
121 |
122 | if resnet_layer_dim[-1] != self.reduce_dim:
123 | self.dcov.conv_dr_block = self.Conv
124 |
125 | self.n_shot = params.n_shot
126 | self.n_way = params.n_way
127 | self.transform_aug = params.n_aug_support_samples
128 |
129 | def _init_weight(self,modules):
130 | for m in modules:
131 | if isinstance(m, nn.Conv2d):
132 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out', nonlinearity='leaky_relu')
133 | elif isinstance(m, nn.BatchNorm2d):
134 | nn.init.constant_(m.weight, 1)
135 | nn.init.constant_(m.bias, 0)
136 |
137 | def normalize(self,x):
138 | x = (x - torch.mean(x, dim=1).unsqueeze(1))
139 | return x
140 |
141 | def forward_feature(self, x):
142 | feat_map = self.feature(x, )
143 | if self.resnet_layer_dim[-1] != self.reduce_dim:
144 | feat_map = self.Conv(feat_map)
145 | out = feat_map
146 | return out
147 |
148 | def normalize_feature(self, x):
149 | if self.params.norm == 'center':
150 | x = x - x.mean(2).unsqueeze(2)
151 | return x
152 | else:
153 | return x
154 |
155 | def forward_pretrain(self, x):
156 | x = self.forward_feature(x)
157 | x = self.drop(x)
158 | return self.SFC(x)
159 |
160 | def train_loop(self,epoch,train_loader,optimizer):
161 | print_step = 100
162 | avg_loss = 0
163 | total_correct = 0
164 | iter_num = len(train_loader)
165 | total = 0
166 | loss_ce_fn = nn.CrossEntropyLoss()
167 | for i ,data in enumerate(train_loader):
168 | image , label = data
169 | image = image.cuda()
170 | label = label.cuda()
171 | out = self.forward_pretrain(image)
172 | loss = loss_ce_fn(out, label)
173 | avg_loss = avg_loss + loss.item()
174 | optimizer.zero_grad()
175 | loss.backward()
176 | optimizer.step()
177 | _, pred = torch.max(out, 1)
178 | correct = (pred == label).sum().item()
179 | total_correct += correct
180 | total += label.size(0)
181 | if i % print_step == 0:
182 | print('\rEpoch {:d} | Batch: {:d}/{:d} | Loss: {:.4f} | Acc_train: {:.2f}'.format(epoch, i, len(train_loader),
183 | avg_loss / float(i + 1),correct/label.shape[0]*100), end=' ')
184 | print()
185 |
186 | return avg_loss / iter_num, float(total_correct) / total * 100
187 |
188 | def meta_val_loop(self,val_loader):
189 | acc = []
190 | for i, data in enumerate(val_loader):
191 |
192 | support_xs, support_ys, query_xs, query_ys = data
193 | support_xs = support_xs.cuda()
194 | query_xs = query_xs.cuda()
195 | split_size = 128
196 | if support_xs.squeeze(0).shape[0] >= split_size:
197 | feat_sup_ = []
198 | for j in range(math.ceil(support_xs.squeeze(0).shape[0] / split_size)):
199 | fest_sup_item = self.forward_feature(
200 | support_xs.squeeze(0)[j * split_size:min((j + 1) * split_size, support_xs.shape[1]), :, :, :],)
201 | feat_sup_.append(fest_sup_item if len(fest_sup_item.shape) >= 1 else fest_sup_item.unsqueeze(0))
202 | feat_sup = torch.cat(feat_sup_, dim=0)
203 | else:
204 | feat_sup = self.forward_feature(support_xs.squeeze(0),)
205 | if query_xs.squeeze(0).shape[0] > split_size:
206 | feat_qry_ = []
207 | for j in range(math.ceil(query_xs.squeeze(0).shape[0] / split_size)):
208 | feat_qry_item = self.forward_feature(
209 | query_xs.squeeze(0)[j * split_size:min((j + 1) * split_size, query_xs.shape[1]), :, :, :],
210 | )
211 | feat_qry_.append(feat_qry_item if len(feat_qry_item.shape) > 1 else feat_qry_item.unsqueeze(0))
212 |
213 | feat_qry = torch.cat(feat_qry_, dim=0)
214 | else:
215 | feat_qry = self.forward_feature(query_xs.squeeze(0),)
216 | if self.params.LR:
217 | pred = self.LR(feat_sup, support_ys, feat_qry, query_ys)
218 | else:
219 | with torch.enable_grad():
220 | pred = self.softmax(feat_sup, support_ys, feat_qry, )
221 | _, pred = torch.max(pred, dim=-1)
222 | if self.params.n_symmetry_aug > 1:
223 | query_ys = query_ys.view(-1, self.params.n_symmetry_aug)
224 | query_ys = torch.mode(query_ys, dim=-1)[0]
225 | acc_epo = np.mean(pred.cpu().numpy() == query_ys.numpy())
226 | acc.append(acc_epo)
227 | return mean_confidence_interval(acc)
228 |
229 | def meta_test_loop(self,test_loader):
230 | acc = []
231 | for i, (x, _) in enumerate(test_loader):
232 | self.params.n_aug_support_samples = self.transform_aug
233 | tic = time.time()
234 | x = x.contiguous().view(self.n_way, (self.n_shot + self.params.n_queries), *x.size()[2:])
235 | support_xs = x[:, :self.n_shot].contiguous().view(
236 | self.n_way * self.n_shot * self.params.n_aug_support_samples, *x.size()[3:]).cuda()
237 | query_xs = x[:, self.n_shot:, 0:self.params.n_symmetry_aug].contiguous().view(
238 | self.n_way * self.params.n_queries * self.params.n_symmetry_aug, *x.size()[3:]).cuda()
239 |
240 | support_y = torch.from_numpy(np.repeat(range(self.params.n_way),self.n_shot*self.params.n_aug_support_samples)).unsqueeze(0)
241 | split_size = 128
242 | if support_xs.shape[0] >= split_size:
243 | feat_sup_ = []
244 | for j in range(math.ceil(support_xs.shape[0]/split_size)):
245 | fest_sup_item =self.forward_feature(support_xs[j*split_size:min((j+1)*split_size,support_xs.shape[0]),],)
246 | feat_sup_.append(fest_sup_item if len(fest_sup_item.shape)>=1 else fest_sup_item.unsqueeze(0))
247 | feat_sup = torch.cat(feat_sup_,dim=0)
248 | else:
249 | feat_sup = self.forward_feature(support_xs)
250 | if query_xs.shape[0] >= split_size:
251 | feat_qry_ = []
252 | for j in range(math.ceil(query_xs.shape[0]/split_size)):
253 | feat_qry_item = self.forward_feature(
254 | query_xs[j * split_size:min((j + 1) * split_size, query_xs.shape[0]), ],)
255 | feat_qry_.append(feat_qry_item if len(feat_qry_item.shape) > 1 else feat_qry_item.unsqueeze(0))
256 |
257 | feat_qry = torch.cat(feat_qry_,dim=0)
258 | else:
259 | feat_qry = self.forward_feature(query_xs,)
260 |
261 | if self.params.LR:
262 | pred = self.predict_wo_fc(feat_sup, support_y, feat_qry,)
263 |
264 | else:
265 | with torch.enable_grad():
266 | pred = self.softmax(feat_sup, support_y, feat_qry,)
267 | _,pred = torch.max(pred,dim=-1)
268 |
269 | query_ys = np.repeat(range(self.n_way), self.params.n_queries)
270 | pred = pred.view(-1)
271 | acc_epo = np.mean(pred.cpu().numpy() == query_ys)
272 | acc.append(acc_epo)
273 | print("\repisode {} acc: {:.2f} | avg_acc: {:.2f} +- {:.2f}, elapse : {:.2f}".format(i, acc_epo * 100,
274 | *mean_confidence_interval(
275 | acc, multi=100), (
276 | time.time() - tic) / 60),
277 | end='')
278 |
279 | return mean_confidence_interval(acc)
280 |
281 | def distillation(self,epoch,train_loader,optimizer,model_t):
282 | print_step = 100
283 | avg_loss = 0
284 | total_correct = 0
285 | iter_num = len(train_loader)
286 | total = 0
287 | loss_div_fn = DistillKL(4)
288 | loss_ce_fn = nn.CrossEntropyLoss()
289 | for i, data in enumerate(train_loader):
290 | image, label = data
291 | image = image.cuda()
292 | label = label.cuda()
293 | with torch.no_grad():
294 | out_t = model_t.forward_pretrain(image)
295 |
296 | out= self.forward_pretrain(image)
297 | loss_ce = loss_ce_fn(out, label)
298 | loss_div = loss_div_fn(out, out_t)
299 |
300 | loss = loss_ce * 0.5 + loss_div * 0.5
301 | avg_loss = avg_loss + loss.item()
302 | optimizer.zero_grad()
303 | loss.backward()
304 | optimizer.step()
305 |
306 | _, pred = torch.max(out, 1)
307 | correct = (pred == label).sum().item()
308 | total_correct += correct
309 | total += label.size(0)
310 | if i % print_step == 0:
311 | print('\rEpoch {:d} | Batch: {:d}/{:d} | Loss: {:.4f} | Acc_train: {:.2f}'.format(epoch, i,
312 | len(train_loader),
313 | avg_loss / float(
314 | i + 1),
315 | correct / label.shape[
316 | 0] * 100),
317 | end=' ')
318 | print()
319 | return avg_loss / iter_num, float(total_correct) / total * 100
320 |
321 | # new selective local fusion :
322 | def softmax(self,support_z,support_ys,query_z,):
323 | loss_ce_fn = nn.CrossEntropyLoss()
324 | batch_size = self.params.sfc_bs
325 | walk_times = 24
326 | alpha = self.params.alpha
327 | tempe = self.params.sim_temperature
328 | support_ys = support_ys.cuda()
329 |
330 | if self.params.embeding_way in ['BDC']:
331 | SFC = nn.Linear(self.dim, self.params.n_way).cuda()
332 | iter_num = 100
333 | optimizer = torch.optim.AdamW([{'params': SFC.parameters()}], lr=0.001,
334 | weight_decay=self.params.wd_test,eps=1e-4)
335 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, iter_num * math.ceil(self.n_way*self.n_shot/batch_size),eta_min=1e-3)
336 |
337 |
338 | else:
339 | tempe =16
340 |
341 | if self.params.embeding_way in ['baseline++']:
342 | SFC = nn.Linear(self.reduce_dim, self.params.n_way, bias=False).cuda()
343 | WeightNorm.apply(SFC, 'weight', dim=0)
344 | else:
345 | SFC = nn.Linear(self.reduce_dim, self.params.n_way).cuda()
346 |
347 | if self.params.optim in ['Adam']:
348 | # lr = 5e-3
349 | optimizer = torch.optim.AdamW([{'params': SFC.parameters()}], lr=0.005,
350 | weight_decay=self.params.wd_test, eps=5e-3)
351 |
352 | iter_num = 100
353 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, iter_num * math.ceil(
354 | self.n_way * self.n_shot / batch_size), eta_min=5e-3)
355 |
356 |
357 | else:
358 | optimizer = torch.optim.SGD([{'params': SFC.parameters()}],
359 | lr=self.params.lr, momentum=0.9, nesterov=True,
360 | weight_decay=self.params.wd_test)
361 |
362 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 150], gamma=0.1)
363 | iter_num = 180
364 |
365 |
366 | SFC.train()
367 |
368 | if self.params.embeding_way in ['BDC']:
369 | support_z = self.dcov(support_z)
370 | query_z = self.dcov(query_z)
371 |
372 | else:
373 | support_z = self.avg_pool(support_z).view(support_z.shape[0], -1)
374 | query_z = self.avg_pool(query_z)
375 |
376 | support_ys = support_ys.view(self.n_way * self.n_shot, self.params.n_aug_support_samples, -1)
377 | global_ys = support_ys[:, 0, :]
378 |
379 | support_z = support_z.reshape(self.n_way,self.n_shot,self.params.n_aug_support_samples,-1)
380 | query_z = query_z.reshape(self.n_way,self.params.n_queries,self.params.n_aug_support_samples,-1)
381 |
382 |
383 | feat_q = query_z[:,:,0]
384 | feat_ql = query_z[:,:,1:]
385 | feat_g = support_z[:,:,0]
386 | feat_sl = support_z[:,:,1:]
387 | # w_local: n * k * n * m
388 | num_sample = self.n_way*self.n_shot
389 | global_ys = global_ys.view(self.n_way,self.n_shot,-1)
390 |
391 | feat_g = feat_g.detach()
392 | feat_sl = feat_sl.detach()
393 |
394 | # feat_sl: n * k * n * dim
395 | I = torch.eye(self.n_way,self.n_way,device=feat_g.device).unsqueeze(0).unsqueeze(1)
396 | proto_moving = torch.mean(feat_g, dim=1)
397 |
398 |
399 |
400 | for i in range(iter_num):
401 | weight = compute_weight_local(proto_moving.unsqueeze(1), feat_sl, feat_sl, self.params.measure)
402 | idx_walk = torch.randperm(self.params.n_aug_support_samples-1,)[:walk_times]
403 | w_local = F.softmax(weight[:,:,:,idx_walk] * tempe, dim=-1)
404 | feat_s = torch.sum((feat_sl[:,:,idx_walk,:].unsqueeze(-3)) * (w_local.detach().unsqueeze(-1)), dim=-2)
405 | support_x = alpha * feat_g.unsqueeze(-2) + (1- alpha) * feat_s
406 | proto_update = torch.sum(torch.matmul(torch.mean(support_x,dim=1).transpose(1,2),torch.eye(self.n_way,device=proto_moving.device).unsqueeze(0)),dim=-1)
407 | proto_moving = 0.9 * proto_moving + 0.1 * proto_update
408 | spt_norm = torch.norm(support_x, p=2, dim=-1).unsqueeze(-1).expand_as(support_x)
409 | support_x = support_x.div(spt_norm + 1e-6)
410 |
411 |
412 | SFC.train()
413 | sample_idxs = torch.randperm(num_sample)
414 | for j in range(math.ceil(num_sample/batch_size)):
415 | idxs = sample_idxs[j*batch_size:min((j+1)*batch_size,num_sample)]
416 | x = support_x[idxs//self.n_shot,idxs%self.n_shot]
417 | y = global_ys[idxs//self.n_shot,idxs%self.n_shot]
418 | x = self.drop(x)
419 | # out = torch.sum(SFC(x)*I,dim=-1).view(-1,self.n_way)
420 | out = torch.sum(x.mul(SFC.weight),dim=-1) + SFC.bias
421 | loss_ce = loss_ce_fn(out,y.long().view(-1))
422 | loss = loss_ce
423 | optimizer.zero_grad()
424 | loss.backward()
425 | optimizer.step()
426 | if lr_scheduler is not None:
427 | lr_scheduler.step()
428 |
429 | SFC.eval()
430 |
431 | w_local = compute_weight_local(proto_moving.unsqueeze(1), feat_ql, feat_sl,self.params.measure)
432 | w_local = F.softmax(w_local * tempe, dim=-1)
433 |
434 | # feat_sl: n * k * n * dim
435 | feat_lq = torch.sum(feat_ql.unsqueeze(-3) * w_local.unsqueeze(-1), dim=-2)
436 | query_x = alpha * feat_q.unsqueeze(-2) + (1- alpha) * feat_lq
437 |
438 | spt_norm = torch.norm(query_x, p=2, dim=-1).unsqueeze(-1).expand_as(query_x)
439 | query_x = query_x.div(spt_norm + 1e-6)
440 |
441 | with torch.no_grad():
442 | # out = torch.sum(SFC(query_x)*I,dim=-1).view(-1,self.n_way)
443 | out = torch.sum(query_x.mul(SFC.weight), dim=-1) + SFC.bias
444 |
445 | return out
446 |
447 | def LR(self,support_z,support_ys,query_z,query_ys):
448 |
449 | clf = LR(penalty='l2',
450 | random_state=0,
451 | C=self.params.penalty_c,
452 | solver='lbfgs',
453 | max_iter=1000,
454 | multi_class='multinomial')
455 |
456 | spt_norm = torch.norm(support_z, p=2, dim=1).unsqueeze(1).expand_as(support_z)
457 | spt_normalized = support_z.div(spt_norm + 1e-6)
458 |
459 | qry_norm = torch.norm(query_z, p=2, dim=1).unsqueeze(1).expand_as(query_z)
460 | qry_normalized = query_z.div(qry_norm + 1e-6)
461 |
462 | z_support = spt_normalized.detach().cpu().numpy()
463 | z_query = qry_normalized.detach().cpu().numpy()
464 |
465 | y_support = np.repeat(range(self.params.n_way), self.n_shot)
466 |
467 | clf.fit(z_support, y_support)
468 |
469 | return torch.from_numpy(clf.predict(z_query))
470 |
--------------------------------------------------------------------------------