├── sample_images ├── logo.png ├── example.jpg ├── outline.png ├── framework.png ├── outline.jpeg └── sketchxlogo.jpg ├── .idea ├── misc.xml ├── vcs.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── .gitignore ├── modules.xml ├── semisupervised-FGSBIR.iml └── deployment.xml ├── Photo_to_Sketch_2D_Attention ├── base_model.py ├── CVPR_18_Baseline │ ├── preprocess │ │ └── preprocess_rdp.py │ ├── Image_Networks.py │ ├── utils.py │ ├── Sketch_Networks.py │ ├── main.py │ ├── model.py │ ├── rasterize.py │ ├── dataset.py │ └── base_model.py ├── main.py ├── Image_Networks.py ├── model.py ├── rasterize.py ├── Sketch_Networks.py ├── dataset.py └── utils.py ├── README.md ├── LICENSE └── index.html /sample_images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AyanKumarBhunia/semisupervised-FGSBIR/HEAD/sample_images/logo.png -------------------------------------------------------------------------------- /sample_images/example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AyanKumarBhunia/semisupervised-FGSBIR/HEAD/sample_images/example.jpg -------------------------------------------------------------------------------- /sample_images/outline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AyanKumarBhunia/semisupervised-FGSBIR/HEAD/sample_images/outline.png -------------------------------------------------------------------------------- /sample_images/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AyanKumarBhunia/semisupervised-FGSBIR/HEAD/sample_images/framework.png -------------------------------------------------------------------------------- /sample_images/outline.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AyanKumarBhunia/semisupervised-FGSBIR/HEAD/sample_images/outline.jpeg -------------------------------------------------------------------------------- /sample_images/sketchxlogo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AyanKumarBhunia/semisupervised-FGSBIR/HEAD/sample_images/sketchxlogo.jpg -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /../../../../:\PhD Works\semisupervised-FGSBIR\.idea/dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/semisupervised-FGSBIR.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /Photo_to_Sketch_2D_Attention/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from Image_Networks import * 3 | from Sketch_Networks import * 4 | from torch import optim 5 | import torch 6 | import time 7 | import torch.nn.functional as F 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | from utils import * 10 | import torchvision 11 | from dataset import get_sketchOnly_dataloader, get_dataloader 12 | from rasterize import rasterize_relative, to_stroke_list 13 | import math 14 | from rasterize import batch_rasterize_relative 15 | 16 | 17 | 18 | class Photo2Sketch_Base(nn.Module): 19 | 20 | def __init__(self, hp): 21 | super(Photo2Sketch_Base, self).__init__() 22 | self.Image_Encoder = EncoderCNN() 23 | # self.Image_Decoder = DecoderCNN() 24 | # self.Sketch_Encoder = EncoderRNN(hp) 25 | self.Sketch_Decoder = DecoderRNN2D(hp) 26 | self.hp = hp 27 | # self.apply(weights_init_normal) 28 | 29 | def freeze_weights(self): 30 | for name, x in self.named_parameters(): 31 | x.requires_grad = False 32 | 33 | 34 | def Unfreeze_weights(self): 35 | for name, x in self.named_parameters(): 36 | x.requires_grad = True 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /Photo_to_Sketch_2D_Attention/CVPR_18_Baseline/preprocess/preprocess_rdp.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import numpy as np 4 | from bresenham import bresenham 5 | from rasterize import rasterize_Sketch 6 | from PIL import Image 7 | from rdp import rdp 8 | 9 | if __name__ == '__main__': 10 | coordinate_path = os.path.join('/home/media/On_the_Fly/Code_ALL/Final_Dataset/ShoeV2/ShoeV2_Coordinate') 11 | # coordinate_path = r'E:\sketchPool\Sketch_Classification\TU_Berlin' 12 | with open(coordinate_path, 'rb') as fp: 13 | Coordinate = pickle.load(fp) 14 | 15 | # for key in Coordinate.keys(): 16 | key_list = list(Coordinate.keys()) 17 | 18 | for i_rdp in [3, 4]: #reversed(range(11)): 19 | rdp_simplified = {} 20 | max_points_old = [] 21 | max_points_new = [] 22 | 23 | if not os.path.exists(str(i_rdp)): 24 | os.makedirs(str(i_rdp)) 25 | 26 | for num, key in enumerate(key_list): 27 | 28 | print(i_rdp, num, key) 29 | sketch_points = Coordinate[key] 30 | sketch_points_orig = sketch_points 31 | 32 | sketch_points = sketch_points.astype(np.float) 33 | # sketch_points[:, :2] = sketch_points[:, :2] / np.array([800, 800]) 34 | # sketch_points[:, :2] = sketch_points[:, :2] * 256 35 | sketch_points = np.round(sketch_points) 36 | 37 | all_strokes = np.split(sketch_points, np.where(sketch_points[:, 2])[0] + 1, axis=0)[:-1] 38 | 39 | max_points_old.append(sketch_points_orig.shape) 40 | 41 | sketch_img_orig = rasterize_Sketch(sketch_points_orig) 42 | sketch_img_orig = Image.fromarray(sketch_img_orig).convert('RGB') 43 | # sketch_img_orig.show() 44 | sketch_img_orig.save(str(i_rdp) + '/' + str(num)+'.jpg') 45 | 46 | sketch_points_sampled_new = [] 47 | for stroke in all_strokes: 48 | stroke_new = rdp(stroke[:, :2], epsilon=i_rdp, algo="iter") 49 | stroke_new = np.hstack((stroke_new, np.zeros((stroke_new.shape[0], 1)))) 50 | stroke_new[-1, -1] = 1. 51 | # print(stroke_new.shape, stroke.shape) 52 | sketch_points_sampled_new.append(stroke_new) 53 | sketch_points_new = np.vstack(sketch_points_sampled_new) 54 | 55 | max_points_new.append(sketch_points_new.shape[0]) 56 | # print(sketch_points_orig.shape, sketch_points_new.shape) 57 | 58 | 59 | sketch_img_orig = rasterize_Sketch(sketch_points_new) 60 | sketch_img_orig = Image.fromarray(sketch_img_orig).convert('RGB') 61 | # sketch_img_orig.show() 62 | sketch_img_orig.save(str(i_rdp) + '/' + str(num)+'Low_.jpg') 63 | 64 | # if sketch_points_new.shape[0] > 200: 65 | # combined_image = np.concatenate(( sketch_img_orig, sketch_img_rdp), axis=1) 66 | # combined_image = Image.fromarray(combined_image).convert('RGB') 67 | # combined_image.save('./saved_folder2/image_' + str(num) + '_@' + 68 | # str(sketch_points_orig.shape[0]) + 69 | # '_' + 70 | # str(sketch_points_new.shape[0]) + '.jpg') 71 | 72 | rdp_simplified[key] = sketch_points_new 73 | 74 | # print(sketch_img.shape) 75 | 76 | print('Max number of Points Old: {}'.format(max(max_points_old))) 77 | print('Max number of Points New: {}'.format(max(max_points_new))) 78 | 79 | with open('ShoeV2_RDP_' + str(i_rdp), 'wb') as fp: 80 | pickle.dump(rdp_simplified, fp) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # More Photos are All You Need: Semi-Supervised Learning for Fine-Grained Sketch Based Image Retrieval, CVPR 2021. 2 | **Ayan Kumar Bhunia**, Pinaki nath Chowdhury, Aneeshan Sain, Yongxin Yang, Tao Xiang, Yi-Zhe Song, “More Photos are All You Need: Semi-Supervised Learning for Fine-Grained Sketch Based Image Retrieval”, IEEE Conf. on Computer Vision and Pattern Recognition (**CVPR**), 2021. 3 | 4 | ## SketchX_ShoeV2/ChairV2 Dataset: [Download](https://drive.google.com/file/d/1frltfiEd9ymnODZFHYrbg741kfys1rq1/view?usp=sharing) 5 | 6 | ## Abstract 7 | A fundamental challenge faced by existing Fine-Grained Sketch-Based Image Retrieval (FG-SBIR) models is the data scarcity -- model performances are largely bottlenecked by the lack of sketch-photo pairs. Whilst the number of photos can be easily scaled, each corresponding sketch still needs to be individually produced. In this paper, we aim to mitigate such an upper-bound on sketch data, and study whether unlabelled photos alone (of which they are many) can be cultivated for performances gain. In particular, we introduce a novel semi-supervised framework for cross-modal retrieval that can additionally leverage large-scale unlabelled photos to account for data scarcity. At the centre of our semi-supervision design is a sequential photo-to-sketch generation model that aims to generate paired sketches for unlabelled photos. Importantly, we further introduce a discriminator guided mechanism to guide against unfaithful generation, together with a distillation loss based regularizer to provide tolerance against noisy training samples. Last but not least, we treat generation and retrieval as two conjugate problems, where a joint learning procedure is devised for each module to mutually benefit from each other. Extensive experiments show that our semi-supervised model yields significant performance boost over the state-of-the-art supervised alternatives, as well as existing methods that can exploit unlabelled photos for FG-SBIR. 8 | 9 | ## Outline 10 | ![Outline](./sample_images/outline.jpeg) 11 | 12 | **Figure:** Our proposed method additionally leverages large scale photos without any manually labelled paired sketches to improve FG-SBIR performance. Moreover, we show that the two conjugate process, photo-to-sketch generation and fine-grained SBIR, could improve each other by joint training. 13 | 14 | ## Joint Architecture 15 | 16 | ![Framework](./sample_images/framework.png) 17 | **Figure:** Our framework: a FG-SBIR model leverages large scale unlabelled photos using a sequential photo-to-sketch generation model along with labelled pairs. Discriminator guided instance-wise weighting and distillation loss are used to guard against the noisy generated data. Simultaneously, photo-to-sketch generation model learns by taking reward from FG-SBIR model and Discriminator via policy gradient (over both labelled and unlabelled) together with supervised VAE loss over labelled data. Note rasterization (vector to raster format) is a non-differentiable operation. 18 | 19 | 20 | ## Examples 21 | ![Framework](./sample_images/example.jpg) 22 | **Figure:** Qualitative results on our photo-to-sketch generation process, where sketch is shown with attention-map at progressive instances. 23 | 24 | 25 | ## Citation 26 | 27 | If you find this article useful in your research, please consider citing: 28 | ``` 29 | @InProceedings{semi-fgsbir, 30 | author = {Ayan Kumar Bhunia and Pinaki Nath Chowdhury and Aneeshan Sain and Yongxin Yang and Tao Xiang and Yi-Zhe Song}, 31 | title = {More Photos are All You Need: Semi-Supervised Learning for Fine-Grained Sketch Based Image Retrieval}, 32 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 33 | month = {June}, 34 | year = {2021} 35 | } 36 | ``` 37 | ## Work done at [SketchX Lab](http://sketchx.ai/), CVSSP, University of Surrey. 38 | -------------------------------------------------------------------------------- /Photo_to_Sketch_2D_Attention/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model import Photo2Sketch 3 | from dataset import get_dataloader 4 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5 | import argparse 6 | import random 7 | from matplotlib import pyplot as plt 8 | from rasterize import batch_rasterize_relative 9 | from torchvision.utils import save_image 10 | import time 11 | from dataset import get_sketchOnly_dataloader 12 | 13 | 14 | if __name__ == "__main__": 15 | 16 | parser = argparse.ArgumentParser(description='Photo2Sketch') 17 | 18 | parser.add_argument('--setup', type=str, default='QMUL', help='QuickDraw vs QMUL') 19 | parser.add_argument('--batchsize', type=int, default=1) 20 | parser.add_argument('--nThreads', type=int, default=8) 21 | 22 | parser.add_argument('--max_epoch', type=int, default=1) 23 | parser.add_argument('--eval_freq_iter', type=int, default=1000) 24 | 25 | 26 | parser.add_argument('--enc_rnn_size', default=256) 27 | parser.add_argument('--dec_rnn_size', default=512) 28 | parser.add_argument('--z_size', default=128) 29 | 30 | parser.add_argument('--num_mixture', default=20) 31 | parser.add_argument('--input_dropout_prob', default=0.9) 32 | parser.add_argument('--output_dropout_prob', default=0.9) 33 | parser.add_argument('--batch_size_sketch_rnn', default=100) 34 | 35 | parser.add_argument('--kl_weight_start', default=0.01) 36 | parser.add_argument('--kl_decay_rate', default=0.99995) 37 | parser.add_argument('--kl_tolerance', default=0.2) 38 | parser.add_argument('--kl_weight', default=1.0) 39 | 40 | parser.add_argument('--learning_rate', default=0.0001) 41 | parser.add_argument('--decay_rate', default=0.9999) 42 | parser.add_argument('--min_learning_rate', default=0.00001) 43 | parser.add_argument('--grad_clip', default=1.) 44 | 45 | hp = parser.parse_args() 46 | 47 | 48 | print(hp) 49 | model = Photo2Sketch(hp) 50 | model.to(device) 51 | model.load_state_dict(torch.load('./modelCVPR21/QMUL/model_photo2Sketch_QMUL_2Dattention_8000_.pth')) 52 | 53 | 54 | step = 0 55 | current_loss = 1e+10 56 | 57 | """ Model Training Image2Sketch """ 58 | 59 | if hp.setup == 'QuickDraw': 60 | 61 | dataloader = get_sketchOnly_dataloader(hp) 62 | 63 | for step in range(100000*2): 64 | sample = dataloader.train_batch() 65 | 66 | rgb_image = sample['photo'].to(device) 67 | sketch_vector = sample['sketch_vector'].to(device) # Seq_Len, Batch, Feature 68 | length_sketch = sample['length'].to(device) 69 | 70 | sup_p2s_loss, kl_cost_rgb, total_loss = model.Image2Sketch_Train(rgb_image, sketch_vector, length_sketch, step) 71 | 72 | if total_loss.item() < current_loss: 73 | torch.save(model.state_dict(), './modelCVPR21/model_photo2Sketch_Pretraining2D_K.pth') 74 | current_loss = total_loss.item() 75 | 76 | elif hp.setup == 'QMUL': 77 | 78 | dataloader_Train, dataloader_Test = get_dataloader(hp) 79 | 80 | for i_epoch in range(hp.max_epoch): 81 | for z_num, batch_data in enumerate(dataloader_Train): 82 | 83 | rgb_image = batch_data['photo'].to(device) 84 | sketch_vector = batch_data['sketch_vector'].to(device).permute(1, 0, 2).float() # Seq_Len, Batch, Feature 85 | length_sketch = batch_data['length'].to(device) - 1 # TODO: Relative coord has one less 86 | sketch_name = batch_data['sketch_path'][0] 87 | 88 | sup_p2s_loss, kl_cost_rgb, total_loss = model.Image2Sketch_Train(rgb_image, sketch_vector, 89 | length_sketch, step, sketch_name) 90 | step += 1 91 | 92 | print(z_num) 93 | # if total_loss.item() < current_loss: 94 | # torch.save(model.state_dict(), './modelCVPR21/model_photo2Sketch_QMUL_2Dattention_Pre.pth') 95 | # current_loss = total_loss.item() 96 | 97 | # if step % 1000 == 0: 98 | # torch.save(model.state_dict(), './modelCVPR21/QMUL/model_photo2Sketch_QMUL_2Dattention_' + str(step) + '_.pth') 99 | # current_loss = total_loss.item() -------------------------------------------------------------------------------- /Photo_to_Sketch_2D_Attention/CVPR_18_Baseline/Image_Networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as backbone_ 4 | 5 | 6 | class EncoderCNN(nn.Module): 7 | def __init__(self, hp=None): 8 | super(EncoderCNN, self).__init__() 9 | self.feature = Unet_Encoder(in_channels=3) 10 | self.fc_mu = nn.Linear(512, 128) 11 | self.fc_std = nn.Linear(512, 128) 12 | 13 | def forward(self, x): 14 | x = self.feature(x) 15 | mean = self.fc_mu(x) 16 | log_var = self.fc_std(x) 17 | posterior_dist = torch.distributions.Normal(mean, torch.exp(0.5 * log_var)) 18 | return posterior_dist 19 | 20 | class DecoderCNN(nn.Module): 21 | def __init__(self, hp=None): 22 | super(DecoderCNN, self).__init__() 23 | self.model = Unet_Decoder(out_channels=3) 24 | def forward(self, x): 25 | return self.model(x) 26 | 27 | 28 | class Unet_Encoder(nn.Module): 29 | def __init__(self, in_channels=3): 30 | super(Unet_Encoder, self).__init__() 31 | 32 | self.down_1 = Unet_DownBlock(in_channels, 32, normalize=False) 33 | self.down_2 = Unet_DownBlock(32, 64) 34 | self.down_3 = Unet_DownBlock(64, 128) 35 | self.down_4 = Unet_DownBlock(128, 256) 36 | self.down_5 = Unet_DownBlock(256, 256) 37 | self.linear_encoder = nn.Linear(256 * 8 * 8, 512) 38 | self.dropout = nn.Dropout(0.5) 39 | 40 | def forward(self, x): 41 | x = self.down_1(x) 42 | x = self.down_2(x) 43 | x = self.down_3(x) 44 | x = self.down_4(x) 45 | x = self.down_5(x) 46 | x = torch.flatten(x, start_dim=1) 47 | x = self.linear_encoder(x) 48 | x = self.dropout(x) 49 | return x 50 | 51 | 52 | class Unet_Decoder(nn.Module): 53 | def __init__(self, out_channels=3): 54 | super(Unet_Decoder, self).__init__() 55 | self.linear_1 = nn.Linear(128, 8*8*256) 56 | self.dropout = nn.Dropout(0.5) 57 | self.deconv_1 = Unet_UpBlock(256, 256) 58 | self.deconv_2 = Unet_UpBlock(256, 128) 59 | self.deconv_3 = Unet_UpBlock(128, 64) 60 | self.deconv_4 = Unet_UpBlock(64, 32) 61 | self.final_image = nn.Sequential(*[nn.ConvTranspose2d(32, out_channels, 62 | kernel_size=4, stride=2, 63 | padding=1), nn.Tanh()]) 64 | 65 | def forward(self, x): 66 | x = self.linear_1(x) 67 | x = x.view(-1, 256, 8, 8) 68 | x = self.dropout(x) 69 | x = self.deconv_1(x) 70 | x = self.deconv_2(x) 71 | x = self.deconv_3(x) 72 | x = self.deconv_4(x) 73 | x = self.final_image(x) 74 | return x 75 | 76 | 77 | class Unet_UpBlock(nn.Module): 78 | def __init__(self, inner_nc, outer_nc): 79 | super(Unet_UpBlock, self).__init__() 80 | layers = [ 81 | nn.ConvTranspose2d(inner_nc, outer_nc, 4, 2, 1, bias=True), 82 | nn.InstanceNorm2d(outer_nc), 83 | nn.ReLU(inplace=True), 84 | ] 85 | self.model = nn.Sequential(*layers) 86 | 87 | def forward(self, x): 88 | return self.model(x) 89 | 90 | 91 | class Unet_DownBlock(nn.Module): 92 | def __init__(self, inner_nc, outer_nc, normalize=True): 93 | super(Unet_DownBlock, self).__init__() 94 | layers = [nn.Conv2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=True)] 95 | if normalize: 96 | layers.append(nn.InstanceNorm2d(outer_nc)) 97 | layers.append(nn.LeakyReLU(0.2, True)) 98 | self.model = nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | return self.model(x) 102 | 103 | 104 | class VGG_encoder(nn.Module): 105 | def __init__(self, hp): 106 | super(VGG_encoder, self).__init__() 107 | self.feature = backbone_.vgg16(pretrained=True).features 108 | self.pool_method = nn.AdaptiveMaxPool2d(1) 109 | self.dropout = nn.Dropout(0.5) 110 | 111 | def forward(self, x): 112 | x = self.backbone(input) 113 | x = self.pool_method(x).view(-1, 512) 114 | x = self.dropout(x) 115 | return x 116 | 117 | 118 | def weights_init_normal(m): 119 | classname = m.__class__.__name__ 120 | if classname.find("Conv") != -1: 121 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 122 | if hasattr(m, "bias") and m.bias is not None: 123 | torch.nn.init.constant_(m.bias.data, 0.0) 124 | elif classname.find("BatchNorm2d") != -1: 125 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 126 | torch.nn.init.constant_(m.bias.data, 0.0) 127 | 128 | if __name__ == '__main__': 129 | 130 | pass -------------------------------------------------------------------------------- /Photo_to_Sketch_2D_Attention/Image_Networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as backbone_ 4 | 5 | 6 | 7 | 8 | class EncoderCNN(nn.Module): 9 | def __init__(self, hp=None): 10 | super(EncoderCNN, self).__init__() 11 | self.feature = backbone_.vgg16(pretrained=True).features 12 | self.pool_method = nn.AdaptiveMaxPool2d(1) 13 | self.fc_mu = nn.Linear(512, 128) 14 | self.fc_std = nn.Linear(512, 128) 15 | 16 | def forward(self, x): 17 | backbone_feature = self.feature(x) 18 | x = torch.flatten(self.pool_method(backbone_feature), start_dim=1) 19 | mean = self.fc_mu(x) 20 | log_var = self.fc_std(x) 21 | posterior_dist = torch.distributions.Normal(mean, torch.exp(0.5 * log_var)) 22 | return backbone_feature, posterior_dist 23 | 24 | 25 | # class EncoderCNN(nn.Module): 26 | # def __init__(self, hp=None): 27 | # super(EncoderCNN, self).__init__() 28 | # self.feature = Unet_Encoder(in_channels=3) 29 | # self.fc_mu = nn.Linear(512, 128) 30 | # self.fc_std = nn.Linear(512, 128) 31 | # 32 | # def forward(self, x): 33 | # x = self.feature(x) 34 | # mean = self.fc_mu(x) 35 | # log_var = self.fc_std(x) 36 | # posterior_dist = torch.distributions.Normal(mean, torch.exp(0.5 * log_var)) 37 | # return posterior_dist 38 | 39 | class DecoderCNN(nn.Module): 40 | def __init__(self, hp=None): 41 | super(DecoderCNN, self).__init__() 42 | self.model = Unet_Decoder(out_channels=3) 43 | def forward(self, x): 44 | return self.model(x) 45 | 46 | 47 | class Unet_Encoder(nn.Module): 48 | def __init__(self, in_channels=3): 49 | super(Unet_Encoder, self).__init__() 50 | 51 | self.down_1 = Unet_DownBlock(in_channels, 32, normalize=False) 52 | self.down_2 = Unet_DownBlock(32, 64) 53 | self.down_3 = Unet_DownBlock(64, 128) 54 | self.down_4 = Unet_DownBlock(128, 256) 55 | self.down_5 = Unet_DownBlock(256, 256) 56 | self.linear_encoder = nn.Linear(256 * 8 * 8, 512) 57 | self.dropout = nn.Dropout(0.5) 58 | 59 | def forward(self, x): 60 | x = self.down_1(x) 61 | x = self.down_2(x) 62 | x = self.down_3(x) 63 | x = self.down_4(x) 64 | x = self.down_5(x) 65 | x = torch.flatten(x, start_dim=1) 66 | x = self.linear_encoder(x) 67 | x = self.dropout(x) 68 | return x 69 | 70 | 71 | class Unet_Decoder(nn.Module): 72 | def __init__(self, out_channels=3): 73 | super(Unet_Decoder, self).__init__() 74 | self.linear_1 = nn.Linear(128, 8*8*256) 75 | self.dropout = nn.Dropout(0.5) 76 | self.deconv_1 = Unet_UpBlock(256, 256) 77 | self.deconv_2 = Unet_UpBlock(256, 128) 78 | self.deconv_3 = Unet_UpBlock(128, 64) 79 | self.deconv_4 = Unet_UpBlock(64, 32) 80 | self.final_image = nn.Sequential(*[nn.ConvTranspose2d(32, out_channels, 81 | kernel_size=4, stride=2, 82 | padding=1), nn.Tanh()]) 83 | 84 | def forward(self, x): 85 | x = self.linear_1(x) 86 | x = x.view(-1, 256, 8, 8) 87 | x = self.dropout(x) 88 | x = self.deconv_1(x) 89 | x = self.deconv_2(x) 90 | x = self.deconv_3(x) 91 | x = self.deconv_4(x) 92 | x = self.final_image(x) 93 | return x 94 | 95 | 96 | class Unet_UpBlock(nn.Module): 97 | def __init__(self, inner_nc, outer_nc): 98 | super(Unet_UpBlock, self).__init__() 99 | layers = [ 100 | nn.ConvTranspose2d(inner_nc, outer_nc, 4, 2, 1, bias=True), 101 | nn.InstanceNorm2d(outer_nc), 102 | nn.ReLU(inplace=True), 103 | ] 104 | self.model = nn.Sequential(*layers) 105 | 106 | def forward(self, x): 107 | return self.model(x) 108 | 109 | 110 | class Unet_DownBlock(nn.Module): 111 | def __init__(self, inner_nc, outer_nc, normalize=True): 112 | super(Unet_DownBlock, self).__init__() 113 | layers = [nn.Conv2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=True)] 114 | if normalize: 115 | layers.append(nn.InstanceNorm2d(outer_nc)) 116 | layers.append(nn.LeakyReLU(0.2, True)) 117 | self.model = nn.Sequential(*layers) 118 | 119 | def forward(self, x): 120 | return self.model(x) 121 | 122 | 123 | class VGG_encoder(nn.Module): 124 | def __init__(self, hp): 125 | super(VGG_encoder, self).__init__() 126 | self.feature = backbone_.vgg16(pretrained=True).features 127 | self.pool_method = nn.AdaptiveMaxPool2d(1) 128 | self.dropout = nn.Dropout(0.5) 129 | 130 | def forward(self, x): 131 | x = self.backbone(input) 132 | x = self.pool_method(x).view(-1, 512) 133 | x = self.dropout(x) 134 | return x 135 | 136 | 137 | def weights_init_normal(m): 138 | classname = m.__class__.__name__ 139 | if classname.find("Conv") != -1: 140 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 141 | if hasattr(m, "bias") and m.bias is not None: 142 | torch.nn.init.constant_(m.bias.data, 0.0) 143 | elif classname.find("BatchNorm2d") != -1: 144 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 145 | torch.nn.init.constant_(m.bias.data, 0.0) 146 | 147 | if __name__ == '__main__': 148 | 149 | pass -------------------------------------------------------------------------------- /Photo_to_Sketch_2D_Attention/CVPR_18_Baseline/utils.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import torch 3 | use_cuda = True 4 | from IPython.display import SVG, display 5 | import numpy as np 6 | import svgwrite 7 | from six.moves import xrange 8 | import math 9 | import torch.nn as nn 10 | import torchvision 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | import torch.nn.functional as F 13 | import torch 14 | import torchvision 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | from torch.utils.tensorboard import SummaryWriter 17 | import os 18 | import shutil 19 | 20 | def to_normal_strokes(big_stroke): 21 | """Convert from stroke-5 format (from sketch-rnn paper) back to stroke-3.""" 22 | 23 | l = 0 24 | for i in range(len(big_stroke)): 25 | if big_stroke[i, 4] > 0: 26 | l = i 27 | break 28 | if l == 0: 29 | l = len(big_stroke)-1 30 | result = np.zeros((l+1, 3)) 31 | result[:, 0:2] = big_stroke[0:l+1, 0:2] 32 | result[:, 2] = big_stroke[0:l+1, 3] 33 | result[-1, -1] = 1. 34 | return result 35 | 36 | def get_bounds(data, factor=10): 37 | """Return bounds of data.""" 38 | min_x = 0 39 | max_x = 0 40 | min_y = 0 41 | max_y = 0 42 | 43 | abs_x = 0 44 | abs_y = 0 45 | for i in range(len(data)): 46 | x = float(data[i, 0]) / factor 47 | y = float(data[i, 1]) / factor 48 | abs_x += x 49 | abs_y += y 50 | min_x = min(min_x, abs_x) 51 | min_y = min(min_y, abs_y) 52 | max_x = max(max_x, abs_x) 53 | max_y = max(max_y, abs_y) 54 | 55 | return (min_x, max_x, min_y, max_y) 56 | 57 | 58 | 59 | def transfer_ImageNomralization(x, type='to_Gen'): 60 | # https://discuss.pytorch.org/t/how-to-normalize-multidimensional-tensor/65304 61 | #to_Gen (-1, 1) vs to_Recog (ImageNet Normalize) 62 | if type == 'to_Gen': 63 | # First Unnormalize 64 | mean = torch.tensor([-0.485/0.229, -0.456/0.224, -0.406/0.225]).to(device) 65 | std = torch.tensor([1/0.229, 1/0.224, 1/0.225]).to(device) 66 | x.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) 67 | # Then Normalize Again 68 | mean = torch.tensor([0.5, 0.5, 0.5]).to(device) 69 | std = torch.tensor([0.5, 0.5, 0.5]).to(device) 70 | x.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) 71 | 72 | elif type == 'to_Recog': 73 | # First Unnormalize 74 | mean = torch.tensor([-1.0, -1.0, -1.0]).to(device) 75 | std = torch.tensor([1/0.5, 1/0.5, 1/0.5]).to(device) 76 | x.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) 77 | # Then Normalize Again 78 | mean = torch.tensor([0.485, 0.456, 0.406]).to(device) 79 | std = torch.tensor([0.229, 0.224, 0.225]).to(device) 80 | x.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) 81 | return x 82 | 83 | def sample_next_state(output, hp, temperature =0.01): 84 | 85 | def adjust_temp(pi_pdf): 86 | pi_pdf = np.log(pi_pdf)/temperature 87 | pi_pdf -= pi_pdf.max() 88 | pi_pdf = np.exp(pi_pdf) 89 | pi_pdf /= pi_pdf.sum() 90 | return pi_pdf 91 | 92 | [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen_logits] = output 93 | # get mixture indices: 94 | o_pi = o_pi.data[0,:].cpu().numpy() 95 | o_pi = adjust_temp(o_pi) 96 | pi_idx = np.random.choice(hp.num_mixture, p=o_pi) 97 | # get pen state: 98 | o_pen = F.softmax(o_pen_logits, dim=-1) 99 | o_pen = o_pen.data[0,:].cpu().numpy() 100 | pen = adjust_temp(o_pen) 101 | pen_idx = np.random.choice(3, p=pen) 102 | # get mixture params: 103 | o_mu1 = o_mu1.data[0,pi_idx].item() 104 | o_mu2 = o_mu2.data[0,pi_idx].item() 105 | o_sigma1 = o_sigma1.data[0,pi_idx].item() 106 | o_sigma2 = o_sigma2.data[0,pi_idx].item() 107 | o_corr = o_corr.data[0,pi_idx].item() 108 | x,y = sample_bivariate_normal(o_mu1,o_mu2,o_sigma1,o_sigma2,o_corr, temperature = temperature, greedy=False) 109 | next_state = torch.zeros(5) 110 | next_state[0] = x 111 | next_state[1] = y 112 | next_state[pen_idx+2] = 1 113 | return next_state.to(device).view(1,1,-1), next_state 114 | 115 | 116 | def sample_bivariate_normal(mu_x, mu_y, sigma_x, sigma_y, rho_xy, temperature = 0.2, greedy=False): 117 | # inputs must be floats 118 | if greedy: 119 | return mu_x, mu_y 120 | mean = [mu_x, mu_y] 121 | sigma_x *= np.sqrt(temperature) #confusion 122 | sigma_y *= np.sqrt(temperature) #confusion 123 | cov = [[sigma_x * sigma_x, rho_xy * sigma_x * sigma_y], \ 124 | [rho_xy * sigma_x * sigma_y, sigma_y * sigma_y]] 125 | x = np.random.multivariate_normal(mean, cov, 1) 126 | return x[0][0], x[0][1] 127 | 128 | class Visualizer: 129 | def __init__(self, name = 'Photo2Sketch'): 130 | 131 | if os.path.exists('Tensorboard_' + name): 132 | shutil.rmtree('Tensorboard_' + name) 133 | 134 | self.writer = SummaryWriter('Tensorboard_' + name) 135 | 136 | self.mean = torch.tensor([-1.0, -1.0, -1.0]).to(device) 137 | self.std = torch.tensor([1 / 0.5, 1 / 0.5, 1 / 0.5]).to(device) 138 | 139 | def vis_image(self, visularize, step): 140 | for keys, value in visularize.items(): 141 | #print(keys,value.size()) 142 | value.sub_(self.mean[None, :, None, None]).div_(self.std[None, :, None, None]) 143 | visularize[keys] = torchvision.utils.make_grid(value) 144 | self.writer.add_image('{}'.format(keys), visularize[keys], step) 145 | 146 | 147 | def plot_scalars(self, scalars, step): 148 | 149 | for keys, value in scalars.items(): 150 | #print(keys,value.size()) 151 | self.writer.add_scalar('{}'.format(keys), scalars[keys], step) -------------------------------------------------------------------------------- /Photo_to_Sketch_2D_Attention/CVPR_18_Baseline/Sketch_Networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 4 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5 | # from utils import * 6 | import torch.nn as nn 7 | import numpy as np 8 | 9 | class EncoderRNN(nn.Module): 10 | def __init__(self, hp): 11 | super(EncoderRNN, self).__init__() 12 | self.lstm = nn.LSTM(5, hp.enc_rnn_size, dropout=hp.input_dropout_prob, bidirectional=True) 13 | self.fc_mu = nn.Linear(2*hp.enc_rnn_size, hp.z_size) 14 | self.fc_sigma = nn.Linear(2*hp.enc_rnn_size, hp.z_size) 15 | 16 | def forward(self, x, Seq_Len=None): 17 | x = pack_padded_sequence(x, Seq_Len, enforce_sorted=False) 18 | _, (h_n, _) = self.lstm(x.float()) 19 | h_n = h_n.permute(1,0,2).reshape(h_n.shape[1], -1) 20 | mean = self.fc_mu(h_n) 21 | log_var = self.fc_sigma(h_n) 22 | posterior_dist = torch.distributions.Normal(mean, torch.exp(0.5 * log_var)) 23 | return posterior_dist 24 | 25 | 26 | 27 | class DecoderRNN(nn.Module): 28 | def __init__(self, hp): 29 | super(DecoderRNN, self).__init__() 30 | self.fc_hc = nn.Linear(hp.z_size, 2 * hp.dec_rnn_size) 31 | self.lstm = nn.LSTM(hp.z_size + 5, hp.dec_rnn_size, dropout=hp.output_dropout_prob) 32 | self.fc_params = nn.Linear(hp.dec_rnn_size, 6 * hp.num_mixture + 3) 33 | self.hp = hp 34 | 35 | def forward(self, inputs, z_vector, seq_len = None, hidden_cell=None, isTrain = True, get_deterministic = True): 36 | 37 | self.training = isTrain 38 | if hidden_cell is None: 39 | hidden, cell = torch.split(F.tanh(self.fc_hc(z_vector)), self.hp.dec_rnn_size, 1) 40 | hidden_cell = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous()) 41 | 42 | if seq_len is None: 43 | # seq_len = torch.tensor([1]).type(torch.int64).to(device) 44 | seq_len = torch.ones(inputs.shape[1]).type(torch.int64).to(device) 45 | 46 | inputs = pack_padded_sequence(inputs, seq_len, enforce_sorted=False) 47 | outputs, (hidden, cell) = self.lstm(inputs, hidden_cell) 48 | outputs, _ = pad_packed_sequence(outputs) 49 | 50 | if self.training: 51 | if outputs.shape[0] != (self.hp.max_seq_len + 1): 52 | pad = torch.zeros(outputs.shape[-1]).repeat(self.hp.max_seq_len + 1 - outputs.shape[0], outputs.shape[1], 1).cuda() 53 | outputs = torch.cat((outputs, pad), dim=0) 54 | y_output = self.fc_params(outputs.permute(1,0,2)) 55 | else: 56 | y_output = self.fc_params(hidden.permute(1,0,2)) 57 | 58 | z_pen_logits = y_output[:, :, 0:3] 59 | z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = torch.chunk(y_output[:, :, 3:], 6, 2) 60 | z_pi = F.softmax(z_pi, dim=-1) 61 | z_sigma1 = torch.exp(z_sigma1) 62 | z_sigma2 = torch.exp(z_sigma2) 63 | z_corr = torch.tanh(z_corr) 64 | 65 | 66 | if (not self.training) and get_deterministic: 67 | batch_size = z_pi.shape[0] 68 | z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen_logits = z_pi.reshape(-1, 20), z_mu1.reshape(-1, 20), z_mu2.reshape(-1, 20), \ 69 | z_sigma1.reshape(-1, 20), z_sigma2.reshape(-1, 20), z_corr.reshape(-1, 20), z_pen_logits.reshape(-1, 3) 70 | 71 | recons_output = torch.zeros(batch_size, 5).to(device) 72 | z_pi_idx = z_pi.argmax(dim=-1) 73 | z_pen_idx = z_pen_logits.argmax(-1) 74 | recons_output[:, 0] = z_mu1[range(z_mu1.shape[0]), z_pi_idx] 75 | recons_output[:, 1] = z_mu2[range(z_mu2.shape[0]), z_pi_idx] 76 | 77 | recons_output[range(z_mu1.shape[0]), z_pen_idx + 2] = 1. 78 | 79 | return recons_output.unsqueeze(0).data, (hidden, cell) 80 | 81 | return [z_pi.reshape(-1, 20), z_mu1.reshape(-1, 20), z_mu2.reshape(-1, 20), \ 82 | z_sigma1.reshape(-1, 20), z_sigma2.reshape(-1, 20), z_corr.reshape(-1, 20), z_pen_logits.reshape(-1, 3)], (hidden, cell) 83 | 84 | 85 | def torch_2d_normal(x1, x2, mu1, mu2, s1, s2, rho): 86 | """Returns result of eq # 24 of http://arxiv.org/abs/1308.0850.""" 87 | norm1 = x1 - mu1 88 | norm2 = x2 - mu2 89 | s1s2 = s1 * s2 90 | 91 | z_1 = (norm1 / s1) ** 2 92 | z_2 = (norm2 / s2) ** 2 93 | z1_z2 = (norm1 * norm2) / s1s2 94 | 95 | z = z_1 + z_2 - 2 * rho * z1_z2 96 | neg_rho = 1 - rho ** 2 97 | result = torch.exp(-z / (2 * neg_rho)) 98 | denom = 2 * np.pi * s1s2 * torch.sqrt(neg_rho) 99 | return result / denom 100 | 101 | 102 | def sketch_reconstruction_loss(output, x_input): 103 | # x_input = 104 | # Ouput = Predicted 123 parameters from decoder = Batch*Max_seq_len, 20 105 | [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen_logits] = output 106 | [x1_data, x2_data, eos_data, eoc_data, cont_data] = torch.chunk(x_input.reshape(-1, 5), 5, 1) 107 | pen_data = torch.cat([eos_data, eoc_data, cont_data], 1) 108 | mask = 1.0 - pen_data[:, 2] # use training data for this 109 | 110 | result0 = torch_2d_normal(x1_data, x2_data, o_mu1, o_mu2, o_sigma1, o_sigma2, 111 | o_corr) 112 | epsilon = 1e-6 113 | 114 | result1 = torch.sum(result0 * o_pi, dim=1) # ? unsqueeae(-1) ?? 115 | result1 = -torch.log(result1 + epsilon) # avoid log(0) 116 | 117 | result2 = F.cross_entropy(o_pen_logits, pen_data.argmax(1), reduction='none') 118 | 119 | result = mask * result1 + mask * result2 120 | # result = result1 + result2 121 | 122 | return result.mean() 123 | 124 | def set_learninRate(optimizer, curr_learning_rate): 125 | for g in optimizer.param_groups: 126 | g['lr'] = curr_learning_rate 127 | -------------------------------------------------------------------------------- /Photo_to_Sketch_2D_Attention/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from Image_Networks import * 3 | from Sketch_Networks import * 4 | from torch import optim 5 | import torch 6 | import time 7 | import torch.nn.functional as F 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | from utils import * 10 | import torchvision 11 | from dataset import get_sketchOnly_dataloader 12 | from rasterize import rasterize_relative, to_stroke_list 13 | import math 14 | from rasterize import batch_rasterize_relative 15 | from base_model import Photo2Sketch_Base 16 | from torchvision.utils import save_image 17 | import os 18 | 19 | 20 | class Photo2Sketch(Photo2Sketch_Base): 21 | def __init__(self, hp): 22 | 23 | Photo2Sketch_Base.__init__(self, hp) 24 | self.train_params = self.parameters() 25 | self.optimizer = optim.Adam(self.train_params, hp.learning_rate, betas=(0.5, 0.999)) 26 | # self.visualizer = Visualizer() 27 | 28 | def Image2Sketch_Train(self, rgb_image, sketch_vector, length_sketch, step, sketch_name): 29 | 30 | self.train() 31 | self.optimizer.zero_grad() 32 | 33 | curr_learning_rate = ((self.hp.learning_rate - self.hp.min_learning_rate) * 34 | (self.hp.decay_rate) ** step + self.hp.min_learning_rate) 35 | curr_kl_weight = (self.hp.kl_weight - (self.hp.kl_weight - self.hp.kl_weight_start) * 36 | (self.hp.kl_decay_rate) ** step) 37 | 38 | 39 | """ Encoding the Input """ 40 | backbone_feature, rgb_encoded_dist = self.Image_Encoder(rgb_image) 41 | rgb_encoded_dist_z_vector = rgb_encoded_dist.rsample() 42 | 43 | """ Ditribution Matching Loss """ 44 | prior_distribution = torch.distributions.Normal(torch.zeros_like(rgb_encoded_dist.mean), 45 | torch.ones_like(rgb_encoded_dist.stddev)) 46 | 47 | kl_cost_rgb = torch.max(torch.distributions.kl_divergence(rgb_encoded_dist, prior_distribution).mean(), torch.tensor(self.hp.kl_tolerance).to(device)) 48 | 49 | ############################################################## 50 | ############################################################## 51 | """ Cross Modal the Decoding """ 52 | ############################################################## 53 | ############################################################## 54 | 55 | photo2sketch_output = self.Sketch_Decoder(backbone_feature, rgb_encoded_dist_z_vector, sketch_vector, length_sketch + 1) 56 | 57 | end_token = torch.stack([torch.tensor([0, 0, 0, 0, 1])] * rgb_image.shape[0]).unsqueeze(0).to(device).float() 58 | batch = torch.cat([sketch_vector, end_token], 0) 59 | x_target = batch.permute(1, 0, 2) # batch-> Seq_Len, Batch, Feature_dim 60 | 61 | sup_p2s_loss = sketch_reconstruction_loss(photo2sketch_output, x_target) #TODO: Photo to Sketch Loss 62 | 63 | loss = sup_p2s_loss + curr_kl_weight*kl_cost_rgb 64 | 65 | set_learninRate(self.optimizer, curr_learning_rate) 66 | loss.backward() 67 | nn.utils.clip_grad_norm_(self.train_params, self.hp.grad_clip) 68 | self.optimizer.step() 69 | 70 | print('Step:{} ** sup_p2s_loss:{} ** kl_cost_rgb:{} ** Total_loss:{}'.format(step, sup_p2s_loss, 71 | kl_cost_rgb, loss)) 72 | 73 | 74 | if step%5 == 0: 75 | 76 | data = {} 77 | data['Reconstrcution_Loss'] = sup_p2s_loss 78 | data['KL_Loss'] = kl_cost_rgb 79 | data['Total Loss'] = loss 80 | 81 | self.visualizer.plot_scalars(data, step) 82 | 83 | 84 | if step%1 == 0: 85 | 86 | folder_name = os.path.join('./CVPR_SSL/' + '_'.join(sketch_name.split('/')[-1].split('_')[:-1])) 87 | if not os.path.exists(folder_name): 88 | os.makedirs(folder_name) 89 | 90 | sketch_vector_gt = sketch_vector.permute(1, 0, 2) 91 | 92 | save_sketch(sketch_vector_gt[0], sketch_name) 93 | 94 | 95 | with torch.no_grad(): 96 | photo2sketch_gen, attention_plot = \ 97 | self.Sketch_Decoder(backbone_feature, rgb_encoded_dist_z_vector, sketch_vector, length_sketch+1, isTrain=False) 98 | 99 | sketch_vector_gt = sketch_vector.permute(1, 0, 2) 100 | 101 | 102 | for num, len in enumerate(length_sketch): 103 | photo2sketch_gen[num, len:, 4 ] = 1.0 104 | photo2sketch_gen[num, len:, 2:4] = 0.0 105 | 106 | save_sketch_gen(photo2sketch_gen[0], sketch_name) 107 | 108 | sketch_vector_gt_draw = batch_rasterize_relative(sketch_vector_gt) 109 | photo2sketch_gen_draw = batch_rasterize_relative(photo2sketch_gen) 110 | 111 | batch_redraw = [] 112 | plot_attention = showAttention(attention_plot, rgb_image, sketch_vector_gt_draw, photo2sketch_gen_draw, sketch_name) 113 | # max_image = 5 114 | # for a, b, c, d in zip(sketch_vector_gt_draw[:max_image], rgb_image.cpu()[:max_image], 115 | # photo2sketch_gen_draw[:max_image], plot_attention[:max_image]): 116 | # batch_redraw.append(torch.cat((1. - a, b, 1. - c, d), dim=-1)) 117 | # 118 | # torchvision.utils.save_image(torch.stack(batch_redraw), './Redraw_Photo2Sketch_' 119 | # + self.hp.setup + '/redraw_{}.jpg'.format(step), 120 | # nrow=1, normalize=False) 121 | 122 | # data = {'attention_1': [], 'attention_2':[]} 123 | # for x in attention_plot: 124 | # data['attention_1'].append(x[0]) 125 | # data['attention_2'].append(x[2]) 126 | # 127 | # data['attention_1'] = torch.stack(data['attention_1']) 128 | # data['attention_2'] = torch.stack(data['attention_2']) 129 | # 130 | # self.visualizer.vis_image(data, step) 131 | 132 | 133 | 134 | # return sup_p2s_loss, kl_cost_rgb, loss 135 | 136 | return 0, 0, 0 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /Photo_to_Sketch_2D_Attention/CVPR_18_Baseline/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model import Photo2Sketch 3 | from dataset import get_dataloader 4 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5 | import argparse 6 | import random 7 | from matplotlib import pyplot as plt 8 | from rasterize import batch_rasterize_relative 9 | from torchvision.utils import save_image 10 | import time 11 | 12 | if __name__ == "__main__": 13 | 14 | parser = argparse.ArgumentParser(description='Photo2Sketch') 15 | # parser.add_argument('--backbone_name', type=str, default='Resnet', help='VGG / InceptionV3/ Resnet') 16 | # parser.add_argument('--pool_method', type=str, default='AdaptiveAvgPool2d', help='AdaptiveMaxPool2d / AdaptiveAvgPool2d / AvgPool2d') 17 | parser.add_argument('--batchsize', type=int, default=32) 18 | parser.add_argument('--nThreads', type=int, default=8) 19 | 20 | # parser.add_argument('--learning_rate', type=float, default=0.0001) 21 | parser.add_argument('--max_epoch', type=int, default=200) 22 | parser.add_argument('--eval_freq_iter', type=int, default=1000) 23 | 24 | 25 | parser.add_argument('--enc_rnn_size', default=256) 26 | parser.add_argument('--dec_rnn_size', default=512) 27 | parser.add_argument('--z_size', default=128) 28 | 29 | parser.add_argument('--num_mixture', default=20) 30 | parser.add_argument('--input_dropout_prob', default=0.9) 31 | parser.add_argument('--output_dropout_prob', default=0.9) 32 | parser.add_argument('--batch_size_sketch_rnn', default=100) 33 | 34 | parser.add_argument('--kl_weight_start', default=0.01) 35 | parser.add_argument('--kl_decay_rate', default=0.99995) 36 | parser.add_argument('--kl_tolerance', default=0.2) 37 | parser.add_argument('--kl_weight', default=1.0) 38 | 39 | parser.add_argument('--learning_rate', default=0.001) 40 | parser.add_argument('--decay_rate', default=0.9999) 41 | parser.add_argument('--min_learning_rate', default=0.00001) 42 | parser.add_argument('--grad_clip', default=1.) 43 | 44 | # parser.add_argument('--sketch_rnn_max_seq_len', default=200) 45 | 46 | hp = parser.parse_args() 47 | 48 | print(hp) 49 | model = Photo2Sketch(hp) 50 | model.to(device) 51 | 52 | # """ Load Pretrained Model """ 53 | # model.Sketch_Encoder.load_state_dict(torch.load('./pretrain_models/Sketch_Encoder.pth', map_location=device)) 54 | # model.Sketch_Decoder.load_state_dict(torch.load('./pretrain_models/Sketch_Decoder.pth', map_location=device)) 55 | 56 | """ Model Pretraining """ 57 | # model.pretrain_SketchBranch(iteration=100000) 58 | # model.pretrain_ImageBranch() 59 | model.pretrain_SketchBranch_ShoeV2() 60 | 61 | """ Load Pretrained Model """ 62 | model.Image_Encoder.load_state_dict(torch.load('./pretrain_models/Image_Encoder.pth', map_location=device)) 63 | model.Image_Decoder.load_state_dict(torch.load('./pretrain_models/Image_Decoder.pth', map_location=device)) 64 | model.Sketch_Encoder.load_state_dict(torch.load('./pretrain_models/Sketch_Encoder.pth', map_location=device)) 65 | model.Sketch_Decoder.load_state_dict(torch.load('./pretrain_models/Sketch_Decoder.pth', map_location=device)) 66 | 67 | """ Model Training Image2Sketch """ 68 | dataloader_Train, dataloader_Test = get_dataloader(hp) 69 | step = 0 70 | loss_best = 0 71 | 72 | 73 | for i_epoch in range(hp.max_epoch): 74 | for batch_data in dataloader_Train: 75 | rgb_image = batch_data['positive_img'].to(device) 76 | sketch_vector = batch_data['relative_fivePoint'].to(device).permute(1, 0, 2).float() # Seq_Len, Batch, Feature 77 | length_sketch = batch_data['sketch_length'].to(device) -1 #TODO: Relative coord has one less 78 | 79 | sup_p2s_loss, sup_s2p_loss, KL_1, KL_2, \ 80 | short_p2p, short_s2s, total_loss = model.Image2Sketch_Train(rgb_image, sketch_vector, length_sketch, step) 81 | 82 | print('Step:{} ** sup_p2s_loss:{} ** sup_s2p_loss:{} ** KL_1:{} ** KL_2:{} ' 83 | '** short_p2p:{} ** short_s2s:{} ** Total_loss:{}'.format(step, sup_p2s_loss, sup_s2p_loss, KL_1, KL_2, 84 | short_p2p, short_s2s, total_loss)) 85 | 86 | # print(batch_data['sketch_img'].shape) 87 | # 88 | # start_time = time.time() 89 | # save_image(1. - batch_rasterize_relative(batch_data['relative_fivePoint']), 'a.jpg') 90 | # print('Time:{}'.format(time.time() - start_time)) 91 | # 92 | # start_time = time.time() 93 | # save_image(1. - batch_rasterize_relative(batch_data['relative_coordinate']), 'b.jpg') 94 | # print('Time:{}'.format(time.time() - start_time)) 95 | # 96 | # 97 | # save_image(batch_data['sketch_img'], 'c.jpg') 98 | 99 | 100 | 101 | 102 | # for i_epoch in range(hp.max_epoch): 103 | # for batch_data in dataloader_Train: 104 | # kl_cost, recons_loss, loss, curr_kl_weight = model.train_model(batch_data, step) 105 | # step = step + 1 106 | # print('Step:{} ** Current_KL:{} ** KL_Loss:{} ' 107 | # '** Recons_Loss:{} ** Total_loss:{}'.format(step, curr_kl_weight, 108 | # kl_cost, recons_loss, loss)) 109 | # if (step + 1) % hp.eval_freq_iter == 0: 110 | # rand_int = random.randint(1, 21) 111 | # for i_num, batch_data in enumerate(dataloader_Test): 112 | # if i_num > rand_int: 113 | # break 114 | # 115 | # kl_cost, recons_loss, loss, curr_kl_weight = \ 116 | # model.test_model(batch_data, step) 117 | # 118 | # print('### Evaluation ### \n Step:{} ** Current_KL:{} ** KL_Loss:{} ' 119 | # '** Recons_Loss:{} ** Total_loss:{}'.format(step, curr_kl_weight, 120 | # kl_cost, recons_loss, loss)) 121 | # if loss < loss_best: 122 | # loss_best = loss 123 | # print('### Model Updated ###') 124 | # torch.save(model.state_dict(), 'model_best.pth') 125 | 126 | -------------------------------------------------------------------------------- /Photo_to_Sketch_2D_Attention/CVPR_18_Baseline/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from Image_Networks import * 3 | from Sketch_Networks import * 4 | from torch import optim 5 | import torch 6 | import time 7 | import torch.nn.functional as F 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | from utils import * 10 | import torchvision 11 | from dataset import get_imageOnly_dataloader, get_sketchOnly_dataloader 12 | from rasterize import rasterize_relative, to_stroke_list 13 | import math 14 | from rasterize import batch_rasterize_relative 15 | from base_model import Photo2Sketch_Base 16 | from torchvision.utils import save_image 17 | 18 | class Photo2Sketch(Photo2Sketch_Base): 19 | def __init__(self, hp): 20 | 21 | Photo2Sketch_Base.__init__(self, hp) 22 | self.train_params = self.parameters() 23 | self.main_optimizer = optim.Adam(self.train_params, hp.learning_rate) 24 | 25 | 26 | def Image2Sketch_Train(self, rgb_image, sketch_vector, length_sketch, step): 27 | 28 | self.train() 29 | self.main_optimizer.zero_grad() 30 | 31 | """ Encoding the Input """ 32 | sketch_encoded_dist = self.Sketch_Encoder(sketch_vector, length_sketch) 33 | sketch_encoded_z_vector = sketch_encoded_dist.rsample() 34 | 35 | rgb_encoded_dist = self.Image_Encoder(rgb_image) 36 | rgb_encoded_dist_z_vector = rgb_encoded_dist.rsample() 37 | 38 | """ Ditribution Matching Loss """ 39 | prior_distribution = torch.distributions.Normal(torch.zeros_like(sketch_encoded_dist.mean), 40 | torch.ones_like(sketch_encoded_dist.stddev)) 41 | kl_cost_1 = torch.distributions.kl_divergence(sketch_encoded_dist, prior_distribution).sum() 42 | kl_cost_2 = torch.distributions.kl_divergence(rgb_encoded_dist, prior_distribution).sum() 43 | 44 | 45 | ############################################################## 46 | """ Cross Modal the Decoding """ 47 | ############################################################## 48 | 49 | """ a) Photo to Sketch """ 50 | start_token = torch.stack([torch.tensor([0, 0, 1, 0, 0])] *rgb_image.shape[0]).unsqueeze(0).float().to(device) 51 | batch_init = torch.cat([start_token, sketch_vector], 0) 52 | z_stack = torch.stack([rgb_encoded_dist_z_vector] * (self.hp.max_seq_len + 1)) 53 | inputs = torch.cat([batch_init, z_stack], 2) 54 | 55 | photo2sketch_output, _ = self.Sketch_Decoder(inputs, rgb_encoded_dist_z_vector, length_sketch + 1) 56 | 57 | end_token = torch.stack([torch.tensor([0, 0, 0, 0, 1])] * rgb_image.shape[0]).unsqueeze(0).to(device).float() 58 | batch = torch.cat([sketch_vector, end_token], 0) 59 | x_target = batch.permute(1, 0, 2) # batch-> Seq_Len, Batch, Feature_dim 60 | 61 | sup_p2s_loss = sketch_reconstruction_loss(photo2sketch_output, x_target) #TODO: Photo to Sketch Loss 62 | 63 | 64 | """ b) Sketch to Photo """ 65 | cross_recons_photo = self.Image_Decoder(sketch_encoded_z_vector) 66 | # sup_s2p_loss = F.mse_loss(rgb_image, cross_recons_photo, reduction='sum')/rgb_image.shape[0] #TODO: Sketch 2 Photo Loss 67 | sup_s2p_loss = F.mse_loss(rgb_image, cross_recons_photo) 68 | 69 | ############################################################## 70 | """ Self Modal the Decoding """ 71 | ############################################################## 72 | """ a) Photo to photo """ 73 | self_recons_photo = self.Image_Decoder(rgb_encoded_dist_z_vector) 74 | # short_p2p_loss = F.mse_loss(rgb_image, self_recons_photo, reduction='sum')/rgb_image.shape[0] 75 | short_p2p_loss = F.mse_loss(rgb_image, self_recons_photo) 76 | 77 | """ a) Sketch to Sketch """ 78 | start_token = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * rgb_image.shape[0]).unsqueeze(0).to(device).float() 79 | batch_init = torch.cat([start_token, sketch_vector], 0) 80 | z_stack = torch.stack([sketch_encoded_z_vector] * (self.hp.max_seq_len + 1)) 81 | inputs = torch.cat([batch_init, z_stack], 2) 82 | 83 | sketch2sketch_output, _ = self.Sketch_Decoder(inputs, sketch_encoded_z_vector, length_sketch + 1) 84 | 85 | end_token = torch.stack([torch.Tensor([0, 0, 0, 0, 1])] * rgb_image.shape[0]).unsqueeze(0).to(device).float() 86 | batch = torch.cat([sketch_vector, end_token], 0) 87 | x_target = batch.permute(1, 0, 2) # batch-> Seq_Len, Batch, Feature_dim 88 | 89 | short_s2s_loss = sketch_reconstruction_loss(sketch2sketch_output, x_target) # TODO: Photo to Sketch Loss 90 | 91 | loss = sup_p2s_loss + sup_s2p_loss + short_p2p_loss + short_s2s_loss + 0.01*(kl_cost_1 + kl_cost_2) 92 | 93 | loss.backward() 94 | nn.utils.clip_grad_norm(self.train_params, self.hp.grad_clip) 95 | self.main_optimizer.step() 96 | 97 | 98 | if step%1000 == 0: 99 | 100 | """ Draw Photo to Sketch """ 101 | start_token = torch.Tensor([0, 0, 1, 0, 0]).view(-1, 5).to(device) 102 | start_token = torch.stack([start_token] * rgb_encoded_dist_z_vector.shape[0], dim=1) 103 | state = start_token 104 | hidden_cell = None 105 | 106 | batch_gen_strokes = [] 107 | for i_seq in range(self.hp.max_seq_len): 108 | input = torch.cat([state, rgb_encoded_dist_z_vector.unsqueeze(0)], 2) 109 | state, hidden_cell = self.Sketch_Decoder(input, rgb_encoded_dist_z_vector, hidden_cell=hidden_cell, isTrain=False, 110 | get_deterministic=True) 111 | batch_gen_strokes.append(state.squeeze(0)) 112 | photo2sketch_gen = torch.stack(batch_gen_strokes, dim=1) 113 | 114 | """ Draw Sketch to Sketch """ 115 | start_token = torch.Tensor([0, 0, 1, 0, 0]).view(-1, 5).to(device) 116 | start_token = torch.stack([start_token] * sketch_encoded_z_vector.shape[0], dim=1) 117 | state = start_token 118 | hidden_cell = None 119 | 120 | batch_gen_strokes = [] 121 | for i_seq in range(self.hp.max_seq_len): 122 | input = torch.cat([state, sketch_encoded_z_vector.unsqueeze(0)], 2) 123 | state, hidden_cell = self.Sketch_Decoder(input, sketch_encoded_z_vector, hidden_cell=hidden_cell, isTrain=False, 124 | get_deterministic=True) 125 | batch_gen_strokes.append(state.squeeze(0)) 126 | sketch2sketch_gen = torch.stack(batch_gen_strokes, dim=1) 127 | 128 | sketch_vector_gt = sketch_vector.permute(1, 0, 2) 129 | 130 | sketch_vector_gt_draw = batch_rasterize_relative(sketch_vector_gt).to(device) 131 | photo2sketch_gen_draw = batch_rasterize_relative(photo2sketch_gen).to(device) 132 | sketch2sketch_gen_draw = batch_rasterize_relative(sketch2sketch_gen).to(device) 133 | 134 | batch_redraw = [] 135 | for a, b, c, d, e ,f in zip(sketch_vector_gt_draw, rgb_image, photo2sketch_gen_draw, sketch2sketch_gen_draw, self_recons_photo, cross_recons_photo): 136 | batch_redraw.append(torch.cat((1. - a, b, 1. - c, 1. - d, e, f), dim=-1)) 137 | 138 | torchvision.utils.save_image(torch.stack(batch_redraw), './Redraw_Photo2Sketch/redraw_{}.jpg'.format(step), 139 | nrow=6) 140 | 141 | 142 | return sup_p2s_loss, sup_s2p_loss, short_p2p_loss, short_s2s_loss, kl_cost_1, kl_cost_2, loss 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /Photo_to_Sketch_2D_Attention/rasterize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from bresenham import bresenham 3 | import scipy.ndimage 4 | from PIL import Image 5 | from matplotlib import pyplot as plt 6 | import torch 7 | from utils import to_normal_strokes 8 | 9 | 10 | def mydrawPNG(vector_image, Side = 256): 11 | 12 | raster_image = np.zeros((int(Side), int(Side)), dtype=np.float32) 13 | initX, initY = int(vector_image[0, 0]), int(vector_image[0, 1]) 14 | stroke_bbox = [] 15 | stroke_cord_buffer = [] 16 | pixel_length = 0 17 | 18 | for i in range(0, len(vector_image)): 19 | if i > 0: 20 | if vector_image[i - 1, 2] == 1: 21 | initX, initY = int(vector_image[i, 0]), int(vector_image[i, 1]) 22 | 23 | cordList = list(bresenham(initX, initY, int(vector_image[i, 0]), int(vector_image[i, 1]))) 24 | pixel_length += len(cordList) 25 | stroke_cord_buffer.extend([list(i) for i in cordList]) 26 | 27 | for cord in cordList: 28 | if (cord[0] > 0 and cord[1] > 0) and (cord[0] < Side and cord[1] < Side): 29 | raster_image[cord[1], cord[0]] = 255.0 30 | initX, initY = int(vector_image[i, 0]), int(vector_image[i, 1]) 31 | 32 | if vector_image[i, 2] == 1: 33 | min_x = np.array(stroke_cord_buffer)[:, 0].min() 34 | min_y = np.array(stroke_cord_buffer)[:, 1].min() 35 | max_x = np.array(stroke_cord_buffer)[:, 0].max() 36 | max_y = np.array(stroke_cord_buffer)[:, 1].max() 37 | stroke_bbox.append([min_x, min_y, max_x, max_y]) 38 | stroke_cord_buffer = [] 39 | 40 | raster_image = scipy.ndimage.binary_dilation(raster_image) * 255.0 41 | #utils.image_boxes(Image.fromarray(raster_image).convert('RGB'), stroke_bbox).show() 42 | return raster_image, stroke_bbox 43 | 44 | 45 | def preprocess(sketch_points, side = 256.0): 46 | sketch_points = sketch_points.astype(np.float) 47 | sketch_points[:, :2] = sketch_points[:, :2] / np.array([256, 256]) 48 | sketch_points[:,:2] = sketch_points[:,:2] * side 49 | sketch_points = np.round(sketch_points) 50 | return sketch_points 51 | 52 | def rasterize_Sketch(sketch_points): 53 | sketch_points = preprocess(sketch_points) 54 | raster_images, _ = mydrawPNG(sketch_points) 55 | return raster_images 56 | 57 | def to_delXY(sketch): 58 | new_skech = sketch.copy() 59 | new_skech[:-1,:2] = new_skech[1:,:2] - new_skech[:-1,:2] 60 | new_skech[:-1, 2] = new_skech[1:, 2] 61 | return new_skech[:-1,:] 62 | 63 | 64 | def to_Absolute(sketch, start_point=(0,0)): 65 | new_skech = sketch.copy() 66 | origin = np.array([start_point[0], start_point[1], 0]) 67 | new_skech = np.vstack((origin, new_skech)) # add the implicit origin 68 | new_skech[:, :2] = np.cumsum(new_skech[:, :2], axis=0) 69 | return new_skech 70 | 71 | 72 | 73 | def toStrokeList(sketch): 74 | return np.split(sketch, np.where(sketch[:, 2])[0] + 1, axis=0)[:-1] 75 | 76 | 77 | def to_FivePoint(sketch, max_seq_len=130): 78 | len_seq = len(sketch[:, 0]) 79 | new_seq = np.zeros((max_seq_len, 5)) 80 | new_seq[0:len_seq, :2] = sketch[:, :2] 81 | new_seq[0:len_seq, 3] = sketch[:, 2] 82 | new_seq[0:len_seq, 2] = 1 - new_seq[0:len_seq, 3] 83 | new_seq[(len_seq - 1):, 4] = 1 84 | new_seq[(len_seq - 1), 2:4] = 0 85 | return new_seq 86 | 87 | 88 | def to_stroke_list(sketch): 89 | ## sketch: an `.npz` style sketch from QuickDraw 90 | sketch = np.vstack((np.array([0, 0, 0]), sketch)) 91 | sketch[:,:2] = np.cumsum(sketch[:,:2], axis=0) 92 | 93 | # range normalization 94 | xmin, xmax = sketch[:,0].min(), sketch[:,0].max() 95 | ymin, ymax = sketch[:,1].min(), sketch[:,1].max() 96 | 97 | sketch[:,0] = ((sketch[:,0] - xmin) / float(xmax - xmin)) * (255.-60.) + 30. 98 | sketch[:,1] = ((sketch[:,1] - ymin) / float(ymax - ymin)) * (255.-60.) + 30. 99 | sketch = sketch.astype(np.int64) 100 | 101 | stroke_list = np.split(sketch[:,:2], np.where(sketch[:,2])[0] + 1, axis=0) 102 | 103 | if stroke_list[-1].size == 0: 104 | stroke_list = stroke_list[:-1] 105 | 106 | if len(stroke_list) == 0: 107 | stroke_list = [sketch[:, :2]] 108 | # print('error') 109 | 110 | return stroke_list 111 | 112 | 113 | def rasterize_relative(stroke_list, fig, xlim=[0,255], ylim=[0,255]): 114 | # Usage: image = rasterize_relative(to_stroke_list(to_normal_strokes(data['relative_fivePoint'][0])), canvas) 115 | # fig = plt.figure(frameon=False, figsize=(2.56, 2.56)) 116 | for stroke in stroke_list: 117 | stroke = stroke[:,:2].astype(np.int64) 118 | plt.plot(stroke[:,0], stroke[:,1]) 119 | plt.xlim(*xlim) 120 | plt.ylim(*ylim) 121 | 122 | plt.gca().invert_yaxis(); plt.axis('off') 123 | fig.canvas.draw() 124 | X = np.array(fig.canvas.renderer._renderer) 125 | plt.gca().cla() 126 | X = X[...,:3] / 255. 127 | X = X.mean(2) 128 | X[X == 1.] = 0.; X[X > 0.] = 255.0 129 | sketch_img = Image.fromarray(X).convert('RGB') 130 | # plt.close(fig) 131 | return sketch_img 132 | 133 | def mydrawPNG_from_list(vector_image, Side = 256): 134 | 135 | raster_image = np.zeros((int(Side), int(Side)), dtype=np.float32) 136 | 137 | for stroke in vector_image: 138 | initX, initY = int(stroke[0, 0]), int(stroke[0, 1]) 139 | 140 | for i_pos in range(1, len(stroke)): 141 | cordList = list(bresenham(initX, initY, int(stroke[i_pos, 0]), int(stroke[i_pos, 1]))) 142 | for cord in cordList: 143 | if (cord[0] > 0 and cord[1] > 0) and (cord[0] <= Side and cord[1] <= Side): 144 | raster_image[cord[1], cord[0]] = 255.0 145 | else: 146 | print('error') 147 | initX, initY = int(stroke[i_pos, 0]), int(stroke[i_pos, 1]) 148 | 149 | raster_image = scipy.ndimage.binary_dilation(raster_image) * 255.0 150 | 151 | return Image.fromarray(raster_image).convert('RGB') 152 | 153 | 154 | def batch_rasterize_relative(sketch): 155 | 156 | def to_stroke_list(sketch): 157 | ## sketch: an `.npz` style sketch from QuickDraw 158 | sketch = np.vstack((np.array([0, 0, 0]), sketch)) 159 | sketch[:, :2] = np.cumsum(sketch[:, :2], axis=0) 160 | 161 | # range normalization 162 | xmin, xmax = sketch[:, 0].min(), sketch[:, 0].max() 163 | ymin, ymax = sketch[:, 1].min(), sketch[:, 1].max() 164 | 165 | sketch[:, 0] = ((sketch[:, 0] - xmin) / float(xmax - xmin)) * (255. - 60.) + 30. 166 | sketch[:, 1] = ((sketch[:, 1] - ymin) / float(ymax - ymin)) * (255. - 60.) + 30. 167 | sketch = sketch.astype(np.int64) 168 | 169 | stroke_list = np.split(sketch[:, :2], np.where(sketch[:, 2])[0] + 1, axis=0) 170 | 171 | if stroke_list[-1].size == 0: 172 | stroke_list = stroke_list[:-1] 173 | 174 | if len(stroke_list) == 0: 175 | stroke_list = [sketch[:, :2]] 176 | # print('error') 177 | return stroke_list 178 | 179 | batch_redraw = [] 180 | if sketch.shape[-1] == 5: 181 | for data in sketch: 182 | # image = rasterize_relative(to_stroke_list(to_normal_strokes(data.cpu().numpy())), canvas) 183 | image = mydrawPNG_from_list(to_stroke_list(to_normal_strokes(data.cpu().numpy()))) 184 | batch_redraw.append(torch.from_numpy(np.array(image)).permute(2, 0, 1)) 185 | elif sketch.shape[-1] == 3: 186 | for data in sketch: 187 | # image = rasterize_relative(to_stroke_list(data.cpu().numpy()), canvas) 188 | image = mydrawPNG_from_list(to_stroke_list(data.cpu().numpy())) 189 | batch_redraw.append(torch.from_numpy(np.array(image)).permute(2, 0, 1)) 190 | 191 | return torch.stack(batch_redraw).float() 192 | 193 | 194 | 195 | 196 | 197 | -------------------------------------------------------------------------------- /Photo_to_Sketch_2D_Attention/CVPR_18_Baseline/rasterize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from bresenham import bresenham 3 | import scipy.ndimage 4 | from PIL import Image 5 | from matplotlib import pyplot as plt 6 | import torch 7 | from utils import to_normal_strokes 8 | 9 | 10 | def mydrawPNG(vector_image, Side = 256): 11 | 12 | raster_image = np.zeros((int(Side), int(Side)), dtype=np.float32) 13 | initX, initY = int(vector_image[0, 0]), int(vector_image[0, 1]) 14 | stroke_bbox = [] 15 | stroke_cord_buffer = [] 16 | pixel_length = 0 17 | 18 | for i in range(0, len(vector_image)): 19 | if i > 0: 20 | if vector_image[i - 1, 2] == 1: 21 | initX, initY = int(vector_image[i, 0]), int(vector_image[i, 1]) 22 | 23 | cordList = list(bresenham(initX, initY, int(vector_image[i, 0]), int(vector_image[i, 1]))) 24 | pixel_length += len(cordList) 25 | stroke_cord_buffer.extend([list(i) for i in cordList]) 26 | 27 | for cord in cordList: 28 | if (cord[0] > 0 and cord[1] > 0) and (cord[0] < Side and cord[1] < Side): 29 | raster_image[cord[1], cord[0]] = 255.0 30 | initX, initY = int(vector_image[i, 0]), int(vector_image[i, 1]) 31 | 32 | if vector_image[i, 2] == 1: 33 | min_x = np.array(stroke_cord_buffer)[:, 0].min() 34 | min_y = np.array(stroke_cord_buffer)[:, 1].min() 35 | max_x = np.array(stroke_cord_buffer)[:, 0].max() 36 | max_y = np.array(stroke_cord_buffer)[:, 1].max() 37 | stroke_bbox.append([min_x, min_y, max_x, max_y]) 38 | stroke_cord_buffer = [] 39 | 40 | raster_image = scipy.ndimage.binary_dilation(raster_image) * 255.0 41 | #utils.image_boxes(Image.fromarray(raster_image).convert('RGB'), stroke_bbox).show() 42 | return raster_image, stroke_bbox 43 | 44 | 45 | def preprocess(sketch_points, side = 256.0): 46 | sketch_points = sketch_points.astype(np.float) 47 | sketch_points[:, :2] = sketch_points[:, :2] / np.array([256, 256]) 48 | sketch_points[:,:2] = sketch_points[:,:2] * side 49 | sketch_points = np.round(sketch_points) 50 | return sketch_points 51 | 52 | def rasterize_Sketch(sketch_points): 53 | sketch_points = preprocess(sketch_points) 54 | raster_images, _ = mydrawPNG(sketch_points) 55 | return raster_images 56 | 57 | def to_delXY(sketch): 58 | new_skech = sketch.copy() 59 | new_skech[:-1,:2] = new_skech[1:,:2] - new_skech[:-1,:2] 60 | new_skech[:-1, 2] = new_skech[1:, 2] 61 | return new_skech[:-1,:] 62 | 63 | 64 | def to_Absolute(sketch, start_point=(0,0)): 65 | new_skech = sketch.copy() 66 | origin = np.array([start_point[0], start_point[1], 0]) 67 | new_skech = np.vstack((origin, new_skech)) # add the implicit origin 68 | new_skech[:, :2] = np.cumsum(new_skech[:, :2], axis=0) 69 | return new_skech 70 | 71 | 72 | 73 | def toStrokeList(sketch): 74 | return np.split(sketch, np.where(sketch[:, 2])[0] + 1, axis=0)[:-1] 75 | 76 | 77 | def to_FivePoint(sketch, max_seq_len=130): 78 | len_seq = len(sketch[:, 0]) 79 | new_seq = np.zeros((max_seq_len, 5)) 80 | new_seq[0:len_seq, :2] = sketch[:, :2] 81 | new_seq[0:len_seq, 3] = sketch[:, 2] 82 | new_seq[0:len_seq, 2] = 1 - new_seq[0:len_seq, 3] 83 | new_seq[(len_seq - 1):, 4] = 1 84 | new_seq[(len_seq - 1), 2:4] = 0 85 | return new_seq 86 | 87 | 88 | def to_stroke_list(sketch): 89 | ## sketch: an `.npz` style sketch from QuickDraw 90 | sketch = np.vstack((np.array([0, 0, 0]), sketch)) 91 | sketch[:,:2] = np.cumsum(sketch[:,:2], axis=0) 92 | 93 | # range normalization 94 | xmin, xmax = sketch[:,0].min(), sketch[:,0].max() 95 | ymin, ymax = sketch[:,1].min(), sketch[:,1].max() 96 | 97 | sketch[:,0] = ((sketch[:,0] - xmin) / float(xmax - xmin)) * (255.-60.) + 30. 98 | sketch[:,1] = ((sketch[:,1] - ymin) / float(ymax - ymin)) * (255.-60.) + 30. 99 | sketch = sketch.astype(np.int64) 100 | 101 | stroke_list = np.split(sketch[:,:2], np.where(sketch[:,2])[0] + 1, axis=0) 102 | 103 | if stroke_list[-1].size == 0: 104 | stroke_list = stroke_list[:-1] 105 | 106 | if len(stroke_list) == 0: 107 | stroke_list = [sketch[:, :2]] 108 | # print('error') 109 | 110 | return stroke_list 111 | 112 | 113 | def rasterize_relative(stroke_list, fig, xlim=[0,255], ylim=[0,255]): 114 | # Usage: image = rasterize_relative(to_stroke_list(to_normal_strokes(data['relative_fivePoint'][0])), canvas) 115 | # fig = plt.figure(frameon=False, figsize=(2.56, 2.56)) 116 | for stroke in stroke_list: 117 | stroke = stroke[:,:2].astype(np.int64) 118 | plt.plot(stroke[:,0], stroke[:,1]) 119 | plt.xlim(*xlim) 120 | plt.ylim(*ylim) 121 | 122 | plt.gca().invert_yaxis(); plt.axis('off') 123 | fig.canvas.draw() 124 | X = np.array(fig.canvas.renderer._renderer) 125 | plt.gca().cla() 126 | X = X[...,:3] / 255. 127 | X = X.mean(2) 128 | X[X == 1.] = 0.; X[X > 0.] = 255.0 129 | sketch_img = Image.fromarray(X).convert('RGB') 130 | # plt.close(fig) 131 | return sketch_img 132 | 133 | def mydrawPNG_from_list(vector_image, Side = 256): 134 | 135 | raster_image = np.zeros((int(Side), int(Side)), dtype=np.float32) 136 | 137 | for stroke in vector_image: 138 | initX, initY = int(stroke[0, 0]), int(stroke[0, 1]) 139 | 140 | for i_pos in range(1, len(stroke)): 141 | cordList = list(bresenham(initX, initY, int(stroke[i_pos, 0]), int(stroke[i_pos, 1]))) 142 | for cord in cordList: 143 | if (cord[0] > 0 and cord[1] > 0) and (cord[0] <= Side and cord[1] <= Side): 144 | raster_image[cord[1], cord[0]] = 255.0 145 | else: 146 | print('error') 147 | initX, initY = int(stroke[i_pos, 0]), int(stroke[i_pos, 1]) 148 | 149 | raster_image = scipy.ndimage.binary_dilation(raster_image) * 255.0 150 | 151 | return Image.fromarray(raster_image).convert('RGB') 152 | 153 | 154 | def batch_rasterize_relative(sketch): 155 | 156 | def to_stroke_list(sketch): 157 | ## sketch: an `.npz` style sketch from QuickDraw 158 | sketch = np.vstack((np.array([0, 0, 0]), sketch)) 159 | sketch[:, :2] = np.cumsum(sketch[:, :2], axis=0) 160 | 161 | # range normalization 162 | xmin, xmax = sketch[:, 0].min(), sketch[:, 0].max() 163 | ymin, ymax = sketch[:, 1].min(), sketch[:, 1].max() 164 | 165 | sketch[:, 0] = ((sketch[:, 0] - xmin) / float(xmax - xmin)) * (255. - 60.) + 30. 166 | sketch[:, 1] = ((sketch[:, 1] - ymin) / float(ymax - ymin)) * (255. - 60.) + 30. 167 | sketch = sketch.astype(np.int64) 168 | 169 | stroke_list = np.split(sketch[:, :2], np.where(sketch[:, 2])[0] + 1, axis=0) 170 | 171 | if stroke_list[-1].size == 0: 172 | stroke_list = stroke_list[:-1] 173 | 174 | if len(stroke_list) == 0: 175 | stroke_list = [sketch[:, :2]] 176 | # print('error') 177 | return stroke_list 178 | 179 | batch_redraw = [] 180 | if sketch.shape[-1] == 5: 181 | for data in sketch: 182 | # image = rasterize_relative(to_stroke_list(to_normal_strokes(data.cpu().numpy())), canvas) 183 | image = mydrawPNG_from_list(to_stroke_list(to_normal_strokes(data.cpu().numpy()))) 184 | batch_redraw.append(torch.from_numpy(np.array(image)).permute(2, 0, 1)) 185 | elif sketch.shape[-1] == 3: 186 | for data in sketch: 187 | # image = rasterize_relative(to_stroke_list(data.cpu().numpy()), canvas) 188 | image = mydrawPNG_from_list(to_stroke_list(data.cpu().numpy())) 189 | batch_redraw.append(torch.from_numpy(np.array(image)).permute(2, 0, 1)) 190 | 191 | return torch.stack(batch_redraw).float() 192 | 193 | 194 | 195 | # def rasterize_relative_V(stroke_list, fig, xlim=[0,255], ylim=[0,255]): 196 | # # Usage: image = rasterize_relative(to_stroke_list(to_normal_strokes(data['relative_fivePoint'][0])), canvas) 197 | # 198 | # for stroke in stroke_list: 199 | # stroke = stroke[:,:2].astype(np.int64) 200 | # plt.plot(stroke[:,0], stroke[:,1]) 201 | # plt.xlim(*xlim) 202 | # plt.ylim(*ylim) 203 | # 204 | # plt.gca().invert_yaxis(); plt.axis('off') 205 | # fig.canvas.draw() 206 | # 207 | # # w, h = fig.canvas.get_width_height() 208 | # X = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8) 209 | # X.shape = (256, 256, 4) 210 | # 211 | # # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode 212 | # X = np.roll(X, 3, axis=2) 213 | # X = Image.frombytes("RGBA", (256, 256), X.tostring()) 214 | # X = np.array(X.convert('RGB')).mean(2)/255. 215 | # X[X == 1.] = 0.; X[X > 0.] = 255.0 216 | # sketch_img = Image.fromarray(X).convert('RGB') 217 | # return sketch_img 218 | 219 | # 220 | # 221 | # 222 | # 223 | # X = np.array(fig.canvas.renderer._renderer) 224 | # plt.gca().cla() 225 | # X = X[...,:3] / 255. 226 | # X = X.mean(2) 227 | # X[X == 1.] = 0.; X[X > 0.] = 255.0 228 | # sketch_img = Image.fromarray(X).convert('RGB') 229 | # plt.close(fig) 230 | # return sketch_img 231 | 232 | # def toAbsolute(sketch): 233 | # new_skech = sketch.copy() 234 | # origin = np.array([0, 0, 0]) 235 | # new_skech = np.vstack((origin, new_skech)) # add the implicit origin 236 | # new_skech[:, :2] = np.cumsum(new_skech[:, :2], axis=0) 237 | # return new_skech 238 | # def to_delXY(sketch): 239 | # new_skech = sketch.copy() 240 | # new_skech[:,:2] = new_skech[:,:2] - new_skech[0,:2] 241 | # new_skech[1:,:2] -= new_skech[:-1,:2] 242 | # new_skech = new_skech[1:,:] 243 | # return new_skech 244 | # stroke_list = toStrokeList(sketch_abs) # convenient structure for drawing 245 | # for stroke in stroke_list: 246 | # stroke = stroke[:,:-1] 247 | # plt.plot(stroke[:,0], stroke[:,1]) 248 | # plt.axis('off') 249 | # plt.gca().invert_yaxis() 250 | # plt.show() 251 | 252 | 253 | 254 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Photo_to_Sketch_2D_Attention/Sketch_Networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 4 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5 | # from utils import * 6 | import torch.nn as nn 7 | import numpy as np 8 | 9 | 10 | class DecoderRNN2D(nn.Module): 11 | def __init__(self, hp): 12 | super(DecoderRNN2D, self).__init__() 13 | self.fc_hc = nn.Linear(hp.z_size, 2 * hp.dec_rnn_size) 14 | self.lstm = nn.LSTM(hp.dec_rnn_size + 5, hp.dec_rnn_size) 15 | self.fc_params = nn.Linear(hp.dec_rnn_size, 6 * hp.num_mixture + 3) 16 | self.hp = hp 17 | self.attention_cell = AttentionCell2D(hp.dec_rnn_size) 18 | 19 | 20 | def forward(self, backbone_feature, z_vector, sketch_vector=None, seq_len=None, isTrain=True): 21 | 22 | batch_size = z_vector.shape[0] 23 | start_token = torch.stack([torch.tensor([0, 0, 1, 0, 0])] * batch_size).unsqueeze(0).float().to(device) 24 | 25 | 26 | self.training = isTrain 27 | output_hiddens = torch.FloatTensor(batch_size, self.hp.max_seq_len + 1, self.hp.dec_rnn_size).fill_(0).to(device) 28 | 29 | 30 | hidden, cell = torch.split(F.tanh(self.fc_hc(z_vector)), self.hp.dec_rnn_size, 1) 31 | hidden_cell = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous()) 32 | 33 | 34 | 35 | if self.training: 36 | batch_init = torch.cat([start_token, sketch_vector], 0) 37 | num_steps = sketch_vector.shape[0] + 1 38 | for i in range(num_steps): 39 | state_point = batch_init[i, :, ] 40 | att_feature, _ = self.attention_cell.forward(backbone_feature, hidden_cell[0].squeeze(0)) 41 | concat_context = torch.cat([att_feature, state_point], 1).unsqueeze(0) # batch_size x (num_channel + num_embedding) 42 | _, hidden_cell = self.lstm(concat_context, hidden_cell) 43 | output_hiddens[:, i, :] = hidden_cell[0].squeeze(0) # LSTM hidden index (0: hidden, 1: Cell) 44 | 45 | y_output = self.fc_params(output_hiddens) 46 | 47 | """ Split the data""" 48 | z_pen_logits = y_output[:, :, 0:3] 49 | z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = torch.chunk(y_output[:, :, 3:], 6, 2) 50 | z_pi = F.softmax(z_pi, dim=-1) 51 | z_sigma1 = torch.exp(z_sigma1) 52 | z_sigma2 = torch.exp(z_sigma2) 53 | z_corr = torch.tanh(z_corr) 54 | 55 | return [z_pi.reshape(-1, 20), z_mu1.reshape(-1, 20), z_mu2.reshape(-1, 20), \ 56 | z_sigma1.reshape(-1, 20), z_sigma2.reshape(-1, 20), z_corr.reshape(-1, 20), z_pen_logits.reshape(-1, 3)] 57 | 58 | else: 59 | batch_gen_strokes = [] 60 | state_point = start_token.squeeze(0) # [GO] token 61 | num_steps = sketch_vector.shape[0] + 1 62 | # batch_init = torch.cat([start_token, sketch_vector], 0) 63 | attention_plot = [] 64 | 65 | for i in range(num_steps): 66 | 67 | att_feature, attention = self.attention_cell.forward(backbone_feature, hidden_cell[0].squeeze(0)) 68 | attention_plot.append(attention.view(batch_size, 1, 8, 8)) 69 | # state_point = batch_init[i, :, :] 70 | concat_context = torch.cat([att_feature, state_point], 1).unsqueeze(0) # batch_size x (num_channel + num_embedding) 71 | _, hidden_cell = self.lstm(concat_context, hidden_cell) 72 | y_output = self.fc_params(hidden_cell[0].permute(1, 0, 2)) 73 | 74 | 75 | """ Split the data to get next output """ 76 | z_pen_logits = y_output[:, :, 0:3] 77 | z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = torch.chunk(y_output[:, :, 3:], 6, 2) 78 | z_pi = F.softmax(z_pi, dim=-1) 79 | z_sigma1 = torch.exp(z_sigma1) 80 | z_sigma2 = torch.exp(z_sigma2) 81 | z_corr = torch.tanh(z_corr) 82 | 83 | batch_size = z_pi.shape[0] 84 | z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, \ 85 | z_corr, z_pen_logits = z_pi.reshape(-1, 20), z_mu1.reshape(-1,20), \ 86 | z_mu2.reshape(-1, 20), z_sigma1.reshape(-1,20), \ 87 | z_sigma2.reshape(-1, 20), z_corr.reshape(-1, 20), z_pen_logits.reshape(-1, 3) 88 | 89 | recons_output = torch.zeros(batch_size, 5).to(device) 90 | z_pi_idx = z_pi.argmax(dim=-1) 91 | z_pen_idx = z_pen_logits.argmax(-1) 92 | recons_output[:, 0] = z_mu1[range(z_mu1.shape[0]), z_pi_idx] 93 | recons_output[:, 1] = z_mu2[range(z_mu2.shape[0]), z_pi_idx] 94 | 95 | recons_output[range(z_mu1.shape[0]), z_pen_idx + 2] = 1. 96 | 97 | state_point = recons_output.data 98 | batch_gen_strokes.append(state_point) 99 | 100 | return torch.stack(batch_gen_strokes, dim=1), attention_plot 101 | 102 | 103 | 104 | class AttentionCell2D(nn.Module): 105 | 106 | def __init__(self, hidden_size): 107 | super(AttentionCell2D, self).__init__() 108 | self.featrue_layers = 512 109 | self.hidden_dim_de = hidden_size 110 | self.embedding_size = 256 111 | 112 | self.conv_h = nn.Linear(self.hidden_dim_de, self.embedding_size) 113 | self.conv_f = nn.Conv2d(self.featrue_layers, 114 | self.embedding_size, kernel_size=3, padding=1) 115 | 116 | self.conv_att = nn.Linear(self.embedding_size, 1) 117 | # self.dropout = nn.Dropout(p=0.5) 118 | 119 | def forward(self, conv_f, h):#conv_f[10,512,6,40],h:[10,512] 120 | 121 | g_em = self.conv_h(h) #10, 512 122 | g_em = g_em.unsqueeze(-1).permute(0, 2, 1) #[10, 1, 256] 123 | 124 | x_em = self.conv_f(conv_f) #[10, 256, 8, 25] 125 | x_em = x_em.view(x_em.shape[0], -1, conv_f.shape[2] * conv_f.shape[3]) #[10, 256, 200] 126 | x_em = x_em.permute(0, 2, 1) #[10, 200, 256] 127 | 128 | feat = torch.tanh(x_em + g_em) #[10, 200, 256] 129 | alpha = F.softmax(self.conv_att(feat), dim=1) # [10, 200, 1] 130 | alpha = alpha.permute(0, 2, 1) # [10, 1, 200] 131 | 132 | orgfeat_embed = conv_f.view(conv_f.shape[0], -1, conv_f.shape[2] * conv_f.shape[3]) #[10, 512, 200] 133 | orgfeat_embed = orgfeat_embed.permute(0, 2, 1) #[10, 200, 512] 134 | 135 | att_out = torch.bmm(alpha, orgfeat_embed).squeeze(1) # [50, 1, 64] x [50, 64, 512] -> [50, 1, 512] 136 | 137 | return att_out, alpha 138 | 139 | class EncoderRNN(nn.Module): 140 | def __init__(self, hp): 141 | super(EncoderRNN, self).__init__() 142 | self.lstm = nn.LSTM(5, hp.enc_rnn_size, dropout=hp.input_dropout_prob, bidirectional=True) 143 | self.fc_mu = nn.Linear(2*hp.enc_rnn_size, hp.z_size) 144 | self.fc_sigma = nn.Linear(2*hp.enc_rnn_size, hp.z_size) 145 | 146 | def forward(self, x, Seq_Len=None): 147 | x = pack_padded_sequence(x, Seq_Len, enforce_sorted=False) 148 | _, (h_n, _) = self.lstm(x.float()) 149 | h_n = h_n.permute(1,0,2).reshape(h_n.shape[1], -1) 150 | mean = self.fc_mu(h_n) 151 | log_var = self.fc_sigma(h_n) 152 | posterior_dist = torch.distributions.Normal(mean, torch.exp(0.5 * log_var)) 153 | return posterior_dist 154 | 155 | 156 | 157 | class DecoderRNN(nn.Module): 158 | def __init__(self, hp): 159 | super(DecoderRNN, self).__init__() 160 | self.fc_hc = nn.Linear(hp.z_size, 2 * hp.dec_rnn_size) 161 | self.lstm = nn.LSTM(hp.z_size + 5, hp.dec_rnn_size, dropout=hp.output_dropout_prob) 162 | self.fc_params = nn.Linear(hp.dec_rnn_size, 6 * hp.num_mixture + 3) 163 | self.hp = hp 164 | 165 | def forward(self, inputs, z_vector, seq_len = None, hidden_cell=None, isTrain = True, get_deterministic = True): 166 | 167 | self.training = isTrain 168 | if hidden_cell is None: 169 | hidden, cell = torch.split(F.tanh(self.fc_hc(z_vector)), self.hp.dec_rnn_size, 1) 170 | hidden_cell = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous()) 171 | 172 | if seq_len is None: 173 | # seq_len = torch.tensor([1]).type(torch.int64).to(device) 174 | seq_len = torch.ones(inputs.shape[1]).type(torch.int64).to(device) 175 | 176 | inputs = pack_padded_sequence(inputs, seq_len, enforce_sorted=False) 177 | outputs, (hidden, cell) = self.lstm(inputs, hidden_cell) 178 | outputs, _ = pad_packed_sequence(outputs) 179 | 180 | if self.training: 181 | if outputs.shape[0] != (self.hp.max_seq_len + 1): 182 | pad = torch.zeros(outputs.shape[-1]).repeat(self.hp.max_seq_len + 1 - outputs.shape[0], outputs.shape[1], 1).cuda() 183 | outputs = torch.cat((outputs, pad), dim=0) 184 | y_output = self.fc_params(outputs.permute(1,0,2)) 185 | else: 186 | y_output = self.fc_params(hidden.permute(1,0,2)) 187 | 188 | z_pen_logits = y_output[:, :, 0:3] 189 | z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = torch.chunk(y_output[:, :, 3:], 6, 2) 190 | z_pi = F.softmax(z_pi, dim=-1) 191 | z_sigma1 = torch.exp(z_sigma1) 192 | z_sigma2 = torch.exp(z_sigma2) 193 | z_corr = torch.tanh(z_corr) 194 | 195 | 196 | if (not self.training) and get_deterministic: 197 | batch_size = z_pi.shape[0] 198 | z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen_logits = z_pi.reshape(-1, 20), z_mu1.reshape(-1, 20), z_mu2.reshape(-1, 20), \ 199 | z_sigma1.reshape(-1, 20), z_sigma2.reshape(-1, 20), z_corr.reshape(-1, 20), z_pen_logits.reshape(-1, 3) 200 | 201 | recons_output = torch.zeros(batch_size, 5).to(device) 202 | z_pi_idx = z_pi.argmax(dim=-1) 203 | z_pen_idx = z_pen_logits.argmax(-1) 204 | recons_output[:, 0] = z_mu1[range(z_mu1.shape[0]), z_pi_idx] 205 | recons_output[:, 1] = z_mu2[range(z_mu2.shape[0]), z_pi_idx] 206 | 207 | recons_output[range(z_mu1.shape[0]), z_pen_idx + 2] = 1. 208 | 209 | return recons_output.unsqueeze(0).data, (hidden, cell) 210 | 211 | return [z_pi.reshape(-1, 20), z_mu1.reshape(-1, 20), z_mu2.reshape(-1, 20), \ 212 | z_sigma1.reshape(-1, 20), z_sigma2.reshape(-1, 20), z_corr.reshape(-1, 20), z_pen_logits.reshape(-1, 3)], (hidden, cell) 213 | 214 | 215 | def torch_2d_normal(x1, x2, mu1, mu2, s1, s2, rho): 216 | """Returns result of eq # 24 of http://arxiv.org/abs/1308.0850.""" 217 | norm1 = x1 - mu1 218 | norm2 = x2 - mu2 219 | s1s2 = s1 * s2 220 | 221 | z_1 = (norm1 / s1) ** 2 222 | z_2 = (norm2 / s2) ** 2 223 | z1_z2 = (norm1 * norm2) / s1s2 224 | 225 | z = z_1 + z_2 - 2 * rho * z1_z2 226 | neg_rho = 1 - rho ** 2 227 | result = torch.exp(-z / (2 * neg_rho)) 228 | denom = 2 * np.pi * s1s2 * torch.sqrt(neg_rho) 229 | return result / denom 230 | 231 | 232 | def sketch_reconstruction_loss(output, x_input): 233 | # x_input = 234 | # Ouput = Predicted 123 parameters from decoder = Batch*Max_seq_len, 20 235 | [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen_logits] = output 236 | [x1_data, x2_data, eos_data, eoc_data, cont_data] = torch.chunk(x_input.reshape(-1, 5), 5, 1) 237 | pen_data = torch.cat([eos_data, eoc_data, cont_data], 1) 238 | mask = 1.0 - pen_data[:, 2] # use training data for this 239 | 240 | result0 = torch_2d_normal(x1_data, x2_data, o_mu1, o_mu2, o_sigma1, o_sigma2, 241 | o_corr) 242 | epsilon = 1e-6 243 | 244 | result1 = torch.sum(result0 * o_pi, dim=1) # ? unsqueeae(-1) ?? 245 | result1 = -torch.log(result1 + epsilon) # avoid log(0) 246 | 247 | if torch.isnan(result1).any(): 248 | print('Catched') 249 | 250 | result2 = F.cross_entropy(o_pen_logits, pen_data.argmax(1), reduction='none') 251 | 252 | result = mask * result1 + mask * result2 253 | # result = result1 + result2 254 | 255 | return result.mean() 256 | 257 | def set_learninRate(optimizer, curr_learning_rate): 258 | for g in optimizer.param_groups: 259 | g['lr'] = curr_learning_rate 260 | -------------------------------------------------------------------------------- /Photo_to_Sketch_2D_Attention/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 3 | import pickle 4 | import torch.utils.data as data 5 | import torchvision.transforms as transforms 6 | import os 7 | from random import randint 8 | from PIL import Image 9 | import random 10 | random.seed(9001) 11 | 12 | from rasterize import * 13 | import argparse 14 | import numpy as np 15 | import torchvision 16 | from matplotlib import pyplot as plt 17 | from utils import * 18 | import time 19 | import torchvision.transforms.functional as TF 20 | from torchvision.utils import save_image 21 | 22 | 23 | class Photo2Sketch_Dataset(data.Dataset): 24 | 25 | def __init__(self, hp, mode): 26 | super(Photo2Sketch_Dataset, self).__init__() 27 | 28 | self.hp = hp 29 | self.mode = mode 30 | hp.root_dir = '/home/media/On_the_Fly/Code_ALL/Final_Dataset' 31 | hp.dataset_name = 'ShoeV2' 32 | hp.seq_len_threshold = 251 33 | 34 | self.root_dir = os.path.join(hp.root_dir, hp.dataset_name) 35 | 36 | with open('./preprocess/ShoeV2_RDP_3', 'rb') as fp: 37 | self.Coordinate = pickle.load(fp) 38 | 39 | 40 | seq_len_threshold = 81 41 | coordinate_refine = {} 42 | seq_len = [] 43 | for key in self.Coordinate.keys(): 44 | if len(self.Coordinate[key]) < seq_len_threshold: 45 | coordinate_refine[key] = self.Coordinate[key] 46 | seq_len.append(len(self.Coordinate[key])) 47 | self.Coordinate = coordinate_refine 48 | hp.max_seq_len = max(seq_len) 49 | hp.average_seq_len = int(np.round(np.mean(seq_len) + 0.5*np.std(seq_len))) 50 | 51 | # greater_than_average = 0 52 | # for seq in seq_len: 53 | # if seq > self.hp.average_len: 54 | # greater_than_average +=1 55 | 56 | self.Train_Sketch = [x for x in self.Coordinate if ('train' in x) and (len(self.Coordinate[x]) < seq_len_threshold)] # separating trains 57 | self.Test_Sketch = [x for x in self.Coordinate if ('test' in x) and (len(self.Coordinate[x]) < seq_len_threshold)] # separating tests 58 | 59 | self.train_transform = get_transform('Train') 60 | self.test_transform = get_transform('Test') 61 | 62 | # # seq_len = [] 63 | # # for key in self.Coordinate.keys(): 64 | # # seq_len += [len(self.Coordinate[key])] 65 | # # plt.hist(seq_len) 66 | # # plt.savefig('histogram of number of Coordinate Points.png') 67 | # # plt.close() 68 | # # hp.max_seq_len = max(seq_len) 69 | # hp.max_seq_len = 130 70 | 71 | 72 | """" Preprocess offset coordinates """ 73 | self.Offset_Coordinate = {} 74 | for key in self.Coordinate.keys(): 75 | self.Offset_Coordinate[key] = to_delXY(self.Coordinate[key]) 76 | data = [] 77 | for sample in self.Offset_Coordinate.values(): 78 | data.extend(sample[:, 0]) 79 | data.extend(sample[:, 1]) 80 | data = np.array(data) 81 | scale_factor = np.std(data) 82 | 83 | for key in self.Coordinate.keys(): 84 | self.Offset_Coordinate[key][:, :2] /= scale_factor 85 | 86 | """" <<< Preprocess offset coordinates >>> """ 87 | """" <<< Done >>> """ 88 | 89 | 90 | 91 | def __getitem__(self, item): 92 | 93 | if self.mode == 'Train': 94 | sketch_path = self.Train_Sketch[item] 95 | 96 | positive_sample = '_'.join(self.Train_Sketch[item].split('/')[-1].split('_')[:-1]) 97 | positive_path = os.path.join(self.root_dir, 'photo', positive_sample + '.png') 98 | positive_img = Image.open(positive_path).convert('RGB') 99 | 100 | sketch_abs = self.Coordinate[sketch_path] 101 | sketch_delta = self.Offset_Coordinate[sketch_path] 102 | 103 | sketch_img = rasterize_Sketch(sketch_abs) 104 | sketch_img = Image.fromarray(sketch_img).convert('RGB') 105 | 106 | # sketch_img = TF.hflip(sketch_img) 107 | # positive_img = TF.hflip(positive_img) 108 | sketch_img = self.train_transform(sketch_img) 109 | positive_img = self.train_transform(positive_img) 110 | 111 | #################################### #################################### #################################### 112 | 113 | absolute_coordinate = np.zeros((self.hp.max_seq_len, 3)) 114 | relative_coordinate = np.zeros((self.hp.max_seq_len, 3)) 115 | absolute_coordinate[:sketch_abs.shape[0], :] = sketch_abs 116 | relative_coordinate[:sketch_delta.shape[0], :] = sketch_delta 117 | #################################### #################################### #################################### 118 | 119 | # sample = {'sketch_img': sketch_img, 120 | # 'sketch_path': sketch_path, 121 | # 'absolute_coordinate':absolute_coordinate, 122 | # 'relative_coordinate': relative_coordinate, 123 | # 'sketch_length': int(len(sketch_abs)), 124 | # 'absolute_fivePoint': to_FivePoint(sketch_abs, self.hp.max_seq_len), 125 | # 'relative_fivePoint': to_FivePoint(sketch_delta, self.hp.max_seq_len), 126 | # 'positive_img': positive_img, 127 | # 'positive_path': positive_sample} 128 | 129 | sample = {'sketch_path': sketch_path, 'length': int(len(sketch_abs)), 130 | 'sketch_vector': to_FivePoint(sketch_delta, self.hp.max_seq_len), 131 | 'photo': positive_img} 132 | 133 | 134 | elif self.mode == 'Test': 135 | 136 | sketch_path = self.Test_Sketch[item] 137 | 138 | positive_sample = '_'.join(self.Test_Sketch[item].split('/')[-1].split('_')[:-1]) 139 | positive_path = os.path.join(self.root_dir, 'photo', positive_sample + '.png') 140 | 141 | sketch_abs = self.Coordinate[sketch_path] 142 | sketch_delta = self.Offset_Coordinate[sketch_path] 143 | 144 | sketch_img = rasterize_Sketch(sketch_abs) 145 | sketch_img = Image.fromarray(sketch_img).convert('RGB') 146 | 147 | sketch_img = self.test_transform(sketch_img) 148 | positive_img = self.test_transform(Image.open(positive_path).convert('RGB')) 149 | 150 | #################################### #################################### #################################### 151 | 152 | absolute_coordinate = np.zeros((self.hp.max_seq_len, 3)) 153 | relative_coordinate = np.zeros((self.hp.max_seq_len, 3)) 154 | absolute_coordinate[:sketch_abs.shape[0], :] = sketch_abs 155 | relative_coordinate[:sketch_delta.shape[0], :] = sketch_delta 156 | #################################### #################################### #################################### 157 | 158 | # sample = {'sketch_img': sketch_img, 159 | # 'sketch_path': sketch_path, 160 | # 'absolute_coordinate':absolute_coordinate, 161 | # 'relative_coordinate': relative_coordinate, 162 | # 'sketch_length': int(len(sketch_abs)), 163 | # 'absolute_fivePoint': to_FivePoint(sketch_abs, self.hp.max_seq_len), 164 | # 'relative_fivePoint': to_FivePoint(sketch_delta, self.hp.max_seq_len), 165 | # 'positive_img': positive_img, 166 | # 'positive_path': positive_sample} 167 | 168 | sample = { 'sketch_path': sketch_path, 169 | 'length': int(len(sketch_abs)), 170 | 'sketch_vector': to_FivePoint(sketch_delta, self.hp.max_seq_len), 171 | 'photo': positive_img} 172 | 173 | return sample 174 | 175 | 176 | 177 | def __len__(self): 178 | if self.mode == 'Train': 179 | return len(self.Train_Sketch) 180 | elif self.mode == 'Test': 181 | return len(self.Test_Sketch) 182 | 183 | 184 | 185 | def get_dataloader(hp): 186 | 187 | dataset_Train = Photo2Sketch_Dataset(hp, mode = 'Train') 188 | 189 | 190 | dataset_Test = Photo2Sketch_Dataset(hp, mode = 'Test') 191 | 192 | dataset_Train = torch.utils.data.ConcatDataset([dataset_Train, dataset_Test]) 193 | 194 | dataloader_Train = data.DataLoader(dataset_Train, batch_size=hp.batchsize, shuffle=False, 195 | num_workers=int(hp.nThreads)) 196 | 197 | dataloader_Test = data.DataLoader(dataset_Test, batch_size=hp.batchsize, shuffle=False, 198 | num_workers=int(hp.nThreads)) 199 | 200 | return dataloader_Train, dataloader_Test 201 | 202 | 203 | def get_transform(type): 204 | transform_list = [] 205 | if type is 'Train': 206 | transform_list.extend([transforms.Resize(256)]) 207 | elif type is 'Test': 208 | transform_list.extend([transforms.Resize(256)]) 209 | # transform_list.extend( 210 | # [transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) 211 | transform_list.extend( 212 | [transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 213 | 214 | return transforms.Compose(transform_list) 215 | 216 | 217 | 218 | 219 | 220 | class get_sketchOnly_dataloader(object): 221 | 222 | def __init__(self, hp): 223 | dataset = np.load('sketchrnn_shoe.npz', encoding='latin1', allow_pickle=True) 224 | self.hp = hp 225 | data_train = dataset['train'] 226 | data_valid = dataset['valid'] 227 | 228 | # hp.sketch_rnn_max_seq_len = self.max_size(np.concatenate((data_train, data_valid))) 229 | sizes = [len(seq) for seq in np.concatenate((data_train, data_valid))] 230 | hp.sketch_rnn_max_seq_len = max(sizes) 231 | hp.max_seq_len = hp.sketch_rnn_max_seq_len 232 | hp.average_seq_len = int(np.round(np.mean(sizes) + np.std(sizes))) 233 | 234 | 235 | data_train = self.purify(data_train) 236 | self.data_train = self.normalize(data_train) 237 | 238 | 239 | data_valid = self.purify(data_valid) 240 | self.data_valid = self.normalize(data_valid) 241 | 242 | self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 243 | 244 | # self.mean = torch.tensor([0.5, 0.5, 0.5]) 245 | # self.std = torch.tensor([0.5, 0.5, 0.5]) 246 | 247 | 248 | self.mean = torch.tensor([0.485, 0.456, 0.406]) 249 | self.std = torch.tensor([0.229, 0.224, 0.225]) 250 | 251 | 252 | 253 | def purify(self, strokes): 254 | """removes to small or too long sequences + removes large gaps""" 255 | data = [] 256 | for seq in strokes: 257 | if seq.shape[0] <= self.hp.sketch_rnn_max_seq_len and seq.shape[0] > 10: 258 | seq = np.minimum(seq, 1000) 259 | seq = np.maximum(seq, -1000) 260 | seq = np.array(seq, dtype=np.float32) 261 | data.append(seq) 262 | return data 263 | 264 | def max_size(self, data): 265 | """larger sequence length in the data set""" 266 | sizes = [len(seq) for seq in data] 267 | self.hp.average_len = np.round(np.mean(sizes) + np.std(sizes)) 268 | 269 | # greater_than_average = 0 270 | # for seq in data: 271 | # if len(seq) > self.hp.average_len: 272 | # greater_than_average +=1 273 | 274 | return max(sizes) 275 | 276 | 277 | def calculate_normalizing_scale_factor(self, strokes): 278 | """Calculate the normalizing factor explained in appendix of sketch-rnn.""" 279 | data = [] 280 | for i in range(len(strokes)): 281 | for j in range(len(strokes[i])): 282 | data.append(strokes[i][j, 0]) 283 | data.append(strokes[i][j, 1]) 284 | data = np.array(data) 285 | return np.std(data) 286 | 287 | def normalize(self, strokes): 288 | """Normalize entire dataset (delta_x, delta_y) by the scaling factor.""" 289 | data = [] 290 | scale_factor = self.calculate_normalizing_scale_factor(strokes) 291 | for seq in strokes: 292 | seq[:, 0:2] /= scale_factor 293 | data.append(seq) 294 | return data 295 | 296 | def train_batch(self, batch_size_sketch_rnn=50): 297 | batch_idx = np.random.choice(len(self.data_train), batch_size_sketch_rnn) 298 | batch_sequences = [self.data_train[idx] for idx in batch_idx] 299 | strokes = [] 300 | lengths = [] 301 | indice = 0 302 | for seq in batch_sequences: 303 | len_seq = len(seq[:, 0]) 304 | new_seq = np.zeros((self.hp.sketch_rnn_max_seq_len, 5)) 305 | new_seq[0:len_seq, :2] = seq[:, :2] 306 | new_seq[0:len_seq, 3] = seq[:, 2] 307 | new_seq[0:len_seq, 2] = 1 - new_seq[0:len_seq, 3] 308 | new_seq[(len_seq-1):, 4] = 1 309 | new_seq[(len_seq - 1), 2:4] = 0 310 | lengths.append(len(seq[:, 0])) 311 | strokes.append(new_seq) 312 | indice += 1 313 | 314 | batch = torch.from_numpy(np.stack(strokes, 1)).to(device).float() 315 | batch_image = 1. - batch_rasterize_relative(batch.permute(1, 0, 2))/255. 316 | batch_image.sub_(self.mean[None, :, None, None]).div_(self.std[None, :, None, None]) 317 | 318 | sample = {'length': torch.tensor(lengths).type(torch.int64), 319 | 'sketch_vector': batch, 320 | 'photo': batch_image} 321 | 322 | return sample 323 | 324 | 325 | def valid_batch(self, batch_size_sketch_rnn=100): 326 | batch_idx = np.random.choice(len(self.data_valid), batch_size_sketch_rnn) 327 | batch_sequences = [self.data_valid[idx] for idx in batch_idx] 328 | strokes = [] 329 | lengths = [] 330 | indice = 0 331 | for seq in batch_sequences: 332 | len_seq = len(seq[:, 0]) 333 | new_seq = np.zeros((self.hp.sketch_rnn_max_seq_len, 5)) 334 | new_seq[0:len_seq, :2] = seq[:, :2] 335 | new_seq[0:len_seq, 3] = seq[:, 2] 336 | new_seq[0:len_seq, 2] = 1 - new_seq[0:len_seq, 3] 337 | new_seq[len_seq:, 4] = 1 338 | lengths.append(len(seq[:, 0])) 339 | strokes.append(new_seq) 340 | indice += 1 341 | 342 | batch = torch.from_numpy(np.stack(strokes, 1)).to(device).float() 343 | batch_image = 1. - batch_rasterize_relative(batch.permute(1, 0, 2))/255. 344 | batch_image.sub_(self.mean[None, :, None, None]).div_(self.std[None, :, None, None]) 345 | 346 | sample = {'length': torch.tensor(lengths).type(torch.int64), 347 | 'sketch_vector': batch, 348 | 'photo': batch_image} 349 | 350 | return sample 351 | 352 | 353 | if __name__ == '__main__': 354 | pass 355 | -------------------------------------------------------------------------------- /Photo_to_Sketch_2D_Attention/CVPR_18_Baseline/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 3 | import pickle 4 | import torch.utils.data as data 5 | import torchvision.transforms as transforms 6 | import os 7 | from random import randint 8 | from PIL import Image 9 | import random 10 | random.seed(9001) 11 | 12 | from rasterize import * 13 | import argparse 14 | import numpy as np 15 | import torchvision 16 | from matplotlib import pyplot as plt 17 | from utils import * 18 | import time 19 | import torchvision.transforms.functional as TF 20 | 21 | class Photo2Sketch_Dataset(data.Dataset): 22 | def __init__(self, hp, mode): 23 | super(Photo2Sketch_Dataset, self).__init__() 24 | 25 | self.hp = hp 26 | self.mode = mode 27 | hp.root_dir = '/home/media/On_the_Fly/Code_ALL/Final_Dataset' 28 | hp.dataset_name = 'ShoeV2' 29 | hp.seq_len_threshold = 251 30 | 31 | # coordinate_path = os.path.join(hp.root_dir, hp.dataset_name , hp.dataset_name + '_Coordinate') 32 | self.root_dir = os.path.join(hp.root_dir, hp.dataset_name) 33 | 34 | with open('./preprocess/ShoeV2_RDP_3', 'rb') as fp: 35 | self.Coordinate = pickle.load(fp) 36 | 37 | coordinate_refine = {} 38 | seq_len = [] 39 | for key in self.Coordinate.keys(): 40 | if len(self.Coordinate[key]) < 81: 41 | coordinate_refine[key] = self.Coordinate[key] 42 | seq_len.append(len(self.Coordinate[key])) 43 | self.Coordinate = coordinate_refine 44 | hp.max_seq_len = max(seq_len) 45 | hp.average_seq_len = int(np.round(np.mean(seq_len) + np.std(seq_len))) 46 | 47 | # greater_than_average = 0 48 | # for seq in seq_len: 49 | # if seq > self.hp.average_len: 50 | # greater_than_average +=1 51 | 52 | self.Train_Sketch = [x for x in self.Coordinate if ('train' in x) and (len(self.Coordinate[x]) <130)] # separating trains 53 | self.Test_Sketch = [x for x in self.Coordinate if ('test' in x) and (len(self.Coordinate[x]) <130)] # separating tests 54 | 55 | self.train_transform = get_transform('Train') 56 | self.test_transform = get_transform('Test') 57 | 58 | # # seq_len = [] 59 | # # for key in self.Coordinate.keys(): 60 | # # seq_len += [len(self.Coordinate[key])] 61 | # # plt.hist(seq_len) 62 | # # plt.savefig('histogram of number of Coordinate Points.png') 63 | # # plt.close() 64 | # # hp.max_seq_len = max(seq_len) 65 | # hp.max_seq_len = 130 66 | 67 | 68 | """" Preprocess offset coordinates """ 69 | self.Offset_Coordinate = {} 70 | for key in self.Coordinate.keys(): 71 | self.Offset_Coordinate[key] = to_delXY(self.Coordinate[key]) 72 | data = [] 73 | for sample in self.Offset_Coordinate.values(): 74 | data.extend(sample[:, 0]) 75 | data.extend(sample[:, 1]) 76 | data = np.array(data) 77 | scale_factor = np.std(data) 78 | 79 | for key in self.Coordinate.keys(): 80 | self.Offset_Coordinate[key][:, :2] /= scale_factor 81 | 82 | """" <<< Preprocess offset coordinates >>> """ 83 | """" <<< Done >>> """ 84 | 85 | 86 | 87 | def __getitem__(self, item): 88 | 89 | if self.mode == 'Train': 90 | sketch_path = self.Train_Sketch[item] 91 | 92 | positive_sample = '_'.join(self.Train_Sketch[item].split('/')[-1].split('_')[:-1]) 93 | positive_path = os.path.join(self.root_dir, 'photo', positive_sample + '.png') 94 | positive_img = Image.open(positive_path).convert('RGB') 95 | 96 | sketch_abs = self.Coordinate[sketch_path] 97 | sketch_delta = self.Offset_Coordinate[sketch_path] 98 | 99 | sketch_img = rasterize_Sketch(sketch_abs) 100 | sketch_img = Image.fromarray(sketch_img).convert('RGB') 101 | 102 | sketch_img = TF.hflip(sketch_img) 103 | positive_img = TF.hflip(positive_img) 104 | sketch_img = self.train_transform(sketch_img) 105 | positive_img = self.train_transform(positive_img) 106 | 107 | #################################### #################################### #################################### 108 | 109 | absolute_coordinate = np.zeros((self.hp.max_seq_len, 3)) 110 | relative_coordinate = np.zeros((self.hp.max_seq_len, 3)) 111 | absolute_coordinate[:sketch_abs.shape[0], :] = sketch_abs 112 | relative_coordinate[:sketch_delta.shape[0], :] = sketch_delta 113 | #################################### #################################### #################################### 114 | 115 | sample = {'sketch_img': sketch_img, 116 | 'sketch_path': sketch_path, 117 | 'absolute_coordinate':absolute_coordinate, 118 | 'relative_coordinate': relative_coordinate, 119 | 'sketch_length': int(len(sketch_abs)), 120 | 'absolute_fivePoint': to_FivePoint(sketch_abs, self.hp.max_seq_len), 121 | 'relative_fivePoint': to_FivePoint(sketch_delta, self.hp.max_seq_len), 122 | 'positive_img': positive_img, 123 | 'positive_path': positive_sample} 124 | 125 | 126 | elif self.mode == 'Test': 127 | 128 | sketch_path = self.Test_Sketch[item] 129 | 130 | positive_sample = '_'.join(self.Test_Sketch[item].split('/')[-1].split('_')[:-1]) 131 | positive_path = os.path.join(self.root_dir, 'photo', positive_sample + '.png') 132 | 133 | sketch_abs = self.Coordinate[sketch_path] 134 | sketch_delta = self.Offset_Coordinate[sketch_path] 135 | 136 | sketch_img = rasterize_Sketch(sketch_abs) 137 | sketch_img = Image.fromarray(sketch_img).convert('RGB') 138 | 139 | sketch_img = self.test_transform(sketch_img) 140 | positive_img = self.test_transform(Image.open(positive_path).convert('RGB')) 141 | 142 | #################################### #################################### #################################### 143 | 144 | absolute_coordinate = np.zeros((self.hp.max_seq_len, 3)) 145 | relative_coordinate = np.zeros((self.hp.max_seq_len, 3)) 146 | absolute_coordinate[:sketch_abs.shape[0], :] = sketch_abs 147 | relative_coordinate[:sketch_delta.shape[0], :] = sketch_delta 148 | #################################### #################################### #################################### 149 | 150 | sample = {'sketch_img': sketch_img, 151 | 'sketch_path': sketch_path, 152 | 'absolute_coordinate':absolute_coordinate, 153 | 'relative_coordinate': relative_coordinate, 154 | 'sketch_length': int(len(sketch_abs)), 155 | 'absolute_fivePoint': to_FivePoint(sketch_abs), 156 | 'relative_fivePoint': to_FivePoint(sketch_delta), 157 | 'positive_img': positive_img, 158 | 'positive_path': positive_sample} 159 | 160 | return sample 161 | 162 | 163 | 164 | def __len__(self): 165 | if self.mode == 'Train': 166 | return len(self.Train_Sketch) 167 | elif self.mode == 'Test': 168 | return len(self.Test_Sketch) 169 | 170 | 171 | 172 | def get_dataloader(hp): 173 | 174 | dataset_Train = Photo2Sketch_Dataset(hp, mode = 'Train') 175 | dataloader_Train = data.DataLoader(dataset_Train, batch_size=hp.batchsize, shuffle=False, 176 | num_workers=int(hp.nThreads)) 177 | 178 | dataset_Test = Photo2Sketch_Dataset(hp, mode = 'Test') 179 | dataloader_Test = data.DataLoader(dataset_Test, batch_size=hp.batchsize, shuffle=False, 180 | num_workers=int(hp.nThreads)) 181 | 182 | return dataloader_Train, dataloader_Test 183 | 184 | 185 | def get_transform(type): 186 | transform_list = [] 187 | if type is 'Train': 188 | transform_list.extend([transforms.Resize(256)]) 189 | elif type is 'Test': 190 | transform_list.extend([transforms.Resize(256)]) 191 | transform_list.extend( 192 | [transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) 193 | return transforms.Compose(transform_list) 194 | 195 | 196 | def get_imageOnly_dataloader(dataroot = '/home/media/CVPR_2021/Sketch_SelfSupervised/ut-zap50k-images-square'): 197 | dataset = torchvision.datasets.ImageFolder(root=dataroot, 198 | transform=transforms.Compose([ 199 | transforms.Resize(260), 200 | transforms.CenterCrop(256), 201 | transforms.ToTensor(), 202 | transforms.Normalize((0.5, 0.5, 0.5), 203 | (0.5, 0.5, 0.5)), 204 | ])) 205 | assert dataset 206 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, 207 | shuffle=True, 208 | num_workers=8) 209 | return dataloader 210 | 211 | 212 | class get_sketchOnly_dataloader(object): 213 | 214 | def __init__(self, hp): 215 | dataset = np.load('sketchrnn_shoe.npz', encoding='latin1', allow_pickle=True) 216 | # /home/media/Siggraph/pytorch-sketchRNN/cat.npz sketchrnn_shoe.npz 217 | self.hp = hp 218 | data_train = dataset['train'] 219 | data_valid = dataset['valid'] 220 | 221 | 222 | # hp.sketch_rnn_max_seq_len = self.max_size(np.concatenate((data_train, data_valid))) 223 | sizes = [len(seq) for seq in np.concatenate((data_train, data_valid))] 224 | hp.sketch_rnn_max_seq_len = max(sizes) 225 | hp.average_seq_len = int(np.round(np.mean(sizes) + np.std(sizes))) 226 | 227 | 228 | 229 | data_train = self.purify(data_train) 230 | self.data_train = self.normalize(data_train) 231 | 232 | 233 | data_valid = self.purify(data_valid) 234 | self.data_valid = self.normalize(data_valid) 235 | 236 | 237 | def purify(self, strokes): 238 | """removes to small or too long sequences + removes large gaps""" 239 | data = [] 240 | for seq in strokes: 241 | if seq.shape[0] <= self.hp.sketch_rnn_max_seq_len and seq.shape[0] > 10: 242 | seq = np.minimum(seq, 1000) 243 | seq = np.maximum(seq, -1000) 244 | seq = np.array(seq, dtype=np.float32) 245 | data.append(seq) 246 | return data 247 | 248 | def max_size(self, data): 249 | """larger sequence length in the data set""" 250 | sizes = [len(seq) for seq in data] 251 | self.hp.average_len = np.round(np.mean(sizes) + np.std(sizes)) 252 | 253 | # greater_than_average = 0 254 | # for seq in data: 255 | # if len(seq) > self.hp.average_len: 256 | # greater_than_average +=1 257 | 258 | return max(sizes) 259 | 260 | 261 | def calculate_normalizing_scale_factor(self, strokes): 262 | """Calculate the normalizing factor explained in appendix of sketch-rnn.""" 263 | data = [] 264 | for i in range(len(strokes)): 265 | for j in range(len(strokes[i])): 266 | data.append(strokes[i][j, 0]) 267 | data.append(strokes[i][j, 1]) 268 | data = np.array(data) 269 | return np.std(data) 270 | 271 | def normalize(self, strokes): 272 | """Normalize entire dataset (delta_x, delta_y) by the scaling factor.""" 273 | data = [] 274 | scale_factor = self.calculate_normalizing_scale_factor(strokes) 275 | for seq in strokes: 276 | seq[:, 0:2] /= scale_factor 277 | data.append(seq) 278 | return data 279 | 280 | def train_batch(self, batch_size_sketch_rnn=100): 281 | batch_idx = np.random.choice(len(self.data_train), batch_size_sketch_rnn) 282 | batch_sequences = [self.data_train[idx] for idx in batch_idx] 283 | strokes = [] 284 | lengths = [] 285 | indice = 0 286 | for seq in batch_sequences: 287 | len_seq = len(seq[:, 0]) 288 | new_seq = np.zeros((self.hp.sketch_rnn_max_seq_len, 5)) 289 | new_seq[0:len_seq, :2] = seq[:, :2] 290 | new_seq[0:len_seq, 3] = seq[:, 2] 291 | new_seq[0:len_seq, 2] = 1 - new_seq[0:len_seq, 3] 292 | new_seq[(len_seq-1):, 4] = 1 293 | new_seq[(len_seq - 1), 2:4] = 0 294 | lengths.append(len(seq[:, 0])) 295 | strokes.append(new_seq) 296 | indice += 1 297 | 298 | batch = torch.from_numpy(np.stack(strokes, 1)).to(device).float() 299 | return batch, torch.tensor(lengths).type(torch.int64).to(device) 300 | 301 | 302 | def valid_batch(self, batch_size_sketch_rnn=100): 303 | batch_idx = np.random.choice(len(self.data_valid), batch_size_sketch_rnn) 304 | batch_sequences = [self.data_valid[idx] for idx in batch_idx] 305 | strokes = [] 306 | lengths = [] 307 | indice = 0 308 | for seq in batch_sequences: 309 | len_seq = len(seq[:, 0]) 310 | new_seq = np.zeros((self.hp.sketch_rnn_max_seq_len, 5)) 311 | new_seq[0:len_seq, :2] = seq[:, :2] 312 | new_seq[0:len_seq, 3] = seq[:, 2] 313 | new_seq[0:len_seq, 2] = 1 - new_seq[0:len_seq, 3] 314 | new_seq[len_seq:, 4] = 1 315 | lengths.append(len(seq[:, 0])) 316 | strokes.append(new_seq) 317 | indice += 1 318 | 319 | batch = torch.from_numpy(np.stack(strokes, 1)).to(device).float() 320 | return batch, torch.tensor(lengths).type(torch.int64).to(device) 321 | 322 | 323 | if __name__ == '__main__': 324 | pass 325 | # ########## Module Testing #################### 326 | # parser = argparse.ArgumentParser(description='rgb2sketch') 327 | # parser.add_argument('--dataset_name', type=str, default='ShoeV2') 328 | # hp = parser.parse_args() 329 | # hp.batchsize = 64 330 | # hp.nThreads = 8 331 | # hp.splitTrain = 0.7 332 | # dataset_Train = Photo2Sketch_Dataset(hp, mode='Train') 333 | # dataloader_Train = data.DataLoader(dataset_Train, batch_size=hp.batchsize, shuffle=False, 334 | # num_workers=int(hp.nThreads)) 335 | # 336 | # canvas = plt.figure(frameon=False, figsize=(2.56, 2.56)) 337 | # 338 | # for data in dataloader_Train: 339 | # print(data['sketch_img'].shape) 340 | # print(data['absolute_coordinate'].shape) 341 | # print(data['relative_coordinate'].shape) 342 | # print(data['absolute_fivePoint'].shape) 343 | # print(data['relative_fivePoint'].shape) 344 | # torchvision.utils.save_image(data['sketch_img'], 'sketch_img.jpg', normalize=True) 345 | # torchvision.utils.save_image(data['positive_img'], 'positive_img.jpg', normalize=True) 346 | # 347 | # 348 | # batch_redraw = [] 349 | # 350 | # for data in data['relative_fivePoint']: 351 | # image = rasterize_relative(to_stroke_list(to_normal_strokes(data)), canvas) 352 | # batch_redraw.append(torch.from_numpy(np.array(image)).permute(2, 0, 1)) 353 | # 354 | # torchvision.utils.save_image(torch.stack(batch_redraw).float(), 'batch_redraw.jpg', normalize=True) 355 | # 356 | # 357 | # # image = rasterize_relative(to_stroke_list(to_normal_strokes(data['relative_fivePoint'][7])), canvas) 358 | # # image.save('a.jpg') 359 | # # image.show() 360 | 361 | -------------------------------------------------------------------------------- /Photo_to_Sketch_2D_Attention/utils.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import torch 3 | use_cuda = True 4 | from IPython.display import SVG, display 5 | import numpy as np 6 | import svgwrite 7 | from six.moves import xrange 8 | import math 9 | import torch.nn as nn 10 | import torchvision 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | import torch.nn.functional as F 13 | import torch 14 | import torchvision 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | # from torch.utils.tensorboard import SummaryWriter 17 | import os 18 | import shutil 19 | import cv2 20 | import imageio 21 | 22 | def to_normal_strokes(big_stroke): 23 | """Convert from stroke-5 format (from sketch-rnn paper) back to stroke-3.""" 24 | l = 0 25 | for i in range(len(big_stroke)): 26 | if big_stroke[i, 4] > 0: 27 | l = i 28 | break 29 | if l == 0: 30 | l = len(big_stroke)-1 31 | result = np.zeros((l+1, 3)) 32 | result[:, 0:2] = big_stroke[0:l+1, 0:2] 33 | result[:, 2] = big_stroke[0:l+1, 3] 34 | result[-1, -1] = 1. 35 | return result 36 | 37 | 38 | def get_bounds(data, factor=10): 39 | """Return bounds of data.""" 40 | min_x = 0 41 | max_x = 0 42 | min_y = 0 43 | max_y = 0 44 | 45 | abs_x = 0 46 | abs_y = 0 47 | for i in range(len(data)): 48 | x = float(data[i, 0]) / factor 49 | y = float(data[i, 1]) / factor 50 | abs_x += x 51 | abs_y += y 52 | min_x = min(min_x, abs_x) 53 | min_y = min(min_y, abs_y) 54 | max_x = max(max_x, abs_x) 55 | max_y = max(max_y, abs_y) 56 | 57 | return (min_x, max_x, min_y, max_y) 58 | 59 | 60 | 61 | def transfer_ImageNomralization(x, type='to_Gen'): 62 | # https://discuss.pytorch.org/t/how-to-normalize-multidimensional-tensor/65304 63 | #to_Gen (-1, 1) vs to_Recog (ImageNet Normalize) 64 | if type == 'to_Gen': 65 | # First Unnormalize 66 | mean = torch.tensor([-0.485/0.229, -0.456/0.224, -0.406/0.225]).to(device) 67 | std = torch.tensor([1/0.229, 1/0.224, 1/0.225]).to(device) 68 | x.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) 69 | # Then Normalize Again 70 | mean = torch.tensor([0.5, 0.5, 0.5]).to(device) 71 | std = torch.tensor([0.5, 0.5, 0.5]).to(device) 72 | x.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) 73 | 74 | elif type == 'to_Recog': 75 | # First Unnormalize 76 | mean = torch.tensor([-1.0, -1.0, -1.0]).to(device) 77 | std = torch.tensor([1/0.5, 1/0.5, 1/0.5]).to(device) 78 | x.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) 79 | # Then Normalize Again 80 | mean = torch.tensor([0.485, 0.456, 0.406]).to(device) 81 | std = torch.tensor([0.229, 0.224, 0.225]).to(device) 82 | x.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) 83 | return x 84 | 85 | def sample_next_state(output, hp, temperature =0.01): 86 | 87 | def adjust_temp(pi_pdf): 88 | pi_pdf = np.log(pi_pdf)/temperature 89 | pi_pdf -= pi_pdf.max() 90 | pi_pdf = np.exp(pi_pdf) 91 | pi_pdf /= pi_pdf.sum() 92 | return pi_pdf 93 | 94 | [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen_logits] = output 95 | # get mixture indices: 96 | o_pi = o_pi.data[0,:].cpu().numpy() 97 | o_pi = adjust_temp(o_pi) 98 | pi_idx = np.random.choice(hp.num_mixture, p=o_pi) 99 | # get pen state: 100 | o_pen = F.softmax(o_pen_logits, dim=-1) 101 | o_pen = o_pen.data[0,:].cpu().numpy() 102 | pen = adjust_temp(o_pen) 103 | pen_idx = np.random.choice(3, p=pen) 104 | # get mixture params: 105 | o_mu1 = o_mu1.data[0,pi_idx].item() 106 | o_mu2 = o_mu2.data[0,pi_idx].item() 107 | o_sigma1 = o_sigma1.data[0,pi_idx].item() 108 | o_sigma2 = o_sigma2.data[0,pi_idx].item() 109 | o_corr = o_corr.data[0,pi_idx].item() 110 | x,y = sample_bivariate_normal(o_mu1,o_mu2,o_sigma1,o_sigma2,o_corr, temperature = temperature, greedy=False) 111 | next_state = torch.zeros(5) 112 | next_state[0] = x 113 | next_state[1] = y 114 | next_state[pen_idx+2] = 1 115 | return next_state.to(device).view(1,1,-1), next_state 116 | 117 | 118 | def sample_bivariate_normal(mu_x, mu_y, sigma_x, sigma_y, rho_xy, temperature = 0.2, greedy=False): 119 | # inputs must be floats 120 | if greedy: 121 | return mu_x, mu_y 122 | mean = [mu_x, mu_y] 123 | sigma_x *= np.sqrt(temperature) #confusion 124 | sigma_y *= np.sqrt(temperature) #confusion 125 | cov = [[sigma_x * sigma_x, rho_xy * sigma_x * sigma_y], \ 126 | [rho_xy * sigma_x * sigma_y, sigma_y * sigma_y]] 127 | x = np.random.multivariate_normal(mean, cov, 1) 128 | return x[0][0], x[0][1] 129 | 130 | class Visualizer: 131 | def __init__(self, name = 'Photo2Sketch'): 132 | 133 | # if os.path.exists('Tensorboard_' + name): 134 | # shutil.rmtree('Tensorboard_' + name) 135 | 136 | self.writer = SummaryWriter() 137 | 138 | self.mean = torch.tensor([-1.0, -1.0, -1.0]).to(device) 139 | self.std = torch.tensor([1 / 0.5, 1 / 0.5, 1 / 0.5]).to(device) 140 | 141 | def vis_image(self, visularize, step, normalize=False): 142 | for keys, value in visularize.items(): 143 | #print(keys,value.size()) 144 | if normalize: 145 | value.sub_(self.mean[None, :, None, None]).div_(self.std[None, :, None, None]) 146 | visularize[keys] = torchvision.utils.make_grid(value) 147 | self.writer.add_image('{}'.format(keys), visularize[keys], step) 148 | 149 | 150 | def plot_scalars(self, scalars, step): 151 | 152 | for keys, value in scalars.items(): 153 | #print(keys,value.size()) 154 | self.writer.add_scalar('{}'.format(keys), scalars[keys], step) 155 | 156 | 157 | def showAttention(attention_plot, sketch_img, sketch_vector_gt_draw, photo2sketch_gen_draw, sketch_name): 158 | # Set up figure with colorbar 159 | mean = torch.tensor([-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225]).to('cpu') 160 | std = torch.tensor([1 / 0.229, 1 / 0.224, 1 / 0.225]).to('cpu') 161 | 162 | 163 | folder_name = os.path.join('./CVPR_SSL/' + '_'.join(sketch_name.split('/')[-1].split('_')[:-1])) 164 | if not os.path.exists(folder_name): 165 | os.makedirs(folder_name) 166 | 167 | 168 | # sketch_vector_gt_draw = sketch_vector_gt_draw.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) 169 | sketch_vector_gt_draw = sketch_vector_gt_draw.squeeze(0).permute(1, 2, 0).numpy() 170 | sketch_vector_gt_draw = cv2.resize(np.float32(np.uint8(255. * (1. - sketch_vector_gt_draw))), (256, 256)) 171 | 172 | # photo2sketch_gen_draw = photo2sketch_gen_draw.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) 173 | photo2sketch_gen_draw = photo2sketch_gen_draw.squeeze(0).permute(1, 2, 0).numpy() 174 | photo2sketch_gen_draw = cv2.resize(np.float32(np.uint8(255. * (1. -photo2sketch_gen_draw))), (256, 256)) 175 | 176 | imageio.imwrite(folder_name + '/sketch_' + 'GT.jpg', sketch_vector_gt_draw) 177 | imageio.imwrite(folder_name + '/sketch_' + 'Gen.jpg', photo2sketch_gen_draw) 178 | 179 | attention_dictionary = {} 180 | for num, val in enumerate(sketch_img): 181 | attention_dictionary[num] = [] 182 | val = val.cpu() 183 | x = val.unsqueeze(0) 184 | x.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) 185 | x = x.squeeze(0) 186 | attention_dictionary[num].append(x) 187 | 188 | 189 | alpha = 0.5 190 | for atten_num, x_data in enumerate(attention_plot): 191 | for num, per_image_x in enumerate(x_data): 192 | 193 | attention = per_image_x.squeeze(0).cpu().numpy() 194 | 195 | # attention[attention < 0.01] = 0 196 | # attention = attention / attention.sum() 197 | # attention = np.clip(attention / attention.max() * 255., 0, 255).astype(np.uint8) 198 | 199 | heatmap = cv2.applyColorMap(np.uint8(255 * attention), cv2.COLORMAP_JET) 200 | heatmap = cv2.resize(np.float32(heatmap), (256, 256)) 201 | 202 | # heatmap = heatmap**2 203 | 204 | 205 | # image = 255. - attention_dictionary[num][0].permute(1, 2, 0).numpy() 206 | 207 | # mean = torch.tensor([-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225]).to('cpu') 208 | # std = torch.tensor([1 / 0.229, 1 / 0.224, 1 / 0.225]).to('cpu') 209 | # x = attention_dictionary[num][0].unsqueeze(0) 210 | # x.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) 211 | # x = x.squeeze(0) 212 | image = attention_dictionary[num][0].permute(1, 2, 0).numpy() 213 | image = cv2.resize(np.float32(np.uint8(255. * image)), (256, 256)) 214 | 215 | # image preprocess 216 | 217 | imageio.imwrite(folder_name + '/RGB.jpg', image) 218 | 219 | heat_map_overlay = cv2.addWeighted(heatmap, alpha, image, 1 - alpha, 0) 220 | imageio.imwrite(folder_name + '/' + sketch_name.split('/')[-1] + '_' + str(atten_num) + '.jpg', heat_map_overlay) 221 | 222 | 223 | heat_map_tensor = torch.from_numpy(heat_map_overlay).permute(2, 0, 1) 224 | # heat_map_tensor = attention_dictionary[num][0] + torch.from_numpy(heatmap).permute(2, 0, 1) 225 | # heat_map_tensor = heat_map_tensor / heat_map_tensor.max() 226 | attention_dictionary[num].append(heat_map_tensor) 227 | 228 | # plot_attention = [] 229 | # for num, val in enumerate(sketch_img): 230 | # image = torch.stack(attention_dictionary[num][1:], dim=0).permute(1, 2, 0, 3).reshape(3, 256, -1) 231 | # image.add_(-image.min()).div_(image.max() - image.min()+ 1e-5) 232 | # plot_attention.append(image) 233 | 234 | # return torch.stack(plot_attention) 235 | return None 236 | 237 | # def showAttention(attention_plot, sketch_img): 238 | # # Set up figure with colorbar 239 | # 240 | # attention_dictionary = {} 241 | # for num, val in enumerate(sketch_img): 242 | # attention_dictionary[num] = [] 243 | # attention_dictionary[num].append(val.cpu()) 244 | # alpha = 0.3 245 | # for x_data in attention_plot: 246 | # for num, per_image_x in enumerate(x_data): 247 | # attention = per_image_x.squeeze(0).cpu().numpy() 248 | # heatmap = cv2.applyColorMap(np.uint8(255 * attention), cv2.COLORMAP_JET) 249 | # heatmap = cv2.resize(np.float32(heatmap), (256, 256)) 250 | # 251 | # image = 255. - attention_dictionary[num][0].permute(1, 2, 0).numpy() 252 | # # mean = torch.tensor([-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225]).to(device) 253 | # # std = torch.tensor([1 / 0.229, 1 / 0.224, 1 / 0.225]).to(device) 254 | # # x = attention_dictionary[num][0].unsqueeze(0) 255 | # # x.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) 256 | # # x = x.squeeze(0) 257 | # # image = x.permute(1, 2, 0).numpy() 258 | # # image = cv2.resize(np.float32(np.uint8(255. * image)), (256, 256)) 259 | # 260 | # heat_map_overlay = cv2.addWeighted(heatmap, alpha, image, 1 - alpha, 0) 261 | # heat_map_tensor = torch.from_numpy(heat_map_overlay).permute(2, 0, 1) 262 | # # heat_map_tensor = attention_dictionary[num][0] + torch.from_numpy(heatmap).permute(2, 0, 1) 263 | # # heat_map_tensor = heat_map_tensor / heat_map_tensor.max() 264 | # attention_dictionary[num].append(heat_map_tensor) 265 | # 266 | # plot_attention = [] 267 | # for num, val in enumerate(sketch_img): 268 | # image = torch.stack(attention_dictionary[num][1:], dim=0).permute(1, 2, 0, 3).reshape(3, 256, -1) 269 | # image.add_(-image.min()).div_(image.max() - image.min()+ 1e-5) 270 | # plot_attention.append(image) 271 | # 272 | # return torch.stack(plot_attention) 273 | 274 | 275 | # 276 | # 277 | # 278 | # fig = plt.figure() 279 | # ax = fig.add_subplot(111) 280 | # cax = ax.matshow(attentions.numpy(), cmap='bone') 281 | # fig.colorbar(cax) 282 | # 283 | # # Set up axes 284 | # ax.set_xticklabels([''] + input_sentence.split(' ') + 285 | # [''], rotation=90) 286 | # ax.set_yticklabels([''] + output_words) 287 | # 288 | # # Show label at every tick 289 | # ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) 290 | # ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) 291 | 292 | def to_stroke_list(sketch): 293 | ## sketch: an `.npz` style sketch from QuickDraw 294 | sketch = np.vstack((np.array([0, 0, 0]), sketch)) 295 | sketch[:, :2] = np.cumsum(sketch[:, :2], axis=0) 296 | 297 | # range normalization 298 | xmin, xmax = sketch[:, 0].min(), sketch[:, 0].max() 299 | ymin, ymax = sketch[:, 1].min(), sketch[:, 1].max() 300 | 301 | sketch[:, 0] = ((sketch[:, 0] - xmin) / float(xmax - xmin)) * (255. - 60.) + 30. 302 | sketch[:, 1] = ((sketch[:, 1] - ymin) / float(ymax - ymin)) * (255. - 60.) + 30. 303 | sketch = sketch.astype(np.int64) 304 | 305 | stroke_list = np.split(sketch[:, :2], np.where(sketch[:, 2])[0] + 1, axis=0) 306 | 307 | if stroke_list[-1].size == 0: 308 | stroke_list = stroke_list[:-1] 309 | 310 | if len(stroke_list) == 0: 311 | stroke_list = [sketch[:, :2]] 312 | # print('error') 313 | return stroke_list 314 | 315 | def save_sketch(sketch_vector, sketch_name): 316 | stroke_list = to_stroke_list(to_normal_strokes(sketch_vector.cpu().numpy())) 317 | 318 | points = np.sum([len(x) for x in stroke_list]) 319 | point_list = [len(x) for x in stroke_list] 320 | 321 | folder_name = os.path.join('./CVPR_SSL/' + '_'.join(sketch_name.split('/')[-1].split('_')[:-1]), sketch_name.split('/')[-1]) 322 | if not os.path.exists(folder_name): 323 | os.makedirs(folder_name) 324 | 325 | fig = plt.figure(frameon=False, figsize=(2.56, 2.56)) 326 | xlim = [0, 255] 327 | ylim = [0, 255] 328 | x_count = 0 329 | count = 0 330 | for stroke in stroke_list: 331 | stroke_buffer = np.array(stroke[0]) 332 | for x_num in range(len(stroke)): 333 | x_count = x_count + 1 334 | stroke_buffer = np.vstack((stroke_buffer, stroke[x_num, :2])) 335 | if x_count % 5 == 0: 336 | 337 | plt.plot(stroke_buffer[:, 0], stroke_buffer[:, 1], '.', linestyle='solid', linewidth=1.0, markersize=5) 338 | plt.gca().invert_yaxis(); 339 | plt.axis('off') 340 | 341 | plt.savefig(folder_name + '/sketch_' + str(count) + 'points_.jpg', bbox_inches='tight', 342 | pad_inches=0, dpi=1200) 343 | count = count + 1 344 | plt.gca().invert_yaxis(); 345 | 346 | 347 | plt.plot(stroke_buffer[:, 0], stroke_buffer[:, 1], '.', linestyle='solid', linewidth=1.0, markersize=5) 348 | 349 | 350 | 351 | 352 | 353 | 354 | def save_sketch_gen(sketch_vector, sketch_name): 355 | stroke_list = to_stroke_list(to_normal_strokes(sketch_vector.cpu().numpy())) 356 | 357 | folder_name = os.path.join('./CVPR_SSL/' + '_'.join(sketch_name.split('/')[-1].split('_')[:-1]), sketch_name.split('/')[-1]+'GEN') 358 | if not os.path.exists(folder_name): 359 | os.makedirs(folder_name) 360 | 361 | 362 | fig = plt.figure(frameon=False, figsize=(2.56, 2.56)) 363 | xlim = [0, 255] 364 | ylim = [0, 255] 365 | x_count = 0 366 | count = 0 367 | for stroke in stroke_list: 368 | stroke_buffer = np.array(stroke[0]) 369 | for x_num in range(len(stroke)): 370 | x_count = x_count + 1 371 | stroke_buffer = np.vstack((stroke_buffer, stroke[x_num, :2])) 372 | if x_count % 5 == 0: 373 | 374 | plt.plot(stroke_buffer[:, 0], stroke_buffer[:, 1], '.', linestyle='solid', linewidth=1.0, markersize=5) 375 | plt.gca().invert_yaxis(); 376 | plt.axis('off') 377 | 378 | plt.savefig(folder_name + '/sketch_' + str(count) + 'points_.jpg', bbox_inches='tight', 379 | pad_inches=0, dpi=1200) 380 | count = count + 1 381 | plt.gca().invert_yaxis(); 382 | 383 | 384 | plt.plot(stroke_buffer[:, 0], stroke_buffer[:, 1], '.', linestyle='solid', linewidth=1.0, markersize=5) -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 119 | 120 | 121 | 122 | 123 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | More Photos are All You Need: Semi-Supervised Learning for Fine-Grained Sketch Based Image Retrieval 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 150 |
148 |

