├── __init__.py
├── mtnet-fig.png
├── mtnet-subspace.png
├── experiments
├── sine.sh
├── omniglot.sh
├── polynomial.sh
└── miniimagenet.sh
├── data
├── omniglot_resized
│ └── resize_images.py
└── miniImagenet
│ └── proc_images.py
├── special_grads.py
├── LICENSE
├── poly_generator.py
├── utils.py
├── README.md
├── data_generator.py
├── main.py
└── maml.py
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/mtnet-fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoonholee/MT-net/HEAD/mtnet-fig.png
--------------------------------------------------------------------------------
/mtnet-subspace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoonholee/MT-net/HEAD/mtnet-subspace.png
--------------------------------------------------------------------------------
/experiments/sine.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | for lr in .01 .04 .1
4 | do
5 | python main.py \
6 | --datasource=sinusoid --metatrain_iterations=60000 \
7 | --meta_batch_size=4 --update_lr=$lr --norm=None --resume=True \
8 | --update_batch_size=10 --use_T=True --use_M=True --share_M=True \
9 | --logdir=logs/sine
10 | done
11 |
12 | # For example, to use T-net:
13 | # --use_T=True --use_M=False --share_M=False
14 | #
15 | # Original MAML is recovered by using:
16 | # --use_T=False --use_M=False --share_M=False
17 |
--------------------------------------------------------------------------------
/experiments/omniglot.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Omniglot 5-way with MT-net
4 | python main.py \
5 | --datasource=omniglot --metatrain_iterations=40000 \
6 | --meta_batch_size=32 --update_batch_size=1\
7 | --num_classes=5 --num_updates=1 --logdir=logs/omniglot20way \
8 | --update_lr=.4 --use_T=True --use_M=True --share_M=True
9 |
10 | # Omniglot 20-way with MT-net
11 | python main.py \
12 | --datasource=omniglot --metatrain_iterations=40000 \
13 | --meta_batch_size=16 --update_batch_size=1\
14 | --num_classes=20 --num_updates=1 --logdir=logs/omniglot20way \
15 | --update_lr=.1 --use_T=True --use_M=True --share_M=True
16 |
--------------------------------------------------------------------------------
/experiments/polynomial.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python main.py \
4 | --datasource=polynomial --metatrain_iterations=60000 --update_batch_size=10 \
5 | --meta_batch_size=4 --norm=None --logdir=logs/poly --poly_order=0 \
6 | --use_T=True --use_M=True --share_M=True
7 |
8 | python main.py \
9 | --datasource=polynomial --metatrain_iterations=60000 --update_batch_size=10 \
10 | --meta_batch_size=4 --norm=None --logdir=logs/poly --poly_order=1 \
11 | --use_T=True --use_M=True --share_M=True
12 |
13 | python main.py \
14 | --datasource=polynomial --metatrain_iterations=60000 --update_batch_size=10 \
15 | --meta_batch_size=4 --norm=None --logdir=logs/poly --poly_order=2 \
16 | --use_T=True --use_M=True --share_M=True
17 |
--------------------------------------------------------------------------------
/data/omniglot_resized/resize_images.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage instructions:
3 | First download the omniglot dataset
4 | and put the contents of both images_background and images_evaluation in data/omniglot/ (without the root folder)
5 |
6 | Then, run the following:
7 | cd data/
8 | cp -r omniglot/* omniglot_resized/
9 | cd omniglot_resized/
10 | python resize_images.py
11 | """
12 | from PIL import Image
13 | import glob
14 |
15 | image_path = '*/*/'
16 |
17 | all_images = glob.glob(image_path + '*')
18 |
19 | i = 0
20 |
21 | for image_file in all_images:
22 | im = Image.open(image_file)
23 | im = im.resize((28,28), resample=Image.LANCZOS)
24 | im.save(image_file)
25 | i += 1
26 |
27 | if i % 200 == 0:
28 | print(i)
29 |
30 |
--------------------------------------------------------------------------------
/special_grads.py:
--------------------------------------------------------------------------------
1 | """ Code for second derivatives not implemented in TensorFlow library. """
2 | from tensorflow.python.framework import ops
3 | from tensorflow.python.ops import array_ops
4 | from tensorflow.python.ops import gen_nn_ops
5 |
6 | @ops.RegisterGradient("MaxPoolGrad")
7 | def _MaxPoolGradGrad(op, grad):
8 | gradient = gen_nn_ops._max_pool_grad(op.inputs[0], op.outputs[0],
9 | grad, op.get_attr("ksize"), op.get_attr("strides"),
10 | padding=op.get_attr("padding"), data_format=op.get_attr("data_format"))
11 | gradgrad1 = array_ops.zeros(shape = array_ops.shape(op.inputs[1]), dtype=gradient.dtype)
12 | gradgrad2 = array_ops.zeros(shape = array_ops.shape(op.inputs[2]), dtype=gradient.dtype)
13 | return (gradient, gradgrad1, gradgrad2)
14 |
--------------------------------------------------------------------------------
/experiments/miniimagenet.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # miniImagenet with MT-nets and hyperparameters from MAML
4 | python main.py \
5 | --datasource=miniimagenet --metatrain_iterations=60000 \
6 | --meta_batch_size=4 --update_batch_size=1 \
7 | --num_updates=5 --logdir=logs/miniimagenet5way \
8 | --update_lr=.01 --resume=True --num_filters=32 --max_pool=True \
9 | --use_T=True --use_M=True --share_M=True
10 |
11 | # works well even with single gradient step
12 | python main.py \
13 | --datasource=miniimagenet --metatrain_iterations=60000 \
14 | --meta_batch_size=4 --update_batch_size=1 \
15 | --num_updates=1 --logdir=logs/miniimagenet5way \
16 | --update_lr=.4 --resume=True --num_filters=32 --max_pool=True \
17 | --use_T=True --use_M=True --share_M=True
18 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Chelsea Finn
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/data/miniImagenet/proc_images.py:
--------------------------------------------------------------------------------
1 | """
2 | Script for converting from csv file datafiles to a directory for each image (which is how it is loaded by MAML code)
3 |
4 | Acquire miniImagenet from Ravi & Larochelle '17, along with the train, val, and test csv files. Put the
5 | csv files in the miniImagenet directory and put the images in the directory 'miniImagenet/images/'.
6 | Then run this script from the miniImagenet directory:
7 | cd data/miniImagenet/
8 | python proc_images.py
9 | """
10 |
11 | from __future__ import print_function
12 | import csv
13 | import glob
14 | import os
15 |
16 | from PIL import Image
17 |
18 | path_to_images = 'images/'
19 |
20 | all_images = glob.glob(path_to_images + '*')
21 |
22 | # Resize images
23 | for i, image_file in enumerate(all_images):
24 | im = Image.open(image_file)
25 | im = im.resize((84, 84), resample=Image.LANCZOS)
26 | im.save(image_file)
27 | if i % 500 == 0:
28 | print(i)
29 |
30 | # Put in correct directory
31 | for datatype in ['train', 'val', 'test']:
32 | os.system('mkdir ' + datatype)
33 |
34 | with open(datatype + '.csv', 'r') as f:
35 | reader = csv.reader(f, delimiter=',')
36 | last_label = ''
37 | for i, row in enumerate(reader):
38 | if i == 0: # skip the headers
39 | continue
40 | label = row[1]
41 | image_name = row[0]
42 | if label != last_label:
43 | cur_dir = datatype + '/' + label + '/'
44 | os.system('mkdir ' + cur_dir)
45 | last_label = label
46 | os.system('mv images/' + image_name + ' ' + cur_dir)
47 |
--------------------------------------------------------------------------------
/poly_generator.py:
--------------------------------------------------------------------------------
1 | """ Code for generating polynomials. """
2 | import numpy as np
3 | from tensorflow.python.platform import flags
4 |
5 | FLAGS = flags.FLAGS
6 |
7 |
8 | class PolyDataGenerator(object):
9 | def __init__(self, num_samples_per_class, batch_size, config={}):
10 | assert FLAGS.datasource == 'polynomial'
11 | self.batch_size = batch_size
12 | self.num_samples_per_class = num_samples_per_class
13 | self.num_classes = 1 # by default 1 (only relevant for classification problems)
14 | self.poly_order = FLAGS.poly_order
15 |
16 | self.generate = self.generate_polynomial_batch
17 | self.input_range = config.get('input_range', [-2.0, 2.0])
18 | self.coeff_range = config.get('coeff_range', [-1.0, 1.0])
19 | self.dim_input = 1
20 | self.dim_output = 1
21 |
22 | def generate_polynomial_batch(self):
23 | coeffs = np.random.uniform(self.coeff_range[0], self.coeff_range[1], [self.batch_size, self.poly_order+1])
24 | outputs = np.zeros([self.batch_size, self.num_samples_per_class, self.dim_output])
25 | init_inputs = np.zeros([self.batch_size, self.num_samples_per_class, self.dim_input])
26 | polynomial = np.polynomial.polynomial.polyval
27 |
28 | for func in range(self.batch_size):
29 | init_inputs[func] = np.random.uniform(
30 | self.input_range[0], self.input_range[1], [self.num_samples_per_class, 1])
31 | func_coeffs = coeffs[func] # [c0, c1,...,]
32 | for i in range(self.poly_order + 1):
33 | func_coeffs[i] /= (2 ** i)
34 | outputs[func] = polynomial(init_inputs[func], func_coeffs)
35 |
36 | return init_inputs, outputs
37 |
38 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """ Utility functions. """
2 | import numpy as np
3 | import os
4 | import random
5 | import tensorflow as tf
6 |
7 | from tensorflow.contrib.layers.python import layers as tf_layers
8 | from tensorflow.python.platform import flags
9 |
10 | FLAGS = flags.FLAGS
11 |
12 | ## Image helper
13 | def get_images(paths, labels, nb_samples=None, shuffle=True):
14 | if nb_samples is not None:
15 | sampler = lambda x: random.sample(x, nb_samples)
16 | else:
17 | sampler = lambda x: x
18 | images = [(i, os.path.join(path, image)) \
19 | for i, path in zip(labels, paths) \
20 | for image in sampler(os.listdir(path))]
21 | if shuffle:
22 | random.shuffle(images)
23 | return images
24 |
25 | ## Network helpers
26 | def conv_block(inp, cweight, bweight, reuse, scope, activation=tf.nn.relu, max_pool_pad='VALID', residual=False):
27 | """ Perform, conv, batch norm, nonlinearity, and max pool """
28 | stride, no_stride = [1,2,2,1], [1,1,1,1]
29 |
30 | if FLAGS.max_pool:
31 | conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME') + bweight
32 | else:
33 | conv_output = tf.nn.conv2d(inp, cweight, stride, 'SAME') + bweight
34 | normed = normalize(conv_output, activation, reuse, scope)
35 | if FLAGS.max_pool:
36 | normed = tf.nn.max_pool(normed, stride, stride, max_pool_pad)
37 | return normed
38 |
39 | def normalize(inp, activation, reuse, scope):
40 | if FLAGS.norm == 'batch_norm':
41 | return tf_layers.batch_norm(inp, activation_fn=activation, reuse=reuse, scope=scope)
42 | elif FLAGS.norm == 'layer_norm':
43 | return tf_layers.layer_norm(inp, activation_fn=activation, reuse=reuse, scope=scope)
44 | elif FLAGS.norm == 'None':
45 | return activation(inp)
46 |
47 | ## Loss functions
48 | def mse(pred, label):
49 | pred = tf.reshape(pred, [-1])
50 | label = tf.reshape(label, [-1])
51 | return tf.reduce_mean(tf.square(pred-label))
52 |
53 | def xent(pred, label):
54 | # Note - with tf version <=0.12, this loss has incorrect 2nd derivatives
55 | return tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=label) / FLAGS.update_batch_size
56 |
57 |
58 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MT-net
2 |
3 | Code accompanying the paper [Gradient-Based Meta-Learning with Learned Layerwise Metric and Subspace (Yoonho Lee and Seungjin Choi, ICML 2018)](https://arxiv.org/abs/1801.05558).
4 | It includes code for running the experiments in the paper (few-shot sine wave regression, Omniglot and miniImagenet few-shot classification).
5 |
6 | ## Abstract
7 | 
8 |
9 | Gradient-based meta-learning methods leverage gradient descent to learn the commonalities among various tasks. While previous such methods have been successful in meta-learning tasks, they resort to simple gradient descent during meta-testing. Our primary contribution is the **MT-net**, which enables the meta-learner to learn on each layer's activation space a subspace that the task-specific learner performs gradient descent on. Additionally, a task-specific learner of an {\em MT-net} performs gradient descent with respect to a meta-learned distance metric, which warps the activation space to be more sensitive to task identity. We demonstrate that the dimension of this learned subspace reflects the complexity of the task-specific learner's adaptation task, and also that our model is less sensitive to the choice of initial learning rates than previous gradient-based meta-learning methods. Our method achieves state-of-the-art or comparable performance on few-shot classification and regression tasks.
10 |
11 | ### Data
12 | For the Omniglot and MiniImagenet data, see the usage instructions in `data/omniglot_resized/resize_images.py` and `data/miniImagenet/proc_images.py` respectively.
13 |
14 | ### Usage
15 | To run the code, see the usage instructions at the top of `main.py`.
16 |
17 | For MT-nets, set `use_T`, `use_M`, `share_M` to `True`.
18 |
19 | For T-nets, set `use_T` to `True` and `use_M` to `False`.
20 |
21 | ## Reference
22 |
23 | If you found the provided code useful, please cite our work.
24 |
25 | ```
26 | @inproceedings{lee2018gradient,
27 | title={Gradient-based meta-learning with learned layerwise metric and subspace},
28 | author={Lee, Yoonho and Choi, Seungjin},
29 | booktitle={International Conference on Machine Learning},
30 | pages={2933--2942},
31 | year={2018}
32 | }
33 | ```
34 |
35 | ---
36 |
37 | This codebase is based on the repository for [MAML](https://github.com/cbfinn/maml).
38 |
--------------------------------------------------------------------------------
/data_generator.py:
--------------------------------------------------------------------------------
1 | """ Code for loading data. """
2 | import numpy as np
3 | import os
4 | import random
5 | import tensorflow as tf
6 |
7 | from tensorflow.python.platform import flags
8 | from utils import get_images
9 |
10 | FLAGS = flags.FLAGS
11 |
12 | class DataGenerator(object):
13 | """
14 | Data Generator capable of generating batches of sinusoid or Omniglot data.
15 | A "class" is considered a class of omniglot digits or a particular sinusoid function.
16 | """
17 | def __init__(self, num_samples_per_class, batch_size, config={}):
18 | """
19 | Args:
20 | num_samples_per_class: num samples to generate per class in one batch
21 | batch_size: size of meta batch size (e.g. number of functions)
22 | """
23 | self.batch_size = batch_size
24 | self.num_samples_per_class = num_samples_per_class
25 | self.num_classes = 1 # by default 1 (only relevant for classification problems)
26 |
27 | if FLAGS.datasource == 'sinusoid':
28 | self.generate = self.generate_sinusoid_batch
29 | self.amp_range = config.get('amp_range', [0.1, 5.0])
30 | self.phase_range = config.get('phase_range', [0, np.pi])
31 | self.input_range = config.get('input_range', [-5.0, 5.0])
32 | self.freq_range = config.get('freq_range', [0.8, 1.2])
33 | self.dim_input = 1
34 | self.dim_output = 1
35 | elif 'omniglot' in FLAGS.datasource:
36 | self.num_classes = config.get('num_classes', FLAGS.num_classes)
37 | self.img_size = config.get('img_size', (28, 28))
38 | self.dim_input = np.prod(self.img_size)
39 | self.dim_output = self.num_classes
40 | # data that is pre-resized using PIL with lanczos filter
41 | data_folder = config.get('data_folder', './data/omniglot_resized')
42 |
43 | character_folders = [os.path.join(data_folder, family, character) \
44 | for family in os.listdir(data_folder) \
45 | if os.path.isdir(os.path.join(data_folder, family)) \
46 | for character in os.listdir(os.path.join(data_folder, family))]
47 | random.seed(1)
48 | random.shuffle(character_folders)
49 | num_val = 100
50 | num_train = config.get('num_train', 1200) - num_val
51 | self.metatrain_character_folders = character_folders[:num_train]
52 | if FLAGS.test_set:
53 | self.metaval_character_folders = character_folders[num_train:num_train+num_val]
54 | else:
55 | self.metaval_character_folders = character_folders[num_train+num_val:]
56 | self.rotations = config.get('rotations', [0, 90, 180, 270])
57 | elif FLAGS.datasource == 'miniimagenet':
58 | self.num_classes = config.get('num_classes', FLAGS.num_classes)
59 | self.img_size = config.get('img_size', (84, 84))
60 | self.dim_input = np.prod(self.img_size)*3
61 | self.dim_output = self.num_classes
62 | metatrain_folder = config.get('metatrain_folder', './data/miniImagenet/train')
63 | if FLAGS.test_set:
64 | metaval_folder = config.get('metaval_folder', './data/miniImagenet/test')
65 | else:
66 | metaval_folder = config.get('metaval_folder', './data/miniImagenet/val')
67 |
68 | metatrain_folders = [os.path.join(metatrain_folder, label) \
69 | for label in os.listdir(metatrain_folder) \
70 | if os.path.isdir(os.path.join(metatrain_folder, label)) \
71 | ]
72 | metaval_folders = [os.path.join(metaval_folder, label) \
73 | for label in os.listdir(metaval_folder) \
74 | if os.path.isdir(os.path.join(metaval_folder, label)) \
75 | ]
76 | self.metatrain_character_folders = metatrain_folders
77 | self.metaval_character_folders = metaval_folders
78 | self.rotations = config.get('rotations', [0])
79 | else:
80 | raise ValueError('Unrecognized data source')
81 |
82 |
83 | def make_data_tensor(self, train=True):
84 | if train:
85 | folders = self.metatrain_character_folders
86 | folders = folders[:FLAGS.num_train_classes]
87 | # number of tasks, not number of meta-iterations. (divide by metabatch size to measure)
88 | num_total_batches = 200000 if not FLAGS.debug else 32
89 | else:
90 | folders = self.metaval_character_folders
91 | num_total_batches = 600 if not FLAGS.debug else 32
92 |
93 | # make list of files
94 | print('Generating filenames')
95 | all_filenames = []
96 | for _ in range(num_total_batches):
97 | sampled_character_folders = random.sample(folders, self.num_classes)
98 | random.shuffle(sampled_character_folders)
99 | labels_and_images = get_images(sampled_character_folders, range(self.num_classes), nb_samples=self.num_samples_per_class, shuffle=False)
100 | # make sure the above isn't randomized order
101 | labels = [li[0] for li in labels_and_images]
102 | filenames = [li[1] for li in labels_and_images]
103 | all_filenames.extend(filenames)
104 |
105 | # make queue for tensorflow to read from
106 | filename_queue = tf.train.string_input_producer(tf.convert_to_tensor(all_filenames), shuffle=False)
107 | print('Generating image processing ops')
108 | image_reader = tf.WholeFileReader()
109 | _, image_file = image_reader.read(filename_queue)
110 | if FLAGS.datasource == 'miniimagenet':
111 | image = tf.image.decode_jpeg(image_file, channels=3)
112 | image.set_shape((self.img_size[0], self.img_size[1], 3))
113 | image = tf.reshape(image, [self.dim_input])
114 | image = tf.cast(image, tf.float32) / 255.0
115 | else:
116 | image = tf.image.decode_png(image_file)
117 | image.set_shape((self.img_size[0],self.img_size[1],1))
118 | image = tf.reshape(image, [self.dim_input])
119 | image = tf.cast(image, tf.float32) / 255.0
120 | image = 1.0 - image # invert
121 | num_preprocess_threads = 1
122 | # TODO: enable this to be set to >1
123 | min_queue_examples = 256
124 | examples_per_batch = self.num_classes * self.num_samples_per_class
125 | batch_image_size = self.batch_size * examples_per_batch
126 | print('Batching images')
127 | images = tf.train.batch(
128 | [image],
129 | batch_size=batch_image_size,
130 | num_threads=num_preprocess_threads,
131 | capacity=min_queue_examples + 3 * batch_image_size,
132 | )
133 | all_image_batches, all_label_batches = [], []
134 | print('Manipulating image data to be right shape')
135 | for i in range(self.batch_size):
136 | image_batch = images[i*examples_per_batch:(i+1)*examples_per_batch]
137 |
138 | if FLAGS.datasource == 'omniglot':
139 | # omniglot augments the dataset by rotating digits to create new classes
140 | # get rotation per class (e.g. 0,1,2,0,0 if there are 5 classes)
141 | rotations = tf.multinomial(tf.log([[1., 1., 1., 1.]]), self.num_classes)
142 | label_batch = tf.convert_to_tensor(labels)
143 | new_list, new_label_list = [], []
144 | for k in range(self.num_samples_per_class):
145 | class_idxs = tf.range(0, self.num_classes)
146 | class_idxs = tf.random_shuffle(class_idxs)
147 |
148 | true_idxs = class_idxs*self.num_samples_per_class + k
149 | new_list.append(tf.gather(image_batch,true_idxs))
150 | if FLAGS.datasource == 'omniglot': # and FLAGS.train:
151 | new_list[-1] = tf.stack([tf.reshape(tf.image.rot90(
152 | tf.reshape(new_list[-1][ind], [self.img_size[0],self.img_size[1],1]),
153 | k=tf.cast(rotations[0,class_idxs[ind]], tf.int32)), (self.dim_input,))
154 | for ind in range(self.num_classes)])
155 | new_label_list.append(tf.gather(label_batch, true_idxs))
156 | new_list = tf.concat(new_list, 0) # has shape [self.num_classes*self.num_samples_per_class, self.dim_input]
157 | new_label_list = tf.concat(new_label_list, 0)
158 | all_image_batches.append(new_list)
159 | all_label_batches.append(new_label_list)
160 | all_image_batches = tf.stack(all_image_batches)
161 | all_label_batches = tf.stack(all_label_batches)
162 | all_label_batches = tf.one_hot(all_label_batches, self.num_classes)
163 | return all_image_batches, all_label_batches
164 |
165 | def generate_sinusoid_batch(self, train=True, input_idx=None):
166 | # Note train arg is not used (but it is used for omniglot method.
167 | # input_idx is used during qualitative testing --the number of examples used for the grad update
168 | amp = np.random.uniform(self.amp_range[0], self.amp_range[1], [self.batch_size])
169 | phase = np.random.uniform(self.phase_range[0], self.phase_range[1], [self.batch_size])
170 | freq = np.random.uniform(self.freq_range[0], self.freq_range[1], [self.batch_size])
171 | outputs = np.zeros([self.batch_size, self.num_samples_per_class, self.dim_output])
172 | init_inputs = np.zeros([self.batch_size, self.num_samples_per_class, self.dim_input])
173 | for func in range(self.batch_size):
174 | init_inputs[func] = np.random.uniform(self.input_range[0], self.input_range[1], [self.num_samples_per_class, 1])
175 | if input_idx is not None:
176 | init_inputs[:, input_idx:, 0] = np.linspace(
177 | self.input_range[0], self.input_range[1],
178 | num=self.num_samples_per_class-input_idx, retstep=False)
179 | outputs[func] = amp[func] * np.sin(freq[func] * init_inputs[func]-phase[func])
180 | return init_inputs, outputs, amp, phase
181 |
182 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage Instructions:
3 | Scripts with hyperparameters are in experiments/
4 |
5 | To run evaluation, use the '--train=False' flag and the '--test_set=True' flag to use the test set.
6 | """
7 |
8 | import csv
9 | import numpy as np
10 | import pickle
11 | import random
12 | import tensorflow as tf
13 |
14 | from data_generator import DataGenerator
15 | from poly_generator import PolyDataGenerator
16 | from maml import MAML
17 | from tensorflow.python.platform import flags
18 |
19 | FLAGS = flags.FLAGS
20 |
21 | ## Dataset/method options
22 | flags.DEFINE_string('datasource', 'sinusoid', 'sinusoid or omniglot or miniimagenet')
23 | flags.DEFINE_integer('num_classes', 5, 'number of classes used in classification (e.g. 5-way classification).')
24 | flags.DEFINE_integer('num_train_classes', -1, 'number of classes to train on (-1 for all).')
25 | # oracle means task id is input (only suitable for sinusoid)
26 | flags.DEFINE_string('baseline', None, 'oracle, or None')
27 |
28 | ## Training options
29 | flags.DEFINE_integer('pretrain_iterations', 0, 'number of pre-training iterations.')
30 | flags.DEFINE_integer('metatrain_iterations', 15000, 'number of metatraining iterations.') # 15k for omniglot, 50k for sinusoid
31 | flags.DEFINE_integer('meta_batch_size', 25, 'number of tasks sampled per meta-update')
32 | flags.DEFINE_float('meta_lr', 0.001, 'the base learning rate of the generator')
33 | flags.DEFINE_integer('update_batch_size', 5, 'number of examples used for inner gradient update (K for K-shot learning).')
34 | flags.DEFINE_float('update_lr', 1e-3, 'step size alpha for inner gradient update.') # 0.1 for omniglot
35 | flags.DEFINE_integer('num_updates', 1, 'number of inner gradient updates during training.')
36 | flags.DEFINE_integer('poly_order', 1, 'order of polynomial to generate')
37 |
38 | ## Model options
39 | #flags.DEFINE_string('mod', '', 'modifications to original paper. None, split, both')
40 | flags.DEFINE_bool('use_T', False, 'whether or not to use transformation matrix T')
41 | flags.DEFINE_bool('use_M', False, 'whether or not to use mask M')
42 | flags.DEFINE_bool('share_M', False, 'only effective if use_M is true, whether or not to '
43 | 'share masks between weights'
44 | 'that contribute to the same activation')
45 | flags.DEFINE_float('temp', 1, 'temperature for gumbel-softmax')
46 | flags.DEFINE_float('logit_init', 0, 'initial logit')
47 | flags.DEFINE_string('norm', 'batch_norm', 'batch_norm, layer_norm, or None')
48 | flags.DEFINE_integer('dim_hidden', 40, 'dimension of fc layer')
49 | flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- use 32 for '
50 | 'miniimagenet, 64 for omiglot.')
51 | flags.DEFINE_bool('conv', True, 'whether or not to use a convolutional network, only applicable in some cases')
52 | flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions')
53 | flags.DEFINE_bool('stop_grad', False, 'if True, do not use second derivatives in meta-optimization (for speed)')
54 |
55 | ## Logging, saving, and testing options
56 | flags.DEFINE_bool('log', True, 'if false, do not log summaries, for debugging code.')
57 | flags.DEFINE_string('logdir', '/tmp/data', 'directory for summaries and checkpoints.')
58 | flags.DEFINE_bool('debug', False, 'debug mode. uses less data for fast evaluation.')
59 | flags.DEFINE_bool('resume', True, 'resume training if there is a model available')
60 | flags.DEFINE_bool('train', True, 'True to train, False to test.')
61 | flags.DEFINE_integer('test_iter', -1, 'iteration to load model (-1 for latest model)')
62 | flags.DEFINE_bool('test_set', False, 'Set to true to test on the the test set, False for the validation set.')
63 | flags.DEFINE_integer('train_update_batch_size', -1, 'number of examples used for gradient update during training (use if you want to test with a different number).')
64 | flags.DEFINE_float('train_update_lr', -1, 'value of inner gradient step step during training. (use if you want to test with a different value)') # 0.1 for omniglot
65 |
66 |
67 | def train(model, saver, sess, exp_string, data_generator, resume_itr=0):
68 | SUMMARY_INTERVAL = 100
69 | SAVE_INTERVAL = 1000
70 | if FLAGS.debug:
71 | SUMMARY_INTERVAL = PRINT_INTERVAL = 10
72 | TEST_PRINT_INTERVAL = PRINT_INTERVAL*5
73 | elif FLAGS.datasource in ['sinusoid', 'polynomial']:
74 | PRINT_INTERVAL = 1000
75 | TEST_PRINT_INTERVAL = PRINT_INTERVAL*5
76 | else:
77 | PRINT_INTERVAL = 100
78 | TEST_PRINT_INTERVAL = PRINT_INTERVAL*5
79 |
80 | if FLAGS.log:
81 | train_writer = tf.summary.FileWriter(FLAGS.logdir + '/' + exp_string, sess.graph)
82 | print('Done initializing, starting training.')
83 | prelosses, postlosses = [], []
84 |
85 | num_classes = data_generator.num_classes # for classification, 1 otherwise
86 | multitask_weights, reg_weights = [], []
87 |
88 | for itr in range(resume_itr, FLAGS.pretrain_iterations + FLAGS.metatrain_iterations):
89 | feed_dict = {}
90 | if FLAGS.datasource == 'sinusoid':
91 | batch_x, batch_y, amp, phase = data_generator.generate()
92 |
93 | if FLAGS.baseline == 'oracle':
94 | batch_x = np.concatenate([batch_x, np.zeros([batch_x.shape[0], batch_x.shape[1], 2])], 2)
95 | for i in range(FLAGS.meta_batch_size):
96 | batch_x[i, :, 1] = amp[i]
97 | batch_x[i, :, 2] = phase[i]
98 |
99 | inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :]
100 | labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :]
101 | inputb = batch_x[:, num_classes*FLAGS.update_batch_size:, :] # b used for testing
102 | labelb = batch_y[:, num_classes*FLAGS.update_batch_size:, :]
103 | feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, model.labelb: labelb}
104 |
105 | elif FLAGS.datasource == 'polynomial':
106 | batch_x, batch_y = data_generator.generate()
107 | inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :]
108 | labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :]
109 | inputb = batch_x[:, num_classes*FLAGS.update_batch_size:, :] # b used for testing
110 | labelb = batch_y[:, num_classes*FLAGS.update_batch_size:, :]
111 | feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, model.labelb: labelb}
112 |
113 |
114 | if itr < FLAGS.pretrain_iterations:
115 | input_tensors = [model.pretrain_op]
116 | else:
117 | input_tensors = [model.metatrain_op]
118 |
119 | if itr % SUMMARY_INTERVAL == 0 or itr % PRINT_INTERVAL == 0:
120 | input_tensors.extend([model.summ_op, model.total_loss1,
121 | model.total_losses2[FLAGS.num_updates-1]])
122 | if model.classification:
123 | input_tensors.extend([model.total_accuracy1, model.total_accuracies2[FLAGS.num_updates-1]])
124 |
125 | result = sess.run(input_tensors, feed_dict)
126 |
127 | if itr % SUMMARY_INTERVAL == 0:
128 | prelosses.append(result[-2])
129 | if FLAGS.log:
130 | train_writer.add_summary(result[1], itr)
131 | postlosses.append(result[-1])
132 |
133 | if itr != 0 and itr % PRINT_INTERVAL == 0:
134 | if itr < FLAGS.pretrain_iterations:
135 | print_str = 'Pretrain Iteration ' + str(itr)
136 | else:
137 | print_str = 'Iteration ' + str(itr - FLAGS.pretrain_iterations)
138 | print_str += ': ' + str(np.mean(prelosses)) + ', ' + str(np.mean(postlosses))
139 | print(print_str)
140 | #print sess.run(model.total_probs)
141 | prelosses, postlosses = [], []
142 |
143 | if itr != 0 and itr % SAVE_INTERVAL == 0:
144 | saver.save(sess, FLAGS.logdir + '/' + exp_string + '/model' + str(itr))
145 |
146 | # sinusoid is infinite data, so no need to test on meta-validation set.
147 | if itr != 0 and itr % TEST_PRINT_INTERVAL == 0 and FLAGS.datasource not in ['sinusoid', 'polynomial']:
148 | if 'generate' not in dir(data_generator):
149 | feed_dict = {}
150 | if model.classification:
151 | input_tensors = [model.metaval_total_accuracy1,
152 | model.metaval_total_accuracies2[FLAGS.num_updates-1], model.summ_op]
153 | else:
154 | input_tensors = [model.metaval_total_loss1,
155 | model.metaval_total_losses2[FLAGS.num_updates-1], model.summ_op]
156 | else:
157 | batch_x, batch_y, amp, phase = data_generator.generate(train=False)
158 | inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :]
159 | inputb = batch_x[:, num_classes*FLAGS.update_batch_size:, :]
160 | labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :]
161 | labelb = batch_y[:, num_classes*FLAGS.update_batch_size:, :]
162 | feed_dict = {model.inputa: inputa, model.inputb: inputb,
163 | model.labela: labela, model.labelb: labelb, model.meta_lr: 0.0}
164 | if model.classification:
165 | input_tensors = [model.total_accuracy1, model.total_accuracies2[FLAGS.num_updates-1]]
166 | else:
167 | input_tensors = [model.total_loss1, model.total_losses2[FLAGS.num_updates-1]]
168 |
169 | result = sess.run(input_tensors, feed_dict)
170 | print('Validation results: ' + str(result[0]) + ', ' + str(result[1]))
171 |
172 | saver.save(sess, FLAGS.logdir + '/' + exp_string + '/model' + str(itr))
173 |
174 |
175 | def test(model, saver, sess, exp_string, data_generator, test_num_updates=None):
176 | num_classes = data_generator.num_classes # for classification, 1 otherwise
177 |
178 | np.random.seed(1)
179 | random.seed(1)
180 |
181 | metaval_accuracies = []
182 |
183 | if FLAGS.datasource == 'miniimagenet':
184 | NUM_TEST_POINTS = 4000
185 | elif FLAGS.datasource == 'polynomial':
186 | NUM_TEST_POINTS = 20
187 | else:
188 | NUM_TEST_POINTS = 600
189 | for point_n in range(NUM_TEST_POINTS):
190 | if 'generate' not in dir(data_generator):
191 | feed_dict = {model.meta_lr: 0.0}
192 | elif FLAGS.datasource == 'sinusoid':
193 | batch_x, batch_y, amp, phase = data_generator.generate(train=False)
194 |
195 | if FLAGS.baseline == 'oracle': # NOTE - this flag is specific to sinusoid
196 | batch_x = np.concatenate([batch_x, np.zeros([batch_x.shape[0], batch_x.shape[1], 2])], 2)
197 | batch_x[0, :, 1] = amp[0]
198 | batch_x[0, :, 2] = phase[0]
199 |
200 | inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :]
201 | inputb = batch_x[:, num_classes*FLAGS.update_batch_size:, :]
202 | labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :]
203 | labelb = batch_y[:, num_classes*FLAGS.update_batch_size:, :]
204 |
205 | feed_dict = {model.inputa: inputa, model.inputb: inputb,
206 | model.labela: labela, model.labelb: labelb, model.meta_lr: 0.0}
207 | elif FLAGS.datasource == 'polynomial':
208 | batch_x, batch_y = data_generator.generate()
209 | inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :]
210 | inputb = batch_x[:, num_classes*FLAGS.update_batch_size:, :]
211 | labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :]
212 | labelb = batch_y[:, num_classes*FLAGS.update_batch_size:, :]
213 | feed_dict = {model.inputa: inputa, model.inputb: inputb,
214 | model.labela: labela, model.labelb: labelb, model.meta_lr: 0.0}
215 |
216 | ########## plotting code
217 | import matplotlib.pyplot as plt
218 | from matplotlib import rc
219 | import matplotlib
220 | matplotlib.rcParams.update({'font.size': 25})
221 | fig, ax = plt.subplots()
222 | fig.set_size_inches(15, 10)
223 | plt.plot(inputa.flatten(), labela.flatten(), 'ro')
224 | plt.plot(inputb.flatten(), labelb.flatten(), 'r,')
225 | outputbs = sess.run(model.outputbs, feed_dict)
226 | plt.plot(inputb.flatten(), outputbs[0].flatten(), color='#bfbfbf', marker=',', linestyle='None')
227 | plt.plot(inputb.flatten(), outputbs[1].flatten(), color='#666666', marker=',', linestyle='None')
228 | plt.plot(inputb.flatten(), outputbs[9].flatten(), color='#000000', marker=',', linestyle='None')
229 | plt.title('Polynomial order ' + str(FLAGS.poly_order))
230 | plt.legend()
231 | axes = plt.gca()
232 | axes.set_xlim([-2, 2])
233 | axes.set_ylim([-5.1, 5.1])
234 | plt.savefig(FLAGS.logdir + '/' + exp_string + '/' + str(point_n) + '.png')
235 | #plt.savefig(str(point_n) + '.png')
236 | plt.cla()
237 |
238 | if model.classification:
239 | result = sess.run([model.metaval_total_accuracy1] + model.metaval_total_accuracies2, feed_dict)
240 | else:
241 | result = sess.run([model.total_loss1] + model.total_losses2, feed_dict)
242 | metaval_accuracies.append(result)
243 |
244 | metaval_accuracies = np.array(metaval_accuracies)
245 | means = np.mean(metaval_accuracies, 0)
246 | stds = np.std(metaval_accuracies, 0)
247 | ci95 = 1.96*stds/np.sqrt(NUM_TEST_POINTS)
248 |
249 | print('Mean validation accuracy/loss, stddev, and confidence intervals')
250 | print((means, stds, ci95))
251 | filename = FLAGS.logdir + '/' + exp_string + '/' + 'test_ubs' + str(FLAGS.update_batch_size) + \
252 | '_stepsize' + str(FLAGS.update_lr) + '_testiter' + str(FLAGS.test_iter)
253 | with open(filename + '.pkl', 'w') as f:
254 | pickle.dump({'mses': metaval_accuracies}, f)
255 | with open(filename + '.csv', 'w') as f:
256 | writer = csv.writer(f, delimiter=',')
257 | writer.writerow(['update'+str(i) for i in range(len(means))])
258 | writer.writerow(means)
259 | writer.writerow(stds)
260 | writer.writerow(ci95)
261 |
262 |
263 | def main():
264 | if FLAGS.datasource in ['sinusoid', 'polynomial']:
265 | if FLAGS.train:
266 | test_num_updates = 5
267 | else:
268 | test_num_updates = 10
269 | elif FLAGS.datasource == 'miniimagenet':
270 | if FLAGS.train:
271 | test_num_updates = 1 # eval on at least one update during training
272 | else:
273 | test_num_updates = 10
274 | else:
275 | test_num_updates = 10
276 |
277 | if not FLAGS.train:
278 | orig_meta_batch_size = FLAGS.meta_batch_size
279 | # always use meta batch size of 1 when testing.
280 | FLAGS.meta_batch_size = 1
281 |
282 | if FLAGS.datasource == 'sinusoid':
283 | #data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size)
284 | # Use 10 val samples (meta-SGD, 4.1 paragraph 2 first line)
285 | data_generator = DataGenerator(FLAGS.update_batch_size+10, FLAGS.meta_batch_size)
286 | elif FLAGS.datasource == 'polynomial':
287 | if FLAGS.train:
288 | data_generator = PolyDataGenerator(FLAGS.update_batch_size+10, FLAGS.meta_batch_size)
289 | else:
290 | data_generator = PolyDataGenerator(4000, FLAGS.meta_batch_size)
291 | elif FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet':
292 | assert FLAGS.meta_batch_size == 1
293 | assert FLAGS.update_batch_size == 1
294 | data_generator = DataGenerator(1, FLAGS.meta_batch_size) # only use one datapoint,
295 | elif FLAGS.datasource == 'miniimagenet': # TODO - use 15 val examples for imagenet?
296 | if FLAGS.train:
297 | data_generator = DataGenerator(FLAGS.update_batch_size+15, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory
298 | else:
299 | data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory
300 | else:
301 | assert FLAGS.datasource == 'omniglot'
302 | data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory
303 |
304 | dim_output = data_generator.dim_output
305 | if FLAGS.baseline == 'oracle':
306 | assert FLAGS.datasource == 'sinusoid'
307 | dim_input = 3
308 | FLAGS.pretrain_iterations += FLAGS.metatrain_iterations
309 | FLAGS.metatrain_iterations = 0
310 | else:
311 | dim_input = data_generator.dim_input
312 |
313 | if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot':
314 | tf_data_load = True
315 | num_classes = data_generator.num_classes
316 |
317 | if FLAGS.train: # only construct training model if needed
318 | random.seed(5)
319 | image_tensor, label_tensor = data_generator.make_data_tensor()
320 | inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
321 | inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
322 | labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
323 | labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
324 | input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}
325 |
326 | random.seed(6)
327 | image_tensor, label_tensor = data_generator.make_data_tensor(train=False)
328 | inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
329 | inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
330 | labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
331 | labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
332 | metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}
333 | else:
334 | input_tensors = None
335 | tf_data_load = False
336 |
337 | model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
338 | if FLAGS.train or not tf_data_load:
339 | model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
340 | if tf_data_load:
341 | model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_')
342 | model.summ_op = tf.summary.merge_all()
343 |
344 | saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=3)
345 |
346 | sess = tf.InteractiveSession()
347 |
348 | if not FLAGS.train:
349 | # change to original meta batch size when loading model.
350 | FLAGS.meta_batch_size = orig_meta_batch_size
351 |
352 | if FLAGS.train_update_batch_size == -1:
353 | FLAGS.train_update_batch_size = FLAGS.update_batch_size
354 | if FLAGS.train_update_lr == -1:
355 | FLAGS.train_update_lr = FLAGS.update_lr
356 |
357 | exp_string = 'cls_'+str(FLAGS.num_classes)+\
358 | '.mbs_'+str(FLAGS.meta_batch_size) + \
359 | '.ubs_' + str(FLAGS.train_update_batch_size) + \
360 | '.numstep' + str(FLAGS.num_updates) + \
361 | '.updatelr' + str(FLAGS.train_update_lr) + \
362 | '.temp' + str(FLAGS.temp)
363 |
364 | if FLAGS.debug:
365 | exp_string += '!DEBUG!'
366 |
367 | if FLAGS.use_T and FLAGS.use_M and FLAGS.share_M:
368 | exp_string += 'MTnet'
369 | if FLAGS.use_T and not FLAGS.use_M:
370 | exp_string += 'Tnet'
371 | if not FLAGS.use_T and FLAGS.use_M and FLAGS.share_M:
372 | exp_string += 'Mnet'
373 | if FLAGS.use_T and FLAGS.use_M and not FLAGS.share_M:
374 | exp_string += 'MTnet_noshare'
375 | if not FLAGS.use_T and FLAGS.use_M and not FLAGS.share_M:
376 | exp_string += 'Mnet_noshare'
377 | if not FLAGS.use_T and not FLAGS.use_M:
378 | exp_string += 'MAML'
379 |
380 | if FLAGS.datasource == 'polynomial':
381 | exp_string += 'ord' + str(FLAGS.poly_order)
382 | if FLAGS.num_train_classes != -1:
383 | exp_string += 'ntc' + str(FLAGS.num_train_classes)
384 | if FLAGS.num_filters != 64:
385 | exp_string += 'hidden' + str(FLAGS.num_filters)
386 | if FLAGS.max_pool:
387 | exp_string += 'maxpool'
388 | if FLAGS.stop_grad:
389 | exp_string += 'stopgrad'
390 | if FLAGS.baseline:
391 | exp_string += FLAGS.baseline
392 | if FLAGS.norm == 'batch_norm':
393 | exp_string += 'batchnorm'
394 | elif FLAGS.norm == 'layer_norm':
395 | exp_string += 'layernorm'
396 | elif FLAGS.norm == 'None':
397 | exp_string += 'nonorm'
398 | else:
399 | print('Norm setting not recognized.')
400 |
401 | resume_itr = 0
402 | tf.global_variables_initializer().run()
403 | tf.train.start_queue_runners()
404 |
405 | if FLAGS.resume or not FLAGS.train:
406 | model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string)
407 | if FLAGS.test_iter > 0:
408 | model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter)
409 | if model_file:
410 | ind1 = model_file.index('model')
411 | resume_itr = int(model_file[ind1+5:])
412 | print("Restoring model weights from " + model_file)
413 | saver.restore(sess, model_file)
414 |
415 | print flags.FLAGS.__flags
416 | print exp_string
417 |
418 | if FLAGS.train:
419 | train(model, saver, sess, exp_string, data_generator, resume_itr)
420 | else:
421 | test(model, saver, sess, exp_string, data_generator, test_num_updates)
422 |
423 |
424 | if __name__ == "__main__":
425 | main()
426 |
--------------------------------------------------------------------------------
/maml.py:
--------------------------------------------------------------------------------
1 | """ Code for the MAML algorithm and network definitions. """
2 | import numpy as np
3 |
4 | try:
5 | import special_grads
6 | except KeyError as e:
7 | print 'WARNING: Cannot define MaxPoolGrad, likely already defined for this version of TensorFlow:', e
8 | import tensorflow as tf
9 |
10 | from tensorflow.python.platform import flags
11 | from utils import mse, xent, conv_block, normalize
12 |
13 | FLAGS = flags.FLAGS
14 |
15 |
16 | class MAML:
17 | def __init__(self, dim_input=1, dim_output=1, test_num_updates=5):
18 | """ must call construct_model() after initializing MAML! """
19 | self.dim_input = dim_input
20 | self.dim_output = dim_output
21 | self.update_lr = FLAGS.update_lr
22 | self.meta_lr = tf.placeholder_with_default(FLAGS.meta_lr, ())
23 | self.classification = False
24 | self.test_num_updates = test_num_updates
25 | if FLAGS.datasource in ['sinusoid', 'polynomial']:
26 | self.dim_hidden = [FLAGS.dim_hidden, FLAGS.dim_hidden]
27 | if FLAGS.use_T:
28 | self.forward = self.forward_fc_withT
29 | else:
30 | self.forward = self.forward_fc
31 | self.construct_weights = self.construct_fc_weights
32 | self.loss_func = mse
33 | elif FLAGS.datasource == 'omniglot' or FLAGS.datasource == 'miniimagenet':
34 | self.loss_func = xent
35 | self.classification = True
36 | if FLAGS.conv:
37 | self.dim_hidden = FLAGS.num_filters
38 | if FLAGS.use_T:
39 | self.forward = self.forward_conv_withT
40 | else:
41 | self.forward = self.forward_conv
42 | self.construct_weights = self.construct_conv_weights
43 | else:
44 | self.dim_hidden = [256, 128, 64, 64]
45 | self.forward = self.forward_fc
46 | self.construct_weights = self.construct_fc_weights
47 | if FLAGS.datasource == 'miniimagenet':
48 | self.channels = 3
49 | else:
50 | self.channels = 1
51 | self.img_size = int(np.sqrt(self.dim_input / self.channels))
52 | else:
53 | raise ValueError('Unrecognized data source.')
54 |
55 | def construct_model(self, input_tensors=None, prefix='metatrain_'):
56 | # a: training data for inner gradient, b: test data for meta gradient
57 | if input_tensors is None:
58 | if 'inputa' not in dir(self):
59 | self.inputa = tf.placeholder(tf.float32)
60 | self.inputb = tf.placeholder(tf.float32)
61 | self.labela = tf.placeholder(tf.float32)
62 | self.labelb = tf.placeholder(tf.float32)
63 | else:
64 | self.inputa = input_tensors['inputa']
65 | self.inputb = input_tensors['inputb']
66 | self.labela = input_tensors['labela']
67 | self.labelb = input_tensors['labelb']
68 |
69 | with tf.variable_scope('model', reuse=None) as training_scope:
70 | self.dropout_probs = {}
71 | if 'weights' in dir(self):
72 | training_scope.reuse_variables()
73 | weights = self.weights
74 | else:
75 | # Define the weights
76 | self.weights = weights = self.construct_weights()
77 |
78 | # outputbs[i] and lossesb[i] is the output and loss after i+1 gradient updates
79 | lossesa, outputas, lossesb, outputbs = [], [], [], []
80 | accuraciesa, accuraciesb = [], []
81 | num_updates = max(self.test_num_updates, FLAGS.num_updates)
82 | outputbs = [[]] * num_updates
83 | lossesb = [[]] * num_updates
84 | accuraciesb = [[]] * num_updates
85 |
86 | def task_metalearn(inp, reuse=True):
87 | """ Perform gradient descent for one task in the meta-batch. """
88 | inputa, inputb, labela, labelb = inp
89 | task_outputbs, task_lossesb = [], []
90 | mse_lossesb = []
91 |
92 | if self.classification:
93 | task_accuraciesb = []
94 |
95 | train_keys = list(weights.keys())
96 | if FLAGS.use_M and FLAGS.share_M:
97 | def make_shared_mask(key):
98 | temperature = FLAGS.temp
99 | logits = weights[key+'_prob']
100 | logits = tf.stack([logits, tf.zeros(logits.shape)], 1)
101 | U = tf.random_uniform(logits.shape, minval=0, maxval=1)
102 | gumbel = -tf.log(-tf.log(U + 1e-20) + 1e-20)
103 | y = logits + gumbel
104 | gumbel_softmax = tf.nn.softmax(y / temperature)
105 | gumbel_hard = tf.cast(tf.equal(gumbel_softmax, tf.reduce_max(gumbel_softmax, 1, keep_dims=True)), tf.float32)
106 | mask = tf.stop_gradient(gumbel_hard - gumbel_softmax) + gumbel_softmax
107 | return mask[:, 0]
108 |
109 | def get_mask(masks, name):
110 | mask = masks[[k for k in masks.keys() if name[-1] in k][0]]
111 | if 'conv' in name: # Conv
112 | mask = tf.reshape(mask, [1, 1, 1, -1])
113 | tile_size = weights[name].shape.as_list()[:3] + [1]
114 | mask = tf.tile(mask, tile_size)
115 | elif 'w' in name: # FC
116 | mask = tf.reshape(mask, [1, -1])
117 | tile_size = weights[name].shape.as_list()[:1] + [1]
118 | mask = tf.tile(mask, tile_size)
119 | elif 'b' in name: # Bias
120 | mask = tf.reshape(mask, [-1])
121 | return mask
122 | if self.classification:
123 | masks = {k: make_shared_mask(k) for k in ['conv1', 'conv2', 'conv3', 'conv4', 'w5']}
124 | else:
125 | masks = {k: make_shared_mask(k) for k in ['w1', 'w2', 'w3']}
126 |
127 | if FLAGS.use_M and not FLAGS.share_M:
128 | def get_mask_noshare(key):
129 | temperature = FLAGS.temp
130 | logits = weights[key + '_prob']
131 | logits = tf.stack([logits, tf.zeros(logits.shape)], 1)
132 | U = tf.random_uniform(logits.shape, minval=0, maxval=1)
133 | gumbel = -tf.log(-tf.log(U + 1e-20) + 1e-20)
134 | y = logits + gumbel
135 | gumbel_softmax = tf.nn.softmax(y / temperature)
136 | gumbel_hard = tf.cast(tf.equal(gumbel_softmax, tf.reduce_max(gumbel_softmax, 1, keep_dims=True)), tf.float32)
137 | out = tf.stop_gradient(gumbel_hard - gumbel_softmax) + gumbel_softmax
138 | return tf.reshape(out[:, 0], weights[key].shape)
139 |
140 | train_keys = [k for k in weights.keys() if 'prob' not in k and 'f' not in k]
141 | train_weights = [weights[k] for k in train_keys]
142 | task_outputa = self.forward(inputa, weights, reuse=reuse) # only reuse on the first iter
143 | self.task_outputa = task_outputa
144 | task_lossa = self.loss_func(task_outputa, labela)
145 | grads = tf.gradients(task_lossa, train_weights)
146 | if FLAGS.stop_grad:
147 | grads = [tf.stop_gradient(grad) for grad in grads]
148 | gradients = dict(zip(train_keys, grads))
149 |
150 | fast_weights = dict(zip(weights.keys(), [weights[key] for key in weights.keys()]))
151 |
152 | def compute_weights(key):
153 | prev_weights = fast_weights[key]
154 | if key not in train_keys:
155 | return prev_weights
156 | if FLAGS.use_M and FLAGS.share_M:
157 | mask = get_mask(masks, key)
158 | new_weights = prev_weights - self.update_lr * mask * gradients[key]
159 | elif FLAGS.use_M and not FLAGS.share_M:
160 | mask = get_mask_noshare(key)
161 | new_weights = prev_weights - self.update_lr * mask * gradients[key]
162 | else:
163 | new_weights = prev_weights - self.update_lr * gradients[key]
164 | return new_weights
165 |
166 | fast_weights = dict(zip(
167 | weights.keys(), [compute_weights(key) for key in weights.keys()]))
168 |
169 | output = self.forward(inputb, fast_weights, reuse=True)
170 | task_outputbs.append(output)
171 | loss = self.loss_func(output, labelb)
172 | task_lossesb.append(loss)
173 |
174 | for j in range(num_updates - 1):
175 | output = self.forward(inputa, fast_weights, reuse=True)
176 | loss = self.loss_func(output, labela)
177 | train_weights = [fast_weights[k] for k in train_keys]
178 | grads = tf.gradients(loss, train_weights)
179 | if FLAGS.stop_grad:
180 | grads = [tf.stop_gradient(grad) for grad in grads]
181 | gradients = dict(zip(train_keys, grads))
182 |
183 | fast_weights = dict(zip(
184 | weights.keys(), [compute_weights(key) for key in weights.keys()]))
185 |
186 | output = self.forward(inputb, fast_weights, reuse=True)
187 | task_outputbs.append(output)
188 | loss = self.loss_func(output, labelb)
189 | task_lossesb.append(loss)
190 |
191 | task_output = [task_outputa, task_outputbs, task_lossa, task_lossesb]
192 |
193 | if self.classification:
194 | task_accuracya = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputa), 1),
195 | tf.argmax(labela, 1))
196 | for j in range(num_updates):
197 | task_accuraciesb.append(
198 | tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputbs[j]), 1),
199 | tf.argmax(labelb, 1)))
200 | task_output.extend([task_accuracya, task_accuraciesb])
201 |
202 | return task_output
203 |
204 | if FLAGS.norm is not 'None':
205 | # to initialize the batch norm vars, might want to combine this, and not run idx 0 twice.
206 | unused = task_metalearn((self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]), False)
207 |
208 | out_dtype = [tf.float32, [tf.float32] * num_updates, tf.float32, [tf.float32] * num_updates]
209 | if self.classification:
210 | out_dtype.extend([tf.float32, [tf.float32] * num_updates])
211 | result = tf.map_fn(task_metalearn, elems=(self.inputa, self.inputb, self.labela, self.labelb),
212 | dtype=out_dtype, parallel_iterations=FLAGS.meta_batch_size)
213 | if self.classification:
214 | outputas, outputbs, lossesa, lossesb, accuraciesa, accuraciesb = result
215 | else:
216 | outputas, outputbs, lossesa, lossesb = result
217 |
218 | logit_keys = sorted([k for k in weights.keys() if 'prob' in k])
219 | logit_weights = [-weights[k] for k in logit_keys]
220 | probs = [tf.exp(w) / (1 + tf.exp(w)) for w in logit_weights]
221 | self.total_probs = [tf.reduce_mean(p) for p in probs]
222 |
223 | ## Performance & Optimization
224 | if 'train' in prefix:
225 | self.total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size)
226 | self.total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j
227 | in range(num_updates)]
228 | # after the map_fn
229 | self.outputas, self.outputbs = outputas, outputbs
230 | if self.classification:
231 | self.total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(FLAGS.meta_batch_size)
232 | self.total_accuracies2 = total_accuracies2 = [
233 | tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
234 | self.pretrain_op = tf.train.AdamOptimizer(self.meta_lr).minimize( total_loss1)
235 |
236 | if FLAGS.metatrain_iterations > 0:
237 | optimizer = tf.train.AdamOptimizer(self.meta_lr)
238 | loss = self.total_losses2[FLAGS.num_updates - 1]
239 | self.gvs = gvs = optimizer.compute_gradients(loss)
240 | if FLAGS.datasource == 'miniimagenet':
241 | gvs = [(tf.clip_by_value(grad, -10, 10), var) for grad, var in gvs]
242 | self.metatrain_op = optimizer.apply_gradients(gvs)
243 |
244 | else:
245 | self.metaval_total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size)
246 | self.metaval_total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size)
247 | for j in range(num_updates)]
248 | if self.classification:
249 | self.metaval_total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(
250 | FLAGS.meta_batch_size)
251 | self.metaval_total_accuracies2 = total_accuracies2 = [
252 | tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
253 |
254 | ## Summaries
255 | tf.summary.scalar(prefix + 'change probs', tf.reduce_mean(self.total_probs))
256 | tf.summary.scalar(prefix + 'Pre-update loss', total_loss1)
257 | if self.classification:
258 | tf.summary.scalar(prefix + 'Pre-update accuracy', total_accuracy1)
259 |
260 | for j in range(num_updates):
261 | tf.summary.scalar(prefix + 'Post-update loss, step ' + str(j + 1), total_losses2[j])
262 | if self.classification:
263 | tf.summary.scalar(prefix + 'Post-update accuracy, step ' + str(j + 1), total_accuracies2[j])
264 |
265 | for k, v in weights.iteritems():
266 | tf.summary.histogram(k, v)
267 | if 'prob' in k:
268 | tf.summary.histogram('prob_'+k, tf.nn.softmax(tf.stack([v, tf.zeros(v.shape)], 1))[:, 0])
269 |
270 | ### Network construction functions (fc networks and conv networks)
271 | def construct_fc_weights(self):
272 | weights = {}
273 | weights['w1'] = tf.Variable(tf.truncated_normal([self.dim_input, self.dim_hidden[0]], stddev=0.01))
274 | weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden[0]]))
275 | for i in range(1, len(self.dim_hidden)):
276 | weights['w' + str(i + 1)] = tf.Variable(
277 | tf.truncated_normal([self.dim_hidden[i - 1], self.dim_hidden[i]], stddev=0.01))
278 | weights['b' + str(i + 1)] = tf.Variable(tf.zeros([self.dim_hidden[i]]))
279 | weights['w' + str(len(self.dim_hidden) + 1)] = tf.Variable(
280 | tf.truncated_normal([self.dim_hidden[-1], self.dim_output], stddev=0.01))
281 | weights['b' + str(len(self.dim_hidden) + 1)] = tf.Variable(tf.zeros([self.dim_output]))
282 |
283 | if FLAGS.use_M and not FLAGS.share_M:
284 | weights['w1_prob'] = tf.Variable(tf.truncated_normal([self.dim_input * self.dim_hidden[0]], stddev=.1))
285 | weights['b1_prob'] = tf.Variable(tf.truncated_normal([self.dim_hidden[0]], stddev=.1))
286 | for i in range(1, len(self.dim_hidden)):
287 | weights['w' + str(i + 1) + '_prob'] = tf.Variable(
288 | tf.truncated_normal([self.dim_hidden[i - 1] * self.dim_hidden[i]], stddev=.1))
289 | weights['b' + str(i + 1) + '_prob'] = tf.Variable(
290 | tf.truncated_normal([self.dim_hidden[i]], stddev=.1))
291 | weights['w' + str(len(self.dim_hidden) + 1) + '_prob'] = tf.Variable(
292 | tf.truncated_normal([self.dim_hidden[-1] * self.dim_output], stddev=0.1))
293 | weights['b' + str(len(self.dim_hidden) + 1) + '_prob'] = tf.Variable(
294 | tf.truncated_normal([self.dim_output], stddev=.1))
295 | elif FLAGS.use_M and FLAGS.share_M:
296 | weights['w1_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_hidden[0]]))
297 | for i in range(1, len(self.dim_hidden)):
298 | weights['w' + str(i + 1) + '_prob'] = tf.Variable(
299 | FLAGS.logit_init * tf.ones([self.dim_hidden[i]]))
300 | weights['w' + str(len(self.dim_hidden) + 1) + '_prob'] = tf.Variable(
301 | FLAGS.logit_init * tf.ones([self.dim_output]))
302 |
303 | if FLAGS.use_T:
304 | weights['w1_f'] = tf.Variable(tf.eye(self.dim_hidden[0]))
305 | weights['w2_f'] = tf.Variable(tf.eye(self.dim_hidden[1]))
306 | weights['w3_f'] = tf.Variable(tf.eye(self.dim_output))
307 | return weights
308 |
309 | def forward_fc(self, inp, weights, reuse=False):
310 | hidden = normalize(tf.matmul(inp, weights['w1']) + weights['b1'],
311 | activation=tf.nn.relu, reuse=reuse, scope='0')
312 | for i in range(1, len(self.dim_hidden)):
313 | hidden = normalize(tf.matmul(hidden, weights['w' + str(i + 1)]) + weights['b' + str(i + 1)],
314 | activation=tf.nn.relu, reuse=reuse, scope=str(i + 1))
315 | return tf.matmul(hidden, weights['w' + str(len(self.dim_hidden) + 1)]) + \
316 | weights['b' + str(len(self.dim_hidden) + 1)]
317 |
318 | def forward_fc_withT(self, inp, weights, reuse=False):
319 | hidden = tf.matmul(tf.matmul(inp, weights['w1']) + weights['b1'], weights['w1_f'])
320 | hidden = normalize(hidden, activation=tf.nn.relu, reuse=reuse, scope='1')
321 | hidden = tf.matmul(tf.matmul(hidden, weights['w2']) + weights['b2'], weights['w2_f'])
322 | hidden = normalize(hidden, activation=tf.nn.relu, reuse=reuse, scope='2')
323 | hidden = tf.matmul(tf.matmul(hidden, weights['w3']) + weights['b3'], weights['w3_f'])
324 | return hidden
325 |
326 | def construct_conv_weights(self):
327 | weights = {}
328 | dtype = tf.float32
329 | conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
330 | fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
331 | k = 3
332 | channels = self.channels
333 | dim_hidden = self.dim_hidden
334 |
335 | def get_conv(name, shape):
336 | return tf.get_variable(name, shape, initializer=conv_initializer, dtype=dtype)
337 |
338 | def get_identity(dim, conv=True):
339 | return tf.Variable(tf.eye(dim, batch_shape=[1,1])) if conv \
340 | else tf.Variable(tf.eye(dim))
341 |
342 | weights['conv1'] = get_conv('conv1', [k, k, channels, self.dim_hidden])
343 | weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden]))
344 | weights['conv2'] = get_conv('conv2', [k, k, dim_hidden, self.dim_hidden])
345 | weights['b2'] = tf.Variable(tf.zeros([self.dim_hidden]))
346 | weights['conv3'] = get_conv('conv3', [k, k, dim_hidden, self.dim_hidden])
347 | weights['b3'] = tf.Variable(tf.zeros([self.dim_hidden]))
348 | weights['conv4'] = get_conv('conv4', [k, k, dim_hidden, self.dim_hidden])
349 | weights['b4'] = tf.Variable(tf.zeros([self.dim_hidden]))
350 | if FLAGS.datasource == 'miniimagenet':
351 | # assumes max pooling
352 | assert FLAGS.max_pool
353 | weights['w5'] = tf.get_variable('w5', [self.dim_hidden * 5 * 5, self.dim_output],
354 | initializer=fc_initializer)
355 | weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5')
356 |
357 | if FLAGS.use_M and not FLAGS.share_M:
358 | weights['conv1_prob'] = tf.Variable(tf.truncated_normal([k * k * channels * self.dim_hidden], stddev=.01))
359 | weights['b1_prob'] = tf.Variable(tf.truncated_normal([self.dim_hidden], stddev=.01))
360 | weights['conv2_prob'] = tf.Variable(tf.truncated_normal([k * k * dim_hidden * self.dim_hidden], stddev=.01))
361 | weights['b2_prob'] = tf.Variable(tf.truncated_normal([self.dim_hidden], stddev=.01))
362 | weights['conv3_prob'] = tf.Variable(tf.truncated_normal([k * k * dim_hidden * self.dim_hidden], stddev=.01))
363 | weights['b3_prob'] = tf.Variable(tf.truncated_normal([self.dim_hidden], stddev=.01))
364 | weights['conv4_prob'] = tf.Variable(tf.truncated_normal([k * k * dim_hidden * self.dim_hidden], stddev=.01))
365 | weights['b4_prob'] = tf.Variable(tf.truncated_normal([self.dim_hidden], stddev=.01))
366 | weights['w5_prob'] = tf.Variable(tf.truncated_normal([dim_hidden *5*5* self.dim_output], stddev=.01))
367 | weights['b5_prob'] = tf.Variable(tf.truncated_normal([self.dim_output], stddev=.01))
368 | if FLAGS.use_M and FLAGS.share_M:
369 | weights['conv1_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_hidden]))
370 | weights['conv2_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_hidden]))
371 | weights['conv3_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_hidden]))
372 | weights['conv4_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_hidden]))
373 | weights['w5_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_output]))
374 |
375 | if FLAGS.use_T:
376 | weights['conv1_f'] = get_identity(self.dim_hidden, conv=True)
377 | weights['conv2_f'] = get_identity(self.dim_hidden, conv=True)
378 | weights['conv3_f'] = get_identity(self.dim_hidden, conv=True)
379 | weights['conv4_f'] = get_identity(self.dim_hidden, conv=True)
380 | weights['w5_f'] = get_identity(self.dim_output, conv=False)
381 | else:
382 | weights['w5'] = tf.Variable(tf.random_normal([dim_hidden, self.dim_output]), name='w5')
383 | weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5')
384 | if FLAGS.use_M and not FLAGS.share_M:
385 | weights['conv1_prob'] = tf.Variable(tf.truncated_normal([k * k * channels * self.dim_hidden], stddev=.01))
386 | weights['conv2_prob'] = tf.Variable(tf.truncated_normal([k * k * dim_hidden * self.dim_hidden], stddev=.01))
387 | weights['conv3_prob'] = tf.Variable(tf.truncated_normal([k * k * dim_hidden * self.dim_hidden], stddev=.01))
388 | weights['conv4_prob'] = tf.Variable(tf.truncated_normal([k * k * dim_hidden * self.dim_hidden], stddev=.01))
389 | weights['w5_prob'] = tf.Variable(tf.truncated_normal([dim_hidden * self.dim_output], stddev=.01))
390 | if FLAGS.use_M and FLAGS.share_M:
391 | weights['conv1_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_hidden]))
392 | weights['conv2_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_hidden]))
393 | weights['conv3_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_hidden]))
394 | weights['conv4_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_hidden]))
395 | weights['w5_prob'] = tf.Variable(FLAGS.logit_init * tf.ones([self.dim_output]))
396 |
397 | if FLAGS.use_T:
398 | weights['conv1_f'] = get_identity(self.dim_hidden, conv=True)
399 | weights['conv2_f'] = get_identity(self.dim_hidden, conv=True)
400 | weights['conv3_f'] = get_identity(self.dim_hidden, conv=True)
401 | weights['conv4_f'] = get_identity(self.dim_hidden, conv=True)
402 | weights['w5_f'] = get_identity(self.dim_output, conv=False)
403 | return weights
404 |
405 | def forward_conv(self, inp, weights, reuse=False, scope=''):
406 | # reuse is for the normalization parameters.
407 | channels = self.channels
408 | inp = tf.reshape(inp, [-1, self.img_size, self.img_size, channels])
409 | hidden1 = conv_block(inp, weights['conv1'], weights['b1'], reuse, scope + '0')
410 | hidden2 = conv_block(hidden1, weights['conv2'], weights['b2'], reuse, scope + '1')
411 | hidden3 = conv_block(hidden2, weights['conv3'], weights['b3'], reuse, scope + '2')
412 | hidden4 = conv_block(hidden3, weights['conv4'], weights['b4'], reuse, scope + '3')
413 |
414 | if FLAGS.datasource == 'miniimagenet':
415 | # last hidden layer is 6x6x64-ish, reshape to a vector
416 | hidden4 = tf.reshape(hidden4, [-1, np.prod([int(dim) for dim in hidden4.get_shape()[1:]])])
417 | else:
418 | hidden4 = tf.reduce_mean(hidden4, [1, 2])
419 | return tf.matmul(hidden4, weights['w5']) + weights['b5']
420 |
421 | def forward_conv_withT(self, inp, weights, reuse=False, scope=''):
422 | # reuse is for the normalization parameters.
423 | def conv_tout(inp, cweight, bweight, rweight, reuse, scope, activation=tf.nn.relu, max_pool_pad='VALID',
424 | residual=False):
425 | stride, no_stride = [1, 2, 2, 1], [1, 1, 1, 1]
426 | if FLAGS.max_pool:
427 | conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME') + bweight
428 | else:
429 | conv_output = tf.nn.conv2d(inp, cweight, stride, 'SAME') + bweight
430 | conv_output = tf.nn.conv2d(conv_output, rweight, no_stride, 'SAME')
431 | normed = normalize(conv_output, activation, reuse, scope)
432 | if FLAGS.max_pool:
433 | normed = tf.nn.max_pool(normed, stride, stride, max_pool_pad)
434 | return normed
435 |
436 | channels = self.channels
437 | inp = tf.reshape(inp, [-1, self.img_size, self.img_size, channels])
438 | hidden1 = conv_tout(inp, weights['conv1'], weights['b1'], weights['conv1_f'], reuse, scope + '0')
439 | hidden2 = conv_tout(hidden1, weights['conv2'], weights['b2'], weights['conv2_f'], reuse, scope + '1')
440 | hidden3 = conv_tout(hidden2, weights['conv3'], weights['b3'], weights['conv3_f'], reuse, scope + '2')
441 | hidden4 = conv_tout(hidden3, weights['conv4'], weights['b4'], weights['conv4_f'], reuse, scope + '3')
442 |
443 | if FLAGS.datasource == 'miniimagenet':
444 | # last hidden layer is 6x6x64-ish, reshape to a vector
445 | hidden4 = tf.reshape(hidden4, [-1, np.prod([int(dim) for dim in hidden4.get_shape()[1:]])])
446 | else:
447 | hidden4 = tf.reduce_mean(hidden4, [1, 2])
448 | hidden5 = tf.matmul(hidden4, weights['w5']) + weights['b5']
449 | return tf.matmul(hidden5, weights['w5_f'])
450 |
--------------------------------------------------------------------------------