├── model
├── __init__.py
├── trainer
│ ├── __init__.py
│ ├── base.py
│ ├── helpers.py
│ └── fsl_trainer.py
├── models
│ ├── utils
│ │ ├── __init__.py
│ │ ├── embedder.py
│ │ ├── stochastic_depth.py
│ │ ├── tokenizer.py
│ │ └── transformers.py
│ ├── __init__.py
│ ├── base.py
│ ├── protonet.py
│ ├── fcanet.py
│ ├── INSTA_ProtoNet.py
│ └── INSTA.py
├── networks
│ ├── __init__.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── embedder.py
│ │ ├── stochastic_depth.py
│ │ ├── tokenizer.py
│ │ └── transformers.py
│ ├── dropblock.py
│ ├── res12.py
│ ├── res18.py
│ └── res10.py
├── logger.py
├── dataloader
│ ├── split_cub.py
│ ├── samplers.py
│ ├── transforms.py
│ ├── mini_imagenet.py
│ ├── cub.py
│ └── tiered_imagenet.py
├── data_parallel.py
└── utils.py
├── data
├── cub
│ └── .gitignore
└── miniimagenet
│ ├── .gitignore
│ └── download.sh
├── visual
├── concept.png
├── heatmap.png
└── pipeline.png
├── .idea
├── misc.xml
├── vcs.xml
├── .gitignore
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── modules.xml
├── code.iml
├── remote-mappings.xml
└── deployment.xml
├── train_fsl.py
├── LICENSE
└── README.md
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/model/trainer/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/cub/.gitignore:
--------------------------------------------------------------------------------
1 | images
2 |
--------------------------------------------------------------------------------
/model/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/model/networks/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/model/networks/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/miniimagenet/.gitignore:
--------------------------------------------------------------------------------
1 | images
2 |
--------------------------------------------------------------------------------
/model/models/__init__.py:
--------------------------------------------------------------------------------
1 | from model.models.base import FewShotModel_1
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/visual/concept.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RongKaiWeskerMA/INSTA/HEAD/visual/concept.png
--------------------------------------------------------------------------------
/visual/heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RongKaiWeskerMA/INSTA/HEAD/visual/heatmap.png
--------------------------------------------------------------------------------
/visual/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RongKaiWeskerMA/INSTA/HEAD/visual/pipeline.png
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/data/miniimagenet/download.sh:
--------------------------------------------------------------------------------
1 |
2 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1BCxmqLANXHbBaWs8A7_jqfVUv8mydp5R' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1BCxmqLANXHbBaWs8A7_jqfVUv8mydp5R" -O miniimagenet.zip && rm -rf /tmp/cookies.txt
3 |
4 | unzip miniimagenet.zip miniimagenet/
5 |
--------------------------------------------------------------------------------
/.idea/code.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/remote-mappings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/train_fsl.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from model.trainer.fsl_trainer import FSLTrainer
4 | from model.utils import (
5 | pprint, set_gpu,
6 | get_command_line_parser,
7 | postprocess_args,
8 | )
9 | # from ipdb import launch_ipdb_on_exception
10 |
11 | if __name__ == '__main__':
12 | parser = get_command_line_parser()
13 | args = postprocess_args(parser.parse_args())
14 | # with launch_ipdb_on_exception():
15 | pprint(vars(args))
16 |
17 | set_gpu(args.gpu)
18 | trainer = FSLTrainer(args)
19 | trainer.train()
20 | trainer.evaluate_test()
21 | trainer.final_record()
22 | print(args.save_path)
23 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 RongKaiWeskerMA
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/model/models/utils/embedder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Embedder(nn.Module):
5 | def __init__(self,
6 | word_embedding_dim=300,
7 | vocab_size=100000,
8 | padding_idx=1,
9 | pretrained_weight=None,
10 | embed_freeze=False,
11 | *args, **kwargs):
12 | super(Embedder, self).__init__()
13 | self.embeddings = nn.Embedding.from_pretrained(pretrained_weight, freeze=embed_freeze) \
14 | if pretrained_weight is not None else \
15 | nn.Embedding(vocab_size, word_embedding_dim, padding_idx=padding_idx)
16 | self.embeddings.weight.requires_grad = not embed_freeze
17 |
18 | def forward_mask(self, mask):
19 | bsz, seq_len = mask.shape
20 | new_mask = mask.view(bsz, seq_len, 1)
21 | new_mask = new_mask.sum(-1)
22 | new_mask = (new_mask > 0)
23 | return new_mask
24 |
25 | def forward(self, x, mask=None):
26 | embed = self.embeddings(x)
27 | embed = embed if mask is None else embed * self.forward_mask(mask).unsqueeze(-1).float()
28 | return embed, mask
29 |
30 | @staticmethod
31 | def init_weight(m):
32 | if isinstance(m, nn.Linear):
33 | nn.init.trunc_normal_(m.weight, std=.02)
34 | if isinstance(m, nn.Linear) and m.bias is not None:
35 | nn.init.constant_(m.bias, 0)
36 | else:
37 | nn.init.normal_(m.weight)
38 |
--------------------------------------------------------------------------------
/model/networks/utils/embedder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Embedder(nn.Module):
5 | def __init__(self,
6 | word_embedding_dim=300,
7 | vocab_size=100000,
8 | padding_idx=1,
9 | pretrained_weight=None,
10 | embed_freeze=False,
11 | *args, **kwargs):
12 | super(Embedder, self).__init__()
13 | self.embeddings = nn.Embedding.from_pretrained(pretrained_weight, freeze=embed_freeze) \
14 | if pretrained_weight is not None else \
15 | nn.Embedding(vocab_size, word_embedding_dim, padding_idx=padding_idx)
16 | self.embeddings.weight.requires_grad = not embed_freeze
17 |
18 | def forward_mask(self, mask):
19 | bsz, seq_len = mask.shape
20 | new_mask = mask.view(bsz, seq_len, 1)
21 | new_mask = new_mask.sum(-1)
22 | new_mask = (new_mask > 0)
23 | return new_mask
24 |
25 | def forward(self, x, mask=None):
26 | embed = self.embeddings(x)
27 | embed = embed if mask is None else embed * self.forward_mask(mask).unsqueeze(-1).float()
28 | return embed, mask
29 |
30 | @staticmethod
31 | def init_weight(m):
32 | if isinstance(m, nn.Linear):
33 | nn.init.trunc_normal_(m.weight, std=.02)
34 | if isinstance(m, nn.Linear) and m.bias is not None:
35 | nn.init.constant_(m.bias, 0)
36 | else:
37 | nn.init.normal_(m.weight)
38 |
--------------------------------------------------------------------------------
/model/models/utils/stochastic_depth.py:
--------------------------------------------------------------------------------
1 | # Thanks to rwightman's timm package
2 | # github.com:rwightman/pytorch-image-models
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | def drop_path(x, drop_prob: float = 0., training: bool = False):
9 | """
10 | Obtained from: github.com:rwightman/pytorch-image-models
11 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
12 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
13 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
14 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
15 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
16 | 'survival rate' as the argument.
17 | """
18 | if drop_prob == 0. or not training:
19 | return x
20 | keep_prob = 1 - drop_prob
21 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
22 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
23 | random_tensor.floor_() # binarize
24 | output = x.div(keep_prob) * random_tensor
25 | return output
26 |
27 |
28 | class DropPath(nn.Module):
29 | """
30 | Obtained from: github.com:rwightman/pytorch-image-models
31 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
32 | """
33 |
34 | def __init__(self, drop_prob=None):
35 | super(DropPath, self).__init__()
36 | self.drop_prob = drop_prob
37 |
38 | def forward(self, x):
39 | return drop_path(x, self.drop_prob, self.training)
40 |
--------------------------------------------------------------------------------
/model/networks/utils/stochastic_depth.py:
--------------------------------------------------------------------------------
1 | # Thanks to rwightman's timm package
2 | # github.com:rwightman/pytorch-image-models
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | def drop_path(x, drop_prob: float = 0., training: bool = False):
9 | """
10 | Obtained from: github.com:rwightman/pytorch-image-models
11 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
12 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
13 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
14 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
15 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
16 | 'survival rate' as the argument.
17 | """
18 | if drop_prob == 0. or not training:
19 | return x
20 | keep_prob = 1 - drop_prob
21 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
22 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
23 | random_tensor.floor_() # binarize
24 | output = x.div(keep_prob) * random_tensor
25 | return output
26 |
27 |
28 | class DropPath(nn.Module):
29 | """
30 | Obtained from: github.com:rwightman/pytorch-image-models
31 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
32 | """
33 |
34 | def __init__(self, drop_prob=None):
35 | super(DropPath, self).__init__()
36 | self.drop_prob = drop_prob
37 |
38 | def forward(self, x):
39 | return drop_path(x, self.drop_prob, self.training)
40 |
--------------------------------------------------------------------------------
/model/logger.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os.path as osp
3 | import numpy as np
4 | from collections import defaultdict, OrderedDict
5 | from tensorboardX import SummaryWriter
6 |
7 | class ConfigEncoder(json.JSONEncoder):
8 | def default(self, o):
9 | if isinstance(o, type):
10 | return {'$class': o.__module__ + "." + o.__name__}
11 | elif isinstance(o, Enum):
12 | return {
13 | '$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name
14 | }
15 | elif callable(o):
16 | return {
17 | '$function': o.__module__ + "." + o.__name__
18 | }
19 | return json.JSONEncoder.default(self, o)
20 |
21 | class Logger(object):
22 | def __init__(self, args, log_dir, **kwargs):
23 | self.logger_path = osp.join(log_dir, 'scalars.json')
24 | self.tb_logger = SummaryWriter(
25 | logdir=osp.join(log_dir, 'tflogger'),
26 | **kwargs,
27 | )
28 | self.log_config(vars(args))
29 |
30 | self.scalars = defaultdict(OrderedDict)
31 |
32 | def add_scalar(self, key, value, counter):
33 | assert self.scalars[key].get(counter, None) is None, 'counter should be distinct'
34 | self.scalars[key][counter] = value
35 | self.tb_logger.add_scalar(key, value, counter)
36 |
37 | def log_config(self, variant_data):
38 | config_filepath = osp.join(osp.dirname(self.logger_path), 'configs.json')
39 | with open(config_filepath, "w") as fd:
40 | json.dump(variant_data, fd, indent=2, sort_keys=True, cls=ConfigEncoder)
41 |
42 | def dump(self):
43 | with open(self.logger_path, 'w') as fd:
44 | json.dump(self.scalars, fd, indent=2)
--------------------------------------------------------------------------------
/model/dataloader/split_cub.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import argparse
4 | from os import listdir
5 | from os.path import isfile, isdir, join
6 | import random
7 |
8 | if __name__ == '__main__':
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--data', type=str, help='path to the data')
11 | parser.add_argument('--split', type=str, help='path to the split folder')
12 | args = parser.parse_args()
13 | dataset_list = ['train','val','test']
14 | #
15 | prex1 = args.data
16 | data_path = join(prex1,'images/')
17 | #
18 | folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))]
19 | folder_list.sort()
20 | label_dict = dict(zip(folder_list,range(0,len(folder_list))))
21 |
22 | classfile_list_all = []
23 |
24 | for i, folder in enumerate(folder_list):
25 | folder_path = join(data_path, folder)
26 | classfile_list_all.append( [join(folder,cf) for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')])
27 | random.shuffle(classfile_list_all[i])
28 |
29 | if not os.path.isdir(args.split):
30 | os.makedirs(args.split)
31 |
32 |
33 | for dataset in dataset_list:
34 | file_list = []
35 | label_list = []
36 | for i, classfile_list in enumerate(classfile_list_all):
37 | if 'train' in dataset:
38 | if (i%2 == 0):
39 | file_list = file_list + classfile_list
40 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist()
41 |
42 | if 'val' in dataset:
43 | if (i%4 == 1):
44 | file_list = file_list + classfile_list
45 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist()
46 |
47 | if 'test' in dataset:
48 | if (i%4 == 3):
49 | file_list = file_list + classfile_list
50 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist()
--------------------------------------------------------------------------------
/model/models/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from sklearn.svm import LinearSVC
5 | from sklearn.linear_model import LogisticRegression
6 | from sklearn.model_selection import GridSearchCV
7 |
8 |
9 | class FewShotModel_1(nn.Module):
10 | def __init__(self, args):
11 | super().__init__()
12 | self.args = args
13 | # from model.models.ddf import DDF
14 | if args.backbone_class == 'Res12':
15 | hdim = 640
16 | from model.networks.res12 import ResNet
17 | self.encoder = ResNet()
18 | elif args.backbone_class == 'Res18':
19 | hdim = 512
20 | from model.networks.res18 import ResNet
21 | self.encoder = ResNet()
22 | else:
23 | raise ValueError('')
24 |
25 | def split_instances(self, data):
26 | args = self.args
27 | if self.training:
28 | return (torch.Tensor(np.arange(args.way*args.shot)).long().view(1, args.shot, args.way),
29 | torch.Tensor(np.arange(args.way*args.shot, args.way * (args.shot + args.query))).long().view(1, args.query, args.way))
30 | else:
31 | return (torch.Tensor(np.arange(args.eval_way*args.eval_shot)).long().view(1, args.eval_shot, args.eval_way),
32 | torch.Tensor(np.arange(args.eval_way*args.eval_shot, args.eval_way * (args.eval_shot + args.eval_query))).long().view(1, args.eval_query, args.eval_way))
33 |
34 |
35 | def forward(self, x, get_feature=False):
36 | if get_feature:
37 | # get feature with the provided embeddings
38 | return self.encoder(x)
39 | else:
40 | # feature extraction
41 | x = x.squeeze(0)
42 | instance_embs = self.encoder(x)
43 |
44 | support_idx, query_idx = self.split_instances(x)
45 | if self.training:
46 | logits, logits_reg = self._forward(instance_embs, support_idx, query_idx)
47 | return logits, logits_reg
48 | else:
49 | logits = self._forward(instance_embs, support_idx, query_idx)
50 | return logits
51 |
52 | def _forward(self, x, support_idx, query_idx):
53 | raise NotImplementedError('Suppose to be implemented by subclass')
--------------------------------------------------------------------------------
/model/networks/dropblock.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from torch.distributions import Bernoulli
5 |
6 |
7 | class DropBlock(nn.Module):
8 | def __init__(self, block_size):
9 | super(DropBlock, self).__init__()
10 |
11 | self.block_size = block_size
12 |
13 | def forward(self, x, gamma):
14 | # shape: (bsize, channels, height, width)
15 |
16 | if self.training:
17 | batch_size, channels, height, width = x.shape
18 | bernoulli = Bernoulli(gamma)
19 | mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1)))
20 | if torch.cuda.is_available():
21 | mask = mask.cuda()
22 | block_mask = self._compute_block_mask(mask)
23 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3]
24 | count_ones = block_mask.sum()
25 |
26 | return block_mask * x * (countM / count_ones)
27 | else:
28 | return x
29 |
30 | def _compute_block_mask(self, mask):
31 | left_padding = int((self.block_size-1) / 2)
32 | right_padding = int(self.block_size / 2)
33 |
34 | batch_size, channels, height, width = mask.shape
35 | non_zero_idxs = mask.nonzero()
36 | nr_blocks = non_zero_idxs.shape[0]
37 |
38 | offsets = torch.stack(
39 | [
40 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding,
41 | torch.arange(self.block_size).repeat(self.block_size), #- left_padding
42 | ]
43 | ).t()
44 | offsets = torch.cat((torch.zeros(self.block_size**2, 2).long(), offsets.long()), 1)
45 | if torch.cuda.is_available():
46 | offsets = offsets.cuda()
47 |
48 | if nr_blocks > 0:
49 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1)
50 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4)
51 | offsets = offsets.long()
52 |
53 | block_idxs = non_zero_idxs + offsets
54 | #block_idxs += left_padding
55 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
56 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1.
57 | else:
58 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
59 |
60 | block_mask = 1 - padded_mask#[:height, :width]
61 | return block_mask
62 |
--------------------------------------------------------------------------------
/model/dataloader/samplers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class CategoriesSampler():
6 |
7 | def __init__(self, label, n_batch, n_cls, n_per):
8 | self.n_batch = n_batch
9 | self.n_cls = n_cls
10 | self.n_per = n_per
11 |
12 | label = np.array(label)
13 | self.m_ind = []
14 | for i in range(max(label) + 1):
15 | ind = np.argwhere(label == i).reshape(-1)
16 | ind = torch.from_numpy(ind)
17 | self.m_ind.append(ind)
18 |
19 | def __len__(self):
20 | return self.n_batch
21 |
22 | def __iter__(self):
23 | for i_batch in range(self.n_batch):
24 | batch = []
25 | classes = torch.randperm(len(self.m_ind))[:self.n_cls]
26 | for c in classes:
27 | l = self.m_ind[c]
28 | pos = torch.randperm(len(l))[:self.n_per]
29 | batch.append(l[pos])
30 | batch = torch.stack(batch).t().reshape(-1)
31 | yield batch
32 |
33 |
34 | class RandomSampler():
35 |
36 | def __init__(self, label, n_batch, n_per):
37 | self.n_batch = n_batch
38 | self.n_per = n_per
39 | self.label = np.array(label)
40 | self.num_label = self.label.shape[0]
41 |
42 | def __len__(self):
43 | return self.n_batch
44 |
45 | def __iter__(self):
46 | for i_batch in range(self.n_batch):
47 | batch = torch.randperm(self.num_label)[:self.n_per]
48 | yield batch
49 |
50 |
51 | # sample for each class
52 | class ClassSampler():
53 |
54 | def __init__(self, label, n_per=None):
55 | self.n_per = n_per
56 | label = np.array(label)
57 | self.m_ind = []
58 | for i in range(max(label) + 1):
59 | ind = np.argwhere(label == i).reshape(-1)
60 | ind = torch.from_numpy(ind)
61 | self.m_ind.append(ind)
62 |
63 | def __len__(self):
64 | return len(self.m_ind)
65 |
66 | def __iter__(self):
67 | classes = torch.arange(len(self.m_ind))
68 | for c in classes:
69 | l = self.m_ind[int(c)]
70 | if self.n_per is None:
71 | pos = torch.randperm(len(l))
72 | else:
73 | pos = torch.randperm(len(l))[:self.n_per]
74 | yield l[pos]
75 |
76 |
77 | # for ResNet Fine-Tune, which output the same index of task examples several times
78 | class InSetSampler():
79 |
80 | def __init__(self, n_batch, n_sbatch, pool): # pool is a tensor
81 | self.n_batch = n_batch
82 | self.n_sbatch = n_sbatch
83 | self.pool = pool
84 | self.pool_size = pool.shape[0]
85 |
86 | def __len__(self):
87 | return self.n_batch
88 |
89 | def __iter__(self):
90 | for i_batch in range(self.n_batch):
91 | batch = self.pool[torch.randperm(self.pool_size)[:self.n_sbatch]]
92 | yield batch
--------------------------------------------------------------------------------
/model/dataloader/transforms.py:
--------------------------------------------------------------------------------
1 | # Credits to DeepVoltaire
2 | # github:DeepVoltaire/AutoAugment
3 |
4 | from PIL import Image, ImageEnhance, ImageOps
5 | import random
6 |
7 |
8 | class ShearX(object):
9 | def __init__(self, fillcolor=(128, 128, 128)):
10 | self.fillcolor = fillcolor
11 |
12 | def __call__(self, x, magnitude):
13 | return x.transform(
14 | x.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
15 | Image.BICUBIC, fillcolor=self.fillcolor)
16 |
17 |
18 | class ShearY(object):
19 | def __init__(self, fillcolor=(128, 128, 128)):
20 | self.fillcolor = fillcolor
21 |
22 | def __call__(self, x, magnitude):
23 | return x.transform(
24 | x.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
25 | Image.BICUBIC, fillcolor=self.fillcolor)
26 |
27 |
28 | class TranslateX(object):
29 | def __init__(self, fillcolor=(128, 128, 128)):
30 | self.fillcolor = fillcolor
31 |
32 | def __call__(self, x, magnitude):
33 | return x.transform(
34 | x.size, Image.AFFINE, (1, 0, magnitude * x.size[0] * random.choice([-1, 1]), 0, 1, 0),
35 | fillcolor=self.fillcolor)
36 |
37 |
38 | class TranslateY(object):
39 | def __init__(self, fillcolor=(128, 128, 128)):
40 | self.fillcolor = fillcolor
41 |
42 | def __call__(self, x, magnitude):
43 | return x.transform(
44 | x.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * x.size[1] * random.choice([-1, 1])),
45 | fillcolor=self.fillcolor)
46 |
47 |
48 | class Rotate(object):
49 | # from https://stackoverflow.com/questions/
50 | # 5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
51 | def __call__(self, x, magnitude):
52 | rot = x.convert("RGBA").rotate(magnitude)
53 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(x.mode)
54 |
55 |
56 | class Color(object):
57 | def __call__(self, x, magnitude):
58 | return ImageEnhance.Color(x).enhance(1 + magnitude * random.choice([-1, 1]))
59 |
60 |
61 | class Posterize(object):
62 | def __call__(self, x, magnitude):
63 | return ImageOps.posterize(x, magnitude)
64 |
65 |
66 | class Solarize(object):
67 | def __call__(self, x, magnitude):
68 | return ImageOps.solarize(x, magnitude)
69 |
70 |
71 | class Contrast(object):
72 | def __call__(self, x, magnitude):
73 | return ImageEnhance.Contrast(x).enhance(1 + magnitude * random.choice([-1, 1]))
74 |
75 |
76 | class Sharpness(object):
77 | def __call__(self, x, magnitude):
78 | return ImageEnhance.Sharpness(x).enhance(1 + magnitude * random.choice([-1, 1]))
79 |
80 |
81 | class Brightness(object):
82 | def __call__(self, x, magnitude):
83 | return ImageEnhance.Brightness(x).enhance(1 + magnitude * random.choice([-1, 1]))
84 |
85 |
86 | class AutoContrast(object):
87 | def __call__(self, x, magnitude):
88 | return ImageOps.autocontrast(x)
89 |
90 |
91 | class Equalize(object):
92 | def __call__(self, x, magnitude):
93 | return ImageOps.equalize(x)
94 |
95 |
96 | class Invert(object):
97 | def __call__(self, x, magnitude):
98 | return ImageOps.invert(x)
99 |
--------------------------------------------------------------------------------
/model/models/protonet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 |
6 | from model.models import FewShotModel_1
7 |
8 | """
9 | The ProtoNet class inherits from FewShotModel_1, which is assumed to be tailored for few-shot learning scenarios.
10 | This implementation specifically targets the scenario where the model is expected to learn from a limited number of examples (support set)
11 | and generalize well to new, unseen examples (query set).
12 | """
13 |
14 | class ProtoNet(FewShotModel_1):
15 | def __init__(self, args):
16 | """
17 | Initialize the ProtoNet with the given arguments.
18 | This constructor passes any arguments to the superclass FewShotModel_1, which might perform some initial setup.
19 | """
20 | super().__init__(args)
21 |
22 | def _forward(self, instance_embs, support_idx, query_idx):
23 | """
24 | Custom forward logic for processing instance embeddings and calculating the prototypes.
25 |
26 | Parameters:
27 | - instance_embs: Tensor containing embeddings for all instances.
28 | - support_idx: Indices of support examples within instance_embs.
29 | - query_idx: Indices of query examples within instance_embs.
30 |
31 | The method handles two cases:
32 | 1. If Grad-CAM is enabled, it returns the raw embeddings for visualization purposes.
33 | 2. Otherwise, it processes the embeddings to compute class prototypes and their distances to query examples.
34 | """
35 | if self.args.grad_cam:
36 | # Return embeddings directly for Grad-CAM visualization.
37 | return instance_embs
38 |
39 | else:
40 | # Extract the size of the last dimension, which represents the dimensionality of the embeddings.
41 | emb_dim = instance_embs.size(-1)
42 |
43 | # Organize support and query data by reshaping them according to their indices.
44 | support = instance_embs[support_idx.flatten()].view(*(support_idx.shape + (-1,)))
45 | query = instance_embs[query_idx.flatten()].view(*(query_idx.shape + (-1,)))
46 |
47 | # Compute the mean of the support embeddings to form the prototypes for each class.
48 | proto = support.mean(dim=1) # Ntask x NK x d
49 |
50 | # Prepare for distance calculation between queries and prototypes.
51 | num_batch = proto.shape[0]
52 | num_proto = proto.shape[1]
53 | num_query = np.prod(query_idx.shape[-2:])
54 |
55 | if True: # Placeholder for a boolean flag such as self.args.use_euclidean
56 | # Compute Euclidean distances
57 | query = query.view(-1, emb_dim).unsqueeze(1) # Reshape for broadcasting
58 | proto = proto.unsqueeze(1).expand(num_batch, num_query, num_proto, emb_dim)
59 | proto = proto.contiguous().view(num_batch * num_query, num_proto, emb_dim)
60 | logits = - torch.sum((proto - query) ** 2, 2) / self.args.temperature
61 | else:
62 | # Compute Cosine similarity
63 | proto = F.normalize(proto, dim=-1) # Normalize for cosine distance
64 | query = query.view(num_batch, -1, emb_dim) # Reshape for matrix multiplication
65 | logits = torch.bmm(query, proto.permute([0, 2, 1])) / self.args.temperature
66 | logits = logits.view(-1, num_proto)
67 |
68 | # Depending on the training state, return logits directly or with additional processing.
69 | if self.training:
70 | return logits, None
71 | else:
72 | return logits
73 |
--------------------------------------------------------------------------------
/model/trainer/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import torch
3 | import os.path as osp
4 |
5 | from model.utils import (
6 | ensure_path,
7 | Averager, Timer, count_acc,
8 | compute_confidence_interval,
9 | )
10 | from model.logger import Logger
11 |
12 | class Trainer(object, metaclass=abc.ABCMeta):
13 | def __init__(self, args):
14 | self.args = args
15 | # ensure_path(
16 | # self.args.save_path,
17 | # scripts_to_save=['model/models', 'model/networks', __file__],
18 | # )
19 | self.logger = Logger(args, osp.join(args.save_path))
20 |
21 | self.train_step = 0
22 | self.train_epoch = 0
23 | self.max_steps = args.episodes_per_epoch * args.max_epoch
24 | self.dt, self.ft = Averager(), Averager()
25 | self.bt, self.ot = Averager(), Averager()
26 | self.timer = Timer()
27 |
28 | # train statistics
29 | self.trlog = {}
30 | self.trlog['max_acc'] = 0.0
31 | self.trlog['max_acc_epoch'] = 0
32 | self.trlog['max_acc_interval'] = 0.0
33 |
34 | @abc.abstractmethod
35 | def train(self):
36 | pass
37 |
38 | @abc.abstractmethod
39 | def evaluate(self, data_loader):
40 | pass
41 |
42 | @abc.abstractmethod
43 | def evaluate_test(self, data_loader):
44 | pass
45 |
46 | @abc.abstractmethod
47 | def final_record(self):
48 | pass
49 |
50 | def try_evaluate(self, epoch):
51 | args = self.args
52 | if self.train_epoch % args.eval_interval == 0:
53 | vl, va, vap = self.evaluate(self.val_loader)
54 | self.logger.add_scalar('val_loss', float(vl), self.train_epoch)
55 | self.logger.add_scalar('val_acc', float(va), self.train_epoch)
56 |
57 | print('epoch {}, val, loss={:.4f} acc={:.4f}+{:.4f}'.format(epoch, vl, va, vap))
58 |
59 | if va >= self.trlog['max_acc']:
60 | self.trlog['max_acc'] = va
61 | self.trlog['max_acc_interval'] = vap
62 | self.trlog['max_acc_epoch'] = self.train_epoch
63 | self.save_model('max_acc')
64 |
65 | def try_logging(self, tl1, tl2, ta, tg=None):
66 | args = self.args
67 | if self.train_step % args.log_interval == 0:
68 | print('epoch {}, train {:06g}/{:06g}, total loss={:.4f}, loss={:.4f} acc={:.4f}, lr={:.4g}'
69 | .format(self.train_epoch,
70 | self.train_step,
71 | self.max_steps,
72 | tl1.item(), tl2.item(), ta.item(),
73 | self.optimizer.param_groups[0]['lr']))
74 | self.logger.add_scalar('train_total_loss', tl1.item(), self.train_step)
75 | self.logger.add_scalar('train_loss', tl2.item(), self.train_step)
76 | self.logger.add_scalar('train_acc', ta.item(), self.train_step)
77 | if tg is not None:
78 | self.logger.add_scalar('grad_norm', tg.item(), self.train_step)
79 | print('data_timer: {:.2f} sec, ' \
80 | 'forward_timer: {:.2f} sec,' \
81 | 'backward_timer: {:.2f} sec, ' \
82 | 'optim_timer: {:.2f} sec'.format(
83 | self.dt.item(), self.ft.item(),
84 | self.bt.item(), self.ot.item())
85 | )
86 | self.logger.dump()
87 |
88 | def save_model(self, name):
89 | torch.save(
90 | dict(params=self.model.state_dict()),
91 | osp.join(self.args.save_path, name + '.pth')
92 | )
93 |
94 | def __str__(self):
95 | return "{}({})".format(
96 | self.__class__.__name__,
97 | self.model.__class__.__name__
98 | )
99 |
--------------------------------------------------------------------------------
/model/data_parallel.py:
--------------------------------------------------------------------------------
1 | from torch.nn.parallel import DataParallel
2 | import torch
3 | from torch.nn.parallel._functions import Scatter
4 | from torch.nn.parallel.parallel_apply import parallel_apply
5 |
6 | def scatter(inputs, target_gpus, chunk_sizes, dim=0):
7 | r"""
8 | Slices tensors into approximately equal chunks and
9 | distributes them across given GPUs. Duplicates
10 | references to objects that are not tensors.
11 | """
12 | def scatter_map(obj):
13 | if isinstance(obj, torch.Tensor):
14 | try:
15 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
16 | except:
17 | print('obj', obj.size())
18 | print('dim', dim)
19 | print('chunk_sizes', chunk_sizes)
20 | quit()
21 | if isinstance(obj, tuple) and len(obj) > 0:
22 | return list(zip(*map(scatter_map, obj)))
23 | if isinstance(obj, list) and len(obj) > 0:
24 | return list(map(list, zip(*map(scatter_map, obj))))
25 | if isinstance(obj, dict) and len(obj) > 0:
26 | return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
27 | return [obj for targets in target_gpus]
28 |
29 | # After scatter_map is called, a scatter_map cell will exist. This cell
30 | # has a reference to the actual function scatter_map, which has references
31 | # to a closure that has a reference to the scatter_map cell (because the
32 | # fn is recursive). To avoid this reference cycle, we set the function to
33 | # None, clearing the cell
34 | try:
35 | return scatter_map(inputs)
36 | finally:
37 | scatter_map = None
38 |
39 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
40 | r"""Scatter with support for kwargs dictionary"""
41 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
42 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
43 | if len(inputs) < len(kwargs):
44 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
45 | elif len(kwargs) < len(inputs):
46 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
47 | inputs = tuple(inputs)
48 | kwargs = tuple(kwargs)
49 | return inputs, kwargs
50 |
51 | class BalancedDataParallel(DataParallel):
52 | def __init__(self, gpu0_bsz, *args, **kwargs):
53 | self.gpu0_bsz = gpu0_bsz
54 | super().__init__(*args, **kwargs)
55 |
56 | def forward(self, *inputs, **kwargs):
57 | if not self.device_ids:
58 | return self.module(*inputs, **kwargs)
59 | if self.gpu0_bsz == 0:
60 | device_ids = self.device_ids[1:]
61 | else:
62 | device_ids = self.device_ids
63 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
64 | if len(self.device_ids) == 1:
65 | return self.module(*inputs[0], **kwargs[0])
66 | replicas = self.replicate(self.module, self.device_ids)
67 | if self.gpu0_bsz == 0:
68 | replicas = replicas[1:]
69 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
70 | return self.gather(outputs, self.output_device)
71 |
72 | def parallel_apply(self, replicas, device_ids, inputs, kwargs):
73 | return parallel_apply(replicas, inputs, kwargs, device_ids)
74 |
75 | def scatter(self, inputs, kwargs, device_ids):
76 | bsz = inputs[0].size(self.dim)
77 | num_dev = len(self.device_ids)
78 | gpu0_bsz = self.gpu0_bsz
79 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
80 | if gpu0_bsz < bsz_unit:
81 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
82 | delta = bsz - sum(chunk_sizes)
83 | for i in range(delta):
84 | chunk_sizes[i + 1] += 1
85 | if gpu0_bsz == 0:
86 | chunk_sizes = chunk_sizes[1:]
87 | else:
88 | return super().scatter(inputs, kwargs, device_ids)
89 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
90 |
91 |
--------------------------------------------------------------------------------
/model/models/utils/tokenizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class Tokenizer(nn.Module):
7 | def __init__(self,
8 | kernel_size, stride, padding,
9 | pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,
10 | n_conv_layers=1,
11 | n_input_channels=3,
12 | n_output_channels=64,
13 | in_planes=64,
14 | activation=None,
15 | max_pool=True,
16 | conv_bias=False):
17 | super(Tokenizer, self).__init__()
18 |
19 | n_filter_list = [n_input_channels] + \
20 | [in_planes for _ in range(n_conv_layers - 1)] + \
21 | [n_output_channels]
22 |
23 | self.conv_layers = nn.Sequential(
24 | *[nn.Sequential(
25 | nn.Conv2d(n_filter_list[i], n_filter_list[i + 1],
26 | kernel_size=(kernel_size, kernel_size),
27 | stride=(stride, stride),
28 | padding=(padding, padding), bias=conv_bias),
29 | nn.Identity() if activation is None else activation(),
30 | nn.MaxPool2d(kernel_size=pooling_kernel_size,
31 | stride=pooling_stride,
32 | padding=pooling_padding) if max_pool else nn.Identity()
33 | )
34 | for i in range(n_conv_layers)
35 | ])
36 |
37 | self.flattener = nn.Flatten(2, 3)
38 | self.apply(self.init_weight)
39 |
40 | def sequence_length(self, n_channels=3, height=224, width=224):
41 | return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]
42 |
43 | def forward(self, x):
44 | return self.flattener(self.conv_layers(x)).transpose(-2, -1)
45 |
46 | @staticmethod
47 | def init_weight(m):
48 | if isinstance(m, nn.Conv2d):
49 | nn.init.kaiming_normal_(m.weight)
50 |
51 |
52 | class TextTokenizer(nn.Module):
53 | def __init__(self,
54 | kernel_size, stride, padding,
55 | pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,
56 | embedding_dim=300,
57 | n_output_channels=128,
58 | activation=None,
59 | max_pool=True,
60 | *args, **kwargs):
61 | super(TextTokenizer, self).__init__()
62 |
63 | self.max_pool = max_pool
64 | self.conv_layers = nn.Sequential(
65 | nn.Conv2d(1, n_output_channels,
66 | kernel_size=(kernel_size, embedding_dim),
67 | stride=(stride, 1),
68 | padding=(padding, 0), bias=False),
69 | nn.Identity() if activation is None else activation(),
70 | nn.MaxPool2d(
71 | kernel_size=(pooling_kernel_size, 1),
72 | stride=(pooling_stride, 1),
73 | padding=(pooling_padding, 0)
74 | ) if max_pool else nn.Identity()
75 | )
76 |
77 | self.apply(self.init_weight)
78 |
79 | def seq_len(self, seq_len=32, embed_dim=300):
80 | return self.forward(torch.zeros((1, seq_len, embed_dim)))[0].shape[1]
81 |
82 | def forward_mask(self, mask):
83 | new_mask = mask.unsqueeze(1).float()
84 | cnn_weight = torch.ones(
85 | (1, 1, self.conv_layers[0].kernel_size[0]),
86 | device=mask.device,
87 | dtype=torch.float)
88 | new_mask = F.conv1d(
89 | new_mask, cnn_weight, None,
90 | self.conv_layers[0].stride[0], self.conv_layers[0].padding[0], 1, 1)
91 | if self.max_pool:
92 | new_mask = F.max_pool1d(
93 | new_mask, self.conv_layers[2].kernel_size[0],
94 | self.conv_layers[2].stride[0], self.conv_layers[2].padding[0], 1, False, False)
95 | new_mask = new_mask.squeeze(1)
96 | new_mask = (new_mask > 0)
97 | return new_mask
98 |
99 | def forward(self, x, mask=None):
100 | x = x.unsqueeze(1)
101 | x = self.conv_layers(x)
102 | x = x.transpose(1, 3).squeeze(1)
103 | x = x if mask is None else x * self.forward_mask(mask).unsqueeze(-1).float()
104 | return x, mask
105 |
106 | @staticmethod
107 | def init_weight(m):
108 | if isinstance(m, nn.Conv2d):
109 | nn.init.kaiming_normal_(m.weight)
110 |
--------------------------------------------------------------------------------
/model/networks/utils/tokenizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class Tokenizer(nn.Module):
7 | def __init__(self,
8 | kernel_size, stride, padding,
9 | pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,
10 | n_conv_layers=1,
11 | n_input_channels=3,
12 | n_output_channels=64,
13 | in_planes=64,
14 | activation=None,
15 | max_pool=True,
16 | conv_bias=False):
17 | super(Tokenizer, self).__init__()
18 |
19 | n_filter_list = [n_input_channels] + \
20 | [in_planes for _ in range(n_conv_layers - 1)] + \
21 | [n_output_channels]
22 |
23 | self.conv_layers = nn.Sequential(
24 | *[nn.Sequential(
25 | nn.Conv2d(n_filter_list[i], n_filter_list[i + 1],
26 | kernel_size=(kernel_size, kernel_size),
27 | stride=(stride, stride),
28 | padding=(padding, padding), bias=conv_bias),
29 | nn.Identity() if activation is None else activation(),
30 | nn.MaxPool2d(kernel_size=pooling_kernel_size,
31 | stride=pooling_stride,
32 | padding=pooling_padding) if max_pool else nn.Identity()
33 | )
34 | for i in range(n_conv_layers)
35 | ])
36 |
37 | self.flattener = nn.Flatten(2, 3)
38 | self.apply(self.init_weight)
39 |
40 | def sequence_length(self, n_channels=3, height=224, width=224):
41 | return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]
42 |
43 | def forward(self, x):
44 | return self.flattener(self.conv_layers(x)).transpose(-2, -1)
45 |
46 | @staticmethod
47 | def init_weight(m):
48 | if isinstance(m, nn.Conv2d):
49 | nn.init.kaiming_normal_(m.weight)
50 |
51 |
52 | class TextTokenizer(nn.Module):
53 | def __init__(self,
54 | kernel_size, stride, padding,
55 | pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,
56 | embedding_dim=300,
57 | n_output_channels=128,
58 | activation=None,
59 | max_pool=True,
60 | *args, **kwargs):
61 | super(TextTokenizer, self).__init__()
62 |
63 | self.max_pool = max_pool
64 | self.conv_layers = nn.Sequential(
65 | nn.Conv2d(1, n_output_channels,
66 | kernel_size=(kernel_size, embedding_dim),
67 | stride=(stride, 1),
68 | padding=(padding, 0), bias=False),
69 | nn.Identity() if activation is None else activation(),
70 | nn.MaxPool2d(
71 | kernel_size=(pooling_kernel_size, 1),
72 | stride=(pooling_stride, 1),
73 | padding=(pooling_padding, 0)
74 | ) if max_pool else nn.Identity()
75 | )
76 |
77 | self.apply(self.init_weight)
78 |
79 | def seq_len(self, seq_len=32, embed_dim=300):
80 | return self.forward(torch.zeros((1, seq_len, embed_dim)))[0].shape[1]
81 |
82 | def forward_mask(self, mask):
83 | new_mask = mask.unsqueeze(1).float()
84 | cnn_weight = torch.ones(
85 | (1, 1, self.conv_layers[0].kernel_size[0]),
86 | device=mask.device,
87 | dtype=torch.float)
88 | new_mask = F.conv1d(
89 | new_mask, cnn_weight, None,
90 | self.conv_layers[0].stride[0], self.conv_layers[0].padding[0], 1, 1)
91 | if self.max_pool:
92 | new_mask = F.max_pool1d(
93 | new_mask, self.conv_layers[2].kernel_size[0],
94 | self.conv_layers[2].stride[0], self.conv_layers[2].padding[0], 1, False, False)
95 | new_mask = new_mask.squeeze(1)
96 | new_mask = (new_mask > 0)
97 | return new_mask
98 |
99 | def forward(self, x, mask=None):
100 | x = x.unsqueeze(1)
101 | x = self.conv_layers(x)
102 | x = x.transpose(1, 3).squeeze(1)
103 | x = x if mask is None else x * self.forward_mask(mask).unsqueeze(-1).float()
104 | return x, mask
105 |
106 | @staticmethod
107 | def init_weight(m):
108 | if isinstance(m, nn.Conv2d):
109 | nn.init.kaiming_normal_(m.weight)
110 |
--------------------------------------------------------------------------------
/model/dataloader/mini_imagenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os.path as osp
3 | from PIL import Image
4 | from .transforms import *
5 | import PIL
6 |
7 | from torch.utils.data import Dataset
8 | from torchvision import transforms
9 | from tqdm import tqdm
10 | import numpy as np
11 |
12 | # Paths
13 | THIS_PATH = osp.dirname(__file__)
14 | ROOT_PATH = osp.abspath(osp.join(THIS_PATH, '..', '..'))
15 | ROOT_PATH2 = osp.abspath(osp.join(THIS_PATH, '..', '..', '..'))
16 | IMAGE_PATH1 = osp.join(ROOT_PATH, 'data/miniimagenet/images')
17 | SPLIT_PATH = osp.join(ROOT_PATH, 'data/miniimagenet/split')
18 | CACHE_PATH = osp.join(ROOT_PATH, '.cache/')
19 |
20 |
21 | def identity(x):
22 | """Identity function."""
23 | return x
24 |
25 | class MiniImageNet(Dataset):
26 | """Dataset class for MiniImageNet."""
27 | def __init__(self, setname, args, augment=False):
28 | """Initialize MiniImageNet dataset."""
29 | im_size = args.orig_imsize
30 | csv_path = osp.join(SPLIT_PATH, setname + '.csv')
31 | cache_path = osp.join( CACHE_PATH, "{}.{}.{}.pt".format(self.__class__.__name__, setname, im_size) )
32 | self.args = args
33 | self.use_im_cache = ( im_size != -1 ) # not using cache
34 |
35 | # Check if using image cache
36 | if self.use_im_cache:
37 | if not osp.exists(cache_path):
38 | print('* Cache miss... Preprocessing {}...'.format(setname))
39 | resize_ = identity if im_size < 0 else transforms.Resize(im_size)
40 | data, label = self.parse_csv(csv_path, setname)
41 | self.data = [ resize_(Image.open(path).convert('RGB')) for path in data ]
42 | self.label = label
43 | print('* Dump cache from {}'.format(cache_path))
44 | torch.save({'data': self.data, 'label': self.label }, cache_path)
45 | else:
46 | print('* Load cache from {}'.format(cache_path))
47 | cache = torch.load(cache_path)
48 | self.data = cache['data']
49 | self.label = cache['label']
50 | else:
51 | self.data, self.label = self.parse_csv(csv_path, setname)
52 |
53 | self.num_class = len(set(self.label))
54 |
55 | image_size = 84
56 | if augment and setname == 'train':
57 | # Augmentation transforms
58 | transforms_list = [
59 | transforms.RandomResizedCrop(image_size),
60 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
61 | transforms.RandomHorizontalFlip(),
62 | transforms.ToTensor(),
63 | ]
64 | else:
65 | # Validation/Test transforms
66 | transforms_list = [
67 | transforms.Resize(92),
68 | transforms.CenterCrop(image_size),
69 | transforms.ToTensor(),
70 | ]
71 |
72 | # Transformation based on backbone class
73 | if args.backbone_class == 'Res12' :
74 | self.transform = transforms.Compose(
75 | transforms_list + [
76 | transforms.Normalize(np.array([x / 255.0 for x in [120.39586422, 115.59361427, 104.54012653]]),
77 | np.array([x / 255.0 for x in [70.68188272, 68.27635443, 72.54505529]]))
78 | ])
79 | elif args.backbone_class == 'Res18':
80 | self.transform = transforms.Compose(
81 | transforms_list + [
82 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
83 | std=[0.229, 0.224, 0.225])
84 | ])
85 | else:
86 | raise ValueError('Non-supported Network Types. Please Revise Data Pre-Processing Scripts.')
87 |
88 | def parse_csv(self, csv_path, setname):
89 | """Parse CSV file to get image paths and labels."""
90 | lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]
91 |
92 | data = []
93 | label = []
94 | lb = -1
95 |
96 | self.wnids = []
97 |
98 | for l in tqdm(lines, ncols=64):
99 | name, wnid = l.split(',')
100 | path = osp.join(IMAGE_PATH1, name)
101 | if wnid not in self.wnids:
102 | self.wnids.append(wnid)
103 | lb += 1
104 | data.append( path )
105 | label.append(lb)
106 |
107 | return data, label
108 |
109 | def __len__(self):
110 | """Get the length of the dataset."""
111 | return len(self.data)
112 |
113 | def __getitem__(self, i):
114 | """Get an item from the dataset."""
115 | data, label = self.data[i], self.label[i]
116 | if self.use_im_cache:
117 | image = self.transform(data)
118 | else:
119 | image = self.transform(Image.open(data).convert('RGB'))
120 |
121 | return image, label
122 |
--------------------------------------------------------------------------------
/model/networks/res12.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from model.networks.dropblock import DropBlock
5 |
6 | # This ResNet network was designed following the practice of the following papers:
7 | # TADAM: Task dependent adaptive metric for improved few-shot learning (Oreshkin et al., in NIPS 2018) and
8 | # A Simple Neural Attentive Meta-Learner (Mishra et al., in ICLR 2018).
9 |
10 | def conv3x3(in_planes, out_planes, stride=1):
11 | """3x3 convolution with padding"""
12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=1, bias=False)
14 |
15 |
16 | class BasicBlock(nn.Module):
17 | expansion = 1
18 |
19 | def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False, block_size=1):
20 | super(BasicBlock, self).__init__()
21 | self.conv1 = conv3x3(inplanes, planes)
22 | self.bn1 = nn.BatchNorm2d(planes)
23 | self.relu = nn.LeakyReLU(0.1)
24 | self.conv2 = conv3x3(planes, planes)
25 | self.bn2 = nn.BatchNorm2d(planes)
26 | self.conv3 = conv3x3(planes, planes)
27 | self.bn3 = nn.BatchNorm2d(planes)
28 | self.maxpool = nn.MaxPool2d(stride)
29 | self.downsample = downsample
30 | self.stride = stride
31 | self.drop_rate = drop_rate
32 | self.num_batches_tracked = 0
33 | self.drop_block = drop_block
34 | self.block_size = block_size
35 | self.DropBlock = DropBlock(block_size=self.block_size)
36 |
37 | def forward(self, x):
38 | self.num_batches_tracked += 1
39 |
40 | residual = x
41 |
42 | out = self.conv1(x)
43 | out = self.bn1(out)
44 | out = self.relu(out)
45 |
46 | out = self.conv2(out)
47 | out = self.bn2(out)
48 | out = self.relu(out)
49 |
50 | out = self.conv3(out)
51 | out = self.bn3(out)
52 |
53 | if self.downsample is not None:
54 | residual = self.downsample(x)
55 | out += residual
56 | out = self.relu(out)
57 | out = self.maxpool(out)
58 |
59 | if self.drop_rate > 0:
60 | if self.drop_block == True:
61 | feat_size = out.size()[2]
62 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate)
63 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2
64 | out = self.DropBlock(out, gamma=gamma)
65 | else:
66 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True)
67 |
68 | return out
69 |
70 |
71 | class ResNet(nn.Module):
72 |
73 | def __init__(self, block=BasicBlock, keep_prob=1.0, avg_pool=True, drop_rate=0.1, dropblock_size=5):
74 | self.inplanes = 3
75 | super(ResNet, self).__init__()
76 |
77 | self.layer1 = self._make_layer(block, 64, stride=2, drop_rate=drop_rate)
78 | self.layer2 = self._make_layer(block, 160, stride=2, drop_rate=drop_rate)
79 | self.layer3 = self._make_layer(block, 320, stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
80 | self.layer4 = self._make_layer(block, 640, stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
81 | if avg_pool:
82 | self.avgpool = nn.AvgPool2d(5, stride=1)
83 | self.keep_prob = keep_prob
84 | self.keep_avg_pool = avg_pool
85 | self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False)
86 | self.drop_rate = drop_rate
87 |
88 | for m in self.modules():
89 | if isinstance(m, nn.Conv2d):
90 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
91 | elif isinstance(m, nn.BatchNorm2d):
92 | nn.init.constant_(m.weight, 1)
93 | nn.init.constant_(m.bias, 0)
94 |
95 | def _make_layer(self, block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1):
96 | downsample = None
97 | if stride != 1 or self.inplanes != planes * block.expansion:
98 | downsample = nn.Sequential(
99 | nn.Conv2d(self.inplanes, planes * block.expansion,
100 | kernel_size=1, stride=1, bias=False),
101 | nn.BatchNorm2d(planes * block.expansion),
102 | )
103 |
104 | layers = []
105 | layers.append(block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size))
106 | self.inplanes = planes * block.expansion
107 |
108 | return nn.Sequential(*layers)
109 |
110 | def forward(self, x):
111 | x = self.layer1(x)
112 | x = self.layer2(x)
113 | x = self.layer3(x)
114 | x = self.layer4(x)
115 | # x = torch.nn.functional.interpolate(x, size=(7, 7), mode='bilinear')
116 | # if self.keep_avg_pool:
117 | # x = self.avgpool(x)
118 | # x = x.view(x.size(0), -1)
119 | return x
120 |
121 |
122 | def Res12(keep_prob=1.0, avg_pool=False, **kwargs):
123 | """Constructs a ResNet-12 model.
124 | """
125 | model = ResNet(BasicBlock, keep_prob=keep_prob, avg_pool=avg_pool, **kwargs)
126 | return model
127 |
--------------------------------------------------------------------------------
/model/dataloader/cub.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import PIL
3 | from PIL import Image
4 |
5 | import numpy as np
6 | from torch.utils.data import Dataset
7 | from torchvision import transforms
8 | import torch
9 | THIS_PATH = osp.dirname(__file__)
10 | ROOT_PATH = osp.abspath(osp.join(THIS_PATH, '..', '..'))
11 | ROOT_PATH2 = osp.abspath(osp.join(THIS_PATH, '..', '..', '..'))
12 | IMAGE_PATH = 'data/CUB/CUB_200_2011/images' ##previously use data in feat/data/cub/images
13 |
14 | SPLIT_PATH = 'data/CUB/split'
15 | CACHE_PATH = osp.join(ROOT_PATH, '.cache/')
16 |
17 | # This is for the CUB dataset
18 | # It is notable, we assume the cub images are cropped based on the given bounding boxes
19 | # The concept labels are based on the attribute value, which are for further use (and not used in this work)
20 |
21 | class CUB(Dataset):
22 |
23 | def __init__(self, setname, args, augment=False):
24 | im_size = args.orig_imsize
25 | txt_path = osp.join(SPLIT_PATH, setname + '.csv')
26 | lines = [x.strip() for x in open(txt_path, 'r').readlines()][1:]
27 | cache_path = osp.join( CACHE_PATH, "{}.{}.{}.pt".format(self.__class__.__name__, setname, im_size) )
28 |
29 | self.use_im_cache = ( im_size != -1 ) # not using cache
30 | if self.use_im_cache:
31 | if not osp.exists(cache_path):
32 | print('* Cache miss... Preprocessing {}...'.format(setname))
33 | resize_ = identity if im_size < 0 else transforms.Resize(im_size)
34 | data, label = self.parse_csv(txt_path)
35 | self.data = [ resize_(Image.open(path).convert('RGB')) for path in data ]
36 | self.label = label
37 | print('* Dump cache from {}'.format(cache_path))
38 | torch.save({'data': self.data, 'label': self.label }, cache_path)
39 | else:
40 | print('* Load cache from {}'.format(cache_path))
41 | cache = torch.load(cache_path)
42 | self.data = cache['data']
43 | self.label = cache['label']
44 | else:
45 | self.data, self.label = self.parse_csv(txt_path)
46 |
47 | self.num_class = np.unique(np.array(self.label)).shape[0]
48 | image_size = 84
49 |
50 | if augment and setname == 'train':
51 | transforms_list = [
52 | transforms.RandomResizedCrop(image_size),
53 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
54 | transforms.RandomHorizontalFlip(),
55 | transforms.ToTensor(),
56 | ]
57 | else:
58 | transforms_list = [
59 | transforms.Resize(92),
60 | transforms.CenterCrop(image_size),
61 | transforms.ToTensor(),
62 | ]
63 |
64 | # Transformation
65 | if args.backbone_class == 'ConvNet':
66 | self.transform = transforms.Compose(
67 | transforms_list + [
68 | transforms.Normalize(np.array([0.485, 0.456, 0.406]),
69 | np.array([0.229, 0.224, 0.225]))
70 | ])
71 | elif args.backbone_class == 'Res12':
72 | self.transform = transforms.Compose(
73 | transforms_list + [
74 | transforms.Normalize(np.array([x / 255.0 for x in [120.39586422, 115.59361427, 104.54012653]]),
75 | np.array([x / 255.0 for x in [70.68188272, 68.27635443, 72.54505529]]))
76 | ])
77 | elif args.backbone_class == 'Res18':
78 | self.transform = transforms.Compose(
79 | transforms_list + [
80 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
81 | std=[0.229, 0.224, 0.225])
82 | ])
83 | elif args.backbone_class == 'WRN':
84 | self.transform = transforms.Compose(
85 | transforms_list + [
86 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
87 | std=[0.229, 0.224, 0.225])
88 | ])
89 | else:
90 | raise ValueError('Non-supported Network Types. Please Revise Data Pre-Processing Scripts.')
91 |
92 | def parse_csv(self, txt_path):
93 | data = []
94 | label = []
95 | lb = -1
96 | self.wnids = []
97 | lines = [x.strip() for x in open(txt_path, 'r').readlines()][1:]
98 |
99 | for l in lines:
100 | context = l.split(',')
101 | name = context[0]
102 | wnid = context[1]
103 | path = osp.join(IMAGE_PATH, name)
104 | if wnid not in self.wnids:
105 | self.wnids.append(wnid)
106 | lb += 1
107 |
108 | data.append(path)
109 | label.append(lb)
110 |
111 | return data, label
112 |
113 |
114 | def __len__(self):
115 | return len(self.data)
116 |
117 | def __getitem__(self, i):
118 | data, label = self.data[i], self.label[i]
119 | if self.use_im_cache:
120 | image = self.transform(data)
121 | else:
122 | image = self.transform(Image.open(data).convert('RGB'))
123 | return image, label
--------------------------------------------------------------------------------
/model/dataloader/tiered_imagenet.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import os.path as osp
5 | import numpy as np
6 | import pickle
7 | import sys
8 | import torch
9 | import torch.utils.data as data
10 | import torchvision.transforms as transforms
11 | from PIL import Image
12 |
13 | # Set the appropriate paths of the datasets here.
14 | THIS_PATH = osp.dirname(__file__)
15 | ROOT_PATH1 = osp.abspath(osp.join(THIS_PATH, '..', '..', '..'))
16 | ROOT_PATH2 = osp.abspath(osp.join(THIS_PATH, '..', '..'))
17 | IMAGE_PATH = osp.join(ROOT_PATH2, 'data/tieredimagenet/')
18 | SPLIT_PATH = osp.join(ROOT_PATH2, 'data/miniimagenet/split')
19 |
20 | from .transforms import *
21 | import PIL
22 |
23 |
24 | def buildLabelIndex(labels):
25 | label2inds = {}
26 | for idx, label in enumerate(labels):
27 | if label not in label2inds:
28 | label2inds[label] = []
29 | label2inds[label].append(idx)
30 |
31 | return label2inds
32 |
33 |
34 | def load_data(file):
35 | try:
36 | with open(file, 'rb') as fo:
37 | data = pickle.load(fo)
38 | return data
39 | except:
40 | with open(file, 'rb') as f:
41 | u = pickle._Unpickler(f)
42 | u.encoding = 'latin1'
43 | data = u.load()
44 | return data
45 |
46 | file_path = {'train':[os.path.join(IMAGE_PATH, 'train_images.npz'), os.path.join(IMAGE_PATH, 'train_labels.pkl')],
47 | 'val':[os.path.join(IMAGE_PATH, 'val_images.npz'), os.path.join(IMAGE_PATH,'val_labels.pkl')],
48 | 'test':[os.path.join(IMAGE_PATH, 'test_images.npz'), os.path.join(IMAGE_PATH, 'test_labels.pkl')]}
49 |
50 | class tieredImageNet(data.Dataset):
51 | def __init__(self, setname, args, augment=False):
52 | assert(setname=='train' or setname=='val' or setname=='test')
53 | image_path = file_path[setname][0]
54 | label_path = file_path[setname][1]
55 |
56 | data_train = load_data(label_path)
57 | labels = data_train['labels']
58 | self.data = np.load(image_path)['images']
59 | label = []
60 | lb = -1
61 | self.wnids = []
62 | for wnid in labels:
63 | if wnid not in self.wnids:
64 | self.wnids.append(wnid)
65 | lb += 1
66 | label.append(lb)
67 |
68 | self.label = label
69 | self.num_class = len(set(label))
70 |
71 | if augment and setname == 'train':
72 | transforms_list = [
73 | transforms.RandomCrop(84, padding=8),
74 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
75 | transforms.RandomHorizontalFlip(),
76 | transforms.ToTensor(),
77 | ]
78 | elif args.backbone_class == 'Res10':
79 | transforms_list = [
80 | transforms.RandomSizedCrop(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333),
81 | interpolation=PIL.Image.BILINEAR),
82 | transforms.RandomHorizontalFlip(p=0.5),
83 | transforms.ToTensor()
84 | ]
85 |
86 |
87 | else:
88 | transforms_list = [
89 | transforms.ToTensor(),
90 | ]
91 |
92 | # Transformation
93 | if args.backbone_class == 'ConvNet':
94 | self.transform = transforms.Compose(
95 | transforms_list + [
96 | transforms.Normalize(np.array([0.485, 0.456, 0.406]),
97 | np.array([0.229, 0.224, 0.225]))
98 | ])
99 | elif args.backbone_class == 'ResNet':
100 | self.transform = transforms.Compose(
101 | transforms_list + [
102 | transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
103 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))
104 | ])
105 | elif args.backbone_class == 'Res12' or args.backbone_class == 'Res10' :
106 | self.transform = transforms.Compose(
107 | transforms_list + [
108 | transforms.Normalize(np.array([x / 255.0 for x in [120.39586422, 115.59361427, 104.54012653]]),
109 | np.array([x / 255.0 for x in [70.68188272, 68.27635443, 72.54505529]]))
110 | ])
111 | elif args.backbone_class == 'Res18':
112 | self.transform = transforms.Compose(
113 | transforms_list + [
114 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
115 | std=[0.229, 0.224, 0.225])
116 | ])
117 | elif args.backbone_class == 'WRN':
118 | self.transform = transforms.Compose(
119 | transforms_list + [
120 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
121 | std=[0.229, 0.224, 0.225])
122 | ])
123 | else:
124 | raise ValueError('Non-supported Network Types. Please Revise Data Pre-Processing Scripts.')
125 |
126 |
127 | def __getitem__(self, index):
128 | img, label = self.data[index], self.label[index]
129 | img = self.transform(Image.fromarray(img))
130 | return img, label
131 |
132 | def __len__(self):
133 | return len(self.data)
134 |
--------------------------------------------------------------------------------
/model/models/fcanet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | def get_freq_indices(method):
7 | assert method in ['top1', 'top2', 'top4', 'top8', 'top16', 'top32',
8 | 'bot1', 'bot2', 'bot4', 'bot8', 'bot16', 'bot32',
9 | 'low1', 'low2', 'low4', 'low8', 'low16', 'low32']
10 | num_freq = int(method[3:])
11 | if 'top' in method:
12 | all_top_indices_x = [0, 0, 6, 0, 0, 1, 1, 4, 5, 1, 3, 0, 0, 0, 3, 2, 4, 6, 3, 5, 5, 2, 6, 5, 5, 3, 3, 4, 2, 2,
13 | 6, 1]
14 | all_top_indices_y = [0, 1, 0, 5, 2, 0, 2, 0, 0, 6, 0, 4, 6, 3, 5, 2, 6, 3, 3, 3, 5, 1, 1, 2, 4, 2, 1, 1, 3, 0,
15 | 5, 3]
16 | mapper_x = all_top_indices_x[:num_freq]
17 | mapper_y = all_top_indices_y[:num_freq]
18 | elif 'low' in method:
19 | all_low_indices_x = [0, 0, 1, 1, 0, 2, 2, 1, 2, 0, 3, 4, 0, 1, 3, 1, 5, 0, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6, 1, 2,
20 | 3, 4]
21 | all_low_indices_y = [0, 1, 0, 1, 2, 0, 1, 2, 2, 3, 0, 0, 4, 3, 1, 4, 0, 5, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5,
22 | 4, 3]
23 | mapper_x = all_low_indices_x[:num_freq]
24 | mapper_y = all_low_indices_y[:num_freq]
25 | elif 'bot' in method:
26 | all_bot_indices_x = [6, 1, 3, 3, 2, 4, 1, 2, 4, 4, 5, 1, 4, 6, 2, 5, 6, 1, 6, 2, 2, 4, 3, 3, 5, 5, 6, 2, 5, 5,
27 | 3, 6]
28 | all_bot_indices_y = [6, 4, 4, 6, 6, 3, 1, 4, 4, 5, 6, 5, 2, 2, 5, 1, 4, 3, 5, 0, 3, 1, 1, 2, 4, 2, 1, 1, 5, 3,
29 | 3, 3]
30 | mapper_x = all_bot_indices_x[:num_freq]
31 | mapper_y = all_bot_indices_y[:num_freq]
32 | else:
33 | raise NotImplementedError
34 | return mapper_x, mapper_y
35 |
36 |
37 | class MultiSpectralAttentionLayer(torch.nn.Module):
38 | def __init__(self, channel, dct_h, dct_w, sigma, k, freq_sel_method='top16'):
39 | super(MultiSpectralAttentionLayer, self).__init__()
40 | self.sigma = sigma
41 | self.k = k
42 | self.dct_h = dct_h
43 | self.dct_w = dct_w
44 |
45 | mapper_x, mapper_y = get_freq_indices(freq_sel_method)
46 | self.num_split = len(mapper_x)
47 | mapper_x = [temp_x * (dct_h // 5) for temp_x in mapper_x]
48 | mapper_y = [temp_y * (dct_w // 5) for temp_y in mapper_y]
49 | # make the frequencies in different sizes are identical to a 5x5 frequency space
50 | # eg, (2,2) in 10x10 is identical to (1,1) in5x5
51 |
52 | self.dct_layer = MultiSpectralDCTLayer(dct_h, dct_w, mapper_x, mapper_y, channel)
53 | self.fc = nn.Sequential(
54 | nn.Linear(channel, int(channel*self.sigma), bias=False),
55 | nn.ReLU(inplace=True),
56 | nn.Linear(int(channel*self.sigma), channel*self.k**2, bias=False),
57 | nn.Sigmoid()
58 | )
59 |
60 | def forward(self, x):
61 | n, c, h, w = x.shape
62 | x_pooled = x
63 | if h != self.dct_h or w != self.dct_w:
64 | x_pooled = torch.nn.functional.adaptive_avg_pool2d(x, (self.dct_h, self.dct_w))
65 | # If you have concerns about one-line-change, don't worry. :)
66 | # In the ImageNet models, this line will never be triggered.
67 | # This is for compatibility in instance segmentation and object detection.
68 | y = self.dct_layer(x_pooled)
69 |
70 | y = self.fc(y).view(n, c, self.k, self.k)
71 | # return x * y.expand_as(x)
72 | return y
73 |
74 | class MultiSpectralDCTLayer(nn.Module):
75 | """
76 | Generate dct filters
77 | """
78 |
79 | def __init__(self, height, width, mapper_x, mapper_y, channel):
80 | super(MultiSpectralDCTLayer, self).__init__()
81 |
82 | assert len(mapper_x) == len(mapper_y)
83 | assert channel % len(mapper_x) == 0
84 |
85 | self.num_freq = len(mapper_x)
86 |
87 | # fixed DCT init
88 | self.register_buffer('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel))
89 |
90 | # fixed random init
91 | # self.register_buffer('weight', torch.rand(channel, height, width))
92 |
93 | # learnable DCT init
94 | # self.register_parameter('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel))
95 |
96 | # learnable random init
97 | # self.register_parameter('weight', torch.rand(channel, height, width))
98 |
99 | # num_freq, h, w
100 |
101 | def forward(self, x):
102 | assert len(x.shape) == 4, 'x must been 4 dimensions, but got ' + str(len(x.shape))
103 | # n, c, h, w = x.shape
104 |
105 | x = x * self.weight
106 |
107 | result = torch.sum(x, dim=[2, 3])
108 | return result
109 |
110 | def build_filter(self, pos, freq, POS):
111 | result = math.cos(math.pi * freq * (pos + 0.5) / POS) / math.sqrt(POS)
112 | if freq == 0:
113 | return result
114 | else:
115 | return result * math.sqrt(2)
116 |
117 | def get_dct_filter(self, tile_size_x, tile_size_y, mapper_x, mapper_y, channel):
118 | dct_filter = torch.zeros(channel, tile_size_x, tile_size_y)
119 |
120 | c_part = channel // len(mapper_x)
121 |
122 | for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)):
123 | for t_x in range(tile_size_x):
124 | for t_y in range(tile_size_y):
125 | dct_filter[i * c_part: (i + 1) * c_part, t_x, t_y] = self.build_filter(t_x, u_x,
126 | tile_size_x) * self.build_filter(
127 | t_y, v_y, tile_size_y)
128 |
129 | return dct_filter
--------------------------------------------------------------------------------
/model/models/INSTA_ProtoNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 |
6 | from model.models import FewShotModel_1
7 | from model.models.INSTA import INSTA
8 |
9 | """
10 | The INSTA_ProtoNet class combines INSTA-based attention mechanisms with the prototypical networks approach.
11 | This hybrid model is designed for few-shot learning tasks where it's important to quickly adapt to new classes
12 | with very few examples per class.
13 | """
14 |
15 | class INSTA_ProtoNet(FewShotModel_1):
16 | def __init__(self, args):
17 | """
18 | Initializes the INSTA_ProtoNet with the given arguments.
19 |
20 | Parameters:
21 | - args: Configuration settings including hyperparameters for the network setup.
22 | """
23 | super().__init__(args)
24 | self.args = args
25 | # Instantiate the INSTA model with specific parameters.
26 | self.INSTA = INSTA(640, 5, 0.2, 3, args=args)
27 |
28 | def inner_loop(self, proto, support):
29 | """
30 | Performs an inner optimization loop to fine-tune prototypes on support sets during meta-training.
31 |
32 | Parameters:
33 | - proto: Initial prototypes, typically the mean of the support embeddings.
34 | - support: Support set embeddings used for fine-tuning the prototypes.
35 |
36 | Returns:
37 | - SFC: Updated (fine-tuned) prototypes.
38 | """
39 | # Clone and detach prototypes to prevent gradients from accumulating across episodes.
40 | SFC = proto.clone().detach()
41 | SFC = nn.Parameter(SFC, requires_grad=True)
42 |
43 | # Initialize an SGD optimizer specifically for this inner loop.
44 | optimizer = torch.optim.SGD([SFC], lr=0.6, momentum=0.9, dampening=0.9, weight_decay=0)
45 |
46 | # Create labels for the support set, used in cross-entropy loss during fine-tuning.
47 | label_shot = torch.arange(self.args.way).repeat(self.args.shot)
48 | label_shot = label_shot.type(torch.cuda.LongTensor)
49 |
50 | # Perform gradient steps to update the prototypes.
51 | with torch.enable_grad():
52 | for k in range(50): # Number of gradient steps.
53 | rand_id = torch.randperm(self.args.way * self.args.shot).cuda()
54 | for j in range(0, self.args.way * self.args.shot, 4):
55 | selected_id = rand_id[j: min(j + 4, self.args.way * self.args.shot)]
56 | batch_shot = support[selected_id, :]
57 | batch_label = label_shot[selected_id]
58 | optimizer.zero_grad()
59 | logits = self.classifier(batch_shot.detach(), SFC)
60 | if logits.dim() == 1:
61 | logits = logits.unsqueeze(0)
62 | loss = F.cross_entropy(logits, batch_label)
63 | loss.backward()
64 | optimizer.step()
65 | return SFC
66 |
67 | def classifier(self, query, proto):
68 | """
69 | Simple classifier that computes the negative squared Euclidean distance between query and prototype vectors,
70 | scaled by a temperature parameter for controlling the sharpness of the distribution.
71 |
72 | Parameters:
73 | - query: Query set embeddings.
74 | - proto: Prototype vectors.
75 |
76 | Returns:
77 | - logits: Logits representing similarity scores between each query and each prototype.
78 | """
79 | logits = -torch.sum((proto.unsqueeze(0) - query.unsqueeze(1)) ** 2, 2) / self.args.temperature
80 | return logits.squeeze()
81 |
82 | def _forward(self, instance_embs, support_idx, query_idx):
83 | """
84 | Forward pass of the model, processing both support and query data.
85 |
86 | Parameters:
87 | - instance_embs: Embeddings of all instances.
88 | - support_idx: Indices identifying support instances.
89 | - query_idx: Indices identifying query instances.
90 |
91 | Implements the forward pass, integrating both spatial and feature adaptation using the INSTA module.
92 | """
93 | emb_dim = instance_embs.size()[-3:]
94 | channel_dim = emb_dim[0]
95 |
96 | # Organize support and query data based on indices, and reshape accordingly.
97 | support = instance_embs[support_idx.flatten()].view(*(support_idx.shape + emb_dim))
98 | query = instance_embs[query_idx.flatten()].view(*(query_idx.shape + emb_dim))
99 | num_samples = support.shape[1]
100 | num_proto = support.shape[2]
101 | support = support.squeeze()
102 |
103 | # Adapt support features using the INSTA model and average to form adapted prototypes.
104 | adapted_s, task_kernel = self.INSTA(support.view(-1, *emb_dim))
105 | query = query.view(-1, *emb_dim)
106 | adapted_proto = adapted_s.view(num_samples, -1, *adapted_s.shape[1:]).mean(0)
107 | adapted_proto = nn.AdaptiveAvgPool2d(1)(adapted_proto).squeeze(-1).squeeze(-1)
108 |
109 | # Adapt query features using the INSTA unfolding and kernel multiplication approach.
110 | query_ = nn.AdaptiveAvgPool2d(1)((self.INSTA.unfold(query, int((task_kernel.shape[-1]+1)/2-1), task_kernel.shape[-1]) * task_kernel)).squeeze()
111 | query = query + query_
112 | adapted_q = nn.AdaptiveAvgPool2d(1)(query).squeeze(-1).squeeze(-1)
113 |
114 | # Optionally perform an inner loop optimization during testing.
115 | if self.args.testing:
116 | adapted_proto = self.inner_loop(adapted_proto, nn.AdaptiveAvgPool2d(1)(support).squeeze().view(num_proto*num_samples, channel_dim))
117 |
118 | # Classify using the adapted prototypes and query embeddings.
119 | logits = self.classifier(adapted_q, adapted_proto)
120 |
121 | if self.training:
122 | reg_logits = None
123 | return logits, reg_logits
124 | else:
125 | return logits
126 |
--------------------------------------------------------------------------------
/model/trainer/helpers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.optim as optim
5 | from torch.utils.data import DataLoader
6 | from model.dataloader.samplers import CategoriesSampler
7 | from model.models.protonet import ProtoNet
8 | from model.models.INSTA_ProtoNet import INSTA_ProtoNet
9 |
10 | class MultiGPUDataloader:
11 | def __init__(self, dataloader, num_device):
12 | self.dataloader = dataloader
13 | self.num_device = num_device
14 |
15 | def __len__(self):
16 | return len(self.dataloader) // self.num_device
17 |
18 | def __iter__(self):
19 | data_iter = iter(self.dataloader)
20 | done = False
21 |
22 | while not done:
23 | try:
24 | output_batch = ([], [])
25 | for _ in range(self.num_device):
26 | batch = next(data_iter)
27 | for i, v in enumerate(batch):
28 | output_batch[i].append(v[None])
29 |
30 | yield ( torch.cat(_, dim=0) for _ in output_batch )
31 | except StopIteration:
32 | done = True
33 | return
34 |
35 | def get_dataloader(args):
36 |
37 | if args.dataset == 'CUB':
38 | from model.dataloader.cub import CUB as Dataset
39 | elif args.dataset == 'TieredImageNet':
40 | from model.dataloader.tiered_imagenet import tieredImageNet as Dataset
41 | elif args.dataset == 'MiniImageNet':
42 | from model.dataloader.mini_imagenet import MiniImageNet as Dataset
43 | else:
44 | raise ValueError('Non-supported Dataset.')
45 |
46 | num_device = torch.cuda.device_count()
47 | num_episodes = args.episodes_per_epoch*num_device if args.multi_gpu else args.episodes_per_epoch
48 | num_workers=args.num_workers*num_device if args.multi_gpu else args.num_workers
49 | trainset = Dataset('train', args, augment=args.augment)
50 | args.num_class = trainset.num_class
51 | train_sampler = CategoriesSampler(trainset.label,
52 | num_episodes,
53 | max(args.way, args.num_classes),
54 | args.shot + args.query)
55 |
56 | train_loader = DataLoader(dataset=trainset,
57 | num_workers=num_workers,
58 | batch_sampler=train_sampler,
59 | pin_memory=True)
60 |
61 |
62 | valset = Dataset('val', args)
63 | val_sampler = CategoriesSampler(valset.label,
64 | args.num_eval_episodes,
65 | args.eval_way, args.eval_shot + args.eval_query)
66 | val_loader = DataLoader(dataset=valset,
67 | batch_sampler=val_sampler,
68 | num_workers=args.num_workers,
69 | pin_memory=True)
70 |
71 |
72 | testset = Dataset('test', args)
73 | test_sampler = CategoriesSampler(testset.label,
74 | 600, # args.num_eval_episodes,
75 | args.eval_way, args.eval_shot + args.eval_query)
76 | test_loader = DataLoader(dataset=testset,
77 | batch_sampler=test_sampler,
78 | num_workers=args.num_workers,
79 | pin_memory=True)
80 |
81 | return train_loader, val_loader, test_loader
82 |
83 | def prepare_model(args):
84 | model = eval(args.model_class)(args)
85 |
86 | # load pre-trained model (no FC weights)
87 | if args.init_weights is not None:
88 | model_dict = model.state_dict()
89 | pretrained_dict = torch.load(args.init_weights)['params']
90 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
91 | print(pretrained_dict.keys())
92 | model_dict.update(pretrained_dict)
93 | model.load_state_dict(model_dict)
94 |
95 | if torch.cuda.is_available():
96 | torch.backends.cudnn.benchmark = True
97 |
98 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
99 | model = model.to(device)
100 | if args.multi_gpu:
101 | model.encoder = nn.DataParallel(model.encoder, dim=0)
102 | para_model = model.to(device)
103 | else:
104 | para_model = model.to(device)
105 |
106 | return model, para_model
107 |
108 | def prepare_optimizer(model, args):
109 | top_para = [v for k, v in model.named_parameters() if 'encoder' not in k]
110 | if args.use_AdamW:
111 | optimizer = optim.AdamW(
112 | [{'params': model.encoder.parameters()},
113 | {'params': top_para, 'lr': args.lr * args.lr_mul}],
114 | lr=args.lr
115 |
116 | )
117 |
118 |
119 | else:
120 | optimizer = optim.SGD(
121 | [{'params': model.encoder.parameters()},
122 | {'params': top_para, 'lr': args.lr * args.lr_mul}],
123 | lr=args.lr,
124 | momentum=args.mom,
125 | nesterov=True,
126 | weight_decay=args.weight_decay
127 | )
128 |
129 |
130 |
131 | if args.lr_scheduler == 'step':
132 | lr_scheduler = optim.lr_scheduler.StepLR(
133 | optimizer,
134 | step_size=int(args.step_size),
135 | gamma=args.gamma
136 | )
137 | elif args.lr_scheduler == 'multistep':
138 | lr_scheduler = optim.lr_scheduler.MultiStepLR(
139 | optimizer,
140 | milestones=[int(_) for _ in args.step_size.split(',')],
141 | gamma=args.gamma,
142 | )
143 | elif args.lr_scheduler == 'cosine':
144 | lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
145 | optimizer,
146 | args.max_epoch,
147 | eta_min=0
148 | )
149 | else:
150 | raise ValueError('No Such Scheduler')
151 |
152 | return optimizer, lr_scheduler
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # INSTA: Learning Instance and Task-Aware Dynamic Kernels for Few Shot Learning
2 |
3 |
4 |
5 |
6 |
7 | If this repository is helpful to you, please cite the following bib:
8 | ```Shell
9 | @article{ma2021learning,
10 | title={Learning Instance and Task-Aware Dynamic Kernels for Few Shot Learning},
11 | author={Ma, Rongkai and Fang, Pengfei and Avraham, Gil and Zuo, Yan and Drummond, Tom and Harandi, Mehrtash},
12 | journal={arXiv preprint arXiv:2112.03494},
13 | year={2021}
14 | }
15 | ```
16 | This repository provides the implementation and demo of [**Learning Instance and Task-Aware Dynamic Kernels for Few Shot Learning**](https://arxiv.org/abs/2112.03494) on [Prototypical Network](https://arxiv.org/pdf/1703.05175.pdf). The dynamic enviroment of few-shot learning (FSL) requires a model capable of rapidly adapting to the novel tasks. Moreover, given the low-data regime of FSL, it requires the model to encode rich information for per-data sample. To tackle this problem, we propose to learn a dynamic kernel that is both **ins**tance and **t**ask-**a**ware: **INSTA** for each channel and spatial location of a feature map, given the task (episode) at hands. Beyond that, we further incorporate the inforamtion from the fequency domain to generate our dynamic kernel.
17 |
18 |
19 |
20 |
21 | ## Prerequisites
22 | We use anaconda to manage the virtual environment. Please install the following packages to run this repository. If there is a "No module" error, please install the suggested packages according to the error message.
23 | * python 3.8
24 | * [pytorch 1.7.0](https://pytorch.org/get-started/previous-versions/)
25 | * torchvision 0.8.0
26 | * torchaudio 0.7.0
27 | * tqdm
28 | * tensorboardX
29 |
30 | ## Dataset
31 |
32 | ### Tiered-ImageNet
33 |
34 | Tiered-ImageNet is also a subset of the ImageNet. This dataset consists of 608 classes from 34 categories and is split into 351 classes from 20 categories for training, 97 classes from 6 categories for validation, and 160 classes from 8 categories for testing. You can download the processed dataset in this [repository](https://github.com/icoz69/DeepEMD). Once the dataset is downloaded, please move it to /data direcotry. Note that the images have been resized into 84x84.
35 |
36 | ### Mini-ImageNet
37 | ```Shell
38 | ├── data
39 | ├── Mini-ImageNet
40 | ├── split
41 | ├── train
42 | ├── validation
43 | ├── test
44 | ├── images
45 | ├── im_0.jpg
46 | ├── im_1.jpg
47 | .
48 | .
49 | .
50 | ├── im_n.jpg
51 | ```
52 |
53 | Mini-ImageNet is sampled from ImageNet. This dataset has 100 classes, with each having 600 samples. We follow the standard protocol to split the dataset into 64 training, 16 validation, and 20 testing classes. For downloading the corresponding split and data files, please refer to [this repository](https://github.com/Sha-Lab/FEAT).
54 |
55 | ### CUB
56 |
57 | The CUB is a fine-grained dataset, which consists of 11,788 images from 200 different breeds of birds. We follow the standard settings, in which the dataset is split into 100/50/50 breeds for training, validation, and testing, respectively. For ResNet-12 backbone, please refer to [this repository](https://github.com/icoz69/DeepEMD) to split the datasset and for ResNet-18 backbone, please refer to [this repository ](https://github.com/imtiazziko/LaplacianShot).
58 |
59 | ### FC100
60 |
61 | FC100 dataset is a variant of the standard CIFAR100 dataset, which contains images from 100 classes, with each class containing 600 samples. We follow the standard setting, where the dataset is split into 60/20/20 classes for training, validation and testing, respectively. For downloading and split the data, please refer to [DeepEMD repository](https://github.com/icoz69/DeepEMD).
62 |
63 | ## Training
64 |
65 | We provide the example command line for Tiered-ImageNet and Mini-ImageNet below:
66 | ```shell
67 | $ python train_fsl.py --max_epoch 200 --model_class INSTA_ProtoNet --backbone_class Res12 --dataset TieredImageNet --way 5 --eval_way 5 --shot 5 --eval_shot 5 --query 15 --eval_query 15 --temperature 32 --temperature2 64 --lr 0.0002 --lr_mul 100 --lr_scheduler cosine --gamma 0.5 --gpu 1 --init_weights ./saves/initialization/tieredimagenet/Res12-pre.pth --eval_interval 1 --use_euclidean
68 | ```
69 | ```shell
70 | $ python train_fsl.py --max_epoch 200 --model_class INSTA_ProtoNet --backbone_class Res12 --dataset TieredImageNet --way 5 --eval_way 5 --shot 1 --eval_shot 1 --query 15 --eval_query 15 --temperature 64 --temperature2 64 --lr 0.0002 --lr_mul 30 --lr_scheduler cosine --gamma 0.5 --gpu 0 --init_weights ./saves/initialization/tieredimagenet/Res12-pre.pth --eval_interval 1 --use_euclidean
71 | ```
72 | ```shell
73 | $ python train_fsl.py --max_epoch 200 --model_class INSTA_ProtoNet --backbone_class Res12 --dataset MiniImageNet --way 5 --eval_way 5 --shot 1 --eval_shot 1 --query 15 --eval_query 15 --temperature 64 --temperature2 64 --lr 0.0002 --lr_mul 25 --lr_scheduler cosine --gamma 0.5 --gpu 0 --init_weights ./saves/initialization/miniimagenet/Res12-pre.pth --eval_interval 1 --use_euclidean
74 | ```
75 | ```shell
76 | $ python train_fsl.py --max_epoch 200 --model_class INSTA_ProtoNet --backbone_class Res12 --dataset MiniImageNet --way 5 --eval_way 5 --shot 5 --eval_shot 5 --query 15 --eval_query 15 --balance_1 1 --temperature 24 --temperature2 32 --lr 0.0002 --lr_mul 25 --lr_scheduler cosine --gamma 0.5 --gpu 0 --init_weights ./saves/initialization/miniimagenet/Res12-pre.pth --eval_interval 1 --use_euclidean
77 | ```
78 | ## To Do
79 | *
80 | *
81 |
82 | ## Acknowledgements
83 | We acknowledge the following repositories to provide valuable insight of our code construciton:
84 |
85 | * [FEAT](https://github.com/Sha-Lab/FEAT)
86 | * [DeepEMD](https://github.com/icoz69/DeepEMD)
87 | * [Chen *etal*](https://github.com/wyharveychen/CloserLookFewShot)
88 | * [DeepBDC](https://github.com/Fei-Long121/DeepBDC)
89 | * [DDF](https://github.com/theFoxofSky/ddfnet)
90 | * [FCANet](https://github.com/cfzd/FcaNet)
91 | * [Fan *etal*](https://github.com/fanq15/FSOD-code)
92 | * [simple-cnaps](https://github.com/peymanbateni/simple-cnaps)
93 |
--------------------------------------------------------------------------------
/model/networks/res18.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | __all__ = ['resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
4 | 'resnet152']
5 |
6 |
7 | def conv3x3(in_planes, out_planes, stride=1):
8 | """3x3 convolution with padding"""
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
10 | padding=1, bias=False)
11 |
12 |
13 | def conv1x1(in_planes, out_planes, stride=1):
14 | """1x1 convolution"""
15 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
16 |
17 |
18 | class BasicBlock(nn.Module):
19 | expansion = 1
20 |
21 | def __init__(self, inplanes, planes, stride=1, downsample=None):
22 | super(BasicBlock, self).__init__()
23 | self.conv1 = conv3x3(inplanes, planes, stride)
24 | self.bn1 = nn.BatchNorm2d(planes)
25 | self.relu = nn.ReLU(inplace=True)
26 | self.conv2 = conv3x3(planes, planes)
27 | self.bn2 = nn.BatchNorm2d(planes)
28 | self.downsample = downsample
29 | self.stride = stride
30 |
31 | def forward(self, x):
32 | identity = x
33 |
34 | out = self.conv1(x)
35 | out = self.bn1(out)
36 | out = self.relu(out)
37 |
38 | out = self.conv2(out)
39 | out = self.bn2(out)
40 |
41 | if self.downsample is not None:
42 | identity = self.downsample(x)
43 |
44 | out += identity
45 | out = self.relu(out)
46 |
47 | return out
48 |
49 |
50 | class Bottleneck(nn.Module):
51 | expansion = 4
52 |
53 | def __init__(self, inplanes, planes, stride=1, downsample=None):
54 | super(Bottleneck, self).__init__()
55 | self.conv1 = conv1x1(inplanes, planes)
56 | self.bn1 = nn.BatchNorm2d(planes)
57 | self.conv2 = conv3x3(planes, planes, stride)
58 | self.bn2 = nn.BatchNorm2d(planes)
59 | self.conv3 = conv1x1(planes, planes * self.expansion)
60 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
61 | self.relu = nn.ReLU(inplace=True)
62 | self.downsample = downsample
63 | self.stride = stride
64 |
65 | def forward(self, x):
66 | identity = x
67 |
68 | out = self.conv1(x)
69 | out = self.bn1(out)
70 | out = self.relu(out)
71 |
72 | out = self.conv2(out)
73 | out = self.bn2(out)
74 | out = self.relu(out)
75 |
76 | out = self.conv3(out)
77 | out = self.bn3(out)
78 |
79 | if self.downsample is not None:
80 | identity = self.downsample(x)
81 |
82 | out += identity
83 | out = self.relu(out)
84 |
85 | return out
86 |
87 |
88 | class ResNet(nn.Module):
89 |
90 | def __init__(self, block=BasicBlock, layers=[2, 2, 2, 2], zero_init_residual=False):
91 | super(ResNet, self).__init__()
92 | self.inplanes = 64
93 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1,
94 | bias=False)
95 | self.bn1 = nn.BatchNorm2d(64)
96 | self.relu = nn.ReLU(inplace=True)
97 | self.layer1 = self._make_layer(block, 64, layers[0])
98 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
99 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
100 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
101 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
102 |
103 | for m in self.modules():
104 | if isinstance(m, nn.Conv2d):
105 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
106 | elif isinstance(m, nn.BatchNorm2d):
107 | nn.init.constant_(m.weight, 1)
108 | nn.init.constant_(m.bias, 0)
109 |
110 | # Zero-initialize the last BN in each residual branch,
111 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
112 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
113 | if zero_init_residual:
114 | for m in self.modules():
115 | if isinstance(m, Bottleneck):
116 | nn.init.constant_(m.bn3.weight, 0)
117 | elif isinstance(m, BasicBlock):
118 | nn.init.constant_(m.bn2.weight, 0)
119 |
120 | def _make_layer(self, block, planes, blocks, stride=1):
121 | downsample = None
122 | if stride != 1 or self.inplanes != planes * block.expansion:
123 | downsample = nn.Sequential(
124 | conv1x1(self.inplanes, planes * block.expansion, stride),
125 | nn.BatchNorm2d(planes * block.expansion),
126 | )
127 |
128 | layers = []
129 | layers.append(block(self.inplanes, planes, stride, downsample))
130 | self.inplanes = planes * block.expansion
131 | for _ in range(1, blocks):
132 | layers.append(block(self.inplanes, planes))
133 |
134 | return nn.Sequential(*layers)
135 |
136 | def forward(self, x):
137 | x = self.conv1(x)
138 | x = self.bn1(x)
139 | x = self.relu(x)
140 |
141 | x = self.layer1(x)
142 | x = self.layer2(x)
143 | x = self.layer3(x)
144 | x = self.layer4(x)
145 |
146 | # x = self.avgpool(x)
147 | # x = x.view(x.size(0), -1)
148 |
149 | return x
150 |
151 |
152 | def resnet10(**kwargs):
153 | """Constructs a ResNet-10 model.
154 | """
155 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs)
156 | return model
157 |
158 |
159 | def resnet18(**kwargs):
160 | """Constructs a ResNet-18 model.
161 | """
162 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
163 | return model
164 |
165 |
166 | def resnet34(**kwargs):
167 | """Constructs a ResNet-34 model.
168 | """
169 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
170 | return model
171 |
172 |
173 | def resnet50(**kwargs):
174 | """Constructs a ResNet-50 model.
175 | """
176 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
177 | return model
178 |
179 |
180 | def resnet101(**kwargs):
181 | """Constructs a ResNet-101 model.
182 | """
183 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
184 | return model
185 |
186 |
187 | def resnet152(**kwargs):
188 | """Constructs a ResNet-152 model.
189 | """
190 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
191 | return model
--------------------------------------------------------------------------------
/model/trainer/fsl_trainer.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os.path as osp
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from model.trainer.base import Trainer
8 | from model.trainer.helpers import (
9 | get_dataloader, prepare_model, prepare_optimizer,
10 | )
11 | from model.utils import (
12 | pprint, ensure_path,
13 | Averager, Timer, count_acc,
14 | compute_confidence_interval
15 | )
16 |
17 | from tqdm import tqdm
18 |
19 |
20 | class FSLTrainer(Trainer):
21 | def __init__(self, args):
22 | super().__init__(args)
23 |
24 | self.train_loader, self.val_loader, self.test_loader = get_dataloader(args)
25 | self.model, self.para_model = prepare_model(args)
26 | self.optimizer, self.lr_scheduler = prepare_optimizer(self.model, args)
27 |
28 | def prepare_label(self):
29 | args = self.args
30 |
31 | # prepare one-hot label
32 | label = torch.arange(args.way, dtype=torch.int16).repeat(args.query)
33 | label_aux = torch.arange(args.way, dtype=torch.int8).repeat(args.shot + args.query)
34 |
35 | label = label.type(torch.LongTensor)
36 | label_aux = label_aux.type(torch.LongTensor)
37 |
38 | if torch.cuda.is_available():
39 | label = label.cuda()
40 | label_aux = label_aux.cuda()
41 |
42 | return label, label_aux
43 |
44 | def train(self):
45 | args = self.args
46 | self.model.train()
47 | if self.args.fix_BN:
48 | self.model.encoder.eval()
49 |
50 | # start FSL training
51 | label, label_aux = self.prepare_label()
52 | for epoch in range(1, args.max_epoch + 1):
53 | self.train_epoch += 1
54 | self.model.train()
55 | if self.args.fix_BN:
56 | self.model.encoder.eval()
57 |
58 | tl1 = Averager()
59 | tl2 = Averager()
60 | ta = Averager()
61 |
62 | start_tm = time.time()
63 | for batch in self.train_loader:
64 | self.train_step += 1
65 |
66 | if torch.cuda.is_available():
67 | data, gt_label = [_.cuda() for _ in batch]
68 | else:
69 | data, gt_label = batch[0], batch[1]
70 |
71 | data_tm = time.time()
72 | self.dt.add(data_tm - start_tm)
73 |
74 | # get saved centers
75 | logits, reg_logits = self.para_model(data)
76 |
77 | if reg_logits is not None:
78 | loss = F.cross_entropy(logits, label)
79 | total_loss = args.balance_1*loss + args.balance_2 * F.cross_entropy(reg_logits, label_aux)
80 |
81 |
82 | else:
83 | loss = F.cross_entropy(logits, label)
84 | total_loss = F.cross_entropy(logits, label)
85 |
86 | tl2.add(loss)
87 | forward_tm = time.time()
88 | self.ft.add(forward_tm - data_tm)
89 | acc = count_acc(logits, label)
90 | tl1.add(total_loss.item())
91 | ta.add(acc)
92 |
93 | self.optimizer.zero_grad()
94 | total_loss.backward()
95 | backward_tm = time.time()
96 | self.bt.add(backward_tm - forward_tm)
97 |
98 | self.optimizer.step()
99 | optimizer_tm = time.time()
100 | self.ot.add(optimizer_tm - backward_tm)
101 |
102 | # refresh start_tm
103 | start_tm = time.time()
104 |
105 | self.lr_scheduler.step()
106 | self.try_evaluate(epoch)
107 |
108 | print('ETA:{}/{}'.format(
109 | self.timer.measure(),
110 | self.timer.measure(self.train_epoch / args.max_epoch))
111 | )
112 |
113 | torch.save(self.trlog, osp.join(args.save_path, 'trlog'))
114 | self.save_model('epoch-last')
115 |
116 | def evaluate(self, data_loader):
117 | # restore model args
118 | args = self.args
119 | # evaluation mode
120 | self.model.eval()
121 | record = np.zeros((args.num_eval_episodes, 2)) # loss and acc
122 | label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query)
123 | label = label.type(torch.LongTensor)
124 | if torch.cuda.is_available():
125 | label = label.cuda()
126 | print('best epoch {}, best val acc={:.4f} + {:.4f}'.format(
127 | self.trlog['max_acc_epoch'],
128 | self.trlog['max_acc'],
129 | self.trlog['max_acc_interval']))
130 | with torch.no_grad():
131 | for i, batch in enumerate(data_loader, 1):
132 | if torch.cuda.is_available():
133 | data, _ = [_.cuda() for _ in batch]
134 | else:
135 | data = batch[0]
136 |
137 | logits = self.model(data)
138 | loss = F.cross_entropy(logits, label)
139 | acc = count_acc(logits, label)
140 | record[i-1, 0] = loss.item()
141 | record[i-1, 1] = acc
142 |
143 | assert(i == record.shape[0])
144 | vl, _ = compute_confidence_interval(record[:,0])
145 | va, vap = compute_confidence_interval(record[:,1])
146 |
147 | # train mode
148 | self.model.train()
149 | if self.args.fix_BN:
150 | self.model.encoder.eval()
151 |
152 | return vl, va, vap
153 |
154 |
155 | def evaluate_test(self):
156 | # restore model args
157 | args = self.args
158 | self.args.testing = True
159 | self.model.load_state_dict(torch.load(osp.join(self.args.save_path, 'max_acc.pth'))['params'])
160 | self.model.eval()
161 | record = np.zeros((600, 2)) # loss and acc
162 | label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query)
163 | label = label.type(torch.LongTensor)
164 | if torch.cuda.is_available():
165 | label = label.cuda()
166 | print('best epoch {}, best val acc={:.4f} + {:.4f}'.format(
167 | self.trlog['max_acc_epoch'],
168 | self.trlog['max_acc'],
169 | self.trlog['max_acc_interval']))
170 | with torch.no_grad():
171 | for i, batch in tqdm(enumerate(self.test_loader, 1)):
172 | if torch.cuda.is_available():
173 | data, _ = [_.cuda() for _ in batch]
174 | else:
175 | data = batch[0]
176 |
177 | logits = self.model(data)
178 |
179 | loss = F.cross_entropy(logits, label)
180 | acc = count_acc(logits, label)
181 | record[i-1, 0] = loss.item()
182 | record[i-1, 1] = acc
183 | assert(i == record.shape[0])
184 | vl, _ = compute_confidence_interval(record[:,0])
185 | va, vap = compute_confidence_interval(record[:,1])
186 |
187 | self.trlog['test_acc'] = va
188 | self.trlog['test_acc_interval'] = vap
189 | self.trlog['test_loss'] = vl
190 |
191 | print('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format(
192 | self.trlog['max_acc_epoch'],
193 | self.trlog['max_acc'],
194 | self.trlog['max_acc_interval']))
195 | print('Test acc={:.4f} + {:.4f}\n'.format(
196 | self.trlog['test_acc'],
197 | self.trlog['test_acc_interval']))
198 |
199 | return vl, va, vap
200 |
201 | def final_record(self):
202 | # save the best performance in a txt file
203 |
204 | with open(osp.join(self.args.save_path, '{}+{}'.format(self.trlog['test_acc'], self.trlog['test_acc_interval'])), 'w') as f:
205 | f.write('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format(
206 | self.trlog['max_acc_epoch'],
207 | self.trlog['max_acc'],
208 | self.trlog['max_acc_interval']))
209 | f.write('Test acc={:.4f} + {:.4f}\n'.format(
210 | self.trlog['test_acc'],
211 | self.trlog['test_acc_interval']))
--------------------------------------------------------------------------------
/model/models/INSTA.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 | from model.models.fcanet import MultiSpectralAttentionLayer
5 |
6 | """
7 | The INSTA class inherits from nn.Module and implements an attention mechanism
8 | that involves both channel and spatial features. It's designed to work with feature maps
9 | and applies both a channel attention and a learned convolutional kernel for spatial attention.
10 | """
11 |
12 | class INSTA(nn.Module):
13 | def __init__(self, c, spatial_size, sigma, k, args):
14 | """
15 | Initialize the INSTA network module.
16 |
17 | Parameters:
18 | - c: Number of channels in the input feature map.
19 | - spatial_size: The height and width of the input feature map.
20 | - sigma: A parameter possibly used for normalization or a scale parameter in attention mechanisms.
21 | - k: Kernel size for convolution operations and spatial attention.
22 | - args: Additional arguments for setup, possibly including hyperparameters or configuration options.
23 | """
24 | super().__init__()
25 | self.channel = c
26 | self.h1 = sigma
27 | self.h2 = k **2
28 | self.k = k
29 | # Standard 2D convolution for channel reduction or transformation.
30 | self.conv = nn.Conv2d(self.channel, self.h2, 1)
31 | # Batch normalization for the output of the spatial attention.
32 | self.fn_spatial = nn.BatchNorm2d(spatial_size**2)
33 | # Batch normalization for the output of the channel attention.
34 | self.fn_channel = nn.BatchNorm2d(self.channel)
35 | # Unfold operation for transforming feature map into patches.
36 | self.Unfold = nn.Unfold(kernel_size=self.k, padding=int((self.k+1)/2-1))
37 | self.spatial_size = spatial_size
38 | # Dictionary mapping channel numbers to width/height for MultiSpectralAttentionLayer.
39 | c2wh = dict([(512, 11), (640, self.spatial_size)])
40 | # MultiSpectralAttentionLayer for performing attention across spectral (frequency) components.
41 | self.channel_att = MultiSpectralAttentionLayer(c, c2wh[c], c2wh[c], sigma=self.h1, k=self.k, freq_sel_method='low16')
42 | self.args = args
43 | # Upper part of a Coordinate Learning Module (CLM), which modifies feature maps.
44 | self.CLM_upper = nn.Sequential(
45 | nn.Conv2d(c, c*2, 1),
46 | nn.BatchNorm2d(c*2),
47 | nn.ReLU(),
48 | nn.Conv2d(c*2, c*2, 1),
49 | nn.BatchNorm2d(c * 2),
50 | nn.ReLU()
51 | )
52 |
53 | # Lower part of CLM, transforming the features back to original channel dimensions and applying sigmoid.
54 | self.CLM_lower = nn.Sequential(
55 | nn.Conv2d(c*2, c*2, 1),
56 | nn.BatchNorm2d(c*2),
57 | nn.ReLU(),
58 | nn.Conv2d(c*2, c, 1),
59 | nn.BatchNorm2d(c),
60 | nn.Sigmoid() # Sigmoid activation to normalize the feature values between 0 and 1.
61 | )
62 |
63 | def CLM(self, featuremap):
64 | """
65 | The Coordinate Learning Module (CLM) that processes feature maps to adapt them spatially.
66 |
67 | Parameters:
68 | - featuremap: The input feature map to the CLM.
69 |
70 | Returns:
71 | - The adapted feature map processed through the CLM.
72 | """
73 | # Apply the upper CLM to modify and then aggregate features.
74 | adap = self.CLM_upper(featuremap)
75 | intermediate = adap.sum(dim=0) # Summing features across the batch dimension.
76 | adap_1 = self.CLM_lower(intermediate.unsqueeze(0)) # Applying the lower CLM.
77 | return adap_1
78 |
79 | def spatial_kernel_network(self, feature_map, conv):
80 | """
81 | Applies a convolution to the feature map to generate a spatial kernel,
82 | which will be used to modulate the spatial regions of the input features.
83 |
84 | Parameters:
85 | - feature_map: The feature map to process.
86 | - conv: The convolutional layer to apply.
87 |
88 | Returns:
89 | - The processed spatial kernel.
90 | """
91 | spatial_kernel = conv(feature_map)
92 | spatial_kernel = spatial_kernel.flatten(-2).transpose(-1, -2)
93 | size = spatial_kernel.size()
94 | spatial_kernel = spatial_kernel.view(size[0], -1, self.k, self.k)
95 | spatial_kernel = self.fn_spatial(spatial_kernel)
96 |
97 | spatial_kernel = spatial_kernel.flatten(-2)
98 | return spatial_kernel
99 |
100 | def channel_kernel_network(self, feature_map):
101 | """
102 | Processes the feature map through a channel attention mechanism to modulate the channels
103 | based on their importance.
104 |
105 | Parameters:
106 | - feature_map: The feature map to process.
107 |
108 | Returns:
109 | - The channel-modulated feature map.
110 | """
111 | channel_kernel = self.channel_att(feature_map)
112 | channel_kernel = self.fn_channel(channel_kernel)
113 | channel_kernel = channel_kernel.flatten(-2)
114 | channel_kernel = channel_kernel.squeeze().view(channel_kernel.shape[0], self.channel, -1)
115 | return channel_kernel
116 |
117 | def unfold(self, x, padding, k):
118 | """
119 | A manual implementation of the unfold operation, which extracts sliding local blocks from a batched input tensor.
120 |
121 | Parameters:
122 | - x: The input tensor.
123 | - padding: Padding to apply to the tensor.
124 | - k: Kernel size for the blocks to extract.
125 |
126 | Returns:
127 | - The unfolded tensor containing all local blocks.
128 | """
129 | x_padded = torch.cuda.FloatTensor(x.shape[0], x.shape[1], x.shape[2] + 2 * padding, x.shape[3] + 2 * padding).fill_(0)
130 | x_padded[:, :, padding:-padding, padding:-padding] = x
131 | x_unfolded = torch.cuda.FloatTensor(*x.shape, k, k).fill_(0)
132 | for i in range(int((self.k+1)/2-1), x.shape[2] + int((self.k+1)/2-1)):
133 | for j in range(int((self.k+1)/2-1), x.shape[3] + int((self.k+1)/2-1)):
134 | x_unfolded[:, :, i - int(((self.k+1)/2-1)), j - int(((self.k+1)/2-1)), :, :] = x_padded[:, :, i-int(((self.k+1)/2-1)):i + int((self.k+1)/2), j - int(((self.k+1)/2-1)):j + int(((self.k+1)/2))]
135 | return x_unfolded
136 |
137 | def forward(self, x):
138 | """
139 | The forward method of INSTA, which combines the spatial and channel kernels to adapt the feature map,
140 | along with performing the unfolding operation to facilitate local receptive processing.
141 |
142 | Parameters:
143 | - x: The input tensor to the network.
144 |
145 | Returns:
146 | - The adapted feature map and the task-specific kernel used for adaptation.
147 | """
148 | spatial_kernel = self.spatial_kernel_network(x, self.conv).unsqueeze(-3)
149 | channel_kernenl = self.channel_kernel_network(x).unsqueeze(-2)
150 | kernel = spatial_kernel * channel_kernenl # Combine spatial and channel kernels
151 | # Resize kernel and apply to the unfolded feature map
152 | kernel_shape = kernel.size()
153 | feature_shape = x.size()
154 | instance_kernel = kernel.view(kernel_shape[0], kernel_shape[1], feature_shape[-2], feature_shape[-1], self.k, self.k)
155 | task_s = self.CLM(x) # Get task-specific representation
156 | spatial_kernel_task = self.spatial_kernel_network(task_s, self.conv).unsqueeze(-3)
157 | channel_kernenl_task = self.channel_kernel_network(task_s).unsqueeze(-2)
158 | task_kernel = spatial_kernel_task * channel_kernenl_task
159 | task_kernel_shape = task_kernel.size()
160 | task_kernel = task_kernel.view(task_kernel_shape[0], task_kernel_shape[1], feature_shape[-2], feature_shape[-1], self.k, self.k)
161 | kernel = task_kernel * instance_kernel
162 | unfold_feature = self.unfold(x, int((self.k+1)/2-1), self.k) # Perform a custom unfold operation
163 | adapted_feauture = (unfold_feature * kernel).mean(dim=(-1, -2)).squeeze(-1).squeeze(-1)
164 | return adapted_feauture + x, task_kernel # Return the normal training output and task-specific kernel
165 |
--------------------------------------------------------------------------------
/model/networks/res10.py:
--------------------------------------------------------------------------------
1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate
2 |
3 | import torch
4 | from torch.autograd import Variable
5 | import torch.nn as nn
6 | import math
7 | import numpy as np
8 | import torch.nn.functional as F
9 | from torch.nn.utils.weight_norm import WeightNorm
10 |
11 |
12 | # Basic ResNet model
13 |
14 | def init_layer(L):
15 | # Initialization using fan-in
16 | if isinstance(L, nn.Conv2d):
17 | n = L.kernel_size[0] * L.kernel_size[1] * L.out_channels
18 | L.weight.data.normal_(0, math.sqrt(2.0 / float(n)))
19 | elif isinstance(L, nn.BatchNorm2d):
20 | L.weight.data.fill_(1)
21 | L.bias.data.fill_(0)
22 |
23 |
24 | class distLinear(nn.Module):
25 | def __init__(self, indim, outdim):
26 | super(distLinear, self).__init__()
27 | self.L = nn.Linear(indim, outdim, bias=False)
28 | self.class_wise_learnable_norm = True # See the issue#4&8 in the github
29 | if self.class_wise_learnable_norm:
30 | WeightNorm.apply(self.L, 'weight', dim=0) # split the weight update component to direction and norm
31 |
32 | if outdim <= 200:
33 | self.scale_factor = 2; # a fixed scale factor to scale the output of cos value into a reasonably large input for softmax, for to reproduce the result of CUB with ResNet10, use 4. see the issue#31 in the github
34 | else:
35 | self.scale_factor = 10; # in omniglot, a larger scale factor is required to handle >1000 output classes.
36 |
37 | def forward(self, x):
38 | x_norm = torch.norm(x, p=2, dim=1).unsqueeze(1).expand_as(x)
39 | x_normalized = x.div(x_norm + 0.00001)
40 | if not self.class_wise_learnable_norm:
41 | L_norm = torch.norm(self.L.weight.data, p=2, dim=1).unsqueeze(1).expand_as(self.L.weight.data)
42 | self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001)
43 | cos_dist = self.L(
44 | x_normalized) # matrix product by forward function, but when using WeightNorm, this also multiply the cosine distance by a class-wise learnable norm, see the issue#4&8 in the github
45 | scores = self.scale_factor * (cos_dist)
46 |
47 | return scores
48 |
49 |
50 | class Flatten(nn.Module):
51 | def __init__(self):
52 | super(Flatten, self).__init__()
53 |
54 | def forward(self, x):
55 | return x.view(x.size(0), -1)
56 |
57 |
58 | class Linear_fw(nn.Linear): # used in MAML to forward input with fast weight
59 | def __init__(self, in_features, out_features):
60 | super(Linear_fw, self).__init__(in_features, out_features)
61 | self.weight.fast = None # Lazy hack to add fast weight link
62 | self.bias.fast = None
63 |
64 | def forward(self, x):
65 | if self.weight.fast is not None and self.bias.fast is not None:
66 | out = F.linear(x, self.weight.fast,
67 | self.bias.fast) # weight.fast (fast weight) is the temporaily adapted weight
68 | else:
69 | out = super(Linear_fw, self).forward(x)
70 | return out
71 |
72 |
73 | class Conv2d_fw(nn.Conv2d): # used in MAML to forward input with fast weight
74 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
75 | super(Conv2d_fw, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
76 | bias=bias)
77 | self.weight.fast = None
78 | if not self.bias is None:
79 | self.bias.fast = None
80 |
81 | def forward(self, x):
82 | if self.bias is None:
83 | if self.weight.fast is not None:
84 | out = F.conv2d(x, self.weight.fast, None, stride=self.stride, padding=self.padding)
85 | else:
86 | out = super(Conv2d_fw, self).forward(x)
87 | else:
88 | if self.weight.fast is not None and self.bias.fast is not None:
89 | out = F.conv2d(x, self.weight.fast, self.bias.fast, stride=self.stride, padding=self.padding)
90 | else:
91 | out = super(Conv2d_fw, self).forward(x)
92 |
93 | return out
94 |
95 |
96 | class BatchNorm2d_fw(nn.BatchNorm2d): # used in MAML to forward input with fast weight
97 | def __init__(self, num_features):
98 | super(BatchNorm2d_fw, self).__init__(num_features)
99 | self.weight.fast = None
100 | self.bias.fast = None
101 |
102 | def forward(self, x):
103 | running_mean = torch.zeros(x.data.size()[1]).cuda()
104 | running_var = torch.ones(x.data.size()[1]).cuda()
105 | if self.weight.fast is not None and self.bias.fast is not None:
106 | out = F.batch_norm(x, running_mean, running_var, self.weight.fast, self.bias.fast, training=True,
107 | momentum=1)
108 | # batch_norm momentum hack: follow hack of Kate Rakelly in pytorch-maml/src/layers.py
109 | else:
110 | out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, training=True, momentum=1)
111 | return out
112 |
113 |
114 | # Simple Conv Block
115 |
116 |
117 | # Simple ResNet Block
118 | class SimpleBlock(nn.Module):
119 | maml = False # Default
120 |
121 | def __init__(self, indim, outdim, half_res):
122 | super(SimpleBlock, self).__init__()
123 | self.indim = indim
124 | self.outdim = outdim
125 | self.C1 = Conv2d_fw(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False)
126 | self.BN1 = BatchNorm2d_fw(outdim)
127 | self.C2 = Conv2d_fw(outdim, outdim, kernel_size=3, padding=1, bias=False)
128 | self.BN2 = BatchNorm2d_fw(outdim)
129 |
130 | self.relu1 = nn.ReLU(inplace=True)
131 | self.relu2 = nn.ReLU(inplace=True)
132 |
133 | self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2]
134 |
135 | self.half_res = half_res
136 |
137 | # if the input number of channels is not equal to the output, then need a 1x1 convolution
138 | if indim != outdim:
139 | if self.maml:
140 | self.shortcut = Conv2d_fw(indim, outdim, 1, 2 if half_res else 1, bias=False)
141 | self.BNshortcut = BatchNorm2d_fw(outdim)
142 | else:
143 | self.shortcut = nn.Conv2d(indim, outdim, 1, 2 if half_res else 1, bias=False)
144 | self.BNshortcut = nn.BatchNorm2d(outdim)
145 |
146 | self.parametrized_layers.append(self.shortcut)
147 | self.parametrized_layers.append(self.BNshortcut)
148 | self.shortcut_type = '1x1'
149 | else:
150 | self.shortcut_type = 'identity'
151 |
152 | for layer in self.parametrized_layers:
153 | init_layer(layer)
154 |
155 | def forward(self, x):
156 | out = self.C1(x)
157 | out = self.BN1(out)
158 | out = self.relu1(out)
159 | out = self.C2(out)
160 | out = self.BN2(out)
161 | short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x))
162 | out = out + short_out
163 | out = self.relu2(out)
164 | return out
165 |
166 |
167 |
168 |
169 |
170 |
171 | class ResNet(nn.Module):
172 | maml = False # Default
173 |
174 | def __init__(self, block, list_of_num_layers, list_of_out_dims, flatten=True):
175 | # list_of_num_layers specifies number of layers in each stage
176 | # list_of_out_dims specifies number of output channel for each stage
177 | super(ResNet, self).__init__()
178 | assert len(list_of_num_layers) == 4, 'Can have only four stages'
179 |
180 | conv1 = Conv2d_fw(3, 64, kernel_size=7, stride=2, padding=3,
181 | bias=False)
182 | bn1 = BatchNorm2d_fw(64)
183 | relu = nn.ReLU()
184 | pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
185 |
186 | init_layer(conv1)
187 | init_layer(bn1)
188 |
189 | trunk = [conv1, bn1, relu, pool1]
190 |
191 | indim = 64
192 | for i in range(4):
193 |
194 | for j in range(list_of_num_layers[i]):
195 | half_res = (i >= 1) and (j == 0)
196 | B = block(indim, list_of_out_dims[i], half_res)
197 | trunk.append(B)
198 | indim = list_of_out_dims[i]
199 |
200 | if flatten:
201 | avgpool = nn.AvgPool2d(7)
202 | trunk.append(avgpool)
203 | trunk.append(Flatten())
204 | self.final_feat_dim = indim
205 | else:
206 | self.final_feat_dim = [indim, 7, 7]
207 |
208 | self.trunk = nn.Sequential(*trunk)
209 |
210 | def forward(self, x):
211 | out = self.trunk(x)
212 | return out
213 |
214 |
215 |
216 |
217 | def ResNet10(flatten=True):
218 | return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], flatten)
219 |
220 |
221 |
222 |
223 |
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import time
4 | import pprint
5 | import torch
6 | import argparse
7 | import numpy as np
8 | import torch.nn as nn
9 | def one_hot(indices, depth):
10 | """
11 | Returns a one-hot tensor.
12 | This is a PyTorch equivalent of Tensorflow's tf.one_hot.
13 |
14 | Parameters:
15 | indices: a (n_batch, m) Tensor or (m) Tensor.
16 | depth: a scalar. Represents the depth of the one hot dimension.
17 | Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor.
18 | """
19 |
20 | encoded_indicies = torch.zeros(indices.size() + torch.Size([depth]))
21 | if indices.is_cuda:
22 | encoded_indicies = encoded_indicies.cuda()
23 | index = indices.view(indices.size()+torch.Size([1]))
24 | encoded_indicies = encoded_indicies.scatter_(1,index,1)
25 |
26 | return encoded_indicies
27 |
28 | def set_gpu(x):
29 | os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
30 | os.environ['CUDA_VISIBLE_DEVICES'] = x
31 | print('using gpu:', x)
32 |
33 | def ensure_path(dir_path, scripts_to_save=None):
34 | if os.path.exists(dir_path):
35 | if input('{} exists, remove? ([y]/n)'.format(dir_path)) != 'n':
36 | shutil.rmtree(dir_path)
37 | os.mkdir(dir_path)
38 | else:
39 | os.mkdir(dir_path)
40 |
41 | print('Experiment dir : {}'.format(dir_path))
42 | if scripts_to_save is not None:
43 | script_path = os.path.join(dir_path, 'scripts')
44 | if not os.path.exists(script_path):
45 | os.makedirs(script_path)
46 | for src_file in scripts_to_save:
47 | dst_file = os.path.join(dir_path, 'scripts', os.path.basename(src_file))
48 | print('copy {} to {}'.format(src_file, dst_file))
49 | if os.path.isdir(src_file):
50 | shutil.copytree(src_file, dst_file)
51 | else:
52 | shutil.copyfile(src_file, dst_file)
53 |
54 | class Averager():
55 |
56 | def __init__(self):
57 | self.n = 0
58 | self.v = 0
59 |
60 | def add(self, x):
61 | self.v = (self.v * self.n + x) / (self.n + 1)
62 | self.n += 1
63 |
64 | def item(self):
65 | return self.v
66 |
67 |
68 | class CrossEntropyLoss(nn.Module):
69 | def __init__(self):
70 | super(CrossEntropyLoss, self).__init__()
71 | self.logsoftmax = nn.LogSoftmax(dim=1)
72 |
73 | def forward(self, inputs, targets):
74 | input_ = inputs
75 | input_ = input_.view(input_.size(0), input_.size(1), -1)
76 |
77 | log_probs = self.logsoftmax(input_)
78 | targets_ = torch.zeros(input_.size(0), input_.size(1)).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
79 | targets_ = targets_.unsqueeze(-1)
80 | targets_ = targets_.cuda()
81 | loss = (- targets_ * log_probs).mean(0).sum()
82 | return loss / input_.size(2)
83 |
84 |
85 |
86 | def c_acc(logits, labels_test):
87 | _, preds = torch.max(logits, 1)
88 | acc = (torch.sum(preds == labels_test)).type(torch.cuda.FloatTensor) / labels_test.size(0)
89 |
90 | return acc.item()
91 |
92 |
93 |
94 | def count_acc(logits, label):
95 | pred = torch.argmax(logits, dim=1)
96 | if torch.cuda.is_available():
97 | return (pred == label).type(torch.cuda.FloatTensor).mean().item()
98 | else:
99 | return (pred == label).type(torch.FloatTensor).mean().item()
100 |
101 | def euclidean_metric(a, b):
102 | n = a.shape[0]
103 | m = b.shape[0]
104 | a = a.unsqueeze(1).expand(n, m, -1)
105 | b = b.unsqueeze(0).expand(n, m, -1)
106 | logits = -((a - b)**2).sum(dim=2)
107 | return logits
108 |
109 | class Timer():
110 |
111 | def __init__(self):
112 | self.o = time.time()
113 |
114 | def measure(self, p=1):
115 | x = (time.time() - self.o) / p
116 | x = int(x)
117 | if x >= 3600:
118 | return '{:.1f}h'.format(x / 3600)
119 | if x >= 60:
120 | return '{}m'.format(round(x / 60))
121 | return '{}s'.format(x)
122 |
123 | _utils_pp = pprint.PrettyPrinter()
124 | def pprint(x):
125 | _utils_pp.pprint(x)
126 |
127 | def compute_confidence_interval(data):
128 | """
129 | Compute 95% confidence interval
130 | :param data: An array of mean accuracy (or mAP) across a number of sampled episodes.
131 | :return: the 95% confidence interval for this data.
132 | """
133 | a = 1.0 * np.array(data)
134 | m = np.mean(a)
135 | std = np.std(a)
136 | pm = 1.96 * (std / np.sqrt(len(a)))
137 | return m, pm
138 |
139 | def postprocess_args(args):
140 | args.num_classes = args.way
141 | save_path1 = '-'.join([args.dataset, args.model_class, args.backbone_class,
142 | '{:02d}w{:02d}s{:02}q'.format(args.way, args.shot, args.query)])
143 |
144 |
145 | save_path2 = '_'.join([str('_'.join(args.step_size.split(','))), str(args.gamma),
146 | 'lr{:.2g}mul{:.2g}'.format(args.lr, args.lr_mul),
147 | str(args.lr_scheduler),
148 | 'T1{}T2{}'.format(args.temperature, args.temperature2),
149 | 'b{}'.format(args.balance_1),
150 | 'bsz{:03d}'.format( max(args.way, args.num_classes)*(args.shot+args.query) ),
151 | ])
152 | if args.init_weights is not None:
153 | save_path1 += '-Pre'
154 | if args.use_euclidean:
155 | save_path1 += '-DIS'
156 | else:
157 | save_path1 += '-SIM'
158 |
159 | if args.fix_BN:
160 | save_path2 += '-FBN'
161 | if not args.augment:
162 | save_path2 += '-NoAug'
163 |
164 | if not os.path.exists(os.path.join(args.save_dir, save_path1)):
165 | os.mkdir(os.path.join(args.save_dir, save_path1))
166 | args.save_path = os.path.join(args.save_dir, save_path1, save_path2)
167 | return args
168 |
169 | def get_command_line_parser():
170 | parser = argparse.ArgumentParser()
171 | parser.add_argument('--max_epoch', type=int, default=200)
172 | parser.add_argument('--episodes_per_epoch', type=int, default=100)
173 | parser.add_argument('--num_eval_episodes', type=int, default=600)
174 | parser.add_argument('--model_class', type=str, default='INSTA_PorotNet',
175 | choices=['INSTA_ProtoNet', 'ProtoNet'])
176 | parser.add_argument('--use_euclidean', action='store_true', default=False)
177 | parser.add_argument('--use_AdamW', action='store_true', default=False)
178 | parser.add_argument('--backbone_class', type=str, default='Res12',
179 | choices=['Res12', 'Res18'])
180 | parser.add_argument('--dataset', type=str, default='MiniImageNet',
181 | choices=['MiniImageNet', 'TieredImageNet', 'CUB', 'FC100'])
182 |
183 | parser.add_argument('--way', type=int, default=5)
184 | parser.add_argument('--eval_way', type=int, default=5)
185 | parser.add_argument('--shot', type=int, default=1)
186 | parser.add_argument('--eval_shot', type=int, default=1)
187 | parser.add_argument('--query', type=int, default=15)
188 | parser.add_argument('--eval_query', type=int, default=15)
189 | parser.add_argument('--balance_1', type=float, default=0)
190 | parser.add_argument('--balance_2', type=float, default=0)
191 | parser.add_argument('--temperature', type=float, default=1)
192 | parser.add_argument('--temperature2', type=float, default=1) # the temperature in the
193 |
194 | # optimization parameters
195 | parser.add_argument('--orig_imsize', type=int, default=-1) # -1 for no cache, and -2 for no resize, only for MiniImageNet and CUB
196 | parser.add_argument('--lr', type=float, default=0.0001)
197 | parser.add_argument('--lr_mul', type=float, default=10)
198 | parser.add_argument('--lr_scheduler', type=str, default='step', choices=['multistep', 'step', 'cosine'])
199 | parser.add_argument('--step_size', type=str, default='20')
200 | parser.add_argument('--gamma', type=float, default=0.2)
201 | parser.add_argument('--fix_BN', action='store_true', default=False) # means we do not update the running mean/var in BN, not to freeze BN
202 | parser.add_argument('--augment', action='store_true', default=False)
203 | parser.add_argument('--baseline', type=str, default='y')
204 | parser.add_argument('--multi_gpu', action='store_true', default=False)
205 | parser.add_argument('--gpu', default='0')
206 | parser.add_argument('--init_weights', type=str, default=None)
207 | parser.add_argument('--emb_adap', action='store_true', default=False)
208 | parser.add_argument('--testing', action='store_true', default=False)
209 |
210 | # usually untouched parameters
211 | parser.add_argument('--mom', type=float, default=0.9)
212 | parser.add_argument('--weight_decay', type=float, default=0.0005) # we find this weight decay value works the best
213 | parser.add_argument('--num_workers', type=int, default=8)
214 | parser.add_argument('--log_interval', type=int, default=50)
215 | parser.add_argument('--eval_interval', type=int, default=1)
216 | parser.add_argument('--save_dir', type=str, default='./checkpoints')
217 |
218 | return parser
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
117 |
118 |
119 |
--------------------------------------------------------------------------------
/model/models/utils/transformers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Module, ModuleList, Linear, Dropout, LayerNorm, Identity, Parameter, init
3 | import torch.nn.functional as F
4 | from .stochastic_depth import DropPath
5 |
6 |
7 | class Attention(Module):
8 | """
9 | Obtained from timm: github.com:rwightman/pytorch-image-models
10 | """
11 |
12 | def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
13 | super().__init__()
14 | self.num_heads = num_heads
15 | head_dim = dim // self.num_heads
16 | self.scale = head_dim ** -0.5
17 |
18 | self.qkv = Linear(dim, dim * 3, bias=False)
19 | self.attn_drop = Dropout(attention_dropout)
20 | self.proj = Linear(dim, dim)
21 | self.proj_drop = Dropout(projection_dropout)
22 |
23 | def forward(self, x):
24 | B, N, C = x.shape
25 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
26 | q, k, v = qkv[0], qkv[1], qkv[2]
27 |
28 | attn = (q @ k.transpose(-2, -1)) * self.scale
29 | attn = attn.softmax(dim=-1)
30 | attn = self.attn_drop(attn)
31 |
32 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
33 | x = self.proj(x)
34 | x = self.proj_drop(x)
35 | return x
36 |
37 |
38 | class MaskedAttention(Module):
39 | def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
40 | super().__init__()
41 | self.num_heads = num_heads
42 | head_dim = dim // self.num_heads
43 | self.scale = head_dim ** -0.5
44 |
45 | self.qkv = Linear(dim, dim * 3, bias=False)
46 | self.attn_drop = Dropout(attention_dropout)
47 | self.proj = Linear(dim, dim)
48 | self.proj_drop = Dropout(projection_dropout)
49 |
50 | def forward(self, x, mask=None):
51 | B, N, C = x.shape
52 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
53 | q, k, v = qkv[0], qkv[1], qkv[2]
54 |
55 | attn = (q @ k.transpose(-2, -1)) * self.scale
56 |
57 | if mask is not None:
58 | mask_value = -torch.finfo(attn.dtype).max
59 | assert mask.shape[-1] == attn.shape[-1], 'mask has incorrect dimensions'
60 | mask = mask[:, None, :] * mask[:, :, None]
61 | mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
62 | attn.masked_fill_(~mask, mask_value)
63 |
64 | attn = attn.softmax(dim=-1)
65 | attn = self.attn_drop(attn)
66 |
67 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
68 | x = self.proj(x)
69 | x = self.proj_drop(x)
70 | return x
71 |
72 |
73 | class TransformerEncoderLayer(Module):
74 | """
75 | Inspired by torch.nn.TransformerEncoderLayer and timm.
76 | """
77 |
78 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
79 | attention_dropout=0.1, drop_path_rate=0.1):
80 | super(TransformerEncoderLayer, self).__init__()
81 | self.pre_norm = LayerNorm(d_model)
82 | self.self_attn = Attention(dim=d_model, num_heads=nhead,
83 | attention_dropout=attention_dropout, projection_dropout=dropout)
84 |
85 | self.linear1 = Linear(d_model, dim_feedforward)
86 | self.dropout1 = Dropout(dropout)
87 | self.norm1 = LayerNorm(d_model)
88 | self.linear2 = Linear(dim_feedforward, d_model)
89 | self.dropout2 = Dropout(dropout)
90 |
91 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity()
92 |
93 | self.activation = F.gelu
94 |
95 | def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
96 | src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
97 | src = self.norm1(src)
98 | src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
99 | src = src + self.drop_path(self.dropout2(src2))
100 | return src
101 |
102 |
103 | class MaskedTransformerEncoderLayer(Module):
104 | """
105 | Inspired by torch.nn.TransformerEncoderLayer and timm.
106 | """
107 |
108 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
109 | attention_dropout=0.1, drop_path_rate=0.1):
110 | super(MaskedTransformerEncoderLayer, self).__init__()
111 | self.pre_norm = LayerNorm(d_model)
112 | self.self_attn = MaskedAttention(dim=d_model, num_heads=nhead,
113 | attention_dropout=attention_dropout, projection_dropout=dropout)
114 |
115 | self.linear1 = Linear(d_model, dim_feedforward)
116 | self.dropout1 = Dropout(dropout)
117 | self.norm1 = LayerNorm(d_model)
118 | self.linear2 = Linear(dim_feedforward, d_model)
119 | self.dropout2 = Dropout(dropout)
120 |
121 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity()
122 |
123 | self.activation = F.gelu
124 |
125 | def forward(self, src: torch.Tensor, mask=None, *args, **kwargs) -> torch.Tensor:
126 | src = src + self.drop_path(self.self_attn(self.pre_norm(src), mask))
127 | src = self.norm1(src)
128 | src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
129 | src = src + self.drop_path(self.dropout2(src2))
130 | return src
131 |
132 |
133 | class TransformerClassifier(Module):
134 | def __init__(self,
135 | seq_pool=True,
136 | embedding_dim=768,
137 | num_layers=12,
138 | num_heads=12,
139 | mlp_ratio=4.0,
140 | num_classes=1000,
141 | dropout_rate=0.1,
142 | attention_dropout=0.1,
143 | stochastic_depth_rate=0.1,
144 | positional_embedding='sine',
145 | sequence_length=None,
146 | *args, **kwargs):
147 | super().__init__()
148 | positional_embedding = positional_embedding if \
149 | positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
150 | dim_feedforward = int(embedding_dim * mlp_ratio)
151 | self.embedding_dim = embedding_dim
152 | self.sequence_length = sequence_length
153 | self.seq_pool = seq_pool
154 |
155 | assert sequence_length is not None or positional_embedding == 'none', \
156 | f"Positional embedding is set to {positional_embedding} and" \
157 | f" the sequence length was not specified."
158 |
159 | if not seq_pool:
160 | sequence_length += 1
161 | self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim),
162 | requires_grad=True)
163 | else:
164 | self.attention_pool = Linear(self.embedding_dim, 1)
165 |
166 | if positional_embedding != 'none':
167 | if positional_embedding == 'learnable':
168 | self.positional_emb = Parameter(torch.zeros(1, sequence_length, embedding_dim),
169 | requires_grad=True)
170 | init.trunc_normal_(self.positional_emb, std=0.2)
171 | else:
172 | self.positional_emb = Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim),
173 | requires_grad=False)
174 | else:
175 | self.positional_emb = None
176 |
177 | self.dropout = Dropout(p=dropout_rate)
178 | dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]
179 | self.blocks = ModuleList([
180 | TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
181 | dim_feedforward=dim_feedforward, dropout=dropout_rate,
182 | attention_dropout=attention_dropout, drop_path_rate=dpr[i])
183 | for i in range(num_layers)])
184 | self.norm = LayerNorm(embedding_dim)
185 |
186 | self.fc = Linear(embedding_dim, num_classes)
187 | self.apply(self.init_weight)
188 |
189 | def forward(self, x):
190 | if self.positional_emb is None and x.size(1) < self.sequence_length:
191 | x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
192 |
193 | if not self.seq_pool:
194 | cls_token = self.class_emb.expand(x.shape[0], -1, -1)
195 | x = torch.cat((cls_token, x), dim=1)
196 |
197 | if self.positional_emb is not None:
198 | x += self.positional_emb
199 |
200 | x = self.dropout(x)
201 |
202 | for blk in self.blocks:
203 | x = blk(x)
204 | x = self.norm(x)
205 |
206 | if self.seq_pool:
207 | x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
208 | else:
209 | x = x[:, 0]
210 |
211 | x = self.fc(x)
212 | return x
213 |
214 | @staticmethod
215 | def init_weight(m):
216 | if isinstance(m, Linear):
217 | init.trunc_normal_(m.weight, std=.02)
218 | if isinstance(m, Linear) and m.bias is not None:
219 | init.constant_(m.bias, 0)
220 | elif isinstance(m, LayerNorm):
221 | init.constant_(m.bias, 0)
222 | init.constant_(m.weight, 1.0)
223 |
224 | @staticmethod
225 | def sinusoidal_embedding(n_channels, dim):
226 | pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
227 | for p in range(n_channels)])
228 | pe[:, 0::2] = torch.sin(pe[:, 0::2])
229 | pe[:, 1::2] = torch.cos(pe[:, 1::2])
230 | return pe.unsqueeze(0)
231 |
232 |
233 | class MaskedTransformerClassifier(Module):
234 | def __init__(self,
235 | seq_pool=True,
236 | embedding_dim=768,
237 | num_layers=12,
238 | num_heads=12,
239 | mlp_ratio=4.0,
240 | num_classes=1000,
241 | dropout_rate=0.1,
242 | attention_dropout=0.1,
243 | stochastic_depth_rate=0.1,
244 | positional_embedding='sine',
245 | seq_len=None,
246 | *args, **kwargs):
247 | super().__init__()
248 | positional_embedding = positional_embedding if \
249 | positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
250 | dim_feedforward = int(embedding_dim * mlp_ratio)
251 | self.embedding_dim = embedding_dim
252 | self.seq_len = seq_len
253 | self.seq_pool = seq_pool
254 |
255 | assert seq_len is not None or positional_embedding == 'none', \
256 | f"Positional embedding is set to {positional_embedding} and" \
257 | f" the sequence length was not specified."
258 |
259 | if not seq_pool:
260 | seq_len += 1
261 | self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim),
262 | requires_grad=True)
263 | else:
264 | self.attention_pool = Linear(self.embedding_dim, 1)
265 |
266 | if positional_embedding != 'none':
267 | if positional_embedding == 'learnable':
268 | seq_len += 1 # padding idx
269 | self.positional_emb = Parameter(torch.zeros(1, seq_len, embedding_dim),
270 | requires_grad=True)
271 | init.trunc_normal_(self.positional_emb, std=0.2)
272 | else:
273 | self.positional_emb = Parameter(self.sinusoidal_embedding(seq_len,
274 | embedding_dim,
275 | padding_idx=True),
276 | requires_grad=False)
277 | else:
278 | self.positional_emb = None
279 |
280 | self.dropout = Dropout(p=dropout_rate)
281 | dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]
282 | self.blocks = ModuleList([
283 | MaskedTransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
284 | dim_feedforward=dim_feedforward, dropout=dropout_rate,
285 | attention_dropout=attention_dropout, drop_path_rate=dpr[i])
286 | for i in range(num_layers)])
287 | self.norm = LayerNorm(embedding_dim)
288 |
289 | self.fc = Linear(embedding_dim, num_classes)
290 | self.apply(self.init_weight)
291 |
292 | def forward(self, x, mask=None):
293 | if self.positional_emb is None and x.size(1) < self.seq_len:
294 | x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
295 |
296 | if not self.seq_pool:
297 | cls_token = self.class_emb.expand(x.shape[0], -1, -1)
298 | x = torch.cat((cls_token, x), dim=1)
299 | if mask is not None:
300 | mask = torch.cat([torch.ones(size=(mask.shape[0], 1), device=mask.device), mask.float()], dim=1)
301 | mask = (mask > 0)
302 |
303 | if self.positional_emb is not None:
304 | x += self.positional_emb
305 |
306 | x = self.dropout(x)
307 |
308 | for blk in self.blocks:
309 | x = blk(x, mask=mask)
310 | x = self.norm(x)
311 |
312 | if self.seq_pool:
313 | x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
314 | else:
315 | x = x[:, 0]
316 |
317 | x = self.fc(x)
318 | return x
319 |
320 | @staticmethod
321 | def init_weight(m):
322 | if isinstance(m, Linear):
323 | init.trunc_normal_(m.weight, std=.02)
324 | if isinstance(m, Linear) and m.bias is not None:
325 | init.constant_(m.bias, 0)
326 | elif isinstance(m, LayerNorm):
327 | init.constant_(m.bias, 0)
328 | init.constant_(m.weight, 1.0)
329 |
330 | @staticmethod
331 | def sinusoidal_embedding(n_channels, dim, padding_idx=False):
332 | pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
333 | for p in range(n_channels)])
334 | pe[:, 0::2] = torch.sin(pe[:, 0::2])
335 | pe[:, 1::2] = torch.cos(pe[:, 1::2])
336 | pe = pe.unsqueeze(0)
337 | if padding_idx:
338 | return torch.cat([torch.zeros((1, 1, dim)), pe], dim=1)
339 | return pe
340 |
--------------------------------------------------------------------------------
/model/networks/utils/transformers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Module, ModuleList, Linear, Dropout, LayerNorm, Identity, Parameter, init
3 | import torch.nn.functional as F
4 | from .stochastic_depth import DropPath
5 |
6 |
7 | class Attention(Module):
8 | """
9 | Obtained from timm: github.com:rwightman/pytorch-image-models
10 | """
11 |
12 | def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
13 | super().__init__()
14 | self.num_heads = num_heads
15 | head_dim = dim // self.num_heads
16 | self.scale = head_dim ** -0.5
17 |
18 | self.qkv = Linear(dim, dim * 3, bias=False)
19 | self.attn_drop = Dropout(attention_dropout)
20 | self.proj = Linear(dim, dim)
21 | self.proj_drop = Dropout(projection_dropout)
22 |
23 | def forward(self, x):
24 | B, N, C = x.shape
25 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
26 | q, k, v = qkv[0], qkv[1], qkv[2]
27 |
28 | attn = (q @ k.transpose(-2, -1)) * self.scale
29 | attn = attn.softmax(dim=-1)
30 | attn = self.attn_drop(attn)
31 |
32 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
33 | x = self.proj(x)
34 | x = self.proj_drop(x)
35 | return x
36 |
37 |
38 | class MaskedAttention(Module):
39 | def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
40 | super().__init__()
41 | self.num_heads = num_heads
42 | head_dim = dim // self.num_heads
43 | self.scale = head_dim ** -0.5
44 |
45 | self.qkv = Linear(dim, dim * 3, bias=False)
46 | self.attn_drop = Dropout(attention_dropout)
47 | self.proj = Linear(dim, dim)
48 | self.proj_drop = Dropout(projection_dropout)
49 |
50 | def forward(self, x, mask=None):
51 | B, N, C = x.shape
52 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
53 | q, k, v = qkv[0], qkv[1], qkv[2]
54 |
55 | attn = (q @ k.transpose(-2, -1)) * self.scale
56 |
57 | if mask is not None:
58 | mask_value = -torch.finfo(attn.dtype).max
59 | assert mask.shape[-1] == attn.shape[-1], 'mask has incorrect dimensions'
60 | mask = mask[:, None, :] * mask[:, :, None]
61 | mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
62 | attn.masked_fill_(~mask, mask_value)
63 |
64 | attn = attn.softmax(dim=-1)
65 | attn = self.attn_drop(attn)
66 |
67 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
68 | x = self.proj(x)
69 | x = self.proj_drop(x)
70 | return x
71 |
72 |
73 | class TransformerEncoderLayer(Module):
74 | """
75 | Inspired by torch.nn.TransformerEncoderLayer and timm.
76 | """
77 |
78 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
79 | attention_dropout=0.1, drop_path_rate=0.1):
80 | super(TransformerEncoderLayer, self).__init__()
81 | self.pre_norm = LayerNorm(d_model)
82 | self.self_attn = Attention(dim=d_model, num_heads=nhead,
83 | attention_dropout=attention_dropout, projection_dropout=dropout)
84 |
85 | self.linear1 = Linear(d_model, dim_feedforward)
86 | self.dropout1 = Dropout(dropout)
87 | self.norm1 = LayerNorm(d_model)
88 | self.linear2 = Linear(dim_feedforward, d_model)
89 | self.dropout2 = Dropout(dropout)
90 |
91 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity()
92 |
93 | self.activation = F.gelu
94 |
95 | def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
96 | src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
97 | src = self.norm1(src)
98 | src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
99 | src = src + self.drop_path(self.dropout2(src2))
100 | return src
101 |
102 |
103 | class MaskedTransformerEncoderLayer(Module):
104 | """
105 | Inspired by torch.nn.TransformerEncoderLayer and timm.
106 | """
107 |
108 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
109 | attention_dropout=0.1, drop_path_rate=0.1):
110 | super(MaskedTransformerEncoderLayer, self).__init__()
111 | self.pre_norm = LayerNorm(d_model)
112 | self.self_attn = MaskedAttention(dim=d_model, num_heads=nhead,
113 | attention_dropout=attention_dropout, projection_dropout=dropout)
114 |
115 | self.linear1 = Linear(d_model, dim_feedforward)
116 | self.dropout1 = Dropout(dropout)
117 | self.norm1 = LayerNorm(d_model)
118 | self.linear2 = Linear(dim_feedforward, d_model)
119 | self.dropout2 = Dropout(dropout)
120 |
121 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity()
122 |
123 | self.activation = F.gelu
124 |
125 | def forward(self, src: torch.Tensor, mask=None, *args, **kwargs) -> torch.Tensor:
126 | src = src + self.drop_path(self.self_attn(self.pre_norm(src), mask))
127 | src = self.norm1(src)
128 | src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
129 | src = src + self.drop_path(self.dropout2(src2))
130 | return src
131 |
132 |
133 | class TransformerClassifier(Module):
134 | def __init__(self,
135 | seq_pool=True,
136 | embedding_dim=768,
137 | num_layers=12,
138 | num_heads=12,
139 | mlp_ratio=4.0,
140 | num_classes=1000,
141 | dropout_rate=0.1,
142 | attention_dropout=0.1,
143 | stochastic_depth_rate=0.1,
144 | positional_embedding='sine',
145 | sequence_length=None,
146 | *args, **kwargs):
147 | super().__init__()
148 | positional_embedding = positional_embedding if \
149 | positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
150 | dim_feedforward = int(embedding_dim * mlp_ratio)
151 | self.embedding_dim = embedding_dim
152 | self.sequence_length = sequence_length
153 | self.seq_pool = seq_pool
154 |
155 | assert sequence_length is not None or positional_embedding == 'none', \
156 | f"Positional embedding is set to {positional_embedding} and" \
157 | f" the sequence length was not specified."
158 |
159 | if not seq_pool:
160 | sequence_length += 1
161 | self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim),
162 | requires_grad=True)
163 | else:
164 | self.attention_pool = Linear(self.embedding_dim, 1)
165 |
166 | if positional_embedding != 'none':
167 | if positional_embedding == 'learnable':
168 | self.positional_emb = Parameter(torch.zeros(1, sequence_length, embedding_dim),
169 | requires_grad=True)
170 | init.trunc_normal_(self.positional_emb, std=0.2)
171 | else:
172 | self.positional_emb = Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim),
173 | requires_grad=False)
174 | else:
175 | self.positional_emb = None
176 |
177 | self.dropout = Dropout(p=dropout_rate)
178 | dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]
179 | self.blocks = ModuleList([
180 | TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
181 | dim_feedforward=dim_feedforward, dropout=dropout_rate,
182 | attention_dropout=attention_dropout, drop_path_rate=dpr[i])
183 | for i in range(num_layers)])
184 | self.norm = LayerNorm(embedding_dim)
185 |
186 | # self.fc = Linear(embedding_dim, num_classes)
187 | self.apply(self.init_weight)
188 |
189 | def forward(self, x):
190 | if self.positional_emb is None and x.size(1) < self.sequence_length:
191 | x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
192 |
193 | if not self.seq_pool:
194 | cls_token = self.class_emb.expand(x.shape[0], -1, -1)
195 | x = torch.cat((cls_token, x), dim=1)
196 |
197 | if self.positional_emb is not None:
198 | x += self.positional_emb
199 |
200 | x = self.dropout(x)
201 |
202 | for blk in self.blocks:
203 | x = blk(x)
204 | x = self.norm(x)
205 |
206 | if self.seq_pool:
207 | x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
208 | else:
209 | x = x[:, 0]
210 |
211 | # x = self.fc(x)
212 | return x
213 |
214 | @staticmethod
215 | def init_weight(m):
216 | if isinstance(m, Linear):
217 | init.trunc_normal_(m.weight, std=.02)
218 | if isinstance(m, Linear) and m.bias is not None:
219 | init.constant_(m.bias, 0)
220 | elif isinstance(m, LayerNorm):
221 | init.constant_(m.bias, 0)
222 | init.constant_(m.weight, 1.0)
223 |
224 | @staticmethod
225 | def sinusoidal_embedding(n_channels, dim):
226 | pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
227 | for p in range(n_channels)])
228 | pe[:, 0::2] = torch.sin(pe[:, 0::2])
229 | pe[:, 1::2] = torch.cos(pe[:, 1::2])
230 | return pe.unsqueeze(0)
231 |
232 |
233 | class MaskedTransformerClassifier(Module):
234 | def __init__(self,
235 | seq_pool=True,
236 | embedding_dim=768,
237 | num_layers=12,
238 | num_heads=12,
239 | mlp_ratio=4.0,
240 | num_classes=1000,
241 | dropout_rate=0.1,
242 | attention_dropout=0.1,
243 | stochastic_depth_rate=0.1,
244 | positional_embedding='sine',
245 | seq_len=None,
246 | *args, **kwargs):
247 | super().__init__()
248 | positional_embedding = positional_embedding if \
249 | positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
250 | dim_feedforward = int(embedding_dim * mlp_ratio)
251 | self.embedding_dim = embedding_dim
252 | self.seq_len = seq_len
253 | self.seq_pool = seq_pool
254 |
255 | assert seq_len is not None or positional_embedding == 'none', \
256 | f"Positional embedding is set to {positional_embedding} and" \
257 | f" the sequence length was not specified."
258 |
259 | if not seq_pool:
260 | seq_len += 1
261 | self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim),
262 | requires_grad=True)
263 | else:
264 | self.attention_pool = Linear(self.embedding_dim, 1)
265 |
266 | if positional_embedding != 'none':
267 | if positional_embedding == 'learnable':
268 | seq_len += 1 # padding idx
269 | self.positional_emb = Parameter(torch.zeros(1, seq_len, embedding_dim),
270 | requires_grad=True)
271 | init.trunc_normal_(self.positional_emb, std=0.2)
272 | else:
273 | self.positional_emb = Parameter(self.sinusoidal_embedding(seq_len,
274 | embedding_dim,
275 | padding_idx=True),
276 | requires_grad=False)
277 | else:
278 | self.positional_emb = None
279 |
280 | self.dropout = Dropout(p=dropout_rate)
281 | dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]
282 | self.blocks = ModuleList([
283 | MaskedTransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
284 | dim_feedforward=dim_feedforward, dropout=dropout_rate,
285 | attention_dropout=attention_dropout, drop_path_rate=dpr[i])
286 | for i in range(num_layers)])
287 | self.norm = LayerNorm(embedding_dim)
288 |
289 | self.fc = Linear(embedding_dim, num_classes)
290 | self.apply(self.init_weight)
291 |
292 | def forward(self, x, mask=None):
293 | if self.positional_emb is None and x.size(1) < self.seq_len:
294 | x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
295 |
296 | if not self.seq_pool:
297 | cls_token = self.class_emb.expand(x.shape[0], -1, -1)
298 | x = torch.cat((cls_token, x), dim=1)
299 | if mask is not None:
300 | mask = torch.cat([torch.ones(size=(mask.shape[0], 1), device=mask.device), mask.float()], dim=1)
301 | mask = (mask > 0)
302 |
303 | if self.positional_emb is not None:
304 | x += self.positional_emb
305 |
306 | x = self.dropout(x)
307 |
308 | for blk in self.blocks:
309 | x = blk(x, mask=mask)
310 | x = self.norm(x)
311 |
312 | if self.seq_pool:
313 | x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
314 | else:
315 | x = x[:, 0]
316 |
317 | x = self.fc(x)
318 | return x
319 |
320 | @staticmethod
321 | def init_weight(m):
322 | if isinstance(m, Linear):
323 | init.trunc_normal_(m.weight, std=.02)
324 | if isinstance(m, Linear) and m.bias is not None:
325 | init.constant_(m.bias, 0)
326 | elif isinstance(m, LayerNorm):
327 | init.constant_(m.bias, 0)
328 | init.constant_(m.weight, 1.0)
329 |
330 | @staticmethod
331 | def sinusoidal_embedding(n_channels, dim, padding_idx=False):
332 | pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
333 | for p in range(n_channels)])
334 | pe[:, 0::2] = torch.sin(pe[:, 0::2])
335 | pe[:, 1::2] = torch.cos(pe[:, 1::2])
336 | pe = pe.unsqueeze(0)
337 | if padding_idx:
338 | return torch.cat([torch.zeros((1, 1, dim)), pe], dim=1)
339 | return pe
340 |
--------------------------------------------------------------------------------