├── README.md ├── images └── visdom.png ├── train.py ├── triplet_image_loader.py ├── triplet_mnist_loader.py └── tripletnet.py /README.md: -------------------------------------------------------------------------------- 1 | # A PyTorch Implementation for Triplet Networks 2 | 3 | This repository contains a [PyTorch](http://pytorch.org/) implementation for triplet networks. 4 | 5 | The code provides two different ways to load triplets for the network. First, it contain a simple [MNIST Loader](https://github.com/andreasveit/triplet-network-pytorch/blob/master/triplet_mnist_loader.py) that generates triplets from the MNIST class labels. Second, this repository provides a [Triplet Loader](https://github.com/andreasveit/triplet-network-pytorch/blob/master/triplet_image_loader.py) that loads images from folders, provided a [list of triplets](https://github.com/andreasveit/triplet-network-pytorch/blob/master/triplet_image_loader.py#L22). 6 | 7 | ### Example usage: 8 | 9 | ```sh 10 | $ python train.py 11 | ``` 12 | ### Tracking experiments with Visdom 13 | 14 | This repository allows to track experiments with [visdom](https://github.com/facebookresearch/visdom). You can use the [VisdomLinePlotter](https://github.com/andreasveit/triplet-network-pytorch/blob/master/train.py#L216) to plot training progress. 15 | 16 | 17 | 18 | If this implementation is useful to you and your project, please also consider to cite or acknowledge this code repository. 19 | -------------------------------------------------------------------------------- /images/visdom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreasveit/triplet-network-pytorch/14e64764ec33d0a6e54508cee24ef7e7fe3a8366/images/visdom.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import shutil 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torchvision import datasets, transforms 10 | from torch.autograd import Variable 11 | import torch.backends.cudnn as cudnn 12 | from triplet_mnist_loader import MNIST_t 13 | from triplet_image_loader import TripletImageLoader 14 | from tripletnet import Tripletnet 15 | from visdom import Visdom 16 | import numpy as np 17 | 18 | # Training settings 19 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 20 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 21 | help='input batch size for training (default: 64)') 22 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 23 | help='input batch size for testing (default: 1000)') 24 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 25 | help='number of epochs to train (default: 10)') 26 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 27 | help='learning rate (default: 0.01)') 28 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 29 | help='SGD momentum (default: 0.5)') 30 | parser.add_argument('--no-cuda', action='store_true', default=False, 31 | help='enables CUDA training') 32 | parser.add_argument('--seed', type=int, default=1, metavar='S', 33 | help='random seed (default: 1)') 34 | parser.add_argument('--log-interval', type=int, default=20, metavar='N', 35 | help='how many batches to wait before logging training status') 36 | parser.add_argument('--margin', type=float, default=0.2, metavar='M', 37 | help='margin for triplet loss (default: 0.2)') 38 | parser.add_argument('--resume', default='', type=str, 39 | help='path to latest checkpoint (default: none)') 40 | parser.add_argument('--name', default='TripletNet', type=str, 41 | help='name of experiment') 42 | 43 | best_acc = 0 44 | 45 | 46 | def main(): 47 | global args, best_acc 48 | args = parser.parse_args() 49 | args.cuda = not args.no_cuda and torch.cuda.is_available() 50 | torch.manual_seed(args.seed) 51 | if args.cuda: 52 | torch.cuda.manual_seed(args.seed) 53 | global plotter 54 | plotter = VisdomLinePlotter(env_name=args.name) 55 | 56 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 57 | train_loader = torch.utils.data.DataLoader( 58 | MNIST_t('../data', train=True, download=True, 59 | transform=transforms.Compose([ 60 | transforms.ToTensor(), 61 | transforms.Normalize((0.1307,), (0.3081,)) 62 | ])), 63 | batch_size=args.batch_size, shuffle=True, **kwargs) 64 | test_loader = torch.utils.data.DataLoader( 65 | MNIST_t('../data', train=False, transform=transforms.Compose([ 66 | transforms.ToTensor(), 67 | transforms.Normalize((0.1307,), (0.3081,)) 68 | ])), 69 | batch_size=args.batch_size, shuffle=True, **kwargs) 70 | 71 | class Net(nn.Module): 72 | def __init__(self): 73 | super(Net, self).__init__() 74 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 75 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 76 | self.conv2_drop = nn.Dropout2d() 77 | self.fc1 = nn.Linear(320, 50) 78 | self.fc2 = nn.Linear(50, 10) 79 | 80 | def forward(self, x): 81 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 82 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 83 | x = x.view(-1, 320) 84 | x = F.relu(self.fc1(x)) 85 | x = F.dropout(x, training=self.training) 86 | return self.fc2(x) 87 | 88 | model = Net() 89 | tnet = Tripletnet(model) 90 | if args.cuda: 91 | tnet.cuda() 92 | 93 | # optionally resume from a checkpoint 94 | if args.resume: 95 | if os.path.isfile(args.resume): 96 | print("=> loading checkpoint '{}'".format(args.resume)) 97 | checkpoint = torch.load(args.resume) 98 | args.start_epoch = checkpoint['epoch'] 99 | best_prec1 = checkpoint['best_prec1'] 100 | tnet.load_state_dict(checkpoint['state_dict']) 101 | print("=> loaded checkpoint '{}' (epoch {})" 102 | .format(args.resume, checkpoint['epoch'])) 103 | else: 104 | print("=> no checkpoint found at '{}'".format(args.resume)) 105 | 106 | cudnn.benchmark = True 107 | 108 | criterion = torch.nn.MarginRankingLoss(margin = args.margin) 109 | optimizer = optim.SGD(tnet.parameters(), lr=args.lr, momentum=args.momentum) 110 | 111 | n_parameters = sum([p.data.nelement() for p in tnet.parameters()]) 112 | print(' + Number of params: {}'.format(n_parameters)) 113 | 114 | for epoch in range(1, args.epochs + 1): 115 | # train for one epoch 116 | train(train_loader, tnet, criterion, optimizer, epoch) 117 | # evaluate on validation set 118 | acc = test(test_loader, tnet, criterion, epoch) 119 | 120 | # remember best acc and save checkpoint 121 | is_best = acc > best_acc 122 | best_acc = max(acc, best_acc) 123 | save_checkpoint({ 124 | 'epoch': epoch + 1, 125 | 'state_dict': tnet.state_dict(), 126 | 'best_prec1': best_acc, 127 | }, is_best) 128 | 129 | def train(train_loader, tnet, criterion, optimizer, epoch): 130 | losses = AverageMeter() 131 | accs = AverageMeter() 132 | emb_norms = AverageMeter() 133 | 134 | # switch to train mode 135 | tnet.train() 136 | for batch_idx, (data1, data2, data3) in enumerate(train_loader): 137 | if args.cuda: 138 | data1, data2, data3 = data1.cuda(), data2.cuda(), data3.cuda() 139 | data1, data2, data3 = Variable(data1), Variable(data2), Variable(data3) 140 | 141 | # compute output 142 | dista, distb, embedded_x, embedded_y, embedded_z = tnet(data1, data2, data3) 143 | # 1 means, dista should be larger than distb 144 | target = torch.FloatTensor(dista.size()).fill_(1) 145 | if args.cuda: 146 | target = target.cuda() 147 | target = Variable(target) 148 | 149 | loss_triplet = criterion(dista, distb, target) 150 | loss_embedd = embedded_x.norm(2) + embedded_y.norm(2) + embedded_z.norm(2) 151 | loss = loss_triplet + 0.001 * loss_embedd 152 | 153 | # measure accuracy and record loss 154 | acc = accuracy(dista, distb) 155 | losses.update(loss_triplet.data[0], data1.size(0)) 156 | accs.update(acc, data1.size(0)) 157 | emb_norms.update(loss_embedd.data[0]/3, data1.size(0)) 158 | 159 | # compute gradient and do optimizer step 160 | optimizer.zero_grad() 161 | loss.backward() 162 | optimizer.step() 163 | 164 | if batch_idx % args.log_interval == 0: 165 | print('Train Epoch: {} [{}/{}]\t' 166 | 'Loss: {:.4f} ({:.4f}) \t' 167 | 'Acc: {:.2f}% ({:.2f}%) \t' 168 | 'Emb_Norm: {:.2f} ({:.2f})'.format( 169 | epoch, batch_idx * len(data1), len(train_loader.dataset), 170 | losses.val, losses.avg, 171 | 100. * accs.val, 100. * accs.avg, emb_norms.val, emb_norms.avg)) 172 | # log avg values to somewhere 173 | plotter.plot('acc', 'train', epoch, accs.avg) 174 | plotter.plot('loss', 'train', epoch, losses.avg) 175 | plotter.plot('emb_norms', 'train', epoch, emb_norms.avg) 176 | 177 | def test(test_loader, tnet, criterion, epoch): 178 | losses = AverageMeter() 179 | accs = AverageMeter() 180 | 181 | # switch to evaluation mode 182 | tnet.eval() 183 | for batch_idx, (data1, data2, data3) in enumerate(test_loader): 184 | if args.cuda: 185 | data1, data2, data3 = data1.cuda(), data2.cuda(), data3.cuda() 186 | data1, data2, data3 = Variable(data1), Variable(data2), Variable(data3) 187 | 188 | # compute output 189 | dista, distb, _, _, _ = tnet(data1, data2, data3) 190 | target = torch.FloatTensor(dista.size()).fill_(1) 191 | if args.cuda: 192 | target = target.cuda() 193 | target = Variable(target) 194 | test_loss = criterion(dista, distb, target).data[0] 195 | 196 | # measure accuracy and record loss 197 | acc = accuracy(dista, distb) 198 | accs.update(acc, data1.size(0)) 199 | losses.update(test_loss, data1.size(0)) 200 | 201 | print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format( 202 | losses.avg, 100. * accs.avg)) 203 | plotter.plot('acc', 'test', epoch, accs.avg) 204 | plotter.plot('loss', 'test', epoch, losses.avg) 205 | return accs.avg 206 | 207 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 208 | """Saves checkpoint to disk""" 209 | directory = "runs/%s/"%(args.name) 210 | if not os.path.exists(directory): 211 | os.makedirs(directory) 212 | filename = directory + filename 213 | torch.save(state, filename) 214 | if is_best: 215 | shutil.copyfile(filename, 'runs/%s/'%(args.name) + 'model_best.pth.tar') 216 | 217 | class VisdomLinePlotter(object): 218 | """Plots to Visdom""" 219 | def __init__(self, env_name='main'): 220 | self.viz = Visdom() 221 | self.env = env_name 222 | self.plots = {} 223 | def plot(self, var_name, split_name, x, y): 224 | if var_name not in self.plots: 225 | self.plots[var_name] = self.viz.line(X=np.array([x,x]), Y=np.array([y,y]), env=self.env, opts=dict( 226 | legend=[split_name], 227 | title=var_name, 228 | xlabel='Epochs', 229 | ylabel=var_name 230 | )) 231 | else: 232 | self.viz.updateTrace(X=np.array([x]), Y=np.array([y]), env=self.env, win=self.plots[var_name], name=split_name) 233 | 234 | class AverageMeter(object): 235 | """Computes and stores the average and current value""" 236 | def __init__(self): 237 | self.reset() 238 | 239 | def reset(self): 240 | self.val = 0 241 | self.avg = 0 242 | self.sum = 0 243 | self.count = 0 244 | 245 | def update(self, val, n=1): 246 | self.val = val 247 | self.sum += val * n 248 | self.count += n 249 | self.avg = self.sum / self.count 250 | 251 | def accuracy(dista, distb): 252 | margin = 0 253 | pred = (dista - distb - margin).cpu().data 254 | return (pred > 0).sum()*1.0/dista.size()[0] 255 | 256 | if __name__ == '__main__': 257 | main() 258 | -------------------------------------------------------------------------------- /triplet_image_loader.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import os.path 4 | 5 | import torch.utils.data 6 | import torchvision.transforms as transforms 7 | 8 | def default_image_loader(path): 9 | return Image.open(path).convert('RGB') 10 | 11 | class TripletImageLoader(torch.utils.data.Dataset): 12 | def __init__(self, base_path, filenames_filename, triplets_file_name, transform=None, 13 | loader=default_image_loader): 14 | """ filenames_filename: A text file with each line containing the path to an image e.g., 15 | images/class1/sample.jpg 16 | triplets_file_name: A text file with each line containing three integers, 17 | where integer i refers to the i-th image in the filenames file. 18 | For a line of intergers 'a b c', a triplet is defined such that image a is more 19 | similar to image c than it is to image b, e.g., 20 | 0 2017 42 """ 21 | self.base_path = base_path 22 | self.filenamelist = [] 23 | for line in open(filenames_filename): 24 | self.filenamelist.append(line.rstrip('\n')) 25 | triplets = [] 26 | for line in open(triplets_file_name): 27 | triplets.append((line.split()[0], line.split()[1], line.split()[2])) # anchor, far, close 28 | self.triplets = triplets 29 | self.transform = transform 30 | self.loader = loader 31 | 32 | def __getitem__(self, index): 33 | path1, path2, path3 = self.triplets[index] 34 | img1 = self.loader(os.path.join(self.base_path,self.filenamelist[int(path1)])) 35 | img2 = self.loader(os.path.join(self.base_path,self.filenamelist[int(path2)])) 36 | img3 = self.loader(os.path.join(self.base_path,self.filenamelist[int(path3)])) 37 | if self.transform is not None: 38 | img1 = self.transform(img1) 39 | img2 = self.transform(img2) 40 | img3 = self.transform(img3) 41 | 42 | return img1, img2, img3 43 | 44 | def __len__(self): 45 | return len(self.triplets) 46 | -------------------------------------------------------------------------------- /triplet_mnist_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.utils.data as data 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import errno 7 | import torch 8 | import json 9 | import codecs 10 | import numpy as np 11 | import csv 12 | 13 | 14 | class MNIST_t(data.Dataset): 15 | urls = [ 16 | 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 17 | 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 18 | 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', 19 | 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', 20 | ] 21 | raw_folder = 'raw' 22 | processed_folder = 'processed' 23 | training_file = 'training.pt' 24 | test_file = 'test.pt' 25 | train_triplet_file = 'train_triplets.txt' 26 | test_triplet_file = 'test_triplets.txt' 27 | 28 | def __init__(self, root, n_train_triplets=50000, n_test_triplets=10000, train=True, transform=None, target_transform=None, download=False): 29 | self.root = root 30 | 31 | self.transform = transform 32 | self.train = train # training set or test set 33 | 34 | if download: 35 | self.download() 36 | 37 | if not self._check_exists(): 38 | raise RuntimeError('Dataset not found.' + 39 | ' You can use download=True to download it') 40 | 41 | if self.train: 42 | self.train_data, self.train_labels = torch.load( 43 | os.path.join(root, self.processed_folder, self.training_file)) 44 | self.make_triplet_list(n_train_triplets) 45 | triplets = [] 46 | for line in open(os.path.join(root, self.processed_folder, self.train_triplet_file)): 47 | triplets.append((int(line.split()[0]), int(line.split()[1]), int(line.split()[2]))) # anchor, close, far 48 | self.triplets_train = triplets 49 | else: 50 | self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file)) 51 | self.make_triplet_list(n_test_triplets) 52 | triplets = [] 53 | for line in open(os.path.join(root, self.processed_folder, self.test_triplet_file)): 54 | triplets.append((int(line.split()[0]), int(line.split()[1]), int(line.split()[2]))) # anchor, close, far 55 | self.triplets_test = triplets 56 | 57 | 58 | def __getitem__(self, index): 59 | if self.train: 60 | idx1, idx2, idx3 = self.triplets_train[index] 61 | img1, img2, img3 = self.train_data[idx1], self.train_data[idx2], self.train_data[idx3] 62 | else: 63 | idx1, idx2, idx3 = self.triplets_test[index] 64 | img1, img2, img3 = self.test_data[idx1], self.test_data[idx2], self.test_data[idx3] 65 | 66 | # doing this so that it is consistent with all other datasets 67 | # to return a PIL Image 68 | img1 = Image.fromarray(img1.numpy(), mode='L') 69 | img2 = Image.fromarray(img2.numpy(), mode='L') 70 | img3 = Image.fromarray(img3.numpy(), mode='L') 71 | 72 | if self.transform is not None: 73 | img1 = self.transform(img1) 74 | img2 = self.transform(img2) 75 | img3 = self.transform(img3) 76 | 77 | return img1, img2, img3 78 | 79 | def __len__(self): 80 | if self.train: 81 | return len(self.triplets_train) 82 | else: 83 | return len(self.triplets_test) 84 | 85 | def _check_exists(self): 86 | return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ 87 | os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file)) 88 | 89 | def _check_triplets_exists(self): 90 | return os.path.exists(os.path.join(self.root, self.processed_folder, self.train_triplet_file)) and \ 91 | os.path.exists(os.path.join(self.root, self.processed_folder, self.test_triplet_file)) 92 | 93 | def download(self): 94 | from six.moves import urllib 95 | import gzip 96 | 97 | if self._check_exists(): 98 | return 99 | 100 | # download files 101 | try: 102 | os.makedirs(os.path.join(self.root, self.raw_folder)) 103 | os.makedirs(os.path.join(self.root, self.processed_folder)) 104 | except OSError as e: 105 | if e.errno == errno.EEXIST: 106 | pass 107 | else: 108 | raise 109 | 110 | for url in self.urls: 111 | print('Downloading ' + url) 112 | data = urllib.request.urlopen(url) 113 | filename = url.rpartition('/')[2] 114 | file_path = os.path.join(self.root, self.raw_folder, filename) 115 | with open(file_path, 'wb') as f: 116 | f.write(data.read()) 117 | with open(file_path.replace('.gz', ''), 'wb') as out_f, \ 118 | gzip.GzipFile(file_path) as zip_f: 119 | out_f.write(zip_f.read()) 120 | os.unlink(file_path) 121 | 122 | # process and save as torch files 123 | print('Processing...') 124 | 125 | training_set = ( 126 | read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')), 127 | read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte')) 128 | ) 129 | test_set = ( 130 | read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')), 131 | read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte')) 132 | ) 133 | with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f: 134 | torch.save(training_set, f) 135 | with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f: 136 | torch.save(test_set, f) 137 | 138 | print('Done!') 139 | 140 | def make_triplet_list(self, ntriplets): 141 | 142 | if self._check_triplets_exists(): 143 | return 144 | print('Processing Triplet Generation ...') 145 | if self.train: 146 | np_labels = self.train_labels.numpy() 147 | filename = self.train_triplet_file 148 | else: 149 | np_labels = self.test_labels.numpy() 150 | filename = self.test_triplet_file 151 | triplets = [] 152 | for class_idx in range(10): 153 | a = np.random.choice(np.where(np_labels==class_idx)[0], int(ntriplets/10), replace=True) 154 | b = np.random.choice(np.where(np_labels==class_idx)[0], int(ntriplets/10), replace=True) 155 | while np.any((a-b)==0): 156 | np.random.shuffle(b) 157 | c = np.random.choice(np.where(np_labels!=class_idx)[0], int(ntriplets/10), replace=True) 158 | 159 | for i in range(a.shape[0]): 160 | triplets.append([int(a[i]), int(c[i]), int(b[i])]) 161 | 162 | with open(os.path.join(self.root, self.processed_folder, filename), "w") as f: 163 | writer = csv.writer(f, delimiter=' ') 164 | writer.writerows(triplets) 165 | print('Done!') 166 | 167 | 168 | 169 | 170 | 171 | def get_int(b): 172 | return int(codecs.encode(b, 'hex'), 16) 173 | 174 | 175 | def parse_byte(b): 176 | if isinstance(b, str): 177 | return ord(b) 178 | return b 179 | 180 | 181 | def read_label_file(path): 182 | with open(path, 'rb') as f: 183 | data = f.read() 184 | assert get_int(data[:4]) == 2049 185 | length = get_int(data[4:8]) 186 | labels = [parse_byte(b) for b in data[8:]] 187 | assert len(labels) == length 188 | return torch.LongTensor(labels) 189 | 190 | 191 | def read_image_file(path): 192 | with open(path, 'rb') as f: 193 | data = f.read() 194 | assert get_int(data[:4]) == 2051 195 | length = get_int(data[4:8]) 196 | num_rows = get_int(data[8:12]) 197 | num_cols = get_int(data[12:16]) 198 | images = [] 199 | idx = 16 200 | for l in range(length): 201 | img = [] 202 | images.append(img) 203 | for r in range(num_rows): 204 | row = [] 205 | img.append(row) 206 | for c in range(num_cols): 207 | row.append(parse_byte(data[idx])) 208 | idx += 1 209 | assert len(images) == length 210 | return torch.ByteTensor(images).view(-1, 28, 28) 211 | -------------------------------------------------------------------------------- /tripletnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Tripletnet(nn.Module): 6 | def __init__(self, embeddingnet): 7 | super(Tripletnet, self).__init__() 8 | self.embeddingnet = embeddingnet 9 | 10 | def forward(self, x, y, z): 11 | embedded_x = self.embeddingnet(x) 12 | embedded_y = self.embeddingnet(y) 13 | embedded_z = self.embeddingnet(z) 14 | dist_a = F.pairwise_distance(embedded_x, embedded_y, 2) 15 | dist_b = F.pairwise_distance(embedded_x, embedded_z, 2) 16 | return dist_a, dist_b, embedded_x, embedded_y, embedded_z 17 | --------------------------------------------------------------------------------