├── README.md ├── config.py ├── dataReader.py ├── evaluate.py ├── net.py ├── requirements.txt ├── train.py └── train_mnist.py /README.md: -------------------------------------------------------------------------------- 1 | # Keras - MAML 2 | 3 | ## Part 1. Introduction 4 | 5 | As we all know, deep learning need vast data. If you don't have this condition, you can use pre-training weights. Most of data can be fitted be pre-training weights, but there all still some data that can't converge to the global lowest point. So it is exist one weights that can let all task get best result? 6 | 7 | Yes, this is "Model-Agnostic Meta-Learning". The biggest difference between MAML and pre-training weights:Pre-training weights minimize only for original task loss. MAML can minimize all task loss with a few steps of training. 8 | 9 | If this works for you, please give me a star, this is very important to me.😊 10 | ## Part 2. Quick Start 11 | 12 | 1. Pull repository. 13 | 14 | ```shell 15 | git clone https://github.com/Runist/MAML-keras.git 16 | ``` 17 | 18 | 2. You need to install some dependency package. 19 | 20 | ```shell 21 | cd MAML-keras 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | 3. Download the *Omiglot* dataset and maml weights. 26 | 27 | ```shell 28 | wget https://github.com/Runist/MAML-keras/releases/download/v1.0/Omniglot.tar 29 | wget https://github.com/Runist/MAML-keras/releases/download/v1.0/maml.h5 30 | tar -xvf Omniglot.tar 31 | ``` 32 | 33 | 4. Run **train_mnist.py**, after few minutes, you'll get mnist weight. 34 | 35 | ```shell 36 | python train_mnist.py 37 | ``` 38 | 39 | ``` 40 | 235/235 [==============================] - 62s 133ms/step - loss: 0.3736 - sparse_categorical_accuracy: 0.8918 41 | Epoch 2/3 42 | 235/235 [==============================] - 2s 9ms/step - loss: 0.0385 - sparse_categorical_accuracy: 0.9886 43 | Epoch 3/3 44 | 235/235 [==============================] - 2s 9ms/step - loss: 0.0219 - sparse_categorical_accuracy: 0.9934 45 | 313/313 [==============================] - 27s 48ms/step - loss: 0.0373 - sparse_categorical_accuracy: 0.9882 46 | ``` 47 | 48 | 5. Run **evaluate.py**, you'll see the difference between MAML and MNIST initialization weights. 49 | 50 | ```shell 51 | python evaluate.py 52 | ``` 53 | 54 | ``` 55 | Model with mnist initialize weight train for 3 step, val loss: 1.8765, accuracy: 0.3400. 56 | Model with mnist initialize weight train for 5 step, val loss: 1.5195, accuracy: 0.4600. 57 | Model with maml weight train for 3 step, val loss: 0.8904, accuracy: 0.6700. 58 | Model with maml weight train for 5 step, val loss: 0.5034, accuracy: 0.7800. 59 | ``` 60 | 61 | ## Part 3. Train your own dataset 62 | 1. You should set same parameters in **config.py**. More detail you can get in my [blog](https://blog.csdn.net/weixin_42392454/article/details/109891791?spm=1001.2014.3001.5501). 63 | 64 | ```python 65 | parser.add_argument('--n_way', type=int, default=10, 66 | help='The number of class of every task.') 67 | parser.add_argument('--k_shot', type=int, default=1, 68 | help='The number of support set image for every task.') 69 | parser.add_argument('--q_query', type=int, default=1, 70 | help='The number of query set image for every task.') 71 | parser.add_argument('--input_shape', type=tuple, default=(28, 28, 1), 72 | help='The image shape of model input.') 73 | ``` 74 | 75 | 2. Start training. 76 | 77 | ```shell 78 | python train.py --n_way=5 --k_shot=1 --q_query=1 79 | ``` 80 | 81 | ## Part 4. Paper and other implement 82 | 83 | - [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](https://arxiv.org/pdf/1703.03400.pdf) 84 | - [cbfinn/*maml*](https://github.com/cbfinn/maml) 85 | - [dragen1860/*MAML*-Pytorch](https://github.com/dragen1860/MAML-Pytorch) 86 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : config.py 3 | # @Author: Runist 4 | # @Time : 2020/7/8 16:54 5 | # @Software: PyCharm 6 | # @Brief: 配置文件 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--gpu', type=str, default='0', help='Select gpu device.') 11 | parser.add_argument('--train_data_dir', type=str, 12 | default="./Omniglot/images_background/", 13 | help='The directory containing the train image data.') 14 | parser.add_argument('--val_data_dir', type=str, 15 | default="./Omniglot/images_evaluation/", 16 | help='The directory containing the validation image data.') 17 | parser.add_argument('--summary_path', type=str, 18 | default="./summary", 19 | help='The directory of the summary writer.') 20 | 21 | parser.add_argument('--batch_size', type=int, default=32, 22 | help='Number of task per train batch.') 23 | parser.add_argument('--val_batch_size', type=int, default=16, 24 | help='Number of task per test batch.') 25 | parser.add_argument('--epochs', type=int, default=100, 26 | help='The training epochs.') 27 | parser.add_argument('--inner_lr', type=float, default=0.04, 28 | help='The learning rate of of the support set.') 29 | parser.add_argument('--outer_lr', type=float, default=0.001, 30 | help='The learning rate of of the query set.') 31 | 32 | parser.add_argument('--n_way', type=int, default=10, 33 | help='The number of class of every task.') 34 | parser.add_argument('--k_shot', type=int, default=1, 35 | help='The number of support set image for every task.') 36 | parser.add_argument('--q_query', type=int, default=1, 37 | help='The number of query set image for every task.') 38 | parser.add_argument('--input_shape', type=tuple, default=(28, 28, 1), 39 | help='The image shape of model input.') 40 | 41 | args = parser.parse_args() 42 | -------------------------------------------------------------------------------- /dataReader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : dataReader.py 3 | # @Author: Runist 4 | # @Time : 2020/7/7 10:06 5 | # @Software: PyCharm 6 | # @Brief: 数据读取脚本 7 | 8 | import random 9 | import numpy as np 10 | import glob 11 | import cv2 as cv 12 | import tensorflow as tf 13 | 14 | 15 | class MAMLDataLoader: 16 | 17 | def __init__(self, data_path, batch_size, n_way=5, k_shot=1, q_query=1): 18 | """ 19 | MAML数据读取器 20 | :param data_path: 数据路径,此文件夹下需要有分好类的子文件夹 21 | :param batch_size: 有多少个不同的任务 22 | :param n_way: 一个任务中分为几类 23 | :param k_shot: 一个类中有几个图片用于Inner looper的训练 24 | :param q_query: 一个类中有几个图片用于Outer looper的训练 25 | """ 26 | self.file_list = [f for f in glob.glob(data_path + "**/character*", recursive=True)] 27 | self.steps = len(self.file_list) // batch_size 28 | 29 | self.n_way = n_way 30 | self.k_shot = k_shot 31 | self.q_query = q_query 32 | self.meta_batch_size = batch_size 33 | 34 | def __len__(self): 35 | return self.steps 36 | 37 | def get_one_task_data(self): 38 | """ 39 | 获取一个task,一个task内有n_way个类,每个类有k_shot张用于inner训练,q_query张用于outer训练 40 | :return: support_data, query_data 41 | """ 42 | img_dirs = random.sample(self.file_list, self.n_way) 43 | support_data = [] 44 | query_data = [] 45 | 46 | support_image = [] 47 | support_label = [] 48 | query_image = [] 49 | query_label = [] 50 | 51 | for label, img_dir in enumerate(img_dirs): 52 | img_list = [f for f in glob.glob(img_dir + "**/*.png", recursive=True)] 53 | images = random.sample(img_list, self.k_shot + self.q_query) 54 | 55 | # Read support set 56 | for img_path in images[:self.k_shot]: 57 | image = cv.imread(img_path, cv.IMREAD_UNCHANGED) 58 | image = image / 255. 59 | image = np.expand_dims(image, axis=-1) 60 | support_data.append((image, label)) 61 | 62 | # Read query set 63 | for img_path in images[self.k_shot:]: 64 | image = cv.imread(img_path, cv.IMREAD_UNCHANGED) 65 | image = image / 255. 66 | image = np.expand_dims(image, axis=-1) 67 | query_data.append((image, label)) 68 | 69 | # shuffle support set 70 | random.shuffle(support_data) 71 | for data in support_data: 72 | support_image.append(data[0]) 73 | support_label.append(data[1]) 74 | 75 | # shuffle query set 76 | random.shuffle(query_data) 77 | for data in query_data: 78 | query_image.append(data[0]) 79 | query_label.append(data[1]) 80 | 81 | return np.array(support_image), np.array(support_label), np.array(query_image), np.array(query_label) 82 | 83 | def get_one_batch(self): 84 | """ 85 | 获取一个batch的样本,这里一个batch中是以task为个体 86 | :return: k_shot_data, q_query_data 87 | """ 88 | 89 | while True: 90 | batch_support_image = [] 91 | batch_support_label = [] 92 | batch_query_image = [] 93 | batch_query_label = [] 94 | 95 | for _ in range(self.meta_batch_size): 96 | support_image, support_label, query_image, query_label = self.get_one_task_data() 97 | batch_support_image.append(support_image) 98 | batch_support_label.append(support_label) 99 | batch_query_image.append(query_image) 100 | batch_query_label.append(query_label) 101 | 102 | yield np.array(batch_support_image), np.array(batch_support_label), \ 103 | np.array(batch_query_image), np.array(batch_query_label) 104 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : evaluate.py 3 | # @Author: Runist 4 | # @Time : 2021/4/26 15:42 5 | # @Software: PyCharm 6 | # @Brief: 测试脚本 7 | 8 | from tensorflow.keras import optimizers, utils, metrics 9 | import tensorflow as tf 10 | import numpy as np 11 | 12 | from dataReader import MAMLDataLoader 13 | from net import MAML 14 | from config import args 15 | import os 16 | 17 | if __name__ == '__main__': 18 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 19 | gpus = tf.config.experimental.list_physical_devices("GPU") 20 | if gpus: 21 | for gpu in gpus: 22 | tf.config.experimental.set_memory_growth(gpu, True) 23 | 24 | val_data = MAMLDataLoader(args.val_data_dir, args.val_batch_size) 25 | 26 | mnist_model = MAML(args.input_shape, args.n_way) 27 | maml = MAML(args.input_shape, args.n_way) 28 | 29 | # 对比测试 30 | 31 | # mnist weights 32 | mnist_model.meta_model.load_weights("mnist.h5") 33 | optimizer = optimizers.Adam(args.inner_lr) 34 | val_loss, val_acc = mnist_model.train_on_batch(val_data.get_one_batch(), inner_optimizer=optimizer, inner_step=3) 35 | print("Model with mnist initialize weight train for 3 step, val loss: {:.4f}, accuracy: {:.4f}.".format(val_loss, val_acc)) 36 | 37 | mnist_model.meta_model.load_weights("mnist.h5") 38 | optimizer = optimizers.Adam(args.inner_lr) 39 | val_loss, val_acc = mnist_model.train_on_batch(val_data.get_one_batch(), inner_optimizer=optimizer, inner_step=5) 40 | print("Model with mnist initialize weight train for 5 step, val loss: {:.4f}, accuracy: {:.4f}.".format(val_loss, val_acc)) 41 | 42 | # maml weights 43 | maml.meta_model.load_weights("maml.h5") 44 | optimizer = optimizers.Adam(args.inner_lr) 45 | val_loss, val_acc = maml.train_on_batch(val_data.get_one_batch(), inner_optimizer=optimizer, inner_step=3) 46 | print("Model with maml weight train for 3 step, val loss: {:.4f}, accuracy: {:.4f}.".format(val_loss, val_acc)) 47 | 48 | maml.meta_model.load_weights("maml.h5") 49 | optimizer = optimizers.Adam(args.inner_lr) 50 | val_loss, val_acc = maml.train_on_batch(val_data.get_one_batch(), inner_optimizer=optimizer, inner_step=5) 51 | print("Model with maml weight train for 5 step, val loss: {:.4f}, accuracy: {:.4f}.".format(val_loss, val_acc)) 52 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : net.py 3 | # @Author: Runist 4 | # @Time : 2020/7/6 16:52 5 | # @Software: PyCharm 6 | # @Brief: 实现模型分类的网络,MAML与网络结构无关,重点在训练过程 7 | 8 | from tensorflow.keras import layers, models, losses 9 | import tensorflow as tf 10 | import numpy as np 11 | 12 | 13 | class MAML: 14 | def __init__(self, input_shape, num_classes): 15 | """ 16 | MAML模型类,需要两个模型,一个是作为真实更新的权重θ,另一个是用来做θ'的更新 17 | :param input_shape: 模型输入shape 18 | :param num_classes: 分类数目 19 | """ 20 | self.input_shape = input_shape 21 | self.num_classes = num_classes 22 | self.meta_model = self.get_maml_model() 23 | 24 | def get_maml_model(self): 25 | """ 26 | 建立maml模型 27 | :return: maml model 28 | """ 29 | model = models.Sequential([ 30 | layers.Conv2D(filters=64, kernel_size=3, padding='same', activation="relu", 31 | input_shape=self.input_shape), 32 | layers.BatchNormalization(), 33 | layers.MaxPool2D(pool_size=2, strides=2), 34 | 35 | layers.Conv2D(filters=64, kernel_size=3, padding='same', activation="relu"), 36 | layers.BatchNormalization(), 37 | layers.MaxPool2D(pool_size=2, strides=2), 38 | 39 | layers.Conv2D(filters=64, kernel_size=3, padding='same', activation="relu"), 40 | layers.BatchNormalization(), 41 | layers.MaxPool2D(pool_size=2, strides=2), 42 | 43 | layers.Conv2D(filters=64, kernel_size=3, padding="same", activation="relu"), 44 | layers.BatchNormalization(), 45 | layers.MaxPool2D(pool_size=2, strides=2), 46 | 47 | layers.Flatten(), 48 | layers.Dense(self.num_classes, activation='softmax'), 49 | ]) 50 | 51 | return model 52 | 53 | def train_on_batch(self, train_data, inner_optimizer, inner_step, outer_optimizer=None): 54 | """ 55 | MAML一个batch的训练过程 56 | :param train_data: 训练数据,以task为一个单位 57 | :param inner_optimizer: support set对应的优化器 58 | :param inner_step: 内部更新几个step 59 | :param outer_optimizer: query set对应的优化器,如果对象不存在则不更新梯度 60 | :return: batch query loss 61 | """ 62 | batch_acc = [] 63 | batch_loss = [] 64 | task_weights = [] 65 | 66 | # 用meta_weights保存一开始的权重,并将其设置为inner step模型的权重 67 | meta_weights = self.meta_model.get_weights() 68 | 69 | meta_support_image, meta_support_label, meta_query_image, meta_query_label = next(train_data) 70 | for support_image, support_label in zip(meta_support_image, meta_support_label): 71 | 72 | # 每个task都需要载入最原始的weights进行更新 73 | self.meta_model.set_weights(meta_weights) 74 | for _ in range(inner_step): 75 | with tf.GradientTape() as tape: 76 | logits = self.meta_model(support_image, training=True) 77 | loss = losses.sparse_categorical_crossentropy(support_label, logits) 78 | loss = tf.reduce_mean(loss) 79 | 80 | acc = tf.cast(tf.argmax(logits, axis=-1, output_type=tf.int32) == support_label, tf.float32) 81 | acc = tf.reduce_mean(acc) 82 | 83 | grads = tape.gradient(loss, self.meta_model.trainable_variables) 84 | inner_optimizer.apply_gradients(zip(grads, self.meta_model.trainable_variables)) 85 | 86 | # 每次经过inner loop更新过后的weights都需要保存一次,保证这个weights后面outer loop训练的是同一个task 87 | task_weights.append(self.meta_model.get_weights()) 88 | 89 | with tf.GradientTape() as tape: 90 | for i, (query_image, query_label) in enumerate(zip(meta_query_image, meta_query_label)): 91 | 92 | # 载入每个task weights进行前向传播 93 | self.meta_model.set_weights(task_weights[i]) 94 | 95 | logits = self.meta_model(query_image, training=True) 96 | loss = losses.sparse_categorical_crossentropy(query_label, logits) 97 | loss = tf.reduce_mean(loss) 98 | batch_loss.append(loss) 99 | 100 | acc = tf.cast(tf.argmax(logits, axis=-1) == query_label, tf.float32) 101 | acc = tf.reduce_mean(acc) 102 | batch_acc.append(acc) 103 | 104 | mean_acc = tf.reduce_mean(batch_acc) 105 | mean_loss = tf.reduce_mean(batch_loss) 106 | 107 | # 无论是否更新,都需要载入最开始的权重进行更新,防止val阶段改变了原本的权重 108 | self.meta_model.set_weights(meta_weights) 109 | if outer_optimizer: 110 | grads = tape.gradient(mean_loss, self.meta_model.trainable_variables) 111 | outer_optimizer.apply_gradients(zip(grads, self.meta_model.trainable_variables)) 112 | 113 | return mean_loss, mean_acc 114 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu>=2.3.0 2 | opencv-python 3 | numpy -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : train.py 3 | # @Author: Runist 4 | # @Time : 2021/4/23 17:30 5 | # @Software: PyCharm 6 | # @Brief: 训练脚本 7 | 8 | from tensorflow.keras import optimizers, utils 9 | import tensorflow as tf 10 | import numpy as np 11 | 12 | from dataReader import MAMLDataLoader 13 | from net import MAML 14 | from config import args 15 | import shutil 16 | import os 17 | 18 | 19 | if __name__ == '__main__': 20 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 21 | 22 | gpus = tf.config.experimental.list_physical_devices("GPU") 23 | if gpus: 24 | for gpu in gpus: 25 | tf.config.experimental.set_memory_growth(gpu, True) 26 | 27 | train_data = MAMLDataLoader(args.train_data_dir, args.batch_size) 28 | val_data = MAMLDataLoader(args.val_data_dir, args.val_batch_size) 29 | 30 | inner_optimizer = optimizers.Adam(args.inner_lr) 31 | outer_optimizer = optimizers.Adam(args.outer_lr) 32 | 33 | maml = MAML(args.input_shape, args.n_way) 34 | # 验证次数可以少一些,不需要每次都更新这么多 35 | val_data.steps = 10 36 | 37 | for e in range(args.epochs): 38 | 39 | train_progbar = utils.Progbar(train_data.steps) 40 | val_progbar = utils.Progbar(val_data.steps) 41 | print('\nEpoch {}/{}'.format(e+1, args.epochs)) 42 | 43 | train_meta_loss = [] 44 | train_meta_acc = [] 45 | val_meta_loss = [] 46 | val_meta_acc = [] 47 | 48 | for i in range(train_data.steps): 49 | batch_train_loss, acc = maml.train_on_batch(train_data.get_one_batch(), 50 | inner_optimizer, 51 | inner_step=1, 52 | outer_optimizer=outer_optimizer) 53 | 54 | train_meta_loss.append(batch_train_loss) 55 | train_meta_acc.append(acc) 56 | train_progbar.update(i+1, [('loss', np.mean(train_meta_loss)), 57 | ('accuracy', np.mean(train_meta_acc))]) 58 | 59 | for i in range(val_data.steps): 60 | batch_val_loss, val_acc = maml.train_on_batch(val_data.get_one_batch(), inner_optimizer, inner_step=3) 61 | 62 | val_meta_loss.append(batch_val_loss) 63 | val_meta_acc.append(val_acc) 64 | val_progbar.update(i+1, [('val_loss', np.mean(val_meta_loss)), 65 | ('val_accuracy', np.mean(val_meta_acc))]) 66 | 67 | maml.meta_model.save_weights("maml.h5") 68 | -------------------------------------------------------------------------------- /train_mnist.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : train_mnist.py 3 | # @Author: Runist 4 | # @Time : 2021/9/3 9:25 5 | # @Software: PyCharm 6 | # @Brief: 7 | from net import MAML 8 | from tensorflow.keras import datasets, losses, optimizers, metrics 9 | from config import args 10 | import numpy as np 11 | import tensorflow as tf 12 | import os 13 | 14 | 15 | if __name__ == '__main__': 16 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 17 | 18 | gpus = tf.config.experimental.list_physical_devices("GPU") 19 | if gpus: 20 | for gpu in gpus: 21 | tf.config.experimental.set_memory_growth(gpu, True) 22 | 23 | maml = MAML(args.input_shape, 10) 24 | model = maml.get_maml_model() 25 | 26 | (x_train, y_train), (x_test, y_test) = datasets.mnist.load_data() 27 | 28 | # Normalize data 29 | x_train = x_train.astype("float32") / 255.0 30 | x_train = np.reshape(x_train, (-1, 28, 28, 1)) 31 | 32 | x_test = x_test.astype("float32") / 255.0 33 | x_test = np.reshape(x_test, (-1, 28, 28, 1)) 34 | 35 | # 训练teacher网络 36 | model.compile( 37 | optimizer=optimizers.Adam(), 38 | loss=losses.SparseCategoricalCrossentropy(from_logits=False), 39 | metrics=[metrics.SparseCategoricalAccuracy()], 40 | ) 41 | 42 | model.fit(x_train, y_train, epochs=3, shuffle=True, batch_size=256 ) 43 | model.evaluate(x_test, y_test) 44 | model.save_weights("mnist.h5") 45 | --------------------------------------------------------------------------------