├── README.md ├── classification ├── README.md ├── convnet.py ├── materials │ ├── test.csv │ ├── train.csv │ └── val.csv ├── miniimagenet.py ├── model │ ├── maml.py │ ├── maml_pcg.py │ └── pcg_module.py ├── modified_pytorchmodule.py ├── run_test_modgrad.py ├── samplers.py ├── test_modgrad.py ├── train_modgrad.py └── utils.py └── comparison_method.png /README.md: -------------------------------------------------------------------------------- 1 | # On Modulating the Gradient for Meta-Learning 2 | 3 | The repository contains the code for: 4 |
5 | [On Modulating the Gradient for Meta-Learning](http://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123530545.pdf) 6 |
7 | [Supp. Material](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123530545-supp.pdf) 8 |
9 | ECCV 2020 10 | 11 | Comparison with prior methods: MAML, Meta-SGD, and Ours. 12 | 13 |
14 | 15 | 16 | 17 | 18 | 19 | Please refer to each folder for different tasks: classification, regression, and RL. 20 | 21 | 22 | ## Citation 23 | 24 | ```` 25 | @inproceedings{Christian2020ModGrad, 26 | author = {Simon, Christian and Koniusz, Piotr and Nock, Richard and Harandi, Mehrtash}, 27 | title = {On Modulating the Gradient for Meta-Learning}, 28 | booktitle = {The European Conference on Computer Vision}, 29 | year = {2020} 30 | } 31 | ```` 32 | 33 | 34 | ## Acknowledgement 35 | Thank you for CAVIA code: 36 | regression and RL tasks are adopted from https://github.com/lmzintgraf/cavia 37 | 38 | -------------------------------------------------------------------------------- /classification/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | To download mini-ImageNet: 4 | Google drive file [here](https://drive.google.com/file/d/1HkgrkAwukzEZA0TpO7010PkAOREb2Nuk/view) to directly 5 | download the `mini-imagenet.zip` file. This mini-ImageNet set refers to https://github.com/Clarifai/few-shot-ctm. 6 | 7 | Change '--data-path' to the folder where the data is stored. 8 | 9 | Please run train_modgrad.py for training and run_test_modgrad.py for testing. 10 | -------------------------------------------------------------------------------- /classification/convnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from modified_pytorchmodule import Conv2d_fw, Linear_fw, BatchNorm2d_fw 3 | 4 | def conv_block(in_channels, out_channels): 5 | return nn.Sequential( 6 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 7 | nn.BatchNorm2d(out_channels), 8 | nn.ReLU(), 9 | nn.MaxPool2d(2) 10 | ) 11 | 12 | def conv_block_fast(in_channels, out_channels): 13 | return nn.Sequential( 14 | Conv2d_fw(in_channels, out_channels, 3, padding=1), 15 | BatchNorm2d_fw(out_channels), 16 | nn.ReLU(), 17 | nn.MaxPool2d(2) 18 | ) 19 | 20 | class ConvNet_MAML(nn.Module): 21 | 22 | def __init__(self, x_dim=3, hid_dim=64, z_dim=64): 23 | super().__init__() 24 | self.encoder = nn.Sequential( 25 | conv_block_fast(x_dim, hid_dim), 26 | conv_block_fast(hid_dim, hid_dim), 27 | conv_block_fast(hid_dim, hid_dim), 28 | conv_block_fast(hid_dim, z_dim), 29 | ) 30 | 31 | self.out_channels = 1600 32 | 33 | def forward(self, x): 34 | 35 | x = self.encoder(x) 36 | return x.view(x.size(0), -1) 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /classification/miniimagenet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from PIL import Image 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | import numpy as np 7 | ROOT_PATH = './materials/' 8 | 9 | 10 | 11 | class MiniImageNet(Dataset): 12 | 13 | def __init__(self, setname, img_path): 14 | csv_path = osp.join(ROOT_PATH, setname + '.csv') 15 | lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:] 16 | IMG_PATH = img_path 17 | data = [] 18 | label = [] 19 | lb = -1 20 | 21 | self.wnids = [] 22 | 23 | for l in lines: 24 | name, wnid = l.split(',') 25 | path = osp.join(IMG_PATH, 'images', name) 26 | if wnid not in self.wnids: 27 | self.wnids.append(wnid) 28 | lb += 1 29 | data.append(path) 30 | label.append(lb) 31 | 32 | self.data = data 33 | self.label = label 34 | if setname == 'train': 35 | self.transform = transforms.Compose([ 36 | transforms.ToTensor(), 37 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 38 | std=[0.229, 0.224, 0.225]) 39 | ]) 40 | else: 41 | self.transform = transforms.Compose([ 42 | transforms.ToTensor(), 43 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 44 | std=[0.229, 0.224, 0.225]) 45 | ]) 46 | 47 | 48 | 49 | def __len__(self): 50 | return len(self.data) 51 | 52 | def __getitem__(self, i): 53 | path, label = self.data[i], self.label[i] 54 | img =Image.open(path).convert('RGB') 55 | img = img.resize((84, 84)).convert('RGB') 56 | image = self.transform(img) 57 | 58 | return image, label 59 | 60 | 61 | -------------------------------------------------------------------------------- /classification/model/maml.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from convnet import ConvNet_MAML, Linear_fw 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from networks.modified_pytorchmodule import DistLinear 8 | 9 | 10 | class MAML(nn.Module): 11 | def __init__(self, n_way, n_shot, train_lr=0.1, noise_rate=0.): 12 | super().__init__() 13 | self.cnn = ConvNet_MAML() 14 | self.classifier = Linear_fw(self.cnn.out_channels, n_way) 15 | #self.classifier = DistLinear(self.cnn.out_channels, n_way) 16 | self.train_lr = train_lr 17 | self.n_way = n_way 18 | self.n_shot = n_shot 19 | self.noise_rate = noise_rate 20 | self.idx=16 21 | 22 | def forward(self, input, query, inner_update_num=10): 23 | 24 | fast_parameters = [] 25 | noises = [] 26 | for param in self.parameters(): 27 | param.fast = None 28 | fast_parameters.append(param) 29 | noises.append(torch.zeros_like(param).normal_(0, self.noise_rate)) 30 | 31 | #y_a_i = Variable( torch.from_numpy( np.repeat(range( self.n_way ), self.n_shot ) )).cuda() #label for support data 32 | y_a_i = torch.arange(self.n_way).repeat(self.n_shot) 33 | y_a_i = y_a_i.type(torch.cuda.LongTensor) 34 | 35 | # y_q_i = torch.arange(self.n_way).repeat(15) 36 | # y_q_i = y_q_i.type(torch.cuda.LongTensor) 37 | 38 | for ii in range(inner_update_num): 39 | #grad_support = self.run_inner_step(input, y_a_i, fast_parameters) 40 | grad_support = self.run_inner_step(input, y_a_i, fast_parameters) 41 | #grad_query = self.run_inner_step(self, input, y_a_i, fast_parameters) 42 | #do not calculate gradient of gradient if using first order approximation 43 | fast_parameters = [] 44 | for k, weight in enumerate(self.parameters()): 45 | if k == self.idx: #### REMOVE THIS FOR NORMAL MAML 46 | if weight.fast is None: 47 | weight.fast = weight - self.train_lr * (grad_support[k])# + noises[k])#.detach() #create weight.fast 48 | else: 49 | weight.fast = weight.fast - self.train_lr * (grad_support[k])# + noises[k])#.detach() #create an updated weight.fast, note the '-' is not merely minus value, but to create a new weight.fast 50 | else: 51 | weight.fast = weight #### REMOVE THIS FOR NORMAL MAML 52 | fast_parameters.append(weight.fast) 53 | 54 | query = self.cnn(query) 55 | scores = self.classifier(query) 56 | 57 | return scores 58 | 59 | 60 | def run_inner_step(self, input, label, fast_parameters): 61 | x = self.cnn(input) 62 | out = self.classifier(x) 63 | loss = F.cross_entropy(out, label) 64 | grad = torch.autograd.grad(loss, fast_parameters, create_graph=False, retain_graph=True) 65 | grad = [ g.detach() for g in grad ] 66 | 67 | return grad -------------------------------------------------------------------------------- /classification/model/maml_pcg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from convnet import ConvNet_MAML, Linear_fw 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | 7 | 8 | class MAML_PCG(nn.Module): 9 | def __init__(self, n_way, n_shot, train_lr=0.1, noise_rate=0.): 10 | super().__init__() 11 | self.cnn = ConvNet_MAML() 12 | self.classifier = Linear_fw(self.cnn.out_channels, n_way) 13 | self.train_lr = train_lr 14 | self.n_way = n_way 15 | self.n_shot = n_shot 16 | self.idxs = [8, 9, 12, 13] #idxs of parameters location. 17 | self.noise_rate = noise_rate 18 | 19 | 20 | def forward(self, input, query, pcg, inner_update_num=2, train=False): 21 | 22 | fast_parameters = [] 23 | noises = [] 24 | for param in self.parameters(): 25 | param.fast = None 26 | fast_parameters.append(param) 27 | noises.append(torch.zeros_like(param).normal_(0, self.noise_rate)) 28 | 29 | y_a_i = torch.arange(self.n_way).repeat(self.n_shot) 30 | y_a_i = y_a_i.type(torch.cuda.LongTensor) 31 | 32 | grad_support = self.run_inner_step(input, y_a_i, fast_parameters, create_graph=True, detach=False) 33 | 34 | pcg.reset() 35 | 36 | for ii in range(inner_update_num*2): # 2 forwards and backwards 37 | jj = 0 38 | precond = pcg(pcg.context_params) 39 | for k, weight in enumerate(self.parameters()): 40 | weight.fast = None 41 | 42 | if k in self.idxs: 43 | precond[jj] = precond[jj].view(-1).view(*weight.size()) 44 | weight.fast = weight - self.train_lr*(grad_support[k]+ noises[k]) * precond[jj] 45 | jj = jj + 1 46 | else: 47 | weight.fast = weight 48 | 49 | grad_mask = self.run_inner_step(input, y_a_i, pcg.context_params, create_graph=True, detach=False)[0] 50 | pcg.context_params = -grad_mask 51 | 52 | query_f = self.cnn(query) 53 | scores = self.classifier(query_f) 54 | 55 | return scores 56 | 57 | 58 | def run_inner_step(self, input, label, parameters, create_graph=False, detach=True): 59 | x = self.cnn(input) 60 | out = self.classifier(x) 61 | loss = F.cross_entropy(out, label) 62 | grad = torch.autograd.grad(loss, parameters, create_graph=create_graph, retain_graph=True) 63 | if detach: 64 | grad = [ g.detach() for g in grad ] 65 | 66 | return grad -------------------------------------------------------------------------------- /classification/model/pcg_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from torch.autograd import Variable 6 | 7 | 8 | class PCG(nn.Module): 9 | def __init__(self, num_filters=64, kernel_size=3, num_plastic=300, num_mix=5): 10 | super(PCG, self).__init__() 11 | self.num_filters = num_filters 12 | self.kernel_size = kernel_size 13 | self.num_plastic = num_plastic 14 | self.num_mix = num_mix 15 | 16 | 17 | self.uu_3 = nn.Sequential(nn.Linear(self.num_plastic, self.num_filters * self.kernel_size), 18 | nn.ReLU(), 19 | nn.Linear(self.num_filters * self.kernel_size, self.num_mix+ self.num_mix*self.num_filters * self.kernel_size) 20 | ) 21 | self.vv_3 = nn.Sequential(nn.Linear(self.num_plastic, self.num_filters * self.kernel_size), 22 | nn.ReLU(), 23 | nn.Linear(self.num_filters * self.kernel_size, 24 | self.num_mix + self.num_mix * self.num_filters * self.kernel_size) 25 | ) 26 | self.bb_3 = nn.Sequential(nn.Linear(self.num_plastic, self.num_filters), 27 | nn.ReLU(), 28 | nn.Linear(self.num_filters, self.num_mix + self.num_mix* self.num_filters) 29 | ) 30 | 31 | self.uu_4 = nn.Sequential(nn.Linear(self.num_plastic, self.num_filters * self.kernel_size), 32 | nn.ReLU(), 33 | nn.Linear(self.num_filters * self.kernel_size, 34 | self.num_mix + self.num_mix *self.num_filters * self.kernel_size) 35 | ) 36 | self.vv_4 = nn.Sequential(nn.Linear(self.num_plastic, self.num_filters * self.kernel_size), 37 | nn.ReLU(), 38 | nn.Linear(self.num_filters * self.kernel_size, 39 | self.num_mix + self.num_mix *self.num_filters * self.kernel_size) 40 | ) 41 | self.bb_4 = nn.Sequential(nn.Linear(self.num_plastic, self.num_filters), 42 | nn.ReLU(), 43 | nn.Linear(self.num_filters, self.num_mix+ self.num_mix*self.num_filters) 44 | ) 45 | 46 | 47 | self.context_params = torch.zeros(size=[self.num_plastic], requires_grad=True, device="cuda") 48 | 49 | 50 | for param in self.parameters(): 51 | self.init_layer(param) 52 | 53 | def reset(self): 54 | self.context_params = self.context_params.detach() * 0. 55 | self.context_params.requires_grad = True 56 | 57 | def init_layer(self, L): 58 | # Initialization using fan-in 59 | if isinstance(L, nn.Conv2d): 60 | n = L.kernel_size[0] * L.kernel_size[1] * L.out_channels 61 | L.weight.data.normal_(0, math.sqrt(2.0 / float(n))) 62 | elif isinstance(L, nn.BatchNorm2d): 63 | L.weight.data.fill_(1) 64 | L.bias.data.fill_(0) 65 | elif isinstance(L, nn.BatchNorm2d): 66 | L.weight.data.fill_(1) 67 | L.bias.data.fill_(0) 68 | elif isinstance(L, nn.Linear): 69 | torch.nn.init.kaiming_uniform_( L.weight, nonlinearity='linear') 70 | 71 | def forward(self, context_params): 72 | if self.num_mix <= 1: 73 | conv3_uv, conv3_b = self.assemble_w_b(self.uu_3, self.vv_3, self.bb_3, context_params) 74 | conv4_uv, conv4_b = self.assemble_w_b(self.uu_4, self.vv_4, self.bb_4, context_params) 75 | else: 76 | conv3_uv, conv3_b = self.assemble_w_b_multi(self.uu_3, self.vv_3, self.bb_3, context_params) 77 | conv4_uv, conv4_b = self.assemble_w_b_multi(self.uu_4, self.vv_4, self.bb_4, context_params) 78 | 79 | return [conv3_uv, conv3_b, conv4_uv, conv4_b] 80 | 81 | 82 | def assemble_w_b(self, uu_func, vv_func, bb_func, lat): 83 | 84 | uu = uu_func(lat) 85 | vv = vv_func(lat) 86 | bb = bb_func(lat) 87 | 88 | wu_ext = uu.unsqueeze(-1) 89 | wv_ext_t = vv.unsqueeze(-1).transpose(0, 1) 90 | model 91 | conv_uv = torch.mm(wu_ext, wv_ext_t) 92 | conv_b = bb 93 | 94 | return F.relu(conv_uv), F.relu(conv_b) 95 | 96 | 97 | def assemble_w_b_multi(self, uu_func, vv_func, bb_func, lat): 98 | 99 | uu_all = uu_func(lat) 100 | vv_all = vv_func(lat) 101 | bb_all = bb_func(lat) 102 | 103 | mixture_coeff_uu = F.softmax(uu_all[:self.num_mix]) 104 | mixture_coeff_vv = F.softmax(vv_all[:self.num_mix]) 105 | mixture_coeff_bb = F.softmax(bb_all[:self.num_mix]) 106 | 107 | uu = uu_all[self.num_mix:].view(self.num_mix, -1) 108 | uu = uu * mixture_coeff_uu.unsqueeze(-1) 109 | uu = uu.sum(0) 110 | 111 | vv = vv_all[self.num_mix:].view(self.num_mix, -1) 112 | vv = vv * mixture_coeff_vv.unsqueeze(-1) 113 | vv = vv.sum(0) 114 | 115 | bb = bb_all[self.num_mix:].view(self.num_mix, -1) 116 | bb = bb * mixture_coeff_bb.unsqueeze(-1) 117 | bb = bb.sum(0) 118 | 119 | wu_ext = uu.unsqueeze(-1) 120 | wv_ext_t = vv.unsqueeze(-1).transpose(0, 1) 121 | 122 | conv_uv = torch.mm(wu_ext, wv_ext_t) 123 | conv_b = bb 124 | 125 | return F.relu(conv_uv), F.relu(conv_b) 126 | -------------------------------------------------------------------------------- /classification/modified_pytorchmodule.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | 6 | from torch.nn.utils.weight_norm import WeightNorm 7 | 8 | class Linear_fw(nn.Linear): #used in MAML to forward input with fast weight 9 | def __init__(self, in_features, out_features, bias = True): 10 | super(Linear_fw, self).__init__(in_features, out_features, bias) 11 | self.weight.fast = None 12 | self.bias_bool = bias 13 | if bias: 14 | self.bias.fast = None 15 | 16 | def forward(self, x): 17 | if self.bias_bool: 18 | if self.weight.fast is not None and self.bias.fast is not None: 19 | out = F.linear(x, self.weight.fast, self.bias.fast) 20 | else: 21 | out = super(Linear_fw, self).forward(x) 22 | else: 23 | if self.weight.fast is not None : 24 | out = F.linear(x, self.weight.fast) 25 | else: 26 | out = super(Linear_fw, self).forward(x) 27 | return out 28 | 29 | class Linear_fwNoBias(nn.Linear): 30 | def __init__(self, in_features, out_features): 31 | super(Linear_fwNoBias, self).__init__(in_features, out_features, bias=False) 32 | self.weight.fast = None 33 | 34 | def forward(self, x): 35 | if self.weight.fast is not None : 36 | out = F.linear(x, self.weight.fast, bias=None) 37 | else: 38 | out = super(Linear_fwNoBias, self).forward(x) 39 | return out 40 | 41 | 42 | class DistLinear(nn.Linear): #used in MAML to forward input with fast weight 43 | def __init__(self, in_features, out_features): 44 | super(DistLinear, self).__init__(in_features, out_features, bias=False) 45 | self.weight.fast = None 46 | L_norm = torch.norm(self.weight.data, p=2, dim=1).unsqueeze(1).expand_as(self.weight.data) 47 | self.weight.data = self.weight.data.div(L_norm + 1e-12) 48 | 49 | def forward(self, x): 50 | x_norm = torch.norm(x, p=2, dim=1).unsqueeze(1).expand_as(x) 51 | x_normalized = x.div(x_norm + 1e-12) 52 | 53 | if self.weight.fast is not None: 54 | L_norm = torch.norm(self.weight.fast, p=2, dim=1).unsqueeze(1).expand_as(self.weight.fast) 55 | self.weight.fast = self.weight.fast.div(L_norm + 1e-12) 56 | out = F.linear(x_normalized, self.weight.fast, bias=None) 57 | else: 58 | L_norm = torch.norm(self.weight.data, p=2, dim=1).unsqueeze(1).expand_as(self.weight.data) 59 | self.weight.data = self.weight.data.div(L_norm + 1e-12) 60 | out = super(DistLinear, self).forward(x_normalized) 61 | return out 62 | 63 | class Conv2d_fw(nn.Conv2d): #used in MAML to forward input with fast weight 64 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,padding=0, bias = True, groups=1): 65 | super(Conv2d_fw, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias, groups=groups) 66 | self.weight.fast = None 67 | if not self.bias is None: 68 | self.bias.fast = None 69 | 70 | def forward(self, x): 71 | if self.bias is None: 72 | if self.weight.fast is not None: 73 | out = F.conv2d(x, self.weight.fast, None, stride= self.stride, padding=self.padding, groups=self.groups) 74 | else: 75 | out = super(Conv2d_fw, self).forward(x) 76 | else: 77 | if self.weight.fast is not None and self.bias.fast is not None: 78 | out = F.conv2d(x, self.weight.fast, self.bias.fast, stride= self.stride, padding=self.padding, groups=self.groups) 79 | else: 80 | out = super(Conv2d_fw, self).forward(x) 81 | 82 | return out 83 | 84 | 85 | 86 | class BatchNorm2d_fw(nn.BatchNorm2d): #used in MAML to forward input with fast weight 87 | def __init__(self, num_features): 88 | super(BatchNorm2d_fw, self).__init__(num_features) 89 | self.weight.fast = None 90 | self.bias.fast = None 91 | 92 | def forward(self, x): 93 | running_mean = torch.zeros(x.data.size()[1]).cuda() 94 | running_var = torch.ones(x.data.size()[1]).cuda() 95 | if self.weight.fast is not None and self.bias.fast is not None: 96 | out = F.batch_norm(x, running_mean, running_var, self.weight.fast, self.bias.fast, training = True, momentum = 1) 97 | else: 98 | out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, training = True, momentum = 1) 99 | return out -------------------------------------------------------------------------------- /classification/run_test_modgrad.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import os 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | 9 | from miniimagenet import MiniImageNet 10 | from samplers import CategoriesSampler 11 | from model.maml_pcg import MAML_PCG 12 | from model.pcg_module import PCG 13 | from utils import pprint, set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric 14 | from torch.nn.utils.clip_grad import clip_grad_norm_ 15 | #from newfunc.labelsmoothing import LabelSmoothingLoss 16 | 17 | import time 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--max-epoch', type=int, default=1) 22 | parser.add_argument('--save-epoch', type=int, default=1000) 23 | parser.add_argument('--shot', type=int, default=5) 24 | parser.add_argument('--query', type=int, default=15) 25 | parser.add_argument('--train-way', type=int, default=5) 26 | parser.add_argument('--test-way', type=int, default=5) 27 | parser.add_argument('--inner-step', type=int, default=1) 28 | parser.add_argument('--noise-rate', type=float, default=0.0) 29 | parser.add_argument('--load-path', default='./results/pcg_maml/max-acc.pth') 30 | parser.add_argument('--load-path-pcg', default='./results/pcg_maml/max-acc-pcg.pth') 31 | parser.add_argument('--data-path', default='/scratch1/sim314/flush1/miniimagenet/ctm_images') 32 | parser.add_argument('--gpu', default='3') 33 | 34 | 35 | args = parser.parse_args() 36 | pprint(vars(args)) 37 | 38 | set_gpu(args.gpu) 39 | 40 | valset = MiniImageNet('test', args.data_path) 41 | val_sampler = CategoriesSampler(valset.label, 1000, 42 | args.test_way, args.shot + args.query) 43 | val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, 44 | num_workers=8, pin_memory=True) 45 | 46 | model = MAML_PCG(args.train_way, args.shot, noise_rate=args.noise_rate).cuda() 47 | model.load_state_dict(torch.load(args.load_path)) 48 | 49 | pcg = PCG(num_plastic=300).cuda() 50 | pcg.load_state_dict(torch.load(args.load_path_pcg)) 51 | 52 | 53 | trlog = {} 54 | trlog['args'] = vars(args) 55 | trlog['val_loss'] = [] 56 | trlog['val_acc'] = [] 57 | trlog['max_acc'] = 0.0 58 | 59 | timer = Timer() 60 | 61 | vl = Averager() 62 | va = Averager() 63 | 64 | for epoch in range(1, args.max_epoch + 1): 65 | 66 | for i, batch in enumerate(val_loader, 1): 67 | with torch.no_grad(): 68 | data, _ = [_.cuda() for _ in batch] 69 | p = args.shot * args.test_way 70 | data_shot, data_query = data[:p], data[p:] 71 | label = torch.arange(args.test_way).repeat(args.query) 72 | label = label.type(torch.cuda.LongTensor) 73 | 74 | logits = model(data_shot, data_query, pcg, inner_update_num=args.inner_step) 75 | loss = F.cross_entropy(logits, label) 76 | 77 | vl.add(loss.item()) 78 | 79 | acc = count_acc(logits, label) 80 | va.add(acc) 81 | 82 | vl.add(loss.item()) 83 | va.add(acc) 84 | pcg.reset() 85 | 86 | vl = vl.item() 87 | va = va.item() 88 | print('epoch {}, val, loss={:.4f} acc={:.4f} maxacc={:.4f}'.format(epoch, vl, va,trlog['max_acc'])) 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /classification/samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class CategoriesSampler(): 6 | 7 | def __init__(self, label, n_batch, n_cls, n_per): 8 | self.n_batch = n_batch 9 | self.n_cls = n_cls 10 | self.n_per = n_per 11 | 12 | label = np.array(label) 13 | self.m_ind = [] 14 | total_class = max(label) 15 | for i in range(total_class ): 16 | #print(i) 17 | ind = np.argwhere(label == i).reshape(-1) 18 | ind = torch.from_numpy(ind) 19 | if len(ind) > 4: 20 | self.m_ind.append(ind) 21 | 22 | def __len__(self): 23 | return self.n_batch 24 | 25 | def __iter__(self): 26 | for i_batch in range(self.n_batch): 27 | batch = [] 28 | classes = torch.randperm(len(self.m_ind))[:self.n_cls] 29 | for c in classes: 30 | l = self.m_ind[c] 31 | pos = torch.randperm(len(l))[:self.n_per] 32 | batch.append(l[pos]) 33 | batch = torch.stack(batch).t().reshape(-1) 34 | #for i in range(1000): 35 | yield batch 36 | 37 | -------------------------------------------------------------------------------- /classification/test_modgrad.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import os 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | 9 | from miniimagenet import MiniImageNet 10 | from samplers import CategoriesSampler 11 | from model.maml_pcg import MAML_PCG 12 | from model.pcg_module import PCG 13 | from utils import pprint, set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric 14 | from torch.nn.utils.clip_grad import clip_grad_norm_ 15 | #from newfunc.labelsmoothing import LabelSmoothingLoss 16 | 17 | import time 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--max-epoch', type=int, default=1) 22 | parser.add_argument('--save-epoch', type=int, default=1000) 23 | parser.add_argument('--shot', type=int, default=5) 24 | parser.add_argument('--query', type=int, default=15) 25 | parser.add_argument('--train-way', type=int, default=5) 26 | parser.add_argument('--test-way', type=int, default=5) 27 | parser.add_argument('--inner-step', type=int, default=1) 28 | parser.add_argument('--noise-rate', type=float, default=0.0) 29 | parser.add_argument('--load-path', default='./results/pcg_maml/max-acc.pth') 30 | parser.add_argument('--data-path', default='yourdatapath') 31 | parser.add_argument('--gpu', default='3') 32 | 33 | 34 | args = parser.parse_args() 35 | pprint(vars(args)) 36 | 37 | set_gpu(args.gpu) 38 | 39 | 40 | valset = MiniImageNet('test', args.data_path) 41 | val_sampler = CategoriesSampler(valset.label, 1000, 42 | args.test_way, args.shot + args.query) 43 | val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, 44 | num_workers=8, pin_memory=True) 45 | 46 | model = MAML_PCG(args.train_way, args.shot, noise_rate=args.noise_rate).cuda() 47 | model.load_state_dict(torch.load(args.load_path)) 48 | 49 | pcg = PCG(num_plastic=300).cuda() 50 | 51 | optimizer = torch.optim.Adam(list(model.parameters()) , lr=0.001, amsgrad=False) 52 | optimizer_pcg = torch.optim.Adam(list(pcg.parameters()), lr=0.001, amsgrad=False) 53 | 54 | 55 | def save_model(name): 56 | if not os.path.exists(args.save_path): 57 | os.mkdir(args.save_path) 58 | torch.save(model.state_dict(), osp.join(args.save_path, name + '.pth')) 59 | torch.save(pcg.state_dict(), osp.join(args.save_path, name + '-pcg.pth')) 60 | 61 | trlog = {} 62 | trlog['args'] = vars(args) 63 | trlog['val_loss'] = [] 64 | trlog['val_acc'] = [] 65 | trlog['max_acc'] = 0.0 66 | 67 | timer = Timer() 68 | 69 | 70 | for epoch in range(1, args.max_epoch + 1): 71 | 72 | for i, batch in enumerate(val_loader, 1): 73 | with torch.no_grad(): 74 | data, _ = [_.cuda() for _ in batch] 75 | p = args.shot * args.test_way 76 | data_shot, data_query = data[:p], data[p:] 77 | label = torch.arange(args.test_way).repeat(args.query) 78 | label = label.type(torch.cuda.LongTensor) 79 | 80 | logits = model(data_shot, data_query, pcg, inner_update_num=args.inner_step) 81 | loss = F.cross_entropy(logits, label) 82 | 83 | vl.add(loss.item()) 84 | 85 | acc = count_acc(logits, label) 86 | va.add(acc) 87 | 88 | vl.add(loss.item()) 89 | va.add(acc) 90 | pcg.reset() 91 | 92 | vl = vl.item() 93 | va = va.item() 94 | print('epoch {}, val, loss={:.4f} acc={:.4f} maxacc={:.4f}'.format(epoch, vl, va,trlog['max_acc'])) 95 | 96 | if va > trlog['max_acc']: 97 | trlog['max_acc'] = va 98 | save_model('max-acc') 99 | 100 | trlog['val_loss'].append(vl) 101 | trlog['val_acc'].append(va) 102 | 103 | torch.save(trlog, osp.join(args.save_path, 'trlog')) 104 | 105 | save_model('epoch-last') 106 | 107 | if epoch % args.save_epoch == 0: 108 | save_model('epoch-{}'.format(epoch)) 109 | 110 | print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch))) 111 | 112 | 113 | -------------------------------------------------------------------------------- /classification/train_modgrad.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import os 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | 9 | from miniimagenet import MiniImageNet 10 | from samplers import CategoriesSampler 11 | from model.maml_pcg import MAML_PCG 12 | from model.pcg_module import PCG 13 | from utils import pprint, set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric 14 | from torch.nn.utils.clip_grad import clip_grad_norm_ 15 | #from newfunc.labelsmoothing import LabelSmoothingLoss 16 | 17 | import time 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--max-epoch', type=int, default=500) 22 | parser.add_argument('--save-epoch', type=int, default=1000) 23 | parser.add_argument('--shot', type=int, default=5) 24 | parser.add_argument('--query', type=int, default=15) 25 | parser.add_argument('--train-way', type=int, default=5) 26 | parser.add_argument('--test-way', type=int, default=5) 27 | parser.add_argument('--inner-step', type=int, default=1) 28 | parser.add_argument('--noise-rate', type=float, default=0.0) 29 | parser.add_argument('--save-path', default='./results/pcg_maml/') 30 | parser.add_argument('--data-path', default='yourdatapath') 31 | parser.add_argument('--gpu', default='1') 32 | 33 | 34 | args = parser.parse_args() 35 | pprint(vars(args)) 36 | 37 | set_gpu(args.gpu) 38 | 39 | trainset = MiniImageNet('train', args.data_path) 40 | train_sampler = CategoriesSampler(trainset.label, 100, 41 | args.train_way, args.shot + args.query) 42 | train_loader = DataLoader(dataset=trainset, batch_sampler=train_sampler, 43 | num_workers=8, pin_memory=True) 44 | 45 | valset = MiniImageNet('val', args.data_path) 46 | val_sampler = CategoriesSampler(valset.label, 400, 47 | args.test_way, args.shot + args.query) 48 | val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, 49 | num_workers=8, pin_memory=True) 50 | 51 | model = MAML_PCG(args.train_way, args.shot, noise_rate=args.noise_rate).cuda() 52 | 53 | pcg = PCG(num_plastic=300).cuda() 54 | 55 | task_num = 3 56 | lr_adjust_base = [200, 400] 57 | lr_adjust_pcg = [80, 160, 240, 320] 58 | 59 | optimizer = torch.optim.Adam(list(model.parameters()) , lr=0.001, amsgrad=False) 60 | optimizer_pcg = torch.optim.Adam(list(pcg.parameters()), lr=0.001, amsgrad=False) 61 | 62 | 63 | def save_model(name): 64 | if not os.path.exists(args.save_path): 65 | os.mkdir(args.save_path) 66 | torch.save(model.state_dict(), osp.join(args.save_path, name + '.pth')) 67 | torch.save(pcg.state_dict(), osp.join(args.save_path, name + '-pcg.pth')) 68 | 69 | trlog = {} 70 | trlog['args'] = vars(args) 71 | trlog['train_loss'] = [] 72 | trlog['val_loss'] = [] 73 | trlog['train_acc'] = [] 74 | trlog['val_acc'] = [] 75 | trlog['max_acc'] = 0.0 76 | 77 | timer = Timer() 78 | 79 | 80 | for epoch in range(1, args.max_epoch + 1): 81 | 82 | if epoch in lr_adjust_base:#lr_adjust : 83 | for param_group in optimizer.param_groups: 84 | param_group['lr'] = param_group['lr'] * 0.5 85 | 86 | if epoch in lr_adjust_pcg : 87 | for param_group in optimizer_pcg.param_groups: 88 | param_group['lr'] = param_group['lr'] * 0.5 89 | 90 | 91 | model.train() 92 | pcg.train() 93 | 94 | tl = Averager() 95 | ta = Averager() 96 | ratee = 0. 97 | loss_all = [] 98 | 99 | for i, batch in enumerate(train_loader, start=1): 100 | data, _ = [_.cuda() for _ in batch] 101 | p = args.shot * args.train_way 102 | qq = p + args.query * args.train_way 103 | data_shot, data_query = data[:p], data[p:] 104 | label = torch.arange(args.train_way).repeat(args.query) 105 | label = label.type(torch.cuda.LongTensor) 106 | 107 | #end = time.time() 108 | logits = model(data_shot, data_query, pcg, inner_update_num=args.inner_step, train=True) 109 | #print(time.time()-end) 110 | loss = F.cross_entropy(logits, label)#smoothloss(logits, label)#F.cross_entropy(logits, label) 111 | loss_all.append(loss) 112 | 113 | 114 | if i%task_num == 0 and i > 0: 115 | total_loss = torch.stack(loss_all).sum(0) 116 | optimizer.zero_grad() 117 | optimizer_pcg.zero_grad() 118 | total_loss.backward() 119 | optimizer.step() 120 | optimizer_pcg.step() 121 | loss_all = [] 122 | 123 | pcg.reset() 124 | tl.add(loss.item()) 125 | acc = count_acc(logits, label) 126 | ta.add(acc) 127 | 128 | 129 | print('epoch {} acc={:.4f}'.format(epoch, ta.item())) 130 | if (epoch < 400 and epoch%30!=0 ): 131 | continue 132 | 133 | vl = Averager() 134 | va = Averager() 135 | 136 | for i, batch in enumerate(val_loader, 1): 137 | with torch.no_grad(): 138 | data, _ = [_.cuda() for _ in batch] 139 | p = args.shot * args.test_way 140 | data_shot, data_query = data[:p], data[p:] 141 | label = torch.arange(args.test_way).repeat(args.query) 142 | label = label.type(torch.cuda.LongTensor) 143 | 144 | logits = model(data_shot, data_query, pcg, inner_update_num=args.inner_step) 145 | loss = F.cross_entropy(logits, label) 146 | 147 | tl.add(loss.item()) 148 | 149 | acc = count_acc(logits, label) 150 | ta.add(acc) 151 | 152 | vl.add(loss.item()) 153 | va.add(acc) 154 | pcg.reset() 155 | 156 | vl = vl.item() 157 | va = va.item() 158 | print('epoch {}, val, loss={:.4f} acc={:.4f} maxacc={:.4f}'.format(epoch, vl, va,trlog['max_acc'])) 159 | 160 | if va > trlog['max_acc']: 161 | trlog['max_acc'] = va 162 | save_model('max-acc') 163 | 164 | trlog['train_loss'].append(tl) 165 | trlog['train_acc'].append(ta) 166 | trlog['val_loss'].append(vl) 167 | trlog['val_acc'].append(va) 168 | 169 | torch.save(trlog, osp.join(args.save_path, 'trlog')) 170 | 171 | save_model('epoch-last') 172 | 173 | if epoch % args.save_epoch == 0: 174 | save_model('epoch-{}'.format(epoch)) 175 | 176 | print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch))) 177 | 178 | 179 | -------------------------------------------------------------------------------- /classification/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | import pprint 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.autograd.variable as Variable 9 | 10 | from math import sqrt 11 | from numpy.random import seed 12 | from numpy.random import randn 13 | from numpy import mean 14 | from scipy.stats import sem 15 | from scipy.stats import t 16 | import numpy as np 17 | from collections import OrderedDict 18 | import random 19 | 20 | class GaussianNoise(nn.Module): 21 | 22 | def __init__(self, batch_size, input_shape=(3, 84, 84), std=0.05): 23 | super(GaussianNoise, self).__init__() 24 | self.shape = (batch_size,) + input_shape 25 | self.noise = Variable(torch.zeros(self.shape).cuda()) 26 | self.std = std 27 | 28 | def forward(self, x, std=0.15): 29 | noise = Variable(torch.zeros(x.shape).cuda()) 30 | noise = noise.data.normal_(0, std=std) 31 | return x + noise 32 | 33 | 34 | def set_seed(seed, cudnn=True): 35 | """ 36 | Seed everything we can! 37 | Note that gym environments might need additional seeding (env.seed(seed)), 38 | and num_workers needs to be set to 1. 39 | """ 40 | random.seed(seed) 41 | np.random.seed(seed) 42 | torch.manual_seed(seed) 43 | torch.random.manual_seed(seed) 44 | torch.cuda.manual_seed(seed) 45 | # note: the below slows down the code but makes it reproducible 46 | if (seed is not None) and cudnn: 47 | torch.backends.cudnn.deterministic = True 48 | 49 | 50 | def set_gpu(x): 51 | os.environ['CUDA_VISIBLE_DEVICES'] = x 52 | print('using gpu:', x) 53 | 54 | 55 | def clone(tensor): 56 | """Detach and clone a tensor including the ``requires_grad`` attribute. 57 | 58 | Arguments: 59 | tensor (torch.Tensor): tensor to clone. 60 | """ 61 | cloned = tensor.clone()#tensor.detach().clone() 62 | # cloned.requires_grad = tensor.requires_grad 63 | # if tensor.grad is not None: 64 | # cloned.grad = clone(tensor.grad) 65 | return cloned 66 | 67 | def clone_state_dict(state_dict): 68 | """Clone a state_dict. If state_dict is from a ``torch.nn.Module``, use ``keep_vars=True``. 69 | 70 | Arguments: 71 | state_dict (OrderedDict): the state_dict to clone. Assumes state_dict is not detached from model state. 72 | """ 73 | return OrderedDict([(name, clone(param)) for name, param in state_dict.items()]) 74 | 75 | def ensure_path(path): 76 | if os.path.exists(path): 77 | if input('{} exists, remove? ([y]/n)'.format(path)) != 'n': 78 | shutil.rmtree(path) 79 | os.mkdir(path) 80 | else: 81 | os.mkdir(path) 82 | 83 | 84 | class Averager(): 85 | 86 | def __init__(self): 87 | self.n = 0 88 | self.v = 0 89 | 90 | def add(self, x): 91 | self.v = (self.v * self.n + x) / (self.n + 1) 92 | self.n += 1 93 | 94 | def item(self): 95 | return self.v 96 | 97 | 98 | def count_acc(logits, label): 99 | pred = torch.argmax(logits, dim=1) 100 | return (pred == label).type(torch.cuda.FloatTensor).mean().item() 101 | 102 | 103 | def dot_metric(a, b): 104 | return torch.mm(a, b.t()) 105 | 106 | 107 | def count_accuracy(logits, label): 108 | pred = torch.argmax(logits, dim=1).view(-1) 109 | label = label.view(-1) 110 | accuracy = 100 * pred.eq(label).float().mean() 111 | return accuracy 112 | 113 | def euclidean_metric(a, b): 114 | n = a.shape[0] 115 | m = b.shape[0] 116 | a = a.unsqueeze(1).expand(n, m, -1) 117 | b = b.unsqueeze(0).expand(n, m, -1) 118 | #logits = -((a - b)**2).sum(dim=2) 119 | logits = -((a - b)**2).sum(dim=2) 120 | return logits 121 | 122 | 123 | class Timer(): 124 | 125 | def __init__(self): 126 | self.o = time.time() 127 | 128 | def measure(self, p=1): 129 | x = (time.time() - self.o) / p 130 | x = int(x) 131 | if x >= 3600: 132 | return '{:.1f}h'.format(x / 3600) 133 | if x >= 60: 134 | return '{}m'.format(round(x / 60)) 135 | return '{}s'.format(x) 136 | 137 | _utils_pp = pprint.PrettyPrinter() 138 | def pprint(x): 139 | _utils_pp.pprint(x) 140 | 141 | 142 | def l2_loss(pred, label): 143 | return ((pred - label)**2).sum() / len(pred) / 2 144 | 145 | def set_protocol(data_path, protocol, test_protocol): 146 | train = [] 147 | val = [] 148 | 149 | all_set = ['shn', 'hon', 'clv', 'clk', 'gls', 'scl', 'sci', 'nat', 'shx', 'rel'] 150 | 151 | if protocol == 'p1': 152 | for i in range(3): 153 | train.append(data_path + '/crops_' + all_set[i]) 154 | elif protocol == 'p2': 155 | for i in range(3, 6): 156 | train.append(data_path + '/crops_' + all_set[i]) 157 | elif protocol == 'p3': 158 | for i in range(6, 8): 159 | train.append(data_path + '/crops_' + all_set[i]) 160 | elif protocol == 'p4': 161 | for i in range(8, 10): 162 | train.append(data_path + '/crops_' + all_set[i]) 163 | 164 | if test_protocol == 'p1': 165 | for i in range(3): 166 | val.append(data_path + '/crops_' + all_set[i]) 167 | elif test_protocol == 'p2': 168 | for i in range(3, 6): 169 | val.append(data_path + '/crops_' + all_set[i]) 170 | elif test_protocol == 'p3': 171 | for i in range(6, 8): 172 | val.append(data_path + '/crops_' + all_set[i]) 173 | elif test_protocol == 'p4': 174 | for i in range(8, 10): 175 | val.append(data_path + '/crops_' + all_set[i]) 176 | 177 | 178 | 179 | return train, val 180 | 181 | 182 | 183 | 184 | def independent_ttest(data1, data2, alpha): 185 | # calculate means 186 | mean1, mean2 = mean(data1), mean(data2) 187 | # calculate standard errors 188 | se1, se2 = sem(data1), sem(data2) 189 | # standard error on the difference between the samples 190 | sed = sqrt(se1**2.0 + se2**2.0) 191 | # calculate the t statistic 192 | t_stat = (mean1 - mean2) / sed 193 | # degrees of freedom 194 | df = len(data1) + len(data2) - 2 195 | # calculate the critical value 196 | cv = t.ppf(1.0 - alpha, df) 197 | # calculate the p-value 198 | p = (1.0 - t.cdf(abs(t_stat), df)) * 2.0 199 | # return everything 200 | return t_stat, df, cv, p 201 | 202 | 203 | def perturb(data): 204 | 205 | randno = np.random.randint(0, 5) 206 | if randno == 1: 207 | return torch.cat((data, data.flip(3)), dim=0) 208 | elif randno == 2: #180 209 | return torch.cat((data, data.flip(2)), dim=0) 210 | elif randno == 3: #90 211 | return torch.cat((data, data.transpose(2,3)), dim=0) 212 | else: 213 | return torch.cat((data, data.transpose(2, 3).flip(3)), dim=0) -------------------------------------------------------------------------------- /comparison_method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrysts/generative_preconditioner/724deb49abbd45f06d3fb3a003dca07ad9d84241/comparison_method.png --------------------------------------------------------------------------------