├── .gitignore
├── pytorch
├── models
│ ├── __init__.py
│ ├── conv2d_mtl.py
│ ├── mtl.py
│ └── resnet_mtl.py
├── trainer
│ ├── __init__.py
│ ├── pre.py
│ └── meta.py
├── utils
│ ├── __init__.py
│ ├── gpu_tools.py
│ └── misc.py
├── dataloader
│ ├── __init__.py
│ ├── samplers.py
│ └── dataset_loader.py
├── run_pre.py
├── run_meta.py
├── README.md
└── main.py
├── tensorflow
├── utils
│ ├── __init__.py
│ └── misc.py
├── models
│ ├── __init__.py
│ ├── pre_model.py
│ ├── resnet12.py
│ ├── meta_model.py
│ └── resnet18.py
├── trainer
│ ├── __init__.py
│ ├── pre.py
│ └── meta.py
├── data_generator
│ ├── __init__.py
│ ├── pre_data_generator.py
│ └── meta_data_generator.py
├── README.md
├── run_experiment.py
└── main.py
├── LICENSE
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | # File types
2 | *.pyc
3 | *.npy
4 |
5 | # File
6 | .DS_Store
7 |
--------------------------------------------------------------------------------
/pytorch/models/__init__.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Tianjin University
4 | ## liuyaoyao@tju.edu.cn
5 | ## Copyright (c) 2019
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 |
--------------------------------------------------------------------------------
/pytorch/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Tianjin University
4 | ## liuyaoyao@tju.edu.cn
5 | ## Copyright (c) 2019
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 |
--------------------------------------------------------------------------------
/pytorch/utils/__init__.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Tianjin University
4 | ## liuyaoyao@tju.edu.cn
5 | ## Copyright (c) 2019
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 |
--------------------------------------------------------------------------------
/tensorflow/utils/__init__.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Tianjin University
4 | ## liuyaoyao@tju.edu.cn
5 | ## Copyright (c) 2019
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 |
--------------------------------------------------------------------------------
/pytorch/dataloader/__init__.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Tianjin University
4 | ## liuyaoyao@tju.edu.cn
5 | ## Copyright (c) 2019
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 |
--------------------------------------------------------------------------------
/tensorflow/models/__init__.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Tianjin University
4 | ## liuyaoyao@tju.edu.cn
5 | ## Copyright (c) 2019
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 |
--------------------------------------------------------------------------------
/tensorflow/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Tianjin University
4 | ## liuyaoyao@tju.edu.cn
5 | ## Copyright (c) 2019
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 |
--------------------------------------------------------------------------------
/tensorflow/data_generator/__init__.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Tianjin University
4 | ## liuyaoyao@tju.edu.cn
5 | ## Copyright (c) 2019
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 |
--------------------------------------------------------------------------------
/pytorch/utils/gpu_tools.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Tianjin University
4 | ## liuyaoyao@tju.edu.cn
5 | ## Copyright (c) 2019
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 | """ Tools for GPU. """
11 | import os
12 | import torch
13 |
14 | def set_gpu(cuda_device):
15 | os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device
16 | print('Using gpu:', cuda_device)
17 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Yaoyao Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/pytorch/run_pre.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Tianjin University
4 | ## liuyaoyao@tju.edu.cn
5 | ## Copyright (c) 2019
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 | """ Generate commands for pre-train phase. """
11 | import os
12 |
13 | def run_exp(lr=0.1, gamma=0.2, step_size=30):
14 | max_epoch = 110
15 | shot = 1
16 | query = 15
17 | way = 5
18 | gpu = 1
19 | base_lr = 0.01
20 |
21 | the_command = 'python3 main.py' \
22 | + ' --pre_max_epoch=' + str(max_epoch) \
23 | + ' --shot=' + str(shot) \
24 | + ' --train_query=' + str(query) \
25 | + ' --way=' + str(way) \
26 | + ' --pre_step_size=' + str(step_size) \
27 | + ' --pre_gamma=' + str(gamma) \
28 | + ' --gpu=' + str(gpu) \
29 | + ' --base_lr=' + str(base_lr) \
30 | + ' --pre_lr=' + str(lr) \
31 | + ' --phase=pre_train'
32 |
33 | os.system(the_command)
34 |
35 | run_exp(lr=0.1, gamma=0.2, step_size=30)
36 |
--------------------------------------------------------------------------------
/pytorch/dataloader/samplers.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Modified from: https://github.com/Sha-Lab/FEAT
4 | ## Tianjin University
5 | ## liuyaoyao@tju.edu.cn
6 | ## Copyright (c) 2019
7 | ##
8 | ## This source code is licensed under the MIT-style license found in the
9 | ## LICENSE file in the root directory of this source tree
10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
11 | """ Sampler for dataloader. """
12 | import torch
13 | import numpy as np
14 |
15 | class CategoriesSampler():
16 | """The class to generate episodic data"""
17 | def __init__(self, label, n_batch, n_cls, n_per):
18 | self.n_batch = n_batch
19 | self.n_cls = n_cls
20 | self.n_per = n_per
21 |
22 | label = np.array(label)
23 | self.m_ind = []
24 | for i in range(max(label) + 1):
25 | ind = np.argwhere(label == i).reshape(-1)
26 | ind = torch.from_numpy(ind)
27 | self.m_ind.append(ind)
28 |
29 | def __len__(self):
30 | return self.n_batch
31 | def __iter__(self):
32 | for i_batch in range(self.n_batch):
33 | batch = []
34 | classes = torch.randperm(len(self.m_ind))[:self.n_cls]
35 | for c in classes:
36 | l = self.m_ind[c]
37 | pos = torch.randperm(len(l))[:self.n_per]
38 | batch.append(l[pos])
39 | batch = torch.stack(batch).t().reshape(-1)
40 | yield batch
41 |
--------------------------------------------------------------------------------
/pytorch/run_meta.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Tianjin University
4 | ## liuyaoyao@tju.edu.cn
5 | ## Copyright (c) 2019
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 | """ Generate commands for meta-train phase. """
11 | import os
12 |
13 | def run_exp(num_batch=1000, shot=1, query=15, lr1=0.0001, lr2=0.001, base_lr=0.01, update_step=10, gamma=0.5):
14 | max_epoch = 100
15 | way = 5
16 | step_size = 10
17 | gpu = 1
18 |
19 | the_command = 'python3 main.py' \
20 | + ' --max_epoch=' + str(max_epoch) \
21 | + ' --num_batch=' + str(num_batch) \
22 | + ' --shot=' + str(shot) \
23 | + ' --train_query=' + str(query) \
24 | + ' --way=' + str(way) \
25 | + ' --meta_lr1=' + str(lr1) \
26 | + ' --meta_lr2=' + str(lr2) \
27 | + ' --step_size=' + str(step_size) \
28 | + ' --gamma=' + str(gamma) \
29 | + ' --gpu=' + str(gpu) \
30 | + ' --base_lr=' + str(base_lr) \
31 | + ' --update_step=' + str(update_step)
32 |
33 | os.system(the_command + ' --phase=meta_train')
34 | os.system(the_command + ' --phase=meta_eval')
35 |
36 | run_exp(num_batch=100, shot=1, query=15, lr1=0.0001, lr2=0.001, base_lr=0.01, update_step=100, gamma=0.5)
37 | run_exp(num_batch=100, shot=5, query=15, lr1=0.0001, lr2=0.001, base_lr=0.01, update_step=100, gamma=0.5)
38 |
--------------------------------------------------------------------------------
/pytorch/utils/misc.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Modified from: https://github.com/Sha-Lab/FEAT
4 | ## Tianjin University
5 | ## liuyaoyao@tju.edu.cn
6 | ## Copyright (c) 2019
7 | ##
8 | ## This source code is licensed under the MIT-style license found in the
9 | ## LICENSE file in the root directory of this source tree
10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
11 | """ Additional utility functions. """
12 | import os
13 | import time
14 | import pprint
15 | import torch
16 | import numpy as np
17 | import torch.nn.functional as F
18 |
19 | def ensure_path(path):
20 | """The function to make log path.
21 | Args:
22 | path: the generated saving path.
23 | """
24 | if os.path.exists(path):
25 | pass
26 | else:
27 | os.mkdir(path)
28 |
29 | class Averager():
30 | """The class to calculate the average."""
31 | def __init__(self):
32 | self.n = 0
33 | self.v = 0
34 |
35 | def add(self, x):
36 | self.v = (self.v * self.n + x) / (self.n + 1)
37 | self.n += 1
38 |
39 | def item(self):
40 | return self.v
41 |
42 | def count_acc(logits, label):
43 | """The function to calculate the .
44 | Args:
45 | logits: input logits.
46 | label: ground truth labels.
47 | Return:
48 | The output accuracy.
49 | """
50 | pred = F.softmax(logits, dim=1).argmax(dim=1)
51 | if torch.cuda.is_available():
52 | return (pred == label).type(torch.cuda.FloatTensor).mean().item()
53 | return (pred == label).type(torch.FloatTensor).mean().item()
54 |
55 | class Timer():
56 | """The class for timer."""
57 | def __init__(self):
58 | self.o = time.time()
59 |
60 | def measure(self, p=1):
61 | x = (time.time() - self.o) / p
62 | x = int(x)
63 | if x >= 3600:
64 | return '{:.1f}h'.format(x / 3600)
65 | if x >= 60:
66 | return '{}m'.format(round(x / 60))
67 | return '{}s'.format(x)
68 |
69 | _utils_pp = pprint.PrettyPrinter()
70 |
71 | def pprint(x):
72 | _utils_pp.pprint(x)
73 |
74 | def compute_confidence_interval(data):
75 | """The function to calculate the .
76 | Args:
77 | data: input records
78 | label: ground truth labels.
79 | Return:
80 | m: mean value
81 | pm: confidence interval.
82 | """
83 | a = 1.0 * np.array(data)
84 | m = np.mean(a)
85 | std = np.std(a)
86 | pm = 1.96 * (std / np.sqrt(len(a)))
87 | return m, pm
88 |
--------------------------------------------------------------------------------
/tensorflow/data_generator/pre_data_generator.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Modified from: https://github.com/cbfinn/maml
4 | ## Tianjin University
5 | ## liuyaoyao@tju.edu.cn
6 | ## Copyright (c) 2019
7 | ##
8 | ## This source code is licensed under the MIT-style license found in the
9 | ## LICENSE file in the root directory of this source tree
10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
11 |
12 | """ Data generator for pre-train phase. """
13 | import numpy as np
14 | import os
15 | import random
16 | import tensorflow as tf
17 | from tqdm import trange
18 |
19 | from tensorflow.python.platform import flags
20 | from utils.misc import get_pretrain_images
21 |
22 | FLAGS = flags.FLAGS
23 |
24 | class PreDataGenerator(object):
25 | """The class to generate episodes for pre-train phase."""
26 | def __init__(self):
27 | self.num_classes = FLAGS.way_num
28 | self.img_size = (FLAGS.img_size, FLAGS.img_size)
29 | self.dim_input = np.prod(self.img_size)*3
30 | self.pretrain_class_num = FLAGS.pretrain_class_num
31 | self.pretrain_batch_size = FLAGS.pretrain_batch_size
32 | pretrain_folder = FLAGS.pretrain_folders
33 |
34 | pretrain_folders = [os.path.join(pretrain_folder, label) for label in os.listdir(pretrain_folder) if os.path.isdir(os.path.join(pretrain_folder, label))]
35 | self.pretrain_character_folders = pretrain_folders
36 |
37 | def make_data_tensor(self):
38 | """The function to make tensor for the tensorflow model."""
39 | print('Generating pre-training data')
40 | all_filenames_and_labels = []
41 | folders = self.pretrain_character_folders
42 |
43 | for idx, path in enumerate(folders):
44 | all_filenames_and_labels += get_pretrain_images(path, idx)
45 | random.shuffle(all_filenames_and_labels)
46 | all_labels = [li[0] for li in all_filenames_and_labels]
47 | all_filenames = [li[1] for li in all_filenames_and_labels]
48 | filename_queue = tf.train.string_input_producer(tf.convert_to_tensor(all_filenames), shuffle=False)
49 | label_queue = tf.train.slice_input_producer([tf.convert_to_tensor(all_labels)], shuffle=False)
50 | image_reader = tf.WholeFileReader()
51 | _, image_file = image_reader.read(filename_queue)
52 |
53 | image = tf.image.decode_jpeg(image_file, channels=3)
54 | image.set_shape((self.img_size[0],self.img_size[1],3))
55 | image = tf.reshape(image, [self.dim_input])
56 | image = tf.cast(image, tf.float32) / 255.0
57 |
58 | num_preprocess_threads = 1
59 | min_queue_examples = 256
60 | batch_image_size = self.pretrain_batch_size
61 | image_batch, label_batch = tf.train.batch([image, label_queue], batch_size = batch_image_size, num_threads=num_preprocess_threads,capacity=min_queue_examples + 3 * batch_image_size)
62 | label_batch = tf.one_hot(tf.reshape(label_batch, [-1]), self.pretrain_class_num)
63 | return image_batch, label_batch
64 |
--------------------------------------------------------------------------------
/tensorflow/models/pre_model.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Tianjin University
4 | ## liuyaoyao@tju.edu.cn
5 | ## Copyright (c) 2019
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 |
11 | """ Models for pre-train phase. """
12 | import tensorflow as tf
13 | from tensorflow.python.platform import flags
14 |
15 | FLAGS = flags.FLAGS
16 |
17 | def MakePreModel():
18 | """The function to make pre model.
19 | Arg:
20 | Pre-train model class.
21 | """
22 | if FLAGS.backbone_arch=='resnet12':
23 | try:#python2
24 | from resnet12 import Models
25 | except ImportError:#python3
26 | from models.resnet12 import Models
27 | elif FLAGS.backbone_arch=='resnet18':
28 | try:#python2
29 | from resnet18 import Models
30 | except ImportError:#python3
31 | from models.resnet18 import Models
32 | else:
33 | print('Please set the correct backbone')
34 |
35 | class PreModel(Models):
36 | """The class for pre-train model."""
37 | def construct_pretrain_model(self, input_tensors=None, is_val=False):
38 | """The function to construct pre-train model.
39 | Args:
40 | input_tensors: the input tensor to construct pre-train model.
41 | is_val: whether the model is for validation.
42 | """
43 | self.input = input_tensors['pretrain_input']
44 | self.label = input_tensors['pretrain_label']
45 | with tf.variable_scope('pretrain-model', reuse=None) as training_scope:
46 | self.weights = weights = self.construct_resnet_weights()
47 | self.fc_weights = fc_weights = self.construct_fc_weights()
48 |
49 | if is_val is False:
50 | self.pretrain_task_output = self.forward_fc(self.forward_pretrain_resnet(self.input, weights, reuse=False), fc_weights)
51 | self.pretrain_task_loss = self.pretrain_loss_func(self.pretrain_task_output, self.label)
52 | optimizer = tf.train.AdamOptimizer(self.pretrain_lr)
53 | self.pretrain_op = optimizer.minimize(self.pretrain_task_loss, var_list=weights.values()+fc_weights.values())
54 | self.pretrain_task_accuracy = tf.reduce_mean(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax( \
55 | self.pretrain_task_output), 1), tf.argmax(self.label, 1)))
56 | tf.summary.scalar('pretrain train loss', self.pretrain_task_loss)
57 | tf.summary.scalar('pretrain train accuracy', self.pretrain_task_accuracy)
58 | else:
59 | self.pretrain_task_output_val = self.forward_fc(self.forward_pretrain_resnet(self.input, weights, reuse=True), fc_weights)
60 | self.pretrain_task_accuracy_val = tf.reduce_mean(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax( \
61 | self.pretrain_task_output_val), 1), tf.argmax(self.label, 1)))
62 | tf.summary.scalar('pretrain val accuracy', self.pretrain_task_accuracy_val)
63 |
64 | return PreModel()
65 |
--------------------------------------------------------------------------------
/pytorch/dataloader/dataset_loader.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Modified from: https://github.com/Sha-Lab/FEAT
4 | ## Tianjin University
5 | ## liuyaoyao@tju.edu.cn
6 | ## Copyright (c) 2019
7 | ##
8 | ## This source code is licensed under the MIT-style license found in the
9 | ## LICENSE file in the root directory of this source tree
10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
11 | """ Dataloader for all datasets. """
12 | import os.path as osp
13 | import os
14 | from PIL import Image
15 | from torch.utils.data import Dataset
16 | from torchvision import transforms
17 | import numpy as np
18 |
19 | class DatasetLoader(Dataset):
20 | """The class to load the dataset"""
21 | def __init__(self, setname, args, train_aug=False):
22 | # Set the path according to train, val and test
23 | if setname=='train':
24 | THE_PATH = osp.join(args.dataset_dir, 'train')
25 | label_list = os.listdir(THE_PATH)
26 | elif setname=='test':
27 | THE_PATH = osp.join(args.dataset_dir, 'test')
28 | label_list = os.listdir(THE_PATH)
29 | elif setname=='val':
30 | THE_PATH = osp.join(args.dataset_dir, 'val')
31 | label_list = os.listdir(THE_PATH)
32 | else:
33 | raise ValueError('Wrong setname.')
34 |
35 | # Generate empty list for data and label
36 | data = []
37 | label = []
38 |
39 | # Get folders' name
40 | folders = [osp.join(THE_PATH, the_label) for the_label in label_list if os.path.isdir(osp.join(THE_PATH, the_label))]
41 |
42 | # Get the images' paths and labels
43 | for idx, this_folder in enumerate(folders):
44 | this_folder_images = os.listdir(this_folder)
45 | for image_path in this_folder_images:
46 | data.append(osp.join(this_folder, image_path))
47 | label.append(idx)
48 |
49 | # Set data, label and class number to be accessable from outside
50 | self.data = data
51 | self.label = label
52 | self.num_class = len(set(label))
53 |
54 | # Transformation
55 | if train_aug:
56 | image_size = 80
57 | self.transform = transforms.Compose([
58 | transforms.Resize(92),
59 | transforms.RandomResizedCrop(88),
60 | transforms.CenterCrop(image_size),
61 | transforms.RandomHorizontalFlip(),
62 | transforms.ToTensor(),
63 | transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
64 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))])
65 | else:
66 | image_size = 80
67 | self.transform = transforms.Compose([
68 | transforms.Resize(92),
69 | transforms.CenterCrop(image_size),
70 | transforms.ToTensor(),
71 | transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
72 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))])
73 |
74 |
75 | def __len__(self):
76 | return len(self.data)
77 |
78 | def __getitem__(self, i):
79 | path, label = self.data[i], self.label[i]
80 | image = self.transform(Image.open(path).convert('RGB'))
81 | return image, label
82 |
--------------------------------------------------------------------------------
/pytorch/README.md:
--------------------------------------------------------------------------------
1 | # Meta-Transfer Learning PyTorch
2 | [](https://www.python.org/)
3 | [](https://pytorch.org/)
4 |
5 | #### Summary
6 |
7 | * [Installation](#installation)
8 | * [Project Architecture](#project-architecture)
9 | * [Running Experiments](#running-experiments)
10 |
11 |
12 | ## Installation
13 |
14 | In order to run this repository, we advise you to install python 3.5 and PyTorch 0.4.0 with Anaconda.
15 |
16 | You may download Anaconda and read the installation instruction on their official website:
17 |
18 |
19 | Create a new environment and install PyTorch and torchvision on it:
20 |
21 | ```bash
22 | conda create --name mtl-pytorch python=3.5
23 | conda activate mtl-pytorch
24 | conda install pytorch=0.4.0
25 | conda install torchvision -c pytorch
26 | ```
27 |
28 | Install other requirements:
29 | ```bash
30 | pip install tqdm tensorboardX miniimagenettools
31 | ```
32 |
33 | Clone this repository:
34 |
35 | ```bash
36 | git clone https://github.com/yaoyao-liu/meta-transfer-learning.git
37 | cd meta-transfer-learning/pytorch
38 | ```
39 |
40 | ## Project Architecture
41 |
42 | ```
43 | .
44 | ├── data_generator
45 | | ├── dataset_loader.py # data loader for all datasets
46 | | └── meta_data_generator.py # samplers for meta train
47 | ├── models
48 | | ├── mtl.py # meta-transfer class
49 | | ├── resnet_mtl.py # resnet class
50 | | └── conv2d_mtl.py # meta-transfer convolution class
51 | ├── trainer
52 | | ├── pre.py # pre-train trainer class
53 | | └── meta.py # meta-train trainer class
54 | ├── utils
55 | | ├── gpu_tools.py # GPU tool functions
56 | | └── misc.py # miscellaneous tool functions
57 | ├── main.py # the python file with main function and parameter settings
58 | ├── run_pre.py # the script to run pre-train phase
59 | └── run_meta.py # the script to run meta-train and meta-test phases
60 | ```
61 |
62 | ## Running Experiments
63 |
64 | Run pretrain phase:
65 | ```bash
66 | python run_pre.py
67 | ```
68 | Run meta-train and meta-test phase:
69 | ```bash
70 | python run_meta.py
71 | ```
72 |
73 | ### Hyperparameters and Options
74 | Hyperparameters and options in `main.py`.
75 |
76 | - `model_type` The network architecture
77 | - `dataset` Meta dataset
78 | - `phase` pre-train, meta-train or meta-eval
79 | - `seed` Manual seed for PyTorch, "0" means using random seed
80 | - `gpu` GPU id
81 | - `dataset_dir` Directory for the images
82 | - `max_epoch` Epoch number for meta-train phase
83 | - `num_batch` The number for different tasks used for meta-train
84 | - `shot` Shot number, how many samples for one class in a task
85 | - `way` Way number, how many classes in a task
86 | - `train_query` The number of training samples for each class in a task
87 | - `val_query` The number of test samples for each class in a task
88 | - `meta_lr1` Learning rate for SS weights
89 | - `meta_lr2` Learning rate for FC weights
90 | - `base_lr` Learning rate for the inner loop
91 | - `update_step` The number of updates for the inner loop
92 | - `step_size` The number of epochs to reduce the meta learning rates
93 | - `gamma` Gamma for the meta-train learning rate decay
94 | - `init_weights` The pretained weights for meta-train phase
95 | - `eval_weights` The meta-trained weights for meta-eval phase
96 | - `meta_label` Additional label for meta-train
97 | - `pre_max_epoch` Epoch number for pre-train pahse
98 | - `pre_batch_size` Batch size for pre-train pahse
99 | - `pre_lr` Learning rate for pre-train pahse
100 | - `pre_gamma` Gamma for the preteain learning rate decay
101 | - `pre_step_size` The number of epochs to reduce the pre-train learning rate
102 | - `pre_custom_weight_decay` Weight decay for the optimizer during pre-train
103 |
104 |
--------------------------------------------------------------------------------
/tensorflow/trainer/pre.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Modified from: https://github.com/cbfinn/maml
4 | ## Tianjin University
5 | ## liuyaoyao@tju.edu.cn
6 | ## Copyright (c) 2019
7 | ##
8 | ## This source code is licensed under the MIT-style license found in the
9 | ## LICENSE file in the root directory of this source tree
10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
11 |
12 | """ Trainer for pre-train phase. """
13 | import os
14 | import numpy as np
15 | import tensorflow as tf
16 |
17 | from tqdm import trange
18 | from data_generator.pre_data_generator import PreDataGenerator
19 | from models.pre_model import MakePreModel
20 | from tensorflow.python.platform import flags
21 |
22 | FLAGS = flags.FLAGS
23 |
24 | class PreTrainer:
25 | """The class that contains the code for the pre-train phase"""
26 | def __init__(self):
27 | # This class defines the pre-train phase trainer
28 | print('Generating pre-training classes')
29 |
30 | # Generate Pre-train Data Tensors
31 | pre_train_data_generator = PreDataGenerator()
32 | pretrain_input, pretrain_label = pre_train_data_generator.make_data_tensor()
33 | pre_train_input_tensors = {'pretrain_input': pretrain_input, 'pretrain_label': pretrain_label}
34 |
35 | # Build Pre-train Model
36 | self.model = MakePreModel()
37 | self.model.construct_pretrain_model(input_tensors=pre_train_input_tensors)
38 | self.model.pretrain_summ_op = tf.summary.merge_all()
39 |
40 | # Start the TensorFlow Session
41 | if FLAGS.full_gpu_memory_mode:
42 | gpu_config = tf.ConfigProto()
43 | gpu_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_rate
44 | self.sess = tf.InteractiveSession(config=gpu_config)
45 | else:
46 | self.sess = tf.InteractiveSession()
47 |
48 | # Initialize and Start the Pre-train Phase
49 | tf.global_variables_initializer().run()
50 | tf.train.start_queue_runners()
51 | self.pre_train()
52 |
53 | def pre_train(self):
54 | # Load Parameters from FLAGS
55 | pretrain_iterations = FLAGS.pretrain_iterations
56 | weights_save_dir_base = FLAGS.pretrain_dir
57 | pre_save_str = FLAGS.pre_string
58 |
59 | # Build Pre-train Log Folder
60 | weights_save_dir = os.path.join(weights_save_dir_base, pre_save_str)
61 | if not os.path.exists(weights_save_dir):
62 | os.mkdir(weights_save_dir)
63 | pretrain_writer = tf.summary.FileWriter(weights_save_dir, self.sess.graph)
64 | pre_lr = FLAGS.pre_lr
65 |
66 | print('Start pre-train phase')
67 | print('Pre-train Hyper parameters: ' + pre_save_str)
68 |
69 | # Start the iterations
70 | for itr in trange(pretrain_iterations):
71 | # Generate the Feed Dict and Run the Optimizer
72 | feed_dict = {self.model.pretrain_lr: pre_lr}
73 | input_tensors = [self.model.pretrain_op, self.model.pretrain_summ_op]
74 | input_tensors.extend([self.model.pretrain_task_loss, self.model.pretrain_task_accuracy])
75 | result = self.sess.run(input_tensors, feed_dict)
76 |
77 | # Print Results during Training
78 | if (itr!=0) and itr % FLAGS.pre_print_step == 0:
79 | print_str = '[*] Pre Loss: ' + str(result[-2]) + ', Pre Acc: ' + str(result[-1])
80 | print(print_str)
81 |
82 | # Write the TensorFlow Summery
83 | if itr % FLAGS.pre_sum_step == 0:
84 | pretrain_writer.add_summary(result[1], itr)
85 |
86 | # Decrease the Learning Rate after Some Iterations
87 | if (itr!=0) and itr % FLAGS.pre_lr_dropstep == 0:
88 | pre_lr = pre_lr * 0.5
89 | if FLAGS.pre_lr_stop and pre_lr < FLAGS.min_pre_lr:
90 | pre_lr = FLAGS.min_pre_lr
91 |
92 | # Save Pre-train Model
93 | if (itr!=0) and itr % FLAGS.pre_save_step == 0:
94 | print('Saving pretrain weights to npy')
95 | weights = self.sess.run(self.model.weights)
96 | np.save(os.path.join(weights_save_dir, "weights_{}.npy".format(itr)), weights)
97 |
--------------------------------------------------------------------------------
/pytorch/models/conv2d_mtl.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Modified from: https://github.com/pytorch/pytorch
4 | ## Tianjin University
5 | ## liuyaoyao@tju.edu.cn
6 | ## Copyright (c) 2019
7 | ##
8 | ## This source code is licensed under the MIT-style license found in the
9 | ## LICENSE file in the root directory of this source tree
10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
11 | """ MTL CONV layers. """
12 | import math
13 | import torch
14 | import torch.nn.functional as F
15 | from torch.nn.parameter import Parameter
16 | from torch.nn.modules.module import Module
17 | from torch.nn.modules.utils import _pair
18 |
19 | class _ConvNdMtl(Module):
20 | """The class for meta-transfer convolution"""
21 | def __init__(self, in_channels, out_channels, kernel_size, stride,
22 | padding, dilation, transposed, output_padding, groups, bias):
23 | super(_ConvNdMtl, self).__init__()
24 | if in_channels % groups != 0:
25 | raise ValueError('in_channels must be divisible by groups')
26 | if out_channels % groups != 0:
27 | raise ValueError('out_channels must be divisible by groups')
28 | self.in_channels = in_channels
29 | self.out_channels = out_channels
30 | self.kernel_size = kernel_size
31 | self.stride = stride
32 | self.padding = padding
33 | self.dilation = dilation
34 | self.transposed = transposed
35 | self.output_padding = output_padding
36 | self.groups = groups
37 | if transposed:
38 | self.weight = Parameter(torch.Tensor(
39 | in_channels, out_channels // groups, *kernel_size))
40 | self.mtl_weight = Parameter(torch.ones(in_channels, out_channels // groups, 1, 1))
41 | else:
42 | self.weight = Parameter(torch.Tensor(
43 | out_channels, in_channels // groups, *kernel_size))
44 | self.mtl_weight = Parameter(torch.ones(out_channels, in_channels // groups, 1, 1))
45 | self.weight.requires_grad=False
46 | if bias:
47 | self.bias = Parameter(torch.Tensor(out_channels))
48 | self.bias.requires_grad=False
49 | self.mtl_bias = Parameter(torch.zeros(out_channels))
50 | else:
51 | self.register_parameter('bias', None)
52 | self.register_parameter('mtl_bias', None)
53 | self.reset_parameters()
54 |
55 | def reset_parameters(self):
56 | n = self.in_channels
57 | for k in self.kernel_size:
58 | n *= k
59 | stdv = 1. / math.sqrt(n)
60 | self.weight.data.uniform_(-stdv, stdv)
61 | self.mtl_weight.data.uniform_(1, 1)
62 | if self.bias is not None:
63 | self.bias.data.uniform_(-stdv, stdv)
64 | self.mtl_bias.data.uniform_(0, 0)
65 |
66 | def extra_repr(self):
67 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
68 | ', stride={stride}')
69 | if self.padding != (0,) * len(self.padding):
70 | s += ', padding={padding}'
71 | if self.dilation != (1,) * len(self.dilation):
72 | s += ', dilation={dilation}'
73 | if self.output_padding != (0,) * len(self.output_padding):
74 | s += ', output_padding={output_padding}'
75 | if self.groups != 1:
76 | s += ', groups={groups}'
77 | if self.bias is None:
78 | s += ', bias=False'
79 | return s.format(**self.__dict__)
80 |
81 | class Conv2dMtl(_ConvNdMtl):
82 | """The class for meta-transfer convolution"""
83 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
84 | padding=0, dilation=1, groups=1, bias=True):
85 | kernel_size = _pair(kernel_size)
86 | stride = _pair(stride)
87 | padding = _pair(padding)
88 | dilation = _pair(dilation)
89 | super(Conv2dMtl, self).__init__(
90 | in_channels, out_channels, kernel_size, stride, padding, dilation,
91 | False, _pair(0), groups, bias)
92 |
93 | def forward(self, inp):
94 | new_mtl_weight = self.mtl_weight.expand(self.weight.shape)
95 | new_weight = self.weight.mul(new_mtl_weight)
96 | if self.bias is not None:
97 | new_bias = self.bias + self.mtl_bias
98 | else:
99 | new_bias = None
100 | return F.conv2d(inp, new_weight, new_bias, self.stride,
101 | self.padding, self.dilation, self.groups)
102 |
--------------------------------------------------------------------------------
/pytorch/main.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Tianjin University
4 | ## liuyaoyao@tju.edu.cn
5 | ## Copyright (c) 2019
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 | """ Main function for this repo. """
11 | import argparse
12 | import torch
13 | from utils.misc import pprint
14 | from utils.gpu_tools import set_gpu
15 | from trainer.meta import MetaTrainer
16 | from trainer.pre import PreTrainer
17 |
18 | if __name__ == '__main__':
19 | parser = argparse.ArgumentParser()
20 | # Basic parameters
21 | parser.add_argument('--model_type', type=str, default='ResNet', choices=['ResNet']) # The network architecture
22 | parser.add_argument('--dataset', type=str, default='MiniImageNet', choices=['miniImageNet', 'tieredImageNet', 'FC100']) # Dataset
23 | parser.add_argument('--phase', type=str, default='meta_train', choices=['pre_train', 'meta_train', 'meta_eval']) # Phase
24 | parser.add_argument('--seed', type=int, default=0) # Manual seed for PyTorch, "0" means using random seed
25 | parser.add_argument('--gpu', default='1') # GPU id
26 | parser.add_argument('--dataset_dir', type=str, default='./data/mini/') # Dataset folder
27 |
28 | # Parameters for meta-train phase
29 | parser.add_argument('--max_epoch', type=int, default=100) # Epoch number for meta-train phase
30 | parser.add_argument('--num_batch', type=int, default=100) # The number for different tasks used for meta-train
31 | parser.add_argument('--shot', type=int, default=1) # Shot number, how many samples for one class in a task
32 | parser.add_argument('--way', type=int, default=5) # Way number, how many classes in a task
33 | parser.add_argument('--train_query', type=int, default=15) # The number of training samples for each class in a task
34 | parser.add_argument('--val_query', type=int, default=15) # The number of test samples for each class in a task
35 | parser.add_argument('--meta_lr1', type=float, default=0.0001) # Learning rate for SS weights
36 | parser.add_argument('--meta_lr2', type=float, default=0.001) # Learning rate for FC weights
37 | parser.add_argument('--base_lr', type=float, default=0.01) # Learning rate for the inner loop
38 | parser.add_argument('--update_step', type=int, default=50) # The number of updates for the inner loop
39 | parser.add_argument('--step_size', type=int, default=10) # The number of epochs to reduce the meta learning rates
40 | parser.add_argument('--gamma', type=float, default=0.5) # Gamma for the meta-train learning rate decay
41 | parser.add_argument('--init_weights', type=str, default=None) # The pre-trained weights for meta-train phase
42 | parser.add_argument('--eval_weights', type=str, default=None) # The meta-trained weights for meta-eval phase
43 | parser.add_argument('--meta_label', type=str, default='exp1') # Additional label for meta-train
44 |
45 | # Parameters for pretain phase
46 | parser.add_argument('--pre_max_epoch', type=int, default=100) # Epoch number for pre-train phase
47 | parser.add_argument('--pre_batch_size', type=int, default=128) # Batch size for pre-train phase
48 | parser.add_argument('--pre_lr', type=float, default=0.1) # Learning rate for pre-train phase
49 | parser.add_argument('--pre_gamma', type=float, default=0.2) # Gamma for the pre-train learning rate decay
50 | parser.add_argument('--pre_step_size', type=int, default=30) # The number of epochs to reduce the pre-train learning rate
51 | parser.add_argument('--pre_custom_momentum', type=float, default=0.9) # Momentum for the optimizer during pre-train
52 | parser.add_argument('--pre_custom_weight_decay', type=float, default=0.0005) # Weight decay for the optimizer during pre-train
53 |
54 | # Set and print the parameters
55 | args = parser.parse_args()
56 | pprint(vars(args))
57 |
58 | # Set the GPU id
59 | set_gpu(args.gpu)
60 |
61 | # Set manual seed for PyTorch
62 | if args.seed==0:
63 | print ('Using random seed.')
64 | torch.backends.cudnn.benchmark = True
65 | else:
66 | print ('Using manual seed:', args.seed)
67 | torch.manual_seed(args.seed)
68 | torch.cuda.manual_seed(args.seed)
69 | torch.backends.cudnn.deterministic = True
70 | torch.backends.cudnn.benchmark = False
71 |
72 | # Start trainer for pre-train, meta-train or meta-eval
73 | if args.phase=='meta_train':
74 | trainer = MetaTrainer(args)
75 | trainer.train()
76 | elif args.phase=='meta_eval':
77 | trainer = MetaTrainer(args)
78 | trainer.eval()
79 | elif args.phase=='pre_train':
80 | trainer = PreTrainer(args)
81 | trainer.train()
82 | else:
83 | raise ValueError('Please set correct phase.')
84 |
--------------------------------------------------------------------------------
/tensorflow/README.md:
--------------------------------------------------------------------------------
1 | # Meta-Transfer Learning TensorFlow
2 | [](https://www.python.org/)
3 | [](https://github.com/y2l/meta-transfer-learning/tree/master/tensorflow)
4 |
5 | #### Summary
6 |
7 | * [Installation](#installation)
8 | * [Project Architecture](#project-architecture)
9 | * [Running Experiments](#running-experiments)
10 |
11 |
12 | ## Installation
13 |
14 | In order to run this repository, we advise you to install python 2.7 or 3.5 and TensorFlow 1.3.0 with Anaconda.
15 |
16 | You may download Anaconda and read the installation instruction on their official website:
17 |
18 |
19 | Create a new environment and install tensorflow on it:
20 |
21 | ```bash
22 | conda create --name mtl-tf python=2.7
23 | conda activate mtl-tf
24 | conda install tensorflow-gpu=1.3.0
25 | ```
26 |
27 | Install other requirements:
28 | ```bash
29 | pip install scipy tqdm opencv-python pillow matplotlib miniimagenettools
30 | ```
31 |
32 | Clone this repository:
33 |
34 | ```bash
35 | git clone https://github.com/yaoyao-liu/meta-transfer-learning.git
36 | cd meta-transfer-learning/tensorflow
37 | ```
38 |
39 | ## Project Architecture
40 |
41 | ```
42 | .
43 | ├── data_generator # dataset generator
44 | | ├── pre_data_generator.py # data genertor for pre-train phase
45 | | └── meta_data_generator.py # data genertor for meta-train phase
46 | ├── models # tensorflow model files
47 | | ├── resnet12.py # resnet12 class
48 | | ├── resnet18.py # resnet18 class
49 | | ├── pre_model.py # pre-train model class
50 | | └── meta_model.py # meta-train model class
51 | ├── trainer # tensorflow trianer files
52 | | ├── pre.py # pre-train trainer class
53 | | └── meta.py # meta-train trainer class
54 | ├── utils # a series of tools used in this repo
55 | | └── misc.py # miscellaneous tool functions
56 | ├── main.py # the python file with main function and parameter settings
57 | └── run_experiment.py # the script to run the whole experiment
58 | ```
59 |
60 | ## Running Experiments
61 |
62 | ### Training from Scratch
63 | Run pre-train phase:
64 | ```bash
65 | python run_experiment.py PRE
66 | ```
67 | Run meta-train and meta-test phase:
68 | ```bash
69 | python run_experiment.py META
70 | ```
71 |
72 | ### Hyperparameters and Options
73 | You may edit the `run_experiment.py` file to change the hyperparameters and options.
74 |
75 | - `LOG_DIR` Name of the folder to save the log files
76 | - `GPU_ID` GPU device id
77 | - `NET_ARCH` Network backbone (resnet12 or resnet18)
78 | - `PRE_TRA_LABEL` Additional label for pre-train model
79 | - `PRE_TRA_ITER_MAX` Iteration number for the pre-train phase
80 | - `PRE_TRA_DROP` Dropout keep rate for the pre-train phase
81 | - `PRE_DROP_STEP` Iteration number for the pre-train learning rate reducing
82 | - `PRE_LR` Pre-train learning rate
83 | - `SHOT_NUM` Sample number for each class
84 | - `WAY_NUM` Class number for the few-shot tasks
85 | - `MAX_MAX_ITER` Iteration number for meta-train phase
86 | - `META_BATCH_SIZE` Meta batch size
87 | - `PRE_ITER` Iteration number for the pre-train model used in the meta-train phase
88 | - `UPDATE_NUM` Epoch number for the base learning
89 | - `SAVE_STEP` Iteration number to save the meta model
90 | - `META_LR` Meta learning rate
91 | - `META_LR_MIN` Meta learning rate min value
92 | - `LR_DROP_STEP` Iteration number for the meta learning rate reducing
93 | - `BASE_LR` Base learning rate
94 | - `PRE_TRA_DIR` Directory for the pre-train phase images
95 | - `META_TRA_DIR` Directory for the meta-train images
96 | - `META_VAL_DIR` Directory for the meta-validation images
97 | - `META_TES_DIR` Directory for the meta-test images
98 |
99 | The file `run_experiment.py` is just a script to generate commands for `main.py`. If you want to change other settings, please see the comments and descriptions in `main.py`.
100 |
101 | ### Using Downloaded Models
102 | In the default setting, if you run `python run_experiment.py`, the pretrain process will be conducted before the meta-train phase starts. If you want to use the model pretrained by us, you may download the model by the following link. To run experiments with the downloaded model, please make sure you are using python 2.7.
103 |
104 | Comparison of the original paper and the open-source code in terms of test set accuracy:
105 |
106 | | (%) | 𝑚𝑖𝑛𝑖 1-shot | 𝑚𝑖𝑛𝑖 5-shot | FC100 1-shot | FC100 5-shot |
107 | | ---------------------- | ------------ | ------------ | ------------ | ------------ |
108 | | `MTL Paper` | `60.2 ± 1.8` | `74.3 ± 0.9` | `43.6 ± 1.8` | `55.4 ± 0.9` |
109 | | `This Repo` | `60.8 ± 1.8` | `74.3 ± 0.9` | `44.3 ± 1.8` | `56.8 ± 1.0` |
110 |
111 | Download models: [\[Google Drive\]](https://drive.google.com/drive/folders/1MzH2enwLKuzmODYAEATnyiP_602zrdrE?usp=sharing)
112 |
113 | Move the downloaded npy files to `./logs/download_weights` (e.g. 𝑚𝑖𝑛𝑖ImageNet, 1-shot):
114 | ```bash
115 | mkdir -p ./logs/download_weights
116 | mv ~/downloads/mini-1shot/*.npy ./logs/download_weights
117 | ```
118 |
119 | Run meta-train with downloaded model:
120 | ```bash
121 | python run_experiment.py META_LOAD
122 | ```
123 |
124 | Run meta-test with downloaded model:
125 | ```bash
126 | python run_experiment.py TEST_LOAD
127 | ```
128 |
129 |
--------------------------------------------------------------------------------
/pytorch/models/mtl.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Tianjin University
4 | ## liuyaoyao@tju.edu.cn
5 | ## Copyright (c) 2019
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 | """ Model for meta-transfer learning. """
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from models.resnet_mtl import ResNetMtl
15 |
16 | class BaseLearner(nn.Module):
17 | """The class for inner loop."""
18 | def __init__(self, args, z_dim):
19 | super().__init__()
20 | self.args = args
21 | self.z_dim = z_dim
22 | self.vars = nn.ParameterList()
23 | self.fc1_w = nn.Parameter(torch.ones([self.args.way, self.z_dim]))
24 | torch.nn.init.kaiming_normal_(self.fc1_w)
25 | self.vars.append(self.fc1_w)
26 | self.fc1_b = nn.Parameter(torch.zeros(self.args.way))
27 | self.vars.append(self.fc1_b)
28 |
29 | def forward(self, input_x, the_vars=None):
30 | if the_vars is None:
31 | the_vars = self.vars
32 | fc1_w = the_vars[0]
33 | fc1_b = the_vars[1]
34 | net = F.linear(input_x, fc1_w, fc1_b)
35 | return net
36 |
37 | def parameters(self):
38 | return self.vars
39 |
40 | class MtlLearner(nn.Module):
41 | """The class for outer loop."""
42 | def __init__(self, args, mode='meta', num_cls=64):
43 | super().__init__()
44 | self.args = args
45 | self.mode = mode
46 | self.update_lr = args.base_lr
47 | self.update_step = args.update_step
48 | z_dim = 640
49 | self.base_learner = BaseLearner(args, z_dim)
50 |
51 | if self.mode == 'meta':
52 | self.encoder = ResNetMtl()
53 | else:
54 | self.encoder = ResNetMtl(mtl=False)
55 | self.pre_fc = nn.Sequential(nn.Linear(640, 1000), nn.ReLU(), nn.Linear(1000, num_cls))
56 |
57 | def forward(self, inp):
58 | """The function to forward the model.
59 | Args:
60 | inp: input images.
61 | Returns:
62 | the outputs of MTL model.
63 | """
64 | if self.mode=='pre':
65 | return self.pretrain_forward(inp)
66 | elif self.mode=='meta':
67 | data_shot, label_shot, data_query = inp
68 | return self.meta_forward(data_shot, label_shot, data_query)
69 | elif self.mode=='preval':
70 | data_shot, label_shot, data_query = inp
71 | return self.preval_forward(data_shot, label_shot, data_query)
72 | else:
73 | raise ValueError('Please set the correct mode.')
74 |
75 | def pretrain_forward(self, inp):
76 | """The function to forward pretrain phase.
77 | Args:
78 | inp: input images.
79 | Returns:
80 | the outputs of pretrain model.
81 | """
82 | return self.pre_fc(self.encoder(inp))
83 |
84 | def meta_forward(self, data_shot, label_shot, data_query):
85 | """The function to forward meta-train phase.
86 | Args:
87 | data_shot: train images for the task
88 | label_shot: train labels for the task
89 | data_query: test images for the task.
90 | Returns:
91 | logits_q: the predictions for the test samples.
92 | """
93 | embedding_query = self.encoder(data_query)
94 | embedding_shot = self.encoder(data_shot)
95 | logits = self.base_learner(embedding_shot)
96 | loss = F.cross_entropy(logits, label_shot)
97 | grad = torch.autograd.grad(loss, self.base_learner.parameters())
98 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.base_learner.parameters())))
99 | logits_q = self.base_learner(embedding_query, fast_weights)
100 |
101 | for _ in range(1, self.update_step):
102 | logits = self.base_learner(embedding_shot, fast_weights)
103 | loss = F.cross_entropy(logits, label_shot)
104 | grad = torch.autograd.grad(loss, fast_weights)
105 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))
106 | logits_q = self.base_learner(embedding_query, fast_weights)
107 | return logits_q
108 |
109 | def preval_forward(self, data_shot, label_shot, data_query):
110 | """The function to forward meta-validation during pretrain phase.
111 | Args:
112 | data_shot: train images for the task
113 | label_shot: train labels for the task
114 | data_query: test images for the task.
115 | Returns:
116 | logits_q: the predictions for the test samples.
117 | """
118 | embedding_query = self.encoder(data_query)
119 | embedding_shot = self.encoder(data_shot)
120 | logits = self.base_learner(embedding_shot)
121 | loss = F.cross_entropy(logits, label_shot)
122 | grad = torch.autograd.grad(loss, self.base_learner.parameters())
123 | fast_weights = list(map(lambda p: p[1] - 0.01 * p[0], zip(grad, self.base_learner.parameters())))
124 | logits_q = self.base_learner(embedding_query, fast_weights)
125 |
126 | for _ in range(1, 100):
127 | logits = self.base_learner(embedding_shot, fast_weights)
128 | loss = F.cross_entropy(logits, label_shot)
129 | grad = torch.autograd.grad(loss, fast_weights)
130 | fast_weights = list(map(lambda p: p[1] - 0.01 * p[0], zip(grad, fast_weights)))
131 | logits_q = self.base_learner(embedding_query, fast_weights)
132 | return logits_q
133 |
--------------------------------------------------------------------------------
/tensorflow/run_experiment.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Tianjin University
4 | ## liuyaoyao@tju.edu.cn
5 | ## Copyright (c) 2019
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 |
11 | """ Generate commands for main.py """
12 | import os
13 | import sys
14 |
15 | def run_experiment(PHASE='META'):
16 | """The function to generate commands to run the experiments.
17 | Arg:
18 | PHASE: the phase for MTL. 'PRE' means pre-train phase, and 'META' means meta-train and meta-test phases.
19 | """
20 | # Some important options
21 | # Please note that not all the options are shown here. For more detailed options, please edit main.py
22 |
23 | # Basic options
24 | LOG_DIR = 'experiment_results' # Name of the folder to save the log files
25 | GPU_ID = 1 # GPU device id
26 | NET_ARCH = 'resnet12' # Additional label for pre-train model
27 |
28 | # Pre-train phase options
29 | PRE_TRA_LABEL = 'normal' # Additional label for pre-train model
30 | PRE_TRA_ITER_MAX = 20000 # Iteration number for the pre-train phase
31 | PRE_TRA_DROP = 0.9 # Dropout keep rate for the pre-train phase
32 | PRE_DROP_STEP = 5000 # Iteration number for the pre-train learning rate reducing
33 | PRE_LR = 0.001 # Pre-train learning rate
34 |
35 | # Meta options
36 | SHOT_NUM = 1 # Shot number for the few-shot tasks
37 | WAY_NUM = 5 # Class number for the few-shot tasks
38 | MAX_ITER = 20000 # Iteration number for meta-train
39 | META_BATCH_SIZE = 2 # Meta batch size
40 | PRE_ITER = 10000 # Iteration number for the pre-train model used in the meta-train phase
41 | UPDATE_NUM = 20 # Epoch number for the base learning
42 | SAVE_STEP = 100 # Iteration number to save the meta model
43 | META_LR = 0.001 # Meta learning rate
44 | META_LR_MIN = 0.0001 # Meta learning rate min value
45 | LR_DROP_STEP = 1000 # The iteration number for the meta learning rate reducing
46 | BASE_LR = 0.001 # Base learning rate
47 |
48 | # Data directories
49 | PRE_TRA_DIR = './data/mini-imagenet/train' # Directory for the pre-train phase images
50 | META_TRA_DIR = './data/mini-imagenet/train' # Directory for the meta-train images
51 | META_VAL_DIR = './data/mini-imagenet/val' # Directory for the meta-validation images
52 | META_TES_DIR = './data/mini-imagenet/test' # Directory for the meta-test images
53 |
54 | # Generate the base command for main.py
55 | base_command = 'python main.py' \
56 | + ' --backbone_arch=' + str(NET_ARCH) \
57 | + ' --metatrain_iterations=' + str(MAX_ITER) \
58 | + ' --meta_batch_size=' + str(META_BATCH_SIZE) \
59 | + ' --shot_num=' + str(SHOT_NUM) \
60 | + ' --meta_lr=' + str(META_LR) \
61 | + ' --min_meta_lr=' + str(META_LR_MIN) \
62 | + ' --base_lr=' + str(BASE_LR)\
63 | + ' --train_base_epoch_num=' + str(UPDATE_NUM) \
64 | + ' --way_num=' + str(WAY_NUM) \
65 | + ' --exp_log_label=' + LOG_DIR \
66 | + ' --pretrain_dropout_keep=' + str(PRE_TRA_DROP) \
67 | + ' --activation=leaky_relu' \
68 | + ' --pre_lr=' + str(PRE_LR)\
69 | + ' --pre_lr_dropstep=' + str(PRE_DROP_STEP) \
70 | + ' --meta_save_step=' + str(SAVE_STEP) \
71 | + ' --lr_drop_step=' + str(LR_DROP_STEP) \
72 | + ' --pretrain_folders=' + PRE_TRA_DIR \
73 | + ' --pretrain_label=' + PRE_TRA_LABEL \
74 | + ' --device_id=' + str(GPU_ID) \
75 | + ' --metatrain_dir=' + META_TRA_DIR \
76 | + ' --metaval_dir=' + META_VAL_DIR \
77 | + ' --metatest_dir=' + META_TES_DIR
78 |
79 | def process_test_command(TEST_STEP, in_command):
80 | """The function to adapt the base command to the meta-test phase.
81 | Args:
82 | TEST_STEP: the iteration number for the meta model to be loaded.
83 | in_command: the input base command.
84 | Return:
85 | Processed command.
86 | """
87 | output_test_command = in_command \
88 | + ' --phase=meta' \
89 | + ' --pretrain_iterations=' + str(PRE_ITER) \
90 | + ' --metatrain=False' \
91 | + ' --test_iter=' + str(TEST_STEP)
92 | return output_test_command
93 |
94 | if PHASE=='PRE':
95 | print('****** Start Pre-train Phase ******')
96 | pre_command = base_command + ' --phase=pre' + ' --pretrain_iterations=' + str(PRE_TRA_ITER_MAX)
97 | os.system(pre_command)
98 |
99 | if PHASE=='META':
100 | print('****** Start Meta-train Phase ******')
101 | meta_train_command = base_command + ' --phase=meta' + ' --pretrain_iterations=' + str(PRE_ITER)
102 | os.system(meta_train_command)
103 |
104 | print('****** Start Meta-test Phase ******')
105 | for idx in range(MAX_ITER):
106 | if idx % SAVE_STEP == 0:
107 | print('[*] Runing meta-test, load model for ' + str(idx) + ' iterations')
108 | test_command = process_test_command(idx, base_command)
109 | os.system(test_command)
110 |
111 | if PHASE=='META_LOAD':
112 | print('****** Start Meta-train Phase with Downloaded Weights ******')
113 | meta_train_command = base_command + ' --phase=meta' + ' --pretrain_iterations=' + str(PRE_ITER) + ' --load_saved_weights=True'
114 | os.system(meta_train_command)
115 |
116 | if PHASE=='TEST_LOAD':
117 | print('****** Start Meta-test Phase with Downloaded Weights ******')
118 | test_command = process_test_command(0, base_command) + ' --load_saved_weights=True'
119 | os.system(test_command)
120 |
121 | THE_INPUT_PHASE = sys.argv[1]
122 | run_experiment(PHASE=THE_INPUT_PHASE)
123 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Meta-Transfer Learning for Few-Shot Learning
2 | [](https://github.com/y2l/meta-transfer-learning-tensorflow/blob/master/LICENSE)
3 | [](https://www.python.org/)
4 | [](https://github.com/y2l/meta-transfer-learning/tree/master/tensorflow)
5 | [](https://pytorch.org/)
6 | [](https://scholar.google.com/citations?view_op=view_citation&hl=en&user=Uf9GqRsAAAAJ&citation_for_view=Uf9GqRsAAAAJ:bEWYMUwI8FkC)
7 |
10 |
11 | This repository contains the TensorFlow and PyTorch implementation for the [CVPR 2019](http://cvpr2019.thecvf.com/) Paper ["Meta-Transfer Learning for Few-Shot Learning"](http://openaccess.thecvf.com/content_CVPR_2019/papers/Sun_Meta-Transfer_Learning_for_Few-Shot_Learning_CVPR_2019_paper.pdf) by [Qianru Sun](https://qianrusun1015.github.io),\* [Yaoyao Liu](https://people.mpi-inf.mpg.de/~yaliu/),\* [Tat-Seng Chua](https://www.chuatatseng.com/), and [Bernt Schiele](https://www.mpi-inf.mpg.de/departments/computer-vision-and-multimodal-computing/people/bernt-schiele/) (\*=equal contribution).
12 |
13 | If you have any questions on this repository or the related paper, feel free to [create an issue](https://github.com/yaoyao-liu/meta-transfer-learning/issues/new) or [send me an email](mailto:yaoyao.liu+github@mpi-inf.mpg.de).
14 |
15 | #### Summary
16 |
17 | * [Introduction](#introduction)
18 | * [Getting Started](#getting-started)
19 | * [Datasets](#datasets)
20 | * [Performance](#performance)
21 | * [Citation](#citation)
22 | * [Acknowledgements](#acknowledgements)
23 |
24 |
25 | ## Introduction
26 |
27 | Meta-learning has been proposed as a framework to address the challenging few-shot learning setting. The key idea is to leverage a large number of similar few-shot tasks in order to learn how to adapt a base-learner to a new task for which only a few labeled samples are available. As deep neural networks (DNNs) tend to overfit using a few samples only, meta-learning typically uses shallow neural networks (SNNs), thus limiting its effectiveness. In this paper we propose a novel few-shot learning method called ***meta-transfer learning (MTL)*** which learns to adapt a ***deep NN*** for ***few shot learning tasks***. Specifically, meta refers to training multiple tasks, and transfer is achieved by learning scaling and shifting functions of DNN weights for each task. We conduct experiments using (5-class, 1-shot) and (5-class, 5-shot) recognition tasks on two challenging few-shot learning benchmarks: 𝑚𝑖𝑛𝑖ImageNet and Fewshot-CIFAR100.
28 |
29 |
30 |
31 |
32 |
33 | > Figure: Meta-Transfer Learning. (a) Parameter-level fine-tuning (FT) is a conventional meta-training operation, e.g. in MAML. Its update works for all neuron parameters, 𝑊 and 𝑏. (b) Our neuron-level scaling and shifting (SS) operations in meta-transfer learning. They reduce the number of learning parameters and avoid overfitting problems. In addition, they keep large-scale trained parameters (in yellow) frozen, preventing “catastrophic forgetting”.
34 |
35 | ## Getting Started
36 |
37 | Please see `README.md` files in the corresponding folders:
38 |
39 | * TensorFlow: [\[Document\]](https://github.com/y2l/meta-transfer-learning/blob/master/tensorflow/README.md)
40 | * PyTorch: [\[Document\]](https://github.com/y2l/meta-transfer-learning/blob/master/pytorch/README.md)
41 |
42 | ## Datasets
43 |
44 | Directly download processed images: [\[Download Page\]](https://mtl.yyliu.net/download/)
45 |
46 | ### 𝒎𝒊𝒏𝒊ImageNet
47 |
48 | The 𝑚𝑖𝑛𝑖ImageNet dataset was proposed by [Vinyals et al.](http://papers.nips.cc/paper/6385-matching-networks-for-one-shot-learning.pdf) for few-shot learning evaluation. Its complexity is high due to the use of ImageNet images but requires fewer resources and infrastructure than running on the full [ImageNet dataset](https://arxiv.org/pdf/1409.0575.pdf). In total, there are 100 classes with 600 samples of 84×84 color images per class. These 100 classes are divided into 64, 16, and 20 classes respectively for sampling tasks for meta-training, meta-validation, and meta-test. To generate this dataset from ImageNet, you may use the repository [𝑚𝑖𝑛𝑖ImageNet tools](https://github.com/y2l/mini-imagenet-tools).
49 |
50 | ### Fewshot-CIFAR100
51 |
52 | Fewshot-CIFAR100 (FC100) is based on the popular object classification dataset CIFAR100. The splits were
53 | proposed by [TADAM](https://arxiv.org/pdf/1805.10123.pdf). It offers a more challenging scenario with lower image resolution and more challenging meta-training/test splits that are separated according to object super-classes. It contains 100 object classes and each class has 600 samples of 32 × 32 color images. The 100 classes belong to 20 super-classes. Meta-training data are from 60 classes belonging to 12 super-classes. Meta-validation and meta-test sets contain 20 classes belonging to 4 super-classes, respectively.
54 |
55 | ### 𝒕𝒊𝒆𝒓𝒆𝒅ImageNet
56 |
57 | The [𝑡𝑖𝑒𝑟𝑒𝑑ImageNet](https://arxiv.org/pdf/1803.00676.pdf) dataset is a larger subset of ILSVRC-12 with 608 classes (779,165 images) grouped into 34 higher-level nodes in the ImageNet human-curated hierarchy. To generate this dataset from ImageNet, you may use the repository 𝑡𝑖𝑒𝑟𝑒𝑑ImageNet dataset: [𝑡𝑖𝑒𝑟𝑒𝑑ImageNet tools](https://github.com/y2l/tiered-imagenet-tools).
58 |
59 |
60 | ## Performance
61 |
62 | | (%) | 𝑚𝑖𝑛𝑖 1-shot | 𝑚𝑖𝑛𝑖 5-shot | FC100 1-shot | FC100 5-shot |
63 | | ---------------------- | ------------ | ------------ | ------------ | ------------ |
64 | | `MTL Paper` | `60.2 ± 1.8` | `74.3 ± 0.9` | `43.6 ± 1.8` | `55.4 ± 0.9` |
65 | | `TensorFlow` | `60.8 ± 1.8` | `74.3 ± 0.9` | `44.3 ± 1.8` | `56.8 ± 1.0` |
66 | * The performance for the PyTorch version is under checking.
67 |
68 | ## Citation
69 |
70 | Please cite our paper if it is helpful to your work:
71 |
72 | ```bibtex
73 | @inproceedings{SunLCS2019MTL,
74 | author = {Qianru Sun and
75 | Yaoyao Liu and
76 | Tat{-}Seng Chua and
77 | Bernt Schiele},
78 | title = {Meta-Transfer Learning for Few-Shot Learning},
79 | booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition, {CVPR}
80 | 2019, Long Beach, CA, USA, June 16-20, 2019},
81 | pages = {403--412},
82 | publisher = {Computer Vision Foundation / {IEEE}},
83 | year = {2019}
84 | }
85 | ```
86 |
87 | ## Acknowledgements
88 |
89 | Our implementations use the source code from the following repositories and users:
90 |
91 | * [Model-Agnostic Meta-Learning](https://github.com/cbfinn/maml)
92 |
93 | * [Optimization as a Model for Few-Shot Learning](https://github.com/gitabcworld/FewShotLearning)
94 |
95 | * [Learning Embedding Adaptation for Few-Shot Learning](https://github.com/Sha-Lab/FEAT)
96 |
97 | * [dragen1860/MAML-Pytorch](https://github.com/dragen1860/MAML-Pytorch)
98 |
99 | * [@icoz69](https://github.com/icoz69)
100 |
101 | * [@CookieLau](https://github.com/CookieLau)
102 |
--------------------------------------------------------------------------------
/pytorch/models/resnet_mtl.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Modified from: https://github.com/Sha-Lab/FEAT
4 | ## Tianjin University
5 | ## liuyaoyao@tju.edu.cn
6 | ## Copyright (c) 2019
7 | ##
8 | ## This source code is licensed under the MIT-style license found in the
9 | ## LICENSE file in the root directory of this source tree
10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
11 | """ ResNet with MTL. """
12 | import torch.nn as nn
13 | from models.conv2d_mtl import Conv2dMtl
14 |
15 | def conv3x3(in_planes, out_planes, stride=1):
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17 | padding=1, bias=False)
18 |
19 | class BasicBlock(nn.Module):
20 | expansion = 1
21 |
22 | def __init__(self, inplanes, planes, stride=1, downsample=None):
23 | super(BasicBlock, self).__init__()
24 | self.conv1 = conv3x3(inplanes, planes, stride)
25 | self.bn1 = nn.BatchNorm2d(planes)
26 | self.relu = nn.ReLU(inplace=True)
27 | self.conv2 = conv3x3(planes, planes)
28 | self.bn2 = nn.BatchNorm2d(planes)
29 | self.downsample = downsample
30 | self.stride = stride
31 |
32 | def forward(self, x):
33 | residual = x
34 |
35 | out = self.conv1(x)
36 | out = self.bn1(out)
37 | out = self.relu(out)
38 |
39 | out = self.conv2(out)
40 | out = self.bn2(out)
41 |
42 | if self.downsample is not None:
43 | residual = self.downsample(x)
44 |
45 | out += residual
46 | out = self.relu(out)
47 |
48 | return out
49 |
50 | class Bottleneck(nn.Module):
51 | expansion = 4
52 |
53 | def __init__(self, inplanes, planes, stride=1, downsample=None):
54 | super(Bottleneck, self).__init__()
55 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
56 | self.bn1 = nn.BatchNorm2d(planes)
57 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
58 | padding=1, bias=False)
59 | self.bn2 = nn.BatchNorm2d(planes)
60 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
61 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
62 | self.relu = nn.ReLU(inplace=True)
63 | self.downsample = downsample
64 | self.stride = stride
65 |
66 | def forward(self, x):
67 | residual = x
68 |
69 | out = self.conv1(x)
70 | out = self.bn1(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv2(out)
74 | out = self.bn2(out)
75 | out = self.relu(out)
76 |
77 | out = self.conv3(out)
78 | out = self.bn3(out)
79 |
80 | if self.downsample is not None:
81 | residual = self.downsample(x)
82 |
83 | out += residual
84 | out = self.relu(out)
85 |
86 | return out
87 |
88 | def conv3x3mtl(in_planes, out_planes, stride=1):
89 | return Conv2dMtl(in_planes, out_planes, kernel_size=3, stride=stride,
90 | padding=1, bias=False)
91 |
92 |
93 | class BasicBlockMtl(nn.Module):
94 | expansion = 1
95 |
96 | def __init__(self, inplanes, planes, stride=1, downsample=None):
97 | super(BasicBlockMtl, self).__init__()
98 | self.conv1 = conv3x3mtl(inplanes, planes, stride)
99 | self.bn1 = nn.BatchNorm2d(planes)
100 | self.relu = nn.ReLU(inplace=True)
101 | self.conv2 = conv3x3mtl(planes, planes)
102 | self.bn2 = nn.BatchNorm2d(planes)
103 | self.downsample = downsample
104 | self.stride = stride
105 |
106 | def forward(self, x):
107 | residual = x
108 |
109 | out = self.conv1(x)
110 | out = self.bn1(out)
111 | out = self.relu(out)
112 |
113 | out = self.conv2(out)
114 | out = self.bn2(out)
115 |
116 | if self.downsample is not None:
117 | residual = self.downsample(x)
118 |
119 | out += residual
120 | out = self.relu(out)
121 |
122 | return out
123 |
124 |
125 | class BottleneckMtl(nn.Module):
126 | expansion = 4
127 |
128 | def __init__(self, inplanes, planes, stride=1, downsample=None):
129 | super(BottleneckMtl, self).__init__()
130 | self.conv1 = Conv2dMtl(inplanes, planes, kernel_size=1, bias=False)
131 | self.bn1 = nn.BatchNorm2d(planes)
132 | self.conv2 = Conv2dMtl(planes, planes, kernel_size=3, stride=stride,
133 | padding=1, bias=False)
134 | self.bn2 = nn.BatchNorm2d(planes)
135 | self.conv3 = Conv2dMtl(planes, planes * self.expansion, kernel_size=1, bias=False)
136 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
137 | self.relu = nn.ReLU(inplace=True)
138 | self.downsample = downsample
139 | self.stride = stride
140 |
141 | def forward(self, x):
142 | residual = x
143 |
144 | out = self.conv1(x)
145 | out = self.bn1(out)
146 | out = self.relu(out)
147 |
148 | out = self.conv2(out)
149 | out = self.bn2(out)
150 | out = self.relu(out)
151 |
152 | out = self.conv3(out)
153 | out = self.bn3(out)
154 |
155 | if self.downsample is not None:
156 | residual = self.downsample(x)
157 |
158 | out += residual
159 | out = self.relu(out)
160 |
161 | return out
162 |
163 | class ResNetMtl(nn.Module):
164 |
165 | def __init__(self, layers=[4, 4, 4], mtl=True):
166 | super(ResNetMtl, self).__init__()
167 | if mtl:
168 | self.Conv2d = Conv2dMtl
169 | block = BasicBlockMtl
170 | else:
171 | self.Conv2d = nn.Conv2d
172 | block = BasicBlock
173 | cfg = [160, 320, 640]
174 | self.inplanes = iChannels = int(cfg[0]/2)
175 | self.conv1 = self.Conv2d(3, iChannels, kernel_size=3, stride=1, padding=1)
176 | self.bn1 = nn.BatchNorm2d(iChannels)
177 | self.relu = nn.ReLU(inplace=True)
178 | self.layer1 = self._make_layer(block, cfg[0], layers[0], stride=2)
179 | self.layer2 = self._make_layer(block, cfg[1], layers[1], stride=2)
180 | self.layer3 = self._make_layer(block, cfg[2], layers[2], stride=2)
181 | self.avgpool = nn.AvgPool2d(10, stride=1)
182 |
183 | for m in self.modules():
184 | if isinstance(m, self.Conv2d):
185 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
186 | elif isinstance(m, nn.BatchNorm2d):
187 | nn.init.constant_(m.weight, 1)
188 | nn.init.constant_(m.bias, 0)
189 |
190 | def _make_layer(self, block, planes, blocks, stride=1):
191 | downsample = None
192 | if stride != 1 or self.inplanes != planes * block.expansion:
193 | downsample = nn.Sequential(
194 | self.Conv2d(self.inplanes, planes * block.expansion,
195 | kernel_size=1, stride=stride, bias=False),
196 | nn.BatchNorm2d(planes * block.expansion),
197 | )
198 |
199 | layers = []
200 | layers.append(block(self.inplanes, planes, stride, downsample))
201 | self.inplanes = planes * block.expansion
202 | for i in range(1, blocks):
203 | layers.append(block(self.inplanes, planes))
204 | return nn.Sequential(*layers)
205 |
206 | def forward(self, x):
207 | x = self.conv1(x)
208 | x = self.bn1(x)
209 | x = self.relu(x)
210 | x = self.layer1(x)
211 | x = self.layer2(x)
212 | x = self.layer3(x)
213 |
214 | x = self.avgpool(x)
215 | x = x.view(x.size(0), -1)
216 |
217 | return x
218 |
219 |
--------------------------------------------------------------------------------
/tensorflow/main.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Tianjin University
4 | ## liuyaoyao@tju.edu.cn
5 | ## Copyright (c) 2019
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 |
11 | import os
12 | from tensorflow.python.platform import flags
13 | from trainer.meta import MetaTrainer
14 | from trainer.pre import PreTrainer
15 |
16 | FLAGS = flags.FLAGS
17 |
18 | ### Basic options
19 | flags.DEFINE_integer('img_size', 84, 'image size')
20 | flags.DEFINE_integer('device_id', 0, 'GPU device ID to run the job.')
21 | flags.DEFINE_float('gpu_rate', 0.9, 'the parameter for the full_gpu_memory_mode')
22 | flags.DEFINE_string('phase', 'meta', 'pre or meta')
23 | flags.DEFINE_string('exp_log_label', 'experiment_results', 'directory for summaries and checkpoints')
24 | flags.DEFINE_string('logdir_base', './logs/', 'directory for logs')
25 | flags.DEFINE_bool('full_gpu_memory_mode', False, 'in this mode, the code occupies GPU memory in advance')
26 | flags.DEFINE_string('backbone_arch', 'resnet12', 'network backbone')
27 |
28 | ### Pre-train phase options
29 | flags.DEFINE_integer('pre_lr_dropstep', 5000, 'the step number to drop pre_lr')
30 | flags.DEFINE_integer('pretrain_class_num', 64, 'number of classes used in the pre-train phase')
31 | flags.DEFINE_integer('pretrain_batch_size', 64, 'batch_size for the pre-train phase')
32 | flags.DEFINE_integer('pretrain_iterations', 30000, 'number of pretraining iterations.')
33 | flags.DEFINE_integer('pre_sum_step', 10, 'the step number to summary during pretraining')
34 | flags.DEFINE_integer('pre_save_step', 1000, 'the step number to save the pretrain model')
35 | flags.DEFINE_integer('pre_print_step', 1000, 'the step number to print the pretrain results')
36 | flags.DEFINE_float('pre_lr', 0.001, 'the pretrain learning rate')
37 | flags.DEFINE_float('min_pre_lr', 0.0001, 'the pretrain learning rate min')
38 | flags.DEFINE_float('pretrain_dropout_keep', 0.9, 'the dropout keep parameter in the pre-train phase')
39 | flags.DEFINE_string('pretrain_folders', './data/mini-imagenet/train', 'directory for pre-train data')
40 | flags.DEFINE_string('pretrain_label', 'mini_normal', 'additional label for the pre-train log folder')
41 | flags.DEFINE_bool('pre_lr_stop', False, 'whether stop decrease the pre_lr when it is low')
42 |
43 | ### Meta phase options
44 | flags.DEFINE_integer('way_num', 5, 'number of classes (e.g. 5-way classification)')
45 | flags.DEFINE_integer('shot_num', 1, 'number of examples per class (K for K-shot learning)')
46 | flags.DEFINE_integer('metatrain_epite_sample_num', 15, 'number of meta train episode-test samples')
47 | flags.DEFINE_integer('metatest_epite_sample_num', 0, 'number of meta test episode-test samples, 0 means metatest_epite_sample_num=shot_num')
48 | flags.DEFINE_integer('meta_sum_step', 10, 'the step number to summary during meta-training')
49 | flags.DEFINE_integer('meta_save_step', 500, 'the step number to save the model')
50 | flags.DEFINE_integer('meta_intrain_val_sample', 600, 'the number of samples used for val during meta-train')
51 | flags.DEFINE_integer('meta_print_step', 100, 'the step number to print the meta-train results')
52 | flags.DEFINE_integer('meta_val_print_step', 100, 'the step number to print the meta-val results during meta-training')
53 | flags.DEFINE_integer('metatrain_iterations', 15000, 'number of meta-train iterations.')
54 | flags.DEFINE_integer('meta_batch_size', 2, 'number of tasks sampled per meta-update')
55 | flags.DEFINE_integer('train_base_epoch_num', 20, 'number of inner gradient updates during training.')
56 | flags.DEFINE_integer('test_base_epoch_num', 100, 'number of inner gradient updates during test.')
57 | flags.DEFINE_integer('lr_drop_step', 5000, 'the step number to drop meta_lr')
58 | flags.DEFINE_integer('test_iter', 1000, 'iteration to load model')
59 | flags.DEFINE_float('meta_lr', 0.001, 'the meta learning rate of the generator')
60 | flags.DEFINE_float('lr_drop_rate', 0.5, 'the step number to drop meta_lr')
61 | flags.DEFINE_float('min_meta_lr', 0.0001, 'the min meta learning rate of the generator')
62 | flags.DEFINE_float('base_lr', 1e-3, 'step size alpha for inner gradient update.')
63 | flags.DEFINE_string('metatrain_dir', './data/mini-imagenet/train', 'directory for meta-train set')
64 | flags.DEFINE_string('metaval_dir', './data/mini-imagenet/val', 'directory for meta-val set')
65 | flags.DEFINE_string('metatest_dir', './data/mini-imagenet/test', 'directory for meta-test set')
66 | flags.DEFINE_string('activation', 'leaky_relu', 'leaky_relu, relu, or None')
67 | flags.DEFINE_string('norm', 'batch_norm', 'batch_norm, layer_norm, or None')
68 | flags.DEFINE_bool('metatrain', True, 'is this the meta-train phase')
69 | flags.DEFINE_bool('base_augmentation', True, 'whether do data augmentation during base learning')
70 | flags.DEFINE_bool('redo_init', True, 're-build the initialization weights')
71 | flags.DEFINE_bool('load_saved_weights', False, 'load the downloaded weights')
72 |
73 | # Generate experiment key words string
74 | exp_string = 'arch(' + FLAGS.backbone_arch + ')'
75 | exp_string += '.cls(' + str(FLAGS.way_num) + ')'
76 | exp_string += '.shot(' + str(FLAGS.shot_num) + ')'
77 | exp_string += '.meta_batch(' + str(FLAGS.meta_batch_size) + ')'
78 | exp_string += '.base_epoch(' + str(FLAGS.train_base_epoch_num) + ')'
79 | exp_string += '.meta_lr(' + str(FLAGS.meta_lr) + ')'
80 | exp_string += '.base_lr(' + str(FLAGS.base_lr) + ')'
81 | exp_string += '.pre_iterations(' + str(FLAGS.pretrain_iterations) + ')'
82 | exp_string += '.pre_dropout(' + str(FLAGS.pretrain_dropout_keep) + ')'
83 | exp_string += '.acti(' + str(FLAGS.activation) + ')'
84 | exp_string += '.lr_drop_step(' + str(FLAGS.lr_drop_step) + ')'
85 | exp_string += '.lr_drop_rate(' + str(FLAGS.lr_drop_rate) + ')'
86 | exp_string += '.pre_label(' + str(FLAGS.pretrain_label) + ')'
87 |
88 | if FLAGS.norm == 'batch_norm':
89 | exp_string += '.norm(batch)'
90 | elif FLAGS.norm == 'layer_norm':
91 | exp_string += '.norm(layer)'
92 | elif FLAGS.norm == 'None':
93 | exp_string += '.norm(none)'
94 | else:
95 | raise Exception('Norm setting is not recognized')
96 |
97 | FLAGS.exp_string = exp_string
98 | print('Parameters: ' + exp_string)
99 |
100 | # Generate pre-train key words string
101 | pre_save_str = 'arch(' + FLAGS.backbone_arch + ')'
102 | pre_save_str += '.pre_lr(' + str(FLAGS.pre_lr) + ')'
103 | pre_save_str += '.pre_lrdrop(' + str(FLAGS.pre_lr_dropstep) + ')'
104 | pre_save_str += '.pre_class(' + str(FLAGS.pretrain_class_num) + ')'
105 | pre_save_str += '.pre_batch(' + str(FLAGS.pretrain_batch_size) + ')'
106 | pre_save_str += '.pre_dropout(' + str(FLAGS.pretrain_dropout_keep) + ')'
107 | if FLAGS.pre_lr_stop:
108 | pre_save_str += '.pre_lr_stop(True)'
109 | else:
110 | pre_save_str += '.pre_lr_stop(False)'
111 | pre_save_str += '.pre_label(' + FLAGS.pretrain_label + ')'
112 | FLAGS.pre_string = pre_save_str
113 |
114 | # Generate log folders
115 | FLAGS.logdir = FLAGS.logdir_base + FLAGS.exp_log_label
116 | FLAGS.pretrain_dir = FLAGS.logdir_base + 'pretrain_weights'
117 |
118 | if not os.path.exists(FLAGS.logdir_base):
119 | os.mkdir(FLAGS.logdir_base)
120 | if not os.path.exists(FLAGS.logdir):
121 | os.mkdir(FLAGS.logdir)
122 | if not os.path.exists(FLAGS.pretrain_dir):
123 | os.mkdir(FLAGS.pretrain_dir)
124 |
125 | # If FLAGS.redo_init is true, delete the previous intialization weights.
126 | if FLAGS.redo_init:
127 | if not os.path.exists('./logs/init_weights'):
128 | os.system('rm -r ./logs/init_weights')
129 | print('Init weights have been deleted')
130 | else:
131 | print('No init weights')
132 |
133 | def main():
134 | # Set GPU device id
135 | print('Using GPU ' + str(FLAGS.device_id))
136 | os.environ['CUDA_VISIBLE_DEVICES'] = str(FLAGS.device_id)
137 | # Select pre-train phase or meta-learning phase
138 | if FLAGS.phase=='pre':
139 | trainer = PreTrainer()
140 | elif FLAGS.phase=='meta':
141 | trainer = MetaTrainer()
142 | else:
143 | raise Exception('Please set correct phase')
144 |
145 | if __name__ == "__main__":
146 | main()
147 |
--------------------------------------------------------------------------------
/tensorflow/data_generator/meta_data_generator.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Modified from: https://github.com/cbfinn/maml
4 | ## Tianjin University
5 | ## liuyaoyao@tju.edu.cn
6 | ## Copyright (c) 2019
7 | ##
8 | ## This source code is licensed under the MIT-style license found in the
9 | ## LICENSE file in the root directory of this source tree
10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
11 |
12 | """ Data generator for meta-learning. """
13 | import numpy as np
14 | import os
15 | import random
16 | import tensorflow as tf
17 |
18 | from tqdm import trange
19 | from tensorflow.python.platform import flags
20 | from utils.misc import get_images, process_batch, process_batch_augmentation
21 |
22 | FLAGS = flags.FLAGS
23 |
24 | class MetaDataGenerator(object):
25 | """The class to generate data lists and episodes for meta-train and meta-test."""
26 | def __init__(self):
27 | # Set the base folder to save the data lists
28 | filename_dir = FLAGS.logdir_base + 'processed_data/'
29 | if not os.path.exists(filename_dir):
30 | os.mkdir(filename_dir)
31 |
32 | # Set the detailed folder name for saving the data lists
33 | self.this_setting_filename_dir = filename_dir + 'shot(' + str(FLAGS.shot_num) + ').way(' + str(FLAGS.way_num) \
34 | + ').metatr_epite(' + str(FLAGS.metatrain_epite_sample_num) + ').metate_epite(' + str(FLAGS.metatest_epite_sample_num) + ')/'
35 | if not os.path.exists(self.this_setting_filename_dir):
36 | os.mkdir(self.this_setting_filename_dir)
37 |
38 | def generate_data(self, data_type='train'):
39 | """The function to generate the data lists.
40 | Arg:
41 | data_type: the phase for meta-learning.
42 | """
43 | if data_type=='train':
44 | metatrain_folder = FLAGS.metatrain_dir
45 | folders = [os.path.join(metatrain_folder, label) \
46 | for label in os.listdir(metatrain_folder) \
47 | if os.path.isdir(os.path.join(metatrain_folder, label)) \
48 | ]
49 | num_total_batches = FLAGS.metatrain_iterations * FLAGS.meta_batch_size + 10
50 | num_samples_per_class = FLAGS.shot_num + FLAGS.metatrain_epite_sample_num
51 |
52 | elif data_type=='test':
53 | metatest_folder = FLAGS.metatest_dir
54 | folders = [os.path.join(metatest_folder, label) \
55 | for label in os.listdir(metatest_folder) \
56 | if os.path.isdir(os.path.join(metatest_folder, label)) \
57 | ]
58 | num_total_batches = 600
59 | if FLAGS.metatest_epite_sample_num==0:
60 | num_samples_per_class = FLAGS.shot_num*2
61 | else:
62 | num_samples_per_class = FLAGS.shot_num + FLAGS.metatest_epite_sample_num
63 | elif data_type=='val':
64 | metaval_folder = FLAGS.metaval_dir
65 | folders = [os.path.join(metaval_folder, label) \
66 | for label in os.listdir(metaval_folder) \
67 | if os.path.isdir(os.path.join(metaval_folder, label)) \
68 | ]
69 | num_total_batches = 600
70 | if FLAGS.metatest_epite_sample_num==0:
71 | num_samples_per_class = FLAGS.shot_num*2
72 | else:
73 | num_samples_per_class = FLAGS.shot_num + FLAGS.metatest_epite_sample_num
74 | else:
75 | raise Exception('Please check data list type')
76 |
77 | task_num = FLAGS.way_num * num_samples_per_class
78 | epitr_sample_num = FLAGS.shot_num
79 |
80 | if not os.path.exists(self.this_setting_filename_dir+'/' + data_type + '_data.npy'):
81 | print('Generating ' + data_type + ' data')
82 | data_list = []
83 | for epi_idx in trange(num_total_batches):
84 | sampled_character_folders = random.sample(folders, FLAGS.way_num)
85 | random.shuffle(sampled_character_folders)
86 | labels_and_images = get_images(sampled_character_folders, \
87 | range(FLAGS.way_num), nb_samples=num_samples_per_class, shuffle=False)
88 | labels = [li[0] for li in labels_and_images]
89 | filenames = [li[1] for li in labels_and_images]
90 | this_task_tr_filenames = []
91 | this_task_tr_labels = []
92 | this_task_te_filenames = []
93 | this_task_te_labels = []
94 | for class_idx in range(FLAGS.way_num):
95 | this_class_filenames = filenames[class_idx*num_samples_per_class:(class_idx+1)*num_samples_per_class]
96 | this_class_label = labels[class_idx*num_samples_per_class:(class_idx+1)*num_samples_per_class]
97 | this_task_tr_filenames += this_class_filenames[0:epitr_sample_num]
98 | this_task_tr_labels += this_class_label[0:epitr_sample_num]
99 | this_task_te_filenames += this_class_filenames[epitr_sample_num:]
100 | this_task_te_labels += this_class_label[epitr_sample_num:]
101 |
102 | this_batch_data = {'filenamea': this_task_tr_filenames, 'filenameb': this_task_te_filenames, 'labela': this_task_tr_labels, \
103 | 'labelb': this_task_te_labels}
104 | data_list.append(this_batch_data)
105 |
106 | np.save(self.this_setting_filename_dir+'/' + data_type + '_data.npy', data_list)
107 | print('The ' + data_type + ' data are saved')
108 | else:
109 | print('The ' + data_type + ' data have already been created')
110 |
111 | def load_data(self, data_type='test'):
112 | """The function to load the data lists.
113 | Arg:
114 | data_type: the phase for meta-learning.
115 | """
116 | data_list = np.load(self.this_setting_filename_dir+'/' + data_type + '_data.npy', allow_pickle=True, encoding="latin1")
117 | if data_type=='train':
118 | self.train_data = data_list
119 | elif data_type=='test':
120 | self.test_data = data_list
121 | elif data_type=='val':
122 | self.val_data = data_list
123 | else:
124 | print('[Error] Please check data list type')
125 |
126 | def load_episode(self, index, data_type='train'):
127 | """The function to load the episodes.
128 | Args:
129 | index: the index for the episodes.
130 | data_type: the phase for meta-learning.
131 | """
132 | if data_type=='train':
133 | data_list = self.train_data
134 | epite_sample_num = FLAGS.metatrain_epite_sample_num
135 | elif data_type=='test':
136 | data_list = self.test_data
137 | if FLAGS.metatest_epite_sample_num==0:
138 | epite_sample_num = FLAGS.shot_num
139 | else:
140 | epite_sample_num = FLAGS.metatest_episode_test_sample
141 | elif data_type=='val':
142 | data_list = self.val_data
143 | if FLAGS.metatest_epite_sample_num==0:
144 | epite_sample_num = FLAGS.shot_num
145 | else:
146 | epite_sample_num = FLAGS.metatest_episode_test_sample
147 | else:
148 | raise Exception('Please check data list type')
149 |
150 | dim_input = FLAGS.img_size * FLAGS.img_size * 3
151 | epitr_sample_num = FLAGS.shot_num
152 |
153 | this_episode = data_list[index]
154 | this_task_tr_filenames = this_episode['filenamea']
155 | this_task_te_filenames = this_episode['filenameb']
156 | this_task_tr_labels = this_episode['labela']
157 | this_task_te_labels = this_episode['labelb']
158 |
159 | if FLAGS.metatrain is False and FLAGS.base_augmentation:
160 | this_inputa, this_labela = process_batch_augmentation(this_task_tr_filenames, \
161 | this_task_tr_labels, dim_input, epitr_sample_num)
162 | this_inputb, this_labelb = process_batch(this_task_te_filenames, \
163 | this_task_te_labels, dim_input, epite_sample_num)
164 | else:
165 | this_inputa, this_labela = process_batch(this_task_tr_filenames, \
166 | this_task_tr_labels, dim_input, epitr_sample_num)
167 | this_inputb, this_labelb = process_batch(this_task_te_filenames, \
168 | this_task_te_labels, dim_input, epite_sample_num)
169 | return this_inputa, this_labela, this_inputb, this_labelb
170 |
--------------------------------------------------------------------------------
/pytorch/trainer/pre.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Tianjin University
4 | ## liuyaoyao@tju.edu.cn
5 | ## Copyright (c) 2019
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 | """ Trainer for pretrain phase. """
11 | import os.path as osp
12 | import os
13 | import tqdm
14 | import torch
15 | import torch.nn.functional as F
16 | from torch.utils.data import DataLoader
17 | from dataloader.samplers import CategoriesSampler
18 | from models.mtl import MtlLearner
19 | from utils.misc import Averager, Timer, count_acc, ensure_path
20 | from tensorboardX import SummaryWriter
21 | from dataloader.dataset_loader import DatasetLoader as Dataset
22 |
23 | class PreTrainer(object):
24 | """The class that contains the code for the pretrain phase."""
25 | def __init__(self, args):
26 | # Set the folder to save the records and checkpoints
27 | log_base_dir = './logs/'
28 | if not osp.exists(log_base_dir):
29 | os.mkdir(log_base_dir)
30 | pre_base_dir = osp.join(log_base_dir, 'pre')
31 | if not osp.exists(pre_base_dir):
32 | os.mkdir(pre_base_dir)
33 | save_path1 = '_'.join([args.dataset, args.model_type])
34 | save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \
35 | str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch)
36 | args.save_path = pre_base_dir + '/' + save_path1 + '_' + save_path2
37 | ensure_path(args.save_path)
38 |
39 | # Set args to be shareable in the class
40 | self.args = args
41 |
42 | # Load pretrain set
43 | self.trainset = Dataset('train', self.args, train_aug=True)
44 | self.train_loader = DataLoader(dataset=self.trainset, batch_size=args.pre_batch_size, shuffle=True, num_workers=8, pin_memory=True)
45 |
46 | # Load meta-val set
47 | self.valset = Dataset('val', self.args)
48 | self.val_sampler = CategoriesSampler(self.valset.label, 600, self.args.way, self.args.shot + self.args.val_query)
49 | self.val_loader = DataLoader(dataset=self.valset, batch_sampler=self.val_sampler, num_workers=8, pin_memory=True)
50 |
51 | # Set pretrain class number
52 | num_class_pretrain = self.trainset.num_class
53 |
54 | # Build pretrain model
55 | self.model = MtlLearner(self.args, mode='pre', num_cls=num_class_pretrain)
56 |
57 | # Set optimizer
58 | self.optimizer = torch.optim.SGD([{'params': self.model.encoder.parameters(), 'lr': self.args.pre_lr}, \
59 | {'params': self.model.pre_fc.parameters(), 'lr': self.args.pre_lr}], \
60 | momentum=self.args.pre_custom_momentum, nesterov=True, weight_decay=self.args.pre_custom_weight_decay)
61 | # Set learning rate scheduler
62 | self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.args.pre_step_size, \
63 | gamma=self.args.pre_gamma)
64 |
65 | # Set model to GPU
66 | if torch.cuda.is_available():
67 | torch.backends.cudnn.benchmark = True
68 | self.model = self.model.cuda()
69 |
70 | def save_model(self, name):
71 | """The function to save checkpoints.
72 | Args:
73 | name: the name for saved checkpoint
74 | """
75 | torch.save(dict(params=self.model.encoder.state_dict()), osp.join(self.args.save_path, name + '.pth'))
76 |
77 | def train(self):
78 | """The function for the pre-train phase."""
79 |
80 | # Set the pretrain log
81 | trlog = {}
82 | trlog['args'] = vars(self.args)
83 | trlog['train_loss'] = []
84 | trlog['val_loss'] = []
85 | trlog['train_acc'] = []
86 | trlog['val_acc'] = []
87 | trlog['max_acc'] = 0.0
88 | trlog['max_acc_epoch'] = 0
89 |
90 | # Set the timer
91 | timer = Timer()
92 | # Set global count to zero
93 | global_count = 0
94 | # Set tensorboardX
95 | writer = SummaryWriter(comment=self.args.save_path)
96 |
97 | # Start pretrain
98 | for epoch in range(1, self.args.pre_max_epoch + 1):
99 | # Update learning rate
100 | self.lr_scheduler.step()
101 | # Set the model to train mode
102 | self.model.train()
103 | self.model.mode = 'pre'
104 | # Set averager classes to record training losses and accuracies
105 | train_loss_averager = Averager()
106 | train_acc_averager = Averager()
107 |
108 | # Using tqdm to read samples from train loader
109 | tqdm_gen = tqdm.tqdm(self.train_loader)
110 | for i, batch in enumerate(tqdm_gen, 1):
111 | # Update global count number
112 | global_count = global_count + 1
113 | if torch.cuda.is_available():
114 | data, _ = [_.cuda() for _ in batch]
115 | else:
116 | data = batch[0]
117 | label = batch[1]
118 | if torch.cuda.is_available():
119 | label = label.type(torch.cuda.LongTensor)
120 | else:
121 | label = label.type(torch.LongTensor)
122 | # Output logits for model
123 | logits = self.model(data)
124 | # Calculate train loss
125 | loss = F.cross_entropy(logits, label)
126 | # Calculate train accuracy
127 | acc = count_acc(logits, label)
128 | # Write the tensorboardX records
129 | writer.add_scalar('data/loss', float(loss), global_count)
130 | writer.add_scalar('data/acc', float(acc), global_count)
131 | # Print loss and accuracy for this step
132 | tqdm_gen.set_description('Epoch {}, Loss={:.4f} Acc={:.4f}'.format(epoch, loss.item(), acc))
133 |
134 | # Add loss and accuracy for the averagers
135 | train_loss_averager.add(loss.item())
136 | train_acc_averager.add(acc)
137 |
138 | # Loss backwards and optimizer updates
139 | self.optimizer.zero_grad()
140 | loss.backward()
141 | self.optimizer.step()
142 |
143 | # Update the averagers
144 | train_loss_averager = train_loss_averager.item()
145 | train_acc_averager = train_acc_averager.item()
146 |
147 | # Start validation for this epoch, set model to eval mode
148 | self.model.eval()
149 | self.model.mode = 'preval'
150 |
151 | # Set averager classes to record validation losses and accuracies
152 | val_loss_averager = Averager()
153 | val_acc_averager = Averager()
154 |
155 | # Generate the labels for test
156 | label = torch.arange(self.args.way).repeat(self.args.val_query)
157 | if torch.cuda.is_available():
158 | label = label.type(torch.cuda.LongTensor)
159 | else:
160 | label = label.type(torch.LongTensor)
161 | label_shot = torch.arange(self.args.way).repeat(self.args.shot)
162 | if torch.cuda.is_available():
163 | label_shot = label_shot.type(torch.cuda.LongTensor)
164 | else:
165 | label_shot = label_shot.type(torch.LongTensor)
166 |
167 | # Print previous information
168 | if epoch % 10 == 0:
169 | print('Best Epoch {}, Best Val acc={:.4f}'.format(trlog['max_acc_epoch'], trlog['max_acc']))
170 | # Run meta-validation
171 | for i, batch in enumerate(self.val_loader, 1):
172 | if torch.cuda.is_available():
173 | data, _ = [_.cuda() for _ in batch]
174 | else:
175 | data = batch[0]
176 | p = self.args.shot * self.args.way
177 | data_shot, data_query = data[:p], data[p:]
178 | logits = self.model((data_shot, label_shot, data_query))
179 | loss = F.cross_entropy(logits, label)
180 | acc = count_acc(logits, label)
181 | val_loss_averager.add(loss.item())
182 | val_acc_averager.add(acc)
183 |
184 | # Update validation averagers
185 | val_loss_averager = val_loss_averager.item()
186 | val_acc_averager = val_acc_averager.item()
187 | # Write the tensorboardX records
188 | writer.add_scalar('data/val_loss', float(val_loss_averager), epoch)
189 | writer.add_scalar('data/val_acc', float(val_acc_averager), epoch)
190 | # Print loss and accuracy for this epoch
191 | print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format(epoch, val_loss_averager, val_acc_averager))
192 |
193 | # Update best saved model
194 | if val_acc_averager > trlog['max_acc']:
195 | trlog['max_acc'] = val_acc_averager
196 | trlog['max_acc_epoch'] = epoch
197 | self.save_model('max_acc')
198 | # Save model every 10 epochs
199 | if epoch % 10 == 0:
200 | self.save_model('epoch'+str(epoch))
201 |
202 | # Update the logs
203 | trlog['train_loss'].append(train_loss_averager)
204 | trlog['train_acc'].append(train_acc_averager)
205 | trlog['val_loss'].append(val_loss_averager)
206 | trlog['val_acc'].append(val_acc_averager)
207 |
208 | # Save log
209 | torch.save(trlog, osp.join(self.args.save_path, 'trlog'))
210 |
211 | if epoch % 10 == 0:
212 | print('Running Time: {}, Estimated Time: {}'.format(timer.measure(), timer.measure(epoch / self.args.max_epoch)))
213 | writer.close()
214 |
--------------------------------------------------------------------------------
/tensorflow/utils/misc.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Modified from: https://github.com/cbfinn/maml
4 | ## Tianjin University
5 | ## liuyaoyao@tju.edu.cn
6 | ## Copyright (c) 2019
7 | ##
8 | ## This source code is licensed under the MIT-style license found in the
9 | ## LICENSE file in the root directory of this source tree
10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
11 |
12 | """ Additional utility functions. """
13 | import numpy as np
14 | import os
15 | import cv2
16 | import random
17 | import tensorflow as tf
18 |
19 | from matplotlib.pyplot import imread
20 | from tensorflow.contrib.layers.python import layers as tf_layers
21 | from tensorflow.python.platform import flags
22 |
23 | FLAGS = flags.FLAGS
24 |
25 | def get_smallest_k_index(input_, k):
26 | """The function to get the smallest k items' indices.
27 | Args:
28 | input_: the list to be processed.
29 | k: the number of indices to return.
30 | Return:
31 | The index list with k dimensions.
32 | """
33 | input_copy = np.copy(input_)
34 | k_list = []
35 | for idx in range(k):
36 | this_index = np.argmin(input_copy)
37 | k_list.append(this_index)
38 | input_copy[this_index]=np.max(input_copy)
39 | return k_list
40 |
41 | def one_hot(inp):
42 | """The function to make the input to one-hot vectors.
43 | Arg:
44 | inp: the input numpy array.
45 | Return:
46 | The reorganized one-shot array.
47 | """
48 | n_class = inp.max() + 1
49 | n_sample = inp.shape[0]
50 | out = np.zeros((n_sample, n_class))
51 | for idx in range(n_sample):
52 | out[idx, inp[idx]] = 1
53 | return out
54 |
55 | def one_hot_class(inp, n_class):
56 | """The function to make the input to n-class one-hot vectors.
57 | Args:
58 | inp: the input numpy array.
59 | n_class: the number of classes.
60 | Return:
61 | The reorganized n-class one-shot array.
62 | """
63 | n_sample = inp.shape[0]
64 | out = np.zeros((n_sample, n_class))
65 | for idx in range(n_sample):
66 | out[idx, inp[idx]] = 1
67 | return out
68 |
69 | def process_batch(input_filename_list, input_label_list, dim_input, batch_sample_num):
70 | """The function to process a part of an episode.
71 | Args:
72 | input_filename_list: the image files' directory list.
73 | input_label_list: the image files' corressponding label list.
74 | dim_input: the dimension number of the images.
75 | batch_sample_num: the sample number of the inputed images.
76 | Returns:
77 | img_array: the numpy array of processed images.
78 | label_array: the numpy array of processed labels.
79 | """
80 | new_path_list = []
81 | new_label_list = []
82 | for k in range(batch_sample_num):
83 | class_idxs = list(range(0, FLAGS.way_num))
84 | random.shuffle(class_idxs)
85 | for class_idx in class_idxs:
86 | true_idx = class_idx*batch_sample_num + k
87 | new_path_list.append(input_filename_list[true_idx])
88 | new_label_list.append(input_label_list[true_idx])
89 |
90 | img_list = []
91 | for filepath in new_path_list:
92 | this_img = imread(filepath)
93 | this_img = np.reshape(this_img, [-1, dim_input])
94 | this_img = this_img / 255.0
95 | img_list.append(this_img)
96 |
97 | img_array = np.array(img_list).reshape([FLAGS.way_num*batch_sample_num, dim_input])
98 | label_array = one_hot(np.array(new_label_list)).reshape([FLAGS.way_num*batch_sample_num, -1])
99 | return img_array, label_array
100 |
101 | def process_batch_augmentation(input_filename_list, input_label_list, dim_input, batch_sample_num):
102 | """The function to process a part of an episode. All the images will be augmented by flipping.
103 | Args:
104 | input_filename_list: the image files' directory list.
105 | input_label_list: the image files' corressponding label list.
106 | dim_input: the dimension number of the images.
107 | batch_sample_num: the sample number of the inputed images.
108 | Returns:
109 | img_array: the numpy array of processed images.
110 | label_array: the numpy array of processed labels.
111 | """
112 | new_path_list = []
113 | new_label_list = []
114 | for k in range(batch_sample_num):
115 | class_idxs = list(range(0, FLAGS.way_num))
116 | random.shuffle(class_idxs)
117 | for class_idx in class_idxs:
118 | true_idx = class_idx*batch_sample_num + k
119 | new_path_list.append(input_filename_list[true_idx])
120 | new_label_list.append(input_label_list[true_idx])
121 |
122 | img_list = []
123 | img_list_h = []
124 | for filepath in new_path_list:
125 | this_img = imread(filepath)
126 | this_img_h = cv2.flip(this_img, 1)
127 | this_img = np.reshape(this_img, [-1, dim_input])
128 | this_img = this_img / 255.0
129 | img_list.append(this_img)
130 | this_img_h = np.reshape(this_img_h, [-1, dim_input])
131 | this_img_h = this_img_h / 255.0
132 | img_list_h.append(this_img_h)
133 |
134 | img_list_all = img_list + img_list_h
135 | label_list_all = new_label_list + new_label_list
136 |
137 | img_array = np.array(img_list_all).reshape([FLAGS.way_num*batch_sample_num*2, dim_input])
138 | label_array = one_hot(np.array(label_list_all)).reshape([FLAGS.way_num*batch_sample_num*2, -1])
139 | return img_array, label_array
140 |
141 |
142 | def get_images(paths, labels, nb_samples=None, shuffle=True):
143 | """The function to get the image files' directories with given class labels.
144 | Args:
145 | paths: the base path for the images.
146 | labels: the class name labels.
147 | nb_samples: the number of samples.
148 | shuffle: whether shuffle the generated image list.
149 | Return:
150 | The list for the image files' directories.
151 | """
152 | if nb_samples is not None:
153 | sampler = lambda x: random.sample(x, nb_samples)
154 | else:
155 | sampler = lambda x: x
156 | images = [(i, os.path.join(path, image)) \
157 | for i, path in zip(labels, paths) \
158 | for image in sampler(os.listdir(path))]
159 | if shuffle:
160 | random.shuffle(images)
161 | return images
162 |
163 | def get_pretrain_images(path, label):
164 | """The function to get the image files' directories for pre-train phase.
165 | Args:
166 | paths: the base path for the images.
167 | labels: the class name labels.
168 | is_val: whether the images are for the validation phase during pre-training.
169 | Return:
170 | The list for the image files' directories.
171 | """
172 | images = []
173 | for image in os.listdir(path):
174 | images.append((label, os.path.join(path, image)))
175 | return images
176 |
177 | def get_images_tc(paths, labels, nb_samples=None, shuffle=True, is_val=False):
178 | """The function to get the image files' directories with given class labels for pre-train phase.
179 | Args:
180 | paths: the base path for the images.
181 | labels: the class name labels.
182 | nb_samples: the number of samples.
183 | shuffle: whether shuffle the generated image list.
184 | is_val: whether the images are for the validation phase during pre-training.
185 | Return:
186 | The list for the image files' directories.
187 | """
188 | if nb_samples is not None:
189 | sampler = lambda x: random.sample(x, nb_samples)
190 | else:
191 | sampler = lambda x: x
192 | if is_val is False:
193 | images = [(i, os.path.join(path, image)) \
194 | for i, path in zip(labels, paths) \
195 | for image in sampler(os.listdir(path)[0:500])]
196 | else:
197 | images = [(i, os.path.join(path, image)) \
198 | for i, path in zip(labels, paths) \
199 | for image in sampler(os.listdir(path)[500:])]
200 | if shuffle:
201 | random.shuffle(images)
202 | return images
203 |
204 |
205 | ## Network helpers
206 |
207 | def leaky_relu(x, leak=0.1):
208 | """The leaky relu function.
209 | Args:
210 | x: the input feature maps.
211 | leak: the parameter for leaky relu.
212 | Return:
213 | The feature maps processed by non-liner layer.
214 | """
215 | return tf.maximum(x, leak*x)
216 |
217 | def resnet_conv_block(inp, cweight, bweight, reuse, scope, activation=leaky_relu):
218 | """The function to forward a conv layer.
219 | Args:
220 | inp: the input feature maps.
221 | cweight: the filters' weights for this conv layer.
222 | bweight: the biases' weights for this conv layer.
223 | reuse: whether reuse the variables for the batch norm.
224 | scope: the label for this conv layer.
225 | activation: the activation function for this conv layer.
226 | Return:
227 | The processed feature maps.
228 | """
229 | stride, no_stride = [1,2,2,1], [1,1,1,1]
230 |
231 | if FLAGS.activation == 'leaky_relu':
232 | activation = leaky_relu
233 | elif FLAGS.activation == 'relu':
234 | activation = tf.nn.relu
235 | else:
236 | activation = None
237 |
238 | conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME') + bweight
239 | normed = normalize(conv_output, activation, reuse, scope)
240 |
241 | return normed
242 |
243 | def resnet_nob_conv_block(inp, cweight, reuse, scope):
244 | """The function to forward a conv layer without biases, normalization and non-liner layer.
245 | Args:
246 | inp: the input feature maps.
247 | cweight: the filters' weights for this conv layer.
248 | reuse: whether reuse the variables for the batch norm.
249 | scope: the label for this conv layer.
250 | Return:
251 | The processed feature maps.
252 | """
253 | stride, no_stride = [1,2,2,1], [1,1,1,1]
254 | conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME')
255 | return conv_output
256 |
257 | def normalize(inp, activation, reuse, scope):
258 | """The function to forward the normalization.
259 | Args:
260 | inp: the input feature maps.
261 | reuse: whether reuse the variables for the batch norm.
262 | scope: the label for this conv layer.
263 | activation: the activation function for this conv layer.
264 | Return:
265 | The processed feature maps.
266 | """
267 | if FLAGS.norm == 'batch_norm':
268 | return tf_layers.batch_norm(inp, activation_fn=activation, reuse=reuse, scope=scope)
269 | elif FLAGS.norm == 'layer_norm':
270 | return tf_layers.layer_norm(inp, activation_fn=activation, reuse=reuse, scope=scope)
271 | elif FLAGS.norm == 'None':
272 | if activation is not None:
273 | return activation(inp)
274 | return inp
275 | else:
276 | raise ValueError('Please set correct normalization.')
277 |
278 | ## Loss functions
279 |
280 | def mse(pred, label):
281 | """The MSE loss function.
282 | Args:
283 | pred: the predictions.
284 | label: the ground truth labels.
285 | Return:
286 | The Loss.
287 | """
288 | pred = tf.reshape(pred, [-1])
289 | label = tf.reshape(label, [-1])
290 | return tf.reduce_mean(tf.square(pred-label))
291 |
292 | def softmaxloss(pred, label):
293 | """The softmax cross entropy loss function.
294 | Args:
295 | pred: the predictions.
296 | label: the ground truth labels.
297 | Return:
298 | The Loss.
299 | """
300 | return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=label))
301 |
302 | def xent(pred, label):
303 | """The softmax cross entropy loss function. The losses will be normalized by the shot number.
304 | Args:
305 | pred: the predictions.
306 | label: the ground truth labels.
307 | Return:
308 | The Loss.
309 | Note: with tf version <=0.12, this loss has incorrect 2nd derivatives
310 | """
311 | return tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=label) / FLAGS.shot_num
312 |
--------------------------------------------------------------------------------
/tensorflow/models/resnet12.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Modified from: https://github.com/cbfinn/maml
4 | ## Tianjin University
5 | ## liuyaoyao@tju.edu.cn
6 | ## Copyright (c) 2019
7 | ##
8 | ## This source code is licensed under the MIT-style license found in the
9 | ## LICENSE file in the root directory of this source tree
10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
11 |
12 | """ ResNet-12 class. """
13 | import numpy as np
14 | import tensorflow as tf
15 | from tensorflow.python.platform import flags
16 | from utils.misc import mse, softmaxloss, xent, resnet_conv_block, resnet_nob_conv_block
17 |
18 | FLAGS = flags.FLAGS
19 |
20 | class Models:
21 | """The class that contains the code for the basic resnet models and SS weights"""
22 | def __init__(self):
23 | # Set the dimension number for the input feature maps
24 | self.dim_input = FLAGS.img_size * FLAGS.img_size * 3
25 | # Set the dimension number for the outputs
26 | self.dim_output = FLAGS.way_num
27 | # Load base learning rates from FLAGS
28 | self.update_lr = FLAGS.base_lr
29 | # Load the pre-train phase class number from FLAGS
30 | self.pretrain_class_num = FLAGS.pretrain_class_num
31 | # Set the initial meta learning rate
32 | self.meta_lr = tf.placeholder_with_default(FLAGS.meta_lr, ())
33 | # Set the initial pre-train learning rate
34 | self.pretrain_lr = tf.placeholder_with_default(FLAGS.pre_lr, ())
35 |
36 | # Set the default objective functions for meta-train and pre-train
37 | self.loss_func = xent
38 | self.pretrain_loss_func = softmaxloss
39 |
40 | # Set the default channel number to 3
41 | self.channels = 3
42 | # Load the image size from FLAGS
43 | self.img_size = FLAGS.img_size
44 |
45 | def process_ss_weights(self, weights, ss_weights, label):
46 | """The function to process the scaling operation
47 | Args:
48 | weights: the weights for the resnet.
49 | ss_weights: the weights for scaling and shifting operation.
50 | label: the label to indicate which layer we are operating.
51 | Return:
52 | The processed weights for the new resnet.
53 | """
54 | [dim0, dim1] = weights[label].get_shape().as_list()[0:2]
55 | this_ss_weights = tf.tile(ss_weights[label], multiples=[dim0, dim1, 1, 1])
56 | return tf.multiply(weights[label], this_ss_weights)
57 |
58 | def forward_pretrain_resnet(self, inp, weights, reuse=False, scope=''):
59 | """The function to forward the resnet during pre-train phase
60 | Args:
61 | inp: input feature maps.
62 | weights: input resnet weights.
63 | reuse: reuse the batch norm weights or not.
64 | scope: the label to indicate which layer we are processing.
65 | Return:
66 | The processed feature maps.
67 | """
68 | inp = tf.reshape(inp, [-1, self.img_size, self.img_size, self.channels])
69 | net = self.pretrain_block_forward(inp, weights, 'block1', reuse, scope)
70 | net = self.pretrain_block_forward(net, weights, 'block2', reuse, scope)
71 | net = self.pretrain_block_forward(net, weights, 'block3', reuse, scope)
72 | net = self.pretrain_block_forward(net, weights, 'block4', reuse, scope)
73 | net = tf.nn.avg_pool(net, [1,5,5,1], [1,5,5,1], 'VALID')
74 | net = tf.reshape(net, [-1, np.prod([int(dim) for dim in net.get_shape()[1:]])])
75 | return net
76 |
77 | def forward_resnet(self, inp, weights, ss_weights, reuse=False, scope=''):
78 | """The function to forward the resnet during meta-train phase
79 | Args:
80 | inp: input feature maps.
81 | weights: input resnet weights.
82 | ss_weights: input scaling weights.
83 | reuse: reuse the batch norm weights or not.
84 | scope: the label to indicate which layer we are processing.
85 | Return:
86 | The processed feature maps.
87 | """
88 | inp = tf.reshape(inp, [-1, self.img_size, self.img_size, self.channels])
89 | net = self.block_forward(inp, weights, ss_weights, 'block1', reuse, scope)
90 | net = self.block_forward(net, weights, ss_weights, 'block2', reuse, scope)
91 | net = self.block_forward(net, weights, ss_weights, 'block3', reuse, scope)
92 | net = self.block_forward(net, weights, ss_weights, 'block4', reuse, scope)
93 | net = tf.nn.avg_pool(net, [1,5,5,1], [1,5,5,1], 'VALID')
94 | net = tf.reshape(net, [-1, np.prod([int(dim) for dim in net.get_shape()[1:]])])
95 | return net
96 |
97 | def forward_fc(self, inp, fc_weights):
98 | """The function to forward the fc layer
99 | Args:
100 | inp: input feature maps.
101 | fc_weights: input fc weights.
102 | Return:
103 | The processed feature maps.
104 | """
105 | net = tf.matmul(inp, fc_weights['w5']) + fc_weights['b5']
106 | return net
107 |
108 | def pretrain_block_forward(self, inp, weights, block, reuse, scope):
109 | """The function to forward a resnet block during pre-train phase
110 | Args:
111 | inp: input feature maps.
112 | weights: input resnet weights.
113 | block: the string to indicate which block we are processing.
114 | reuse: reuse the batch norm weights or not.
115 | scope: the label to indicate which layer we are processing.
116 | Return:
117 | The processed feature maps.
118 | """
119 | net = resnet_conv_block(inp, weights[block + '_conv1'], weights[block + '_bias1'], reuse, scope+block+'0')
120 | net = resnet_conv_block(net, weights[block + '_conv2'], weights[block + '_bias2'], reuse, scope+block+'1')
121 | net = resnet_conv_block(net, weights[block + '_conv3'], weights[block + '_bias3'], reuse, scope+block+'2')
122 | res = resnet_nob_conv_block(inp, weights[block + '_conv_res'], reuse, scope+block+'res')
123 | net = net + res
124 | net = tf.nn.max_pool(net, [1,2,2,1], [1,2,2,1], 'VALID')
125 | net = tf.nn.dropout(net, keep_prob=FLAGS.pretrain_dropout_keep)
126 | return net
127 |
128 | def block_forward(self, inp, weights, ss_weights, block, reuse, scope):
129 | """The function to forward a resnet block during meta-train phase
130 | Args:
131 | inp: input feature maps.
132 | weights: input resnet weights.
133 | ss_weights: input scaling weights.
134 | block: the string to indicate which block we are processing.
135 | reuse: reuse the batch norm weights or not.
136 | scope: the label to indicate which layer we are processing.
137 | Return:
138 | The processed feature maps.
139 | """
140 | net = resnet_conv_block(inp, self.process_ss_weights(weights, ss_weights, block + '_conv1'), \
141 | ss_weights[block + '_bias1'], reuse, scope+block+'0')
142 | net = resnet_conv_block(net, self.process_ss_weights(weights, ss_weights, block + '_conv2'), \
143 | ss_weights[block + '_bias2'], reuse, scope+block+'1')
144 | net = resnet_conv_block(net, self.process_ss_weights(weights, ss_weights, block + '_conv3'), \
145 | ss_weights[block + '_bias3'], reuse, scope+block+'2')
146 | res = resnet_nob_conv_block(inp, weights[block + '_conv_res'], reuse, scope+block+'res')
147 | net = net + res
148 | net = tf.nn.max_pool(net, [1,2,2,1], [1,2,2,1], 'VALID')
149 | net = tf.nn.dropout(net, keep_prob=1)
150 | return net
151 |
152 | def construct_fc_weights(self):
153 | """The function to construct fc weights.
154 | Return:
155 | The fc weights.
156 | """
157 | dtype = tf.float32
158 | fc_weights = {}
159 | fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
160 | if FLAGS.phase=='pre':
161 | fc_weights['w5'] = tf.get_variable('fc_w5', [512, FLAGS.pretrain_class_num], initializer=fc_initializer)
162 | fc_weights['b5'] = tf.Variable(tf.zeros([FLAGS.pretrain_class_num]), name='fc_b5')
163 | else:
164 | fc_weights['w5'] = tf.get_variable('fc_w5', [512, self.dim_output], initializer=fc_initializer)
165 | fc_weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='fc_b5')
166 | return fc_weights
167 |
168 | def construct_resnet_weights(self):
169 | """The function to construct resnet weights.
170 | Return:
171 | The resnet weights.
172 | """
173 | weights = {}
174 | dtype = tf.float32
175 | conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
176 | fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
177 | weights = self.construct_residual_block_weights(weights, 3, 3, 64, conv_initializer, dtype, 'block1')
178 | weights = self.construct_residual_block_weights(weights, 3, 64, 128, conv_initializer, dtype, 'block2')
179 | weights = self.construct_residual_block_weights(weights, 3, 128, 256, conv_initializer, dtype, 'block3')
180 | weights = self.construct_residual_block_weights(weights, 3, 256, 512, conv_initializer, dtype, 'block4')
181 | weights['w5'] = tf.get_variable('w5', [512, FLAGS.pretrain_class_num], initializer=fc_initializer)
182 | weights['b5'] = tf.Variable(tf.zeros([FLAGS.pretrain_class_num]), name='b5')
183 | return weights
184 |
185 | def construct_residual_block_weights(self, weights, k, last_dim_hidden, dim_hidden, conv_initializer, dtype, scope='block0'):
186 | """The function to construct one block of the resnet weights.
187 | Args:
188 | weights: the resnet weight list.
189 | k: the dimension number of the convolution kernel.
190 | last_dim_hidden: the hidden dimension number of last block.
191 | dim_hidden: the hidden dimension number of the block.
192 | conv_initializer: the convolution initializer.
193 | dtype: the dtype for numpy.
194 | scope: the label to indicate which block we are processing.
195 | Return:
196 | The resnet block weights.
197 | """
198 | weights[scope + '_conv1'] = tf.get_variable(scope + '_conv1', [k, k, last_dim_hidden, dim_hidden], \
199 | initializer=conv_initializer, dtype=dtype)
200 | weights[scope + '_bias1'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias1')
201 | weights[scope + '_conv2'] = tf.get_variable(scope + '_conv2', [k, k, dim_hidden, dim_hidden], \
202 | initializer=conv_initializer, dtype=dtype)
203 | weights[scope + '_bias2'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias2')
204 | weights[scope + '_conv3'] = tf.get_variable(scope + '_conv3', [k, k, dim_hidden, dim_hidden], \
205 | initializer=conv_initializer, dtype=dtype)
206 | weights[scope + '_bias3'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias3')
207 | weights[scope + '_conv_res'] = tf.get_variable(scope + '_conv_res', [1, 1, last_dim_hidden, dim_hidden], \
208 | initializer=conv_initializer, dtype=dtype)
209 | return weights
210 |
211 | def construct_resnet_ss_weights(self):
212 | """The function to construct ss weights.
213 | Return:
214 | The ss weights.
215 | """
216 | ss_weights = {}
217 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 3, 64, 'block1')
218 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 64, 128, 'block2')
219 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 128, 256, 'block3')
220 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 256, 512, 'block4')
221 | return ss_weights
222 |
223 | def construct_residual_block_ss_weights(self, ss_weights, last_dim_hidden, dim_hidden, scope='block0'):
224 | """The function to construct one block ss weights.
225 | Args:
226 | ss_weights: the ss weight list.
227 | last_dim_hidden: the hidden dimension number of last block.
228 | dim_hidden: the hidden dimension number of the block.
229 | scope: the label to indicate which block we are processing.
230 | Return:
231 | The ss block weights.
232 | """
233 | ss_weights[scope + '_conv1'] = tf.Variable(tf.ones([1, 1, last_dim_hidden, dim_hidden]), name=scope + '_conv1')
234 | ss_weights[scope + '_bias1'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias1')
235 | ss_weights[scope + '_conv2'] = tf.Variable(tf.ones([1, 1, dim_hidden, dim_hidden]), name=scope + '_conv2')
236 | ss_weights[scope + '_bias2'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias2')
237 | ss_weights[scope + '_conv3'] = tf.Variable(tf.ones([1, 1, dim_hidden, dim_hidden]), name=scope + '_conv3')
238 | ss_weights[scope + '_bias3'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias3')
239 | return ss_weights
240 |
--------------------------------------------------------------------------------
/tensorflow/models/meta_model.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Modified from: https://github.com/cbfinn/maml
4 | ## Tianjin University
5 | ## liuyaoyao@tju.edu.cn
6 | ## Copyright (c) 2019
7 | ##
8 | ## This source code is licensed under the MIT-style license found in the
9 | ## LICENSE file in the root directory of this source tree
10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
11 |
12 | """ Models for meta-learning. """
13 | import tensorflow as tf
14 | from tensorflow.python.platform import flags
15 | from utils.misc import mse, softmaxloss, xent, resnet_conv_block, resnet_nob_conv_block
16 |
17 | FLAGS = flags.FLAGS
18 |
19 | def MakeMetaModel():
20 | """The function to make meta model.
21 | Arg:
22 | Meta-train model class.
23 | """
24 | if FLAGS.backbone_arch=='resnet12':
25 | try:#python2
26 | from resnet12 import Models
27 | except ImportError:#python3
28 | from models.resnet12 import Models
29 | elif FLAGS.backbone_arch=='resnet18':
30 | try:#python2
31 | from resnet18 import Models
32 | except ImportError:#python3
33 | from models.resnet18 import Models
34 | else:
35 | print('Please set the correct backbone')
36 |
37 | class MetaModel(Models):
38 | """The class for the meta models. This class is inheritance from Models, so some variables are in the Models class."""
39 | def construct_model(self):
40 | """The function to construct meta-train model."""
41 | # Set the placeholder for the input episode
42 | self.inputa = tf.placeholder(tf.float32) # episode train images
43 | self.inputb = tf.placeholder(tf.float32) # episode test images
44 | self.labela = tf.placeholder(tf.float32) # episode train labels
45 | self.labelb = tf.placeholder(tf.float32) # episode test labels
46 |
47 | with tf.variable_scope('meta-model', reuse=None) as training_scope:
48 | # construct the model weights
49 | self.ss_weights = ss_weights = self.construct_resnet_ss_weights()
50 | self.weights = weights = self.construct_resnet_weights()
51 | self.fc_weights = fc_weights = self.construct_fc_weights()
52 |
53 | # Load base epoch number from FLAGS
54 | num_updates = FLAGS.train_base_epoch_num
55 |
56 | def task_metalearn(inp, reuse=True):
57 | """The function to process one episode in a meta-batch.
58 | Args:
59 | inp: the input episode.
60 | reuse: whether reuse the variables for the normalization.
61 | Returns:
62 | A serious outputs like losses and accuracies.
63 | """
64 | # Seperate inp to different variables
65 | inputa, inputb, labela, labelb = inp
66 | # Generate empty list to record losses
67 | lossa_list = [] # Base train loss list
68 | lossb_list = [] # Base test loss list
69 |
70 | # Embed the input images to embeddings with ss weights
71 | emb_outputa = self.forward_resnet(inputa, weights, ss_weights, reuse=reuse) # Embed episode train
72 | emb_outputb = self.forward_resnet(inputb, weights, ss_weights, reuse=True) # Embed episode test
73 |
74 | # Run the first epoch of the base learning
75 | # Forward fc layer for episode train
76 | outputa = self.forward_fc(emb_outputa, fc_weights)
77 | # Calculate base train loss
78 | lossa = self.loss_func(outputa, labela)
79 | # Record base train loss
80 | lossa_list.append(lossa)
81 | # Forward fc layer for episode test
82 | outputb = self.forward_fc(emb_outputb, fc_weights)
83 | # Calculate base test loss
84 | lossb = self.loss_func(outputb, labelb)
85 | # Record base test loss
86 | lossb_list.append(lossb)
87 | # Calculate the gradients for the fc layer
88 | grads = tf.gradients(lossa, list(fc_weights.values()))
89 | gradients = dict(zip(fc_weights.keys(), grads))
90 | # Use graient descent to update the fc layer
91 | fast_fc_weights = dict(zip(fc_weights.keys(), [fc_weights[key] - \
92 | self.update_lr*gradients[key] for key in fc_weights.keys()]))
93 |
94 | for j in range(num_updates - 1):
95 | # Run the following base epochs, these are similar to the first base epoch
96 | lossa = self.loss_func(self.forward_fc(emb_outputa, fast_fc_weights), labela)
97 | lossa_list.append(lossa)
98 | lossb = self.loss_func(self.forward_fc(emb_outputb, fast_fc_weights), labelb)
99 | lossb_list.append(lossb)
100 | grads = tf.gradients(lossa, list(fast_fc_weights.values()))
101 | gradients = dict(zip(fast_fc_weights.keys(), grads))
102 | fast_fc_weights = dict(zip(fast_fc_weights.keys(), [fast_fc_weights[key] - \
103 | self.update_lr*gradients[key] for key in fast_fc_weights.keys()]))
104 |
105 | # Calculate final episode test predictions
106 | outputb = self.forward_fc(emb_outputb, fast_fc_weights)
107 | # Calculate the final episode test loss, it is the loss for the episode on meta-train
108 | final_lossb = self.loss_func(outputb, labelb)
109 | # Calculate the final episode test accuarcy
110 | accb = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(outputb), 1), tf.argmax(labelb, 1))
111 |
112 | # Reorganize all the outputs to a list
113 | task_output = [final_lossb, lossb_list, lossa_list, accb]
114 |
115 | return task_output
116 |
117 | # Initial the batch normalization weights
118 | if FLAGS.norm is not None:
119 | unused = task_metalearn((self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]), False)
120 |
121 | # Set the dtype of the outputs
122 | out_dtype = [tf.float32, [tf.float32]*num_updates, [tf.float32]*num_updates, tf.float32]
123 |
124 | # Run two episodes for a meta batch using parallel setting
125 | result = tf.map_fn(task_metalearn, elems=(self.inputa, self.inputb, self.labela, self.labelb), \
126 | dtype=out_dtype, parallel_iterations=FLAGS.meta_batch_size)
127 | # Seperate the outputs to different variables
128 | lossb, lossesb, lossesa, accsb = result
129 |
130 | # Set the variables to output from the tensorflow graph
131 | self.total_loss = total_loss = tf.reduce_sum(lossb) / tf.to_float(FLAGS.meta_batch_size)
132 | self.total_accuracy = total_accuracy = tf.reduce_sum(accsb) / tf.to_float(FLAGS.meta_batch_size)
133 | self.total_lossa = total_lossa = [tf.reduce_sum(lossesa[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
134 | self.total_lossb = total_lossb = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
135 |
136 | # Set the meta-train optimizer
137 | optimizer = tf.train.AdamOptimizer(self.meta_lr)
138 | self.metatrain_op = optimizer.minimize(total_loss, var_list=list(ss_weights.values()) + list(fc_weights.values()))
139 |
140 | # Set the tensorboard
141 | self.training_summaries = []
142 | self.training_summaries.append(tf.summary.scalar('Meta Train Loss', (total_loss / tf.to_float(FLAGS.metatrain_epite_sample_num))))
143 | self.training_summaries.append(tf.summary.scalar('Meta Train Accuracy', total_accuracy))
144 | for j in range(num_updates):
145 | self.training_summaries.append(tf.summary.scalar('Base Train Loss Step' + str(j+1), total_lossa[j]))
146 | for j in range(num_updates):
147 | self.training_summaries.append(tf.summary.scalar('Base Val Loss Step' + str(j+1), total_lossb[j]))
148 |
149 | self.training_summ_op = tf.summary.merge(self.training_summaries)
150 |
151 | self.input_val_loss = tf.placeholder(tf.float32)
152 | self.input_val_acc = tf.placeholder(tf.float32)
153 | self.val_summaries = []
154 | self.val_summaries.append(tf.summary.scalar('Meta Val Loss', self.input_val_loss))
155 | self.val_summaries.append(tf.summary.scalar('Meta Val Accuracy', self.input_val_acc))
156 | self.val_summ_op = tf.summary.merge(self.val_summaries)
157 |
158 | def construct_test_model(self):
159 | """The function to construct meta-test model."""
160 | # Set the placeholder for the input episode
161 | self.inputa = tf.placeholder(tf.float32)
162 | self.inputb = tf.placeholder(tf.float32)
163 | self.labela = tf.placeholder(tf.float32)
164 | self.labelb = tf.placeholder(tf.float32)
165 |
166 | with tf.variable_scope('meta-test-model', reuse=None) as training_scope:
167 | # construct the model weights
168 | self.ss_weights = ss_weights = self.construct_resnet_ss_weights()
169 | self.weights = weights = self.construct_resnet_weights()
170 | self.fc_weights = fc_weights = self.construct_fc_weights()
171 |
172 | # Load test base epoch number from FLAGS
173 | num_updates = FLAGS.test_base_epoch_num
174 |
175 | def task_metalearn(inp, reuse=True):
176 | """The function to process one episode in a meta-batch.
177 | Args:
178 | inp: the input episode.
179 | reuse: whether reuse the variables for the normalization.
180 | Returns:
181 | A serious outputs like losses and accuracies.
182 | """
183 | # Seperate inp to different variables
184 | inputa, inputb, labela, labelb = inp
185 | # Generate empty list to record accuracies
186 | accb_list = []
187 |
188 | # Embed the input images to embeddings with ss weights
189 | emb_outputa = self.forward_resnet(inputa, weights, ss_weights, reuse=reuse)
190 | emb_outputb = self.forward_resnet(inputb, weights, ss_weights, reuse=True)
191 |
192 | # This part is similar to the meta-train function, you may refer to the comments above
193 | outputa = self.forward_fc(emb_outputa, fc_weights)
194 | lossa = self.loss_func(outputa, labela)
195 | grads = tf.gradients(lossa, list(fc_weights.values()))
196 | gradients = dict(zip(fc_weights.keys(), grads))
197 | fast_fc_weights = dict(zip(fc_weights.keys(), [fc_weights[key] - \
198 | self.update_lr*gradients[key] for key in fc_weights.keys()]))
199 | outputb = self.forward_fc(emb_outputb, fast_fc_weights)
200 | accb = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(outputb), 1), tf.argmax(labelb, 1))
201 | accb_list.append(accb)
202 |
203 | for j in range(num_updates - 1):
204 | lossa = self.loss_func(self.forward_fc(emb_outputa, fast_fc_weights), labela)
205 | grads = tf.gradients(lossa, list(fast_fc_weights.values()))
206 | gradients = dict(zip(fast_fc_weights.keys(), grads))
207 | fast_fc_weights = dict(zip(fast_fc_weights.keys(), [fast_fc_weights[key] - \
208 | self.update_lr*gradients[key] for key in fast_fc_weights.keys()]))
209 | outputb = self.forward_fc(emb_outputb, fast_fc_weights)
210 | accb = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(outputb), 1), tf.argmax(labelb, 1))
211 | accb_list.append(accb)
212 |
213 | lossb = self.loss_func(outputb, labelb)
214 |
215 | task_output = [lossb, accb, accb_list]
216 |
217 | return task_output
218 |
219 | if FLAGS.norm is not None:
220 | unused = task_metalearn((self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]), False)
221 |
222 | out_dtype = [tf.float32, tf.float32, [tf.float32]*num_updates]
223 |
224 | result = tf.map_fn(task_metalearn, elems=(self.inputa, self.inputb, self.labela, self.labelb), \
225 | dtype=out_dtype, parallel_iterations=FLAGS.meta_batch_size)
226 | lossesb, accsb, accsb_list = result
227 |
228 | self.metaval_total_loss = total_loss = tf.reduce_sum(lossesb)
229 | self.metaval_total_accuracy = total_accuracy = tf.reduce_sum(accsb)
230 | self.metaval_total_accuracies = total_accuracies =[tf.reduce_sum(accsb_list[j]) for j in range(num_updates)]
231 |
232 | return MetaModel()
233 |
--------------------------------------------------------------------------------
/pytorch/trainer/meta.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Modified from: https://github.com/Sha-Lab/FEAT
4 | ## Tianjin University
5 | ## liuyaoyao@tju.edu.cn
6 | ## Copyright (c) 2019
7 | ##
8 | ## This source code is licensed under the MIT-style license found in the
9 | ## LICENSE file in the root directory of this source tree
10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
11 | """ Trainer for meta-train phase. """
12 | import os.path as osp
13 | import os
14 | import tqdm
15 | import numpy as np
16 | import torch
17 | import torch.nn.functional as F
18 | from torch.utils.data import DataLoader
19 | from dataloader.samplers import CategoriesSampler
20 | from models.mtl import MtlLearner
21 | from utils.misc import Averager, Timer, count_acc, compute_confidence_interval, ensure_path
22 | from tensorboardX import SummaryWriter
23 | from dataloader.dataset_loader import DatasetLoader as Dataset
24 |
25 | class MetaTrainer(object):
26 | """The class that contains the code for the meta-train phase and meta-eval phase."""
27 | def __init__(self, args):
28 | # Set the folder to save the records and checkpoints
29 | log_base_dir = './logs/'
30 | if not osp.exists(log_base_dir):
31 | os.mkdir(log_base_dir)
32 | meta_base_dir = osp.join(log_base_dir, 'meta')
33 | if not osp.exists(meta_base_dir):
34 | os.mkdir(meta_base_dir)
35 | save_path1 = '_'.join([args.dataset, args.model_type, 'MTL'])
36 | save_path2 = 'shot' + str(args.shot) + '_way' + str(args.way) + '_query' + str(args.train_query) + \
37 | '_step' + str(args.step_size) + '_gamma' + str(args.gamma) + '_lr1' + str(args.meta_lr1) + '_lr2' + str(args.meta_lr2) + \
38 | '_batch' + str(args.num_batch) + '_maxepoch' + str(args.max_epoch) + \
39 | '_baselr' + str(args.base_lr) + '_updatestep' + str(args.update_step) + \
40 | '_stepsize' + str(args.step_size) + '_' + args.meta_label
41 | args.save_path = meta_base_dir + '/' + save_path1 + '_' + save_path2
42 | ensure_path(args.save_path)
43 |
44 | # Set args to be shareable in the class
45 | self.args = args
46 |
47 | # Load meta-train set
48 | self.trainset = Dataset('train', self.args)
49 | self.train_sampler = CategoriesSampler(self.trainset.label, self.args.num_batch, self.args.way, self.args.shot + self.args.train_query)
50 | self.train_loader = DataLoader(dataset=self.trainset, batch_sampler=self.train_sampler, num_workers=8, pin_memory=True)
51 |
52 | # Load meta-val set
53 | self.valset = Dataset('val', self.args)
54 | self.val_sampler = CategoriesSampler(self.valset.label, 600, self.args.way, self.args.shot + self.args.val_query)
55 | self.val_loader = DataLoader(dataset=self.valset, batch_sampler=self.val_sampler, num_workers=8, pin_memory=True)
56 |
57 | # Build meta-transfer learning model
58 | self.model = MtlLearner(self.args)
59 |
60 | # Set optimizer
61 | self.optimizer = torch.optim.Adam([{'params': filter(lambda p: p.requires_grad, self.model.encoder.parameters())}, \
62 | {'params': self.model.base_learner.parameters(), 'lr': self.args.meta_lr2}], lr=self.args.meta_lr1)
63 | # Set learning rate scheduler
64 | self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.args.step_size, gamma=self.args.gamma)
65 |
66 | # load pretrained model without FC classifier
67 | self.model_dict = self.model.state_dict()
68 | if self.args.init_weights is not None:
69 | pretrained_dict = torch.load(self.args.init_weights)['params']
70 | else:
71 | pre_base_dir = osp.join(log_base_dir, 'pre')
72 | pre_save_path1 = '_'.join([args.dataset, args.model_type])
73 | pre_save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \
74 | str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch)
75 | pre_save_path = pre_base_dir + '/' + pre_save_path1 + '_' + pre_save_path2
76 | pretrained_dict = torch.load(osp.join(pre_save_path, 'max_acc.pth'))['params']
77 | pretrained_dict = {'encoder.'+k: v for k, v in pretrained_dict.items()}
78 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in self.model_dict}
79 | print(pretrained_dict.keys())
80 | self.model_dict.update(pretrained_dict)
81 | self.model.load_state_dict(self.model_dict)
82 |
83 | # Set model to GPU
84 | if torch.cuda.is_available():
85 | torch.backends.cudnn.benchmark = True
86 | self.model = self.model.cuda()
87 |
88 | def save_model(self, name):
89 | """The function to save checkpoints.
90 | Args:
91 | name: the name for saved checkpoint
92 | """
93 | torch.save(dict(params=self.model.state_dict()), osp.join(self.args.save_path, name + '.pth'))
94 |
95 | def train(self):
96 | """The function for the meta-train phase."""
97 |
98 | # Set the meta-train log
99 | trlog = {}
100 | trlog['args'] = vars(self.args)
101 | trlog['train_loss'] = []
102 | trlog['val_loss'] = []
103 | trlog['train_acc'] = []
104 | trlog['val_acc'] = []
105 | trlog['max_acc'] = 0.0
106 | trlog['max_acc_epoch'] = 0
107 |
108 | # Set the timer
109 | timer = Timer()
110 | # Set global count to zero
111 | global_count = 0
112 | # Set tensorboardX
113 | writer = SummaryWriter(comment=self.args.save_path)
114 |
115 | # Generate the labels for train set of the episodes
116 | label_shot = torch.arange(self.args.way).repeat(self.args.shot)
117 | if torch.cuda.is_available():
118 | label_shot = label_shot.type(torch.cuda.LongTensor)
119 | else:
120 | label_shot = label_shot.type(torch.LongTensor)
121 |
122 | # Start meta-train
123 | for epoch in range(1, self.args.max_epoch + 1):
124 | # Update learning rate
125 | self.lr_scheduler.step()
126 | # Set the model to train mode
127 | self.model.train()
128 | # Set averager classes to record training losses and accuracies
129 | train_loss_averager = Averager()
130 | train_acc_averager = Averager()
131 |
132 | # Generate the labels for test set of the episodes during meta-train updates
133 | label = torch.arange(self.args.way).repeat(self.args.train_query)
134 | if torch.cuda.is_available():
135 | label = label.type(torch.cuda.LongTensor)
136 | else:
137 | label = label.type(torch.LongTensor)
138 |
139 | # Using tqdm to read samples from train loader
140 | tqdm_gen = tqdm.tqdm(self.train_loader)
141 | for i, batch in enumerate(tqdm_gen, 1):
142 | # Update global count number
143 | global_count = global_count + 1
144 | if torch.cuda.is_available():
145 | data, _ = [_.cuda() for _ in batch]
146 | else:
147 | data = batch[0]
148 | p = self.args.shot * self.args.way
149 | data_shot, data_query = data[:p], data[p:]
150 | # Output logits for model
151 | logits = self.model((data_shot, label_shot, data_query))
152 | # Calculate meta-train loss
153 | loss = F.cross_entropy(logits, label)
154 | # Calculate meta-train accuracy
155 | acc = count_acc(logits, label)
156 | # Write the tensorboardX records
157 | writer.add_scalar('data/loss', float(loss), global_count)
158 | writer.add_scalar('data/acc', float(acc), global_count)
159 | # Print loss and accuracy for this step
160 | tqdm_gen.set_description('Epoch {}, Loss={:.4f} Acc={:.4f}'.format(epoch, loss.item(), acc))
161 |
162 | # Add loss and accuracy for the averagers
163 | train_loss_averager.add(loss.item())
164 | train_acc_averager.add(acc)
165 |
166 | # Loss backwards and optimizer updates
167 | self.optimizer.zero_grad()
168 | loss.backward()
169 | self.optimizer.step()
170 |
171 | # Update the averagers
172 | train_loss_averager = train_loss_averager.item()
173 | train_acc_averager = train_acc_averager.item()
174 |
175 | # Start validation for this epoch, set model to eval mode
176 | self.model.eval()
177 |
178 | # Set averager classes to record validation losses and accuracies
179 | val_loss_averager = Averager()
180 | val_acc_averager = Averager()
181 |
182 | # Generate the labels for test set of the episodes during meta-val for this epoch
183 | label = torch.arange(self.args.way).repeat(self.args.val_query)
184 | if torch.cuda.is_available():
185 | label = label.type(torch.cuda.LongTensor)
186 | else:
187 | label = label.type(torch.LongTensor)
188 |
189 | # Print previous information
190 | if epoch % 10 == 0:
191 | print('Best Epoch {}, Best Val Acc={:.4f}'.format(trlog['max_acc_epoch'], trlog['max_acc']))
192 | # Run meta-validation
193 | for i, batch in enumerate(self.val_loader, 1):
194 | if torch.cuda.is_available():
195 | data, _ = [_.cuda() for _ in batch]
196 | else:
197 | data = batch[0]
198 | p = self.args.shot * self.args.way
199 | data_shot, data_query = data[:p], data[p:]
200 | logits = self.model((data_shot, label_shot, data_query))
201 | loss = F.cross_entropy(logits, label)
202 | acc = count_acc(logits, label)
203 |
204 | val_loss_averager.add(loss.item())
205 | val_acc_averager.add(acc)
206 |
207 | # Update validation averagers
208 | val_loss_averager = val_loss_averager.item()
209 | val_acc_averager = val_acc_averager.item()
210 | # Write the tensorboardX records
211 | writer.add_scalar('data/val_loss', float(val_loss_averager), epoch)
212 | writer.add_scalar('data/val_acc', float(val_acc_averager), epoch)
213 | # Print loss and accuracy for this epoch
214 | print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format(epoch, val_loss_averager, val_acc_averager))
215 |
216 | # Update best saved model
217 | if val_acc_averager > trlog['max_acc']:
218 | trlog['max_acc'] = val_acc_averager
219 | trlog['max_acc_epoch'] = epoch
220 | self.save_model('max_acc')
221 | # Save model every 10 epochs
222 | if epoch % 10 == 0:
223 | self.save_model('epoch'+str(epoch))
224 |
225 | # Update the logs
226 | trlog['train_loss'].append(train_loss_averager)
227 | trlog['train_acc'].append(train_acc_averager)
228 | trlog['val_loss'].append(val_loss_averager)
229 | trlog['val_acc'].append(val_acc_averager)
230 |
231 | # Save log
232 | torch.save(trlog, osp.join(self.args.save_path, 'trlog'))
233 |
234 | if epoch % 10 == 0:
235 | print('Running Time: {}, Estimated Time: {}'.format(timer.measure(), timer.measure(epoch / self.args.max_epoch)))
236 |
237 | writer.close()
238 |
239 | def eval(self):
240 | """The function for the meta-eval phase."""
241 | # Load the logs
242 | trlog = torch.load(osp.join(self.args.save_path, 'trlog'))
243 |
244 | # Load meta-test set
245 | test_set = Dataset('test', self.args)
246 | sampler = CategoriesSampler(test_set.label, 600, self.args.way, self.args.shot + self.args.val_query)
247 | loader = DataLoader(test_set, batch_sampler=sampler, num_workers=8, pin_memory=True)
248 |
249 | # Set test accuracy recorder
250 | test_acc_record = np.zeros((600,))
251 |
252 | # Load model for meta-test phase
253 | if self.args.eval_weights is not None:
254 | self.model.load_state_dict(torch.load(self.args.eval_weights)['params'])
255 | else:
256 | self.model.load_state_dict(torch.load(osp.join(self.args.save_path, 'max_acc' + '.pth'))['params'])
257 | # Set model to eval mode
258 | self.model.eval()
259 |
260 | # Set accuracy averager
261 | ave_acc = Averager()
262 |
263 | # Generate labels
264 | label = torch.arange(self.args.way).repeat(self.args.val_query)
265 | if torch.cuda.is_available():
266 | label = label.type(torch.cuda.LongTensor)
267 | else:
268 | label = label.type(torch.LongTensor)
269 | label_shot = torch.arange(self.args.way).repeat(self.args.shot)
270 | if torch.cuda.is_available():
271 | label_shot = label_shot.type(torch.cuda.LongTensor)
272 | else:
273 | label_shot = label_shot.type(torch.LongTensor)
274 |
275 | # Start meta-test
276 | for i, batch in enumerate(loader, 1):
277 | if torch.cuda.is_available():
278 | data, _ = [_.cuda() for _ in batch]
279 | else:
280 | data = batch[0]
281 | k = self.args.way * self.args.shot
282 | data_shot, data_query = data[:k], data[k:]
283 | logits = self.model((data_shot, label_shot, data_query))
284 | acc = count_acc(logits, label)
285 | ave_acc.add(acc)
286 | test_acc_record[i-1] = acc
287 | if i % 100 == 0:
288 | print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100))
289 |
290 | # Calculate the confidence interval, update the logs
291 | m, pm = compute_confidence_interval(test_acc_record)
292 | print('Val Best Epoch {}, Acc {:.4f}, Test Acc {:.4f}'.format(trlog['max_acc_epoch'], trlog['max_acc'], ave_acc.item()))
293 | print('Test Acc {:.4f} + {:.4f}'.format(m, pm))
294 |
--------------------------------------------------------------------------------
/tensorflow/trainer/meta.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Modified from: https://github.com/cbfinn/maml
4 | ## Tianjin University
5 | ## liuyaoyao@tju.edu.cn
6 | ## Copyright (c) 2019
7 | ##
8 | ## This source code is licensed under the MIT-style license found in the
9 | ## LICENSE file in the root directory of this source tree
10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
11 |
12 | """ Trainer for meta-learning. """
13 | import os
14 | import csv
15 | import pickle
16 | import random
17 | import numpy as np
18 | import tensorflow as tf
19 |
20 | from tqdm import trange
21 | from data_generator.meta_data_generator import MetaDataGenerator
22 | from models.meta_model import MakeMetaModel
23 | from tensorflow.python.platform import flags
24 | from utils.misc import process_batch
25 |
26 | FLAGS = flags.FLAGS
27 |
28 | class MetaTrainer:
29 | """The class that contains the code for the meta-train and meta-test."""
30 | def __init__(self):
31 | # Remove the saved datalist for a new experiment
32 | os.system('rm -r ./logs/processed_data/*')
33 | data_generator = MetaDataGenerator()
34 | if FLAGS.metatrain:
35 | # Build model for meta-train phase
36 | print('Building meta-train model')
37 | self.model = MakeMetaModel()
38 | self.model.construct_model()
39 | print('Meta-train model is built')
40 | # Start tensorflow session
41 | self.start_session()
42 | # Generate data for meta-train phase
43 | if FLAGS.load_saved_weights:
44 | random.seed(5)
45 | data_generator.generate_data(data_type='train')
46 | if FLAGS.load_saved_weights:
47 | random.seed(7)
48 | data_generator.generate_data(data_type='test')
49 | if FLAGS.load_saved_weights:
50 | random.seed(9)
51 | data_generator.generate_data(data_type='val')
52 | else:
53 | # Build model for meta-test phase
54 | print('Building meta-test mdoel')
55 | self.model = MakeMetaModel()
56 | self.model.construct_test_model()
57 | self.model.summ_op = tf.summary.merge_all()
58 | print('Meta-test model is built')
59 | # Start tensorflow session
60 | self.start_session()
61 | # Generate data for meta-test phase
62 | if FLAGS.load_saved_weights:
63 | random.seed(7)
64 | data_generator.generate_data(data_type='test')
65 | # Load the experiment setting string from FLAGS
66 | exp_string = FLAGS.exp_string
67 |
68 | # Global initialization and starting queue
69 | tf.global_variables_initializer().run()
70 | tf.train.start_queue_runners()
71 |
72 | if FLAGS.metatrain:
73 | # Process initialization weights for meta-train
74 | init_dir = FLAGS.logdir_base + 'init_weights/'
75 | if not os.path.exists(init_dir):
76 | os.mkdir(init_dir)
77 | pre_save_str = FLAGS.pre_string
78 | this_init_dir = init_dir + pre_save_str + '.pre_iter(' + str(FLAGS.pretrain_iterations) + ')/'
79 | if not os.path.exists(this_init_dir):
80 | # If there is no saved initialization weights for meta-train, load pre-train model and save initialization weights
81 | os.mkdir(this_init_dir)
82 | if FLAGS.load_saved_weights:
83 | print('Loading downloaded pretrain weights')
84 | weights = np.load('logs/download_weights/weights.npy', allow_pickle=True, encoding="latin1").tolist()
85 | else:
86 | print('Loading pretrain weights')
87 | weights_save_dir_base = FLAGS.pretrain_dir
88 | weights_save_dir = os.path.join(weights_save_dir_base, pre_save_str)
89 | weights = np.load(os.path.join(weights_save_dir, "weights_{}.npy".format(FLAGS.pretrain_iterations)), \
90 | allow_pickle=True, encoding="latin1").tolist()
91 | bais_list = [bais_item for bais_item in weights.keys() if '_bias' in bais_item]
92 | # Assign the bias weights to ss model in order to train them during meta-train
93 | for bais_key in bais_list:
94 | self.sess.run(tf.assign(self.model.ss_weights[bais_key], weights[bais_key]))
95 | # Assign pretrained weights to tensorflow variables
96 | for key in weights.keys():
97 | self.sess.run(tf.assign(self.model.weights[key], weights[key]))
98 | print('Pretrain weights loaded, saving init weights')
99 | # Load and save init weights for the model
100 | new_weights = self.sess.run(self.model.weights)
101 | ss_weights = self.sess.run(self.model.ss_weights)
102 | fc_weights = self.sess.run(self.model.fc_weights)
103 | np.save(this_init_dir + 'weights_init.npy', new_weights)
104 | np.save(this_init_dir + 'ss_weights_init.npy', ss_weights)
105 | np.save(this_init_dir + 'fc_weights_init.npy', fc_weights)
106 | else:
107 | # If the initialization weights are already generated, load the previous saved ones
108 | # This process is deactivate in the default settings, you may activate this for ablative study
109 | print('Loading previous saved init weights')
110 | weights = np.load(this_init_dir + 'weights_init.npy', allow_pickle=True, encoding="latin1").tolist()
111 | ss_weights = np.load(this_init_dir + 'ss_weights_init.npy', allow_pickle=True, encoding="latin1").tolist()
112 | fc_weights = np.load(this_init_dir + 'fc_weights_init.npy', allow_pickle=True, encoding="latin1").tolist()
113 | for key in weights.keys():
114 | self.sess.run(tf.assign(self.model.weights[key], weights[key]))
115 | for key in ss_weights.keys():
116 | self.sess.run(tf.assign(self.model.ss_weights[key], ss_weights[key]))
117 | for key in fc_weights.keys():
118 | self.sess.run(tf.assign(self.model.fc_weights[key], fc_weights[key]))
119 | print('Init weights loaded')
120 | else:
121 | # Load the saved meta model for meta-test phase
122 | if FLAGS.load_saved_weights:
123 | # Load the downloaded weights
124 | weights = np.load('./logs/download_weights/weights.npy', allow_pickle=True, encoding="latin1").tolist()
125 | ss_weights = np.load('./logs/download_weights/ss_weights.npy', allow_pickle=True, encoding="latin1").tolist()
126 | fc_weights = np.load('./logs/download_weights/fc_weights.npy', allow_pickle=True, encoding="latin1").tolist()
127 | else:
128 | # Load the saved weights of meta-train
129 | weights = np.load(FLAGS.logdir + '/' + exp_string + '/weights_' + str(FLAGS.test_iter) + '.npy', \
130 | allow_pickle=True, encoding="latin1").tolist()
131 | ss_weights = np.load(FLAGS.logdir + '/' + exp_string + '/ss_weights_' + str(FLAGS.test_iter) + '.npy', \
132 | allow_pickle=True, encoding="latin1").tolist()
133 | fc_weights = np.load(FLAGS.logdir + '/' + exp_string + '/fc_weights_' + str(FLAGS.test_iter) + '.npy', \
134 | allow_pickle=True, encoding="latin1").tolist()
135 | # Assign the weights to the tensorflow variables
136 | for key in weights.keys():
137 | self.sess.run(tf.assign(self.model.weights[key], weights[key]))
138 | for key in ss_weights.keys():
139 | self.sess.run(tf.assign(self.model.ss_weights[key], ss_weights[key]))
140 | for key in fc_weights.keys():
141 | self.sess.run(tf.assign(self.model.fc_weights[key], fc_weights[key]))
142 | print('Weights loaded')
143 | if FLAGS.load_saved_weights:
144 | print('Meta test using downloaded model')
145 | else:
146 | print('Test iter: ' + str(FLAGS.test_iter))
147 |
148 | if FLAGS.metatrain:
149 | self.train(data_generator)
150 | else:
151 | self.test(data_generator)
152 |
153 | def start_session(self):
154 | """The function to start tensorflow session."""
155 | if FLAGS.full_gpu_memory_mode:
156 | gpu_config = tf.ConfigProto()
157 | gpu_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_rate
158 | self.sess = tf.InteractiveSession(config=gpu_config)
159 | else:
160 | self.sess = tf.InteractiveSession()
161 |
162 | def train(self, data_generator):
163 | """The function for the meta-train phase
164 | Arg:
165 | data_generator: the data generator class for this phase
166 | """
167 | # Load the experiment setting string from FLAGS
168 | exp_string = FLAGS.exp_string
169 | # Generate tensorboard file writer
170 | train_writer = tf.summary.FileWriter(FLAGS.logdir + '/' + exp_string, self.sess.graph)
171 | print('Start meta-train phase')
172 | # Generate empty list to record losses and accuracies
173 | loss_list, acc_list = [], []
174 | # Load the meta learning rate from FLAGS
175 | train_lr = FLAGS.meta_lr
176 | # Load data for meta-train and meta validation
177 | data_generator.load_data(data_type='train')
178 | data_generator.load_data(data_type='val')
179 |
180 | for train_idx in trange(FLAGS.metatrain_iterations):
181 | # Load the episodes for this meta batch
182 | inputa = []
183 | labela = []
184 | inputb = []
185 | labelb = []
186 | for meta_batch_idx in range(FLAGS.meta_batch_size):
187 | this_episode = data_generator.load_episode(index=train_idx*FLAGS.meta_batch_size+meta_batch_idx, data_type='train')
188 | inputa.append(this_episode[0])
189 | labela.append(this_episode[1])
190 | inputb.append(this_episode[2])
191 | labelb.append(this_episode[3])
192 | inputa = np.array(inputa)
193 | labela = np.array(labela)
194 | inputb = np.array(inputb)
195 | labelb = np.array(labelb)
196 |
197 | # Generate feed dict for the tensorflow graph
198 | feed_dict = {self.model.inputa: inputa, self.model.inputb: inputb, \
199 | self.model.labela: labela, self.model.labelb: labelb, self.model.meta_lr: train_lr}
200 |
201 | # Set the variables to load from the tensorflow graph
202 | input_tensors = [self.model.metatrain_op] # The meta train optimizer
203 | input_tensors.extend([self.model.total_loss]) # The meta train loss
204 | input_tensors.extend([self.model.total_accuracy]) # The meta train accuracy
205 | input_tensors.extend([self.model.training_summ_op]) # The tensorboard summary operation
206 |
207 | # run this meta-train iteration
208 | result = self.sess.run(input_tensors, feed_dict)
209 |
210 | # record losses, accuracies and tensorboard
211 | loss_list.append(result[1])
212 | acc_list.append(result[2])
213 | train_writer.add_summary(result[3], train_idx)
214 |
215 | # print meta-train information on the screen after several iterations
216 | if (train_idx!=0) and train_idx % FLAGS.meta_print_step == 0:
217 | print_str = 'Iteration:' + str(train_idx)
218 | print_str += ' Loss:' + str(np.mean(loss_list)) + ' Acc:' + str(np.mean(acc_list))
219 | print(print_str)
220 | loss_list, acc_list = [], []
221 |
222 | # Save the model during meta-teain
223 | if train_idx % FLAGS.meta_save_step == 0:
224 | weights = self.sess.run(self.model.weights)
225 | ss_weights = self.sess.run(self.model.ss_weights)
226 | fc_weights = self.sess.run(self.model.fc_weights)
227 | np.save(FLAGS.logdir + '/' + exp_string + '/weights_' + str(train_idx) + '.npy', weights)
228 | np.save(FLAGS.logdir + '/' + exp_string + '/ss_weights_' + str(train_idx) + '.npy', ss_weights)
229 | np.save(FLAGS.logdir + '/' + exp_string + '/fc_weights_' + str(train_idx) + '.npy', fc_weights)
230 |
231 | # Run the meta-validation during meta-train
232 | if train_idx % FLAGS.meta_val_print_step == 0:
233 | test_loss = []
234 | test_accs = []
235 | for test_itr in range(FLAGS.meta_intrain_val_sample):
236 | this_episode = data_generator.load_episode(index=test_itr, data_type='val')
237 | test_inputa = this_episode[0][np.newaxis, :]
238 | test_labela = this_episode[1][np.newaxis, :]
239 | test_inputb = this_episode[2][np.newaxis, :]
240 | test_labelb = this_episode[3][np.newaxis, :]
241 |
242 | test_feed_dict = {self.model.inputa: test_inputa, self.model.inputb: test_inputb, \
243 | self.model.labela: test_labela, self.model.labelb: test_labelb, \
244 | self.model.meta_lr: 0.0}
245 | test_input_tensors = [self.model.total_loss, self.model.total_accuracy]
246 | test_result = self.sess.run(test_input_tensors, test_feed_dict)
247 | test_loss.append(test_result[0])
248 | test_accs.append(test_result[1])
249 |
250 | valsum_feed_dict = {self.model.input_val_loss: \
251 | np.mean(test_loss)*np.float(FLAGS.meta_batch_size)/np.float(FLAGS.shot_num), \
252 | self.model.input_val_acc: np.mean(test_accs)*np.float(FLAGS.meta_batch_size)}
253 | valsum = self.sess.run(self.model.val_summ_op, valsum_feed_dict)
254 | train_writer.add_summary(valsum, train_idx)
255 | print_str = '[***] Val Loss:' + str(np.mean(test_loss)*FLAGS.meta_batch_size) + \
256 | ' Val Acc:' + str(np.mean(test_accs)*FLAGS.meta_batch_size)
257 | print(print_str)
258 |
259 | # Reduce the meta learning rate to half after several iterations
260 | if (train_idx!=0) and train_idx % FLAGS.lr_drop_step == 0:
261 | train_lr = train_lr * FLAGS.lr_drop_rate
262 | if train_lr < 0.1 * FLAGS.meta_lr:
263 | train_lr = 0.1 * FLAGS.meta_lr
264 | print('Train LR: {}'.format(train_lr))
265 |
266 | # Save the final model
267 | weights = self.sess.run(self.model.weights)
268 | ss_weights = self.sess.run(self.model.ss_weights)
269 | fc_weights = self.sess.run(self.model.fc_weights)
270 | np.save(FLAGS.logdir + '/' + exp_string + '/weights_' + str(train_idx+1) + '.npy', weights)
271 | np.save(FLAGS.logdir + '/' + exp_string + '/ss_weights_' + str(train_idx+1) + '.npy', ss_weights)
272 | np.save(FLAGS.logdir + '/' + exp_string + '/fc_weights_' + str(train_idx+1) + '.npy', fc_weights)
273 |
274 | def test(self, data_generator):
275 | """The function for the meta-test phase
276 | Arg:
277 | data_generator: the data generator class for this phase
278 | """
279 | # Set meta-test episode number
280 | NUM_TEST_POINTS = 600
281 | # Load the experiment setting string from FLAGS
282 | exp_string = FLAGS.exp_string
283 | print('Start meta-test phase')
284 | np.random.seed(1)
285 | # Generate empty list to record accuracies
286 | metaval_accuracies = []
287 | # Load data for meta-test
288 | data_generator.load_data(data_type='test')
289 | for test_idx in trange(NUM_TEST_POINTS):
290 | # Load one episode for meta-test
291 | this_episode = data_generator.load_episode(index=test_idx, data_type='test')
292 | inputa = this_episode[0][np.newaxis, :]
293 | labela = this_episode[1][np.newaxis, :]
294 | inputb = this_episode[2][np.newaxis, :]
295 | labelb = this_episode[3][np.newaxis, :]
296 | feed_dict = {self.model.inputa: inputa, self.model.inputb: inputb, \
297 | self.model.labela: labela, self.model.labelb: labelb, self.model.meta_lr: 0.0}
298 | result = self.sess.run(self.model.metaval_total_accuracies, feed_dict)
299 | metaval_accuracies.append(result)
300 | # Calculate the mean accuarcies and the confidence intervals
301 | metaval_accuracies = np.array(metaval_accuracies)
302 | means = np.mean(metaval_accuracies, 0)
303 | stds = np.std(metaval_accuracies, 0)
304 | ci95 = 1.96*stds/np.sqrt(NUM_TEST_POINTS)
305 |
306 | # Print the meta-test results
307 | print('Test accuracies and confidence intervals')
308 | print((means, ci95))
309 |
310 | # Save the meta-test results in the csv files
311 | if not FLAGS.load_saved_weights:
312 | out_filename = FLAGS.logdir +'/'+ exp_string + '/' + 'result_' + str(FLAGS.shot_num) + 'shot_' + str(FLAGS.test_iter) + '.csv'
313 | out_pkl = FLAGS.logdir +'/'+ exp_string + '/' + 'result_' + str(FLAGS.shot_num) + 'shot_' + str(FLAGS.test_iter) + '.pkl'
314 | with open(out_pkl, 'wb') as f:
315 | pickle.dump({'mses': metaval_accuracies}, f)
316 | with open(out_filename, 'w') as f:
317 | writer = csv.writer(f, delimiter=',')
318 | writer.writerow(['update'+str(i) for i in range(len(means))])
319 | writer.writerow(means)
320 | writer.writerow(stds)
321 | writer.writerow(ci95)
322 |
--------------------------------------------------------------------------------
/tensorflow/models/resnet18.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Yaoyao Liu
3 | ## Modified from: https://github.com/cbfinn/maml
4 | ## Tianjin University
5 | ## liuyaoyao@tju.edu.cn
6 | ## Copyright (c) 2019
7 | ##
8 | ## This source code is licensed under the MIT-style license found in the
9 | ## LICENSE file in the root directory of this source tree
10 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
11 |
12 | """ ResNet-18 class. """
13 | import numpy as np
14 | import tensorflow as tf
15 | from tensorflow.python.platform import flags
16 | from utils.misc import mse, softmaxloss, xent, resnet_conv_block, resnet_nob_conv_block
17 |
18 | FLAGS = flags.FLAGS
19 |
20 | class Models:
21 | """The class that contains the code for the basic resnet models and SS weights"""
22 | def __init__(self):
23 | # Set the dimension number for the input feature maps
24 | self.dim_input = FLAGS.img_size * FLAGS.img_size * 3
25 | # Set the dimension number for the outputs
26 | self.dim_output = FLAGS.way_num
27 | # Load base learning rates from FLAGS
28 | self.update_lr = FLAGS.base_lr
29 | # Load the pre-train phase class number from FLAGS
30 | self.pretrain_class_num = FLAGS.pretrain_class_num
31 | # Set the initial meta learning rate
32 | self.meta_lr = tf.placeholder_with_default(FLAGS.meta_lr, ())
33 | # Set the initial pre-train learning rate
34 | self.pretrain_lr = tf.placeholder_with_default(FLAGS.pre_lr, ())
35 |
36 | # Set the default objective functions for meta-train and pre-train
37 | self.loss_func = xent
38 | self.pretrain_loss_func = softmaxloss
39 |
40 | # Set the default channel number to 3
41 | self.channels = 3
42 | # Load the image size from FLAGS
43 | self.img_size = FLAGS.img_size
44 |
45 | def process_ss_weights(self, weights, ss_weights, label):
46 | """The function to process the scaling operation
47 | Args:
48 | weights: the weights for the resnet.
49 | ss_weights: the weights for scaling and shifting operation.
50 | label: the label to indicate which layer we are operating.
51 | Return:
52 | The processed weights for the new resnet.
53 | """
54 | [dim0, dim1] = weights[label].get_shape().as_list()[0:2]
55 | this_ss_weights = tf.tile(ss_weights[label], multiples=[dim0, dim1, 1, 1])
56 | return tf.multiply(weights[label], this_ss_weights)
57 |
58 | def forward_pretrain_resnet(self, inp, weights, reuse=False, scope=''):
59 | """The function to forward the resnet during pre-train phase
60 | Args:
61 | inp: input feature maps.
62 | weights: input resnet weights.
63 | reuse: reuse the batch norm weights or not.
64 | scope: the label to indicate which layer we are processing.
65 | Return:
66 | The processed feature maps.
67 | """
68 | inp = tf.reshape(inp, [-1, self.img_size, self.img_size, self.channels])
69 | inp = tf.image.resize_images(inp, size=[224,224], method=tf.image.ResizeMethod.BILINEAR)
70 | net = self.pretrain_first_block_forward(inp, weights, 'block0_1', reuse, scope)
71 |
72 | net = self.pretrain_block_forward(net, weights, 'block1_1', reuse, scope)
73 | net = self.pretrain_block_forward(net, weights, 'block1_2', reuse, scope, block_last_layer=True)
74 |
75 | net = self.pretrain_block_forward(net, weights, 'block2_1', reuse, scope)
76 | net = self.pretrain_block_forward(net, weights, 'block2_2', reuse, scope, block_last_layer=True)
77 |
78 | net = self.pretrain_block_forward(net, weights, 'block3_1', reuse, scope)
79 | net = self.pretrain_block_forward(net, weights, 'block3_2', reuse, scope, block_last_layer=True)
80 |
81 | net = self.pretrain_block_forward(net, weights, 'block4_1', reuse, scope)
82 | net = self.pretrain_block_forward(net, weights, 'block4_2', reuse, scope, block_last_layer=True)
83 |
84 | net = tf.nn.avg_pool(net, [1,7,7,1], [1,7,7,1], 'SAME')
85 | net = tf.reshape(net, [-1, np.prod([int(dim) for dim in net.get_shape()[1:]])])
86 | return net
87 |
88 | def forward_resnet(self, inp, weights, ss_weights, reuse=False, scope=''):
89 | """The function to forward the resnet during meta-train phase
90 | Args:
91 | inp: input feature maps.
92 | weights: input resnet weights.
93 | ss_weights: input scaling weights.
94 | reuse: reuse the batch norm weights or not.
95 | scope: the label to indicate which layer we are processing.
96 | Return:
97 | The processed feature maps.
98 | """
99 | inp = tf.reshape(inp, [-1, self.img_size, self.img_size, self.channels])
100 | inp = tf.image.resize_images(inp, size=[224,224], method=tf.image.ResizeMethod.BILINEAR)
101 | net = self.first_block_forward(inp, weights, ss_weights, 'block0_1', reuse, scope)
102 |
103 | net = self.block_forward(net, weights, ss_weights, 'block1_1', reuse, scope)
104 | net = self.block_forward(net, weights, ss_weights, 'block1_2', reuse, scope, block_last_layer=True)
105 |
106 | net = self.block_forward(net, weights, ss_weights, 'block2_1', reuse, scope)
107 | net = self.block_forward(net, weights, ss_weights, 'block2_2', reuse, scope, block_last_layer=True)
108 |
109 | net = self.block_forward(net, weights, ss_weights, 'block3_1', reuse, scope)
110 | net = self.block_forward(net, weights, ss_weights, 'block3_2', reuse, scope, block_last_layer=True)
111 |
112 | net = self.block_forward(net, weights, ss_weights, 'block4_1', reuse, scope)
113 | net = self.block_forward(net, weights, ss_weights, 'block4_2', reuse, scope, block_last_layer=True)
114 |
115 | net = tf.nn.avg_pool(net, [1,7,7,1], [1,7,7,1], 'SAME')
116 | net = tf.reshape(net, [-1, np.prod([int(dim) for dim in net.get_shape()[1:]])])
117 | return net
118 |
119 | def forward_fc(self, inp, fc_weights):
120 | """The function to forward the fc layer
121 | Args:
122 | inp: input feature maps.
123 | fc_weights: input fc weights.
124 | Return:
125 | The processed feature maps.
126 | """
127 | net = tf.matmul(inp, fc_weights['w5']) + fc_weights['b5']
128 | return net
129 |
130 | def pretrain_block_forward(self, inp, weights, block, reuse, scope, block_last_layer=False):
131 | """The function to forward a resnet block during pre-train phase
132 | Args:
133 | inp: input feature maps.
134 | weights: input resnet weights.
135 | block: the string to indicate which block we are processing.
136 | reuse: reuse the batch norm weights or not.
137 | scope: the label to indicate which layer we are processing.
138 | block_last_layer: whether it is the last layer of this block.
139 | Return:
140 | The processed feature maps.
141 | """
142 | net = resnet_conv_block(inp, weights[block + '_conv1'], weights[block + '_bias1'], reuse, scope+block+'0')
143 | net = resnet_conv_block(net, weights[block + '_conv2'], weights[block + '_bias2'], reuse, scope+block+'1')
144 | res = resnet_nob_conv_block(inp, weights[block + '_conv_res'], reuse, scope+block+'res')
145 | net = net + res
146 | if block_last_layer:
147 | net = tf.nn.max_pool(net, [1,2,2,1], [1,2,2,1], 'SAME')
148 | net = tf.nn.dropout(net, keep_prob=FLAGS.pretrain_dropout_keep)
149 | return net
150 |
151 | def block_forward(self, inp, weights, ss_weights, block, reuse, scope, block_last_layer=False):
152 | """The function to forward a resnet block during meta-train phase
153 | Args:
154 | inp: input feature maps.
155 | weights: input resnet weights.
156 | ss_weights: input scaling weights.
157 | block: the string to indicate which block we are processing.
158 | reuse: reuse the batch norm weights or not.
159 | scope: the label to indicate which layer we are processing.
160 | block_last_layer: whether it is the last layer of this block.
161 | Return:
162 | The processed feature maps.
163 | """
164 | net = resnet_conv_block(inp, self.process_ss_weights(weights, ss_weights, block + '_conv1'), \
165 | ss_weights[block + '_bias1'], reuse, scope+block+'0')
166 | net = resnet_conv_block(net, self.process_ss_weights(weights, ss_weights, block + '_conv2'), \
167 | ss_weights[block + '_bias2'], reuse, scope+block+'1')
168 | res = resnet_nob_conv_block(inp, weights[block + '_conv_res'], reuse, scope+block+'res')
169 | net = net + res
170 | if block_last_layer:
171 | net = tf.nn.max_pool(net, [1,2,2,1], [1,2,2,1], 'SAME')
172 | net = tf.nn.dropout(net, keep_prob=1)
173 | return net
174 |
175 | def pretrain_first_block_forward(self, inp, weights, block, reuse, scope):
176 | """The function to forward the first resnet block during pre-train phase
177 | Args:
178 | inp: input feature maps.
179 | weights: input resnet weights.
180 | block: the string to indicate which block we are processing.
181 | reuse: reuse the batch norm weights or not.
182 | scope: the label to indicate which layer we are processing.
183 | Return:
184 | The processed feature maps.
185 | """
186 | net = resnet_conv_block(inp, weights[block + '_conv1'], weights[block + '_bias1'], reuse, scope+block+'0')
187 | net = tf.nn.max_pool(net, [1,3,3,1], [1,2,2,1], 'SAME')
188 | net = tf.nn.dropout(net, keep_prob=FLAGS.pretrain_dropout_keep)
189 | return net
190 |
191 | def first_block_forward(self, inp, weights, ss_weights, block, reuse, scope, block_last_layer=False):
192 | """The function to forward the first resnet block during meta-train phase
193 | Args:
194 | inp: input feature maps.
195 | weights: input resnet weights.
196 | block: the string to indicate which block we are processing.
197 | reuse: reuse the batch norm weights or not.
198 | scope: the label to indicate which layer we are processing.
199 | Return:
200 | The processed feature maps.
201 | """
202 | net = resnet_conv_block(inp, self.process_ss_weights(weights, ss_weights, block + '_conv1'), \
203 | ss_weights[block + '_bias1'], reuse, scope+block+'0')
204 | net = tf.nn.max_pool(net, [1,3,3,1], [1,2,2,1], 'SAME')
205 | net = tf.nn.dropout(net, keep_prob=1)
206 | return net
207 |
208 | def construct_fc_weights(self):
209 | """The function to construct fc weights.
210 | Return:
211 | The fc weights.
212 | """
213 | dtype = tf.float32
214 | fc_weights = {}
215 | fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
216 | if FLAGS.phase=='pre':
217 | fc_weights['w5'] = tf.get_variable('fc_w5', [512, FLAGS.pretrain_class_num], initializer=fc_initializer)
218 | fc_weights['b5'] = tf.Variable(tf.zeros([FLAGS.pretrain_class_num]), name='fc_b5')
219 | else:
220 | fc_weights['w5'] = tf.get_variable('fc_w5', [512, self.dim_output], initializer=fc_initializer)
221 | fc_weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='fc_b5')
222 | return fc_weights
223 |
224 | def construct_resnet_weights(self):
225 | """The function to construct resnet weights.
226 | Return:
227 | The resnet weights.
228 | """
229 | weights = {}
230 | dtype = tf.float32
231 | conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
232 | fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
233 | weights = self.construct_first_block_weights(weights, 7, 3, 64, conv_initializer, dtype, 'block0_1')
234 |
235 | weights = self.construct_residual_block_weights(weights, 3, 64, 64, conv_initializer, dtype, 'block1_1')
236 | weights = self.construct_residual_block_weights(weights, 3, 64, 64, conv_initializer, dtype, 'block1_2')
237 |
238 | weights = self.construct_residual_block_weights(weights, 3, 64, 128, conv_initializer, dtype, 'block2_1')
239 | weights = self.construct_residual_block_weights(weights, 3, 128, 128, conv_initializer, dtype, 'block2_2')
240 |
241 | weights = self.construct_residual_block_weights(weights, 3, 128, 256, conv_initializer, dtype, 'block3_1')
242 | weights = self.construct_residual_block_weights(weights, 3, 256, 256, conv_initializer, dtype, 'block3_2')
243 |
244 | weights = self.construct_residual_block_weights(weights, 3, 256, 512, conv_initializer, dtype, 'block4_1')
245 | weights = self.construct_residual_block_weights(weights, 3, 512, 512, conv_initializer, dtype, 'block4_2')
246 |
247 | weights['w5'] = tf.get_variable('w5', [512, FLAGS.pretrain_class_num], initializer=fc_initializer)
248 | weights['b5'] = tf.Variable(tf.zeros([FLAGS.pretrain_class_num]), name='b5')
249 | return weights
250 |
251 | def construct_residual_block_weights(self, weights, k, last_dim_hidden, dim_hidden, conv_initializer, dtype, scope='block0'):
252 | """The function to construct one block of the resnet weights.
253 | Args:
254 | weights: the resnet weight list.
255 | k: the dimension number of the convolution kernel.
256 | last_dim_hidden: the hidden dimension number of last block.
257 | dim_hidden: the hidden dimension number of the block.
258 | conv_initializer: the convolution initializer.
259 | dtype: the dtype for numpy.
260 | scope: the label to indicate which block we are processing.
261 | Return:
262 | The resnet block weights.
263 | """
264 | weights[scope + '_conv1'] = tf.get_variable(scope + '_conv1', [k, k, last_dim_hidden, dim_hidden], \
265 | initializer=conv_initializer, dtype=dtype)
266 | weights[scope + '_bias1'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias1')
267 | weights[scope + '_conv2'] = tf.get_variable(scope + '_conv2', [k, k, dim_hidden, dim_hidden], \
268 | initializer=conv_initializer, dtype=dtype)
269 | weights[scope + '_bias2'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias2')
270 | weights[scope + '_conv_res'] = tf.get_variable(scope + '_conv_res', [1, 1, last_dim_hidden, dim_hidden], \
271 | initializer=conv_initializer, dtype=dtype)
272 | return weights
273 |
274 | def construct_first_block_weights(self, weights, k, last_dim_hidden, dim_hidden, conv_initializer, dtype, scope='block0'):
275 | """The function to construct the first block of the resnet weights.
276 | Args:
277 | weights: the resnet weight list.
278 | k: the dimension number of the convolution kernel.
279 | last_dim_hidden: the hidden dimension number of last block.
280 | dim_hidden: the hidden dimension number of the block.
281 | conv_initializer: the convolution initializer.
282 | dtype: the dtype for numpy.
283 | scope: the label to indicate which block we are processing.
284 | Return:
285 | The resnet block weights.
286 | """
287 | weights[scope + '_conv1'] = tf.get_variable(scope + '_conv1', [k, k, last_dim_hidden, dim_hidden], \
288 | initializer=conv_initializer, dtype=dtype)
289 | weights[scope + '_bias1'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias1')
290 | return weights
291 |
292 | def construct_first_block_ss_weights(self, ss_weights, last_dim_hidden, dim_hidden, scope='block0'):
293 | """The function to construct first block's ss weights.
294 | Return:
295 | The ss weights.
296 | """
297 | ss_weights[scope + '_conv1'] = tf.Variable(tf.ones([1, 1, last_dim_hidden, dim_hidden]), name=scope + '_conv1')
298 | ss_weights[scope + '_bias1'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias1')
299 | return ss_weights
300 |
301 | def construct_resnet_ss_weights(self):
302 | """The function to construct ss weights.
303 | Return:
304 | The ss weights.
305 | """
306 | ss_weights = {}
307 | ss_weights = self.construct_first_block_ss_weights(ss_weights, 3, 64, 'block0_1')
308 |
309 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 64, 64, 'block1_1')
310 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 64, 64, 'block1_2')
311 |
312 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 64, 128, 'block2_1')
313 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 128, 128, 'block2_2')
314 |
315 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 128, 256, 'block3_1')
316 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 256, 256, 'block3_2')
317 |
318 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 256, 512, 'block4_1')
319 | ss_weights = self.construct_residual_block_ss_weights(ss_weights, 512, 512, 'block4_2')
320 |
321 | return ss_weights
322 |
323 | def construct_residual_block_ss_weights(self, ss_weights, last_dim_hidden, dim_hidden, scope='block0'):
324 | """The function to construct one block ss weights.
325 | Args:
326 | ss_weights: the ss weight list.
327 | last_dim_hidden: the hidden dimension number of last block.
328 | dim_hidden: the hidden dimension number of the block.
329 | scope: the label to indicate which block we are processing.
330 | Return:
331 | The ss block weights.
332 | """
333 | ss_weights[scope + '_conv1'] = tf.Variable(tf.ones([1, 1, last_dim_hidden, dim_hidden]), name=scope + '_conv1')
334 | ss_weights[scope + '_bias1'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias1')
335 | ss_weights[scope + '_conv2'] = tf.Variable(tf.ones([1, 1, dim_hidden, dim_hidden]), name=scope + '_conv2')
336 | ss_weights[scope + '_bias2'] = tf.Variable(tf.zeros([dim_hidden]), name=scope + '_bias2')
337 | return ss_weights
338 |
--------------------------------------------------------------------------------