├── LICENSE ├── README.md ├── cnn.py ├── data_utils.py ├── memory.py ├── omniglot.py └── siamese.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LSH_Memory 2 | One-Shot Learning using Nearest-Neighbor Search (NNS) and Locality-Sensitive Hashing LSH 3 | 4 | Run data_utils.py to download and pre-process Omniglot dataset 5 | 6 | # Author: [Ryan Spring](https://github.com/rdspring1) 7 | 8 | # Implementation 9 | 1. [Siamese Neural Networks for One-shot Image Recognition](https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf) 10 | 2. [Learning to Remember Rare Events](https://arxiv.org/abs/1703.03129) 11 | 12 | # Other References 13 | 1. [One-shot Learning with Memory-Augmented Neural Networks](https://arxiv.org/abs/1605.06065.pdf) 14 | 2. [Matching Networks for One Shot Learning](https://arxiv.org/abs/1606.04080.pdf) 15 | * [Andrej Karpathy's Notes - Matching Networks](https://github.com/karpathy/paper-notes/blob/master/matching_networks.md) 16 | -------------------------------------------------------------------------------- /cnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | 5 | import torch 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | 9 | from torch.autograd import Variable 10 | import torch.optim as optim 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | import omniglot 15 | import memory 16 | 17 | class Net(nn.Module): 18 | def __init__(self, input_shape): 19 | super(Net, self).__init__() 20 | # Constants 21 | kernel = 3 22 | pad = int((kernel-1)/2.0) 23 | p = 0.3 24 | 25 | ch, row, col = input_shape 26 | self.conv1 = nn.Conv2d(ch, 32, kernel, padding=(pad, pad)) 27 | self.conv2 = nn.Conv2d(32, 32, kernel, padding=(pad, pad)) 28 | self.conv3 = nn.Conv2d(32, 64, kernel, padding=(pad, pad)) 29 | self.conv4 = nn.Conv2d(64, 64, kernel, padding=(pad, pad)) 30 | self.pool = nn.MaxPool2d(2, 2) 31 | self.fc1 = nn.Linear(row // 4 * col // 4 * 64, 128) 32 | self.dropout = nn.Dropout(p) 33 | 34 | def forward(self, x, predict): 35 | x = F.relu(self.conv1(x)) 36 | x = F.relu(self.conv2(x)) 37 | x = self.pool(x) 38 | x = F.relu(self.conv3(x)) 39 | x = F.relu(self.conv4(x)) 40 | x = self.pool(x) 41 | x = x.view(x.size(0), -1) 42 | x = self.fc1(x) 43 | if not predict: 44 | x = self.dropout(x) 45 | return x 46 | 47 | memory_size = 8192 48 | batch_size = 16 49 | key_dim = 128 50 | episode_length = 30 51 | episode_width = 5 52 | validation_frequency = 20 53 | DATA_FILE_FORMAT = os.path.join(os.getcwd(), '%s_omni.pkl') 54 | 55 | train_filepath = DATA_FILE_FORMAT % 'train' 56 | trainset = omniglot.OmniglotDataset(train_filepath) 57 | trainloader = trainset.sample_episode_batch(episode_length, episode_width, batch_size, N=100000) 58 | 59 | test_filepath = DATA_FILE_FORMAT % 'test' 60 | testset = omniglot.OmniglotDataset(test_filepath) 61 | 62 | #torch.cuda.set_device(1) 63 | net = Net(input_shape=(1,28,28)) 64 | mem = memory.Memory(memory_size, key_dim) 65 | net.add_module("memory", mem) 66 | net.cuda() 67 | 68 | optimizer = optim.Adam(net.parameters(), lr=1e-4, eps=1e-4) 69 | 70 | cummulative_loss = 0 71 | counter = 0 72 | for i, data in enumerate(trainloader, 0): 73 | # erase memory before training episode 74 | mem.build() 75 | x, y = data 76 | for xx, yy in zip(x, y): 77 | optimizer.zero_grad() 78 | xx_cuda, yy_cuda = Variable(xx.cuda()), Variable(yy.cuda()) 79 | embed = net(xx_cuda, False) 80 | yy_hat, softmax_embed, loss = mem.query(embed, yy_cuda, False) 81 | loss.backward() 82 | optimizer.step() 83 | cummulative_loss += loss.data[0] 84 | counter += 1 85 | 86 | if i % validation_frequency == 0: 87 | # validation 88 | correct = [] 89 | correct_by_k_shot = dict((k, list()) for k in range(episode_width + 1)) 90 | testloader = testset.sample_episode_batch(episode_length, episode_width, batch_size=1, N=50) 91 | 92 | for data in testloader: 93 | # erase memory before validation episode 94 | mem.build() 95 | 96 | x, y = data 97 | y_hat = [] 98 | for xx, yy in zip(x, y): 99 | xx_cuda, yy_cuda = Variable(xx.cuda()), Variable(yy.cuda()) 100 | query = net(xx_cuda, True) 101 | yy_hat, embed, loss = mem.query(query, yy_cuda, True) 102 | y_hat.append(yy_hat) 103 | correct.append(float(torch.equal(yy_hat.cpu(), torch.unsqueeze(yy, dim=1)))) 104 | 105 | # compute per_shot accuracies 106 | seen_count = [0 for idx in range(episode_width)] 107 | # loop over episode steps 108 | for yy, yy_hat in zip(y, y_hat): 109 | count = seen_count[yy[0] % episode_width] 110 | if count < (episode_width + 1): 111 | correct_by_k_shot[count].append(float(torch.equal(yy_hat.cpu(), torch.unsqueeze(yy, dim=1)))) 112 | seen_count[yy[0] % episode_width] += 1 113 | 114 | print("episode batch: {0:d} average loss: {1:.6f}".format(i, (cummulative_loss / (counter)))) 115 | print("validation overall accuracy {0:f}".format(np.mean(correct))) 116 | for idx in range(episode_width + 1): 117 | print("{0:d}-shot: {1:.3f}".format(idx, np.mean(correct_by_k_shot[idx]))) 118 | cummulative_loss = 0 119 | counter = 0 120 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import _pickle as pickle 2 | import logging 3 | import os 4 | import subprocess 5 | 6 | import numpy as np 7 | from scipy.misc import imresize 8 | from scipy.misc import imrotate 9 | from scipy.ndimage import imread 10 | import tensorflow as tf 11 | 12 | 13 | REPO_LOCATION = 'https://github.com/brendenlake/omniglot.git' 14 | REPO_DIR = os.path.join(os.getcwd(), 'omniglot') 15 | DATA_DIR = os.path.join(REPO_DIR, 'python') 16 | TRAIN_DIR = os.path.join(DATA_DIR, 'images_background') 17 | TEST_DIR = os.path.join(DATA_DIR, 'images_evaluation') 18 | DATA_FILE_FORMAT = os.path.join(os.getcwd(), '%s_omni.pkl') 19 | 20 | TEST_ROTATIONS = False # augment testing data with rotations 21 | IMAGE_ORIGINAL_SIZE = 105 22 | IMAGE_NEW_SIZE = 28 23 | 24 | def crawl_directory(directory, augment_with_rotations=False, first_label=0): 25 | """Crawls data directory and returns stuff.""" 26 | label_idx = first_label 27 | images = [] 28 | labels = [] 29 | info = [] 30 | 31 | # traverse root directory 32 | for root, _, files in os.walk(directory): 33 | logging.info('Reading files from %s', root) 34 | 35 | for file_name in files: 36 | full_file_name = os.path.join(root, file_name) 37 | img = imread(full_file_name, flatten=True) 38 | for idx, angle in enumerate([0, 90, 180, 270]): 39 | if not augment_with_rotations and idx > 0: 40 | break 41 | 42 | images.append(imrotate(img, angle)) 43 | labels.append(label_idx + idx) 44 | info.append(full_file_name) 45 | 46 | if len(files) == 20: 47 | label_idx += 4 if augment_with_rotations else 1 48 | return images, labels, info 49 | 50 | 51 | def resize_images(images, new_width, new_height): 52 | """Resize images to new dimensions.""" 53 | resized_images = np.zeros([images.shape[0], new_width, new_height], dtype=np.float32) 54 | 55 | for idx in range(images.shape[0]): 56 | resized_images[idx, :, :] = imresize(images[idx, :, :], 57 | [new_width, new_height], 58 | interp='bilinear', 59 | mode=None) 60 | return resized_images 61 | 62 | 63 | def write_datafiles(directory, write_file, 64 | resize=True, rotate=False, 65 | new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE, 66 | first_label=0): 67 | """Load and preprocess images from a directory and write them to a file. 68 | 69 | Args: 70 | directory: Directory of alphabet sub-directories. 71 | write_file: Filename to write to. 72 | resize: Whether to resize the images. 73 | rotate: Whether to augment the dataset with rotations. 74 | new_width: New resize width. 75 | new_height: New resize height. 76 | first_label: Label to start with. 77 | 78 | Returns: 79 | Number of new labels created. 80 | """ 81 | 82 | # these are the default sizes for Omniglot: 83 | imgwidth = IMAGE_ORIGINAL_SIZE 84 | imgheight = IMAGE_ORIGINAL_SIZE 85 | 86 | logging.info('Reading the data.') 87 | images, labels, info = crawl_directory(directory, augment_with_rotations=rotate, first_label=first_label) 88 | 89 | images_np = np.zeros([len(images), imgwidth, imgheight], dtype=np.bool) 90 | labels_np = np.zeros([len(labels)], dtype=np.uint32) 91 | for idx in range(len(images)): 92 | images_np[idx, :, :] = images[idx] 93 | labels_np[idx] = labels[idx] 94 | 95 | if resize: 96 | logging.info('Resizing images.') 97 | resized_images = resize_images(images_np, new_width, new_height) 98 | 99 | logging.info('Writing resized data in float32 format.') 100 | data = {'images': resized_images, 101 | 'labels': labels_np, 102 | 'info': info} 103 | with tf.gfile.GFile(write_file, 'w') as f: 104 | pickle.dump(data, f) 105 | else: 106 | logging.info('Writing original sized data in boolean format.') 107 | data = {'images': images_np, 108 | 'labels': labels_np, 109 | 'info': info} 110 | with tf.gfile.GFile(write_file, 'w') as f: 111 | pickle.dump(data, f) 112 | 113 | return len(np.unique(labels_np)) 114 | 115 | def maybe_download_data(): 116 | """Download Omniglot repo if it does not exist.""" 117 | if os.path.exists(REPO_DIR): 118 | logging.info('It appears that Git repo already exists.') 119 | else: 120 | logging.info('It appears that Git repo does not exist.') 121 | logging.info('Cloning now.') 122 | 123 | subprocess.check_output('git clone %s' % REPO_LOCATION, shell=True) 124 | 125 | if os.path.exists(TRAIN_DIR): 126 | logging.info('It appears that train data has already been unzipped.') 127 | else: 128 | logging.info('It appears that train data has not been unzipped.') 129 | logging.info('Unzipping now.') 130 | 131 | subprocess.check_output('unzip %s.zip -d %s' % (TRAIN_DIR, DATA_DIR), shell=True) 132 | 133 | if os.path.exists(TEST_DIR): 134 | logging.info('It appears that test data has already been unzipped.') 135 | else: 136 | logging.info('It appears that test data has not been unzipped.') 137 | logging.info('Unzipping now.') 138 | 139 | subprocess.check_output('unzip %s.zip -d %s' % (TEST_DIR, DATA_DIR), shell=True) 140 | 141 | def preprocess_omniglot(): 142 | """Download and prepare raw Omniglot data. 143 | 144 | Downloads the data from GitHub if it does not exist. 145 | Then load the images, augment with rotations if desired. 146 | Resize the images and write them to a pickle file. 147 | """ 148 | 149 | maybe_download_data() 150 | 151 | directory = TRAIN_DIR 152 | write_file = DATA_FILE_FORMAT % 'train' 153 | num_labels = write_datafiles(directory, write_file, resize=True, rotate=True, new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE) 154 | 155 | directory = TEST_DIR 156 | write_file = DATA_FILE_FORMAT % 'test' 157 | write_datafiles(directory, write_file, resize=True, rotate=False, new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE) 158 | 159 | preprocess_omniglot() 160 | -------------------------------------------------------------------------------- /memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as ag 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import math 7 | import functools 8 | 9 | def random_uniform(shape, low, high, cuda): 10 | x = torch.rand(*shape) 11 | result_cpu = (high - low) * x + low 12 | if cuda: 13 | return result_cpu.cuda() 14 | else: 15 | return result_cpu 16 | 17 | def multiply(x): 18 | return functools.reduce(lambda x,y: x*y, x, 1) 19 | 20 | def flatten(x): 21 | """ Flatten matrix into a vector """ 22 | count = multiply(x.size()) 23 | return x.resize_(count) 24 | 25 | def index(batch_size, x): 26 | idx = torch.arange(0, batch_size).long() 27 | idx = torch.unsqueeze(idx, -1) 28 | return torch.cat((idx, x), dim=1) 29 | 30 | def MemoryLoss(positive, negative, margin): 31 | """ 32 | Calculate Average Memory Loss Function 33 | positive - positive cosine similarity 34 | negative - negative cosine similarity 35 | margin 36 | """ 37 | assert(positive.size() == negative.size()) 38 | dist_hinge = torch.clamp(negative - positive + margin, min=0.0) 39 | loss = torch.mean(dist_hinge) 40 | return loss 41 | 42 | """ 43 | Softmax Temperature - 44 | + Assume we have K elements at distance x. One element is at distance x+a 45 | + e^tm(x+a) / K*e^tm*x + e^tm(x+a) = e^tm*a / K + e^tm*a 46 | + For 20% probability, e^tm*a = 0.2K -> tm = ln(0.2 K)/a 47 | """ 48 | 49 | class Memory(nn.Module): 50 | def __init__(self, memory_size, key_dim, top_k = 256, inverse_temp = 40, age_noise=8.0, margin = 0.1): 51 | super(Memory, self).__init__() 52 | # Constants 53 | self.memory_size = memory_size 54 | self.key_dim = key_dim 55 | self.top_k = min(top_k, memory_size) 56 | self.softmax_temperature = max(1.0, math.log(0.2 * top_k) / inverse_temp) 57 | self.age_noise = age_noise 58 | self.margin = margin 59 | 60 | # Parameters 61 | self.build() 62 | self.query_proj = nn.Linear(key_dim, key_dim) 63 | 64 | def build(self): 65 | self.keys = F.normalize(random_uniform((self.memory_size, self.key_dim), -0.001, 0.001, cuda=True), dim=1) 66 | self.keys_var = ag.Variable(self.keys, requires_grad=False) 67 | self.values = torch.zeros(self.memory_size, 1).long().cuda() 68 | self.age = torch.zeros(self.memory_size, 1).cuda() 69 | 70 | def predict(self, x): 71 | batch_size, dims = x.size() 72 | query = F.normalize(self.query_proj(x), dim=1) 73 | 74 | # Find the k-nearest neighbors of the query 75 | scores = torch.matmul(query, torch.t(self.keys_var)) 76 | cosine_similarity, topk_indices_var = torch.topk(scores, self.top_k, dim=1) 77 | 78 | # softmax of cosine similarities - embedding 79 | softmax_score = F.softmax(self.softmax_temperature * cosine_similarity) 80 | 81 | # retrive memory values - prediction 82 | y_hat_indices = topk_indices_var.data[:, 0] 83 | y_hat = self.values[y_hat_indices] 84 | 85 | return y_hat, softmax_score 86 | 87 | def query(self, x, y, predict=False): 88 | """ 89 | Compute the nearest neighbor of the input queries. 90 | 91 | Arguments: 92 | x: A normalized matrix of queries of size (batch_size x key_dim) 93 | y: A matrix of correct labels (batch_size x 1) 94 | Returns: 95 | y_hat, A (batch-size x 1) matrix 96 | - the nearest neighbor to the query in memory_size 97 | softmax_score, A (batch_size x 1) matrix 98 | - A normalized score measuring the similarity between query and nearest neighbor 99 | loss - average loss for memory module 100 | """ 101 | batch_size, dims = x.size() 102 | query = F.normalize(self.query_proj(x), dim=1) 103 | #query = F.normalize(torch.matmul(x, self.query_proj), dim=1) 104 | 105 | # Find the k-nearest neighbors of the query 106 | scores = torch.matmul(query, torch.t(self.keys_var)) 107 | cosine_similarity, topk_indices_var = torch.topk(scores, self.top_k, dim=1) 108 | 109 | # softmax of cosine similarities - embedding 110 | softmax_score = F.softmax(self.softmax_temperature * cosine_similarity) 111 | 112 | # retrive memory values - prediction 113 | topk_indices = topk_indices_var.detach().data 114 | y_hat_indices = topk_indices[:, 0] 115 | y_hat = self.values[y_hat_indices] 116 | 117 | loss = None 118 | if not predict: 119 | # Loss Function 120 | # topk_indices = (batch_size x topk) 121 | # topk_values = (batch_size x topk x value_size) 122 | 123 | # collect the memory values corresponding to the topk scores 124 | batch_size, topk_size = topk_indices.size() 125 | flat_topk = flatten(topk_indices) 126 | flat_topk_values = self.values[topk_indices] 127 | topk_values = flat_topk_values.resize_(batch_size, topk_size) 128 | 129 | correct_mask = torch.eq(topk_values, torch.unsqueeze(y.data, dim=1)).float() 130 | correct_mask_var = ag.Variable(correct_mask, requires_grad=False) 131 | 132 | pos_score, pos_idx = torch.topk(torch.mul(cosine_similarity, correct_mask_var), 1, dim=1) 133 | neg_score, neg_idx = torch.topk(torch.mul(cosine_similarity, 1-correct_mask_var), 1, dim=1) 134 | 135 | # zero-out correct scores if there are no correct values in topk values 136 | mask = 1.0 - torch.eq(torch.sum(correct_mask_var, dim=1), 0.0).float() 137 | pos_score = torch.mul(pos_score, torch.unsqueeze(mask, dim=1)) 138 | 139 | #print(pos_score, neg_score) 140 | loss = MemoryLoss(pos_score, neg_score, self.margin) 141 | 142 | # Update memory 143 | self.update(query, y, y_hat, y_hat_indices) 144 | 145 | return y_hat, softmax_score, loss 146 | 147 | def update(self, query, y, y_hat, y_hat_indices): 148 | batch_size, dims = query.size() 149 | 150 | # 1) Untouched: Increment memory by 1 151 | self.age += 1 152 | 153 | # Divide batch by correctness 154 | result = torch.squeeze(torch.eq(y_hat, torch.unsqueeze(y.data, dim=1))).float() 155 | incorrect_examples = torch.squeeze(torch.nonzero(1-result)) 156 | correct_examples = torch.squeeze(torch.nonzero(result)) 157 | 158 | incorrect = len(incorrect_examples.size()) > 0 159 | correct = len(correct_examples.size()) > 0 160 | 161 | # 2) Correct: if V[n1] = v 162 | # Update Key k[n1] <- normalize(q + K[n1]), Reset Age A[n1] <- 0 163 | if correct: 164 | correct_indices = y_hat_indices[correct_examples] 165 | correct_keys = self.keys[correct_indices] 166 | correct_query = query.data[correct_examples] 167 | 168 | new_correct_keys = F.normalize(correct_keys + correct_query, dim=1) 169 | self.keys[correct_indices] = new_correct_keys 170 | self.age[correct_indices] = 0 171 | 172 | # 3) Incorrect: if V[n1] != v 173 | # Select item with oldest age, Add random offset - n' = argmax_i(A[i]) + r_i 174 | # K[n'] <- q, V[n'] <- v, A[n'] <- 0 175 | if incorrect: 176 | incorrect_size = incorrect_examples.size()[0] 177 | incorrect_query = query.data[incorrect_examples] 178 | incorrect_values = y.data[incorrect_examples] 179 | 180 | age_with_noise = self.age + random_uniform((self.memory_size, 1), -self.age_noise, self.age_noise, cuda=True) 181 | topk_values, topk_indices = torch.topk(age_with_noise, incorrect_size, dim=0) 182 | oldest_indices = torch.squeeze(topk_indices) 183 | 184 | self.keys[oldest_indices] = incorrect_query 185 | self.values[oldest_indices] = incorrect_values 186 | self.age[oldest_indices] = 0 187 | -------------------------------------------------------------------------------- /omniglot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import _pickle as pickle 4 | import numpy as np 5 | 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | import torch.utils.data.sampler as sampler 9 | from torchvision import transforms, utils 10 | 11 | def random_index(seed, N): 12 | """ Args: seed - initial index, N - maximum index 13 | Return: A random index between [0, N] except for seed 14 | """ 15 | offset = random.randint(1, N-1) 16 | idx = (seed + offset) % N 17 | assert(seed != idx) 18 | return idx 19 | 20 | class OmniglotDataset(Dataset): 21 | """Omniglot dataset.""" 22 | 23 | def __init__(self, filepath): 24 | """ 25 | Args: 26 | filepath (string): path to data file 27 | Data format - list of characters, list of images, (row, col, ch) numpy array normalized between (0.0, 1.0) 28 | Omniglot dataset - Each language contains a set of characters; Each character is defined by 20 different images 29 | """ 30 | with open(filepath, "rb") as f: 31 | processed_data = pickle.load(f) 32 | 33 | self.data = dict() 34 | for image, label in zip(processed_data['images'], processed_data['labels']): 35 | if label not in self.data: 36 | self.data[label] = list() 37 | img = np.expand_dims(image, axis=0).astype('float32') 38 | #img /= 255.0 39 | self.data[label].append(img) 40 | self.num_categories = len(self.data) 41 | self.category_size = len(self.data[processed_data['labels'][0]]) 42 | 43 | def sample_episode_batch(self, episode_length, episode_width, batch_size, N): 44 | """Generates a random batch for training or validation. 45 | 46 | Structures each element of the batch as an 'episode'. 47 | Each episode contains episode_length examples and 48 | episode_width distinct labels. 49 | 50 | Args: 51 | data: A dictionary mapping label to list of examples. 52 | episode_length: Number of examples in each episode. 53 | episode_width: Distinct number of labels in each episode. 54 | batch_size: Batch size (number of episodes). 55 | 56 | Returns: 57 | A tuple (x, y) where x is a list of batches of examples 58 | with size episode_length and y is a list of batches of labels. 59 | xx = (batch_size, example), yy = (batch_size,) 60 | """ 61 | for rnd in range(N): 62 | episodes_x = [list() for _ in range(episode_length)] 63 | episodes_y = [list() for _ in range(episode_length)] 64 | assert(self.num_categories >= episode_width) 65 | 66 | for b in range(batch_size): 67 | episode_labels = random.sample(self.data.keys(), episode_width) 68 | 69 | # Evenly divide episode_length among episode_width 70 | remainder = episode_length % episode_width 71 | remainders = [0] * (episode_width - remainder) + [1] * remainder 72 | quotient = int((episode_length - remainder) / episode_width) 73 | episode_x = [random.sample(self.data[label], r + quotient) for label, r in zip(episode_labels, remainders)] 74 | assert(quotient+1 <= self.category_size) 75 | 76 | # Arrange episode so that each distinct label is seen before moving to 2nd showing 77 | # Concatenate class episodes together into single list 78 | episode = sum([[(example, label_id, example_id) for example_id, example in enumerate(examples_per_label)] for label_id, examples_per_label in enumerate(episode_x)], list()) 79 | random.shuffle(episode) 80 | episode.sort(key=lambda elem: elem[2]) 81 | assert len(episode) == episode_length 82 | 83 | # During training, the set of labels for each episode are considered distinct 84 | # The memory is not emptied during each training episode 85 | for idx in range(episode_length): 86 | episodes_x[idx].append(episode[idx][0]) 87 | episodes_y[idx].append(episode[idx][1] + b * episode_width) 88 | 89 | yield ([torch.from_numpy(np.array(xx)) for xx in episodes_x], 90 | [torch.from_numpy(np.array(yy)) for yy in episodes_y]) 91 | 92 | class SiameseDataset(Dataset): 93 | """Siamese Dataset dataset.""" 94 | 95 | def __init__(self, filepath): 96 | """ 97 | Args: 98 | filepath (string): path to data file 99 | Data format - list of characters, list of images, (row, col, ch) numpy array normalized between (0.0, 1.0) 100 | Omniglot dataset - Each language contains a set of characters; Each character is defined by 20 different images 101 | """ 102 | with open(filepath, "rb") as f: 103 | processed_data = pickle.load(f) 104 | 105 | self.data = dict() 106 | for image, label in zip(processed_data['images'], processed_data['labels']): 107 | if label not in self.data: 108 | self.data[label] = list() 109 | img = np.expand_dims(image, axis=0).astype('float32') 110 | img /= 255.0 111 | self.data[label].append(img) 112 | self.num_categories = len(self.data) 113 | self.category_size = len(self.data[processed_data['labels'][0]]) 114 | 115 | def __len__(self): 116 | return self.num_categories 117 | 118 | def __getitem__(self, idx): 119 | raise NotImplementedError 120 | 121 | class TrainSiameseDataset(SiameseDataset): 122 | def __init__(self, filepath): 123 | super(TrainSiameseDataset, self).__init__(filepath) 124 | 125 | def __getitem__(self, idx): 126 | index, same = idx 127 | 128 | if same: 129 | imageset = self.data[index] 130 | selected = random.sample(imageset, 2) 131 | images = [torch.from_numpy(image) for image in selected] 132 | else: 133 | left_imageset = self.data[index] 134 | right_imageset = self.data[random_index(index, self.num_categories)] 135 | left_img = random.sample(left_imageset, 1) 136 | right_img = random.sample(right_imageset, 1) 137 | images = [torch.from_numpy(image) for image in (left_img + right_img)] 138 | 139 | label = int(same) 140 | sample = [images, label] 141 | return sample 142 | 143 | class TestSiameseDataset(SiameseDataset): 144 | def __init__(self, filepath): 145 | super(TestSiameseDataset, self).__init__(filepath) 146 | 147 | def __getitem__(self, idx): 148 | """ Args: [test_image, same] = idx 149 | test_image = (test_category, test_category_image) 150 | same (bool) = if support image comes from the same category 151 | """ 152 | test_id, same = idx 153 | category, index = test_id 154 | test_img = self.data[category][index] 155 | 156 | support_idx = random_index(index, self.category_size) 157 | if same: 158 | support_img = self.data[category][support_idx] 159 | else: 160 | support_category = random_index(category, self.num_categories) 161 | support_img = self.data[support_category][support_idx] 162 | 163 | selected = (test_img, support_img) 164 | images = [torch.from_numpy(image) for image in selected] 165 | label = int(same) 166 | sample = [images, label] 167 | return sample 168 | 169 | class SiameseSampler(sampler.Sampler): 170 | """Samples elements for Siamese Network Training.""" 171 | 172 | def __init__(self, data_source, rnd, batch_size, sampler_type): 173 | """ Args: classes - number of classes in dataset 174 | rnd - number of iterations 175 | batch_size - size of batch 176 | sampler_type - (test = 1) OR (train = 0) 177 | split (int) - index to switch from same (label=1) to different (label=0) 178 | """ 179 | self.data_source = data_source 180 | self.rnd = rnd 181 | self.batch_size = batch_size 182 | self.sampler_type = sampler_type 183 | self.split = 1 if sampler_type else int(batch_size/2) 184 | 185 | def __len__(self): 186 | return self.batch_size * self.rnd 187 | 188 | def __iter__(self): 189 | if self.sampler_type: 190 | pos = self.generate_test() 191 | 192 | batch_index = 0 193 | for idx in range(self.batch_size * self.rnd): 194 | if not self.sampler_type: 195 | pos = random.randint(0, len(self.data_source)-1) 196 | 197 | if batch_index < self.split: 198 | yield (pos, True) 199 | else: 200 | yield (pos, False) 201 | 202 | batch_index += 1 203 | if batch_index == self.batch_size: 204 | batch_index = 0 205 | if self.sampler_type: 206 | pos = self.generate_test() 207 | 208 | def generate_test(self): 209 | category = random.randint(0, self.data_source.num_categories-1) 210 | index = random.randint(0, self.data_source.category_size-1) 211 | return (category, index) 212 | -------------------------------------------------------------------------------- /siamese.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | 7 | from torch.autograd import Variable 8 | import torch.optim as optim 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from torch.utils.data import Dataset, DataLoader 13 | import omniglot 14 | 15 | class Net(nn.Module): 16 | def __init__(self, input_shape): 17 | super(Net, self).__init__() 18 | ch, row, col = input_shape 19 | kernel = 3 20 | pad = int((kernel-1)/2.0) 21 | 22 | self.predict = nn.Linear(128, 2) 23 | 24 | self.convolution = nn.Sequential( 25 | nn.Conv2d(ch, 64, kernel, padding=pad), 26 | nn.ReLU(inplace=True), 27 | nn.Conv2d(64, 64, kernel, padding=pad), 28 | nn.ReLU(inplace=True), 29 | nn.MaxPool2d(2, 2), 30 | nn.Conv2d(64, 128, kernel, padding=pad), 31 | nn.ReLU(inplace=True), 32 | nn.Conv2d(128, 128, kernel, padding=pad), 33 | nn.ReLU(inplace=True), 34 | nn.MaxPool2d(2,2) 35 | ) 36 | 37 | self.fc = nn.Sequential( 38 | nn.Linear(row // 4 * col // 4 * 128, 128), 39 | nn.Sigmoid() 40 | ) 41 | 42 | def embed(self, x): 43 | x = self.convolution(x) 44 | x = x.view(x.size(0), -1) 45 | x = self.fc(x) 46 | return x 47 | 48 | def forward(self, x, y): 49 | embed_x = self.embed(x) 50 | embed_y = self.embed(y) 51 | l1_distance = torch.abs(embed_x - embed_y) 52 | result = self.predict(l1_distance) 53 | return result 54 | 55 | epochs = 1000 56 | rnd = 1000 57 | M = 32 58 | N = 20 59 | K = 250 60 | DATA_FILE_FORMAT = os.path.join(os.getcwd(), '%s_omni.pkl') 61 | 62 | train_filepath = DATA_FILE_FORMAT % 'train' 63 | train_set = omniglot.TrainSiameseDataset(train_filepath) 64 | train_sampler = omniglot.SiameseSampler(train_set, rnd, M, False) 65 | trainloader = torch.utils.data.DataLoader(train_set, batch_size=M, shuffle=True, sampler=train_sampler, num_workers=4) 66 | 67 | test_filepath = DATA_FILE_FORMAT % 'test' 68 | test_set = omniglot.TestSiameseDataset(test_filepath) 69 | test_sampler = omniglot.SiameseSampler(test_set, K, N, True) 70 | testloader = torch.utils.data.DataLoader(test_set, batch_size=N, shuffle=False, sampler=test_sampler, num_workers=4) 71 | 72 | #torch.cuda.set_device(1) 73 | net = Net(input_shape=(1,28,28)) 74 | net.cuda() 75 | 76 | criterion = nn.CrossEntropyLoss() 77 | optimizer = optim.Adam(net.parameters(), lr=1e-3) 78 | 79 | for epoch in range(epochs): 80 | running_loss = 0 81 | for i, data in enumerate(trainloader, 0): 82 | optimizer.zero_grad() 83 | inputs, labels = data 84 | left, right = inputs 85 | left, right, labels = Variable(left.cuda()), Variable(right.cuda()), Variable(labels.cuda()) 86 | y_hat = net(left, right) 87 | loss = criterion(y_hat, labels) 88 | loss.backward() 89 | optimizer.step() 90 | 91 | running_loss += loss.data[0] 92 | if i == len(trainloader)-1: 93 | print("[{0:d}, {1:5d}] loss: {2:.3f}".format((epoch+1), (i+1), (running_loss / len(trainloader)))) 94 | running_loss = 0.0 95 | 96 | print('Finished Training') 97 | 98 | total = 0 99 | correct = 0 100 | print("Evaluating model on {0} unique {1}-way one-shot learning tasks ...".format(K,N)) 101 | for i, data in enumerate(testloader, 0): 102 | inputs, labels = data 103 | x, y = inputs 104 | x, y = Variable(x.cuda()), Variable(y.cuda()) 105 | labels = labels.cuda() 106 | y_hat = net(x, y) 107 | _, predicted = torch.max(y_hat.data, 1) 108 | if torch.eq(predicted, labels).sum() == N: 109 | correct += 1 110 | total += 1 111 | 112 | print('Accuracy {0}% for {1}-way one-shot learning: {2}'.format(100 * correct / total, N, correct)) 113 | --------------------------------------------------------------------------------