├── .github └── FUNDING.yml ├── loss.png ├── make_dataset.py ├── model.py ├── mydataset.py ├── readme.md ├── requirements.txt └── train.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | custom: https://github.com/fangpin#sponsorships 2 | -------------------------------------------------------------------------------- /loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fangpin/siamese-pytorch/750cfef127c905b00b984d939ad1a028b9e4ad3f/loss.png -------------------------------------------------------------------------------- /make_dataset.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import sys 3 | import os 4 | 5 | 6 | data_path_read = sys.argv[1] 7 | data_path_write = sys.argv[2] 8 | 9 | for alphabeta in os.listdir(data_path_read): 10 | alphabeta_path = os.path.join(data_path_read, alphabeta) 11 | path_write1 = data_path_write[:-2] + '-' + alphabeta 12 | for charactor in os.listdir(alphabeta_path): 13 | charactor_path = os.path.join(alphabeta_path, charactor) 14 | path_write2 = path_write1 + '-' + charactor 15 | os.makedirs(os.path.join(data_path_write, path_write2)) 16 | for drawer in os.listdir(charactor_path): 17 | drawer_path = os.path.join(charactor_path, drawer) 18 | shutil.copyfile(drawer_path, os.path.join(data_path_write, path_write2, drawer)) 19 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Siamese(nn.Module): 7 | 8 | def __init__(self): 9 | super(Siamese, self).__init__() 10 | self.conv = nn.Sequential( 11 | nn.Conv2d(1, 64, 10), # 64@96*96 12 | nn.ReLU(inplace=True), 13 | nn.MaxPool2d(2), # 64@48*48 14 | nn.Conv2d(64, 128, 7), 15 | nn.ReLU(), # 128@42*42 16 | nn.MaxPool2d(2), # 128@21*21 17 | nn.Conv2d(128, 128, 4), 18 | nn.ReLU(), # 128@18*18 19 | nn.MaxPool2d(2), # 128@9*9 20 | nn.Conv2d(128, 256, 4), 21 | nn.ReLU(), # 256@6*6 22 | ) 23 | self.liner = nn.Sequential(nn.Linear(9216, 4096), nn.Sigmoid()) 24 | self.out = nn.Linear(4096, 1) 25 | 26 | def forward_one(self, x): 27 | x = self.conv(x) 28 | x = x.view(x.size()[0], -1) 29 | x = self.liner(x) 30 | return x 31 | 32 | def forward(self, x1, x2): 33 | out1 = self.forward_one(x1) 34 | out2 = self.forward_one(x2) 35 | dis = torch.abs(out1 - out2) 36 | out = self.out(dis) 37 | # return self.sigmoid(out) 38 | return out 39 | 40 | 41 | # for test 42 | if __name__ == '__main__': 43 | net = Siamese() 44 | print(net) 45 | print(list(net.parameters())) 46 | -------------------------------------------------------------------------------- /mydataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import os 4 | from numpy.random import choice as npc 5 | import numpy as np 6 | import time 7 | import random 8 | import torchvision.datasets as dset 9 | from PIL import Image 10 | 11 | 12 | class OmniglotTrain(Dataset): 13 | 14 | def __init__(self, dataPath, transform=None): 15 | super(OmniglotTrain, self).__init__() 16 | np.random.seed(0) 17 | # self.dataset = dataset 18 | self.transform = transform 19 | self.datas, self.num_classes = self.loadToMem(dataPath) 20 | 21 | def loadToMem(self, dataPath): 22 | print("begin loading training dataset to memory") 23 | datas = {} 24 | agrees = [0, 90, 180, 270] 25 | idx = 0 26 | for agree in agrees: 27 | for alphaPath in os.listdir(dataPath): 28 | for charPath in os.listdir(os.path.join(dataPath, alphaPath)): 29 | datas[idx] = [] 30 | for samplePath in os.listdir(os.path.join(dataPath, alphaPath, charPath)): 31 | filePath = os.path.join(dataPath, alphaPath, charPath, samplePath) 32 | datas[idx].append(Image.open(filePath).rotate(agree).convert('L')) 33 | idx += 1 34 | print("finish loading training dataset to memory") 35 | return datas, idx 36 | 37 | def __len__(self): 38 | return 21000000 39 | 40 | def __getitem__(self, index): 41 | # image1 = random.choice(self.dataset.imgs) 42 | label = None 43 | img1 = None 44 | img2 = None 45 | # get image from same class 46 | if index % 2 == 1: 47 | label = 1.0 48 | idx1 = random.randint(0, self.num_classes - 1) 49 | image1 = random.choice(self.datas[idx1]) 50 | image2 = random.choice(self.datas[idx1]) 51 | # get image from different class 52 | else: 53 | label = 0.0 54 | idx1 = random.randint(0, self.num_classes - 1) 55 | idx2 = random.randint(0, self.num_classes - 1) 56 | while idx1 == idx2: 57 | idx2 = random.randint(0, self.num_classes - 1) 58 | image1 = random.choice(self.datas[idx1]) 59 | image2 = random.choice(self.datas[idx2]) 60 | 61 | if self.transform: 62 | image1 = self.transform(image1) 63 | image2 = self.transform(image2) 64 | return image1, image2, torch.from_numpy(np.array([label], dtype=np.float32)) 65 | 66 | 67 | class OmniglotTest(Dataset): 68 | 69 | def __init__(self, dataPath, transform=None, times=200, way=20): 70 | np.random.seed(1) 71 | super(OmniglotTest, self).__init__() 72 | self.transform = transform 73 | self.times = times 74 | self.way = way 75 | self.img1 = None 76 | self.c1 = None 77 | self.datas, self.num_classes = self.loadToMem(dataPath) 78 | 79 | def loadToMem(self, dataPath): 80 | print("begin loading test dataset to memory") 81 | datas = {} 82 | idx = 0 83 | for alphaPath in os.listdir(dataPath): 84 | for charPath in os.listdir(os.path.join(dataPath, alphaPath)): 85 | datas[idx] = [] 86 | for samplePath in os.listdir(os.path.join(dataPath, alphaPath, charPath)): 87 | filePath = os.path.join(dataPath, alphaPath, charPath, samplePath) 88 | datas[idx].append(Image.open(filePath).convert('L')) 89 | idx += 1 90 | print("finish loading test dataset to memory") 91 | return datas, idx 92 | 93 | def __len__(self): 94 | return self.times * self.way 95 | 96 | def __getitem__(self, index): 97 | idx = index % self.way 98 | label = None 99 | # generate image pair from same class 100 | if idx == 0: 101 | self.c1 = random.randint(0, self.num_classes - 1) 102 | self.img1 = random.choice(self.datas[self.c1]) 103 | img2 = random.choice(self.datas[self.c1]) 104 | # generate image pair from different class 105 | else: 106 | c2 = random.randint(0, self.num_classes - 1) 107 | while self.c1 == c2: 108 | c2 = random.randint(0, self.num_classes - 1) 109 | img2 = random.choice(self.datas[c2]) 110 | 111 | if self.transform: 112 | img1 = self.transform(self.img1) 113 | img2 = self.transform(img2) 114 | return img1, img2 115 | 116 | 117 | # test 118 | if __name__=='__main__': 119 | omniglotTrain = OmniglotTrain('./images_background', 30000*8) 120 | print(omniglotTrain) 121 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Siamese Networks for One-Shot Learning 2 | 3 | A reimplementation of the [original paper](https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf) in pytorch with 4 | training and testing on the [Omniglot dataset](https://github.com/brendenlake/omniglot). 5 | 6 | ## requirement 7 | - pytorch 8 | - torchvision 9 | - python3.5+ 10 | - python-gflags 11 | 12 | See requirements.txt 13 | 14 | ## run step 15 | - download dataset 16 | ``` 17 | git clone https://github.com/brendenlake/omniglot.git 18 | cd omniglot/python 19 | unzip images_evaluation.zip 20 | unzip images_background.zip 21 | cd ../.. 22 | # setup directory for saving models 23 | mkdir models 24 | ``` 25 | - train and test by running 26 | ```shell 27 | python3 train.py --train_path omniglot/python/images_background \ 28 | --test_path omniglot/python/images_evaluation \ 29 | --gpu_ids 0 \ 30 | --model_path models 31 | ``` 32 | 33 | ## experiment result 34 | Loss value is sampled after every 200 batches 35 | ![img](https://github.com/fangpin/siamese-network/blob/master/loss.png) 36 | My final precision is 89.5% a little smaller than the result of the paper (92%). 37 | 38 | The small result difference might be caused by some difference between my implementation and the paper's. I list these differences as follows: 39 | 40 | - learning rate 41 | 42 | instead of using SGD with momentum I just use ADAM. 43 | 44 | - parameters initialization and settings 45 | 46 | Instead of using individual initialization methods, learning rates and regularization rates at different layers I simply use the default setting of pytorch and keep them same. 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cycler==0.10.0 2 | kiwisolver==1.0.1 3 | matplotlib==3.0.2 4 | numpy==1.16.1 5 | Pillow==5.4.1 6 | pkg-resources==0.0.0 7 | pyparsing==2.3.1 8 | python-dateutil==2.8.0 9 | python-gflags==3.1.2 10 | six==1.12.0 11 | torch==1.0.1.post2 12 | torchvision==0.2.1 13 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import torchvision 4 | from torchvision import transforms 5 | import torchvision.datasets as dset 6 | from torchvision import transforms 7 | from mydataset import OmniglotTrain, OmniglotTest 8 | from torch.utils.data import DataLoader 9 | from torch.autograd import Variable 10 | import matplotlib.pyplot as plt 11 | from model import Siamese 12 | import time 13 | import numpy as np 14 | import gflags 15 | import sys 16 | from collections import deque 17 | import os 18 | 19 | 20 | if __name__ == '__main__': 21 | 22 | Flags = gflags.FLAGS 23 | gflags.DEFINE_bool("cuda", True, "use cuda") 24 | gflags.DEFINE_string("train_path", "/home/data/pin/data/omniglot/images_background", "training folder") 25 | gflags.DEFINE_string("test_path", "/home/data/pin/data/omniglot/images_evaluation", 'path of testing folder') 26 | gflags.DEFINE_integer("way", 20, "how much way one-shot learning") 27 | gflags.DEFINE_string("times", 400, "number of samples to test accuracy") 28 | gflags.DEFINE_integer("workers", 4, "number of dataLoader workers") 29 | gflags.DEFINE_integer("batch_size", 128, "number of batch size") 30 | gflags.DEFINE_float("lr", 0.00006, "learning rate") 31 | gflags.DEFINE_integer("show_every", 10, "show result after each show_every iter.") 32 | gflags.DEFINE_integer("save_every", 100, "save model after each save_every iter.") 33 | gflags.DEFINE_integer("test_every", 100, "test model after each test_every iter.") 34 | gflags.DEFINE_integer("max_iter", 50000, "number of iterations before stopping") 35 | gflags.DEFINE_string("model_path", "/home/data/pin/model/siamese", "path to store model") 36 | gflags.DEFINE_string("gpu_ids", "0,1,2,3", "gpu ids used to train") 37 | 38 | Flags(sys.argv) 39 | 40 | data_transforms = transforms.Compose([ 41 | transforms.RandomAffine(15), 42 | transforms.ToTensor() 43 | ]) 44 | 45 | 46 | # train_dataset = dset.ImageFolder(root=Flags.train_path) 47 | # test_dataset = dset.ImageFolder(root=Flags.test_path) 48 | 49 | 50 | os.environ["CUDA_VISIBLE_DEVICES"] = Flags.gpu_ids 51 | print("use gpu:", Flags.gpu_ids, "to train.") 52 | 53 | trainSet = OmniglotTrain(Flags.train_path, transform=data_transforms) 54 | testSet = OmniglotTest(Flags.test_path, transform=transforms.ToTensor(), times = Flags.times, way = Flags.way) 55 | testLoader = DataLoader(testSet, batch_size=Flags.way, shuffle=False, num_workers=Flags.workers) 56 | 57 | trainLoader = DataLoader(trainSet, batch_size=Flags.batch_size, shuffle=False, num_workers=Flags.workers) 58 | 59 | loss_fn = torch.nn.BCEWithLogitsLoss(size_average=True) 60 | net = Siamese() 61 | 62 | # multi gpu 63 | if len(Flags.gpu_ids.split(",")) > 1: 64 | net = torch.nn.DataParallel(net) 65 | 66 | if Flags.cuda: 67 | net.cuda() 68 | 69 | net.train() 70 | 71 | optimizer = torch.optim.Adam(net.parameters(),lr = Flags.lr ) 72 | optimizer.zero_grad() 73 | 74 | train_loss = [] 75 | loss_val = 0 76 | time_start = time.time() 77 | queue = deque(maxlen=20) 78 | 79 | for batch_id, (img1, img2, label) in enumerate(trainLoader, 1): 80 | if batch_id > Flags.max_iter: 81 | break 82 | if Flags.cuda: 83 | img1, img2, label = Variable(img1.cuda()), Variable(img2.cuda()), Variable(label.cuda()) 84 | else: 85 | img1, img2, label = Variable(img1), Variable(img2), Variable(label) 86 | optimizer.zero_grad() 87 | output = net.forward(img1, img2) 88 | loss = loss_fn(output, label) 89 | loss_val += loss.item() 90 | loss.backward() 91 | optimizer.step() 92 | if batch_id % Flags.show_every == 0 : 93 | print('[%d]\tloss:\t%.5f\ttime lapsed:\t%.2f s'%(batch_id, loss_val/Flags.show_every, time.time() - time_start)) 94 | loss_val = 0 95 | time_start = time.time() 96 | if batch_id % Flags.save_every == 0: 97 | torch.save(net.state_dict(), Flags.model_path + '/model-inter-' + str(batch_id+1) + ".pt") 98 | if batch_id % Flags.test_every == 0: 99 | right, error = 0, 0 100 | for _, (test1, test2) in enumerate(testLoader, 1): 101 | if Flags.cuda: 102 | test1, test2 = test1.cuda(), test2.cuda() 103 | test1, test2 = Variable(test1), Variable(test2) 104 | output = net.forward(test1, test2).data.cpu().numpy() 105 | pred = np.argmax(output) 106 | if pred == 0: 107 | right += 1 108 | else: error += 1 109 | print('*'*70) 110 | print('[%d]\tTest set\tcorrect:\t%d\terror:\t%d\tprecision:\t%f'%(batch_id, right, error, right*1.0/(right+error))) 111 | print('*'*70) 112 | queue.append(right*1.0/(right+error)) 113 | train_loss.append(loss_val) 114 | # learning_rate = learning_rate * 0.95 115 | 116 | with open('train_loss', 'wb') as f: 117 | pickle.dump(train_loss, f) 118 | 119 | acc = 0.0 120 | for d in queue: 121 | acc += d 122 | print("#"*70) 123 | print("final accuracy: ", acc/20) 124 | --------------------------------------------------------------------------------