├── MiniImagenet.py ├── README.md ├── main.py ├── meta.py ├── naive.py ├── omniglot.py ├── omniglotNShot.py └── res └── test-acc.png /MiniImagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | from torchvision.transforms import transforms 5 | import numpy as np 6 | import collections 7 | from PIL import Image 8 | import csv 9 | import random 10 | 11 | 12 | class MiniImagenet(Dataset): 13 | """ 14 | put mini-imagenet files as : 15 | root : 16 | |- images/*.jpg includes all imgeas 17 | |- train.csv 18 | |- test.csv 19 | |- val.csv 20 | NOTICE: meta-learning is different from general supervised learning, especially the concept of batch and set. 21 | batch: contains several sets 22 | sets: conains n_way * k_shot for meta-train set, n_way * n_query for meta-test set. 23 | """ 24 | 25 | def __init__(self, root, mode, batchsz, n_way, k_shot, k_query, resize, startidx=0): 26 | """ 27 | 28 | :param root: root path of mini-imagenet 29 | :param mode: train, val or test 30 | :param batchsz: batch size of sets, not batch of imgs 31 | :param n_way: 32 | :param k_shot: 33 | :param k_query: num of qeruy imgs per class 34 | :param resize: resize to 35 | :param startidx: start to index label from startidx 36 | """ 37 | 38 | self.batchsz = batchsz # batch of set, not batch of imgs 39 | self.n_way = n_way # n-way 40 | self.k_shot = k_shot # k-shot 41 | self.k_query = k_query # for evaluation 42 | self.setsz = self.n_way * self.k_shot # num of samples per set 43 | self.querysz = self.n_way * self.k_query # number of samples per set for evaluation 44 | self.resize = resize # resize to 45 | self.startidx = startidx # index label not from 0, but from startidx 46 | print('shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, resize:%d' % ( 47 | mode, batchsz, n_way, k_shot, k_query, resize)) 48 | 49 | if mode == 'train': 50 | self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'), 51 | transforms.Resize((self.resize, self.resize)), 52 | # transforms.RandomHorizontalFlip(), 53 | # transforms.RandomRotation(5), 54 | transforms.ToTensor(), 55 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 56 | ]) 57 | else: 58 | self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'), 59 | transforms.Resize((self.resize, self.resize)), 60 | transforms.ToTensor(), 61 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 62 | ]) 63 | 64 | self.path = os.path.join(root, 'images') # image path 65 | csvdata = self.loadCSV(os.path.join(root, mode + '.csv')) # csv path 66 | self.data = [] 67 | self.img2label = {} 68 | for i, (k, v) in enumerate(csvdata.items()): 69 | self.data.append(v) # [[img1, img2, ...], [img111, ...]] 70 | self.img2label[k] = i + self.startidx # {"img_name[:9]":label} 71 | self.cls_num = len(self.data) 72 | 73 | self.create_batch(self.batchsz) 74 | 75 | def loadCSV(self, csvf): 76 | """ 77 | return a dict saving the information of csv 78 | :param splitFile: csv file name 79 | :return: {label:[file1, file2 ...]} 80 | """ 81 | dictLabels = {} 82 | with open(csvf) as csvfile: 83 | csvreader = csv.reader(csvfile, delimiter=',') 84 | next(csvreader, None) # skip (filename, label) 85 | for i, row in enumerate(csvreader): 86 | filename = row[0] 87 | label = row[1] 88 | # append filename to current label 89 | if label in dictLabels.keys(): 90 | dictLabels[label].append(filename) 91 | else: 92 | dictLabels[label] = [filename] 93 | return dictLabels 94 | 95 | def create_batch(self, batchsz): 96 | """ 97 | create batch for meta-learning. 98 | ×episode× here means batch, and it means how many sets we want to retain. 99 | :param episodes: batch size 100 | :return: 101 | """ 102 | self.support_x_batch = [] # support set batch 103 | self.query_x_batch = [] # query set batch 104 | for b in range(batchsz): # for each batch 105 | # 1.select n_way classes randomly 106 | selected_cls = np.random.choice(self.cls_num, self.n_way, False) # no duplicate 107 | np.random.shuffle(selected_cls) 108 | support_x = [] 109 | query_x = [] 110 | for idx, cls in enumerate(selected_cls): 111 | # 2. select k_shot + k_query for each class 112 | selected_imgs_idx = np.random.choice(len(self.data[cls]), self.k_shot + self.k_query, False) 113 | np.random.shuffle(selected_imgs_idx) 114 | indexDtrain = np.array(selected_imgs_idx[:self.k_shot]) # idx for Dtrain 115 | indexDtest = np.array(selected_imgs_idx[self.k_shot:]) # idx for Dtest 116 | support_x.append( 117 | np.array(self.data[cls])[indexDtrain].tolist()) # get all images filename for current Dtrain 118 | query_x.append(np.array(self.data[cls])[indexDtest].tolist()) 119 | 120 | # shuffle the correponding relation between support set and query set 121 | random.shuffle(support_x) 122 | random.shuffle(query_x) 123 | 124 | self.support_x_batch.append(support_x) # append set to current sets 125 | self.query_x_batch.append(query_x) # append sets to current sets 126 | 127 | def __getitem__(self, index): 128 | """ 129 | index means index of sets, 0<= index <= batchsz-1 130 | :param index: 131 | :return: 132 | """ 133 | # [setsz, 3, resize, resize] 134 | support_x = torch.FloatTensor(self.setsz, 3, self.resize, self.resize) 135 | # [setsz] 136 | support_y = np.zeros((self.setsz), dtype=np.int) 137 | # [querysz, 3, resize, resize] 138 | query_x = torch.FloatTensor(self.querysz, 3, self.resize, self.resize) 139 | # [querysz] 140 | query_y = np.zeros((self.querysz), dtype=np.int) 141 | 142 | flatten_support_x = [os.path.join(self.path, item) 143 | for sublist in self.support_x_batch[index] for item in sublist] 144 | support_y = np.array( 145 | [self.img2label[item[:9]] # filename:n0153282900000005.jpg, the first 9 characters treated as label 146 | for sublist in self.support_x_batch[index] for item in sublist]) 147 | flatten_query_x = [os.path.join(self.path, item) 148 | for sublist in self.query_x_batch[index] for item in sublist] 149 | query_y = np.array([self.img2label[item[:9]] 150 | for sublist in self.query_x_batch[index] for item in sublist]) 151 | 152 | for i, path in enumerate(flatten_support_x): 153 | support_x[i] = self.transform(path) 154 | 155 | for i, path in enumerate(flatten_query_x): 156 | query_x[i] = self.transform(path) 157 | 158 | 159 | # now try to convert label from global indexing to local indexing. 160 | # support_y: [n_way * k_shot] 161 | # query_y: [n_way * k_query] 162 | support_y_idx = np.array([i for i in range(self.n_way) for j in range(self.k_shot)]) 163 | query_y_idx = [] 164 | # for each value in query_y 165 | for query_y_value in query_y: 166 | # find the equal value index in support_y 167 | for i in range(support_y.shape[0]): 168 | if query_y_value == support_y[i]: 169 | break 170 | # use the index as the label info. 171 | query_y_idx.append(support_y_idx[i]) 172 | query_y_idx = torch.from_numpy(np.array(query_y_idx)) 173 | support_y_idx = torch.from_numpy(np.array(support_y_idx)) 174 | 175 | # if np.random.randint(1000)< 1: 176 | # print('global indexing:') 177 | # print(support_y, query_y) 178 | # print('local indexing:') 179 | # print(support_y_idx.numpy(), query_y_idx.numpy()) 180 | 181 | return support_x, support_y_idx, query_x, query_y_idx 182 | 183 | def __len__(self): 184 | # as we have built up to batchsz of sets, you can sample some small batch size of sets. 185 | return self.batchsz 186 | 187 | 188 | if __name__ == '__main__': 189 | # the following episode is to view one set of images via tensorboard. 190 | from torchvision.utils import make_grid 191 | from matplotlib import pyplot as plt 192 | from tensorboardX import SummaryWriter 193 | import time 194 | 195 | plt.ion() 196 | 197 | tb = SummaryWriter('runs', 'mini-imagenet') 198 | mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=5, k_shot=1, k_query=1, batchsz=1000, resize=168) 199 | 200 | for i, set_ in enumerate(mini): 201 | # support_x: [k_shot*n_way, 3, 84, 84] 202 | support_x, support_y, query_x, query_y = set_ 203 | 204 | support_x = make_grid(support_x, nrow=2) 205 | query_x = make_grid(query_x, nrow=2) 206 | 207 | plt.figure(1) 208 | plt.imshow(support_x.transpose(2, 0).numpy()) 209 | plt.pause(0.5) 210 | plt.figure(2) 211 | plt.imshow(query_x.transpose(2, 0).numpy()) 212 | plt.pause(0.5) 213 | 214 | tb.add_image('support_x', support_x) 215 | tb.add_image('query_x', query_x) 216 | 217 | time.sleep(5) 218 | 219 | tb.close() 220 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reptile-Pytorch 2 | PyTorch implementation of the supervised learning experiments from the paper: 3 | Reptile: A Scalable Meta-Learning Algorithm: https://blog.openai.com/reptile/ 4 | , which is based on Model-Agnostic Meta-Learning (MAML): https://arxiv.org/abs/1703.03400 5 | 6 | 7 | 8 | # Ominiglot 9 | 10 | ## Howto 11 | change `dataset = 'omniglot' ` in `main.py` and just run `python main.py`, the program will download omniglot dataset automatically. 12 | modify the value of `meta_batchsz` to fit your GPU memory size. 13 | 14 | ## benchmark 15 | | Model | Fine Tune | 5-way Acc. | | 20-way Acc | | 16 | |------------------------------------- |----------- |--------------- |--------------- |-------------- |--------------- | 17 | | | | 1-shot | 5-shot | 1-shot | 5-shot | 18 | | MANN | N | 82.8% | 94.9% | - | - | 19 | | Matching Nets | N | 98.1% | 98.9% | 93.8% | 98.5% | 20 | | Matching Nets | Y | 97.9% | 98.7% | 93.5% | 98.7% | 21 | | MAML | Y | 98.7+-0.4% | 99.9+-0.1% | 95.8+-0.3% | 98.9+-0.2% | 22 | | **Ours** | Y | 98.62% | 99.52% | 96.09% | 98.24% | 23 | 24 | 25 | >5way 1shot episode: 11580\*512 finetune acc:0.990234 test acc:0.986250 26 | 27 | >5way 5shot episode: 27180\*128 finetune acc:0.995625 test acc:0.995219 28 | 29 | >20way 1shot episode: 23160\*128 finetune acc:0.960937 test acc:0.960898 30 | 31 | >20way 5shot episode: 11580\*32 finetune acc:0.985938 test acc:0.982437 32 | 33 | 34 | ## training curve 35 | ![test acc](res/test-acc.png) 36 | 37 | 38 | 39 | # ~~mini-Imagenet~~ (CAN NOT WORK!) 40 | 41 | > train `mini-imagenet` is extremely slow, since the code train task one by one squentially. 42 | 43 | ## Howto 44 | 45 | download `mini-imagenet` dataset and make it looks like: 46 | ```shell 47 | mini-imagenet/ 48 | ├── images 49 | ├── n0210891500001298.jpg 50 | ├── n0287152500001298.jpg 51 | ... 52 | ├── test.csv 53 | ├── val.csv 54 | └── train.csv 55 | 56 | MAML-Pytorch/ 57 | ├── main.py 58 | ├── meta.py 59 | ├── Readme.md 60 | ├── naive.md 61 | ... 62 | ``` 63 | 64 | change `dataset = 'mini-imagenet' ` in `main.py` and just run `python main.py`. 65 | 66 | ## benchmark 67 | 68 | | Model | Fine Tune | 5-way Acc. | | 20-way Acc | | 69 | |-------------------------------------|-----------|------------|--------|------------|--------| 70 | | | | 1-shot | 5-shot | 1-shot | 5-shot | 71 | | Matching Nets | N | 43.56% | 55.31% | 17.31% | 22.69% | 72 | | Meta-LSTM | | 43.44% | 60.60% | 16.70% | 26.06% | 73 | | MAML | Y | 48.7% | 63.11% | 16.49% | 19.29% | 74 | | **Ours** | Y | - | - | - | - | 75 | 76 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from omniglotNShot import OmniglotNShot 2 | from meta import MetaLearner 3 | from naive import Naive 4 | from MiniImagenet import MiniImagenet 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | import numpy as np 9 | from tensorboardX import SummaryWriter 10 | from torch.utils.data import DataLoader 11 | 12 | 13 | def main(): 14 | meta_batchsz = 32 15 | n_way = 5 16 | k_shot = 1 17 | k_query = k_shot 18 | meta_lr = 1e-3 19 | num_updates = 5 20 | dataset = 'omniglot' 21 | 22 | 23 | 24 | if dataset == 'omniglot': 25 | imgsz = 28 26 | db = OmniglotNShot('dataset', batchsz=meta_batchsz, n_way=n_way, k_shot=k_shot, k_query=k_query, imgsz=imgsz) 27 | 28 | elif dataset == 'mini-imagenet': 29 | imgsz = 84 30 | # the dataset loaders are different from omniglot to mini-imagenet. for omniglot, it just has one loader to use 31 | # get_batch(train or test) to get different batch. 32 | # for mini-imagenet, it should have two dataloader, one is train_loader and another is test_loader. 33 | mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, 34 | batchsz=10000, resize=imgsz) 35 | db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) 36 | mini_test = MiniImagenet('../mini-imagenet/', mode='test', n_way=n_way, k_shot=k_shot, k_query=k_query, 37 | batchsz=1000, resize=imgsz) 38 | db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True) 39 | 40 | else: 41 | raise NotImplementedError 42 | 43 | 44 | meta = MetaLearner(Naive, (n_way, imgsz), n_way=n_way, k_shot=k_shot, meta_batchsz=meta_batchsz, beta=meta_lr, 45 | num_updates=num_updates).cuda() 46 | 47 | tb = SummaryWriter('runs') 48 | 49 | 50 | # main loop 51 | for episode_num in range(200000): 52 | 53 | # 1. train 54 | if dataset == 'omniglot': 55 | support_x, support_y, query_x, query_y = db.get_batch('test') 56 | support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() 57 | query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() 58 | support_y = Variable(torch.from_numpy(support_y).long()).cuda() 59 | query_y = Variable(torch.from_numpy(query_y).long()).cuda() 60 | elif dataset == 'mini-imagenet': 61 | try: 62 | batch_test = iter(db).next() 63 | except StopIteration as err: 64 | mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, 65 | batchsz=10000, resize=imgsz) 66 | db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) 67 | 68 | support_x = Variable(batch_test[0]).cuda() 69 | support_y = Variable(batch_test[1]).cuda() 70 | query_x = Variable(batch_test[2]).cuda() 71 | query_y = Variable(batch_test[3]).cuda() 72 | 73 | # backprop has been embeded in forward func. 74 | accs = meta(support_x, support_y, query_x, query_y) 75 | train_acc = np.array(accs).mean() 76 | 77 | # 2. test 78 | if episode_num % 30 == 0: 79 | test_accs = [] 80 | for i in range(min(episode_num // 5000 + 3, 10)): # get average acc. 81 | if dataset == 'omniglot': 82 | support_x, support_y, query_x, query_y = db.get_batch('test') 83 | support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() 84 | query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda() 85 | support_y = Variable(torch.from_numpy(support_y).long()).cuda() 86 | query_y = Variable(torch.from_numpy(query_y).long()).cuda() 87 | elif dataset == 'mini-imagenet': 88 | try: 89 | batch_test = iter(db_test).next() 90 | except StopIteration as err: 91 | mini_test = MiniImagenet('../mini-imagenet/', mode='test', n_way=n_way, k_shot=k_shot, 92 | k_query=k_query, 93 | batchsz=1000, resize=imgsz) 94 | db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True) 95 | support_x = Variable(batch_test[0]).cuda() 96 | support_y = Variable(batch_test[1]).cuda() 97 | query_x = Variable(batch_test[2]).cuda() 98 | query_y = Variable(batch_test[3]).cuda() 99 | 100 | 101 | # get accuracy 102 | test_acc = meta.pred(support_x, support_y, query_x, query_y) 103 | test_accs.append(test_acc) 104 | 105 | test_acc = np.array(test_accs).mean() 106 | print('episode:', episode_num, '\tfinetune acc:%.6f' % train_acc, '\t\ttest acc:%.6f' % test_acc) 107 | tb.add_scalar('test-acc', test_acc) 108 | tb.add_scalar('finetune-acc', train_acc) 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /meta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import optim 4 | from torch import autograd 5 | from torch.autograd import Variable 6 | from torch.nn import functional as F 7 | import numpy as np 8 | 9 | 10 | class Learner(nn.Module): 11 | """ 12 | This is a learner class, which will accept a specific network module, such as OmniNet that define the network forward 13 | process. Learner class will create two same network, one as theta network and the other acts as theta_pi network. 14 | for each episode, the theta_pi network will copy its initial parameters from theta network and update several steps 15 | by meta-train set and then calculate its loss on meta-test set. All loss on meta-test set will be sumed together and 16 | then backprop on theta network, which should be done on metalaerner class. 17 | For learner class, it will be responsible for update for several steps on meta-train set and return with the loss on 18 | meta-test set. 19 | """ 20 | 21 | def __init__(self, net_cls, *args): 22 | """ 23 | It will receive a class: net_cls and its parameters: args for net_cls. 24 | :param net_cls: class, not instance 25 | :param args: the parameters for net_cls 26 | """ 27 | super(Learner, self).__init__() 28 | # pls make sure net_cls is a class but NOT an instance of class. 29 | assert net_cls.__class__ == type 30 | 31 | # we will create two class instance meanwhile and use one as theta network and the other as theta_pi network. 32 | self.net = net_cls(*args) 33 | # you must call create_pi_net to create pi network additionally 34 | self.net_pi = net_cls(*args) 35 | # update theta_pi = theta_pi - lr * grad 36 | # according to the paper, here we use naive version of SGD to update theta_pi 37 | # 0.1 here means the learner_lr 38 | self.optimizer = optim.SGD(self.net_pi.parameters(), 0.1) 39 | 40 | def parameters(self): 41 | """ 42 | Override this function to return only net parameters for MetaLearner's optimize 43 | it will ignore theta_pi network parameters. 44 | :return: 45 | """ 46 | return self.net.parameters() 47 | 48 | def update_pi(self): 49 | """ 50 | copy parameters from self.net -> self.net_pi 51 | :return: 52 | """ 53 | for m_from, m_to in zip(self.net.modules(), self.net_pi.modules()): 54 | if isinstance(m_to, nn.Linear) or isinstance(m_to, nn.Conv2d) or isinstance(m_to, nn.BatchNorm2d): 55 | m_to.weight.data = m_from.weight.data.clone() 56 | if m_to.bias is not None: 57 | m_to.bias.data = m_from.bias.data.clone() 58 | 59 | def forward(self, support_x, support_y, query_x, query_y, num_updates): 60 | """ 61 | learn on current episode meta-train: support_x & support_y and then calculate loss on meta-test set: query_x&y 62 | :param support_x: [setsz, c_, h, w] 63 | :param support_y: [setsz] 64 | :param query_x: [querysz, c_, h, w] 65 | :param query_y: [querysz] 66 | :param num_updates: 5 67 | :return: 68 | """ 69 | # now try to fine-tune from current $theta$ parameters -> $theta_pi$ 70 | # after num_updates of fine-tune, we will get a good theta_pi parameters so that it will retain satisfying 71 | # performance on specific task, that's, current episode. 72 | # firstly, copy theta_pi from theta network 73 | self.update_pi() 74 | 75 | # update for several steps 76 | for i in range(num_updates): 77 | # forward and backward to update net_pi grad. 78 | loss, pred = self.net_pi(support_x, support_y) 79 | self.optimizer.zero_grad() 80 | loss.backward() 81 | self.optimizer.step() 82 | 83 | # Compute the meta gradient and return it, the gradient is from one episode 84 | # in metalearner, it will merge all loss from different episode and sum over it. 85 | loss, pred = self.net_pi(query_x, query_y) 86 | # pred: [setsz, n_way], indices: [setsz] 87 | _, indices = torch.max(pred, dim=1) 88 | correct = torch.eq(indices, query_y).sum().data[0] 89 | acc = correct / query_y.size(0) 90 | 91 | # gradient for validation on theta_pi 92 | # after call autorad.grad, you can not call backward again except for setting create_graph = True 93 | # as we will use the loss as dummpy loss to conduct a dummy backprop to write our gradients to theta network, 94 | # here we set create_graph to true to support second time backward. 95 | grads_pi = autograd.grad(loss, self.net_pi.parameters(), create_graph=True) 96 | 97 | return loss, grads_pi, acc 98 | 99 | def net_forward(self, support_x, support_y): 100 | """ 101 | This function is purely for updating net network. In metalearner, we need the get the loss op from net network 102 | to write our merged gradients into net network, hence will call this function to get a dummy loss op. 103 | :param support_x: [setsz, c, h, w] 104 | :param support_y: [sessz, c, h, w] 105 | :return: dummy loss and dummy pred 106 | """ 107 | loss, pred = self.net(support_x, support_y) 108 | return loss, pred 109 | 110 | 111 | class MetaLearner(nn.Module): 112 | """ 113 | As we have mentioned in Learner class, the metalearner class will receive a series of loss on different tasks/episodes 114 | on theta_pi network, and it will merage all loss and then sum over it. The summed loss will be backproped on theta 115 | network to update theta parameters, which is the initialization point we want to find. 116 | """ 117 | 118 | def __init__(self, net_cls, net_cls_args, n_way, k_shot, meta_batchsz, beta, num_updates): 119 | """ 120 | 121 | :param net_cls: class, not instance. the class of specific Network for learner 122 | :param net_cls_args: tuple, args for net_cls, like (n_way, imgsz) 123 | :param n_way: 124 | :param k_shot: 125 | :param meta_batchsz: number of tasks/episode 126 | :param beta: learning rate for meta-learner 127 | :param num_updates: number of updates for learner 128 | """ 129 | super(MetaLearner, self).__init__() 130 | 131 | self.n_way = n_way 132 | self.k_shot = k_shot 133 | self.meta_batchsz = meta_batchsz 134 | self.beta = beta 135 | # self.alpha = alpha # set alpha in Learner.optimizer directly. 136 | self.num_updates = num_updates 137 | 138 | # it will contains a learner class to learn on episodes and gather the loss together. 139 | self.learner = Learner(net_cls, *net_cls_args) 140 | # the optimizer is to update theta parameters, not theta_pi parameters. 141 | self.optimizer = optim.Adam(self.learner.parameters(), lr=beta) 142 | 143 | def write_grads(self, dummy_loss, sum_grads_pi): 144 | """ 145 | write loss into learner.net, gradients come from sum_grads_pi. 146 | Since the gradients info is not calculated by general backward, we need this function to write the right gradients 147 | into theta network and update theta parameters as wished. 148 | :param dummy_loss: dummy loss, nothing but to write our gradients by hook 149 | :param sum_grads_pi: the summed gradients 150 | :return: 151 | """ 152 | 153 | # Register a hook on each parameter in the net that replaces the current dummy grad 154 | # with our grads accumulated across the meta-batch 155 | hooks = [] 156 | 157 | for i, v in enumerate(self.learner.parameters()): 158 | def closure(): 159 | ii = i 160 | return lambda grad: sum_grads_pi[ii] 161 | 162 | # if you write: hooks.append( v.register_hook(lambda grad : sum_grads_pi[i]) ) 163 | # it will pop an ERROR, i don't know why? 164 | hooks.append(v.register_hook(closure())) 165 | 166 | # use our sumed gradients_pi to update the theta/net network, 167 | # since our optimizer receive the self.net.parameters() only. 168 | self.optimizer.zero_grad() 169 | dummy_loss.backward() 170 | self.optimizer.step() 171 | 172 | # if you do NOT remove the hook, the GPU memory will expode!!! 173 | for h in hooks: 174 | h.remove() 175 | 176 | def forward(self, support_x, support_y, query_x, query_y): 177 | """ 178 | Here we receive a series of episode, each episode will be learned by learner and get a loss on parameters theta. 179 | we gather the loss and sum all the loss and then update theta network. 180 | setsz = n_way * k_shotf 181 | querysz = n_way * k_shot 182 | :param support_x: [meta_batchsz, setsz, c_, h, w] 183 | :param support_y: [meta_batchsz, setsz] 184 | :param query_x: [meta_batchsz, querysz, c_, h, w] 185 | :param query_y: [meta_batchsz, querysz] 186 | :return: 187 | """ 188 | sum_grads_pi = None 189 | meta_batchsz = support_y.size(0) 190 | 191 | # support_x[i]: [setsz, c_, h, w] 192 | # we do different learning task sequentially, not parallel. 193 | accs = [] 194 | # for each task/episode. 195 | for i in range(meta_batchsz): 196 | _, grad_pi, episode_acc = self.learner(support_x[i], support_y[i], query_x[i], query_y[i], self.num_updates) 197 | accs.append(episode_acc) 198 | if sum_grads_pi is None: 199 | sum_grads_pi = grad_pi 200 | else: # accumulate all gradients from different episode learner 201 | sum_grads_pi = [torch.add(i, j) for i, j in zip(sum_grads_pi, grad_pi)] 202 | 203 | # As we already have the grads to update 204 | # We use a dummy forward / backward pass to get the correct grads into self.net 205 | # the right grads will be updated by hook, ignoring backward. 206 | # use hook mechnism to write sumed gradient into network. 207 | # we need to update the theta/net network, we need a op from net network, so we call self.learner.net_forward 208 | # to get the op from net network, since the loss from self.learner.forward will return loss from net_pi network. 209 | dummy_loss, _ = self.learner.net_forward(support_x[0], support_y[0]) 210 | self.write_grads(dummy_loss, sum_grads_pi) 211 | 212 | return accs 213 | 214 | def pred(self, support_x, support_y, query_x, query_y): 215 | """ 216 | predict for query_x 217 | :param support_x: 218 | :param support_y: 219 | :param query_x: 220 | :param query_y: 221 | :return: 222 | """ 223 | meta_batchsz = support_y.size(0) 224 | 225 | accs = [] 226 | # for each task/episode. 227 | # the learner will copy parameters from current theta network and then fine-tune on support set. 228 | for i in range(meta_batchsz): 229 | _, _, episode_acc = self.learner(support_x[i], support_y[i], query_x[i], query_y[i], self.num_updates) 230 | accs.append(episode_acc) 231 | 232 | return np.array(accs).mean() 233 | -------------------------------------------------------------------------------- /naive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import optim 4 | from torch import autograd 5 | from torch.autograd import Variable 6 | from torch.nn import functional as F 7 | 8 | 9 | class Naive(nn.Module): 10 | """ 11 | Define your network here. 12 | """ 13 | def __init__(self, n_way, imgsz): 14 | super(Naive, self).__init__() 15 | 16 | if imgsz > 28: # for mini-imagenet 17 | self.net = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3), 18 | nn.AvgPool2d(kernel_size=2), 19 | nn.BatchNorm2d(64), 20 | nn.ReLU(inplace=True), 21 | 22 | nn.Conv2d(64, 64, kernel_size=3), 23 | nn.AvgPool2d(kernel_size=2), 24 | nn.BatchNorm2d(64), 25 | nn.ReLU(inplace=True), 26 | 27 | nn.Conv2d(64, 64, kernel_size=3), 28 | nn.BatchNorm2d(64), 29 | nn.ReLU(inplace=True), 30 | 31 | nn.Conv2d(64, 64, kernel_size=3), 32 | nn.BatchNorm2d(64), 33 | nn.ReLU(inplace=True), 34 | 35 | nn.MaxPool2d(3,2) 36 | 37 | ) 38 | else: # for omniglot 39 | self.net = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3), 40 | nn.AvgPool2d(kernel_size=2), 41 | nn.BatchNorm2d(64), 42 | nn.ReLU(inplace=True), 43 | 44 | nn.Conv2d(64, 64, kernel_size=3), 45 | nn.AvgPool2d(kernel_size=2), 46 | nn.BatchNorm2d(64), 47 | nn.ReLU(inplace=True), 48 | 49 | nn.Conv2d(64, 64, kernel_size=3), 50 | nn.BatchNorm2d(64), 51 | nn.ReLU(inplace=True), 52 | 53 | nn.Conv2d(64, 64, kernel_size=3), 54 | nn.BatchNorm2d(64), 55 | nn.ReLU(inplace=True) 56 | 57 | ) 58 | 59 | # dummy forward to get feature size 60 | dummy_img = Variable(torch.randn(2, 3, imgsz, imgsz)) 61 | repsz = self.net(dummy_img).size() 62 | _, c, h, w = repsz 63 | self.fc_dim = c * h * w 64 | 65 | self.fc = nn.Sequential(nn.Linear(self.fc_dim, 64), 66 | nn.ReLU(inplace=True), 67 | nn.Linear(64, n_way)) 68 | 69 | self.criteon = nn.CrossEntropyLoss() 70 | 71 | print(self) 72 | print('Naive repnet sz:', repsz) 73 | 74 | def forward(self, x, target): 75 | x = self.net(x) 76 | x = x.view(-1, self.fc_dim) 77 | pred = self.fc(x) 78 | loss = self.criteon(pred, target) 79 | 80 | return loss, pred 81 | -------------------------------------------------------------------------------- /omniglot.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os 3 | import os.path 4 | import errno 5 | 6 | 7 | class Omniglot(data.Dataset): 8 | urls = [ 9 | 'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip', 10 | 'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip' 11 | ] 12 | raw_folder = 'raw' 13 | processed_folder = 'processed' 14 | training_file = 'training.pt' 15 | test_file = 'test.pt' 16 | 17 | ''' 18 | The items are (filename,category). The index of all the categories can be found in self.idx_classes 19 | Args: 20 | - root: the directory where the dataset will be stored 21 | - transform: how to transform the input 22 | - target_transform: how to transform the target 23 | - download: need to download the dataset 24 | ''' 25 | 26 | def __init__(self, root, transform=None, target_transform=None, download=False): 27 | self.root = root 28 | self.transform = transform 29 | self.target_transform = target_transform 30 | 31 | if not self._check_exists(): 32 | if download: 33 | self.download() 34 | else: 35 | raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') 36 | 37 | self.all_items = find_classes(os.path.join(self.root, self.processed_folder)) 38 | self.idx_classes = index_classes(self.all_items) 39 | 40 | def __getitem__(self, index): 41 | filename = self.all_items[index][0] 42 | img = str.join('/', [self.all_items[index][2], filename]) 43 | 44 | target = self.idx_classes[self.all_items[index][1]] 45 | if self.transform is not None: 46 | img = self.transform(img) 47 | if self.target_transform is not None: 48 | target = self.target_transform(target) 49 | 50 | return img, target 51 | 52 | def __len__(self): 53 | return len(self.all_items) 54 | 55 | def _check_exists(self): 56 | return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \ 57 | os.path.exists(os.path.join(self.root, self.processed_folder, "images_background")) 58 | 59 | def download(self): 60 | from six.moves import urllib 61 | import zipfile 62 | 63 | if self._check_exists(): 64 | return 65 | 66 | # download files 67 | try: 68 | os.makedirs(os.path.join(self.root, self.raw_folder)) 69 | os.makedirs(os.path.join(self.root, self.processed_folder)) 70 | except OSError as e: 71 | if e.errno == errno.EEXIST: 72 | pass 73 | else: 74 | raise 75 | 76 | for url in self.urls: 77 | print('== Downloading ' + url) 78 | data = urllib.request.urlopen(url) 79 | filename = url.rpartition('/')[2] 80 | file_path = os.path.join(self.root, self.raw_folder, filename) 81 | with open(file_path, 'wb') as f: 82 | f.write(data.read()) 83 | file_processed = os.path.join(self.root, self.processed_folder) 84 | print("== Unzip from " + file_path + " to " + file_processed) 85 | zip_ref = zipfile.ZipFile(file_path, 'r') 86 | zip_ref.extractall(file_processed) 87 | zip_ref.close() 88 | print("Download finished.") 89 | 90 | 91 | def find_classes(root_dir): 92 | retour = [] 93 | for (root, dirs, files) in os.walk(root_dir): 94 | for f in files: 95 | if (f.endswith("png")): 96 | r = root.split('/') 97 | lr = len(r) 98 | retour.append((f, r[lr - 2] + "/" + r[lr - 1], root)) 99 | print("== Found %d items " % len(retour)) 100 | return retour 101 | 102 | 103 | def index_classes(items): 104 | idx = {} 105 | for i in items: 106 | if i[1] not in idx: 107 | idx[i[1]] = len(idx) 108 | print("== Found %d classes" % len(idx)) 109 | return idx 110 | -------------------------------------------------------------------------------- /omniglotNShot.py: -------------------------------------------------------------------------------- 1 | from omniglot import Omniglot 2 | import torchvision.transforms as transforms 3 | from PIL import Image 4 | import os.path 5 | import numpy as np 6 | 7 | 8 | class OmniglotNShot(): 9 | def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz): 10 | """ 11 | 12 | :param dataroot: 13 | :param batch_size: 14 | :param n_way: 15 | :param k_shot: 16 | """ 17 | 18 | self.resize = imgsz 19 | if not os.path.isfile(os.path.join(root, 'omni.npy')): 20 | # if root/data.npy does not exist, just download it 21 | self.x = Omniglot(root, download=True, 22 | transform=transforms.Compose([lambda x: Image.open(x).convert('L'), 23 | transforms.Resize(self.resize), 24 | lambda x: np.reshape(x, (self.resize, self.resize, 1))])) 25 | 26 | temp = dict() # {label:img1, img2..., 20 imgs in total, 1623 label} 27 | for (img, label) in self.x: 28 | if label in temp: 29 | temp[label].append(img) 30 | else: 31 | temp[label] = [img] 32 | 33 | self.x = [] 34 | for label, imgs in temp.items(): # labels info deserted , each label contains 20imgs 35 | self.x.append(np.array(imgs)) 36 | 37 | # as different class may have different number of imgs 38 | self.x = np.array(self.x) # [[20 imgs],..., 1623 classes in total] 39 | # each character contains 20 imgs 40 | print('dataset shape:', self.x.shape) # [1623, 20, 84, 84, 1] 41 | temp = [] # Free memory 42 | # save all dataset into npy file. 43 | np.save(os.path.join(root, 'omni.npy'), self.x) 44 | else: 45 | # if data.npy exists, just load it. 46 | self.x = np.load(os.path.join(root, 'omni.npy')) 47 | 48 | self.x = self.x / 255 49 | # self.x: [1623, shuffled, 20 imgs, 84, 84, 1] 50 | np.random.shuffle(self.x) # shuffle on the first dim = 1623 cls 51 | 52 | self.x_train, self.x_test = self.x[:1200], self.x[1200:] 53 | self.normalization() 54 | 55 | self.batchsz = batchsz 56 | self.n_cls = self.x.shape[0] # 1623 57 | self.n_way = n_way # n way 58 | self.k_shot = k_shot # k shot 59 | self.k_query = k_query # k query 60 | 61 | # save pointer of current read batch in total cache 62 | self.indexes = {"train": 0, "test": 0} 63 | self.datasets = {"train": self.x_train, "test": self.x_test} # original data cached 64 | print("train_shape", self.x_train.shape, "test_shape", self.x_test.shape) 65 | 66 | self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]), # current epoch data cached 67 | "test": self.load_data_cache(self.datasets["test"])} 68 | 69 | def normalization(self): 70 | """ 71 | Normalizes our data, to have a mean of 0 and sdt of 1 72 | """ 73 | self.mean = np.mean(self.x_train) 74 | self.std = np.std(self.x_train) 75 | self.max = np.max(self.x_train) 76 | self.min = np.min(self.x_train) 77 | print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) 78 | self.x_train = (self.x_train - self.mean) / self.std 79 | self.x_test = (self.x_test - self.mean) / self.std 80 | 81 | self.mean = np.mean(self.x_train) 82 | self.std = np.std(self.x_train) 83 | self.max = np.max(self.x_train) 84 | self.min = np.min(self.x_train) 85 | print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) 86 | 87 | def load_data_cache(self, data_pack): 88 | """ 89 | Collects several batches data for N-shot learning 90 | :param data_pack: [cls_num, 20, 84, 84, 1] 91 | :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks 92 | """ 93 | # take 5 way 1 shot as example: 5 * 1 94 | setsz = self.k_shot * self.n_way 95 | querysz = self.k_query * self.n_way 96 | data_cache = [] 97 | 98 | # print('preload next 50 caches of batchsz of batch.') 99 | for sample in range(50): # num of episodes 100 | # (batch, setsz, imgs) 101 | support_x = np.zeros((self.batchsz, setsz, self.resize, self.resize, 1)) 102 | # (batch, setsz) 103 | support_y = np.zeros((self.batchsz, setsz), dtype=np.int) 104 | # (batch, querysz, imgs) 105 | query_x = np.zeros((self.batchsz, querysz, self.resize, self.resize, 1)) 106 | # (batch, querysz) 107 | query_y = np.zeros((self.batchsz, querysz), dtype=np.int) 108 | 109 | for i in range(self.batchsz): # one batch means one set 110 | shuffle_idx = np.arange(self.n_way) # [0,1,2,3,4] 111 | np.random.shuffle(shuffle_idx) # [2,4,1,0,3] 112 | shuffle_idx_test = np.arange(self.n_way) # [0,1,2,3,4] 113 | np.random.shuffle(shuffle_idx_test) # [2,0,1,4,3] 114 | selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False) 115 | 116 | for j, cur_class in enumerate(selected_cls): # for each selected cls 117 | # Count number of times this class is inside the meta-test 118 | # [img1, img2 ,,, = k_shot + k_query ] 119 | selected_imgs = np.random.choice(data_pack.shape[1], self.k_shot + self.k_query, False) 120 | 121 | # meta-training, select the first k_shot imgs for each class as support imgs 122 | for offset, img in enumerate(selected_imgs[:self.k_shot]): 123 | # i: batch idx 124 | # cur_class: cls in n_way 125 | support_x[i, shuffle_idx[j] * self.k_shot + offset, ...] = data_pack[cur_class][img] 126 | support_y[i, shuffle_idx[j] * self.k_shot + offset] = j # relative indexing 127 | 128 | # meta-test, treat following k_query imgs as query imgs 129 | for offset, img in enumerate(selected_imgs[self.k_shot:]): 130 | query_x[i, shuffle_idx_test[j] * self.k_query + offset, ...] = data_pack[cur_class][img] 131 | query_y[i, shuffle_idx_test[j] * self.k_query + offset] = j # relative indexing 132 | 133 | data_cache.append([support_x, support_y, query_x, query_y]) 134 | return data_cache 135 | 136 | def __get_batch(self, mode): 137 | """ 138 | Gets next batch from the dataset with name. 139 | :param dataset_name: The name of the dataset (one of "train", "val", "test") 140 | :return: 141 | """ 142 | # update cache if indexes is larger cached num 143 | if self.indexes[mode] >= len(self.datasets_cache[mode]): 144 | self.indexes[mode] = 0 145 | self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode]) 146 | 147 | next_batch = self.datasets_cache[mode][self.indexes[mode]] 148 | self.indexes[mode] += 1 149 | 150 | return next_batch 151 | 152 | def get_batch(self, mode): 153 | 154 | """ 155 | Get next batch 156 | :return: Next batch 157 | """ 158 | x_support_set, y_support_set, x_target, y_target = self.__get_batch(mode) 159 | 160 | k = int(np.random.uniform(low=0, high=4)) # 0 - 3 161 | # Iterate over the sequence. Extract batches. 162 | 163 | for i in np.arange(x_support_set.shape[0]): 164 | # batchsz, setsz, c, h, w 165 | x_support_set[i, :, :, :, :] = self.__rotate_batch(x_support_set[i, :, :, :, :], k) 166 | 167 | # Rotate all the batch of the target images 168 | for i in np.arange(x_target.shape[0]): 169 | x_target[i, :, :, :, :] = self.__rotate_batch(x_target[i, :, :, :, :], k) 170 | 171 | return x_support_set, y_support_set, x_target, y_target 172 | 173 | def __rotate_batch(self, batch_images, k): 174 | """ 175 | Rotates a whole image batch 176 | :param batch_images: A batch of images 177 | :param k: integer degree of rotation counter-clockwise 178 | :return: The rotated batch of images 179 | """ 180 | batch_size = len(batch_images) 181 | for i in np.arange(batch_size): 182 | batch_images[i] = np.rot90(batch_images[i], k) 183 | return batch_images 184 | 185 | 186 | if __name__ == '__main__': 187 | # the following episode is to view one set of images via tensorboard. 188 | from torchvision.utils import make_grid 189 | from matplotlib import pyplot as plt 190 | from tensorboardX import SummaryWriter 191 | import time 192 | import torch 193 | 194 | plt.ion() 195 | 196 | tb = SummaryWriter('runs', 'mini-imagenet') 197 | db = OmniglotNShot('dataset', batchsz=20, n_way=5, k_shot=5, k_query=2) 198 | 199 | set_ = db.get_batch('train') 200 | while set_ != None: 201 | # support_x: [k_shot*n_way, 3, 84, 84] 202 | support_x, support_y, query_x, query_y = set_ 203 | print(support_y[0]) 204 | print(query_y[0]) 205 | # [b, setsz, h, w, c] => [b, setsz, c, w, h] => [b, setsz, 3c, w, h] 206 | support_x = torch.from_numpy(support_x).float().transpose(2, 4).repeat(1, 1, 3, 1, 1) 207 | query_x = torch.from_numpy(query_x).float().transpose(2, 4).repeat(1, 1, 3, 1, 1) 208 | support_y = torch.from_numpy(support_y).float() # [batch, setsz, 1] 209 | query_y = torch.from_numpy(query_y).float() 210 | batchsz, setsz, c, h, w = support_x.size() 211 | 212 | support_x = make_grid(support_x[0], nrow=5) 213 | query_x = make_grid(query_x[0], nrow=2) 214 | 215 | plt.figure('support x') 216 | plt.imshow(support_x.transpose(2, 0).transpose(1, 0).numpy()) 217 | plt.pause(0.5) 218 | plt.figure('query x') 219 | plt.imshow(query_x.transpose(2, 0).transpose(1, 0).numpy()) 220 | plt.pause(0.5) 221 | 222 | tb.add_image('support_x', support_x) 223 | tb.add_image('query_x', query_x) 224 | 225 | set_ = db.get_batch('train') 226 | time.sleep(10) 227 | 228 | tb.close() 229 | -------------------------------------------------------------------------------- /res/test-acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragen1860/Reptile-Pytorch/0bd5d777c167fa1aea66b5dc325a7fc4e56632ab/res/test-acc.png --------------------------------------------------------------------------------