├── install-instructions ├── testing.PNG ├── training.PNG ├── create_train.PNG ├── qualitative.PNG ├── quantitative.PNG ├── create_train_result.PNG └── network_architecture.PNG ├── .gitignore ├── matlab_scripts ├── modcrop.m ├── generate_test_video.m ├── store2hdf5.m └── generate_train_video.m ├── loss.py ├── colorize.py ├── pytorch_ssim.py ├── test.py ├── SR_datasets.py ├── train.py ├── model.py ├── README.md └── solver.py /install-instructions/testing.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thangvubk/video-super-resolution/HEAD/install-instructions/testing.PNG -------------------------------------------------------------------------------- /install-instructions/training.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thangvubk/video-super-resolution/HEAD/install-instructions/training.PNG -------------------------------------------------------------------------------- /install-instructions/create_train.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thangvubk/video-super-resolution/HEAD/install-instructions/create_train.PNG -------------------------------------------------------------------------------- /install-instructions/qualitative.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thangvubk/video-super-resolution/HEAD/install-instructions/qualitative.PNG -------------------------------------------------------------------------------- /install-instructions/quantitative.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thangvubk/video-super-resolution/HEAD/install-instructions/quantitative.PNG -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .ipynb_checkpoints 3 | preprocessed_data/ 4 | data/ 5 | Results/ 6 | *.m~ 7 | *.ipynb 8 | check_point/ 9 | results/ 10 | -------------------------------------------------------------------------------- /install-instructions/create_train_result.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thangvubk/video-super-resolution/HEAD/install-instructions/create_train_result.PNG -------------------------------------------------------------------------------- /install-instructions/network_architecture.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thangvubk/video-super-resolution/HEAD/install-instructions/network_architecture.PNG -------------------------------------------------------------------------------- /matlab_scripts/modcrop.m: -------------------------------------------------------------------------------- 1 | function img = modcrop(img, scale) 2 | % The img size should be divided by scale, to align interpolation 3 | sz = size(img); 4 | sz = sz - mod(sz, scale); 5 | img = img(1:sz(1), 1:sz(2)); 6 | end 7 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import pytorch_ssim 2 | import torch.nn as nn 3 | 4 | MSE_and_SSIM_model = ['VRES', 'VRES5', 'VRES10', 'VRES7', 'VRES15'] 5 | 6 | 7 | def get_loss_fn(model_name): 8 | if model_name in MSE_and_SSIM_model: 9 | return MSE_and_SSIM_loss() 10 | else: 11 | return nn.MSELoss() 12 | 13 | 14 | class MSE_and_SSIM_loss(nn.Module): 15 | def __init__(self, alpha=0.9): 16 | super(MSE_and_SSIM_loss, self).__init__() 17 | self.MSE = nn.MSELoss() 18 | self.SSIM = pytorch_ssim.SSIM() 19 | self.alpha = alpha 20 | 21 | def forward(self, img1, img2): 22 | loss = (self.alpha*self.MSE(img1, img2) 23 | + (1 - self.alpha)*(1 - self.SSIM(img1, img2))) 24 | return loss 25 | -------------------------------------------------------------------------------- /colorize.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | import argparse 5 | 6 | import cv2 7 | from tqdm import tqdm 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--org-img-path', default='./data/test/IndMya/') 11 | parser.add_argument('--out-img-path', default='./results/VRES/3x') 12 | args = parser.parse_args() 13 | 14 | save_img_path = osp.join(args.out_img_path, 'color') 15 | if not osp.exists(save_img_path): 16 | os.makedirs(save_img_path) 17 | org_imgs = glob.glob(osp.join(args.org_img_path, '*')) 18 | out_imgs = glob.glob(osp.join(args.out_img_path, '*.png')) 19 | org_imgs.sort() 20 | out_imgs.sort() 21 | assert len(org_imgs) == len(out_imgs) 22 | 23 | for i in tqdm(range(len(org_imgs))): 24 | # Get center frame (frame3) of org img 25 | org_img = glob.glob(osp.join(org_imgs[i], '*f3*')) 26 | org_img = cv2.imread(org_img[0]) 27 | org_img_ycbcr = cv2.cvtColor(org_img, cv2.COLOR_BGR2YCR_CB) 28 | 29 | # get out img in gray scale 30 | out_img = cv2.imread(out_imgs[i], 0) 31 | 32 | # merge images 33 | h, w = out_img.shape 34 | save_img = org_img_ycbcr[:h, :w, :] 35 | save_img[:, :, 0] = out_img 36 | save_img = cv2.cvtColor(save_img, cv2.COLOR_YCrCb2BGR) 37 | save_img_name = osp.join(save_img_path, osp.basename(out_imgs[i])) 38 | cv2.imwrite(save_img_name, save_img) 39 | -------------------------------------------------------------------------------- /matlab_scripts/generate_test_video.m: -------------------------------------------------------------------------------- 1 | clear; close all; 2 | 3 | %% Configurationi 4 | % NOTE: you can modify this part 5 | test_set = 'IndMya'; %(IndMya, vid4/city, vid3/walk, vid4/calendar, vid4/foliage) 6 | scale = 3; % (2, 3, 4) 7 | 8 | %% Create save path for high resolution and low resolution images based on config 9 | % NOTE: you should NOT modify the following parts 10 | disp(sprintf('%10s: %s', 'Test set', test_set)); 11 | disp(sprintf('%10s: %d', 'Scale', scale)); 12 | 13 | scale_dir = strcat(int2str(scale), 'x'); 14 | 15 | % example 16 | % read_path = '../data/test/myanmar/' 17 | % save_path = '../preprocessed_data//test/myanmar/3x/' 18 | read_path = fullfile('../data', 'test', test_set); 19 | save_path = fullfile('../preprocessed_data', 'test', test_set, scale_dir); 20 | 21 | if exist(save_path, 'dir') ~= 7 22 | mkdir(save_path) 23 | end 24 | 25 | is_init_data = true; 26 | 27 | % get folder in read_path 28 | dirs = dir(read_path); 29 | 30 | 31 | count = 0; 32 | for i_dir = 1 : length(dirs) 33 | scene_dir = dirs(i_dir).name; 34 | if scene_dir(1) ~= 's' %valid folder begin with 's' 35 | continue 36 | end 37 | 38 | disp(sprintf('processing dir: %s', scene_dir)); 39 | 40 | count = count + 1; 41 | 42 | filepaths = dir(fullfile(read_path, scene_dir, '*.bmp')); 43 | 44 | for i = 1 : length(filepaths) 45 | image = imread(fullfile(read_path, scene_dir, filepaths(i).name)); 46 | if size(image, 3) == 3 47 | image_ycbcr = rgb2ycbcr(image); 48 | image_y = image_ycbcr(:, :, 1); 49 | end 50 | hr_im = im2double(image_y); 51 | hr_im = modcrop(hr_im, scale); 52 | [hei, wid] = size(hr_im); 53 | lr_im = imresize(hr_im,1/scale,'bicubic'); 54 | lr_im = imresize(lr_im ,[hei, wid],'bicubic'); 55 | 56 | if is_init_data 57 | data = zeros(hei, wid, 5, 1); 58 | label = zeros(hei, wid, 5, 1); 59 | is_init_data = false; 60 | end 61 | 62 | data(:, :, i, count) = lr_im; 63 | label(:, :, i, count) = hr_im; 64 | end 65 | end 66 | 67 | %% writing to HDF5 68 | chunksz = 2; 69 | created_flag = false; 70 | totalct = 0; 71 | 72 | for batchno = 1:floor((count)/chunksz) 73 | last_read=(batchno-1)*chunksz; 74 | batchdata = data(:,:,:,last_read+1:last_read+chunksz); 75 | batchlabs = label(:,:,:,last_read+1:last_read+chunksz); 76 | 77 | startloc = struct('dat',[1,1,1,totalct+1], 'lab', [1,1,1,totalct+1]); 78 | curr_dat_sz = store2hdf5(fullfile(save_path, 'dataset.h5'), batchdata, batchlabs, ~created_flag, startloc, chunksz); 79 | created_flag = true; 80 | totalct = curr_dat_sz(end); 81 | end 82 | h5disp(fullfile(save_path, 'dataset.h5')); 83 | 84 | -------------------------------------------------------------------------------- /pytorch_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 1 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | (_, channel, _, _) = img1.size() 49 | 50 | if channel == self.channel and self.window.data.type() == img1.data.type(): 51 | window = self.window 52 | else: 53 | window = create_window(self.window_size, channel) 54 | 55 | if img1.is_cuda: 56 | window = window.cuda(img1.get_device()) 57 | window = window.type_as(img1) 58 | 59 | self.window = window 60 | self.channel = channel 61 | 62 | 63 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 64 | 65 | def ssim(img1, img2, window_size = 11, size_average = True): 66 | (_, channel, _, _) = img1.size() 67 | window = create_window(window_size, channel) 68 | 69 | if img1.is_cuda: 70 | window = window.cuda(img1.get_device()) 71 | window = window.type_as(img1) 72 | 73 | return _ssim(img1, img2, window, window_size, channel, size_average) 74 | -------------------------------------------------------------------------------- /matlab_scripts/store2hdf5.m: -------------------------------------------------------------------------------- 1 | function [curr_dat_sz, curr_lab_sz] = store2hdf5(filename, data, labels, create, startloc, chunksz) 2 | % *data* is W*H*C*N matrix of images should be normalized (e.g. to lie between 0 and 1) beforehand 3 | % *label* is D*N matrix of labels (D labels per sample) 4 | % *create* [0/1] specifies whether to create file newly or to append to previously created file, useful to store information in batches when a dataset is too big to be held in memory (default: 1) 5 | % *startloc* (point at which to start writing data). By default, 6 | % if create=1 (create mode), startloc.data=[1 1 1 1], and startloc.lab=[1 1]; 7 | % if create=0 (append mode), startloc.data=[1 1 1 K+1], and startloc.lab = [1 K+1]; where K is the current number of samples stored in the HDF 8 | % chunksz (used only in create mode), specifies number of samples to be stored per chunk (see HDF5 documentation on chunking) for creating HDF5 files with unbounded maximum size - TLDR; higher chunk sizes allow faster read-write operations 9 | 10 | % verify that format is right 11 | dat_dims=size(data); 12 | lab_dims=size(labels); 13 | num_samples=dat_dims(end); 14 | 15 | assert(lab_dims(end)==num_samples, 'Number of samples should be matched between data and labels'); 16 | 17 | if ~exist('create','var') 18 | create=true; 19 | end 20 | 21 | 22 | if create 23 | %fprintf('Creating dataset with %d samples\n', num_samples); 24 | if ~exist('chunksz', 'var') 25 | chunksz=1000; 26 | end 27 | if exist(filename, 'file') 28 | fprintf('Warning: replacing existing file %s \n', filename); 29 | delete(filename); 30 | end 31 | h5create(filename, '/data', [dat_dims(1:end-1) Inf], 'Datatype', 'single', 'ChunkSize', [dat_dims(1:end-1) chunksz]); % width, height, channels, number 32 | h5create(filename, '/label', [lab_dims(1:end-1) Inf], 'Datatype', 'single', 'ChunkSize', [lab_dims(1:end-1) chunksz]); % width, height, channels, number 33 | if ~exist('startloc','var') 34 | startloc.dat=[ones(1,length(dat_dims)-1), 1]; 35 | startloc.lab=[ones(1,length(lab_dims)-1), 1]; 36 | end 37 | else % append mode 38 | if ~exist('startloc','var') 39 | info=h5info(filename); 40 | prev_dat_sz=info.Datasets(1).Dataspace.Size; 41 | prev_lab_sz=info.Datasets(2).Dataspace.Size; 42 | assert(prev_dat_sz(1:end-1)==dat_dims(1:end-1), 'Data dimensions must match existing dimensions in dataset'); 43 | assert(prev_lab_sz(1:end-1)==lab_dims(1:end-1), 'Label dimensions must match existing dimensions in dataset'); 44 | startloc.dat=[ones(1,length(dat_dims)-1), prev_dat_sz(end)+1]; 45 | startloc.lab=[ones(1,length(lab_dims)-1), prev_lab_sz(end)+1]; 46 | end 47 | end 48 | 49 | if ~isempty(data) 50 | h5write(filename, '/data', single(data), startloc.dat, size(data)); 51 | h5write(filename, '/label', single(labels), startloc.lab, size(labels)); 52 | end 53 | 54 | if nargout 55 | info=h5info(filename); 56 | curr_dat_sz=info.Datasets(1).Dataspace.Size; 57 | curr_lab_sz=info.Datasets(2).Dataspace.Size; 58 | end 59 | end 60 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | from SR_datasets import DatasetFactory 7 | from model import ModelFactory 8 | from solver import Solver 9 | 10 | 11 | description = 'Video Super Resolution pytorch implementation' 12 | parser = argparse.ArgumentParser(description=description) 13 | parser.add_argument('-m', '--model', metavar='M', type=str, default='VRES', 14 | help='network architecture. Default False') 15 | parser.add_argument('--model_path', 16 | default='./check_point/VRES/3x/best_model.pt') 17 | parser.add_argument('-s', '--scale', metavar='S', type=int, default=3, 18 | help='interpolation scale. Default 3') 19 | parser.add_argument('--test-set', metavar='NAME', type=str, default='IndMya', 20 | help='dataset for testing. Default IndMya') 21 | args = parser.parse_args() 22 | 23 | 24 | def get_full_path(scale, test_set): 25 | """ 26 | Get full path of data based on configs and target path 27 | example: preprocessed_data/test/set5/3x 28 | """ 29 | scale_path = str(scale) + 'x' 30 | return os.path.join('preprocessed_data/test', test_set, scale_path) 31 | 32 | 33 | def display_config(): 34 | print('############################################################') 35 | print('# Video Super Resolution - Pytorch implementation #') 36 | print('# by Thang Vu (thangvubk@gmail.com #') 37 | print('############################################################') 38 | print('') 39 | print('-------YOUR SETTINGS_________') 40 | for arg in vars(args): 41 | print("%15s: %s" % (str(arg), str(getattr(args, arg)))) 42 | print('') 43 | 44 | 45 | def export(scale, model_name, stats, outputs): 46 | path = os.path.join('results', model_name, str(scale) + 'x') 47 | 48 | if not os.path.exists(path): 49 | os.makedirs(path) 50 | 51 | for i, img in enumerate(outputs): 52 | img_name = os.path.join(path, model_name + '_output%03d.png' % i) 53 | cv2.imwrite(img_name, img) 54 | 55 | with open(os.path.join(path, model_name + '.txt'), 'w') as f: 56 | psnrs, ssims, proc_time = stats 57 | f.write('\t\tPSNR\tSSIM\tTime\n') 58 | for i in range(len(psnrs)): 59 | print('Img%d: PSNR: %.3f SSIM: %.3f Time: %.4f' 60 | % (i, psnrs[i], ssims[i], proc_time[i])) 61 | f.write('Img%d:\t%.3f\t%.3f\t%.4f\n' 62 | % (i, psnrs[i], ssims[i], proc_time[i])) 63 | print('Average test psnr: %.3fdB' % np.mean(psnrs)) 64 | print('Average test ssim: %.3f' % np.mean(ssims)) 65 | print('Finish!!!') 66 | 67 | 68 | def main(): 69 | display_config() 70 | 71 | dataset_root = get_full_path(args.scale, args.test_set) 72 | 73 | print('Contructing dataset...') 74 | dataset_factory = DatasetFactory() 75 | train_dataset = dataset_factory.create_dataset(args.model, 76 | dataset_root) 77 | 78 | model_factory = ModelFactory() 79 | model = model_factory.create_model(args.model) 80 | 81 | check_point = os.path.join( 82 | 'check_point', model.name, str(args.scale) + 'x') 83 | solver = Solver(model, check_point) 84 | 85 | print('Testing...') 86 | stats, outputs = solver.test(train_dataset, args.model_path) 87 | export(args.scale, model.name, stats, outputs) 88 | 89 | 90 | if __name__ == '__main__': 91 | main() 92 | -------------------------------------------------------------------------------- /matlab_scripts/generate_train_video.m: -------------------------------------------------------------------------------- 1 | clear; close all; 2 | 3 | %% Configuration 4 | % NOTE: you can modify this part 5 | train_set = 'train'; 6 | scale = 3; 7 | hr_size = 48; 8 | stride = 36; 9 | 10 | %% Create save path for high resolution and low resolution images based on config 11 | % NOTE: you should NOT modify the following parts 12 | use_upscale_interpolation = true; 13 | disp(sprintf('%10s: %s', 'Train set', train_set)); 14 | disp(sprintf('%10s: %d', 'Scale', scale)); 15 | 16 | scale_dir = strcat(int2str(scale), 'x'); 17 | 18 | % example: 19 | % read_path = '../data/train' 20 | % save_path = '../preprocessed_data_video/train/3x/' 21 | read_path = fullfile('../data', train_set) 22 | save_path = fullfile('../preprocessed_data', train_set, scale_dir); 23 | 24 | if exist(save_path, 'dir') ~= 7 25 | mkdir(save_path) 26 | end 27 | 28 | % count variable to order the data 29 | base_count = 0; 30 | count = 0; 31 | 32 | data = zeros(hr_size, hr_size, 5, 1); 33 | label = zeros(hr_size, hr_size, 5, 1); 34 | 35 | dirs = dir(read_path); 36 | for i_dir = 1 : length(dirs) 37 | is_switch_dir = true; 38 | scene_dir = dirs(i_dir).name; 39 | if scene_dir(1) ~= 's' %valid folder begin with 's' 40 | continue 41 | end 42 | disp(sprintf('processing dir: %s', scene_dir)); 43 | 44 | filepaths = dir(fullfile(read_path, scene_dir, '*.bmp')); 45 | 46 | for i = 1 : length(filepaths) 47 | % if switch dir add count to base_count 48 | if is_switch_dir 49 | base_count = base_count + count; 50 | is_switch_dir = false; 51 | end 52 | 53 | % reset count 54 | count = 0; 55 | 56 | image = imread(fullfile(read_path, scene_dir, filepaths(i).name)); 57 | if size(image, 3) == 3 58 | image_ycbcr = rgb2ycbcr(image); 59 | image_y = image_ycbcr(:, :, 1); 60 | end 61 | hr_im = im2double(image_y); 62 | hr_im = modcrop(hr_im, scale); 63 | [hei, wid] = size(hr_im); 64 | lr_im = imresize(hr_im,1/scale,'bicubic'); 65 | 66 | if use_upscale_interpolation 67 | lr_im = imresize(lr_im ,[hei, wid],'bicubic'); 68 | end 69 | 70 | for h = 1 : stride : hei - hr_size + 1 71 | for w = 1 : stride : wid - hr_size + 1 72 | 73 | hr_sub_im = hr_im(h:hr_size+h-1, w:hr_size+w-1); 74 | 75 | if use_upscale_interpolation 76 | lr_sub_im = lr_im(h:hr_size+h-1, w:hr_size+w-1); 77 | else 78 | lr_sub_im = lr_im(uint32((h-1)/scale + 1):uint32((hr_size+h-1)/scale), uint32((w-1)/scale+1):uint32((hr_size+w-1)/scale)); 79 | end 80 | 81 | count = count + 1; 82 | 83 | data(:, :, i, base_count + count) = lr_sub_im; 84 | label(:, :, i, base_count + count) = hr_sub_im; 85 | end 86 | end 87 | 88 | 89 | end 90 | end 91 | 92 | 93 | %% writing to HDF5 94 | chunksz = 32; 95 | created_flag = false; 96 | totalct = 0; 97 | 98 | for batchno = 1:floor((base_count + count)/chunksz) 99 | last_read=(batchno-1)*chunksz; 100 | batchdata = data(:,:,:,last_read+1:last_read+chunksz); 101 | batchlabs = label(:,:,:,last_read+1:last_read+chunksz); 102 | 103 | startloc = struct('dat',[1,1,1,totalct+1], 'lab', [1,1,1,totalct+1]); 104 | curr_dat_sz = store2hdf5(fullfile(save_path, 'dataset.h5'), batchdata, batchlabs, ~created_flag, startloc, chunksz); 105 | created_flag = true; 106 | totalct = curr_dat_sz(end); 107 | end 108 | h5disp(fullfile(save_path, 'dataset.h5')); 109 | 110 | -------------------------------------------------------------------------------- /SR_datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import numpy as np 4 | import h5py 5 | import torch 6 | 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class DatasetFactory(object): 11 | 12 | def create_dataset(self, name, root, scale=3): 13 | if name == 'VSRCNN': 14 | return VSRCNN_dataset(root) 15 | elif name == 'VRES': 16 | return VRES_dataset(root) 17 | elif name == 'MFCNN': 18 | return MFCNN_dataset(root) 19 | elif name == 'VRES3D': 20 | return VRES3D_dataset(root) 21 | elif name == 'VRES10': 22 | return VRES10_dataset(root) 23 | elif name == 'VRES5': 24 | return VRES5_dataset(root) 25 | elif name == 'VRES15': 26 | return VRES15_dataset(root) 27 | elif name == 'VRES7': 28 | return VRES7_dataset(root) 29 | else: 30 | raise Exception('Unknown dataset {}'.format(name)) 31 | 32 | 33 | class VRES_dataset(Dataset): 34 | 35 | def __init__(self, root): 36 | root = os.path.join(root, 'dataset.h5') 37 | f = h5py.File(root) 38 | self.low_res_imgs = f.get('data') 39 | self.high_res_imgs = f.get('label') 40 | 41 | self.low_res_imgs = np.array(self.low_res_imgs) 42 | self.high_res_imgs = np.array(self.high_res_imgs) 43 | 44 | def __len__(self): 45 | return self.high_res_imgs.shape[0] 46 | 47 | def __getitem__(self, idx): 48 | center = 2 49 | 50 | low_res_imgs = self.low_res_imgs[idx] 51 | high_res_imgs = self.high_res_imgs[idx] 52 | 53 | # h5 in matlab is (H, W, C) 54 | # h5 in python is (C, W, H) 55 | # we need to transpose to (C, H, W) 56 | low_res_imgs = low_res_imgs.transpose(0, 2, 1) 57 | high_res_imgs = high_res_imgs.transpose(0, 2, 1) 58 | 59 | high_res_img = high_res_imgs[center] 60 | high_res_img = high_res_img[np.newaxis, :, :] 61 | 62 | low_res_imgs -= 0.5 63 | high_res_img -= 0.5 64 | 65 | # transform np image to torch tensor 66 | low_res_imgs = torch.Tensor(low_res_imgs) 67 | high_res_img = torch.Tensor(high_res_img) 68 | 69 | return low_res_imgs, high_res_img 70 | 71 | 72 | class VSRCNN_dataset(VRES_dataset): 73 | def __getitem__(self, idx): 74 | center = 2 75 | low_res_imgs = self.low_res_imgs[idx] 76 | high_res_imgs = self.high_res_imgs[idx] 77 | 78 | # h5 in matlab is (H, W, C) 79 | # h5 in python is (C, W, H) 80 | # we need to transpose to (C, H, W) 81 | low_res_imgs = low_res_imgs.transpose(0, 2, 1) 82 | high_res_imgs = high_res_imgs.transpose(0, 2, 1) 83 | 84 | low_res_img = low_res_imgs[center] 85 | high_res_img = high_res_imgs[center] 86 | 87 | low_res_img = low_res_img[np.newaxis, :, :] 88 | high_res_img = high_res_img[np.newaxis, :, :] 89 | 90 | low_res_img -= 0.5 91 | high_res_img -= 0.5 92 | 93 | # transform np image to torch tensor 94 | low_res_img = torch.Tensor(low_res_img) 95 | high_res_img = torch.Tensor(high_res_img) 96 | 97 | return low_res_img, high_res_img 98 | 99 | 100 | class MFCNN_dataset(VRES_dataset): 101 | pass 102 | 103 | 104 | class VRES3D_dataset(VRES_dataset): 105 | pass 106 | 107 | 108 | class VRES10_dataset(VRES_dataset): 109 | pass 110 | 111 | 112 | class VRES5_dataset(VRES_dataset): 113 | pass 114 | 115 | 116 | class VRES15_dataset(VRES_dataset): 117 | pass 118 | 119 | 120 | class VRES7_dataset(VRES_dataset): 121 | pass 122 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from SR_datasets import DatasetFactory 5 | from model import ModelFactory 6 | from solver import Solver 7 | from loss import get_loss_fn 8 | 9 | 10 | description = 'Video Super Resolution pytorch implementation' 11 | 12 | parser = argparse.ArgumentParser(description=description) 13 | 14 | parser.add_argument('-m', '--model', metavar='M', type=str, default='VRES', 15 | help='network architecture. Default VRES') 16 | parser.add_argument('-s', '--scale', metavar='S', type=int, default=3, 17 | help='interpolation scale. Default 3') 18 | parser.add_argument('--train-set', metavar='T', type=str, default='train', 19 | help='data set for training. Default train') 20 | parser.add_argument('--val-set', metavar='V', type=str, default='test/IndMya', 21 | help='data set for validation. Default IndMya') 22 | parser.add_argument('-b', '--batch-size', metavar='B', type=int, default=100, 23 | help='batch size used for training. Default 100') 24 | parser.add_argument('-l', '--learning-rate', metavar='L', type=float, 25 | default=1e-3, help='learning rate. Default 1e-3') 26 | parser.add_argument('-n', '--num-epochs', metavar='N', type=int, default=50, 27 | help='number of training epochs. Default 100') 28 | parser.add_argument('-f', '--fine-tune', dest='fine_tune', action='store_true', 29 | help='fine tune the model under check_point dir,\ 30 | instead of training from scratch. Default False') 31 | parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', 32 | help='print training information. Default False') 33 | 34 | args = parser.parse_args() 35 | 36 | 37 | def get_full_path(scale, train_set): 38 | """ 39 | Get full path of data based on configs and target path 40 | example: preprocessed_data/test/set5/3x 41 | """ 42 | scale_path = str(scale) + 'x' 43 | return os.path.join('preprocessed_data', train_set, scale_path) 44 | 45 | 46 | def display_config(): 47 | print('############################################################') 48 | print('# Video Super Resolution - Pytorch implementation #') 49 | print('# by Thang Vu (thangvubk@gmail.com) #') 50 | print('############################################################') 51 | print('') 52 | print('-------YOUR SETTINGS_________') 53 | for arg in vars(args): 54 | print("%15s: %s" % (str(arg), str(getattr(args, arg)))) 55 | print('') 56 | 57 | 58 | def main(): 59 | display_config() 60 | 61 | train_root = get_full_path(args.scale, args.train_set) 62 | val_root = get_full_path(args.scale, args.val_set) 63 | 64 | print('Contructing dataset...') 65 | dataset_factory = DatasetFactory() 66 | train_dataset = dataset_factory.create_dataset(args.model, 67 | train_root) 68 | val_dataset = dataset_factory.create_dataset(args.model, 69 | val_root) 70 | 71 | model_factory = ModelFactory() 72 | model = model_factory.create_model(args.model) 73 | loss_fn = get_loss_fn(model.name) 74 | 75 | check_point = os.path.join( 76 | 'check_point', model.name, str(args.scale) + 'x') 77 | 78 | solver = Solver( 79 | model, check_point, loss_fn=loss_fn, batch_size=args.batch_size, 80 | num_epochs=args.num_epochs, learning_rate=args.learning_rate, 81 | fine_tune=args.fine_tune, verbose=args.verbose) 82 | 83 | print('Training...') 84 | solver.train(train_dataset, val_dataset) 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class ModelFactory(object): 8 | 9 | def create_model(self, model_name): 10 | if model_name == 'VSRCNN': 11 | return VSRCNN() 12 | elif model_name == 'VRES': 13 | return VRES() 14 | elif model_name == 'MFCNN': 15 | return MFCNN() 16 | elif model_name == 'VRES10': 17 | return VRES10() 18 | elif model_name == 'VRES5': 19 | return VRES5() 20 | elif model_name == 'VRES15': 21 | return VRES15() 22 | elif model_name == 'VRES7': 23 | return VRES7() 24 | else: 25 | raise Exception('unknown model {}'.format(model_name)) 26 | 27 | 28 | class VSRCNN(nn.Module): 29 | """ 30 | Model for SRCNN 31 | 32 | LR -> Conv1 -> Relu -> Conv2 -> Relu -> Conv3 -> HR 33 | 34 | Args: 35 | - C1, C2, C3: num output channels for Conv1, Conv2, and Conv3 36 | - F1, F2, F3: filter size 37 | """ 38 | def __init__(self, 39 | C1=64, C2=32, C3=1, 40 | F1=9, F2=1, F3=5): 41 | super(VSRCNN, self).__init__() 42 | self.name = 'VSRCNN' 43 | self.conv1 = nn.Conv2d(1, C1, F1, padding=4, bias=False) 44 | self.conv2 = nn.Conv2d(C1, C2, F2) 45 | self.conv3 = nn.Conv2d(C2, C3, F3, padding=2, bias=False) 46 | 47 | def forward(self, x): 48 | x = F.relu(self.conv1(x)) 49 | x = F.relu(self.conv2(x)) 50 | x = self.conv3(x) 51 | return x 52 | 53 | 54 | class VRES(nn.Module): 55 | def __init__(self): 56 | super(VRES, self).__init__() 57 | self.name = 'VRES' 58 | self.conv_first = nn.Conv2d(5, 64, 3, padding=1, bias=False) 59 | self.conv_next = nn.Conv2d(64, 64, 3, padding=1, bias=False) 60 | self.conv_last = nn.Conv2d(64, 1, 3, padding=1, bias=False) 61 | self.residual_layer = self.make_layer(Conv_ReLU_Block, 18) 62 | self.relu = nn.ReLU(inplace=True) 63 | 64 | # xavier initialization 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv2d): 67 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 68 | m.weight.data.normal_(0, math.sqrt(2. / n)) 69 | 70 | def make_layer(self, block, num_of_layer): 71 | layers = [] 72 | for _ in range(num_of_layer): 73 | layers.append(block()) 74 | return nn.Sequential(*layers) 75 | 76 | def forward(self, x): 77 | center = 2 78 | res = x[:, center, :, :] 79 | res = res.unsqueeze(1) 80 | out = self.relu(self.conv_first(x)) 81 | out = self.residual_layer(out) 82 | out = self.conv_last(out) 83 | out = torch.add(out, res) 84 | return out 85 | 86 | 87 | class Conv_ReLU_Block(nn.Module): 88 | def __init__(self): 89 | super(Conv_ReLU_Block, self).__init__() 90 | self.conv = nn.Conv2d(64, 64, 3, padding=1, bias=False) 91 | self.relu = nn.ReLU(inplace=True) 92 | 93 | def forward(self, x): 94 | return self.relu(self.conv(x)) 95 | 96 | 97 | class MFCNN(nn.Module): 98 | def __init__(self): 99 | super(MFCNN, self).__init__() 100 | self.name = 'MFCNN' 101 | self.conv1 = nn.Conv2d(5, 32, 9, padding=4, bias=False) 102 | self.conv2 = nn.Conv2d(32, 32, 5, padding=2, bias=False) 103 | self.conv3 = nn.Conv2d(32, 64, 5, padding=2, bias=False) 104 | self.conv4 = nn.Conv2d(64, 32, 3, padding=1, bias=False) 105 | self.conv5 = nn.Conv2d(32, 16, 3, padding=1, bias=False) 106 | self.conv6 = nn.Conv2d(16, 1, 3, padding=1, bias=False) 107 | 108 | def forward(self, x): 109 | x = F.relu(self.conv1(x)) 110 | x = F.relu(self.conv2(x)) 111 | x = F.relu(self.conv3(x)) 112 | x = F.relu(self.conv4(x)) 113 | x = F.relu(self.conv5(x)) 114 | x = self.conv6(x) 115 | return x 116 | 117 | 118 | class VRES10(VRES): 119 | def __init__(self): 120 | super(VRES10, self).__init__() 121 | self.name = 'VRES10' 122 | self.residual_layer = self.make_layer(Conv_ReLU_Block, 8) 123 | 124 | 125 | class VRES5(VRES): 126 | def __init__(self): 127 | super(VRES5, self).__init__() 128 | self.name = 'VRES5' 129 | self.residual_layer = self.make_layer(Conv_ReLU_Block, 3) 130 | 131 | 132 | class VRES15(VRES): 133 | def __init__(self): 134 | super(VRES15, self).__init__() 135 | self.name = 'VRES15' 136 | self.residual_layer = self.make_layer(Conv_ReLU_Block, 13) 137 | 138 | 139 | class VRES7(VRES): 140 | def __init__(self): 141 | super(VRES7, self).__init__() 142 | self.name = 'VRES7' 143 | self.residual_layer = self.make_layer(Conv_ReLU_Block, 5) 144 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Video Super Resolution, SRCNN, MFCNN, VDCN (ours) benchmark comparison 2 | - Report: [pdf](https://drive.google.com/file/d/1A6mHsTWZZhWai8evuEjS-HEGmB2q49fh/view) 3 | 4 | This is a pytorch implementation of video super resolution algorithms [SRCNN](http://personal.ie.cuhk.edu.hk/~ccloy/files/eccv_2014_deepresolution.pdf), [MFCNN](http://cs231n.stanford.edu/reports/2016/pdfs/212_Report.pdf), and [VDCN](https://drive.google.com/open?id=1A6mHsTWZZhWai8evuEjS-HEGmB2q49fh) (ours). This project is used for one of my course, which aims to improve the performance of the baseline (SRCNN, MFCNN). 5 | 6 | To run this project you need to setup the environment, download the dataset, run script to process data, and then you can train and test the network models. I will show you step by step to run this project and i hope it is clear enough :D. 7 | ## Prerequisite 8 | I tested my project in Corei7, 64G RAM, GPU Titan X. Because it use big dataset so you should have CPU/GPU strong enough and about 16 or 24G RAM. 9 | ## Environment 10 | - Pytorch 1.0 11 | - tqdm 12 | - h5py 13 | - cv2 14 | ## Dataset 15 | First, download dataset from this [link](https://drive.google.com/open?id=1-5eKvxDnIqrXE3ABSk6RcPwMrgsKeCsw) and put it in this project. FYI, the training set (IndMya trainset) is taken the India and Myanmar video from [Hamonics](https://www.harmonicinc.com/free-4k-demo-footage/) website. The test sets include IndMya and vid4 (city, walk, foliage, and calendar). After the download completes, unzip it. Your should see the path of data is ``video-super-resolution/data/train/``. 16 | ## Process data 17 | The data is processed by MATLAB scripts, the reason for that is interpolation implementation of MATLAB is different from Python. To do that, open your MATLAB then 18 | ``` 19 | $ cd matlab_scripts/ 20 | $ generate_train_video 21 | ``` 22 | When the script is running, you should see the output as follow 23 | 24 | ![create_train](https://github.com/thangvubk/video-super-resolution/blob/master/install-instructions/create_train.PNG) 25 | 26 | After the scipt finishes, you should see something like 27 | 28 | ![creat_train_result](https://github.com/thangvubk/video-super-resolution/blob/master/install-instructions/create_train_result.PNG) 29 | 30 | As you can see, we have a dataset of ``data`` and ``label``. The train dataset will be stored in the path ``video-super-resolution/preprocessed_data/train/3x/dataset.h5`` 31 | 32 | Do the similar thing with test set: 33 | ``` 34 | $ generate_test_video 35 | ``` 36 | > NOTE: If you want to run train and test the network with different dataset and frame up-scale factor, you should modify the dataset, and scale variable in the ``generate_test_video`` and ``generate_train_video`` scripts (see the scripts for instructions). 37 | 38 | 39 | ## Pretrain model 40 | | Method | Scale | Download | 41 | |:------:|:-----:|----------| 42 | | VRES | 3 | [model](https://drive.google.com/file/d/1unaOdmkw9vM8hHptExqxMflOggJ9M67R/view?usp=sharing) | 43 | ## Execute the code 44 | To train the network: 45 | ```python train.py --verbose``` 46 | 47 | you should see something like 48 | 49 | ![train](https://github.com/thangvubk/video-super-resolution/blob/master/install-instructions/training.PNG) 50 | To test the network: 51 | ```python test.py``` 52 | 53 | you should see something like 54 | 55 | ![test](https://github.com/thangvubk/video-super-resolution/blob/master/install-instructions/testing.PNG) 56 | 57 | The experiment results will be saved in results/ 58 | >NOTE: That is the simplest way to train and test the model, all the settings will take default values. You can add options for training and testing. For example if i want to train model ``MFCNN``, initial learning-rate 1e-2, num of epoch 100, batch_size 64, scale factor 3, verbose true: ``python train.py -m MFCNN -l 1e-2 -n 100 -b 64 -s 3 --verbose``. See ``python main.py --help`` and ``python test.py --help`` for detail information. 59 | 60 | ## Benchmark comparisions 61 | our network architecture is similar to figure below. Which use 5 consecutive low-resolution frames as the input and produce the high resolution center frame. 62 | 63 | ![network_architecture](https://github.com/thangvubk/video-super-resolution/blob/master/install-instructions/network_architecture.PNG) 64 | 65 | Benchmark comparsions on vid4 dataset 66 | 67 | Quantity: 68 | ![quantity](https://github.com/thangvubk/video-super-resolution/blob/master/install-instructions/quantitative.PNG) 69 | 70 | Quality: 71 | ![quality](https://github.com/thangvubk/video-super-resolution/blob/master/install-instructions/qualitative.PNG) 72 | 73 | see our report [VDCN](https://drive.google.com/open?id=1A6mHsTWZZhWai8evuEjS-HEGmB2q49fh) for more comparison. 74 | 75 | ## Project explaination 76 | - ``train.py``: where you can start to train the network 77 | - ``test.py``: where you can start to test the network 78 | - ``model.py``: declare SRCNN, MFCNN, and our model with different network depth (default 20 layers). Note that our network in the code have name VRES. 79 | - ``SR_dataset.py``: declare dataset for each model 80 | - ``solver.py``: encapsulate all the logics to train the network 81 | - ``pytorch_ssim.py``: pytorch implementation for SSIM loss (with autograd), clone from this [repo](https://github.com/Po-Hsun-Su/pytorch-ssim) 82 | - ``loss.py``: loss function for models 83 | 84 | ## TODO 85 | Upload pretrained models 86 | 87 | ## Building your own model 88 | To create your new model you need to define a new network architecture and new dataset class. See ``model.py`` and ``SR_datset.py`` for the idea :D. 89 | 90 | I hope my instructions are clear enough for you. If you have any problem, you can contact me through thangvubk@gmail.com or use the issue tab. If you are insterested in this project, you are very welcome. Many Thanks. 91 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import time 4 | from shutil import copyfile 5 | 6 | import torch 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | import torch.nn as nn 10 | import math 11 | from tqdm import tqdm 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | import pytorch_ssim 15 | 16 | 17 | class Solver(object): 18 | """ 19 | A Solver encapsulates all the logic necessary for training super resolution 20 | The Solver accepts both training and validation data label so it can 21 | periodically check the PSNR on training 22 | 23 | To train a model, you will first construct a Solver instance, 24 | pass the model, datasets, and various option (optimizer, loss_fn, 25 | batch_size, etc) to the constructor. 26 | 27 | After train() method is called. The best model is saved into 28 | 'check_point' dir, which is used for the testing time. 29 | 30 | """ 31 | def __init__(self, model, check_point, **kwargs): 32 | """ 33 | Construct a new Solver instance 34 | 35 | Required arguments 36 | - model: a torch nn module describe the neural network architecture 37 | - check_point: save trained model for testing for finetuning 38 | 39 | Optional arguments: 40 | - num_epochs: number of epochs to run during training 41 | - batch_size: batch size for train phase 42 | - optimizer: update rule for model parameters 43 | - loss_fn: loss function for the model 44 | - fine_tune: fine tune the model in check_point dir instead of training 45 | from scratch 46 | - verbose: print training information 47 | - print_every: period of statistics printing 48 | """ 49 | self.model = model 50 | self.check_point = check_point 51 | self.num_epochs = kwargs.pop('num_epochs', 10) 52 | self.batch_size = kwargs.pop('batch_size', 128) 53 | self.learning_rate = kwargs.pop('learning_rate', 1e-4) 54 | self.optimizer = optim.Adam( 55 | model.parameters(), 56 | lr=self.learning_rate, weight_decay=1e-6) 57 | self.scheduler = lr_scheduler.StepLR( 58 | self.optimizer, step_size=20, gamma=0.5) 59 | self.loss_fn = kwargs.pop('loss_fn', nn.MSELoss()) 60 | self.fine_tune = kwargs.pop('fine_tune', False) 61 | self.verbose = kwargs.pop('verbose', False) 62 | self.print_every = kwargs.pop('print_every', 10) 63 | 64 | self._reset() 65 | 66 | def _reset(self): 67 | """ Initialize some book-keeping variable, dont call it manually""" 68 | self.use_gpu = torch.cuda.is_available() 69 | if self.use_gpu: 70 | self.model = self.model.cuda() 71 | 72 | def _epoch_step(self, dataset, epoch): 73 | """ Perform 1 training 'epoch' on the 'dataset'""" 74 | dataloader = DataLoader(dataset, batch_size=self.batch_size, 75 | shuffle=True, num_workers=4) 76 | 77 | num_batchs = len(dataset)//self.batch_size 78 | 79 | running_loss = 0 80 | for i, (input_batch, label_batch) in enumerate(tqdm(dataloader)): 81 | 82 | # Wrap with torch Variable 83 | input_batch, label_batch = self._wrap_variable(input_batch, 84 | label_batch, 85 | self.use_gpu) 86 | 87 | # zero the grad 88 | self.optimizer.zero_grad() 89 | 90 | # Forward 91 | output_batch = self.model(input_batch) 92 | loss = self.loss_fn(output_batch, label_batch) 93 | 94 | running_loss += loss.item() 95 | 96 | # Backward + update 97 | loss.backward() 98 | nn.utils.clip_grad_norm_(self.model.parameters(), 0.4) 99 | self.optimizer.step() 100 | 101 | average_loss = running_loss/num_batchs 102 | if self.verbose: 103 | print('Epoch %5d, loss %.5f' % (epoch, average_loss)) 104 | 105 | def _wrap_variable(self, input_batch, label_batch, use_gpu): 106 | if use_gpu: 107 | input_batch, label_batch = (Variable(input_batch.cuda()), 108 | Variable(label_batch.cuda())) 109 | else: 110 | input_batch, label_batch = (Variable(input_batch), 111 | Variable(label_batch)) 112 | return input_batch, label_batch 113 | 114 | def _comput_PSNR(self, imgs1, imgs2): 115 | """Compute PSNR between two image array and return the psnr sum""" 116 | N = imgs1.size()[0] 117 | imdiff = imgs1 - imgs2 118 | imdiff = imdiff.view(N, -1) 119 | rmse = torch.sqrt(torch.mean(imdiff**2, dim=1)) 120 | psnr = 20*torch.log(255/rmse)/math.log(10) # psnr = 20*log10(255/rmse) 121 | psnr = torch.sum(psnr) 122 | return psnr 123 | 124 | def _check_PSNR(self, dataset, is_test=False): 125 | """ 126 | Get the output of model with the input being 'dataset' then 127 | compute the PSNR between output and label. 128 | 129 | if 'is_test' is True, psnr and output of each image is also 130 | return for statistics and generate output image at test phase 131 | """ 132 | 133 | dataloader = DataLoader(dataset, batch_size=1, 134 | shuffle=False, num_workers=4) 135 | 136 | avr_psnr = 0 137 | avr_ssim = 0 138 | 139 | # book keeping variables for test phase 140 | psnrs = [] # psnr for each image 141 | ssims = [] # ssim for each image 142 | proc_time = [] # processing time 143 | outputs = [] # output for each image 144 | 145 | for batch, (input_batch, label_batch) in enumerate(dataloader): 146 | input_batch, label_batch = self._wrap_variable(input_batch, 147 | label_batch, 148 | self.use_gpu) 149 | if is_test: 150 | start = time.time() 151 | output_batch = self.model(input_batch) 152 | elapsed_time = time.time() - start 153 | else: 154 | output_batch = self.model(input_batch) 155 | 156 | # ssim is calculated with the normalize (range [0, 1]) image 157 | ssim = pytorch_ssim.ssim( 158 | output_batch + 0.5, label_batch + 0.5, size_average=False) 159 | ssim = torch.sum(ssim).item() 160 | avr_ssim += ssim 161 | 162 | # calculate PSRN 163 | output = output_batch.data 164 | label = label_batch.data 165 | 166 | output = (output + 0.5)*255 167 | label = (label + 0.5)*255 168 | 169 | output = output.squeeze(dim=1) 170 | label = label.squeeze(dim=1) 171 | 172 | psnr = self._comput_PSNR(output, label) 173 | psnr = psnr.item() 174 | avr_psnr += psnr 175 | 176 | # save psnrs and outputs for stats and generate image at test time 177 | if is_test: 178 | psnrs.append(psnr) 179 | ssims.append(ssim) 180 | proc_time.append(elapsed_time) 181 | np_output = output.cpu().numpy() 182 | outputs.append(np_output[0]) 183 | 184 | epoch_size = len(dataset) 185 | avr_psnr /= epoch_size 186 | avr_ssim /= epoch_size 187 | stats = (psnrs, ssims, proc_time) 188 | 189 | return avr_psnr, avr_ssim, stats, outputs 190 | 191 | def train(self, train_dataset, val_dataset): 192 | """ 193 | Train the 'train_dataset', 194 | if 'fine_tune' is True, we finetune the model under 'check_point' dir 195 | instead of training from scratch. 196 | 197 | The best model is save under checkpoint which is used 198 | for test phase or finetuning 199 | """ 200 | 201 | # check fine_tuning option 202 | model_path = os.path.join(self.check_point, 'model.pt') 203 | if self.fine_tune and not os.path.exists(model_path): 204 | raise Exception('Cannot find %s.' % model_path) 205 | elif self.fine_tune and os.path.exists(model_path): 206 | if self.verbose: 207 | print('Loading %s for finetuning.' % model_path) 208 | self.model = torch.load(model_path) 209 | self.optimizer = optim.Adam( 210 | self.model.parameters(), lr=self.learning_rate) 211 | 212 | # capture best model 213 | best_val_psnr = -1 214 | 215 | # Train the model 216 | for epoch in range(self.num_epochs): 217 | self._epoch_step(train_dataset, epoch) 218 | self.scheduler.step() 219 | 220 | if self.verbose: 221 | print('Validate PSNR...') 222 | 223 | # compuate validate PSNR and SSIM on val dataset 224 | val_psnr, val_ssim, _, _ = self._check_PSNR(val_dataset) 225 | 226 | if self.verbose: 227 | print('Val PSNR: %.3fdB. Val ssim: %.3f' 228 | % (val_psnr, val_ssim)) 229 | 230 | # write the model to hard-disk for testing 231 | print('Saving model') 232 | if not os.path.exists(self.check_point): 233 | os.makedirs(self.check_point) 234 | model_path = os.path.join(self.check_point, 'epoch{}.pt'.format(epoch)) 235 | torch.save(self.model, model_path) 236 | if best_val_psnr < val_psnr: 237 | print('Copy best model') 238 | target_path = os.path.join(self.check_point, 'best_model.pt') 239 | copyfile(model_path, target_path) 240 | best_val_psnr = val_psnr 241 | print('') 242 | 243 | def test(self, dataset, model_path): 244 | """ 245 | Load the model stored in train_model.pt from training phase, 246 | then return the average PNSR on test samples. 247 | """ 248 | if not os.path.exists(model_path): 249 | raise Exception('Cannot find %s.' % model_path) 250 | 251 | self.model = torch.load(model_path) 252 | _, _, stats, outputs = self._check_PSNR(dataset, is_test=True) 253 | return stats, outputs 254 | --------------------------------------------------------------------------------