├── .gitignore ├── imgs ├── LSTM.png ├── mann.png └── mann_1.png ├── mann ├── utils │ ├── tf_utils.py │ ├── generators.py │ └── images.py ├── model.py └── mann_cell.py ├── LICENSE ├── README.md └── run_mann.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | omniglot 3 | .idea 4 | -------------------------------------------------------------------------------- /imgs/LSTM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Leputa/MANN-meta-learning/HEAD/imgs/LSTM.png -------------------------------------------------------------------------------- /imgs/mann.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Leputa/MANN-meta-learning/HEAD/imgs/mann.png -------------------------------------------------------------------------------- /imgs/mann_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Leputa/MANN-meta-learning/HEAD/imgs/mann_1.png -------------------------------------------------------------------------------- /mann/utils/tf_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | def variable_one_hot(shape, name=''): 5 | initial = np.zeros(shape) 6 | initial[...,0] = 1 7 | return tf.constant(initial, dtype=tf.float32) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Laputa 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mann/utils/generators.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | from PIL import Image 5 | 6 | from .images import get_images_labels 7 | 8 | 9 | class OmniglotGenerator(object): 10 | def __init__(self, data_folder, nb_classes=5, nb_samples_per_class=10, img_size = (20, 20)): 11 | self.nb_classes = nb_classes 12 | self.nb_samples_per_class = nb_samples_per_class 13 | self.img_size = img_size 14 | self.images = [] 15 | for dirname, subdirname, filelist in os.walk(data_folder): 16 | if filelist: 17 | self.images.append( 18 | [Image.open(os.path.join(dirname, filename)).copy() for filename in filelist] 19 | ) 20 | num_train = 1200 21 | self.train_images = self.images[:num_train] 22 | self.test_images = self.images[num_train:] 23 | 24 | def sample(self, batch_type, batch_size, sample_strategy="random"): 25 | if batch_type == "train": 26 | data = self.train_images 27 | elif batch_type == "test": 28 | data = self.test_images 29 | 30 | sampled_inputs = np.zeros((batch_size, self.nb_classes * self.nb_samples_per_class, np.prod(self.img_size)), dtype=np.float32) 31 | sampled_outputs = np.zeros((batch_size, self.nb_classes * self.nb_samples_per_class), dtype=np.int32) 32 | 33 | for i in range(batch_size): 34 | images, labels = get_images_labels(data, self.nb_classes, self.nb_samples_per_class, self.img_size, sample_strategy) 35 | sampled_inputs[i] = np.asarray(images, dtype=np.float32) 36 | sampled_outputs[i] = np.asarray(labels, dtype=np.int32) 37 | return sampled_inputs, sampled_outputs 38 | 39 | 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta-Learning with Memory-Augmented Neural Networks in Tensorflow 2 | 3 | A concise alternative Tensorflow Implementation of Papar *Santoro, Adam, et al."[Meta-learning with memory-augmented neural networks.](http://proceedings.mlr.press/v48/santoro16.pdf)"International conference on machine learning. 2016.* 4 | And the model are encapsulated into class MANNCell which can be used as BasicRNNCell. 5 | The code is inspired by the excellent implementations of [tristandeleu](https://github.com/tristandeleu/ntm-one-shot) and [snowkylin](https://github.com/snowkylin/ntm). 6 | 7 | 8 | ## Memory-Augmented Neural Networks 9 | As shown in reference paper, MANNs(Memory-Augmented Neural Networks) refer to the class of external memory equipped networkds such as NTMs(Neural Turing Machines). 10 |
11 | ![MANN](imgs/mann.png) 12 |
13 | 14 | ## Dependencies 15 | * Python 3.6 16 | * Tensorflow==1.14 17 | * numpy==1.16.4 18 | * PIL==7.1.1 19 | 20 | ## Usage 21 | ### Omniglot DataSet 22 | Download [images_background.zip](https://github.com/brendenlake/omniglot/blob/master/python/images_background.zip) (964 classes) and [images_evaluation.zip](https://github.com/brendenlake/omniglot/blob/master/python/images_evaluation.zip) (679 classes), 23 | and place them in the [./omniglot](omniglot) folder. 24 | 25 | ### Running 26 | `python run_mann.py` 27 | `python run_mann.py --mode test` 28 | `python run_mann.py --model LSTM` 29 | `python run_mann.py --model LSTM --mode test` 30 | 31 | ### Class MANNCell() 32 | ```python 33 | from mann.mann_cell import MANNCell 34 | cell = MANNCell( 35 | lstm_size = 200, 36 | memory_size = 128, 37 | memory_dim = 40, 38 | nb_reads = 4, 39 | gamma = 0.95 40 | ) 41 | state = cell.zero_state(batch_size, tf.float32) 42 | output, state = tf.scan(lambda init, elem: cell(elem, init[1]), elems=tf.transpose(input, perm=[1, 0, 2]), initializer=(tf.zeros(shape=(batch_size, lstm_size+nb_reads*memory_dim)), state)) 43 | output = tf.transpose(output, perm=[1, 0, 2]) 44 | ``` 45 | 46 | 47 | 48 | ## Performance 49 | Omniglot Classfication: 50 | 51 | ![LSTM](imgs/LSTM.png) | ![MANN](imgs/mann_1.png) 52 | ---|--- 53 | 54 | 55 | Test-set classfication accuracies on the Omniglot dataset, using one-hot encodings of labels and five classes presented per episode. 56 | 57 | | Model | 1st | 2nd | 3rd | 4th | 5th | 10th| 58 | | :--- | :---: | :---: | :---: | :---: | :---: | :---: | 59 | | LSTMref | 24.4% | 49.5% | 55.3% | 61.0% | 63.6% | 62.5% | 60 | | LSTMrepo | 30.4% | 77.9% | 85.3% | 87.5% | 88.8% | 91.6% | 61 | | MANNref | 36.4% | 82.8% | 91.0% | 92.6% | 94.9% | 98.1% | 62 | | MANNrepo | 35.4% | 89.2% | 95.2% | 96.3% | 96.9% | 97.8% | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /mann/utils/images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | 5 | import scipy 6 | from scipy.misc import imread 7 | from scipy.ndimage import rotate, shift 8 | 9 | from PIL import Image, ImageOps 10 | 11 | 12 | def get_images_labels(all_images, nb_classes, nb_samples_per_class, image_size, sample_stategy = "uniform"): 13 | sample_classes = np.random.choice(range(len(all_images)), replace=True, size=nb_classes) 14 | if sample_stategy == "random": 15 | labels = np.random.randint(0, nb_classes, nb_classes * nb_samples_per_class) 16 | elif sample_stategy == "uniform": 17 | labels = np.concatenate([[i] * nb_samples_per_class for i in range(nb_classes)]) 18 | np.random.shuffle(labels) 19 | angles = np.random.randint(0, 4, nb_classes) * 90 20 | images = [image_transform(all_images[sample_classes[i]][np.random.randint(0, len(all_images[sample_classes[i]]))], 21 | angle=angles[i]+(np.random.rand()-0.5)*22.5, trans=np.random.randint(-10, 11, size=2).tolist(), size=image_size) 22 | for i in labels] 23 | return images, labels 24 | 25 | def image_transform(image, angle=0., trans=(0.,0.), size=(20, 20)): 26 | image = ImageOps.invert(image.convert("L")).rotate(angle, translate=trans).resize(size) 27 | np_image= np.reshape(np.array(image, dtype=np.float32), newshape=(np.prod(size))) 28 | max_value = np.max(np_image) 29 | if max_value > 0.: 30 | np_image = np_image / max_value 31 | return np_image 32 | 33 | 34 | # def get_images(character_folders, nb_classes, nb_samples_per_class, sample_stategy = "uniform"): 35 | # sampled_characters = random.sample(character_folders, nb_classes) 36 | # if sample_stategy == "random": 37 | # images_labels = [(label, os.path.join(character, image_path)) \ 38 | # for label, character in zip(np.arange(nb_classes), sampled_characters) \ 39 | # for image_path in os.listdir(character)] 40 | # images_labels = random.sample(images_labels, nb_classes * nb_samples_per_class) 41 | # elif sample_stategy == "uniform": 42 | # sampler = lambda x: random.sample(x, nb_samples_per_class) 43 | # images_labels = [(i, os.path.join(path, image)) 44 | # for i, path in zip(np.arange(nb_classes), sampled_characters) 45 | # for image in sampler(os.listdir(path))] 46 | # random.shuffle(images_labels) 47 | # return images_labels 48 | 49 | # def load_transform(image_path, angle=0., size=(20, 20)): 50 | # # Load the image 51 | # original = imread(image_path, flatten=True) 52 | # # Rotate the image 53 | # rotated = np.maximum(np.minimum(rotate(original, angle=angle * 90 + (np.random.rand() - 0.5) * 22.5, cval=1.), 1.), 0.) 54 | # # Resize the image 55 | # resized = np.asarray(scipy.misc.imresize(rotated, size=size), dtype=np.float32) / 255. 56 | # # Invert the image 57 | # inverted = 1. - resized 58 | # max_value = np.max(inverted) 59 | # if max_value > 0.: 60 | # inverted /= max_value 61 | # return inverted 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /mann/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from .mann_cell import MANNCell 4 | 5 | 6 | 7 | class MANN(): 8 | def __init__(self, learning_rate = 1e-3, input_size = 20 * 20, memory_size = 128, memory_dim = 40, 9 | controller_size = 200, nb_reads = 4, num_layers = 1, nb_classes = 5, nb_samples_per_class = 10, batch_size = 16, model="MANN"): 10 | self.learning_rate = learning_rate 11 | self.input_size = input_size 12 | self.memory_size = memory_size 13 | self.memory_dim = memory_dim 14 | self.controller_size = controller_size 15 | self.num_layers = num_layers 16 | self.nb_reads = nb_reads 17 | self.nb_classes = nb_classes 18 | self.nb_samples_per_class = nb_samples_per_class 19 | self.batch_size = batch_size 20 | self.model = model 21 | 22 | self.image = tf.placeholder(dtype=tf.float32, shape=[self.batch_size, self.nb_classes * self.nb_samples_per_class, self.input_size], name="input_var") 23 | self.label = tf.placeholder(dtype=tf.int32, shape=[self.batch_size, self.nb_classes * self.nb_samples_per_class], name="target_var") 24 | 25 | 26 | def build_model(self): 27 | input_var = self.image 28 | target_var = self.label 29 | 30 | one_hot_target = tf.one_hot(target_var, self.nb_classes, axis=-1) 31 | offset_target_var = tf.concat([tf.zeros_like(tf.expand_dims(one_hot_target[:, 0], 1)), one_hot_target[:,:-1]], axis=1) 32 | ntm_input = tf.concat([input_var, offset_target_var], axis=2) 33 | 34 | if self.model == "LSTM": 35 | def rnn_cell(rnn_size): 36 | return tf.nn.rnn_cell.BasicLSTMCell(rnn_size) 37 | cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell(self.controller_size) for _ in range(self.num_layers)]) 38 | hidden_dim = self.controller_size 39 | elif self.model == "MANN": 40 | cell = MANNCell(lstm_size=self.controller_size, memory_size=self.memory_size, memory_dim=self.memory_dim, nb_reads=self.nb_reads) 41 | hidden_dim = self.controller_size + self.nb_reads * self.memory_dim 42 | 43 | state = cell.zero_state(self.batch_size, tf.float32) 44 | output, cell_state = tf.scan(lambda init, elem: cell(elem, init[1]), elems=tf.transpose(ntm_input, perm=[1, 0, 2]), initializer=(tf.zeros(shape=(self.batch_size, hidden_dim)), state)) 45 | output = tf.transpose(output, perm=[1, 0, 2]) 46 | 47 | with tf.variable_scope("o2o"): 48 | output = tf.layers.dense( 49 | inputs=output, 50 | units=self.nb_classes, 51 | kernel_initializer=tf.random_uniform_initializer(minval=-0.1, maxval=0.1), 52 | bias_initializer=tf.random_uniform_initializer(minval=-0.1, maxval=0.1), 53 | ) 54 | 55 | self.output = tf.nn.softmax(output, dim=2) 56 | self.output = tf.reshape(self.output, shape=(self.batch_size, self.nb_classes * self.nb_samples_per_class, self.nb_classes)) 57 | self.loss = tf.reduce_mean(tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(labels=one_hot_target, logits=output), axis=1)) 58 | 59 | with tf.variable_scope("optimizer"): 60 | self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) 61 | self.train_op = self.optimizer.minimize((self.loss)) 62 | 63 | -------------------------------------------------------------------------------- /mann/mann_cell.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from mann.utils.tf_utils import variable_one_hot 4 | 5 | 6 | class MANNCell(): 7 | def __init__(self, lstm_size, memory_size, memory_dim, nb_reads, 8 | gamma=0.95, reuse=False): 9 | self.lstm_size = lstm_size 10 | self.memory_size = memory_size 11 | self.memory_dim = memory_dim 12 | self.nb_reads = nb_reads 13 | self.reuse = reuse 14 | self.step = 0 15 | self.gamma = gamma 16 | self.controller = tf.nn.rnn_cell.BasicLSTMCell(self.lstm_size) 17 | 18 | 19 | def __call__(self, input, prev_state): 20 | M_prev, r_prev, controller_state_prev, wu_prev, wr_prev = \ 21 | prev_state["M"], prev_state["read_vector"], prev_state["controller_state"], prev_state["wu"], prev_state["wr"] 22 | 23 | controller_input = tf.concat([input, wr_prev], axis=-1) 24 | with tf.variable_scope("controller", reuse=self.reuse): 25 | controller_hidden_t, controller_state_t = self.controller(controller_input, controller_state_prev) 26 | 27 | parameter_dim_per_head = self.memory_dim * 2 + 1 28 | parameter_total_dim = parameter_dim_per_head * self.nb_reads # [] 29 | 30 | with tf.variable_scope("o2p", reuse=(self.step > 0) or self.reuse): 31 | parameter = tf.layers.dense( 32 | inputs=controller_hidden_t, 33 | units=parameter_total_dim, 34 | kernel_initializer=tf.random_uniform_initializer(minval=-0.1, maxval=0.1), 35 | bias_initializer=tf.random_uniform_initializer(minval=-0.1, maxval=0.1), 36 | ) 37 | 38 | indices_prev, wlu_prev = self.least_used(wu_prev) 39 | 40 | k = tf.tanh(parameter[:, 0:self.nb_reads * self.memory_dim], name="k") 41 | a = tf.tanh(parameter[:, self.nb_reads * self.memory_dim: 2 * self.nb_reads * self.memory_dim], name="a") 42 | sig_alpha = tf.sigmoid(parameter[:, -self.nb_reads: ], name="sig_alpha") 43 | 44 | wr_t = self.read_head_addressing(k, M_prev) 45 | ww_t = self.write_head_addressing(sig_alpha, wr_prev, wlu_prev) 46 | 47 | wu_t = self.gamma * wu_prev + tf.reduce_sum(wr_t, axis=1) + tf.reduce_sum(ww_t, axis=1) 48 | 49 | # "Prior to writing to memory, the least used memory location set to zero" 50 | M_t = M_prev * tf.expand_dims(1. - tf.one_hot(indices_prev[:, -1], self.memory_size), dim=2) 51 | M_t = M_t + tf.matmul(tf.transpose(ww_t, perm=[0,2,1]), tf.reshape(a, shape=(a.get_shape()[0], self.nb_reads, self.memory_dim))) 52 | 53 | r_t = tf.reshape(tf.matmul(wr_t, M_t), shape=(r_prev.get_shape()[0], self.nb_reads * self.memory_dim)) 54 | 55 | 56 | state = { 57 | "M": M_t, 58 | "read_vector": r_t, 59 | "controller_state": controller_state_t, 60 | "wu": wu_t, 61 | "wr": tf.reshape(wr_t, shape=(wr_t.get_shape()[0], self.nb_reads * self.memory_size)), 62 | } 63 | 64 | NTM_output = tf.concat([controller_hidden_t, r_t], axis=-1) 65 | 66 | self.step += 1 67 | return NTM_output, state 68 | 69 | 70 | def read_head_addressing(self, k, M_prev, eps=1e-8): 71 | with tf.variable_scope("read_head_addressing"): 72 | k = tf.reshape(k, shape=(k.get_shape()[0], self.nb_reads, self.memory_dim)) 73 | inner_product = tf.matmul(k, tf.transpose(M_prev, [0, 2, 1])) 74 | 75 | k_norm = tf.sqrt(tf.expand_dims(tf.reduce_sum(tf.square(k), 2), 2)) 76 | M_norm = tf.sqrt(tf.expand_dims(tf.reduce_sum(tf.square(M_prev), 2), 1)) 77 | 78 | norm_product = k_norm * M_norm 79 | K = inner_product / (norm_product + eps) 80 | return tf.nn.softmax(K) 81 | 82 | def write_head_addressing(self, sig_alpha, wr_prev, wlu_prev): 83 | with tf.variable_scope("write_head_addressing"): 84 | sig_alpha = tf.expand_dims(sig_alpha, axis=-1) 85 | wr_prev = tf.reshape(wr_prev, shape=(wr_prev.get_shape()[0], self.nb_reads, self.memory_size)) 86 | return sig_alpha * wr_prev + (1. - sig_alpha) * tf.expand_dims(wlu_prev, axis=1) 87 | 88 | def least_used(self, w_u): 89 | _, indices = tf.nn.top_k(w_u, k=self.memory_size) 90 | wlu = tf.cast(tf.slice(indices, [0, self.memory_size - self.nb_reads], [w_u.get_shape()[0], self.nb_reads]), dtype=tf.int32) 91 | wlu = tf.reduce_sum(tf.one_hot(wlu, self.memory_size), axis=1) 92 | return indices, wlu 93 | 94 | def zero_state(self, batch_size, dtype): 95 | with tf.variable_scope("init", reuse=self.reuse): 96 | M_0 = tf.constant(np.ones([batch_size, self.memory_size, self.memory_dim]) * 1e-6, dtype=tf.float32) 97 | r_0 = tf.zeros(shape=(batch_size, self.nb_reads * self.memory_dim)) 98 | controller_state_0 = self.controller.zero_state(batch_size, dtype) 99 | wu_0 = variable_one_hot(shape=(batch_size, self.memory_size)) 100 | wr_0 = variable_one_hot(shape=(batch_size, self.memory_size * self.nb_reads)) 101 | 102 | state ={ 103 | "M": M_0, 104 | "read_vector":r_0, 105 | "controller_state": controller_state_0, 106 | "wu": wu_0, 107 | "wr": wr_0, 108 | } 109 | 110 | return state 111 | -------------------------------------------------------------------------------- /run_mann.py: -------------------------------------------------------------------------------- 1 | from mann.model import MANN 2 | from mann.utils.generators import OmniglotGenerator 3 | 4 | from argparse import ArgumentParser 5 | import tensorflow as tf 6 | import os 7 | import numpy as np 8 | import datetime 9 | 10 | 11 | def build_argparser(): 12 | parser = ArgumentParser() 13 | 14 | parser.add_argument('--mode', default="train") 15 | parser.add_argument('--restore_training', default=False) 16 | parser.add_argument('--batch-size', 17 | dest='_batch_size', help='Batch size (default: %(default)s)', 18 | type=int, default=16) 19 | parser.add_argument('--num-classes', 20 | dest='_nb_classes', help='Number of classes in each episode (default: %(default)s)', 21 | type=int, default=5) 22 | parser.add_argument('--num-samples', 23 | dest='_nb_samples_per_class', help='Number of taotal samples in each episode (default: %(default)s)', 24 | type=int, default=10) 25 | parser.add_argument('--input-height', 26 | dest='_input_height', help='Input image height (default: %(default)s)', 27 | type=int, default=20) 28 | parser.add_argument('--input-width', 29 | dest='_input_width', help='Input image width (default: %(default)s)', 30 | type=int, default=20) 31 | parser.add_argument('--num-reads', 32 | dest='_nb_reads', help='Number of read heads (default: %(default)s)', 33 | type=int, default=4) 34 | parser.add_argument('--controller-size', 35 | dest='_controller_size', help='Number of hidden units in controller (default: %(default)s)', 36 | type=int, default=200) 37 | parser.add_argument('--memory-locations', 38 | dest='_memory_locations', help='Number of locations in the memory (default: %(default)s)', 39 | type=int, default=128) 40 | parser.add_argument('--memory-word-size', 41 | dest='_memory_word_size', help='Size of each word in memory (default: %(default)s)', 42 | type=int, default=40) 43 | parser.add_argument('--num_layers', 44 | dest='_num_layers', help='Size of each word in memory (default: %(default)s)', 45 | type=int, default=1) 46 | parser.add_argument('--learning-rate', 47 | dest='_learning_rate', help='Learning Rate (default: %(default)s)', 48 | type=float, default=1e-3) 49 | parser.add_argument('--start_iterations', 50 | dest='_start_iterations', default=0) 51 | parser.add_argument('--iterations', 52 | dest='_iterations', help='Number of iterations for training (default: %(default)s)', 53 | type=int, default=100000) 54 | parser.add_argument('--augment', default=True) 55 | parser.add_argument('--save-dir', default='./ckpt/') 56 | parser.add_argument("--log-dir", default="./log/") 57 | parser.add_argument('--model', default="MANN", help='LSTM or MANN') 58 | 59 | return parser 60 | 61 | def metric_accuracy(args, labels, outputs): 62 | seq_length = args._nb_classes * args._nb_samples_per_class 63 | outputs = np.argmax(outputs, axis=-1) 64 | correct = [0] * seq_length 65 | total = [0] * seq_length 66 | for i in range(np.shape(labels)[0]): 67 | label = labels[i] 68 | output = outputs[i] 69 | class_count = {} 70 | for j in range(seq_length): 71 | class_count[label[j]] = class_count.get(label[j], 0) + 1 72 | total[class_count[label[j]]] += 1 73 | if label[j] == output[j]: 74 | correct[class_count[label[j]]] += 1 75 | return [float(correct[i]) / total[i] if total[i] > 0. else 0. for i in range(1, args._nb_samples_per_class + 1)] 76 | 77 | 78 | def train(model:MANN, data_genarator: OmniglotGenerator, sess, saver, args): 79 | start_iter = args._start_iterations 80 | max_iter = args._iterations 81 | csv_write_path = '{}/{}-{}-{}--{}.csv'.format(args.log_dir, args.model, args._nb_classes, args._nb_samples_per_class, datetime.datetime.now().strftime('%m-%d-%H-%M')) 82 | 83 | print(args) 84 | print("1st\t2nd\t3rd\t4th\t5th\t6th\t7th\t8th\t9th\t10th\tbatch\tloss") 85 | 86 | for ep in range(start_iter, max_iter): 87 | if ep % 100 == 0: 88 | image, label = data_genarator.sample("test", args._batch_size) 89 | feed_dict = {model.image: image, model.label: label} 90 | output, loss = sess.run([model.output, model.loss], feed_dict=feed_dict) 91 | accuracy = metric_accuracy(args, label, output) 92 | for accu in accuracy: 93 | print('%.4f' % accu, end='\t') 94 | print('%d\t%.4f' % (ep, loss)) 95 | 96 | with open(csv_write_path, 'a') as fh: 97 | fh.write(str(ep) + ", " +", ".join(['%.4f' % accu for accu in accuracy])+ "\n") 98 | 99 | 100 | if ep % 5000 == 0 and ep > 0: 101 | saver.save(sess, os.path.join(args.save_dir, args.model) + "/model.", global_step=ep) 102 | 103 | image, label = data_genarator.sample("train", args._batch_size) 104 | feed_dict = {model.image: image, model.label: label} 105 | 106 | sess.run(model.train_op, feed_dict=feed_dict) 107 | 108 | def test(model: MANN, data_generator: OmniglotGenerator, sess, args): 109 | print("Test Result\n1st\t2nd\t3rd\t4th\t5th\t6th\t7th\t8th\t9th\t10th\tloss") 110 | label_list = [] 111 | output_list = [] 112 | loss_list = [] 113 | for ep in range(100): 114 | image, label = data_generator.sample("test", args._batch_size) 115 | feed_dict = {model.image: image, model.label: label} 116 | output, loss = sess.run([model.output, model.loss], feed_dict = feed_dict) 117 | label_list.append(label) 118 | output_list.append(output) 119 | loss_list.append(loss) 120 | accuracy = metric_accuracy(args, np.concatenate(label_list, axis=0), np.concatenate(output_list, axis=0)) 121 | for accu in accuracy: 122 | print('%.4f' % accu, end='\t') 123 | print(np.mean(loss_list)) 124 | 125 | if __name__ == "__main__": 126 | parser = build_argparser() 127 | args = parser.parse_args() 128 | 129 | batch_size = args._batch_size 130 | nb_classes = args._nb_classes 131 | nb_samples_per_class = args._nb_samples_per_class 132 | img_size = (args._input_height, args._input_width) 133 | input_size = args._input_height * args._input_width 134 | 135 | nb_reads = args._nb_reads 136 | controller_size = args._controller_size 137 | memory_size = args._memory_locations 138 | memory_dim = args._memory_word_size 139 | num_layers = args._num_layers 140 | 141 | learning_rate = args._learning_rate 142 | 143 | 144 | 145 | model = MANN(learning_rate, input_size, memory_size, memory_dim, 146 | controller_size, nb_reads, num_layers, nb_classes, nb_samples_per_class, batch_size, args.model) 147 | model.build_model() 148 | 149 | data_generator = OmniglotGenerator(data_folder="./omniglot", nb_classes=nb_classes, 150 | nb_samples_per_class=nb_samples_per_class, img_size=img_size) 151 | 152 | tf_config = tf.ConfigProto() 153 | tf_config.gpu_options.allow_growth = True 154 | sess = tf.InteractiveSession(config=tf_config) 155 | 156 | if args.restore_training or args.mode == "test": 157 | saver = tf.train.Saver() 158 | ckpt = tf.train.get_checkpoint_state(os.path.join(args.save_dir, args.model)) 159 | saver.restore(sess, ckpt.model_checkpoint_path) 160 | else: 161 | saver = tf.train.Saver(tf.global_variables()) 162 | sess.run(tf.global_variables_initializer()) 163 | 164 | if args.mode == "train": 165 | train(model, data_generator, sess, saver, args) 166 | elif args.mode == "test": 167 | test(model, data_generator, sess, args) 168 | 169 | sess.close() 170 | --------------------------------------------------------------------------------