├── LICENSE ├── README.md ├── data.py ├── data ├── miniImagenet │ ├── test.csv │ ├── train.csv │ └── val.csv └── preprocess_mini_imagenet.py ├── experiment_builder.py ├── meta_matching_network.py ├── saved_models ├── LGM-Net_5way1shot.csv └── LGM-Net_5way5shot.csv ├── storage.py └── train_meta_matching_network.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 likesiwell 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LGM-Net 2 | TensorFlow source code for the following publication: 3 | > LGM-Net: Learning to Generate Matching Networks for Few-Shot Learning 4 | > 5 | > Huaiyu Li, Weiming Dong, Xing Mei, Chongyang Ma, Feiyue Huang, Bao-Gang Hu 6 | > 7 | > In *Proceedings of the 36th International Conference on Machine Learning (ICML 2019)* 8 | 9 | # Requirements 10 | - Python 3.5 11 | - [NumPy](http://www.numpy.org/) 12 | - [SciPy](https://www.scipy.org/) 13 | - [tqdm](https://pypi.python.org/pypi/tqdm) 14 | - [Tensorflow 1.4](https://www.tensorflow.org/install/) 15 | - [Opencv 3.2.0](https://opencv.org/) 16 | 17 | ## Preparation 18 | Set the path of resized miniImageNet dataset in `data.py` 19 | 20 | ## Train 21 | ``` 22 | python train_meta_matching_network.py --way 5 --shot 1 23 | 24 | python train_meta_matching_network.py --way 5 --shot 5 25 | ``` 26 | ## Test 27 | ``` 28 | python train_meta_matching_network.py --way 5 --shot 1 --is_test True --ckp checkpoint_id 29 | ``` 30 | 31 | ## Acknowledgements 32 | Thanks to [Antreas Antoniou](https://github.com/AntreasAntoniou/) for his [Matching Networks implementation](https://github.com/AntreasAntoniou/MatchingNetworks) of which parts were used for this implementation. 33 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | from scipy.ndimage import rotate 5 | import scipy.misc 6 | from tqdm import tqdm 7 | import pandas as pd 8 | import time 9 | import matplotlib.pyplot as plt 10 | 11 | def load_miniImageNet(): 12 | """ 13 | :return: train, val, test tensor with shape (X, 600, 84, 84, 3) 14 | """ 15 | print("Loading resized MiniImageNet from jpg images") 16 | # resizetargetpath = '/home/hy/DataSets/miniImageNet/resizedminiImages/' 17 | resizetargetpath = './DataSets/miniImageNet/resizedminiImages/' 18 | csv_file_dir = './data/miniImagenet' 19 | def data_loader(csv_file): 20 | data = pd.read_csv(csv_file, sep=',') 21 | data = data.filename.tolist() 22 | img_list = [] 23 | for file in tqdm(data): 24 | img = scipy.misc.imread(resizetargetpath+file).astype(np.float32) 25 | img_list.append(img) 26 | imgs = np.concatenate(img_list, axis=0) 27 | imgs = np.reshape(imgs, [-1, 600, 84, 84, 3]) 28 | return imgs 29 | start = time.time() 30 | train = data_loader(os.path.join(csv_file_dir, 'train.csv')) 31 | val = data_loader(os.path.join(csv_file_dir, 'val.csv')) 32 | 33 | # fake data to accelerate testing without loading data 34 | # train = np.ones((64, 600, 84, 84, 3)) 35 | # val = np.ones((16, 600, 84, 84, 3)) 36 | test = data_loader(os.path.join(csv_file_dir, 'test.csv')) 37 | print("Loading from raw images data cost %.5f s" % (time.time()-start)) 38 | print(train.shape, val.shape, test.shape) 39 | 40 | return train, val, test 41 | 42 | 43 | class MiniImageNetDataSet(): 44 | def __init__(self, batch_size, classes_per_set=20, samples_per_class=5, seed=2591, shuffle_classes=False): 45 | ''' 46 | Construct a N-shot MiniImageNet Dataset 47 | :param batch_size: 48 | :param classes_per_set: 49 | :param samples_per_class: 50 | :param seed: 51 | :param shuffle_classes: 52 | e.g. For a 20-way, 1-shot learning task, use classes_per_set=20 and samples_per_class=1 53 | For a 5-way, 10-shot learning task, use classes_per_set=5 and samples_per_class=10 54 | ''' 55 | np.random.seed(seed) 56 | self.x_train, self.x_val, self.x_test = load_miniImageNet() 57 | if shuffle_classes: 58 | class_ids = np.arange(self.x_train.shape[0]) 59 | np.random.shuffle(class_ids) 60 | self.x_train = self.x_train[class_ids] 61 | class_ids = np.arange(self.x_val.shape[0]) 62 | np.random.shuffle(class_ids) 63 | self.x_val = self.x_val[class_ids] 64 | class_ids = np.arange(self.x_test.shape[0]) 65 | np.random.shuffle(class_ids) 66 | self.x_test = self.x_test[class_ids] 67 | 68 | # self.mean = np.mean(list(self.x_train)+list(self.x_val)) 69 | self.mean = 113.77 # precomputed 70 | # self.std = np.std(list(self.x_train)+list(self.x_val)) 71 | self.std = 70.1899 # precomputed 72 | print("mean ", self.mean, " std ", self.std) 73 | self.batch_size = batch_size 74 | self.n_classes = 100 75 | self.classes_per_set = classes_per_set 76 | self.samples_per_class = samples_per_class 77 | self.indexes = {"train": 0, "val": 0, "test": 0} 78 | self.datasets = {"train": self.x_train, "val": self.x_val, "test": self.x_test} 79 | 80 | def preprocess_batch(self, x_batch): 81 | x_batch = (x_batch-self.mean) / self.std 82 | return x_batch 83 | 84 | def sample_new_batch(self, data_pack): 85 | """ 86 | Collect batches data for N-shot learning 87 | :param data_pack: Data pack to use (any one of train, val, test) 88 | :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks 89 | """ 90 | support_set_x = np.zeros((self.batch_size, self.classes_per_set, self.samples_per_class, data_pack.shape[2], 91 | data_pack.shape[3], data_pack.shape[4]), dtype=np.float32) 92 | support_set_y = np.zeros((self.batch_size, self.classes_per_set, self.samples_per_class), dtype=np.float32) 93 | target_x = np.zeros((self.batch_size, data_pack.shape[2], data_pack.shape[3], data_pack.shape[4]), 94 | dtype=np.float32) 95 | target_y = np.zeros((self.batch_size, ), dtype=np.float32) 96 | # for each task, there is only one target image for test, for example, 5-way-1-shot, 97 | # support set contains 5 images and target set contains 1 image. 98 | for i in range(self.batch_size): 99 | # Each idx in batch contains a task 100 | classes_idx = np.arange(data_pack.shape[0]) 101 | samples_idx = np.arange(data_pack.shape[1]) 102 | # not select replicate samples 103 | choose_classes = np.random.choice(classes_idx, size=self.classes_per_set, replace=False) 104 | choose_label = np.random.choice(self.classes_per_set, size=1) 105 | choose_samples = np.random.choice(samples_idx, size=self.samples_per_class+1, replace=False) 106 | 107 | # select out the chosen classes as the task labels, make sure the images are correct 108 | x_temp = data_pack[choose_classes] 109 | x_temp = x_temp[:, choose_samples] 110 | y_temp = np.arange(self.classes_per_set) 111 | support_set_x[i] = x_temp[:, :-1] 112 | support_set_y[i] = np.expand_dims(y_temp[:], axis=1) 113 | # the target of the one-shot learning task, only choose one labels 114 | target_x[i] = x_temp[choose_label, -1] 115 | target_y[i] = y_temp[choose_label] 116 | 117 | return support_set_x, support_set_y, target_x, target_y 118 | 119 | def get_batch(self, dataset_name, augment=False): 120 | """ 121 | Gets next batch from the dataset with name. 122 | :param dataset_name: The name of the dataset (one of "train", "val", "test") 123 | :return: 124 | """ 125 | x_support_set, y_support_set, x_target, y_target = self.sample_new_batch(self.datasets[dataset_name]) 126 | if augment: 127 | # todo image data augmentation 128 | # k = np.random.randint(0, 4, size=(self.batch_size, self.classes_per_set)) 129 | k = np.random.choice(a=[-1, -0.25, 0, 0.25, 1], size=(self.batch_size, self.classes_per_set), replace=True) 130 | 131 | x_augmented_support_set = [] 132 | x_augmented_target_set = [] 133 | for b in range(self.batch_size): 134 | temp_class_support = [] 135 | 136 | for c in range(self.classes_per_set): 137 | x_temp_support_set = self.rotate_batch(x_support_set[b, c], axis=(1, 2), k=k[b, c]) 138 | if y_target[b] == y_support_set[b, c, 0]: 139 | x_temp_target = self.rotate_batch(x_target[b], axis=(0, 1), k=k[b, c]) 140 | 141 | temp_class_support.append(x_temp_support_set) 142 | 143 | x_augmented_support_set.append(temp_class_support) 144 | x_augmented_target_set.append(x_temp_target) 145 | 146 | x_support_set = np.array(x_augmented_support_set) 147 | x_target = np.array(x_augmented_target_set) 148 | x_support_set = self.preprocess_batch(x_support_set) 149 | x_target = self.preprocess_batch(x_target) 150 | 151 | return x_support_set, y_support_set, x_target, y_target 152 | 153 | def rotate_batch(self, x_batch, axis, k): 154 | # print(x_batch.shape, axis, k) 155 | # x_batch = rotate(x_batch, k*90, reshape=False, axes=axis, mode="nearest") 156 | x_batch = rotate(x_batch, k*45, reshape=False, axes=axis, mode="nearest") 157 | return x_batch 158 | 159 | def get_train_batch(self, augment=False): 160 | 161 | """ 162 | Get next training batch 163 | :return: Next training batch 164 | """ 165 | return self.get_batch("train", augment) 166 | 167 | def get_test_batch(self, augment=False): 168 | 169 | """ 170 | Get next test batch 171 | :return: Next test_batch 172 | """ 173 | return self.get_batch("test", augment) 174 | 175 | def get_val_batch(self, augment=False): 176 | 177 | """ 178 | Get next val batch 179 | :return: Next val batch 180 | """ 181 | return self.get_batch("val", augment) 182 | 183 | if __name__ == '__main__': 184 | mini = MiniImageNetDataSet(batch_size=32) 185 | x_support_set, y_support_set, x_target, y_target = mini.get_test_batch(augment=True) 186 | print(np.min(x_support_set), np.max(x_support_set)) # -1.62089 2.01211 187 | print(x_support_set.shape, y_support_set.shape, x_target.shape, y_target.shape) 188 | 189 | for i in range(10): 190 | for img in x_support_set[i, 2]: 191 | plt.imshow(img) 192 | plt.show() 193 | -------------------------------------------------------------------------------- /data/preprocess_mini_imagenet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import csv 3 | import glob, os 4 | from tqdm import tqdm 5 | from shutil import copyfile 6 | import cv2 7 | 8 | pathImageNet = './DataSets/ImageNet/' 9 | pathminiImageNet = './DataSets/miniImageNet/allimages' 10 | targetpath = './DataSets/miniImageNet/miniImages' 11 | resizetargetpath = './DataSets/miniImageNet/resizedminiImages' 12 | csv_file_dir = './miniImagenet' 13 | filesCSVSachinRavi = [os.path.join(csv_file_dir, 'train.csv'), 14 | os.path.join(csv_file_dir, 'val.csv'), 15 | os.path.join(csv_file_dir, 'test.csv')] 16 | 17 | for filename in filesCSVSachinRavi: 18 | with open(filename) as csvfile: 19 | csv_reader = csv.reader(csvfile, delimiter=',') 20 | next(csv_reader, None) 21 | images = {} 22 | print('Reading IDs....') 23 | for row in tqdm(csv_reader): 24 | if row[1] in images.keys(): 25 | images[row[1]].append(row[0]) 26 | else: 27 | images[row[1]] = [row[0]] 28 | 29 | print('Writing photos....') 30 | for c in tqdm(images.keys()): # Iterate over all the classes 31 | lst_files = [] 32 | for file in glob.glob(pathminiImageNet + "/*"+c+"*"): 33 | lst_files.append(file) # the absolute path 34 | # TODO: Sort by name of by index number of the image??? 35 | # I sort by the number of the image 36 | lst_index = [int(i[i.rfind('_')+1:i.rfind('.')]) for i in lst_files] 37 | index_sorted = sorted(range(len(lst_index)), key=lst_index.__getitem__) 38 | # print(images[c]) 39 | # 40 | # # Now iterate 41 | index_selected = [int(i[i.index('.') - 4:i.index('.')]) for i in images[c]] 42 | selected_images = np.array(index_sorted)[np.array(index_selected) - 1] 43 | # print("selected Images", selected_images) 44 | for i in np.arange(len(selected_images)): 45 | # read file and resize to 84x84x3 46 | im = cv2.imread(os.path.join(pathImageNet,lst_files[selected_images[i]])) 47 | im_resized = cv2.resize(im, (84, 84), interpolation=cv2.INTER_AREA) 48 | cv2.imwrite(os.path.join(resizetargetpath, images[c][i]), im_resized) 49 | copyfile(os.path.join(pathminiImageNet, lst_files[selected_images[i]]), os.path.join(targetpath, images[c][i])) 50 | -------------------------------------------------------------------------------- /experiment_builder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tqdm 3 | from meta_matching_network import MetaMatchingNetwork 4 | from scipy.stats import mode 5 | import numpy as np 6 | 7 | class ExperimentBuilder: 8 | 9 | def __init__(self, data): 10 | """ 11 | Initializes an ExperimentBuilder object. The ExperimentBuilder object takes care of setting up our experiment 12 | and provides helper functions such as run_training_epoch and run_validation_epoch to simplify out training 13 | and evaluation procedures. 14 | :param data: A data provider class 15 | """ 16 | self.data = data 17 | 18 | def build_experiment(self, batch_size, classes_per_set, samples_per_class, init_lr = 1e-3): 19 | """ 20 | :param batch_size: The experiment batch size 21 | :param classes_per_set: An integer indicating the number of classes per support set 22 | :param samples_per_class: An integer indicating the number of samples per class 23 | :param init_lr: The initial learning rate 24 | :return: some ops 25 | """ 26 | 27 | height, width, channels = self.data.x_train.shape[2], self.data.x_train.shape[3], self.data.x_train.shape[4] # (84, 84, 3) 28 | 29 | ## Construct placeholders 30 | self.support_set_images = tf.placeholder(tf.float32, [batch_size, classes_per_set, samples_per_class, height, width, 31 | channels], 'support_set_images') 32 | self.support_set_labels = tf.placeholder(tf.int32, [batch_size, classes_per_set, samples_per_class], 'support_set_labels') 33 | self.target_image = tf.placeholder(tf.float32, [batch_size, height, width, channels], 'target_image') 34 | self.target_label = tf.placeholder(tf.int32, [batch_size], 'target_label') 35 | self.training_phase = tf.placeholder(tf.bool, name='training-flag') 36 | self.rotate_flag = tf.placeholder(tf.bool, name='rotate-flag') 37 | self.keep_prob = tf.placeholder(tf.float32, name='dropout-prob') 38 | self.current_learning_rate = init_lr # 1e-3 39 | self.learning_rate = tf.placeholder(tf.float32, name='learning-rate-set') 40 | 41 | ## 42 | self.one_shot_learner = MetaMatchingNetwork(batch_size=batch_size, support_set_images=self.support_set_images, 43 | support_set_labels=self.support_set_labels, 44 | target_image=self.target_image, target_label=self.target_label, 45 | keep_prob=self.keep_prob, 46 | is_training=self.training_phase, rotate_flag=self.rotate_flag, 47 | num_classes_per_set=classes_per_set, 48 | num_samples_per_class=samples_per_class, learning_rate=self.learning_rate) 49 | 50 | _, self.losses, self.c_error_opt_op = self.one_shot_learner.init_train() 51 | init = tf.global_variables_initializer() 52 | self.total_train_iter = 0 53 | return self.one_shot_learner, self.losses, self.c_error_opt_op, init 54 | 55 | def run_training_epoch(self, total_train_batches, sess): 56 | """ 57 | Runs one training epoch 58 | :param total_train_batches: Number of batches to train on 59 | :param sess: Session object 60 | :return: mean_training_categorical_crossentropy_loss and mean_training_accuracy 61 | """ 62 | total_c_loss = 0. 63 | total_accuracy = 0. 64 | with tqdm.tqdm(total=total_train_batches) as pbar: 65 | 66 | for i in range(total_train_batches): # train epoch 67 | x_support_set, y_support_set, x_target, y_target = self.data.get_train_batch(augment=True) 68 | _, c_loss_value, acc = sess.run( 69 | [self.c_error_opt_op, self.losses['losses'], self.losses['accuracy']], 70 | feed_dict={self.keep_prob: 1.0, self.support_set_images: x_support_set, 71 | self.support_set_labels: y_support_set, self.target_image: x_target, self.target_label: y_target, 72 | self.training_phase: True, self.rotate_flag: False, self.learning_rate: self.current_learning_rate}) 73 | 74 | iter_out = "train_loss: {}, train_accuracy: {}, current lr: {}".format(c_loss_value, acc, self.current_learning_rate) 75 | pbar.set_description(iter_out) 76 | 77 | pbar.update(1) 78 | total_c_loss += c_loss_value 79 | total_accuracy += acc 80 | self.total_train_iter += 1 81 | if self.total_train_iter % 1500 == 0: 82 | self.current_learning_rate /= 1.11111 83 | self.current_learning_rate = max(1e-6, self.current_learning_rate) 84 | # set a lower bound of the learning rate, 1e-6 is reasonable, 1e-8 is too small 85 | print("Change learning rate to ", self.current_learning_rate) 86 | 87 | total_c_loss = total_c_loss / total_train_batches 88 | total_accuracy = total_accuracy / total_train_batches 89 | return total_c_loss, total_accuracy, self.current_learning_rate 90 | 91 | # some check functions for debugging 92 | def run_check_gradient(self, sess): 93 | x_support_set, y_support_set, x_target, y_target = self.data.get_train_batch(augment=True) 94 | feed_dict = {self.keep_prob: 1.0, self.support_set_images: x_support_set, 95 | self.support_set_labels: y_support_set, self.target_image: x_target, self.target_label: y_target, 96 | self.training_phase: True, self.rotate_flag: False, self.learning_rate: self.current_learning_rate} 97 | self.one_shot_learner.check_gradients_magnitude(sess, feed_dict=feed_dict) 98 | 99 | def run_check_tensor(self, sess): 100 | x_support_set, y_support_set, x_target, y_target = self.data.get_train_batch(augment=True) 101 | feed_dict = {self.keep_prob: 1.0, self.support_set_images: x_support_set, 102 | self.support_set_labels: y_support_set, self.target_image: x_target, self.target_label: y_target, 103 | self.training_phase: True, self.rotate_flag: False, self.learning_rate: self.current_learning_rate} 104 | self.one_shot_learner.check_tensors_magnitude(sess, feed_dict=feed_dict) 105 | 106 | def run_check_g(self, sess): 107 | self.one_shot_learner.check_g(sess) 108 | 109 | 110 | def run_check_genweights(self, sess): 111 | """ 112 | To generate the weights distribution of a task 113 | :param sess: 114 | :return: 115 | """ 116 | x_support_set, y_support_set, x_target, y_target = self.data.get_test_batch(augment=True) 117 | 118 | emb_list = [] 119 | label_list = [] 120 | for i in range(10): 121 | feed_dict = {self.keep_prob: 1.0, self.support_set_images: x_support_set, 122 | self.support_set_labels: y_support_set, self.target_image: x_target, self.target_label: y_target, 123 | self.training_phase: False, self.rotate_flag: False, self.learning_rate: self.current_learning_rate} 124 | tasks_gen_weights_list = self.losses['tasks_gen_weights_list'] # shape is batchsize x 4( four tensor for batchsize tasks) 125 | tgws = sess.run(tasks_gen_weights_list, feed_dict=feed_dict) 126 | tw_np_list = [] 127 | for tgw in tgws: 128 | tw = np.concatenate([tgw[0].reshape(-1), tgw[1].reshape(-1), tgw[2].reshape(-1), tgw[3].reshape(-1)]) 129 | tw_np_list.append(tw) 130 | 131 | emb = np.array(tw_np_list) 132 | label = np.arange(len(tgws)) 133 | emb_list.append(emb) 134 | label_list.append(label) 135 | 136 | embs = np.array(emb_list) 137 | labels = np.array(label_list) 138 | print("get result shape ", embs.shape, labels.shape) 139 | np.savez("data.npz", embs = embs, labels=labels) 140 | # todo better print some predictions results, to see the performance, we plot good predicted weights 141 | # tsne is not influenced by order, hence, here, we just give it a label, and store several inference 142 | 143 | 144 | def run_validation_epoch(self, total_val_batches, sess): 145 | """ 146 | Runs one validation epoch 147 | :param total_val_batches: Number of batches to train on 148 | :param sess: Session object 149 | :return: mean_validation_categorical_crossentropy_loss and mean_validation_accuracy 150 | """ 151 | total_val_c_loss = 0. 152 | total_val_accuracy = 0. 153 | 154 | with tqdm.tqdm(total=total_val_batches) as pbar: 155 | for i in range(total_val_batches): # validation epoch 156 | x_support_set, y_support_set, x_target, y_target = self.data.get_val_batch(augment=True) 157 | c_loss_value, acc = sess.run( 158 | [self.losses['losses'], self.losses['accuracy']], 159 | feed_dict={self.keep_prob: 1.0, self.support_set_images: x_support_set, 160 | self.support_set_labels: y_support_set, self.target_image: x_target, self.target_label: y_target, 161 | self.training_phase: False, self.rotate_flag: False}) 162 | 163 | iter_out = "val_loss: {}, val_accuracy: {}".format(c_loss_value, acc) 164 | pbar.set_description(iter_out) 165 | pbar.update(1) 166 | 167 | total_val_c_loss += c_loss_value 168 | total_val_accuracy += acc 169 | 170 | total_val_c_loss = total_val_c_loss / total_val_batches 171 | total_val_accuracy = total_val_accuracy / total_val_batches 172 | 173 | return total_val_c_loss, total_val_accuracy 174 | 175 | def run_testing_epoch(self, total_test_batches, sess): 176 | """ 177 | Runs one testing epoch 178 | :param total_test_batches: Number of batches to train on 179 | :param sess: Session object 180 | :return: mean_testing_categorical_crossentropy_loss and mean_testing_accuracy 181 | """ 182 | total_test_c_loss = 0. 183 | total_test_accuracy = 0. 184 | with tqdm.tqdm(total=total_test_batches) as pbar: 185 | for i in range(total_test_batches): 186 | x_support_set, y_support_set, x_target, y_target = self.data.get_test_batch(augment=True) 187 | c_loss_value, acc = sess.run( 188 | [self.losses['losses'], self.losses['accuracy']], 189 | feed_dict={self.keep_prob: 1.0, self.support_set_images: x_support_set, 190 | self.support_set_labels: y_support_set, self.target_image: x_target, 191 | self.target_label: y_target, 192 | self.training_phase: False, self.rotate_flag: False}) 193 | 194 | iter_out = "test_loss: {}, test_accuracy: {}".format(c_loss_value, acc) 195 | pbar.set_description(iter_out) 196 | pbar.update(1) 197 | 198 | total_test_c_loss += c_loss_value 199 | total_test_accuracy += acc 200 | total_test_c_loss = total_test_c_loss / total_test_batches 201 | total_test_accuracy = total_test_accuracy / total_test_batches 202 | return total_test_c_loss, total_test_accuracy 203 | 204 | 205 | def run_ensemble_testing_epoch(self, total_test_batches, sess): 206 | """ 207 | :param total_test_batches: 208 | :param sess: 209 | :return: 210 | """ 211 | print("********* run_ensemble_testing_epoch") 212 | total_es_test_accuracy = 0. # ensemble model 213 | total_sg_test_accuracy = 0. # single model 214 | n_ensembles = 10 # 20 215 | with tqdm.tqdm(total=total_test_batches) as pbar: 216 | for i in range(total_test_batches): 217 | x_support_set, y_support_set, x_target, y_target = self.data.get_test_batch(augment=True) 218 | t_list = [] 219 | for i in range(n_ensembles): 220 | preds = sess.run(self.losses['preds'], feed_dict={self.keep_prob: 1.0, self.support_set_images: x_support_set, 221 | self.support_set_labels: y_support_set, 222 | self.target_image: x_target, 223 | self.target_label: y_target, 224 | self.training_phase: False, self.rotate_flag: False}) 225 | t_preds = np.argmax(preds, axis=1) 226 | t_list.append(t_preds) 227 | 228 | ens_preds_st = np.stack(t_list, axis=0) 229 | ens_preds = mode(ens_preds_st, axis=0)[0][0] 230 | one_acc = np.mean(t_preds==y_target) 231 | ens_acc = np.mean(ens_preds==y_target) 232 | 233 | # print("Ensemble prediction {}".format(ens_preds_st)) 234 | # print("e {}".format(ens_preds)) 235 | # print("y {}".format(y_target.astype(np.int32))) 236 | 237 | iter_out = "Ensemble test_accuracy: {}, single model test accuracy: {}".format(ens_acc, one_acc) 238 | # print(iter_out) 239 | pbar.set_description(iter_out) 240 | pbar.update(1) 241 | 242 | total_es_test_accuracy += ens_acc 243 | total_sg_test_accuracy += one_acc 244 | 245 | total_sg_test_accuracy = total_sg_test_accuracy / total_test_batches 246 | total_es_test_accuracy = total_es_test_accuracy / total_test_batches 247 | return total_sg_test_accuracy, total_es_test_accuracy 248 | -------------------------------------------------------------------------------- /meta_matching_network.py: -------------------------------------------------------------------------------- 1 | """ 2 | In this file, I reimplement the function, and change the APIs to make it flexible to change functions. 3 | Since the whole architecture has clear architecture, so here it will not be too difficult. 4 | """ 5 | import tensorflow as tf 6 | import tensorflow.contrib.rnn as rnn 7 | from tensorflow.python.ops.nn_ops import max_pool, avg_pool 8 | import numpy as np 9 | 10 | def print_params(vars_list, name=None): 11 | print("#"*30) 12 | if name is not None: 13 | print("The variables of ", name) 14 | for var in vars_list: 15 | print(var.name, var.get_shape()) 16 | print("#"*30) 17 | def leaky_relu(x, leak=0.2, name='leaky_relu'): 18 | return tf.maximum(x, x * leak, name=name) 19 | def relu(x, name='relu'): 20 | return tf.nn.relu(x, name=name) 21 | 22 | def normalization(inputs, training, type='layer_norm'): 23 | """ 24 | :param inputs: 25 | :param training: 26 | :param type: 'batch_norm' , 'instance_norm' , 'layer_norm' 27 | :return: 28 | """ 29 | if type == 'batch_norm': 30 | return tf.contrib.layers.batch_norm(inputs, updates_collections=None, decay=0.99, 31 | scale=True, center=True, is_training=training) 32 | elif type == 'instance_norm': 33 | return tf.contrib.layers.instance_norm(inputs, center=True, scale=True) 34 | elif type == 'layer_norm': 35 | return tf.contrib.layers.layer_norm(inputs, center=True, scale=True) 36 | 37 | class DistanceNetwork: 38 | def __init__(self, metric='cosine'): 39 | """ 40 | :param metric: 'cosine', 'euclidean' 41 | 'cosine' is better, but also can use 'euclidean 42 | """ 43 | self.reuse = False 44 | self.metric = metric 45 | 46 | def __call__(self, support_set, input_image, name='_distance', training=False): 47 | """ 48 | This module calculates the cosine distance between each of the support set embeddings and the target 49 | image embeddings. 50 | :param support_set: The embeddings of the support set images, tensor of shape [batch_size, spc, 64] (32,5,576) 51 | :param input_image: The embedding of the target image, tensor of shape [batch_size, 64] (32, 576) 52 | :param name: Name of the op to appear on the graph 53 | :param training: Flag indicating training or evaluation (True/False) 54 | :return: A tensor with cosine similarities of shape [batch_size, sequence_length, 1] 55 | """ 56 | print("In DistanceNetwork, using ", self.metric) 57 | if self.metric == 'cosine': 58 | with tf.name_scope(self.metric+name): 59 | input_image = tf.expand_dims(input_image, axis=1) 60 | norm_s = tf.nn.l2_normalize(support_set, dim=2) 61 | norm_i = tf.nn.l2_normalize(input_image, dim=2) 62 | similarities = tf.reduce_sum(tf.multiply(norm_s, norm_i), axis=2) 63 | elif self.metric == 'euclidean': 64 | with tf.name_scope(self.metric + name): 65 | # euclidean distance should use negative one to be similarities, large distance means different 66 | input_image = tf.expand_dims(input_image, axis=1) 67 | similarities = -tf.reduce_mean(tf.square(support_set - input_image), axis=2) 68 | else: 69 | raise TypeError("Choose distance metrics from cosine, euclidean ") 70 | 71 | print("Similarities ", similarities) # (32, 5) 72 | return similarities 73 | 74 | class AttentionalClassify: 75 | def __init__(self): 76 | self.reuse = False 77 | 78 | def __call__(self, similarities, support_set_y, name, training=False): 79 | """ 80 | Produces pdfs over the support set classes for the target set image. 81 | n*k is sequence length 82 | :param similarities: A tensor with cosine similarities of size [ batch_size, n*k] (32, 20) 83 | :param support_set_y: A tensor with the one hot vectors of the targets for each support set image 84 | [batch_size, n*k, num_classes] (32, nk, 5) 85 | :param name: The name of the op to appear on tf graph 86 | :param training: Flag indicating training or evaluation stage (True/False) 87 | :return: Softmax pdf 88 | """ 89 | print("In AttentionalClassify") 90 | print(similarities.get_shape(), support_set_y.get_shape()) 91 | with tf.name_scope('attentional-classification' + name), tf.variable_scope('attentional-classification', 92 | reuse=self.reuse): 93 | softmax_similarities = tf.nn.softmax(similarities) # (32,5) 94 | preds = tf.squeeze(tf.matmul(tf.expand_dims(softmax_similarities, 1), support_set_y)) # (32,5) 95 | 96 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='attentional-classification') 97 | return preds 98 | 99 | ### Meta Network Classes 100 | 101 | """Fully Connected Meta Network, with reparameterization tricks """ 102 | class MetaNetwork: 103 | """ 104 | A general version of Meta Network, which can generate weights for both MLP and CNN with bias... 105 | Without sharing parameters for each kernel; This will requires more meta weights 106 | """ 107 | def __init__(self): 108 | self.reuse = False 109 | 110 | def __call__(self, inputs, context, out_size=64, kernel_size=[3, 3], name='Meta'): 111 | """ 112 | :param inputs: (nk+1, 6,6,128) tensor containing all samples in a task 113 | :param context: (64,) context vector for a task 114 | :param out_size: 115 | :param kernel_size: 116 | :param name: 117 | :return: 118 | """ 119 | print("In meta network, inputs shape:{}, context shape: {}".format(inputs.get_shape(), context.get_shape())) 120 | # (6, 11, 11, 128), context shape: (64,) 121 | 122 | inputs_shape_list = inputs.get_shape().as_list() 123 | c_dim = context.get_shape().as_list()[-1] 124 | # split the context into mean and variance predicted by task context encoder 125 | z_dim = c_dim // 2 126 | c_mu = context[:z_dim] 127 | c_log_var = context[z_dim:] 128 | 129 | 130 | if len(inputs_shape_list) == 4: 131 | is_CNN = True 132 | else: 133 | is_CNN = False 134 | 135 | if is_CNN == True: 136 | assert kernel_size[0] == kernel_size[1] 137 | f_size = kernel_size[0] # filter size 138 | in_size = inputs_shape_list[-1] # input channel number 64 139 | 140 | M = f_size*f_size*in_size 141 | N = out_size 142 | wt_shape = [M+1, N] # weights tensor shape, with bias 143 | else: 144 | M = inputs_shape_list[-1] 145 | N = out_size 146 | wt_shape = [M+1, N] 147 | 148 | with tf.variable_scope("MetaNetwork_" + name, reuse=self.reuse): 149 | 150 | with tf.variable_scope("z_signal"): 151 | z_signal = tf.random_normal(shape=[1, z_dim], name='z_signal') 152 | 153 | # reparameterization trick 154 | z_c_mu = tf.expand_dims(c_mu, axis=0) 155 | z_c_log_var = tf.expand_dims(c_log_var, axis=0) 156 | print(z_c_mu.get_shape(), z_c_log_var.get_shape(), z_signal.get_shape()) 157 | z_c = z_c_mu + tf.exp(z_c_log_var/2)*z_signal 158 | 159 | with tf.variable_scope("meta_weights"): 160 | w1 = tf.get_variable('w1', [z_dim, (M+1)*N], initializer=tf.glorot_uniform_initializer()) 161 | b1 = tf.get_variable('b1', [(M+1)*N], initializer=tf.constant_initializer(0.0)) 162 | final = tf.matmul(z_c, w1) + b1 # (N, M+1) 163 | meta_weights = final[0, :M*N] 164 | meta_bias = final[0, M*N:] 165 | print("Meta weights ", meta_weights.get_shape(), meta_bias.get_shape()) 166 | 167 | if is_CNN: 168 | meta_weights = tf.transpose(tf.reshape(meta_weights, (out_size, in_size, f_size, f_size))) 169 | else: 170 | meta_weights = tf.transpose(tf.reshape(meta_weights, (out_size, M))) 171 | 172 | # print("meta weights ", meta_weights, meta_bias) 173 | with tf.variable_scope("normalize_weights"): 174 | if is_CNN: 175 | meta_weights = tf.nn.l2_normalize(meta_weights, dim=[0, 1, 2]) # exp0 176 | else: 177 | meta_weights = tf.nn.l2_normalize(meta_weights, dim=[0]) # exp0 178 | return meta_weights, meta_bias 179 | 180 | 181 | class MetaConvolution: 182 | """ 183 | Meta Convolutional Network 184 | """ 185 | def __init__(self): 186 | self.reuse=False 187 | self.metanet = MetaNetwork() 188 | def __call__(self, inputs, context, filters, kernel_size, training=False, name='meta_conv'): 189 | """ 190 | :param inputs: A convolutional Tensor (nk+1, 6, 6, 128) 191 | :param context: a vector represent corresponding task context, which is (64, ) tensor 192 | :param filters: meta network output channel number 193 | :param kernel_size: 194 | :param training: 195 | :param keep_prob: In fact, this is a placeholder 196 | :return: 197 | """ 198 | # print("inputs ", inputs.get_shape()) 199 | # print("context ", context.get_shape()) 200 | meta_conv_w, meta_conv_b = self.metanet(inputs, context, out_size=filters, kernel_size=kernel_size, name=name) 201 | tf.add_to_collection('meta_conv_w', meta_conv_w) 202 | tf.add_to_collection('meta_conv_b', meta_conv_b) 203 | outputs = tf.nn.conv2d(inputs, meta_conv_w, strides=[1, 1, 1, 1], padding='SAME') + meta_conv_b 204 | self.reuse = True 205 | return outputs, meta_conv_w, meta_conv_b 206 | 207 | 208 | ### Task encoder Classes 209 | 210 | class TaskTransformer: 211 | def __init__(self): 212 | self.reuse = False 213 | self.layer_sizes = [64, 64, 64] # 64->32 214 | def __call__(self, task_embedding, training=False, keep_prob=1.0): 215 | """ 216 | 217 | :param task_images: images from a task 218 | :param training: 219 | :return: 220 | """ 221 | with tf.variable_scope("TaskTransformer", reuse=self.reuse): 222 | # 11*11 223 | with tf.variable_scope('t_conv1'): 224 | te = tf.layers.conv2d(task_embedding, self.layer_sizes[0], [3, 3], strides=(1, 1), 225 | padding='SAME') 226 | te = relu(te, name='relu') 227 | te = normalization(te, training=training, type='batch_norm') 228 | te = max_pool(te, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 229 | # 6*6 230 | with tf.variable_scope('t_conv2'): 231 | te = tf.layers.conv2d(te, self.layer_sizes[1], [3, 3], strides=(1, 1), padding='SAME') 232 | te = relu(te, name='relu') 233 | te = normalization(te, training=training, type='batch_norm') 234 | te = max_pool(te, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 235 | # 3*3 236 | with tf.variable_scope("t_conv3"): 237 | te = tf.layers.conv2d(te, self.layer_sizes[2], [3, 3], strides=(1, 1), padding='SAME') 238 | te = tf.reduce_mean(te, axis=[1, 2]) 239 | 240 | self.reuse = True 241 | return te 242 | 243 | class TaskContextEncoder: 244 | def __init__(self, batch_size, method='mean'): 245 | """ 246 | :param layer_sizes: A list containing the neuron numbers per layer e.g. [100, 100, 100] returns a 3 layer, 100 247 | neuron bid-LSTM 248 | [32] 249 | :param batch_size: The experiments batch size, useless here 250 | """ 251 | self.reuse = False 252 | self.tasktrans = TaskTransformer() 253 | self.batch_size = batch_size 254 | self.method = method 255 | 256 | def __call__(self, support_set_embeddings, training=False, name='TaskContext'): 257 | """ 258 | :param support_set_embeddings: a list of tensor (bs, k*n, w, h, c) 259 | :param training: 260 | :param name: 261 | :return: 262 | """ 263 | [bs, kn, w, h, c] = support_set_embeddings.get_shape().as_list() 264 | support_set_embeddings = tf.reshape(support_set_embeddings, shape=[bs*kn, w, h, c]) 265 | # feature transformer 266 | with tf.variable_scope(name_or_scope=name, reuse=self.reuse): 267 | if self.method == 'mean': 268 | t_context = self.tasktrans(support_set_embeddings, training=training) # (bs*kn, w1,h1,c1) 269 | t_context = tf.reshape(t_context, shape=[bs, kn, -1]) 270 | t_context = tf.reduce_mean(t_context, axis=1) # (bs, num_features) 271 | print("t_context shape ", t_context.get_shape()) # (32, 64) 272 | elif self.method == 'bilstm': 273 | ## todo add bilstm implementation, previous implementation fails 274 | pass 275 | else: 276 | raise TypeError("No Such Methods, please use mean") 277 | self.reuse = True 278 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=name) 279 | print_params(self.variables, name='Task Context Module') 280 | return t_context 281 | 282 | # feature extractor, all meta convolutions 283 | class Classifier: 284 | def __init__(self, batch_size): 285 | """ 286 | Fully Convolutional Network using meta convolution 287 | :param batch_size: Batch size for experiment 288 | :param layer_sizes: A list of length 4 containing the layer sizes 289 | :param num_channels: Number of channels of images 290 | """ 291 | self.reuse = False 292 | self.batch_size = batch_size 293 | self.meta_conv = MetaConvolution() 294 | self.layer_sizes = [64, 64] 295 | assert len(self.layer_sizes) == 2, "layer_sizes should be a list of length 2" 296 | 297 | def __call__(self, image_embedding, task_context, training=False, keep_prob=1.0): 298 | """ 299 | Runs the CNN producing the embeddings and the gradients. 300 | :param image_input: Image input to produce embeddings for. 301 | :param training: A flag indicating training or evaluation 302 | :param keep_prob: A tf placeholder of type tf.float32 indicating the amount of dropout applied 303 | :return: Embeddings of size [batch_size, 64] 304 | """ 305 | print("task_context shape ", task_context.get_shape()) 306 | with tf.variable_scope('Classifier', reuse=self.reuse): 307 | # 11*11 308 | with tf.variable_scope("meta_conv1"): 309 | m_conv1, m_conv1_w, m_conv1_b = self.meta_conv(image_embedding, task_context, self.layer_sizes[0], [3, 3], training=training) 310 | m_conv1 = relu(m_conv1, name='outputs') 311 | m_conv1 = max_pool(m_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], 312 | padding='SAME') 313 | # 6*6 314 | with tf.variable_scope("meta_conv2"): 315 | m_conv2, m_conv2_w, m_conv2_b = self.meta_conv(m_conv1, task_context, self.layer_sizes[1], [3, 3], training=training) 316 | m_conv2 = tf.contrib.layers.flatten(m_conv2) 317 | print("m_conv2 ", m_conv2.get_shape()) 318 | gen_weights_list = [m_conv1_w, m_conv1_b, m_conv2_w, m_conv2_b] 319 | 320 | 321 | self.reuse = True 322 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Classifier') 323 | # print_params(self.variables, name="Feature Extractor") 324 | return m_conv2, gen_weights_list 325 | 326 | # feature extractor 327 | class Extractor: 328 | def __init__(self): 329 | """ 330 | Builds a meta CNN to produce embeddings, the final layer weights are generated via meta network 331 | :param layer_sizes: A list of length 4 containing the layer sizes 332 | """ 333 | self.reuse = False 334 | self.layer_sizes = [64, 64, 64, 64] 335 | assert len(self.layer_sizes) == 4, "layer_sizes should be a list of length 4" 336 | 337 | def __call__(self, support_target_images, training=False, keep_prob=1.0): 338 | """ 339 | Runs the CNN producing the embeddings and the gradients. 340 | :param image_input: Image input to produce embeddings for. [batch_size, 28, 28, 1] 341 | :param training: A flag indicating training or evaluation 342 | :param keep_prob: A tf placeholder of type tf.float32 indicating the amount of dropout applied 343 | :return: Embeddings of size [batch_size, 64] 344 | """ 345 | [bs, kn, w, h, c] = support_target_images.get_shape().as_list() 346 | support_target_images = tf.reshape(support_target_images, shape=[bs*kn, w, h, c]) 347 | with tf.variable_scope('extractor', reuse=self.reuse): 348 | 349 | with tf.variable_scope('conv_layers'): 350 | # 84*84 351 | with tf.variable_scope('g_conv1'): 352 | g_conv1_encoder = tf.layers.conv2d(support_target_images, self.layer_sizes[0], [3, 3], strides=(1, 1), 353 | padding='SAME') 354 | g_conv1_encoder = tf.contrib.layers.batch_norm(g_conv1_encoder, updates_collections=None, decay=0.99, 355 | scale=True, center=True, is_training=training) 356 | 357 | g_conv1_encoder = relu(g_conv1_encoder, name='outputs') 358 | g_conv1_encoder = max_pool(g_conv1_encoder, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], 359 | padding='SAME') 360 | g_conv1_encoder = tf.nn.dropout(g_conv1_encoder, keep_prob=keep_prob) 361 | # 42*42 362 | with tf.variable_scope('g_conv2'): 363 | g_conv2_encoder = tf.layers.conv2d(g_conv1_encoder, self.layer_sizes[1], [3, 3], strides=(1, 1), 364 | padding='SAME') 365 | g_conv2_encoder = tf.contrib.layers.batch_norm(g_conv2_encoder, updates_collections=None, decay=0.99, 366 | scale=True, center=True, is_training=training) 367 | 368 | g_conv2_encoder = relu(g_conv2_encoder, name='outputs') 369 | g_conv2_encoder = max_pool(g_conv2_encoder, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], 370 | padding='SAME') 371 | 372 | # 21*21 373 | with tf.variable_scope('g_conv3'): 374 | g_conv3_encoder = tf.layers.conv2d(g_conv2_encoder, self.layer_sizes[2], [3, 3], strides=(1, 1), 375 | padding='SAME') 376 | g_conv3_encoder = tf.contrib.layers.batch_norm(g_conv3_encoder, updates_collections=None, decay=0.99, 377 | scale=True, center=True, is_training=training) 378 | 379 | g_conv3_encoder = relu(g_conv3_encoder, name='outputs') 380 | g_conv3_encoder = max_pool(g_conv3_encoder, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], 381 | padding='SAME') 382 | # 11*11 383 | with tf.variable_scope('g_conv4'): 384 | g_conv4_encoder = tf.layers.conv2d(g_conv3_encoder, self.layer_sizes[3], [3, 3], strides=(1, 1), 385 | padding='SAME') 386 | g_conv4_encoder = tf.contrib.layers.batch_norm(g_conv4_encoder, updates_collections=None, decay=0.99, 387 | scale=True, center=True, is_training=training) 388 | g_conv4_encoder = relu(g_conv4_encoder, name='outputs') # ? 389 | 390 | 391 | self.reuse = True 392 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='conv_layers') 393 | [bskn, we, he, ce] = g_conv4_encoder.get_shape().as_list() 394 | embeddings = tf.reshape(g_conv4_encoder, [bs, kn, we, he, ce]) 395 | 396 | # print_params(self.variables, name="Feature Extractor") 397 | return embeddings 398 | 399 | 400 | class MetaMatchingNetwork: 401 | def __init__(self, support_set_images, support_set_labels, target_image, target_label, keep_prob, 402 | batch_size=32, is_training=False, learning_rate=0.001, rotate_flag=False, num_classes_per_set=5, 403 | num_samples_per_class=1, task_method="mean"): 404 | 405 | """ 406 | Builds a matching network, the training and evaluation ops as well as data augmentation routines. 407 | :param support_set_images: A tensor containing the support set images [batch_size, sequence_size, 28, 28, 1] 408 | :param support_set_labels: A tensor containing the support set labels [batch_size, sequence_size, 1] 409 | :param target_image: A tensor containing the target image (image to produce label for) [batch_size, 28, 28, 1] 410 | :param target_label: A tensor containing the target label [batch_size, 1] 411 | :param keep_prob: A tf placeholder of type tf.float32 denotes the amount of dropout to be used 412 | :param batch_size: The batch size for the experiment 413 | :param num_channels: Number of channels of the images 414 | :param is_training: Flag indicating whether we are training or evaluating 415 | :param rotate_flag: Flag indicating whether to rotate the images; This is useless!!!!!!!!!!!!!! 416 | :param num_classes_per_set: Integer indicating the number of classes per set 417 | :param num_samples_per_class: Integer indicating the number of samples per class 418 | :param task_method: Choose from "mean" 419 | """ 420 | 421 | self.batch_size = batch_size 422 | self.Classifier = Classifier(self.batch_size) 423 | self.tce = TaskContextEncoder(batch_size=self.batch_size, method=task_method) 424 | self.dn = DistanceNetwork(metric='cosine') 425 | self.extractor = Extractor() 426 | self.classify = AttentionalClassify() 427 | self.support_set_images = support_set_images 428 | self.support_set_labels = support_set_labels 429 | self.target_image = target_image 430 | self.target_label = target_label 431 | 432 | self.keep_prob = keep_prob 433 | self.is_training = is_training 434 | self.rotate_flag = rotate_flag 435 | self.num_classes_per_set = num_classes_per_set 436 | self.num_samples_per_class = num_samples_per_class 437 | self.learning_rate = learning_rate 438 | self.tensor_list = [] 439 | 440 | def loss(self): 441 | """ 442 | Builds tf graph for Matching Networks, produces losses and summary statistics. 443 | :return: 444 | """ 445 | with tf.name_scope("losses"): 446 | [b, num_classes, spc] = self.support_set_labels.get_shape().as_list() 447 | 448 | self.support_set_labels_ = tf.reshape(self.support_set_labels, shape=(b, num_classes * spc)) 449 | self.support_set_labels_ = tf.one_hot(self.support_set_labels_, self.num_classes_per_set) # one hot encode 450 | 451 | [b, num_classes, spc, h, w, c] = self.support_set_images.get_shape().as_list() 452 | self.support_set_images_ = tf.reshape(self.support_set_images, shape=(b, num_classes*spc, h, w, c)) 453 | 454 | ## zero step: extractor feature embeddings 455 | self.target_image_ = tf.expand_dims(self.target_image, axis=1) #(b, 1, h, w, c) 456 | ## merge support set and target set, in order to share the feature extractors 457 | support_target_images = tf.concat([self.support_set_images_, self.target_image_], axis=1) #(b, n*k+1, h, w, c) 458 | print("+++ support_target images ", support_target_images.get_shape()) # (32, 6, 84, 84, 3) 459 | print("+++ support_target images [:-1]", support_target_images[:, :-1].get_shape()) # (32, 5, 84, 84, 3) 460 | support_target_embeddings = self.extractor(support_target_images, training=self.is_training, keep_prob=self.keep_prob) 461 | print("+++", support_target_embeddings.get_shape()) # (32, 6, 6, 6 , 96) the last dimension is feature dimension 462 | 463 | ## first step: generate task feature representation by using support set features 464 | task_contexts = self.tce(support_target_embeddings[:, :-1], training=self.is_training) # (bs, num_task_features) (32, 64) 465 | 466 | ## second step: transform images via conditional meta task convolution 467 | trans_support_images_list = [] 468 | trans_target_images_list = [] 469 | tasks_gen_weights_list = [] # todo test generated weights distribution 470 | for i, (tc, ste) in enumerate(zip(tf.unstack(task_contexts), tf.unstack(support_target_embeddings))): 471 | print("============ In task instance ", i) 472 | # support task image embeddings for one task 473 | steb, gen_weights_list = self.Classifier(image_embedding=ste, task_context=tc, training=self.is_training, keep_prob=self.keep_prob) # (6, 4608) 474 | trans_support_images_list.append(steb[:-1]) 475 | trans_target_images_list.append(steb[-1]) 476 | tasks_gen_weights_list.append(gen_weights_list) 477 | 478 | trans_support = tf.stack(trans_support_images_list) 479 | trans_target = tf.stack(trans_target_images_list) 480 | print("=="*10) # shape error 481 | print("trans support set shape and target shape ", trans_support.get_shape(), trans_target.get_shape()) # (32, 5, 4608) (32, 4608) 482 | 483 | similarities = self.dn(support_set=trans_support, input_image=trans_target, name="distance_calculation", 484 | training=self.is_training) #get similarity between support set embeddings and target 485 | 486 | preds = self.classify(similarities, support_set_y=self.support_set_labels_, name='classify', training=self.is_training) 487 | 488 | if self.batch_size == 1: 489 | print("If preds is batchsize = 1, reshape it to avoid shape error.") 490 | preds = tf.reshape(preds, shape=(self.batch_size, preds.get_shape().as_list()[-1])) 491 | print("preds shape ", preds.get_shape(), tf.argmax(preds, 1).get_shape()) # (bs, num_classes) 492 | print("target label shape ", self.target_label.get_shape()) 493 | # produce predictions for target probabilities 494 | correct_prediction = tf.equal(tf.argmax(preds, 1), tf.cast(self.target_label, tf.int64)) 495 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 496 | targets = tf.one_hot(self.target_label, self.num_classes_per_set) 497 | print("targets shape one hot ", targets.get_shape()) 498 | crossentropy_loss = tf.reduce_mean(-tf.reduce_sum(targets * tf.log(preds), 499 | reduction_indices=[1])) 500 | print(crossentropy_loss) 501 | 502 | tf.add_to_collection('crossentropy_losses', crossentropy_loss) 503 | tf.add_to_collection('accuracy', accuracy) 504 | 505 | # todo why return like this, rather than a better string keyworkds? 506 | return { 507 | 'losses': tf.add_n(tf.get_collection('crossentropy_losses'), name='total_classification_loss'), 508 | 'accuracy': tf.add_n(tf.get_collection('accuracy'), name='accuracy'), 509 | 'preds': preds, # added for ensemble training 510 | 't_label': self.target_label, 511 | 'tasks_gen_weights_list': tasks_gen_weights_list 512 | } 513 | 514 | 515 | def test_ensemble(self, M =1): 516 | """ 517 | Test using the simpliest ensemble methods: max voting 518 | But this implemetation is not used, because it is complicated, we just test run the same task instance for 519 | several times and max voting the results. In experiemnt_builder.py. 520 | :return: 521 | """ 522 | with tf.name_scope("losses"): 523 | [b, num_classes, spc] = self.support_set_labels.get_shape().as_list() 524 | print("data type ", self.support_set_labels.dtype) 525 | self.support_set_labels_ = tf.reshape(self.support_set_labels, shape=(b, num_classes * spc)) 526 | print("data type ", self.support_set_labels.dtype) 527 | self.support_set_labels_ = tf.one_hot(self.support_set_labels_, self.num_classes_per_set) # one hot encode 528 | 529 | [b, num_classes, spc, h, w, c] = self.support_set_images.get_shape().as_list() 530 | support_set_images_ = tf.reshape(self.support_set_images, shape=(b, num_classes*spc, h, w, c)) 531 | 532 | ## zero step: extractor feature embeddings 533 | target_image_ = tf.expand_dims(self.target_image, axis=1) #(b, 1, h, w, c) 534 | ## merge support set and target set, in order to share the feature extractors 535 | support_target_images = tf.concat([support_set_images_, target_image_], axis=1) #(b, n*k+1, h, w, c) 536 | print("+++ support_target images ", support_target_images.get_shape()) # (32, 6, 84, 84, 3) 537 | print("+++ support_target images [:-1]", support_target_images[:, :-1].get_shape()) # (32, 5, 84, 84, 3) 538 | support_target_embeddings = self.extractor(support_target_images, training=self.is_training, keep_prob=self.keep_prob) 539 | print("+++", support_target_embeddings.get_shape()) # (32, 6, 6, 6 , 96) the last dimension is feature dimension 540 | 541 | ## first step: generate task feature representation 542 | task_contexts = self.tce(support_target_embeddings[:, :-1], training=self.is_training) # (bs, num_task_features) (32, 64) 543 | 544 | ## second step: transform images via conditional meta task convolution 545 | ## todo In order to generate ensemble weights for the same task instance, we just need to run generation network several times 546 | ensemble_preds = [] 547 | for m in range(M): 548 | trans_support_images_list = [] 549 | trans_target_images_list = [] 550 | for i, (tc, ste) in enumerate(zip(tf.unstack(task_contexts), tf.unstack(support_target_embeddings))): 551 | print("============ In task instance ", i) 552 | # support task image embeddings for one task 553 | steb = self.Classifier(image_embedding=ste, task_context=tc, training=self.is_training, keep_prob=self.keep_prob) #(6, 4608) 554 | trans_support_images_list.append(steb[:-1]) 555 | trans_target_images_list.append(steb[-1]) 556 | 557 | 558 | trans_support = tf.stack(trans_support_images_list) 559 | trans_target = tf.stack(trans_target_images_list) 560 | print("==" * 10) # shape error 561 | print("trans support set shape and target shape ", trans_support.get_shape(), trans_target.get_shape()) 562 | 563 | similarities = self.dn(support_set=trans_support, input_image=trans_target, name="distance_calculation", 564 | training=self.is_training) # get similarity between support set embeddings and target 565 | 566 | preds = self.classify(similarities, 567 | support_set_y=self.support_set_labels_, name='classify', training=self.is_training) 568 | print("preds shape ", preds.get_shape()) # (bs, num_classes) 569 | ensemble_preds.append(tf.arg_max(preds, 1)) 570 | 571 | ensemble_preds = tf.stack(ensemble_preds) 572 | 573 | 574 | return ensemble_preds 575 | 576 | 577 | def train(self, losses): 578 | """ 579 | Builds the train op 580 | :param losses: A dictionary containing the losses 581 | :param learning_rate: Learning rate to be used for Adam 582 | :param beta1: Beta1 to be used for Adam 583 | :return: 584 | """ 585 | c_opt = tf.train.AdamOptimizer(beta1=0.9, learning_rate=self.learning_rate) 586 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Needed for correct batch norm usage 587 | with tf.control_dependencies(update_ops): # Needed for correct batch norm usage 588 | train_variables = tf.trainable_variables() # all variables 589 | c_error_opt_op = c_opt.minimize(losses['losses'], 590 | var_list=train_variables) 591 | print_params(train_variables, "All trainable variables") 592 | 593 | return c_error_opt_op, train_variables 594 | 595 | 596 | def init_train(self): 597 | """ 598 | Get all ops, as well as all losses. 599 | :return: 600 | """ 601 | losses = self.loss() 602 | c_error_opt_op, trainable_variables = self.train(losses) 603 | summary = tf.summary.merge_all() # summary is not used 604 | 605 | # construct gradient check operation 606 | check_var_list = trainable_variables 607 | print_params(check_var_list, "check_var_list") 608 | grads_list = tf.gradients(losses['losses'], check_var_list) 609 | 610 | # print_params(grads_list, "gradient_list") 611 | self.grad_var_dict = {'var': check_var_list, 'grad': grads_list} 612 | 613 | return summary, losses, c_error_opt_op 614 | 615 | 616 | def check_gradients_magnitude(self, sess, feed_dict): 617 | """ 618 | Using self.all_trainable_variables and self.losses to compute the gradients of 619 | :param sess: 620 | :param feed_dict: 621 | :return: 622 | """ 623 | print("check gradients") 624 | print("name, grad norm, mean, std, max, min | var norm, mean, std, max, min") 625 | grad_values = sess.run(self.grad_var_dict['grad'], feed_dict=feed_dict) 626 | var_values = sess.run(self.grad_var_dict['var'], feed_dict=feed_dict) 627 | for var, g_value, v_value in zip(self.grad_var_dict['var'], grad_values, var_values): 628 | print(var.name, np.linalg.norm(g_value), np.mean(g_value), np.std(g_value), np.max(g_value), np.min(g_value), "|", 629 | np.linalg.norm(v_value), np.mean(v_value), np.std(v_value), np.max(v_value), np.min(v_value)) 630 | 631 | def check_tensors_magnitude(self, sess, feed_dict): 632 | print("check meta convolution weights================") 633 | tensors = sess.run(self.tensor_list, feed_dict=feed_dict) 634 | for t, t_v in zip(self.tensor_list, tensors): 635 | print(t.name, np.linalg.norm(t_v), np.mean(t_v), np.std(t_v), np.max(t_v), np.min(t_v)) 636 | 637 | def check_g(self, sess): 638 | g1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="Classifier/meta_conv1/MetaNetwork_meta_conv/normalize_weights/") 639 | g2 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="Classifier/meta_conv2/MetaNetwork_meta_conv/normalize_weights/") 640 | g1_, g2_ = sess.run([g1, g2]) 641 | print("g1 ", g1_) 642 | print("g2 ", g2_) 643 | -------------------------------------------------------------------------------- /saved_models/LGM-Net_5way1shot.csv: -------------------------------------------------------------------------------- 1 | "Experimental details: 5way1shot learning problems, with 32 tasks per task batch" 2 | epoch,train_c_loss,train_c_accuracy,val_loss,val_accuracy,test_c_loss,test_c_accuracy,learning_rate 3 | 0,1.5445495754480363,0.3121875,|,1.6010234627723694,0.26925,|,0.001 4 | 1,1.499713549375534,0.34771875,|,1.5617561650276184,0.3005,|,0.0009000009000009 5 | 2,1.4878948799371718,0.36425,|,1.541407859325409,0.31575,|,0.0008100016200024299 6 | 3,1.4694296985864639,0.383125,|,1.5346404581069946,0.3255,|,0.0008100016200024299 7 | 4,1.4319936710596084,0.417125,|,1.4182707328796387,0.413875,|,0.0007290021870043739 8 | 5,1.3397139768600463,0.47459375,|,1.398400444984436,0.419375,|,0.0006561026244065608 9 | 6,1.2927515919208528,0.504125,|,1.3127874517440796,0.465,|,0.0006561026244065608 10 | 7,1.2591973766088487,0.521875,|,1.3165641841888427,0.4665,|,0.0005904929524588572 11 | 8,1.230527029633522,0.53996875,|,1.2442341322898864,0.517375,|,0.0005314441886571602 12 | 9,1.2000443918704986,0.5525,|,1.2156013498306275,0.521875,|,0.0005314441886571602 13 | 10,1.1692824005484581,0.57296875,|,1.1878551177978516,0.533375,|,0.00047830024809169223 14 | 11,1.1498141718506814,0.58465625,|,1.2198719878196715,0.5195,|,0.0004304706537531768 15 | 12,1.1299580749869347,0.5935625,|,1.1591106853485107,0.554875,|,0.0004304706537531768 16 | 13,1.117333487033844,0.608,|,1.2265379204750062,0.51425,|,0.00038742397580183486 17 | 14,1.1051008509397506,0.61390625,|,1.1441001234054566,0.58,|,0.00034868192690357824 18 | 15,1.0985774109959603,0.6211875,|,1.1319306268692018,0.57875,|,0.00034868192690357824 19 | 16,1.0924408220648765,0.6229375,|,1.1046866605281829,0.59675,|,0.00031381404802726846 20 | 17,1.0873432452082634,0.62784375,|,1.1410838220119477,0.566875,|,0.00028243292565746723 21 | 18,1.0818726527094842,0.62896875,|,1.114662103652954,0.59175,|,0.00028243292565746723 22 | 19,1.0768992357254028,0.628375,|,1.0985932338237763,0.601,|,0.00025418988728160777 23 | 20,1.0735189628005029,0.63665625,|,1.0863552334308624,0.6125,|,0.00022877112732457431 24 | 21,1.0662873446941377,0.63971875,|,1.0946032395362855,0.601375,|,0.00022877112732457431 25 | 22,1.0663728016018867,0.64096875,|,1.0941046109199524,0.610125,|,0.00020589422048633736 26 | 23,1.0613981426954269,0.64653125,|,1.0779522473812104,0.613125,|,0.00018530498374268737 27 | 24,1.0587388240098954,0.64903125,|,1.080456080675125,0.624375,|,0.00018530498374268737 28 | 25,1.05669124430418,0.6449375,|,1.0736862227916717,0.617,|,0.00016677465214307076 29 | 26,1.0569503492712975,0.647625,|,1.0762848591804504,0.616375,|,0.0001500973370261007 30 | 27,1.0533107498288155,0.6515625,|,1.0822871241569518,0.617375,|,0.0001500973370261007 31 | 28,1.0466531674861907,0.66365625,|,1.0705348944664002,0.634625,|,0.00013508773841122904 32 | 29,1.0487820727825166,0.6585,|,1.0733948266506195,0.625,|,0.00012157908614919228 33 | 30,1.046616067171097,0.65921875,|,1.0532318587303162,0.636125,|,0.00012157908614919228 34 | 31,1.0443749601244927,0.6676875,|,1.0608603971004487,0.642375,|,0.00010942128695556 35 | 32,1.0404647966623306,0.67728125,|,1.0522499492168427,0.64425,|,9.847925673926074e-05 36 | 33,1.0376997192502022,0.67825,|,1.0515476005077362,0.658,|,9.847925673926074e-05 37 | 34,1.034475837826729,0.6885625,|,1.0506402542591096,0.666625,|,8.863141969675436e-05 38 | 35,1.031707742869854,0.693875,|,1.0642974443435669,0.6545,|,7.976835749543641e-05 39 | 36,1.030346744298935,0.69478125,|,1.058531328678131,0.6655,|,7.976835749543641e-05 40 | 37,1.0260019425749778,0.702375,|,1.0416859436035155,0.685,|,7.179159353748631e-05 41 | 38,1.0228070938587188,0.7036875,|,1.0462041351795197,0.673125,|,6.461249879623648e-05 42 | 39,1.0191641378998757,0.7050625,|,1.0447719173431396,0.673,|,6.461249879623648e-05 43 | 40,1.0178477314114571,0.70365625,|,1.0245844991207123,0.683875,|,5.815130706791989e-05 44 | 41,1.0164055927991866,0.705,|,1.029972083568573,0.6855,|,5.23362286973566e-05 45 | 42,1.012796747148037,0.70909375,|,1.0342148439884187,0.666875,|,5.23362286973566e-05 46 | 43,1.008345476925373,0.70934375,|,1.022366582632065,0.689875,|,4.710265293027387e-05 47 | 44,1.0077247146368027,0.70825,|,1.0165497152805327,0.69025,|,4.239243002967651e-05 48 | 45,1.0067220974564552,0.70778125,|,1.0156841025352479,0.6865,|,4.239243002967651e-05 49 | 46,1.0070675688385964,0.71015625,|,1.0142420885562897,0.691125,|,3.815322517993404e-05 50 | 47,1.0025808094143867,0.7103125,|,1.0182862040996552,0.681125,|,3.433793699987763e-05 51 | 48,1.0052349282503128,0.70984375,|,1.0130113682746886,0.6835,|,3.433793699987763e-05 52 | 49,1.0029132783412933,0.71040625,|,1.0123639512062073,0.68875,|,3.090417420406407e-05 53 | 50,1.0010166531801223,0.71109375,|,1.0098220613002777,0.686625,|,2.781378459744226e-05 54 | 51,0.9993769236207009,0.71059375,|,1.005019094467163,0.687,|,2.781378459744226e-05 55 | 52,1.0007721087932586,0.71078125,|,1.0091100773811341,0.688,|,2.5032431170129203e-05 56 | 53,0.9988831198215484,0.708125,|,1.0127237865924834,0.6845,|,2.2529210582326862e-05 57 | 54,0.9997893477082253,0.71103125,|,1.0044245841503143,0.697375,|,2.2529210582326862e-05 58 | 55,0.9977093946933746,0.7103125,|,1.000846139907837,0.70125,|,2.0276309800403976e-05 59 | 56,0.997953902065754,0.711125,|,1.0006545417308808,0.688125,|,1.8248697069060645e-05 60 | 57,0.9922989777326584,0.71821875,|,1.0011090092658996,0.70225,|,1.8248697069060645e-05 61 | 58,0.9931831551194191,0.71359375,|,1.0033869709968566,0.689125,|,1.6423843785998366e-05 62 | 59,0.9931818928718567,0.70953125,|,0.9995859611034393,0.695875,|,1.4781474188872717e-05 63 | 60,0.9922252687811851,0.7169375,|,1.002436219930649,0.6925,|,1.4781474188872717e-05 64 | 61,0.9935459222197532,0.71471875,|,1.0084782528877259,0.680125,|,1.330334007332552e-05 65 | 62,0.9918156101107597,0.711,|,1.0031493334770203,0.689125,|,1.1973018039011007e-05 66 | 63,0.9899841905832291,0.71525,|,1.002759815454483,0.6805,|,1.1973018039011007e-05 67 | 64,0.9894641071557999,0.714875,|,1.000965327501297,0.6925,|,1.0775727010836917e-05 68 | 65,0.9936699901223183,0.7075625,|,0.9981847593784332,0.698625,|,9.698164007917233e-06 69 | 66,0.9910290696024895,0.7119375,|,0.9972658684253692,0.696375,|,9.698164007917233e-06 70 | 67,0.9891880717277527,0.71775,|,0.9938144853115082,0.704125,|,8.728356335481845e-06 71 | 68,0.9911971803307533,0.71196875,|,1.0007078433036805,0.692125,|,7.855528557462218e-06 72 | 69,0.9905816400051117,0.71440625,|,1.0013626449108124,0.686625,|,7.855528557462218e-06 73 | 70,0.9880280638337136,0.71584375,|,0.994016453742981,0.69525,|,7.069982771698768e-06 74 | 71,0.989318107008934,0.7125625,|,1.0014776418209075,0.694375,|,6.362990857519748e-06 75 | 72,0.9872357841134072,0.718625,|,0.9965774624347686,0.691625,|,6.362990857519748e-06 76 | 73,0.9891435065865517,0.713125,|,0.9994966144561768,0.695375,|,5.7266974984652716e-06 77 | 74,0.9885274119973183,0.71678125,|,0.9981614172458648,0.687375,|,5.154032902651647e-06 78 | 75,0.9867126668095588,0.715625,|,0.9958910403251648,0.69275,|,5.154032902651647e-06 79 | 76,0.9856268544197082,0.72075,|,0.9973285841941834,0.69,|,4.6386342510207334e-06 80 | 77,0.9872510530948639,0.71375,|,0.9979269418716431,0.697875,|,4.174775000693661e-06 81 | 78,0.9895659605264664,0.7139375,|,0.9974048404693604,0.690375,|,4.174775000693661e-06 82 | 79,0.9855795170068741,0.71915625,|,0.9935031626224518,0.70175,|,3.7573012579255527e-06 83 | 80,0.9865176743268966,0.71346875,|,0.9988768138885498,0.6955,|,3.381574513707511e-06 84 | 81,0.9880614857673645,0.7208125,|,0.9980690622329712,0.694375,|,3.381574513707511e-06 85 | 82,0.9851745100021362,0.71790625,|,0.99214475440979,0.6935,|,3.043420105756866e-06 86 | 83,0.9869118260741234,0.7123125,|,0.9976359667778015,0.695375,|,2.7390808342620134e-06 87 | 84,0.9872344358563423,0.71909375,|,0.9979701068401337,0.6895,|,2.7390808342620134e-06 88 | 85,0.9872423234581947,0.71571875,|,0.9959442427158356,0.693375,|,2.465175216011028e-06 89 | 86,0.9880361961722374,0.71590625,|,0.9932924673557282,0.699375,|,2.218659913069838e-06 90 | 87,0.9883375192284584,0.7120625,|,0.9934261329174042,0.696125,|,2.218659913069838e-06 91 | 88,0.9849764352440834,0.720625,|,0.9994953334331512,0.6905,|,1.9967959185587727e-06 92 | 89,0.9871260563731193,0.71659375,|,0.9954492030143738,0.6935,|,1.7971181238210192e-06 93 | 90,0.9861876389980316,0.71521875,|,0.9964611933231354,0.691,|,1.7971181238210192e-06 94 | 91,0.9883370707035065,0.7141875,|,0.994308352470398,0.6925,|,1.617407928846846e-06 95 | 92,0.986999629676342,0.71459375,|,0.9947234945297241,0.69625,|,1.455668591630753e-06 96 | 93,0.9856248727440834,0.7175,|,0.9960993909835816,0.692125,|,1.455668591630753e-06 97 | 94,0.9876667613387108,0.7176875,|,0.99408451628685,0.698125,|,1.3101030425707202e-06 98 | 95,0.9872552558779717,0.71153125,|,0.9936875529289245,0.69525,|,1.1790939174075655e-06 99 | 96,0.987247347176075,0.71546875,|,0.996220009803772,0.6965,|,1.1790939174075655e-06 100 | 97,0.9833368999958039,0.7183125,|,0.9924579939842224,0.6945,|,1.0611855868523958e-06 101 | 98,0.9866373932361603,0.7138125,|,0.9953383429050445,0.69075,|,1e-06 102 | 99,0.9864456036686897,0.71671875,|,0.9935038115978241,0.7015,|,1e-06 103 | 100,0.9882224181890488,0.71353125,|,0.9983947892189026,0.683375,|,1e-06 104 | 101,0.9864632025957107,0.71496875,|,0.9914444577693939,0.700125,|,1e-06 105 | 102,0.9879327654838562,0.710875,|,0.9993712475299835,0.690875,|,1e-06 106 | 103,0.9857791796326637,0.7131875,|,0.9947507474422455,0.697,|,1e-06 107 | 104,0.9856277720332146,0.7165,|,1.0001007421016692,0.695,|,1e-06 108 | 105,0.9851196827888489,0.71503125,|,0.9916622264385223,0.697875,|,1e-06 109 | 106,0.9857503755688667,0.7165625,|,0.9915952196121216,0.701,|,1e-06 110 | 107,0.986286814570427,0.71284375,|,0.9990080921649933,0.691875,|,1e-06 111 | 108,0.987551983654499,0.71275,|,0.9939578359127045,0.694875,|,1e-06 112 | 109,0.9852879179120064,0.71671875,|,0.9969532618522644,0.692625,|,1e-06 113 | 110,0.985292276442051,0.71384375,|,0.994267863035202,0.697875,|,1e-06 114 | 111,0.9869120596051216,0.712625,|,0.9962990345954895,0.694875,|,1e-06 115 | 116 | -------------------------------------------------------------------------------- /saved_models/LGM-Net_5way5shot.csv: -------------------------------------------------------------------------------- 1 | epoch,train_c_loss,train_c_accuracy,val_loss,val_accuracy,test_c_loss,test_c_accuracy,learning_rate 2 | 0,1.4794143481254578,0.37790477108210324,|,1.530621942838033,0.33092064311107,|,0.001 3 | 1,1.4384986971616744,0.42523810613900426,|,1.4905219535827636,0.35187302374839785,|,0.0009000009000009 4 | 2,1.417018749833107,0.4417619166933,|,1.5266384665171306,0.33676191267371175,|,0.0008100016200024299 5 | 3,1.3848891097307205,0.47709525138139725,|,1.3964020309448242,0.4518095357815425,|,0.0008100016200024299 6 | 4,1.3117292701005936,0.528571444362402,|,1.3239590708414715,0.48292064932982126,|,0.0007290021870043739 7 | 5,1.2527057031393052,0.5641428754627704,|,1.266130963007609,0.5146666829983393,|,0.0006561026244065608 8 | 6,1.2135665994882583,0.5775238274335861,|,1.253574998219808,0.5334603346983592,|,0.0006561026244065608 9 | 7,1.1812826035618782,0.59247620908916,|,1.3123717575073242,0.5172063652276992,|,0.0005904929524588572 10 | 8,1.151051842212677,0.6182381139993668,|,1.1657802772521972,0.596571447134018,|,0.0005314441886571602 11 | 9,1.1286021245718003,0.6296666859090329,|,1.1871331020991007,0.5659682716925939,|,0.0005314441886571602 12 | 10,1.113659391105175,0.6359047807753087,|,1.1205877335866292,0.6162539876302083,|,0.00047830024809169223 13 | 11,1.1008097408413886,0.6494285898804665,|,1.1486652447382608,0.6073650981585185,|,0.0004304706537531768 14 | 12,1.0894447775483131,0.652619065463543,|,1.154638181845347,0.5969523995717366,|,0.0004304706537531768 15 | 13,1.0815625839233398,0.6638571609407663,|,1.1705582456588746,0.5756190656820933,|,0.00038742397580183486 16 | 14,1.074609143614769,0.6700952562093735,|,1.0941525882085164,0.6238730347156525,|,0.00034868192690357824 17 | 15,1.066640758395195,0.6724285894930363,|,1.1052334322929382,0.6195555752913157,|,0.00034868192690357824 18 | 16,1.062277755498886,0.6824285891056061,|,1.0794332574208578,0.632253986954689,|,0.00031381404802726846 19 | 17,1.0589765034914016,0.6804285894036293,|,1.091760593255361,0.6353016054630279,|,0.00028243292565746723 20 | 18,1.0523238880634307,0.6819523987472057,|,1.0990214296976726,0.625142876068751,|,0.00028243292565746723 21 | 19,1.0519083257913588,0.6870000175833703,|,1.138675019423167,0.6110476386547089,|,0.00025418988728160777 22 | 20,1.0473543484807015,0.6958095411062241,|,1.0929494552612304,0.642158748070399,|,0.00022877112732457431 23 | 21,1.0414988352060317,0.6935714460015296,|,1.0498243436813355,0.6681904942194621,|,0.00022877112732457431 24 | 22,1.0405406198501588,0.6981428744494915,|,1.1237690176963806,0.6038095417022705,|,0.00020589422048633736 25 | 23,1.0361715241074563,0.701761921852827,|,1.0494243599573772,0.6741587481498719,|,0.00018530498374268737 26 | 24,1.0353034226894378,0.7035238266289234,|,1.084030127843221,0.6469841453234355,|,0.00018530498374268737 27 | 25,1.0325324236154556,0.7055238265097141,|,1.0854471696217856,0.6471111296017965,|,0.00016677465214307076 28 | 26,1.0339876180291176,0.7040952550470829,|,1.0402559378941854,0.680253986120224,|,0.0001500973370261007 29 | 27,1.0273289981484413,0.7111904927194118,|,1.0464825895627339,0.6765714464187622,|,0.0001500973370261007 30 | 28,1.0206399053931237,0.7150476355850697,|,1.0677538121541341,0.6670476366678874,|,0.00013508773841122904 31 | 29,1.0236930932998658,0.7158095403313637,|,1.0442526073455811,0.6797460496425629,|,0.00012157908614919228 32 | 30,1.0238992015719415,0.7155714449882508,|,1.0360242910385131,0.680127001841863,|,0.00012157908614919228 33 | 31,1.0202175869941712,0.7176666828095913,|,1.0326108633677165,0.6754285895824432,|,0.00010942128695556 34 | 32,1.0159290164113044,0.7232381113767624,|,1.0333374813397724,0.6893968427975973,|,9.847925673926074e-05 35 | 33,1.014186650276184,0.7272857300043106,|,1.0384174467722576,0.6817777953942616,|,9.847925673926074e-05 36 | 34,1.0131457862854003,0.7268095397055149,|,1.061963483651479,0.6662857331434886,|,8.863141969675436e-05 37 | 35,1.015402841746807,0.7264285874664783,|,1.0251935385068258,0.6961270011266073,|,7.976835749543641e-05 38 | 36,1.013168230175972,0.7243809684216976,|,1.0401739360491435,0.6816508119106293,|,7.976835749543641e-05 39 | 37,1.0090720003247262,0.7269523967802525,|,1.0358487337430318,0.6821587478319804,|,7.179159353748631e-05 40 | 38,1.0069879306554794,0.7340476345419884,|,1.0533957619667054,0.6683174784183502,|,6.461249879623648e-05 41 | 39,1.0074534192085267,0.7340952536165715,|,1.0336142470041911,0.6848254144191742,|,6.461249879623648e-05 42 | 40,1.0077199192643165,0.7366190629601479,|,1.0204657464027405,0.700063509384791,|,5.815130706791989e-05 43 | 41,1.0086787326931954,0.7295238252282142,|,1.0293830089569092,0.6901587476730346,|,5.23362286973566e-05 44 | 42,1.00710616004467,0.7365238249003887,|,1.0310692043304444,0.6893968430360158,|,5.23362286973566e-05 45 | 43,1.0047642387151718,0.7381904916763306,|,1.022451763788859,0.6962539854844412,|,4.710265293027387e-05 46 | 44,1.004776228070259,0.7359047774076461,|,1.0285942800839742,0.6864762081305186,|,4.239243002967651e-05 47 | 45,1.0049628773927688,0.736476205766201,|,1.0198976454734803,0.7024762071768442,|,4.239243002967651e-05 48 | 46,1.003290397465229,0.7345238249599934,|,1.0179931246439615,0.7046349376837413,|,3.815322517993404e-05 49 | 47,1.0039572761058808,0.7398571581840515,|,1.0260509300231933,0.6928254141807556,|,3.433793699987763e-05 50 | 48,1.0032896136641503,0.7381904915571212,|,1.0241142010688782,0.6905397000312805,|,3.433793699987763e-05 51 | 49,1.0009824174046515,0.7423333485424518,|,1.016299942970276,0.6982857314745585,|,3.090417420406407e-05 52 | 50,1.0021914989352225,0.7429523960053921,|,1.0219259794553122,0.6918095413843791,|,2.781378459744226e-05 53 | 51,0.9987834879755974,0.7450000148415565,|,1.014530428727468,0.7069206516742707,|,2.781378459744226e-05 54 | 52,1.0018584983944894,0.7366666820049286,|,1.0158259228070576,0.6999365250269572,|,2.5032431170129203e-05 55 | 53,1.001616661787033,0.737000015437603,|,1.0143155024846395,0.7094603340625762,|,2.2529210582326862e-05 56 | 54,0.998037013053894,0.746809538692236,|,1.0153810903231304,0.6993016045888265,|,2.2529210582326862e-05 57 | 55,1.0019707735180854,0.7357619203031063,|,1.0117404109636943,0.7051428740819295,|,2.0276309800403976e-05 58 | 56,0.9988146208524704,0.745428586423397,|,1.0181722151438395,0.6968889060815175,|,1.8248697069060645e-05 59 | 57,0.997735205411911,0.7450000149309635,|,1.0146190516153972,0.7042539850076039,|,1.8248697069060645e-05 60 | 58,1.0006219379305838,0.7371428725123406,|,1.0148262224197389,0.7005714457829794,|,1.6423843785998366e-05 61 | 59,0.999027640581131,0.7481428719460964,|,1.015043905099233,0.7003174778620402,|,1.4781474188872717e-05 62 | 60,0.997166543841362,0.7546190620660782,|,1.019954841295878,0.6958730332056682,|,1.4781474188872717e-05 63 | 61,0.9980602219700814,0.7440000151097774,|,1.0125636410713197,0.7113650958538056,|,1.330334007332552e-05 64 | 62,0.9995630125999451,0.7420476342439651,|,1.0149098768234253,0.7022222392559051,|,1.1973018039011007e-05 65 | 63,0.9981341549158096,0.7454762054383754,|,1.0123273126284282,0.7050158898830414,|,1.1973018039011007e-05 66 | 64,0.9975689168572426,0.7453333483040333,|,1.0165470487276713,0.7008254137833914,|,1.0775727010836917e-05 67 | 65,0.9975241376161575,0.7433333484828473,|,1.018069540341695,0.6957460490862528,|,9.698164007917233e-06 68 | 66,0.9974723815917969,0.7431428723037243,|,1.0140311024983724,0.7005714455445607,|,9.698164007917233e-06 69 | 67,0.9972024328112602,0.7464762053489685,|,1.0152619152069091,0.7004444613456726,|,8.728356335481845e-06 70 | 68,0.9968251070380211,0.7445238244831562,|,1.0154408113161724,0.6938412872950236,|,7.855528557462218e-06 71 | 69,1.0013459552526474,0.7434762054681778,|,1.0154476375579835,0.6952381128072739,|,7.855528557462218e-06 72 | 70,0.9961092925071716,0.7468571577966213,|,1.0151561986605326,0.6958730335235596,|,7.069982771698768e-06 73 | 71,0.9939195781946182,0.752000014692545,|,1.0161791604359944,0.7064127152760824,|,6.362990857519748e-06 74 | 72,0.9955473267436028,0.7468095387518406,|,1.0154679689407349,0.7043809694449107,|,6.362990857519748e-06 75 | 73,0.9988683499097825,0.7446666816473008,|,1.0139406994183857,0.6990476361910503,|,5.7266974984652716e-06 76 | 74,0.9980757756233215,0.7447143007218838,|,1.0159306480089823,0.7028571595350901,|,5.154032902651647e-06 77 | 75,0.9967172985076904,0.751857157498598,|,1.0101400394439697,0.7051428737640381,|,5.154032902651647e-06 78 | 76,0.9983731025457382,0.744095253109932,|,1.0191668395996094,0.7026031911373138,|,4.6386342510207334e-06 79 | 77,0.9972426945567131,0.7439523958861828,|,1.0137869313557943,0.7037460487683614,|,4.174775000693661e-06 80 | 78,0.9958418239951133,0.7475238244831562,|,1.01552632188797,0.7055238262812297,|,4.174775000693661e-06 81 | 79,0.9946665861606598,0.7448571577966213,|,1.0185982538859049,0.6940952552159627,|,3.7573012579255527e-06 82 | 80,0.9945237467885018,0.7507143003046512,|,1.0156020121574403,0.6981587470372518,|,3.381574513707511e-06 83 | 81,0.9983999510407447,0.7454762054681778,|,1.0115817187627156,0.7081904927889506,|,3.381574513707511e-06 84 | 82,0.9978004472851754,0.7451904910802841,|,1.0162606894175212,0.6947301756540935,|,3.043420105756866e-06 85 | 83,0.9954412258863449,0.747761919438839,|,1.0129364191691081,0.7071746202309926,|,2.7390808342620134e-06 86 | 84,0.9974803032875061,0.7484285862147808,|,1.0130406432151795,0.7078095404307048,|,2.7390808342620134e-06 87 | 85,0.9963329938650132,0.7473333482444287,|,1.0123944587707518,0.7013333497842152,|,2.465175216011028e-06 88 | 86,0.9968853131532669,0.7491904909312725,|,1.0152585169474284,0.709079381386439,|,2.218659913069838e-06 89 | 87,0.9967688482999801,0.7497619194090366,|,1.0130866349538168,0.7028571597735087,|,2.218659913069838e-06 90 | 88,0.9955883451104164,0.7477619194984436,|,1.0087193665504455,0.7103492232958476,|,1.9967959185587727e-06 91 | 89,0.9963952634334564,0.7499047765731811,|,1.0119079387982686,0.7052698576450348,|,1.7971181238210192e-06 92 | 90,0.9940502173900604,0.749285728931427,|,1.0134299133618672,0.7033650964101156,|,1.7971181238210192e-06 93 | 91,0.9967742608189583,0.7468571577966213,|,1.0143127336502076,0.7045079535643259,|,1.617407928846846e-06 94 | 92,0.9993314532637596,0.7480000147819519,|,1.0140020945866903,0.7055238266785939,|,1.455668591630753e-06 95 | 93,0.9970857136249542,0.7457143004536628,|,1.0146337610880534,0.7050158898035686,|,1.455668591630753e-06 96 | 94,0.9961443133354188,0.749476205199957,|,1.011101448535919,0.7071746197541555,|,1.3101030425707202e-06 97 | 95,0.9979928250908852,0.7442381103634834,|,1.0089339674313864,0.7097143024603526,|,1.1790939174075655e-06 98 | 96,0.9988329073190689,0.7500476338267327,|,1.0113442328770956,0.7076825562318166,|,1.1790939174075655e-06 99 | 97,0.9937769343852997,0.7515714432299138,|,1.012561622619629,0.7052698578834534,|,1.0611855868523958e-06 100 | 98,0.9956936491727829,0.7499047764539719,|,1.0098556286493938,0.7069206515947978,|,1e-06 101 | 99,0.9971122528910636,0.7433809674978257,|,1.0148551165262858,0.7062857313156128,|,1e-06 102 | 100,0.9955465306043625,0.7473809671998024,|,1.0122335311571757,0.7069206516742707,|,1e-06 103 | 101,0.9973875126838684,0.7449523957967759,|,1.0085430471102397,0.7040000166098277,|,1e-06 104 | 102,0.9960169767737389,0.7478095386326313,|,1.014762407620748,0.7097143026192984,|,1e-06 105 | 103,0.9968047832846642,0.7469047768712044,|,1.0123548758824665,0.7014603346188863,|,1e-06 106 | 104,0.9961887421011925,0.7440000150799752,|,1.0126547096570333,0.7059047785600027,|,1e-06 107 | 105,0.9937586798667908,0.7481428720355033,|,1.0146249049504599,0.7048889058430989,|,1e-06 108 | 106,0.9968654531240463,0.7458571577072144,|,1.0059006803830464,0.7084444614251455,|,1e-06 109 | 107,0.9949011039137841,0.749523824185133,|,1.0084359774589537,0.716698429107666,|,1e-06 110 | 108,0.9923029226064682,0.753714300274849,|,1.014968035697937,0.7040000169277191,|,1e-06 111 | 109,0.9976140464544296,0.7463809674084186,|,1.0085040259361266,0.7045079535643259,|,1e-06 112 | 110,0.9952442250847816,0.7496666814684868,|,1.010713170369466,0.7067936675548554,|,1e-06 113 | 111,0.9943325238227845,0.7447619197070598,|,1.0150070544878642,0.7046349376837413,|,1e-06 114 | 112,0.9955504981279373,0.7501904909312725,|,1.0113677722613017,0.7080635092258454,|,1e-06 115 | 113,0.9955820127129554,0.7462381098866463,|,1.0154381992022197,0.6940952554543813,|,1e-06 116 | 114,0.9952063945531845,0.7489523957073688,|,1.013160351594289,0.7097143026192984,|,1e-06 117 | 115,0.9968226038813591,0.7469047767221928,|,1.0136601546605428,0.7078095406691234,|,1e-06 118 | 116,0.9975639362335205,0.7470952529907227,|,1.0127662817637126,0.6995555723508199,|,1e-06 119 | 117,0.9957504984140396,0.7466190625429153,|,1.0111416562398274,0.6994285881519318,|,1e-06 120 | 118,0.9972883882522583,0.749095252752304,|,1.0139512082735698,0.7013333506584167,|,1e-06 121 | -------------------------------------------------------------------------------- /storage.py: -------------------------------------------------------------------------------- 1 | import csv 2 | 3 | def save_statistics(experiment_name, line_to_add): 4 | with open("{}.csv".format(experiment_name), 'a') as f: 5 | writer = csv.writer(f) 6 | writer.writerow(line_to_add) 7 | 8 | def load_statistics(experiment_name): 9 | data_dict = dict() 10 | with open("{}.csv".format(experiment_name), 'r') as f: 11 | lines = f.readlines() 12 | data_labels = lines[0].replace("\n", "").split(",") 13 | del lines[0] 14 | 15 | for label in data_labels: 16 | data_dict[label] = [] 17 | 18 | for line in lines: 19 | data = line.replace("\n", "").split(",") 20 | for key, item in zip(data_labels, data): 21 | data_dict[key].append(item) 22 | return data_dict 23 | -------------------------------------------------------------------------------- /train_meta_matching_network.py: -------------------------------------------------------------------------------- 1 | from meta_matching_network import * 2 | from experiment_builder import ExperimentBuilder 3 | import tensorflow.contrib.slim as slim 4 | import data as dataset 5 | import tqdm 6 | from storage import save_statistics 7 | import tensorflow as tf 8 | 9 | import argparse 10 | 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--way', type=int, default=5, 15 | help='classes per set (default: 5)') 16 | parser.add_argument('--shot', type=int, default=1, 17 | help='samples per class (default: 1)') 18 | parser.add_argument('--is_test', type=bool, default=False, help="Select FALSE for training, and True for testing") 19 | parser.add_argument('--ckp', type=int, default=-1, 20 | help='Select corresponding checkpoint for testing (default: -1)') 21 | opt = parser.parse_args() 22 | print(opt) 23 | 24 | tf.reset_default_graph() 25 | # Experiment Setup 26 | sp = 1 # split 27 | batch_size = int(32 // sp) # default 32 for 5way1shot 28 | classes_per_set = opt.way #20 29 | samples_per_class = opt.shot 30 | # N-way, K-shot 31 | continue_from_epoch = opt.ckp # use -1 to start from scratch 32 | logs_path = "one_shot_outputs/" 33 | experiment_name = "LGM-Net_{}way{}shot".format(classes_per_set, samples_per_class) 34 | 35 | # Experiment builder 36 | data = dataset.MiniImageNetDataSet(batch_size=batch_size, classes_per_set=classes_per_set, 37 | samples_per_class=samples_per_class, shuffle_classes=True) 38 | experiment = ExperimentBuilder(data) 39 | one_shot_miniImagenet, losses, c_error_opt_op, init = experiment.build_experiment(batch_size, 40 | classes_per_set, 41 | samples_per_class) 42 | total_epochs = 120 43 | total_train_batches = 1000 44 | total_val_batches = int(250 * sp) 45 | total_test_batches = int(250 * sp) 46 | 47 | 48 | logs="{}way{}shot learning problems, with {} tasks per task batch".format(classes_per_set, samples_per_class, batch_size) 49 | save_statistics(experiment_name, ["Experimental details: {}".format(logs)]) 50 | save_statistics(experiment_name, ["epoch", "train_c_loss", "train_c_accuracy", "val_loss", "val_accuracy", 51 | "test_c_loss", "test_c_accuracy", "learning_rate"]) 52 | 53 | 54 | # Experiment initialization and running 55 | config = tf.ConfigProto() 56 | config.gpu_options.allow_growth = True 57 | config.gpu_options.visible_device_list = "0" 58 | with tf.Session(config=config) as sess: 59 | sess.run(init) 60 | saver = tf.train.Saver(max_to_keep=5) 61 | if continue_from_epoch != -1: #load checkpoint if needed 62 | print("Loading from checkpoint") 63 | checkpoint = "saved_models/{}_{}.ckpt".format(experiment_name, continue_from_epoch) 64 | variables_to_restore = [] 65 | tf.logging.info("The variables to restore") 66 | for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): 67 | print(var.name, var.get_shape()) 68 | variables_to_restore.append(var) 69 | 70 | tf.logging.info('Fine-tuning from %s' % checkpoint) 71 | fine_tune = slim.assign_from_checkpoint_fn(checkpoint, variables_to_restore, ignore_missing_vars=True) 72 | fine_tune(sess) 73 | 74 | if opt.is_test: 75 | total_val_c_loss, total_val_accuracy = experiment.run_validation_epoch(total_val_batches=total_val_batches, sess=sess) 76 | print("Validating : val_loss: {}, val_accuracy: {}".format(total_val_c_loss, total_val_accuracy)) 77 | total_test_c_loss, total_test_accuracy = experiment.run_testing_epoch(total_test_batches=total_test_batches, sess=sess) 78 | print("Testing: test_loss: {}, test_accuracy: {}".format(total_test_c_loss, total_test_accuracy)) 79 | total_sg_test_accuracy, total_es_test_accuracy = experiment.run_ensemble_testing_epoch(total_test_batches=total_test_batches, sess=sess) 80 | print("Testing Ensemble: single accuracy {}, ensemble accuracy: {}".format(total_sg_test_accuracy, total_es_test_accuracy)) 81 | else: 82 | with tqdm.tqdm(total=total_epochs) as pbar_e: 83 | for e in range(0, total_epochs): 84 | total_c_loss, total_accuracy, lr = experiment.run_training_epoch(total_train_batches=total_train_batches,sess=sess) 85 | print("Epoch {}: train_loss: {}, train_accuracy: {}".format(e, total_c_loss, total_accuracy)) 86 | 87 | total_val_c_loss, total_val_accuracy = experiment.run_validation_epoch(total_val_batches=total_val_batches, sess=sess) 88 | print("Epoch {}: val_loss: {}, val_accuracy: {}".format(e, total_val_c_loss, total_val_accuracy)) 89 | 90 | total_test_c_loss, total_test_accuracy = experiment.run_testing_epoch(total_test_batches=total_test_batches, sess=sess) 91 | print("Epoch {}: test_loss: {}, test_accuracy: {}".format(e, total_test_c_loss, total_test_accuracy)) 92 | 93 | save_statistics(experiment_name, [e, total_c_loss, total_accuracy, total_val_c_loss, total_val_accuracy, 94 | total_test_c_loss, total_test_accuracy, 'lr: {}'.format(lr)]) 95 | 96 | save_path = saver.save(sess, "saved_models/{}_{}.ckpt".format(experiment_name, e)) 97 | pbar_e.update(1) 98 | --------------------------------------------------------------------------------