├── __init__.py ├── mtnet-fig.png ├── mtnet-subspace.png ├── experiments ├── sine.sh ├── omniglot.sh ├── polynomial.sh └── miniimagenet.sh ├── data ├── omniglot_resized │ └── resize_images.py └── miniImagenet │ └── proc_images.py ├── special_grads.py ├── LICENSE ├── poly_generator.py ├── utils.py ├── README.md ├── data_generator.py ├── main.py └── maml.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mtnet-fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoonholee/MT-net/HEAD/mtnet-fig.png -------------------------------------------------------------------------------- /mtnet-subspace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoonholee/MT-net/HEAD/mtnet-subspace.png -------------------------------------------------------------------------------- /experiments/sine.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | for lr in .01 .04 .1 4 | do 5 | python main.py \ 6 | --datasource=sinusoid --metatrain_iterations=60000 \ 7 | --meta_batch_size=4 --update_lr=$lr --norm=None --resume=True \ 8 | --update_batch_size=10 --use_T=True --use_M=True --share_M=True \ 9 | --logdir=logs/sine 10 | done 11 | 12 | # For example, to use T-net: 13 | # --use_T=True --use_M=False --share_M=False 14 | # 15 | # Original MAML is recovered by using: 16 | # --use_T=False --use_M=False --share_M=False 17 | -------------------------------------------------------------------------------- /experiments/omniglot.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Omniglot 5-way with MT-net 4 | python main.py \ 5 | --datasource=omniglot --metatrain_iterations=40000 \ 6 | --meta_batch_size=32 --update_batch_size=1\ 7 | --num_classes=5 --num_updates=1 --logdir=logs/omniglot20way \ 8 | --update_lr=.4 --use_T=True --use_M=True --share_M=True 9 | 10 | # Omniglot 20-way with MT-net 11 | python main.py \ 12 | --datasource=omniglot --metatrain_iterations=40000 \ 13 | --meta_batch_size=16 --update_batch_size=1\ 14 | --num_classes=20 --num_updates=1 --logdir=logs/omniglot20way \ 15 | --update_lr=.1 --use_T=True --use_M=True --share_M=True 16 | -------------------------------------------------------------------------------- /experiments/polynomial.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python main.py \ 4 | --datasource=polynomial --metatrain_iterations=60000 --update_batch_size=10 \ 5 | --meta_batch_size=4 --norm=None --logdir=logs/poly --poly_order=0 \ 6 | --use_T=True --use_M=True --share_M=True 7 | 8 | python main.py \ 9 | --datasource=polynomial --metatrain_iterations=60000 --update_batch_size=10 \ 10 | --meta_batch_size=4 --norm=None --logdir=logs/poly --poly_order=1 \ 11 | --use_T=True --use_M=True --share_M=True 12 | 13 | python main.py \ 14 | --datasource=polynomial --metatrain_iterations=60000 --update_batch_size=10 \ 15 | --meta_batch_size=4 --norm=None --logdir=logs/poly --poly_order=2 \ 16 | --use_T=True --use_M=True --share_M=True 17 | -------------------------------------------------------------------------------- /data/omniglot_resized/resize_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage instructions: 3 | First download the omniglot dataset 4 | and put the contents of both images_background and images_evaluation in data/omniglot/ (without the root folder) 5 | 6 | Then, run the following: 7 | cd data/ 8 | cp -r omniglot/* omniglot_resized/ 9 | cd omniglot_resized/ 10 | python resize_images.py 11 | """ 12 | from PIL import Image 13 | import glob 14 | 15 | image_path = '*/*/' 16 | 17 | all_images = glob.glob(image_path + '*') 18 | 19 | i = 0 20 | 21 | for image_file in all_images: 22 | im = Image.open(image_file) 23 | im = im.resize((28,28), resample=Image.LANCZOS) 24 | im.save(image_file) 25 | i += 1 26 | 27 | if i % 200 == 0: 28 | print(i) 29 | 30 | -------------------------------------------------------------------------------- /special_grads.py: -------------------------------------------------------------------------------- 1 | """ Code for second derivatives not implemented in TensorFlow library. """ 2 | from tensorflow.python.framework import ops 3 | from tensorflow.python.ops import array_ops 4 | from tensorflow.python.ops import gen_nn_ops 5 | 6 | @ops.RegisterGradient("MaxPoolGrad") 7 | def _MaxPoolGradGrad(op, grad): 8 | gradient = gen_nn_ops._max_pool_grad(op.inputs[0], op.outputs[0], 9 | grad, op.get_attr("ksize"), op.get_attr("strides"), 10 | padding=op.get_attr("padding"), data_format=op.get_attr("data_format")) 11 | gradgrad1 = array_ops.zeros(shape = array_ops.shape(op.inputs[1]), dtype=gradient.dtype) 12 | gradgrad2 = array_ops.zeros(shape = array_ops.shape(op.inputs[2]), dtype=gradient.dtype) 13 | return (gradient, gradgrad1, gradgrad2) 14 | -------------------------------------------------------------------------------- /experiments/miniimagenet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # miniImagenet with MT-nets and hyperparameters from MAML 4 | python main.py \ 5 | --datasource=miniimagenet --metatrain_iterations=60000 \ 6 | --meta_batch_size=4 --update_batch_size=1 \ 7 | --num_updates=5 --logdir=logs/miniimagenet5way \ 8 | --update_lr=.01 --resume=True --num_filters=32 --max_pool=True \ 9 | --use_T=True --use_M=True --share_M=True 10 | 11 | # works well even with single gradient step 12 | python main.py \ 13 | --datasource=miniimagenet --metatrain_iterations=60000 \ 14 | --meta_batch_size=4 --update_batch_size=1 \ 15 | --num_updates=1 --logdir=logs/miniimagenet5way \ 16 | --update_lr=.4 --resume=True --num_filters=32 --max_pool=True \ 17 | --use_T=True --use_M=True --share_M=True 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Chelsea Finn 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/miniImagenet/proc_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for converting from csv file datafiles to a directory for each image (which is how it is loaded by MAML code) 3 | 4 | Acquire miniImagenet from Ravi & Larochelle '17, along with the train, val, and test csv files. Put the 5 | csv files in the miniImagenet directory and put the images in the directory 'miniImagenet/images/'. 6 | Then run this script from the miniImagenet directory: 7 | cd data/miniImagenet/ 8 | python proc_images.py 9 | """ 10 | 11 | from __future__ import print_function 12 | import csv 13 | import glob 14 | import os 15 | 16 | from PIL import Image 17 | 18 | path_to_images = 'images/' 19 | 20 | all_images = glob.glob(path_to_images + '*') 21 | 22 | # Resize images 23 | for i, image_file in enumerate(all_images): 24 | im = Image.open(image_file) 25 | im = im.resize((84, 84), resample=Image.LANCZOS) 26 | im.save(image_file) 27 | if i % 500 == 0: 28 | print(i) 29 | 30 | # Put in correct directory 31 | for datatype in ['train', 'val', 'test']: 32 | os.system('mkdir ' + datatype) 33 | 34 | with open(datatype + '.csv', 'r') as f: 35 | reader = csv.reader(f, delimiter=',') 36 | last_label = '' 37 | for i, row in enumerate(reader): 38 | if i == 0: # skip the headers 39 | continue 40 | label = row[1] 41 | image_name = row[0] 42 | if label != last_label: 43 | cur_dir = datatype + '/' + label + '/' 44 | os.system('mkdir ' + cur_dir) 45 | last_label = label 46 | os.system('mv images/' + image_name + ' ' + cur_dir) 47 | -------------------------------------------------------------------------------- /poly_generator.py: -------------------------------------------------------------------------------- 1 | """ Code for generating polynomials. """ 2 | import numpy as np 3 | from tensorflow.python.platform import flags 4 | 5 | FLAGS = flags.FLAGS 6 | 7 | 8 | class PolyDataGenerator(object): 9 | def __init__(self, num_samples_per_class, batch_size, config={}): 10 | assert FLAGS.datasource == 'polynomial' 11 | self.batch_size = batch_size 12 | self.num_samples_per_class = num_samples_per_class 13 | self.num_classes = 1 # by default 1 (only relevant for classification problems) 14 | self.poly_order = FLAGS.poly_order 15 | 16 | self.generate = self.generate_polynomial_batch 17 | self.input_range = config.get('input_range', [-2.0, 2.0]) 18 | self.coeff_range = config.get('coeff_range', [-1.0, 1.0]) 19 | self.dim_input = 1 20 | self.dim_output = 1 21 | 22 | def generate_polynomial_batch(self): 23 | coeffs = np.random.uniform(self.coeff_range[0], self.coeff_range[1], [self.batch_size, self.poly_order+1]) 24 | outputs = np.zeros([self.batch_size, self.num_samples_per_class, self.dim_output]) 25 | init_inputs = np.zeros([self.batch_size, self.num_samples_per_class, self.dim_input]) 26 | polynomial = np.polynomial.polynomial.polyval 27 | 28 | for func in range(self.batch_size): 29 | init_inputs[func] = np.random.uniform( 30 | self.input_range[0], self.input_range[1], [self.num_samples_per_class, 1]) 31 | func_coeffs = coeffs[func] # [c0, c1,...,] 32 | for i in range(self.poly_order + 1): 33 | func_coeffs[i] /= (2 ** i) 34 | outputs[func] = polynomial(init_inputs[func], func_coeffs) 35 | 36 | return init_inputs, outputs 37 | 38 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ Utility functions. """ 2 | import numpy as np 3 | import os 4 | import random 5 | import tensorflow as tf 6 | 7 | from tensorflow.contrib.layers.python import layers as tf_layers 8 | from tensorflow.python.platform import flags 9 | 10 | FLAGS = flags.FLAGS 11 | 12 | ## Image helper 13 | def get_images(paths, labels, nb_samples=None, shuffle=True): 14 | if nb_samples is not None: 15 | sampler = lambda x: random.sample(x, nb_samples) 16 | else: 17 | sampler = lambda x: x 18 | images = [(i, os.path.join(path, image)) \ 19 | for i, path in zip(labels, paths) \ 20 | for image in sampler(os.listdir(path))] 21 | if shuffle: 22 | random.shuffle(images) 23 | return images 24 | 25 | ## Network helpers 26 | def conv_block(inp, cweight, bweight, reuse, scope, activation=tf.nn.relu, max_pool_pad='VALID', residual=False): 27 | """ Perform, conv, batch norm, nonlinearity, and max pool """ 28 | stride, no_stride = [1,2,2,1], [1,1,1,1] 29 | 30 | if FLAGS.max_pool: 31 | conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME') + bweight 32 | else: 33 | conv_output = tf.nn.conv2d(inp, cweight, stride, 'SAME') + bweight 34 | normed = normalize(conv_output, activation, reuse, scope) 35 | if FLAGS.max_pool: 36 | normed = tf.nn.max_pool(normed, stride, stride, max_pool_pad) 37 | return normed 38 | 39 | def normalize(inp, activation, reuse, scope): 40 | if FLAGS.norm == 'batch_norm': 41 | return tf_layers.batch_norm(inp, activation_fn=activation, reuse=reuse, scope=scope) 42 | elif FLAGS.norm == 'layer_norm': 43 | return tf_layers.layer_norm(inp, activation_fn=activation, reuse=reuse, scope=scope) 44 | elif FLAGS.norm == 'None': 45 | return activation(inp) 46 | 47 | ## Loss functions 48 | def mse(pred, label): 49 | pred = tf.reshape(pred, [-1]) 50 | label = tf.reshape(label, [-1]) 51 | return tf.reduce_mean(tf.square(pred-label)) 52 | 53 | def xent(pred, label): 54 | # Note - with tf version <=0.12, this loss has incorrect 2nd derivatives 55 | return tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=label) / FLAGS.update_batch_size 56 | 57 | 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MT-net 2 | 3 | Code accompanying the paper [Gradient-Based Meta-Learning with Learned Layerwise Metric and Subspace (Yoonho Lee and Seungjin Choi, ICML 2018)](https://arxiv.org/abs/1801.05558). 4 | It includes code for running the experiments in the paper (few-shot sine wave regression, Omniglot and miniImagenet few-shot classification). 5 | 6 | ## Abstract 7 | 8 | 9 | Gradient-based meta-learning methods leverage gradient descent to learn the commonalities among various tasks. While previous such methods have been successful in meta-learning tasks, they resort to simple gradient descent during meta-testing. Our primary contribution is the **MT-net**, which enables the meta-learner to learn on each layer's activation space a subspace that the task-specific learner performs gradient descent on. Additionally, a task-specific learner of an {\em MT-net} performs gradient descent with respect to a meta-learned distance metric, which warps the activation space to be more sensitive to task identity. We demonstrate that the dimension of this learned subspace reflects the complexity of the task-specific learner's adaptation task, and also that our model is less sensitive to the choice of initial learning rates than previous gradient-based meta-learning methods. Our method achieves state-of-the-art or comparable performance on few-shot classification and regression tasks. 10 | 11 | ### Data 12 | For the Omniglot and MiniImagenet data, see the usage instructions in `data/omniglot_resized/resize_images.py` and `data/miniImagenet/proc_images.py` respectively. 13 | 14 | ### Usage 15 | To run the code, see the usage instructions at the top of `main.py`. 16 | 17 | For MT-nets, set `use_T`, `use_M`, `share_M` to `True`. 18 | 19 | For T-nets, set `use_T` to `True` and `use_M` to `False`. 20 | 21 | ## Reference 22 | 23 | If you found the provided code useful, please cite our work. 24 | 25 | ``` 26 | @inproceedings{lee2018gradient, 27 | title={Gradient-based meta-learning with learned layerwise metric and subspace}, 28 | author={Lee, Yoonho and Choi, Seungjin}, 29 | booktitle={International Conference on Machine Learning}, 30 | pages={2933--2942}, 31 | year={2018} 32 | } 33 | ``` 34 | 35 | --- 36 | 37 | This codebase is based on the repository for [MAML](https://github.com/cbfinn/maml). 38 | -------------------------------------------------------------------------------- /data_generator.py: -------------------------------------------------------------------------------- 1 | """ Code for loading data. """ 2 | import numpy as np 3 | import os 4 | import random 5 | import tensorflow as tf 6 | 7 | from tensorflow.python.platform import flags 8 | from utils import get_images 9 | 10 | FLAGS = flags.FLAGS 11 | 12 | class DataGenerator(object): 13 | """ 14 | Data Generator capable of generating batches of sinusoid or Omniglot data. 15 | A "class" is considered a class of omniglot digits or a particular sinusoid function. 16 | """ 17 | def __init__(self, num_samples_per_class, batch_size, config={}): 18 | """ 19 | Args: 20 | num_samples_per_class: num samples to generate per class in one batch 21 | batch_size: size of meta batch size (e.g. number of functions) 22 | """ 23 | self.batch_size = batch_size 24 | self.num_samples_per_class = num_samples_per_class 25 | self.num_classes = 1 # by default 1 (only relevant for classification problems) 26 | 27 | if FLAGS.datasource == 'sinusoid': 28 | self.generate = self.generate_sinusoid_batch 29 | self.amp_range = config.get('amp_range', [0.1, 5.0]) 30 | self.phase_range = config.get('phase_range', [0, np.pi]) 31 | self.input_range = config.get('input_range', [-5.0, 5.0]) 32 | self.freq_range = config.get('freq_range', [0.8, 1.2]) 33 | self.dim_input = 1 34 | self.dim_output = 1 35 | elif 'omniglot' in FLAGS.datasource: 36 | self.num_classes = config.get('num_classes', FLAGS.num_classes) 37 | self.img_size = config.get('img_size', (28, 28)) 38 | self.dim_input = np.prod(self.img_size) 39 | self.dim_output = self.num_classes 40 | # data that is pre-resized using PIL with lanczos filter 41 | data_folder = config.get('data_folder', './data/omniglot_resized') 42 | 43 | character_folders = [os.path.join(data_folder, family, character) \ 44 | for family in os.listdir(data_folder) \ 45 | if os.path.isdir(os.path.join(data_folder, family)) \ 46 | for character in os.listdir(os.path.join(data_folder, family))] 47 | random.seed(1) 48 | random.shuffle(character_folders) 49 | num_val = 100 50 | num_train = config.get('num_train', 1200) - num_val 51 | self.metatrain_character_folders = character_folders[:num_train] 52 | if FLAGS.test_set: 53 | self.metaval_character_folders = character_folders[num_train:num_train+num_val] 54 | else: 55 | self.metaval_character_folders = character_folders[num_train+num_val:] 56 | self.rotations = config.get('rotations', [0, 90, 180, 270]) 57 | elif FLAGS.datasource == 'miniimagenet': 58 | self.num_classes = config.get('num_classes', FLAGS.num_classes) 59 | self.img_size = config.get('img_size', (84, 84)) 60 | self.dim_input = np.prod(self.img_size)*3 61 | self.dim_output = self.num_classes 62 | metatrain_folder = config.get('metatrain_folder', './data/miniImagenet/train') 63 | if FLAGS.test_set: 64 | metaval_folder = config.get('metaval_folder', './data/miniImagenet/test') 65 | else: 66 | metaval_folder = config.get('metaval_folder', './data/miniImagenet/val') 67 | 68 | metatrain_folders = [os.path.join(metatrain_folder, label) \ 69 | for label in os.listdir(metatrain_folder) \ 70 | if os.path.isdir(os.path.join(metatrain_folder, label)) \ 71 | ] 72 | metaval_folders = [os.path.join(metaval_folder, label) \ 73 | for label in os.listdir(metaval_folder) \ 74 | if os.path.isdir(os.path.join(metaval_folder, label)) \ 75 | ] 76 | self.metatrain_character_folders = metatrain_folders 77 | self.metaval_character_folders = metaval_folders 78 | self.rotations = config.get('rotations', [0]) 79 | else: 80 | raise ValueError('Unrecognized data source') 81 | 82 | 83 | def make_data_tensor(self, train=True): 84 | if train: 85 | folders = self.metatrain_character_folders 86 | folders = folders[:FLAGS.num_train_classes] 87 | # number of tasks, not number of meta-iterations. (divide by metabatch size to measure) 88 | num_total_batches = 200000 if not FLAGS.debug else 32 89 | else: 90 | folders = self.metaval_character_folders 91 | num_total_batches = 600 if not FLAGS.debug else 32 92 | 93 | # make list of files 94 | print('Generating filenames') 95 | all_filenames = [] 96 | for _ in range(num_total_batches): 97 | sampled_character_folders = random.sample(folders, self.num_classes) 98 | random.shuffle(sampled_character_folders) 99 | labels_and_images = get_images(sampled_character_folders, range(self.num_classes), nb_samples=self.num_samples_per_class, shuffle=False) 100 | # make sure the above isn't randomized order 101 | labels = [li[0] for li in labels_and_images] 102 | filenames = [li[1] for li in labels_and_images] 103 | all_filenames.extend(filenames) 104 | 105 | # make queue for tensorflow to read from 106 | filename_queue = tf.train.string_input_producer(tf.convert_to_tensor(all_filenames), shuffle=False) 107 | print('Generating image processing ops') 108 | image_reader = tf.WholeFileReader() 109 | _, image_file = image_reader.read(filename_queue) 110 | if FLAGS.datasource == 'miniimagenet': 111 | image = tf.image.decode_jpeg(image_file, channels=3) 112 | image.set_shape((self.img_size[0], self.img_size[1], 3)) 113 | image = tf.reshape(image, [self.dim_input]) 114 | image = tf.cast(image, tf.float32) / 255.0 115 | else: 116 | image = tf.image.decode_png(image_file) 117 | image.set_shape((self.img_size[0],self.img_size[1],1)) 118 | image = tf.reshape(image, [self.dim_input]) 119 | image = tf.cast(image, tf.float32) / 255.0 120 | image = 1.0 - image # invert 121 | num_preprocess_threads = 1 122 | # TODO: enable this to be set to >1 123 | min_queue_examples = 256 124 | examples_per_batch = self.num_classes * self.num_samples_per_class 125 | batch_image_size = self.batch_size * examples_per_batch 126 | print('Batching images') 127 | images = tf.train.batch( 128 | [image], 129 | batch_size=batch_image_size, 130 | num_threads=num_preprocess_threads, 131 | capacity=min_queue_examples + 3 * batch_image_size, 132 | ) 133 | all_image_batches, all_label_batches = [], [] 134 | print('Manipulating image data to be right shape') 135 | for i in range(self.batch_size): 136 | image_batch = images[i*examples_per_batch:(i+1)*examples_per_batch] 137 | 138 | if FLAGS.datasource == 'omniglot': 139 | # omniglot augments the dataset by rotating digits to create new classes 140 | # get rotation per class (e.g. 0,1,2,0,0 if there are 5 classes) 141 | rotations = tf.multinomial(tf.log([[1., 1., 1., 1.]]), self.num_classes) 142 | label_batch = tf.convert_to_tensor(labels) 143 | new_list, new_label_list = [], [] 144 | for k in range(self.num_samples_per_class): 145 | class_idxs = tf.range(0, self.num_classes) 146 | class_idxs = tf.random_shuffle(class_idxs) 147 | 148 | true_idxs = class_idxs*self.num_samples_per_class + k 149 | new_list.append(tf.gather(image_batch,true_idxs)) 150 | if FLAGS.datasource == 'omniglot': # and FLAGS.train: 151 | new_list[-1] = tf.stack([tf.reshape(tf.image.rot90( 152 | tf.reshape(new_list[-1][ind], [self.img_size[0],self.img_size[1],1]), 153 | k=tf.cast(rotations[0,class_idxs[ind]], tf.int32)), (self.dim_input,)) 154 | for ind in range(self.num_classes)]) 155 | new_label_list.append(tf.gather(label_batch, true_idxs)) 156 | new_list = tf.concat(new_list, 0) # has shape [self.num_classes*self.num_samples_per_class, self.dim_input] 157 | new_label_list = tf.concat(new_label_list, 0) 158 | all_image_batches.append(new_list) 159 | all_label_batches.append(new_label_list) 160 | all_image_batches = tf.stack(all_image_batches) 161 | all_label_batches = tf.stack(all_label_batches) 162 | all_label_batches = tf.one_hot(all_label_batches, self.num_classes) 163 | return all_image_batches, all_label_batches 164 | 165 | def generate_sinusoid_batch(self, train=True, input_idx=None): 166 | # Note train arg is not used (but it is used for omniglot method. 167 | # input_idx is used during qualitative testing --the number of examples used for the grad update 168 | amp = np.random.uniform(self.amp_range[0], self.amp_range[1], [self.batch_size]) 169 | phase = np.random.uniform(self.phase_range[0], self.phase_range[1], [self.batch_size]) 170 | freq = np.random.uniform(self.freq_range[0], self.freq_range[1], [self.batch_size]) 171 | outputs = np.zeros([self.batch_size, self.num_samples_per_class, self.dim_output]) 172 | init_inputs = np.zeros([self.batch_size, self.num_samples_per_class, self.dim_input]) 173 | for func in range(self.batch_size): 174 | init_inputs[func] = np.random.uniform(self.input_range[0], self.input_range[1], [self.num_samples_per_class, 1]) 175 | if input_idx is not None: 176 | init_inputs[:, input_idx:, 0] = np.linspace( 177 | self.input_range[0], self.input_range[1], 178 | num=self.num_samples_per_class-input_idx, retstep=False) 179 | outputs[func] = amp[func] * np.sin(freq[func] * init_inputs[func]-phase[func]) 180 | return init_inputs, outputs, amp, phase 181 | 182 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage Instructions: 3 | Scripts with hyperparameters are in experiments/ 4 | 5 | To run evaluation, use the '--train=False' flag and the '--test_set=True' flag to use the test set. 6 | """ 7 | 8 | import csv 9 | import numpy as np 10 | import pickle 11 | import random 12 | import tensorflow as tf 13 | 14 | from data_generator import DataGenerator 15 | from poly_generator import PolyDataGenerator 16 | from maml import MAML 17 | from tensorflow.python.platform import flags 18 | 19 | FLAGS = flags.FLAGS 20 | 21 | ## Dataset/method options 22 | flags.DEFINE_string('datasource', 'sinusoid', 'sinusoid or omniglot or miniimagenet') 23 | flags.DEFINE_integer('num_classes', 5, 'number of classes used in classification (e.g. 5-way classification).') 24 | flags.DEFINE_integer('num_train_classes', -1, 'number of classes to train on (-1 for all).') 25 | # oracle means task id is input (only suitable for sinusoid) 26 | flags.DEFINE_string('baseline', None, 'oracle, or None') 27 | 28 | ## Training options 29 | flags.DEFINE_integer('pretrain_iterations', 0, 'number of pre-training iterations.') 30 | flags.DEFINE_integer('metatrain_iterations', 15000, 'number of metatraining iterations.') # 15k for omniglot, 50k for sinusoid 31 | flags.DEFINE_integer('meta_batch_size', 25, 'number of tasks sampled per meta-update') 32 | flags.DEFINE_float('meta_lr', 0.001, 'the base learning rate of the generator') 33 | flags.DEFINE_integer('update_batch_size', 5, 'number of examples used for inner gradient update (K for K-shot learning).') 34 | flags.DEFINE_float('update_lr', 1e-3, 'step size alpha for inner gradient update.') # 0.1 for omniglot 35 | flags.DEFINE_integer('num_updates', 1, 'number of inner gradient updates during training.') 36 | flags.DEFINE_integer('poly_order', 1, 'order of polynomial to generate') 37 | 38 | ## Model options 39 | #flags.DEFINE_string('mod', '', 'modifications to original paper. None, split, both') 40 | flags.DEFINE_bool('use_T', False, 'whether or not to use transformation matrix T') 41 | flags.DEFINE_bool('use_M', False, 'whether or not to use mask M') 42 | flags.DEFINE_bool('share_M', False, 'only effective if use_M is true, whether or not to ' 43 | 'share masks between weights' 44 | 'that contribute to the same activation') 45 | flags.DEFINE_float('temp', 1, 'temperature for gumbel-softmax') 46 | flags.DEFINE_float('logit_init', 0, 'initial logit') 47 | flags.DEFINE_string('norm', 'batch_norm', 'batch_norm, layer_norm, or None') 48 | flags.DEFINE_integer('dim_hidden', 40, 'dimension of fc layer') 49 | flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- use 32 for ' 50 | 'miniimagenet, 64 for omiglot.') 51 | flags.DEFINE_bool('conv', True, 'whether or not to use a convolutional network, only applicable in some cases') 52 | flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions') 53 | flags.DEFINE_bool('stop_grad', False, 'if True, do not use second derivatives in meta-optimization (for speed)') 54 | 55 | ## Logging, saving, and testing options 56 | flags.DEFINE_bool('log', True, 'if false, do not log summaries, for debugging code.') 57 | flags.DEFINE_string('logdir', '/tmp/data', 'directory for summaries and checkpoints.') 58 | flags.DEFINE_bool('debug', False, 'debug mode. uses less data for fast evaluation.') 59 | flags.DEFINE_bool('resume', True, 'resume training if there is a model available') 60 | flags.DEFINE_bool('train', True, 'True to train, False to test.') 61 | flags.DEFINE_integer('test_iter', -1, 'iteration to load model (-1 for latest model)') 62 | flags.DEFINE_bool('test_set', False, 'Set to true to test on the the test set, False for the validation set.') 63 | flags.DEFINE_integer('train_update_batch_size', -1, 'number of examples used for gradient update during training (use if you want to test with a different number).') 64 | flags.DEFINE_float('train_update_lr', -1, 'value of inner gradient step step during training. (use if you want to test with a different value)') # 0.1 for omniglot 65 | 66 | 67 | def train(model, saver, sess, exp_string, data_generator, resume_itr=0): 68 | SUMMARY_INTERVAL = 100 69 | SAVE_INTERVAL = 1000 70 | if FLAGS.debug: 71 | SUMMARY_INTERVAL = PRINT_INTERVAL = 10 72 | TEST_PRINT_INTERVAL = PRINT_INTERVAL*5 73 | elif FLAGS.datasource in ['sinusoid', 'polynomial']: 74 | PRINT_INTERVAL = 1000 75 | TEST_PRINT_INTERVAL = PRINT_INTERVAL*5 76 | else: 77 | PRINT_INTERVAL = 100 78 | TEST_PRINT_INTERVAL = PRINT_INTERVAL*5 79 | 80 | if FLAGS.log: 81 | train_writer = tf.summary.FileWriter(FLAGS.logdir + '/' + exp_string, sess.graph) 82 | print('Done initializing, starting training.') 83 | prelosses, postlosses = [], [] 84 | 85 | num_classes = data_generator.num_classes # for classification, 1 otherwise 86 | multitask_weights, reg_weights = [], [] 87 | 88 | for itr in range(resume_itr, FLAGS.pretrain_iterations + FLAGS.metatrain_iterations): 89 | feed_dict = {} 90 | if FLAGS.datasource == 'sinusoid': 91 | batch_x, batch_y, amp, phase = data_generator.generate() 92 | 93 | if FLAGS.baseline == 'oracle': 94 | batch_x = np.concatenate([batch_x, np.zeros([batch_x.shape[0], batch_x.shape[1], 2])], 2) 95 | for i in range(FLAGS.meta_batch_size): 96 | batch_x[i, :, 1] = amp[i] 97 | batch_x[i, :, 2] = phase[i] 98 | 99 | inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :] 100 | labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :] 101 | inputb = batch_x[:, num_classes*FLAGS.update_batch_size:, :] # b used for testing 102 | labelb = batch_y[:, num_classes*FLAGS.update_batch_size:, :] 103 | feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, model.labelb: labelb} 104 | 105 | elif FLAGS.datasource == 'polynomial': 106 | batch_x, batch_y = data_generator.generate() 107 | inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :] 108 | labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :] 109 | inputb = batch_x[:, num_classes*FLAGS.update_batch_size:, :] # b used for testing 110 | labelb = batch_y[:, num_classes*FLAGS.update_batch_size:, :] 111 | feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, model.labelb: labelb} 112 | 113 | 114 | if itr < FLAGS.pretrain_iterations: 115 | input_tensors = [model.pretrain_op] 116 | else: 117 | input_tensors = [model.metatrain_op] 118 | 119 | if itr % SUMMARY_INTERVAL == 0 or itr % PRINT_INTERVAL == 0: 120 | input_tensors.extend([model.summ_op, model.total_loss1, 121 | model.total_losses2[FLAGS.num_updates-1]]) 122 | if model.classification: 123 | input_tensors.extend([model.total_accuracy1, model.total_accuracies2[FLAGS.num_updates-1]]) 124 | 125 | result = sess.run(input_tensors, feed_dict) 126 | 127 | if itr % SUMMARY_INTERVAL == 0: 128 | prelosses.append(result[-2]) 129 | if FLAGS.log: 130 | train_writer.add_summary(result[1], itr) 131 | postlosses.append(result[-1]) 132 | 133 | if itr != 0 and itr % PRINT_INTERVAL == 0: 134 | if itr < FLAGS.pretrain_iterations: 135 | print_str = 'Pretrain Iteration ' + str(itr) 136 | else: 137 | print_str = 'Iteration ' + str(itr - FLAGS.pretrain_iterations) 138 | print_str += ': ' + str(np.mean(prelosses)) + ', ' + str(np.mean(postlosses)) 139 | print(print_str) 140 | #print sess.run(model.total_probs) 141 | prelosses, postlosses = [], [] 142 | 143 | if itr != 0 and itr % SAVE_INTERVAL == 0: 144 | saver.save(sess, FLAGS.logdir + '/' + exp_string + '/model' + str(itr)) 145 | 146 | # sinusoid is infinite data, so no need to test on meta-validation set. 147 | if itr != 0 and itr % TEST_PRINT_INTERVAL == 0 and FLAGS.datasource not in ['sinusoid', 'polynomial']: 148 | if 'generate' not in dir(data_generator): 149 | feed_dict = {} 150 | if model.classification: 151 | input_tensors = [model.metaval_total_accuracy1, 152 | model.metaval_total_accuracies2[FLAGS.num_updates-1], model.summ_op] 153 | else: 154 | input_tensors = [model.metaval_total_loss1, 155 | model.metaval_total_losses2[FLAGS.num_updates-1], model.summ_op] 156 | else: 157 | batch_x, batch_y, amp, phase = data_generator.generate(train=False) 158 | inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :] 159 | inputb = batch_x[:, num_classes*FLAGS.update_batch_size:, :] 160 | labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :] 161 | labelb = batch_y[:, num_classes*FLAGS.update_batch_size:, :] 162 | feed_dict = {model.inputa: inputa, model.inputb: inputb, 163 | model.labela: labela, model.labelb: labelb, model.meta_lr: 0.0} 164 | if model.classification: 165 | input_tensors = [model.total_accuracy1, model.total_accuracies2[FLAGS.num_updates-1]] 166 | else: 167 | input_tensors = [model.total_loss1, model.total_losses2[FLAGS.num_updates-1]] 168 | 169 | result = sess.run(input_tensors, feed_dict) 170 | print('Validation results: ' + str(result[0]) + ', ' + str(result[1])) 171 | 172 | saver.save(sess, FLAGS.logdir + '/' + exp_string + '/model' + str(itr)) 173 | 174 | 175 | def test(model, saver, sess, exp_string, data_generator, test_num_updates=None): 176 | num_classes = data_generator.num_classes # for classification, 1 otherwise 177 | 178 | np.random.seed(1) 179 | random.seed(1) 180 | 181 | metaval_accuracies = [] 182 | 183 | if FLAGS.datasource == 'miniimagenet': 184 | NUM_TEST_POINTS = 4000 185 | elif FLAGS.datasource == 'polynomial': 186 | NUM_TEST_POINTS = 20 187 | else: 188 | NUM_TEST_POINTS = 600 189 | for point_n in range(NUM_TEST_POINTS): 190 | if 'generate' not in dir(data_generator): 191 | feed_dict = {model.meta_lr: 0.0} 192 | elif FLAGS.datasource == 'sinusoid': 193 | batch_x, batch_y, amp, phase = data_generator.generate(train=False) 194 | 195 | if FLAGS.baseline == 'oracle': # NOTE - this flag is specific to sinusoid 196 | batch_x = np.concatenate([batch_x, np.zeros([batch_x.shape[0], batch_x.shape[1], 2])], 2) 197 | batch_x[0, :, 1] = amp[0] 198 | batch_x[0, :, 2] = phase[0] 199 | 200 | inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :] 201 | inputb = batch_x[:, num_classes*FLAGS.update_batch_size:, :] 202 | labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :] 203 | labelb = batch_y[:, num_classes*FLAGS.update_batch_size:, :] 204 | 205 | feed_dict = {model.inputa: inputa, model.inputb: inputb, 206 | model.labela: labela, model.labelb: labelb, model.meta_lr: 0.0} 207 | elif FLAGS.datasource == 'polynomial': 208 | batch_x, batch_y = data_generator.generate() 209 | inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :] 210 | inputb = batch_x[:, num_classes*FLAGS.update_batch_size:, :] 211 | labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :] 212 | labelb = batch_y[:, num_classes*FLAGS.update_batch_size:, :] 213 | feed_dict = {model.inputa: inputa, model.inputb: inputb, 214 | model.labela: labela, model.labelb: labelb, model.meta_lr: 0.0} 215 | 216 | ########## plotting code 217 | import matplotlib.pyplot as plt 218 | from matplotlib import rc 219 | import matplotlib 220 | matplotlib.rcParams.update({'font.size': 25}) 221 | fig, ax = plt.subplots() 222 | fig.set_size_inches(15, 10) 223 | plt.plot(inputa.flatten(), labela.flatten(), 'ro') 224 | plt.plot(inputb.flatten(), labelb.flatten(), 'r,') 225 | outputbs = sess.run(model.outputbs, feed_dict) 226 | plt.plot(inputb.flatten(), outputbs[0].flatten(), color='#bfbfbf', marker=',', linestyle='None') 227 | plt.plot(inputb.flatten(), outputbs[1].flatten(), color='#666666', marker=',', linestyle='None') 228 | plt.plot(inputb.flatten(), outputbs[9].flatten(), color='#000000', marker=',', linestyle='None') 229 | plt.title('Polynomial order ' + str(FLAGS.poly_order)) 230 | plt.legend() 231 | axes = plt.gca() 232 | axes.set_xlim([-2, 2]) 233 | axes.set_ylim([-5.1, 5.1]) 234 | plt.savefig(FLAGS.logdir + '/' + exp_string + '/' + str(point_n) + '.png') 235 | #plt.savefig(str(point_n) + '.png') 236 | plt.cla() 237 | 238 | if model.classification: 239 | result = sess.run([model.metaval_total_accuracy1] + model.metaval_total_accuracies2, feed_dict) 240 | else: 241 | result = sess.run([model.total_loss1] + model.total_losses2, feed_dict) 242 | metaval_accuracies.append(result) 243 | 244 | metaval_accuracies = np.array(metaval_accuracies) 245 | means = np.mean(metaval_accuracies, 0) 246 | stds = np.std(metaval_accuracies, 0) 247 | ci95 = 1.96*stds/np.sqrt(NUM_TEST_POINTS) 248 | 249 | print('Mean validation accuracy/loss, stddev, and confidence intervals') 250 | print((means, stds, ci95)) 251 | filename = FLAGS.logdir + '/' + exp_string + '/' + 'test_ubs' + str(FLAGS.update_batch_size) + \ 252 | '_stepsize' + str(FLAGS.update_lr) + '_testiter' + str(FLAGS.test_iter) 253 | with open(filename + '.pkl', 'w') as f: 254 | pickle.dump({'mses': metaval_accuracies}, f) 255 | with open(filename + '.csv', 'w') as f: 256 | writer = csv.writer(f, delimiter=',') 257 | writer.writerow(['update'+str(i) for i in range(len(means))]) 258 | writer.writerow(means) 259 | writer.writerow(stds) 260 | writer.writerow(ci95) 261 | 262 | 263 | def main(): 264 | if FLAGS.datasource in ['sinusoid', 'polynomial']: 265 | if FLAGS.train: 266 | test_num_updates = 5 267 | else: 268 | test_num_updates = 10 269 | elif FLAGS.datasource == 'miniimagenet': 270 | if FLAGS.train: 271 | test_num_updates = 1 # eval on at least one update during training 272 | else: 273 | test_num_updates = 10 274 | else: 275 | test_num_updates = 10 276 | 277 | if not FLAGS.train: 278 | orig_meta_batch_size = FLAGS.meta_batch_size 279 | # always use meta batch size of 1 when testing. 280 | FLAGS.meta_batch_size = 1 281 | 282 | if FLAGS.datasource == 'sinusoid': 283 | #data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) 284 | # Use 10 val samples (meta-SGD, 4.1 paragraph 2 first line) 285 | data_generator = DataGenerator(FLAGS.update_batch_size+10, FLAGS.meta_batch_size) 286 | elif FLAGS.datasource == 'polynomial': 287 | if FLAGS.train: 288 | data_generator = PolyDataGenerator(FLAGS.update_batch_size+10, FLAGS.meta_batch_size) 289 | else: 290 | data_generator = PolyDataGenerator(4000, FLAGS.meta_batch_size) 291 | elif FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet': 292 | assert FLAGS.meta_batch_size == 1 293 | assert FLAGS.update_batch_size == 1 294 | data_generator = DataGenerator(1, FLAGS.meta_batch_size) # only use one datapoint, 295 | elif FLAGS.datasource == 'miniimagenet': # TODO - use 15 val examples for imagenet? 296 | if FLAGS.train: 297 | data_generator = DataGenerator(FLAGS.update_batch_size+15, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory 298 | else: 299 | data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory 300 | else: 301 | assert FLAGS.datasource == 'omniglot' 302 | data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory 303 | 304 | dim_output = data_generator.dim_output 305 | if FLAGS.baseline == 'oracle': 306 | assert FLAGS.datasource == 'sinusoid' 307 | dim_input = 3 308 | FLAGS.pretrain_iterations += FLAGS.metatrain_iterations 309 | FLAGS.metatrain_iterations = 0 310 | else: 311 | dim_input = data_generator.dim_input 312 | 313 | if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot': 314 | tf_data_load = True 315 | num_classes = data_generator.num_classes 316 | 317 | if FLAGS.train: # only construct training model if needed 318 | random.seed(5) 319 | image_tensor, label_tensor = data_generator.make_data_tensor() 320 | inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) 321 | inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) 322 | labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) 323 | labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) 324 | input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} 325 | 326 | random.seed(6) 327 | image_tensor, label_tensor = data_generator.make_data_tensor(train=False) 328 | inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) 329 | inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) 330 | labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) 331 | labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) 332 | metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} 333 | else: 334 | input_tensors = None 335 | tf_data_load = False 336 | 337 | model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) 338 | if FLAGS.train or not tf_data_load: 339 | model.construct_model(input_tensors=input_tensors, prefix='metatrain_') 340 | if tf_data_load: 341 | model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') 342 | model.summ_op = tf.summary.merge_all() 343 | 344 | saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=3) 345 | 346 | sess = tf.InteractiveSession() 347 | 348 | if not FLAGS.train: 349 | # change to original meta batch size when loading model. 350 | FLAGS.meta_batch_size = orig_meta_batch_size 351 | 352 | if FLAGS.train_update_batch_size == -1: 353 | FLAGS.train_update_batch_size = FLAGS.update_batch_size 354 | if FLAGS.train_update_lr == -1: 355 | FLAGS.train_update_lr = FLAGS.update_lr 356 | 357 | exp_string = 'cls_'+str(FLAGS.num_classes)+\ 358 | '.mbs_'+str(FLAGS.meta_batch_size) + \ 359 | '.ubs_' + str(FLAGS.train_update_batch_size) + \ 360 | '.numstep' + str(FLAGS.num_updates) + \ 361 | '.updatelr' + str(FLAGS.train_update_lr) + \ 362 | '.temp' + str(FLAGS.temp) 363 | 364 | if FLAGS.debug: 365 | exp_string += '!DEBUG!' 366 | 367 | if FLAGS.use_T and FLAGS.use_M and FLAGS.share_M: 368 | exp_string += 'MTnet' 369 | if FLAGS.use_T and not FLAGS.use_M: 370 | exp_string += 'Tnet' 371 | if not FLAGS.use_T and FLAGS.use_M and FLAGS.share_M: 372 | exp_string += 'Mnet' 373 | if FLAGS.use_T and FLAGS.use_M and not FLAGS.share_M: 374 | exp_string += 'MTnet_noshare' 375 | if not FLAGS.use_T and FLAGS.use_M and not FLAGS.share_M: 376 | exp_string += 'Mnet_noshare' 377 | if not FLAGS.use_T and not FLAGS.use_M: 378 | exp_string += 'MAML' 379 | 380 | if FLAGS.datasource == 'polynomial': 381 | exp_string += 'ord' + str(FLAGS.poly_order) 382 | if FLAGS.num_train_classes != -1: 383 | exp_string += 'ntc' + str(FLAGS.num_train_classes) 384 | if FLAGS.num_filters != 64: 385 | exp_string += 'hidden' + str(FLAGS.num_filters) 386 | if FLAGS.max_pool: 387 | exp_string += 'maxpool' 388 | if FLAGS.stop_grad: 389 | exp_string += 'stopgrad' 390 | if FLAGS.baseline: 391 | exp_string += FLAGS.baseline 392 | if FLAGS.norm == 'batch_norm': 393 | exp_string += 'batchnorm' 394 | elif FLAGS.norm == 'layer_norm': 395 | exp_string += 'layernorm' 396 | elif FLAGS.norm == 'None': 397 | exp_string += 'nonorm' 398 | else: 399 | print('Norm setting not recognized.') 400 | 401 | resume_itr = 0 402 | tf.global_variables_initializer().run() 403 | tf.train.start_queue_runners() 404 | 405 | if FLAGS.resume or not FLAGS.train: 406 | model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) 407 | if FLAGS.test_iter > 0: 408 | model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter) 409 | if model_file: 410 | ind1 = model_file.index('model') 411 | resume_itr = int(model_file[ind1+5:]) 412 | print("Restoring model weights from " + model_file) 413 | saver.restore(sess, model_file) 414 | 415 | print flags.FLAGS.__flags 416 | print exp_string 417 | 418 | if FLAGS.train: 419 | train(model, saver, sess, exp_string, data_generator, resume_itr) 420 | else: 421 | test(model, saver, sess, exp_string, data_generator, test_num_updates) 422 | 423 | 424 | if __name__ == "__main__": 425 | main() 426 | -------------------------------------------------------------------------------- /maml.py: -------------------------------------------------------------------------------- 1 | """ Code for the MAML algorithm and network definitions. """ 2 | import numpy as np 3 | 4 | try: 5 | import special_grads 6 | except KeyError as e: 7 | print 'WARNING: Cannot define MaxPoolGrad, likely already defined for this version of TensorFlow:', e 8 | import tensorflow as tf 9 | 10 | from tensorflow.python.platform import flags 11 | from utils import mse, xent, conv_block, normalize 12 | 13 | FLAGS = flags.FLAGS 14 | 15 | 16 | class MAML: 17 | def __init__(self, dim_input=1, dim_output=1, test_num_updates=5): 18 | """ must call construct_model() after initializing MAML! """ 19 | self.dim_input = dim_input 20 | self.dim_output = dim_output 21 | self.update_lr = FLAGS.update_lr 22 | self.meta_lr = tf.placeholder_with_default(FLAGS.meta_lr, ()) 23 | self.classification = False 24 | self.test_num_updates = test_num_updates 25 | if FLAGS.datasource in ['sinusoid', 'polynomial']: 26 | self.dim_hidden = [FLAGS.dim_hidden, FLAGS.dim_hidden] 27 | if FLAGS.use_T: 28 | self.forward = self.forward_fc_withT 29 | else: 30 | self.forward = self.forward_fc 31 | self.construct_weights = self.construct_fc_weights 32 | self.loss_func = mse 33 | elif FLAGS.datasource == 'omniglot' or FLAGS.datasource == 'miniimagenet': 34 | self.loss_func = xent 35 | self.classification = True 36 | if FLAGS.conv: 37 | self.dim_hidden = FLAGS.num_filters 38 | if FLAGS.use_T: 39 | self.forward = self.forward_conv_withT 40 | else: 41 | self.forward = self.forward_conv 42 | self.construct_weights = self.construct_conv_weights 43 | else: 44 | self.dim_hidden = [256, 128, 64, 64] 45 | self.forward = self.forward_fc 46 | self.construct_weights = self.construct_fc_weights 47 | if FLAGS.datasource == 'miniimagenet': 48 | self.channels = 3 49 | else: 50 | self.channels = 1 51 | self.img_size = int(np.sqrt(self.dim_input / self.channels)) 52 | else: 53 | raise ValueError('Unrecognized data source.') 54 | 55 | def construct_model(self, input_tensors=None, prefix='metatrain_'): 56 | # a: training data for inner gradient, b: test data for meta gradient 57 | if input_tensors is None: 58 | if 'inputa' not in dir(self): 59 | self.inputa = tf.placeholder(tf.float32) 60 | self.inputb = tf.placeholder(tf.float32) 61 | self.labela = tf.placeholder(tf.float32) 62 | self.labelb = tf.placeholder(tf.float32) 63 | else: 64 | self.inputa = input_tensors['inputa'] 65 | self.inputb = input_tensors['inputb'] 66 | self.labela = input_tensors['labela'] 67 | self.labelb = input_tensors['labelb'] 68 | 69 | with tf.variable_scope('model', reuse=None) as training_scope: 70 | self.dropout_probs = {} 71 | if 'weights' in dir(self): 72 | training_scope.reuse_variables() 73 | weights = self.weights 74 | else: 75 | # Define the weights 76 | self.weights = weights = self.construct_weights() 77 | 78 | # outputbs[i] and lossesb[i] is the output and loss after i+1 gradient updates 79 | lossesa, outputas, lossesb, outputbs = [], [], [], [] 80 | accuraciesa, accuraciesb = [], [] 81 | num_updates = max(self.test_num_updates, FLAGS.num_updates) 82 | outputbs = [[]] * num_updates 83 | lossesb = [[]] * num_updates 84 | accuraciesb = [[]] * num_updates 85 | 86 | def task_metalearn(inp, reuse=True): 87 | """ Perform gradient descent for one task in the meta-batch. """ 88 | inputa, inputb, labela, labelb = inp 89 | task_outputbs, task_lossesb = [], [] 90 | mse_lossesb = [] 91 | 92 | if self.classification: 93 | task_accuraciesb = [] 94 | 95 | train_keys = list(weights.keys()) 96 | if FLAGS.use_M and FLAGS.share_M: 97 | def make_shared_mask(key): 98 | temperature = FLAGS.temp 99 | logits = weights[key+'_prob'] 100 | logits = tf.stack([logits, tf.zeros(logits.shape)], 1) 101 | U = tf.random_uniform(logits.shape, minval=0, maxval=1) 102 | gumbel = -tf.log(-tf.log(U + 1e-20) + 1e-20) 103 | y = logits + gumbel 104 | gumbel_softmax = tf.nn.softmax(y / temperature) 105 | gumbel_hard = tf.cast(tf.equal(gumbel_softmax, tf.reduce_max(gumbel_softmax, 1, keep_dims=True)), tf.float32) 106 | mask = tf.stop_gradient(gumbel_hard - gumbel_softmax) + gumbel_softmax 107 | return mask[:, 0] 108 | 109 | def get_mask(masks, name): 110 | mask = masks[[k for k in masks.keys() if name[-1] in k][0]] 111 | if 'conv' in name: # Conv 112 | mask = tf.reshape(mask, [1, 1, 1, -1]) 113 | tile_size = weights[name].shape.as_list()[:3] + [1] 114 | mask = tf.tile(mask, tile_size) 115 | elif 'w' in name: # FC 116 | mask = tf.reshape(mask, [1, -1]) 117 | tile_size = weights[name].shape.as_list()[:1] + [1] 118 | mask = tf.tile(mask, tile_size) 119 | elif 'b' in name: # Bias 120 | mask = tf.reshape(mask, [-1]) 121 | return mask 122 | if self.classification: 123 | masks = {k: make_shared_mask(k) for k in ['conv1', 'conv2', 'conv3', 'conv4', 'w5']} 124 | else: 125 | masks = {k: make_shared_mask(k) for k in ['w1', 'w2', 'w3']} 126 | 127 | if FLAGS.use_M and not FLAGS.share_M: 128 | def get_mask_noshare(key): 129 | temperature = FLAGS.temp 130 | logits = weights[key + '_prob'] 131 | logits = tf.stack([logits, tf.zeros(logits.shape)], 1) 132 | U = tf.random_uniform(logits.shape, minval=0, maxval=1) 133 | gumbel = -tf.log(-tf.log(U + 1e-20) + 1e-20) 134 | y = logits + gumbel 135 | gumbel_softmax = tf.nn.softmax(y / temperature) 136 | gumbel_hard = tf.cast(tf.equal(gumbel_softmax, tf.reduce_max(gumbel_softmax, 1, keep_dims=True)), tf.float32) 137 | out = tf.stop_gradient(gumbel_hard - gumbel_softmax) + gumbel_softmax 138 | return tf.reshape(out[:, 0], weights[key].shape) 139 | 140 | train_keys = [k for k in weights.keys() if 'prob' not in k and 'f' not in k] 141 | train_weights = [weights[k] for k in train_keys] 142 | task_outputa = self.forward(inputa, weights, reuse=reuse) # only reuse on the first iter 143 | self.task_outputa = task_outputa 144 | task_lossa = self.loss_func(task_outputa, labela) 145 | grads = tf.gradients(task_lossa, train_weights) 146 | if FLAGS.stop_grad: 147 | grads = [tf.stop_gradient(grad) for grad in grads] 148 | gradients = dict(zip(train_keys, grads)) 149 | 150 | fast_weights = dict(zip(weights.keys(), [weights[key] for key in weights.keys()])) 151 | 152 | def compute_weights(key): 153 | prev_weights = fast_weights[key] 154 | if key not in train_keys: 155 | return prev_weights 156 | if FLAGS.use_M and FLAGS.share_M: 157 | mask = get_mask(masks, key) 158 | new_weights = prev_weights - self.update_lr * mask * gradients[key] 159 | elif FLAGS.use_M and not FLAGS.share_M: 160 | mask = get_mask_noshare(key) 161 | new_weights = prev_weights - self.update_lr * mask * gradients[key] 162 | else: 163 | new_weights = prev_weights - self.update_lr * gradients[key] 164 | return new_weights 165 | 166 | fast_weights = dict(zip( 167 | weights.keys(), [compute_weights(key) for key in weights.keys()])) 168 | 169 | output = self.forward(inputb, fast_weights, reuse=True) 170 | task_outputbs.append(output) 171 | loss = self.loss_func(output, labelb) 172 | task_lossesb.append(loss) 173 | 174 | for j in range(num_updates - 1): 175 | output = self.forward(inputa, fast_weights, reuse=True) 176 | loss = self.loss_func(output, labela) 177 | train_weights = [fast_weights[k] for k in train_keys] 178 | grads = tf.gradients(loss, train_weights) 179 | if FLAGS.stop_grad: 180 | grads = [tf.stop_gradient(grad) for grad in grads] 181 | gradients = dict(zip(train_keys, grads)) 182 | 183 | fast_weights = dict(zip( 184 | weights.keys(), [compute_weights(key) for key in weights.keys()])) 185 | 186 | output = self.forward(inputb, fast_weights, reuse=True) 187 | task_outputbs.append(output) 188 | loss = self.loss_func(output, labelb) 189 | task_lossesb.append(loss) 190 | 191 | task_output = [task_outputa, task_outputbs, task_lossa, task_lossesb] 192 | 193 | if self.classification: 194 | task_accuracya = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputa), 1), 195 | tf.argmax(labela, 1)) 196 | for j in range(num_updates): 197 | task_accuraciesb.append( 198 | tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputbs[j]), 1), 199 | tf.argmax(labelb, 1))) 200 | task_output.extend([task_accuracya, task_accuraciesb]) 201 | 202 | return task_output 203 | 204 | if FLAGS.norm is not 'None': 205 | # to initialize the batch norm vars, might want to combine this, and not run idx 0 twice. 206 | unused = task_metalearn((self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]), False) 207 | 208 | out_dtype = [tf.float32, [tf.float32] * num_updates, tf.float32, [tf.float32] * num_updates] 209 | if self.classification: 210 | out_dtype.extend([tf.float32, [tf.float32] * num_updates]) 211 | result = tf.map_fn(task_metalearn, elems=(self.inputa, self.inputb, self.labela, self.labelb), 212 | dtype=out_dtype, parallel_iterations=FLAGS.meta_batch_size) 213 | if self.classification: 214 | outputas, outputbs, lossesa, lossesb, accuraciesa, accuraciesb = result 215 | else: 216 | outputas, outputbs, lossesa, lossesb = result 217 | 218 | logit_keys = sorted([k for k in weights.keys() if 'prob' in k]) 219 | logit_weights = [-weights[k] for k in logit_keys] 220 | probs = [tf.exp(w) / (1 + tf.exp(w)) for w in logit_weights] 221 | self.total_probs = [tf.reduce_mean(p) for p in probs] 222 | 223 | ## Performance & Optimization 224 | if 'train' in prefix: 225 | self.total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size) 226 | self.total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j 227 | in range(num_updates)] 228 | # after the map_fn 229 | self.outputas, self.outputbs = outputas, outputbs 230 | if self.classification: 231 | self.total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(FLAGS.meta_batch_size) 232 | self.total_accuracies2 = total_accuracies2 = [ 233 | tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)] 234 | self.pretrain_op = tf.train.AdamOptimizer(self.meta_lr).minimize( total_loss1) 235 | 236 | if FLAGS.metatrain_iterations > 0: 237 | optimizer = tf.train.AdamOptimizer(self.meta_lr) 238 | loss = self.total_losses2[FLAGS.num_updates - 1] 239 | self.gvs = gvs = optimizer.compute_gradients(loss) 240 | if FLAGS.datasource == 'miniimagenet': 241 | gvs = [(tf.clip_by_value(grad, -10, 10), var) for grad, var in gvs] 242 | self.metatrain_op = optimizer.apply_gradients(gvs) 243 | 244 | else: 245 | self.metaval_total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size) 246 | self.metaval_total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) 247 | for j in range(num_updates)] 248 | if self.classification: 249 | self.metaval_total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float( 250 | FLAGS.meta_batch_size) 251 | self.metaval_total_accuracies2 = total_accuracies2 = [ 252 | tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)] 253 | 254 | ## Summaries 255 | tf.summary.scalar(prefix + 'change probs', tf.reduce_mean(self.total_probs)) 256 | tf.summary.scalar(prefix + 'Pre-update loss', total_loss1) 257 | if self.classification: 258 | tf.summary.scalar(prefix + 'Pre-update accuracy', total_accuracy1) 259 | 260 | for j in range(num_updates): 261 | tf.summary.scalar(prefix + 'Post-update loss, step ' + str(j + 1), total_losses2[j]) 262 | if self.classification: 263 | tf.summary.scalar(prefix + 'Post-update accuracy, step ' + str(j + 1), total_accuracies2[j]) 264 | 265 | for k, v in weights.iteritems(): 266 | tf.summary.histogram(k, v) 267 | if 'prob' in k: 268 | tf.summary.histogram('prob_'+k, tf.nn.softmax(tf.stack([v, tf.zeros(v.shape)], 1))[:, 0]) 269 | 270 | ### Network construction functions (fc networks and conv networks) 271 | def construct_fc_weights(self): 272 | weights = {} 273 | weights['w1'] = tf.Variable(tf.truncated_normal([self.dim_input, self.dim_hidden[0]], stddev=0.01)) 274 | weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden[0]])) 275 | for i in range(1, len(self.dim_hidden)): 276 | weights['w' + str(i + 1)] = tf.Variable( 277 | tf.truncated_normal([self.dim_hidden[i - 1], self.dim_hidden[i]], stddev=0.01)) 278 | weights['b' + str(i + 1)] = tf.Variable(tf.zeros([self.dim_hidden[i]])) 279 | weights['w' + str(len(self.dim_hidden) + 1)] = tf.Variable( 280 | tf.truncated_normal([self.dim_hidden[-1], self.dim_output], stddev=0.01)) 281 | weights['b' + str(len(self.dim_hidden) + 1)] = tf.Variable(tf.zeros([self.dim_output])) 282 | 283 | if FLAGS.use_M and not FLAGS.share_M: 284 | weights['w1_prob'] = tf.Variable(tf.truncated_normal([self.dim_input * self.dim_hidden[0]], stddev=.1)) 285 | weights['b1_prob'] = tf.Variable(tf.truncated_normal([self.dim_hidden[0]], stddev=.1)) 286 | for i in range(1, len(self.dim_hidden)): 287 | weights['w' + str(i + 1) + '_prob'] = tf.Variable( 288 | tf.truncated_normal([self.dim_hidden[i - 1] * self.dim_hidden[i]], stddev=.1)) 289 | weights['b' + str(i + 1) + '_prob'] = tf.Variable( 290 | tf.truncated_normal([self.dim_hidden[i]], stddev=.1)) 291 | weights['w' + str(len(self.dim_hidden) + 1) + '_prob'] = tf.Variable( 292 | tf.truncated_normal([self.dim_hidden[-1] * self.dim_output], stddev=0.1)) 293 | weights['b' + str(len(self.dim_hidden) + 1) + '_prob'] = tf.Variable( 294 | tf.truncated_normal([self.dim_output], stddev=.1)) 295 | elif FLAGS.use_M and FLAGS.share_M: 296 | weights['w1_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_hidden[0]])) 297 | for i in range(1, len(self.dim_hidden)): 298 | weights['w' + str(i + 1) + '_prob'] = tf.Variable( 299 | FLAGS.logit_init * tf.ones([self.dim_hidden[i]])) 300 | weights['w' + str(len(self.dim_hidden) + 1) + '_prob'] = tf.Variable( 301 | FLAGS.logit_init * tf.ones([self.dim_output])) 302 | 303 | if FLAGS.use_T: 304 | weights['w1_f'] = tf.Variable(tf.eye(self.dim_hidden[0])) 305 | weights['w2_f'] = tf.Variable(tf.eye(self.dim_hidden[1])) 306 | weights['w3_f'] = tf.Variable(tf.eye(self.dim_output)) 307 | return weights 308 | 309 | def forward_fc(self, inp, weights, reuse=False): 310 | hidden = normalize(tf.matmul(inp, weights['w1']) + weights['b1'], 311 | activation=tf.nn.relu, reuse=reuse, scope='0') 312 | for i in range(1, len(self.dim_hidden)): 313 | hidden = normalize(tf.matmul(hidden, weights['w' + str(i + 1)]) + weights['b' + str(i + 1)], 314 | activation=tf.nn.relu, reuse=reuse, scope=str(i + 1)) 315 | return tf.matmul(hidden, weights['w' + str(len(self.dim_hidden) + 1)]) + \ 316 | weights['b' + str(len(self.dim_hidden) + 1)] 317 | 318 | def forward_fc_withT(self, inp, weights, reuse=False): 319 | hidden = tf.matmul(tf.matmul(inp, weights['w1']) + weights['b1'], weights['w1_f']) 320 | hidden = normalize(hidden, activation=tf.nn.relu, reuse=reuse, scope='1') 321 | hidden = tf.matmul(tf.matmul(hidden, weights['w2']) + weights['b2'], weights['w2_f']) 322 | hidden = normalize(hidden, activation=tf.nn.relu, reuse=reuse, scope='2') 323 | hidden = tf.matmul(tf.matmul(hidden, weights['w3']) + weights['b3'], weights['w3_f']) 324 | return hidden 325 | 326 | def construct_conv_weights(self): 327 | weights = {} 328 | dtype = tf.float32 329 | conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype) 330 | fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype) 331 | k = 3 332 | channels = self.channels 333 | dim_hidden = self.dim_hidden 334 | 335 | def get_conv(name, shape): 336 | return tf.get_variable(name, shape, initializer=conv_initializer, dtype=dtype) 337 | 338 | def get_identity(dim, conv=True): 339 | return tf.Variable(tf.eye(dim, batch_shape=[1,1])) if conv \ 340 | else tf.Variable(tf.eye(dim)) 341 | 342 | weights['conv1'] = get_conv('conv1', [k, k, channels, self.dim_hidden]) 343 | weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden])) 344 | weights['conv2'] = get_conv('conv2', [k, k, dim_hidden, self.dim_hidden]) 345 | weights['b2'] = tf.Variable(tf.zeros([self.dim_hidden])) 346 | weights['conv3'] = get_conv('conv3', [k, k, dim_hidden, self.dim_hidden]) 347 | weights['b3'] = tf.Variable(tf.zeros([self.dim_hidden])) 348 | weights['conv4'] = get_conv('conv4', [k, k, dim_hidden, self.dim_hidden]) 349 | weights['b4'] = tf.Variable(tf.zeros([self.dim_hidden])) 350 | if FLAGS.datasource == 'miniimagenet': 351 | # assumes max pooling 352 | assert FLAGS.max_pool 353 | weights['w5'] = tf.get_variable('w5', [self.dim_hidden * 5 * 5, self.dim_output], 354 | initializer=fc_initializer) 355 | weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5') 356 | 357 | if FLAGS.use_M and not FLAGS.share_M: 358 | weights['conv1_prob'] = tf.Variable(tf.truncated_normal([k * k * channels * self.dim_hidden], stddev=.01)) 359 | weights['b1_prob'] = tf.Variable(tf.truncated_normal([self.dim_hidden], stddev=.01)) 360 | weights['conv2_prob'] = tf.Variable(tf.truncated_normal([k * k * dim_hidden * self.dim_hidden], stddev=.01)) 361 | weights['b2_prob'] = tf.Variable(tf.truncated_normal([self.dim_hidden], stddev=.01)) 362 | weights['conv3_prob'] = tf.Variable(tf.truncated_normal([k * k * dim_hidden * self.dim_hidden], stddev=.01)) 363 | weights['b3_prob'] = tf.Variable(tf.truncated_normal([self.dim_hidden], stddev=.01)) 364 | weights['conv4_prob'] = tf.Variable(tf.truncated_normal([k * k * dim_hidden * self.dim_hidden], stddev=.01)) 365 | weights['b4_prob'] = tf.Variable(tf.truncated_normal([self.dim_hidden], stddev=.01)) 366 | weights['w5_prob'] = tf.Variable(tf.truncated_normal([dim_hidden *5*5* self.dim_output], stddev=.01)) 367 | weights['b5_prob'] = tf.Variable(tf.truncated_normal([self.dim_output], stddev=.01)) 368 | if FLAGS.use_M and FLAGS.share_M: 369 | weights['conv1_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_hidden])) 370 | weights['conv2_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_hidden])) 371 | weights['conv3_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_hidden])) 372 | weights['conv4_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_hidden])) 373 | weights['w5_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_output])) 374 | 375 | if FLAGS.use_T: 376 | weights['conv1_f'] = get_identity(self.dim_hidden, conv=True) 377 | weights['conv2_f'] = get_identity(self.dim_hidden, conv=True) 378 | weights['conv3_f'] = get_identity(self.dim_hidden, conv=True) 379 | weights['conv4_f'] = get_identity(self.dim_hidden, conv=True) 380 | weights['w5_f'] = get_identity(self.dim_output, conv=False) 381 | else: 382 | weights['w5'] = tf.Variable(tf.random_normal([dim_hidden, self.dim_output]), name='w5') 383 | weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5') 384 | if FLAGS.use_M and not FLAGS.share_M: 385 | weights['conv1_prob'] = tf.Variable(tf.truncated_normal([k * k * channels * self.dim_hidden], stddev=.01)) 386 | weights['conv2_prob'] = tf.Variable(tf.truncated_normal([k * k * dim_hidden * self.dim_hidden], stddev=.01)) 387 | weights['conv3_prob'] = tf.Variable(tf.truncated_normal([k * k * dim_hidden * self.dim_hidden], stddev=.01)) 388 | weights['conv4_prob'] = tf.Variable(tf.truncated_normal([k * k * dim_hidden * self.dim_hidden], stddev=.01)) 389 | weights['w5_prob'] = tf.Variable(tf.truncated_normal([dim_hidden * self.dim_output], stddev=.01)) 390 | if FLAGS.use_M and FLAGS.share_M: 391 | weights['conv1_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_hidden])) 392 | weights['conv2_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_hidden])) 393 | weights['conv3_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_hidden])) 394 | weights['conv4_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_hidden])) 395 | weights['w5_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_output])) 396 | 397 | if FLAGS.use_T: 398 | weights['conv1_f'] = get_identity(self.dim_hidden, conv=True) 399 | weights['conv2_f'] = get_identity(self.dim_hidden, conv=True) 400 | weights['conv3_f'] = get_identity(self.dim_hidden, conv=True) 401 | weights['conv4_f'] = get_identity(self.dim_hidden, conv=True) 402 | weights['w5_f'] = get_identity(self.dim_output, conv=False) 403 | return weights 404 | 405 | def forward_conv(self, inp, weights, reuse=False, scope=''): 406 | # reuse is for the normalization parameters. 407 | channels = self.channels 408 | inp = tf.reshape(inp, [-1, self.img_size, self.img_size, channels]) 409 | hidden1 = conv_block(inp, weights['conv1'], weights['b1'], reuse, scope + '0') 410 | hidden2 = conv_block(hidden1, weights['conv2'], weights['b2'], reuse, scope + '1') 411 | hidden3 = conv_block(hidden2, weights['conv3'], weights['b3'], reuse, scope + '2') 412 | hidden4 = conv_block(hidden3, weights['conv4'], weights['b4'], reuse, scope + '3') 413 | 414 | if FLAGS.datasource == 'miniimagenet': 415 | # last hidden layer is 6x6x64-ish, reshape to a vector 416 | hidden4 = tf.reshape(hidden4, [-1, np.prod([int(dim) for dim in hidden4.get_shape()[1:]])]) 417 | else: 418 | hidden4 = tf.reduce_mean(hidden4, [1, 2]) 419 | return tf.matmul(hidden4, weights['w5']) + weights['b5'] 420 | 421 | def forward_conv_withT(self, inp, weights, reuse=False, scope=''): 422 | # reuse is for the normalization parameters. 423 | def conv_tout(inp, cweight, bweight, rweight, reuse, scope, activation=tf.nn.relu, max_pool_pad='VALID', 424 | residual=False): 425 | stride, no_stride = [1, 2, 2, 1], [1, 1, 1, 1] 426 | if FLAGS.max_pool: 427 | conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME') + bweight 428 | else: 429 | conv_output = tf.nn.conv2d(inp, cweight, stride, 'SAME') + bweight 430 | conv_output = tf.nn.conv2d(conv_output, rweight, no_stride, 'SAME') 431 | normed = normalize(conv_output, activation, reuse, scope) 432 | if FLAGS.max_pool: 433 | normed = tf.nn.max_pool(normed, stride, stride, max_pool_pad) 434 | return normed 435 | 436 | channels = self.channels 437 | inp = tf.reshape(inp, [-1, self.img_size, self.img_size, channels]) 438 | hidden1 = conv_tout(inp, weights['conv1'], weights['b1'], weights['conv1_f'], reuse, scope + '0') 439 | hidden2 = conv_tout(hidden1, weights['conv2'], weights['b2'], weights['conv2_f'], reuse, scope + '1') 440 | hidden3 = conv_tout(hidden2, weights['conv3'], weights['b3'], weights['conv3_f'], reuse, scope + '2') 441 | hidden4 = conv_tout(hidden3, weights['conv4'], weights['b4'], weights['conv4_f'], reuse, scope + '3') 442 | 443 | if FLAGS.datasource == 'miniimagenet': 444 | # last hidden layer is 6x6x64-ish, reshape to a vector 445 | hidden4 = tf.reshape(hidden4, [-1, np.prod([int(dim) for dim in hidden4.get_shape()[1:]])]) 446 | else: 447 | hidden4 = tf.reduce_mean(hidden4, [1, 2]) 448 | hidden5 = tf.matmul(hidden4, weights['w5']) + weights['b5'] 449 | return tf.matmul(hidden5, weights['w5_f']) 450 | --------------------------------------------------------------------------------