├── unsupervised_hdr ├── tools │ ├── __init__.py │ ├── evaluation.py │ ├── io.py │ └── helper_functions.py ├── __init__.py ├── data │ ├── __init__.py │ └── dataset.py ├── models │ ├── loss.py │ ├── __init__.py │ ├── network_blocks.py │ ├── encoder.py │ └── decoder.py └── core.py ├── requirements.txt ├── resources ├── hdr_diff.jpg ├── 2fstop_low_gt.jpg ├── 4fstop_low_gt.jpg ├── ground_truth.jpg ├── 2fstop_high_diff.jpg ├── 2fstop_high_gt.jpg ├── 2fstop_low_diff.jpg ├── 4fstop_high_diff.jpg ├── 4fstop_high_gt.jpg ├── 4fstop_low_diff.jpg ├── 2fstop_low_predict.jpg ├── 4fstop_low_predict.jpg ├── 2fstop_high_predict.jpg ├── 4fstop_high_predict.jpg └── opencv_reconstructed.jpg ├── setup.py ├── LICENSE.txt ├── .gitignore └── README.md /unsupervised_hdr/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | timm>=0.5.4 3 | tqdm>=4.61.2 -------------------------------------------------------------------------------- /unsupervised_hdr/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import UnsupervisedHDRModel 2 | -------------------------------------------------------------------------------- /unsupervised_hdr/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import LDRDataset 2 | -------------------------------------------------------------------------------- /unsupervised_hdr/tools/evaluation.py: -------------------------------------------------------------------------------- 1 | # TODO: Reimplementation PU-PNSR, PU-SSIM, HDR-VDP2.2 2 | -------------------------------------------------------------------------------- /resources/hdr_diff.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tattaka/unsupervised-hdr-imaging/HEAD/resources/hdr_diff.jpg -------------------------------------------------------------------------------- /resources/2fstop_low_gt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tattaka/unsupervised-hdr-imaging/HEAD/resources/2fstop_low_gt.jpg -------------------------------------------------------------------------------- /resources/4fstop_low_gt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tattaka/unsupervised-hdr-imaging/HEAD/resources/4fstop_low_gt.jpg -------------------------------------------------------------------------------- /resources/ground_truth.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tattaka/unsupervised-hdr-imaging/HEAD/resources/ground_truth.jpg -------------------------------------------------------------------------------- /resources/2fstop_high_diff.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tattaka/unsupervised-hdr-imaging/HEAD/resources/2fstop_high_diff.jpg -------------------------------------------------------------------------------- /resources/2fstop_high_gt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tattaka/unsupervised-hdr-imaging/HEAD/resources/2fstop_high_gt.jpg -------------------------------------------------------------------------------- /resources/2fstop_low_diff.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tattaka/unsupervised-hdr-imaging/HEAD/resources/2fstop_low_diff.jpg -------------------------------------------------------------------------------- /resources/4fstop_high_diff.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tattaka/unsupervised-hdr-imaging/HEAD/resources/4fstop_high_diff.jpg -------------------------------------------------------------------------------- /resources/4fstop_high_gt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tattaka/unsupervised-hdr-imaging/HEAD/resources/4fstop_high_gt.jpg -------------------------------------------------------------------------------- /resources/4fstop_low_diff.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tattaka/unsupervised-hdr-imaging/HEAD/resources/4fstop_low_diff.jpg -------------------------------------------------------------------------------- /resources/2fstop_low_predict.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tattaka/unsupervised-hdr-imaging/HEAD/resources/2fstop_low_predict.jpg -------------------------------------------------------------------------------- /resources/4fstop_low_predict.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tattaka/unsupervised-hdr-imaging/HEAD/resources/4fstop_low_predict.jpg -------------------------------------------------------------------------------- /resources/2fstop_high_predict.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tattaka/unsupervised-hdr-imaging/HEAD/resources/2fstop_high_predict.jpg -------------------------------------------------------------------------------- /resources/4fstop_high_predict.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tattaka/unsupervised-hdr-imaging/HEAD/resources/4fstop_high_predict.jpg -------------------------------------------------------------------------------- /resources/opencv_reconstructed.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tattaka/unsupervised-hdr-imaging/HEAD/resources/opencv_reconstructed.jpg -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | 4 | def _requires_from_file(filename): 5 | return open(filename).read().splitlines() 6 | 7 | 8 | setup( 9 | name="unsupervised_hdr", 10 | version="0.1.0", 11 | license="MIT License", 12 | description="A Library for unsupervised exposure changes from LDR video", 13 | author="tattaka", 14 | url="https://github.com/tattaka/unsupervised-hdr-imaging", 15 | packages=find_packages(), 16 | install_requires=_requires_from_file("requirements.txt"), 17 | ) 18 | -------------------------------------------------------------------------------- /unsupervised_hdr/models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class ImageSpaceLoss(nn.Module): 7 | def __init__(self, lam: int = 5) -> None: 8 | super().__init__() 9 | self.lam = lam 10 | self.l1_loss = nn.L1Loss() 11 | # TODO: Try revisiting L1 loss(https://arxiv.org/abs/2201.10084) 12 | 13 | def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: 14 | loss = self.l1_loss(x1, x2) + self.lam * ( 15 | 1 - torch.mean(F.cosine_similarity(x1, x2, dim=1)) 16 | ) 17 | return loss 18 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2022 tattaka666@gmail.com 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /unsupervised_hdr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import timm 2 | from torch import nn 3 | 4 | from .decoder import SimpleDecoder 5 | from .encoder import SimpleEncoder 6 | from .loss import ImageSpaceLoss 7 | 8 | encoder_factory = {"SimpleEncoder": SimpleEncoder} 9 | decoder_factory = {"SimpleDecoder": SimpleDecoder} 10 | 11 | 12 | class EncoderDecoderModel(nn.Module): 13 | def __init__( 14 | self, encoder: str, decoder: str, encoder_pretrained: bool = False 15 | ) -> None: 16 | super().__init__() 17 | if encoder == "SimpleEncoder": 18 | self.encoder = SimpleEncoder() 19 | else: 20 | self.encoder = timm.create_model( 21 | encoder, 22 | features_only=True, 23 | out_indices=(0, 1, 2, 3), 24 | pretrained=encoder_pretrained, 25 | ) 26 | if decoder == "SimpleDecoder": 27 | self.decoder = SimpleDecoder(feature_info=self.encoder.feature_info) 28 | 29 | def forward(self, x): 30 | return self.decoder(self.encoder(x)) 31 | -------------------------------------------------------------------------------- /unsupervised_hdr/tools/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import List 4 | 5 | import cv2 6 | import numpy as np 7 | from tqdm.auto import tqdm 8 | 9 | 10 | def write_hdr_images(dir_path: str, image_list: List[np.ndarray]) -> None: 11 | os.makedirs(dir_path, exist_ok=True) 12 | for i, image in enumerate(tqdm(image_list)): 13 | cv2.imwrite( 14 | os.path.join(dir_path, str(i).zfill(len(str(len(image_list)))) + ".hdr"), 15 | image, 16 | ) 17 | 18 | 19 | def write_exr_images(dir_path: str, image_list: List[np.ndarray]) -> None: 20 | os.makedirs(dir_path, exist_ok=True) 21 | raise NotImplementedError 22 | 23 | 24 | def write_hdr_to_mp4(video_path: str, image_list: List[np.ndarray], fps: float) -> None: 25 | image_size = image_list[0].shape[:2][::-1] 26 | fourcc = cv2.VideoWriter_fourcc("m", "p", "4", "v") 27 | video = cv2.VideoWriter(video_path, fourcc, fps, image_size) 28 | if not video.isOpened(): 29 | print("can't be opened video file") 30 | sys.exit() 31 | for image in tqdm(image_list): 32 | image = (image * 255).clip(0, 255).astype(np.uint8) 33 | video.write(image) 34 | video.release() 35 | -------------------------------------------------------------------------------- /unsupervised_hdr/data/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from torch.utils import data 7 | 8 | 9 | def exposure(img: torch.Tensor, v: float) -> torch.Tensor: 10 | return (img * (2**v) ** (1 / 2.2)).clamp(0, 255).to(torch.uint8).to(torch.float32) 11 | 12 | 13 | class LDRDataset(data.Dataset): 14 | def __init__( 15 | self, video_path: str, train: bool = False, image_size: Tuple[int, int] = None 16 | ) -> None: 17 | self.video_path = video_path 18 | cap = cv2.VideoCapture(self.video_path) 19 | if not cap.isOpened(): 20 | raise FileNotFoundError 21 | self.cap_frame_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 22 | self.train = train 23 | self.image_size = image_size 24 | 25 | def __len__(self) -> int: 26 | return self.cap_frame_num 27 | 28 | def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: 29 | cap = cv2.VideoCapture(self.video_path) 30 | cap.set(cv2.CAP_PROP_POS_FRAMES, index) 31 | ret, frame = cap.read() 32 | if self.image_size is not None: 33 | frame = cv2.resize(frame, dsize=self.image_size) 34 | frame = torch.as_tensor(frame.transpose(2, 0, 1), dtype=torch.float32) 35 | if self.train: 36 | s = np.random.rand() * 0.25 37 | else: 38 | s = 0 39 | frame = exposure(frame, s) 40 | frame_high = exposure(frame, 2 + s) 41 | return {"Ib": frame / 255, "Ih": frame_high / 255} 42 | -------------------------------------------------------------------------------- /unsupervised_hdr/models/network_blocks.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | @dataclass 9 | class FeatureInfo: 10 | c: List[int] = field(default_factory=list) 11 | r: List[int] = field(default_factory=list) 12 | 13 | def channels(self): 14 | return self.c 15 | 16 | def reduction(self): 17 | return self.r 18 | 19 | 20 | class ConcatPool2d(nn.Module): 21 | def __init__( 22 | self, 23 | kernel_size: int, 24 | stride: int = None, 25 | padding: int = 0, 26 | ceil_mode: bool = False, 27 | ): 28 | super().__init__() 29 | self.mp = nn.MaxPool2d( 30 | kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode 31 | ) 32 | self.ap = nn.AvgPool2d( 33 | kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode 34 | ) 35 | 36 | def forward(self, x): 37 | return torch.cat([self.mp(x), self.ap(x)], 1) 38 | 39 | 40 | class Conv2dReLU(nn.Module): 41 | def __init__( 42 | self, 43 | in_channels: int, 44 | out_channels: int, 45 | kernel_size: int, 46 | padding: int = 0, 47 | stride: int = 1, 48 | use_batchnorm: bool = False, 49 | **batchnorm_params 50 | ): 51 | 52 | super().__init__() 53 | 54 | layers = [ 55 | nn.Conv2d( 56 | in_channels, 57 | out_channels, 58 | kernel_size, 59 | stride=stride, 60 | padding=padding, 61 | bias=not (use_batchnorm), 62 | ), 63 | nn.ReLU(inplace=True), 64 | ] 65 | 66 | if use_batchnorm: 67 | layers.insert(1, nn.BatchNorm2d(out_channels, **batchnorm_params)) 68 | 69 | self.block = nn.Sequential(*layers) 70 | 71 | def forward(self, x): 72 | return self.block(x) 73 | -------------------------------------------------------------------------------- /unsupervised_hdr/models/encoder.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from .network_blocks import ConcatPool2d, Conv2dReLU, FeatureInfo 7 | 8 | 9 | class SimpleEncoder(nn.Module): 10 | def __init__(self) -> None: 11 | super().__init__() 12 | self.conv = nn.ModuleList() 13 | self.feature_info = FeatureInfo() 14 | self.conv.append( 15 | Conv2dReLU( 16 | in_channels=3, 17 | out_channels=64, 18 | kernel_size=3, 19 | stride=1, 20 | padding=1, 21 | use_batchnorm=False, 22 | ) 23 | ) 24 | self.feature_info.c.append(64) 25 | self.feature_info.r.append(1) 26 | self.conv.append( 27 | nn.Sequential( 28 | ConcatPool2d(kernel_size=2), 29 | Conv2dReLU( 30 | in_channels=128, 31 | out_channels=128, 32 | kernel_size=3, 33 | stride=1, 34 | padding=1, 35 | use_batchnorm=False, 36 | ), 37 | ) 38 | ) 39 | self.feature_info.c.append(128) 40 | self.feature_info.r.append(2) 41 | self.conv.append( 42 | nn.Sequential( 43 | ConcatPool2d(kernel_size=2), 44 | Conv2dReLU( 45 | in_channels=256, 46 | out_channels=256, 47 | kernel_size=3, 48 | stride=1, 49 | padding=1, 50 | use_batchnorm=False, 51 | ), 52 | ) 53 | ) 54 | self.feature_info.c.append(256) 55 | self.feature_info.r.append(4) 56 | self.conv.append( 57 | nn.Sequential( 58 | ConcatPool2d(kernel_size=2), 59 | Conv2dReLU( 60 | in_channels=512, 61 | out_channels=512, 62 | kernel_size=3, 63 | stride=1, 64 | padding=1, 65 | use_batchnorm=False, 66 | ), 67 | ) 68 | ) 69 | self.feature_info.c.append(512) 70 | self.feature_info.r.append(8) 71 | 72 | def forward(self, x: torch.Tensor) -> List[torch.Tensor]: 73 | y = [x] 74 | for m in self.conv: 75 | y.append(m(y[-1])) 76 | del y[0] 77 | return y 78 | 79 | 80 | # if __name__ == "__main__": 81 | # import torch 82 | 83 | # img = torch.randn(2, 3, 128, 128) 84 | # e = SimpleEncoder() 85 | # for y in e(img): 86 | # print(y.shape) 87 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | .ipynb 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # pytype static type analyzer 136 | .pytype/ 137 | 138 | # Cython debug symbols 139 | cython_debug/ -------------------------------------------------------------------------------- /unsupervised_hdr/models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from .network_blocks import Conv2dReLU, FeatureInfo 6 | 7 | 8 | class SimpleDecoder(nn.Module): 9 | def __init__(self, feature_info: FeatureInfo) -> None: 10 | super().__init__() 11 | self.encoder_channel = feature_info.channels() 12 | self.encoder_reduction = feature_info.reduction() 13 | self.conv = nn.ModuleList() 14 | self.conv.append( 15 | Conv2dReLU( 16 | in_channels=self.encoder_channel[-1], 17 | out_channels=self.encoder_channel[-2], 18 | kernel_size=3, 19 | stride=1, 20 | padding=1, 21 | use_batchnorm=False, 22 | ), 23 | ) 24 | self.conv.append( 25 | Conv2dReLU( 26 | in_channels=self.encoder_channel[-2] * 2, 27 | out_channels=self.encoder_channel[-3], 28 | kernel_size=3, 29 | stride=1, 30 | padding=1, 31 | use_batchnorm=False, 32 | ), 33 | ) 34 | self.conv.append( 35 | Conv2dReLU( 36 | in_channels=self.encoder_channel[-3] * 2, 37 | out_channels=self.encoder_channel[-4], 38 | kernel_size=3, 39 | stride=1, 40 | padding=1, 41 | use_batchnorm=False, 42 | ), 43 | ) 44 | self.last_conv = nn.Sequential( 45 | Conv2dReLU( 46 | in_channels=self.encoder_channel[-4] * 2, 47 | out_channels=self.encoder_channel[-4] * 2, 48 | kernel_size=3, 49 | stride=1, 50 | padding=1, 51 | use_batchnorm=False, 52 | ), 53 | nn.Conv2d( 54 | in_channels=self.encoder_channel[-4] * 2, out_channels=3, kernel_size=1 55 | ), 56 | nn.Sigmoid(), 57 | ) 58 | 59 | def forward(self, x: torch.Tensor) -> torch.Tensor: 60 | feat = x[-1] 61 | for i, m in enumerate(self.conv): 62 | feat = F.interpolate(m(feat), scale_factor=2, mode="bilinear") 63 | feat = torch.cat([feat, x[-i - 2]], dim=1) 64 | if self.encoder_reduction[0] > 1: 65 | feat = F.interpolate( 66 | feat, scale_factor=self.encoder_reduction[0], mode="bilinear" 67 | ) 68 | return self.last_conv(feat) 69 | 70 | 71 | # if __name__ == "__main__": 72 | # from encoder import SimpleEncoder 73 | 74 | # img = torch.randn(2, 3, 128, 128) 75 | # e = SimpleEncoder() 76 | # e_out = e(img) 77 | # d = SimpleDecoder(e.feature_info) 78 | # d_out = d(e_out) 79 | # print(d_out.shape) 80 | -------------------------------------------------------------------------------- /unsupervised_hdr/tools/helper_functions.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Union 2 | 3 | import numpy as np 4 | import torch 5 | 6 | TORCH_MAJOR = int(torch.__version__.split(".")[0]) 7 | TORCH_MINOR = int(torch.__version__.split(".")[1]) 8 | if TORCH_MAJOR == 1 and TORCH_MINOR < 8: 9 | from torch._six import container_abcs, int_classes, string_classes 10 | else: 11 | import collections.abc as container_abcs 12 | 13 | int_classes = int 14 | string_classes = str 15 | 16 | 17 | def input2dict(input: Any, default_key: str) -> Dict[str, Any]: 18 | processed = {} 19 | if not isinstance(input, dict) and input is not None: 20 | processed[default_key] = input 21 | return processed 22 | else: 23 | return input 24 | 25 | 26 | def input2list(input: Any) -> List[Any]: 27 | if input is not None: 28 | if not isinstance(input, list): 29 | processed = [input] 30 | return processed 31 | else: 32 | return input 33 | else: 34 | return [] 35 | 36 | 37 | def list2dict(input: List[Any]) -> Dict[str, Any]: 38 | if isinstance(input, list): 39 | output = {i: v for i, v in enumerate(input)} 40 | return output 41 | else: 42 | raise TypeError("input is not list!") 43 | 44 | 45 | def concat_data(all_data): 46 | elem = all_data[0] 47 | # import sys 48 | # print(type(all_data[0][0])) 49 | if isinstance(elem, torch.Tensor): 50 | out = None 51 | if torch.utils.data.get_worker_info() is not None: 52 | # If we're in a background process, concatenate directly into a 53 | # shared memory tensor to avoid an extra copy 54 | numel = sum([x.numel() for x in all_data]) 55 | storage = elem.storage()._new_shared(numel) 56 | out = elem.new(storage) 57 | try: 58 | return torch.cat(all_data, out=out) 59 | except RuntimeError: 60 | return all_data 61 | elif ( 62 | type(elem).__module__ == "numpy" 63 | and type(elem).__name__ != "str_" 64 | and type(elem).__name__ != "string_" 65 | ): 66 | if type(elem).__name__ == "ndarray": 67 | # array of string classes and object 68 | return np.concatenate(all_data) 69 | elif elem.shape == (): # scalars 70 | return all_data 71 | elif isinstance(elem, float): 72 | return all_data 73 | elif isinstance(elem, int_classes): 74 | return all_data 75 | elif isinstance(elem, string_classes): 76 | return all_data 77 | elif isinstance(elem, container_abcs.Mapping): 78 | # [{1:torch.Tensor(), 2:torch.Tensor()}, {1:torch.Tensor(), 2:torch.Tensor()}...]みたいな時 79 | return {key: concat_data([d[key] for d in all_data]) for key in elem} 80 | elif isinstance(elem, container_abcs.Sequence): 81 | # [[[1, 2, 3, 4], [1, 2, 3, 4]], [[1, 2, 3, 4], [1, 2, 3, 4]]] 82 | # -> [[1, 2, 3, 4,] [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]] 83 | result = [] 84 | [result.extend(d) for d in all_data] 85 | return result 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Unsupervised HDR Imaging 2 | ===== 3 | 4 | This repository is the unofficial implementation of ["Unsupervised HDR Imaging: What Can Be Learned from a Single 8-bit Video?"](https://arxiv.org/abs/2202.05522) using PyTorch. 5 | The main idea of this paper is to assume that a single video contains a variety of exposures and train to map from higher to lower exposures. 6 | 7 | Different the original paper, it does not implementation sampling of frames during training and uses avg+max pooling instead of mixed pooling. 8 | 9 | ## Demo 10 | ||-4 fstop|-2 fstop|0 fstop|2 fstop|4 fstop| 11 | |---|---|---|---|---|---| 12 | |Ground Truth|![4fstop_low_gt.jpg](resources/4fstop_low_gt.jpg)|![2fstop_low_gt.jpg](resources/2fstop_low_gt.jpg)|![ground_truth.jpg](resources/ground_truth.jpg)|![2fstop_high_gt.jpg](resources/2fstop_high_gt.jpg)|![4fstop_high_gt.jpg](resources/4fstop_high_gt.jpg)| 13 | |Predict|![4fstop_low_predict.jpg](resources/4fstop_low_predict.jpg)|![2fstop_low_predict.jpg](resources/2fstop_low_predict.jpg)|![opencv_reconstructed.jpg](resources/opencv_reconstructed.jpg)|![2fstop_high_predict.jpg](resources/2fstop_high_predict.jpg)|![4fstop_high_predict.jpg](resources/4fstop_high_predict.jpg)| 14 | |diff|![4fstop_low_diff.jpg](resources/4fstop_low_diff.jpg)|![2fstop_low_diff.jpg](resources/2fstop_low_diff.jpg)|![hdr_diff.jpg](resources/hdr_diff.jpg)|![2fstop_high_diff.jpg](resources/2fstop_high_diff.jpg)|![4fstop_high_diff.jpg](resources/4fstop_high_diff.jpg)| 15 | 16 | These images are generated using [example/christmas_tree.ipynb](example/christmas_tree.ipynb). 17 | `0 fstop(Predict)` reconstructed by MergeMertens method. 18 | ## Usage 19 | ``` python 20 | from unsupervised_hdr import UnsupervisedHDRModel 21 | model = UnsupervisedHDRModel( 22 | video_path =VIDEO_PATH, # supporing only mp4 format 23 | encoder_lr = 1e-4, 24 | decoder_lr = 1e-4, 25 | num_worker = 1, 26 | device_ids = 0, # int (device_id) or None(using all gpu) or list(device_id array, ex. [0, 1]) 27 | output_dir = LOGDIR_PATH) 28 | model.fit(max_epoch=64, batch_size=1) 29 | out = model.predict( 30 | frame_idx=None, # None(all frames) or int(frame_idx) or list(frame_idx array, ex. [0, 100, 200]) 31 | batch_size=1) 32 | # out: { 33 | # "hdr_image": reconstracted hdr image using MergeMertens of opencv(frames, h, w, c), 34 | # "exposure_list": [-4fstop, -2fstop, 0fstop(input), 2fstop, 4fstop] 35 | # } 36 | ``` 37 | ## TODO 38 | * [x] Multi GPU training 39 | * [ ] Ealry stopping 40 | * [ ] Dealing with the artifacts by convolution(other decoder) 41 | * [ ] More data augmentation(ex. scale, crop) 42 | * [ ] Support other pooling for default model(supported avg+max pooling only now) 43 | * [x] Support imagenet pretrain timm encoder 44 | * [ ] Freeze batch normalization layer 45 | * Input image size must be a multiple of 16 46 | * [ ] Support other pretrain method(ex. simsiam) 47 | * [ ] Evaluate transfer learning 48 | 49 | ## Requirement 50 | * PyTorch>=1.9.0 51 | * timm>=0.5.4 52 | * tqdm>=4.61.2 53 | 54 | ## Installation 55 | Clone: 56 | `$ git clone https://github.com/tattaka/unsupervised-hdr-imaging.git` 57 | 58 | Using pip: 59 | `$ cd unsupervised-hdr-imaging` 60 | `$ pip install .` 61 | 62 | ## License 63 | This repository is under [MIT license](https://en.wikipedia.org/wiki/MIT_License). 64 | -------------------------------------------------------------------------------- /unsupervised_hdr/core.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import re 5 | import warnings 6 | from typing import Dict, List, Tuple, Union 7 | 8 | import cv2 9 | import numpy as np 10 | import torch 11 | import torch.distributed as dist 12 | import torch.multiprocessing as mp 13 | import torch.optim as optim 14 | from torch import nn 15 | from torch.nn.parallel import DistributedDataParallel as DDP 16 | from torch.utils.data import DataLoader 17 | from torch.utils.data.distributed import DistributedSampler 18 | from tqdm.auto import tqdm 19 | 20 | from .data import LDRDataset 21 | from .models import EncoderDecoderModel, ImageSpaceLoss 22 | from .tools import helper_functions 23 | 24 | warnings.simplefilter("ignore") 25 | TORCH_MAJOR = int(torch.__version__.split(".")[0]) 26 | TORCH_MINOR = int(torch.__version__.split(".")[1]) 27 | if TORCH_MAJOR == 1 and TORCH_MINOR < 8: 28 | from torch._six import container_abcs, int_classes, string_classes 29 | else: 30 | import collections.abc as container_abcs 31 | 32 | int_classes = int 33 | string_classes = str 34 | 35 | 36 | def setup(rank, world_size): 37 | os.environ["MASTER_ADDR"] = "localhost" 38 | os.environ["MASTER_PORT"] = "12355" 39 | 40 | # initialize the process group 41 | dist.init_process_group("gloo", rank=rank, world_size=world_size) 42 | 43 | 44 | def cleanup(): 45 | dist.destroy_process_group() 46 | 47 | 48 | def seed_worker(worker_id): 49 | worker_seed = torch.initial_seed() % 2**32 50 | np.random.seed(worker_seed) 51 | random.seed(worker_seed) 52 | 53 | 54 | class UnsupervisedHDRModel: 55 | def __init__( 56 | self, 57 | video_path: str, 58 | checkpoint_path: str = None, 59 | encoder: str = "SimpleEncoder", 60 | decoder: str = "SimpleDecoder", 61 | encoder_pretrained: bool = None, 62 | encoder_lr: float = 1e-4, 63 | decoder_lr: float = 1e-4, 64 | num_worker: int = 4, 65 | device_ids: Union[str, int, List[int]] = None, 66 | output_dir: str = "./", 67 | seed: int = 0, 68 | ) -> None: 69 | self.video_path = video_path 70 | self.encoder_lr = encoder_lr 71 | self.decoder_lr = decoder_lr 72 | self.num_worker = num_worker 73 | self.output_dir = output_dir 74 | self.seed = seed 75 | self.max_epoch = None 76 | self.iterator = 0 77 | self.epoch = 0 78 | 79 | seed = 0 80 | 81 | random.seed(seed) 82 | np.random.seed(seed) 83 | torch.manual_seed(seed) 84 | torch.backends.cudnn.benchmark = False 85 | torch.backends.cudnn.deterministic = True 86 | g = torch.Generator() 87 | g.manual_seed(seed) 88 | 89 | if device_ids is None: 90 | self.device_ids = list(range(torch.cuda.device_count())) 91 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 92 | else: 93 | if device_ids == "cpu": 94 | self.device_ids = list() 95 | self.device = torch.device("cpu") 96 | else: 97 | self.device_ids = helper_functions.input2list(device_ids) 98 | self.device = torch.device(f"cuda:{self.device_ids[0]}") 99 | self.build_model(encoder, decoder, encoder_pretrained) 100 | self.configure_optimizers() 101 | self.train_dataloader = None 102 | self.predict_dataloader = None 103 | self.initialize_logger() 104 | if checkpoint_path: 105 | self.load_checkpoint(checkpoint_path) 106 | 107 | def build_model( 108 | self, encoder: str, decoder: str, encoder_pretrained: bool = False 109 | ) -> None: 110 | self.model = EncoderDecoderModel(encoder, decoder, encoder_pretrained) 111 | self.mse_loss = nn.MSELoss() 112 | self.image_space_loss = ImageSpaceLoss() 113 | 114 | def build_dataset( 115 | self, 116 | rank: int, 117 | world_size: int, 118 | train: bool, 119 | batch_size: int = 1, 120 | image_size: Tuple[int, int] = (512, 512), 121 | ) -> None: 122 | # TODO: dataloader 123 | dataset = LDRDataset(self.video_path, train=train, image_size=image_size) 124 | dataloader = DataLoader( 125 | dataset, 126 | batch_size=batch_size, 127 | shuffle=(rank == -1 and train), 128 | sampler=None 129 | if rank == -1 130 | else DistributedSampler(dataset, rank=rank, num_replicas=world_size), 131 | num_workers=self.num_worker, 132 | worker_init_fn=seed_worker, 133 | pin_memory=True, 134 | drop_last=train, 135 | ) 136 | return dataloader 137 | 138 | def initialize_logger(self) -> None: 139 | self.train_results = {} 140 | self.lr = {} 141 | if not os.path.exists(self.output_dir): 142 | os.makedirs(self.output_dir) 143 | 144 | if not os.path.exists(os.path.join(self.output_dir, "checkpoints")): 145 | os.makedirs(os.path.join(self.output_dir, "checkpoints")) 146 | # TODO: save hyper parameters 147 | self.train_results["best_loss"] = float("inf") 148 | self.result_info = "" 149 | 150 | def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]: 151 | """Forward pass. Returns logits.""" 152 | outputs = {} 153 | outputs["pred_delta"] = self.model(image) 154 | outputs["pred_Ib"] = image * outputs["pred_delta"] 155 | return outputs 156 | 157 | def loss( 158 | self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor] 159 | ) -> Dict[str, torch.Tensor]: 160 | losses = {} 161 | losses["loss_delta"] = self.mse_loss( 162 | (batch["Ib"] + 1e-6) / (batch["Ih"] + 1e-6), outputs["pred_delta"] 163 | ) 164 | losses["loss_image"] = self.image_space_loss(batch["Ib"], outputs["pred_Ib"]) 165 | losses["loss"] = losses["loss_delta"] + losses["loss_image"] 166 | return losses 167 | 168 | def training_step( 169 | self, batch: Dict[str, torch.Tensor], batch_idx: int = None 170 | ) -> Dict[str, torch.Tensor]: 171 | step_output = {} 172 | outputs = self.forward(batch["Ih"]) 173 | train_loss = self.loss(outputs, batch) 174 | step_output.update(train_loss) 175 | train_loss["loss"].backward() 176 | self.optimizer.step() 177 | return step_output 178 | 179 | def training_one_epoch(self, rank=-1, device=None) -> Dict[str, float]: 180 | self.model.train() 181 | losses = {} 182 | with tqdm( 183 | self.train_dataloader, 184 | position=0, 185 | leave=True, 186 | ascii=" ##", 187 | dynamic_ncols=True, 188 | disable=rank > 0, 189 | ) as t: 190 | for batch_idx, batch_data in enumerate(t): 191 | batch_data = self.cuda(batch_data, device=device if rank < 0 else rank) 192 | step_output = self.training_step(batch_data) 193 | t.set_description("Epoch %i Training" % self.epoch) 194 | print_losses = {} 195 | for key in step_output: 196 | print_losses[key] = step_output[key].item() 197 | if key in losses.keys(): 198 | losses[key].append(step_output[key].item()) 199 | else: 200 | losses[key] = [step_output[key].item()] 201 | t.set_postfix(ordered_dict=dict(**print_losses)) 202 | self.iterator += 1 203 | if rank <= 0: 204 | for key in losses: 205 | losses[key] = sum(losses[key]) / len(losses[key]) 206 | if len(losses.values()) > 1 and not ("loss" in losses.keys()): 207 | losses["loss"] = sum(losses.values()) 208 | for key in losses: 209 | self.train_results[key] = losses[key] 210 | if self.train_results["best_loss"] >= self.train_results["loss"]: 211 | self.train_results["best_loss"] = self.train_results["loss"] 212 | self.save_checkpoint(metric="loss") 213 | if self.epoch != 0: 214 | self.save_checkpoint() 215 | self.lr = {} 216 | self.lr["lr"] = [group["lr"] for group in self.optimizer.param_groups][0] 217 | self.result_info = "" 218 | for result_key, result_value in zip( 219 | self.train_results.keys(), self.train_results.values() 220 | ): 221 | self.result_info = ( 222 | self.result_info 223 | + result_key 224 | + ":" 225 | + str(round(result_value, 4)) 226 | + " " 227 | ) 228 | for lr_key, lr_value in zip(self.lr.keys(), self.lr.values()): 229 | self.result_info = ( 230 | self.result_info + lr_key + ":" + str(round(lr_value, 4)) + " " 231 | ) 232 | print("Epoch %i" % self.epoch, self.result_info) 233 | self.epoch += 1 234 | return losses 235 | 236 | def fit_single( 237 | self, rank: int, max_epoch: int, world_size: int, batch_size: int = 1 238 | ) -> None: 239 | # TODO: early stopping 240 | self.train_dataloader = self.build_dataset( 241 | rank=rank, 242 | world_size=world_size, 243 | train=True, 244 | batch_size=batch_size, 245 | ) 246 | use_ddp = world_size > 1 247 | if use_ddp: 248 | setup(rank, world_size) 249 | self.model = self.model.to(rank) 250 | self.model = DDP(self.model, device_ids=[rank]) 251 | else: 252 | self.model = self.model.to(self.device) 253 | self.max_epoch = max_epoch 254 | for _ in range(max_epoch): 255 | self.training_one_epoch(rank=rank, device=self.device) 256 | if use_ddp: 257 | cleanup() 258 | 259 | def fit(self, max_epoch: int, batch_size: int = 1) -> None: 260 | if len(self.device_ids) > 1: 261 | mp.spawn( 262 | self.fit_single, 263 | args=(max_epoch, len(self.device_ids), batch_size), 264 | nprocs=len(self.device_ids), 265 | join=True, 266 | ) 267 | else: 268 | self.fit_single(-1, max_epoch, len(self.device_ids), batch_size) 269 | 270 | def predict_step( 271 | self, batch: Dict[str, torch.Tensor], batch_idx: int = None 272 | ) -> Dict[str, np.ndarray]: 273 | with torch.no_grad(): 274 | delta = self.model(batch["Ib"]) # (bs, 3, h, w) 275 | Il_2 = batch["Ib"] * delta 276 | delta_2l = self.model(Il_2) 277 | Il_4 = Il_2 * delta_2l 278 | Ih_2 = batch["Ib"] / delta 279 | delta_2h = self.model(Ih_2) 280 | Ih_4 = Ih_2 / delta_2h 281 | exposure_list = [ 282 | Il_4, 283 | Il_2, 284 | batch["Ib"], 285 | Ih_2, 286 | Ih_4, 287 | ] 288 | exposure_list = np.stack( 289 | [ 290 | (img.clone().detach().cpu().numpy().clip(0, 1) * 255).astype(np.uint8) 291 | for img in exposure_list 292 | ] 293 | ).transpose( 294 | 1, 0, 3, 4, 2 295 | ) # (bs, 5, h, w, 3) 296 | 297 | merge_mertens = cv2.createMergeMertens() 298 | hdr_image = [merge_mertens.process(img_set) for img_set in exposure_list] 299 | output = {"exposure_list": exposure_list, "hdr_image": hdr_image} 300 | return output # BGR 301 | 302 | def predict( 303 | self, 304 | frame_idx: Union[int, List[int]] = None, 305 | batch_size: int = 1, 306 | image_size=None, 307 | ) -> Dict[str, np.ndarray]: 308 | self.predict_dataloader = self.build_dataset( 309 | rank=-1, 310 | world_size=1, 311 | train=False, 312 | batch_size=batch_size, 313 | image_size=image_size, 314 | ) 315 | best_checkpoint = os.path.join(self.output_dir, "checkpoints", "best_loss.pth") 316 | print(f"Start loading best checkpoint from {best_checkpoint}") 317 | self.load_checkpoint(best_checkpoint) 318 | self.cuda(self.model, device=self.device) 319 | self.model.eval() 320 | print("Finish loading!") 321 | output = {"exposure_list": [], "hdr_image": []} 322 | if frame_idx is None: 323 | for batch in tqdm(self.predict_dataloader): 324 | batch = self.cuda(batch, self.device) 325 | output_batch = self.predict_step(batch) 326 | output["exposure_list"].append(output_batch["exposure_list"]) 327 | output["hdr_image"].append(output_batch["hdr_image"]) 328 | else: 329 | frame_idx = helper_functions.input2list(frame_idx) 330 | for i in tqdm( 331 | range(math.ceil(len(frame_idx) / self.predict_dataloader.batch_size)) 332 | ): 333 | batch = [] 334 | for f in frame_idx[ 335 | i 336 | * self.predict_dataloader.batch_size : (i + 1) 337 | * self.predict_dataloader.batch_size 338 | ]: 339 | batch.append(self.predict_dataloader.dataset[f]) 340 | batch = self.cuda(batch, self.device) 341 | if len(batch) == 1: 342 | batch = {k: b.unsqueeze(0) for k, b in batch[0].items()} 343 | else: 344 | batch = helper_functions.concat_data(batch) 345 | output_batch = self.predict_step(batch) 346 | output["exposure_list"].append(output_batch["exposure_list"]) 347 | output["hdr_image"].append(output_batch["hdr_image"]) 348 | output["exposure_list"] = np.concatenate(output["exposure_list"]) 349 | output["hdr_image"] = np.concatenate(output["hdr_image"]) 350 | return output 351 | 352 | def load_checkpoint(self, checkpoint_path): 353 | self.model.load_state_dict( 354 | torch.load( 355 | checkpoint_path, 356 | map_location=torch.device("cpu"), 357 | )["model_state_dict"] 358 | ) 359 | 360 | def save_checkpoint(self, metric=None): 361 | checkpoint = {} 362 | if metric is None: 363 | file_path = os.path.join(self.output_dir, "checkpoints", "last.pth") 364 | else: 365 | file_path = os.path.join( 366 | self.output_dir, "checkpoints", "best_" + metric + ".pth" 367 | ) 368 | checkpoint = { 369 | "best_epoch": self.epoch, 370 | "best_" + metric: self.train_results["best_" + metric], 371 | } 372 | checkpoint["model_state_dict"] = ( 373 | self.model.module.state_dict() 374 | if isinstance(self.model, DDP) 375 | else self.model.state_dict() 376 | ) 377 | 378 | torch.save(checkpoint, file_path) 379 | 380 | def configure_optimizers(self) -> optim.Optimizer: 381 | 382 | self.optimizer = optim.Adam( 383 | [ 384 | {"params": self.model.encoder.parameters(), "lr": self.encoder_lr}, 385 | {"params": self.model.decoder.parameters(), "lr": self.decoder_lr}, 386 | ], 387 | lr=self.decoder_lr, 388 | weight_decay=0.0001, 389 | ) 390 | 391 | def cuda(self, x, device=None): 392 | np_str_obj_array_pattern = re.compile(r"[SaUO]") 393 | if torch.cuda.is_available(): 394 | if isinstance(x, torch.Tensor): 395 | x = x.cuda(non_blocking=True, device=device) 396 | return x 397 | elif isinstance(x, nn.Module): 398 | x = x.cuda(device=device) 399 | return x 400 | elif isinstance(x, np.ndarray): 401 | if x.shape == (): 402 | if np_str_obj_array_pattern.search(x.dtype.str) is not None: 403 | return x 404 | return self.cuda(torch.as_tensor(x), device=device) 405 | return self.cuda(torch.from_numpy(x), device=device) 406 | elif isinstance(x, float): 407 | return self.cuda(torch.tensor(x, dtype=torch.float64), device=device) 408 | elif isinstance(x, int_classes): 409 | return self.cuda(torch.tensor(x), device=device) 410 | elif isinstance(x, string_classes): 411 | return x 412 | elif isinstance(x, container_abcs.Mapping): 413 | return {key: self.cuda(x[key], device=device) for key in x} 414 | elif isinstance(x, container_abcs.Sequence): 415 | return [ 416 | self.cuda(np.array(xi), device=device) 417 | if isinstance(xi, container_abcs.Sequence) 418 | else self.cuda(xi, device=device) 419 | for xi in x 420 | ] 421 | 422 | def to_cpu(self, x): 423 | if isinstance(x, torch.Tensor) and x.device != "cpu": 424 | return x.clone().detach().cpu() 425 | elif isinstance(x, np.ndarray): 426 | return x 427 | elif isinstance(x, container_abcs.Mapping): 428 | return {key: self.to_cpu(x[key]) for key in x} 429 | elif isinstance(x, container_abcs.Sequence): 430 | return [self.to_cpu(xi) for xi in x] 431 | else: 432 | return x 433 | --------------------------------------------------------------------------------