├── .gitignore ├── README.md ├── approx_huber_loss.py ├── data ├── __init__.py ├── benchmark.py ├── benchmark_video.py ├── cdvl100.py ├── cdvl_video.py ├── common.py ├── div2k.py ├── srdata.py └── vsrdata.py ├── logger └── logger.py ├── main.py ├── model ├── __init__.py ├── espcn.py ├── espcn_modified.py ├── espcn_multiframe.py ├── espcn_multiframe2.py ├── motioncompensator.py └── vespcn.py ├── option.py ├── template.py ├── test.py ├── tools ├── imgresize.m └── videocapture.m ├── trainer.py ├── trainer_mc.py ├── trainer_vsr.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | .ipynb_checkpoints 27 | 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *,cover 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | 56 | # Sphinx documentation 57 | docs/_build/ 58 | 59 | # PyBuilder 60 | target/ 61 | 62 | # PyTorch 63 | *.pt 64 | *.pdf 65 | *.txt 66 | *.swp 67 | *.vscode 68 | *.xml 69 | *.iml 70 | 71 | # Image 72 | *.png 73 | *.jpg 74 | *.jpeg 75 | 76 | # Floydhub 77 | .DS_Store 78 | .floydexpt 79 | .floydignore 80 | floyd.yml 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VESPCN-PyTorch 2 | PyTorch implementation of ESPCN [1]/VESPCN [2]. 3 | 4 | ## **How to run the code** 5 | 1. Add your own template in template.py, indicating parameters related to running the code 6 | (especially, specify the task (Image/MC/Video) and set training/test dataset directories specific to your filesystem) 7 | 2. Add your model in ./model/ directory (filename should be in lower cases) 8 | 3. Type "python3 main.py --template $(your template) --model $(model you want to train)" for training 9 | 4. If you want to add additional options for test benchmark datasets, modify ./data/__init__.py. 10 | 5. For additional details, refer to [3] (We have borrowed most of the implementation details from there). 11 | 12 | ## **TODO list** 13 | - [x] Implement the SISR ESPCN network 14 | - [x] Making dataloader for video SR 15 | - [x] Complete the motion compensation network 16 | - [x] Joining the ESPCN to motion compensation network 17 | 18 | ## **References** 19 | [1] W. Shi et al, “Real-time single image and video super-resolution using an efficient sub-pixel convolutional neural network,” IEEE CVPR 2016. 20 | 21 | [2] J. Caballero et al, “Real-Time Video Super-Resolution with Spatio-Temporal Networks and Motion Compensation,” IEEE CVPR 2017. 22 | 23 | [3] https://github.com/thstkdgus35/EDSR-PyTorch (borrowed the overall code structure) 24 | -------------------------------------------------------------------------------- /approx_huber_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | # TODO: Is the first channel flow with respect to x? 7 | # TODO: Fix 'mean' issues 8 | 9 | 10 | class Approx_Huber_Loss(nn.Module): 11 | def __init__(self, args): 12 | super(Approx_Huber_Loss, self).__init__() 13 | self.device = torch.device('cpu' if args.cpu else 'cuda') 14 | self.sobel_filter_X = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).reshape((1, 1, 3, 3)) 15 | self.sobel_filter_Y = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).reshape((1, 1, 3, 3)) 16 | self.sobel_filter_X = torch.from_numpy(self.sobel_filter_X).float().to(self.device) 17 | self.sobel_filter_Y = torch.from_numpy(self.sobel_filter_Y).float().to(self.device) 18 | self.epsilon = torch.Tensor([0.01]).float().to(self.device) 19 | 20 | def forward(self, flow): 21 | flow_X = flow[:, 0:1] 22 | flow_Y = flow[:, 1:] 23 | grad_X = F.conv2d(flow_X, self.sobel_filter_X, bias=None, stride=1, padding=1) 24 | grad_Y = F.conv2d(flow_Y, self.sobel_filter_Y, bias=None, stride=1, padding=1) 25 | huber = torch.sqrt(self.epsilon + torch.sum(grad_X.pow(2)+grad_Y.pow(2))) 26 | return huber 27 | 28 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from importlib import import_module 4 | 5 | from torch.utils.data import DataLoader 6 | 7 | class Data: 8 | def __init__(self, args): 9 | self.args = args 10 | self.data_train = args.data_train 11 | self.data_test = args.data_test 12 | 13 | list_benchmarks = ['Set5', 'Set14', 'B100', 'Urban100'] 14 | benchmark = self.data_test in list_benchmarks 15 | 16 | list_benchmarks_video = ['Vid4'] 17 | benchmark_video = self.data_test in list_benchmarks_video 18 | if not self.args.test_only: 19 | m_train = import_module('data.' + self.data_train.lower()) 20 | trainset = getattr(m_train, self.data_train)(self.args) 21 | self.loader_train = DataLoader( 22 | trainset, 23 | batch_size=self.args.batch_size, 24 | shuffle=True, 25 | pin_memory=not self.args.cpu 26 | ) 27 | else: 28 | self.loader_train = None 29 | 30 | if benchmark: 31 | m_test = import_module('data.benchmark') 32 | testset = getattr(m_test, 'Benchmark')(self.args, name=args.data_test, train=False) 33 | elif benchmark_video: 34 | m_test = import_module('data.benchmark_video') 35 | testset = getattr(m_test, 'Benchmark_video')(self.args, name=args.data_test, train=False) 36 | else: 37 | class_name = self.data_test 38 | m_test = import_module('data.' + class_name.lower()) 39 | testset = getattr(m_test, class_name)(self.args, train=False) 40 | 41 | self.loader_test = DataLoader(testset, batch_size=1, shuffle=False, pin_memory=not self.args.cpu) 42 | 43 | -------------------------------------------------------------------------------- /data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Benchmark(srdata.SRData): 12 | """ 13 | Data generator for benchmark tasks 14 | """ 15 | def __init__(self, args, name='', train=False): 16 | super(Benchmark, self).__init__( 17 | args, name=name, train=train 18 | ) 19 | 20 | def _set_filesystem(self, dir_data): 21 | if self.args.template == 'SY': 22 | self.apath = os.path.join(dir_data, 'benchmark', self.name) 23 | self.dir_hr = os.path.join(self.apath, 'HR') 24 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic', 'X{}'.format(self.args.scale)) 25 | ################################################ 26 | # # 27 | # Fill in your directory with your own template# 28 | # # 29 | ################################################ 30 | elif self.args.template == "JH": 31 | self.apath = os.path.join(dir_data, self.name) 32 | self.dir_hr = os.path.join(self.apath, 'HR') 33 | self.dir_lr = os.path.join(self.apath, 'LR') 34 | 35 | def _load_file(self, idx): 36 | lr, hr, filename = super(Benchmark, self)._load_file(idx=idx) 37 | if self.name == 'Set14': 38 | if lr.ndim == 2: 39 | lr = np.repeat(np.expand_dims(lr, axis=2), 3, axis=2) 40 | hr = np.repeat(np.expand_dims(hr, axis=2), 3, axis=2) 41 | 42 | return lr, hr, filename 43 | -------------------------------------------------------------------------------- /data/benchmark_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | from data import vsrdata 6 | 7 | import numpy as np 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | class Benchmark_video(vsrdata.VSRData): 13 | """ 14 | Data generator for benchmark tasks 15 | """ 16 | def __init__(self, args, name='', train=False): 17 | super(Benchmark_video, self).__init__( 18 | args, name=name, train=train 19 | ) 20 | 21 | def _set_filesystem(self, dir_data): 22 | if self.args.template == 'SY': 23 | self.apath = os.path.join(dir_data, 'benchmark', self.name) 24 | self.dir_hr = os.path.join(self.apath, 'HR') 25 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic', 'X{}'.format(self.args.scale)) 26 | ################################################ 27 | # # 28 | # Fill in your directory with your own template# 29 | # # 30 | ################################################ 31 | elif self.args.template == "JH_Video" or self.args.template == "JH_MC": 32 | self.apath = os.path.join(dir_data, self.name) 33 | self.dir_hr = os.path.join(self.apath, 'HR') 34 | self.dir_lr = os.path.join(self.apath, 'LR') 35 | print("test video path (HR):", self.dir_hr) 36 | print("test video path (LR):", self.dir_lr) 37 | 38 | def _load_file(self, idx): 39 | lr, hr, filename = super(Benchmark_video, self)._load_file(idx=idx) 40 | if self.name == 'Set14': 41 | if lr.ndim == 2: 42 | lr = np.repeat(np.expand_dims(lr, axis=2), 3, axis=2) 43 | hr = np.repeat(np.expand_dims(hr, axis=2), 3, axis=2) 44 | 45 | return lr, hr, filename 46 | -------------------------------------------------------------------------------- /data/cdvl100.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class CDVL100(srdata.SRData): 12 | def __init__(self, args, name='CDVL100', train=True): 13 | super(CDVL100, self).__init__( 14 | args, name=name, train=train 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | ################################################ 19 | # # 20 | # Fill in your directory with your own template# 21 | # # 22 | ################################################ 23 | if self.args.template == "SY": 24 | super(CDVL100, self)._set_filesystem(dir_data) 25 | 26 | if self.args.template == "JH": 27 | print("Loading CDVL100") 28 | self.apath = os.path.join(dir_data, self.name) 29 | self.dir_hr = os.path.join(self.apath, 'HR') 30 | self.dir_lr = os.path.join(self.apath, 'LR') 31 | -------------------------------------------------------------------------------- /data/cdvl_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import vsrdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | # Data loader for CDVL videos 12 | class CDVL_VIDEO(vsrdata.VSRData): 13 | 14 | def __init__(self, args, name='CDVL', train=True): 15 | super(CDVL_VIDEO, self).__init__(args, name=name, train=train) 16 | 17 | 18 | def _scan(self): 19 | names_hr, names_lr = super(CDVL_VIDEO, self)._scan() 20 | names_hr = names_hr[self.begin - 1:self.end] 21 | names_lr = names_lr[self.begin - 1:self.end] 22 | 23 | return names_hr, names_lr 24 | 25 | def _set_filesystem(self, dir_data): 26 | ################################################ 27 | # # 28 | # Fill in your directory with your own template# 29 | # # 30 | ################################################ 31 | if self.args.template == "SY": 32 | super(CDVL_VIDEO, self)._set_filesystem(dir_data) 33 | 34 | if self.args.template == "JH_Video" or self.args.template == "JH_MC": 35 | print("Loading CDVL videos") 36 | self.apath = os.path.join(dir_data, self.name) 37 | self.dir_hr = os.path.join(self.apath, 'HR') 38 | self.dir_lr = os.path.join(self.apath, 'LR') 39 | print("Train video path (HR):", self.dir_hr) 40 | print("Train video path (LR):", self.dir_lr) 41 | -------------------------------------------------------------------------------- /data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import skimage.io as sio 5 | import skimage.color as sc 6 | import skimage.transform as st 7 | 8 | import torch 9 | from torchvision import transforms 10 | 11 | """ 12 | Repository for common functions required for manipulating data 13 | """ 14 | 15 | 16 | def get_patch(*args, patch_size=17, scale=1): 17 | """ 18 | Get patch from an image 19 | """ 20 | ih, iw, _ = args[0].shape 21 | 22 | ip = patch_size 23 | tp = scale * ip 24 | 25 | ix = random.randrange(0, iw - ip + 1) 26 | iy = random.randrange(0, ih - ip + 1) 27 | tx, ty = scale * ix, scale * iy 28 | 29 | ret = [ 30 | args[0][iy:iy + ip, ix:ix + ip, :], 31 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] 32 | ] 33 | 34 | return ret 35 | 36 | 37 | def set_channel(*args, n_channels=3): 38 | def _set_channel(img): 39 | if img.ndim == 2: 40 | img = np.expand_dims(img, axis=2) 41 | 42 | c = img.shape[2] 43 | if n_channels == 1 and c == 3: 44 | img = sc.rgb2ycbcr(img) 45 | elif n_channels == 3 and c == 1: 46 | img = np.concatenate([img] * n_channels, 2) 47 | 48 | return img 49 | 50 | return [_set_channel(a) for a in args] 51 | 52 | 53 | def np2Tensor(*args, rgb_range=255, n_colors=1): 54 | def _np2Tensor(img): 55 | # NHWC -> NCHW 56 | if img.shape[2] == 3 and n_colors == 3: 57 | mean_RGB = np.array([123.68, 116.779, 103.939]) 58 | img = img.astype('float64') - mean_RGB 59 | elif img.shape[2] == 3 and n_colors == 1: 60 | mean_YCbCr = np.array([109, 0, 0]) 61 | img = img.astype('float64') - mean_YCbCr 62 | else: 63 | mean_YCbCr = np.array([109]) 64 | img = img.astype('float64') - mean_YCbCr 65 | 66 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 67 | tensor = torch.from_numpy(np_transpose).float() 68 | tensor.mul_(rgb_range / 255) 69 | 70 | return tensor 71 | 72 | return [_np2Tensor(a) for a in args] 73 | 74 | def augment(*args, hflip=True, rot=True): 75 | hflip = hflip and random.random() < 0.5 76 | vflip = rot and random.random() < 0.5 77 | rot90 = rot and random.random() < 0.5 78 | 79 | def _augment(img): 80 | if hflip: img = img[:, ::-1, :] 81 | if vflip: img = img[::-1, :, :] 82 | if rot90: img = np.rot90(img) 83 | 84 | return img 85 | 86 | return [_augment(a) for a in args] 87 | 88 | -------------------------------------------------------------------------------- /data/div2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class DIV2K(srdata.SRData): 5 | def __init__(self, args, name='DIV2K', train=True): 6 | super(DIV2K, self).__init__( 7 | args, name=name, train=train 8 | ) 9 | 10 | def _scan(self): 11 | names_hr, names_lr = super(DIV2K, self)._scan() 12 | names_hr = names_hr[self.begin - 1:self.end] 13 | names_lr = [n[self.begin - 1:self.end] for n in names_lr] 14 | 15 | return names_hr, names_lr 16 | 17 | def _set_filesystem(self, dir_data): 18 | if self.args.template == 'SY': 19 | self.apath = os.path.join(dir_data, self.name) 20 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 21 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic', 'X{}'.format(self.args.scale)) 22 | ################################################ 23 | # # 24 | # Fill in your directory with your own template# 25 | # # 26 | ################################################ 27 | elif self.args.template == 'JH': 28 | print("Loading DIV2K") 29 | self.dir_hr = os.path.join(dir_data, 'DIV2K') 30 | self.dir_lr = os.path.join(dir_data, 'DIV2K_LR') 31 | #print(self.dir_hr) 32 | #print(self.dir_lr) 33 | 34 | -------------------------------------------------------------------------------- /data/srdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import imageio 4 | import skimage.color as sc 5 | import numpy as np 6 | from scipy import misc 7 | 8 | from data import common 9 | 10 | import torch 11 | import torch.utils.data as data 12 | 13 | 14 | class SRData(data.Dataset): 15 | def __init__(self, args, name='', train=True): 16 | self.args = args 17 | self.name = name 18 | self.train = train 19 | self.split = 'train' if train else 'test' 20 | self.do_eval = True 21 | self.scale = args.scale 22 | self.idx_scale = 0 23 | 24 | data_range = [r.split('-') for r in args.data_range.split('/')] 25 | if train: 26 | data_range = data_range[0] 27 | else: 28 | if args.test_only and len(data_range) == 1: 29 | data_range = data_range[0] 30 | else: 31 | data_range = data_range[1] 32 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 33 | 34 | if train: 35 | self._set_filesystem(args.dir_data) 36 | else: 37 | self._set_filesystem(args.dir_data_test) 38 | 39 | self.images_hr, self.images_lr = self._scan() 40 | if train and args.process: 41 | print('Loading image dataset...') 42 | self.data_hr, self.data_lr = self._load(self.images_hr, self.images_lr) 43 | 44 | if train: 45 | self.repeat = args.test_every // (len(self.images_hr) // args.batch_size) 46 | 47 | # Below functions as used to prepare images 48 | def _scan(self): 49 | """ 50 | Returns a list of image directories 51 | """ 52 | names_hr = sorted(glob.glob(os.path.join(self.dir_hr, '*.png'))) 53 | names_lr = sorted(glob.glob(os.path.join(self.dir_lr, '*.png'))) 54 | 55 | return names_hr, names_lr 56 | 57 | def _load(self, names_hr, names_lr): 58 | data_lr = [imageio.imread(filename) for filename in names_lr] 59 | data_hr = [imageio.imread(filename) for filename in names_hr] 60 | return data_hr, data_lr 61 | 62 | def _set_filesystem(self, dir_data): 63 | self.apath = os.path.join(dir_data, self.name) 64 | if self.args.template == 'SY': 65 | self.dir_hr = os.path.join(self.apath, 'HR') 66 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic', 'X{}'.format(self.args.scale)) 67 | ################################################ 68 | # # 69 | # Fill in your directory with your own template# 70 | # # 71 | ################################################ 72 | elif self.args.template == "JH": 73 | self.apath = os.path.join(dir_data, self.name) 74 | self.dir_hr = os.path.join(self.apath, 'HR') 75 | self.dir_lr = os.path.join(self.apath, 'LR') 76 | 77 | def __getitem__(self, idx): 78 | if self.train and self.args.process: 79 | lr, hr, filename = self._load_file_from_loaded_data(idx) 80 | else: 81 | lr, hr, filename = self._load_file(idx) 82 | if self.train: 83 | lr_extend = hr 84 | else: 85 | lr_extend = misc.imresize(lr, size=self.args.scale*100, interp='bicubic') 86 | lr, lr_extend, hr = self.get_patch(lr, lr_extend, hr) 87 | lr, lr_extend, hr = common.set_channel(lr, lr_extend, hr, n_channels=self.args.n_colors) 88 | lr_tensor, lre_tensor, hr_tensor = common.np2Tensor( 89 | lr, lr_extend, hr, rgb_range=self.args.rgb_range, n_colors=self.args.n_colors 90 | ) 91 | return lr_tensor, lre_tensor, hr_tensor, filename 92 | 93 | def __len__(self): 94 | if self.train: 95 | return len(self.images_hr) * self.repeat 96 | else: 97 | return len(self.images_hr) 98 | 99 | def _get_index(self, idx): 100 | if self.train: 101 | return idx % len(self.images_hr) 102 | else: 103 | return idx 104 | 105 | def _load_file(self, idx): 106 | """ 107 | Read image from given image directory 108 | """ 109 | idx = self._get_index(idx) 110 | f_hr = self.images_hr[idx] 111 | f_lr = self.images_lr[idx] 112 | 113 | filename, _ = os.path.splitext(os.path.basename(f_hr)) 114 | hr = imageio.imread(f_hr) 115 | lr = imageio.imread(f_lr) 116 | 117 | return lr, hr, filename 118 | 119 | def _load_file_from_loaded_data(self, idx): 120 | idx = self._get_index(idx) 121 | hr = self.data_hr[idx] 122 | lr = self.data_lr[idx] 123 | filename = os.path.splitext(os.path.split(self.images_hr[idx])[-1])[0] 124 | 125 | return lr, hr, filename 126 | 127 | def get_patch(self, lr, lr_extend, hr): 128 | """ 129 | Returns patches for multiple scales 130 | """ 131 | scale = self.scale 132 | if self.train: 133 | lr, lr_extend, hr = common.get_patch( 134 | lr, lr_extend, hr, 135 | patch_size=self.args.patch_size, 136 | scale=scale, 137 | ) 138 | if not self.args.no_augment: 139 | lr, lr_extend, hr = common.augment(lr, lr_extend, hr) 140 | else: 141 | ih, iw = lr.shape[:2] 142 | lr_extend = lr_extend[0:ih * scale, 0:iw * scale] 143 | hr = hr[0:ih * scale, 0:iw * scale] 144 | 145 | return lr, lr_extend, hr 146 | -------------------------------------------------------------------------------- /data/vsrdata.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import glob 4 | import time 5 | import skimage.color as sc 6 | from data import common 7 | import pickle 8 | import numpy as np 9 | import imageio 10 | import random 11 | import torch 12 | import torch.utils.data as data 13 | import cv2 14 | 15 | 16 | class VSRData(data.Dataset): 17 | def __init__(self, args, name='', train=True): 18 | self.args = args 19 | self.name = name 20 | self.train = train 21 | self.scale = args.scale 22 | self.idx_scale = 0 23 | self.n_seq = args.n_sequence 24 | print("n_seq:", args.n_sequence) 25 | print("n_frames_per_video:", args.n_frames_per_video) 26 | # self.image_range : need to make it flexible in the test area 27 | self.img_range = 30 28 | self.n_frames_video = [] 29 | data_range = [r.split('-') for r in args.data_range.split('/')] 30 | if train: 31 | data_range = data_range[0] 32 | else: 33 | if args.test_only and len(data_range) == 1: 34 | data_range = data_range[0] 35 | else: 36 | data_range = data_range[1] 37 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 38 | 39 | if train: 40 | self._set_filesystem(args.dir_data) 41 | else: 42 | self._set_filesystem(args.dir_data_test) 43 | 44 | self.images_hr, self.images_lr = self._scan() 45 | self.num_video = len(self.images_hr) 46 | print("Number of videos to load:", self.num_video) 47 | if train: 48 | self.repeat = args.test_every // max((self.num_video // self.args.batch_size), 1) 49 | if args.process: 50 | self.data_hr, self.data_lr = self._load(self.num_video) 51 | 52 | # Below functions as used to prepare images 53 | def _scan(self): 54 | """ 55 | Returns a list of image directories 56 | """ 57 | if self.train: 58 | # training datasets are labeled as .../Video*/HR/*.png 59 | vid_hr_names = sorted(glob.glob(os.path.join(self.dir_hr, 'Video*'))) 60 | vid_lr_names = sorted(glob.glob(os.path.join(self.dir_lr, 'Video*'))) 61 | else: 62 | vid_hr_names = sorted(glob.glob(os.path.join(self.dir_hr, '*'))) 63 | vid_lr_names = sorted(glob.glob(os.path.join(self.dir_lr, '*'))) 64 | 65 | assert len(vid_hr_names) == len(vid_lr_names) 66 | 67 | names_hr = [] 68 | names_lr = [] 69 | 70 | if self.train: 71 | for vid_hr_name, vid_lr_name in zip(vid_hr_names, vid_lr_names): 72 | start = random.randint(0, self.img_range - self.args.n_frames_per_video) 73 | hr_dir_names = sorted(glob.glob(os.path.join(vid_hr_name, '*.png')))[start: start+self.args.n_frames_per_video] 74 | lr_dir_names = sorted(glob.glob(os.path.join(vid_lr_name, '*.png')))[start: start+self.args.n_frames_per_video] 75 | names_hr.append(hr_dir_names) 76 | names_lr.append(lr_dir_names) 77 | self.n_frames_video.append(len(hr_dir_names)) 78 | else: 79 | for vid_hr_name, vid_lr_name in zip(vid_hr_names, vid_lr_names): 80 | hr_dir_names = sorted(glob.glob(os.path.join(vid_hr_name, '*.png'))) 81 | lr_dir_names = sorted(glob.glob(os.path.join(vid_lr_name, '*.png'))) 82 | names_hr.append(hr_dir_names) 83 | names_lr.append(lr_dir_names) 84 | self.n_frames_video.append(len(hr_dir_names)) 85 | 86 | return names_hr, names_lr 87 | 88 | def _load(self, n_videos): 89 | data_lr = [] 90 | data_hr = [] 91 | for idx in range(n_videos): 92 | if idx % 10 == 0: 93 | print("Loading video %d" %idx) 94 | lrs, hrs, _ = self._load_file(idx) 95 | hrs = np.array([imageio.imread(hr_name) for hr_name in self.images_hr[idx]]) 96 | lrs = np.array([imageio.imread(lr_name) for lr_name in self.images_lr[idx]]) 97 | data_lr.append(lrs) 98 | data_hr.append(hrs) 99 | #data_lr = common.set_channel(*data_lr, n_channels=self.args.n_colors) 100 | #data_hr = common.set_channel(*data_hr, n_channels=self.args.n_colors) 101 | return data_hr, data_lr 102 | 103 | def _set_filesystem(self, dir_data): 104 | self.apath = os.path.join(dir_data, self.name) 105 | if self.args.template == 'SY': 106 | self.dir_hr = os.path.join(self.apath, 'HR') 107 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic', 'X{}'.format(self.args.scale)) 108 | ################################################ 109 | # # 110 | # Fill in your directory with your own template# 111 | # # 112 | ################################################ 113 | elif self.args.template == "JH": 114 | self.apath = os.path.join(dir_data, self.name) 115 | print("apath:", self.apath) 116 | self.dir_hr = os.path.join(self.apath, 'HR') 117 | self.dir_lr = os.path.join(self.apath, 'LR') 118 | else: 119 | # This is just for testing: must fix later! 120 | self.dir_hr = os.path.join(self.apath, 'HR_big') 121 | self.dir_lr = os.path.join(self.apath, 'LR_big') 122 | 123 | def __getitem__(self, idx): 124 | if self.args.process: 125 | lrs, hrs, filenames = self._load_file_from_loaded_data(idx) 126 | else: 127 | lrs, hrs, filenames = self._load_file(idx) 128 | 129 | patches = [self.get_patch(lr, hr) for lr, hr in zip(lrs, hrs)] 130 | lrs = np.array([patch[0] for patch in patches]) 131 | hrs = np.array([patch[1] for patch in patches]) 132 | lrs = np.array(common.set_channel(*lrs, n_channels=self.args.n_colors)) 133 | hrs = np.array(common.set_channel(*hrs, n_channels=self.args.n_colors)) 134 | lr_tensors = common.np2Tensor(*lrs, rgb_range=self.args.rgb_range, n_colors=self.args.n_colors) 135 | hr_tensors = common.np2Tensor(*hrs, rgb_range=self.args.rgb_range, n_colors=self.args.n_colors) 136 | return torch.stack(lr_tensors), torch.stack(hr_tensors), filenames 137 | 138 | def __len__(self): 139 | if self.train: 140 | return len(self.images_hr) * self.repeat 141 | else: 142 | # if test, call all possible video sequence fragments 143 | return sum(self.n_frames_video) - (self.n_seq - 1) * len(self.n_frames_video) 144 | 145 | def _get_index(self, idx): 146 | if self.train: 147 | return idx % self.num_video 148 | else: 149 | return idx 150 | 151 | def _find_video_num(self, idx, n_frame): 152 | for i, j in enumerate(n_frame): 153 | if idx < j: 154 | return i, idx 155 | else: 156 | idx -= j 157 | 158 | def _load_file(self, idx): 159 | """ 160 | Read image from given image directory 161 | Return: n_seq * H * W * C numpy array and list of corresponding filenames 162 | """ 163 | 164 | if self.train: 165 | f_hrs = self.images_hr[idx] 166 | f_lrs = self.images_lr[idx] 167 | start = self._get_index(random.randint(0, self.n_frames_video[idx] - self.n_seq)) 168 | filenames = [os.path.splitext(os.path.basename(file))[0] for file in f_hrs[start:start+self.n_seq]] 169 | hrs = np.array([imageio.imread(hr_name) for hr_name in f_hrs[start:start+self.n_seq]]) 170 | lrs = np.array([imageio.imread(lr_name) for lr_name in f_lrs[start:start+self.n_seq]]) 171 | 172 | else: 173 | n_poss_frames = [n - self.n_seq + 1 for n in self.n_frames_video] 174 | video_idx, frame_idx = self._find_video_num(idx, n_poss_frames) 175 | f_hrs = self.images_hr[video_idx][frame_idx:frame_idx+self.n_seq] 176 | f_lrs = self.images_lr[video_idx][frame_idx:frame_idx+self.n_seq] 177 | filenames = [os.path.split(os.path.dirname(file))[-1] + '.' + os.path.splitext(os.path.basename(file))[0] for file in f_hrs] 178 | hrs = np.array([imageio.imread(hr_name) for hr_name in f_hrs]) 179 | lrs = np.array([imageio.imread(lr_name) for lr_name in f_lrs]) 180 | return lrs, hrs, filenames 181 | 182 | def _load_file_from_loaded_data(self, idx): 183 | idx = self._get_index(idx) 184 | 185 | if self.train: 186 | start = self._get_index(random.randint(0, self.n_frames_video[idx] - self.n_seq)) 187 | hrs = self.data_hr[idx][start:start+self.n_seq] 188 | lrs = self.data_lr[idx][start:start+self.n_seq] 189 | filenames = [os.path.splitext(os.path.split(name)[-1])[0] for name in self.images_hr[idx]] 190 | 191 | else: 192 | n_poss_frames = [n - self.n_seq + 1 for n in self.n_frames_video] 193 | video_idx, frame_idx = self._find_video_num(idx, n_poss_frames) 194 | f_hrs = self.images_hr[video_idx][frame_idx:frame_idx+self.n_seq] 195 | hrs = self.data_hr[video_idx][frame_idx:frame_idx+self.n_seq] 196 | lrs = self.data_lr[video_idx][frame_idx:frame_idx+self.n_seq] 197 | filenames = [os.path.split(os.path.dirname(file))[-1] + '.' + os.path.splitext(os.path.basename(file))[0] for file in f_hrs] 198 | 199 | return lrs, hrs, filenames 200 | 201 | def get_patch(self, lr, hr): 202 | """ 203 | Returns patches for multiple scales 204 | """ 205 | scale = self.scale 206 | if self.train: 207 | patch_size = self.args.patch_size - (self.args.patch_size % 4) 208 | lr, hr = common.get_patch( 209 | lr, 210 | hr, 211 | patch_size=patch_size, 212 | scale=scale, 213 | ) 214 | if not self.args.no_augment: 215 | lr, hr = common.augment(lr, hr) 216 | else: 217 | ih, iw = lr.shape[:2] 218 | ih -= ih % 4 219 | iw -= iw % 4 220 | lr = lr[:ih, :iw] 221 | hr = hr[:ih * scale, :iw * scale] 222 | 223 | return lr, hr 224 | -------------------------------------------------------------------------------- /logger/logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import imageio 3 | import numpy as np 4 | import os 5 | import datetime 6 | from scipy import misc 7 | import skimage.color as sc 8 | 9 | import matplotlib 10 | matplotlib.use('Agg') 11 | from matplotlib import pyplot as plt 12 | 13 | 14 | class Logger: 15 | def __init__(self, args): 16 | self.args = args 17 | self.psnr_log = torch.Tensor() 18 | self.loss_log = torch.Tensor() 19 | 20 | if args.load == '.': 21 | if args.save == '.': 22 | args.save = datetime.datetime.now().strftime('%Y%m%d_%H:%M') 23 | self.dir = 'experiment/' + args.save 24 | else: 25 | self.dir = 'experiment/' + args.load 26 | if not os.path.exists(self.dir): 27 | args.load = '.' 28 | else: 29 | self.loss_log = torch.load(self.dir + '/loss_log.pt') 30 | self.psnr_log = torch.load(self.dir + '/psnr_log.pt') 31 | print('Continue from epoch {}...'.format(len(self.psnr_log))) 32 | 33 | if args.reset: 34 | os.system('rm -rf {}'.format(self.dir)) 35 | args.load = '.' 36 | 37 | if not os.path.exists(self.dir): 38 | os.makedirs(self.dir) 39 | if not os.path.exists(self.dir + '/model'): 40 | os.makedirs(self.dir + '/model') 41 | if not os.path.exists(self.dir + '/result/'+self.args.data_test): 42 | print("Creating dir for saving images...", self.dir + '/result/'+self.args.data_test) 43 | os.makedirs(self.dir + '/result/'+self.args.data_test) 44 | 45 | print('Save Path : {}'.format(self.dir)) 46 | 47 | open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w' 48 | self.log_file = open(self.dir + '/log.txt', open_type) 49 | with open(self.dir + '/config.txt', open_type) as f: 50 | f.write('From epoch {}...'.format(len(self.psnr_log)) + '\n\n') 51 | for arg in vars(args): 52 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 53 | f.write('\n') 54 | 55 | def write_log(self, log): 56 | print(log) 57 | self.log_file.write(log + '\n') 58 | 59 | def save(self, trainer, epoch, is_best): 60 | trainer.model.save(self.dir, is_best) 61 | torch.save(self.loss_log, os.path.join(self.dir, 'loss_log.pt')) 62 | torch.save(self.psnr_log, os.path.join(self.dir, 'psnr_log.pt')) 63 | torch.save(trainer.optimizer.state_dict(), os.path.join(self.dir, 'optimizer.pt')) 64 | self.plot_loss_log(epoch) 65 | self.plot_psnr_log(epoch) 66 | 67 | def save_images(self, filename, save_list, scale): 68 | if self.args.task == 'Image': 69 | filename = '{}/result/{}/{}_x{}_'.format(self.dir, self.args.data_test, filename, scale) 70 | postfix = ['LR', 'HR', 'SR'] 71 | elif self.args.task == 'MC': 72 | f = filename.split('.') 73 | filename = '{}/result/{}/{}/{}_'.format(self.dir, self.args.data_test, f[0], f[1]) 74 | if not os.path.exists(os.path.dirname(filename)): 75 | os.makedirs(os.path.dirname(filename)) 76 | postfix = ['f1', 'f2', 'f2c'] 77 | elif self.args.task == 'Video': 78 | f = filename.split('.') 79 | filename = '{}/result/{}/{}/{}_'.format(self.dir, self.args.data_test, f[0], f[1]) 80 | if not os.path.exists(os.path.dirname(filename)): 81 | os.makedirs(os.path.dirname(filename)) 82 | postfix = ['LR', 'HR', 'SR'] 83 | for img, post in zip(save_list, postfix): 84 | img = img[0].data.mul(255 / self.args.rgb_range) 85 | img = np.transpose(img.cpu().numpy(), (1, 2, 0)) 86 | if img.shape[2] == 1: 87 | img = img.squeeze(axis=2) 88 | elif img.shape[2] == 3 and self.args.n_colors == 1: 89 | img = sc.ycbcr2rgb(img.astype('float')).clip(0, 1) 90 | img = (255 * img).round().astype('uint8') 91 | #img = img[:,:,0].round().astype('uint8') 92 | if post == 'LR': 93 | img = misc.imresize(img, size=self.args.scale*100, interp='bicubic') 94 | imageio.imwrite('{}{}.png'.format(filename, post), img) 95 | 96 | def start_log(self, train=True): 97 | if train: 98 | self.loss_log = torch.cat((self.loss_log, torch.zeros(1))) 99 | else: 100 | self.psnr_log = torch.cat((self.psnr_log, torch.zeros(1))) 101 | 102 | def report_log(self, item, train=True): 103 | if train: 104 | self.loss_log[-1] += item 105 | else: 106 | self.psnr_log[-1] += item 107 | 108 | def end_log(self, n_div, train=True): 109 | if train: 110 | self.loss_log[-1].div_(n_div) 111 | else: 112 | self.psnr_log[-1].div_(n_div) 113 | 114 | def plot_loss_log(self, epoch): 115 | axis = np.linspace(1, epoch, epoch) 116 | fig = plt.figure() 117 | plt.title('Loss Graph') 118 | plt.plot(axis, self.loss_log.numpy()) 119 | plt.legend() 120 | plt.xlabel('Epochs') 121 | plt.ylabel('Loss') 122 | plt.grid(True) 123 | plt.savefig(os.path.join(self.dir, 'loss.pdf')) 124 | plt.close(fig) 125 | 126 | def plot_psnr_log(self, epoch): 127 | axis = np.linspace(1, epoch, epoch) 128 | fig = plt.figure() 129 | plt.title('PSNR Graph') 130 | plt.plot(axis, self.psnr_log.numpy()) 131 | plt.legend() 132 | plt.xlabel('Epochs') 133 | plt.ylabel('PSNR') 134 | plt.grid(True) 135 | plt.savefig(os.path.join(self.dir, 'psnr.pdf')) 136 | plt.close(fig) 137 | 138 | def done(self): 139 | self.log_file.close() 140 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import data 4 | import model 5 | from option import args 6 | from trainer import Trainer 7 | from trainer_mc import Trainer_MC 8 | from trainer_vsr import Trainer_VSR 9 | from logger import logger 10 | 11 | torch.manual_seed(args.seed) 12 | chkp = logger.Logger(args) 13 | if args.task == 'MC': 14 | print("Selected task: MC") 15 | model = model.Model(args, chkp) 16 | loader = data.Data(args) 17 | t = Trainer_MC(args, loader, model, chkp) 18 | while not t.terminate(): 19 | t.train() 20 | t.test() 21 | 22 | elif args.task == 'Video': 23 | print("Selected task: Video") 24 | model = model.Model(args, chkp) 25 | loader = data.Data(args) 26 | t = Trainer_VSR(args, loader, model, chkp) 27 | while not t.terminate(): 28 | t.train() 29 | t.test() 30 | 31 | elif args.task == 'Image': 32 | print("Selected task: Image") 33 | loader = data.Data(args) 34 | model = model.Model(args, chkp) 35 | t = Trainer(args, loader, model, chkp) 36 | while not t.terminate(): 37 | t.train() 38 | t.test() 39 | 40 | else: 41 | print('Please Enter Appropriate Task Type!!!') 42 | 43 | chkp.done() 44 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class Model(nn.Module): 9 | def __init__(self, args, ckp): 10 | super(Model, self).__init__() 11 | #print('Making model...') 12 | self.args = args 13 | self.scale = args.scale 14 | self.cpu = args.cpu 15 | self.device = torch.device('cpu' if args.cpu else 'cuda') 16 | self.n_GPUs = args.n_GPUs 17 | self.ckp = ckp 18 | 19 | module = import_module('model.' + args.model.lower()) 20 | self.model = module.make_model(args).to(self.device) 21 | if not args.cpu and args.n_GPUs > 1: 22 | self.model = nn.DataParallel(self.model, range(args.n_GPUs)) 23 | 24 | self.load( 25 | ckp.dir, 26 | pre_train=args.pre_train, 27 | resume=args.resume, 28 | cpu=args.cpu 29 | ) 30 | print(self.get_model(), file=ckp.log_file) 31 | 32 | def forward(self, *args): 33 | return self.model(*args) 34 | 35 | def get_model(self): 36 | if self.n_GPUs == 1: 37 | return self.model 38 | else: 39 | return self.model.module 40 | 41 | def state_dict(self, **kwargs): 42 | target = self.get_model() 43 | return target.state_dict(**kwargs) 44 | 45 | def save(self, apath, is_best=False, filename=''): 46 | target = self.get_model() 47 | filename = 'model_{}'.format(filename) 48 | torch.save( 49 | target.state_dict(), 50 | os.path.join(apath, 'model', '{}latest.pt'.format(filename)) 51 | ) 52 | if is_best: 53 | torch.save( 54 | target.state_dict(), 55 | os.path.join(apath, 'model', '{}best.pt'.format(filename)) 56 | ) 57 | 58 | def load(self, apath, pre_train='.', resume=False, cpu=False): 59 | if cpu: 60 | kwargs = {'map_location': lambda storage, loc: storage} 61 | else: 62 | kwargs = {} 63 | 64 | if pre_train != '.': 65 | print('Loading model from {}'.format(pre_train)) 66 | self.get_model().load_state_dict( 67 | torch.load(pre_train, **kwargs), 68 | strict=False 69 | ) 70 | 71 | elif resume: 72 | print('Loading model from {}'.format(os.path.join(apath, 'model', 'model_latest.pt'))) 73 | self.get_model().load_state_dict( 74 | torch.load( 75 | os.path.join(apath, 'model', 'model_latest.pt'), 76 | **kwargs 77 | ), 78 | strict=False 79 | ) 80 | elif self.args.test_only: 81 | self.get_model().load_state_dict( 82 | torch.load( 83 | os.path.join(apath, 'model', 'model_best.pt'), 84 | **kwargs 85 | ), 86 | strict=False 87 | ) 88 | 89 | -------------------------------------------------------------------------------- /model/espcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def make_model(args): 6 | return ESPCN(args) 7 | 8 | class ESPCN(nn.Module): 9 | #upscale_factor -> args 10 | def __init__(self, args): 11 | super(ESPCN, self).__init__() 12 | print("Creating ESPCN (x%d)" %args.scale) 13 | self.conv1 = nn.Conv2d(args.n_colors, 64, kernel_size=5, padding=2) 14 | self.conv2 = nn.Conv2d(64, 32, kernel_size=3, padding=1) 15 | self.conv3 = nn.Conv2d(32, args.n_colors * args.scale * args.scale, kernel_size=3, padding=1) 16 | self.pixel_shuffle = nn.PixelShuffle(args.scale) 17 | self.conv4 = nn.Conv2d(args.n_colors, args.n_colors, kernel_size=1, padding=0) 18 | self.relu = nn.ReLU() 19 | self.tanh = nn.Tanh() 20 | 21 | def forward(self, x): 22 | x = self.relu(self.conv1(x)) 23 | x = self.relu(self.conv2(x)) 24 | x = self.relu(self.conv3(x)) 25 | x = self.pixel_shuffle(x) 26 | x = self.conv4(x) 27 | return x 28 | -------------------------------------------------------------------------------- /model/espcn_modified.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def make_model(args): 6 | return ESPCN_modified(args) 7 | 8 | class ESPCN_modified(nn.Module): 9 | #upscale_factor -> args 10 | def __init__(self, args): 11 | super(ESPCN_modified, self).__init__() 12 | print("Creating modified ESPCN (x%d)" %args.scale) 13 | self.conv1 = nn.Conv2d(args.n_colors, 64, kernel_size = 5, padding = 2) 14 | self.conv1_1 = nn.Conv2d(64, 64, kernel_size = 3, padding = 1) 15 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size = 3, padding = 1) 16 | self.conv1_3 = nn.Conv2d(64, 64, kernel_size = 3, padding = 1) 17 | self.conv2 = nn.Conv2d(64, 32, kernel_size = 3, padding = 1) 18 | self.conv3 = nn.Conv2d(32, args.n_colors * args.scale * args.scale, kernel_size = 3, padding = 1) 19 | self.pixel_shuffle = nn.PixelShuffle(args.scale) 20 | self.conv4 = nn.Conv2d(args.n_colors, args.n_colors, kernel_size = 1, padding = 0) 21 | self.relu = nn.ReLU() 22 | self.tanh = nn.Tanh() 23 | 24 | def forward(self, x): 25 | x = self.relu(self.conv1(x)) 26 | x = self.relu(self.conv1_1(x)) 27 | x = self.relu(self.conv1_2(x)) 28 | x = self.relu(self.conv1_3(x)) 29 | x = self.relu(self.conv2(x)) 30 | x = self.relu(self.conv3(x)) 31 | x = self.pixel_shuffle(x) 32 | #x = self.tanh(x) 33 | x = self.conv4(x) 34 | return x 35 | -------------------------------------------------------------------------------- /model/espcn_multiframe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def make_model(args): 6 | return ESPCN_multiframe(args) 7 | 8 | class ESPCN_multiframe(nn.Module): 9 | #upscale_factor -> args 10 | def __init__(self, args): 11 | super(ESPCN_multiframe, self).__init__() 12 | self.name = 'ESPCN_mf' 13 | print("Creating ESPCN multiframe (x%d)" %args.scale) 14 | ''' 15 | self.network = [nn.Conv2d(args.n_colors*args.n_sequence, 24, kernel_size = 3, padding =1), nn.ReLU(True)] 16 | for i in range(0,3): 17 | self.network.extend([nn.Conv2d(24, 24, kernel_size = 3, padding =1), nn.ReLU(True)]) 18 | 19 | self.network.extend([nn.Conv2d(24, args.n_colors * args.scale * args.scale, kernel_size = 3, padding =1), nn.ReLU(True)]) 20 | self.network.extend([nn.PixelShuffle(args.scale)]) 21 | self.network.extend([nn.Conv2d(args.n_colors, args.n_colors, kernel_size = 1, padding = 0)]) 22 | ''' 23 | self.network = [nn.Conv2d(args.n_colors*args.n_sequence, 64, kernel_size = 5, padding =2), nn.ReLU(True)] 24 | self.network.extend([nn.Conv2d(64, 48, kernel_size = 3, padding =1), nn.ReLU(True)]) 25 | self.network.extend([nn.Conv2d(48, 32, kernel_size = 3, padding =1), nn.ReLU(True)]) 26 | self.network.extend([nn.Conv2d(32, 24, kernel_size = 3, padding =1), nn.ReLU(True)]) 27 | self.network.extend([nn.Conv2d(24, args.n_colors * args.scale * args.scale, kernel_size = 3, padding =1), nn.ReLU(True)]) 28 | self.network.extend([nn.PixelShuffle(args.scale)]) 29 | self.network.extend([nn.Conv2d(args.n_colors, args.n_colors, kernel_size = 1, padding = 0)]) 30 | 31 | self.net = nn.Sequential(*self.network) 32 | 33 | 34 | def forward(self, x): 35 | if isinstance(x, list): 36 | # squeeze frames n_sequence * [N, 1, n_colors, H, W] -> n_sequence * [N, n_colors, H, W] 37 | lr_frames_squeezed = [torch.squeeze(frame, dim = 1) for frame in x] 38 | # concatenate frames n_sequence * [N, n_colors, H, W] -> [N, n_sequence * n_colors, H, W] 39 | x = torch.cat(lr_frames_squeezed, dim = 1) 40 | 41 | return self.net(x) -------------------------------------------------------------------------------- /model/espcn_multiframe2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def make_model(args): 7 | return ESPCN_multiframe2(args) 8 | 9 | 10 | class ESPCN_multiframe2(nn.Module): 11 | # Add Residual connection! 12 | def __init__(self, args): 13 | super(ESPCN_multiframe2, self).__init__() 14 | print("Creating ESPCN multiframe2 (x%d)" % args.scale) 15 | network = [nn.Conv2d(args.n_colors * args.n_sequence, 64, kernel_size=3, padding=1), nn.ReLU(True)] 16 | network.extend([nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(True)]) 17 | network.extend([nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(True)]) 18 | network.extend([nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(True)]) 19 | network.extend([nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(True)]) 20 | network.extend([nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(True)]) 21 | network.extend([nn.Conv2d(32, 20, kernel_size=3, padding=1), nn.ReLU(True)]) 22 | network.extend( 23 | [nn.Conv2d(20, args.n_colors * args.scale * args.scale, kernel_size=3, padding=1), nn.ReLU(True)]) 24 | network.extend([nn.PixelShuffle(args.scale)]) 25 | network.extend([nn.Conv2d(args.n_colors, args.n_colors, kernel_size=1, padding=0)]) 26 | 27 | self.net = nn.Sequential(*network) 28 | 29 | def forward(self, x): 30 | if isinstance(x, list): 31 | # squeeze frames n_sequence * [N, 1, n_colors, H, W] -> n_sequence * [N, n_colors, H, W] 32 | lr_frames_squeezed = [torch.squeeze(frame, dim = 1) for frame in x] 33 | # concatenate frames n_sequence * [N, n_colors, H, W] -> [N, n_sequence * n_colors, H, W] 34 | x = torch.cat(lr_frames_squeezed, dim = 1) 35 | 36 | return self.net(x) 37 | -------------------------------------------------------------------------------- /model/motioncompensator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | def make_model(args): 7 | return MotionCompensator(args) 8 | 9 | class MotionCompensator(nn.Module): 10 | def __init__(self, args): 11 | self.device = 'cuda' 12 | if args.cpu: 13 | self.device = 'cpu' 14 | super(MotionCompensator, self).__init__() 15 | print("Creating Motion compensator") 16 | 17 | def _gconv(in_channels, out_channels, kernel_size=3, groups=1, stride=1, bias=True): 18 | return nn.Conv2d(in_channels*groups, out_channels*groups, kernel_size, groups=groups, stride=stride, 19 | padding=(kernel_size // 2), bias=bias) 20 | 21 | # Coarse flow 22 | coarse_flow = [_gconv(2, 24, kernel_size=5, groups=args.n_colors, stride=2), nn.ReLU(inplace=True)] 23 | coarse_flow.extend([_gconv(24, 24, kernel_size=3, groups=args.n_colors), nn.ReLU(True)]) 24 | coarse_flow.extend([_gconv(24, 24, kernel_size=5, groups=args.n_colors, stride=2), nn.ReLU(True)]) 25 | coarse_flow.extend([_gconv(24, 24, kernel_size=3, groups=args.n_colors), nn.ReLU(True)]) 26 | coarse_flow.extend([_gconv(24, 32, kernel_size=3, groups=args.n_colors), nn.Tanh()]) 27 | coarse_flow.extend([nn.PixelShuffle(4)]) 28 | 29 | self.C_flow = nn.Sequential(*coarse_flow) 30 | 31 | # Fine flow 32 | fine_flow = [_gconv(5, 24, kernel_size=5, groups=args.n_colors, stride=2), nn.ReLU(inplace=True)] 33 | for _ in range(3): 34 | fine_flow.extend([_gconv(24, 24, kernel_size=3, groups=args.n_colors), nn.ReLU(True)]) 35 | fine_flow.extend([_gconv(24, 8, kernel_size=3, groups=args.n_colors), nn.Tanh()]) 36 | fine_flow.extend([nn.PixelShuffle(2)]) 37 | 38 | self.F_flow = nn.Sequential(*fine_flow) 39 | 40 | def forward(self, frame_1, frame_2): 41 | # Create identity flow 42 | x = np.linspace(-1, 1, frame_1.shape[3]) 43 | y = np.linspace(-1, 1, frame_1.shape[2]) 44 | xv, yv = np.meshgrid(x, y) 45 | id_flow = np.expand_dims(np.stack([xv, yv], axis=-1), axis=0) 46 | self.identity_flow = torch.from_numpy(id_flow).float().to(self.device) 47 | 48 | # Coarse flow 49 | coarse_in = torch.cat((frame_1, frame_2), dim=1) 50 | coarse_out = self.C_flow(coarse_in) 51 | coarse_out[:,0] /= frame_1.shape[3] 52 | coarse_out[:,1] /= frame_2.shape[2] 53 | frame_2_compensated_coarse = self.warp(frame_2, coarse_out) 54 | 55 | # Fine flow 56 | fine_in = torch.cat((frame_1, frame_2, frame_2_compensated_coarse, coarse_out), dim=1) 57 | fine_out = self.F_flow(fine_in) 58 | fine_out[:,0] /= frame_1.shape[3] 59 | fine_out[:,1] /= frame_2.shape[2] 60 | flow = (coarse_out + fine_out) 61 | 62 | frame_2_compensated = self.warp(frame_2, flow) 63 | 64 | return frame_2_compensated, flow 65 | 66 | def warp(self, img, flow): 67 | # https://discuss.pytorch.org/t/solved-how-to-do-the-interpolating-of-optical-flow/5019 68 | # permute flow N C H W -> N H W C 69 | img_compensated = F.grid_sample(img, (-flow.permute(0,2,3,1)+self.identity_flow).clamp(-1,1), padding_mode='border') 70 | return img_compensated 71 | -------------------------------------------------------------------------------- /model/vespcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from model.motioncompensator import make_model as make_mc 6 | from model.espcn_multiframe2 import make_model as make_espcn 7 | from approx_huber_loss import Approx_Huber_Loss 8 | 9 | def make_model(args): 10 | return VESPCN(args) 11 | 12 | class VESPCN(nn.Module): 13 | def __init__(self, args): 14 | self.name = 'VESPCN' 15 | self.device = 'cuda' 16 | if args.cpu: 17 | self.device = 'cpu' 18 | super(VESPCN, self).__init__() 19 | print("Creating VESPCN") 20 | 21 | self.mseloss = nn.MSELoss() 22 | self.huberloss = Approx_Huber_Loss(args) 23 | self.motionCompensator = make_mc(args) 24 | self.espcn = make_espcn(args) 25 | 26 | #self.motionCompensator.load_ 27 | self.motionCompensator.load_state_dict(torch.load('./experiment/model_best_mc.pt'), strict=False) 28 | self.espcn.load_state_dict(torch.load('./experiment/model_best_espcn.pt'), strict=False) 29 | #self.espcn.load_state_dict(torch.load('./experiment/ESPCN_multiframe/model/model_best.pt'), strict=False) 30 | 31 | def forward(self, frame_list): 32 | # squeeze frames n_sequence * [N, 1, n_colors, H, W] -> n_sequence * [N, n_colors, H, W] 33 | frame_list = [torch.squeeze(frame, dim = 1) for frame in frame_list] 34 | 35 | frame1 = frame_list[0] 36 | frame2 = frame_list[1] 37 | frame3 = frame_list[2] 38 | 39 | frame1_compensated, flow1 = self.motionCompensator(frame2, frame1) 40 | frame3_compensated, flow2 = self.motionCompensator(frame2, frame3) 41 | 42 | loss_mc_mse = self.mseloss(frame1_compensated, frame2) + self.mseloss(frame3_compensated, frame2) 43 | loss_mc_huber = self.huberloss(flow1) + self.huberloss(flow2) 44 | 45 | #print(frame1_compensated.shape, frame2.shape, frame3_compensated.shape) 46 | # n_sequence * [N, n_colors, H, W] -> [N, n_sequence * n_colors, H, W] 47 | lr_frames_cat = torch.cat((frame1_compensated, frame2, frame3_compensated), dim = 1) 48 | #print(lr_frames_cat.shape) 49 | return self.espcn(lr_frames_cat), loss_mc_mse, loss_mc_huber 50 | -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import template 3 | 4 | parser = argparse.ArgumentParser(description='EDSR and MDSR') 5 | 6 | parser.add_argument('--debug', action='store_true', 7 | help='Enables debug mode') 8 | parser.add_argument('--template', default='.', 9 | help='You can set various templates in option.py') 10 | parser.add_argument('--task', type=str, default='Video', 11 | help='Type of task (Image/Video/MC)') 12 | 13 | # Hardware specifications 14 | parser.add_argument('--n_threads', type=int, default=6, 15 | help='number of threads for data loading') 16 | parser.add_argument('--cpu', action='store_true', 17 | help='use cpu only') 18 | parser.add_argument('--n_GPUs', type=int, default=1, 19 | help='number of GPUs') 20 | parser.add_argument('--seed', type=int, default=1, 21 | help='random seed') 22 | 23 | # Data specifications 24 | parser.add_argument('--dir_data', type=str, default='../../Dataset', 25 | help='dataset directory') 26 | parser.add_argument('--dir_data_test', type=str, default='../../Dataset', 27 | help='dataset directory') 28 | parser.add_argument('--dir_demo', type=str, default='../test', 29 | help='demo image directory') 30 | parser.add_argument('--data_train', type=str, default='DIV2K', 31 | help='train dataset name') 32 | parser.add_argument('--data_test', type=str, default='Set5', 33 | help='test dataset name') 34 | parser.add_argument('--data_range', type=str, default='1-90/91-100', 35 | help='train/test data range') 36 | parser.add_argument('--process', action='store_true', 37 | help='if True, load all dataset at once at RAM') 38 | parser.add_argument('--scale', type=str, default=3, 39 | help='super resolution scale') 40 | parser.add_argument('--patch_size', type=int, default=20, 41 | help='output patch size') 42 | parser.add_argument('--rgb_range', type=int, default=1, 43 | help='maximum value of RGB') 44 | parser.add_argument('--n_colors', type=int, default=1, 45 | help='number of color channels to use') 46 | parser.add_argument('--no_augment', action='store_true', 47 | help='do not use data augmentation') 48 | 49 | # Video SR parameters 50 | parser.add_argument('--n_sequence', type=int, default=3, 51 | help='length of image sequence per video') 52 | parser.add_argument('--n_frames_per_video', type=int, default=30, 53 | help='number of frames per video to load') 54 | 55 | 56 | # Model specifications 57 | parser.add_argument('--model', default='ESPCN', 58 | help='model name') 59 | parser.add_argument('--pre_train', type=str, default='.', 60 | help='pre-trained model directory') 61 | 62 | 63 | # Training specifications 64 | parser.add_argument('--reset', action='store_true', 65 | help='reset the training') 66 | parser.add_argument('--test_every', type=int, default=1000, 67 | help='do test per every N batches') 68 | parser.add_argument('--epochs', type=int, default=1000, 69 | help='number of epochs to train') 70 | parser.add_argument('--batch_size', type=int, default=16, 71 | help='input batch size for training') 72 | parser.add_argument('--test_only', action='store_true', 73 | help='set this option to test the model') 74 | 75 | 76 | # Optimization specifications 77 | parser.add_argument('--lr', type=float, default=1e-4, 78 | help='learning rate') 79 | parser.add_argument('--lr_decay', type=int, default=200, 80 | help='learning rate decay per N epochs') 81 | parser.add_argument('--decay_type', type=str, default='step', 82 | help='learning rate decay type') 83 | parser.add_argument('--gamma', type=float, default=0.5, 84 | help='learning rate decay factor for step decay') 85 | parser.add_argument('--beta1', type=float, default=0.9, 86 | help='ADAM beta1') 87 | parser.add_argument('--beta2', type=float, default=0.999, 88 | help='ADAM beta2') 89 | parser.add_argument('--epsilon', type=float, default=1e-8, 90 | help='ADAM epsilon for numerical stability') 91 | parser.add_argument('--weight_decay', type=float, default=0, 92 | help='weight decay') 93 | parser.add_argument('--lambd', type=float, default=0.0005, 94 | help='coefficient for modified huber loss') 95 | parser.add_argument('--beta', type=float, default=0.005, 96 | help='coefficient for motioncompensation mse loss') 97 | 98 | 99 | 100 | # Log specifications 101 | parser.add_argument('--save', type=str, default='save_path', 102 | help='file name to save') 103 | parser.add_argument('--load', type=str, default='.', 104 | help='file name to load') 105 | parser.add_argument('--resume', action='store_true', 106 | help='resume from the latest if true') 107 | parser.add_argument('--print_every', type=int, default=100, 108 | help='how many batches to wait before logging training status') 109 | parser.add_argument('--save_images', default=True, action='store_false', 110 | help='save images') 111 | 112 | args = parser.parse_args() 113 | template.set_template(args) 114 | 115 | if args.epochs == 0: 116 | args.epochs = 1e8 117 | 118 | for arg in vars(args): 119 | if vars(args)[arg] == 'True': 120 | vars(args)[arg] = True 121 | elif vars(args)[arg] == 'False': 122 | vars(args)[arg] = False 123 | -------------------------------------------------------------------------------- /template.py: -------------------------------------------------------------------------------- 1 | def set_template(args): 2 | if args.template == 'SY': 3 | args.data_train = 'CDVL_VIDEO' 4 | args.data_range = '1-16/90-100' 5 | args.dir_data = '../../Dataset' 6 | args.data_test = 'Vid4' 7 | args.dir_data_test = '../../Dataset' 8 | args.process = True 9 | args.lr = 5e-4 10 | if args.task == 'MC': 11 | args.model = 'MotionCompensator' 12 | args.n_sequence = 2 13 | 14 | elif args.task == 'Image': 15 | args.model = 'ESPCN' 16 | 17 | elif args.task == 'Video': 18 | args.model = 'ESPCN_multiframe2' 19 | args.patch_size = 17 20 | args.n_sequence = 3 21 | 22 | elif args.template == 'JH': 23 | args.task = "Image" 24 | args.save = args.model 25 | args.data_train = 'CDVL100' 26 | args.dir_data = '/home/johnyi/deeplearning/research/SISR_Datasets/train' 27 | args.data_test = 'Set5' 28 | args.dir_data_test = '/home/johnyi/deeplearning/research/SISR_Datasets/test' 29 | args.process = True 30 | elif args.template == 'JH_Video': 31 | args.task = "Video" 32 | args.save = args.model 33 | args.test_every = 1000 34 | args.n_sequence = 3 35 | args.n_frames_per_video = 15 36 | args.data_range = '1-135/91-100' 37 | args.data_train = 'CDVL_VIDEO' 38 | args.dir_data = '/home/johnyi/deeplearning/research/VSR_Datasets/train' 39 | args.data_test = 'Vid4' 40 | args.dir_data_test = '/home/johnyi/deeplearning/research/VSR_Datasets/test' 41 | args.process = True 42 | elif args.template == 'JH_MC': 43 | args.task = "MC" 44 | args.model = "MotionCompensator" 45 | args.save = args.model 46 | args.n_sequence = 2 47 | args.n_frames_per_video = 15 48 | args.data_range = '1-5/91-100' 49 | args.epochs = 1000 50 | args.data_train = 'CDVL_VIDEO' 51 | args.dir_data = '/home/johnyi/deeplearning/research/VSR_Datasets/train' 52 | args.data_test = 'Vid4' 53 | args.dir_data_test = '/home/johnyi/deeplearning/research/VSR_Datasets/test' 54 | args.process = True 55 | else: 56 | # TODO: Download train/test data & modify args for real testing 57 | args.batch_size = 2 58 | args.epochs = 1000 59 | args.dir_data = '/Users/junhokim/videoSR/data' 60 | args.dir_data_test = '/Users/junhokim/videoSR/data' 61 | args.process = True 62 | args.n_sequence = 4 63 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Below are imports just for testing 2 | from option import args 3 | import cv2, imageio 4 | from data.vsrdata import VSRData 5 | """ 6 | Original test setting: 5 videos with 30 frames inside each directory 7 | length of frame sequence = 4 8 | batch_size = 2 9 | """ 10 | ''' 11 | if __name__ == '__main__': 12 | if args.template == 'SY': 13 | vsr = VSRData(args, name='CDVL_Video', train=False) 14 | else: 15 | vsr = VSRData(args) 16 | print(len(vsr.data_hr)) # 5 17 | print(len(vsr.data_lr)) # 5 18 | print(vsr.data_hr[1].shape) # (4,1080,1920,3) 19 | print(vsr.data_lr[1].shape) # (4, 360, 640, 3) 20 | img_samples = [] 21 | for i in range(args.n_sequence): 22 | imageio.imwrite('hr_{}.jpg'.format(i), vsr.data_hr[0][i, :]) 23 | imageio.imwrite('lr_{}.jpg'.format(i), vsr.data_lr[0][i, :]) 24 | print(len(vsr[0][0])) # 4 25 | print(vsr[0][0][0].shape) # torch.Size([3,17,17]) 26 | print(len(vsr[0][1])) # 4 27 | print(vsr[0][1][0].shape) # torch.Size([3,51,51]) 28 | print(vsr[0][2]) # ['00001', '00002', '00003', '00004'] 29 | ''' 30 | 31 | 32 | import torch.nn.functional as F 33 | import torch.optim as optim 34 | import torch 35 | import torchvision 36 | from PIL import Image 37 | import numpy as np 38 | 39 | img = Image.open('./frame2.png') 40 | img = np.array(img) 41 | img = np.array([img]).astype("float64") 42 | b = torch.from_numpy(img).double() 43 | b = b.permute(0, 3, 1, 2) 44 | print(b.size()) 45 | ''' 46 | flow = np.zeros((img.shape[1], img.shape[2], 2)) 47 | for i in range(0, img.shape[1]): 48 | for j in range(0, img.shape[2]): 49 | flow[i,j,:] = [i, j] 50 | flow[:,:,0] = flow[:,:,0]/img.shape[1] 51 | flow[:,:,1] = flow[:,:,1]/img.shape[2] 52 | flow = np.concatenate((flow[:,:,1:], flow[:,:,0:1]), axis=2) 53 | flow = (flow - 0.5) * 2 54 | flow = np.array([flow]) 55 | flow = torch.from_numpy(flow).double() 56 | ''' 57 | # Create identity flow 58 | x = np.linspace(-1, 1, img.shape[2]) - 0.01 59 | y = np.linspace(-1, 1, img.shape[1]) - 0.01 60 | xv, yv = np.meshgrid(x, y) 61 | id_flow = np.expand_dims(np.stack([xv, yv], axis=-1), axis=0) 62 | flow = torch.from_numpy(id_flow).double() 63 | 64 | compensated = F.grid_sample(b, flow) 65 | out = compensated.permute(0,2,3,1) 66 | print(compensated.shape) 67 | print(out.shape) 68 | out = np.round(out.numpy()).astype("uint8")[0] 69 | print(np.max(out), np.min(out)) 70 | imageio.imwrite("./out.png", out) -------------------------------------------------------------------------------- /tools/imgresize.m: -------------------------------------------------------------------------------- 1 | function imgresize() 2 | folders = dir('./HR'); 3 | for i = 3:numel(folders) 4 | foldername = folders(i).name; 5 | folderpath = ['./HR/', foldername]; 6 | files = dir(folderpath); 7 | newfoldpath = ['./LR_bicubic/X3/', foldername]; 8 | mkdir(newfoldpath); 9 | for j = 3:numel(files) 10 | fname = files(j).name; 11 | disp(fname) 12 | fullname = fullfile(folderpath, fname); 13 | img = imread(fullname); 14 | lrimg = imresize(img, 1/3); 15 | newfpath = fullfile(newfoldpath, fname); 16 | imwrite(lrimg, newfpath); 17 | end 18 | end 19 | end -------------------------------------------------------------------------------- /tools/videocapture.m: -------------------------------------------------------------------------------- 1 | function videocapture(videoname, time) 2 | videon = ['video', videoname, '.avi']; 3 | v = VideoReader(videon); 4 | 5 | v.CurrentTime = time; 6 | x = v.CurrentTime; 7 | subfolder = ['Video', videoname]; 8 | mkdir('./HR', subfolder); 9 | for ii = 1:30 10 | img = readFrame(v); 11 | filename = [sprintf('%05d',int32(ii+x * 30)) '.png']; 12 | fullname = fullfile('./HR',subfolder, filename); 13 | imwrite(img,fullname) 14 | end 15 | end 16 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import imageio 5 | import decimal 6 | 7 | import numpy as np 8 | from scipy import misc 9 | 10 | import torch 11 | import torch.optim as optim 12 | import torch.optim.lr_scheduler as lrs 13 | import torch.nn as nn 14 | from tqdm import tqdm 15 | 16 | import utils 17 | 18 | 19 | class Trainer: 20 | def __init__(self, args, loader, my_model, ckp): 21 | self.args = args 22 | self.scale = args.scale 23 | self.device = torch.device('cpu' if self.args.cpu else 'cuda') 24 | self.loader_train = loader.loader_train 25 | self.loader_test = loader.loader_test 26 | self.model = my_model 27 | self.optimizer = self.make_optimizer() 28 | self.scheduler = self.make_scheduler() 29 | self.ckp = ckp 30 | self.loss = nn.MSELoss() 31 | 32 | if args.load != '.': 33 | self.optimizer.load_state_dict(torch.load(os.path.join(ckp.dir, 'optimizer.pt'))) 34 | for _ in range(len(ckp.psnr_log)): 35 | self.scheduler.step() 36 | 37 | def make_optimizer(self): 38 | kwargs = {'lr': self.args.lr, 'weight_decay': self.args.weight_decay} 39 | return optim.Adam(self.model.parameters(), **kwargs) 40 | 41 | def make_scheduler(self): 42 | kwargs = {'step_size': self.args.lr_decay, 'gamma': self.args.gamma} 43 | return lrs.StepLR(self.optimizer, **kwargs) 44 | 45 | def train(self): 46 | self.scheduler.step() 47 | epoch = self.scheduler.last_epoch + 1 48 | lr = self.scheduler.get_lr()[0] 49 | 50 | self.ckp.write_log('Epoch {:3d} with Lr {:.2e}'.format(epoch, decimal.Decimal(lr))) 51 | 52 | self.model.train() 53 | self.ckp.start_log() 54 | for batch, (lr, _, hr, _) in enumerate(self.loader_train): 55 | if self.args.n_colors == 1 and lr.size()[1] == 3: 56 | lr = lr[:, 0:1, :, :] 57 | hr = hr[:, 0:1, :, :] 58 | 59 | lr = lr.to(self.device) 60 | hr = hr.to(self.device) 61 | 62 | self.optimizer.zero_grad() 63 | sr = self.model(lr) 64 | loss = self.loss(sr, hr) 65 | self.ckp.report_log(loss.item()) 66 | loss.backward() 67 | self.optimizer.step() 68 | 69 | if (batch+1) % self.args.print_every == 0: 70 | self.ckp.write_log('[{}/{}]\tLoss : {:.5f}'.format( 71 | (batch + 1) * self.args.batch_size, len(self.loader_train.dataset), 72 | self.ckp.loss_log[-1] / (batch + 1))) 73 | 74 | self.ckp.end_log(len(self.loader_train)) 75 | 76 | def test(self): 77 | epoch = self.scheduler.last_epoch + 1 78 | self.ckp.write_log('\nEvaluation:') 79 | self.model.eval() 80 | self.ckp.start_log(train=False) 81 | with torch.no_grad(): 82 | tqdm_test = tqdm(self.loader_test, ncols=80) 83 | bic_PSNR = 0 84 | for idx_img, (lr, lre, hr, filename) in enumerate(tqdm_test): 85 | ycbcr_flag = False 86 | if self.args.n_colors == 1 and lr.size()[1] == 3: 87 | # If n_colors is 1, split image into Y,Cb,Cr 88 | ycbcr_flag = True 89 | sr_cbcr = lre[:, 1:, :, :].to(self.device) 90 | lre = lre[:, 0:1, :, :] 91 | lr_cbcr = lr[:, 1:, :, :].to(self.device) 92 | lr = lr[:, 0:1, :, :] 93 | hr_cbcr = hr[:, 1:, :, :].to(self.device) 94 | hr = hr[:, 0:1, :, :] 95 | 96 | filename = filename[0] 97 | lre = lre.to(self.device) 98 | lr = lr.to(self.device) 99 | hr = hr.to(self.device) 100 | sr = self.model(lr) 101 | PSNR = utils.calc_psnr(self.args, sr, hr) 102 | bic_PSNR += utils.calc_psnr(self.args, lre, hr) 103 | self.ckp.report_log(PSNR, train=False) 104 | lr, hr, sr = utils.postprocess(lr, hr, sr, 105 | rgb_range=self.args.rgb_range, 106 | ycbcr_flag=ycbcr_flag, device=self.device) 107 | 108 | if ycbcr_flag: 109 | lr = torch.cat((lr, lr_cbcr), dim=1) 110 | hr = torch.cat((hr, hr_cbcr), dim=1) 111 | sr = torch.cat((sr, sr_cbcr), dim=1) 112 | 113 | save_list = [lr, hr, sr] 114 | if self.args.save_images: 115 | self.ckp.save_images(filename, save_list, self.args.scale) 116 | 117 | self.ckp.end_log(len(self.loader_test), train=False) 118 | best = self.ckp.psnr_log.max(0) 119 | self.ckp.write_log('[{}]\taverage PSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( 120 | self.args.data_test, self.ckp.psnr_log[-1], 121 | best[0], best[1] + 1)) 122 | print('Bicubic PSNR: {:.3f}'.format(bic_PSNR / len(self.loader_test))) 123 | if not self.args.test_only: 124 | self.ckp.save(self, epoch, is_best=(best[1] + 1 == epoch)) 125 | 126 | def terminate(self): 127 | if self.args.test_only: 128 | self.test() 129 | return True 130 | else: 131 | epoch = self.scheduler.last_epoch + 1 132 | return epoch >= self.args.epochs -------------------------------------------------------------------------------- /trainer_mc.py: -------------------------------------------------------------------------------- 1 | import time 2 | import decimal 3 | import numpy as np 4 | 5 | import os 6 | import torch 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lrs 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from tqdm import tqdm 12 | from approx_huber_loss import Approx_Huber_Loss 13 | 14 | import utils 15 | 16 | class Trainer_MC: 17 | def __init__(self, args, loader, my_model, ckp): 18 | self.args = args 19 | self.scale = args.scale 20 | self.device = torch.device('cpu' if self.args.cpu else 'cuda') 21 | self.loader_train = loader.loader_train 22 | self.loader_test = loader.loader_test 23 | self.model = my_model 24 | self.optimizer = self.make_optimizer() 25 | self.scheduler = self.make_scheduler() 26 | self.ckp = ckp 27 | self.loss = nn.MSELoss() 28 | self.flow_loss = Approx_Huber_Loss(args) 29 | 30 | if args.load != '.': 31 | self.optimizer.load_state_dict(torch.load(os.path.join(ckp.dir, 'optimizer.pt'))) 32 | for _ in range(len(ckp.psnr_log)): 33 | self.scheduler.step() 34 | 35 | def set_loader(self, new_loader): 36 | self.loader_train = new_loader.loader_train 37 | self.loader_test = new_loader.loader_test 38 | 39 | def make_optimizer(self): 40 | kwargs = {'lr': self.args.lr, 'weight_decay': self.args.weight_decay} 41 | return optim.Adam(self.model.parameters(), **kwargs) 42 | 43 | def make_scheduler(self): 44 | kwargs = {'step_size': self.args.lr_decay, 'gamma': self.args.gamma} 45 | return lrs.StepLR(self.optimizer, **kwargs) 46 | 47 | def train(self): 48 | print("VSR training") 49 | self.scheduler.step() 50 | epoch = self.scheduler.last_epoch + 1 51 | lr = self.scheduler.get_lr()[0] 52 | 53 | self.ckp.write_log('Epoch {:3d} with Lr {:.2e}'.format(epoch, decimal.Decimal(lr))) 54 | 55 | self.model.train() 56 | self.ckp.start_log() 57 | for batch, (lr, _, _) in enumerate(self.loader_train): 58 | # tensor size of lr : B*n*C*H*W (H, W = args.patch_size) 59 | self.optimizer.zero_grad() 60 | if self.args.n_colors == 1 and lr.size()[-3] == 3: 61 | lr = lr[:, :, 0:1, :, :] 62 | lr = lr.to(self.device) 63 | frame1, frame2 = lr[:, 0], lr[:, 1] 64 | frame2_compensated, flow = self.model(frame1, frame2) 65 | loss = self.loss(frame2_compensated, frame1) + self.args.lambd * self.flow_loss(flow) 66 | 67 | self.ckp.report_log(loss.item()) # TODO: Check logging issues for Huber loss 68 | loss.backward() 69 | self.optimizer.step() 70 | 71 | if (batch + 1) % self.args.print_every == 0: 72 | self.ckp.write_log('[{}/{}]\tLoss : {:.5f}'.format( 73 | (batch + 1) * self.args.batch_size, len(self.loader_train.dataset), 74 | self.ckp.loss_log[-1] / (batch + 1))) 75 | 76 | self.ckp.end_log(len(self.loader_train)) 77 | 78 | def test(self): 79 | epoch = self.scheduler.last_epoch + 1 80 | self.ckp.write_log('\nEvaluation:') 81 | self.model.eval() 82 | self.ckp.start_log(train=False) 83 | with torch.no_grad(): 84 | tqdm_test = tqdm(self.loader_test, ncols=80) 85 | for idx_img, (lr, _, filename) in enumerate(tqdm_test): 86 | ycbcr_flag = False 87 | filename = filename[0][0] 88 | lr = lr.to(self.device) 89 | frame1, frame2 = lr[:, 0], lr[:, 1] 90 | if self.args.n_colors == 1 and lr.size()[-3] == 3: 91 | ycbcr_flag = True 92 | frame1_cbcr = frame1[:, 1:] 93 | frame2_cbcr = frame2[:, 1:] 94 | frame1 = frame1[:, 0:1] 95 | frame2 = frame2[:, 0:1] 96 | 97 | frame2_compensated, flow = self.model(frame1, frame2) 98 | 99 | PSNR = utils.calc_psnr(self.args, frame1, frame2_compensated) 100 | self.ckp.report_log(PSNR, train=False) 101 | frame1, frame2, frame2c = utils.postprocess(frame1, frame2, frame2_compensated, 102 | rgb_range=self.args.rgb_range, ycbcr_flag=ycbcr_flag, device=self.device) 103 | 104 | if ycbcr_flag: 105 | frame1 = torch.cat((frame1, frame1_cbcr), dim=1) 106 | frame2 = torch.cat((frame2, frame2_cbcr), dim=1) 107 | frame2_cbcr_c = F.grid_sample(frame2_cbcr, flow.permute(0, 2, 3, 1), padding_mode='border') 108 | frame2c = torch.cat((frame2c, frame2_cbcr_c), dim=1) 109 | 110 | save_list = [frame1, frame2, frame2c] 111 | if self.args.save_images and idx_img%10 == 0: 112 | self.ckp.save_images(filename, save_list, self.args.scale) 113 | 114 | self.ckp.end_log(len(self.loader_test), train=False) 115 | best = self.ckp.psnr_log.max(0) 116 | self.ckp.write_log('[{}]\taverage PSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( 117 | self.args.data_test, self.ckp.psnr_log[-1], 118 | best[0], best[1] + 1)) 119 | if not self.args.test_only: 120 | self.ckp.save(self, epoch, is_best=(best[1] + 1 == epoch)) 121 | 122 | def terminate(self): 123 | if self.args.test_only: 124 | self.test() 125 | return True 126 | else: 127 | epoch = self.scheduler.last_epoch + 1 128 | return epoch >= self.args.epochs 129 | -------------------------------------------------------------------------------- /trainer_vsr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import imageio 5 | import decimal 6 | 7 | import numpy as np 8 | from scipy import misc 9 | 10 | import torch 11 | import torch.optim as optim 12 | import torch.optim.lr_scheduler as lrs 13 | import torch.nn as nn 14 | from tqdm import tqdm 15 | 16 | import utils 17 | 18 | 19 | class Trainer_VSR: 20 | def __init__(self, args, loader, my_model, ckp): 21 | self.args = args 22 | self.scale = args.scale 23 | self.device = torch.device('cpu' if self.args.cpu else 'cuda') 24 | self.loader_train = loader.loader_train 25 | self.loader_test = loader.loader_test 26 | self.model = my_model 27 | self.optimizer = self.make_optimizer() 28 | self.scheduler = self.make_scheduler() 29 | self.ckp = ckp 30 | self.loss = nn.MSELoss() 31 | 32 | if args.load != '.': 33 | self.optimizer.load_state_dict(torch.load(os.path.join(ckp.dir, 'optimizer.pt'))) 34 | for _ in range(len(ckp.psnr_log)): 35 | self.scheduler.step() 36 | 37 | def set_loader(self, new_loader): 38 | self.loader_train = new_loader.loader_train 39 | self.loader_test = new_loader.loader_test 40 | 41 | def make_optimizer(self): 42 | kwargs = {'lr': self.args.lr, 'weight_decay': self.args.weight_decay} 43 | return optim.Adam(self.model.parameters(), **kwargs) 44 | 45 | def make_scheduler(self): 46 | kwargs = {'step_size': self.args.lr_decay, 'gamma': self.args.gamma} 47 | return lrs.StepLR(self.optimizer, **kwargs) 48 | 49 | def train(self): 50 | print("VSR training") 51 | self.scheduler.step() 52 | epoch = self.scheduler.last_epoch + 1 53 | lr = self.scheduler.get_lr()[0] 54 | 55 | self.ckp.write_log('Epoch {:3d} with Lr {:.2e}'.format(epoch, decimal.Decimal(lr))) 56 | 57 | self.model.train() 58 | self.ckp.start_log() 59 | for batch, (lr, hr, _) in enumerate(self.loader_train): 60 | #lr: [batch_size, n_seq, 3, patch_size, patch_size] 61 | if self.args.n_colors == 1 and lr.size()[2] == 3: 62 | lr = lr[:, :, 0:1, :, :] 63 | hr = hr[:, :, 0:1, :, :] 64 | 65 | # Divide LR frame sequence [N, n_sequence, n_colors, H, W] -> n_sequence * [N, 1, n_colors, H, W] 66 | lr = list(torch.split(lr, self.args.n_colors, dim = 1)) 67 | 68 | # target frame = middle HR frame [N, n_colors, H, W] 69 | hr = hr[:, int(hr.shape[1]/2), : ,: ,:] 70 | 71 | #lr = lr.to(self.device) 72 | lr = [x.to(self.device) for x in lr] 73 | hr = hr.to(self.device) 74 | 75 | self.optimizer.zero_grad() 76 | # output frame = single HR frame [N, n_colors, H, W] 77 | if self.model.get_model().name == 'ESPCN_mf': 78 | sr = self.model(lr) 79 | loss = self.loss(sr, hr) 80 | elif self.model.get_model().name == 'VESPCN': 81 | sr, loss_mc_mse, loss_mc_huber = self.model(lr) 82 | loss_mc = self.args.beta * loss_mc_mse + self.args.lambd * loss_mc_huber 83 | loss_espcn = self.loss(sr, hr) 84 | loss = loss_espcn + loss_mc 85 | 86 | self.ckp.report_log(loss.item()) 87 | loss.backward() 88 | self.optimizer.step() 89 | 90 | if (batch + 1) % self.args.print_every == 0: 91 | self.ckp.write_log('[{}/{}]\tLoss : {:.5f}'.format( 92 | (batch + 1) * self.args.batch_size, len(self.loader_train.dataset), 93 | self.ckp.loss_log[-1] / (batch + 1))) 94 | print(loss_mc.item(), loss_espcn.item()) 95 | self.ckp.end_log(len(self.loader_train)) 96 | 97 | def test(self): 98 | epoch = self.scheduler.last_epoch + 1 99 | self.ckp.write_log('\nEvaluation:') 100 | self.model.eval() 101 | self.ckp.start_log(train=False) 102 | with torch.no_grad(): 103 | tqdm_test = tqdm(self.loader_test, ncols=80) 104 | for idx_img, (lr, hr, filename) in enumerate(tqdm_test): 105 | ycbcr_flag = False 106 | filename = filename[0][0] 107 | # lr: [batch_size, n_seq, 3, patch_size, patch_size] 108 | if self.args.n_colors == 1 and lr.size()[2] == 3: 109 | # If n_colors is 1, split image into Y,Cb,Cr 110 | ycbcr_flag = True 111 | # for CbCr, select the middle frame 112 | lr_center_y = lr[:, int(hr.shape[1]/2), 0:1, :, :].to(self.device) 113 | lr_cbcr = lr[:, int(hr.shape[1]/2), 1:, :, :].to(self.device) 114 | hr_cbcr = hr[:, int(hr.shape[1]/2), 1:, :, :].to(self.device) 115 | # extract Y channels (lr should be group, hr should be the center frame) 116 | lr = lr[:, :, 0:1, :, :] 117 | hr = hr[:, int(hr.shape[1]/2), 0:1, :, :] 118 | 119 | # Divide LR frame sequence [N, n_sequence, n_colors, H, W] -> n_sequence * [N, 1, n_colors, H, W] 120 | lr = list(torch.split(lr, self.args.n_colors, dim = 1)) 121 | 122 | #lr = lr.to(self.device) 123 | lr = [x.to(self.device) for x in lr] 124 | hr = hr.to(self.device) 125 | 126 | # output frame = single HR frame [N, n_colors, H, W] 127 | if self.model.get_model().name == 'ESPCN_mf': 128 | sr = self.model(lr) 129 | elif self.model.get_model().name == 'VESPCN': 130 | sr, _, _ = self.model(lr) 131 | 132 | PSNR = utils.calc_psnr(self.args, sr, hr) 133 | self.ckp.report_log(PSNR, train=False) 134 | hr, sr = utils.postprocess(hr, sr, rgb_range=self.args.rgb_range, 135 | ycbcr_flag=ycbcr_flag, device=self.device) 136 | 137 | if self.args.save_images and idx_img%30 == 0: 138 | if ycbcr_flag: 139 | [lr_center_y] = utils.postprocess(lr_center_y, rgb_range=self.args.rgb_range, 140 | ycbcr_flag=ycbcr_flag, device=self.device) 141 | lr = torch.cat((lr_center_y, lr_cbcr), dim=1) 142 | hr = torch.cat((hr, hr_cbcr), dim=1) 143 | sr = torch.cat((sr, hr_cbcr), dim=1) 144 | 145 | save_list = [lr, hr, sr] 146 | 147 | self.ckp.save_images(filename, save_list, self.args.scale) 148 | 149 | self.ckp.end_log(len(self.loader_test), train=False) 150 | best = self.ckp.psnr_log.max(0) 151 | self.ckp.write_log('[{}]\taverage PSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( 152 | self.args.data_test, self.ckp.psnr_log[-1], 153 | best[0], best[1] + 1)) 154 | if not self.args.test_only: 155 | self.ckp.save(self, epoch, is_best=(best[1] + 1 == epoch)) 156 | 157 | def terminate(self): 158 | if self.args.test_only: 159 | self.test() 160 | return True 161 | else: 162 | epoch = self.scheduler.last_epoch + 1 163 | return epoch >= self.args.epochs 164 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import math 6 | 7 | 8 | def postprocess(*images, rgb_range, ycbcr_flag, device): 9 | def _postprocess(img, rgb_coefficient, ycbcr_flag, device): 10 | if ycbcr_flag: 11 | mean_YCbCr = torch.Tensor([109]).to(device) 12 | out = (img.mul(rgb_coefficient) + mean_YCbCr).clamp(16, 235).div(rgb_coefficient) 13 | elif img.shape[2] == 3: 14 | mean_RGB = torch.Tensor([123.68, 116.779, 103.939]).to(device) 15 | mean_RGB = mean_RGB.reshape([1, 3, 1, 1]) 16 | out = (img.mul(rgb_coefficient) + mean_RGB).clamp(0, 255).round().div(rgb_coefficient) 17 | else: 18 | mean_YCbCr = torch.Tensor([109]).to(device) 19 | out = (img.mul(rgb_coefficient) + mean_YCbCr).clamp(0, 255).round() 20 | out.div_(rgb_coefficient) 21 | 22 | return out 23 | 24 | rgb_coefficient = 255 / rgb_range 25 | return [_postprocess(img, rgb_coefficient, ycbcr_flag, device) for img in images] 26 | 27 | ''' 28 | def calc_PSNR(img1, img2): 29 | # assume RGB image 30 | target_data = np.array(img1, dtype=np.float64) 31 | ref_data = np.array(img2, dtype=np.float64) 32 | diff = ref_data - target_data 33 | diff = diff.flatten('C') 34 | rmse = math.sqrt(np.mean(diff ** 2.)) 35 | if rmse == 0: 36 | return 100 37 | else: 38 | return 20*math.log10(255.0/rmse) 39 | ''' 40 | 41 | def calc_psnr(args, x, y): 42 | if isinstance(x, torch.Tensor): 43 | diff = (x - y).data 44 | shave = 2 + args.scale 45 | valid = diff[:, :, shave:-shave, shave:-shave] 46 | if args.n_colors == 3: 47 | convert = valid.new(1, 3, 1, 1) 48 | convert[0, 0, 0, 0] = 65.738 49 | convert[0, 1, 0, 0] = 129.057 50 | convert[0, 2, 0, 0] = 25.064 51 | valid.mul_(convert).div_(256) 52 | valid = valid.sum(dim=1) 53 | mse = valid.div(args.rgb_range).pow(2).mean() 54 | if mse == 0: 55 | mse = 1e-10 56 | # print('mse :', mse) 57 | return -10 * math.log10(mse) 58 | 59 | elif isinstance(x, np.ndarray): 60 | diff = (x - y) 61 | if diff.ndim == 4: 62 | diff = np.transpose(np.squeeze(diff, axis=0), (1, 2, 0)) 63 | shave = 2 + args.scale 64 | valid = diff[shave:-shave, shave:-shave, :] 65 | if args.n_colors == 3: 66 | valid[:, :, 0] *= 65.738 67 | valid[:, :, 1] *= 129.057 68 | valid[:, :, 2] *= 25.064 69 | valid = valid.sum(axis=2) / 256 70 | mse = (valid ** 2).mean() 71 | if mse == 0: 72 | mse = 1e-10 73 | # print('mse :', mse) 74 | return -10 * math.log10(mse) 75 | --------------------------------------------------------------------------------