├── LICENSE ├── README.md ├── Set5 ├── baby_GT.bmp ├── baby_GT_scale_2.bmp ├── baby_GT_scale_3.bmp ├── baby_GT_scale_4.bmp ├── bird_GT.bmp ├── bird_GT_scale_2.bmp ├── bird_GT_scale_3.bmp ├── bird_GT_scale_4.bmp ├── butterfly_GT.bmp ├── butterfly_GT_scale_2.bmp ├── butterfly_GT_scale_3.bmp ├── butterfly_GT_scale_4.bmp ├── head_GT.bmp ├── head_GT_scale_2.bmp ├── head_GT_scale_3.bmp ├── head_GT_scale_4.bmp ├── woman_GT.bmp ├── woman_GT_scale_2.bmp ├── woman_GT_scale_3.bmp └── woman_GT_scale_4.bmp ├── Set5_mat ├── baby_GT_x2.mat ├── baby_GT_x3.mat ├── baby_GT_x4.mat ├── bird_GT_x2.mat ├── bird_GT_x3.mat ├── bird_GT_x4.mat ├── butterfly_GT_x2.mat ├── butterfly_GT_x3.mat ├── butterfly_GT_x4.mat ├── head_GT_x2.mat ├── head_GT_x3.mat ├── head_GT_x4.mat ├── woman_GT_x2.mat ├── woman_GT_x3.mat └── woman_GT_x4.mat ├── VDSR-Demo.ipynb ├── data ├── generate_test_mat.m ├── generate_train.m ├── modcrop.m ├── store2hdf5.m └── train.h5 ├── dataset.py ├── demo.py ├── eval.py ├── main_vdsr.py ├── model └── model_epoch_50.pth ├── result ├── input.bmp └── output.bmp └── vdsr.py /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2017- Jiu XU 4 | Copyright (c) 2017- Rakuten, Inc 5 | Copyright (c) 2017- Rakuten Institute of Technology 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch VDSR 2 | Implementation of CVPR2016 Paper: "Accurate Image Super-Resolution Using 3 | Very Deep Convolutional Networks"(http://cv.snu.ac.kr/research/VDSR/) in PyTorch 4 | 5 | ## Usage 6 | ### Training 7 | ``` 8 | usage: main_vdsr.py [-h] [--batchSize BATCHSIZE] [--nEpochs NEPOCHS] [--lr LR] 9 | [--step STEP] [--cuda] [--resume RESUME] 10 | [--start-epoch START_EPOCH] [--clip CLIP] [--threads THREADS] 11 | [--momentum MOMENTUM] [--weight-decay WEIGHT_DECAY] 12 | [--pretrained PRETRAINED] [--gpus GPUS] 13 | 14 | optional arguments: 15 | -h, --help Show this help message and exit 16 | --batchSize Training batch size 17 | --nEpochs Number of epochs to train for 18 | --lr Learning rate. Default=0.01 19 | --step Learning rate decay, Default: n=10 epochs 20 | --cuda Use cuda 21 | --resume Path to checkpoint 22 | --clip Clipping Gradients. Default=0.4 23 | --threads Number of threads for data loader to use Default=1 24 | --momentum Momentum, Default: 0.9 25 | --weight-decay Weight decay, Default: 1e-4 26 | --pretrained PRETRAINED 27 | path to pretrained model (default: none) 28 | --gpus GPUS gpu ids (default: 0) 29 | ``` 30 | An example of training usage is shown as follows: 31 | ``` 32 | python main_vdsr.py --cuda --gpus 0 33 | ``` 34 | 35 | ### Evaluation 36 | ``` 37 | usage: eval.py [-h] [--cuda] [--model MODEL] [--dataset DATASET] 38 | [--scale SCALE] [--gpus GPUS] 39 | 40 | PyTorch VDSR Eval 41 | 42 | optional arguments: 43 | -h, --help show this help message and exit 44 | --cuda use cuda? 45 | --model MODEL model path 46 | --dataset DATASET dataset name, Default: Set5 47 | --gpus GPUS gpu ids (default: 0) 48 | ``` 49 | An example of training usage is shown as follows: 50 | ``` 51 | python eval.py --cuda --dataset Set5 52 | ``` 53 | 54 | ### Demo 55 | ``` 56 | usage: demo.py [-h] [--cuda] [--model MODEL] [--image IMAGE] [--scale SCALE] [--gpus GPUS] 57 | 58 | optional arguments: 59 | -h, --help Show this help message and exit 60 | --cuda Use cuda 61 | --model Model path. Default=model/model_epoch_50.pth 62 | --image Image name. Default=butterfly_GT 63 | --scale Scale factor, Default: 4 64 | --gpus GPUS gpu ids (default: 0) 65 | ``` 66 | An example of usage is shown as follows: 67 | ``` 68 | python eval.py --model model/model_epoch_50.pth --dataset Set5 --cuda 69 | ``` 70 | 71 | ### Prepare Training dataset 72 | - We provide a simple hdf5 format training sample in data folder with 'data' and 'label' keys, the training data is generated with Matlab Bicubic Interplotation, please refer [Code for Data Generation](https://github.com/twtygqyy/pytorch-vdsr/tree/master/data) for creating training files. 73 | 74 | ### Performance 75 | - We provide a pretrained VDSR model trained on [291](https://drive.google.com/open?id=1Rt3asDLuMgLuJvPA1YrhyjWhb97Ly742) images with data augmentation 76 | - No bias is used in this implementation, and the gradient clipping's implementation is different from paper 77 | - Performance in PSNR on Set5 78 | 79 | | Scale | VDSR Paper | VDSR PyTorch| 80 | | ------------- |:-------------:| -----:| 81 | | 2x | 37.53 | 37.65 | 82 | | 3x | 33.66 | 33.77| 83 | | 4x | 31.35 | 31.45 | 84 | 85 | ### Result 86 | From left to right are ground truth, bicubic and vdsr 87 |

88 | 89 | 90 | 91 |

92 | -------------------------------------------------------------------------------- /Set5/baby_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/baby_GT.bmp -------------------------------------------------------------------------------- /Set5/baby_GT_scale_2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/baby_GT_scale_2.bmp -------------------------------------------------------------------------------- /Set5/baby_GT_scale_3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/baby_GT_scale_3.bmp -------------------------------------------------------------------------------- /Set5/baby_GT_scale_4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/baby_GT_scale_4.bmp -------------------------------------------------------------------------------- /Set5/bird_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/bird_GT.bmp -------------------------------------------------------------------------------- /Set5/bird_GT_scale_2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/bird_GT_scale_2.bmp -------------------------------------------------------------------------------- /Set5/bird_GT_scale_3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/bird_GT_scale_3.bmp -------------------------------------------------------------------------------- /Set5/bird_GT_scale_4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/bird_GT_scale_4.bmp -------------------------------------------------------------------------------- /Set5/butterfly_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/butterfly_GT.bmp -------------------------------------------------------------------------------- /Set5/butterfly_GT_scale_2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/butterfly_GT_scale_2.bmp -------------------------------------------------------------------------------- /Set5/butterfly_GT_scale_3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/butterfly_GT_scale_3.bmp -------------------------------------------------------------------------------- /Set5/butterfly_GT_scale_4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/butterfly_GT_scale_4.bmp -------------------------------------------------------------------------------- /Set5/head_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/head_GT.bmp -------------------------------------------------------------------------------- /Set5/head_GT_scale_2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/head_GT_scale_2.bmp -------------------------------------------------------------------------------- /Set5/head_GT_scale_3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/head_GT_scale_3.bmp -------------------------------------------------------------------------------- /Set5/head_GT_scale_4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/head_GT_scale_4.bmp -------------------------------------------------------------------------------- /Set5/woman_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/woman_GT.bmp -------------------------------------------------------------------------------- /Set5/woman_GT_scale_2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/woman_GT_scale_2.bmp -------------------------------------------------------------------------------- /Set5/woman_GT_scale_3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/woman_GT_scale_3.bmp -------------------------------------------------------------------------------- /Set5/woman_GT_scale_4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5/woman_GT_scale_4.bmp -------------------------------------------------------------------------------- /Set5_mat/baby_GT_x2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/baby_GT_x2.mat -------------------------------------------------------------------------------- /Set5_mat/baby_GT_x3.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/baby_GT_x3.mat -------------------------------------------------------------------------------- /Set5_mat/baby_GT_x4.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/baby_GT_x4.mat -------------------------------------------------------------------------------- /Set5_mat/bird_GT_x2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/bird_GT_x2.mat -------------------------------------------------------------------------------- /Set5_mat/bird_GT_x3.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/bird_GT_x3.mat -------------------------------------------------------------------------------- /Set5_mat/bird_GT_x4.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/bird_GT_x4.mat -------------------------------------------------------------------------------- /Set5_mat/butterfly_GT_x2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/butterfly_GT_x2.mat -------------------------------------------------------------------------------- /Set5_mat/butterfly_GT_x3.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/butterfly_GT_x3.mat -------------------------------------------------------------------------------- /Set5_mat/butterfly_GT_x4.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/butterfly_GT_x4.mat -------------------------------------------------------------------------------- /Set5_mat/head_GT_x2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/head_GT_x2.mat -------------------------------------------------------------------------------- /Set5_mat/head_GT_x3.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/head_GT_x3.mat -------------------------------------------------------------------------------- /Set5_mat/head_GT_x4.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/head_GT_x4.mat -------------------------------------------------------------------------------- /Set5_mat/woman_GT_x2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/woman_GT_x2.mat -------------------------------------------------------------------------------- /Set5_mat/woman_GT_x3.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/woman_GT_x3.mat -------------------------------------------------------------------------------- /Set5_mat/woman_GT_x4.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/Set5_mat/woman_GT_x4.mat -------------------------------------------------------------------------------- /data/generate_test_mat.m: -------------------------------------------------------------------------------- 1 | clear;close all; 2 | %% settings 3 | folder = 'Set5'; 4 | 5 | %% generate data 6 | filepaths = []; 7 | filepaths = [filepaths; dir(fullfile(folder, '*.bmp'))]; 8 | 9 | scale = [2, 3, 4]; 10 | 11 | for i = 1 : length(filepaths) 12 | im_gt = imread(fullfile(folder,filepaths(i).name)); 13 | for s = 1 : length(scale) 14 | im_gt = modcrop(im_gt, scale(s)); 15 | im_gt = double(im_gt); 16 | im_gt_ycbcr = rgb2ycbcr(im_gt / 255.0); 17 | im_gt_y = im_gt_ycbcr(:,:,1) * 255.0; 18 | im_l_ycbcr = imresize(im_gt_ycbcr,1/scale(s),'bicubic'); 19 | im_b_ycbcr = imresize(im_l_ycbcr,scale(s),'bicubic'); 20 | im_l_y = im_l_ycbcr(:,:,1) * 255.0; 21 | im_l = ycbcr2rgb(im_l_ycbcr) * 255.0; 22 | im_b_y = im_b_ycbcr(:,:,1) * 255.0; 23 | im_b = ycbcr2rgb(im_b_ycbcr) * 255.0; 24 | last = length(filepaths(i).name)-4; 25 | filename = sprintf('Set5_mat/%s_x%s.mat',filepaths(i).name(1 : last),num2str(scale(s))); 26 | save(filename, 'im_gt_y', 'im_b_y', 'im_gt', 'im_b', 'im_l_ycbcr', 'im_l_y', 'im_l'); 27 | end 28 | end 29 | -------------------------------------------------------------------------------- /data/generate_train.m: -------------------------------------------------------------------------------- 1 | clear;close all; 2 | 3 | folder = 'path/to/train/folder'; 4 | 5 | savepath = 'train.h5'; 6 | size_input = 41; 7 | size_label = 41; 8 | stride = 41; 9 | 10 | %% scale factors 11 | scale = [2,3,4]; 12 | %% downsizing 13 | downsizes = [1,0.7,0.5]; 14 | 15 | %% initialization 16 | data = zeros(size_input, size_input, 1, 1); 17 | label = zeros(size_label, size_label, 1, 1); 18 | 19 | count = 0; 20 | margain = 0; 21 | 22 | %% generate data 23 | filepaths = []; 24 | filepaths = [filepaths; dir(fullfile(folder, '*.jpg'))]; 25 | filepaths = [filepaths; dir(fullfile(folder, '*.bmp'))]; 26 | 27 | for i = 1 : length(filepaths) 28 | for flip = 1: 3 29 | for degree = 1 : 4 30 | for s = 1 : length(scale) 31 | for downsize = 1 : length(downsizes) 32 | image = imread(fullfile(folder,filepaths(i).name)); 33 | 34 | if flip == 1 35 | image = flipdim(image ,1); 36 | end 37 | if flip == 2 38 | image = flipdim(image ,2); 39 | end 40 | 41 | image = imrotate(image, 90 * (degree - 1)); 42 | 43 | image = imresize(image,downsizes(downsize),'bicubic'); 44 | 45 | if size(image,3)==3 46 | image = rgb2ycbcr(image); 47 | image = im2double(image(:, :, 1)); 48 | 49 | im_label = modcrop(image, scale(s)); 50 | [hei,wid] = size(im_label); 51 | im_input = imresize(imresize(im_label,1/scale(s),'bicubic'),[hei,wid],'bicubic'); 52 | filepaths(i).name 53 | for x = 1 : stride : hei-size_input+1 54 | for y = 1 :stride : wid-size_input+1 55 | 56 | subim_input = im_input(x : x+size_input-1, y : y+size_input-1); 57 | subim_label = im_label(x : x+size_label-1, y : y+size_label-1); 58 | 59 | count=count+1; 60 | 61 | data(:, :, 1, count) = subim_input; 62 | label(:, :, 1, count) = subim_label; 63 | end 64 | end 65 | end 66 | end 67 | end 68 | end 69 | end 70 | end 71 | 72 | order = randperm(count); 73 | data = data(:, :, 1, order); 74 | label = label(:, :, 1, order); 75 | 76 | %% writing to HDF5 77 | chunksz = 64; 78 | created_flag = false; 79 | totalct = 0; 80 | 81 | for batchno = 1:floor(count/chunksz) 82 | batchno 83 | last_read=(batchno-1)*chunksz; 84 | batchdata = data(:,:,1,last_read+1:last_read+chunksz); 85 | batchlabs = label(:,:,1,last_read+1:last_read+chunksz); 86 | 87 | startloc = struct('dat',[1,1,1,totalct+1], 'lab', [1,1,1,totalct+1]); 88 | curr_dat_sz = store2hdf5(savepath, batchdata, batchlabs, ~created_flag, startloc, chunksz); 89 | created_flag = true; 90 | totalct = curr_dat_sz(end); 91 | end 92 | 93 | h5disp(savepath); 94 | -------------------------------------------------------------------------------- /data/modcrop.m: -------------------------------------------------------------------------------- 1 | function imgs = modcrop(imgs, modulo) 2 | if size(imgs,3)==1 3 | sz = size(imgs); 4 | sz = sz - mod(sz, modulo); 5 | imgs = imgs(1:sz(1), 1:sz(2)); 6 | else 7 | tmpsz = size(imgs); 8 | sz = tmpsz(1:2); 9 | sz = sz - mod(sz, modulo); 10 | imgs = imgs(1:sz(1), 1:sz(2),:); 11 | end 12 | 13 | -------------------------------------------------------------------------------- /data/store2hdf5.m: -------------------------------------------------------------------------------- 1 | function [curr_dat_sz, curr_lab_sz] = store2hdf5(filename, data, labels, create, startloc, chunksz) 2 | % *data* is W*H*C*N matrix of images should be normalized (e.g. to lie between 0 and 1) beforehand 3 | % *label* is D*N matrix of labels (D labels per sample) 4 | % *create* [0/1] specifies whether to create file newly or to append to previously created file, useful to store information in batches when a dataset is too big to be held in memory (default: 1) 5 | % *startloc* (point at which to start writing data). By default, 6 | % if create=1 (create mode), startloc.data=[1 1 1 1], and startloc.lab=[1 1]; 7 | % if create=0 (append mode), startloc.data=[1 1 1 K+1], and startloc.lab = [1 K+1]; where K is the current number of samples stored in the HDF 8 | % chunksz (used only in create mode), specifies number of samples to be stored per chunk (see HDF5 documentation on chunking) for creating HDF5 files with unbounded maximum size - TLDR; higher chunk sizes allow faster read-write operations 9 | 10 | % verify that format is right 11 | dat_dims=size(data); 12 | lab_dims=size(labels); 13 | num_samples=dat_dims(end); 14 | 15 | assert(lab_dims(end)==num_samples, 'Number of samples should be matched between data and labels'); 16 | 17 | if ~exist('create','var') 18 | create=true; 19 | end 20 | 21 | 22 | if create 23 | %fprintf('Creating dataset with %d samples\n', num_samples); 24 | if ~exist('chunksz', 'var') 25 | chunksz=1000; 26 | end 27 | if exist(filename, 'file') 28 | fprintf('Warning: replacing existing file %s \n', filename); 29 | delete(filename); 30 | end 31 | h5create(filename, '/data', [dat_dims(1:end-1) Inf], 'Datatype', 'single', 'ChunkSize', [dat_dims(1:end-1) chunksz]); % width, height, channels, number 32 | h5create(filename, '/label', [lab_dims(1:end-1) Inf], 'Datatype', 'single', 'ChunkSize', [lab_dims(1:end-1) chunksz]); % width, height, channels, number 33 | if ~exist('startloc','var') 34 | startloc.dat=[ones(1,length(dat_dims)-1), 1]; 35 | startloc.lab=[ones(1,length(lab_dims)-1), 1]; 36 | end 37 | else % append mode 38 | if ~exist('startloc','var') 39 | info=h5info(filename); 40 | prev_dat_sz=info.Datasets(1).Dataspace.Size; 41 | prev_lab_sz=info.Datasets(2).Dataspace.Size; 42 | assert(prev_dat_sz(1:end-1)==dat_dims(1:end-1), 'Data dimensions must match existing dimensions in dataset'); 43 | assert(prev_lab_sz(1:end-1)==lab_dims(1:end-1), 'Label dimensions must match existing dimensions in dataset'); 44 | startloc.dat=[ones(1,length(dat_dims)-1), prev_dat_sz(end)+1]; 45 | startloc.lab=[ones(1,length(lab_dims)-1), prev_lab_sz(end)+1]; 46 | end 47 | end 48 | 49 | if ~isempty(data) 50 | h5write(filename, '/data', single(data), startloc.dat, size(data)); 51 | h5write(filename, '/label', single(labels), startloc.lab, size(labels)); 52 | end 53 | 54 | if nargout 55 | info=h5info(filename); 56 | curr_dat_sz=info.Datasets(1).Dataspace.Size; 57 | curr_lab_sz=info.Datasets(2).Dataspace.Size; 58 | end 59 | end 60 | -------------------------------------------------------------------------------- /data/train.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/data/train.h5 -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | import h5py 4 | 5 | class DatasetFromHdf5(data.Dataset): 6 | def __init__(self, file_path): 7 | super(DatasetFromHdf5, self).__init__() 8 | hf = h5py.File(file_path) 9 | self.data = hf.get('data') 10 | self.target = hf.get('label') 11 | 12 | def __getitem__(self, index): 13 | return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float() 14 | 15 | def __len__(self): 16 | return self.data.shape[0] -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse, os 2 | import torch 3 | from torch.autograd import Variable 4 | from scipy.ndimage import imread 5 | from PIL import Image 6 | import numpy as np 7 | import time, math 8 | import matplotlib.pyplot as plt 9 | 10 | parser = argparse.ArgumentParser(description="PyTorch VDSR Demo") 11 | parser.add_argument("--cuda", action="store_true", help="use cuda?") 12 | parser.add_argument("--model", default="model/model_epoch_50.pth", type=str, help="model path") 13 | parser.add_argument("--image", default="butterfly_GT", type=str, help="image name") 14 | parser.add_argument("--scale", default=4, type=int, help="scale factor, Default: 4") 15 | parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)") 16 | 17 | def PSNR(pred, gt, shave_border=0): 18 | height, width = pred.shape[:2] 19 | pred = pred[shave_border:height - shave_border, shave_border:width - shave_border] 20 | gt = gt[shave_border:height - shave_border, shave_border:width - shave_border] 21 | imdff = pred - gt 22 | rmse = math.sqrt(np.mean(imdff ** 2)) 23 | if rmse == 0: 24 | return 100 25 | return 20 * math.log10(255.0 / rmse) 26 | 27 | def colorize(y, ycbcr): 28 | img = np.zeros((y.shape[0], y.shape[1], 3), np.uint8) 29 | img[:,:,0] = y 30 | img[:,:,1] = ycbcr[:,:,1] 31 | img[:,:,2] = ycbcr[:,:,2] 32 | img = Image.fromarray(img, "YCbCr").convert("RGB") 33 | return img 34 | 35 | opt = parser.parse_args() 36 | cuda = opt.cuda 37 | 38 | if cuda: 39 | print("=> use gpu id: '{}'".format(opt.gpus)) 40 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus 41 | if not torch.cuda.is_available(): 42 | raise Exception("No GPU found or Wrong gpu id, please run without --cuda") 43 | 44 | 45 | model = torch.load(opt.model, map_location=lambda storage, loc: storage)["model"] 46 | 47 | im_gt_ycbcr = imread("Set5/" + opt.image + ".bmp", mode="YCbCr") 48 | im_b_ycbcr = imread("Set5/"+ opt.image + "_scale_"+ str(opt.scale) + ".bmp", mode="YCbCr") 49 | 50 | im_gt_y = im_gt_ycbcr[:,:,0].astype(float) 51 | im_b_y = im_b_ycbcr[:,:,0].astype(float) 52 | 53 | psnr_bicubic = PSNR(im_gt_y, im_b_y,shave_border=opt.scale) 54 | 55 | im_input = im_b_y/255. 56 | 57 | im_input = Variable(torch.from_numpy(im_input).float()).view(1, -1, im_input.shape[0], im_input.shape[1]) 58 | 59 | if cuda: 60 | model = model.cuda() 61 | im_input = im_input.cuda() 62 | else: 63 | model = model.cpu() 64 | 65 | start_time = time.time() 66 | out = model(im_input) 67 | elapsed_time = time.time() - start_time 68 | 69 | out = out.cpu() 70 | 71 | im_h_y = out.data[0].numpy().astype(np.float32) 72 | 73 | im_h_y = im_h_y * 255. 74 | im_h_y[im_h_y < 0] = 0 75 | im_h_y[im_h_y > 255.] = 255. 76 | 77 | psnr_predicted = PSNR(im_gt_y, im_h_y[0,:,:], shave_border=opt.scale) 78 | 79 | im_h = colorize(im_h_y[0,:,:], im_b_ycbcr) 80 | im_gt = Image.fromarray(im_gt_ycbcr, "YCbCr").convert("RGB") 81 | im_b = Image.fromarray(im_b_ycbcr, "YCbCr").convert("RGB") 82 | 83 | print("Scale=",opt.scale) 84 | print("PSNR_predicted=", psnr_predicted) 85 | print("PSNR_bicubic=", psnr_bicubic) 86 | print("It takes {}s for processing".format(elapsed_time)) 87 | 88 | fig = plt.figure() 89 | ax = plt.subplot("131") 90 | ax.imshow(im_gt) 91 | ax.set_title("GT") 92 | 93 | ax = plt.subplot("132") 94 | ax.imshow(im_b) 95 | ax.set_title("Input(bicubic)") 96 | 97 | ax = plt.subplot("133") 98 | ax.imshow(im_h) 99 | ax.set_title("Output(vdsr)") 100 | plt.show() 101 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse, os 2 | import torch 3 | from torch.autograd import Variable 4 | import numpy as np 5 | import time, math, glob 6 | import scipy.io as sio 7 | 8 | parser = argparse.ArgumentParser(description="PyTorch VDSR Eval") 9 | parser.add_argument("--cuda", action="store_true", help="use cuda?") 10 | parser.add_argument("--model", default="model/model_epoch_50.pth", type=str, help="model path") 11 | parser.add_argument("--dataset", default="Set5", type=str, help="dataset name, Default: Set5") 12 | parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)") 13 | 14 | def PSNR(pred, gt, shave_border=0): 15 | height, width = pred.shape[:2] 16 | pred = pred[shave_border:height - shave_border, shave_border:width - shave_border] 17 | gt = gt[shave_border:height - shave_border, shave_border:width - shave_border] 18 | imdff = pred - gt 19 | rmse = math.sqrt(np.mean(imdff ** 2)) 20 | if rmse == 0: 21 | return 100 22 | return 20 * math.log10(255.0 / rmse) 23 | 24 | opt = parser.parse_args() 25 | cuda = opt.cuda 26 | 27 | if cuda: 28 | print("=> use gpu id: '{}'".format(opt.gpus)) 29 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus 30 | if not torch.cuda.is_available(): 31 | raise Exception("No GPU found or Wrong gpu id, please run without --cuda") 32 | 33 | model = torch.load(opt.model, map_location=lambda storage, loc: storage)["model"] 34 | 35 | scales = [2,3,4] 36 | 37 | image_list = glob.glob(opt.dataset+"_mat/*.*") 38 | 39 | for scale in scales: 40 | avg_psnr_predicted = 0.0 41 | avg_psnr_bicubic = 0.0 42 | avg_elapsed_time = 0.0 43 | count = 0.0 44 | for image_name in image_list: 45 | if str(scale) in image_name: 46 | count += 1 47 | print("Processing ", image_name) 48 | im_gt_y = sio.loadmat(image_name)['im_gt_y'] 49 | im_b_y = sio.loadmat(image_name)['im_b_y'] 50 | 51 | im_gt_y = im_gt_y.astype(float) 52 | im_b_y = im_b_y.astype(float) 53 | 54 | psnr_bicubic = PSNR(im_gt_y, im_b_y,shave_border=scale) 55 | avg_psnr_bicubic += psnr_bicubic 56 | 57 | im_input = im_b_y/255. 58 | 59 | im_input = Variable(torch.from_numpy(im_input).float()).view(1, -1, im_input.shape[0], im_input.shape[1]) 60 | 61 | if cuda: 62 | model = model.cuda() 63 | im_input = im_input.cuda() 64 | else: 65 | model = model.cpu() 66 | 67 | start_time = time.time() 68 | HR = model(im_input) 69 | elapsed_time = time.time() - start_time 70 | avg_elapsed_time += elapsed_time 71 | 72 | HR = HR.cpu() 73 | 74 | im_h_y = HR.data[0].numpy().astype(np.float32) 75 | 76 | im_h_y = im_h_y * 255. 77 | im_h_y[im_h_y < 0] = 0 78 | im_h_y[im_h_y > 255.] = 255. 79 | im_h_y = im_h_y[0,:,:] 80 | 81 | psnr_predicted = PSNR(im_gt_y, im_h_y,shave_border=scale) 82 | avg_psnr_predicted += psnr_predicted 83 | 84 | print("Scale=", scale) 85 | print("Dataset=", opt.dataset) 86 | print("PSNR_predicted=", avg_psnr_predicted/count) 87 | print("PSNR_bicubic=", avg_psnr_bicubic/count) 88 | print("It takes average {}s for processing".format(avg_elapsed_time/count)) 89 | -------------------------------------------------------------------------------- /main_vdsr.py: -------------------------------------------------------------------------------- 1 | import argparse, os 2 | import torch 3 | import random 4 | import torch.backends.cudnn as cudnn 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.autograd import Variable 8 | from torch.utils.data import DataLoader 9 | from vdsr import Net 10 | from dataset import DatasetFromHdf5 11 | 12 | # Training settings 13 | parser = argparse.ArgumentParser(description="PyTorch VDSR") 14 | parser.add_argument("--batchSize", type=int, default=128, help="Training batch size") 15 | parser.add_argument("--nEpochs", type=int, default=50, help="Number of epochs to train for") 16 | parser.add_argument("--lr", type=float, default=0.1, help="Learning Rate. Default=0.1") 17 | parser.add_argument("--step", type=int, default=10, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=10") 18 | parser.add_argument("--cuda", action="store_true", help="Use cuda?") 19 | parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)") 20 | parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)") 21 | parser.add_argument("--clip", type=float, default=0.4, help="Clipping Gradients. Default=0.4") 22 | parser.add_argument("--threads", type=int, default=1, help="Number of threads for data loader to use, Default: 1") 23 | parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9") 24 | parser.add_argument("--weight-decay", "--wd", default=1e-4, type=float, help="Weight decay, Default: 1e-4") 25 | parser.add_argument('--pretrained', default='', type=str, help='path to pretrained model (default: none)') 26 | parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)") 27 | 28 | def main(): 29 | global opt, model 30 | opt = parser.parse_args() 31 | print(opt) 32 | 33 | cuda = opt.cuda 34 | if cuda: 35 | print("=> use gpu id: '{}'".format(opt.gpus)) 36 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus 37 | if not torch.cuda.is_available(): 38 | raise Exception("No GPU found or Wrong gpu id, please run without --cuda") 39 | 40 | opt.seed = random.randint(1, 10000) 41 | print("Random Seed: ", opt.seed) 42 | torch.manual_seed(opt.seed) 43 | if cuda: 44 | torch.cuda.manual_seed(opt.seed) 45 | 46 | cudnn.benchmark = True 47 | 48 | print("===> Loading datasets") 49 | train_set = DatasetFromHdf5("data/train.h5") 50 | training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True) 51 | 52 | print("===> Building model") 53 | model = Net() 54 | criterion = nn.MSELoss(size_average=False) 55 | 56 | print("===> Setting GPU") 57 | if cuda: 58 | model = model.cuda() 59 | criterion = criterion.cuda() 60 | 61 | # optionally resume from a checkpoint 62 | if opt.resume: 63 | if os.path.isfile(opt.resume): 64 | print("=> loading checkpoint '{}'".format(opt.resume)) 65 | checkpoint = torch.load(opt.resume) 66 | opt.start_epoch = checkpoint["epoch"] + 1 67 | model.load_state_dict(checkpoint["model"].state_dict()) 68 | else: 69 | print("=> no checkpoint found at '{}'".format(opt.resume)) 70 | 71 | # optionally copy weights from a checkpoint 72 | if opt.pretrained: 73 | if os.path.isfile(opt.pretrained): 74 | print("=> loading model '{}'".format(opt.pretrained)) 75 | weights = torch.load(opt.pretrained) 76 | model.load_state_dict(weights['model'].state_dict()) 77 | else: 78 | print("=> no model found at '{}'".format(opt.pretrained)) 79 | 80 | print("===> Setting Optimizer") 81 | optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay) 82 | 83 | print("===> Training") 84 | for epoch in range(opt.start_epoch, opt.nEpochs + 1): 85 | train(training_data_loader, optimizer, model, criterion, epoch) 86 | save_checkpoint(model, epoch) 87 | 88 | def adjust_learning_rate(optimizer, epoch): 89 | """Sets the learning rate to the initial LR decayed by 10 every 10 epochs""" 90 | lr = opt.lr * (0.1 ** (epoch // opt.step)) 91 | return lr 92 | 93 | def train(training_data_loader, optimizer, model, criterion, epoch): 94 | lr = adjust_learning_rate(optimizer, epoch-1) 95 | 96 | for param_group in optimizer.param_groups: 97 | param_group["lr"] = lr 98 | 99 | print("Epoch = {}, lr = {}".format(epoch, optimizer.param_groups[0]["lr"])) 100 | 101 | model.train() 102 | 103 | for iteration, batch in enumerate(training_data_loader, 1): 104 | input, target = Variable(batch[0]), Variable(batch[1], requires_grad=False) 105 | 106 | if opt.cuda: 107 | input = input.cuda() 108 | target = target.cuda() 109 | 110 | loss = criterion(model(input), target) 111 | optimizer.zero_grad() 112 | loss.backward() 113 | nn.utils.clip_grad_norm(model.parameters(),opt.clip) 114 | optimizer.step() 115 | 116 | if iteration%100 == 0: 117 | print("===> Epoch[{}]({}/{}): Loss: {:.10f}".format(epoch, iteration, len(training_data_loader), loss.data[0])) 118 | 119 | def save_checkpoint(model, epoch): 120 | model_out_path = "checkpoint/" + "model_epoch_{}.pth".format(epoch) 121 | state = {"epoch": epoch ,"model": model} 122 | if not os.path.exists("checkpoint/"): 123 | os.makedirs("checkpoint/") 124 | 125 | torch.save(state, model_out_path) 126 | 127 | print("Checkpoint saved to {}".format(model_out_path)) 128 | 129 | if __name__ == "__main__": 130 | main() -------------------------------------------------------------------------------- /model/model_epoch_50.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/model/model_epoch_50.pth -------------------------------------------------------------------------------- /result/input.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/result/input.bmp -------------------------------------------------------------------------------- /result/output.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twtygqyy/pytorch-vdsr/514b021044018baf909e79f48392783daa592888/result/output.bmp -------------------------------------------------------------------------------- /vdsr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from math import sqrt 4 | 5 | class Conv_ReLU_Block(nn.Module): 6 | def __init__(self): 7 | super(Conv_ReLU_Block, self).__init__() 8 | self.conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 9 | self.relu = nn.ReLU(inplace=True) 10 | 11 | def forward(self, x): 12 | return self.relu(self.conv(x)) 13 | 14 | class Net(nn.Module): 15 | def __init__(self): 16 | super(Net, self).__init__() 17 | self.residual_layer = self.make_layer(Conv_ReLU_Block, 18) 18 | self.input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 19 | self.output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False) 20 | self.relu = nn.ReLU(inplace=True) 21 | 22 | for m in self.modules(): 23 | if isinstance(m, nn.Conv2d): 24 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 25 | m.weight.data.normal_(0, sqrt(2. / n)) 26 | 27 | def make_layer(self, block, num_of_layer): 28 | layers = [] 29 | for _ in range(num_of_layer): 30 | layers.append(block()) 31 | return nn.Sequential(*layers) 32 | 33 | def forward(self, x): 34 | residual = x 35 | out = self.relu(self.input(x)) 36 | out = self.residual_layer(out) 37 | out = self.output(out) 38 | out = torch.add(out,residual) 39 | return out 40 | --------------------------------------------------------------------------------