├── .gitignore
├── README.md
├── __init__.py
├── data_generator.py
├── main.py
├── maml.py
├── readme
├── metatrain_Postupdate_accuracy__step_1.png
├── metatrain_Postupdate_loss__step_1.png
├── metaval_Postupdate_accuracy__step_1.png
├── metaval_Postupdate_accuracy__step_1_time.png
└── metaval_Postupdate_loss__step_1.png
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | *.pyc
3 | data/
4 | logs/
5 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | Meta-SGD([Meta-SGD: Learning to Learn Quickly for Few Shot Learning(Zhenguo Li et al.)](https://arxiv.org/abs/1707.09835)) experiment on Omniglot classification compared with MAML([Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks (Finn et al., ICML 2017)](https://arxiv.org/abs/1703.03400))
4 |
5 | code from [MAML](https://github.com/cbfinn/maml)
6 |
7 | data from [Omniglot](https://github.com/brendenlake/omniglot)
8 |
9 | tips: some difference with the paper [Meta-SGD: Learning to Learn Quickly for Few Shot Learning(Zhenguo Li et al.)](https://arxiv.org/abs/1707.09835), the meta-update datas do not come from the seperate dataset.
10 |
11 | ### Usage
12 |
13 | ```
14 | python main.py --datasource=omniglot --metatrain_iterations=40000 --meta_batch_size=32 --update_batch_size=1 --update_lr=0.4 --num_updates=1 --logdir=logs/omniglot5way/
15 |
16 | ```
17 | ```
18 | python main.py --datasource=omniglot --metatrain_iterations=40000 --meta_batch_size=32 --update_batch_size=1 --update_lr=0.4 --num_updates=1 --logdir=logs/omniglot5way/ --train=False --test_set=True
19 |
20 | ```
21 |
22 | ### metaSGD and MAML
23 |
24 | all the x label in the figure is iteration step.
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 | considering the time cost other than the iteration step:
35 |
36 |
37 | - we can see that the convergence speed and performance of metaSGD is better than MAML
38 | - the result in both iteration and time scale is the same
39 | - other than MAML, performance of meta-SGD won't get worst in long-term training.
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foolyc/Meta-SGD/4922a8dab9bf6368654f174b9d3976dc77627012/__init__.py
--------------------------------------------------------------------------------
/data_generator.py:
--------------------------------------------------------------------------------
1 | """ Code for loading data. """
2 | import numpy as np
3 | import os
4 | import random
5 | import tensorflow as tf
6 |
7 | from tensorflow.python.platform import flags
8 | from utils import get_images
9 |
10 | FLAGS = flags.FLAGS
11 |
12 | class DataGenerator(object):
13 | """
14 | Data Generator capable of generating batches of sinusoid or Omniglot data.
15 | A "class" is considered a class of omniglot digits or a particular sinusoid function.
16 | """
17 | def __init__(self, num_samples_per_class, batch_size, config={}):
18 | """
19 | Args:
20 | num_samples_per_class: num samples to generate per class in one batch
21 | batch_size: size of meta batch size (e.g. number of functions)
22 | """
23 | self.batch_size = batch_size
24 | self.num_samples_per_class = num_samples_per_class
25 | self.num_classes = 1 # by default 1 (only relevant for classification problems)
26 |
27 | if FLAGS.datasource == 'sinusoid':
28 | self.generate = self.generate_sinusoid_batch
29 | self.amp_range = config.get('amp_range', [0.1, 5.0])
30 | self.phase_range = config.get('phase_range', [0, np.pi])
31 | self.input_range = config.get('input_range', [-5.0, 5.0])
32 | self.dim_input = 1
33 | self.dim_output = 1
34 | elif 'omniglot' in FLAGS.datasource:
35 | self.num_classes = config.get('num_classes', FLAGS.num_classes)
36 | self.img_size = config.get('img_size', (28, 28))
37 | self.dim_input = np.prod(self.img_size)
38 | self.dim_output = self.num_classes
39 | # data that is pre-resized using PIL with lanczos filter
40 | data_folder = config.get('data_folder', './data/omniglot_resized')
41 |
42 | character_folders = [os.path.join(data_folder, family, character) \
43 | for family in os.listdir(data_folder) \
44 | if os.path.isdir(os.path.join(data_folder, family)) \
45 | for character in os.listdir(os.path.join(data_folder, family))]
46 | random.seed(1)
47 | random.shuffle(character_folders)
48 | num_val = 100
49 | num_train = config.get('num_train', 1200) - num_val
50 | self.metatrain_character_folders = character_folders[:num_train]
51 | if FLAGS.test_set:
52 | self.metaval_character_folders = character_folders[num_train:num_train+num_val]
53 | else:
54 | self.metaval_character_folders = character_folders[num_train+num_val:]
55 | self.rotations = config.get('rotations', [0, 90, 180, 270])
56 | elif FLAGS.datasource == 'miniimagenet':
57 | self.num_classes = config.get('num_classes', FLAGS.num_classes)
58 | self.img_size = config.get('img_size', (84, 84))
59 | self.dim_input = np.prod(self.img_size)*3
60 | self.dim_output = self.num_classes
61 | metatrain_folder = config.get('metatrain_folder', './data/miniImagenet/train')
62 | if FLAGS.test_set:
63 | metaval_folder = config.get('metaval_folder', './data/miniImagenet/test')
64 | else:
65 | metaval_folder = config.get('metaval_folder', './data/miniImagenet/val')
66 |
67 | metatrain_folders = [os.path.join(metatrain_folder, label) \
68 | for label in os.listdir(metatrain_folder) \
69 | if os.path.isdir(os.path.join(metatrain_folder, label)) \
70 | ]
71 | metaval_folders = [os.path.join(metaval_folder, label) \
72 | for label in os.listdir(metaval_folder) \
73 | if os.path.isdir(os.path.join(metaval_folder, label)) \
74 | ]
75 | self.metatrain_character_folders = metatrain_folders
76 | self.metaval_character_folders = metaval_folders
77 | self.rotations = config.get('rotations', [0])
78 | else:
79 | raise ValueError('Unrecognized data source')
80 |
81 |
82 | def make_data_tensor(self, train=True):
83 | if train:
84 | folders = self.metatrain_character_folders
85 | # number of tasks, not number of meta-iterations. (divide by metabatch size to measure)
86 | num_total_batches = 200000
87 | else:
88 | folders = self.metaval_character_folders
89 | num_total_batches = 600
90 |
91 | # make list of files
92 | print('Generating filenames')
93 | all_filenames = []
94 | for _ in range(num_total_batches):
95 | sampled_character_folders = random.sample(folders, self.num_classes)
96 | random.shuffle(sampled_character_folders)
97 | labels_and_images = get_images(sampled_character_folders, range(self.num_classes), nb_samples=self.num_samples_per_class, shuffle=False)
98 | # make sure the above isn't randomized order
99 | labels = [li[0] for li in labels_and_images]
100 | filenames = [li[1] for li in labels_and_images]
101 | all_filenames.extend(filenames)
102 |
103 | # make queue for tensorflow to read from
104 | filename_queue = tf.train.string_input_producer(tf.convert_to_tensor(all_filenames), shuffle=False)
105 | print('Generating image processing ops')
106 | image_reader = tf.WholeFileReader()
107 | _, image_file = image_reader.read(filename_queue)
108 | if FLAGS.datasource == 'miniimagenet':
109 | image = tf.image.decode_jpeg(image_file)
110 | image.set_shape((self.img_size[0],self.img_size[1],3))
111 | image = tf.reshape(image, [self.dim_input])
112 | image = tf.cast(image, tf.float32) / 255.0
113 | else:
114 | image = tf.image.decode_png(image_file)
115 | image.set_shape((self.img_size[0],self.img_size[1],1))
116 | image = tf.reshape(image, [self.dim_input])
117 | image = tf.cast(image, tf.float32) / 255.0
118 | image = 1.0 - image # invert
119 | num_preprocess_threads = 1 # TODO - enable this to be set to >1
120 | min_queue_examples = 256
121 | examples_per_batch = self.num_classes * self.num_samples_per_class
122 | batch_image_size = self.batch_size * examples_per_batch
123 | print('Batching images')
124 | images = tf.train.batch(
125 | [image],
126 | batch_size = batch_image_size,
127 | num_threads=num_preprocess_threads,
128 | capacity=min_queue_examples + 3 * batch_image_size,
129 | )
130 | all_image_batches, all_label_batches = [], []
131 | print('Manipulating image data to be right shape')
132 | for i in range(self.batch_size):
133 | image_batch = images[i*examples_per_batch:(i+1)*examples_per_batch]
134 |
135 | if FLAGS.datasource == 'omniglot':
136 | # omniglot augments the dataset by rotating digits to create new classes
137 | # get rotation per class (e.g. 0,1,2,0,0 if there are 5 classes)
138 | rotations = tf.multinomial(tf.log([[1., 1.,1.,1.]]), self.num_classes)
139 | label_batch = tf.convert_to_tensor(labels)
140 | new_list, new_label_list = [], []
141 | for k in range(self.num_samples_per_class):
142 | class_idxs = tf.range(0, self.num_classes)
143 | class_idxs = tf.random_shuffle(class_idxs)
144 |
145 | true_idxs = class_idxs*self.num_samples_per_class + k
146 | new_list.append(tf.gather(image_batch,true_idxs))
147 | if FLAGS.datasource == 'omniglot': # and FLAGS.train:
148 | new_list[-1] = tf.stack([tf.reshape(tf.image.rot90(
149 | tf.reshape(new_list[-1][ind], [self.img_size[0],self.img_size[1],1]),
150 | k=tf.cast(rotations[0,class_idxs[ind]], tf.int32)), (self.dim_input,))
151 | for ind in range(self.num_classes)])
152 | new_label_list.append(tf.gather(label_batch, true_idxs))
153 | new_list = tf.concat(new_list, 0) # has shape [self.num_classes*self.num_samples_per_class, self.dim_input]
154 | new_label_list = tf.concat(new_label_list, 0)
155 | all_image_batches.append(new_list)
156 | all_label_batches.append(new_label_list)
157 | all_image_batches = tf.stack(all_image_batches)
158 | all_label_batches = tf.stack(all_label_batches)
159 | all_label_batches = tf.one_hot(all_label_batches, self.num_classes)
160 | return all_image_batches, all_label_batches
161 |
162 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage Instructions:
3 | 10-shot sinusoid:
4 | python main.py --datasource=sinusoid --logdir=logs/sine/ --metatrain_iterations=70000 --norm=None --update_batch_size=10
5 |
6 | 10-shot sinusoid baselines:
7 | python main.py --datasource=sinusoid --logdir=logs/sine/ --pretrain_iterations=70000 --metatrain_iterations=0 --norm=None --update_batch_size=10 --baseline=oracle
8 | python main.py --datasource=sinusoid --logdir=logs/sine/ --pretrain_iterations=70000 --metatrain_iterations=0 --norm=None --update_batch_size=10
9 |
10 | 5-way, 1-shot omniglot:
11 | python main.py --datasource=omniglot --metatrain_iterations=40000 --meta_batch_size=32 --update_batch_size=1 --update_lr=0.4 --num_updates=1 --logdir=logs/omniglot5way/
12 |
13 | 20-way, 1-shot omniglot:
14 | python main.py --datasource=omniglot --metatrain_iterations=40000 --meta_batch_size=16 --update_batch_size=1 --num_classes=20 --update_lr=0.1 --num_updates=5 --logdir=logs/omniglot20way/
15 |
16 | 5-way 1-shot mini imagenet:
17 | python main.py --datasource=miniimagenet --metatrain_iterations=60000 --meta_batch_size=4 --update_batch_size=1 --update_lr=0.01 --num_updates=5 --num_classes=5 --logdir=logs/miniimagenet1shot/ --num_filters=32 --max_pool=True
18 |
19 | 5-way 5-shot mini imagenet:
20 | python main.py --datasource=miniimagenet --metatrain_iterations=60000 --meta_batch_size=4 --update_batch_size=5 --update_lr=0.01 --num_updates=5 --num_classes=5 --logdir=logs/miniimagenet5shot/ --num_filters=32 --max_pool=True
21 |
22 | To run evaluation, use the '--train=False' flag and the '--test_set=True' flag to use the test set.
23 |
24 | For omniglot and miniimagenet training, acquire the dataset online, put it in the correspoding data directory, and see the python script instructions in that directory to preprocess the data.
25 | """
26 | import csv
27 | import numpy as np
28 | import pickle
29 | import random
30 | import tensorflow as tf
31 |
32 | from data_generator import DataGenerator
33 | from maml import MAML
34 | from tensorflow.python.platform import flags
35 |
36 | FLAGS = flags.FLAGS
37 |
38 | ## Dataset/method options
39 | flags.DEFINE_string('datasource', 'sinusoid', 'sinusoid or omniglot or miniimagenet')
40 | flags.DEFINE_integer('num_classes', 5, 'number of classes used in classification (e.g. 5-way classification).')
41 | # oracle means task id is input (only suitable for sinusoid)
42 | flags.DEFINE_string('baseline', None, 'oracle, or None')
43 |
44 | ## Training options
45 | flags.DEFINE_integer('pretrain_iterations', 0, 'number of pre-training iterations.')
46 | flags.DEFINE_integer('metatrain_iterations', 15000, 'number of metatraining iterations.') # 15k for omniglot, 50k for sinusoid
47 | flags.DEFINE_integer('meta_batch_size', 25, 'number of tasks sampled per meta-update')
48 | flags.DEFINE_float('meta_lr', 0.001, 'the base learning rate of the generator')
49 | flags.DEFINE_integer('update_batch_size', 5, 'number of examples used for inner gradient update (K for K-shot learning).')
50 | flags.DEFINE_float('update_lr', 1e-3, 'step size alpha for inner gradient update.') # 0.1 for omniglot
51 | flags.DEFINE_integer('num_updates', 1, 'number of inner gradient updates during training.')
52 |
53 | ## Model options
54 | flags.DEFINE_string('norm', 'batch_norm', 'batch_norm, layer_norm, or None')
55 | flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omiglot.')
56 | flags.DEFINE_bool('conv', True, 'whether or not to use a convolutional network, only applicable in some cases')
57 | flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions')
58 | flags.DEFINE_bool('stop_grad', False, 'if True, do not use second derivatives in meta-optimization (for speed)')
59 |
60 | ## Logging, saving, and testing options
61 | flags.DEFINE_bool('log', True, 'if false, do not log summaries, for debugging code.')
62 | flags.DEFINE_string('logdir', '/tmp/data', 'directory for summaries and checkpoints.')
63 | flags.DEFINE_bool('resume', True, 'resume training if there is a model available')
64 | flags.DEFINE_bool('train', True, 'True to train, False to test.')
65 | flags.DEFINE_integer('test_iter', -1, 'iteration to load model (-1 for latest model)')
66 | flags.DEFINE_bool('test_set', False, 'Set to true to test on the the test set, False for the validation set.')
67 | flags.DEFINE_integer('train_update_batch_size', -1, 'number of examples used for gradient update during training (use if you want to test with a different number).')
68 | flags.DEFINE_float('train_update_lr', -1, 'value of inner gradient step step during training. (use if you want to test with a different value)') # 0.1 for omniglot
69 |
70 | def train(model, saver, sess, exp_string, data_generator, resume_itr=0):
71 | SUMMARY_INTERVAL = 100
72 | SAVE_INTERVAL = 1000
73 | if FLAGS.datasource == 'sinusoid':
74 | PRINT_INTERVAL = 1000
75 | TEST_PRINT_INTERVAL = PRINT_INTERVAL*5
76 | else:
77 | PRINT_INTERVAL = 100
78 | TEST_PRINT_INTERVAL = PRINT_INTERVAL*5
79 |
80 | if FLAGS.log:
81 | train_writer = tf.summary.FileWriter(FLAGS.logdir + '/' + exp_string, sess.graph)
82 | print('Done initializing, starting training.')
83 | prelosses, postlosses = [], []
84 |
85 | num_classes = data_generator.num_classes # for classification, 1 otherwise
86 | multitask_weights, reg_weights = [], []
87 |
88 | for itr in range(resume_itr, FLAGS.pretrain_iterations + FLAGS.metatrain_iterations):
89 | feed_dict = {}
90 | if 'generate' in dir(data_generator):
91 | batch_x, batch_y, amp, phase = data_generator.generate()
92 |
93 | if FLAGS.baseline == 'oracle':
94 | batch_x = np.concatenate([batch_x, np.zeros([batch_x.shape[0], batch_x.shape[1], 2])], 2)
95 | for i in range(FLAGS.meta_batch_size):
96 | batch_x[i, :, 1] = amp[i]
97 | batch_x[i, :, 2] = phase[i]
98 |
99 | inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :]
100 | labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :]
101 | inputb = batch_x[:, num_classes*FLAGS.update_batch_size:, :] # b used for testing
102 | labelb = batch_y[:, num_classes*FLAGS.update_batch_size:, :]
103 | feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, model.labelb: labelb}
104 |
105 | if itr < FLAGS.pretrain_iterations:
106 | input_tensors = [model.pretrain_op]
107 | else:
108 | input_tensors = [model.metatrain_op]
109 |
110 | if (itr % SUMMARY_INTERVAL == 0 or itr % PRINT_INTERVAL == 0):
111 | input_tensors.extend([model.summ_op, model.total_loss1, model.total_losses2[FLAGS.num_updates-1]])
112 | if model.classification:
113 | input_tensors.extend([model.total_accuracy1, model.total_accuracies2[FLAGS.num_updates-1]])
114 |
115 | result = sess.run(input_tensors, feed_dict)
116 |
117 | if itr % SUMMARY_INTERVAL == 0:
118 | prelosses.append(result[-2])
119 | if FLAGS.log:
120 | train_writer.add_summary(result[1], itr)
121 | postlosses.append(result[-1])
122 |
123 | if (itr!=0) and itr % PRINT_INTERVAL == 0:
124 | if itr < FLAGS.pretrain_iterations:
125 | print_str = 'Pretrain Iteration ' + str(itr)
126 | else:
127 | print_str = 'Iteration ' + str(itr - FLAGS.pretrain_iterations)
128 | print_str += ': ' + str(np.mean(prelosses)) + ', ' + str(np.mean(postlosses))
129 | print(print_str)
130 | prelosses, postlosses = [], []
131 |
132 | if (itr!=0) and itr % SAVE_INTERVAL == 0:
133 | saver.save(sess, FLAGS.logdir + '/' + exp_string + '/model' + str(itr))
134 |
135 | # sinusoid is infinite data, so no need to test on meta-validation set.
136 | if (itr!=0) and itr % TEST_PRINT_INTERVAL == 0 and FLAGS.datasource !='sinusoid':
137 | if 'generate' not in dir(data_generator):
138 | feed_dict = {}
139 | if model.classification:
140 | input_tensors = [model.metaval_total_accuracy1, model.metaval_total_accuracies2[FLAGS.num_updates-1], model.summ_op]
141 | else:
142 | input_tensors = [model.metaval_total_loss1, model.metaval_total_losses2[FLAGS.num_updates-1], model.summ_op]
143 | else:
144 | batch_x, batch_y, amp, phase = data_generator.generate(train=False)
145 | inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :]
146 | inputb = batch_x[:, num_classes*FLAGS.update_batch_size:, :]
147 | labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :]
148 | labelb = batch_y[:, num_classes*FLAGS.update_batch_size:, :]
149 | feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, model.labelb: labelb, model.meta_lr: 0.0}
150 | if model.classification:
151 | input_tensors = [model.total_accuracy1, model.total_accuracies2[FLAGS.num_updates-1]]
152 | else:
153 | input_tensors = [model.total_loss1, model.total_losses2[FLAGS.num_updates-1]]
154 |
155 | result = sess.run(input_tensors, feed_dict)
156 | print('Validation results: ' + str(result[0]) + ', ' + str(result[1]))
157 |
158 | saver.save(sess, FLAGS.logdir + '/' + exp_string + '/model' + str(itr))
159 |
160 | # calculated for omniglot
161 | NUM_TEST_POINTS = 600
162 |
163 | def test(model, saver, sess, exp_string, data_generator, test_num_updates=None):
164 | num_classes = data_generator.num_classes # for classification, 1 otherwise
165 |
166 | np.random.seed(1)
167 | random.seed(1)
168 |
169 | metaval_accuracies = []
170 |
171 | for _ in range(NUM_TEST_POINTS):
172 | if 'generate' not in dir(data_generator):
173 | feed_dict = {}
174 | feed_dict = {model.meta_lr : 0.0}
175 | else:
176 | batch_x, batch_y, amp, phase = data_generator.generate(train=False)
177 |
178 | if FLAGS.baseline == 'oracle': # NOTE - this flag is specific to sinusoid
179 | batch_x = np.concatenate([batch_x, np.zeros([batch_x.shape[0], batch_x.shape[1], 2])], 2)
180 | batch_x[0, :, 1] = amp[0]
181 | batch_x[0, :, 2] = phase[0]
182 |
183 | inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :]
184 | inputb = batch_x[:,num_classes*FLAGS.update_batch_size:, :]
185 | labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :]
186 | labelb = batch_y[:,num_classes*FLAGS.update_batch_size:, :]
187 |
188 | feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, model.labelb: labelb, model.meta_lr: 0.0}
189 |
190 | if model.classification:
191 | result = sess.run([model.total_accuracy1] + model.total_accuracies2, feed_dict)
192 | else: # this is for sinusoid
193 | result = sess.run([model.total_loss1] + model.total_losses2, feed_dict)
194 | metaval_accuracies.append(result)
195 |
196 | metaval_accuracies = np.array(metaval_accuracies)
197 | means = np.mean(metaval_accuracies, 0)
198 | stds = np.std(metaval_accuracies, 0)
199 | ci95 = 1.96*stds/np.sqrt(NUM_TEST_POINTS)
200 |
201 | print('Mean validation accuracy/loss, stddev, and confidence intervals')
202 | print((means, stds, ci95))
203 |
204 | out_filename = FLAGS.logdir +'/'+ exp_string + '/' + 'test_ubs' + str(FLAGS.update_batch_size) + '_stepsize' + str(FLAGS.update_lr) + '.csv'
205 | out_pkl = FLAGS.logdir +'/'+ exp_string + '/' + 'test_ubs' + str(FLAGS.update_batch_size) + '_stepsize' + str(FLAGS.update_lr) + '.pkl'
206 | with open(out_pkl, 'w') as f:
207 | pickle.dump({'mses': metaval_accuracies}, f)
208 | with open(out_filename, 'w') as f:
209 | writer = csv.writer(f, delimiter=',')
210 | writer.writerow(['update'+str(i) for i in range(len(means))])
211 | writer.writerow(means)
212 | writer.writerow(stds)
213 | writer.writerow(ci95)
214 |
215 | def main():
216 | if FLAGS.datasource == 'sinusoid':
217 | if FLAGS.train:
218 | test_num_updates = 5
219 | else:
220 | test_num_updates = 10
221 | else:
222 | if FLAGS.datasource == 'miniimagenet':
223 | if FLAGS.train == True:
224 | test_num_updates = 1 # eval on at least one update during training
225 | else:
226 | test_num_updates = 10
227 | else:
228 | test_num_updates = 10
229 |
230 | if FLAGS.train == False:
231 | orig_meta_batch_size = FLAGS.meta_batch_size
232 | # always use meta batch size of 1 when testing.
233 | FLAGS.meta_batch_size = 1
234 |
235 | if FLAGS.datasource == 'sinusoid':
236 | data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size)
237 | else:
238 | if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet':
239 | assert FLAGS.meta_batch_size == 1
240 | assert FLAGS.update_batch_size == 1
241 | data_generator = DataGenerator(1, FLAGS.meta_batch_size) # only use one datapoint,
242 | else:
243 | if FLAGS.datasource == 'miniimagenet': # TODO - use 15 val examples for imagenet?
244 | if FLAGS.train:
245 | data_generator = DataGenerator(FLAGS.update_batch_size+15, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory
246 | else:
247 | data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory
248 | else:
249 | data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory
250 |
251 |
252 | dim_output = data_generator.dim_output
253 | if FLAGS.baseline == 'oracle':
254 | assert FLAGS.datasource == 'sinusoid'
255 | dim_input = 3
256 | FLAGS.pretrain_iterations += FLAGS.metatrain_iterations
257 | FLAGS.metatrain_iterations = 0
258 | else:
259 | dim_input = data_generator.dim_input
260 |
261 | if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot':
262 | tf_data_load = True
263 | num_classes = data_generator.num_classes
264 |
265 | if FLAGS.train: # only construct training model if needed
266 | random.seed(5)
267 | image_tensor, label_tensor = data_generator.make_data_tensor()
268 | inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
269 | inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
270 | labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
271 | labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
272 | input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}
273 |
274 | random.seed(6)
275 | image_tensor, label_tensor = data_generator.make_data_tensor(train=False)
276 | inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
277 | inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
278 | labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
279 | labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
280 | metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}
281 | else:
282 | tf_data_load = False
283 | input_tensors = None
284 |
285 | model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
286 | if FLAGS.train or not tf_data_load:
287 | model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
288 | if tf_data_load:
289 | model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_')
290 | model.summ_op = tf.summary.merge_all()
291 |
292 | saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10)
293 |
294 | sess = tf.InteractiveSession()
295 |
296 | if FLAGS.train == False:
297 | # change to original meta batch size when loading model.
298 | FLAGS.meta_batch_size = orig_meta_batch_size
299 |
300 | if FLAGS.train_update_batch_size == -1:
301 | FLAGS.train_update_batch_size = FLAGS.update_batch_size
302 | if FLAGS.train_update_lr == -1:
303 | FLAGS.train_update_lr = FLAGS.update_lr
304 |
305 | exp_string = 'cls_'+str(FLAGS.num_classes)+'.mbs_'+str(FLAGS.meta_batch_size) + '.ubs_' + str(FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr)
306 |
307 | if FLAGS.num_filters != 64:
308 | exp_string += 'hidden' + str(FLAGS.num_filters)
309 | if FLAGS.max_pool:
310 | exp_string += 'maxpool'
311 | if FLAGS.stop_grad:
312 | exp_string += 'stopgrad'
313 | if FLAGS.baseline:
314 | exp_string += FLAGS.baseline
315 | if FLAGS.norm == 'batch_norm':
316 | exp_string += 'batchnorm'
317 | elif FLAGS.norm == 'layer_norm':
318 | exp_string += 'layernorm'
319 | elif FLAGS.norm == 'None':
320 | exp_string += 'nonorm'
321 | else:
322 | print('Norm setting not recognized.')
323 |
324 | resume_itr = 0
325 | model_file = None
326 |
327 | tf.global_variables_initializer().run()
328 | tf.train.start_queue_runners()
329 |
330 | if FLAGS.resume or not FLAGS.train:
331 | model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string)
332 | if FLAGS.test_iter > 0:
333 | model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter)
334 | if model_file:
335 | ind1 = model_file.index('model')
336 | resume_itr = int(model_file[ind1+5:])
337 | print("Restoring model weights from " + model_file)
338 | saver.restore(sess, model_file)
339 |
340 | if FLAGS.train:
341 | train(model, saver, sess, exp_string, data_generator, resume_itr)
342 | else:
343 | test(model, saver, sess, exp_string, data_generator, test_num_updates)
344 |
345 | if __name__ == "__main__":
346 | main()
347 |
--------------------------------------------------------------------------------
/maml.py:
--------------------------------------------------------------------------------
1 | """ Code for the MAML algorithm and network definitions. """
2 | import numpy as np
3 | # import special_grads
4 | import tensorflow as tf
5 |
6 | from tensorflow.python.platform import flags
7 | from utils import mse, xent, conv_block, normalize
8 |
9 | FLAGS = flags.FLAGS
10 |
11 | class MAML:
12 | def __init__(self, dim_input=1, dim_output=1, test_num_updates=5):
13 | """ must call construct_model() after initializing MAML! """
14 | self.dim_input = dim_input
15 | self.dim_output = dim_output
16 | # self.update_lr = FLAGS.update_lr
17 | self.meta_lr = tf.placeholder_with_default(FLAGS.meta_lr, ())
18 | self.classification = False
19 | self.test_num_updates = test_num_updates
20 | if FLAGS.datasource == 'sinusoid':
21 | self.dim_hidden = [40, 40]
22 | self.loss_func = mse
23 | self.forward = self.forward_fc
24 | self.construct_weights = self.construct_fc_weights
25 | elif FLAGS.datasource == 'omniglot' or FLAGS.datasource == 'miniimagenet':
26 | self.loss_func = xent
27 | self.classification = True
28 | if FLAGS.conv:
29 | self.dim_hidden = FLAGS.num_filters
30 | self.forward = self.forward_conv
31 | self.construct_weights = self.construct_conv_weights
32 | else:
33 | self.dim_hidden = [256, 128, 64, 64]
34 | self.forward=self.forward_fc
35 | self.construct_weights = self.construct_fc_weights
36 | if FLAGS.datasource == 'miniimagenet':
37 | self.channels = 3
38 | else:
39 | self.channels = 1
40 | self.img_size = int(np.sqrt(self.dim_input/self.channels))
41 | else:
42 | raise ValueError('Unrecognized data source.')
43 |
44 | def construct_model(self, input_tensors=None, prefix='metatrain_'):
45 | # a: training data for inner gradient, b: test data for meta gradient
46 | if input_tensors is None:
47 | self.inputa = tf.placeholder(tf.float32)
48 | self.inputb = tf.placeholder(tf.float32)
49 | self.labela = tf.placeholder(tf.float32)
50 | self.labelb = tf.placeholder(tf.float32)
51 | else:
52 | self.inputa = input_tensors['inputa']
53 | self.inputb = input_tensors['inputb']
54 | self.labela = input_tensors['labela']
55 | self.labelb = input_tensors['labelb']
56 |
57 | with tf.variable_scope('model', reuse=None) as training_scope:
58 | if 'weights' in dir(self):
59 | training_scope.reuse_variables()
60 | weights = self.weights
61 | else:
62 | # Define the weights
63 | self.weights = weights = self.construct_weights()
64 | self.update_lr = tf.Variable(0.001, "updatelr")
65 |
66 |
67 | # outputbs[i] and lossesb[i] is the output and loss after i+1 gradient updates
68 | lossesa, outputas, lossesb, outputbs = [], [], [], []
69 | accuraciesa, accuraciesb = [], []
70 | num_updates = max(self.test_num_updates, FLAGS.num_updates)
71 | outputbs = [[]]*num_updates
72 | lossesb = [[]]*num_updates
73 | accuraciesb = [[]]*num_updates
74 |
75 | def task_metalearn(inp, reuse=True):
76 | """ Perform gradient descent for one task in the meta-batch. """
77 | inputa, inputb, labela, labelb = inp
78 | task_outputbs, task_lossesb = [], []
79 |
80 | if self.classification:
81 | task_accuraciesb = []
82 |
83 | task_outputa = self.forward(inputa, weights, reuse=reuse) # only reuse on the first iter
84 | task_lossa = self.loss_func(task_outputa, labela)
85 |
86 | grads = tf.gradients(task_lossa, list(weights.values()))
87 | if FLAGS.stop_grad:
88 | grads = [tf.stop_gradient(grad) for grad in grads]
89 | gradients = dict(zip(weights.keys(), grads))
90 | fast_weights = dict(zip(weights.keys(), [weights[key] - self.update_lr*gradients[key] for key in weights.keys()]))
91 | output = self.forward(inputb, fast_weights, reuse=True)
92 | task_outputbs.append(output)
93 | task_lossesb.append(self.loss_func(output, labelb))
94 |
95 | for j in range(num_updates - 1):
96 | loss = self.loss_func(self.forward(inputa, fast_weights, reuse=True), labela)
97 | grads = tf.gradients(loss, list(fast_weights.values()))
98 | if FLAGS.stop_grad:
99 | grads = [tf.stop_gradient(grad) for grad in grads]
100 | gradients = dict(zip(fast_weights.keys(), grads))
101 | fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.update_lr*gradients[key] for key in fast_weights.keys()]))
102 | output = self.forward(inputb, fast_weights, reuse=True)
103 | task_outputbs.append(output)
104 | task_lossesb.append(self.loss_func(output, labelb))
105 |
106 | task_output = [task_outputa, task_outputbs, task_lossa, task_lossesb]
107 |
108 | if self.classification:
109 | task_accuracya = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputa), 1), tf.argmax(labela, 1))
110 | for j in range(num_updates):
111 | task_accuraciesb.append(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputbs[j]), 1), tf.argmax(labelb, 1)))
112 | task_output.extend([task_accuracya, task_accuraciesb])
113 |
114 | return task_output
115 |
116 | if FLAGS.norm is not 'None':
117 | # to initialize the batch norm vars, might want to combine this, and not run idx 0 twice.
118 | unused = task_metalearn((self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]), False)
119 |
120 | out_dtype = [tf.float32, [tf.float32]*num_updates, tf.float32, [tf.float32]*num_updates]
121 | if self.classification:
122 | out_dtype.extend([tf.float32, [tf.float32]*num_updates])
123 | result = tf.map_fn(task_metalearn, elems=(self.inputa, self.inputb, self.labela, self.labelb), dtype=out_dtype, parallel_iterations=FLAGS.meta_batch_size)
124 | if self.classification:
125 | outputas, outputbs, lossesa, lossesb, accuraciesa, accuraciesb = result
126 | else:
127 | outputas, outputbs, lossesa, lossesb = result
128 |
129 | ## Performance & Optimization
130 | if 'train' in prefix:
131 | self.total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size)
132 | self.total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
133 | # after the map_fn
134 | self.outputas, self.outputbs = outputas, outputbs
135 | if self.classification:
136 | self.total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(FLAGS.meta_batch_size)
137 | self.total_accuracies2 = total_accuracies2 = [tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
138 | self.pretrain_op = tf.train.AdamOptimizer(self.meta_lr).minimize(total_loss1)
139 |
140 | if FLAGS.metatrain_iterations > 0:
141 | optimizer = tf.train.AdamOptimizer(self.meta_lr)
142 | self.gvs = gvs = optimizer.compute_gradients(self.total_losses2[FLAGS.num_updates-1])
143 | if FLAGS.datasource == 'miniimagenet':
144 | gvs = [(tf.clip_by_value(grad, -10, 10), var) for grad, var in gvs]
145 | self.metatrain_op = optimizer.apply_gradients(gvs)
146 | else:
147 | self.metaval_total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size)
148 | self.metaval_total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
149 | if self.classification:
150 | self.metaval_total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(FLAGS.meta_batch_size)
151 | self.total_accuracy1 = self.metaval_total_accuracy1
152 | self.metaval_total_accuracies2 = total_accuracies2 =[tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
153 | self.total_accuracies2 = self.metaval_total_accuracies2
154 |
155 | ## Summaries
156 | tf.summary.scalar(prefix+'Pre-update loss', total_loss1)
157 | if self.classification:
158 | tf.summary.scalar(prefix+'Pre-update accuracy', total_accuracy1)
159 |
160 | for j in range(num_updates):
161 | tf.summary.scalar(prefix+'Post-update loss, step ' + str(j+1), total_losses2[j])
162 | if self.classification:
163 | tf.summary.scalar(prefix+'Post-update accuracy, step ' + str(j+1), total_accuracies2[j])
164 |
165 | ### Network construction functions (fc networks and conv networks)
166 | def construct_fc_weights(self):
167 | weights = {}
168 | weights['w1'] = tf.Variable(tf.truncated_normal([self.dim_input, self.dim_hidden[0]], stddev=0.01))
169 | weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden[0]]))
170 | for i in range(1,len(self.dim_hidden)):
171 | weights['w'+str(i+1)] = tf.Variable(tf.truncated_normal([self.dim_hidden[i-1], self.dim_hidden[i]], stddev=0.01))
172 | weights['b'+str(i+1)] = tf.Variable(tf.zeros([self.dim_hidden[i]]))
173 | weights['w'+str(len(self.dim_hidden)+1)] = tf.Variable(tf.truncated_normal([self.dim_hidden[-1], self.dim_output], stddev=0.01))
174 | weights['b'+str(len(self.dim_hidden)+1)] = tf.Variable(tf.zeros([self.dim_output]))
175 | return weights
176 |
177 | def forward_fc(self, inp, weights, reuse=False):
178 | hidden = normalize(tf.matmul(inp, weights['w1']) + weights['b1'], activation=tf.nn.relu, reuse=reuse, scope='0')
179 | for i in range(1,len(self.dim_hidden)):
180 | hidden = normalize(tf.matmul(hidden, weights['w'+str(i+1)]) + weights['b'+str(i+1)], activation=tf.nn.relu, reuse=reuse, scope=str(i+1))
181 | return tf.matmul(hidden, weights['w'+str(len(self.dim_hidden)+1)]) + weights['b'+str(len(self.dim_hidden)+1)]
182 |
183 | def construct_conv_weights(self):
184 | weights = {}
185 |
186 | dtype = tf.float32
187 | conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
188 | fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
189 | k = 3
190 |
191 | weights['conv1'] = tf.get_variable('conv1', [k, k, self.channels, self.dim_hidden], initializer=conv_initializer, dtype=dtype)
192 | weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden]))
193 | weights['conv2'] = tf.get_variable('conv2', [k, k, self.dim_hidden, self.dim_hidden], initializer=conv_initializer, dtype=dtype)
194 | weights['b2'] = tf.Variable(tf.zeros([self.dim_hidden]))
195 | weights['conv3'] = tf.get_variable('conv3', [k, k, self.dim_hidden, self.dim_hidden], initializer=conv_initializer, dtype=dtype)
196 | weights['b3'] = tf.Variable(tf.zeros([self.dim_hidden]))
197 | weights['conv4'] = tf.get_variable('conv4', [k, k, self.dim_hidden, self.dim_hidden], initializer=conv_initializer, dtype=dtype)
198 | weights['b4'] = tf.Variable(tf.zeros([self.dim_hidden]))
199 | if FLAGS.datasource == 'miniimagenet':
200 | # assumes max pooling
201 | weights['w5'] = tf.get_variable('w5', [self.dim_hidden*5*5, self.dim_output], initializer=fc_initializer)
202 | weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5')
203 | else:
204 | weights['w5'] = tf.Variable(tf.random_normal([self.dim_hidden, self.dim_output]), name='w5')
205 | weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5')
206 | return weights
207 |
208 | def forward_conv(self, inp, weights, reuse=False, scope=''):
209 | # reuse is for the normalization parameters.
210 | channels = self.channels
211 | inp = tf.reshape(inp, [-1, self.img_size, self.img_size, channels])
212 |
213 | hidden1 = conv_block(inp, weights['conv1'], weights['b1'], reuse, scope+'0')
214 | hidden2 = conv_block(hidden1, weights['conv2'], weights['b2'], reuse, scope+'1')
215 | hidden3 = conv_block(hidden2, weights['conv3'], weights['b3'], reuse, scope+'2')
216 | hidden4 = conv_block(hidden3, weights['conv4'], weights['b4'], reuse, scope+'3')
217 | if FLAGS.datasource == 'miniimagenet':
218 | # last hidden layer is 6x6x64-ish, reshape to a vector
219 | hidden4 = tf.reshape(hidden4, [-1, np.prod([int(dim) for dim in hidden4.get_shape()[1:]])])
220 | else:
221 | hidden4 = tf.reduce_mean(hidden4, [1, 2])
222 |
223 | return tf.matmul(hidden4, weights['w5']) + weights['b5']
224 |
225 |
226 |
--------------------------------------------------------------------------------
/readme/metatrain_Postupdate_accuracy__step_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foolyc/Meta-SGD/4922a8dab9bf6368654f174b9d3976dc77627012/readme/metatrain_Postupdate_accuracy__step_1.png
--------------------------------------------------------------------------------
/readme/metatrain_Postupdate_loss__step_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foolyc/Meta-SGD/4922a8dab9bf6368654f174b9d3976dc77627012/readme/metatrain_Postupdate_loss__step_1.png
--------------------------------------------------------------------------------
/readme/metaval_Postupdate_accuracy__step_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foolyc/Meta-SGD/4922a8dab9bf6368654f174b9d3976dc77627012/readme/metaval_Postupdate_accuracy__step_1.png
--------------------------------------------------------------------------------
/readme/metaval_Postupdate_accuracy__step_1_time.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foolyc/Meta-SGD/4922a8dab9bf6368654f174b9d3976dc77627012/readme/metaval_Postupdate_accuracy__step_1_time.png
--------------------------------------------------------------------------------
/readme/metaval_Postupdate_loss__step_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foolyc/Meta-SGD/4922a8dab9bf6368654f174b9d3976dc77627012/readme/metaval_Postupdate_loss__step_1.png
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """ Utility functions. """
2 | import numpy as np
3 | import os
4 | import random
5 | import tensorflow as tf
6 |
7 | from tensorflow.contrib.layers.python import layers as tf_layers
8 | from tensorflow.python.platform import flags
9 |
10 | FLAGS = flags.FLAGS
11 |
12 | ## Image helper
13 | def get_images(paths, labels, nb_samples=None, shuffle=True):
14 | if nb_samples is not None:
15 | sampler = lambda x: random.sample(x, nb_samples)
16 | else:
17 | sampler = lambda x: x
18 | images = [(i, os.path.join(path, image)) \
19 | for i, path in zip(labels, paths) \
20 | for image in sampler(os.listdir(path))]
21 | if shuffle:
22 | random.shuffle(images)
23 | return images
24 |
25 | ## Network helpers
26 | def conv_block(inp, cweight, bweight, reuse, scope, activation=tf.nn.relu, max_pool_pad='VALID', residual=False):
27 | """ Perform, conv, batch norm, nonlinearity, and max pool """
28 | stride, no_stride = [1,2,2,1], [1,1,1,1]
29 |
30 | if FLAGS.max_pool:
31 | conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME') + bweight
32 | else:
33 | conv_output = tf.nn.conv2d(inp, cweight, stride, 'SAME') + bweight
34 | normed = normalize(conv_output, activation, reuse, scope)
35 | if FLAGS.max_pool:
36 | normed = tf.nn.max_pool(normed, stride, stride, max_pool_pad)
37 | return normed
38 |
39 | def normalize(inp, activation, reuse, scope):
40 | if FLAGS.norm == 'batch_norm':
41 | return tf_layers.batch_norm(inp, activation_fn=activation, reuse=reuse, scope=scope)
42 | elif FLAGS.norm == 'layer_norm':
43 | return tf_layers.layer_norm(inp, activation_fn=activation, reuse=reuse, scope=scope)
44 | elif FLAGS.norm == 'None':
45 | return activation(inp)
46 |
47 | ## Loss functions
48 | def mse(pred, label):
49 | pred = tf.reshape(pred, [-1])
50 | label = tf.reshape(label, [-1])
51 | return tf.reduce_mean(tf.square(pred-label))
52 |
53 | def xent(pred, label):
54 | # Note - with tf version <=0.12, this loss has incorrect 2nd derivatives
55 | return tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=label) / FLAGS.update_batch_size
56 |
57 |
58 |
--------------------------------------------------------------------------------