├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── calc_metrics.py ├── dataset ├── HDR.py └── HDRpatches.py ├── models └── NHDRRNet.py ├── test.py ├── test.sh ├── train.py ├── train.sh └── utils ├── HDRutils.py ├── configs.py ├── dataprocessor.py ├── dataset.py ├── loss.py ├── metrics.py └── solvers.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Galaxies99/NHDRRNet-pytorch/b20aae987e586a6cf9c9c52fc07b0884ce6fdf37/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # dataset 2 | /data 3 | /patches 4 | /samples 5 | 6 | # model storage 7 | /checkpoint 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Tony Fang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NHDRRNet-pytorch 2 | 📷 NHDRRNet (TIP'20) implementation using PyTorch framework 3 | 4 | ## Introduction 5 | 6 | This repository is the implementation of NHDRRNet [2] using PyTorch framework. The author did not open the code, therefore, we create this repository to implement NHDRRNet using PyTorch framework. 7 | 8 | ## Requirements 9 | 10 | + PyTorch 1.4+ 11 | + Cuda version 10.1+ 12 | + OpenCV 13 | + numpy, tqdm, scipy, etc 14 | 15 | ## Getting Started 16 | 17 | ### Download Dataset 18 | 19 | The Kalantari Dataset can be downloaded from https://www.robots.ox.ac.uk/~szwu/storage/hdr/kalantari_dataset.zip [2]. 20 | 21 | ### Dataset Model Selection 22 | 23 | There are two dataset models provided in `dataset` folder. Using `HDRpatches.py` will generate patches in `patches` folder and will cost ~200GB spaces, but it runs faster. Using `HDR.py` (default) will open image file only when it needs to do so, thus it will save disk space. Feel free to choose the method you want. 24 | 25 | ### Configs Modifications 26 | 27 | + You may modify the arguments in `Configs()` to satisfy your own environment, for specific arguments descriptions, see `utils/configs.py`. 28 | + You may modify arguments of NHDRRNet to train a better model, for specific arguments descriptions, see config dictionary in `models/NHDRRNet.py`. 29 | 30 | ### Train 31 | 32 | ```bash 33 | python train.py 34 | ``` 35 | 36 | ### Test 37 | 38 | First, make sure that you have models (`checkpoint.tar`) under `checkpoint_dir` (which is defined in `Configs()`). 39 | 40 | ```bash 41 | python test.py 42 | ``` 43 | 44 | **Note**. `test.py` will dump the result images in `sample` folder. 45 | 46 | ### Tone-mapping (post-processing) 47 | 48 | Generated HDR images are in `.hdr` format, which may not be properly displayed in your image viewer directly. You may use [Photomatix](https://www.hdrsoft.com/) for tonemapping [2]: 49 | 50 | - Download [Photomatix](https://www.hdrsoft.com/) free trial, which won't expire. 51 | - Load the generated `.hdr` file in Photomatix. 52 | - Adjust the parameter settings. You may refer to pre-defined styles, such as `Detailed` and `Painterly2`. 53 | - Save your final image in `.tif` or `.jpg`. 54 | 55 | ## Reference 56 | 57 | [1] Yan, Qingsen, et al. "Deep hdr imaging via a non-local network." *IEEE Transactions on Image Processing* 29 (2020): 4308-4322. 58 | 59 | [2] elliottwu/DeepHDR repository: https://github.com/elliottwu/DeepHDR -------------------------------------------------------------------------------- /calc_metrics.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import argparse 4 | from utils.dataprocessor import * 5 | from utils.metrics import * 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--gt_path', type=str, default='data/kalantari_dataset/test') 9 | parser.add_argument('--test_path', type=str, default='') 10 | 11 | configs = parser.parse_args() 12 | file_path = configs.test_path 13 | gt_path = configs.gt_path 14 | 15 | dirs = [] 16 | for dir in os.listdir(file_path): 17 | if os.path.isdir(os.path.join(file_path, dir)): 18 | dirs.append(dir) 19 | 20 | dirs = sorted(dirs) 21 | 22 | psnr = PSNR() 23 | ssim = SSIM() 24 | 25 | total_psnr = 0 26 | total_ssim = 0 27 | 28 | for dir in dirs: 29 | gt_file = os.path.join(os.path.join(gt_path, dir), 'ref_hdr_aligned.hdr') 30 | my_file = os.path.join(os.path.join(file_path, dir), 'hdr.hdr') 31 | hdr = get_image(my_file) 32 | h, w, _ = hdr.shape 33 | hdr_gt = get_image(gt_file, [h, w], True) 34 | hdr_gt = inverse_transform(hdr_gt) 35 | hdr = inverse_transform(hdr) 36 | print('------------------------------------------') 37 | print('scene ', dir) 38 | cur_psnr = psnr(hdr, hdr_gt) 39 | cur_ssim = ssim(hdr, hdr_gt) 40 | print('PSNR:', cur_psnr) 41 | print('SSIM:', cur_ssim) 42 | total_psnr += cur_psnr 43 | total_ssim += cur_ssim 44 | 45 | print('******************************************') 46 | print('Final Report:') 47 | print(' Average PSNR: ', total_psnr / len(dirs)) 48 | print(' Average SSIM: ', total_ssim / len(dirs)) 49 | -------------------------------------------------------------------------------- /dataset/HDR.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from glob import glob 4 | import numpy as np 5 | from utils.dataprocessor import * 6 | from utils.HDRutils import * 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class KalantariDataset(Dataset): 11 | def __init__(self, configs, 12 | input_name='input_*_aligned.tif', 13 | ref_name='ref_*_aligned.tif', 14 | input_exp_name='input_exp.txt', 15 | ref_exp_name='ref_exp.txt', 16 | ref_hdr_name='ref_hdr_aligned.hdr'): 17 | super().__init__() 18 | print('====> Start preparing training data.') 19 | 20 | # Some basic information 21 | self.filepath = os.path.join(configs.data_path, 'train') 22 | self.scene_dirs = [scene_dir for scene_dir in os.listdir(self.filepath) 23 | if os.path.isdir(os.path.join(self.filepath, scene_dir))] 24 | self.scene_dirs = sorted(self.scene_dirs) 25 | self.num_scenes = len(self.scene_dirs) 26 | self.patch_size = configs.patch_size 27 | self.image_size = configs.image_size 28 | self.patch_stride = configs.patch_stride 29 | self.num_shots = configs.num_shots 30 | self.input_name = input_name 31 | self.ref_name = ref_name 32 | self.input_exp_name = input_exp_name 33 | self.ref_exp_name = ref_exp_name 34 | self.ref_hdr_name = ref_hdr_name 35 | self.total_count = 0 36 | # Count the number of patches in each trainning image 37 | self.count = [] 38 | for i, scene_dir in enumerate(self.scene_dirs): 39 | cur_scene_dir = os.path.join(self.filepath, scene_dir) 40 | in_LDR_paths = sorted(glob(os.path.join(cur_scene_dir, input_name))) 41 | tmp_img = get_image(in_LDR_paths[0]).astype(np.float32) 42 | h, w, c = tmp_img.shape 43 | if h < self.patch_size[0] or w < self.patch_size[1]: 44 | raise AttributeError('The size of some trainning images are smaller than the patch size.') 45 | h_count = np.ceil(h / self.patch_stride) 46 | w_count = np.ceil(w / self.patch_stride) 47 | self.count.append(h_count * w_count) 48 | self.total_count = self.total_count + h_count * w_count 49 | self.count = np.array(self.count).astype(int) 50 | self.total_count = int(self.total_count) 51 | 52 | print('====> Finish preparing training data!') 53 | 54 | def __len__(self): 55 | return self.total_count 56 | 57 | def __getitem__(self, index): 58 | # Find the corresponding image 59 | idx_beg = 0 60 | cur_scene_dir = "" 61 | scene_idx = -1 62 | scene_posidx = -1 63 | for i, scene_dir in enumerate(self.scene_dirs): 64 | idx_end = idx_beg + self.count[i] 65 | if idx_beg <= index < idx_end: 66 | cur_scene_dir = os.path.join(self.filepath, scene_dir) 67 | scene_idx = i 68 | scene_posidx = index - idx_beg 69 | break 70 | idx_beg = idx_end 71 | if scene_idx == -1: 72 | raise ValueError('Index out of bound') 73 | 74 | in_LDR_paths = sorted(glob(os.path.join(cur_scene_dir, self.input_name))) 75 | tmp_img = get_image(in_LDR_paths[0]) 76 | h, w, c = tmp_img.shape 77 | 78 | # Count the indices of h and w 79 | h_count = np.ceil(h / self.patch_stride) 80 | w_count = np.ceil(w / self.patch_stride) 81 | h_idx = int(scene_posidx / w_count) 82 | w_idx = int(scene_posidx - h_idx * w_count) 83 | 84 | # Count the up, down, left, right of the patch 85 | h_up = h_idx * self.patch_stride 86 | h_down = h_idx * self.patch_stride + self.patch_size[0] 87 | if h_down > h: 88 | h_up = h - self.patch_size[0] 89 | h_down = h 90 | 91 | w_left = w_idx * self.patch_stride 92 | w_right = w_idx * self.patch_stride + self.patch_size[1] 93 | if w_right > w: 94 | w_left = w - self.patch_size[1] 95 | w_right = w 96 | 97 | # Get the input images 98 | in_LDR = np.zeros((self.patch_size[0], self.patch_size[1], c * self.num_shots)) 99 | for j, in_LDR_path in enumerate(in_LDR_paths): 100 | in_LDR[:, :, j * c:(j + 1) * c] = get_image(in_LDR_path)[h_up:h_down, w_left:w_right, :] 101 | in_LDR = np.array(in_LDR).astype(np.float32) 102 | 103 | in_exp_path = os.path.join(cur_scene_dir, self.input_exp_name) 104 | in_exp = np.array(open(in_exp_path).read().split('\n')[:self.num_shots]).astype(np.float32) 105 | 106 | ref_HDR = get_image(os.path.join(cur_scene_dir, self.ref_hdr_name))[h_up:h_down, w_left:w_right, :] 107 | 108 | ref_LDR_paths = sorted(glob(os.path.join(cur_scene_dir, self.ref_name))) 109 | ref_LDR = np.zeros((self.patch_size[0], self.patch_size[1], c * self.num_shots)) 110 | for j, ref_LDR_path in enumerate(ref_LDR_paths): 111 | ref_LDR[:, :, j * c:(j + 1) * c] = get_image(ref_LDR_path)[h_up:h_down, w_left:w_right, :] 112 | ref_LDR = np.array(ref_LDR).astype(np.float32) 113 | 114 | ref_exp_path = os.path.join(cur_scene_dir, self.ref_exp_name) 115 | ref_exp = np.array(open(ref_exp_path).read().split('\n')[:self.num_shots]).astype(np.float32) 116 | 117 | # Make some random transformation. 118 | distortions = np.random.uniform(0.0, 1.0, 2) 119 | # Horizontal flip 120 | if distortions[0] < 0.5: 121 | in_LDR = np.flip(in_LDR, axis=1) 122 | ref_LDR = np.flip(ref_LDR, axis=1) 123 | ref_HDR = np.flip(ref_HDR, axis=1) 124 | 125 | # Rotation 126 | k = int(distortions[1] * 4 + 0.5) 127 | in_LDR = np.rot90(in_LDR, k) 128 | ref_LDR = np.rot90(ref_LDR, k) 129 | ref_HDR = np.rot90(ref_HDR, k) 130 | in_exp = 2 ** in_exp 131 | ref_exp = 2 ** ref_exp 132 | 133 | in_HDR = LDR2HDR_batch(in_LDR, in_exp) 134 | 135 | # In pytorch, channels is in axis=1, so ijk -> kij 136 | in_LDR = np.einsum("ijk->kij", in_LDR) 137 | ref_LDR = np.einsum("ijk->kij", ref_LDR) 138 | in_HDR = np.einsum("ijk->kij", in_HDR) 139 | ref_HDR = np.einsum("ijk->kij", ref_HDR) 140 | return in_LDR.copy().astype(np.float32), ref_LDR.copy().astype(np.float32), \ 141 | in_HDR.copy().astype(np.float32), ref_HDR.copy().astype(np.float32), \ 142 | in_exp.copy().astype(np.float32), ref_exp.copy().astype(np.float32) 143 | 144 | 145 | class KalantariTestDataset(Dataset): 146 | def __init__(self, configs, 147 | input_name= 'input_*_aligned.tif', 148 | input_exp_name = 'input_exp.txt', 149 | ref_hdr_name = 'ref_hdr_aligned.hdr'): 150 | super().__init__() 151 | print('====> Start preparing testing data.') 152 | self.filepath = os.path.join(configs.data_path, 'test') 153 | self.scene_dirs = [scene_dir for scene_dir in os.listdir(self.filepath) 154 | if os.path.isdir(os.path.join(self.filepath, scene_dir))] 155 | self.scene_dirs = sorted(self.scene_dirs) 156 | self.num_scenes = len(self.scene_dirs) 157 | self.patch_size = configs.patch_size 158 | self.patch_stride = configs.patch_stride 159 | self.num_shots = configs.num_shots 160 | self.sample_path = configs.sample_dir 161 | self.input_name = input_name 162 | self.input_exp_name = input_exp_name 163 | self.ref_hdr_name = ref_hdr_name 164 | print('====> Finish preparing testing data!') 165 | 166 | def __len__(self): 167 | return self.num_scenes 168 | 169 | def __getitem__(self, index): 170 | scene_dir = self.scene_dirs[index] 171 | scene_path = os.path.join(self.filepath, scene_dir) 172 | sample_path = os.path.join(self.sample_path, scene_dir) 173 | LDR_path = os.path.join(scene_path, self.input_name) 174 | exp_path = os.path.join(scene_path, self.input_exp_name) 175 | ref_HDR_path = os.path.join(scene_path, self.ref_hdr_name) 176 | in_LDR, in_HDRs, in_exp, ref_HDR = get_input(LDR_path, exp_path, ref_HDR_path) 177 | in_LDR = np.einsum("ijk->kij", in_LDR) 178 | in_HDRs = np.einsum("ijk->kij", in_HDRs) 179 | ref_HDR = np.einsum("ijk->kij", ref_HDR) 180 | return sample_path, \ 181 | in_LDR.copy().astype(np.float32), \ 182 | in_HDRs.copy().astype(np.float32), \ 183 | in_exp.copy().astype(np.float32), \ 184 | ref_HDR.copy().astype(np.float32) 185 | 186 | -------------------------------------------------------------------------------- /dataset/HDRpatches.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from glob import glob 4 | import numpy as np 5 | from utils.dataprocessor import * 6 | from utils.HDRutils import * 7 | from torch.utils.data import Dataset 8 | 9 | 10 | 11 | class KalantariDataset(Dataset): 12 | def __init__(self, configs, 13 | input_name= 'input_*_aligned.tif', 14 | ref_name = 'ref_*_aligned.tif', 15 | input_exp_name = 'input_exp.txt', 16 | ref_exp_name = 'ref_exp.txt', 17 | ref_hdr_name = 'ref_hdr_aligned.hdr'): 18 | super().__init__() 19 | print('====> Start preparing training data.') 20 | filepath = os.path.join(configs.data_path, 'train') 21 | self.scene_dirs = [scene_dir for scene_dir in os.listdir(filepath) 22 | if os.path.isdir(os.path.join(filepath, scene_dir))] 23 | self.scene_dirs = sorted(self.scene_dirs) 24 | self.num_scenes = len(self.scene_dirs) 25 | 26 | self.patch_size = configs.patch_size 27 | self.patch_stride = configs.patch_stride 28 | self.num_shots = configs.num_shots 29 | self.patch_path = configs.patch_dir 30 | self.count = len(os.listdir(self.patch_path)) 31 | 32 | if self.count == 0: 33 | for i, scene_dir in enumerate(self.scene_dirs): 34 | print(i) 35 | cur_scene_dir = os.path.join(filepath, scene_dir) 36 | in_LDR_paths = sorted(glob(os.path.join(cur_scene_dir, input_name))) 37 | tmp_img = get_image(in_LDR_paths[0]) 38 | h, w, c = tmp_img.shape 39 | in_LDRs = np.zeros((h, w, c * self.num_shots)) 40 | for j, in_LDR_path in enumerate(in_LDR_paths): 41 | in_LDRs[:, :, j * c:(j + 1) * c] = get_image(in_LDR_path) 42 | in_LDRs = in_LDRs.astype(np.float32) 43 | in_exps_path = os.path.join(cur_scene_dir, input_exp_name) 44 | in_exps = np.array(open(in_exps_path).read().split('\n')[:self.num_shots]).astype(np.float32) 45 | ref_HDR = get_image(os.path.join(cur_scene_dir, ref_hdr_name)) 46 | ref_LDR_paths = sorted(glob(os.path.join(cur_scene_dir, ref_name))) 47 | ref_LDRs = np.zeros((h, w, c * self.num_shots)) 48 | for j, ref_LDR_path in enumerate(ref_LDR_paths): 49 | ref_LDRs[:, :, j * c:(j + 1) * c] = get_image(ref_LDR_path) 50 | ref_exps_path = os.path.join(cur_scene_dir, ref_exp_name) 51 | ref_exps = np.array(open(ref_exps_path).read().split('\n')[:self.num_shots]).astype(np.float32) 52 | 53 | # Cut the images into several patches 54 | # Store patches into files to save memory. 55 | for _h in range(0, h - self.patch_size[0] + 1, self.patch_stride): 56 | for _w in range(0, w - self.patch_size[1] + 1, self.patch_stride): 57 | store_patch(_h, _h + self.patch_size[0], _w, _w + self.patch_size[1], in_LDRs, in_exps, 58 | ref_HDR, ref_LDRs, ref_exps, self.patch_path, self.count) 59 | self.count += 1 60 | if h % self.patch_size[0]: 61 | for _w in range(0, w - self.patch_size[1] + 1, self.patch_stride): 62 | store_patch(h - self.patch_size[0], h, _w, _w + self.patch_size[1], in_LDRs, in_exps, 63 | ref_HDR, ref_LDRs, ref_exps, self.patch_path, self.count) 64 | self.count += 1 65 | if w % self.patch_size[1]: 66 | for _h in range(0, h - self.patch_size[0] + 1, self.patch_stride): 67 | store_patch(_h, _h + self.patch_size[0], w - self.patch_size[1], w, in_LDRs, in_exps, 68 | ref_HDR, ref_LDRs, ref_exps, self.patch_path, self.count) 69 | self.count += 1 70 | if h % self.patch_size[0] and w % self.patch_size[1]: 71 | store_patch(h - self.patch_size[0], h, w - self.patch_size[1], w, in_LDRs, in_exps, 72 | ref_HDR, ref_LDRs, ref_exps, self.patch_path, self.count) 73 | self.count += 1 74 | print('====> Finish preparing training data!') 75 | 76 | def __len__(self): 77 | return self.count 78 | 79 | def __getitem__(self, index): 80 | distortions = np.random.uniform(0.0, 1.0, 2) 81 | 82 | data = get_patch_from_file(self.patch_path, index) 83 | in_LDR = data['in_LDR'] 84 | ref_LDR = data['ref_LDR'] 85 | ref_HDR = data['ref_HDR'] 86 | in_exp = data['in_exp'] 87 | ref_exp = data['ref_exp'] 88 | 89 | # Horizontal flip 90 | if distortions[0] < 0.5: 91 | in_LDR = np.flip(in_LDR, axis=1) 92 | ref_LDR = np.flip(ref_LDR, axis=1) 93 | ref_HDR = np.flip(ref_HDR, axis=1) 94 | 95 | # Rotation 96 | k = int(distortions[1] * 4 + 0.5) 97 | in_LDR = np.rot90(in_LDR, k) 98 | ref_LDR = np.rot90(ref_LDR, k) 99 | ref_HDR = np.rot90(ref_HDR, k) 100 | in_exp = 2 ** in_exp 101 | ref_exp = 2 ** ref_exp 102 | 103 | in_HDR = LDR2HDR_batch(in_LDR, in_exp) 104 | 105 | # In pytorch, channels is in axis=1, so ijk -> kij 106 | in_LDR = np.einsum("ijk->kij", in_LDR) 107 | ref_LDR = np.einsum("ijk->kij", ref_LDR) 108 | in_HDR = np.einsum("ijk->kij", in_HDR) 109 | ref_HDR = np.einsum("ijk->kij", ref_HDR) 110 | return in_LDR.copy().astype(np.float32), ref_LDR.copy().astype(np.float32), \ 111 | in_HDR.copy().astype(np.float32), ref_HDR.copy().astype(np.float32), \ 112 | in_exp.copy().astype(np.float32), ref_exp.copy().astype(np.float32) 113 | 114 | 115 | class KalantariTestDataset(Dataset): 116 | def __init__(self, configs, 117 | input_name= 'input_*_aligned.tif', 118 | input_exp_name = 'input_exp.txt', 119 | ref_hdr_name = 'ref_hdr_aligned.hdr'): 120 | super().__init__() 121 | print('====> Start preparing testing data.') 122 | self.filepath = os.path.join(configs.data_path, 'test') 123 | self.scene_dirs = [scene_dir for scene_dir in os.listdir(self.filepath) 124 | if os.path.isdir(os.path.join(self.filepath, scene_dir))] 125 | self.scene_dirs = sorted(self.scene_dirs) 126 | self.num_scenes = len(self.scene_dirs) 127 | self.patch_size = configs.patch_size 128 | self.patch_stride = configs.patch_stride 129 | self.num_shots = configs.num_shots 130 | self.sample_path = configs.sample_dir 131 | self.input_name = input_name 132 | self.input_exp_name = input_exp_name 133 | self.ref_hdr_name = ref_hdr_name 134 | print('====> Finish preparing testing data!') 135 | 136 | def __len__(self): 137 | return self.num_scenes 138 | 139 | def __getitem__(self, index): 140 | scene_dir = self.scene_dirs[index] 141 | scene_path = os.path.join(self.filepath, scene_dir) 142 | sample_path = os.path.join(self.sample_path, scene_dir) 143 | LDR_path = os.path.join(scene_path, self.input_name) 144 | exp_path = os.path.join(scene_path, self.input_exp_name) 145 | ref_HDR_path = os.path.join(scene_path, self.ref_hdr_name) 146 | in_LDRs, in_HDRs, in_exps, ref_HDRs = get_input(LDR_path, exp_path, ref_HDR_path) 147 | in_LDRs = np.einsum("ijk->kij", in_LDRs) 148 | in_HDRs = np.einsum("ijk->kij", in_HDRs) 149 | ref_HDRs = np.einsum("ijk->kij", ref_HDRs) 150 | return sample_path, \ 151 | in_LDRs.copy().astype(np.float32), \ 152 | in_HDRs.copy().astype(np.float32), \ 153 | in_exps.copy().astype(np.float32), \ 154 | ref_HDRs.copy().astype(np.float32) 155 | 156 | -------------------------------------------------------------------------------- /models/NHDRRNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models import resnet 5 | 6 | config={ 7 | 'in_channel': 6, 8 | 'hidden_dim': 32, 9 | 'encoder_kernel_size': 3, 10 | 'encoder_stride': 2, 11 | 'triple_pass_filter': 256 12 | } 13 | 14 | 15 | class PaddedConv2d(nn.Module): 16 | def __init__(self, input_channels, output_channels, ks, stride): 17 | super().__init__() 18 | # Custom Padding Calculation 19 | if isinstance(ks, tuple): 20 | k_h, k_w = ks 21 | else: 22 | k_h = ks 23 | k_w = ks 24 | if isinstance(stride, tuple): 25 | s_h, s_w = stride 26 | else: 27 | s_h = stride 28 | s_w = stride 29 | pad_h, pad_w = k_h - s_h, k_w - s_w 30 | pad_up, pad_left = pad_h // 2, pad_w // 2 31 | pad_down, pad_right= pad_h - pad_up, pad_w - pad_left 32 | self.pad = nn.ZeroPad2d([pad_left, pad_right, pad_up, pad_down]) 33 | self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=ks, stride=stride, bias=True) 34 | 35 | def forward(self, x): 36 | x = self.pad(x) 37 | x = self.conv(x) 38 | return x 39 | 40 | 41 | class NHDRRNet(nn.Module): 42 | def __init__(self) -> None: 43 | super(NHDRRNet, self).__init__() 44 | # self.filter = config.filter 45 | # self.encoder_kernel = config.encoder_kernel 46 | # self.decoder_kernel = config.decoder_kernel 47 | # self.triple_pass_filter = config.triple_pass_filter 48 | self.c_dim = 3 49 | 50 | self.encoder_1 = [] 51 | self.encoder_2 = [] 52 | self.encoder_3 = [] 53 | self.encoder_1.append(self._make_encoder(config['in_channel'], config['hidden_dim'])) 54 | self.encoder_1.append(self._make_encoder(config['hidden_dim'], config['hidden_dim'] * 2)) 55 | self.encoder_1.append(self._make_encoder(config['hidden_dim'] * 2, config['hidden_dim'] * 4)) 56 | self.encoder_1.append(self._make_encoder(config['hidden_dim'] * 4, config['hidden_dim'] * 8)) 57 | self.encoder_1 = nn.ModuleList(self.encoder_1) 58 | 59 | self.encoder_2.append(self._make_encoder(config['in_channel'], config['hidden_dim'])) 60 | self.encoder_2.append(self._make_encoder(config['hidden_dim'], config['hidden_dim'] * 2)) 61 | self.encoder_2.append(self._make_encoder(config['hidden_dim'] * 2, config['hidden_dim'] * 4)) 62 | self.encoder_2.append(self._make_encoder(config['hidden_dim'] * 4, config['hidden_dim'] * 8)) 63 | self.encoder_2 = nn.ModuleList(self.encoder_2) 64 | 65 | self.encoder_3.append(self._make_encoder(config['in_channel'], config['hidden_dim'])) 66 | self.encoder_3.append(self._make_encoder(config['hidden_dim'], config['hidden_dim'] * 2)) 67 | self.encoder_3.append(self._make_encoder(config['hidden_dim'] * 2, config['hidden_dim'] * 4)) 68 | self.encoder_3.append(self._make_encoder(config['hidden_dim'] * 4, config['hidden_dim'] * 8)) 69 | self.encoder_3 = nn.ModuleList(self.encoder_3) 70 | 71 | self.final_encoder = nn.Sequential( 72 | PaddedConv2d(config['hidden_dim'] * 8 * 3, config['triple_pass_filter'], 3, 1), 73 | nn.BatchNorm2d(config['triple_pass_filter'], momentum=0.9), 74 | nn.ReLU() 75 | ) 76 | self.triple_list = [] 77 | for i in range(10): 78 | self.triple_list.append(nn.ModuleList(self._make_triple_pass_layer())) 79 | self.triple_list = nn.ModuleList(self.triple_list) 80 | self.avgpool = nn.AdaptiveAvgPool2d((16, 16)) 81 | self.theta_conv = PaddedConv2d(config['triple_pass_filter'], 128, 1, 1) 82 | self.phi_conv = PaddedConv2d(config['triple_pass_filter'], 128, 1, 1) 83 | self.g_conv = PaddedConv2d(config['triple_pass_filter'], 128, 1, 1) 84 | self.theta_phi_g_conv = PaddedConv2d(config['triple_pass_filter']//2, config['triple_pass_filter'], 1, 1) 85 | self.decoder1 = self._make_decoder(config['triple_pass_filter'] * 2, config['hidden_dim'] * 4) 86 | self.decoder2 = self._make_decoder(config['hidden_dim'] * 4 * 4, config['hidden_dim'] * 2) 87 | self.decoder3 = self._make_decoder(config['hidden_dim'] * 2 * 4, config['hidden_dim']) 88 | self.decoder_final = nn.Sequential( 89 | nn.ConvTranspose2d(config['hidden_dim'] * 4, config['hidden_dim'], 4, 2, 1, bias=True), 90 | nn.BatchNorm2d(config['hidden_dim']), 91 | nn.LeakyReLU() 92 | ) 93 | self.final = nn.Sequential( 94 | PaddedConv2d(config['hidden_dim'], 3, 3, 1), 95 | nn.Tanh() 96 | ) 97 | 98 | def _make_encoder(self, in_c, out): 99 | encoder = nn.Sequential( 100 | PaddedConv2d(in_c, out, config['encoder_kernel_size'], config['encoder_stride']), 101 | nn.BatchNorm2d(out, momentum=0.9), 102 | nn.ReLU() 103 | ) 104 | return encoder 105 | # def _make_encoder(self): 106 | # encoder = nn.Sequential( 107 | # PaddedConv2d(config['in_channel'], config['hidden_dim'], config['encoder_kernel_size'], config['encoder_stride']), 108 | # nn.BatchNorm2d(), 109 | # nn.ReLU(), 110 | # PaddedConv2d(config['hidden_dim'], config['hidden_dim'] * 2, config['encoder_kernel_size'], config['encoder_stride']), 111 | # nn.BatchNorm2d(), 112 | # nn.ReLU(), 113 | # PaddedConv2d(config['hidden_dim'] * 2, config['hidden_dim'] * 4, config['encoder_kernel_size'], config['encoder_stride']), 114 | # nn.BatchNorm2d(), 115 | # nn.ReLU(), 116 | # PaddedConv2d(config['hidden_dim'] * 4, config['hidden_dim'] * 8, config['encoder_kernel_size'], config['encoder_stride']), 117 | # nn.BatchNorm2d(), 118 | # nn.ReLU() 119 | # ) 120 | # return encoder 121 | 122 | def _make_decoder(self, in_c, out): 123 | decoder = nn.Sequential( 124 | nn.ConvTranspose2d(in_c, out, 4, 2, 1, bias=True), 125 | nn.BatchNorm2d(out), 126 | nn.LeakyReLU() 127 | ) 128 | return decoder 129 | 130 | def _make_triple_pass_layer(self): 131 | return [PaddedConv2d(config['triple_pass_filter'], config['triple_pass_filter'], 1, 1), 132 | PaddedConv2d(config['triple_pass_filter'], config['triple_pass_filter'], 3, 1), 133 | PaddedConv2d(config['triple_pass_filter'], config['triple_pass_filter'], 5, 1), 134 | PaddedConv2d(config['triple_pass_filter'] * 3, config['triple_pass_filter'], 3, 1)] 135 | 136 | def triplepass(self, x, i): 137 | x1 = F.relu(self.triple_list[i][0](x)) 138 | x2 = F.relu(self.triple_list[i][1](x)) 139 | x3 = F.relu(self.triple_list[i][2](x)) 140 | x3 = torch.cat([x1,x2,x3], dim=1) 141 | x4 = self.triple_list[i][3](x3) 142 | x5 = x4 + x 143 | 144 | return x5 145 | 146 | def global_non_local(self, x): 147 | b, c, h, w = x.shape 148 | theta = self.theta_conv(x).reshape(b, c//2, h * w).permute(0, 2, 1).contiguous() 149 | phi = self.phi_conv(x).reshape(b, c//2, h * w) 150 | g = self.g_conv(x).reshape(b, c//2, h * w).permute(0, 2, 1).contiguous() 151 | 152 | theta_phi = F.softmax(torch.matmul(theta, phi),dim=-1) 153 | theta_phi_g = torch.matmul(theta_phi, g) 154 | theta_phi_g = theta_phi_g.permute(0, 2, 1).contiguous().reshape(b, c//2, h, w) 155 | 156 | theta_phi_g = self.theta_phi_g_conv(theta_phi_g) 157 | 158 | output = theta_phi_g + x 159 | 160 | return output 161 | 162 | def forward(self, in_LDR, in_HDR): 163 | image1 = torch.cat([in_LDR[:, 0:self.c_dim, :, :], in_HDR[:, 0:self.c_dim, :, :]], 1) 164 | image2 = torch.cat([in_LDR[:, self.c_dim:self.c_dim * 2, :, :], in_HDR[:, self.c_dim:self.c_dim * 2, :, :]], 1) 165 | image3 = torch.cat([in_LDR[:, self.c_dim * 2:self.c_dim * 3, :, :], in_HDR[:, self.c_dim * 2:self.c_dim * 3, :, :]], 1) 166 | 167 | # if debug: 168 | # print('image1: {}, image2: {}, image3: {}'.format(image1.shape, image2.shape, image3.shape)) 169 | 170 | # encoding 171 | x1_32 = self.encoder_1[0](image1) 172 | x1_64 = self.encoder_1[1](x1_32) 173 | x1_128 = self.encoder_1[2](x1_64) 174 | x1 = self.encoder_1[3](x1_128) 175 | 176 | # if debug: 177 | # print('x1_32: {}, x1_64: {}, x1_128: {}, x1: {}'.format(x1_32.shape, x1_64.shape, x1_128.shape, x1.shape)) 178 | 179 | x2_32 = self.encoder_2[0](image2) 180 | x2_64 = self.encoder_2[1](x2_32) 181 | x2_128 = self.encoder_2[2](x2_64) 182 | x2 = self.encoder_2[3](x2_128) 183 | 184 | x3_32 = self.encoder_3[0](image3) 185 | x3_64 = self.encoder_3[1](x3_32) 186 | x3_128 = self.encoder_3[2](x3_64) 187 | x3 = self.encoder_3[3](x3_128) 188 | 189 | 190 | # merging 191 | x_cat = torch.cat([x1, x2, x3], dim=1) 192 | encoder_final = self.final_encoder(x_cat) 193 | 194 | tpl_out = self.triplepass(encoder_final, 0) 195 | for i in range(1,9): 196 | tpl_out = self.triplepass(tpl_out, i) 197 | 198 | glb_out = self.avgpool(encoder_final) 199 | glb_out = self.global_non_local(glb_out) 200 | required_size = [encoder_final.shape[2], encoder_final.shape[3]] 201 | glb_out = F.interpolate(glb_out, size=required_size) 202 | 203 | # decoding 204 | out_512 = torch.cat([tpl_out, glb_out], dim=1) 205 | out_128 = self.decoder1(out_512) 206 | out_128 = torch.cat([out_128, x1_128, x2_128, x3_128], dim=1) 207 | out_64 = self.decoder2(out_128) 208 | out_64 = torch.cat([out_64, x1_64, x2_64, x3_64], dim=1) 209 | out_32 = self.decoder3(out_64) 210 | out_32 = torch.cat([out_32, x1_32, x2_32, x3_32], dim=1) 211 | out = self.decoder_final(out_32) 212 | out = self.final(out) 213 | 214 | return out 215 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import optim 4 | from torch.utils.data import DataLoader 5 | from utils.loss import HDRLoss 6 | from utils.HDRutils import tonemap 7 | from utils.dataprocessor import dump_sample 8 | from dataset.HDR import KalantariTestDataset 9 | from models.NHDRRNet import NHDRRNet 10 | from utils.configs import Configs 11 | 12 | 13 | # Get configurations 14 | configs = Configs() 15 | 16 | # Load dataset 17 | test_dataset = KalantariTestDataset(configs=configs) 18 | test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True) 19 | 20 | # Build NHDRRNet model from configs 21 | model = NHDRRNet() 22 | if configs.multigpu is False: 23 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 24 | model.to(device) 25 | else: 26 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | if device == torch.device('cpu'): 28 | raise EnvironmentError('No GPUs, cannot initialize multigpu training.') 29 | model.to(device) 30 | 31 | # Define optimizer 32 | optimizer = optim.Adam(model.parameters(), betas=(configs.beta1, configs.beta2), lr=configs.learning_rate) 33 | 34 | # Define Criterion 35 | criterion = HDRLoss() 36 | 37 | # Read checkpoints 38 | checkpoint_file = configs.checkpoint_dir + '/checkpoint.tar' 39 | if os.path.isfile(checkpoint_file): 40 | checkpoint = torch.load(checkpoint_file) 41 | model.load_state_dict(checkpoint['model_state_dict']) 42 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 43 | start_epoch = checkpoint['epoch'] 44 | print("Load checkpoint %s (epoch %d)", checkpoint_file, start_epoch) 45 | else: 46 | raise ModuleNotFoundError('No checkpoint files.') 47 | 48 | if configs.multigpu is True: 49 | model = torch.nn.DataParallel(model) 50 | 51 | 52 | def test_one_epoch(): 53 | model.eval() 54 | mean_loss = 0 55 | count = 0 56 | for idx, data in enumerate(test_dataloader): 57 | sample_path, in_LDRs, in_HDRs, in_exps, ref_HDRs = data 58 | sample_path = sample_path[0] 59 | in_LDRs = in_LDRs.to(device) 60 | in_HDRs = in_HDRs.to(device) 61 | ref_HDRs = ref_HDRs.to(device) 62 | # Forward 63 | with torch.no_grad(): 64 | res = model(in_LDRs, in_HDRs) 65 | 66 | # Compute loss 67 | with torch.no_grad(): 68 | loss = criterion(tonemap(res), tonemap(ref_HDRs)) 69 | 70 | dump_sample(sample_path, res.cpu().detach().numpy()) 71 | 72 | print('--------------- Test Batch %d ---------------' % (idx + 1)) 73 | print('loss: %.12f' % loss.item()) 74 | mean_loss += loss.item() 75 | count += 1 76 | 77 | mean_loss = mean_loss / count 78 | return mean_loss 79 | 80 | 81 | def test(): 82 | loss = test_one_epoch() 83 | print('mean test loss: %.12f' % loss) 84 | 85 | 86 | if __name__ == '__main__': 87 | test() 88 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test.py -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from torch import optim 5 | from torch.utils.data import DataLoader 6 | from utils.solvers import PolyLR 7 | from utils.loss import HDRLoss 8 | from utils.HDRutils import tonemap 9 | from utils.dataprocessor import dump_sample 10 | from dataset.HDR import KalantariDataset, KalantariTestDataset 11 | from models.NHDRRNet import NHDRRNet 12 | from utils.configs import Configs 13 | import random 14 | import numpy as np 15 | 16 | 17 | def setup_seed(seed=0): 18 | random.seed(seed) 19 | os.environ['PYTHONHASHSEED'] = str(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed(seed) 23 | torch.cuda.manual_seed_all(seed) 24 | torch.backends.cudnn.benchmark = False 25 | torch.backends.cudnn.deterministic = True 26 | 27 | 28 | # Get configurations 29 | configs = Configs() 30 | 31 | # Load Data & build dataset 32 | train_dataset = KalantariDataset(configs=configs) 33 | train_dataloader = DataLoader(train_dataset, batch_size=configs.batch_size, shuffle=True) 34 | 35 | test_dataset = KalantariTestDataset(configs=configs) 36 | test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True) 37 | 38 | 39 | # Build NHDRRNet model from configs 40 | model = NHDRRNet() 41 | if configs.multigpu is False: 42 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 43 | model.to(device) 44 | else: 45 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 46 | if device == torch.device('cpu'): 47 | raise EnvironmentError('No GPUs, cannot initialize multigpu training.') 48 | model.to(device) 49 | 50 | # Define optimizer 51 | optimizer = optim.Adam(model.parameters(), betas=(configs.beta1, configs.beta2), lr=configs.learning_rate) 52 | 53 | # Define Criterion 54 | criterion = HDRLoss() 55 | 56 | # Define Scheduler 57 | lr_scheduler = PolyLR(optimizer, max_iter=configs.epoch, power=0.9) 58 | 59 | # Read checkpoints 60 | start_epoch = 0 61 | checkpoint_file = configs.checkpoint_dir + '/checkpoint.tar' 62 | if os.path.isfile(checkpoint_file): 63 | checkpoint = torch.load(checkpoint_file) 64 | model.load_state_dict(checkpoint['model_state_dict']) 65 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 66 | start_epoch = checkpoint['epoch'] 67 | lr_scheduler.load_state_dict(checkpoint['scheduler']) 68 | print("Load checkpoint %s (epoch %d)", checkpoint_file, start_epoch) 69 | 70 | 71 | if configs.multigpu is True: 72 | model = torch.nn.DataParallel(model) 73 | 74 | 75 | def train_one_epoch(): 76 | model.train() 77 | for idx, data in enumerate(train_dataloader): 78 | in_LDRs, ref_LDRs, in_HDRs, ref_HDRs, in_exps, ref_exps = data 79 | in_LDRs = in_LDRs.to(device) 80 | in_HDRs = in_HDRs.to(device) 81 | ref_HDRs = ref_HDRs.to(device) 82 | # Forward 83 | result = model(in_LDRs, in_HDRs) 84 | # Backward 85 | loss = criterion(tonemap(result), tonemap(ref_HDRs)) 86 | loss.backward() 87 | optimizer.step() 88 | optimizer.zero_grad() 89 | 90 | print('--------------- Train Batch %d ---------------' % (idx + 1)) 91 | print('loss: %.12f' % loss.item()) 92 | 93 | 94 | def eval_one_epoch(): 95 | model.eval() 96 | mean_loss = 0 97 | count = 0 98 | for idx, data in enumerate(test_dataloader): 99 | sample_path, in_LDRs, in_HDRs, in_exps, ref_HDRs = data 100 | sample_path = sample_path[0] 101 | in_LDRs = in_LDRs.to(device) 102 | in_HDRs = in_HDRs.to(device) 103 | ref_HDRs = ref_HDRs.to(device) 104 | # Forward 105 | with torch.no_grad(): 106 | res = model(in_LDRs, in_HDRs) 107 | # Compute loss 108 | with torch.no_grad(): 109 | loss = criterion(tonemap(res), tonemap(ref_HDRs)) 110 | dump_sample(sample_path, res.cpu().detach().numpy()) 111 | print('--------------- Eval Batch %d ---------------' % (idx + 1)) 112 | print('loss: %.12f' % loss.item()) 113 | mean_loss += loss.item() 114 | count += 1 115 | 116 | mean_loss = mean_loss / count 117 | return mean_loss 118 | 119 | 120 | def train(start_epoch): 121 | global cur_epoch 122 | for epoch in range(start_epoch, configs.epoch): 123 | cur_epoch = epoch 124 | print('**************** Epoch %d ****************' % (epoch + 1)) 125 | print('learning rate: %f' % (lr_scheduler.get_last_lr()[0])) 126 | train_one_epoch() 127 | loss = eval_one_epoch() 128 | lr_scheduler.step() 129 | if configs.multigpu is False: 130 | save_dict = {'epoch': epoch + 1, 'loss': loss, 131 | 'optimizer_state_dict': optimizer.state_dict(), 132 | 'model_state_dict': model.state_dict(), 133 | 'scheduler': lr_scheduler.state_dict() 134 | } 135 | else: 136 | save_dict = {'epoch': epoch + 1, 'loss': loss, 137 | 'optimizer_state_dict': optimizer.state_dict(), 138 | 'model_state_dict': model.module.state_dict(), 139 | 'scheduler': lr_scheduler.state_dict() 140 | } 141 | torch.save(save_dict, os.path.join(configs.checkpoint_dir, 'checkpoint.tar')) 142 | torch.save(save_dict, os.path.join(configs.checkpoint_dir, 'checkpoint' + str(epoch) + '.tar')) 143 | print('mean eval loss: %.12f' % loss) 144 | 145 | 146 | if __name__ == '__main__': 147 | train(start_epoch) 148 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train.py -------------------------------------------------------------------------------- /utils/HDRutils.py: -------------------------------------------------------------------------------- 1 | # Ref: https://github.com/elliottwu/DeepHDR/ 2 | import numpy as np 3 | import torch 4 | import cv2 5 | 6 | MU = 5000. # tunemapping parameter 7 | GAMMA = 2.2 # LDR&HDR domain transform parameter 8 | 9 | 10 | def LDR2HDR(img, expo): # input/output 0~1 11 | return (((img+1)/2.)**GAMMA / expo) *2.-1 12 | 13 | 14 | def LDR2HDR_batch(imgs, expos): # input/output 0~1 15 | return np.concatenate([LDR2HDR(imgs[:, :, 0:3], expos[0]), 16 | LDR2HDR(imgs[:, :, 3:6], expos[1]), 17 | LDR2HDR(imgs[:, :, 6:9], expos[2])], axis=2) 18 | 19 | 20 | def HDR2LDR(imgs, expo): # input/output 0~1 21 | return (np.clip(((imgs+1)/2.*expo),0,1)**(1/GAMMA)) *2.-1 22 | 23 | 24 | def transform_LDR(image, im_size=(256, 256)): 25 | out = image.astype(np.float32) 26 | out = cv2.resize(out, im_size) 27 | return out/127.5 - 1. 28 | 29 | 30 | def transform_HDR(image, im_size=(256, 256)): 31 | out = cv2.resize(image, im_size) 32 | return out*2. - 1. 33 | 34 | 35 | def tonemap(images): # input/output 0~1 36 | return torch.log(1 + MU * (images + 1) / 2.) / np.log(1 + MU) * 2. - 1 37 | 38 | 39 | def tonemap_np(images): # input/output 0~1 40 | return np.log(1 + MU * (images + 1) / 2.) / np.log(1 + MU) * 2. - 1 41 | -------------------------------------------------------------------------------- /utils/configs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class Configs(object): 4 | def __init__(self, data_path='data', 5 | **kwargs): 6 | self.data_path = data_path 7 | self.epoch = kwargs.get('epoch', 50) 8 | self.learning_rate = kwargs.get('learning_rate', 0.001) 9 | self.beta1 = kwargs.get('beta1', 0.9) 10 | self.beta2 = kwargs.get('beta2', 0.999) 11 | self.load_size = kwargs.get('load_size', 250) 12 | self.patch_size = kwargs.get('patch_size', (256, 256)) 13 | self.image_size = kwargs.get('image_size', (256, 256)) 14 | self.patch_stride = kwargs.get('patch_stride', 64) 15 | self.patch_dir = kwargs.get('patch_dir', 'patches') 16 | self.batch_size = kwargs.get('batch_size', 32) 17 | self.c_dim = kwargs.get('c_dim', 3) 18 | self.num_shots = kwargs.get('num_shots', 3) 19 | self.checkpoint_dir = kwargs.get('checkpoint_dir', 'checkpoint') 20 | self.sample_dir = kwargs.get('sample_dir', 'samples') 21 | self.log_dir = kwargs.get('log_dir', 'logs') 22 | self.multigpu = kwargs.get('multigpu', False) 23 | if not os.path.exists(self.checkpoint_dir): 24 | os.makedirs(self.checkpoint_dir) 25 | if not os.path.exists(self.sample_dir): 26 | os.makedirs(self.sample_dir) 27 | if not os.path.exists(self.patch_dir): 28 | os.makedirs(self.patch_dir) 29 | -------------------------------------------------------------------------------- /utils/dataprocessor.py: -------------------------------------------------------------------------------- 1 | # Ref: https://github.com/elliottwu/DeepHDR 2 | import os 3 | import cv2 4 | from glob import glob 5 | import pickle 6 | import numpy as np 7 | from utils.HDRutils import * 8 | 9 | 10 | def imread(path): 11 | if path[-4:] == '.hdr': 12 | img = cv2.imread(path, -1) 13 | else: 14 | img = cv2.imread(path)/255. 15 | return img.astype(np.float32)[..., ::-1] 16 | 17 | 18 | def radiance_writer(out_path, image): 19 | with open(out_path, "wb") as f: 20 | f.write(b"#?RADIANCE\n# Made with Python & Numpy\nFORMAT=32-bit_rle_rgbe\n\n") 21 | f.write(b"-Y %d +X %d\n" %(image.shape[0], image.shape[1])) 22 | 23 | brightest = np.maximum(np.maximum(image[..., 0], image[..., 1]), image[..., 2]) 24 | mantissa = np.zeros_like(brightest) 25 | exponent = np.zeros_like(brightest) 26 | np.frexp(brightest, mantissa, exponent) 27 | scaled_mantissa = mantissa * 255.0 / brightest 28 | rgbe = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8) 29 | rgbe[..., 0:3] = np.around(image[..., 0:3] * scaled_mantissa[..., None]) 30 | rgbe[..., 3] = np.around(exponent + 128) 31 | 32 | rgbe.flatten().tofile(f) 33 | 34 | 35 | def store_patch(h1, h2, w1, w2, in_LDRs, in_exps, ref_HDR, ref_LDRs, ref_exps, save_path, save_id): 36 | res = { 37 | 'in_LDR': in_LDRs, 38 | 'ref_LDR': ref_LDRs, 39 | 'ref_HDR': ref_HDR, 40 | 'in_exp': in_exps, 41 | 'ref_exp': ref_exps, 42 | } 43 | with open(save_path + '/' + str(save_id) + '.pkl', 'wb') as pkl_file: 44 | pickle.dump(res, pkl_file) 45 | 46 | 47 | def get_patch_from_file(pkl_path, pkl_id): 48 | with open(pkl_path + '/' + str(pkl_id) + '.pkl', 'rb') as pkl_file: 49 | res = pickle.load(pkl_file) 50 | return res 51 | 52 | 53 | # always return RGB, float32, range 0~1 54 | def get_image(image_path, image_size=None, is_crop=False): 55 | if is_crop: 56 | assert (image_size is not None), "the crop size must be specified" 57 | return transform(imread(image_path), image_size, is_crop) 58 | 59 | 60 | def merge(images, size): 61 | h, w = images.shape[1], images.shape[2] 62 | img = np.zeros((int(h * size[0]), int(w * size[1]), 3)) 63 | for idx, image in enumerate(images): 64 | i = idx % size[1] 65 | j = idx // size[1] 66 | img[j*h:j*h+h, i*w:i*w+w, :] = image 67 | return img 68 | 69 | 70 | def center_crop(x, image_size): 71 | crop_h, crop_w = image_size 72 | h, w = x.shape[:2] 73 | j = int(round((h - crop_h)/2.)) 74 | i = int(round((w - crop_w)/2.)) 75 | return cv2.resize(x[max(0, j):min(h, j+crop_h), max(0, i):min(w, i+crop_w)], (crop_w, crop_h)) 76 | 77 | 78 | def transform(image, image_size, is_crop): 79 | if is_crop: 80 | out = center_crop(image, image_size) 81 | elif image_size is not None: 82 | out = cv2.resize(image, image_size) 83 | else: 84 | out = image 85 | out = out*2. - 1 86 | return out.astype(np.float32) 87 | 88 | 89 | def inverse_transform(images): 90 | return (images + 1) / 2 91 | 92 | 93 | # get input 94 | def get_input(LDR_path, exp_path, ref_HDR_path): 95 | in_LDR_paths = sorted(glob(LDR_path)) 96 | ns = len(in_LDR_paths) 97 | tmp_img = cv2.imread(in_LDR_paths[0]).astype(np.float32) 98 | h, w, c = tmp_img.shape 99 | h = h // 16 * 16 100 | w = w // 16 * 16 101 | 102 | in_exps = np.array(open(exp_path).read().split('\n')[:ns]).astype(np.float32) 103 | in_LDRs = np.zeros((h, w, c * ns), dtype=np.float32) 104 | in_HDRs = np.zeros((h, w, c * ns), dtype=np.float32) 105 | 106 | for i, image_path in enumerate(in_LDR_paths): 107 | img = get_image(image_path, image_size=[h, w], is_crop=True) 108 | in_LDRs[:, :, c * i:c * (i + 1)] = img 109 | in_HDRs[:, :, c * i:c * (i + 1)] = LDR2HDR(img, 2. ** in_exps[i]) 110 | 111 | ref_HDR = get_image(ref_HDR_path, image_size=[h, w], is_crop=True) 112 | return in_LDRs, in_HDRs, in_exps, ref_HDR 113 | 114 | 115 | def dump_sample(sample_path, img): 116 | img = img[0] 117 | h, w, _ = img.shape 118 | if not os.path.exists(sample_path): 119 | os.makedirs(sample_path) 120 | file_path = sample_path + '/hdr.hdr' 121 | img = inverse_transform(img) 122 | img = np.einsum('ijk->jki', img) 123 | radiance_writer(file_path, img) 124 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | # Ref: https://github.com/elliottwu/DeepHDR 2 | import os 3 | import cv2 4 | from glob import glob 5 | import pickle 6 | import numpy as np 7 | from utils.HDRutils import * 8 | 9 | 10 | def imread(path): 11 | if path[-4:] == '.hdr': 12 | img = cv2.imread(path, -1) 13 | else: 14 | img = cv2.imread(path)/255. 15 | return img.astype(np.float32)[..., ::-1] 16 | 17 | 18 | def imsave(images, size, path): 19 | if path[-4:] == '.hdr': 20 | return radiance_writer(path, merge(images, size)) 21 | else: 22 | return cv2.imwrite(path, merge(images, size)[..., ::-1]*255.) 23 | 24 | 25 | def radiance_writer(out_path, image): 26 | with open(out_path, "wb") as f: 27 | f.write(b"#?RADIANCE\n# Made with Python & Numpy\nFORMAT=32-bit_rle_rgbe\n\n") 28 | f.write(b"-Y %d +X %d\n" %(image.shape[0], image.shape[1])) 29 | 30 | brightest = np.maximum(np.maximum(image[..., 0], image[..., 1]), image[..., 2]) 31 | mantissa = np.zeros_like(brightest) 32 | exponent = np.zeros_like(brightest) 33 | np.frexp(brightest, mantissa, exponent) 34 | scaled_mantissa = mantissa * 255.0 / brightest 35 | rgbe = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8) 36 | rgbe[..., 0:3] = np.around(image[..., 0:3] * scaled_mantissa[..., None]) 37 | rgbe[..., 3] = np.around(exponent + 128) 38 | 39 | rgbe.flatten().tofile(f) 40 | 41 | 42 | def store_patch(h1, h2, w1, w2, in_LDRs, in_exps, ref_HDR, ref_LDRs, ref_exps, save_path, save_id): 43 | in_LDRs_patch = in_LDRs[h1:h2, w1:w2, :] 44 | in_LDRs_patch_1 = in_LDRs_patch[:, :, 2::-1] 45 | in_LDRs_patch_2 = in_LDRs_patch[:, :, 5:2:-1] 46 | in_LDRs_patch_3 = in_LDRs_patch[:, :, 8:5:-1] 47 | in_LDRs_patch = np.concatenate([in_LDRs_patch_1, in_LDRs_patch_2, in_LDRs_patch_3], axis=2) 48 | ref_HDR_patch = ref_HDR[h1:h2, w1:w2, ::-1] 49 | ref_LDRs_patch = ref_LDRs[h1:h2, w1:w2, :] 50 | ref_LDRs_patch_1 = ref_LDRs_patch[:, :, 2::-1] 51 | ref_LDRs_patch_2 = ref_LDRs_patch[:, :, 5:2:-1] 52 | ref_LDRs_patch_3 = ref_LDRs_patch[:, :, 8:5:-1] 53 | ref_LDRs_patch = np.concatenate([ref_LDRs_patch_1, ref_LDRs_patch_2, ref_LDRs_patch_3], axis=2) 54 | 55 | res = { 56 | 'in_LDR': in_LDRs_patch, 57 | 'ref_LDR': ref_LDRs_patch, 58 | 'ref_HDR': ref_HDR_patch, 59 | 'in_exp': in_exps, 60 | 'ref_exp': ref_exps, 61 | } 62 | with open(save_path + '/' + str(save_id) + '.pkl', 'wb') as pkl_file: 63 | pickle.dump(res, pkl_file) 64 | 65 | 66 | def get_patch_from_file(pkl_path, pkl_id): 67 | with open(pkl_path + '/' + str(pkl_id) + '.pkl', 'rb') as pkl_file: 68 | res = pickle.load(pkl_file) 69 | return res 70 | 71 | 72 | # always return RGB, float32, range 0~1 73 | def get_image(image_path, image_size=None, is_crop=False): 74 | if is_crop: 75 | assert (image_size is not None), "the crop size must be specified" 76 | return transform(imread(image_path), image_size, is_crop) 77 | 78 | 79 | def save_images(images, size, image_path): 80 | return imsave(inverse_transform(images), size, image_path) 81 | 82 | 83 | def merge_images(images, size): 84 | return inverse_transform(images) 85 | 86 | 87 | def merge(images, size): 88 | h, w = images.shape[1], images.shape[2] 89 | img = np.zeros((int(h * size[0]), int(w * size[1]), 3)) 90 | for idx, image in enumerate(images): 91 | i = idx % size[1] 92 | j = idx // size[1] 93 | img[j*h:j*h+h, i*w:i*w+w, :] = image 94 | return img 95 | 96 | 97 | def center_crop(x, image_size): 98 | crop_h, crop_w = image_size 99 | h, w = x.shape[:2] 100 | j = int(round((h - crop_h)/2.)) 101 | i = int(round((w - crop_w)/2.)) 102 | return cv2.resize(x[max(0, j):min(h, j+crop_h), max(0, i):min(w, i+crop_w)], (crop_w, crop_h)) 103 | 104 | 105 | def transform(image, image_size, is_crop): 106 | if is_crop: 107 | out = center_crop(image, image_size) 108 | elif image_size is not None: 109 | out = cv2.resize(image, image_size) 110 | else: 111 | out = image 112 | out = out*2. - 1 113 | return out.astype(np.float32) 114 | 115 | 116 | def inverse_transform(images): 117 | return (images + 1) / 2 118 | 119 | 120 | # get input 121 | def get_input(LDR_path, exp_path, ref_HDR_path): 122 | in_LDR_paths = sorted(glob(LDR_path)) 123 | ns = len(in_LDR_paths) 124 | tmp_img = cv2.imread(in_LDR_paths[0]).astype(np.float32) 125 | h, w, c = tmp_img.shape 126 | resize = 16 127 | h = h // 16 * 16 128 | w = w // 16 * 16 129 | # h = h // 8 * 8 130 | # w = w // 8 * 8 131 | 132 | in_exps = np.array(open(exp_path).read().split('\n')[:ns]).astype(np.float32) 133 | in_LDRs = np.zeros((h, w, c * ns), dtype=np.float32) 134 | in_HDRs = np.zeros((h, w, c * ns), dtype=np.float32) 135 | 136 | for i, image_path in enumerate(in_LDR_paths): 137 | img = get_image(image_path, image_size=[h, w], is_crop=True) 138 | in_LDRs[:, :, c * i:c * (i + 1)] = img 139 | in_HDRs[:, :, c * i:c * (i + 1)] = LDR2HDR(img, 2. ** in_exps[i]) 140 | 141 | ref_HDR = get_image(ref_HDR_path, image_size=[h, w], is_crop=True) 142 | return in_LDRs, in_HDRs, in_exps, ref_HDR 143 | 144 | 145 | def dump_sample(sample_path, img): 146 | img = img[0] 147 | h, w, _ = img.shape 148 | if not os.path.exists(sample_path): 149 | os.makedirs(sample_path) 150 | file_path = sample_path + '/hdr.hdr' 151 | img = inverse_transform(img) 152 | img = np.einsum('ijk->jki', img) 153 | radiance_writer(file_path, img) 154 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | 6 | class HDRLoss(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, out_img, ref_img): 11 | return torch.mean((out_img - ref_img) ** 2) 12 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import cv2 4 | 5 | 6 | class PSNR(): 7 | def __init__(self, range=1): 8 | self.range = range 9 | 10 | def __call__(self, img1, img2): 11 | mse = np.mean((img1 - img2) ** 2) 12 | return 20 * math.log10(self.range / math.sqrt(mse)) 13 | 14 | 15 | class SSIM(): 16 | def __init__(self, range=1): 17 | self.range = range 18 | 19 | def __call__(self, img1, img2): 20 | if not img1.shape == img2.shape: 21 | raise ValueError("Input images must have the same dimensions.") 22 | if img1.ndim == 2: # Grey or Y-channel image 23 | return self._ssim(img1, img2) 24 | elif img1.ndim == 3: 25 | if img1.shape[2] == 3: 26 | ssims = [] 27 | for i in range(3): 28 | ssims.append(self._ssim(img1, img2)) 29 | return np.array(ssims).mean() 30 | elif img1.shape[2] == 1: 31 | return self._ssim(np.squeeze(img1), np.squeeze(img2)) 32 | else: 33 | raise ValueError("Wrong input image dimensions.") 34 | 35 | def _ssim(self, img1, img2): 36 | C1 = (0.01 * self.range) ** 2 37 | C2 = (0.03 * self.range) ** 2 38 | 39 | img1 = img1.astype(np.float64) 40 | img2 = img2.astype(np.float64) 41 | kernel = cv2.getGaussianKernel(11, 1.5) 42 | window = np.outer(kernel, kernel.transpose()) 43 | 44 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 45 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 46 | mu1_sq = mu1 ** 2 47 | mu2_sq = mu2 ** 2 48 | mu1_mu2 = mu1 * mu2 49 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 50 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 51 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 52 | 53 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 54 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 55 | ) 56 | return ssim_map.mean() 57 | -------------------------------------------------------------------------------- /utils/solvers.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import LambdaLR 2 | 3 | 4 | class LambdaStepLR(LambdaLR): 5 | def __init__(self, optimizer, lr_lambda, last_step=-1): 6 | super(LambdaStepLR, self).__init__(optimizer, lr_lambda, last_step) 7 | 8 | @property 9 | def last_step(self): 10 | return self.last_epoch 11 | 12 | @last_step.setter 13 | def last_step(self, v): 14 | self.last_epoch = v 15 | 16 | 17 | class PolyLR(LambdaStepLR): 18 | def __init__(self, optimizer, max_iter, power=0.9, last_step=-1): 19 | super(PolyLR, self).__init__(optimizer, lambda s: (1 - s / (max_iter + 1))**power, last_step) 20 | 21 | 22 | class SquaredLR(LambdaStepLR): 23 | def __init__(self, optimizer, max_iter, last_step=-1): 24 | super(SquaredLR, self).__init__(optimizer, lambda s: (1 - s / (max_iter + 1))**2, last_step) --------------------------------------------------------------------------------