├── requirements.txt ├── core ├── dataset │ ├── __init__.py │ ├── dataset.py │ └── omniglot.py └── helper.py ├── args.py ├── train.py ├── README.md └── net └── maml.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1 2 | torchvision==0.9.1 3 | opencv-python 4 | numpy -------------------------------------------------------------------------------- /core/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : __init__.py 3 | # @Author: Runist 4 | # @Time : 2022/7/6 10:50 5 | # @Software: PyCharm 6 | # @Brief: 7 | from .dataset import MAMLDataset 8 | from .omniglot import OmniglotDataset 9 | 10 | __all__ = ['MAMLDataset', 'OmniglotDataset'] -------------------------------------------------------------------------------- /core/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : dataset.py 3 | # @Author: Runist 4 | # @Time : 2022/7/6 10:38 5 | # @Software: PyCharm 6 | # @Brief: 7 | 8 | from torch.utils.data.dataset import Dataset 9 | 10 | 11 | class MAMLDataset(Dataset): 12 | 13 | def __init__(self, data_path, batch_size, n_way=10, k_shot=2, q_query=1): 14 | 15 | self.file_list = self.get_file_list(data_path) 16 | self.batch_size = batch_size 17 | self.n_way = n_way 18 | self.k_shot = k_shot 19 | self.q_query = q_query 20 | 21 | def get_file_list(self, data_path): 22 | raise NotImplementedError('get_file_list function not implemented!') 23 | 24 | def get_one_task_data(self): 25 | raise NotImplementedError('get_one_task_data function not implemented!') 26 | 27 | def __len__(self): 28 | return len(self.file_list) // self.batch_size 29 | 30 | def __getitem__(self, index): 31 | return self.get_one_task_data() 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /core/helper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : helper.py 3 | # @Author: Runist 4 | # @Time : 2022/7/6 11:16 5 | # @Software: PyCharm 6 | # @Brief: 7 | from net.maml import Classifier 8 | from core.dataset import OmniglotDataset 9 | 10 | import os 11 | import torch 12 | from torch import nn 13 | import numpy as np 14 | import random 15 | 16 | 17 | def get_model(args, dev): 18 | """ 19 | Get model. 20 | Args: 21 | args: ArgumentParser 22 | dev: torch dev 23 | 24 | Returns: model 25 | 26 | """ 27 | model = Classifier(1, args.n_way).cuda() 28 | model.to(dev) 29 | 30 | return model 31 | 32 | 33 | def get_dataset(args): 34 | """ 35 | Get maml dataset. 36 | Args: 37 | args: ArgumentParser 38 | 39 | Returns: dataset 40 | 41 | """ 42 | train_dataset = OmniglotDataset(args.train_data_dir, args.task_num, 43 | n_way=args.n_way, k_shot=args.k_shot, q_query=args.q_query) 44 | val_dataset = OmniglotDataset(args.val_data_dir, args.val_task_num, 45 | n_way=args.n_way, k_shot=args.k_shot, q_query=args.q_query) 46 | 47 | return train_dataset, val_dataset 48 | 49 | 50 | def seed_torch(seed): 51 | """ 52 | Set all random seed 53 | Args: 54 | seed: random seed 55 | 56 | Returns: None 57 | 58 | """ 59 | 60 | random.seed(seed) 61 | os.environ['PYTHONHASHSEED'] = str(seed) 62 | np.random.seed(seed) 63 | torch.manual_seed(seed) 64 | torch.cuda.manual_seed(seed) 65 | torch.cuda.manual_seed_all(seed) 66 | torch.backends.cudnn.benchmark = False 67 | torch.backends.cudnn.deterministic = True 68 | 69 | 70 | def remove_dir_and_create_dir(dir_name, is_remove=True): 71 | """ 72 | Make new folder, if this folder exist, we will remove it and create a new folder. 73 | Args: 74 | dir_name: path of folder 75 | is_remove: if true, it will remove old folder and create new folder 76 | 77 | Returns: None 78 | 79 | """ 80 | if not os.path.exists(dir_name): 81 | os.makedirs(dir_name) 82 | print(dir_name, "create.") 83 | else: 84 | if is_remove: 85 | shutil.rmtree(dir_name) 86 | os.makedirs(dir_name) 87 | print(dir_name, "create.") 88 | else: 89 | print(dir_name, "is exist.") -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : args.py 3 | # @Author: Runist 4 | # @Time : 2022/3/29 9:46 5 | # @Software: PyCharm 6 | # @Brief: Code argument parser 7 | 8 | import argparse 9 | import warnings 10 | import os 11 | import torch 12 | import sys 13 | sys.path.append(os.getcwd()) 14 | 15 | from core.helper import seed_torch 16 | 17 | warnings.filterwarnings("ignore") 18 | 19 | parser = argparse.ArgumentParser() 20 | 21 | parser.add_argument('--gpu', type=str, default='0', help='Select gpu device.') 22 | parser.add_argument('--train_data_dir', type=str, 23 | default="./data/Omniglot/images_background/", 24 | help='The directory containing the train image data.') 25 | parser.add_argument('--val_data_dir', type=str, 26 | default="./data/Omniglot/images_evaluation/", 27 | help='The directory containing the validation image data.') 28 | parser.add_argument('--summary_path', type=str, 29 | default="./summary", 30 | help='The directory of the summary writer.') 31 | 32 | parser.add_argument('--task_num', type=int, default=32, 33 | help='Number of task per train batch.') 34 | parser.add_argument('--val_task_num', type=int, default=16, 35 | help='Number of task per test batch.') 36 | parser.add_argument('--num_workers', type=int, default=12, help='The number of torch dataloader thread.') 37 | 38 | parser.add_argument('--epochs', type=int, default=150, 39 | help='The training epochs.') 40 | parser.add_argument('--inner_lr', type=float, default=0.04, 41 | help='The learning rate of of the support set.') 42 | parser.add_argument('--outer_lr', type=float, default=0.001, 43 | help='The learning rate of of the query set.') 44 | 45 | parser.add_argument('--n_way', type=int, default=5, 46 | help='The number of class of every task.') 47 | parser.add_argument('--k_shot', type=int, default=1, 48 | help='The number of support set image for every task.') 49 | parser.add_argument('--q_query', type=int, default=1, 50 | help='The number of query set image for every task.') 51 | 52 | args = parser.parse_args() 53 | 54 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 55 | seed_torch(1206) 56 | 57 | dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 58 | -------------------------------------------------------------------------------- /core/dataset/omniglot.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : omniglot.py 3 | # @Author: Runist 4 | # @Time : 2022/7/6 10:41 5 | # @Software: PyCharm 6 | # @Brief: 7 | 8 | import random 9 | import numpy as np 10 | import glob 11 | from PIL import Image 12 | import torch.nn.functional as F 13 | import torch 14 | 15 | from core.dataset import MAMLDataset 16 | 17 | 18 | class OmniglotDataset(MAMLDataset): 19 | def get_file_list(self, data_path): 20 | """ 21 | Get all fonts list. 22 | Args: 23 | data_path: Omniglot Data path 24 | 25 | Returns: fonts list 26 | 27 | """ 28 | return [f for f in glob.glob(data_path + "**/character*", recursive=True)] 29 | 30 | def get_one_task_data(self): 31 | """ 32 | Get ones task maml data, include one batch support images and labels, one batch query images and labels. 33 | Returns: support_data, query_data 34 | 35 | """ 36 | img_dirs = random.sample(self.file_list, self.n_way) 37 | support_data = [] 38 | query_data = [] 39 | 40 | support_image = [] 41 | support_label = [] 42 | query_image = [] 43 | query_label = [] 44 | 45 | for label, img_dir in enumerate(img_dirs): 46 | img_list = [f for f in glob.glob(img_dir + "**/*.png", recursive=True)] 47 | images = random.sample(img_list, self.k_shot + self.q_query) 48 | 49 | # Read support set 50 | for img_path in images[:self.k_shot]: 51 | image = Image.open(img_path) 52 | image = np.array(image) 53 | image = np.expand_dims(image / 255., axis=0) 54 | support_data.append((image, label)) 55 | 56 | # Read query set 57 | for img_path in images[self.k_shot:]: 58 | image = Image.open(img_path) 59 | image = np.array(image) 60 | image = np.expand_dims(image / 255., axis=0) 61 | query_data.append((image, label)) 62 | 63 | # shuffle support set 64 | random.shuffle(support_data) 65 | for data in support_data: 66 | support_image.append(data[0]) 67 | support_label.append(data[1]) 68 | 69 | # shuffle query set 70 | random.shuffle(query_data) 71 | for data in query_data: 72 | query_image.append(data[0]) 73 | query_label.append(data[1]) 74 | 75 | return np.array(support_image), np.array(support_label), np.array(query_image), np.array(query_label) 76 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : train.py 3 | # @Author: Runist 4 | # @Time : 2022/7/6 10:01 5 | # @Software: PyCharm 6 | # @Brief: 7 | 8 | import torch 9 | import numpy as np 10 | from tqdm import tqdm 11 | import torch.optim as optim 12 | from torch.utils.data import DataLoader 13 | 14 | from core.helper import get_model, get_dataset 15 | from net.maml import maml_train 16 | from args import args, dev 17 | 18 | 19 | if __name__ == '__main__': 20 | model = get_model(args, dev) 21 | train_dataset, val_dataset = get_dataset(args) 22 | 23 | train_loader = DataLoader(train_dataset, batch_size=args.task_num, shuffle=True, num_workers=args.num_workers) 24 | val_loader = DataLoader(val_dataset, batch_size=args.val_task_num, shuffle=False, num_workers=args.num_workers) 25 | 26 | params = [p for p in model.parameters() if p.requires_grad] 27 | optimizer = optim.Adam(params, args.outer_lr) 28 | best_acc = 0 29 | 30 | model.train() 31 | for epoch in range(args.epochs): 32 | train_acc = [] 33 | val_acc = [] 34 | train_loss = [] 35 | val_loss = [] 36 | 37 | train_bar = tqdm(train_loader) 38 | for support_images, support_labels, query_images, query_labels in train_bar: 39 | train_bar.set_description("epoch {}".format(epoch + 1)) 40 | # Get variables 41 | support_images = support_images.float().to(dev) 42 | support_labels = support_labels.long().to(dev) 43 | query_images = query_images.float().to(dev) 44 | query_labels = query_labels.long().to(dev) 45 | 46 | loss, acc = maml_train(model, support_images, support_labels, query_images, query_labels, 47 | 1, args, optimizer) 48 | 49 | train_loss.append(loss.item()) 50 | train_acc.append(acc) 51 | train_bar.set_postfix(loss="{:.4f}".format(loss.item())) 52 | 53 | for support_images, support_labels, query_images, query_labels in val_loader: 54 | 55 | # Get variables 56 | support_images = support_images.float().to(dev) 57 | support_labels = support_labels.long().to(dev) 58 | query_images = query_images.float().to(dev) 59 | query_labels = query_labels.long().to(dev) 60 | 61 | loss, acc = maml_train(model, support_images, support_labels, query_images, query_labels, 62 | 3, args, optimizer, is_train=False) 63 | 64 | # Must use .item() to add total loss, or will occur GPU memory leak. 65 | # Because dynamic graph is created during forward, collect in backward. 66 | val_loss.append(loss.item()) 67 | val_acc.append(acc) 68 | 69 | print("=> loss: {:.4f} acc: {:.4f} val_loss: {:.4f} val_acc: {:.4f}". 70 | format(np.mean(train_loss), np.mean(train_acc), np.mean(val_loss), np.mean(val_acc))) 71 | 72 | if np.mean(val_acc) > best_acc: 73 | best_acc = np.mean(val_acc) 74 | torch.save(model, 'best.pt') 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pyorch - MAML 2 | 3 | ## Part 1. Introduction 4 | 5 | As we all know, deep learning need vast data. If you don't have this condition, you can use pre-training weights. Most of data can be fitted be pre-training weights, but there all still some data that can't converge to the global lowest point. So it is exist one weights that can let all task get best result? 6 | 7 | Yes, this is "Model-Agnostic Meta-Learning". The biggest difference between MAML and pre-training weights:Pre-training weights minimize only for original task loss. MAML can minimize all task loss with a few steps of training. 8 | 9 | If this works for you, please give me a star, this is very important to me.😊 10 | 11 | ## Part 2. Quick Start 12 | 13 | 1. Pull repository. 14 | 15 | ```shell 16 | git clone https://github.com/Runist/torch_maml.git 17 | ``` 18 | 19 | 2. You need to install some dependency package. 20 | 21 | ```shell 22 | cd torch_maml 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | 3. Download the *Omiglot* dataset. 27 | 28 | ```shell 29 | mkdir data 30 | cd data 31 | wget https://github.com/Runist/MAML-keras/releases/download/v1.0/Omniglot.tar 32 | tar -xvf Omniglot.tar 33 | ``` 34 | 35 | 4. Start training. 36 | 37 | ```shell 38 | python train.py 39 | ``` 40 | 41 | ``` 42 | epoch 1: 100%|█████████████████████████████████████| 4/4 [00:05<00:00, 1.45s/it, loss=1.2326] 43 | => loss: 1.2917 acc: 0.4990 val_loss: 0.8875 val_acc: 0.7963 44 | epoch 2: 100%|█████████████████████████████████████| 4/4 [00:05<00:00, 1.32s/it, loss=0.9818] 45 | => loss: 1.0714 acc: 0.6688 val_loss: 0.8573 val_acc: 0.7713 46 | epoch 3: 100%|█████████████████████████████████████| 4/4 [00:05<00:00, 1.34s/it, loss=0.9472] 47 | => loss: 0.9896 acc: 0.6922 val_loss: 0.8000 val_acc: 0.7773 48 | epoch 4: 100%|█████████████████████████████████████| 4/4 [00:05<00:00, 1.39s/it, loss=0.7929] 49 | => loss: 0.8258 acc: 0.7812 val_loss: 0.8071 val_acc: 0.7676 50 | epoch 5: 100%|█████████████████████████████████████| 4/4 [00:08<00:00, 2.14s/it, loss=0.6662] 51 | => loss: 0.7754 acc: 0.7646 val_loss: 0.7144 val_acc: 0.7833 52 | epoch 6: 100%|█████████████████████████████████████| 4/4 [00:04<00:00, 1.21s/it, loss=0.7490] 53 | => loss: 0.7565 acc: 0.7635 val_loss: 0.6317 val_acc: 0.8130 54 | epoch 7: 100%|█████████████████████████████████████| 4/4 [00:05<00:00, 1.25s/it, loss=0.5380] 55 | => loss: 0.5871 acc: 0.8333 val_loss: 0.5963 val_acc: 0.8255 56 | epoch 8: 100%|█████████████████████████████████████| 4/4 [00:05<00:00, 1.27s/it, loss=0.5144] 57 | => loss: 0.5786 acc: 0.8255 val_loss: 0.5652 val_acc: 0.8463 58 | epoch 9: 100%|█████████████████████████████████████| 4/4 [00:04<00:00, 1.18s/it, loss=0.4945] 59 | => loss: 0.5038 acc: 0.8510 val_loss: 0.6305 val_acc: 0.8005 60 | epoch 10: 100%|█████████████████████████████████████| 4/4 [00:06<00:00, 1.75s/it, loss=0.4634] 61 | => loss: 0.4523 acc: 0.8719 val_loss: 0.5285 val_acc: 0.8491 62 | ``` 63 | 64 | ## Part 3. Train your own dataset 65 | 1. You should set same parameters in **args.py**. More detail you can get in my [blog](https://blog.csdn.net/weixin_42392454/article/details/109891791?spm=1001.2014.3001.5501). 66 | 67 | ```python 68 | parser.add_argument('--train_data_dir', type=str, 69 | default="./data/Omniglot/images_background/", 70 | help='The directory containing the train image data.') 71 | parser.add_argument('--val_data_dir', type=str, 72 | default="./data/Omniglot/images_evaluation/", 73 | help='The directory containing the validation image data.') 74 | parser.add_argument('--n_way', type=int, default=10, 75 | help='The number of class of every task.') 76 | parser.add_argument('--k_shot', type=int, default=1, 77 | help='The number of support set image for every task.') 78 | parser.add_argument('--q_query', type=int, default=1, 79 | help='The number of query set image for every task.') 80 | ``` 81 | 82 | 2. Start training. 83 | 84 | ```shell 85 | python train.py --n_way=5 --k_shot=1 --q_query=1 86 | ``` 87 | 88 | ## Part 4. Paper and other implement 89 | 90 | - [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](https://arxiv.org/pdf/1703.03400.pdf) 91 | - [cbfinn/*maml*](https://github.com/cbfinn/maml) 92 | - [dragen1860/*MAML*-Pytorch](https://github.com/dragen1860/MAML-Pytorch) 93 | - [Runist](https://github.com/Runist)/[MAML-keras](https://github.com/Runist/MAML-keras) 94 | -------------------------------------------------------------------------------- /net/maml.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : maml.py 3 | # @Author: Runist 4 | # @Time : 2022/7/6 11:54 5 | # @Software: PyCharm 6 | # @Brief: 7 | 8 | import torch 9 | import torch.nn as nn 10 | import collections 11 | import torch.nn.functional as F 12 | import numpy as np 13 | 14 | 15 | class ConvBlock(nn.Module): 16 | 17 | def __init__(self, in_ch, out_ch): 18 | super(ConvBlock, self).__init__() 19 | self.conv2d = nn.Conv2d(in_ch, out_ch, 3, padding=1) 20 | self.bn = nn.BatchNorm2d(out_ch) 21 | self.relu = nn.ReLU() 22 | self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) 23 | 24 | def forward(self, x): 25 | x = self.conv2d(x) 26 | x = self.bn(x) 27 | x = self.relu(x) 28 | x = self.max_pool(x) 29 | 30 | return x 31 | 32 | 33 | def ConvBlockFunction(input, w, b, w_bn, b_bn): 34 | x = F.conv2d(input, w, b, padding=1) 35 | x = F.batch_norm(x, running_mean=None, running_var=None, weight=w_bn, bias=b_bn, training=True) 36 | x = F.relu(x) 37 | output = F.max_pool2d(x, kernel_size=2, stride=2) 38 | 39 | return output 40 | 41 | 42 | class Classifier(nn.Module): 43 | def __init__(self, in_ch, n_way): 44 | super(Classifier, self).__init__() 45 | self.conv1 = ConvBlock(in_ch, 64) 46 | self.conv2 = ConvBlock(64, 64) 47 | self.conv3 = ConvBlock(64, 64) 48 | self.conv4 = ConvBlock(64, 64) 49 | self.logits = nn.Linear(64, n_way) 50 | 51 | def forward(self, x): 52 | x = self.conv1(x) 53 | x = self.conv2(x) 54 | x = self.conv3(x) 55 | x = self.conv4(x) 56 | x = x.view(x.shape[0], -1) 57 | x = self.logits(x) 58 | 59 | return x 60 | 61 | def functional_forward(self, x, params): 62 | x = ConvBlockFunction(x, params[f'conv1.conv2d.weight'], params[f'conv1.conv2d.bias'], 63 | params.get(f'conv1.bn.weight'), params.get(f'conv1.bn.bias')) 64 | x = ConvBlockFunction(x, params[f'conv2.conv2d.weight'], params[f'conv2.conv2d.bias'], 65 | params.get(f'conv2.bn.weight'), params.get(f'conv2.bn.bias')) 66 | x = ConvBlockFunction(x, params[f'conv3.conv2d.weight'], params[f'conv3.conv2d.bias'], 67 | params.get(f'conv3.bn.weight'), params.get(f'conv3.bn.bias')) 68 | x = ConvBlockFunction(x, params[f'conv4.conv2d.weight'], params[f'conv4.conv2d.bias'], 69 | params.get(f'conv4.bn.weight'), params.get(f'conv4.bn.bias')) 70 | 71 | x = x.view(x.shape[0], -1) 72 | x = F.linear(x, params['logits.weight'], params['logits.bias']) 73 | 74 | return x 75 | 76 | 77 | def maml_train(model, support_images, support_labels, query_images, query_labels, inner_step, args, optimizer, is_train=True): 78 | """ 79 | Train the model using MAML method. 80 | Args: 81 | model: Any model 82 | support_images: several task support images 83 | support_labels: several support labels 84 | query_images: several query images 85 | query_labels: several query labels 86 | inner_step: support data training step 87 | args: ArgumentParser 88 | optimizer: optimizer 89 | is_train: whether train 90 | 91 | Returns: meta loss, meta accuracy 92 | 93 | """ 94 | meta_loss = [] 95 | meta_acc = [] 96 | 97 | for support_image, support_label, query_image, query_label in zip(support_images, support_labels, query_images, query_labels): 98 | 99 | fast_weights = collections.OrderedDict(model.named_parameters()) 100 | for _ in range(inner_step): 101 | # Update weight 102 | support_logit = model.functional_forward(support_image, fast_weights) 103 | support_loss = nn.CrossEntropyLoss().cuda()(support_logit, support_label) 104 | grads = torch.autograd.grad(support_loss, fast_weights.values(), create_graph=True) 105 | fast_weights = collections.OrderedDict((name, param - args.inner_lr * grad) 106 | for ((name, param), grad) in zip(fast_weights.items(), grads)) 107 | 108 | # Use trained weight to get query loss 109 | query_logit = model.functional_forward(query_image, fast_weights) 110 | query_prediction = torch.max(query_logit, dim=1)[1] 111 | 112 | query_loss = nn.CrossEntropyLoss().cuda()(query_logit, query_label) 113 | query_acc = torch.eq(query_label, query_prediction).sum() / len(query_label) 114 | 115 | meta_loss.append(query_loss) 116 | meta_acc.append(query_acc.data.cpu().numpy()) 117 | 118 | # Zero the gradient 119 | optimizer.zero_grad() 120 | meta_loss = torch.stack(meta_loss).mean() 121 | meta_acc = np.mean(meta_acc) 122 | 123 | if is_train: 124 | meta_loss.backward() 125 | optimizer.step() 126 | 127 | return meta_loss, meta_acc 128 | --------------------------------------------------------------------------------