├── OmniglotBuilder.py ├── README.MD ├── architecture.png ├── data └── data.npy ├── data_loader.py ├── gen_datatset.py ├── mainOmniglot.py ├── matching_networks.py └── result.png /OmniglotBuilder.py: -------------------------------------------------------------------------------- 1 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | # Created by: BoyuanJiang 3 | # College of Information Science & Electronic Engineering,ZheJiang University 4 | # Email: ginger188@gmail.com 5 | # Copyright (c) 2017 6 | 7 | # @Time :17-8-29 16:20 8 | # @FILE :OmniglotBuilder.py 9 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | from matching_networks import MatchingNetwork 12 | # from MatchingNetwork import MatchingNetwork 13 | import torch 14 | import tqdm 15 | from torch.autograd import Variable 16 | import torch.backends.cudnn as cudnn 17 | from torch.optim.lr_scheduler import ReduceLROnPlateau 18 | 19 | 20 | class OmniglotBuilder: 21 | def __init__(self, data): 22 | """ 23 | Initializes the experiment 24 | :param data: 25 | """ 26 | self.data = data 27 | 28 | def build_experiment(self, batch_size, num_channels, lr, image_size, classes_per_set, samples_per_class, keep_prob, 29 | fce, optim, weight_decay, use_cuda): 30 | """ 31 | 32 | :param batch_size: 33 | :param num_channels: 34 | :param lr: 35 | :param image_size: 36 | :param classes_per_set: 37 | :param samples_per_class: 38 | :param keep_prob: 39 | :param fce: 40 | :param optim: 41 | :param weight_decay: 42 | :param use_cuda: 43 | :return: 44 | """ 45 | self.classes_per_set = classes_per_set 46 | self.sample_per_class = samples_per_class 47 | self.keep_prob = keep_prob 48 | self.batch_size = batch_size 49 | self.lr = lr 50 | self.image_size = image_size 51 | self.optim = optim 52 | self.wd = weight_decay 53 | self.isCuadAvailable = torch.cuda.is_available() 54 | self.use_cuda = use_cuda 55 | self.matchNet = MatchingNetwork(keep_prob, batch_size, num_channels, self.lr, fce, classes_per_set, 56 | samples_per_class, image_size, self.isCuadAvailable & self.use_cuda) 57 | self.total_iter = 0 58 | if self.isCuadAvailable & self.use_cuda: 59 | cudnn.benchmark = True # set True to speedup 60 | torch.cuda.manual_seed_all(2017) 61 | self.matchNet.cuda() 62 | self.total_train_iter = 0 63 | self.optimizer = self._create_optimizer(self.matchNet, self.lr) 64 | self.scheduler = ReduceLROnPlateau(self.optimizer, 'min',verbose=True) 65 | 66 | def run_training_epoch(self, total_train_batches): 67 | """ 68 | Run the training epoch 69 | :param total_train_batches: Number of batches to train on 70 | :return: 71 | """ 72 | total_c_loss = 0.0 73 | total_accuracy = 0.0 74 | # optimizer = self._create_optimizer(self.matchNet, self.lr) 75 | 76 | with tqdm.tqdm(total=total_train_batches) as pbar: 77 | for i in range(total_train_batches): 78 | x_support_set, y_support_set, x_target, y_target = self.data.get_train_batch(True) 79 | x_support_set = Variable(torch.from_numpy(x_support_set)).float() 80 | y_support_set = Variable(torch.from_numpy(y_support_set), requires_grad=False).long() 81 | x_target = Variable(torch.from_numpy(x_target)).float() 82 | y_target = Variable(torch.from_numpy(y_target), requires_grad=False).squeeze().long() 83 | 84 | # convert to one hot encoding 85 | y_support_set = y_support_set.unsqueeze(2) 86 | sequence_length = y_support_set.size()[1] 87 | batch_size = y_support_set.size()[0] 88 | y_support_set_one_hot = Variable( 89 | torch.zeros(batch_size, sequence_length, self.classes_per_set).scatter_(2, 90 | y_support_set.data, 91 | 1), requires_grad=False) 92 | 93 | # reshape channels and change order 94 | size = x_support_set.size() 95 | x_support_set = x_support_set.permute(0, 1, 4, 2, 3) 96 | x_target = x_target.permute(0, 3, 1, 2) 97 | if self.isCuadAvailable & self.use_cuda: 98 | acc, c_loss = self.matchNet(x_support_set.cuda(), y_support_set_one_hot.cuda(), x_target.cuda(), 99 | y_target.cuda()) 100 | else: 101 | acc, c_loss = self.matchNet(x_support_set, y_support_set_one_hot, x_target, y_target) 102 | 103 | # optimize process 104 | self.optimizer.zero_grad() 105 | c_loss.backward() 106 | self.optimizer.step() 107 | 108 | # TODO: update learning rate? 109 | 110 | iter_out = "tr_loss: {}, tr_accuracy: {}".format(c_loss.data[0], acc.data[0]) 111 | pbar.set_description(iter_out) 112 | pbar.update(1) 113 | total_c_loss += c_loss.data[0] 114 | total_accuracy += acc.data[0] 115 | # self.total_train_iter+=1 116 | 117 | total_c_loss = total_c_loss / total_train_batches 118 | total_accuracy = total_accuracy / total_train_batches 119 | return total_c_loss, total_accuracy 120 | 121 | def _create_optimizer(self, model, lr): 122 | # setup optimizer 123 | if self.optim == "adam": 124 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=self.wd) 125 | elif self.optim == "sgd": 126 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, dampening=0.9, weight_decay=self.wd) 127 | else: 128 | raise Exception("Not a valid optimizer offered: {0}".format(self.optim)) 129 | return optimizer 130 | 131 | def _adjust_learning_rate(self, optimizer): 132 | """ 133 | Update the learning rate after some epochs 134 | :param optimizer: 135 | :return: 136 | """ 137 | 138 | def run_val_epoch(self, total_val_batches): 139 | """ 140 | Run the training epoch 141 | :param total_train_batches: Number of batches to train on 142 | :return: 143 | """ 144 | total_c_loss = 0.0 145 | total_accuracy = 0.0 146 | 147 | with tqdm.tqdm(total=total_val_batches) as pbar: 148 | for i in range(total_val_batches): 149 | x_support_set, y_support_set, x_target, y_target = self.data.get_val_batch(False) 150 | x_support_set = Variable(torch.from_numpy(x_support_set)).float() 151 | y_support_set = Variable(torch.from_numpy(y_support_set), requires_grad=False).long() 152 | x_target = Variable(torch.from_numpy(x_target)).float() 153 | y_target = Variable(torch.from_numpy(y_target), requires_grad=False).squeeze().long() 154 | 155 | # convert to one hot encoding 156 | y_support_set = y_support_set.unsqueeze(2) 157 | sequence_length = y_support_set.size()[1] 158 | batch_size = y_support_set.size()[0] 159 | y_support_set_one_hot = Variable( 160 | torch.zeros(batch_size, sequence_length, self.classes_per_set).scatter_(2, 161 | y_support_set.data, 162 | 1), requires_grad=False) 163 | 164 | # reshape channels and change order 165 | size = x_support_set.size() 166 | x_support_set = x_support_set.permute(0, 1, 4, 2, 3) 167 | x_target = x_target.permute(0, 3, 1, 2) 168 | if self.isCuadAvailable & self.use_cuda: 169 | acc, c_loss = self.matchNet(x_support_set.cuda(), y_support_set_one_hot.cuda(), x_target.cuda(), 170 | y_target.cuda()) 171 | else: 172 | acc, c_loss = self.matchNet(x_support_set, y_support_set_one_hot, x_target, y_target) 173 | 174 | # TODO: update learning rate? 175 | 176 | iter_out = "val_loss: {}, val_accuracy: {}".format(c_loss.data[0], acc.data[0]) 177 | pbar.set_description(iter_out) 178 | pbar.update(1) 179 | total_c_loss += c_loss.data[0] 180 | total_accuracy += acc.data[0] 181 | # self.total_train_iter+=1 182 | 183 | total_c_loss = total_c_loss / total_val_batches 184 | total_accuracy = total_accuracy / total_val_batches 185 | self.scheduler.step(total_c_loss) 186 | return total_c_loss, total_accuracy 187 | 188 | def run_test_epoch(self, total_test_batches): 189 | """ 190 | Run the training epoch 191 | :param total_train_batches: Number of batches to train on 192 | :return: 193 | """ 194 | total_c_loss = 0.0 195 | total_accuracy = 0.0 196 | 197 | with tqdm.tqdm(total=total_test_batches) as pbar: 198 | for i in range(total_test_batches): 199 | x_support_set, y_support_set, x_target, y_target = self.data.get_test_batch(False) 200 | x_support_set = Variable(torch.from_numpy(x_support_set)).float() 201 | y_support_set = Variable(torch.from_numpy(y_support_set), requires_grad=False).long() 202 | x_target = Variable(torch.from_numpy(x_target)).float() 203 | y_target = Variable(torch.from_numpy(y_target), requires_grad=False).squeeze().long() 204 | 205 | # convert to one hot encoding 206 | y_support_set = y_support_set.unsqueeze(2) 207 | sequence_length = y_support_set.size()[1] 208 | batch_size = y_support_set.size()[0] 209 | y_support_set_one_hot = Variable( 210 | torch.zeros(batch_size, sequence_length, self.classes_per_set).scatter_(2, 211 | y_support_set.data, 212 | 1), requires_grad=False) 213 | 214 | # reshape channels and change order 215 | size = x_support_set.size() 216 | x_support_set = x_support_set.permute(0, 1, 4, 2, 3) 217 | x_target = x_target.permute(0, 3, 1, 2) 218 | if self.isCuadAvailable & self.use_cuda: 219 | acc, c_loss = self.matchNet(x_support_set.cuda(), y_support_set_one_hot.cuda(), x_target.cuda(), 220 | y_target.cuda()) 221 | else: 222 | acc, c_loss = self.matchNet(x_support_set, y_support_set_one_hot, x_target, y_target) 223 | 224 | # TODO: update learning rate? 225 | 226 | iter_out = "val_loss: {}, val_accuracy: {}".format(c_loss.data[0], acc.data[0]) 227 | pbar.set_description(iter_out) 228 | pbar.update(1) 229 | total_c_loss += c_loss.data[0] 230 | total_accuracy += acc.data[0] 231 | # self.total_train_iter+=1 232 | 233 | total_c_loss = total_c_loss / total_test_batches 234 | total_accuracy = total_accuracy / total_test_batches 235 | return total_c_loss, total_accuracy -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | This is the pytorch implement of **[Matching Networks for One Shot Learning](https://arxiv.org/pdf/1606.04080.pdf)** 2 | ![architecture](https://github.com/BoyuanJiang/matching-networks-pytorch/blob/master/architecture.png) 3 | ## Train 4 | If you want to train the model,simply run the code 5 | ``` 6 | python mainOmniglot.py 7 | ``` 8 | 9 | You can set *fce = True* if you want use Full Context Embeddings.You can use the follow set in mainOmniglot.py for a 5-way one shot learning. 10 | ``` 11 | classes_per_set = 5 12 | samples_per_class = 1 13 | ``` 14 | 15 | ## Result 16 | After about 30 epoches,you can achieve about 97% accuracy on train set and 96% on val and test set. 17 | ![result](https://github.com/BoyuanJiang/matching-networks-pytorch/blob/master/result.png) 18 | 19 | ## Acknowledgement 20 | This work refer to [gitabcworld](https://github.com/gitabcworld/)'s implement,detail at[https://github.com/AntreasAntoniou/MatchingNetworks](https://github.com/AntreasAntoniou/MatchingNetworks). -------------------------------------------------------------------------------- /architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyuanJiang/matching-networks-pytorch/3f7bc55ce531143004cc4bc6138a3485d4886794/architecture.png -------------------------------------------------------------------------------- /data/data.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyuanJiang/matching-networks-pytorch/3f7bc55ce531143004cc4bc6138a3485d4886794/data/data.npy -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | # Created by: BoyuanJiang 3 | # College of Information Science & Electronic Engineering,ZheJiang University 4 | # Email: ginger188@gmail.com 5 | # Copyright (c) 2017 6 | 7 | # @Time :17-8-27 10:46 8 | # @FILE :data_loader.py 9 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import numpy as np 12 | 13 | 14 | class OmniglotNShotDataset(): 15 | def __init__(self, batch_size, classes_per_set=20, samples_per_class=1, seed=2017, shuffle=True, use_cache=True): 16 | """ 17 | Construct N-shot dataset 18 | :param batch_size: Experiment batch_size 19 | :param classes_per_set: Integer indicating the number of classes per set 20 | :param samples_per_class: Integer indicating samples per class 21 | :param seed: seed for random function 22 | :param shuffle: if shuffle the dataset 23 | :param use_cache: if true,cache dataset to memory.It can speedup the train but require larger memory 24 | """ 25 | np.random.seed(seed) 26 | self.x = np.load('data/data.npy') 27 | self.x = np.reshape(self.x, newshape=(self.x.shape[0], self.x.shape[1], 28, 28, 1)) 28 | if shuffle: 29 | np.random.shuffle(self.x) 30 | self.x_train, self.x_val, self.x_test = self.x[:1200], self.x[1200:1411], self.x[1411:] 31 | # self.mean = np.mean(list(self.x_train) + list(self.x_val)) 32 | self.x_train = self.processes_batch(self.x_train, np.mean(self.x_train), np.std(self.x_train)) 33 | self.x_test = self.processes_batch(self.x_test, np.mean(self.x_test), np.std(self.x_test)) 34 | self.x_val = self.processes_batch(self.x_val, np.mean(self.x_val), np.std(self.x_val)) 35 | # self.std = np.std(list(self.x_train) + list(self.x_val)) 36 | self.batch_size = batch_size 37 | self.n_classes = self.x.shape[0] 38 | self.classes_per_set = classes_per_set 39 | self.samples_per_class = samples_per_class 40 | self.indexes = {"train": 0, "val": 0, "test": 0} 41 | self.datatset = {"train": self.x_train, "val": self.x_val, "test": self.x_test} 42 | self.use_cache = use_cache 43 | if self.use_cache: 44 | self.cached_datatset = {"train": self.load_data_cache(self.x_train), 45 | "val": self.load_data_cache(self.x_val), 46 | "test": self.load_data_cache(self.x_test)} 47 | 48 | def processes_batch(self, x_batch, mean, std): 49 | """ 50 | Normalizes a batch images 51 | :param x_batch: a batch images 52 | :return: normalized images 53 | """ 54 | return (x_batch - mean) / std 55 | 56 | def _sample_new_batch(self, data_pack): 57 | """ 58 | Collect 1000 batches data for N-shot learning 59 | :param data_pack: one of(train,test,val) dataset shape[classes_num,20,28,28,1] 60 | :return: A list with [support_set_x,support_set_y,target_x,target_y] ready to be fed to our networks 61 | """ 62 | support_set_x = np.zeros((self.batch_size, self.classes_per_set, self.samples_per_class, data_pack.shape[2], 63 | data_pack.shape[3], data_pack.shape[4]), np.float32) 64 | 65 | support_set_y = np.zeros((self.batch_size, self.classes_per_set, self.samples_per_class), np.int32) 66 | target_x = np.zeros((self.batch_size, data_pack.shape[2], data_pack.shape[3], data_pack.shape[4]), np.float32) 67 | target_y = np.zeros((self.batch_size, 1), np.int32) 68 | 69 | for i in range(self.batch_size): 70 | classes_idx = np.arange(data_pack.shape[0]) 71 | samples_idx = np.arange(data_pack.shape[1]) 72 | choose_classes = np.random.choice(classes_idx, size=self.classes_per_set, replace=False) 73 | choose_label = np.random.choice(self.classes_per_set, size=1) 74 | choose_samples = np.random.choice(samples_idx, size=self.samples_per_class + 1, replace=False) 75 | 76 | x_temp = data_pack[choose_classes] 77 | x_temp = x_temp[:, choose_samples] 78 | y_temp = np.arange(self.classes_per_set) 79 | support_set_x[i] = x_temp[:, :-1] 80 | support_set_y[i] = np.expand_dims(y_temp[:], axis=1) 81 | target_x[i] = x_temp[choose_label, -1] 82 | target_y[i] = y_temp[choose_label] 83 | 84 | return support_set_x, support_set_y, target_x, target_y 85 | 86 | def _rotate_data(self, image, k): 87 | """ 88 | Rotates one image by self.k * 90 degrees counter-clockwise 89 | :param image: Image to rotate 90 | :return: Rotated Image 91 | """ 92 | return np.rot90(image, k) 93 | 94 | def _rotate_batch(self, batch_images, k): 95 | """ 96 | Rotates a whole image batch 97 | :param batch_images: A batch of images 98 | :param k: integer degree of rotation counter-clockwise 99 | :return: The rotated batch of images 100 | """ 101 | batch_size = batch_images.shape[0] 102 | for i in np.arange(batch_size): 103 | batch_images[i] = self._rotate_data(batch_images[i], k) 104 | return batch_images 105 | 106 | def _get_batch(self, dataset_name, augment=False): 107 | """ 108 | Get next batch from the dataset with name. 109 | :param dataset_name: The name of dataset(one of "train","val","test") 110 | :param augment: if rotate the images 111 | :return: a batch images 112 | """ 113 | if self.use_cache: 114 | support_set_x, support_set_y, target_x, target_y = self._get_batch_from_cache(dataset_name) 115 | else: 116 | support_set_x, support_set_y, target_x, target_y = self._sample_new_batch(self.datatset[dataset_name]) 117 | if augment: 118 | k = np.random.randint(0, 4, size=(self.batch_size, self.classes_per_set)) 119 | a_support_set_x = [] 120 | a_target_x = [] 121 | for b in range(self.batch_size): 122 | temp_class_set = [] 123 | for c in range(self.classes_per_set): 124 | temp_class_set_x = self._rotate_batch(support_set_x[b, c], k=k[b, c]) 125 | if target_y[b] == support_set_y[b, c, 0]: 126 | temp_target_x = self._rotate_data(target_x[b], k=k[b, c]) 127 | temp_class_set.append(temp_class_set_x) 128 | a_support_set_x.append(temp_class_set) 129 | a_target_x.append(temp_target_x) 130 | support_set_x = np.array(a_support_set_x) 131 | target_x = np.array(a_target_x) 132 | support_set_x = support_set_x.reshape((support_set_x.shape[0], support_set_x.shape[1] * support_set_x.shape[2], 133 | support_set_x.shape[3], support_set_x.shape[4], support_set_x.shape[5])) 134 | support_set_y = support_set_y.reshape(support_set_y.shape[0], support_set_y.shape[1] * support_set_y.shape[2]) 135 | return support_set_x, support_set_y, target_x, target_y 136 | 137 | def get_train_batch(self, augment=False): 138 | return self._get_batch("train", augment) 139 | 140 | def get_val_batch(self, augment=False): 141 | return self._get_batch("val", augment) 142 | 143 | def get_test_batch(self, augment=False): 144 | return self._get_batch("test", augment) 145 | 146 | def load_data_cache(self, data_pack, argument=True): 147 | """ 148 | cache the dataset in memory 149 | :param data_pack: shape[classes_num,20,28,28,1] 150 | :return: 151 | """ 152 | cached_dataset = [] 153 | classes_idx = np.arange(data_pack.shape[0]) 154 | samples_idx = np.arange(data_pack.shape[1]) 155 | for _ in range(1000): 156 | support_set_x = np.zeros((self.batch_size, self.classes_per_set, self.samples_per_class, data_pack.shape[2], 157 | data_pack.shape[3], data_pack.shape[4]), np.float32) 158 | 159 | support_set_y = np.zeros((self.batch_size, self.classes_per_set, self.samples_per_class), np.int32) 160 | target_x = np.zeros((self.batch_size, data_pack.shape[2], data_pack.shape[3], data_pack.shape[4]), 161 | np.float32) 162 | target_y = np.zeros((self.batch_size, 1), np.int32) 163 | for i in range(self.batch_size): 164 | choose_classes = np.random.choice(classes_idx, size=self.classes_per_set, replace=False) 165 | choose_label = np.random.choice(self.classes_per_set, size=1) 166 | choose_samples = np.random.choice(samples_idx, size=self.samples_per_class + 1, replace=False) 167 | 168 | x_temp = data_pack[choose_classes] 169 | x_temp = x_temp[:, choose_samples] 170 | y_temp = np.arange(self.classes_per_set) 171 | support_set_x[i] = x_temp[:, :-1] 172 | support_set_y[i] = np.expand_dims(y_temp[:], axis=1) 173 | target_x[i] = x_temp[choose_label, -1] 174 | target_y[i] = y_temp[choose_label] 175 | cached_dataset.append([support_set_x, support_set_y, target_x, target_y]) 176 | return cached_dataset 177 | 178 | def _get_batch_from_cache(self, dataset_name): 179 | """ 180 | 181 | :param dataset_name: 182 | :return: 183 | """ 184 | if self.indexes[dataset_name] >= len(self.cached_datatset[dataset_name]): 185 | self.indexes[dataset_name] = 0 186 | self.cached_datatset[dataset_name] = self.load_data_cache(self.datatset[dataset_name]) 187 | next_batch = self.cached_datatset[dataset_name][self.indexes[dataset_name]] 188 | self.indexes[dataset_name] += 1 189 | x_support_set, y_support_set, x_target, y_target = next_batch 190 | return x_support_set, y_support_set, x_target, y_target 191 | -------------------------------------------------------------------------------- /gen_datatset.py: -------------------------------------------------------------------------------- 1 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | # Created by: BoyuanJiang 3 | # College of Information Science & Electronic Engineering,ZheJiang University 4 | # Email: ginger188@gmail.com 5 | # Copyright (c) 2017 6 | 7 | # @Time :17-8-27 10:18 8 | # @FILE :gen_datatset.py 9 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import numpy as np 12 | from scipy import misc 13 | import os 14 | 15 | dataset = [] 16 | examples = [] 17 | # images_background 18 | data_root = "./data/" 19 | alphabets = os.listdir(data_root + "images_background") 20 | for alphabet in alphabets: 21 | characters = os.listdir(os.path.join(data_root, "images_background", alphabet)) 22 | for character in characters: 23 | files = os.listdir(os.path.join(data_root, "images_background", alphabet, character)) 24 | examples = [] 25 | for img_file in files: 26 | img = misc.imresize( 27 | misc.imread(os.path.join(data_root, "images_background", alphabet, character, img_file)), [28, 28]) 28 | # img = (np.float32(img) / 255.).flatten() 29 | examples.append(img) 30 | dataset.append(examples) 31 | 32 | # images_evaluation 33 | data_root = "./data/" 34 | alphabets = os.listdir(data_root + "images_evaluation") 35 | for alphabet in alphabets: 36 | characters = os.listdir(os.path.join(data_root, "images_evaluation", alphabet)) 37 | for character in characters: 38 | files = os.listdir(os.path.join(data_root, "images_evaluation", alphabet, character)) 39 | examples = [] 40 | for img_file in files: 41 | img = misc.imresize( 42 | misc.imread(os.path.join(data_root, "images_evaluation", alphabet, character, img_file)), [28, 28]) 43 | # img = (np.float32(img) / 255.).flatten() 44 | examples.append(img) 45 | dataset.append(examples) 46 | 47 | np.save(data_root + "dataset.npy", np.asarray(dataset)) 48 | -------------------------------------------------------------------------------- /mainOmniglot.py: -------------------------------------------------------------------------------- 1 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | # Created by: BoyuanJiang 3 | # College of Information Science & Electronic Engineering,ZheJiang University 4 | # Email: ginger188@gmail.com 5 | # Copyright (c) 2017 6 | 7 | # @Time :17-8-29 22:26 8 | # @FILE :mainOmniglot.py 9 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | 12 | from data_loader import OmniglotNShotDataset 13 | from OmniglotBuilder import OmniglotBuilder 14 | import tqdm 15 | 16 | # Experiment setup 17 | batch_size = 20 18 | fce = True 19 | classes_per_set = 20 20 | samples_per_class = 1 21 | channels = 1 22 | # Training setup 23 | total_epochs = 100 24 | total_train_batches = 1000 25 | total_val_batches = 250 26 | total_test_batches = 500 27 | best_val_acc = 0.0 28 | 29 | data = OmniglotNShotDataset(batch_size=batch_size, classes_per_set=classes_per_set, 30 | samples_per_class=samples_per_class, seed=2017, shuffle=True, use_cache=False) 31 | obj_oneShotBuilder = OmniglotBuilder(data) 32 | obj_oneShotBuilder.build_experiment(batch_size=batch_size, num_channels=1, lr=1e-3, image_size=28, classes_per_set=20, 33 | samples_per_class=1, keep_prob=0.0, fce=True, optim="adam", weight_decay=0, 34 | use_cuda=True) 35 | 36 | with tqdm.tqdm(total=total_train_batches) as pbar_e: 37 | for e in range(total_epochs): 38 | total_c_loss, total_accuracy = obj_oneShotBuilder.run_training_epoch(total_train_batches) 39 | print("Epoch {}: train_loss:{} train_accuracy:{}".format(e, total_c_loss, total_accuracy)) 40 | total_val_c_loss, total_val_accuracy = obj_oneShotBuilder.run_val_epoch(total_val_batches) 41 | print("Epoch {}: val_loss:{} val_accuracy:{}".format(e, total_val_c_loss, total_val_accuracy)) 42 | if total_val_accuracy>best_val_acc: 43 | best_val_acc = total_val_accuracy 44 | total_test_c_loss, total_test_accuracy = obj_oneShotBuilder.run_test_epoch(total_test_batches) 45 | print("Epoch {}: test_loss:{} test_accuracy:{}".format(e, total_test_c_loss, total_test_accuracy)) 46 | pbar_e.update(1) 47 | -------------------------------------------------------------------------------- /matching_networks.py: -------------------------------------------------------------------------------- 1 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | # Created by: BoyuanJiang 3 | # College of Information Science & Electronic Engineering,ZheJiang University 4 | # Email: ginger188@gmail.com 5 | # Copyright (c) 2017 6 | 7 | # @Time :17-8-27 21:25 8 | # @FILE :matching_networks.py 9 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import math 14 | import numpy as np 15 | import torch.nn.functional as F 16 | from torch.autograd import Variable 17 | 18 | 19 | def convLayer(in_channels, out_channels, keep_prob=0.0): 20 | """3*3 convolution with padding,ever time call it the output size become half""" 21 | cnn_seq = nn.Sequential( 22 | nn.Conv2d(in_channels, out_channels, 3, 1, 1), 23 | nn.ReLU(True), 24 | nn.BatchNorm2d(out_channels), 25 | nn.MaxPool2d(kernel_size=2, stride=2), 26 | nn.Dropout(keep_prob) 27 | ) 28 | return cnn_seq 29 | 30 | 31 | class Classifier(nn.Module): 32 | def __init__(self, layer_size=64, num_channels=1, keep_prob=1.0, image_size=28): 33 | super(Classifier, self).__init__() 34 | """ 35 | Build a CNN to produce embeddings 36 | :param layer_size:64(default) 37 | :param num_channels: 38 | :param keep_prob: 39 | :param image_size: 40 | """ 41 | self.layer1 = convLayer(num_channels, layer_size, keep_prob) 42 | self.layer2 = convLayer(layer_size, layer_size, keep_prob) 43 | self.layer3 = convLayer(layer_size, layer_size, keep_prob) 44 | self.layer4 = convLayer(layer_size, layer_size, keep_prob) 45 | 46 | finalSize = int(math.floor(image_size / (2 * 2 * 2 * 2))) 47 | self.outSize = finalSize * finalSize * layer_size 48 | 49 | def forward(self, image_input): 50 | """ 51 | Use CNN defined above 52 | :param image_input: 53 | :return: 54 | """ 55 | x = self.layer1(image_input) 56 | x = self.layer2(x) 57 | x = self.layer3(x) 58 | x = self.layer4(x) 59 | x = x.view(x.size()[0], -1) 60 | return x 61 | 62 | 63 | class AttentionalClassify(nn.Module): 64 | def __init__(self): 65 | super(AttentionalClassify, self).__init__() 66 | 67 | def forward(self, similarities, support_set_y): 68 | """ 69 | Products pdfs over the support set classes for the target set image. 70 | :param similarities: A tensor with cosine similarites of size[batch_size,sequence_length] 71 | :param support_set_y:[batch_size,sequence_length,classes_num] 72 | :return: Softmax pdf shape[batch_size,classes_num] 73 | """ 74 | softmax = nn.Softmax() 75 | softmax_similarities = softmax(similarities) 76 | preds = softmax_similarities.unsqueeze(1).bmm(support_set_y).squeeze() 77 | return preds 78 | 79 | 80 | class DistanceNetwork(nn.Module): 81 | """ 82 | This model calculates the cosine distance between each of the support set embeddings and the target image embeddings. 83 | """ 84 | 85 | def __init__(self): 86 | super(DistanceNetwork, self).__init__() 87 | 88 | def forward(self, support_set, input_image): 89 | """ 90 | forward implement 91 | :param support_set:the embeddings of the support set images.shape[sequence_length,batch_size,64] 92 | :param input_image: the embedding of the target image,shape[batch_size,64] 93 | :return:shape[batch_size,sequence_length] 94 | """ 95 | eps = 1e-10 96 | similarities = [] 97 | for support_image in support_set: 98 | sum_support = torch.sum(torch.pow(support_image, 2), 1) 99 | support_manitude = sum_support.clamp(eps, float("inf")).rsqrt() 100 | dot_product = input_image.unsqueeze(1).bmm(support_image.unsqueeze(2)).squeeze() 101 | cosine_similarity = dot_product * support_manitude 102 | similarities.append(cosine_similarity) 103 | similarities = torch.stack(similarities) 104 | return similarities.t() 105 | 106 | 107 | class BidirectionalLSTM(nn.Module): 108 | def __init__(self, layer_size, batch_size, vector_dim,use_cuda): 109 | super(BidirectionalLSTM, self).__init__() 110 | """ 111 | Initial a muti-layer Bidirectional LSTM 112 | :param layer_size: a list of each layer'size 113 | :param batch_size: 114 | :param vector_dim: 115 | """ 116 | self.batch_size = batch_size 117 | self.hidden_size = layer_size[0] 118 | self.vector_dim = vector_dim 119 | self.num_layer = len(layer_size) 120 | self.use_cuda = use_cuda 121 | self.lstm = nn.LSTM(input_size=self.vector_dim, num_layers=self.num_layer, hidden_size=self.hidden_size, 122 | bidirectional=True) 123 | self.hidden = self.init_hidden(self.use_cuda) 124 | 125 | def init_hidden(self,use_cuda): 126 | if use_cuda: 127 | return (Variable(torch.zeros(self.lstm.num_layers * 2, self.batch_size, self.lstm.hidden_size),requires_grad=False).cuda(), 128 | Variable(torch.zeros(self.lstm.num_layers * 2, self.batch_size, self.lstm.hidden_size),requires_grad=False).cuda()) 129 | else: 130 | return (Variable(torch.zeros(self.lstm.num_layers * 2, self.batch_size, self.lstm.hidden_size),requires_grad=False), 131 | Variable(torch.zeros(self.lstm.num_layers * 2, self.batch_size, self.lstm.hidden_size),requires_grad=False)) 132 | 133 | def repackage_hidden(self,h): 134 | """Wraps hidden states in new Variables, to detach them from their history.""" 135 | if type(h) == Variable: 136 | return Variable(h.data) 137 | else: 138 | return tuple(self.repackage_hidden(v) for v in h) 139 | 140 | def forward(self, inputs): 141 | # self.hidden = self.init_hidden(self.use_cuda) 142 | self.hidden = self.repackage_hidden(self.hidden) 143 | output, self.hidden = self.lstm(inputs, self.hidden) 144 | return output 145 | 146 | 147 | class MatchingNetwork(nn.Module): 148 | def __init__(self, keep_prob, batch_size=32, num_channels=1, learning_rate=1e-3, fce=False, num_classes_per_set=20, \ 149 | num_samples_per_class=1, image_size=28, use_cuda=True): 150 | """ 151 | This is our main network 152 | :param keep_prob: dropout rate 153 | :param batch_size: 154 | :param num_channels: 155 | :param learning_rate: 156 | :param fce: Flag indicating whether to use full context embeddings(i.e. apply an LSTM on the CNN embeddings) 157 | :param num_classes_per_set: 158 | :param num_samples_per_class: 159 | :param image_size: 160 | """ 161 | super(MatchingNetwork, self).__init__() 162 | self.batch_size = batch_size 163 | self.keep_prob = keep_prob 164 | self.num_channels = num_channels 165 | self.learning_rate = learning_rate 166 | self.fce = fce 167 | self.num_classes_per_set = num_classes_per_set 168 | self.num_samples_per_class = num_samples_per_class 169 | self.image_size = image_size 170 | self.g = Classifier(layer_size=64, num_channels=num_channels, keep_prob=keep_prob, image_size=image_size) 171 | self.dn = DistanceNetwork() 172 | self.classify = AttentionalClassify() 173 | if self.fce: 174 | self.lstm = BidirectionalLSTM(layer_size=[32], batch_size=self.batch_size, vector_dim=self.g.outSize,use_cuda=use_cuda) 175 | 176 | def forward(self, support_set_images, support_set_y_one_hot, target_image, target_y): 177 | """ 178 | Main process of the network 179 | :param support_set_images: shape[batch_size,sequence_length,num_channels,image_size,image_size] 180 | :param support_set_y_one_hot: shape[batch_size,sequence_length,num_classes_per_set] 181 | :param target_image: shape[batch_size,num_channels,image_size,image_size] 182 | :param target_y: 183 | :return: 184 | """ 185 | # produce embeddings for support set images 186 | encoded_images = [] 187 | for i in np.arange(support_set_images.size(1)): 188 | gen_encode = self.g(support_set_images[:, i, :, :]) 189 | encoded_images.append(gen_encode) 190 | 191 | # produce embeddings for target images 192 | gen_encode = self.g(target_image) 193 | encoded_images.append(gen_encode) 194 | output = torch.stack(encoded_images) 195 | 196 | # use fce? 197 | if self.fce: 198 | outputs = self.lstm(output) 199 | 200 | # get similarities between support set embeddings and target 201 | similarites = self.dn(support_set=output[:-1], input_image=output[-1]) 202 | 203 | # produce predictions for target probabilities 204 | preds = self.classify(similarites, support_set_y=support_set_y_one_hot) 205 | 206 | # calculate the accuracy 207 | values, indices = preds.max(1) 208 | accuracy = torch.mean((indices.squeeze() == target_y).float()) 209 | crossentropy_loss = F.cross_entropy(preds, target_y.long()) 210 | 211 | return accuracy, crossentropy_loss 212 | -------------------------------------------------------------------------------- /result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyuanJiang/matching-networks-pytorch/3f7bc55ce531143004cc4bc6138a3485d4886794/result.png --------------------------------------------------------------------------------