├── data ├── __init__.py └── mini_imagenet.py ├── models ├── __init__.py ├── attention.py ├── resnet12_2.py ├── diffusion_grad.py └── classification_heads.py ├── __init__.py ├── utils.py ├── README.md ├── metadiff_test.py ├── cal_optial_prototypes.py └── metadiff_train.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Implement your code here. 2 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import pprint 4 | import torch 5 | 6 | def set_gpu(x): 7 | os.environ['CUDA_VISIBLE_DEVICES'] = x 8 | print('using gpu:', x) 9 | 10 | def check_dir(path): 11 | ''' 12 | Create directory if it does not exist. 13 | path: Path of directory. 14 | ''' 15 | if not os.path.exists(path): 16 | os.mkdir(path) 17 | 18 | def count_accuracy(logits, label): 19 | pred = torch.argmax(logits, dim=1).view(-1) 20 | label = label.view(-1) 21 | accuracy = 100 * pred.eq(label).float().mean() 22 | return accuracy 23 | 24 | class Timer(): 25 | def __init__(self): 26 | self.o = time.time() 27 | 28 | def measure(self, p=1): 29 | x = (time.time() - self.o) / float(p) 30 | # x = int(x) 31 | if x >= 3600: 32 | return '{:.1f}h'.format(x / 3600) 33 | if x >= 60: 34 | return '{}m'.format(round(x / 60)) 35 | return '{}s'.format(x) 36 | 37 | def log(log_file_path, string): 38 | ''' 39 | Write one line of log into screen and file. 40 | log_file_path: Path of log file. 41 | string: String to write in log file. 42 | ''' 43 | with open(log_file_path, 'a+') as f: 44 | f.write(string + '\n') 45 | f.flush() 46 | print(string) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MetaDiff: Meta-Learning with Conditional Diffusion for Few-Shot Learning 2 | This repository contains the code for the paper: 3 |
4 | [**MetaDiff: Meta-Learning with Conditional Diffusion for Few-Shot Learning**](https://arxiv.org/pdf/2307.16424.pdf) 5 |
6 | Baoquan Zhang, Chuyao Luo, Demin Yu, Huiwei Lin, Xutao Li, Yunming Ye, Bowen Zhang 7 |
8 | AAAI 2024 9 | 10 | ### Abstract 11 | 12 | Equipping a deep model the abaility of few-shot learning, i.e., learning quickly from only few examples, is a core challenge for artificial intelligence. Gradient-based meta-learning approaches effectively address the challenge by learning how to learn novel tasks. Its key idea is learning a deep model in a bi-level optimization manner, where the outer-loop process learns a shared gradient descent algorithm (i.e., its hyperparameters), while the inner-loop process leverage it to optimize a task-specific model by using only few labeled data. Although these existing methods have shown superior performance, the outer-loop process requires calculating second-order derivatives along the inner optimization path, which imposes considerable memory burdens and the risk of vanishing gradients. Drawing inspiration from recent progress of diffusion models, we find that the inner-loop gradient descent process can be actually viewed as a reverse process (i.e., denoising) of diffusion where the target of denoising is model weights but the origin data. Based on this fact, in this paper, we propose to model the gradient descent optimizer as a diffusion model and then present a novel task-conditional diffusion-based meta-learning, called MetaDiff, that effectively models the optimization process of model weights from Gaussion noises to target weights in a denoising manner. Thanks to the training efficiency of diffusion models, our MetaDiff do not need to differentiate through the inner-loop path such that the memory burdens and the risk of vanishing gradients can be effectvely alleviated. Experiment results show that our MetaDiff outperforms the state-of-the-art gradient-based meta-learning family in few-shot learning tasks. 13 | 14 | ### Citation 15 | 16 | If you use this code for your research, please cite our paper: 17 | ``` 18 | @inproceedings{zhang2022metadiff, 19 | author = {Zhang, Baoquan and Luo, Chuyao and Yu, Demin and Lin, Huiwei and Li, Xutao and Ye, Yunming and Zhang, Bowen}, 20 | title = {MetaDiff: Meta-Learning with Conditional Diffusion for Few-Shot Learning}, 21 | booktitle = {AAAI}, 22 | year = {2024}, 23 | } 24 | ``` 25 | -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import math 5 | 6 | 7 | class SelfAttention(nn.Module): 8 | def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True): 9 | super().__init__() 10 | self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias) 11 | self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias) 12 | self.n_heads = n_heads 13 | self.d_head = d_embed // n_heads 14 | 15 | def forward(self, x, causal_mask=False): 16 | input_shape = x.shape 17 | batch_size, sequence_length, d_embed = input_shape 18 | interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head) 19 | 20 | q, k, v = self.in_proj(x).chunk(3, dim=-1) 21 | 22 | q = q.view(interim_shape).transpose(1, 2) 23 | k = k.view(interim_shape).transpose(1, 2) 24 | v = v.view(interim_shape).transpose(1, 2) 25 | 26 | weight = q @ k.transpose(-1, -2) 27 | if causal_mask: 28 | mask = torch.ones_like(weight, dtype=torch.bool).triu(1) 29 | weight.masked_fill_(mask, -torch.inf) 30 | weight /= math.sqrt(self.d_head) 31 | weight = F.softmax(weight, dim=-1) 32 | 33 | output = weight @ v 34 | output = output.transpose(1, 2) 35 | output = output.reshape(input_shape) 36 | output = self.out_proj(output) 37 | return output 38 | 39 | class CrossAttention(nn.Module): 40 | def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True): 41 | super().__init__() 42 | self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias) 43 | self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias) 44 | self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias) 45 | self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias) 46 | self.n_heads = n_heads 47 | self.d_head = d_embed // n_heads 48 | 49 | def forward(self, x, y): 50 | input_shape = x.shape 51 | batch_size, sequence_length, d_embed = input_shape 52 | interim_shape = (batch_size, -1, self.n_heads, self.d_head) 53 | 54 | q = self.q_proj(x) 55 | k = self.k_proj(y) 56 | v = self.v_proj(y) 57 | 58 | q = q.view(interim_shape).transpose(1, 2) 59 | k = k.view(interim_shape).transpose(1, 2) 60 | v = v.view(interim_shape).transpose(1, 2) 61 | 62 | weight = q @ k.transpose(-1, -2) 63 | weight /= math.sqrt(self.d_head) 64 | weight = F.softmax(weight, dim=-1) 65 | 66 | output = weight @ v 67 | output = output.transpose(1, 2).contiguous() 68 | output = output.view(input_shape) 69 | output = self.out_proj(output) 70 | return output -------------------------------------------------------------------------------- /models/resnet12_2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def conv3x3(in_planes, out_planes): 4 | return nn.Conv2d(in_planes, out_planes, 3, padding=1, bias=False) 5 | 6 | 7 | def conv1x1(in_planes, out_planes): 8 | return nn.Conv2d(in_planes, out_planes, 1, bias=False) 9 | 10 | 11 | def norm_layer(planes): 12 | return nn.BatchNorm2d(planes) 13 | 14 | 15 | class Block(nn.Module): 16 | 17 | def __init__(self, inplanes, planes, downsample, use_relu=True): 18 | super().__init__() 19 | 20 | self.use_relu = use_relu 21 | self.relu = nn.LeakyReLU(0.1) 22 | 23 | self.conv1 = conv3x3(inplanes, planes) 24 | self.bn1 = norm_layer(planes) 25 | self.conv2 = conv3x3(planes, planes) 26 | self.bn2 = norm_layer(planes) 27 | self.conv3 = conv3x3(planes, planes) 28 | self.bn3 = norm_layer(planes) 29 | 30 | self.downsample = downsample 31 | 32 | self.maxpool = nn.MaxPool2d(2) 33 | 34 | def forward(self, x): 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv3(out) 44 | out = self.bn3(out) 45 | 46 | identity = self.downsample(x) 47 | 48 | out += identity 49 | if self.use_relu: 50 | out = self.relu(out) 51 | 52 | out = self.maxpool(out) 53 | 54 | return out 55 | 56 | 57 | class ResNet12(nn.Module): 58 | 59 | def __init__(self, channels): 60 | super().__init__() 61 | 62 | self.inplanes = 3 63 | 64 | self.layer1 = self._make_layer(channels[0]) 65 | self.layer2 = self._make_layer(channels[1]) 66 | self.layer3 = self._make_layer(channels[2]) 67 | self.layer4 = self._make_layer(channels[3], use_relu=False) 68 | 69 | self.out_dim = channels[3] 70 | 71 | for m in self.modules(): 72 | if isinstance(m, nn.Conv2d): 73 | nn.init.kaiming_normal_(m.weight, mode='fan_out', 74 | nonlinearity='leaky_relu') 75 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 76 | nn.init.constant_(m.weight, 1) 77 | nn.init.constant_(m.bias, 0) 78 | 79 | def _make_layer(self, planes, use_relu=True): 80 | downsample = nn.Sequential( 81 | conv1x1(self.inplanes, planes), 82 | norm_layer(planes), 83 | ) 84 | block = Block(self.inplanes, planes, downsample, use_relu=use_relu) 85 | self.inplanes = planes 86 | return block 87 | 88 | def forward(self, x, use_pool=True): 89 | x = self.layer1(x) 90 | x = self.layer2(x) 91 | x = self.layer3(x) 92 | x = self.layer4(x) 93 | if use_pool: 94 | x = x.view(x.shape[0], x.shape[1], -1).mean(dim=2) 95 | return x 96 | 97 | def resnet12(): 98 | return ResNet12([64, 128, 256, 512]) 99 | 100 | def resnet12_wide(): 101 | return ResNet12([64, 160, 320, 640]) 102 | -------------------------------------------------------------------------------- /metadiff_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import argparse 4 | import random 5 | import numpy as np 6 | from tqdm import tqdm 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | from torch.autograd import Variable 12 | 13 | from models.resnet12_2 import resnet12 14 | from models.diffusion_grad import get_diffusion_from_args 15 | 16 | from utils import set_gpu, Timer, count_accuracy, check_dir, log 17 | 18 | def one_hot(indices, depth): 19 | """ 20 | Returns a one-hot tensor. 21 | This is a PyTorch equivalent of Tensorflow's tf.one_hot. 22 | 23 | Parameters: 24 | indices: a (n_batch, m) Tensor or (m) Tensor. 25 | depth: a scalar. Represents the depth of the one hot dimension. 26 | Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor. 27 | """ 28 | 29 | encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda() 30 | index = indices.view(indices.size()+torch.Size([1])) 31 | encoded_indicies = encoded_indicies.scatter_(1,index,1) 32 | 33 | return encoded_indicies 34 | 35 | def get_model(options): 36 | # Choose the embedding network 37 | if options.network == 'ResNet': 38 | network = resnet12().cuda() 39 | network = torch.nn.DataParallel(network) 40 | fea_dim = 512 41 | else: 42 | print ("Cannot recognize the network type") 43 | assert(False) 44 | # Choose the classification head 45 | cls_head = get_diffusion_from_args().cuda() 46 | return (network, cls_head) 47 | 48 | def get_dataset(options): 49 | # Choose the embedding network 50 | if options.dataset == 'miniImageNet': 51 | from data.mini_imagenet import MiniImageNet, FewShotDataloader 52 | dataset_test = MiniImageNet(phase='test') 53 | data_loader = FewShotDataloader 54 | else: 55 | print ("Cannot recognize the dataset type") 56 | assert(False) 57 | 58 | return (dataset_test, data_loader) 59 | 60 | def seed_torch(seed=21): 61 | os.environ['PYTHONHASHSEED'] = str(seed) 62 | random.seed(seed) 63 | np.random.seed(seed) 64 | torch.manual_seed(seed) 65 | torch.cuda.manual_seed(seed) 66 | torch.cuda.manual_seed_all(seed) 67 | # torch.backends.cudnn.deterministic = True 68 | torch.backends.cudnn.deterministic = True #cpu/gpu结果一致 69 | torch.backends.cudnn.benchmark = False #训练集变化不大时使训练加速 70 | 71 | 72 | def test(opt, dataset_test, data_loader): 73 | # Dataloader of Gidaris & Komodakis (CVPR 2018) 74 | dloader_test = data_loader( 75 | dataset=dataset_test, 76 | nKnovel=opt.test_way, 77 | nKbase=0, 78 | nExemplars=opt.val_shot, # num training examples per novel category 79 | nTestNovel=opt.val_query * opt.test_way, # num test examples for all the novel categories 80 | nTestBase=0, # num test examples for all the base categories 81 | batch_size=1, 82 | num_workers=0, 83 | epoch_size=1 * opt.val_episode, # num of batches per epoch 84 | ) 85 | 86 | set_gpu(opt.gpu) 87 | check_dir('./experiments/') 88 | check_dir(opt.save_path) 89 | 90 | log_file_path = os.path.join(opt.save_path, "train_log.txt") 91 | log(log_file_path, str(vars(opt))) 92 | 93 | (embedding_net, cls_head) = get_model(opt) 94 | # Load saved model checkpoints 95 | saved_models = torch.load(os.path.join(opt.save_path, 'model.pth')) 96 | embedding_net.load_state_dict(saved_models['embedding']) 97 | embedding_net.eval() 98 | cls_head.load_state_dict(saved_models['head']) 99 | cls_head.eval() 100 | 101 | max_val_acc = 0.0 102 | max_test_acc = 0.0 103 | 104 | timer = Timer() 105 | x_entropy = torch.nn.CrossEntropyLoss() 106 | 107 | # Evaluate on the validation split 108 | _, _ = [x.eval() for x in (cls_head, embedding_net)] 109 | 110 | test_accuracies = [] 111 | test_losses = [] 112 | for i, batch in enumerate(tqdm(dloader_test()), 1): 113 | data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch] 114 | 115 | test_n_support = opt.test_way * opt.val_shot 116 | test_n_query = opt.test_way * opt.val_query 117 | 118 | with torch.no_grad(): 119 | emb_support = embedding_net(data_support.reshape([-1] + list(data_support.shape[-3:]))) 120 | emb_support = emb_support.reshape(1, test_n_support, -1) 121 | # emb_support = F.normalize(emb_support, dim=-1) 122 | emb_query = embedding_net(data_query.reshape([-1] + list(data_query.shape[-3:]))) 123 | emb_query = emb_query.reshape(1, test_n_query, -1) 124 | # emb_query = F.normalize(emb_query, dim=-1) 125 | 126 | logit_query = cls_head.sample(emb_query, emb_support, labels_support, labels_query, opt.test_way, opt.val_shot) 127 | loss = x_entropy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1)) 128 | acc = count_accuracy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1)) 129 | 130 | test_accuracies.append(acc.item()) 131 | test_losses.append(loss.item()) 132 | 133 | test_acc_avg = np.mean(np.array(test_accuracies)) 134 | test_acc_ci95 = 1.96 * np.std(np.array(test_accuracies)) / np.sqrt(opt.val_episode) 135 | 136 | test_loss_avg = np.mean(np.array(test_losses)) 137 | 138 | if test_acc_avg > max_test_acc: 139 | max_test_acc = test_acc_avg 140 | log(log_file_path, 'Test Loss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} % (Best)' \ 141 | .format(test_loss_avg, test_acc_avg, test_acc_ci95)) 142 | else: 143 | log(log_file_path, 'Test Loss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} %' \ 144 | .format(test_loss_avg, test_acc_avg, test_acc_ci95)) 145 | 146 | if __name__ == '__main__': 147 | seed_torch(21) 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument('--num-epoch', type=int, default=1000, 150 | help='number of training epochs') 151 | parser.add_argument('--val-shot', type=int, default=1, 152 | help='number of support examples per validation class') 153 | parser.add_argument('--val-episode', type=int, default=600, 154 | help='number of episodes per validation') 155 | parser.add_argument('--val-query', type=int, default=15, 156 | help='number of query examples per validation class') 157 | parser.add_argument('--test-way', type=int, default=5, 158 | help='number of classes in one test (or validation) episode') 159 | parser.add_argument('--save-path', default='./experiments/exp_1') 160 | parser.add_argument('--gpu', default='0') 161 | parser.add_argument('--network', type=str, default='ResNet', 162 | help='choose which embedding network to use. ProtoNet, R2D2, ResNet') 163 | parser.add_argument('--dataset', type=str, default='miniImageNet', 164 | help='choose which classification head to use. miniImageNet, tieredImageNet, CIFAR_FS, FC100') 165 | 166 | opt = parser.parse_args() 167 | 168 | (dataset_test, data_loader) = get_dataset(opt) 169 | 170 | test(opt, dataset_test, data_loader) -------------------------------------------------------------------------------- /cal_optial_prototypes.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import argparse 4 | import random 5 | import numpy as np 6 | from tqdm import tqdm 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | from torch.autograd import Variable 12 | 13 | from models.classification_heads import ClassificationHead 14 | from models.R2D2_embedding import R2D2Embedding 15 | from models.protonet_embedding import ProtoNetEmbedding 16 | from models.resnet12_2 import resnet12 17 | from utils import set_gpu, Timer, count_accuracy, check_dir, log 18 | import pickle 19 | 20 | def one_hot(indices, depth): 21 | """ 22 | Returns a one-hot tensor. 23 | This is a PyTorch equivalent of Tensorflow's tf.one_hot. 24 | 25 | Parameters: 26 | indices: a (n_batch, m) Tensor or (m) Tensor. 27 | depth: a scalar. Represents the depth of the one hot dimension. 28 | Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor. 29 | """ 30 | 31 | encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda() 32 | index = indices.view(indices.size()+torch.Size([1])) 33 | encoded_indicies = encoded_indicies.scatter_(1,index,1) 34 | 35 | return encoded_indicies 36 | 37 | def get_model(options): 38 | # Choose the embedding network 39 | if options.network == 'ProtoNet': 40 | network = ProtoNetEmbedding().cuda() 41 | fea_dim = 64 42 | elif options.network == 'R2D2': 43 | network = R2D2Embedding().cuda() 44 | elif options.network == 'ResNet': 45 | network = resnet12().cuda() 46 | network = torch.nn.DataParallel(network) 47 | fea_dim = 512 48 | else: 49 | print ("Cannot recognize the network type") 50 | assert(False) 51 | # info_max_layer = InfoMaxLayer().cuda() 52 | return network 53 | 54 | def get_dataset(options): 55 | # Choose the embedding network 56 | if options.dataset == 'miniImageNet': 57 | from data.mini_imagenet import MiniImageNet, FewShotDataloader 58 | dataset_train = MiniImageNet(phase='train') 59 | dataset_val = MiniImageNet(phase='val') 60 | dataset_test = MiniImageNet(phase='test') 61 | data_loader = FewShotDataloader 62 | else: 63 | print ("Cannot recognize the dataset type") 64 | assert(False) 65 | 66 | return (dataset_train, dataset_val, dataset_test, data_loader) 67 | 68 | if __name__ == '__main__': 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument('--num-epoch', type=int, default=100, 71 | help='number of training epochs') 72 | parser.add_argument('--save-epoch', type=int, default=20, 73 | help='frequency of model saving') 74 | parser.add_argument('--train-shot', type=int, default=1, 75 | help='number of support examples per training class') 76 | parser.add_argument('--val-shot', type=int, default=1, 77 | help='number of support examples per validation class') 78 | parser.add_argument('--train-query', type=int, default=15, 79 | help='number of query examples per training class') 80 | parser.add_argument('--val-episode', type=int, default=600, 81 | help='number of episodes per validation') 82 | parser.add_argument('--val-query', type=int, default=15, 83 | help='number of query examples per validation class') 84 | parser.add_argument('--train-way', type=int, default=5, 85 | help='number of classes in one training episode') 86 | parser.add_argument('--test-way', type=int, default=5, 87 | help='number of classes in one test (or validation) episode') 88 | parser.add_argument('--save-path', default='./experiments/exp_1') 89 | parser.add_argument('--gpu', default='0') 90 | parser.add_argument('--network', type=str, default='ResNet', 91 | help='choose which embedding network to use. ProtoNet, R2D2, ResNet') 92 | parser.add_argument('--head', type=str, default='CosineNet', 93 | help='choose which classification head to use. ProtoNet, Ridge, R2D2, SVM') 94 | parser.add_argument('--pre_head', type=str, default='add_margin', 95 | help='choose which classification head to use. ProtoNet, Ridge, R2D2, SVM') 96 | parser.add_argument('--dataset', type=str, default='miniImageNet', 97 | help='choose which classification head to use. miniImageNet, tieredImageNet, CIFAR_FS, FC100') 98 | parser.add_argument('--episodes-per-batch', type=int, default=8, 99 | help='number of episodes per batch') 100 | parser.add_argument('--eps', type=float, default=0.0, 101 | help='epsilon of label smoothing') 102 | 103 | opt = parser.parse_args() 104 | 105 | (dataset_train, dataset_val, dataset_test, data_loader) = get_dataset(opt) 106 | 107 | data_loader_pre = torch.utils.data.DataLoader 108 | # Dataloader of Gidaris & Komodakis (CVPR 2018) 109 | dloader_train = data_loader_pre( 110 | dataset=dataset_train, 111 | batch_size=1, 112 | shuffle=False, 113 | num_workers=0 114 | ) 115 | 116 | set_gpu(opt.gpu) 117 | check_dir('./experiments/') 118 | check_dir(opt.save_path) 119 | 120 | log_file_path = os.path.join(opt.save_path, "train_log.txt") 121 | log(log_file_path, str(vars(opt))) 122 | 123 | embedding_net = get_model(opt) 124 | 125 | # Load saved model checkpoints 126 | saved_models = torch.load(os.path.join(opt.save_path, 'best_pretrain_model_resnet.pth')) 127 | embedding_net.load_state_dict(saved_models['embedding']) 128 | embedding_net.eval() 129 | 130 | embs = [] 131 | labels = [] 132 | for i, batch in enumerate(tqdm(dloader_train), 1): 133 | data, label = [x.cuda() for x in batch] 134 | with torch.no_grad(): 135 | emb = embedding_net(data) 136 | embs.append(emb) 137 | labels.append(label) 138 | embs = torch.cat(embs, dim=0).cpu() 139 | labels = list(torch.cat(labels, dim=0).cpu().numpy()) 140 | label2index = {} 141 | for k, v in enumerate(labels): 142 | v = int(v) 143 | if v not in label2index.keys(): 144 | label2index[v] = [] 145 | label2index[v].append(k) 146 | else: 147 | label2index[v].append(k) 148 | 149 | label2optimal_proto = {} 150 | optimal_prototype = torch.zeros(64, 512).type_as(embs) 151 | for k, v in label2index.items(): 152 | sub_embs = embs[v, :] 153 | label2optimal_proto[k] = torch.mean(sub_embs, dim=0).unsqueeze(dim=0) 154 | optimal_prototype[k] = torch.mean(sub_embs, dim=0) 155 | 156 | database = {'dict': label2optimal_proto, 'array': optimal_prototype} 157 | with open(os.path.join(opt.save_path, "mini_imagenet_optimal_prototype.pickle"), 'wb') as handle: 158 | pickle.dump(database, handle, protocol=pickle.HIGHEST_PROTOCOL) 159 | 160 | with open(os.path.join(opt.save_path, "mini_imagenet_optimal_prototype.pickle"), 'rb') as handle: 161 | part_prior = pickle.load(handle) 162 | print(1) -------------------------------------------------------------------------------- /metadiff_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import argparse 4 | import random 5 | import numpy as np 6 | from tqdm import tqdm 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | from torch.autograd import Variable 12 | 13 | # from models.classification_heads import ClassificationHead 14 | from models.R2D2_embedding import R2D2Embedding 15 | from models.protonet_embedding import ProtoNetEmbedding 16 | from models.resnet12_2 import resnet12 17 | # from models.protonet_metanet1 import ProtonetMetaLearner 18 | # from models.classifier import LinearClassifier, NNClassifier, distLinear 19 | # from models.PredTrainHead import LinearRotateHead, DCLHead, DistRotateHead 20 | from models.diffusion import get_diffusion_from_args 21 | 22 | from utils import set_gpu, Timer, count_accuracy, check_dir, log 23 | 24 | def one_hot(indices, depth): 25 | """ 26 | Returns a one-hot tensor. 27 | This is a PyTorch equivalent of Tensorflow's tf.one_hot. 28 | 29 | Parameters: 30 | indices: a (n_batch, m) Tensor or (m) Tensor. 31 | depth: a scalar. Represents the depth of the one hot dimension. 32 | Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor. 33 | """ 34 | 35 | encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda() 36 | index = indices.view(indices.size()+torch.Size([1])) 37 | encoded_indicies = encoded_indicies.scatter_(1,index,1) 38 | 39 | return encoded_indicies 40 | 41 | def get_model(options): 42 | # Choose the embedding network 43 | if options.network == 'ProtoNet': 44 | network = ProtoNetEmbedding().cuda() 45 | fea_dim = 64 46 | elif options.network == 'R2D2': 47 | network = R2D2Embedding().cuda() 48 | elif options.network == 'ResNet': 49 | network = resnet12().cuda() 50 | network = torch.nn.DataParallel(network) 51 | fea_dim = 512 52 | else: 53 | print ("Cannot recognize the network type") 54 | assert(False) 55 | 56 | 57 | # Choose the classification head 58 | 59 | cls_head = get_diffusion_from_args().cuda() 60 | return (network, cls_head) 61 | 62 | def get_dataset(options): 63 | # Choose the embedding network 64 | if options.dataset == 'miniImageNet': 65 | from data.mini_imagenet import MiniImageNet, FewShotDataloader 66 | dataset_train = MiniImageNet(phase='train') 67 | dataset_val = MiniImageNet(phase='val') 68 | data_loader = FewShotDataloader 69 | else: 70 | print ("Cannot recognize the dataset type") 71 | assert(False) 72 | 73 | return (dataset_train, dataset_val, data_loader) 74 | 75 | def seed_torch(seed=21): 76 | os.environ['PYTHONHASHSEED'] = str(seed) 77 | random.seed(seed) 78 | np.random.seed(seed) 79 | torch.manual_seed(seed) 80 | torch.cuda.manual_seed(seed) 81 | torch.cuda.manual_seed_all(seed) 82 | # torch.backends.cudnn.deterministic = True 83 | torch.backends.cudnn.deterministic = True #cpu/gpu结果一致 84 | torch.backends.cudnn.benchmark = False #训练集变化不大时使训练加速 85 | 86 | def train(opt, dataset_train, dataset_val, data_loader): 87 | # Dataloader of Gidaris & Komodakis (CVPR 2018) 88 | dloader_train = data_loader( 89 | dataset=dataset_train, 90 | nKnovel=opt.train_way, 91 | nKbase=0, 92 | nExemplars=opt.train_shot, # num training examples per novel category 93 | nTestNovel=opt.train_way * opt.train_query, # num test examples for all the novel categories 94 | nTestBase=0, # num test examples for all the base categories 95 | batch_size=opt.episodes_per_batch, 96 | num_workers=4, 97 | epoch_size=opt.episodes_per_batch * 10000, # num of batches per epoch 98 | ) 99 | 100 | dloader_val = data_loader( 101 | dataset=dataset_val, 102 | nKnovel=opt.test_way, 103 | nKbase=0, 104 | nExemplars=opt.val_shot, # num training examples per novel category 105 | nTestNovel=opt.val_query * opt.test_way, # num test examples for all the novel categories 106 | nTestBase=0, # num test examples for all the base categories 107 | batch_size=1, 108 | num_workers=0, 109 | epoch_size=1 * opt.val_episode, # num of batches per epoch 110 | ) 111 | 112 | set_gpu(opt.gpu) 113 | check_dir('./experiments/') 114 | check_dir(opt.save_path) 115 | 116 | log_file_path = os.path.join(opt.save_path, "train_log.txt") 117 | log(log_file_path, str(vars(opt))) 118 | 119 | (embedding_net, cls_head) = get_model(opt) 120 | # Load saved model checkpoints 121 | saved_models = torch.load(os.path.join(opt.save_path, 'best_pretrain_model_resnet.pth')) 122 | # saved_models = torch.load(os.path.join(opt.save_path, 'best_model2.pth')) 123 | embedding_net.load_state_dict(saved_models['embedding']) 124 | embedding_net.eval() 125 | 126 | import pickle 127 | with open(os.path.join(opt.save_path, "mini_imagenet_optimal_prototype.pickle"), 'rb') as handle: 128 | prototype_ground_true = pickle.load(handle) 129 | prototype_ground_true_dict = prototype_ground_true['dict'] 130 | prototype_ground_true_array = prototype_ground_true['array'].cuda() 131 | 132 | optimizer = torch.optim.Adam(cls_head.parameters(), lr=0.0001, weight_decay=5e-4) 133 | 134 | lambda_epoch = lambda e: 1.0 if e < 100000 else (0.1 if e < 30 else 0.01 if e < 40 else (0.001)) 135 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_epoch, last_epoch=-1) 136 | 137 | max_val_acc = 0.0 138 | max_test_acc = 0.0 139 | 140 | timer = Timer() 141 | x_entropy = torch.nn.CrossEntropyLoss() 142 | 143 | for epoch in range(1, opt.num_epoch + 1): 144 | # Train on the training split 145 | lr_scheduler.step() 146 | 147 | # Fetch the current epoch's learning rate 148 | epoch_learning_rate = 0.1 149 | for param_group in optimizer.param_groups: 150 | epoch_learning_rate = param_group['lr'] 151 | 152 | log(log_file_path, 'Train Epoch: {}\tLearning Rate: {:.4f}'.format( 153 | epoch, epoch_learning_rate)) 154 | 155 | _ = [x.train() for x in (cls_head, )] 156 | # _, _ = [x.train() for x in (cls_head, embedding_net)] 157 | 158 | train_accuracies = [] 159 | train_losses = [] 160 | 161 | for i, batch in enumerate(tqdm(dloader_train(epoch)), 1): 162 | data_support, labels_support, data_query, labels_query, all_k, _ = [x.cuda() for x in batch] 163 | 164 | # prototype_ground_true = torch.gather(prototype_ground_true_array, 0, all_k) 165 | prototype_ground_true = prototype_ground_true_array[all_k, :] 166 | 167 | train_n_support = opt.train_way * opt.train_shot 168 | train_n_query = opt.train_way * opt.train_query 169 | 170 | with torch.no_grad(): 171 | emb_support = embedding_net(data_support.reshape([-1] + list(data_support.shape[-3:]))) 172 | emb_support = emb_support.reshape(opt.episodes_per_batch, train_n_support, -1) 173 | # emb_support = F.normalize(emb_support, dim=-1) 174 | 175 | emb_query = embedding_net(data_query.reshape([-1] + list(data_query.shape[-3:]))) 176 | emb_query = emb_query.reshape(opt.episodes_per_batch, train_n_query, -1) 177 | # emb_query = F.normalize(emb_query, dim=-1) 178 | 179 | loss = cls_head(prototype_ground_true, emb_query, emb_support, labels_support, labels_query, opt.train_way, opt.train_shot) 180 | 181 | 182 | 183 | if (i % 1000 == 0): 184 | logit_query = cls_head.sample(emb_query, emb_support, labels_support, labels_query, opt.train_way, 185 | opt.train_shot) 186 | # loss = x_entropy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1)) 187 | acc = count_accuracy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1)) 188 | 189 | train_accuracies.append(acc.item()) 190 | train_losses.append(loss.item()) 191 | train_acc_avg = np.mean(np.array(train_accuracies)) 192 | log(log_file_path, 'Train Epoch: {}\tBatch: [{}]\tLoss: {}\tAccuracy: {} % ({} %)'.format( 193 | epoch, i, loss.item(), train_acc_avg, acc)) 194 | 195 | optimizer.zero_grad() 196 | loss.backward() 197 | optimizer.step() 198 | 199 | # Evaluate on the validation split 200 | _, _ = [x.eval() for x in (cls_head, embedding_net)] 201 | 202 | val_accuracies = [] 203 | val_losses = [] 204 | 205 | for i, batch in enumerate(tqdm(dloader_val(epoch)), 1): 206 | data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch] 207 | 208 | test_n_support = opt.test_way * opt.val_shot 209 | test_n_query = opt.test_way * opt.val_query 210 | with torch.no_grad(): 211 | emb_support = embedding_net(data_support.reshape([-1] + list(data_support.shape[-3:]))) 212 | emb_support = emb_support.reshape(1, test_n_support, -1) 213 | # emb_support = F.normalize(emb_support, dim=-1) 214 | emb_query = embedding_net(data_query.reshape([-1] + list(data_query.shape[-3:]))) 215 | emb_query = emb_query.reshape(1, test_n_query, -1) 216 | # emb_query = F.normalize(emb_query, dim=-1) 217 | 218 | logit_query = cls_head.sample(emb_query, emb_support, labels_support, labels_query, opt.test_way, opt.val_shot) 219 | loss = x_entropy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1)) 220 | acc = count_accuracy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1)) 221 | 222 | val_accuracies.append(acc.item()) 223 | val_losses.append(loss.item()) 224 | 225 | val_acc_avg = np.mean(np.array(val_accuracies)) 226 | val_acc_ci95 = 1.96 * np.std(np.array(val_accuracies)) / np.sqrt(opt.val_episode) 227 | 228 | val_loss_avg = np.mean(np.array(val_losses)) 229 | 230 | if val_acc_avg > max_val_acc: 231 | max_val_acc = val_acc_avg 232 | # torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()}, \ 233 | # os.path.join(opt.save_path, 'best_model_val_double_dir_opt_cub_{}_shot.pth'.format(opt.val_shot))) 234 | log(log_file_path, 'Validation Epoch: {}\t\t\tLoss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} % (Best)' \ 235 | .format(epoch, val_loss_avg, val_acc_avg, val_acc_ci95)) 236 | else: 237 | log(log_file_path, 'Validation Epoch: {}\t\t\tLoss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} %' \ 238 | .format(epoch, val_loss_avg, val_acc_avg, val_acc_ci95)) 239 | 240 | 241 | def test(opt, n_iter, dataset_train, dataset_val, dataset_test, data_loader): 242 | # Dataloader of Gidaris & Komodakis (CVPR 2018) 243 | dloader_test = data_loader( 244 | dataset=dataset_test, 245 | nKnovel=opt.test_way, 246 | nKbase=0, 247 | nExemplars=opt.val_shot, # num training examples per novel category 248 | nTestNovel=opt.val_query * opt.test_way, # num test examples for all the novel categories 249 | nTestBase=0, # num test examples for all the base categories 250 | batch_size=1, 251 | num_workers=0, 252 | epoch_size=1 * opt.val_episode, # num of batches per epoch 253 | ) 254 | 255 | set_gpu(opt.gpu) 256 | check_dir('./experiments/') 257 | check_dir(opt.save_path) 258 | 259 | log_file_path = os.path.join(opt.save_path, "train_log.txt") 260 | log(log_file_path, str(vars(opt))) 261 | 262 | (embedding_net, cls_head) = get_model(opt) 263 | # Load saved model checkpoints 264 | saved_models = torch.load(os.path.join(opt.save_path, 'best_model_val_double_dir_opt_cub_5_shot.pth')) 265 | # saved_models = torch.load(os.path.join(opt.save_path, 'best_model2.pth')) 266 | embedding_net.load_state_dict(saved_models['embedding']) 267 | embedding_net.eval() 268 | cls_head.load_state_dict(saved_models['head']) 269 | cls_head.eval() 270 | 271 | max_val_acc = 0.0 272 | max_test_acc = 0.0 273 | 274 | timer = Timer() 275 | x_entropy = torch.nn.CrossEntropyLoss() 276 | 277 | # Evaluate on the validation split 278 | _, _ = [x.eval() for x in (cls_head, embedding_net)] 279 | 280 | test_accuracies = [] 281 | test_losses = [] 282 | epoch = 9 283 | for i, batch in enumerate(tqdm(dloader_test(epoch)), 1): 284 | data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch] 285 | 286 | test_n_support = opt.test_way * opt.val_shot 287 | test_n_query = opt.test_way * opt.val_query 288 | 289 | with torch.no_grad(): 290 | emb_support = embedding_net(data_support.reshape([-1] + list(data_support.shape[-3:]))) 291 | emb_support = emb_support.reshape(1, test_n_support, -1) 292 | # emb_support = F.normalize(emb_support, dim=-1) 293 | emb_query = embedding_net(data_query.reshape([-1] + list(data_query.shape[-3:]))) 294 | emb_query = emb_query.reshape(1, test_n_query, -1) 295 | # emb_query = F.normalize(emb_query, dim=-1) 296 | 297 | logit_querys = cls_head(emb_query, emb_support, labels_support, labels_query, opt.test_way, opt.val_shot, is_train=False, update_step_test=n_iter) 298 | # logit_querys = logit_querys[-1:] 299 | for kk, logit_query in enumerate(logit_querys): 300 | loss = x_entropy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1)) 301 | acc = count_accuracy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1)) 302 | 303 | test_accuracies.append(acc.item()) 304 | test_losses.append(loss.item()) 305 | 306 | test_acc_avg = np.mean(np.array(test_accuracies)) 307 | test_acc_ci95 = 1.96 * np.std(np.array(test_accuracies)) / np.sqrt(opt.val_episode) 308 | 309 | test_loss_avg = np.mean(np.array(test_losses)) 310 | 311 | if test_acc_avg > max_test_acc: 312 | max_test_acc = test_acc_avg 313 | log(log_file_path, 'Test Loss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} % (Best)' \ 314 | .format(test_loss_avg, test_acc_avg, test_acc_ci95)) 315 | else: 316 | log(log_file_path, 'Test Loss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} %' \ 317 | .format(test_loss_avg, test_acc_avg, test_acc_ci95)) 318 | 319 | if __name__ == '__main__': 320 | seed_torch(21) 321 | parser = argparse.ArgumentParser() 322 | parser.add_argument('--num-epoch', type=int, default=1000, 323 | help='number of training epochs') 324 | parser.add_argument('--save-epoch', type=int, default=10, 325 | help='frequency of model saving') 326 | parser.add_argument('--train-shot', type=int, default=5, 327 | help='number of support examples per training class') 328 | parser.add_argument('--val-shot', type=int, default=5, 329 | help='number of support examples per validation class') 330 | parser.add_argument('--train-query', type=int, default=15, 331 | help='number of query examples per training class') 332 | parser.add_argument('--val-episode', type=int, default=600, 333 | help='number of episodes per validation') 334 | parser.add_argument('--val-query', type=int, default=15, 335 | help='number of query examples per validation class') 336 | parser.add_argument('--train-way', type=int, default=5, 337 | help='number of classes in one training episode') 338 | parser.add_argument('--test-way', type=int, default=5, 339 | help='number of classes in one test (or validation) episode') 340 | parser.add_argument('--save-path', default='./experiments/exp_1') 341 | parser.add_argument('--gpu', default='0') 342 | parser.add_argument('--network', type=str, default='ResNet', 343 | help='choose which embedding network to use. ProtoNet, R2D2, ResNet') 344 | parser.add_argument('--head', type=str, default='CosineNet', 345 | help='choose which classification head to use. ProtoNet, Ridge, R2D2, SVM') 346 | parser.add_argument('--pre_head', type=str, default='LinearRotateNet', 347 | help='choose which classification head to use. ProtoNet, Ridge, R2D2, SVM') 348 | parser.add_argument('--dataset', type=str, default='miniImageNet', 349 | help='choose which classification head to use. miniImageNet, tieredImageNet, CIFAR_FS, FC100') 350 | parser.add_argument('--episodes-per-batch', type=int, default=8, 351 | help='number of episodes per batch') 352 | parser.add_argument('--eps', type=float, default=0.0, 353 | help='epsilon of label smoothing') 354 | 355 | 356 | opt = parser.parse_args() 357 | 358 | (dataset_train, dataset_val, data_loader) = get_dataset(opt) 359 | 360 | train(opt, dataset_train, dataset_val, data_loader) -------------------------------------------------------------------------------- /data/mini_imagenet.py: -------------------------------------------------------------------------------- 1 | # Dataloader of Gidaris & Komodakis, CVPR 2018 2 | # Adapted from: 3 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/dataloader.py 4 | from __future__ import print_function 5 | 6 | import os 7 | import os.path 8 | import numpy as np 9 | import random 10 | import pickle 11 | import json 12 | import math 13 | 14 | import torch 15 | import torch.utils.data as data 16 | import torchvision 17 | import torchvision.datasets as datasets 18 | import torchvision.transforms as transforms 19 | import torchnet as tnt 20 | 21 | import h5py 22 | 23 | from PIL import Image 24 | from PIL import ImageEnhance 25 | 26 | from pdb import set_trace as breakpoint 27 | 28 | 29 | # Set the appropriate paths of the datasets here. 30 | _MINI_IMAGENET_DATASET_DIR = '../datasets/few_shot_data/MiniImagenet' 31 | 32 | def buildLabelIndex(labels): 33 | label2inds = {} 34 | for idx, label in enumerate(labels): 35 | if label not in label2inds: 36 | label2inds[label] = [] 37 | label2inds[label].append(idx) 38 | 39 | return label2inds 40 | 41 | 42 | def load_data(file): 43 | try: 44 | with open(file, 'rb') as fo: 45 | data = pickle.load(fo) 46 | return data 47 | except: 48 | with open(file, 'rb') as f: 49 | u = pickle._Unpickler(f) 50 | u.encoding = 'latin1' 51 | data = u.load() 52 | return data 53 | 54 | class MiniImageNet(data.Dataset): 55 | def __init__(self, phase='train', do_not_use_random_transf=False): 56 | 57 | self.base_folder = 'miniImagenet' 58 | assert(phase=='train' or phase=='val' or phase=='test' or phase=='all_train') 59 | self.phase = phase 60 | self.name = 'MiniImageNet_' + phase 61 | 62 | print('Loading mini ImageNet dataset - phase {0}'.format(phase)) 63 | file_train_categories_train_phase = os.path.join( 64 | _MINI_IMAGENET_DATASET_DIR, 65 | 'miniImageNet_category_split_train_phase_train.pickle') 66 | file_train_categories_val_phase = os.path.join( 67 | _MINI_IMAGENET_DATASET_DIR, 68 | 'miniImageNet_category_split_train_phase_val.pickle') 69 | file_train_categories_test_phase = os.path.join( 70 | _MINI_IMAGENET_DATASET_DIR, 71 | 'miniImageNet_category_split_train_phase_test.pickle') 72 | file_val_categories_val_phase = os.path.join( 73 | _MINI_IMAGENET_DATASET_DIR, 74 | 'miniImageNet_category_split_val.pickle') 75 | file_test_categories_test_phase = os.path.join( 76 | _MINI_IMAGENET_DATASET_DIR, 77 | 'miniImageNet_category_split_test.pickle') 78 | 79 | if self.phase=='train': 80 | # During training phase we only load the training phase images 81 | # of the training categories (aka base categories). 82 | data_train = load_data(file_train_categories_train_phase) 83 | self.data = data_train['data'] 84 | self.labels = data_train['labels'] 85 | 86 | self.label2ind = buildLabelIndex(self.labels) 87 | self.labelIds = sorted(self.label2ind.keys()) 88 | self.num_cats = len(self.labelIds) 89 | self.labelIds_base = self.labelIds 90 | self.num_cats_base = len(self.labelIds_base) 91 | 92 | elif self.phase=='val' or self.phase=='test': 93 | if self.phase=='test': 94 | # load data that will be used for evaluating the recognition 95 | # accuracy of the base categories. 96 | data_base = load_data(file_train_categories_test_phase) 97 | # load data that will be use for evaluating the few-shot recogniton 98 | # accuracy on the novel categories. 99 | data_novel = load_data(file_test_categories_test_phase) 100 | else: # phase=='val' 101 | # load data that will be used for evaluating the recognition 102 | # accuracy of the base categories. 103 | data_base = load_data(file_train_categories_val_phase) 104 | # load data that will be use for evaluating the few-shot recogniton 105 | # accuracy on the novel categories. 106 | data_novel = load_data(file_val_categories_val_phase) 107 | 108 | self.data = np.concatenate( 109 | [data_base['data'], data_novel['data']], axis=0) 110 | self.labels = data_base['labels'] + data_novel['labels'] 111 | 112 | self.label2ind = buildLabelIndex(self.labels) 113 | self.labelIds = sorted(self.label2ind.keys()) 114 | self.num_cats = len(self.labelIds) 115 | 116 | self.labelIds_base = buildLabelIndex(data_base['labels']).keys() 117 | self.labelIds_novel = buildLabelIndex(data_novel['labels']).keys() 118 | self.num_cats_base = len(self.labelIds_base) 119 | self.num_cats_novel = len(self.labelIds_novel) 120 | intersection = set(self.labelIds_base) & set(self.labelIds_novel) 121 | assert(len(intersection) == 0) 122 | elif self.phase=='all_train': 123 | data_train = load_data(file_train_categories_train_phase) 124 | data_test = load_data(file_test_categories_test_phase) 125 | data_val = load_data(file_val_categories_val_phase) 126 | 127 | self.data = np.concatenate( 128 | [data_train['data'], data_test['data'], data_val['data']], axis=0) 129 | self.labels = data_train['labels'] + data_test['labels']+ data_val['labels'] 130 | 131 | self.label2ind = buildLabelIndex(self.labels) 132 | self.labelIds = sorted(self.label2ind.keys()) 133 | self.num_cats = len(self.labelIds) 134 | self.labelIds_base = self.labelIds 135 | self.num_cats_base = len(self.labelIds_base) 136 | # assert(len(intersection) == 0) 137 | else: 138 | raise ValueError('Not valid phase {0}'.format(self.phase)) 139 | 140 | mean_pix = [x/255.0 for x in [120.39586422, 115.59361427, 104.54012653]] 141 | std_pix = [x/255.0 for x in [70.68188272, 68.27635443, 72.54505529]] 142 | normalize = transforms.Normalize(mean=mean_pix, std=std_pix) 143 | 144 | if (self.phase=='test' or self.phase=='val') or (do_not_use_random_transf==True): 145 | self.transform = transforms.Compose([ 146 | # transforms.RandomCrop(84, padding=8), 147 | lambda x: np.asarray(x), 148 | transforms.ToTensor(), 149 | normalize 150 | ]) 151 | else: 152 | self.transform = transforms.Compose([ 153 | transforms.RandomCrop(84, padding=8), 154 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 155 | transforms.RandomHorizontalFlip(), 156 | lambda x: np.asarray(x), 157 | transforms.ToTensor(), 158 | normalize 159 | ]) 160 | 161 | def __getitem__(self, index): 162 | img, label = self.data[index], self.labels[index] 163 | # doing this so that it is consistent with all other datasets 164 | # to return a PIL Image 165 | img = Image.fromarray(img) 166 | if self.transform is not None: 167 | img = self.transform(img) 168 | return img, label 169 | 170 | def __len__(self): 171 | return len(self.data) 172 | 173 | 174 | class FewShotDataloader(): 175 | def __init__(self, 176 | dataset, 177 | nKnovel=5, # number of novel categories. 178 | nKbase=-1, # number of base categories. 179 | nExemplars=1, # number of training examples per novel category. 180 | nTestNovel=15*5, # number of test examples for all the novel categories. 181 | nTestBase=15*5, # number of test examples for all the base categories. 182 | batch_size=1, # number of training episodes per batch. 183 | num_workers=4, 184 | epoch_size=2000, # number of batches per epoch. 185 | ): 186 | 187 | self.dataset = dataset 188 | self.phase = self.dataset.phase 189 | max_possible_nKnovel = (self.dataset.num_cats_base if self.phase=='train' 190 | else self.dataset.num_cats_novel) 191 | assert(nKnovel >= 0 and nKnovel < max_possible_nKnovel) 192 | self.nKnovel = nKnovel 193 | 194 | max_possible_nKbase = self.dataset.num_cats_base 195 | nKbase = nKbase if nKbase >= 0 else max_possible_nKbase 196 | if self.phase=='train' and nKbase > 0: 197 | nKbase -= self.nKnovel 198 | max_possible_nKbase -= self.nKnovel 199 | 200 | assert(nKbase >= 0 and nKbase <= max_possible_nKbase) 201 | self.nKbase = nKbase 202 | 203 | self.nExemplars = nExemplars 204 | self.nTestNovel = nTestNovel 205 | self.nTestBase = nTestBase 206 | self.batch_size = batch_size 207 | self.epoch_size = epoch_size 208 | self.num_workers = num_workers 209 | self.is_eval_mode = (self.phase=='test') or (self.phase=='val') 210 | 211 | def sampleImageIdsFrom(self, cat_id, sample_size=1): 212 | """ 213 | Samples `sample_size` number of unique image ids picked from the 214 | category `cat_id` (i.e., self.dataset.label2ind[cat_id]). 215 | 216 | Args: 217 | cat_id: a scalar with the id of the category from which images will 218 | be sampled. 219 | sample_size: number of images that will be sampled. 220 | 221 | Returns: 222 | image_ids: a list of length `sample_size` with unique image ids. 223 | """ 224 | assert(cat_id in self.dataset.label2ind) 225 | assert(len(self.dataset.label2ind[cat_id]) >= sample_size) 226 | # Note: random.sample samples elements without replacement. 227 | return random.sample(self.dataset.label2ind[cat_id], sample_size) 228 | 229 | def sampleCategories(self, cat_set, sample_size=1): 230 | """ 231 | Samples `sample_size` number of unique categories picked from the 232 | `cat_set` set of categories. `cat_set` can be either 'base' or 'novel'. 233 | 234 | Args: 235 | cat_set: string that specifies the set of categories from which 236 | categories will be sampled. 237 | sample_size: number of categories that will be sampled. 238 | 239 | Returns: 240 | cat_ids: a list of length `sample_size` with unique category ids. 241 | """ 242 | if cat_set=='base': 243 | labelIds = self.dataset.labelIds_base 244 | elif cat_set=='novel': 245 | labelIds = self.dataset.labelIds_novel 246 | else: 247 | raise ValueError('Not recognized category set {}'.format(cat_set)) 248 | 249 | assert(len(labelIds) >= sample_size) 250 | # return sample_size unique categories chosen from labelIds set of 251 | # categories (that can be either self.labelIds_base or self.labelIds_novel) 252 | # Note: random.sample samples elements without replacement. 253 | return random.sample(labelIds, sample_size) 254 | 255 | def sample_base_and_novel_categories(self, nKbase, nKnovel): 256 | """ 257 | Samples `nKbase` number of base categories and `nKnovel` number of novel 258 | categories. 259 | 260 | Args: 261 | nKbase: number of base categories 262 | nKnovel: number of novel categories 263 | 264 | Returns: 265 | Kbase: a list of length 'nKbase' with the ids of the sampled base 266 | categories. 267 | Knovel: a list of lenght 'nKnovel' with the ids of the sampled novel 268 | categories. 269 | """ 270 | if self.is_eval_mode: 271 | assert(nKnovel <= self.dataset.num_cats_novel) 272 | # sample from the set of base categories 'nKbase' number of base 273 | # categories. 274 | Kbase = sorted(self.sampleCategories('base', nKbase)) 275 | # sample from the set of novel categories 'nKnovel' number of novel 276 | # categories. 277 | Knovel = sorted(self.sampleCategories('novel', nKnovel)) 278 | else: 279 | # sample from the set of base categories 'nKnovel' + 'nKbase' number 280 | # of categories. 281 | cats_ids = self.sampleCategories('base', nKnovel+nKbase) 282 | assert(len(cats_ids) == (nKnovel+nKbase)) 283 | # Randomly pick 'nKnovel' number of fake novel categories and keep 284 | # the rest as base categories. 285 | random.shuffle(cats_ids) 286 | Knovel = sorted(cats_ids[:nKnovel]) 287 | Kbase = sorted(cats_ids[nKnovel:]) 288 | 289 | return Kbase, Knovel 290 | 291 | def sample_test_examples_for_base_categories(self, Kbase, nTestBase): 292 | """ 293 | Sample `nTestBase` number of images from the `Kbase` categories. 294 | 295 | Args: 296 | Kbase: a list of length `nKbase` with the ids of the categories from 297 | where the images will be sampled. 298 | nTestBase: the total number of images that will be sampled. 299 | 300 | Returns: 301 | Tbase: a list of length `nTestBase` with 2-element tuples. The 1st 302 | element of each tuple is the image id that was sampled and the 303 | 2nd elemend is its category label (which is in the range 304 | [0, len(Kbase)-1]). 305 | """ 306 | Tbase = [] 307 | if len(Kbase) > 0: 308 | # Sample for each base category a number images such that the total 309 | # number sampled images of all categories to be equal to `nTestBase`. 310 | KbaseIndices = np.random.choice( 311 | np.arange(len(Kbase)), size=nTestBase, replace=True) 312 | KbaseIndices, NumImagesPerCategory = np.unique( 313 | KbaseIndices, return_counts=True) 314 | 315 | for Kbase_idx, NumImages in zip(KbaseIndices, NumImagesPerCategory): 316 | imd_ids = self.sampleImageIdsFrom( 317 | Kbase[Kbase_idx], sample_size=NumImages) 318 | Tbase += [(img_id, Kbase_idx) for img_id in imd_ids] 319 | 320 | assert(len(Tbase) == nTestBase) 321 | 322 | return Tbase 323 | 324 | def sample_train_and_test_examples_for_novel_categories( 325 | self, Knovel, nTestNovel, nExemplars, nKbase): 326 | """Samples train and test examples of the novel categories. 327 | 328 | Args: 329 | Knovel: a list with the ids of the novel categories. 330 | nTestNovel: the total number of test images that will be sampled 331 | from all the novel categories. 332 | nExemplars: the number of training examples per novel category that 333 | will be sampled. 334 | nKbase: the number of base categories. It is used as offset of the 335 | category index of each sampled image. 336 | 337 | Returns: 338 | Tnovel: a list of length `nTestNovel` with 2-element tuples. The 339 | 1st element of each tuple is the image id that was sampled and 340 | the 2nd element is its category label (which is in the range 341 | [nKbase, nKbase + len(Knovel) - 1]). 342 | Exemplars: a list of length len(Knovel) * nExemplars of 2-element 343 | tuples. The 1st element of each tuple is the image id that was 344 | sampled and the 2nd element is its category label (which is in 345 | the ragne [nKbase, nKbase + len(Knovel) - 1]). 346 | """ 347 | 348 | if len(Knovel) == 0: 349 | return [], [] 350 | 351 | nKnovel = len(Knovel) 352 | Tnovel = [] 353 | Exemplars = [] 354 | assert((nTestNovel % nKnovel) == 0) 355 | nEvalExamplesPerClass = int(nTestNovel / nKnovel) 356 | 357 | for Knovel_idx in range(len(Knovel)): 358 | imd_ids = self.sampleImageIdsFrom( 359 | Knovel[Knovel_idx], 360 | sample_size=(nEvalExamplesPerClass + nExemplars)) 361 | 362 | imds_tnovel = imd_ids[:nEvalExamplesPerClass] 363 | imds_ememplars = imd_ids[nEvalExamplesPerClass:] 364 | 365 | Tnovel += [(img_id, nKbase+Knovel_idx) for img_id in imds_tnovel] 366 | Exemplars += [(img_id, nKbase+Knovel_idx) for img_id in imds_ememplars] 367 | assert(len(Tnovel) == nTestNovel) 368 | assert(len(Exemplars) == len(Knovel) * nExemplars) 369 | random.shuffle(Exemplars) 370 | 371 | return Tnovel, Exemplars 372 | 373 | def sample_episode(self): 374 | """Samples a training episode.""" 375 | nKnovel = self.nKnovel 376 | nKbase = self.nKbase 377 | nTestNovel = self.nTestNovel 378 | nTestBase = self.nTestBase 379 | nExemplars = self.nExemplars 380 | 381 | Kbase, Knovel = self.sample_base_and_novel_categories(nKbase, nKnovel) 382 | Tbase = self.sample_test_examples_for_base_categories(Kbase, nTestBase) 383 | Tnovel, Exemplars = self.sample_train_and_test_examples_for_novel_categories( 384 | Knovel, nTestNovel, nExemplars, nKbase) 385 | 386 | # concatenate the base and novel category examples. 387 | Test = Tbase + Tnovel 388 | random.shuffle(Test) 389 | Kall = Kbase + Knovel 390 | 391 | return Exemplars, Test, Kall, nKbase 392 | 393 | def createExamplesTensorData(self, examples): 394 | """ 395 | Creates the examples image and label tensor data. 396 | 397 | Args: 398 | examples: a list of 2-element tuples, each representing a 399 | train or test example. The 1st element of each tuple 400 | is the image id of the example and 2nd element is the 401 | category label of the example, which is in the range 402 | [0, nK - 1], where nK is the total number of categories 403 | (both novel and base). 404 | 405 | Returns: 406 | images: a tensor of shape [nExamples, Height, Width, 3] with the 407 | example images, where nExamples is the number of examples 408 | (i.e., nExamples = len(examples)). 409 | labels: a tensor of shape [nExamples] with the category label 410 | of each example. 411 | """ 412 | images = torch.stack( 413 | [self.dataset[img_idx][0] for img_idx, _ in examples], dim=0) 414 | labels = torch.LongTensor([label for _, label in examples]) 415 | return images, labels 416 | 417 | def get_iterator(self, ): 418 | rand_seed = 8 419 | random.seed(rand_seed) 420 | np.random.seed(rand_seed) 421 | def load_function(iter_idx): 422 | Exemplars, Test, Kall, nKbase = self.sample_episode() 423 | Xt, Yt = self.createExamplesTensorData(Test) 424 | Kall = torch.LongTensor(Kall) 425 | if len(Exemplars) > 0: 426 | Xe, Ye = self.createExamplesTensorData(Exemplars) 427 | return Xe, Ye, Xt, Yt, Kall, nKbase 428 | else: 429 | return Xt, Yt, Kall, nKbase 430 | 431 | tnt_dataset = tnt.dataset.ListDataset( 432 | elem_list=range(self.epoch_size), load=load_function) 433 | data_loader = tnt_dataset.parallel( 434 | batch_size=self.batch_size, 435 | num_workers=(0 if self.is_eval_mode else self.num_workers), 436 | shuffle=(False if self.is_eval_mode else True)) 437 | 438 | return data_loader 439 | 440 | def __call__(self, epoch=0): 441 | return self.get_iterator(epoch) 442 | 443 | def __len__(self): 444 | return (self.epoch_size / self.batch_size) 445 | -------------------------------------------------------------------------------- /models/diffusion_grad.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from functools import partial 7 | from copy import deepcopy 8 | 9 | import argparse 10 | import torchvision 11 | from .attention import SelfAttention, CrossAttention 12 | # from models.unet import UNet 13 | 14 | # from .ema import EMA 15 | # from .utils import extract 16 | 17 | class TimeEmbedding(nn.Module): 18 | def __init__(self, n_embd): 19 | super().__init__() 20 | self.linear_1 = nn.Linear(n_embd, 4 * n_embd) 21 | self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd) 22 | 23 | def forward(self, x): 24 | x = self.linear_1(x) 25 | x = F.silu(x) 26 | x = self.linear_2(x) 27 | return x 28 | 29 | 30 | class ResidualBlock(nn.Module): 31 | def __init__(self, in_channels, out_channels, n_time=320): 32 | super().__init__() 33 | self.groupnorm_feature = nn.GroupNorm(32, in_channels) 34 | self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 35 | self.linear_time = nn.Linear(n_time, out_channels) 36 | 37 | self.groupnorm_merged = nn.GroupNorm(32, out_channels) 38 | self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) 39 | 40 | if in_channels == out_channels: 41 | self.residual_layer = nn.Identity() 42 | else: 43 | self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 44 | 45 | def forward(self, feature, time): 46 | residue = feature 47 | 48 | feature = self.groupnorm_feature(feature) 49 | feature = F.silu(feature) 50 | feature = self.conv_feature(feature) 51 | 52 | time = F.silu(time) 53 | time = self.linear_time(time) 54 | 55 | merged = feature + time.unsqueeze(-1).unsqueeze(-1) 56 | merged = self.groupnorm_merged(merged) 57 | merged = F.silu(merged) 58 | merged = self.conv_merged(merged) 59 | 60 | return merged + self.residual_layer(residue) 61 | 62 | 63 | class AttentionBlock(nn.Module): 64 | def __init__(self, n_head: int, n_embd: int, d_context=512): 65 | super().__init__() 66 | channels = n_head * n_embd 67 | 68 | self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6) 69 | self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0) 70 | 71 | self.layernorm_1 = nn.LayerNorm(channels) 72 | self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False) 73 | self.layernorm_2 = nn.LayerNorm(channels) 74 | self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False) 75 | self.layernorm_3 = nn.LayerNorm(channels) 76 | self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2) 77 | self.linear_geglu_2 = nn.Linear(4 * channels, channels) 78 | 79 | self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0) 80 | 81 | def forward(self, x, context): 82 | residue_long = x 83 | 84 | x = self.groupnorm(x) 85 | x = self.conv_input(x) 86 | 87 | n, c, h, w = x.shape 88 | x = x.view((n, c, h * w)) # (n, c, hw) 89 | x = x.transpose(-1, -2) # (n, hw, c) 90 | 91 | residue_short = x 92 | x = self.layernorm_1(x) 93 | x = self.attention_1(x) 94 | x += residue_short 95 | 96 | residue_short = x 97 | x = self.layernorm_2(x) 98 | x = self.attention_2(x, context) 99 | x += residue_short 100 | 101 | residue_short = x 102 | x = self.layernorm_3(x) 103 | x, gate = self.linear_geglu_1(x).chunk(2, dim=-1) 104 | x = x * F.gelu(gate) 105 | x = self.linear_geglu_2(x) 106 | x += residue_short 107 | 108 | x = x.transpose(-1, -2) # (n, c, hw) 109 | x = x.view((n, c, h, w)) # (n, c, h, w) 110 | 111 | return self.conv_output(x) + residue_long 112 | 113 | 114 | class Upsample(nn.Module): 115 | def __init__(self, channels): 116 | super().__init__() 117 | self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1) 118 | 119 | def forward(self, x): 120 | x = F.interpolate(x, scale_factor=2, mode='nearest') 121 | return self.conv(x) 122 | 123 | 124 | class SwitchSequential(nn.Sequential): 125 | def forward(self, x, context, time): 126 | for layer in self: 127 | if isinstance(layer, AttentionBlock): 128 | x = layer(x, context) 129 | elif isinstance(layer, ResidualBlock): 130 | x = layer(x, time) 131 | else: 132 | x = layer(x) 133 | return x 134 | 135 | 136 | class UNet(nn.Module): 137 | def __init__(self): 138 | super().__init__() 139 | self.encoders = nn.ModuleList([ 140 | SwitchSequential(ResidualBlock(512, 256), AttentionBlock(8, 32)), 141 | SwitchSequential(ResidualBlock(256, 128), AttentionBlock(8, 16)), 142 | ]) 143 | self.bottleneck = SwitchSequential( 144 | ResidualBlock(128, 128), 145 | AttentionBlock(8, 16), 146 | ResidualBlock(128, 128), 147 | ) 148 | self.decoders = nn.ModuleList([ 149 | SwitchSequential(ResidualBlock(128+128, 256), AttentionBlock(8, 32)), 150 | SwitchSequential(ResidualBlock(256+256, 512), AttentionBlock(8, 64)), 151 | ]) 152 | 153 | self.pos_emb = nn.Linear(5, 512, bias=False) 154 | 155 | def get_time_embedding(self, timestep): 156 | freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160).cuda() 157 | x = timestep[:, None] * freqs[None] 158 | return torch.cat([torch.cos(x), torch.sin(x)], dim=-1) 159 | 160 | def forward(self, x, time = None, emb_query = None, emb_support = None, labels_support = None, labels_query = None, n_way = None, k_shot = None): 161 | time = self.get_time_embedding(time) 162 | 163 | tasks_per_batch = emb_query.size(0) 164 | n_support = emb_support.size(1) 165 | 166 | 167 | support_labels_one_hot = one_hot(labels_support.reshape(-1), n_way) 168 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way) 169 | 170 | pos_emb = self.pos_emb(support_labels_one_hot) 171 | context = emb_support + pos_emb 172 | 173 | prototype_labels = torch.LongTensor(list(range(n_way))).cuda() 174 | prototype_labels = one_hot(prototype_labels, n_way).view(1, n_way, n_way) 175 | prototype_pos_emb = self.pos_emb(prototype_labels) 176 | 177 | x = x + prototype_pos_emb.permute(0, 2, 1).unsqueeze(-1).repeat(tasks_per_batch, 1, 1, 1) 178 | 179 | 180 | skip_connections = [] 181 | for layers in self.encoders: 182 | x = layers(x, context, time) 183 | skip_connections.append(x) 184 | 185 | x = self.bottleneck(x, context, time) 186 | 187 | for layers in self.decoders: 188 | x = torch.cat((x, skip_connections.pop()), dim=1) 189 | x = layers(x, context, time) 190 | 191 | return x 192 | 193 | 194 | class FinalLayer(nn.Module): 195 | def __init__(self, in_channels, out_channels): 196 | super().__init__() 197 | self.groupnorm = nn.GroupNorm(32, in_channels) 198 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 199 | 200 | def forward(self, x): 201 | x = self.groupnorm(x) 202 | x = F.silu(x) 203 | x = self.conv(x) 204 | return x 205 | 206 | def one_hot(indices, depth): 207 | """ 208 | Returns a one-hot tensor. 209 | This is a PyTorch equivalent of Tensorflow's tf.one_hot. 210 | 211 | Parameters: 212 | indices: a (n_batch, m) Tensor or (m) Tensor. 213 | depth: a scalar. Represents the depth of the one hot dimension. 214 | Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor. 215 | """ 216 | 217 | encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda() 218 | index = indices.view(indices.size() + torch.Size([1])) 219 | encoded_indicies = encoded_indicies.scatter_(1, index, 1) 220 | 221 | return encoded_indicies 222 | 223 | 224 | def extract(a, t, x_shape): 225 | b, *_ = t.shape 226 | out = a.gather(-1, t) 227 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 228 | 229 | class EMA(): 230 | def __init__(self, decay): 231 | self.decay = decay 232 | 233 | def update_average(self, old, new): 234 | if old is None: 235 | return new 236 | return old * self.decay + (1 - self.decay) * new 237 | 238 | def update_model_average(self, ema_model, current_model): 239 | for current_params, ema_params in zip(current_model.parameters(), ema_model.parameters()): 240 | old, new = ema_params.data, current_params.data 241 | ema_params.data = self.update_average(old, new) 242 | 243 | class DMFunc(nn.Module): 244 | 245 | def __init__(self): 246 | super(DMFunc, self).__init__() 247 | self.nfe = 0 248 | self.scale_factor = 10 249 | self.time_scale = nn.Sequential( 250 | nn.Linear(320, 160), 251 | nn.ReLU(), 252 | nn.Linear(160, 1), 253 | nn.Sigmoid() 254 | ) 255 | # self.res = nn.Sequential( 256 | # nn.Linear(512+320, 256), 257 | # nn.ReLU(), 258 | # nn.Linear(256, 512), 259 | # ) 260 | 261 | self.time_emb1 = nn.Linear(320, 256) 262 | self.proto_inf1 = nn.Linear(512, 256) 263 | 264 | self.time_emb2 = nn.Linear(320, 128) 265 | self.proto_inf2 = nn.Linear(256, 128) 266 | 267 | self.time_emb3 = nn.Linear(320, 128) 268 | self.proto_inf3 = nn.Linear(128, 128) 269 | 270 | self.time_emb4 = nn.Linear(320, 256) 271 | self.proto_inf4 = nn.Linear(128, 256) 272 | 273 | self.time_emb5 = nn.Linear(320, 512) 274 | self.proto_inf5 = nn.Linear(256, 512) 275 | 276 | # self.time_emb1 = nn.Linear(320, 256//4) 277 | # self.proto_inf1 = nn.Linear(512//4, 256//4) 278 | # 279 | # self.time_emb2 = nn.Linear(320, 128//4) 280 | # self.proto_inf2 = nn.Linear(256//4, 128//4) 281 | # 282 | # self.time_emb3 = nn.Linear(320, 128//4) 283 | # self.proto_inf3 = nn.Linear(128//4, 128//4) 284 | # 285 | # self.time_emb4 = nn.Linear(320, 256//4) 286 | # self.proto_inf4 = nn.Linear(128//4, 256//4) 287 | # 288 | # self.time_emb5 = nn.Linear(320, 512//4) 289 | # self.proto_inf5 = nn.Linear(256//4, 512//4) 290 | 291 | def get_time_embedding(self, timestep): 292 | freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160).cuda() 293 | x = timestep[:, None] * freqs[None] 294 | return torch.cat([torch.cos(x), torch.sin(x)], dim=-1) 295 | 296 | def forward(self, x, time, emb_query, emb_support, labels_support, labels_query, n_way, k_shot): 297 | tasks_per_batch = emb_support.shape[0] 298 | n_support = emb_support.shape[1] 299 | support_labels = one_hot(labels_support.view(tasks_per_batch * n_support), n_way) 300 | support_labels = support_labels.view(tasks_per_batch, n_support, -1) 301 | 302 | # encoding time 303 | time = self.get_time_embedding(time) 304 | scale = self.time_scale(time) 305 | 306 | nb, nd, nw, _ = x.shape 307 | prototypes = x.detach()[:, :, :, 0].permute(0, 2, 1) 308 | 309 | # 310 | weights = prototypes 311 | weights.requires_grad = True 312 | 313 | # logits = 10 * F.cosine_similarity( 314 | # emb_support.unsqueeze(2).expand(-1, -1, n_way, -1), 315 | # weights.unsqueeze(1).expand(-1, emb_support.shape[1], -1, -1), 316 | # dim=-1) + torch.sum(weights*weights) * 5e-4 317 | # loss = nn.MSELoss()(F.softmax(logits.reshape(-1, n_way), dim=1), support_labels.reshape(-1, n_way)) 318 | # loss = nn.CrossEntropyLoss()(logits.reshape(-1, n_way), labels_support.reshape(-1)) 319 | 320 | diff = weights.unsqueeze(1).expand(-1, emb_support.shape[1], -1, -1) - emb_support.unsqueeze(2).expand(-1, -1, n_way, -1) 321 | loss = torch.sum(torch.sum(diff*diff, dim=-1) * support_labels) 322 | 323 | # logits = F.softmax(-1 * torch.sum(diff*diff, dim=-1) / 512, dim=-1) 324 | # # loss = nn.MSELoss()(F.softmax(logits.reshape(-1, n_way), dim=1), support_labels.reshape(-1, n_way)) 325 | # loss = nn.CrossEntropyLoss()(logits.reshape(-1, n_way), labels_support.reshape(-1)) 326 | 327 | # compute grad and update inner loop weights 328 | grads = torch.autograd.grad(loss, weights) 329 | x = grads[0].detach() 330 | 331 | # # prototypes_ = (prototypes+prototypes_fs)/2 332 | # prototypes_ = prototypes 333 | 334 | # query_labels = torch.nn.functional.cosine_similarity( 335 | # emb_query.unsqueeze(2).expand(-1, -1, prototypes_.shape[1], -1), 336 | # prototypes_.unsqueeze(1).expand(-1, emb_query.shape[1], -1, -1), 337 | # dim=-1) 338 | # query_labels = query_labels * self.scale_factor 339 | # query_labels = F.softmax(query_labels, dim=-1) 340 | # data_labels = torch.cat([support_labels, query_labels], dim=1).unsqueeze(dim=-1).permute(0, 2, 1, 3) 341 | 342 | # topk, indices = torch.topk(data_labels, k_shot, dim=2) 343 | # mask = torch.zeros_like(data_labels) 344 | # mask = mask.scatter(2, indices, 1) 345 | # data_labels = data_labels * mask 346 | 347 | # data_labels = support_labels.unsqueeze(dim=-1).permute(0, 2, 1, 3) 348 | # 349 | # diff_weights = data_labels / torch.sum(data_labels, dim=2, keepdim=True) 350 | # 351 | # # cal vector fild 352 | # # all_x = torch.cat([emb_support, emb_query], dim=1) 353 | # all_x = emb_support 354 | # x = prototypes 355 | # x_left = x.unsqueeze(dim=2).expand(-1, -1, all_x.shape[1], -1) 356 | # all_x_right = all_x.unsqueeze(dim=1).expand(-1, n_way, -1, -1) 357 | # diff = (x_left - all_x_right) 358 | # diff = scale.unsqueeze(dim=-1) * torch.sum((diff_weights * diff), dim=2) + (1-scale.unsqueeze(dim=-1))*self.res(torch.cat([prototypes, time.unsqueeze(dim=1).repeat(1,5,1)], dim=-1)) 359 | 360 | # x = torch.sum((diff_weights * diff), dim=2) 361 | 362 | time_ = self.time_emb1(time).unsqueeze(dim=1).repeat(1,5,1) 363 | x_1 = self.proto_inf1(x) 364 | x_1 = F.softplus(x_1 * time_) 365 | 366 | time_ = self.time_emb2(time).unsqueeze(dim=1).repeat(1,5,1) 367 | x_2 = self.proto_inf2(x_1) 368 | x_2 = F.softplus(x_2 * time_) 369 | 370 | time_ = self.time_emb3(time).unsqueeze(dim=1).repeat(1,5,1) 371 | x_3 = self.proto_inf3(x_2) 372 | x_3 = F.softplus(x_3 * time_) + x_2 373 | 374 | time_ = self.time_emb4(time).unsqueeze(dim=1).repeat(1,5,1) 375 | x_4 = self.proto_inf4(x_3) 376 | x_4 = F.softplus(x_4 * time_) + x_1 377 | 378 | time_ = self.time_emb5(time).unsqueeze(dim=1).repeat(1,5,1) 379 | x_5 = self.proto_inf5(x_4) 380 | x_5 = x_5 * time_ + x 381 | 382 | # diff = scale.unsqueeze(dim=-1) * torch.sum((diff_weights * diff), dim=2) + ( 383 | # 1 - scale.unsqueeze(dim=-1)) * x 384 | 385 | # diff = x_5 * scale.unsqueeze(dim=-1) 386 | 387 | diff = scale.unsqueeze(dim=-1) * x + ( 388 | 1 - scale.unsqueeze(dim=-1)) * x_5 389 | 390 | return diff.permute(0, 2, 1).unsqueeze(-1) 391 | 392 | 393 | class DMFunc1(nn.Module): 394 | 395 | def __init__(self): 396 | super(DMFunc1, self).__init__() 397 | self.nfe = 0 398 | self.scale_factor = 10 399 | 400 | self.time_emb1 = nn.Linear(320, 256) 401 | self.proto_inf1 = nn.Linear(512, 256) 402 | self.spt_inf1 = nn.Linear(512, 256) 403 | 404 | self.time_emb2 = nn.Linear(320, 128) 405 | self.proto_inf2 = nn.Linear(256, 128) 406 | self.spt_inf2 = nn.Linear(256, 128) 407 | 408 | self.time_emb3 = nn.Linear(320, 128) 409 | self.proto_inf3 = nn.Linear(128, 128) 410 | self.spt_inf3 = nn.Linear(128, 128) 411 | 412 | self.time_emb4 = nn.Linear(320, 256) 413 | self.proto_inf4 = nn.Linear(128, 256) 414 | self.spt_inf4 = nn.Linear(128, 256) 415 | 416 | self.time_emb5 = nn.Linear(320, 512) 417 | self.proto_inf5 = nn.Linear(256, 512) 418 | self.spt_inf5 = nn.Linear(256, 512) 419 | 420 | def get_time_embedding(self, timestep): 421 | freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160).cuda() 422 | x = timestep[:, None] * freqs[None] 423 | return torch.cat([torch.cos(x), torch.sin(x)], dim=-1) 424 | 425 | def forward(self, x, time, emb_query, emb_support, labels_support, labels_query, n_way, k_shot): 426 | tasks_per_batch = emb_support.shape[0] 427 | n_support = emb_support.shape[1] 428 | 429 | # encoding time 430 | time = self.get_time_embedding(time) 431 | 432 | nb, nd, nw, _ = x.shape 433 | x = x.permute(0, 2, 3, 1) 434 | 435 | sorted_emb_support = [] 436 | for way_i in range(n_way): 437 | temp = [] 438 | for bi in range(tasks_per_batch): 439 | temp.append(emb_support[bi, labels_support[bi, :]==way_i, :]) 440 | temp = torch.stack(temp, dim=0) 441 | sorted_emb_support.append(temp) 442 | sorted_emb_support = torch.stack(sorted_emb_support, dim=1) 443 | 444 | time_ = self.time_emb1(time).unsqueeze(dim=1).unsqueeze(dim=1) 445 | x = self.proto_inf1(x) 446 | spt_emb = self.spt_inf1(sorted_emb_support) 447 | x = F.softplus(x * time_ * spt_emb) 448 | x_1 = x 449 | time_ = self.time_emb2(time).unsqueeze(dim=1).unsqueeze(dim=1) 450 | x = self.proto_inf2(x) 451 | spt_emb = self.spt_inf2(spt_emb) 452 | x = F.softplus(x * time_ * spt_emb) 453 | x_2 = x 454 | time_ = self.time_emb3(time).unsqueeze(dim=1).unsqueeze(dim=1) 455 | x = self.proto_inf3(x) 456 | spt_emb = self.spt_inf3(spt_emb) 457 | x = F.softplus(x * time_ * spt_emb) + x_2 458 | 459 | time_ = self.time_emb4(time).unsqueeze(dim=1).unsqueeze(dim=1) 460 | x = self.proto_inf4(x) 461 | spt_emb = self.spt_inf4(spt_emb) 462 | x = F.softplus(x * time_ * spt_emb) + x_1 463 | time_ = self.time_emb5(time).unsqueeze(dim=1).unsqueeze(dim=1) 464 | x = self.proto_inf5(x) 465 | spt_emb = self.spt_inf5(spt_emb) 466 | x = F.softplus(x * time_ * spt_emb) + x 467 | 468 | x = x.mean(dim=-2) 469 | return x.permute(0, 2, 1).unsqueeze(-1) 470 | 471 | 472 | class GaussianDiffusion(nn.Module): 473 | __doc__ = r"""Gaussian Diffusion model. Forwarding through the module returns diffusion reversal scalar loss tensor. 474 | 475 | Input: 476 | x: tensor of shape (N, img_channels, *img_size) 477 | y: tensor of shape (N) 478 | Output: 479 | scalar loss tensor 480 | Args: 481 | model (nn.Module): model which estimates diffusion noise 482 | img_size (tuple): image size tuple (H, W) 483 | img_channels (int): number of image channels 484 | betas (np.ndarray): numpy array of diffusion betas 485 | loss_type (string): loss type, "l1" or "l2" 486 | ema_decay (float): model weights exponential moving average decay 487 | ema_start (int): number of steps before EMA 488 | ema_update_rate (int): number of steps before each EMA update 489 | """ 490 | 491 | def __init__( 492 | self, 493 | model, 494 | img_size, 495 | img_channels, 496 | num_classes, 497 | betas, 498 | loss_type="l2", 499 | ema_decay=0.9999, 500 | ema_start=5000, 501 | ema_update_rate=1, 502 | ): 503 | super().__init__() 504 | 505 | self.model = model 506 | self.ema_model = deepcopy(model) 507 | 508 | self.ema = EMA(ema_decay) 509 | self.ema_decay = ema_decay 510 | self.ema_start = ema_start 511 | self.ema_update_rate = ema_update_rate 512 | self.step = 0 513 | 514 | self.img_size = img_size 515 | self.img_channels = img_channels 516 | self.num_classes = num_classes 517 | 518 | if loss_type not in ["l1", "l2"]: 519 | raise ValueError("__init__() got unknown loss type") 520 | 521 | self.loss_type = loss_type 522 | self.num_timesteps = len(betas) 523 | 524 | alphas = 1.0 - betas 525 | alphas_cumprod = np.cumprod(alphas) 526 | 527 | to_torch = partial(torch.tensor, dtype=torch.float32) 528 | 529 | self.register_buffer("betas", to_torch(betas)) 530 | self.register_buffer("alphas", to_torch(alphas)) 531 | self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) 532 | 533 | self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) 534 | self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1 - alphas_cumprod))) 535 | self.register_buffer("reciprocal_sqrt_alphas", to_torch(np.sqrt(1 / alphas))) 536 | 537 | self.register_buffer("remove_noise_coeff", to_torch(betas / np.sqrt(1 - alphas_cumprod))) 538 | self.register_buffer("sigma", to_torch(np.sqrt(betas))) 539 | 540 | def update_ema(self): 541 | self.step += 1 542 | if self.step % self.ema_update_rate == 0: 543 | if self.step < self.ema_start: 544 | self.ema_model.load_state_dict(self.model.state_dict()) 545 | else: 546 | self.ema.update_model_average(self.ema_model, self.model) 547 | 548 | # @torch.no_grad() 549 | def remove_noise(self, x, t, emb_query, emb_support, labels_support, labels_query, n_way, k_shot, use_ema=True): 550 | if use_ema: 551 | return ( 552 | (x - extract(self.remove_noise_coeff, t, x.shape) * self.ema_model(x, t, emb_query, emb_support, labels_support, labels_query, n_way, k_shot)) * 553 | extract(self.reciprocal_sqrt_alphas, t, x.shape) 554 | ) 555 | else: 556 | return ( 557 | (x - extract(self.remove_noise_coeff, t, x.shape) * self.model(x, t, emb_query, emb_support, labels_support, labels_query, n_way, k_shot)) * 558 | extract(self.reciprocal_sqrt_alphas, t, x.shape) 559 | ) 560 | 561 | #@torch.no_grad() 562 | def sample(self, emb_query, emb_support, labels_support, labels_query, n_way, k_shot, use_ema=True): 563 | batch_size = emb_query.shape[0] 564 | device = emb_query.device 565 | x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device) 566 | 567 | for t in range(self.num_timesteps - 1, -1, -1): 568 | t_batch = torch.tensor([t], device=device).repeat(batch_size) 569 | x = self.remove_noise(x, t_batch, emb_query, emb_support, labels_support, labels_query, n_way, k_shot, use_ema) 570 | 571 | # if t > 0: 572 | # x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x) 573 | # if t > 0: 574 | # x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x) 575 | 576 | prototype = x.detach()[:, :, :, 0].permute(0, 2, 1) 577 | logits = 10 * F.cosine_similarity( 578 | emb_query.unsqueeze(2).expand(-1, -1, n_way, -1), 579 | prototype.unsqueeze(1).expand(-1, emb_query.shape[1], -1, -1), 580 | dim=-1) 581 | # logits = torch.sum(emb_query.unsqueeze(2).expand(-1, -1, n_way, -1)* 582 | # prototype.unsqueeze(1).expand(-1, emb_query.shape[1], -1, -1), dim=-1) 583 | return logits 584 | 585 | @torch.no_grad() 586 | def sample_diffusion_sequence(self, batch_size, device, y=None, use_ema=True): 587 | if y is not None and batch_size != len(y): 588 | raise ValueError("sample batch size different from length of given y") 589 | 590 | x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device) 591 | diffusion_sequence = [x.cpu().detach()] 592 | 593 | for t in range(self.num_timesteps - 1, -1, -1): 594 | t_batch = torch.tensor([t], device=device).repeat(batch_size) 595 | x = self.remove_noise(x, t_batch, y, use_ema) 596 | 597 | if t > 0: 598 | x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x) 599 | 600 | diffusion_sequence.append(x.cpu().detach()) 601 | 602 | return diffusion_sequence 603 | 604 | def perturb_x(self, x, t, noise): 605 | return ( 606 | extract(self.sqrt_alphas_cumprod, t, x.shape) * x + 607 | extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise 608 | ) 609 | 610 | def get_losses(self, x, t, emb_query, emb_support, labels_support, labels_query, n_way, k_shot): 611 | noise = torch.randn_like(x) 612 | 613 | perturbed_x = self.perturb_x(x, t, noise) 614 | estimated_noise = self.model(perturbed_x, t, emb_query, emb_support, labels_support, labels_query, n_way, k_shot) 615 | 616 | if self.loss_type == "l1": 617 | loss = F.l1_loss(estimated_noise, noise) 618 | elif self.loss_type == "l2": 619 | loss = F.mse_loss(estimated_noise, noise) 620 | 621 | # loss = F.l1_loss(estimated_noise, noise) 622 | 623 | return loss 624 | 625 | def forward(self, x, emb_query, emb_support, labels_support, labels_query, n_way, k_shot): 626 | x = x.permute(0, 2, 1).unsqueeze(-1) 627 | b, c, h, w = x.shape 628 | device = x.device 629 | t = torch.randint(0, self.num_timesteps, (b,), device=device) 630 | return self.get_losses(x, t, emb_query, emb_support, labels_support, labels_query, n_way, k_shot) 631 | 632 | 633 | def generate_cosine_schedule(T, s=0.008): 634 | def f(t, T): 635 | return (np.cos((t / T + s) / (1 + s) * np.pi / 2)) ** 2 636 | 637 | alphas = [] 638 | f0 = f(0, T) 639 | 640 | for t in range(T + 1): 641 | alphas.append(f(t, T) / f0) 642 | 643 | betas = [] 644 | 645 | for t in range(1, T + 1): 646 | betas.append(min(1 - alphas[t] / alphas[t - 1], 0.999)) 647 | 648 | return np.array(betas) 649 | 650 | 651 | def generate_linear_schedule(T, low, high): 652 | return np.linspace(low, high, T) 653 | 654 | 655 | def get_transform(): 656 | class RescaleChannels(object): 657 | def __call__(self, sample): 658 | return 2 * sample - 1 659 | 660 | return torchvision.transforms.Compose([ 661 | torchvision.transforms.ToTensor(), 662 | RescaleChannels(), 663 | ]) 664 | 665 | 666 | def str2bool(v): 667 | """ 668 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 669 | """ 670 | if isinstance(v, bool): 671 | return v 672 | if v.lower() in ("yes", "true", "t", "y", "1"): 673 | return True 674 | elif v.lower() in ("no", "false", "f", "n", "0"): 675 | return False 676 | else: 677 | raise argparse.ArgumentTypeError("boolean value expected") 678 | 679 | 680 | def add_dict_to_argparser(parser, default_dict): 681 | """ 682 | https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/script_util.py 683 | """ 684 | for k, v in default_dict.items(): 685 | v_type = type(v) 686 | if v is None: 687 | v_type = str 688 | elif isinstance(v, bool): 689 | v_type = str2bool 690 | parser.add_argument(f"--{k}", default=v, type=v_type) 691 | 692 | 693 | def get_diffusion_from_args(): 694 | num_timesteps = 1000 695 | schedule = "linear" 696 | loss_type = "l2" 697 | use_labels = False 698 | 699 | base_channels = 128 700 | channel_mults = (1, 2, 2, 2) 701 | num_res_blocks = 2 702 | time_emb_dim = 128 * 4 703 | norm = "gn" 704 | dropout = 0.1 705 | activation = "silu" 706 | attention_resolutions = (1,) 707 | 708 | ema_decay = 0.9999 709 | ema_update_rate = 1 710 | schedule_low = 1e-4 711 | schedule_high = 0.02 712 | 713 | activations = { 714 | "relu": F.relu, 715 | "mish": F.mish, 716 | "silu": F.silu, 717 | } 718 | 719 | model = DMFunc() 720 | 721 | if schedule == "cosine": 722 | betas = generate_cosine_schedule(num_timesteps) 723 | else: 724 | betas = generate_linear_schedule( 725 | num_timesteps, 726 | schedule_low * 1000 / num_timesteps, 727 | schedule_high * 1000 / num_timesteps, 728 | ) 729 | 730 | diffusion = GaussianDiffusion( 731 | model, (5, 1), 512, 10, 732 | betas, 733 | ema_decay=ema_decay, 734 | ema_update_rate=ema_update_rate, 735 | ema_start=2000, 736 | loss_type=loss_type, 737 | ) 738 | return diffusion -------------------------------------------------------------------------------- /models/classification_heads.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | import torch.nn as nn 7 | from qpth.qp import QPFunction 8 | from models.local_distribute import * 9 | from qpth.qp import QPFunction 10 | import torch.nn.functional as F 11 | import cv2 12 | import torch.fft 13 | 14 | 15 | def computeGramMatrix(A, B): 16 | """ 17 | Constructs a linear kernel matrix between A and B. 18 | We assume that each row in A and B represents a d-dimensional feature vector. 19 | 20 | Parameters: 21 | A: a (n_batch, n, d) Tensor. 22 | B: a (n_batch, m, d) Tensor. 23 | Returns: a (n_batch, n, m) Tensor. 24 | """ 25 | 26 | assert(A.dim() == 3) 27 | assert(B.dim() == 3) 28 | assert(A.size(0) == B.size(0) and A.size(2) == B.size(2)) 29 | 30 | return torch.bmm(A, B.transpose(1,2)) 31 | 32 | 33 | def binv(b_mat): 34 | """ 35 | Computes an inverse of each matrix in the batch. 36 | Pytorch 0.4.1 does not support batched matrix inverse. 37 | Hence, we are solving AX=I. 38 | 39 | Parameters: 40 | b_mat: a (n_batch, n, n) Tensor. 41 | Returns: a (n_batch, n, n) Tensor. 42 | """ 43 | 44 | id_matrix = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat).cuda() 45 | b_inv, _ = torch.gesv(id_matrix, b_mat) 46 | 47 | return b_inv 48 | 49 | 50 | def one_hot(indices, depth): 51 | """ 52 | Returns a one-hot tensor. 53 | This is a PyTorch equivalent of Tensorflow's tf.one_hot. 54 | 55 | Parameters: 56 | indices: a (n_batch, m) Tensor or (m) Tensor. 57 | depth: a scalar. Represents the depth of the one hot dimension. 58 | Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor. 59 | """ 60 | 61 | encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda() 62 | index = indices.view(indices.size()+torch.Size([1])) 63 | encoded_indicies = encoded_indicies.scatter_(1,index,1) 64 | 65 | return encoded_indicies 66 | 67 | def batched_kronecker(matrix1, matrix2): 68 | matrix1_flatten = matrix1.reshape(matrix1.size()[0], -1) 69 | matrix2_flatten = matrix2.reshape(matrix2.size()[0], -1) 70 | return torch.bmm(matrix1_flatten.unsqueeze(2), matrix2_flatten.unsqueeze(1)).reshape([matrix1.size()[0]] + list(matrix1.size()[1:]) + list(matrix2.size()[1:])).permute([0, 1, 3, 2, 4]).reshape(matrix1.size(0), matrix1.size(1) * matrix2.size(1), matrix1.size(2) * matrix2.size(2)) 71 | 72 | 73 | def FFTHead(query, support, support_labels, n_way, n_shot, normalize=True): 74 | """ 75 | Constructs the prototype representation of each class(=mean of support vectors of each class) and 76 | returns the classification score (=L2 distance to each class prototype) on the query set. 77 | 78 | This model is the classification head described in: 79 | Prototypical Networks for Few-shot Learning 80 | (Snell et al., NIPS 2017). 81 | 82 | Parameters: 83 | query: a (tasks_per_batch, n_query, d) Tensor. 84 | support: a (tasks_per_batch, n_support, d) Tensor. 85 | support_labels: a (tasks_per_batch, n_support) Tensor. 86 | n_way: a scalar. Represents the number of classes in a few-shot classification task. 87 | n_shot: a scalar. Represents the number of support examples given per class. 88 | normalize: a boolean. Represents whether if we want to normalize the distances by the embedding dimension. 89 | Returns: a (tasks_per_batch, n_query, n_way) Tensor. 90 | """ 91 | support = F.normalize(support, dim=-3) 92 | query = F.normalize(query, dim=-3) 93 | 94 | tasks_per_batch = query.size(0) 95 | n_support = support.size(1) 96 | n_query = query.size(1) 97 | n_c = query.size(2) 98 | n_w = query.size(3) 99 | n_h = query.size(4) 100 | 101 | assert (n_support == n_way * n_shot) # n_support must equal to n_way * n_shot 102 | 103 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) 104 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way) 105 | 106 | # From: 107 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/architectures/PrototypicalNetworksHead.py 108 | # ************************* Compute Prototypes ************************** 109 | labels_train_transposed = support_labels_one_hot.transpose(1, 2) 110 | # Batch matrix multiplication: 111 | # prototypes = labels_train_transposed * features_train ==> 112 | # [batch_size x nKnovel x num_channels] = 113 | # [batch_size x nKnovel x num_train_examples] * [batch_size * num_train_examples * num_channels] 114 | prototypes = torch.bmm(labels_train_transposed, support.reshape(tasks_per_batch, n_support, -1)) 115 | # Divide with the number of examples per novel category. 116 | prototypes = prototypes.div( 117 | labels_train_transposed.sum(dim=2, keepdim=True).expand_as(prototypes) 118 | ) 119 | prototypes = prototypes.reshape(tasks_per_batch, n_way, n_c, n_w, n_h) 120 | # Distance Matrix Vectorization Trick 121 | 122 | prototypes_rfft2 = torch.fft.fftn(prototypes, dim=(3, 4)) 123 | prototypes_rfft2 = torch.roll(prototypes_rfft2, (n_w // 2, n_h // 2), dims=(3, 4)) 124 | query_rfft2 = torch.fft.fftn(query, dim=(3, 4)) 125 | query_rfft2 = torch.roll(query_rfft2, (n_w // 2, n_h // 2), dims=(3, 4)) 126 | prototypes_rfft2_a = torch.abs(prototypes_rfft2) 127 | # a = torch.max(prototypes_rfft2_a, keepdim=True, dim=-1)[0] 128 | # prototypes_rfft2_a = prototypes_rfft2_a / torch.max(torch.max(prototypes_rfft2_a, keepdim=True, dim=-1)[0], keepdim=True, dim=-2)[0] 129 | query_rfft2_a = torch.abs(query_rfft2) 130 | # query_rfft2_a = query_rfft2_a / torch.max(torch.max(query_rfft2_a, keepdim=True, dim=-1)[0], 131 | # keepdim=True, dim=-2)[0] 132 | 133 | prototypes = prototypes_rfft2_a.reshape(tasks_per_batch, n_way, -1) 134 | query = query_rfft2_a.reshape(tasks_per_batch, n_query, -1) 135 | 136 | AB = computeGramMatrix(query, prototypes) 137 | AA = (query * query).sum(dim=2, keepdim=True) 138 | BB = (prototypes * prototypes).sum(dim=2, keepdim=True).reshape(tasks_per_batch, 1, n_way) 139 | logits = AA.expand_as(AB) - 2 * AB + BB.expand_as(AB) 140 | logits = -logits 141 | 142 | if normalize: 143 | logits = logits / query.shape[-1] 144 | 145 | return logits 146 | 147 | 148 | def MetaOptNetHead_Ridge(query, support, support_labels, n_way, n_shot, lambda_reg=50.0, double_precision=False): 149 | """ 150 | Fits the support set with ridge regression and 151 | returns the classification score on the query set. 152 | 153 | Parameters: 154 | query: a (tasks_per_batch, n_query, d) Tensor. 155 | support: a (tasks_per_batch, n_support, d) Tensor. 156 | support_labels: a (tasks_per_batch, n_support) Tensor. 157 | n_way: a scalar. Represents the number of classes in a few-shot classification task. 158 | n_shot: a scalar. Represents the number of support examples given per class. 159 | lambda_reg: a scalar. Represents the strength of L2 regularization. 160 | Returns: a (tasks_per_batch, n_query, n_way) Tensor. 161 | """ 162 | 163 | tasks_per_batch = query.size(0) 164 | n_support = support.size(1) 165 | n_query = query.size(1) 166 | 167 | assert(query.dim() == 3) 168 | assert(support.dim() == 3) 169 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2)) 170 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot 171 | 172 | #Here we solve the dual problem: 173 | #Note that the classes are indexed by m & samples are indexed by i. 174 | #min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i 175 | 176 | #where w_m(\alpha) = \sum_i \alpha^m_i x_i, 177 | 178 | #\alpha is an (n_support, n_way) matrix 179 | kernel_matrix = computeGramMatrix(support, support) 180 | kernel_matrix += lambda_reg * torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda() 181 | 182 | block_kernel_matrix = kernel_matrix.repeat(n_way, 1, 1) #(n_way * tasks_per_batch, n_support, n_support) 183 | 184 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) # (tasks_per_batch * n_support, n_way) 185 | support_labels_one_hot = support_labels_one_hot.transpose(0, 1) # (n_way, tasks_per_batch * n_support) 186 | support_labels_one_hot = support_labels_one_hot.reshape(n_way * tasks_per_batch, n_support) # (n_way*tasks_per_batch, n_support) 187 | 188 | G = block_kernel_matrix 189 | e = -2.0 * support_labels_one_hot 190 | 191 | #This is a fake inequlity constraint as qpth does not support QP without an inequality constraint. 192 | id_matrix_1 = torch.zeros(tasks_per_batch*n_way, n_support, n_support) 193 | C = Variable(id_matrix_1) 194 | h = Variable(torch.zeros((tasks_per_batch*n_way, n_support))) 195 | dummy = Variable(torch.Tensor()).cuda() # We want to ignore the equality constraint. 196 | 197 | if double_precision: 198 | G, e, C, h = [x.double().cuda() for x in [G, e, C, h]] 199 | 200 | else: 201 | G, e, C, h = [x.float().cuda() for x in [G, e, C, h]] 202 | 203 | # Solve the following QP to fit SVM: 204 | # \hat z = argmin_z 1/2 z^T G z + e^T z 205 | # subject to Cz <= h 206 | # We use detach() to prevent backpropagation to fixed variables. 207 | qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach()) 208 | #qp_sol = QPFunction(verbose=False)(G, e.detach(), dummy.detach(), dummy.detach(), dummy.detach(), dummy.detach()) 209 | 210 | #qp_sol (n_way*tasks_per_batch, n_support) 211 | qp_sol = qp_sol.reshape(n_way, tasks_per_batch, n_support) 212 | #qp_sol (n_way, tasks_per_batch, n_support) 213 | qp_sol = qp_sol.permute(1, 2, 0) 214 | #qp_sol (tasks_per_batch, n_support, n_way) 215 | 216 | # Compute the classification score. 217 | compatibility = computeGramMatrix(support, query) 218 | compatibility = compatibility.float() 219 | compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way) 220 | qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way) 221 | logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way) 222 | logits = logits * compatibility 223 | logits = torch.sum(logits, 1) 224 | 225 | return logits 226 | 227 | def R2D2Head(query, support, support_labels, n_way, n_shot, l2_regularizer_lambda=50.0): 228 | """ 229 | Fits the support set with ridge regression and 230 | returns the classification score on the query set. 231 | 232 | This model is the classification head described in: 233 | Meta-learning with differentiable closed-form solvers 234 | (Bertinetto et al., in submission to NIPS 2018). 235 | 236 | Parameters: 237 | query: a (tasks_per_batch, n_query, d) Tensor. 238 | support: a (tasks_per_batch, n_support, d) Tensor. 239 | support_labels: a (tasks_per_batch, n_support) Tensor. 240 | n_way: a scalar. Represents the number of classes in a few-shot classification task. 241 | n_shot: a scalar. Represents the number of support examples given per class. 242 | l2_regularizer_lambda: a scalar. Represents the strength of L2 regularization. 243 | Returns: a (tasks_per_batch, n_query, n_way) Tensor. 244 | """ 245 | 246 | tasks_per_batch = query.size(0) 247 | n_support = support.size(1) 248 | 249 | assert(query.dim() == 3) 250 | assert(support.dim() == 3) 251 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2)) 252 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot 253 | 254 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) 255 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way) 256 | 257 | id_matrix = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda() 258 | 259 | # Compute the dual form solution of the ridge regression. 260 | # W = X^T(X X^T - lambda * I)^(-1) Y 261 | ridge_sol = computeGramMatrix(support, support) + l2_regularizer_lambda * id_matrix 262 | ridge_sol = binv(ridge_sol) 263 | ridge_sol = torch.bmm(support.transpose(1,2), ridge_sol) 264 | ridge_sol = torch.bmm(ridge_sol, support_labels_one_hot) 265 | 266 | # Compute the classification score. 267 | # score = W X 268 | logits = torch.bmm(query, ridge_sol) 269 | 270 | return logits 271 | 272 | 273 | def MetaOptNetHead_SVM_He(query, support, support_labels, n_way, n_shot, C_reg=0.01, double_precision=False): 274 | """ 275 | Fits the support set with multi-class SVM and 276 | returns the classification score on the query set. 277 | 278 | This is the multi-class SVM presented in: 279 | A simplified multi-class support vector machine with reduced dual optimization 280 | (He et al., Pattern Recognition Letter 2012). 281 | 282 | This SVM is desirable because the dual variable of size is n_support 283 | (as opposed to n_way*n_support in the Weston&Watkins or Crammer&Singer multi-class SVM). 284 | This model is the classification head that we have initially used for our project. 285 | This was dropped since it turned out that it performs suboptimally on the meta-learning scenarios. 286 | 287 | Parameters: 288 | query: a (tasks_per_batch, n_query, d) Tensor. 289 | support: a (tasks_per_batch, n_support, d) Tensor. 290 | support_labels: a (tasks_per_batch, n_support) Tensor. 291 | n_way: a scalar. Represents the number of classes in a few-shot classification task. 292 | n_shot: a scalar. Represents the number of support examples given per class. 293 | C_reg: a scalar. Represents the cost parameter C in SVM. 294 | Returns: a (tasks_per_batch, n_query, n_way) Tensor. 295 | """ 296 | 297 | tasks_per_batch = query.size(0) 298 | n_support = support.size(1) 299 | n_query = query.size(1) 300 | 301 | assert(query.dim() == 3) 302 | assert(support.dim() == 3) 303 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2)) 304 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot 305 | 306 | 307 | kernel_matrix = computeGramMatrix(support, support) 308 | 309 | V = (support_labels * n_way - torch.ones(tasks_per_batch, n_support, n_way).cuda()) / (n_way - 1) 310 | G = computeGramMatrix(V, V).detach() 311 | G = kernel_matrix * G 312 | 313 | e = Variable(-1.0 * torch.ones(tasks_per_batch, n_support)) 314 | id_matrix = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support) 315 | C = Variable(torch.cat((id_matrix, -id_matrix), 1)) 316 | h = Variable(torch.cat((C_reg * torch.ones(tasks_per_batch, n_support), torch.zeros(tasks_per_batch, n_support)), 1)) 317 | dummy = Variable(torch.Tensor()).cuda() # We want to ignore the equality constraint. 318 | 319 | if double_precision: 320 | G, e, C, h = [x.double().cuda() for x in [G, e, C, h]] 321 | else: 322 | G, e, C, h = [x.cuda() for x in [G, e, C, h]] 323 | 324 | # Solve the following QP to fit SVM: 325 | # \hat z = argmin_z 1/2 z^T G z + e^T z 326 | # subject to Cz <= h 327 | # We use detach() to prevent backpropagation to fixed variables. 328 | qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach()) 329 | 330 | # Compute the classification score. 331 | compatibility = computeGramMatrix(query, support) 332 | compatibility = compatibility.float() 333 | 334 | logits = qp_sol.float().unsqueeze(1).expand(tasks_per_batch, n_query, n_support) 335 | logits = logits * compatibility 336 | logits = logits.view(tasks_per_batch, n_query, n_shot, n_way) 337 | logits = torch.sum(logits, 2) 338 | 339 | return logits 340 | 341 | 342 | def CosineNetHead(query, support, support_labels, n_way, n_shot, normalize=True): 343 | """ 344 | Constructs the prototype representation of each class(=mean of support vectors of each class) and 345 | returns the classification score (=L2 distance to each class prototype) on the query set. 346 | 347 | This model is the classification head described in: 348 | Prototypical Networks for Few-shot Learning 349 | (Snell et al., NIPS 2017). 350 | 351 | Parameters: 352 | query: a (tasks_per_batch, n_query, d) Tensor. 353 | support: a (tasks_per_batch, n_support, d) Tensor. 354 | support_labels: a (tasks_per_batch, n_support) Tensor. 355 | n_way: a scalar. Represents the number of classes in a few-shot classification task. 356 | n_shot: a scalar. Represents the number of support examples given per class. 357 | normalize: a boolean. Represents whether if we want to normalize the distances by the embedding dimension. 358 | Returns: a (tasks_per_batch, n_query, n_way) Tensor. 359 | """ 360 | 361 | tasks_per_batch = query.size(0) 362 | n_support = support.size(1) 363 | n_query = query.size(1) 364 | d = query.size(2) 365 | 366 | assert (query.dim() == 3) 367 | assert (support.dim() == 3) 368 | assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) 369 | assert (n_support == n_way * n_shot) # n_support must equal to n_way * n_shot 370 | 371 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) 372 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way) 373 | 374 | # From: 375 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/architectures/PrototypicalNetworksHead.py 376 | # ************************* Compute Prototypes ************************** 377 | labels_train_transposed = support_labels_one_hot.transpose(1, 2) 378 | # Batch matrix multiplication: 379 | # prototypes = labels_train_transposed * features_train ==> 380 | # [batch_size x nKnovel x num_channels] = 381 | # [batch_size x nKnovel x num_train_examples] * [batch_size * num_train_examples * num_channels] 382 | prototypes = torch.bmm(labels_train_transposed, support) 383 | # Divide with the number of examples per novel category. 384 | prototypes = prototypes.div( 385 | labels_train_transposed.sum(dim=2, keepdim=True).expand_as(prototypes) 386 | ) 387 | 388 | # Distance Matrix Vectorization Trick 389 | logits = torch.nn.functional.cosine_similarity(query.unsqueeze(2).expand(-1, -1, prototypes.shape[1], -1), 390 | prototypes.unsqueeze(1).expand(-1, query.shape[1], -1, -1), dim=-1) 391 | 392 | # logits = torch.nn.CosineSimilarity(dim=-1)(query.unsqueeze(2).expand(-1, -1, prototypes.shape[1], -1), 393 | # prototypes.unsqueeze(1).expand(-1, query.shape[1], -1, -1))*10 394 | 395 | return logits 396 | 397 | 398 | def emd_inference_qpth(distance_matrix, weight1, weight2, form='QP', l2_strength=0.0001): 399 | """ 400 | to use the QP solver QPTH to derive EMD (LP problem), 401 | one can transform the LP problem to QP, 402 | or omit the QP term by multiplying it with a small value,i.e. l2_strngth. 403 | :param distance_matrix: nbatch * element_number * element_number 404 | :param weight1: nbatch * weight_number 405 | :param weight2: nbatch * weight_number 406 | :return: 407 | emd distance: nbatch*1 408 | flow : nbatch * weight_number *weight_number 409 | """ 410 | 411 | weight1 = (weight1 * weight1.shape[-1]) / weight1.sum(1).unsqueeze(1) 412 | weight2 = (weight2 * weight2.shape[-1]) / weight2.sum(1).unsqueeze(1) 413 | 414 | nbatch = distance_matrix.shape[0] 415 | nelement_distmatrix = distance_matrix.shape[1] * distance_matrix.shape[2] 416 | nelement_weight1 = weight1.shape[1] 417 | nelement_weight2 = weight2.shape[1] 418 | 419 | Q_1 = distance_matrix.view(-1, 1, nelement_distmatrix).double() 420 | 421 | if form == 'QP': 422 | # version: QTQ 423 | Q = torch.bmm(Q_1.transpose(2, 1), Q_1).double().cuda() + 1e-4 * torch.eye( 424 | nelement_distmatrix).double().cuda().unsqueeze(0).repeat(nbatch, 1, 1) # 0.00001 * 425 | p = torch.zeros(nbatch, nelement_distmatrix).double().cuda() 426 | elif form == 'L2': 427 | # version: regularizer 428 | Q = (l2_strength * torch.eye(nelement_distmatrix).double()).cuda().unsqueeze(0).repeat(nbatch, 1, 1) 429 | p = distance_matrix.view(nbatch, nelement_distmatrix).double() 430 | else: 431 | raise ValueError('Unkown form') 432 | 433 | h_1 = torch.zeros(nbatch, nelement_distmatrix).double().cuda() 434 | h_2 = torch.cat([weight1, weight2], 1).double() 435 | h = torch.cat((h_1, h_2), 1) 436 | 437 | G_1 = -torch.eye(nelement_distmatrix).double().cuda().unsqueeze(0).repeat(nbatch, 1, 1) 438 | G_2 = torch.zeros([nbatch, nelement_weight1 + nelement_weight2, nelement_distmatrix]).double().cuda() 439 | # sum_j(xij) = si 440 | for i in range(nelement_weight1): 441 | G_2[:, i, nelement_weight2 * i:nelement_weight2 * (i + 1)] = 1 442 | # sum_i(xij) = dj 443 | for j in range(nelement_weight2): 444 | G_2[:, nelement_weight1 + j, j::nelement_weight2] = 1 445 | #xij>=0, sum_j(xij) <= si,sum_i(xij) <= dj, sum_ij(x_ij) = min(sum(si), sum(dj)) 446 | G = torch.cat((G_1, G_2), 1) 447 | A = torch.ones(nbatch, 1, nelement_distmatrix).double().cuda() 448 | b = torch.min(torch.sum(weight1, 1), torch.sum(weight2, 1)).unsqueeze(1).double() 449 | flow = QPFunction(verbose=-1)(Q, p, G, h, A, b) 450 | 451 | emd_score = torch.sum((1 - Q_1).squeeze() * flow, 1) 452 | return emd_score, flow.view(-1, nelement_weight1, nelement_weight2) 453 | 454 | def emd_inference_opencv(cost_matrix, weight1, weight2): 455 | # cost matrix is a tensor of shape [N,N] 456 | cost_matrix = cost_matrix.detach().cpu().numpy() 457 | 458 | # weight1 = F.relu(weight1) + 1e-5 459 | # weight2 = F.relu(weight2) + 1e-5 460 | # 461 | # weight1 = (weight1 * (weight1.shape[0] / weight1.sum().item())).view(-1, 1).detach().cpu().numpy() 462 | # weight2 = (weight2 * (weight2.shape[0] / weight2.sum().item())).view(-1, 1).detach().cpu().numpy() 463 | weight1 = weight1.detach().cpu().numpy() 464 | weight2 = weight2.detach().cpu().numpy() 465 | # print(np.sum(weight1)) 466 | # print(np.sum(weight2)) 467 | cost, _, flow = cv2.EMD(weight1, weight2, cv2.DIST_USER, cost_matrix) 468 | return cost, flow 469 | 470 | def emd_inference_opencv_test(distance_matrix,weight1,weight2): 471 | distance_list = [] 472 | flow_list = [] 473 | 474 | for i in range (distance_matrix.shape[0]): 475 | cost,flow=emd_inference_opencv(distance_matrix[i],weight1[i],weight2[i]) 476 | distance_list.append(cost) 477 | flow_list.append(torch.from_numpy(flow)) 478 | 479 | emd_distance = torch.Tensor(distance_list).cuda().double() 480 | flow = torch.stack(flow_list, dim=0).cuda().double() 481 | 482 | return emd_distance,flow 483 | 484 | def LocalNetHead(query, support, support_labels, n_way, n_shot, normalize=True): 485 | """ 486 | Constructs the prototype representation of each class(=mean of support vectors of each class) and 487 | returns the classification score (=L2 distance to each class prototype) on the query set. 488 | 489 | This model is the classification head described in: 490 | Prototypical Networks for Few-shot Learning 491 | (Snell et al., NIPS 2017). 492 | 493 | Parameters: 494 | query: a (tasks_per_batch, n_query, d) Tensor. 495 | support: a (tasks_per_batch, n_support, d) Tensor. 496 | support_labels: a (tasks_per_batch, n_support) Tensor. 497 | n_way: a scalar. Represents the number of classes in a few-shot classification task. 498 | n_shot: a scalar. Represents the number of support examples given per class. 499 | normalize: a boolean. Represents whether if we want to normalize the distances by the embedding dimension. 500 | Returns: a (tasks_per_batch, n_query, n_way) Tensor. 501 | """ 502 | 503 | tasks_per_batch = query.size(0) 504 | n_support = support.size(1) 505 | n_query = query.size(1) 506 | d = query.size(2) 507 | 508 | assert (query.dim() == 3) 509 | assert (support.dim() == 3) 510 | assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) 511 | assert (n_support == n_way * n_shot) # n_support must equal to n_way * n_shot 512 | 513 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) 514 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way) 515 | 516 | # From: 517 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/architectures/PrototypicalNetworksHead.py 518 | # ************************* Compute Prototypes ************************** 519 | labels_train_transposed = support_labels_one_hot.transpose(1, 2) 520 | # Batch matrix multiplication: 521 | # prototypes = labels_train_transposed * features_train ==> 522 | # [batch_size x nKnovel x num_channels] = 523 | # [batch_size x nKnovel x num_train_examples] * [batch_size * num_train_examples * num_channels] 524 | prototypes = torch.bmm(labels_train_transposed, support) 525 | # Divide with the number of examples per novel category. 526 | prototypes = prototypes.div( 527 | labels_train_transposed.sum(dim=2, keepdim=True).expand_as(prototypes) 528 | ) 529 | data = torch.cat([support, query], dim=1) 530 | data_local_dis = [] 531 | for i in range(tasks_per_batch): 532 | dis = x2p_torch(data[i], tol=1e-5, perplexity=10.0) 533 | data_local_dis.append(dis) 534 | data_local_dis = torch.stack(data_local_dis, dim=0) 535 | # Distance Matrix Vectorization Trick 536 | # AB = computeGramMatrix(query, prototypes) 537 | # AA = (query * query).sum(dim=2, keepdim=True) 538 | # BB = (prototypes * prototypes).sum(dim=2, keepdim=True).reshape(tasks_per_batch, 1, n_way) 539 | # logits = AA.expand_as(AB) - 2 * AB + BB.expand_as(AB) 540 | # logits = -logits 541 | # 542 | # if normalize: 543 | # logits = logits / d 544 | cosine_distance_matrix = 1 - torch.nn.functional.cosine_similarity(data.unsqueeze(2).expand(-1, -1, data.shape[1], -1), 545 | data.unsqueeze(1).expand(-1, data.shape[1], -1, -1), dim=-1) 546 | support_dist = data_local_dis[:, :n_support, :].unsqueeze(2).expand(-1, -1, n_query, -1).reshape(tasks_per_batch*n_support*n_query, -1) 547 | query_dist = data_local_dis[:, n_support:, :].unsqueeze(1).expand(-1, n_support, -1, -1).reshape(tasks_per_batch*n_support*n_query, -1) 548 | cost_matrix = cosine_distance_matrix.unsqueeze(1).expand(-1, n_support*n_query, -1, -1).reshape(tasks_per_batch*n_support*n_query, n_support+n_query, n_support+n_query) 549 | # emd_distance_qpth, qpth_flow = emd_inference_qpth(cost_matrix, support_dist, query_dist) 550 | emd_distance_qpth, qpth_flow = emd_inference_opencv_test(cost_matrix, support_dist, query_dist) 551 | logits = emd_distance_qpth.reshape(tasks_per_batch, n_support, n_query).permute(0, 2, 1).type_as(support) 552 | return logits 553 | 554 | def ProtoNetHead(query, support, support_labels, n_way, n_shot, normalize=True): 555 | """ 556 | Constructs the prototype representation of each class(=mean of support vectors of each class) and 557 | returns the classification score (=L2 distance to each class prototype) on the query set. 558 | 559 | This model is the classification head described in: 560 | Prototypical Networks for Few-shot Learning 561 | (Snell et al., NIPS 2017). 562 | 563 | Parameters: 564 | query: a (tasks_per_batch, n_query, d) Tensor. 565 | support: a (tasks_per_batch, n_support, d) Tensor. 566 | support_labels: a (tasks_per_batch, n_support) Tensor. 567 | n_way: a scalar. Represents the number of classes in a few-shot classification task. 568 | n_shot: a scalar. Represents the number of support examples given per class. 569 | normalize: a boolean. Represents whether if we want to normalize the distances by the embedding dimension. 570 | Returns: a (tasks_per_batch, n_query, n_way) Tensor. 571 | """ 572 | 573 | tasks_per_batch = query.size(0) 574 | n_support = support.size(1) 575 | n_query = query.size(1) 576 | d = query.size(2) 577 | 578 | assert(query.dim() == 3) 579 | assert(support.dim() == 3) 580 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2)) 581 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot 582 | 583 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) 584 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way) 585 | 586 | # From: 587 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/architectures/PrototypicalNetworksHead.py 588 | #************************* Compute Prototypes ************************** 589 | labels_train_transposed = support_labels_one_hot.transpose(1,2) 590 | # Batch matrix multiplication: 591 | # prototypes = labels_train_transposed * features_train ==> 592 | # [batch_size x nKnovel x num_channels] = 593 | # [batch_size x nKnovel x num_train_examples] * [batch_size * num_train_examples * num_channels] 594 | prototypes = torch.bmm(labels_train_transposed, support) 595 | # Divide with the number of examples per novel category. 596 | prototypes = prototypes.div( 597 | labels_train_transposed.sum(dim=2, keepdim=True).expand_as(prototypes) 598 | ) 599 | 600 | # Distance Matrix Vectorization Trick 601 | AB = computeGramMatrix(query, prototypes) 602 | AA = (query * query).sum(dim=2, keepdim=True) 603 | BB = (prototypes * prototypes).sum(dim=2, keepdim=True).reshape(tasks_per_batch, 1, n_way) 604 | logits = AA.expand_as(AB) - 2 * AB + BB.expand_as(AB) 605 | logits = -logits 606 | 607 | if normalize: 608 | logits = logits / d 609 | 610 | return logits 611 | 612 | def MetaOptNetHead_SVM_CS(query, support, support_labels, n_way, n_shot, C_reg=0.1, double_precision=False, maxIter=15): 613 | """ 614 | Fits the support set with multi-class SVM and 615 | returns the classification score on the query set. 616 | 617 | This is the multi-class SVM presented in: 618 | On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines 619 | (Crammer and Singer, Journal of Machine Learning Research 2001). 620 | 621 | This model is the classification head that we use for the final version. 622 | Parameters: 623 | query: a (tasks_per_batch, n_query, d) Tensor. 624 | support: a (tasks_per_batch, n_support, d) Tensor. 625 | support_labels: a (tasks_per_batch, n_support) Tensor. 626 | n_way: a scalar. Represents the number of classes in a few-shot classification task. 627 | n_shot: a scalar. Represents the number of support examples given per class. 628 | C_reg: a scalar. Represents the cost parameter C in SVM. 629 | Returns: a (tasks_per_batch, n_query, n_way) Tensor. 630 | """ 631 | 632 | tasks_per_batch = query.size(0) 633 | n_support = support.size(1) 634 | n_query = query.size(1) 635 | 636 | assert(query.dim() == 3) 637 | assert(support.dim() == 3) 638 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2)) 639 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot 640 | 641 | #Here we solve the dual problem: 642 | #Note that the classes are indexed by m & samples are indexed by i. 643 | #min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i 644 | #s.t. \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i 645 | 646 | #where w_m(\alpha) = \sum_i \alpha^m_i x_i, 647 | #and C^m_i = C if m = y_i, 648 | #C^m_i = 0 if m != y_i. 649 | #This borrows the notation of liblinear. 650 | 651 | #\alpha is an (n_support, n_way) matrix 652 | kernel_matrix = computeGramMatrix(support, support) 653 | 654 | id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda() 655 | block_kernel_matrix = batched_kronecker(kernel_matrix, id_matrix_0) 656 | #This seems to help avoid PSD error from the QP solver. 657 | block_kernel_matrix += 1.0 * torch.eye(n_way*n_support).expand(tasks_per_batch, n_way*n_support, n_way*n_support).cuda() 658 | 659 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) # (tasks_per_batch * n_support, n_support) 660 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way) 661 | support_labels_one_hot = support_labels_one_hot.reshape(tasks_per_batch, n_support * n_way) 662 | 663 | G = block_kernel_matrix 664 | e = -1.0 * support_labels_one_hot 665 | #print (G.size()) 666 | #This part is for the inequality constraints: 667 | #\alpha^m_i <= C^m_i \forall m,i 668 | #where C^m_i = C if m = y_i, 669 | #C^m_i = 0 if m != y_i. 670 | id_matrix_1 = torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support) 671 | C = Variable(id_matrix_1) 672 | h = Variable(C_reg * support_labels_one_hot) 673 | #print (C.size(), h.size()) 674 | #This part is for the equality constraints: 675 | #\sum_m \alpha^m_i=0 \forall i 676 | id_matrix_2 = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda() 677 | 678 | A = Variable(batched_kronecker(id_matrix_2, torch.ones(tasks_per_batch, 1, n_way).cuda())) 679 | b = Variable(torch.zeros(tasks_per_batch, n_support)) 680 | #print (A.size(), b.size()) 681 | if double_precision: 682 | G, e, C, h, A, b = [x.double().cuda() for x in [G, e, C, h, A, b]] 683 | else: 684 | G, e, C, h, A, b = [x.float().cuda() for x in [G, e, C, h, A, b]] 685 | 686 | # Solve the following QP to fit SVM: 687 | # \hat z = argmin_z 1/2 z^T G z + e^T z 688 | # subject to Cz <= h 689 | # We use detach() to prevent backpropagation to fixed variables. 690 | qp_sol = QPFunction(verbose=False, maxIter=maxIter)(G, e.detach(), C.detach(), h.detach(), A.detach(), b.detach()) 691 | 692 | # Compute the classification score. 693 | compatibility = computeGramMatrix(support, query) 694 | compatibility = compatibility.float() 695 | compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way) 696 | qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way) 697 | logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way) 698 | logits = logits * compatibility 699 | logits = torch.sum(logits, 1) 700 | 701 | return logits 702 | 703 | def MetaOptNetHead_SVM_WW(query, support, support_labels, n_way, n_shot, C_reg=0.00001, double_precision=False): 704 | """ 705 | Fits the support set with multi-class SVM and 706 | returns the classification score on the query set. 707 | 708 | This is the multi-class SVM presented in: 709 | Support Vector Machines for Multi Class Pattern Recognition 710 | (Weston and Watkins, ESANN 1999). 711 | 712 | Parameters: 713 | query: a (tasks_per_batch, n_query, d) Tensor. 714 | support: a (tasks_per_batch, n_support, d) Tensor. 715 | support_labels: a (tasks_per_batch, n_support) Tensor. 716 | n_way: a scalar. Represents the number of classes in a few-shot classification task. 717 | n_shot: a scalar. Represents the number of support examples given per class. 718 | C_reg: a scalar. Represents the cost parameter C in SVM. 719 | Returns: a (tasks_per_batch, n_query, n_way) Tensor. 720 | """ 721 | """ 722 | Fits the support set with multi-class SVM and 723 | returns the classification score on the query set. 724 | 725 | This is the multi-class SVM presented in: 726 | Support Vector Machines for Multi Class Pattern Recognition 727 | (Weston and Watkins, ESANN 1999). 728 | 729 | Parameters: 730 | query: a (tasks_per_batch, n_query, d) Tensor. 731 | support: a (tasks_per_batch, n_support, d) Tensor. 732 | support_labels: a (tasks_per_batch, n_support) Tensor. 733 | n_way: a scalar. Represents the number of classes in a few-shot classification task. 734 | n_shot: a scalar. Represents the number of support examples given per class. 735 | C_reg: a scalar. Represents the cost parameter C in SVM. 736 | Returns: a (tasks_per_batch, n_query, n_way) Tensor. 737 | """ 738 | tasks_per_batch = query.size(0) 739 | n_support = support.size(1) 740 | n_query = query.size(1) 741 | 742 | assert(query.dim() == 3) 743 | assert(support.dim() == 3) 744 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2)) 745 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot 746 | 747 | #In theory, \alpha is an (n_support, n_way) matrix 748 | #NOTE: In this implementation, we solve for a flattened vector of size (n_way*n_support) 749 | #In order to turn it into a matrix, you must first reshape it into an (n_way, n_support) matrix 750 | #then transpose it, resulting in (n_support, n_way) matrix 751 | kernel_matrix = computeGramMatrix(support, support) + torch.ones(tasks_per_batch, n_support, n_support).cuda() 752 | 753 | id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda() 754 | block_kernel_matrix = batched_kronecker(id_matrix_0, kernel_matrix) 755 | 756 | kernel_matrix_mask_x = support_labels.reshape(tasks_per_batch, n_support, 1).expand(tasks_per_batch, n_support, n_support) 757 | kernel_matrix_mask_y = support_labels.reshape(tasks_per_batch, 1, n_support).expand(tasks_per_batch, n_support, n_support) 758 | kernel_matrix_mask = (kernel_matrix_mask_x == kernel_matrix_mask_y).float() 759 | 760 | block_kernel_matrix_inter = kernel_matrix_mask * kernel_matrix 761 | block_kernel_matrix += block_kernel_matrix_inter.repeat(1, n_way, n_way) 762 | 763 | kernel_matrix_mask_second_term = support_labels.reshape(tasks_per_batch, n_support, 1).expand(tasks_per_batch, n_support, n_support * n_way) 764 | kernel_matrix_mask_second_term = kernel_matrix_mask_second_term == torch.arange(n_way).long().repeat(n_support).reshape(n_support, n_way).transpose(1, 0).reshape(1, -1).repeat(n_support, 1).cuda() 765 | kernel_matrix_mask_second_term = kernel_matrix_mask_second_term.float() 766 | 767 | block_kernel_matrix -= (2.0 - 1e-4) * (kernel_matrix_mask_second_term * kernel_matrix.repeat(1, 1, n_way)).repeat(1, n_way, 1) 768 | 769 | Y_support = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) 770 | Y_support = Y_support.view(tasks_per_batch, n_support, n_way) 771 | Y_support = Y_support.transpose(1, 2) # (tasks_per_batch, n_way, n_support) 772 | Y_support = Y_support.reshape(tasks_per_batch, n_way * n_support) 773 | 774 | G = block_kernel_matrix 775 | 776 | e = -2.0 * torch.ones(tasks_per_batch, n_way * n_support) 777 | id_matrix = torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support) 778 | 779 | C_mat = C_reg * torch.ones(tasks_per_batch, n_way * n_support).cuda() - C_reg * Y_support 780 | 781 | C = Variable(torch.cat((id_matrix, -id_matrix), 1)) 782 | #C = Variable(torch.cat((id_matrix_masked, -id_matrix_masked), 1)) 783 | zer = torch.zeros(tasks_per_batch, n_way * n_support).cuda() 784 | 785 | h = Variable(torch.cat((C_mat, zer), 1)) 786 | 787 | dummy = Variable(torch.Tensor()).cuda() # We want to ignore the equality constraint. 788 | 789 | if double_precision: 790 | G, e, C, h = [x.double().cuda() for x in [G, e, C, h]] 791 | else: 792 | G, e, C, h = [x.cuda() for x in [G, e, C, h]] 793 | 794 | # Solve the following QP to fit SVM: 795 | # \hat z = argmin_z 1/2 z^T G z + e^T z 796 | # subject to Cz <= h 797 | # We use detach() to prevent backpropagation to fixed variables. 798 | #qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach()) 799 | qp_sol = QPFunction(verbose=False)(G, e, C, h, dummy.detach(), dummy.detach()) 800 | 801 | # Compute the classification score. 802 | compatibility = computeGramMatrix(support, query) + torch.ones(tasks_per_batch, n_support, n_query).cuda() 803 | compatibility = compatibility.float() 804 | compatibility = compatibility.unsqueeze(1).expand(tasks_per_batch, n_way, n_support, n_query) 805 | qp_sol = qp_sol.float() 806 | qp_sol = qp_sol.reshape(tasks_per_batch, n_way, n_support) 807 | A_i = torch.sum(qp_sol, 1) # (tasks_per_batch, n_support) 808 | A_i = A_i.unsqueeze(1).expand(tasks_per_batch, n_way, n_support) 809 | qp_sol = qp_sol.float().unsqueeze(3).expand(tasks_per_batch, n_way, n_support, n_query) 810 | Y_support_reshaped = Y_support.reshape(tasks_per_batch, n_way, n_support) 811 | Y_support_reshaped = A_i * Y_support_reshaped 812 | Y_support_reshaped = Y_support_reshaped.unsqueeze(3).expand(tasks_per_batch, n_way, n_support, n_query) 813 | logits = (Y_support_reshaped - qp_sol) * compatibility 814 | 815 | logits = torch.sum(logits, 2) 816 | 817 | return logits.transpose(1, 2) 818 | 819 | class ClassificationHead(nn.Module): 820 | def __init__(self, base_learner='MetaOptNet', enable_scale=True): 821 | super(ClassificationHead, self).__init__() 822 | if ('SVM-CS' in base_learner): 823 | self.head = MetaOptNetHead_SVM_CS 824 | elif ('Ridge' in base_learner): 825 | self.head = MetaOptNetHead_Ridge 826 | elif ('R2D2' in base_learner): 827 | self.head = R2D2Head 828 | elif ('Proto' in base_learner): 829 | self.head = ProtoNetHead 830 | elif ('Local' in base_learner): 831 | self.head = LocalNetHead 832 | elif ('Cosine' in base_learner): 833 | self.head = CosineNetHead 834 | elif ('SVM-He' in base_learner): 835 | self.head = MetaOptNetHead_SVM_He 836 | elif ('SVM-WW' in base_learner): 837 | self.head = MetaOptNetHead_SVM_WW 838 | elif ('FFT' in base_learner): 839 | self.head = FFTHead 840 | else: 841 | print ("Cannot recognize the base learner type") 842 | assert(False) 843 | 844 | # Add a learnable scale 845 | self.enable_scale = enable_scale 846 | self.scale = nn.Parameter(torch.FloatTensor([1.0])) 847 | 848 | def forward(self, query, support, support_labels, n_way, n_shot, **kwargs): 849 | if self.enable_scale: 850 | return self.scale * self.head(query, support, support_labels, n_way, n_shot, **kwargs) 851 | else: 852 | return self.head(query, support, support_labels, n_way, n_shot, **kwargs) * 10 853 | --------------------------------------------------------------------------------