├── LICENSE
├── README.md
├── data
├── generate_train_srresnet.m
├── modcrop.m
└── store2hdf5.m
├── dataset.py
├── demo.py
├── eval.py
├── main_srresnet.py
├── model
└── model_srresnet.pth
├── result
└── result.png
├── srresnet.py
└── testsets
├── Set14
├── baboon.mat
├── barbara.mat
├── bridge.mat
├── coastguard.mat
├── comic.mat
├── face.mat
├── flowers.mat
├── foreman.mat
├── lenna.mat
├── man.mat
├── monarch.mat
├── pepper.mat
├── ppt3.mat
└── zebra.mat
└── Set5
├── baby_GT.mat
├── bird_GT.mat
├── butterfly_GT.mat
├── head_GT.mat
└── woman_GT.mat
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2017- Jiu XU
4 | Copyright (c) 2017- Rakuten, Inc
5 | Copyright (c) 2017- Rakuten Institute of Technology
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch SRResNet
2 | Implementation of Paper: "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"(https://arxiv.org/abs/1609.04802) in PyTorch
3 |
4 | ## Usage
5 | ### Training
6 | ```
7 | usage: main_srresnet.py [-h] [--batchSize BATCHSIZE] [--nEpochs NEPOCHS]
8 | [--lr LR] [--step STEP] [--cuda] [--resume RESUME]
9 | [--start-epoch START_EPOCH] [--threads THREADS]
10 | [--pretrained PRETRAINED] [--vgg_loss] [--gpus GPUS]
11 |
12 | optional arguments:
13 | -h, --help show this help message and exit
14 | --batchSize BATCHSIZE
15 | training batch size
16 | --nEpochs NEPOCHS number of epochs to train for
17 | --lr LR Learning Rate. Default=1e-4
18 | --step STEP Sets the learning rate to the initial LR decayed by
19 | momentum every n epochs, Default: n=500
20 | --cuda Use cuda?
21 | --resume RESUME Path to checkpoint (default: none)
22 | --start-epoch START_EPOCH
23 | Manual epoch number (useful on restarts)
24 | --threads THREADS Number of threads for data loader to use, Default: 1
25 | --pretrained PRETRAINED
26 | path to pretrained model (default: none)
27 | --vgg_loss Use content loss?
28 | --gpus GPUS gpu ids (default: 0)
29 | ```
30 | An example of training usage is shown as follows:
31 | ```
32 | python main_srresnet.py --cuda --vgg_loss --gpus 0
33 | ```
34 |
35 | ### demo
36 | ```
37 | usage: demo.py [-h] [--cuda] [--model MODEL] [--image IMAGE]
38 | [--dataset DATASET] [--scale SCALE] [--gpus GPUS]
39 |
40 | optional arguments:
41 | -h, --help show this help message and exit
42 | --cuda use cuda?
43 | --model MODEL model path
44 | --image IMAGE image name
45 | --dataset DATASET dataset name
46 | --scale SCALE scale factor, Default: 4
47 | --gpus GPUS gpu ids (default: 0)
48 | ```
49 | We convert Set5 test set images to mat format using Matlab, for simple image reading
50 | An example of usage is shown as follows:
51 | ```
52 | python demo.py --model model/model_srresnet.pth --dataset Set5 --image butterfly_GT --scale 4 --cuda
53 | ```
54 |
55 | ### Eval
56 | ```
57 | usage: eval.py [-h] [--cuda] [--model MODEL] [--dataset DATASET]
58 | [--scale SCALE] [--gpus GPUS]
59 |
60 | optional arguments:
61 | -h, --help show this help message and exit
62 | --cuda use cuda?
63 | --model MODEL model path
64 | --dataset DATASET dataset name, Default: Set5
65 | --scale SCALE scale factor, Default: 4
66 | --gpus GPUS gpu ids (default: 0)
67 | ```
68 | We convert Set5 test set images to mat format using Matlab. Since PSNR is evaluated on only Y channel, we import matlab in python, and use rgb2ycbcr function for converting rgb image to ycbcr image. You will have to setup the matlab python interface so as to import matlab library.
69 | An example of usage is shown as follows:
70 | ```
71 | python eval.py --model model/model_srresnet.pth --dataset Set5 --cuda
72 | ```
73 |
74 | ### Prepare Training dataset
75 | - Please refer [Code for Data Generation](https://github.com/twtygqyy/pytorch-SRResNet/tree/master/data) for creating training files.
76 | - Data augmentations including flipping, rotation, downsizing are adopted.
77 |
78 |
79 | ### Performance
80 | - We provide a pretrained model trained on [291](http://cv.snu.ac.kr/research/VDSR/train_data.zip) images with data augmentation
81 | - Instance Normalization is applied instead of Batch Normalization for better performance
82 | - So far performance in PSNR is not as good as paper, any suggestion is welcome
83 |
84 | | Dataset | SRResNet Paper | SRResNet PyTorch|
85 | | :-------------:|:--------------:|:---------------:|
86 | | Set5 | 32.05 | **31.80** |
87 | | Set14 | 28.49 | **28.25** |
88 | | BSD100 | 27.58 | **27.51** |
89 |
90 | ### Result
91 | From left to right are ground truth, bicubic and SRResNet
92 |
93 |
94 |
95 |
--------------------------------------------------------------------------------
/data/generate_train_srresnet.m:
--------------------------------------------------------------------------------
1 | clear;
2 | close all;
3 | folder = 'path/to/train/folder';
4 |
5 | savepath = 'srresnet_x4.h5';
6 |
7 | %% scale factors
8 | scale = 4;
9 |
10 | size_label = 96;
11 | size_input = size_label/scale;
12 | stride = 48;
13 |
14 | %% downsizing
15 | downsizes = [1,0.7,0.5];
16 |
17 | data = zeros(size_input, size_input, 3, 1);
18 | label = zeros(size_label, size_label, 3, 1);
19 |
20 | count = 0;
21 | margain = 0;
22 |
23 | %% generate data
24 | filepaths = [];
25 | filepaths = [filepaths; dir(fullfile(folder, '*.jpg'))];
26 | filepaths = [filepaths; dir(fullfile(folder, '*.bmp'))];
27 | filepaths = [filepaths; dir(fullfile(folder, '*.png'))];
28 |
29 | length(filepaths)
30 |
31 | for i = 1 : length(filepaths)
32 | for flip = 1: 3
33 | for degree = 1 : 4
34 | for downsize = 1 : length(downsizes)
35 | image = imread(fullfile(folder,filepaths(i).name));
36 | if flip == 1
37 | image = flipdim(image ,1);
38 | end
39 | if flip == 2
40 | image = flipdim(image ,2);
41 | end
42 |
43 | image = imrotate(image, 90 * (degree - 1));
44 | image = imresize(image,downsizes(downsize),'bicubic');
45 |
46 | if size(image,3)==3
47 | %image = rgb2ycbcr(image);
48 | image = im2double(image);
49 | im_label = modcrop(image, scale);
50 | [hei,wid, c] = size(im_label);
51 |
52 | filepaths(i).name
53 | for x = 1 + margain : stride : hei-size_label+1 - margain
54 | for y = 1 + margain :stride : wid-size_label+1 - margain
55 | subim_label = im_label(x : x+size_label-1, y : y+size_label-1, :);
56 | subim_input = imresize(subim_label,1/scale,'bicubic');
57 | % figure;
58 | % imshow(subim_input);
59 | % figure;
60 | % imshow(subim_label);
61 | count=count+1;
62 | data(:, :, :, count) = subim_input;
63 | label(:, :, :, count) = subim_label;
64 | end
65 | end
66 | end
67 | end
68 | end
69 | end
70 | end
71 |
72 | order = randperm(count);
73 | data = data(:, :, :, order);
74 | label = label(:, :, :, order);
75 |
76 | %% writing to HDF5
77 | chunksz = 64;
78 | created_flag = false;
79 | totalct = 0;
80 |
81 | for batchno = 1:floor(count/chunksz)
82 | batchno
83 | last_read=(batchno-1)*chunksz;
84 | batchdata = data(:,:,:,last_read+1:last_read+chunksz);
85 | batchlabs = label(:,:,:,last_read+1:last_read+chunksz);
86 | startloc = struct('dat',[1,1,1,totalct+1], 'lab', [1,1,1,totalct+1]);
87 | curr_dat_sz = store2hdf5(savepath, batchdata, batchlabs, ~created_flag, startloc, chunksz);
88 | created_flag = true;
89 | totalct = curr_dat_sz(end);
90 | end
91 |
92 | h5disp(savepath);
--------------------------------------------------------------------------------
/data/modcrop.m:
--------------------------------------------------------------------------------
1 | function imgs = modcrop(imgs, modulo)
2 | if size(imgs,3)==1
3 | sz = size(imgs);
4 | sz = sz - mod(sz, modulo);
5 | imgs = imgs(1:sz(1), 1:sz(2));
6 | else
7 | tmpsz = size(imgs);
8 | sz = tmpsz(1:2);
9 | sz = sz - mod(sz, modulo);
10 | imgs = imgs(1:sz(1), 1:sz(2),:);
11 | end
12 |
13 |
--------------------------------------------------------------------------------
/data/store2hdf5.m:
--------------------------------------------------------------------------------
1 | function [curr_dat_sz, curr_lab_sz] = store2hdf5(filename, data, labels, create, startloc, chunksz)
2 | % *data* is W*H*C*N matrix of images should be normalized (e.g. to lie between 0 and 1) beforehand
3 | % *label* is D*N matrix of labels (D labels per sample)
4 | % *create* [0/1] specifies whether to create file newly or to append to previously created file, useful to store information in batches when a dataset is too big to be held in memory (default: 1)
5 | % *startloc* (point at which to start writing data). By default,
6 | % if create=1 (create mode), startloc.data=[1 1 1 1], and startloc.lab=[1 1];
7 | % if create=0 (append mode), startloc.data=[1 1 1 K+1], and startloc.lab = [1 K+1]; where K is the current number of samples stored in the HDF
8 | % chunksz (used only in create mode), specifies number of samples to be stored per chunk (see HDF5 documentation on chunking) for creating HDF5 files with unbounded maximum size - TLDR; higher chunk sizes allow faster read-write operations
9 |
10 | % verify that format is right
11 | dat_dims=size(data);
12 | lab_dims=size(labels);
13 | num_samples=dat_dims(end);
14 |
15 | assert(lab_dims(end)==num_samples, 'Number of samples should be matched between data and labels');
16 |
17 | if ~exist('create','var')
18 | create=true;
19 | end
20 |
21 |
22 | if create
23 | %fprintf('Creating dataset with %d samples\n', num_samples);
24 | if ~exist('chunksz', 'var')
25 | chunksz=1000;
26 | end
27 | if exist(filename, 'file')
28 | fprintf('Warning: replacing existing file %s \n', filename);
29 | delete(filename);
30 | end
31 | h5create(filename, '/data', [dat_dims(1:end-1) Inf], 'Datatype', 'single', 'ChunkSize', [dat_dims(1:end-1) chunksz]); % width, height, channels, number
32 | h5create(filename, '/label', [lab_dims(1:end-1) Inf], 'Datatype', 'single', 'ChunkSize', [lab_dims(1:end-1) chunksz]); % width, height, channels, number
33 | if ~exist('startloc','var')
34 | startloc.dat=[ones(1,length(dat_dims)-1), 1];
35 | startloc.lab=[ones(1,length(lab_dims)-1), 1];
36 | end
37 | else % append mode
38 | if ~exist('startloc','var')
39 | info=h5info(filename);
40 | prev_dat_sz=info.Datasets(1).Dataspace.Size;
41 | prev_lab_sz=info.Datasets(2).Dataspace.Size;
42 | assert(prev_dat_sz(1:end-1)==dat_dims(1:end-1), 'Data dimensions must match existing dimensions in dataset');
43 | assert(prev_lab_sz(1:end-1)==lab_dims(1:end-1), 'Label dimensions must match existing dimensions in dataset');
44 | startloc.dat=[ones(1,length(dat_dims)-1), prev_dat_sz(end)+1];
45 | startloc.lab=[ones(1,length(lab_dims)-1), prev_lab_sz(end)+1];
46 | end
47 | end
48 |
49 | if ~isempty(data)
50 | h5write(filename, '/data', single(data), startloc.dat, size(data));
51 | h5write(filename, '/label', single(labels), startloc.lab, size(labels));
52 | end
53 |
54 | if nargout
55 | info=h5info(filename);
56 | curr_dat_sz=info.Datasets(1).Dataspace.Size;
57 | curr_lab_sz=info.Datasets(2).Dataspace.Size;
58 | end
59 | end
60 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import torch
3 | import h5py
4 |
5 | class DatasetFromHdf5(data.Dataset):
6 | def __init__(self, file_path):
7 | super(DatasetFromHdf5, self).__init__()
8 | hf = h5py.File(file_path)
9 | self.data = hf.get("data")
10 | self.target = hf.get("label")
11 |
12 | def __getitem__(self, index):
13 | return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float()
14 |
15 | def __len__(self):
16 | return self.data.shape[0]
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import argparse, os
2 | import torch
3 | from torch.autograd import Variable
4 | import numpy as np
5 | import time, math
6 | import scipy.io as sio
7 | import matplotlib.pyplot as plt
8 |
9 | parser = argparse.ArgumentParser(description="PyTorch SRResNet Demo")
10 | parser.add_argument("--cuda", action="store_true", help="use cuda?")
11 | parser.add_argument("--model", default="model/model_srresnet.pth", type=str, help="model path")
12 | parser.add_argument("--image", default="butterfly_GT", type=str, help="image name")
13 | parser.add_argument("--dataset", default="Set5", type=str, help="dataset name")
14 | parser.add_argument("--scale", default=4, type=int, help="scale factor, Default: 4")
15 | parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")
16 |
17 | def PSNR(pred, gt, shave_border=0):
18 | height, width = pred.shape[:2]
19 | pred = pred[shave_border:height - shave_border, shave_border:width - shave_border]
20 | gt = gt[shave_border:height - shave_border, shave_border:width - shave_border]
21 | imdff = pred - gt
22 | rmse = math.sqrt(np.mean(imdff ** 2))
23 | if rmse == 0:
24 | return 100
25 | return 20 * math.log10(255.0 / rmse)
26 |
27 | opt = parser.parse_args()
28 | cuda = opt.cuda
29 |
30 | if cuda:
31 | print("=> use gpu id: '{}'".format(opt.gpus))
32 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
33 | if not torch.cuda.is_available():
34 | raise Exception("No GPU found or Wrong gpu id, please run without --cuda")
35 |
36 | model = torch.load(opt.model)["model"]
37 |
38 | im_gt = sio.loadmat("testsets/" + opt.dataset + "/" + opt.image + ".mat")['im_gt']
39 | im_b = sio.loadmat("testsets/" + opt.dataset + "/" + opt.image + ".mat")['im_b']
40 | im_l = sio.loadmat("testsets/" + opt.dataset + "/" + opt.image + ".mat")['im_l']
41 |
42 | im_gt = im_gt.astype(float).astype(np.uint8)
43 | im_b = im_b.astype(float).astype(np.uint8)
44 | im_l = im_l.astype(float).astype(np.uint8)
45 |
46 | im_input = im_l.astype(np.float32).transpose(2,0,1)
47 | im_input = im_input.reshape(1,im_input.shape[0],im_input.shape[1],im_input.shape[2])
48 | im_input = Variable(torch.from_numpy(im_input/255.).float())
49 |
50 | if cuda:
51 | model = model.cuda()
52 | im_input = im_input.cuda()
53 | else:
54 | model = model.cpu()
55 |
56 | start_time = time.time()
57 | out = model(im_input)
58 | elapsed_time = time.time() - start_time
59 |
60 | out = out.cpu()
61 |
62 | im_h = out.data[0].numpy().astype(np.float32)
63 |
64 | im_h = im_h*255.
65 | im_h[im_h<0] = 0
66 | im_h[im_h>255.] = 255.
67 | im_h = im_h.transpose(1,2,0)
68 |
69 | print("Dataset=",opt.dataset)
70 | print("Scale=",opt.scale)
71 | print("It takes {}s for processing".format(elapsed_time))
72 |
73 | fig = plt.figure()
74 | ax = plt.subplot("131")
75 | ax.imshow(im_gt)
76 | ax.set_title("GT")
77 |
78 | ax = plt.subplot("132")
79 | ax.imshow(im_b)
80 | ax.set_title("Input(Bicubic)")
81 |
82 | ax = plt.subplot("133")
83 | ax.imshow(im_h.astype(np.uint8))
84 | ax.set_title("Output(SRResNet)")
85 | plt.show()
86 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import matlab.engine
2 | import argparse, os
3 | import torch
4 | from torch.autograd import Variable
5 | import numpy as np
6 | import time, math, glob
7 | import scipy.io as sio
8 | import cv2
9 |
10 | parser = argparse.ArgumentParser(description="PyTorch SRResNet Eval")
11 | parser.add_argument("--cuda", action="store_true", help="use cuda?")
12 | parser.add_argument("--model", default="model/model_srresnet.pth", type=str, help="model path")
13 | parser.add_argument("--dataset", default="Set5", type=str, help="dataset name, Default: Set5")
14 | parser.add_argument("--scale", default=4, type=int, help="scale factor, Default: 4")
15 | parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")
16 |
17 | def PSNR(pred, gt, shave_border=0):
18 | height, width = pred.shape[:2]
19 | pred = pred[shave_border:height - shave_border, shave_border:width - shave_border]
20 | gt = gt[shave_border:height - shave_border, shave_border:width - shave_border]
21 | imdff = pred - gt
22 | rmse = math.sqrt(np.mean(imdff ** 2))
23 | if rmse == 0:
24 | return 100
25 | return 20 * math.log10(255.0 / rmse)
26 |
27 | opt = parser.parse_args()
28 | cuda = opt.cuda
29 | eng = matlab.engine.start_matlab()
30 |
31 | if cuda:
32 | print("=> use gpu id: '{}'".format(opt.gpus))
33 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
34 | if not torch.cuda.is_available():
35 | raise Exception("No GPU found or Wrong gpu id, please run without --cuda")
36 |
37 | model = torch.load(opt.model)["model"]
38 |
39 | image_list = glob.glob("./testsets/" + opt.dataset + "/*.*")
40 |
41 | avg_psnr_predicted = 0.0
42 | avg_psnr_bicubic = 0.0
43 | avg_elapsed_time = 0.0
44 |
45 | for image_name in image_list:
46 | print("Processing ", image_name)
47 | im_gt_y = sio.loadmat(image_name)['im_gt_y']
48 | im_b_y = sio.loadmat(image_name)['im_b_y']
49 | im_l = sio.loadmat(image_name)['im_l']
50 |
51 | im_gt_y = im_gt_y.astype(float)
52 | im_b_y = im_b_y.astype(float)
53 | im_l = im_l.astype(float)
54 |
55 | psnr_bicubic = PSNR(im_gt_y, im_b_y,shave_border=opt.scale)
56 | avg_psnr_bicubic += psnr_bicubic
57 |
58 | im_input = im_l.astype(np.float32).transpose(2,0,1)
59 | im_input = im_input.reshape(1,im_input.shape[0],im_input.shape[1],im_input.shape[2])
60 | im_input = Variable(torch.from_numpy(im_input/255.).float())
61 |
62 | if cuda:
63 | model = model.cuda()
64 | im_input = im_input.cuda()
65 | else:
66 | model = model.cpu()
67 |
68 | start_time = time.time()
69 | HR_4x = model(im_input)
70 | elapsed_time = time.time() - start_time
71 | avg_elapsed_time += elapsed_time
72 |
73 | HR_4x = HR_4x.cpu()
74 |
75 | im_h = HR_4x.data[0].numpy().astype(np.float32)
76 |
77 | im_h = im_h*255.
78 | im_h = np.clip(im_h, 0., 255.)
79 | im_h = im_h.transpose(1,2,0).astype(np.float32)
80 |
81 | im_h_matlab = matlab.double((im_h / 255.).tolist())
82 | im_h_ycbcr = eng.rgb2ycbcr(im_h_matlab)
83 | im_h_ycbcr = np.array(im_h_ycbcr._data).reshape(im_h_ycbcr.size, order='F').astype(np.float32) * 255.
84 | im_h_y = im_h_ycbcr[:,:,0]
85 |
86 | psnr_predicted = PSNR(im_gt_y, im_h_y,shave_border=opt.scale)
87 | avg_psnr_predicted += psnr_predicted
88 |
89 | print("Scale=", opt.scale)
90 | print("Dataset=", opt.dataset)
91 | print("PSNR_predicted=", avg_psnr_predicted/len(image_list))
92 | print("PSNR_bicubic=", avg_psnr_bicubic/len(image_list))
93 | print("It takes average {}s for processing".format(avg_elapsed_time/len(image_list)))
94 |
--------------------------------------------------------------------------------
/main_srresnet.py:
--------------------------------------------------------------------------------
1 | import argparse, os
2 | import torch
3 | import math, random
4 | import torch.backends.cudnn as cudnn
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from torch.autograd import Variable
8 | from torch.utils.data import DataLoader
9 | from srresnet import _NetG
10 | from dataset import DatasetFromHdf5
11 | from torchvision import models
12 | import torch.utils.model_zoo as model_zoo
13 |
14 | # Training settings
15 | parser = argparse.ArgumentParser(description="PyTorch SRResNet")
16 | parser.add_argument("--batchSize", type=int, default=16, help="training batch size")
17 | parser.add_argument("--nEpochs", type=int, default=500, help="number of epochs to train for")
18 | parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate. Default=1e-4")
19 | parser.add_argument("--step", type=int, default=200, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=500")
20 | parser.add_argument("--cuda", action="store_true", help="Use cuda?")
21 | parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)")
22 | parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)")
23 | parser.add_argument("--threads", type=int, default=0, help="Number of threads for data loader to use, Default: 1")
24 | parser.add_argument("--pretrained", default="", type=str, help="path to pretrained model (default: none)")
25 | parser.add_argument("--vgg_loss", action="store_true", help="Use content loss?")
26 | parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")
27 |
28 | def main():
29 |
30 | global opt, model, netContent
31 | opt = parser.parse_args()
32 | print(opt)
33 |
34 | cuda = opt.cuda
35 | if cuda:
36 | print("=> use gpu id: '{}'".format(opt.gpus))
37 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
38 | if not torch.cuda.is_available():
39 | raise Exception("No GPU found or Wrong gpu id, please run without --cuda")
40 |
41 | opt.seed = random.randint(1, 10000)
42 | print("Random Seed: ", opt.seed)
43 | torch.manual_seed(opt.seed)
44 | if cuda:
45 | torch.cuda.manual_seed(opt.seed)
46 |
47 | cudnn.benchmark = True
48 |
49 | print("===> Loading datasets")
50 | train_set = DatasetFromHdf5("/path/to/your/hdf5/data/like/rgb_srresnet_x4.h5")
51 | training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, \
52 | batch_size=opt.batchSize, shuffle=True)
53 |
54 | if opt.vgg_loss:
55 | print('===> Loading VGG model')
56 | netVGG = models.vgg19()
57 | netVGG.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'))
58 | class _content_model(nn.Module):
59 | def __init__(self):
60 | super(_content_model, self).__init__()
61 | self.feature = nn.Sequential(*list(netVGG.features.children())[:-1])
62 |
63 | def forward(self, x):
64 | out = self.feature(x)
65 | return out
66 |
67 | netContent = _content_model()
68 |
69 | print("===> Building model")
70 | model = _NetG()
71 | criterion = nn.MSELoss(size_average=False)
72 |
73 | print("===> Setting GPU")
74 | if cuda:
75 | model = model.cuda()
76 | criterion = criterion.cuda()
77 | if opt.vgg_loss:
78 | netContent = netContent.cuda()
79 |
80 | # optionally resume from a checkpoint
81 | if opt.resume:
82 | if os.path.isfile(opt.resume):
83 | print("=> loading checkpoint '{}'".format(opt.resume))
84 | checkpoint = torch.load(opt.resume)
85 | opt.start_epoch = checkpoint["epoch"] + 1
86 | model.load_state_dict(checkpoint["model"].state_dict())
87 | else:
88 | print("=> no checkpoint found at '{}'".format(opt.resume))
89 |
90 | # optionally copy weights from a checkpoint
91 | if opt.pretrained:
92 | if os.path.isfile(opt.pretrained):
93 | print("=> loading model '{}'".format(opt.pretrained))
94 | weights = torch.load(opt.pretrained)
95 | model.load_state_dict(weights['model'].state_dict())
96 | else:
97 | print("=> no model found at '{}'".format(opt.pretrained))
98 |
99 | print("===> Setting Optimizer")
100 | optimizer = optim.Adam(model.parameters(), lr=opt.lr)
101 |
102 | print("===> Training")
103 | for epoch in range(opt.start_epoch, opt.nEpochs + 1):
104 | train(training_data_loader, optimizer, model, criterion, epoch)
105 | save_checkpoint(model, epoch)
106 |
107 | def adjust_learning_rate(optimizer, epoch):
108 | """Sets the learning rate to the initial LR decayed by 10"""
109 | lr = opt.lr * (0.1 ** (epoch // opt.step))
110 | return lr
111 |
112 | def train(training_data_loader, optimizer, model, criterion, epoch):
113 |
114 | lr = adjust_learning_rate(optimizer, epoch-1)
115 |
116 | for param_group in optimizer.param_groups:
117 | param_group["lr"] = lr
118 |
119 | print("Epoch={}, lr={}".format(epoch, optimizer.param_groups[0]["lr"]))
120 | model.train()
121 |
122 | for iteration, batch in enumerate(training_data_loader, 1):
123 |
124 | input, target = Variable(batch[0]), Variable(batch[1], requires_grad=False)
125 |
126 | if opt.cuda:
127 | input = input.cuda()
128 | target = target.cuda()
129 |
130 | output = model(input)
131 | loss = criterion(output, target)
132 |
133 | if opt.vgg_loss:
134 | content_input = netContent(output)
135 | content_target = netContent(target)
136 | content_target = content_target.detach()
137 | content_loss = criterion(content_input, content_target)
138 |
139 | optimizer.zero_grad()
140 |
141 | if opt.vgg_loss:
142 | netContent.zero_grad()
143 | content_loss.backward(retain_graph=True)
144 |
145 | loss.backward()
146 |
147 | optimizer.step()
148 |
149 | if iteration%100 == 0:
150 | if opt.vgg_loss:
151 | print("===> Epoch[{}]({}/{}): Loss: {:.5} Content_loss {:.5}".format(epoch, iteration, len(training_data_loader), loss.data[0], content_loss.data[0]))
152 | else:
153 | print("===> Epoch[{}]({}/{}): Loss: {:.5}".format(epoch, iteration, len(training_data_loader), loss.data[0]))
154 |
155 | def save_checkpoint(model, epoch):
156 | model_out_path = "checkpoint/" + "model_epoch_{}.pth".format(epoch)
157 | state = {"epoch": epoch ,"model": model}
158 | if not os.path.exists("checkpoint/"):
159 | os.makedirs("checkpoint/")
160 |
161 | torch.save(state, model_out_path)
162 |
163 | print("Checkpoint saved to {}".format(model_out_path))
164 |
165 | if __name__ == "__main__":
166 | main()
167 |
--------------------------------------------------------------------------------
/model/model_srresnet.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/model/model_srresnet.pth
--------------------------------------------------------------------------------
/result/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/result/result.png
--------------------------------------------------------------------------------
/srresnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 |
5 | class _Residual_Block(nn.Module):
6 | def __init__(self):
7 | super(_Residual_Block, self).__init__()
8 |
9 | self.conv1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
10 | self.in1 = nn.InstanceNorm2d(64, affine=True)
11 | self.relu = nn.LeakyReLU(0.2, inplace=True)
12 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
13 | self.in2 = nn.InstanceNorm2d(64, affine=True)
14 |
15 | def forward(self, x):
16 | identity_data = x
17 | output = self.relu(self.in1(self.conv1(x)))
18 | output = self.in2(self.conv2(output))
19 | output = torch.add(output,identity_data)
20 | return output
21 |
22 | class _NetG(nn.Module):
23 | def __init__(self):
24 | super(_NetG, self).__init__()
25 |
26 | self.conv_input = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, stride=1, padding=4, bias=False)
27 | self.relu = nn.LeakyReLU(0.2, inplace=True)
28 |
29 | self.residual = self.make_layer(_Residual_Block, 16)
30 |
31 | self.conv_mid = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
32 | self.bn_mid = nn.InstanceNorm2d(64, affine=True)
33 |
34 | self.upscale4x = nn.Sequential(
35 | nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
36 | nn.PixelShuffle(2),
37 | nn.LeakyReLU(0.2, inplace=True),
38 | nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
39 | nn.PixelShuffle(2),
40 | nn.LeakyReLU(0.2, inplace=True),
41 | )
42 |
43 | self.conv_output = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=9, stride=1, padding=4, bias=False)
44 |
45 | for m in self.modules():
46 | if isinstance(m, nn.Conv2d):
47 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
48 | m.weight.data.normal_(0, math.sqrt(2. / n))
49 | if m.bias is not None:
50 | m.bias.data.zero_()
51 |
52 | def make_layer(self, block, num_of_layer):
53 | layers = []
54 | for _ in range(num_of_layer):
55 | layers.append(block())
56 | return nn.Sequential(*layers)
57 |
58 | def forward(self, x):
59 | out = self.relu(self.conv_input(x))
60 | residual = out
61 | out = self.residual(out)
62 | out = self.bn_mid(self.conv_mid(out))
63 | out = torch.add(out,residual)
64 | out = self.upscale4x(out)
65 | out = self.conv_output(out)
66 | return out
67 |
68 | class _NetD(nn.Module):
69 | def __init__(self):
70 | super(_NetD, self).__init__()
71 |
72 | self.features = nn.Sequential(
73 |
74 | # input is (3) x 96 x 96
75 | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
76 | nn.LeakyReLU(0.2, inplace=True),
77 |
78 | # state size. (64) x 96 x 96
79 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),
80 | nn.BatchNorm2d(64),
81 | nn.LeakyReLU(0.2, inplace=True),
82 |
83 | # state size. (64) x 96 x 96
84 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),
85 | nn.BatchNorm2d(128),
86 | nn.LeakyReLU(0.2, inplace=True),
87 |
88 | # state size. (64) x 48 x 48
89 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
90 | nn.BatchNorm2d(128),
91 | nn.LeakyReLU(0.2, inplace=True),
92 |
93 | # state size. (128) x 48 x 48
94 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
95 | nn.BatchNorm2d(256),
96 | nn.LeakyReLU(0.2, inplace=True),
97 |
98 | # state size. (256) x 24 x 24
99 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
100 | nn.BatchNorm2d(256),
101 | nn.LeakyReLU(0.2, inplace=True),
102 |
103 | # state size. (256) x 12 x 12
104 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False),
105 | nn.BatchNorm2d(512),
106 | nn.LeakyReLU(0.2, inplace=True),
107 |
108 | # state size. (512) x 12 x 12
109 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False),
110 | nn.BatchNorm2d(512),
111 | nn.LeakyReLU(0.2, inplace=True),
112 | )
113 |
114 | self.LeakyReLU = nn.LeakyReLU(0.2, inplace=True)
115 | self.fc1 = nn.Linear(512 * 6 * 6, 1024)
116 | self.fc2 = nn.Linear(1024, 1)
117 | self.sigmoid = nn.Sigmoid()
118 |
119 | for m in self.modules():
120 | if isinstance(m, nn.Conv2d):
121 | m.weight.data.normal_(0.0, 0.02)
122 | elif isinstance(m, nn.BatchNorm2d):
123 | m.weight.data.normal_(1.0, 0.02)
124 | m.bias.data.fill_(0)
125 |
126 | def forward(self, input):
127 |
128 | out = self.features(input)
129 |
130 | # state size. (512) x 6 x 6
131 | out = out.view(out.size(0), -1)
132 |
133 | # state size. (512 x 6 x 6)
134 | out = self.fc1(out)
135 |
136 | # state size. (1024)
137 | out = self.LeakyReLU(out)
138 |
139 | out = self.fc2(out)
140 | out = self.sigmoid(out)
141 | return out.view(-1, 1).squeeze(1)
--------------------------------------------------------------------------------
/testsets/Set14/baboon.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/baboon.mat
--------------------------------------------------------------------------------
/testsets/Set14/barbara.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/barbara.mat
--------------------------------------------------------------------------------
/testsets/Set14/bridge.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/bridge.mat
--------------------------------------------------------------------------------
/testsets/Set14/coastguard.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/coastguard.mat
--------------------------------------------------------------------------------
/testsets/Set14/comic.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/comic.mat
--------------------------------------------------------------------------------
/testsets/Set14/face.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/face.mat
--------------------------------------------------------------------------------
/testsets/Set14/flowers.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/flowers.mat
--------------------------------------------------------------------------------
/testsets/Set14/foreman.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/foreman.mat
--------------------------------------------------------------------------------
/testsets/Set14/lenna.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/lenna.mat
--------------------------------------------------------------------------------
/testsets/Set14/man.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/man.mat
--------------------------------------------------------------------------------
/testsets/Set14/monarch.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/monarch.mat
--------------------------------------------------------------------------------
/testsets/Set14/pepper.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/pepper.mat
--------------------------------------------------------------------------------
/testsets/Set14/ppt3.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/ppt3.mat
--------------------------------------------------------------------------------
/testsets/Set14/zebra.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set14/zebra.mat
--------------------------------------------------------------------------------
/testsets/Set5/baby_GT.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set5/baby_GT.mat
--------------------------------------------------------------------------------
/testsets/Set5/bird_GT.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set5/bird_GT.mat
--------------------------------------------------------------------------------
/testsets/Set5/butterfly_GT.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set5/butterfly_GT.mat
--------------------------------------------------------------------------------
/testsets/Set5/head_GT.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set5/head_GT.mat
--------------------------------------------------------------------------------
/testsets/Set5/woman_GT.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-SRResNet/d715729c8805f59dccd4a89acb7af11cb7b1534a/testsets/Set5/woman_GT.mat
--------------------------------------------------------------------------------