├── LICENSE ├── README.md ├── data ├── generate_train_srresnet.m ├── modcrop.m └── store2hdf5.m ├── dataset.py ├── demo.py ├── eval.py ├── main_srresnet.py ├── model └── model_srresnet.pth ├── result └── result.png ├── srresnet.py └── testsets ├── Set14 ├── baboon.mat ├── barbara.mat ├── bridge.mat ├── coastguard.mat ├── comic.mat ├── face.mat ├── flowers.mat ├── foreman.mat ├── lenna.mat ├── man.mat ├── monarch.mat ├── pepper.mat ├── ppt3.mat └── zebra.mat └── Set5 ├── baby_GT.mat ├── bird_GT.mat ├── butterfly_GT.mat ├── head_GT.mat └── woman_GT.mat /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2017- Jiu XU 4 | Copyright (c) 2017- Rakuten, Inc 5 | Copyright (c) 2017- Rakuten Institute of Technology 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch SRResNet 2 | Implementation of Paper: "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"(https://arxiv.org/abs/1609.04802) in PyTorch 3 | 4 | ## Usage 5 | ### Training 6 | ``` 7 | usage: main_srresnet.py [-h] [--batchSize BATCHSIZE] [--nEpochs NEPOCHS] 8 | [--lr LR] [--step STEP] [--cuda] [--resume RESUME] 9 | [--start-epoch START_EPOCH] [--threads THREADS] 10 | [--pretrained PRETRAINED] [--vgg_loss] [--gpus GPUS] 11 | 12 | optional arguments: 13 | -h, --help show this help message and exit 14 | --batchSize BATCHSIZE 15 | training batch size 16 | --nEpochs NEPOCHS number of epochs to train for 17 | --lr LR Learning Rate. Default=1e-4 18 | --step STEP Sets the learning rate to the initial LR decayed by 19 | momentum every n epochs, Default: n=500 20 | --cuda Use cuda? 21 | --resume RESUME Path to checkpoint (default: none) 22 | --start-epoch START_EPOCH 23 | Manual epoch number (useful on restarts) 24 | --threads THREADS Number of threads for data loader to use, Default: 1 25 | --pretrained PRETRAINED 26 | path to pretrained model (default: none) 27 | --vgg_loss Use content loss? 28 | --gpus GPUS gpu ids (default: 0) 29 | ``` 30 | An example of training usage is shown as follows: 31 | ``` 32 | python main_srresnet.py --cuda --vgg_loss --gpus 0 33 | ``` 34 | 35 | ### demo 36 | ``` 37 | usage: demo.py [-h] [--cuda] [--model MODEL] [--image IMAGE] 38 | [--dataset DATASET] [--scale SCALE] [--gpus GPUS] 39 | 40 | optional arguments: 41 | -h, --help show this help message and exit 42 | --cuda use cuda? 43 | --model MODEL model path 44 | --image IMAGE image name 45 | --dataset DATASET dataset name 46 | --scale SCALE scale factor, Default: 4 47 | --gpus GPUS gpu ids (default: 0) 48 | ``` 49 | We convert Set5 test set images to mat format using Matlab, for simple image reading 50 | An example of usage is shown as follows: 51 | ``` 52 | python demo.py --model model/model_srresnet.pth --dataset Set5 --image butterfly_GT --scale 4 --cuda 53 | ``` 54 | 55 | ### Eval 56 | ``` 57 | usage: eval.py [-h] [--cuda] [--model MODEL] [--dataset DATASET] 58 | [--scale SCALE] [--gpus GPUS] 59 | 60 | optional arguments: 61 | -h, --help show this help message and exit 62 | --cuda use cuda? 63 | --model MODEL model path 64 | --dataset DATASET dataset name, Default: Set5 65 | --scale SCALE scale factor, Default: 4 66 | --gpus GPUS gpu ids (default: 0) 67 | ``` 68 | We convert Set5 test set images to mat format using Matlab. Since PSNR is evaluated on only Y channel, we import matlab in python, and use rgb2ycbcr function for converting rgb image to ycbcr image. You will have to setup the matlab python interface so as to import matlab library. 69 | An example of usage is shown as follows: 70 | ``` 71 | python eval.py --model model/model_srresnet.pth --dataset Set5 --cuda 72 | ``` 73 | 74 | ### Prepare Training dataset 75 | - Please refer [Code for Data Generation](https://github.com/twtygqyy/pytorch-SRResNet/tree/master/data) for creating training files. 76 | - Data augmentations including flipping, rotation, downsizing are adopted. 77 | 78 | 79 | ### Performance 80 | - We provide a pretrained model trained on [291](http://cv.snu.ac.kr/research/VDSR/train_data.zip) images with data augmentation 81 | - Instance Normalization is applied instead of Batch Normalization for better performance 82 | - So far performance in PSNR is not as good as paper, any suggestion is welcome 83 | 84 | | Dataset | SRResNet Paper | SRResNet PyTorch| 85 | | :-------------:|:--------------:|:---------------:| 86 | | Set5 | 32.05 | **31.80** | 87 | | Set14 | 28.49 | **28.25** | 88 | | BSD100 | 27.58 | **27.51** | 89 | 90 | ### Result 91 | From left to right are ground truth, bicubic and SRResNet 92 |

93 | 94 |

95 | -------------------------------------------------------------------------------- /data/generate_train_srresnet.m: -------------------------------------------------------------------------------- 1 | clear; 2 | close all; 3 | folder = 'path/to/train/folder'; 4 | 5 | savepath = 'srresnet_x4.h5'; 6 | 7 | %% scale factors 8 | scale = 4; 9 | 10 | size_label = 96; 11 | size_input = size_label/scale; 12 | stride = 48; 13 | 14 | %% downsizing 15 | downsizes = [1,0.7,0.5]; 16 | 17 | data = zeros(size_input, size_input, 3, 1); 18 | label = zeros(size_label, size_label, 3, 1); 19 | 20 | count = 0; 21 | margain = 0; 22 | 23 | %% generate data 24 | filepaths = []; 25 | filepaths = [filepaths; dir(fullfile(folder, '*.jpg'))]; 26 | filepaths = [filepaths; dir(fullfile(folder, '*.bmp'))]; 27 | filepaths = [filepaths; dir(fullfile(folder, '*.png'))]; 28 | 29 | length(filepaths) 30 | 31 | for i = 1 : length(filepaths) 32 | for flip = 1: 3 33 | for degree = 1 : 4 34 | for downsize = 1 : length(downsizes) 35 | image = imread(fullfile(folder,filepaths(i).name)); 36 | if flip == 1 37 | image = flipdim(image ,1); 38 | end 39 | if flip == 2 40 | image = flipdim(image ,2); 41 | end 42 | 43 | image = imrotate(image, 90 * (degree - 1)); 44 | image = imresize(image,downsizes(downsize),'bicubic'); 45 | 46 | if size(image,3)==3 47 | %image = rgb2ycbcr(image); 48 | image = im2double(image); 49 | im_label = modcrop(image, scale); 50 | [hei,wid, c] = size(im_label); 51 | 52 | filepaths(i).name 53 | for x = 1 + margain : stride : hei-size_label+1 - margain 54 | for y = 1 + margain :stride : wid-size_label+1 - margain 55 | subim_label = im_label(x : x+size_label-1, y : y+size_label-1, :); 56 | subim_input = imresize(subim_label,1/scale,'bicubic'); 57 | % figure; 58 | % imshow(subim_input); 59 | % figure; 60 | % imshow(subim_label); 61 | count=count+1; 62 | data(:, :, :, count) = subim_input; 63 | label(:, :, :, count) = subim_label; 64 | end 65 | end 66 | end 67 | end 68 | end 69 | end 70 | end 71 | 72 | order = randperm(count); 73 | data = data(:, :, :, order); 74 | label = label(:, :, :, order); 75 | 76 | %% writing to HDF5 77 | chunksz = 64; 78 | created_flag = false; 79 | totalct = 0; 80 | 81 | for batchno = 1:floor(count/chunksz) 82 | batchno 83 | last_read=(batchno-1)*chunksz; 84 | batchdata = data(:,:,:,last_read+1:last_read+chunksz); 85 | batchlabs = label(:,:,:,last_read+1:last_read+chunksz); 86 | startloc = struct('dat',[1,1,1,totalct+1], 'lab', [1,1,1,totalct+1]); 87 | curr_dat_sz = store2hdf5(savepath, batchdata, batchlabs, ~created_flag, startloc, chunksz); 88 | created_flag = true; 89 | totalct = curr_dat_sz(end); 90 | end 91 | 92 | h5disp(savepath); -------------------------------------------------------------------------------- /data/modcrop.m: -------------------------------------------------------------------------------- 1 | function imgs = modcrop(imgs, modulo) 2 | if size(imgs,3)==1 3 | sz = size(imgs); 4 | sz = sz - mod(sz, modulo); 5 | imgs = imgs(1:sz(1), 1:sz(2)); 6 | else 7 | tmpsz = size(imgs); 8 | sz = tmpsz(1:2); 9 | sz = sz - mod(sz, modulo); 10 | imgs = imgs(1:sz(1), 1:sz(2),:); 11 | end 12 | 13 | -------------------------------------------------------------------------------- /data/store2hdf5.m: -------------------------------------------------------------------------------- 1 | function [curr_dat_sz, curr_lab_sz] = store2hdf5(filename, data, labels, create, startloc, chunksz) 2 | % *data* is W*H*C*N matrix of images should be normalized (e.g. to lie between 0 and 1) beforehand 3 | % *label* is D*N matrix of labels (D labels per sample) 4 | % *create* [0/1] specifies whether to create file newly or to append to previously created file, useful to store information in batches when a dataset is too big to be held in memory (default: 1) 5 | % *startloc* (point at which to start writing data). By default, 6 | % if create=1 (create mode), startloc.data=[1 1 1 1], and startloc.lab=[1 1]; 7 | % if create=0 (append mode), startloc.data=[1 1 1 K+1], and startloc.lab = [1 K+1]; where K is the current number of samples stored in the HDF 8 | % chunksz (used only in create mode), specifies number of samples to be stored per chunk (see HDF5 documentation on chunking) for creating HDF5 files with unbounded maximum size - TLDR; higher chunk sizes allow faster read-write operations 9 | 10 | % verify that format is right 11 | dat_dims=size(data); 12 | lab_dims=size(labels); 13 | num_samples=dat_dims(end); 14 | 15 | assert(lab_dims(end)==num_samples, 'Number of samples should be matched between data and labels'); 16 | 17 | if ~exist('create','var') 18 | create=true; 19 | end 20 | 21 | 22 | if create 23 | %fprintf('Creating dataset with %d samples\n', num_samples); 24 | if ~exist('chunksz', 'var') 25 | chunksz=1000; 26 | end 27 | if exist(filename, 'file') 28 | fprintf('Warning: replacing existing file %s \n', filename); 29 | delete(filename); 30 | end 31 | h5create(filename, '/data', [dat_dims(1:end-1) Inf], 'Datatype', 'single', 'ChunkSize', [dat_dims(1:end-1) chunksz]); % width, height, channels, number 32 | h5create(filename, '/label', [lab_dims(1:end-1) Inf], 'Datatype', 'single', 'ChunkSize', [lab_dims(1:end-1) chunksz]); % width, height, channels, number 33 | if ~exist('startloc','var') 34 | startloc.dat=[ones(1,length(dat_dims)-1), 1]; 35 | startloc.lab=[ones(1,length(lab_dims)-1), 1]; 36 | end 37 | else % append mode 38 | if ~exist('startloc','var') 39 | info=h5info(filename); 40 | prev_dat_sz=info.Datasets(1).Dataspace.Size; 41 | prev_lab_sz=info.Datasets(2).Dataspace.Size; 42 | assert(prev_dat_sz(1:end-1)==dat_dims(1:end-1), 'Data dimensions must match existing dimensions in dataset'); 43 | assert(prev_lab_sz(1:end-1)==lab_dims(1:end-1), 'Label dimensions must match existing dimensions in dataset'); 44 | startloc.dat=[ones(1,length(dat_dims)-1), prev_dat_sz(end)+1]; 45 | startloc.lab=[ones(1,length(lab_dims)-1), prev_lab_sz(end)+1]; 46 | end 47 | end 48 | 49 | if ~isempty(data) 50 | h5write(filename, '/data', single(data), startloc.dat, size(data)); 51 | h5write(filename, '/label', single(labels), startloc.lab, size(labels)); 52 | end 53 | 54 | if nargout 55 | info=h5info(filename); 56 | curr_dat_sz=info.Datasets(1).Dataspace.Size; 57 | curr_lab_sz=info.Datasets(2).Dataspace.Size; 58 | end 59 | end 60 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | import h5py 4 | 5 | class DatasetFromHdf5(data.Dataset): 6 | def __init__(self, file_path): 7 | super(DatasetFromHdf5, self).__init__() 8 | hf = h5py.File(file_path) 9 | self.data = hf.get("data") 10 | self.target = hf.get("label") 11 | 12 | def __getitem__(self, index): 13 | return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float() 14 | 15 | def __len__(self): 16 | return self.data.shape[0] -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse, os 2 | import torch 3 | from torch.autograd import Variable 4 | import numpy as np 5 | import time, math 6 | import scipy.io as sio 7 | import matplotlib.pyplot as plt 8 | 9 | parser = argparse.ArgumentParser(description="PyTorch SRResNet Demo") 10 | parser.add_argument("--cuda", action="store_true", help="use cuda?") 11 | parser.add_argument("--model", default="model/model_srresnet.pth", type=str, help="model path") 12 | parser.add_argument("--image", default="butterfly_GT", type=str, help="image name") 13 | parser.add_argument("--dataset", default="Set5", type=str, help="dataset name") 14 | parser.add_argument("--scale", default=4, type=int, help="scale factor, Default: 4") 15 | parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)") 16 | 17 | def PSNR(pred, gt, shave_border=0): 18 | height, width = pred.shape[:2] 19 | pred = pred[shave_border:height - shave_border, shave_border:width - shave_border] 20 | gt = gt[shave_border:height - shave_border, shave_border:width - shave_border] 21 | imdff = pred - gt 22 | rmse = math.sqrt(np.mean(imdff ** 2)) 23 | if rmse == 0: 24 | return 100 25 | return 20 * math.log10(255.0 / rmse) 26 | 27 | opt = parser.parse_args() 28 | cuda = opt.cuda 29 | 30 | if cuda: 31 | print("=> use gpu id: '{}'".format(opt.gpus)) 32 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus 33 | if not torch.cuda.is_available(): 34 | raise Exception("No GPU found or Wrong gpu id, please run without --cuda") 35 | 36 | model = torch.load(opt.model)["model"] 37 | 38 | im_gt = sio.loadmat("testsets/" + opt.dataset + "/" + opt.image + ".mat")['im_gt'] 39 | im_b = sio.loadmat("testsets/" + opt.dataset + "/" + opt.image + ".mat")['im_b'] 40 | im_l = sio.loadmat("testsets/" + opt.dataset + "/" + opt.image + ".mat")['im_l'] 41 | 42 | im_gt = im_gt.astype(float).astype(np.uint8) 43 | im_b = im_b.astype(float).astype(np.uint8) 44 | im_l = im_l.astype(float).astype(np.uint8) 45 | 46 | im_input = im_l.astype(np.float32).transpose(2,0,1) 47 | im_input = im_input.reshape(1,im_input.shape[0],im_input.shape[1],im_input.shape[2]) 48 | im_input = Variable(torch.from_numpy(im_input/255.).float()) 49 | 50 | if cuda: 51 | model = model.cuda() 52 | im_input = im_input.cuda() 53 | else: 54 | model = model.cpu() 55 | 56 | start_time = time.time() 57 | out = model(im_input) 58 | elapsed_time = time.time() - start_time 59 | 60 | out = out.cpu() 61 | 62 | im_h = out.data[0].numpy().astype(np.float32) 63 | 64 | im_h = im_h*255. 65 | im_h[im_h<0] = 0 66 | im_h[im_h>255.] = 255. 67 | im_h = im_h.transpose(1,2,0) 68 | 69 | print("Dataset=",opt.dataset) 70 | print("Scale=",opt.scale) 71 | print("It takes {}s for processing".format(elapsed_time)) 72 | 73 | fig = plt.figure() 74 | ax = plt.subplot("131") 75 | ax.imshow(im_gt) 76 | ax.set_title("GT") 77 | 78 | ax = plt.subplot("132") 79 | ax.imshow(im_b) 80 | ax.set_title("Input(Bicubic)") 81 | 82 | ax = plt.subplot("133") 83 | ax.imshow(im_h.astype(np.uint8)) 84 | ax.set_title("Output(SRResNet)") 85 | plt.show() 86 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import matlab.engine 2 | import argparse, os 3 | import torch 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import time, math, glob 7 | import scipy.io as sio 8 | import cv2 9 | 10 | parser = argparse.ArgumentParser(description="PyTorch SRResNet Eval") 11 | parser.add_argument("--cuda", action="store_true", help="use cuda?") 12 | parser.add_argument("--model", default="model/model_srresnet.pth", type=str, help="model path") 13 | parser.add_argument("--dataset", default="Set5", type=str, help="dataset name, Default: Set5") 14 | parser.add_argument("--scale", default=4, type=int, help="scale factor, Default: 4") 15 | parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)") 16 | 17 | def PSNR(pred, gt, shave_border=0): 18 | height, width = pred.shape[:2] 19 | pred = pred[shave_border:height - shave_border, shave_border:width - shave_border] 20 | gt = gt[shave_border:height - shave_border, shave_border:width - shave_border] 21 | imdff = pred - gt 22 | rmse = math.sqrt(np.mean(imdff ** 2)) 23 | if rmse == 0: 24 | return 100 25 | return 20 * math.log10(255.0 / rmse) 26 | 27 | opt = parser.parse_args() 28 | cuda = opt.cuda 29 | eng = matlab.engine.start_matlab() 30 | 31 | if cuda: 32 | print("=> use gpu id: '{}'".format(opt.gpus)) 33 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus 34 | if not torch.cuda.is_available(): 35 | raise Exception("No GPU found or Wrong gpu id, please run without --cuda") 36 | 37 | model = torch.load(opt.model)["model"] 38 | 39 | image_list = glob.glob("./testsets/" + opt.dataset + "/*.*") 40 | 41 | avg_psnr_predicted = 0.0 42 | avg_psnr_bicubic = 0.0 43 | avg_elapsed_time = 0.0 44 | 45 | for image_name in image_list: 46 | print("Processing ", image_name) 47 | im_gt_y = sio.loadmat(image_name)['im_gt_y'] 48 | im_b_y = sio.loadmat(image_name)['im_b_y'] 49 | im_l = sio.loadmat(image_name)['im_l'] 50 | 51 | im_gt_y = im_gt_y.astype(float) 52 | im_b_y = im_b_y.astype(float) 53 | im_l = im_l.astype(float) 54 | 55 | psnr_bicubic = PSNR(im_gt_y, im_b_y,shave_border=opt.scale) 56 | avg_psnr_bicubic += psnr_bicubic 57 | 58 | im_input = im_l.astype(np.float32).transpose(2,0,1) 59 | im_input = im_input.reshape(1,im_input.shape[0],im_input.shape[1],im_input.shape[2]) 60 | im_input = Variable(torch.from_numpy(im_input/255.).float()) 61 | 62 | if cuda: 63 | model = model.cuda() 64 | im_input = im_input.cuda() 65 | else: 66 | model = model.cpu() 67 | 68 | start_time = time.time() 69 | HR_4x = model(im_input) 70 | elapsed_time = time.time() - start_time 71 | avg_elapsed_time += elapsed_time 72 | 73 | HR_4x = HR_4x.cpu() 74 | 75 | im_h = HR_4x.data[0].numpy().astype(np.float32) 76 | 77 | im_h = im_h*255. 78 | im_h = np.clip(im_h, 0., 255.) 79 | im_h = im_h.transpose(1,2,0).astype(np.float32) 80 | 81 | im_h_matlab = matlab.double((im_h / 255.).tolist()) 82 | im_h_ycbcr = eng.rgb2ycbcr(im_h_matlab) 83 | im_h_ycbcr = np.array(im_h_ycbcr._data).reshape(im_h_ycbcr.size, order='F').astype(np.float32) * 255. 84 | im_h_y = im_h_ycbcr[:,:,0] 85 | 86 | psnr_predicted = PSNR(im_gt_y, im_h_y,shave_border=opt.scale) 87 | avg_psnr_predicted += psnr_predicted 88 | 89 | print("Scale=", opt.scale) 90 | print("Dataset=", opt.dataset) 91 | print("PSNR_predicted=", avg_psnr_predicted/len(image_list)) 92 | print("PSNR_bicubic=", avg_psnr_bicubic/len(image_list)) 93 | print("It takes average {}s for processing".format(avg_elapsed_time/len(image_list))) 94 | -------------------------------------------------------------------------------- /main_srresnet.py: -------------------------------------------------------------------------------- 1 | import argparse, os 2 | import torch 3 | import math, random 4 | import torch.backends.cudnn as cudnn 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.autograd import Variable 8 | from torch.utils.data import DataLoader 9 | from srresnet import _NetG 10 | from dataset import DatasetFromHdf5 11 | from torchvision import models 12 | import torch.utils.model_zoo as model_zoo 13 | 14 | # Training settings 15 | parser = argparse.ArgumentParser(description="PyTorch SRResNet") 16 | parser.add_argument("--batchSize", type=int, default=16, help="training batch size") 17 | parser.add_argument("--nEpochs", type=int, default=500, help="number of epochs to train for") 18 | parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate. Default=1e-4") 19 | parser.add_argument("--step", type=int, default=200, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=500") 20 | parser.add_argument("--cuda", action="store_true", help="Use cuda?") 21 | parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)") 22 | parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)") 23 | parser.add_argument("--threads", type=int, default=0, help="Number of threads for data loader to use, Default: 1") 24 | parser.add_argument("--pretrained", default="", type=str, help="path to pretrained model (default: none)") 25 | parser.add_argument("--vgg_loss", action="store_true", help="Use content loss?") 26 | parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)") 27 | 28 | def main(): 29 | 30 | global opt, model, netContent 31 | opt = parser.parse_args() 32 | print(opt) 33 | 34 | cuda = opt.cuda 35 | if cuda: 36 | print("=> use gpu id: '{}'".format(opt.gpus)) 37 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus 38 | if not torch.cuda.is_available(): 39 | raise Exception("No GPU found or Wrong gpu id, please run without --cuda") 40 | 41 | opt.seed = random.randint(1, 10000) 42 | print("Random Seed: ", opt.seed) 43 | torch.manual_seed(opt.seed) 44 | if cuda: 45 | torch.cuda.manual_seed(opt.seed) 46 | 47 | cudnn.benchmark = True 48 | 49 | print("===> Loading datasets") 50 | train_set = DatasetFromHdf5("/path/to/your/hdf5/data/like/rgb_srresnet_x4.h5") 51 | training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, \ 52 | batch_size=opt.batchSize, shuffle=True) 53 | 54 | if opt.vgg_loss: 55 | print('===> Loading VGG model') 56 | netVGG = models.vgg19() 57 | netVGG.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/vgg19-dcbb9e9d.pth')) 58 | class _content_model(nn.Module): 59 | def __init__(self): 60 | super(_content_model, self).__init__() 61 | self.feature = nn.Sequential(*list(netVGG.features.children())[:-1]) 62 | 63 | def forward(self, x): 64 | out = self.feature(x) 65 | return out 66 | 67 | netContent = _content_model() 68 | 69 | print("===> Building model") 70 | model = _NetG() 71 | criterion = nn.MSELoss(size_average=False) 72 | 73 | print("===> Setting GPU") 74 | if cuda: 75 | model = model.cuda() 76 | criterion = criterion.cuda() 77 | if opt.vgg_loss: 78 | netContent = netContent.cuda() 79 | 80 | # optionally resume from a checkpoint 81 | if opt.resume: 82 | if os.path.isfile(opt.resume): 83 | print("=> loading checkpoint '{}'".format(opt.resume)) 84 | checkpoint = torch.load(opt.resume) 85 | opt.start_epoch = checkpoint["epoch"] + 1 86 | model.load_state_dict(checkpoint["model"].state_dict()) 87 | else: 88 | print("=> no checkpoint found at '{}'".format(opt.resume)) 89 | 90 | # optionally copy weights from a checkpoint 91 | if opt.pretrained: 92 | if os.path.isfile(opt.pretrained): 93 | print("=> loading model '{}'".format(opt.pretrained)) 94 | weights = torch.load(opt.pretrained) 95 | model.load_state_dict(weights['model'].state_dict()) 96 | else: 97 | print("=> no model found at '{}'".format(opt.pretrained)) 98 | 99 | print("===> Setting Optimizer") 100 | optimizer = optim.Adam(model.parameters(), lr=opt.lr) 101 | 102 | print("===> Training") 103 | for epoch in range(opt.start_epoch, opt.nEpochs + 1): 104 | train(training_data_loader, optimizer, model, criterion, epoch) 105 | save_checkpoint(model, epoch) 106 | 107 | def adjust_learning_rate(optimizer, epoch): 108 | """Sets the learning rate to the initial LR decayed by 10""" 109 | lr = opt.lr * (0.1 ** (epoch // opt.step)) 110 | return lr 111 | 112 | def train(training_data_loader, optimizer, model, criterion, epoch): 113 | 114 | lr = adjust_learning_rate(optimizer, epoch-1) 115 | 116 | for param_group in optimizer.param_groups: 117 | param_group["lr"] = lr 118 | 119 | print("Epoch={}, lr={}".format(epoch, optimizer.param_groups[0]["lr"])) 120 | model.train() 121 | 122 | for iteration, batch in enumerate(training_data_loader, 1): 123 | 124 | input, target = Variable(batch[0]), Variable(batch[1], requires_grad=False) 125 | 126 | if opt.cuda: 127 | input = input.cuda() 128 | target = target.cuda() 129 | 130 | output = model(input) 131 | loss = criterion(output, target) 132 | 133 | if opt.vgg_loss: 134 | content_input = netContent(output) 135 | content_target = netContent(target) 136 | content_target = content_target.detach() 137 | content_loss = criterion(content_input, content_target) 138 | 139 | optimizer.zero_grad() 140 | 141 | if opt.vgg_loss: 142 | netContent.zero_grad() 143 | content_loss.backward(retain_graph=True) 144 | 145 | loss.backward() 146 | 147 | optimizer.step() 148 | 149 | if iteration%100 == 0: 150 | if opt.vgg_loss: 151 | print("===> Epoch[{}]({}/{}): Loss: {:.5} Content_loss {:.5}".format(epoch, iteration, len(training_data_loader), loss.data[0], content_loss.data[0])) 152 | else: 153 | print("===> Epoch[{}]({}/{}): Loss: {:.5}".format(epoch, iteration, len(training_data_loader), loss.data[0])) 154 | 155 | def save_checkpoint(model, epoch): 156 | model_out_path = "checkpoint/" + "model_epoch_{}.pth".format(epoch) 157 | state = {"epoch": epoch ,"model": model} 158 | if not os.path.exists("checkpoint/"): 159 | os.makedirs("checkpoint/") 160 | 161 | torch.save(state, model_out_path) 162 | 163 | print("Checkpoint saved to {}".format(model_out_path)) 164 | 165 | if __name__ == "__main__": 166 | main() 167 | -------------------------------------------------------------------------------- /model/model_srresnet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/model/model_srresnet.pth -------------------------------------------------------------------------------- /result/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/result/result.png -------------------------------------------------------------------------------- /srresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | class _Residual_Block(nn.Module): 6 | def __init__(self): 7 | super(_Residual_Block, self).__init__() 8 | 9 | self.conv1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 10 | self.in1 = nn.InstanceNorm2d(64, affine=True) 11 | self.relu = nn.LeakyReLU(0.2, inplace=True) 12 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 13 | self.in2 = nn.InstanceNorm2d(64, affine=True) 14 | 15 | def forward(self, x): 16 | identity_data = x 17 | output = self.relu(self.in1(self.conv1(x))) 18 | output = self.in2(self.conv2(output)) 19 | output = torch.add(output,identity_data) 20 | return output 21 | 22 | class _NetG(nn.Module): 23 | def __init__(self): 24 | super(_NetG, self).__init__() 25 | 26 | self.conv_input = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, stride=1, padding=4, bias=False) 27 | self.relu = nn.LeakyReLU(0.2, inplace=True) 28 | 29 | self.residual = self.make_layer(_Residual_Block, 16) 30 | 31 | self.conv_mid = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 32 | self.bn_mid = nn.InstanceNorm2d(64, affine=True) 33 | 34 | self.upscale4x = nn.Sequential( 35 | nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False), 36 | nn.PixelShuffle(2), 37 | nn.LeakyReLU(0.2, inplace=True), 38 | nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False), 39 | nn.PixelShuffle(2), 40 | nn.LeakyReLU(0.2, inplace=True), 41 | ) 42 | 43 | self.conv_output = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=9, stride=1, padding=4, bias=False) 44 | 45 | for m in self.modules(): 46 | if isinstance(m, nn.Conv2d): 47 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 48 | m.weight.data.normal_(0, math.sqrt(2. / n)) 49 | if m.bias is not None: 50 | m.bias.data.zero_() 51 | 52 | def make_layer(self, block, num_of_layer): 53 | layers = [] 54 | for _ in range(num_of_layer): 55 | layers.append(block()) 56 | return nn.Sequential(*layers) 57 | 58 | def forward(self, x): 59 | out = self.relu(self.conv_input(x)) 60 | residual = out 61 | out = self.residual(out) 62 | out = self.bn_mid(self.conv_mid(out)) 63 | out = torch.add(out,residual) 64 | out = self.upscale4x(out) 65 | out = self.conv_output(out) 66 | return out 67 | 68 | class _NetD(nn.Module): 69 | def __init__(self): 70 | super(_NetD, self).__init__() 71 | 72 | self.features = nn.Sequential( 73 | 74 | # input is (3) x 96 x 96 75 | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), 76 | nn.LeakyReLU(0.2, inplace=True), 77 | 78 | # state size. (64) x 96 x 96 79 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False), 80 | nn.BatchNorm2d(64), 81 | nn.LeakyReLU(0.2, inplace=True), 82 | 83 | # state size. (64) x 96 x 96 84 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False), 85 | nn.BatchNorm2d(128), 86 | nn.LeakyReLU(0.2, inplace=True), 87 | 88 | # state size. (64) x 48 x 48 89 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False), 90 | nn.BatchNorm2d(128), 91 | nn.LeakyReLU(0.2, inplace=True), 92 | 93 | # state size. (128) x 48 x 48 94 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False), 95 | nn.BatchNorm2d(256), 96 | nn.LeakyReLU(0.2, inplace=True), 97 | 98 | # state size. (256) x 24 x 24 99 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False), 100 | nn.BatchNorm2d(256), 101 | nn.LeakyReLU(0.2, inplace=True), 102 | 103 | # state size. (256) x 12 x 12 104 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False), 105 | nn.BatchNorm2d(512), 106 | nn.LeakyReLU(0.2, inplace=True), 107 | 108 | # state size. (512) x 12 x 12 109 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False), 110 | nn.BatchNorm2d(512), 111 | nn.LeakyReLU(0.2, inplace=True), 112 | ) 113 | 114 | self.LeakyReLU = nn.LeakyReLU(0.2, inplace=True) 115 | self.fc1 = nn.Linear(512 * 6 * 6, 1024) 116 | self.fc2 = nn.Linear(1024, 1) 117 | self.sigmoid = nn.Sigmoid() 118 | 119 | for m in self.modules(): 120 | if isinstance(m, nn.Conv2d): 121 | m.weight.data.normal_(0.0, 0.02) 122 | elif isinstance(m, nn.BatchNorm2d): 123 | m.weight.data.normal_(1.0, 0.02) 124 | m.bias.data.fill_(0) 125 | 126 | def forward(self, input): 127 | 128 | out = self.features(input) 129 | 130 | # state size. (512) x 6 x 6 131 | out = out.view(out.size(0), -1) 132 | 133 | # state size. (512 x 6 x 6) 134 | out = self.fc1(out) 135 | 136 | # state size. (1024) 137 | out = self.LeakyReLU(out) 138 | 139 | out = self.fc2(out) 140 | out = self.sigmoid(out) 141 | return out.view(-1, 1).squeeze(1) -------------------------------------------------------------------------------- /testsets/Set14/baboon.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/baboon.mat -------------------------------------------------------------------------------- /testsets/Set14/barbara.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/barbara.mat -------------------------------------------------------------------------------- /testsets/Set14/bridge.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/bridge.mat -------------------------------------------------------------------------------- /testsets/Set14/coastguard.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/coastguard.mat -------------------------------------------------------------------------------- /testsets/Set14/comic.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/comic.mat -------------------------------------------------------------------------------- /testsets/Set14/face.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/face.mat -------------------------------------------------------------------------------- /testsets/Set14/flowers.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/flowers.mat -------------------------------------------------------------------------------- /testsets/Set14/foreman.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/foreman.mat -------------------------------------------------------------------------------- /testsets/Set14/lenna.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/lenna.mat -------------------------------------------------------------------------------- /testsets/Set14/man.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/man.mat -------------------------------------------------------------------------------- /testsets/Set14/monarch.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/monarch.mat -------------------------------------------------------------------------------- /testsets/Set14/pepper.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/pepper.mat -------------------------------------------------------------------------------- /testsets/Set14/ppt3.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/ppt3.mat -------------------------------------------------------------------------------- /testsets/Set14/zebra.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/zebra.mat -------------------------------------------------------------------------------- /testsets/Set5/baby_GT.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set5/baby_GT.mat -------------------------------------------------------------------------------- /testsets/Set5/bird_GT.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set5/bird_GT.mat -------------------------------------------------------------------------------- /testsets/Set5/butterfly_GT.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set5/butterfly_GT.mat -------------------------------------------------------------------------------- /testsets/Set5/head_GT.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set5/head_GT.mat -------------------------------------------------------------------------------- /testsets/Set5/woman_GT.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set5/woman_GT.mat --------------------------------------------------------------------------------