├── README.md
├── classification
├── README.md
├── convnet.py
├── materials
│ ├── test.csv
│ ├── train.csv
│ └── val.csv
├── miniimagenet.py
├── model
│ ├── maml.py
│ ├── maml_pcg.py
│ └── pcg_module.py
├── modified_pytorchmodule.py
├── run_test_modgrad.py
├── samplers.py
├── test_modgrad.py
├── train_modgrad.py
└── utils.py
└── comparison_method.png
/README.md:
--------------------------------------------------------------------------------
1 | # On Modulating the Gradient for Meta-Learning
2 |
3 | The repository contains the code for:
4 |
5 | [On Modulating the Gradient for Meta-Learning](http://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123530545.pdf)
6 |
7 | [Supp. Material](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123530545-supp.pdf)
8 |
9 | ECCV 2020
10 |
11 | Comparison with prior methods: MAML, Meta-SGD, and Ours.
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | Please refer to each folder for different tasks: classification, regression, and RL.
20 |
21 |
22 | ## Citation
23 |
24 | ````
25 | @inproceedings{Christian2020ModGrad,
26 | author = {Simon, Christian and Koniusz, Piotr and Nock, Richard and Harandi, Mehrtash},
27 | title = {On Modulating the Gradient for Meta-Learning},
28 | booktitle = {The European Conference on Computer Vision},
29 | year = {2020}
30 | }
31 | ````
32 |
33 |
34 | ## Acknowledgement
35 | Thank you for CAVIA code:
36 | regression and RL tasks are adopted from https://github.com/lmzintgraf/cavia
37 |
38 |
--------------------------------------------------------------------------------
/classification/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | To download mini-ImageNet:
4 | Google drive file [here](https://drive.google.com/file/d/1HkgrkAwukzEZA0TpO7010PkAOREb2Nuk/view) to directly
5 | download the `mini-imagenet.zip` file. This mini-ImageNet set refers to https://github.com/Clarifai/few-shot-ctm.
6 |
7 | Change '--data-path' to the folder where the data is stored.
8 |
9 | Please run train_modgrad.py for training and run_test_modgrad.py for testing.
10 |
--------------------------------------------------------------------------------
/classification/convnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from modified_pytorchmodule import Conv2d_fw, Linear_fw, BatchNorm2d_fw
3 |
4 | def conv_block(in_channels, out_channels):
5 | return nn.Sequential(
6 | nn.Conv2d(in_channels, out_channels, 3, padding=1),
7 | nn.BatchNorm2d(out_channels),
8 | nn.ReLU(),
9 | nn.MaxPool2d(2)
10 | )
11 |
12 | def conv_block_fast(in_channels, out_channels):
13 | return nn.Sequential(
14 | Conv2d_fw(in_channels, out_channels, 3, padding=1),
15 | BatchNorm2d_fw(out_channels),
16 | nn.ReLU(),
17 | nn.MaxPool2d(2)
18 | )
19 |
20 | class ConvNet_MAML(nn.Module):
21 |
22 | def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
23 | super().__init__()
24 | self.encoder = nn.Sequential(
25 | conv_block_fast(x_dim, hid_dim),
26 | conv_block_fast(hid_dim, hid_dim),
27 | conv_block_fast(hid_dim, hid_dim),
28 | conv_block_fast(hid_dim, z_dim),
29 | )
30 |
31 | self.out_channels = 1600
32 |
33 | def forward(self, x):
34 |
35 | x = self.encoder(x)
36 | return x.view(x.size(0), -1)
37 |
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/classification/miniimagenet.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | from PIL import Image
3 |
4 | from torch.utils.data import Dataset
5 | from torchvision import transforms
6 | import numpy as np
7 | ROOT_PATH = './materials/'
8 |
9 |
10 |
11 | class MiniImageNet(Dataset):
12 |
13 | def __init__(self, setname, img_path):
14 | csv_path = osp.join(ROOT_PATH, setname + '.csv')
15 | lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]
16 | IMG_PATH = img_path
17 | data = []
18 | label = []
19 | lb = -1
20 |
21 | self.wnids = []
22 |
23 | for l in lines:
24 | name, wnid = l.split(',')
25 | path = osp.join(IMG_PATH, 'images', name)
26 | if wnid not in self.wnids:
27 | self.wnids.append(wnid)
28 | lb += 1
29 | data.append(path)
30 | label.append(lb)
31 |
32 | self.data = data
33 | self.label = label
34 | if setname == 'train':
35 | self.transform = transforms.Compose([
36 | transforms.ToTensor(),
37 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
38 | std=[0.229, 0.224, 0.225])
39 | ])
40 | else:
41 | self.transform = transforms.Compose([
42 | transforms.ToTensor(),
43 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
44 | std=[0.229, 0.224, 0.225])
45 | ])
46 |
47 |
48 |
49 | def __len__(self):
50 | return len(self.data)
51 |
52 | def __getitem__(self, i):
53 | path, label = self.data[i], self.label[i]
54 | img =Image.open(path).convert('RGB')
55 | img = img.resize((84, 84)).convert('RGB')
56 | image = self.transform(img)
57 |
58 | return image, label
59 |
60 |
61 |
--------------------------------------------------------------------------------
/classification/model/maml.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from convnet import ConvNet_MAML, Linear_fw
4 | from torch.autograd import Variable
5 | import numpy as np
6 | import torch.nn.functional as F
7 | from networks.modified_pytorchmodule import DistLinear
8 |
9 |
10 | class MAML(nn.Module):
11 | def __init__(self, n_way, n_shot, train_lr=0.1, noise_rate=0.):
12 | super().__init__()
13 | self.cnn = ConvNet_MAML()
14 | self.classifier = Linear_fw(self.cnn.out_channels, n_way)
15 | #self.classifier = DistLinear(self.cnn.out_channels, n_way)
16 | self.train_lr = train_lr
17 | self.n_way = n_way
18 | self.n_shot = n_shot
19 | self.noise_rate = noise_rate
20 | self.idx=16
21 |
22 | def forward(self, input, query, inner_update_num=10):
23 |
24 | fast_parameters = []
25 | noises = []
26 | for param in self.parameters():
27 | param.fast = None
28 | fast_parameters.append(param)
29 | noises.append(torch.zeros_like(param).normal_(0, self.noise_rate))
30 |
31 | #y_a_i = Variable( torch.from_numpy( np.repeat(range( self.n_way ), self.n_shot ) )).cuda() #label for support data
32 | y_a_i = torch.arange(self.n_way).repeat(self.n_shot)
33 | y_a_i = y_a_i.type(torch.cuda.LongTensor)
34 |
35 | # y_q_i = torch.arange(self.n_way).repeat(15)
36 | # y_q_i = y_q_i.type(torch.cuda.LongTensor)
37 |
38 | for ii in range(inner_update_num):
39 | #grad_support = self.run_inner_step(input, y_a_i, fast_parameters)
40 | grad_support = self.run_inner_step(input, y_a_i, fast_parameters)
41 | #grad_query = self.run_inner_step(self, input, y_a_i, fast_parameters)
42 | #do not calculate gradient of gradient if using first order approximation
43 | fast_parameters = []
44 | for k, weight in enumerate(self.parameters()):
45 | if k == self.idx: #### REMOVE THIS FOR NORMAL MAML
46 | if weight.fast is None:
47 | weight.fast = weight - self.train_lr * (grad_support[k])# + noises[k])#.detach() #create weight.fast
48 | else:
49 | weight.fast = weight.fast - self.train_lr * (grad_support[k])# + noises[k])#.detach() #create an updated weight.fast, note the '-' is not merely minus value, but to create a new weight.fast
50 | else:
51 | weight.fast = weight #### REMOVE THIS FOR NORMAL MAML
52 | fast_parameters.append(weight.fast)
53 |
54 | query = self.cnn(query)
55 | scores = self.classifier(query)
56 |
57 | return scores
58 |
59 |
60 | def run_inner_step(self, input, label, fast_parameters):
61 | x = self.cnn(input)
62 | out = self.classifier(x)
63 | loss = F.cross_entropy(out, label)
64 | grad = torch.autograd.grad(loss, fast_parameters, create_graph=False, retain_graph=True)
65 | grad = [ g.detach() for g in grad ]
66 |
67 | return grad
--------------------------------------------------------------------------------
/classification/model/maml_pcg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from convnet import ConvNet_MAML, Linear_fw
4 | from torch.autograd import Variable
5 | import torch.nn.functional as F
6 |
7 |
8 | class MAML_PCG(nn.Module):
9 | def __init__(self, n_way, n_shot, train_lr=0.1, noise_rate=0.):
10 | super().__init__()
11 | self.cnn = ConvNet_MAML()
12 | self.classifier = Linear_fw(self.cnn.out_channels, n_way)
13 | self.train_lr = train_lr
14 | self.n_way = n_way
15 | self.n_shot = n_shot
16 | self.idxs = [8, 9, 12, 13] #idxs of parameters location.
17 | self.noise_rate = noise_rate
18 |
19 |
20 | def forward(self, input, query, pcg, inner_update_num=2, train=False):
21 |
22 | fast_parameters = []
23 | noises = []
24 | for param in self.parameters():
25 | param.fast = None
26 | fast_parameters.append(param)
27 | noises.append(torch.zeros_like(param).normal_(0, self.noise_rate))
28 |
29 | y_a_i = torch.arange(self.n_way).repeat(self.n_shot)
30 | y_a_i = y_a_i.type(torch.cuda.LongTensor)
31 |
32 | grad_support = self.run_inner_step(input, y_a_i, fast_parameters, create_graph=True, detach=False)
33 |
34 | pcg.reset()
35 |
36 | for ii in range(inner_update_num*2): # 2 forwards and backwards
37 | jj = 0
38 | precond = pcg(pcg.context_params)
39 | for k, weight in enumerate(self.parameters()):
40 | weight.fast = None
41 |
42 | if k in self.idxs:
43 | precond[jj] = precond[jj].view(-1).view(*weight.size())
44 | weight.fast = weight - self.train_lr*(grad_support[k]+ noises[k]) * precond[jj]
45 | jj = jj + 1
46 | else:
47 | weight.fast = weight
48 |
49 | grad_mask = self.run_inner_step(input, y_a_i, pcg.context_params, create_graph=True, detach=False)[0]
50 | pcg.context_params = -grad_mask
51 |
52 | query_f = self.cnn(query)
53 | scores = self.classifier(query_f)
54 |
55 | return scores
56 |
57 |
58 | def run_inner_step(self, input, label, parameters, create_graph=False, detach=True):
59 | x = self.cnn(input)
60 | out = self.classifier(x)
61 | loss = F.cross_entropy(out, label)
62 | grad = torch.autograd.grad(loss, parameters, create_graph=create_graph, retain_graph=True)
63 | if detach:
64 | grad = [ g.detach() for g in grad ]
65 |
66 | return grad
--------------------------------------------------------------------------------
/classification/model/pcg_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | from torch.autograd import Variable
6 |
7 |
8 | class PCG(nn.Module):
9 | def __init__(self, num_filters=64, kernel_size=3, num_plastic=300, num_mix=5):
10 | super(PCG, self).__init__()
11 | self.num_filters = num_filters
12 | self.kernel_size = kernel_size
13 | self.num_plastic = num_plastic
14 | self.num_mix = num_mix
15 |
16 |
17 | self.uu_3 = nn.Sequential(nn.Linear(self.num_plastic, self.num_filters * self.kernel_size),
18 | nn.ReLU(),
19 | nn.Linear(self.num_filters * self.kernel_size, self.num_mix+ self.num_mix*self.num_filters * self.kernel_size)
20 | )
21 | self.vv_3 = nn.Sequential(nn.Linear(self.num_plastic, self.num_filters * self.kernel_size),
22 | nn.ReLU(),
23 | nn.Linear(self.num_filters * self.kernel_size,
24 | self.num_mix + self.num_mix * self.num_filters * self.kernel_size)
25 | )
26 | self.bb_3 = nn.Sequential(nn.Linear(self.num_plastic, self.num_filters),
27 | nn.ReLU(),
28 | nn.Linear(self.num_filters, self.num_mix + self.num_mix* self.num_filters)
29 | )
30 |
31 | self.uu_4 = nn.Sequential(nn.Linear(self.num_plastic, self.num_filters * self.kernel_size),
32 | nn.ReLU(),
33 | nn.Linear(self.num_filters * self.kernel_size,
34 | self.num_mix + self.num_mix *self.num_filters * self.kernel_size)
35 | )
36 | self.vv_4 = nn.Sequential(nn.Linear(self.num_plastic, self.num_filters * self.kernel_size),
37 | nn.ReLU(),
38 | nn.Linear(self.num_filters * self.kernel_size,
39 | self.num_mix + self.num_mix *self.num_filters * self.kernel_size)
40 | )
41 | self.bb_4 = nn.Sequential(nn.Linear(self.num_plastic, self.num_filters),
42 | nn.ReLU(),
43 | nn.Linear(self.num_filters, self.num_mix+ self.num_mix*self.num_filters)
44 | )
45 |
46 |
47 | self.context_params = torch.zeros(size=[self.num_plastic], requires_grad=True, device="cuda")
48 |
49 |
50 | for param in self.parameters():
51 | self.init_layer(param)
52 |
53 | def reset(self):
54 | self.context_params = self.context_params.detach() * 0.
55 | self.context_params.requires_grad = True
56 |
57 | def init_layer(self, L):
58 | # Initialization using fan-in
59 | if isinstance(L, nn.Conv2d):
60 | n = L.kernel_size[0] * L.kernel_size[1] * L.out_channels
61 | L.weight.data.normal_(0, math.sqrt(2.0 / float(n)))
62 | elif isinstance(L, nn.BatchNorm2d):
63 | L.weight.data.fill_(1)
64 | L.bias.data.fill_(0)
65 | elif isinstance(L, nn.BatchNorm2d):
66 | L.weight.data.fill_(1)
67 | L.bias.data.fill_(0)
68 | elif isinstance(L, nn.Linear):
69 | torch.nn.init.kaiming_uniform_( L.weight, nonlinearity='linear')
70 |
71 | def forward(self, context_params):
72 | if self.num_mix <= 1:
73 | conv3_uv, conv3_b = self.assemble_w_b(self.uu_3, self.vv_3, self.bb_3, context_params)
74 | conv4_uv, conv4_b = self.assemble_w_b(self.uu_4, self.vv_4, self.bb_4, context_params)
75 | else:
76 | conv3_uv, conv3_b = self.assemble_w_b_multi(self.uu_3, self.vv_3, self.bb_3, context_params)
77 | conv4_uv, conv4_b = self.assemble_w_b_multi(self.uu_4, self.vv_4, self.bb_4, context_params)
78 |
79 | return [conv3_uv, conv3_b, conv4_uv, conv4_b]
80 |
81 |
82 | def assemble_w_b(self, uu_func, vv_func, bb_func, lat):
83 |
84 | uu = uu_func(lat)
85 | vv = vv_func(lat)
86 | bb = bb_func(lat)
87 |
88 | wu_ext = uu.unsqueeze(-1)
89 | wv_ext_t = vv.unsqueeze(-1).transpose(0, 1)
90 | model
91 | conv_uv = torch.mm(wu_ext, wv_ext_t)
92 | conv_b = bb
93 |
94 | return F.relu(conv_uv), F.relu(conv_b)
95 |
96 |
97 | def assemble_w_b_multi(self, uu_func, vv_func, bb_func, lat):
98 |
99 | uu_all = uu_func(lat)
100 | vv_all = vv_func(lat)
101 | bb_all = bb_func(lat)
102 |
103 | mixture_coeff_uu = F.softmax(uu_all[:self.num_mix])
104 | mixture_coeff_vv = F.softmax(vv_all[:self.num_mix])
105 | mixture_coeff_bb = F.softmax(bb_all[:self.num_mix])
106 |
107 | uu = uu_all[self.num_mix:].view(self.num_mix, -1)
108 | uu = uu * mixture_coeff_uu.unsqueeze(-1)
109 | uu = uu.sum(0)
110 |
111 | vv = vv_all[self.num_mix:].view(self.num_mix, -1)
112 | vv = vv * mixture_coeff_vv.unsqueeze(-1)
113 | vv = vv.sum(0)
114 |
115 | bb = bb_all[self.num_mix:].view(self.num_mix, -1)
116 | bb = bb * mixture_coeff_bb.unsqueeze(-1)
117 | bb = bb.sum(0)
118 |
119 | wu_ext = uu.unsqueeze(-1)
120 | wv_ext_t = vv.unsqueeze(-1).transpose(0, 1)
121 |
122 | conv_uv = torch.mm(wu_ext, wv_ext_t)
123 | conv_b = bb
124 |
125 | return F.relu(conv_uv), F.relu(conv_b)
126 |
--------------------------------------------------------------------------------
/classification/modified_pytorchmodule.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn.functional as F
4 | import torch.nn as nn
5 |
6 | from torch.nn.utils.weight_norm import WeightNorm
7 |
8 | class Linear_fw(nn.Linear): #used in MAML to forward input with fast weight
9 | def __init__(self, in_features, out_features, bias = True):
10 | super(Linear_fw, self).__init__(in_features, out_features, bias)
11 | self.weight.fast = None
12 | self.bias_bool = bias
13 | if bias:
14 | self.bias.fast = None
15 |
16 | def forward(self, x):
17 | if self.bias_bool:
18 | if self.weight.fast is not None and self.bias.fast is not None:
19 | out = F.linear(x, self.weight.fast, self.bias.fast)
20 | else:
21 | out = super(Linear_fw, self).forward(x)
22 | else:
23 | if self.weight.fast is not None :
24 | out = F.linear(x, self.weight.fast)
25 | else:
26 | out = super(Linear_fw, self).forward(x)
27 | return out
28 |
29 | class Linear_fwNoBias(nn.Linear):
30 | def __init__(self, in_features, out_features):
31 | super(Linear_fwNoBias, self).__init__(in_features, out_features, bias=False)
32 | self.weight.fast = None
33 |
34 | def forward(self, x):
35 | if self.weight.fast is not None :
36 | out = F.linear(x, self.weight.fast, bias=None)
37 | else:
38 | out = super(Linear_fwNoBias, self).forward(x)
39 | return out
40 |
41 |
42 | class DistLinear(nn.Linear): #used in MAML to forward input with fast weight
43 | def __init__(self, in_features, out_features):
44 | super(DistLinear, self).__init__(in_features, out_features, bias=False)
45 | self.weight.fast = None
46 | L_norm = torch.norm(self.weight.data, p=2, dim=1).unsqueeze(1).expand_as(self.weight.data)
47 | self.weight.data = self.weight.data.div(L_norm + 1e-12)
48 |
49 | def forward(self, x):
50 | x_norm = torch.norm(x, p=2, dim=1).unsqueeze(1).expand_as(x)
51 | x_normalized = x.div(x_norm + 1e-12)
52 |
53 | if self.weight.fast is not None:
54 | L_norm = torch.norm(self.weight.fast, p=2, dim=1).unsqueeze(1).expand_as(self.weight.fast)
55 | self.weight.fast = self.weight.fast.div(L_norm + 1e-12)
56 | out = F.linear(x_normalized, self.weight.fast, bias=None)
57 | else:
58 | L_norm = torch.norm(self.weight.data, p=2, dim=1).unsqueeze(1).expand_as(self.weight.data)
59 | self.weight.data = self.weight.data.div(L_norm + 1e-12)
60 | out = super(DistLinear, self).forward(x_normalized)
61 | return out
62 |
63 | class Conv2d_fw(nn.Conv2d): #used in MAML to forward input with fast weight
64 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,padding=0, bias = True, groups=1):
65 | super(Conv2d_fw, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias, groups=groups)
66 | self.weight.fast = None
67 | if not self.bias is None:
68 | self.bias.fast = None
69 |
70 | def forward(self, x):
71 | if self.bias is None:
72 | if self.weight.fast is not None:
73 | out = F.conv2d(x, self.weight.fast, None, stride= self.stride, padding=self.padding, groups=self.groups)
74 | else:
75 | out = super(Conv2d_fw, self).forward(x)
76 | else:
77 | if self.weight.fast is not None and self.bias.fast is not None:
78 | out = F.conv2d(x, self.weight.fast, self.bias.fast, stride= self.stride, padding=self.padding, groups=self.groups)
79 | else:
80 | out = super(Conv2d_fw, self).forward(x)
81 |
82 | return out
83 |
84 |
85 |
86 | class BatchNorm2d_fw(nn.BatchNorm2d): #used in MAML to forward input with fast weight
87 | def __init__(self, num_features):
88 | super(BatchNorm2d_fw, self).__init__(num_features)
89 | self.weight.fast = None
90 | self.bias.fast = None
91 |
92 | def forward(self, x):
93 | running_mean = torch.zeros(x.data.size()[1]).cuda()
94 | running_var = torch.ones(x.data.size()[1]).cuda()
95 | if self.weight.fast is not None and self.bias.fast is not None:
96 | out = F.batch_norm(x, running_mean, running_var, self.weight.fast, self.bias.fast, training = True, momentum = 1)
97 | else:
98 | out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, training = True, momentum = 1)
99 | return out
--------------------------------------------------------------------------------
/classification/run_test_modgrad.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os.path as osp
3 | import os
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.utils.data import DataLoader
8 |
9 | from miniimagenet import MiniImageNet
10 | from samplers import CategoriesSampler
11 | from model.maml_pcg import MAML_PCG
12 | from model.pcg_module import PCG
13 | from utils import pprint, set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric
14 | from torch.nn.utils.clip_grad import clip_grad_norm_
15 | #from newfunc.labelsmoothing import LabelSmoothingLoss
16 |
17 | import time
18 |
19 | if __name__ == '__main__':
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--max-epoch', type=int, default=1)
22 | parser.add_argument('--save-epoch', type=int, default=1000)
23 | parser.add_argument('--shot', type=int, default=5)
24 | parser.add_argument('--query', type=int, default=15)
25 | parser.add_argument('--train-way', type=int, default=5)
26 | parser.add_argument('--test-way', type=int, default=5)
27 | parser.add_argument('--inner-step', type=int, default=1)
28 | parser.add_argument('--noise-rate', type=float, default=0.0)
29 | parser.add_argument('--load-path', default='./results/pcg_maml/max-acc.pth')
30 | parser.add_argument('--load-path-pcg', default='./results/pcg_maml/max-acc-pcg.pth')
31 | parser.add_argument('--data-path', default='/scratch1/sim314/flush1/miniimagenet/ctm_images')
32 | parser.add_argument('--gpu', default='3')
33 |
34 |
35 | args = parser.parse_args()
36 | pprint(vars(args))
37 |
38 | set_gpu(args.gpu)
39 |
40 | valset = MiniImageNet('test', args.data_path)
41 | val_sampler = CategoriesSampler(valset.label, 1000,
42 | args.test_way, args.shot + args.query)
43 | val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler,
44 | num_workers=8, pin_memory=True)
45 |
46 | model = MAML_PCG(args.train_way, args.shot, noise_rate=args.noise_rate).cuda()
47 | model.load_state_dict(torch.load(args.load_path))
48 |
49 | pcg = PCG(num_plastic=300).cuda()
50 | pcg.load_state_dict(torch.load(args.load_path_pcg))
51 |
52 |
53 | trlog = {}
54 | trlog['args'] = vars(args)
55 | trlog['val_loss'] = []
56 | trlog['val_acc'] = []
57 | trlog['max_acc'] = 0.0
58 |
59 | timer = Timer()
60 |
61 | vl = Averager()
62 | va = Averager()
63 |
64 | for epoch in range(1, args.max_epoch + 1):
65 |
66 | for i, batch in enumerate(val_loader, 1):
67 | with torch.no_grad():
68 | data, _ = [_.cuda() for _ in batch]
69 | p = args.shot * args.test_way
70 | data_shot, data_query = data[:p], data[p:]
71 | label = torch.arange(args.test_way).repeat(args.query)
72 | label = label.type(torch.cuda.LongTensor)
73 |
74 | logits = model(data_shot, data_query, pcg, inner_update_num=args.inner_step)
75 | loss = F.cross_entropy(logits, label)
76 |
77 | vl.add(loss.item())
78 |
79 | acc = count_acc(logits, label)
80 | va.add(acc)
81 |
82 | vl.add(loss.item())
83 | va.add(acc)
84 | pcg.reset()
85 |
86 | vl = vl.item()
87 | va = va.item()
88 | print('epoch {}, val, loss={:.4f} acc={:.4f} maxacc={:.4f}'.format(epoch, vl, va,trlog['max_acc']))
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/classification/samplers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class CategoriesSampler():
6 |
7 | def __init__(self, label, n_batch, n_cls, n_per):
8 | self.n_batch = n_batch
9 | self.n_cls = n_cls
10 | self.n_per = n_per
11 |
12 | label = np.array(label)
13 | self.m_ind = []
14 | total_class = max(label)
15 | for i in range(total_class ):
16 | #print(i)
17 | ind = np.argwhere(label == i).reshape(-1)
18 | ind = torch.from_numpy(ind)
19 | if len(ind) > 4:
20 | self.m_ind.append(ind)
21 |
22 | def __len__(self):
23 | return self.n_batch
24 |
25 | def __iter__(self):
26 | for i_batch in range(self.n_batch):
27 | batch = []
28 | classes = torch.randperm(len(self.m_ind))[:self.n_cls]
29 | for c in classes:
30 | l = self.m_ind[c]
31 | pos = torch.randperm(len(l))[:self.n_per]
32 | batch.append(l[pos])
33 | batch = torch.stack(batch).t().reshape(-1)
34 | #for i in range(1000):
35 | yield batch
36 |
37 |
--------------------------------------------------------------------------------
/classification/test_modgrad.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os.path as osp
3 | import os
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.utils.data import DataLoader
8 |
9 | from miniimagenet import MiniImageNet
10 | from samplers import CategoriesSampler
11 | from model.maml_pcg import MAML_PCG
12 | from model.pcg_module import PCG
13 | from utils import pprint, set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric
14 | from torch.nn.utils.clip_grad import clip_grad_norm_
15 | #from newfunc.labelsmoothing import LabelSmoothingLoss
16 |
17 | import time
18 |
19 | if __name__ == '__main__':
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--max-epoch', type=int, default=1)
22 | parser.add_argument('--save-epoch', type=int, default=1000)
23 | parser.add_argument('--shot', type=int, default=5)
24 | parser.add_argument('--query', type=int, default=15)
25 | parser.add_argument('--train-way', type=int, default=5)
26 | parser.add_argument('--test-way', type=int, default=5)
27 | parser.add_argument('--inner-step', type=int, default=1)
28 | parser.add_argument('--noise-rate', type=float, default=0.0)
29 | parser.add_argument('--load-path', default='./results/pcg_maml/max-acc.pth')
30 | parser.add_argument('--data-path', default='yourdatapath')
31 | parser.add_argument('--gpu', default='3')
32 |
33 |
34 | args = parser.parse_args()
35 | pprint(vars(args))
36 |
37 | set_gpu(args.gpu)
38 |
39 |
40 | valset = MiniImageNet('test', args.data_path)
41 | val_sampler = CategoriesSampler(valset.label, 1000,
42 | args.test_way, args.shot + args.query)
43 | val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler,
44 | num_workers=8, pin_memory=True)
45 |
46 | model = MAML_PCG(args.train_way, args.shot, noise_rate=args.noise_rate).cuda()
47 | model.load_state_dict(torch.load(args.load_path))
48 |
49 | pcg = PCG(num_plastic=300).cuda()
50 |
51 | optimizer = torch.optim.Adam(list(model.parameters()) , lr=0.001, amsgrad=False)
52 | optimizer_pcg = torch.optim.Adam(list(pcg.parameters()), lr=0.001, amsgrad=False)
53 |
54 |
55 | def save_model(name):
56 | if not os.path.exists(args.save_path):
57 | os.mkdir(args.save_path)
58 | torch.save(model.state_dict(), osp.join(args.save_path, name + '.pth'))
59 | torch.save(pcg.state_dict(), osp.join(args.save_path, name + '-pcg.pth'))
60 |
61 | trlog = {}
62 | trlog['args'] = vars(args)
63 | trlog['val_loss'] = []
64 | trlog['val_acc'] = []
65 | trlog['max_acc'] = 0.0
66 |
67 | timer = Timer()
68 |
69 |
70 | for epoch in range(1, args.max_epoch + 1):
71 |
72 | for i, batch in enumerate(val_loader, 1):
73 | with torch.no_grad():
74 | data, _ = [_.cuda() for _ in batch]
75 | p = args.shot * args.test_way
76 | data_shot, data_query = data[:p], data[p:]
77 | label = torch.arange(args.test_way).repeat(args.query)
78 | label = label.type(torch.cuda.LongTensor)
79 |
80 | logits = model(data_shot, data_query, pcg, inner_update_num=args.inner_step)
81 | loss = F.cross_entropy(logits, label)
82 |
83 | vl.add(loss.item())
84 |
85 | acc = count_acc(logits, label)
86 | va.add(acc)
87 |
88 | vl.add(loss.item())
89 | va.add(acc)
90 | pcg.reset()
91 |
92 | vl = vl.item()
93 | va = va.item()
94 | print('epoch {}, val, loss={:.4f} acc={:.4f} maxacc={:.4f}'.format(epoch, vl, va,trlog['max_acc']))
95 |
96 | if va > trlog['max_acc']:
97 | trlog['max_acc'] = va
98 | save_model('max-acc')
99 |
100 | trlog['val_loss'].append(vl)
101 | trlog['val_acc'].append(va)
102 |
103 | torch.save(trlog, osp.join(args.save_path, 'trlog'))
104 |
105 | save_model('epoch-last')
106 |
107 | if epoch % args.save_epoch == 0:
108 | save_model('epoch-{}'.format(epoch))
109 |
110 | print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch)))
111 |
112 |
113 |
--------------------------------------------------------------------------------
/classification/train_modgrad.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os.path as osp
3 | import os
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.utils.data import DataLoader
8 |
9 | from miniimagenet import MiniImageNet
10 | from samplers import CategoriesSampler
11 | from model.maml_pcg import MAML_PCG
12 | from model.pcg_module import PCG
13 | from utils import pprint, set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric
14 | from torch.nn.utils.clip_grad import clip_grad_norm_
15 | #from newfunc.labelsmoothing import LabelSmoothingLoss
16 |
17 | import time
18 |
19 | if __name__ == '__main__':
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--max-epoch', type=int, default=500)
22 | parser.add_argument('--save-epoch', type=int, default=1000)
23 | parser.add_argument('--shot', type=int, default=5)
24 | parser.add_argument('--query', type=int, default=15)
25 | parser.add_argument('--train-way', type=int, default=5)
26 | parser.add_argument('--test-way', type=int, default=5)
27 | parser.add_argument('--inner-step', type=int, default=1)
28 | parser.add_argument('--noise-rate', type=float, default=0.0)
29 | parser.add_argument('--save-path', default='./results/pcg_maml/')
30 | parser.add_argument('--data-path', default='yourdatapath')
31 | parser.add_argument('--gpu', default='1')
32 |
33 |
34 | args = parser.parse_args()
35 | pprint(vars(args))
36 |
37 | set_gpu(args.gpu)
38 |
39 | trainset = MiniImageNet('train', args.data_path)
40 | train_sampler = CategoriesSampler(trainset.label, 100,
41 | args.train_way, args.shot + args.query)
42 | train_loader = DataLoader(dataset=trainset, batch_sampler=train_sampler,
43 | num_workers=8, pin_memory=True)
44 |
45 | valset = MiniImageNet('val', args.data_path)
46 | val_sampler = CategoriesSampler(valset.label, 400,
47 | args.test_way, args.shot + args.query)
48 | val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler,
49 | num_workers=8, pin_memory=True)
50 |
51 | model = MAML_PCG(args.train_way, args.shot, noise_rate=args.noise_rate).cuda()
52 |
53 | pcg = PCG(num_plastic=300).cuda()
54 |
55 | task_num = 3
56 | lr_adjust_base = [200, 400]
57 | lr_adjust_pcg = [80, 160, 240, 320]
58 |
59 | optimizer = torch.optim.Adam(list(model.parameters()) , lr=0.001, amsgrad=False)
60 | optimizer_pcg = torch.optim.Adam(list(pcg.parameters()), lr=0.001, amsgrad=False)
61 |
62 |
63 | def save_model(name):
64 | if not os.path.exists(args.save_path):
65 | os.mkdir(args.save_path)
66 | torch.save(model.state_dict(), osp.join(args.save_path, name + '.pth'))
67 | torch.save(pcg.state_dict(), osp.join(args.save_path, name + '-pcg.pth'))
68 |
69 | trlog = {}
70 | trlog['args'] = vars(args)
71 | trlog['train_loss'] = []
72 | trlog['val_loss'] = []
73 | trlog['train_acc'] = []
74 | trlog['val_acc'] = []
75 | trlog['max_acc'] = 0.0
76 |
77 | timer = Timer()
78 |
79 |
80 | for epoch in range(1, args.max_epoch + 1):
81 |
82 | if epoch in lr_adjust_base:#lr_adjust :
83 | for param_group in optimizer.param_groups:
84 | param_group['lr'] = param_group['lr'] * 0.5
85 |
86 | if epoch in lr_adjust_pcg :
87 | for param_group in optimizer_pcg.param_groups:
88 | param_group['lr'] = param_group['lr'] * 0.5
89 |
90 |
91 | model.train()
92 | pcg.train()
93 |
94 | tl = Averager()
95 | ta = Averager()
96 | ratee = 0.
97 | loss_all = []
98 |
99 | for i, batch in enumerate(train_loader, start=1):
100 | data, _ = [_.cuda() for _ in batch]
101 | p = args.shot * args.train_way
102 | qq = p + args.query * args.train_way
103 | data_shot, data_query = data[:p], data[p:]
104 | label = torch.arange(args.train_way).repeat(args.query)
105 | label = label.type(torch.cuda.LongTensor)
106 |
107 | #end = time.time()
108 | logits = model(data_shot, data_query, pcg, inner_update_num=args.inner_step, train=True)
109 | #print(time.time()-end)
110 | loss = F.cross_entropy(logits, label)#smoothloss(logits, label)#F.cross_entropy(logits, label)
111 | loss_all.append(loss)
112 |
113 |
114 | if i%task_num == 0 and i > 0:
115 | total_loss = torch.stack(loss_all).sum(0)
116 | optimizer.zero_grad()
117 | optimizer_pcg.zero_grad()
118 | total_loss.backward()
119 | optimizer.step()
120 | optimizer_pcg.step()
121 | loss_all = []
122 |
123 | pcg.reset()
124 | tl.add(loss.item())
125 | acc = count_acc(logits, label)
126 | ta.add(acc)
127 |
128 |
129 | print('epoch {} acc={:.4f}'.format(epoch, ta.item()))
130 | if (epoch < 400 and epoch%30!=0 ):
131 | continue
132 |
133 | vl = Averager()
134 | va = Averager()
135 |
136 | for i, batch in enumerate(val_loader, 1):
137 | with torch.no_grad():
138 | data, _ = [_.cuda() for _ in batch]
139 | p = args.shot * args.test_way
140 | data_shot, data_query = data[:p], data[p:]
141 | label = torch.arange(args.test_way).repeat(args.query)
142 | label = label.type(torch.cuda.LongTensor)
143 |
144 | logits = model(data_shot, data_query, pcg, inner_update_num=args.inner_step)
145 | loss = F.cross_entropy(logits, label)
146 |
147 | tl.add(loss.item())
148 |
149 | acc = count_acc(logits, label)
150 | ta.add(acc)
151 |
152 | vl.add(loss.item())
153 | va.add(acc)
154 | pcg.reset()
155 |
156 | vl = vl.item()
157 | va = va.item()
158 | print('epoch {}, val, loss={:.4f} acc={:.4f} maxacc={:.4f}'.format(epoch, vl, va,trlog['max_acc']))
159 |
160 | if va > trlog['max_acc']:
161 | trlog['max_acc'] = va
162 | save_model('max-acc')
163 |
164 | trlog['train_loss'].append(tl)
165 | trlog['train_acc'].append(ta)
166 | trlog['val_loss'].append(vl)
167 | trlog['val_acc'].append(va)
168 |
169 | torch.save(trlog, osp.join(args.save_path, 'trlog'))
170 |
171 | save_model('epoch-last')
172 |
173 | if epoch % args.save_epoch == 0:
174 | save_model('epoch-{}'.format(epoch))
175 |
176 | print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch)))
177 |
178 |
179 |
--------------------------------------------------------------------------------
/classification/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import time
4 | import pprint
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.autograd.variable as Variable
9 |
10 | from math import sqrt
11 | from numpy.random import seed
12 | from numpy.random import randn
13 | from numpy import mean
14 | from scipy.stats import sem
15 | from scipy.stats import t
16 | import numpy as np
17 | from collections import OrderedDict
18 | import random
19 |
20 | class GaussianNoise(nn.Module):
21 |
22 | def __init__(self, batch_size, input_shape=(3, 84, 84), std=0.05):
23 | super(GaussianNoise, self).__init__()
24 | self.shape = (batch_size,) + input_shape
25 | self.noise = Variable(torch.zeros(self.shape).cuda())
26 | self.std = std
27 |
28 | def forward(self, x, std=0.15):
29 | noise = Variable(torch.zeros(x.shape).cuda())
30 | noise = noise.data.normal_(0, std=std)
31 | return x + noise
32 |
33 |
34 | def set_seed(seed, cudnn=True):
35 | """
36 | Seed everything we can!
37 | Note that gym environments might need additional seeding (env.seed(seed)),
38 | and num_workers needs to be set to 1.
39 | """
40 | random.seed(seed)
41 | np.random.seed(seed)
42 | torch.manual_seed(seed)
43 | torch.random.manual_seed(seed)
44 | torch.cuda.manual_seed(seed)
45 | # note: the below slows down the code but makes it reproducible
46 | if (seed is not None) and cudnn:
47 | torch.backends.cudnn.deterministic = True
48 |
49 |
50 | def set_gpu(x):
51 | os.environ['CUDA_VISIBLE_DEVICES'] = x
52 | print('using gpu:', x)
53 |
54 |
55 | def clone(tensor):
56 | """Detach and clone a tensor including the ``requires_grad`` attribute.
57 |
58 | Arguments:
59 | tensor (torch.Tensor): tensor to clone.
60 | """
61 | cloned = tensor.clone()#tensor.detach().clone()
62 | # cloned.requires_grad = tensor.requires_grad
63 | # if tensor.grad is not None:
64 | # cloned.grad = clone(tensor.grad)
65 | return cloned
66 |
67 | def clone_state_dict(state_dict):
68 | """Clone a state_dict. If state_dict is from a ``torch.nn.Module``, use ``keep_vars=True``.
69 |
70 | Arguments:
71 | state_dict (OrderedDict): the state_dict to clone. Assumes state_dict is not detached from model state.
72 | """
73 | return OrderedDict([(name, clone(param)) for name, param in state_dict.items()])
74 |
75 | def ensure_path(path):
76 | if os.path.exists(path):
77 | if input('{} exists, remove? ([y]/n)'.format(path)) != 'n':
78 | shutil.rmtree(path)
79 | os.mkdir(path)
80 | else:
81 | os.mkdir(path)
82 |
83 |
84 | class Averager():
85 |
86 | def __init__(self):
87 | self.n = 0
88 | self.v = 0
89 |
90 | def add(self, x):
91 | self.v = (self.v * self.n + x) / (self.n + 1)
92 | self.n += 1
93 |
94 | def item(self):
95 | return self.v
96 |
97 |
98 | def count_acc(logits, label):
99 | pred = torch.argmax(logits, dim=1)
100 | return (pred == label).type(torch.cuda.FloatTensor).mean().item()
101 |
102 |
103 | def dot_metric(a, b):
104 | return torch.mm(a, b.t())
105 |
106 |
107 | def count_accuracy(logits, label):
108 | pred = torch.argmax(logits, dim=1).view(-1)
109 | label = label.view(-1)
110 | accuracy = 100 * pred.eq(label).float().mean()
111 | return accuracy
112 |
113 | def euclidean_metric(a, b):
114 | n = a.shape[0]
115 | m = b.shape[0]
116 | a = a.unsqueeze(1).expand(n, m, -1)
117 | b = b.unsqueeze(0).expand(n, m, -1)
118 | #logits = -((a - b)**2).sum(dim=2)
119 | logits = -((a - b)**2).sum(dim=2)
120 | return logits
121 |
122 |
123 | class Timer():
124 |
125 | def __init__(self):
126 | self.o = time.time()
127 |
128 | def measure(self, p=1):
129 | x = (time.time() - self.o) / p
130 | x = int(x)
131 | if x >= 3600:
132 | return '{:.1f}h'.format(x / 3600)
133 | if x >= 60:
134 | return '{}m'.format(round(x / 60))
135 | return '{}s'.format(x)
136 |
137 | _utils_pp = pprint.PrettyPrinter()
138 | def pprint(x):
139 | _utils_pp.pprint(x)
140 |
141 |
142 | def l2_loss(pred, label):
143 | return ((pred - label)**2).sum() / len(pred) / 2
144 |
145 | def set_protocol(data_path, protocol, test_protocol):
146 | train = []
147 | val = []
148 |
149 | all_set = ['shn', 'hon', 'clv', 'clk', 'gls', 'scl', 'sci', 'nat', 'shx', 'rel']
150 |
151 | if protocol == 'p1':
152 | for i in range(3):
153 | train.append(data_path + '/crops_' + all_set[i])
154 | elif protocol == 'p2':
155 | for i in range(3, 6):
156 | train.append(data_path + '/crops_' + all_set[i])
157 | elif protocol == 'p3':
158 | for i in range(6, 8):
159 | train.append(data_path + '/crops_' + all_set[i])
160 | elif protocol == 'p4':
161 | for i in range(8, 10):
162 | train.append(data_path + '/crops_' + all_set[i])
163 |
164 | if test_protocol == 'p1':
165 | for i in range(3):
166 | val.append(data_path + '/crops_' + all_set[i])
167 | elif test_protocol == 'p2':
168 | for i in range(3, 6):
169 | val.append(data_path + '/crops_' + all_set[i])
170 | elif test_protocol == 'p3':
171 | for i in range(6, 8):
172 | val.append(data_path + '/crops_' + all_set[i])
173 | elif test_protocol == 'p4':
174 | for i in range(8, 10):
175 | val.append(data_path + '/crops_' + all_set[i])
176 |
177 |
178 |
179 | return train, val
180 |
181 |
182 |
183 |
184 | def independent_ttest(data1, data2, alpha):
185 | # calculate means
186 | mean1, mean2 = mean(data1), mean(data2)
187 | # calculate standard errors
188 | se1, se2 = sem(data1), sem(data2)
189 | # standard error on the difference between the samples
190 | sed = sqrt(se1**2.0 + se2**2.0)
191 | # calculate the t statistic
192 | t_stat = (mean1 - mean2) / sed
193 | # degrees of freedom
194 | df = len(data1) + len(data2) - 2
195 | # calculate the critical value
196 | cv = t.ppf(1.0 - alpha, df)
197 | # calculate the p-value
198 | p = (1.0 - t.cdf(abs(t_stat), df)) * 2.0
199 | # return everything
200 | return t_stat, df, cv, p
201 |
202 |
203 | def perturb(data):
204 |
205 | randno = np.random.randint(0, 5)
206 | if randno == 1:
207 | return torch.cat((data, data.flip(3)), dim=0)
208 | elif randno == 2: #180
209 | return torch.cat((data, data.flip(2)), dim=0)
210 | elif randno == 3: #90
211 | return torch.cat((data, data.transpose(2,3)), dim=0)
212 | else:
213 | return torch.cat((data, data.transpose(2, 3).flip(3)), dim=0)
--------------------------------------------------------------------------------
/comparison_method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chrysts/generative_preconditioner/724deb49abbd45f06d3fb3a003dca07ad9d84241/comparison_method.png
--------------------------------------------------------------------------------