├── .gitignore ├── pytorch ├── models │ ├── __init__.py │ ├── conv2d_mtl.py │ ├── mtl.py │ └── resnet_mtl.py ├── trainer │ ├── __init__.py │ ├── pre.py │ └── meta.py ├── utils │ ├── __init__.py │ ├── gpu_tools.py │ └── misc.py ├── dataloader │ ├── __init__.py │ ├── samplers.py │ └── dataset_loader.py ├── run_pre.py ├── run_meta.py ├── README.md └── main.py ├── tensorflow ├── utils │ ├── __init__.py │ └── misc.py ├── models │ ├── __init__.py │ ├── pre_model.py │ ├── resnet12.py │ ├── meta_model.py │ └── resnet18.py ├── trainer │ ├── __init__.py │ ├── pre.py │ └── meta.py ├── data_generator │ ├── __init__.py │ ├── pre_data_generator.py │ └── meta_data_generator.py ├── README.md ├── run_experiment.py └── main.py ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | # File types 2 | *.pyc 3 | *.npy 4 | 5 | # File 6 | .DS_Store 7 | -------------------------------------------------------------------------------- /pytorch/models/__init__.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Tianjin University 4 | ## liuyaoyao@tju.edu.cn 5 | ## Copyright (c) 2019 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | -------------------------------------------------------------------------------- /pytorch/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Tianjin University 4 | ## liuyaoyao@tju.edu.cn 5 | ## Copyright (c) 2019 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | -------------------------------------------------------------------------------- /pytorch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Tianjin University 4 | ## liuyaoyao@tju.edu.cn 5 | ## Copyright (c) 2019 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | -------------------------------------------------------------------------------- /tensorflow/utils/__init__.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Tianjin University 4 | ## liuyaoyao@tju.edu.cn 5 | ## Copyright (c) 2019 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | -------------------------------------------------------------------------------- /pytorch/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Tianjin University 4 | ## liuyaoyao@tju.edu.cn 5 | ## Copyright (c) 2019 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | -------------------------------------------------------------------------------- /tensorflow/models/__init__.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Tianjin University 4 | ## liuyaoyao@tju.edu.cn 5 | ## Copyright (c) 2019 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | -------------------------------------------------------------------------------- /tensorflow/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Tianjin University 4 | ## liuyaoyao@tju.edu.cn 5 | ## Copyright (c) 2019 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | -------------------------------------------------------------------------------- /tensorflow/data_generator/__init__.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Tianjin University 4 | ## liuyaoyao@tju.edu.cn 5 | ## Copyright (c) 2019 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | -------------------------------------------------------------------------------- /pytorch/utils/gpu_tools.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Tianjin University 4 | ## liuyaoyao@tju.edu.cn 5 | ## Copyright (c) 2019 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | """ Tools for GPU. """ 11 | import os 12 | import torch 13 | 14 | def set_gpu(cuda_device): 15 | os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device 16 | print('Using gpu:', cuda_device) 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yaoyao Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pytorch/run_pre.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Tianjin University 4 | ## liuyaoyao@tju.edu.cn 5 | ## Copyright (c) 2019 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | """ Generate commands for pre-train phase. """ 11 | import os 12 | 13 | def run_exp(lr=0.1, gamma=0.2, step_size=30): 14 | max_epoch = 110 15 | shot = 1 16 | query = 15 17 | way = 5 18 | gpu = 1 19 | base_lr = 0.01 20 | 21 | the_command = 'python3 main.py' \ 22 | + ' --pre_max_epoch=' + str(max_epoch) \ 23 | + ' --shot=' + str(shot) \ 24 | + ' --train_query=' + str(query) \ 25 | + ' --way=' + str(way) \ 26 | + ' --pre_step_size=' + str(step_size) \ 27 | + ' --pre_gamma=' + str(gamma) \ 28 | + ' --gpu=' + str(gpu) \ 29 | + ' --base_lr=' + str(base_lr) \ 30 | + ' --pre_lr=' + str(lr) \ 31 | + ' --phase=pre_train' 32 | 33 | os.system(the_command) 34 | 35 | run_exp(lr=0.1, gamma=0.2, step_size=30) 36 | -------------------------------------------------------------------------------- /pytorch/dataloader/samplers.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Modified from: https://github.com/Sha-Lab/FEAT 4 | ## Tianjin University 5 | ## liuyaoyao@tju.edu.cn 6 | ## Copyright (c) 2019 7 | ## 8 | ## This source code is licensed under the MIT-style license found in the 9 | ## LICENSE file in the root directory of this source tree 10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 11 | """ Sampler for dataloader. """ 12 | import torch 13 | import numpy as np 14 | 15 | class CategoriesSampler(): 16 | """The class to generate episodic data""" 17 | def __init__(self, label, n_batch, n_cls, n_per): 18 | self.n_batch = n_batch 19 | self.n_cls = n_cls 20 | self.n_per = n_per 21 | 22 | label = np.array(label) 23 | self.m_ind = [] 24 | for i in range(max(label) + 1): 25 | ind = np.argwhere(label == i).reshape(-1) 26 | ind = torch.from_numpy(ind) 27 | self.m_ind.append(ind) 28 | 29 | def __len__(self): 30 | return self.n_batch 31 | def __iter__(self): 32 | for i_batch in range(self.n_batch): 33 | batch = [] 34 | classes = torch.randperm(len(self.m_ind))[:self.n_cls] 35 | for c in classes: 36 | l = self.m_ind[c] 37 | pos = torch.randperm(len(l))[:self.n_per] 38 | batch.append(l[pos]) 39 | batch = torch.stack(batch).t().reshape(-1) 40 | yield batch 41 | -------------------------------------------------------------------------------- /pytorch/run_meta.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Tianjin University 4 | ## liuyaoyao@tju.edu.cn 5 | ## Copyright (c) 2019 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | """ Generate commands for meta-train phase. """ 11 | import os 12 | 13 | def run_exp(num_batch=1000, shot=1, query=15, lr1=0.0001, lr2=0.001, base_lr=0.01, update_step=10, gamma=0.5): 14 | max_epoch = 100 15 | way = 5 16 | step_size = 10 17 | gpu = 1 18 | 19 | the_command = 'python3 main.py' \ 20 | + ' --max_epoch=' + str(max_epoch) \ 21 | + ' --num_batch=' + str(num_batch) \ 22 | + ' --shot=' + str(shot) \ 23 | + ' --train_query=' + str(query) \ 24 | + ' --way=' + str(way) \ 25 | + ' --meta_lr1=' + str(lr1) \ 26 | + ' --meta_lr2=' + str(lr2) \ 27 | + ' --step_size=' + str(step_size) \ 28 | + ' --gamma=' + str(gamma) \ 29 | + ' --gpu=' + str(gpu) \ 30 | + ' --base_lr=' + str(base_lr) \ 31 | + ' --update_step=' + str(update_step) 32 | 33 | os.system(the_command + ' --phase=meta_train') 34 | os.system(the_command + ' --phase=meta_eval') 35 | 36 | run_exp(num_batch=100, shot=1, query=15, lr1=0.0001, lr2=0.001, base_lr=0.01, update_step=100, gamma=0.5) 37 | run_exp(num_batch=100, shot=5, query=15, lr1=0.0001, lr2=0.001, base_lr=0.01, update_step=100, gamma=0.5) 38 | -------------------------------------------------------------------------------- /pytorch/utils/misc.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Modified from: https://github.com/Sha-Lab/FEAT 4 | ## Tianjin University 5 | ## liuyaoyao@tju.edu.cn 6 | ## Copyright (c) 2019 7 | ## 8 | ## This source code is licensed under the MIT-style license found in the 9 | ## LICENSE file in the root directory of this source tree 10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 11 | """ Additional utility functions. """ 12 | import os 13 | import time 14 | import pprint 15 | import torch 16 | import numpy as np 17 | import torch.nn.functional as F 18 | 19 | def ensure_path(path): 20 | """The function to make log path. 21 | Args: 22 | path: the generated saving path. 23 | """ 24 | if os.path.exists(path): 25 | pass 26 | else: 27 | os.mkdir(path) 28 | 29 | class Averager(): 30 | """The class to calculate the average.""" 31 | def __init__(self): 32 | self.n = 0 33 | self.v = 0 34 | 35 | def add(self, x): 36 | self.v = (self.v * self.n + x) / (self.n + 1) 37 | self.n += 1 38 | 39 | def item(self): 40 | return self.v 41 | 42 | def count_acc(logits, label): 43 | """The function to calculate the . 44 | Args: 45 | logits: input logits. 46 | label: ground truth labels. 47 | Return: 48 | The output accuracy. 49 | """ 50 | pred = F.softmax(logits, dim=1).argmax(dim=1) 51 | if torch.cuda.is_available(): 52 | return (pred == label).type(torch.cuda.FloatTensor).mean().item() 53 | return (pred == label).type(torch.FloatTensor).mean().item() 54 | 55 | class Timer(): 56 | """The class for timer.""" 57 | def __init__(self): 58 | self.o = time.time() 59 | 60 | def measure(self, p=1): 61 | x = (time.time() - self.o) / p 62 | x = int(x) 63 | if x >= 3600: 64 | return '{:.1f}h'.format(x / 3600) 65 | if x >= 60: 66 | return '{}m'.format(round(x / 60)) 67 | return '{}s'.format(x) 68 | 69 | _utils_pp = pprint.PrettyPrinter() 70 | 71 | def pprint(x): 72 | _utils_pp.pprint(x) 73 | 74 | def compute_confidence_interval(data): 75 | """The function to calculate the . 76 | Args: 77 | data: input records 78 | label: ground truth labels. 79 | Return: 80 | m: mean value 81 | pm: confidence interval. 82 | """ 83 | a = 1.0 * np.array(data) 84 | m = np.mean(a) 85 | std = np.std(a) 86 | pm = 1.96 * (std / np.sqrt(len(a))) 87 | return m, pm 88 | -------------------------------------------------------------------------------- /tensorflow/data_generator/pre_data_generator.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Modified from: https://github.com/cbfinn/maml 4 | ## Tianjin University 5 | ## liuyaoyao@tju.edu.cn 6 | ## Copyright (c) 2019 7 | ## 8 | ## This source code is licensed under the MIT-style license found in the 9 | ## LICENSE file in the root directory of this source tree 10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 11 | 12 | """ Data generator for pre-train phase. """ 13 | import numpy as np 14 | import os 15 | import random 16 | import tensorflow as tf 17 | from tqdm import trange 18 | 19 | from tensorflow.python.platform import flags 20 | from utils.misc import get_pretrain_images 21 | 22 | FLAGS = flags.FLAGS 23 | 24 | class PreDataGenerator(object): 25 | """The class to generate episodes for pre-train phase.""" 26 | def __init__(self): 27 | self.num_classes = FLAGS.way_num 28 | self.img_size = (FLAGS.img_size, FLAGS.img_size) 29 | self.dim_input = np.prod(self.img_size)*3 30 | self.pretrain_class_num = FLAGS.pretrain_class_num 31 | self.pretrain_batch_size = FLAGS.pretrain_batch_size 32 | pretrain_folder = FLAGS.pretrain_folders 33 | 34 | pretrain_folders = [os.path.join(pretrain_folder, label) for label in os.listdir(pretrain_folder) if os.path.isdir(os.path.join(pretrain_folder, label))] 35 | self.pretrain_character_folders = pretrain_folders 36 | 37 | def make_data_tensor(self): 38 | """The function to make tensor for the tensorflow model.""" 39 | print('Generating pre-training data') 40 | all_filenames_and_labels = [] 41 | folders = self.pretrain_character_folders 42 | 43 | for idx, path in enumerate(folders): 44 | all_filenames_and_labels += get_pretrain_images(path, idx) 45 | random.shuffle(all_filenames_and_labels) 46 | all_labels = [li[0] for li in all_filenames_and_labels] 47 | all_filenames = [li[1] for li in all_filenames_and_labels] 48 | filename_queue = tf.train.string_input_producer(tf.convert_to_tensor(all_filenames), shuffle=False) 49 | label_queue = tf.train.slice_input_producer([tf.convert_to_tensor(all_labels)], shuffle=False) 50 | image_reader = tf.WholeFileReader() 51 | _, image_file = image_reader.read(filename_queue) 52 | 53 | image = tf.image.decode_jpeg(image_file, channels=3) 54 | image.set_shape((self.img_size[0],self.img_size[1],3)) 55 | image = tf.reshape(image, [self.dim_input]) 56 | image = tf.cast(image, tf.float32) / 255.0 57 | 58 | num_preprocess_threads = 1 59 | min_queue_examples = 256 60 | batch_image_size = self.pretrain_batch_size 61 | image_batch, label_batch = tf.train.batch([image, label_queue], batch_size = batch_image_size, num_threads=num_preprocess_threads,capacity=min_queue_examples + 3 * batch_image_size) 62 | label_batch = tf.one_hot(tf.reshape(label_batch, [-1]), self.pretrain_class_num) 63 | return image_batch, label_batch 64 | -------------------------------------------------------------------------------- /tensorflow/models/pre_model.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Tianjin University 4 | ## liuyaoyao@tju.edu.cn 5 | ## Copyright (c) 2019 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | """ Models for pre-train phase. """ 12 | import tensorflow as tf 13 | from tensorflow.python.platform import flags 14 | 15 | FLAGS = flags.FLAGS 16 | 17 | def MakePreModel(): 18 | """The function to make pre model. 19 | Arg: 20 | Pre-train model class. 21 | """ 22 | if FLAGS.backbone_arch=='resnet12': 23 | try:#python2 24 | from resnet12 import Models 25 | except ImportError:#python3 26 | from models.resnet12 import Models 27 | elif FLAGS.backbone_arch=='resnet18': 28 | try:#python2 29 | from resnet18 import Models 30 | except ImportError:#python3 31 | from models.resnet18 import Models 32 | else: 33 | print('Please set the correct backbone') 34 | 35 | class PreModel(Models): 36 | """The class for pre-train model.""" 37 | def construct_pretrain_model(self, input_tensors=None, is_val=False): 38 | """The function to construct pre-train model. 39 | Args: 40 | input_tensors: the input tensor to construct pre-train model. 41 | is_val: whether the model is for validation. 42 | """ 43 | self.input = input_tensors['pretrain_input'] 44 | self.label = input_tensors['pretrain_label'] 45 | with tf.variable_scope('pretrain-model', reuse=None) as training_scope: 46 | self.weights = weights = self.construct_resnet_weights() 47 | self.fc_weights = fc_weights = self.construct_fc_weights() 48 | 49 | if is_val is False: 50 | self.pretrain_task_output = self.forward_fc(self.forward_pretrain_resnet(self.input, weights, reuse=False), fc_weights) 51 | self.pretrain_task_loss = self.pretrain_loss_func(self.pretrain_task_output, self.label) 52 | optimizer = tf.train.AdamOptimizer(self.pretrain_lr) 53 | self.pretrain_op = optimizer.minimize(self.pretrain_task_loss, var_list=weights.values()+fc_weights.values()) 54 | self.pretrain_task_accuracy = tf.reduce_mean(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax( \ 55 | self.pretrain_task_output), 1), tf.argmax(self.label, 1))) 56 | tf.summary.scalar('pretrain train loss', self.pretrain_task_loss) 57 | tf.summary.scalar('pretrain train accuracy', self.pretrain_task_accuracy) 58 | else: 59 | self.pretrain_task_output_val = self.forward_fc(self.forward_pretrain_resnet(self.input, weights, reuse=True), fc_weights) 60 | self.pretrain_task_accuracy_val = tf.reduce_mean(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax( \ 61 | self.pretrain_task_output_val), 1), tf.argmax(self.label, 1))) 62 | tf.summary.scalar('pretrain val accuracy', self.pretrain_task_accuracy_val) 63 | 64 | return PreModel() 65 | -------------------------------------------------------------------------------- /pytorch/dataloader/dataset_loader.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Modified from: https://github.com/Sha-Lab/FEAT 4 | ## Tianjin University 5 | ## liuyaoyao@tju.edu.cn 6 | ## Copyright (c) 2019 7 | ## 8 | ## This source code is licensed under the MIT-style license found in the 9 | ## LICENSE file in the root directory of this source tree 10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 11 | """ Dataloader for all datasets. """ 12 | import os.path as osp 13 | import os 14 | from PIL import Image 15 | from torch.utils.data import Dataset 16 | from torchvision import transforms 17 | import numpy as np 18 | 19 | class DatasetLoader(Dataset): 20 | """The class to load the dataset""" 21 | def __init__(self, setname, args, train_aug=False): 22 | # Set the path according to train, val and test 23 | if setname=='train': 24 | THE_PATH = osp.join(args.dataset_dir, 'train') 25 | label_list = os.listdir(THE_PATH) 26 | elif setname=='test': 27 | THE_PATH = osp.join(args.dataset_dir, 'test') 28 | label_list = os.listdir(THE_PATH) 29 | elif setname=='val': 30 | THE_PATH = osp.join(args.dataset_dir, 'val') 31 | label_list = os.listdir(THE_PATH) 32 | else: 33 | raise ValueError('Wrong setname.') 34 | 35 | # Generate empty list for data and label 36 | data = [] 37 | label = [] 38 | 39 | # Get folders' name 40 | folders = [osp.join(THE_PATH, the_label) for the_label in label_list if os.path.isdir(osp.join(THE_PATH, the_label))] 41 | 42 | # Get the images' paths and labels 43 | for idx, this_folder in enumerate(folders): 44 | this_folder_images = os.listdir(this_folder) 45 | for image_path in this_folder_images: 46 | data.append(osp.join(this_folder, image_path)) 47 | label.append(idx) 48 | 49 | # Set data, label and class number to be accessable from outside 50 | self.data = data 51 | self.label = label 52 | self.num_class = len(set(label)) 53 | 54 | # Transformation 55 | if train_aug: 56 | image_size = 80 57 | self.transform = transforms.Compose([ 58 | transforms.Resize(92), 59 | transforms.RandomResizedCrop(88), 60 | transforms.CenterCrop(image_size), 61 | transforms.RandomHorizontalFlip(), 62 | transforms.ToTensor(), 63 | transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), 64 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))]) 65 | else: 66 | image_size = 80 67 | self.transform = transforms.Compose([ 68 | transforms.Resize(92), 69 | transforms.CenterCrop(image_size), 70 | transforms.ToTensor(), 71 | transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), 72 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))]) 73 | 74 | 75 | def __len__(self): 76 | return len(self.data) 77 | 78 | def __getitem__(self, i): 79 | path, label = self.data[i], self.label[i] 80 | image = self.transform(Image.open(path).convert('RGB')) 81 | return image, label 82 | -------------------------------------------------------------------------------- /pytorch/README.md: -------------------------------------------------------------------------------- 1 | # Meta-Transfer Learning PyTorch 2 | [![Python](https://img.shields.io/badge/python-3.5-blue.svg?style=flat-square&logo=python&color=3776AB)](https://www.python.org/) 3 | [![PyTorch](https://img.shields.io/badge/pytorch-0.4.0-%237732a8?style=flat-square&logo=PyTorch&color=EE4C2C)](https://pytorch.org/) 4 | 5 | #### Summary 6 | 7 | * [Installation](#installation) 8 | * [Project Architecture](#project-architecture) 9 | * [Running Experiments](#running-experiments) 10 | 11 | 12 | ## Installation 13 | 14 | In order to run this repository, we advise you to install python 3.5 and PyTorch 0.4.0 with Anaconda. 15 | 16 | You may download Anaconda and read the installation instruction on their official website: 17 | 18 | 19 | Create a new environment and install PyTorch and torchvision on it: 20 | 21 | ```bash 22 | conda create --name mtl-pytorch python=3.5 23 | conda activate mtl-pytorch 24 | conda install pytorch=0.4.0 25 | conda install torchvision -c pytorch 26 | ``` 27 | 28 | Install other requirements: 29 | ```bash 30 | pip install tqdm tensorboardX miniimagenettools 31 | ``` 32 | 33 | Clone this repository: 34 | 35 | ```bash 36 | git clone https://github.com/yaoyao-liu/meta-transfer-learning.git 37 | cd meta-transfer-learning/pytorch 38 | ``` 39 | 40 | ## Project Architecture 41 | 42 | ``` 43 | . 44 | ├── data_generator 45 | | ├── dataset_loader.py # data loader for all datasets 46 | | └── meta_data_generator.py # samplers for meta train 47 | ├── models 48 | | ├── mtl.py # meta-transfer class 49 | | ├── resnet_mtl.py # resnet class 50 | | └── conv2d_mtl.py # meta-transfer convolution class 51 | ├── trainer 52 | | ├── pre.py # pre-train trainer class 53 | | └── meta.py # meta-train trainer class 54 | ├── utils 55 | | ├── gpu_tools.py # GPU tool functions 56 | | └── misc.py # miscellaneous tool functions 57 | ├── main.py # the python file with main function and parameter settings 58 | ├── run_pre.py # the script to run pre-train phase 59 | └── run_meta.py # the script to run meta-train and meta-test phases 60 | ``` 61 | 62 | ## Running Experiments 63 | 64 | Run pretrain phase: 65 | ```bash 66 | python run_pre.py 67 | ``` 68 | Run meta-train and meta-test phase: 69 | ```bash 70 | python run_meta.py 71 | ``` 72 | 73 | ### Hyperparameters and Options 74 | Hyperparameters and options in `main.py`. 75 | 76 | - `model_type` The network architecture 77 | - `dataset` Meta dataset 78 | - `phase` pre-train, meta-train or meta-eval 79 | - `seed` Manual seed for PyTorch, "0" means using random seed 80 | - `gpu` GPU id 81 | - `dataset_dir` Directory for the images 82 | - `max_epoch` Epoch number for meta-train phase 83 | - `num_batch` The number for different tasks used for meta-train 84 | - `shot` Shot number, how many samples for one class in a task 85 | - `way` Way number, how many classes in a task 86 | - `train_query` The number of training samples for each class in a task 87 | - `val_query` The number of test samples for each class in a task 88 | - `meta_lr1` Learning rate for SS weights 89 | - `meta_lr2` Learning rate for FC weights 90 | - `base_lr` Learning rate for the inner loop 91 | - `update_step` The number of updates for the inner loop 92 | - `step_size` The number of epochs to reduce the meta learning rates 93 | - `gamma` Gamma for the meta-train learning rate decay 94 | - `init_weights` The pretained weights for meta-train phase 95 | - `eval_weights` The meta-trained weights for meta-eval phase 96 | - `meta_label` Additional label for meta-train 97 | - `pre_max_epoch` Epoch number for pre-train pahse 98 | - `pre_batch_size` Batch size for pre-train pahse 99 | - `pre_lr` Learning rate for pre-train pahse 100 | - `pre_gamma` Gamma for the preteain learning rate decay 101 | - `pre_step_size` The number of epochs to reduce the pre-train learning rate 102 | - `pre_custom_weight_decay` Weight decay for the optimizer during pre-train 103 | 104 | -------------------------------------------------------------------------------- /tensorflow/trainer/pre.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Modified from: https://github.com/cbfinn/maml 4 | ## Tianjin University 5 | ## liuyaoyao@tju.edu.cn 6 | ## Copyright (c) 2019 7 | ## 8 | ## This source code is licensed under the MIT-style license found in the 9 | ## LICENSE file in the root directory of this source tree 10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 11 | 12 | """ Trainer for pre-train phase. """ 13 | import os 14 | import numpy as np 15 | import tensorflow as tf 16 | 17 | from tqdm import trange 18 | from data_generator.pre_data_generator import PreDataGenerator 19 | from models.pre_model import MakePreModel 20 | from tensorflow.python.platform import flags 21 | 22 | FLAGS = flags.FLAGS 23 | 24 | class PreTrainer: 25 | """The class that contains the code for the pre-train phase""" 26 | def __init__(self): 27 | # This class defines the pre-train phase trainer 28 | print('Generating pre-training classes') 29 | 30 | # Generate Pre-train Data Tensors 31 | pre_train_data_generator = PreDataGenerator() 32 | pretrain_input, pretrain_label = pre_train_data_generator.make_data_tensor() 33 | pre_train_input_tensors = {'pretrain_input': pretrain_input, 'pretrain_label': pretrain_label} 34 | 35 | # Build Pre-train Model 36 | self.model = MakePreModel() 37 | self.model.construct_pretrain_model(input_tensors=pre_train_input_tensors) 38 | self.model.pretrain_summ_op = tf.summary.merge_all() 39 | 40 | # Start the TensorFlow Session 41 | if FLAGS.full_gpu_memory_mode: 42 | gpu_config = tf.ConfigProto() 43 | gpu_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_rate 44 | self.sess = tf.InteractiveSession(config=gpu_config) 45 | else: 46 | self.sess = tf.InteractiveSession() 47 | 48 | # Initialize and Start the Pre-train Phase 49 | tf.global_variables_initializer().run() 50 | tf.train.start_queue_runners() 51 | self.pre_train() 52 | 53 | def pre_train(self): 54 | # Load Parameters from FLAGS 55 | pretrain_iterations = FLAGS.pretrain_iterations 56 | weights_save_dir_base = FLAGS.pretrain_dir 57 | pre_save_str = FLAGS.pre_string 58 | 59 | # Build Pre-train Log Folder 60 | weights_save_dir = os.path.join(weights_save_dir_base, pre_save_str) 61 | if not os.path.exists(weights_save_dir): 62 | os.mkdir(weights_save_dir) 63 | pretrain_writer = tf.summary.FileWriter(weights_save_dir, self.sess.graph) 64 | pre_lr = FLAGS.pre_lr 65 | 66 | print('Start pre-train phase') 67 | print('Pre-train Hyper parameters: ' + pre_save_str) 68 | 69 | # Start the iterations 70 | for itr in trange(pretrain_iterations): 71 | # Generate the Feed Dict and Run the Optimizer 72 | feed_dict = {self.model.pretrain_lr: pre_lr} 73 | input_tensors = [self.model.pretrain_op, self.model.pretrain_summ_op] 74 | input_tensors.extend([self.model.pretrain_task_loss, self.model.pretrain_task_accuracy]) 75 | result = self.sess.run(input_tensors, feed_dict) 76 | 77 | # Print Results during Training 78 | if (itr!=0) and itr % FLAGS.pre_print_step == 0: 79 | print_str = '[*] Pre Loss: ' + str(result[-2]) + ', Pre Acc: ' + str(result[-1]) 80 | print(print_str) 81 | 82 | # Write the TensorFlow Summery 83 | if itr % FLAGS.pre_sum_step == 0: 84 | pretrain_writer.add_summary(result[1], itr) 85 | 86 | # Decrease the Learning Rate after Some Iterations 87 | if (itr!=0) and itr % FLAGS.pre_lr_dropstep == 0: 88 | pre_lr = pre_lr * 0.5 89 | if FLAGS.pre_lr_stop and pre_lr < FLAGS.min_pre_lr: 90 | pre_lr = FLAGS.min_pre_lr 91 | 92 | # Save Pre-train Model 93 | if (itr!=0) and itr % FLAGS.pre_save_step == 0: 94 | print('Saving pretrain weights to npy') 95 | weights = self.sess.run(self.model.weights) 96 | np.save(os.path.join(weights_save_dir, "weights_{}.npy".format(itr)), weights) 97 | -------------------------------------------------------------------------------- /pytorch/models/conv2d_mtl.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Modified from: https://github.com/pytorch/pytorch 4 | ## Tianjin University 5 | ## liuyaoyao@tju.edu.cn 6 | ## Copyright (c) 2019 7 | ## 8 | ## This source code is licensed under the MIT-style license found in the 9 | ## LICENSE file in the root directory of this source tree 10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 11 | """ MTL CONV layers. """ 12 | import math 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.nn.parameter import Parameter 16 | from torch.nn.modules.module import Module 17 | from torch.nn.modules.utils import _pair 18 | 19 | class _ConvNdMtl(Module): 20 | """The class for meta-transfer convolution""" 21 | def __init__(self, in_channels, out_channels, kernel_size, stride, 22 | padding, dilation, transposed, output_padding, groups, bias): 23 | super(_ConvNdMtl, self).__init__() 24 | if in_channels % groups != 0: 25 | raise ValueError('in_channels must be divisible by groups') 26 | if out_channels % groups != 0: 27 | raise ValueError('out_channels must be divisible by groups') 28 | self.in_channels = in_channels 29 | self.out_channels = out_channels 30 | self.kernel_size = kernel_size 31 | self.stride = stride 32 | self.padding = padding 33 | self.dilation = dilation 34 | self.transposed = transposed 35 | self.output_padding = output_padding 36 | self.groups = groups 37 | if transposed: 38 | self.weight = Parameter(torch.Tensor( 39 | in_channels, out_channels // groups, *kernel_size)) 40 | self.mtl_weight = Parameter(torch.ones(in_channels, out_channels // groups, 1, 1)) 41 | else: 42 | self.weight = Parameter(torch.Tensor( 43 | out_channels, in_channels // groups, *kernel_size)) 44 | self.mtl_weight = Parameter(torch.ones(out_channels, in_channels // groups, 1, 1)) 45 | self.weight.requires_grad=False 46 | if bias: 47 | self.bias = Parameter(torch.Tensor(out_channels)) 48 | self.bias.requires_grad=False 49 | self.mtl_bias = Parameter(torch.zeros(out_channels)) 50 | else: 51 | self.register_parameter('bias', None) 52 | self.register_parameter('mtl_bias', None) 53 | self.reset_parameters() 54 | 55 | def reset_parameters(self): 56 | n = self.in_channels 57 | for k in self.kernel_size: 58 | n *= k 59 | stdv = 1. / math.sqrt(n) 60 | self.weight.data.uniform_(-stdv, stdv) 61 | self.mtl_weight.data.uniform_(1, 1) 62 | if self.bias is not None: 63 | self.bias.data.uniform_(-stdv, stdv) 64 | self.mtl_bias.data.uniform_(0, 0) 65 | 66 | def extra_repr(self): 67 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 68 | ', stride={stride}') 69 | if self.padding != (0,) * len(self.padding): 70 | s += ', padding={padding}' 71 | if self.dilation != (1,) * len(self.dilation): 72 | s += ', dilation={dilation}' 73 | if self.output_padding != (0,) * len(self.output_padding): 74 | s += ', output_padding={output_padding}' 75 | if self.groups != 1: 76 | s += ', groups={groups}' 77 | if self.bias is None: 78 | s += ', bias=False' 79 | return s.format(**self.__dict__) 80 | 81 | class Conv2dMtl(_ConvNdMtl): 82 | """The class for meta-transfer convolution""" 83 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 84 | padding=0, dilation=1, groups=1, bias=True): 85 | kernel_size = _pair(kernel_size) 86 | stride = _pair(stride) 87 | padding = _pair(padding) 88 | dilation = _pair(dilation) 89 | super(Conv2dMtl, self).__init__( 90 | in_channels, out_channels, kernel_size, stride, padding, dilation, 91 | False, _pair(0), groups, bias) 92 | 93 | def forward(self, inp): 94 | new_mtl_weight = self.mtl_weight.expand(self.weight.shape) 95 | new_weight = self.weight.mul(new_mtl_weight) 96 | if self.bias is not None: 97 | new_bias = self.bias + self.mtl_bias 98 | else: 99 | new_bias = None 100 | return F.conv2d(inp, new_weight, new_bias, self.stride, 101 | self.padding, self.dilation, self.groups) 102 | -------------------------------------------------------------------------------- /pytorch/main.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Tianjin University 4 | ## liuyaoyao@tju.edu.cn 5 | ## Copyright (c) 2019 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | """ Main function for this repo. """ 11 | import argparse 12 | import torch 13 | from utils.misc import pprint 14 | from utils.gpu_tools import set_gpu 15 | from trainer.meta import MetaTrainer 16 | from trainer.pre import PreTrainer 17 | 18 | if __name__ == '__main__': 19 | parser = argparse.ArgumentParser() 20 | # Basic parameters 21 | parser.add_argument('--model_type', type=str, default='ResNet', choices=['ResNet']) # The network architecture 22 | parser.add_argument('--dataset', type=str, default='MiniImageNet', choices=['miniImageNet', 'tieredImageNet', 'FC100']) # Dataset 23 | parser.add_argument('--phase', type=str, default='meta_train', choices=['pre_train', 'meta_train', 'meta_eval']) # Phase 24 | parser.add_argument('--seed', type=int, default=0) # Manual seed for PyTorch, "0" means using random seed 25 | parser.add_argument('--gpu', default='1') # GPU id 26 | parser.add_argument('--dataset_dir', type=str, default='./data/mini/') # Dataset folder 27 | 28 | # Parameters for meta-train phase 29 | parser.add_argument('--max_epoch', type=int, default=100) # Epoch number for meta-train phase 30 | parser.add_argument('--num_batch', type=int, default=100) # The number for different tasks used for meta-train 31 | parser.add_argument('--shot', type=int, default=1) # Shot number, how many samples for one class in a task 32 | parser.add_argument('--way', type=int, default=5) # Way number, how many classes in a task 33 | parser.add_argument('--train_query', type=int, default=15) # The number of training samples for each class in a task 34 | parser.add_argument('--val_query', type=int, default=15) # The number of test samples for each class in a task 35 | parser.add_argument('--meta_lr1', type=float, default=0.0001) # Learning rate for SS weights 36 | parser.add_argument('--meta_lr2', type=float, default=0.001) # Learning rate for FC weights 37 | parser.add_argument('--base_lr', type=float, default=0.01) # Learning rate for the inner loop 38 | parser.add_argument('--update_step', type=int, default=50) # The number of updates for the inner loop 39 | parser.add_argument('--step_size', type=int, default=10) # The number of epochs to reduce the meta learning rates 40 | parser.add_argument('--gamma', type=float, default=0.5) # Gamma for the meta-train learning rate decay 41 | parser.add_argument('--init_weights', type=str, default=None) # The pre-trained weights for meta-train phase 42 | parser.add_argument('--eval_weights', type=str, default=None) # The meta-trained weights for meta-eval phase 43 | parser.add_argument('--meta_label', type=str, default='exp1') # Additional label for meta-train 44 | 45 | # Parameters for pretain phase 46 | parser.add_argument('--pre_max_epoch', type=int, default=100) # Epoch number for pre-train phase 47 | parser.add_argument('--pre_batch_size', type=int, default=128) # Batch size for pre-train phase 48 | parser.add_argument('--pre_lr', type=float, default=0.1) # Learning rate for pre-train phase 49 | parser.add_argument('--pre_gamma', type=float, default=0.2) # Gamma for the pre-train learning rate decay 50 | parser.add_argument('--pre_step_size', type=int, default=30) # The number of epochs to reduce the pre-train learning rate 51 | parser.add_argument('--pre_custom_momentum', type=float, default=0.9) # Momentum for the optimizer during pre-train 52 | parser.add_argument('--pre_custom_weight_decay', type=float, default=0.0005) # Weight decay for the optimizer during pre-train 53 | 54 | # Set and print the parameters 55 | args = parser.parse_args() 56 | pprint(vars(args)) 57 | 58 | # Set the GPU id 59 | set_gpu(args.gpu) 60 | 61 | # Set manual seed for PyTorch 62 | if args.seed==0: 63 | print ('Using random seed.') 64 | torch.backends.cudnn.benchmark = True 65 | else: 66 | print ('Using manual seed:', args.seed) 67 | torch.manual_seed(args.seed) 68 | torch.cuda.manual_seed(args.seed) 69 | torch.backends.cudnn.deterministic = True 70 | torch.backends.cudnn.benchmark = False 71 | 72 | # Start trainer for pre-train, meta-train or meta-eval 73 | if args.phase=='meta_train': 74 | trainer = MetaTrainer(args) 75 | trainer.train() 76 | elif args.phase=='meta_eval': 77 | trainer = MetaTrainer(args) 78 | trainer.eval() 79 | elif args.phase=='pre_train': 80 | trainer = PreTrainer(args) 81 | trainer.train() 82 | else: 83 | raise ValueError('Please set correct phase.') 84 | -------------------------------------------------------------------------------- /tensorflow/README.md: -------------------------------------------------------------------------------- 1 | # Meta-Transfer Learning TensorFlow 2 | [![Python](https://img.shields.io/badge/python-2.7%20%7C%203.5-blue.svg?style=flat-square&logo=python&color=3776AB)](https://www.python.org/) 3 | [![TensorFlow](https://img.shields.io/badge/tensorflow-1.3.0-orange.svg?style=flat-square&logo=tensorflow&color=FF6F00)](https://github.com/y2l/meta-transfer-learning/tree/master/tensorflow) 4 | 5 | #### Summary 6 | 7 | * [Installation](#installation) 8 | * [Project Architecture](#project-architecture) 9 | * [Running Experiments](#running-experiments) 10 | 11 | 12 | ## Installation 13 | 14 | In order to run this repository, we advise you to install python 2.7 or 3.5 and TensorFlow 1.3.0 with Anaconda. 15 | 16 | You may download Anaconda and read the installation instruction on their official website: 17 | 18 | 19 | Create a new environment and install tensorflow on it: 20 | 21 | ```bash 22 | conda create --name mtl-tf python=2.7 23 | conda activate mtl-tf 24 | conda install tensorflow-gpu=1.3.0 25 | ``` 26 | 27 | Install other requirements: 28 | ```bash 29 | pip install scipy tqdm opencv-python pillow matplotlib miniimagenettools 30 | ``` 31 | 32 | Clone this repository: 33 | 34 | ```bash 35 | git clone https://github.com/yaoyao-liu/meta-transfer-learning.git 36 | cd meta-transfer-learning/tensorflow 37 | ``` 38 | 39 | ## Project Architecture 40 | 41 | ``` 42 | . 43 | ├── data_generator # dataset generator 44 | | ├── pre_data_generator.py # data genertor for pre-train phase 45 | | └── meta_data_generator.py # data genertor for meta-train phase 46 | ├── models # tensorflow model files 47 | | ├── resnet12.py # resnet12 class 48 | | ├── resnet18.py # resnet18 class 49 | | ├── pre_model.py # pre-train model class 50 | | └── meta_model.py # meta-train model class 51 | ├── trainer # tensorflow trianer files 52 | | ├── pre.py # pre-train trainer class 53 | | └── meta.py # meta-train trainer class 54 | ├── utils # a series of tools used in this repo 55 | | └── misc.py # miscellaneous tool functions 56 | ├── main.py # the python file with main function and parameter settings 57 | └── run_experiment.py # the script to run the whole experiment 58 | ``` 59 | 60 | ## Running Experiments 61 | 62 | ### Training from Scratch 63 | Run pre-train phase: 64 | ```bash 65 | python run_experiment.py PRE 66 | ``` 67 | Run meta-train and meta-test phase: 68 | ```bash 69 | python run_experiment.py META 70 | ``` 71 | 72 | ### Hyperparameters and Options 73 | You may edit the `run_experiment.py` file to change the hyperparameters and options. 74 | 75 | - `LOG_DIR` Name of the folder to save the log files 76 | - `GPU_ID` GPU device id 77 | - `NET_ARCH` Network backbone (resnet12 or resnet18) 78 | - `PRE_TRA_LABEL` Additional label for pre-train model 79 | - `PRE_TRA_ITER_MAX` Iteration number for the pre-train phase 80 | - `PRE_TRA_DROP` Dropout keep rate for the pre-train phase 81 | - `PRE_DROP_STEP` Iteration number for the pre-train learning rate reducing 82 | - `PRE_LR` Pre-train learning rate 83 | - `SHOT_NUM` Sample number for each class 84 | - `WAY_NUM` Class number for the few-shot tasks 85 | - `MAX_MAX_ITER` Iteration number for meta-train phase 86 | - `META_BATCH_SIZE` Meta batch size 87 | - `PRE_ITER` Iteration number for the pre-train model used in the meta-train phase 88 | - `UPDATE_NUM` Epoch number for the base learning 89 | - `SAVE_STEP` Iteration number to save the meta model 90 | - `META_LR` Meta learning rate 91 | - `META_LR_MIN` Meta learning rate min value 92 | - `LR_DROP_STEP` Iteration number for the meta learning rate reducing 93 | - `BASE_LR` Base learning rate 94 | - `PRE_TRA_DIR` Directory for the pre-train phase images 95 | - `META_TRA_DIR` Directory for the meta-train images 96 | - `META_VAL_DIR` Directory for the meta-validation images 97 | - `META_TES_DIR` Directory for the meta-test images 98 | 99 | The file `run_experiment.py` is just a script to generate commands for `main.py`. If you want to change other settings, please see the comments and descriptions in `main.py`. 100 | 101 | ### Using Downloaded Models 102 | In the default setting, if you run `python run_experiment.py`, the pretrain process will be conducted before the meta-train phase starts. If you want to use the model pretrained by us, you may download the model by the following link. To run experiments with the downloaded model, please make sure you are using python 2.7. 103 | 104 | Comparison of the original paper and the open-source code in terms of test set accuracy: 105 | 106 | | (%) | 𝑚𝑖𝑛𝑖 1-shot | 𝑚𝑖𝑛𝑖 5-shot | FC100 1-shot | FC100 5-shot | 107 | | ---------------------- | ------------ | ------------ | ------------ | ------------ | 108 | | `MTL Paper` | `60.2 ± 1.8` | `74.3 ± 0.9` | `43.6 ± 1.8` | `55.4 ± 0.9` | 109 | | `This Repo` | `60.8 ± 1.8` | `74.3 ± 0.9` | `44.3 ± 1.8` | `56.8 ± 1.0` | 110 | 111 | Download models: [\[Google Drive\]](https://drive.google.com/drive/folders/1MzH2enwLKuzmODYAEATnyiP_602zrdrE?usp=sharing) 112 | 113 | Move the downloaded npy files to `./logs/download_weights` (e.g. 𝑚𝑖𝑛𝑖ImageNet, 1-shot): 114 | ```bash 115 | mkdir -p ./logs/download_weights 116 | mv ~/downloads/mini-1shot/*.npy ./logs/download_weights 117 | ``` 118 | 119 | Run meta-train with downloaded model: 120 | ```bash 121 | python run_experiment.py META_LOAD 122 | ``` 123 | 124 | Run meta-test with downloaded model: 125 | ```bash 126 | python run_experiment.py TEST_LOAD 127 | ``` 128 | 129 | -------------------------------------------------------------------------------- /pytorch/models/mtl.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Tianjin University 4 | ## liuyaoyao@tju.edu.cn 5 | ## Copyright (c) 2019 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | """ Model for meta-transfer learning. """ 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from models.resnet_mtl import ResNetMtl 15 | 16 | class BaseLearner(nn.Module): 17 | """The class for inner loop.""" 18 | def __init__(self, args, z_dim): 19 | super().__init__() 20 | self.args = args 21 | self.z_dim = z_dim 22 | self.vars = nn.ParameterList() 23 | self.fc1_w = nn.Parameter(torch.ones([self.args.way, self.z_dim])) 24 | torch.nn.init.kaiming_normal_(self.fc1_w) 25 | self.vars.append(self.fc1_w) 26 | self.fc1_b = nn.Parameter(torch.zeros(self.args.way)) 27 | self.vars.append(self.fc1_b) 28 | 29 | def forward(self, input_x, the_vars=None): 30 | if the_vars is None: 31 | the_vars = self.vars 32 | fc1_w = the_vars[0] 33 | fc1_b = the_vars[1] 34 | net = F.linear(input_x, fc1_w, fc1_b) 35 | return net 36 | 37 | def parameters(self): 38 | return self.vars 39 | 40 | class MtlLearner(nn.Module): 41 | """The class for outer loop.""" 42 | def __init__(self, args, mode='meta', num_cls=64): 43 | super().__init__() 44 | self.args = args 45 | self.mode = mode 46 | self.update_lr = args.base_lr 47 | self.update_step = args.update_step 48 | z_dim = 640 49 | self.base_learner = BaseLearner(args, z_dim) 50 | 51 | if self.mode == 'meta': 52 | self.encoder = ResNetMtl() 53 | else: 54 | self.encoder = ResNetMtl(mtl=False) 55 | self.pre_fc = nn.Sequential(nn.Linear(640, 1000), nn.ReLU(), nn.Linear(1000, num_cls)) 56 | 57 | def forward(self, inp): 58 | """The function to forward the model. 59 | Args: 60 | inp: input images. 61 | Returns: 62 | the outputs of MTL model. 63 | """ 64 | if self.mode=='pre': 65 | return self.pretrain_forward(inp) 66 | elif self.mode=='meta': 67 | data_shot, label_shot, data_query = inp 68 | return self.meta_forward(data_shot, label_shot, data_query) 69 | elif self.mode=='preval': 70 | data_shot, label_shot, data_query = inp 71 | return self.preval_forward(data_shot, label_shot, data_query) 72 | else: 73 | raise ValueError('Please set the correct mode.') 74 | 75 | def pretrain_forward(self, inp): 76 | """The function to forward pretrain phase. 77 | Args: 78 | inp: input images. 79 | Returns: 80 | the outputs of pretrain model. 81 | """ 82 | return self.pre_fc(self.encoder(inp)) 83 | 84 | def meta_forward(self, data_shot, label_shot, data_query): 85 | """The function to forward meta-train phase. 86 | Args: 87 | data_shot: train images for the task 88 | label_shot: train labels for the task 89 | data_query: test images for the task. 90 | Returns: 91 | logits_q: the predictions for the test samples. 92 | """ 93 | embedding_query = self.encoder(data_query) 94 | embedding_shot = self.encoder(data_shot) 95 | logits = self.base_learner(embedding_shot) 96 | loss = F.cross_entropy(logits, label_shot) 97 | grad = torch.autograd.grad(loss, self.base_learner.parameters()) 98 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.base_learner.parameters()))) 99 | logits_q = self.base_learner(embedding_query, fast_weights) 100 | 101 | for _ in range(1, self.update_step): 102 | logits = self.base_learner(embedding_shot, fast_weights) 103 | loss = F.cross_entropy(logits, label_shot) 104 | grad = torch.autograd.grad(loss, fast_weights) 105 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))) 106 | logits_q = self.base_learner(embedding_query, fast_weights) 107 | return logits_q 108 | 109 | def preval_forward(self, data_shot, label_shot, data_query): 110 | """The function to forward meta-validation during pretrain phase. 111 | Args: 112 | data_shot: train images for the task 113 | label_shot: train labels for the task 114 | data_query: test images for the task. 115 | Returns: 116 | logits_q: the predictions for the test samples. 117 | """ 118 | embedding_query = self.encoder(data_query) 119 | embedding_shot = self.encoder(data_shot) 120 | logits = self.base_learner(embedding_shot) 121 | loss = F.cross_entropy(logits, label_shot) 122 | grad = torch.autograd.grad(loss, self.base_learner.parameters()) 123 | fast_weights = list(map(lambda p: p[1] - 0.01 * p[0], zip(grad, self.base_learner.parameters()))) 124 | logits_q = self.base_learner(embedding_query, fast_weights) 125 | 126 | for _ in range(1, 100): 127 | logits = self.base_learner(embedding_shot, fast_weights) 128 | loss = F.cross_entropy(logits, label_shot) 129 | grad = torch.autograd.grad(loss, fast_weights) 130 | fast_weights = list(map(lambda p: p[1] - 0.01 * p[0], zip(grad, fast_weights))) 131 | logits_q = self.base_learner(embedding_query, fast_weights) 132 | return logits_q 133 | -------------------------------------------------------------------------------- /tensorflow/run_experiment.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Tianjin University 4 | ## liuyaoyao@tju.edu.cn 5 | ## Copyright (c) 2019 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | """ Generate commands for main.py """ 12 | import os 13 | import sys 14 | 15 | def run_experiment(PHASE='META'): 16 | """The function to generate commands to run the experiments. 17 | Arg: 18 | PHASE: the phase for MTL. 'PRE' means pre-train phase, and 'META' means meta-train and meta-test phases. 19 | """ 20 | # Some important options 21 | # Please note that not all the options are shown here. For more detailed options, please edit main.py 22 | 23 | # Basic options 24 | LOG_DIR = 'experiment_results' # Name of the folder to save the log files 25 | GPU_ID = 1 # GPU device id 26 | NET_ARCH = 'resnet12' # Additional label for pre-train model 27 | 28 | # Pre-train phase options 29 | PRE_TRA_LABEL = 'normal' # Additional label for pre-train model 30 | PRE_TRA_ITER_MAX = 20000 # Iteration number for the pre-train phase 31 | PRE_TRA_DROP = 0.9 # Dropout keep rate for the pre-train phase 32 | PRE_DROP_STEP = 5000 # Iteration number for the pre-train learning rate reducing 33 | PRE_LR = 0.001 # Pre-train learning rate 34 | 35 | # Meta options 36 | SHOT_NUM = 1 # Shot number for the few-shot tasks 37 | WAY_NUM = 5 # Class number for the few-shot tasks 38 | MAX_ITER = 20000 # Iteration number for meta-train 39 | META_BATCH_SIZE = 2 # Meta batch size 40 | PRE_ITER = 10000 # Iteration number for the pre-train model used in the meta-train phase 41 | UPDATE_NUM = 20 # Epoch number for the base learning 42 | SAVE_STEP = 100 # Iteration number to save the meta model 43 | META_LR = 0.001 # Meta learning rate 44 | META_LR_MIN = 0.0001 # Meta learning rate min value 45 | LR_DROP_STEP = 1000 # The iteration number for the meta learning rate reducing 46 | BASE_LR = 0.001 # Base learning rate 47 | 48 | # Data directories 49 | PRE_TRA_DIR = './data/mini-imagenet/train' # Directory for the pre-train phase images 50 | META_TRA_DIR = './data/mini-imagenet/train' # Directory for the meta-train images 51 | META_VAL_DIR = './data/mini-imagenet/val' # Directory for the meta-validation images 52 | META_TES_DIR = './data/mini-imagenet/test' # Directory for the meta-test images 53 | 54 | # Generate the base command for main.py 55 | base_command = 'python main.py' \ 56 | + ' --backbone_arch=' + str(NET_ARCH) \ 57 | + ' --metatrain_iterations=' + str(MAX_ITER) \ 58 | + ' --meta_batch_size=' + str(META_BATCH_SIZE) \ 59 | + ' --shot_num=' + str(SHOT_NUM) \ 60 | + ' --meta_lr=' + str(META_LR) \ 61 | + ' --min_meta_lr=' + str(META_LR_MIN) \ 62 | + ' --base_lr=' + str(BASE_LR)\ 63 | + ' --train_base_epoch_num=' + str(UPDATE_NUM) \ 64 | + ' --way_num=' + str(WAY_NUM) \ 65 | + ' --exp_log_label=' + LOG_DIR \ 66 | + ' --pretrain_dropout_keep=' + str(PRE_TRA_DROP) \ 67 | + ' --activation=leaky_relu' \ 68 | + ' --pre_lr=' + str(PRE_LR)\ 69 | + ' --pre_lr_dropstep=' + str(PRE_DROP_STEP) \ 70 | + ' --meta_save_step=' + str(SAVE_STEP) \ 71 | + ' --lr_drop_step=' + str(LR_DROP_STEP) \ 72 | + ' --pretrain_folders=' + PRE_TRA_DIR \ 73 | + ' --pretrain_label=' + PRE_TRA_LABEL \ 74 | + ' --device_id=' + str(GPU_ID) \ 75 | + ' --metatrain_dir=' + META_TRA_DIR \ 76 | + ' --metaval_dir=' + META_VAL_DIR \ 77 | + ' --metatest_dir=' + META_TES_DIR 78 | 79 | def process_test_command(TEST_STEP, in_command): 80 | """The function to adapt the base command to the meta-test phase. 81 | Args: 82 | TEST_STEP: the iteration number for the meta model to be loaded. 83 | in_command: the input base command. 84 | Return: 85 | Processed command. 86 | """ 87 | output_test_command = in_command \ 88 | + ' --phase=meta' \ 89 | + ' --pretrain_iterations=' + str(PRE_ITER) \ 90 | + ' --metatrain=False' \ 91 | + ' --test_iter=' + str(TEST_STEP) 92 | return output_test_command 93 | 94 | if PHASE=='PRE': 95 | print('****** Start Pre-train Phase ******') 96 | pre_command = base_command + ' --phase=pre' + ' --pretrain_iterations=' + str(PRE_TRA_ITER_MAX) 97 | os.system(pre_command) 98 | 99 | if PHASE=='META': 100 | print('****** Start Meta-train Phase ******') 101 | meta_train_command = base_command + ' --phase=meta' + ' --pretrain_iterations=' + str(PRE_ITER) 102 | os.system(meta_train_command) 103 | 104 | print('****** Start Meta-test Phase ******') 105 | for idx in range(MAX_ITER): 106 | if idx % SAVE_STEP == 0: 107 | print('[*] Runing meta-test, load model for ' + str(idx) + ' iterations') 108 | test_command = process_test_command(idx, base_command) 109 | os.system(test_command) 110 | 111 | if PHASE=='META_LOAD': 112 | print('****** Start Meta-train Phase with Downloaded Weights ******') 113 | meta_train_command = base_command + ' --phase=meta' + ' --pretrain_iterations=' + str(PRE_ITER) + ' --load_saved_weights=True' 114 | os.system(meta_train_command) 115 | 116 | if PHASE=='TEST_LOAD': 117 | print('****** Start Meta-test Phase with Downloaded Weights ******') 118 | test_command = process_test_command(0, base_command) + ' --load_saved_weights=True' 119 | os.system(test_command) 120 | 121 | THE_INPUT_PHASE = sys.argv[1] 122 | run_experiment(PHASE=THE_INPUT_PHASE) 123 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta-Transfer Learning for Few-Shot Learning 2 | [![LICENSE](https://img.shields.io/badge/license-MIT-green?style=flat-square)](https://github.com/y2l/meta-transfer-learning-tensorflow/blob/master/LICENSE) 3 | [![Python](https://img.shields.io/badge/python-2.7%20%7C%203.5-blue.svg?style=flat-square&logo=python&color=3776AB)](https://www.python.org/) 4 | [![TensorFlow](https://img.shields.io/badge/tensorflow-1.3.0-orange.svg?style=flat-square&logo=tensorflow&color=FF6F00)](https://github.com/y2l/meta-transfer-learning/tree/master/tensorflow) 5 | [![PyTorch](https://img.shields.io/badge/pytorch-0.4.0-%237732a8?style=flat-square&logo=PyTorch&color=EE4C2C)](https://pytorch.org/) 6 | [![Citations](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/yaoyao-liu/google-scholar/google-scholar-stats/gs_data_shieldsio_mtl.json&logo=Google%20Scholar&color=5087ec&style=flat-square&label=citations)](https://scholar.google.com/citations?view_op=view_citation&hl=en&user=Uf9GqRsAAAAJ&citation_for_view=Uf9GqRsAAAAJ:bEWYMUwI8FkC) 7 | 10 | 11 | This repository contains the TensorFlow and PyTorch implementation for the [CVPR 2019](http://cvpr2019.thecvf.com/) Paper ["Meta-Transfer Learning for Few-Shot Learning"](http://openaccess.thecvf.com/content_CVPR_2019/papers/Sun_Meta-Transfer_Learning_for_Few-Shot_Learning_CVPR_2019_paper.pdf) by [Qianru Sun](https://qianrusun1015.github.io),\* [Yaoyao Liu](https://people.mpi-inf.mpg.de/~yaliu/),\* [Tat-Seng Chua](https://www.chuatatseng.com/), and [Bernt Schiele](https://www.mpi-inf.mpg.de/departments/computer-vision-and-multimodal-computing/people/bernt-schiele/) (\*=equal contribution). 12 | 13 | If you have any questions on this repository or the related paper, feel free to [create an issue](https://github.com/yaoyao-liu/meta-transfer-learning/issues/new) or [send me an email](mailto:yaoyao.liu+github@mpi-inf.mpg.de). 14 | 15 | #### Summary 16 | 17 | * [Introduction](#introduction) 18 | * [Getting Started](#getting-started) 19 | * [Datasets](#datasets) 20 | * [Performance](#performance) 21 | * [Citation](#citation) 22 | * [Acknowledgements](#acknowledgements) 23 | 24 | 25 | ## Introduction 26 | 27 | Meta-learning has been proposed as a framework to address the challenging few-shot learning setting. The key idea is to leverage a large number of similar few-shot tasks in order to learn how to adapt a base-learner to a new task for which only a few labeled samples are available. As deep neural networks (DNNs) tend to overfit using a few samples only, meta-learning typically uses shallow neural networks (SNNs), thus limiting its effectiveness. In this paper we propose a novel few-shot learning method called ***meta-transfer learning (MTL)*** which learns to adapt a ***deep NN*** for ***few shot learning tasks***. Specifically, meta refers to training multiple tasks, and transfer is achieved by learning scaling and shifting functions of DNN weights for each task. We conduct experiments using (5-class, 1-shot) and (5-class, 5-shot) recognition tasks on two challenging few-shot learning benchmarks: 𝑚𝑖𝑛𝑖ImageNet and Fewshot-CIFAR100. 28 | 29 |

