├── README.md ├── metric.py ├── model.py ├── result ├── structure.png └── vis.png ├── test.py ├── train.py ├── utils.py └── vis.py /README.md: -------------------------------------------------------------------------------- 1 | # ACNet 2 | 3 | A PyTorch implementation of ACNet based on TCSVT 2023 paper 4 | [ACNet: Approaching-and-Centralizing Network for Zero-Shot Sketch-Based Image Retrieval](https://ieeexplore.ieee.org/document/10052737). 5 | 6 | ![Network Architecture](result/structure.png) 7 | 8 | ## Requirements 9 | 10 | - [Anaconda](https://www.anaconda.com/download/) 11 | - [PyTorch](https://pytorch.org) 12 | 13 | ``` 14 | conda install pytorch=1.10.0 torchvision cudatoolkit=11.3 -c pytorch 15 | ``` 16 | 17 | - [Timm](https://rwightman.github.io/pytorch-image-models/) 18 | 19 | ``` 20 | pip install timm 21 | ``` 22 | 23 | - [Pytorch Metric Learning](https://kevinmusgrave.github.io/pytorch-metric-learning/) 24 | 25 | ``` 26 | conda install pytorch-metric-learning -c metric-learning -c pytorch 27 | ``` 28 | 29 | ## Dataset 30 | 31 | [Sketchy Extended](http://sketchy.eye.gatech.edu) and 32 | [TU-Berlin Extended](http://cybertron.cg.tu-berlin.de/eitz/projects/classifysketch/) datasets are used in this repo, you 33 | could download these datasets from official websites, or download them from 34 | [Google Drive](https://drive.google.com/drive/folders/1lce41k7cGNUOwzt-eswCeahDLWG6Cdk0?usp=sharing). The data directory 35 | structure is shown as follows: 36 | 37 | ``` 38 | ├──sketchy 39 | ├── train 40 | ├── sketch 41 | ├── airplane 42 | ├── n02691156_58-1.jpg 43 | └── ... 44 | ... 45 | ├── photo 46 | same structure as sketch 47 | ├── val 48 | same structure as train 49 | ... 50 | ├──tuberlin 51 | same structure as sketchy 52 | ... 53 | ``` 54 | 55 | ## Usage 56 | 57 | ### Train Model 58 | 59 | ``` 60 | python train.py --data_name tuberlin 61 | optional arguments: 62 | --data_root Datasets root path [default value is '/data'] 63 | --data_name Dataset name [default value is 'sketchy'](choices=['sketchy', 'tuberlin']) 64 | --backbone_type Backbone type [default value is 'resnet50'](choices=['resnet50', 'vgg16']) 65 | --emb_dim Embedding dim [default value is 512] 66 | --batch_size Number of images in each mini-batch [default value is 64] 67 | --epochs Number of epochs over the model to train [default value is 10] 68 | --warmup Number of warmups over the extractor to train [default value is 1] 69 | --save_root Result saved root path [default value is 'result'] 70 | ``` 71 | 72 | ### Test Model 73 | 74 | ``` 75 | python test.py --num 8 76 | optional arguments: 77 | --data_root Datasets root path [default value is '/data'] 78 | --query_name Query image name [default value is '/data/sketchy/val/sketch/cow/n01887787_591-14.jpg'] 79 | --data_base Queried database [default value is 'result/sketchy_resnet50_512_vectors.pth'] 80 | --num Retrieval number [default value is 5] 81 | --save_root Result saved root path [default value is 'result'] 82 | ``` 83 | 84 | ## Benchmarks 85 | 86 | The models are trained on one NVIDIA GTX TITAN (12G) GPU. `Adam` is used to optimize the model, `lr` is `1e-5` 87 | for backbone, `1e-3` for generator and `1e-4` for discriminator. all the hyper-parameters are the default values. 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 |
BackboneDimSketchy ExtendedTU-Berlin ExtendedDownload
mAP@200mAP@allP@100P@200mAP@200mAP@allP@100P@200
VGG166432.638.048.744.739.837.150.648.0MEGA
VGG1651238.342.253.349.347.243.958.155.3MEGA
VGG16409640.043.254.650.851.747.962.359.3MEGA
ResNet506443.046.056.852.747.544.957.254.9MEGA
ResNet5051251.755.964.360.857.757.765.864.4MEGA
ResNet50409651.155.763.860.057.358.664.663.5MEGA
190 | 191 | ## Results 192 | 193 | ![vis](result/vis.png) 194 | 195 | ## Citing ACNet 196 | 197 | If you find ACNet helpful, please consider citing: 198 | ``` 199 | @article{ren2023acnet, 200 | title={ACNet: Approaching-and-Centralizing Network for Zero-Shot Sketch-Based Image Retrieval}, 201 | author={Ren, Hao and Zheng, Ziqiang and Wu, Yang and Lu, Hong and Yang, Yang and Shan, Ying and Yeung, Sai-Kit}, 202 | journal={IEEE Transactions on Circuits and Systems for Video Technology}, 203 | year={2023}, 204 | publisher={IEEE} 205 | } 206 | ``` 207 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.distance import cdist 3 | 4 | 5 | def sake_metric(predicted_features_gallery, gt_labels_gallery, predicted_features_query, gt_labels_query, k=None): 6 | if k is None: 7 | k = {'precision': 100, 'map': predicted_features_gallery.shape[0]} 8 | if k['precision'] is None: 9 | k['precision'] = 100 10 | if k['map'] is None: 11 | k['map'] = predicted_features_gallery.shape[0] 12 | 13 | scores = -cdist(predicted_features_query, predicted_features_gallery, metric='cosine') 14 | gt_labels_query = gt_labels_query.flatten() 15 | gt_labels_gallery = gt_labels_gallery.flatten() 16 | aps = map_sake(gt_labels_gallery, predicted_features_query, gt_labels_query, scores, k=k['map']) 17 | prec = prec_sake(gt_labels_gallery, predicted_features_query, gt_labels_query, scores, k=k['precision']) 18 | return sum(aps) / len(aps), prec 19 | 20 | 21 | def map_sake(gt_labels_gallery, predicted_features_query, gt_labels_query, scores, k=None): 22 | mAP_ls = [[] for _ in range(len(np.unique(gt_labels_query)))] 23 | mean_mAP = [] 24 | for fi in range(predicted_features_query.shape[0]): 25 | mapi = eval_ap(gt_labels_query[fi], scores[fi], gt_labels_gallery, top=k) 26 | mAP_ls[gt_labels_query[fi]].append(mapi) 27 | mean_mAP.append(mapi) 28 | return mean_mAP 29 | 30 | 31 | def prec_sake(gt_labels_gallery, predicted_features_query, gt_labels_query, scores, k=None): 32 | # compute precision for two modalities 33 | prec_ls = [[] for _ in range(len(np.unique(gt_labels_query)))] 34 | mean_prec = [] 35 | for fi in range(predicted_features_query.shape[0]): 36 | prec = eval_precision(gt_labels_query[fi], scores[fi], gt_labels_gallery, top=k) 37 | prec_ls[gt_labels_query[fi]].append(prec) 38 | mean_prec.append(prec) 39 | return np.nanmean(mean_prec) 40 | 41 | 42 | def eval_ap(inst_id, scores, gt_labels, top=None): 43 | pos_flag = gt_labels == inst_id 44 | tot = scores.shape[0] 45 | tot_pos = np.sum(pos_flag) 46 | 47 | sort_idx = np.argsort(-scores) 48 | tp = pos_flag[sort_idx] 49 | fp = np.logical_not(tp) 50 | 51 | if top is not None: 52 | top = min(top, tot) 53 | tp = tp[:top] 54 | fp = fp[:top] 55 | tot_pos = min(top, tot_pos) 56 | 57 | fp = np.cumsum(fp) 58 | tp = np.cumsum(tp) 59 | try: 60 | rec = tp / tot_pos 61 | prec = tp / (tp + fp) 62 | except: 63 | return np.nan 64 | 65 | ap = voc_ap(rec, prec) 66 | return ap 67 | 68 | 69 | def voc_ap(rec, prec): 70 | mrec = np.append(0, rec) 71 | mrec = np.append(mrec, 1) 72 | 73 | mpre = np.append(0, prec) 74 | mpre = np.append(mpre, 0) 75 | 76 | for ii in range(len(mpre) - 2, -1, -1): 77 | mpre[ii] = max(mpre[ii], mpre[ii + 1]) 78 | 79 | msk = [i != j for i, j in zip(mrec[1:], mrec[0:-1])] 80 | ap = np.sum((mrec[1:][msk] - mrec[0:-1][msk]) * mpre[1:][msk]) 81 | return ap 82 | 83 | 84 | def eval_precision(inst_id, scores, gt_labels, top=100): 85 | pos_flag = gt_labels == inst_id 86 | tot = scores.shape[0] 87 | top = min(top, tot) 88 | sort_idx = np.argsort(-scores) 89 | return np.sum(pos_flag[sort_idx][:top]) / top 90 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_channels): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv = nn.Sequential(nn.Conv2d(in_channels, in_channels, 3, padding=1, padding_mode='reflect'), 11 | nn.InstanceNorm2d(in_channels), nn.ReLU(inplace=True), 12 | nn.Conv2d(in_channels, in_channels, 3, padding=1, padding_mode='reflect'), 13 | nn.InstanceNorm2d(in_channels)) 14 | 15 | def forward(self, x): 16 | return x + self.conv(x) 17 | 18 | 19 | class Generator(nn.Module): 20 | def __init__(self, in_channels=64, num_block=9): 21 | super(Generator, self).__init__() 22 | 23 | # in conv 24 | self.in_conv = nn.Sequential(nn.Conv2d(3, in_channels, 7, padding=3, padding_mode='reflect'), 25 | nn.InstanceNorm2d(in_channels), nn.ReLU(inplace=True)) 26 | 27 | # down sample 28 | down_sample = [] 29 | for _ in range(2): 30 | out_channels = in_channels * 2 31 | down_sample += [nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1), 32 | nn.InstanceNorm2d(out_channels), nn.ReLU(inplace=True)] 33 | in_channels = out_channels 34 | self.down_sample = nn.Sequential(*down_sample) 35 | 36 | # conv blocks 37 | self.convs = nn.Sequential(*[ResidualBlock(in_channels) for _ in range(num_block)]) 38 | 39 | # up sample 40 | up_sample = [] 41 | for _ in range(2): 42 | out_channels = in_channels // 2 43 | up_sample += [nn.ConvTranspose2d(in_channels, out_channels, 3, stride=2, padding=1, output_padding=1), 44 | nn.InstanceNorm2d(out_channels), nn.ReLU(inplace=True)] 45 | in_channels = out_channels 46 | self.up_sample = nn.Sequential(*up_sample) 47 | 48 | # out conv 49 | self.out_conv = nn.Sequential(nn.Conv2d(in_channels, 3, 7, padding=3, padding_mode='reflect'), nn.Tanh()) 50 | 51 | def forward(self, x): 52 | x = self.in_conv(x) 53 | x = self.down_sample(x) 54 | x = self.convs(x) 55 | x = self.up_sample(x) 56 | out = self.out_conv(x) 57 | return out 58 | 59 | 60 | class Discriminator(nn.Module): 61 | def __init__(self, in_channels=64): 62 | super(Discriminator, self).__init__() 63 | 64 | self.conv1 = nn.Sequential(nn.Conv2d(3, in_channels, 4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True)) 65 | 66 | self.conv2 = nn.Sequential(nn.Conv2d(in_channels, in_channels * 2, 4, stride=2, padding=1), 67 | nn.InstanceNorm2d(in_channels * 2), nn.LeakyReLU(0.2, inplace=True)) 68 | 69 | self.conv3 = nn.Sequential(nn.Conv2d(in_channels * 2, in_channels * 4, 4, stride=2, padding=1), 70 | nn.InstanceNorm2d(in_channels * 4), nn.LeakyReLU(0.2, inplace=True)) 71 | 72 | self.conv4 = nn.Sequential(nn.Conv2d(in_channels * 4, in_channels * 8, 4, padding=1), 73 | nn.InstanceNorm2d(in_channels * 8), nn.LeakyReLU(0.2, inplace=True)) 74 | 75 | self.conv5 = nn.Conv2d(in_channels * 8, 1, 4, padding=1) 76 | 77 | def forward(self, x): 78 | x = self.conv1(x) 79 | x = self.conv2(x) 80 | x = self.conv3(x) 81 | x = self.conv4(x) 82 | out = self.conv5(x) 83 | return out 84 | 85 | 86 | class Extractor(nn.Module): 87 | def __init__(self, backbone_type, emb_dim): 88 | super(Extractor, self).__init__() 89 | 90 | # backbone 91 | model_name = 'resnet50' if backbone_type == 'resnet50' else 'vgg16' 92 | self.backbone = timm.create_model(model_name, pretrained=True, num_classes=emb_dim, global_pool='max') 93 | 94 | def forward(self, x): 95 | x = self.backbone(x) 96 | out = F.normalize(x, dim=-1) 97 | return out 98 | 99 | 100 | def set_bn_eval(m): 101 | classname = m.__class__.__name__ 102 | if classname.find('BatchNorm2d') != -1: 103 | m.eval() -------------------------------------------------------------------------------- /result/structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/ACNet/60387b1c0429282840fba0fbf76f372b7187f02b/result/structure.png -------------------------------------------------------------------------------- /result/vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/ACNet/60387b1c0429282840fba0fbf76f372b7187f02b/result/vis.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | import torch 6 | from PIL import Image, ImageDraw 7 | 8 | from utils import DomainDataset 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser(description='Test Model') 12 | parser.add_argument('--data_root', default='/data', type=str, help='Datasets root path') 13 | parser.add_argument('--query_name', default='/data/sketchy/val/sketch/cow/n01887787_591-14.jpg', type=str, 14 | help='Query image name') 15 | parser.add_argument('--data_base', default='result/sketchy_resnet50_512_vectors.pth', type=str, 16 | help='Queried database') 17 | parser.add_argument('--num', default=5, type=int, help='Retrieval number') 18 | parser.add_argument('--save_root', default='result', type=str, help='Result saved root path') 19 | 20 | opt = parser.parse_args() 21 | 22 | data_root, query_name, data_base, retrieval_num = opt.data_root, opt.query_name, opt.data_base, opt.num 23 | save_root, data_name = opt.save_root, data_base.split('/')[-1].split('_')[0] 24 | 25 | vectors = torch.load(data_base) 26 | val_data = DomainDataset(data_root, data_name, split='val') 27 | 28 | if query_name not in val_data.images: 29 | raise FileNotFoundError('{} not found'.format(query_name)) 30 | query_index = val_data.images.index(query_name) 31 | query_image = Image.open(query_name).resize((224, 224), resample=Image.BILINEAR) 32 | query_label = val_data.labels[query_index] 33 | query_feature = vectors[query_index] 34 | 35 | gallery_images, gallery_labels = [], [] 36 | for i, domain in enumerate(val_data.domains): 37 | if domain == 0: 38 | gallery_images.append(val_data.images[i]) 39 | gallery_labels.append(val_data.labels[i]) 40 | gallery_features = vectors[torch.tensor(val_data.domains) == 0] 41 | 42 | sim_matrix = query_feature.unsqueeze(0).mm(gallery_features.t()).squeeze() 43 | idx = sim_matrix.topk(k=retrieval_num, dim=-1)[1] 44 | 45 | result_path = '{}/{}'.format(save_root, query_name.split('/')[-1].split('.')[0]) 46 | if os.path.exists(result_path): 47 | shutil.rmtree(result_path) 48 | os.mkdir(result_path) 49 | query_image.save('{}/query.jpg'.format(result_path)) 50 | for num, index in enumerate(idx): 51 | retrieval_image = Image.open(gallery_images[index.item()]).resize((224, 224), resample=Image.BILINEAR) 52 | draw = ImageDraw.Draw(retrieval_image) 53 | retrieval_label = gallery_labels[index.item()] 54 | retrieval_status = retrieval_label == query_label 55 | retrieval_sim = sim_matrix[index.item()].item() 56 | if retrieval_status: 57 | draw.rectangle((0, 0, 223, 223), outline='green', width=8) 58 | else: 59 | draw.rectangle((0, 0, 223, 223), outline='red', width=8) 60 | retrieval_image.save('{}/retrieval_{}_{}.jpg'.format(result_path, num + 1, '%.4f' % retrieval_sim)) 61 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from pytorch_metric_learning.losses import NormalizedSoftmaxLoss 9 | from torch import nn 10 | from torch.backends import cudnn 11 | from torch.optim import Adam 12 | from torch.utils.data.dataloader import DataLoader 13 | from tqdm import tqdm 14 | 15 | from model import Extractor, Discriminator, Generator, set_bn_eval 16 | from utils import DomainDataset, compute_metric 17 | 18 | # for reproducibility 19 | random.seed(1) 20 | np.random.seed(1) 21 | torch.manual_seed(1) 22 | cudnn.deterministic = True 23 | cudnn.benchmark = False 24 | 25 | 26 | # train for one epoch 27 | def train(backbone, data_loader): 28 | backbone.train() 29 | # fix bn on backbone 30 | backbone.apply(set_bn_eval) 31 | generator.train() 32 | discriminator.train() 33 | total_extractor_loss, total_generator_loss, total_identity_loss, total_discriminator_loss = 0.0, 0.0, 0.0, 0.0 34 | total_num, train_bar = 0, tqdm(data_loader, dynamic_ncols=True) 35 | for sketch, photo, label in train_bar: 36 | sketch, photo, label = sketch.cuda(), photo.cuda(), label.cuda() 37 | 38 | optimizer_generator.zero_grad() 39 | optimizer_extractor.zero_grad() 40 | 41 | # generator # 42 | fake = generator(sketch) 43 | pred_fake = discriminator(fake) 44 | 45 | # generator loss 46 | target_fake = torch.ones(pred_fake.size(), device=pred_fake.device) 47 | gg_loss = adversarial_criterion(pred_fake, target_fake) 48 | total_generator_loss += gg_loss.item() * sketch.size(0) 49 | # identity loss 50 | ii_loss = identity_criterion(generator(photo), photo) 51 | total_identity_loss += ii_loss.item() * sketch.size(0) 52 | 53 | # extractor # 54 | sketch_proj = backbone(sketch) 55 | photo_proj = backbone(photo) 56 | fake_proj = backbone(fake) 57 | 58 | # extractor loss 59 | class_loss = (class_criterion(sketch_proj, label) + class_criterion(photo_proj, label) + 60 | class_criterion(fake_proj, label)) / 3 61 | total_extractor_loss += class_loss.item() * sketch.size(0) 62 | 63 | (gg_loss + 0.1 * ii_loss + 10 * class_loss).backward() 64 | 65 | optimizer_generator.step() 66 | optimizer_extractor.step() 67 | 68 | # discriminator loss # 69 | optimizer_discriminator.zero_grad() 70 | pred_photo = discriminator(photo) 71 | target_photo = torch.ones(pred_photo.size(), device=pred_photo.device) 72 | pred_fake = discriminator(fake.detach()) 73 | target_fake = torch.zeros(pred_fake.size(), device=pred_fake.device) 74 | adversarial_loss = (adversarial_criterion(pred_photo, target_photo) + 75 | adversarial_criterion(pred_fake, target_fake)) / 2 76 | total_discriminator_loss += adversarial_loss.item() * sketch.size(0) 77 | 78 | adversarial_loss.backward() 79 | optimizer_discriminator.step() 80 | 81 | total_num += sketch.size(0) 82 | 83 | e_loss = total_extractor_loss / total_num 84 | g_loss = total_generator_loss / total_num 85 | i_loss = total_identity_loss / total_num 86 | d_loss = total_discriminator_loss / total_num 87 | train_bar.set_description('Train Epoch: [{}/{}] E-Loss: {:.4f} G-Loss: {:.4f} I-Loss: {:.4f} D-Loss: {:.4f}' 88 | .format(epoch, epochs, e_loss, g_loss, i_loss, d_loss)) 89 | 90 | return e_loss, g_loss, i_loss, d_loss 91 | 92 | 93 | # val for one epoch 94 | def val(backbone, encoder, data_loader): 95 | backbone.eval() 96 | encoder.eval() 97 | vectors, domains, labels = [], [], [] 98 | with torch.no_grad(): 99 | for img, domain, label in tqdm(data_loader, desc='Feature extracting', dynamic_ncols=True): 100 | img = img.cuda() 101 | photo = img[domain == 0] 102 | sketch = img[domain == 1] 103 | photo_emb = backbone(photo) 104 | sketch_emb = backbone(encoder(sketch)) 105 | emb = torch.cat((photo_emb, sketch_emb), dim=0) 106 | vectors.append(emb.cpu()) 107 | label = torch.cat((label[domain == 0], label[domain == 1]), dim=0) 108 | labels.append(label) 109 | domain = torch.cat((domain[domain == 0], domain[domain == 1]), dim=0) 110 | domains.append(domain) 111 | vectors = torch.cat(vectors, dim=0) 112 | domains = torch.cat(domains, dim=0) 113 | labels = torch.cat(labels, dim=0) 114 | acc = compute_metric(vectors, domains, labels) 115 | results['P@100'].append(acc['P@100'] * 100) 116 | results['P@200'].append(acc['P@200'] * 100) 117 | results['mAP@200'].append(acc['mAP@200'] * 100) 118 | results['mAP@all'].append(acc['mAP@all'] * 100) 119 | print('Val Epoch: [{}/{}] | P@100:{:.1f}% | P@200:{:.1f}% | mAP@200:{:.1f}% | mAP@all:{:.1f}%' 120 | .format(epoch, epochs, acc['P@100'] * 100, acc['P@200'] * 100, acc['mAP@200'] * 100, 121 | acc['mAP@all'] * 100)) 122 | return acc['precise'], vectors 123 | 124 | 125 | if __name__ == '__main__': 126 | parser = argparse.ArgumentParser(description='Train Model') 127 | # common args 128 | parser.add_argument('--data_root', default='/data', type=str, help='Datasets root path') 129 | parser.add_argument('--data_name', default='sketchy', type=str, choices=['sketchy', 'tuberlin'], 130 | help='Dataset name') 131 | parser.add_argument('--backbone_type', default='resnet50', type=str, choices=['resnet50', 'vgg16'], 132 | help='Backbone type') 133 | parser.add_argument('--emb_dim', default=512, type=int, help='Embedding dim') 134 | parser.add_argument('--batch_size', default=64, type=int, help='Number of images in each mini-batch') 135 | parser.add_argument('--epochs', default=10, type=int, help='Number of epochs over the model to train') 136 | parser.add_argument('--warmup', default=1, type=int, help='Number of warmups over the extractor to train') 137 | parser.add_argument('--save_root', default='result', type=str, help='Result saved root path') 138 | 139 | # args parse 140 | args = parser.parse_args() 141 | data_root, data_name, backbone_type, emb_dim = args.data_root, args.data_name, args.backbone_type, args.emb_dim 142 | batch_size, epochs, warmup, save_root = args.batch_size, args.epochs, args.warmup, args.save_root 143 | 144 | # data prepare 145 | train_data = DomainDataset(data_root, data_name, split='train') 146 | val_data = DomainDataset(data_root, data_name, split='val') 147 | train_loader = DataLoader(train_data, batch_size=batch_size // 2, shuffle=True, num_workers=8) 148 | val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=8) 149 | 150 | # model define 151 | extractor = Extractor(backbone_type, emb_dim).cuda() 152 | generator = Generator(in_channels=8, num_block=8).cuda() 153 | discriminator = Discriminator(in_channels=8).cuda() 154 | 155 | # loss setup 156 | class_criterion = NormalizedSoftmaxLoss(len(train_data.classes), emb_dim).cuda() 157 | adversarial_criterion = nn.MSELoss() 158 | identity_criterion = nn.L1Loss() 159 | # optimizer config 160 | optimizer_extractor = Adam([{'params': extractor.parameters()}, {'params': class_criterion.parameters(), 161 | 'lr': 1e-3}], lr=1e-5) 162 | optimizer_generator = Adam(generator.parameters(), lr=1e-3, betas=(0.5, 0.999)) 163 | optimizer_discriminator = Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999)) 164 | 165 | # training loop 166 | results = {'extractor_loss': [], 'generator_loss': [], 'identity_loss': [], 'discriminator_loss': [], 167 | 'precise': [], 'P@100': [], 'P@200': [], 'mAP@200': [], 'mAP@all': []} 168 | save_name_pre = '{}_{}_{}'.format(data_name, backbone_type, emb_dim) 169 | if not os.path.exists(save_root): 170 | os.makedirs(save_root) 171 | best_precise = 0.0 172 | for epoch in range(1, epochs + 1): 173 | 174 | # warmup, not update the parameters of extractor, except the final fc layer 175 | for param in list(extractor.backbone.parameters())[:-2]: 176 | param.requires_grad = False if epoch <= warmup else True 177 | 178 | extractor_loss, generator_loss, identity_loss, discriminator_loss = train(extractor, train_loader) 179 | results['extractor_loss'].append(extractor_loss) 180 | results['generator_loss'].append(generator_loss) 181 | results['identity_loss'].append(identity_loss) 182 | results['discriminator_loss'].append(discriminator_loss) 183 | precise, features = val(extractor, generator, val_loader) 184 | results['precise'].append(precise * 100) 185 | 186 | # save statistics 187 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1)) 188 | data_frame.to_csv('{}/{}_results.csv'.format(save_root, save_name_pre), index_label='epoch') 189 | 190 | if precise > best_precise: 191 | best_precise = precise 192 | torch.save(extractor.state_dict(), '{}/{}_extractor.pth'.format(save_root, save_name_pre)) 193 | torch.save(generator.state_dict(), '{}/{}_generator.pth'.format(save_root, save_name_pre)) 194 | torch.save(discriminator.state_dict(), '{}/{}_discriminator.pth'.format(save_root, save_name_pre)) 195 | torch.save(class_criterion.state_dict(), '{}/{}_proxies.pth'.format(save_root, save_name_pre)) 196 | torch.save(features, '{}/{}_vectors.pth'.format(save_root, save_name_pre)) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | 5 | from PIL import Image 6 | from torch.utils.data.dataset import Dataset 7 | from torchvision import transforms 8 | from torchvision.transforms import InterpolationMode 9 | 10 | from metric import sake_metric 11 | 12 | 13 | def get_transform(split='train'): 14 | if split == 'train': 15 | return transforms.Compose([ 16 | transforms.Resize((224, 224), interpolation=InterpolationMode.BILINEAR), 17 | transforms.RandomHorizontalFlip(p=0.5), 18 | transforms.ToTensor(), 19 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 20 | else: 21 | return transforms.Compose([ 22 | transforms.Resize((224, 224), interpolation=InterpolationMode.BILINEAR), 23 | transforms.ToTensor(), 24 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 25 | 26 | 27 | class DomainDataset(Dataset): 28 | def __init__(self, data_root, data_name, split='train'): 29 | super(DomainDataset, self).__init__() 30 | 31 | images = [] 32 | for classes in os.listdir(os.path.join(data_root, data_name, split, 'sketch')): 33 | sketches = glob.glob(os.path.join(data_root, data_name, split, 'sketch', str(classes), '*.jpg')) 34 | photos = glob.glob(os.path.join(data_root, data_name, split, 'photo', str(classes), '*.jpg')) 35 | # only consider the classes which photo images >= 400 for tuberlin dataset 36 | if len(photos) < 400 and data_name == 'tuberlin' and split == 'val': 37 | pass 38 | else: 39 | images += sketches 40 | # only append sketches for train 41 | if split == 'val': 42 | images += photos 43 | self.images = sorted(images) 44 | self.transform = get_transform(split) 45 | 46 | self.domains, self.labels, self.classes = [], [], {} 47 | i = 0 48 | for img in self.images: 49 | domain, label = os.path.dirname(img).split('/')[-2:] 50 | self.domains.append(0 if domain == 'photo' else 1) 51 | if label not in self.classes: 52 | self.classes[label] = i 53 | i += 1 54 | self.labels.append(self.classes[label]) 55 | # store photos for each class to easy sample for sketch in training period 56 | if split == 'train': 57 | self.refs = {} 58 | for key, value in self.classes.items(): 59 | self.refs[value] = sorted(glob.glob(os.path.join(data_root, data_name, split, 'photo', key, '*.jpg'))) 60 | 61 | self.split = split 62 | 63 | def __getitem__(self, index): 64 | img = Image.open(self.images[index]) 65 | img = self.transform(img) 66 | label = self.labels[index] 67 | if self.split == 'val': 68 | domain = self.domains[index] 69 | return img, domain, label 70 | else: 71 | ref = Image.open(random.choice(self.refs[label])) 72 | ref = self.transform(ref) 73 | return img, ref, label 74 | 75 | def __len__(self): 76 | return len(self.images) 77 | 78 | 79 | def compute_metric(vectors, domains, labels): 80 | acc = {} 81 | 82 | photo_vectors = vectors[domains == 0].numpy() 83 | sketch_vectors = vectors[domains == 1].numpy() 84 | photo_labels = labels[domains == 0].numpy() 85 | sketch_labels = labels[domains == 1].numpy() 86 | map_all, p_100 = sake_metric(photo_vectors, photo_labels, sketch_vectors, sketch_labels) 87 | map_200, p_200 = sake_metric(photo_vectors, photo_labels, sketch_vectors, sketch_labels, 88 | {'precision': 200, 'map': 200}) 89 | 90 | acc['P@100'], acc['P@200'], acc['mAP@200'], acc['mAP@all'] = p_100, p_200, map_200, map_all 91 | # the mean value is chosen as the representative of precise 92 | acc['precise'] = (acc['P@100'] + acc['P@200'] + acc['mAP@200'] + acc['mAP@all']) / 4 93 | return acc 94 | 95 | -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import shutil 5 | 6 | import torch 7 | from PIL import Image 8 | from torchvision.transforms import ToPILImage 9 | from tqdm import tqdm 10 | 11 | from model import Generator 12 | from utils import get_transform 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser(description='Vis Generator') 16 | parser.add_argument('--sketch_name', default='/data/sketchy/val/sketch/cow', type=str, 17 | help='Sketch image name') 18 | parser.add_argument('--generator_name', default='result/sketchy_resnet50_512_generator.pth', type=str, 19 | help='Generator name') 20 | parser.add_argument('--save_root', default='result', type=str, help='Result saved root path') 21 | 22 | opt = parser.parse_args() 23 | 24 | sketch_names, generator_name, save_root = opt.sketch_name, opt.generator_name, opt.save_root 25 | 26 | generator = Generator(in_channels=8, num_block=8) 27 | generator.load_state_dict(torch.load(generator_name, map_location='cpu')) 28 | generator = generator.cuda() 29 | generator.eval() 30 | 31 | sketch_names = glob.glob('{}/*.jpg'.format(sketch_names)) 32 | 33 | for sketch_name in tqdm(sketch_names): 34 | sketch = get_transform('val')(Image.open(sketch_name)).unsqueeze(dim=0).cuda() 35 | with torch.no_grad(): 36 | photo = generator(sketch) 37 | 38 | result_path = '{}/{}'.format(save_root, os.path.basename(sketch_name).split('.')[0]) 39 | if os.path.exists(result_path): 40 | shutil.rmtree(result_path) 41 | os.mkdir(result_path) 42 | 43 | Image.open(sketch_name).resize((224, 224), resample=Image.BILINEAR).save('{}/sketch.jpg'.format(result_path)) 44 | ToPILImage()((((photo.squeeze(dim=0) + 1.0) / 2) * 255).byte().cpu()).save('{}/photo.jpg'.format(result_path)) 45 | --------------------------------------------------------------------------------