├── LICENSE ├── README.md ├── data ├── art_painting_features.hdf5 ├── art_painting_train_features.hdf5 ├── art_painting_val_features.hdf5 ├── cartoon_features.hdf5 ├── cartoon_train_features.hdf5 ├── cartoon_val_features.hdf5 ├── photo_features.hdf5 ├── photo_train_features.hdf5 ├── photo_val_features.hdf5 ├── sketch_features.hdf5 ├── sketch_train_features.hdf5 └── sketch_val_features.hdf5 ├── data_reader.py ├── main_baseline.py ├── main_mldg.py ├── mlp.py ├── model.py ├── ops.py ├── run_baseline.sh ├── run_mldg.sh └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 DL 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MLDG 2 | This is a code sample for the paper "Learning to Generalize: Meta-Learning for Domain Generalization" https://arxiv.org/pdf/1710.03463.pdf 3 | 4 | 5 | This code is the MLP version of MLDG with one-hidden layer, whose inputs are the features extracted for PACS. 6 | The baseline is the one for the sanity check without the meta-train and meta-val losses. 7 | 8 | 9 | 10 | ## Requirements 11 | Python 2.7 12 | 13 | Pytorch 0.3.1 14 | 15 | ## Run the baseline 16 | Please download the data first, the data is the deep features extracted from ImageNet pretrained ResNet18, then 17 | 18 | sh run_baseline.sh 'data_root/' # data_root is the folder path where you download your data to. 19 | 20 | ## Run the MLDG 21 | 22 | sh run_mldg.sh 'data_root/' 23 | 24 | ## Bibtex 25 | ``` 26 | @inproceedings{Li2018MLDG, 27 | title={Learning to Generalize: Meta-Learning for Domain Generalization}, 28 | author={Li, Da and Yang, Yongxin and Song, Yi-Zhe and Hospedales, Timothy}, 29 | booktitle={AAAI Conference on Artificial Intelligence}, 30 | year={2018} 31 | } 32 | ``` 33 | 34 | ## Your own data 35 | Please tune the 'meta_step_size' and 'meta_val_beta' for your own data, this parameter is related to 'alpha' and 'beta' in paper which should be tuned for different cases. 36 | -------------------------------------------------------------------------------- /data/art_painting_features.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HAHA-DL/MLDG/19134860db3b02b511ea682d1cd353350322602f/data/art_painting_features.hdf5 -------------------------------------------------------------------------------- /data/art_painting_train_features.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HAHA-DL/MLDG/19134860db3b02b511ea682d1cd353350322602f/data/art_painting_train_features.hdf5 -------------------------------------------------------------------------------- /data/art_painting_val_features.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HAHA-DL/MLDG/19134860db3b02b511ea682d1cd353350322602f/data/art_painting_val_features.hdf5 -------------------------------------------------------------------------------- /data/cartoon_features.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HAHA-DL/MLDG/19134860db3b02b511ea682d1cd353350322602f/data/cartoon_features.hdf5 -------------------------------------------------------------------------------- /data/cartoon_train_features.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HAHA-DL/MLDG/19134860db3b02b511ea682d1cd353350322602f/data/cartoon_train_features.hdf5 -------------------------------------------------------------------------------- /data/cartoon_val_features.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HAHA-DL/MLDG/19134860db3b02b511ea682d1cd353350322602f/data/cartoon_val_features.hdf5 -------------------------------------------------------------------------------- /data/photo_features.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HAHA-DL/MLDG/19134860db3b02b511ea682d1cd353350322602f/data/photo_features.hdf5 -------------------------------------------------------------------------------- /data/photo_train_features.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HAHA-DL/MLDG/19134860db3b02b511ea682d1cd353350322602f/data/photo_train_features.hdf5 -------------------------------------------------------------------------------- /data/photo_val_features.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HAHA-DL/MLDG/19134860db3b02b511ea682d1cd353350322602f/data/photo_val_features.hdf5 -------------------------------------------------------------------------------- /data/sketch_features.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HAHA-DL/MLDG/19134860db3b02b511ea682d1cd353350322602f/data/sketch_features.hdf5 -------------------------------------------------------------------------------- /data/sketch_train_features.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HAHA-DL/MLDG/19134860db3b02b511ea682d1cd353350322602f/data/sketch_train_features.hdf5 -------------------------------------------------------------------------------- /data/sketch_val_features.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HAHA-DL/MLDG/19134860db3b02b511ea682d1cd353350322602f/data/sketch_val_features.hdf5 -------------------------------------------------------------------------------- /data_reader.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | 4 | from utils import unfold_label, shuffle_data 5 | 6 | 7 | class BatchImageGenerator: 8 | def __init__(self, flags, stage, file_path, b_unfold_label): 9 | 10 | if stage not in ['train', 'val', 'test']: 11 | assert ValueError('invalid stage!') 12 | 13 | self.configuration(flags, stage, file_path) 14 | self.load_data(b_unfold_label) 15 | 16 | def configuration(self, flags, stage, file_path): 17 | self.batch_size = flags.batch_size 18 | self.current_index = -1 19 | self.file_path = file_path 20 | self.stage = stage 21 | self.shuffled = False 22 | 23 | def normalize(self, inputs): 24 | 25 | # the mean and std used for the normalization of 26 | # the inputs for the pytorch pretrained model 27 | mean = [0.485, 0.456, 0.406] 28 | std = [0.229, 0.224, 0.225] 29 | 30 | # norm to [0, 1] 31 | inputs = inputs / 255.0 32 | 33 | inputs_norm = [] 34 | for item in inputs: 35 | item = np.transpose(item, (2, 0, 1)) 36 | item_norm = [] 37 | for c, m, s in zip(item, mean, std): 38 | c = np.subtract(c, m) 39 | c = np.divide(c, s) 40 | item_norm.append(c) 41 | 42 | item_norm = np.stack(item_norm) 43 | inputs_norm.append(item_norm) 44 | 45 | inputs_norm = np.stack(inputs_norm) 46 | 47 | return inputs_norm 48 | 49 | def load_data(self, b_unfold_label): 50 | file_path = self.file_path 51 | f = h5py.File(file_path, "r") 52 | self.images = np.array(f['images']) 53 | self.labels = np.array(f['labels']) 54 | f.close() 55 | 56 | # shift the labels to start from 0 57 | self.labels -= np.min(self.labels) 58 | 59 | if b_unfold_label: 60 | self.labels = unfold_label(labels=self.labels, classes=len(np.unique(self.labels))) 61 | assert len(self.images) == len(self.labels) 62 | 63 | self.file_num_train = len(self.labels) 64 | print('data num loaded:', self.file_num_train) 65 | 66 | if self.stage is 'train': 67 | self.images, self.labels = shuffle_data(samples=self.images, labels=self.labels) 68 | 69 | def get_images_labels_batch(self): 70 | 71 | images = [] 72 | labels = [] 73 | for index in range(self.batch_size): 74 | self.current_index += 1 75 | 76 | # void over flow 77 | if self.current_index > self.file_num_train - 1: 78 | self.current_index %= self.file_num_train 79 | 80 | self.images, self.labels = shuffle_data(samples=self.images, labels=self.labels) 81 | 82 | images.append(self.images[self.current_index]) 83 | labels.append(self.labels[self.current_index]) 84 | 85 | images = np.stack(images) 86 | labels = np.stack(labels) 87 | 88 | return images, labels 89 | -------------------------------------------------------------------------------- /main_baseline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from model import ModelBaseline 4 | 5 | 6 | def main(): 7 | main_arg_parser = argparse.ArgumentParser(description="parser") 8 | subparsers = main_arg_parser.add_subparsers(title="subcommands", dest="subcommand") 9 | 10 | train_arg_parser = subparsers.add_parser("train", help="parser for training arguments") 11 | train_arg_parser.add_argument("--test_every", type=int, default=50, 12 | help="number of test every steps") 13 | train_arg_parser.add_argument("--batch_size", type=int, default=128, 14 | help="batch size for training, default is 64") 15 | train_arg_parser.add_argument("--num_classes", type=int, default=10, 16 | help="number of classes") 17 | train_arg_parser.add_argument("--step_size", type=int, default=1, 18 | help="number of step size to decay the lr") 19 | train_arg_parser.add_argument("--inner_loops", type=int, default=200000, 20 | help="number of classes") 21 | train_arg_parser.add_argument("--unseen_index", type=int, default=0, 22 | help="index of unseen domain") 23 | train_arg_parser.add_argument("--lr", type=float, default=0.0001, 24 | help='learning rate of the model') 25 | train_arg_parser.add_argument("--weight_decay", type=float, default=0.00005, 26 | help='weight decay') 27 | train_arg_parser.add_argument("--momentum", type=float, default=0.9, 28 | help='momentum') 29 | train_arg_parser.add_argument("--logs", type=str, default='logs/', 30 | help='logs folder to write log') 31 | train_arg_parser.add_argument("--model_path", type=str, default='', 32 | help='folder for saving model') 33 | train_arg_parser.add_argument("--state_dict", type=str, default='', 34 | help='model of pre trained') 35 | train_arg_parser.add_argument("--data_root", type=str, default='', 36 | help='folder root of the data') 37 | train_arg_parser.add_argument("--debug", type=bool, default=False, 38 | help='whether for debug mode or not') 39 | args = main_arg_parser.parse_args() 40 | 41 | model_obj = ModelBaseline(flags=args) 42 | model_obj.train(flags=args) 43 | 44 | # after training, we should test the held out domain 45 | model_obj.heldout_test(flags=args) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /main_mldg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from model import ModelMLDG 4 | 5 | 6 | def main(): 7 | main_arg_parser = argparse.ArgumentParser(description="parser") 8 | subparsers = main_arg_parser.add_subparsers(title="subcommands", dest="subcommand") 9 | 10 | train_arg_parser = subparsers.add_parser("train", help="parser for training arguments") 11 | train_arg_parser.add_argument("--test_every", type=int, default=50, 12 | help="number of test every steps") 13 | train_arg_parser.add_argument("--batch_size", type=int, default=128, 14 | help="batch size for training, default is 64") 15 | train_arg_parser.add_argument("--num_classes", type=int, default=10, 16 | help="number of classes") 17 | train_arg_parser.add_argument("--step_size", type=int, default=1, 18 | help="number of classes") 19 | train_arg_parser.add_argument("--inner_loops", type=int, default=200000, 20 | help="number of classes") 21 | train_arg_parser.add_argument("--unseen_index", type=int, default=0, 22 | help="index of unseen domain") 23 | train_arg_parser.add_argument("--lr", type=float, default=0.0001, 24 | help='learning rate of the model') 25 | train_arg_parser.add_argument("--meta_step_size", type=float, default=0.0001, 26 | help='meta step size') 27 | train_arg_parser.add_argument("--meta_val_beta", type=float, default=0.0001, 28 | help='the strength of the meta val loss') 29 | train_arg_parser.add_argument("--weight_decay", type=float, default=0.00005, 30 | help='weight decay') 31 | train_arg_parser.add_argument("--momentum", type=float, default=0.9, 32 | help='momentum') 33 | train_arg_parser.add_argument("--logs", type=str, default='logs/', 34 | help='logs folder to write log') 35 | train_arg_parser.add_argument("--model_path", type=str, default='', 36 | help='folder for saving model') 37 | train_arg_parser.add_argument("--state_dict", type=str, default='', 38 | help='model of pre trained') 39 | train_arg_parser.add_argument("--data_root", type=str, default='', 40 | help='folder root of the data') 41 | train_arg_parser.add_argument("--stop_gradient", type=bool, default=False, 42 | help='whether stop gradient of the first order gradient') 43 | train_arg_parser.add_argument("--debug", type=bool, default=False, 44 | help='whether for debug mode or not') 45 | args = main_arg_parser.parse_args() 46 | 47 | model_obj = ModelMLDG(flags=args) 48 | model_obj.train(flags=args) 49 | 50 | # after training, we should test the held out domain 51 | model_obj.heldout_test(flags=args) 52 | 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /mlp.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from ops import linear 7 | 8 | 9 | class MLP(nn.Module): 10 | def __init__(self, num_classes=1000): 11 | super(MLP, self).__init__() 12 | self.fc1 = nn.Linear(512, 512) 13 | self.fc2 = nn.Linear(512, num_classes) 14 | 15 | # when you add the convolution and batch norm, below will be useful 16 | for m in self.modules(): 17 | if isinstance(m, nn.Conv2d): 18 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 19 | m.weight.data.normal_(0, math.sqrt(2. / n)) 20 | elif isinstance(m, nn.BatchNorm2d): 21 | m.weight.data.fill_(1) 22 | m.bias.data.zero_() 23 | 24 | def forward(self, x, meta_loss=None, meta_step_size=None, stop_gradient=False): 25 | 26 | x = linear(inputs=x, 27 | weight=self.fc1.weight, 28 | bias=self.fc1.bias, 29 | meta_loss=meta_loss, 30 | meta_step_size=meta_step_size, 31 | stop_gradient=stop_gradient) 32 | 33 | x = F.relu(x, inplace=True) 34 | 35 | x = linear(inputs=x, 36 | weight=self.fc2.weight, 37 | bias=self.fc2.bias, 38 | meta_loss=meta_loss, 39 | meta_step_size=meta_step_size, 40 | stop_gradient=stop_gradient) 41 | 42 | end_points = {'Predictions': F.softmax(input=x, dim=-1)} 43 | 44 | return x, end_points 45 | 46 | 47 | def MLPNet(**kwargs): 48 | model = MLP(**kwargs) 49 | return model 50 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.model_zoo as model_zoo 6 | from sklearn.metrics import accuracy_score 7 | from torch.autograd import Variable 8 | from torch.optim import lr_scheduler 9 | 10 | import mlp 11 | from data_reader import BatchImageGenerator 12 | from utils import sgd, crossentropyloss, fix_seed, write_log, compute_accuracy 13 | 14 | 15 | class ModelBaseline: 16 | def __init__(self, flags): 17 | 18 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 19 | 20 | # fix the random seed or not 21 | fix_seed() 22 | 23 | self.setup_path(flags) 24 | 25 | self.network = mlp.MLPNet(num_classes=flags.num_classes) 26 | 27 | self.network = self.network.cuda() 28 | 29 | print(self.network) 30 | print('flags:', flags) 31 | 32 | if not os.path.exists(flags.logs): 33 | os.mkdir(flags.logs) 34 | 35 | flags_log = os.path.join(flags.logs, 'flags_log.txt') 36 | write_log(flags, flags_log) 37 | 38 | self.load_state_dict(flags.state_dict) 39 | 40 | self.configure(flags) 41 | 42 | def setup_path(self, flags): 43 | 44 | root_folder = flags.data_root 45 | train_data = ['art_painting_train_features.hdf5', 46 | 'cartoon_train_features.hdf5', 47 | 'photo_train_features.hdf5', 48 | 'sketch_train_features.hdf5'] 49 | 50 | val_data = ['art_painting_val_features.hdf5', 51 | 'cartoon_val_features.hdf5', 52 | 'photo_val_features.hdf5', 53 | 'sketch_val_features.hdf5'] 54 | 55 | test_data = ['art_painting_features.hdf5', 56 | 'cartoon_features.hdf5', 57 | 'photo_features.hdf5', 58 | 'sketch_features.hdf5'] 59 | 60 | self.train_paths = [] 61 | for data in train_data: 62 | path = os.path.join(root_folder, data) 63 | self.train_paths.append(path) 64 | 65 | self.val_paths = [] 66 | for data in val_data: 67 | path = os.path.join(root_folder, data) 68 | self.val_paths.append(path) 69 | 70 | unseen_index = flags.unseen_index 71 | 72 | self.unseen_data_path = os.path.join(root_folder, test_data[unseen_index]) 73 | self.train_paths.remove(self.train_paths[unseen_index]) 74 | self.val_paths.remove(self.val_paths[unseen_index]) 75 | 76 | if not os.path.exists(flags.logs): 77 | os.mkdir(flags.logs) 78 | 79 | flags_log = os.path.join(flags.logs, 'path_log.txt') 80 | write_log(str(self.train_paths), flags_log) 81 | write_log(str(self.val_paths), flags_log) 82 | write_log(str(self.unseen_data_path), flags_log) 83 | 84 | self.batImageGenTrains = [] 85 | for train_path in self.train_paths: 86 | batImageGenTrain = BatchImageGenerator(flags=flags, file_path=train_path, stage='train', 87 | b_unfold_label=False) 88 | self.batImageGenTrains.append(batImageGenTrain) 89 | 90 | self.batImageGenVals = [] 91 | for val_path in self.val_paths: 92 | batImageGenVal = BatchImageGenerator(flags=flags, file_path=val_path, stage='val', 93 | b_unfold_label=True) 94 | self.batImageGenVals.append(batImageGenVal) 95 | 96 | def load_state_dict(self, state_dict=''): 97 | 98 | if state_dict: 99 | try: 100 | tmp = torch.load(state_dict) 101 | pretrained_dict = tmp['state'] 102 | except: 103 | pretrained_dict = model_zoo.load_url(state_dict) 104 | 105 | model_dict = self.network.state_dict() 106 | # 1. filter out unnecessary keys 107 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if 108 | k in model_dict and v.size() == model_dict[k].size()} 109 | # 2. overwrite entries in the existing state dict 110 | model_dict.update(pretrained_dict) 111 | # 3. load the new state dict 112 | self.network.load_state_dict(model_dict) 113 | 114 | def heldout_test(self, flags): 115 | 116 | # load the best model in the validation data 117 | model_path = os.path.join(flags.model_path, 'best_model.tar') 118 | self.load_state_dict(state_dict=model_path) 119 | 120 | # test 121 | batImageGenTest = BatchImageGenerator(flags=flags, file_path=self.unseen_data_path, stage='test', 122 | b_unfold_label=False) 123 | test_images = batImageGenTest.images 124 | 125 | threshold = 100 126 | n_slices_test = len(test_images) / threshold 127 | indices_test = [] 128 | for per_slice in range(n_slices_test - 1): 129 | indices_test.append(len(test_images) * (per_slice + 1) / n_slices_test) 130 | test_image_splits = np.split(test_images, indices_or_sections=indices_test) 131 | 132 | # Verify the splits are correct 133 | test_image_splits_2_whole = np.concatenate(test_image_splits) 134 | assert np.all(test_images == test_image_splits_2_whole) 135 | 136 | # split the test data into splits and test them one by one 137 | predictions = [] 138 | self.network.eval() 139 | for test_image_split in test_image_splits: 140 | images_test = Variable(torch.from_numpy(np.array(test_image_split, dtype=np.float32))).cuda() 141 | outputs, end_points = self.network(images_test) 142 | 143 | pred = end_points['Predictions'] 144 | pred = pred.cpu().data.numpy() 145 | predictions.append(pred) 146 | 147 | # concatenate the test predictions first 148 | predictions = np.concatenate(predictions) 149 | 150 | # accuracy 151 | accuracy = accuracy_score(y_true=batImageGenTest.labels, 152 | y_pred=np.argmax(predictions, -1)) 153 | 154 | flags_log = os.path.join(flags.logs, 'heldout_test_log.txt') 155 | write_log(accuracy, flags_log) 156 | 157 | def configure(self, flags): 158 | 159 | for name, para in self.network.named_parameters(): 160 | print(name, para.size()) 161 | 162 | self.optimizer = sgd(parameters=self.network.parameters(), 163 | lr=flags.lr, 164 | weight_decay=flags.weight_decay, 165 | momentum=flags.momentum) 166 | 167 | self.scheduler = lr_scheduler.StepLR(optimizer=self.optimizer, step_size=flags.step_size, gamma=0.1) 168 | self.loss_fn = crossentropyloss() 169 | 170 | def train(self, flags): 171 | self.network.train() 172 | 173 | self.best_accuracy_val = -1 174 | 175 | for ite in range(flags.inner_loops): 176 | 177 | self.scheduler.step(epoch=ite) 178 | 179 | total_loss = 0.0 180 | for index in range(len(self.batImageGenTrains)): 181 | images_train, labels_train = self.batImageGenTrains[index].get_images_labels_batch() 182 | 183 | inputs, labels = torch.from_numpy( 184 | np.array(images_train, dtype=np.float32)), torch.from_numpy( 185 | np.array(labels_train, dtype=np.float32)) 186 | 187 | # wrap the inputs and labels in Variable 188 | inputs, labels = Variable(inputs, requires_grad=False).cuda(), \ 189 | Variable(labels, requires_grad=False).long().cuda() 190 | 191 | outputs, _ = self.network(x=inputs) 192 | 193 | # loss 194 | loss = self.loss_fn(outputs, labels) 195 | total_loss += loss 196 | 197 | # init the grad to zeros first 198 | self.optimizer.zero_grad() 199 | 200 | # backward your network 201 | total_loss.backward() 202 | 203 | # optimize the parameters 204 | self.optimizer.step() 205 | 206 | print( 207 | 'ite:', ite, 'loss:', total_loss.cpu().data.numpy()[0], 'lr:', 208 | self.scheduler.get_lr()[0]) 209 | 210 | flags_log = os.path.join(flags.logs, 'loss_log.txt') 211 | write_log( 212 | str(total_loss.cpu().data.numpy()[0]), 213 | flags_log) 214 | 215 | del total_loss, outputs 216 | 217 | if ite % flags.test_every == 0 and ite is not 0 or flags.debug: 218 | self.test_workflow(self.batImageGenVals, flags, ite) 219 | 220 | def test_workflow(self, batImageGenVals, flags, ite): 221 | 222 | accuracies = [] 223 | for count, batImageGenVal in enumerate(batImageGenVals): 224 | accuracy_val = self.test(batImageGenTest=batImageGenVal, flags=flags, ite=ite, 225 | log_dir=flags.logs, log_prefix='val_index_{}'.format(count)) 226 | 227 | accuracies.append(accuracy_val) 228 | 229 | mean_acc = np.mean(accuracies) 230 | 231 | if mean_acc > self.best_accuracy_val: 232 | self.best_accuracy_val = mean_acc 233 | 234 | f = open(os.path.join(flags.logs, 'Best_val.txt'), mode='a') 235 | f.write('ite:{}, best val accuracy:{}\n'.format(ite, self.best_accuracy_val)) 236 | f.close() 237 | 238 | if not os.path.exists(flags.model_path): 239 | os.mkdir(flags.model_path) 240 | 241 | outfile = os.path.join(flags.model_path, 'best_model.tar') 242 | torch.save({'ite': ite, 'state': self.network.state_dict()}, outfile) 243 | 244 | def test(self, flags, ite, log_prefix, log_dir='logs/', batImageGenTest=None): 245 | 246 | # switch on the network test mode 247 | self.network.eval() 248 | 249 | if batImageGenTest is None: 250 | batImageGenTest = BatchImageGenerator(flags=flags, file_path='', stage='test', b_unfold_label=True) 251 | 252 | images_test = batImageGenTest.images 253 | labels_test = batImageGenTest.labels 254 | 255 | threshold = 50 256 | if len(images_test) > threshold: 257 | 258 | n_slices_test = len(images_test) / threshold 259 | indices_test = [] 260 | for per_slice in range(n_slices_test - 1): 261 | indices_test.append(len(images_test) * (per_slice + 1) / n_slices_test) 262 | test_image_splits = np.split(images_test, indices_or_sections=indices_test) 263 | 264 | # Verify the splits are correct 265 | test_image_splits_2_whole = np.concatenate(test_image_splits) 266 | assert np.all(images_test == test_image_splits_2_whole) 267 | 268 | # split the test data into splits and test them one by one 269 | test_image_preds = [] 270 | for test_image_split in test_image_splits: 271 | images_test = Variable(torch.from_numpy(np.array(test_image_split, dtype=np.float32))).cuda() 272 | outputs, end_points = self.network(images_test) 273 | 274 | predictions = end_points['Predictions'] 275 | predictions = predictions.cpu().data.numpy() 276 | test_image_preds.append(predictions) 277 | 278 | # concatenate the test predictions first 279 | predictions = np.concatenate(test_image_preds) 280 | else: 281 | images_test = Variable(torch.from_numpy(np.array(images_test, dtype=np.float32))).cuda() 282 | outputs, end_points = self.network(images_test) 283 | 284 | predictions = end_points['Predictions'] 285 | predictions = predictions.cpu().data.numpy() 286 | 287 | accuracy = compute_accuracy(predictions=predictions, labels=labels_test) 288 | print('----------accuracy test----------:', accuracy) 289 | 290 | if not os.path.exists(log_dir): 291 | os.mkdir(log_dir) 292 | 293 | log_path = os.path.join(log_dir, '{}.txt'.format(log_prefix)) 294 | write_log(str('ite:{}, accuracy:{}'.format(ite, accuracy)), log_path=log_path) 295 | 296 | # switch on the network train mode after test 297 | self.network.train() 298 | 299 | return accuracy 300 | 301 | 302 | class ModelMLDG(ModelBaseline): 303 | def __init__(self, flags): 304 | 305 | ModelBaseline.__init__(self, flags) 306 | 307 | def train(self, flags): 308 | self.network.train() 309 | 310 | self.best_accuracy_val = -1 311 | 312 | for ite in range(flags.inner_loops): 313 | 314 | self.scheduler.step(epoch=ite) 315 | 316 | # select the validation domain for meta val 317 | index_val = np.random.choice(a=np.arange(0, len(self.batImageGenTrains)), size=1)[0] 318 | batImageMetaVal = self.batImageGenTrains[index_val] 319 | 320 | meta_train_loss = 0.0 321 | # get the inputs and labels from the data reader 322 | for index in range(len(self.batImageGenTrains)): 323 | 324 | if index == index_val: 325 | continue 326 | 327 | images_train, labels_train = self.batImageGenTrains[index].get_images_labels_batch() 328 | 329 | inputs_train, labels_train = torch.from_numpy( 330 | np.array(images_train, dtype=np.float32)), torch.from_numpy( 331 | np.array(labels_train, dtype=np.float32)) 332 | 333 | # wrap the inputs and labels in Variable 334 | inputs_train, labels_train = Variable(inputs_train, requires_grad=False).cuda(), \ 335 | Variable(labels_train, requires_grad=False).long().cuda() 336 | 337 | # forward with the adapted parameters 338 | outputs_train, _ = self.network(x=inputs_train) 339 | 340 | # loss 341 | loss = self.loss_fn(outputs_train, labels_train) 342 | meta_train_loss += loss 343 | 344 | image_val, labels_val = batImageMetaVal.get_images_labels_batch() 345 | inputs_val, labels_val = torch.from_numpy( 346 | np.array(image_val, dtype=np.float32)), torch.from_numpy( 347 | np.array(labels_val, dtype=np.float32)) 348 | 349 | # wrap the inputs and labels in Variable 350 | inputs_val, labels_val = Variable(inputs_val, requires_grad=False).cuda(), \ 351 | Variable(labels_val, requires_grad=False).long().cuda() 352 | 353 | # forward with the adapted parameters 354 | outputs_val, _ = self.network(x=inputs_val, 355 | meta_loss=meta_train_loss, 356 | meta_step_size=flags.meta_step_size, 357 | stop_gradient=flags.stop_gradient) 358 | 359 | meta_val_loss = self.loss_fn(outputs_val, labels_val) 360 | 361 | total_loss = meta_train_loss + meta_val_loss * flags.meta_val_beta 362 | 363 | # init the grad to zeros first 364 | self.optimizer.zero_grad() 365 | 366 | # backward your network 367 | total_loss.backward() 368 | 369 | # optimize the parameters 370 | self.optimizer.step() 371 | 372 | print( 373 | 'ite:', ite, 374 | 'meta_train_loss:', meta_train_loss.cpu().data.numpy()[0], 375 | 'meta_val_loss:', meta_val_loss.cpu().data.numpy()[0], 376 | 'lr:', 377 | self.scheduler.get_lr()[0]) 378 | 379 | flags_log = os.path.join(flags.logs, 'loss_log.txt') 380 | write_log( 381 | str(meta_train_loss.cpu().data.numpy()[0]) + '\t' + str(meta_val_loss.cpu().data.numpy()[0]), 382 | flags_log) 383 | 384 | del total_loss, outputs_val, outputs_train 385 | 386 | if ite % flags.test_every == 0 and ite is not 0 or flags.debug: 387 | self.test_workflow(self.batImageGenVals, flags, ite) 388 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import torch.autograd as autograd 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | 5 | 6 | def linear(inputs, weight, bias, meta_step_size=0.001, meta_loss=None, stop_gradient=False): 7 | if meta_loss is not None: 8 | 9 | if not stop_gradient: 10 | grad_weight = autograd.grad(meta_loss, weight, create_graph=True)[0] 11 | 12 | if bias is not None: 13 | grad_bias = autograd.grad(meta_loss, bias, create_graph=True)[0] 14 | bias_adapt = bias - grad_bias * meta_step_size 15 | else: 16 | bias_adapt = bias 17 | 18 | else: 19 | grad_weight = Variable(autograd.grad(meta_loss, weight, create_graph=True)[0].data, requires_grad=False) 20 | 21 | if bias is not None: 22 | grad_bias = Variable(autograd.grad(meta_loss, bias, create_graph=True)[0].data, requires_grad=False) 23 | bias_adapt = bias - grad_bias * meta_step_size 24 | else: 25 | bias_adapt = bias 26 | 27 | return F.linear(inputs, 28 | weight - grad_weight * meta_step_size, 29 | bias_adapt) 30 | else: 31 | return F.linear(inputs, weight, bias) 32 | 33 | 34 | def conv2d(inputs, weight, bias, meta_step_size=0.001, stride=1, padding=0, dilation=1, groups=1, meta_loss=None, 35 | stop_gradient=False): 36 | if meta_loss is not None: 37 | 38 | if not stop_gradient: 39 | grad_weight = autograd.grad(meta_loss, weight, create_graph=True)[0] 40 | 41 | if bias is not None: 42 | grad_bias = autograd.grad(meta_loss, bias, create_graph=True)[0] 43 | bias_adapt = bias - grad_bias * meta_step_size 44 | else: 45 | bias_adapt = bias 46 | 47 | else: 48 | grad_weight = Variable(autograd.grad(meta_loss, weight, create_graph=True)[0].data, 49 | requires_grad=False) 50 | if bias is not None: 51 | grad_bias = Variable(autograd.grad(meta_loss, bias, create_graph=True)[0].data, requires_grad=False) 52 | bias_adapt = bias - grad_bias * meta_step_size 53 | else: 54 | bias_adapt = bias 55 | 56 | return F.conv2d(inputs, 57 | weight - grad_weight * meta_step_size, 58 | bias_adapt, stride, 59 | padding, 60 | dilation, groups) 61 | else: 62 | return F.conv2d(inputs, weight, bias, stride, padding, dilation, groups) 63 | 64 | 65 | def relu(inputs): 66 | return F.threshold(inputs, 0, 0, inplace=True) 67 | 68 | 69 | def maxpool(inputs, kernel_size, stride=None, padding=0): 70 | return F.max_pool2d(inputs, kernel_size, stride, padding=padding) 71 | -------------------------------------------------------------------------------- /run_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | times=5 4 | for j in `seq 1 $times` 5 | do 6 | max=3 7 | for i in `seq 0 $max` 8 | do 9 | python main_baseline.py train \ 10 | --lr=5e-4 \ 11 | --num_classes=7 \ 12 | --test_every=500 \ 13 | --logs='run_'$j'/logs_'$i'/' \ 14 | --batch_size=64 \ 15 | --model_path='run_'$j'/models_'$i'/' \ 16 | --unseen_index=$i \ 17 | --inner_loops=45001 \ 18 | --step_size=15000 \ 19 | --state_dict='' \ 20 | --data_root=$1 21 | done 22 | done -------------------------------------------------------------------------------- /run_mldg.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | times=5 4 | for j in `seq 1 $times` 5 | do 6 | max=3 7 | for i in `seq 0 $max` 8 | do 9 | python main_mldg.py train \ 10 | --lr=5e-4 \ 11 | --num_classes=7 \ 12 | --test_every=500 \ 13 | --logs='run_'$j'/logs_mldg_'$i'/' \ 14 | --batch_size=64 \ 15 | --model_path='run_'$j'/models_mldg_'$i'/' \ 16 | --unseen_index=$i \ 17 | --inner_loops=45001 \ 18 | --step_size=15000 \ 19 | --state_dict='' \ 20 | --data_root=$1 \ 21 | --meta_step_size=5e-1 \ 22 | --meta_val_beta=1.0 23 | done 24 | done 25 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.optim as optim 4 | from sklearn.metrics import accuracy_score 5 | 6 | 7 | def unfold_label(labels, classes): 8 | new_labels = [] 9 | 10 | assert len(np.unique(labels)) == classes 11 | # minimum value of labels 12 | mini = np.min(labels) 13 | 14 | for index in range(len(labels)): 15 | dump = np.full(shape=[classes], fill_value=0).astype(np.int8) 16 | _class = int(labels[index]) - mini 17 | dump[_class] = 1 18 | new_labels.append(dump) 19 | 20 | return np.array(new_labels) 21 | 22 | 23 | def shuffle_data(samples, labels): 24 | num = len(labels) 25 | shuffle_index = np.random.permutation(np.arange(num)) 26 | shuffled_samples = samples[shuffle_index] 27 | shuffled_labels = labels[shuffle_index] 28 | return shuffled_samples, shuffled_labels 29 | 30 | 31 | def shuffle_list(li): 32 | np.random.shuffle(li) 33 | return li 34 | 35 | 36 | def shuffle_list_with_ind(li): 37 | shuffle_index = np.random.permutation(np.arange(len(li))) 38 | li = li[shuffle_index] 39 | return li, shuffle_index 40 | 41 | 42 | def num_flat_features(x): 43 | size = x.size()[1:] # all dimensions except the batch dimension 44 | num_features = 1 45 | for s in size: 46 | num_features *= s 47 | return num_features 48 | 49 | 50 | def crossentropyloss(): 51 | loss_fn = torch.nn.CrossEntropyLoss() 52 | return loss_fn 53 | 54 | 55 | def mseloss(): 56 | loss_fn = torch.nn.MSELoss() 57 | return loss_fn 58 | 59 | 60 | def sgd(parameters, lr, weight_decay=0.00005, momentum=0.9): 61 | opt = optim.SGD(params=parameters, lr=lr, momentum=momentum, weight_decay=weight_decay) 62 | return opt 63 | 64 | 65 | def fix_seed(): 66 | # torch.manual_seed(1108) 67 | # np.random.seed(1108) 68 | pass 69 | 70 | 71 | def write_log(log, log_path): 72 | f = open(log_path, mode='a') 73 | f.write(str(log)) 74 | f.write('\n') 75 | f.close() 76 | 77 | 78 | def compute_accuracy(predictions, labels): 79 | accuracy = accuracy_score(y_true=np.argmax(labels, axis=-1), y_pred=np.argmax(predictions, axis=-1)) 80 | return accuracy 81 | --------------------------------------------------------------------------------