├── .gitignore ├── LICENSE ├── MiniImagenet.py ├── README.md ├── backup ├── csmlv0.py ├── mainv0.py └── naive5_train.py ├── learner.py ├── meta.py ├── miniimagenet_train.py ├── omniglot.py ├── omniglotNShot.py ├── omniglot_train.py ├── res ├── heart.gif └── mini-screen.png └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | backup 2 | miniimagenet 3 | omniglot 4 | .idea 5 | __pycache__ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Jackie Loong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MiniImagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | from torchvision.transforms import transforms 5 | import numpy as np 6 | import collections 7 | from PIL import Image 8 | import csv 9 | import random 10 | 11 | 12 | class MiniImagenet(Dataset): 13 | """ 14 | put mini-imagenet files as : 15 | root : 16 | |- images/*.jpg includes all imgeas 17 | |- train.csv 18 | |- test.csv 19 | |- val.csv 20 | NOTICE: meta-learning is different from general supervised learning, especially the concept of batch and set. 21 | batch: contains several sets 22 | sets: conains n_way * k_shot for meta-train set, n_way * n_query for meta-test set. 23 | """ 24 | 25 | def __init__(self, root, mode, batchsz, n_way, k_shot, k_query, resize, startidx=0): 26 | """ 27 | 28 | :param root: root path of mini-imagenet 29 | :param mode: train, val or test 30 | :param batchsz: batch size of sets, not batch of imgs 31 | :param n_way: 32 | :param k_shot: 33 | :param k_query: num of qeruy imgs per class 34 | :param resize: resize to 35 | :param startidx: start to index label from startidx 36 | """ 37 | 38 | self.batchsz = batchsz # batch of set, not batch of imgs 39 | self.n_way = n_way # n-way 40 | self.k_shot = k_shot # k-shot 41 | self.k_query = k_query # for evaluation 42 | self.setsz = self.n_way * self.k_shot # num of samples per set 43 | self.querysz = self.n_way * self.k_query # number of samples per set for evaluation 44 | self.resize = resize # resize to 45 | self.startidx = startidx # index label not from 0, but from startidx 46 | print('shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, resize:%d' % ( 47 | mode, batchsz, n_way, k_shot, k_query, resize)) 48 | 49 | if mode == 'train': 50 | self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'), 51 | transforms.Resize((self.resize, self.resize)), 52 | # transforms.RandomHorizontalFlip(), 53 | # transforms.RandomRotation(5), 54 | transforms.ToTensor(), 55 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 56 | ]) 57 | else: 58 | self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'), 59 | transforms.Resize((self.resize, self.resize)), 60 | transforms.ToTensor(), 61 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 62 | ]) 63 | 64 | self.path = os.path.join(root, 'images') # image path 65 | csvdata = self.loadCSV(os.path.join(root, mode + '.csv')) # csv path 66 | self.data = [] 67 | self.img2label = {} 68 | for i, (k, v) in enumerate(csvdata.items()): 69 | self.data.append(v) # [[img1, img2, ...], [img111, ...]] 70 | self.img2label[k] = i + self.startidx # {"img_name[:9]":label} 71 | self.cls_num = len(self.data) 72 | 73 | self.create_batch(self.batchsz) 74 | 75 | def loadCSV(self, csvf): 76 | """ 77 | return a dict saving the information of csv 78 | :param splitFile: csv file name 79 | :return: {label:[file1, file2 ...]} 80 | """ 81 | dictLabels = {} 82 | with open(csvf) as csvfile: 83 | csvreader = csv.reader(csvfile, delimiter=',') 84 | next(csvreader, None) # skip (filename, label) 85 | for i, row in enumerate(csvreader): 86 | filename = row[0] 87 | label = row[1] 88 | # append filename to current label 89 | if label in dictLabels.keys(): 90 | dictLabels[label].append(filename) 91 | else: 92 | dictLabels[label] = [filename] 93 | return dictLabels 94 | 95 | def create_batch(self, batchsz): 96 | """ 97 | create batch for meta-learning. 98 | ×episode× here means batch, and it means how many sets we want to retain. 99 | :param episodes: batch size 100 | :return: 101 | """ 102 | self.support_x_batch = [] # support set batch 103 | self.query_x_batch = [] # query set batch 104 | for b in range(batchsz): # for each batch 105 | # 1.select n_way classes randomly 106 | selected_cls = np.random.choice(self.cls_num, self.n_way, False) # no duplicate 107 | np.random.shuffle(selected_cls) 108 | support_x = [] 109 | query_x = [] 110 | for cls in selected_cls: 111 | # 2. select k_shot + k_query for each class 112 | selected_imgs_idx = np.random.choice(len(self.data[cls]), self.k_shot + self.k_query, False) 113 | np.random.shuffle(selected_imgs_idx) 114 | indexDtrain = np.array(selected_imgs_idx[:self.k_shot]) # idx for Dtrain 115 | indexDtest = np.array(selected_imgs_idx[self.k_shot:]) # idx for Dtest 116 | support_x.append( 117 | np.array(self.data[cls])[indexDtrain].tolist()) # get all images filename for current Dtrain 118 | query_x.append(np.array(self.data[cls])[indexDtest].tolist()) 119 | 120 | # shuffle the correponding relation between support set and query set 121 | random.shuffle(support_x) 122 | random.shuffle(query_x) 123 | 124 | self.support_x_batch.append(support_x) # append set to current sets 125 | self.query_x_batch.append(query_x) # append sets to current sets 126 | 127 | def __getitem__(self, index): 128 | """ 129 | index means index of sets, 0<= index <= batchsz-1 130 | :param index: 131 | :return: 132 | """ 133 | # [setsz, 3, resize, resize] 134 | support_x = torch.FloatTensor(self.setsz, 3, self.resize, self.resize) 135 | # [setsz] 136 | support_y = np.zeros((self.setsz), dtype=np.int) 137 | # [querysz, 3, resize, resize] 138 | query_x = torch.FloatTensor(self.querysz, 3, self.resize, self.resize) 139 | # [querysz] 140 | query_y = np.zeros((self.querysz), dtype=np.int) 141 | 142 | flatten_support_x = [os.path.join(self.path, item) 143 | for sublist in self.support_x_batch[index] for item in sublist] 144 | support_y = np.array( 145 | [self.img2label[item[:9]] # filename:n0153282900000005.jpg, the first 9 characters treated as label 146 | for sublist in self.support_x_batch[index] for item in sublist]).astype(np.int32) 147 | 148 | flatten_query_x = [os.path.join(self.path, item) 149 | for sublist in self.query_x_batch[index] for item in sublist] 150 | query_y = np.array([self.img2label[item[:9]] 151 | for sublist in self.query_x_batch[index] for item in sublist]).astype(np.int32) 152 | 153 | # print('global:', support_y, query_y) 154 | # support_y: [setsz] 155 | # query_y: [querysz] 156 | # unique: [n-way], sorted 157 | unique = np.unique(support_y) 158 | random.shuffle(unique) 159 | # relative means the label ranges from 0 to n-way 160 | support_y_relative = np.zeros(self.setsz) 161 | query_y_relative = np.zeros(self.querysz) 162 | for idx, l in enumerate(unique): 163 | support_y_relative[support_y == l] = idx 164 | query_y_relative[query_y == l] = idx 165 | 166 | # print('relative:', support_y_relative, query_y_relative) 167 | 168 | for i, path in enumerate(flatten_support_x): 169 | support_x[i] = self.transform(path) 170 | 171 | for i, path in enumerate(flatten_query_x): 172 | query_x[i] = self.transform(path) 173 | # print(support_set_y) 174 | # return support_x, torch.LongTensor(support_y), query_x, torch.LongTensor(query_y) 175 | 176 | return support_x, torch.LongTensor(support_y_relative), query_x, torch.LongTensor(query_y_relative) 177 | 178 | def __len__(self): 179 | # as we have built up to batchsz of sets, you can sample some small batch size of sets. 180 | return self.batchsz 181 | 182 | 183 | if __name__ == '__main__': 184 | # the following episode is to view one set of images via tensorboard. 185 | from torchvision.utils import make_grid 186 | from matplotlib import pyplot as plt 187 | from tensorboardX import SummaryWriter 188 | import time 189 | 190 | plt.ion() 191 | 192 | tb = SummaryWriter('runs', 'mini-imagenet') 193 | mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=5, k_shot=1, k_query=1, batchsz=1000, resize=168) 194 | 195 | for i, set_ in enumerate(mini): 196 | # support_x: [k_shot*n_way, 3, 84, 84] 197 | support_x, support_y, query_x, query_y = set_ 198 | 199 | support_x = make_grid(support_x, nrow=2) 200 | query_x = make_grid(query_x, nrow=2) 201 | 202 | plt.figure(1) 203 | plt.imshow(support_x.transpose(2, 0).numpy()) 204 | plt.pause(0.5) 205 | plt.figure(2) 206 | plt.imshow(query_x.transpose(2, 0).numpy()) 207 | plt.pause(0.5) 208 | 209 | tb.add_image('support_x', support_x) 210 | tb.add_image('query_x', query_x) 211 | 212 | time.sleep(5) 213 | 214 | tb.close() 215 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MAML-Pytorch 2 | PyTorch implementation of the supervised learning experiments from the paper: 3 | [Model-Agnostic Meta-Learning (MAML)](https://arxiv.org/abs/1703.03400). 4 | 5 | > Version 1.0: Both `MiniImagenet` and `Omniglot` Datasets are supported! Have Fun~ 6 | 7 | > Version 2.0: Re-write meta learner and basic learner. Solved some serious bugs in version 1.0. 8 | 9 | For Tensorflow Implementation, please visit official [HERE](https://github.com/cbfinn/maml) and simplier version [HERE](https://github.com/dragen1860/MAML-TensorFlow). 10 | 11 | For First-Order Approximation Implementation, Reptile namely, please visit [HERE](https://github.com/dragen1860/Reptile-Pytorch). 12 | 13 | ![heart](res/heart.gif) 14 | 15 | # Platform 16 | - python: 3.x 17 | - Pytorch: 0.4+ 18 | 19 | # MiniImagenet 20 | 21 | 22 | ## Howto 23 | 24 | For 5-way 1-shot exp., it allocates nearly 6GB GPU memory. 25 | 26 | 1. download `MiniImagenet` dataset from [here](https://github.com/dragen1860/LearningToCompare-Pytorch/issues/4), splitting: `train/val/test.csv` from [here](https://github.com/twitter/meta-learning-lstm/tree/master/data/miniImagenet). 27 | 2. extract it like: 28 | ```shell 29 | miniimagenet/ 30 | ├── images 31 | ├── n0210891500001298.jpg 32 | ├── n0287152500001298.jpg 33 | ... 34 | ├── test.csv 35 | ├── val.csv 36 | └── train.csv 37 | 38 | 39 | ``` 40 | 3. modify the `path` in `miniimagenet_train.py`: 41 | ```python 42 | mini = MiniImagenet('miniimagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt, 43 | k_query=args.k_qry, 44 | batchsz=10000, resize=args.imgsz) 45 | ... 46 | mini_test = MiniImagenet('miniimagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt, 47 | k_query=args.k_qry, 48 | batchsz=100, resize=args.imgsz) 49 | ``` 50 | to your actual data path. 51 | 52 | 4. just run `python miniimagenet_train.py` and the running screenshot is as follows: 53 | ![screenshot-miniimagetnet](res/mini-screen.png) 54 | 55 | If your reproducation perf. is not so good, maybe you can enlarge your `training epoch` to get longer training. And MAML is notorious for its hard training. Therefore, this implementation only provide you a basic start point to begin your research. 56 | and the performance below is true and achieved on my machine. 57 | 58 | ## Benchmark 59 | 60 | | Model | Fine Tune | 5-way Acc. | | 20-way Acc.| | 61 | |-------------------------------------|-----------|------------|--------|------------|--------| 62 | | | | 1-shot | 5-shot | 1-shot | 5-shot | 63 | | Matching Nets | N | 43.56% | 55.31% | 17.31% | 22.69% | 64 | | Meta-LSTM | | 43.44% | 60.60% | 16.70% | 26.06% | 65 | | MAML | Y | 48.7% | 63.11% | 16.49% | 19.29% | 66 | | **Ours** | Y | 46.2% | 60.3% | - | - | 67 | 68 | 69 | 70 | # Ominiglot 71 | 72 | ## Howto 73 | run `python omniglot_train.py`, the program will download `omniglot` dataset automatically. 74 | 75 | decrease the value of `args.task_num` to fit your GPU memory capacity. 76 | 77 | For 5-way 1-shot exp., it allocates nearly 3GB GPU memory. 78 | 79 | 80 | # Refer to this Rep. 81 | ``` 82 | @misc{MAML_Pytorch, 83 | author = {Liangqu Long}, 84 | title = {MAML-Pytorch Implementation}, 85 | year = {2018}, 86 | publisher = {GitHub}, 87 | journal = {GitHub repository}, 88 | howpublished = {\url{https://github.com/dragen1860/MAML-Pytorch}}, 89 | commit = {master} 90 | } 91 | ``` 92 | -------------------------------------------------------------------------------- /backup/csmlv0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import optim 4 | from torch import autograd 5 | from torch import multiprocessing 6 | from torch.autograd import Variable 7 | from torch.nn import functional as F 8 | from torch.utils.data import TensorDataset 9 | import numpy as np 10 | import os 11 | 12 | multiprocessing = multiprocessing.get_context('spawn') 13 | 14 | 15 | class Concept(nn.Module): 16 | 17 | def __init__(self): 18 | super(Concept, self).__init__() 19 | 20 | self.net = nn.Sequential( 21 | nn.Conv2d(3, 64, kernel_size=3, padding=0), 22 | nn.BatchNorm2d(64, momentum=1), 23 | nn.ReLU(inplace=True), 24 | nn.MaxPool2d(kernel_size=2), 25 | 26 | nn.Conv2d(64, 64, kernel_size=3, padding=0), 27 | nn.BatchNorm2d(64, momentum=1), 28 | nn.ReLU(inplace=True), 29 | nn.MaxPool2d(kernel_size=2), 30 | 31 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 32 | nn.BatchNorm2d(64, momentum=1), 33 | nn.ReLU(inplace=True), 34 | 35 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(64, momentum=1), 37 | nn.ReLU(inplace=True), 38 | ) 39 | 40 | def load(self, src): 41 | """ 42 | Load parameters from central pool 43 | :param src: 44 | :return: 45 | """ 46 | self.load_state_dict(src.state_dict()) 47 | 48 | def forward(self, x): 49 | x = self.net(x) 50 | 51 | return x 52 | 53 | 54 | class Relation(nn.Module): 55 | 56 | def __init__(self): 57 | super(Relation, self).__init__() 58 | 59 | self.g = nn.Sequential( 60 | nn.Linear(2 * (64 + 2), 256), 61 | nn.ReLU(inplace=True), 62 | nn.Linear(256, 256), 63 | nn.ReLU(inplace=True), 64 | nn.Linear(256, 256), 65 | nn.ReLU(inplace=True) 66 | ) 67 | 68 | self.f = nn.Sequential( 69 | nn.Linear(256, 256), 70 | nn.ReLU(inplace=True), 71 | nn.Linear(256, 256), 72 | nn.ReLU(inplace=True), 73 | nn.Linear(256, 256), 74 | nn.ReLU(inplace=True) 75 | ) 76 | 77 | def forward(self, x): 78 | pass 79 | 80 | 81 | class OutLayer(nn.Module): 82 | 83 | def __init__(self): 84 | super(OutLayer, self).__init__() 85 | 86 | self.net = nn.Sequential( 87 | nn.Linear(64 * 3 * 3, 5) 88 | ) 89 | 90 | def forward(self, x): 91 | # downsample 92 | x = F.avg_pool2d(x, 5, 5) 93 | # flatten 94 | x = x.view(x.size(0), -1) 95 | # print(x.size()) 96 | return self.net(x) 97 | 98 | 99 | def inner_train(K, gpuidx, support_x, support_y, query_x, query_y, concepts, Q): 100 | """ 101 | inner-loop train function. 102 | :param K: train iterations 103 | :param gpuidx: which gpu to train 104 | :param support_x: [b, setsz, c_, h, w] 105 | :param support_y: [] 106 | :param query_x: [b, querysz] 107 | :param query_y: 108 | :param concepts: concepts network 109 | :param Q: Queue to receive result 110 | :return: 111 | """ 112 | # 113 | assert support_x.size(0) == query_x.size(0) 114 | # move current tensor into working GPU card. 115 | support_x = support_x.cuda(gpuidx) 116 | support_y = support_y.cuda(gpuidx) 117 | query_x = query_x.cuda(gpuidx) 118 | query_y = query_y.cuda(gpuidx) 119 | 120 | support_db = TensorDataset(support_x, support_y) 121 | query_db = TensorDataset(query_x, query_y) 122 | 123 | # this is inner-loop, update for K steps for one task. 124 | outlayer = OutLayer().cuda(gpuidx) 125 | criteon = nn.CrossEntropyLoss().cuda(gpuidx) 126 | # this is inner-loop optimizer, and corresponding lr stands for update-lr 127 | optimizer = optim.Adam(outlayer.parameters(), lr=1e-3) 128 | 129 | # right = [0] on gpuidx 130 | right = Variable(torch.zeros(1).cuda(gpuidx)) 131 | # loss = [0] on gpuidx 132 | loss = Variable(torch.zeros(1).cuda(gpuidx)) 133 | for (support_xb, support_yb), (query_xb, query_yb) in zip(support_db, query_db): 134 | # support_xb: [setsz, c_, h, w] 135 | # support_yb: [setsz] 136 | # query_xb : [querysz, c_, h, w] 137 | # query_yb : [querysz] 138 | # 1. meta-train for K iterations on meta-train dataset 139 | for i in range(K): 140 | # get the representation from concept-network 141 | x = concepts[gpuidx](support_xb) 142 | # detach gradient backpropagation 143 | x = x.detach() 144 | # push to outlayer-network 145 | logits = outlayer(x) 146 | # compute loss 147 | loss = criteon(logits, support_yb) 148 | 149 | # backward 150 | outlayer.zero_grad() 151 | loss.backward() 152 | optimizer.step() 153 | 154 | # 2. meta-test on meta-test dataset 155 | x = concepts[gpuidx](query_xb) 156 | # [querysz, nway] [querysz] 157 | logits = outlayer(x) 158 | _, idx = logits.max(1) 159 | # convert ByteTensor to LongTensor 160 | pred = idx.long() 161 | 162 | # 3. accumulate all right num and loss 163 | # torch.eq() return with ByteTensor 164 | # we use logits to compute loss while use pred to calculate accuracy 165 | right += torch.eq(pred, query_yb).sum().float() 166 | loss += criteon(logits, query_yb) 167 | 168 | # compute accuracy 169 | accuracy = right.data[0] / np.array(query_y.size()).prod() 170 | 171 | print(gpuidx, loss.data[0], accuracy) 172 | # save meta-test-loss into Queue for current task 173 | # just save data, not TENSOR 174 | Q.put([gpuidx, loss.data[0], accuracy]) 175 | 176 | del outlayer, criteon 177 | print('removed outlayer and criteon.') 178 | 179 | 180 | class CSML: 181 | """ 182 | Concept-Sharing Meta-Learning 183 | """ 184 | 185 | def __init__(self): 186 | 187 | # num of task training in parallel 188 | self.N = 3 189 | # inner-loop update iteration 190 | self.K = 10 191 | 192 | # each task has individual concept and output network, we deploy them on distinct GPUs and 193 | # merge into a list. 194 | self.concepts = [] 195 | self.outlayers = [] 196 | self.optimizer = None 197 | 198 | # to save async multi-tasks' loss and accuracy 199 | self.Q = multiprocessing.Queue() 200 | 201 | print('please call deploy() func to deploy networks. DO NOT call cuda() explicitly.') 202 | 203 | def deploy(self): 204 | # deplay N task on distributed GPU cluster and 205 | # append instance into list 206 | for i in range(self.N): 207 | concept = Concept().cuda(i) 208 | outlayer = OutLayer().cuda(i) 209 | self.concepts.append(concept) 210 | self.outlayers.append(outlayer) 211 | 212 | # meta optimizer 213 | self.optimizer = optim.Adam(self.concepts[0].parameters(), lr=1e-3) 214 | print('deploy done.') 215 | 216 | def train(self, support_x, support_y, query_x, query_y, train=True): 217 | """ 218 | This is meta-train and meta-test function. 219 | :param support_x: [batchsz, setsz, c_, h, w] 220 | :param support_y: [batchsz, setsz] 221 | :param query_x: [batchsz, querysz] 222 | :param query_y: [batchsz, querysz] 223 | :return: 224 | """ 225 | # we need split single batch into several batch to process asynchonuous 226 | batchsz = support_x.size(0) 227 | support_xb = torch.chunk(support_x, self.N) 228 | support_yb = torch.chunk(support_y, self.N) 229 | query_xb = torch.chunk(query_x, self.N) 230 | query_yb = torch.chunk(query_y, self.N) 231 | 232 | # 1. download latest Concept Network weights from central pool 233 | # here download from GPU0 234 | for i in range(1, self.N): 235 | self.concepts[i].load(self.concepts[0]) 236 | 237 | # 3. start training for whole tasks asynchronous 238 | processes = [] 239 | for p in range(self.N): 240 | p = multiprocessing.Process(target=inner_train, 241 | args=(self.K, p, support_xb[p], support_yb[p], query_xb[p], query_yb[p], 242 | self.concepts, self.Q)) 243 | p.start() 244 | processes.append(p) 245 | for p in processes: 246 | p.join() 247 | 248 | print('join completed.') 249 | # 4. merge result 250 | # util here, we have executed all tasks in GPU cluster in parallel. 251 | data = [self.Q.get_nowait() for _ in range(self.N)] 252 | accuracy = np.array([i[2] for i in data]).mean() 253 | meta_train_loss = np.array([i[1] for i in data]).astype(np.float32).sum() 254 | # meta_train_loss = Variable(torch.FloatTensor(meta_train_loss).cuda(0)) 255 | 256 | print('acc:', accuracy, 'meta-loss:', meta_train_loss) 257 | 258 | if train: 259 | # compute gradients 260 | autograd.grad() 261 | dummy_x = support_x[0][:2].cuda(0) 262 | # update concept network. 263 | self.optimizer.zero_grad() 264 | # [2, c_, h, w] 265 | dummy_loss = self.concepts[0](dummy_x) 266 | dummy_loss.backward(torch.FloatTensor([1]).cuda(0)) 267 | self.optimizer.step() 268 | -------------------------------------------------------------------------------- /backup/mainv0.py: -------------------------------------------------------------------------------- 1 | from omniglotNShot import OmniglotNShot 2 | from MiniImagenet import MiniImagenet 3 | from csml import CSML 4 | 5 | import torch 6 | from torch.autograd import Variable 7 | import numpy as np 8 | from tensorboardX import SummaryWriter 9 | from torch.utils.data import DataLoader 10 | 11 | 12 | def main(): 13 | meta_batchsz = 32 * 3 14 | n_way = 5 15 | k_shot = 5 16 | k_query = k_shot 17 | meta_lr = 1e-3 18 | num_updates = 5 19 | dataset = 'mini-imagenet' 20 | 21 | if dataset == 'omniglot': 22 | imgsz = 28 23 | db = OmniglotNShot('dataset', batchsz=meta_batchsz, n_way=n_way, k_shot=k_shot, k_query=k_query, imgsz=imgsz) 24 | 25 | elif dataset == 'mini-imagenet': 26 | imgsz = 84 27 | # the dataset loaders are different from omniglot to mini-imagenet. for omniglot, it just has one loader to use 28 | # get_batch(train or test) to get different batch. 29 | # for mini-imagenet, it should have two dataloader, one is train_loader and another is test_loader. 30 | mini = MiniImagenet('../../hdd1/meta/mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, 31 | batchsz=10000, resize=imgsz) 32 | db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) 33 | mini_test = MiniImagenet('../../hdd1/meta/mini-imagenet/', mode='test', n_way=n_way, k_shot=k_shot, 34 | k_query=k_query, 35 | batchsz=1000, resize=imgsz) 36 | db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True) 37 | 38 | else: 39 | raise NotImplementedError 40 | 41 | # do NOT call .cuda() implicitly 42 | net = CSML() 43 | net.deploy() 44 | 45 | tb = SummaryWriter('runs') 46 | 47 | # main loop 48 | for episode_num in range(200000): 49 | 50 | # 1. train 51 | if dataset == 'omniglot': 52 | support_x, support_y, query_x, query_y = db.get_batch('test') 53 | support_x = Variable( 54 | torch.from_numpy(support_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() 55 | query_x = Variable( 56 | torch.from_numpy(query_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() 57 | support_y = Variable(torch.from_numpy(support_y).long()).cuda() 58 | query_y = Variable(torch.from_numpy(query_y).long()).cuda() 59 | 60 | elif dataset == 'mini-imagenet': 61 | try: 62 | batch_train = iter(db).next() 63 | except StopIteration as err: 64 | mini = MiniImagenet('../../hdd1/meta/mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, 65 | k_query=k_query, 66 | batchsz=10000, resize=imgsz) 67 | db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) 68 | batch_train = iter(db).next() 69 | 70 | support_x = Variable(batch_train[0]) 71 | support_y = Variable(batch_train[1]) 72 | query_x = Variable(batch_train[2]) 73 | query_y = Variable(batch_train[3]) 74 | print(support_x.size(), support_y.size()) 75 | 76 | # backprop has been embeded in forward func. 77 | accs = net.train(support_x, support_y, query_x, query_y) 78 | train_acc = np.array(accs).mean() 79 | 80 | # 2. test 81 | if episode_num % 30 == 220: 82 | test_accs = [] 83 | for i in range(min(episode_num // 5000 + 3, 10)): # get average acc. 84 | if dataset == 'omniglot': 85 | support_x, support_y, query_x, query_y = db.get_batch('test') 86 | support_x = Variable( 87 | torch.from_numpy(support_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 88 | 1)).cuda() 89 | query_x = Variable( 90 | torch.from_numpy(query_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() 91 | support_y = Variable(torch.from_numpy(support_y).long()).cuda() 92 | query_y = Variable(torch.from_numpy(query_y).long()).cuda() 93 | 94 | elif dataset == 'mini-imagenet': 95 | try: 96 | batch_test = iter(db_test).next() 97 | except StopIteration as err: 98 | mini_test = MiniImagenet('../../hdd1/meta/mini-imagenet/', mode='test', n_way=n_way, 99 | k_shot=k_shot, 100 | k_query=k_query, 101 | batchsz=1000, resize=imgsz) 102 | db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True) 103 | batch_test = iter(db).next() 104 | 105 | support_x = Variable(batch_test[0]) 106 | support_y = Variable(batch_test[1]) 107 | query_x = Variable(batch_test[2]) 108 | query_y = Variable(batch_test[3]) 109 | 110 | # get accuracy 111 | # test_acc = net.train(support_x, support_y, query_x, query_y, train=False) 112 | test_accs.append(test_acc) 113 | 114 | test_acc = np.array(test_accs).mean() 115 | print('episode:', episode_num, '\tfinetune acc:%.6f' % train_acc, '\t\ttest acc:%.6f' % test_acc) 116 | tb.add_scalar('test-acc', test_acc) 117 | tb.add_scalar('finetune-acc', train_acc) 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /backup/naive5_train.py: -------------------------------------------------------------------------------- 1 | import torch, os 2 | import numpy as np 3 | from torch import optim 4 | from torch.autograd import Variable 5 | from MiniImagenet import MiniImagenet 6 | from naive5 import Naive5 7 | import scipy.stats 8 | from torch.utils.data import DataLoader 9 | from torch.optim import lr_scheduler 10 | import random, sys, pickle 11 | import argparse 12 | from torch import nn 13 | 14 | global_train_acc_buff = 0 15 | global_train_loss_buff = 0 16 | global_test_acc_buff = 0 17 | global_test_loss_buff = 0 18 | global_buff = [] 19 | 20 | 21 | def write2file(n_way, k_shot): 22 | global_buff.append([global_train_loss_buff, global_train_acc_buff, global_test_loss_buff, global_test_acc_buff]) 23 | with open("mini%d%d.pkl" % (n_way, k_shot), "wb") as fp: 24 | pickle.dump(global_buff, fp) 25 | 26 | 27 | def mean_confidence_interval(accs, confidence=0.95): 28 | n = accs.shape[0] 29 | m, se = np.mean(accs), scipy.stats.sem(accs) 30 | h = se * scipy.stats.t._ppf((1 + confidence) / 2, n - 1) 31 | return m, h 32 | 33 | 34 | # save best acc info, to save the best model to ckpt. 35 | best_accuracy = 0 36 | 37 | 38 | def evaluation(net, batchsz, n_way, k_shot, imgsz, episodesz, threhold, mdl_file): 39 | """ 40 | obey the expriment setting of MAML and Learning2Compare, we randomly sample 600 episodes and 15 query images per query 41 | set. 42 | :param net: 43 | :param batchsz: 44 | :return: 45 | """ 46 | k_query = 15 47 | mini_val = MiniImagenet('../mini-imagenet/', mode='test', n_way=n_way, k_shot=k_shot, k_query=k_query, 48 | batchsz=600, resize=imgsz) 49 | db_val = DataLoader(mini_val, batchsz, shuffle=True, num_workers=2, pin_memory=True) 50 | 51 | accs = [] 52 | episode_num = 0 # record tested num of episodes 53 | 54 | for batch_test in db_val: 55 | # [60, setsz, c_, h, w] 56 | # setsz = (5 + 15) * 5 57 | support_x = Variable(batch_test[0]).cuda() 58 | support_y = Variable(batch_test[1]).cuda() 59 | query_x = Variable(batch_test[2]).cuda() 60 | query_y = Variable(batch_test[3]).cuda() 61 | 62 | # we will split query set into 15 splits. 63 | # query_x : [batch, 15*way, c_, h, w] 64 | # query_x_b : tuple, 15 * [b, way, c_, h, w] 65 | query_x_b = torch.chunk(query_x, k_query, dim=1) 66 | # query_y : [batch, 15*way] 67 | # query_y_b: 15* [b, way] 68 | query_y_b = torch.chunk(query_y, k_query, dim=1) 69 | preds = [] 70 | net.eval() 71 | # we don't need the total acc on 600 episodes, but we need the acc per sets of 15*nway setsz. 72 | total_correct = 0 73 | total_num = 0 74 | total_loss = 0 75 | for query_x_mini, query_y_mini in zip(query_x_b, query_y_b): 76 | # print('query_x_mini', query_x_mini.size(), 'query_y_mini', query_y_mini.size()) 77 | loss, pred, correct = net(support_x, support_y, query_x_mini.contiguous(), query_y_mini, False) 78 | correct = correct.sum() # multi-gpu 79 | # pred: [b, nway] 80 | preds.append(pred) 81 | total_correct += correct.data[0] 82 | total_num += query_y_mini.size(0) * query_y_mini.size(1) 83 | 84 | total_loss += loss.data[0] 85 | 86 | # # 15 * [b, nway] => [b, 15*nway] 87 | # preds = torch.cat(preds, dim= 1) 88 | acc = total_correct / total_num 89 | print('%.5f,' % acc, end=' ') 90 | sys.stdout.flush() 91 | accs.append(acc) 92 | 93 | # update tested episode number 94 | episode_num += query_y.size(0) 95 | if episode_num > episodesz: 96 | # test current tested episodes acc. 97 | acc = np.array(accs).mean() 98 | if acc >= threhold: 99 | # if current acc is very high, we conduct all 600 episodes testing. 100 | continue 101 | else: 102 | # current acc is low, just conduct `episodesz` num of episodes. 103 | break 104 | 105 | # compute the distribution of 600/episodesz episodes acc. 106 | global best_accuracy 107 | accs = np.array(accs) 108 | accuracy, sem = mean_confidence_interval(accs) 109 | print('\naccuracy:', accuracy, 'sem:', sem) 110 | print('<<<<<<<<< accuracy:', accuracy, 'best accuracy:', best_accuracy, '>>>>>>>>') 111 | 112 | if accuracy > best_accuracy: 113 | best_accuracy = accuracy 114 | torch.save(net.state_dict(), mdl_file) 115 | print('Saved to checkpoint:', mdl_file) 116 | 117 | # we only take the last one batch as avg_loss 118 | total_loss = total_loss / n_way / k_query 119 | 120 | global global_test_loss_buff, global_test_acc_buff 121 | global_test_loss_buff = total_loss 122 | global_test_acc_buff = accuracy 123 | write2file(n_way, k_shot) 124 | 125 | return accuracy, sem 126 | 127 | 128 | def main(): 129 | argparser = argparse.ArgumentParser() 130 | argparser.add_argument('-n', help='n way') 131 | argparser.add_argument('-k', help='k shot') 132 | argparser.add_argument('-b', help='batch size') 133 | argparser.add_argument('-l', help='learning rate', default=1e-3) 134 | args = argparser.parse_args() 135 | n_way = int(args.n) 136 | k_shot = int(args.k) 137 | batchsz = int(args.b) 138 | lr = float(args.l) 139 | 140 | k_query = 1 141 | imgsz = 224 142 | threhold = 0.699 if k_shot == 5 else 0.584 # threshold for when to test full version of episode 143 | mdl_file = 'ckpt/naive5_3x3%d%d.mdl' % (n_way, k_shot) 144 | print('mini-imagnet: %d-way %d-shot lr:%f, threshold:%f' % (n_way, k_shot, lr, threhold)) 145 | 146 | global global_buff 147 | if os.path.exists('mini%d%d.pkl' % (n_way, k_shot)): 148 | global_buff = pickle.load(open('mini%d%d.pkl' % (n_way, k_shot), 'rb')) 149 | print('load pkl buff:', len(global_buff)) 150 | 151 | net = nn.DataParallel(Naive5(n_way, k_shot, imgsz), device_ids=[0, 1, 2]).cuda() 152 | print(net) 153 | 154 | if os.path.exists(mdl_file): 155 | print('load from checkpoint ...', mdl_file) 156 | net.load_state_dict(torch.load(mdl_file)) 157 | else: 158 | print('training from scratch.') 159 | 160 | # whole parameters number 161 | model_parameters = filter(lambda p: p.requires_grad, net.parameters()) 162 | params = sum([np.prod(p.size()) for p in model_parameters]) 163 | print('Total params:', params) 164 | 165 | # build optimizer and lr scheduler 166 | optimizer = optim.Adam(net.parameters(), lr=lr) 167 | # optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, nesterov=True) 168 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.5, patience=25, verbose=True) 169 | 170 | for epoch in range(1000): 171 | mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, 172 | batchsz=10000, resize=imgsz) 173 | db = DataLoader(mini, batchsz, shuffle=True, num_workers=8, pin_memory=True) 174 | total_train_loss = 0 175 | total_train_correct = 0 176 | total_train_num = 0 177 | 178 | for step, batch in enumerate(db): 179 | # 1. test 180 | if step % 300 == 0: 181 | # evaluation(net, batchsz, n_way, k_shot, imgsz, episodesz, threhold, mdl_file): 182 | accuracy, sem = evaluation(net, batchsz, n_way, k_shot, imgsz, 600, threhold, mdl_file) 183 | scheduler.step(accuracy) 184 | 185 | # 2. train 186 | support_x = Variable(batch[0]).cuda() 187 | support_y = Variable(batch[1]).cuda() 188 | query_x = Variable(batch[2]).cuda() 189 | query_y = Variable(batch[3]).cuda() 190 | 191 | net.train() 192 | loss, pred, correct = net(support_x, support_y, query_x, query_y) 193 | loss = loss.sum() / support_x.size(0) # multi-gpu, divide by total batchsz 194 | total_train_loss += loss.data[0] 195 | total_train_correct += correct.data[0] 196 | total_train_num += support_y.size(0) * n_way # k_query = 1 197 | 198 | optimizer.zero_grad() 199 | loss.backward() 200 | optimizer.step() 201 | 202 | # 3. print 203 | if step % 20 == 0 and step != 0: 204 | acc = total_train_correct / total_train_num 205 | total_train_correct = 0 206 | total_train_num = 0 207 | 208 | print('%d-way %d-shot %d batch> epoch:%d step:%d, loss:%.4f, train acc:%.4f' % ( 209 | n_way, k_shot, batchsz, epoch, step, total_train_loss, acc)) 210 | total_train_loss = 0 211 | 212 | global global_train_loss_buff, global_train_acc_buff 213 | global_train_loss_buff = loss.data[0] / (n_way * k_shot) 214 | global_train_acc_buff = acc 215 | write2file(n_way, k_shot) 216 | 217 | 218 | if __name__ == '__main__': 219 | main() 220 | -------------------------------------------------------------------------------- /learner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | 7 | 8 | class Learner(nn.Module): 9 | """ 10 | 11 | """ 12 | 13 | def __init__(self, config, imgc, imgsz): 14 | """ 15 | 16 | :param config: network config file, type:list of (string, list) 17 | :param imgc: 1 or 3 18 | :param imgsz: 28 or 84 19 | """ 20 | super(Learner, self).__init__() 21 | 22 | 23 | self.config = config 24 | 25 | # this dict contains all tensors needed to be optimized 26 | self.vars = nn.ParameterList() 27 | # running_mean and running_var 28 | self.vars_bn = nn.ParameterList() 29 | 30 | for i, (name, param) in enumerate(self.config): 31 | if name is 'conv2d': 32 | # [ch_out, ch_in, kernelsz, kernelsz] 33 | w = nn.Parameter(torch.ones(*param[:4])) 34 | # gain=1 according to cbfin's implementation 35 | torch.nn.init.kaiming_normal_(w) 36 | self.vars.append(w) 37 | # [ch_out] 38 | self.vars.append(nn.Parameter(torch.zeros(param[0]))) 39 | 40 | elif name is 'convt2d': 41 | # [ch_in, ch_out, kernelsz, kernelsz, stride, padding] 42 | w = nn.Parameter(torch.ones(*param[:4])) 43 | # gain=1 according to cbfin's implementation 44 | torch.nn.init.kaiming_normal_(w) 45 | self.vars.append(w) 46 | # [ch_in, ch_out] 47 | self.vars.append(nn.Parameter(torch.zeros(param[1]))) 48 | 49 | elif name is 'linear': 50 | # [ch_out, ch_in] 51 | w = nn.Parameter(torch.ones(*param)) 52 | # gain=1 according to cbfinn's implementation 53 | torch.nn.init.kaiming_normal_(w) 54 | self.vars.append(w) 55 | # [ch_out] 56 | self.vars.append(nn.Parameter(torch.zeros(param[0]))) 57 | 58 | elif name is 'bn': 59 | # [ch_out] 60 | w = nn.Parameter(torch.ones(param[0])) 61 | self.vars.append(w) 62 | # [ch_out] 63 | self.vars.append(nn.Parameter(torch.zeros(param[0]))) 64 | 65 | # must set requires_grad=False 66 | running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False) 67 | running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False) 68 | self.vars_bn.extend([running_mean, running_var]) 69 | 70 | 71 | elif name in ['tanh', 'relu', 'upsample', 'avg_pool2d', 'max_pool2d', 72 | 'flatten', 'reshape', 'leakyrelu', 'sigmoid']: 73 | continue 74 | else: 75 | raise NotImplementedError 76 | 77 | 78 | 79 | 80 | 81 | 82 | def extra_repr(self): 83 | info = '' 84 | 85 | for name, param in self.config: 86 | if name is 'conv2d': 87 | tmp = 'conv2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)'\ 88 | %(param[1], param[0], param[2], param[3], param[4], param[5],) 89 | info += tmp + '\n' 90 | 91 | elif name is 'convt2d': 92 | tmp = 'convTranspose2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)'\ 93 | %(param[0], param[1], param[2], param[3], param[4], param[5],) 94 | info += tmp + '\n' 95 | 96 | elif name is 'linear': 97 | tmp = 'linear:(in:%d, out:%d)'%(param[1], param[0]) 98 | info += tmp + '\n' 99 | 100 | elif name is 'leakyrelu': 101 | tmp = 'leakyrelu:(slope:%f)'%(param[0]) 102 | info += tmp + '\n' 103 | 104 | 105 | elif name is 'avg_pool2d': 106 | tmp = 'avg_pool2d:(k:%d, stride:%d, padding:%d)'%(param[0], param[1], param[2]) 107 | info += tmp + '\n' 108 | elif name is 'max_pool2d': 109 | tmp = 'max_pool2d:(k:%d, stride:%d, padding:%d)'%(param[0], param[1], param[2]) 110 | info += tmp + '\n' 111 | elif name in ['flatten', 'tanh', 'relu', 'upsample', 'reshape', 'sigmoid', 'use_logits', 'bn']: 112 | tmp = name + ':' + str(tuple(param)) 113 | info += tmp + '\n' 114 | else: 115 | raise NotImplementedError 116 | 117 | return info 118 | 119 | 120 | 121 | def forward(self, x, vars=None, bn_training=True): 122 | """ 123 | This function can be called by finetunning, however, in finetunning, we dont wish to update 124 | running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights. 125 | Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False 126 | but weight/bias will be updated and not dirty initial theta parameters via fast_weiths. 127 | :param x: [b, 1, 28, 28] 128 | :param vars: 129 | :param bn_training: set False to not update 130 | :return: x, loss, likelihood, kld 131 | """ 132 | 133 | if vars is None: 134 | vars = self.vars 135 | 136 | idx = 0 137 | bn_idx = 0 138 | 139 | for name, param in self.config: 140 | if name is 'conv2d': 141 | w, b = vars[idx], vars[idx + 1] 142 | # remember to keep synchrozied of forward_encoder and forward_decoder! 143 | x = F.conv2d(x, w, b, stride=param[4], padding=param[5]) 144 | idx += 2 145 | # print(name, param, '\tout:', x.shape) 146 | elif name is 'convt2d': 147 | w, b = vars[idx], vars[idx + 1] 148 | # remember to keep synchrozied of forward_encoder and forward_decoder! 149 | x = F.conv_transpose2d(x, w, b, stride=param[4], padding=param[5]) 150 | idx += 2 151 | # print(name, param, '\tout:', x.shape) 152 | elif name is 'linear': 153 | w, b = vars[idx], vars[idx + 1] 154 | x = F.linear(x, w, b) 155 | idx += 2 156 | # print('forward:', idx, x.norm().item()) 157 | elif name is 'bn': 158 | w, b = vars[idx], vars[idx + 1] 159 | running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx+1] 160 | x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training) 161 | idx += 2 162 | bn_idx += 2 163 | 164 | elif name is 'flatten': 165 | # print(x.shape) 166 | x = x.view(x.size(0), -1) 167 | elif name is 'reshape': 168 | # [b, 8] => [b, 2, 2, 2] 169 | x = x.view(x.size(0), *param) 170 | elif name is 'relu': 171 | x = F.relu(x, inplace=param[0]) 172 | elif name is 'leakyrelu': 173 | x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1]) 174 | elif name is 'tanh': 175 | x = F.tanh(x) 176 | elif name is 'sigmoid': 177 | x = torch.sigmoid(x) 178 | elif name is 'upsample': 179 | x = F.upsample_nearest(x, scale_factor=param[0]) 180 | elif name is 'max_pool2d': 181 | x = F.max_pool2d(x, param[0], param[1], param[2]) 182 | elif name is 'avg_pool2d': 183 | x = F.avg_pool2d(x, param[0], param[1], param[2]) 184 | 185 | else: 186 | raise NotImplementedError 187 | 188 | # make sure variable is used properly 189 | assert idx == len(vars) 190 | assert bn_idx == len(self.vars_bn) 191 | 192 | 193 | return x 194 | 195 | 196 | def zero_grad(self, vars=None): 197 | """ 198 | 199 | :param vars: 200 | :return: 201 | """ 202 | with torch.no_grad(): 203 | if vars is None: 204 | for p in self.vars: 205 | if p.grad is not None: 206 | p.grad.zero_() 207 | else: 208 | for p in vars: 209 | if p.grad is not None: 210 | p.grad.zero_() 211 | 212 | def parameters(self): 213 | """ 214 | override this function since initial parameters will return with a generator. 215 | :return: 216 | """ 217 | return self.vars -------------------------------------------------------------------------------- /meta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import optim 4 | from torch.nn import functional as F 5 | from torch.utils.data import TensorDataset, DataLoader 6 | from torch import optim 7 | import numpy as np 8 | 9 | from learner import Learner 10 | from copy import deepcopy 11 | 12 | 13 | 14 | class Meta(nn.Module): 15 | """ 16 | Meta Learner 17 | """ 18 | def __init__(self, args, config): 19 | """ 20 | 21 | :param args: 22 | """ 23 | super(Meta, self).__init__() 24 | 25 | self.update_lr = args.update_lr 26 | self.meta_lr = args.meta_lr 27 | self.n_way = args.n_way 28 | self.k_spt = args.k_spt 29 | self.k_qry = args.k_qry 30 | self.task_num = args.task_num 31 | self.update_step = args.update_step 32 | self.update_step_test = args.update_step_test 33 | 34 | 35 | self.net = Learner(config, args.imgc, args.imgsz) 36 | self.meta_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr) 37 | 38 | 39 | 40 | 41 | def clip_grad_by_norm_(self, grad, max_norm): 42 | """ 43 | in-place gradient clipping. 44 | :param grad: list of gradients 45 | :param max_norm: maximum norm allowable 46 | :return: 47 | """ 48 | 49 | total_norm = 0 50 | counter = 0 51 | for g in grad: 52 | param_norm = g.data.norm(2) 53 | total_norm += param_norm.item() ** 2 54 | counter += 1 55 | total_norm = total_norm ** (1. / 2) 56 | 57 | clip_coef = max_norm / (total_norm + 1e-6) 58 | if clip_coef < 1: 59 | for g in grad: 60 | g.data.mul_(clip_coef) 61 | 62 | return total_norm/counter 63 | 64 | 65 | def forward(self, x_spt, y_spt, x_qry, y_qry): 66 | """ 67 | 68 | :param x_spt: [b, setsz, c_, h, w] 69 | :param y_spt: [b, setsz] 70 | :param x_qry: [b, querysz, c_, h, w] 71 | :param y_qry: [b, querysz] 72 | :return: 73 | """ 74 | task_num, setsz, c_, h, w = x_spt.size() 75 | querysz = x_qry.size(1) 76 | 77 | losses_q = [0 for _ in range(self.update_step + 1)] # losses_q[i] is the loss on step i 78 | corrects = [0 for _ in range(self.update_step + 1)] 79 | 80 | 81 | for i in range(task_num): 82 | 83 | # 1. run the i-th task and compute loss for k=0 84 | logits = self.net(x_spt[i], vars=None, bn_training=True) 85 | loss = F.cross_entropy(logits, y_spt[i]) 86 | grad = torch.autograd.grad(loss, self.net.parameters()) 87 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) 88 | 89 | # this is the loss and accuracy before first update 90 | with torch.no_grad(): 91 | # [setsz, nway] 92 | logits_q = self.net(x_qry[i], self.net.parameters(), bn_training=True) 93 | loss_q = F.cross_entropy(logits_q, y_qry[i]) 94 | losses_q[0] += loss_q 95 | 96 | pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 97 | correct = torch.eq(pred_q, y_qry[i]).sum().item() 98 | corrects[0] = corrects[0] + correct 99 | 100 | # this is the loss and accuracy after the first update 101 | with torch.no_grad(): 102 | # [setsz, nway] 103 | logits_q = self.net(x_qry[i], fast_weights, bn_training=True) 104 | loss_q = F.cross_entropy(logits_q, y_qry[i]) 105 | losses_q[1] += loss_q 106 | # [setsz] 107 | pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 108 | correct = torch.eq(pred_q, y_qry[i]).sum().item() 109 | corrects[1] = corrects[1] + correct 110 | 111 | for k in range(1, self.update_step): 112 | # 1. run the i-th task and compute loss for k=1~K-1 113 | logits = self.net(x_spt[i], fast_weights, bn_training=True) 114 | loss = F.cross_entropy(logits, y_spt[i]) 115 | # 2. compute grad on theta_pi 116 | grad = torch.autograd.grad(loss, fast_weights) 117 | # 3. theta_pi = theta_pi - train_lr * grad 118 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))) 119 | 120 | logits_q = self.net(x_qry[i], fast_weights, bn_training=True) 121 | # loss_q will be overwritten and just keep the loss_q on last update step. 122 | loss_q = F.cross_entropy(logits_q, y_qry[i]) 123 | losses_q[k + 1] += loss_q 124 | 125 | with torch.no_grad(): 126 | pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 127 | correct = torch.eq(pred_q, y_qry[i]).sum().item() # convert to numpy 128 | corrects[k + 1] = corrects[k + 1] + correct 129 | 130 | 131 | 132 | # end of all tasks 133 | # sum over all losses on query set across all tasks 134 | loss_q = losses_q[-1] / task_num 135 | 136 | # optimize theta parameters 137 | self.meta_optim.zero_grad() 138 | loss_q.backward() 139 | # print('meta update') 140 | # for p in self.net.parameters()[:5]: 141 | # print(torch.norm(p).item()) 142 | self.meta_optim.step() 143 | 144 | 145 | accs = np.array(corrects) / (querysz * task_num) 146 | 147 | return accs 148 | 149 | 150 | def finetunning(self, x_spt, y_spt, x_qry, y_qry): 151 | """ 152 | 153 | :param x_spt: [setsz, c_, h, w] 154 | :param y_spt: [setsz] 155 | :param x_qry: [querysz, c_, h, w] 156 | :param y_qry: [querysz] 157 | :return: 158 | """ 159 | assert len(x_spt.shape) == 4 160 | 161 | querysz = x_qry.size(0) 162 | 163 | corrects = [0 for _ in range(self.update_step_test + 1)] 164 | 165 | # in order to not ruin the state of running_mean/variance and bn_weight/bias 166 | # we finetunning on the copied model instead of self.net 167 | net = deepcopy(self.net) 168 | 169 | # 1. run the i-th task and compute loss for k=0 170 | logits = net(x_spt) 171 | loss = F.cross_entropy(logits, y_spt) 172 | grad = torch.autograd.grad(loss, net.parameters()) 173 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters()))) 174 | 175 | # this is the loss and accuracy before first update 176 | with torch.no_grad(): 177 | # [setsz, nway] 178 | logits_q = net(x_qry, net.parameters(), bn_training=True) 179 | # [setsz] 180 | pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 181 | # scalar 182 | correct = torch.eq(pred_q, y_qry).sum().item() 183 | corrects[0] = corrects[0] + correct 184 | 185 | # this is the loss and accuracy after the first update 186 | with torch.no_grad(): 187 | # [setsz, nway] 188 | logits_q = net(x_qry, fast_weights, bn_training=True) 189 | # [setsz] 190 | pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 191 | # scalar 192 | correct = torch.eq(pred_q, y_qry).sum().item() 193 | corrects[1] = corrects[1] + correct 194 | 195 | for k in range(1, self.update_step_test): 196 | # 1. run the i-th task and compute loss for k=1~K-1 197 | logits = net(x_spt, fast_weights, bn_training=True) 198 | loss = F.cross_entropy(logits, y_spt) 199 | # 2. compute grad on theta_pi 200 | grad = torch.autograd.grad(loss, fast_weights) 201 | # 3. theta_pi = theta_pi - train_lr * grad 202 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))) 203 | 204 | logits_q = net(x_qry, fast_weights, bn_training=True) 205 | # loss_q will be overwritten and just keep the loss_q on last update step. 206 | loss_q = F.cross_entropy(logits_q, y_qry) 207 | 208 | with torch.no_grad(): 209 | pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 210 | correct = torch.eq(pred_q, y_qry).sum().item() # convert to numpy 211 | corrects[k + 1] = corrects[k + 1] + correct 212 | 213 | 214 | del net 215 | 216 | accs = np.array(corrects) / querysz 217 | 218 | return accs 219 | 220 | 221 | 222 | 223 | def main(): 224 | pass 225 | 226 | 227 | if __name__ == '__main__': 228 | main() 229 | -------------------------------------------------------------------------------- /miniimagenet_train.py: -------------------------------------------------------------------------------- 1 | import torch, os 2 | import numpy as np 3 | from MiniImagenet import MiniImagenet 4 | import scipy.stats 5 | from torch.utils.data import DataLoader 6 | from torch.optim import lr_scheduler 7 | import random, sys, pickle 8 | import argparse 9 | 10 | from meta import Meta 11 | 12 | 13 | def mean_confidence_interval(accs, confidence=0.95): 14 | n = accs.shape[0] 15 | m, se = np.mean(accs), scipy.stats.sem(accs) 16 | h = se * scipy.stats.t._ppf((1 + confidence) / 2, n - 1) 17 | return m, h 18 | 19 | 20 | def main(): 21 | 22 | torch.manual_seed(222) 23 | torch.cuda.manual_seed_all(222) 24 | np.random.seed(222) 25 | 26 | print(args) 27 | 28 | config = [ 29 | ('conv2d', [32, 3, 3, 3, 1, 0]), 30 | ('relu', [True]), 31 | ('bn', [32]), 32 | ('max_pool2d', [2, 2, 0]), 33 | ('conv2d', [32, 32, 3, 3, 1, 0]), 34 | ('relu', [True]), 35 | ('bn', [32]), 36 | ('max_pool2d', [2, 2, 0]), 37 | ('conv2d', [32, 32, 3, 3, 1, 0]), 38 | ('relu', [True]), 39 | ('bn', [32]), 40 | ('max_pool2d', [2, 2, 0]), 41 | ('conv2d', [32, 32, 3, 3, 1, 0]), 42 | ('relu', [True]), 43 | ('bn', [32]), 44 | ('max_pool2d', [2, 1, 0]), 45 | ('flatten', []), 46 | ('linear', [args.n_way, 32 * 5 * 5]) 47 | ] 48 | 49 | device = torch.device('cuda') 50 | maml = Meta(args, config).to(device) 51 | 52 | tmp = filter(lambda x: x.requires_grad, maml.parameters()) 53 | num = sum(map(lambda x: np.prod(x.shape), tmp)) 54 | print(maml) 55 | print('Total trainable tensors:', num) 56 | 57 | # batchsz here means total episode number 58 | mini = MiniImagenet('/home/i/tmp/MAML-Pytorch/miniimagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt, 59 | k_query=args.k_qry, 60 | batchsz=10000, resize=args.imgsz) 61 | mini_test = MiniImagenet('/home/i/tmp/MAML-Pytorch/miniimagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt, 62 | k_query=args.k_qry, 63 | batchsz=100, resize=args.imgsz) 64 | 65 | for epoch in range(args.epoch//10000): 66 | # fetch meta_batchsz num of episode each time 67 | db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) 68 | 69 | for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): 70 | 71 | x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device) 72 | 73 | accs = maml(x_spt, y_spt, x_qry, y_qry) 74 | 75 | if step % 30 == 0: 76 | print('step:', step, '\ttraining acc:', accs) 77 | 78 | if step % 500 == 0: # evaluation 79 | db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) 80 | accs_all_test = [] 81 | 82 | for x_spt, y_spt, x_qry, y_qry in db_test: 83 | x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \ 84 | x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) 85 | 86 | accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) 87 | accs_all_test.append(accs) 88 | 89 | # [b, update_step+1] 90 | accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) 91 | print('Test acc:', accs) 92 | 93 | 94 | if __name__ == '__main__': 95 | 96 | argparser = argparse.ArgumentParser() 97 | argparser.add_argument('--epoch', type=int, help='epoch number', default=60000) 98 | argparser.add_argument('--n_way', type=int, help='n way', default=5) 99 | argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1) 100 | argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15) 101 | argparser.add_argument('--imgsz', type=int, help='imgsz', default=84) 102 | argparser.add_argument('--imgc', type=int, help='imgc', default=3) 103 | argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=4) 104 | argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3) 105 | argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01) 106 | argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5) 107 | argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10) 108 | 109 | args = argparser.parse_args() 110 | 111 | main() 112 | -------------------------------------------------------------------------------- /omniglot.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os 3 | import os.path 4 | import errno 5 | 6 | 7 | class Omniglot(data.Dataset): 8 | urls = [ 9 | 'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip', 10 | 'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip' 11 | ] 12 | raw_folder = 'raw' 13 | processed_folder = 'processed' 14 | training_file = 'training.pt' 15 | test_file = 'test.pt' 16 | 17 | ''' 18 | The items are (filename,category). The index of all the categories can be found in self.idx_classes 19 | Args: 20 | - root: the directory where the dataset will be stored 21 | - transform: how to transform the input 22 | - target_transform: how to transform the target 23 | - download: need to download the dataset 24 | ''' 25 | 26 | def __init__(self, root, transform=None, target_transform=None, download=False): 27 | self.root = root 28 | self.transform = transform 29 | self.target_transform = target_transform 30 | 31 | if not self._check_exists(): 32 | if download: 33 | self.download() 34 | else: 35 | raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') 36 | 37 | self.all_items = find_classes(os.path.join(self.root, self.processed_folder)) 38 | self.idx_classes = index_classes(self.all_items) 39 | 40 | def __getitem__(self, index): 41 | filename = self.all_items[index][0] 42 | img = str.join('/', [self.all_items[index][2], filename]) 43 | 44 | target = self.idx_classes[self.all_items[index][1]] 45 | if self.transform is not None: 46 | img = self.transform(img) 47 | if self.target_transform is not None: 48 | target = self.target_transform(target) 49 | 50 | return img, target 51 | 52 | def __len__(self): 53 | return len(self.all_items) 54 | 55 | def _check_exists(self): 56 | return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \ 57 | os.path.exists(os.path.join(self.root, self.processed_folder, "images_background")) 58 | 59 | def download(self): 60 | from six.moves import urllib 61 | import zipfile 62 | 63 | if self._check_exists(): 64 | return 65 | 66 | # download files 67 | try: 68 | os.makedirs(os.path.join(self.root, self.raw_folder)) 69 | os.makedirs(os.path.join(self.root, self.processed_folder)) 70 | except OSError as e: 71 | if e.errno == errno.EEXIST: 72 | pass 73 | else: 74 | raise 75 | 76 | for url in self.urls: 77 | print('== Downloading ' + url) 78 | data = urllib.request.urlopen(url) 79 | filename = url.rpartition('/')[2] 80 | file_path = os.path.join(self.root, self.raw_folder, filename) 81 | with open(file_path, 'wb') as f: 82 | f.write(data.read()) 83 | file_processed = os.path.join(self.root, self.processed_folder) 84 | print("== Unzip from " + file_path + " to " + file_processed) 85 | zip_ref = zipfile.ZipFile(file_path, 'r') 86 | zip_ref.extractall(file_processed) 87 | zip_ref.close() 88 | print("Download finished.") 89 | 90 | 91 | def find_classes(root_dir): 92 | retour = [] 93 | for (root, dirs, files) in os.walk(root_dir): 94 | for f in files: 95 | if (f.endswith("png")): 96 | r = root.split('/') 97 | lr = len(r) 98 | retour.append((f, r[lr - 2] + "/" + r[lr - 1], root)) 99 | print("== Found %d items " % len(retour)) 100 | return retour 101 | 102 | 103 | def index_classes(items): 104 | idx = {} 105 | for i in items: 106 | if i[1] not in idx: 107 | idx[i[1]] = len(idx) 108 | print("== Found %d classes" % len(idx)) 109 | return idx 110 | -------------------------------------------------------------------------------- /omniglotNShot.py: -------------------------------------------------------------------------------- 1 | from omniglot import Omniglot 2 | import torchvision.transforms as transforms 3 | from PIL import Image 4 | import os.path 5 | import numpy as np 6 | 7 | 8 | class OmniglotNShot: 9 | 10 | def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz): 11 | """ 12 | Different from mnistNShot, the 13 | :param root: 14 | :param batchsz: task num 15 | :param n_way: 16 | :param k_shot: 17 | :param k_qry: 18 | :param imgsz: 19 | """ 20 | 21 | self.resize = imgsz 22 | if not os.path.isfile(os.path.join(root, 'omniglot.npy')): 23 | # if root/data.npy does not exist, just download it 24 | self.x = Omniglot(root, download=True, 25 | transform=transforms.Compose([lambda x: Image.open(x).convert('L'), 26 | lambda x: x.resize((imgsz, imgsz)), 27 | lambda x: np.reshape(x, (imgsz, imgsz, 1)), 28 | lambda x: np.transpose(x, [2, 0, 1]), 29 | lambda x: x/255.]) 30 | ) 31 | 32 | temp = dict() # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label} 33 | for (img, label) in self.x: 34 | if label in temp.keys(): 35 | temp[label].append(img) 36 | else: 37 | temp[label] = [img] 38 | 39 | self.x = [] 40 | for label, imgs in temp.items(): # labels info deserted , each label contains 20imgs 41 | self.x.append(np.array(imgs)) 42 | 43 | # as different class may have different number of imgs 44 | self.x = np.array(self.x).astype(np.float) # [[20 imgs],..., 1623 classes in total] 45 | # each character contains 20 imgs 46 | print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1] 47 | temp = [] # Free memory 48 | # save all dataset into npy file. 49 | np.save(os.path.join(root, 'omniglot.npy'), self.x) 50 | print('write into omniglot.npy.') 51 | else: 52 | # if data.npy exists, just load it. 53 | self.x = np.load(os.path.join(root, 'omniglot.npy')) 54 | print('load from omniglot.npy.') 55 | 56 | # [1623, 20, 84, 84, 1] 57 | # TODO: can not shuffle here, we must keep training and test set distinct! 58 | self.x_train, self.x_test = self.x[:1200], self.x[1200:] 59 | 60 | # self.normalization() 61 | 62 | self.batchsz = batchsz 63 | self.n_cls = self.x.shape[0] # 1623 64 | self.n_way = n_way # n way 65 | self.k_shot = k_shot # k shot 66 | self.k_query = k_query # k query 67 | assert (k_shot + k_query) <=20 68 | 69 | # save pointer of current read batch in total cache 70 | self.indexes = {"train": 0, "test": 0} 71 | self.datasets = {"train": self.x_train, "test": self.x_test} # original data cached 72 | print("DB: train", self.x_train.shape, "test", self.x_test.shape) 73 | 74 | self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]), # current epoch data cached 75 | "test": self.load_data_cache(self.datasets["test"])} 76 | 77 | def normalization(self): 78 | """ 79 | Normalizes our data, to have a mean of 0 and sdt of 1 80 | """ 81 | self.mean = np.mean(self.x_train) 82 | self.std = np.std(self.x_train) 83 | self.max = np.max(self.x_train) 84 | self.min = np.min(self.x_train) 85 | # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) 86 | self.x_train = (self.x_train - self.mean) / self.std 87 | self.x_test = (self.x_test - self.mean) / self.std 88 | 89 | self.mean = np.mean(self.x_train) 90 | self.std = np.std(self.x_train) 91 | self.max = np.max(self.x_train) 92 | self.min = np.min(self.x_train) 93 | 94 | # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) 95 | 96 | def load_data_cache(self, data_pack): 97 | """ 98 | Collects several batches data for N-shot learning 99 | :param data_pack: [cls_num, 20, 84, 84, 1] 100 | :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks 101 | """ 102 | # take 5 way 1 shot as example: 5 * 1 103 | setsz = self.k_shot * self.n_way 104 | querysz = self.k_query * self.n_way 105 | data_cache = [] 106 | 107 | # print('preload next 50 caches of batchsz of batch.') 108 | for sample in range(10): # num of episodes 109 | 110 | x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] 111 | for i in range(self.batchsz): # one batch means one set 112 | 113 | x_spt, y_spt, x_qry, y_qry = [], [], [], [] 114 | selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False) 115 | 116 | for j, cur_class in enumerate(selected_cls): 117 | 118 | selected_img = np.random.choice(20, self.k_shot + self.k_query, False) 119 | 120 | # meta-training and meta-test 121 | x_spt.append(data_pack[cur_class][selected_img[:self.k_shot]]) 122 | x_qry.append(data_pack[cur_class][selected_img[self.k_shot:]]) 123 | y_spt.append([j for _ in range(self.k_shot)]) 124 | y_qry.append([j for _ in range(self.k_query)]) 125 | 126 | # shuffle inside a batch 127 | perm = np.random.permutation(self.n_way * self.k_shot) 128 | x_spt = np.array(x_spt).reshape(self.n_way * self.k_shot, 1, self.resize, self.resize)[perm] 129 | y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm] 130 | perm = np.random.permutation(self.n_way * self.k_query) 131 | x_qry = np.array(x_qry).reshape(self.n_way * self.k_query, 1, self.resize, self.resize)[perm] 132 | y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm] 133 | 134 | # append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84] 135 | x_spts.append(x_spt) 136 | y_spts.append(y_spt) 137 | x_qrys.append(x_qry) 138 | y_qrys.append(y_qry) 139 | 140 | 141 | # [b, setsz, 1, 84, 84] 142 | x_spts = np.array(x_spts).astype(np.float32).reshape(self.batchsz, setsz, 1, self.resize, self.resize) 143 | y_spts = np.array(y_spts).astype(np.int).reshape(self.batchsz, setsz) 144 | # [b, qrysz, 1, 84, 84] 145 | x_qrys = np.array(x_qrys).astype(np.float32).reshape(self.batchsz, querysz, 1, self.resize, self.resize) 146 | y_qrys = np.array(y_qrys).astype(np.int).reshape(self.batchsz, querysz) 147 | 148 | data_cache.append([x_spts, y_spts, x_qrys, y_qrys]) 149 | 150 | return data_cache 151 | 152 | def next(self, mode='train'): 153 | """ 154 | Gets next batch from the dataset with name. 155 | :param mode: The name of the splitting (one of "train", "val", "test") 156 | :return: 157 | """ 158 | # update cache if indexes is larger cached num 159 | if self.indexes[mode] >= len(self.datasets_cache[mode]): 160 | self.indexes[mode] = 0 161 | self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode]) 162 | 163 | next_batch = self.datasets_cache[mode][self.indexes[mode]] 164 | self.indexes[mode] += 1 165 | 166 | return next_batch 167 | 168 | 169 | 170 | 171 | 172 | if __name__ == '__main__': 173 | 174 | import time 175 | import torch 176 | import visdom 177 | 178 | # plt.ion() 179 | viz = visdom.Visdom(env='omniglot_view') 180 | 181 | db = OmniglotNShot('db/omniglot', batchsz=20, n_way=5, k_shot=5, k_query=15, imgsz=64) 182 | 183 | for i in range(1000): 184 | x_spt, y_spt, x_qry, y_qry = db.next('train') 185 | 186 | 187 | # [b, setsz, h, w, c] => [b, setsz, c, w, h] => [b, setsz, 3c, w, h] 188 | x_spt = torch.from_numpy(x_spt) 189 | x_qry = torch.from_numpy(x_qry) 190 | y_spt = torch.from_numpy(y_spt) 191 | y_qry = torch.from_numpy(y_qry) 192 | batchsz, setsz, c, h, w = x_spt.size() 193 | 194 | 195 | viz.images(x_spt[0], nrow=5, win='x_spt', opts=dict(title='x_spt')) 196 | viz.images(x_qry[0], nrow=15, win='x_qry', opts=dict(title='x_qry')) 197 | viz.text(str(y_spt[0]), win='y_spt', opts=dict(title='y_spt')) 198 | viz.text(str(y_qry[0]), win='y_qry', opts=dict(title='y_qry')) 199 | 200 | 201 | time.sleep(10) 202 | 203 | -------------------------------------------------------------------------------- /omniglot_train.py: -------------------------------------------------------------------------------- 1 | import torch, os 2 | import numpy as np 3 | from omniglotNShot import OmniglotNShot 4 | import argparse 5 | 6 | from meta import Meta 7 | 8 | def main(args): 9 | 10 | torch.manual_seed(222) 11 | torch.cuda.manual_seed_all(222) 12 | np.random.seed(222) 13 | 14 | print(args) 15 | 16 | config = [ 17 | ('conv2d', [64, 1, 3, 3, 2, 0]), 18 | ('relu', [True]), 19 | ('bn', [64]), 20 | ('conv2d', [64, 64, 3, 3, 2, 0]), 21 | ('relu', [True]), 22 | ('bn', [64]), 23 | ('conv2d', [64, 64, 3, 3, 2, 0]), 24 | ('relu', [True]), 25 | ('bn', [64]), 26 | ('conv2d', [64, 64, 2, 2, 1, 0]), 27 | ('relu', [True]), 28 | ('bn', [64]), 29 | ('flatten', []), 30 | ('linear', [args.n_way, 64]) 31 | ] 32 | 33 | device = torch.device('cuda') 34 | maml = Meta(args, config).to(device) 35 | 36 | tmp = filter(lambda x: x.requires_grad, maml.parameters()) 37 | num = sum(map(lambda x: np.prod(x.shape), tmp)) 38 | print(maml) 39 | print('Total trainable tensors:', num) 40 | 41 | db_train = OmniglotNShot('omniglot', 42 | batchsz=args.task_num, 43 | n_way=args.n_way, 44 | k_shot=args.k_spt, 45 | k_query=args.k_qry, 46 | imgsz=args.imgsz) 47 | 48 | for step in range(args.epoch): 49 | 50 | x_spt, y_spt, x_qry, y_qry = db_train.next() 51 | x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ 52 | torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) 53 | 54 | # set traning=True to update running_mean, running_variance, bn_weights, bn_bias 55 | accs = maml(x_spt, y_spt, x_qry, y_qry) 56 | 57 | if step % 50 == 0: 58 | print('step:', step, '\ttraining acc:', accs) 59 | 60 | if step % 500 == 0: 61 | accs = [] 62 | for _ in range(1000//args.task_num): 63 | # test 64 | x_spt, y_spt, x_qry, y_qry = db_train.next('test') 65 | x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \ 66 | torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device) 67 | 68 | # split to single task each time 69 | for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry): 70 | test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) 71 | accs.append( test_acc ) 72 | 73 | # [b, update_step+1] 74 | accs = np.array(accs).mean(axis=0).astype(np.float16) 75 | print('Test acc:', accs) 76 | 77 | 78 | if __name__ == '__main__': 79 | 80 | argparser = argparse.ArgumentParser() 81 | argparser.add_argument('--epoch', type=int, help='epoch number', default=40000) 82 | argparser.add_argument('--n_way', type=int, help='n way', default=5) 83 | argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1) 84 | argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15) 85 | argparser.add_argument('--imgsz', type=int, help='imgsz', default=28) 86 | argparser.add_argument('--imgc', type=int, help='imgc', default=1) 87 | argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=32) 88 | argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3) 89 | argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.4) 90 | argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5) 91 | argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10) 92 | 93 | args = argparser.parse_args() 94 | 95 | main(args) 96 | -------------------------------------------------------------------------------- /res/heart.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragen1860/MAML-Pytorch/98a00d41724c133bd29619a2fb2cc46dd128a368/res/heart.gif -------------------------------------------------------------------------------- /res/mini-screen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragen1860/MAML-Pytorch/98a00d41724c133bd29619a2fb2cc46dd128a368/res/mini-screen.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Net(nn.Module): 6 | 7 | def __init__(self): 8 | super(Net, self).__init__() 9 | 10 | self.a = nn.ParameterList([nn.Parameter(torch.zeros(3, 4))]) 11 | b = [torch.ones(2, 3), torch.ones(2, 3)] 12 | for i in range(2): 13 | self.register_buffer('b%d' % i, b[i]) 14 | 15 | def forward(self, input): 16 | return self.a[0] 17 | 18 | 19 | class MAML(nn.Module): 20 | 21 | def __init__(self): 22 | super(MAML, self).__init__() 23 | 24 | self.net = Net() 25 | 26 | def forward(self, input): 27 | return self.net(input) 28 | 29 | 30 | def main(): 31 | device = torch.device('cuda') 32 | maml = MAML().to(device) 33 | print(maml.net.a) 34 | print(maml.net.b0) 35 | 36 | 37 | if __name__ == '__main__': 38 | main() 39 | --------------------------------------------------------------------------------