├── model ├── __init__.py ├── trainer │ ├── __init__.py │ ├── base.py │ ├── helpers.py │ └── fsl_trainer.py ├── models │ ├── utils │ │ ├── __init__.py │ │ ├── embedder.py │ │ ├── stochastic_depth.py │ │ ├── tokenizer.py │ │ └── transformers.py │ ├── __init__.py │ ├── base.py │ ├── protonet.py │ ├── fcanet.py │ ├── INSTA_ProtoNet.py │ └── INSTA.py ├── networks │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ ├── embedder.py │ │ ├── stochastic_depth.py │ │ ├── tokenizer.py │ │ └── transformers.py │ ├── dropblock.py │ ├── res12.py │ ├── res18.py │ └── res10.py ├── logger.py ├── dataloader │ ├── split_cub.py │ ├── samplers.py │ ├── transforms.py │ ├── mini_imagenet.py │ ├── cub.py │ └── tiered_imagenet.py ├── data_parallel.py └── utils.py ├── data ├── cub │ └── .gitignore └── miniimagenet │ ├── .gitignore │ └── download.sh ├── visual ├── concept.png ├── heatmap.png └── pipeline.png ├── .idea ├── misc.xml ├── vcs.xml ├── .gitignore ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── modules.xml ├── code.iml ├── remote-mappings.xml └── deployment.xml ├── train_fsl.py ├── LICENSE └── README.md /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/cub/.gitignore: -------------------------------------------------------------------------------- 1 | images 2 | -------------------------------------------------------------------------------- /model/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/networks/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/networks/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/miniimagenet/.gitignore: -------------------------------------------------------------------------------- 1 | images 2 | -------------------------------------------------------------------------------- /model/models/__init__.py: -------------------------------------------------------------------------------- 1 | from model.models.base import FewShotModel_1 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /visual/concept.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RongKaiWeskerMA/INSTA/HEAD/visual/concept.png -------------------------------------------------------------------------------- /visual/heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RongKaiWeskerMA/INSTA/HEAD/visual/heatmap.png -------------------------------------------------------------------------------- /visual/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RongKaiWeskerMA/INSTA/HEAD/visual/pipeline.png -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /data/miniimagenet/download.sh: -------------------------------------------------------------------------------- 1 | 2 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1BCxmqLANXHbBaWs8A7_jqfVUv8mydp5R' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1BCxmqLANXHbBaWs8A7_jqfVUv8mydp5R" -O miniimagenet.zip && rm -rf /tmp/cookies.txt 3 | 4 | unzip miniimagenet.zip miniimagenet/ 5 | -------------------------------------------------------------------------------- /.idea/code.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/remote-mappings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /train_fsl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from model.trainer.fsl_trainer import FSLTrainer 4 | from model.utils import ( 5 | pprint, set_gpu, 6 | get_command_line_parser, 7 | postprocess_args, 8 | ) 9 | # from ipdb import launch_ipdb_on_exception 10 | 11 | if __name__ == '__main__': 12 | parser = get_command_line_parser() 13 | args = postprocess_args(parser.parse_args()) 14 | # with launch_ipdb_on_exception(): 15 | pprint(vars(args)) 16 | 17 | set_gpu(args.gpu) 18 | trainer = FSLTrainer(args) 19 | trainer.train() 20 | trainer.evaluate_test() 21 | trainer.final_record() 22 | print(args.save_path) 23 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 RongKaiWeskerMA 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model/models/utils/embedder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Embedder(nn.Module): 5 | def __init__(self, 6 | word_embedding_dim=300, 7 | vocab_size=100000, 8 | padding_idx=1, 9 | pretrained_weight=None, 10 | embed_freeze=False, 11 | *args, **kwargs): 12 | super(Embedder, self).__init__() 13 | self.embeddings = nn.Embedding.from_pretrained(pretrained_weight, freeze=embed_freeze) \ 14 | if pretrained_weight is not None else \ 15 | nn.Embedding(vocab_size, word_embedding_dim, padding_idx=padding_idx) 16 | self.embeddings.weight.requires_grad = not embed_freeze 17 | 18 | def forward_mask(self, mask): 19 | bsz, seq_len = mask.shape 20 | new_mask = mask.view(bsz, seq_len, 1) 21 | new_mask = new_mask.sum(-1) 22 | new_mask = (new_mask > 0) 23 | return new_mask 24 | 25 | def forward(self, x, mask=None): 26 | embed = self.embeddings(x) 27 | embed = embed if mask is None else embed * self.forward_mask(mask).unsqueeze(-1).float() 28 | return embed, mask 29 | 30 | @staticmethod 31 | def init_weight(m): 32 | if isinstance(m, nn.Linear): 33 | nn.init.trunc_normal_(m.weight, std=.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | else: 37 | nn.init.normal_(m.weight) 38 | -------------------------------------------------------------------------------- /model/networks/utils/embedder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Embedder(nn.Module): 5 | def __init__(self, 6 | word_embedding_dim=300, 7 | vocab_size=100000, 8 | padding_idx=1, 9 | pretrained_weight=None, 10 | embed_freeze=False, 11 | *args, **kwargs): 12 | super(Embedder, self).__init__() 13 | self.embeddings = nn.Embedding.from_pretrained(pretrained_weight, freeze=embed_freeze) \ 14 | if pretrained_weight is not None else \ 15 | nn.Embedding(vocab_size, word_embedding_dim, padding_idx=padding_idx) 16 | self.embeddings.weight.requires_grad = not embed_freeze 17 | 18 | def forward_mask(self, mask): 19 | bsz, seq_len = mask.shape 20 | new_mask = mask.view(bsz, seq_len, 1) 21 | new_mask = new_mask.sum(-1) 22 | new_mask = (new_mask > 0) 23 | return new_mask 24 | 25 | def forward(self, x, mask=None): 26 | embed = self.embeddings(x) 27 | embed = embed if mask is None else embed * self.forward_mask(mask).unsqueeze(-1).float() 28 | return embed, mask 29 | 30 | @staticmethod 31 | def init_weight(m): 32 | if isinstance(m, nn.Linear): 33 | nn.init.trunc_normal_(m.weight, std=.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | else: 37 | nn.init.normal_(m.weight) 38 | -------------------------------------------------------------------------------- /model/models/utils/stochastic_depth.py: -------------------------------------------------------------------------------- 1 | # Thanks to rwightman's timm package 2 | # github.com:rwightman/pytorch-image-models 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def drop_path(x, drop_prob: float = 0., training: bool = False): 9 | """ 10 | Obtained from: github.com:rwightman/pytorch-image-models 11 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 12 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 13 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 14 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 15 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 16 | 'survival rate' as the argument. 17 | """ 18 | if drop_prob == 0. or not training: 19 | return x 20 | keep_prob = 1 - drop_prob 21 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 22 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 23 | random_tensor.floor_() # binarize 24 | output = x.div(keep_prob) * random_tensor 25 | return output 26 | 27 | 28 | class DropPath(nn.Module): 29 | """ 30 | Obtained from: github.com:rwightman/pytorch-image-models 31 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 32 | """ 33 | 34 | def __init__(self, drop_prob=None): 35 | super(DropPath, self).__init__() 36 | self.drop_prob = drop_prob 37 | 38 | def forward(self, x): 39 | return drop_path(x, self.drop_prob, self.training) 40 | -------------------------------------------------------------------------------- /model/networks/utils/stochastic_depth.py: -------------------------------------------------------------------------------- 1 | # Thanks to rwightman's timm package 2 | # github.com:rwightman/pytorch-image-models 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def drop_path(x, drop_prob: float = 0., training: bool = False): 9 | """ 10 | Obtained from: github.com:rwightman/pytorch-image-models 11 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 12 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 13 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 14 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 15 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 16 | 'survival rate' as the argument. 17 | """ 18 | if drop_prob == 0. or not training: 19 | return x 20 | keep_prob = 1 - drop_prob 21 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 22 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 23 | random_tensor.floor_() # binarize 24 | output = x.div(keep_prob) * random_tensor 25 | return output 26 | 27 | 28 | class DropPath(nn.Module): 29 | """ 30 | Obtained from: github.com:rwightman/pytorch-image-models 31 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 32 | """ 33 | 34 | def __init__(self, drop_prob=None): 35 | super(DropPath, self).__init__() 36 | self.drop_prob = drop_prob 37 | 38 | def forward(self, x): 39 | return drop_path(x, self.drop_prob, self.training) 40 | -------------------------------------------------------------------------------- /model/logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path as osp 3 | import numpy as np 4 | from collections import defaultdict, OrderedDict 5 | from tensorboardX import SummaryWriter 6 | 7 | class ConfigEncoder(json.JSONEncoder): 8 | def default(self, o): 9 | if isinstance(o, type): 10 | return {'$class': o.__module__ + "." + o.__name__} 11 | elif isinstance(o, Enum): 12 | return { 13 | '$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name 14 | } 15 | elif callable(o): 16 | return { 17 | '$function': o.__module__ + "." + o.__name__ 18 | } 19 | return json.JSONEncoder.default(self, o) 20 | 21 | class Logger(object): 22 | def __init__(self, args, log_dir, **kwargs): 23 | self.logger_path = osp.join(log_dir, 'scalars.json') 24 | self.tb_logger = SummaryWriter( 25 | logdir=osp.join(log_dir, 'tflogger'), 26 | **kwargs, 27 | ) 28 | self.log_config(vars(args)) 29 | 30 | self.scalars = defaultdict(OrderedDict) 31 | 32 | def add_scalar(self, key, value, counter): 33 | assert self.scalars[key].get(counter, None) is None, 'counter should be distinct' 34 | self.scalars[key][counter] = value 35 | self.tb_logger.add_scalar(key, value, counter) 36 | 37 | def log_config(self, variant_data): 38 | config_filepath = osp.join(osp.dirname(self.logger_path), 'configs.json') 39 | with open(config_filepath, "w") as fd: 40 | json.dump(variant_data, fd, indent=2, sort_keys=True, cls=ConfigEncoder) 41 | 42 | def dump(self): 43 | with open(self.logger_path, 'w') as fd: 44 | json.dump(self.scalars, fd, indent=2) -------------------------------------------------------------------------------- /model/dataloader/split_cub.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | from os import listdir 5 | from os.path import isfile, isdir, join 6 | import random 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--data', type=str, help='path to the data') 11 | parser.add_argument('--split', type=str, help='path to the split folder') 12 | args = parser.parse_args() 13 | dataset_list = ['train','val','test'] 14 | # 15 | prex1 = args.data 16 | data_path = join(prex1,'images/') 17 | # 18 | folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))] 19 | folder_list.sort() 20 | label_dict = dict(zip(folder_list,range(0,len(folder_list)))) 21 | 22 | classfile_list_all = [] 23 | 24 | for i, folder in enumerate(folder_list): 25 | folder_path = join(data_path, folder) 26 | classfile_list_all.append( [join(folder,cf) for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')]) 27 | random.shuffle(classfile_list_all[i]) 28 | 29 | if not os.path.isdir(args.split): 30 | os.makedirs(args.split) 31 | 32 | 33 | for dataset in dataset_list: 34 | file_list = [] 35 | label_list = [] 36 | for i, classfile_list in enumerate(classfile_list_all): 37 | if 'train' in dataset: 38 | if (i%2 == 0): 39 | file_list = file_list + classfile_list 40 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 41 | 42 | if 'val' in dataset: 43 | if (i%4 == 1): 44 | file_list = file_list + classfile_list 45 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 46 | 47 | if 'test' in dataset: 48 | if (i%4 == 3): 49 | file_list = file_list + classfile_list 50 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() -------------------------------------------------------------------------------- /model/models/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from sklearn.svm import LinearSVC 5 | from sklearn.linear_model import LogisticRegression 6 | from sklearn.model_selection import GridSearchCV 7 | 8 | 9 | class FewShotModel_1(nn.Module): 10 | def __init__(self, args): 11 | super().__init__() 12 | self.args = args 13 | # from model.models.ddf import DDF 14 | if args.backbone_class == 'Res12': 15 | hdim = 640 16 | from model.networks.res12 import ResNet 17 | self.encoder = ResNet() 18 | elif args.backbone_class == 'Res18': 19 | hdim = 512 20 | from model.networks.res18 import ResNet 21 | self.encoder = ResNet() 22 | else: 23 | raise ValueError('') 24 | 25 | def split_instances(self, data): 26 | args = self.args 27 | if self.training: 28 | return (torch.Tensor(np.arange(args.way*args.shot)).long().view(1, args.shot, args.way), 29 | torch.Tensor(np.arange(args.way*args.shot, args.way * (args.shot + args.query))).long().view(1, args.query, args.way)) 30 | else: 31 | return (torch.Tensor(np.arange(args.eval_way*args.eval_shot)).long().view(1, args.eval_shot, args.eval_way), 32 | torch.Tensor(np.arange(args.eval_way*args.eval_shot, args.eval_way * (args.eval_shot + args.eval_query))).long().view(1, args.eval_query, args.eval_way)) 33 | 34 | 35 | def forward(self, x, get_feature=False): 36 | if get_feature: 37 | # get feature with the provided embeddings 38 | return self.encoder(x) 39 | else: 40 | # feature extraction 41 | x = x.squeeze(0) 42 | instance_embs = self.encoder(x) 43 | 44 | support_idx, query_idx = self.split_instances(x) 45 | if self.training: 46 | logits, logits_reg = self._forward(instance_embs, support_idx, query_idx) 47 | return logits, logits_reg 48 | else: 49 | logits = self._forward(instance_embs, support_idx, query_idx) 50 | return logits 51 | 52 | def _forward(self, x, support_idx, query_idx): 53 | raise NotImplementedError('Suppose to be implemented by subclass') -------------------------------------------------------------------------------- /model/networks/dropblock.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.distributions import Bernoulli 5 | 6 | 7 | class DropBlock(nn.Module): 8 | def __init__(self, block_size): 9 | super(DropBlock, self).__init__() 10 | 11 | self.block_size = block_size 12 | 13 | def forward(self, x, gamma): 14 | # shape: (bsize, channels, height, width) 15 | 16 | if self.training: 17 | batch_size, channels, height, width = x.shape 18 | bernoulli = Bernoulli(gamma) 19 | mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))) 20 | if torch.cuda.is_available(): 21 | mask = mask.cuda() 22 | block_mask = self._compute_block_mask(mask) 23 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3] 24 | count_ones = block_mask.sum() 25 | 26 | return block_mask * x * (countM / count_ones) 27 | else: 28 | return x 29 | 30 | def _compute_block_mask(self, mask): 31 | left_padding = int((self.block_size-1) / 2) 32 | right_padding = int(self.block_size / 2) 33 | 34 | batch_size, channels, height, width = mask.shape 35 | non_zero_idxs = mask.nonzero() 36 | nr_blocks = non_zero_idxs.shape[0] 37 | 38 | offsets = torch.stack( 39 | [ 40 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding, 41 | torch.arange(self.block_size).repeat(self.block_size), #- left_padding 42 | ] 43 | ).t() 44 | offsets = torch.cat((torch.zeros(self.block_size**2, 2).long(), offsets.long()), 1) 45 | if torch.cuda.is_available(): 46 | offsets = offsets.cuda() 47 | 48 | if nr_blocks > 0: 49 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1) 50 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4) 51 | offsets = offsets.long() 52 | 53 | block_idxs = non_zero_idxs + offsets 54 | #block_idxs += left_padding 55 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 56 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1. 57 | else: 58 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 59 | 60 | block_mask = 1 - padded_mask#[:height, :width] 61 | return block_mask 62 | -------------------------------------------------------------------------------- /model/dataloader/samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class CategoriesSampler(): 6 | 7 | def __init__(self, label, n_batch, n_cls, n_per): 8 | self.n_batch = n_batch 9 | self.n_cls = n_cls 10 | self.n_per = n_per 11 | 12 | label = np.array(label) 13 | self.m_ind = [] 14 | for i in range(max(label) + 1): 15 | ind = np.argwhere(label == i).reshape(-1) 16 | ind = torch.from_numpy(ind) 17 | self.m_ind.append(ind) 18 | 19 | def __len__(self): 20 | return self.n_batch 21 | 22 | def __iter__(self): 23 | for i_batch in range(self.n_batch): 24 | batch = [] 25 | classes = torch.randperm(len(self.m_ind))[:self.n_cls] 26 | for c in classes: 27 | l = self.m_ind[c] 28 | pos = torch.randperm(len(l))[:self.n_per] 29 | batch.append(l[pos]) 30 | batch = torch.stack(batch).t().reshape(-1) 31 | yield batch 32 | 33 | 34 | class RandomSampler(): 35 | 36 | def __init__(self, label, n_batch, n_per): 37 | self.n_batch = n_batch 38 | self.n_per = n_per 39 | self.label = np.array(label) 40 | self.num_label = self.label.shape[0] 41 | 42 | def __len__(self): 43 | return self.n_batch 44 | 45 | def __iter__(self): 46 | for i_batch in range(self.n_batch): 47 | batch = torch.randperm(self.num_label)[:self.n_per] 48 | yield batch 49 | 50 | 51 | # sample for each class 52 | class ClassSampler(): 53 | 54 | def __init__(self, label, n_per=None): 55 | self.n_per = n_per 56 | label = np.array(label) 57 | self.m_ind = [] 58 | for i in range(max(label) + 1): 59 | ind = np.argwhere(label == i).reshape(-1) 60 | ind = torch.from_numpy(ind) 61 | self.m_ind.append(ind) 62 | 63 | def __len__(self): 64 | return len(self.m_ind) 65 | 66 | def __iter__(self): 67 | classes = torch.arange(len(self.m_ind)) 68 | for c in classes: 69 | l = self.m_ind[int(c)] 70 | if self.n_per is None: 71 | pos = torch.randperm(len(l)) 72 | else: 73 | pos = torch.randperm(len(l))[:self.n_per] 74 | yield l[pos] 75 | 76 | 77 | # for ResNet Fine-Tune, which output the same index of task examples several times 78 | class InSetSampler(): 79 | 80 | def __init__(self, n_batch, n_sbatch, pool): # pool is a tensor 81 | self.n_batch = n_batch 82 | self.n_sbatch = n_sbatch 83 | self.pool = pool 84 | self.pool_size = pool.shape[0] 85 | 86 | def __len__(self): 87 | return self.n_batch 88 | 89 | def __iter__(self): 90 | for i_batch in range(self.n_batch): 91 | batch = self.pool[torch.randperm(self.pool_size)[:self.n_sbatch]] 92 | yield batch -------------------------------------------------------------------------------- /model/dataloader/transforms.py: -------------------------------------------------------------------------------- 1 | # Credits to DeepVoltaire 2 | # github:DeepVoltaire/AutoAugment 3 | 4 | from PIL import Image, ImageEnhance, ImageOps 5 | import random 6 | 7 | 8 | class ShearX(object): 9 | def __init__(self, fillcolor=(128, 128, 128)): 10 | self.fillcolor = fillcolor 11 | 12 | def __call__(self, x, magnitude): 13 | return x.transform( 14 | x.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 15 | Image.BICUBIC, fillcolor=self.fillcolor) 16 | 17 | 18 | class ShearY(object): 19 | def __init__(self, fillcolor=(128, 128, 128)): 20 | self.fillcolor = fillcolor 21 | 22 | def __call__(self, x, magnitude): 23 | return x.transform( 24 | x.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 25 | Image.BICUBIC, fillcolor=self.fillcolor) 26 | 27 | 28 | class TranslateX(object): 29 | def __init__(self, fillcolor=(128, 128, 128)): 30 | self.fillcolor = fillcolor 31 | 32 | def __call__(self, x, magnitude): 33 | return x.transform( 34 | x.size, Image.AFFINE, (1, 0, magnitude * x.size[0] * random.choice([-1, 1]), 0, 1, 0), 35 | fillcolor=self.fillcolor) 36 | 37 | 38 | class TranslateY(object): 39 | def __init__(self, fillcolor=(128, 128, 128)): 40 | self.fillcolor = fillcolor 41 | 42 | def __call__(self, x, magnitude): 43 | return x.transform( 44 | x.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * x.size[1] * random.choice([-1, 1])), 45 | fillcolor=self.fillcolor) 46 | 47 | 48 | class Rotate(object): 49 | # from https://stackoverflow.com/questions/ 50 | # 5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 51 | def __call__(self, x, magnitude): 52 | rot = x.convert("RGBA").rotate(magnitude) 53 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(x.mode) 54 | 55 | 56 | class Color(object): 57 | def __call__(self, x, magnitude): 58 | return ImageEnhance.Color(x).enhance(1 + magnitude * random.choice([-1, 1])) 59 | 60 | 61 | class Posterize(object): 62 | def __call__(self, x, magnitude): 63 | return ImageOps.posterize(x, magnitude) 64 | 65 | 66 | class Solarize(object): 67 | def __call__(self, x, magnitude): 68 | return ImageOps.solarize(x, magnitude) 69 | 70 | 71 | class Contrast(object): 72 | def __call__(self, x, magnitude): 73 | return ImageEnhance.Contrast(x).enhance(1 + magnitude * random.choice([-1, 1])) 74 | 75 | 76 | class Sharpness(object): 77 | def __call__(self, x, magnitude): 78 | return ImageEnhance.Sharpness(x).enhance(1 + magnitude * random.choice([-1, 1])) 79 | 80 | 81 | class Brightness(object): 82 | def __call__(self, x, magnitude): 83 | return ImageEnhance.Brightness(x).enhance(1 + magnitude * random.choice([-1, 1])) 84 | 85 | 86 | class AutoContrast(object): 87 | def __call__(self, x, magnitude): 88 | return ImageOps.autocontrast(x) 89 | 90 | 91 | class Equalize(object): 92 | def __call__(self, x, magnitude): 93 | return ImageOps.equalize(x) 94 | 95 | 96 | class Invert(object): 97 | def __call__(self, x, magnitude): 98 | return ImageOps.invert(x) 99 | -------------------------------------------------------------------------------- /model/models/protonet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | from model.models import FewShotModel_1 7 | 8 | """ 9 | The ProtoNet class inherits from FewShotModel_1, which is assumed to be tailored for few-shot learning scenarios. 10 | This implementation specifically targets the scenario where the model is expected to learn from a limited number of examples (support set) 11 | and generalize well to new, unseen examples (query set). 12 | """ 13 | 14 | class ProtoNet(FewShotModel_1): 15 | def __init__(self, args): 16 | """ 17 | Initialize the ProtoNet with the given arguments. 18 | This constructor passes any arguments to the superclass FewShotModel_1, which might perform some initial setup. 19 | """ 20 | super().__init__(args) 21 | 22 | def _forward(self, instance_embs, support_idx, query_idx): 23 | """ 24 | Custom forward logic for processing instance embeddings and calculating the prototypes. 25 | 26 | Parameters: 27 | - instance_embs: Tensor containing embeddings for all instances. 28 | - support_idx: Indices of support examples within instance_embs. 29 | - query_idx: Indices of query examples within instance_embs. 30 | 31 | The method handles two cases: 32 | 1. If Grad-CAM is enabled, it returns the raw embeddings for visualization purposes. 33 | 2. Otherwise, it processes the embeddings to compute class prototypes and their distances to query examples. 34 | """ 35 | if self.args.grad_cam: 36 | # Return embeddings directly for Grad-CAM visualization. 37 | return instance_embs 38 | 39 | else: 40 | # Extract the size of the last dimension, which represents the dimensionality of the embeddings. 41 | emb_dim = instance_embs.size(-1) 42 | 43 | # Organize support and query data by reshaping them according to their indices. 44 | support = instance_embs[support_idx.flatten()].view(*(support_idx.shape + (-1,))) 45 | query = instance_embs[query_idx.flatten()].view(*(query_idx.shape + (-1,))) 46 | 47 | # Compute the mean of the support embeddings to form the prototypes for each class. 48 | proto = support.mean(dim=1) # Ntask x NK x d 49 | 50 | # Prepare for distance calculation between queries and prototypes. 51 | num_batch = proto.shape[0] 52 | num_proto = proto.shape[1] 53 | num_query = np.prod(query_idx.shape[-2:]) 54 | 55 | if True: # Placeholder for a boolean flag such as self.args.use_euclidean 56 | # Compute Euclidean distances 57 | query = query.view(-1, emb_dim).unsqueeze(1) # Reshape for broadcasting 58 | proto = proto.unsqueeze(1).expand(num_batch, num_query, num_proto, emb_dim) 59 | proto = proto.contiguous().view(num_batch * num_query, num_proto, emb_dim) 60 | logits = - torch.sum((proto - query) ** 2, 2) / self.args.temperature 61 | else: 62 | # Compute Cosine similarity 63 | proto = F.normalize(proto, dim=-1) # Normalize for cosine distance 64 | query = query.view(num_batch, -1, emb_dim) # Reshape for matrix multiplication 65 | logits = torch.bmm(query, proto.permute([0, 2, 1])) / self.args.temperature 66 | logits = logits.view(-1, num_proto) 67 | 68 | # Depending on the training state, return logits directly or with additional processing. 69 | if self.training: 70 | return logits, None 71 | else: 72 | return logits 73 | -------------------------------------------------------------------------------- /model/trainer/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | import os.path as osp 4 | 5 | from model.utils import ( 6 | ensure_path, 7 | Averager, Timer, count_acc, 8 | compute_confidence_interval, 9 | ) 10 | from model.logger import Logger 11 | 12 | class Trainer(object, metaclass=abc.ABCMeta): 13 | def __init__(self, args): 14 | self.args = args 15 | # ensure_path( 16 | # self.args.save_path, 17 | # scripts_to_save=['model/models', 'model/networks', __file__], 18 | # ) 19 | self.logger = Logger(args, osp.join(args.save_path)) 20 | 21 | self.train_step = 0 22 | self.train_epoch = 0 23 | self.max_steps = args.episodes_per_epoch * args.max_epoch 24 | self.dt, self.ft = Averager(), Averager() 25 | self.bt, self.ot = Averager(), Averager() 26 | self.timer = Timer() 27 | 28 | # train statistics 29 | self.trlog = {} 30 | self.trlog['max_acc'] = 0.0 31 | self.trlog['max_acc_epoch'] = 0 32 | self.trlog['max_acc_interval'] = 0.0 33 | 34 | @abc.abstractmethod 35 | def train(self): 36 | pass 37 | 38 | @abc.abstractmethod 39 | def evaluate(self, data_loader): 40 | pass 41 | 42 | @abc.abstractmethod 43 | def evaluate_test(self, data_loader): 44 | pass 45 | 46 | @abc.abstractmethod 47 | def final_record(self): 48 | pass 49 | 50 | def try_evaluate(self, epoch): 51 | args = self.args 52 | if self.train_epoch % args.eval_interval == 0: 53 | vl, va, vap = self.evaluate(self.val_loader) 54 | self.logger.add_scalar('val_loss', float(vl), self.train_epoch) 55 | self.logger.add_scalar('val_acc', float(va), self.train_epoch) 56 | 57 | print('epoch {}, val, loss={:.4f} acc={:.4f}+{:.4f}'.format(epoch, vl, va, vap)) 58 | 59 | if va >= self.trlog['max_acc']: 60 | self.trlog['max_acc'] = va 61 | self.trlog['max_acc_interval'] = vap 62 | self.trlog['max_acc_epoch'] = self.train_epoch 63 | self.save_model('max_acc') 64 | 65 | def try_logging(self, tl1, tl2, ta, tg=None): 66 | args = self.args 67 | if self.train_step % args.log_interval == 0: 68 | print('epoch {}, train {:06g}/{:06g}, total loss={:.4f}, loss={:.4f} acc={:.4f}, lr={:.4g}' 69 | .format(self.train_epoch, 70 | self.train_step, 71 | self.max_steps, 72 | tl1.item(), tl2.item(), ta.item(), 73 | self.optimizer.param_groups[0]['lr'])) 74 | self.logger.add_scalar('train_total_loss', tl1.item(), self.train_step) 75 | self.logger.add_scalar('train_loss', tl2.item(), self.train_step) 76 | self.logger.add_scalar('train_acc', ta.item(), self.train_step) 77 | if tg is not None: 78 | self.logger.add_scalar('grad_norm', tg.item(), self.train_step) 79 | print('data_timer: {:.2f} sec, ' \ 80 | 'forward_timer: {:.2f} sec,' \ 81 | 'backward_timer: {:.2f} sec, ' \ 82 | 'optim_timer: {:.2f} sec'.format( 83 | self.dt.item(), self.ft.item(), 84 | self.bt.item(), self.ot.item()) 85 | ) 86 | self.logger.dump() 87 | 88 | def save_model(self, name): 89 | torch.save( 90 | dict(params=self.model.state_dict()), 91 | osp.join(self.args.save_path, name + '.pth') 92 | ) 93 | 94 | def __str__(self): 95 | return "{}({})".format( 96 | self.__class__.__name__, 97 | self.model.__class__.__name__ 98 | ) 99 | -------------------------------------------------------------------------------- /model/data_parallel.py: -------------------------------------------------------------------------------- 1 | from torch.nn.parallel import DataParallel 2 | import torch 3 | from torch.nn.parallel._functions import Scatter 4 | from torch.nn.parallel.parallel_apply import parallel_apply 5 | 6 | def scatter(inputs, target_gpus, chunk_sizes, dim=0): 7 | r""" 8 | Slices tensors into approximately equal chunks and 9 | distributes them across given GPUs. Duplicates 10 | references to objects that are not tensors. 11 | """ 12 | def scatter_map(obj): 13 | if isinstance(obj, torch.Tensor): 14 | try: 15 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 16 | except: 17 | print('obj', obj.size()) 18 | print('dim', dim) 19 | print('chunk_sizes', chunk_sizes) 20 | quit() 21 | if isinstance(obj, tuple) and len(obj) > 0: 22 | return list(zip(*map(scatter_map, obj))) 23 | if isinstance(obj, list) and len(obj) > 0: 24 | return list(map(list, zip(*map(scatter_map, obj)))) 25 | if isinstance(obj, dict) and len(obj) > 0: 26 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 27 | return [obj for targets in target_gpus] 28 | 29 | # After scatter_map is called, a scatter_map cell will exist. This cell 30 | # has a reference to the actual function scatter_map, which has references 31 | # to a closure that has a reference to the scatter_map cell (because the 32 | # fn is recursive). To avoid this reference cycle, we set the function to 33 | # None, clearing the cell 34 | try: 35 | return scatter_map(inputs) 36 | finally: 37 | scatter_map = None 38 | 39 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): 40 | r"""Scatter with support for kwargs dictionary""" 41 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] 42 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] 43 | if len(inputs) < len(kwargs): 44 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 45 | elif len(kwargs) < len(inputs): 46 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 47 | inputs = tuple(inputs) 48 | kwargs = tuple(kwargs) 49 | return inputs, kwargs 50 | 51 | class BalancedDataParallel(DataParallel): 52 | def __init__(self, gpu0_bsz, *args, **kwargs): 53 | self.gpu0_bsz = gpu0_bsz 54 | super().__init__(*args, **kwargs) 55 | 56 | def forward(self, *inputs, **kwargs): 57 | if not self.device_ids: 58 | return self.module(*inputs, **kwargs) 59 | if self.gpu0_bsz == 0: 60 | device_ids = self.device_ids[1:] 61 | else: 62 | device_ids = self.device_ids 63 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids) 64 | if len(self.device_ids) == 1: 65 | return self.module(*inputs[0], **kwargs[0]) 66 | replicas = self.replicate(self.module, self.device_ids) 67 | if self.gpu0_bsz == 0: 68 | replicas = replicas[1:] 69 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) 70 | return self.gather(outputs, self.output_device) 71 | 72 | def parallel_apply(self, replicas, device_ids, inputs, kwargs): 73 | return parallel_apply(replicas, inputs, kwargs, device_ids) 74 | 75 | def scatter(self, inputs, kwargs, device_ids): 76 | bsz = inputs[0].size(self.dim) 77 | num_dev = len(self.device_ids) 78 | gpu0_bsz = self.gpu0_bsz 79 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) 80 | if gpu0_bsz < bsz_unit: 81 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) 82 | delta = bsz - sum(chunk_sizes) 83 | for i in range(delta): 84 | chunk_sizes[i + 1] += 1 85 | if gpu0_bsz == 0: 86 | chunk_sizes = chunk_sizes[1:] 87 | else: 88 | return super().scatter(inputs, kwargs, device_ids) 89 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) 90 | 91 | -------------------------------------------------------------------------------- /model/models/utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Tokenizer(nn.Module): 7 | def __init__(self, 8 | kernel_size, stride, padding, 9 | pooling_kernel_size=3, pooling_stride=2, pooling_padding=1, 10 | n_conv_layers=1, 11 | n_input_channels=3, 12 | n_output_channels=64, 13 | in_planes=64, 14 | activation=None, 15 | max_pool=True, 16 | conv_bias=False): 17 | super(Tokenizer, self).__init__() 18 | 19 | n_filter_list = [n_input_channels] + \ 20 | [in_planes for _ in range(n_conv_layers - 1)] + \ 21 | [n_output_channels] 22 | 23 | self.conv_layers = nn.Sequential( 24 | *[nn.Sequential( 25 | nn.Conv2d(n_filter_list[i], n_filter_list[i + 1], 26 | kernel_size=(kernel_size, kernel_size), 27 | stride=(stride, stride), 28 | padding=(padding, padding), bias=conv_bias), 29 | nn.Identity() if activation is None else activation(), 30 | nn.MaxPool2d(kernel_size=pooling_kernel_size, 31 | stride=pooling_stride, 32 | padding=pooling_padding) if max_pool else nn.Identity() 33 | ) 34 | for i in range(n_conv_layers) 35 | ]) 36 | 37 | self.flattener = nn.Flatten(2, 3) 38 | self.apply(self.init_weight) 39 | 40 | def sequence_length(self, n_channels=3, height=224, width=224): 41 | return self.forward(torch.zeros((1, n_channels, height, width))).shape[1] 42 | 43 | def forward(self, x): 44 | return self.flattener(self.conv_layers(x)).transpose(-2, -1) 45 | 46 | @staticmethod 47 | def init_weight(m): 48 | if isinstance(m, nn.Conv2d): 49 | nn.init.kaiming_normal_(m.weight) 50 | 51 | 52 | class TextTokenizer(nn.Module): 53 | def __init__(self, 54 | kernel_size, stride, padding, 55 | pooling_kernel_size=3, pooling_stride=2, pooling_padding=1, 56 | embedding_dim=300, 57 | n_output_channels=128, 58 | activation=None, 59 | max_pool=True, 60 | *args, **kwargs): 61 | super(TextTokenizer, self).__init__() 62 | 63 | self.max_pool = max_pool 64 | self.conv_layers = nn.Sequential( 65 | nn.Conv2d(1, n_output_channels, 66 | kernel_size=(kernel_size, embedding_dim), 67 | stride=(stride, 1), 68 | padding=(padding, 0), bias=False), 69 | nn.Identity() if activation is None else activation(), 70 | nn.MaxPool2d( 71 | kernel_size=(pooling_kernel_size, 1), 72 | stride=(pooling_stride, 1), 73 | padding=(pooling_padding, 0) 74 | ) if max_pool else nn.Identity() 75 | ) 76 | 77 | self.apply(self.init_weight) 78 | 79 | def seq_len(self, seq_len=32, embed_dim=300): 80 | return self.forward(torch.zeros((1, seq_len, embed_dim)))[0].shape[1] 81 | 82 | def forward_mask(self, mask): 83 | new_mask = mask.unsqueeze(1).float() 84 | cnn_weight = torch.ones( 85 | (1, 1, self.conv_layers[0].kernel_size[0]), 86 | device=mask.device, 87 | dtype=torch.float) 88 | new_mask = F.conv1d( 89 | new_mask, cnn_weight, None, 90 | self.conv_layers[0].stride[0], self.conv_layers[0].padding[0], 1, 1) 91 | if self.max_pool: 92 | new_mask = F.max_pool1d( 93 | new_mask, self.conv_layers[2].kernel_size[0], 94 | self.conv_layers[2].stride[0], self.conv_layers[2].padding[0], 1, False, False) 95 | new_mask = new_mask.squeeze(1) 96 | new_mask = (new_mask > 0) 97 | return new_mask 98 | 99 | def forward(self, x, mask=None): 100 | x = x.unsqueeze(1) 101 | x = self.conv_layers(x) 102 | x = x.transpose(1, 3).squeeze(1) 103 | x = x if mask is None else x * self.forward_mask(mask).unsqueeze(-1).float() 104 | return x, mask 105 | 106 | @staticmethod 107 | def init_weight(m): 108 | if isinstance(m, nn.Conv2d): 109 | nn.init.kaiming_normal_(m.weight) 110 | -------------------------------------------------------------------------------- /model/networks/utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Tokenizer(nn.Module): 7 | def __init__(self, 8 | kernel_size, stride, padding, 9 | pooling_kernel_size=3, pooling_stride=2, pooling_padding=1, 10 | n_conv_layers=1, 11 | n_input_channels=3, 12 | n_output_channels=64, 13 | in_planes=64, 14 | activation=None, 15 | max_pool=True, 16 | conv_bias=False): 17 | super(Tokenizer, self).__init__() 18 | 19 | n_filter_list = [n_input_channels] + \ 20 | [in_planes for _ in range(n_conv_layers - 1)] + \ 21 | [n_output_channels] 22 | 23 | self.conv_layers = nn.Sequential( 24 | *[nn.Sequential( 25 | nn.Conv2d(n_filter_list[i], n_filter_list[i + 1], 26 | kernel_size=(kernel_size, kernel_size), 27 | stride=(stride, stride), 28 | padding=(padding, padding), bias=conv_bias), 29 | nn.Identity() if activation is None else activation(), 30 | nn.MaxPool2d(kernel_size=pooling_kernel_size, 31 | stride=pooling_stride, 32 | padding=pooling_padding) if max_pool else nn.Identity() 33 | ) 34 | for i in range(n_conv_layers) 35 | ]) 36 | 37 | self.flattener = nn.Flatten(2, 3) 38 | self.apply(self.init_weight) 39 | 40 | def sequence_length(self, n_channels=3, height=224, width=224): 41 | return self.forward(torch.zeros((1, n_channels, height, width))).shape[1] 42 | 43 | def forward(self, x): 44 | return self.flattener(self.conv_layers(x)).transpose(-2, -1) 45 | 46 | @staticmethod 47 | def init_weight(m): 48 | if isinstance(m, nn.Conv2d): 49 | nn.init.kaiming_normal_(m.weight) 50 | 51 | 52 | class TextTokenizer(nn.Module): 53 | def __init__(self, 54 | kernel_size, stride, padding, 55 | pooling_kernel_size=3, pooling_stride=2, pooling_padding=1, 56 | embedding_dim=300, 57 | n_output_channels=128, 58 | activation=None, 59 | max_pool=True, 60 | *args, **kwargs): 61 | super(TextTokenizer, self).__init__() 62 | 63 | self.max_pool = max_pool 64 | self.conv_layers = nn.Sequential( 65 | nn.Conv2d(1, n_output_channels, 66 | kernel_size=(kernel_size, embedding_dim), 67 | stride=(stride, 1), 68 | padding=(padding, 0), bias=False), 69 | nn.Identity() if activation is None else activation(), 70 | nn.MaxPool2d( 71 | kernel_size=(pooling_kernel_size, 1), 72 | stride=(pooling_stride, 1), 73 | padding=(pooling_padding, 0) 74 | ) if max_pool else nn.Identity() 75 | ) 76 | 77 | self.apply(self.init_weight) 78 | 79 | def seq_len(self, seq_len=32, embed_dim=300): 80 | return self.forward(torch.zeros((1, seq_len, embed_dim)))[0].shape[1] 81 | 82 | def forward_mask(self, mask): 83 | new_mask = mask.unsqueeze(1).float() 84 | cnn_weight = torch.ones( 85 | (1, 1, self.conv_layers[0].kernel_size[0]), 86 | device=mask.device, 87 | dtype=torch.float) 88 | new_mask = F.conv1d( 89 | new_mask, cnn_weight, None, 90 | self.conv_layers[0].stride[0], self.conv_layers[0].padding[0], 1, 1) 91 | if self.max_pool: 92 | new_mask = F.max_pool1d( 93 | new_mask, self.conv_layers[2].kernel_size[0], 94 | self.conv_layers[2].stride[0], self.conv_layers[2].padding[0], 1, False, False) 95 | new_mask = new_mask.squeeze(1) 96 | new_mask = (new_mask > 0) 97 | return new_mask 98 | 99 | def forward(self, x, mask=None): 100 | x = x.unsqueeze(1) 101 | x = self.conv_layers(x) 102 | x = x.transpose(1, 3).squeeze(1) 103 | x = x if mask is None else x * self.forward_mask(mask).unsqueeze(-1).float() 104 | return x, mask 105 | 106 | @staticmethod 107 | def init_weight(m): 108 | if isinstance(m, nn.Conv2d): 109 | nn.init.kaiming_normal_(m.weight) 110 | -------------------------------------------------------------------------------- /model/dataloader/mini_imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path as osp 3 | from PIL import Image 4 | from .transforms import * 5 | import PIL 6 | 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | # Paths 13 | THIS_PATH = osp.dirname(__file__) 14 | ROOT_PATH = osp.abspath(osp.join(THIS_PATH, '..', '..')) 15 | ROOT_PATH2 = osp.abspath(osp.join(THIS_PATH, '..', '..', '..')) 16 | IMAGE_PATH1 = osp.join(ROOT_PATH, 'data/miniimagenet/images') 17 | SPLIT_PATH = osp.join(ROOT_PATH, 'data/miniimagenet/split') 18 | CACHE_PATH = osp.join(ROOT_PATH, '.cache/') 19 | 20 | 21 | def identity(x): 22 | """Identity function.""" 23 | return x 24 | 25 | class MiniImageNet(Dataset): 26 | """Dataset class for MiniImageNet.""" 27 | def __init__(self, setname, args, augment=False): 28 | """Initialize MiniImageNet dataset.""" 29 | im_size = args.orig_imsize 30 | csv_path = osp.join(SPLIT_PATH, setname + '.csv') 31 | cache_path = osp.join( CACHE_PATH, "{}.{}.{}.pt".format(self.__class__.__name__, setname, im_size) ) 32 | self.args = args 33 | self.use_im_cache = ( im_size != -1 ) # not using cache 34 | 35 | # Check if using image cache 36 | if self.use_im_cache: 37 | if not osp.exists(cache_path): 38 | print('* Cache miss... Preprocessing {}...'.format(setname)) 39 | resize_ = identity if im_size < 0 else transforms.Resize(im_size) 40 | data, label = self.parse_csv(csv_path, setname) 41 | self.data = [ resize_(Image.open(path).convert('RGB')) for path in data ] 42 | self.label = label 43 | print('* Dump cache from {}'.format(cache_path)) 44 | torch.save({'data': self.data, 'label': self.label }, cache_path) 45 | else: 46 | print('* Load cache from {}'.format(cache_path)) 47 | cache = torch.load(cache_path) 48 | self.data = cache['data'] 49 | self.label = cache['label'] 50 | else: 51 | self.data, self.label = self.parse_csv(csv_path, setname) 52 | 53 | self.num_class = len(set(self.label)) 54 | 55 | image_size = 84 56 | if augment and setname == 'train': 57 | # Augmentation transforms 58 | transforms_list = [ 59 | transforms.RandomResizedCrop(image_size), 60 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 61 | transforms.RandomHorizontalFlip(), 62 | transforms.ToTensor(), 63 | ] 64 | else: 65 | # Validation/Test transforms 66 | transforms_list = [ 67 | transforms.Resize(92), 68 | transforms.CenterCrop(image_size), 69 | transforms.ToTensor(), 70 | ] 71 | 72 | # Transformation based on backbone class 73 | if args.backbone_class == 'Res12' : 74 | self.transform = transforms.Compose( 75 | transforms_list + [ 76 | transforms.Normalize(np.array([x / 255.0 for x in [120.39586422, 115.59361427, 104.54012653]]), 77 | np.array([x / 255.0 for x in [70.68188272, 68.27635443, 72.54505529]])) 78 | ]) 79 | elif args.backbone_class == 'Res18': 80 | self.transform = transforms.Compose( 81 | transforms_list + [ 82 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 83 | std=[0.229, 0.224, 0.225]) 84 | ]) 85 | else: 86 | raise ValueError('Non-supported Network Types. Please Revise Data Pre-Processing Scripts.') 87 | 88 | def parse_csv(self, csv_path, setname): 89 | """Parse CSV file to get image paths and labels.""" 90 | lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:] 91 | 92 | data = [] 93 | label = [] 94 | lb = -1 95 | 96 | self.wnids = [] 97 | 98 | for l in tqdm(lines, ncols=64): 99 | name, wnid = l.split(',') 100 | path = osp.join(IMAGE_PATH1, name) 101 | if wnid not in self.wnids: 102 | self.wnids.append(wnid) 103 | lb += 1 104 | data.append( path ) 105 | label.append(lb) 106 | 107 | return data, label 108 | 109 | def __len__(self): 110 | """Get the length of the dataset.""" 111 | return len(self.data) 112 | 113 | def __getitem__(self, i): 114 | """Get an item from the dataset.""" 115 | data, label = self.data[i], self.label[i] 116 | if self.use_im_cache: 117 | image = self.transform(data) 118 | else: 119 | image = self.transform(Image.open(data).convert('RGB')) 120 | 121 | return image, label 122 | -------------------------------------------------------------------------------- /model/networks/res12.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from model.networks.dropblock import DropBlock 5 | 6 | # This ResNet network was designed following the practice of the following papers: 7 | # TADAM: Task dependent adaptive metric for improved few-shot learning (Oreshkin et al., in NIPS 2018) and 8 | # A Simple Neural Attentive Meta-Learner (Mishra et al., in ICLR 2018). 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False, block_size=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = conv3x3(inplanes, planes) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.relu = nn.LeakyReLU(0.1) 24 | self.conv2 = conv3x3(planes, planes) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | self.conv3 = conv3x3(planes, planes) 27 | self.bn3 = nn.BatchNorm2d(planes) 28 | self.maxpool = nn.MaxPool2d(stride) 29 | self.downsample = downsample 30 | self.stride = stride 31 | self.drop_rate = drop_rate 32 | self.num_batches_tracked = 0 33 | self.drop_block = drop_block 34 | self.block_size = block_size 35 | self.DropBlock = DropBlock(block_size=self.block_size) 36 | 37 | def forward(self, x): 38 | self.num_batches_tracked += 1 39 | 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv3(out) 51 | out = self.bn3(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | out += residual 56 | out = self.relu(out) 57 | out = self.maxpool(out) 58 | 59 | if self.drop_rate > 0: 60 | if self.drop_block == True: 61 | feat_size = out.size()[2] 62 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate) 63 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2 64 | out = self.DropBlock(out, gamma=gamma) 65 | else: 66 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True) 67 | 68 | return out 69 | 70 | 71 | class ResNet(nn.Module): 72 | 73 | def __init__(self, block=BasicBlock, keep_prob=1.0, avg_pool=True, drop_rate=0.1, dropblock_size=5): 74 | self.inplanes = 3 75 | super(ResNet, self).__init__() 76 | 77 | self.layer1 = self._make_layer(block, 64, stride=2, drop_rate=drop_rate) 78 | self.layer2 = self._make_layer(block, 160, stride=2, drop_rate=drop_rate) 79 | self.layer3 = self._make_layer(block, 320, stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 80 | self.layer4 = self._make_layer(block, 640, stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 81 | if avg_pool: 82 | self.avgpool = nn.AvgPool2d(5, stride=1) 83 | self.keep_prob = keep_prob 84 | self.keep_avg_pool = avg_pool 85 | self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False) 86 | self.drop_rate = drop_rate 87 | 88 | for m in self.modules(): 89 | if isinstance(m, nn.Conv2d): 90 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 91 | elif isinstance(m, nn.BatchNorm2d): 92 | nn.init.constant_(m.weight, 1) 93 | nn.init.constant_(m.bias, 0) 94 | 95 | def _make_layer(self, block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1): 96 | downsample = None 97 | if stride != 1 or self.inplanes != planes * block.expansion: 98 | downsample = nn.Sequential( 99 | nn.Conv2d(self.inplanes, planes * block.expansion, 100 | kernel_size=1, stride=1, bias=False), 101 | nn.BatchNorm2d(planes * block.expansion), 102 | ) 103 | 104 | layers = [] 105 | layers.append(block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size)) 106 | self.inplanes = planes * block.expansion 107 | 108 | return nn.Sequential(*layers) 109 | 110 | def forward(self, x): 111 | x = self.layer1(x) 112 | x = self.layer2(x) 113 | x = self.layer3(x) 114 | x = self.layer4(x) 115 | # x = torch.nn.functional.interpolate(x, size=(7, 7), mode='bilinear') 116 | # if self.keep_avg_pool: 117 | # x = self.avgpool(x) 118 | # x = x.view(x.size(0), -1) 119 | return x 120 | 121 | 122 | def Res12(keep_prob=1.0, avg_pool=False, **kwargs): 123 | """Constructs a ResNet-12 model. 124 | """ 125 | model = ResNet(BasicBlock, keep_prob=keep_prob, avg_pool=avg_pool, **kwargs) 126 | return model 127 | -------------------------------------------------------------------------------- /model/dataloader/cub.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import PIL 3 | from PIL import Image 4 | 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | import torch 9 | THIS_PATH = osp.dirname(__file__) 10 | ROOT_PATH = osp.abspath(osp.join(THIS_PATH, '..', '..')) 11 | ROOT_PATH2 = osp.abspath(osp.join(THIS_PATH, '..', '..', '..')) 12 | IMAGE_PATH = 'data/CUB/CUB_200_2011/images' ##previously use data in feat/data/cub/images 13 | 14 | SPLIT_PATH = 'data/CUB/split' 15 | CACHE_PATH = osp.join(ROOT_PATH, '.cache/') 16 | 17 | # This is for the CUB dataset 18 | # It is notable, we assume the cub images are cropped based on the given bounding boxes 19 | # The concept labels are based on the attribute value, which are for further use (and not used in this work) 20 | 21 | class CUB(Dataset): 22 | 23 | def __init__(self, setname, args, augment=False): 24 | im_size = args.orig_imsize 25 | txt_path = osp.join(SPLIT_PATH, setname + '.csv') 26 | lines = [x.strip() for x in open(txt_path, 'r').readlines()][1:] 27 | cache_path = osp.join( CACHE_PATH, "{}.{}.{}.pt".format(self.__class__.__name__, setname, im_size) ) 28 | 29 | self.use_im_cache = ( im_size != -1 ) # not using cache 30 | if self.use_im_cache: 31 | if not osp.exists(cache_path): 32 | print('* Cache miss... Preprocessing {}...'.format(setname)) 33 | resize_ = identity if im_size < 0 else transforms.Resize(im_size) 34 | data, label = self.parse_csv(txt_path) 35 | self.data = [ resize_(Image.open(path).convert('RGB')) for path in data ] 36 | self.label = label 37 | print('* Dump cache from {}'.format(cache_path)) 38 | torch.save({'data': self.data, 'label': self.label }, cache_path) 39 | else: 40 | print('* Load cache from {}'.format(cache_path)) 41 | cache = torch.load(cache_path) 42 | self.data = cache['data'] 43 | self.label = cache['label'] 44 | else: 45 | self.data, self.label = self.parse_csv(txt_path) 46 | 47 | self.num_class = np.unique(np.array(self.label)).shape[0] 48 | image_size = 84 49 | 50 | if augment and setname == 'train': 51 | transforms_list = [ 52 | transforms.RandomResizedCrop(image_size), 53 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 54 | transforms.RandomHorizontalFlip(), 55 | transforms.ToTensor(), 56 | ] 57 | else: 58 | transforms_list = [ 59 | transforms.Resize(92), 60 | transforms.CenterCrop(image_size), 61 | transforms.ToTensor(), 62 | ] 63 | 64 | # Transformation 65 | if args.backbone_class == 'ConvNet': 66 | self.transform = transforms.Compose( 67 | transforms_list + [ 68 | transforms.Normalize(np.array([0.485, 0.456, 0.406]), 69 | np.array([0.229, 0.224, 0.225])) 70 | ]) 71 | elif args.backbone_class == 'Res12': 72 | self.transform = transforms.Compose( 73 | transforms_list + [ 74 | transforms.Normalize(np.array([x / 255.0 for x in [120.39586422, 115.59361427, 104.54012653]]), 75 | np.array([x / 255.0 for x in [70.68188272, 68.27635443, 72.54505529]])) 76 | ]) 77 | elif args.backbone_class == 'Res18': 78 | self.transform = transforms.Compose( 79 | transforms_list + [ 80 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 81 | std=[0.229, 0.224, 0.225]) 82 | ]) 83 | elif args.backbone_class == 'WRN': 84 | self.transform = transforms.Compose( 85 | transforms_list + [ 86 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 87 | std=[0.229, 0.224, 0.225]) 88 | ]) 89 | else: 90 | raise ValueError('Non-supported Network Types. Please Revise Data Pre-Processing Scripts.') 91 | 92 | def parse_csv(self, txt_path): 93 | data = [] 94 | label = [] 95 | lb = -1 96 | self.wnids = [] 97 | lines = [x.strip() for x in open(txt_path, 'r').readlines()][1:] 98 | 99 | for l in lines: 100 | context = l.split(',') 101 | name = context[0] 102 | wnid = context[1] 103 | path = osp.join(IMAGE_PATH, name) 104 | if wnid not in self.wnids: 105 | self.wnids.append(wnid) 106 | lb += 1 107 | 108 | data.append(path) 109 | label.append(lb) 110 | 111 | return data, label 112 | 113 | 114 | def __len__(self): 115 | return len(self.data) 116 | 117 | def __getitem__(self, i): 118 | data, label = self.data[i], self.label[i] 119 | if self.use_im_cache: 120 | image = self.transform(data) 121 | else: 122 | image = self.transform(Image.open(data).convert('RGB')) 123 | return image, label -------------------------------------------------------------------------------- /model/dataloader/tiered_imagenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import os.path as osp 5 | import numpy as np 6 | import pickle 7 | import sys 8 | import torch 9 | import torch.utils.data as data 10 | import torchvision.transforms as transforms 11 | from PIL import Image 12 | 13 | # Set the appropriate paths of the datasets here. 14 | THIS_PATH = osp.dirname(__file__) 15 | ROOT_PATH1 = osp.abspath(osp.join(THIS_PATH, '..', '..', '..')) 16 | ROOT_PATH2 = osp.abspath(osp.join(THIS_PATH, '..', '..')) 17 | IMAGE_PATH = osp.join(ROOT_PATH2, 'data/tieredimagenet/') 18 | SPLIT_PATH = osp.join(ROOT_PATH2, 'data/miniimagenet/split') 19 | 20 | from .transforms import * 21 | import PIL 22 | 23 | 24 | def buildLabelIndex(labels): 25 | label2inds = {} 26 | for idx, label in enumerate(labels): 27 | if label not in label2inds: 28 | label2inds[label] = [] 29 | label2inds[label].append(idx) 30 | 31 | return label2inds 32 | 33 | 34 | def load_data(file): 35 | try: 36 | with open(file, 'rb') as fo: 37 | data = pickle.load(fo) 38 | return data 39 | except: 40 | with open(file, 'rb') as f: 41 | u = pickle._Unpickler(f) 42 | u.encoding = 'latin1' 43 | data = u.load() 44 | return data 45 | 46 | file_path = {'train':[os.path.join(IMAGE_PATH, 'train_images.npz'), os.path.join(IMAGE_PATH, 'train_labels.pkl')], 47 | 'val':[os.path.join(IMAGE_PATH, 'val_images.npz'), os.path.join(IMAGE_PATH,'val_labels.pkl')], 48 | 'test':[os.path.join(IMAGE_PATH, 'test_images.npz'), os.path.join(IMAGE_PATH, 'test_labels.pkl')]} 49 | 50 | class tieredImageNet(data.Dataset): 51 | def __init__(self, setname, args, augment=False): 52 | assert(setname=='train' or setname=='val' or setname=='test') 53 | image_path = file_path[setname][0] 54 | label_path = file_path[setname][1] 55 | 56 | data_train = load_data(label_path) 57 | labels = data_train['labels'] 58 | self.data = np.load(image_path)['images'] 59 | label = [] 60 | lb = -1 61 | self.wnids = [] 62 | for wnid in labels: 63 | if wnid not in self.wnids: 64 | self.wnids.append(wnid) 65 | lb += 1 66 | label.append(lb) 67 | 68 | self.label = label 69 | self.num_class = len(set(label)) 70 | 71 | if augment and setname == 'train': 72 | transforms_list = [ 73 | transforms.RandomCrop(84, padding=8), 74 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 75 | transforms.RandomHorizontalFlip(), 76 | transforms.ToTensor(), 77 | ] 78 | elif args.backbone_class == 'Res10': 79 | transforms_list = [ 80 | transforms.RandomSizedCrop(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), 81 | interpolation=PIL.Image.BILINEAR), 82 | transforms.RandomHorizontalFlip(p=0.5), 83 | transforms.ToTensor() 84 | ] 85 | 86 | 87 | else: 88 | transforms_list = [ 89 | transforms.ToTensor(), 90 | ] 91 | 92 | # Transformation 93 | if args.backbone_class == 'ConvNet': 94 | self.transform = transforms.Compose( 95 | transforms_list + [ 96 | transforms.Normalize(np.array([0.485, 0.456, 0.406]), 97 | np.array([0.229, 0.224, 0.225])) 98 | ]) 99 | elif args.backbone_class == 'ResNet': 100 | self.transform = transforms.Compose( 101 | transforms_list + [ 102 | transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), 103 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]])) 104 | ]) 105 | elif args.backbone_class == 'Res12' or args.backbone_class == 'Res10' : 106 | self.transform = transforms.Compose( 107 | transforms_list + [ 108 | transforms.Normalize(np.array([x / 255.0 for x in [120.39586422, 115.59361427, 104.54012653]]), 109 | np.array([x / 255.0 for x in [70.68188272, 68.27635443, 72.54505529]])) 110 | ]) 111 | elif args.backbone_class == 'Res18': 112 | self.transform = transforms.Compose( 113 | transforms_list + [ 114 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 115 | std=[0.229, 0.224, 0.225]) 116 | ]) 117 | elif args.backbone_class == 'WRN': 118 | self.transform = transforms.Compose( 119 | transforms_list + [ 120 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 121 | std=[0.229, 0.224, 0.225]) 122 | ]) 123 | else: 124 | raise ValueError('Non-supported Network Types. Please Revise Data Pre-Processing Scripts.') 125 | 126 | 127 | def __getitem__(self, index): 128 | img, label = self.data[index], self.label[index] 129 | img = self.transform(Image.fromarray(img)) 130 | return img, label 131 | 132 | def __len__(self): 133 | return len(self.data) 134 | -------------------------------------------------------------------------------- /model/models/fcanet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def get_freq_indices(method): 7 | assert method in ['top1', 'top2', 'top4', 'top8', 'top16', 'top32', 8 | 'bot1', 'bot2', 'bot4', 'bot8', 'bot16', 'bot32', 9 | 'low1', 'low2', 'low4', 'low8', 'low16', 'low32'] 10 | num_freq = int(method[3:]) 11 | if 'top' in method: 12 | all_top_indices_x = [0, 0, 6, 0, 0, 1, 1, 4, 5, 1, 3, 0, 0, 0, 3, 2, 4, 6, 3, 5, 5, 2, 6, 5, 5, 3, 3, 4, 2, 2, 13 | 6, 1] 14 | all_top_indices_y = [0, 1, 0, 5, 2, 0, 2, 0, 0, 6, 0, 4, 6, 3, 5, 2, 6, 3, 3, 3, 5, 1, 1, 2, 4, 2, 1, 1, 3, 0, 15 | 5, 3] 16 | mapper_x = all_top_indices_x[:num_freq] 17 | mapper_y = all_top_indices_y[:num_freq] 18 | elif 'low' in method: 19 | all_low_indices_x = [0, 0, 1, 1, 0, 2, 2, 1, 2, 0, 3, 4, 0, 1, 3, 1, 5, 0, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6, 1, 2, 20 | 3, 4] 21 | all_low_indices_y = [0, 1, 0, 1, 2, 0, 1, 2, 2, 3, 0, 0, 4, 3, 1, 4, 0, 5, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 22 | 4, 3] 23 | mapper_x = all_low_indices_x[:num_freq] 24 | mapper_y = all_low_indices_y[:num_freq] 25 | elif 'bot' in method: 26 | all_bot_indices_x = [6, 1, 3, 3, 2, 4, 1, 2, 4, 4, 5, 1, 4, 6, 2, 5, 6, 1, 6, 2, 2, 4, 3, 3, 5, 5, 6, 2, 5, 5, 27 | 3, 6] 28 | all_bot_indices_y = [6, 4, 4, 6, 6, 3, 1, 4, 4, 5, 6, 5, 2, 2, 5, 1, 4, 3, 5, 0, 3, 1, 1, 2, 4, 2, 1, 1, 5, 3, 29 | 3, 3] 30 | mapper_x = all_bot_indices_x[:num_freq] 31 | mapper_y = all_bot_indices_y[:num_freq] 32 | else: 33 | raise NotImplementedError 34 | return mapper_x, mapper_y 35 | 36 | 37 | class MultiSpectralAttentionLayer(torch.nn.Module): 38 | def __init__(self, channel, dct_h, dct_w, sigma, k, freq_sel_method='top16'): 39 | super(MultiSpectralAttentionLayer, self).__init__() 40 | self.sigma = sigma 41 | self.k = k 42 | self.dct_h = dct_h 43 | self.dct_w = dct_w 44 | 45 | mapper_x, mapper_y = get_freq_indices(freq_sel_method) 46 | self.num_split = len(mapper_x) 47 | mapper_x = [temp_x * (dct_h // 5) for temp_x in mapper_x] 48 | mapper_y = [temp_y * (dct_w // 5) for temp_y in mapper_y] 49 | # make the frequencies in different sizes are identical to a 5x5 frequency space 50 | # eg, (2,2) in 10x10 is identical to (1,1) in5x5 51 | 52 | self.dct_layer = MultiSpectralDCTLayer(dct_h, dct_w, mapper_x, mapper_y, channel) 53 | self.fc = nn.Sequential( 54 | nn.Linear(channel, int(channel*self.sigma), bias=False), 55 | nn.ReLU(inplace=True), 56 | nn.Linear(int(channel*self.sigma), channel*self.k**2, bias=False), 57 | nn.Sigmoid() 58 | ) 59 | 60 | def forward(self, x): 61 | n, c, h, w = x.shape 62 | x_pooled = x 63 | if h != self.dct_h or w != self.dct_w: 64 | x_pooled = torch.nn.functional.adaptive_avg_pool2d(x, (self.dct_h, self.dct_w)) 65 | # If you have concerns about one-line-change, don't worry. :) 66 | # In the ImageNet models, this line will never be triggered. 67 | # This is for compatibility in instance segmentation and object detection. 68 | y = self.dct_layer(x_pooled) 69 | 70 | y = self.fc(y).view(n, c, self.k, self.k) 71 | # return x * y.expand_as(x) 72 | return y 73 | 74 | class MultiSpectralDCTLayer(nn.Module): 75 | """ 76 | Generate dct filters 77 | """ 78 | 79 | def __init__(self, height, width, mapper_x, mapper_y, channel): 80 | super(MultiSpectralDCTLayer, self).__init__() 81 | 82 | assert len(mapper_x) == len(mapper_y) 83 | assert channel % len(mapper_x) == 0 84 | 85 | self.num_freq = len(mapper_x) 86 | 87 | # fixed DCT init 88 | self.register_buffer('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel)) 89 | 90 | # fixed random init 91 | # self.register_buffer('weight', torch.rand(channel, height, width)) 92 | 93 | # learnable DCT init 94 | # self.register_parameter('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel)) 95 | 96 | # learnable random init 97 | # self.register_parameter('weight', torch.rand(channel, height, width)) 98 | 99 | # num_freq, h, w 100 | 101 | def forward(self, x): 102 | assert len(x.shape) == 4, 'x must been 4 dimensions, but got ' + str(len(x.shape)) 103 | # n, c, h, w = x.shape 104 | 105 | x = x * self.weight 106 | 107 | result = torch.sum(x, dim=[2, 3]) 108 | return result 109 | 110 | def build_filter(self, pos, freq, POS): 111 | result = math.cos(math.pi * freq * (pos + 0.5) / POS) / math.sqrt(POS) 112 | if freq == 0: 113 | return result 114 | else: 115 | return result * math.sqrt(2) 116 | 117 | def get_dct_filter(self, tile_size_x, tile_size_y, mapper_x, mapper_y, channel): 118 | dct_filter = torch.zeros(channel, tile_size_x, tile_size_y) 119 | 120 | c_part = channel // len(mapper_x) 121 | 122 | for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)): 123 | for t_x in range(tile_size_x): 124 | for t_y in range(tile_size_y): 125 | dct_filter[i * c_part: (i + 1) * c_part, t_x, t_y] = self.build_filter(t_x, u_x, 126 | tile_size_x) * self.build_filter( 127 | t_y, v_y, tile_size_y) 128 | 129 | return dct_filter -------------------------------------------------------------------------------- /model/models/INSTA_ProtoNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | from model.models import FewShotModel_1 7 | from model.models.INSTA import INSTA 8 | 9 | """ 10 | The INSTA_ProtoNet class combines INSTA-based attention mechanisms with the prototypical networks approach. 11 | This hybrid model is designed for few-shot learning tasks where it's important to quickly adapt to new classes 12 | with very few examples per class. 13 | """ 14 | 15 | class INSTA_ProtoNet(FewShotModel_1): 16 | def __init__(self, args): 17 | """ 18 | Initializes the INSTA_ProtoNet with the given arguments. 19 | 20 | Parameters: 21 | - args: Configuration settings including hyperparameters for the network setup. 22 | """ 23 | super().__init__(args) 24 | self.args = args 25 | # Instantiate the INSTA model with specific parameters. 26 | self.INSTA = INSTA(640, 5, 0.2, 3, args=args) 27 | 28 | def inner_loop(self, proto, support): 29 | """ 30 | Performs an inner optimization loop to fine-tune prototypes on support sets during meta-training. 31 | 32 | Parameters: 33 | - proto: Initial prototypes, typically the mean of the support embeddings. 34 | - support: Support set embeddings used for fine-tuning the prototypes. 35 | 36 | Returns: 37 | - SFC: Updated (fine-tuned) prototypes. 38 | """ 39 | # Clone and detach prototypes to prevent gradients from accumulating across episodes. 40 | SFC = proto.clone().detach() 41 | SFC = nn.Parameter(SFC, requires_grad=True) 42 | 43 | # Initialize an SGD optimizer specifically for this inner loop. 44 | optimizer = torch.optim.SGD([SFC], lr=0.6, momentum=0.9, dampening=0.9, weight_decay=0) 45 | 46 | # Create labels for the support set, used in cross-entropy loss during fine-tuning. 47 | label_shot = torch.arange(self.args.way).repeat(self.args.shot) 48 | label_shot = label_shot.type(torch.cuda.LongTensor) 49 | 50 | # Perform gradient steps to update the prototypes. 51 | with torch.enable_grad(): 52 | for k in range(50): # Number of gradient steps. 53 | rand_id = torch.randperm(self.args.way * self.args.shot).cuda() 54 | for j in range(0, self.args.way * self.args.shot, 4): 55 | selected_id = rand_id[j: min(j + 4, self.args.way * self.args.shot)] 56 | batch_shot = support[selected_id, :] 57 | batch_label = label_shot[selected_id] 58 | optimizer.zero_grad() 59 | logits = self.classifier(batch_shot.detach(), SFC) 60 | if logits.dim() == 1: 61 | logits = logits.unsqueeze(0) 62 | loss = F.cross_entropy(logits, batch_label) 63 | loss.backward() 64 | optimizer.step() 65 | return SFC 66 | 67 | def classifier(self, query, proto): 68 | """ 69 | Simple classifier that computes the negative squared Euclidean distance between query and prototype vectors, 70 | scaled by a temperature parameter for controlling the sharpness of the distribution. 71 | 72 | Parameters: 73 | - query: Query set embeddings. 74 | - proto: Prototype vectors. 75 | 76 | Returns: 77 | - logits: Logits representing similarity scores between each query and each prototype. 78 | """ 79 | logits = -torch.sum((proto.unsqueeze(0) - query.unsqueeze(1)) ** 2, 2) / self.args.temperature 80 | return logits.squeeze() 81 | 82 | def _forward(self, instance_embs, support_idx, query_idx): 83 | """ 84 | Forward pass of the model, processing both support and query data. 85 | 86 | Parameters: 87 | - instance_embs: Embeddings of all instances. 88 | - support_idx: Indices identifying support instances. 89 | - query_idx: Indices identifying query instances. 90 | 91 | Implements the forward pass, integrating both spatial and feature adaptation using the INSTA module. 92 | """ 93 | emb_dim = instance_embs.size()[-3:] 94 | channel_dim = emb_dim[0] 95 | 96 | # Organize support and query data based on indices, and reshape accordingly. 97 | support = instance_embs[support_idx.flatten()].view(*(support_idx.shape + emb_dim)) 98 | query = instance_embs[query_idx.flatten()].view(*(query_idx.shape + emb_dim)) 99 | num_samples = support.shape[1] 100 | num_proto = support.shape[2] 101 | support = support.squeeze() 102 | 103 | # Adapt support features using the INSTA model and average to form adapted prototypes. 104 | adapted_s, task_kernel = self.INSTA(support.view(-1, *emb_dim)) 105 | query = query.view(-1, *emb_dim) 106 | adapted_proto = adapted_s.view(num_samples, -1, *adapted_s.shape[1:]).mean(0) 107 | adapted_proto = nn.AdaptiveAvgPool2d(1)(adapted_proto).squeeze(-1).squeeze(-1) 108 | 109 | # Adapt query features using the INSTA unfolding and kernel multiplication approach. 110 | query_ = nn.AdaptiveAvgPool2d(1)((self.INSTA.unfold(query, int((task_kernel.shape[-1]+1)/2-1), task_kernel.shape[-1]) * task_kernel)).squeeze() 111 | query = query + query_ 112 | adapted_q = nn.AdaptiveAvgPool2d(1)(query).squeeze(-1).squeeze(-1) 113 | 114 | # Optionally perform an inner loop optimization during testing. 115 | if self.args.testing: 116 | adapted_proto = self.inner_loop(adapted_proto, nn.AdaptiveAvgPool2d(1)(support).squeeze().view(num_proto*num_samples, channel_dim)) 117 | 118 | # Classify using the adapted prototypes and query embeddings. 119 | logits = self.classifier(adapted_q, adapted_proto) 120 | 121 | if self.training: 122 | reg_logits = None 123 | return logits, reg_logits 124 | else: 125 | return logits 126 | -------------------------------------------------------------------------------- /model/trainer/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.optim as optim 5 | from torch.utils.data import DataLoader 6 | from model.dataloader.samplers import CategoriesSampler 7 | from model.models.protonet import ProtoNet 8 | from model.models.INSTA_ProtoNet import INSTA_ProtoNet 9 | 10 | class MultiGPUDataloader: 11 | def __init__(self, dataloader, num_device): 12 | self.dataloader = dataloader 13 | self.num_device = num_device 14 | 15 | def __len__(self): 16 | return len(self.dataloader) // self.num_device 17 | 18 | def __iter__(self): 19 | data_iter = iter(self.dataloader) 20 | done = False 21 | 22 | while not done: 23 | try: 24 | output_batch = ([], []) 25 | for _ in range(self.num_device): 26 | batch = next(data_iter) 27 | for i, v in enumerate(batch): 28 | output_batch[i].append(v[None]) 29 | 30 | yield ( torch.cat(_, dim=0) for _ in output_batch ) 31 | except StopIteration: 32 | done = True 33 | return 34 | 35 | def get_dataloader(args): 36 | 37 | if args.dataset == 'CUB': 38 | from model.dataloader.cub import CUB as Dataset 39 | elif args.dataset == 'TieredImageNet': 40 | from model.dataloader.tiered_imagenet import tieredImageNet as Dataset 41 | elif args.dataset == 'MiniImageNet': 42 | from model.dataloader.mini_imagenet import MiniImageNet as Dataset 43 | else: 44 | raise ValueError('Non-supported Dataset.') 45 | 46 | num_device = torch.cuda.device_count() 47 | num_episodes = args.episodes_per_epoch*num_device if args.multi_gpu else args.episodes_per_epoch 48 | num_workers=args.num_workers*num_device if args.multi_gpu else args.num_workers 49 | trainset = Dataset('train', args, augment=args.augment) 50 | args.num_class = trainset.num_class 51 | train_sampler = CategoriesSampler(trainset.label, 52 | num_episodes, 53 | max(args.way, args.num_classes), 54 | args.shot + args.query) 55 | 56 | train_loader = DataLoader(dataset=trainset, 57 | num_workers=num_workers, 58 | batch_sampler=train_sampler, 59 | pin_memory=True) 60 | 61 | 62 | valset = Dataset('val', args) 63 | val_sampler = CategoriesSampler(valset.label, 64 | args.num_eval_episodes, 65 | args.eval_way, args.eval_shot + args.eval_query) 66 | val_loader = DataLoader(dataset=valset, 67 | batch_sampler=val_sampler, 68 | num_workers=args.num_workers, 69 | pin_memory=True) 70 | 71 | 72 | testset = Dataset('test', args) 73 | test_sampler = CategoriesSampler(testset.label, 74 | 600, # args.num_eval_episodes, 75 | args.eval_way, args.eval_shot + args.eval_query) 76 | test_loader = DataLoader(dataset=testset, 77 | batch_sampler=test_sampler, 78 | num_workers=args.num_workers, 79 | pin_memory=True) 80 | 81 | return train_loader, val_loader, test_loader 82 | 83 | def prepare_model(args): 84 | model = eval(args.model_class)(args) 85 | 86 | # load pre-trained model (no FC weights) 87 | if args.init_weights is not None: 88 | model_dict = model.state_dict() 89 | pretrained_dict = torch.load(args.init_weights)['params'] 90 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 91 | print(pretrained_dict.keys()) 92 | model_dict.update(pretrained_dict) 93 | model.load_state_dict(model_dict) 94 | 95 | if torch.cuda.is_available(): 96 | torch.backends.cudnn.benchmark = True 97 | 98 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 99 | model = model.to(device) 100 | if args.multi_gpu: 101 | model.encoder = nn.DataParallel(model.encoder, dim=0) 102 | para_model = model.to(device) 103 | else: 104 | para_model = model.to(device) 105 | 106 | return model, para_model 107 | 108 | def prepare_optimizer(model, args): 109 | top_para = [v for k, v in model.named_parameters() if 'encoder' not in k] 110 | if args.use_AdamW: 111 | optimizer = optim.AdamW( 112 | [{'params': model.encoder.parameters()}, 113 | {'params': top_para, 'lr': args.lr * args.lr_mul}], 114 | lr=args.lr 115 | 116 | ) 117 | 118 | 119 | else: 120 | optimizer = optim.SGD( 121 | [{'params': model.encoder.parameters()}, 122 | {'params': top_para, 'lr': args.lr * args.lr_mul}], 123 | lr=args.lr, 124 | momentum=args.mom, 125 | nesterov=True, 126 | weight_decay=args.weight_decay 127 | ) 128 | 129 | 130 | 131 | if args.lr_scheduler == 'step': 132 | lr_scheduler = optim.lr_scheduler.StepLR( 133 | optimizer, 134 | step_size=int(args.step_size), 135 | gamma=args.gamma 136 | ) 137 | elif args.lr_scheduler == 'multistep': 138 | lr_scheduler = optim.lr_scheduler.MultiStepLR( 139 | optimizer, 140 | milestones=[int(_) for _ in args.step_size.split(',')], 141 | gamma=args.gamma, 142 | ) 143 | elif args.lr_scheduler == 'cosine': 144 | lr_scheduler = optim.lr_scheduler.CosineAnnealingLR( 145 | optimizer, 146 | args.max_epoch, 147 | eta_min=0 148 | ) 149 | else: 150 | raise ValueError('No Such Scheduler') 151 | 152 | return optimizer, lr_scheduler -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # INSTA: Learning Instance and Task-Aware Dynamic Kernels for Few Shot Learning 2 | 3 |

