├── README.md
├── metric.py
├── model.py
├── result
├── structure.png
└── vis.png
├── test.py
├── train.py
├── utils.py
└── vis.py
/README.md:
--------------------------------------------------------------------------------
1 | # ACNet
2 |
3 | A PyTorch implementation of ACNet based on TCSVT 2023 paper
4 | [ACNet: Approaching-and-Centralizing Network for Zero-Shot Sketch-Based Image Retrieval](https://ieeexplore.ieee.org/document/10052737).
5 |
6 | 
7 |
8 | ## Requirements
9 |
10 | - [Anaconda](https://www.anaconda.com/download/)
11 | - [PyTorch](https://pytorch.org)
12 |
13 | ```
14 | conda install pytorch=1.10.0 torchvision cudatoolkit=11.3 -c pytorch
15 | ```
16 |
17 | - [Timm](https://rwightman.github.io/pytorch-image-models/)
18 |
19 | ```
20 | pip install timm
21 | ```
22 |
23 | - [Pytorch Metric Learning](https://kevinmusgrave.github.io/pytorch-metric-learning/)
24 |
25 | ```
26 | conda install pytorch-metric-learning -c metric-learning -c pytorch
27 | ```
28 |
29 | ## Dataset
30 |
31 | [Sketchy Extended](http://sketchy.eye.gatech.edu) and
32 | [TU-Berlin Extended](http://cybertron.cg.tu-berlin.de/eitz/projects/classifysketch/) datasets are used in this repo, you
33 | could download these datasets from official websites, or download them from
34 | [Google Drive](https://drive.google.com/drive/folders/1lce41k7cGNUOwzt-eswCeahDLWG6Cdk0?usp=sharing). The data directory
35 | structure is shown as follows:
36 |
37 | ```
38 | ├──sketchy
39 | ├── train
40 | ├── sketch
41 | ├── airplane
42 | ├── n02691156_58-1.jpg
43 | └── ...
44 | ...
45 | ├── photo
46 | same structure as sketch
47 | ├── val
48 | same structure as train
49 | ...
50 | ├──tuberlin
51 | same structure as sketchy
52 | ...
53 | ```
54 |
55 | ## Usage
56 |
57 | ### Train Model
58 |
59 | ```
60 | python train.py --data_name tuberlin
61 | optional arguments:
62 | --data_root Datasets root path [default value is '/data']
63 | --data_name Dataset name [default value is 'sketchy'](choices=['sketchy', 'tuberlin'])
64 | --backbone_type Backbone type [default value is 'resnet50'](choices=['resnet50', 'vgg16'])
65 | --emb_dim Embedding dim [default value is 512]
66 | --batch_size Number of images in each mini-batch [default value is 64]
67 | --epochs Number of epochs over the model to train [default value is 10]
68 | --warmup Number of warmups over the extractor to train [default value is 1]
69 | --save_root Result saved root path [default value is 'result']
70 | ```
71 |
72 | ### Test Model
73 |
74 | ```
75 | python test.py --num 8
76 | optional arguments:
77 | --data_root Datasets root path [default value is '/data']
78 | --query_name Query image name [default value is '/data/sketchy/val/sketch/cow/n01887787_591-14.jpg']
79 | --data_base Queried database [default value is 'result/sketchy_resnet50_512_vectors.pth']
80 | --num Retrieval number [default value is 5]
81 | --save_root Result saved root path [default value is 'result']
82 | ```
83 |
84 | ## Benchmarks
85 |
86 | The models are trained on one NVIDIA GTX TITAN (12G) GPU. `Adam` is used to optimize the model, `lr` is `1e-5`
87 | for backbone, `1e-3` for generator and `1e-4` for discriminator. all the hyper-parameters are the default values.
88 |
89 |
90 |
91 |
92 | Backbone |
93 | Dim |
94 | Sketchy Extended |
95 | TU-Berlin Extended |
96 | Download |
97 |
98 |
99 | mAP@200 |
100 | mAP@all |
101 | P@100 |
102 | P@200 |
103 | mAP@200 |
104 | mAP@all |
105 | P@100 |
106 | P@200 |
107 |
108 |
109 |
110 |
111 | VGG16 |
112 | 64 |
113 | 32.6 |
114 | 38.0 |
115 | 48.7 |
116 | 44.7 |
117 | 39.8 |
118 | 37.1 |
119 | 50.6 |
120 | 48.0 |
121 | MEGA |
122 |
123 |
124 | VGG16 |
125 | 512 |
126 | 38.3 |
127 | 42.2 |
128 | 53.3 |
129 | 49.3 |
130 | 47.2 |
131 | 43.9 |
132 | 58.1 |
133 | 55.3 |
134 | MEGA |
135 |
136 |
137 | VGG16 |
138 | 4096 |
139 | 40.0 |
140 | 43.2 |
141 | 54.6 |
142 | 50.8 |
143 | 51.7 |
144 | 47.9 |
145 | 62.3 |
146 | 59.3 |
147 | MEGA |
148 |
149 |
150 | ResNet50 |
151 | 64 |
152 | 43.0 |
153 | 46.0 |
154 | 56.8 |
155 | 52.7 |
156 | 47.5 |
157 | 44.9 |
158 | 57.2 |
159 | 54.9 |
160 | MEGA |
161 |
162 |
163 | ResNet50 |
164 | 512 |
165 | 51.7 |
166 | 55.9 |
167 | 64.3 |
168 | 60.8 |
169 | 57.7 |
170 | 57.7 |
171 | 65.8 |
172 | 64.4 |
173 | MEGA |
174 |
175 |
176 | ResNet50 |
177 | 4096 |
178 | 51.1 |
179 | 55.7 |
180 | 63.8 |
181 | 60.0 |
182 | 57.3 |
183 | 58.6 |
184 | 64.6 |
185 | 63.5 |
186 | MEGA |
187 |
188 |
189 |
190 |
191 | ## Results
192 |
193 | 
194 |
195 | ## Citing ACNet
196 |
197 | If you find ACNet helpful, please consider citing:
198 | ```
199 | @article{ren2023acnet,
200 | title={ACNet: Approaching-and-Centralizing Network for Zero-Shot Sketch-Based Image Retrieval},
201 | author={Ren, Hao and Zheng, Ziqiang and Wu, Yang and Lu, Hong and Yang, Yang and Shan, Ying and Yeung, Sai-Kit},
202 | journal={IEEE Transactions on Circuits and Systems for Video Technology},
203 | year={2023},
204 | publisher={IEEE}
205 | }
206 | ```
207 |
--------------------------------------------------------------------------------
/metric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial.distance import cdist
3 |
4 |
5 | def sake_metric(predicted_features_gallery, gt_labels_gallery, predicted_features_query, gt_labels_query, k=None):
6 | if k is None:
7 | k = {'precision': 100, 'map': predicted_features_gallery.shape[0]}
8 | if k['precision'] is None:
9 | k['precision'] = 100
10 | if k['map'] is None:
11 | k['map'] = predicted_features_gallery.shape[0]
12 |
13 | scores = -cdist(predicted_features_query, predicted_features_gallery, metric='cosine')
14 | gt_labels_query = gt_labels_query.flatten()
15 | gt_labels_gallery = gt_labels_gallery.flatten()
16 | aps = map_sake(gt_labels_gallery, predicted_features_query, gt_labels_query, scores, k=k['map'])
17 | prec = prec_sake(gt_labels_gallery, predicted_features_query, gt_labels_query, scores, k=k['precision'])
18 | return sum(aps) / len(aps), prec
19 |
20 |
21 | def map_sake(gt_labels_gallery, predicted_features_query, gt_labels_query, scores, k=None):
22 | mAP_ls = [[] for _ in range(len(np.unique(gt_labels_query)))]
23 | mean_mAP = []
24 | for fi in range(predicted_features_query.shape[0]):
25 | mapi = eval_ap(gt_labels_query[fi], scores[fi], gt_labels_gallery, top=k)
26 | mAP_ls[gt_labels_query[fi]].append(mapi)
27 | mean_mAP.append(mapi)
28 | return mean_mAP
29 |
30 |
31 | def prec_sake(gt_labels_gallery, predicted_features_query, gt_labels_query, scores, k=None):
32 | # compute precision for two modalities
33 | prec_ls = [[] for _ in range(len(np.unique(gt_labels_query)))]
34 | mean_prec = []
35 | for fi in range(predicted_features_query.shape[0]):
36 | prec = eval_precision(gt_labels_query[fi], scores[fi], gt_labels_gallery, top=k)
37 | prec_ls[gt_labels_query[fi]].append(prec)
38 | mean_prec.append(prec)
39 | return np.nanmean(mean_prec)
40 |
41 |
42 | def eval_ap(inst_id, scores, gt_labels, top=None):
43 | pos_flag = gt_labels == inst_id
44 | tot = scores.shape[0]
45 | tot_pos = np.sum(pos_flag)
46 |
47 | sort_idx = np.argsort(-scores)
48 | tp = pos_flag[sort_idx]
49 | fp = np.logical_not(tp)
50 |
51 | if top is not None:
52 | top = min(top, tot)
53 | tp = tp[:top]
54 | fp = fp[:top]
55 | tot_pos = min(top, tot_pos)
56 |
57 | fp = np.cumsum(fp)
58 | tp = np.cumsum(tp)
59 | try:
60 | rec = tp / tot_pos
61 | prec = tp / (tp + fp)
62 | except:
63 | return np.nan
64 |
65 | ap = voc_ap(rec, prec)
66 | return ap
67 |
68 |
69 | def voc_ap(rec, prec):
70 | mrec = np.append(0, rec)
71 | mrec = np.append(mrec, 1)
72 |
73 | mpre = np.append(0, prec)
74 | mpre = np.append(mpre, 0)
75 |
76 | for ii in range(len(mpre) - 2, -1, -1):
77 | mpre[ii] = max(mpre[ii], mpre[ii + 1])
78 |
79 | msk = [i != j for i, j in zip(mrec[1:], mrec[0:-1])]
80 | ap = np.sum((mrec[1:][msk] - mrec[0:-1][msk]) * mpre[1:][msk])
81 | return ap
82 |
83 |
84 | def eval_precision(inst_id, scores, gt_labels, top=100):
85 | pos_flag = gt_labels == inst_id
86 | tot = scores.shape[0]
87 | top = min(top, tot)
88 | sort_idx = np.argsort(-scores)
89 | return np.sum(pos_flag[sort_idx][:top]) / top
90 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import timm
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_channels):
8 | super(ResidualBlock, self).__init__()
9 |
10 | self.conv = nn.Sequential(nn.Conv2d(in_channels, in_channels, 3, padding=1, padding_mode='reflect'),
11 | nn.InstanceNorm2d(in_channels), nn.ReLU(inplace=True),
12 | nn.Conv2d(in_channels, in_channels, 3, padding=1, padding_mode='reflect'),
13 | nn.InstanceNorm2d(in_channels))
14 |
15 | def forward(self, x):
16 | return x + self.conv(x)
17 |
18 |
19 | class Generator(nn.Module):
20 | def __init__(self, in_channels=64, num_block=9):
21 | super(Generator, self).__init__()
22 |
23 | # in conv
24 | self.in_conv = nn.Sequential(nn.Conv2d(3, in_channels, 7, padding=3, padding_mode='reflect'),
25 | nn.InstanceNorm2d(in_channels), nn.ReLU(inplace=True))
26 |
27 | # down sample
28 | down_sample = []
29 | for _ in range(2):
30 | out_channels = in_channels * 2
31 | down_sample += [nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1),
32 | nn.InstanceNorm2d(out_channels), nn.ReLU(inplace=True)]
33 | in_channels = out_channels
34 | self.down_sample = nn.Sequential(*down_sample)
35 |
36 | # conv blocks
37 | self.convs = nn.Sequential(*[ResidualBlock(in_channels) for _ in range(num_block)])
38 |
39 | # up sample
40 | up_sample = []
41 | for _ in range(2):
42 | out_channels = in_channels // 2
43 | up_sample += [nn.ConvTranspose2d(in_channels, out_channels, 3, stride=2, padding=1, output_padding=1),
44 | nn.InstanceNorm2d(out_channels), nn.ReLU(inplace=True)]
45 | in_channels = out_channels
46 | self.up_sample = nn.Sequential(*up_sample)
47 |
48 | # out conv
49 | self.out_conv = nn.Sequential(nn.Conv2d(in_channels, 3, 7, padding=3, padding_mode='reflect'), nn.Tanh())
50 |
51 | def forward(self, x):
52 | x = self.in_conv(x)
53 | x = self.down_sample(x)
54 | x = self.convs(x)
55 | x = self.up_sample(x)
56 | out = self.out_conv(x)
57 | return out
58 |
59 |
60 | class Discriminator(nn.Module):
61 | def __init__(self, in_channels=64):
62 | super(Discriminator, self).__init__()
63 |
64 | self.conv1 = nn.Sequential(nn.Conv2d(3, in_channels, 4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True))
65 |
66 | self.conv2 = nn.Sequential(nn.Conv2d(in_channels, in_channels * 2, 4, stride=2, padding=1),
67 | nn.InstanceNorm2d(in_channels * 2), nn.LeakyReLU(0.2, inplace=True))
68 |
69 | self.conv3 = nn.Sequential(nn.Conv2d(in_channels * 2, in_channels * 4, 4, stride=2, padding=1),
70 | nn.InstanceNorm2d(in_channels * 4), nn.LeakyReLU(0.2, inplace=True))
71 |
72 | self.conv4 = nn.Sequential(nn.Conv2d(in_channels * 4, in_channels * 8, 4, padding=1),
73 | nn.InstanceNorm2d(in_channels * 8), nn.LeakyReLU(0.2, inplace=True))
74 |
75 | self.conv5 = nn.Conv2d(in_channels * 8, 1, 4, padding=1)
76 |
77 | def forward(self, x):
78 | x = self.conv1(x)
79 | x = self.conv2(x)
80 | x = self.conv3(x)
81 | x = self.conv4(x)
82 | out = self.conv5(x)
83 | return out
84 |
85 |
86 | class Extractor(nn.Module):
87 | def __init__(self, backbone_type, emb_dim):
88 | super(Extractor, self).__init__()
89 |
90 | # backbone
91 | model_name = 'resnet50' if backbone_type == 'resnet50' else 'vgg16'
92 | self.backbone = timm.create_model(model_name, pretrained=True, num_classes=emb_dim, global_pool='max')
93 |
94 | def forward(self, x):
95 | x = self.backbone(x)
96 | out = F.normalize(x, dim=-1)
97 | return out
98 |
99 |
100 | def set_bn_eval(m):
101 | classname = m.__class__.__name__
102 | if classname.find('BatchNorm2d') != -1:
103 | m.eval()
--------------------------------------------------------------------------------
/result/structure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/ACNet/60387b1c0429282840fba0fbf76f372b7187f02b/result/structure.png
--------------------------------------------------------------------------------
/result/vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/ACNet/60387b1c0429282840fba0fbf76f372b7187f02b/result/vis.png
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 |
5 | import torch
6 | from PIL import Image, ImageDraw
7 |
8 | from utils import DomainDataset
9 |
10 | if __name__ == '__main__':
11 | parser = argparse.ArgumentParser(description='Test Model')
12 | parser.add_argument('--data_root', default='/data', type=str, help='Datasets root path')
13 | parser.add_argument('--query_name', default='/data/sketchy/val/sketch/cow/n01887787_591-14.jpg', type=str,
14 | help='Query image name')
15 | parser.add_argument('--data_base', default='result/sketchy_resnet50_512_vectors.pth', type=str,
16 | help='Queried database')
17 | parser.add_argument('--num', default=5, type=int, help='Retrieval number')
18 | parser.add_argument('--save_root', default='result', type=str, help='Result saved root path')
19 |
20 | opt = parser.parse_args()
21 |
22 | data_root, query_name, data_base, retrieval_num = opt.data_root, opt.query_name, opt.data_base, opt.num
23 | save_root, data_name = opt.save_root, data_base.split('/')[-1].split('_')[0]
24 |
25 | vectors = torch.load(data_base)
26 | val_data = DomainDataset(data_root, data_name, split='val')
27 |
28 | if query_name not in val_data.images:
29 | raise FileNotFoundError('{} not found'.format(query_name))
30 | query_index = val_data.images.index(query_name)
31 | query_image = Image.open(query_name).resize((224, 224), resample=Image.BILINEAR)
32 | query_label = val_data.labels[query_index]
33 | query_feature = vectors[query_index]
34 |
35 | gallery_images, gallery_labels = [], []
36 | for i, domain in enumerate(val_data.domains):
37 | if domain == 0:
38 | gallery_images.append(val_data.images[i])
39 | gallery_labels.append(val_data.labels[i])
40 | gallery_features = vectors[torch.tensor(val_data.domains) == 0]
41 |
42 | sim_matrix = query_feature.unsqueeze(0).mm(gallery_features.t()).squeeze()
43 | idx = sim_matrix.topk(k=retrieval_num, dim=-1)[1]
44 |
45 | result_path = '{}/{}'.format(save_root, query_name.split('/')[-1].split('.')[0])
46 | if os.path.exists(result_path):
47 | shutil.rmtree(result_path)
48 | os.mkdir(result_path)
49 | query_image.save('{}/query.jpg'.format(result_path))
50 | for num, index in enumerate(idx):
51 | retrieval_image = Image.open(gallery_images[index.item()]).resize((224, 224), resample=Image.BILINEAR)
52 | draw = ImageDraw.Draw(retrieval_image)
53 | retrieval_label = gallery_labels[index.item()]
54 | retrieval_status = retrieval_label == query_label
55 | retrieval_sim = sim_matrix[index.item()].item()
56 | if retrieval_status:
57 | draw.rectangle((0, 0, 223, 223), outline='green', width=8)
58 | else:
59 | draw.rectangle((0, 0, 223, 223), outline='red', width=8)
60 | retrieval_image.save('{}/retrieval_{}_{}.jpg'.format(result_path, num + 1, '%.4f' % retrieval_sim))
61 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | from pytorch_metric_learning.losses import NormalizedSoftmaxLoss
9 | from torch import nn
10 | from torch.backends import cudnn
11 | from torch.optim import Adam
12 | from torch.utils.data.dataloader import DataLoader
13 | from tqdm import tqdm
14 |
15 | from model import Extractor, Discriminator, Generator, set_bn_eval
16 | from utils import DomainDataset, compute_metric
17 |
18 | # for reproducibility
19 | random.seed(1)
20 | np.random.seed(1)
21 | torch.manual_seed(1)
22 | cudnn.deterministic = True
23 | cudnn.benchmark = False
24 |
25 |
26 | # train for one epoch
27 | def train(backbone, data_loader):
28 | backbone.train()
29 | # fix bn on backbone
30 | backbone.apply(set_bn_eval)
31 | generator.train()
32 | discriminator.train()
33 | total_extractor_loss, total_generator_loss, total_identity_loss, total_discriminator_loss = 0.0, 0.0, 0.0, 0.0
34 | total_num, train_bar = 0, tqdm(data_loader, dynamic_ncols=True)
35 | for sketch, photo, label in train_bar:
36 | sketch, photo, label = sketch.cuda(), photo.cuda(), label.cuda()
37 |
38 | optimizer_generator.zero_grad()
39 | optimizer_extractor.zero_grad()
40 |
41 | # generator #
42 | fake = generator(sketch)
43 | pred_fake = discriminator(fake)
44 |
45 | # generator loss
46 | target_fake = torch.ones(pred_fake.size(), device=pred_fake.device)
47 | gg_loss = adversarial_criterion(pred_fake, target_fake)
48 | total_generator_loss += gg_loss.item() * sketch.size(0)
49 | # identity loss
50 | ii_loss = identity_criterion(generator(photo), photo)
51 | total_identity_loss += ii_loss.item() * sketch.size(0)
52 |
53 | # extractor #
54 | sketch_proj = backbone(sketch)
55 | photo_proj = backbone(photo)
56 | fake_proj = backbone(fake)
57 |
58 | # extractor loss
59 | class_loss = (class_criterion(sketch_proj, label) + class_criterion(photo_proj, label) +
60 | class_criterion(fake_proj, label)) / 3
61 | total_extractor_loss += class_loss.item() * sketch.size(0)
62 |
63 | (gg_loss + 0.1 * ii_loss + 10 * class_loss).backward()
64 |
65 | optimizer_generator.step()
66 | optimizer_extractor.step()
67 |
68 | # discriminator loss #
69 | optimizer_discriminator.zero_grad()
70 | pred_photo = discriminator(photo)
71 | target_photo = torch.ones(pred_photo.size(), device=pred_photo.device)
72 | pred_fake = discriminator(fake.detach())
73 | target_fake = torch.zeros(pred_fake.size(), device=pred_fake.device)
74 | adversarial_loss = (adversarial_criterion(pred_photo, target_photo) +
75 | adversarial_criterion(pred_fake, target_fake)) / 2
76 | total_discriminator_loss += adversarial_loss.item() * sketch.size(0)
77 |
78 | adversarial_loss.backward()
79 | optimizer_discriminator.step()
80 |
81 | total_num += sketch.size(0)
82 |
83 | e_loss = total_extractor_loss / total_num
84 | g_loss = total_generator_loss / total_num
85 | i_loss = total_identity_loss / total_num
86 | d_loss = total_discriminator_loss / total_num
87 | train_bar.set_description('Train Epoch: [{}/{}] E-Loss: {:.4f} G-Loss: {:.4f} I-Loss: {:.4f} D-Loss: {:.4f}'
88 | .format(epoch, epochs, e_loss, g_loss, i_loss, d_loss))
89 |
90 | return e_loss, g_loss, i_loss, d_loss
91 |
92 |
93 | # val for one epoch
94 | def val(backbone, encoder, data_loader):
95 | backbone.eval()
96 | encoder.eval()
97 | vectors, domains, labels = [], [], []
98 | with torch.no_grad():
99 | for img, domain, label in tqdm(data_loader, desc='Feature extracting', dynamic_ncols=True):
100 | img = img.cuda()
101 | photo = img[domain == 0]
102 | sketch = img[domain == 1]
103 | photo_emb = backbone(photo)
104 | sketch_emb = backbone(encoder(sketch))
105 | emb = torch.cat((photo_emb, sketch_emb), dim=0)
106 | vectors.append(emb.cpu())
107 | label = torch.cat((label[domain == 0], label[domain == 1]), dim=0)
108 | labels.append(label)
109 | domain = torch.cat((domain[domain == 0], domain[domain == 1]), dim=0)
110 | domains.append(domain)
111 | vectors = torch.cat(vectors, dim=0)
112 | domains = torch.cat(domains, dim=0)
113 | labels = torch.cat(labels, dim=0)
114 | acc = compute_metric(vectors, domains, labels)
115 | results['P@100'].append(acc['P@100'] * 100)
116 | results['P@200'].append(acc['P@200'] * 100)
117 | results['mAP@200'].append(acc['mAP@200'] * 100)
118 | results['mAP@all'].append(acc['mAP@all'] * 100)
119 | print('Val Epoch: [{}/{}] | P@100:{:.1f}% | P@200:{:.1f}% | mAP@200:{:.1f}% | mAP@all:{:.1f}%'
120 | .format(epoch, epochs, acc['P@100'] * 100, acc['P@200'] * 100, acc['mAP@200'] * 100,
121 | acc['mAP@all'] * 100))
122 | return acc['precise'], vectors
123 |
124 |
125 | if __name__ == '__main__':
126 | parser = argparse.ArgumentParser(description='Train Model')
127 | # common args
128 | parser.add_argument('--data_root', default='/data', type=str, help='Datasets root path')
129 | parser.add_argument('--data_name', default='sketchy', type=str, choices=['sketchy', 'tuberlin'],
130 | help='Dataset name')
131 | parser.add_argument('--backbone_type', default='resnet50', type=str, choices=['resnet50', 'vgg16'],
132 | help='Backbone type')
133 | parser.add_argument('--emb_dim', default=512, type=int, help='Embedding dim')
134 | parser.add_argument('--batch_size', default=64, type=int, help='Number of images in each mini-batch')
135 | parser.add_argument('--epochs', default=10, type=int, help='Number of epochs over the model to train')
136 | parser.add_argument('--warmup', default=1, type=int, help='Number of warmups over the extractor to train')
137 | parser.add_argument('--save_root', default='result', type=str, help='Result saved root path')
138 |
139 | # args parse
140 | args = parser.parse_args()
141 | data_root, data_name, backbone_type, emb_dim = args.data_root, args.data_name, args.backbone_type, args.emb_dim
142 | batch_size, epochs, warmup, save_root = args.batch_size, args.epochs, args.warmup, args.save_root
143 |
144 | # data prepare
145 | train_data = DomainDataset(data_root, data_name, split='train')
146 | val_data = DomainDataset(data_root, data_name, split='val')
147 | train_loader = DataLoader(train_data, batch_size=batch_size // 2, shuffle=True, num_workers=8)
148 | val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=8)
149 |
150 | # model define
151 | extractor = Extractor(backbone_type, emb_dim).cuda()
152 | generator = Generator(in_channels=8, num_block=8).cuda()
153 | discriminator = Discriminator(in_channels=8).cuda()
154 |
155 | # loss setup
156 | class_criterion = NormalizedSoftmaxLoss(len(train_data.classes), emb_dim).cuda()
157 | adversarial_criterion = nn.MSELoss()
158 | identity_criterion = nn.L1Loss()
159 | # optimizer config
160 | optimizer_extractor = Adam([{'params': extractor.parameters()}, {'params': class_criterion.parameters(),
161 | 'lr': 1e-3}], lr=1e-5)
162 | optimizer_generator = Adam(generator.parameters(), lr=1e-3, betas=(0.5, 0.999))
163 | optimizer_discriminator = Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))
164 |
165 | # training loop
166 | results = {'extractor_loss': [], 'generator_loss': [], 'identity_loss': [], 'discriminator_loss': [],
167 | 'precise': [], 'P@100': [], 'P@200': [], 'mAP@200': [], 'mAP@all': []}
168 | save_name_pre = '{}_{}_{}'.format(data_name, backbone_type, emb_dim)
169 | if not os.path.exists(save_root):
170 | os.makedirs(save_root)
171 | best_precise = 0.0
172 | for epoch in range(1, epochs + 1):
173 |
174 | # warmup, not update the parameters of extractor, except the final fc layer
175 | for param in list(extractor.backbone.parameters())[:-2]:
176 | param.requires_grad = False if epoch <= warmup else True
177 |
178 | extractor_loss, generator_loss, identity_loss, discriminator_loss = train(extractor, train_loader)
179 | results['extractor_loss'].append(extractor_loss)
180 | results['generator_loss'].append(generator_loss)
181 | results['identity_loss'].append(identity_loss)
182 | results['discriminator_loss'].append(discriminator_loss)
183 | precise, features = val(extractor, generator, val_loader)
184 | results['precise'].append(precise * 100)
185 |
186 | # save statistics
187 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
188 | data_frame.to_csv('{}/{}_results.csv'.format(save_root, save_name_pre), index_label='epoch')
189 |
190 | if precise > best_precise:
191 | best_precise = precise
192 | torch.save(extractor.state_dict(), '{}/{}_extractor.pth'.format(save_root, save_name_pre))
193 | torch.save(generator.state_dict(), '{}/{}_generator.pth'.format(save_root, save_name_pre))
194 | torch.save(discriminator.state_dict(), '{}/{}_discriminator.pth'.format(save_root, save_name_pre))
195 | torch.save(class_criterion.state_dict(), '{}/{}_proxies.pth'.format(save_root, save_name_pre))
196 | torch.save(features, '{}/{}_vectors.pth'.format(save_root, save_name_pre))
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import random
4 |
5 | from PIL import Image
6 | from torch.utils.data.dataset import Dataset
7 | from torchvision import transforms
8 | from torchvision.transforms import InterpolationMode
9 |
10 | from metric import sake_metric
11 |
12 |
13 | def get_transform(split='train'):
14 | if split == 'train':
15 | return transforms.Compose([
16 | transforms.Resize((224, 224), interpolation=InterpolationMode.BILINEAR),
17 | transforms.RandomHorizontalFlip(p=0.5),
18 | transforms.ToTensor(),
19 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
20 | else:
21 | return transforms.Compose([
22 | transforms.Resize((224, 224), interpolation=InterpolationMode.BILINEAR),
23 | transforms.ToTensor(),
24 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
25 |
26 |
27 | class DomainDataset(Dataset):
28 | def __init__(self, data_root, data_name, split='train'):
29 | super(DomainDataset, self).__init__()
30 |
31 | images = []
32 | for classes in os.listdir(os.path.join(data_root, data_name, split, 'sketch')):
33 | sketches = glob.glob(os.path.join(data_root, data_name, split, 'sketch', str(classes), '*.jpg'))
34 | photos = glob.glob(os.path.join(data_root, data_name, split, 'photo', str(classes), '*.jpg'))
35 | # only consider the classes which photo images >= 400 for tuberlin dataset
36 | if len(photos) < 400 and data_name == 'tuberlin' and split == 'val':
37 | pass
38 | else:
39 | images += sketches
40 | # only append sketches for train
41 | if split == 'val':
42 | images += photos
43 | self.images = sorted(images)
44 | self.transform = get_transform(split)
45 |
46 | self.domains, self.labels, self.classes = [], [], {}
47 | i = 0
48 | for img in self.images:
49 | domain, label = os.path.dirname(img).split('/')[-2:]
50 | self.domains.append(0 if domain == 'photo' else 1)
51 | if label not in self.classes:
52 | self.classes[label] = i
53 | i += 1
54 | self.labels.append(self.classes[label])
55 | # store photos for each class to easy sample for sketch in training period
56 | if split == 'train':
57 | self.refs = {}
58 | for key, value in self.classes.items():
59 | self.refs[value] = sorted(glob.glob(os.path.join(data_root, data_name, split, 'photo', key, '*.jpg')))
60 |
61 | self.split = split
62 |
63 | def __getitem__(self, index):
64 | img = Image.open(self.images[index])
65 | img = self.transform(img)
66 | label = self.labels[index]
67 | if self.split == 'val':
68 | domain = self.domains[index]
69 | return img, domain, label
70 | else:
71 | ref = Image.open(random.choice(self.refs[label]))
72 | ref = self.transform(ref)
73 | return img, ref, label
74 |
75 | def __len__(self):
76 | return len(self.images)
77 |
78 |
79 | def compute_metric(vectors, domains, labels):
80 | acc = {}
81 |
82 | photo_vectors = vectors[domains == 0].numpy()
83 | sketch_vectors = vectors[domains == 1].numpy()
84 | photo_labels = labels[domains == 0].numpy()
85 | sketch_labels = labels[domains == 1].numpy()
86 | map_all, p_100 = sake_metric(photo_vectors, photo_labels, sketch_vectors, sketch_labels)
87 | map_200, p_200 = sake_metric(photo_vectors, photo_labels, sketch_vectors, sketch_labels,
88 | {'precision': 200, 'map': 200})
89 |
90 | acc['P@100'], acc['P@200'], acc['mAP@200'], acc['mAP@all'] = p_100, p_200, map_200, map_all
91 | # the mean value is chosen as the representative of precise
92 | acc['precise'] = (acc['P@100'] + acc['P@200'] + acc['mAP@200'] + acc['mAP@all']) / 4
93 | return acc
94 |
95 |
--------------------------------------------------------------------------------
/vis.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 | import shutil
5 |
6 | import torch
7 | from PIL import Image
8 | from torchvision.transforms import ToPILImage
9 | from tqdm import tqdm
10 |
11 | from model import Generator
12 | from utils import get_transform
13 |
14 | if __name__ == '__main__':
15 | parser = argparse.ArgumentParser(description='Vis Generator')
16 | parser.add_argument('--sketch_name', default='/data/sketchy/val/sketch/cow', type=str,
17 | help='Sketch image name')
18 | parser.add_argument('--generator_name', default='result/sketchy_resnet50_512_generator.pth', type=str,
19 | help='Generator name')
20 | parser.add_argument('--save_root', default='result', type=str, help='Result saved root path')
21 |
22 | opt = parser.parse_args()
23 |
24 | sketch_names, generator_name, save_root = opt.sketch_name, opt.generator_name, opt.save_root
25 |
26 | generator = Generator(in_channels=8, num_block=8)
27 | generator.load_state_dict(torch.load(generator_name, map_location='cpu'))
28 | generator = generator.cuda()
29 | generator.eval()
30 |
31 | sketch_names = glob.glob('{}/*.jpg'.format(sketch_names))
32 |
33 | for sketch_name in tqdm(sketch_names):
34 | sketch = get_transform('val')(Image.open(sketch_name)).unsqueeze(dim=0).cuda()
35 | with torch.no_grad():
36 | photo = generator(sketch)
37 |
38 | result_path = '{}/{}'.format(save_root, os.path.basename(sketch_name).split('.')[0])
39 | if os.path.exists(result_path):
40 | shutil.rmtree(result_path)
41 | os.mkdir(result_path)
42 |
43 | Image.open(sketch_name).resize((224, 224), resample=Image.BILINEAR).save('{}/sketch.jpg'.format(result_path))
44 | ToPILImage()((((photo.squeeze(dim=0) + 1.0) / 2) * 255).byte().cpu()).save('{}/photo.jpg'.format(result_path))
45 |
--------------------------------------------------------------------------------