├── AutoMLDemos ├── demo1.py └── demo2.py ├── README.md ├── main.py ├── models.py └── plot-helpers.py /AutoMLDemos/demo1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | import tensorflow.contrib.layers as layers 5 | import far_ho as far 6 | import far_ho.examples as far_ex 7 | 8 | sess = tf.InteractiveSession() 9 | 10 | 11 | def get_data(): 12 | # load a small portion of mnist data 13 | datasets = far_ex.mnist(folder=os.path.join(os.getcwd(), 'MNIST_DATA'), partitions=(.1, .1,)) 14 | return datasets.train, datasets.validation 15 | 16 | 17 | def g_logits(x,y): 18 | with tf.variable_scope('model'): 19 | h1 = layers.fully_connected(x, 300) 20 | logits = layers.fully_connected(h1, int(y.shape[1])) 21 | return logits 22 | 23 | 24 | x = tf.placeholder(tf.float32, shape=(None, 28**2), name='x') 25 | y = tf.placeholder(tf.float32, shape=(None, 10), name='y') 26 | logits = g_logits(x,y) 27 | train_set, validation_set = get_data() 28 | 29 | lambdas = far.get_hyperparameter('lambdas', tf.zeros(train_set.num_examples)) 30 | lr = far.get_hyperparameter('lr', initializer=0.01) 31 | 32 | ce = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits) 33 | L = tf.reduce_mean(tf.sigmoid(lambdas)*ce) 34 | E = tf.reduce_mean(ce) 35 | 36 | inner_optimizer = far.GradientDescentOptimizer(lr) 37 | outer_optimizer = tf.train.AdamOptimizer() 38 | hyper_step = far.HyperOptimizer().minimize(E, outer_optimizer, L, inner_optimizer) 39 | 40 | T = 200 # Number of inner iterations 41 | train_set_supplier = train_set.create_supplier(x, y) 42 | validation_set_supplier = validation_set.create_supplier(x, y) 43 | tf.global_variables_initializer().run() 44 | 45 | print('inner:', L.eval(train_set_supplier())) 46 | print('outer:', E.eval(validation_set_supplier())) 47 | # print('-'*50) 48 | n_hyper_iterations = 10 49 | for _ in range(n_hyper_iterations): 50 | hyper_step(T, 51 | inner_objective_feed_dicts=train_set_supplier, 52 | outer_objective_feed_dicts=validation_set_supplier) 53 | print('inner:', L.eval(train_set_supplier())) 54 | print('outer:', E.eval(validation_set_supplier())) 55 | print('learning rate', lr.eval()) 56 | print('norm of examples weight', tf.norm(lambdas).eval()) 57 | print('-'*50) 58 | 59 | # # plt.plot(tr_accs, label='training accuracy') 60 | # # plt.plot(val_accs, label='validation accuracy') 61 | # # plt.legend(loc=0, frameon=True) 62 | # # # plt.xlim(0, 19) 63 | -------------------------------------------------------------------------------- /AutoMLDemos/demo2.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function, division 2 | from functools import reduce 3 | 4 | import tensorflow as tf 5 | import tensorflow.contrib.layers as tcl 6 | import far_ho as far 7 | from collections import defaultdict 8 | 9 | 10 | def hyper_conv_layer(x): 11 | hyper_coll = far.HYPERPARAMETERS_COLLECTIONS 12 | return tcl.conv2d(x, num_outputs=64, stride=2, 13 | kernel_size=3, 14 | # normalizer_fn= 15 | # lambda z: tcl.batch_norm(z, 16 | # variables_collections=hyper_coll, 17 | # trainable=False), 18 | trainable=False, 19 | variables_collections=hyper_coll) 20 | 21 | 22 | def build_hyper_representation(_x, auto_reuse=False): 23 | reuse = tf.AUTO_REUSE if auto_reuse else False 24 | with tf.variable_scope('HR', reuse=reuse): 25 | conv_out = reduce(lambda lp, k: hyper_conv_layer(lp), 26 | range(4), _x) 27 | return tf.reshape(conv_out, shape=(-1, 256)) 28 | 29 | 30 | def classifier(_x, _y): 31 | return tcl.fully_connected( 32 | _x, int(_y.shape[1]), activation_fn=None, 33 | weights_initializer=tf.zeros_initializer) 34 | 35 | 36 | def get_placeholders(): 37 | _x = tf.placeholder(tf.float32, (None, 28, 28, 1)) 38 | _y = tf.placeholder(tf.float32, (None, 5)) 39 | return _x, _y 40 | 41 | 42 | def get_data(): 43 | import experiment_manager.datasets.load as load 44 | return load.meta_omniglot( 45 | std_num_classes=5, 46 | std_num_examples=(5, 15*5)) 47 | 48 | 49 | def make_feed_dicts(tasks, mbd): 50 | train_fd, test_fd = {}, {} 51 | for task, _x, _y in zip(tasks, mbd['x'], mbd['y']): 52 | train_fd[_x] = task.train.data 53 | train_fd[_y] = task.train.target 54 | test_fd[_x] = task.test.data 55 | test_fd[_y] = task.test.target 56 | return train_fd, test_fd 57 | 58 | 59 | def accuracy(y_true, logits): 60 | return tf.reduce_mean(tf.cast( 61 | tf.equal(tf.argmax(y_true, 1), tf.argmax(logits, 1)), 62 | tf.float32)) 63 | 64 | 65 | def meta_test(meta_batches, mbd, opt, n_steps): 66 | ss = tf.get_default_session() 67 | ss.run(tf.variables_initializer(tf.trainable_variables())) 68 | sum_loss, sum_acc = 0., 0. 69 | n_tasks = len(mbd['err'])*len(meta_batches) 70 | for _tasks in meta_batches: 71 | _train_fd, _valid_fd = make_feed_dicts(_tasks, mbd) 72 | mb_err = tf.add_n(mbd['err']) 73 | mb_acc = tf.add_n(mbd['acc']) 74 | opt_step = opt.minimize(mb_err) 75 | for i in range(n_steps): 76 | ss.run(opt_step, feed_dict=_train_fd) 77 | 78 | mb_loss, mb_acc = ss.run([mb_err, mb_acc], feed_dict=_valid_fd) 79 | sum_loss += mb_loss 80 | sum_acc += mb_acc 81 | 82 | return sum_loss/n_tasks, sum_acc/n_tasks 83 | 84 | 85 | meta_batch_size = 16 # meta-batch size 86 | n_episodes_testing = 10 87 | mb_dict = defaultdict(list) # meta_batch dictionary 88 | meta_dataset = get_data() 89 | 90 | for _ in range(meta_batch_size): 91 | x, y = get_placeholders() 92 | mb_dict['x'].append(x) 93 | mb_dict['y'].append(y) 94 | hyper_repr = build_hyper_representation(x, auto_reuse=True) 95 | logits = classifier(hyper_repr, y) 96 | ce = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( 97 | labels=y, logits=logits)) 98 | mb_dict['err'].append(ce) 99 | mb_dict['acc'].append(accuracy(y, logits)) 100 | 101 | L = tf.add_n(mb_dict['err']) 102 | E = L / meta_batch_size 103 | mean_acc = tf.add_n(mb_dict['acc'])/meta_batch_size 104 | 105 | inner_opt = far.GradientDescentOptimizer(learning_rate=0.1) 106 | outer_opt = tf.train.AdamOptimizer() 107 | 108 | hyper_step = far.HyperOptimizer().minimize( 109 | E, outer_opt, L, inner_opt) 110 | 111 | sess = tf.Session() 112 | n_hyper_steps = 100 113 | with sess.as_default(): 114 | tf.global_variables_initializer().run() 115 | T = 3 116 | for meta_batch in meta_dataset.train.generate(n_hyper_steps, batch_size=meta_batch_size): 117 | train_fd, valid_fd = make_feed_dicts(meta_batch, mb_dict) 118 | hyper_step(T, train_fd, valid_fd) 119 | 120 | test_optim = tf.train.GradientDescentOptimizer(0.1) 121 | test_mbs = [mb for mb in meta_dataset.test.generate(n_episodes_testing, batch_size=meta_batch_size, rand=0)] 122 | 123 | print('train_test (loss, acc)', sess.run([E, mean_acc], feed_dict=valid_fd)) 124 | print('test_test (loss, acc)', meta_test(test_mbs, mb_dict, test_optim, T)) 125 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hyper-representation 2 | This is the official repo for the experiments in the paper "Bilevel Programming for Hyperparameter Optimization and Meta-Learning" 3 | 4 | __WORK IN PROGRESS__ 5 | 6 | ## Prerequisites 7 | you need to install [tensorflow](https://www.tensorflow.org/install/) >= 1.2.1 8 | the package [Far-HO](https://github.com/lucfra/FAR-HO) for hyper-parameter optimization 9 | and [this experiment manager package](https://github.com/lucfra/ExperimentManager) 10 | 11 | To be sure that Far-HO version matches with the one used to run these experiments, install Far-HO from branch [final_ICML18](https://github.com/lucfra/FAR-HO/tree/final_ICML18) 12 | 13 | ## Results Replications 14 | To replicate the experiments run `python main.py` from command line with the appropriate arguments 15 | described in the `main.py` file. 16 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import matplotlib 4 | matplotlib.use('Agg') 5 | 6 | import tensorflow as tf 7 | import far_ho as far 8 | import experiment_manager as em 9 | import numpy as np 10 | import inspect, os, time 11 | #from hr_resnet import hr_res_net_tcml_v1_builder, hr_res_net_tcml_Omniglot_builder 12 | from shutil import copyfile 13 | from models import hr_res_net_tcml_v1_builder 14 | from threading import Thread 15 | import pickle 16 | 17 | from tensorflow.python.platform import flags 18 | from far_ho.examples.hyper_representation import omniglot_model 19 | 20 | import seaborn 21 | seaborn.set_style('whitegrid', {'figure.figsize': (30, 20)}) 22 | 23 | em.DATASET_FOLDER = 'datasets' 24 | 25 | parser = argparse.ArgumentParser() 26 | 27 | parser.add_argument('-m', '--mode', type=str, default="train", metavar='STRING', 28 | help='mode, can be train or test') 29 | 30 | # GPU options 31 | parser.add_argument('-vg', '--visible-gpus', type=str, default="1", metavar='STRING', 32 | help="gpus that tensorflow will see") 33 | 34 | # Dataset/method options 35 | parser.add_argument('-d', '--dataset', type=str, default='miniimagenet', metavar='STRING', 36 | help='omniglot or miniimagenet.') 37 | parser.add_argument('-nc', '--classes', type=int, default=5, metavar='NUMBER', 38 | help='number of classes used in classification (c for c-way classification).') 39 | parser.add_argument('-etr', '--examples_train', type=int, default=1, metavar='NUMBER', 40 | help='number of examples used for inner gradient update (k for k-shot learning).') 41 | parser.add_argument('-etes', '--examples_test', type=int, default=15, metavar='NUMBER', 42 | help='number of examples used for test sets') 43 | 44 | # Training options 45 | parser.add_argument('-s', '--seed', type=int, default=0, metavar='NUMBER', 46 | help='seed for random number generators') 47 | parser.add_argument('-mbs', '--meta_batch_size', type=int, default=2, metavar='NUMBER', 48 | help='number of tasks sampled per meta-update') 49 | parser.add_argument('-nmi', '--n_meta_iterations', type=int, default=50000, metavar='NUMBER', 50 | help='number of metatraining iterations.') 51 | parser.add_argument('-T', '--T', type=int, default=5, metavar='NUMBER', 52 | help='number of inner updates during training.') 53 | parser.add_argument('-xi', '--xavier', type=bool, default=False, metavar='BOOLEAN', 54 | help='FFNN weights initializer') 55 | parser.add_argument('-bn', '--batch-norm', type=bool, default=False, metavar='BOOLEAN', 56 | help='Use batch normalization before classifier') 57 | parser.add_argument('-mlr', '--meta-lr', type=float, default=0.3, metavar='NUMBER', 58 | help='starting meta learning rate') 59 | parser.add_argument('-mlrdr', '--meta-lr-decay-rate', type=float, default=1.e-5, metavar='NUMBER', 60 | help='meta lr inverse time decay rate') 61 | parser.add_argument('-cv', '--clip-value', type=float, default=0., metavar='NUMBER', 62 | help='meta gradient clip value (0. for no clipping)') 63 | parser.add_argument('-lr', '--lr', type=float, default=0.4, metavar='NUMBER', 64 | help='starting learning rate') 65 | parser.add_argument('-lrl', '--learn-lr', type=bool, default=False, metavar='BOOLEAN', 66 | help='True if learning rate is an hyperparameter') 67 | 68 | # Logging, saving, and testing options 69 | parser.add_argument('-log', '--log', type=bool, default=False, metavar='BOOLEAN', 70 | help='if false, do not log summaries, for debugging code.') 71 | parser.add_argument('-ld', '--logdir', type=str, default='logs/', metavar='STRING', 72 | help='directory for summaries and checkpoints.') 73 | parser.add_argument('-res', '--resume', type=bool, default=True, metavar='BOOLEAN', 74 | help='resume training if there is a model available') 75 | parser.add_argument('-pi', '--print-interval', type=int, default=1, metavar='NUMBER', 76 | help='number of meta-train iterations before print') 77 | parser.add_argument('-si', '--save_interval', type=int, default=1, metavar='NUMBER', 78 | help='number of meta-train iterations before save') 79 | parser.add_argument('-te', '--test_episodes', type=int, default=600, metavar='NUMBER', 80 | help='number of episodes for testing') 81 | 82 | # Testing options (put parser.mode = 'test') 83 | parser.add_argument('-exd', '--exp-dir', type=str, default=None, metavar='STRING', 84 | help='directory of the experiment model files') 85 | parser.add_argument('-itt', '--iterations_to_test', type=str, default=[40000], metavar='STRING', 86 | help='meta_iteration to test (model file must be in "exp_dir")') 87 | 88 | args = parser.parse_args() 89 | 90 | available_devices = ('/gpu:0', '/gpu:1') 91 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpus 92 | 93 | exp_string = str(args.classes) + 'way_' + str(args.examples_train) + 'shot_' + str(args.meta_batch_size) + 'mbs' \ 94 | + str(args.T) + 'T' + str(args.clip_value) + 'cv' + str(args.meta_lr) + 'mlr' + str(args.lr)\ 95 | + str(args.learn_lr) + 'lr' 96 | 97 | dataset_load_dict = {'omniglot': em.load.meta_omniglot, 'miniimagenet': em.load.meta_mini_imagenet} 98 | model_dict = {'omniglot': omniglot_model, 'miniimagenet': hr_res_net_tcml_v1_builder()} 99 | 100 | 101 | def batch_producer(metadataset, batch_queue, n_batches, batch_size, rand=0): 102 | while True: 103 | batch_queue.put([d for d in metadataset.generate(n_batches, batch_size, rand)]) 104 | 105 | 106 | def start_batch_makers(number_of_workers, metadataset, batch_queue, n_batches, batch_size, rand=0): 107 | for w in range(number_of_workers): 108 | worker = Thread(target=batch_producer, args=(metadataset, batch_queue, n_batches, batch_size, rand)) 109 | worker.setDaemon(True) 110 | worker.start() 111 | 112 | 113 | # Class for debugging purposes for multi-thread issues (used now because it resolves rand issues) 114 | class BatchQueueMock: 115 | def __init__(self, metadataset, n_batches, batch_size, rand): 116 | self.metadataset = metadataset 117 | self.n_batches = n_batches 118 | self.batch_size = batch_size 119 | self.rand = rand 120 | 121 | def get(self): 122 | return [d for d in self.metadataset.generate(self.n_batches, self.batch_size, self.rand)] 123 | 124 | 125 | def save_obj(file_path, obj): 126 | with open(file_path, 'wb') as handle: 127 | pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL) 128 | 129 | 130 | def load_obj(file_path): 131 | with open(file_path, 'rb') as handle: 132 | b = pickle.load(handle) 133 | return b 134 | 135 | 136 | ''' Useful Functions ''' 137 | 138 | 139 | def feed_dicts(dat_lst, exs): 140 | dat_lst = em.as_list(dat_lst) 141 | train_fd = em.utils.merge_dicts( 142 | *[{_ex.x: dat.train.data, _ex.y: dat.train.target} 143 | for _ex, dat in zip(exs, dat_lst)]) 144 | valid_fd = em.utils.merge_dicts( 145 | *[{_ex.x: dat.test.data, _ex.y: dat.test.target} 146 | for _ex, dat in zip(exs, dat_lst)]) 147 | 148 | return train_fd, valid_fd 149 | 150 | 151 | def just_train_on_dataset(dat, exs, far_ho, sess, T): 152 | train_fd, valid_fd = feed_dicts(dat, exs) 153 | # print('train_feed:', train_fd) # DEBUG 154 | sess.run(far_ho.hypergradient.initialization) 155 | tr_acc, v_acc = [], [] 156 | for ex in exs: 157 | # ts = io_opt.minimize(ex.errors['training'], var_list=ex.model.var_list).ts 158 | # ts = tf.train.GradientDescentOptimizer(lr).minimize(ex.errors['training'], var_list=ex.model.var_list) 159 | [sess.run(ex.optimizers['ts'], feed_dict={ex.x: train_fd[ex.x], ex.y: train_fd[ex.y]}) for _ in range(T)] 160 | tr_acc.append(sess.run(ex.scores['accuracy'], feed_dict={ex.x: train_fd[ex.x], ex.y: train_fd[ex.y]})) 161 | v_acc.append(sess.run(ex.scores['accuracy'], feed_dict={ex.x: valid_fd[ex.x], ex.y: valid_fd[ex.y]})) 162 | return tr_acc, v_acc 163 | 164 | 165 | def accuracy_on(batch_queue, exs, far_ho, sess, T): 166 | tr_acc, v_acc = [], [] 167 | for d in batch_queue.get(): 168 | result = just_train_on_dataset(d, exs, far_ho, sess, T) 169 | tr_acc.extend(result[0]) 170 | v_acc.extend(result[1]) 171 | return tr_acc, v_acc 172 | 173 | 174 | def just_train_on_dataset_up_to_T(dat, exs, far_ho, sess, T): 175 | train_fd, valid_fd = feed_dicts(dat, exs) 176 | # print('train_feed:', train_fd) # DEBUG 177 | sess.run(far_ho.hypergradient.initialization) 178 | tr_acc, v_acc = [[] for _ in range(T)], [[] for _ in range(T)] 179 | for ex in exs: 180 | # ts = io_opt.minimize(ex.errors['training'], var_list=ex.model.var_list).ts 181 | # ts = tf.train.GradientDescentOptimizer(lr).minimize(ex.errors['training'], var_list=ex.model.var_list) 182 | for t in range(T): 183 | sess.run(ex.optimizers['ts'], feed_dict={ex.x: train_fd[ex.x], ex.y: train_fd[ex.y]}) 184 | tr_acc[t].append(sess.run(ex.scores['accuracy'], feed_dict={ex.x: train_fd[ex.x], ex.y: train_fd[ex.y]})) 185 | v_acc[t].append(sess.run(ex.scores['accuracy'], feed_dict={ex.x: valid_fd[ex.x], ex.y: valid_fd[ex.y]})) 186 | return tr_acc, v_acc 187 | 188 | 189 | def accuracy_on_up_to_T(batch_queue, exs, far_ho, sess, T): 190 | tr_acc, v_acc = [[] for _ in range(T)], [[] for _ in range(T)] 191 | for d in batch_queue.get(): 192 | result = just_train_on_dataset_up_to_T(d, exs, far_ho, sess, T) 193 | [tr_acc[T].extend(r) for T, r in enumerate(result[0])] 194 | [v_acc[T].extend(r) for T, r in enumerate(result[1])] 195 | 196 | return tr_acc, v_acc 197 | 198 | 199 | def build(metasets, hyper_model_builder, learn_lr, lr0, MBS, mlr0, mlr_decay, batch_norm_before_classifier, weights_initializer, 200 | process_fn=None): 201 | exs = [em.SLExperiment(metasets) for _ in range(MBS)] 202 | 203 | hyper_repr_model = hyper_model_builder(exs[0].x, 'HyperRepr') 204 | 205 | if learn_lr: 206 | lr = far.get_hyperparameter('lr', lr0) 207 | else: 208 | lr = tf.constant(lr0, name='lr') 209 | 210 | gs = tf.get_variable('global_step', initializer=0, trainable=False) 211 | meta_lr = tf.train.inverse_time_decay(mlr0, gs, decay_steps=1., decay_rate=mlr_decay) 212 | 213 | io_opt = far.GradientDescentOptimizer(lr) 214 | oo_opt = tf.train.AdamOptimizer(meta_lr) 215 | far_ho = far.HyperOptimizer() 216 | 217 | for k, ex in enumerate(exs): 218 | # print(k) # DEBUG 219 | with tf.device(available_devices[k % len(available_devices)]): 220 | repr_out = hyper_repr_model.for_input(ex.x).out 221 | 222 | other_train_vars = [] 223 | if batch_norm_before_classifier: 224 | batch_mean, batch_var = tf.nn.moments(repr_out, [0]) 225 | scale = tf.Variable(tf.ones_like(repr_out[0])) 226 | beta = tf.Variable(tf.zeros_like(repr_out[0])) 227 | other_train_vars.append(scale) 228 | other_train_vars.append(beta) 229 | repr_out = tf.nn.batch_normalization(repr_out, batch_mean, batch_var, beta, scale, 1e-3) 230 | 231 | ex.model = em.models.FeedForwardNet(repr_out, metasets.train.dim_target, 232 | output_weight_initializer=weights_initializer, name='Classifier_%s' % k) 233 | 234 | ex.errors['training'] = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=ex.y, 235 | logits=ex.model.out)) 236 | ex.errors['validation'] = ex.errors['training'] 237 | ex.scores['accuracy'] = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(ex.y, 1), tf.argmax(ex.model.out, 1)), 238 | tf.float32), name='accuracy') 239 | 240 | # simple training step used for testing (look 241 | ex.optimizers['ts'] = tf.train.GradientDescentOptimizer(lr).minimize(ex.errors['training'], 242 | var_list=ex.model.var_list) 243 | 244 | optim_dict = far_ho.inner_problem(ex.errors['training'], io_opt, 245 | var_list=ex.model.var_list + other_train_vars) 246 | far_ho.outer_problem(ex.errors['validation'], optim_dict, oo_opt, 247 | hyper_list=tf.get_collection(far.GraphKeys.HYPERPARAMETERS), global_step=gs) 248 | 249 | far_ho.finalize(process_fn=process_fn) 250 | saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=240) 251 | return exs, far_ho, saver 252 | 253 | 254 | def meta_train(exp_dir, metasets, exs, far_ho, saver, sess, n_test_episodes, MBS, seed, resume, T, 255 | n_meta_iterations, print_interval, save_interval): 256 | # use workers to fill the batches queues (is it worth it?) 257 | 258 | result_path = os.path.join(exp_dir, 'results.pickle') 259 | 260 | tf.global_variables_initializer().run(session=sess) 261 | 262 | n_test_batches = n_test_episodes // MBS 263 | rand = em.get_rand_state(seed) 264 | 265 | results = {'train_train': {'mean': [], 'std': []}, 'train_test': {'mean': [], 'std': []}, 266 | 'test_test': {'mean': [], 'std': []}, 'valid_test': {'mean': [], 'std': []}, 267 | 'outer_losses': {'mean':[], 'std': []}, 'learning_rate': [], 'iterations': [], 268 | 'episodes': [], 'time': []} 269 | 270 | start_time = time.time() 271 | 272 | resume_itr = 0 273 | if resume: 274 | model_file = tf.train.latest_checkpoint(exp_dir) 275 | if model_file: 276 | print("Restoring results from " + result_path) 277 | results = load_obj(result_path) 278 | start_time = results['time'][-1] 279 | 280 | ind1 = model_file.index('model') 281 | resume_itr = int(model_file[ind1 + 5:]) + 1 282 | print("Restoring model weights from " + model_file) 283 | saver.restore(sess, model_file) 284 | 285 | ''' Meta-Train ''' 286 | train_batches = BatchQueueMock(metasets.train, 1, MBS, rand) 287 | valid_batches = BatchQueueMock(metasets.validation, n_test_batches, MBS, rand) 288 | test_batches = BatchQueueMock(metasets.test, n_test_batches, MBS, rand) 289 | 290 | print('\nIteration quantities: train_train acc, train_test acc, valid_test, acc' 291 | ' test_test acc mean(std) over %d episodes' % n_test_episodes) 292 | with sess.as_default(): 293 | inner_losses = [] 294 | for meta_it in range(resume_itr, n_meta_iterations): 295 | tr_fd, v_fd = feed_dicts(train_batches.get()[0], exs) 296 | 297 | far_ho.run(T, tr_fd, v_fd) 298 | # inner_losses.append(far_ho.inner_losses) 299 | 300 | outer_losses = [sess.run(ex.errors['validation'], v_fd) for ex in exs] 301 | outer_losses_moments = (np.mean(outer_losses), np.std(outer_losses)) 302 | results['outer_losses']['mean'].append(outer_losses_moments[0]) 303 | results['outer_losses']['std'].append(outer_losses_moments[1]) 304 | 305 | # print('inner_losses: ', inner_losses[-1]) 306 | 307 | if meta_it % print_interval == 0 or meta_it == n_meta_iterations - 1: 308 | results['iterations'].append(meta_it) 309 | results['episodes'].append(meta_it * MBS) 310 | 311 | train_result = accuracy_on(train_batches, exs, far_ho, sess, T) 312 | test_result = accuracy_on(test_batches, exs, far_ho, sess, T) 313 | valid_result = accuracy_on(valid_batches, exs, far_ho, sess, T) 314 | 315 | train_train = (np.mean(train_result[0]), np.std(train_result[0])) 316 | train_test = (np.mean(train_result[1]), np.std(train_result[1])) 317 | valid_test = (np.mean(valid_result[1]), np.std(valid_result[1])) 318 | test_test = (np.mean(test_result[1]), np.std(test_result[1])) 319 | 320 | duration = time.time() - start_time 321 | results['time'].append(duration) 322 | 323 | results['train_train']['mean'].append(train_train[0]) 324 | results['train_test']['mean'].append(train_test[0]) 325 | results['valid_test']['mean'].append(valid_test[0]) 326 | results['test_test']['mean'].append(test_test[0]) 327 | 328 | results['train_train']['std'].append(train_train[1]) 329 | results['train_test']['std'].append(train_test[1]) 330 | results['valid_test']['std'].append(valid_test[1]) 331 | results['test_test']['std'].append(test_test[1]) 332 | 333 | results['inner_losses'] = inner_losses 334 | 335 | print('mean outer losses: {}'.format(outer_losses_moments[1])) 336 | 337 | print('it %d, ep %d (%.2fs): %.3f, %.3f, %.3f, %.3f' % (meta_it, meta_it * MBS, duration, train_train[0], 338 | train_test[0], valid_test[0], test_test[0])) 339 | 340 | lr = sess.run(["lr:0"])[0] 341 | print('lr: {}'.format(lr)) 342 | 343 | # do_plot(logdir, results) 344 | 345 | if meta_it % save_interval == 0 or meta_it == n_meta_iterations - 1: 346 | saver.save(sess, exp_dir + '/model' + str(meta_it)) 347 | save_obj(result_path, results) 348 | 349 | return results 350 | 351 | 352 | def meta_test(exp_dir, metasets, exs, far_ho, saver, sess, c_way, k_shot, lr, n_test_episodes, MBS, seed, T, 353 | iterations=list(range(10000))): 354 | 355 | meta_test_str = str(c_way) + 'way_' + str(k_shot) + 'shot_' \ 356 | + str(T) + 'T' + str(lr) + 'lr' + str(n_test_episodes) + 'ep' 357 | 358 | n_test_batches = n_test_episodes // MBS 359 | rand = em.get_rand_state(seed) 360 | 361 | valid_batches = BatchQueueMock(metasets.validation, n_test_batches, MBS, rand) 362 | test_batches = BatchQueueMock(metasets.test, n_test_batches, MBS, rand) 363 | 364 | print('\nMeta-testing {} (over {} eps)...'.format(meta_test_str, n_test_episodes)) 365 | 366 | test_results = {'test_test': {'mean': [], 'std': []}, 'valid_test': {'mean': [], 'std': []}, 367 | 'cp_numbers': [], 'time': [], 368 | 'n_test_episodes': n_test_episodes, 'episodes': [], 'iterations': []} 369 | 370 | test_result_path = os.path.join(exp_dir, meta_test_str + '_results.pickle') 371 | 372 | start_time = time.time() 373 | for i in iterations: 374 | model_file = os.path.join(exp_dir, 'model' + str(i)) 375 | if tf.train.checkpoint_exists(model_file): 376 | print("Restoring model weights from " + model_file) 377 | saver.restore(sess, model_file) 378 | 379 | test_results['iterations'].append(i) 380 | test_results['episodes'].append(i * MBS) 381 | 382 | valid_result = accuracy_on(valid_batches, exs, far_ho, sess, T) 383 | test_result = accuracy_on(test_batches, exs, far_ho, sess, T) 384 | 385 | duration = time.time() - start_time 386 | 387 | valid_test = (np.mean(valid_result[1]), np.std(valid_result[1])) 388 | test_test = (np.mean(test_result[1]), np.std(test_result[1])) 389 | 390 | test_results['time'].append(duration) 391 | 392 | test_results['valid_test']['mean'].append(valid_test[0]) 393 | test_results['test_test']['mean'].append(test_test[0]) 394 | 395 | test_results['valid_test']['std'].append(valid_test[1]) 396 | test_results['test_test']['std'].append(test_test[1]) 397 | 398 | print('valid-test_test acc (%d meta_it)(%.2fs): %.3f (%.3f), %.3f (%.3f)' % (i, duration, valid_test[0], 399 | valid_test[1],test_test[0], 400 | test_test[1])) 401 | 402 | save_obj(test_result_path, test_results) 403 | 404 | return test_results 405 | 406 | 407 | def meta_test_up_to_T(exp_dir, metasets, exs, far_ho, saver, sess, c_way, k_shot, lr, n_test_episodes, MBS, seed, T, 408 | iterations=list(range(10000))): 409 | meta_test_str = str(c_way) + 'way_' + str(k_shot) + 'shot_' + str(lr) + 'lr' + str(n_test_episodes) + 'ep' 410 | 411 | n_test_batches = n_test_episodes // MBS 412 | rand = em.get_rand_state(seed) 413 | 414 | valid_batches = BatchQueueMock(metasets.validation, n_test_batches, MBS, rand) 415 | test_batches = BatchQueueMock(metasets.test, n_test_batches, MBS, rand) 416 | train_batches = BatchQueueMock(metasets.train, n_test_batches, MBS, rand) 417 | 418 | print('\nMeta-testing {} (over {} eps)...'.format(meta_test_str, n_test_episodes)) 419 | 420 | test_results = {'valid_test': [], 'test_test': [], 'train_test': [], 'time': [], 'n_test_episodes': n_test_episodes, 421 | 'episodes': [], 'iterations': []} 422 | 423 | test_result_path = os.path.join(exp_dir, meta_test_str + 'noTrain_results.pickle') 424 | 425 | start_time = time.time() 426 | for i in iterations: 427 | model_file = os.path.join(exp_dir, 'model' + str(i)) 428 | if tf.train.checkpoint_exists(model_file): 429 | print("Restoring model weights from " + model_file) 430 | saver.restore(sess, model_file) 431 | 432 | test_results['iterations'].append(i) 433 | test_results['episodes'].append(i * MBS) 434 | 435 | valid_result = accuracy_on_up_to_T(valid_batches, exs, far_ho, sess, T) 436 | test_result = accuracy_on_up_to_T(test_batches, exs, far_ho, sess, T) 437 | train_result = accuracy_on_up_to_T(train_batches, exs, far_ho, sess, T) 438 | 439 | duration = time.time() - start_time 440 | 441 | test_results['time'].append(duration) 442 | 443 | for t in range(T): 444 | 445 | valid_test = (np.mean(valid_result[1][t]), np.std(valid_result[1][t])) 446 | test_test = (np.mean(test_result[1][t]), np.std(test_result[1][t])) 447 | train_test = (np.mean(train_result[1][t]), np.std(train_result[1][t])) 448 | 449 | if t >= len(test_results['valid_test']): 450 | test_results['valid_test'].append({'mean': [], 'std': []}) 451 | test_results['test_test'].append({'mean': [], 'std': []}) 452 | test_results['train_test'].append({'mean': [], 'std': []}) 453 | 454 | test_results['valid_test'][t]['mean'].append(valid_test[0]) 455 | test_results['test_test'][t]['mean'].append(test_test[0]) 456 | test_results['train_test'][t]['mean'].append(train_test[0]) 457 | 458 | test_results['valid_test'][t]['std'].append(valid_test[1]) 459 | test_results['test_test'][t]['std'].append(test_test[1]) 460 | test_results['train_test'][t]['std'].append(train_test[1]) 461 | 462 | print('valid-test_test acc T=%d (%d meta_it)(%.2fs): %.4f (%.4f), %.4f (%.4f),' 463 | ' %.4f (%.4f)' % (t+1, i, duration, train_test[0], train_test[1], valid_test[0], valid_test[1], 464 | test_test[0], test_test[1])) 465 | 466 | #print('valid-test_test acc T=%d (%d meta_it)(%.2fs): %.4f (%.4f),' 467 | # ' %.4f (%.4f)' % (t+1, i, duration, valid_test[0], valid_test[1], 468 | # test_test[0], test_test[1])) 469 | 470 | save_obj(test_result_path, test_results) 471 | 472 | return test_results 473 | 474 | 475 | # training and testing function 476 | def train_and_test(metasets, name_of_exp, hyper_model_builder, logdir='logs/', seed=None, lr0=0.04, learn_lr=False, mlr0=0.001, 477 | mlr_decay=1.e-5, T=5, resume=True, MBS=4, n_meta_iterations=5000, weights_initializer=tf.zeros_initializer, 478 | batch_norm_before_classifier=False, process_fn=None, save_interval=5000, print_interval=5000, 479 | n_test_episodes=1000): 480 | 481 | params = locals() 482 | print('params: {}'.format(params)) 483 | 484 | ''' Problem Setup ''' 485 | np.random.seed(seed) 486 | tf.set_random_seed(seed) 487 | 488 | exp_dir = logdir + '/' + name_of_exp 489 | print('\nExperiment directory:', exp_dir + '...') 490 | if not os.path.exists(exp_dir): 491 | os.makedirs(exp_dir) 492 | 493 | executing_file_path = inspect.getfile(inspect.currentframe()) 494 | print('copying {} into {}'.format(executing_file_path, exp_dir)) 495 | copyfile(executing_file_path, os.path.join(exp_dir, executing_file_path.split('/')[-1])) 496 | 497 | exs, far_ho, saver = build(metasets, hyper_model_builder,learn_lr, lr0, MBS, mlr0, mlr_decay, 498 | batch_norm_before_classifier, weights_initializer, process_fn) 499 | 500 | sess = tf.Session(config=em.utils.GPU_CONFIG()) 501 | 502 | meta_train(exp_dir, metasets, exs, far_ho, saver, sess, n_test_episodes, MBS, seed, resume, T, 503 | n_meta_iterations, print_interval, save_interval) 504 | 505 | meta_test(exp_dir, metasets, exs, far_ho, saver, sess, args.classes, args.examples_train, lr0, 506 | n_test_episodes, MBS, seed, T, list(range(n_meta_iterations))) 507 | 508 | 509 | # training and testing function 510 | def build_and_test(metasets, exp_dir, hyper_model_builder, seed=None, lr0=0.04, T=5, MBS=4, 511 | weights_initializer=tf.zeros_initializer, batch_norm_before_classifier=False, 512 | process_fn=None, n_test_episodes=600, iterations_to_test=list(range(100000))): 513 | 514 | params = locals() 515 | print('params: {}'.format(params)) 516 | 517 | mlr_decay = 1.e-5 518 | mlr0 = 0.001 519 | learn_lr = False 520 | 521 | ''' Problem Setup ''' 522 | np.random.seed(seed) 523 | tf.set_random_seed(seed) 524 | 525 | exs, far_ho, saver = build(metasets, hyper_model_builder,learn_lr, lr0, MBS, mlr0, mlr_decay, 526 | batch_norm_before_classifier, weights_initializer, process_fn) 527 | 528 | sess = tf.Session(config=em.utils.GPU_CONFIG()) 529 | 530 | meta_test_up_to_T(exp_dir, metasets, exs, far_ho, saver, sess, args.classes, args.examples_train, lr0, 531 | n_test_episodes, MBS, seed, T, iterations_to_test) 532 | 533 | 534 | def main(): 535 | print(args.__dict__) 536 | 537 | try: 538 | metasets = dataset_load_dict[args.dataset]( 539 | std_num_classes=args.classes, std_num_examples=(args.examples_train*args.classes, 540 | args.examples_test*args.classes)) 541 | except KeyError: 542 | raise ValueError('dataset FLAG must be omniglot or miniimagenet') 543 | 544 | weights_initializer = tf.contrib.layers.xavier_initializer() if args.xavier else tf.zeros_initializer 545 | 546 | if args.clip_value > 0.: 547 | def process_fn(t): 548 | return tf.clip_by_value(t, -args.clip_value, args.clip_value) 549 | else: 550 | process_fn = None 551 | 552 | logdir = args.logdir + args.dataset 553 | 554 | hyper_model_builder = model_dict[args.dataset] 555 | 556 | if args.mode == 'train': 557 | train_and_test(metasets, exp_string, hyper_model_builder, logdir, seed=args.seed, 558 | lr0=args.lr, 559 | learn_lr=args.learn_lr, mlr0=args.meta_lr, mlr_decay=args.meta_lr_decay_rate, T=args.T, 560 | resume=args.resume, MBS=args.meta_batch_size, n_meta_iterations=args.n_meta_iterations, 561 | weights_initializer=weights_initializer, batch_norm_before_classifier=args.batch_norm, 562 | process_fn=process_fn, save_interval=args.save_interval, print_interval=args.print_interval, 563 | n_test_episodes=args.test_episodes) 564 | 565 | elif args.mode == 'test': 566 | build_and_test(metasets, args.exp_dir, hyper_model_builder, seed=args.seed, lr0=args.lr, 567 | T=args.T, MBS=args.meta_batch_size, weights_initializer=weights_initializer, 568 | batch_norm_before_classifier=args.batch_norm, process_fn=process_fn, 569 | n_test_episodes=args.test_episodes, iterations_to_test=args.iterations_to_test) 570 | 571 | 572 | if __name__ == "__main__": 573 | main() 574 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import experiment_manager as em 2 | from experiment_manager import models 3 | import tensorflow as tf 4 | from tensorflow.contrib import layers as tcl 5 | import far_ho as far 6 | 7 | 8 | class TCML_ResNet(models.Network): 9 | def __init__(self, _input, name=None, deterministic_initialization=False, reuse=False): 10 | self.var_coll = far.HYPERPARAMETERS_COLLECTIONS 11 | super().__init__(_input, name, deterministic_initialization, reuse) 12 | 13 | 14 | self.betas = self.filter_vars('beta') 15 | self.moving_means = self.filter_vars('moving_mean') 16 | self.moving_variances = self.filter_vars('moving_variance') 17 | 18 | if not reuse: 19 | far.utils.remove_from_collection(far.GraphKeys.MODEL_VARIABLES, *self.moving_means, *self.moving_variances) 20 | 21 | far.utils.remove_from_collection(far.GraphKeys.HYPERPARAMETERS, *self.moving_means, *self.moving_variances) 22 | print(name, 'MODEL CREATED') 23 | 24 | def _build(self): 25 | 26 | def residual_block(x, n_filters): 27 | skip_c = tcl.conv2d(x, n_filters, 1, activation_fn=None) 28 | 29 | def conv_block(xx): 30 | out = tcl.conv2d(xx, n_filters, 3, activation_fn=None, normalizer_fn=tcl.batch_norm, 31 | variables_collections=self.var_coll) 32 | return em.utils.leaky_relu(out, 0.1) 33 | 34 | out = x 35 | for _ in range(3): 36 | out = conv_block(out) 37 | 38 | add = tf.add(skip_c, out) 39 | 40 | return tf.nn.max_pool(add, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME') 41 | 42 | self + residual_block(self.out, 64) 43 | self + residual_block(self.out, 96) 44 | self + residual_block(self.out, 128) 45 | self + residual_block(self.out, 256) 46 | self + tcl.conv2d(self.out, 2048, 1, variables_collections=self.var_coll) 47 | self + tf.nn.avg_pool(self.out, [1, 6, 6, 1], [1, 6, 6, 1], 'VALID') 48 | self + tcl.conv2d(self.out, 512, 1, variables_collections=self.var_coll) 49 | self + tf.reshape(self.out, (-1, 512)) 50 | 51 | def for_input(self, new_input): 52 | return TCML_ResNet(new_input, self.name, self.deterministic_initialization, True) 53 | 54 | 55 | class TCML_ResNet_Omniglot(models.Network): 56 | def __init__(self, _input, name=None, deterministic_initialization=False, reuse=False): 57 | self.var_coll = far.HYPERPARAMETERS_COLLECTIONS 58 | super().__init__(_input, name, deterministic_initialization, reuse) 59 | 60 | 61 | self.betas = self.filter_vars('beta') 62 | self.moving_means = self.filter_vars('moving_mean') 63 | self.moving_variances = self.filter_vars('moving_variance') 64 | 65 | if not reuse: 66 | far.utils.remove_from_collection(far.GraphKeys.MODEL_VARIABLES, *self.moving_means, *self.moving_variances) 67 | 68 | far.utils.remove_from_collection(far.GraphKeys.HYPERPARAMETERS, *self.moving_means, *self.moving_variances) 69 | print(name, 'MODEL CREATED') 70 | 71 | def _build(self): 72 | 73 | def residual_block(x, n_filters): 74 | skip_c = tcl.conv2d(x, n_filters, 1, activation_fn=None) 75 | 76 | def conv_block(xx): 77 | out = tcl.conv2d(xx, n_filters, 3, activation_fn=None, normalizer_fn=tcl.batch_norm, 78 | variables_collections=self.var_coll) 79 | return em.utils.leaky_relu(out, 0.1) 80 | 81 | out = x 82 | for _ in range(3): 83 | out = conv_block(out) 84 | 85 | add = tf.add(skip_c, out) 86 | 87 | return tf.nn.max_pool(add, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME') 88 | 89 | self + residual_block(self.out, 64) 90 | self + residual_block(self.out, 96) 91 | # self + residual_block(self.out, 128) 92 | # self + residual_block(self.out, 256) 93 | self + tcl.conv2d(self.out, 2048, 1, variables_collections=self.var_coll) 94 | self + tf.nn.avg_pool(self.out, [1, 6, 6, 1], [1, 6, 6, 1], 'VALID') 95 | self + tcl.conv2d(self.out, 512, 1, variables_collections=self.var_coll) 96 | self + tf.reshape(self.out, (-1, 512)) 97 | 98 | def for_input(self, new_input): 99 | return TCML_ResNet_Omniglot(new_input, self.name, self.deterministic_initialization, True) 100 | 101 | 102 | class TCML_ResNet_Omniglot_v2(models.Network): 103 | def __init__(self, _input, name=None, deterministic_initialization=False, reuse=False): 104 | self.var_coll = far.HYPERPARAMETERS_COLLECTIONS 105 | super().__init__(_input, name, deterministic_initialization, reuse) 106 | 107 | self.betas = self.filter_vars('beta') 108 | self.moving_means = self.filter_vars('moving_mean') 109 | self.moving_variances = self.filter_vars('moving_variance') 110 | 111 | if not reuse: 112 | far.utils.remove_from_collection(far.GraphKeys.MODEL_VARIABLES, *self.moving_means, *self.moving_variances) 113 | 114 | far.utils.remove_from_collection(far.GraphKeys.HYPERPARAMETERS, *self.moving_means, *self.moving_variances) 115 | print(name, 'MODEL CREATED') 116 | 117 | def _build(self): 118 | 119 | def residual_block(x, n_filters): 120 | skip_c = tcl.conv2d(x, n_filters, 1, activation_fn=None) 121 | 122 | def conv_block(xx): 123 | out = tcl.conv2d(xx, n_filters, 3, activation_fn=None, normalizer_fn=tcl.batch_norm, 124 | variables_collections=self.var_coll) 125 | return em.utils.leaky_relu(out, 0.1) 126 | 127 | out = x 128 | for _ in range(3): 129 | out = conv_block(out) 130 | 131 | add = tf.add(skip_c, out) 132 | 133 | return tf.nn.max_pool(add, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME') 134 | 135 | self + residual_block(self.out, 64) 136 | self + residual_block(self.out, 96) 137 | self + residual_block(self.out, 128) 138 | self + residual_block(self.out, 256) 139 | self + tcl.conv2d(self.out, 2048, 1, variables_collections=self.var_coll) 140 | self + tf.nn.avg_pool(self.out, [1, 6, 6, 1], [1, 6, 6, 1], 'SAME') 141 | self + tcl.conv2d(self.out, 512, 1, variables_collections=self.var_coll) 142 | self + tf.reshape(self.out, (-1, 512)) 143 | 144 | def for_input(self, new_input): 145 | return TCML_ResNet_Omniglot_v2(new_input, self.name, self.deterministic_initialization, True) 146 | 147 | 148 | def hr_res_net_tcml_Omniglot_builder_v2(): 149 | return lambda x, name: TCML_ResNet_Omniglot_v2(x, name=name) 150 | 151 | 152 | 153 | 154 | def hr_res_net_tcml_v1_builder(): 155 | return lambda x, name: TCML_ResNet(x, name=name) 156 | 157 | 158 | def hr_res_net_tcml_Omniglot_builder(): 159 | return lambda x, name: TCML_ResNet_Omniglot(x, name=name) 160 | 161 | 162 | if __name__ == '__main__': 163 | inp = tf.placeholder(tf.float32, (None, 84, 84, 3)) 164 | net = TCML_ResNet(inp) 165 | print(net.out) 166 | print(far.hyperparameters()) 167 | 168 | 169 | -------------------------------------------------------------------------------- /plot-helpers.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | 4 | 5 | def do_plot(exp_path, results=None): 6 | import pickle, os 7 | import numpy as np 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt 11 | 12 | if results is None: 13 | def load_obj(file_path): 14 | with open(file_path, 'rb') as handle: 15 | b = pickle.load(handle) 16 | return b 17 | results = load_obj(os.path.join(exp_path, 'results.pickle')) 18 | # results = load_obj(os.path.join(exp_path, 'test_results.pickle')) 19 | results = load_obj(os.path.join(exp_path, 'test5shot_results.pickle')) 20 | 21 | exp_name = exp_path.split('/')[-1] 22 | 23 | x_div = 1000 24 | x_string = 'episodes' # can be 'iterations' or 'episodes' 25 | x_values = np.array(results[x_string]) / x_div 26 | 27 | arg_max_valid = np.argmax(results['valid_test']['mean']) 28 | chosen_valid = results['valid_test']['mean'][arg_max_valid] 29 | chosen_test = results['test_test']['mean'][arg_max_valid] 30 | chosen_valid_ci = results['valid_test']['std'][arg_max_valid] 31 | chosen_test_ci = results['test_test']['std'][arg_max_valid] 32 | chosen_x = x_values[arg_max_valid] 33 | 34 | fig = plt.figure(figsize=(8, 5)) 35 | plt.title(exp_name) 36 | #plt.plot(x_values, results['train_train']['mean'], label='train_train') 37 | #plt.plot(x_values, results['train_test']['mean'], label='train_test') 38 | plt.plot(x_values, results['valid_test']['mean'], label='valid_test (t: %.4f (%.4f), x: %d)' % (chosen_valid, 39 | chosen_valid_ci, 40 | chosen_x)) 41 | plt.plot(x_values, results['test_test']['mean'], label='test_test (t: %.4f (%.4f), x: %d)' % (chosen_test, 42 | chosen_test_ci, 43 | chosen_x)) 44 | plt.legend(loc=0) 45 | plt.xlabel(x_string + ' / %d' % x_div) 46 | plt.savefig(exp_path + '/5shot_accuracies.png') 47 | plt.close(fig) 48 | 49 | 50 | def load_results(exp_path): 51 | with open(os.path.join(exp_path, 'results.pickle'), 'rb') as handle: 52 | b = pickle.load(handle) 53 | return b --------------------------------------------------------------------------------