├── LICENSE
├── README.md
├── Set5
├── baby_GT.bmp
├── baby_GT_scale_2.bmp
├── baby_GT_scale_3.bmp
├── baby_GT_scale_4.bmp
├── bird_GT.bmp
├── bird_GT_scale_2.bmp
├── bird_GT_scale_3.bmp
├── bird_GT_scale_4.bmp
├── butterfly_GT.bmp
├── butterfly_GT_scale_2.bmp
├── butterfly_GT_scale_3.bmp
├── butterfly_GT_scale_4.bmp
├── head_GT.bmp
├── head_GT_scale_2.bmp
├── head_GT_scale_3.bmp
├── head_GT_scale_4.bmp
├── woman_GT.bmp
├── woman_GT_scale_2.bmp
├── woman_GT_scale_3.bmp
└── woman_GT_scale_4.bmp
├── Set5_mat
├── baby_GT_x2.mat
├── baby_GT_x3.mat
├── baby_GT_x4.mat
├── bird_GT_x2.mat
├── bird_GT_x3.mat
├── bird_GT_x4.mat
├── butterfly_GT_x2.mat
├── butterfly_GT_x3.mat
├── butterfly_GT_x4.mat
├── head_GT_x2.mat
├── head_GT_x3.mat
├── head_GT_x4.mat
├── woman_GT_x2.mat
├── woman_GT_x3.mat
└── woman_GT_x4.mat
├── VDSR-Demo.ipynb
├── data
├── generate_test_mat.m
├── generate_train.m
├── modcrop.m
├── store2hdf5.m
└── train.h5
├── dataset.py
├── demo.py
├── eval.py
├── main_vdsr.py
├── model
└── model_epoch_50.pth
├── result
├── input.bmp
└── output.bmp
└── vdsr.py
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2017- Jiu XU
4 | Copyright (c) 2017- Rakuten, Inc
5 | Copyright (c) 2017- Rakuten Institute of Technology
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch VDSR
2 | Implementation of CVPR2016 Paper: "Accurate Image Super-Resolution Using
3 | Very Deep Convolutional Networks"(http://cv.snu.ac.kr/research/VDSR/) in PyTorch
4 |
5 | ## Usage
6 | ### Training
7 | ```
8 | usage: main_vdsr.py [-h] [--batchSize BATCHSIZE] [--nEpochs NEPOCHS] [--lr LR]
9 | [--step STEP] [--cuda] [--resume RESUME]
10 | [--start-epoch START_EPOCH] [--clip CLIP] [--threads THREADS]
11 | [--momentum MOMENTUM] [--weight-decay WEIGHT_DECAY]
12 | [--pretrained PRETRAINED] [--gpus GPUS]
13 |
14 | optional arguments:
15 | -h, --help Show this help message and exit
16 | --batchSize Training batch size
17 | --nEpochs Number of epochs to train for
18 | --lr Learning rate. Default=0.01
19 | --step Learning rate decay, Default: n=10 epochs
20 | --cuda Use cuda
21 | --resume Path to checkpoint
22 | --clip Clipping Gradients. Default=0.4
23 | --threads Number of threads for data loader to use Default=1
24 | --momentum Momentum, Default: 0.9
25 | --weight-decay Weight decay, Default: 1e-4
26 | --pretrained PRETRAINED
27 | path to pretrained model (default: none)
28 | --gpus GPUS gpu ids (default: 0)
29 | ```
30 | An example of training usage is shown as follows:
31 | ```
32 | python main_vdsr.py --cuda --gpus 0
33 | ```
34 |
35 | ### Evaluation
36 | ```
37 | usage: eval.py [-h] [--cuda] [--model MODEL] [--dataset DATASET]
38 | [--scale SCALE] [--gpus GPUS]
39 |
40 | PyTorch VDSR Eval
41 |
42 | optional arguments:
43 | -h, --help show this help message and exit
44 | --cuda use cuda?
45 | --model MODEL model path
46 | --dataset DATASET dataset name, Default: Set5
47 | --gpus GPUS gpu ids (default: 0)
48 | ```
49 | An example of training usage is shown as follows:
50 | ```
51 | python eval.py --cuda --dataset Set5
52 | ```
53 |
54 | ### Demo
55 | ```
56 | usage: demo.py [-h] [--cuda] [--model MODEL] [--image IMAGE] [--scale SCALE] [--gpus GPUS]
57 |
58 | optional arguments:
59 | -h, --help Show this help message and exit
60 | --cuda Use cuda
61 | --model Model path. Default=model/model_epoch_50.pth
62 | --image Image name. Default=butterfly_GT
63 | --scale Scale factor, Default: 4
64 | --gpus GPUS gpu ids (default: 0)
65 | ```
66 | An example of usage is shown as follows:
67 | ```
68 | python eval.py --model model/model_epoch_50.pth --dataset Set5 --cuda
69 | ```
70 |
71 | ### Prepare Training dataset
72 | - We provide a simple hdf5 format training sample in data folder with 'data' and 'label' keys, the training data is generated with Matlab Bicubic Interplotation, please refer [Code for Data Generation](https://github.com/twtygqyy/pytorch-vdsr/tree/master/data) for creating training files.
73 |
74 | ### Performance
75 | - We provide a pretrained VDSR model trained on [291](https://drive.google.com/open?id=1Rt3asDLuMgLuJvPA1YrhyjWhb97Ly742) images with data augmentation
76 | - No bias is used in this implementation, and the gradient clipping's implementation is different from paper
77 | - Performance in PSNR on Set5
78 |
79 | | Scale | VDSR Paper | VDSR PyTorch|
80 | | ------------- |:-------------:| -----:|
81 | | 2x | 37.53 | 37.65 |
82 | | 3x | 33.66 | 33.77|
83 | | 4x | 31.35 | 31.45 |
84 |
85 | ### Result
86 | From left to right are ground truth, bicubic and vdsr
87 |
88 |
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/Set5/baby_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/baby_GT.bmp
--------------------------------------------------------------------------------
/Set5/baby_GT_scale_2.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/baby_GT_scale_2.bmp
--------------------------------------------------------------------------------
/Set5/baby_GT_scale_3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/baby_GT_scale_3.bmp
--------------------------------------------------------------------------------
/Set5/baby_GT_scale_4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/baby_GT_scale_4.bmp
--------------------------------------------------------------------------------
/Set5/bird_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/bird_GT.bmp
--------------------------------------------------------------------------------
/Set5/bird_GT_scale_2.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/bird_GT_scale_2.bmp
--------------------------------------------------------------------------------
/Set5/bird_GT_scale_3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/bird_GT_scale_3.bmp
--------------------------------------------------------------------------------
/Set5/bird_GT_scale_4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/bird_GT_scale_4.bmp
--------------------------------------------------------------------------------
/Set5/butterfly_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/butterfly_GT.bmp
--------------------------------------------------------------------------------
/Set5/butterfly_GT_scale_2.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/butterfly_GT_scale_2.bmp
--------------------------------------------------------------------------------
/Set5/butterfly_GT_scale_3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/butterfly_GT_scale_3.bmp
--------------------------------------------------------------------------------
/Set5/butterfly_GT_scale_4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/butterfly_GT_scale_4.bmp
--------------------------------------------------------------------------------
/Set5/head_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/head_GT.bmp
--------------------------------------------------------------------------------
/Set5/head_GT_scale_2.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/head_GT_scale_2.bmp
--------------------------------------------------------------------------------
/Set5/head_GT_scale_3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/head_GT_scale_3.bmp
--------------------------------------------------------------------------------
/Set5/head_GT_scale_4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/head_GT_scale_4.bmp
--------------------------------------------------------------------------------
/Set5/woman_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/woman_GT.bmp
--------------------------------------------------------------------------------
/Set5/woman_GT_scale_2.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/woman_GT_scale_2.bmp
--------------------------------------------------------------------------------
/Set5/woman_GT_scale_3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/woman_GT_scale_3.bmp
--------------------------------------------------------------------------------
/Set5/woman_GT_scale_4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/woman_GT_scale_4.bmp
--------------------------------------------------------------------------------
/Set5_mat/baby_GT_x2.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/baby_GT_x2.mat
--------------------------------------------------------------------------------
/Set5_mat/baby_GT_x3.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/baby_GT_x3.mat
--------------------------------------------------------------------------------
/Set5_mat/baby_GT_x4.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/baby_GT_x4.mat
--------------------------------------------------------------------------------
/Set5_mat/bird_GT_x2.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/bird_GT_x2.mat
--------------------------------------------------------------------------------
/Set5_mat/bird_GT_x3.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/bird_GT_x3.mat
--------------------------------------------------------------------------------
/Set5_mat/bird_GT_x4.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/bird_GT_x4.mat
--------------------------------------------------------------------------------
/Set5_mat/butterfly_GT_x2.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/butterfly_GT_x2.mat
--------------------------------------------------------------------------------
/Set5_mat/butterfly_GT_x3.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/butterfly_GT_x3.mat
--------------------------------------------------------------------------------
/Set5_mat/butterfly_GT_x4.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/butterfly_GT_x4.mat
--------------------------------------------------------------------------------
/Set5_mat/head_GT_x2.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/head_GT_x2.mat
--------------------------------------------------------------------------------
/Set5_mat/head_GT_x3.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/head_GT_x3.mat
--------------------------------------------------------------------------------
/Set5_mat/head_GT_x4.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/head_GT_x4.mat
--------------------------------------------------------------------------------
/Set5_mat/woman_GT_x2.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/woman_GT_x2.mat
--------------------------------------------------------------------------------
/Set5_mat/woman_GT_x3.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/woman_GT_x3.mat
--------------------------------------------------------------------------------
/Set5_mat/woman_GT_x4.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/woman_GT_x4.mat
--------------------------------------------------------------------------------
/data/generate_test_mat.m:
--------------------------------------------------------------------------------
1 | clear;close all;
2 | %% settings
3 | folder = 'Set5';
4 |
5 | %% generate data
6 | filepaths = [];
7 | filepaths = [filepaths; dir(fullfile(folder, '*.bmp'))];
8 |
9 | scale = [2, 3, 4];
10 |
11 | for i = 1 : length(filepaths)
12 | im_gt = imread(fullfile(folder,filepaths(i).name));
13 | for s = 1 : length(scale)
14 | im_gt = modcrop(im_gt, scale(s));
15 | im_gt = double(im_gt);
16 | im_gt_ycbcr = rgb2ycbcr(im_gt / 255.0);
17 | im_gt_y = im_gt_ycbcr(:,:,1) * 255.0;
18 | im_l_ycbcr = imresize(im_gt_ycbcr,1/scale(s),'bicubic');
19 | im_b_ycbcr = imresize(im_l_ycbcr,scale(s),'bicubic');
20 | im_l_y = im_l_ycbcr(:,:,1) * 255.0;
21 | im_l = ycbcr2rgb(im_l_ycbcr) * 255.0;
22 | im_b_y = im_b_ycbcr(:,:,1) * 255.0;
23 | im_b = ycbcr2rgb(im_b_ycbcr) * 255.0;
24 | last = length(filepaths(i).name)-4;
25 | filename = sprintf('Set5_mat/%s_x%s.mat',filepaths(i).name(1 : last),num2str(scale(s)));
26 | save(filename, 'im_gt_y', 'im_b_y', 'im_gt', 'im_b', 'im_l_ycbcr', 'im_l_y', 'im_l');
27 | end
28 | end
29 |
--------------------------------------------------------------------------------
/data/generate_train.m:
--------------------------------------------------------------------------------
1 | clear;close all;
2 |
3 | folder = 'path/to/train/folder';
4 |
5 | savepath = 'train.h5';
6 | size_input = 41;
7 | size_label = 41;
8 | stride = 41;
9 |
10 | %% scale factors
11 | scale = [2,3,4];
12 | %% downsizing
13 | downsizes = [1,0.7,0.5];
14 |
15 | %% initialization
16 | data = zeros(size_input, size_input, 1, 1);
17 | label = zeros(size_label, size_label, 1, 1);
18 |
19 | count = 0;
20 | margain = 0;
21 |
22 | %% generate data
23 | filepaths = [];
24 | filepaths = [filepaths; dir(fullfile(folder, '*.jpg'))];
25 | filepaths = [filepaths; dir(fullfile(folder, '*.bmp'))];
26 |
27 | for i = 1 : length(filepaths)
28 | for flip = 1: 3
29 | for degree = 1 : 4
30 | for s = 1 : length(scale)
31 | for downsize = 1 : length(downsizes)
32 | image = imread(fullfile(folder,filepaths(i).name));
33 |
34 | if flip == 1
35 | image = flipdim(image ,1);
36 | end
37 | if flip == 2
38 | image = flipdim(image ,2);
39 | end
40 |
41 | image = imrotate(image, 90 * (degree - 1));
42 |
43 | image = imresize(image,downsizes(downsize),'bicubic');
44 |
45 | if size(image,3)==3
46 | image = rgb2ycbcr(image);
47 | image = im2double(image(:, :, 1));
48 |
49 | im_label = modcrop(image, scale(s));
50 | [hei,wid] = size(im_label);
51 | im_input = imresize(imresize(im_label,1/scale(s),'bicubic'),[hei,wid],'bicubic');
52 | filepaths(i).name
53 | for x = 1 : stride : hei-size_input+1
54 | for y = 1 :stride : wid-size_input+1
55 |
56 | subim_input = im_input(x : x+size_input-1, y : y+size_input-1);
57 | subim_label = im_label(x : x+size_label-1, y : y+size_label-1);
58 |
59 | count=count+1;
60 |
61 | data(:, :, 1, count) = subim_input;
62 | label(:, :, 1, count) = subim_label;
63 | end
64 | end
65 | end
66 | end
67 | end
68 | end
69 | end
70 | end
71 |
72 | order = randperm(count);
73 | data = data(:, :, 1, order);
74 | label = label(:, :, 1, order);
75 |
76 | %% writing to HDF5
77 | chunksz = 64;
78 | created_flag = false;
79 | totalct = 0;
80 |
81 | for batchno = 1:floor(count/chunksz)
82 | batchno
83 | last_read=(batchno-1)*chunksz;
84 | batchdata = data(:,:,1,last_read+1:last_read+chunksz);
85 | batchlabs = label(:,:,1,last_read+1:last_read+chunksz);
86 |
87 | startloc = struct('dat',[1,1,1,totalct+1], 'lab', [1,1,1,totalct+1]);
88 | curr_dat_sz = store2hdf5(savepath, batchdata, batchlabs, ~created_flag, startloc, chunksz);
89 | created_flag = true;
90 | totalct = curr_dat_sz(end);
91 | end
92 |
93 | h5disp(savepath);
94 |
--------------------------------------------------------------------------------
/data/modcrop.m:
--------------------------------------------------------------------------------
1 | function imgs = modcrop(imgs, modulo)
2 | if size(imgs,3)==1
3 | sz = size(imgs);
4 | sz = sz - mod(sz, modulo);
5 | imgs = imgs(1:sz(1), 1:sz(2));
6 | else
7 | tmpsz = size(imgs);
8 | sz = tmpsz(1:2);
9 | sz = sz - mod(sz, modulo);
10 | imgs = imgs(1:sz(1), 1:sz(2),:);
11 | end
12 |
13 |
--------------------------------------------------------------------------------
/data/store2hdf5.m:
--------------------------------------------------------------------------------
1 | function [curr_dat_sz, curr_lab_sz] = store2hdf5(filename, data, labels, create, startloc, chunksz)
2 | % *data* is W*H*C*N matrix of images should be normalized (e.g. to lie between 0 and 1) beforehand
3 | % *label* is D*N matrix of labels (D labels per sample)
4 | % *create* [0/1] specifies whether to create file newly or to append to previously created file, useful to store information in batches when a dataset is too big to be held in memory (default: 1)
5 | % *startloc* (point at which to start writing data). By default,
6 | % if create=1 (create mode), startloc.data=[1 1 1 1], and startloc.lab=[1 1];
7 | % if create=0 (append mode), startloc.data=[1 1 1 K+1], and startloc.lab = [1 K+1]; where K is the current number of samples stored in the HDF
8 | % chunksz (used only in create mode), specifies number of samples to be stored per chunk (see HDF5 documentation on chunking) for creating HDF5 files with unbounded maximum size - TLDR; higher chunk sizes allow faster read-write operations
9 |
10 | % verify that format is right
11 | dat_dims=size(data);
12 | lab_dims=size(labels);
13 | num_samples=dat_dims(end);
14 |
15 | assert(lab_dims(end)==num_samples, 'Number of samples should be matched between data and labels');
16 |
17 | if ~exist('create','var')
18 | create=true;
19 | end
20 |
21 |
22 | if create
23 | %fprintf('Creating dataset with %d samples\n', num_samples);
24 | if ~exist('chunksz', 'var')
25 | chunksz=1000;
26 | end
27 | if exist(filename, 'file')
28 | fprintf('Warning: replacing existing file %s \n', filename);
29 | delete(filename);
30 | end
31 | h5create(filename, '/data', [dat_dims(1:end-1) Inf], 'Datatype', 'single', 'ChunkSize', [dat_dims(1:end-1) chunksz]); % width, height, channels, number
32 | h5create(filename, '/label', [lab_dims(1:end-1) Inf], 'Datatype', 'single', 'ChunkSize', [lab_dims(1:end-1) chunksz]); % width, height, channels, number
33 | if ~exist('startloc','var')
34 | startloc.dat=[ones(1,length(dat_dims)-1), 1];
35 | startloc.lab=[ones(1,length(lab_dims)-1), 1];
36 | end
37 | else % append mode
38 | if ~exist('startloc','var')
39 | info=h5info(filename);
40 | prev_dat_sz=info.Datasets(1).Dataspace.Size;
41 | prev_lab_sz=info.Datasets(2).Dataspace.Size;
42 | assert(prev_dat_sz(1:end-1)==dat_dims(1:end-1), 'Data dimensions must match existing dimensions in dataset');
43 | assert(prev_lab_sz(1:end-1)==lab_dims(1:end-1), 'Label dimensions must match existing dimensions in dataset');
44 | startloc.dat=[ones(1,length(dat_dims)-1), prev_dat_sz(end)+1];
45 | startloc.lab=[ones(1,length(lab_dims)-1), prev_lab_sz(end)+1];
46 | end
47 | end
48 |
49 | if ~isempty(data)
50 | h5write(filename, '/data', single(data), startloc.dat, size(data));
51 | h5write(filename, '/label', single(labels), startloc.lab, size(labels));
52 | end
53 |
54 | if nargout
55 | info=h5info(filename);
56 | curr_dat_sz=info.Datasets(1).Dataspace.Size;
57 | curr_lab_sz=info.Datasets(2).Dataspace.Size;
58 | end
59 | end
60 |
--------------------------------------------------------------------------------
/data/train.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/data/train.h5
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import torch
3 | import h5py
4 |
5 | class DatasetFromHdf5(data.Dataset):
6 | def __init__(self, file_path):
7 | super(DatasetFromHdf5, self).__init__()
8 | hf = h5py.File(file_path)
9 | self.data = hf.get('data')
10 | self.target = hf.get('label')
11 |
12 | def __getitem__(self, index):
13 | return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float()
14 |
15 | def __len__(self):
16 | return self.data.shape[0]
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import argparse, os
2 | import torch
3 | from torch.autograd import Variable
4 | from scipy.ndimage import imread
5 | from PIL import Image
6 | import numpy as np
7 | import time, math
8 | import matplotlib.pyplot as plt
9 |
10 | parser = argparse.ArgumentParser(description="PyTorch VDSR Demo")
11 | parser.add_argument("--cuda", action="store_true", help="use cuda?")
12 | parser.add_argument("--model", default="model/model_epoch_50.pth", type=str, help="model path")
13 | parser.add_argument("--image", default="butterfly_GT", type=str, help="image name")
14 | parser.add_argument("--scale", default=4, type=int, help="scale factor, Default: 4")
15 | parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")
16 |
17 | def PSNR(pred, gt, shave_border=0):
18 | height, width = pred.shape[:2]
19 | pred = pred[shave_border:height - shave_border, shave_border:width - shave_border]
20 | gt = gt[shave_border:height - shave_border, shave_border:width - shave_border]
21 | imdff = pred - gt
22 | rmse = math.sqrt(np.mean(imdff ** 2))
23 | if rmse == 0:
24 | return 100
25 | return 20 * math.log10(255.0 / rmse)
26 |
27 | def colorize(y, ycbcr):
28 | img = np.zeros((y.shape[0], y.shape[1], 3), np.uint8)
29 | img[:,:,0] = y
30 | img[:,:,1] = ycbcr[:,:,1]
31 | img[:,:,2] = ycbcr[:,:,2]
32 | img = Image.fromarray(img, "YCbCr").convert("RGB")
33 | return img
34 |
35 | opt = parser.parse_args()
36 | cuda = opt.cuda
37 |
38 | if cuda:
39 | print("=> use gpu id: '{}'".format(opt.gpus))
40 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
41 | if not torch.cuda.is_available():
42 | raise Exception("No GPU found or Wrong gpu id, please run without --cuda")
43 |
44 |
45 | model = torch.load(opt.model, map_location=lambda storage, loc: storage)["model"]
46 |
47 | im_gt_ycbcr = imread("Set5/" + opt.image + ".bmp", mode="YCbCr")
48 | im_b_ycbcr = imread("Set5/"+ opt.image + "_scale_"+ str(opt.scale) + ".bmp", mode="YCbCr")
49 |
50 | im_gt_y = im_gt_ycbcr[:,:,0].astype(float)
51 | im_b_y = im_b_ycbcr[:,:,0].astype(float)
52 |
53 | psnr_bicubic = PSNR(im_gt_y, im_b_y,shave_border=opt.scale)
54 |
55 | im_input = im_b_y/255.
56 |
57 | im_input = Variable(torch.from_numpy(im_input).float()).view(1, -1, im_input.shape[0], im_input.shape[1])
58 |
59 | if cuda:
60 | model = model.cuda()
61 | im_input = im_input.cuda()
62 | else:
63 | model = model.cpu()
64 |
65 | start_time = time.time()
66 | out = model(im_input)
67 | elapsed_time = time.time() - start_time
68 |
69 | out = out.cpu()
70 |
71 | im_h_y = out.data[0].numpy().astype(np.float32)
72 |
73 | im_h_y = im_h_y * 255.
74 | im_h_y[im_h_y < 0] = 0
75 | im_h_y[im_h_y > 255.] = 255.
76 |
77 | psnr_predicted = PSNR(im_gt_y, im_h_y[0,:,:], shave_border=opt.scale)
78 |
79 | im_h = colorize(im_h_y[0,:,:], im_b_ycbcr)
80 | im_gt = Image.fromarray(im_gt_ycbcr, "YCbCr").convert("RGB")
81 | im_b = Image.fromarray(im_b_ycbcr, "YCbCr").convert("RGB")
82 |
83 | print("Scale=",opt.scale)
84 | print("PSNR_predicted=", psnr_predicted)
85 | print("PSNR_bicubic=", psnr_bicubic)
86 | print("It takes {}s for processing".format(elapsed_time))
87 |
88 | fig = plt.figure()
89 | ax = plt.subplot("131")
90 | ax.imshow(im_gt)
91 | ax.set_title("GT")
92 |
93 | ax = plt.subplot("132")
94 | ax.imshow(im_b)
95 | ax.set_title("Input(bicubic)")
96 |
97 | ax = plt.subplot("133")
98 | ax.imshow(im_h)
99 | ax.set_title("Output(vdsr)")
100 | plt.show()
101 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import argparse, os
2 | import torch
3 | from torch.autograd import Variable
4 | import numpy as np
5 | import time, math, glob
6 | import scipy.io as sio
7 |
8 | parser = argparse.ArgumentParser(description="PyTorch VDSR Eval")
9 | parser.add_argument("--cuda", action="store_true", help="use cuda?")
10 | parser.add_argument("--model", default="model/model_epoch_50.pth", type=str, help="model path")
11 | parser.add_argument("--dataset", default="Set5", type=str, help="dataset name, Default: Set5")
12 | parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")
13 |
14 | def PSNR(pred, gt, shave_border=0):
15 | height, width = pred.shape[:2]
16 | pred = pred[shave_border:height - shave_border, shave_border:width - shave_border]
17 | gt = gt[shave_border:height - shave_border, shave_border:width - shave_border]
18 | imdff = pred - gt
19 | rmse = math.sqrt(np.mean(imdff ** 2))
20 | if rmse == 0:
21 | return 100
22 | return 20 * math.log10(255.0 / rmse)
23 |
24 | opt = parser.parse_args()
25 | cuda = opt.cuda
26 |
27 | if cuda:
28 | print("=> use gpu id: '{}'".format(opt.gpus))
29 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
30 | if not torch.cuda.is_available():
31 | raise Exception("No GPU found or Wrong gpu id, please run without --cuda")
32 |
33 | model = torch.load(opt.model, map_location=lambda storage, loc: storage)["model"]
34 |
35 | scales = [2,3,4]
36 |
37 | image_list = glob.glob(opt.dataset+"_mat/*.*")
38 |
39 | for scale in scales:
40 | avg_psnr_predicted = 0.0
41 | avg_psnr_bicubic = 0.0
42 | avg_elapsed_time = 0.0
43 | count = 0.0
44 | for image_name in image_list:
45 | if str(scale) in image_name:
46 | count += 1
47 | print("Processing ", image_name)
48 | im_gt_y = sio.loadmat(image_name)['im_gt_y']
49 | im_b_y = sio.loadmat(image_name)['im_b_y']
50 |
51 | im_gt_y = im_gt_y.astype(float)
52 | im_b_y = im_b_y.astype(float)
53 |
54 | psnr_bicubic = PSNR(im_gt_y, im_b_y,shave_border=scale)
55 | avg_psnr_bicubic += psnr_bicubic
56 |
57 | im_input = im_b_y/255.
58 |
59 | im_input = Variable(torch.from_numpy(im_input).float()).view(1, -1, im_input.shape[0], im_input.shape[1])
60 |
61 | if cuda:
62 | model = model.cuda()
63 | im_input = im_input.cuda()
64 | else:
65 | model = model.cpu()
66 |
67 | start_time = time.time()
68 | HR = model(im_input)
69 | elapsed_time = time.time() - start_time
70 | avg_elapsed_time += elapsed_time
71 |
72 | HR = HR.cpu()
73 |
74 | im_h_y = HR.data[0].numpy().astype(np.float32)
75 |
76 | im_h_y = im_h_y * 255.
77 | im_h_y[im_h_y < 0] = 0
78 | im_h_y[im_h_y > 255.] = 255.
79 | im_h_y = im_h_y[0,:,:]
80 |
81 | psnr_predicted = PSNR(im_gt_y, im_h_y,shave_border=scale)
82 | avg_psnr_predicted += psnr_predicted
83 |
84 | print("Scale=", scale)
85 | print("Dataset=", opt.dataset)
86 | print("PSNR_predicted=", avg_psnr_predicted/count)
87 | print("PSNR_bicubic=", avg_psnr_bicubic/count)
88 | print("It takes average {}s for processing".format(avg_elapsed_time/count))
89 |
--------------------------------------------------------------------------------
/main_vdsr.py:
--------------------------------------------------------------------------------
1 | import argparse, os
2 | import torch
3 | import random
4 | import torch.backends.cudnn as cudnn
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from torch.autograd import Variable
8 | from torch.utils.data import DataLoader
9 | from vdsr import Net
10 | from dataset import DatasetFromHdf5
11 |
12 | # Training settings
13 | parser = argparse.ArgumentParser(description="PyTorch VDSR")
14 | parser.add_argument("--batchSize", type=int, default=128, help="Training batch size")
15 | parser.add_argument("--nEpochs", type=int, default=50, help="Number of epochs to train for")
16 | parser.add_argument("--lr", type=float, default=0.1, help="Learning Rate. Default=0.1")
17 | parser.add_argument("--step", type=int, default=10, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=10")
18 | parser.add_argument("--cuda", action="store_true", help="Use cuda?")
19 | parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)")
20 | parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)")
21 | parser.add_argument("--clip", type=float, default=0.4, help="Clipping Gradients. Default=0.4")
22 | parser.add_argument("--threads", type=int, default=1, help="Number of threads for data loader to use, Default: 1")
23 | parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9")
24 | parser.add_argument("--weight-decay", "--wd", default=1e-4, type=float, help="Weight decay, Default: 1e-4")
25 | parser.add_argument('--pretrained', default='', type=str, help='path to pretrained model (default: none)')
26 | parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")
27 |
28 | def main():
29 | global opt, model
30 | opt = parser.parse_args()
31 | print(opt)
32 |
33 | cuda = opt.cuda
34 | if cuda:
35 | print("=> use gpu id: '{}'".format(opt.gpus))
36 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
37 | if not torch.cuda.is_available():
38 | raise Exception("No GPU found or Wrong gpu id, please run without --cuda")
39 |
40 | opt.seed = random.randint(1, 10000)
41 | print("Random Seed: ", opt.seed)
42 | torch.manual_seed(opt.seed)
43 | if cuda:
44 | torch.cuda.manual_seed(opt.seed)
45 |
46 | cudnn.benchmark = True
47 |
48 | print("===> Loading datasets")
49 | train_set = DatasetFromHdf5("data/train.h5")
50 | training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
51 |
52 | print("===> Building model")
53 | model = Net()
54 | criterion = nn.MSELoss(size_average=False)
55 |
56 | print("===> Setting GPU")
57 | if cuda:
58 | model = model.cuda()
59 | criterion = criterion.cuda()
60 |
61 | # optionally resume from a checkpoint
62 | if opt.resume:
63 | if os.path.isfile(opt.resume):
64 | print("=> loading checkpoint '{}'".format(opt.resume))
65 | checkpoint = torch.load(opt.resume)
66 | opt.start_epoch = checkpoint["epoch"] + 1
67 | model.load_state_dict(checkpoint["model"].state_dict())
68 | else:
69 | print("=> no checkpoint found at '{}'".format(opt.resume))
70 |
71 | # optionally copy weights from a checkpoint
72 | if opt.pretrained:
73 | if os.path.isfile(opt.pretrained):
74 | print("=> loading model '{}'".format(opt.pretrained))
75 | weights = torch.load(opt.pretrained)
76 | model.load_state_dict(weights['model'].state_dict())
77 | else:
78 | print("=> no model found at '{}'".format(opt.pretrained))
79 |
80 | print("===> Setting Optimizer")
81 | optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay)
82 |
83 | print("===> Training")
84 | for epoch in range(opt.start_epoch, opt.nEpochs + 1):
85 | train(training_data_loader, optimizer, model, criterion, epoch)
86 | save_checkpoint(model, epoch)
87 |
88 | def adjust_learning_rate(optimizer, epoch):
89 | """Sets the learning rate to the initial LR decayed by 10 every 10 epochs"""
90 | lr = opt.lr * (0.1 ** (epoch // opt.step))
91 | return lr
92 |
93 | def train(training_data_loader, optimizer, model, criterion, epoch):
94 | lr = adjust_learning_rate(optimizer, epoch-1)
95 |
96 | for param_group in optimizer.param_groups:
97 | param_group["lr"] = lr
98 |
99 | print("Epoch = {}, lr = {}".format(epoch, optimizer.param_groups[0]["lr"]))
100 |
101 | model.train()
102 |
103 | for iteration, batch in enumerate(training_data_loader, 1):
104 | input, target = Variable(batch[0]), Variable(batch[1], requires_grad=False)
105 |
106 | if opt.cuda:
107 | input = input.cuda()
108 | target = target.cuda()
109 |
110 | loss = criterion(model(input), target)
111 | optimizer.zero_grad()
112 | loss.backward()
113 | nn.utils.clip_grad_norm(model.parameters(),opt.clip)
114 | optimizer.step()
115 |
116 | if iteration%100 == 0:
117 | print("===> Epoch[{}]({}/{}): Loss: {:.10f}".format(epoch, iteration, len(training_data_loader), loss.data[0]))
118 |
119 | def save_checkpoint(model, epoch):
120 | model_out_path = "checkpoint/" + "model_epoch_{}.pth".format(epoch)
121 | state = {"epoch": epoch ,"model": model}
122 | if not os.path.exists("checkpoint/"):
123 | os.makedirs("checkpoint/")
124 |
125 | torch.save(state, model_out_path)
126 |
127 | print("Checkpoint saved to {}".format(model_out_path))
128 |
129 | if __name__ == "__main__":
130 | main()
--------------------------------------------------------------------------------
/model/model_epoch_50.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/model/model_epoch_50.pth
--------------------------------------------------------------------------------
/result/input.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/result/input.bmp
--------------------------------------------------------------------------------
/result/output.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/result/output.bmp
--------------------------------------------------------------------------------
/vdsr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from math import sqrt
4 |
5 | class Conv_ReLU_Block(nn.Module):
6 | def __init__(self):
7 | super(Conv_ReLU_Block, self).__init__()
8 | self.conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
9 | self.relu = nn.ReLU(inplace=True)
10 |
11 | def forward(self, x):
12 | return self.relu(self.conv(x))
13 |
14 | class Net(nn.Module):
15 | def __init__(self):
16 | super(Net, self).__init__()
17 | self.residual_layer = self.make_layer(Conv_ReLU_Block, 18)
18 | self.input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
19 | self.output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)
20 | self.relu = nn.ReLU(inplace=True)
21 |
22 | for m in self.modules():
23 | if isinstance(m, nn.Conv2d):
24 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
25 | m.weight.data.normal_(0, sqrt(2. / n))
26 |
27 | def make_layer(self, block, num_of_layer):
28 | layers = []
29 | for _ in range(num_of_layer):
30 | layers.append(block())
31 | return nn.Sequential(*layers)
32 |
33 | def forward(self, x):
34 | residual = x
35 | out = self.relu(self.input(x))
36 | out = self.residual_layer(out)
37 | out = self.output(out)
38 | out = torch.add(out,residual)
39 | return out
40 |
--------------------------------------------------------------------------------