├── .gitognore ├── README.md ├── main.py ├── meta.py ├── omniglot.py └── omniglotNShot.py /.gitognore: -------------------------------------------------------------------------------- 1 | *.ipynb_checkpoints 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MAML-Pytorch 2 | PyTorch implementation of the supervised learning experiments from the paper: 3 | Model-Agnostic Meta-Learning (MAML): https://arxiv.org/abs/1703.03400 4 | 5 | 6 | 7 | refer to this github : https://github.com/dragen1860/MAML-Pytorch, https://github.com/lim0606/MAML-Pytorch 8 | 9 | # Platform 10 | 11 | - Python : 3.7.3 12 | - Pytorch : 1.2.0 13 | 14 | # Ominiglot 15 | 16 | run ```python main.py``` , the program will download omniglot dataset automatically 17 | 18 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from omniglotNShot import OmniglotNShot 2 | from meta import MetaLearner 3 | 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | 9 | class Net(nn.Module): 10 | def __init__(self, n_way, img_size): 11 | super(Net, self).__init__() 12 | 13 | self.net = nn.Sequential(nn.Conv2d(1, 64, kernel_size=3), 14 | nn.AvgPool2d(kernel_size=2), 15 | nn.BatchNorm2d(64), 16 | nn.ReLU(inplace=True), 17 | 18 | nn.Conv2d(64, 64, kernel_size=3), 19 | nn.AvgPool2d(kernel_size=2), 20 | nn.BatchNorm2d(64), 21 | nn.ReLU(inplace=True), 22 | 23 | nn.Conv2d(64, 64, kernel_size=3), 24 | nn.BatchNorm2d(64), 25 | nn.ReLU(inplace=True), 26 | 27 | nn.Conv2d(64, 64, kernel_size=3), 28 | nn.BatchNorm2d(64), 29 | nn.ReLU(inplace=True)) 30 | 31 | self.fc = nn.Sequential(nn.Linear(64, 64), 32 | nn.ReLU(inplace=True), 33 | nn.Linear(64, n_way)) 34 | 35 | self.loss = nn.CrossEntropyLoss() 36 | 37 | def forward(self, x, target): 38 | # x:[5, 1, 28, 28] : 5 way 1 shot 39 | x = self.net(x) 40 | x = x.view(-1, 64) 41 | pred = self.fc(x) 42 | loss = self.loss(pred, target) 43 | 44 | return loss, pred 45 | 46 | def main(): 47 | meta_batch_size = 32 48 | n_way = 5 49 | k_shot = 1 50 | k_query = 1 51 | meta_lr = 1e-3 52 | num_updates = 5 53 | 54 | img_size = 28 55 | omni_data = OmniglotNShot('dataset', batch_size=meta_batch_size, n_way=n_way, 56 | k_shot=k_shot, k_query=k_query, img_size=img_size) 57 | 58 | meta = MetaLearner(Net, (n_way, img_size), n_way=n_way, k_shot=k_shot, meta_batch_size=meta_batch_size, 59 | alpha=0.1, beta=meta_lr, num_updates=num_updates).cuda() 60 | 61 | for episode_num in range(100): 62 | support_x, support_y, query_x, query_y = omni_data.get_batch('train') # support, query for train 63 | # support_x : [32, 5, 1, 28, 28] 64 | support_x = torch.from_numpy(support_x).float().cuda() 65 | query_x = torch.from_numpy(query_x).float().cuda() 66 | support_y = torch.from_numpy(support_y).long().cuda() 67 | query_y = torch.from_numpy(query_y).long().cuda() 68 | accs = meta(support_x, support_y, query_x, query_y) 69 | train_acc = np.array(accs).mean() 70 | 71 | if episode_num % 30 == 0: 72 | test_accs = [] 73 | 74 | support_x, support_y, query_x, query_y = omni_data.get_batch('test') # support, query for test 75 | support_x = torch.from_numpy(support_x).float().cuda() 76 | query_x = torch.from_numpy(query_x).float().cuda() 77 | support_y = torch.from_numpy(support_y).long().cuda() 78 | query_y = torch.from_numpy(query_y).long().cuda() 79 | 80 | test_acc = meta.pred(support_x, support_y, query_x, query_y) 81 | test_accs.append(test_acc) 82 | 83 | test_acc = np.array(test_accs).mean() 84 | print('episode:',episode_num, '\tfintune acc:%.6f' % train_acc, '\t\ttest acc:%.6f' % test_acc) 85 | 86 | if __name__ == '__main__': 87 | main() -------------------------------------------------------------------------------- /meta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import optim 4 | from torch import autograd 5 | import numpy as np 6 | 7 | class Learner(nn.Module): 8 | 9 | def __init__(self, net, alpha, *args): 10 | super(Learner, self).__init__() 11 | self.alpha = alpha 12 | 13 | self.net_theta = net(*args) # theta : prior / general 14 | self.net_phi = net(*args) # phi : task specific 15 | self.optimizer = optim.SGD(self.net_phi.parameters(), self.alpha) # Learner(inner loop, for task specific phi) 16 | 17 | def forward(self, support_x, support_y, query_x, query_y, num_updates): 18 | # To get phi from current theta (fine tune) 19 | # copy theta to phit 20 | 21 | with torch.no_grad(): 22 | for theta, phi in zip(self.net_theta.modules(), self.net_phi.modules()): 23 | if isinstance(phi, nn.Linear) or isinstance(phi, nn.Conv2d) or isinstance(phi, nn.BatchNorm2d): 24 | phi.weight.data = theta.weight.clone() # you must use .clone() 25 | if phi.bias is not None: 26 | phi.bias.data = theta.bias.clone() 27 | # clone():copy the data to another memory but it has no interfere with gradient back propagation (cf. deepcopy) 28 | 29 | # support_x: [5, 1, 28, 28] 30 | for i in range(num_updates): 31 | loss, pred = self.net_phi(support_x, support_y) 32 | self.optimizer.zero_grad() 33 | loss.backward() 34 | self.optimizer.step() 35 | 36 | # Calculating meta gradient 37 | # Calculate phi net's gradient to update theta by meta learner 38 | loss, pred = self.net_phi(query_x, query_y) 39 | # pred : [dataset_size, n_way] (5,5) 40 | _, indices = torch.max(pred, dim=1) 41 | correct = torch.eq(indices, query_y).sum().item() 42 | acc = correct/query_y.size(0) 43 | 44 | # create_graph=True : Can recall backward after autograd.grad (for Hessian) 45 | gradient_phi = autograd.grad(loss, self.net_phi.parameters(), create_graph=True) #create_graph : for second derivative 46 | 47 | return loss, gradient_phi, acc 48 | 49 | def net_forward(self, support_x, support_y): 50 | # theta update (general) 51 | # To write the merged gradients in net_theta network from metalearner 52 | 53 | loss, pred = self.net_theta(support_x, support_y) 54 | return loss, pred 55 | 56 | class MetaLearner(nn.Module): 57 | # Received the loss of various tasks in net_pi network and found a general initialization parameter that combines everything. 58 | # Update theta by using phi and meta-test set for every episode 59 | 60 | def __init__(self, net, net_args, n_way, k_shot, meta_batch_size, alpha, beta, num_updates): 61 | super(MetaLearner, self).__init__() 62 | 63 | self.n_way = n_way 64 | self.k_shot = k_shot 65 | self.meta_batch_size = meta_batch_size 66 | self.beta = beta 67 | self.num_updates = num_updates 68 | 69 | self.learner = Learner(net, alpha, *net_args) 70 | self.optimizer = optim.Adam(self.learner.parameters(), lr=beta) 71 | 72 | def meta_update(self, dummy_loss, sum_grads_phi): 73 | # Update theta_parameter by sum_gradients 74 | hooks = [] 75 | for k, v in enumerate(self.learner.parameters()): 76 | def closure(): 77 | key = k 78 | return lambda grad: sum_grads_phi[key] 79 | 80 | hooks.append(v.register_hook(closure())) 81 | # register_hook : If you manipulate the gradients, the optimizer will use these new custom gradients to update the parameters 82 | # If you want to save gradients 83 | # The purpose of this piece of code is to investigate how to use modified gradient to update parameters. 84 | 85 | self.optimizer.zero_grad() 86 | dummy_loss.backward() # dummy_loss : summed gradients_phi (for general theta network) 87 | self.optimizer.step() 88 | 89 | for h in hooks: 90 | h.remove() 91 | 92 | def forward(self, support_x, support_y, query_x, query_y): 93 | # Learned by Learner for every episode -> get the losses of parameter theta 94 | # Get loss and combine to update theta 95 | 96 | sum_grads_phi = None 97 | meta_batch_size = support_y.size(0) # 5 98 | 99 | accs = [] 100 | for i in range(meta_batch_size): 101 | _, grad_phi, episode_acc = self.learner(support_x[i], support_y[i], query_x[i], query_y[i], self.num_updates) 102 | accs.append(episode_acc) 103 | if sum_grads_phi is None: 104 | sum_grads_phi = grad_phi 105 | else: 106 | sum_grads_phi = [torch.add(i,j) for i,j in zip(sum_grads_phi, grad_phi)] # to get theta 107 | 108 | dummy_loss, _ = self.learner.net_forward(support_x[0], support_y[0]) 109 | # support_x[0] : [5, 1, 28, 28] 110 | self.meta_update(dummy_loss, sum_grads_phi) 111 | 112 | return accs 113 | 114 | def pred(self, support_x, support_y, query_x, query_y): 115 | meta_batch_size = support_y.size(0) 116 | accs = [] 117 | 118 | for i in range(meta_batch_size): 119 | _, _, episode_acc = self.learner(support_x[i], support_y[i], query_x[i], query_y[i], self.num_updates) 120 | accs.append(episode_acc) 121 | 122 | return np.array(accs).mean() 123 | -------------------------------------------------------------------------------- /omniglot.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os 3 | import os.path 4 | import errno 5 | 6 | 7 | class Omniglot(data.Dataset): 8 | urls = [ 9 | 'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip', 10 | 'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip' 11 | ] 12 | raw_folder = 'raw' 13 | processed_folder = 'processed' 14 | training_file = 'training.pt' 15 | test_file = 'test.pt' 16 | 17 | ''' 18 | The items are (filename,category). The index of all the categories can be found in self.idx_classes 19 | Args: 20 | - root: the directory where the dataset will be stored 21 | - transform: how to transform the input 22 | - target_transform: how to transform the target 23 | - download: need to download the dataset 24 | ''' 25 | 26 | def __init__(self, root, transform=None, target_transform=None, download=False): 27 | self.root = root 28 | self.transform = transform 29 | self.target_transform = target_transform 30 | 31 | if not self._check_exists(): 32 | if download: 33 | self.download() 34 | else: 35 | raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') 36 | 37 | self.all_items = find_classes(os.path.join(self.root, self.processed_folder)) 38 | self.idx_classes = index_classes(self.all_items) 39 | 40 | def __getitem__(self, index): 41 | filename = self.all_items[index][0] 42 | img = str.join('/', [self.all_items[index][2], filename]) 43 | 44 | target = self.idx_classes[self.all_items[index][1]] 45 | if self.transform is not None: 46 | img = self.transform(img) 47 | 48 | if self.target_transform is not None: 49 | target = self.target_transform(target) 50 | 51 | return img, target 52 | 53 | def __len__(self): 54 | return len(self.all_items) 55 | 56 | def _check_exists(self): 57 | return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \ 58 | os.path.exists(os.path.join(self.root, self.processed_folder, "images_background")) 59 | 60 | def download(self): 61 | from six.moves import urllib 62 | import zipfile 63 | 64 | if self._check_exists(): 65 | return 66 | 67 | # download files 68 | try: 69 | os.makedirs(os.path.join(self.root, self.raw_folder)) 70 | os.makedirs(os.path.join(self.root, self.processed_folder)) 71 | except OSError as e: 72 | if e.errno == errno.EEXIST: 73 | pass 74 | else: 75 | raise 76 | 77 | for url in self.urls: 78 | print('== Downloading ' + url) 79 | data = urllib.request.urlopen(url) 80 | filename = url.rpartition('/')[2] 81 | file_path = os.path.join(self.root, self.raw_folder, filename) 82 | with open(file_path, 'wb') as f: 83 | f.write(data.read()) 84 | file_processed = os.path.join(self.root, self.processed_folder) 85 | print("== Unzip from " + file_path + " to " + file_processed) 86 | zip_ref = zipfile.ZipFile(file_path, 'r') 87 | zip_ref.extractall(file_processed) 88 | zip_ref.close() 89 | print("Download finished.") 90 | 91 | 92 | def find_classes(root_dir): 93 | retour = [] 94 | for (root, dirs, files) in os.walk(root_dir): 95 | for f in files: 96 | if (f.endswith("png")): 97 | r = root.split('/') 98 | lr = len(r) 99 | retour.append((f, r[lr - 2] + "/" + r[lr - 1], root)) 100 | print("== Found %d items " % len(retour)) 101 | return retour 102 | 103 | 104 | def index_classes(items): 105 | idx = {} 106 | for i in items: 107 | if i[1] not in idx: 108 | idx[i[1]] = len(idx) 109 | print("== Found %d classes" % len(idx)) 110 | return idx 111 | -------------------------------------------------------------------------------- /omniglotNShot.py: -------------------------------------------------------------------------------- 1 | from omniglot import Omniglot 2 | 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | from PIL import Image 6 | import os.path 7 | import numpy as np 8 | 9 | class OmniglotNShot(): 10 | def __init__(self, root, batch_size, n_way, k_shot, k_query, img_size): 11 | self.resize = img_size 12 | 13 | if not os.path.isfile(os.path.join(root, 'omni.npy')): 14 | self.x = Omniglot(root, download=True, 15 | transform = transforms.Compose([lambda x: Image.open(x).convert('L'), 16 | transforms.Resize(self.resize), 17 | lambda x : np.reshape(x, (self.resize, self.resize, 1)), 18 | lambda x : x/255., 19 | lambda x: np.transpose(x, [2,0,1]), 20 | ])) 21 | 22 | temp = dict() #{label : img1, img2, ..., 20 imgs in total, 1623 label} 23 | # len(self.x) # 32460 = 20 * 1623 24 | for (img, label) in self.x: 25 | if label in temp: 26 | # img :[1,28,28] 27 | temp[label].append(img) 28 | else: 29 | temp[label] = [img] 30 | 31 | self.x = [] 32 | for label, imgs in temp.items(): 33 | self.x.append(np.array(imgs)) 34 | 35 | self.x = np.array(self.x) # [[20 imgs], ... , 1623 classes in total] 36 | 37 | temp = [] # Free memory 38 | np.save(os.path.join(root, 'omni.npy'), self.x) 39 | 40 | else: 41 | self.x = np.load(os.path.join(root, 'omni.npy')) 42 | 43 | # x : [1623, 20, 1, 28, 28] 44 | np.random.shuffle(self.x) # shuffle on the first dim = 1623 cls 45 | self.x_train, self.x_test = self.x[:1200], self.x[1200:] 46 | 47 | # normalization 48 | self.x_train = (self.x_train - np.mean(self.x_train)) / np.std(self.x_train) 49 | self.x_test = (self.x_test - np.mean(self.x_test)) / np.std(self.x_test) 50 | 51 | self.batch_size = batch_size 52 | self.n_class = self.x.shape[0] # 1623 53 | self.n_way = n_way 54 | self.k_shot = k_shot 55 | self.k_query = k_query 56 | 57 | self.indexes = {"train": 0, "test": 0} 58 | self.datasets = {"train": self.x_train, "test": self.x_test} 59 | 60 | self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]), 61 | "test": self.load_data_cache(self.datasets["test"])} 62 | # len(self.datasets_cache['train'])) : 50 63 | # self.datasets_cache['train'][0][0] : [32,5,1,28,28] 64 | # self.datasets_cache['train'][0][1] : [32,5] 65 | def load_data_cache(self, data_pack): 66 | """ 67 | Collects several batches data for N-shot learning 68 | N shot Learning을 한 data batches 69 | data_pack : [class_num, 20, 1, 28, 28] #class_num : train일 때 1200, test는 423 70 | return : A list [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks 71 | """ 72 | 73 | dataset_size = self.k_shot * self.n_way 74 | query_size = self.k_query * self.n_way 75 | data_cache = [] 76 | 77 | for sample in range(100): # num of eisodes 78 | 79 | support_x = np.zeros((self.batch_size, dataset_size, 1, self.resize, self.resize)) # [32, 5, 28, 28, 1] 80 | support_y = np.zeros((self.batch_size, dataset_size), dtype=np.int) 81 | query_x = np.zeros((self.batch_size, query_size, 1, self.resize, self.resize)) # [32, 5, 28, 28, 1] 82 | query_y = np.zeros((self.batch_size, query_size), dtype=np.int) 83 | 84 | for i in range(self.batch_size): 85 | shuffle_idx = np.arange(self.n_way) # [0,1,2,3,4] 86 | np.random.shuffle(shuffle_idx) # [2,4,1,0,3] 87 | shuffle_idx_test = np.arange(self.n_way) # [0,1,2,3,4] 88 | np.random.shuffle(shuffle_idx_test) # [2,0,1,4,3] 89 | 90 | selected_cls = np.random.choice(data_pack.shape[0], self.n_way, replace=False) 91 | for j, cur_class in enumerate(selected_cls): 92 | # cur_class : count the case in meta_test 93 | selected_imgs = np.random.choice(data_pack.shape[1], self.k_shot+self.k_query, replace=False) # # Select k_shot + k_query in 20 94 | 95 | # Divide support and query set in meta-train dataset 96 | # support_set for meta training 97 | for offset, img in enumerate(selected_imgs[:self.k_shot]): 98 | # i :batch_idx, cur_class : class in n_way 99 | support_x[i, shuffle_idx[j]*self.k_shot+offset, ...] = data_pack[cur_class][img] 100 | support_y[i, shuffle_idx[j]*self.k_shot+offset] = j 101 | 102 | # query_set for meta testing 103 | for offset, img in enumerate(selected_imgs[self.k_shot:]): 104 | query_x[i, shuffle_idx_test[j]*self.k_query+offset, ...] = data_pack[cur_class][img] 105 | query_y[i, shuffle_idx_test[j]*self.k_query+offset] = j 106 | 107 | data_cache.append([support_x, support_y, query_x, query_y]) 108 | return data_cache 109 | 110 | def get_batch(self, mode): 111 | # mode : train / test 112 | # Gets next batch from the dataset with name. 113 | 114 | if self.indexes[mode] >= len(self.datasets_cache[mode]): 115 | self.indexes[mode] = 0 116 | self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode]) 117 | 118 | # len(self.datasets_cache['train'])) : 100 119 | next_batch = self.datasets_cache[mode][self.indexes[mode]] 120 | self.indexes[mode] += 1 121 | 122 | return next_batch --------------------------------------------------------------------------------