├── .gitignore ├── LICENSE ├── README.md ├── evaluate.py ├── net.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data/ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Fan Jingbo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_deephash 2 | 3 | ## Introduction 4 | 5 | This is the Pytorch implementation of [Deep Learning of Binary Hash Codes for Fast Image Retrieval](https://github.com/kevinlin311tw/caffe-cvprw15), and can achieve more than 93% mAP in CIFAR10 dataset. 6 | 7 | ## Environment 8 | 9 | > Pytorch 1.4.0 10 | > 11 | > torchvision 0.5.0 12 | > 13 | > tqdm 14 | > 15 | > numpy 16 | 17 | 18 | ## Training 19 | 20 | ```bash 21 | python train.py 22 | ``` 23 | 24 | You will get trained models in model folder by default, and models' names are their test accuracy. 25 | 26 | ## Evaluation 27 | 28 | ```bash 29 | python evaluate.py --pretrained {your saved model name in model folder by default} 30 | ``` 31 | 32 | ## Tips 33 | 34 | 1. If using Windows, keep num_works zero 35 | 36 | 2. There are some other args, which you can get them by adding '-h' or reading the code. 37 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from timeit import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.optim.lr_scheduler 8 | from torchvision import datasets, transforms 9 | from tqdm import tqdm 10 | 11 | from net import AlexNetPlusLatent 12 | 13 | parser = argparse.ArgumentParser(description='Deep Hashing evaluate mAP') 14 | parser.add_argument('--pretrained', type=float, default=0, metavar='pretrained_model', 15 | help='loading pretrained model(default = None)') 16 | parser.add_argument('--bits', type=int, default=48, metavar='bts', 17 | help='binary bits') 18 | args = parser.parse_args() 19 | 20 | 21 | def load_data(): 22 | transform_train = transforms.Compose( 23 | [transforms.Resize(227), 24 | transforms.ToTensor(), 25 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) 26 | transform_test = transforms.Compose( 27 | [transforms.Resize(227), 28 | transforms.ToTensor(), 29 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) 30 | trainset = datasets.CIFAR10(root='./data', train=True, download=True, 31 | transform=transform_train) 32 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, 33 | shuffle=False, num_workers=0) 34 | 35 | testset = datasets.CIFAR10(root='./data', train=False, download=True, 36 | transform=transform_test) 37 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, 38 | shuffle=False, num_workers=0) 39 | return trainloader, testloader 40 | 41 | 42 | def binary_output(dataloader): 43 | net = AlexNetPlusLatent(args.bits) 44 | net.load_state_dict(torch.load('./model/{}'.format(args.pretrained))) 45 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 46 | print("Use device: " + str(device)) 47 | net.to(device) 48 | full_batch_output = torch.cuda.FloatTensor() 49 | full_batch_label = torch.cuda.LongTensor() 50 | with torch.no_grad(): 51 | for batch_idx, (inputs, targets) in enumerate(dataloader): 52 | inputs, targets = inputs.to(device), targets.to(device) 53 | outputs, _ = net(inputs) 54 | full_batch_output = torch.cat((full_batch_output, outputs.data), 0) 55 | full_batch_label = torch.cat((full_batch_label, targets.data), 0) 56 | return torch.round(full_batch_output), full_batch_label 57 | 58 | 59 | def evaluate(trn_binary, trn_label, tst_binary, tst_label): 60 | classes = np.max(tst_label) + 1 61 | for i in range(classes): 62 | if i == 0: 63 | tst_sample_binary = tst_binary[np.random.RandomState(seed=i).permutation(np.where(tst_label == i)[0])[:100]] 64 | tst_sample_label = np.array([i]).repeat(100) 65 | continue 66 | else: 67 | tst_sample_binary = np.concatenate([tst_sample_binary, tst_binary[np.random.RandomState(seed=i).permutation(np.where(tst_label==i)[0])[:100]]]) 68 | tst_sample_label = np.concatenate([tst_sample_label, np.array([i]).repeat(100)]) 69 | query_times = tst_sample_binary.shape[0] 70 | trainset_len = trn_binary.shape[0] 71 | AP = np.zeros(query_times) 72 | precision_radius = np.zeros(query_times) 73 | Ns = np.arange(1, trainset_len + 1) 74 | sum_tp = np.zeros(trainset_len) 75 | total_time_start = time.time() 76 | with tqdm(total=query_times, desc="Query") as pbar: 77 | for i in range(query_times): 78 | query_label = tst_sample_label[i] 79 | query_binary = tst_sample_binary[i, :] 80 | query_result = np.count_nonzero(query_binary != trn_binary, axis=1) # don't need to divide binary length 81 | sort_indices = np.argsort(query_result) 82 | buffer_yes = np.equal(query_label, trn_label[sort_indices]).astype(int) 83 | P = np.cumsum(buffer_yes) / Ns 84 | precision_radius[i] = P[np.where(np.sort(query_result) > 2)[0][0]-1] 85 | AP[i] = np.sum(P * buffer_yes) / sum(buffer_yes) 86 | sum_tp = sum_tp + np.cumsum(buffer_yes) 87 | pbar.set_postfix({'Average Precision': '{0:1.5f}'.format(AP[i])}) 88 | pbar.update(1) 89 | pbar.close() 90 | mAP = np.mean(AP) 91 | precision_at_k = sum_tp / Ns / query_times 92 | index = [100, 200, 400, 600, 800, 1000] 93 | index = [i - 1 for i in index] 94 | print('precision at k:', precision_at_k[index]) 95 | print('precision within Hamming radius 2:', np.mean(precision_radius)) 96 | map = np.mean(AP) 97 | print('mAP:', map) 98 | print('Total query time:', time.time() - total_time_start) 99 | 100 | 101 | if __name__ == "__main__": 102 | if os.path.exists('./result/train_binary') and os.path.exists('./result/train_label') and \ 103 | os.path.exists('./result/test_binary') and os.path.exists('./result/test_label') and args.pretrained == 0: 104 | train_binary = torch.load('./result/train_binary') 105 | train_label = torch.load('./result/train_label') 106 | test_binary = torch.load('./result/test_binary') 107 | test_label = torch.load('./result/test_label') 108 | 109 | else: 110 | trainloader, testloader = load_data() 111 | train_binary, train_label = binary_output(trainloader) 112 | test_binary, test_label = binary_output(testloader) 113 | if not os.path.isdir('result'): 114 | os.mkdir('result') 115 | torch.save(train_binary, './result/train_binary') 116 | torch.save(train_label, './result/train_label') 117 | torch.save(test_binary, './result/test_binary') 118 | torch.save(test_label, './result/test_label') 119 | 120 | train_binary = train_binary.cpu().numpy() 121 | train_binary = np.asarray(train_binary, np.int32) 122 | train_label = train_label.cpu().numpy() 123 | test_binary = test_binary.cpu().numpy() 124 | test_binary = np.asarray(test_binary, np.int32) 125 | test_label = test_label.cpu().numpy() 126 | 127 | evaluate(train_binary, train_label, test_binary, test_label) 128 | 129 | 130 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | from torchvision import models 4 | 5 | os.environ['TORCH_HOME'] = 'models' 6 | alexnet_model = models.alexnet(pretrained=True) 7 | 8 | 9 | class AlexNetPlusLatent(nn.Module): 10 | def __init__(self, bits): 11 | super(AlexNetPlusLatent, self).__init__() 12 | self.bits = bits 13 | self.features = nn.Sequential(*list(alexnet_model.features.children())) 14 | self.remain = nn.Sequential(*list(alexnet_model.classifier.children())[:-1]) 15 | self.Linear1 = nn.Linear(4096, self.bits) 16 | self.sigmoid = nn.Sigmoid() 17 | self.Linear2 = nn.Linear(self.bits, 10) 18 | 19 | def forward(self, x): 20 | x = self.features(x) 21 | x = x.view(x.size(0), 256 * 6 * 6) 22 | x = self.remain(x) 23 | x = self.Linear1(x) 24 | features = self.sigmoid(x) 25 | result = self.Linear2(features) 26 | return features, result 27 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import shutil 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim.lr_scheduler 9 | from torchvision import datasets, transforms 10 | from tqdm import tqdm 11 | 12 | from net import AlexNetPlusLatent 13 | 14 | parser = argparse.ArgumentParser(description='Deep Hashing') 15 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 16 | help='learning rate (default: 0.01)') 17 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 18 | help='SGD momentum (default: 0.9)') 19 | parser.add_argument('--epoch', type=int, default=128, metavar='epoch', 20 | help='epoch') 21 | parser.add_argument('--pretrained', type=str, default=0, metavar='pretrained_model', 22 | help='loading pretrained model(default = None)') 23 | parser.add_argument('--bits', type=int, default=48, metavar='bts', 24 | help='binary bits') 25 | parser.add_argument('--path', type=str, default='model', metavar='P', 26 | help='path directory') 27 | args = parser.parse_args() 28 | 29 | 30 | def init_dataset(): 31 | transform_train = transforms.Compose( 32 | [transforms.Resize(256), 33 | transforms.RandomCrop(227), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.ToTensor(), 36 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) 37 | transform_test = transforms.Compose( 38 | [transforms.Resize(227), 39 | transforms.ToTensor(), 40 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) 41 | trainset = datasets.CIFAR10(root='./data', train=True, download=True, 42 | transform=transform_train) 43 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 44 | shuffle=True, num_workers=0) 45 | 46 | testset = datasets.CIFAR10(root='./data', train=False, download=True, 47 | transform=transform_test) 48 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, 49 | shuffle=True, num_workers=0) 50 | return trainloader, testloader 51 | 52 | 53 | def train(epoch_num): 54 | print('\nEpoch: %d' % epoch_num) 55 | net.train() 56 | train_loss = 0 57 | correct = 0 58 | total = 0 59 | with tqdm(total=math.ceil(len(trainloader)), desc="Training") as pbar: 60 | for batch_idx, (inputs, targets) in enumerate(trainloader): 61 | inputs, targets = inputs.to(device), targets.to(device) 62 | _, outputs = net(inputs) 63 | loss = softmaxloss(outputs, targets) 64 | optimizer4nn.zero_grad() 65 | loss.backward() 66 | optimizer4nn.step() 67 | train_loss += softmaxloss(outputs, targets).item() 68 | _, predicted = torch.max(outputs.data, 1) 69 | total += targets.size(0) 70 | correct += predicted.eq(targets.data).sum() 71 | pbar.set_postfix({'loss': '{0:1.5f}'.format(loss), 'accurate': '{:.2%}'.format(correct.item() / total)}) 72 | pbar.update(1) 73 | pbar.close() 74 | return train_loss / (batch_idx + 1) 75 | 76 | 77 | def test(): 78 | net.eval() 79 | with torch.no_grad(): 80 | test_loss = 0 81 | correct = 0 82 | total = 0 83 | with tqdm(total=math.ceil(len(testloader)), desc="Testing") as pbar: 84 | for batch_idx, (inputs, targets) in enumerate(testloader): 85 | inputs, targets = inputs.to(device), targets.to(device) 86 | _, outputs = net(inputs) 87 | loss = softmaxloss(outputs, targets) 88 | test_loss += loss.item() 89 | _, predicted = torch.max(outputs.data, 1) 90 | total += targets.size(0) 91 | correct += predicted.eq(targets.data).sum() 92 | pbar.set_postfix({'loss': '{0:1.5f}'.format(loss), 'accurate': '{:.2%}'.format(correct.item() / total)}) 93 | pbar.update(1) 94 | pbar.close() 95 | acc = 100 * int(correct) / int(total) 96 | if epoch == args.epoch: 97 | print('Saving') 98 | if not os.path.isdir('{}'.format(args.path)): 99 | os.mkdir('{}'.format(args.path)) 100 | torch.save(net.state_dict(), './{}/{}'.format(args.path, acc)) 101 | 102 | 103 | if __name__ == '__main__': 104 | torch.cuda.empty_cache() # When using windows, this line is needed 105 | trainloader, testloader = init_dataset() 106 | net = AlexNetPlusLatent(args.bits) 107 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 108 | print("Use device: " + str(device)) 109 | net.to(device) 110 | softmaxloss = nn.CrossEntropyLoss().cuda() 111 | optimizer4nn = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=0.0005) 112 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer4nn, milestones=[args.epoch], gamma=0.1) 113 | best_acc = 0 114 | start_epoch = 1 115 | if args.pretrained: 116 | net.load_state_dict(torch.load('./{}/{}'.format(args.path, args.pretrained))) 117 | test() 118 | else: 119 | if os.path.isdir('{}'.format(args.path)): 120 | shutil.rmtree('{}'.format(args.path)) 121 | for epoch in range(start_epoch, start_epoch + args.epoch): 122 | train(epoch) 123 | test() 124 | scheduler.step(epoch) 125 | 126 | --------------------------------------------------------------------------------