4 | 5 |

6 | 7 | If this repository is helpful to you, please cite the following bib: 8 | ```Shell 9 | @article{ma2021learning, 10 | title={Learning Instance and Task-Aware Dynamic Kernels for Few Shot Learning}, 11 | author={Ma, Rongkai and Fang, Pengfei and Avraham, Gil and Zuo, Yan and Drummond, Tom and Harandi, Mehrtash}, 12 | journal={arXiv preprint arXiv:2112.03494}, 13 | year={2021} 14 | } 15 | ``` 16 | This repository provides the implementation and demo of [**Learning Instance and Task-Aware Dynamic Kernels for Few Shot Learning**](https://arxiv.org/abs/2112.03494) on [Prototypical Network](https://arxiv.org/pdf/1703.05175.pdf). The dynamic enviroment of few-shot learning (FSL) requires a model capable of rapidly adapting to the novel tasks. Moreover, given the low-data regime of FSL, it requires the model to encode rich information for per-data sample. To tackle this problem, we propose to learn a dynamic kernel that is both **ins**tance and **t**ask-**a**ware: **INSTA** for each channel and spatial location of a feature map, given the task (episode) at hands. Beyond that, we further incorporate the inforamtion from the fequency domain to generate our dynamic kernel. 17 |

18 | 19 |

20 | 21 | ## Prerequisites 22 | We use anaconda to manage the virtual environment. Please install the following packages to run this repository. If there is a "No module" error, please install the suggested packages according to the error message. 23 | * python 3.8 24 | * [pytorch 1.7.0](https://pytorch.org/get-started/previous-versions/) 25 | * torchvision 0.8.0 26 | * torchaudio 0.7.0 27 | * tqdm 28 | * tensorboardX 29 | 30 | ## Dataset 31 | 32 | ### Tiered-ImageNet 33 | 34 | Tiered-ImageNet is also a subset of the ImageNet. This dataset consists of 608 classes from 34 categories and is split into 351 classes from 20 categories for training, 97 classes from 6 categories for validation, and 160 classes from 8 categories for testing. You can download the processed dataset in this [repository](https://github.com/icoz69/DeepEMD). Once the dataset is downloaded, please move it to /data direcotry. Note that the images have been resized into 84x84. 35 | 36 | ### Mini-ImageNet 37 | ```Shell 38 | ├── data 39 | ├── Mini-ImageNet 40 | ├── split 41 | ├── train 42 | ├── validation 43 | ├── test 44 | ├── images 45 | ├── im_0.jpg 46 | ├── im_1.jpg 47 | . 48 | . 49 | . 50 | ├── im_n.jpg 51 | ``` 52 | 53 | Mini-ImageNet is sampled from ImageNet. This dataset has 100 classes, with each having 600 samples. We follow the standard protocol to split the dataset into 64 training, 16 validation, and 20 testing classes. For downloading the corresponding split and data files, please refer to [this repository](https://github.com/Sha-Lab/FEAT). 54 | 55 | ### CUB 56 | 57 | The CUB is a fine-grained dataset, which consists of 11,788 images from 200 different breeds of birds. We follow the standard settings, in which the dataset is split into 100/50/50 breeds for training, validation, and testing, respectively. For ResNet-12 backbone, please refer to [this repository](https://github.com/icoz69/DeepEMD) to split the datasset and for ResNet-18 backbone, please refer to [this repository ](https://github.com/imtiazziko/LaplacianShot). 58 | 59 | ### FC100 60 | 61 | FC100 dataset is a variant of the standard CIFAR100 dataset, which contains images from 100 classes, with each class containing 600 samples. We follow the standard setting, where the dataset is split into 60/20/20 classes for training, validation and testing, respectively. For downloading and split the data, please refer to [DeepEMD repository](https://github.com/icoz69/DeepEMD). 62 | 63 | ## Training 64 | 65 | We provide the example command line for Tiered-ImageNet and Mini-ImageNet below: 66 | ```shell 67 | $ python train_fsl.py --max_epoch 200 --model_class INSTA_ProtoNet --backbone_class Res12 --dataset TieredImageNet --way 5 --eval_way 5 --shot 5 --eval_shot 5 --query 15 --eval_query 15 --temperature 32 --temperature2 64 --lr 0.0002 --lr_mul 100 --lr_scheduler cosine --gamma 0.5 --gpu 1 --init_weights ./saves/initialization/tieredimagenet/Res12-pre.pth --eval_interval 1 --use_euclidean 68 | ``` 69 | ```shell 70 | $ python train_fsl.py --max_epoch 200 --model_class INSTA_ProtoNet --backbone_class Res12 --dataset TieredImageNet --way 5 --eval_way 5 --shot 1 --eval_shot 1 --query 15 --eval_query 15 --temperature 64 --temperature2 64 --lr 0.0002 --lr_mul 30 --lr_scheduler cosine --gamma 0.5 --gpu 0 --init_weights ./saves/initialization/tieredimagenet/Res12-pre.pth --eval_interval 1 --use_euclidean 71 | ``` 72 | ```shell 73 | $ python train_fsl.py --max_epoch 200 --model_class INSTA_ProtoNet --backbone_class Res12 --dataset MiniImageNet --way 5 --eval_way 5 --shot 1 --eval_shot 1 --query 15 --eval_query 15 --temperature 64 --temperature2 64 --lr 0.0002 --lr_mul 25 --lr_scheduler cosine --gamma 0.5 --gpu 0 --init_weights ./saves/initialization/miniimagenet/Res12-pre.pth --eval_interval 1 --use_euclidean 74 | ``` 75 | ```shell 76 | $ python train_fsl.py --max_epoch 200 --model_class INSTA_ProtoNet --backbone_class Res12 --dataset MiniImageNet --way 5 --eval_way 5 --shot 5 --eval_shot 5 --query 15 --eval_query 15 --balance_1 1 --temperature 24 --temperature2 32 --lr 0.0002 --lr_mul 25 --lr_scheduler cosine --gamma 0.5 --gpu 0 --init_weights ./saves/initialization/miniimagenet/Res12-pre.pth --eval_interval 1 --use_euclidean 77 | ``` 78 | ## To Do 79 | * 80 | * 81 | 82 | ## Acknowledgements 83 | We acknowledge the following repositories to provide valuable insight of our code construciton: 84 | 85 | * [FEAT](https://github.com/Sha-Lab/FEAT) 86 | * [DeepEMD](https://github.com/icoz69/DeepEMD) 87 | * [Chen *etal*](https://github.com/wyharveychen/CloserLookFewShot) 88 | * [DeepBDC](https://github.com/Fei-Long121/DeepBDC) 89 | * [DDF](https://github.com/theFoxofSky/ddfnet) 90 | * [FCANet](https://github.com/cfzd/FcaNet) 91 | * [Fan *etal*](https://github.com/fanq15/FSOD-code) 92 | * [simple-cnaps](https://github.com/peymanbateni/simple-cnaps) 93 | -------------------------------------------------------------------------------- /model/networks/res18.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | __all__ = ['resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 4 | 'resnet152'] 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | def conv1x1(in_planes, out_planes, stride=1): 14 | """1x1 convolution""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(inplanes, planes, stride) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.conv2 = conv3x3(planes, planes) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.downsample = downsample 29 | self.stride = stride 30 | 31 | def forward(self, x): 32 | identity = x 33 | 34 | out = self.conv1(x) 35 | out = self.bn1(out) 36 | out = self.relu(out) 37 | 38 | out = self.conv2(out) 39 | out = self.bn2(out) 40 | 41 | if self.downsample is not None: 42 | identity = self.downsample(x) 43 | 44 | out += identity 45 | out = self.relu(out) 46 | 47 | return out 48 | 49 | 50 | class Bottleneck(nn.Module): 51 | expansion = 4 52 | 53 | def __init__(self, inplanes, planes, stride=1, downsample=None): 54 | super(Bottleneck, self).__init__() 55 | self.conv1 = conv1x1(inplanes, planes) 56 | self.bn1 = nn.BatchNorm2d(planes) 57 | self.conv2 = conv3x3(planes, planes, stride) 58 | self.bn2 = nn.BatchNorm2d(planes) 59 | self.conv3 = conv1x1(planes, planes * self.expansion) 60 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.downsample = downsample 63 | self.stride = stride 64 | 65 | def forward(self, x): 66 | identity = x 67 | 68 | out = self.conv1(x) 69 | out = self.bn1(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv2(out) 73 | out = self.bn2(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv3(out) 77 | out = self.bn3(out) 78 | 79 | if self.downsample is not None: 80 | identity = self.downsample(x) 81 | 82 | out += identity 83 | out = self.relu(out) 84 | 85 | return out 86 | 87 | 88 | class ResNet(nn.Module): 89 | 90 | def __init__(self, block=BasicBlock, layers=[2, 2, 2, 2], zero_init_residual=False): 91 | super(ResNet, self).__init__() 92 | self.inplanes = 64 93 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, 94 | bias=False) 95 | self.bn1 = nn.BatchNorm2d(64) 96 | self.relu = nn.ReLU(inplace=True) 97 | self.layer1 = self._make_layer(block, 64, layers[0]) 98 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 99 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 100 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 101 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 106 | elif isinstance(m, nn.BatchNorm2d): 107 | nn.init.constant_(m.weight, 1) 108 | nn.init.constant_(m.bias, 0) 109 | 110 | # Zero-initialize the last BN in each residual branch, 111 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 112 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 113 | if zero_init_residual: 114 | for m in self.modules(): 115 | if isinstance(m, Bottleneck): 116 | nn.init.constant_(m.bn3.weight, 0) 117 | elif isinstance(m, BasicBlock): 118 | nn.init.constant_(m.bn2.weight, 0) 119 | 120 | def _make_layer(self, block, planes, blocks, stride=1): 121 | downsample = None 122 | if stride != 1 or self.inplanes != planes * block.expansion: 123 | downsample = nn.Sequential( 124 | conv1x1(self.inplanes, planes * block.expansion, stride), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, stride, downsample)) 130 | self.inplanes = planes * block.expansion 131 | for _ in range(1, blocks): 132 | layers.append(block(self.inplanes, planes)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | 141 | x = self.layer1(x) 142 | x = self.layer2(x) 143 | x = self.layer3(x) 144 | x = self.layer4(x) 145 | 146 | # x = self.avgpool(x) 147 | # x = x.view(x.size(0), -1) 148 | 149 | return x 150 | 151 | 152 | def resnet10(**kwargs): 153 | """Constructs a ResNet-10 model. 154 | """ 155 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 156 | return model 157 | 158 | 159 | def resnet18(**kwargs): 160 | """Constructs a ResNet-18 model. 161 | """ 162 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 163 | return model 164 | 165 | 166 | def resnet34(**kwargs): 167 | """Constructs a ResNet-34 model. 168 | """ 169 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 170 | return model 171 | 172 | 173 | def resnet50(**kwargs): 174 | """Constructs a ResNet-50 model. 175 | """ 176 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 177 | return model 178 | 179 | 180 | def resnet101(**kwargs): 181 | """Constructs a ResNet-101 model. 182 | """ 183 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 184 | return model 185 | 186 | 187 | def resnet152(**kwargs): 188 | """Constructs a ResNet-152 model. 189 | """ 190 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 191 | return model -------------------------------------------------------------------------------- /model/trainer/fsl_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os.path as osp 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from model.trainer.base import Trainer 8 | from model.trainer.helpers import ( 9 | get_dataloader, prepare_model, prepare_optimizer, 10 | ) 11 | from model.utils import ( 12 | pprint, ensure_path, 13 | Averager, Timer, count_acc, 14 | compute_confidence_interval 15 | ) 16 | 17 | from tqdm import tqdm 18 | 19 | 20 | class FSLTrainer(Trainer): 21 | def __init__(self, args): 22 | super().__init__(args) 23 | 24 | self.train_loader, self.val_loader, self.test_loader = get_dataloader(args) 25 | self.model, self.para_model = prepare_model(args) 26 | self.optimizer, self.lr_scheduler = prepare_optimizer(self.model, args) 27 | 28 | def prepare_label(self): 29 | args = self.args 30 | 31 | # prepare one-hot label 32 | label = torch.arange(args.way, dtype=torch.int16).repeat(args.query) 33 | label_aux = torch.arange(args.way, dtype=torch.int8).repeat(args.shot + args.query) 34 | 35 | label = label.type(torch.LongTensor) 36 | label_aux = label_aux.type(torch.LongTensor) 37 | 38 | if torch.cuda.is_available(): 39 | label = label.cuda() 40 | label_aux = label_aux.cuda() 41 | 42 | return label, label_aux 43 | 44 | def train(self): 45 | args = self.args 46 | self.model.train() 47 | if self.args.fix_BN: 48 | self.model.encoder.eval() 49 | 50 | # start FSL training 51 | label, label_aux = self.prepare_label() 52 | for epoch in range(1, args.max_epoch + 1): 53 | self.train_epoch += 1 54 | self.model.train() 55 | if self.args.fix_BN: 56 | self.model.encoder.eval() 57 | 58 | tl1 = Averager() 59 | tl2 = Averager() 60 | ta = Averager() 61 | 62 | start_tm = time.time() 63 | for batch in self.train_loader: 64 | self.train_step += 1 65 | 66 | if torch.cuda.is_available(): 67 | data, gt_label = [_.cuda() for _ in batch] 68 | else: 69 | data, gt_label = batch[0], batch[1] 70 | 71 | data_tm = time.time() 72 | self.dt.add(data_tm - start_tm) 73 | 74 | # get saved centers 75 | logits, reg_logits = self.para_model(data) 76 | 77 | if reg_logits is not None: 78 | loss = F.cross_entropy(logits, label) 79 | total_loss = args.balance_1*loss + args.balance_2 * F.cross_entropy(reg_logits, label_aux) 80 | 81 | 82 | else: 83 | loss = F.cross_entropy(logits, label) 84 | total_loss = F.cross_entropy(logits, label) 85 | 86 | tl2.add(loss) 87 | forward_tm = time.time() 88 | self.ft.add(forward_tm - data_tm) 89 | acc = count_acc(logits, label) 90 | tl1.add(total_loss.item()) 91 | ta.add(acc) 92 | 93 | self.optimizer.zero_grad() 94 | total_loss.backward() 95 | backward_tm = time.time() 96 | self.bt.add(backward_tm - forward_tm) 97 | 98 | self.optimizer.step() 99 | optimizer_tm = time.time() 100 | self.ot.add(optimizer_tm - backward_tm) 101 | 102 | # refresh start_tm 103 | start_tm = time.time() 104 | 105 | self.lr_scheduler.step() 106 | self.try_evaluate(epoch) 107 | 108 | print('ETA:{}/{}'.format( 109 | self.timer.measure(), 110 | self.timer.measure(self.train_epoch / args.max_epoch)) 111 | ) 112 | 113 | torch.save(self.trlog, osp.join(args.save_path, 'trlog')) 114 | self.save_model('epoch-last') 115 | 116 | def evaluate(self, data_loader): 117 | # restore model args 118 | args = self.args 119 | # evaluation mode 120 | self.model.eval() 121 | record = np.zeros((args.num_eval_episodes, 2)) # loss and acc 122 | label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query) 123 | label = label.type(torch.LongTensor) 124 | if torch.cuda.is_available(): 125 | label = label.cuda() 126 | print('best epoch {}, best val acc={:.4f} + {:.4f}'.format( 127 | self.trlog['max_acc_epoch'], 128 | self.trlog['max_acc'], 129 | self.trlog['max_acc_interval'])) 130 | with torch.no_grad(): 131 | for i, batch in enumerate(data_loader, 1): 132 | if torch.cuda.is_available(): 133 | data, _ = [_.cuda() for _ in batch] 134 | else: 135 | data = batch[0] 136 | 137 | logits = self.model(data) 138 | loss = F.cross_entropy(logits, label) 139 | acc = count_acc(logits, label) 140 | record[i-1, 0] = loss.item() 141 | record[i-1, 1] = acc 142 | 143 | assert(i == record.shape[0]) 144 | vl, _ = compute_confidence_interval(record[:,0]) 145 | va, vap = compute_confidence_interval(record[:,1]) 146 | 147 | # train mode 148 | self.model.train() 149 | if self.args.fix_BN: 150 | self.model.encoder.eval() 151 | 152 | return vl, va, vap 153 | 154 | 155 | def evaluate_test(self): 156 | # restore model args 157 | args = self.args 158 | self.args.testing = True 159 | self.model.load_state_dict(torch.load(osp.join(self.args.save_path, 'max_acc.pth'))['params']) 160 | self.model.eval() 161 | record = np.zeros((600, 2)) # loss and acc 162 | label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query) 163 | label = label.type(torch.LongTensor) 164 | if torch.cuda.is_available(): 165 | label = label.cuda() 166 | print('best epoch {}, best val acc={:.4f} + {:.4f}'.format( 167 | self.trlog['max_acc_epoch'], 168 | self.trlog['max_acc'], 169 | self.trlog['max_acc_interval'])) 170 | with torch.no_grad(): 171 | for i, batch in tqdm(enumerate(self.test_loader, 1)): 172 | if torch.cuda.is_available(): 173 | data, _ = [_.cuda() for _ in batch] 174 | else: 175 | data = batch[0] 176 | 177 | logits = self.model(data) 178 | 179 | loss = F.cross_entropy(logits, label) 180 | acc = count_acc(logits, label) 181 | record[i-1, 0] = loss.item() 182 | record[i-1, 1] = acc 183 | assert(i == record.shape[0]) 184 | vl, _ = compute_confidence_interval(record[:,0]) 185 | va, vap = compute_confidence_interval(record[:,1]) 186 | 187 | self.trlog['test_acc'] = va 188 | self.trlog['test_acc_interval'] = vap 189 | self.trlog['test_loss'] = vl 190 | 191 | print('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format( 192 | self.trlog['max_acc_epoch'], 193 | self.trlog['max_acc'], 194 | self.trlog['max_acc_interval'])) 195 | print('Test acc={:.4f} + {:.4f}\n'.format( 196 | self.trlog['test_acc'], 197 | self.trlog['test_acc_interval'])) 198 | 199 | return vl, va, vap 200 | 201 | def final_record(self): 202 | # save the best performance in a txt file 203 | 204 | with open(osp.join(self.args.save_path, '{}+{}'.format(self.trlog['test_acc'], self.trlog['test_acc_interval'])), 'w') as f: 205 | f.write('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format( 206 | self.trlog['max_acc_epoch'], 207 | self.trlog['max_acc'], 208 | self.trlog['max_acc_interval'])) 209 | f.write('Test acc={:.4f} + {:.4f}\n'.format( 210 | self.trlog['test_acc'], 211 | self.trlog['test_acc_interval'])) -------------------------------------------------------------------------------- /model/models/INSTA.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from model.models.fcanet import MultiSpectralAttentionLayer 5 | 6 | """ 7 | The INSTA class inherits from nn.Module and implements an attention mechanism 8 | that involves both channel and spatial features. It's designed to work with feature maps 9 | and applies both a channel attention and a learned convolutional kernel for spatial attention. 10 | """ 11 | 12 | class INSTA(nn.Module): 13 | def __init__(self, c, spatial_size, sigma, k, args): 14 | """ 15 | Initialize the INSTA network module. 16 | 17 | Parameters: 18 | - c: Number of channels in the input feature map. 19 | - spatial_size: The height and width of the input feature map. 20 | - sigma: A parameter possibly used for normalization or a scale parameter in attention mechanisms. 21 | - k: Kernel size for convolution operations and spatial attention. 22 | - args: Additional arguments for setup, possibly including hyperparameters or configuration options. 23 | """ 24 | super().__init__() 25 | self.channel = c 26 | self.h1 = sigma 27 | self.h2 = k **2 28 | self.k = k 29 | # Standard 2D convolution for channel reduction or transformation. 30 | self.conv = nn.Conv2d(self.channel, self.h2, 1) 31 | # Batch normalization for the output of the spatial attention. 32 | self.fn_spatial = nn.BatchNorm2d(spatial_size**2) 33 | # Batch normalization for the output of the channel attention. 34 | self.fn_channel = nn.BatchNorm2d(self.channel) 35 | # Unfold operation for transforming feature map into patches. 36 | self.Unfold = nn.Unfold(kernel_size=self.k, padding=int((self.k+1)/2-1)) 37 | self.spatial_size = spatial_size 38 | # Dictionary mapping channel numbers to width/height for MultiSpectralAttentionLayer. 39 | c2wh = dict([(512, 11), (640, self.spatial_size)]) 40 | # MultiSpectralAttentionLayer for performing attention across spectral (frequency) components. 41 | self.channel_att = MultiSpectralAttentionLayer(c, c2wh[c], c2wh[c], sigma=self.h1, k=self.k, freq_sel_method='low16') 42 | self.args = args 43 | # Upper part of a Coordinate Learning Module (CLM), which modifies feature maps. 44 | self.CLM_upper = nn.Sequential( 45 | nn.Conv2d(c, c*2, 1), 46 | nn.BatchNorm2d(c*2), 47 | nn.ReLU(), 48 | nn.Conv2d(c*2, c*2, 1), 49 | nn.BatchNorm2d(c * 2), 50 | nn.ReLU() 51 | ) 52 | 53 | # Lower part of CLM, transforming the features back to original channel dimensions and applying sigmoid. 54 | self.CLM_lower = nn.Sequential( 55 | nn.Conv2d(c*2, c*2, 1), 56 | nn.BatchNorm2d(c*2), 57 | nn.ReLU(), 58 | nn.Conv2d(c*2, c, 1), 59 | nn.BatchNorm2d(c), 60 | nn.Sigmoid() # Sigmoid activation to normalize the feature values between 0 and 1. 61 | ) 62 | 63 | def CLM(self, featuremap): 64 | """ 65 | The Coordinate Learning Module (CLM) that processes feature maps to adapt them spatially. 66 | 67 | Parameters: 68 | - featuremap: The input feature map to the CLM. 69 | 70 | Returns: 71 | - The adapted feature map processed through the CLM. 72 | """ 73 | # Apply the upper CLM to modify and then aggregate features. 74 | adap = self.CLM_upper(featuremap) 75 | intermediate = adap.sum(dim=0) # Summing features across the batch dimension. 76 | adap_1 = self.CLM_lower(intermediate.unsqueeze(0)) # Applying the lower CLM. 77 | return adap_1 78 | 79 | def spatial_kernel_network(self, feature_map, conv): 80 | """ 81 | Applies a convolution to the feature map to generate a spatial kernel, 82 | which will be used to modulate the spatial regions of the input features. 83 | 84 | Parameters: 85 | - feature_map: The feature map to process. 86 | - conv: The convolutional layer to apply. 87 | 88 | Returns: 89 | - The processed spatial kernel. 90 | """ 91 | spatial_kernel = conv(feature_map) 92 | spatial_kernel = spatial_kernel.flatten(-2).transpose(-1, -2) 93 | size = spatial_kernel.size() 94 | spatial_kernel = spatial_kernel.view(size[0], -1, self.k, self.k) 95 | spatial_kernel = self.fn_spatial(spatial_kernel) 96 | 97 | spatial_kernel = spatial_kernel.flatten(-2) 98 | return spatial_kernel 99 | 100 | def channel_kernel_network(self, feature_map): 101 | """ 102 | Processes the feature map through a channel attention mechanism to modulate the channels 103 | based on their importance. 104 | 105 | Parameters: 106 | - feature_map: The feature map to process. 107 | 108 | Returns: 109 | - The channel-modulated feature map. 110 | """ 111 | channel_kernel = self.channel_att(feature_map) 112 | channel_kernel = self.fn_channel(channel_kernel) 113 | channel_kernel = channel_kernel.flatten(-2) 114 | channel_kernel = channel_kernel.squeeze().view(channel_kernel.shape[0], self.channel, -1) 115 | return channel_kernel 116 | 117 | def unfold(self, x, padding, k): 118 | """ 119 | A manual implementation of the unfold operation, which extracts sliding local blocks from a batched input tensor. 120 | 121 | Parameters: 122 | - x: The input tensor. 123 | - padding: Padding to apply to the tensor. 124 | - k: Kernel size for the blocks to extract. 125 | 126 | Returns: 127 | - The unfolded tensor containing all local blocks. 128 | """ 129 | x_padded = torch.cuda.FloatTensor(x.shape[0], x.shape[1], x.shape[2] + 2 * padding, x.shape[3] + 2 * padding).fill_(0) 130 | x_padded[:, :, padding:-padding, padding:-padding] = x 131 | x_unfolded = torch.cuda.FloatTensor(*x.shape, k, k).fill_(0) 132 | for i in range(int((self.k+1)/2-1), x.shape[2] + int((self.k+1)/2-1)): 133 | for j in range(int((self.k+1)/2-1), x.shape[3] + int((self.k+1)/2-1)): 134 | x_unfolded[:, :, i - int(((self.k+1)/2-1)), j - int(((self.k+1)/2-1)), :, :] = x_padded[:, :, i-int(((self.k+1)/2-1)):i + int((self.k+1)/2), j - int(((self.k+1)/2-1)):j + int(((self.k+1)/2))] 135 | return x_unfolded 136 | 137 | def forward(self, x): 138 | """ 139 | The forward method of INSTA, which combines the spatial and channel kernels to adapt the feature map, 140 | along with performing the unfolding operation to facilitate local receptive processing. 141 | 142 | Parameters: 143 | - x: The input tensor to the network. 144 | 145 | Returns: 146 | - The adapted feature map and the task-specific kernel used for adaptation. 147 | """ 148 | spatial_kernel = self.spatial_kernel_network(x, self.conv).unsqueeze(-3) 149 | channel_kernenl = self.channel_kernel_network(x).unsqueeze(-2) 150 | kernel = spatial_kernel * channel_kernenl # Combine spatial and channel kernels 151 | # Resize kernel and apply to the unfolded feature map 152 | kernel_shape = kernel.size() 153 | feature_shape = x.size() 154 | instance_kernel = kernel.view(kernel_shape[0], kernel_shape[1], feature_shape[-2], feature_shape[-1], self.k, self.k) 155 | task_s = self.CLM(x) # Get task-specific representation 156 | spatial_kernel_task = self.spatial_kernel_network(task_s, self.conv).unsqueeze(-3) 157 | channel_kernenl_task = self.channel_kernel_network(task_s).unsqueeze(-2) 158 | task_kernel = spatial_kernel_task * channel_kernenl_task 159 | task_kernel_shape = task_kernel.size() 160 | task_kernel = task_kernel.view(task_kernel_shape[0], task_kernel_shape[1], feature_shape[-2], feature_shape[-1], self.k, self.k) 161 | kernel = task_kernel * instance_kernel 162 | unfold_feature = self.unfold(x, int((self.k+1)/2-1), self.k) # Perform a custom unfold operation 163 | adapted_feauture = (unfold_feature * kernel).mean(dim=(-1, -2)).squeeze(-1).squeeze(-1) 164 | return adapted_feauture + x, task_kernel # Return the normal training output and task-specific kernel 165 | -------------------------------------------------------------------------------- /model/networks/res10.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from torch.autograd import Variable 5 | import torch.nn as nn 6 | import math 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from torch.nn.utils.weight_norm import WeightNorm 10 | 11 | 12 | # Basic ResNet model 13 | 14 | def init_layer(L): 15 | # Initialization using fan-in 16 | if isinstance(L, nn.Conv2d): 17 | n = L.kernel_size[0] * L.kernel_size[1] * L.out_channels 18 | L.weight.data.normal_(0, math.sqrt(2.0 / float(n))) 19 | elif isinstance(L, nn.BatchNorm2d): 20 | L.weight.data.fill_(1) 21 | L.bias.data.fill_(0) 22 | 23 | 24 | class distLinear(nn.Module): 25 | def __init__(self, indim, outdim): 26 | super(distLinear, self).__init__() 27 | self.L = nn.Linear(indim, outdim, bias=False) 28 | self.class_wise_learnable_norm = True # See the issue#4&8 in the github 29 | if self.class_wise_learnable_norm: 30 | WeightNorm.apply(self.L, 'weight', dim=0) # split the weight update component to direction and norm 31 | 32 | if outdim <= 200: 33 | self.scale_factor = 2; # a fixed scale factor to scale the output of cos value into a reasonably large input for softmax, for to reproduce the result of CUB with ResNet10, use 4. see the issue#31 in the github 34 | else: 35 | self.scale_factor = 10; # in omniglot, a larger scale factor is required to handle >1000 output classes. 36 | 37 | def forward(self, x): 38 | x_norm = torch.norm(x, p=2, dim=1).unsqueeze(1).expand_as(x) 39 | x_normalized = x.div(x_norm + 0.00001) 40 | if not self.class_wise_learnable_norm: 41 | L_norm = torch.norm(self.L.weight.data, p=2, dim=1).unsqueeze(1).expand_as(self.L.weight.data) 42 | self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001) 43 | cos_dist = self.L( 44 | x_normalized) # matrix product by forward function, but when using WeightNorm, this also multiply the cosine distance by a class-wise learnable norm, see the issue#4&8 in the github 45 | scores = self.scale_factor * (cos_dist) 46 | 47 | return scores 48 | 49 | 50 | class Flatten(nn.Module): 51 | def __init__(self): 52 | super(Flatten, self).__init__() 53 | 54 | def forward(self, x): 55 | return x.view(x.size(0), -1) 56 | 57 | 58 | class Linear_fw(nn.Linear): # used in MAML to forward input with fast weight 59 | def __init__(self, in_features, out_features): 60 | super(Linear_fw, self).__init__(in_features, out_features) 61 | self.weight.fast = None # Lazy hack to add fast weight link 62 | self.bias.fast = None 63 | 64 | def forward(self, x): 65 | if self.weight.fast is not None and self.bias.fast is not None: 66 | out = F.linear(x, self.weight.fast, 67 | self.bias.fast) # weight.fast (fast weight) is the temporaily adapted weight 68 | else: 69 | out = super(Linear_fw, self).forward(x) 70 | return out 71 | 72 | 73 | class Conv2d_fw(nn.Conv2d): # used in MAML to forward input with fast weight 74 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True): 75 | super(Conv2d_fw, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, 76 | bias=bias) 77 | self.weight.fast = None 78 | if not self.bias is None: 79 | self.bias.fast = None 80 | 81 | def forward(self, x): 82 | if self.bias is None: 83 | if self.weight.fast is not None: 84 | out = F.conv2d(x, self.weight.fast, None, stride=self.stride, padding=self.padding) 85 | else: 86 | out = super(Conv2d_fw, self).forward(x) 87 | else: 88 | if self.weight.fast is not None and self.bias.fast is not None: 89 | out = F.conv2d(x, self.weight.fast, self.bias.fast, stride=self.stride, padding=self.padding) 90 | else: 91 | out = super(Conv2d_fw, self).forward(x) 92 | 93 | return out 94 | 95 | 96 | class BatchNorm2d_fw(nn.BatchNorm2d): # used in MAML to forward input with fast weight 97 | def __init__(self, num_features): 98 | super(BatchNorm2d_fw, self).__init__(num_features) 99 | self.weight.fast = None 100 | self.bias.fast = None 101 | 102 | def forward(self, x): 103 | running_mean = torch.zeros(x.data.size()[1]).cuda() 104 | running_var = torch.ones(x.data.size()[1]).cuda() 105 | if self.weight.fast is not None and self.bias.fast is not None: 106 | out = F.batch_norm(x, running_mean, running_var, self.weight.fast, self.bias.fast, training=True, 107 | momentum=1) 108 | # batch_norm momentum hack: follow hack of Kate Rakelly in pytorch-maml/src/layers.py 109 | else: 110 | out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, training=True, momentum=1) 111 | return out 112 | 113 | 114 | # Simple Conv Block 115 | 116 | 117 | # Simple ResNet Block 118 | class SimpleBlock(nn.Module): 119 | maml = False # Default 120 | 121 | def __init__(self, indim, outdim, half_res): 122 | super(SimpleBlock, self).__init__() 123 | self.indim = indim 124 | self.outdim = outdim 125 | self.C1 = Conv2d_fw(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False) 126 | self.BN1 = BatchNorm2d_fw(outdim) 127 | self.C2 = Conv2d_fw(outdim, outdim, kernel_size=3, padding=1, bias=False) 128 | self.BN2 = BatchNorm2d_fw(outdim) 129 | 130 | self.relu1 = nn.ReLU(inplace=True) 131 | self.relu2 = nn.ReLU(inplace=True) 132 | 133 | self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2] 134 | 135 | self.half_res = half_res 136 | 137 | # if the input number of channels is not equal to the output, then need a 1x1 convolution 138 | if indim != outdim: 139 | if self.maml: 140 | self.shortcut = Conv2d_fw(indim, outdim, 1, 2 if half_res else 1, bias=False) 141 | self.BNshortcut = BatchNorm2d_fw(outdim) 142 | else: 143 | self.shortcut = nn.Conv2d(indim, outdim, 1, 2 if half_res else 1, bias=False) 144 | self.BNshortcut = nn.BatchNorm2d(outdim) 145 | 146 | self.parametrized_layers.append(self.shortcut) 147 | self.parametrized_layers.append(self.BNshortcut) 148 | self.shortcut_type = '1x1' 149 | else: 150 | self.shortcut_type = 'identity' 151 | 152 | for layer in self.parametrized_layers: 153 | init_layer(layer) 154 | 155 | def forward(self, x): 156 | out = self.C1(x) 157 | out = self.BN1(out) 158 | out = self.relu1(out) 159 | out = self.C2(out) 160 | out = self.BN2(out) 161 | short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x)) 162 | out = out + short_out 163 | out = self.relu2(out) 164 | return out 165 | 166 | 167 | 168 | 169 | 170 | 171 | class ResNet(nn.Module): 172 | maml = False # Default 173 | 174 | def __init__(self, block, list_of_num_layers, list_of_out_dims, flatten=True): 175 | # list_of_num_layers specifies number of layers in each stage 176 | # list_of_out_dims specifies number of output channel for each stage 177 | super(ResNet, self).__init__() 178 | assert len(list_of_num_layers) == 4, 'Can have only four stages' 179 | 180 | conv1 = Conv2d_fw(3, 64, kernel_size=7, stride=2, padding=3, 181 | bias=False) 182 | bn1 = BatchNorm2d_fw(64) 183 | relu = nn.ReLU() 184 | pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 185 | 186 | init_layer(conv1) 187 | init_layer(bn1) 188 | 189 | trunk = [conv1, bn1, relu, pool1] 190 | 191 | indim = 64 192 | for i in range(4): 193 | 194 | for j in range(list_of_num_layers[i]): 195 | half_res = (i >= 1) and (j == 0) 196 | B = block(indim, list_of_out_dims[i], half_res) 197 | trunk.append(B) 198 | indim = list_of_out_dims[i] 199 | 200 | if flatten: 201 | avgpool = nn.AvgPool2d(7) 202 | trunk.append(avgpool) 203 | trunk.append(Flatten()) 204 | self.final_feat_dim = indim 205 | else: 206 | self.final_feat_dim = [indim, 7, 7] 207 | 208 | self.trunk = nn.Sequential(*trunk) 209 | 210 | def forward(self, x): 211 | out = self.trunk(x) 212 | return out 213 | 214 | 215 | 216 | 217 | def ResNet10(flatten=True): 218 | return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], flatten) 219 | 220 | 221 | 222 | 223 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | import pprint 5 | import torch 6 | import argparse 7 | import numpy as np 8 | import torch.nn as nn 9 | def one_hot(indices, depth): 10 | """ 11 | Returns a one-hot tensor. 12 | This is a PyTorch equivalent of Tensorflow's tf.one_hot. 13 | 14 | Parameters: 15 | indices: a (n_batch, m) Tensor or (m) Tensor. 16 | depth: a scalar. Represents the depth of the one hot dimension. 17 | Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor. 18 | """ 19 | 20 | encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])) 21 | if indices.is_cuda: 22 | encoded_indicies = encoded_indicies.cuda() 23 | index = indices.view(indices.size()+torch.Size([1])) 24 | encoded_indicies = encoded_indicies.scatter_(1,index,1) 25 | 26 | return encoded_indicies 27 | 28 | def set_gpu(x): 29 | os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID' 30 | os.environ['CUDA_VISIBLE_DEVICES'] = x 31 | print('using gpu:', x) 32 | 33 | def ensure_path(dir_path, scripts_to_save=None): 34 | if os.path.exists(dir_path): 35 | if input('{} exists, remove? ([y]/n)'.format(dir_path)) != 'n': 36 | shutil.rmtree(dir_path) 37 | os.mkdir(dir_path) 38 | else: 39 | os.mkdir(dir_path) 40 | 41 | print('Experiment dir : {}'.format(dir_path)) 42 | if scripts_to_save is not None: 43 | script_path = os.path.join(dir_path, 'scripts') 44 | if not os.path.exists(script_path): 45 | os.makedirs(script_path) 46 | for src_file in scripts_to_save: 47 | dst_file = os.path.join(dir_path, 'scripts', os.path.basename(src_file)) 48 | print('copy {} to {}'.format(src_file, dst_file)) 49 | if os.path.isdir(src_file): 50 | shutil.copytree(src_file, dst_file) 51 | else: 52 | shutil.copyfile(src_file, dst_file) 53 | 54 | class Averager(): 55 | 56 | def __init__(self): 57 | self.n = 0 58 | self.v = 0 59 | 60 | def add(self, x): 61 | self.v = (self.v * self.n + x) / (self.n + 1) 62 | self.n += 1 63 | 64 | def item(self): 65 | return self.v 66 | 67 | 68 | class CrossEntropyLoss(nn.Module): 69 | def __init__(self): 70 | super(CrossEntropyLoss, self).__init__() 71 | self.logsoftmax = nn.LogSoftmax(dim=1) 72 | 73 | def forward(self, inputs, targets): 74 | input_ = inputs 75 | input_ = input_.view(input_.size(0), input_.size(1), -1) 76 | 77 | log_probs = self.logsoftmax(input_) 78 | targets_ = torch.zeros(input_.size(0), input_.size(1)).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 79 | targets_ = targets_.unsqueeze(-1) 80 | targets_ = targets_.cuda() 81 | loss = (- targets_ * log_probs).mean(0).sum() 82 | return loss / input_.size(2) 83 | 84 | 85 | 86 | def c_acc(logits, labels_test): 87 | _, preds = torch.max(logits, 1) 88 | acc = (torch.sum(preds == labels_test)).type(torch.cuda.FloatTensor) / labels_test.size(0) 89 | 90 | return acc.item() 91 | 92 | 93 | 94 | def count_acc(logits, label): 95 | pred = torch.argmax(logits, dim=1) 96 | if torch.cuda.is_available(): 97 | return (pred == label).type(torch.cuda.FloatTensor).mean().item() 98 | else: 99 | return (pred == label).type(torch.FloatTensor).mean().item() 100 | 101 | def euclidean_metric(a, b): 102 | n = a.shape[0] 103 | m = b.shape[0] 104 | a = a.unsqueeze(1).expand(n, m, -1) 105 | b = b.unsqueeze(0).expand(n, m, -1) 106 | logits = -((a - b)**2).sum(dim=2) 107 | return logits 108 | 109 | class Timer(): 110 | 111 | def __init__(self): 112 | self.o = time.time() 113 | 114 | def measure(self, p=1): 115 | x = (time.time() - self.o) / p 116 | x = int(x) 117 | if x >= 3600: 118 | return '{:.1f}h'.format(x / 3600) 119 | if x >= 60: 120 | return '{}m'.format(round(x / 60)) 121 | return '{}s'.format(x) 122 | 123 | _utils_pp = pprint.PrettyPrinter() 124 | def pprint(x): 125 | _utils_pp.pprint(x) 126 | 127 | def compute_confidence_interval(data): 128 | """ 129 | Compute 95% confidence interval 130 | :param data: An array of mean accuracy (or mAP) across a number of sampled episodes. 131 | :return: the 95% confidence interval for this data. 132 | """ 133 | a = 1.0 * np.array(data) 134 | m = np.mean(a) 135 | std = np.std(a) 136 | pm = 1.96 * (std / np.sqrt(len(a))) 137 | return m, pm 138 | 139 | def postprocess_args(args): 140 | args.num_classes = args.way 141 | save_path1 = '-'.join([args.dataset, args.model_class, args.backbone_class, 142 | '{:02d}w{:02d}s{:02}q'.format(args.way, args.shot, args.query)]) 143 | 144 | 145 | save_path2 = '_'.join([str('_'.join(args.step_size.split(','))), str(args.gamma), 146 | 'lr{:.2g}mul{:.2g}'.format(args.lr, args.lr_mul), 147 | str(args.lr_scheduler), 148 | 'T1{}T2{}'.format(args.temperature, args.temperature2), 149 | 'b{}'.format(args.balance_1), 150 | 'bsz{:03d}'.format( max(args.way, args.num_classes)*(args.shot+args.query) ), 151 | ]) 152 | if args.init_weights is not None: 153 | save_path1 += '-Pre' 154 | if args.use_euclidean: 155 | save_path1 += '-DIS' 156 | else: 157 | save_path1 += '-SIM' 158 | 159 | if args.fix_BN: 160 | save_path2 += '-FBN' 161 | if not args.augment: 162 | save_path2 += '-NoAug' 163 | 164 | if not os.path.exists(os.path.join(args.save_dir, save_path1)): 165 | os.mkdir(os.path.join(args.save_dir, save_path1)) 166 | args.save_path = os.path.join(args.save_dir, save_path1, save_path2) 167 | return args 168 | 169 | def get_command_line_parser(): 170 | parser = argparse.ArgumentParser() 171 | parser.add_argument('--max_epoch', type=int, default=200) 172 | parser.add_argument('--episodes_per_epoch', type=int, default=100) 173 | parser.add_argument('--num_eval_episodes', type=int, default=600) 174 | parser.add_argument('--model_class', type=str, default='INSTA_PorotNet', 175 | choices=['INSTA_ProtoNet', 'ProtoNet']) 176 | parser.add_argument('--use_euclidean', action='store_true', default=False) 177 | parser.add_argument('--use_AdamW', action='store_true', default=False) 178 | parser.add_argument('--backbone_class', type=str, default='Res12', 179 | choices=['Res12', 'Res18']) 180 | parser.add_argument('--dataset', type=str, default='MiniImageNet', 181 | choices=['MiniImageNet', 'TieredImageNet', 'CUB', 'FC100']) 182 | 183 | parser.add_argument('--way', type=int, default=5) 184 | parser.add_argument('--eval_way', type=int, default=5) 185 | parser.add_argument('--shot', type=int, default=1) 186 | parser.add_argument('--eval_shot', type=int, default=1) 187 | parser.add_argument('--query', type=int, default=15) 188 | parser.add_argument('--eval_query', type=int, default=15) 189 | parser.add_argument('--balance_1', type=float, default=0) 190 | parser.add_argument('--balance_2', type=float, default=0) 191 | parser.add_argument('--temperature', type=float, default=1) 192 | parser.add_argument('--temperature2', type=float, default=1) # the temperature in the 193 | 194 | # optimization parameters 195 | parser.add_argument('--orig_imsize', type=int, default=-1) # -1 for no cache, and -2 for no resize, only for MiniImageNet and CUB 196 | parser.add_argument('--lr', type=float, default=0.0001) 197 | parser.add_argument('--lr_mul', type=float, default=10) 198 | parser.add_argument('--lr_scheduler', type=str, default='step', choices=['multistep', 'step', 'cosine']) 199 | parser.add_argument('--step_size', type=str, default='20') 200 | parser.add_argument('--gamma', type=float, default=0.2) 201 | parser.add_argument('--fix_BN', action='store_true', default=False) # means we do not update the running mean/var in BN, not to freeze BN 202 | parser.add_argument('--augment', action='store_true', default=False) 203 | parser.add_argument('--baseline', type=str, default='y') 204 | parser.add_argument('--multi_gpu', action='store_true', default=False) 205 | parser.add_argument('--gpu', default='0') 206 | parser.add_argument('--init_weights', type=str, default=None) 207 | parser.add_argument('--emb_adap', action='store_true', default=False) 208 | parser.add_argument('--testing', action='store_true', default=False) 209 | 210 | # usually untouched parameters 211 | parser.add_argument('--mom', type=float, default=0.9) 212 | parser.add_argument('--weight_decay', type=float, default=0.0005) # we find this weight decay value works the best 213 | parser.add_argument('--num_workers', type=int, default=8) 214 | parser.add_argument('--log_interval', type=int, default=50) 215 | parser.add_argument('--eval_interval', type=int, default=1) 216 | parser.add_argument('--save_dir', type=str, default='./checkpoints') 217 | 218 | return parser -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 119 | -------------------------------------------------------------------------------- /model/models/utils/transformers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module, ModuleList, Linear, Dropout, LayerNorm, Identity, Parameter, init 3 | import torch.nn.functional as F 4 | from .stochastic_depth import DropPath 5 | 6 | 7 | class Attention(Module): 8 | """ 9 | Obtained from timm: github.com:rwightman/pytorch-image-models 10 | """ 11 | 12 | def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1): 13 | super().__init__() 14 | self.num_heads = num_heads 15 | head_dim = dim // self.num_heads 16 | self.scale = head_dim ** -0.5 17 | 18 | self.qkv = Linear(dim, dim * 3, bias=False) 19 | self.attn_drop = Dropout(attention_dropout) 20 | self.proj = Linear(dim, dim) 21 | self.proj_drop = Dropout(projection_dropout) 22 | 23 | def forward(self, x): 24 | B, N, C = x.shape 25 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 26 | q, k, v = qkv[0], qkv[1], qkv[2] 27 | 28 | attn = (q @ k.transpose(-2, -1)) * self.scale 29 | attn = attn.softmax(dim=-1) 30 | attn = self.attn_drop(attn) 31 | 32 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 33 | x = self.proj(x) 34 | x = self.proj_drop(x) 35 | return x 36 | 37 | 38 | class MaskedAttention(Module): 39 | def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1): 40 | super().__init__() 41 | self.num_heads = num_heads 42 | head_dim = dim // self.num_heads 43 | self.scale = head_dim ** -0.5 44 | 45 | self.qkv = Linear(dim, dim * 3, bias=False) 46 | self.attn_drop = Dropout(attention_dropout) 47 | self.proj = Linear(dim, dim) 48 | self.proj_drop = Dropout(projection_dropout) 49 | 50 | def forward(self, x, mask=None): 51 | B, N, C = x.shape 52 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 53 | q, k, v = qkv[0], qkv[1], qkv[2] 54 | 55 | attn = (q @ k.transpose(-2, -1)) * self.scale 56 | 57 | if mask is not None: 58 | mask_value = -torch.finfo(attn.dtype).max 59 | assert mask.shape[-1] == attn.shape[-1], 'mask has incorrect dimensions' 60 | mask = mask[:, None, :] * mask[:, :, None] 61 | mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) 62 | attn.masked_fill_(~mask, mask_value) 63 | 64 | attn = attn.softmax(dim=-1) 65 | attn = self.attn_drop(attn) 66 | 67 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 68 | x = self.proj(x) 69 | x = self.proj_drop(x) 70 | return x 71 | 72 | 73 | class TransformerEncoderLayer(Module): 74 | """ 75 | Inspired by torch.nn.TransformerEncoderLayer and timm. 76 | """ 77 | 78 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 79 | attention_dropout=0.1, drop_path_rate=0.1): 80 | super(TransformerEncoderLayer, self).__init__() 81 | self.pre_norm = LayerNorm(d_model) 82 | self.self_attn = Attention(dim=d_model, num_heads=nhead, 83 | attention_dropout=attention_dropout, projection_dropout=dropout) 84 | 85 | self.linear1 = Linear(d_model, dim_feedforward) 86 | self.dropout1 = Dropout(dropout) 87 | self.norm1 = LayerNorm(d_model) 88 | self.linear2 = Linear(dim_feedforward, d_model) 89 | self.dropout2 = Dropout(dropout) 90 | 91 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity() 92 | 93 | self.activation = F.gelu 94 | 95 | def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor: 96 | src = src + self.drop_path(self.self_attn(self.pre_norm(src))) 97 | src = self.norm1(src) 98 | src2 = self.linear2(self.dropout1(self.activation(self.linear1(src)))) 99 | src = src + self.drop_path(self.dropout2(src2)) 100 | return src 101 | 102 | 103 | class MaskedTransformerEncoderLayer(Module): 104 | """ 105 | Inspired by torch.nn.TransformerEncoderLayer and timm. 106 | """ 107 | 108 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 109 | attention_dropout=0.1, drop_path_rate=0.1): 110 | super(MaskedTransformerEncoderLayer, self).__init__() 111 | self.pre_norm = LayerNorm(d_model) 112 | self.self_attn = MaskedAttention(dim=d_model, num_heads=nhead, 113 | attention_dropout=attention_dropout, projection_dropout=dropout) 114 | 115 | self.linear1 = Linear(d_model, dim_feedforward) 116 | self.dropout1 = Dropout(dropout) 117 | self.norm1 = LayerNorm(d_model) 118 | self.linear2 = Linear(dim_feedforward, d_model) 119 | self.dropout2 = Dropout(dropout) 120 | 121 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity() 122 | 123 | self.activation = F.gelu 124 | 125 | def forward(self, src: torch.Tensor, mask=None, *args, **kwargs) -> torch.Tensor: 126 | src = src + self.drop_path(self.self_attn(self.pre_norm(src), mask)) 127 | src = self.norm1(src) 128 | src2 = self.linear2(self.dropout1(self.activation(self.linear1(src)))) 129 | src = src + self.drop_path(self.dropout2(src2)) 130 | return src 131 | 132 | 133 | class TransformerClassifier(Module): 134 | def __init__(self, 135 | seq_pool=True, 136 | embedding_dim=768, 137 | num_layers=12, 138 | num_heads=12, 139 | mlp_ratio=4.0, 140 | num_classes=1000, 141 | dropout_rate=0.1, 142 | attention_dropout=0.1, 143 | stochastic_depth_rate=0.1, 144 | positional_embedding='sine', 145 | sequence_length=None, 146 | *args, **kwargs): 147 | super().__init__() 148 | positional_embedding = positional_embedding if \ 149 | positional_embedding in ['sine', 'learnable', 'none'] else 'sine' 150 | dim_feedforward = int(embedding_dim * mlp_ratio) 151 | self.embedding_dim = embedding_dim 152 | self.sequence_length = sequence_length 153 | self.seq_pool = seq_pool 154 | 155 | assert sequence_length is not None or positional_embedding == 'none', \ 156 | f"Positional embedding is set to {positional_embedding} and" \ 157 | f" the sequence length was not specified." 158 | 159 | if not seq_pool: 160 | sequence_length += 1 161 | self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim), 162 | requires_grad=True) 163 | else: 164 | self.attention_pool = Linear(self.embedding_dim, 1) 165 | 166 | if positional_embedding != 'none': 167 | if positional_embedding == 'learnable': 168 | self.positional_emb = Parameter(torch.zeros(1, sequence_length, embedding_dim), 169 | requires_grad=True) 170 | init.trunc_normal_(self.positional_emb, std=0.2) 171 | else: 172 | self.positional_emb = Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim), 173 | requires_grad=False) 174 | else: 175 | self.positional_emb = None 176 | 177 | self.dropout = Dropout(p=dropout_rate) 178 | dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)] 179 | self.blocks = ModuleList([ 180 | TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads, 181 | dim_feedforward=dim_feedforward, dropout=dropout_rate, 182 | attention_dropout=attention_dropout, drop_path_rate=dpr[i]) 183 | for i in range(num_layers)]) 184 | self.norm = LayerNorm(embedding_dim) 185 | 186 | self.fc = Linear(embedding_dim, num_classes) 187 | self.apply(self.init_weight) 188 | 189 | def forward(self, x): 190 | if self.positional_emb is None and x.size(1) < self.sequence_length: 191 | x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0) 192 | 193 | if not self.seq_pool: 194 | cls_token = self.class_emb.expand(x.shape[0], -1, -1) 195 | x = torch.cat((cls_token, x), dim=1) 196 | 197 | if self.positional_emb is not None: 198 | x += self.positional_emb 199 | 200 | x = self.dropout(x) 201 | 202 | for blk in self.blocks: 203 | x = blk(x) 204 | x = self.norm(x) 205 | 206 | if self.seq_pool: 207 | x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2) 208 | else: 209 | x = x[:, 0] 210 | 211 | x = self.fc(x) 212 | return x 213 | 214 | @staticmethod 215 | def init_weight(m): 216 | if isinstance(m, Linear): 217 | init.trunc_normal_(m.weight, std=.02) 218 | if isinstance(m, Linear) and m.bias is not None: 219 | init.constant_(m.bias, 0) 220 | elif isinstance(m, LayerNorm): 221 | init.constant_(m.bias, 0) 222 | init.constant_(m.weight, 1.0) 223 | 224 | @staticmethod 225 | def sinusoidal_embedding(n_channels, dim): 226 | pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)] 227 | for p in range(n_channels)]) 228 | pe[:, 0::2] = torch.sin(pe[:, 0::2]) 229 | pe[:, 1::2] = torch.cos(pe[:, 1::2]) 230 | return pe.unsqueeze(0) 231 | 232 | 233 | class MaskedTransformerClassifier(Module): 234 | def __init__(self, 235 | seq_pool=True, 236 | embedding_dim=768, 237 | num_layers=12, 238 | num_heads=12, 239 | mlp_ratio=4.0, 240 | num_classes=1000, 241 | dropout_rate=0.1, 242 | attention_dropout=0.1, 243 | stochastic_depth_rate=0.1, 244 | positional_embedding='sine', 245 | seq_len=None, 246 | *args, **kwargs): 247 | super().__init__() 248 | positional_embedding = positional_embedding if \ 249 | positional_embedding in ['sine', 'learnable', 'none'] else 'sine' 250 | dim_feedforward = int(embedding_dim * mlp_ratio) 251 | self.embedding_dim = embedding_dim 252 | self.seq_len = seq_len 253 | self.seq_pool = seq_pool 254 | 255 | assert seq_len is not None or positional_embedding == 'none', \ 256 | f"Positional embedding is set to {positional_embedding} and" \ 257 | f" the sequence length was not specified." 258 | 259 | if not seq_pool: 260 | seq_len += 1 261 | self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim), 262 | requires_grad=True) 263 | else: 264 | self.attention_pool = Linear(self.embedding_dim, 1) 265 | 266 | if positional_embedding != 'none': 267 | if positional_embedding == 'learnable': 268 | seq_len += 1 # padding idx 269 | self.positional_emb = Parameter(torch.zeros(1, seq_len, embedding_dim), 270 | requires_grad=True) 271 | init.trunc_normal_(self.positional_emb, std=0.2) 272 | else: 273 | self.positional_emb = Parameter(self.sinusoidal_embedding(seq_len, 274 | embedding_dim, 275 | padding_idx=True), 276 | requires_grad=False) 277 | else: 278 | self.positional_emb = None 279 | 280 | self.dropout = Dropout(p=dropout_rate) 281 | dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)] 282 | self.blocks = ModuleList([ 283 | MaskedTransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads, 284 | dim_feedforward=dim_feedforward, dropout=dropout_rate, 285 | attention_dropout=attention_dropout, drop_path_rate=dpr[i]) 286 | for i in range(num_layers)]) 287 | self.norm = LayerNorm(embedding_dim) 288 | 289 | self.fc = Linear(embedding_dim, num_classes) 290 | self.apply(self.init_weight) 291 | 292 | def forward(self, x, mask=None): 293 | if self.positional_emb is None and x.size(1) < self.seq_len: 294 | x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0) 295 | 296 | if not self.seq_pool: 297 | cls_token = self.class_emb.expand(x.shape[0], -1, -1) 298 | x = torch.cat((cls_token, x), dim=1) 299 | if mask is not None: 300 | mask = torch.cat([torch.ones(size=(mask.shape[0], 1), device=mask.device), mask.float()], dim=1) 301 | mask = (mask > 0) 302 | 303 | if self.positional_emb is not None: 304 | x += self.positional_emb 305 | 306 | x = self.dropout(x) 307 | 308 | for blk in self.blocks: 309 | x = blk(x, mask=mask) 310 | x = self.norm(x) 311 | 312 | if self.seq_pool: 313 | x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2) 314 | else: 315 | x = x[:, 0] 316 | 317 | x = self.fc(x) 318 | return x 319 | 320 | @staticmethod 321 | def init_weight(m): 322 | if isinstance(m, Linear): 323 | init.trunc_normal_(m.weight, std=.02) 324 | if isinstance(m, Linear) and m.bias is not None: 325 | init.constant_(m.bias, 0) 326 | elif isinstance(m, LayerNorm): 327 | init.constant_(m.bias, 0) 328 | init.constant_(m.weight, 1.0) 329 | 330 | @staticmethod 331 | def sinusoidal_embedding(n_channels, dim, padding_idx=False): 332 | pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)] 333 | for p in range(n_channels)]) 334 | pe[:, 0::2] = torch.sin(pe[:, 0::2]) 335 | pe[:, 1::2] = torch.cos(pe[:, 1::2]) 336 | pe = pe.unsqueeze(0) 337 | if padding_idx: 338 | return torch.cat([torch.zeros((1, 1, dim)), pe], dim=1) 339 | return pe 340 | -------------------------------------------------------------------------------- /model/networks/utils/transformers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module, ModuleList, Linear, Dropout, LayerNorm, Identity, Parameter, init 3 | import torch.nn.functional as F 4 | from .stochastic_depth import DropPath 5 | 6 | 7 | class Attention(Module): 8 | """ 9 | Obtained from timm: github.com:rwightman/pytorch-image-models 10 | """ 11 | 12 | def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1): 13 | super().__init__() 14 | self.num_heads = num_heads 15 | head_dim = dim // self.num_heads 16 | self.scale = head_dim ** -0.5 17 | 18 | self.qkv = Linear(dim, dim * 3, bias=False) 19 | self.attn_drop = Dropout(attention_dropout) 20 | self.proj = Linear(dim, dim) 21 | self.proj_drop = Dropout(projection_dropout) 22 | 23 | def forward(self, x): 24 | B, N, C = x.shape 25 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 26 | q, k, v = qkv[0], qkv[1], qkv[2] 27 | 28 | attn = (q @ k.transpose(-2, -1)) * self.scale 29 | attn = attn.softmax(dim=-1) 30 | attn = self.attn_drop(attn) 31 | 32 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 33 | x = self.proj(x) 34 | x = self.proj_drop(x) 35 | return x 36 | 37 | 38 | class MaskedAttention(Module): 39 | def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1): 40 | super().__init__() 41 | self.num_heads = num_heads 42 | head_dim = dim // self.num_heads 43 | self.scale = head_dim ** -0.5 44 | 45 | self.qkv = Linear(dim, dim * 3, bias=False) 46 | self.attn_drop = Dropout(attention_dropout) 47 | self.proj = Linear(dim, dim) 48 | self.proj_drop = Dropout(projection_dropout) 49 | 50 | def forward(self, x, mask=None): 51 | B, N, C = x.shape 52 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 53 | q, k, v = qkv[0], qkv[1], qkv[2] 54 | 55 | attn = (q @ k.transpose(-2, -1)) * self.scale 56 | 57 | if mask is not None: 58 | mask_value = -torch.finfo(attn.dtype).max 59 | assert mask.shape[-1] == attn.shape[-1], 'mask has incorrect dimensions' 60 | mask = mask[:, None, :] * mask[:, :, None] 61 | mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) 62 | attn.masked_fill_(~mask, mask_value) 63 | 64 | attn = attn.softmax(dim=-1) 65 | attn = self.attn_drop(attn) 66 | 67 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 68 | x = self.proj(x) 69 | x = self.proj_drop(x) 70 | return x 71 | 72 | 73 | class TransformerEncoderLayer(Module): 74 | """ 75 | Inspired by torch.nn.TransformerEncoderLayer and timm. 76 | """ 77 | 78 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 79 | attention_dropout=0.1, drop_path_rate=0.1): 80 | super(TransformerEncoderLayer, self).__init__() 81 | self.pre_norm = LayerNorm(d_model) 82 | self.self_attn = Attention(dim=d_model, num_heads=nhead, 83 | attention_dropout=attention_dropout, projection_dropout=dropout) 84 | 85 | self.linear1 = Linear(d_model, dim_feedforward) 86 | self.dropout1 = Dropout(dropout) 87 | self.norm1 = LayerNorm(d_model) 88 | self.linear2 = Linear(dim_feedforward, d_model) 89 | self.dropout2 = Dropout(dropout) 90 | 91 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity() 92 | 93 | self.activation = F.gelu 94 | 95 | def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor: 96 | src = src + self.drop_path(self.self_attn(self.pre_norm(src))) 97 | src = self.norm1(src) 98 | src2 = self.linear2(self.dropout1(self.activation(self.linear1(src)))) 99 | src = src + self.drop_path(self.dropout2(src2)) 100 | return src 101 | 102 | 103 | class MaskedTransformerEncoderLayer(Module): 104 | """ 105 | Inspired by torch.nn.TransformerEncoderLayer and timm. 106 | """ 107 | 108 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 109 | attention_dropout=0.1, drop_path_rate=0.1): 110 | super(MaskedTransformerEncoderLayer, self).__init__() 111 | self.pre_norm = LayerNorm(d_model) 112 | self.self_attn = MaskedAttention(dim=d_model, num_heads=nhead, 113 | attention_dropout=attention_dropout, projection_dropout=dropout) 114 | 115 | self.linear1 = Linear(d_model, dim_feedforward) 116 | self.dropout1 = Dropout(dropout) 117 | self.norm1 = LayerNorm(d_model) 118 | self.linear2 = Linear(dim_feedforward, d_model) 119 | self.dropout2 = Dropout(dropout) 120 | 121 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity() 122 | 123 | self.activation = F.gelu 124 | 125 | def forward(self, src: torch.Tensor, mask=None, *args, **kwargs) -> torch.Tensor: 126 | src = src + self.drop_path(self.self_attn(self.pre_norm(src), mask)) 127 | src = self.norm1(src) 128 | src2 = self.linear2(self.dropout1(self.activation(self.linear1(src)))) 129 | src = src + self.drop_path(self.dropout2(src2)) 130 | return src 131 | 132 | 133 | class TransformerClassifier(Module): 134 | def __init__(self, 135 | seq_pool=True, 136 | embedding_dim=768, 137 | num_layers=12, 138 | num_heads=12, 139 | mlp_ratio=4.0, 140 | num_classes=1000, 141 | dropout_rate=0.1, 142 | attention_dropout=0.1, 143 | stochastic_depth_rate=0.1, 144 | positional_embedding='sine', 145 | sequence_length=None, 146 | *args, **kwargs): 147 | super().__init__() 148 | positional_embedding = positional_embedding if \ 149 | positional_embedding in ['sine', 'learnable', 'none'] else 'sine' 150 | dim_feedforward = int(embedding_dim * mlp_ratio) 151 | self.embedding_dim = embedding_dim 152 | self.sequence_length = sequence_length 153 | self.seq_pool = seq_pool 154 | 155 | assert sequence_length is not None or positional_embedding == 'none', \ 156 | f"Positional embedding is set to {positional_embedding} and" \ 157 | f" the sequence length was not specified." 158 | 159 | if not seq_pool: 160 | sequence_length += 1 161 | self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim), 162 | requires_grad=True) 163 | else: 164 | self.attention_pool = Linear(self.embedding_dim, 1) 165 | 166 | if positional_embedding != 'none': 167 | if positional_embedding == 'learnable': 168 | self.positional_emb = Parameter(torch.zeros(1, sequence_length, embedding_dim), 169 | requires_grad=True) 170 | init.trunc_normal_(self.positional_emb, std=0.2) 171 | else: 172 | self.positional_emb = Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim), 173 | requires_grad=False) 174 | else: 175 | self.positional_emb = None 176 | 177 | self.dropout = Dropout(p=dropout_rate) 178 | dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)] 179 | self.blocks = ModuleList([ 180 | TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads, 181 | dim_feedforward=dim_feedforward, dropout=dropout_rate, 182 | attention_dropout=attention_dropout, drop_path_rate=dpr[i]) 183 | for i in range(num_layers)]) 184 | self.norm = LayerNorm(embedding_dim) 185 | 186 | # self.fc = Linear(embedding_dim, num_classes) 187 | self.apply(self.init_weight) 188 | 189 | def forward(self, x): 190 | if self.positional_emb is None and x.size(1) < self.sequence_length: 191 | x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0) 192 | 193 | if not self.seq_pool: 194 | cls_token = self.class_emb.expand(x.shape[0], -1, -1) 195 | x = torch.cat((cls_token, x), dim=1) 196 | 197 | if self.positional_emb is not None: 198 | x += self.positional_emb 199 | 200 | x = self.dropout(x) 201 | 202 | for blk in self.blocks: 203 | x = blk(x) 204 | x = self.norm(x) 205 | 206 | if self.seq_pool: 207 | x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2) 208 | else: 209 | x = x[:, 0] 210 | 211 | # x = self.fc(x) 212 | return x 213 | 214 | @staticmethod 215 | def init_weight(m): 216 | if isinstance(m, Linear): 217 | init.trunc_normal_(m.weight, std=.02) 218 | if isinstance(m, Linear) and m.bias is not None: 219 | init.constant_(m.bias, 0) 220 | elif isinstance(m, LayerNorm): 221 | init.constant_(m.bias, 0) 222 | init.constant_(m.weight, 1.0) 223 | 224 | @staticmethod 225 | def sinusoidal_embedding(n_channels, dim): 226 | pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)] 227 | for p in range(n_channels)]) 228 | pe[:, 0::2] = torch.sin(pe[:, 0::2]) 229 | pe[:, 1::2] = torch.cos(pe[:, 1::2]) 230 | return pe.unsqueeze(0) 231 | 232 | 233 | class MaskedTransformerClassifier(Module): 234 | def __init__(self, 235 | seq_pool=True, 236 | embedding_dim=768, 237 | num_layers=12, 238 | num_heads=12, 239 | mlp_ratio=4.0, 240 | num_classes=1000, 241 | dropout_rate=0.1, 242 | attention_dropout=0.1, 243 | stochastic_depth_rate=0.1, 244 | positional_embedding='sine', 245 | seq_len=None, 246 | *args, **kwargs): 247 | super().__init__() 248 | positional_embedding = positional_embedding if \ 249 | positional_embedding in ['sine', 'learnable', 'none'] else 'sine' 250 | dim_feedforward = int(embedding_dim * mlp_ratio) 251 | self.embedding_dim = embedding_dim 252 | self.seq_len = seq_len 253 | self.seq_pool = seq_pool 254 | 255 | assert seq_len is not None or positional_embedding == 'none', \ 256 | f"Positional embedding is set to {positional_embedding} and" \ 257 | f" the sequence length was not specified." 258 | 259 | if not seq_pool: 260 | seq_len += 1 261 | self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim), 262 | requires_grad=True) 263 | else: 264 | self.attention_pool = Linear(self.embedding_dim, 1) 265 | 266 | if positional_embedding != 'none': 267 | if positional_embedding == 'learnable': 268 | seq_len += 1 # padding idx 269 | self.positional_emb = Parameter(torch.zeros(1, seq_len, embedding_dim), 270 | requires_grad=True) 271 | init.trunc_normal_(self.positional_emb, std=0.2) 272 | else: 273 | self.positional_emb = Parameter(self.sinusoidal_embedding(seq_len, 274 | embedding_dim, 275 | padding_idx=True), 276 | requires_grad=False) 277 | else: 278 | self.positional_emb = None 279 | 280 | self.dropout = Dropout(p=dropout_rate) 281 | dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)] 282 | self.blocks = ModuleList([ 283 | MaskedTransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads, 284 | dim_feedforward=dim_feedforward, dropout=dropout_rate, 285 | attention_dropout=attention_dropout, drop_path_rate=dpr[i]) 286 | for i in range(num_layers)]) 287 | self.norm = LayerNorm(embedding_dim) 288 | 289 | self.fc = Linear(embedding_dim, num_classes) 290 | self.apply(self.init_weight) 291 | 292 | def forward(self, x, mask=None): 293 | if self.positional_emb is None and x.size(1) < self.seq_len: 294 | x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0) 295 | 296 | if not self.seq_pool: 297 | cls_token = self.class_emb.expand(x.shape[0], -1, -1) 298 | x = torch.cat((cls_token, x), dim=1) 299 | if mask is not None: 300 | mask = torch.cat([torch.ones(size=(mask.shape[0], 1), device=mask.device), mask.float()], dim=1) 301 | mask = (mask > 0) 302 | 303 | if self.positional_emb is not None: 304 | x += self.positional_emb 305 | 306 | x = self.dropout(x) 307 | 308 | for blk in self.blocks: 309 | x = blk(x, mask=mask) 310 | x = self.norm(x) 311 | 312 | if self.seq_pool: 313 | x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2) 314 | else: 315 | x = x[:, 0] 316 | 317 | x = self.fc(x) 318 | return x 319 | 320 | @staticmethod 321 | def init_weight(m): 322 | if isinstance(m, Linear): 323 | init.trunc_normal_(m.weight, std=.02) 324 | if isinstance(m, Linear) and m.bias is not None: 325 | init.constant_(m.bias, 0) 326 | elif isinstance(m, LayerNorm): 327 | init.constant_(m.bias, 0) 328 | init.constant_(m.weight, 1.0) 329 | 330 | @staticmethod 331 | def sinusoidal_embedding(n_channels, dim, padding_idx=False): 332 | pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)] 333 | for p in range(n_channels)]) 334 | pe[:, 0::2] = torch.sin(pe[:, 0::2]) 335 | pe[:, 1::2] = torch.cos(pe[:, 1::2]) 336 | pe = pe.unsqueeze(0) 337 | if padding_idx: 338 | return torch.cat([torch.zeros((1, 1, dim)), pe], dim=1) 339 | return pe 340 | --------------------------------------------------------------------------------