149 |
151 | 152 | 153 | 154 | 155 |
More Photos are All You Need:
156 |
Semi-Supervised Learning for Fine-Grained Sketch Based Image Retrieval

157 | 158 | 159 | 161 | 162 | 164 | 165 | 166 | 168 | 169 | 170 | 172 | 173 | 174 | 175 | 177 | 178 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 |
160 |
Ayan Kumar Bhunia
163 |
Pinaki Nath Chowdhury
167 |
Aneeshan Sain
171 |
Yongxin Yang
176 |
Tao (Tony) Xiang
179 |
Yi-Zhe Song
196 | 197 | 198 | 199 |
SketchX, Centre for Vision Speech and Signal Processing,
University of Surrey, United Kingdom

200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 |
Published at CVPR 2021

208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 |
[Paper]
[GitHub]

218 | 219 | 220 | 221 | 222 | 223 | 224 | 229 |
225 | 226 | 227 |

228 |
230 |
231 | 232 |
233 | A fundamental challenge faced by existing Fine-Grained Sketch-Based Image Retrieval (FG-SBIR) models is the data scarcity -- model performances are largely bottlenecked by the lack of sketch-photo pairs. Whilst the number of photos can be easily scaled, each corresponding sketch still needs to be individually produced. In this paper, we aim to mitigate such an upper-bound on sketch data, and study whether unlabelled photos alone (of which they are many) can be cultivated for performances gain. In particular, we introduce a novel semi-supervised framework for cross-modal retrieval that can additionally leverage large-scale unlabelled photos to account for data scarcity. At the centre of our semi-supervision design is a sequential photo-to-sketch generation model that aims to generate paired sketches for unlabelled photos. Importantly, we further introduce a discriminator guided mechanism to guide against unfaithful generation, together with a distillation loss based regularizer to provide tolerance against noisy training samples. Last but not least, we treat generation and retrieval as two conjugate problems, where a joint learning procedure is devised for each module to mutually benefit from each other. Extensive experiments show that our semi-supervised model yields significant performance boost over the state-of-the-art supervised alternatives, as well as existing methods that can exploit unlabelled photos for FG-SBIR.
234 |

