├── __init__.py ├── readme.md ├── maml.py ├── data_loader.py ├── util.py ├── train.py └── data_generator.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | #### MAML 的粗糙版实现,可能很多地方实现的不好。请大家批评~ 2 | 3 | 4 | 5 | #### 知乎链接 https://zhuanlan.zhihu.com/p/84277881 , 有问题可以去知乎留言,github留言可能回复不及时。 6 | 7 | 8 | 9 | 10 | 11 | @author Miao_txy -------------------------------------------------------------------------------- /maml.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # from torchsummary import summary 5 | 6 | 7 | class Maml(nn.Module): 8 | def __init__(self): 9 | super(Maml, self).__init__() 10 | self.conv_layer = nn.Sequential( 11 | nn.BatchNorm2d(1), 12 | nn.ReLU(), 13 | nn.Conv2d(1, 64, 3, padding=1), 14 | nn.BatchNorm2d(64), 15 | nn.ReLU(), 16 | nn.Conv2d(64, 64, 3, padding=1), 17 | nn.BatchNorm2d(64), 18 | nn.ReLU(), 19 | nn.Conv2d(64, 64, 3, padding=1), 20 | nn.BatchNorm2d(64), 21 | nn.ReLU(), 22 | nn.Conv2d(64, 64, 3, padding=1), 23 | ) 24 | self.linear_layer = nn.Sequential( 25 | nn.Linear(64 * 28 * 28, 1024), 26 | nn.Sigmoid(), 27 | nn.Dropout(0.5), 28 | nn.Linear(1024, 5), 29 | ) 30 | 31 | def forward(self, x): 32 | x = self.conv_layer(x) 33 | x = x.view(-1, 64 * 28 * 28) 34 | x = self.linear_layer(x) 35 | return x 36 | 37 | 38 | if __name__ == "__main__": 39 | 40 | maml = Maml() 41 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 42 | 43 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from PIL import Image 4 | import torchvision.transforms as transforms 5 | 6 | 7 | class Data_loader(): 8 | 9 | def __init__(self, path, train, transform): 10 | self.path = path 11 | self.train = train 12 | self.transform = transform 13 | self.load_image() 14 | 15 | def _load_image(self): 16 | self.characters = [] 17 | for (root, dir, files) in os.walk(self.path): 18 | images = [] 19 | for file in files: 20 | if file.endswith(".png"): 21 | file = root + "\\" + file 22 | image = Image.open(file) 23 | image = self.transform(image) 24 | images.append(image) 25 | if len(images) != 0: 26 | self.characters.append(images) 27 | 28 | def load_image(self): 29 | file_name = self.path + [r"\val", r"\train"][self.train] + ".pickle" 30 | if self._exsit_image_pickle() and os.path.getsize(file_name) > 0: 31 | f = open(file_name, "rb") 32 | self.characters = pickle.load(f) 33 | else: 34 | f = open(file_name, "wb") 35 | self._load_image() 36 | pickle.dump(self.characters, f) 37 | 38 | def _exsit_image_pickle(self): 39 | file_name = self.path + [r"\val", r"\train"][self.train] + ".pickle" 40 | return os.path.exists(file_name) 41 | 42 | 43 | if __name__ == "__main__": 44 | 45 | path = r"C:\Users\Miao_\Desktop\my_maml\omniglot" 46 | transform = transforms.Compose([ 47 | transforms.Resize(28), 48 | transforms.ToTensor(), 49 | transforms.Normalize(mean=[0.5], std=[0.5]), 50 | ]) 51 | Ominiglot = Data_loader(path=path, train=True, transform=transform) 52 | characters = Ominiglot.characters 53 | print(len(characters)) 54 | print(len(characters[0])) 55 | for idx,i in enumerate(characters): 56 | if len(i) == 0: 57 | print(idx) 58 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | 5 | _internal_attrs = {'_backend', '_parameters', '_buffers', '_backward_hooks', '_forward_hooks', '_forward_pre_hooks', '_modules'} 6 | 7 | 8 | class Scope(object): 9 | def __init__(self): 10 | self._modules = OrderedDict() 11 | 12 | 13 | def _make_functional(module, params_box, params_offset): 14 | self = Scope() 15 | num_params = len(module._parameters) 16 | param_names = list(module._parameters.keys()) 17 | forward = type(module).forward 18 | if isinstance(module, nn.Conv2d): 19 | setattr(self, "conv2d_forward", module.conv2d_forward) 20 | if isinstance(module, nn.BatchNorm2d): 21 | setattr(self, "_check_input_dim", module._check_input_dim) 22 | setattr(self, "num_batches_tracked", module.num_batches_tracked) 23 | setattr(self, "running_mean", module.running_mean) 24 | setattr(self, "running_var", module.running_var) 25 | 26 | for name, attr in module.__dict__.items(): 27 | if name in _internal_attrs: 28 | continue 29 | setattr(self, name, attr) 30 | 31 | child_params_offset = params_offset + num_params 32 | for name, child in module.named_children(): 33 | child_params_offset, fchild = _make_functional(child, params_box, child_params_offset) 34 | self._modules[name] = fchild 35 | setattr(self, name, fchild) 36 | 37 | def fmodule(*args, **kwargs): 38 | for name, param in zip(param_names, params_box[0][params_offset:params_offset + num_params]): 39 | setattr(self, name, param) 40 | return forward(self, *args, **kwargs) 41 | 42 | return child_params_offset, fmodule 43 | 44 | 45 | def make_functional(module): 46 | params_box = [None] 47 | _, fmodule_internal = _make_functional(module, params_box, 0) 48 | 49 | def fmodule(*args, **kwargs): 50 | params_box[0] = kwargs.pop('params') 51 | return fmodule_internal(*args, **kwargs) 52 | 53 | return fmodule 54 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from numpy.random import choice as choice 4 | from maml import Maml 5 | from data_generator import Data_gernerator 6 | from numpy.random import shuffle 7 | from torch.autograd import grad 8 | from util import make_functional 9 | 10 | 11 | if __name__ == "__main__": 12 | 13 | config = {} 14 | config["path"] = r"C:\Users\Miao_\Desktop\my_maml\omniglot" 15 | config["num_epoches"] = 100000 16 | config["task_batch"] = 2 17 | config["support_num"] = 1 18 | config["query_num"] = 15 19 | config["nways"] = 5 20 | 21 | in_lr = 0.01 22 | meta_lr = 0.001 23 | 24 | data_gen = Data_gernerator(**config) 25 | 26 | def train(): 27 | 28 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 29 | model = Maml().to(device=device) 30 | 31 | parameters = list(model.parameters()) 32 | optimizer = torch.optim.Adam(parameters, lr=meta_lr) 33 | criterion = torch.nn.CrossEntropyLoss() 34 | 35 | for idx, batch_data in enumerate(data_gen.generator): 36 | 37 | loss = 0 38 | acc = 0 39 | 40 | for task_data in batch_data: 41 | 42 | s, s_y, q, q_y = task_data 43 | 44 | s_res = model(s) 45 | s_loss = criterion(s_res, s_y) 46 | s_grad = grad(s_loss, model.parameters(), create_graph=True) 47 | 48 | fast_weights = list(map(lambda p: p[1] - in_lr * p[0], zip(s_grad, parameters))) 49 | f_model = make_functional(model) 50 | 51 | q_res = f_model(q, params=fast_weights) 52 | q_loss = criterion(q_res, q_y) 53 | loss = q_loss if loss == 0 else loss + q_loss 54 | q_acc = (torch.argmax(q_res, dim=1) == q_y).sum().item() / len(q_y) 55 | acc = q_acc if acc == 0 else acc + q_acc 56 | 57 | if idx % 2: 58 | print("acc:{} loss:{}".format(acc / config["task_batch"], loss)) 59 | 60 | optimizer.zero_grad() 61 | loss.backward() 62 | optimizer.step() 63 | train() 64 | -------------------------------------------------------------------------------- /data_generator.py: -------------------------------------------------------------------------------- 1 | from data_loader import Data_loader 2 | import numpy as np 3 | from numpy.random import choice, shuffle 4 | from torchvision import transforms 5 | import torch 6 | 7 | 8 | class Data_gernerator(): 9 | def __init__(self, **config): 10 | self.path = config["path"] 11 | self.num_epoches = config["num_epoches"] 12 | self.task_batch = config["task_batch"] 13 | self.support_num = config["support_num"] 14 | self.query_num = config["query_num"] 15 | self.nways = config["nways"] 16 | 17 | self.characters = self.load() 18 | self.generator = self.sample_task() 19 | 20 | def load(self): 21 | transform = transforms.Compose([ 22 | transforms.Resize(28), 23 | transforms.ToTensor(), 24 | transforms.Normalize(mean=[0.5], std=[0.5]), 25 | ]) 26 | Ominiglot = Data_loader(path=self.path, 27 | train=True, 28 | transform=transform) 29 | characters = Ominiglot.characters 30 | return characters 31 | 32 | def sample_task(self): 33 | 34 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 35 | 36 | count = 0 37 | while count < self.num_epoches: 38 | 39 | batch_data = [] 40 | for i in range(self.task_batch): 41 | task_data = [] 42 | classes = choice(len(self.characters), self.nways) 43 | shuffle(classes) 44 | for j in range(self.nways): 45 | sel_imgs = choice(len(self.characters[classes[j]]), (self.query_num + self.support_num)) 46 | task_data.append([self.characters[classes[j]][_] for _ in sel_imgs]) 47 | 48 | s = [one_class[0] for one_class in task_data] 49 | s_y = [_ for _ in range(self.nways)] 50 | s_list = list(zip(s, s_y)) 51 | shuffle(s_list) 52 | s, s_y = zip(*s_list) 53 | s = torch.stack(s).to(device=device) 54 | s_y = torch.tensor(s_y).to(device=device) 55 | 56 | q = [_ for one_class in task_data for _ in one_class[1:]] 57 | q_y = [_ // self.query_num for _ in range(self.query_num * self.nways)] 58 | q_list = list(zip(q, q_y)) 59 | shuffle(q_list) 60 | q, q_y = zip(*q_list) 61 | q = torch.stack(q).to(device=device) 62 | q_y = torch.tensor(q_y).to(device=device) 63 | 64 | batch_data.append((s, s_y, q, q_y)) 65 | 66 | yield batch_data 67 | 68 | 69 | if __name__ == "__main__": 70 | para= {} 71 | para["path"]= r"C:\Users\Miao_\Desktop\my_maml\omniglot" 72 | para["num_epoches"]= 100000 73 | para["task_batch"]= 32 74 | para["support_num"]= 1 75 | para["query_num"]= 15 76 | para["nways"]= 5 77 | 78 | a= Data_gernerator(**para) 79 | 80 | for i in a.generator: 81 | print(len(i)) 82 | print(len(i[0])) 83 | print(len(i[0][0])) 84 | break 85 | --------------------------------------------------------------------------------