├── .DS_Store ├── .gitattributes ├── .gitignore ├── README.md ├── data ├── dataloader.py └── train_fileList.txt ├── demo.py ├── model ├── losses.py ├── networks.py └── networks_without_coarse.py ├── test.py ├── train.py └── utils └── model_storage.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cydiachen/FSRNET_pytorch/e792839db4b6c49204842ae42975da74fac96285/.DS_Store -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FSRNet Pytorch 2 | Dear friends, 3 | Thank you for keep tracking in this implementation of FSRNet (CVPR 2018 Oral Paper) 4 | 5 | I have been spent the whole summer as an intern in Iluvatar.ai. I have been back to school, so I have time to complete the Project. 6 | 7 | I rewrite the Train.py and other model code completely. Now I am uploading pretrained Weights on BaiduNetDisk together with Training Set. 8 | 9 | Based on [WaveletSRNet](https://github.com/hhb072/WaveletSRNet/), I altered the code by adopting FSRNet network structure. 10 | 11 | ## Prerequisites 12 | 13 | * Python 3.6 14 | * Pytorch 1.0 or newer (Pytorch > 0.4 should be ok) 15 | * matplotlib 16 | * skimage 17 | 18 | ## Train 19 | 20 | Change the option in Train.py to set the dataset's directory. I am using CelebAHQ-MASK as the training set. The GroundTruth is generated by zllrunning/face-parsing.PyTorch(https://github.com/zllrunning/face-parsing.PyTorch) with pretrained model. 21 | 22 | Dataset Link: https://pan.baidu.com/s/1HEECUyKI5GOSrd7NPlm-ow 密码:z2ud 23 | 24 | ## Test 25 | 26 | ON GOING :| 27 | PYTHON AND NOTEBOOK WILL BE PROVIDED. 28 | Pretrained Weights:链接:https://pan.baidu.com/s/1ZkgABGefsMjO6XhhvlBzRA 密码:libl 29 | 30 | ## Result 31 | 32 | ## Citation 33 | If you find FSRNet useful in your research, please consider citing (* indicates equal contributions): 34 | 35 | @inproceedings{CT-FSRNet-2018, 36 | title={FSRNet: End-to-End Learning Face Super-Resolution with Facial Priors}, 37 | author={Chen, Yu* and Tai, Ying* and Liu, Xiaoming and Shen, Chunhua and Yang, Jian }, 38 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 39 | year={2018} 40 | } 41 | 42 | 43 | -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.utils.data as data 4 | from os import listdir 5 | import os 6 | from os.path import join 7 | from PIL import Image, ImageOps 8 | import random 9 | import torchvision.transforms as transforms 10 | import cv2 11 | import numpy as np 12 | from torch.autograd import Variable 13 | import matplotlib.pyplot as plt 14 | 15 | 16 | def loadFromFile(path, datasize): 17 | if path is None: 18 | return None, None 19 | 20 | # print("Load from file %s" % path) 21 | f = open(path) 22 | data = [] 23 | for idx in range(0, datasize): 24 | line = f.readline() 25 | line = line[:-1] 26 | data.append(line) 27 | f.close() 28 | return data 29 | 30 | 31 | def load_lr_hr_prior(file_path, input_height=128, input_width=128, output_height=128, output_width=128, is_mirror=False, 32 | is_gray=True, scale=8.0, is_scale_back=True, is_parsing_map=True): 33 | if input_width is None: 34 | input_width = input_height 35 | if output_width is None: 36 | output_width = output_height 37 | 38 | # print(file_path) 39 | 40 | img = cv2.imread(file_path) 41 | # img = Image.open(file_path) 42 | 43 | if is_gray is False: 44 | b, g, r = cv2.split(img) 45 | img = cv2.merge([r, g, b]) 46 | if is_gray is True: 47 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 48 | 49 | if is_mirror and random.randint(0, 1) is 0: 50 | img = ImageOps.mirror(img) 51 | 52 | if input_height is not None: 53 | img = cv2.resize(img, (input_width, input_height), interpolation=cv2.INTER_CUBIC) 54 | 55 | if is_parsing_map: 56 | str = ['skin.png','lbrow.png','rbrow.png','leye.png','reye.png','lear.png','rear.png','nose.png','mouth','ulip.png','llip.png'] 57 | 58 | hms = np.zeros((64, 64, len(str))) 59 | 60 | for i in range(len(str)): 61 | (onlyfilePath, img_name) = os.path.split(file_path) 62 | full_name = onlyfilePath + "/Parsing_Maps/" + img_name[:-4] + "_"+ str[i] 63 | hm = cv2.imread(full_name, cv2.IMREAD_GRAYSCALE) 64 | hm_resized = cv2.resize(hm, (64, 64), interpolation=cv2.INTER_CUBIC) / 255.0 65 | hms[:, :, i] = hm_resized 66 | 67 | img = cv2.resize(img, (output_width, output_height), interpolation=cv2.INTER_CUBIC) 68 | img_lr = cv2.resize(img, (int(output_width / scale), int(output_height / scale)), interpolation=cv2.INTER_CUBIC) 69 | 70 | if is_scale_back: 71 | img_lr = cv2.resize(img_lr, (output_width, output_height), interpolation=cv2.INTER_CUBIC) 72 | return img_lr, img, hms 73 | else: 74 | return img_lr, img, hms 75 | 76 | 77 | class ImageDatasetFromFile(data.Dataset): 78 | def __init__(self, image_list, img_path, input_height=128, input_width=128, output_height=128, output_width=128, 79 | is_mirror=False, is_gray=False, upscale=8.0, is_scale_back=True, is_parsing_map=True): 80 | super(ImageDatasetFromFile, self).__init__() 81 | 82 | self.image_filenames = image_list 83 | self.upscale = upscale 84 | self.is_mirror = is_mirror 85 | self.img_path = img_path 86 | self.input_height = input_height 87 | self.input_width = input_width 88 | self.output_height = output_height 89 | self.output_width = output_width 90 | self.is_scale_back = is_scale_back 91 | self.is_gray = is_gray 92 | self.is_parsing_map = is_parsing_map 93 | 94 | self.input_transform = transforms.Compose([ 95 | transforms.ToTensor()]) 96 | 97 | def __getitem__(self, idx): 98 | 99 | if self.is_mirror: 100 | is_mirror = random.randint(0, 1) is 0 101 | else: 102 | is_mirror = False 103 | 104 | image_filenames = loadFromFile(self.image_filenames, len(open(self.image_filenames, 'r').readlines())) 105 | fullpath = join(self.img_path, image_filenames[idx]) 106 | 107 | lr, hr, pm = load_lr_hr_prior(fullpath, 108 | self.input_height, self.input_width, self.output_height, self.output_width, 109 | self.is_mirror, self.is_gray, self.upscale, self.is_scale_back, 110 | self.is_parsing_map) 111 | 112 | input = self.input_transform(lr) 113 | target = self.input_transform(hr) 114 | parsing_map = self.input_transform(pm) 115 | 116 | return input, target, parsing_map 117 | 118 | def __len__(self): 119 | return len(open(self.image_filenames, 'rU').readlines()) 120 | 121 | 122 | # demo_dataset = ImageDatasetFromFile("/home/cydia/文档/毕业设计/make_Face_boundary/81_landmarks/fileList.txt", 123 | # "/home/cydia/图片/sample/") 124 | # 125 | # train_data_loader = data.DataLoader(dataset=demo_dataset, batch_size=1, num_workers=8) 126 | 127 | if __name__ == '__main__': 128 | for titer, batch in enumerate(train_data_loader): 129 | input, target, heatmaps = Variable(batch[0]), Variable(batch[1]), Variable(batch[2]) 130 | 131 | Input = input.permute(0, 2, 3, 1).cpu().data.numpy() 132 | Target = target.permute(0, 2, 3, 1).cpu().data.numpy() 133 | Parsing_maps = heatmaps.permute(0, 2, 3, 1).cpu().data.numpy() 134 | 135 | plt.figure("Input Image") 136 | plt.imshow(Input[0, :, :, :]) 137 | plt.axis('on') 138 | plt.title('image') 139 | plt.show() 140 | 141 | plt.figure("Target Image") 142 | plt.imshow(Target[0, :, :, :]) 143 | plt.axis('on') 144 | plt.title('Target') 145 | plt.show() 146 | 147 | plt.figure("HMS") 148 | plt.imshow(Parsing_maps[0, :, :, 0]) 149 | plt.axis('on') 150 | plt.title('OMS') 151 | plt.show() 152 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cydiachen/FSRNET_pytorch/e792839db4b6c49204842ae42975da74fac96285/demo.py -------------------------------------------------------------------------------- /model/losses.py: -------------------------------------------------------------------------------- 1 | import torch, sys 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | import torch.autograd as autograd 6 | import numpy as np 7 | import torchvision.models as models 8 | import util.util as util 9 | from util.image_pool import ImagePool 10 | from torch.autograd import Variable 11 | ############################################################################### 12 | # Functions 13 | ############################################################################### 14 | 15 | class ContentLoss(): 16 | def initialize(self, loss): 17 | self.criterion = loss 18 | 19 | def get_loss(self, fakeIm, realIm): 20 | return self.criterion(fakeIm, realIm) 21 | 22 | class PerceptualLoss(): 23 | def contentFunc(self): 24 | conv_5_3_layer = 14 #layer index of relu5_3 in vgg-16 25 | cnn = models.vgg19(pretrained=True).features 26 | cnn = cnn.cuda() 27 | model = nn.Sequential() 28 | model = model.cuda() 29 | for i,layer in enumerate(list(cnn)): 30 | model.add_module(str(i),layer) 31 | if i == conv_5_3_layer: 32 | break 33 | print(model) 34 | #sys.exit(0) 35 | return model 36 | 37 | def initialize(self, loss): 38 | self.criterion = loss 39 | self.contentFunc = self.contentFunc() 40 | 41 | def get_loss(self, fakeIm, realIm): 42 | f_fake = self.contentFunc.forward(fakeIm) 43 | f_real = self.contentFunc.forward(realIm) 44 | f_real_no_grad = f_real.detach() 45 | loss = self.criterion(f_fake, f_real_no_grad) 46 | return loss 47 | 48 | class PriorLoss(nn.Module): 49 | def initialize(self, loss): 50 | self.criterion = loss 51 | 52 | def get_loss(self, net, fakeIm, realIm): 53 | realIm_no_grad = realIm.detach() 54 | loss = self.criterion(fakeIm, realIm_no_grad) 55 | return loss 56 | 57 | class GANLoss(nn.Module): 58 | def __init__(self, use_l1=True, target_real_label=1.0, target_fake_label=0.0, 59 | tensor=torch.FloatTensor): 60 | super(GANLoss, self).__init__() 61 | self.real_label = target_real_label 62 | self.fake_label = target_fake_label 63 | self.real_label_var = None 64 | self.fake_label_var = None 65 | self.Tensor = tensor 66 | if use_l1: 67 | self.loss = nn.L1Loss() 68 | else: 69 | self.loss = nn.BCELoss() 70 | 71 | def get_target_tensor(self, input, target_is_real): 72 | target_tensor = None 73 | if target_is_real: 74 | create_label = ((self.real_label_var is None) or 75 | (self.real_label_var.numel() != input.numel())) 76 | if create_label: 77 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 78 | self.real_label_var = Variable(real_tensor, requires_grad=False) 79 | target_tensor = self.real_label_var 80 | else: 81 | create_label = ((self.fake_label_var is None) or 82 | (self.fake_label_var.numel() != input.numel())) 83 | if create_label: 84 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 85 | self.fake_label_var = Variable(fake_tensor, requires_grad=False) 86 | target_tensor = self.fake_label_var 87 | return target_tensor 88 | 89 | def __call__(self, input, target_is_real): 90 | target_tensor = self.get_target_tensor(input, target_is_real) 91 | return self.loss(input, target_tensor) 92 | 93 | class DiscLoss(): 94 | def name(self): 95 | return 'DiscLoss' 96 | 97 | def initialize(self, opt, tensor): 98 | self.criterionGAN = GANLoss(use_l1=False, tensor=tensor) 99 | self.fake_AB_pool = ImagePool(opt.pool_size) 100 | 101 | def get_g_loss(self, net, realA, fakeB): 102 | # First, G(A) should fake the discriminator 103 | pred_fake = net.forward(fakeB) 104 | return self.criterionGAN(pred_fake, 1) 105 | 106 | def get_loss(self, net, realA, fakeB, realB): 107 | # Fake 108 | # stop backprop to the generator by detaching fake_B 109 | # Generated Image Disc Output should be close to zero 110 | self.pred_fake = net.forward(fakeB.detach()) 111 | self.loss_D_fake = self.criterionGAN(self.pred_fake, 0) 112 | 113 | # Real 114 | self.pred_real = net.forward(realB) 115 | self.loss_D_real = self.criterionGAN(self.pred_real, 1) 116 | 117 | # Combined loss 118 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 119 | return self.loss_D 120 | 121 | class DiscLossLS(DiscLoss): 122 | def name(self): 123 | return 'DiscLossLS' 124 | 125 | def initialize(self, opt, tensor): 126 | DiscLoss.initialize(self, opt, tensor) 127 | self.criterionGAN = GANLoss(use_l1=True, tensor=tensor) 128 | 129 | def get_g_loss(self,net, realA, fakeB): 130 | return DiscLoss.get_g_loss(self,net, realA, fakeB) 131 | 132 | def get_loss(self, net, realA, fakeB, realB): 133 | return DiscLoss.get_loss(self, net, realA, fakeB, realB) 134 | 135 | class DiscLossWGANGP(DiscLossLS): 136 | def name(self): 137 | return 'DiscLossWGAN-GP' 138 | 139 | def initialize(self, opt, tensor): 140 | DiscLossLS.initialize(self, opt, tensor) 141 | self.LAMBDA = 10 142 | 143 | def get_g_loss(self, net, realA, fakeB): 144 | # First, G(A) should fake the discriminator 145 | self.D_fake = net.forward(fakeB) 146 | return -self.D_fake.mean() 147 | 148 | def calc_gradient_penalty(self, netD, real_data, fake_data): 149 | alpha = torch.rand(1, 1) 150 | alpha = alpha.expand(real_data.size()) 151 | alpha = alpha.cuda() 152 | 153 | interpolates = alpha * real_data + ((1 - alpha) * fake_data) 154 | 155 | interpolates = interpolates.cuda() 156 | interpolates = Variable(interpolates, requires_grad=True) 157 | 158 | disc_interpolates = netD.forward(interpolates) 159 | 160 | gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, 161 | grad_outputs=torch.ones(disc_interpolates.size()).cuda(), 162 | create_graph=True, retain_graph=True, only_inputs=True)[0] 163 | 164 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA 165 | return gradient_penalty 166 | 167 | def get_loss(self, net, realA, fakeB, realB): 168 | self.D_fake = net.forward(fakeB.detach()) 169 | self.D_fake = self.D_fake.mean() 170 | 171 | # Real 172 | self.D_real = net.forward(realB) 173 | self.D_real = self.D_real.mean() 174 | # Combined loss 175 | self.loss_D = self.D_fake - self.D_real 176 | gradient_penalty = self.calc_gradient_penalty(net, realB.data, fakeB.data) 177 | return self.loss_D + gradient_penalty 178 | 179 | 180 | def init_loss(opt, tensor): 181 | disc_loss = None 182 | content_loss = None 183 | perceptual_loss = None 184 | prior_loss = None 185 | perceptual_loss = PerceptualLoss() 186 | perceptual_loss.initialize(nn.MSELoss()) 187 | content_loss = ContentLoss() 188 | content_loss.initialize(nn.MSELoss()) 189 | 190 | if opt.gan_type == 'wgan-gp': 191 | disc_loss = DiscLossWGANGP() 192 | elif opt.gan_type == 'lsgan': 193 | disc_loss = DiscLossLS() 194 | elif opt.gan_type == 'gan': 195 | disc_loss = DiscLoss() 196 | else: 197 | raise ValueError("GAN [%s] not recognized." % opt.gan_type) 198 | disc_loss.initialize(opt, tensor) 199 | 200 | prior_loss = PriorLoss() 201 | prior_loss.initialize(nn.MSELoss()) 202 | 203 | return disc_loss, perceptual_loss, content_loss, prior_loss 204 | -------------------------------------------------------------------------------- /model/networks.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | import functools 6 | from torch.autograd import Variable 7 | import numpy as np 8 | ############################################################################### 9 | # Functions 10 | ############################################################################### 11 | 12 | 13 | def weights_init(m): 14 | classname = m.__class__.__name__ 15 | if classname.find('Conv') != -1: 16 | m.weight.data.normal_(0.0, 0.02) 17 | if hasattr(m.bias, 'data'): 18 | m.bias.data.fill_(0) 19 | elif classname.find('BatchNorm2d') != -1: 20 | m.weight.data.normal_(1.0, 0.02) 21 | m.bias.data.fill_(0) 22 | 23 | def get_norm_layer(norm_type='instance'): 24 | if norm_type == 'batch': 25 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 26 | elif norm_type == 'instance': 27 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) 28 | else: 29 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 30 | return norm_layer 31 | 32 | def define_coarse_SR_Encoder(norm_layer): 33 | coarse_SR_Encoder = [nn.ReflectionPad2d(1), 34 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=0, bias=False), 35 | norm_layer(64), 36 | nn.ReLU(True)] 37 | for i in range(3): 38 | coarse_SR_Encoder += [ResnetBlock(64, 'reflect', norm_layer, False, False)] 39 | coarse_SR_Encoder += [nn.ReflectionPad2d(1), 40 | nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=0, bias=False), 41 | nn.Tanh()] 42 | coarse_SR_Encoder = nn.Sequential(*coarse_SR_Encoder) 43 | return coarse_SR_Encoder 44 | 45 | def define_fine_SR_Encoder(norm_layer): 46 | fine_SR_Encoder = [nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False), 47 | norm_layer(64), 48 | nn.ReLU(True)] 49 | for i in range(12): 50 | fine_SR_Encoder += [ResnetBlock(64, 'reflect', norm_layer, False, False)] 51 | fine_SR_Encoder += [nn.ReflectionPad2d(1), 52 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0, bias=False), 53 | nn.Tanh()] 54 | fine_SR_Encoder = nn.Sequential(*fine_SR_Encoder) 55 | return fine_SR_Encoder 56 | 57 | def define_prior_Estimation_Network(norm_layer): 58 | prior_Estimation_Network = [nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), 59 | norm_layer(64), 60 | nn.ReLU(True)] 61 | prior_Estimation_Network += [Residual(64, 128)] 62 | for i in range(2): 63 | prior_Estimation_Network += [ResnetBlock(128, 'reflect', norm_layer, False, False)] 64 | for i in range(2): 65 | prior_Estimation_Network += [HourGlassBlock(128, 3, norm_layer)] 66 | prior_Estimation_Network = nn.Sequential(*prior_Estimation_Network) 67 | return prior_Estimation_Network 68 | 69 | def define_fine_SR_Decoder(norm_layer): 70 | fine_SR_Decoder = [nn.Conv2d(192, 64, kernel_size=3, stride=1, padding=1, bias=False), 71 | norm_layer(64), 72 | nn.ReLU(True)] 73 | fine_SR_Decoder += [nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False), 74 | norm_layer(64), 75 | nn.ReLU(True)] 76 | for i in range(3): 77 | fine_SR_Decoder += [ResnetBlock(64, 'reflect', norm_layer, False, False)] 78 | fine_SR_Decoder += [nn.ReflectionPad2d(1), 79 | nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=0, bias=False), 80 | nn.Tanh()] 81 | fine_SR_Decoder = nn.Sequential(*fine_SR_Decoder) 82 | return fine_SR_Decoder 83 | 84 | def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[], use_parallel=True, learn_residual=False): 85 | netG = None 86 | use_gpu = len(gpu_ids) > 0 87 | norm_layer = get_norm_layer(norm_type=norm) 88 | 89 | if use_gpu: 90 | assert(torch.cuda.is_available()) 91 | 92 | netG = Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids, use_parallel=use_parallel, learn_residual=learn_residual) 93 | 94 | if len(gpu_ids) > 0: 95 | netG.cuda(gpu_ids[0]) 96 | netG.apply(weights_init) 97 | return netG 98 | 99 | 100 | def define_D(input_nc, ndf, which_model_netD, 101 | n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[], use_parallel = True): 102 | netD = None 103 | use_gpu = len(gpu_ids) > 0 104 | norm_layer = get_norm_layer(norm_type=norm) 105 | 106 | if use_gpu: 107 | assert(torch.cuda.is_available()) 108 | if which_model_netD == 'basic': 109 | netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids, use_parallel=use_parallel) 110 | elif which_model_netD == 'n_layers': 111 | netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids, use_parallel=use_parallel) 112 | else: 113 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % 114 | which_model_netD) 115 | if use_gpu: 116 | netD.cuda(gpu_ids[0]) 117 | netD.apply(weights_init) 118 | return netD 119 | 120 | 121 | def print_network(net): 122 | num_params = 0 123 | for param in net.parameters(): 124 | num_params += param.numel() 125 | print(net) 126 | print('Total number of parameters: %d' % num_params) 127 | 128 | 129 | ############################################################################## 130 | # Classes 131 | ############################################################################## 132 | 133 | 134 | # Defines the generator that consists of Resnet blocks between a few 135 | # downsampling/upsampling operations. 136 | # Code and idea originally from Justin Johnson's architecture. 137 | # https://github.com/jcjohnson/fast-neural-style/ 138 | class Generator(nn.Module): 139 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], use_parallel=True, learn_residual=False, padding_type='reflect'): 140 | assert(n_blocks >= 0) 141 | super(Generator, self).__init__() 142 | self.input_nc = input_nc 143 | self.output_nc = output_nc 144 | self.ngf = ngf 145 | self.gpu_ids = gpu_ids 146 | self.use_parallel = use_parallel 147 | self.learn_residual = learn_residual 148 | if type(norm_layer) == functools.partial: 149 | use_bias = norm_layer.func == nn.InstanceNorm2d 150 | else: 151 | use_bias = norm_layer == nn.InstanceNorm2d 152 | 153 | self.coarse_SR_Encoder = define_coarse_SR_Encoder(norm_layer) 154 | self.fine_SR_Encoder = define_fine_SR_Encoder(norm_layer) 155 | self.prior_Estimation_Network = define_prior_Estimation_Network(norm_layer) 156 | self.fine_SR_Decoder = define_fine_SR_Decoder(norm_layer) 157 | 158 | def forward(self, input, is_hr=False): 159 | if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor) and self.use_parallel: 160 | if is_hr == True: 161 | heatmaps = nn.parallel.data_parallel(self.prior_Estimation_Network, input, self.gpu_ids) 162 | return heatmaps 163 | else: 164 | coarse_HR = nn.parallel.data_parallel(self.coarse_SR_Encoder, input, self.gpu_ids) 165 | parsing = nn.parallel.data_parallel(self.fine_SR_Encoder, coarse_HR, self.gpu_ids) 166 | heatmaps = nn.parallel.data_parallel(self.prior_Estimation_Network, coarse_HR, self.gpu_ids) 167 | concatenation = torch.cat((parsing, heatmaps), 1) 168 | output = nn.parallel.data_parallel(self.fine_SR_Decoder, concatenation, self.gpu_ids) 169 | else: 170 | if is_hr == True: 171 | heatmaps = self.prior_Estimation_Network(input) 172 | return heatmaps 173 | else: 174 | coarse_HR = self.coarse_SR_Encoder(input) 175 | parsing = self.fine_SR_Encoder(coarse_HR) 176 | heatmaps = self.prior_Estimation_Network(coarse_HR) 177 | concatenation = torch.cat((parsing, heatmaps), 1) 178 | output = self.fine_SR_Decoder(concatenation) 179 | 180 | if self.learn_residual: 181 | output = input + output 182 | output = torch.clamp(output, min = -1, max = 1) 183 | return coarse_HR, heatmaps, output 184 | 185 | #Define a hourglass block 186 | class HourGlassBlock(nn.Module): 187 | def __init__(self, dim, n, norm_layer): 188 | super(HourGlassBlock, self).__init__() 189 | self._dim = dim 190 | self._n = n 191 | self._norm_layer = norm_layer 192 | self._init_layers(self._dim, self._n, self._norm_layer) 193 | 194 | def _init_layers(self, dim, n, norm_layer): 195 | setattr(self, 'res'+str(n)+'_1', Residual(dim, dim)) 196 | setattr(self, 'pool'+str(n)+'_1', nn.MaxPool2d(2,2)) 197 | setattr(self, 'res'+str(n)+'_2', Residual(dim, dim)) 198 | if n > 1: 199 | self._init_layers(dim, n-1, norm_layer) 200 | else: 201 | self.res_center = Residual(dim, dim) 202 | setattr(self,'res'+str(n)+'_3', Residual(dim, dim)) 203 | setattr(self,'unsample'+str(n), nn.Upsample(scale_factor=2)) 204 | 205 | def _forward(self, x, dim, n): 206 | up1 = x 207 | up1 = eval('self.res'+str(n)+'_1')(up1) 208 | low1 = eval('self.pool'+str(n)+'_1')(x) 209 | low1 = eval('self.res'+str(n)+'_2')(low1) 210 | if n > 1: 211 | low2 = self._forward(low1, dim, n-1) 212 | else: 213 | low2 = self.res_center(low1) 214 | low3 = low2 215 | low3 = eval('self.'+'res'+str(n)+'_3')(low3) 216 | up2 = eval('self.'+'unsample'+str(n)).forward(low3) 217 | out = up1 + up2 218 | return out 219 | 220 | def forward(self, x): 221 | return self._forward(x, self._dim, self._n) 222 | 223 | class Residual(nn.Module): 224 | def __init__(self, ins, outs): 225 | super(Residual, self).__init__() 226 | self.convBlock = nn.Sequential( 227 | nn.BatchNorm2d(ins), 228 | nn.ReLU(inplace=True), 229 | nn.Conv2d(ins,outs/2,1), 230 | nn.BatchNorm2d(outs/2), 231 | nn.ReLU(inplace=True), 232 | nn.Conv2d(outs/2,outs/2,3,1,1), 233 | nn.BatchNorm2d(outs/2), 234 | nn.ReLU(inplace=True), 235 | nn.Conv2d(outs/2,outs,1) 236 | ) 237 | if ins != outs: 238 | self.skipConv = nn.Conv2d(ins,outs,1) 239 | self.ins = ins 240 | self.outs = outs 241 | def forward(self, x): 242 | residual = x 243 | x = self.convBlock(x) 244 | if self.ins != self.outs: 245 | residual = self.skipConv(residual) 246 | x += residual 247 | return x 248 | 249 | # Define a resnet block 250 | class ResnetBlock(nn.Module): 251 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, updimension=False): 252 | super(ResnetBlock, self).__init__() 253 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, updimension) 254 | 255 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, updimension): 256 | conv_block = [] 257 | p = 0 258 | if padding_type == 'reflect': 259 | conv_block += [nn.ReflectionPad2d(1)] 260 | elif padding_type == 'replicate': 261 | conv_block += [nn.ReplicationPad2d(1)] 262 | elif padding_type == 'zero': 263 | p = 1 264 | else: 265 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 266 | in_chan = dim 267 | if updimension == True: 268 | out_chan = in_chan * 2 269 | else: 270 | out_chan = dim 271 | conv_block += [nn.Conv2d(in_chan, out_chan, kernel_size=3, padding=p, bias=use_bias), 272 | norm_layer(dim), 273 | nn.ReLU(True)] 274 | 275 | if use_dropout: 276 | conv_block += [nn.Dropout(0.5)] 277 | 278 | p = 0 279 | if padding_type == 'reflect': 280 | conv_block += [nn.ReflectionPad2d(1)] 281 | elif padding_type == 'replicate': 282 | conv_block += [nn.ReplicationPad2d(1)] 283 | elif padding_type == 'zero': 284 | p = 1 285 | else: 286 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 287 | conv_block += [nn.Conv2d(in_chan, out_chan, kernel_size=3, padding=p, bias=use_bias), 288 | norm_layer(dim)] 289 | 290 | return nn.Sequential(*conv_block) 291 | 292 | def forward(self, x): 293 | out = x + self.conv_block(x) 294 | return out 295 | 296 | class NLayerDiscriminator(nn.Module): 297 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[], use_parallel = True): 298 | super(NLayerDiscriminator, self).__init__() 299 | self.gpu_ids = gpu_ids 300 | self.use_parallel = use_parallel 301 | if type(norm_layer) == functools.partial: 302 | use_bias = norm_layer.func == nn.InstanceNorm2d 303 | else: 304 | use_bias = norm_layer == nn.InstanceNorm2d 305 | 306 | kw = 4 307 | padw = int(np.ceil((kw-1)/2)) 308 | sequence = [ 309 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 310 | nn.LeakyReLU(0.2, True) 311 | ] 312 | 313 | nf_mult = 1 314 | nf_mult_prev = 1 315 | for n in range(1, n_layers): 316 | nf_mult_prev = nf_mult 317 | nf_mult = min(2**n, 8) 318 | sequence += [ 319 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 320 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), 321 | norm_layer(ndf * nf_mult), 322 | nn.LeakyReLU(0.2, True) 323 | ] 324 | 325 | nf_mult_prev = nf_mult 326 | nf_mult = min(2**n_layers, 8) 327 | sequence += [ 328 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 329 | kernel_size=kw, stride=1, padding=padw, bias=use_bias), 330 | norm_layer(ndf * nf_mult), 331 | nn.LeakyReLU(0.2, True) 332 | ] 333 | 334 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 335 | 336 | if use_sigmoid: 337 | sequence += [nn.Sigmoid()] 338 | 339 | self.model = nn.Sequential(*sequence) 340 | 341 | def forward(self, input): 342 | if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor) and self.use_parallel: 343 | return nn.parallel.data_parallel(self.model, input, self.gpu_ids) 344 | else: 345 | return self.model(input) 346 | -------------------------------------------------------------------------------- /model/networks_without_coarse.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | import functools 6 | from torch.autograd import Variable 7 | import numpy as np 8 | ############################################################################### 9 | # Functions 10 | ############################################################################### 11 | 12 | 13 | def weights_init(m): 14 | classname = m.__class__.__name__ 15 | if classname.find('Conv') != -1: 16 | m.weight.data.normal_(0.0, 0.02) 17 | if hasattr(m.bias, 'data'): 18 | m.bias.data.fill_(0) 19 | elif classname.find('BatchNorm2d') != -1: 20 | m.weight.data.normal_(1.0, 0.02) 21 | m.bias.data.fill_(0) 22 | 23 | def get_norm_layer(norm_type='instance'): 24 | if norm_type == 'batch': 25 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 26 | elif norm_type == 'instance': 27 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) 28 | else: 29 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 30 | return norm_layer 31 | 32 | def define_coarse_SR_Encoder(norm_layer): 33 | coarse_SR_Encoder = [nn.ReflectionPad2d(1), 34 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=0, bias=False), 35 | norm_layer(64), 36 | nn.ReLU(True)] 37 | for i in range(3): 38 | coarse_SR_Encoder += [ResnetBlock(64, 'reflect', norm_layer, False, False)] 39 | coarse_SR_Encoder += [nn.ReflectionPad2d(1), 40 | nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=0, bias=False), 41 | nn.Tanh()] 42 | coarse_SR_Encoder = nn.Sequential(*coarse_SR_Encoder) 43 | return coarse_SR_Encoder 44 | 45 | def define_fine_SR_Encoder(norm_layer): 46 | fine_SR_Encoder = [nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False), 47 | norm_layer(64), 48 | nn.ReLU(True)] 49 | for i in range(12): 50 | fine_SR_Encoder += [ResnetBlock(64, 'reflect', norm_layer, False, False)] 51 | fine_SR_Encoder += [nn.ReflectionPad2d(1), 52 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0, bias=False), 53 | nn.Tanh()] 54 | fine_SR_Encoder = nn.Sequential(*fine_SR_Encoder) 55 | return fine_SR_Encoder 56 | 57 | def define_prior_Estimation_Network(norm_layer): 58 | prior_Estimation_Network = [nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), 59 | norm_layer(64), 60 | nn.ReLU(True)] 61 | prior_Estimation_Network += [Residual(64, 128)] 62 | for i in range(2): 63 | prior_Estimation_Network += [ResnetBlock(128, 'reflect', norm_layer, False, False)] 64 | for i in range(2): 65 | prior_Estimation_Network += [HourGlassBlock(128, 3, norm_layer)] 66 | prior_Estimation_Network = nn.Sequential(*prior_Estimation_Network) 67 | return prior_Estimation_Network 68 | 69 | def define_fine_SR_Decoder(norm_layer): 70 | fine_SR_Decoder = [nn.Conv2d(192, 64, kernel_size=3, stride=1, padding=1, bias=False), 71 | norm_layer(64), 72 | nn.ReLU(True)] 73 | fine_SR_Decoder += [nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False), 74 | norm_layer(64), 75 | nn.ReLU(True)] 76 | for i in range(3): 77 | fine_SR_Decoder += [ResnetBlock(64, 'reflect', norm_layer, False, False)] 78 | fine_SR_Decoder += [nn.ReflectionPad2d(1), 79 | nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=0, bias=False), 80 | nn.Tanh()] 81 | fine_SR_Decoder = nn.Sequential(*fine_SR_Decoder) 82 | return fine_SR_Decoder 83 | 84 | def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[], use_parallel=True, learn_residual=False): 85 | netG = None 86 | use_gpu = len(gpu_ids) > 0 87 | norm_layer = get_norm_layer(norm_type=norm) 88 | 89 | if use_gpu: 90 | assert(torch.cuda.is_available()) 91 | 92 | netG = Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids, use_parallel=use_parallel, learn_residual=learn_residual) 93 | 94 | if len(gpu_ids) > 0: 95 | netG.cuda(gpu_ids[0]) 96 | netG.apply(weights_init) 97 | return netG 98 | 99 | 100 | def define_D(input_nc, ndf, which_model_netD, 101 | n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[], use_parallel = True): 102 | netD = None 103 | use_gpu = len(gpu_ids) > 0 104 | norm_layer = get_norm_layer(norm_type=norm) 105 | 106 | if use_gpu: 107 | assert(torch.cuda.is_available()) 108 | if which_model_netD == 'basic': 109 | netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids, use_parallel=use_parallel) 110 | elif which_model_netD == 'n_layers': 111 | netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids, use_parallel=use_parallel) 112 | else: 113 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % 114 | which_model_netD) 115 | if use_gpu: 116 | netD.cuda(gpu_ids[0]) 117 | netD.apply(weights_init) 118 | return netD 119 | 120 | 121 | def print_network(net): 122 | num_params = 0 123 | for param in net.parameters(): 124 | num_params += param.numel() 125 | print(net) 126 | print('Total number of parameters: %d' % num_params) 127 | 128 | 129 | ############################################################################## 130 | # Classes 131 | ############################################################################## 132 | 133 | 134 | # Defines the generator that consists of Resnet blocks between a few 135 | # downsampling/upsampling operations. 136 | # Code and idea originally from Justin Johnson's architecture. 137 | # https://github.com/jcjohnson/fast-neural-style/ 138 | class Generator(nn.Module): 139 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], use_parallel=True, learn_residual=False, padding_type='reflect'): 140 | assert(n_blocks >= 0) 141 | super(Generator, self).__init__() 142 | self.input_nc = input_nc 143 | self.output_nc = output_nc 144 | self.ngf = ngf 145 | self.gpu_ids = gpu_ids 146 | self.use_parallel = use_parallel 147 | self.learn_residual = learn_residual 148 | if type(norm_layer) == functools.partial: 149 | use_bias = norm_layer.func == nn.InstanceNorm2d 150 | else: 151 | use_bias = norm_layer == nn.InstanceNorm2d 152 | 153 | self.coarse_SR_Encoder = define_coarse_SR_Encoder(norm_layer) 154 | self.fine_SR_Encoder = define_fine_SR_Encoder(norm_layer) 155 | self.prior_Estimation_Network = define_prior_Estimation_Network(norm_layer) 156 | self.fine_SR_Decoder = define_fine_SR_Decoder(norm_layer) 157 | 158 | def forward(self, input, is_hr=False): 159 | if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor) and self.use_parallel: 160 | if is_hr == True: 161 | heatmaps = nn.parallel.data_parallel(self.prior_Estimation_Network, input, self.gpu_ids) 162 | return heatmaps 163 | else: 164 | coarse_HR = input 165 | #coarse_HR = nn.parallel.data_parallel(self.coarse_SR_Encoder, input, self.gpu_ids) 166 | parsing = nn.parallel.data_parallel(self.fine_SR_Encoder, coarse_HR, self.gpu_ids) 167 | heatmaps = nn.parallel.data_parallel(self.prior_Estimation_Network, coarse_HR, self.gpu_ids) 168 | concatenation = torch.cat((parsing, heatmaps), 1) 169 | output = nn.parallel.data_parallel(self.fine_SR_Decoder, concatenation, self.gpu_ids) 170 | else: 171 | if is_hr == True: 172 | heatmaps = self.prior_Estimation_Network(input) 173 | return heatmaps 174 | else: 175 | coarse_HR = input 176 | #coarse_HR = self.coarse_SR_Encoder(input) 177 | parsing = self.fine_SR_Encoder(coarse_HR) 178 | heatmaps = self.prior_Estimation_Network(coarse_HR) 179 | concatenation = torch.cat((parsing, heatmaps), 1) 180 | output = self.fine_SR_Decoder(concatenation) 181 | 182 | if self.learn_residual: 183 | output = input + output 184 | output = torch.clamp(output, min = -1, max = 1) 185 | return coarse_HR, heatmaps, output 186 | 187 | #Define a hourglass block 188 | class HourGlassBlock(nn.Module): 189 | def __init__(self, dim, n, norm_layer): 190 | super(HourGlassBlock, self).__init__() 191 | self._dim = dim 192 | self._n = n 193 | self._norm_layer = norm_layer 194 | self._init_layers(self._dim, self._n, self._norm_layer) 195 | 196 | def _init_layers(self, dim, n, norm_layer): 197 | setattr(self, 'res'+str(n)+'_1', Residual(dim, dim)) 198 | setattr(self, 'pool'+str(n)+'_1', nn.MaxPool2d(2,2)) 199 | setattr(self, 'res'+str(n)+'_2', Residual(dim, dim)) 200 | if n > 1: 201 | self._init_layers(dim, n-1, norm_layer) 202 | else: 203 | self.res_center = Residual(dim, dim) 204 | setattr(self,'res'+str(n)+'_3', Residual(dim, dim)) 205 | setattr(self,'unsample'+str(n), nn.Upsample(scale_factor=2)) 206 | 207 | def _forward(self, x, dim, n): 208 | up1 = x 209 | up1 = eval('self.res'+str(n)+'_1')(up1) 210 | low1 = eval('self.pool'+str(n)+'_1')(x) 211 | low1 = eval('self.res'+str(n)+'_2')(low1) 212 | if n > 1: 213 | low2 = self._forward(low1, dim, n-1) 214 | else: 215 | low2 = self.res_center(low1) 216 | low3 = low2 217 | low3 = eval('self.'+'res'+str(n)+'_3')(low3) 218 | up2 = eval('self.'+'unsample'+str(n)).forward(low3) 219 | out = up1 + up2 220 | return out 221 | 222 | def forward(self, x): 223 | return self._forward(x, self._dim, self._n) 224 | 225 | class Residual(nn.Module): 226 | def __init__(self, ins, outs): 227 | super(Residual, self).__init__() 228 | self.convBlock = nn.Sequential( 229 | nn.BatchNorm2d(ins), 230 | nn.ReLU(inplace=True), 231 | nn.Conv2d(ins,outs/2,1), 232 | nn.BatchNorm2d(outs/2), 233 | nn.ReLU(inplace=True), 234 | nn.Conv2d(outs/2,outs/2,3,1,1), 235 | nn.BatchNorm2d(outs/2), 236 | nn.ReLU(inplace=True), 237 | nn.Conv2d(outs/2,outs,1) 238 | ) 239 | if ins != outs: 240 | self.skipConv = nn.Conv2d(ins,outs,1) 241 | self.ins = ins 242 | self.outs = outs 243 | def forward(self, x): 244 | residual = x 245 | x = self.convBlock(x) 246 | if self.ins != self.outs: 247 | residual = self.skipConv(residual) 248 | x += residual 249 | return x 250 | 251 | # Define a resnet block 252 | class ResnetBlock(nn.Module): 253 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, updimension=False): 254 | super(ResnetBlock, self).__init__() 255 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, updimension) 256 | 257 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, updimension): 258 | conv_block = [] 259 | p = 0 260 | if padding_type == 'reflect': 261 | conv_block += [nn.ReflectionPad2d(1)] 262 | elif padding_type == 'replicate': 263 | conv_block += [nn.ReplicationPad2d(1)] 264 | elif padding_type == 'zero': 265 | p = 1 266 | else: 267 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 268 | in_chan = dim 269 | if updimension == True: 270 | out_chan = in_chan * 2 271 | else: 272 | out_chan = dim 273 | conv_block += [nn.Conv2d(in_chan, out_chan, kernel_size=3, padding=p, bias=use_bias), 274 | norm_layer(dim), 275 | nn.ReLU(True)] 276 | 277 | if use_dropout: 278 | conv_block += [nn.Dropout(0.5)] 279 | 280 | p = 0 281 | if padding_type == 'reflect': 282 | conv_block += [nn.ReflectionPad2d(1)] 283 | elif padding_type == 'replicate': 284 | conv_block += [nn.ReplicationPad2d(1)] 285 | elif padding_type == 'zero': 286 | p = 1 287 | else: 288 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 289 | conv_block += [nn.Conv2d(in_chan, out_chan, kernel_size=3, padding=p, bias=use_bias), 290 | norm_layer(dim)] 291 | 292 | return nn.Sequential(*conv_block) 293 | 294 | def forward(self, x): 295 | out = x + self.conv_block(x) 296 | return out 297 | 298 | class NLayerDiscriminator(nn.Module): 299 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[], use_parallel = True): 300 | super(NLayerDiscriminator, self).__init__() 301 | self.gpu_ids = gpu_ids 302 | self.use_parallel = use_parallel 303 | if type(norm_layer) == functools.partial: 304 | use_bias = norm_layer.func == nn.InstanceNorm2d 305 | else: 306 | use_bias = norm_layer == nn.InstanceNorm2d 307 | 308 | kw = 4 309 | padw = int(np.ceil((kw-1)/2)) 310 | sequence = [ 311 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 312 | nn.LeakyReLU(0.2, True) 313 | ] 314 | 315 | nf_mult = 1 316 | nf_mult_prev = 1 317 | for n in range(1, n_layers): 318 | nf_mult_prev = nf_mult 319 | nf_mult = min(2**n, 8) 320 | sequence += [ 321 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 322 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), 323 | norm_layer(ndf * nf_mult), 324 | nn.LeakyReLU(0.2, True) 325 | ] 326 | 327 | nf_mult_prev = nf_mult 328 | nf_mult = min(2**n_layers, 8) 329 | sequence += [ 330 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 331 | kernel_size=kw, stride=1, padding=padw, bias=use_bias), 332 | norm_layer(ndf * nf_mult), 333 | nn.LeakyReLU(0.2, True) 334 | ] 335 | 336 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 337 | 338 | if use_sigmoid: 339 | sequence += [nn.Sigmoid()] 340 | 341 | self.model = nn.Sequential(*sequence) 342 | 343 | def forward(self, input): 344 | if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor) and self.use_parallel: 345 | return nn.parallel.data_parallel(self.model, input, self.gpu_ids) 346 | else: 347 | return self.model(input) 348 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cydiachen/FSRNET_pytorch/e792839db4b6c49204842ae42975da74fac96285/test.py -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | # Using this code to force the usage of any specific GPUs 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | import argparse 5 | import os 6 | import random 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim as optim 12 | import torch.utils.data 13 | import torchvision.datasets as dset 14 | import torch.utils.data as data 15 | import time 16 | import numpy as np 17 | import torchvision.utils as vutils 18 | from torch.autograd import Variable 19 | from math import log10 20 | import torchvision 21 | import cv2 22 | import skimage 23 | import scipy.io 24 | import glob 25 | import matplotlib.image as mpimg 26 | import matplotlib.pyplot as plt 27 | from models import losses 28 | from model.networks import * 29 | from utils.model_storage import save_checkpoint 30 | from data.dataloader import * 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--pretrained", default="", type=str, help="path to pretrained model (default: none)") 34 | parser.add_argument("--lr", default="2.5e-4", type=float, help="The learning rate of our network") 35 | parser.add_argument("--save_freq", default="2", type=float, help="The intervals of our model storage intervals") 36 | parser.add_argument("--iter_freq", default="2", type=float, help="The intervals of our model's evaluation intervals") 37 | parser.add_argument("--result_dir", default="./result/", type=str, help="The path of our result images") 38 | parser.add_argument("--model_path", default="./weights/", type=str, help="The path to store our model") 39 | parser.add_argument("--epochs", default="100", type=int, help="The path to store our model") 40 | parser.add_argument("--start_epoch", default="0", type=int, help="The path to store our model") 41 | parser.add_argument("--batch_size", default="14", type=int, help="The path to store our batch_size") 42 | parser.add_argument("--image_dir", default="./data/CelebA-HQ-img/", type=str, help="The path to store our batch_size") 43 | parser.add_argument("--image_list", default="./data/train_fileList.txt", type=int, help="The path to store our batch_size") 44 | 45 | global opt,model 46 | opt = parser.parse_args() 47 | start_time = time.time() 48 | 49 | demo_dataset = ImageDatasetFromFile( 50 | opt.image_list, 51 | opt.image_dir) 52 | train_data_loader = data.DataLoader(dataset=demo_dataset, batch_size=opt.batch_size, num_workers=8, drop_last=True, 53 | pin_memory=True) 54 | 55 | fsrnet = define_G(input_nc = 3, output_nc = 3) 56 | criterion_MSE = nn.MSELoss() 57 | 58 | if torch.cuda.is_available(): 59 | fsrnet = fsrnet.cuda() 60 | criterion_MSE = criterion_MSE.cuda() 61 | 62 | optimizerG = optim.RMSprop(fsrnet.parameters(),lr = opt.lr) 63 | 64 | if opt.pretrained: 65 | if os.path.isfile(opt.pretrained): 66 | print("=> loading model '{}'".format(opt.pretrained)) 67 | weights = torch.load(opt.pretrained) 68 | 69 | # debug 70 | pretrained_dict = weights['model'].state_dict() 71 | model_dict = fsrnet.state_dict() 72 | 73 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 74 | model_dict.update(pretrained_dict) 75 | 76 | # 3. load the new state dict 77 | fsrnet.load_state_dict(model_dict) 78 | else: 79 | print("=> no model found at '{}'".format(opt.pretrained)) 80 | 81 | for epoch in range(opt.start_epoch,opt.epochs): 82 | if epoch % opt.save_freq == 0: 83 | #model, epoch, model_path, iteration, prefix="" 84 | save_checkpoint(fsrnet, epoch, opt.model_path,0, prefix='_ParsingMaps_') 85 | 86 | for iteration, batch in enumerate(train_data_loader): 87 | input, target, heatmaps = Variable(batch[0]), Variable(batch[1]), Variable(batch[2]) 88 | heatmaps = heatmaps.type_as(input) 89 | 90 | if torch.cuda.is_available(): 91 | input = input.cuda() 92 | heatmaps = heatmaps.cuda() 93 | target = target.cuda() 94 | 95 | upscaled,boundaries,reconstructed = fsrnet(input) 96 | fsrnet.zero_grad() 97 | 98 | loss_us = criterion_MSE(upscaled,target) 99 | loss_hm = criterion_MSE(boundaries,heatmaps) 100 | loss_final = criterion_MSE(reconstructed,target) 101 | g_loss = loss_us + loss_hm + loss_final 102 | g_loss.backward() 103 | optimizerG.step() 104 | 105 | info = "===> Epoch[{}]({}/{}): time: {:4.4f}:\n".format(epoch, iteration, len(demo_dataset) // 16, 106 | time.time() - start_time) 107 | info += "Total_loss: {:.4f}, Basic Upscale Loss:{:.4f}, Prior Estimation Loss:{:.4f}, Final Reconstruction Loss: {:.4f}\n".format( 108 | g_loss.float(), loss_us.float(), loss_hm.float(), loss_final.float()) 109 | 110 | print(info) 111 | 112 | if epoch % opt.iter_freq == 0: 113 | # model, epoch, model_path, iteration, prefix="" 114 | if not os.path.isdir(opt.result_dir + '%04d_Coarse_SR_network' % epoch): 115 | os.makedirs(opt.result_dir + '%04d_Coarse_SR_network' % epoch) 116 | if not os.path.isdir(opt.result_dir + '%04d_Prior_Estimation' % epoch): 117 | os.makedirs(opt.result_dir + '%04d_Prior_Estimation' % epoch) 118 | if not os.path.isdir(opt.result_dir + '%04d_Final_SR_reconstruction' % epoch): 119 | os.makedirs(opt.result_dir + '%04d_Final_SR_reconstruction' % epoch) 120 | 121 | final_output = reconstructed.permute(0,2,3,1).detach().cpu().numpy() 122 | final_output_0 = final_output[0,:,:,:] 123 | 124 | estimated_boundary = boundaries.permute(0,2,3,1).detach().cpu().numpy() 125 | estimated_boundary_0 = estimated_boundary[0,:,:,0] 126 | 127 | output = upscaled.permute(0,2,3,1).detach().cpu().numpy() 128 | output_0 = output[0,:,:,:] 129 | 130 | scipy.misc.toimage(output_0 * 131 | 255, high=255, low=0, cmin=0, cmax=255).save( 132 | opt.result_dir + '%04d_Coarse_SR_network/%d.jpg' % (epoch, iteration)) 133 | scipy.misc.toimage(estimated_boundary_0 * 255, high=255, low=0, cmin=0, cmax=255).save( 134 | opt.result_dir + '%04d_Prior_Estimation/%d.jpg' % (epoch, iteration)) 135 | scipy.misc.toimage(final_output_0 * 255, high=255, low=0, cmin=0, cmax=255).save( 136 | opt.result_dir + '%04d_Final_SR_reconstruction/%d.jpg' % (epoch, iteration)) -------------------------------------------------------------------------------- /utils/model_storage.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | def save_checkpoint(model, epoch, model_path, iteration, prefix=""): 5 | if not os.path.isdir(model_path): 6 | os.makedirs(model_path) 7 | model_out_path = model_path + prefix + "model_epoch_{}_iter_{}.pth".format(epoch, iteration) 8 | state = {"epoch": epoch, "model": model} 9 | if not os.path.exists(model_path): 10 | os.makedirs(model_path) 11 | torch.save(state, model_out_path) 12 | print("Checkpoint saved to {}".format(model_out_path)) 13 | 14 | # 后期检查一下是否需要单独写一下 15 | 16 | def save_Hyperparameter(): 17 | pass --------------------------------------------------------------------------------