158 |
159 |
160 | Ayan Kumar Bhunia
161 |
162 |
163 | Pinaki Nath Chowdhury
164 |
165 |
166 |
167 | Aneeshan Sain
168 |
169 |
170 |
171 | Yongxin Yang
172 |
173 |
174 |
175 |
176 | Tao (Tony) Xiang
177 |
178 |
179 | Yi-Zhe Song
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 | SketchX, Centre for Vision Speech and Signal Processing, University of Surrey, United Kingdom
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 | Published at CVPR 2021
206 |
207 |
208 |
209 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 | A fundamental challenge faced by existing Fine-Grained Sketch-Based Image Retrieval (FG-SBIR) models is the data scarcity -- model performances are largely bottlenecked by the lack of sketch-photo pairs. Whilst the number of photos can be easily scaled, each corresponding sketch still needs to be individually produced. In this paper, we aim to mitigate such an upper-bound on sketch data, and study whether unlabelled photos alone (of which they are many) can be cultivated for performances gain. In particular, we introduce a novel semi-supervised framework for cross-modal retrieval that can additionally leverage large-scale unlabelled photos to account for data scarcity. At the centre of our semi-supervision design is a sequential photo-to-sketch generation model that aims to generate paired sketches for unlabelled photos. Importantly, we further introduce a discriminator guided mechanism to guide against unfaithful generation, together with a distillation loss based regularizer to provide tolerance against noisy training samples. Last but not least, we treat generation and retrieval as two conjugate problems, where a joint learning procedure is devised for each module to mutually benefit from each other. Extensive experiments show that our semi-supervised model yields significant performance boost over the state-of-the-art supervised alternatives, as well as existing methods that can exploit unlabelled photos for FG-SBIR.
234 |
235 |
236 | Framework
237 |
238 | Our framework: a FG-SBIR model leverages large scale unlabelled photos using a sequential photo-to-sketch generation model along with labelled pairs. Discriminator guided instance-wise weighting and distillation loss are used to guard against the noisy generated data. Simultaneously, photo-to-sketch generation model learns by taking reward from FG-SBIR model and Discriminator via policy gradient (over both labelled and unlabelled) together with supervised VAE loss over labelled data. Note rasterization (vector to raster format) is a non-differentiable operation.
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 | Short Presentation
251 |
252 |
253 | VIDEO
254 |
255 |
256 |
257 |
258 |
259 | Bibtex
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 | Citation More Photos are All You Need: Semi-Supervised Learning for Fine-Grained Sketch Based Image Retrieval. In CVPR 2021.
277 |
278 | [Bibtex]
279 |
280 |
281 |
282 | @InProceedings{bhunia_semifgsbir,
283 | author = {Ayan Kumar Bhunia and Pinaki Nath Chowdhury and Aneeshan Sain and Yongxin Yang and Tao Xiang and Yi-Zhe Song},
284 | title = {More Photos are All You Need: Semi-Supervised Learning for Fine-Grained Sketch Based Image Retrieval},
285 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
286 | month = {June},
287 | year = {2021}
288 | }
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
316 |
317 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
--------------------------------------------------------------------------------
/Photo_to_Sketch_2D_Attention/CVPR_18_Baseline/base_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from Image_Networks import *
3 | from Sketch_Networks import *
4 | from torch import optim
5 | import torch
6 | import time
7 | import torch.nn.functional as F
8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9 | from utils import *
10 | import torchvision
11 | from dataset import get_imageOnly_dataloader, get_sketchOnly_dataloader, get_dataloader
12 | from rasterize import rasterize_relative, to_stroke_list
13 | import math
14 | from rasterize import batch_rasterize_relative
15 |
16 |
17 |
18 | class Photo2Sketch_Base(nn.Module):
19 |
20 | def __init__(self, hp):
21 | super(Photo2Sketch_Base, self).__init__()
22 | self.Image_Encoder = EncoderCNN()
23 | self.Image_Decoder = DecoderCNN()
24 | self.Sketch_Encoder = EncoderRNN(hp)
25 | self.Sketch_Decoder = DecoderRNN(hp)
26 | self.hp = hp
27 | self.apply(weights_init_normal)
28 |
29 | def pretrain_SketchBranch(self, iteration = 100000):
30 |
31 | dataloader = get_sketchOnly_dataloader(self.hp)
32 | self.hp.max_seq_len = self.hp.sketch_rnn_max_seq_len
33 | self.Sketch_Encoder.train()
34 | self.Sketch_Decoder.train()
35 | self.train_sketch_params = list(self.Sketch_Encoder.parameters()) + list(self.Sketch_Decoder.parameters())
36 | self.sketch_optimizer = optim.Adam(self.train_sketch_params, self.hp.learning_rate)
37 | self.visalizer = Visualizer()
38 |
39 | for step in range(iteration):
40 |
41 | batch, lengths = dataloader.train_batch()
42 |
43 | self.sketch_optimizer.zero_grad()
44 |
45 | curr_learning_rate = ((self.hp.learning_rate - self.hp.min_learning_rate) *
46 | (self.hp.decay_rate) ** step + self.hp.min_learning_rate)
47 | curr_kl_weight = (self.hp.kl_weight - (self.hp.kl_weight - self.hp.kl_weight_start) *
48 | (self.hp.kl_decay_rate) ** step)
49 |
50 | post_dist = self.Sketch_Encoder(batch, lengths)
51 |
52 | z_vector = post_dist.rsample()
53 | start_token = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * self.hp.batch_size_sketch_rnn).unsqueeze(0).to(
54 | device)
55 | batch_init = torch.cat([start_token, batch], 0)
56 | z_stack = torch.stack([z_vector] * (self.hp.sketch_rnn_max_seq_len + 1))
57 | inputs = torch.cat([batch_init, z_stack], 2)
58 |
59 | output, _ = self.Sketch_Decoder(inputs, z_vector, lengths + 1)
60 |
61 | end_token = torch.stack([torch.Tensor([0, 0, 0, 0, 1])] * batch.shape[1]).unsqueeze(0).to(device)
62 | batch = torch.cat([batch, end_token], 0)
63 | x_target = batch.permute(1, 0, 2) # batch-> Seq_Len, Batch, Feature_dim
64 |
65 | #################### Loss Calculation ########################################
66 | ##############################################################################
67 | recons_loss = sketch_reconstruction_loss(output, x_target)
68 |
69 | prior_distribution = torch.distributions.Normal(torch.zeros_like(post_dist.mean),
70 | torch.ones_like(post_dist.stddev))
71 | kl_cost = torch.max(torch.distributions.kl_divergence(post_dist, prior_distribution).mean(),
72 | torch.tensor(self.hp.kl_tolerance).to(device))
73 | loss = recons_loss + curr_kl_weight * kl_cost
74 |
75 | #################### Update Gradient ########################################
76 | #############################################################################
77 | set_learninRate(self.sketch_optimizer, curr_learning_rate)
78 | loss.backward()
79 | nn.utils.clip_grad_norm(self.train_sketch_params, self.hp.grad_clip)
80 | self.sketch_optimizer.step()
81 |
82 | if (step + 1) % 5 == 0:
83 | print('Step:{} ** KL_Loss:{} '
84 | '** Recons_Loss:{} ** Total_loss:{}'.format(step, kl_cost.item(),
85 | recons_loss.item(), loss.item()))
86 |
87 | data = {}
88 | data['Reconstrcution_Loss'] = recons_loss
89 | data['KL_Loss'] = kl_cost
90 | data['Total Loss'] = loss
91 | self.visalizer.plot_scalars(data, step)
92 |
93 | if (step + 1) % self.hp.eval_freq_iter == 0:
94 |
95 | batch_input, batch_gen_strokes = self.sketch_generation_deterministic(dataloader)
96 | # batch_input, batch_gen_strokes = self.sketch_generation_sample(dataloader)
97 |
98 | batch_redraw = batch_rasterize_relative(batch_gen_strokes)
99 |
100 | if batch_input is not None:
101 | batch_input_redraw = batch_rasterize_relative(batch_input)
102 | batch = []
103 | for a, b in zip(batch_input_redraw, batch_redraw):
104 | batch.append(torch.cat((a, 1. - b), dim=-1))
105 | batch = torch.stack(batch).float()
106 | else:
107 | batch = batch_redraw.float()
108 |
109 | torchvision.utils.save_image(batch, './pretrain_sketch_Viz/deterministic/batch_rceonstruction_' + str(step) + '_.jpg',
110 | nrow=round(math.sqrt(len(batch))))
111 |
112 | torch.save(self.Sketch_Encoder.state_dict(), './pretrain_models/Sketch_Encoder.pth')
113 | torch.save(self.Sketch_Decoder.state_dict(), './pretrain_models/Sketch_Decoder.pth')
114 |
115 | self.Sketch_Encoder.train()
116 | self.Sketch_Decoder.train()
117 |
118 |
119 |
120 | def sketch_generation_deterministic(self, dataloader, number_of_sample=64, condition = True):
121 |
122 | self.Sketch_Encoder.eval()
123 | self.Sketch_Decoder.eval()
124 |
125 | batch, lengths = dataloader.valid_batch(number_of_sample)
126 | if condition:
127 | post_dist = self.Sketch_Encoder(batch, lengths)
128 | z_vector = post_dist.sample()
129 | else:
130 | z_vector = torch.randn(number_of_sample, 128).to(device)
131 |
132 | start_token = torch.Tensor([0, 0, 1, 0, 0]).view(-1, 5).to(device)
133 | start_token = torch.stack([start_token] * number_of_sample, dim=1)
134 | state = start_token
135 | hidden_cell = None
136 |
137 | batch_gen_strokes = []
138 | for i_seq in range(self.hp.average_seq_len):
139 | input = torch.cat([state, z_vector.unsqueeze(0)], 2)
140 | state, hidden_cell = self.Sketch_Decoder(input, z_vector, hidden_cell=hidden_cell, isTrain=False, get_deterministic=True)
141 | batch_gen_strokes.append(state.squeeze(0))
142 |
143 | batch_gen_strokes = torch.stack(batch_gen_strokes, dim=1)
144 |
145 | if condition:
146 | return batch.permute(1, 0, 2), batch_gen_strokes
147 | else:
148 | return None, batch_gen_strokes
149 |
150 |
151 | def sketch_generation_sample(self, dataloader, number_of_sample=64, condition = True):
152 |
153 | self.Sketch_Encoder.eval()
154 | self.Sketch_Decoder.eval()
155 |
156 | batch_gen_strokes = []
157 | batch_input = []
158 |
159 | for i_x in range(number_of_sample):
160 | batch, lengths = dataloader.valid_batch(1)
161 |
162 | if condition:
163 | post_dist = self.Sketch_Encoder(batch, lengths)
164 | z_vector = post_dist.sample()
165 | else:
166 | z_vector = torch.randn(1,128).to(device)
167 |
168 | start_token = torch.Tensor([0,0,1,0,0]).view(1, 1, -1).to(device)
169 | state = start_token
170 | hidden_cell = None
171 | gen_strokes = []
172 | for i in range(self.hp.sketch_rnn_max_seq_len):
173 | input = torch.cat([state, z_vector.unsqueeze(0)],2)
174 | output, hidden_cell = self.Sketch_Decoder(input, z_vector, hidden_cell = hidden_cell, isTrain = False, get_deterministic=False)
175 | state, next_state = sample_next_state(output, self.hp)
176 | gen_strokes.append(next_state)
177 |
178 | gen_strokes = torch.stack(gen_strokes)
179 | batch_gen_strokes.append(gen_strokes)
180 | batch_input.append(batch.squeeze(1))
181 |
182 | batch_gen_strokes = torch.stack(batch_gen_strokes, dim=1)
183 | batch_input = torch.stack(batch_input, dim=1)
184 |
185 | if condition:
186 | return batch_input.permute(1, 0, 2), batch_gen_strokes.permute(1, 0, 2)
187 | else:
188 | return None, batch_gen_strokes.permute(1, 0, 2)
189 |
190 |
191 | def pretrain_ImageBranch(self, epoch = 200):
192 |
193 | image_dataloader = get_imageOnly_dataloader()
194 | self.Image_Encoder.train()
195 | self.Image_Decoder.train()
196 | self.train_image_params = list(self.Image_Encoder.parameters()) + list(self.Image_Decoder.parameters())
197 | self.image_optimizer = optim.Adam(self.train_image_params, self.hp.learning_rate)
198 | step = 0
199 | self.visalizer = Visualizer()
200 |
201 | for i_epoch in range(epoch):
202 |
203 | for _, batch_sample in enumerate(image_dataloader, 0):
204 |
205 | step = step + 1
206 | self.image_optimizer.zero_grad()
207 |
208 | batch_image = batch_sample[0].to(device)
209 | post_dist = self.Image_Encoder(batch_image)
210 | z_vector = post_dist.rsample()
211 | recons_batch_image = self.Image_Decoder(z_vector)
212 |
213 | # batch_image_normalized = transfer_ImageNomralization(batch_image, 'to_Gen')
214 | batch_image_normalized = batch_image
215 | recons_loss = F.mse_loss(batch_image_normalized, recons_batch_image, reduction='sum')/batch_image.shape[0]
216 | # recons_loss = F.mse_loss(batch_image_normalized, recons_batch_image)
217 |
218 | prior_distribution = torch.distributions.Normal(torch.zeros_like(post_dist.mean), torch.ones_like(post_dist.stddev))
219 | kl_cost = torch.distributions.kl_divergence(post_dist, prior_distribution).sum(1).mean()
220 |
221 | loss = recons_loss + kl_cost
222 |
223 | # log_var = torch.log(post_dist.stddev**2)
224 | # loss_matrx = 1 + log_var - post_dist.loc ** 2 - log_var.exp()
225 | # loss_matrx_sum = torch.sum(loss_matrx, dim=1)
226 | # kld_loss = torch.mean(-0.5 * loss_matrx_sum, dim=0)
227 |
228 | loss.backward()
229 | nn.utils.clip_grad_norm(self.train_image_params, self.hp.grad_clip)
230 | self.image_optimizer.step()
231 |
232 |
233 |
234 | if (step + 1) % 20 == 0:
235 | print('Step:{} ** KL_Loss:{} '
236 | '** Recons_Loss:{} ** Total_loss:{}'.format(step, kl_cost.item(),
237 | recons_loss.item(), loss.item()))
238 |
239 | data = {}
240 | data['Reconstrcution_Loss'] = recons_loss
241 | data['KL_Loss'] = kl_cost
242 | data['Total Loss'] = loss
243 | self.visalizer.plot_scalars(data, step)
244 |
245 | data = {}
246 | data['Input_Image'] = batch_image
247 | data['Recons_Image'] = recons_batch_image
248 | sample_z = torch.randn_like(z_vector)
249 | data['Sampled_Image'] = self.Image_Decoder(sample_z)
250 | self.visalizer.vis_image(data, step)
251 |
252 |
253 | if (step + 1) % self.hp.eval_freq_iter == 0:
254 | saved_tensor = torch.cat([batch_image_normalized, recons_batch_image], dim=0)
255 | torchvision.utils.save_image(saved_tensor, './pretrain_image_Viz/'+ str(step) + '.jpg', normalize=True)
256 | torch.save(self.Image_Encoder.state_dict(), './pretrain_models/Image_Encoder' + str(step) + '.pth')
257 | torch.save(self.Image_Decoder.state_dict(), './pretrain_models/Image_Decoder' + str(step) + '.pth')
258 |
259 | def pretrain_SketchBranch_ShoeV2(self, iteration = 10000):
260 |
261 | self.hp.batchsize = 100
262 | dataloader_Train, dataloader_Test = get_dataloader(self.hp)
263 |
264 | self.Sketch_Encoder.train()
265 | self.Sketch_Decoder.train()
266 |
267 | self.train_sketch_params = list(self.Sketch_Encoder.parameters()) + list(self.Sketch_Decoder.parameters())
268 | self.sketch_optimizer = optim.Adam(self.train_sketch_params, self.hp.learning_rate)
269 |
270 | self.visalizer = Visualizer()
271 |
272 | step =0
273 |
274 | for i_epoch in range(2000):
275 |
276 | for batch_data in dataloader_Train:
277 |
278 | batch = batch_data['relative_fivePoint'].to(device).permute(1, 0, 2).float() # Seq_Len, Batch, Feature
279 | lengths = batch_data['sketch_length'].to(device) - 1 # TODO: Relative coord has one less
280 | step += 1
281 | # batch, lengths = dataloader.train_batch()
282 |
283 | self.sketch_optimizer.zero_grad()
284 |
285 | curr_learning_rate = ((self.hp.learning_rate - self.hp.min_learning_rate) *
286 | (self.hp.decay_rate) ** step + self.hp.min_learning_rate)
287 | curr_kl_weight = (self.hp.kl_weight - (self.hp.kl_weight - self.hp.kl_weight_start) *
288 | (self.hp.kl_decay_rate) ** step)
289 |
290 | post_dist = self.Sketch_Encoder(batch, lengths)
291 |
292 | z_vector = post_dist.rsample()
293 | start_token = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * batch.shape[1]).unsqueeze(0).to(device)
294 | batch_init = torch.cat([start_token, batch], 0)
295 | z_stack = torch.stack([z_vector] * (self.hp.max_seq_len + 1))
296 | inputs = torch.cat([batch_init, z_stack], 2)
297 |
298 | output, _ = self.Sketch_Decoder(inputs, z_vector, lengths + 1)
299 |
300 | end_token = torch.stack([torch.Tensor([0, 0, 0, 0, 1])] * batch.shape[1]).unsqueeze(0).to(device)
301 | batch = torch.cat([batch, end_token], 0)
302 | x_target = batch.permute(1, 0, 2) # batch-> Seq_Len, Batch, Feature_dim
303 |
304 | #################### Loss Calculation ########################################
305 | ##############################################################################
306 | recons_loss = sketch_reconstruction_loss(output, x_target)
307 |
308 | prior_distribution = torch.distributions.Normal(torch.zeros_like(post_dist.mean),
309 | torch.ones_like(post_dist.stddev))
310 | kl_cost = torch.max(torch.distributions.kl_divergence(post_dist, prior_distribution).mean(),
311 | torch.tensor(self.hp.kl_tolerance).to(device))
312 | loss = recons_loss + curr_kl_weight * kl_cost
313 |
314 | #################### Update Gradient ########################################
315 | #############################################################################
316 | set_learninRate(self.sketch_optimizer, curr_learning_rate)
317 | loss.backward()
318 | nn.utils.clip_grad_norm(self.train_sketch_params, self.hp.grad_clip)
319 | self.sketch_optimizer.step()
320 |
321 | if (step + 1) % 5 == 0:
322 | print('Step:{} ** KL_Loss:{} '
323 | '** Recons_Loss:{} ** Total_loss:{}'.format(step, kl_cost.item(),
324 | recons_loss.item(), loss.item()))
325 | data = {}
326 | data['Reconstrcution_Loss'] = recons_loss
327 | data['KL_Loss'] = kl_cost
328 | data['Total Loss'] = loss
329 | self.visalizer.plot_scalars(data, step)
330 |
331 | if (step -1) % 1000 == 0:
332 |
333 | """ Draw Sketch to Sketch """
334 | start_token = torch.Tensor([0, 0, 1, 0, 0]).view(-1, 5).to(device)
335 | start_token = torch.stack([start_token] * z_vector.shape[0], dim=1)
336 | state = start_token
337 | hidden_cell = None
338 |
339 | batch_gen_strokes = []
340 | for i_seq in range(self.hp.average_seq_len):
341 | input = torch.cat([state, z_vector.unsqueeze(0)], 2)
342 | state, hidden_cell = self.Sketch_Decoder(input, z_vector, hidden_cell=hidden_cell,
343 | isTrain=False,
344 | get_deterministic=True)
345 | batch_gen_strokes.append(state.squeeze(0))
346 |
347 | sketch2sketch_gen = torch.stack(batch_gen_strokes, dim=1)
348 | sketch_vector_gt = batch.permute(1, 0, 2)
349 |
350 | sketch_vector_gt_draw = batch_rasterize_relative(sketch_vector_gt).to(device)
351 | sketch2sketch_gen_draw = batch_rasterize_relative(sketch2sketch_gen).to(device)
352 |
353 | batch_redraw = []
354 | for a, b in zip(sketch_vector_gt_draw, sketch2sketch_gen_draw):
355 | batch_redraw.append(torch.cat((a, 1.- b), dim=-1))
356 |
357 | torchvision.utils.save_image(torch.stack(batch_redraw),
358 | './pretrain_sketch_Viz/ShoeV2/redraw_{}.jpg'.format(step),
359 | nrow=8)
360 |
361 | torch.save(self.Sketch_Encoder.state_dict(), './pretrain_models/ShoeV2/Sketch_Encoder.pth')
362 | torch.save(self.Sketch_Decoder.state_dict(), './pretrain_models/ShoeV2/Sketch_Decoder.pth')
363 |
364 | self.Sketch_Encoder.train()
365 | self.Sketch_Decoder.train()
366 |
367 |
368 |
369 | def freeze_weights(self):
370 | for name, x in self.named_parameters():
371 | x.requires_grad = False
372 |
373 |
374 | def Unfreeze_weights(self):
375 | for name, x in self.named_parameters():
376 | x.requires_grad = True
377 |
378 |
379 |
380 |
381 |
382 |
--------------------------------------------------------------------------------