├── README.md ├── model.py ├── utils.py └── main.py /README.md: -------------------------------------------------------------------------------- 1 | # DSH-pytorch 2 | PyTorch implementation of paper [Deep Supervised Hashing for Fast Image Retrieval](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Liu_Deep_Supervised_Hashing_CVPR_2016_paper.pdf) 3 | 4 | And thanks to [This repo on TensorFlow](https://github.com/yg33717/DSH_tensorflow) 5 | 6 | ## Result 7 | I only tested it once as it has been pushed, and it seems to be able to reproduce the result claimed on the paper. 8 | 9 | More results will be added... 10 | 11 | # How to Run 12 | I intended to make it easy for reading. You can easily run it by 13 | ``` 14 | python main.py 15 | ``` 16 | commandline uasge have a nice-looking help 17 | ``` 18 | pyhon main.py -h 19 | ``` 20 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.init as init 3 | 4 | 5 | class DSH(nn.Module): 6 | def __init__(self, num_binary): 7 | super().__init__() 8 | self.conv = nn.Sequential( 9 | nn.Conv2d(3, 32, kernel_size=5, padding=2), # same padding 10 | nn.ReLU(inplace=True), 11 | nn.MaxPool2d(kernel_size=3, stride=2), 12 | 13 | nn.Conv2d(32, 32, kernel_size=5, padding=2), 14 | nn.ReLU(inplace=True), 15 | nn.AvgPool2d(kernel_size=3, stride=2), 16 | 17 | nn.Conv2d(32, 64, kernel_size=5, padding=2), 18 | nn.ReLU(inplace=True), 19 | nn.AvgPool2d(kernel_size=3, stride=2), 20 | ) 21 | self.fc = nn.Sequential( 22 | nn.Linear(64 * 3 * 3, 500), 23 | nn.ReLU(inplace=True), 24 | 25 | nn.Linear(500, num_binary) 26 | ) 27 | 28 | for m in self.modules(): 29 | if m.__class__ == nn.Conv2d or m.__class__ == nn.Linear: 30 | init.xavier_normal(m.weight.data) 31 | m.bias.data.fill_(0) 32 | 33 | def forward(self, x): 34 | x = self.conv(x) 35 | x = x.view(x.size(0), -1) 36 | x = self.fc(x) 37 | 38 | return x 39 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import wraps 3 | import numpy as np 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader 8 | from torchvision.datasets.cifar import CIFAR10 9 | import torchvision.transforms as transforms 10 | 11 | 12 | def init_cifar_dataloader(root, batchSize): 13 | """load dataset""" 14 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 15 | transform_train = transforms.Compose([ 16 | transforms.RandomHorizontalFlip(), 17 | transforms.ToTensor(), 18 | normalize 19 | ]) 20 | transform_test = transforms.Compose([ 21 | transforms.ToTensor(), 22 | normalize 23 | ]) 24 | 25 | train_loader = DataLoader(CIFAR10(root, train=True, download=True, transform=transform_train), 26 | batch_size=batchSize, shuffle=True, num_workers=4, pin_memory=True) 27 | print(f'train set: {len(train_loader.dataset)}') 28 | test_loader = DataLoader(CIFAR10(root, train=False, download=True, transform=transform_test), 29 | batch_size=batchSize * 8, shuffle=False, num_workers=4, pin_memory=True) 30 | print(f'val set: {len(test_loader.dataset)}') 31 | 32 | return train_loader, test_loader 33 | 34 | 35 | def timing(f): 36 | """print time used for function f""" 37 | 38 | @wraps(f) 39 | def wrapper(*args, **kwargs): 40 | time_start = time.time() 41 | ret = f(*args, **kwargs) 42 | print(f'total time = {time.time() - time_start:.4f}') 43 | return ret 44 | 45 | return wrapper 46 | 47 | 48 | def compute_result(dataloader, net): 49 | bs, clses = [], [] 50 | net.eval() 51 | for img, cls in dataloader: 52 | clses.append(cls) 53 | bs.append(net(Variable(img.cuda(), volatile=True)).data.cpu()) 54 | return torch.sign(torch.cat(bs)), torch.cat(clses) 55 | 56 | 57 | @timing 58 | def compute_mAP(trn_binary, tst_binary, trn_label, tst_label): 59 | """ 60 | compute mAP by searching testset from trainset 61 | https://github.com/flyingpot/pytorch_deephash 62 | """ 63 | for x in trn_binary, tst_binary, trn_label, tst_label: x.long() 64 | 65 | AP = [] 66 | Ns = torch.arange(1, trn_binary.size(0) + 1) 67 | for i in range(tst_binary.size(0)): 68 | query_label, query_binary = tst_label[i], tst_binary[i] 69 | _, query_result = torch.sum((query_binary != trn_binary).long(), dim=1).sort() 70 | correct = (query_label == trn_label[query_result]).float() 71 | P = torch.cumsum(correct, dim=0) / Ns 72 | AP.append(torch.sum(P * correct) / torch.sum(correct)) 73 | mAP = torch.mean(torch.Tensor(AP)) 74 | return mAP 75 | 76 | 77 | def choose_gpu(i_gpu): 78 | """choose current CUDA device""" 79 | torch.cuda.device(i_gpu).__enter__() 80 | cudnn.benchmark = True 81 | 82 | 83 | def feed_random_seed(seed=np.random.randint(1, 10000)): 84 | """feed random seed""" 85 | np.random.seed(seed) 86 | torch.manual_seed(seed) 87 | torch.cuda.manual_seed(seed) 88 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.optim as optim 5 | from tensorboardX import SummaryWriter 6 | 7 | from model import * 8 | from utils import * 9 | 10 | 11 | def hashing_loss(b, cls, m, alpha): 12 | """ 13 | compute hashing loss 14 | automatically consider all n^2 pairs 15 | """ 16 | y = (cls.unsqueeze(0) != cls.unsqueeze(1)).float().view(-1) 17 | dist = ((b.unsqueeze(0) - b.unsqueeze(1)) ** 2).sum(dim=2).view(-1) 18 | loss = (1 - y) / 2 * dist + y / 2 * (m - dist).clamp(min=0) 19 | 20 | loss = loss.mean() + alpha * (b.abs() - 1).abs().sum(dim=1).mean() * 2 21 | 22 | return loss 23 | 24 | 25 | def train(epoch, dataloader, net, optimizer, m, alpha): 26 | accum_loss = 0 27 | net.train() 28 | for i, (img, cls) in enumerate(dataloader): 29 | img, cls = [Variable(x.cuda()) for x in (img, cls)] 30 | 31 | net.zero_grad() 32 | b = net(img) 33 | loss = hashing_loss(b, cls, m, alpha) 34 | 35 | loss.backward() 36 | optimizer.step() 37 | accum_loss += loss.data[0] 38 | 39 | print(f'[{epoch}][{i}/{len(dataloader)}] loss: {loss.data[0]:.4f}') 40 | return accum_loss / len(dataloader) 41 | 42 | 43 | def test(epoch, dataloader, net, m, alpha): 44 | accum_loss = 0 45 | net.eval() 46 | for img, cls in dataloader: 47 | img, cls = [Variable(x.cuda(), volatile=True) for x in (img, cls)] 48 | 49 | b = net(img) 50 | loss = hashing_loss(b, cls, m, alpha) 51 | accum_loss += loss.data[0] 52 | 53 | accum_loss /= len(dataloader) 54 | print(f'[{epoch}] val loss: {accum_loss:.4f}') 55 | return accum_loss 56 | 57 | 58 | def main(): 59 | parser = argparse.ArgumentParser(description='train DSH') 60 | parser.add_argument('--cifar', default='../dataset/cifar', help='path to cifar') 61 | parser.add_argument('--weights', default='', help="path to weight (to continue training)") 62 | parser.add_argument('--outf', default='checkpoints', help='folder to output model checkpoints') 63 | parser.add_argument('--checkpoint', type=int, default=50, help='checkpointing after batches') 64 | 65 | parser.add_argument('--batchSize', type=int, default=256, help='input batch size') 66 | parser.add_argument('--ngpu', type=int, default=0, help='which GPU to use') 67 | 68 | parser.add_argument('--binary_bits', type=int, default=12, help='length of hashing binary') 69 | parser.add_argument('--alpha', type=float, default=0.01, help='weighting of regularizer') 70 | 71 | parser.add_argument('--niter', type=int, default=500, help='number of epochs to train for') 72 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 73 | 74 | opt = parser.parse_args() 75 | print(opt) 76 | 77 | os.makedirs(opt.outf, exist_ok=True) 78 | choose_gpu(opt.ngpu) 79 | feed_random_seed() 80 | train_loader, test_loader = init_cifar_dataloader(opt.cifar, opt.batchSize) 81 | logger = SummaryWriter() 82 | 83 | # setup net 84 | net = DSH(opt.binary_bits) 85 | resume_epoch = 0 86 | print(net) 87 | if opt.weights: 88 | print(f'loading weight form {opt.weights}') 89 | resume_epoch = int(os.path.basename(opt.weights)[:-4]) 90 | net.load_state_dict(torch.load(opt.weights, map_location=lambda storage, location: storage)) 91 | 92 | net.cuda() 93 | 94 | # setup optimizer 95 | optimizer = optim.Adam(net.parameters(), lr=opt.lr, weight_decay=0.004) 96 | 97 | for epoch in range(resume_epoch, opt.niter): 98 | train_loss = train(epoch, train_loader, net, optimizer, 2 * opt.binary_bits, opt.alpha) 99 | logger.add_scalar('train_loss', train_loss, epoch) 100 | 101 | test_loss = test(epoch, test_loader, net, 2 * opt.binary_bits, opt.alpha) 102 | logger.add_scalar('test_loss', test_loss, epoch) 103 | 104 | if epoch % opt.checkpoint == 0: 105 | # compute mAP by searching testset images from trainset 106 | trn_binary, trn_label = compute_result(train_loader, net) 107 | tst_binary, tst_label = compute_result(test_loader, net) 108 | mAP = compute_mAP(trn_binary, tst_binary, trn_label, tst_label) 109 | print(f'[{epoch}] retrieval mAP: {mAP:.4f}') 110 | logger.add_scalar('retrieval_mAP', mAP, epoch) 111 | 112 | # save checkpoints 113 | torch.save(net.state_dict(), os.path.join(opt.outf, f'{epoch:03d}.pth')) 114 | 115 | 116 | if __name__ == '__main__': 117 | main() 118 | --------------------------------------------------------------------------------