├── imgs ├── network.png ├── framework.png ├── photo_synthesis.png └── sketch_synthesis.png ├── data.py ├── adain.py ├── compute_fsim.m ├── LICENSE ├── vgg.py ├── util.py ├── README.md ├── fusion.py ├── dataset.py ├── test.py ├── inference.py ├── resnet.py ├── modules.py ├── train.py ├── model.py └── networks.py /imgs/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony0720/DCNP/HEAD/imgs/network.png -------------------------------------------------------------------------------- /imgs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony0720/DCNP/HEAD/imgs/framework.png -------------------------------------------------------------------------------- /imgs/photo_synthesis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony0720/DCNP/HEAD/imgs/photo_synthesis.png -------------------------------------------------------------------------------- /imgs/sketch_synthesis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony0720/DCNP/HEAD/imgs/sketch_synthesis.png -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | from dataset import DatasetFromFolder 4 | 5 | 6 | def get_training_set(root_dir): 7 | train_dir = join(root_dir, "train") 8 | 9 | return DatasetFromFolder(train_dir) 10 | 11 | 12 | def get_test_set(root_dir): 13 | test_dir = join(root_dir, "test") 14 | 15 | return DatasetFromFolder(test_dir) 16 | -------------------------------------------------------------------------------- /adain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | def calc_mean_std(feat, eps=1e-5): 4 | # eps is a small value added to the variance to avoid divide-by-zero. 5 | size = feat.size() 6 | assert (len(size) == 4 or 3) 7 | N, C = size[:2] 8 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 9 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 10 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 11 | return feat_mean, feat_std 12 | 13 | #AadIN 14 | def adaptive_instance_normalization(content_feat, style_feat): 15 | assert (content_feat.size()[:2] == style_feat.size()[:2]) 16 | size = content_feat.size() 17 | style_mean, style_std = calc_mean_std(style_feat) 18 | content_mean, content_std = calc_mean_std(content_feat) 19 | 20 | normalized_feat = (content_feat - content_mean.expand( 21 | size)) / content_std.expand(size) 22 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 23 | -------------------------------------------------------------------------------- /compute_fsim.m: -------------------------------------------------------------------------------- 1 | probe_path = '.\CUFS\DCNP\Photo'; 2 | gallery_path = '.\CUFS\GroundTruth\Photo'; 3 | 4 | probe_list = readImageNames(probe_path); 5 | 6 | fsim_value = 0.; 7 | fsim_sum = 0.; 8 | fsim_average = 0.; 9 | 10 | for i = 1 : length(probe_list) 11 | 12 | probe = imread(fullfile(probe_path, probe_list(i).name)); 13 | 14 | [height width ch] = size(probe); 15 | if ch == 3 16 | probe = rgb2gray(probe); 17 | end 18 | probe = double(probe); 19 | 20 | gallery = imread(fullfile(gallery_path, probe_list(i).name)); 21 | 22 | 23 | [height width ch] = size(gallery); 24 | if ch == 3 25 | gallery = rgb2gray(gallery); 26 | end 27 | gallery = double(gallery); 28 | 29 | fsim_value = fsim(probe, gallery); 30 | %fprintf('\n fsim: %f\n', fsim_value) 31 | fsim_sum = fsim_sum + fsim_value; 32 | end 33 | 34 | fsim_average = fsim_sum / length(probe_list); 35 | fprintf('\nAverage fsim is %f\n', fsim_average); 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tony0720 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | class VGG19(torch.nn.Module): 7 | def __init__(self, requires_grad=False): 8 | super().__init__() 9 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features 10 | self.slice1 = torch.nn.Sequential() 11 | self.slice2 = torch.nn.Sequential() 12 | self.slice3 = torch.nn.Sequential() 13 | self.slice4 = torch.nn.Sequential() 14 | self.slice5 = torch.nn.Sequential() 15 | for x in range(1): 16 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 17 | for x in range(1, 6): 18 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 19 | for x in range(6, 11): 20 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 21 | for x in range(11, 20): 22 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 23 | for x in range(20, 29): 24 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 25 | if not requires_grad: 26 | for param in self.parameters(): 27 | param.requires_grad = False 28 | 29 | def forward(self, X): 30 | c11 = self.slice1(X) 31 | c21 = self.slice2(c11) 32 | c31 = self.slice3(c21) 33 | c41 = self.slice4(c31) 34 | c51 = self.slice5(c41) 35 | out = [c11, c21, c31, c41, c51] 36 | return out -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | 5 | 6 | def is_image_file(filename): 7 | return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) 8 | 9 | 10 | def load_img(filepath): 11 | img = Image.open(filepath).convert('RGB') 12 | H,W=img.size 13 | img = img.resize((256,256), Image.BICUBIC) 14 | return img,H,W 15 | 16 | 17 | def save_img(image_tensor, H,W,filename): 18 | image_numpy = image_tensor.float().numpy() 19 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 20 | image_numpy = image_numpy.clip(0, 255) 21 | image_numpy = image_numpy.astype(np.uint8) 22 | image_pil = Image.fromarray(image_numpy) 23 | image_pil = image_pil.resize((H,W), Image.BICUBIC) 24 | image_pil.save(filename) 25 | #print("Image saved as {}".format(filename)) 26 | 27 | 28 | def save_feature_map(feature_map, filename): 29 | image_numpy = feature_map.float().numpy() 30 | image_numpy = (image_numpy + 1) / 2.0 * 255.0 31 | image_numpy = image_numpy.clip(0, 255) 32 | image_numpy = image_numpy.astype(np.uint8) 33 | image_pil = Image.fromarray(image_numpy) 34 | image_pil = image_pil.resize((200, 250), Image.BICUBIC) 35 | image_pil.save(filename) 36 | 37 | 38 | def tensor2img(image_tensor): 39 | image_numpy = image_tensor.float().numpy() 40 | image_numpy = (image_numpy + 1) / 2.0 * 255.0 41 | image_numpy = image_numpy.clip(0, 255) 42 | image_numpy = image_numpy.astype(np.uint8) 43 | return image_numpy 44 | 45 | def get_attention(fake_label, real_label): 46 | error_map = torch.abs(fake_label.detach() - real_label.detach()) * 0.5 47 | return error_map 48 | 49 | def get_facial_label(facial_tensor): 50 | facial_label = torch.argmax(facial_tensor, 1).unsqueeze(1) 51 | facial_label = (facial_label.float() / 18.0 - 0.5) * 2.0 52 | return facial_label 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## DCNP 2 | 3 | Pytorch Code for "Dual Conditional Normalization Pyramid Network 4 | for Face Photo-Sketch Synthesis". 5 | 6 | ![framework](/imgs/framework.png) 7 | 8 | ![network](/imgs/network.png) 9 | 10 | ### Requirements 11 | 12 | + Ubuntu 18.04 13 | + Anaconda (Python, Numpy, PIL, etc.) 14 | + PyTorch 1.7.1 15 | + torchvision 0.8.2 16 | 17 | ### Prepare data 18 | 19 | 1. Creat folder '/data/'. 20 | 21 | 2. Download the datasets from [Google Drive](https://drive.google.com/file/d/1K9EXuHCu2zeQ1WP2JVhAc3yWfN0rIjtE/view?usp=sharing) and put them into '/data'. 22 | 23 | ### Inference: 24 | 25 | 1. Create folder '/checkpoint/pretrained/'. 26 | 27 | 2. Download the pre-trained models from [Google Drive](https://drive.google.com/file/d/1_S3Iy22RLfeG9dCCBq8tsbwTUmvpeFyh/view?usp=sharing) and put them into '/checkpoint/pretrained/'. 28 | 29 | 3. Configure the dataset from ['cuhk', 'ar', 'xmwvts', 'cuhk_feret', 'WildSketch']. 30 | 31 | 4. Run: 32 | 33 | ``` 34 | python inference.py 35 | ``` 36 | 5. Check the results under './results/pretrained/'. 37 | 38 | ### Train: 39 | 40 | 1. Configure the dataset from ['cuhk', 'ar', 'xmwvts', 'cuhk_feret', 'WildSketch'] and name your output_path. 41 | 42 | 2. Run: 43 | 44 | ``` 45 | python train.py 46 | ``` 47 | 48 | ### Test: 49 | 50 | 1. Configure the dataset from ['cuhk', 'ar', 'xmwvts', 'cuhk_feret', 'WildSketch'] and confirm your output_path to be consistent with the name at the train stage. 51 | 52 | 2. Run: 53 | 54 | ``` 55 | python test.py 56 | ``` 57 | 58 | 3. Check the results under './results/'. 59 | 60 | ### Results 61 | 62 | Our final results can be downloaded [here](https://drive.google.com/file/d/1iLesbjhFp5oYkOTSKzwgO_wUvTZ61Z9-/view?usp=sharing). 63 | 64 | ![photo_synthesis](/imgs/photo_synthesis.png) 65 | ![sketch_synthesis](/imgs/sketch_synthesis.png) 66 | 67 | ### Evaluation 68 | 69 | Matlab is requested to compute the FSIM metrics in [compute_fsim.m](https://github.com/Tony0720/Dual-Conditional-Normalization-Pyramid-Network-for-Face-Photo-Sketch-Synthesis/blob/main/compute_fsim.m). 70 | The evaluation of FID can be referred to [here](https://github.com/mseitzer/pytorch-fid). 71 | The evaluation of LPIPS can be referred to [here](https://github.com/richzhang/PerceptualSimilarity). 72 | 73 | ### Acknowledgments 74 | 75 | * This code builds heavily on **[pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)**. Thanks for open-sourcing! 76 | -------------------------------------------------------------------------------- /fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DBFM(nn.Module): 5 | def __init__(self, nf): 6 | super(DBFM, self).__init__() 7 | self.conv_a = nn.Conv2d(nf, nf, 1, 1, 0, bias=True) 8 | self.conv_gate_a = nn.Conv2d(nf, nf, 1, 1, 0, bias=True) 9 | self.conv_b = nn.Conv2d(nf, nf, 1, 1, 0, bias=True) 10 | self.conv_gate_b = nn.Conv2d(nf, nf, 1, 1, 0, bias=True) 11 | self.sigmoid = nn.Sigmoid() 12 | self.conv = nn.Sequential( 13 | nn.Conv2d(nf, nf, 3, 1, 1, bias=True), 14 | nn.BatchNorm2d(nf), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(nf, nf, 3, 1, 1, bias=True), 17 | nn.BatchNorm2d(nf), 18 | nn.ReLU(inplace=True) 19 | ) 20 | self.global_att = nn.Sequential( 21 | nn.AdaptiveAvgPool2d(1), 22 | nn.Conv2d(nf, int(nf//4), kernel_size=1, stride=1, padding=0), 23 | # nn.BatchNorm2d(inter_channels), 24 | nn.ReLU(inplace=True), 25 | nn.Conv2d(int(nf//4),nf, kernel_size=1, stride=1, padding=0), 26 | # nn.BatchNorm2d(channels), 27 | ) 28 | 29 | def forward(self, a, b): 30 | a = self.conv_a(a) 31 | g_a = self.conv_gate_a(a) 32 | g_a = self.sigmoid(g_a) 33 | b = self.conv_b(b) 34 | g_b = self.conv_gate_b(b) 35 | g_b = self.sigmoid(g_b) 36 | x = self.global_att(a+b) 37 | wei = self.sigmoid(x) 38 | a_gff = wei*(1+g_a)*a+(1-wei)*(1+g_b)*b 39 | out = self.conv(a_gff) 40 | 41 | return out 42 | 43 | class AFF(nn.Module): 44 | ''' 45 | 多特征融合 AFF 46 | ''' 47 | 48 | def __init__(self, channels=64, r=4): 49 | super(AFF, self).__init__() 50 | inter_channels = int(channels // r) 51 | 52 | self.local_att = nn.Sequential( 53 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 54 | nn.BatchNorm2d(inter_channels), 55 | nn.ReLU(inplace=True), 56 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 57 | nn.BatchNorm2d(channels), 58 | ) 59 | 60 | self.global_att = nn.Sequential( 61 | nn.AdaptiveAvgPool2d(1), 62 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 63 | # nn.BatchNorm2d(inter_channels), 64 | nn.ReLU(inplace=True), 65 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 66 | # nn.BatchNorm2d(channels), 67 | ) 68 | self.sigmoid = nn.Sigmoid() 69 | 70 | def forward(self, x, residual): 71 | xa = x + residual 72 | xl = self.local_att(xa) 73 | xg = self.global_att(xa) 74 | 75 | 76 | xlg = xl + xg 77 | wei = self.sigmoid(xlg) 78 | 79 | xo = 2 * x * wei + 2 * residual * (1 - wei) 80 | return xo -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from os import listdir 2 | from os.path import join 3 | import random 4 | 5 | from PIL import Image 6 | import torch 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | 10 | from util import is_image_file, load_img 11 | 12 | 13 | class DatasetFromFolder(data.Dataset): 14 | def __init__(self, image_dir): 15 | super(DatasetFromFolder, self).__init__() 16 | self.a_path = join(image_dir, "a") 17 | self.b_path = join(image_dir, "b") 18 | self.image_filenames = [x for x in listdir(self.a_path) if is_image_file(x)] 19 | self.random_filenames = self.image_filenames[:] 20 | random.shuffle(self.random_filenames) 21 | # self.b_image_filenames = [x for x in listdir(self.b_path) if is_image_file(x)] 22 | 23 | transform_list = [transforms.ToTensor(), 24 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 25 | 26 | self.transform = transforms.Compose(transform_list) 27 | 28 | # def __getitem__(self, index): 29 | # # Load Image 30 | # a = load_img(join(self.a_path, self.image_filenames[index])) 31 | # a = self.transform(a) 32 | # b = load_img(join(self.b_path, self.image_filenames[index])) 33 | # b = self.transform(b) 34 | 35 | # return a, b 36 | 37 | def __getitem__(self, index): 38 | a = Image.open(join(self.a_path, self.image_filenames[index])).convert('RGB') 39 | # index_b = random.randint(0, len(self.b_image_filenames) - 1) 40 | b = Image.open(join(self.b_path, self.image_filenames[index])).convert('RGB') 41 | a = a.resize((286,286), Image.BICUBIC) 42 | b = b.resize((286,286), Image.BICUBIC) 43 | 44 | a = transforms.ToTensor()(a) 45 | b = transforms.ToTensor()(b) 46 | w_offset = random.randint(0, max(0, 286 - 256 - 1)) 47 | h_offset = random.randint(0, max(0, 286 - 256 - 1)) 48 | 49 | a = a[:, h_offset:h_offset + 256, w_offset:w_offset + 256] 50 | b = b[:, h_offset:h_offset + 256, w_offset:w_offset + 256] 51 | 52 | a = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(a) 53 | b = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(b) 54 | 55 | a0 = Image.open(join(self.a_path, self.random_filenames[index])).convert('RGB') 56 | # index_b = random.randint(0, len(self.b_image_filenames) - 1) 57 | b0 = Image.open(join(self.b_path, self.random_filenames[index])).convert('RGB') 58 | a0 = a0.resize((286,286), Image.BICUBIC) 59 | b0 = b0.resize((286,286), Image.BICUBIC) 60 | 61 | a0 = transforms.ToTensor()(a0) 62 | b0 = transforms.ToTensor()(b0) 63 | w_offset = random.randint(0, max(0, 286 - 256 - 1)) 64 | h_offset = random.randint(0, max(0, 286 - 256 - 1)) 65 | 66 | a0 = a0[:, h_offset:h_offset + 256, w_offset:w_offset + 256] 67 | b0 = b0[:, h_offset:h_offset + 256, w_offset:w_offset + 256] 68 | 69 | a0 = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(a0) 70 | b0 = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(b0) 71 | 72 | return a, b, a0, b0 73 | 74 | def __len__(self): 75 | return len(self.image_filenames) 76 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import torch 6 | import torchvision.transforms as transforms 7 | from util import is_image_file, load_img, save_img 8 | from model import BiSeNet 9 | 10 | # Testing settings 11 | parser = argparse.ArgumentParser(description='Residual Nets') 12 | parser.add_argument('--dataset',type=str,default='cuhk', help='[cuhk, ar, xmwvts, cuhk_feret, WildSketch]') 13 | parser.add_argument('--output_path', type=str,default='exp1', help='output path') 14 | parser.add_argument('--nepochs', type=int, default=200, help='saved model of which epochs') 15 | parser.add_argument('--cuda', type=str,default=True,help='use cuda') 16 | opt = parser.parse_args() 17 | print(opt) 18 | 19 | device = torch.device("cuda:0" if opt.cuda else "cpu") 20 | 21 | a2b_model_path = "./checkpoint/{}/netG_a2b_model_epoch_{}.pth".format(opt.output_path, opt.nepochs) 22 | b2a_model_path = "./checkpoint/{}/netG_b2a_model_epoch_{}.pth".format(opt.output_path, opt.nepochs) 23 | net_g_a2b = torch.load(a2b_model_path).to(device) 24 | net_g_b2a = torch.load(b2a_model_path).to(device) 25 | 26 | n_classes = 19 27 | parsing_net = BiSeNet(n_classes=n_classes).to(device) 28 | parsing_net.load_state_dict(torch.load('face_parsing_bisenet.pth')) 29 | parsing_net.eval() 30 | for param in parsing_net.parameters(): 31 | param.requires_grad = False 32 | 33 | a_dir = "data/{}/test/a/".format(opt.dataset) 34 | b_dir = "data/{}/test/b/".format(opt.dataset) 35 | 36 | a_image_filenames = [x for x in os.listdir(a_dir) if is_image_file(x)] 37 | b_image_filenames = [x for x in os.listdir(b_dir) if is_image_file(x)] 38 | 39 | transform_list = [transforms.ToTensor(), 40 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 41 | 42 | transform = transforms.Compose(transform_list) 43 | 44 | if not os.path.exists('./result'): 45 | os.mkdir('./result') 46 | if not os.path.exists(os.path.join('./result', opt.output_path)): 47 | os.mkdir(os.path.join('./result', opt.output_path)) 48 | os.mkdir(os.path.join('./result', opt.output_path,'a2b')) 49 | os.mkdir(os.path.join('./result', opt.output_path,'b2a')) 50 | 51 | 52 | for image_name in a_image_filenames: 53 | img_a,Ha,Wa = load_img(a_dir + image_name) 54 | img_a = transform(img_a) 55 | img_a = img_a.unsqueeze(0).to(device) 56 | a_pfea = parsing_net(img_a.detach())[0] 57 | m=random.choice([x for x in os.listdir(b_dir) if is_image_file(x)]) 58 | img_b,_,_ = load_img(b_dir + m) 59 | img_b = transform(img_b) 60 | img_b = img_b.unsqueeze(0).to(device) 61 | b_gen = net_g_a2b(img_a,img_b,a_pfea) 62 | b_gen = b_gen.detach().squeeze(0).cpu() 63 | save_img(b_gen,Ha,Wa, "./result/{}/{}/{}".format(opt.output_path, 'a2b', image_name)) 64 | del img_a,img_b,b_gen,a_pfea 65 | torch.cuda.empty_cache() 66 | 67 | for image_name in b_image_filenames: 68 | img_b,Hb,Wb = load_img(b_dir + image_name) 69 | img_b = transform(img_b) 70 | img_b = img_b.unsqueeze(0).to(device) 71 | b_pfea = parsing_net(img_b.detach())[0] 72 | n=random.choice([x for x in os.listdir(a_dir) if is_image_file(x)]) 73 | img_a,_,_ = load_img(a_dir + n) 74 | img_a = transform(img_a) 75 | img_a = img_a.unsqueeze(0).to(device) 76 | a_gen = net_g_b2a(img_b,img_a,b_pfea) 77 | a_gen = a_gen.detach().squeeze(0).cpu() 78 | save_img(a_gen,Hb,Wb, "./result/{}/{}/{}".format(opt.output_path, 'b2a', image_name)) 79 | del img_a,img_b,a_gen,b_pfea 80 | torch.cuda.empty_cache() 81 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import torch 6 | import torchvision.transforms as transforms 7 | from util import is_image_file, load_img, save_img 8 | from model import BiSeNet 9 | 10 | # Testing settings 11 | parser = argparse.ArgumentParser(description='Residual Nets') 12 | parser.add_argument('--dataset',type=str,default='cuhk', help='[cuhk, ar, xmwvts, cuhk_feret, WildSketch]') 13 | parser.add_argument('--output_path', type=str,default='pretrained', help='output path') 14 | parser.add_argument('--cuda', type=str,default=True,help='use cuda') 15 | opt = parser.parse_args() 16 | print(opt) 17 | 18 | device = torch.device("cuda:0" if opt.cuda else "cpu") 19 | 20 | a2b_model_path = "./checkpoint/pretrained/{}_a2b.pth".format(opt.dataset) 21 | b2a_model_path = "./checkpoint/pretrained/{}_b2a.pth".format(opt.dataset) 22 | net_g_a2b = torch.load(a2b_model_path).to(device) 23 | net_g_b2a = torch.load(b2a_model_path).to(device) 24 | n_classes = 19 25 | parsing_net = BiSeNet(n_classes=n_classes).to(device) 26 | parsing_net.load_state_dict(torch.load('face_parsing_bisenet.pth')) 27 | parsing_net.eval() 28 | for param in parsing_net.parameters(): 29 | param.requires_grad = False 30 | 31 | a_dir = "data/{}/test/a/".format(opt.dataset) 32 | b_dir = "data/{}/test/b/".format(opt.dataset) 33 | 34 | a_image_filenames = [x for x in os.listdir(a_dir) if is_image_file(x)] 35 | b_image_filenames = [x for x in os.listdir(b_dir) if is_image_file(x)] 36 | 37 | transform_list = [transforms.ToTensor(), 38 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 39 | 40 | transform = transforms.Compose(transform_list) 41 | 42 | if not os.path.exists('./result'): 43 | os.mkdir('./result') 44 | if not os.path.exists(os.path.join('./result', opt.output_path)): 45 | os.mkdir(os.path.join('./result', opt.output_path)) 46 | if not os.path.exists(os.path.join('./result', opt.output_path, opt.dataset)): 47 | os.mkdir(os.path.join('./result', opt.output_path, opt.dataset)) 48 | os.mkdir(os.path.join('./result', opt.output_path, opt.dataset, 'a2b')) 49 | os.mkdir(os.path.join('./result', opt.output_path, opt.dataset, 'b2a')) 50 | 51 | 52 | for image_name in a_image_filenames: 53 | img_a,Ha,Wa = load_img(a_dir + image_name) 54 | img_a = transform(img_a) 55 | img_a = img_a.unsqueeze(0).to(device) 56 | a_pfea = parsing_net(img_a.detach())[0] 57 | m=random.choice([x for x in os.listdir(b_dir) if is_image_file(x)]) 58 | img_b,_,_ = load_img(b_dir + m) 59 | img_b = transform(img_b) 60 | img_b = img_b.unsqueeze(0).to(device) 61 | b_gen = net_g_a2b(img_a,img_b,a_pfea) 62 | b_gen = b_gen.detach().squeeze(0).cpu() 63 | save_img(b_gen,Ha,Wa, "./result/{}/{}/{}/{}".format(opt.output_path, opt.dataset, 'a2b', image_name)) 64 | del img_a,img_b,b_gen,a_pfea 65 | torch.cuda.empty_cache() 66 | 67 | for image_name in b_image_filenames: 68 | img_b,Hb,Wb = load_img(b_dir + image_name) 69 | img_b = transform(img_b) 70 | img_b = img_b.unsqueeze(0).to(device) 71 | b_pfea = parsing_net(img_b.detach())[0] 72 | n=random.choice([x for x in os.listdir(a_dir) if is_image_file(x)]) 73 | img_a,_,_ = load_img(a_dir + n) 74 | img_a = transform(img_a) 75 | img_a = img_a.unsqueeze(0).to(device) 76 | a_gen = net_g_b2a(img_b,img_a,b_pfea) 77 | a_gen = a_gen.detach().squeeze(0).cpu() 78 | save_img(a_gen,Hb,Wb, "./result/{}/{}/{}/{}".format(opt.output_path, opt.dataset , 'b2a', image_name)) 79 | del img_a,img_b,a_gen,b_pfea 80 | torch.cuda.empty_cache() -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | # from modules.bn import InPlaceABNSync as BatchNorm2d 10 | 11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | def __init__(self, in_chan, out_chan, stride=1): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(in_chan, out_chan, stride) 24 | self.bn1 = nn.BatchNorm2d(out_chan) 25 | self.conv2 = conv3x3(out_chan, out_chan) 26 | self.bn2 = nn.BatchNorm2d(out_chan) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = None 29 | if in_chan != out_chan or stride != 1: 30 | self.downsample = nn.Sequential( 31 | nn.Conv2d(in_chan, out_chan, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(out_chan), 34 | ) 35 | 36 | def forward(self, x): 37 | residual = self.conv1(x) 38 | residual = F.relu(self.bn1(residual)) 39 | residual = self.conv2(residual) 40 | residual = self.bn2(residual) 41 | 42 | shortcut = x 43 | if self.downsample is not None: 44 | shortcut = self.downsample(x) 45 | 46 | out = shortcut + residual 47 | out = self.relu(out) 48 | return out 49 | 50 | 51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 53 | for i in range(bnum-1): 54 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 55 | return nn.Sequential(*layers) 56 | 57 | 58 | class Resnet18(nn.Module): 59 | def __init__(self): 60 | super(Resnet18, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 69 | self.init_weight() 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = F.relu(self.bn1(x)) 74 | x = self.maxpool(x) 75 | 76 | x = self.layer1(x) 77 | feat8 = self.layer2(x) # 1/8 78 | feat16 = self.layer3(feat8) # 1/16 79 | feat32 = self.layer4(feat16) # 1/32 80 | return feat8, feat16, feat32 81 | 82 | def init_weight(self): 83 | state_dict = modelzoo.load_url(resnet18_url) 84 | self_state_dict = self.state_dict() 85 | for k, v in state_dict.items(): 86 | if 'fc' in k: continue 87 | self_state_dict.update({k: v}) 88 | self.load_state_dict(self_state_dict) 89 | 90 | def get_params(self): 91 | wd_params, nowd_params = [], [] 92 | for name, module in self.named_modules(): 93 | if isinstance(module, (nn.Linear, nn.Conv2d)): 94 | wd_params.append(module.weight) 95 | if not module.bias is None: 96 | nowd_params.append(module.bias) 97 | elif isinstance(module, nn.BatchNorm2d): 98 | nowd_params += list(module.parameters()) 99 | return wd_params, nowd_params 100 | 101 | 102 | if __name__ == "__main__": 103 | net = Resnet18() 104 | x = torch.randn(16, 3, 224, 224) 105 | out = net(x) 106 | print(out[0].size()) 107 | print(out[1].size()) 108 | print(out[2].size()) 109 | net.get_params() 110 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from functools import reduce 6 | import torch.nn.utils.spectral_norm as spectral_norm 7 | from adain import adaptive_instance_normalization 8 | 9 | # Define a basic residual block 10 | class ResBlock(nn.Module): 11 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 12 | super(ResBlock, self).__init__() 13 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 14 | 15 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 16 | conv_block = [] 17 | p = 0 18 | if padding_type == 'reflect': 19 | conv_block += [nn.ReflectionPad2d(1)] 20 | elif padding_type == 'replicate': 21 | conv_block += [nn.ReplicationPad2d(1)] 22 | elif padding_type == 'zero': 23 | p = 1 24 | else: 25 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 26 | 27 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 28 | norm_layer(dim), 29 | nn.ReLU(True)] 30 | if use_dropout: 31 | conv_block += [nn.Dropout(0.5)] 32 | 33 | p = 0 34 | if padding_type == 'reflect': 35 | conv_block += [nn.ReflectionPad2d(1)] 36 | elif padding_type == 'replicate': 37 | conv_block += [nn.ReplicationPad2d(1)] 38 | elif padding_type == 'zero': 39 | p = 1 40 | else: 41 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 42 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 43 | norm_layer(dim)] 44 | 45 | return nn.Sequential(*conv_block) 46 | 47 | def forward(self, x): 48 | out = x + self.conv_block(x) 49 | return nn.ReLU(True)(out) 50 | # return out 51 | 52 | 53 | class SPADEResBlk(nn.Module): 54 | def __init__(self, fin, fout, seg_fin): 55 | super().__init__() 56 | # Attributes 57 | self.learned_shortcut = (fin != fout) 58 | fmiddle = min(fin, fout) 59 | 60 | # create conv layers 61 | self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) 62 | self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) 63 | if self.learned_shortcut: 64 | self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) 65 | 66 | # apply spectral norm if specified 67 | self.conv_0 = spectral_norm(self.conv_0) 68 | self.conv_1 = spectral_norm(self.conv_1) 69 | if self.learned_shortcut: 70 | self.conv_s = spectral_norm(self.conv_s) 71 | 72 | # define normalization layers 73 | self.norm_0 = SPADE(fin, seg_fin) 74 | self.norm_1 = SPADE(fmiddle, seg_fin) 75 | if self.learned_shortcut: 76 | self.norm_s = SPADE(fin, seg_fin) 77 | 78 | # note the resnet block with SPADE also takes in |seg|, 79 | # the semantic segmentation map as input 80 | def forward(self, x, seg): 81 | x_s = self.shortcut(x, seg) 82 | 83 | dx = self.conv_0(self.actvn(self.norm_0(x, seg))) 84 | dx = self.conv_1(self.actvn(self.norm_1(dx, seg))) 85 | 86 | out = x_s + dx 87 | 88 | return out 89 | 90 | def shortcut(self, x, seg): 91 | if self.learned_shortcut: 92 | x_s = self.conv_s(self.norm_s(x, seg)) 93 | else: 94 | x_s = x 95 | return x_s 96 | 97 | def actvn(self, x): 98 | return F.leaky_relu(x, 2e-1) 99 | 100 | class SPADE(nn.Module): 101 | def __init__(self, cin, seg_dim): 102 | super().__init__() 103 | self.conv = nn.Sequential( 104 | nn.Conv2d(seg_dim, 128, kernel_size=3, stride=1, padding=1), 105 | nn.ReLU(), 106 | ) 107 | self.alpha = nn.Conv2d(128, cin, 108 | kernel_size=3, stride=1, padding=1) 109 | self.beta = nn.Conv2d(128, cin, 110 | kernel_size=3, stride=1, padding=1) 111 | 112 | @staticmethod 113 | def PN(x): 114 | ''' 115 | positional normalization: normalize each positional vector along the channel dimension 116 | ''' 117 | assert len(x.shape) == 4, 'Only works for 4D(image) tensor' 118 | x = x - x.mean(dim=1, keepdim=True) 119 | x_norm = x.norm(dim=1, keepdim=True) + 1e-6 120 | x = x / x_norm 121 | return x 122 | 123 | def DPN(self, x, s): 124 | h, w = x.shape[2], x.shape[3] 125 | s = F.interpolate(s, (h, w), mode='bilinear', align_corners = False) 126 | s = self.conv(s) 127 | a = self.alpha(s) 128 | b = self.beta(s) 129 | return x * (1 + a) + b 130 | 131 | def forward(self, x, s): 132 | x_out = self.DPN(self.PN(x), s) 133 | return x_out 134 | 135 | class ResAdaIN(nn.Module): 136 | def __init__(self, dim, norm_layer=nn.BatchNorm2d, use_bias=True): 137 | super(ResAdaIN, self).__init__() 138 | self.pad1 = nn.ReflectionPad2d(1) 139 | self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=use_bias) 140 | self.norm1 = norm_layer(dim) 141 | self.relu = nn.ReLU(True) 142 | self.pad2 = nn.ReflectionPad2d(1) 143 | self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=use_bias) 144 | self.norm2 = norm_layer(dim) 145 | 146 | def forward(self, x, style): 147 | 148 | residual = x 149 | x = adaptive_instance_normalization(x,style) 150 | x = self.pad1(x) 151 | x = self.conv1(x) 152 | x = self.norm1(x) 153 | x = self.relu(x) 154 | x = self.pad2(x) 155 | x = self.conv2(x) 156 | x = self.norm2(x) 157 | out = x + residual 158 | 159 | return nn.ReLU(True)(out) 160 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader 9 | import torch.backends.cudnn as cudnn 10 | from data import get_training_set, get_test_set 11 | from util import tensor2img, get_facial_label, get_attention, save_feature_map 12 | from networks import define_G, define_D, GANLoss, VGGLoss, get_scheduler, update_learning_rate 13 | from model import BiSeNet 14 | from torch.autograd import Variable 15 | 16 | # Training settings 17 | parser = argparse.ArgumentParser(description='Residual Nets') 18 | parser.add_argument('--dataset', type=str,default='cuhk', help='[cuhk, ar, xmwvts, cuhk_feret, WildSketch]') 19 | parser.add_argument('--output_path', type=str,default='exp1', help='output path') 20 | parser.add_argument('--batchSize', type=int, default=1, help='training batch size') 21 | parser.add_argument('--testBatchSize', type=int, default=1, help='testing batch size') 22 | parser.add_argument('--input_nc', type=int, default=3, help='input image channels') 23 | parser.add_argument('--output_nc', type=int, default=3, help='output image channels') 24 | parser.add_argument('--ngf', type=int, default=64, help='generator filters in first conv layer') 25 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count') 26 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 27 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 28 | parser.add_argument('--ndf', type=int, default=64, help='discriminator filters in first conv layer') 29 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 30 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine') 31 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 32 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 33 | parser.add_argument('--vgg', type=str,default=True, help='use vgg loss?') 34 | parser.add_argument('--cuda', type=str,default=True, help='use cuda?') 35 | parser.add_argument('--visual', action='store_true', help='visualize the result?') 36 | parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use') 37 | parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123') 38 | 39 | 40 | 41 | opt = parser.parse_args() 42 | 43 | print(opt) 44 | 45 | if opt.cuda and not torch.cuda.is_available(): 46 | raise Exception("No GPU found, please run without --cuda") 47 | 48 | cudnn.benchmark = True 49 | 50 | torch.manual_seed(opt.seed) 51 | if opt.cuda: 52 | torch.cuda.manual_seed(opt.seed) 53 | 54 | print('===> Loading datasets') 55 | root_path = "data/" 56 | train_set = get_training_set(root_path + opt.dataset) 57 | # test_set = get_test_set(root_path + opt.dataset) 58 | training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True) 59 | # testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False) 60 | 61 | device = torch.device("cuda:0" if opt.cuda else "cpu") 62 | 63 | n_classes = 19 64 | parsing_net = BiSeNet(n_classes=n_classes).to(device) 65 | parsing_net.load_state_dict(torch.load('face_parsing_bisenet.pth')) 66 | parsing_net.eval() 67 | for param in parsing_net.parameters(): 68 | param.requires_grad = False 69 | 70 | net_g_a2b = define_G(opt.input_nc, opt.output_nc, opt.ngf, 'instance', False, 'normal', 0.02, gpu_ids=device) 71 | net_g_b2a = define_G(opt.input_nc, opt.output_nc, opt.ngf, 'instance', False, 'normal', 0.02, gpu_ids=device) 72 | net_d_a2b = define_D(opt.input_nc, opt.ndf, 'basic', gpu_ids=device) 73 | net_d_b2a = define_D(opt.input_nc, opt.ndf, 'basic', gpu_ids=device) 74 | 75 | # setup optimizer 76 | criterionL1 = nn.L1Loss().to(device) 77 | criterionMSE = nn.MSELoss().to(device) 78 | criterionGAN = GANLoss().to(device) 79 | if opt.vgg: 80 | criterionVGG = VGGLoss().to(device) 81 | 82 | optimizer_net_g_a2b = optim.Adam(filter(lambda p: p.requires_grad,net_g_a2b.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) 83 | optimizer_net_g_b2a = optim.Adam(filter(lambda p: p.requires_grad,net_g_b2a.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) 84 | optimizer_net_d_a2b = optim.Adam(filter(lambda p: p.requires_grad,net_d_a2b.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) 85 | optimizer_net_d_b2a = optim.Adam(filter(lambda p: p.requires_grad,net_d_b2a.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) 86 | 87 | torch.autograd.set_detect_anomaly(True) 88 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): 89 | # train 90 | for iteration, batch in enumerate(training_data_loader, 1): 91 | # a: photo, b: sketch 92 | 93 | # forward 94 | real_a, real_b = batch[0].to(device), batch[1].to(device) 95 | ref_a, ref_b = batch[2].to(device), batch[3].to(device) 96 | 97 | real_label_a = parsing_net(real_a.detach())[0] 98 | real_label_b = parsing_net(real_b.detach())[0] 99 | 100 | fake_b = net_g_a2b(real_a,ref_b,real_label_a) 101 | fake_a = net_g_b2a(real_b,ref_a,real_label_b) 102 | 103 | fake_label_a = parsing_net(fake_a.detach())[0] 104 | fake_label_b = parsing_net(fake_b.detach())[0] 105 | 106 | rec_b = net_g_a2b(fake_a,ref_b,fake_label_a) 107 | rec_a = net_g_b2a(fake_b,ref_a,fake_label_b) 108 | 109 | ## train net_d_a2b 110 | optimizer_net_d_a2b.zero_grad() 111 | # train with fake 112 | pred_fake_a2b = net_d_a2b.forward(fake_b.detach()) 113 | loss_d_a2b_fake = criterionGAN(pred_fake_a2b, False) 114 | # train with real 115 | pred_real_a2b = net_d_a2b.forward(real_b) 116 | loss_d_a2b_real = criterionGAN(pred_real_a2b, True) 117 | # Combined D loss 118 | loss_d_a2b = (loss_d_a2b_fake + loss_d_a2b_real) * 0.5 119 | loss_d_a2b.backward() 120 | optimizer_net_d_a2b.step() 121 | 122 | ## train net_d_b2a 123 | optimizer_net_d_b2a.zero_grad() 124 | # train with fake 125 | pred_fake_b2a = net_d_b2a.forward(fake_a.detach()) 126 | loss_d_b2a_fake = criterionGAN(pred_fake_b2a, False) 127 | # train with real 128 | pred_real_b2a = net_d_b2a.forward(real_a) 129 | loss_d_b2a_real = criterionGAN(pred_real_b2a, True) 130 | # Combined D loss 131 | loss_d_b2a = (loss_d_b2a_fake + loss_d_b2a_real) * 0.5 132 | loss_d_b2a.backward() 133 | optimizer_net_d_b2a.step() 134 | 135 | # train net_g 136 | optimizer_net_g_a2b.zero_grad() 137 | optimizer_net_g_b2a.zero_grad() 138 | 139 | # gan loss 140 | pred_a2b = net_d_a2b.forward(fake_b) 141 | loss_g_a2b = criterionGAN(pred_a2b, True) 142 | # vgg loss 143 | loss_g_vgg_a2b = criterionVGG(fake_b, real_b)*0.5 144 | 145 | # gan loss 146 | pred_b2a = net_d_b2a.forward(fake_a) 147 | loss_g_b2a = criterionGAN(pred_b2a, True) 148 | # vgg loss 149 | loss_g_vgg_b2a = criterionVGG(fake_a, real_a)*0.5 150 | 151 | loss_rec_a = criterionVGG(rec_a, real_a) * 0.5 152 | loss_rec_b = criterionVGG(rec_b, real_b) * 0.5 153 | 154 | loss_g = loss_g_a2b + loss_g_b2a + loss_g_vgg_a2b + loss_g_vgg_b2a + loss_rec_a + loss_rec_b 155 | loss_g.backward() 156 | 157 | optimizer_net_g_a2b.step() 158 | optimizer_net_g_b2a.step() 159 | 160 | print("===> Epoch[{}]({}/{}): real_score: {:.4f}, fake_score: {:.4f}, g_gan_loss: {:.4f}".format( 161 | epoch, iteration, len(training_data_loader), pred_real_a2b.data.mean().item(), pred_fake_a2b.data.mean().item(), loss_g.item())) 162 | 163 | # update_learning_rate(a2b_scheduler, optimizer_net_g_a2b) 164 | # update_learning_rate(b2a_scheduler, optimizer_net_g_b2a) 165 | 166 | 167 | 168 | # checkpoint 169 | if epoch % 2 == 0: 170 | if not os.path.exists("./checkpoint"): 171 | os.mkdir("./checkpoint") 172 | if not os.path.exists(os.path.join("./checkpoint", opt.output_path)): 173 | os.mkdir(os.path.join("./checkpoint", opt.output_path)) 174 | net_g_a2b_model_out_path = "./checkpoint/{}/netG_a2b_model_epoch_{}.pth".format(opt.output_path, epoch) 175 | net_g_b2a_model_out_path = "./checkpoint/{}/netG_b2a_model_epoch_{}.pth".format(opt.output_path, epoch) 176 | net_d_a2b_model_out_path = "./checkpoint/{}/netD_a2b_model_epoch_{}.pth".format(opt.output_path, epoch) 177 | net_d_b2a_model_out_path = "./checkpoint/{}/netD_b2a_model_epoch_{}.pth".format(opt.output_path, epoch) 178 | torch.save(net_g_a2b, net_g_a2b_model_out_path) 179 | torch.save(net_g_b2a, net_g_b2a_model_out_path) 180 | torch.save(net_d_a2b, net_d_a2b_model_out_path) 181 | torch.save(net_d_b2a, net_d_b2a_model_out_path) 182 | print("Checkpoint saved to {}".format("./checkpoint" + opt.output_path)) 183 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision 9 | 10 | from resnet import Resnet18 11 | # from modules.bn import InPlaceABNSync as BatchNorm2d 12 | 13 | 14 | class ConvBNReLU(nn.Module): 15 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): 16 | super(ConvBNReLU, self).__init__() 17 | self.conv = nn.Conv2d(in_chan, 18 | out_chan, 19 | kernel_size = ks, 20 | stride = stride, 21 | padding = padding, 22 | bias = False) 23 | self.bn = nn.BatchNorm2d(out_chan) 24 | self.init_weight() 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | x = F.relu(self.bn(x)) 29 | return x 30 | 31 | def init_weight(self): 32 | for ly in self.children(): 33 | if isinstance(ly, nn.Conv2d): 34 | nn.init.kaiming_normal_(ly.weight, a=1) 35 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 36 | 37 | class BiSeNetOutput(nn.Module): 38 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): 39 | super(BiSeNetOutput, self).__init__() 40 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 41 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) 42 | self.init_weight() 43 | 44 | def forward(self, x): 45 | x = self.conv(x) 46 | x = self.conv_out(x) 47 | return x 48 | 49 | def init_weight(self): 50 | for ly in self.children(): 51 | if isinstance(ly, nn.Conv2d): 52 | nn.init.kaiming_normal_(ly.weight, a=1) 53 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 54 | 55 | def get_params(self): 56 | wd_params, nowd_params = [], [] 57 | for name, module in self.named_modules(): 58 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 59 | wd_params.append(module.weight) 60 | if not module.bias is None: 61 | nowd_params.append(module.bias) 62 | elif isinstance(module, nn.BatchNorm2d): 63 | nowd_params += list(module.parameters()) 64 | return wd_params, nowd_params 65 | 66 | 67 | class AttentionRefinementModule(nn.Module): 68 | def __init__(self, in_chan, out_chan, *args, **kwargs): 69 | super(AttentionRefinementModule, self).__init__() 70 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 71 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) 72 | self.bn_atten = nn.BatchNorm2d(out_chan) 73 | self.sigmoid_atten = nn.Sigmoid() 74 | self.init_weight() 75 | 76 | def forward(self, x): 77 | feat = self.conv(x) 78 | atten = F.avg_pool2d(feat, feat.size()[2:]) 79 | atten = self.conv_atten(atten) 80 | atten = self.bn_atten(atten) 81 | atten = self.sigmoid_atten(atten) 82 | out = torch.mul(feat, atten) 83 | return out 84 | 85 | def init_weight(self): 86 | for ly in self.children(): 87 | if isinstance(ly, nn.Conv2d): 88 | nn.init.kaiming_normal_(ly.weight, a=1) 89 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 90 | 91 | 92 | class ContextPath(nn.Module): 93 | def __init__(self, *args, **kwargs): 94 | super(ContextPath, self).__init__() 95 | self.resnet = Resnet18() 96 | self.arm16 = AttentionRefinementModule(256, 128) 97 | self.arm32 = AttentionRefinementModule(512, 128) 98 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 99 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 100 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 101 | 102 | self.init_weight() 103 | 104 | def forward(self, x): 105 | H0, W0 = x.size()[2:] 106 | feat8, feat16, feat32 = self.resnet(x) 107 | H8, W8 = feat8.size()[2:] 108 | H16, W16 = feat16.size()[2:] 109 | H32, W32 = feat32.size()[2:] 110 | 111 | avg = F.avg_pool2d(feat32, feat32.size()[2:]) 112 | avg = self.conv_avg(avg) 113 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest') 114 | 115 | feat32_arm = self.arm32(feat32) 116 | feat32_sum = feat32_arm + avg_up 117 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') 118 | feat32_up = self.conv_head32(feat32_up) 119 | 120 | feat16_arm = self.arm16(feat16) 121 | feat16_sum = feat16_arm + feat32_up 122 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') 123 | feat16_up = self.conv_head16(feat16_up) 124 | 125 | return feat8, feat16_up, feat32_up # x8, x8, x16 126 | 127 | def init_weight(self): 128 | for ly in self.children(): 129 | if isinstance(ly, nn.Conv2d): 130 | nn.init.kaiming_normal_(ly.weight, a=1) 131 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 132 | 133 | def get_params(self): 134 | wd_params, nowd_params = [], [] 135 | for name, module in self.named_modules(): 136 | if isinstance(module, (nn.Linear, nn.Conv2d)): 137 | wd_params.append(module.weight) 138 | if not module.bias is None: 139 | nowd_params.append(module.bias) 140 | elif isinstance(module, nn.BatchNorm2d): 141 | nowd_params += list(module.parameters()) 142 | return wd_params, nowd_params 143 | 144 | 145 | ### This is not used, since I replace this with the resnet feature with the same size 146 | class SpatialPath(nn.Module): 147 | def __init__(self, *args, **kwargs): 148 | super(SpatialPath, self).__init__() 149 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) 150 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 151 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 152 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) 153 | self.init_weight() 154 | 155 | def forward(self, x): 156 | feat = self.conv1(x) 157 | feat = self.conv2(feat) 158 | feat = self.conv3(feat) 159 | feat = self.conv_out(feat) 160 | return feat 161 | 162 | def init_weight(self): 163 | for ly in self.children(): 164 | if isinstance(ly, nn.Conv2d): 165 | nn.init.kaiming_normal_(ly.weight, a=1) 166 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 167 | 168 | def get_params(self): 169 | wd_params, nowd_params = [], [] 170 | for name, module in self.named_modules(): 171 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 172 | wd_params.append(module.weight) 173 | if not module.bias is None: 174 | nowd_params.append(module.bias) 175 | elif isinstance(module, nn.BatchNorm2d): 176 | nowd_params += list(module.parameters()) 177 | return wd_params, nowd_params 178 | 179 | 180 | class FeatureFusionModule(nn.Module): 181 | def __init__(self, in_chan, out_chan, *args, **kwargs): 182 | super(FeatureFusionModule, self).__init__() 183 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 184 | self.conv1 = nn.Conv2d(out_chan, 185 | out_chan//4, 186 | kernel_size = 1, 187 | stride = 1, 188 | padding = 0, 189 | bias = False) 190 | self.conv2 = nn.Conv2d(out_chan//4, 191 | out_chan, 192 | kernel_size = 1, 193 | stride = 1, 194 | padding = 0, 195 | bias = False) 196 | self.relu = nn.ReLU(inplace=True) 197 | self.sigmoid = nn.Sigmoid() 198 | self.init_weight() 199 | 200 | def forward(self, fsp, fcp): 201 | fcat = torch.cat([fsp, fcp], dim=1) 202 | feat = self.convblk(fcat) 203 | atten = F.avg_pool2d(feat, feat.size()[2:]) 204 | atten = self.conv1(atten) 205 | atten = self.relu(atten) 206 | atten = self.conv2(atten) 207 | atten = self.sigmoid(atten) 208 | feat_atten = torch.mul(feat, atten) 209 | feat_out = feat_atten + feat 210 | return feat_out 211 | 212 | def init_weight(self): 213 | for ly in self.children(): 214 | if isinstance(ly, nn.Conv2d): 215 | nn.init.kaiming_normal_(ly.weight, a=1) 216 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 217 | 218 | def get_params(self): 219 | wd_params, nowd_params = [], [] 220 | for name, module in self.named_modules(): 221 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 222 | wd_params.append(module.weight) 223 | if not module.bias is None: 224 | nowd_params.append(module.bias) 225 | elif isinstance(module, nn.BatchNorm2d): 226 | nowd_params += list(module.parameters()) 227 | return wd_params, nowd_params 228 | 229 | 230 | class BiSeNet(nn.Module): 231 | def __init__(self, n_classes, *args, **kwargs): 232 | super(BiSeNet, self).__init__() 233 | self.cp = ContextPath() 234 | ## here self.sp is deleted 235 | self.ffm = FeatureFusionModule(256, 256) 236 | self.conv_out = BiSeNetOutput(256, 256, n_classes) 237 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes) 238 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes) 239 | self.init_weight() 240 | 241 | def forward(self, x): 242 | H, W = x.size()[2:] 243 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature 244 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature 245 | feat_fuse = self.ffm(feat_sp, feat_cp8) 246 | 247 | feat_out = self.conv_out(feat_fuse) 248 | feat_out16 = self.conv_out16(feat_cp8) 249 | feat_out32 = self.conv_out32(feat_cp16) 250 | 251 | feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) 252 | feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) 253 | feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) 254 | return feat_out, feat_out16, feat_out32 255 | 256 | def init_weight(self): 257 | for ly in self.children(): 258 | if isinstance(ly, nn.Conv2d): 259 | nn.init.kaiming_normal_(ly.weight, a=1) 260 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 261 | 262 | def get_params(self): 263 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] 264 | for name, child in self.named_children(): 265 | child_wd_params, child_nowd_params = child.get_params() 266 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): 267 | lr_mul_wd_params += child_wd_params 268 | lr_mul_nowd_params += child_nowd_params 269 | else: 270 | wd_params += child_wd_params 271 | nowd_params += child_nowd_params 272 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params 273 | 274 | 275 | if __name__ == "__main__": 276 | net = BiSeNet(19) 277 | net.cuda() 278 | net.eval() 279 | in_ten = torch.randn(16, 3, 640, 480).cuda() 280 | out, out16, out32 = net(in_ten) 281 | print(out.shape) 282 | 283 | net.get_params() 284 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | #from typing_extensions import Concatenate 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn import init 6 | import functools 7 | from torch.optim import lr_scheduler 8 | import math 9 | from modules import ResBlock,SPADEResBlk,ResAdaIN 10 | from vgg import VGG19 11 | from fusion import AFF, DBFM 12 | from torch.autograd import Variable 13 | 14 | 15 | def get_norm_layer(norm_type='instance'): 16 | if norm_type == 'batch': 17 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 18 | elif norm_type == 'instance': 19 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 20 | elif norm_type == 'none': 21 | norm_layer = None 22 | else: 23 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 24 | return norm_layer 25 | 26 | 27 | def get_scheduler(optimizer, opt): 28 | if opt.lr_policy == 'lambda': 29 | def lambda_rule(epoch): 30 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 31 | return lr_l 32 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 33 | elif opt.lr_policy == 'step': 34 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 35 | elif opt.lr_policy == 'plateau': 36 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 37 | elif opt.lr_policy == 'cosine': 38 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 39 | else: 40 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 41 | return scheduler 42 | 43 | 44 | # update learning rate (called once every epoch) 45 | def update_learning_rate(scheduler, optimizer): 46 | scheduler.step() 47 | lr = optimizer.param_groups[0]['lr'] 48 | print('learning rate = %.7f' % lr) 49 | 50 | 51 | def init_weights(net, init_type='normal', gain=0.02): 52 | def init_func(m): 53 | classname = m.__class__.__name__ 54 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 55 | if init_type == 'normal': 56 | init.normal_(m.weight.data, 0.0, gain) 57 | elif init_type == 'xavier': 58 | init.xavier_normal_(m.weight.data, gain=gain) 59 | elif init_type == 'kaiming': 60 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 61 | elif init_type == 'orthogonal': 62 | init.orthogonal_(m.weight.data, gain=gain) 63 | else: 64 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 65 | if hasattr(m, 'bias') and m.bias is not None: 66 | init.constant_(m.bias.data, 0.0) 67 | elif classname.find('BatchNorm2d') != -1: 68 | init.normal_(m.weight.data, 1.0, gain) 69 | init.constant_(m.bias.data, 0.0) 70 | 71 | print('initialize network with %s' % init_type) 72 | net.apply(init_func) 73 | 74 | 75 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids='cuda:0'): 76 | net.to(gpu_ids) 77 | init_weights(net, init_type, gain=init_gain) 78 | return net 79 | 80 | 81 | def define_G(input_nc, output_nc, ngf, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids='cuda:0'): 82 | net = None 83 | norm_layer = get_norm_layer(norm_type=norm) 84 | net = TransferNet2(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) 85 | 86 | return init_net(net, init_type, init_gain, gpu_ids) 87 | 88 | 89 | class TransferNet2(nn.Module): 90 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type='reflect'): 91 | assert(n_blocks >= 0) 92 | super(TransferNet2, self).__init__() 93 | self.input_nc = input_nc 94 | self.output_nc = output_nc 95 | self.ngf = ngf 96 | self.n_blocks = n_blocks 97 | if type(norm_layer) == functools.partial: 98 | use_bias = norm_layer.func == nn.InstanceNorm2d 99 | else: 100 | use_bias = norm_layer == nn.InstanceNorm2d 101 | 102 | self.p_in = Inconv(input_nc, ngf, norm_layer, use_bias) 103 | self.p_down1 = Down(ngf, ngf * 2, norm_layer, use_bias) 104 | self.p_down2 = Down(ngf * 2, ngf * 4, norm_layer, use_bias) 105 | self.p_down3 = Down(ngf * 4, ngf * 8, norm_layer, use_bias) 106 | self.spade1 = SPADEResBlk(128,128,19) 107 | self.spade2 = SPADEResBlk(256,256,19) 108 | self.spade3 = SPADEResBlk(512,512,19) 109 | 110 | self.resblock1=nn.Sequential( 111 | ResAdaIN(ngf * 2), 112 | ResAdaIN(ngf * 2), 113 | ResAdaIN(ngf * 2), 114 | ResAdaIN(ngf * 2), 115 | ResAdaIN(ngf * 2), 116 | ResAdaIN(ngf * 2), 117 | ResAdaIN(ngf * 2), 118 | ResAdaIN(ngf * 2), 119 | ResAdaIN(ngf * 2) 120 | ) 121 | 122 | self.resblock2=nn.Sequential( 123 | ResAdaIN(ngf * 4), 124 | ResAdaIN(ngf * 4), 125 | ResAdaIN(ngf * 4), 126 | ResAdaIN(ngf * 4), 127 | ResAdaIN(ngf * 4), 128 | ResAdaIN(ngf * 4), 129 | ResAdaIN(ngf * 4), 130 | ResAdaIN(ngf * 4), 131 | ResAdaIN(ngf * 4) 132 | ) 133 | 134 | self.resblock3=nn.Sequential( 135 | ResAdaIN(ngf * 8), 136 | ResAdaIN(ngf * 8), 137 | ResAdaIN(ngf * 8), 138 | ResAdaIN(ngf * 8), 139 | ResAdaIN(ngf * 8), 140 | ResAdaIN(ngf * 8), 141 | ResAdaIN(ngf * 8), 142 | ResAdaIN(ngf * 8), 143 | ResAdaIN(ngf * 8) 144 | ) 145 | 146 | self.up1_p = Up(ngf * 8, ngf * 4, norm_layer, use_bias) 147 | self.up2_p = Up(ngf * 4, ngf * 2, norm_layer, use_bias) 148 | self.up3_p = Up(ngf * 2, ngf, norm_layer, use_bias) 149 | self.out_conv_p = Outconv(ngf, output_nc) 150 | 151 | self.cat1 = AFF(channels=ngf*4) 152 | self.cat2 = AFF(channels=ngf*2) 153 | self.fusion1 = DBFM(128) 154 | self.fusion2 = DBFM(256) 155 | self.fusion3 = DBFM(512) 156 | 157 | def forward(self, photo, sketch, mask_p): 158 | mask_p = torch.softmax(mask_p,1) 159 | 160 | fp_64 = self.p_in(photo) 161 | fs_64 = self.p_in(sketch) 162 | fp_128 = self.p_down1(fp_64) 163 | fs_128 = self.p_down1(fs_64) 164 | fp_256 = self.p_down2(fp_128) 165 | fs_256 = self.p_down2(fs_128) 166 | fp_512 = self.p_down3(fp_256) 167 | fs_512 = self.p_down3(fs_256) 168 | 169 | for i in range(9): 170 | fp_128 = self.resblock1[i](fp_128,fs_128) 171 | fp_128_spade = self.spade1(fs_128,mask_p) 172 | fp_128_ = self.fusion1(fp_128,fp_128_spade) 173 | 174 | for i in range(9): 175 | fp_256 = self.resblock2[i](fp_256,fs_256) 176 | fp_256_spade = self.spade2(fs_256,mask_p) 177 | fp_256_ = self.fusion2(fp_256,fp_256_spade) 178 | 179 | for i in range(9): 180 | fp_512 = self.resblock3[i](fp_512,fs_512) 181 | fp_512_spade = self.spade3(fs_512,mask_p) 182 | fp_512_ = self.fusion3(fp_512,fp_512_spade) 183 | 184 | fp_256__ = self.up1_p(fp_512_) 185 | fp_256_ = self.cat1(fp_256_,fp_256__) 186 | 187 | fp_128__ = self.up2_p(fp_256_) 188 | fp_128_ = self.cat2(fp_128_,fp_128__) 189 | 190 | fp_64 = self.up3_p(fp_128_) 191 | fake_sketch = self.out_conv_p(fp_64) 192 | 193 | return fake_sketch 194 | 195 | 196 | class Inconv(nn.Module): 197 | def __init__(self, in_ch, out_ch, norm_layer, use_bias): 198 | super(Inconv, self).__init__() 199 | self.inconv = nn.Sequential( 200 | nn.ReflectionPad2d(3), 201 | nn.Conv2d(in_ch, out_ch, kernel_size=7, padding=0, 202 | bias=use_bias), 203 | norm_layer(out_ch), 204 | nn.ReLU(True) 205 | ) 206 | 207 | def forward(self, x): 208 | x = self.inconv(x) 209 | return x 210 | 211 | 212 | class Down(nn.Module): 213 | def __init__(self, in_ch, out_ch, norm_layer, use_bias): 214 | super(Down, self).__init__() 215 | self.down = nn.Sequential( 216 | nn.Conv2d(in_ch, out_ch, kernel_size=3, 217 | stride=2, padding=1, bias=use_bias), 218 | norm_layer(out_ch), 219 | nn.ReLU(True) 220 | ) 221 | 222 | def forward(self, x): 223 | x = self.down(x) 224 | return x 225 | 226 | 227 | class Up(nn.Module): 228 | def __init__(self, in_ch, out_ch, norm_layer, use_bias): 229 | super(Up, self).__init__() 230 | self.up = nn.Sequential( 231 | #nn.Upsample(scale_factor=2, mode='nearest'), 232 | #nn.Conv2d(in_ch, out_ch, 233 | # kernel_size=3, stride=1, 234 | # padding=1, bias=use_bias), 235 | nn.ConvTranspose2d(in_ch, out_ch, 236 | kernel_size=3, stride=2, 237 | padding=1, output_padding=1, 238 | bias=use_bias), 239 | norm_layer(out_ch), 240 | nn.ReLU(True) 241 | ) 242 | 243 | def forward(self, x): 244 | x = self.up(x) 245 | return x 246 | 247 | 248 | class Outconv(nn.Module): 249 | def __init__(self, in_ch, out_ch): 250 | super(Outconv, self).__init__() 251 | self.outconv = nn.Sequential( 252 | nn.ReflectionPad2d(3), 253 | nn.Conv2d(in_ch, out_ch, kernel_size=7, padding=0), 254 | nn.Tanh() 255 | ) 256 | 257 | def forward(self, x): 258 | x = self.outconv(x) 259 | return x 260 | 261 | 262 | def define_D(input_nc, ndf, netD, 263 | n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_ids='cuda:0'): 264 | net = None 265 | norm_layer = get_norm_layer(norm_type=norm) 266 | 267 | if netD == 'basic': 268 | net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid) 269 | elif netD == 'n_layers': 270 | net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid) 271 | else: 272 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % net) 273 | 274 | return init_net(net, init_type, init_gain, gpu_ids) 275 | 276 | 277 | 278 | # Defines the PatchGAN discriminator with the specified arguments. 279 | class NLayerDiscriminator(nn.Module): 280 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False): 281 | super(NLayerDiscriminator, self).__init__() 282 | if type(norm_layer) == functools.partial: 283 | use_bias = norm_layer.func == nn.InstanceNorm2d 284 | else: 285 | use_bias = norm_layer == nn.InstanceNorm2d 286 | 287 | kw = 4 288 | padw = 1 289 | sequence = [ 290 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 291 | nn.LeakyReLU(0.2, True) 292 | ] 293 | 294 | nf_mult = 1 295 | nf_mult_prev = 1 296 | for n in range(1, n_layers): 297 | nf_mult_prev = nf_mult 298 | nf_mult = min(2**n, 8) 299 | sequence += [ 300 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 301 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), 302 | norm_layer(ndf * nf_mult), 303 | nn.LeakyReLU(0.2, True) 304 | ] 305 | 306 | nf_mult_prev = nf_mult 307 | nf_mult = min(2**n_layers, 8) 308 | sequence += [ 309 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 310 | kernel_size=kw, stride=1, padding=padw, bias=use_bias), 311 | norm_layer(ndf * nf_mult), 312 | nn.LeakyReLU(0.2, True) 313 | ] 314 | 315 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 316 | 317 | if use_sigmoid: 318 | sequence += [nn.Sigmoid()] 319 | 320 | self.model = nn.Sequential(*sequence) 321 | 322 | def forward(self, input): 323 | return self.model(input) 324 | 325 | class GANLoss(nn.Module): 326 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): 327 | super(GANLoss, self).__init__() 328 | self.register_buffer('real_label', torch.tensor(target_real_label)) 329 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 330 | if use_lsgan: 331 | self.loss = nn.MSELoss() 332 | else: 333 | self.loss = nn.BCELoss() 334 | 335 | def get_target_tensor(self, input, target_is_real): 336 | if target_is_real: 337 | target_tensor = self.real_label 338 | else: 339 | target_tensor = self.fake_label 340 | return target_tensor.expand_as(input) 341 | 342 | def __call__(self, input, target_is_real): 343 | target_tensor = self.get_target_tensor(input, target_is_real) 344 | return self.loss(input, target_tensor) 345 | 346 | 347 | class VGGLoss(nn.Module): 348 | def __init__(self): 349 | super(VGGLoss, self).__init__() 350 | self.vgg = VGG19() 351 | self.criterion = nn.L1Loss() 352 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 353 | 354 | def forward(self, x, y): 355 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 356 | loss = 0 357 | for i in range(len(x_vgg)): 358 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 359 | return loss --------------------------------------------------------------------------------