├── data
├── __init__.py
└── mini_imagenet.py
├── models
├── __init__.py
├── attention.py
├── resnet12_2.py
├── diffusion_grad.py
└── classification_heads.py
├── __init__.py
├── utils.py
├── README.md
├── metadiff_test.py
├── cal_optial_prototypes.py
└── metadiff_train.py
/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # Implement your code here.
2 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import pprint
4 | import torch
5 |
6 | def set_gpu(x):
7 | os.environ['CUDA_VISIBLE_DEVICES'] = x
8 | print('using gpu:', x)
9 |
10 | def check_dir(path):
11 | '''
12 | Create directory if it does not exist.
13 | path: Path of directory.
14 | '''
15 | if not os.path.exists(path):
16 | os.mkdir(path)
17 |
18 | def count_accuracy(logits, label):
19 | pred = torch.argmax(logits, dim=1).view(-1)
20 | label = label.view(-1)
21 | accuracy = 100 * pred.eq(label).float().mean()
22 | return accuracy
23 |
24 | class Timer():
25 | def __init__(self):
26 | self.o = time.time()
27 |
28 | def measure(self, p=1):
29 | x = (time.time() - self.o) / float(p)
30 | # x = int(x)
31 | if x >= 3600:
32 | return '{:.1f}h'.format(x / 3600)
33 | if x >= 60:
34 | return '{}m'.format(round(x / 60))
35 | return '{}s'.format(x)
36 |
37 | def log(log_file_path, string):
38 | '''
39 | Write one line of log into screen and file.
40 | log_file_path: Path of log file.
41 | string: String to write in log file.
42 | '''
43 | with open(log_file_path, 'a+') as f:
44 | f.write(string + '\n')
45 | f.flush()
46 | print(string)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MetaDiff: Meta-Learning with Conditional Diffusion for Few-Shot Learning
2 | This repository contains the code for the paper:
3 |
4 | [**MetaDiff: Meta-Learning with Conditional Diffusion for Few-Shot Learning**](https://arxiv.org/pdf/2307.16424.pdf)
5 |
6 | Baoquan Zhang, Chuyao Luo, Demin Yu, Huiwei Lin, Xutao Li, Yunming Ye, Bowen Zhang
7 |
8 | AAAI 2024
9 |
10 | ### Abstract
11 |
12 | Equipping a deep model the abaility of few-shot learning, i.e., learning quickly from only few examples, is a core challenge for artificial intelligence. Gradient-based meta-learning approaches effectively address the challenge by learning how to learn novel tasks. Its key idea is learning a deep model in a bi-level optimization manner, where the outer-loop process learns a shared gradient descent algorithm (i.e., its hyperparameters), while the inner-loop process leverage it to optimize a task-specific model by using only few labeled data. Although these existing methods have shown superior performance, the outer-loop process requires calculating second-order derivatives along the inner optimization path, which imposes considerable memory burdens and the risk of vanishing gradients. Drawing inspiration from recent progress of diffusion models, we find that the inner-loop gradient descent process can be actually viewed as a reverse process (i.e., denoising) of diffusion where the target of denoising is model weights but the origin data. Based on this fact, in this paper, we propose to model the gradient descent optimizer as a diffusion model and then present a novel task-conditional diffusion-based meta-learning, called MetaDiff, that effectively models the optimization process of model weights from Gaussion noises to target weights in a denoising manner. Thanks to the training efficiency of diffusion models, our MetaDiff do not need to differentiate through the inner-loop path such that the memory burdens and the risk of vanishing gradients can be effectvely alleviated. Experiment results show that our MetaDiff outperforms the state-of-the-art gradient-based meta-learning family in few-shot learning tasks.
13 |
14 | ### Citation
15 |
16 | If you use this code for your research, please cite our paper:
17 | ```
18 | @inproceedings{zhang2022metadiff,
19 | author = {Zhang, Baoquan and Luo, Chuyao and Yu, Demin and Lin, Huiwei and Li, Xutao and Ye, Yunming and Zhang, Bowen},
20 | title = {MetaDiff: Meta-Learning with Conditional Diffusion for Few-Shot Learning},
21 | booktitle = {AAAI},
22 | year = {2024},
23 | }
24 | ```
25 |
--------------------------------------------------------------------------------
/models/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | import math
5 |
6 |
7 | class SelfAttention(nn.Module):
8 | def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
9 | super().__init__()
10 | self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
11 | self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
12 | self.n_heads = n_heads
13 | self.d_head = d_embed // n_heads
14 |
15 | def forward(self, x, causal_mask=False):
16 | input_shape = x.shape
17 | batch_size, sequence_length, d_embed = input_shape
18 | interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
19 |
20 | q, k, v = self.in_proj(x).chunk(3, dim=-1)
21 |
22 | q = q.view(interim_shape).transpose(1, 2)
23 | k = k.view(interim_shape).transpose(1, 2)
24 | v = v.view(interim_shape).transpose(1, 2)
25 |
26 | weight = q @ k.transpose(-1, -2)
27 | if causal_mask:
28 | mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
29 | weight.masked_fill_(mask, -torch.inf)
30 | weight /= math.sqrt(self.d_head)
31 | weight = F.softmax(weight, dim=-1)
32 |
33 | output = weight @ v
34 | output = output.transpose(1, 2)
35 | output = output.reshape(input_shape)
36 | output = self.out_proj(output)
37 | return output
38 |
39 | class CrossAttention(nn.Module):
40 | def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
41 | super().__init__()
42 | self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
43 | self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
44 | self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
45 | self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
46 | self.n_heads = n_heads
47 | self.d_head = d_embed // n_heads
48 |
49 | def forward(self, x, y):
50 | input_shape = x.shape
51 | batch_size, sequence_length, d_embed = input_shape
52 | interim_shape = (batch_size, -1, self.n_heads, self.d_head)
53 |
54 | q = self.q_proj(x)
55 | k = self.k_proj(y)
56 | v = self.v_proj(y)
57 |
58 | q = q.view(interim_shape).transpose(1, 2)
59 | k = k.view(interim_shape).transpose(1, 2)
60 | v = v.view(interim_shape).transpose(1, 2)
61 |
62 | weight = q @ k.transpose(-1, -2)
63 | weight /= math.sqrt(self.d_head)
64 | weight = F.softmax(weight, dim=-1)
65 |
66 | output = weight @ v
67 | output = output.transpose(1, 2).contiguous()
68 | output = output.view(input_shape)
69 | output = self.out_proj(output)
70 | return output
--------------------------------------------------------------------------------
/models/resnet12_2.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | def conv3x3(in_planes, out_planes):
4 | return nn.Conv2d(in_planes, out_planes, 3, padding=1, bias=False)
5 |
6 |
7 | def conv1x1(in_planes, out_planes):
8 | return nn.Conv2d(in_planes, out_planes, 1, bias=False)
9 |
10 |
11 | def norm_layer(planes):
12 | return nn.BatchNorm2d(planes)
13 |
14 |
15 | class Block(nn.Module):
16 |
17 | def __init__(self, inplanes, planes, downsample, use_relu=True):
18 | super().__init__()
19 |
20 | self.use_relu = use_relu
21 | self.relu = nn.LeakyReLU(0.1)
22 |
23 | self.conv1 = conv3x3(inplanes, planes)
24 | self.bn1 = norm_layer(planes)
25 | self.conv2 = conv3x3(planes, planes)
26 | self.bn2 = norm_layer(planes)
27 | self.conv3 = conv3x3(planes, planes)
28 | self.bn3 = norm_layer(planes)
29 |
30 | self.downsample = downsample
31 |
32 | self.maxpool = nn.MaxPool2d(2)
33 |
34 | def forward(self, x):
35 | out = self.conv1(x)
36 | out = self.bn1(out)
37 | out = self.relu(out)
38 |
39 | out = self.conv2(out)
40 | out = self.bn2(out)
41 | out = self.relu(out)
42 |
43 | out = self.conv3(out)
44 | out = self.bn3(out)
45 |
46 | identity = self.downsample(x)
47 |
48 | out += identity
49 | if self.use_relu:
50 | out = self.relu(out)
51 |
52 | out = self.maxpool(out)
53 |
54 | return out
55 |
56 |
57 | class ResNet12(nn.Module):
58 |
59 | def __init__(self, channels):
60 | super().__init__()
61 |
62 | self.inplanes = 3
63 |
64 | self.layer1 = self._make_layer(channels[0])
65 | self.layer2 = self._make_layer(channels[1])
66 | self.layer3 = self._make_layer(channels[2])
67 | self.layer4 = self._make_layer(channels[3], use_relu=False)
68 |
69 | self.out_dim = channels[3]
70 |
71 | for m in self.modules():
72 | if isinstance(m, nn.Conv2d):
73 | nn.init.kaiming_normal_(m.weight, mode='fan_out',
74 | nonlinearity='leaky_relu')
75 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
76 | nn.init.constant_(m.weight, 1)
77 | nn.init.constant_(m.bias, 0)
78 |
79 | def _make_layer(self, planes, use_relu=True):
80 | downsample = nn.Sequential(
81 | conv1x1(self.inplanes, planes),
82 | norm_layer(planes),
83 | )
84 | block = Block(self.inplanes, planes, downsample, use_relu=use_relu)
85 | self.inplanes = planes
86 | return block
87 |
88 | def forward(self, x, use_pool=True):
89 | x = self.layer1(x)
90 | x = self.layer2(x)
91 | x = self.layer3(x)
92 | x = self.layer4(x)
93 | if use_pool:
94 | x = x.view(x.shape[0], x.shape[1], -1).mean(dim=2)
95 | return x
96 |
97 | def resnet12():
98 | return ResNet12([64, 128, 256, 512])
99 |
100 | def resnet12_wide():
101 | return ResNet12([64, 160, 320, 640])
102 |
--------------------------------------------------------------------------------
/metadiff_test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import argparse
4 | import random
5 | import numpy as np
6 | from tqdm import tqdm
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch.utils.data import DataLoader
11 | from torch.autograd import Variable
12 |
13 | from models.resnet12_2 import resnet12
14 | from models.diffusion_grad import get_diffusion_from_args
15 |
16 | from utils import set_gpu, Timer, count_accuracy, check_dir, log
17 |
18 | def one_hot(indices, depth):
19 | """
20 | Returns a one-hot tensor.
21 | This is a PyTorch equivalent of Tensorflow's tf.one_hot.
22 |
23 | Parameters:
24 | indices: a (n_batch, m) Tensor or (m) Tensor.
25 | depth: a scalar. Represents the depth of the one hot dimension.
26 | Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor.
27 | """
28 |
29 | encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda()
30 | index = indices.view(indices.size()+torch.Size([1]))
31 | encoded_indicies = encoded_indicies.scatter_(1,index,1)
32 |
33 | return encoded_indicies
34 |
35 | def get_model(options):
36 | # Choose the embedding network
37 | if options.network == 'ResNet':
38 | network = resnet12().cuda()
39 | network = torch.nn.DataParallel(network)
40 | fea_dim = 512
41 | else:
42 | print ("Cannot recognize the network type")
43 | assert(False)
44 | # Choose the classification head
45 | cls_head = get_diffusion_from_args().cuda()
46 | return (network, cls_head)
47 |
48 | def get_dataset(options):
49 | # Choose the embedding network
50 | if options.dataset == 'miniImageNet':
51 | from data.mini_imagenet import MiniImageNet, FewShotDataloader
52 | dataset_test = MiniImageNet(phase='test')
53 | data_loader = FewShotDataloader
54 | else:
55 | print ("Cannot recognize the dataset type")
56 | assert(False)
57 |
58 | return (dataset_test, data_loader)
59 |
60 | def seed_torch(seed=21):
61 | os.environ['PYTHONHASHSEED'] = str(seed)
62 | random.seed(seed)
63 | np.random.seed(seed)
64 | torch.manual_seed(seed)
65 | torch.cuda.manual_seed(seed)
66 | torch.cuda.manual_seed_all(seed)
67 | # torch.backends.cudnn.deterministic = True
68 | torch.backends.cudnn.deterministic = True #cpu/gpu结果一致
69 | torch.backends.cudnn.benchmark = False #训练集变化不大时使训练加速
70 |
71 |
72 | def test(opt, dataset_test, data_loader):
73 | # Dataloader of Gidaris & Komodakis (CVPR 2018)
74 | dloader_test = data_loader(
75 | dataset=dataset_test,
76 | nKnovel=opt.test_way,
77 | nKbase=0,
78 | nExemplars=opt.val_shot, # num training examples per novel category
79 | nTestNovel=opt.val_query * opt.test_way, # num test examples for all the novel categories
80 | nTestBase=0, # num test examples for all the base categories
81 | batch_size=1,
82 | num_workers=0,
83 | epoch_size=1 * opt.val_episode, # num of batches per epoch
84 | )
85 |
86 | set_gpu(opt.gpu)
87 | check_dir('./experiments/')
88 | check_dir(opt.save_path)
89 |
90 | log_file_path = os.path.join(opt.save_path, "train_log.txt")
91 | log(log_file_path, str(vars(opt)))
92 |
93 | (embedding_net, cls_head) = get_model(opt)
94 | # Load saved model checkpoints
95 | saved_models = torch.load(os.path.join(opt.save_path, 'model.pth'))
96 | embedding_net.load_state_dict(saved_models['embedding'])
97 | embedding_net.eval()
98 | cls_head.load_state_dict(saved_models['head'])
99 | cls_head.eval()
100 |
101 | max_val_acc = 0.0
102 | max_test_acc = 0.0
103 |
104 | timer = Timer()
105 | x_entropy = torch.nn.CrossEntropyLoss()
106 |
107 | # Evaluate on the validation split
108 | _, _ = [x.eval() for x in (cls_head, embedding_net)]
109 |
110 | test_accuracies = []
111 | test_losses = []
112 | for i, batch in enumerate(tqdm(dloader_test()), 1):
113 | data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch]
114 |
115 | test_n_support = opt.test_way * opt.val_shot
116 | test_n_query = opt.test_way * opt.val_query
117 |
118 | with torch.no_grad():
119 | emb_support = embedding_net(data_support.reshape([-1] + list(data_support.shape[-3:])))
120 | emb_support = emb_support.reshape(1, test_n_support, -1)
121 | # emb_support = F.normalize(emb_support, dim=-1)
122 | emb_query = embedding_net(data_query.reshape([-1] + list(data_query.shape[-3:])))
123 | emb_query = emb_query.reshape(1, test_n_query, -1)
124 | # emb_query = F.normalize(emb_query, dim=-1)
125 |
126 | logit_query = cls_head.sample(emb_query, emb_support, labels_support, labels_query, opt.test_way, opt.val_shot)
127 | loss = x_entropy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1))
128 | acc = count_accuracy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1))
129 |
130 | test_accuracies.append(acc.item())
131 | test_losses.append(loss.item())
132 |
133 | test_acc_avg = np.mean(np.array(test_accuracies))
134 | test_acc_ci95 = 1.96 * np.std(np.array(test_accuracies)) / np.sqrt(opt.val_episode)
135 |
136 | test_loss_avg = np.mean(np.array(test_losses))
137 |
138 | if test_acc_avg > max_test_acc:
139 | max_test_acc = test_acc_avg
140 | log(log_file_path, 'Test Loss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} % (Best)' \
141 | .format(test_loss_avg, test_acc_avg, test_acc_ci95))
142 | else:
143 | log(log_file_path, 'Test Loss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} %' \
144 | .format(test_loss_avg, test_acc_avg, test_acc_ci95))
145 |
146 | if __name__ == '__main__':
147 | seed_torch(21)
148 | parser = argparse.ArgumentParser()
149 | parser.add_argument('--num-epoch', type=int, default=1000,
150 | help='number of training epochs')
151 | parser.add_argument('--val-shot', type=int, default=1,
152 | help='number of support examples per validation class')
153 | parser.add_argument('--val-episode', type=int, default=600,
154 | help='number of episodes per validation')
155 | parser.add_argument('--val-query', type=int, default=15,
156 | help='number of query examples per validation class')
157 | parser.add_argument('--test-way', type=int, default=5,
158 | help='number of classes in one test (or validation) episode')
159 | parser.add_argument('--save-path', default='./experiments/exp_1')
160 | parser.add_argument('--gpu', default='0')
161 | parser.add_argument('--network', type=str, default='ResNet',
162 | help='choose which embedding network to use. ProtoNet, R2D2, ResNet')
163 | parser.add_argument('--dataset', type=str, default='miniImageNet',
164 | help='choose which classification head to use. miniImageNet, tieredImageNet, CIFAR_FS, FC100')
165 |
166 | opt = parser.parse_args()
167 |
168 | (dataset_test, data_loader) = get_dataset(opt)
169 |
170 | test(opt, dataset_test, data_loader)
--------------------------------------------------------------------------------
/cal_optial_prototypes.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import argparse
4 | import random
5 | import numpy as np
6 | from tqdm import tqdm
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch.utils.data import DataLoader
11 | from torch.autograd import Variable
12 |
13 | from models.classification_heads import ClassificationHead
14 | from models.R2D2_embedding import R2D2Embedding
15 | from models.protonet_embedding import ProtoNetEmbedding
16 | from models.resnet12_2 import resnet12
17 | from utils import set_gpu, Timer, count_accuracy, check_dir, log
18 | import pickle
19 |
20 | def one_hot(indices, depth):
21 | """
22 | Returns a one-hot tensor.
23 | This is a PyTorch equivalent of Tensorflow's tf.one_hot.
24 |
25 | Parameters:
26 | indices: a (n_batch, m) Tensor or (m) Tensor.
27 | depth: a scalar. Represents the depth of the one hot dimension.
28 | Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor.
29 | """
30 |
31 | encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda()
32 | index = indices.view(indices.size()+torch.Size([1]))
33 | encoded_indicies = encoded_indicies.scatter_(1,index,1)
34 |
35 | return encoded_indicies
36 |
37 | def get_model(options):
38 | # Choose the embedding network
39 | if options.network == 'ProtoNet':
40 | network = ProtoNetEmbedding().cuda()
41 | fea_dim = 64
42 | elif options.network == 'R2D2':
43 | network = R2D2Embedding().cuda()
44 | elif options.network == 'ResNet':
45 | network = resnet12().cuda()
46 | network = torch.nn.DataParallel(network)
47 | fea_dim = 512
48 | else:
49 | print ("Cannot recognize the network type")
50 | assert(False)
51 | # info_max_layer = InfoMaxLayer().cuda()
52 | return network
53 |
54 | def get_dataset(options):
55 | # Choose the embedding network
56 | if options.dataset == 'miniImageNet':
57 | from data.mini_imagenet import MiniImageNet, FewShotDataloader
58 | dataset_train = MiniImageNet(phase='train')
59 | dataset_val = MiniImageNet(phase='val')
60 | dataset_test = MiniImageNet(phase='test')
61 | data_loader = FewShotDataloader
62 | else:
63 | print ("Cannot recognize the dataset type")
64 | assert(False)
65 |
66 | return (dataset_train, dataset_val, dataset_test, data_loader)
67 |
68 | if __name__ == '__main__':
69 | parser = argparse.ArgumentParser()
70 | parser.add_argument('--num-epoch', type=int, default=100,
71 | help='number of training epochs')
72 | parser.add_argument('--save-epoch', type=int, default=20,
73 | help='frequency of model saving')
74 | parser.add_argument('--train-shot', type=int, default=1,
75 | help='number of support examples per training class')
76 | parser.add_argument('--val-shot', type=int, default=1,
77 | help='number of support examples per validation class')
78 | parser.add_argument('--train-query', type=int, default=15,
79 | help='number of query examples per training class')
80 | parser.add_argument('--val-episode', type=int, default=600,
81 | help='number of episodes per validation')
82 | parser.add_argument('--val-query', type=int, default=15,
83 | help='number of query examples per validation class')
84 | parser.add_argument('--train-way', type=int, default=5,
85 | help='number of classes in one training episode')
86 | parser.add_argument('--test-way', type=int, default=5,
87 | help='number of classes in one test (or validation) episode')
88 | parser.add_argument('--save-path', default='./experiments/exp_1')
89 | parser.add_argument('--gpu', default='0')
90 | parser.add_argument('--network', type=str, default='ResNet',
91 | help='choose which embedding network to use. ProtoNet, R2D2, ResNet')
92 | parser.add_argument('--head', type=str, default='CosineNet',
93 | help='choose which classification head to use. ProtoNet, Ridge, R2D2, SVM')
94 | parser.add_argument('--pre_head', type=str, default='add_margin',
95 | help='choose which classification head to use. ProtoNet, Ridge, R2D2, SVM')
96 | parser.add_argument('--dataset', type=str, default='miniImageNet',
97 | help='choose which classification head to use. miniImageNet, tieredImageNet, CIFAR_FS, FC100')
98 | parser.add_argument('--episodes-per-batch', type=int, default=8,
99 | help='number of episodes per batch')
100 | parser.add_argument('--eps', type=float, default=0.0,
101 | help='epsilon of label smoothing')
102 |
103 | opt = parser.parse_args()
104 |
105 | (dataset_train, dataset_val, dataset_test, data_loader) = get_dataset(opt)
106 |
107 | data_loader_pre = torch.utils.data.DataLoader
108 | # Dataloader of Gidaris & Komodakis (CVPR 2018)
109 | dloader_train = data_loader_pre(
110 | dataset=dataset_train,
111 | batch_size=1,
112 | shuffle=False,
113 | num_workers=0
114 | )
115 |
116 | set_gpu(opt.gpu)
117 | check_dir('./experiments/')
118 | check_dir(opt.save_path)
119 |
120 | log_file_path = os.path.join(opt.save_path, "train_log.txt")
121 | log(log_file_path, str(vars(opt)))
122 |
123 | embedding_net = get_model(opt)
124 |
125 | # Load saved model checkpoints
126 | saved_models = torch.load(os.path.join(opt.save_path, 'best_pretrain_model_resnet.pth'))
127 | embedding_net.load_state_dict(saved_models['embedding'])
128 | embedding_net.eval()
129 |
130 | embs = []
131 | labels = []
132 | for i, batch in enumerate(tqdm(dloader_train), 1):
133 | data, label = [x.cuda() for x in batch]
134 | with torch.no_grad():
135 | emb = embedding_net(data)
136 | embs.append(emb)
137 | labels.append(label)
138 | embs = torch.cat(embs, dim=0).cpu()
139 | labels = list(torch.cat(labels, dim=0).cpu().numpy())
140 | label2index = {}
141 | for k, v in enumerate(labels):
142 | v = int(v)
143 | if v not in label2index.keys():
144 | label2index[v] = []
145 | label2index[v].append(k)
146 | else:
147 | label2index[v].append(k)
148 |
149 | label2optimal_proto = {}
150 | optimal_prototype = torch.zeros(64, 512).type_as(embs)
151 | for k, v in label2index.items():
152 | sub_embs = embs[v, :]
153 | label2optimal_proto[k] = torch.mean(sub_embs, dim=0).unsqueeze(dim=0)
154 | optimal_prototype[k] = torch.mean(sub_embs, dim=0)
155 |
156 | database = {'dict': label2optimal_proto, 'array': optimal_prototype}
157 | with open(os.path.join(opt.save_path, "mini_imagenet_optimal_prototype.pickle"), 'wb') as handle:
158 | pickle.dump(database, handle, protocol=pickle.HIGHEST_PROTOCOL)
159 |
160 | with open(os.path.join(opt.save_path, "mini_imagenet_optimal_prototype.pickle"), 'rb') as handle:
161 | part_prior = pickle.load(handle)
162 | print(1)
--------------------------------------------------------------------------------
/metadiff_train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import argparse
4 | import random
5 | import numpy as np
6 | from tqdm import tqdm
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch.utils.data import DataLoader
11 | from torch.autograd import Variable
12 |
13 | # from models.classification_heads import ClassificationHead
14 | from models.R2D2_embedding import R2D2Embedding
15 | from models.protonet_embedding import ProtoNetEmbedding
16 | from models.resnet12_2 import resnet12
17 | # from models.protonet_metanet1 import ProtonetMetaLearner
18 | # from models.classifier import LinearClassifier, NNClassifier, distLinear
19 | # from models.PredTrainHead import LinearRotateHead, DCLHead, DistRotateHead
20 | from models.diffusion import get_diffusion_from_args
21 |
22 | from utils import set_gpu, Timer, count_accuracy, check_dir, log
23 |
24 | def one_hot(indices, depth):
25 | """
26 | Returns a one-hot tensor.
27 | This is a PyTorch equivalent of Tensorflow's tf.one_hot.
28 |
29 | Parameters:
30 | indices: a (n_batch, m) Tensor or (m) Tensor.
31 | depth: a scalar. Represents the depth of the one hot dimension.
32 | Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor.
33 | """
34 |
35 | encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda()
36 | index = indices.view(indices.size()+torch.Size([1]))
37 | encoded_indicies = encoded_indicies.scatter_(1,index,1)
38 |
39 | return encoded_indicies
40 |
41 | def get_model(options):
42 | # Choose the embedding network
43 | if options.network == 'ProtoNet':
44 | network = ProtoNetEmbedding().cuda()
45 | fea_dim = 64
46 | elif options.network == 'R2D2':
47 | network = R2D2Embedding().cuda()
48 | elif options.network == 'ResNet':
49 | network = resnet12().cuda()
50 | network = torch.nn.DataParallel(network)
51 | fea_dim = 512
52 | else:
53 | print ("Cannot recognize the network type")
54 | assert(False)
55 |
56 |
57 | # Choose the classification head
58 |
59 | cls_head = get_diffusion_from_args().cuda()
60 | return (network, cls_head)
61 |
62 | def get_dataset(options):
63 | # Choose the embedding network
64 | if options.dataset == 'miniImageNet':
65 | from data.mini_imagenet import MiniImageNet, FewShotDataloader
66 | dataset_train = MiniImageNet(phase='train')
67 | dataset_val = MiniImageNet(phase='val')
68 | data_loader = FewShotDataloader
69 | else:
70 | print ("Cannot recognize the dataset type")
71 | assert(False)
72 |
73 | return (dataset_train, dataset_val, data_loader)
74 |
75 | def seed_torch(seed=21):
76 | os.environ['PYTHONHASHSEED'] = str(seed)
77 | random.seed(seed)
78 | np.random.seed(seed)
79 | torch.manual_seed(seed)
80 | torch.cuda.manual_seed(seed)
81 | torch.cuda.manual_seed_all(seed)
82 | # torch.backends.cudnn.deterministic = True
83 | torch.backends.cudnn.deterministic = True #cpu/gpu结果一致
84 | torch.backends.cudnn.benchmark = False #训练集变化不大时使训练加速
85 |
86 | def train(opt, dataset_train, dataset_val, data_loader):
87 | # Dataloader of Gidaris & Komodakis (CVPR 2018)
88 | dloader_train = data_loader(
89 | dataset=dataset_train,
90 | nKnovel=opt.train_way,
91 | nKbase=0,
92 | nExemplars=opt.train_shot, # num training examples per novel category
93 | nTestNovel=opt.train_way * opt.train_query, # num test examples for all the novel categories
94 | nTestBase=0, # num test examples for all the base categories
95 | batch_size=opt.episodes_per_batch,
96 | num_workers=4,
97 | epoch_size=opt.episodes_per_batch * 10000, # num of batches per epoch
98 | )
99 |
100 | dloader_val = data_loader(
101 | dataset=dataset_val,
102 | nKnovel=opt.test_way,
103 | nKbase=0,
104 | nExemplars=opt.val_shot, # num training examples per novel category
105 | nTestNovel=opt.val_query * opt.test_way, # num test examples for all the novel categories
106 | nTestBase=0, # num test examples for all the base categories
107 | batch_size=1,
108 | num_workers=0,
109 | epoch_size=1 * opt.val_episode, # num of batches per epoch
110 | )
111 |
112 | set_gpu(opt.gpu)
113 | check_dir('./experiments/')
114 | check_dir(opt.save_path)
115 |
116 | log_file_path = os.path.join(opt.save_path, "train_log.txt")
117 | log(log_file_path, str(vars(opt)))
118 |
119 | (embedding_net, cls_head) = get_model(opt)
120 | # Load saved model checkpoints
121 | saved_models = torch.load(os.path.join(opt.save_path, 'best_pretrain_model_resnet.pth'))
122 | # saved_models = torch.load(os.path.join(opt.save_path, 'best_model2.pth'))
123 | embedding_net.load_state_dict(saved_models['embedding'])
124 | embedding_net.eval()
125 |
126 | import pickle
127 | with open(os.path.join(opt.save_path, "mini_imagenet_optimal_prototype.pickle"), 'rb') as handle:
128 | prototype_ground_true = pickle.load(handle)
129 | prototype_ground_true_dict = prototype_ground_true['dict']
130 | prototype_ground_true_array = prototype_ground_true['array'].cuda()
131 |
132 | optimizer = torch.optim.Adam(cls_head.parameters(), lr=0.0001, weight_decay=5e-4)
133 |
134 | lambda_epoch = lambda e: 1.0 if e < 100000 else (0.1 if e < 30 else 0.01 if e < 40 else (0.001))
135 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_epoch, last_epoch=-1)
136 |
137 | max_val_acc = 0.0
138 | max_test_acc = 0.0
139 |
140 | timer = Timer()
141 | x_entropy = torch.nn.CrossEntropyLoss()
142 |
143 | for epoch in range(1, opt.num_epoch + 1):
144 | # Train on the training split
145 | lr_scheduler.step()
146 |
147 | # Fetch the current epoch's learning rate
148 | epoch_learning_rate = 0.1
149 | for param_group in optimizer.param_groups:
150 | epoch_learning_rate = param_group['lr']
151 |
152 | log(log_file_path, 'Train Epoch: {}\tLearning Rate: {:.4f}'.format(
153 | epoch, epoch_learning_rate))
154 |
155 | _ = [x.train() for x in (cls_head, )]
156 | # _, _ = [x.train() for x in (cls_head, embedding_net)]
157 |
158 | train_accuracies = []
159 | train_losses = []
160 |
161 | for i, batch in enumerate(tqdm(dloader_train(epoch)), 1):
162 | data_support, labels_support, data_query, labels_query, all_k, _ = [x.cuda() for x in batch]
163 |
164 | # prototype_ground_true = torch.gather(prototype_ground_true_array, 0, all_k)
165 | prototype_ground_true = prototype_ground_true_array[all_k, :]
166 |
167 | train_n_support = opt.train_way * opt.train_shot
168 | train_n_query = opt.train_way * opt.train_query
169 |
170 | with torch.no_grad():
171 | emb_support = embedding_net(data_support.reshape([-1] + list(data_support.shape[-3:])))
172 | emb_support = emb_support.reshape(opt.episodes_per_batch, train_n_support, -1)
173 | # emb_support = F.normalize(emb_support, dim=-1)
174 |
175 | emb_query = embedding_net(data_query.reshape([-1] + list(data_query.shape[-3:])))
176 | emb_query = emb_query.reshape(opt.episodes_per_batch, train_n_query, -1)
177 | # emb_query = F.normalize(emb_query, dim=-1)
178 |
179 | loss = cls_head(prototype_ground_true, emb_query, emb_support, labels_support, labels_query, opt.train_way, opt.train_shot)
180 |
181 |
182 |
183 | if (i % 1000 == 0):
184 | logit_query = cls_head.sample(emb_query, emb_support, labels_support, labels_query, opt.train_way,
185 | opt.train_shot)
186 | # loss = x_entropy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1))
187 | acc = count_accuracy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1))
188 |
189 | train_accuracies.append(acc.item())
190 | train_losses.append(loss.item())
191 | train_acc_avg = np.mean(np.array(train_accuracies))
192 | log(log_file_path, 'Train Epoch: {}\tBatch: [{}]\tLoss: {}\tAccuracy: {} % ({} %)'.format(
193 | epoch, i, loss.item(), train_acc_avg, acc))
194 |
195 | optimizer.zero_grad()
196 | loss.backward()
197 | optimizer.step()
198 |
199 | # Evaluate on the validation split
200 | _, _ = [x.eval() for x in (cls_head, embedding_net)]
201 |
202 | val_accuracies = []
203 | val_losses = []
204 |
205 | for i, batch in enumerate(tqdm(dloader_val(epoch)), 1):
206 | data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch]
207 |
208 | test_n_support = opt.test_way * opt.val_shot
209 | test_n_query = opt.test_way * opt.val_query
210 | with torch.no_grad():
211 | emb_support = embedding_net(data_support.reshape([-1] + list(data_support.shape[-3:])))
212 | emb_support = emb_support.reshape(1, test_n_support, -1)
213 | # emb_support = F.normalize(emb_support, dim=-1)
214 | emb_query = embedding_net(data_query.reshape([-1] + list(data_query.shape[-3:])))
215 | emb_query = emb_query.reshape(1, test_n_query, -1)
216 | # emb_query = F.normalize(emb_query, dim=-1)
217 |
218 | logit_query = cls_head.sample(emb_query, emb_support, labels_support, labels_query, opt.test_way, opt.val_shot)
219 | loss = x_entropy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1))
220 | acc = count_accuracy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1))
221 |
222 | val_accuracies.append(acc.item())
223 | val_losses.append(loss.item())
224 |
225 | val_acc_avg = np.mean(np.array(val_accuracies))
226 | val_acc_ci95 = 1.96 * np.std(np.array(val_accuracies)) / np.sqrt(opt.val_episode)
227 |
228 | val_loss_avg = np.mean(np.array(val_losses))
229 |
230 | if val_acc_avg > max_val_acc:
231 | max_val_acc = val_acc_avg
232 | # torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()}, \
233 | # os.path.join(opt.save_path, 'best_model_val_double_dir_opt_cub_{}_shot.pth'.format(opt.val_shot)))
234 | log(log_file_path, 'Validation Epoch: {}\t\t\tLoss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} % (Best)' \
235 | .format(epoch, val_loss_avg, val_acc_avg, val_acc_ci95))
236 | else:
237 | log(log_file_path, 'Validation Epoch: {}\t\t\tLoss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} %' \
238 | .format(epoch, val_loss_avg, val_acc_avg, val_acc_ci95))
239 |
240 |
241 | def test(opt, n_iter, dataset_train, dataset_val, dataset_test, data_loader):
242 | # Dataloader of Gidaris & Komodakis (CVPR 2018)
243 | dloader_test = data_loader(
244 | dataset=dataset_test,
245 | nKnovel=opt.test_way,
246 | nKbase=0,
247 | nExemplars=opt.val_shot, # num training examples per novel category
248 | nTestNovel=opt.val_query * opt.test_way, # num test examples for all the novel categories
249 | nTestBase=0, # num test examples for all the base categories
250 | batch_size=1,
251 | num_workers=0,
252 | epoch_size=1 * opt.val_episode, # num of batches per epoch
253 | )
254 |
255 | set_gpu(opt.gpu)
256 | check_dir('./experiments/')
257 | check_dir(opt.save_path)
258 |
259 | log_file_path = os.path.join(opt.save_path, "train_log.txt")
260 | log(log_file_path, str(vars(opt)))
261 |
262 | (embedding_net, cls_head) = get_model(opt)
263 | # Load saved model checkpoints
264 | saved_models = torch.load(os.path.join(opt.save_path, 'best_model_val_double_dir_opt_cub_5_shot.pth'))
265 | # saved_models = torch.load(os.path.join(opt.save_path, 'best_model2.pth'))
266 | embedding_net.load_state_dict(saved_models['embedding'])
267 | embedding_net.eval()
268 | cls_head.load_state_dict(saved_models['head'])
269 | cls_head.eval()
270 |
271 | max_val_acc = 0.0
272 | max_test_acc = 0.0
273 |
274 | timer = Timer()
275 | x_entropy = torch.nn.CrossEntropyLoss()
276 |
277 | # Evaluate on the validation split
278 | _, _ = [x.eval() for x in (cls_head, embedding_net)]
279 |
280 | test_accuracies = []
281 | test_losses = []
282 | epoch = 9
283 | for i, batch in enumerate(tqdm(dloader_test(epoch)), 1):
284 | data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch]
285 |
286 | test_n_support = opt.test_way * opt.val_shot
287 | test_n_query = opt.test_way * opt.val_query
288 |
289 | with torch.no_grad():
290 | emb_support = embedding_net(data_support.reshape([-1] + list(data_support.shape[-3:])))
291 | emb_support = emb_support.reshape(1, test_n_support, -1)
292 | # emb_support = F.normalize(emb_support, dim=-1)
293 | emb_query = embedding_net(data_query.reshape([-1] + list(data_query.shape[-3:])))
294 | emb_query = emb_query.reshape(1, test_n_query, -1)
295 | # emb_query = F.normalize(emb_query, dim=-1)
296 |
297 | logit_querys = cls_head(emb_query, emb_support, labels_support, labels_query, opt.test_way, opt.val_shot, is_train=False, update_step_test=n_iter)
298 | # logit_querys = logit_querys[-1:]
299 | for kk, logit_query in enumerate(logit_querys):
300 | loss = x_entropy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1))
301 | acc = count_accuracy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1))
302 |
303 | test_accuracies.append(acc.item())
304 | test_losses.append(loss.item())
305 |
306 | test_acc_avg = np.mean(np.array(test_accuracies))
307 | test_acc_ci95 = 1.96 * np.std(np.array(test_accuracies)) / np.sqrt(opt.val_episode)
308 |
309 | test_loss_avg = np.mean(np.array(test_losses))
310 |
311 | if test_acc_avg > max_test_acc:
312 | max_test_acc = test_acc_avg
313 | log(log_file_path, 'Test Loss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} % (Best)' \
314 | .format(test_loss_avg, test_acc_avg, test_acc_ci95))
315 | else:
316 | log(log_file_path, 'Test Loss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} %' \
317 | .format(test_loss_avg, test_acc_avg, test_acc_ci95))
318 |
319 | if __name__ == '__main__':
320 | seed_torch(21)
321 | parser = argparse.ArgumentParser()
322 | parser.add_argument('--num-epoch', type=int, default=1000,
323 | help='number of training epochs')
324 | parser.add_argument('--save-epoch', type=int, default=10,
325 | help='frequency of model saving')
326 | parser.add_argument('--train-shot', type=int, default=5,
327 | help='number of support examples per training class')
328 | parser.add_argument('--val-shot', type=int, default=5,
329 | help='number of support examples per validation class')
330 | parser.add_argument('--train-query', type=int, default=15,
331 | help='number of query examples per training class')
332 | parser.add_argument('--val-episode', type=int, default=600,
333 | help='number of episodes per validation')
334 | parser.add_argument('--val-query', type=int, default=15,
335 | help='number of query examples per validation class')
336 | parser.add_argument('--train-way', type=int, default=5,
337 | help='number of classes in one training episode')
338 | parser.add_argument('--test-way', type=int, default=5,
339 | help='number of classes in one test (or validation) episode')
340 | parser.add_argument('--save-path', default='./experiments/exp_1')
341 | parser.add_argument('--gpu', default='0')
342 | parser.add_argument('--network', type=str, default='ResNet',
343 | help='choose which embedding network to use. ProtoNet, R2D2, ResNet')
344 | parser.add_argument('--head', type=str, default='CosineNet',
345 | help='choose which classification head to use. ProtoNet, Ridge, R2D2, SVM')
346 | parser.add_argument('--pre_head', type=str, default='LinearRotateNet',
347 | help='choose which classification head to use. ProtoNet, Ridge, R2D2, SVM')
348 | parser.add_argument('--dataset', type=str, default='miniImageNet',
349 | help='choose which classification head to use. miniImageNet, tieredImageNet, CIFAR_FS, FC100')
350 | parser.add_argument('--episodes-per-batch', type=int, default=8,
351 | help='number of episodes per batch')
352 | parser.add_argument('--eps', type=float, default=0.0,
353 | help='epsilon of label smoothing')
354 |
355 |
356 | opt = parser.parse_args()
357 |
358 | (dataset_train, dataset_val, data_loader) = get_dataset(opt)
359 |
360 | train(opt, dataset_train, dataset_val, data_loader)
--------------------------------------------------------------------------------
/data/mini_imagenet.py:
--------------------------------------------------------------------------------
1 | # Dataloader of Gidaris & Komodakis, CVPR 2018
2 | # Adapted from:
3 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/dataloader.py
4 | from __future__ import print_function
5 |
6 | import os
7 | import os.path
8 | import numpy as np
9 | import random
10 | import pickle
11 | import json
12 | import math
13 |
14 | import torch
15 | import torch.utils.data as data
16 | import torchvision
17 | import torchvision.datasets as datasets
18 | import torchvision.transforms as transforms
19 | import torchnet as tnt
20 |
21 | import h5py
22 |
23 | from PIL import Image
24 | from PIL import ImageEnhance
25 |
26 | from pdb import set_trace as breakpoint
27 |
28 |
29 | # Set the appropriate paths of the datasets here.
30 | _MINI_IMAGENET_DATASET_DIR = '../datasets/few_shot_data/MiniImagenet'
31 |
32 | def buildLabelIndex(labels):
33 | label2inds = {}
34 | for idx, label in enumerate(labels):
35 | if label not in label2inds:
36 | label2inds[label] = []
37 | label2inds[label].append(idx)
38 |
39 | return label2inds
40 |
41 |
42 | def load_data(file):
43 | try:
44 | with open(file, 'rb') as fo:
45 | data = pickle.load(fo)
46 | return data
47 | except:
48 | with open(file, 'rb') as f:
49 | u = pickle._Unpickler(f)
50 | u.encoding = 'latin1'
51 | data = u.load()
52 | return data
53 |
54 | class MiniImageNet(data.Dataset):
55 | def __init__(self, phase='train', do_not_use_random_transf=False):
56 |
57 | self.base_folder = 'miniImagenet'
58 | assert(phase=='train' or phase=='val' or phase=='test' or phase=='all_train')
59 | self.phase = phase
60 | self.name = 'MiniImageNet_' + phase
61 |
62 | print('Loading mini ImageNet dataset - phase {0}'.format(phase))
63 | file_train_categories_train_phase = os.path.join(
64 | _MINI_IMAGENET_DATASET_DIR,
65 | 'miniImageNet_category_split_train_phase_train.pickle')
66 | file_train_categories_val_phase = os.path.join(
67 | _MINI_IMAGENET_DATASET_DIR,
68 | 'miniImageNet_category_split_train_phase_val.pickle')
69 | file_train_categories_test_phase = os.path.join(
70 | _MINI_IMAGENET_DATASET_DIR,
71 | 'miniImageNet_category_split_train_phase_test.pickle')
72 | file_val_categories_val_phase = os.path.join(
73 | _MINI_IMAGENET_DATASET_DIR,
74 | 'miniImageNet_category_split_val.pickle')
75 | file_test_categories_test_phase = os.path.join(
76 | _MINI_IMAGENET_DATASET_DIR,
77 | 'miniImageNet_category_split_test.pickle')
78 |
79 | if self.phase=='train':
80 | # During training phase we only load the training phase images
81 | # of the training categories (aka base categories).
82 | data_train = load_data(file_train_categories_train_phase)
83 | self.data = data_train['data']
84 | self.labels = data_train['labels']
85 |
86 | self.label2ind = buildLabelIndex(self.labels)
87 | self.labelIds = sorted(self.label2ind.keys())
88 | self.num_cats = len(self.labelIds)
89 | self.labelIds_base = self.labelIds
90 | self.num_cats_base = len(self.labelIds_base)
91 |
92 | elif self.phase=='val' or self.phase=='test':
93 | if self.phase=='test':
94 | # load data that will be used for evaluating the recognition
95 | # accuracy of the base categories.
96 | data_base = load_data(file_train_categories_test_phase)
97 | # load data that will be use for evaluating the few-shot recogniton
98 | # accuracy on the novel categories.
99 | data_novel = load_data(file_test_categories_test_phase)
100 | else: # phase=='val'
101 | # load data that will be used for evaluating the recognition
102 | # accuracy of the base categories.
103 | data_base = load_data(file_train_categories_val_phase)
104 | # load data that will be use for evaluating the few-shot recogniton
105 | # accuracy on the novel categories.
106 | data_novel = load_data(file_val_categories_val_phase)
107 |
108 | self.data = np.concatenate(
109 | [data_base['data'], data_novel['data']], axis=0)
110 | self.labels = data_base['labels'] + data_novel['labels']
111 |
112 | self.label2ind = buildLabelIndex(self.labels)
113 | self.labelIds = sorted(self.label2ind.keys())
114 | self.num_cats = len(self.labelIds)
115 |
116 | self.labelIds_base = buildLabelIndex(data_base['labels']).keys()
117 | self.labelIds_novel = buildLabelIndex(data_novel['labels']).keys()
118 | self.num_cats_base = len(self.labelIds_base)
119 | self.num_cats_novel = len(self.labelIds_novel)
120 | intersection = set(self.labelIds_base) & set(self.labelIds_novel)
121 | assert(len(intersection) == 0)
122 | elif self.phase=='all_train':
123 | data_train = load_data(file_train_categories_train_phase)
124 | data_test = load_data(file_test_categories_test_phase)
125 | data_val = load_data(file_val_categories_val_phase)
126 |
127 | self.data = np.concatenate(
128 | [data_train['data'], data_test['data'], data_val['data']], axis=0)
129 | self.labels = data_train['labels'] + data_test['labels']+ data_val['labels']
130 |
131 | self.label2ind = buildLabelIndex(self.labels)
132 | self.labelIds = sorted(self.label2ind.keys())
133 | self.num_cats = len(self.labelIds)
134 | self.labelIds_base = self.labelIds
135 | self.num_cats_base = len(self.labelIds_base)
136 | # assert(len(intersection) == 0)
137 | else:
138 | raise ValueError('Not valid phase {0}'.format(self.phase))
139 |
140 | mean_pix = [x/255.0 for x in [120.39586422, 115.59361427, 104.54012653]]
141 | std_pix = [x/255.0 for x in [70.68188272, 68.27635443, 72.54505529]]
142 | normalize = transforms.Normalize(mean=mean_pix, std=std_pix)
143 |
144 | if (self.phase=='test' or self.phase=='val') or (do_not_use_random_transf==True):
145 | self.transform = transforms.Compose([
146 | # transforms.RandomCrop(84, padding=8),
147 | lambda x: np.asarray(x),
148 | transforms.ToTensor(),
149 | normalize
150 | ])
151 | else:
152 | self.transform = transforms.Compose([
153 | transforms.RandomCrop(84, padding=8),
154 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
155 | transforms.RandomHorizontalFlip(),
156 | lambda x: np.asarray(x),
157 | transforms.ToTensor(),
158 | normalize
159 | ])
160 |
161 | def __getitem__(self, index):
162 | img, label = self.data[index], self.labels[index]
163 | # doing this so that it is consistent with all other datasets
164 | # to return a PIL Image
165 | img = Image.fromarray(img)
166 | if self.transform is not None:
167 | img = self.transform(img)
168 | return img, label
169 |
170 | def __len__(self):
171 | return len(self.data)
172 |
173 |
174 | class FewShotDataloader():
175 | def __init__(self,
176 | dataset,
177 | nKnovel=5, # number of novel categories.
178 | nKbase=-1, # number of base categories.
179 | nExemplars=1, # number of training examples per novel category.
180 | nTestNovel=15*5, # number of test examples for all the novel categories.
181 | nTestBase=15*5, # number of test examples for all the base categories.
182 | batch_size=1, # number of training episodes per batch.
183 | num_workers=4,
184 | epoch_size=2000, # number of batches per epoch.
185 | ):
186 |
187 | self.dataset = dataset
188 | self.phase = self.dataset.phase
189 | max_possible_nKnovel = (self.dataset.num_cats_base if self.phase=='train'
190 | else self.dataset.num_cats_novel)
191 | assert(nKnovel >= 0 and nKnovel < max_possible_nKnovel)
192 | self.nKnovel = nKnovel
193 |
194 | max_possible_nKbase = self.dataset.num_cats_base
195 | nKbase = nKbase if nKbase >= 0 else max_possible_nKbase
196 | if self.phase=='train' and nKbase > 0:
197 | nKbase -= self.nKnovel
198 | max_possible_nKbase -= self.nKnovel
199 |
200 | assert(nKbase >= 0 and nKbase <= max_possible_nKbase)
201 | self.nKbase = nKbase
202 |
203 | self.nExemplars = nExemplars
204 | self.nTestNovel = nTestNovel
205 | self.nTestBase = nTestBase
206 | self.batch_size = batch_size
207 | self.epoch_size = epoch_size
208 | self.num_workers = num_workers
209 | self.is_eval_mode = (self.phase=='test') or (self.phase=='val')
210 |
211 | def sampleImageIdsFrom(self, cat_id, sample_size=1):
212 | """
213 | Samples `sample_size` number of unique image ids picked from the
214 | category `cat_id` (i.e., self.dataset.label2ind[cat_id]).
215 |
216 | Args:
217 | cat_id: a scalar with the id of the category from which images will
218 | be sampled.
219 | sample_size: number of images that will be sampled.
220 |
221 | Returns:
222 | image_ids: a list of length `sample_size` with unique image ids.
223 | """
224 | assert(cat_id in self.dataset.label2ind)
225 | assert(len(self.dataset.label2ind[cat_id]) >= sample_size)
226 | # Note: random.sample samples elements without replacement.
227 | return random.sample(self.dataset.label2ind[cat_id], sample_size)
228 |
229 | def sampleCategories(self, cat_set, sample_size=1):
230 | """
231 | Samples `sample_size` number of unique categories picked from the
232 | `cat_set` set of categories. `cat_set` can be either 'base' or 'novel'.
233 |
234 | Args:
235 | cat_set: string that specifies the set of categories from which
236 | categories will be sampled.
237 | sample_size: number of categories that will be sampled.
238 |
239 | Returns:
240 | cat_ids: a list of length `sample_size` with unique category ids.
241 | """
242 | if cat_set=='base':
243 | labelIds = self.dataset.labelIds_base
244 | elif cat_set=='novel':
245 | labelIds = self.dataset.labelIds_novel
246 | else:
247 | raise ValueError('Not recognized category set {}'.format(cat_set))
248 |
249 | assert(len(labelIds) >= sample_size)
250 | # return sample_size unique categories chosen from labelIds set of
251 | # categories (that can be either self.labelIds_base or self.labelIds_novel)
252 | # Note: random.sample samples elements without replacement.
253 | return random.sample(labelIds, sample_size)
254 |
255 | def sample_base_and_novel_categories(self, nKbase, nKnovel):
256 | """
257 | Samples `nKbase` number of base categories and `nKnovel` number of novel
258 | categories.
259 |
260 | Args:
261 | nKbase: number of base categories
262 | nKnovel: number of novel categories
263 |
264 | Returns:
265 | Kbase: a list of length 'nKbase' with the ids of the sampled base
266 | categories.
267 | Knovel: a list of lenght 'nKnovel' with the ids of the sampled novel
268 | categories.
269 | """
270 | if self.is_eval_mode:
271 | assert(nKnovel <= self.dataset.num_cats_novel)
272 | # sample from the set of base categories 'nKbase' number of base
273 | # categories.
274 | Kbase = sorted(self.sampleCategories('base', nKbase))
275 | # sample from the set of novel categories 'nKnovel' number of novel
276 | # categories.
277 | Knovel = sorted(self.sampleCategories('novel', nKnovel))
278 | else:
279 | # sample from the set of base categories 'nKnovel' + 'nKbase' number
280 | # of categories.
281 | cats_ids = self.sampleCategories('base', nKnovel+nKbase)
282 | assert(len(cats_ids) == (nKnovel+nKbase))
283 | # Randomly pick 'nKnovel' number of fake novel categories and keep
284 | # the rest as base categories.
285 | random.shuffle(cats_ids)
286 | Knovel = sorted(cats_ids[:nKnovel])
287 | Kbase = sorted(cats_ids[nKnovel:])
288 |
289 | return Kbase, Knovel
290 |
291 | def sample_test_examples_for_base_categories(self, Kbase, nTestBase):
292 | """
293 | Sample `nTestBase` number of images from the `Kbase` categories.
294 |
295 | Args:
296 | Kbase: a list of length `nKbase` with the ids of the categories from
297 | where the images will be sampled.
298 | nTestBase: the total number of images that will be sampled.
299 |
300 | Returns:
301 | Tbase: a list of length `nTestBase` with 2-element tuples. The 1st
302 | element of each tuple is the image id that was sampled and the
303 | 2nd elemend is its category label (which is in the range
304 | [0, len(Kbase)-1]).
305 | """
306 | Tbase = []
307 | if len(Kbase) > 0:
308 | # Sample for each base category a number images such that the total
309 | # number sampled images of all categories to be equal to `nTestBase`.
310 | KbaseIndices = np.random.choice(
311 | np.arange(len(Kbase)), size=nTestBase, replace=True)
312 | KbaseIndices, NumImagesPerCategory = np.unique(
313 | KbaseIndices, return_counts=True)
314 |
315 | for Kbase_idx, NumImages in zip(KbaseIndices, NumImagesPerCategory):
316 | imd_ids = self.sampleImageIdsFrom(
317 | Kbase[Kbase_idx], sample_size=NumImages)
318 | Tbase += [(img_id, Kbase_idx) for img_id in imd_ids]
319 |
320 | assert(len(Tbase) == nTestBase)
321 |
322 | return Tbase
323 |
324 | def sample_train_and_test_examples_for_novel_categories(
325 | self, Knovel, nTestNovel, nExemplars, nKbase):
326 | """Samples train and test examples of the novel categories.
327 |
328 | Args:
329 | Knovel: a list with the ids of the novel categories.
330 | nTestNovel: the total number of test images that will be sampled
331 | from all the novel categories.
332 | nExemplars: the number of training examples per novel category that
333 | will be sampled.
334 | nKbase: the number of base categories. It is used as offset of the
335 | category index of each sampled image.
336 |
337 | Returns:
338 | Tnovel: a list of length `nTestNovel` with 2-element tuples. The
339 | 1st element of each tuple is the image id that was sampled and
340 | the 2nd element is its category label (which is in the range
341 | [nKbase, nKbase + len(Knovel) - 1]).
342 | Exemplars: a list of length len(Knovel) * nExemplars of 2-element
343 | tuples. The 1st element of each tuple is the image id that was
344 | sampled and the 2nd element is its category label (which is in
345 | the ragne [nKbase, nKbase + len(Knovel) - 1]).
346 | """
347 |
348 | if len(Knovel) == 0:
349 | return [], []
350 |
351 | nKnovel = len(Knovel)
352 | Tnovel = []
353 | Exemplars = []
354 | assert((nTestNovel % nKnovel) == 0)
355 | nEvalExamplesPerClass = int(nTestNovel / nKnovel)
356 |
357 | for Knovel_idx in range(len(Knovel)):
358 | imd_ids = self.sampleImageIdsFrom(
359 | Knovel[Knovel_idx],
360 | sample_size=(nEvalExamplesPerClass + nExemplars))
361 |
362 | imds_tnovel = imd_ids[:nEvalExamplesPerClass]
363 | imds_ememplars = imd_ids[nEvalExamplesPerClass:]
364 |
365 | Tnovel += [(img_id, nKbase+Knovel_idx) for img_id in imds_tnovel]
366 | Exemplars += [(img_id, nKbase+Knovel_idx) for img_id in imds_ememplars]
367 | assert(len(Tnovel) == nTestNovel)
368 | assert(len(Exemplars) == len(Knovel) * nExemplars)
369 | random.shuffle(Exemplars)
370 |
371 | return Tnovel, Exemplars
372 |
373 | def sample_episode(self):
374 | """Samples a training episode."""
375 | nKnovel = self.nKnovel
376 | nKbase = self.nKbase
377 | nTestNovel = self.nTestNovel
378 | nTestBase = self.nTestBase
379 | nExemplars = self.nExemplars
380 |
381 | Kbase, Knovel = self.sample_base_and_novel_categories(nKbase, nKnovel)
382 | Tbase = self.sample_test_examples_for_base_categories(Kbase, nTestBase)
383 | Tnovel, Exemplars = self.sample_train_and_test_examples_for_novel_categories(
384 | Knovel, nTestNovel, nExemplars, nKbase)
385 |
386 | # concatenate the base and novel category examples.
387 | Test = Tbase + Tnovel
388 | random.shuffle(Test)
389 | Kall = Kbase + Knovel
390 |
391 | return Exemplars, Test, Kall, nKbase
392 |
393 | def createExamplesTensorData(self, examples):
394 | """
395 | Creates the examples image and label tensor data.
396 |
397 | Args:
398 | examples: a list of 2-element tuples, each representing a
399 | train or test example. The 1st element of each tuple
400 | is the image id of the example and 2nd element is the
401 | category label of the example, which is in the range
402 | [0, nK - 1], where nK is the total number of categories
403 | (both novel and base).
404 |
405 | Returns:
406 | images: a tensor of shape [nExamples, Height, Width, 3] with the
407 | example images, where nExamples is the number of examples
408 | (i.e., nExamples = len(examples)).
409 | labels: a tensor of shape [nExamples] with the category label
410 | of each example.
411 | """
412 | images = torch.stack(
413 | [self.dataset[img_idx][0] for img_idx, _ in examples], dim=0)
414 | labels = torch.LongTensor([label for _, label in examples])
415 | return images, labels
416 |
417 | def get_iterator(self, ):
418 | rand_seed = 8
419 | random.seed(rand_seed)
420 | np.random.seed(rand_seed)
421 | def load_function(iter_idx):
422 | Exemplars, Test, Kall, nKbase = self.sample_episode()
423 | Xt, Yt = self.createExamplesTensorData(Test)
424 | Kall = torch.LongTensor(Kall)
425 | if len(Exemplars) > 0:
426 | Xe, Ye = self.createExamplesTensorData(Exemplars)
427 | return Xe, Ye, Xt, Yt, Kall, nKbase
428 | else:
429 | return Xt, Yt, Kall, nKbase
430 |
431 | tnt_dataset = tnt.dataset.ListDataset(
432 | elem_list=range(self.epoch_size), load=load_function)
433 | data_loader = tnt_dataset.parallel(
434 | batch_size=self.batch_size,
435 | num_workers=(0 if self.is_eval_mode else self.num_workers),
436 | shuffle=(False if self.is_eval_mode else True))
437 |
438 | return data_loader
439 |
440 | def __call__(self, epoch=0):
441 | return self.get_iterator(epoch)
442 |
443 | def __len__(self):
444 | return (self.epoch_size / self.batch_size)
445 |
--------------------------------------------------------------------------------
/models/diffusion_grad.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from functools import partial
7 | from copy import deepcopy
8 |
9 | import argparse
10 | import torchvision
11 | from .attention import SelfAttention, CrossAttention
12 | # from models.unet import UNet
13 |
14 | # from .ema import EMA
15 | # from .utils import extract
16 |
17 | class TimeEmbedding(nn.Module):
18 | def __init__(self, n_embd):
19 | super().__init__()
20 | self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
21 | self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)
22 |
23 | def forward(self, x):
24 | x = self.linear_1(x)
25 | x = F.silu(x)
26 | x = self.linear_2(x)
27 | return x
28 |
29 |
30 | class ResidualBlock(nn.Module):
31 | def __init__(self, in_channels, out_channels, n_time=320):
32 | super().__init__()
33 | self.groupnorm_feature = nn.GroupNorm(32, in_channels)
34 | self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
35 | self.linear_time = nn.Linear(n_time, out_channels)
36 |
37 | self.groupnorm_merged = nn.GroupNorm(32, out_channels)
38 | self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
39 |
40 | if in_channels == out_channels:
41 | self.residual_layer = nn.Identity()
42 | else:
43 | self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
44 |
45 | def forward(self, feature, time):
46 | residue = feature
47 |
48 | feature = self.groupnorm_feature(feature)
49 | feature = F.silu(feature)
50 | feature = self.conv_feature(feature)
51 |
52 | time = F.silu(time)
53 | time = self.linear_time(time)
54 |
55 | merged = feature + time.unsqueeze(-1).unsqueeze(-1)
56 | merged = self.groupnorm_merged(merged)
57 | merged = F.silu(merged)
58 | merged = self.conv_merged(merged)
59 |
60 | return merged + self.residual_layer(residue)
61 |
62 |
63 | class AttentionBlock(nn.Module):
64 | def __init__(self, n_head: int, n_embd: int, d_context=512):
65 | super().__init__()
66 | channels = n_head * n_embd
67 |
68 | self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
69 | self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
70 |
71 | self.layernorm_1 = nn.LayerNorm(channels)
72 | self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
73 | self.layernorm_2 = nn.LayerNorm(channels)
74 | self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
75 | self.layernorm_3 = nn.LayerNorm(channels)
76 | self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
77 | self.linear_geglu_2 = nn.Linear(4 * channels, channels)
78 |
79 | self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
80 |
81 | def forward(self, x, context):
82 | residue_long = x
83 |
84 | x = self.groupnorm(x)
85 | x = self.conv_input(x)
86 |
87 | n, c, h, w = x.shape
88 | x = x.view((n, c, h * w)) # (n, c, hw)
89 | x = x.transpose(-1, -2) # (n, hw, c)
90 |
91 | residue_short = x
92 | x = self.layernorm_1(x)
93 | x = self.attention_1(x)
94 | x += residue_short
95 |
96 | residue_short = x
97 | x = self.layernorm_2(x)
98 | x = self.attention_2(x, context)
99 | x += residue_short
100 |
101 | residue_short = x
102 | x = self.layernorm_3(x)
103 | x, gate = self.linear_geglu_1(x).chunk(2, dim=-1)
104 | x = x * F.gelu(gate)
105 | x = self.linear_geglu_2(x)
106 | x += residue_short
107 |
108 | x = x.transpose(-1, -2) # (n, c, hw)
109 | x = x.view((n, c, h, w)) # (n, c, h, w)
110 |
111 | return self.conv_output(x) + residue_long
112 |
113 |
114 | class Upsample(nn.Module):
115 | def __init__(self, channels):
116 | super().__init__()
117 | self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
118 |
119 | def forward(self, x):
120 | x = F.interpolate(x, scale_factor=2, mode='nearest')
121 | return self.conv(x)
122 |
123 |
124 | class SwitchSequential(nn.Sequential):
125 | def forward(self, x, context, time):
126 | for layer in self:
127 | if isinstance(layer, AttentionBlock):
128 | x = layer(x, context)
129 | elif isinstance(layer, ResidualBlock):
130 | x = layer(x, time)
131 | else:
132 | x = layer(x)
133 | return x
134 |
135 |
136 | class UNet(nn.Module):
137 | def __init__(self):
138 | super().__init__()
139 | self.encoders = nn.ModuleList([
140 | SwitchSequential(ResidualBlock(512, 256), AttentionBlock(8, 32)),
141 | SwitchSequential(ResidualBlock(256, 128), AttentionBlock(8, 16)),
142 | ])
143 | self.bottleneck = SwitchSequential(
144 | ResidualBlock(128, 128),
145 | AttentionBlock(8, 16),
146 | ResidualBlock(128, 128),
147 | )
148 | self.decoders = nn.ModuleList([
149 | SwitchSequential(ResidualBlock(128+128, 256), AttentionBlock(8, 32)),
150 | SwitchSequential(ResidualBlock(256+256, 512), AttentionBlock(8, 64)),
151 | ])
152 |
153 | self.pos_emb = nn.Linear(5, 512, bias=False)
154 |
155 | def get_time_embedding(self, timestep):
156 | freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160).cuda()
157 | x = timestep[:, None] * freqs[None]
158 | return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
159 |
160 | def forward(self, x, time = None, emb_query = None, emb_support = None, labels_support = None, labels_query = None, n_way = None, k_shot = None):
161 | time = self.get_time_embedding(time)
162 |
163 | tasks_per_batch = emb_query.size(0)
164 | n_support = emb_support.size(1)
165 |
166 |
167 | support_labels_one_hot = one_hot(labels_support.reshape(-1), n_way)
168 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way)
169 |
170 | pos_emb = self.pos_emb(support_labels_one_hot)
171 | context = emb_support + pos_emb
172 |
173 | prototype_labels = torch.LongTensor(list(range(n_way))).cuda()
174 | prototype_labels = one_hot(prototype_labels, n_way).view(1, n_way, n_way)
175 | prototype_pos_emb = self.pos_emb(prototype_labels)
176 |
177 | x = x + prototype_pos_emb.permute(0, 2, 1).unsqueeze(-1).repeat(tasks_per_batch, 1, 1, 1)
178 |
179 |
180 | skip_connections = []
181 | for layers in self.encoders:
182 | x = layers(x, context, time)
183 | skip_connections.append(x)
184 |
185 | x = self.bottleneck(x, context, time)
186 |
187 | for layers in self.decoders:
188 | x = torch.cat((x, skip_connections.pop()), dim=1)
189 | x = layers(x, context, time)
190 |
191 | return x
192 |
193 |
194 | class FinalLayer(nn.Module):
195 | def __init__(self, in_channels, out_channels):
196 | super().__init__()
197 | self.groupnorm = nn.GroupNorm(32, in_channels)
198 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
199 |
200 | def forward(self, x):
201 | x = self.groupnorm(x)
202 | x = F.silu(x)
203 | x = self.conv(x)
204 | return x
205 |
206 | def one_hot(indices, depth):
207 | """
208 | Returns a one-hot tensor.
209 | This is a PyTorch equivalent of Tensorflow's tf.one_hot.
210 |
211 | Parameters:
212 | indices: a (n_batch, m) Tensor or (m) Tensor.
213 | depth: a scalar. Represents the depth of the one hot dimension.
214 | Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor.
215 | """
216 |
217 | encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda()
218 | index = indices.view(indices.size() + torch.Size([1]))
219 | encoded_indicies = encoded_indicies.scatter_(1, index, 1)
220 |
221 | return encoded_indicies
222 |
223 |
224 | def extract(a, t, x_shape):
225 | b, *_ = t.shape
226 | out = a.gather(-1, t)
227 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
228 |
229 | class EMA():
230 | def __init__(self, decay):
231 | self.decay = decay
232 |
233 | def update_average(self, old, new):
234 | if old is None:
235 | return new
236 | return old * self.decay + (1 - self.decay) * new
237 |
238 | def update_model_average(self, ema_model, current_model):
239 | for current_params, ema_params in zip(current_model.parameters(), ema_model.parameters()):
240 | old, new = ema_params.data, current_params.data
241 | ema_params.data = self.update_average(old, new)
242 |
243 | class DMFunc(nn.Module):
244 |
245 | def __init__(self):
246 | super(DMFunc, self).__init__()
247 | self.nfe = 0
248 | self.scale_factor = 10
249 | self.time_scale = nn.Sequential(
250 | nn.Linear(320, 160),
251 | nn.ReLU(),
252 | nn.Linear(160, 1),
253 | nn.Sigmoid()
254 | )
255 | # self.res = nn.Sequential(
256 | # nn.Linear(512+320, 256),
257 | # nn.ReLU(),
258 | # nn.Linear(256, 512),
259 | # )
260 |
261 | self.time_emb1 = nn.Linear(320, 256)
262 | self.proto_inf1 = nn.Linear(512, 256)
263 |
264 | self.time_emb2 = nn.Linear(320, 128)
265 | self.proto_inf2 = nn.Linear(256, 128)
266 |
267 | self.time_emb3 = nn.Linear(320, 128)
268 | self.proto_inf3 = nn.Linear(128, 128)
269 |
270 | self.time_emb4 = nn.Linear(320, 256)
271 | self.proto_inf4 = nn.Linear(128, 256)
272 |
273 | self.time_emb5 = nn.Linear(320, 512)
274 | self.proto_inf5 = nn.Linear(256, 512)
275 |
276 | # self.time_emb1 = nn.Linear(320, 256//4)
277 | # self.proto_inf1 = nn.Linear(512//4, 256//4)
278 | #
279 | # self.time_emb2 = nn.Linear(320, 128//4)
280 | # self.proto_inf2 = nn.Linear(256//4, 128//4)
281 | #
282 | # self.time_emb3 = nn.Linear(320, 128//4)
283 | # self.proto_inf3 = nn.Linear(128//4, 128//4)
284 | #
285 | # self.time_emb4 = nn.Linear(320, 256//4)
286 | # self.proto_inf4 = nn.Linear(128//4, 256//4)
287 | #
288 | # self.time_emb5 = nn.Linear(320, 512//4)
289 | # self.proto_inf5 = nn.Linear(256//4, 512//4)
290 |
291 | def get_time_embedding(self, timestep):
292 | freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160).cuda()
293 | x = timestep[:, None] * freqs[None]
294 | return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
295 |
296 | def forward(self, x, time, emb_query, emb_support, labels_support, labels_query, n_way, k_shot):
297 | tasks_per_batch = emb_support.shape[0]
298 | n_support = emb_support.shape[1]
299 | support_labels = one_hot(labels_support.view(tasks_per_batch * n_support), n_way)
300 | support_labels = support_labels.view(tasks_per_batch, n_support, -1)
301 |
302 | # encoding time
303 | time = self.get_time_embedding(time)
304 | scale = self.time_scale(time)
305 |
306 | nb, nd, nw, _ = x.shape
307 | prototypes = x.detach()[:, :, :, 0].permute(0, 2, 1)
308 |
309 | #
310 | weights = prototypes
311 | weights.requires_grad = True
312 |
313 | # logits = 10 * F.cosine_similarity(
314 | # emb_support.unsqueeze(2).expand(-1, -1, n_way, -1),
315 | # weights.unsqueeze(1).expand(-1, emb_support.shape[1], -1, -1),
316 | # dim=-1) + torch.sum(weights*weights) * 5e-4
317 | # loss = nn.MSELoss()(F.softmax(logits.reshape(-1, n_way), dim=1), support_labels.reshape(-1, n_way))
318 | # loss = nn.CrossEntropyLoss()(logits.reshape(-1, n_way), labels_support.reshape(-1))
319 |
320 | diff = weights.unsqueeze(1).expand(-1, emb_support.shape[1], -1, -1) - emb_support.unsqueeze(2).expand(-1, -1, n_way, -1)
321 | loss = torch.sum(torch.sum(diff*diff, dim=-1) * support_labels)
322 |
323 | # logits = F.softmax(-1 * torch.sum(diff*diff, dim=-1) / 512, dim=-1)
324 | # # loss = nn.MSELoss()(F.softmax(logits.reshape(-1, n_way), dim=1), support_labels.reshape(-1, n_way))
325 | # loss = nn.CrossEntropyLoss()(logits.reshape(-1, n_way), labels_support.reshape(-1))
326 |
327 | # compute grad and update inner loop weights
328 | grads = torch.autograd.grad(loss, weights)
329 | x = grads[0].detach()
330 |
331 | # # prototypes_ = (prototypes+prototypes_fs)/2
332 | # prototypes_ = prototypes
333 |
334 | # query_labels = torch.nn.functional.cosine_similarity(
335 | # emb_query.unsqueeze(2).expand(-1, -1, prototypes_.shape[1], -1),
336 | # prototypes_.unsqueeze(1).expand(-1, emb_query.shape[1], -1, -1),
337 | # dim=-1)
338 | # query_labels = query_labels * self.scale_factor
339 | # query_labels = F.softmax(query_labels, dim=-1)
340 | # data_labels = torch.cat([support_labels, query_labels], dim=1).unsqueeze(dim=-1).permute(0, 2, 1, 3)
341 |
342 | # topk, indices = torch.topk(data_labels, k_shot, dim=2)
343 | # mask = torch.zeros_like(data_labels)
344 | # mask = mask.scatter(2, indices, 1)
345 | # data_labels = data_labels * mask
346 |
347 | # data_labels = support_labels.unsqueeze(dim=-1).permute(0, 2, 1, 3)
348 | #
349 | # diff_weights = data_labels / torch.sum(data_labels, dim=2, keepdim=True)
350 | #
351 | # # cal vector fild
352 | # # all_x = torch.cat([emb_support, emb_query], dim=1)
353 | # all_x = emb_support
354 | # x = prototypes
355 | # x_left = x.unsqueeze(dim=2).expand(-1, -1, all_x.shape[1], -1)
356 | # all_x_right = all_x.unsqueeze(dim=1).expand(-1, n_way, -1, -1)
357 | # diff = (x_left - all_x_right)
358 | # diff = scale.unsqueeze(dim=-1) * torch.sum((diff_weights * diff), dim=2) + (1-scale.unsqueeze(dim=-1))*self.res(torch.cat([prototypes, time.unsqueeze(dim=1).repeat(1,5,1)], dim=-1))
359 |
360 | # x = torch.sum((diff_weights * diff), dim=2)
361 |
362 | time_ = self.time_emb1(time).unsqueeze(dim=1).repeat(1,5,1)
363 | x_1 = self.proto_inf1(x)
364 | x_1 = F.softplus(x_1 * time_)
365 |
366 | time_ = self.time_emb2(time).unsqueeze(dim=1).repeat(1,5,1)
367 | x_2 = self.proto_inf2(x_1)
368 | x_2 = F.softplus(x_2 * time_)
369 |
370 | time_ = self.time_emb3(time).unsqueeze(dim=1).repeat(1,5,1)
371 | x_3 = self.proto_inf3(x_2)
372 | x_3 = F.softplus(x_3 * time_) + x_2
373 |
374 | time_ = self.time_emb4(time).unsqueeze(dim=1).repeat(1,5,1)
375 | x_4 = self.proto_inf4(x_3)
376 | x_4 = F.softplus(x_4 * time_) + x_1
377 |
378 | time_ = self.time_emb5(time).unsqueeze(dim=1).repeat(1,5,1)
379 | x_5 = self.proto_inf5(x_4)
380 | x_5 = x_5 * time_ + x
381 |
382 | # diff = scale.unsqueeze(dim=-1) * torch.sum((diff_weights * diff), dim=2) + (
383 | # 1 - scale.unsqueeze(dim=-1)) * x
384 |
385 | # diff = x_5 * scale.unsqueeze(dim=-1)
386 |
387 | diff = scale.unsqueeze(dim=-1) * x + (
388 | 1 - scale.unsqueeze(dim=-1)) * x_5
389 |
390 | return diff.permute(0, 2, 1).unsqueeze(-1)
391 |
392 |
393 | class DMFunc1(nn.Module):
394 |
395 | def __init__(self):
396 | super(DMFunc1, self).__init__()
397 | self.nfe = 0
398 | self.scale_factor = 10
399 |
400 | self.time_emb1 = nn.Linear(320, 256)
401 | self.proto_inf1 = nn.Linear(512, 256)
402 | self.spt_inf1 = nn.Linear(512, 256)
403 |
404 | self.time_emb2 = nn.Linear(320, 128)
405 | self.proto_inf2 = nn.Linear(256, 128)
406 | self.spt_inf2 = nn.Linear(256, 128)
407 |
408 | self.time_emb3 = nn.Linear(320, 128)
409 | self.proto_inf3 = nn.Linear(128, 128)
410 | self.spt_inf3 = nn.Linear(128, 128)
411 |
412 | self.time_emb4 = nn.Linear(320, 256)
413 | self.proto_inf4 = nn.Linear(128, 256)
414 | self.spt_inf4 = nn.Linear(128, 256)
415 |
416 | self.time_emb5 = nn.Linear(320, 512)
417 | self.proto_inf5 = nn.Linear(256, 512)
418 | self.spt_inf5 = nn.Linear(256, 512)
419 |
420 | def get_time_embedding(self, timestep):
421 | freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160).cuda()
422 | x = timestep[:, None] * freqs[None]
423 | return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
424 |
425 | def forward(self, x, time, emb_query, emb_support, labels_support, labels_query, n_way, k_shot):
426 | tasks_per_batch = emb_support.shape[0]
427 | n_support = emb_support.shape[1]
428 |
429 | # encoding time
430 | time = self.get_time_embedding(time)
431 |
432 | nb, nd, nw, _ = x.shape
433 | x = x.permute(0, 2, 3, 1)
434 |
435 | sorted_emb_support = []
436 | for way_i in range(n_way):
437 | temp = []
438 | for bi in range(tasks_per_batch):
439 | temp.append(emb_support[bi, labels_support[bi, :]==way_i, :])
440 | temp = torch.stack(temp, dim=0)
441 | sorted_emb_support.append(temp)
442 | sorted_emb_support = torch.stack(sorted_emb_support, dim=1)
443 |
444 | time_ = self.time_emb1(time).unsqueeze(dim=1).unsqueeze(dim=1)
445 | x = self.proto_inf1(x)
446 | spt_emb = self.spt_inf1(sorted_emb_support)
447 | x = F.softplus(x * time_ * spt_emb)
448 | x_1 = x
449 | time_ = self.time_emb2(time).unsqueeze(dim=1).unsqueeze(dim=1)
450 | x = self.proto_inf2(x)
451 | spt_emb = self.spt_inf2(spt_emb)
452 | x = F.softplus(x * time_ * spt_emb)
453 | x_2 = x
454 | time_ = self.time_emb3(time).unsqueeze(dim=1).unsqueeze(dim=1)
455 | x = self.proto_inf3(x)
456 | spt_emb = self.spt_inf3(spt_emb)
457 | x = F.softplus(x * time_ * spt_emb) + x_2
458 |
459 | time_ = self.time_emb4(time).unsqueeze(dim=1).unsqueeze(dim=1)
460 | x = self.proto_inf4(x)
461 | spt_emb = self.spt_inf4(spt_emb)
462 | x = F.softplus(x * time_ * spt_emb) + x_1
463 | time_ = self.time_emb5(time).unsqueeze(dim=1).unsqueeze(dim=1)
464 | x = self.proto_inf5(x)
465 | spt_emb = self.spt_inf5(spt_emb)
466 | x = F.softplus(x * time_ * spt_emb) + x
467 |
468 | x = x.mean(dim=-2)
469 | return x.permute(0, 2, 1).unsqueeze(-1)
470 |
471 |
472 | class GaussianDiffusion(nn.Module):
473 | __doc__ = r"""Gaussian Diffusion model. Forwarding through the module returns diffusion reversal scalar loss tensor.
474 |
475 | Input:
476 | x: tensor of shape (N, img_channels, *img_size)
477 | y: tensor of shape (N)
478 | Output:
479 | scalar loss tensor
480 | Args:
481 | model (nn.Module): model which estimates diffusion noise
482 | img_size (tuple): image size tuple (H, W)
483 | img_channels (int): number of image channels
484 | betas (np.ndarray): numpy array of diffusion betas
485 | loss_type (string): loss type, "l1" or "l2"
486 | ema_decay (float): model weights exponential moving average decay
487 | ema_start (int): number of steps before EMA
488 | ema_update_rate (int): number of steps before each EMA update
489 | """
490 |
491 | def __init__(
492 | self,
493 | model,
494 | img_size,
495 | img_channels,
496 | num_classes,
497 | betas,
498 | loss_type="l2",
499 | ema_decay=0.9999,
500 | ema_start=5000,
501 | ema_update_rate=1,
502 | ):
503 | super().__init__()
504 |
505 | self.model = model
506 | self.ema_model = deepcopy(model)
507 |
508 | self.ema = EMA(ema_decay)
509 | self.ema_decay = ema_decay
510 | self.ema_start = ema_start
511 | self.ema_update_rate = ema_update_rate
512 | self.step = 0
513 |
514 | self.img_size = img_size
515 | self.img_channels = img_channels
516 | self.num_classes = num_classes
517 |
518 | if loss_type not in ["l1", "l2"]:
519 | raise ValueError("__init__() got unknown loss type")
520 |
521 | self.loss_type = loss_type
522 | self.num_timesteps = len(betas)
523 |
524 | alphas = 1.0 - betas
525 | alphas_cumprod = np.cumprod(alphas)
526 |
527 | to_torch = partial(torch.tensor, dtype=torch.float32)
528 |
529 | self.register_buffer("betas", to_torch(betas))
530 | self.register_buffer("alphas", to_torch(alphas))
531 | self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
532 |
533 | self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
534 | self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1 - alphas_cumprod)))
535 | self.register_buffer("reciprocal_sqrt_alphas", to_torch(np.sqrt(1 / alphas)))
536 |
537 | self.register_buffer("remove_noise_coeff", to_torch(betas / np.sqrt(1 - alphas_cumprod)))
538 | self.register_buffer("sigma", to_torch(np.sqrt(betas)))
539 |
540 | def update_ema(self):
541 | self.step += 1
542 | if self.step % self.ema_update_rate == 0:
543 | if self.step < self.ema_start:
544 | self.ema_model.load_state_dict(self.model.state_dict())
545 | else:
546 | self.ema.update_model_average(self.ema_model, self.model)
547 |
548 | # @torch.no_grad()
549 | def remove_noise(self, x, t, emb_query, emb_support, labels_support, labels_query, n_way, k_shot, use_ema=True):
550 | if use_ema:
551 | return (
552 | (x - extract(self.remove_noise_coeff, t, x.shape) * self.ema_model(x, t, emb_query, emb_support, labels_support, labels_query, n_way, k_shot)) *
553 | extract(self.reciprocal_sqrt_alphas, t, x.shape)
554 | )
555 | else:
556 | return (
557 | (x - extract(self.remove_noise_coeff, t, x.shape) * self.model(x, t, emb_query, emb_support, labels_support, labels_query, n_way, k_shot)) *
558 | extract(self.reciprocal_sqrt_alphas, t, x.shape)
559 | )
560 |
561 | #@torch.no_grad()
562 | def sample(self, emb_query, emb_support, labels_support, labels_query, n_way, k_shot, use_ema=True):
563 | batch_size = emb_query.shape[0]
564 | device = emb_query.device
565 | x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)
566 |
567 | for t in range(self.num_timesteps - 1, -1, -1):
568 | t_batch = torch.tensor([t], device=device).repeat(batch_size)
569 | x = self.remove_noise(x, t_batch, emb_query, emb_support, labels_support, labels_query, n_way, k_shot, use_ema)
570 |
571 | # if t > 0:
572 | # x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
573 | # if t > 0:
574 | # x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
575 |
576 | prototype = x.detach()[:, :, :, 0].permute(0, 2, 1)
577 | logits = 10 * F.cosine_similarity(
578 | emb_query.unsqueeze(2).expand(-1, -1, n_way, -1),
579 | prototype.unsqueeze(1).expand(-1, emb_query.shape[1], -1, -1),
580 | dim=-1)
581 | # logits = torch.sum(emb_query.unsqueeze(2).expand(-1, -1, n_way, -1)*
582 | # prototype.unsqueeze(1).expand(-1, emb_query.shape[1], -1, -1), dim=-1)
583 | return logits
584 |
585 | @torch.no_grad()
586 | def sample_diffusion_sequence(self, batch_size, device, y=None, use_ema=True):
587 | if y is not None and batch_size != len(y):
588 | raise ValueError("sample batch size different from length of given y")
589 |
590 | x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)
591 | diffusion_sequence = [x.cpu().detach()]
592 |
593 | for t in range(self.num_timesteps - 1, -1, -1):
594 | t_batch = torch.tensor([t], device=device).repeat(batch_size)
595 | x = self.remove_noise(x, t_batch, y, use_ema)
596 |
597 | if t > 0:
598 | x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
599 |
600 | diffusion_sequence.append(x.cpu().detach())
601 |
602 | return diffusion_sequence
603 |
604 | def perturb_x(self, x, t, noise):
605 | return (
606 | extract(self.sqrt_alphas_cumprod, t, x.shape) * x +
607 | extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise
608 | )
609 |
610 | def get_losses(self, x, t, emb_query, emb_support, labels_support, labels_query, n_way, k_shot):
611 | noise = torch.randn_like(x)
612 |
613 | perturbed_x = self.perturb_x(x, t, noise)
614 | estimated_noise = self.model(perturbed_x, t, emb_query, emb_support, labels_support, labels_query, n_way, k_shot)
615 |
616 | if self.loss_type == "l1":
617 | loss = F.l1_loss(estimated_noise, noise)
618 | elif self.loss_type == "l2":
619 | loss = F.mse_loss(estimated_noise, noise)
620 |
621 | # loss = F.l1_loss(estimated_noise, noise)
622 |
623 | return loss
624 |
625 | def forward(self, x, emb_query, emb_support, labels_support, labels_query, n_way, k_shot):
626 | x = x.permute(0, 2, 1).unsqueeze(-1)
627 | b, c, h, w = x.shape
628 | device = x.device
629 | t = torch.randint(0, self.num_timesteps, (b,), device=device)
630 | return self.get_losses(x, t, emb_query, emb_support, labels_support, labels_query, n_way, k_shot)
631 |
632 |
633 | def generate_cosine_schedule(T, s=0.008):
634 | def f(t, T):
635 | return (np.cos((t / T + s) / (1 + s) * np.pi / 2)) ** 2
636 |
637 | alphas = []
638 | f0 = f(0, T)
639 |
640 | for t in range(T + 1):
641 | alphas.append(f(t, T) / f0)
642 |
643 | betas = []
644 |
645 | for t in range(1, T + 1):
646 | betas.append(min(1 - alphas[t] / alphas[t - 1], 0.999))
647 |
648 | return np.array(betas)
649 |
650 |
651 | def generate_linear_schedule(T, low, high):
652 | return np.linspace(low, high, T)
653 |
654 |
655 | def get_transform():
656 | class RescaleChannels(object):
657 | def __call__(self, sample):
658 | return 2 * sample - 1
659 |
660 | return torchvision.transforms.Compose([
661 | torchvision.transforms.ToTensor(),
662 | RescaleChannels(),
663 | ])
664 |
665 |
666 | def str2bool(v):
667 | """
668 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
669 | """
670 | if isinstance(v, bool):
671 | return v
672 | if v.lower() in ("yes", "true", "t", "y", "1"):
673 | return True
674 | elif v.lower() in ("no", "false", "f", "n", "0"):
675 | return False
676 | else:
677 | raise argparse.ArgumentTypeError("boolean value expected")
678 |
679 |
680 | def add_dict_to_argparser(parser, default_dict):
681 | """
682 | https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/script_util.py
683 | """
684 | for k, v in default_dict.items():
685 | v_type = type(v)
686 | if v is None:
687 | v_type = str
688 | elif isinstance(v, bool):
689 | v_type = str2bool
690 | parser.add_argument(f"--{k}", default=v, type=v_type)
691 |
692 |
693 | def get_diffusion_from_args():
694 | num_timesteps = 1000
695 | schedule = "linear"
696 | loss_type = "l2"
697 | use_labels = False
698 |
699 | base_channels = 128
700 | channel_mults = (1, 2, 2, 2)
701 | num_res_blocks = 2
702 | time_emb_dim = 128 * 4
703 | norm = "gn"
704 | dropout = 0.1
705 | activation = "silu"
706 | attention_resolutions = (1,)
707 |
708 | ema_decay = 0.9999
709 | ema_update_rate = 1
710 | schedule_low = 1e-4
711 | schedule_high = 0.02
712 |
713 | activations = {
714 | "relu": F.relu,
715 | "mish": F.mish,
716 | "silu": F.silu,
717 | }
718 |
719 | model = DMFunc()
720 |
721 | if schedule == "cosine":
722 | betas = generate_cosine_schedule(num_timesteps)
723 | else:
724 | betas = generate_linear_schedule(
725 | num_timesteps,
726 | schedule_low * 1000 / num_timesteps,
727 | schedule_high * 1000 / num_timesteps,
728 | )
729 |
730 | diffusion = GaussianDiffusion(
731 | model, (5, 1), 512, 10,
732 | betas,
733 | ema_decay=ema_decay,
734 | ema_update_rate=ema_update_rate,
735 | ema_start=2000,
736 | loss_type=loss_type,
737 | )
738 | return diffusion
--------------------------------------------------------------------------------
/models/classification_heads.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torch
5 | from torch.autograd import Variable
6 | import torch.nn as nn
7 | from qpth.qp import QPFunction
8 | from models.local_distribute import *
9 | from qpth.qp import QPFunction
10 | import torch.nn.functional as F
11 | import cv2
12 | import torch.fft
13 |
14 |
15 | def computeGramMatrix(A, B):
16 | """
17 | Constructs a linear kernel matrix between A and B.
18 | We assume that each row in A and B represents a d-dimensional feature vector.
19 |
20 | Parameters:
21 | A: a (n_batch, n, d) Tensor.
22 | B: a (n_batch, m, d) Tensor.
23 | Returns: a (n_batch, n, m) Tensor.
24 | """
25 |
26 | assert(A.dim() == 3)
27 | assert(B.dim() == 3)
28 | assert(A.size(0) == B.size(0) and A.size(2) == B.size(2))
29 |
30 | return torch.bmm(A, B.transpose(1,2))
31 |
32 |
33 | def binv(b_mat):
34 | """
35 | Computes an inverse of each matrix in the batch.
36 | Pytorch 0.4.1 does not support batched matrix inverse.
37 | Hence, we are solving AX=I.
38 |
39 | Parameters:
40 | b_mat: a (n_batch, n, n) Tensor.
41 | Returns: a (n_batch, n, n) Tensor.
42 | """
43 |
44 | id_matrix = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat).cuda()
45 | b_inv, _ = torch.gesv(id_matrix, b_mat)
46 |
47 | return b_inv
48 |
49 |
50 | def one_hot(indices, depth):
51 | """
52 | Returns a one-hot tensor.
53 | This is a PyTorch equivalent of Tensorflow's tf.one_hot.
54 |
55 | Parameters:
56 | indices: a (n_batch, m) Tensor or (m) Tensor.
57 | depth: a scalar. Represents the depth of the one hot dimension.
58 | Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor.
59 | """
60 |
61 | encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda()
62 | index = indices.view(indices.size()+torch.Size([1]))
63 | encoded_indicies = encoded_indicies.scatter_(1,index,1)
64 |
65 | return encoded_indicies
66 |
67 | def batched_kronecker(matrix1, matrix2):
68 | matrix1_flatten = matrix1.reshape(matrix1.size()[0], -1)
69 | matrix2_flatten = matrix2.reshape(matrix2.size()[0], -1)
70 | return torch.bmm(matrix1_flatten.unsqueeze(2), matrix2_flatten.unsqueeze(1)).reshape([matrix1.size()[0]] + list(matrix1.size()[1:]) + list(matrix2.size()[1:])).permute([0, 1, 3, 2, 4]).reshape(matrix1.size(0), matrix1.size(1) * matrix2.size(1), matrix1.size(2) * matrix2.size(2))
71 |
72 |
73 | def FFTHead(query, support, support_labels, n_way, n_shot, normalize=True):
74 | """
75 | Constructs the prototype representation of each class(=mean of support vectors of each class) and
76 | returns the classification score (=L2 distance to each class prototype) on the query set.
77 |
78 | This model is the classification head described in:
79 | Prototypical Networks for Few-shot Learning
80 | (Snell et al., NIPS 2017).
81 |
82 | Parameters:
83 | query: a (tasks_per_batch, n_query, d) Tensor.
84 | support: a (tasks_per_batch, n_support, d) Tensor.
85 | support_labels: a (tasks_per_batch, n_support) Tensor.
86 | n_way: a scalar. Represents the number of classes in a few-shot classification task.
87 | n_shot: a scalar. Represents the number of support examples given per class.
88 | normalize: a boolean. Represents whether if we want to normalize the distances by the embedding dimension.
89 | Returns: a (tasks_per_batch, n_query, n_way) Tensor.
90 | """
91 | support = F.normalize(support, dim=-3)
92 | query = F.normalize(query, dim=-3)
93 |
94 | tasks_per_batch = query.size(0)
95 | n_support = support.size(1)
96 | n_query = query.size(1)
97 | n_c = query.size(2)
98 | n_w = query.size(3)
99 | n_h = query.size(4)
100 |
101 | assert (n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
102 |
103 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way)
104 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way)
105 |
106 | # From:
107 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/architectures/PrototypicalNetworksHead.py
108 | # ************************* Compute Prototypes **************************
109 | labels_train_transposed = support_labels_one_hot.transpose(1, 2)
110 | # Batch matrix multiplication:
111 | # prototypes = labels_train_transposed * features_train ==>
112 | # [batch_size x nKnovel x num_channels] =
113 | # [batch_size x nKnovel x num_train_examples] * [batch_size * num_train_examples * num_channels]
114 | prototypes = torch.bmm(labels_train_transposed, support.reshape(tasks_per_batch, n_support, -1))
115 | # Divide with the number of examples per novel category.
116 | prototypes = prototypes.div(
117 | labels_train_transposed.sum(dim=2, keepdim=True).expand_as(prototypes)
118 | )
119 | prototypes = prototypes.reshape(tasks_per_batch, n_way, n_c, n_w, n_h)
120 | # Distance Matrix Vectorization Trick
121 |
122 | prototypes_rfft2 = torch.fft.fftn(prototypes, dim=(3, 4))
123 | prototypes_rfft2 = torch.roll(prototypes_rfft2, (n_w // 2, n_h // 2), dims=(3, 4))
124 | query_rfft2 = torch.fft.fftn(query, dim=(3, 4))
125 | query_rfft2 = torch.roll(query_rfft2, (n_w // 2, n_h // 2), dims=(3, 4))
126 | prototypes_rfft2_a = torch.abs(prototypes_rfft2)
127 | # a = torch.max(prototypes_rfft2_a, keepdim=True, dim=-1)[0]
128 | # prototypes_rfft2_a = prototypes_rfft2_a / torch.max(torch.max(prototypes_rfft2_a, keepdim=True, dim=-1)[0], keepdim=True, dim=-2)[0]
129 | query_rfft2_a = torch.abs(query_rfft2)
130 | # query_rfft2_a = query_rfft2_a / torch.max(torch.max(query_rfft2_a, keepdim=True, dim=-1)[0],
131 | # keepdim=True, dim=-2)[0]
132 |
133 | prototypes = prototypes_rfft2_a.reshape(tasks_per_batch, n_way, -1)
134 | query = query_rfft2_a.reshape(tasks_per_batch, n_query, -1)
135 |
136 | AB = computeGramMatrix(query, prototypes)
137 | AA = (query * query).sum(dim=2, keepdim=True)
138 | BB = (prototypes * prototypes).sum(dim=2, keepdim=True).reshape(tasks_per_batch, 1, n_way)
139 | logits = AA.expand_as(AB) - 2 * AB + BB.expand_as(AB)
140 | logits = -logits
141 |
142 | if normalize:
143 | logits = logits / query.shape[-1]
144 |
145 | return logits
146 |
147 |
148 | def MetaOptNetHead_Ridge(query, support, support_labels, n_way, n_shot, lambda_reg=50.0, double_precision=False):
149 | """
150 | Fits the support set with ridge regression and
151 | returns the classification score on the query set.
152 |
153 | Parameters:
154 | query: a (tasks_per_batch, n_query, d) Tensor.
155 | support: a (tasks_per_batch, n_support, d) Tensor.
156 | support_labels: a (tasks_per_batch, n_support) Tensor.
157 | n_way: a scalar. Represents the number of classes in a few-shot classification task.
158 | n_shot: a scalar. Represents the number of support examples given per class.
159 | lambda_reg: a scalar. Represents the strength of L2 regularization.
160 | Returns: a (tasks_per_batch, n_query, n_way) Tensor.
161 | """
162 |
163 | tasks_per_batch = query.size(0)
164 | n_support = support.size(1)
165 | n_query = query.size(1)
166 |
167 | assert(query.dim() == 3)
168 | assert(support.dim() == 3)
169 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2))
170 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
171 |
172 | #Here we solve the dual problem:
173 | #Note that the classes are indexed by m & samples are indexed by i.
174 | #min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i
175 |
176 | #where w_m(\alpha) = \sum_i \alpha^m_i x_i,
177 |
178 | #\alpha is an (n_support, n_way) matrix
179 | kernel_matrix = computeGramMatrix(support, support)
180 | kernel_matrix += lambda_reg * torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda()
181 |
182 | block_kernel_matrix = kernel_matrix.repeat(n_way, 1, 1) #(n_way * tasks_per_batch, n_support, n_support)
183 |
184 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) # (tasks_per_batch * n_support, n_way)
185 | support_labels_one_hot = support_labels_one_hot.transpose(0, 1) # (n_way, tasks_per_batch * n_support)
186 | support_labels_one_hot = support_labels_one_hot.reshape(n_way * tasks_per_batch, n_support) # (n_way*tasks_per_batch, n_support)
187 |
188 | G = block_kernel_matrix
189 | e = -2.0 * support_labels_one_hot
190 |
191 | #This is a fake inequlity constraint as qpth does not support QP without an inequality constraint.
192 | id_matrix_1 = torch.zeros(tasks_per_batch*n_way, n_support, n_support)
193 | C = Variable(id_matrix_1)
194 | h = Variable(torch.zeros((tasks_per_batch*n_way, n_support)))
195 | dummy = Variable(torch.Tensor()).cuda() # We want to ignore the equality constraint.
196 |
197 | if double_precision:
198 | G, e, C, h = [x.double().cuda() for x in [G, e, C, h]]
199 |
200 | else:
201 | G, e, C, h = [x.float().cuda() for x in [G, e, C, h]]
202 |
203 | # Solve the following QP to fit SVM:
204 | # \hat z = argmin_z 1/2 z^T G z + e^T z
205 | # subject to Cz <= h
206 | # We use detach() to prevent backpropagation to fixed variables.
207 | qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach())
208 | #qp_sol = QPFunction(verbose=False)(G, e.detach(), dummy.detach(), dummy.detach(), dummy.detach(), dummy.detach())
209 |
210 | #qp_sol (n_way*tasks_per_batch, n_support)
211 | qp_sol = qp_sol.reshape(n_way, tasks_per_batch, n_support)
212 | #qp_sol (n_way, tasks_per_batch, n_support)
213 | qp_sol = qp_sol.permute(1, 2, 0)
214 | #qp_sol (tasks_per_batch, n_support, n_way)
215 |
216 | # Compute the classification score.
217 | compatibility = computeGramMatrix(support, query)
218 | compatibility = compatibility.float()
219 | compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way)
220 | qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way)
221 | logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way)
222 | logits = logits * compatibility
223 | logits = torch.sum(logits, 1)
224 |
225 | return logits
226 |
227 | def R2D2Head(query, support, support_labels, n_way, n_shot, l2_regularizer_lambda=50.0):
228 | """
229 | Fits the support set with ridge regression and
230 | returns the classification score on the query set.
231 |
232 | This model is the classification head described in:
233 | Meta-learning with differentiable closed-form solvers
234 | (Bertinetto et al., in submission to NIPS 2018).
235 |
236 | Parameters:
237 | query: a (tasks_per_batch, n_query, d) Tensor.
238 | support: a (tasks_per_batch, n_support, d) Tensor.
239 | support_labels: a (tasks_per_batch, n_support) Tensor.
240 | n_way: a scalar. Represents the number of classes in a few-shot classification task.
241 | n_shot: a scalar. Represents the number of support examples given per class.
242 | l2_regularizer_lambda: a scalar. Represents the strength of L2 regularization.
243 | Returns: a (tasks_per_batch, n_query, n_way) Tensor.
244 | """
245 |
246 | tasks_per_batch = query.size(0)
247 | n_support = support.size(1)
248 |
249 | assert(query.dim() == 3)
250 | assert(support.dim() == 3)
251 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2))
252 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
253 |
254 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way)
255 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way)
256 |
257 | id_matrix = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda()
258 |
259 | # Compute the dual form solution of the ridge regression.
260 | # W = X^T(X X^T - lambda * I)^(-1) Y
261 | ridge_sol = computeGramMatrix(support, support) + l2_regularizer_lambda * id_matrix
262 | ridge_sol = binv(ridge_sol)
263 | ridge_sol = torch.bmm(support.transpose(1,2), ridge_sol)
264 | ridge_sol = torch.bmm(ridge_sol, support_labels_one_hot)
265 |
266 | # Compute the classification score.
267 | # score = W X
268 | logits = torch.bmm(query, ridge_sol)
269 |
270 | return logits
271 |
272 |
273 | def MetaOptNetHead_SVM_He(query, support, support_labels, n_way, n_shot, C_reg=0.01, double_precision=False):
274 | """
275 | Fits the support set with multi-class SVM and
276 | returns the classification score on the query set.
277 |
278 | This is the multi-class SVM presented in:
279 | A simplified multi-class support vector machine with reduced dual optimization
280 | (He et al., Pattern Recognition Letter 2012).
281 |
282 | This SVM is desirable because the dual variable of size is n_support
283 | (as opposed to n_way*n_support in the Weston&Watkins or Crammer&Singer multi-class SVM).
284 | This model is the classification head that we have initially used for our project.
285 | This was dropped since it turned out that it performs suboptimally on the meta-learning scenarios.
286 |
287 | Parameters:
288 | query: a (tasks_per_batch, n_query, d) Tensor.
289 | support: a (tasks_per_batch, n_support, d) Tensor.
290 | support_labels: a (tasks_per_batch, n_support) Tensor.
291 | n_way: a scalar. Represents the number of classes in a few-shot classification task.
292 | n_shot: a scalar. Represents the number of support examples given per class.
293 | C_reg: a scalar. Represents the cost parameter C in SVM.
294 | Returns: a (tasks_per_batch, n_query, n_way) Tensor.
295 | """
296 |
297 | tasks_per_batch = query.size(0)
298 | n_support = support.size(1)
299 | n_query = query.size(1)
300 |
301 | assert(query.dim() == 3)
302 | assert(support.dim() == 3)
303 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2))
304 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
305 |
306 |
307 | kernel_matrix = computeGramMatrix(support, support)
308 |
309 | V = (support_labels * n_way - torch.ones(tasks_per_batch, n_support, n_way).cuda()) / (n_way - 1)
310 | G = computeGramMatrix(V, V).detach()
311 | G = kernel_matrix * G
312 |
313 | e = Variable(-1.0 * torch.ones(tasks_per_batch, n_support))
314 | id_matrix = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support)
315 | C = Variable(torch.cat((id_matrix, -id_matrix), 1))
316 | h = Variable(torch.cat((C_reg * torch.ones(tasks_per_batch, n_support), torch.zeros(tasks_per_batch, n_support)), 1))
317 | dummy = Variable(torch.Tensor()).cuda() # We want to ignore the equality constraint.
318 |
319 | if double_precision:
320 | G, e, C, h = [x.double().cuda() for x in [G, e, C, h]]
321 | else:
322 | G, e, C, h = [x.cuda() for x in [G, e, C, h]]
323 |
324 | # Solve the following QP to fit SVM:
325 | # \hat z = argmin_z 1/2 z^T G z + e^T z
326 | # subject to Cz <= h
327 | # We use detach() to prevent backpropagation to fixed variables.
328 | qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach())
329 |
330 | # Compute the classification score.
331 | compatibility = computeGramMatrix(query, support)
332 | compatibility = compatibility.float()
333 |
334 | logits = qp_sol.float().unsqueeze(1).expand(tasks_per_batch, n_query, n_support)
335 | logits = logits * compatibility
336 | logits = logits.view(tasks_per_batch, n_query, n_shot, n_way)
337 | logits = torch.sum(logits, 2)
338 |
339 | return logits
340 |
341 |
342 | def CosineNetHead(query, support, support_labels, n_way, n_shot, normalize=True):
343 | """
344 | Constructs the prototype representation of each class(=mean of support vectors of each class) and
345 | returns the classification score (=L2 distance to each class prototype) on the query set.
346 |
347 | This model is the classification head described in:
348 | Prototypical Networks for Few-shot Learning
349 | (Snell et al., NIPS 2017).
350 |
351 | Parameters:
352 | query: a (tasks_per_batch, n_query, d) Tensor.
353 | support: a (tasks_per_batch, n_support, d) Tensor.
354 | support_labels: a (tasks_per_batch, n_support) Tensor.
355 | n_way: a scalar. Represents the number of classes in a few-shot classification task.
356 | n_shot: a scalar. Represents the number of support examples given per class.
357 | normalize: a boolean. Represents whether if we want to normalize the distances by the embedding dimension.
358 | Returns: a (tasks_per_batch, n_query, n_way) Tensor.
359 | """
360 |
361 | tasks_per_batch = query.size(0)
362 | n_support = support.size(1)
363 | n_query = query.size(1)
364 | d = query.size(2)
365 |
366 | assert (query.dim() == 3)
367 | assert (support.dim() == 3)
368 | assert (query.size(0) == support.size(0) and query.size(2) == support.size(2))
369 | assert (n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
370 |
371 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way)
372 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way)
373 |
374 | # From:
375 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/architectures/PrototypicalNetworksHead.py
376 | # ************************* Compute Prototypes **************************
377 | labels_train_transposed = support_labels_one_hot.transpose(1, 2)
378 | # Batch matrix multiplication:
379 | # prototypes = labels_train_transposed * features_train ==>
380 | # [batch_size x nKnovel x num_channels] =
381 | # [batch_size x nKnovel x num_train_examples] * [batch_size * num_train_examples * num_channels]
382 | prototypes = torch.bmm(labels_train_transposed, support)
383 | # Divide with the number of examples per novel category.
384 | prototypes = prototypes.div(
385 | labels_train_transposed.sum(dim=2, keepdim=True).expand_as(prototypes)
386 | )
387 |
388 | # Distance Matrix Vectorization Trick
389 | logits = torch.nn.functional.cosine_similarity(query.unsqueeze(2).expand(-1, -1, prototypes.shape[1], -1),
390 | prototypes.unsqueeze(1).expand(-1, query.shape[1], -1, -1), dim=-1)
391 |
392 | # logits = torch.nn.CosineSimilarity(dim=-1)(query.unsqueeze(2).expand(-1, -1, prototypes.shape[1], -1),
393 | # prototypes.unsqueeze(1).expand(-1, query.shape[1], -1, -1))*10
394 |
395 | return logits
396 |
397 |
398 | def emd_inference_qpth(distance_matrix, weight1, weight2, form='QP', l2_strength=0.0001):
399 | """
400 | to use the QP solver QPTH to derive EMD (LP problem),
401 | one can transform the LP problem to QP,
402 | or omit the QP term by multiplying it with a small value,i.e. l2_strngth.
403 | :param distance_matrix: nbatch * element_number * element_number
404 | :param weight1: nbatch * weight_number
405 | :param weight2: nbatch * weight_number
406 | :return:
407 | emd distance: nbatch*1
408 | flow : nbatch * weight_number *weight_number
409 | """
410 |
411 | weight1 = (weight1 * weight1.shape[-1]) / weight1.sum(1).unsqueeze(1)
412 | weight2 = (weight2 * weight2.shape[-1]) / weight2.sum(1).unsqueeze(1)
413 |
414 | nbatch = distance_matrix.shape[0]
415 | nelement_distmatrix = distance_matrix.shape[1] * distance_matrix.shape[2]
416 | nelement_weight1 = weight1.shape[1]
417 | nelement_weight2 = weight2.shape[1]
418 |
419 | Q_1 = distance_matrix.view(-1, 1, nelement_distmatrix).double()
420 |
421 | if form == 'QP':
422 | # version: QTQ
423 | Q = torch.bmm(Q_1.transpose(2, 1), Q_1).double().cuda() + 1e-4 * torch.eye(
424 | nelement_distmatrix).double().cuda().unsqueeze(0).repeat(nbatch, 1, 1) # 0.00001 *
425 | p = torch.zeros(nbatch, nelement_distmatrix).double().cuda()
426 | elif form == 'L2':
427 | # version: regularizer
428 | Q = (l2_strength * torch.eye(nelement_distmatrix).double()).cuda().unsqueeze(0).repeat(nbatch, 1, 1)
429 | p = distance_matrix.view(nbatch, nelement_distmatrix).double()
430 | else:
431 | raise ValueError('Unkown form')
432 |
433 | h_1 = torch.zeros(nbatch, nelement_distmatrix).double().cuda()
434 | h_2 = torch.cat([weight1, weight2], 1).double()
435 | h = torch.cat((h_1, h_2), 1)
436 |
437 | G_1 = -torch.eye(nelement_distmatrix).double().cuda().unsqueeze(0).repeat(nbatch, 1, 1)
438 | G_2 = torch.zeros([nbatch, nelement_weight1 + nelement_weight2, nelement_distmatrix]).double().cuda()
439 | # sum_j(xij) = si
440 | for i in range(nelement_weight1):
441 | G_2[:, i, nelement_weight2 * i:nelement_weight2 * (i + 1)] = 1
442 | # sum_i(xij) = dj
443 | for j in range(nelement_weight2):
444 | G_2[:, nelement_weight1 + j, j::nelement_weight2] = 1
445 | #xij>=0, sum_j(xij) <= si,sum_i(xij) <= dj, sum_ij(x_ij) = min(sum(si), sum(dj))
446 | G = torch.cat((G_1, G_2), 1)
447 | A = torch.ones(nbatch, 1, nelement_distmatrix).double().cuda()
448 | b = torch.min(torch.sum(weight1, 1), torch.sum(weight2, 1)).unsqueeze(1).double()
449 | flow = QPFunction(verbose=-1)(Q, p, G, h, A, b)
450 |
451 | emd_score = torch.sum((1 - Q_1).squeeze() * flow, 1)
452 | return emd_score, flow.view(-1, nelement_weight1, nelement_weight2)
453 |
454 | def emd_inference_opencv(cost_matrix, weight1, weight2):
455 | # cost matrix is a tensor of shape [N,N]
456 | cost_matrix = cost_matrix.detach().cpu().numpy()
457 |
458 | # weight1 = F.relu(weight1) + 1e-5
459 | # weight2 = F.relu(weight2) + 1e-5
460 | #
461 | # weight1 = (weight1 * (weight1.shape[0] / weight1.sum().item())).view(-1, 1).detach().cpu().numpy()
462 | # weight2 = (weight2 * (weight2.shape[0] / weight2.sum().item())).view(-1, 1).detach().cpu().numpy()
463 | weight1 = weight1.detach().cpu().numpy()
464 | weight2 = weight2.detach().cpu().numpy()
465 | # print(np.sum(weight1))
466 | # print(np.sum(weight2))
467 | cost, _, flow = cv2.EMD(weight1, weight2, cv2.DIST_USER, cost_matrix)
468 | return cost, flow
469 |
470 | def emd_inference_opencv_test(distance_matrix,weight1,weight2):
471 | distance_list = []
472 | flow_list = []
473 |
474 | for i in range (distance_matrix.shape[0]):
475 | cost,flow=emd_inference_opencv(distance_matrix[i],weight1[i],weight2[i])
476 | distance_list.append(cost)
477 | flow_list.append(torch.from_numpy(flow))
478 |
479 | emd_distance = torch.Tensor(distance_list).cuda().double()
480 | flow = torch.stack(flow_list, dim=0).cuda().double()
481 |
482 | return emd_distance,flow
483 |
484 | def LocalNetHead(query, support, support_labels, n_way, n_shot, normalize=True):
485 | """
486 | Constructs the prototype representation of each class(=mean of support vectors of each class) and
487 | returns the classification score (=L2 distance to each class prototype) on the query set.
488 |
489 | This model is the classification head described in:
490 | Prototypical Networks for Few-shot Learning
491 | (Snell et al., NIPS 2017).
492 |
493 | Parameters:
494 | query: a (tasks_per_batch, n_query, d) Tensor.
495 | support: a (tasks_per_batch, n_support, d) Tensor.
496 | support_labels: a (tasks_per_batch, n_support) Tensor.
497 | n_way: a scalar. Represents the number of classes in a few-shot classification task.
498 | n_shot: a scalar. Represents the number of support examples given per class.
499 | normalize: a boolean. Represents whether if we want to normalize the distances by the embedding dimension.
500 | Returns: a (tasks_per_batch, n_query, n_way) Tensor.
501 | """
502 |
503 | tasks_per_batch = query.size(0)
504 | n_support = support.size(1)
505 | n_query = query.size(1)
506 | d = query.size(2)
507 |
508 | assert (query.dim() == 3)
509 | assert (support.dim() == 3)
510 | assert (query.size(0) == support.size(0) and query.size(2) == support.size(2))
511 | assert (n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
512 |
513 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way)
514 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way)
515 |
516 | # From:
517 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/architectures/PrototypicalNetworksHead.py
518 | # ************************* Compute Prototypes **************************
519 | labels_train_transposed = support_labels_one_hot.transpose(1, 2)
520 | # Batch matrix multiplication:
521 | # prototypes = labels_train_transposed * features_train ==>
522 | # [batch_size x nKnovel x num_channels] =
523 | # [batch_size x nKnovel x num_train_examples] * [batch_size * num_train_examples * num_channels]
524 | prototypes = torch.bmm(labels_train_transposed, support)
525 | # Divide with the number of examples per novel category.
526 | prototypes = prototypes.div(
527 | labels_train_transposed.sum(dim=2, keepdim=True).expand_as(prototypes)
528 | )
529 | data = torch.cat([support, query], dim=1)
530 | data_local_dis = []
531 | for i in range(tasks_per_batch):
532 | dis = x2p_torch(data[i], tol=1e-5, perplexity=10.0)
533 | data_local_dis.append(dis)
534 | data_local_dis = torch.stack(data_local_dis, dim=0)
535 | # Distance Matrix Vectorization Trick
536 | # AB = computeGramMatrix(query, prototypes)
537 | # AA = (query * query).sum(dim=2, keepdim=True)
538 | # BB = (prototypes * prototypes).sum(dim=2, keepdim=True).reshape(tasks_per_batch, 1, n_way)
539 | # logits = AA.expand_as(AB) - 2 * AB + BB.expand_as(AB)
540 | # logits = -logits
541 | #
542 | # if normalize:
543 | # logits = logits / d
544 | cosine_distance_matrix = 1 - torch.nn.functional.cosine_similarity(data.unsqueeze(2).expand(-1, -1, data.shape[1], -1),
545 | data.unsqueeze(1).expand(-1, data.shape[1], -1, -1), dim=-1)
546 | support_dist = data_local_dis[:, :n_support, :].unsqueeze(2).expand(-1, -1, n_query, -1).reshape(tasks_per_batch*n_support*n_query, -1)
547 | query_dist = data_local_dis[:, n_support:, :].unsqueeze(1).expand(-1, n_support, -1, -1).reshape(tasks_per_batch*n_support*n_query, -1)
548 | cost_matrix = cosine_distance_matrix.unsqueeze(1).expand(-1, n_support*n_query, -1, -1).reshape(tasks_per_batch*n_support*n_query, n_support+n_query, n_support+n_query)
549 | # emd_distance_qpth, qpth_flow = emd_inference_qpth(cost_matrix, support_dist, query_dist)
550 | emd_distance_qpth, qpth_flow = emd_inference_opencv_test(cost_matrix, support_dist, query_dist)
551 | logits = emd_distance_qpth.reshape(tasks_per_batch, n_support, n_query).permute(0, 2, 1).type_as(support)
552 | return logits
553 |
554 | def ProtoNetHead(query, support, support_labels, n_way, n_shot, normalize=True):
555 | """
556 | Constructs the prototype representation of each class(=mean of support vectors of each class) and
557 | returns the classification score (=L2 distance to each class prototype) on the query set.
558 |
559 | This model is the classification head described in:
560 | Prototypical Networks for Few-shot Learning
561 | (Snell et al., NIPS 2017).
562 |
563 | Parameters:
564 | query: a (tasks_per_batch, n_query, d) Tensor.
565 | support: a (tasks_per_batch, n_support, d) Tensor.
566 | support_labels: a (tasks_per_batch, n_support) Tensor.
567 | n_way: a scalar. Represents the number of classes in a few-shot classification task.
568 | n_shot: a scalar. Represents the number of support examples given per class.
569 | normalize: a boolean. Represents whether if we want to normalize the distances by the embedding dimension.
570 | Returns: a (tasks_per_batch, n_query, n_way) Tensor.
571 | """
572 |
573 | tasks_per_batch = query.size(0)
574 | n_support = support.size(1)
575 | n_query = query.size(1)
576 | d = query.size(2)
577 |
578 | assert(query.dim() == 3)
579 | assert(support.dim() == 3)
580 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2))
581 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
582 |
583 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way)
584 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way)
585 |
586 | # From:
587 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/architectures/PrototypicalNetworksHead.py
588 | #************************* Compute Prototypes **************************
589 | labels_train_transposed = support_labels_one_hot.transpose(1,2)
590 | # Batch matrix multiplication:
591 | # prototypes = labels_train_transposed * features_train ==>
592 | # [batch_size x nKnovel x num_channels] =
593 | # [batch_size x nKnovel x num_train_examples] * [batch_size * num_train_examples * num_channels]
594 | prototypes = torch.bmm(labels_train_transposed, support)
595 | # Divide with the number of examples per novel category.
596 | prototypes = prototypes.div(
597 | labels_train_transposed.sum(dim=2, keepdim=True).expand_as(prototypes)
598 | )
599 |
600 | # Distance Matrix Vectorization Trick
601 | AB = computeGramMatrix(query, prototypes)
602 | AA = (query * query).sum(dim=2, keepdim=True)
603 | BB = (prototypes * prototypes).sum(dim=2, keepdim=True).reshape(tasks_per_batch, 1, n_way)
604 | logits = AA.expand_as(AB) - 2 * AB + BB.expand_as(AB)
605 | logits = -logits
606 |
607 | if normalize:
608 | logits = logits / d
609 |
610 | return logits
611 |
612 | def MetaOptNetHead_SVM_CS(query, support, support_labels, n_way, n_shot, C_reg=0.1, double_precision=False, maxIter=15):
613 | """
614 | Fits the support set with multi-class SVM and
615 | returns the classification score on the query set.
616 |
617 | This is the multi-class SVM presented in:
618 | On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines
619 | (Crammer and Singer, Journal of Machine Learning Research 2001).
620 |
621 | This model is the classification head that we use for the final version.
622 | Parameters:
623 | query: a (tasks_per_batch, n_query, d) Tensor.
624 | support: a (tasks_per_batch, n_support, d) Tensor.
625 | support_labels: a (tasks_per_batch, n_support) Tensor.
626 | n_way: a scalar. Represents the number of classes in a few-shot classification task.
627 | n_shot: a scalar. Represents the number of support examples given per class.
628 | C_reg: a scalar. Represents the cost parameter C in SVM.
629 | Returns: a (tasks_per_batch, n_query, n_way) Tensor.
630 | """
631 |
632 | tasks_per_batch = query.size(0)
633 | n_support = support.size(1)
634 | n_query = query.size(1)
635 |
636 | assert(query.dim() == 3)
637 | assert(support.dim() == 3)
638 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2))
639 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
640 |
641 | #Here we solve the dual problem:
642 | #Note that the classes are indexed by m & samples are indexed by i.
643 | #min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i
644 | #s.t. \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i
645 |
646 | #where w_m(\alpha) = \sum_i \alpha^m_i x_i,
647 | #and C^m_i = C if m = y_i,
648 | #C^m_i = 0 if m != y_i.
649 | #This borrows the notation of liblinear.
650 |
651 | #\alpha is an (n_support, n_way) matrix
652 | kernel_matrix = computeGramMatrix(support, support)
653 |
654 | id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda()
655 | block_kernel_matrix = batched_kronecker(kernel_matrix, id_matrix_0)
656 | #This seems to help avoid PSD error from the QP solver.
657 | block_kernel_matrix += 1.0 * torch.eye(n_way*n_support).expand(tasks_per_batch, n_way*n_support, n_way*n_support).cuda()
658 |
659 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) # (tasks_per_batch * n_support, n_support)
660 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way)
661 | support_labels_one_hot = support_labels_one_hot.reshape(tasks_per_batch, n_support * n_way)
662 |
663 | G = block_kernel_matrix
664 | e = -1.0 * support_labels_one_hot
665 | #print (G.size())
666 | #This part is for the inequality constraints:
667 | #\alpha^m_i <= C^m_i \forall m,i
668 | #where C^m_i = C if m = y_i,
669 | #C^m_i = 0 if m != y_i.
670 | id_matrix_1 = torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support)
671 | C = Variable(id_matrix_1)
672 | h = Variable(C_reg * support_labels_one_hot)
673 | #print (C.size(), h.size())
674 | #This part is for the equality constraints:
675 | #\sum_m \alpha^m_i=0 \forall i
676 | id_matrix_2 = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda()
677 |
678 | A = Variable(batched_kronecker(id_matrix_2, torch.ones(tasks_per_batch, 1, n_way).cuda()))
679 | b = Variable(torch.zeros(tasks_per_batch, n_support))
680 | #print (A.size(), b.size())
681 | if double_precision:
682 | G, e, C, h, A, b = [x.double().cuda() for x in [G, e, C, h, A, b]]
683 | else:
684 | G, e, C, h, A, b = [x.float().cuda() for x in [G, e, C, h, A, b]]
685 |
686 | # Solve the following QP to fit SVM:
687 | # \hat z = argmin_z 1/2 z^T G z + e^T z
688 | # subject to Cz <= h
689 | # We use detach() to prevent backpropagation to fixed variables.
690 | qp_sol = QPFunction(verbose=False, maxIter=maxIter)(G, e.detach(), C.detach(), h.detach(), A.detach(), b.detach())
691 |
692 | # Compute the classification score.
693 | compatibility = computeGramMatrix(support, query)
694 | compatibility = compatibility.float()
695 | compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way)
696 | qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way)
697 | logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way)
698 | logits = logits * compatibility
699 | logits = torch.sum(logits, 1)
700 |
701 | return logits
702 |
703 | def MetaOptNetHead_SVM_WW(query, support, support_labels, n_way, n_shot, C_reg=0.00001, double_precision=False):
704 | """
705 | Fits the support set with multi-class SVM and
706 | returns the classification score on the query set.
707 |
708 | This is the multi-class SVM presented in:
709 | Support Vector Machines for Multi Class Pattern Recognition
710 | (Weston and Watkins, ESANN 1999).
711 |
712 | Parameters:
713 | query: a (tasks_per_batch, n_query, d) Tensor.
714 | support: a (tasks_per_batch, n_support, d) Tensor.
715 | support_labels: a (tasks_per_batch, n_support) Tensor.
716 | n_way: a scalar. Represents the number of classes in a few-shot classification task.
717 | n_shot: a scalar. Represents the number of support examples given per class.
718 | C_reg: a scalar. Represents the cost parameter C in SVM.
719 | Returns: a (tasks_per_batch, n_query, n_way) Tensor.
720 | """
721 | """
722 | Fits the support set with multi-class SVM and
723 | returns the classification score on the query set.
724 |
725 | This is the multi-class SVM presented in:
726 | Support Vector Machines for Multi Class Pattern Recognition
727 | (Weston and Watkins, ESANN 1999).
728 |
729 | Parameters:
730 | query: a (tasks_per_batch, n_query, d) Tensor.
731 | support: a (tasks_per_batch, n_support, d) Tensor.
732 | support_labels: a (tasks_per_batch, n_support) Tensor.
733 | n_way: a scalar. Represents the number of classes in a few-shot classification task.
734 | n_shot: a scalar. Represents the number of support examples given per class.
735 | C_reg: a scalar. Represents the cost parameter C in SVM.
736 | Returns: a (tasks_per_batch, n_query, n_way) Tensor.
737 | """
738 | tasks_per_batch = query.size(0)
739 | n_support = support.size(1)
740 | n_query = query.size(1)
741 |
742 | assert(query.dim() == 3)
743 | assert(support.dim() == 3)
744 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2))
745 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
746 |
747 | #In theory, \alpha is an (n_support, n_way) matrix
748 | #NOTE: In this implementation, we solve for a flattened vector of size (n_way*n_support)
749 | #In order to turn it into a matrix, you must first reshape it into an (n_way, n_support) matrix
750 | #then transpose it, resulting in (n_support, n_way) matrix
751 | kernel_matrix = computeGramMatrix(support, support) + torch.ones(tasks_per_batch, n_support, n_support).cuda()
752 |
753 | id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda()
754 | block_kernel_matrix = batched_kronecker(id_matrix_0, kernel_matrix)
755 |
756 | kernel_matrix_mask_x = support_labels.reshape(tasks_per_batch, n_support, 1).expand(tasks_per_batch, n_support, n_support)
757 | kernel_matrix_mask_y = support_labels.reshape(tasks_per_batch, 1, n_support).expand(tasks_per_batch, n_support, n_support)
758 | kernel_matrix_mask = (kernel_matrix_mask_x == kernel_matrix_mask_y).float()
759 |
760 | block_kernel_matrix_inter = kernel_matrix_mask * kernel_matrix
761 | block_kernel_matrix += block_kernel_matrix_inter.repeat(1, n_way, n_way)
762 |
763 | kernel_matrix_mask_second_term = support_labels.reshape(tasks_per_batch, n_support, 1).expand(tasks_per_batch, n_support, n_support * n_way)
764 | kernel_matrix_mask_second_term = kernel_matrix_mask_second_term == torch.arange(n_way).long().repeat(n_support).reshape(n_support, n_way).transpose(1, 0).reshape(1, -1).repeat(n_support, 1).cuda()
765 | kernel_matrix_mask_second_term = kernel_matrix_mask_second_term.float()
766 |
767 | block_kernel_matrix -= (2.0 - 1e-4) * (kernel_matrix_mask_second_term * kernel_matrix.repeat(1, 1, n_way)).repeat(1, n_way, 1)
768 |
769 | Y_support = one_hot(support_labels.view(tasks_per_batch * n_support), n_way)
770 | Y_support = Y_support.view(tasks_per_batch, n_support, n_way)
771 | Y_support = Y_support.transpose(1, 2) # (tasks_per_batch, n_way, n_support)
772 | Y_support = Y_support.reshape(tasks_per_batch, n_way * n_support)
773 |
774 | G = block_kernel_matrix
775 |
776 | e = -2.0 * torch.ones(tasks_per_batch, n_way * n_support)
777 | id_matrix = torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support)
778 |
779 | C_mat = C_reg * torch.ones(tasks_per_batch, n_way * n_support).cuda() - C_reg * Y_support
780 |
781 | C = Variable(torch.cat((id_matrix, -id_matrix), 1))
782 | #C = Variable(torch.cat((id_matrix_masked, -id_matrix_masked), 1))
783 | zer = torch.zeros(tasks_per_batch, n_way * n_support).cuda()
784 |
785 | h = Variable(torch.cat((C_mat, zer), 1))
786 |
787 | dummy = Variable(torch.Tensor()).cuda() # We want to ignore the equality constraint.
788 |
789 | if double_precision:
790 | G, e, C, h = [x.double().cuda() for x in [G, e, C, h]]
791 | else:
792 | G, e, C, h = [x.cuda() for x in [G, e, C, h]]
793 |
794 | # Solve the following QP to fit SVM:
795 | # \hat z = argmin_z 1/2 z^T G z + e^T z
796 | # subject to Cz <= h
797 | # We use detach() to prevent backpropagation to fixed variables.
798 | #qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach())
799 | qp_sol = QPFunction(verbose=False)(G, e, C, h, dummy.detach(), dummy.detach())
800 |
801 | # Compute the classification score.
802 | compatibility = computeGramMatrix(support, query) + torch.ones(tasks_per_batch, n_support, n_query).cuda()
803 | compatibility = compatibility.float()
804 | compatibility = compatibility.unsqueeze(1).expand(tasks_per_batch, n_way, n_support, n_query)
805 | qp_sol = qp_sol.float()
806 | qp_sol = qp_sol.reshape(tasks_per_batch, n_way, n_support)
807 | A_i = torch.sum(qp_sol, 1) # (tasks_per_batch, n_support)
808 | A_i = A_i.unsqueeze(1).expand(tasks_per_batch, n_way, n_support)
809 | qp_sol = qp_sol.float().unsqueeze(3).expand(tasks_per_batch, n_way, n_support, n_query)
810 | Y_support_reshaped = Y_support.reshape(tasks_per_batch, n_way, n_support)
811 | Y_support_reshaped = A_i * Y_support_reshaped
812 | Y_support_reshaped = Y_support_reshaped.unsqueeze(3).expand(tasks_per_batch, n_way, n_support, n_query)
813 | logits = (Y_support_reshaped - qp_sol) * compatibility
814 |
815 | logits = torch.sum(logits, 2)
816 |
817 | return logits.transpose(1, 2)
818 |
819 | class ClassificationHead(nn.Module):
820 | def __init__(self, base_learner='MetaOptNet', enable_scale=True):
821 | super(ClassificationHead, self).__init__()
822 | if ('SVM-CS' in base_learner):
823 | self.head = MetaOptNetHead_SVM_CS
824 | elif ('Ridge' in base_learner):
825 | self.head = MetaOptNetHead_Ridge
826 | elif ('R2D2' in base_learner):
827 | self.head = R2D2Head
828 | elif ('Proto' in base_learner):
829 | self.head = ProtoNetHead
830 | elif ('Local' in base_learner):
831 | self.head = LocalNetHead
832 | elif ('Cosine' in base_learner):
833 | self.head = CosineNetHead
834 | elif ('SVM-He' in base_learner):
835 | self.head = MetaOptNetHead_SVM_He
836 | elif ('SVM-WW' in base_learner):
837 | self.head = MetaOptNetHead_SVM_WW
838 | elif ('FFT' in base_learner):
839 | self.head = FFTHead
840 | else:
841 | print ("Cannot recognize the base learner type")
842 | assert(False)
843 |
844 | # Add a learnable scale
845 | self.enable_scale = enable_scale
846 | self.scale = nn.Parameter(torch.FloatTensor([1.0]))
847 |
848 | def forward(self, query, support, support_labels, n_way, n_shot, **kwargs):
849 | if self.enable_scale:
850 | return self.scale * self.head(query, support, support_labels, n_way, n_shot, **kwargs)
851 | else:
852 | return self.head(query, support, support_labels, n_way, n_shot, **kwargs) * 10
853 |
--------------------------------------------------------------------------------