├── .gitignore ├── README.md ├── __init__.py ├── data_generator.py ├── main.py ├── maml.py ├── readme ├── metatrain_Postupdate_accuracy__step_1.png ├── metatrain_Postupdate_loss__step_1.png ├── metaval_Postupdate_accuracy__step_1.png ├── metaval_Postupdate_accuracy__step_1_time.png └── metaval_Postupdate_loss__step_1.png └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.pyc 3 | data/ 4 | logs/ 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | Meta-SGD([Meta-SGD: Learning to Learn Quickly for Few Shot Learning(Zhenguo Li et al.)](https://arxiv.org/abs/1707.09835)) experiment on Omniglot classification compared with MAML([Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks (Finn et al., ICML 2017)](https://arxiv.org/abs/1703.03400)) 4 | 5 | code from [MAML](https://github.com/cbfinn/maml) 6 | 7 | data from [Omniglot](https://github.com/brendenlake/omniglot) 8 | 9 | tips: some difference with the paper [Meta-SGD: Learning to Learn Quickly for Few Shot Learning(Zhenguo Li et al.)](https://arxiv.org/abs/1707.09835), the meta-update datas do not come from the seperate dataset. 10 | 11 | ### Usage 12 | 13 | ``` 14 | python main.py --datasource=omniglot --metatrain_iterations=40000 --meta_batch_size=32 --update_batch_size=1 --update_lr=0.4 --num_updates=1 --logdir=logs/omniglot5way/ 15 | 16 | ``` 17 | ``` 18 | python main.py --datasource=omniglot --metatrain_iterations=40000 --meta_batch_size=32 --update_batch_size=1 --update_lr=0.4 --num_updates=1 --logdir=logs/omniglot5way/ --train=False --test_set=True 19 | 20 | ``` 21 | 22 | ### metaSGD and MAML 23 | 24 | all the x label in the figure is iteration step. 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | considering the time cost other than the iteration step: 35 | 36 | 37 | - we can see that the convergence speed and performance of metaSGD is better than MAML 38 | - the result in both iteration and time scale is the same 39 | - other than MAML, performance of meta-SGD won't get worst in long-term training. -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foolyc/Meta-SGD/4922a8dab9bf6368654f174b9d3976dc77627012/__init__.py -------------------------------------------------------------------------------- /data_generator.py: -------------------------------------------------------------------------------- 1 | """ Code for loading data. """ 2 | import numpy as np 3 | import os 4 | import random 5 | import tensorflow as tf 6 | 7 | from tensorflow.python.platform import flags 8 | from utils import get_images 9 | 10 | FLAGS = flags.FLAGS 11 | 12 | class DataGenerator(object): 13 | """ 14 | Data Generator capable of generating batches of sinusoid or Omniglot data. 15 | A "class" is considered a class of omniglot digits or a particular sinusoid function. 16 | """ 17 | def __init__(self, num_samples_per_class, batch_size, config={}): 18 | """ 19 | Args: 20 | num_samples_per_class: num samples to generate per class in one batch 21 | batch_size: size of meta batch size (e.g. number of functions) 22 | """ 23 | self.batch_size = batch_size 24 | self.num_samples_per_class = num_samples_per_class 25 | self.num_classes = 1 # by default 1 (only relevant for classification problems) 26 | 27 | if FLAGS.datasource == 'sinusoid': 28 | self.generate = self.generate_sinusoid_batch 29 | self.amp_range = config.get('amp_range', [0.1, 5.0]) 30 | self.phase_range = config.get('phase_range', [0, np.pi]) 31 | self.input_range = config.get('input_range', [-5.0, 5.0]) 32 | self.dim_input = 1 33 | self.dim_output = 1 34 | elif 'omniglot' in FLAGS.datasource: 35 | self.num_classes = config.get('num_classes', FLAGS.num_classes) 36 | self.img_size = config.get('img_size', (28, 28)) 37 | self.dim_input = np.prod(self.img_size) 38 | self.dim_output = self.num_classes 39 | # data that is pre-resized using PIL with lanczos filter 40 | data_folder = config.get('data_folder', './data/omniglot_resized') 41 | 42 | character_folders = [os.path.join(data_folder, family, character) \ 43 | for family in os.listdir(data_folder) \ 44 | if os.path.isdir(os.path.join(data_folder, family)) \ 45 | for character in os.listdir(os.path.join(data_folder, family))] 46 | random.seed(1) 47 | random.shuffle(character_folders) 48 | num_val = 100 49 | num_train = config.get('num_train', 1200) - num_val 50 | self.metatrain_character_folders = character_folders[:num_train] 51 | if FLAGS.test_set: 52 | self.metaval_character_folders = character_folders[num_train:num_train+num_val] 53 | else: 54 | self.metaval_character_folders = character_folders[num_train+num_val:] 55 | self.rotations = config.get('rotations', [0, 90, 180, 270]) 56 | elif FLAGS.datasource == 'miniimagenet': 57 | self.num_classes = config.get('num_classes', FLAGS.num_classes) 58 | self.img_size = config.get('img_size', (84, 84)) 59 | self.dim_input = np.prod(self.img_size)*3 60 | self.dim_output = self.num_classes 61 | metatrain_folder = config.get('metatrain_folder', './data/miniImagenet/train') 62 | if FLAGS.test_set: 63 | metaval_folder = config.get('metaval_folder', './data/miniImagenet/test') 64 | else: 65 | metaval_folder = config.get('metaval_folder', './data/miniImagenet/val') 66 | 67 | metatrain_folders = [os.path.join(metatrain_folder, label) \ 68 | for label in os.listdir(metatrain_folder) \ 69 | if os.path.isdir(os.path.join(metatrain_folder, label)) \ 70 | ] 71 | metaval_folders = [os.path.join(metaval_folder, label) \ 72 | for label in os.listdir(metaval_folder) \ 73 | if os.path.isdir(os.path.join(metaval_folder, label)) \ 74 | ] 75 | self.metatrain_character_folders = metatrain_folders 76 | self.metaval_character_folders = metaval_folders 77 | self.rotations = config.get('rotations', [0]) 78 | else: 79 | raise ValueError('Unrecognized data source') 80 | 81 | 82 | def make_data_tensor(self, train=True): 83 | if train: 84 | folders = self.metatrain_character_folders 85 | # number of tasks, not number of meta-iterations. (divide by metabatch size to measure) 86 | num_total_batches = 200000 87 | else: 88 | folders = self.metaval_character_folders 89 | num_total_batches = 600 90 | 91 | # make list of files 92 | print('Generating filenames') 93 | all_filenames = [] 94 | for _ in range(num_total_batches): 95 | sampled_character_folders = random.sample(folders, self.num_classes) 96 | random.shuffle(sampled_character_folders) 97 | labels_and_images = get_images(sampled_character_folders, range(self.num_classes), nb_samples=self.num_samples_per_class, shuffle=False) 98 | # make sure the above isn't randomized order 99 | labels = [li[0] for li in labels_and_images] 100 | filenames = [li[1] for li in labels_and_images] 101 | all_filenames.extend(filenames) 102 | 103 | # make queue for tensorflow to read from 104 | filename_queue = tf.train.string_input_producer(tf.convert_to_tensor(all_filenames), shuffle=False) 105 | print('Generating image processing ops') 106 | image_reader = tf.WholeFileReader() 107 | _, image_file = image_reader.read(filename_queue) 108 | if FLAGS.datasource == 'miniimagenet': 109 | image = tf.image.decode_jpeg(image_file) 110 | image.set_shape((self.img_size[0],self.img_size[1],3)) 111 | image = tf.reshape(image, [self.dim_input]) 112 | image = tf.cast(image, tf.float32) / 255.0 113 | else: 114 | image = tf.image.decode_png(image_file) 115 | image.set_shape((self.img_size[0],self.img_size[1],1)) 116 | image = tf.reshape(image, [self.dim_input]) 117 | image = tf.cast(image, tf.float32) / 255.0 118 | image = 1.0 - image # invert 119 | num_preprocess_threads = 1 # TODO - enable this to be set to >1 120 | min_queue_examples = 256 121 | examples_per_batch = self.num_classes * self.num_samples_per_class 122 | batch_image_size = self.batch_size * examples_per_batch 123 | print('Batching images') 124 | images = tf.train.batch( 125 | [image], 126 | batch_size = batch_image_size, 127 | num_threads=num_preprocess_threads, 128 | capacity=min_queue_examples + 3 * batch_image_size, 129 | ) 130 | all_image_batches, all_label_batches = [], [] 131 | print('Manipulating image data to be right shape') 132 | for i in range(self.batch_size): 133 | image_batch = images[i*examples_per_batch:(i+1)*examples_per_batch] 134 | 135 | if FLAGS.datasource == 'omniglot': 136 | # omniglot augments the dataset by rotating digits to create new classes 137 | # get rotation per class (e.g. 0,1,2,0,0 if there are 5 classes) 138 | rotations = tf.multinomial(tf.log([[1., 1.,1.,1.]]), self.num_classes) 139 | label_batch = tf.convert_to_tensor(labels) 140 | new_list, new_label_list = [], [] 141 | for k in range(self.num_samples_per_class): 142 | class_idxs = tf.range(0, self.num_classes) 143 | class_idxs = tf.random_shuffle(class_idxs) 144 | 145 | true_idxs = class_idxs*self.num_samples_per_class + k 146 | new_list.append(tf.gather(image_batch,true_idxs)) 147 | if FLAGS.datasource == 'omniglot': # and FLAGS.train: 148 | new_list[-1] = tf.stack([tf.reshape(tf.image.rot90( 149 | tf.reshape(new_list[-1][ind], [self.img_size[0],self.img_size[1],1]), 150 | k=tf.cast(rotations[0,class_idxs[ind]], tf.int32)), (self.dim_input,)) 151 | for ind in range(self.num_classes)]) 152 | new_label_list.append(tf.gather(label_batch, true_idxs)) 153 | new_list = tf.concat(new_list, 0) # has shape [self.num_classes*self.num_samples_per_class, self.dim_input] 154 | new_label_list = tf.concat(new_label_list, 0) 155 | all_image_batches.append(new_list) 156 | all_label_batches.append(new_label_list) 157 | all_image_batches = tf.stack(all_image_batches) 158 | all_label_batches = tf.stack(all_label_batches) 159 | all_label_batches = tf.one_hot(all_label_batches, self.num_classes) 160 | return all_image_batches, all_label_batches 161 | 162 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage Instructions: 3 | 10-shot sinusoid: 4 | python main.py --datasource=sinusoid --logdir=logs/sine/ --metatrain_iterations=70000 --norm=None --update_batch_size=10 5 | 6 | 10-shot sinusoid baselines: 7 | python main.py --datasource=sinusoid --logdir=logs/sine/ --pretrain_iterations=70000 --metatrain_iterations=0 --norm=None --update_batch_size=10 --baseline=oracle 8 | python main.py --datasource=sinusoid --logdir=logs/sine/ --pretrain_iterations=70000 --metatrain_iterations=0 --norm=None --update_batch_size=10 9 | 10 | 5-way, 1-shot omniglot: 11 | python main.py --datasource=omniglot --metatrain_iterations=40000 --meta_batch_size=32 --update_batch_size=1 --update_lr=0.4 --num_updates=1 --logdir=logs/omniglot5way/ 12 | 13 | 20-way, 1-shot omniglot: 14 | python main.py --datasource=omniglot --metatrain_iterations=40000 --meta_batch_size=16 --update_batch_size=1 --num_classes=20 --update_lr=0.1 --num_updates=5 --logdir=logs/omniglot20way/ 15 | 16 | 5-way 1-shot mini imagenet: 17 | python main.py --datasource=miniimagenet --metatrain_iterations=60000 --meta_batch_size=4 --update_batch_size=1 --update_lr=0.01 --num_updates=5 --num_classes=5 --logdir=logs/miniimagenet1shot/ --num_filters=32 --max_pool=True 18 | 19 | 5-way 5-shot mini imagenet: 20 | python main.py --datasource=miniimagenet --metatrain_iterations=60000 --meta_batch_size=4 --update_batch_size=5 --update_lr=0.01 --num_updates=5 --num_classes=5 --logdir=logs/miniimagenet5shot/ --num_filters=32 --max_pool=True 21 | 22 | To run evaluation, use the '--train=False' flag and the '--test_set=True' flag to use the test set. 23 | 24 | For omniglot and miniimagenet training, acquire the dataset online, put it in the correspoding data directory, and see the python script instructions in that directory to preprocess the data. 25 | """ 26 | import csv 27 | import numpy as np 28 | import pickle 29 | import random 30 | import tensorflow as tf 31 | 32 | from data_generator import DataGenerator 33 | from maml import MAML 34 | from tensorflow.python.platform import flags 35 | 36 | FLAGS = flags.FLAGS 37 | 38 | ## Dataset/method options 39 | flags.DEFINE_string('datasource', 'sinusoid', 'sinusoid or omniglot or miniimagenet') 40 | flags.DEFINE_integer('num_classes', 5, 'number of classes used in classification (e.g. 5-way classification).') 41 | # oracle means task id is input (only suitable for sinusoid) 42 | flags.DEFINE_string('baseline', None, 'oracle, or None') 43 | 44 | ## Training options 45 | flags.DEFINE_integer('pretrain_iterations', 0, 'number of pre-training iterations.') 46 | flags.DEFINE_integer('metatrain_iterations', 15000, 'number of metatraining iterations.') # 15k for omniglot, 50k for sinusoid 47 | flags.DEFINE_integer('meta_batch_size', 25, 'number of tasks sampled per meta-update') 48 | flags.DEFINE_float('meta_lr', 0.001, 'the base learning rate of the generator') 49 | flags.DEFINE_integer('update_batch_size', 5, 'number of examples used for inner gradient update (K for K-shot learning).') 50 | flags.DEFINE_float('update_lr', 1e-3, 'step size alpha for inner gradient update.') # 0.1 for omniglot 51 | flags.DEFINE_integer('num_updates', 1, 'number of inner gradient updates during training.') 52 | 53 | ## Model options 54 | flags.DEFINE_string('norm', 'batch_norm', 'batch_norm, layer_norm, or None') 55 | flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omiglot.') 56 | flags.DEFINE_bool('conv', True, 'whether or not to use a convolutional network, only applicable in some cases') 57 | flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions') 58 | flags.DEFINE_bool('stop_grad', False, 'if True, do not use second derivatives in meta-optimization (for speed)') 59 | 60 | ## Logging, saving, and testing options 61 | flags.DEFINE_bool('log', True, 'if false, do not log summaries, for debugging code.') 62 | flags.DEFINE_string('logdir', '/tmp/data', 'directory for summaries and checkpoints.') 63 | flags.DEFINE_bool('resume', True, 'resume training if there is a model available') 64 | flags.DEFINE_bool('train', True, 'True to train, False to test.') 65 | flags.DEFINE_integer('test_iter', -1, 'iteration to load model (-1 for latest model)') 66 | flags.DEFINE_bool('test_set', False, 'Set to true to test on the the test set, False for the validation set.') 67 | flags.DEFINE_integer('train_update_batch_size', -1, 'number of examples used for gradient update during training (use if you want to test with a different number).') 68 | flags.DEFINE_float('train_update_lr', -1, 'value of inner gradient step step during training. (use if you want to test with a different value)') # 0.1 for omniglot 69 | 70 | def train(model, saver, sess, exp_string, data_generator, resume_itr=0): 71 | SUMMARY_INTERVAL = 100 72 | SAVE_INTERVAL = 1000 73 | if FLAGS.datasource == 'sinusoid': 74 | PRINT_INTERVAL = 1000 75 | TEST_PRINT_INTERVAL = PRINT_INTERVAL*5 76 | else: 77 | PRINT_INTERVAL = 100 78 | TEST_PRINT_INTERVAL = PRINT_INTERVAL*5 79 | 80 | if FLAGS.log: 81 | train_writer = tf.summary.FileWriter(FLAGS.logdir + '/' + exp_string, sess.graph) 82 | print('Done initializing, starting training.') 83 | prelosses, postlosses = [], [] 84 | 85 | num_classes = data_generator.num_classes # for classification, 1 otherwise 86 | multitask_weights, reg_weights = [], [] 87 | 88 | for itr in range(resume_itr, FLAGS.pretrain_iterations + FLAGS.metatrain_iterations): 89 | feed_dict = {} 90 | if 'generate' in dir(data_generator): 91 | batch_x, batch_y, amp, phase = data_generator.generate() 92 | 93 | if FLAGS.baseline == 'oracle': 94 | batch_x = np.concatenate([batch_x, np.zeros([batch_x.shape[0], batch_x.shape[1], 2])], 2) 95 | for i in range(FLAGS.meta_batch_size): 96 | batch_x[i, :, 1] = amp[i] 97 | batch_x[i, :, 2] = phase[i] 98 | 99 | inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :] 100 | labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :] 101 | inputb = batch_x[:, num_classes*FLAGS.update_batch_size:, :] # b used for testing 102 | labelb = batch_y[:, num_classes*FLAGS.update_batch_size:, :] 103 | feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, model.labelb: labelb} 104 | 105 | if itr < FLAGS.pretrain_iterations: 106 | input_tensors = [model.pretrain_op] 107 | else: 108 | input_tensors = [model.metatrain_op] 109 | 110 | if (itr % SUMMARY_INTERVAL == 0 or itr % PRINT_INTERVAL == 0): 111 | input_tensors.extend([model.summ_op, model.total_loss1, model.total_losses2[FLAGS.num_updates-1]]) 112 | if model.classification: 113 | input_tensors.extend([model.total_accuracy1, model.total_accuracies2[FLAGS.num_updates-1]]) 114 | 115 | result = sess.run(input_tensors, feed_dict) 116 | 117 | if itr % SUMMARY_INTERVAL == 0: 118 | prelosses.append(result[-2]) 119 | if FLAGS.log: 120 | train_writer.add_summary(result[1], itr) 121 | postlosses.append(result[-1]) 122 | 123 | if (itr!=0) and itr % PRINT_INTERVAL == 0: 124 | if itr < FLAGS.pretrain_iterations: 125 | print_str = 'Pretrain Iteration ' + str(itr) 126 | else: 127 | print_str = 'Iteration ' + str(itr - FLAGS.pretrain_iterations) 128 | print_str += ': ' + str(np.mean(prelosses)) + ', ' + str(np.mean(postlosses)) 129 | print(print_str) 130 | prelosses, postlosses = [], [] 131 | 132 | if (itr!=0) and itr % SAVE_INTERVAL == 0: 133 | saver.save(sess, FLAGS.logdir + '/' + exp_string + '/model' + str(itr)) 134 | 135 | # sinusoid is infinite data, so no need to test on meta-validation set. 136 | if (itr!=0) and itr % TEST_PRINT_INTERVAL == 0 and FLAGS.datasource !='sinusoid': 137 | if 'generate' not in dir(data_generator): 138 | feed_dict = {} 139 | if model.classification: 140 | input_tensors = [model.metaval_total_accuracy1, model.metaval_total_accuracies2[FLAGS.num_updates-1], model.summ_op] 141 | else: 142 | input_tensors = [model.metaval_total_loss1, model.metaval_total_losses2[FLAGS.num_updates-1], model.summ_op] 143 | else: 144 | batch_x, batch_y, amp, phase = data_generator.generate(train=False) 145 | inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :] 146 | inputb = batch_x[:, num_classes*FLAGS.update_batch_size:, :] 147 | labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :] 148 | labelb = batch_y[:, num_classes*FLAGS.update_batch_size:, :] 149 | feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, model.labelb: labelb, model.meta_lr: 0.0} 150 | if model.classification: 151 | input_tensors = [model.total_accuracy1, model.total_accuracies2[FLAGS.num_updates-1]] 152 | else: 153 | input_tensors = [model.total_loss1, model.total_losses2[FLAGS.num_updates-1]] 154 | 155 | result = sess.run(input_tensors, feed_dict) 156 | print('Validation results: ' + str(result[0]) + ', ' + str(result[1])) 157 | 158 | saver.save(sess, FLAGS.logdir + '/' + exp_string + '/model' + str(itr)) 159 | 160 | # calculated for omniglot 161 | NUM_TEST_POINTS = 600 162 | 163 | def test(model, saver, sess, exp_string, data_generator, test_num_updates=None): 164 | num_classes = data_generator.num_classes # for classification, 1 otherwise 165 | 166 | np.random.seed(1) 167 | random.seed(1) 168 | 169 | metaval_accuracies = [] 170 | 171 | for _ in range(NUM_TEST_POINTS): 172 | if 'generate' not in dir(data_generator): 173 | feed_dict = {} 174 | feed_dict = {model.meta_lr : 0.0} 175 | else: 176 | batch_x, batch_y, amp, phase = data_generator.generate(train=False) 177 | 178 | if FLAGS.baseline == 'oracle': # NOTE - this flag is specific to sinusoid 179 | batch_x = np.concatenate([batch_x, np.zeros([batch_x.shape[0], batch_x.shape[1], 2])], 2) 180 | batch_x[0, :, 1] = amp[0] 181 | batch_x[0, :, 2] = phase[0] 182 | 183 | inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :] 184 | inputb = batch_x[:,num_classes*FLAGS.update_batch_size:, :] 185 | labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :] 186 | labelb = batch_y[:,num_classes*FLAGS.update_batch_size:, :] 187 | 188 | feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, model.labelb: labelb, model.meta_lr: 0.0} 189 | 190 | if model.classification: 191 | result = sess.run([model.total_accuracy1] + model.total_accuracies2, feed_dict) 192 | else: # this is for sinusoid 193 | result = sess.run([model.total_loss1] + model.total_losses2, feed_dict) 194 | metaval_accuracies.append(result) 195 | 196 | metaval_accuracies = np.array(metaval_accuracies) 197 | means = np.mean(metaval_accuracies, 0) 198 | stds = np.std(metaval_accuracies, 0) 199 | ci95 = 1.96*stds/np.sqrt(NUM_TEST_POINTS) 200 | 201 | print('Mean validation accuracy/loss, stddev, and confidence intervals') 202 | print((means, stds, ci95)) 203 | 204 | out_filename = FLAGS.logdir +'/'+ exp_string + '/' + 'test_ubs' + str(FLAGS.update_batch_size) + '_stepsize' + str(FLAGS.update_lr) + '.csv' 205 | out_pkl = FLAGS.logdir +'/'+ exp_string + '/' + 'test_ubs' + str(FLAGS.update_batch_size) + '_stepsize' + str(FLAGS.update_lr) + '.pkl' 206 | with open(out_pkl, 'w') as f: 207 | pickle.dump({'mses': metaval_accuracies}, f) 208 | with open(out_filename, 'w') as f: 209 | writer = csv.writer(f, delimiter=',') 210 | writer.writerow(['update'+str(i) for i in range(len(means))]) 211 | writer.writerow(means) 212 | writer.writerow(stds) 213 | writer.writerow(ci95) 214 | 215 | def main(): 216 | if FLAGS.datasource == 'sinusoid': 217 | if FLAGS.train: 218 | test_num_updates = 5 219 | else: 220 | test_num_updates = 10 221 | else: 222 | if FLAGS.datasource == 'miniimagenet': 223 | if FLAGS.train == True: 224 | test_num_updates = 1 # eval on at least one update during training 225 | else: 226 | test_num_updates = 10 227 | else: 228 | test_num_updates = 10 229 | 230 | if FLAGS.train == False: 231 | orig_meta_batch_size = FLAGS.meta_batch_size 232 | # always use meta batch size of 1 when testing. 233 | FLAGS.meta_batch_size = 1 234 | 235 | if FLAGS.datasource == 'sinusoid': 236 | data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) 237 | else: 238 | if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet': 239 | assert FLAGS.meta_batch_size == 1 240 | assert FLAGS.update_batch_size == 1 241 | data_generator = DataGenerator(1, FLAGS.meta_batch_size) # only use one datapoint, 242 | else: 243 | if FLAGS.datasource == 'miniimagenet': # TODO - use 15 val examples for imagenet? 244 | if FLAGS.train: 245 | data_generator = DataGenerator(FLAGS.update_batch_size+15, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory 246 | else: 247 | data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory 248 | else: 249 | data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory 250 | 251 | 252 | dim_output = data_generator.dim_output 253 | if FLAGS.baseline == 'oracle': 254 | assert FLAGS.datasource == 'sinusoid' 255 | dim_input = 3 256 | FLAGS.pretrain_iterations += FLAGS.metatrain_iterations 257 | FLAGS.metatrain_iterations = 0 258 | else: 259 | dim_input = data_generator.dim_input 260 | 261 | if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot': 262 | tf_data_load = True 263 | num_classes = data_generator.num_classes 264 | 265 | if FLAGS.train: # only construct training model if needed 266 | random.seed(5) 267 | image_tensor, label_tensor = data_generator.make_data_tensor() 268 | inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) 269 | inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) 270 | labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) 271 | labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) 272 | input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} 273 | 274 | random.seed(6) 275 | image_tensor, label_tensor = data_generator.make_data_tensor(train=False) 276 | inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) 277 | inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) 278 | labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) 279 | labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) 280 | metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} 281 | else: 282 | tf_data_load = False 283 | input_tensors = None 284 | 285 | model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) 286 | if FLAGS.train or not tf_data_load: 287 | model.construct_model(input_tensors=input_tensors, prefix='metatrain_') 288 | if tf_data_load: 289 | model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') 290 | model.summ_op = tf.summary.merge_all() 291 | 292 | saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) 293 | 294 | sess = tf.InteractiveSession() 295 | 296 | if FLAGS.train == False: 297 | # change to original meta batch size when loading model. 298 | FLAGS.meta_batch_size = orig_meta_batch_size 299 | 300 | if FLAGS.train_update_batch_size == -1: 301 | FLAGS.train_update_batch_size = FLAGS.update_batch_size 302 | if FLAGS.train_update_lr == -1: 303 | FLAGS.train_update_lr = FLAGS.update_lr 304 | 305 | exp_string = 'cls_'+str(FLAGS.num_classes)+'.mbs_'+str(FLAGS.meta_batch_size) + '.ubs_' + str(FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr) 306 | 307 | if FLAGS.num_filters != 64: 308 | exp_string += 'hidden' + str(FLAGS.num_filters) 309 | if FLAGS.max_pool: 310 | exp_string += 'maxpool' 311 | if FLAGS.stop_grad: 312 | exp_string += 'stopgrad' 313 | if FLAGS.baseline: 314 | exp_string += FLAGS.baseline 315 | if FLAGS.norm == 'batch_norm': 316 | exp_string += 'batchnorm' 317 | elif FLAGS.norm == 'layer_norm': 318 | exp_string += 'layernorm' 319 | elif FLAGS.norm == 'None': 320 | exp_string += 'nonorm' 321 | else: 322 | print('Norm setting not recognized.') 323 | 324 | resume_itr = 0 325 | model_file = None 326 | 327 | tf.global_variables_initializer().run() 328 | tf.train.start_queue_runners() 329 | 330 | if FLAGS.resume or not FLAGS.train: 331 | model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) 332 | if FLAGS.test_iter > 0: 333 | model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter) 334 | if model_file: 335 | ind1 = model_file.index('model') 336 | resume_itr = int(model_file[ind1+5:]) 337 | print("Restoring model weights from " + model_file) 338 | saver.restore(sess, model_file) 339 | 340 | if FLAGS.train: 341 | train(model, saver, sess, exp_string, data_generator, resume_itr) 342 | else: 343 | test(model, saver, sess, exp_string, data_generator, test_num_updates) 344 | 345 | if __name__ == "__main__": 346 | main() 347 | -------------------------------------------------------------------------------- /maml.py: -------------------------------------------------------------------------------- 1 | """ Code for the MAML algorithm and network definitions. """ 2 | import numpy as np 3 | # import special_grads 4 | import tensorflow as tf 5 | 6 | from tensorflow.python.platform import flags 7 | from utils import mse, xent, conv_block, normalize 8 | 9 | FLAGS = flags.FLAGS 10 | 11 | class MAML: 12 | def __init__(self, dim_input=1, dim_output=1, test_num_updates=5): 13 | """ must call construct_model() after initializing MAML! """ 14 | self.dim_input = dim_input 15 | self.dim_output = dim_output 16 | # self.update_lr = FLAGS.update_lr 17 | self.meta_lr = tf.placeholder_with_default(FLAGS.meta_lr, ()) 18 | self.classification = False 19 | self.test_num_updates = test_num_updates 20 | if FLAGS.datasource == 'sinusoid': 21 | self.dim_hidden = [40, 40] 22 | self.loss_func = mse 23 | self.forward = self.forward_fc 24 | self.construct_weights = self.construct_fc_weights 25 | elif FLAGS.datasource == 'omniglot' or FLAGS.datasource == 'miniimagenet': 26 | self.loss_func = xent 27 | self.classification = True 28 | if FLAGS.conv: 29 | self.dim_hidden = FLAGS.num_filters 30 | self.forward = self.forward_conv 31 | self.construct_weights = self.construct_conv_weights 32 | else: 33 | self.dim_hidden = [256, 128, 64, 64] 34 | self.forward=self.forward_fc 35 | self.construct_weights = self.construct_fc_weights 36 | if FLAGS.datasource == 'miniimagenet': 37 | self.channels = 3 38 | else: 39 | self.channels = 1 40 | self.img_size = int(np.sqrt(self.dim_input/self.channels)) 41 | else: 42 | raise ValueError('Unrecognized data source.') 43 | 44 | def construct_model(self, input_tensors=None, prefix='metatrain_'): 45 | # a: training data for inner gradient, b: test data for meta gradient 46 | if input_tensors is None: 47 | self.inputa = tf.placeholder(tf.float32) 48 | self.inputb = tf.placeholder(tf.float32) 49 | self.labela = tf.placeholder(tf.float32) 50 | self.labelb = tf.placeholder(tf.float32) 51 | else: 52 | self.inputa = input_tensors['inputa'] 53 | self.inputb = input_tensors['inputb'] 54 | self.labela = input_tensors['labela'] 55 | self.labelb = input_tensors['labelb'] 56 | 57 | with tf.variable_scope('model', reuse=None) as training_scope: 58 | if 'weights' in dir(self): 59 | training_scope.reuse_variables() 60 | weights = self.weights 61 | else: 62 | # Define the weights 63 | self.weights = weights = self.construct_weights() 64 | self.update_lr = tf.Variable(0.001, "updatelr") 65 | 66 | 67 | # outputbs[i] and lossesb[i] is the output and loss after i+1 gradient updates 68 | lossesa, outputas, lossesb, outputbs = [], [], [], [] 69 | accuraciesa, accuraciesb = [], [] 70 | num_updates = max(self.test_num_updates, FLAGS.num_updates) 71 | outputbs = [[]]*num_updates 72 | lossesb = [[]]*num_updates 73 | accuraciesb = [[]]*num_updates 74 | 75 | def task_metalearn(inp, reuse=True): 76 | """ Perform gradient descent for one task in the meta-batch. """ 77 | inputa, inputb, labela, labelb = inp 78 | task_outputbs, task_lossesb = [], [] 79 | 80 | if self.classification: 81 | task_accuraciesb = [] 82 | 83 | task_outputa = self.forward(inputa, weights, reuse=reuse) # only reuse on the first iter 84 | task_lossa = self.loss_func(task_outputa, labela) 85 | 86 | grads = tf.gradients(task_lossa, list(weights.values())) 87 | if FLAGS.stop_grad: 88 | grads = [tf.stop_gradient(grad) for grad in grads] 89 | gradients = dict(zip(weights.keys(), grads)) 90 | fast_weights = dict(zip(weights.keys(), [weights[key] - self.update_lr*gradients[key] for key in weights.keys()])) 91 | output = self.forward(inputb, fast_weights, reuse=True) 92 | task_outputbs.append(output) 93 | task_lossesb.append(self.loss_func(output, labelb)) 94 | 95 | for j in range(num_updates - 1): 96 | loss = self.loss_func(self.forward(inputa, fast_weights, reuse=True), labela) 97 | grads = tf.gradients(loss, list(fast_weights.values())) 98 | if FLAGS.stop_grad: 99 | grads = [tf.stop_gradient(grad) for grad in grads] 100 | gradients = dict(zip(fast_weights.keys(), grads)) 101 | fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.update_lr*gradients[key] for key in fast_weights.keys()])) 102 | output = self.forward(inputb, fast_weights, reuse=True) 103 | task_outputbs.append(output) 104 | task_lossesb.append(self.loss_func(output, labelb)) 105 | 106 | task_output = [task_outputa, task_outputbs, task_lossa, task_lossesb] 107 | 108 | if self.classification: 109 | task_accuracya = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputa), 1), tf.argmax(labela, 1)) 110 | for j in range(num_updates): 111 | task_accuraciesb.append(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputbs[j]), 1), tf.argmax(labelb, 1))) 112 | task_output.extend([task_accuracya, task_accuraciesb]) 113 | 114 | return task_output 115 | 116 | if FLAGS.norm is not 'None': 117 | # to initialize the batch norm vars, might want to combine this, and not run idx 0 twice. 118 | unused = task_metalearn((self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]), False) 119 | 120 | out_dtype = [tf.float32, [tf.float32]*num_updates, tf.float32, [tf.float32]*num_updates] 121 | if self.classification: 122 | out_dtype.extend([tf.float32, [tf.float32]*num_updates]) 123 | result = tf.map_fn(task_metalearn, elems=(self.inputa, self.inputb, self.labela, self.labelb), dtype=out_dtype, parallel_iterations=FLAGS.meta_batch_size) 124 | if self.classification: 125 | outputas, outputbs, lossesa, lossesb, accuraciesa, accuraciesb = result 126 | else: 127 | outputas, outputbs, lossesa, lossesb = result 128 | 129 | ## Performance & Optimization 130 | if 'train' in prefix: 131 | self.total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size) 132 | self.total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)] 133 | # after the map_fn 134 | self.outputas, self.outputbs = outputas, outputbs 135 | if self.classification: 136 | self.total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(FLAGS.meta_batch_size) 137 | self.total_accuracies2 = total_accuracies2 = [tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)] 138 | self.pretrain_op = tf.train.AdamOptimizer(self.meta_lr).minimize(total_loss1) 139 | 140 | if FLAGS.metatrain_iterations > 0: 141 | optimizer = tf.train.AdamOptimizer(self.meta_lr) 142 | self.gvs = gvs = optimizer.compute_gradients(self.total_losses2[FLAGS.num_updates-1]) 143 | if FLAGS.datasource == 'miniimagenet': 144 | gvs = [(tf.clip_by_value(grad, -10, 10), var) for grad, var in gvs] 145 | self.metatrain_op = optimizer.apply_gradients(gvs) 146 | else: 147 | self.metaval_total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size) 148 | self.metaval_total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)] 149 | if self.classification: 150 | self.metaval_total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(FLAGS.meta_batch_size) 151 | self.total_accuracy1 = self.metaval_total_accuracy1 152 | self.metaval_total_accuracies2 = total_accuracies2 =[tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)] 153 | self.total_accuracies2 = self.metaval_total_accuracies2 154 | 155 | ## Summaries 156 | tf.summary.scalar(prefix+'Pre-update loss', total_loss1) 157 | if self.classification: 158 | tf.summary.scalar(prefix+'Pre-update accuracy', total_accuracy1) 159 | 160 | for j in range(num_updates): 161 | tf.summary.scalar(prefix+'Post-update loss, step ' + str(j+1), total_losses2[j]) 162 | if self.classification: 163 | tf.summary.scalar(prefix+'Post-update accuracy, step ' + str(j+1), total_accuracies2[j]) 164 | 165 | ### Network construction functions (fc networks and conv networks) 166 | def construct_fc_weights(self): 167 | weights = {} 168 | weights['w1'] = tf.Variable(tf.truncated_normal([self.dim_input, self.dim_hidden[0]], stddev=0.01)) 169 | weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden[0]])) 170 | for i in range(1,len(self.dim_hidden)): 171 | weights['w'+str(i+1)] = tf.Variable(tf.truncated_normal([self.dim_hidden[i-1], self.dim_hidden[i]], stddev=0.01)) 172 | weights['b'+str(i+1)] = tf.Variable(tf.zeros([self.dim_hidden[i]])) 173 | weights['w'+str(len(self.dim_hidden)+1)] = tf.Variable(tf.truncated_normal([self.dim_hidden[-1], self.dim_output], stddev=0.01)) 174 | weights['b'+str(len(self.dim_hidden)+1)] = tf.Variable(tf.zeros([self.dim_output])) 175 | return weights 176 | 177 | def forward_fc(self, inp, weights, reuse=False): 178 | hidden = normalize(tf.matmul(inp, weights['w1']) + weights['b1'], activation=tf.nn.relu, reuse=reuse, scope='0') 179 | for i in range(1,len(self.dim_hidden)): 180 | hidden = normalize(tf.matmul(hidden, weights['w'+str(i+1)]) + weights['b'+str(i+1)], activation=tf.nn.relu, reuse=reuse, scope=str(i+1)) 181 | return tf.matmul(hidden, weights['w'+str(len(self.dim_hidden)+1)]) + weights['b'+str(len(self.dim_hidden)+1)] 182 | 183 | def construct_conv_weights(self): 184 | weights = {} 185 | 186 | dtype = tf.float32 187 | conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype) 188 | fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype) 189 | k = 3 190 | 191 | weights['conv1'] = tf.get_variable('conv1', [k, k, self.channels, self.dim_hidden], initializer=conv_initializer, dtype=dtype) 192 | weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden])) 193 | weights['conv2'] = tf.get_variable('conv2', [k, k, self.dim_hidden, self.dim_hidden], initializer=conv_initializer, dtype=dtype) 194 | weights['b2'] = tf.Variable(tf.zeros([self.dim_hidden])) 195 | weights['conv3'] = tf.get_variable('conv3', [k, k, self.dim_hidden, self.dim_hidden], initializer=conv_initializer, dtype=dtype) 196 | weights['b3'] = tf.Variable(tf.zeros([self.dim_hidden])) 197 | weights['conv4'] = tf.get_variable('conv4', [k, k, self.dim_hidden, self.dim_hidden], initializer=conv_initializer, dtype=dtype) 198 | weights['b4'] = tf.Variable(tf.zeros([self.dim_hidden])) 199 | if FLAGS.datasource == 'miniimagenet': 200 | # assumes max pooling 201 | weights['w5'] = tf.get_variable('w5', [self.dim_hidden*5*5, self.dim_output], initializer=fc_initializer) 202 | weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5') 203 | else: 204 | weights['w5'] = tf.Variable(tf.random_normal([self.dim_hidden, self.dim_output]), name='w5') 205 | weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5') 206 | return weights 207 | 208 | def forward_conv(self, inp, weights, reuse=False, scope=''): 209 | # reuse is for the normalization parameters. 210 | channels = self.channels 211 | inp = tf.reshape(inp, [-1, self.img_size, self.img_size, channels]) 212 | 213 | hidden1 = conv_block(inp, weights['conv1'], weights['b1'], reuse, scope+'0') 214 | hidden2 = conv_block(hidden1, weights['conv2'], weights['b2'], reuse, scope+'1') 215 | hidden3 = conv_block(hidden2, weights['conv3'], weights['b3'], reuse, scope+'2') 216 | hidden4 = conv_block(hidden3, weights['conv4'], weights['b4'], reuse, scope+'3') 217 | if FLAGS.datasource == 'miniimagenet': 218 | # last hidden layer is 6x6x64-ish, reshape to a vector 219 | hidden4 = tf.reshape(hidden4, [-1, np.prod([int(dim) for dim in hidden4.get_shape()[1:]])]) 220 | else: 221 | hidden4 = tf.reduce_mean(hidden4, [1, 2]) 222 | 223 | return tf.matmul(hidden4, weights['w5']) + weights['b5'] 224 | 225 | 226 | -------------------------------------------------------------------------------- /readme/metatrain_Postupdate_accuracy__step_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foolyc/Meta-SGD/4922a8dab9bf6368654f174b9d3976dc77627012/readme/metatrain_Postupdate_accuracy__step_1.png -------------------------------------------------------------------------------- /readme/metatrain_Postupdate_loss__step_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foolyc/Meta-SGD/4922a8dab9bf6368654f174b9d3976dc77627012/readme/metatrain_Postupdate_loss__step_1.png -------------------------------------------------------------------------------- /readme/metaval_Postupdate_accuracy__step_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foolyc/Meta-SGD/4922a8dab9bf6368654f174b9d3976dc77627012/readme/metaval_Postupdate_accuracy__step_1.png -------------------------------------------------------------------------------- /readme/metaval_Postupdate_accuracy__step_1_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foolyc/Meta-SGD/4922a8dab9bf6368654f174b9d3976dc77627012/readme/metaval_Postupdate_accuracy__step_1_time.png -------------------------------------------------------------------------------- /readme/metaval_Postupdate_loss__step_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foolyc/Meta-SGD/4922a8dab9bf6368654f174b9d3976dc77627012/readme/metaval_Postupdate_loss__step_1.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ Utility functions. """ 2 | import numpy as np 3 | import os 4 | import random 5 | import tensorflow as tf 6 | 7 | from tensorflow.contrib.layers.python import layers as tf_layers 8 | from tensorflow.python.platform import flags 9 | 10 | FLAGS = flags.FLAGS 11 | 12 | ## Image helper 13 | def get_images(paths, labels, nb_samples=None, shuffle=True): 14 | if nb_samples is not None: 15 | sampler = lambda x: random.sample(x, nb_samples) 16 | else: 17 | sampler = lambda x: x 18 | images = [(i, os.path.join(path, image)) \ 19 | for i, path in zip(labels, paths) \ 20 | for image in sampler(os.listdir(path))] 21 | if shuffle: 22 | random.shuffle(images) 23 | return images 24 | 25 | ## Network helpers 26 | def conv_block(inp, cweight, bweight, reuse, scope, activation=tf.nn.relu, max_pool_pad='VALID', residual=False): 27 | """ Perform, conv, batch norm, nonlinearity, and max pool """ 28 | stride, no_stride = [1,2,2,1], [1,1,1,1] 29 | 30 | if FLAGS.max_pool: 31 | conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME') + bweight 32 | else: 33 | conv_output = tf.nn.conv2d(inp, cweight, stride, 'SAME') + bweight 34 | normed = normalize(conv_output, activation, reuse, scope) 35 | if FLAGS.max_pool: 36 | normed = tf.nn.max_pool(normed, stride, stride, max_pool_pad) 37 | return normed 38 | 39 | def normalize(inp, activation, reuse, scope): 40 | if FLAGS.norm == 'batch_norm': 41 | return tf_layers.batch_norm(inp, activation_fn=activation, reuse=reuse, scope=scope) 42 | elif FLAGS.norm == 'layer_norm': 43 | return tf_layers.layer_norm(inp, activation_fn=activation, reuse=reuse, scope=scope) 44 | elif FLAGS.norm == 'None': 45 | return activation(inp) 46 | 47 | ## Loss functions 48 | def mse(pred, label): 49 | pred = tf.reshape(pred, [-1]) 50 | label = tf.reshape(label, [-1]) 51 | return tf.reduce_mean(tf.square(pred-label)) 52 | 53 | def xent(pred, label): 54 | # Note - with tf version <=0.12, this loss has incorrect 2nd derivatives 55 | return tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=label) / FLAGS.update_batch_size 56 | 57 | 58 | --------------------------------------------------------------------------------