├── README.md ├── main.py ├── model.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Matching Networks for One Shot Learning 2 | Tensorflow implementation of [Matching Networks for One Shot Learning by Vinyals et al](https://arxiv.org/abs/1606.04080). 3 | 4 | ## Prerequisites 5 | - Python 2.7+ 6 | - [NumPy](http://www.numpy.org/) 7 | - [SciPy](https://www.scipy.org/) 8 | - [tqdm](https://pypi.python.org/pypi/tqdm) 9 | - [Tensorflow r1.0+](https://www.tensorflow.org/install/) 10 | 11 | 12 | ## Data 13 | - [Omniglot](https://github.com/brendenlake/omniglot) 14 | 15 | 16 | ## Preparation 17 | 1. Download and extract omniglot dataset, modify `omniglot_train` and `omniglot_test` in `utils.py` to your location. 18 | 19 | 2. First time training will generate `omniglot.npy` to the directory. The shape should be _(1632, 80, 28, 28, 1)_ , meaning 1623 classes, 20 * 4 90-degree-transforms (0, 90, 180, 270), height, width, channel. 1200 classes used for training and 423 used for testing. 20 | 21 | ## Train 22 | ```bash 23 | python main.py --train 24 | ``` 25 | Train from a previous checkpoint at epoch X: 26 | ```bash 27 | python main.py --train --modelpath=ckpt/model-X 28 | ``` 29 | Check out tunable hyper-parameters: 30 | ```bash 31 | python main.py 32 | ``` 33 | 34 | ## Test 35 | ```bash 36 | python main.py --eval 37 | ``` 38 | 39 | ## Notes 40 | - The model will test the evaluation accuracy after every epoch. 41 | - As the paper indicated, training on Omniglot with FCE does not do any better but I still implemented them (as far as I'm concerned there are no repos that fully implement the FCEs by far). 42 | - The authors did not mentioned the value of time steps K in FCE_f, in the [sited paper](https://arxiv.org/abs/1511.06391), K is tested with 0, 1, 5, 10 as shown in table 1. 43 | - When using the data generated by myself (through `utils.py`), the evaluation accuracy at epoch 100 is around 82.00% (training accuracy 83.14%) without data augmentation. 44 | - Nevertheless, when using data provided by _zergylord_ in his [repo](https://github.com/zergylord/oneshot), this implementation can achieve up to 96.61% accuracy (training 97.22%) at epoch 100. 45 | - Issues are welcome! 46 | 47 | ## Resources 48 | - [The paper](https://arxiv.org/abs/1606.04080). 49 | - Referred to [this repo](https://github.com/AntreasAntoniou/MatchingNetworks). 50 | - [Karpathy's note](https://github.com/karpathy/paper-notes/blob/master/matching_networks.md) helps a lot. 51 | 52 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import os 6 | import pdb 7 | import argparse 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | from model import Matching_Nets 13 | from utils import Data_loader 14 | 15 | def evaluate(args): 16 | 17 | loader = Data_loader(args.bsize, args.n_way, args.k_shot, train_mode=False) 18 | model = Matching_Nets(args.lr, args.n_way, args.k_shot, args.use_fce, args.bsize) 19 | 20 | model.build(model.support_set_image_ph, model.support_set_label_ph, model.example_image_ph) 21 | 22 | saver = tf.train.Saver() 23 | config = tf.ConfigProto() 24 | config.gpu_options.allow_growth = True 25 | sess = tf.Session(config=config) 26 | sess.run(tf.global_variables_initializer()) 27 | 28 | if args.modelpath is not None: 29 | print ('Using model: {}'.format(args.modelpath)) 30 | saver.restore(sess, args.modelpath) 31 | else: 32 | latest_ckpt = tf.train.latest_checkpoint('log') 33 | print ('Using latest: {}'.format(latest_ckpt)) 34 | 35 | correct = 0 36 | for _ in xrange(loader.iters): 37 | x_set, y_set, x_hat, y_hat = loader.next_batch() 38 | feed_dict = {model.support_set_image_ph: x_set, 39 | model.support_set_label_ph: y_set, 40 | model.example_image_ph: x_hat} 41 | logits, prediction = sess.run([model.logits, model.pred], feed_dict=feed_dict) 42 | correct += np.sum(np.equal(prediction, y_hat)) 43 | 44 | print ('Evaluation accuracy: %.2f%%' % (correct * 100 / (loader.iters * args.bsize))) 45 | 46 | def train(args): 47 | 48 | train_loader = Data_loader(args.bsize, args.n_way, args.k_shot) 49 | eval_loader = Data_loader(args.bsize, args.n_way, args.k_shot, train_mode=False) 50 | model = Matching_Nets(args.lr, args.n_way, args.k_shot, args.use_fce, args.bsize) 51 | 52 | model.build(model.support_set_image_ph, model.support_set_label_ph, model.example_image_ph) 53 | model.loss(model.example_label_ph) 54 | train_op = model.train() 55 | 56 | saver = tf.train.Saver() 57 | config = tf.ConfigProto() 58 | config.gpu_options.allow_growth = True 59 | sess = tf.Session(config=config) 60 | sess.run(tf.global_variables_initializer()) 61 | 62 | if args.modelpath is not None: 63 | print ('From model: {}'.format(args.modelpath)) 64 | saver.restore(sess, args.modelpath) 65 | 66 | print ('Start training') 67 | print ('batch size: %d, ep: %d, iter: %d, initial lr: %.3f' % (args.bsize, args.ep, train_loader.iters, args.lr)) 68 | 69 | for ep in xrange(args.ep): 70 | # start training 71 | correct = [] 72 | for step in xrange(train_loader.iters): 73 | x_set, y_set, x_hat, y_hat = train_loader.next_batch() 74 | feed_dict = {model.support_set_image_ph: x_set, 75 | model.support_set_label_ph: y_set, 76 | model.example_image_ph: x_hat, 77 | model.example_label_ph: y_hat} 78 | logits, prediction, loss, _ = sess.run([model.logits, model.pred, model.loss_op, train_op], feed_dict=feed_dict) 79 | correct.append(np.equal(prediction, y_hat)) 80 | 81 | if step % 100 == 0: 82 | print ('ep: %3d, step: %3d, loss: %.3f, acc: %.2f%%' % (ep+1, step, loss, np.mean(np.equal(prediction, y_hat)) * 100)) 83 | 84 | print (' Training accuracy: %.2f%%' % (np.mean(np.stack(correct)) * 100)) 85 | checkpoint_path = os.path.join('log', 'matchnet.ckpt') 86 | saver.save(sess, checkpoint_path, global_step=ep+1) 87 | 88 | # training for one epoch done, evaluate on test set 89 | correct = [] 90 | for step in xrange(eval_loader.iters): 91 | x_set, y_set, x_hat, y_hat = eval_loader.next_batch() 92 | feed_dict = {model.support_set_image_ph: x_set, 93 | model.support_set_label_ph: y_set, 94 | model.example_image_ph: x_hat} 95 | logits, prediction = sess.run([model.logits, model.pred], feed_dict=feed_dict) 96 | correct.append(np.equal(prediction, y_hat)) 97 | 98 | print ('Evaluation accuracy: %.2f%%' % (np.mean(np.stack(correct)) * 100)) 99 | 100 | print ('Done.') 101 | 102 | if __name__ == '__main__': 103 | parser = argparse.ArgumentParser() 104 | parser.add_argument('--train', action='store_true', help='set this to train.') 105 | parser.add_argument('--eval', action='store_true', help='set this to evaluate.') 106 | parser.add_argument('--lr', metavar='', type=float, default=1e-3, help='learning rate.') 107 | parser.add_argument('--ep', metavar='', type=int, default=100, help='number of epochs.') 108 | parser.add_argument('--bsize', metavar='', type=int, default=32, help='batch size.') 109 | parser.add_argument('--n-way', metavar='', type=int, default=20, help='number of classes.') 110 | parser.add_argument('--k-shot', metavar='', type=int, default=1, help='number of chances the model see.') 111 | parser.add_argument('--use-fce', metavar='', type=bool, default=False, help='use fully conditional embedding or not.') 112 | parser.add_argument('--modelpath', metavar='', type=str, default=None, help='trained tensorflow model path.') 113 | args, unparsed = parser.parse_known_args() 114 | if len(unparsed) != 0: raise SystemExit('Unknown argument: {}'.format(unparsed)) 115 | if args.train: 116 | train(args) 117 | if args.eval: 118 | evaluate(args) 119 | if not args.train and not args.eval: 120 | parser.print_help() 121 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | slim = tf.contrib.slim 6 | rnn = tf.contrib.rnn 7 | 8 | class Matching_Nets(): 9 | 10 | def __init__(self, lr, n_way, k_shot, use_fce, batch_size=32): 11 | self.lr = lr 12 | self.n_way = n_way 13 | self.k_shot = k_shot 14 | self.use_fce = use_fce 15 | self.batch_size = batch_size 16 | self.processing_steps = 10 17 | 18 | self.support_set_image_ph = tf.placeholder(tf.float32, [None, n_way * k_shot, 28, 28, 1]) 19 | self.support_set_label_ph = tf.placeholder(tf.int32, [None, n_way * k_shot, ]) 20 | self.example_image_ph = tf.placeholder(tf.float32, [None, 28, 28, 1]) 21 | self.example_label_ph = tf.placeholder(tf.int32, [None, ]) 22 | 23 | def image_encoder(self, image): 24 | """the embedding function for image (potentially f = g) 25 | For omniglot it's a simple 4 layer ConvNet, for mini-imagenet it's VGG or Inception 26 | """ 27 | with slim.arg_scope([slim.conv2d], num_outputs=64, kernel_size=3, normalizer_fn=slim.batch_norm): 28 | net = slim.conv2d(image) 29 | net = slim.max_pool2d(net, [2, 2]) 30 | net = slim.conv2d(net) 31 | net = slim.max_pool2d(net, [2, 2]) 32 | net = slim.conv2d(net) 33 | net = slim.max_pool2d(net, [2, 2]) 34 | net = slim.conv2d(net) 35 | net = slim.max_pool2d(net, [2, 2]) 36 | return tf.reshape(net, [-1, 1 * 1 * 64]) 37 | 38 | def fce_g(self, encoded_x_i): 39 | """the fully conditional embedding function g 40 | This is a bi-directional LSTM, g(x_i, S) = h_i(->) + h_i(<-) + g'(x_i) where g' is the image encoder 41 | For omniglot, this is not used. 42 | 43 | encoded_x_i: g'(x_i) in the equation. length n * k list of (batch_size ,64) 44 | """ 45 | fw_cell = rnn.BasicLSTMCell(32) # 32 is half of 64 (output from cnn) 46 | bw_cell = rnn.BasicLSTMCell(32) 47 | outputs, state_fw, state_bw = rnn.static_bidirectional_rnn(fw_cell, bw_cell, encoded_x_i, dtype=tf.float32) 48 | 49 | return tf.add(tf.stack(encoded_x_i), tf.stack(outputs)) 50 | 51 | def fce_f(self, encoded_x, g_embedding): 52 | """the fully conditional embedding function f 53 | This is just a vanilla LSTM with attention where the input at each time step is constant and the hidden state 54 | is a function of previous hidden state but also a concatenated readout vector. 55 | For omniglot, this is not used. 56 | 57 | encoded_x: f'(x_hat) in equation (3) in paper appendix A.1. (batch_size, 64) 58 | g_embedding: g(x_i) in equation (5), (6) in paper appendix A.1. (n * k, batch_size, 64) 59 | """ 60 | cell = rnn.BasicLSTMCell(64) 61 | prev_state = cell.zero_state(self.batch_size, tf.float32) # state[0] is c, state[1] is h 62 | 63 | for step in xrange(self.processing_steps): 64 | output, state = cell(encoded_x, prev_state) # output: (batch_size, 64) 65 | 66 | h_k = tf.add(output, encoded_x) # (batch_size, 64) 67 | 68 | content_based_attention = tf.nn.softmax(tf.multiply(prev_state[1], g_embedding)) # (n * k, batch_size, 64) 69 | r_k = tf.reduce_sum(tf.multiply(content_based_attention, g_embedding), axis=0) # (batch_size, 64) 70 | 71 | prev_state = rnn.LSTMStateTuple(state[0], tf.add(h_k, r_k)) 72 | 73 | return output 74 | 75 | def cosine_similarity(self, target, support_set): 76 | """the c() function that calculate the cosine similarity between (embedded) support set and (embedded) target 77 | 78 | note: the author uses one-sided cosine similarity as zergylord said in his repo (zergylord/oneshot) 79 | """ 80 | #target_normed = tf.nn.l2_normalize(target, 1) # (batch_size, 64) 81 | target_normed = target 82 | sup_similarity = [] 83 | for i in tf.unstack(support_set): 84 | i_normed = tf.nn.l2_normalize(i, 1) # (batch_size, 64) 85 | similarity = tf.matmul(tf.expand_dims(target_normed, 1), tf.expand_dims(i_normed, 2)) # (batch_size, ) 86 | sup_similarity.append(similarity) 87 | 88 | return tf.squeeze(tf.stack(sup_similarity, axis=1)) # (batch_size, n * k) 89 | 90 | def build(self, support_set_image, support_set_label, image): 91 | """the main graph of matching networks""" 92 | image_encoded = self.image_encoder(image) # (batch_size, 64) 93 | support_set_image_encoded = [self.image_encoder(i) for i in tf.unstack(support_set_image, axis=1)] 94 | 95 | if self.use_fce: 96 | g_embedding = self.fce_g(support_set_image_encoded) # (n * k, batch_size, 64) 97 | f_embedding = self.fce_f(image_encoded, g_embedding) # (batch_size, 64) 98 | else: 99 | g_embedding = tf.stack(support_set_image_encoded) # (n * k, batch_size, 64) 100 | f_embedding = image_encoded # (batch_size, 64) 101 | 102 | # c(f(x_hat), g(x_i)) 103 | embeddings_similarity = self.cosine_similarity(f_embedding, g_embedding) # (batch_size, n * k) 104 | 105 | # compute softmax on similarity to get a(x_hat, x_i) 106 | attention = tf.nn.softmax(embeddings_similarity) 107 | 108 | # \hat{y} = \sum_{i=1}^{k} a(\hat{x}, x_i)y_i 109 | y_hat = tf.matmul(tf.expand_dims(attention, 1), tf.one_hot(support_set_label, self.n_way)) 110 | self.logits = tf.squeeze(y_hat) # (batch_size, n) 111 | 112 | self.pred = tf.argmax(self.logits, 1) 113 | 114 | def loss(self, label): 115 | self.loss_op = tf.losses.sparse_softmax_cross_entropy(label, self.logits) 116 | 117 | def train(self): 118 | return tf.train.AdamOptimizer(self.lr).minimize(self.loss_op) 119 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import os 6 | import pdb 7 | import glob 8 | 9 | import scipy.misc 10 | import numpy as np 11 | 12 | from tqdm import tqdm 13 | 14 | def read_omniglot(): 15 | """Read omniglot dataset, save them to a single npy file""" 16 | omniglot_train = '/home/one-shot-dataset/omniglot/python/images_background' 17 | omniglot_eval = '/home/one-shot-dataset/omniglot/python/images_evaluation' 18 | 19 | data = [] 20 | for r in [omniglot_train, omniglot_eval]: 21 | classes = glob.glob(r + '/*') 22 | for cls in tqdm(classes): 23 | alphabets = glob.glob(cls + '/*') 24 | for a in alphabets: 25 | characters = glob.glob(a + '/*') 26 | raws = [] 27 | for ch in characters: # 20 iters 28 | raw = scipy.misc.imread(ch) 29 | raw = scipy.misc.imresize(raw, (28, 28)) 30 | for dg in [0, 90, 180, 270]: # augmentation 31 | raw_rot = scipy.misc.imrotate(raw, dg) 32 | raw_rot = raw_rot[:, :, np.newaxis] # (28, 28, 1) 33 | raw_rot = raw_rot.astype(np.float32) / 255. 34 | raws.append(raw_rot) 35 | data.append(np.asarray(raws)) 36 | np.save('omniglot.npy', np.asarray(data)) 37 | 38 | 39 | class Data_loader(): 40 | 41 | def __init__(self, batch_size, n_way=5, k_shot=1, train_mode=True): 42 | if not os.path.exists('omniglot.npy'): 43 | read_omniglot() 44 | 45 | self.batch_size = batch_size 46 | self.n_way = n_way # 5 or 20, how many classes the model has to select from 47 | self.k_shot = k_shot # 1 or 5, how many times the model sees the example 48 | 49 | omniglot = np.load('omniglot.npy') 50 | #omniglot = np.load('data_zergylord.npy') 51 | #omniglot = np.reshape(omniglot, [-1, 20, 28, 28, 1]) 52 | np.random.shuffle(omniglot) 53 | assert omniglot.dtype == np.float32 54 | assert omniglot.max() == 1.0 55 | assert omniglot.min() == 0.0 56 | 57 | if train_mode: 58 | self.images = omniglot[:1200, :20, :, :, :] 59 | self.num_classes = self.images.shape[0] 60 | self.num_samples = self.images.shape[1] 61 | else: 62 | self.images = omniglot[1200:, :20, :, :, :] 63 | self.num_classes = self.images.shape[0] 64 | self.num_samples = self.images.shape[1] 65 | 66 | self.iters = self.num_classes 67 | 68 | def next_batch(self): 69 | x_set_batch = [] 70 | y_set_batch = [] 71 | x_hat_batch = [] 72 | y_hat_batch = [] 73 | for _ in xrange(self.batch_size): 74 | x_set = [] 75 | y_set = [] 76 | x = [] 77 | y = [] 78 | classes = np.random.permutation(self.num_classes)[:self.n_way] 79 | target_class = np.random.randint(self.n_way) 80 | for i, c in enumerate(classes): 81 | samples = np.random.permutation(self.num_samples)[:self.k_shot+1] 82 | for s in samples[:-1]: 83 | x_set.append(self.images[c][s]) 84 | y_set.append(i) 85 | 86 | if i == target_class: 87 | x_hat_batch.append(self.images[c][samples[-1]]) 88 | y_hat_batch.append(i) 89 | 90 | x_set_batch.append(x_set) 91 | y_set_batch.append(y_set) 92 | 93 | return np.asarray(x_set_batch).astype(np.float32), np.asarray(y_set_batch).astype(np.int32), np.asarray(x_hat_batch).astype(np.float32), np.asarray(y_hat_batch).astype(np.int32) 94 | --------------------------------------------------------------------------------