235 | 236 |

Framework

237 |
238 | Our framework: a FG-SBIR model leverages large scale unlabelled photos using a sequential photo-to-sketch generation model along with labelled pairs. Discriminator guided instance-wise weighting and distillation loss are used to guard against the noisy generated data. Simultaneously, photo-to-sketch generation model learns by taking reward from FG-SBIR model and Discriminator via policy gradient (over both labelled and unlabelled) together with supervised VAE loss over labelled data. Note rasterization (vector to raster format) is a non-differentiable operation. 239 |

240 | 241 |

242 | 245 |
243 |

244 |
246 |

247 | 248 | 249 | 250 |

Short Presentation

251 | 252 | 255 |
253 | 254 |
256 |

257 | 258 | 259 |

Bibtex

260 | 261 | 271 | 273 | 292 | 293 | 294 | 296 | 298 | 301 | 302 |
262 | 263 | 264 |
265 | 266 | 267 | 268 | 269 |
270 |
272 | 274 | 275 | 276 |

Citation
 
More Photos are All You Need: Semi-Supervised Learning for Fine-Grained Sketch Based Image Retrieval. In CVPR 2021.

277 | 278 | [Bibtex] 279 |
280 |
281 | 
282 | @InProceedings{bhunia_semifgsbir,
283 | author = {Ayan Kumar Bhunia and Pinaki Nath Chowdhury and Aneeshan Sain and Yongxin Yang and Tao Xiang and Yi-Zhe Song},
284 | title = {More Photos are All You Need: Semi-Supervised Learning for Fine-Grained Sketch Based Image Retrieval},
285 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
286 | month = {June},
287 | year = {2021}
288 | }
289 |                 
290 |
291 |
295 | 297 | 299 | 300 |
303 | 304 | 316 | 317 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | -------------------------------------------------------------------------------- /Photo_to_Sketch_2D_Attention/CVPR_18_Baseline/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from Image_Networks import * 3 | from Sketch_Networks import * 4 | from torch import optim 5 | import torch 6 | import time 7 | import torch.nn.functional as F 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | from utils import * 10 | import torchvision 11 | from dataset import get_imageOnly_dataloader, get_sketchOnly_dataloader, get_dataloader 12 | from rasterize import rasterize_relative, to_stroke_list 13 | import math 14 | from rasterize import batch_rasterize_relative 15 | 16 | 17 | 18 | class Photo2Sketch_Base(nn.Module): 19 | 20 | def __init__(self, hp): 21 | super(Photo2Sketch_Base, self).__init__() 22 | self.Image_Encoder = EncoderCNN() 23 | self.Image_Decoder = DecoderCNN() 24 | self.Sketch_Encoder = EncoderRNN(hp) 25 | self.Sketch_Decoder = DecoderRNN(hp) 26 | self.hp = hp 27 | self.apply(weights_init_normal) 28 | 29 | def pretrain_SketchBranch(self, iteration = 100000): 30 | 31 | dataloader = get_sketchOnly_dataloader(self.hp) 32 | self.hp.max_seq_len = self.hp.sketch_rnn_max_seq_len 33 | self.Sketch_Encoder.train() 34 | self.Sketch_Decoder.train() 35 | self.train_sketch_params = list(self.Sketch_Encoder.parameters()) + list(self.Sketch_Decoder.parameters()) 36 | self.sketch_optimizer = optim.Adam(self.train_sketch_params, self.hp.learning_rate) 37 | self.visalizer = Visualizer() 38 | 39 | for step in range(iteration): 40 | 41 | batch, lengths = dataloader.train_batch() 42 | 43 | self.sketch_optimizer.zero_grad() 44 | 45 | curr_learning_rate = ((self.hp.learning_rate - self.hp.min_learning_rate) * 46 | (self.hp.decay_rate) ** step + self.hp.min_learning_rate) 47 | curr_kl_weight = (self.hp.kl_weight - (self.hp.kl_weight - self.hp.kl_weight_start) * 48 | (self.hp.kl_decay_rate) ** step) 49 | 50 | post_dist = self.Sketch_Encoder(batch, lengths) 51 | 52 | z_vector = post_dist.rsample() 53 | start_token = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * self.hp.batch_size_sketch_rnn).unsqueeze(0).to( 54 | device) 55 | batch_init = torch.cat([start_token, batch], 0) 56 | z_stack = torch.stack([z_vector] * (self.hp.sketch_rnn_max_seq_len + 1)) 57 | inputs = torch.cat([batch_init, z_stack], 2) 58 | 59 | output, _ = self.Sketch_Decoder(inputs, z_vector, lengths + 1) 60 | 61 | end_token = torch.stack([torch.Tensor([0, 0, 0, 0, 1])] * batch.shape[1]).unsqueeze(0).to(device) 62 | batch = torch.cat([batch, end_token], 0) 63 | x_target = batch.permute(1, 0, 2) # batch-> Seq_Len, Batch, Feature_dim 64 | 65 | #################### Loss Calculation ######################################## 66 | ############################################################################## 67 | recons_loss = sketch_reconstruction_loss(output, x_target) 68 | 69 | prior_distribution = torch.distributions.Normal(torch.zeros_like(post_dist.mean), 70 | torch.ones_like(post_dist.stddev)) 71 | kl_cost = torch.max(torch.distributions.kl_divergence(post_dist, prior_distribution).mean(), 72 | torch.tensor(self.hp.kl_tolerance).to(device)) 73 | loss = recons_loss + curr_kl_weight * kl_cost 74 | 75 | #################### Update Gradient ######################################## 76 | ############################################################################# 77 | set_learninRate(self.sketch_optimizer, curr_learning_rate) 78 | loss.backward() 79 | nn.utils.clip_grad_norm(self.train_sketch_params, self.hp.grad_clip) 80 | self.sketch_optimizer.step() 81 | 82 | if (step + 1) % 5 == 0: 83 | print('Step:{} ** KL_Loss:{} ' 84 | '** Recons_Loss:{} ** Total_loss:{}'.format(step, kl_cost.item(), 85 | recons_loss.item(), loss.item())) 86 | 87 | data = {} 88 | data['Reconstrcution_Loss'] = recons_loss 89 | data['KL_Loss'] = kl_cost 90 | data['Total Loss'] = loss 91 | self.visalizer.plot_scalars(data, step) 92 | 93 | if (step + 1) % self.hp.eval_freq_iter == 0: 94 | 95 | batch_input, batch_gen_strokes = self.sketch_generation_deterministic(dataloader) 96 | # batch_input, batch_gen_strokes = self.sketch_generation_sample(dataloader) 97 | 98 | batch_redraw = batch_rasterize_relative(batch_gen_strokes) 99 | 100 | if batch_input is not None: 101 | batch_input_redraw = batch_rasterize_relative(batch_input) 102 | batch = [] 103 | for a, b in zip(batch_input_redraw, batch_redraw): 104 | batch.append(torch.cat((a, 1. - b), dim=-1)) 105 | batch = torch.stack(batch).float() 106 | else: 107 | batch = batch_redraw.float() 108 | 109 | torchvision.utils.save_image(batch, './pretrain_sketch_Viz/deterministic/batch_rceonstruction_' + str(step) + '_.jpg', 110 | nrow=round(math.sqrt(len(batch)))) 111 | 112 | torch.save(self.Sketch_Encoder.state_dict(), './pretrain_models/Sketch_Encoder.pth') 113 | torch.save(self.Sketch_Decoder.state_dict(), './pretrain_models/Sketch_Decoder.pth') 114 | 115 | self.Sketch_Encoder.train() 116 | self.Sketch_Decoder.train() 117 | 118 | 119 | 120 | def sketch_generation_deterministic(self, dataloader, number_of_sample=64, condition = True): 121 | 122 | self.Sketch_Encoder.eval() 123 | self.Sketch_Decoder.eval() 124 | 125 | batch, lengths = dataloader.valid_batch(number_of_sample) 126 | if condition: 127 | post_dist = self.Sketch_Encoder(batch, lengths) 128 | z_vector = post_dist.sample() 129 | else: 130 | z_vector = torch.randn(number_of_sample, 128).to(device) 131 | 132 | start_token = torch.Tensor([0, 0, 1, 0, 0]).view(-1, 5).to(device) 133 | start_token = torch.stack([start_token] * number_of_sample, dim=1) 134 | state = start_token 135 | hidden_cell = None 136 | 137 | batch_gen_strokes = [] 138 | for i_seq in range(self.hp.average_seq_len): 139 | input = torch.cat([state, z_vector.unsqueeze(0)], 2) 140 | state, hidden_cell = self.Sketch_Decoder(input, z_vector, hidden_cell=hidden_cell, isTrain=False, get_deterministic=True) 141 | batch_gen_strokes.append(state.squeeze(0)) 142 | 143 | batch_gen_strokes = torch.stack(batch_gen_strokes, dim=1) 144 | 145 | if condition: 146 | return batch.permute(1, 0, 2), batch_gen_strokes 147 | else: 148 | return None, batch_gen_strokes 149 | 150 | 151 | def sketch_generation_sample(self, dataloader, number_of_sample=64, condition = True): 152 | 153 | self.Sketch_Encoder.eval() 154 | self.Sketch_Decoder.eval() 155 | 156 | batch_gen_strokes = [] 157 | batch_input = [] 158 | 159 | for i_x in range(number_of_sample): 160 | batch, lengths = dataloader.valid_batch(1) 161 | 162 | if condition: 163 | post_dist = self.Sketch_Encoder(batch, lengths) 164 | z_vector = post_dist.sample() 165 | else: 166 | z_vector = torch.randn(1,128).to(device) 167 | 168 | start_token = torch.Tensor([0,0,1,0,0]).view(1, 1, -1).to(device) 169 | state = start_token 170 | hidden_cell = None 171 | gen_strokes = [] 172 | for i in range(self.hp.sketch_rnn_max_seq_len): 173 | input = torch.cat([state, z_vector.unsqueeze(0)],2) 174 | output, hidden_cell = self.Sketch_Decoder(input, z_vector, hidden_cell = hidden_cell, isTrain = False, get_deterministic=False) 175 | state, next_state = sample_next_state(output, self.hp) 176 | gen_strokes.append(next_state) 177 | 178 | gen_strokes = torch.stack(gen_strokes) 179 | batch_gen_strokes.append(gen_strokes) 180 | batch_input.append(batch.squeeze(1)) 181 | 182 | batch_gen_strokes = torch.stack(batch_gen_strokes, dim=1) 183 | batch_input = torch.stack(batch_input, dim=1) 184 | 185 | if condition: 186 | return batch_input.permute(1, 0, 2), batch_gen_strokes.permute(1, 0, 2) 187 | else: 188 | return None, batch_gen_strokes.permute(1, 0, 2) 189 | 190 | 191 | def pretrain_ImageBranch(self, epoch = 200): 192 | 193 | image_dataloader = get_imageOnly_dataloader() 194 | self.Image_Encoder.train() 195 | self.Image_Decoder.train() 196 | self.train_image_params = list(self.Image_Encoder.parameters()) + list(self.Image_Decoder.parameters()) 197 | self.image_optimizer = optim.Adam(self.train_image_params, self.hp.learning_rate) 198 | step = 0 199 | self.visalizer = Visualizer() 200 | 201 | for i_epoch in range(epoch): 202 | 203 | for _, batch_sample in enumerate(image_dataloader, 0): 204 | 205 | step = step + 1 206 | self.image_optimizer.zero_grad() 207 | 208 | batch_image = batch_sample[0].to(device) 209 | post_dist = self.Image_Encoder(batch_image) 210 | z_vector = post_dist.rsample() 211 | recons_batch_image = self.Image_Decoder(z_vector) 212 | 213 | # batch_image_normalized = transfer_ImageNomralization(batch_image, 'to_Gen') 214 | batch_image_normalized = batch_image 215 | recons_loss = F.mse_loss(batch_image_normalized, recons_batch_image, reduction='sum')/batch_image.shape[0] 216 | # recons_loss = F.mse_loss(batch_image_normalized, recons_batch_image) 217 | 218 | prior_distribution = torch.distributions.Normal(torch.zeros_like(post_dist.mean), torch.ones_like(post_dist.stddev)) 219 | kl_cost = torch.distributions.kl_divergence(post_dist, prior_distribution).sum(1).mean() 220 | 221 | loss = recons_loss + kl_cost 222 | 223 | # log_var = torch.log(post_dist.stddev**2) 224 | # loss_matrx = 1 + log_var - post_dist.loc ** 2 - log_var.exp() 225 | # loss_matrx_sum = torch.sum(loss_matrx, dim=1) 226 | # kld_loss = torch.mean(-0.5 * loss_matrx_sum, dim=0) 227 | 228 | loss.backward() 229 | nn.utils.clip_grad_norm(self.train_image_params, self.hp.grad_clip) 230 | self.image_optimizer.step() 231 | 232 | 233 | 234 | if (step + 1) % 20 == 0: 235 | print('Step:{} ** KL_Loss:{} ' 236 | '** Recons_Loss:{} ** Total_loss:{}'.format(step, kl_cost.item(), 237 | recons_loss.item(), loss.item())) 238 | 239 | data = {} 240 | data['Reconstrcution_Loss'] = recons_loss 241 | data['KL_Loss'] = kl_cost 242 | data['Total Loss'] = loss 243 | self.visalizer.plot_scalars(data, step) 244 | 245 | data = {} 246 | data['Input_Image'] = batch_image 247 | data['Recons_Image'] = recons_batch_image 248 | sample_z = torch.randn_like(z_vector) 249 | data['Sampled_Image'] = self.Image_Decoder(sample_z) 250 | self.visalizer.vis_image(data, step) 251 | 252 | 253 | if (step + 1) % self.hp.eval_freq_iter == 0: 254 | saved_tensor = torch.cat([batch_image_normalized, recons_batch_image], dim=0) 255 | torchvision.utils.save_image(saved_tensor, './pretrain_image_Viz/'+ str(step) + '.jpg', normalize=True) 256 | torch.save(self.Image_Encoder.state_dict(), './pretrain_models/Image_Encoder' + str(step) + '.pth') 257 | torch.save(self.Image_Decoder.state_dict(), './pretrain_models/Image_Decoder' + str(step) + '.pth') 258 | 259 | def pretrain_SketchBranch_ShoeV2(self, iteration = 10000): 260 | 261 | self.hp.batchsize = 100 262 | dataloader_Train, dataloader_Test = get_dataloader(self.hp) 263 | 264 | self.Sketch_Encoder.train() 265 | self.Sketch_Decoder.train() 266 | 267 | self.train_sketch_params = list(self.Sketch_Encoder.parameters()) + list(self.Sketch_Decoder.parameters()) 268 | self.sketch_optimizer = optim.Adam(self.train_sketch_params, self.hp.learning_rate) 269 | 270 | self.visalizer = Visualizer() 271 | 272 | step =0 273 | 274 | for i_epoch in range(2000): 275 | 276 | for batch_data in dataloader_Train: 277 | 278 | batch = batch_data['relative_fivePoint'].to(device).permute(1, 0, 2).float() # Seq_Len, Batch, Feature 279 | lengths = batch_data['sketch_length'].to(device) - 1 # TODO: Relative coord has one less 280 | step += 1 281 | # batch, lengths = dataloader.train_batch() 282 | 283 | self.sketch_optimizer.zero_grad() 284 | 285 | curr_learning_rate = ((self.hp.learning_rate - self.hp.min_learning_rate) * 286 | (self.hp.decay_rate) ** step + self.hp.min_learning_rate) 287 | curr_kl_weight = (self.hp.kl_weight - (self.hp.kl_weight - self.hp.kl_weight_start) * 288 | (self.hp.kl_decay_rate) ** step) 289 | 290 | post_dist = self.Sketch_Encoder(batch, lengths) 291 | 292 | z_vector = post_dist.rsample() 293 | start_token = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * batch.shape[1]).unsqueeze(0).to(device) 294 | batch_init = torch.cat([start_token, batch], 0) 295 | z_stack = torch.stack([z_vector] * (self.hp.max_seq_len + 1)) 296 | inputs = torch.cat([batch_init, z_stack], 2) 297 | 298 | output, _ = self.Sketch_Decoder(inputs, z_vector, lengths + 1) 299 | 300 | end_token = torch.stack([torch.Tensor([0, 0, 0, 0, 1])] * batch.shape[1]).unsqueeze(0).to(device) 301 | batch = torch.cat([batch, end_token], 0) 302 | x_target = batch.permute(1, 0, 2) # batch-> Seq_Len, Batch, Feature_dim 303 | 304 | #################### Loss Calculation ######################################## 305 | ############################################################################## 306 | recons_loss = sketch_reconstruction_loss(output, x_target) 307 | 308 | prior_distribution = torch.distributions.Normal(torch.zeros_like(post_dist.mean), 309 | torch.ones_like(post_dist.stddev)) 310 | kl_cost = torch.max(torch.distributions.kl_divergence(post_dist, prior_distribution).mean(), 311 | torch.tensor(self.hp.kl_tolerance).to(device)) 312 | loss = recons_loss + curr_kl_weight * kl_cost 313 | 314 | #################### Update Gradient ######################################## 315 | ############################################################################# 316 | set_learninRate(self.sketch_optimizer, curr_learning_rate) 317 | loss.backward() 318 | nn.utils.clip_grad_norm(self.train_sketch_params, self.hp.grad_clip) 319 | self.sketch_optimizer.step() 320 | 321 | if (step + 1) % 5 == 0: 322 | print('Step:{} ** KL_Loss:{} ' 323 | '** Recons_Loss:{} ** Total_loss:{}'.format(step, kl_cost.item(), 324 | recons_loss.item(), loss.item())) 325 | data = {} 326 | data['Reconstrcution_Loss'] = recons_loss 327 | data['KL_Loss'] = kl_cost 328 | data['Total Loss'] = loss 329 | self.visalizer.plot_scalars(data, step) 330 | 331 | if (step -1) % 1000 == 0: 332 | 333 | """ Draw Sketch to Sketch """ 334 | start_token = torch.Tensor([0, 0, 1, 0, 0]).view(-1, 5).to(device) 335 | start_token = torch.stack([start_token] * z_vector.shape[0], dim=1) 336 | state = start_token 337 | hidden_cell = None 338 | 339 | batch_gen_strokes = [] 340 | for i_seq in range(self.hp.average_seq_len): 341 | input = torch.cat([state, z_vector.unsqueeze(0)], 2) 342 | state, hidden_cell = self.Sketch_Decoder(input, z_vector, hidden_cell=hidden_cell, 343 | isTrain=False, 344 | get_deterministic=True) 345 | batch_gen_strokes.append(state.squeeze(0)) 346 | 347 | sketch2sketch_gen = torch.stack(batch_gen_strokes, dim=1) 348 | sketch_vector_gt = batch.permute(1, 0, 2) 349 | 350 | sketch_vector_gt_draw = batch_rasterize_relative(sketch_vector_gt).to(device) 351 | sketch2sketch_gen_draw = batch_rasterize_relative(sketch2sketch_gen).to(device) 352 | 353 | batch_redraw = [] 354 | for a, b in zip(sketch_vector_gt_draw, sketch2sketch_gen_draw): 355 | batch_redraw.append(torch.cat((a, 1.- b), dim=-1)) 356 | 357 | torchvision.utils.save_image(torch.stack(batch_redraw), 358 | './pretrain_sketch_Viz/ShoeV2/redraw_{}.jpg'.format(step), 359 | nrow=8) 360 | 361 | torch.save(self.Sketch_Encoder.state_dict(), './pretrain_models/ShoeV2/Sketch_Encoder.pth') 362 | torch.save(self.Sketch_Decoder.state_dict(), './pretrain_models/ShoeV2/Sketch_Decoder.pth') 363 | 364 | self.Sketch_Encoder.train() 365 | self.Sketch_Decoder.train() 366 | 367 | 368 | 369 | def freeze_weights(self): 370 | for name, x in self.named_parameters(): 371 | x.requires_grad = False 372 | 373 | 374 | def Unfreeze_weights(self): 375 | for name, x in self.named_parameters(): 376 | x.requires_grad = True 377 | 378 | 379 | 380 | 381 | 382 | --------------------------------------------------------------------------------