30 | 31 |

32 | 33 | > Figure: Meta-Transfer Learning. (a) Parameter-level fine-tuning (FT) is a conventional meta-training operation, e.g. in MAML. Its update works for all neuron parameters, 𝑊 and 𝑏. (b) Our neuron-level scaling and shifting (SS) operations in meta-transfer learning. They reduce the number of learning parameters and avoid overfitting problems. In addition, they keep large-scale trained parameters (in yellow) frozen, preventing “catastrophic forgetting”. 34 | 35 | ## Getting Started 36 | 37 | Please see `README.md` files in the corresponding folders: 38 | 39 | * TensorFlow: [\[Document\]](https://github.com/y2l/meta-transfer-learning/blob/master/tensorflow/README.md) 40 | * PyTorch: [\[Document\]](https://github.com/y2l/meta-transfer-learning/blob/master/pytorch/README.md) 41 | 42 | ## Datasets 43 | 44 | Directly download processed images: [\[Download Page\]](https://mtl.yyliu.net/download/) 45 | 46 | ### 𝒎𝒊𝒏𝒊ImageNet 47 | 48 | The 𝑚𝑖𝑛𝑖ImageNet dataset was proposed by [Vinyals et al.](http://papers.nips.cc/paper/6385-matching-networks-for-one-shot-learning.pdf) for few-shot learning evaluation. Its complexity is high due to the use of ImageNet images but requires fewer resources and infrastructure than running on the full [ImageNet dataset](https://arxiv.org/pdf/1409.0575.pdf). In total, there are 100 classes with 600 samples of 84×84 color images per class. These 100 classes are divided into 64, 16, and 20 classes respectively for sampling tasks for meta-training, meta-validation, and meta-test. To generate this dataset from ImageNet, you may use the repository [𝑚𝑖𝑛𝑖ImageNet tools](https://github.com/y2l/mini-imagenet-tools). 49 | 50 | ### Fewshot-CIFAR100 51 | 52 | Fewshot-CIFAR100 (FC100) is based on the popular object classification dataset CIFAR100. The splits were 53 | proposed by [TADAM](https://arxiv.org/pdf/1805.10123.pdf). It offers a more challenging scenario with lower image resolution and more challenging meta-training/test splits that are separated according to object super-classes. It contains 100 object classes and each class has 600 samples of 32 × 32 color images. The 100 classes belong to 20 super-classes. Meta-training data are from 60 classes belonging to 12 super-classes. Meta-validation and meta-test sets contain 20 classes belonging to 4 super-classes, respectively. 54 | 55 | ### 𝒕𝒊𝒆𝒓𝒆𝒅ImageNet 56 | 57 | The [𝑡𝑖𝑒𝑟𝑒𝑑ImageNet](https://arxiv.org/pdf/1803.00676.pdf) dataset is a larger subset of ILSVRC-12 with 608 classes (779,165 images) grouped into 34 higher-level nodes in the ImageNet human-curated hierarchy. To generate this dataset from ImageNet, you may use the repository 𝑡𝑖𝑒𝑟𝑒𝑑ImageNet dataset: [𝑡𝑖𝑒𝑟𝑒𝑑ImageNet tools](https://github.com/y2l/tiered-imagenet-tools). 58 | 59 | 60 | ## Performance 61 | 62 | | (%) | 𝑚𝑖𝑛𝑖 1-shot | 𝑚𝑖𝑛𝑖 5-shot | FC100 1-shot | FC100 5-shot | 63 | | ---------------------- | ------------ | ------------ | ------------ | ------------ | 64 | | `MTL Paper` | `60.2 ± 1.8` | `74.3 ± 0.9` | `43.6 ± 1.8` | `55.4 ± 0.9` | 65 | | `TensorFlow` | `60.8 ± 1.8` | `74.3 ± 0.9` | `44.3 ± 1.8` | `56.8 ± 1.0` | 66 | * The performance for the PyTorch version is under checking. 67 | 68 | ## Citation 69 | 70 | Please cite our paper if it is helpful to your work: 71 | 72 | ```bibtex 73 | @inproceedings{SunLCS2019MTL, 74 | author = {Qianru Sun and 75 | Yaoyao Liu and 76 | Tat{-}Seng Chua and 77 | Bernt Schiele}, 78 | title = {Meta-Transfer Learning for Few-Shot Learning}, 79 | booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition, {CVPR} 80 | 2019, Long Beach, CA, USA, June 16-20, 2019}, 81 | pages = {403--412}, 82 | publisher = {Computer Vision Foundation / {IEEE}}, 83 | year = {2019} 84 | } 85 | ``` 86 | 87 | ## Acknowledgements 88 | 89 | Our implementations use the source code from the following repositories and users: 90 | 91 | * [Model-Agnostic Meta-Learning](https://github.com/cbfinn/maml) 92 | 93 | * [Optimization as a Model for Few-Shot Learning](https://github.com/gitabcworld/FewShotLearning) 94 | 95 | * [Learning Embedding Adaptation for Few-Shot Learning](https://github.com/Sha-Lab/FEAT) 96 | 97 | * [dragen1860/MAML-Pytorch](https://github.com/dragen1860/MAML-Pytorch) 98 | 99 | * [@icoz69](https://github.com/icoz69) 100 | 101 | * [@CookieLau](https://github.com/CookieLau) 102 | -------------------------------------------------------------------------------- /pytorch/models/resnet_mtl.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Modified from: https://github.com/Sha-Lab/FEAT 4 | ## Tianjin University 5 | ## liuyaoyao@tju.edu.cn 6 | ## Copyright (c) 2019 7 | ## 8 | ## This source code is licensed under the MIT-style license found in the 9 | ## LICENSE file in the root directory of this source tree 10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 11 | """ ResNet with MTL. """ 12 | import torch.nn as nn 13 | from models.conv2d_mtl import Conv2dMtl 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(inplanes, planes, stride) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | class Bottleneck(nn.Module): 51 | expansion = 4 52 | 53 | def __init__(self, inplanes, planes, stride=1, downsample=None): 54 | super(Bottleneck, self).__init__() 55 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 56 | self.bn1 = nn.BatchNorm2d(planes) 57 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 58 | padding=1, bias=False) 59 | self.bn2 = nn.BatchNorm2d(planes) 60 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 61 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 62 | self.relu = nn.ReLU(inplace=True) 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x): 67 | residual = x 68 | 69 | out = self.conv1(x) 70 | out = self.bn1(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv2(out) 74 | out = self.bn2(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv3(out) 78 | out = self.bn3(out) 79 | 80 | if self.downsample is not None: 81 | residual = self.downsample(x) 82 | 83 | out += residual 84 | out = self.relu(out) 85 | 86 | return out 87 | 88 | def conv3x3mtl(in_planes, out_planes, stride=1): 89 | return Conv2dMtl(in_planes, out_planes, kernel_size=3, stride=stride, 90 | padding=1, bias=False) 91 | 92 | 93 | class BasicBlockMtl(nn.Module): 94 | expansion = 1 95 | 96 | def __init__(self, inplanes, planes, stride=1, downsample=None): 97 | super(BasicBlockMtl, self).__init__() 98 | self.conv1 = conv3x3mtl(inplanes, planes, stride) 99 | self.bn1 = nn.BatchNorm2d(planes) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.conv2 = conv3x3mtl(planes, planes) 102 | self.bn2 = nn.BatchNorm2d(planes) 103 | self.downsample = downsample 104 | self.stride = stride 105 | 106 | def forward(self, x): 107 | residual = x 108 | 109 | out = self.conv1(x) 110 | out = self.bn1(out) 111 | out = self.relu(out) 112 | 113 | out = self.conv2(out) 114 | out = self.bn2(out) 115 | 116 | if self.downsample is not None: 117 | residual = self.downsample(x) 118 | 119 | out += residual 120 | out = self.relu(out) 121 | 122 | return out 123 | 124 | 125 | class BottleneckMtl(nn.Module): 126 | expansion = 4 127 | 128 | def __init__(self, inplanes, planes, stride=1, downsample=None): 129 | super(BottleneckMtl, self).__init__() 130 | self.conv1 = Conv2dMtl(inplanes, planes, kernel_size=1, bias=False) 131 | self.bn1 = nn.BatchNorm2d(planes) 132 | self.conv2 = Conv2dMtl(planes, planes, kernel_size=3, stride=stride, 133 | padding=1, bias=False) 134 | self.bn2 = nn.BatchNorm2d(planes) 135 | self.conv3 = Conv2dMtl(planes, planes * self.expansion, kernel_size=1, bias=False) 136 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 137 | self.relu = nn.ReLU(inplace=True) 138 | self.downsample = downsample 139 | self.stride = stride 140 | 141 | def forward(self, x): 142 | residual = x 143 | 144 | out = self.conv1(x) 145 | out = self.bn1(out) 146 | out = self.relu(out) 147 | 148 | out = self.conv2(out) 149 | out = self.bn2(out) 150 | out = self.relu(out) 151 | 152 | out = self.conv3(out) 153 | out = self.bn3(out) 154 | 155 | if self.downsample is not None: 156 | residual = self.downsample(x) 157 | 158 | out += residual 159 | out = self.relu(out) 160 | 161 | return out 162 | 163 | class ResNetMtl(nn.Module): 164 | 165 | def __init__(self, layers=[4, 4, 4], mtl=True): 166 | super(ResNetMtl, self).__init__() 167 | if mtl: 168 | self.Conv2d = Conv2dMtl 169 | block = BasicBlockMtl 170 | else: 171 | self.Conv2d = nn.Conv2d 172 | block = BasicBlock 173 | cfg = [160, 320, 640] 174 | self.inplanes = iChannels = int(cfg[0]/2) 175 | self.conv1 = self.Conv2d(3, iChannels, kernel_size=3, stride=1, padding=1) 176 | self.bn1 = nn.BatchNorm2d(iChannels) 177 | self.relu = nn.ReLU(inplace=True) 178 | self.layer1 = self._make_layer(block, cfg[0], layers[0], stride=2) 179 | self.layer2 = self._make_layer(block, cfg[1], layers[1], stride=2) 180 | self.layer3 = self._make_layer(block, cfg[2], layers[2], stride=2) 181 | self.avgpool = nn.AvgPool2d(10, stride=1) 182 | 183 | for m in self.modules(): 184 | if isinstance(m, self.Conv2d): 185 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 186 | elif isinstance(m, nn.BatchNorm2d): 187 | nn.init.constant_(m.weight, 1) 188 | nn.init.constant_(m.bias, 0) 189 | 190 | def _make_layer(self, block, planes, blocks, stride=1): 191 | downsample = None 192 | if stride != 1 or self.inplanes != planes * block.expansion: 193 | downsample = nn.Sequential( 194 | self.Conv2d(self.inplanes, planes * block.expansion, 195 | kernel_size=1, stride=stride, bias=False), 196 | nn.BatchNorm2d(planes * block.expansion), 197 | ) 198 | 199 | layers = [] 200 | layers.append(block(self.inplanes, planes, stride, downsample)) 201 | self.inplanes = planes * block.expansion 202 | for i in range(1, blocks): 203 | layers.append(block(self.inplanes, planes)) 204 | return nn.Sequential(*layers) 205 | 206 | def forward(self, x): 207 | x = self.conv1(x) 208 | x = self.bn1(x) 209 | x = self.relu(x) 210 | x = self.layer1(x) 211 | x = self.layer2(x) 212 | x = self.layer3(x) 213 | 214 | x = self.avgpool(x) 215 | x = x.view(x.size(0), -1) 216 | 217 | return x 218 | 219 | -------------------------------------------------------------------------------- /tensorflow/main.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Tianjin University 4 | ## liuyaoyao@tju.edu.cn 5 | ## Copyright (c) 2019 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import os 12 | from tensorflow.python.platform import flags 13 | from trainer.meta import MetaTrainer 14 | from trainer.pre import PreTrainer 15 | 16 | FLAGS = flags.FLAGS 17 | 18 | ### Basic options 19 | flags.DEFINE_integer('img_size', 84, 'image size') 20 | flags.DEFINE_integer('device_id', 0, 'GPU device ID to run the job.') 21 | flags.DEFINE_float('gpu_rate', 0.9, 'the parameter for the full_gpu_memory_mode') 22 | flags.DEFINE_string('phase', 'meta', 'pre or meta') 23 | flags.DEFINE_string('exp_log_label', 'experiment_results', 'directory for summaries and checkpoints') 24 | flags.DEFINE_string('logdir_base', './logs/', 'directory for logs') 25 | flags.DEFINE_bool('full_gpu_memory_mode', False, 'in this mode, the code occupies GPU memory in advance') 26 | flags.DEFINE_string('backbone_arch', 'resnet12', 'network backbone') 27 | 28 | ### Pre-train phase options 29 | flags.DEFINE_integer('pre_lr_dropstep', 5000, 'the step number to drop pre_lr') 30 | flags.DEFINE_integer('pretrain_class_num', 64, 'number of classes used in the pre-train phase') 31 | flags.DEFINE_integer('pretrain_batch_size', 64, 'batch_size for the pre-train phase') 32 | flags.DEFINE_integer('pretrain_iterations', 30000, 'number of pretraining iterations.') 33 | flags.DEFINE_integer('pre_sum_step', 10, 'the step number to summary during pretraining') 34 | flags.DEFINE_integer('pre_save_step', 1000, 'the step number to save the pretrain model') 35 | flags.DEFINE_integer('pre_print_step', 1000, 'the step number to print the pretrain results') 36 | flags.DEFINE_float('pre_lr', 0.001, 'the pretrain learning rate') 37 | flags.DEFINE_float('min_pre_lr', 0.0001, 'the pretrain learning rate min') 38 | flags.DEFINE_float('pretrain_dropout_keep', 0.9, 'the dropout keep parameter in the pre-train phase') 39 | flags.DEFINE_string('pretrain_folders', './data/mini-imagenet/train', 'directory for pre-train data') 40 | flags.DEFINE_string('pretrain_label', 'mini_normal', 'additional label for the pre-train log folder') 41 | flags.DEFINE_bool('pre_lr_stop', False, 'whether stop decrease the pre_lr when it is low') 42 | 43 | ### Meta phase options 44 | flags.DEFINE_integer('way_num', 5, 'number of classes (e.g. 5-way classification)') 45 | flags.DEFINE_integer('shot_num', 1, 'number of examples per class (K for K-shot learning)') 46 | flags.DEFINE_integer('metatrain_epite_sample_num', 15, 'number of meta train episode-test samples') 47 | flags.DEFINE_integer('metatest_epite_sample_num', 0, 'number of meta test episode-test samples, 0 means metatest_epite_sample_num=shot_num') 48 | flags.DEFINE_integer('meta_sum_step', 10, 'the step number to summary during meta-training') 49 | flags.DEFINE_integer('meta_save_step', 500, 'the step number to save the model') 50 | flags.DEFINE_integer('meta_intrain_val_sample', 600, 'the number of samples used for val during meta-train') 51 | flags.DEFINE_integer('meta_print_step', 100, 'the step number to print the meta-train results') 52 | flags.DEFINE_integer('meta_val_print_step', 100, 'the step number to print the meta-val results during meta-training') 53 | flags.DEFINE_integer('metatrain_iterations', 15000, 'number of meta-train iterations.') 54 | flags.DEFINE_integer('meta_batch_size', 2, 'number of tasks sampled per meta-update') 55 | flags.DEFINE_integer('train_base_epoch_num', 20, 'number of inner gradient updates during training.') 56 | flags.DEFINE_integer('test_base_epoch_num', 100, 'number of inner gradient updates during test.') 57 | flags.DEFINE_integer('lr_drop_step', 5000, 'the step number to drop meta_lr') 58 | flags.DEFINE_integer('test_iter', 1000, 'iteration to load model') 59 | flags.DEFINE_float('meta_lr', 0.001, 'the meta learning rate of the generator') 60 | flags.DEFINE_float('lr_drop_rate', 0.5, 'the step number to drop meta_lr') 61 | flags.DEFINE_float('min_meta_lr', 0.0001, 'the min meta learning rate of the generator') 62 | flags.DEFINE_float('base_lr', 1e-3, 'step size alpha for inner gradient update.') 63 | flags.DEFINE_string('metatrain_dir', './data/mini-imagenet/train', 'directory for meta-train set') 64 | flags.DEFINE_string('metaval_dir', './data/mini-imagenet/val', 'directory for meta-val set') 65 | flags.DEFINE_string('metatest_dir', './data/mini-imagenet/test', 'directory for meta-test set') 66 | flags.DEFINE_string('activation', 'leaky_relu', 'leaky_relu, relu, or None') 67 | flags.DEFINE_string('norm', 'batch_norm', 'batch_norm, layer_norm, or None') 68 | flags.DEFINE_bool('metatrain', True, 'is this the meta-train phase') 69 | flags.DEFINE_bool('base_augmentation', True, 'whether do data augmentation during base learning') 70 | flags.DEFINE_bool('redo_init', True, 're-build the initialization weights') 71 | flags.DEFINE_bool('load_saved_weights', False, 'load the downloaded weights') 72 | 73 | # Generate experiment key words string 74 | exp_string = 'arch(' + FLAGS.backbone_arch + ')' 75 | exp_string += '.cls(' + str(FLAGS.way_num) + ')' 76 | exp_string += '.shot(' + str(FLAGS.shot_num) + ')' 77 | exp_string += '.meta_batch(' + str(FLAGS.meta_batch_size) + ')' 78 | exp_string += '.base_epoch(' + str(FLAGS.train_base_epoch_num) + ')' 79 | exp_string += '.meta_lr(' + str(FLAGS.meta_lr) + ')' 80 | exp_string += '.base_lr(' + str(FLAGS.base_lr) + ')' 81 | exp_string += '.pre_iterations(' + str(FLAGS.pretrain_iterations) + ')' 82 | exp_string += '.pre_dropout(' + str(FLAGS.pretrain_dropout_keep) + ')' 83 | exp_string += '.acti(' + str(FLAGS.activation) + ')' 84 | exp_string += '.lr_drop_step(' + str(FLAGS.lr_drop_step) + ')' 85 | exp_string += '.lr_drop_rate(' + str(FLAGS.lr_drop_rate) + ')' 86 | exp_string += '.pre_label(' + str(FLAGS.pretrain_label) + ')' 87 | 88 | if FLAGS.norm == 'batch_norm': 89 | exp_string += '.norm(batch)' 90 | elif FLAGS.norm == 'layer_norm': 91 | exp_string += '.norm(layer)' 92 | elif FLAGS.norm == 'None': 93 | exp_string += '.norm(none)' 94 | else: 95 | raise Exception('Norm setting is not recognized') 96 | 97 | FLAGS.exp_string = exp_string 98 | print('Parameters: ' + exp_string) 99 | 100 | # Generate pre-train key words string 101 | pre_save_str = 'arch(' + FLAGS.backbone_arch + ')' 102 | pre_save_str += '.pre_lr(' + str(FLAGS.pre_lr) + ')' 103 | pre_save_str += '.pre_lrdrop(' + str(FLAGS.pre_lr_dropstep) + ')' 104 | pre_save_str += '.pre_class(' + str(FLAGS.pretrain_class_num) + ')' 105 | pre_save_str += '.pre_batch(' + str(FLAGS.pretrain_batch_size) + ')' 106 | pre_save_str += '.pre_dropout(' + str(FLAGS.pretrain_dropout_keep) + ')' 107 | if FLAGS.pre_lr_stop: 108 | pre_save_str += '.pre_lr_stop(True)' 109 | else: 110 | pre_save_str += '.pre_lr_stop(False)' 111 | pre_save_str += '.pre_label(' + FLAGS.pretrain_label + ')' 112 | FLAGS.pre_string = pre_save_str 113 | 114 | # Generate log folders 115 | FLAGS.logdir = FLAGS.logdir_base + FLAGS.exp_log_label 116 | FLAGS.pretrain_dir = FLAGS.logdir_base + 'pretrain_weights' 117 | 118 | if not os.path.exists(FLAGS.logdir_base): 119 | os.mkdir(FLAGS.logdir_base) 120 | if not os.path.exists(FLAGS.logdir): 121 | os.mkdir(FLAGS.logdir) 122 | if not os.path.exists(FLAGS.pretrain_dir): 123 | os.mkdir(FLAGS.pretrain_dir) 124 | 125 | # If FLAGS.redo_init is true, delete the previous intialization weights. 126 | if FLAGS.redo_init: 127 | if not os.path.exists('./logs/init_weights'): 128 | os.system('rm -r ./logs/init_weights') 129 | print('Init weights have been deleted') 130 | else: 131 | print('No init weights') 132 | 133 | def main(): 134 | # Set GPU device id 135 | print('Using GPU ' + str(FLAGS.device_id)) 136 | os.environ['CUDA_VISIBLE_DEVICES'] = str(FLAGS.device_id) 137 | # Select pre-train phase or meta-learning phase 138 | if FLAGS.phase=='pre': 139 | trainer = PreTrainer() 140 | elif FLAGS.phase=='meta': 141 | trainer = MetaTrainer() 142 | else: 143 | raise Exception('Please set correct phase') 144 | 145 | if __name__ == "__main__": 146 | main() 147 | -------------------------------------------------------------------------------- /tensorflow/data_generator/meta_data_generator.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Modified from: https://github.com/cbfinn/maml 4 | ## Tianjin University 5 | ## liuyaoyao@tju.edu.cn 6 | ## Copyright (c) 2019 7 | ## 8 | ## This source code is licensed under the MIT-style license found in the 9 | ## LICENSE file in the root directory of this source tree 10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 11 | 12 | """ Data generator for meta-learning. """ 13 | import numpy as np 14 | import os 15 | import random 16 | import tensorflow as tf 17 | 18 | from tqdm import trange 19 | from tensorflow.python.platform import flags 20 | from utils.misc import get_images, process_batch, process_batch_augmentation 21 | 22 | FLAGS = flags.FLAGS 23 | 24 | class MetaDataGenerator(object): 25 | """The class to generate data lists and episodes for meta-train and meta-test.""" 26 | def __init__(self): 27 | # Set the base folder to save the data lists 28 | filename_dir = FLAGS.logdir_base + 'processed_data/' 29 | if not os.path.exists(filename_dir): 30 | os.mkdir(filename_dir) 31 | 32 | # Set the detailed folder name for saving the data lists 33 | self.this_setting_filename_dir = filename_dir + 'shot(' + str(FLAGS.shot_num) + ').way(' + str(FLAGS.way_num) \ 34 | + ').metatr_epite(' + str(FLAGS.metatrain_epite_sample_num) + ').metate_epite(' + str(FLAGS.metatest_epite_sample_num) + ')/' 35 | if not os.path.exists(self.this_setting_filename_dir): 36 | os.mkdir(self.this_setting_filename_dir) 37 | 38 | def generate_data(self, data_type='train'): 39 | """The function to generate the data lists. 40 | Arg: 41 | data_type: the phase for meta-learning. 42 | """ 43 | if data_type=='train': 44 | metatrain_folder = FLAGS.metatrain_dir 45 | folders = [os.path.join(metatrain_folder, label) \ 46 | for label in os.listdir(metatrain_folder) \ 47 | if os.path.isdir(os.path.join(metatrain_folder, label)) \ 48 | ] 49 | num_total_batches = FLAGS.metatrain_iterations * FLAGS.meta_batch_size + 10 50 | num_samples_per_class = FLAGS.shot_num + FLAGS.metatrain_epite_sample_num 51 | 52 | elif data_type=='test': 53 | metatest_folder = FLAGS.metatest_dir 54 | folders = [os.path.join(metatest_folder, label) \ 55 | for label in os.listdir(metatest_folder) \ 56 | if os.path.isdir(os.path.join(metatest_folder, label)) \ 57 | ] 58 | num_total_batches = 600 59 | if FLAGS.metatest_epite_sample_num==0: 60 | num_samples_per_class = FLAGS.shot_num*2 61 | else: 62 | num_samples_per_class = FLAGS.shot_num + FLAGS.metatest_epite_sample_num 63 | elif data_type=='val': 64 | metaval_folder = FLAGS.metaval_dir 65 | folders = [os.path.join(metaval_folder, label) \ 66 | for label in os.listdir(metaval_folder) \ 67 | if os.path.isdir(os.path.join(metaval_folder, label)) \ 68 | ] 69 | num_total_batches = 600 70 | if FLAGS.metatest_epite_sample_num==0: 71 | num_samples_per_class = FLAGS.shot_num*2 72 | else: 73 | num_samples_per_class = FLAGS.shot_num + FLAGS.metatest_epite_sample_num 74 | else: 75 | raise Exception('Please check data list type') 76 | 77 | task_num = FLAGS.way_num * num_samples_per_class 78 | epitr_sample_num = FLAGS.shot_num 79 | 80 | if not os.path.exists(self.this_setting_filename_dir+'/' + data_type + '_data.npy'): 81 | print('Generating ' + data_type + ' data') 82 | data_list = [] 83 | for epi_idx in trange(num_total_batches): 84 | sampled_character_folders = random.sample(folders, FLAGS.way_num) 85 | random.shuffle(sampled_character_folders) 86 | labels_and_images = get_images(sampled_character_folders, \ 87 | range(FLAGS.way_num), nb_samples=num_samples_per_class, shuffle=False) 88 | labels = [li[0] for li in labels_and_images] 89 | filenames = [li[1] for li in labels_and_images] 90 | this_task_tr_filenames = [] 91 | this_task_tr_labels = [] 92 | this_task_te_filenames = [] 93 | this_task_te_labels = [] 94 | for class_idx in range(FLAGS.way_num): 95 | this_class_filenames = filenames[class_idx*num_samples_per_class:(class_idx+1)*num_samples_per_class] 96 | this_class_label = labels[class_idx*num_samples_per_class:(class_idx+1)*num_samples_per_class] 97 | this_task_tr_filenames += this_class_filenames[0:epitr_sample_num] 98 | this_task_tr_labels += this_class_label[0:epitr_sample_num] 99 | this_task_te_filenames += this_class_filenames[epitr_sample_num:] 100 | this_task_te_labels += this_class_label[epitr_sample_num:] 101 | 102 | this_batch_data = {'filenamea': this_task_tr_filenames, 'filenameb': this_task_te_filenames, 'labela': this_task_tr_labels, \ 103 | 'labelb': this_task_te_labels} 104 | data_list.append(this_batch_data) 105 | 106 | np.save(self.this_setting_filename_dir+'/' + data_type + '_data.npy', data_list) 107 | print('The ' + data_type + ' data are saved') 108 | else: 109 | print('The ' + data_type + ' data have already been created') 110 | 111 | def load_data(self, data_type='test'): 112 | """The function to load the data lists. 113 | Arg: 114 | data_type: the phase for meta-learning. 115 | """ 116 | data_list = np.load(self.this_setting_filename_dir+'/' + data_type + '_data.npy', allow_pickle=True, encoding="latin1") 117 | if data_type=='train': 118 | self.train_data = data_list 119 | elif data_type=='test': 120 | self.test_data = data_list 121 | elif data_type=='val': 122 | self.val_data = data_list 123 | else: 124 | print('[Error] Please check data list type') 125 | 126 | def load_episode(self, index, data_type='train'): 127 | """The function to load the episodes. 128 | Args: 129 | index: the index for the episodes. 130 | data_type: the phase for meta-learning. 131 | """ 132 | if data_type=='train': 133 | data_list = self.train_data 134 | epite_sample_num = FLAGS.metatrain_epite_sample_num 135 | elif data_type=='test': 136 | data_list = self.test_data 137 | if FLAGS.metatest_epite_sample_num==0: 138 | epite_sample_num = FLAGS.shot_num 139 | else: 140 | epite_sample_num = FLAGS.metatest_episode_test_sample 141 | elif data_type=='val': 142 | data_list = self.val_data 143 | if FLAGS.metatest_epite_sample_num==0: 144 | epite_sample_num = FLAGS.shot_num 145 | else: 146 | epite_sample_num = FLAGS.metatest_episode_test_sample 147 | else: 148 | raise Exception('Please check data list type') 149 | 150 | dim_input = FLAGS.img_size * FLAGS.img_size * 3 151 | epitr_sample_num = FLAGS.shot_num 152 | 153 | this_episode = data_list[index] 154 | this_task_tr_filenames = this_episode['filenamea'] 155 | this_task_te_filenames = this_episode['filenameb'] 156 | this_task_tr_labels = this_episode['labela'] 157 | this_task_te_labels = this_episode['labelb'] 158 | 159 | if FLAGS.metatrain is False and FLAGS.base_augmentation: 160 | this_inputa, this_labela = process_batch_augmentation(this_task_tr_filenames, \ 161 | this_task_tr_labels, dim_input, epitr_sample_num) 162 | this_inputb, this_labelb = process_batch(this_task_te_filenames, \ 163 | this_task_te_labels, dim_input, epite_sample_num) 164 | else: 165 | this_inputa, this_labela = process_batch(this_task_tr_filenames, \ 166 | this_task_tr_labels, dim_input, epitr_sample_num) 167 | this_inputb, this_labelb = process_batch(this_task_te_filenames, \ 168 | this_task_te_labels, dim_input, epite_sample_num) 169 | return this_inputa, this_labela, this_inputb, this_labelb 170 | -------------------------------------------------------------------------------- /pytorch/trainer/pre.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Tianjin University 4 | ## liuyaoyao@tju.edu.cn 5 | ## Copyright (c) 2019 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | """ Trainer for pretrain phase. """ 11 | import os.path as osp 12 | import os 13 | import tqdm 14 | import torch 15 | import torch.nn.functional as F 16 | from torch.utils.data import DataLoader 17 | from dataloader.samplers import CategoriesSampler 18 | from models.mtl import MtlLearner 19 | from utils.misc import Averager, Timer, count_acc, ensure_path 20 | from tensorboardX import SummaryWriter 21 | from dataloader.dataset_loader import DatasetLoader as Dataset 22 | 23 | class PreTrainer(object): 24 | """The class that contains the code for the pretrain phase.""" 25 | def __init__(self, args): 26 | # Set the folder to save the records and checkpoints 27 | log_base_dir = './logs/' 28 | if not osp.exists(log_base_dir): 29 | os.mkdir(log_base_dir) 30 | pre_base_dir = osp.join(log_base_dir, 'pre') 31 | if not osp.exists(pre_base_dir): 32 | os.mkdir(pre_base_dir) 33 | save_path1 = '_'.join([args.dataset, args.model_type]) 34 | save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \ 35 | str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch) 36 | args.save_path = pre_base_dir + '/' + save_path1 + '_' + save_path2 37 | ensure_path(args.save_path) 38 | 39 | # Set args to be shareable in the class 40 | self.args = args 41 | 42 | # Load pretrain set 43 | self.trainset = Dataset('train', self.args, train_aug=True) 44 | self.train_loader = DataLoader(dataset=self.trainset, batch_size=args.pre_batch_size, shuffle=True, num_workers=8, pin_memory=True) 45 | 46 | # Load meta-val set 47 | self.valset = Dataset('val', self.args) 48 | self.val_sampler = CategoriesSampler(self.valset.label, 600, self.args.way, self.args.shot + self.args.val_query) 49 | self.val_loader = DataLoader(dataset=self.valset, batch_sampler=self.val_sampler, num_workers=8, pin_memory=True) 50 | 51 | # Set pretrain class number 52 | num_class_pretrain = self.trainset.num_class 53 | 54 | # Build pretrain model 55 | self.model = MtlLearner(self.args, mode='pre', num_cls=num_class_pretrain) 56 | 57 | # Set optimizer 58 | self.optimizer = torch.optim.SGD([{'params': self.model.encoder.parameters(), 'lr': self.args.pre_lr}, \ 59 | {'params': self.model.pre_fc.parameters(), 'lr': self.args.pre_lr}], \ 60 | momentum=self.args.pre_custom_momentum, nesterov=True, weight_decay=self.args.pre_custom_weight_decay) 61 | # Set learning rate scheduler 62 | self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.args.pre_step_size, \ 63 | gamma=self.args.pre_gamma) 64 | 65 | # Set model to GPU 66 | if torch.cuda.is_available(): 67 | torch.backends.cudnn.benchmark = True 68 | self.model = self.model.cuda() 69 | 70 | def save_model(self, name): 71 | """The function to save checkpoints. 72 | Args: 73 | name: the name for saved checkpoint 74 | """ 75 | torch.save(dict(params=self.model.encoder.state_dict()), osp.join(self.args.save_path, name + '.pth')) 76 | 77 | def train(self): 78 | """The function for the pre-train phase.""" 79 | 80 | # Set the pretrain log 81 | trlog = {} 82 | trlog['args'] = vars(self.args) 83 | trlog['train_loss'] = [] 84 | trlog['val_loss'] = [] 85 | trlog['train_acc'] = [] 86 | trlog['val_acc'] = [] 87 | trlog['max_acc'] = 0.0 88 | trlog['max_acc_epoch'] = 0 89 | 90 | # Set the timer 91 | timer = Timer() 92 | # Set global count to zero 93 | global_count = 0 94 | # Set tensorboardX 95 | writer = SummaryWriter(comment=self.args.save_path) 96 | 97 | # Start pretrain 98 | for epoch in range(1, self.args.pre_max_epoch + 1): 99 | # Update learning rate 100 | self.lr_scheduler.step() 101 | # Set the model to train mode 102 | self.model.train() 103 | self.model.mode = 'pre' 104 | # Set averager classes to record training losses and accuracies 105 | train_loss_averager = Averager() 106 | train_acc_averager = Averager() 107 | 108 | # Using tqdm to read samples from train loader 109 | tqdm_gen = tqdm.tqdm(self.train_loader) 110 | for i, batch in enumerate(tqdm_gen, 1): 111 | # Update global count number 112 | global_count = global_count + 1 113 | if torch.cuda.is_available(): 114 | data, _ = [_.cuda() for _ in batch] 115 | else: 116 | data = batch[0] 117 | label = batch[1] 118 | if torch.cuda.is_available(): 119 | label = label.type(torch.cuda.LongTensor) 120 | else: 121 | label = label.type(torch.LongTensor) 122 | # Output logits for model 123 | logits = self.model(data) 124 | # Calculate train loss 125 | loss = F.cross_entropy(logits, label) 126 | # Calculate train accuracy 127 | acc = count_acc(logits, label) 128 | # Write the tensorboardX records 129 | writer.add_scalar('data/loss', float(loss), global_count) 130 | writer.add_scalar('data/acc', float(acc), global_count) 131 | # Print loss and accuracy for this step 132 | tqdm_gen.set_description('Epoch {}, Loss={:.4f} Acc={:.4f}'.format(epoch, loss.item(), acc)) 133 | 134 | # Add loss and accuracy for the averagers 135 | train_loss_averager.add(loss.item()) 136 | train_acc_averager.add(acc) 137 | 138 | # Loss backwards and optimizer updates 139 | self.optimizer.zero_grad() 140 | loss.backward() 141 | self.optimizer.step() 142 | 143 | # Update the averagers 144 | train_loss_averager = train_loss_averager.item() 145 | train_acc_averager = train_acc_averager.item() 146 | 147 | # Start validation for this epoch, set model to eval mode 148 | self.model.eval() 149 | self.model.mode = 'preval' 150 | 151 | # Set averager classes to record validation losses and accuracies 152 | val_loss_averager = Averager() 153 | val_acc_averager = Averager() 154 | 155 | # Generate the labels for test 156 | label = torch.arange(self.args.way).repeat(self.args.val_query) 157 | if torch.cuda.is_available(): 158 | label = label.type(torch.cuda.LongTensor) 159 | else: 160 | label = label.type(torch.LongTensor) 161 | label_shot = torch.arange(self.args.way).repeat(self.args.shot) 162 | if torch.cuda.is_available(): 163 | label_shot = label_shot.type(torch.cuda.LongTensor) 164 | else: 165 | label_shot = label_shot.type(torch.LongTensor) 166 | 167 | # Print previous information 168 | if epoch % 10 == 0: 169 | print('Best Epoch {}, Best Val acc={:.4f}'.format(trlog['max_acc_epoch'], trlog['max_acc'])) 170 | # Run meta-validation 171 | for i, batch in enumerate(self.val_loader, 1): 172 | if torch.cuda.is_available(): 173 | data, _ = [_.cuda() for _ in batch] 174 | else: 175 | data = batch[0] 176 | p = self.args.shot * self.args.way 177 | data_shot, data_query = data[:p], data[p:] 178 | logits = self.model((data_shot, label_shot, data_query)) 179 | loss = F.cross_entropy(logits, label) 180 | acc = count_acc(logits, label) 181 | val_loss_averager.add(loss.item()) 182 | val_acc_averager.add(acc) 183 | 184 | # Update validation averagers 185 | val_loss_averager = val_loss_averager.item() 186 | val_acc_averager = val_acc_averager.item() 187 | # Write the tensorboardX records 188 | writer.add_scalar('data/val_loss', float(val_loss_averager), epoch) 189 | writer.add_scalar('data/val_acc', float(val_acc_averager), epoch) 190 | # Print loss and accuracy for this epoch 191 | print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format(epoch, val_loss_averager, val_acc_averager)) 192 | 193 | # Update best saved model 194 | if val_acc_averager > trlog['max_acc']: 195 | trlog['max_acc'] = val_acc_averager 196 | trlog['max_acc_epoch'] = epoch 197 | self.save_model('max_acc') 198 | # Save model every 10 epochs 199 | if epoch % 10 == 0: 200 | self.save_model('epoch'+str(epoch)) 201 | 202 | # Update the logs 203 | trlog['train_loss'].append(train_loss_averager) 204 | trlog['train_acc'].append(train_acc_averager) 205 | trlog['val_loss'].append(val_loss_averager) 206 | trlog['val_acc'].append(val_acc_averager) 207 | 208 | # Save log 209 | torch.save(trlog, osp.join(self.args.save_path, 'trlog')) 210 | 211 | if epoch % 10 == 0: 212 | print('Running Time: {}, Estimated Time: {}'.format(timer.measure(), timer.measure(epoch / self.args.max_epoch))) 213 | writer.close() 214 | -------------------------------------------------------------------------------- /tensorflow/utils/misc.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Modified from: https://github.com/cbfinn/maml 4 | ## Tianjin University 5 | ## liuyaoyao@tju.edu.cn 6 | ## Copyright (c) 2019 7 | ## 8 | ## This source code is licensed under the MIT-style license found in the 9 | ## LICENSE file in the root directory of this source tree 10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 11 | 12 | """ Additional utility functions. """ 13 | import numpy as np 14 | import os 15 | import cv2 16 | import random 17 | import tensorflow as tf 18 | 19 | from matplotlib.pyplot import imread 20 | from tensorflow.contrib.layers.python import layers as tf_layers 21 | from tensorflow.python.platform import flags 22 | 23 | FLAGS = flags.FLAGS 24 | 25 | def get_smallest_k_index(input_, k): 26 | """The function to get the smallest k items' indices. 27 | Args: 28 | input_: the list to be processed. 29 | k: the number of indices to return. 30 | Return: 31 | The index list with k dimensions. 32 | """ 33 | input_copy = np.copy(input_) 34 | k_list = [] 35 | for idx in range(k): 36 | this_index = np.argmin(input_copy) 37 | k_list.append(this_index) 38 | input_copy[this_index]=np.max(input_copy) 39 | return k_list 40 | 41 | def one_hot(inp): 42 | """The function to make the input to one-hot vectors. 43 | Arg: 44 | inp: the input numpy array. 45 | Return: 46 | The reorganized one-shot array. 47 | """ 48 | n_class = inp.max() + 1 49 | n_sample = inp.shape[0] 50 | out = np.zeros((n_sample, n_class)) 51 | for idx in range(n_sample): 52 | out[idx, inp[idx]] = 1 53 | return out 54 | 55 | def one_hot_class(inp, n_class): 56 | """The function to make the input to n-class one-hot vectors. 57 | Args: 58 | inp: the input numpy array. 59 | n_class: the number of classes. 60 | Return: 61 | The reorganized n-class one-shot array. 62 | """ 63 | n_sample = inp.shape[0] 64 | out = np.zeros((n_sample, n_class)) 65 | for idx in range(n_sample): 66 | out[idx, inp[idx]] = 1 67 | return out 68 | 69 | def process_batch(input_filename_list, input_label_list, dim_input, batch_sample_num): 70 | """The function to process a part of an episode. 71 | Args: 72 | input_filename_list: the image files' directory list. 73 | input_label_list: the image files' corressponding label list. 74 | dim_input: the dimension number of the images. 75 | batch_sample_num: the sample number of the inputed images. 76 | Returns: 77 | img_array: the numpy array of processed images. 78 | label_array: the numpy array of processed labels. 79 | """ 80 | new_path_list = [] 81 | new_label_list = [] 82 | for k in range(batch_sample_num): 83 | class_idxs = list(range(0, FLAGS.way_num)) 84 | random.shuffle(class_idxs) 85 | for class_idx in class_idxs: 86 | true_idx = class_idx*batch_sample_num + k 87 | new_path_list.append(input_filename_list[true_idx]) 88 | new_label_list.append(input_label_list[true_idx]) 89 | 90 | img_list = [] 91 | for filepath in new_path_list: 92 | this_img = imread(filepath) 93 | this_img = np.reshape(this_img, [-1, dim_input]) 94 | this_img = this_img / 255.0 95 | img_list.append(this_img) 96 | 97 | img_array = np.array(img_list).reshape([FLAGS.way_num*batch_sample_num, dim_input]) 98 | label_array = one_hot(np.array(new_label_list)).reshape([FLAGS.way_num*batch_sample_num, -1]) 99 | return img_array, label_array 100 | 101 | def process_batch_augmentation(input_filename_list, input_label_list, dim_input, batch_sample_num): 102 | """The function to process a part of an episode. All the images will be augmented by flipping. 103 | Args: 104 | input_filename_list: the image files' directory list. 105 | input_label_list: the image files' corressponding label list. 106 | dim_input: the dimension number of the images. 107 | batch_sample_num: the sample number of the inputed images. 108 | Returns: 109 | img_array: the numpy array of processed images. 110 | label_array: the numpy array of processed labels. 111 | """ 112 | new_path_list = [] 113 | new_label_list = [] 114 | for k in range(batch_sample_num): 115 | class_idxs = list(range(0, FLAGS.way_num)) 116 | random.shuffle(class_idxs) 117 | for class_idx in class_idxs: 118 | true_idx = class_idx*batch_sample_num + k 119 | new_path_list.append(input_filename_list[true_idx]) 120 | new_label_list.append(input_label_list[true_idx]) 121 | 122 | img_list = [] 123 | img_list_h = [] 124 | for filepath in new_path_list: 125 | this_img = imread(filepath) 126 | this_img_h = cv2.flip(this_img, 1) 127 | this_img = np.reshape(this_img, [-1, dim_input]) 128 | this_img = this_img / 255.0 129 | img_list.append(this_img) 130 | this_img_h = np.reshape(this_img_h, [-1, dim_input]) 131 | this_img_h = this_img_h / 255.0 132 | img_list_h.append(this_img_h) 133 | 134 | img_list_all = img_list + img_list_h 135 | label_list_all = new_label_list + new_label_list 136 | 137 | img_array = np.array(img_list_all).reshape([FLAGS.way_num*batch_sample_num*2, dim_input]) 138 | label_array = one_hot(np.array(label_list_all)).reshape([FLAGS.way_num*batch_sample_num*2, -1]) 139 | return img_array, label_array 140 | 141 | 142 | def get_images(paths, labels, nb_samples=None, shuffle=True): 143 | """The function to get the image files' directories with given class labels. 144 | Args: 145 | paths: the base path for the images. 146 | labels: the class name labels. 147 | nb_samples: the number of samples. 148 | shuffle: whether shuffle the generated image list. 149 | Return: 150 | The list for the image files' directories. 151 | """ 152 | if nb_samples is not None: 153 | sampler = lambda x: random.sample(x, nb_samples) 154 | else: 155 | sampler = lambda x: x 156 | images = [(i, os.path.join(path, image)) \ 157 | for i, path in zip(labels, paths) \ 158 | for image in sampler(os.listdir(path))] 159 | if shuffle: 160 | random.shuffle(images) 161 | return images 162 | 163 | def get_pretrain_images(path, label): 164 | """The function to get the image files' directories for pre-train phase. 165 | Args: 166 | paths: the base path for the images. 167 | labels: the class name labels. 168 | is_val: whether the images are for the validation phase during pre-training. 169 | Return: 170 | The list for the image files' directories. 171 | """ 172 | images = [] 173 | for image in os.listdir(path): 174 | images.append((label, os.path.join(path, image))) 175 | return images 176 | 177 | def get_images_tc(paths, labels, nb_samples=None, shuffle=True, is_val=False): 178 | """The function to get the image files' directories with given class labels for pre-train phase. 179 | Args: 180 | paths: the base path for the images. 181 | labels: the class name labels. 182 | nb_samples: the number of samples. 183 | shuffle: whether shuffle the generated image list. 184 | is_val: whether the images are for the validation phase during pre-training. 185 | Return: 186 | The list for the image files' directories. 187 | """ 188 | if nb_samples is not None: 189 | sampler = lambda x: random.sample(x, nb_samples) 190 | else: 191 | sampler = lambda x: x 192 | if is_val is False: 193 | images = [(i, os.path.join(path, image)) \ 194 | for i, path in zip(labels, paths) \ 195 | for image in sampler(os.listdir(path)[0:500])] 196 | else: 197 | images = [(i, os.path.join(path, image)) \ 198 | for i, path in zip(labels, paths) \ 199 | for image in sampler(os.listdir(path)[500:])] 200 | if shuffle: 201 | random.shuffle(images) 202 | return images 203 | 204 | 205 | ## Network helpers 206 | 207 | def leaky_relu(x, leak=0.1): 208 | """The leaky relu function. 209 | Args: 210 | x: the input feature maps. 211 | leak: the parameter for leaky relu. 212 | Return: 213 | The feature maps processed by non-liner layer. 214 | """ 215 | return tf.maximum(x, leak*x) 216 | 217 | def resnet_conv_block(inp, cweight, bweight, reuse, scope, activation=leaky_relu): 218 | """The function to forward a conv layer. 219 | Args: 220 | inp: the input feature maps. 221 | cweight: the filters' weights for this conv layer. 222 | bweight: the biases' weights for this conv layer. 223 | reuse: whether reuse the variables for the batch norm. 224 | scope: the label for this conv layer. 225 | activation: the activation function for this conv layer. 226 | Return: 227 | The processed feature maps. 228 | """ 229 | stride, no_stride = [1,2,2,1], [1,1,1,1] 230 | 231 | if FLAGS.activation == 'leaky_relu': 232 | activation = leaky_relu 233 | elif FLAGS.activation == 'relu': 234 | activation = tf.nn.relu 235 | else: 236 | activation = None 237 | 238 | conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME') + bweight 239 | normed = normalize(conv_output, activation, reuse, scope) 240 | 241 | return normed 242 | 243 | def resnet_nob_conv_block(inp, cweight, reuse, scope): 244 | """The function to forward a conv layer without biases, normalization and non-liner layer. 245 | Args: 246 | inp: the input feature maps. 247 | cweight: the filters' weights for this conv layer. 248 | reuse: whether reuse the variables for the batch norm. 249 | scope: the label for this conv layer. 250 | Return: 251 | The processed feature maps. 252 | """ 253 | stride, no_stride = [1,2,2,1], [1,1,1,1] 254 | conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME') 255 | return conv_output 256 | 257 | def normalize(inp, activation, reuse, scope): 258 | """The function to forward the normalization. 259 | Args: 260 | inp: the input feature maps. 261 | reuse: whether reuse the variables for the batch norm. 262 | scope: the label for this conv layer. 263 | activation: the activation function for this conv layer. 264 | Return: 265 | The processed feature maps. 266 | """ 267 | if FLAGS.norm == 'batch_norm': 268 | return tf_layers.batch_norm(inp, activation_fn=activation, reuse=reuse, scope=scope) 269 | elif FLAGS.norm == 'layer_norm': 270 | return tf_layers.layer_norm(inp, activation_fn=activation, reuse=reuse, scope=scope) 271 | elif FLAGS.norm == 'None': 272 | if activation is not None: 273 | return activation(inp) 274 | return inp 275 | else: 276 | raise ValueError('Please set correct normalization.') 277 | 278 | ## Loss functions 279 | 280 | def mse(pred, label): 281 | """The MSE loss function. 282 | Args: 283 | pred: the predictions. 284 | label: the ground truth labels. 285 | Return: 286 | The Loss. 287 | """ 288 | pred = tf.reshape(pred, [-1]) 289 | label = tf.reshape(label, [-1]) 290 | return tf.reduce_mean(tf.square(pred-label)) 291 | 292 | def softmaxloss(pred, label): 293 | """The softmax cross entropy loss function. 294 | Args: 295 | pred: the predictions. 296 | label: the ground truth labels. 297 | Return: 298 | The Loss. 299 | """ 300 | return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=label)) 301 | 302 | def xent(pred, label): 303 | """The softmax cross entropy loss function. The losses will be normalized by the shot number. 304 | Args: 305 | pred: the predictions. 306 | label: the ground truth labels. 307 | Return: 308 | The Loss. 309 | Note: with tf version <=0.12, this loss has incorrect 2nd derivatives 310 | """ 311 | return tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=label) / FLAGS.shot_num 312 | -------------------------------------------------------------------------------- /tensorflow/models/resnet12.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Modified from: https://github.com/cbfinn/maml 4 | ## Tianjin University 5 | ## liuyaoyao@tju.edu.cn 6 | ## Copyright (c) 2019 7 | ## 8 | ## This source code is licensed under the MIT-style license found in the 9 | ## LICENSE file in the root directory of this source tree 10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 11 | 12 | """ ResNet-12 class. """ 13 | import numpy as np 14 | import tensorflow as tf 15 | from tensorflow.python.platform import flags 16 | from utils.misc import mse, softmaxloss, xent, resnet_conv_block, resnet_nob_conv_block 17 | 18 | FLAGS = flags.FLAGS 19 | 20 | class Models: 21 | """The class that contains the code for the basic resnet models and SS weights""" 22 | def __init__(self): 23 | # Set the dimension number for the input feature maps 24 | self.dim_input = FLAGS.img_size * FLAGS.img_size * 3 25 | # Set the dimension number for the outputs 26 | self.dim_output = FLAGS.way_num 27 | # Load base learning rates from FLAGS 28 | self.update_lr = FLAGS.base_lr 29 | # Load the pre-train phase class number from FLAGS 30 | self.pretrain_class_num = FLAGS.pretrain_class_num 31 | # Set the initial meta learning rate 32 | self.meta_lr = tf.placeholder_with_default(FLAGS.meta_lr, ()) 33 | # Set the initial pre-train learning rate 34 | self.pretrain_lr = tf.placeholder_with_default(FLAGS.pre_lr, ()) 35 | 36 | # Set the default objective functions for meta-train and pre-train 37 | self.loss_func = xent 38 | self.pretrain_loss_func = softmaxloss 39 | 40 | # Set the default channel number to 3 41 | self.channels = 3 42 | # Load the image size from FLAGS 43 | self.img_size = FLAGS.img_size 44 | 45 | def process_ss_weights(self, weights, ss_weights, label): 46 | """The function to process the scaling operation 47 | Args: 48 | weights: the weights for the resnet. 49 | ss_weights: the weights for scaling and shifting operation. 50 | label: the label to indicate which layer we are operating. 51 | Return: 52 | The processed weights for the new resnet. 53 | """ 54 | [dim0, dim1] = weights[label].get_shape().as_list()[0:2] 55 | this_ss_weights = tf.tile(ss_weights[label], multiples=[dim0, dim1, 1, 1]) 56 | return tf.multiply(weights[label], this_ss_weights) 57 | 58 | def forward_pretrain_resnet(self, inp, weights, reuse=False, scope=''): 59 | """The function to forward the resnet during pre-train phase 60 | Args: 61 | inp: input feature maps. 62 | weights: input resnet weights. 63 | reuse: reuse the batch norm weights or not. 64 | scope: the label to indicate which layer we are processing. 65 | Return: 66 | The processed feature maps. 67 | """ 68 | inp = tf.reshape(inp, [-1, self.img_size, self.img_size, self.channels]) 69 | net = self.pretrain_block_forward(inp, weights, 'block1', reuse, scope) 70 | net = self.pretrain_block_forward(net, weights, 'block2', reuse, scope) 71 | net = self.pretrain_block_forward(net, weights, 'block3', reuse, scope) 72 | net = self.pretrain_block_forward(net, weights, 'block4', reuse, scope) 73 | net = tf.nn.avg_pool(net, [1,5,5,1], [1,5,5,1], 'VALID') 74 | net = tf.reshape(net, [-1, np.prod([int(dim) for dim in net.get_shape()[1:]])]) 75 | return net 76 | 77 | def forward_resnet(self, inp, weights, ss_weights, reuse=False, scope=''): 78 | """The function to forward the resnet during meta-train phase 79 | Args: 80 | inp: input feature maps. 81 | weights: input resnet weights. 82 | ss_weights: input scaling weights. 83 | reuse: reuse the batch norm weights or not. 84 | scope: the label to indicate which layer we are processing. 85 | Return: 86 | The processed feature maps. 87 | """ 88 | inp = tf.reshape(inp, [-1, self.img_size, self.img_size, self.channels]) 89 | net = self.block_forward(inp, weights, ss_weights, 'block1', reuse, scope) 90 | net = self.block_forward(net, weights, ss_weights, 'block2', reuse, scope) 91 | net = self.block_forward(net, weights, ss_weights, 'block3', reuse, scope) 92 | net = self.block_forward(net, weights, ss_weights, 'block4', reuse, scope) 93 | net = tf.nn.avg_pool(net, [1,5,5,1], [1,5,5,1], 'VALID') 94 | net = tf.reshape(net, [-1, np.prod([int(dim) for dim in net.get_shape()[1:]])]) 95 | return net 96 | 97 | def forward_fc(self, inp, fc_weights): 98 | """The function to forward the fc layer 99 | Args: 100 | inp: input feature maps. 101 | fc_weights: input fc weights. 102 | Return: 103 | The processed feature maps. 104 | """ 105 | net = tf.matmul(inp, fc_weights['w5']) + fc_weights['b5'] 106 | return net 107 | 108 | def pretrain_block_forward(self, inp, weights, block, reuse, scope): 109 | """The function to forward a resnet block during pre-train phase 110 | Args: 111 | inp: input feature maps. 112 | weights: input resnet weights. 113 | block: the string to indicate which block we are processing. 114 | reuse: reuse the batch norm weights or not. 115 | scope: the label to indicate which layer we are processing. 116 | Return: 117 | The processed feature maps. 118 | """ 119 | net = resnet_conv_block(inp, weights[block + '_conv1'], weights[block + '_bias1'], reuse, scope+block+'0') 120 | net = resnet_conv_block(net, weights[block + '_conv2'], weights[block + '_bias2'], reuse, scope+block+'1') 121 | net = resnet_conv_block(net, weights[block + '_conv3'], weights[block + '_bias3'], reuse, scope+block+'2') 122 | res = resnet_nob_conv_block(inp, weights[block + '_conv_res'], reuse, scope+block+'res') 123 | net = net + res 124 | net = tf.nn.max_pool(net, [1,2,2,1], [1,2,2,1], 'VALID') 125 | net = tf.nn.dropout(net, keep_prob=FLAGS.pretrain_dropout_keep) 126 | return net 127 | 128 | def block_forward(self, inp, weights, ss_weights, block, reuse, scope): 129 | """The function to forward a resnet block during meta-train phase 130 | Args: 131 | inp: input feature maps. 132 | weights: input resnet weights. 133 | ss_weights: input scaling weights. 134 | block: the string to indicate which block we are processing. 135 | reuse: reuse the batch norm weights or not. 136 | scope: the label to indicate which layer we are processing. 137 | Return: 138 | The processed feature maps. 139 | """ 140 | net = resnet_conv_block(inp, self.process_ss_weights(weights, ss_weights, block + '_conv1'), \ 141 | ss_weights[block + '_bias1'], reuse, scope+block+'0') 142 | net = resnet_conv_block(net, self.process_ss_weights(weights, ss_weights, block + '_conv2'), \ 143 | ss_weights[block + '_bias2'], reuse, scope+block+'1') 144 | net = resnet_conv_block(net, self.process_ss_weights(weights, ss_weights, block + '_conv3'), \ 145 | ss_weights[block + '_bias3'], reuse, scope+block+'2') 146 | res = resnet_nob_conv_block(inp, weights[block + '_conv_res'], reuse, scope+block+'res') 147 | net = net + res 148 | net = tf.nn.max_pool(net, [1,2,2,1], [1,2,2,1], 'VALID') 149 | net = tf.nn.dropout(net, keep_prob=1) 150 | return net 151 | 152 | def construct_fc_weights(self): 153 | """The function to construct fc weights. 154 | Return: 155 | The fc weights. 156 | """ 157 | dtype = tf.float32 158 | fc_weights = {} 159 | fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype) 160 | if FLAGS.phase=='pre': 161 | fc_weights['w5'] = tf.get_variable('fc_w5', [512, FLAGS.pretrain_class_num], initializer=fc_initializer) 162 | fc_weights['b5'] = tf.Variable(tf.zeros([FLAGS.pretrain_class_num]), name='fc_b5') 163 | else: 164 | fc_weights['w5'] = tf.get_variable('fc_w5', [512, self.dim_output], initializer=fc_initializer) 165 | fc_weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='fc_b5') 166 | return fc_weights 167 | 168 | def construct_resnet_weights(self): 169 | """The function to construct resnet weights. 170 | Return: 171 | The resnet weights. 172 | """ 173 | weights = {} 174 | dtype = tf.float32 175 | conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype) 176 | fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype) 177 | weights = self.construct_residual_block_weights(weights, 3, 3, 64, conv_initializer, dtype, 'block1') 178 | weights = self.construct_residual_block_weights(weights, 3, 64, 128, conv_initializer, dtype, 'block2') 179 | weights = self.construct_residual_block_weights(weights, 3, 128, 256, conv_initializer, dtype, 'block3') 180 | weights = self.construct_residual_block_weights(weights, 3, 256, 512, conv_initializer, dtype, 'block4') 181 | weights['w5'] = tf.get_variable('w5', [512, FLAGS.pretrain_class_num], initializer=fc_initializer) 182 | weights['b5'] = tf.Variable(tf.zeros([FLAGS.pretrain_class_num]), name='b5') 183 | return weights 184 | 185 | def construct_residual_block_weights(self, weights, k, last_dim_hidden, dim_hidden, conv_initializer, dtype, scope='block0'): 186 | """The function to construct one block of the resnet weights. 187 | Args: 188 | weights: the resnet weight list. 189 | k: the dimension number of the convolution kernel. 190 | last_dim_hidden: the hidden dimension number of last block. 191 | dim_hidden: the hidden dimension number of the block. 192 | conv_initializer: the convolution initializer. 193 | dtype: the dtype for numpy. 194 | scope: the label to indicate which block we are processing. 195 | Return: 196 | The resnet block weights. 197 | """ 198 | weights[scope + '_conv1'] = tf.get_variable(scope + '_conv1', [k, k, last_dim_hidden, dim_hidden], \ 199 | initializer=conv_initializer, dtype=dtype) 200 | weights[scope + '_bias1'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias1') 201 | weights[scope + '_conv2'] = tf.get_variable(scope + '_conv2', [k, k, dim_hidden, dim_hidden], \ 202 | initializer=conv_initializer, dtype=dtype) 203 | weights[scope + '_bias2'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias2') 204 | weights[scope + '_conv3'] = tf.get_variable(scope + '_conv3', [k, k, dim_hidden, dim_hidden], \ 205 | initializer=conv_initializer, dtype=dtype) 206 | weights[scope + '_bias3'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias3') 207 | weights[scope + '_conv_res'] = tf.get_variable(scope + '_conv_res', [1, 1, last_dim_hidden, dim_hidden], \ 208 | initializer=conv_initializer, dtype=dtype) 209 | return weights 210 | 211 | def construct_resnet_ss_weights(self): 212 | """The function to construct ss weights. 213 | Return: 214 | The ss weights. 215 | """ 216 | ss_weights = {} 217 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 3, 64, 'block1') 218 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 64, 128, 'block2') 219 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 128, 256, 'block3') 220 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 256, 512, 'block4') 221 | return ss_weights 222 | 223 | def construct_residual_block_ss_weights(self, ss_weights, last_dim_hidden, dim_hidden, scope='block0'): 224 | """The function to construct one block ss weights. 225 | Args: 226 | ss_weights: the ss weight list. 227 | last_dim_hidden: the hidden dimension number of last block. 228 | dim_hidden: the hidden dimension number of the block. 229 | scope: the label to indicate which block we are processing. 230 | Return: 231 | The ss block weights. 232 | """ 233 | ss_weights[scope + '_conv1'] = tf.Variable(tf.ones([1, 1, last_dim_hidden, dim_hidden]), name=scope + '_conv1') 234 | ss_weights[scope + '_bias1'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias1') 235 | ss_weights[scope + '_conv2'] = tf.Variable(tf.ones([1, 1, dim_hidden, dim_hidden]), name=scope + '_conv2') 236 | ss_weights[scope + '_bias2'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias2') 237 | ss_weights[scope + '_conv3'] = tf.Variable(tf.ones([1, 1, dim_hidden, dim_hidden]), name=scope + '_conv3') 238 | ss_weights[scope + '_bias3'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias3') 239 | return ss_weights 240 | -------------------------------------------------------------------------------- /tensorflow/models/meta_model.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Modified from: https://github.com/cbfinn/maml 4 | ## Tianjin University 5 | ## liuyaoyao@tju.edu.cn 6 | ## Copyright (c) 2019 7 | ## 8 | ## This source code is licensed under the MIT-style license found in the 9 | ## LICENSE file in the root directory of this source tree 10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 11 | 12 | """ Models for meta-learning. """ 13 | import tensorflow as tf 14 | from tensorflow.python.platform import flags 15 | from utils.misc import mse, softmaxloss, xent, resnet_conv_block, resnet_nob_conv_block 16 | 17 | FLAGS = flags.FLAGS 18 | 19 | def MakeMetaModel(): 20 | """The function to make meta model. 21 | Arg: 22 | Meta-train model class. 23 | """ 24 | if FLAGS.backbone_arch=='resnet12': 25 | try:#python2 26 | from resnet12 import Models 27 | except ImportError:#python3 28 | from models.resnet12 import Models 29 | elif FLAGS.backbone_arch=='resnet18': 30 | try:#python2 31 | from resnet18 import Models 32 | except ImportError:#python3 33 | from models.resnet18 import Models 34 | else: 35 | print('Please set the correct backbone') 36 | 37 | class MetaModel(Models): 38 | """The class for the meta models. This class is inheritance from Models, so some variables are in the Models class.""" 39 | def construct_model(self): 40 | """The function to construct meta-train model.""" 41 | # Set the placeholder for the input episode 42 | self.inputa = tf.placeholder(tf.float32) # episode train images 43 | self.inputb = tf.placeholder(tf.float32) # episode test images 44 | self.labela = tf.placeholder(tf.float32) # episode train labels 45 | self.labelb = tf.placeholder(tf.float32) # episode test labels 46 | 47 | with tf.variable_scope('meta-model', reuse=None) as training_scope: 48 | # construct the model weights 49 | self.ss_weights = ss_weights = self.construct_resnet_ss_weights() 50 | self.weights = weights = self.construct_resnet_weights() 51 | self.fc_weights = fc_weights = self.construct_fc_weights() 52 | 53 | # Load base epoch number from FLAGS 54 | num_updates = FLAGS.train_base_epoch_num 55 | 56 | def task_metalearn(inp, reuse=True): 57 | """The function to process one episode in a meta-batch. 58 | Args: 59 | inp: the input episode. 60 | reuse: whether reuse the variables for the normalization. 61 | Returns: 62 | A serious outputs like losses and accuracies. 63 | """ 64 | # Seperate inp to different variables 65 | inputa, inputb, labela, labelb = inp 66 | # Generate empty list to record losses 67 | lossa_list = [] # Base train loss list 68 | lossb_list = [] # Base test loss list 69 | 70 | # Embed the input images to embeddings with ss weights 71 | emb_outputa = self.forward_resnet(inputa, weights, ss_weights, reuse=reuse) # Embed episode train 72 | emb_outputb = self.forward_resnet(inputb, weights, ss_weights, reuse=True) # Embed episode test 73 | 74 | # Run the first epoch of the base learning 75 | # Forward fc layer for episode train 76 | outputa = self.forward_fc(emb_outputa, fc_weights) 77 | # Calculate base train loss 78 | lossa = self.loss_func(outputa, labela) 79 | # Record base train loss 80 | lossa_list.append(lossa) 81 | # Forward fc layer for episode test 82 | outputb = self.forward_fc(emb_outputb, fc_weights) 83 | # Calculate base test loss 84 | lossb = self.loss_func(outputb, labelb) 85 | # Record base test loss 86 | lossb_list.append(lossb) 87 | # Calculate the gradients for the fc layer 88 | grads = tf.gradients(lossa, list(fc_weights.values())) 89 | gradients = dict(zip(fc_weights.keys(), grads)) 90 | # Use graient descent to update the fc layer 91 | fast_fc_weights = dict(zip(fc_weights.keys(), [fc_weights[key] - \ 92 | self.update_lr*gradients[key] for key in fc_weights.keys()])) 93 | 94 | for j in range(num_updates - 1): 95 | # Run the following base epochs, these are similar to the first base epoch 96 | lossa = self.loss_func(self.forward_fc(emb_outputa, fast_fc_weights), labela) 97 | lossa_list.append(lossa) 98 | lossb = self.loss_func(self.forward_fc(emb_outputb, fast_fc_weights), labelb) 99 | lossb_list.append(lossb) 100 | grads = tf.gradients(lossa, list(fast_fc_weights.values())) 101 | gradients = dict(zip(fast_fc_weights.keys(), grads)) 102 | fast_fc_weights = dict(zip(fast_fc_weights.keys(), [fast_fc_weights[key] - \ 103 | self.update_lr*gradients[key] for key in fast_fc_weights.keys()])) 104 | 105 | # Calculate final episode test predictions 106 | outputb = self.forward_fc(emb_outputb, fast_fc_weights) 107 | # Calculate the final episode test loss, it is the loss for the episode on meta-train 108 | final_lossb = self.loss_func(outputb, labelb) 109 | # Calculate the final episode test accuarcy 110 | accb = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(outputb), 1), tf.argmax(labelb, 1)) 111 | 112 | # Reorganize all the outputs to a list 113 | task_output = [final_lossb, lossb_list, lossa_list, accb] 114 | 115 | return task_output 116 | 117 | # Initial the batch normalization weights 118 | if FLAGS.norm is not None: 119 | unused = task_metalearn((self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]), False) 120 | 121 | # Set the dtype of the outputs 122 | out_dtype = [tf.float32, [tf.float32]*num_updates, [tf.float32]*num_updates, tf.float32] 123 | 124 | # Run two episodes for a meta batch using parallel setting 125 | result = tf.map_fn(task_metalearn, elems=(self.inputa, self.inputb, self.labela, self.labelb), \ 126 | dtype=out_dtype, parallel_iterations=FLAGS.meta_batch_size) 127 | # Seperate the outputs to different variables 128 | lossb, lossesb, lossesa, accsb = result 129 | 130 | # Set the variables to output from the tensorflow graph 131 | self.total_loss = total_loss = tf.reduce_sum(lossb) / tf.to_float(FLAGS.meta_batch_size) 132 | self.total_accuracy = total_accuracy = tf.reduce_sum(accsb) / tf.to_float(FLAGS.meta_batch_size) 133 | self.total_lossa = total_lossa = [tf.reduce_sum(lossesa[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)] 134 | self.total_lossb = total_lossb = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)] 135 | 136 | # Set the meta-train optimizer 137 | optimizer = tf.train.AdamOptimizer(self.meta_lr) 138 | self.metatrain_op = optimizer.minimize(total_loss, var_list=list(ss_weights.values()) + list(fc_weights.values())) 139 | 140 | # Set the tensorboard 141 | self.training_summaries = [] 142 | self.training_summaries.append(tf.summary.scalar('Meta Train Loss', (total_loss / tf.to_float(FLAGS.metatrain_epite_sample_num)))) 143 | self.training_summaries.append(tf.summary.scalar('Meta Train Accuracy', total_accuracy)) 144 | for j in range(num_updates): 145 | self.training_summaries.append(tf.summary.scalar('Base Train Loss Step' + str(j+1), total_lossa[j])) 146 | for j in range(num_updates): 147 | self.training_summaries.append(tf.summary.scalar('Base Val Loss Step' + str(j+1), total_lossb[j])) 148 | 149 | self.training_summ_op = tf.summary.merge(self.training_summaries) 150 | 151 | self.input_val_loss = tf.placeholder(tf.float32) 152 | self.input_val_acc = tf.placeholder(tf.float32) 153 | self.val_summaries = [] 154 | self.val_summaries.append(tf.summary.scalar('Meta Val Loss', self.input_val_loss)) 155 | self.val_summaries.append(tf.summary.scalar('Meta Val Accuracy', self.input_val_acc)) 156 | self.val_summ_op = tf.summary.merge(self.val_summaries) 157 | 158 | def construct_test_model(self): 159 | """The function to construct meta-test model.""" 160 | # Set the placeholder for the input episode 161 | self.inputa = tf.placeholder(tf.float32) 162 | self.inputb = tf.placeholder(tf.float32) 163 | self.labela = tf.placeholder(tf.float32) 164 | self.labelb = tf.placeholder(tf.float32) 165 | 166 | with tf.variable_scope('meta-test-model', reuse=None) as training_scope: 167 | # construct the model weights 168 | self.ss_weights = ss_weights = self.construct_resnet_ss_weights() 169 | self.weights = weights = self.construct_resnet_weights() 170 | self.fc_weights = fc_weights = self.construct_fc_weights() 171 | 172 | # Load test base epoch number from FLAGS 173 | num_updates = FLAGS.test_base_epoch_num 174 | 175 | def task_metalearn(inp, reuse=True): 176 | """The function to process one episode in a meta-batch. 177 | Args: 178 | inp: the input episode. 179 | reuse: whether reuse the variables for the normalization. 180 | Returns: 181 | A serious outputs like losses and accuracies. 182 | """ 183 | # Seperate inp to different variables 184 | inputa, inputb, labela, labelb = inp 185 | # Generate empty list to record accuracies 186 | accb_list = [] 187 | 188 | # Embed the input images to embeddings with ss weights 189 | emb_outputa = self.forward_resnet(inputa, weights, ss_weights, reuse=reuse) 190 | emb_outputb = self.forward_resnet(inputb, weights, ss_weights, reuse=True) 191 | 192 | # This part is similar to the meta-train function, you may refer to the comments above 193 | outputa = self.forward_fc(emb_outputa, fc_weights) 194 | lossa = self.loss_func(outputa, labela) 195 | grads = tf.gradients(lossa, list(fc_weights.values())) 196 | gradients = dict(zip(fc_weights.keys(), grads)) 197 | fast_fc_weights = dict(zip(fc_weights.keys(), [fc_weights[key] - \ 198 | self.update_lr*gradients[key] for key in fc_weights.keys()])) 199 | outputb = self.forward_fc(emb_outputb, fast_fc_weights) 200 | accb = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(outputb), 1), tf.argmax(labelb, 1)) 201 | accb_list.append(accb) 202 | 203 | for j in range(num_updates - 1): 204 | lossa = self.loss_func(self.forward_fc(emb_outputa, fast_fc_weights), labela) 205 | grads = tf.gradients(lossa, list(fast_fc_weights.values())) 206 | gradients = dict(zip(fast_fc_weights.keys(), grads)) 207 | fast_fc_weights = dict(zip(fast_fc_weights.keys(), [fast_fc_weights[key] - \ 208 | self.update_lr*gradients[key] for key in fast_fc_weights.keys()])) 209 | outputb = self.forward_fc(emb_outputb, fast_fc_weights) 210 | accb = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(outputb), 1), tf.argmax(labelb, 1)) 211 | accb_list.append(accb) 212 | 213 | lossb = self.loss_func(outputb, labelb) 214 | 215 | task_output = [lossb, accb, accb_list] 216 | 217 | return task_output 218 | 219 | if FLAGS.norm is not None: 220 | unused = task_metalearn((self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]), False) 221 | 222 | out_dtype = [tf.float32, tf.float32, [tf.float32]*num_updates] 223 | 224 | result = tf.map_fn(task_metalearn, elems=(self.inputa, self.inputb, self.labela, self.labelb), \ 225 | dtype=out_dtype, parallel_iterations=FLAGS.meta_batch_size) 226 | lossesb, accsb, accsb_list = result 227 | 228 | self.metaval_total_loss = total_loss = tf.reduce_sum(lossesb) 229 | self.metaval_total_accuracy = total_accuracy = tf.reduce_sum(accsb) 230 | self.metaval_total_accuracies = total_accuracies =[tf.reduce_sum(accsb_list[j]) for j in range(num_updates)] 231 | 232 | return MetaModel() 233 | -------------------------------------------------------------------------------- /pytorch/trainer/meta.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Modified from: https://github.com/Sha-Lab/FEAT 4 | ## Tianjin University 5 | ## liuyaoyao@tju.edu.cn 6 | ## Copyright (c) 2019 7 | ## 8 | ## This source code is licensed under the MIT-style license found in the 9 | ## LICENSE file in the root directory of this source tree 10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 11 | """ Trainer for meta-train phase. """ 12 | import os.path as osp 13 | import os 14 | import tqdm 15 | import numpy as np 16 | import torch 17 | import torch.nn.functional as F 18 | from torch.utils.data import DataLoader 19 | from dataloader.samplers import CategoriesSampler 20 | from models.mtl import MtlLearner 21 | from utils.misc import Averager, Timer, count_acc, compute_confidence_interval, ensure_path 22 | from tensorboardX import SummaryWriter 23 | from dataloader.dataset_loader import DatasetLoader as Dataset 24 | 25 | class MetaTrainer(object): 26 | """The class that contains the code for the meta-train phase and meta-eval phase.""" 27 | def __init__(self, args): 28 | # Set the folder to save the records and checkpoints 29 | log_base_dir = './logs/' 30 | if not osp.exists(log_base_dir): 31 | os.mkdir(log_base_dir) 32 | meta_base_dir = osp.join(log_base_dir, 'meta') 33 | if not osp.exists(meta_base_dir): 34 | os.mkdir(meta_base_dir) 35 | save_path1 = '_'.join([args.dataset, args.model_type, 'MTL']) 36 | save_path2 = 'shot' + str(args.shot) + '_way' + str(args.way) + '_query' + str(args.train_query) + \ 37 | '_step' + str(args.step_size) + '_gamma' + str(args.gamma) + '_lr1' + str(args.meta_lr1) + '_lr2' + str(args.meta_lr2) + \ 38 | '_batch' + str(args.num_batch) + '_maxepoch' + str(args.max_epoch) + \ 39 | '_baselr' + str(args.base_lr) + '_updatestep' + str(args.update_step) + \ 40 | '_stepsize' + str(args.step_size) + '_' + args.meta_label 41 | args.save_path = meta_base_dir + '/' + save_path1 + '_' + save_path2 42 | ensure_path(args.save_path) 43 | 44 | # Set args to be shareable in the class 45 | self.args = args 46 | 47 | # Load meta-train set 48 | self.trainset = Dataset('train', self.args) 49 | self.train_sampler = CategoriesSampler(self.trainset.label, self.args.num_batch, self.args.way, self.args.shot + self.args.train_query) 50 | self.train_loader = DataLoader(dataset=self.trainset, batch_sampler=self.train_sampler, num_workers=8, pin_memory=True) 51 | 52 | # Load meta-val set 53 | self.valset = Dataset('val', self.args) 54 | self.val_sampler = CategoriesSampler(self.valset.label, 600, self.args.way, self.args.shot + self.args.val_query) 55 | self.val_loader = DataLoader(dataset=self.valset, batch_sampler=self.val_sampler, num_workers=8, pin_memory=True) 56 | 57 | # Build meta-transfer learning model 58 | self.model = MtlLearner(self.args) 59 | 60 | # Set optimizer 61 | self.optimizer = torch.optim.Adam([{'params': filter(lambda p: p.requires_grad, self.model.encoder.parameters())}, \ 62 | {'params': self.model.base_learner.parameters(), 'lr': self.args.meta_lr2}], lr=self.args.meta_lr1) 63 | # Set learning rate scheduler 64 | self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.args.step_size, gamma=self.args.gamma) 65 | 66 | # load pretrained model without FC classifier 67 | self.model_dict = self.model.state_dict() 68 | if self.args.init_weights is not None: 69 | pretrained_dict = torch.load(self.args.init_weights)['params'] 70 | else: 71 | pre_base_dir = osp.join(log_base_dir, 'pre') 72 | pre_save_path1 = '_'.join([args.dataset, args.model_type]) 73 | pre_save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \ 74 | str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch) 75 | pre_save_path = pre_base_dir + '/' + pre_save_path1 + '_' + pre_save_path2 76 | pretrained_dict = torch.load(osp.join(pre_save_path, 'max_acc.pth'))['params'] 77 | pretrained_dict = {'encoder.'+k: v for k, v in pretrained_dict.items()} 78 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in self.model_dict} 79 | print(pretrained_dict.keys()) 80 | self.model_dict.update(pretrained_dict) 81 | self.model.load_state_dict(self.model_dict) 82 | 83 | # Set model to GPU 84 | if torch.cuda.is_available(): 85 | torch.backends.cudnn.benchmark = True 86 | self.model = self.model.cuda() 87 | 88 | def save_model(self, name): 89 | """The function to save checkpoints. 90 | Args: 91 | name: the name for saved checkpoint 92 | """ 93 | torch.save(dict(params=self.model.state_dict()), osp.join(self.args.save_path, name + '.pth')) 94 | 95 | def train(self): 96 | """The function for the meta-train phase.""" 97 | 98 | # Set the meta-train log 99 | trlog = {} 100 | trlog['args'] = vars(self.args) 101 | trlog['train_loss'] = [] 102 | trlog['val_loss'] = [] 103 | trlog['train_acc'] = [] 104 | trlog['val_acc'] = [] 105 | trlog['max_acc'] = 0.0 106 | trlog['max_acc_epoch'] = 0 107 | 108 | # Set the timer 109 | timer = Timer() 110 | # Set global count to zero 111 | global_count = 0 112 | # Set tensorboardX 113 | writer = SummaryWriter(comment=self.args.save_path) 114 | 115 | # Generate the labels for train set of the episodes 116 | label_shot = torch.arange(self.args.way).repeat(self.args.shot) 117 | if torch.cuda.is_available(): 118 | label_shot = label_shot.type(torch.cuda.LongTensor) 119 | else: 120 | label_shot = label_shot.type(torch.LongTensor) 121 | 122 | # Start meta-train 123 | for epoch in range(1, self.args.max_epoch + 1): 124 | # Update learning rate 125 | self.lr_scheduler.step() 126 | # Set the model to train mode 127 | self.model.train() 128 | # Set averager classes to record training losses and accuracies 129 | train_loss_averager = Averager() 130 | train_acc_averager = Averager() 131 | 132 | # Generate the labels for test set of the episodes during meta-train updates 133 | label = torch.arange(self.args.way).repeat(self.args.train_query) 134 | if torch.cuda.is_available(): 135 | label = label.type(torch.cuda.LongTensor) 136 | else: 137 | label = label.type(torch.LongTensor) 138 | 139 | # Using tqdm to read samples from train loader 140 | tqdm_gen = tqdm.tqdm(self.train_loader) 141 | for i, batch in enumerate(tqdm_gen, 1): 142 | # Update global count number 143 | global_count = global_count + 1 144 | if torch.cuda.is_available(): 145 | data, _ = [_.cuda() for _ in batch] 146 | else: 147 | data = batch[0] 148 | p = self.args.shot * self.args.way 149 | data_shot, data_query = data[:p], data[p:] 150 | # Output logits for model 151 | logits = self.model((data_shot, label_shot, data_query)) 152 | # Calculate meta-train loss 153 | loss = F.cross_entropy(logits, label) 154 | # Calculate meta-train accuracy 155 | acc = count_acc(logits, label) 156 | # Write the tensorboardX records 157 | writer.add_scalar('data/loss', float(loss), global_count) 158 | writer.add_scalar('data/acc', float(acc), global_count) 159 | # Print loss and accuracy for this step 160 | tqdm_gen.set_description('Epoch {}, Loss={:.4f} Acc={:.4f}'.format(epoch, loss.item(), acc)) 161 | 162 | # Add loss and accuracy for the averagers 163 | train_loss_averager.add(loss.item()) 164 | train_acc_averager.add(acc) 165 | 166 | # Loss backwards and optimizer updates 167 | self.optimizer.zero_grad() 168 | loss.backward() 169 | self.optimizer.step() 170 | 171 | # Update the averagers 172 | train_loss_averager = train_loss_averager.item() 173 | train_acc_averager = train_acc_averager.item() 174 | 175 | # Start validation for this epoch, set model to eval mode 176 | self.model.eval() 177 | 178 | # Set averager classes to record validation losses and accuracies 179 | val_loss_averager = Averager() 180 | val_acc_averager = Averager() 181 | 182 | # Generate the labels for test set of the episodes during meta-val for this epoch 183 | label = torch.arange(self.args.way).repeat(self.args.val_query) 184 | if torch.cuda.is_available(): 185 | label = label.type(torch.cuda.LongTensor) 186 | else: 187 | label = label.type(torch.LongTensor) 188 | 189 | # Print previous information 190 | if epoch % 10 == 0: 191 | print('Best Epoch {}, Best Val Acc={:.4f}'.format(trlog['max_acc_epoch'], trlog['max_acc'])) 192 | # Run meta-validation 193 | for i, batch in enumerate(self.val_loader, 1): 194 | if torch.cuda.is_available(): 195 | data, _ = [_.cuda() for _ in batch] 196 | else: 197 | data = batch[0] 198 | p = self.args.shot * self.args.way 199 | data_shot, data_query = data[:p], data[p:] 200 | logits = self.model((data_shot, label_shot, data_query)) 201 | loss = F.cross_entropy(logits, label) 202 | acc = count_acc(logits, label) 203 | 204 | val_loss_averager.add(loss.item()) 205 | val_acc_averager.add(acc) 206 | 207 | # Update validation averagers 208 | val_loss_averager = val_loss_averager.item() 209 | val_acc_averager = val_acc_averager.item() 210 | # Write the tensorboardX records 211 | writer.add_scalar('data/val_loss', float(val_loss_averager), epoch) 212 | writer.add_scalar('data/val_acc', float(val_acc_averager), epoch) 213 | # Print loss and accuracy for this epoch 214 | print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format(epoch, val_loss_averager, val_acc_averager)) 215 | 216 | # Update best saved model 217 | if val_acc_averager > trlog['max_acc']: 218 | trlog['max_acc'] = val_acc_averager 219 | trlog['max_acc_epoch'] = epoch 220 | self.save_model('max_acc') 221 | # Save model every 10 epochs 222 | if epoch % 10 == 0: 223 | self.save_model('epoch'+str(epoch)) 224 | 225 | # Update the logs 226 | trlog['train_loss'].append(train_loss_averager) 227 | trlog['train_acc'].append(train_acc_averager) 228 | trlog['val_loss'].append(val_loss_averager) 229 | trlog['val_acc'].append(val_acc_averager) 230 | 231 | # Save log 232 | torch.save(trlog, osp.join(self.args.save_path, 'trlog')) 233 | 234 | if epoch % 10 == 0: 235 | print('Running Time: {}, Estimated Time: {}'.format(timer.measure(), timer.measure(epoch / self.args.max_epoch))) 236 | 237 | writer.close() 238 | 239 | def eval(self): 240 | """The function for the meta-eval phase.""" 241 | # Load the logs 242 | trlog = torch.load(osp.join(self.args.save_path, 'trlog')) 243 | 244 | # Load meta-test set 245 | test_set = Dataset('test', self.args) 246 | sampler = CategoriesSampler(test_set.label, 600, self.args.way, self.args.shot + self.args.val_query) 247 | loader = DataLoader(test_set, batch_sampler=sampler, num_workers=8, pin_memory=True) 248 | 249 | # Set test accuracy recorder 250 | test_acc_record = np.zeros((600,)) 251 | 252 | # Load model for meta-test phase 253 | if self.args.eval_weights is not None: 254 | self.model.load_state_dict(torch.load(self.args.eval_weights)['params']) 255 | else: 256 | self.model.load_state_dict(torch.load(osp.join(self.args.save_path, 'max_acc' + '.pth'))['params']) 257 | # Set model to eval mode 258 | self.model.eval() 259 | 260 | # Set accuracy averager 261 | ave_acc = Averager() 262 | 263 | # Generate labels 264 | label = torch.arange(self.args.way).repeat(self.args.val_query) 265 | if torch.cuda.is_available(): 266 | label = label.type(torch.cuda.LongTensor) 267 | else: 268 | label = label.type(torch.LongTensor) 269 | label_shot = torch.arange(self.args.way).repeat(self.args.shot) 270 | if torch.cuda.is_available(): 271 | label_shot = label_shot.type(torch.cuda.LongTensor) 272 | else: 273 | label_shot = label_shot.type(torch.LongTensor) 274 | 275 | # Start meta-test 276 | for i, batch in enumerate(loader, 1): 277 | if torch.cuda.is_available(): 278 | data, _ = [_.cuda() for _ in batch] 279 | else: 280 | data = batch[0] 281 | k = self.args.way * self.args.shot 282 | data_shot, data_query = data[:k], data[k:] 283 | logits = self.model((data_shot, label_shot, data_query)) 284 | acc = count_acc(logits, label) 285 | ave_acc.add(acc) 286 | test_acc_record[i-1] = acc 287 | if i % 100 == 0: 288 | print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100)) 289 | 290 | # Calculate the confidence interval, update the logs 291 | m, pm = compute_confidence_interval(test_acc_record) 292 | print('Val Best Epoch {}, Acc {:.4f}, Test Acc {:.4f}'.format(trlog['max_acc_epoch'], trlog['max_acc'], ave_acc.item())) 293 | print('Test Acc {:.4f} + {:.4f}'.format(m, pm)) 294 | -------------------------------------------------------------------------------- /tensorflow/trainer/meta.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Modified from: https://github.com/cbfinn/maml 4 | ## Tianjin University 5 | ## liuyaoyao@tju.edu.cn 6 | ## Copyright (c) 2019 7 | ## 8 | ## This source code is licensed under the MIT-style license found in the 9 | ## LICENSE file in the root directory of this source tree 10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 11 | 12 | """ Trainer for meta-learning. """ 13 | import os 14 | import csv 15 | import pickle 16 | import random 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from tqdm import trange 21 | from data_generator.meta_data_generator import MetaDataGenerator 22 | from models.meta_model import MakeMetaModel 23 | from tensorflow.python.platform import flags 24 | from utils.misc import process_batch 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | class MetaTrainer: 29 | """The class that contains the code for the meta-train and meta-test.""" 30 | def __init__(self): 31 | # Remove the saved datalist for a new experiment 32 | os.system('rm -r ./logs/processed_data/*') 33 | data_generator = MetaDataGenerator() 34 | if FLAGS.metatrain: 35 | # Build model for meta-train phase 36 | print('Building meta-train model') 37 | self.model = MakeMetaModel() 38 | self.model.construct_model() 39 | print('Meta-train model is built') 40 | # Start tensorflow session 41 | self.start_session() 42 | # Generate data for meta-train phase 43 | if FLAGS.load_saved_weights: 44 | random.seed(5) 45 | data_generator.generate_data(data_type='train') 46 | if FLAGS.load_saved_weights: 47 | random.seed(7) 48 | data_generator.generate_data(data_type='test') 49 | if FLAGS.load_saved_weights: 50 | random.seed(9) 51 | data_generator.generate_data(data_type='val') 52 | else: 53 | # Build model for meta-test phase 54 | print('Building meta-test mdoel') 55 | self.model = MakeMetaModel() 56 | self.model.construct_test_model() 57 | self.model.summ_op = tf.summary.merge_all() 58 | print('Meta-test model is built') 59 | # Start tensorflow session 60 | self.start_session() 61 | # Generate data for meta-test phase 62 | if FLAGS.load_saved_weights: 63 | random.seed(7) 64 | data_generator.generate_data(data_type='test') 65 | # Load the experiment setting string from FLAGS 66 | exp_string = FLAGS.exp_string 67 | 68 | # Global initialization and starting queue 69 | tf.global_variables_initializer().run() 70 | tf.train.start_queue_runners() 71 | 72 | if FLAGS.metatrain: 73 | # Process initialization weights for meta-train 74 | init_dir = FLAGS.logdir_base + 'init_weights/' 75 | if not os.path.exists(init_dir): 76 | os.mkdir(init_dir) 77 | pre_save_str = FLAGS.pre_string 78 | this_init_dir = init_dir + pre_save_str + '.pre_iter(' + str(FLAGS.pretrain_iterations) + ')/' 79 | if not os.path.exists(this_init_dir): 80 | # If there is no saved initialization weights for meta-train, load pre-train model and save initialization weights 81 | os.mkdir(this_init_dir) 82 | if FLAGS.load_saved_weights: 83 | print('Loading downloaded pretrain weights') 84 | weights = np.load('logs/download_weights/weights.npy', allow_pickle=True, encoding="latin1").tolist() 85 | else: 86 | print('Loading pretrain weights') 87 | weights_save_dir_base = FLAGS.pretrain_dir 88 | weights_save_dir = os.path.join(weights_save_dir_base, pre_save_str) 89 | weights = np.load(os.path.join(weights_save_dir, "weights_{}.npy".format(FLAGS.pretrain_iterations)), \ 90 | allow_pickle=True, encoding="latin1").tolist() 91 | bais_list = [bais_item for bais_item in weights.keys() if '_bias' in bais_item] 92 | # Assign the bias weights to ss model in order to train them during meta-train 93 | for bais_key in bais_list: 94 | self.sess.run(tf.assign(self.model.ss_weights[bais_key], weights[bais_key])) 95 | # Assign pretrained weights to tensorflow variables 96 | for key in weights.keys(): 97 | self.sess.run(tf.assign(self.model.weights[key], weights[key])) 98 | print('Pretrain weights loaded, saving init weights') 99 | # Load and save init weights for the model 100 | new_weights = self.sess.run(self.model.weights) 101 | ss_weights = self.sess.run(self.model.ss_weights) 102 | fc_weights = self.sess.run(self.model.fc_weights) 103 | np.save(this_init_dir + 'weights_init.npy', new_weights) 104 | np.save(this_init_dir + 'ss_weights_init.npy', ss_weights) 105 | np.save(this_init_dir + 'fc_weights_init.npy', fc_weights) 106 | else: 107 | # If the initialization weights are already generated, load the previous saved ones 108 | # This process is deactivate in the default settings, you may activate this for ablative study 109 | print('Loading previous saved init weights') 110 | weights = np.load(this_init_dir + 'weights_init.npy', allow_pickle=True, encoding="latin1").tolist() 111 | ss_weights = np.load(this_init_dir + 'ss_weights_init.npy', allow_pickle=True, encoding="latin1").tolist() 112 | fc_weights = np.load(this_init_dir + 'fc_weights_init.npy', allow_pickle=True, encoding="latin1").tolist() 113 | for key in weights.keys(): 114 | self.sess.run(tf.assign(self.model.weights[key], weights[key])) 115 | for key in ss_weights.keys(): 116 | self.sess.run(tf.assign(self.model.ss_weights[key], ss_weights[key])) 117 | for key in fc_weights.keys(): 118 | self.sess.run(tf.assign(self.model.fc_weights[key], fc_weights[key])) 119 | print('Init weights loaded') 120 | else: 121 | # Load the saved meta model for meta-test phase 122 | if FLAGS.load_saved_weights: 123 | # Load the downloaded weights 124 | weights = np.load('./logs/download_weights/weights.npy', allow_pickle=True, encoding="latin1").tolist() 125 | ss_weights = np.load('./logs/download_weights/ss_weights.npy', allow_pickle=True, encoding="latin1").tolist() 126 | fc_weights = np.load('./logs/download_weights/fc_weights.npy', allow_pickle=True, encoding="latin1").tolist() 127 | else: 128 | # Load the saved weights of meta-train 129 | weights = np.load(FLAGS.logdir + '/' + exp_string + '/weights_' + str(FLAGS.test_iter) + '.npy', \ 130 | allow_pickle=True, encoding="latin1").tolist() 131 | ss_weights = np.load(FLAGS.logdir + '/' + exp_string + '/ss_weights_' + str(FLAGS.test_iter) + '.npy', \ 132 | allow_pickle=True, encoding="latin1").tolist() 133 | fc_weights = np.load(FLAGS.logdir + '/' + exp_string + '/fc_weights_' + str(FLAGS.test_iter) + '.npy', \ 134 | allow_pickle=True, encoding="latin1").tolist() 135 | # Assign the weights to the tensorflow variables 136 | for key in weights.keys(): 137 | self.sess.run(tf.assign(self.model.weights[key], weights[key])) 138 | for key in ss_weights.keys(): 139 | self.sess.run(tf.assign(self.model.ss_weights[key], ss_weights[key])) 140 | for key in fc_weights.keys(): 141 | self.sess.run(tf.assign(self.model.fc_weights[key], fc_weights[key])) 142 | print('Weights loaded') 143 | if FLAGS.load_saved_weights: 144 | print('Meta test using downloaded model') 145 | else: 146 | print('Test iter: ' + str(FLAGS.test_iter)) 147 | 148 | if FLAGS.metatrain: 149 | self.train(data_generator) 150 | else: 151 | self.test(data_generator) 152 | 153 | def start_session(self): 154 | """The function to start tensorflow session.""" 155 | if FLAGS.full_gpu_memory_mode: 156 | gpu_config = tf.ConfigProto() 157 | gpu_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_rate 158 | self.sess = tf.InteractiveSession(config=gpu_config) 159 | else: 160 | self.sess = tf.InteractiveSession() 161 | 162 | def train(self, data_generator): 163 | """The function for the meta-train phase 164 | Arg: 165 | data_generator: the data generator class for this phase 166 | """ 167 | # Load the experiment setting string from FLAGS 168 | exp_string = FLAGS.exp_string 169 | # Generate tensorboard file writer 170 | train_writer = tf.summary.FileWriter(FLAGS.logdir + '/' + exp_string, self.sess.graph) 171 | print('Start meta-train phase') 172 | # Generate empty list to record losses and accuracies 173 | loss_list, acc_list = [], [] 174 | # Load the meta learning rate from FLAGS 175 | train_lr = FLAGS.meta_lr 176 | # Load data for meta-train and meta validation 177 | data_generator.load_data(data_type='train') 178 | data_generator.load_data(data_type='val') 179 | 180 | for train_idx in trange(FLAGS.metatrain_iterations): 181 | # Load the episodes for this meta batch 182 | inputa = [] 183 | labela = [] 184 | inputb = [] 185 | labelb = [] 186 | for meta_batch_idx in range(FLAGS.meta_batch_size): 187 | this_episode = data_generator.load_episode(index=train_idx*FLAGS.meta_batch_size+meta_batch_idx, data_type='train') 188 | inputa.append(this_episode[0]) 189 | labela.append(this_episode[1]) 190 | inputb.append(this_episode[2]) 191 | labelb.append(this_episode[3]) 192 | inputa = np.array(inputa) 193 | labela = np.array(labela) 194 | inputb = np.array(inputb) 195 | labelb = np.array(labelb) 196 | 197 | # Generate feed dict for the tensorflow graph 198 | feed_dict = {self.model.inputa: inputa, self.model.inputb: inputb, \ 199 | self.model.labela: labela, self.model.labelb: labelb, self.model.meta_lr: train_lr} 200 | 201 | # Set the variables to load from the tensorflow graph 202 | input_tensors = [self.model.metatrain_op] # The meta train optimizer 203 | input_tensors.extend([self.model.total_loss]) # The meta train loss 204 | input_tensors.extend([self.model.total_accuracy]) # The meta train accuracy 205 | input_tensors.extend([self.model.training_summ_op]) # The tensorboard summary operation 206 | 207 | # run this meta-train iteration 208 | result = self.sess.run(input_tensors, feed_dict) 209 | 210 | # record losses, accuracies and tensorboard 211 | loss_list.append(result[1]) 212 | acc_list.append(result[2]) 213 | train_writer.add_summary(result[3], train_idx) 214 | 215 | # print meta-train information on the screen after several iterations 216 | if (train_idx!=0) and train_idx % FLAGS.meta_print_step == 0: 217 | print_str = 'Iteration:' + str(train_idx) 218 | print_str += ' Loss:' + str(np.mean(loss_list)) + ' Acc:' + str(np.mean(acc_list)) 219 | print(print_str) 220 | loss_list, acc_list = [], [] 221 | 222 | # Save the model during meta-teain 223 | if train_idx % FLAGS.meta_save_step == 0: 224 | weights = self.sess.run(self.model.weights) 225 | ss_weights = self.sess.run(self.model.ss_weights) 226 | fc_weights = self.sess.run(self.model.fc_weights) 227 | np.save(FLAGS.logdir + '/' + exp_string + '/weights_' + str(train_idx) + '.npy', weights) 228 | np.save(FLAGS.logdir + '/' + exp_string + '/ss_weights_' + str(train_idx) + '.npy', ss_weights) 229 | np.save(FLAGS.logdir + '/' + exp_string + '/fc_weights_' + str(train_idx) + '.npy', fc_weights) 230 | 231 | # Run the meta-validation during meta-train 232 | if train_idx % FLAGS.meta_val_print_step == 0: 233 | test_loss = [] 234 | test_accs = [] 235 | for test_itr in range(FLAGS.meta_intrain_val_sample): 236 | this_episode = data_generator.load_episode(index=test_itr, data_type='val') 237 | test_inputa = this_episode[0][np.newaxis, :] 238 | test_labela = this_episode[1][np.newaxis, :] 239 | test_inputb = this_episode[2][np.newaxis, :] 240 | test_labelb = this_episode[3][np.newaxis, :] 241 | 242 | test_feed_dict = {self.model.inputa: test_inputa, self.model.inputb: test_inputb, \ 243 | self.model.labela: test_labela, self.model.labelb: test_labelb, \ 244 | self.model.meta_lr: 0.0} 245 | test_input_tensors = [self.model.total_loss, self.model.total_accuracy] 246 | test_result = self.sess.run(test_input_tensors, test_feed_dict) 247 | test_loss.append(test_result[0]) 248 | test_accs.append(test_result[1]) 249 | 250 | valsum_feed_dict = {self.model.input_val_loss: \ 251 | np.mean(test_loss)*np.float(FLAGS.meta_batch_size)/np.float(FLAGS.shot_num), \ 252 | self.model.input_val_acc: np.mean(test_accs)*np.float(FLAGS.meta_batch_size)} 253 | valsum = self.sess.run(self.model.val_summ_op, valsum_feed_dict) 254 | train_writer.add_summary(valsum, train_idx) 255 | print_str = '[***] Val Loss:' + str(np.mean(test_loss)*FLAGS.meta_batch_size) + \ 256 | ' Val Acc:' + str(np.mean(test_accs)*FLAGS.meta_batch_size) 257 | print(print_str) 258 | 259 | # Reduce the meta learning rate to half after several iterations 260 | if (train_idx!=0) and train_idx % FLAGS.lr_drop_step == 0: 261 | train_lr = train_lr * FLAGS.lr_drop_rate 262 | if train_lr < 0.1 * FLAGS.meta_lr: 263 | train_lr = 0.1 * FLAGS.meta_lr 264 | print('Train LR: {}'.format(train_lr)) 265 | 266 | # Save the final model 267 | weights = self.sess.run(self.model.weights) 268 | ss_weights = self.sess.run(self.model.ss_weights) 269 | fc_weights = self.sess.run(self.model.fc_weights) 270 | np.save(FLAGS.logdir + '/' + exp_string + '/weights_' + str(train_idx+1) + '.npy', weights) 271 | np.save(FLAGS.logdir + '/' + exp_string + '/ss_weights_' + str(train_idx+1) + '.npy', ss_weights) 272 | np.save(FLAGS.logdir + '/' + exp_string + '/fc_weights_' + str(train_idx+1) + '.npy', fc_weights) 273 | 274 | def test(self, data_generator): 275 | """The function for the meta-test phase 276 | Arg: 277 | data_generator: the data generator class for this phase 278 | """ 279 | # Set meta-test episode number 280 | NUM_TEST_POINTS = 600 281 | # Load the experiment setting string from FLAGS 282 | exp_string = FLAGS.exp_string 283 | print('Start meta-test phase') 284 | np.random.seed(1) 285 | # Generate empty list to record accuracies 286 | metaval_accuracies = [] 287 | # Load data for meta-test 288 | data_generator.load_data(data_type='test') 289 | for test_idx in trange(NUM_TEST_POINTS): 290 | # Load one episode for meta-test 291 | this_episode = data_generator.load_episode(index=test_idx, data_type='test') 292 | inputa = this_episode[0][np.newaxis, :] 293 | labela = this_episode[1][np.newaxis, :] 294 | inputb = this_episode[2][np.newaxis, :] 295 | labelb = this_episode[3][np.newaxis, :] 296 | feed_dict = {self.model.inputa: inputa, self.model.inputb: inputb, \ 297 | self.model.labela: labela, self.model.labelb: labelb, self.model.meta_lr: 0.0} 298 | result = self.sess.run(self.model.metaval_total_accuracies, feed_dict) 299 | metaval_accuracies.append(result) 300 | # Calculate the mean accuarcies and the confidence intervals 301 | metaval_accuracies = np.array(metaval_accuracies) 302 | means = np.mean(metaval_accuracies, 0) 303 | stds = np.std(metaval_accuracies, 0) 304 | ci95 = 1.96*stds/np.sqrt(NUM_TEST_POINTS) 305 | 306 | # Print the meta-test results 307 | print('Test accuracies and confidence intervals') 308 | print((means, ci95)) 309 | 310 | # Save the meta-test results in the csv files 311 | if not FLAGS.load_saved_weights: 312 | out_filename = FLAGS.logdir +'/'+ exp_string + '/' + 'result_' + str(FLAGS.shot_num) + 'shot_' + str(FLAGS.test_iter) + '.csv' 313 | out_pkl = FLAGS.logdir +'/'+ exp_string + '/' + 'result_' + str(FLAGS.shot_num) + 'shot_' + str(FLAGS.test_iter) + '.pkl' 314 | with open(out_pkl, 'wb') as f: 315 | pickle.dump({'mses': metaval_accuracies}, f) 316 | with open(out_filename, 'w') as f: 317 | writer = csv.writer(f, delimiter=',') 318 | writer.writerow(['update'+str(i) for i in range(len(means))]) 319 | writer.writerow(means) 320 | writer.writerow(stds) 321 | writer.writerow(ci95) 322 | -------------------------------------------------------------------------------- /tensorflow/models/resnet18.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Yaoyao Liu 3 | ## Modified from: https://github.com/cbfinn/maml 4 | ## Tianjin University 5 | ## liuyaoyao@tju.edu.cn 6 | ## Copyright (c) 2019 7 | ## 8 | ## This source code is licensed under the MIT-style license found in the 9 | ## LICENSE file in the root directory of this source tree 10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 11 | 12 | """ ResNet-18 class. """ 13 | import numpy as np 14 | import tensorflow as tf 15 | from tensorflow.python.platform import flags 16 | from utils.misc import mse, softmaxloss, xent, resnet_conv_block, resnet_nob_conv_block 17 | 18 | FLAGS = flags.FLAGS 19 | 20 | class Models: 21 | """The class that contains the code for the basic resnet models and SS weights""" 22 | def __init__(self): 23 | # Set the dimension number for the input feature maps 24 | self.dim_input = FLAGS.img_size * FLAGS.img_size * 3 25 | # Set the dimension number for the outputs 26 | self.dim_output = FLAGS.way_num 27 | # Load base learning rates from FLAGS 28 | self.update_lr = FLAGS.base_lr 29 | # Load the pre-train phase class number from FLAGS 30 | self.pretrain_class_num = FLAGS.pretrain_class_num 31 | # Set the initial meta learning rate 32 | self.meta_lr = tf.placeholder_with_default(FLAGS.meta_lr, ()) 33 | # Set the initial pre-train learning rate 34 | self.pretrain_lr = tf.placeholder_with_default(FLAGS.pre_lr, ()) 35 | 36 | # Set the default objective functions for meta-train and pre-train 37 | self.loss_func = xent 38 | self.pretrain_loss_func = softmaxloss 39 | 40 | # Set the default channel number to 3 41 | self.channels = 3 42 | # Load the image size from FLAGS 43 | self.img_size = FLAGS.img_size 44 | 45 | def process_ss_weights(self, weights, ss_weights, label): 46 | """The function to process the scaling operation 47 | Args: 48 | weights: the weights for the resnet. 49 | ss_weights: the weights for scaling and shifting operation. 50 | label: the label to indicate which layer we are operating. 51 | Return: 52 | The processed weights for the new resnet. 53 | """ 54 | [dim0, dim1] = weights[label].get_shape().as_list()[0:2] 55 | this_ss_weights = tf.tile(ss_weights[label], multiples=[dim0, dim1, 1, 1]) 56 | return tf.multiply(weights[label], this_ss_weights) 57 | 58 | def forward_pretrain_resnet(self, inp, weights, reuse=False, scope=''): 59 | """The function to forward the resnet during pre-train phase 60 | Args: 61 | inp: input feature maps. 62 | weights: input resnet weights. 63 | reuse: reuse the batch norm weights or not. 64 | scope: the label to indicate which layer we are processing. 65 | Return: 66 | The processed feature maps. 67 | """ 68 | inp = tf.reshape(inp, [-1, self.img_size, self.img_size, self.channels]) 69 | inp = tf.image.resize_images(inp, size=[224,224], method=tf.image.ResizeMethod.BILINEAR) 70 | net = self.pretrain_first_block_forward(inp, weights, 'block0_1', reuse, scope) 71 | 72 | net = self.pretrain_block_forward(net, weights, 'block1_1', reuse, scope) 73 | net = self.pretrain_block_forward(net, weights, 'block1_2', reuse, scope, block_last_layer=True) 74 | 75 | net = self.pretrain_block_forward(net, weights, 'block2_1', reuse, scope) 76 | net = self.pretrain_block_forward(net, weights, 'block2_2', reuse, scope, block_last_layer=True) 77 | 78 | net = self.pretrain_block_forward(net, weights, 'block3_1', reuse, scope) 79 | net = self.pretrain_block_forward(net, weights, 'block3_2', reuse, scope, block_last_layer=True) 80 | 81 | net = self.pretrain_block_forward(net, weights, 'block4_1', reuse, scope) 82 | net = self.pretrain_block_forward(net, weights, 'block4_2', reuse, scope, block_last_layer=True) 83 | 84 | net = tf.nn.avg_pool(net, [1,7,7,1], [1,7,7,1], 'SAME') 85 | net = tf.reshape(net, [-1, np.prod([int(dim) for dim in net.get_shape()[1:]])]) 86 | return net 87 | 88 | def forward_resnet(self, inp, weights, ss_weights, reuse=False, scope=''): 89 | """The function to forward the resnet during meta-train phase 90 | Args: 91 | inp: input feature maps. 92 | weights: input resnet weights. 93 | ss_weights: input scaling weights. 94 | reuse: reuse the batch norm weights or not. 95 | scope: the label to indicate which layer we are processing. 96 | Return: 97 | The processed feature maps. 98 | """ 99 | inp = tf.reshape(inp, [-1, self.img_size, self.img_size, self.channels]) 100 | inp = tf.image.resize_images(inp, size=[224,224], method=tf.image.ResizeMethod.BILINEAR) 101 | net = self.first_block_forward(inp, weights, ss_weights, 'block0_1', reuse, scope) 102 | 103 | net = self.block_forward(net, weights, ss_weights, 'block1_1', reuse, scope) 104 | net = self.block_forward(net, weights, ss_weights, 'block1_2', reuse, scope, block_last_layer=True) 105 | 106 | net = self.block_forward(net, weights, ss_weights, 'block2_1', reuse, scope) 107 | net = self.block_forward(net, weights, ss_weights, 'block2_2', reuse, scope, block_last_layer=True) 108 | 109 | net = self.block_forward(net, weights, ss_weights, 'block3_1', reuse, scope) 110 | net = self.block_forward(net, weights, ss_weights, 'block3_2', reuse, scope, block_last_layer=True) 111 | 112 | net = self.block_forward(net, weights, ss_weights, 'block4_1', reuse, scope) 113 | net = self.block_forward(net, weights, ss_weights, 'block4_2', reuse, scope, block_last_layer=True) 114 | 115 | net = tf.nn.avg_pool(net, [1,7,7,1], [1,7,7,1], 'SAME') 116 | net = tf.reshape(net, [-1, np.prod([int(dim) for dim in net.get_shape()[1:]])]) 117 | return net 118 | 119 | def forward_fc(self, inp, fc_weights): 120 | """The function to forward the fc layer 121 | Args: 122 | inp: input feature maps. 123 | fc_weights: input fc weights. 124 | Return: 125 | The processed feature maps. 126 | """ 127 | net = tf.matmul(inp, fc_weights['w5']) + fc_weights['b5'] 128 | return net 129 | 130 | def pretrain_block_forward(self, inp, weights, block, reuse, scope, block_last_layer=False): 131 | """The function to forward a resnet block during pre-train phase 132 | Args: 133 | inp: input feature maps. 134 | weights: input resnet weights. 135 | block: the string to indicate which block we are processing. 136 | reuse: reuse the batch norm weights or not. 137 | scope: the label to indicate which layer we are processing. 138 | block_last_layer: whether it is the last layer of this block. 139 | Return: 140 | The processed feature maps. 141 | """ 142 | net = resnet_conv_block(inp, weights[block + '_conv1'], weights[block + '_bias1'], reuse, scope+block+'0') 143 | net = resnet_conv_block(net, weights[block + '_conv2'], weights[block + '_bias2'], reuse, scope+block+'1') 144 | res = resnet_nob_conv_block(inp, weights[block + '_conv_res'], reuse, scope+block+'res') 145 | net = net + res 146 | if block_last_layer: 147 | net = tf.nn.max_pool(net, [1,2,2,1], [1,2,2,1], 'SAME') 148 | net = tf.nn.dropout(net, keep_prob=FLAGS.pretrain_dropout_keep) 149 | return net 150 | 151 | def block_forward(self, inp, weights, ss_weights, block, reuse, scope, block_last_layer=False): 152 | """The function to forward a resnet block during meta-train phase 153 | Args: 154 | inp: input feature maps. 155 | weights: input resnet weights. 156 | ss_weights: input scaling weights. 157 | block: the string to indicate which block we are processing. 158 | reuse: reuse the batch norm weights or not. 159 | scope: the label to indicate which layer we are processing. 160 | block_last_layer: whether it is the last layer of this block. 161 | Return: 162 | The processed feature maps. 163 | """ 164 | net = resnet_conv_block(inp, self.process_ss_weights(weights, ss_weights, block + '_conv1'), \ 165 | ss_weights[block + '_bias1'], reuse, scope+block+'0') 166 | net = resnet_conv_block(net, self.process_ss_weights(weights, ss_weights, block + '_conv2'), \ 167 | ss_weights[block + '_bias2'], reuse, scope+block+'1') 168 | res = resnet_nob_conv_block(inp, weights[block + '_conv_res'], reuse, scope+block+'res') 169 | net = net + res 170 | if block_last_layer: 171 | net = tf.nn.max_pool(net, [1,2,2,1], [1,2,2,1], 'SAME') 172 | net = tf.nn.dropout(net, keep_prob=1) 173 | return net 174 | 175 | def pretrain_first_block_forward(self, inp, weights, block, reuse, scope): 176 | """The function to forward the first resnet block during pre-train phase 177 | Args: 178 | inp: input feature maps. 179 | weights: input resnet weights. 180 | block: the string to indicate which block we are processing. 181 | reuse: reuse the batch norm weights or not. 182 | scope: the label to indicate which layer we are processing. 183 | Return: 184 | The processed feature maps. 185 | """ 186 | net = resnet_conv_block(inp, weights[block + '_conv1'], weights[block + '_bias1'], reuse, scope+block+'0') 187 | net = tf.nn.max_pool(net, [1,3,3,1], [1,2,2,1], 'SAME') 188 | net = tf.nn.dropout(net, keep_prob=FLAGS.pretrain_dropout_keep) 189 | return net 190 | 191 | def first_block_forward(self, inp, weights, ss_weights, block, reuse, scope, block_last_layer=False): 192 | """The function to forward the first resnet block during meta-train phase 193 | Args: 194 | inp: input feature maps. 195 | weights: input resnet weights. 196 | block: the string to indicate which block we are processing. 197 | reuse: reuse the batch norm weights or not. 198 | scope: the label to indicate which layer we are processing. 199 | Return: 200 | The processed feature maps. 201 | """ 202 | net = resnet_conv_block(inp, self.process_ss_weights(weights, ss_weights, block + '_conv1'), \ 203 | ss_weights[block + '_bias1'], reuse, scope+block+'0') 204 | net = tf.nn.max_pool(net, [1,3,3,1], [1,2,2,1], 'SAME') 205 | net = tf.nn.dropout(net, keep_prob=1) 206 | return net 207 | 208 | def construct_fc_weights(self): 209 | """The function to construct fc weights. 210 | Return: 211 | The fc weights. 212 | """ 213 | dtype = tf.float32 214 | fc_weights = {} 215 | fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype) 216 | if FLAGS.phase=='pre': 217 | fc_weights['w5'] = tf.get_variable('fc_w5', [512, FLAGS.pretrain_class_num], initializer=fc_initializer) 218 | fc_weights['b5'] = tf.Variable(tf.zeros([FLAGS.pretrain_class_num]), name='fc_b5') 219 | else: 220 | fc_weights['w5'] = tf.get_variable('fc_w5', [512, self.dim_output], initializer=fc_initializer) 221 | fc_weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='fc_b5') 222 | return fc_weights 223 | 224 | def construct_resnet_weights(self): 225 | """The function to construct resnet weights. 226 | Return: 227 | The resnet weights. 228 | """ 229 | weights = {} 230 | dtype = tf.float32 231 | conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype) 232 | fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype) 233 | weights = self.construct_first_block_weights(weights, 7, 3, 64, conv_initializer, dtype, 'block0_1') 234 | 235 | weights = self.construct_residual_block_weights(weights, 3, 64, 64, conv_initializer, dtype, 'block1_1') 236 | weights = self.construct_residual_block_weights(weights, 3, 64, 64, conv_initializer, dtype, 'block1_2') 237 | 238 | weights = self.construct_residual_block_weights(weights, 3, 64, 128, conv_initializer, dtype, 'block2_1') 239 | weights = self.construct_residual_block_weights(weights, 3, 128, 128, conv_initializer, dtype, 'block2_2') 240 | 241 | weights = self.construct_residual_block_weights(weights, 3, 128, 256, conv_initializer, dtype, 'block3_1') 242 | weights = self.construct_residual_block_weights(weights, 3, 256, 256, conv_initializer, dtype, 'block3_2') 243 | 244 | weights = self.construct_residual_block_weights(weights, 3, 256, 512, conv_initializer, dtype, 'block4_1') 245 | weights = self.construct_residual_block_weights(weights, 3, 512, 512, conv_initializer, dtype, 'block4_2') 246 | 247 | weights['w5'] = tf.get_variable('w5', [512, FLAGS.pretrain_class_num], initializer=fc_initializer) 248 | weights['b5'] = tf.Variable(tf.zeros([FLAGS.pretrain_class_num]), name='b5') 249 | return weights 250 | 251 | def construct_residual_block_weights(self, weights, k, last_dim_hidden, dim_hidden, conv_initializer, dtype, scope='block0'): 252 | """The function to construct one block of the resnet weights. 253 | Args: 254 | weights: the resnet weight list. 255 | k: the dimension number of the convolution kernel. 256 | last_dim_hidden: the hidden dimension number of last block. 257 | dim_hidden: the hidden dimension number of the block. 258 | conv_initializer: the convolution initializer. 259 | dtype: the dtype for numpy. 260 | scope: the label to indicate which block we are processing. 261 | Return: 262 | The resnet block weights. 263 | """ 264 | weights[scope + '_conv1'] = tf.get_variable(scope + '_conv1', [k, k, last_dim_hidden, dim_hidden], \ 265 | initializer=conv_initializer, dtype=dtype) 266 | weights[scope + '_bias1'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias1') 267 | weights[scope + '_conv2'] = tf.get_variable(scope + '_conv2', [k, k, dim_hidden, dim_hidden], \ 268 | initializer=conv_initializer, dtype=dtype) 269 | weights[scope + '_bias2'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias2') 270 | weights[scope + '_conv_res'] = tf.get_variable(scope + '_conv_res', [1, 1, last_dim_hidden, dim_hidden], \ 271 | initializer=conv_initializer, dtype=dtype) 272 | return weights 273 | 274 | def construct_first_block_weights(self, weights, k, last_dim_hidden, dim_hidden, conv_initializer, dtype, scope='block0'): 275 | """The function to construct the first block of the resnet weights. 276 | Args: 277 | weights: the resnet weight list. 278 | k: the dimension number of the convolution kernel. 279 | last_dim_hidden: the hidden dimension number of last block. 280 | dim_hidden: the hidden dimension number of the block. 281 | conv_initializer: the convolution initializer. 282 | dtype: the dtype for numpy. 283 | scope: the label to indicate which block we are processing. 284 | Return: 285 | The resnet block weights. 286 | """ 287 | weights[scope + '_conv1'] = tf.get_variable(scope + '_conv1', [k, k, last_dim_hidden, dim_hidden], \ 288 | initializer=conv_initializer, dtype=dtype) 289 | weights[scope + '_bias1'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias1') 290 | return weights 291 | 292 | def construct_first_block_ss_weights(self, ss_weights, last_dim_hidden, dim_hidden, scope='block0'): 293 | """The function to construct first block's ss weights. 294 | Return: 295 | The ss weights. 296 | """ 297 | ss_weights[scope + '_conv1'] = tf.Variable(tf.ones([1, 1, last_dim_hidden, dim_hidden]), name=scope + '_conv1') 298 | ss_weights[scope + '_bias1'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias1') 299 | return ss_weights 300 | 301 | def construct_resnet_ss_weights(self): 302 | """The function to construct ss weights. 303 | Return: 304 | The ss weights. 305 | """ 306 | ss_weights = {} 307 | ss_weights = self.construct_first_block_ss_weights(ss_weights, 3, 64, 'block0_1') 308 | 309 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 64, 64, 'block1_1') 310 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 64, 64, 'block1_2') 311 | 312 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 64, 128, 'block2_1') 313 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 128, 128, 'block2_2') 314 | 315 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 128, 256, 'block3_1') 316 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 256, 256, 'block3_2') 317 | 318 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 256, 512, 'block4_1') 319 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 512, 512, 'block4_2') 320 | 321 | return ss_weights 322 | 323 | def construct_residual_block_ss_weights(self, ss_weights, last_dim_hidden, dim_hidden, scope='block0'): 324 | """The function to construct one block ss weights. 325 | Args: 326 | ss_weights: the ss weight list. 327 | last_dim_hidden: the hidden dimension number of last block. 328 | dim_hidden: the hidden dimension number of the block. 329 | scope: the label to indicate which block we are processing. 330 | Return: 331 | The ss block weights. 332 | """ 333 | ss_weights[scope + '_conv1'] = tf.Variable(tf.ones([1, 1, last_dim_hidden, dim_hidden]), name=scope + '_conv1') 334 | ss_weights[scope + '_bias1'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias1') 335 | ss_weights[scope + '_conv2'] = tf.Variable(tf.ones([1, 1, dim_hidden, dim_hidden]), name=scope + '_conv2') 336 | ss_weights[scope + '_bias2'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias2') 337 | return ss_weights 338 | --------------------------------------------------------------------------------