├── user-imgs
└── teaser.jpg
├── requirements.txt
├── configs
├── Bistro_train.txt
└── Bistro_test.txt
├── src
├── script
│ ├── BistroX4_train.bat
│ └── BistroX4_test.bat
├── loss
│ ├── ssim_loss.py
│ ├── temporal_loss.py
│ ├── vgg.py
│ └── __init__.py
├── main.py
├── data
│ ├── __init__.py
│ ├── base_dataset.py
│ ├── data_loader.py
│ ├── data_utils.py
│ ├── view_dataset.py
│ └── irradiance_dataset.py
├── model
│ ├── ConvLSTM.py
│ ├── ssim.py
│ ├── sr_model.py
│ ├── __init__.py
│ └── sr_common.py
├── remodulation.py
├── option.py
├── trainer.py
└── utility.py
├── LICENSE
├── README.md
└── .gitignore
/user-imgs/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Riga2/NSRD/HEAD/user-imgs/teaser.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.24.3
2 | opencv-python==4.7.0.72
3 | Pillow==9.5.0
4 | tqdm==4.65.0
5 | matplotlib
6 | imageio
7 | scikit-image
8 | configparser
--------------------------------------------------------------------------------
/configs/Bistro_train.txt:
--------------------------------------------------------------------------------
1 | root_dir = ../dataset/Bistro/train
2 | grain = 100
3 | total_folder_num = 60
4 | test_folder = []
5 | valid_folder_num = 6
6 | test_only = False
7 | save = Bistro_X4
8 | resume = 1
9 | save_results = False
--------------------------------------------------------------------------------
/src/script/BistroX4_train.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | echo Running training on BistroX4...
3 |
4 | set train_config=../configs/Bistro_train.txt
5 |
6 | python main.py --config %train_config%
7 |
8 | echo Training done.
9 | pause
10 |
--------------------------------------------------------------------------------
/src/loss/ssim_loss.py:
--------------------------------------------------------------------------------
1 | from model import ssim
2 | import torch.nn as nn
3 |
4 |
5 | class SSIM(nn.Module):
6 | def __init__(self):
7 | super(SSIM, self).__init__()
8 |
9 | def forward(self, sr, hr):
10 | return 1 - ssim.ssim(sr, hr)
11 |
--------------------------------------------------------------------------------
/configs/Bistro_test.txt:
--------------------------------------------------------------------------------
1 | root_dir = ../dataset/Bistro/test
2 | grain = 300
3 | total_folder_num = 4
4 | test_folder = [1] # Currently only supports testing one folder at a time
5 | valid_folder_num = 0
6 | test_only = True
7 | save = Bistro_X4
8 | resume = -1
9 | save_results = True
--------------------------------------------------------------------------------
/src/loss/temporal_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 | class Temporal_Loss(nn.Module):
6 | def __init__(self):
7 | super(Temporal_Loss, self).__init__()
8 |
9 | def forward(self, sr_pre, sr_cur):
10 | return F.smooth_l1_loss(sr_pre, sr_cur)
11 |
--------------------------------------------------------------------------------
/src/script/BistroX4_test.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | echo Running test on BistroX4...
3 |
4 | set folders=0 1 2 3
5 | set test_config=../configs/Bistro_test.txt
6 |
7 | for %%i in (%folders%) do (
8 | echo Testing folder %%i ...
9 | python main.py --config %test_config% --test_folder %%i
10 | )
11 |
12 | echo All tests executed.
13 | pause
14 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import utility
4 | import data
5 | import model
6 | import loss
7 | from option import args
8 | from trainer import Trainer
9 |
10 | torch.manual_seed(args.seed)
11 |
12 |
13 | def main():
14 | global model
15 | checkpoint = utility.checkpoint(args)
16 | if checkpoint.ok:
17 | loader = data.Data(args)
18 | _model = model.Model(args, checkpoint)
19 | _loss = loss.Loss(args, checkpoint) if not args.test_only else None
20 | t = Trainer(args, loader, _model, _loss, checkpoint)
21 | while not t.terminate():
22 | t.train()
23 | t.test()
24 |
25 | checkpoint.done()
26 |
27 | if __name__ == '__main__':
28 | main()
29 |
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib import import_module
2 | from data import data_loader
3 | import os
4 |
5 | class Data:
6 | def __init__(self, args):
7 | self.sr_content = args.sr_content
8 | self.loader_train = None
9 | module_name = self.sr_content
10 | m = import_module('data.' + module_name.lower() + '_dataset')
11 | if not args.test_only:
12 | train_datasets = m.make_model(args, train=True)
13 | valid_datasets = m.make_model(args, train=False)
14 | self.loader_train = data_loader.RenderingDataLoader(args, train_datasets)
15 | self.loader_valid = self.loader_train.split_validation(valid_datasets)
16 | else:
17 | test_datasets = m.make_model(args, train=False)
18 | self.loader_valid = data_loader.RenderingDataLoader(args, test_datasets)
19 |
20 |
--------------------------------------------------------------------------------
/src/loss/vgg.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchvision.models as models
6 |
7 | class VGG(nn.Module):
8 | def __init__(self, conv_index):
9 | super(VGG, self).__init__()
10 | vgg_features = models.vgg19(pretrained=True).features
11 | modules = [m for m in vgg_features]
12 | if conv_index.find('22') >= 0:
13 | self.vgg = nn.Sequential(*modules[:8])
14 | elif conv_index.find('54') >= 0:
15 | self.vgg = nn.Sequential(*modules[:35])
16 |
17 | for p in self.parameters():
18 | p.requires_grad = False
19 |
20 | def forward(self, sr, hr):
21 | def _forward(x):
22 | x = self.vgg(x)
23 | return x
24 |
25 | vgg_sr = _forward(sr)
26 | with torch.no_grad():
27 | vgg_hr = _forward(hr.detach())
28 |
29 | loss = F.mse_loss(vgg_sr, vgg_hr)
30 |
31 | return loss
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Riga2
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/model/ConvLSTM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | def depth_wise_conv(in_feats, out_feats, kernel, bias=True):
6 | return nn.Sequential(
7 | nn.Conv2d(in_feats, in_feats, kernel_size=kernel, padding=(kernel // 2), groups=in_feats, bias=bias),
8 | nn.Conv2d(in_feats, out_feats, kernel_size=1)
9 | )
10 |
11 |
12 | class ConvLSTM(nn.Module):
13 |
14 | def __init__(self, input_size, hidden_size, kernel_size):
15 | super(ConvLSTM, self).__init__()
16 |
17 | self.input_size = input_size
18 | self.hidden_size = hidden_size
19 | pad = kernel_size // 2
20 |
21 | self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, kernel_size, padding=pad)
22 | # depth_wise_conv(input_size + hidden_size, 4 * hidden_size, kernel_size)
23 |
24 | def forward(self, input_, prev_state=None):
25 | # get batch and spatial sizes
26 | batch_size = input_.data.size()[0]
27 | spatial_size = input_.data.size()[2:]
28 |
29 | # generate empty prev_state, if None is provided
30 | if prev_state is None:
31 | state_size = [batch_size, self.hidden_size] + list(spatial_size)
32 | prev_state = (
33 | torch.zeros(state_size).to(input_.device),
34 | torch.zeros(state_size).to(input_.device)
35 | )
36 |
37 | prev_hidden, prev_cell = prev_state
38 |
39 | # data size is [batch, channel, height, width]
40 | stacked_inputs = torch.cat((input_, prev_hidden), 1)
41 | gates = self.Gates(stacked_inputs)
42 |
43 | # chunk across channel dimension
44 | in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1)
45 |
46 | # apply sigmoid non linearity
47 | in_gate = torch.sigmoid(in_gate)
48 | remember_gate = torch.sigmoid(remember_gate)
49 | out_gate = torch.sigmoid(out_gate)
50 |
51 | # apply tanh non linearity
52 | cell_gate = torch.tanh(cell_gate)
53 |
54 | # compute current cell and hidden state
55 | cell = (remember_gate * prev_cell) + (in_gate * cell_gate)
56 | hidden = out_gate * torch.tanh(cell)
57 |
58 | return hidden, cell
59 |
--------------------------------------------------------------------------------
/src/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
3 | import cv2
4 | import numpy as np
5 | from torch.utils.data import Dataset
6 | from data import data_utils
7 |
8 | class BaseDataset(Dataset):
9 | def __init__(self, args=None):
10 | super(BaseDataset, self).__init__()
11 | self.args = args
12 | self.scale = args.scale
13 | self.number_previous_frames = args.num_pre_frames
14 | self.hr_size = args.gt_size
15 | self.lr_size = (args.gt_size[0] // self.scale, args.gt_size[1] // self.scale)
16 |
17 | self.gt_dir = os.path.join(args.root_dir, 'GT')
18 | self.lr_dir = os.path.join(args.root_dir, 'X' + str(args.scale))
19 |
20 | self.useNormal, self.useDepth = args.use_normal, args.use_depth
21 | self.depth_dirname = args.depth_dirname
22 | self.normal_dirname = args.normal_dirname
23 | self.mv_dirname = args.mv_dirname
24 | self.ocmv_dirname = args.ocmv_dirname
25 | self.grain = args.grain
26 | self.total_folder_num = args.total_folder_num
27 |
28 | def load_Normal_Unity(self, folder_index, file_index, ext='.exr'):
29 | filename = str(file_index) + ext
30 | lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.normal_dirname, filename)
31 | lr = data_utils.getFromExr(lr_file_path)[:, :, :3]
32 | return lr
33 |
34 | def load_Depth_Unity(self, folder_index, file_index, ext='.exr'):
35 | filename = str(file_index) + ext
36 | lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.depth_dirname, filename)
37 | lr = data_utils.getFromExr(lr_file_path)[:, :, 0][:, :, None]
38 | return lr
39 |
40 | def load_MV(self, folder_index, file_index, ext='.exr'):
41 | filename = str(file_index) + ext
42 | lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.mv_dirname, filename)
43 | # lr = data_utils.getFromBin(lr_file_path, self.lr_size[0], self.lr_size[1])[:, :, :2]
44 | lr = data_utils.getFromExr(lr_file_path)[:, :, :2]
45 | lr[:, :, 0] = lr[:, :, 0] * self.lr_size[1]
46 | lr[:, :, 1] = lr[:, :, 1] * self.lr_size[0]
47 | return lr
48 |
49 | def load_OCMV(self, folder_index, file_index, ext='.exr'):
50 | filename = str(file_index) + ext
51 | lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.ocmv_dirname, filename)
52 | # lr = data_utils.getFromBin(lr_file_path, self.lr_size[0], self.lr_size[1])[:, :, :2]
53 | lr = data_utils.getFromExr(lr_file_path)[:, :, :2]
54 | lr[:, :, 0] = lr[:, :, 0] * self.lr_size[1]
55 | lr[:, :, 1] = lr[:, :, 1] * self.lr_size[0]
56 | return lr
57 |
58 |
--------------------------------------------------------------------------------
/src/model/ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 |
8 | def gaussian(window_size, sigma):
9 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
10 | return gauss / gauss.sum()
11 |
12 |
13 | def create_window(window_size, channel):
14 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
15 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
16 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
17 | return window
18 |
19 |
20 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
21 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
22 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
23 |
24 | mu1_sq = mu1.pow(2)
25 | mu2_sq = mu2.pow(2)
26 | mu1_mu2 = mu1 * mu2
27 |
28 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
29 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
30 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
31 |
32 | C1 = 0.01 ** 2
33 | C2 = 0.03 ** 2
34 |
35 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
36 |
37 | if size_average:
38 | return ssim_map.mean()
39 | else:
40 | return ssim_map.mean(1).mean(1).mean(1)
41 |
42 |
43 | class SSIM(torch.nn.Module):
44 | def __init__(self, window_size=11, size_average=True):
45 | super(SSIM, self).__init__()
46 | self.window_size = window_size
47 | self.size_average = size_average
48 | self.channel = 1
49 | self.window = create_window(window_size, self.channel)
50 |
51 | def forward(self, img1, img2):
52 | (_, channel, _, _) = img1.size()
53 |
54 | if channel == self.channel and self.window.data.type() == img1.data.type():
55 | window = self.window
56 | else:
57 | window = create_window(self.window_size, channel)
58 |
59 | if img1.is_cuda:
60 | window = window.cuda(img1.get_device())
61 | window = window.type_as(img1)
62 |
63 | self.window = window
64 | self.channel = channel
65 |
66 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
67 |
68 |
69 | def ssim(img1, img2, window_size=11, size_average=True):
70 | (_, channel, _, _) = img1.size()
71 | window = create_window(window_size, channel)
72 |
73 | if img1.is_cuda:
74 | window = window.cuda(img1.get_device())
75 | window = window.type_as(img1)
76 |
77 | return _ssim(img1, img2, window, window_size, channel, size_average)
78 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Neural Super-Resolution for Real-time Rendering with Radiance Demodulation (CVPR 2024)
2 |
3 | ### [Paper](https://arxiv.org/abs/2308.06699) | [Project Page](https://riga2.github.io/nsrd/)
4 |
5 | __Note__: Since the dataset is quite large, we have uploaded three scenes to **OneDrive** currently: [Bitsro](https://mailsdueducn-my.sharepoint.com/:u:/g/personal/cuteloong_mail_sdu_edu_cn/EWaiWyZbRxNLgWw7wTV6IBYB2NXLimZ4JcvKhlWpZQ6Jzg?e=hc9scu), [Pica](https://mailsdueducn-my.sharepoint.com/:u:/g/personal/202215216_mail_sdu_edu_cn/ESRf0I6mKH1PqL9wpOW-q7YBYygrKDF9q13piFd1Xyce9g?e=sNhN72) and [San_M](https://mailsdueducn-my.sharepoint.com/:u:/g/personal/202215216_mail_sdu_edu_cn/Eek4usuLcUZGlHCnpzPD63YBukX5FlXELUbTzWa3_zsx1Q?e=nFAwRR).
6 |
7 | If you are in China mainland, you can also access the dataset on [Baidu Cloud Disk](https://pan.baidu.com/s/1GJZ34keRFvGqnJ1Wgg0RHw?pwd=riga).
8 |
9 | 
10 |
11 | ### Installation
12 |
13 | Tested on Windows + CUDA 11.3 + Pytorch 1.12.1
14 |
15 | Install environment:
16 |
17 | ```bazaar
18 | git clone https://github.com/riga2/NSRD.git
19 | cd NSRD
20 | conda create -n NSRD python=3.9
21 | conda activate NSRD
22 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
23 | pip install -r requirements.txt
24 | ```
25 |
26 | The following training and testing take the Bistro scene (X4) as an example.
27 |
28 | ### Training
29 | 1. Make a folder named "dataset", and then download the dataset and put it inside.
30 | ```bazaar
31 | |--configs
32 | |--dataset
33 | |--Bistro
34 | |--train
35 | |---GT
36 | |--0
37 | |--1
38 | ...
39 | |---X4
40 | |--0
41 | |--1
42 | ...
43 | |---test
44 | |---GT
45 | ...
46 | |---X4
47 | ...
48 | ```
49 | 2. Use the Anaconda Prompt to run the following commands to train. The trained model is stored in "experiment\Bistro_X4\model".
50 | ```bazaar
51 | cd src
52 | .\script\BistroX4_train.bat
53 | ```
54 |
55 | ### Testing
56 | 1. Run the following commands to perform super-resolution on the LR lighting components. The SR results are stored in "experiment\Bistro_X4\sr_results_x4".
57 | ```bazaar
58 | cd src
59 | .\test_script\BistroX4_test.bat
60 | ```
61 | 2. Run the following commands to perform remodulation on the SR lighting components. The final results are stored in "experiment\Bistro_X4\final_results_x4".
62 | ```bazaar
63 | cd src
64 | python remodulation.py --exp_dir ../experiment/Bistro_X4 --gt_dir ../dataset/Bistro/test/GT
65 | ```
66 |
67 | ### Citation
68 | ```
69 | @inproceedings{li2024neural,
70 | title={Neural Super-Resolution for Real-time Rendering with Radiance Demodulation},
71 | author={Li, Jia and Chen, Ziling and Wu, Xiaolong and Wang, Lu and Wang, Beibei and Zhang, Lei},
72 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
73 | pages={4357--4367},
74 | year={2024}
75 | }
76 | ```
77 |
--------------------------------------------------------------------------------
/src/data/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | from torch.utils.data import DataLoader
5 | from torch.utils.data.sampler import SubsetRandomSampler, SequentialSampler, Sampler
6 | import copy
7 | from typing import Callable
8 | import random
9 |
10 | class SubsetSequenceSampler(Sampler):
11 | def __init__(self, indices):
12 | super(SubsetSequenceSampler, self).__init__(indices)
13 | self.indices = indices
14 |
15 | def __iter__(self):
16 | return iter(self.indices)
17 |
18 | def __len__(self):
19 | return len(self.indices)
20 |
21 |
22 | class RenderingDataLoader(DataLoader):
23 | def __init__(self, args, dataset):
24 | self.args = args
25 | self.dataset = dataset
26 | self.batch_size = 1 if args.test_only else args.batch_size
27 | self.number_previous_frames = args.num_pre_frames
28 | self.num_frames_samples = args.num_frames_samples
29 | self.test_every = args.test_every
30 | self.n_samples = args.total_folder_num * args.grain
31 |
32 | self.valid_range = []
33 | if args.test_only is False:
34 | valid_folders = np.random.choice(np.arange(0, args.total_folder_num), args.valid_folder_num, replace=False)
35 | for i in valid_folders:
36 | self.valid_range += range(i * args.grain, (i+1) * args.grain)
37 |
38 | self.test_range = []
39 | if args.test_folder is not None:
40 | for i in args.test_folder:
41 | self.test_range += range(i * args.grain, (i+1) * args.grain)
42 |
43 | self.idx_train = []
44 | for idx in range(self.n_samples):
45 | if (idx not in self.valid_range) and (idx not in self.test_range):
46 | idx_mod = idx % args.grain
47 | if (idx_mod >= self.number_previous_frames) and (idx_mod <= args.grain - self.num_frames_samples):
48 | self.idx_train.append(idx)
49 |
50 | self.idx_valid = []
51 | for idx in self.valid_range:
52 | idx_mod = idx % args.grain
53 | if (idx_mod >= self.number_previous_frames) and (idx_mod <= args.grain - self.num_frames_samples):
54 | self.idx_valid.append(idx)
55 |
56 | self.idx_test = []
57 | for idx in self.test_range:
58 | idx_mod = idx % args.grain
59 | if (idx_mod >= self.number_previous_frames) and (idx_mod <= args.grain - self.num_frames_samples):
60 | self.idx_test.append(idx)
61 |
62 | self.sampler, self.valid_sampler = self._split_sampler()
63 |
64 | init_kwargs = {
65 | 'dataset': dataset,
66 | 'batch_size': self.batch_size,
67 | 'num_workers': args.n_threads
68 | }
69 | super().__init__(sampler=self.sampler, **init_kwargs, drop_last=True)
70 |
71 | def _split_sampler(self):
72 | # For test
73 | if len(self.valid_range) == 0:
74 | train_sampler = SubsetSequenceSampler(self.idx_test)
75 | return train_sampler, None
76 |
77 | repeat = (self.batch_size * self.test_every) // len(self.idx_train)
78 | train_idx = np.repeat(self.idx_train, repeat)
79 |
80 | np.random.seed(0)
81 | np.random.shuffle(train_idx)
82 |
83 | train_sampler = SubsetRandomSampler(train_idx)
84 | valid_sampler = SubsetSequenceSampler(self.idx_valid)
85 | self.n_samples = len(train_idx)
86 |
87 | return train_sampler, valid_sampler
88 |
89 | def split_validation(self, valid_dataset):
90 | if self.valid_sampler is None:
91 | return None
92 | else:
93 | valid_dataloader = DataLoader(sampler=self.valid_sampler, dataset=valid_dataset, batch_size=1, num_workers=self.num_workers, drop_last=True)
94 | return valid_dataloader
95 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # user settings
2 | dataset/
3 | experiment/
4 | src/check/*
5 | .idea/
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 | cover/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | .pybuilder/
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | # For a library or package, you might want to ignore these files since the code is
93 | # intended to run in multiple environments; otherwise, check them in:
94 | # .python-version
95 |
96 | # pipenv
97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
100 | # install all needed dependencies.
101 | #Pipfile.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/#use-with-ide
116 | .pdm.toml
117 |
118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
119 | __pypackages__/
120 |
121 | # Celery stuff
122 | celerybeat-schedule
123 | celerybeat.pid
124 |
125 | # SageMath parsed files
126 | *.sage.py
127 |
128 | # Environments
129 | .env
130 | .venv
131 | env/
132 | venv/
133 | ENV/
134 | env.bak/
135 | venv.bak/
136 |
137 | # Spyder project settings
138 | .spyderproject
139 | .spyproject
140 |
141 | # Rope project settings
142 | .ropeproject
143 |
144 | # mkdocs documentation
145 | /site
146 |
147 | # mypy
148 | .mypy_cache/
149 | .dmypy.json
150 | dmypy.json
151 |
152 | # Pyre type checker
153 | .pyre/
154 |
155 | # pytype static type analyzer
156 | .pytype/
157 |
158 | # Cython debug symbols
159 | cython_debug/
160 |
161 | # PyCharm
162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
164 | # and can be added to the global gitignore or merged into this file. For a more nuclear
165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
166 | #.idea/
167 |
--------------------------------------------------------------------------------
/src/remodulation.py:
--------------------------------------------------------------------------------
1 | import configargparse
2 | import os
3 | os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
4 | import cv2
5 | import numpy as np
6 | from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim
7 | from tqdm import tqdm
8 |
9 | def getFromExr(path):
10 | bgr = cv2.imread(path, cv2.IMREAD_UNCHANGED)
11 | rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
12 | return rgb
13 |
14 | def gamma_correct(data):
15 | data = np.where(data < 0.0031, data * 12.92, np.power(data, 1.0/2.4) * 1.055 - 0.055)
16 | return data
17 |
18 | def sRGB(linear):
19 | return np.where(linear < 0.0031308,
20 | 12.92 * linear,
21 | 1.055 * np.power(linear, 1/2.4) - 0.055)
22 |
23 | def toneMapTev(color):
24 | color[:, :, 0] = sRGB(color[:, :, 0])
25 | color[:, :, 1] = sRGB(color[:, :, 1])
26 | color[:, :, 2] = sRGB(color[:, :, 2])
27 |
28 | return np.clip(color, 0.0, 1.0)
29 |
30 |
31 | def calc_psnr_ssim(res_dir, gt_dir):
32 | total_ssim = 0
33 | total_psnr = 0
34 | count = 0
35 | for name in tqdm(os.listdir(res_dir)):
36 | png_path = os.path.join(res_dir, name)
37 | gt_path = os.path.join(gt_dir, name)
38 |
39 | res_png = getFromExr(png_path)
40 | gt_png = getFromExr(gt_path)
41 |
42 | total_psnr += psnr(gt_png, res_png)
43 | total_ssim += ssim(gt_png, res_png, multichannel=True)
44 | count += 1
45 |
46 | print("avg_psnr : {}".format(total_psnr / count))
47 | print("avg_ssim : {}".format(total_ssim / count))
48 |
49 | def remodulation(exp_dir, gt_dir):
50 | avg_psnr, avg_ssim = 0, 0
51 | sr_res_dir = os.path.join(exp_dir, 'sr_results_x4')
52 | img_save_dir = os.path.join(exp_dir, 'final_results_x4')
53 | folder_num = len(os.listdir(sr_res_dir))
54 | for ind in range(folder_num):
55 | os.makedirs(os.path.join(img_save_dir, f'{ind}'), exist_ok=True)
56 |
57 | cur_res_dir = os.path.join(sr_res_dir, f'{ind}')
58 | cur_res_lst = os.listdir(cur_res_dir)
59 | cur_num = len(cur_res_lst)
60 | cur_psnr, cur_ssim = 0, 0
61 | for name in tqdm(cur_res_lst):
62 | res_path = os.path.join(cur_res_dir, name)
63 | png_name = name.split('.')[0] + '.png'
64 |
65 | irr = getFromExr(res_path)
66 | irr[irr < 0] = 0
67 | brdf = getFromExr(os.path.join(gt_dir, f'{ind}', "BRDF", name))
68 | emiss_sky = getFromExr(os.path.join(gt_dir, f'{ind}', 'Emission_Sky', name))
69 | emiss_sky_mask = ((abs(emiss_sky[:, :, 0]) >= 1e-4) | (abs(emiss_sky[:, :, 1]) >= 1e-4) | (abs(emiss_sky[:, :, 2]) >= 1e-4))[:, :, np.newaxis]
70 |
71 | sr_img = brdf * irr
72 | sr_img = np.where(emiss_sky_mask, emiss_sky, sr_img)
73 | sr_img = (toneMapTev(sr_img)*255).astype(np.uint8)
74 |
75 | gt_img = getFromExr(os.path.join(gt_dir, f'{ind}', "View_PNG", png_name))
76 | cur_psnr += psnr(gt_img, sr_img)
77 | cur_ssim += ssim(gt_img, sr_img, win_size=11, channel_axis=2, data_range=255)
78 |
79 | save_path = os.path.join(cur_res_dir, png_name)
80 | cv2.imwrite(save_path, sr_img[:, :, ::-1])
81 |
82 | cur_psnr /= cur_num
83 | cur_ssim /= cur_num
84 | avg_psnr += cur_psnr
85 | avg_ssim += cur_ssim
86 |
87 | avg_psnr /= folder_num
88 | avg_ssim /= folder_num
89 |
90 | print("Avg_pnsr: {}".format(avg_psnr))
91 | print("Avg_ssim: {}".format(avg_ssim))
92 |
93 |
94 | if __name__ == '__main__':
95 | parser = configargparse.ArgumentParser()
96 | parser.add_argument('--exp_dir', type=str, default=r"../experiment/Bistro_X4",
97 | help='experiment dir')
98 | parser.add_argument('--gt_dir', type=str, default=r"../dataset/Bistro/test/GT",
99 | help='ground truth dir, which contains View_PNG, BRDF and Emisson_Sky.')
100 | args = parser.parse_args()
101 |
102 | remodulation(args.exp_dir, args.gt_dir)
103 |
104 |
105 |
--------------------------------------------------------------------------------
/src/model/sr_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from model import sr_common as sr_com
4 | import torch.nn.functional as F
5 | from model.ConvLSTM import ConvLSTM
6 |
7 | def make_model(args):
8 | if args.num_pre_frames != 0:
9 | return SRNet(args)
10 |
11 | class SRNet(nn.Module):
12 | def __init__(self, args=None, conv=sr_com.default_conv):
13 | super(SRNet, self).__init__()
14 | self.scale = args.scale
15 | self.act = nn.ReLU(True)
16 | self.in_dims = args.input_total_dims
17 | self.n_feats = args.n_feats
18 | self.num_previous = args.num_pre_frames
19 | self.total_feats = self.n_feats * (self.num_previous + 1) + self.n_feats
20 |
21 | self.conv1 = conv(self.in_dims, self.n_feats, 3)
22 | self.gate_conv = sr_com.GatedConv2dWithActivation(self.in_dims + 1, self.n_feats, 3, padding=1)
23 |
24 | feats = 64
25 | self.unps = nn.PixelUnshuffle(self.scale)
26 | self.conv2 = conv(3 * self.scale * self.scale, self.n_feats, 3)
27 | self.convLSTM = ConvLSTM(self.total_feats, feats, kernel_size=3)
28 |
29 | # U-shaped reconstruction module
30 | self.encoder_1 = nn.Sequential(
31 | sr_com.RCAB(conv, feats, 3, 16),
32 | sr_com.RCAB(conv, feats, 3, 16),
33 | sr_com.RCAB(conv, feats, 3, 16)
34 | )
35 |
36 | self.encoder_2 = nn.Sequential(
37 | sr_com.RCAB(conv, feats, 3, 16),
38 | sr_com.RCAB(conv, feats, 3, 16)
39 | )
40 |
41 | self.center = nn.Sequential(
42 | sr_com.RCAB(conv, feats, 3, 16),
43 | sr_com.RCAB(conv, feats, 3, 16),
44 | sr_com.RCAB(conv, feats, 3, 16)
45 | )
46 |
47 | self.decoder_2 = nn.Sequential(
48 | conv(feats * 2, feats, 3),
49 | sr_com.RCAB(conv, feats, 3, 16),
50 | sr_com.RCAB(conv, feats, 3, 16)
51 | )
52 |
53 | self.decoder_1 = nn.Sequential(
54 | conv(feats * 2, feats, 3),
55 | sr_com.RCAB(conv, feats, 3, 16),
56 | sr_com.RCAB(conv, feats, 3, 16),
57 | sr_com.RCAB(conv, feats, 3, 16)
58 | )
59 |
60 | self.pooling = nn.MaxPool2d(2)
61 | self.upsize = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
62 | self.conv3 = conv(feats, self.n_feats, 3)
63 |
64 | self.upsampling = nn.Sequential(
65 | conv(self.n_feats, args.output_dims * self.scale * self.scale, 3),
66 | nn.PixelShuffle(self.scale)
67 | )
68 |
69 | def crop_tensor(self, actual: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
70 | # https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py
71 | diffY = int(actual.size()[2] - target.size()[2])
72 | diffX = int(actual.size()[3] - target.size()[3])
73 | x = F.pad(target, [diffX // 2, diffX - diffX // 2,
74 | diffY // 2, diffY - diffY // 2])
75 | return x
76 |
77 | def forward(self, x_tuple):
78 | x, last_sr, prev_state = x_tuple
79 |
80 | # split cur frames
81 | x_cur_frames = x[:, :self.in_dims, :, :]
82 | x_cur_frames = self.conv1(x_cur_frames)
83 |
84 | # the last channel of previous frame is mask
85 | x_pre_frames = []
86 | ch_pre_frames = self.in_dims + 1
87 |
88 | # split previous frames
89 | for i in range(0, self.num_previous):
90 | st, ed = self.in_dims + i * ch_pre_frames, self.in_dims + (i + 1) * ch_pre_frames
91 | pre_frame = x[:, st:ed, :, :]
92 | pre_frame = self.gate_conv(pre_frame)
93 | x_pre_frames.append(pre_frame)
94 | x_pre_frame = torch.cat(x_pre_frames, dim=1)
95 |
96 | # last sr
97 | last_ups = self.unps(last_sr)
98 | last_in = self.conv2(last_ups)
99 |
100 | # path both
101 | x_all = torch.cat((x_cur_frames, x_pre_frame, last_in), dim=1)
102 | state = self.convLSTM(x_all, prev_state)
103 |
104 | x_encoder1 = self.encoder_1(state[0])
105 | x_encoder1_pool = self.pooling(x_encoder1)
106 | x_encoder2 = self.encoder_2(x_encoder1_pool)
107 | x_encoder2_pool = self.pooling(x_encoder2)
108 |
109 | x_center = self.center(x_encoder2_pool)
110 |
111 | x_center_up = self.crop_tensor(x_encoder2, self.upsize(x_center))
112 | x_decoder2 = self.decoder_2(torch.cat((x_center_up, x_encoder2), dim=1))
113 |
114 | x_decoder2_up = self.crop_tensor(x_encoder1, self.upsize(x_decoder2))
115 | x_decoder1 = self.decoder_1(torch.cat((x_decoder2_up, x_encoder1), dim=1))
116 |
117 | x_in = self.conv3(x_decoder1)
118 | x_res = self.upsampling(x_in)
119 |
120 | return x_res, state
121 |
122 |
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import import_module
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.parallel as P
7 | import torch.utils.model_zoo
8 |
9 |
10 | class Model(nn.Module):
11 | def __init__(self, args, ckp):
12 | super(Model, self).__init__()
13 | print('Making model...')
14 |
15 | self.scale = args.scale
16 | self.idx_scale = 0
17 | self.chop = args.chop
18 | self.precision = args.precision
19 | self.cpu = args.cpu
20 | self.device = torch.device('cpu' if args.cpu else 'cuda')
21 | self.n_GPUs = args.n_GPUs
22 | self.save_models = args.save_models
23 |
24 | module = import_module('model.' + args.model.lower())
25 | self.model = module.make_model(args).to(self.device)
26 | if args.precision == 'half':
27 | self.model.half()
28 |
29 | self.load(
30 | ckp.get_path('model'),
31 | resume=args.resume,
32 | cpu=args.cpu
33 | )
34 | print(self.model, file=ckp.log_file)
35 |
36 | def forward(self, x):
37 | if self.training:
38 | if self.n_GPUs > 1:
39 | return P.data_parallel(self.model, x, range(self.n_GPUs))
40 | else:
41 | return self.model(x)
42 | else:
43 | if self.chop:
44 | forward_function = self.forward_chop
45 | else:
46 | forward_function = self.model.forward
47 |
48 | return forward_function(x)
49 |
50 | def save(self, apath, epoch, is_best=False):
51 | save_dirs = [os.path.join(apath, 'model_latest.pt')]
52 |
53 | if is_best:
54 | save_dirs.append(os.path.join(apath, 'model_best.pt'))
55 | if self.save_models:
56 | save_dirs.append(
57 | os.path.join(apath, 'model_{}.pt'.format(epoch))
58 | )
59 |
60 | for s in save_dirs:
61 | torch.save(self.model.state_dict(), s)
62 |
63 | def load(self, apath, resume=-1, cpu=False):
64 | load_from = None
65 | kwargs = {}
66 | if cpu:
67 | kwargs = {'map_location': lambda storage, loc: storage}
68 |
69 | if resume == -1:
70 | load_from = torch.load(
71 | os.path.join(apath, 'model_best.pt'),
72 | **kwargs
73 | )
74 | elif resume == 0:
75 | load_from = torch.load(
76 | os.path.join(apath, 'model_latest.pt'),
77 | **kwargs
78 | )
79 |
80 | if load_from:
81 | self.model.load_state_dict(load_from, strict=False)
82 |
83 | def forward_chop(self, *args, shave=10, min_size=160000):
84 | scale = 1 if self.input_large else self.scale[self.idx_scale]
85 | n_GPUs = min(self.n_GPUs, 4)
86 | # height, width
87 | h, w = args[0].size()[-2:]
88 |
89 | top = slice(0, h//2 + shave)
90 | bottom = slice(h - h//2 - shave, h)
91 | left = slice(0, w//2 + shave)
92 | right = slice(w - w//2 - shave, w)
93 | x_chops = [torch.cat([
94 | a[..., top, left],
95 | a[..., top, right],
96 | a[..., bottom, left],
97 | a[..., bottom, right]
98 | ]) for a in args]
99 |
100 | y_chops = []
101 | if h * w < 4 * min_size:
102 | for i in range(0, 4, n_GPUs):
103 | x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops]
104 | y = P.data_parallel(self.model, *x, range(n_GPUs))
105 | if not isinstance(y, list): y = [y]
106 | if not y_chops:
107 | y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]
108 | else:
109 | for y_chop, _y in zip(y_chops, y):
110 | y_chop.extend(_y.chunk(n_GPUs, dim=0))
111 | else:
112 | for p in zip(*x_chops):
113 | y = self.forward_chop(*p, shave=shave, min_size=min_size)
114 | if not isinstance(y, list): y = [y]
115 | if not y_chops:
116 | y_chops = [[_y] for _y in y]
117 | else:
118 | for y_chop, _y in zip(y_chops, y): y_chop.append(_y)
119 |
120 | h *= scale
121 | w *= scale
122 | top = slice(0, h//2)
123 | bottom = slice(h - h//2, h)
124 | bottom_r = slice(h//2 - h, None)
125 | left = slice(0, w//2)
126 | right = slice(w - w//2, w)
127 | right_r = slice(w//2 - w, None)
128 |
129 | # batch size, number of color channels
130 | b, c = y_chops[0][0].size()[:-2]
131 | y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops]
132 | for y_chop, _y in zip(y_chops, y):
133 | _y[..., top, left] = y_chop[0][..., top, left]
134 | _y[..., top, right] = y_chop[1][..., top, right_r]
135 | _y[..., bottom, left] = y_chop[2][..., bottom_r, left]
136 | _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r]
137 |
138 | if len(y) == 1: y = y[0]
139 |
140 | return y
141 |
--------------------------------------------------------------------------------
/src/data/data_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
3 | import cv2
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import random
8 | import imageio
9 | import torch.nn.functional as F
10 | import time
11 |
12 | def linearUpsample(img, scale):
13 | return cv2.resize(img, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
14 |
15 | def save2Exr(img, path):
16 | cv2.imwrite(path, img[:,:,::-1])
17 |
18 | def getFromBin(path, ih, iw):
19 | return np.fromfile(path, dtype=np.float32).reshape(ih, iw, -1).transpose(0, 1, 2)
20 |
21 | def getFromExr(path):
22 | bgr = cv2.imread(path, cv2.IMREAD_UNCHANGED)
23 | if bgr is None:
24 | print(path)
25 | rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
26 | return rgb
27 |
28 | def bin2exr(pt_path, ih, iw, save_path):
29 | irr = getFromBin(pt_path, ih, iw)[:, :, ::-1]
30 | cv2.imwrite(save_path, irr)
31 |
32 | def get_patch(lr, hr, patch_size=256, scale=2):
33 | ih, iw = lr.shape[1:]
34 | tp = scale * patch_size
35 | ip = tp // scale
36 | ix = random.randrange(0, iw - ip + 1)
37 | iy = random.randrange(0, ih - ip + 1)
38 | tx, ty = scale * ix, scale * iy
39 | ret = [
40 | lr[:, iy:iy + ip, ix:ix + ip],
41 | hr[:, ty:ty + tp, tx:tx + tp]
42 | ]
43 | return ret
44 |
45 | def np2Tensor(np_file):
46 | return torch.from_numpy(np_file).permute(2, 0, 1).float()
47 |
48 | def tensorFromNp(np_file):
49 | np_file = np_file.transpose(2, 0, 1)
50 | return torch.FloatTensor(np_file)
51 |
52 | def mv_mask(ocmv, mv, gate = 0.1):
53 | delta = ocmv - mv
54 | mask = torch.where(torch.abs(delta) < gate, False, True)
55 | x, y = mask[0, :, :], mask[1, :, :]
56 | mask = torch.where(((x) | (y)), 1, 0)
57 | return mask
58 |
59 | def backward_warp_motion(pre: torch.Tensor, motion: torch.Tensor, cur: torch.Tensor) -> torch.Tensor:
60 | # see: https://discuss.pytorch.org/t/image-warping-for-backward-flow-using-forward-flow-matrix-optical-flow/99298
61 | # input image is: [batch, channel, height, width]
62 | # st = time.time()
63 | index_batch, number_channels, height, width = pre.size()
64 | grid_x = torch.arange(width).view(1, -1).repeat(height, 1)
65 | grid_y = torch.arange(height).view(-1, 1).repeat(1, width)
66 | grid_x = grid_x.view(1, 1, height, width).repeat(index_batch, 1, 1, 1)
67 | grid_y = grid_y.view(1, 1, height, width).repeat(index_batch, 1, 1, 1)
68 | #
69 | grid = torch.cat((grid_x, grid_y), 1).float().cuda()
70 | # grid is: [batch, channel (2), height, width]
71 | vgrid = grid - motion
72 | # Grid values must be normalised positions in [-1, 1]
73 | vgrid_x = vgrid[:, 0, :, :]
74 | vgrid_y = vgrid[:, 1, :, :]
75 | vgrid[:, 0, :, :] = (vgrid_x / width) * 2.0 - 1.0
76 | vgrid[:, 1, :, :] = (vgrid_y / height) * 2.0 - 1.0
77 | # swapping grid dimensions in order to match the input of grid_sample.
78 | # that is: [batch, output_height, output_width, grid_pos (2)]
79 | vgrid = vgrid.permute((0, 2, 3, 1))
80 | warped = F.grid_sample(pre, vgrid, align_corners=True)
81 |
82 | # return warped
83 | oox, ooy = torch.split((vgrid < -1) | (vgrid > 1), 1, dim=3)
84 | oo = (oox | ooy).permute(0, 3, 1, 2)
85 | # ed = time.time()
86 | # print('warp {}'.format(ed-st))
87 | return torch.where(oo, cur, warped)
88 |
89 | def warp(x, flow, mode='bilinear', padding_mode='border'):
90 | """ Backward warp `x` according to `flow`
91 |
92 | Both x and flow are pytorch tensor in shape `nchw` and `n2hw`
93 |
94 | Reference:
95 | https://github.com/sniklaus/pytorch-spynet/blob/master/run.py#L41
96 | """
97 |
98 | n, c, h, w = x.size()
99 |
100 | # create mesh grid
101 | iu = torch.linspace(-1.0, 1.0, w).view(1, 1, 1, w).expand(n, -1, h, -1)
102 | iv = torch.linspace(-1.0, 1.0, h).view(1, 1, h, 1).expand(n, -1, -1, w)
103 | grid = torch.cat([iu, iv], 1).to(flow.device)
104 |
105 | # normalize flow to [-1, 1]
106 | flow = torch.cat([
107 | flow[:, 0:1, ...] / ((w - 1.0) / 2.0),
108 | flow[:, 1:2, ...] / ((h - 1.0) / 2.0)], dim=1)
109 |
110 | # add flow to grid and reshape to nhw2
111 | grid = (grid - flow).permute(0, 2, 3, 1)
112 |
113 | # bilinear sampling
114 | # Note: `align_corners` is set to `True` by default for PyTorch version < 1.4.0
115 | if int(''.join(torch.__version__.split('.')[:2])) >= 14:
116 | output = F.grid_sample(
117 | x, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
118 | else:
119 | output = F.grid_sample(x, grid, mode=mode, padding_mode=padding_mode)
120 |
121 | return output
122 |
123 | def toneMapping(data):
124 | return data / (data + 1.0)
125 |
126 | def ldr2hdr(data):
127 | return data / (1.0 - data)
128 |
129 | def chromatic(data):
130 | return 0.5 * data
131 |
132 | def adjust(data):
133 | contrast = 0.0
134 | brightness = 0.0
135 | contrastFactor = (259.0 * (contrast * 256.0 + 255.0)) / (255.0 * (259.0 - 256.0 * contrast))
136 | data = (data - 0.5) * contrastFactor + 0.5 + brightness
137 | return data
138 |
139 |
140 | def gamma_correct(data):
141 | r = 1.0 / 2.2
142 | return np.power(data, r)
143 |
144 |
--------------------------------------------------------------------------------
/src/data/view_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import torch
5 | from data import data_utils
6 | import random
7 | from data.base_dataset import BaseDataset
8 | import imageio
9 | import time
10 | import pickle
11 | warped_save_dir = r'../experiment/warped_img'
12 |
13 | def make_model(args, train=True):
14 | return ViewDataset(args, train)
15 |
16 | class ViewDataset(BaseDataset):
17 | def __init__(self, args, train=True):
18 | super(ViewDataset, self).__init__(args)
19 | self.args = args
20 | self.train = train
21 | self.name = args.data_name
22 | self.crop_size = args.crop_size if train else None
23 | self.num_frames_samples = args.num_frames_samples if train else 1
24 | self.upsample = torch.nn.Upsample(scale_factor=self.scale, mode='bilinear', align_corners=True)
25 | self.view_dirname = args.view_dirname
26 |
27 | def load_View_HR(self, folder_index, file_index, ext='.png'):
28 | filename = str(file_index) + ext
29 | hr_file_path = os.path.join(self.gt_dir, str(folder_index), self.view_dirname, filename)
30 | hr = data_utils.getFromExr(hr_file_path).astype(np.float32) / 255.0
31 | return hr
32 |
33 | def load_View_LR(self, folder_index, file_index, ext='.png'):
34 | filename = str(file_index) + ext
35 | lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.view_dirname, filename)
36 | lr = data_utils.getFromExr(lr_file_path).astype(np.float32) / 255.0
37 | return lr
38 |
39 | def __getitem__(self, item):
40 | folder_index = item // self.grain
41 | file_index = item % self.grain
42 |
43 | HR_lst = []
44 | LR_lst = []
45 | MV_up_lst = []
46 | OCMV_up_lst = []
47 | Mask_up_lst = []
48 |
49 | lr_sh, lr_sw, hr_sh, hr_sw = 0, 0, 0, 0
50 | lr_eh, lr_ew = self.lr_size
51 | hr_eh, hr_ew = self.hr_size
52 | if self.crop_size is not None:
53 | ih, iw = self.lr_size
54 | tp = self.scale * self.crop_size
55 | ip = tp // self.scale
56 | ix = random.randrange(0, iw - ip + 1)
57 | iy = random.randrange(0, ih - ip + 1)
58 | tx, ty = self.scale * ix, self.scale * iy
59 | lr_sh, lr_sw, hr_sh, hr_sw = iy, ix, ty, tx
60 | lr_eh, lr_ew, hr_eh, hr_ew = iy+ip, ix+ip, ty+tp, tx+tp
61 |
62 | for index in range(file_index, file_index + self.num_frames_samples):
63 | # HR
64 | hr = self.load_View_HR(folder_index, index)
65 | hr = data_utils.np2Tensor(hr)
66 | HR_lst.append(hr[:, hr_sh:hr_eh, hr_sw:hr_ew])
67 |
68 | # LR
69 | lr_files, mvs, ocmvs = [], [], []
70 | for idx in range(index, index - self.number_previous_frames - 1, -1):
71 | file = self.load_View_LR(folder_index, idx)
72 | if self.useNormal:
73 | normal = self.load_Normal_Unity(folder_index, idx)
74 | file = np.concatenate((file, normal), axis=2)
75 | if self.useDepth:
76 | depth = self.load_Depth_Unity(folder_index, idx)
77 | file = np.concatenate((file, depth), axis=2)
78 | file = data_utils.np2Tensor(file)
79 | lr_files.append(file[:, lr_sh:lr_eh, lr_sw:lr_ew])
80 |
81 | # mv and ocmv
82 | if (idx != index - self.number_previous_frames):
83 | mv = data_utils.np2Tensor(self.load_MV(folder_index, idx))
84 | ocmv = data_utils.np2Tensor(self.load_OCMV(folder_index, idx))
85 | mvs.append(mv[:, lr_sh:lr_eh, lr_sw:lr_ew])
86 | ocmvs.append(ocmv[:, lr_sh:lr_eh, lr_sw:lr_ew])
87 | if (idx == index):
88 | mv_up = self.upsample(mv[None, :])[0]
89 | ocmv_up = self.upsample(ocmv[None, :])[0]
90 | mask_up = 1 - data_utils.mv_mask(ocmv_up, mv_up)[None, :]
91 | MV_up_lst.append(mv_up[:, hr_sh:hr_eh, hr_sw:hr_ew])
92 | OCMV_up_lst.append(ocmv_up[:, hr_sh:hr_eh, hr_sw:hr_ew])
93 | Mask_up_lst.append(mask_up[:, hr_sh:hr_eh, hr_sw:hr_ew])
94 |
95 | # pre frames do Warp
96 | for i in range(self.number_previous_frames, 0, -1):
97 | for j in range(0, i-1):
98 | mvs[i-1] += mvs[j]
99 | ocmvs[i-1] += ocmvs[j]
100 | lr_files[i] = data_utils.backward_warp_motion(lr_files[i][None, :].cuda(), mvs[i-1][None, :].cuda(),
101 | lr_files[0][None, :].cuda())[0].cpu()
102 |
103 | # pre frames cat mask
104 | for i in range(0, self.number_previous_frames):
105 | mask = data_utils.mv_mask(ocmvs[i], mvs[i])[None, :]
106 | lr_files[i+1] = torch.cat((lr_files[i+1], mask), dim=0)
107 |
108 | lr = torch.cat(lr_files, dim=0)
109 | LR_lst.append(lr)
110 |
111 | return LR_lst, HR_lst, MV_up_lst, Mask_up_lst, str(file_index)
112 |
113 | def __len__(self):
114 | return self.total_folder_num * self.grain
--------------------------------------------------------------------------------
/src/model/sr_common.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 |
5 | def depth_wise_conv(in_feats, out_feats, kernel, bias=True):
6 | return nn.Sequential(
7 | nn.Conv2d(in_feats, in_feats, kernel_size=kernel, padding=(kernel // 2), groups=in_feats, bias=bias),
8 | nn.Conv2d(in_feats, out_feats, kernel_size=1)
9 | )
10 |
11 |
12 | def default_conv(in_feats, out_feats, kernel_size, bias=True):
13 | return nn.Conv2d(
14 | in_feats, out_feats, kernel_size,
15 | padding=(kernel_size // 2), bias=bias)
16 |
17 |
18 | # Channel Attention (CA) Layer
19 | class CALayer(nn.Module):
20 | def __init__(self, channel, reduction=16):
21 | super(CALayer, self).__init__()
22 | # global average pooling: feature --> point
23 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
24 | # feature channel downscale and upscale --> channel weight
25 | self.conv_du = nn.Sequential(
26 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
27 | nn.ReLU(inplace=True),
28 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
29 | nn.Sigmoid()
30 | )
31 |
32 | def forward(self, x):
33 | y = self.avg_pool(x)
34 | y = self.conv_du(y)
35 | return x * y
36 |
37 |
38 | # Residual Channel Attention Block (RCAB)
39 | class RCAB(nn.Module):
40 | def __init__(
41 | self, conv, n_feat, kernel_size, reduction,
42 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
43 | super(RCAB, self).__init__()
44 | modules_body = []
45 | for i in range(2):
46 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
47 | if bn: modules_body.append(nn.BatchNorm2d(n_feat))
48 | if i == 0: modules_body.append(act)
49 | modules_body.append(CALayer(n_feat, reduction))
50 | self.body = nn.Sequential(*modules_body)
51 | self.res_scale = res_scale
52 |
53 | def forward(self, x):
54 | res = self.body(x)
55 | # res = self.body(x).mul(self.res_scale)
56 | res += x
57 | return res
58 |
59 |
60 | class CA(nn.Module):
61 | def __init__(self, channel, reduction=16):
62 | super(CA, self).__init__()
63 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
64 | self.max_pool = nn.AdaptiveAvgPool2d(1)
65 |
66 | self.fc = nn.Sequential(
67 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
68 | nn.ReLU(inplace=True),
69 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
70 | )
71 |
72 | self.sigmoid = nn.Sigmoid()
73 |
74 | def forward(self, x):
75 | avg_out = self.fc(self.avg_pool(x))
76 | max_out = self.fc(self.max_pool(x))
77 | out = avg_out + max_out
78 | return x * self.sigmoid(out)
79 |
80 |
81 | class SA(nn.Module):
82 | def __init__(self, kernel_size=7, bn=True):
83 | super(SA, self).__init__()
84 | self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
85 | self.bn = nn.BatchNorm2d(1, eps=1e-5, momentum=0.01, affine=True) if bn else None
86 | self.sigmoid = nn.Sigmoid()
87 |
88 | def forward(self, x):
89 | avg_out = torch.mean(x, dim=1, keepdim=True)
90 | max_out, _ = torch.max(x, dim=1, keepdim=True)
91 | out = torch.cat((avg_out, max_out), dim=1)
92 | out = self.conv(out)
93 | if self.bn is not None:
94 | out = self.bn(out)
95 | return x * self.sigmoid(out)
96 |
97 |
98 | ## CBAM Block
99 | class CBAM(nn.Module):
100 | def __init__(
101 | self, conv, n_feat, kernel_size, reduction,
102 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
103 | super(CBAM, self).__init__()
104 | modules_body = []
105 | for i in range(2):
106 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
107 | if bn: modules_body.append(nn.BatchNorm2d(n_feat))
108 | if i == 0: modules_body.append(act)
109 | modules_body.append(CALayer(n_feat, reduction))
110 | modules_body.append(SA())
111 | self.body = nn.Sequential(*modules_body)
112 | self.res_scale = res_scale
113 |
114 | def forward(self, x):
115 | res = self.body(x)
116 | # res = self.body(x).mul(self.res_scale)
117 | res += x
118 | return res
119 |
120 |
121 | class GatedConv2dWithActivation(torch.nn.Module):
122 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1, dilation=1, groups=1, bias=True,
123 | batch_norm=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)):
124 | super(GatedConv2dWithActivation, self).__init__()
125 | self.batch_norm = batch_norm
126 | self.activation = activation
127 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
128 | self.mask_conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
129 | bias)
130 | self.batch_norm2d = torch.nn.BatchNorm2d(out_channels)
131 | self.sigmoid = torch.nn.Sigmoid()
132 |
133 | for m in self.modules():
134 | if isinstance(m, nn.Conv2d):
135 | nn.init.kaiming_normal_(m.weight)
136 |
137 | def gated(self, mask):
138 | # return torch.clamp(mask, -1, 1)
139 | return self.sigmoid(mask)
140 |
141 | def forward(self, input):
142 | x = self.conv2d(input)
143 | mask = self.mask_conv2d(input)
144 | if self.activation is not None:
145 | x = self.activation(x) * self.gated(mask)
146 | else:
147 | x = x * self.gated(mask)
148 | if self.batch_norm:
149 | return self.batch_norm2d(x)
150 | else:
151 | return x
152 |
153 |
154 |
--------------------------------------------------------------------------------
/src/data/irradiance_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import torch
5 | from data import data_utils
6 | import random
7 | from data.base_dataset import BaseDataset
8 | import imageio
9 | import time
10 | import pickle
11 | warped_save_dir = r'../experiment/warped_img'
12 |
13 | def make_model(args, train=True):
14 | return IrradianceDataset(args, train)
15 |
16 | class IrradianceDataset(BaseDataset):
17 | def __init__(self, args, train=True):
18 | super(IrradianceDataset, self).__init__(args)
19 | self.args = args
20 | self.train = train
21 | self.name = args.data_name
22 | self.crop_size = args.crop_size if train else None
23 | self.num_frames_samples = args.num_frames_samples if train else 1
24 | self.upsample = torch.nn.Upsample(scale_factor=self.scale, mode='bilinear', align_corners=True)
25 | self.irradiance_dirname = args.irradiance_dirname
26 |
27 | def load_Irradiance_HR(self, folder_index, file_index, ext='.exr'):
28 | filename = str(file_index) + ext
29 | hr_file_path = os.path.join(self.gt_dir, str(folder_index), self.irradiance_dirname, filename)
30 | hr = data_utils.getFromExr(hr_file_path)
31 | return hr
32 |
33 | def load_Irradiance_LR(self, folder_index, file_index, ext='.exr'):
34 | filename = str(file_index) + ext
35 | lr_file_path = os.path.join(self.lr_dir, str(folder_index), self.irradiance_dirname, filename)
36 | lr = data_utils.getFromExr(lr_file_path)
37 | return lr
38 |
39 | def __getitem__(self, item):
40 | folder_index = item // self.grain
41 | file_index = item % self.grain
42 |
43 | HR_lst = []
44 | LR_lst = []
45 | MV_up_lst = []
46 | OCMV_up_lst = []
47 | Mask_up_lst = []
48 |
49 | lr_sh, lr_sw, hr_sh, hr_sw = 0, 0, 0, 0
50 | lr_eh, lr_ew = self.lr_size
51 | hr_eh, hr_ew = self.hr_size
52 | if self.crop_size is not None:
53 | ih, iw = self.lr_size
54 | tp = self.scale * self.crop_size
55 | ip = tp // self.scale
56 | ix = random.randrange(0, iw - ip + 1)
57 | iy = random.randrange(0, ih - ip + 1)
58 | tx, ty = self.scale * ix, self.scale * iy
59 | lr_sh, lr_sw, hr_sh, hr_sw = iy, ix, ty, tx
60 | lr_eh, lr_ew, hr_eh, hr_ew = iy+ip, ix+ip, ty+tp, tx+tp
61 |
62 | for index in range(file_index, file_index + self.num_frames_samples):
63 | # HR
64 | hr = self.load_Irradiance_HR(folder_index, index)
65 | # hdr -> ldr
66 | # hr = data_utils.toneMapping(hr)
67 | hr = data_utils.np2Tensor(hr)
68 | HR_lst.append(hr[:, hr_sh:hr_eh, hr_sw:hr_ew])
69 |
70 | # LR
71 | lr_files, mvs, ocmvs = [], [], []
72 | for idx in range(index, index - self.number_previous_frames - 1, -1):
73 | file = self.load_Irradiance_LR(folder_index, idx)
74 | # hdr -> ldr
75 | # file = data_utils.toneMapping(file)
76 | if self.useNormal:
77 | normal = self.load_Normal_Unity(folder_index, idx)
78 | file = np.concatenate((file, normal), axis=2)
79 | if self.useDepth:
80 | depth = self.load_Depth_Unity(folder_index, idx)
81 | file = np.concatenate((file, depth), axis=2)
82 | file = data_utils.np2Tensor(file)
83 | lr_files.append(file[:, lr_sh:lr_eh, lr_sw:lr_ew])
84 |
85 | # mv and ocmv
86 | if (idx != index - self.number_previous_frames):
87 | mv = data_utils.np2Tensor(self.load_MV(folder_index, idx))
88 | ocmv = data_utils.np2Tensor(self.load_OCMV(folder_index, idx))
89 | mvs.append(mv[:, lr_sh:lr_eh, lr_sw:lr_ew])
90 | ocmvs.append(ocmv[:, lr_sh:lr_eh, lr_sw:lr_ew])
91 | if (idx == index):
92 | mv_up = self.upsample(mv[None, :])[0]
93 | ocmv_up = self.upsample(ocmv[None, :])[0]
94 | mask_up = 1 - data_utils.mv_mask(ocmv_up, mv_up)[None, :]
95 | MV_up_lst.append(mv_up[:, hr_sh:hr_eh, hr_sw:hr_ew])
96 | OCMV_up_lst.append(ocmv_up[:, hr_sh:hr_eh, hr_sw:hr_ew])
97 | Mask_up_lst.append(mask_up[:, hr_sh:hr_eh, hr_sw:hr_ew])
98 |
99 | # pre frames do Warp
100 | for i in range(self.number_previous_frames, 0, -1):
101 | for j in range(0, i-1):
102 | mvs[i-1] += mvs[j]
103 | ocmvs[i-1] += ocmvs[j]
104 | lr_files[i] = data_utils.backward_warp_motion(lr_files[i][None, :].cuda(), mvs[i-1][None, :].cuda(),
105 | lr_files[0][None, :].cuda())[0].cpu()
106 |
107 | # pre frames cat mask
108 | for i in range(0, self.number_previous_frames):
109 | mask = data_utils.mv_mask(ocmvs[i], mvs[i])[None, :]
110 | lr_files[i+1] = torch.cat((lr_files[i+1], mask), dim=0)
111 |
112 | lr = torch.cat(lr_files, dim=0)
113 | LR_lst.append(lr)
114 |
115 | # for i in range(1, self.num_frames_samples-1):
116 | # tmp_mv = MV_up_lst[i]
117 | # tmp_ocmv = OCMV_up_lst[i]
118 | # for j in range(0, i):
119 | # tmp_mv += MV_up_lst[j]
120 | # tmp_ocmv += OCMV_up_lst[i]
121 | # tmp_mask = 1 - data_utils.mv_mask(tmp_ocmv, tmp_mv)[None, :]
122 | # MV_up_lst.append(tmp_mv)
123 | # Mask_up_lst.append(tmp_mask)
124 |
125 | # return LR_lst[0], HR_lst[0], str(file_index)
126 | return LR_lst, HR_lst, MV_up_lst, Mask_up_lst, str(file_index)
127 |
128 | def __len__(self):
129 | return self.total_folder_num * self.grain
--------------------------------------------------------------------------------
/src/loss/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import import_module
3 |
4 | import matplotlib
5 |
6 | matplotlib.use('Agg')
7 | import matplotlib.pyplot as plt
8 |
9 | import numpy as np
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 |
15 | class Loss(nn.modules.loss._Loss):
16 | def __init__(self, args, ckp):
17 | super(Loss, self).__init__()
18 | print('Preparing loss function:')
19 |
20 | self.n_GPUs = args.n_GPUs
21 | self.loss = []
22 | self.loss_module = nn.ModuleList()
23 | for loss in args.loss.split('+'):
24 | weight, loss_type = loss.split('*')
25 | if loss_type == 'MSE':
26 | loss_function = nn.MSELoss()
27 | elif loss_type == 'L1':
28 | loss_function = nn.SmoothL1Loss()
29 | elif loss_type.find('VGG') >= 0:
30 | module = import_module('loss.vgg')
31 | loss_function = getattr(module, 'VGG')(
32 | loss_type[3:]
33 | )
34 | elif loss_type.find('GAN') >= 0:
35 | module = import_module('loss.adversarial')
36 | loss_function = getattr(module, 'Adversarial')(
37 | args,
38 | loss_type
39 | )
40 | elif loss_type.find('Temporal') >= 0:
41 | module = import_module('loss.temporal_loss')
42 | loss_function = getattr(module, 'Temporal_Loss')()
43 | elif loss_type.find('SSIM') >= 0:
44 | module = import_module('loss.ssim_loss')
45 | loss_function = getattr(module, 'SSIM')()
46 |
47 | self.loss.append({
48 | 'type': loss_type,
49 | 'weight': float(weight),
50 | 'function': loss_function})
51 | if loss_type.find('GAN') >= 0:
52 | self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
53 |
54 | if len(self.loss) > 1:
55 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
56 |
57 | for l in self.loss:
58 | if l['function'] is not None:
59 | print('{:.3f} * {}'.format(l['weight'], l['type']))
60 | self.loss_module.append(l['function'])
61 |
62 | self.log = torch.Tensor()
63 |
64 | device = torch.device('cpu' if args.cpu else 'cuda')
65 | self.loss_module.to(device)
66 | if args.precision == 'half': self.loss_module.half()
67 | if not args.cpu and args.n_GPUs > 1:
68 | self.loss_module = nn.DataParallel(
69 | self.loss_module, range(args.n_GPUs)
70 | )
71 |
72 | if args.load != '': self.load(ckp.dir, cpu=args.cpu)
73 |
74 | # loss(sr_cur, hr, sr_warped, sr_merge, mask_up)
75 | def forward(self, sr_cur, hr, sr_warped=None, mask_up=None, needTem=True):
76 | losses = []
77 | for i, l in enumerate(self.loss):
78 | if l['type'] == 'Temporal' and needTem:
79 | sr_warped_masked, merge_masked = sr_warped * mask_up, sr_cur * mask_up
80 | loss = l['function'](sr_warped_masked, merge_masked)
81 | effective_loss = l['weight'] * loss
82 | losses.append(effective_loss)
83 | self.log[-1, i] += effective_loss.item()
84 | elif l['function'] is not None:
85 | loss = l['function'](sr_cur, hr)
86 | effective_loss = l['weight'] * loss
87 | losses.append(effective_loss)
88 | self.log[-1, i] += effective_loss.item()
89 | elif l['type'] == 'DIS':
90 | self.log[-1, i] += self.loss[i - 1]['function'].loss
91 |
92 | loss_sum = sum(losses)
93 | if len(self.loss) > 1:
94 | self.log[-1, -1] += loss_sum.item()
95 |
96 | return loss_sum
97 |
98 | def step(self):
99 | for l in self.get_loss_module():
100 | if hasattr(l, 'scheduler'):
101 | l.scheduler.step()
102 |
103 | def start_log(self):
104 | self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
105 |
106 | def end_log(self, n_batches):
107 | self.log[-1].div_(n_batches)
108 |
109 | def display_loss(self, batch):
110 | n_samples = batch + 1
111 | log = []
112 | for l, c in zip(self.loss, self.log[-1]):
113 | log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
114 |
115 | return ''.join(log)
116 |
117 | def plot_loss(self, apath, epoch):
118 | axis = np.linspace(1, epoch, epoch)
119 | for i, l in enumerate(self.loss):
120 | label = '{} Loss'.format(l['type'])
121 | fig = plt.figure()
122 | plt.title(label)
123 | plt.plot(axis, self.log[:, i].numpy(), label=label)
124 | plt.legend()
125 | plt.xlabel('Epochs')
126 | plt.ylabel('Loss')
127 | plt.grid(True)
128 | plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))
129 | plt.close(fig)
130 |
131 | def get_loss_module(self):
132 | if self.n_GPUs == 1:
133 | return self.loss_module
134 | else:
135 | return self.loss_module.module
136 |
137 | def save(self, apath):
138 | torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
139 | torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
140 |
141 | def load(self, apath, cpu=False):
142 | if cpu:
143 | kwargs = {'map_location': lambda storage, loc: storage}
144 | else:
145 | kwargs = {}
146 |
147 | self.load_state_dict(torch.load(
148 | os.path.join(apath, 'loss.pt'),
149 | **kwargs
150 | ))
151 | self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
152 | for l in self.get_loss_module():
153 | if hasattr(l, 'scheduler'):
154 | for _ in range(len(self.log)): l.scheduler.step()
155 |
--------------------------------------------------------------------------------
/src/option.py:
--------------------------------------------------------------------------------
1 | import configargparse
2 |
3 | parser = configargparse.ArgumentParser(description='NSRD')
4 |
5 | # Config file
6 | parser.add_argument('--config', is_config_file=True, default='../configs/Bistro_train.txt',
7 | help='config file path')
8 |
9 | # Hardware specifications
10 | parser.add_argument('--n_threads', type=int, default=6,
11 | help='number of threads for data loading')
12 | parser.add_argument('--cpu', action='store_true',
13 | help='use cpu only')
14 | parser.add_argument('--n_GPUs', type=int, default=1,
15 | help='number of GPUs')
16 | parser.add_argument('--seed', type=int, default=1,
17 | help='random seed')
18 |
19 | # Data specifications
20 | parser.add_argument('--root_dir', type=str, default=r"../dataset/Bistro/train",
21 | help='rendering data_dir')
22 | parser.add_argument('--data_name', type=str, default='Bistro',
23 | help='test dataset name')
24 | parser.add_argument('--view_dirname', type=str, default='View',
25 | help='view dirname')
26 | parser.add_argument('--irradiance_dirname', type=str, default='Irradiance',
27 | help='Irradiance dirname')
28 | parser.add_argument('--depth_dirname', type=str, default='Depth',
29 | help='depth dirname')
30 | parser.add_argument('--normal_dirname', type=str, default='Normal',
31 | help='normal dirname')
32 | parser.add_argument('--mv_dirname', type=str, default='MV',
33 | help='LR mv dirname')
34 | parser.add_argument('--ocmv_dirname', type=str, default='OCMV',
35 | help='LR ocmv dirname')
36 | parser.add_argument('--brdf_dirname', type=str, default='BRDF',
37 | help='LR ocmv dirname')
38 |
39 | parser.add_argument('--use_normal', type=bool, default=True,
40 | help='use normal')
41 | parser.add_argument('--use_depth', type=bool, default=True,
42 | help='use depth')
43 | parser.add_argument('--num_frames_samples', type=int, default=4,
44 | help='number of previous frames')
45 | parser.add_argument('--num_pre_frames', type=int, default=2,
46 | help='number of previous frames')
47 | parser.add_argument('--input_total_dims', type=int, default=3+1+3,
48 | help='input total dims')
49 | parser.add_argument('--output_dims', type=int, default=3,
50 | help='output dims')
51 | parser.add_argument('--sr_content', type=str, default='Irradiance',
52 | help='choose which content to super resolution, '
53 | '(Irradiance, View, LMDB)')
54 |
55 | parser.add_argument('--gt_size', type=tuple, default=(1080, 1920),
56 | help='GT size(ih, iw)')
57 | parser.add_argument('--scale', type=int, default=4,
58 | help='super resolution scale')
59 | parser.add_argument('--crop_size', type=int, default=96,
60 | help='output patch size')
61 | parser.add_argument('--chop', action='store_true',
62 | help='enable memory-efficient forward')
63 | parser.add_argument('--grain', type=int, default=100,
64 | help='file number in a folder')
65 | parser.add_argument('--total_folder_num', type=int, default=60,
66 | help='total folder num')
67 | parser.add_argument('--test_folder', type=int, action="append",
68 | help='test folder index')
69 | parser.add_argument('--valid_folder_num', type=int, default=6,
70 | help='valid folder number')
71 |
72 | # Model specifications
73 | parser.add_argument('--model', default='sr_model',
74 | help='model name')
75 | parser.add_argument('--extend', type=str, default='.',
76 | help='pre-trained model directory')
77 | parser.add_argument('--n_feats', type=int, default=32,
78 | help='number of feature maps')
79 | parser.add_argument('--precision', type=str, default='single',
80 | choices=('single', 'half'),
81 | help='FP precision for test (single | half)')
82 |
83 | # Training specifications
84 | parser.add_argument('--reset', action='store_true',
85 | help='reset the training')
86 | parser.add_argument('--test_every', type=int, default=1000,
87 | help='do test per every N batches')
88 | parser.add_argument('--epochs', type=int, default=200,
89 | help='number of epochs to train')
90 | parser.add_argument('--batch_size', type=int, default=8,
91 | help='input batch size for training')
92 | parser.add_argument('--test_only', action='store_true', default=False,
93 | help='set this option to test the model')
94 |
95 | # Optimization specifications
96 | parser.add_argument('--lr', type=float, default=0.0005,
97 | help='learning rate')
98 | parser.add_argument('--decay', type=str, default='100',
99 | help='learning rate decay type')
100 | parser.add_argument('--gamma', type=float, default=0.5,
101 | help='learning rate decay factor for step decay')
102 | parser.add_argument('--optimizer', default='ADAM',
103 | choices=('SGD', 'ADAM', 'RMSprop'),
104 | help='optimizer to use (SGD | ADAM | RMSprop)')
105 | parser.add_argument('--momentum', type=float, default=0.9,
106 | help='SGD momentum')
107 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
108 | help='ADAM beta')
109 | parser.add_argument('--epsilon', type=float, default=1e-8,
110 | help='ADAM epsilon for numerical stability')
111 | parser.add_argument('--weight_decay', type=float, default=0,
112 | help='weight decay')
113 | parser.add_argument('--gclip', type=float, default=0,
114 | help='gradient clipping threshold (0 = no clipping)')
115 |
116 | # Loss specifications
117 | parser.add_argument('--loss', type=str, default='1.0*L1+1.0*SSIM+1.0*Temporal',
118 | help='loss function configuration')
119 | parser.add_argument('--skip_threshold', type=float, default='1e8',
120 | help='skipping batch that has large error')
121 |
122 | # Log specifications
123 | parser.add_argument('--save', type=str, default='Bistro_X4',
124 | help='file name to save')
125 | parser.add_argument('--load', type=str, default='',
126 | help='file name to load')
127 | parser.add_argument('--resume', type=int, default=1,
128 | help='resume from specific checkpoint: -1 -> best, 0 -> latest, else -> None')
129 | parser.add_argument('--save_models', action='store_true',
130 | help='save all intermediate models')
131 | parser.add_argument('--print_every', type=int, default=100,
132 | help='how many batches to wait before logging training status')
133 | parser.add_argument('--save_results', action='store_true', default=False,
134 | help='save output results')
135 | parser.add_argument('--save_gt', action='store_true',
136 | help='save low-resolution and high-resolution images together')
137 |
138 | args = parser.parse_args()
139 |
140 | for arg in vars(args):
141 | if vars(args)[arg] == 'True':
142 | vars(args)[arg] = True
143 | elif vars(args)[arg] == 'False':
144 | vars(args)[arg] = False
145 |
146 |
147 |
--------------------------------------------------------------------------------
/src/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
3 | import math
4 | import time
5 | from decimal import Decimal
6 | import numpy as np
7 | import imageio
8 | import utility
9 |
10 | import torch
11 | import torch.nn.utils as utils
12 | from tqdm import tqdm
13 | from data import data_utils
14 |
15 | class Trainer():
16 | def __init__(self, args, loader, my_model, my_loss, ckp):
17 | super(Trainer, self).__init__()
18 | self.args = args
19 | self.scale = args.scale
20 | self.gt_size = args.gt_size
21 | self.batch_size = args.batch_size
22 | self.ckp = ckp
23 | self.model = my_model
24 | self.num_frames_samples = args.num_frames_samples
25 | self.train_loader = loader.loader_train
26 | self.valid_loader = loader.loader_valid
27 | self.loss = my_loss
28 | self.optimizer = utility.make_optimizer(args, self.model)
29 |
30 | if self.args.load != '':
31 | self.optimizer.load(ckp.dir, epoch=len(ckp.log))
32 |
33 | self.error_last = 1e8
34 |
35 | def train(self):
36 | self.loss.step()
37 | epoch = self.optimizer.get_last_epoch() + 1
38 | lr = self.optimizer.get_lr()
39 |
40 | self.ckp.write_log(
41 | '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
42 | )
43 | self.loss.start_log()
44 | self.model.train()
45 |
46 | timer_data, timer_model = utility.timer(), utility.timer()
47 | # TEMP
48 | for batch, (LR_lst, HR_lst, MV_up_lst, Mask_up_lst, _) in enumerate(self.train_loader):
49 | self.optimizer.zero_grad()
50 |
51 | b, c, h, w = HR_lst[0].size()
52 | zero_tensor = torch.zeros(b, c, h, w, dtype=torch.float32)
53 | lr0, zero_tensor, hr0 = self.prepare(LR_lst[0], zero_tensor, HR_lst[0])
54 |
55 | sr_pre, lstm_state = self.model((lr0, zero_tensor, None))
56 | lstm_state = utility.repackage_hidden(lstm_state)
57 | loss = self.loss(sr_pre, hr0, needTem=False)
58 |
59 | for i in range(1, self.num_frames_samples):
60 | sr_pre = sr_pre.detach()
61 | sr_pre.requires_grad = False
62 |
63 | lr, hr, mv_up, mask_up = self.prepare(LR_lst[i], HR_lst[i], MV_up_lst[i], Mask_up_lst[i])
64 |
65 | timer_data.hold()
66 | timer_model.tic()
67 |
68 | sr_pre_warped = data_utils.warp(sr_pre, mv_up)
69 | sr_cur, lstm_state = self.model((lr, sr_pre_warped, lstm_state))
70 | lstm_state = utility.repackage_hidden(lstm_state)
71 |
72 | loss += self.loss(sr_cur, hr, sr_pre_warped, mask_up, needTem=True)
73 | sr_pre = sr_cur
74 | loss.backward()
75 | self.optimizer.step()
76 | timer_model.hold()
77 | if (batch + 1) % self.args.print_every == 0:
78 | self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
79 | (batch + 1) * self.args.batch_size,
80 | self.train_loader.n_samples,
81 | self.loss.display_loss(batch),
82 | timer_model.release(),
83 | timer_data.release()))
84 |
85 | timer_data.tic()
86 |
87 | self.loss.end_log(len(self.train_loader))
88 | self.error_last = self.loss.log[-1, -1]
89 | self.optimizer.schedule()
90 |
91 | def test(self):
92 | torch.set_grad_enabled(False)
93 |
94 | epoch = self.optimizer.get_last_epoch()
95 | self.ckp.write_log('\nEvaluation:')
96 | self.ckp.add_log(
97 | torch.zeros(1, 3)
98 | )
99 | self.model.eval()
100 |
101 | timer_test = utility.timer()
102 | run_model_time = 0
103 | flag = 0
104 | if self.args.save_results: self.ckp.begin_background()
105 |
106 | pre_sr = torch.zeros(1, 3, self.gt_size[0], self.gt_size[1],
107 | dtype=torch.float32).cuda()
108 | lstm_state = None
109 | os.makedirs(".\check", exist_ok=True)
110 | for index, (LR_lst, HR_lst, MV_up_lst, Mask_up_lst, filename) in tqdm(enumerate(self.valid_loader)):
111 | lr, hr, mv_up, mask_up = self.prepare(LR_lst[0], HR_lst[0], MV_up_lst[0], Mask_up_lst[0])
112 | if index == 0:
113 | pre_sr, lstm_state = self.model((lr, pre_sr, lstm_state))
114 | lstm_state = utility.repackage_hidden(lstm_state)
115 | continue
116 | t1 = time.time()
117 | sr_pre_warped = data_utils.warp(pre_sr, mv_up)
118 | cur_sr, lstm_state = self.model((lr, sr_pre_warped, lstm_state))
119 | lstm_state = utility.repackage_hidden(lstm_state)
120 | t2 = time.time()
121 | run_model_time += (t2 - t1)
122 | if self.args.sr_content == "View":
123 | sr = utility.quantize_img(cur_sr)
124 | sr_last = utility.quantize_img(pre_sr)
125 | if flag < 2:
126 | data_utils.save2Exr(np.array(sr[0, :3, :, :].permute(1, 2, 0).detach().cpu()) * 255,
127 | ".\\check\\sr_" + str(flag) + ".png")
128 | data_utils.save2Exr(np.array(hr[0, :3, :, :].permute(1, 2, 0).detach().cpu()) * 255,
129 | ".\\check\\gt_" + str(flag) + ".png")
130 | flag += 1
131 | else:
132 | sr = utility.quantize(cur_sr)
133 | sr_last = utility.quantize(pre_sr)
134 | if flag < 2:
135 | data_utils.save2Exr(np.array(sr[0, :3, :, :].permute(1, 2, 0).detach().cpu()),
136 | ".\\check\\sr_" + str(flag) + ".exr")
137 | data_utils.save2Exr(np.array(hr[0, :3, :, :].permute(1, 2, 0).detach().cpu()),
138 | ".\\check\\gt_" + str(flag) + ".exr")
139 | flag += 1
140 |
141 | pre_sr = cur_sr
142 | save_list = [sr]
143 | assert sr is not torch.nan, "sr is nan!"
144 | val_ssim = 1.0 - utility.calc_ssim(sr, hr).cpu()
145 | warped_sr = data_utils.warp(sr_last, mv_up)
146 | val_tempory = utility.calc_tempory(warped_sr, sr, mask_up).cpu()
147 |
148 | self.ckp.log[-1, 0] += val_ssim
149 | self.ckp.log[-1, 1] += val_tempory
150 | self.ckp.log[-1, 2] += val_tempory + val_ssim
151 |
152 | if self.args.save_gt:
153 | save_list.extend([lr, hr])
154 |
155 | if self.args.save_results:
156 | self.ckp.save_results(self.valid_loader, filename[0], save_list, self.scale)
157 |
158 | self.ckp.log[-1] /= (len(self.valid_loader) - 1)
159 | best = self.ckp.log.min(0)
160 | self.ckp.write_log(
161 | '[{} x{}]\tSSIM: {:.6f}, Tempory: {:.6f}, Total :{:.6f} (Best: {:.6f} @epoch {})'.format(
162 | self.valid_loader.dataset.name,
163 | self.scale,
164 | self.ckp.log[-1][0],
165 | self.ckp.log[-1][1],
166 | self.ckp.log[-1][2],
167 | best[0][2],
168 | best[1][2] + 1
169 | )
170 | )
171 |
172 | self.ckp.write_log('Run model time {:.5f}s\n'.format(run_model_time))
173 | self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
174 | self.ckp.write_log('Saving...')
175 |
176 | if self.args.save_results:
177 | self.ckp.end_background()
178 |
179 | if not self.args.test_only:
180 | self.ckp.save(self, epoch, is_best=(best[0][2] is not torch.nan and best[1][2] + 1 == epoch))
181 |
182 | self.ckp.write_log(
183 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
184 | )
185 |
186 | torch.set_grad_enabled(True)
187 |
188 | def prepare(self, *args):
189 | device = torch.device('cpu' if self.args.cpu else 'cuda')
190 | def _prepare(tensor):
191 | if self.args.precision == 'half': tensor = tensor.half()
192 | return tensor.to(device)
193 |
194 | return [_prepare(a) for a in args]
195 |
196 | def terminate(self):
197 | if self.args.test_only:
198 | self.test()
199 | return True
200 | else:
201 | epoch = self.optimizer.get_last_epoch() + 1
202 | return epoch >= self.args.epochs
--------------------------------------------------------------------------------
/src/utility.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import time
4 | import datetime
5 | import threading
6 | from queue import Queue
7 | import matplotlib
8 |
9 | matplotlib.use('Agg')
10 | import matplotlib.pyplot as plt
11 | from model.ssim import ssim
12 | import numpy as np
13 |
14 | import torch
15 | import torch.optim as optim
16 | import torch.optim.lr_scheduler as lrs
17 | import torch.nn.functional as F
18 | from data.data_utils import save2Exr
19 |
20 |
21 | class timer():
22 | def __init__(self):
23 | self.acc = 0
24 | self.tic()
25 |
26 | def tic(self):
27 | self.t0 = time.time()
28 |
29 | def toc(self, restart=False):
30 | diff = time.time() - self.t0
31 | if restart: self.t0 = time.time()
32 | return diff
33 |
34 | def hold(self):
35 | self.acc += self.toc()
36 |
37 | def release(self):
38 | ret = self.acc
39 | self.acc = 0
40 |
41 | return ret
42 |
43 | def reset(self):
44 | self.acc = 0
45 |
46 |
47 | class checkpoint():
48 | def __init__(self, args):
49 | self.args = args
50 | self.ok = True
51 | self.log = torch.Tensor()
52 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
53 |
54 | if not args.load:
55 | if not args.save:
56 | args.save = now
57 | self.dir = os.path.join('..', 'experiment', args.save)
58 | else:
59 | self.dir = os.path.join('..', 'experiment', args.load)
60 | if os.path.exists(self.dir):
61 | self.log = torch.load(self.get_path('ssim_log.pt'))
62 | print('Continue from epoch {}...'.format(len(self.log)))
63 | else:
64 | args.load = ''
65 |
66 | if args.reset:
67 | os.system('rm -rf ' + self.dir)
68 | args.load = ''
69 |
70 | os.makedirs(self.dir, exist_ok=True)
71 | os.makedirs(self.get_path('model'), exist_ok=True)
72 | if args.test_only:
73 | assert len(args.test_folder) == 1, "Currently only supports testing one folder at a time!"
74 | self.save_dir = 'sr_results_x{}/{}'.format(args.scale, args.test_folder[0])
75 | os.makedirs(self.get_path(self.save_dir), exist_ok=True)
76 |
77 | open_type = 'a' if os.path.exists(self.get_path('log.txt')) else 'w'
78 | self.log_file = open(self.get_path('log.txt'), open_type)
79 | with open(self.get_path('config.txt'), open_type) as f:
80 | f.write(now + '\n\n')
81 | for arg in vars(args):
82 | f.write('{}: {}\n'.format(arg, getattr(args, arg)))
83 | f.write('\n')
84 |
85 | self.n_processes = 4
86 |
87 | def get_path(self, *subdir):
88 | return os.path.join(self.dir, *subdir)
89 |
90 | def save(self, trainer, epoch, is_best=False):
91 | trainer.model.save(self.get_path('model'), epoch, is_best=is_best)
92 | trainer.loss.save(self.dir)
93 | trainer.loss.plot_loss(self.dir, epoch)
94 |
95 | self.plot_ssim(epoch)
96 | trainer.optimizer.save(self.dir)
97 | torch.save(self.log, self.get_path('ssim_log.pt'))
98 |
99 | def add_log(self, log):
100 | self.log = torch.cat([self.log, log])
101 |
102 | def write_log(self, log, refresh=False):
103 | print(log)
104 | self.log_file.write(log + '\n')
105 | if refresh:
106 | self.log_file.close()
107 | self.log_file = open(self.get_path('log.txt'), 'a')
108 |
109 | def done(self):
110 | self.log_file.close()
111 |
112 | def plot_psnr(self, epoch):
113 | axis = np.linspace(1, epoch, epoch)
114 | label = 'SR on {}'.format(self.args.data_name)
115 | fig = plt.figure()
116 | plt.title(label)
117 | plt.plot(
118 | axis,
119 | self.log[:].numpy(),
120 | label='Scale {}'.format(self.args.scale)
121 | )
122 | plt.legend()
123 | plt.xlabel('Epochs')
124 | plt.ylabel('PSNR')
125 | plt.grid(True)
126 | plt.savefig(self.get_path('test_{}.pdf'.format(self.args.data_name)))
127 | plt.close(fig)
128 |
129 | def plot_ssim(self, epoch):
130 | axis = np.linspace(1, epoch, epoch)
131 | label = 'SR on {}'.format(self.args.data_name)
132 | fig = plt.figure()
133 | plt.title(label)
134 | plt.plot(
135 | axis,
136 | (self.log[:]).numpy(),
137 | label='Scale {}'.format(self.args.scale)
138 | )
139 | plt.legend()
140 | plt.xlabel('Epochs')
141 | plt.ylabel('SSIM')
142 | plt.grid(True)
143 | plt.savefig(self.get_path('test_{}.pdf'.format(self.args.data_name)))
144 | plt.close(fig)
145 |
146 | def begin_background(self):
147 | self.queue = Queue()
148 |
149 | def bg_target(queue):
150 | while True:
151 | if not queue.empty():
152 | filename, tensor = queue.get()
153 | if filename is None: break
154 | im_np = tensor.numpy()
155 | save2Exr(im_np, filename)
156 | # im_np.tofile(filename)
157 | # imageio.imwrite(filename, tensor.numpy())
158 |
159 | self.process = [
160 | threading.Thread(target=bg_target, args=(self.queue,)) \
161 | for _ in range(self.n_processes)
162 | ]
163 |
164 | for p in self.process: p.start()
165 |
166 | def end_background(self):
167 | for _ in range(self.n_processes): self.queue.put((None, None))
168 | while not self.queue.empty(): time.sleep(1)
169 | for p in self.process: p.join()
170 |
171 | def save_results(self, dataset, filename, save_list, scale):
172 | if self.args.save_results:
173 | filename = self.get_path(
174 | self.save_dir,
175 | '{}'.format(filename)
176 | )
177 |
178 | if self.args.sr_content == 'View':
179 | for v in save_list:
180 | tensor_cpu = v[0].permute(1, 2, 0).cpu() * 255.0
181 | self.queue.put(('{}.png'.format(filename), tensor_cpu))
182 | else:
183 | for v in save_list:
184 | tensor_cpu = v[0].permute(1, 2, 0).cpu()
185 | self.queue.put(('{}.exr'.format(filename), tensor_cpu))
186 |
187 |
188 | def quantize(tensor):
189 | return torch.clamp(tensor, min=0.0)
190 |
191 |
192 | def quantize_img(tensor):
193 | return torch.clamp(tensor, min=0.0, max=1.0)
194 |
195 |
196 | def hdr2ldr(tensor):
197 | def adjust(data):
198 | contrast = 0.0
199 | brightness = 0.0
200 | contrastFactor = (259.0 * (contrast * 256.0 + 255.0)) / (255.0 * (259.0 - 256.0 * contrast))
201 | data = (data - 0.5) * contrastFactor + 0.5 + brightness
202 | return data
203 |
204 | chromatic = 0.5 * tensor
205 | adj = adjust(chromatic)
206 | tonemapping = adj / (adj + 1.0)
207 | gamma = torch.pow(tonemapping, 1.0 / 2.2)
208 | return gamma
209 |
210 |
211 | def calc_psnr(sr, hr, scale, rgb_range):
212 | if hr.nelement() == 1: return 0
213 |
214 | diff = (sr - hr) / rgb_range
215 | shave = scale + 6
216 |
217 | valid = diff[..., shave:-shave, shave:-shave]
218 | mse = valid.pow(2).mean()
219 |
220 | return -10 * math.log10(mse)
221 |
222 |
223 | def calc_mse(sr, hr):
224 | return F.mse_loss(sr, hr)
225 |
226 |
227 | def calc_ssim(sr, hr):
228 | return ssim(sr, hr)
229 |
230 |
231 | def calc_tempory(warped_sr, merge_sr, noc_mask):
232 | return F.l1_loss(warped_sr * noc_mask, merge_sr * noc_mask)
233 |
234 |
235 | def repackage_hidden(h):
236 | """Wraps hidden states in new Variables, to detach them from their history."""
237 | if isinstance(h, torch.Tensor):
238 | return h.detach()
239 | else:
240 | return tuple(repackage_hidden(v) for v in h)
241 |
242 |
243 | def make_optimizer(args, target):
244 | '''
245 | make optimizer and scheduler together
246 | '''
247 | # optimizer
248 | trainable = filter(lambda x: x.requires_grad, target.parameters())
249 | kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay}
250 |
251 | if args.optimizer == 'SGD':
252 | optimizer_class = optim.SGD
253 | kwargs_optimizer['momentum'] = args.momentum
254 | elif args.optimizer == 'ADAM':
255 | optimizer_class = optim.Adam
256 | kwargs_optimizer['betas'] = args.betas
257 | kwargs_optimizer['eps'] = args.epsilon
258 | elif args.optimizer == 'RMSprop':
259 | optimizer_class = optim.RMSprop
260 | kwargs_optimizer['eps'] = args.epsilon
261 |
262 | # scheduler
263 | milestones = list(map(lambda x: int(x), args.decay.split('-')))
264 | kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma}
265 | scheduler_class = lrs.MultiStepLR
266 |
267 | class CustomOptimizer(optimizer_class):
268 | def __init__(self, *args, **kwargs):
269 | super(CustomOptimizer, self).__init__(*args, **kwargs)
270 |
271 | def _register_scheduler(self, scheduler_class, **kwargs):
272 | self.scheduler = scheduler_class(self, **kwargs)
273 |
274 | def save(self, save_dir):
275 | torch.save(self.state_dict(), self.get_dir(save_dir))
276 |
277 | def load(self, load_dir, epoch=1):
278 | self.load_state_dict(torch.load(self.get_dir(load_dir)))
279 | if epoch > 1:
280 | for _ in range(epoch): self.scheduler.step()
281 |
282 | def get_dir(self, dir_path):
283 | return os.path.join(dir_path, 'optimizer.pt')
284 |
285 | def schedule(self):
286 | self.scheduler.step()
287 |
288 | def get_lr(self):
289 | return self.scheduler.get_last_lr()[0]
290 |
291 | def get_last_epoch(self):
292 | return self.scheduler.last_epoch
293 |
294 | optimizer = CustomOptimizer(trainable, **kwargs_optimizer)
295 | optimizer._register_scheduler(scheduler_class, **kwargs_scheduler)
296 | return optimizer
297 |
--------------------------------------------------------------------------------