├── figure ├── readme.txt ├── figure1.png ├── figure2.png ├── figure3.png └── figure4.png ├── logs └── readme.txt ├── models └── readme.txt ├── Result └── readme.txt ├── Test ├── readme.txt ├── 0905.png └── 0922.png ├── data.py ├── distributions.py ├── prepare_images.py ├── README.md ├── image_utils.py ├── datasets.py ├── test.py ├── main_denoiser.py ├── test_enhance.py ├── main_GAN.py └── modules.py /figure/readme.txt: -------------------------------------------------------------------------------- 1 | description 2 | -------------------------------------------------------------------------------- /logs/readme.txt: -------------------------------------------------------------------------------- 1 | training logs here 2 | -------------------------------------------------------------------------------- /models/readme.txt: -------------------------------------------------------------------------------- 1 | pre-trained models here 2 | -------------------------------------------------------------------------------- /Result/readme.txt: -------------------------------------------------------------------------------- 1 | Results of super-resolved images 2 | -------------------------------------------------------------------------------- /Test/readme.txt: -------------------------------------------------------------------------------- 1 | put input low-resolution images here 2 | -------------------------------------------------------------------------------- /Test/0905.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Holmes-Alan/dSRVAE/HEAD/Test/0905.png -------------------------------------------------------------------------------- /Test/0922.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Holmes-Alan/dSRVAE/HEAD/Test/0922.png -------------------------------------------------------------------------------- /figure/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Holmes-Alan/dSRVAE/HEAD/figure/figure1.png -------------------------------------------------------------------------------- /figure/figure2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Holmes-Alan/dSRVAE/HEAD/figure/figure2.png -------------------------------------------------------------------------------- /figure/figure3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Holmes-Alan/dSRVAE/HEAD/figure/figure3.png -------------------------------------------------------------------------------- /figure/figure4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Holmes-Alan/dSRVAE/HEAD/figure/figure4.png -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from torchvision.transforms import Compose, ToTensor, Normalize 3 | from datasets import DatasetFromFolderEval, DatasetFromFolder 4 | 5 | def transform(): 6 | return Compose([ 7 | ToTensor(), 8 | # Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 9 | ]) 10 | 11 | def get_training_set(data_dir, upscale_factor, patch_size, data_augmentation): 12 | hr_dir = join(data_dir, 'HR') 13 | lr_dir = join(data_dir, 'LR') 14 | return DatasetFromFolder(hr_dir, lr_dir, patch_size, upscale_factor, data_augmentation, 15 | transform=transform()) 16 | 17 | def get_eval_set(lr_dir, upscale_factor): 18 | return DatasetFromFolderEval(lr_dir, upscale_factor, 19 | transform=transform()) 20 | 21 | -------------------------------------------------------------------------------- /distributions.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.utils.data 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | min_epsilon = 1e-5 8 | max_epsilon = 1.-1e-5 9 | #======================================================================================================================= 10 | def log_Normal_diag(x, mean, log_var, average=False, dim=None): 11 | log_normal = -0.5 * ( log_var + torch.pow( x - mean, 2 ) / torch.exp( log_var ) ) 12 | if average: 13 | return torch.mean( log_normal, dim ) 14 | else: 15 | return torch.sum( log_normal, dim ) 16 | 17 | def log_Normal_standard(x, average=False, dim=None): 18 | log_normal = -0.5 * torch.pow( x , 2 ) 19 | if average: 20 | return torch.mean( log_normal, dim ) 21 | else: 22 | return torch.sum( log_normal, dim ) 23 | 24 | def log_Bernoulli(x, mean, average=False, dim=None): 25 | probs = torch.clamp( mean, min=min_epsilon, max=max_epsilon ) 26 | log_bernoulli = x * torch.log( probs ) + (1. - x ) * torch.log( 1. - probs ) 27 | if average: 28 | return torch.mean( log_bernoulli, dim ) 29 | else: 30 | return torch.sum( log_bernoulli, dim ) 31 | 32 | def logisticCDF(x, u, s): 33 | return 1. / ( 1. + torch.exp( -(x-u) / s ) ) 34 | 35 | def sigmoid(x): 36 | return 1. / ( 1. + torch.exp( -x ) ) 37 | 38 | def log_Logistic_256(x, mean, logvar, average=False, reduce=True, dim=None): 39 | bin_size = 1. / 256. 40 | 41 | # implementation like https://github.com/openai/iaf/blob/master/tf_utils/distributions.py#L28 42 | scale = torch.exp(logvar) 43 | x = (torch.floor(x / bin_size) * bin_size - mean) / scale 44 | cdf_plus = torch.sigmoid(x + bin_size/scale) 45 | cdf_minus = torch.sigmoid(x) 46 | 47 | # calculate final log-likelihood for an image 48 | log_logist_256 = - torch.log(cdf_plus - cdf_minus + 1.e-7) 49 | 50 | if reduce: 51 | if average: 52 | return torch.mean(log_logist_256, dim) 53 | else: 54 | return torch.sum(log_logist_256, dim) 55 | else: 56 | return log_logist_256 57 | 58 | 59 | def log_Logistic_512(x, mean, logvar, average=False, reduce=True, dim=None): 60 | bin_size = 1. / 512. 61 | 62 | # implementation like https://github.com/openai/iaf/blob/master/tf_utils/distributions.py#L28 63 | scale = torch.exp(logvar) 64 | x = (torch.floor(x / bin_size) * bin_size - mean) / scale 65 | cdf_plus = torch.sigmoid(x + bin_size/scale) 66 | cdf_minus = torch.sigmoid(x) 67 | 68 | # calculate final log-likelihood for an image 69 | log_logist_512 = - torch.log(cdf_plus - cdf_minus + 1.e-7) 70 | 71 | if reduce: 72 | if average: 73 | return torch.mean(log_logist_512, dim) 74 | else: 75 | return torch.sum(log_logist_512, dim) 76 | else: 77 | return log_logist_512 -------------------------------------------------------------------------------- /prepare_images.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | 4 | import torch.nn as nn 5 | import torch 6 | 7 | 8 | class ImageSplitter: 9 | # key points: 10 | # Boarder padding and over-lapping img splitting to avoid the instability of edge value 11 | # Thanks Waifu2x's autorh nagadomi for suggestions (https://github.com/nagadomi/waifu2x/issues/238) 12 | 13 | def __init__(self, patch_size, scale_factor, stride): 14 | self.patch_size = patch_size 15 | self.scale_factor = scale_factor 16 | self.stride = stride 17 | self.height = 0 18 | self.width = 0 19 | 20 | def split_img_tensor(self, img_tensor): 21 | # resize image and convert them into tensor 22 | batch, channel, height, width = img_tensor.size() 23 | self.height = height 24 | self.width = width 25 | 26 | side = min(height, width, self.patch_size) 27 | delta = self.patch_size - side 28 | Z = torch.zeros([batch, channel, height+delta, width+delta]) 29 | Z[:, :, delta//2:height+delta//2, delta//2:width+delta//2] = img_tensor 30 | batch, channel, new_height, new_width = Z.size() 31 | 32 | patch_box = [] 33 | 34 | # split image into over-lapping pieces 35 | for i in range(0, new_height, self.stride): 36 | for j in range(0, new_width, self.stride): 37 | x = min(new_height, i + self.patch_size) 38 | y = min(new_width, j + self.patch_size) 39 | part = Z[:, :, x-self.patch_size:x, y-self.patch_size:y] 40 | 41 | patch_box.append(part) 42 | 43 | patch_tensor = torch.cat(patch_box, dim=0) 44 | return patch_tensor 45 | 46 | def merge_img_tensor(self, list_img_tensor): 47 | img_tensors = copy.copy(list_img_tensor) 48 | 49 | patch_size = self.patch_size * self.scale_factor 50 | stride = self.stride * self.scale_factor 51 | height = self.height * self.scale_factor 52 | width = self.width * self.scale_factor 53 | side = min(height, width, patch_size) 54 | delta = patch_size - side 55 | new_height = delta + height 56 | new_width = delta + width 57 | out = torch.zeros((1, 3, new_height, new_width)) 58 | mask = torch.zeros((1, 3, new_height, new_width)) 59 | 60 | for i in range(0, new_height, stride): 61 | for j in range(0, new_width, stride): 62 | x = min(new_height, i + patch_size) 63 | y = min(new_width, j + patch_size) 64 | mask_patch = torch.zeros((1, 3, new_height, new_width)) 65 | out_patch = torch.zeros((1, 3, new_height, new_width)) 66 | mask_patch[:, :, (x - patch_size):x, (y - patch_size):y] = 1.0 67 | out_patch[:, :, (x - patch_size):x, (y - patch_size):y] = img_tensors.pop(0) 68 | mask = mask + mask_patch 69 | out = out + out_patch 70 | 71 | out = out / mask 72 | 73 | out = out[:, :, delta//2:new_height - delta//2, delta//2:new_width - delta//2] 74 | 75 | return out 76 | 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dSRVAE (Generative Variational AutoEncoder for Real Image Super-Resolution) 2 | 3 | By Zhi-Song Liu, Li-Wen Wang, Chu-Tak Li, Marie-Paule Cani and Wan-Chi Siu 4 | 5 | This repo only provides simple testing codes, pretrained models and the network strategy demo. 6 | 7 | We propose a joint image denoising and Super-Resolution model by using generative Variational AutoEncoder (dSRVAE) 8 | 9 | We participate CVPRW [NTIRE2020 Real Image Super-Resolution Challenge](https://data.vision.ee.ethz.ch/cvl/ntire20/) 10 | 11 | Please check our [paper](https://arxiv.org/abs/2004.12811) 12 | 13 | # BibTex 14 | 15 | @InProceedings{Liu2020dsrvae, 16 | author = {Zhi-Song Liu, Wan-Chi Siu and Li-Wen Wang and Chu-Tak Li and Marie-Paule Cani and Yui-Lam Chan}, 17 | title = {Unsupervised Real Image Super-Resolution via Generative Variational AutoEncoder}, 18 | booktitle = {IEEE International Conference on Computer Vision and Pattern Recognition Workshop(CVPRW)}, 19 | month = {June}, 20 | year = {2020} 21 | } 22 | 23 | # For proposed dSRVAE model, we claim the following points: 24 | 25 | • First working on using Variational AutoEncoder for image denoising. 26 | 27 | • Then the Super-Resolution Sub-Network (SRSN) is attached as a small overhead to the DAE which forms the proposed dSRVAE to output super-resolved images. 28 | 29 | # Dependencies 30 | Python > 3.0 31 | OpenCV library 32 | Pytorch > 1.0 33 | NVIDIA GPU + CUDA 34 | pytorch-gan-metrics 35 | 36 | # Complete Architecture 37 | The complete architecture is shown as follows, 38 | 39 | ![network](/figure/figure1.png) 40 | 41 | # Implementation 42 | ## 1. Quick testing 43 | --------------------------------------- 44 | 1. Download pre-trained models from 45 | 46 | https://drive.google.com/open?id=1SUZGE04vw5_yDYiw6PJ4sbHAOIEV6TJ7 47 | 48 | and copy them to the folder "models" 49 | 50 | 2. Copy your image to folder "Test" and run 51 | ```sh 52 | $ python test.py 53 | ``` 54 | The SR images will be in folder "Result" 55 | 3. For self-ensemble, run 56 | ```sh 57 | $ python test_enhance.py 58 | ``` 59 | 4. GAN feature evaluation 60 | ``` 61 | # download statistics.npz from http://bioinf.jku.at/research/ttur/ 62 | from pytorch_gan_metrics import get_inception_score, get_fid 63 | 64 | images = ... # [N, 3, H, W] normalized to [0, 1] 65 | IS, IS_std = get_inception_score(images) # Inception Score 66 | FID = get_fid(images, 'path/to/statistics.npz') # Frechet Inception Distance 67 | ``` 68 | 69 | 70 | ## 2. Testing for NTIRE 20202 71 | --------------------------------------- 72 | 73 | ### s1. Testing images on NTIRE2020 Real World Super-Resolution Challenge - Track 1: Image Processing artifacts can be downloaded from the following link: 74 | 75 | https://drive.google.com/open?id=10ZutE-0idGFW0KUyfZ5-2aVSiA-1qUCV 76 | 77 | ### s2. Testing images on NTIRE2020 Real World Super-Resolution Challenge - Track 2: Smartphone Images can be downloaded from the following link: 78 | 79 | https://drive.google.com/open?id=1_R4kRO_029g-HNAzPobo4-xwp86bMZLW 80 | 81 | ### s3. Validation images on NTIRE2020 Real World Super-Resolution Challenge - Track 1 and Track 2 can be downloaded from the following link: 82 | 83 | https://drive.google.com/open?id=1nKEJ4N2V-0NFicfJxm8AJqsjXoGMYjMp 84 | 85 | ## 3. Training 86 | --------------------------- 87 | ### s1. Download the training images from NTIRE2020. 88 | 89 | https://competitions.codalab.org/competitions/22220#learn_the_details 90 | 91 | 92 | ### s2. Start training on Pytorch 93 | 1. Train the Denoising VAE by running 94 | ```sh 95 | $ python main_denoiser.py 96 | ``` 97 | 2. Train the super-resolution SRSN overhead by running 98 | ```sh 99 | $ python main_GAN.py 100 | ``` 101 | --------------------------- 102 | 103 | ## Partial image visual comparison 104 | 105 | ## 1. Visualization comparison 106 | Results on 4x image SR on Track 1 dataset 107 | ![figure2](/figure/figure2.png) 108 | ![figure3](/figure/figure3.png) 109 | ![figure4](/figure/figure4.png) 110 | 111 | 112 | # Reference 113 | You may check our newly work on [General image super-resolution using VAE](https://github.com/Holmes-Alan/SR-VAE) 114 | 115 | You may also check our work on [Reference based face SR using VAE](https://github.com/Holmes-Alan/RefSR) 116 | 117 | You may also check our work on [Reference based General image SR using VAE](https://github.com/Holmes-Alan/RefVAE) 118 | 119 | Special thanks to the contributions of Jakub M. Tomczak for their [VAE with a VampPrior](https://github.com/jmtomczak/vae_vampprior) on KL loss calculation. 120 | -------------------------------------------------------------------------------- /image_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def reduce_image(img, scale): 5 | batch, channels, height, width = img.size() 6 | reduced_img = torch.zeros(batch, channels * scale * scale, height // scale, width // scale).cuda() 7 | 8 | for x in range(scale): 9 | for y in range(scale): 10 | for c in range(channels): 11 | reduced_img[:, c + channels * (y + scale * x), :, :] = img[:, c, x::scale, y::scale] 12 | return reduced_img 13 | 14 | 15 | def reconstruct_image(features, scale): 16 | batch, channels, height, width = features.size() 17 | img_channels = channels // (scale**2) 18 | reconstructed_img = torch.zeros(batch, img_channels, height * scale, width * scale).cuda() 19 | 20 | for x in range(scale): 21 | for y in range(scale): 22 | for c in range(img_channels): 23 | f_channel = c + img_channels * (y + scale * x) 24 | reconstructed_img[:, c, x::scale, y::scale] = features[:, f_channel, :, :] 25 | return reconstructed_img 26 | 27 | 28 | def patchify_tensor(features, patch_size, overlap=10): 29 | batch_size, channels, height, width = features.size() 30 | 31 | effective_patch_size = patch_size - overlap 32 | n_patches_height = (height // effective_patch_size) 33 | n_patches_width = (width // effective_patch_size) 34 | 35 | if n_patches_height * effective_patch_size < height: 36 | n_patches_height += 1 37 | if n_patches_width * effective_patch_size < width: 38 | n_patches_width += 1 39 | 40 | patches = [] 41 | for b in range(batch_size): 42 | for h in range(n_patches_height): 43 | for w in range(n_patches_width): 44 | patch_start_height = min(h * effective_patch_size, height - patch_size) 45 | patch_start_width = min(w * effective_patch_size, width - patch_size) 46 | patches.append(features[b:b+1, :, 47 | patch_start_height: patch_start_height + patch_size, 48 | patch_start_width: patch_start_width + patch_size]) 49 | return torch.cat(patches, 0) 50 | 51 | 52 | def recompose_tensor(patches, full_height, full_width, overlap=10): 53 | 54 | batch_size, channels, patch_size, _ = patches.size() 55 | effective_patch_size = patch_size - overlap 56 | n_patches_height = (full_height // effective_patch_size) 57 | n_patches_width = (full_width // effective_patch_size) 58 | 59 | if n_patches_height * effective_patch_size < full_height: 60 | n_patches_height += 1 61 | if n_patches_width * effective_patch_size < full_width: 62 | n_patches_width += 1 63 | 64 | n_patches = n_patches_height * n_patches_width 65 | if batch_size % n_patches != 0: 66 | print("Error: The number of patches provided to the recompose function does not match the number of patches in each image.") 67 | final_batch_size = batch_size // n_patches 68 | 69 | blending_in = torch.linspace(0.1, 1.0, overlap) 70 | blending_out = torch.linspace(1.0, 0.1, overlap) 71 | middle_part = torch.ones(patch_size - 2 * overlap) 72 | blending_profile = torch.cat([blending_in, middle_part, blending_out], 0) 73 | 74 | horizontal_blending = blending_profile[None].repeat(patch_size, 1) 75 | vertical_blending = blending_profile[:, None].repeat(1, patch_size) 76 | blending_patch = horizontal_blending * vertical_blending 77 | 78 | blending_image = torch.zeros(1, channels, full_height, full_width) 79 | for h in range(n_patches_height): 80 | for w in range(n_patches_width): 81 | patch_start_height = min(h * effective_patch_size, full_height - patch_size) 82 | patch_start_width = min(w * effective_patch_size, full_width - patch_size) 83 | blending_image[0, :, patch_start_height: patch_start_height + patch_size, patch_start_width: patch_start_width + patch_size] += blending_patch[None] 84 | 85 | recomposed_tensor = torch.zeros(final_batch_size, channels, full_height, full_width) 86 | if patches.is_cuda: 87 | blending_patch = blending_patch.cuda() 88 | blending_image = blending_image.cuda() 89 | recomposed_tensor = recomposed_tensor.cuda() 90 | patch_index = 0 91 | for b in range(final_batch_size): 92 | for h in range(n_patches_height): 93 | for w in range(n_patches_width): 94 | patch_start_height = min(h * effective_patch_size, full_height - patch_size) 95 | patch_start_width = min(w * effective_patch_size, full_width - patch_size) 96 | recomposed_tensor[b, :, patch_start_height: patch_start_height + patch_size, patch_start_width: patch_start_width + patch_size] += patches[patch_index] * blending_patch 97 | patch_index += 1 98 | recomposed_tensor /= blending_image 99 | 100 | return recomposed_tensor -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | import numpy as np 4 | import os 5 | from os import listdir 6 | from os.path import join 7 | from PIL import Image, ImageOps 8 | import random 9 | from random import randrange 10 | 11 | 12 | def is_image_file(filename): 13 | return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) 14 | 15 | 16 | def load_img(filepath): 17 | img = Image.open(filepath).convert('RGB') 18 | # y, _, _ = img.split() 19 | return img 20 | 21 | 22 | def modcrop(im, modulo): 23 | (h, w) = im.size 24 | new_h = h//modulo*modulo 25 | new_w = w//modulo*modulo 26 | ih = h - new_h 27 | iw = w - new_w 28 | ims = im.crop((0, 0, h - ih, w - iw)) 29 | return ims 30 | 31 | def rescale_img(img_in, scale): 32 | size_in = img_in.size 33 | new_size_in = tuple([int(x * scale) for x in size_in]) 34 | img_in = img_in.resize(new_size_in, resample=Image.BICUBIC) 35 | return img_in 36 | 37 | 38 | def get_patch(img_in, img_tar, img_ref, patch_size, scale, ix=-1, iy=-1): 39 | (ih, iw) = img_in.size 40 | #(th, tw) = (scale * ih, scale * iw) 41 | 42 | patch_mult = scale # if len(scale) > 1 else 1 43 | tp = patch_mult * patch_size 44 | ip = tp // scale 45 | 46 | if ix == -1: 47 | ix = random.randrange(0, iw - tp + 1) 48 | if iy == -1: 49 | iy = random.randrange(0, ih - tp + 1) 50 | 51 | (tx, ty) = (scale * ix, scale * iy) 52 | 53 | out_in = img_in.crop((iy, ix, iy + ip, ix + ip)) 54 | out_tar = img_tar.crop((iy, ix, iy + ip, ix + ip)) 55 | out_ref = img_ref.crop((iy, ix, iy + tp, ix + tp)) 56 | #img_bic = img_bic.crop((ty, tx, ty + tp, tx + tp)) 57 | 58 | #info_patch = { 59 | # 'ix': ix, 'iy': iy, 'ip': ip, 'tx': tx, 'ty': ty, 'tp': tp} 60 | 61 | return out_in, out_tar, out_ref 62 | 63 | 64 | def augment(img_in, img_tar, img_ref, flip_h=True, rot=True): 65 | info_aug = {'flip_h': False, 'flip_v': False, 'trans': False} 66 | 67 | if random.random() < 0.5 and flip_h: 68 | img_in = ImageOps.flip(img_in) 69 | img_tar = ImageOps.flip(img_tar) 70 | img_ref = ImageOps.flip(img_ref) 71 | #img_bic = ImageOps.flip(img_bic) 72 | info_aug['flip_h'] = True 73 | 74 | if rot: 75 | if random.random() < 0.5: 76 | img_in = ImageOps.mirror(img_in) 77 | img_tar = ImageOps.mirror(img_tar) 78 | img_ref = ImageOps.mirror(img_ref) 79 | #img_bic = ImageOps.mirror(img_bic) 80 | info_aug['flip_v'] = True 81 | if random.random() < 0.5: 82 | img_in = img_in.rotate(180) 83 | img_tar = img_tar.rotate(180) 84 | img_ref = img_ref.rotate(180) 85 | #img_bic = img_bic.rotate(180) 86 | info_aug['trans'] = True 87 | 88 | return img_in, img_tar, img_ref, info_aug 89 | 90 | 91 | class DatasetFromFolder(data.Dataset): 92 | def __init__(self, HR_dir, LR_dir, patch_size, upscale_factor, data_augmentation, transform=None): 93 | super(DatasetFromFolder, self).__init__() 94 | self.hr_image_filenames = [join(HR_dir, x) for x in listdir(HR_dir) if is_image_file(x)] # uncomment it 95 | self.lr_image_filenames = [join(LR_dir, x) for x in listdir(LR_dir) if is_image_file(x)] 96 | self.patch_size = patch_size 97 | self.upscale_factor = upscale_factor 98 | self.transform = transform 99 | self.data_augmentation = data_augmentation 100 | 101 | def __getitem__(self, index): 102 | 103 | target = load_img(self.hr_image_filenames[index]) 104 | target = modcrop(target, 16) 105 | ref = load_img(self.hr_image_filenames[len(self.hr_image_filenames)-index]) 106 | ref = modcrop(ref, 16) 107 | name = self.hr_image_filenames[index] 108 | input = load_img(self.lr_image_filenames[index]) 109 | 110 | input, target, ref = get_patch(input, target, ref, self.patch_size, self.upscale_factor) 111 | 112 | if self.data_augmentation: 113 | input, target, ref, _ = augment(input, target, ref) 114 | 115 | if self.transform: 116 | input = self.transform(input) 117 | target = self.transform(target) 118 | ref = self.transform(ref) 119 | 120 | return input, target, ref 121 | 122 | def __len__(self): 123 | return len(self.hr_image_filenames) # modify the hr to lr 124 | 125 | 126 | class DatasetFromFolderEval(data.Dataset): 127 | def __init__(self, lr_dir, upscale_factor, transform=None): 128 | super(DatasetFromFolderEval, self).__init__() 129 | self.image_filenames = [join(lr_dir, x) for x in listdir(lr_dir) if is_image_file(x)] 130 | self.upscale_factor = upscale_factor 131 | self.transform = transform 132 | 133 | def __getitem__(self, index): 134 | input = load_img(self.image_filenames[index]) 135 | _, file = os.path.split(self.image_filenames[index]) 136 | 137 | bicubic = rescale_img(input, self.upscale_factor) 138 | 139 | if self.transform: 140 | #input = self.transform(input) 141 | bicubic = self.transform(bicubic) 142 | 143 | return bicubic, file 144 | 145 | def __len__(self): 146 | return len(self.image_filenames) 147 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | 4 | import os 5 | import torch 6 | from modules import VAE_SR, VAE_denoise_vali 7 | import torch.nn.functional as F 8 | import torchvision.transforms as transforms 9 | from torch.utils.data import DataLoader 10 | import numpy as np 11 | from os.path import join 12 | import time 13 | from collections import OrderedDict 14 | import math 15 | from datasets import is_image_file 16 | from image_utils import * 17 | from PIL import Image, ImageOps 18 | from os import listdir 19 | import torch.utils.data as utils 20 | from torch.autograd import Variable 21 | import os 22 | # Training settings 23 | parser = argparse.ArgumentParser(description='PyTorch Super Res Example') 24 | parser.add_argument('--upscale_factor', type=int, default=4, help="super resolution upscale factor") 25 | parser.add_argument('--testBatchSize', type=int, default=64, help='testing batch size') 26 | parser.add_argument('--gpu_mode', type=bool, default=True) 27 | parser.add_argument('--chop_forward', type=bool, default=True) 28 | parser.add_argument('--patch_size', type=int, default=128, help='0 to use original frame size') 29 | parser.add_argument('--stride', type=int, default=8, help='0 to use original patch size') 30 | parser.add_argument('--threads', type=int, default=6, help='number of threads for data loader to use') 31 | parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123') 32 | parser.add_argument('--gpus', default=2, type=int, help='number of gpu') 33 | parser.add_argument('--input', type=str, default='Test', help='Location to input images') 34 | parser.add_argument('--model_type', type=str, default='VAE') 35 | parser.add_argument('--output', default='Result', help='Location to save SR results') 36 | parser.add_argument('--model_denoiser', default='models/VAE_denoiser.pth', help='pretrained denoising model') 37 | parser.add_argument('--model_SR', default='models/VAE_SR.pth', help='pretrained SR model') 38 | 39 | opt = parser.parse_args() 40 | 41 | gpus_list = range(opt.gpus) 42 | print(opt) 43 | 44 | cuda = opt.gpu_mode 45 | if cuda and not torch.cuda.is_available(): 46 | raise Exception("No GPU found, please run without --cuda") 47 | 48 | torch.manual_seed(opt.seed) 49 | if cuda: 50 | torch.cuda.manual_seed(opt.seed) 51 | 52 | print('===> Building model ', opt.model_type) 53 | 54 | 55 | denoiser = VAE_denoise_vali(input_dim=3, dim=32, feat_size=8, z_dim=512, prior='standard') 56 | model = VAE_SR(input_dim=3, dim=64, scale_factor=opt.upscale_factor) 57 | 58 | denoiser = torch.nn.DataParallel(denoiser, device_ids=gpus_list) 59 | model = torch.nn.DataParallel(model, device_ids=gpus_list) 60 | if cuda: 61 | denoiser = denoiser.cuda(gpus_list[0]) 62 | model = model.cuda(gpus_list[0]) 63 | 64 | 65 | print('===> Loading datasets') 66 | 67 | if os.path.exists(opt.model_denoiser): 68 | # denoiser.load_state_dict(torch.load(opt.model_denoiser, map_location=lambda storage, loc: storage)) 69 | pretrained_dict = torch.load(opt.model_denoiser, map_location=lambda storage, loc: storage) 70 | model_dict = denoiser.state_dict() 71 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 72 | model_dict.update(pretrained_dict) 73 | denoiser.load_state_dict(model_dict) 74 | print('Pre-trained Denoiser model is loaded.') 75 | 76 | if os.path.exists(opt.model_SR): 77 | model.load_state_dict(torch.load(opt.model_SR, map_location=lambda storage, loc: storage)) 78 | print('Pre-trained SR model is loaded.') 79 | 80 | def eval(): 81 | 82 | denoiser.eval() 83 | model.eval() 84 | 85 | LR_image = [join(opt.input, x) for x in listdir(opt.input) if is_image_file(x)] 86 | SR_image = [join(opt.output, x) for x in listdir(opt.input) if is_image_file(x)] 87 | 88 | 89 | for i in range(LR_image.__len__()): 90 | t0 = time.time() 91 | 92 | LR = Image.open(LR_image[i]).convert('RGB') 93 | with torch.no_grad(): 94 | prediction = chop_forward(LR) 95 | 96 | t1 = time.time() 97 | print("===> Processing: %s || Timer: %.4f sec." % (str(i), (t1 - t0))) 98 | 99 | prediction = prediction * 255.0 100 | prediction = prediction.clamp(0, 255) 101 | 102 | Image.fromarray(np.uint8(prediction)).save(SR_image[i]) 103 | 104 | 105 | 106 | transform = transforms.Compose([ 107 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 108 | ] 109 | ) 110 | 111 | 112 | def chop_forward(img): 113 | 114 | 115 | img = transform(img).unsqueeze(0) 116 | 117 | testset = utils.TensorDataset(img) 118 | test_dataloader = utils.DataLoader(testset, num_workers=opt.threads, 119 | drop_last=False, batch_size=opt.testBatchSize, shuffle=False) 120 | 121 | for iteration, batch in enumerate(test_dataloader, 1): 122 | input = Variable(batch[0]).cuda(gpus_list[0]) 123 | batch_size, channels, img_height, img_width = input.size() 124 | 125 | lowres_patches = patchify_tensor(input, patch_size=opt.patch_size, overlap=opt.stride) 126 | 127 | n_patches = lowres_patches.size(0) 128 | out_box = [] 129 | with torch.no_grad(): 130 | for p in range(n_patches): 131 | LR_input = lowres_patches[p:p + 1] 132 | std_z = torch.from_numpy(np.random.normal(0, 1, (input.shape[0], 512))).float() 133 | z = Variable(std_z, requires_grad=False).cuda(gpus_list[0]) 134 | Denoise_LR = denoiser(LR_input, z) 135 | SR = model(Denoise_LR) 136 | out_box.append(SR) 137 | 138 | out_box = torch.cat(out_box, 0) 139 | SR = recompose_tensor(out_box, opt.upscale_factor * img_height, opt.upscale_factor * img_width, 140 | overlap=opt.upscale_factor * opt.stride) 141 | 142 | 143 | SR = SR.data[0].cpu().permute(1, 2, 0) 144 | 145 | return SR 146 | 147 | 148 | 149 | 150 | ##Eval Start!!!! 151 | eval() 152 | -------------------------------------------------------------------------------- /main_denoiser.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | from math import log10 4 | 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.backends.cudnn as cudnn 10 | from torch.autograd import Variable 11 | from torch.utils.data import DataLoader 12 | from modules import VAE_denoise, VGGFeatureExtractor 13 | import torch.nn.functional as F 14 | from data import get_training_set 15 | import socket 16 | import time 17 | 18 | 19 | # Training settings 20 | parser = argparse.ArgumentParser(description='PyTorch Super Res Example') 21 | parser.add_argument('--upscale_factor', type=int, default=1, help="super resolution upscale factor") 22 | parser.add_argument('--batchSize', type=int, default=64, help='training batch size') 23 | parser.add_argument('--nEpochs', type=int, default=5000, help='number of epochs to train for') 24 | parser.add_argument('--snapshots', type=int, default=10, help='Snapshots') 25 | parser.add_argument('--start_iter', type=int, default=1, help='Starting Epoch') 26 | parser.add_argument('--lr', type=float, default=1e-4, help='Learning Rate. Default=0.01') 27 | parser.add_argument('--gpu_mode', type=bool, default=True) 28 | parser.add_argument('--threads', type=int, default=6, help='number of threads for data loader to use') 29 | parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123') 30 | parser.add_argument('--gpus', default=2, type=int, help='number of gpu') 31 | parser.add_argument('--data_dir', type=str, default='/data/NTIRE2020') 32 | parser.add_argument('--data_augmentation', type=bool, default=True) 33 | parser.add_argument('--model_type', type=str, default='VAE') 34 | parser.add_argument('--patch_size', type=int, default=128, help='Size of cropped LR image') 35 | parser.add_argument('--pretrained_sr', default='VAE_epoch_160.pth', help='sr pretrained base model') 36 | parser.add_argument('--pretrained', type=bool, default=False) 37 | parser.add_argument('--save_folder', default='models/', help='Location to save checkpoint models') 38 | parser.add_argument('--log_folder', default='logs/', help='Location to save checkpoint models') 39 | 40 | opt = parser.parse_args() 41 | gpus_list = range(opt.gpus) 42 | hostname = str(socket.gethostname()) 43 | cudnn.benchmark = True 44 | print(opt) 45 | 46 | 47 | #writer = SummaryWriter('./logs/{0}'.format(opt.log_folder)) 48 | 49 | def train(epoch): 50 | epoch_loss = 0 51 | model.train() 52 | for iteration, batch in enumerate(training_data_loader, 1): 53 | input, target = Variable(batch[0]), Variable(batch[1]) 54 | if cuda: 55 | input = input.cuda(gpus_list[0]) 56 | target = target.cuda(gpus_list[0]) 57 | 58 | optimizer.zero_grad() 59 | t0 = time.time() 60 | 61 | 62 | HR_feat = HR_feat_extractor(target).detach() 63 | Denoise_LR, KL = model(HR_feat, input) 64 | KL_loss = torch.sum(KL) 65 | # Reconstruction loss 66 | SR_loss = L1_criterion(Denoise_LR, target) 67 | 68 | 69 | loss = SR_loss + KL_loss 70 | 71 | t1 = time.time() 72 | epoch_loss += loss.data 73 | loss.backward() 74 | optimizer.step() 75 | 76 | print("===> Epoch[{}]({}/{}): SR_recon: {:.4f} KL_loss: {:.4f} || Timer: {:.4f} sec.".format(epoch, iteration, 77 | len(training_data_loader), SR_loss.data, KL_loss.data, 78 | (t1 - t0))) 79 | 80 | print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(training_data_loader))) 81 | 82 | 83 | def print_network(net): 84 | num_params = 0 85 | for param in net.parameters(): 86 | num_params += param.numel() 87 | print(net) 88 | print('Total number of parameters: %d' % num_params) 89 | 90 | 91 | def checkpoint(epoch): 92 | model_out_path = opt.save_folder + opt.model_type + "_epoch_{}.pth".format( 93 | epoch) 94 | torch.save(model.state_dict(), model_out_path) 95 | print("Checkpoint saved to {}".format(model_out_path)) 96 | 97 | 98 | cuda = opt.gpu_mode 99 | if cuda and not torch.cuda.is_available(): 100 | raise Exception("No GPU found, please run without --cuda") 101 | 102 | torch.manual_seed(opt.seed) 103 | if cuda: 104 | torch.cuda.manual_seed(opt.seed) 105 | 106 | print('===> Loading datasets') 107 | train_set = get_training_set(opt.data_dir, opt.upscale_factor, opt.patch_size, 108 | opt.data_augmentation) 109 | training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True) 110 | 111 | print('===> Building model ', opt.model_type) 112 | 113 | model = VAE_denoise(input_dim=3, dim=32, feat_size=8, z_dim=512, prior='standard') 114 | 115 | HR_feat_extractor = VGGFeatureExtractor(feature_layer=36, use_bn=False, use_input_norm=True, device='cuda') 116 | 117 | model = torch.nn.DataParallel(model) 118 | HR_feat_extractor = torch.nn.DataParallel(HR_feat_extractor) 119 | 120 | L1_criterion = nn.L1Loss() #sum for VAE 121 | L2_criterion = nn.MSELoss() 122 | 123 | 124 | print('---------- Networks architecture -------------') 125 | print_network(model) 126 | print('----------------------------------------------') 127 | 128 | if opt.pretrained: 129 | model_name = os.path.join(opt.save_folder + opt.pretrained_sr) 130 | if os.path.exists(model_name): 131 | model.load_state_dict(torch.load(model_name, map_location=lambda storage, loc: storage)) 132 | print('Pre-trained SR model is loaded.') 133 | 134 | if cuda: 135 | model = model.cuda(gpus_list[0]) 136 | HR_feat_extractor = HR_feat_extractor.cuda(gpus_list[0]) 137 | L1_criterion = L1_criterion.cuda(gpus_list[0]) 138 | 139 | optimizer = optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9, 0.999), eps=1e-8) 140 | 141 | for epoch in range(opt.start_iter, opt.nEpochs + 1): 142 | train(epoch) 143 | 144 | # learning rate is decayed by a factor of 10 every half of total epochs 145 | if (epoch + 1) % (opt.nEpochs / 2) == 0: 146 | for param_group in optimizer.param_groups: 147 | param_group['lr'] /= 10.0 148 | print('Learning rate decay: lr={}'.format(optimizer.param_groups[0]['lr'])) 149 | 150 | if epoch % (opt.snapshots) == 0: 151 | checkpoint(epoch) 152 | -------------------------------------------------------------------------------- /test_enhance.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | 4 | import os 5 | import torch 6 | from modules import VAE_SR, VAE_denoise_vali, VGGFeatureExtractor 7 | import torch.nn.functional as F 8 | import torchvision.transforms as transforms 9 | from torch.utils.data import DataLoader 10 | import numpy as np 11 | from os.path import join 12 | import time 13 | from collections import OrderedDict 14 | import math 15 | from datasets import is_image_file 16 | from image_utils import * 17 | from PIL import Image, ImageOps 18 | from os import listdir 19 | from prepare_images import * 20 | import torch.utils.data as utils 21 | from torch.autograd import Variable 22 | import os 23 | # Training settings 24 | parser = argparse.ArgumentParser(description='PyTorch Super Res Example') 25 | parser.add_argument('--upscale_factor', type=int, default=4, help="super resolution upscale factor") 26 | parser.add_argument('--testBatchSize', type=int, default=64, help='testing batch size') 27 | parser.add_argument('--gpu_mode', type=bool, default=True) 28 | parser.add_argument('--chop_forward', type=bool, default=True) 29 | parser.add_argument('--patch_size', type=int, default=128, help='0 to use original frame size') 30 | parser.add_argument('--stride', type=int, default=8, help='0 to use original patch size') 31 | parser.add_argument('--threads', type=int, default=6, help='number of threads for data loader to use') 32 | parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123') 33 | parser.add_argument('--gpus', default=2, type=int, help='number of gpu') 34 | parser.add_argument('--image_dataset', type=str, default='Test') 35 | parser.add_argument('--model_type', type=str, default='VAE') 36 | parser.add_argument('--output', default='Result', help='Location to save checkpoint models') 37 | parser.add_argument('--model_denoiser', default='models/VAE_denoiser.pth', help='sr pretrained base model') 38 | parser.add_argument('--model_SR', default='models/VAE_SR.pth', help='feature sr pretrained base model') 39 | 40 | opt = parser.parse_args() 41 | 42 | gpus_list = range(opt.gpus) 43 | print(opt) 44 | 45 | cuda = opt.gpu_mode 46 | if cuda and not torch.cuda.is_available(): 47 | raise Exception("No GPU found, please run without --cuda") 48 | 49 | torch.manual_seed(opt.seed) 50 | if cuda: 51 | torch.cuda.manual_seed(opt.seed) 52 | 53 | print('===> Building model ', opt.model_type) 54 | 55 | 56 | denoiser = VAE_denoise_vali(input_dim=3, dim=32, feat_size=8, z_dim=512, prior='standard') 57 | model = VAE_SR(input_dim=3, dim=64, scale_factor=opt.upscale_factor) 58 | 59 | denoiser = torch.nn.DataParallel(denoiser, device_ids=gpus_list) 60 | model = torch.nn.DataParallel(model, device_ids=gpus_list) 61 | if cuda: 62 | denoiser = denoiser.cuda(gpus_list[0]) 63 | model = model.cuda(gpus_list[0]) 64 | 65 | 66 | print('===> Loading datasets') 67 | 68 | if os.path.exists(opt.model_denoiser): 69 | # denoiser.load_state_dict(torch.load(opt.model_denoiser, map_location=lambda storage, loc: storage)) 70 | pretrained_dict = torch.load(opt.model_denoiser, map_location=lambda storage, loc: storage) 71 | model_dict = denoiser.state_dict() 72 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 73 | model_dict.update(pretrained_dict) 74 | denoiser.load_state_dict(model_dict) 75 | print('Pre-trained Denoiser model is loaded.') 76 | 77 | if os.path.exists(opt.model_SR): 78 | model.load_state_dict(torch.load(opt.model_SR, map_location=lambda storage, loc: storage)) 79 | print('Pre-trained SR model is loaded.') 80 | 81 | def eval(): 82 | denoiser.eval() 83 | model.eval() 84 | 85 | LR_image = [join(opt.input, x) for x in listdir(opt.input) if is_image_file(x)] 86 | SR_image = [join(opt.output, x) for x in listdir(opt.input) if is_image_file(x)] 87 | avg_psnr_predicted = 0.0 88 | 89 | for i in range(LR_image.__len__()): 90 | t0 = time.time() 91 | 92 | LR = Image.open(LR_image[i]).convert('RGB') 93 | LR_90 = LR.transpose(Image.ROTATE_90) 94 | LR_180 = LR.transpose(Image.ROTATE_180) 95 | LR_270 = LR.transpose(Image.ROTATE_270) 96 | LR_f = LR.transpose(Image.FLIP_LEFT_RIGHT) 97 | LR_90f = LR_90.transpose(Image.FLIP_LEFT_RIGHT) 98 | LR_180f = LR_180.transpose(Image.FLIP_LEFT_RIGHT) 99 | LR_270f = LR_270.transpose(Image.FLIP_LEFT_RIGHT) 100 | 101 | with torch.no_grad(): 102 | pred = chop_forward(LR) 103 | pred_90 = chop_forward(LR_90) 104 | pred_180 = chop_forward(LR_180) 105 | pred_270 = chop_forward(LR_270) 106 | pred_f = chop_forward(LR_f) 107 | pred_90f = chop_forward(LR_90f) 108 | pred_180f = chop_forward(LR_180f) 109 | pred_270f = chop_forward(LR_270f) 110 | 111 | pred_90 = np.rot90(pred_90, 3) 112 | pred_180 = np.rot90(pred_180, 2) 113 | pred_270 = np.rot90(pred_270, 1) 114 | pred_f = np.fliplr(pred_f) 115 | pred_90f = np.rot90(np.fliplr(pred_90f), 3) 116 | pred_180f = np.rot90(np.fliplr(pred_180f), 2) 117 | pred_270f = np.rot90(np.fliplr(pred_270f), 1) 118 | prediction = (pred + pred_90 + pred_180 + pred_270 + pred_f + pred_90f + pred_180f + pred_270f) * 255.0 / 8.0 119 | 120 | 121 | t1 = time.time() 122 | print("===> Processing: %s || Timer: %.4f sec." % (str(i), (t1 - t0))) 123 | 124 | prediction = prediction.clip(0, 255) 125 | 126 | Image.fromarray(np.uint8(prediction)).save(SR_image[i]) 127 | 128 | 129 | 130 | def chop_forward(img): 131 | 132 | 133 | img = transform(img).unsqueeze(0) 134 | 135 | testset = utils.TensorDataset(img) 136 | test_dataloader = utils.DataLoader(testset, num_workers=opt.threads, 137 | drop_last=False, batch_size=opt.testBatchSize, shuffle=False) 138 | 139 | for iteration, batch in enumerate(test_dataloader, 1): 140 | input = Variable(batch[0]).cuda(gpus_list[0]) 141 | batch_size, channels, img_height, img_width = input.size() 142 | 143 | lowres_patches = patchify_tensor(input, patch_size=opt.patch_size, overlap=opt.stride) 144 | 145 | n_patches = lowres_patches.size(0) 146 | out_box = [] 147 | with torch.no_grad(): 148 | for p in range(n_patches): 149 | LR_input = lowres_patches[p:p + 1] 150 | std_z = torch.from_numpy(np.random.normal(0, 1, (input.shape[0], 512))).float() 151 | z = Variable(std_z, requires_grad=False).cuda(gpus_list[0]) 152 | Denoise_LR = denoiser(LR_input, z) 153 | SR = model(Denoise_LR) 154 | out_box.append(SR) 155 | 156 | out_box = torch.cat(out_box, 0) 157 | SR = recompose_tensor(out_box, opt.upscale_factor * img_height, opt.upscale_factor * img_width, 158 | overlap=opt.upscale_factor * opt.stride) 159 | 160 | 161 | SR = SR.data[0].cpu().permute(1, 2, 0).numpy() 162 | 163 | return SR 164 | 165 | 166 | 167 | 168 | ##Eval Start!!!! 169 | eval() 170 | -------------------------------------------------------------------------------- /main_GAN.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | from math import log10 4 | 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.backends.cudnn as cudnn 10 | from torch.autograd import Variable 11 | from tensorboardX import SummaryWriter 12 | from torch.utils.data import DataLoader 13 | import torch.nn.functional as F 14 | from modules import VAE_denoise_vali, discriminator, VAE_SR, VGGFeatureExtractor 15 | from data import get_training_set 16 | import pdb 17 | import socket 18 | import numpy as np 19 | 20 | 21 | # Training settings 22 | parser = argparse.ArgumentParser(description='PyTorch Super Res Example') 23 | parser.add_argument('--upscale_factor', type=int, default=4, help="super resolution upscale factor") 24 | parser.add_argument('--batchSize', type=int, default=6, help='training batch size') 25 | parser.add_argument('--pretrained_iter', type=int, default=1000, help='number of epochs to train for') 26 | parser.add_argument('--nEpochs', type=int, default=1000, help='number of epochs to train for') 27 | parser.add_argument('--snapshots', type=int, default=5, help='Snapshots') 28 | parser.add_argument('--start_iter', type=int, default=1, help='Starting Epoch') 29 | parser.add_argument('--lr', type=float, default=1e-4, help='Learning Rate. Default=0.01') 30 | parser.add_argument('--gpu_mode', type=bool, default=True) 31 | parser.add_argument('--threads', type=int, default=6, help='number of threads for data loader to use') 32 | parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123') 33 | parser.add_argument('--gpus', default=2, type=int, help='number of gpu') 34 | parser.add_argument('--data_dir', type=str, default='/data/NTIRE2020') 35 | parser.add_argument('--data_augmentation', type=bool, default=True) 36 | parser.add_argument('--patch_size', type=int, default=128, help='Size of cropped LR image') 37 | parser.add_argument('--pretrained_sr', default='VAE_good_v5.pth', help='sr pretrained base model') 38 | parser.add_argument('--pretrained_D', default='GAN_discriminator_110.pth', help='sr pretrained base model') 39 | parser.add_argument('--model_type', default='GAN', help='model name') 40 | parser.add_argument('--pretrained', type=bool, default=True) 41 | parser.add_argument('--pretrain_flag', default=False, help='pretrain generator') 42 | parser.add_argument('--save_folder', default='models/', help='Location to save checkpoint models') 43 | parser.add_argument('--log_folder', default='logs/', help='Location to save checkpoint models') 44 | 45 | opt = parser.parse_args() 46 | gpus_list = range(opt.gpus) 47 | hostname = str(socket.gethostname()) 48 | cudnn.benchmark = True 49 | print(opt) 50 | 51 | 52 | 53 | def train(epoch): 54 | G_epoch_loss = 0 55 | D_epoch_loss = 0 56 | adv_epoch_loss = 0 57 | vgg_epoch_loss = 0 58 | recon_epoch_loss = 0 59 | G.train() 60 | D.train() 61 | feat_extractor.eval() 62 | denoiser.eval() 63 | for iteration, batch in enumerate(training_data_loader, 1): 64 | input, target, ref = batch[0], batch[1], batch[2] 65 | minibatch = input.size()[0] 66 | real_label = torch.ones(minibatch) 67 | fake_label = torch.zeros(minibatch) 68 | 69 | if cuda: 70 | input = Variable(input).cuda(gpus_list[0]) 71 | target = Variable(target).cuda(gpus_list[0]) 72 | ref = Variable(ref).cuda(gpus_list[0]) 73 | real_label = Variable(real_label).cuda(gpus_list[0]) 74 | fake_label = Variable(fake_label).cuda(gpus_list[0]) 75 | 76 | down = torch.nn.Upsample(scale_factor=0.25, mode='bicubic') 77 | up = torch.nn.Upsample(scale_factor=4, mode='bicubic') 78 | 79 | # Reset gradient 80 | for p in D.parameters(): 81 | p.requires_grad = False 82 | 83 | 84 | G_optimizer.zero_grad() 85 | down_ref = down(ref) 86 | with torch.no_grad(): 87 | std_z = torch.from_numpy(np.random.normal(0, 1, (input.shape[0], 512))).float() 88 | z = Variable(std_z, requires_grad=False).cuda(gpus_list[0]) 89 | Denoise_LR = denoiser(input, z) 90 | 91 | SR = G(Denoise_LR) 92 | SR_tar = G(target) 93 | 94 | SR_feat = feat_extractor(SR).detach() 95 | SR_tar_feat = feat_extractor(SR_tar).detach() 96 | Tar_feat = feat_extractor(target).detach() 97 | 98 | 99 | D_fake_decision_1 = D(SR) 100 | D_fake_decision_2 = D(SR_tar) 101 | D_real_decision = D(ref).detach() 102 | 103 | GAN_loss = (BCE_loss(D_fake_decision_1, real_label) 104 | + BCE_loss(D_fake_decision_2, real_label) 105 | + BCE_loss(D_real_decision, fake_label)) / 3.0 106 | 107 | 108 | recon_loss = (L1_loss(down(SR), target) + L1_loss(SR_tar, SR)) / 2.0 109 | vgg_loss = (L1_loss(down(SR_feat), Tar_feat) + L1_loss(SR_tar_feat, SR_feat)) / 2.0 110 | 111 | G_loss = 1.0 * vgg_loss + 1.0 * recon_loss + 0.0001 * GAN_loss 112 | 113 | G_loss.backward() 114 | G_optimizer.step() 115 | 116 | # Reset gradient 117 | for p in D.parameters(): 118 | p.requires_grad = True 119 | 120 | D_optimizer.zero_grad() 121 | 122 | # Train discriminator with real data 123 | D_real_decision = D(ref) 124 | # Train discriminator with fake data 125 | D_fake_decision_1 = D(SR_tar.detach()) 126 | D_fake_decision_2 = D(SR.detach()) 127 | 128 | 129 | Dis_loss = (BCE_loss(D_real_decision, real_label) 130 | + BCE_loss(D_fake_decision_1, fake_label) 131 | + BCE_loss(D_fake_decision_2, fake_label)) / 3.0 132 | 133 | # Back propagation 134 | D_loss = Dis_loss 135 | D_loss.backward() 136 | D_optimizer.step() 137 | 138 | # log 139 | G_epoch_loss += G_loss.data 140 | D_epoch_loss += D_loss.data 141 | adv_epoch_loss += (GAN_loss.data) 142 | recon_epoch_loss += (recon_loss.data) 143 | vgg_epoch_loss += (vgg_loss.data) 144 | 145 | writer.add_scalars('Train_Loss', {'G_loss': G_loss.data, 146 | 'D_loss': D_loss.data, 147 | 'VGG_loss': vgg_loss.data, 148 | 'Adv_loss': GAN_loss.data, 149 | 'Recon_loss': recon_loss.data 150 | }, epoch, iteration) 151 | print( 152 | "===> Epoch[{}]({}/{}): G_loss: {:.4f} || D_loss: {:.4f} || Adv: {:.4f} || Recon_Loss: {:.4f} || VGG_Loss: {:.4f}".format( 153 | epoch, iteration, 154 | len(training_data_loader), G_loss.data, D_loss.data, GAN_loss.data, recon_loss.data, vgg_loss.data)) 155 | print( 156 | "===> Epoch {} Complete: Avg. G_loss: {:.4f} D_loss: {:.4f} Recon_loss: {:.4f} Adv: {:.4f}".format( 157 | epoch, G_epoch_loss / len(training_data_loader), D_epoch_loss / len(training_data_loader), 158 | recon_epoch_loss / len(training_data_loader), 159 | adv_epoch_loss / len(training_data_loader))) 160 | 161 | 162 | def print_network(net): 163 | num_params = 0 164 | for param in net.parameters(): 165 | num_params += param.numel() 166 | print(net) 167 | print('Total number of parameters: %d' % num_params) 168 | 169 | 170 | def checkpoint(epoch, pretrained_flag=False): 171 | if pretrained_flag: 172 | model_out_G = opt.save_folder + opt.model_type + "_pretrain_{}.pth".format(epoch) 173 | torch.save(G.state_dict(), model_out_G) 174 | print("Checkpoint saved to {}".format(model_out_G)) 175 | else: 176 | model_out_G = opt.save_folder + opt.model_type + "_generator_{}.pth".format(epoch) 177 | model_out_D = opt.save_folder + opt.model_type + "_discriminator_{}.pth".format(epoch) 178 | torch.save(G.state_dict(), model_out_G) 179 | torch.save(D.state_dict(), model_out_D) 180 | print("Checkpoint saved to {} and {}".format(model_out_G, model_out_D)) 181 | 182 | 183 | cuda = opt.gpu_mode 184 | if cuda and not torch.cuda.is_available(): 185 | raise Exception("No GPU found, please run without --cuda") 186 | 187 | torch.manual_seed(opt.seed) 188 | if cuda: 189 | torch.cuda.manual_seed(opt.seed) 190 | 191 | print('===> Loading datasets') 192 | train_set = get_training_set(opt.data_dir, opt.upscale_factor, opt.patch_size, 193 | opt.data_augmentation) 194 | training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True) 195 | 196 | print('===> Building model') 197 | 198 | denoiser = VAE_denoise_vali(input_dim=3, dim=32, feat_size=8, z_dim=512, prior='standard') 199 | G = VAE_SR(input_dim=3, dim=64, scale_factor=opt.upscale_factor) 200 | D = discriminator(num_channels=3, base_filter=64, image_size=opt.patch_size * opt.upscale_factor) 201 | feat_extractor = VGGFeatureExtractor(feature_layer=34, use_bn=False, use_input_norm=True, device='cuda') 202 | 203 | 204 | denoiser = torch.nn.DataParallel(denoiser, device_ids=gpus_list) 205 | G = torch.nn.DataParallel(G, device_ids=gpus_list) 206 | D = torch.nn.DataParallel(D, device_ids=gpus_list) 207 | feat_extractor = torch.nn.DataParallel(feat_extractor, device_ids=gpus_list) 208 | 209 | 210 | L1_loss = nn.L1Loss() 211 | BCE_loss = nn.BCEWithLogitsLoss() 212 | 213 | 214 | print('---------- Generator architecture -------------') 215 | print_network(G) 216 | print('---------- Discriminator architecture -------------') 217 | print_network(D) 218 | print('----------------------------------------------') 219 | 220 | model_denoiser = os.path.join(opt.save_folder + 'VAE_denoiser.pth') 221 | denoiser.load_state_dict(torch.load(model_denoiser, map_location=lambda storage, loc: storage)) 222 | print('Pre-trained Denoiser model is loaded.') 223 | 224 | if opt.pretrained: 225 | model_G = os.path.join(opt.save_folder + opt.pretrained_sr) 226 | model_D = os.path.join(opt.save_folder + opt.pretrained_D) 227 | if os.path.exists(model_G): 228 | G.load_state_dict(torch.load(model_G, map_location=lambda storage, loc: storage)) 229 | print('Pre-trained Generator model is loaded.') 230 | if os.path.exists(model_D): 231 | D.load_state_dict(torch.load(model_D, map_location=lambda storage, loc: storage)) 232 | print('Pre-trained Discriminator model is loaded.') 233 | 234 | if cuda: 235 | denoiser = denoiser.cuda(gpus_list[0]) 236 | G = G.cuda(gpus_list[0]) 237 | D = D.cuda(gpus_list[0]) 238 | HR_feat_extractor = HR_feat_extractor.cuda(gpus_list[0]) 239 | feat_extractor = feat_extractor.cuda(gpus_list[0]) 240 | L1_loss = L1_loss.cuda(gpus_list[0]) 241 | BCE_loss = BCE_loss.cuda(gpus_list[0]) 242 | Lap_loss = Lap_loss.cuda(gpus_list[0]) 243 | 244 | 245 | G_optimizer = optim.Adam(G.parameters(), lr=opt.lr, weight_decay=0, betas=(0.9, 0.999), eps=1e-8) 246 | D_optimizer = optim.Adam(D.parameters(), lr=opt.lr, weight_decay=0, betas=(0.9, 0.999), eps=1e-8) 247 | 248 | if opt.pretrain_flag: 249 | print('Pre-training starts.') 250 | for epoch in range(1, opt.pretrained_iter + 1): 251 | train_pretrained(epoch) 252 | 253 | if epoch % 100 == 0: 254 | checkpoint(epoch, pretrained_flag=True) 255 | print('Pre-training finished.') 256 | 257 | writer = SummaryWriter(opt.log_folder) 258 | for epoch in range(opt.start_iter, opt.nEpochs + 1): 259 | train(epoch) 260 | 261 | if (epoch + 1) % (opt.nEpochs / 2) == 0: 262 | for param_group in G_optimizer.param_groups: 263 | param_group['lr'] /= 10.0 264 | print('G: Learning rate decay: lr={}'.format(G_optimizer.param_groups[0]['lr'])) 265 | for param_group in D_optimizer.param_groups: 266 | param_group['lr'] /= 10.0 267 | print('D: Learning rate decay: lr={}'.format(D_optimizer.param_groups[0]['lr'])) 268 | 269 | if (epoch + 1) % (opt.snapshots) == 0: 270 | checkpoint(epoch + 1) 271 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions.normal import Normal 5 | from torch.distributions import kl_divergence 6 | from distributions import log_Bernoulli, log_Normal_diag, log_Normal_standard, log_Logistic_256 7 | from torch.autograd import Variable 8 | import torchvision 9 | import numpy as np 10 | 11 | 12 | 13 | 14 | 15 | 16 | class ConvBlock(torch.nn.Module): 17 | def __init__(self, input_size, output_size, kernel_size, stride, padding, bias=True): 18 | super(ConvBlock, self).__init__() 19 | 20 | self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size, stride, padding, bias=bias) 21 | 22 | self.act = torch.nn.PReLU() 23 | 24 | def forward(self, x): 25 | out = self.conv(x) 26 | 27 | return self.act(out) 28 | 29 | 30 | class DeconvBlock(torch.nn.Module): 31 | def __init__(self, input_size, output_size, kernel_size, stride, padding, bias=True): 32 | super(DeconvBlock, self).__init__() 33 | 34 | self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, bias=bias) 35 | 36 | self.act = torch.nn.PReLU() 37 | 38 | def forward(self, x): 39 | out = self.deconv(x) 40 | 41 | return self.act(out) 42 | 43 | 44 | class UpBlock(torch.nn.Module): 45 | def __init__(self, input_size, output_size, kernel_size, stride, padding): 46 | super(UpBlock, self).__init__() 47 | 48 | self.conv1 = DeconvBlock(input_size, output_size, kernel_size, stride, padding, bias=True) 49 | self.conv2 = ConvBlock(output_size, output_size, kernel_size, stride, padding, bias=True) 50 | self.conv3 = DeconvBlock(output_size, output_size, kernel_size, stride, padding, bias=True) 51 | self.local_weight1 = ConvBlock(input_size, output_size, kernel_size=1, stride=1, padding=0, bias=True) 52 | self.local_weight2 = ConvBlock(output_size, output_size, kernel_size=1, stride=1, padding=0, bias=True) 53 | 54 | def forward(self, x): 55 | hr = self.conv1(x) 56 | lr = self.conv2(hr) 57 | residue = self.local_weight1(x) - lr 58 | h_residue = self.conv3(residue) 59 | hr_weight = self.local_weight2(hr) 60 | return hr_weight + h_residue 61 | 62 | 63 | class DownBlock(torch.nn.Module): 64 | def __init__(self, input_size, output_size, kernel_size, stride, padding): 65 | super(DownBlock, self).__init__() 66 | 67 | self.conv1 = ConvBlock(input_size, output_size, kernel_size, stride, padding, bias=True) 68 | self.conv2 = DeconvBlock(output_size, output_size, kernel_size, stride, padding, bias=True) 69 | self.conv3 = ConvBlock(output_size, output_size, kernel_size, stride, padding, bias=True) 70 | self.local_weight1 = ConvBlock(input_size, output_size, kernel_size=1, stride=1, padding=0, bias=True) 71 | self.local_weight2 = ConvBlock(output_size, output_size, kernel_size=1, stride=1, padding=0, bias=True) 72 | 73 | def forward(self, x): 74 | lr = self.conv1(x) 75 | hr = self.conv2(lr) 76 | residue = self.local_weight1(x) - hr 77 | l_residue = self.conv3(residue) 78 | lr_weight = self.local_weight2(lr) 79 | return lr_weight + l_residue 80 | 81 | class ResnetBlock(torch.nn.Module): 82 | def __init__(self, num_filter, kernel_size=3, stride=1, padding=1, bias=True): 83 | super(ResnetBlock, self).__init__() 84 | self.conv1 = torch.nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding, bias=bias) 85 | self.conv2 = torch.nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding, bias=bias) 86 | 87 | self.act1 = torch.nn.ReLU(inplace=True) 88 | self.act2 = torch.nn.ReLU(inplace=True) 89 | 90 | 91 | def forward(self, x): 92 | 93 | out = self.act1(x) 94 | out = self.conv1(out) 95 | 96 | out = self.act2(out) 97 | out = self.conv2(out) 98 | 99 | out = out + x 100 | 101 | return out 102 | 103 | 104 | 105 | 106 | class VAE_denoise(nn.Module): 107 | def __init__(self, input_dim, dim, feat_size, z_dim, prior): 108 | super(VAE_denoise, self).__init__() 109 | 110 | self.LR_feat = nn.Sequential( 111 | ConvBlock(input_dim, 2*dim, 3, 1, 1), 112 | ConvBlock(2*dim, 2*dim, 3, 1, 1), 113 | ConvBlock(2*dim, dim, 3, 1, 1), 114 | ) 115 | 116 | self.denoise_feat = nn.Sequential( 117 | ConvBlock(2*input_dim, 2*dim, 3, 1, 1), 118 | ConvBlock(2*dim, 2*dim, 3, 1, 1), 119 | ConvBlock(2*dim, dim, 3, 1, 1), 120 | ) 121 | 122 | self.decoder = nn.Sequential( 123 | ConvBlock(4 * dim, 4 * dim, 1, 1, 0), 124 | DeconvBlock(4 * dim, 4 * dim, 6, 4, 1), 125 | DeconvBlock(4 * dim, 2 * dim, 6, 4, 1), 126 | ConvBlock(2 * dim, dim, 3, 1, 1), 127 | ) 128 | 129 | self.SR_recon = nn.Sequential( 130 | ResnetBlock(dim, 3, 1, 1), 131 | ResnetBlock(dim, 3, 1, 1), 132 | ResnetBlock(dim, 3, 1, 1), 133 | ) 134 | 135 | 136 | self.SR_mu = nn.Sequential( 137 | nn.Conv2d(dim, input_dim, 3, 1, 1), 138 | ) 139 | 140 | self.SR_final = nn.Sequential( 141 | nn.Conv2d(dim, input_dim, 3, 1, 1), 142 | ) 143 | 144 | 145 | self.prior = prior 146 | self.feat_size = feat_size 147 | 148 | self.VAE_encoder = nn.Sequential( 149 | nn.Linear(8192, 4096), 150 | nn.Sigmoid() 151 | ) 152 | 153 | self.q_z_mu = nn.Linear(4096, z_dim) 154 | self.q_z_logvar = nn.Sequential( 155 | nn.Linear(4096, z_dim), 156 | nn.Hardtanh(min_val=-6., max_val=2.), 157 | ) 158 | 159 | 160 | self.VAE_decoder = nn.Sequential( 161 | nn.Linear(z_dim, 4096), 162 | nn.Sigmoid(), 163 | nn.Linear(4096, 8192), 164 | nn.Sigmoid(), 165 | ) 166 | 167 | for m in self.modules(): 168 | class_name = m.__class__.__name__ 169 | if class_name.find('Conv2d') != -1: 170 | torch.nn.init.kaiming_normal_(m.weight) 171 | if m.bias is not None: 172 | m.bias.data.zero_() 173 | elif class_name.find('ConvTranspose2d') != -1: 174 | torch.nn.init.kaiming_normal_(m.weight) 175 | if m.bias is not None: 176 | m.bias.data.zero_() 177 | elif class_name.find('Linear') != -1: 178 | torch.nn.init.kaiming_normal_(m.weight) 179 | if m.bias is not None: 180 | m.bias.data.zero_() 181 | 182 | def log_p_z(self, z, prior): 183 | if prior == 'standard': 184 | log_prior = log_Normal_standard(z, dim=1) 185 | 186 | else: 187 | raise Exception('Wrong name of the prior!') 188 | 189 | return log_prior 190 | 191 | def reparameterize(self, mu, logvar, flag=0): 192 | if flag == 0: 193 | std = logvar.mul(0.5).exp_() 194 | eps = torch.cuda.FloatTensor(std.size()).normal_() 195 | eps = Variable(eps) 196 | z = eps.mul(std).add_(mu) 197 | else: 198 | std = logvar.mul(0.5).exp_() 199 | eps = torch.from_numpy(np.random.normal(0, 0.05, size=(std.size(0), 1, std.size(2), std.size(3)))).float() 200 | eps = Variable(eps).cuda() 201 | eps = eps.repeat(1, 3, 1, 1) 202 | z = eps.mul(std).add_(mu) 203 | 204 | return z 205 | 206 | def encode(self, HR_feat): 207 | 208 | x = self.VAE_encoder(HR_feat.view(HR_feat.size(0), -1)) 209 | z_q_mu = self.q_z_mu(x) 210 | z_q_logvar = self.q_z_logvar(x) 211 | 212 | return z_q_mu, z_q_logvar 213 | 214 | def decode(self, LR, z_q): 215 | up = torch.nn.Upsample(scale_factor=4, mode='bicubic') 216 | LR_feat = self.LR_feat(LR) 217 | dec_feat = self.VAE_decoder(z_q) 218 | dec_feat = dec_feat.view(dec_feat.size(0), -1, self.feat_size, self.feat_size) 219 | 220 | mu_feat = self.decoder(dec_feat) 221 | 222 | com_feat = LR_feat - mu_feat 223 | SR_feat = self.SR_recon(com_feat) 224 | Denoise_LR = LR - self.SR_mu(SR_feat) 225 | 226 | return Denoise_LR 227 | 228 | 229 | def forward(self, HR_feat, LR): 230 | z_q_mu, z_q_logvar = self.encode(HR_feat) 231 | 232 | # reparameterize 233 | z_q = self.reparameterize(z_q_mu, z_q_logvar, flag=0) 234 | # prior 235 | log_p_z = self.log_p_z(z_q, self.prior) 236 | # KL 237 | log_q_z = log_Normal_diag(z_q, z_q_mu, z_q_logvar, dim=1) 238 | KL = -(log_p_z - log_q_z) 239 | KL = torch.sum(KL) 240 | 241 | Denoise_LR = self.decode(LR, z_q) 242 | 243 | 244 | return Denoise_LR, KL 245 | 246 | 247 | 248 | class VAE_denoise_vali(nn.Module): 249 | def __init__(self, input_dim, dim, feat_size, z_dim, prior): 250 | super(VAE_denoise_vali, self).__init__() 251 | 252 | self.LR_feat = nn.Sequential( 253 | ConvBlock(input_dim, 2*dim, 3, 1, 1), 254 | ConvBlock(2*dim, 2*dim, 3, 1, 1), 255 | ConvBlock(2*dim, dim, 3, 1, 1), 256 | ) 257 | 258 | self.decoder = nn.Sequential( 259 | ConvBlock(4 * dim, 4 * dim, 1, 1, 0), 260 | DeconvBlock(4 * dim, 4 * dim, 6, 4, 1), 261 | DeconvBlock(4 * dim, 2 * dim, 6, 4, 1), 262 | ConvBlock(2 * dim, dim, 3, 1, 1), 263 | ) 264 | 265 | self.SR_recon = nn.Sequential( 266 | ResnetBlock(dim, 3, 1, 1), 267 | ResnetBlock(dim, 3, 1, 1), 268 | ResnetBlock(dim, 3, 1, 1), 269 | ) 270 | 271 | self.SR_mu = nn.Sequential( 272 | nn.Conv2d(dim, input_dim, 3, 1, 1), 273 | ) 274 | self.prior = prior 275 | self.feat_size = feat_size 276 | 277 | self.VAE_encoder = nn.Sequential( 278 | nn.Linear(8192, 4096), 279 | nn.Sigmoid() 280 | ) 281 | 282 | self.q_z_mu = nn.Linear(4096, z_dim) 283 | self.q_z_logvar = nn.Sequential( 284 | nn.Linear(4096, z_dim), 285 | nn.Hardtanh(min_val=-6., max_val=2.), 286 | ) 287 | 288 | self.VAE_decoder = nn.Sequential( 289 | nn.Linear(z_dim, 4096), 290 | nn.Sigmoid(), 291 | nn.Linear(4096, 8192), 292 | nn.Sigmoid(), 293 | ) 294 | 295 | for m in self.modules(): 296 | class_name = m.__class__.__name__ 297 | if class_name.find('Conv2d') != -1: 298 | torch.nn.init.kaiming_normal_(m.weight) 299 | if m.bias is not None: 300 | m.bias.data.zero_() 301 | elif class_name.find('ConvTranspose2d') != -1: 302 | torch.nn.init.kaiming_normal_(m.weight) 303 | if m.bias is not None: 304 | m.bias.data.zero_() 305 | elif class_name.find('Linear') != -1: 306 | torch.nn.init.kaiming_normal_(m.weight) 307 | if m.bias is not None: 308 | m.bias.data.zero_() 309 | 310 | 311 | def decode(self, LR, z_q): 312 | up = torch.nn.Upsample(scale_factor=4, mode='bicubic') 313 | LR_feat = self.LR_feat(LR) 314 | dec_feat = self.VAE_decoder(z_q) 315 | dec_feat = dec_feat.view(dec_feat.size(0), -1, self.feat_size, self.feat_size) 316 | 317 | mu_feat = self.decoder(dec_feat) 318 | 319 | com_feat = LR_feat - mu_feat 320 | SR_feat = self.SR_recon(com_feat) 321 | Denoise_LR = LR - self.SR_mu(SR_feat) 322 | 323 | return Denoise_LR 324 | 325 | 326 | def forward(self, LR, z_q): 327 | 328 | Denoise_LR = self.decode(LR, z_q) 329 | 330 | return Denoise_LR 331 | 332 | 333 | class VAE_SR(nn.Module): 334 | def __init__(self, input_dim, dim, scale_factor): 335 | super(VAE_SR, self).__init__() 336 | self.up = torch.nn.Upsample(scale_factor=4, mode='bicubic') 337 | self.LR_feat = ConvBlock(input_dim, dim, 3, 1, 1) 338 | self.feat = nn.Sequential( 339 | ResnetBlock(dim, 3, 1, 1, bias=True), 340 | ResnetBlock(dim, 3, 1, 1, bias=True), 341 | ResnetBlock(dim, 3, 1, 1, bias=True), 342 | ResnetBlock(dim, 3, 1, 1, bias=True), 343 | ) 344 | self.recon = nn.Sequential( 345 | nn.Conv2d(dim, input_dim, 3, 1, 1) 346 | ) 347 | 348 | for m in self.modules(): 349 | class_name = m.__class__.__name__ 350 | if class_name.find('Conv2d') != -1: 351 | torch.nn.init.kaiming_normal_(m.weight) 352 | if m.bias is not None: 353 | m.bias.data.zero_() 354 | elif class_name.find('ConvTranspose2d') != -1: 355 | torch.nn.init.kaiming_normal_(m.weight) 356 | if m.bias is not None: 357 | m.bias.data.zero_() 358 | elif class_name.find('Linear') != -1: 359 | torch.nn.init.kaiming_normal_(m.weight) 360 | if m.bias is not None: 361 | m.bias.data.zero_() 362 | 363 | 364 | def forward(self, LR): 365 | LR_feat = self.LR_feat(self.up(LR)) 366 | LR_feat = self.feat(LR_feat) 367 | SR = self.recon(LR_feat) 368 | 369 | return SR 370 | 371 | 372 | 373 | 374 | class discriminator(nn.Module): 375 | def __init__(self, num_channels, base_filter, image_size): 376 | super(discriminator, self).__init__() 377 | self.image_size = image_size 378 | 379 | self.input_conv = ConvBlock(num_channels, base_filter, 3, 1, 1)#512 380 | self.conv_blocks = nn.Sequential( 381 | nn.MaxPool2d(4, 4, 0), 382 | nn.BatchNorm2d(base_filter), 383 | ConvBlock(base_filter, base_filter, 3, 1, 1),#128 384 | nn.MaxPool2d(4,4,0), 385 | nn.BatchNorm2d(base_filter), 386 | ConvBlock(base_filter, base_filter * 2, 3, 1, 1),#32 387 | ConvBlock(base_filter * 2, base_filter * 2, 4, 2, 1),#16 388 | nn.BatchNorm2d(base_filter * 2), 389 | ConvBlock(base_filter * 2, base_filter * 4, 3, 1, 1), 390 | ConvBlock(base_filter * 4, base_filter * 4, 4, 2, 1),#8 391 | nn.BatchNorm2d(base_filter * 4), 392 | ConvBlock(base_filter * 4, base_filter * 8, 3, 1, 1), 393 | ConvBlock(base_filter * 8, base_filter * 8, 4, 2, 1),#4 394 | nn.BatchNorm2d(base_filter * 8), 395 | ) 396 | 397 | self.classifier = nn.Sequential( 398 | nn.Linear(512 * 4 * 4, 100), 399 | nn.ReLU(), 400 | # nn.BatchNorm1d(100), 401 | nn.Linear(100, 1), 402 | ) 403 | 404 | for m in self.modules(): 405 | classname = m.__class__.__name__ 406 | if classname.find('Conv2d') != -1: 407 | torch.nn.init.kaiming_normal_(m.weight) 408 | if m.bias is not None: 409 | m.bias.data.zero_() 410 | elif classname.find('ConvTranspose2d') != -1: 411 | torch.nn.init.kaiming_normal_(m.weight) 412 | if m.bias is not None: 413 | m.bias.data.zero_() 414 | 415 | def forward(self, x): 416 | out = self.input_conv(x) 417 | out = self.conv_blocks(out) 418 | out = out.view(out.size()[0], -1) 419 | out = self.classifier(out).view(-1) 420 | return out 421 | 422 | 423 | 424 | 425 | class VGGFeatureExtractor(nn.Module): 426 | def __init__(self, 427 | feature_layer=34, 428 | use_bn=False, 429 | use_input_norm=True, 430 | device=torch.device('cpu')): 431 | super(VGGFeatureExtractor, self).__init__() 432 | if use_bn: 433 | model = torchvision.models.vgg19_bn(pretrained=True) 434 | else: 435 | model = torchvision.models.vgg19(pretrained=True) 436 | 437 | self.use_input_norm = use_input_norm 438 | if self.use_input_norm: 439 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) 440 | # [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1] 441 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) 442 | # [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1] 443 | self.register_buffer('mean', mean) 444 | self.register_buffer('std', std) 445 | self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) 446 | # No need to BP to variable 447 | for param in self.parameters(): 448 | param.requires_grad = False 449 | # self.act = nn.Sigmoid() 450 | 451 | def forward(self, x): 452 | if self.use_input_norm: 453 | x = (x - self.mean) / self.std 454 | output = self.features(x) 455 | 456 | return output 457 | 458 | --------------------------------------------------------------------------------