├── user-imgs └── teaser.jpg ├── requirements.txt ├── configs ├── Bistro_train.txt └── Bistro_test.txt ├── src ├── script │ ├── BistroX4_train.bat │ └── BistroX4_test.bat ├── loss │ ├── ssim_loss.py │ ├── temporal_loss.py │ ├── vgg.py │ └── __init__.py ├── main.py ├── data │ ├── __init__.py │ ├── base_dataset.py │ ├── data_loader.py │ ├── data_utils.py │ ├── view_dataset.py │ └── irradiance_dataset.py ├── model │ ├── ConvLSTM.py │ ├── ssim.py │ ├── sr_model.py │ ├── __init__.py │ └── sr_common.py ├── remodulation.py ├── option.py ├── trainer.py └── utility.py ├── LICENSE ├── README.md └── .gitignore /user-imgs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Riga2/NSRD/HEAD/user-imgs/teaser.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.3 2 | opencv-python==4.7.0.72 3 | Pillow==9.5.0 4 | tqdm==4.65.0 5 | matplotlib 6 | imageio 7 | scikit-image 8 | configparser -------------------------------------------------------------------------------- /configs/Bistro_train.txt: -------------------------------------------------------------------------------- 1 | root_dir = ../dataset/Bistro/train 2 | grain = 100 3 | total_folder_num = 60 4 | test_folder = [] 5 | valid_folder_num = 6 6 | test_only = False 7 | save = Bistro_X4 8 | resume = 1 9 | save_results = False -------------------------------------------------------------------------------- /src/script/BistroX4_train.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | echo Running training on BistroX4... 3 | 4 | set train_config=../configs/Bistro_train.txt 5 | 6 | python main.py --config %train_config% 7 | 8 | echo Training done. 9 | pause 10 | -------------------------------------------------------------------------------- /src/loss/ssim_loss.py: -------------------------------------------------------------------------------- 1 | from model import ssim 2 | import torch.nn as nn 3 | 4 | 5 | class SSIM(nn.Module): 6 | def __init__(self): 7 | super(SSIM, self).__init__() 8 | 9 | def forward(self, sr, hr): 10 | return 1 - ssim.ssim(sr, hr) 11 | -------------------------------------------------------------------------------- /configs/Bistro_test.txt: -------------------------------------------------------------------------------- 1 | root_dir = ../dataset/Bistro/test 2 | grain = 300 3 | total_folder_num = 4 4 | test_folder = [1] # Currently only supports testing one folder at a time 5 | valid_folder_num = 0 6 | test_only = True 7 | save = Bistro_X4 8 | resume = -1 9 | save_results = True -------------------------------------------------------------------------------- /src/loss/temporal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | class Temporal_Loss(nn.Module): 6 | def __init__(self): 7 | super(Temporal_Loss, self).__init__() 8 | 9 | def forward(self, sr_pre, sr_cur): 10 | return F.smooth_l1_loss(sr_pre, sr_cur) 11 | -------------------------------------------------------------------------------- /src/script/BistroX4_test.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | echo Running test on BistroX4... 3 | 4 | set folders=0 1 2 3 5 | set test_config=../configs/Bistro_test.txt 6 | 7 | for %%i in (%folders%) do ( 8 | echo Testing folder %%i ... 9 | python main.py --config %test_config% --test_folder %%i 10 | ) 11 | 12 | echo All tests executed. 13 | pause 14 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import utility 4 | import data 5 | import model 6 | import loss 7 | from option import args 8 | from trainer import Trainer 9 | 10 | torch.manual_seed(args.seed) 11 | 12 | 13 | def main(): 14 | global model 15 | checkpoint = utility.checkpoint(args) 16 | if checkpoint.ok: 17 | loader = data.Data(args) 18 | _model = model.Model(args, checkpoint) 19 | _loss = loss.Loss(args, checkpoint) if not args.test_only else None 20 | t = Trainer(args, loader, _model, _loss, checkpoint) 21 | while not t.terminate(): 22 | t.train() 23 | t.test() 24 | 25 | checkpoint.done() 26 | 27 | if __name__ == '__main__': 28 | main() 29 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | from data import data_loader 3 | import os 4 | 5 | class Data: 6 | def __init__(self, args): 7 | self.sr_content = args.sr_content 8 | self.loader_train = None 9 | module_name = self.sr_content 10 | m = import_module('data.' + module_name.lower() + '_dataset') 11 | if not args.test_only: 12 | train_datasets = m.make_model(args, train=True) 13 | valid_datasets = m.make_model(args, train=False) 14 | self.loader_train = data_loader.RenderingDataLoader(args, train_datasets) 15 | self.loader_valid = self.loader_train.split_validation(valid_datasets) 16 | else: 17 | test_datasets = m.make_model(args, train=False) 18 | self.loader_valid = data_loader.RenderingDataLoader(args, test_datasets) 19 | 20 | -------------------------------------------------------------------------------- /src/loss/vgg.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision.models as models 6 | 7 | class VGG(nn.Module): 8 | def __init__(self, conv_index): 9 | super(VGG, self).__init__() 10 | vgg_features = models.vgg19(pretrained=True).features 11 | modules = [m for m in vgg_features] 12 | if conv_index.find('22') >= 0: 13 | self.vgg = nn.Sequential(*modules[:8]) 14 | elif conv_index.find('54') >= 0: 15 | self.vgg = nn.Sequential(*modules[:35]) 16 | 17 | for p in self.parameters(): 18 | p.requires_grad = False 19 | 20 | def forward(self, sr, hr): 21 | def _forward(x): 22 | x = self.vgg(x) 23 | return x 24 | 25 | vgg_sr = _forward(sr) 26 | with torch.no_grad(): 27 | vgg_hr = _forward(hr.detach()) 28 | 29 | loss = F.mse_loss(vgg_sr, vgg_hr) 30 | 31 | return loss 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Riga2 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/model/ConvLSTM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def depth_wise_conv(in_feats, out_feats, kernel, bias=True): 6 | return nn.Sequential( 7 | nn.Conv2d(in_feats, in_feats, kernel_size=kernel, padding=(kernel // 2), groups=in_feats, bias=bias), 8 | nn.Conv2d(in_feats, out_feats, kernel_size=1) 9 | ) 10 | 11 | 12 | class ConvLSTM(nn.Module): 13 | 14 | def __init__(self, input_size, hidden_size, kernel_size): 15 | super(ConvLSTM, self).__init__() 16 | 17 | self.input_size = input_size 18 | self.hidden_size = hidden_size 19 | pad = kernel_size // 2 20 | 21 | self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, kernel_size, padding=pad) 22 | # depth_wise_conv(input_size + hidden_size, 4 * hidden_size, kernel_size) 23 | 24 | def forward(self, input_, prev_state=None): 25 | # get batch and spatial sizes 26 | batch_size = input_.data.size()[0] 27 | spatial_size = input_.data.size()[2:] 28 | 29 | # generate empty prev_state, if None is provided 30 | if prev_state is None: 31 | state_size = [batch_size, self.hidden_size] + list(spatial_size) 32 | prev_state = ( 33 | torch.zeros(state_size).to(input_.device), 34 | torch.zeros(state_size).to(input_.device) 35 | ) 36 | 37 | prev_hidden, prev_cell = prev_state 38 | 39 | # data size is [batch, channel, height, width] 40 | stacked_inputs = torch.cat((input_, prev_hidden), 1) 41 | gates = self.Gates(stacked_inputs) 42 | 43 | # chunk across channel dimension 44 | in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1) 45 | 46 | # apply sigmoid non linearity 47 | in_gate = torch.sigmoid(in_gate) 48 | remember_gate = torch.sigmoid(remember_gate) 49 | out_gate = torch.sigmoid(out_gate) 50 | 51 | # apply tanh non linearity 52 | cell_gate = torch.tanh(cell_gate) 53 | 54 | # compute current cell and hidden state 55 | cell = (remember_gate * prev_cell) + (in_gate * cell_gate) 56 | hidden = out_gate * torch.tanh(cell) 57 | 58 | return hidden, cell 59 | -------------------------------------------------------------------------------- /src/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" 3 | import cv2 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from data import data_utils 7 | 8 | class BaseDataset(Dataset): 9 | def __init__(self, args=None): 10 | super(BaseDataset, self).__init__() 11 | self.args = args 12 | self.scale = args.scale 13 | self.number_previous_frames = args.num_pre_frames 14 | self.hr_size = args.gt_size 15 | self.lr_size = (args.gt_size[0] // self.scale, args.gt_size[1] // self.scale) 16 | 17 | self.gt_dir = os.path.join(args.root_dir, 'GT') 18 | self.lr_dir = os.path.join(args.root_dir, 'X' + str(args.scale)) 19 | 20 | self.useNormal, self.useDepth = args.use_normal, args.use_depth 21 | self.depth_dirname = args.depth_dirname 22 | self.normal_dirname = args.normal_dirname 23 | self.mv_dirname = args.mv_dirname 24 | self.ocmv_dirname = args.ocmv_dirname 25 | self.grain = args.grain 26 | self.total_folder_num = args.total_folder_num 27 | 28 | def load_Normal_Unity(self, folder_index, file_index, ext='.exr'): 29 | filename = str(file_index) + ext 30 | lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.normal_dirname, filename) 31 | lr = data_utils.getFromExr(lr_file_path)[:, :, :3] 32 | return lr 33 | 34 | def load_Depth_Unity(self, folder_index, file_index, ext='.exr'): 35 | filename = str(file_index) + ext 36 | lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.depth_dirname, filename) 37 | lr = data_utils.getFromExr(lr_file_path)[:, :, 0][:, :, None] 38 | return lr 39 | 40 | def load_MV(self, folder_index, file_index, ext='.exr'): 41 | filename = str(file_index) + ext 42 | lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.mv_dirname, filename) 43 | # lr = data_utils.getFromBin(lr_file_path, self.lr_size[0], self.lr_size[1])[:, :, :2] 44 | lr = data_utils.getFromExr(lr_file_path)[:, :, :2] 45 | lr[:, :, 0] = lr[:, :, 0] * self.lr_size[1] 46 | lr[:, :, 1] = lr[:, :, 1] * self.lr_size[0] 47 | return lr 48 | 49 | def load_OCMV(self, folder_index, file_index, ext='.exr'): 50 | filename = str(file_index) + ext 51 | lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.ocmv_dirname, filename) 52 | # lr = data_utils.getFromBin(lr_file_path, self.lr_size[0], self.lr_size[1])[:, :, :2] 53 | lr = data_utils.getFromExr(lr_file_path)[:, :, :2] 54 | lr[:, :, 0] = lr[:, :, 0] * self.lr_size[1] 55 | lr[:, :, 1] = lr[:, :, 1] * self.lr_size[0] 56 | return lr 57 | 58 | -------------------------------------------------------------------------------- /src/model/ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 10 | return gauss / gauss.sum() 11 | 12 | 13 | def create_window(window_size, channel): 14 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 15 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 16 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 17 | return window 18 | 19 | 20 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 21 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 22 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 23 | 24 | mu1_sq = mu1.pow(2) 25 | mu2_sq = mu2.pow(2) 26 | mu1_mu2 = mu1 * mu2 27 | 28 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 29 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 30 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 31 | 32 | C1 = 0.01 ** 2 33 | C2 = 0.03 ** 2 34 | 35 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 36 | 37 | if size_average: 38 | return ssim_map.mean() 39 | else: 40 | return ssim_map.mean(1).mean(1).mean(1) 41 | 42 | 43 | class SSIM(torch.nn.Module): 44 | def __init__(self, window_size=11, size_average=True): 45 | super(SSIM, self).__init__() 46 | self.window_size = window_size 47 | self.size_average = size_average 48 | self.channel = 1 49 | self.window = create_window(window_size, self.channel) 50 | 51 | def forward(self, img1, img2): 52 | (_, channel, _, _) = img1.size() 53 | 54 | if channel == self.channel and self.window.data.type() == img1.data.type(): 55 | window = self.window 56 | else: 57 | window = create_window(self.window_size, channel) 58 | 59 | if img1.is_cuda: 60 | window = window.cuda(img1.get_device()) 61 | window = window.type_as(img1) 62 | 63 | self.window = window 64 | self.channel = channel 65 | 66 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 67 | 68 | 69 | def ssim(img1, img2, window_size=11, size_average=True): 70 | (_, channel, _, _) = img1.size() 71 | window = create_window(window_size, channel) 72 | 73 | if img1.is_cuda: 74 | window = window.cuda(img1.get_device()) 75 | window = window.type_as(img1) 76 | 77 | return _ssim(img1, img2, window, window_size, channel, size_average) 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Neural Super-Resolution for Real-time Rendering with Radiance Demodulation (CVPR 2024) 2 | 3 | ### [Paper](https://arxiv.org/abs/2308.06699) | [Project Page](https://riga2.github.io/nsrd/) 4 | 5 | __Note__: Since the dataset is quite large, we have uploaded three scenes to **OneDrive** currently: [Bitsro](https://mailsdueducn-my.sharepoint.com/:u:/g/personal/cuteloong_mail_sdu_edu_cn/EWaiWyZbRxNLgWw7wTV6IBYB2NXLimZ4JcvKhlWpZQ6Jzg?e=hc9scu), [Pica](https://mailsdueducn-my.sharepoint.com/:u:/g/personal/202215216_mail_sdu_edu_cn/ESRf0I6mKH1PqL9wpOW-q7YBYygrKDF9q13piFd1Xyce9g?e=sNhN72) and [San_M](https://mailsdueducn-my.sharepoint.com/:u:/g/personal/202215216_mail_sdu_edu_cn/Eek4usuLcUZGlHCnpzPD63YBukX5FlXELUbTzWa3_zsx1Q?e=nFAwRR). 6 | 7 | If you are in China mainland, you can also access the dataset on [Baidu Cloud Disk](https://pan.baidu.com/s/1GJZ34keRFvGqnJ1Wgg0RHw?pwd=riga). 8 | 9 | ![Teaser](https://github.com/Riga2/NSRD/blob/main/user-imgs/teaser.jpg) 10 | 11 | ### Installation 12 | 13 | Tested on Windows + CUDA 11.3 + Pytorch 1.12.1 14 | 15 | Install environment: 16 | 17 | ```bazaar 18 | git clone https://github.com/riga2/NSRD.git 19 | cd NSRD 20 | conda create -n NSRD python=3.9 21 | conda activate NSRD 22 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | The following training and testing take the Bistro scene (X4) as an example. 27 | 28 | ### Training 29 | 1. Make a folder named "dataset", and then download the dataset and put it inside. 30 | ```bazaar 31 | |--configs 32 | |--dataset 33 | |--Bistro 34 | |--train 35 | |---GT 36 | |--0 37 | |--1 38 | ... 39 | |---X4 40 | |--0 41 | |--1 42 | ... 43 | |---test 44 | |---GT 45 | ... 46 | |---X4 47 | ... 48 | ``` 49 | 2. Use the Anaconda Prompt to run the following commands to train. The trained model is stored in "experiment\Bistro_X4\model". 50 | ```bazaar 51 | cd src 52 | .\script\BistroX4_train.bat 53 | ``` 54 | 55 | ### Testing 56 | 1. Run the following commands to perform super-resolution on the LR lighting components. The SR results are stored in "experiment\Bistro_X4\sr_results_x4". 57 | ```bazaar 58 | cd src 59 | .\test_script\BistroX4_test.bat 60 | ``` 61 | 2. Run the following commands to perform remodulation on the SR lighting components. The final results are stored in "experiment\Bistro_X4\final_results_x4". 62 | ```bazaar 63 | cd src 64 | python remodulation.py --exp_dir ../experiment/Bistro_X4 --gt_dir ../dataset/Bistro/test/GT 65 | ``` 66 | 67 | ### Citation 68 | ``` 69 | @inproceedings{li2024neural, 70 | title={Neural Super-Resolution for Real-time Rendering with Radiance Demodulation}, 71 | author={Li, Jia and Chen, Ziling and Wu, Xiaolong and Wang, Lu and Wang, Beibei and Zhang, Lei}, 72 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 73 | pages={4357--4367}, 74 | year={2024} 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /src/data/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.sampler import SubsetRandomSampler, SequentialSampler, Sampler 6 | import copy 7 | from typing import Callable 8 | import random 9 | 10 | class SubsetSequenceSampler(Sampler): 11 | def __init__(self, indices): 12 | super(SubsetSequenceSampler, self).__init__(indices) 13 | self.indices = indices 14 | 15 | def __iter__(self): 16 | return iter(self.indices) 17 | 18 | def __len__(self): 19 | return len(self.indices) 20 | 21 | 22 | class RenderingDataLoader(DataLoader): 23 | def __init__(self, args, dataset): 24 | self.args = args 25 | self.dataset = dataset 26 | self.batch_size = 1 if args.test_only else args.batch_size 27 | self.number_previous_frames = args.num_pre_frames 28 | self.num_frames_samples = args.num_frames_samples 29 | self.test_every = args.test_every 30 | self.n_samples = args.total_folder_num * args.grain 31 | 32 | self.valid_range = [] 33 | if args.test_only is False: 34 | valid_folders = np.random.choice(np.arange(0, args.total_folder_num), args.valid_folder_num, replace=False) 35 | for i in valid_folders: 36 | self.valid_range += range(i * args.grain, (i+1) * args.grain) 37 | 38 | self.test_range = [] 39 | if args.test_folder is not None: 40 | for i in args.test_folder: 41 | self.test_range += range(i * args.grain, (i+1) * args.grain) 42 | 43 | self.idx_train = [] 44 | for idx in range(self.n_samples): 45 | if (idx not in self.valid_range) and (idx not in self.test_range): 46 | idx_mod = idx % args.grain 47 | if (idx_mod >= self.number_previous_frames) and (idx_mod <= args.grain - self.num_frames_samples): 48 | self.idx_train.append(idx) 49 | 50 | self.idx_valid = [] 51 | for idx in self.valid_range: 52 | idx_mod = idx % args.grain 53 | if (idx_mod >= self.number_previous_frames) and (idx_mod <= args.grain - self.num_frames_samples): 54 | self.idx_valid.append(idx) 55 | 56 | self.idx_test = [] 57 | for idx in self.test_range: 58 | idx_mod = idx % args.grain 59 | if (idx_mod >= self.number_previous_frames) and (idx_mod <= args.grain - self.num_frames_samples): 60 | self.idx_test.append(idx) 61 | 62 | self.sampler, self.valid_sampler = self._split_sampler() 63 | 64 | init_kwargs = { 65 | 'dataset': dataset, 66 | 'batch_size': self.batch_size, 67 | 'num_workers': args.n_threads 68 | } 69 | super().__init__(sampler=self.sampler, **init_kwargs, drop_last=True) 70 | 71 | def _split_sampler(self): 72 | # For test 73 | if len(self.valid_range) == 0: 74 | train_sampler = SubsetSequenceSampler(self.idx_test) 75 | return train_sampler, None 76 | 77 | repeat = (self.batch_size * self.test_every) // len(self.idx_train) 78 | train_idx = np.repeat(self.idx_train, repeat) 79 | 80 | np.random.seed(0) 81 | np.random.shuffle(train_idx) 82 | 83 | train_sampler = SubsetRandomSampler(train_idx) 84 | valid_sampler = SubsetSequenceSampler(self.idx_valid) 85 | self.n_samples = len(train_idx) 86 | 87 | return train_sampler, valid_sampler 88 | 89 | def split_validation(self, valid_dataset): 90 | if self.valid_sampler is None: 91 | return None 92 | else: 93 | valid_dataloader = DataLoader(sampler=self.valid_sampler, dataset=valid_dataset, batch_size=1, num_workers=self.num_workers, drop_last=True) 94 | return valid_dataloader 95 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # user settings 2 | dataset/ 3 | experiment/ 4 | src/check/* 5 | .idea/ 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/#use-with-ide 116 | .pdm.toml 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ 167 | -------------------------------------------------------------------------------- /src/remodulation.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | import os 3 | os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" 4 | import cv2 5 | import numpy as np 6 | from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim 7 | from tqdm import tqdm 8 | 9 | def getFromExr(path): 10 | bgr = cv2.imread(path, cv2.IMREAD_UNCHANGED) 11 | rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) 12 | return rgb 13 | 14 | def gamma_correct(data): 15 | data = np.where(data < 0.0031, data * 12.92, np.power(data, 1.0/2.4) * 1.055 - 0.055) 16 | return data 17 | 18 | def sRGB(linear): 19 | return np.where(linear < 0.0031308, 20 | 12.92 * linear, 21 | 1.055 * np.power(linear, 1/2.4) - 0.055) 22 | 23 | def toneMapTev(color): 24 | color[:, :, 0] = sRGB(color[:, :, 0]) 25 | color[:, :, 1] = sRGB(color[:, :, 1]) 26 | color[:, :, 2] = sRGB(color[:, :, 2]) 27 | 28 | return np.clip(color, 0.0, 1.0) 29 | 30 | 31 | def calc_psnr_ssim(res_dir, gt_dir): 32 | total_ssim = 0 33 | total_psnr = 0 34 | count = 0 35 | for name in tqdm(os.listdir(res_dir)): 36 | png_path = os.path.join(res_dir, name) 37 | gt_path = os.path.join(gt_dir, name) 38 | 39 | res_png = getFromExr(png_path) 40 | gt_png = getFromExr(gt_path) 41 | 42 | total_psnr += psnr(gt_png, res_png) 43 | total_ssim += ssim(gt_png, res_png, multichannel=True) 44 | count += 1 45 | 46 | print("avg_psnr : {}".format(total_psnr / count)) 47 | print("avg_ssim : {}".format(total_ssim / count)) 48 | 49 | def remodulation(exp_dir, gt_dir): 50 | avg_psnr, avg_ssim = 0, 0 51 | sr_res_dir = os.path.join(exp_dir, 'sr_results_x4') 52 | img_save_dir = os.path.join(exp_dir, 'final_results_x4') 53 | folder_num = len(os.listdir(sr_res_dir)) 54 | for ind in range(folder_num): 55 | os.makedirs(os.path.join(img_save_dir, f'{ind}'), exist_ok=True) 56 | 57 | cur_res_dir = os.path.join(sr_res_dir, f'{ind}') 58 | cur_res_lst = os.listdir(cur_res_dir) 59 | cur_num = len(cur_res_lst) 60 | cur_psnr, cur_ssim = 0, 0 61 | for name in tqdm(cur_res_lst): 62 | res_path = os.path.join(cur_res_dir, name) 63 | png_name = name.split('.')[0] + '.png' 64 | 65 | irr = getFromExr(res_path) 66 | irr[irr < 0] = 0 67 | brdf = getFromExr(os.path.join(gt_dir, f'{ind}', "BRDF", name)) 68 | emiss_sky = getFromExr(os.path.join(gt_dir, f'{ind}', 'Emission_Sky', name)) 69 | emiss_sky_mask = ((abs(emiss_sky[:, :, 0]) >= 1e-4) | (abs(emiss_sky[:, :, 1]) >= 1e-4) | (abs(emiss_sky[:, :, 2]) >= 1e-4))[:, :, np.newaxis] 70 | 71 | sr_img = brdf * irr 72 | sr_img = np.where(emiss_sky_mask, emiss_sky, sr_img) 73 | sr_img = (toneMapTev(sr_img)*255).astype(np.uint8) 74 | 75 | gt_img = getFromExr(os.path.join(gt_dir, f'{ind}', "View_PNG", png_name)) 76 | cur_psnr += psnr(gt_img, sr_img) 77 | cur_ssim += ssim(gt_img, sr_img, win_size=11, channel_axis=2, data_range=255) 78 | 79 | save_path = os.path.join(cur_res_dir, png_name) 80 | cv2.imwrite(save_path, sr_img[:, :, ::-1]) 81 | 82 | cur_psnr /= cur_num 83 | cur_ssim /= cur_num 84 | avg_psnr += cur_psnr 85 | avg_ssim += cur_ssim 86 | 87 | avg_psnr /= folder_num 88 | avg_ssim /= folder_num 89 | 90 | print("Avg_pnsr: {}".format(avg_psnr)) 91 | print("Avg_ssim: {}".format(avg_ssim)) 92 | 93 | 94 | if __name__ == '__main__': 95 | parser = configargparse.ArgumentParser() 96 | parser.add_argument('--exp_dir', type=str, default=r"../experiment/Bistro_X4", 97 | help='experiment dir') 98 | parser.add_argument('--gt_dir', type=str, default=r"../dataset/Bistro/test/GT", 99 | help='ground truth dir, which contains View_PNG, BRDF and Emisson_Sky.') 100 | args = parser.parse_args() 101 | 102 | remodulation(args.exp_dir, args.gt_dir) 103 | 104 | 105 | -------------------------------------------------------------------------------- /src/model/sr_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from model import sr_common as sr_com 4 | import torch.nn.functional as F 5 | from model.ConvLSTM import ConvLSTM 6 | 7 | def make_model(args): 8 | if args.num_pre_frames != 0: 9 | return SRNet(args) 10 | 11 | class SRNet(nn.Module): 12 | def __init__(self, args=None, conv=sr_com.default_conv): 13 | super(SRNet, self).__init__() 14 | self.scale = args.scale 15 | self.act = nn.ReLU(True) 16 | self.in_dims = args.input_total_dims 17 | self.n_feats = args.n_feats 18 | self.num_previous = args.num_pre_frames 19 | self.total_feats = self.n_feats * (self.num_previous + 1) + self.n_feats 20 | 21 | self.conv1 = conv(self.in_dims, self.n_feats, 3) 22 | self.gate_conv = sr_com.GatedConv2dWithActivation(self.in_dims + 1, self.n_feats, 3, padding=1) 23 | 24 | feats = 64 25 | self.unps = nn.PixelUnshuffle(self.scale) 26 | self.conv2 = conv(3 * self.scale * self.scale, self.n_feats, 3) 27 | self.convLSTM = ConvLSTM(self.total_feats, feats, kernel_size=3) 28 | 29 | # U-shaped reconstruction module 30 | self.encoder_1 = nn.Sequential( 31 | sr_com.RCAB(conv, feats, 3, 16), 32 | sr_com.RCAB(conv, feats, 3, 16), 33 | sr_com.RCAB(conv, feats, 3, 16) 34 | ) 35 | 36 | self.encoder_2 = nn.Sequential( 37 | sr_com.RCAB(conv, feats, 3, 16), 38 | sr_com.RCAB(conv, feats, 3, 16) 39 | ) 40 | 41 | self.center = nn.Sequential( 42 | sr_com.RCAB(conv, feats, 3, 16), 43 | sr_com.RCAB(conv, feats, 3, 16), 44 | sr_com.RCAB(conv, feats, 3, 16) 45 | ) 46 | 47 | self.decoder_2 = nn.Sequential( 48 | conv(feats * 2, feats, 3), 49 | sr_com.RCAB(conv, feats, 3, 16), 50 | sr_com.RCAB(conv, feats, 3, 16) 51 | ) 52 | 53 | self.decoder_1 = nn.Sequential( 54 | conv(feats * 2, feats, 3), 55 | sr_com.RCAB(conv, feats, 3, 16), 56 | sr_com.RCAB(conv, feats, 3, 16), 57 | sr_com.RCAB(conv, feats, 3, 16) 58 | ) 59 | 60 | self.pooling = nn.MaxPool2d(2) 61 | self.upsize = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 62 | self.conv3 = conv(feats, self.n_feats, 3) 63 | 64 | self.upsampling = nn.Sequential( 65 | conv(self.n_feats, args.output_dims * self.scale * self.scale, 3), 66 | nn.PixelShuffle(self.scale) 67 | ) 68 | 69 | def crop_tensor(self, actual: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 70 | # https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py 71 | diffY = int(actual.size()[2] - target.size()[2]) 72 | diffX = int(actual.size()[3] - target.size()[3]) 73 | x = F.pad(target, [diffX // 2, diffX - diffX // 2, 74 | diffY // 2, diffY - diffY // 2]) 75 | return x 76 | 77 | def forward(self, x_tuple): 78 | x, last_sr, prev_state = x_tuple 79 | 80 | # split cur frames 81 | x_cur_frames = x[:, :self.in_dims, :, :] 82 | x_cur_frames = self.conv1(x_cur_frames) 83 | 84 | # the last channel of previous frame is mask 85 | x_pre_frames = [] 86 | ch_pre_frames = self.in_dims + 1 87 | 88 | # split previous frames 89 | for i in range(0, self.num_previous): 90 | st, ed = self.in_dims + i * ch_pre_frames, self.in_dims + (i + 1) * ch_pre_frames 91 | pre_frame = x[:, st:ed, :, :] 92 | pre_frame = self.gate_conv(pre_frame) 93 | x_pre_frames.append(pre_frame) 94 | x_pre_frame = torch.cat(x_pre_frames, dim=1) 95 | 96 | # last sr 97 | last_ups = self.unps(last_sr) 98 | last_in = self.conv2(last_ups) 99 | 100 | # path both 101 | x_all = torch.cat((x_cur_frames, x_pre_frame, last_in), dim=1) 102 | state = self.convLSTM(x_all, prev_state) 103 | 104 | x_encoder1 = self.encoder_1(state[0]) 105 | x_encoder1_pool = self.pooling(x_encoder1) 106 | x_encoder2 = self.encoder_2(x_encoder1_pool) 107 | x_encoder2_pool = self.pooling(x_encoder2) 108 | 109 | x_center = self.center(x_encoder2_pool) 110 | 111 | x_center_up = self.crop_tensor(x_encoder2, self.upsize(x_center)) 112 | x_decoder2 = self.decoder_2(torch.cat((x_center_up, x_encoder2), dim=1)) 113 | 114 | x_decoder2_up = self.crop_tensor(x_encoder1, self.upsize(x_decoder2)) 115 | x_decoder1 = self.decoder_1(torch.cat((x_decoder2_up, x_encoder1), dim=1)) 116 | 117 | x_in = self.conv3(x_decoder1) 118 | x_res = self.upsampling(x_in) 119 | 120 | return x_res, state 121 | 122 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel as P 7 | import torch.utils.model_zoo 8 | 9 | 10 | class Model(nn.Module): 11 | def __init__(self, args, ckp): 12 | super(Model, self).__init__() 13 | print('Making model...') 14 | 15 | self.scale = args.scale 16 | self.idx_scale = 0 17 | self.chop = args.chop 18 | self.precision = args.precision 19 | self.cpu = args.cpu 20 | self.device = torch.device('cpu' if args.cpu else 'cuda') 21 | self.n_GPUs = args.n_GPUs 22 | self.save_models = args.save_models 23 | 24 | module = import_module('model.' + args.model.lower()) 25 | self.model = module.make_model(args).to(self.device) 26 | if args.precision == 'half': 27 | self.model.half() 28 | 29 | self.load( 30 | ckp.get_path('model'), 31 | resume=args.resume, 32 | cpu=args.cpu 33 | ) 34 | print(self.model, file=ckp.log_file) 35 | 36 | def forward(self, x): 37 | if self.training: 38 | if self.n_GPUs > 1: 39 | return P.data_parallel(self.model, x, range(self.n_GPUs)) 40 | else: 41 | return self.model(x) 42 | else: 43 | if self.chop: 44 | forward_function = self.forward_chop 45 | else: 46 | forward_function = self.model.forward 47 | 48 | return forward_function(x) 49 | 50 | def save(self, apath, epoch, is_best=False): 51 | save_dirs = [os.path.join(apath, 'model_latest.pt')] 52 | 53 | if is_best: 54 | save_dirs.append(os.path.join(apath, 'model_best.pt')) 55 | if self.save_models: 56 | save_dirs.append( 57 | os.path.join(apath, 'model_{}.pt'.format(epoch)) 58 | ) 59 | 60 | for s in save_dirs: 61 | torch.save(self.model.state_dict(), s) 62 | 63 | def load(self, apath, resume=-1, cpu=False): 64 | load_from = None 65 | kwargs = {} 66 | if cpu: 67 | kwargs = {'map_location': lambda storage, loc: storage} 68 | 69 | if resume == -1: 70 | load_from = torch.load( 71 | os.path.join(apath, 'model_best.pt'), 72 | **kwargs 73 | ) 74 | elif resume == 0: 75 | load_from = torch.load( 76 | os.path.join(apath, 'model_latest.pt'), 77 | **kwargs 78 | ) 79 | 80 | if load_from: 81 | self.model.load_state_dict(load_from, strict=False) 82 | 83 | def forward_chop(self, *args, shave=10, min_size=160000): 84 | scale = 1 if self.input_large else self.scale[self.idx_scale] 85 | n_GPUs = min(self.n_GPUs, 4) 86 | # height, width 87 | h, w = args[0].size()[-2:] 88 | 89 | top = slice(0, h//2 + shave) 90 | bottom = slice(h - h//2 - shave, h) 91 | left = slice(0, w//2 + shave) 92 | right = slice(w - w//2 - shave, w) 93 | x_chops = [torch.cat([ 94 | a[..., top, left], 95 | a[..., top, right], 96 | a[..., bottom, left], 97 | a[..., bottom, right] 98 | ]) for a in args] 99 | 100 | y_chops = [] 101 | if h * w < 4 * min_size: 102 | for i in range(0, 4, n_GPUs): 103 | x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops] 104 | y = P.data_parallel(self.model, *x, range(n_GPUs)) 105 | if not isinstance(y, list): y = [y] 106 | if not y_chops: 107 | y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y] 108 | else: 109 | for y_chop, _y in zip(y_chops, y): 110 | y_chop.extend(_y.chunk(n_GPUs, dim=0)) 111 | else: 112 | for p in zip(*x_chops): 113 | y = self.forward_chop(*p, shave=shave, min_size=min_size) 114 | if not isinstance(y, list): y = [y] 115 | if not y_chops: 116 | y_chops = [[_y] for _y in y] 117 | else: 118 | for y_chop, _y in zip(y_chops, y): y_chop.append(_y) 119 | 120 | h *= scale 121 | w *= scale 122 | top = slice(0, h//2) 123 | bottom = slice(h - h//2, h) 124 | bottom_r = slice(h//2 - h, None) 125 | left = slice(0, w//2) 126 | right = slice(w - w//2, w) 127 | right_r = slice(w//2 - w, None) 128 | 129 | # batch size, number of color channels 130 | b, c = y_chops[0][0].size()[:-2] 131 | y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops] 132 | for y_chop, _y in zip(y_chops, y): 133 | _y[..., top, left] = y_chop[0][..., top, left] 134 | _y[..., top, right] = y_chop[1][..., top, right_r] 135 | _y[..., bottom, left] = y_chop[2][..., bottom_r, left] 136 | _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r] 137 | 138 | if len(y) == 1: y = y[0] 139 | 140 | return y 141 | -------------------------------------------------------------------------------- /src/data/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import random 8 | import imageio 9 | import torch.nn.functional as F 10 | import time 11 | 12 | def linearUpsample(img, scale): 13 | return cv2.resize(img, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) 14 | 15 | def save2Exr(img, path): 16 | cv2.imwrite(path, img[:,:,::-1]) 17 | 18 | def getFromBin(path, ih, iw): 19 | return np.fromfile(path, dtype=np.float32).reshape(ih, iw, -1).transpose(0, 1, 2) 20 | 21 | def getFromExr(path): 22 | bgr = cv2.imread(path, cv2.IMREAD_UNCHANGED) 23 | if bgr is None: 24 | print(path) 25 | rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) 26 | return rgb 27 | 28 | def bin2exr(pt_path, ih, iw, save_path): 29 | irr = getFromBin(pt_path, ih, iw)[:, :, ::-1] 30 | cv2.imwrite(save_path, irr) 31 | 32 | def get_patch(lr, hr, patch_size=256, scale=2): 33 | ih, iw = lr.shape[1:] 34 | tp = scale * patch_size 35 | ip = tp // scale 36 | ix = random.randrange(0, iw - ip + 1) 37 | iy = random.randrange(0, ih - ip + 1) 38 | tx, ty = scale * ix, scale * iy 39 | ret = [ 40 | lr[:, iy:iy + ip, ix:ix + ip], 41 | hr[:, ty:ty + tp, tx:tx + tp] 42 | ] 43 | return ret 44 | 45 | def np2Tensor(np_file): 46 | return torch.from_numpy(np_file).permute(2, 0, 1).float() 47 | 48 | def tensorFromNp(np_file): 49 | np_file = np_file.transpose(2, 0, 1) 50 | return torch.FloatTensor(np_file) 51 | 52 | def mv_mask(ocmv, mv, gate = 0.1): 53 | delta = ocmv - mv 54 | mask = torch.where(torch.abs(delta) < gate, False, True) 55 | x, y = mask[0, :, :], mask[1, :, :] 56 | mask = torch.where(((x) | (y)), 1, 0) 57 | return mask 58 | 59 | def backward_warp_motion(pre: torch.Tensor, motion: torch.Tensor, cur: torch.Tensor) -> torch.Tensor: 60 | # see: https://discuss.pytorch.org/t/image-warping-for-backward-flow-using-forward-flow-matrix-optical-flow/99298 61 | # input image is: [batch, channel, height, width] 62 | # st = time.time() 63 | index_batch, number_channels, height, width = pre.size() 64 | grid_x = torch.arange(width).view(1, -1).repeat(height, 1) 65 | grid_y = torch.arange(height).view(-1, 1).repeat(1, width) 66 | grid_x = grid_x.view(1, 1, height, width).repeat(index_batch, 1, 1, 1) 67 | grid_y = grid_y.view(1, 1, height, width).repeat(index_batch, 1, 1, 1) 68 | # 69 | grid = torch.cat((grid_x, grid_y), 1).float().cuda() 70 | # grid is: [batch, channel (2), height, width] 71 | vgrid = grid - motion 72 | # Grid values must be normalised positions in [-1, 1] 73 | vgrid_x = vgrid[:, 0, :, :] 74 | vgrid_y = vgrid[:, 1, :, :] 75 | vgrid[:, 0, :, :] = (vgrid_x / width) * 2.0 - 1.0 76 | vgrid[:, 1, :, :] = (vgrid_y / height) * 2.0 - 1.0 77 | # swapping grid dimensions in order to match the input of grid_sample. 78 | # that is: [batch, output_height, output_width, grid_pos (2)] 79 | vgrid = vgrid.permute((0, 2, 3, 1)) 80 | warped = F.grid_sample(pre, vgrid, align_corners=True) 81 | 82 | # return warped 83 | oox, ooy = torch.split((vgrid < -1) | (vgrid > 1), 1, dim=3) 84 | oo = (oox | ooy).permute(0, 3, 1, 2) 85 | # ed = time.time() 86 | # print('warp {}'.format(ed-st)) 87 | return torch.where(oo, cur, warped) 88 | 89 | def warp(x, flow, mode='bilinear', padding_mode='border'): 90 | """ Backward warp `x` according to `flow` 91 | 92 | Both x and flow are pytorch tensor in shape `nchw` and `n2hw` 93 | 94 | Reference: 95 | https://github.com/sniklaus/pytorch-spynet/blob/master/run.py#L41 96 | """ 97 | 98 | n, c, h, w = x.size() 99 | 100 | # create mesh grid 101 | iu = torch.linspace(-1.0, 1.0, w).view(1, 1, 1, w).expand(n, -1, h, -1) 102 | iv = torch.linspace(-1.0, 1.0, h).view(1, 1, h, 1).expand(n, -1, -1, w) 103 | grid = torch.cat([iu, iv], 1).to(flow.device) 104 | 105 | # normalize flow to [-1, 1] 106 | flow = torch.cat([ 107 | flow[:, 0:1, ...] / ((w - 1.0) / 2.0), 108 | flow[:, 1:2, ...] / ((h - 1.0) / 2.0)], dim=1) 109 | 110 | # add flow to grid and reshape to nhw2 111 | grid = (grid - flow).permute(0, 2, 3, 1) 112 | 113 | # bilinear sampling 114 | # Note: `align_corners` is set to `True` by default for PyTorch version < 1.4.0 115 | if int(''.join(torch.__version__.split('.')[:2])) >= 14: 116 | output = F.grid_sample( 117 | x, grid, mode=mode, padding_mode=padding_mode, align_corners=True) 118 | else: 119 | output = F.grid_sample(x, grid, mode=mode, padding_mode=padding_mode) 120 | 121 | return output 122 | 123 | def toneMapping(data): 124 | return data / (data + 1.0) 125 | 126 | def ldr2hdr(data): 127 | return data / (1.0 - data) 128 | 129 | def chromatic(data): 130 | return 0.5 * data 131 | 132 | def adjust(data): 133 | contrast = 0.0 134 | brightness = 0.0 135 | contrastFactor = (259.0 * (contrast * 256.0 + 255.0)) / (255.0 * (259.0 - 256.0 * contrast)) 136 | data = (data - 0.5) * contrastFactor + 0.5 + brightness 137 | return data 138 | 139 | 140 | def gamma_correct(data): 141 | r = 1.0 / 2.2 142 | return np.power(data, r) 143 | 144 | -------------------------------------------------------------------------------- /src/data/view_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import torch 5 | from data import data_utils 6 | import random 7 | from data.base_dataset import BaseDataset 8 | import imageio 9 | import time 10 | import pickle 11 | warped_save_dir = r'../experiment/warped_img' 12 | 13 | def make_model(args, train=True): 14 | return ViewDataset(args, train) 15 | 16 | class ViewDataset(BaseDataset): 17 | def __init__(self, args, train=True): 18 | super(ViewDataset, self).__init__(args) 19 | self.args = args 20 | self.train = train 21 | self.name = args.data_name 22 | self.crop_size = args.crop_size if train else None 23 | self.num_frames_samples = args.num_frames_samples if train else 1 24 | self.upsample = torch.nn.Upsample(scale_factor=self.scale, mode='bilinear', align_corners=True) 25 | self.view_dirname = args.view_dirname 26 | 27 | def load_View_HR(self, folder_index, file_index, ext='.png'): 28 | filename = str(file_index) + ext 29 | hr_file_path = os.path.join(self.gt_dir, str(folder_index), self.view_dirname, filename) 30 | hr = data_utils.getFromExr(hr_file_path).astype(np.float32) / 255.0 31 | return hr 32 | 33 | def load_View_LR(self, folder_index, file_index, ext='.png'): 34 | filename = str(file_index) + ext 35 | lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.view_dirname, filename) 36 | lr = data_utils.getFromExr(lr_file_path).astype(np.float32) / 255.0 37 | return lr 38 | 39 | def __getitem__(self, item): 40 | folder_index = item // self.grain 41 | file_index = item % self.grain 42 | 43 | HR_lst = [] 44 | LR_lst = [] 45 | MV_up_lst = [] 46 | OCMV_up_lst = [] 47 | Mask_up_lst = [] 48 | 49 | lr_sh, lr_sw, hr_sh, hr_sw = 0, 0, 0, 0 50 | lr_eh, lr_ew = self.lr_size 51 | hr_eh, hr_ew = self.hr_size 52 | if self.crop_size is not None: 53 | ih, iw = self.lr_size 54 | tp = self.scale * self.crop_size 55 | ip = tp // self.scale 56 | ix = random.randrange(0, iw - ip + 1) 57 | iy = random.randrange(0, ih - ip + 1) 58 | tx, ty = self.scale * ix, self.scale * iy 59 | lr_sh, lr_sw, hr_sh, hr_sw = iy, ix, ty, tx 60 | lr_eh, lr_ew, hr_eh, hr_ew = iy+ip, ix+ip, ty+tp, tx+tp 61 | 62 | for index in range(file_index, file_index + self.num_frames_samples): 63 | # HR 64 | hr = self.load_View_HR(folder_index, index) 65 | hr = data_utils.np2Tensor(hr) 66 | HR_lst.append(hr[:, hr_sh:hr_eh, hr_sw:hr_ew]) 67 | 68 | # LR 69 | lr_files, mvs, ocmvs = [], [], [] 70 | for idx in range(index, index - self.number_previous_frames - 1, -1): 71 | file = self.load_View_LR(folder_index, idx) 72 | if self.useNormal: 73 | normal = self.load_Normal_Unity(folder_index, idx) 74 | file = np.concatenate((file, normal), axis=2) 75 | if self.useDepth: 76 | depth = self.load_Depth_Unity(folder_index, idx) 77 | file = np.concatenate((file, depth), axis=2) 78 | file = data_utils.np2Tensor(file) 79 | lr_files.append(file[:, lr_sh:lr_eh, lr_sw:lr_ew]) 80 | 81 | # mv and ocmv 82 | if (idx != index - self.number_previous_frames): 83 | mv = data_utils.np2Tensor(self.load_MV(folder_index, idx)) 84 | ocmv = data_utils.np2Tensor(self.load_OCMV(folder_index, idx)) 85 | mvs.append(mv[:, lr_sh:lr_eh, lr_sw:lr_ew]) 86 | ocmvs.append(ocmv[:, lr_sh:lr_eh, lr_sw:lr_ew]) 87 | if (idx == index): 88 | mv_up = self.upsample(mv[None, :])[0] 89 | ocmv_up = self.upsample(ocmv[None, :])[0] 90 | mask_up = 1 - data_utils.mv_mask(ocmv_up, mv_up)[None, :] 91 | MV_up_lst.append(mv_up[:, hr_sh:hr_eh, hr_sw:hr_ew]) 92 | OCMV_up_lst.append(ocmv_up[:, hr_sh:hr_eh, hr_sw:hr_ew]) 93 | Mask_up_lst.append(mask_up[:, hr_sh:hr_eh, hr_sw:hr_ew]) 94 | 95 | # pre frames do Warp 96 | for i in range(self.number_previous_frames, 0, -1): 97 | for j in range(0, i-1): 98 | mvs[i-1] += mvs[j] 99 | ocmvs[i-1] += ocmvs[j] 100 | lr_files[i] = data_utils.backward_warp_motion(lr_files[i][None, :].cuda(), mvs[i-1][None, :].cuda(), 101 | lr_files[0][None, :].cuda())[0].cpu() 102 | 103 | # pre frames cat mask 104 | for i in range(0, self.number_previous_frames): 105 | mask = data_utils.mv_mask(ocmvs[i], mvs[i])[None, :] 106 | lr_files[i+1] = torch.cat((lr_files[i+1], mask), dim=0) 107 | 108 | lr = torch.cat(lr_files, dim=0) 109 | LR_lst.append(lr) 110 | 111 | return LR_lst, HR_lst, MV_up_lst, Mask_up_lst, str(file_index) 112 | 113 | def __len__(self): 114 | return self.total_folder_num * self.grain -------------------------------------------------------------------------------- /src/model/sr_common.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | def depth_wise_conv(in_feats, out_feats, kernel, bias=True): 6 | return nn.Sequential( 7 | nn.Conv2d(in_feats, in_feats, kernel_size=kernel, padding=(kernel // 2), groups=in_feats, bias=bias), 8 | nn.Conv2d(in_feats, out_feats, kernel_size=1) 9 | ) 10 | 11 | 12 | def default_conv(in_feats, out_feats, kernel_size, bias=True): 13 | return nn.Conv2d( 14 | in_feats, out_feats, kernel_size, 15 | padding=(kernel_size // 2), bias=bias) 16 | 17 | 18 | # Channel Attention (CA) Layer 19 | class CALayer(nn.Module): 20 | def __init__(self, channel, reduction=16): 21 | super(CALayer, self).__init__() 22 | # global average pooling: feature --> point 23 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 24 | # feature channel downscale and upscale --> channel weight 25 | self.conv_du = nn.Sequential( 26 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 29 | nn.Sigmoid() 30 | ) 31 | 32 | def forward(self, x): 33 | y = self.avg_pool(x) 34 | y = self.conv_du(y) 35 | return x * y 36 | 37 | 38 | # Residual Channel Attention Block (RCAB) 39 | class RCAB(nn.Module): 40 | def __init__( 41 | self, conv, n_feat, kernel_size, reduction, 42 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 43 | super(RCAB, self).__init__() 44 | modules_body = [] 45 | for i in range(2): 46 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 47 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 48 | if i == 0: modules_body.append(act) 49 | modules_body.append(CALayer(n_feat, reduction)) 50 | self.body = nn.Sequential(*modules_body) 51 | self.res_scale = res_scale 52 | 53 | def forward(self, x): 54 | res = self.body(x) 55 | # res = self.body(x).mul(self.res_scale) 56 | res += x 57 | return res 58 | 59 | 60 | class CA(nn.Module): 61 | def __init__(self, channel, reduction=16): 62 | super(CA, self).__init__() 63 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 64 | self.max_pool = nn.AdaptiveAvgPool2d(1) 65 | 66 | self.fc = nn.Sequential( 67 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 68 | nn.ReLU(inplace=True), 69 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 70 | ) 71 | 72 | self.sigmoid = nn.Sigmoid() 73 | 74 | def forward(self, x): 75 | avg_out = self.fc(self.avg_pool(x)) 76 | max_out = self.fc(self.max_pool(x)) 77 | out = avg_out + max_out 78 | return x * self.sigmoid(out) 79 | 80 | 81 | class SA(nn.Module): 82 | def __init__(self, kernel_size=7, bn=True): 83 | super(SA, self).__init__() 84 | self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) 85 | self.bn = nn.BatchNorm2d(1, eps=1e-5, momentum=0.01, affine=True) if bn else None 86 | self.sigmoid = nn.Sigmoid() 87 | 88 | def forward(self, x): 89 | avg_out = torch.mean(x, dim=1, keepdim=True) 90 | max_out, _ = torch.max(x, dim=1, keepdim=True) 91 | out = torch.cat((avg_out, max_out), dim=1) 92 | out = self.conv(out) 93 | if self.bn is not None: 94 | out = self.bn(out) 95 | return x * self.sigmoid(out) 96 | 97 | 98 | ## CBAM Block 99 | class CBAM(nn.Module): 100 | def __init__( 101 | self, conv, n_feat, kernel_size, reduction, 102 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 103 | super(CBAM, self).__init__() 104 | modules_body = [] 105 | for i in range(2): 106 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 107 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 108 | if i == 0: modules_body.append(act) 109 | modules_body.append(CALayer(n_feat, reduction)) 110 | modules_body.append(SA()) 111 | self.body = nn.Sequential(*modules_body) 112 | self.res_scale = res_scale 113 | 114 | def forward(self, x): 115 | res = self.body(x) 116 | # res = self.body(x).mul(self.res_scale) 117 | res += x 118 | return res 119 | 120 | 121 | class GatedConv2dWithActivation(torch.nn.Module): 122 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1, dilation=1, groups=1, bias=True, 123 | batch_norm=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)): 124 | super(GatedConv2dWithActivation, self).__init__() 125 | self.batch_norm = batch_norm 126 | self.activation = activation 127 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 128 | self.mask_conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, 129 | bias) 130 | self.batch_norm2d = torch.nn.BatchNorm2d(out_channels) 131 | self.sigmoid = torch.nn.Sigmoid() 132 | 133 | for m in self.modules(): 134 | if isinstance(m, nn.Conv2d): 135 | nn.init.kaiming_normal_(m.weight) 136 | 137 | def gated(self, mask): 138 | # return torch.clamp(mask, -1, 1) 139 | return self.sigmoid(mask) 140 | 141 | def forward(self, input): 142 | x = self.conv2d(input) 143 | mask = self.mask_conv2d(input) 144 | if self.activation is not None: 145 | x = self.activation(x) * self.gated(mask) 146 | else: 147 | x = x * self.gated(mask) 148 | if self.batch_norm: 149 | return self.batch_norm2d(x) 150 | else: 151 | return x 152 | 153 | 154 | -------------------------------------------------------------------------------- /src/data/irradiance_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import torch 5 | from data import data_utils 6 | import random 7 | from data.base_dataset import BaseDataset 8 | import imageio 9 | import time 10 | import pickle 11 | warped_save_dir = r'../experiment/warped_img' 12 | 13 | def make_model(args, train=True): 14 | return IrradianceDataset(args, train) 15 | 16 | class IrradianceDataset(BaseDataset): 17 | def __init__(self, args, train=True): 18 | super(IrradianceDataset, self).__init__(args) 19 | self.args = args 20 | self.train = train 21 | self.name = args.data_name 22 | self.crop_size = args.crop_size if train else None 23 | self.num_frames_samples = args.num_frames_samples if train else 1 24 | self.upsample = torch.nn.Upsample(scale_factor=self.scale, mode='bilinear', align_corners=True) 25 | self.irradiance_dirname = args.irradiance_dirname 26 | 27 | def load_Irradiance_HR(self, folder_index, file_index, ext='.exr'): 28 | filename = str(file_index) + ext 29 | hr_file_path = os.path.join(self.gt_dir, str(folder_index), self.irradiance_dirname, filename) 30 | hr = data_utils.getFromExr(hr_file_path) 31 | return hr 32 | 33 | def load_Irradiance_LR(self, folder_index, file_index, ext='.exr'): 34 | filename = str(file_index) + ext 35 | lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.irradiance_dirname, filename) 36 | lr = data_utils.getFromExr(lr_file_path) 37 | return lr 38 | 39 | def __getitem__(self, item): 40 | folder_index = item // self.grain 41 | file_index = item % self.grain 42 | 43 | HR_lst = [] 44 | LR_lst = [] 45 | MV_up_lst = [] 46 | OCMV_up_lst = [] 47 | Mask_up_lst = [] 48 | 49 | lr_sh, lr_sw, hr_sh, hr_sw = 0, 0, 0, 0 50 | lr_eh, lr_ew = self.lr_size 51 | hr_eh, hr_ew = self.hr_size 52 | if self.crop_size is not None: 53 | ih, iw = self.lr_size 54 | tp = self.scale * self.crop_size 55 | ip = tp // self.scale 56 | ix = random.randrange(0, iw - ip + 1) 57 | iy = random.randrange(0, ih - ip + 1) 58 | tx, ty = self.scale * ix, self.scale * iy 59 | lr_sh, lr_sw, hr_sh, hr_sw = iy, ix, ty, tx 60 | lr_eh, lr_ew, hr_eh, hr_ew = iy+ip, ix+ip, ty+tp, tx+tp 61 | 62 | for index in range(file_index, file_index + self.num_frames_samples): 63 | # HR 64 | hr = self.load_Irradiance_HR(folder_index, index) 65 | # hdr -> ldr 66 | # hr = data_utils.toneMapping(hr) 67 | hr = data_utils.np2Tensor(hr) 68 | HR_lst.append(hr[:, hr_sh:hr_eh, hr_sw:hr_ew]) 69 | 70 | # LR 71 | lr_files, mvs, ocmvs = [], [], [] 72 | for idx in range(index, index - self.number_previous_frames - 1, -1): 73 | file = self.load_Irradiance_LR(folder_index, idx) 74 | # hdr -> ldr 75 | # file = data_utils.toneMapping(file) 76 | if self.useNormal: 77 | normal = self.load_Normal_Unity(folder_index, idx) 78 | file = np.concatenate((file, normal), axis=2) 79 | if self.useDepth: 80 | depth = self.load_Depth_Unity(folder_index, idx) 81 | file = np.concatenate((file, depth), axis=2) 82 | file = data_utils.np2Tensor(file) 83 | lr_files.append(file[:, lr_sh:lr_eh, lr_sw:lr_ew]) 84 | 85 | # mv and ocmv 86 | if (idx != index - self.number_previous_frames): 87 | mv = data_utils.np2Tensor(self.load_MV(folder_index, idx)) 88 | ocmv = data_utils.np2Tensor(self.load_OCMV(folder_index, idx)) 89 | mvs.append(mv[:, lr_sh:lr_eh, lr_sw:lr_ew]) 90 | ocmvs.append(ocmv[:, lr_sh:lr_eh, lr_sw:lr_ew]) 91 | if (idx == index): 92 | mv_up = self.upsample(mv[None, :])[0] 93 | ocmv_up = self.upsample(ocmv[None, :])[0] 94 | mask_up = 1 - data_utils.mv_mask(ocmv_up, mv_up)[None, :] 95 | MV_up_lst.append(mv_up[:, hr_sh:hr_eh, hr_sw:hr_ew]) 96 | OCMV_up_lst.append(ocmv_up[:, hr_sh:hr_eh, hr_sw:hr_ew]) 97 | Mask_up_lst.append(mask_up[:, hr_sh:hr_eh, hr_sw:hr_ew]) 98 | 99 | # pre frames do Warp 100 | for i in range(self.number_previous_frames, 0, -1): 101 | for j in range(0, i-1): 102 | mvs[i-1] += mvs[j] 103 | ocmvs[i-1] += ocmvs[j] 104 | lr_files[i] = data_utils.backward_warp_motion(lr_files[i][None, :].cuda(), mvs[i-1][None, :].cuda(), 105 | lr_files[0][None, :].cuda())[0].cpu() 106 | 107 | # pre frames cat mask 108 | for i in range(0, self.number_previous_frames): 109 | mask = data_utils.mv_mask(ocmvs[i], mvs[i])[None, :] 110 | lr_files[i+1] = torch.cat((lr_files[i+1], mask), dim=0) 111 | 112 | lr = torch.cat(lr_files, dim=0) 113 | LR_lst.append(lr) 114 | 115 | # for i in range(1, self.num_frames_samples-1): 116 | # tmp_mv = MV_up_lst[i] 117 | # tmp_ocmv = OCMV_up_lst[i] 118 | # for j in range(0, i): 119 | # tmp_mv += MV_up_lst[j] 120 | # tmp_ocmv += OCMV_up_lst[i] 121 | # tmp_mask = 1 - data_utils.mv_mask(tmp_ocmv, tmp_mv)[None, :] 122 | # MV_up_lst.append(tmp_mv) 123 | # Mask_up_lst.append(tmp_mask) 124 | 125 | # return LR_lst[0], HR_lst[0], str(file_index) 126 | return LR_lst, HR_lst, MV_up_lst, Mask_up_lst, str(file_index) 127 | 128 | def __len__(self): 129 | return self.total_folder_num * self.grain -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import matplotlib 5 | 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | class Loss(nn.modules.loss._Loss): 16 | def __init__(self, args, ckp): 17 | super(Loss, self).__init__() 18 | print('Preparing loss function:') 19 | 20 | self.n_GPUs = args.n_GPUs 21 | self.loss = [] 22 | self.loss_module = nn.ModuleList() 23 | for loss in args.loss.split('+'): 24 | weight, loss_type = loss.split('*') 25 | if loss_type == 'MSE': 26 | loss_function = nn.MSELoss() 27 | elif loss_type == 'L1': 28 | loss_function = nn.SmoothL1Loss() 29 | elif loss_type.find('VGG') >= 0: 30 | module = import_module('loss.vgg') 31 | loss_function = getattr(module, 'VGG')( 32 | loss_type[3:] 33 | ) 34 | elif loss_type.find('GAN') >= 0: 35 | module = import_module('loss.adversarial') 36 | loss_function = getattr(module, 'Adversarial')( 37 | args, 38 | loss_type 39 | ) 40 | elif loss_type.find('Temporal') >= 0: 41 | module = import_module('loss.temporal_loss') 42 | loss_function = getattr(module, 'Temporal_Loss')() 43 | elif loss_type.find('SSIM') >= 0: 44 | module = import_module('loss.ssim_loss') 45 | loss_function = getattr(module, 'SSIM')() 46 | 47 | self.loss.append({ 48 | 'type': loss_type, 49 | 'weight': float(weight), 50 | 'function': loss_function}) 51 | if loss_type.find('GAN') >= 0: 52 | self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) 53 | 54 | if len(self.loss) > 1: 55 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) 56 | 57 | for l in self.loss: 58 | if l['function'] is not None: 59 | print('{:.3f} * {}'.format(l['weight'], l['type'])) 60 | self.loss_module.append(l['function']) 61 | 62 | self.log = torch.Tensor() 63 | 64 | device = torch.device('cpu' if args.cpu else 'cuda') 65 | self.loss_module.to(device) 66 | if args.precision == 'half': self.loss_module.half() 67 | if not args.cpu and args.n_GPUs > 1: 68 | self.loss_module = nn.DataParallel( 69 | self.loss_module, range(args.n_GPUs) 70 | ) 71 | 72 | if args.load != '': self.load(ckp.dir, cpu=args.cpu) 73 | 74 | # loss(sr_cur, hr, sr_warped, sr_merge, mask_up) 75 | def forward(self, sr_cur, hr, sr_warped=None, mask_up=None, needTem=True): 76 | losses = [] 77 | for i, l in enumerate(self.loss): 78 | if l['type'] == 'Temporal' and needTem: 79 | sr_warped_masked, merge_masked = sr_warped * mask_up, sr_cur * mask_up 80 | loss = l['function'](sr_warped_masked, merge_masked) 81 | effective_loss = l['weight'] * loss 82 | losses.append(effective_loss) 83 | self.log[-1, i] += effective_loss.item() 84 | elif l['function'] is not None: 85 | loss = l['function'](sr_cur, hr) 86 | effective_loss = l['weight'] * loss 87 | losses.append(effective_loss) 88 | self.log[-1, i] += effective_loss.item() 89 | elif l['type'] == 'DIS': 90 | self.log[-1, i] += self.loss[i - 1]['function'].loss 91 | 92 | loss_sum = sum(losses) 93 | if len(self.loss) > 1: 94 | self.log[-1, -1] += loss_sum.item() 95 | 96 | return loss_sum 97 | 98 | def step(self): 99 | for l in self.get_loss_module(): 100 | if hasattr(l, 'scheduler'): 101 | l.scheduler.step() 102 | 103 | def start_log(self): 104 | self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) 105 | 106 | def end_log(self, n_batches): 107 | self.log[-1].div_(n_batches) 108 | 109 | def display_loss(self, batch): 110 | n_samples = batch + 1 111 | log = [] 112 | for l, c in zip(self.loss, self.log[-1]): 113 | log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) 114 | 115 | return ''.join(log) 116 | 117 | def plot_loss(self, apath, epoch): 118 | axis = np.linspace(1, epoch, epoch) 119 | for i, l in enumerate(self.loss): 120 | label = '{} Loss'.format(l['type']) 121 | fig = plt.figure() 122 | plt.title(label) 123 | plt.plot(axis, self.log[:, i].numpy(), label=label) 124 | plt.legend() 125 | plt.xlabel('Epochs') 126 | plt.ylabel('Loss') 127 | plt.grid(True) 128 | plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type']))) 129 | plt.close(fig) 130 | 131 | def get_loss_module(self): 132 | if self.n_GPUs == 1: 133 | return self.loss_module 134 | else: 135 | return self.loss_module.module 136 | 137 | def save(self, apath): 138 | torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) 139 | torch.save(self.log, os.path.join(apath, 'loss_log.pt')) 140 | 141 | def load(self, apath, cpu=False): 142 | if cpu: 143 | kwargs = {'map_location': lambda storage, loc: storage} 144 | else: 145 | kwargs = {} 146 | 147 | self.load_state_dict(torch.load( 148 | os.path.join(apath, 'loss.pt'), 149 | **kwargs 150 | )) 151 | self.log = torch.load(os.path.join(apath, 'loss_log.pt')) 152 | for l in self.get_loss_module(): 153 | if hasattr(l, 'scheduler'): 154 | for _ in range(len(self.log)): l.scheduler.step() 155 | -------------------------------------------------------------------------------- /src/option.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | 3 | parser = configargparse.ArgumentParser(description='NSRD') 4 | 5 | # Config file 6 | parser.add_argument('--config', is_config_file=True, default='../configs/Bistro_train.txt', 7 | help='config file path') 8 | 9 | # Hardware specifications 10 | parser.add_argument('--n_threads', type=int, default=6, 11 | help='number of threads for data loading') 12 | parser.add_argument('--cpu', action='store_true', 13 | help='use cpu only') 14 | parser.add_argument('--n_GPUs', type=int, default=1, 15 | help='number of GPUs') 16 | parser.add_argument('--seed', type=int, default=1, 17 | help='random seed') 18 | 19 | # Data specifications 20 | parser.add_argument('--root_dir', type=str, default=r"../dataset/Bistro/train", 21 | help='rendering data_dir') 22 | parser.add_argument('--data_name', type=str, default='Bistro', 23 | help='test dataset name') 24 | parser.add_argument('--view_dirname', type=str, default='View', 25 | help='view dirname') 26 | parser.add_argument('--irradiance_dirname', type=str, default='Irradiance', 27 | help='Irradiance dirname') 28 | parser.add_argument('--depth_dirname', type=str, default='Depth', 29 | help='depth dirname') 30 | parser.add_argument('--normal_dirname', type=str, default='Normal', 31 | help='normal dirname') 32 | parser.add_argument('--mv_dirname', type=str, default='MV', 33 | help='LR mv dirname') 34 | parser.add_argument('--ocmv_dirname', type=str, default='OCMV', 35 | help='LR ocmv dirname') 36 | parser.add_argument('--brdf_dirname', type=str, default='BRDF', 37 | help='LR ocmv dirname') 38 | 39 | parser.add_argument('--use_normal', type=bool, default=True, 40 | help='use normal') 41 | parser.add_argument('--use_depth', type=bool, default=True, 42 | help='use depth') 43 | parser.add_argument('--num_frames_samples', type=int, default=4, 44 | help='number of previous frames') 45 | parser.add_argument('--num_pre_frames', type=int, default=2, 46 | help='number of previous frames') 47 | parser.add_argument('--input_total_dims', type=int, default=3+1+3, 48 | help='input total dims') 49 | parser.add_argument('--output_dims', type=int, default=3, 50 | help='output dims') 51 | parser.add_argument('--sr_content', type=str, default='Irradiance', 52 | help='choose which content to super resolution, ' 53 | '(Irradiance, View, LMDB)') 54 | 55 | parser.add_argument('--gt_size', type=tuple, default=(1080, 1920), 56 | help='GT size(ih, iw)') 57 | parser.add_argument('--scale', type=int, default=4, 58 | help='super resolution scale') 59 | parser.add_argument('--crop_size', type=int, default=96, 60 | help='output patch size') 61 | parser.add_argument('--chop', action='store_true', 62 | help='enable memory-efficient forward') 63 | parser.add_argument('--grain', type=int, default=100, 64 | help='file number in a folder') 65 | parser.add_argument('--total_folder_num', type=int, default=60, 66 | help='total folder num') 67 | parser.add_argument('--test_folder', type=int, action="append", 68 | help='test folder index') 69 | parser.add_argument('--valid_folder_num', type=int, default=6, 70 | help='valid folder number') 71 | 72 | # Model specifications 73 | parser.add_argument('--model', default='sr_model', 74 | help='model name') 75 | parser.add_argument('--extend', type=str, default='.', 76 | help='pre-trained model directory') 77 | parser.add_argument('--n_feats', type=int, default=32, 78 | help='number of feature maps') 79 | parser.add_argument('--precision', type=str, default='single', 80 | choices=('single', 'half'), 81 | help='FP precision for test (single | half)') 82 | 83 | # Training specifications 84 | parser.add_argument('--reset', action='store_true', 85 | help='reset the training') 86 | parser.add_argument('--test_every', type=int, default=1000, 87 | help='do test per every N batches') 88 | parser.add_argument('--epochs', type=int, default=200, 89 | help='number of epochs to train') 90 | parser.add_argument('--batch_size', type=int, default=8, 91 | help='input batch size for training') 92 | parser.add_argument('--test_only', action='store_true', default=False, 93 | help='set this option to test the model') 94 | 95 | # Optimization specifications 96 | parser.add_argument('--lr', type=float, default=0.0005, 97 | help='learning rate') 98 | parser.add_argument('--decay', type=str, default='100', 99 | help='learning rate decay type') 100 | parser.add_argument('--gamma', type=float, default=0.5, 101 | help='learning rate decay factor for step decay') 102 | parser.add_argument('--optimizer', default='ADAM', 103 | choices=('SGD', 'ADAM', 'RMSprop'), 104 | help='optimizer to use (SGD | ADAM | RMSprop)') 105 | parser.add_argument('--momentum', type=float, default=0.9, 106 | help='SGD momentum') 107 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), 108 | help='ADAM beta') 109 | parser.add_argument('--epsilon', type=float, default=1e-8, 110 | help='ADAM epsilon for numerical stability') 111 | parser.add_argument('--weight_decay', type=float, default=0, 112 | help='weight decay') 113 | parser.add_argument('--gclip', type=float, default=0, 114 | help='gradient clipping threshold (0 = no clipping)') 115 | 116 | # Loss specifications 117 | parser.add_argument('--loss', type=str, default='1.0*L1+1.0*SSIM+1.0*Temporal', 118 | help='loss function configuration') 119 | parser.add_argument('--skip_threshold', type=float, default='1e8', 120 | help='skipping batch that has large error') 121 | 122 | # Log specifications 123 | parser.add_argument('--save', type=str, default='Bistro_X4', 124 | help='file name to save') 125 | parser.add_argument('--load', type=str, default='', 126 | help='file name to load') 127 | parser.add_argument('--resume', type=int, default=1, 128 | help='resume from specific checkpoint: -1 -> best, 0 -> latest, else -> None') 129 | parser.add_argument('--save_models', action='store_true', 130 | help='save all intermediate models') 131 | parser.add_argument('--print_every', type=int, default=100, 132 | help='how many batches to wait before logging training status') 133 | parser.add_argument('--save_results', action='store_true', default=False, 134 | help='save output results') 135 | parser.add_argument('--save_gt', action='store_true', 136 | help='save low-resolution and high-resolution images together') 137 | 138 | args = parser.parse_args() 139 | 140 | for arg in vars(args): 141 | if vars(args)[arg] == 'True': 142 | vars(args)[arg] = True 143 | elif vars(args)[arg] == 'False': 144 | vars(args)[arg] = False 145 | 146 | 147 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 3 | import math 4 | import time 5 | from decimal import Decimal 6 | import numpy as np 7 | import imageio 8 | import utility 9 | 10 | import torch 11 | import torch.nn.utils as utils 12 | from tqdm import tqdm 13 | from data import data_utils 14 | 15 | class Trainer(): 16 | def __init__(self, args, loader, my_model, my_loss, ckp): 17 | super(Trainer, self).__init__() 18 | self.args = args 19 | self.scale = args.scale 20 | self.gt_size = args.gt_size 21 | self.batch_size = args.batch_size 22 | self.ckp = ckp 23 | self.model = my_model 24 | self.num_frames_samples = args.num_frames_samples 25 | self.train_loader = loader.loader_train 26 | self.valid_loader = loader.loader_valid 27 | self.loss = my_loss 28 | self.optimizer = utility.make_optimizer(args, self.model) 29 | 30 | if self.args.load != '': 31 | self.optimizer.load(ckp.dir, epoch=len(ckp.log)) 32 | 33 | self.error_last = 1e8 34 | 35 | def train(self): 36 | self.loss.step() 37 | epoch = self.optimizer.get_last_epoch() + 1 38 | lr = self.optimizer.get_lr() 39 | 40 | self.ckp.write_log( 41 | '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) 42 | ) 43 | self.loss.start_log() 44 | self.model.train() 45 | 46 | timer_data, timer_model = utility.timer(), utility.timer() 47 | # TEMP 48 | for batch, (LR_lst, HR_lst, MV_up_lst, Mask_up_lst, _) in enumerate(self.train_loader): 49 | self.optimizer.zero_grad() 50 | 51 | b, c, h, w = HR_lst[0].size() 52 | zero_tensor = torch.zeros(b, c, h, w, dtype=torch.float32) 53 | lr0, zero_tensor, hr0 = self.prepare(LR_lst[0], zero_tensor, HR_lst[0]) 54 | 55 | sr_pre, lstm_state = self.model((lr0, zero_tensor, None)) 56 | lstm_state = utility.repackage_hidden(lstm_state) 57 | loss = self.loss(sr_pre, hr0, needTem=False) 58 | 59 | for i in range(1, self.num_frames_samples): 60 | sr_pre = sr_pre.detach() 61 | sr_pre.requires_grad = False 62 | 63 | lr, hr, mv_up, mask_up = self.prepare(LR_lst[i], HR_lst[i], MV_up_lst[i], Mask_up_lst[i]) 64 | 65 | timer_data.hold() 66 | timer_model.tic() 67 | 68 | sr_pre_warped = data_utils.warp(sr_pre, mv_up) 69 | sr_cur, lstm_state = self.model((lr, sr_pre_warped, lstm_state)) 70 | lstm_state = utility.repackage_hidden(lstm_state) 71 | 72 | loss += self.loss(sr_cur, hr, sr_pre_warped, mask_up, needTem=True) 73 | sr_pre = sr_cur 74 | loss.backward() 75 | self.optimizer.step() 76 | timer_model.hold() 77 | if (batch + 1) % self.args.print_every == 0: 78 | self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format( 79 | (batch + 1) * self.args.batch_size, 80 | self.train_loader.n_samples, 81 | self.loss.display_loss(batch), 82 | timer_model.release(), 83 | timer_data.release())) 84 | 85 | timer_data.tic() 86 | 87 | self.loss.end_log(len(self.train_loader)) 88 | self.error_last = self.loss.log[-1, -1] 89 | self.optimizer.schedule() 90 | 91 | def test(self): 92 | torch.set_grad_enabled(False) 93 | 94 | epoch = self.optimizer.get_last_epoch() 95 | self.ckp.write_log('\nEvaluation:') 96 | self.ckp.add_log( 97 | torch.zeros(1, 3) 98 | ) 99 | self.model.eval() 100 | 101 | timer_test = utility.timer() 102 | run_model_time = 0 103 | flag = 0 104 | if self.args.save_results: self.ckp.begin_background() 105 | 106 | pre_sr = torch.zeros(1, 3, self.gt_size[0], self.gt_size[1], 107 | dtype=torch.float32).cuda() 108 | lstm_state = None 109 | os.makedirs(".\check", exist_ok=True) 110 | for index, (LR_lst, HR_lst, MV_up_lst, Mask_up_lst, filename) in tqdm(enumerate(self.valid_loader)): 111 | lr, hr, mv_up, mask_up = self.prepare(LR_lst[0], HR_lst[0], MV_up_lst[0], Mask_up_lst[0]) 112 | if index == 0: 113 | pre_sr, lstm_state = self.model((lr, pre_sr, lstm_state)) 114 | lstm_state = utility.repackage_hidden(lstm_state) 115 | continue 116 | t1 = time.time() 117 | sr_pre_warped = data_utils.warp(pre_sr, mv_up) 118 | cur_sr, lstm_state = self.model((lr, sr_pre_warped, lstm_state)) 119 | lstm_state = utility.repackage_hidden(lstm_state) 120 | t2 = time.time() 121 | run_model_time += (t2 - t1) 122 | if self.args.sr_content == "View": 123 | sr = utility.quantize_img(cur_sr) 124 | sr_last = utility.quantize_img(pre_sr) 125 | if flag < 2: 126 | data_utils.save2Exr(np.array(sr[0, :3, :, :].permute(1, 2, 0).detach().cpu()) * 255, 127 | ".\\check\\sr_" + str(flag) + ".png") 128 | data_utils.save2Exr(np.array(hr[0, :3, :, :].permute(1, 2, 0).detach().cpu()) * 255, 129 | ".\\check\\gt_" + str(flag) + ".png") 130 | flag += 1 131 | else: 132 | sr = utility.quantize(cur_sr) 133 | sr_last = utility.quantize(pre_sr) 134 | if flag < 2: 135 | data_utils.save2Exr(np.array(sr[0, :3, :, :].permute(1, 2, 0).detach().cpu()), 136 | ".\\check\\sr_" + str(flag) + ".exr") 137 | data_utils.save2Exr(np.array(hr[0, :3, :, :].permute(1, 2, 0).detach().cpu()), 138 | ".\\check\\gt_" + str(flag) + ".exr") 139 | flag += 1 140 | 141 | pre_sr = cur_sr 142 | save_list = [sr] 143 | assert sr is not torch.nan, "sr is nan!" 144 | val_ssim = 1.0 - utility.calc_ssim(sr, hr).cpu() 145 | warped_sr = data_utils.warp(sr_last, mv_up) 146 | val_tempory = utility.calc_tempory(warped_sr, sr, mask_up).cpu() 147 | 148 | self.ckp.log[-1, 0] += val_ssim 149 | self.ckp.log[-1, 1] += val_tempory 150 | self.ckp.log[-1, 2] += val_tempory + val_ssim 151 | 152 | if self.args.save_gt: 153 | save_list.extend([lr, hr]) 154 | 155 | if self.args.save_results: 156 | self.ckp.save_results(self.valid_loader, filename[0], save_list, self.scale) 157 | 158 | self.ckp.log[-1] /= (len(self.valid_loader) - 1) 159 | best = self.ckp.log.min(0) 160 | self.ckp.write_log( 161 | '[{} x{}]\tSSIM: {:.6f}, Tempory: {:.6f}, Total :{:.6f} (Best: {:.6f} @epoch {})'.format( 162 | self.valid_loader.dataset.name, 163 | self.scale, 164 | self.ckp.log[-1][0], 165 | self.ckp.log[-1][1], 166 | self.ckp.log[-1][2], 167 | best[0][2], 168 | best[1][2] + 1 169 | ) 170 | ) 171 | 172 | self.ckp.write_log('Run model time {:.5f}s\n'.format(run_model_time)) 173 | self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc())) 174 | self.ckp.write_log('Saving...') 175 | 176 | if self.args.save_results: 177 | self.ckp.end_background() 178 | 179 | if not self.args.test_only: 180 | self.ckp.save(self, epoch, is_best=(best[0][2] is not torch.nan and best[1][2] + 1 == epoch)) 181 | 182 | self.ckp.write_log( 183 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True 184 | ) 185 | 186 | torch.set_grad_enabled(True) 187 | 188 | def prepare(self, *args): 189 | device = torch.device('cpu' if self.args.cpu else 'cuda') 190 | def _prepare(tensor): 191 | if self.args.precision == 'half': tensor = tensor.half() 192 | return tensor.to(device) 193 | 194 | return [_prepare(a) for a in args] 195 | 196 | def terminate(self): 197 | if self.args.test_only: 198 | self.test() 199 | return True 200 | else: 201 | epoch = self.optimizer.get_last_epoch() + 1 202 | return epoch >= self.args.epochs -------------------------------------------------------------------------------- /src/utility.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import datetime 5 | import threading 6 | from queue import Queue 7 | import matplotlib 8 | 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt 11 | from model.ssim import ssim 12 | import numpy as np 13 | 14 | import torch 15 | import torch.optim as optim 16 | import torch.optim.lr_scheduler as lrs 17 | import torch.nn.functional as F 18 | from data.data_utils import save2Exr 19 | 20 | 21 | class timer(): 22 | def __init__(self): 23 | self.acc = 0 24 | self.tic() 25 | 26 | def tic(self): 27 | self.t0 = time.time() 28 | 29 | def toc(self, restart=False): 30 | diff = time.time() - self.t0 31 | if restart: self.t0 = time.time() 32 | return diff 33 | 34 | def hold(self): 35 | self.acc += self.toc() 36 | 37 | def release(self): 38 | ret = self.acc 39 | self.acc = 0 40 | 41 | return ret 42 | 43 | def reset(self): 44 | self.acc = 0 45 | 46 | 47 | class checkpoint(): 48 | def __init__(self, args): 49 | self.args = args 50 | self.ok = True 51 | self.log = torch.Tensor() 52 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 53 | 54 | if not args.load: 55 | if not args.save: 56 | args.save = now 57 | self.dir = os.path.join('..', 'experiment', args.save) 58 | else: 59 | self.dir = os.path.join('..', 'experiment', args.load) 60 | if os.path.exists(self.dir): 61 | self.log = torch.load(self.get_path('ssim_log.pt')) 62 | print('Continue from epoch {}...'.format(len(self.log))) 63 | else: 64 | args.load = '' 65 | 66 | if args.reset: 67 | os.system('rm -rf ' + self.dir) 68 | args.load = '' 69 | 70 | os.makedirs(self.dir, exist_ok=True) 71 | os.makedirs(self.get_path('model'), exist_ok=True) 72 | if args.test_only: 73 | assert len(args.test_folder) == 1, "Currently only supports testing one folder at a time!" 74 | self.save_dir = 'sr_results_x{}/{}'.format(args.scale, args.test_folder[0]) 75 | os.makedirs(self.get_path(self.save_dir), exist_ok=True) 76 | 77 | open_type = 'a' if os.path.exists(self.get_path('log.txt')) else 'w' 78 | self.log_file = open(self.get_path('log.txt'), open_type) 79 | with open(self.get_path('config.txt'), open_type) as f: 80 | f.write(now + '\n\n') 81 | for arg in vars(args): 82 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 83 | f.write('\n') 84 | 85 | self.n_processes = 4 86 | 87 | def get_path(self, *subdir): 88 | return os.path.join(self.dir, *subdir) 89 | 90 | def save(self, trainer, epoch, is_best=False): 91 | trainer.model.save(self.get_path('model'), epoch, is_best=is_best) 92 | trainer.loss.save(self.dir) 93 | trainer.loss.plot_loss(self.dir, epoch) 94 | 95 | self.plot_ssim(epoch) 96 | trainer.optimizer.save(self.dir) 97 | torch.save(self.log, self.get_path('ssim_log.pt')) 98 | 99 | def add_log(self, log): 100 | self.log = torch.cat([self.log, log]) 101 | 102 | def write_log(self, log, refresh=False): 103 | print(log) 104 | self.log_file.write(log + '\n') 105 | if refresh: 106 | self.log_file.close() 107 | self.log_file = open(self.get_path('log.txt'), 'a') 108 | 109 | def done(self): 110 | self.log_file.close() 111 | 112 | def plot_psnr(self, epoch): 113 | axis = np.linspace(1, epoch, epoch) 114 | label = 'SR on {}'.format(self.args.data_name) 115 | fig = plt.figure() 116 | plt.title(label) 117 | plt.plot( 118 | axis, 119 | self.log[:].numpy(), 120 | label='Scale {}'.format(self.args.scale) 121 | ) 122 | plt.legend() 123 | plt.xlabel('Epochs') 124 | plt.ylabel('PSNR') 125 | plt.grid(True) 126 | plt.savefig(self.get_path('test_{}.pdf'.format(self.args.data_name))) 127 | plt.close(fig) 128 | 129 | def plot_ssim(self, epoch): 130 | axis = np.linspace(1, epoch, epoch) 131 | label = 'SR on {}'.format(self.args.data_name) 132 | fig = plt.figure() 133 | plt.title(label) 134 | plt.plot( 135 | axis, 136 | (self.log[:]).numpy(), 137 | label='Scale {}'.format(self.args.scale) 138 | ) 139 | plt.legend() 140 | plt.xlabel('Epochs') 141 | plt.ylabel('SSIM') 142 | plt.grid(True) 143 | plt.savefig(self.get_path('test_{}.pdf'.format(self.args.data_name))) 144 | plt.close(fig) 145 | 146 | def begin_background(self): 147 | self.queue = Queue() 148 | 149 | def bg_target(queue): 150 | while True: 151 | if not queue.empty(): 152 | filename, tensor = queue.get() 153 | if filename is None: break 154 | im_np = tensor.numpy() 155 | save2Exr(im_np, filename) 156 | # im_np.tofile(filename) 157 | # imageio.imwrite(filename, tensor.numpy()) 158 | 159 | self.process = [ 160 | threading.Thread(target=bg_target, args=(self.queue,)) \ 161 | for _ in range(self.n_processes) 162 | ] 163 | 164 | for p in self.process: p.start() 165 | 166 | def end_background(self): 167 | for _ in range(self.n_processes): self.queue.put((None, None)) 168 | while not self.queue.empty(): time.sleep(1) 169 | for p in self.process: p.join() 170 | 171 | def save_results(self, dataset, filename, save_list, scale): 172 | if self.args.save_results: 173 | filename = self.get_path( 174 | self.save_dir, 175 | '{}'.format(filename) 176 | ) 177 | 178 | if self.args.sr_content == 'View': 179 | for v in save_list: 180 | tensor_cpu = v[0].permute(1, 2, 0).cpu() * 255.0 181 | self.queue.put(('{}.png'.format(filename), tensor_cpu)) 182 | else: 183 | for v in save_list: 184 | tensor_cpu = v[0].permute(1, 2, 0).cpu() 185 | self.queue.put(('{}.exr'.format(filename), tensor_cpu)) 186 | 187 | 188 | def quantize(tensor): 189 | return torch.clamp(tensor, min=0.0) 190 | 191 | 192 | def quantize_img(tensor): 193 | return torch.clamp(tensor, min=0.0, max=1.0) 194 | 195 | 196 | def hdr2ldr(tensor): 197 | def adjust(data): 198 | contrast = 0.0 199 | brightness = 0.0 200 | contrastFactor = (259.0 * (contrast * 256.0 + 255.0)) / (255.0 * (259.0 - 256.0 * contrast)) 201 | data = (data - 0.5) * contrastFactor + 0.5 + brightness 202 | return data 203 | 204 | chromatic = 0.5 * tensor 205 | adj = adjust(chromatic) 206 | tonemapping = adj / (adj + 1.0) 207 | gamma = torch.pow(tonemapping, 1.0 / 2.2) 208 | return gamma 209 | 210 | 211 | def calc_psnr(sr, hr, scale, rgb_range): 212 | if hr.nelement() == 1: return 0 213 | 214 | diff = (sr - hr) / rgb_range 215 | shave = scale + 6 216 | 217 | valid = diff[..., shave:-shave, shave:-shave] 218 | mse = valid.pow(2).mean() 219 | 220 | return -10 * math.log10(mse) 221 | 222 | 223 | def calc_mse(sr, hr): 224 | return F.mse_loss(sr, hr) 225 | 226 | 227 | def calc_ssim(sr, hr): 228 | return ssim(sr, hr) 229 | 230 | 231 | def calc_tempory(warped_sr, merge_sr, noc_mask): 232 | return F.l1_loss(warped_sr * noc_mask, merge_sr * noc_mask) 233 | 234 | 235 | def repackage_hidden(h): 236 | """Wraps hidden states in new Variables, to detach them from their history.""" 237 | if isinstance(h, torch.Tensor): 238 | return h.detach() 239 | else: 240 | return tuple(repackage_hidden(v) for v in h) 241 | 242 | 243 | def make_optimizer(args, target): 244 | ''' 245 | make optimizer and scheduler together 246 | ''' 247 | # optimizer 248 | trainable = filter(lambda x: x.requires_grad, target.parameters()) 249 | kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay} 250 | 251 | if args.optimizer == 'SGD': 252 | optimizer_class = optim.SGD 253 | kwargs_optimizer['momentum'] = args.momentum 254 | elif args.optimizer == 'ADAM': 255 | optimizer_class = optim.Adam 256 | kwargs_optimizer['betas'] = args.betas 257 | kwargs_optimizer['eps'] = args.epsilon 258 | elif args.optimizer == 'RMSprop': 259 | optimizer_class = optim.RMSprop 260 | kwargs_optimizer['eps'] = args.epsilon 261 | 262 | # scheduler 263 | milestones = list(map(lambda x: int(x), args.decay.split('-'))) 264 | kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma} 265 | scheduler_class = lrs.MultiStepLR 266 | 267 | class CustomOptimizer(optimizer_class): 268 | def __init__(self, *args, **kwargs): 269 | super(CustomOptimizer, self).__init__(*args, **kwargs) 270 | 271 | def _register_scheduler(self, scheduler_class, **kwargs): 272 | self.scheduler = scheduler_class(self, **kwargs) 273 | 274 | def save(self, save_dir): 275 | torch.save(self.state_dict(), self.get_dir(save_dir)) 276 | 277 | def load(self, load_dir, epoch=1): 278 | self.load_state_dict(torch.load(self.get_dir(load_dir))) 279 | if epoch > 1: 280 | for _ in range(epoch): self.scheduler.step() 281 | 282 | def get_dir(self, dir_path): 283 | return os.path.join(dir_path, 'optimizer.pt') 284 | 285 | def schedule(self): 286 | self.scheduler.step() 287 | 288 | def get_lr(self): 289 | return self.scheduler.get_last_lr()[0] 290 | 291 | def get_last_epoch(self): 292 | return self.scheduler.last_epoch 293 | 294 | optimizer = CustomOptimizer(trainable, **kwargs_optimizer) 295 | optimizer._register_scheduler(scheduler_class, **kwargs_scheduler) 296 | return optimizer 297 | --------------------------------------------------------------------------------