├── datasets ├── __init__.py ├── GTSRB.py ├── TT100K.py └── listdataset.py ├── README.md ├── loss.py ├── data_transform.py ├── model.py └── main.py /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .GTSRB import gtsrb_data 2 | from .TT100K import tt100k_data 3 | 4 | __all__ = ('gtsrb_data', 'tt100k_data') 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## QuadNet 2 | PyTorch implementation on [Deep Quadruplet Networks (AAAI 2018)](https://arxiv.org/pdf/1712.01907.pdf) 3 | 4 | 5 | ### Datasets 6 | + [GTSRB](http://benchmark.ini.rub.de/?section=gtsrb&subsection=dataset) 7 | + [TT100K](http://cg.cs.tsinghua.edu.cn/traffic-sign/) 8 | 9 | 10 | ### Please cite our paper in your publications if it helps your research: 11 | ``` 12 | @InProceedings{kim2017co, 13 | title={Co-domain Embedding using Deep Quadruplet Networks for Unseen Traffic Sign Recognition}, 14 | author={Kim, Junsik and Lee, Seokju and Oh, Tae-Hyun and Kweon, In So}, 15 | Booktitle = {Proceedings of the 32th AAAI Conference on Artificial Intelligence}, 16 | year={2018} 17 | } 18 | ``` 19 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.nn.functional as F 5 | 6 | import pdb 7 | 8 | 9 | class ContrastiveLoss(torch.nn.Module): 10 | """ 11 | Contrastive loss function. 12 | Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf 13 | """ 14 | 15 | def __init__(self, margin=2.0): 16 | super(ContrastiveLoss, self).__init__() 17 | self.margin = margin 18 | 19 | def forward(self, output1, output2, label): 20 | euclidean_distance = F.pairwise_distance(output1, output2) 21 | loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) + 22 | (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)) 23 | 24 | 25 | return loss_contrastive 26 | 27 | 28 | class HingeMLoss(torch.nn.Module): 29 | """ 30 | Hinge margin loss function. 31 | Based on: https://arxiv.org/pdf/1712.01907.pdf 32 | """ 33 | 34 | def __init__(self, margin_push=5.0, margin_pull=1.0): 35 | super(HingeMLoss, self).__init__() 36 | self.margin_push = margin_push # label=1 37 | self.margin_pull = margin_pull # label=0 38 | 39 | def forward(self, output1, output2, label): 40 | euclidean_distance = F.pairwise_distance(output1, output2) 41 | loss_hinge = torch.mean( 42 | (1-label) * torch.pow(torch.clamp(euclidean_distance - self.margin_pull, min=0.0), 2) + 43 | (label) * torch.pow(torch.clamp(self.margin_push - euclidean_distance, min=0.0), 2) 44 | ) 45 | 46 | 47 | return loss_hinge -------------------------------------------------------------------------------- /data_transform.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps 6 | import numpy as np 7 | import numbers 8 | import types 9 | import scipy.ndimage as ndimage 10 | 11 | # Seokju added 12 | import matplotlib.pyplot as plt 13 | import collections 14 | import pdb 15 | 16 | '''Set of tranform random routines that takes both input and target as arguments, 17 | in order to have random but coherent transformations. 18 | inputs are PIL Image pairs and targets are ndarrays''' 19 | 20 | 21 | # Seokju added 22 | class PILScale(object): 23 | """Rescale the input PIL.Image to the given size. 24 | 25 | Args: 26 | size (sequence or int): Desired output size. If size is a sequence like 27 | (w, h), output size will be matched to this. If size is an int, 28 | smaller edge of the image will be matched to this number. 29 | i.e, if height > width, then image will be rescaled to 30 | (size * height / width, size) 31 | interpolation (int, optional): Desired interpolation. Default is 32 | ``PIL.Image.BILINEAR`` 33 | """ 34 | 35 | def __init__(self, size, interpolation=Image.BILINEAR): 36 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 37 | self.size = size 38 | self.interpolation = interpolation 39 | 40 | def __call__(self, img): 41 | """ 42 | Args: 43 | img (PIL.Image): Image to be scaled. 44 | 45 | Returns: 46 | PIL.Image: Rescaled image. 47 | """ 48 | if isinstance(self.size, int): 49 | w, h = img.size 50 | if (w <= h and w == self.size) or (h <= w and h == self.size): 51 | return img 52 | if w < h: 53 | ow = self.size 54 | oh = int(self.size * h / w) 55 | return img.resize((ow, oh), self.interpolation) 56 | else: 57 | oh = self.size 58 | ow = int(self.size * w / h) 59 | return img.resize((ow, oh), self.interpolation) 60 | else: 61 | # pdb.set_trace() 62 | return img.resize(self.size, self.interpolation) 63 | 64 | -------------------------------------------------------------------------------- /datasets/GTSRB.py: -------------------------------------------------------------------------------- 1 | # Seokju Lee 2017.12.27 2 | """ 3 | Load GTSRB dataset 4 | img, label 5 | """ 6 | import os.path 7 | import random 8 | import glob 9 | import math 10 | from .listdataset import ListDataset 11 | import torch 12 | import pdb 13 | import matplotlib.pyplot as plt 14 | from torch.utils.serialization import load_lua 15 | import numpy as np 16 | 17 | 18 | def make_dataset(base, imfile, gtfile, split=100): 19 | ''' 20 | Will make list of image path and label 21 | 'img.png / label' 22 | ''' 23 | images = [] 24 | labels = [] 25 | 26 | 27 | with open(imfile, 'r') as txtfile: 28 | for line in txtfile: 29 | elem = line.split('\n') 30 | images.append([os.path.join(base, elem[0])]) 31 | with open(gtfile, 'r') as txtfile: 32 | for line in txtfile: 33 | elem = line.split('\n') 34 | labels.append([int(elem[0])]) 35 | 36 | output = np.concatenate((np.array(images), np.array(labels)), axis=1).tolist() 37 | # pdb.set_trace() 38 | 39 | assert(len(output) > 0) 40 | random.shuffle(output) 41 | 42 | split_index = int(math.floor(len(output)*split/100)) 43 | assert(split_index >= 0 and split_index <= len(output)) 44 | # pdb.set_trace() 45 | return output[:split_index] if split_index < len(output) else output 46 | 47 | 48 | def make_tempset(tempfile): 49 | ''' 50 | Will make list of image path and label 51 | 'img.png / label' 52 | ''' 53 | output = [] 54 | 55 | temp_list = sorted( os.listdir(tempfile) ) 56 | for i in range(len(temp_list)): 57 | output.append([os.path.join(tempfile, temp_list[i]), str(i)]) 58 | 59 | assert(len(output) > 0) 60 | random.shuffle(output) 61 | 62 | return output 63 | 64 | 65 | def gtsrb_data(base, tr, tr_gt, te, te_gt, tp_tr, tp_te, transform=None, split=100, should_invert=False): 66 | train_list = make_dataset(base, tr, tr_gt) 67 | test_list = make_dataset(base, te, te_gt, split) 68 | temp_list_tr = make_tempset(tp_tr) 69 | temp_list_te = make_tempset(tp_te) 70 | # pdb.set_trace() 71 | 72 | train_dataset = ListDataset(train_list, temp_list_tr, transform, should_invert) 73 | test_dataset = ListDataset(test_list, temp_list_te, transform, should_invert) 74 | 75 | return train_dataset, test_dataset 76 | -------------------------------------------------------------------------------- /datasets/TT100K.py: -------------------------------------------------------------------------------- 1 | # Seokju Lee 2017.12.27 2 | """ 3 | Load TT100K dataset 4 | img, label 5 | """ 6 | import os.path 7 | import random 8 | import glob 9 | import math 10 | from .listdataset import ListDataset 11 | import torch 12 | import pdb 13 | import matplotlib.pyplot as plt 14 | from torch.utils.serialization import load_lua 15 | import numpy as np 16 | 17 | 18 | def make_dataset(base, imfile, gtfile, split=100): 19 | ''' 20 | Will make list of image path and label 21 | 'img.png / label' 22 | ''' 23 | images = [] 24 | labels = [] 25 | 26 | 27 | with open(imfile, 'r') as txtfile: 28 | for line in txtfile: 29 | elem = line.split('\n') 30 | images.append([os.path.join(base, elem[0])]) 31 | with open(gtfile, 'r') as txtfile: 32 | for line in txtfile: 33 | elem = line.split('\n') 34 | labels.append([int(elem[0])]) 35 | 36 | output = np.concatenate((np.array(images), np.array(labels)), axis=1).tolist() 37 | # pdb.set_trace() 38 | 39 | assert(len(output) > 0) 40 | random.shuffle(output) 41 | 42 | split_index = int(math.floor(len(output)*split/100)) 43 | assert(split_index >= 0 and split_index <= len(output)) 44 | # pdb.set_trace() 45 | return output[:split_index] if split_index < len(output) else output 46 | 47 | 48 | def make_tempset(tempfile): 49 | ''' 50 | Will make list of image path and label 51 | 'img.png / label' 52 | ''' 53 | output = [] 54 | 55 | temp_list = sorted( os.listdir(tempfile) ) 56 | for i in range(len(temp_list)): 57 | output.append([os.path.join(tempfile, temp_list[i]), str(i+1)]) 58 | 59 | assert(len(output) > 0) 60 | random.shuffle(output) 61 | 62 | return output 63 | 64 | 65 | def tt100k_data(base, tr, tr_gt, te, te_gt, tp_tr, tp_te, transform=None, split=100, should_invert=False): 66 | train_list = make_dataset(base, tr, tr_gt) 67 | test_list = make_dataset(base, te, te_gt, split) 68 | temp_list_tr = make_tempset(tp_tr) 69 | temp_list_te = make_tempset(tp_te) 70 | # pdb.set_trace() 71 | 72 | train_dataset = ListDataset(train_list, temp_list_tr, transform, should_invert) 73 | test_dataset = ListDataset(test_list, temp_list_te, transform, should_invert) 74 | 75 | return train_dataset, test_dataset -------------------------------------------------------------------------------- /datasets/listdataset.py: -------------------------------------------------------------------------------- 1 | # Seokju Lee 2018.03.29 2 | """ 3 | Load siamese list 4 | """ 5 | import torch.utils.data as data 6 | import os 7 | import os.path 8 | from scipy.ndimage import imread 9 | import numpy as np 10 | import pdb 11 | import matplotlib.pyplot as plt 12 | from torch.utils.serialization import load_lua 13 | import torch 14 | from scipy import ndimage, misc 15 | from PIL import Image 16 | import random 17 | 18 | 19 | class ListDataset(data.Dataset): 20 | def __init__(self, path_list, temp_list, transform=None, should_invert=False): 21 | self.path_list = path_list 22 | self.temp_list = temp_list 23 | self.transform = transform 24 | self.should_invert = should_invert 25 | # pdb.set_trace() 26 | 27 | def __getitem__(self, index): 28 | RA_list = self.path_list[index] 29 | 30 | # Randomly pick Real-B 31 | while True: 32 | #keep looping till the different class real image is found 33 | RB_list = random.choice(self.path_list) 34 | if RB_list[1] != RA_list[1]: 35 | break 36 | 37 | # Pick Temp-A 38 | while True: 39 | #keep looping till the Temp-A is found 40 | TA_list = random.choice(self.temp_list) 41 | if TA_list[1] == RA_list[1]: 42 | break 43 | 44 | # Pick Temp-B 45 | while True: 46 | #keep looping till the Temp-B is found 47 | TB_list = random.choice(self.temp_list) 48 | if TB_list[1] == RB_list[1]: 49 | break 50 | 51 | RA = Image.open(RA_list[0]) 52 | RB = Image.open(RB_list[0]) 53 | TA = Image.open(TA_list[0]) 54 | TB = Image.open(TB_list[0]) 55 | 56 | if self.should_invert: 57 | RA = PIL.ImageOps.invert(RA) 58 | RB = PIL.ImageOps.invert(RB) 59 | TA = PIL.ImageOps.invert(TA) 60 | TB = PIL.ImageOps.invert(TB) 61 | 62 | if self.transform is not None: 63 | RA = self.transform(RA) 64 | RB = self.transform(RB) 65 | TA = self.transform(TA) 66 | TB = self.transform(TB) 67 | # pdb.set_trace() 68 | 69 | return RA, RB, TA, TB, \ 70 | torch.from_numpy(np.array( [float(RA_list[1])] )), \ 71 | torch.from_numpy(np.array( [float(RB_list[1])] )) 72 | 73 | 74 | 75 | def __len__(self): 76 | return len(self.path_list) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import pdb 5 | from collections import OrderedDict 6 | 7 | 8 | 9 | class QuadNet(nn.Module): 10 | def __init__(self, num_classes=10): 11 | super(QuadNet, self).__init__() 12 | 13 | self.featReal = nn.Sequential( 14 | nn.Conv2d(3, 100, kernel_size=7, padding=0), 15 | nn.ReLU(inplace=True), 16 | nn.MaxPool2d(2,2), 17 | nn.Conv2d(100, 150, kernel_size=4, padding=0), 18 | nn.ReLU(inplace=True), 19 | nn.MaxPool2d(2,2), 20 | nn.Conv2d(150, 250, kernel_size=4, padding=0), 21 | nn.ReLU(inplace=True), 22 | nn.MaxPool2d(2,2), 23 | ) 24 | 25 | self.fcReal = nn.Sequential( 26 | nn.Linear(3*3*250, 300), 27 | nn.ReLU(inplace=True), 28 | nn.Linear(300, num_classes), 29 | ) 30 | 31 | self.featTemp = nn.Sequential( 32 | nn.Conv2d(3, 100, kernel_size=7, padding=0), 33 | nn.ReLU(inplace=True), 34 | nn.MaxPool2d(2,2), 35 | nn.Conv2d(100, 150, kernel_size=4, padding=0), 36 | nn.ReLU(inplace=True), 37 | nn.MaxPool2d(2,2), 38 | nn.Conv2d(150, 250, kernel_size=4, padding=0), 39 | nn.ReLU(inplace=True), 40 | nn.MaxPool2d(2,2), 41 | ) 42 | 43 | self.fcTemp = nn.Sequential( 44 | nn.Linear(3*3*250, 300), 45 | nn.ReLU(inplace=True), 46 | nn.Linear(300, num_classes), 47 | ) 48 | 49 | def init_weights(self): 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | nn.init.xavier_uniform(m.weight.data) 53 | if m.bias is not None: 54 | m.bias.data.zero_() 55 | 56 | def forward_real(self, x): 57 | x = self.featReal(x) 58 | x = x.view(-1, 3*3*250) 59 | x = self.fcReal(x) 60 | return x 61 | 62 | def forward_temp(self, x): 63 | x = self.featTemp(x) 64 | x = x.view(-1, 3*3*250) 65 | x = self.fcTemp(x) 66 | return x 67 | 68 | def forward(self, realA, realB, tempA, tempB): 69 | RA = self.forward_real(realA) 70 | RB = self.forward_real(realB) 71 | TA = self.forward_temp(tempA) 72 | TB = self.forward_temp(tempB) 73 | return RA, RB, TA, TB 74 | 75 | 76 | class QuadNetSingle(nn.Module): 77 | def __init__(self, num_classes=10): 78 | super(QuadNetSingle, self).__init__() 79 | 80 | self.conv1 = nn.Conv2d(3, 100, kernel_size=7, padding=0) 81 | self.pool = nn.MaxPool2d(2,2) 82 | self.conv2 = nn.Conv2d(100, 150, kernel_size=4, padding=0) 83 | self.conv3 = nn.Conv2d(150, 250, kernel_size=4, padding=0) 84 | 85 | self.fc1 = nn.Linear(3*3*250, 300) 86 | self.fc2 = nn.Linear(300, num_classes) 87 | 88 | self.relu = nn.ReLU(inplace=True) 89 | 90 | def init_weights(self): 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv2d): 93 | nn.init.xavier_uniform(m.weight.data) 94 | if m.bias is not None: 95 | m.bias.data.zero_() 96 | 97 | def forward_once(self, x): 98 | x = self.conv1(x) # 100x42x42 99 | x = self.relu(x) # 100x42x42 100 | x1 = self.pool(x) # 100x21x21 101 | 102 | x = self.conv2(x1) # 150x18x18 103 | x = self.relu(x) # 150x18x18 104 | x2 = self.pool(x) # 150x9x9 105 | 106 | x = self.conv3(x2) # 250x6x6 107 | x = self.relu(x) # 250x6x6 108 | x3 = self.pool(x) # 250x3x3 109 | 110 | xv = x3.view(-1, 3*3*250) 111 | xfc1 = self.relu(self.fc1(xv)) 112 | output = self.fc2(xfc1) 113 | 114 | return output 115 | 116 | def forward(self, realA, realB, tempA, tempB): 117 | RA = self.forward_once(realA) 118 | RB = self.forward_once(realB) 119 | TA = self.forward_once(tempA) 120 | TB = self.forward_once(tempB) 121 | return RA, RB, TA, TB -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Seokju Lee, 2018.04.11 3 | PyTorch implementation of "Co-domain Embedding using Deep Quadruplet Networks for Unseen Traffic Sign Recognition (AAAI-18)" 4 | Base codes from "https://github.com/SeokjuLee/SiameseNet-TS" 5 | 6 | v1: QuadNetSingle (IdsiaNet, single tower) 7 | v2: QuadNet (IdsiaNet, two tower) 8 | 9 | python main.py -a QuadNetSingle 10 | python main.py -a QuadNetSingle --evaluate --pretrained /media/rcv/SSD1/git/QuadNet/gtsrb_data/Wed-Apr-11-17:12/100epochs,b50,lr0.0005/model_best.pth.tar 11 | 12 | python main.py 13 | python main.py --evaluate --pretrained /media/rcv/SSD1/git/QuadNet/gtsrb_data/Wed-Apr-11-19:49/100epochs,b50,lr0.0005/model_best.pth.tar 14 | python main.py --evaluate --pretrained /media/rcv/SSD1/git/QuadNet/gtsrb_data/Wed-Apr-11-22:58/100epochs,b50,lr0.0005/model_best.pth.tar 15 | 16 | python main.py --dataset tt100k_data 17 | 18 | 19 | ps aux | grep python 20 | ''' 21 | 22 | import argparse 23 | import os 24 | import torchvision 25 | import torchvision.datasets as dset 26 | import torchvision.transforms as transforms 27 | from torch.utils.data import DataLoader,Dataset 28 | import matplotlib.pyplot as plt 29 | import torchvision.utils 30 | import numpy as np 31 | import random 32 | from PIL import Image 33 | import torch 34 | from torch.autograd import Variable 35 | import PIL.ImageOps 36 | import torch.nn as nn 37 | from torch import optim 38 | import torch.nn.functional as F 39 | from model import QuadNet 40 | from model import QuadNetSingle 41 | from loss import ContrastiveLoss 42 | from loss import HingeMLoss 43 | import data_transform 44 | import datetime 45 | import datasets 46 | import csv 47 | import time 48 | import shutil 49 | import progressbar 50 | 51 | import pdb 52 | 53 | 54 | 55 | # random.seed(1) 56 | dataset_names = sorted(name for name in datasets.__all__) 57 | 58 | parser = argparse.ArgumentParser(description='PyTorch SiameseNet Training on several datasets') 59 | parser.add_argument('--lr', '--learning-rate', default=0.0005, type=float, 60 | metavar='LR', help='initial learning rate') 61 | parser.add_argument('-j', '--workers', default=0, type=int, metavar='N', 62 | help='number of data loading workers (default: 0)') 63 | parser.add_argument('-b', '--batch-size', default=50, type=int, 64 | metavar='N', help='mini-batch size (default: 8)') 65 | parser.add_argument('-e', '--epochs', default=100, type=int, metavar='N', 66 | help='number of total epochs to run (default: 300') 67 | parser.add_argument('--dataset', metavar='DATASET', default='gtsrb_data', 68 | choices=dataset_names, 69 | help='dataset type : ' + 70 | ' | '.join(dataset_names) + 71 | ' (default: gtsrb_data)') 72 | parser.add_argument('--pretrained', dest='pretrained', default = None, 73 | help='path to pre-trained model') 74 | parser.add_argument('--evaluate', dest='evaluate', action='store_true', 75 | help='evaluate model on validation set') 76 | parser.add_argument('--log-summary', default = 'progress_log_summary.csv', 77 | help='csv where to save per-epoch train and test stats') 78 | parser.add_argument('--log-full', default = 'progress_log_full.csv', 79 | help='csv where to save per-gradient descent train stats') 80 | parser.add_argument('--data-parallel', default=None, 81 | help='Use nn.DataParallel() model') 82 | parser.add_argument('--print-freq', '-p', default=100, type=int, 83 | metavar='N', help='print frequency (default: 10)') 84 | parser.add_argument('--arch','-a', default='QuadNet', 85 | help='select architecture') 86 | 87 | 88 | args = parser.parse_args() 89 | 90 | if args.dataset == 'gtsrb_data': 91 | ### GTSRB 92 | class Config(): 93 | base_path = "/media/rcv/SSD1/Logo_oneshot/GTSRB" 94 | tr_im_path = "/media/rcv/SSD1/Logo_oneshot/GTSRB/Experiment02-22-43/train_impaths.txt" 95 | tr_gt_path = "/media/rcv/SSD1/Logo_oneshot/GTSRB/Experiment02-22-43/train_imclasses.txt" 96 | te_im_path = "/media/rcv/SSD1/Logo_oneshot/GTSRB/Experiment02-22-43/test_impaths.txt" 97 | te_gt_path = "/media/rcv/SSD1/Logo_oneshot/GTSRB/Experiment02-22-43/test_imclasses.txt" 98 | tr_tmp_path = "/media/rcv/SSD1/Logo_oneshot/GTSRB/GTSRB_template_ordered" 99 | te_tmp_path = "/media/rcv/SSD1/Logo_oneshot/GTSRB/GTSRB_template_ordered" 100 | 101 | elif args.dataset == 'tt100k_data': 102 | ### TT100K 103 | class Config(): 104 | base_path = "/media/rcv/SSD1/Logo_oneshot/TT100K" 105 | tr_im_path = "/media/rcv/SSD1/Logo_oneshot/TT100K/exp02_exist_classes_only/train_impaths.txt" 106 | tr_gt_path = "/media/rcv/SSD1/Logo_oneshot/TT100K/exp02_exist_classes_only/train_imclasses.txt" 107 | te_im_path = "/media/rcv/SSD1/Logo_oneshot/TT100K/exp02_exist_classes_only/val_impaths.txt" 108 | te_gt_path = "/media/rcv/SSD1/Logo_oneshot/TT100K/exp02_exist_classes_only/val_imclasses.txt" 109 | tr_tmp_path = "/media/rcv/SSD1/Logo_oneshot/TT100K/TT100K_template_ordered" 110 | te_tmp_path = "/media/rcv/SSD1/Logo_oneshot/TT100K/TT100K_template_ordered" 111 | 112 | 113 | BEST_TEST_LOSS = -1 114 | 115 | 116 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 117 | torch.save(state, os.path.join(save_path,filename)) 118 | if is_best: 119 | shutil.copyfile(os.path.join(save_path,filename), os.path.join(save_path,'model_best.pth.tar')) 120 | 121 | 122 | def imshow(img,text=None,should_save=False): 123 | npimg = img.numpy() 124 | plt.axis("off") 125 | if text: 126 | plt.text(60, 8, text, style='italic',fontweight='bold', 127 | bbox={'facecolor':'white', 'alpha':0.8, 'pad':10}) 128 | plt.imshow(np.transpose(npimg, (1, 2, 0))) 129 | plt.ion() 130 | plt.show() 131 | 132 | def show_plot(iteration,loss): 133 | plt.plot(iteration,loss) 134 | plt.show() 135 | 136 | 137 | 138 | def main(): 139 | global args, BEST_TEST_LOSS, save_path 140 | 141 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 142 | std=[0.229, 0.224, 0.225]) 143 | input_transform = transforms.Compose([ 144 | data_transform.PILScale((48,48)), 145 | transforms.ToTensor(), 146 | # normalize 147 | ]) 148 | 149 | print("=> fetching image/label pairs in '{}'".format(args.dataset)) 150 | train_set, test_set = datasets.__dict__[args.dataset]( 151 | Config.base_path, 152 | Config.tr_im_path, 153 | Config.tr_gt_path, 154 | Config.te_im_path, 155 | Config.te_gt_path, 156 | Config.tr_tmp_path, 157 | Config.te_tmp_path, 158 | transform=input_transform, 159 | split=100, 160 | should_invert=False, 161 | ) 162 | 163 | print('{} samples found, {} train samples and {} test samples '.format(len(test_set)+len(train_set), 164 | len(train_set), 165 | len(test_set))) 166 | 167 | 168 | ### Data visualization ##################### 169 | # vis_dataloader = DataLoader(train_set, 170 | # shuffle=True, 171 | # num_workers=args.workers, 172 | # batch_size=args.batch_size) 173 | # dataiter = iter(vis_dataloader) 174 | # example_batch = next(dataiter) 175 | # concatenated = torch.cat((example_batch[0],example_batch[2]),0) 176 | # imshow(torchvision.utils.make_grid(concatenated)) 177 | # print(example_batch[4].numpy()) 178 | # pdb.set_trace() 179 | ############################################# 180 | 181 | 182 | train_loader = DataLoader(train_set, 183 | shuffle=True, 184 | num_workers=args.workers, 185 | batch_size=args.batch_size) 186 | test_loader = DataLoader(test_set, 187 | shuffle=False, 188 | num_workers=args.workers, 189 | batch_size=args.batch_size) 190 | 191 | if args.arch == 'QuadNet': 192 | net = QuadNet().cuda() 193 | if args.arch == 'QuadNetSingle': 194 | net = QuadNetSingle().cuda() 195 | 196 | if args.pretrained: 197 | print("=> Use pre-trained model") 198 | weights = torch.load(args.pretrained) 199 | net.load_state_dict(weights['state_dict']) 200 | else: 201 | print("=> Randomly initialize model") 202 | net.init_weights() 203 | 204 | 205 | criterion = HingeMLoss() 206 | optimizer = optim.Adam( net.parameters(), lr=args.lr ) 207 | 208 | if args.data_parallel: 209 | net = torch.nn.DataParallel(net).cuda() 210 | 211 | 212 | if args.evaluate: 213 | eval(args.dataset, test_set, net, input_transform) 214 | 215 | 216 | ### Visualize testset with pretrained model ##################### 217 | # while True: 218 | # test_dataloader = DataLoader(test_set, shuffle=True, num_workers=args.workers, batch_size=1) 219 | # dataiter = iter(test_dataloader) 220 | 221 | # for i in range(1): 222 | # ra, rb, ta, tb, _, _ = next(dataiter) 223 | # concatenated = torch.cat((ra,rb),0) 224 | # ra = Variable(ra).cuda() 225 | # rb = Variable(rb).cuda() 226 | # ta = Variable(ta).cuda() 227 | # tb = Variable(tb).cuda() 228 | 229 | # RA,RB,TA,TB = net(ra,rb,ta,tb) 230 | # euclidean_distance = F.pairwise_distance(RA, RB) 231 | # imshow(torchvision.utils.make_grid(concatenated),'Dissimilarity: {:.2f}'.format(euclidean_distance.cpu().data.numpy()[0][0])) 232 | # pdb.set_trace() 233 | ################################################################## 234 | 235 | 236 | save_path = '{}epochs,b{},lr{}'.format( 237 | args.epochs, 238 | args.batch_size, 239 | args.lr) 240 | timestamp = datetime.datetime.now().strftime("%a-%b-%d-%H:%M") 241 | save_path = os.path.join(timestamp,save_path) 242 | save_path = os.path.join(args.dataset,save_path) 243 | print('=> will save everything to {}'.format(save_path)) 244 | if not os.path.exists(save_path): 245 | os.makedirs(save_path) 246 | 247 | 248 | with open(os.path.join(save_path,args.log_summary), 'w') as csvfile: # save every epoch 249 | writer = csv.writer(csvfile, delimiter='\t') 250 | writer.writerow(['tr_loss','te_loss']) 251 | with open(os.path.join(save_path,args.log_full), 'w') as csvfile: # save every iter 252 | writer = csv.writer(csvfile, delimiter='\t') 253 | writer.writerow(['tr_loss_iter']) 254 | 255 | 256 | for epoch in range(0, args.epochs): 257 | adjust_learning_rate(optimizer, epoch) 258 | 259 | train_loss = train(train_loader, net, criterion=criterion, optimizer=optimizer, epoch=epoch) 260 | test_loss = test(test_loader, net, criterion=criterion, epoch=epoch) 261 | 262 | if BEST_TEST_LOSS < 0: 263 | BEST_TEST_LOSS = test_loss 264 | is_best = test_loss < BEST_TEST_LOSS 265 | BEST_TEST_LOSS = min(test_loss, BEST_TEST_LOSS) 266 | 267 | 268 | 269 | ### Save checkpoints 270 | if args.data_parallel: 271 | save_checkpoint({ 272 | 'epoch': epoch + 1, 273 | 'state_dict': net.module.state_dict(), # args.data_parallel = True 274 | 'BEST_TEST_LOSS': BEST_TEST_LOSS, 275 | }, is_best 276 | ) 277 | else: 278 | save_checkpoint({ 279 | 'epoch': epoch + 1, 280 | 'state_dict': net.state_dict(), # args.data_parallel = False 281 | 'BEST_TEST_LOSS': BEST_TEST_LOSS, 282 | }, is_best 283 | ) 284 | if (epoch+1)%10 == 0: 285 | ckptname = 'ckpt_e%04d.pth.tar' %(epoch+1) 286 | if args.data_parallel: 287 | torch.save({ 288 | 'epoch': epoch + 1, 289 | 'state_dict': net.module.state_dict(), # args.data_parallel = True 290 | 'BEST_TEST_LOSS': BEST_TEST_LOSS, 291 | }, os.path.join(save_path,ckptname)) 292 | else: 293 | torch.save({ 294 | 'epoch': epoch + 1, 295 | 'state_dict': net.state_dict(), # args.data_parallel = False 296 | 'BEST_TEST_LOSS': BEST_TEST_LOSS, 297 | }, os.path.join(save_path,ckptname)) 298 | 299 | ### Save epoch logs 300 | with open(os.path.join(save_path,args.log_summary), 'a') as csvfile: 301 | writer = csv.writer(csvfile, delimiter='\t') 302 | writer.writerow([train_loss, test_loss]) 303 | 304 | 305 | 306 | def train(train_loader, net, criterion, optimizer, epoch): 307 | batch_time = AverageMeter() 308 | data_time = AverageMeter() 309 | losses = AverageMeter() 310 | 311 | net.train() 312 | end = time.time() 313 | 314 | for i, data in enumerate(train_loader, 0): 315 | # if i>10: break; 316 | ra, rb, ta, tb, la, lb = data 317 | ra, rb, ta, tb = Variable(ra).cuda(), Variable(rb).cuda(), Variable(ta).cuda(), Variable(tb).cuda() 318 | la, lb = Variable(la).cuda(), Variable(lb).cuda() 319 | data_time.update(time.time() - end) 320 | 321 | # Assign labels: push (different class) and pull (same class) 322 | push = Variable(torch.FloatTensor([1])).cuda().resize(1,1).expand(ra.size(0),1) 323 | pull = Variable(torch.FloatTensor([0])).cuda().resize(1,1).expand(ra.size(0),1) 324 | 325 | RA, RB, TA, TB = net(ra,rb,ta,tb) 326 | # pdb.set_trace() 327 | loss_TATB = criterion(TA,TB,push) 328 | loss_TARA = criterion(TA,RA,pull) 329 | loss_TBRB = criterion(TB,RB,pull) 330 | loss_TARB = criterion(TA,RB,push) 331 | loss_TBRA = criterion(TB,RA,push) 332 | loss = loss_TATB + loss_TARA + loss_TBRB + loss_TARB + loss_TBRA 333 | losses.update(loss.data[0], ra.size(0)) 334 | 335 | optimizer.zero_grad() 336 | loss.backward() 337 | optimizer.step() 338 | 339 | # if epoch > 2 and i % args.print_freq and loss.data[0] > 3 == 0: 340 | # ED = F.pairwise_distance(output1, output2) 341 | # imgVisCat = torch.cat((img0.data.cpu(), img1.data.cpu()),0) 342 | # imshow(torchvision.utils.make_grid(imgVisCat)) 343 | # pdb.set_trace() 344 | 345 | batch_time.update(time.time() - end) 346 | end = time.time() 347 | 348 | if i % args.print_freq == 0: 349 | print('[{0}|{1}/{2}] ' 350 | 'loss:{loss.val:.3f}({loss.avg:.3f}) ' 351 | 'Batch time: {batch_time.val:.3f}({batch_time.avg:.3f}) ' 352 | 'Data time: {data_time.val:.3f}({data_time.avg:.3f})'.format( 353 | epoch, i, len(train_loader), 354 | loss=losses, 355 | batch_time=batch_time, data_time=data_time)) 356 | with open(os.path.join(save_path,args.log_full), 'a') as csvfile: 357 | writer = csv.writer(csvfile, delimiter='\t') 358 | writer.writerow([loss.data[0]]) 359 | 360 | return losses.avg 361 | 362 | 363 | def test(test_loader, net, criterion, epoch): 364 | losses = AverageMeter() 365 | 366 | net.eval() 367 | 368 | for i, data in enumerate(test_loader, 0): 369 | # if i>10: break; 370 | ra, rb, ta, tb, la, lb = data 371 | ra, rb, ta, tb = Variable(ra).cuda(), Variable(rb).cuda(), Variable(ta).cuda(), Variable(tb).cuda() 372 | la, lb = Variable(la).cuda(), Variable(lb).cuda() 373 | 374 | # Assign labels: push (different class) and pull (same class) 375 | push = Variable(torch.FloatTensor([1])).cuda().resize(1,1).expand(ra.size(0),1) 376 | pull = Variable(torch.FloatTensor([0])).cuda().resize(1,1).expand(ra.size(0),1) 377 | 378 | RA, RB, TA, TB = net(ra,rb,ta,tb) 379 | # pdb.set_trace() 380 | loss_TATB = criterion(TA,TB,push) 381 | loss_TARA = criterion(TA,RA,pull) 382 | loss_TBRB = criterion(TB,RB,pull) 383 | loss_TARB = criterion(TA,RB,push) 384 | loss_TBRA = criterion(TB,RA,push) 385 | loss = loss_TATB + loss_TARA + loss_TBRB + loss_TARB + loss_TBRA 386 | losses.update(loss.data[0], ra.size(0)) 387 | 388 | 389 | if i % args.print_freq == 0: 390 | print('Test: [{0}/{1}] ' 391 | 'loss: {loss.val:.3f}({loss.avg:.3f})'.format( 392 | i, len(test_loader), 393 | loss=losses)) 394 | 395 | print(' * loss: {loss.avg:.3f}\t'.format(loss=losses)) 396 | 397 | return losses.avg 398 | 399 | 400 | 401 | 402 | def eval(db, test_set, net, input_transform): 403 | net.eval() 404 | 405 | test_loader = DataLoader(test_set, 406 | shuffle=False, 407 | num_workers=args.workers, 408 | batch_size=400) 409 | 410 | tmp_list = sorted( os.listdir(Config.te_tmp_path) ) 411 | tmp = [] 412 | for i in range(len(tmp_list)): 413 | img = Image.open( os.path.join(Config.te_tmp_path, tmp_list[i]) ) 414 | img = Variable(input_transform(img)).cuda() 415 | tmp.append(img) 416 | 417 | result_table = [] 418 | bar = progressbar.ProgressBar(max_value=len(test_loader)) 419 | for i, data in enumerate(test_loader, 0): 420 | bar.update(i) 421 | ra, _, _, _, la, _ = data 422 | ra, la = Variable(ra).cuda(), Variable(la).cuda() 423 | ED = np.zeros( (ra.size(0), len(tmp_list)) ) 424 | 425 | for tt in range(len(tmp_list)): 426 | ta = tmp[tt].unsqueeze(0).repeat(ra.size(0),1,1,1) 427 | RA, _, TA, _ = net(ra, ra, ta, ta) 428 | for bs in range(ra.size(0)): 429 | ED[bs][tt] = F.pairwise_distance(RA[bs].unsqueeze(0), TA[bs].unsqueeze(0)).cpu().data.numpy()[0][0] 430 | 431 | for bs in range(ra.size(0)): 432 | if db == 'gtsrb_data': 433 | result_table.append([int(la.cpu().data.numpy()[bs][0]), ED[bs].argmin()]) 434 | elif db == 'tt100k_data': 435 | result_table.append([int(la.cpu().data.numpy()[bs][0]), ED[bs].argmin()+1]) 436 | 437 | results = np.array(result_table) 438 | 439 | 440 | if db == 'gtsrb_data': 441 | seenList = [1,2,3,4,5,7,8,9,10,11,12,13,14,15,17,18,25,26,31,33,35,38] 442 | unseenList = [0,6,16,19,20,21,22,23,24,27,28,29,30,32,34,36,37,39,40,41,42] 443 | elif db == 'tt100k_data': 444 | seenList = range(1, 25) 445 | unseenList = range(25, 35) 446 | 447 | seenScore = np.zeros(len(seenList)) 448 | unseenScore = np.zeros(len(unseenList)) 449 | 450 | for i in range(len(seenList)): 451 | idx = np.where(results[:,0]==seenList[i])[0] 452 | num_wrong = np.count_nonzero(results[idx,0] - results[idx,1]) 453 | seenScore[i] = float((len(idx) - num_wrong)) / float(len(idx)) 454 | 455 | for i in range(len(unseenList)): 456 | idx = np.where(results[:,0]==unseenList[i])[0] 457 | num_wrong = np.count_nonzero(results[idx,0] - results[idx,1]) 458 | unseenScore[i] = float((len(idx) - num_wrong)) / float(len(idx)) 459 | 460 | print('seen:', seenScore.mean(), 'unseen:', unseenScore.mean()) 461 | pdb.set_trace() 462 | 463 | 464 | 465 | 466 | class AverageMeter(object): 467 | """Computes and stores the average and current value""" 468 | def __init__(self): 469 | self.reset() 470 | 471 | def reset(self): 472 | self.val = 0 473 | self.avg = 0 474 | self.sum = 0 475 | self.count = 0 476 | 477 | def update(self, val, n=1): 478 | self.val = val 479 | self.sum += val * n 480 | self.count += n 481 | self.avg = self.sum / self.count 482 | 483 | 484 | 485 | def adjust_learning_rate(optimizer, epoch): 486 | """Sets the learning rate to the initial LR decayed by 2 periodically""" 487 | if (epoch+1) % 10 == 0: 488 | for param_group in optimizer.param_groups: 489 | param_group['lr'] = param_group['lr']/2 490 | 491 | 492 | 493 | if __name__ == '__main__': 494 | main() --------------------------------------------------------------------------------