├── .gitignore ├── LICENSE ├── README.md ├── imgs └── title.webp ├── requirements.txt └── src ├── __init__.py ├── config ├── aug │ ├── none.yaml │ └── resize512.yaml ├── config.yaml ├── ds │ ├── test.yaml │ ├── train.yaml │ └── valid.yaml └── runtime │ ├── bilateral_upsample_net.default.yaml │ ├── hist_unet.default.yaml │ ├── lcdpnet.default.yaml │ └── lcdpnet.release.yaml ├── data ├── __init__.py ├── augmentation.py └── img_dataset.py ├── env.yaml ├── globalenv.py ├── model ├── __init__.py ├── arch │ ├── __init__.py │ ├── drconv.py │ ├── hist.py │ ├── nonlocal_block_embedded_gaussian.py │ └── unet_based │ │ ├── __init__.py │ │ └── hist_unet.py ├── basemodel.py ├── basic_loss.py ├── bilateralupsamplenet.py ├── lcdpnet.py └── single_net_basemodel.py ├── test.py ├── train.py └── utils ├── __init__.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | imgs/title.png 9 | .DS_Store 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 hywang99 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | 5 | 6 | 7 | 8 | PWC 9 | 10 | 11 | [`🌐 Website`](https://whyy.site/paper/lcdp)  ·  [`📃 Paper`](https://www.cs.cityu.edu.hk/~rynson/papers/eccv22b.pdf)  ·  [`🗃️ Dataset`](https://drive.google.com/drive/folders/10Reaq-N0DiZiFpSrZ8j5g3g0EJes4JiS?usp=sharing) 12 |
13 | 14 | 15 | **Abstract:** Existing image enhancement methods are typically designed to address either the over- or under-exposure problem in the input image. When the illumination of the input image contains both over- and under-exposure problems, these existing methods may not work well. We observe from the image statistics that the local color distributions (LCDs) of an image suffering from both problems tend to vary across different regions of the image, depending on the local illuminations. Based on this observation, we propose in this paper to exploit these LCDs as a prior for locating and enhancing the two types of regions (i.e., over-/under-exposed regions). First, we leverage the LCDs to represent these regions, and propose a novel local color distribution embedded (LCDE) module to formulate LCDs in multi-scales to model the correlations across different regions. Second, we propose a dual-illumination learning mechanism to enhance the two types of regions. Third, we construct a new dataset to facilitate the learning process, by following the camera image signal processing (ISP) pipeline to render standard RGB images with both under-/over-exposures from raw data. Extensive experiments demonstrate that the proposed method outperforms existing state-of-the-art methods quantitatively and qualitatively. 16 | 17 | ## 📻 News 18 | 19 | - 2023.7.21: if you have an interest in low-light enhancement and NeRF, please check out my latest ICCV2023 work, [LLNeRF](https://github.com/onpix/LLNeRF) ! 🔥🔥🔥 20 | - 2023.7.21: Update README 21 | - 2023.2.7: Merge `tar.gz` files of our dataset to a single `7z` file. 22 | - 2023.2.8: Update packages version in `requirements.txt`. 23 | - 2023.2.8: Upload `env.yaml`. 24 | 25 | ## 🔥 Our Model 26 | 27 | ![Our model](https://hywang99.github.io/images/lcdpnet/arch.png) 28 | 29 | 30 | ## ⚙️ Setup 31 | 32 | 1. Clone `git clone https://github.com/onpix/LCDPNet.git` 33 | 2. Go to directory `cd LCDPNet` 34 | 3. Install required packages `pip install -r requirements.txt` 35 | 36 | We also provide `env.yaml` for quickly installing packages. Note that you may need to modify the env name to prevent overwriting your existing enviroment, or modify cudatoolkit and cudnn version in `env.yaml` to match your local cuda version. 37 | 38 | ## ⌨️ How to run 39 | 40 | To train our model: 41 | 42 | 1. Prepare data: Modify `src/config/ds/train.yaml` and `src/config/ds/valid.yaml`. 43 | 2. Modify configs in `src/config`. Note that we use `hydra` for config management. 44 | 3. Run: `python src/train.py name= num_epoch=200 log_every=2000 valid_every=20` 45 | 46 | To test our model: 47 | 48 | 1. Prepare data: Modify `src/config/ds/test.yaml` 49 | 2. Run: `python src/test.py checkpoint_path=` 50 | 51 | ## 📂 Dataset & Pretrained Model 52 | 53 | The LCDP Dataset is here: [[Google drive]](https://drive.google.com/drive/folders/10Reaq-N0DiZiFpSrZ8j5g3g0EJes4JiS?usp=sharing). Please unzip `lcdp_dataset.7z`. The training and test images are: 54 | 55 | | | Train | Test | 56 | | ----- | ------------- | ------------------ | 57 | | Input | `input/*.png` | `test-input/*.png` | 58 | | GT | `gt/*.png` | `test-gt/*.png` | 59 | 60 | We provide the two pretrained models: `pretrained_models/trained_on_ours.ckpt` and `pretrained_models/trained_on_MSEC.ckpt` for researchers to reproduce the results in Table 1 and Table 2 in our paper. Note that we train `pretrained_models/trained_on_MSEC.ckpt` on the Expert C subset of the MSEC dataset with both over and under-exposed images. 61 | 62 | | Filename | Training data | Testing data | Test PSNR | Test SSIM | 63 | | -------------------- | ------------------------------------------------------------ | ---------------------------- | --------- | --------- | 64 | | trained_on_ours.ckpt | Ours | Our testing data | 23.239 | 0.842 | 65 | | trained_on_MSEC.ckpt | [MSEC](https://github.com/mahmoudnafifi/Exposure_Correction) | MSEC testing data (Expert C) | 22.295 | 0.855 | 66 | 67 | Our model is lightweight. Experiments show that increasing model size will further improve the quality of the results. To train a bigger model, increase the values in `runtime.bilateral_upsample_net.hist_unet.channel_nums`. 68 | 69 | ## 🔗 Cite This Paper 70 | 71 | If you find our work or code helpful, or your research benefits from this repo, please cite our paper: 72 | 73 | ```bibtex 74 | @inproceedings{wang2022lcdp, 75 | title = {Local Color Distributions Prior for Image Enhancement}, 76 | author = {Haoyuan Wang, Ke Xu, and Rynson W.H. Lau}, 77 | booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)}, 78 | year = {2022} 79 | } 80 | ``` -------------------------------------------------------------------------------- /imgs/title.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onpix/LCDPNet/4faa0d98e8ff45f53a3569dd005a74353995b335/imgs/title.webp -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ipdb == 0.13.9 2 | hydra-core == 1.1.1 3 | pytorch_lightning == 1.7.6 4 | scikit-image == 0.19.2 5 | kornia == 0.6.7 6 | wandb == 0.13.8 7 | opencv-python == 4.5.5.64 8 | matplotlib == 3.5.1 9 | torchvision == 0.13.1 10 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onpix/LCDPNet/4faa0d98e8ff45f53a3569dd005a74353995b335/src/__init__.py -------------------------------------------------------------------------------- /src/config/aug/none.yaml: -------------------------------------------------------------------------------- 1 | # @package aug 2 | 3 | 4 | # crop size: 5 | # - false for not cropping; 6 | # - int n for cropping square [n. n]; 7 | # - [h, w] for cropping specific size/ 8 | crop: false 9 | 10 | 11 | # downsample factor. 12 | # - value > 1 for downsample; 13 | # - 0.x for upsample; 14 | # - [h, w] for specific size; 15 | # - false for not resizing. 16 | downsample: false 17 | 18 | h-flip: false 19 | v-flip: false 20 | 21 | lightness_adjust: false 22 | contrast_adjust: false 23 | 24 | resize_divisible_n: false # align input image size to the number divisible by N to avoid some size-mismatch error. -------------------------------------------------------------------------------- /src/config/aug/resize512.yaml: -------------------------------------------------------------------------------- 1 | # @package aug 2 | 3 | # crop size: 4 | # - false for not cropping; 5 | # - int n for cropping square [n. n]; 6 | # - [h, w] for cropping specific sizes 7 | crop: false 8 | 9 | 10 | # downsample factor. 11 | # - value > 1 for downsample; 12 | # - 0.x for upsample; 13 | # - [h, w] for specific size; 14 | # - false for not resizing. 15 | downsample: [ 512, 512 ] 16 | 17 | h-flip: true 18 | v-flip: true -------------------------------------------------------------------------------- /src/config/config.yaml: -------------------------------------------------------------------------------- 1 | project: default_proj 2 | name: default_name # name of experiment 3 | comment: false 4 | debug: false 5 | val_debug_step_nums: 2 # sanity check num 6 | gpu: -1 # number of gpus to use (int) or which GPUs to use (list or str) 7 | backend: ddp # gpu accelerator. value: ddp or none 8 | runtime_precision: 16 9 | amp_backend: native # native or apex 10 | amp_level: O1 11 | dataloader_num_worker: 5 12 | mode: train 13 | logger: tb 14 | 15 | # frequently changed configs: 16 | num_epoch: 1000 17 | valid_every: 10 # validate every N EPOCHS 18 | savemodel_every: 4 # run ModelCheckpoint every N EPOCHS 19 | log_every: 100 # log your message, curve or images every N STEPS 20 | batchsize: 16 21 | valid_batchsize: 1 22 | lr: 1e-4 23 | checkpoint_path: null 24 | 25 | checkpoint_monitor: loss 26 | resume_training: true 27 | monitor_mode: min 28 | early_stop: false 29 | valid_ratio: 0.1 30 | 31 | flags: { } 32 | 33 | defaults: 34 | - aug: resize512 35 | - ds@train_ds: train 36 | - ds@test_ds: test 37 | - ds@valid_ds: valid 38 | - runtime: lcdpnet.release 39 | 40 | hydra: 41 | run: 42 | dir: ./ 43 | -------------------------------------------------------------------------------- /src/config/ds/test.yaml: -------------------------------------------------------------------------------- 1 | # @package ds 2 | class: img_dataset 3 | name: lcdp_data.test 4 | input: 5 | - your_dataset_path/test-input/* 6 | GT: 7 | - your_dataset_path/test-gt/* 8 | -------------------------------------------------------------------------------- /src/config/ds/train.yaml: -------------------------------------------------------------------------------- 1 | # @package ds 2 | class: img_dataset 3 | name: lcdp_data.train 4 | input: 5 | - your_dataset_path/input/* 6 | GT: 7 | - your_dataset_path/gt/* 8 | -------------------------------------------------------------------------------- /src/config/ds/valid.yaml: -------------------------------------------------------------------------------- 1 | # @package ds 2 | class: img_dataset 3 | name: lcdp_data.valid 4 | input: 5 | - your_dataset_path/valid-input/* 6 | GT: 7 | - your_dataset_path/valid-gt/* 8 | -------------------------------------------------------------------------------- /src/config/runtime/bilateral_upsample_net.default.yaml: -------------------------------------------------------------------------------- 1 | # @package runtime 2 | 3 | modelname: bilateral_upsample_net 4 | predict_illumination: false 5 | loss: # when self_supervised is true, this option will be ignored. 6 | mse: 1.0 7 | cos: 0.1 8 | ltv: 0.1 # only matters when predict_illumination is true. 9 | 10 | luma_bins: 8 11 | channel_multiplier: 1 12 | spatial_bin: 16 13 | batch_norm: true 14 | low_resolution: 256 15 | coeffs_type: matrix # selected from [matrix, gamma, retinex] 16 | 17 | conv_type: conv 18 | 19 | # choose from: [ori, hist-unet] 20 | backbone: ori 21 | 22 | # type: false or int. 23 | # if N, do : illu_map **= N to adjust the brightness of the output. 24 | illu_map_power: false 25 | 26 | # only work when using hist-unet 27 | defaults: 28 | - hist_unet.default@hist_unet 29 | -------------------------------------------------------------------------------- /src/config/runtime/hist_unet.default.yaml: -------------------------------------------------------------------------------- 1 | n_bins: 8 2 | hist_as_guide: false 3 | channel_nums: false 4 | encoder_use_hist: false 5 | guide_feature_from_hist: false 6 | region_num: 8 7 | use_gray_hist: false 8 | conv_type: drconv 9 | down_ratio: 2 10 | hist_conv_trainable: false 11 | drconv_position: [ 1,1 ] -------------------------------------------------------------------------------- /src/config/runtime/lcdpnet.default.yaml: -------------------------------------------------------------------------------- 1 | # @package runtime 2 | 3 | modelname: lcdpnet 4 | use_wavelet: false 5 | use_attn_map: false 6 | use_non_local: false 7 | 8 | how_to_fuse: cnn-weights 9 | 10 | backbone: unet 11 | 12 | # choose from : [conv, drconv] 13 | conv_type: conv 14 | 15 | # the output of the backbone is the illu_net. 16 | # If false, the output of the backbone is directly to darken & brighten input. 17 | backbone_out_illu: true 18 | 19 | illumap_channel: 3 # 1 or 3 20 | 21 | # 2 branches share the same weights to predict the illu map. 22 | share_weights: true 23 | 24 | # only work when using hist-unet 25 | n_bins: 8 26 | hist_as_guide: false 27 | 28 | loss: 29 | ltv: 0 # ltv applied on the FINAL OUTPUT. 30 | cos: 0.5 # cos similarity loss 31 | weighted_loss: 0 # weighted loss instead of l1loss 32 | tvloss1: 0.01 # tvloss applied on the illumination 33 | tvloss2: 0.01 # tvloss applied on the inverse illumination 34 | tvloss1_new: 0 35 | tvloss2_new: 0 36 | 37 | l1_loss: 1.0 # default pixel-wise l1 loss 38 | ssim_loss: 0 39 | psnr_loss: 0 40 | illumap_loss: 0 # constraint illumap1 + illumap2 -> 1 41 | hist_loss: 0 42 | inter_hist_loss: 0 43 | vgg_loss: 0 44 | 45 | defaults: 46 | - bilateral_upsample_net.default@bilateral_upsample_net 47 | - hist_unet.default@hist_unet -------------------------------------------------------------------------------- /src/config/runtime/lcdpnet.release.yaml: -------------------------------------------------------------------------------- 1 | # @package runtime 2 | 3 | defaults: 4 | - lcdpnet.default 5 | 6 | loss: 7 | ltv: 0 # ltv applied on the FINAL OUTPUT. 8 | cos: 0 # cos similarity loss 9 | cos2: 0.5 # cos with sigmoid 10 | tvloss1: 0 # tvloss applied on the illumination 11 | tvloss2: 0 # tvloss applied on the inverse illumination 12 | l1_loss: 1.0 # default pixel-wise l1 loss 13 | tvloss1_new: 0.01 14 | tvloss2_new: 0.01 15 | 16 | backbone: bilateral_upsample_net 17 | 18 | bilateral_upsample_net: 19 | backbone: hist-unet 20 | hist_unet: 21 | guide_feature_from_hist: true 22 | region_num: 2 23 | drconv_position: [ 0,1 ] 24 | channel_nums: [ 8,16,32,64,128 ] -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onpix/LCDPNet/4faa0d98e8ff45f53a3569dd005a74353995b335/src/data/__init__.py -------------------------------------------------------------------------------- /src/data/augmentation.py: -------------------------------------------------------------------------------- 1 | # import cv2 2 | from collections.abc import Iterable 3 | 4 | import numpy as np 5 | import torchvision.transforms.functional as F 6 | from torchvision import transforms 7 | 8 | from globalenv import * 9 | 10 | 11 | class Downsample: 12 | def __init__(self, downsample_factor=None): 13 | self.downsample_factor = downsample_factor 14 | 15 | if isinstance(self.downsample_factor, Iterable): 16 | # should be [h, w] 17 | assert len(downsample_factor) == 2 18 | 19 | def __call__(self, img): 20 | ''' 21 | img: passed by the previous transforms. PIL iamge or np.ndarray 22 | ''' 23 | origin_h = img.size[1] 24 | origin_w = img.size[0] 25 | if isinstance(self.downsample_factor, Iterable): 26 | # pass [h,w] 27 | if -1 in self.downsample_factor: 28 | # automatic calculate the output size: 29 | h_scale = origin_h / self.downsample_factor[0] 30 | w_scale = origin_w / self.downsample_factor[1] 31 | 32 | # choose the correct one 33 | scale = max(w_scale, h_scale) 34 | new_size = [ 35 | int(origin_h / scale), # H 36 | int(origin_w / scale) # W 37 | ] 38 | else: 39 | new_size = self.downsample_factor # [H, W] 40 | 41 | elif type(self.downsample_factor + 0.1) == float: 42 | # pass a number as scale factor 43 | # PIL.Image, cv2.resize and torchvision.transforms.Resize all accepts [W, H] 44 | new_size = [ 45 | int(img.size[1] / self.downsample_factor), # H 46 | int(img.size[0] / self.downsample_factor) # W 47 | ] 48 | else: 49 | raise RuntimeError(f'ERR: Wrong config aug.downsample: {self.downsample_factor}') 50 | 51 | img = img.resize(new_size[::-1]) # reverse passed [h, w] to [w, h] 52 | return img 53 | 54 | def __repr__(self): 55 | return self.__class__.__name__ + f'({self.downsample_factor})' 56 | 57 | 58 | def get_value(d, k): 59 | if k in d and d[k]: 60 | return d[k] 61 | else: 62 | return False 63 | 64 | 65 | def parseAugmentation(opt): 66 | ''' 67 | return: pytorch composed transform 68 | ''' 69 | aug_config = opt[AUGMENTATION] 70 | aug_list = [transforms.ToPILImage(), ] 71 | 72 | # the order is fixed: 73 | augmentaionFactory = { 74 | DOWNSAMPLE: Downsample(aug_config[DOWNSAMPLE]) 75 | if get_value(aug_config, DOWNSAMPLE) else None, 76 | CROP: transforms.RandomCrop(aug_config[CROP]) 77 | if get_value(aug_config, CROP) else None, 78 | HORIZON_FLIP: transforms.RandomHorizontalFlip(), 79 | VERTICAL_FLIP: transforms.RandomVerticalFlip(), 80 | } 81 | 82 | for k, v in augmentaionFactory.items(): 83 | if get_value(aug_config, k): 84 | aug_list.append(v) 85 | 86 | aug_list.append(transforms.ToTensor()) 87 | print('Dataset augmentation:') 88 | print(aug_list) 89 | return transforms.Compose(aug_list) 90 | 91 | 92 | if __name__ == '__main__': 93 | pass 94 | -------------------------------------------------------------------------------- /src/data/img_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | from collections.abc import Iterable 4 | from glob import glob 5 | 6 | import cv2 7 | import ipdb 8 | import numpy as np 9 | import torch 10 | import torchvision 11 | from pytorch_lightning import LightningDataModule 12 | 13 | from globalenv import * 14 | from .augmentation import parseAugmentation 15 | 16 | 17 | def parse_item_list_txt(txtpath): 18 | ''' 19 | Parse txt file containing all file paths. Each line one file. 20 | ''' 21 | txt = Path(txtpath) 22 | basedir = txt.parent 23 | content = txt.open().read().splitlines() 24 | 25 | sample = content[0] 26 | if sample.split('/')[0] in str(basedir): 27 | raise NotImplementedError('Not implemented: file path in txt and basedir have comman str.') 28 | else: 29 | assert (basedir / sample).exists() 30 | return [str(basedir / x) for x in content] 31 | 32 | 33 | def load_from_glob_list(globs): 34 | if type(globs) == str: 35 | if globs.endswith('.txt'): 36 | print(f'Parse txt file: {globs}') 37 | return parse_item_list_txt(globs) 38 | else: 39 | return sorted(glob(globs)) 40 | 41 | elif isinstance(globs, Iterable): 42 | # if iamges are glob lists, sort EACH list ** individually ** 43 | res = [] 44 | for g in globs: 45 | assert not g.endswith('.txt'), 'TXT file should not in glob list.' 46 | res.extend(sorted(glob(g))) 47 | return res 48 | else: 49 | ipdb.set_trace() 50 | raise TypeError( 51 | f'ERR: `ds.GT` or `ds.input` has wrong type: expect `str` or `list` but get {type(globs)}') 52 | 53 | 54 | def augment_one_img(img, seed, transform=None): 55 | img = img.astype(np.uint8) 56 | random.seed(seed) 57 | torch.manual_seed(seed) 58 | if transform: 59 | img = transform(img) 60 | return img 61 | 62 | 63 | class ImagesDataset(torch.utils.data.Dataset): 64 | 65 | def __init__(self, opt, ds_type=TRAIN_DATA, transform=None, batchsize=None): 66 | """Initialisation for the Dataset object 67 | transform: PyTorch image transformations to apply to the images 68 | """ 69 | self.transform = transform 70 | self.opt = opt 71 | 72 | gt_globs = opt[ds_type][GT] 73 | input_globs = opt[ds_type][INPUT] 74 | self.have_gt = True if gt_globs else False 75 | 76 | print(f'{ds_type} - GT Directory path: [yellow]{gt_globs}[/yellow]') 77 | print(f'{ds_type} - Input Directory path: [yellow]{input_globs}[/yellow]') 78 | 79 | # load input images: 80 | self.input_list = load_from_glob_list(input_globs) 81 | 82 | # load GT images: 83 | if self.have_gt: 84 | self.gt_list = load_from_glob_list(gt_globs) 85 | try: 86 | assert len(self.input_list) == len(self.gt_list) 87 | except: 88 | ipdb.set_trace() 89 | raise AssertionError( 90 | f'In [{ds_type}]: len(input_images) ({len(self.input_list)}) != len(gt_images) ({len(self.gt_list)})! ') 91 | 92 | # import ipdb; ipdb.set_trace() 93 | print( 94 | f'{ds_type} Dataset length: {self.__len__()}, batch num: {self.__len__() // batchsize}') 95 | 96 | if self.__len__() == 0: 97 | print(f'Error occured! Your ds is: TYPE={ds_type}, config:') 98 | print(opt[ds_type]) 99 | raise RuntimeError(f'[ Err ] Dataset input nums is 0!') 100 | 101 | def __len__(self): 102 | return (len(self.input_list)) 103 | 104 | def debug_save_item(self, input, gt): 105 | # home = os.environ['HOME'] 106 | util.saveTensorAsImg(input, 'i.png') 107 | util.saveTensorAsImg(gt, 'o.png') 108 | 109 | def __getitem__(self, idx): 110 | """Returns a pair of images with the given identifier. This is lazy loading 111 | of data into memory. Only those image pairs needed for the current batch 112 | are loaded. 113 | 114 | :param idx: image pair identifier 115 | :returns: dictionary containing input and output images and their identifier 116 | :rtype: dictionary 117 | 118 | """ 119 | res_item = {INPUT_FPATH: self.input_list[idx]} 120 | 121 | # different seed for different item, but same for GT and INPUT in one item: 122 | # the "seed of seed" is fixed for reproducing 123 | # random.seed(GLOBAL_SEED) 124 | seed = random.randint(0, 100000) 125 | input_img = cv2.imread(self.input_list[idx])[:, :, [2, 1, 0]] 126 | if self.have_gt and self.gt_list[idx].endswith('.hdr'): 127 | input_img = torch.Tensor(input_img / 255).permute(2, 0, 1) 128 | else: 129 | input_img = augment_one_img(input_img, seed, transform=self.transform) 130 | res_item[INPUT] = input_img 131 | 132 | if self.have_gt: 133 | res_item[GT_FPATH] = self.gt_list[idx] 134 | 135 | if res_item[GT_FPATH].endswith('.hdr'): 136 | # gt may be HDR 137 | # do not augment HDR image. 138 | gt_img = cv2.imread(self.gt_list[idx], flags=cv2.IMREAD_ANYDEPTH)[:, :, [2, 1, 0]] 139 | gt_img = torch.Tensor(np.log10(gt_img + 1)).permute(2, 0, 1) 140 | else: 141 | gt_img = cv2.imread(self.gt_list[idx])[:, :, [2, 1, 0]] 142 | gt_img = augment_one_img(gt_img, seed, transform=self.transform) 143 | 144 | res_item[GT] = gt_img 145 | assert res_item[GT].shape == res_item[INPUT].shape 146 | 147 | return res_item 148 | 149 | 150 | class DataModule(LightningDataModule): 151 | def __init__(self, opt, apply_test_transform=False, apply_valid_transform=False): 152 | super().__init__() 153 | self.opt = opt 154 | self.transform = parseAugmentation(opt) 155 | # self.transform = None 156 | if apply_test_transform: 157 | self.test_transform = self.transform 158 | else: 159 | self.test_transform = torchvision.transforms.ToTensor() 160 | 161 | if apply_valid_transform: 162 | self.valid_transform = self.transform 163 | else: 164 | self.valid_transform = torchvision.transforms.ToTensor() 165 | 166 | self.training_dataset = None 167 | self.valid_dataset = None 168 | self.test_dataset = None 169 | 170 | def prepare_data(self): 171 | # download, split, etc... 172 | # only called on 1 GPU/TPU in distributed 173 | ... 174 | 175 | def setup(self, stage): 176 | opt = self.opt 177 | # from train.py 178 | if stage == "fit": 179 | # if opt[TRAIN_DATA]: 180 | assert opt[TRAIN_DATA][INPUT] 181 | self.training_dataset = ImagesDataset(opt, ds_type=TRAIN_DATA, transform=self.transform, batchsize=opt.batchsize) 182 | 183 | # valid data provided: 184 | if opt[VALID_DATA] and opt[VALID_DATA][INPUT]: 185 | self.valid_dataset = ImagesDataset(opt, ds_type=VALID_DATA, transform=self.valid_transform, batchsize=opt.valid_batchsize) 186 | 187 | # no valid data, splt from training data: 188 | elif opt[VALID_RATIO]: 189 | print(f'Split valid dataset from training data. Ratio: {opt[VALID_RATIO]}') 190 | valid_size = int(opt[VALID_RATIO] * len(self.training_dataset)) 191 | train_size = len(self.training_dataset) - valid_size 192 | torch.manual_seed(233) 193 | self.training_dataset, self.valid_dataset = torch.utils.data.random_split(self.training_dataset, [ 194 | train_size, valid_size 195 | ]) 196 | print( 197 | f'Update - training data: {len(self.training_dataset)}; valid data: {len(self.valid_dataset)}') 198 | 199 | # testing phase 200 | # if stage == 'test': 201 | if opt[TEST_DATA] and opt[TEST_DATA][INPUT]: 202 | self.test_dataset = ImagesDataset(opt, ds_type=TEST_DATA, transform=self.test_transform, batchsize=1) 203 | 204 | def train_dataloader(self): 205 | if self.training_dataset: 206 | trainloader = torch.utils.data.DataLoader( 207 | self.training_dataset, 208 | batch_size=self.opt[BATCHSIZE], 209 | num_workers=self.opt[DATALOADER_N], 210 | shuffle=True, 211 | drop_last=True, 212 | pin_memory=True 213 | ) 214 | return trainloader 215 | 216 | def val_dataloader(self): 217 | if self.valid_dataset: 218 | return torch.utils.data.DataLoader( 219 | self.valid_dataset, 220 | batch_size=self.opt[VALID_BATCHSIZE], 221 | shuffle=False, 222 | num_workers=self.opt[DATALOADER_N] 223 | ) 224 | 225 | def test_dataloader(self): 226 | if self.test_dataset: 227 | return torch.utils.data.DataLoader( 228 | self.test_dataset, 229 | batch_size=1, 230 | shuffle=False, 231 | num_workers=self.opt[DATALOADER_N], 232 | pin_memory=True 233 | ) 234 | 235 | def teardown(self, stage): 236 | # clean up after fit or test 237 | # called on every process in DDP 238 | ... 239 | 240 | 241 | if __name__ == '__main__': 242 | ... 243 | -------------------------------------------------------------------------------- /src/env.yaml: -------------------------------------------------------------------------------- 1 | name: base 2 | channels: 3 | - pytorch 4 | - bottler 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _anaconda_depends=2020.07=py38_0 9 | - _ipyw_jlab_nb_ext_conf=0.1.0=py38_0 10 | - _libgcc_mutex=0.1=main 11 | - alabaster=0.7.12=pyhd3eb1b0_0 12 | - anaconda=custom=py38_1 13 | - anaconda-client=1.7.2=py38_0 14 | - anaconda-navigator=2.0.3=py38_0 15 | - anaconda-project=0.9.1=pyhd3eb1b0_1 16 | - appdirs=1.4.4=py_0 17 | - argh=0.26.2=py38_0 18 | - argon2-cffi=20.1.0=py38h27cfd23_1 19 | - asn1crypto=1.4.0=py_0 20 | - astroid=2.5=py38h06a4308_1 21 | - astropy=4.2.1=py38h27cfd23_1 22 | - async_generator=1.10=pyhd3eb1b0_0 23 | - atomicwrites=1.4.0=py_0 24 | - attrs=20.3.0=pyhd3eb1b0_0 25 | - autopep8=1.5.6=pyhd3eb1b0_0 26 | - babel=2.9.0=pyhd3eb1b0_0 27 | - backcall=0.2.0=pyhd3eb1b0_0 28 | - backports=1.0=pyhd3eb1b0_2 29 | - backports.functools_lru_cache=1.6.4=pyhd3eb1b0_0 30 | - backports.shutil_get_terminal_size=1.0.0=pyhd3eb1b0_3 31 | - backports.tempfile=1.0=pyhd3eb1b0_1 32 | - backports.weakref=1.0.post1=py_1 33 | - beautifulsoup4=4.9.3=pyha847dfd_0 34 | - bitarray=2.1.0=py38h27cfd23_1 35 | - bkcharts=0.2=py38_0 36 | - black=19.10b0=py_0 37 | - blas=1.0=mkl 38 | - bleach=3.3.0=pyhd3eb1b0_0 39 | - blosc=1.21.0=h8c45485_0 40 | - bokeh=2.3.2=py38h06a4308_0 41 | - boto=2.49.0=py38_0 42 | - bottleneck=1.3.2=py38heb32a55_1 43 | - brotlipy=0.7.0=py38h27cfd23_1003 44 | - bzip2=1.0.8=h7b6447c_0 45 | - c-ares=1.17.1=h27cfd23_0 46 | - ca-certificates=2022.4.26=h06a4308_0 47 | - cairo=1.16.0=hf32fb01_1 48 | - certifi=2021.10.8=py38h06a4308_2 49 | - cffi=1.14.5=py38h261ae71_0 50 | - chardet=4.0.0=py38h06a4308_1003 51 | - cloudpickle=1.6.0=py_0 52 | - clyent=1.2.2=py38_1 53 | - colorama=0.4.4=pyhd3eb1b0_0 54 | - conda=4.12.0=py38h06a4308_0 55 | - conda-build=3.21.4=py38h06a4308_0 56 | - conda-content-trust=0.1.1=pyhd3eb1b0_0 57 | - conda-env=2.6.0=1 58 | - conda-package-handling=1.7.3=py38h27cfd23_1 59 | - conda-repo-cli=1.0.4=pyhd3eb1b0_0 60 | - conda-token=0.3.0=pyhd3eb1b0_0 61 | - conda-verify=3.4.2=py_1 62 | - contextlib2=0.6.0.post1=py_0 63 | - cryptography=3.4.7=py38hd23ed53_0 64 | - cudatoolkit=10.2.89=hfd86e86_1 65 | - cudatoolkit-dev=10.1.243=h516909a_3 66 | - curl=7.71.1=hbc83047_1 67 | - cython=0.29.23=py38h2531618_0 68 | - cytoolz=0.11.0=py38h7b6447c_0 69 | - dask=2021.4.0=pyhd3eb1b0_0 70 | - dask-core=2021.4.0=pyhd3eb1b0_0 71 | - dbus=1.13.18=hb2f20db_0 72 | - decorator=5.0.6=pyhd3eb1b0_0 73 | - defusedxml=0.7.1=pyhd3eb1b0_0 74 | - diff-match-patch=20200713=py_0 75 | - distributed=2021.4.1=py38h06a4308_0 76 | - docutils=0.17.1=py38h06a4308_1 77 | - entrypoints=0.3=py38_0 78 | - et_xmlfile=1.0.1=py_1001 79 | - expat=2.3.0=h2531618_2 80 | - fastcache=1.1.0=py38h7b6447c_0 81 | - ffmpeg=4.2.2=h20bf706_0 82 | - flake8=3.9.0=pyhd3eb1b0_0 83 | - flask=1.1.2=pyhd3eb1b0_0 84 | - fontconfig=2.13.1=h6c09931_0 85 | - freetype=2.10.4=h5ab3b9f_0 86 | - fribidi=1.0.10=h7b6447c_0 87 | - fsspec=2021.8.1=pyhd3eb1b0_0 88 | - future=0.18.2=py38_1 89 | - get_terminal_size=1.0.0=haa9412d_0 90 | - gevent=21.1.2=py38h27cfd23_1 91 | - glib=2.68.1=h36276a3_0 92 | - glob2=0.7=pyhd3eb1b0_0 93 | - gmp=6.2.1=h2531618_2 94 | - gmpy2=2.0.8=py38hd5f6e3b_3 95 | - gnutls=3.6.15=he1e5248_0 96 | - graphite2=1.3.14=h23475e2_0 97 | - greenlet=1.0.0=py38h2531618_2 98 | - gst-plugins-base=1.14.0=h8213a91_2 99 | - gstreamer=1.14.0=h28cd5cc_2 100 | - h5py=2.10.0=py38h7918eee_0 101 | - harfbuzz=2.8.0=h6f93f22_0 102 | - hdf5=1.10.4=hb1b8bf9_0 103 | - heapdict=1.0.1=py_0 104 | - html5lib=1.1=py_0 105 | - icu=58.2=he6710b0_3 106 | - idna=2.10=pyhd3eb1b0_0 107 | - imagesize=1.2.0=pyhd3eb1b0_0 108 | - importlib-metadata=3.10.0=py38h06a4308_0 109 | - importlib_metadata=3.10.0=hd3eb1b0_0 110 | - iniconfig=1.1.1=pyhd3eb1b0_0 111 | - intel-openmp=2021.2.0=h06a4308_610 112 | - intervaltree=3.1.0=py_0 113 | - ipykernel=5.3.4=py38h5ca1d4c_0 114 | - ipython=7.22.0=py38hb070fc8_0 115 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 116 | - ipywidgets=7.6.3=pyhd3eb1b0_1 117 | - isort=5.8.0=pyhd3eb1b0_0 118 | - itsdangerous=1.1.0=pyhd3eb1b0_0 119 | - jbig=2.1=hdba287a_0 120 | - jdcal=1.4.1=py_0 121 | - jedi=0.17.2=py38h06a4308_1 122 | - jeepney=0.6.0=pyhd3eb1b0_0 123 | - joblib=1.0.1=pyhd3eb1b0_0 124 | - jpeg=9b=h024ee3a_2 125 | - json5=0.9.5=py_0 126 | - jsonschema=3.2.0=py_2 127 | - jupyter=1.0.0=py38_7 128 | - jupyter-packaging=0.7.12=pyhd3eb1b0_0 129 | - jupyter_client=6.1.12=pyhd3eb1b0_0 130 | - jupyter_console=6.4.0=pyhd3eb1b0_0 131 | - jupyter_core=4.7.1=py38h06a4308_0 132 | - jupyterlab_pygments=0.1.2=py_0 133 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 134 | - keyring=22.3.0=py38h06a4308_0 135 | - krb5=1.18.2=h173b8e3_0 136 | - lame=3.100=h7b6447c_0 137 | - lazy-object-proxy=1.6.0=py38h27cfd23_0 138 | - lcms2=2.12=h3be6417_0 139 | - ld_impl_linux-64=2.33.1=h53a641e_7 140 | - libarchive=3.4.2=h62408e4_0 141 | - libcurl=7.71.1=h20c2e04_1 142 | - libedit=3.1.20210216=h27cfd23_1 143 | - libev=4.33=h7b6447c_0 144 | - libffi=3.3=he6710b0_2 145 | - libgcc-ng=9.1.0=hdf63c60_0 146 | - libgfortran-ng=7.3.0=hdf63c60_0 147 | - libiconv=1.15=h63c8f33_5 148 | - libidn2=2.3.2=h7f8727e_0 149 | - liblief=0.10.1=he6710b0_0 150 | - libllvm10=10.0.1=hbcb73fb_5 151 | - libllvm9=9.0.1=h4a3c616_1 152 | - libopus=1.3.1=h7b6447c_0 153 | - libpng=1.6.37=hbc83047_0 154 | - libsodium=1.0.18=h7b6447c_0 155 | - libspatialindex=1.9.3=h2531618_0 156 | - libssh2=1.9.0=h1ba5d50_1 157 | - libstdcxx-ng=9.1.0=hdf63c60_0 158 | - libtasn1=4.16.0=h27cfd23_0 159 | - libtiff=4.2.0=h85742a9_0 160 | - libtool=2.4.6=h7b6447c_1005 161 | - libunistring=0.9.10=h27cfd23_0 162 | - libuuid=1.0.3=h1bed415_2 163 | - libuv=1.40.0=h7b6447c_0 164 | - libvpx=1.7.0=h439df22_0 165 | - libwebp-base=1.2.0=h27cfd23_0 166 | - libxcb=1.14=h7b6447c_0 167 | - libxml2=2.9.10=hb55368b_3 168 | - libxslt=1.1.34=hc22bd24_0 169 | - llvmlite=0.36.0=py38h612dafd_4 170 | - locket=0.2.1=py38h06a4308_1 171 | - lxml=4.6.3=py38h9120a33_0 172 | - lz4-c=1.9.3=h2531618_0 173 | - lzo=2.10=h7b6447c_2 174 | - matplotlib-base=3.3.4=py38h62a2d02_0 175 | - mccabe=0.6.1=py38_1 176 | - mkl=2021.4.0=h06a4308_640 177 | - mkl-service=2.3.0=py38h27cfd23_1 178 | - mkl_fft=1.3.0=py38h42c9631_2 179 | - mkl_random=1.2.1=py38ha9443f7_2 180 | - mock=4.0.3=pyhd3eb1b0_0 181 | - more-itertools=8.7.0=pyhd3eb1b0_0 182 | - mpc=1.1.0=h10f8cd9_1 183 | - mpfr=4.0.2=hb69a4c5_1 184 | - mpmath=1.2.1=py38h06a4308_0 185 | - msgpack-python=1.0.2=py38hff7bd54_1 186 | - multipledispatch=0.6.0=py38_0 187 | - mypy_extensions=0.4.3=py38_0 188 | - navigator-updater=0.2.1=py38_0 189 | - nbclient=0.5.3=pyhd3eb1b0_0 190 | - ncurses=6.2=he6710b0_1 191 | - nest-asyncio=1.5.1=pyhd3eb1b0_0 192 | - nettle=3.7.3=hbbd107a_1 193 | - nltk=3.6.1=pyhd3eb1b0_0 194 | - nose=1.3.7=pyhd3eb1b0_1006 195 | - notebook=6.3.0=py38h06a4308_0 196 | - numba=0.53.1=py38ha9443f7_0 197 | - numexpr=2.7.3=py38h22e1b3c_1 198 | - numpy-base=1.21.5=py38hf524024_2 199 | - numpydoc=1.1.0=pyhd3eb1b0_1 200 | - nvidiacub=1.10.0=0 201 | - olefile=0.46=py_0 202 | - openh264=2.1.0=hd408876_0 203 | - openpyxl=3.0.7=pyhd3eb1b0_0 204 | - openssl=1.1.1n=h7f8727e_0 205 | - pandas=1.2.4=py38h2531618_0 206 | - pandoc=2.12=h06a4308_0 207 | - pandocfilters=1.4.3=py38h06a4308_1 208 | - pango=1.45.3=hd140c19_0 209 | - parso=0.7.0=py_0 210 | - partd=1.2.0=pyhd3eb1b0_0 211 | - patchelf=0.12=h2531618_1 212 | - path=15.1.2=py38h06a4308_0 213 | - path.py=12.5.0=0 214 | - pathlib2=2.3.5=py38h06a4308_2 215 | - pathspec=0.7.0=py_0 216 | - pathtools=0.1.2=pyhd3eb1b0_1 217 | - patsy=0.5.1=py38_0 218 | - pcre=8.44=he6710b0_0 219 | - pep8=1.7.1=py38_0 220 | - pexpect=4.8.0=pyhd3eb1b0_3 221 | - pickleshare=0.7.5=pyhd3eb1b0_1003 222 | - pixman=0.40.0=h7b6447c_0 223 | - pkginfo=1.7.0=py38h06a4308_0 224 | - pluggy=0.13.1=py38h06a4308_0 225 | - ply=3.11=py38_0 226 | - prometheus_client=0.10.1=pyhd3eb1b0_0 227 | - prompt-toolkit=3.0.17=pyh06a4308_0 228 | - prompt_toolkit=3.0.17=hd3eb1b0_0 229 | - psutil=5.8.0=py38h27cfd23_1 230 | - ptyprocess=0.7.0=pyhd3eb1b0_2 231 | - py=1.10.0=pyhd3eb1b0_0 232 | - py-lief=0.10.1=py38h403a769_0 233 | - pycodestyle=2.6.0=pyhd3eb1b0_0 234 | - pycosat=0.6.3=py38h7b6447c_1 235 | - pycparser=2.20=py_2 236 | - pycurl=7.43.0.6=py38h1ba5d50_0 237 | - pydocstyle=6.0.0=pyhd3eb1b0_0 238 | - pyerfa=1.7.3=py38h27cfd23_0 239 | - pyflakes=2.2.0=pyhd3eb1b0_0 240 | - pygments=2.8.1=pyhd3eb1b0_0 241 | - pylint=2.7.4=py38h06a4308_1 242 | - pyls-black=0.4.6=hd3eb1b0_0 243 | - pyls-spyder=0.3.2=pyhd3eb1b0_0 244 | - pyodbc=4.0.30=py38he6710b0_0 245 | - pyopenssl=20.0.1=pyhd3eb1b0_1 246 | - pyqt=5.9.2=py38h05f1152_4 247 | - pyrsistent=0.17.3=py38h7b6447c_0 248 | - pysocks=1.7.1=py38h06a4308_0 249 | - pytables=3.6.1=py38h9fd0a39_0 250 | - pytest=6.2.3=py38h06a4308_2 251 | - python=3.8.8=hdb3f193_5 252 | - python-jsonrpc-server=0.4.0=py_0 253 | - python-language-server=0.36.2=pyhd3eb1b0_0 254 | - python-libarchive-c=2.9=pyhd3eb1b0_1 255 | - python_abi=3.8=2_cp38 256 | - pytz=2021.1=pyhd3eb1b0_0 257 | - pyxdg=0.27=pyhd3eb1b0_0 258 | - pyzmq=20.0.0=py38h2531618_1 259 | - qdarkstyle=2.8.1=py_0 260 | - qt=5.9.7=h5867ecd_1 261 | - qtawesome=1.0.2=pyhd3eb1b0_0 262 | - qtconsole=5.0.3=pyhd3eb1b0_0 263 | - qtpy=1.9.0=py_0 264 | - readline=8.1=h27cfd23_0 265 | - regex=2021.4.4=py38h27cfd23_0 266 | - requests=2.25.1=pyhd3eb1b0_0 267 | - ripgrep=12.1.1=0 268 | - rope=0.18.0=py_0 269 | - rtree=0.9.7=py38h06a4308_1 270 | - ruamel_yaml=0.15.100=py38h27cfd23_0 271 | - scikit-learn=0.24.1=py38ha9443f7_0 272 | - seaborn=0.11.1=pyhd3eb1b0_0 273 | - secretstorage=3.3.1=py38h06a4308_0 274 | - setuptools=52.0.0=py38h06a4308_0 275 | - simplegeneric=0.8.1=py38_2 276 | - singledispatch=3.6.1=pyhd3eb1b0_1001 277 | - sip=4.19.13=py38he6710b0_0 278 | - snappy=1.1.8=he6710b0_0 279 | - sniffio=1.2.0=py38h06a4308_1 280 | - snowballstemmer=2.1.0=pyhd3eb1b0_0 281 | - sortedcollections=2.1.0=pyhd3eb1b0_0 282 | - sortedcontainers=2.3.0=pyhd3eb1b0_0 283 | - soupsieve=2.2.1=pyhd3eb1b0_0 284 | - sphinx=4.0.1=pyhd3eb1b0_0 285 | - sphinxcontrib=1.0=py38_1 286 | - sphinxcontrib-applehelp=1.0.2=pyhd3eb1b0_0 287 | - sphinxcontrib-devhelp=1.0.2=pyhd3eb1b0_0 288 | - sphinxcontrib-htmlhelp=1.0.3=pyhd3eb1b0_0 289 | - sphinxcontrib-jsmath=1.0.1=pyhd3eb1b0_0 290 | - sphinxcontrib-qthelp=1.0.3=pyhd3eb1b0_0 291 | - sphinxcontrib-serializinghtml=1.1.4=pyhd3eb1b0_0 292 | - sphinxcontrib-websupport=1.2.4=py_0 293 | - spyder=4.2.5=py38h06a4308_0 294 | - spyder-kernels=1.10.2=py38h06a4308_0 295 | - sqlalchemy=1.4.15=py38h27cfd23_0 296 | - sqlite=3.35.4=hdfb4753_0 297 | - statsmodels=0.12.2=py38h27cfd23_0 298 | - sympy=1.8=py38h06a4308_0 299 | - tbb=2020.3=hfd86e86_0 300 | - tblib=1.7.0=py_0 301 | - terminado=0.9.4=py38h06a4308_0 302 | - testpath=0.4.4=pyhd3eb1b0_0 303 | - textdistance=4.2.1=pyhd3eb1b0_0 304 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 305 | - three-merge=0.1.1=pyhd3eb1b0_0 306 | - tk=8.6.10=hbc83047_0 307 | - toml=0.10.2=pyhd3eb1b0_0 308 | - toolz=0.11.1=pyhd3eb1b0_0 309 | - torchaudio=0.9.1=py38 310 | - tornado=6.1=py38h27cfd23_0 311 | - typed-ast=1.4.2=py38h27cfd23_1 312 | - ujson=4.0.2=py38h2531618_0 313 | - unicodecsv=0.14.1=py38_0 314 | - unixodbc=2.3.9=h7b6447c_0 315 | - urllib3=1.26.4=pyhd3eb1b0_0 316 | - watchdog=1.0.2=py38h06a4308_1 317 | - wcwidth=0.2.5=py_0 318 | - webencodings=0.5.1=py38_1 319 | - werkzeug=1.0.1=pyhd3eb1b0_0 320 | - wheel=0.36.2=pyhd3eb1b0_0 321 | - widgetsnbextension=3.5.1=py38_0 322 | - wrapt=1.12.1=py38h7b6447c_1 323 | - wurlitzer=2.1.0=py38h06a4308_0 324 | - x264=1!157.20191217=h7b6447c_0 325 | - xlrd=2.0.1=pyhd3eb1b0_0 326 | - xlsxwriter=1.3.8=pyhd3eb1b0_0 327 | - xlwt=1.3.0=py38_0 328 | - xmltodict=0.12.0=py_0 329 | - xz=5.2.5=h7b6447c_0 330 | - yaml=0.2.5=h7b6447c_0 331 | - yapf=0.31.0=pyhd3eb1b0_0 332 | - zeromq=4.3.4=h2531618_0 333 | - zict=2.0.0=pyhd3eb1b0_0 334 | - zipp=3.4.1=pyhd3eb1b0_0 335 | - zlib=1.2.11=h7b6447c_3 336 | - zope=1.0=py38_1 337 | - zope.event=4.5.0=py38_0 338 | - zope.interface=5.3.0=py38h27cfd23_0 339 | - zstd=1.4.5=h9ceee32_0 340 | - pip: 341 | - absl-py==0.13.0 342 | - addict==2.4.0 343 | - aiohttp==3.7.4.post0 344 | - aiosignal==1.2.0 345 | - albumentations==1.1.0 346 | - altair==4.2.0 347 | - antlr4-python3-runtime==4.9.3 348 | - anyio==3.6.1 349 | - anykeystore==0.2 350 | - apex==0.1 351 | - astor==0.8.1 352 | - async-timeout==3.0.1 353 | - audioread==2.1.9 354 | - augly==0.1.9 355 | - av==10.0.0 356 | - backports-zoneinfo==0.2.1 357 | - base58==2.1.1 358 | - bcrypt==4.0.1 359 | - blinker==1.4 360 | - blis==0.7.5 361 | - cachetools==4.2.2 362 | - catalogue==2.0.6 363 | - chex==0.1.4 364 | - clean-fid==0.1.29 365 | - click==8.1.3 366 | - clip==1.0 367 | - commonmark==0.9.1 368 | - configargparse==1.5.3 369 | - configparser==5.0.2 370 | - cryptacular==1.6.2 371 | - cycler==0.11.0 372 | - cymem==2.0.6 373 | - deepdiff==5.7.0 374 | - diffusers==0.3.0 375 | - distlib==0.3.4 376 | - dm-pix==0.3.3 377 | - dm-tree==0.1.7 378 | - docker-pycreds==0.4.0 379 | - dominate==2.6.0 380 | - easydict==1.9 381 | - editdistance==0.6.0 382 | - einops==0.4.1 383 | - etils==0.7.1 384 | - exifread==3.0.0 385 | - facexlib==0.2.5 386 | - fairscale==0.4.9 387 | - fairseq==0.9.0 388 | - fastai==2.5.3 389 | - fastapi==0.85.0 390 | - fastcore==1.3.27 391 | - fastdownload==0.0.5 392 | - fastjsonschema==2.16.1 393 | - fastprogress==1.0.0 394 | - ffmpy==0.3.0 395 | - filelock==3.7.1 396 | - filterpy==1.4.5 397 | - fire==0.4.0 398 | - flax==0.6.1 399 | - font-roboto==0.0.1 400 | - fonts==0.0.3 401 | - fonttools==4.31.2 402 | - frozenlist==1.3.0 403 | - ftfy==6.1.1 404 | - fvcore==0.1.5.post20220305 405 | - gdown==4.5.1 406 | - gfpgan==1.3.8 407 | - gin-config==0.5.0 408 | - gitdb==4.0.7 409 | - gitpython==3.1.18 410 | - google-auth==1.35.0 411 | - google-auth-oauthlib==0.4.6 412 | - gradio==3.4.1 413 | - grpcio==1.40.0 414 | - h11==0.12.0 415 | - htcondor==9.2.0 416 | - httpcore==0.15.0 417 | - httpimport==0.7.2 418 | - httpx==0.23.0 419 | - huggingface-hub==0.10.1 420 | - hupper==1.10.3 421 | - hydra-core==1.1.1 422 | - imageio==2.16.1 423 | - imageio-ffmpeg==0.4.5 424 | - imgaug==0.4.0 425 | - importlib-resources==5.2.2 426 | - install==1.3.5 427 | - iopath==0.1.9 428 | - ipdb==0.13.9 429 | - jax==0.3.23 430 | - jaxlib==0.3.22 431 | - jinja2==3.1.2 432 | - jsonmerge==1.8.0 433 | - jsonpatch==1.32 434 | - jsonpointer==2.2 435 | - jupyter-server==1.18.1 436 | - jupyter-server-mathjax==0.2.6 437 | - jupyterlab==3.4.5 438 | - jupyterlab-git==0.38.0 439 | - jupyterlab-server==2.15.1 440 | - kiwisolver==1.4.0 441 | - kornia==0.6.7 442 | - kornia-moons==0.2.0 443 | - lark==1.1.2 444 | - librosa==0.8.1 445 | - linkify-it-py==1.0.3 446 | - lmdb==1.2.1 447 | - lpips==0.1.4 448 | - markdown==3.3.4 449 | - markdown-it-py==2.1.0 450 | - markupsafe==2.1.1 451 | - matplotlib==3.5.1 452 | - mdit-py-plugins==0.3.1 453 | - mdurl==0.1.2 454 | - mediapy==1.1.0 455 | - mistune==2.0.4 456 | - mrcfile==1.3.0 457 | - multidict==5.1.0 458 | - murmurhash==1.0.6 459 | - natsort==8.1.0 460 | - nbclassic==0.4.3 461 | - nbconvert==7.0.0 462 | - nbdime==3.1.1 463 | - nbformat==5.4.0 464 | - networkx==2.7.1 465 | - ninja==1.10.2.3 466 | - nlpaug==1.1.3 467 | - notebook-shim==0.1.0 468 | - numpy==1.23.3 469 | - oauthlib==3.1.1 470 | - omegaconf==2.2.3 471 | - opencv-python==4.5.5.64 472 | - opencv-python-headless==4.5.4.58 473 | - opt-einsum==3.3.0 474 | - optax==0.1.3 475 | - ordered-set==4.0.2 476 | - orjson==3.8.0 477 | - packaging==21.3 478 | - paramiko==2.11.0 479 | - pastedeploy==2.1.1 480 | - patchify==0.2.3 481 | - pathlib==1.0.1 482 | - pathy==0.6.1 483 | - pbkdf2==1.3 484 | - piexif==1.1.3 485 | - pillow==9.0.1 486 | - pip==22.3.1 487 | - plaster==1.0 488 | - plaster-pastedeploy==0.7 489 | - platformdirs==2.5.2 490 | - plotly==5.3.1 491 | - plyfile==0.7.4 492 | - pooch==1.5.2 493 | - portalocker==2.3.2 494 | - preshed==3.0.6 495 | - promise==2.3 496 | - protobuf==3.17.3 497 | - pyarrow==7.0.0 498 | - pyasn1==0.4.8 499 | - pyasn1-modules==0.2.8 500 | - pycryptodome==3.15.0 501 | - pydantic==1.8.2 502 | - pydeck==0.7.1 503 | - pydeprecate==0.3.1 504 | - pydub==0.25.1 505 | - pyexiftool==0.5.4 506 | - pympler==1.0.1 507 | - pynacl==1.5.0 508 | - pyparsing==3.0.7 509 | - pyramid==2.0 510 | - pyramid-mailer==0.15.1 511 | - python-dateutil==2.8.2 512 | - python-magic==0.4.24 513 | - python-multipart==0.0.4 514 | - python3-openid==3.2.0 515 | - pytorch-fid==0.2.1 516 | - pytorch-lightning==1.7.6 517 | - pytz-deprecation-shim==0.1.0.post0 518 | - pywavelets==1.3.0 519 | - pyyaml==6.0 520 | - qudida==0.0.4 521 | - rawpy==0.17.3 522 | - ray==1.13.0 523 | - realesrgan==0.3.0 524 | - repoze-sendmail==4.4.1 525 | - requests-oauthlib==1.3.0 526 | - resampy==0.2.2 527 | - resize-right==0.0.2 528 | - rfc3986==1.5.0 529 | - rich==11.2.0 530 | - rich-cli==1.8.0 531 | - rich-rst==1.1.7 532 | - rsa==4.7.2 533 | - sacrebleu==2.0.0 534 | - scikit-image==0.19.2 535 | - scikit-video==1.1.11 536 | - scipy==1.8.0 537 | - semver==2.13.0 538 | - send2trash==1.8.0 539 | - sentry-sdk==1.3.1 540 | - setproctitle==1.2.2 541 | - shapely==1.7.1 542 | - shortuuid==1.0.1 543 | - simdkalman==1.0.2 544 | - six==1.16.0 545 | - smart-open==5.2.1 546 | - smmap==4.0.0 547 | - soundfile==0.10.3.post1 548 | - spacy==3.1.4 549 | - spacy-legacy==3.0.8 550 | - srsly==2.4.2 551 | - starlette==0.20.4 552 | - streamlit==1.3.1 553 | - subprocess32==3.5.4 554 | - svox2-csrc==0.0.1.dev0+sphtexcub.lincolor.fast 555 | - tabulate==0.8.9 556 | - tb-nightly==2.11.0a20221012 557 | - tenacity==8.0.1 558 | - tensorboard==2.10.1 559 | - tensorboard-data-server==0.6.1 560 | - tensorboard-plugin-wit==1.8.0 561 | - tensorboardx==2.4.1 562 | - termcolor==1.1.0 563 | - test-tube==0.7.5 564 | - textual==0.1.18 565 | - thinc==8.0.12 566 | - thop==0.0.31-2005241907 567 | - tifffile==2022.3.16 568 | - timm==0.6.7 569 | - tinycss2==1.1.1 570 | - tinycudann==1.6 571 | - tl2==0.0.8 572 | - tokenizers==0.12.1 573 | - torch==1.9.0 574 | - torch-ema==0.3 575 | - torch-scatter==2.0.9 576 | - torchdiffeq==0.2.3 577 | - torchfile==0.1.0 578 | - torchmetrics==0.9.2 579 | - torchsummary==1.5.1 580 | - torchvision==0.10.0 581 | - tqdm==4.64.0 582 | - traitlets==5.3.0 583 | - transaction==3.0.1 584 | - transformers==4.19.2 585 | - translationstring==1.4 586 | - trimesh==3.11.2 587 | - tsmoothie==1.0.4 588 | - typeguard==2.13.3 589 | - typer==0.4.0 590 | - typing-extensions==4.3.0 591 | - tzdata==2021.5 592 | - tzlocal==4.1 593 | - uc-micro-py==1.0.1 594 | - uvicorn==0.18.3 595 | - validators==0.18.2 596 | - velruse==1.1.1 597 | - venusian==3.0.0 598 | - virtualenv==20.15.1 599 | - visdom==0.1.8.9 600 | - vren==1.0 601 | - wand==0.6.7 602 | - wandb==0.13.8 603 | - warmup-scheduler==0.3 604 | - wasabi==0.8.2 605 | - webob==1.8.7 606 | - websocket-client==1.2.3 607 | - websockets==10.3 608 | - wtforms==3.0.1 609 | - wtforms-recaptcha==0.3.2 610 | - yacs==0.1.8 611 | - yarl==1.6.3 612 | - yaspin==2.1.0 613 | - zope-deprecation==4.4.0 614 | - zope-sqlalchemy==1.6 615 | prefix: /home/grads/hywang26/anaconda3 616 | -------------------------------------------------------------------------------- /src/globalenv.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | SRC_PATH = Path(__file__).absolute().parent 5 | ROOT_PATH = SRC_PATH.parent 6 | sys.path.append(str(SRC_PATH)) 7 | 8 | LOGGER_BUFFER_LOCK = False 9 | SPLIT = '————————————————————————————————————————————————————' 10 | 11 | GLOBAL_SEED = 233 12 | TEST_RESULT_DIRNAME = 'test_result' 13 | TRAIN_LOG_DIRNAME = 'log' 14 | CONFIG_DIR = 'config' 15 | CONFIG_FILEPATH = 'config/config.yaml' 16 | LMDB_DIRPATH = ROOT_PATH / 'lmdb' 17 | METRICS_LOG_DIRPATH = ROOT_PATH / 'metrics_log' 18 | OPT_FILENAME = 'CONFIG.yaml' 19 | LOG_FILENAME = 'run.log' 20 | LOG_TIME_FORMAT = '%Y-%m-%d_%H:%M:%S' 21 | INPUT = 'input' 22 | OUTPUT = 'output' 23 | GT = 'GT' 24 | STRING_FALSE = 'False' 25 | SKIP_FLAG = 'q' 26 | DEFAULTS = 'defaults' 27 | HYDRA = 'hydra' 28 | 29 | INPUT_FPATH = 'input_fpath' 30 | GT_FPATH = 'gt_fpath' 31 | 32 | DEBUG = 'debug' 33 | BACKEND = 'backend' 34 | CHECKPOINT_PATH = 'checkpoint_path' 35 | LOG_DIRPATH = 'log_dirpath' 36 | IMG_DIRPATH = 'img_dirpath' 37 | DATALOADER_N = 'dataloader_num_worker' 38 | VAL_DEBUG_STEP_NUMS = 'val_debug_step_nums' 39 | VALID_EVERY = 'valid_every' 40 | LOG_EVERY = 'log_every' 41 | AUGMENTATION = 'aug' 42 | RUNTIME_PRECISION = 'runtime_precision' 43 | NUM_EPOCH = 'num_epoch' 44 | NAME = 'name' 45 | LOSS = 'loss' 46 | TRAIN_DATA = 'train_ds' 47 | VALID_DATA = 'valid_ds' 48 | TEST_DATA = 'test_ds' 49 | GPU = 'gpu' 50 | RUNTIME = 'runtime' 51 | CLASS = 'class' 52 | MODELNAME = 'modelname' 53 | BATCHSIZE = 'batchsize' 54 | VALID_BATCHSIZE = 'valid_batchsize' 55 | LR = 'lr' 56 | CHECKPOINT_MONITOR = 'checkpoint_monitor' 57 | MONITOR_MODE = 'monitor_mode' 58 | COMMENT = 'comment' 59 | EARLY_STOP = 'early_stop' 60 | AMP_BACKEND = 'amp_backend' 61 | AMP_LEVEL = 'amp_level' 62 | VALID_RATIO = 'valid_ratio' 63 | 64 | LTV_LOSS = 'ltv' 65 | COS_LOSS = 'cos' 66 | SSIM_LOSS = 'ssim_loss' 67 | L1_LOSS = 'l1_loss' 68 | COLOR_LOSS = 'l_color' 69 | SPATIAL_LOSS = 'l_spa' 70 | EXPOSURE_LOSS = 'l_exp' 71 | WEIGHTED_LOSS = 'weighted_loss' 72 | PSNR_LOSS = 'psnr_loss' 73 | HIST_LOSS = 'hist_loss' 74 | INTER_HIST_LOSS = 'inter_hist_loss' 75 | VGG_LOSS = 'vgg_loss' 76 | 77 | PSNR = 'psnr' 78 | SSIM = 'ssim' 79 | 80 | VERTICAL_FLIP = 'v-flip' 81 | HORIZON_FLIP = 'h-flip' 82 | DOWNSAMPLE = 'downsample' 83 | RESIZE_DIVISIBLE_N = 'resize_divisible_n' 84 | CROP = 'crop' 85 | LIGHTNESS_ADJUST = 'lightness_adjust' 86 | CONTRAST_ADJUST = 'contrast_adjust' 87 | 88 | BUNET = 'bilateral_upsample_net' 89 | UNET = 'unet' 90 | HIST_UNET = 'hist_unet' 91 | PREDICT_ILLUMINATION = 'predict_illumination' 92 | FILTERS = 'filters' 93 | 94 | MODE = 'mode' 95 | COLOR_SPACE = 'color_space' 96 | BETA1 = 'beta1' 97 | BETA2 = 'beta2' 98 | LAMBDA_SMOOTH = 'lambda_smooth' 99 | LAMBDA_MONOTONICITY = 'lambda_monotonicity' 100 | MSE = 'mse' 101 | L2_LOSS = 'l2_loss' 102 | TV_CONS = 'tv_cons' 103 | MN_CONS = 'mv_cons' 104 | WEIGHTS_NORM = 'wnorm' 105 | TEST_PTH = 'test_pth' 106 | 107 | LUMA_BINS = 'luma_bins' 108 | CHANNEL_MULTIPLIER = 'channel_multiplier' 109 | SPATIAL_BIN = 'spatial_bin' 110 | BATCH_NORM = 'batch_norm' 111 | NET_INPUT_SIZE = 'net_input_size' 112 | LOW_RESOLUTION = 'low_resolution' 113 | ONNX_EXPORTING_MODE = 'onnx_exporting_mode' 114 | SELF_SUPERVISED = 'self_supervised' 115 | COEFFS_TYPE = 'coeffs_type' 116 | ILLU_MAP_POWER = 'illu_map_power' 117 | GAMMA = 'gamma' 118 | MATRIX = 'matrix' 119 | GUIDEMAP = 'guidemap' 120 | USE_HSV = 'use_hsv' 121 | 122 | USE_WAVELET = 'use_wavelet' 123 | NON_LOCAL = 'use_non_local' 124 | USE_ATTN_MAP = 'use_attn_map' 125 | ILLUMAP_CHANNEL = 'illumap_channel' 126 | HOW_TO_FUSE = 'how_to_fuse' 127 | SHARE_WEIGHTS = 'share_weights' 128 | BACKBONE = 'backbone' 129 | ARCH = 'arch' 130 | N_BINS = 'n_bins' 131 | BACKBONE_OUT_ILLU = 'backbone_out_illu' 132 | CONV_TYPE = 'conv_type' 133 | HIST_AS_GUIDE_ = 'hist_as_guide' 134 | ENCODER_USE_HIST = 'encoder_use_hist' 135 | GUIDE_FEATURE_FROM_HIST = 'guide_feature_from_hist' 136 | NC = 'channel_nums' 137 | 138 | ILLU_MAP = 'illu_map' 139 | INVERSE_ILLU_MAP = 'inverse_illu_map' 140 | BRIGHTEN_INPUT = 'brighten_input' 141 | DARKEN_INPUT = 'darken_input' 142 | 143 | TRAIN = 'train' 144 | TEST = 'test' 145 | VALID = 'valid' 146 | ONNX = 'onnx' 147 | CONDOR = 'condor' 148 | IMAGES = 'images' 149 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onpix/LCDPNet/4faa0d98e8ff45f53a3569dd005a74353995b335/src/model/__init__.py -------------------------------------------------------------------------------- /src/model/arch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onpix/LCDPNet/4faa0d98e8ff45f53a3569dd005a74353995b335/src/model/arch/__init__.py -------------------------------------------------------------------------------- /src/model/arch/drconv.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Function 6 | 7 | 8 | class asign_index(torch.autograd.Function): 9 | @staticmethod 10 | def forward(ctx, kernel, guide_feature): 11 | ctx.save_for_backward(kernel, guide_feature) 12 | guide_mask = torch.zeros_like(guide_feature).scatter_(1, guide_feature.argmax(dim=1, keepdim=True), 13 | 1).unsqueeze(2) # B x 3 x 1 x 25 x 25 14 | return torch.sum(kernel * guide_mask, dim=1) 15 | 16 | @staticmethod 17 | def backward(ctx, grad_output): 18 | kernel, guide_feature = ctx.saved_tensors 19 | guide_mask = torch.zeros_like(guide_feature).scatter_(1, guide_feature.argmax(dim=1, keepdim=True), 20 | 1).unsqueeze(2) # B x 3 x 1 x 25 x 25 21 | grad_kernel = grad_output.clone().unsqueeze(1) * guide_mask # B x 3 x 256 x 25 x 25 22 | grad_guide = grad_output.clone().unsqueeze(1) * kernel # B x 3 x 256 x 25 x 25 23 | grad_guide = grad_guide.sum(dim=2) # B x 3 x 25 x 25 24 | softmax = F.softmax(guide_feature, 1) # B x 3 x 25 x 25 25 | grad_guide = softmax * (grad_guide - (softmax * grad_guide).sum(dim=1, keepdim=True)) # B x 3 x 25 x 25 26 | return grad_kernel, grad_guide 27 | 28 | 29 | def xcorr_slow(x, kernel, kwargs): 30 | """for loop to calculate cross correlation 31 | """ 32 | batch = x.size()[0] 33 | out = [] 34 | for i in range(batch): 35 | px = x[i] 36 | pk = kernel[i] 37 | px = px.view(1, px.size()[0], px.size()[1], px.size()[2]) 38 | pk = pk.view(-1, px.size()[1], pk.size()[1], pk.size()[2]) 39 | po = F.conv2d(px, pk, **kwargs) 40 | out.append(po) 41 | out = torch.cat(out, 0) 42 | return out 43 | 44 | 45 | def xcorr_fast(x, kernel, kwargs): 46 | """group conv2d to calculate cross correlation 47 | """ 48 | batch = kernel.size()[0] 49 | pk = kernel.view(-1, x.size()[1], kernel.size()[2], kernel.size()[3]) 50 | px = x.view(1, -1, x.size()[2], x.size()[3]) 51 | po = F.conv2d(px, pk, **kwargs, groups=batch) 52 | po = po.view(batch, -1, po.size()[2], po.size()[3]) 53 | return po 54 | 55 | 56 | class Corr(Function): 57 | @staticmethod 58 | def symbolic(g, x, kernel, groups): 59 | return g.op("Corr", x, kernel, groups_i=groups) 60 | 61 | @staticmethod 62 | def forward(self, x, kernel, groups, kwargs): 63 | """group conv2d to calculate cross correlation 64 | """ 65 | batch = x.size(0) 66 | channel = x.size(1) 67 | x = x.view(1, -1, x.size(2), x.size(3)) 68 | kernel = kernel.view(-1, channel // groups, kernel.size(2), kernel.size(3)) 69 | out = F.conv2d(x, kernel, **kwargs, groups=groups * batch) 70 | out = out.view(batch, -1, out.size(2), out.size(3)) 71 | return out 72 | 73 | 74 | class Correlation(nn.Module): 75 | use_slow = True 76 | 77 | def __init__(self, use_slow=None): 78 | super(Correlation, self).__init__() 79 | if use_slow is not None: 80 | self.use_slow = use_slow 81 | else: 82 | self.use_slow = Correlation.use_slow 83 | 84 | def extra_repr(self): 85 | if self.use_slow: return "xcorr_slow" 86 | return "xcorr_fast" 87 | 88 | def forward(self, x, kernel, **kwargs): 89 | if self.training: 90 | if self.use_slow: 91 | return xcorr_slow(x, kernel, kwargs) 92 | else: 93 | return xcorr_fast(x, kernel, kwargs) 94 | else: 95 | return Corr.apply(x, kernel, 1, kwargs) 96 | 97 | 98 | class DRConv2d(nn.Module): 99 | def __init__(self, in_channels, out_channels, kernel_size, region_num=8, guide_input_channel=False, **kwargs): 100 | super(DRConv2d, self).__init__() 101 | self.region_num = region_num 102 | self.guide_input_channel = guide_input_channel 103 | 104 | self.conv_kernel = nn.Sequential( 105 | nn.AdaptiveAvgPool2d((kernel_size, kernel_size)), 106 | nn.Conv2d(in_channels, region_num * region_num, kernel_size=1), 107 | nn.Sigmoid(), 108 | nn.Conv2d(region_num * region_num, region_num * in_channels * out_channels, kernel_size=1, 109 | groups=region_num) 110 | ) 111 | if guide_input_channel: 112 | # get guide feature from a user input tensor. 113 | self.conv_guide = nn.Conv2d(guide_input_channel, region_num, kernel_size=kernel_size, **kwargs) 114 | else: 115 | self.conv_guide = nn.Conv2d(in_channels, region_num, kernel_size=kernel_size, **kwargs) 116 | 117 | self.corr = Correlation(use_slow=False) 118 | self.kwargs = kwargs 119 | self.asign_index = asign_index.apply 120 | 121 | def forward(self, input, guide_input=None): 122 | kernel = self.conv_kernel(input) 123 | # kernel = kernel.view(kernel.size(0), -1, kernel.size(2), kernel.size(3)) # B x (r*in*out) x W X H 124 | output = self.corr(input, kernel, **self.kwargs) # B x (r*out) x W x H 125 | output = output.view(output.size(0), self.region_num, -1, output.size(2), output.size(3)) # B x r x out x W x H 126 | if self.guide_input_channel: 127 | guide_feature = self.conv_guide(guide_input) 128 | else: 129 | guide_feature = self.conv_guide(input) 130 | self.guide_feature = guide_feature 131 | # self.guide_feature = torch.zeros_like(guide_feature).scatter_(1, guide_feature.argmax(dim=1, keepdim=True), 1).unsqueeze(2) # B x 3 x 1 x 25 x 25 132 | output = self.asign_index(output, guide_feature) 133 | return output 134 | 135 | 136 | class HistDRConv2d(DRConv2d): 137 | def forward(self, input, histmap): 138 | """ 139 | use histmap as guide feature directly. 140 | histmap.shape: [bs, n_bins, h, w] 141 | """ 142 | histmap.requires_grad_(False) 143 | 144 | kernel = self.conv_kernel(input) 145 | output = self.corr(input, kernel, **self.kwargs) # B x (r*out) x W x H 146 | output = output.view(output.size(0), self.region_num, -1, output.size(2), output.size(3)) # B x r x out x W x H 147 | output = self.asign_index(output, histmap) 148 | return output 149 | 150 | 151 | if __name__ == '__main__': 152 | B = 16 153 | in_channels = 256 154 | out_channels = 512 155 | size = 89 156 | conv = DRConv2d(in_channels, out_channels, kernel_size=3, region_num=8).cuda() 157 | conv.train() 158 | input = torch.ones(B, in_channels, size, size).cuda() 159 | output = conv(input) 160 | print(input.shape, output.shape) 161 | 162 | # flops, params 163 | from thop import profile 164 | 165 | 166 | class Conv2d(nn.Module): 167 | def __init__(self): 168 | super(Conv2d, self).__init__() 169 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3) 170 | 171 | def forward(self, input): 172 | return self.conv(input) 173 | 174 | 175 | ipdb.set_trace() 176 | conv2 = Conv2d().cuda() 177 | conv2.train() 178 | print(input.shape, conv2(input).shape) 179 | flops2, params2 = profile(conv2, inputs=(input,)) 180 | flops, params = profile(conv, inputs=(input,)) 181 | 182 | print('[ * ] DRconv FLOPs = ' + str(flops / 1000 ** 3) + 'G') 183 | print('[ * ] DRconv Params Num = ' + str(params / 1000 ** 2) + 'M') 184 | 185 | print('[ * ] Conv FLOPs = ' + str(flops2 / 1000 ** 3) + 'G') 186 | print('[ * ] Conv Params Num = ' + str(params2 / 1000 ** 2) + 'M') 187 | -------------------------------------------------------------------------------- /src/model/arch/hist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_gray(img): 5 | r = img[:, 0, ...] 6 | g = img[:, 1, ...] 7 | b = img[:, 2, ...] 8 | return (0.299 * r + 0.587 * g + 0.114 * b).unsqueeze(1) 9 | 10 | 11 | def pack_tensor(x, n_bins): 12 | # pack tensor: transform gt_hist.shape [n_bins, bs, c, h, w] -> [bs*c, b_bins, h, w] 13 | # merge dim 1 (bs) and dim 2 (channel). 14 | return x.reshape(n_bins, -1, *x.shape[-2:]).permute(1, 0, 2, 3) 15 | 16 | 17 | def get_hist(img, n_bins, grayscale=False): 18 | """ 19 | Given a img (shape: bs, c, h, w), 20 | return the SOFT histogram map (shape: n_bins, bs, c, h, w) 21 | or (shape: n_bins, bs, h, w) when grayscale=True. 22 | """ 23 | if grayscale: 24 | img = get_gray(img) 25 | return torch.stack([ 26 | torch.nn.functional.relu(1 - torch.abs(img - (2 * b - 1) / float(2 * n_bins)) * float(n_bins)) 27 | for b in range(1, n_bins + 1) 28 | ]) 29 | 30 | 31 | def get_hist_conv(n_bins, kernel_size=2, train=False): 32 | """ 33 | Return a conv kernel. 34 | The kernel is used to apply on the histogram map, shrinking the scale of the hist-map. 35 | """ 36 | conv = torch.nn.Conv2d(n_bins, n_bins, kernel_size, kernel_size, bias=False, groups=1) 37 | conv.weight.data.zero_() 38 | for i in range(conv.weight.shape[1]): 39 | alpha = kernel_size ** 2 40 | # alpha = 1 41 | conv.weight.data[i, i, ...] = torch.ones(kernel_size, kernel_size) / alpha 42 | if not train: 43 | conv.requires_grad_(False) 44 | return conv 45 | -------------------------------------------------------------------------------- /src/model/arch/nonlocal_block_embedded_gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import ipdb 5 | 6 | 7 | class _NonLocalBlockND(nn.Module): 8 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample='pool', bn_layer=True): 9 | """ 10 | :param in_channels: 11 | :param inter_channels: 12 | :param dimension: 13 | :param sub_sample: 'pool' or 'bilinear' or False 14 | :param bn_layer: 15 | """ 16 | 17 | super(_NonLocalBlockND, self).__init__() 18 | 19 | assert dimension in [1, 2, 3] 20 | 21 | self.dimension = dimension 22 | self.sub_sample = sub_sample 23 | 24 | self.in_channels = in_channels 25 | self.inter_channels = inter_channels 26 | 27 | if self.inter_channels is None: 28 | self.inter_channels = in_channels // 2 29 | if self.inter_channels == 0: 30 | self.inter_channels = 1 31 | 32 | if dimension == 3: 33 | conv_nd = nn.Conv3d 34 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 35 | bn = nn.BatchNorm3d 36 | elif dimension == 2: 37 | conv_nd = nn.Conv2d 38 | if sub_sample == 'pool': 39 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 40 | elif sub_sample == 'bilinear': 41 | max_pool_layer = nn.UpsamplingBilinear2d([16, 16]) 42 | else: 43 | raise NotImplementedError(f'[ ERR ] Unknown down sample method: {sub_sample}') 44 | bn = nn.BatchNorm2d 45 | else: 46 | conv_nd = nn.Conv1d 47 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 48 | bn = nn.BatchNorm1d 49 | 50 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 51 | kernel_size=1, stride=1, padding=0) 52 | 53 | if bn_layer: 54 | self.W = nn.Sequential( 55 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 56 | kernel_size=1, stride=1, padding=0), 57 | bn(self.in_channels) 58 | ) 59 | nn.init.constant_(self.W[1].weight, 0) 60 | nn.init.constant_(self.W[1].bias, 0) 61 | else: 62 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 63 | kernel_size=1, stride=1, padding=0) 64 | nn.init.constant_(self.W.weight, 0) 65 | nn.init.constant_(self.W.bias, 0) 66 | 67 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 68 | kernel_size=1, stride=1, padding=0) 69 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 70 | kernel_size=1, stride=1, padding=0) 71 | 72 | if sub_sample: 73 | self.g = nn.Sequential(self.g, max_pool_layer) 74 | self.phi = nn.Sequential(self.phi, max_pool_layer) 75 | 76 | def forward(self, x, return_nl_map=False): 77 | """ 78 | :param x: (b, c, t, h, w) 79 | :param return_nl_map: if True return z, nl_map, else only return z. 80 | :return: 81 | """ 82 | 83 | batch_size = x.size(0) 84 | 85 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 86 | g_x = g_x.permute(0, 2, 1) 87 | 88 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 89 | theta_x = theta_x.permute(0, 2, 1) 90 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 91 | f = torch.matmul(theta_x, phi_x) 92 | f_div_C = F.softmax(f, dim=-1) 93 | 94 | y = torch.matmul(f_div_C, g_x) 95 | y = y.permute(0, 2, 1).contiguous() 96 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 97 | W_y = self.W(y) 98 | z = W_y + x 99 | 100 | if return_nl_map: 101 | return z, f_div_C 102 | return z 103 | 104 | 105 | class NONLocalBlock1D(_NonLocalBlockND): 106 | def __init__(self, in_channels, inter_channels=None, sub_sample='pool', bn_layer=True): 107 | super(NONLocalBlock1D, self).__init__(in_channels, 108 | inter_channels=inter_channels, 109 | dimension=1, sub_sample=sub_sample, 110 | bn_layer=bn_layer) 111 | 112 | 113 | class NONLocalBlock2D(_NonLocalBlockND): 114 | def __init__(self, in_channels, inter_channels=None, sub_sample='pool', bn_layer=True): 115 | super(NONLocalBlock2D, self).__init__(in_channels, 116 | inter_channels=inter_channels, 117 | dimension=2, sub_sample=sub_sample, 118 | bn_layer=bn_layer, ) 119 | 120 | 121 | class NONLocalBlock3D(_NonLocalBlockND): 122 | def __init__(self, in_channels, inter_channels=None, sub_sample='pool', bn_layer=True): 123 | super(NONLocalBlock3D, self).__init__(in_channels, 124 | inter_channels=inter_channels, 125 | dimension=3, sub_sample=sub_sample, 126 | bn_layer=bn_layer, ) 127 | 128 | 129 | if __name__ == '__main__': 130 | import torch 131 | 132 | # for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: 133 | # img = torch.zeros(2, 3, 20) 134 | # net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 135 | # out = net(img) 136 | # print(out.size()) 137 | 138 | # img = torch.zeros(2, 3, 20, 20) 139 | # net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 140 | # out = net(img) 141 | # print(out.size()) 142 | 143 | # img = torch.randn(2, 3, 8, 20, 20) 144 | # net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 145 | # out = net(img) 146 | # print(out.size()) 147 | 148 | img = torch.zeros(4, 16, 20, 20).cuda() 149 | net = NONLocalBlock2D(16, sub_sample=True, bn_layer=False).cuda() 150 | out = net(img) 151 | ipdb.set_trace() 152 | -------------------------------------------------------------------------------- /src/model/arch/unet_based/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onpix/LCDPNet/4faa0d98e8ff45f53a3569dd005a74353995b335/src/model/arch/unet_based/__init__.py -------------------------------------------------------------------------------- /src/model/arch/unet_based/hist_unet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ipdb 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from model.arch.drconv import DRConv2d, HistDRConv2d 8 | from model.arch.hist import get_hist, get_hist_conv, pack_tensor 9 | 10 | 11 | class DoubleConv(nn.Module): 12 | def __init__(self, in_channels, out_channels, mid_channels=None): 13 | super().__init__() 14 | if not mid_channels: 15 | mid_channels = out_channels 16 | self.double_conv = nn.Sequential( 17 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 18 | nn.BatchNorm2d(mid_channels), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 21 | nn.BatchNorm2d(out_channels), 22 | nn.ReLU(inplace=True) 23 | ) 24 | 25 | def forward(self, x): 26 | return self.double_conv(x) 27 | 28 | 29 | class DRDoubleConv(nn.Module): 30 | def __init__(self, in_channels, out_channels, mid_channels=None, **kargs): 31 | super().__init__() 32 | if not mid_channels: 33 | mid_channels = out_channels 34 | 35 | self.double_conv = nn.Sequential( 36 | DRConv2d(in_channels, mid_channels, kernel_size=3, region_num=REGION_NUM_, padding=1, **kargs), 37 | nn.BatchNorm2d(mid_channels), 38 | nn.ReLU(inplace=True), 39 | DRConv2d(mid_channels, out_channels, kernel_size=3, region_num=REGION_NUM_, padding=1, **kargs), 40 | nn.BatchNorm2d(out_channels), 41 | nn.ReLU(inplace=True) 42 | ) 43 | assert len(DRCONV_POSITION_) == 2 44 | assert DRCONV_POSITION_[0] or DRCONV_POSITION_[1] 45 | if DRCONV_POSITION_[0] == 0: 46 | print('[ WARN ] Use Conv in DRDoubleConv[0] instead of DRconv.') 47 | self.double_conv[0] = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1) 48 | if DRCONV_POSITION_[1] == 0: 49 | print('[ WARN ] Use Conv in DRDoubleConv[3] instead of DRconv.') 50 | self.double_conv[3] = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1) 51 | 52 | def forward(self, x): 53 | res = self.double_conv(x) 54 | self.guide_features = [] 55 | if DRCONV_POSITION_[0]: 56 | self.guide_features.append(self.double_conv[0].guide_feature) 57 | if DRCONV_POSITION_[1]: 58 | self.guide_features.append(self.double_conv[3].guide_feature) 59 | return res 60 | 61 | 62 | class HistDRDoubleConv(nn.Module): 63 | def __init__(self, in_channels, out_channels, mid_channels=None): 64 | super().__init__() 65 | if not mid_channels: 66 | mid_channels = out_channels 67 | self.conv1 = HistDRConv2d(in_channels, mid_channels, kernel_size=3, region_num=REGION_NUM_, padding=1) 68 | self.inter1 = nn.Sequential( 69 | nn.BatchNorm2d(mid_channels), 70 | nn.ReLU(inplace=True) 71 | ) 72 | self.conv2 = HistDRConv2d(mid_channels, out_channels, kernel_size=3, region_num=REGION_NUM_, padding=1) 73 | self.inter2 = nn.Sequential( 74 | nn.BatchNorm2d(out_channels), 75 | nn.ReLU(inplace=True) 76 | ) 77 | 78 | def forward(self, x, histmap): 79 | y = self.conv1(x, histmap) 80 | y = self.inter1(y) 81 | y = self.conv2(y, histmap) 82 | return self.inter2(y) 83 | 84 | 85 | class HistGuidedDRDoubleConv(nn.Module): 86 | def __init__(self, in_channels, out_channels, mid_channels=None, **kargs): 87 | super().__init__() 88 | assert len(DRCONV_POSITION_) == 2 89 | assert DRCONV_POSITION_[0] or DRCONV_POSITION_[1] 90 | 91 | if not mid_channels: 92 | mid_channels = out_channels 93 | if DRCONV_POSITION_[0]: 94 | self.conv1 = DRConv2d(in_channels, mid_channels, kernel_size=3, region_num=REGION_NUM_, padding=1, **kargs) 95 | else: 96 | print('[ WARN ] Use Conv in HistGuidedDRDoubleConv[0] instead of DRconv.') 97 | self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1) 98 | 99 | self.inter1 = nn.Sequential( 100 | nn.BatchNorm2d(mid_channels), 101 | nn.ReLU(inplace=True) 102 | ) 103 | if DRCONV_POSITION_[1]: 104 | self.conv2 = DRConv2d(mid_channels, out_channels, kernel_size=3, region_num=REGION_NUM_, padding=1, **kargs) 105 | else: 106 | print('[ WARN ] Use Conv in HistGuidedDRDoubleConv[0] instead of DRconv.') 107 | self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1) 108 | 109 | self.inter2 = nn.Sequential( 110 | nn.BatchNorm2d(out_channels), 111 | nn.ReLU(inplace=True) 112 | ) 113 | 114 | def forward(self, x, histmap): 115 | if DRCONV_POSITION_[0]: 116 | y = self.conv1(x, histmap) 117 | else: 118 | y = self.conv1(x) 119 | y = self.inter1(y) 120 | 121 | if DRCONV_POSITION_[1]: 122 | y = self.conv2(y, histmap) 123 | else: 124 | y = self.conv2(y) 125 | 126 | # self.guide_features = [self.conv1.guide_feature, self.conv2.guide_feature] 127 | self.guide_features = [] 128 | if DRCONV_POSITION_[0]: 129 | self.guide_features.append(self.conv1.guide_feature) 130 | if DRCONV_POSITION_[1]: 131 | self.guide_features.append(self.conv2.guide_feature) 132 | 133 | return self.inter2(y) 134 | 135 | 136 | class Up(nn.Module): 137 | def __init__(self, in_channels, out_channels, bilinear=True, **kargs): 138 | super().__init__() 139 | self.up = nn.Upsample(scale_factor=DOWN_RATIO_, mode='bilinear', align_corners=True) 140 | if CONV_TYPE_ == 'drconv': 141 | if HIST_AS_GUIDE_: 142 | self.conv = HistDRDoubleConv(in_channels, out_channels, in_channels // 2) 143 | elif GUIDE_FEATURE_FROM_HIST_: 144 | self.conv = HistGuidedDRDoubleConv(in_channels, out_channels, in_channels // 2, **kargs) 145 | else: 146 | self.conv = DRDoubleConv(in_channels, out_channels, in_channels // 2) 147 | # elif CONV_TYPE_ == 'dconv': 148 | # self.conv = HistDyDoubleConv(in_channels, out_channels, in_channels // 2) 149 | 150 | def forward(self, x1, x2, histmap): 151 | """ 152 | histmap: shape [bs, c * n_bins, h, w] 153 | """ 154 | x1 = self.up(x1) 155 | 156 | # input is CHW 157 | diffY = x2.size()[2] - x1.size()[2] 158 | diffX = x2.size()[3] - x1.size()[3] 159 | 160 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 161 | diffY // 2, diffY - diffY // 2]) 162 | 163 | if HIST_AS_GUIDE_ or GUIDE_FEATURE_FROM_HIST_ or CONV_TYPE_ == 'dconv': 164 | x = torch.cat([x2, x1], dim=1) 165 | res = self.conv(x, histmap) 166 | else: 167 | x = torch.cat([x2, x1, histmap], dim=1) 168 | res = self.conv(x) 169 | self.guide_features = self.conv.guide_features 170 | return res 171 | 172 | 173 | class Down(nn.Module): 174 | def __init__(self, in_channels, out_channels, use_hist=False): 175 | super().__init__() 176 | self.use_hist = use_hist 177 | if not use_hist: 178 | self.maxpool_conv = nn.Sequential( 179 | nn.MaxPool2d(DOWN_RATIO_), 180 | DoubleConv(in_channels, out_channels) 181 | ) 182 | else: 183 | if HIST_AS_GUIDE_: 184 | # self.maxpool_conv = nn.Sequential( 185 | # nn.MaxPool2d(2), 186 | # HistDRDoubleConv(in_channels, out_channels, in_channels // 2) 187 | # ) 188 | raise NotImplementedError() 189 | elif GUIDE_FEATURE_FROM_HIST_: 190 | self.maxpool = nn.MaxPool2d(DOWN_RATIO_) 191 | self.conv = HistGuidedDRDoubleConv(in_channels, out_channels, in_channels // 2) 192 | else: 193 | self.maxpool_conv = nn.Sequential( 194 | nn.MaxPool2d(DOWN_RATIO_), 195 | DRDoubleConv(in_channels, out_channels, in_channels // 2) 196 | ) 197 | 198 | def forward(self, x, histmap=None): 199 | if GUIDE_FEATURE_FROM_HIST_ and self.use_hist: 200 | x = self.maxpool(x) 201 | return self.conv(x, histmap) 202 | elif self.use_hist: 203 | return self.maxpool_conv(torch.cat([x, histmap], axis=1)) 204 | else: 205 | return self.maxpool_conv(x) 206 | 207 | 208 | class OutConv(nn.Module): 209 | def __init__(self, in_channels, out_channels): 210 | super(OutConv, self).__init__() 211 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 212 | 213 | def forward(self, x): 214 | return self.conv(x) 215 | 216 | 217 | class HistUNet(nn.Module): 218 | def __init__(self, 219 | in_channels=3, 220 | out_channels=3, 221 | bilinear=True, 222 | n_bins=8, 223 | hist_as_guide=False, 224 | channel_nums=None, 225 | hist_conv_trainable=False, 226 | encoder_use_hist=False, 227 | guide_feature_from_hist=False, 228 | region_num=8, 229 | use_gray_hist=False, 230 | conv_type='drconv', 231 | down_ratio=1, 232 | drconv_position=[1, 1], 233 | ): 234 | super().__init__() 235 | C_NUMS = [16, 32, 64, 128, 256] 236 | if channel_nums: 237 | C_NUMS = channel_nums 238 | self.maxpool = nn.MaxPool2d(2) 239 | self.n_bins = n_bins 240 | self.encoder_use_hist = encoder_use_hist 241 | self.use_gray_hist = use_gray_hist 242 | self.hist_conv_trainable = hist_conv_trainable 243 | 244 | global HIST_AS_GUIDE_, GUIDE_FEATURE_FROM_HIST_, REGION_NUM_, CONV_TYPE_, DOWN_RATIO_, DRCONV_POSITION_ 245 | HIST_AS_GUIDE_ = hist_as_guide 246 | GUIDE_FEATURE_FROM_HIST_ = guide_feature_from_hist 247 | REGION_NUM_ = region_num 248 | CONV_TYPE_ = conv_type 249 | DOWN_RATIO_ = down_ratio 250 | DRCONV_POSITION_ = drconv_position 251 | 252 | if hist_conv_trainable: 253 | self.hist_conv1 = get_hist_conv(n_bins * in_channels, down_ratio, train=True) 254 | self.hist_conv2 = get_hist_conv(n_bins * in_channels, down_ratio, train=True) 255 | self.hist_conv3 = get_hist_conv(n_bins * in_channels, down_ratio, train=True) 256 | else: 257 | self.hist_conv = get_hist_conv(n_bins, down_ratio) 258 | 259 | factor = 2 if bilinear else 1 260 | self.inc = DoubleConv(in_channels, C_NUMS[0]) 261 | if hist_as_guide or guide_feature_from_hist or conv_type == 'dconv': 262 | extra_c_num = 0 263 | elif use_gray_hist: 264 | extra_c_num = n_bins 265 | else: 266 | extra_c_num = n_bins * in_channels 267 | 268 | if guide_feature_from_hist: 269 | kargs = { 270 | 'guide_input_channel': n_bins if use_gray_hist else n_bins * in_channels 271 | } 272 | else: 273 | kargs = {} 274 | 275 | if encoder_use_hist: 276 | encoder_extra_c_num = extra_c_num 277 | else: 278 | encoder_extra_c_num = 0 279 | 280 | self.down1 = Down(C_NUMS[0] + encoder_extra_c_num, C_NUMS[1], use_hist=encoder_use_hist) 281 | self.down2 = Down(C_NUMS[1] + encoder_extra_c_num, C_NUMS[2], use_hist=encoder_use_hist) 282 | self.down3 = Down(C_NUMS[2] + encoder_extra_c_num, C_NUMS[3], use_hist=encoder_use_hist) 283 | self.down4 = Down(C_NUMS[3] + encoder_extra_c_num, C_NUMS[4] // factor, use_hist=encoder_use_hist) 284 | 285 | self.up1 = Up(C_NUMS[4] + extra_c_num, C_NUMS[3] // factor, bilinear, **kargs) 286 | self.up2 = Up(C_NUMS[3] + extra_c_num, C_NUMS[2] // factor, bilinear, **kargs) 287 | self.up3 = Up(C_NUMS[2] + extra_c_num, C_NUMS[1] // factor, bilinear, **kargs) 288 | self.up4 = Up(C_NUMS[1] + extra_c_num, C_NUMS[0], bilinear, **kargs) 289 | self.outc = OutConv(C_NUMS[0], out_channels) 290 | 291 | def forward(self, x): 292 | # ipdb.set_trace() 293 | # get histograms 294 | # (`get_hist` return shape: n_bins, bs, c, h, w). 295 | if HIST_AS_GUIDE_ or self.use_gray_hist: 296 | histmap = get_hist(x, self.n_bins, grayscale=True) 297 | else: 298 | histmap = get_hist(x, self.n_bins) 299 | 300 | bs = x.shape[0] 301 | histmap = pack_tensor(histmap, self.n_bins).detach() # out: [bs * c, n_bins, h, w] 302 | if not self.hist_conv_trainable: 303 | hist_down2 = self.hist_conv(histmap) 304 | hist_down4 = self.hist_conv(hist_down2) 305 | hist_down8 = self.hist_conv(hist_down4) 306 | 307 | # [bs * c, b_bins, h, w] -> [bs, c*b_bins, h, w] 308 | for item in [histmap, hist_down2, hist_down4, hist_down8]: 309 | item.data = item.reshape(bs, -1, *item.shape[-2:]) 310 | else: 311 | histmap = histmap.reshape(bs, -1, *histmap.shape[-2:]) 312 | hist_down2 = self.hist_conv1(histmap) 313 | hist_down4 = self.hist_conv2(hist_down2) 314 | hist_down8 = self.hist_conv3(hist_down4) # [bs, n_bins * c, h/n, w/n] 315 | 316 | # forward 317 | encoder_hists = [None, ] * 4 318 | if self.encoder_use_hist: 319 | encoder_hists = [histmap, hist_down2, hist_down4, hist_down8] 320 | 321 | x1 = self.inc(x) 322 | x2 = self.down1(x1, encoder_hists[0]) # x2: 16 323 | x3 = self.down2(x2, encoder_hists[1]) # x3: 24 324 | x4 = self.down3(x3, encoder_hists[2]) # x4: 32 325 | x5 = self.down4(x4, encoder_hists[3]) # x5: 32 326 | 327 | # always apply hist in decoder: 328 | # ipdb.set_trace() 329 | x = self.up1(x5, x4, hist_down8) # [x5, x4]: 32 + 32 330 | x = self.up2(x, x3, hist_down4) # [x4, x3]: 331 | x = self.up3(x, x2, hist_down2) 332 | x = self.up4(x, x1, histmap) 333 | 334 | self.guide_features = [layer.guide_features for layer in [ 335 | self.up1, 336 | self.up2, 337 | self.up3, 338 | self.up4, 339 | ]] 340 | 341 | logits = self.outc(x) 342 | return logits 343 | 344 | 345 | if __name__ == '__main__': 346 | model = HistUNet() 347 | x = torch.rand(4, 3, 512, 512) 348 | model(x) 349 | ipdb.set_trace() 350 | -------------------------------------------------------------------------------- /src/model/basemodel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import pathlib 4 | from collections.abc import Iterable 5 | 6 | import pytorch_lightning as pl 7 | import torchvision 8 | import wandb 9 | 10 | import utils.util as util 11 | 12 | try: 13 | from thop import profile 14 | except Exception as e: 15 | print('ERR: import thop failed, skip. error msg:') 16 | print(e) 17 | 18 | from globalenv import * 19 | 20 | global LOGGER_BUFFER_LOCK 21 | 22 | 23 | class BaseModel(pl.core.LightningModule): 24 | def __init__(self, opt, running_modes): 25 | ''' 26 | logger_img_group_names: images group names in wandb logger. recommand: ['train', 'valid'] 27 | ''' 28 | 29 | super().__init__() 30 | self.save_hyperparameters(dict(opt)) 31 | print('Running initialization for BaseModel') 32 | 33 | if IMG_DIRPATH in opt: 34 | # in training mode. 35 | # if in test mode, configLogging is not called. 36 | if TRAIN in running_modes: 37 | self.train_img_dirpath = osp.join(opt[IMG_DIRPATH], TRAIN) 38 | util.mkdir(self.train_img_dirpath) 39 | if VALID in running_modes and (len(opt[VALID_DATA].keys()) > 1 or opt[VALID_RATIO]): 40 | self.valid_img_dirpath = osp.join(opt[IMG_DIRPATH], VALID) 41 | util.mkdir(self.valid_img_dirpath) 42 | 43 | self.opt = opt 44 | self.learning_rate = self.opt[LR] 45 | 46 | self.MODEL_WATCHED = False # for wandb watching model 47 | self.global_valid_step = 0 48 | self.iogt = {} # a dict, saving input, output and gt batch 49 | 50 | assert isinstance(running_modes, Iterable) 51 | self.logger_image_buffer = {k: [] for k in running_modes} 52 | 53 | def show_flops_and_param_num(self, inputs): 54 | # inputs: arguments of `forward()` 55 | try: 56 | flops, params = profile(self, inputs=inputs) 57 | print('[ * ] FLOPs = ' + str(flops / 1000 ** 3) + 'G') 58 | print('[ * ] Params Num = ' + str(params / 1000 ** 2) + 'M') 59 | except Exception as e: 60 | print(f'Err occured while calculating flops: {str(e)}') 61 | 62 | # def get_progress_bar_dict(self): 63 | # items = super().get_progress_bar_dict() 64 | # items.pop("v_num", None) 65 | # # items.pop("loss", None) 66 | # return items 67 | 68 | def build_test_res_dir(self): 69 | assert self.opt[CHECKPOINT_PATH] 70 | modelpath = pathlib.Path(self.opt[CHECKPOINT_PATH]) 71 | 72 | # only `test_ds` is supported when testing. 73 | ds_type = TEST_DATA 74 | runtime_dirname = f'{self.opt.runtime.modelname}_{modelpath.parent.name}_{modelpath.name}@{self.opt.test_ds.name}' 75 | dirpath = modelpath.parent / TEST_RESULT_DIRNAME 76 | 77 | if (dirpath / runtime_dirname).exists(): 78 | if len(os.listdir(dirpath / runtime_dirname)) == 0: 79 | # an existing but empty dir 80 | pass 81 | else: 82 | try: 83 | input_str = input( 84 | f'[ WARN ] Result directory "{runtime_dirname}" exists. Press ENTER to overwrite or input suffix ' 85 | f'to create a new one:\n> New name: {runtime_dirname}.') 86 | except Exception as e: 87 | print( 88 | f'[ WARN ] Excepion {e} occured, ignore input and set `input_str` empty.') 89 | input_str = '' 90 | if input_str == '': 91 | print( 92 | f"[ WARN ] Overwrite result_dir: {runtime_dirname}") 93 | pass 94 | else: 95 | runtime_dirname += '.' + input_str 96 | # fname += '.new' 97 | 98 | dirpath /= runtime_dirname 99 | util.mkdir(dirpath) 100 | print('TEST - Result save path:') 101 | print(str(dirpath)) 102 | 103 | util.save_opt(dirpath, self.opt) 104 | return str(dirpath) 105 | 106 | @staticmethod 107 | def save_img_batch(batch, dirpath, fname, save_num=1): 108 | util.mkdir(dirpath) 109 | imgpath = osp.join(dirpath, fname) 110 | 111 | # If you want to visiual a single image, call .unsqueeze(0) 112 | assert len(batch.shape) == 4 113 | torchvision.utils.save_image(batch[:save_num], imgpath) 114 | 115 | def calc_and_log_losses(self, loss_lambda_map): 116 | logged_losses = {} 117 | loss = 0 118 | for loss_name, loss_weight in self.opt[RUNTIME][LOSS].items(): 119 | if loss_weight: 120 | current = loss_lambda_map[loss_name]() 121 | if current != None: 122 | current *= loss_weight 123 | logged_losses[loss_name] = current 124 | loss += current 125 | 126 | logged_losses[LOSS] = loss 127 | self.log_dict(logged_losses) 128 | return loss 129 | 130 | def log_images_dict(self, mode, input_fname, img_batch_dict, gt_fname=None): 131 | """ 132 | log input, output and gt images to local disk and remote wandb logger. 133 | mode: TRAIN or VALID 134 | """ 135 | if self.opt[DEBUG]: 136 | return 137 | 138 | global LOGGER_BUFFER_LOCK 139 | if LOGGER_BUFFER_LOCK and self.opt.logger == 'wandb': 140 | # buffer is used by other GPU-thread. 141 | # print('Buffer locked!') 142 | return 143 | 144 | assert mode in [TRAIN, VALID] 145 | if mode == VALID: 146 | local_dirpath = self.valid_img_dirpath 147 | step = self.global_valid_step 148 | if self.global_valid_step == 0: 149 | print( 150 | 'WARN: Found global_valid_step=0. Maybe you foget to increase `self.global_valid_step` in `self.validation_step`?') 151 | # log_step = step # to avoid valid log step = train log step 152 | elif mode == TRAIN: 153 | local_dirpath = self.train_img_dirpath 154 | step = self.global_step 155 | # log_step = None 156 | 157 | if step % self.opt[LOG_EVERY] == 0: 158 | suffiix = f'_epoch{self.current_epoch}_step{step}.png' 159 | input_fname = osp.basename(input_fname) + suffiix 160 | 161 | if gt_fname: 162 | gt_fname = osp.basename(gt_fname) + suffiix 163 | 164 | # ****** public buffer opration ****** 165 | LOGGER_BUFFER_LOCK = True 166 | for name, batch in img_batch_dict.items(): 167 | if batch is None or batch is False: 168 | # image is None or False, skip. 169 | continue 170 | 171 | # save local image: 172 | fname = input_fname 173 | if name == GT and gt_fname: 174 | fname = gt_fname 175 | self.save_img_batch( 176 | batch, 177 | # e.g. ../train_log/train/output 178 | osp.join(local_dirpath, name), 179 | fname) 180 | 181 | # save remote image: 182 | if self.opt.logger == 'wandb': 183 | self.add_img_to_buffer(mode, batch, mode, name, fname) 184 | else: 185 | # tb logger 186 | self.logger.experiment.add_image(f'{mode}/{name}', batch[0], step) 187 | 188 | if self.opt.logger == 'wandb': 189 | self.commit_logger_buffer(mode) 190 | 191 | # self.buffer_img_step += 1 192 | LOGGER_BUFFER_LOCK = False 193 | # ****** public buffer opration ****** 194 | 195 | def add_img_to_buffer(self, group_name, batch, *caption): 196 | if len(batch.shape) == 3: 197 | # when input is not a batch: 198 | batch = batch.unsqueeze(0) 199 | 200 | self.logger_image_buffer[group_name].append( 201 | wandb.Image(batch[0], caption='-'.join(caption)) 202 | ) 203 | 204 | def commit_logger_buffer(self, groupname, **kwargs): 205 | assert self.logger 206 | self.logger.experiment.log({ 207 | groupname: self.logger_image_buffer[groupname] 208 | }, **kwargs) 209 | 210 | # clear buffer after each commit for the next commit 211 | self.logger_image_buffer[groupname].clear() 212 | -------------------------------------------------------------------------------- /src/model/basic_loss.py: -------------------------------------------------------------------------------- 1 | import kornia as kn 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision import models, transforms 6 | 7 | from .arch.hist import get_hist, get_hist_conv 8 | 9 | 10 | class HistogramLoss(nn.Module): 11 | def __init__(self, n_bins=8, downscale=16): 12 | super().__init__() 13 | self.n_bins = n_bins 14 | self.hist_conv = get_hist_conv(n_bins, downscale) 15 | 16 | # pack tensor: transform gt_hist.shape [n_bins, bs, c, h, w] -> [bs*c, b_bins, h, w] 17 | # merge dim 1 (bs) and dim 2 (channel). 18 | self.pack_tensor = lambda x: x.reshape(self.n_bins, -1, *x.shape[-2:]).permute(1, 0, 2, 3) 19 | 20 | def forward(self, output, gt): 21 | gt_hist = get_hist(gt, self.n_bins) 22 | output_hist = get_hist(output, self.n_bins) 23 | 24 | shrink_hist_gt = self.hist_conv(self.pack_tensor(gt_hist)) 25 | shrink_hist_output = self.hist_conv(self.pack_tensor(output_hist)) 26 | 27 | return F.mse_loss(shrink_hist_gt, shrink_hist_output) 28 | 29 | 30 | class IntermediateHistogramLoss(HistogramLoss): 31 | def __init__(self, n_bins=8, downscale=16): 32 | super().__init__(n_bins, downscale) 33 | self.exposure_threshold = 0.5 34 | 35 | def forward(self, img, gt, brighten, darken): 36 | """ 37 | input brighten and darken img, get errors between: 38 | - brighten img & darken region in GT 39 | - darken img & brighten region in GT 40 | """ 41 | bs, c, _, _ = gt.shape 42 | gt_hist = get_hist(gt, self.n_bins) 43 | shrink_hist_gt = self.hist_conv(self.pack_tensor(gt_hist)) 44 | 45 | down_size = shrink_hist_gt.shape[-2:] 46 | shrink_hist_gt = shrink_hist_gt.reshape(bs, c, self.n_bins, *down_size) 47 | down_x = F.interpolate(img, size=down_size) 48 | 49 | # get mask from the input: 50 | over_ixs = down_x > self.exposure_threshold 51 | under_ixs = down_x <= self.exposure_threshold 52 | over_mask = down_x.clone() 53 | over_mask[under_ixs] = 0 54 | over_mask[over_ixs] = 1 55 | over_mask.unsqueeze_(2) 56 | under_mask = down_x.clone() 57 | under_mask[under_ixs] = 1 58 | under_mask[over_ixs] = 0 59 | under_mask.unsqueeze_(2) 60 | 61 | shrink_darken_hist = self.hist_conv(self.pack_tensor(get_hist(darken, self.n_bins))).reshape(bs, c, self.n_bins, 62 | *down_size) 63 | shrink_brighten_hist = self.hist_conv(self.pack_tensor(get_hist(brighten, self.n_bins))).reshape(bs, c, 64 | self.n_bins, 65 | *down_size) 66 | 67 | # [ 046 ] use ssim loss 68 | return 0.5 * kn.losses.ssim_loss((shrink_hist_gt * over_mask).view(-1, c, *down_size), 69 | (shrink_darken_hist * over_mask).view(-1, c, *down_size), 70 | window_size=5) + 0.5 * kn.losses.ssim_loss( 71 | (shrink_hist_gt * under_mask).view(-1, c, *down_size), 72 | (shrink_brighten_hist * under_mask).view(-1, c, *down_size), window_size=5) 73 | 74 | # [ 042 ] use l2 loss 75 | # return 0.5 * F.mse_loss(shrink_hist_gt * over_mask, shrink_darken_hist * over_mask) + 0.5 * F.mse_loss(shrink_hist_gt * under_mask, shrink_brighten_hist * under_mask) 76 | 77 | 78 | class WeightedL1Loss(nn.Module): 79 | def __init__(self): 80 | super().__init__() 81 | 82 | def forward(self, input, output, gt): 83 | bias = 0.1 84 | weights = (torch.abs(input - 0.5) + bias) / 0.5 85 | weights = weights.mean(axis=1).unsqueeze(1).repeat(1, 3, 1, 1) 86 | loss = torch.mean(torch.abs(output - gt) * weights.detach()) 87 | return loss 88 | 89 | 90 | class LTVloss(nn.Module): 91 | def __init__(self, alpha=1.2, beta=1.5, eps=1e-4): 92 | super(LTVloss, self).__init__() 93 | self.alpha = alpha 94 | self.beta = beta 95 | self.eps = eps 96 | 97 | def forward(self, origin, illumination, weight): 98 | ''' 99 | origin: one batch of input data. shape [batchsize, 3, h, w] 100 | illumination: one batch of predicted illumination data. if predicted_illumination 101 | is False, then use the output (predicted result) of the network. 102 | ''' 103 | 104 | # # re-normalize origin to 0 ~ 1 105 | # origin = (input_ - input_.min().item()) / (input_.max().item() - input_.min().item()) 106 | 107 | I = origin[:, 0:1, :, :] * 0.299 + origin[:, 1:2, :, :] * \ 108 | 0.587 + origin[:, 2:3, :, :] * 0.114 109 | L = torch.log(I + self.eps) 110 | dx = L[:, :, :-1, :-1] - L[:, :, :-1, 1:] 111 | dy = L[:, :, :-1, :-1] - L[:, :, 1:, :-1] 112 | 113 | dx = self.beta / (torch.pow(torch.abs(dx), self.alpha) + self.eps) 114 | dy = self.beta / (torch.pow(torch.abs(dy), self.alpha) + self.eps) 115 | 116 | x_loss = dx * \ 117 | ((illumination[:, :, :-1, :-1] - illumination[:, :, :-1, 1:]) ** 2) 118 | y_loss = dy * \ 119 | ((illumination[:, :, :-1, :-1] - illumination[:, :, 1:, :-1]) ** 2) 120 | tvloss = torch.mean(x_loss + y_loss) / 2.0 121 | 122 | return tvloss * weight 123 | 124 | 125 | class L_TV(nn.Module): 126 | def __init__(self, TVLoss_weight=1): 127 | super(L_TV, self).__init__() 128 | self.TVLoss_weight = TVLoss_weight 129 | 130 | def forward(self, x): 131 | batch_size = x.size()[0] 132 | h_x = x.size()[2] 133 | w_x = x.size()[3] 134 | count_h = (x.size()[2] - 1) * x.size()[3] 135 | count_w = x.size()[2] * (x.size()[3] - 1) 136 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() 137 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() 138 | return self.TVLoss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size 139 | -------------------------------------------------------------------------------- /src/model/bilateralupsamplenet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from globalenv import * 7 | from .arch.drconv import DRConv2d 8 | from model.arch.unet_based.hist_unet import HistUNet 9 | from .basic_loss import LTVloss 10 | from .single_net_basemodel import SingleNetBaseModel 11 | 12 | 13 | class LitModel(SingleNetBaseModel): 14 | def __init__(self, opt): 15 | super().__init__(opt, BilateralUpsampleNet(opt[RUNTIME]), [TRAIN, VALID]) 16 | low_res = opt[RUNTIME][LOW_RESOLUTION] 17 | 18 | self.down_sampler = lambda x: F.interpolate(x, size=(low_res, low_res), mode='bicubic', align_corners=False) 19 | self.use_illu = opt[RUNTIME][PREDICT_ILLUMINATION] 20 | 21 | self.mse = torch.nn.MSELoss() 22 | self.ltv = LTVloss() 23 | self.cos = torch.nn.CosineSimilarity(1, 1e-8) 24 | 25 | self.net.train() 26 | 27 | def training_step(self, batch, batch_idx): 28 | input_batch, gt_batch, output_batch = super().training_step_forward(batch, batch_idx) 29 | loss_lambda_map = { 30 | MSE: lambda: self.mse(output_batch, gt_batch), 31 | COS_LOSS: lambda: (1 - self.cos(output_batch, gt_batch).mean()) * 0.5, 32 | LTV_LOSS: lambda: self.ltv(input_batch, self.net.illu_map, 1) if self.use_illu else None, 33 | } 34 | 35 | # logging: 36 | loss = self.calc_and_log_losses(loss_lambda_map) 37 | self.log_training_iogt_img(batch, extra_img_dict={ 38 | PREDICT_ILLUMINATION: self.net.illu_map, 39 | GUIDEMAP: self.net.guidemap 40 | }) 41 | return loss 42 | 43 | def validation_step(self, batch, batch_idx): 44 | super().validation_step(batch, batch_idx) 45 | 46 | def test_step(self, batch, batch_ix): 47 | super().test_step(batch, batch_ix) 48 | 49 | def forward(self, x): 50 | low_res_x = self.down_sampler(x) 51 | return self.net(low_res_x, x) 52 | 53 | 54 | class ConvBlock(nn.Module): 55 | def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, use_bias=True, activation=nn.ReLU, 56 | batch_norm=False): 57 | super(ConvBlock, self).__init__() 58 | conv_type = OPT['conv_type'] 59 | if conv_type == 'conv': 60 | self.conv = nn.Conv2d(int(inc), int(outc), kernel_size, padding=padding, stride=stride, bias=use_bias) 61 | elif conv_type.startswith('drconv'): 62 | region_num = int(conv_type.replace('drconv', '')) 63 | self.conv = DRConv2d(int(inc), int(outc), kernel_size, region_num=region_num, padding=padding, 64 | stride=stride) 65 | print(f'[ WARN ] Using DRconv2d(n_region={region_num}) instead of Conv2d in BilateralUpsampleNet.') 66 | else: 67 | raise NotImplementedError() 68 | 69 | self.activation = activation() if activation else None 70 | self.bn = nn.BatchNorm2d(outc) if batch_norm else None 71 | 72 | def forward(self, x): 73 | x = self.conv(x) 74 | if self.bn: 75 | x = self.bn(x) 76 | if self.activation: 77 | x = self.activation(x) 78 | return x 79 | 80 | 81 | class FC(nn.Module): 82 | def __init__(self, inc, outc, activation=nn.ReLU, batch_norm=False): 83 | super(FC, self).__init__() 84 | self.fc = nn.Linear(int(inc), int(outc), bias=(not batch_norm)) 85 | self.activation = activation() if activation else None 86 | self.bn = nn.BatchNorm1d(outc) if batch_norm else None 87 | 88 | def forward(self, x): 89 | x = self.fc(x) 90 | if self.bn: 91 | x = self.bn(x) 92 | if self.activation: 93 | x = self.activation(x) 94 | return x 95 | 96 | 97 | class SliceNode(nn.Module): 98 | def __init__(self, opt): 99 | super(SliceNode, self).__init__() 100 | self.opt = opt 101 | 102 | def forward(self, bilateral_grid, guidemap): 103 | # bilateral_grid shape: Nx12x8x16x16 104 | device = bilateral_grid.get_device() 105 | N, _, H, W = guidemap.shape 106 | hg, wg = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)]) # [0,511] HxW 107 | if device >= 0: 108 | hg = hg.to(device) 109 | wg = wg.to(device) 110 | 111 | hg = hg.float().repeat(N, 1, 1).unsqueeze(3) / (H - 1) * 2 - 1 # norm to [-1,1] NxHxWx1 112 | wg = wg.float().repeat(N, 1, 1).unsqueeze(3) / (W - 1) * 2 - 1 # norm to [-1,1] NxHxWx1 113 | guidemap = guidemap * 2 - 1 114 | guidemap = guidemap.permute(0, 2, 3, 1).contiguous() 115 | guidemap_guide = torch.cat([wg, hg, guidemap], dim=3).unsqueeze(1) 116 | 117 | # guidemap shape: [N, 1 (D), H, W] 118 | # bilateral_grid shape: [N, 12 (c), 8 (d), 16 (h), 16 (w)], which is considered as a 3D space: [8, 16, 16] 119 | # guidemap_guide shape: [N, 1 (D), H, W, 3], which is considered as a 3D space: [1, H, W] 120 | # coeff shape: [N, 12 (c), 1 (D), H, W] 121 | 122 | # in F.grid_sample, gird is guidemap_guide, input is bilateral_grid 123 | # guidemap_guide[N, D, H, W] is a 3-vector . but: 124 | # x -> W, y -> H, z -> D in bilater_grid 125 | # What does it really do: 126 | # [ 1 ] For pixel in guidemap_guide[D, H, W], get , and: 127 | # [ 2 ] Normalize from [-1, 1] to [0, w - 1], [0, h - 1], [0, d - 1], respectively. 128 | # [ 3 ] Locate pixel in bilateral_grid at position [N, :, z, y, x]. 129 | # [ 4 ] Interplate using the neighbor values as the output affine matrix. 130 | 131 | # Force them have the same type for fp16 training : 132 | guidemap_guide = guidemap_guide.type_as(bilateral_grid) 133 | # bilateral_grid = bilateral_grid.type_as(guidemap_guide) 134 | coeff = F.grid_sample(bilateral_grid, guidemap_guide, 'bilinear', align_corners=True) 135 | return coeff.squeeze(2) 136 | 137 | 138 | class ApplyCoeffs(nn.Module): 139 | def __init__(self): 140 | super(ApplyCoeffs, self).__init__() 141 | 142 | def forward(self, coeff, full_res_input): 143 | ''' 144 | coeff shape: [bs, 12, h, w] 145 | input shape: [bs, 3, h, w] 146 | Affine: 147 | r = a11*r + a12*g + a13*b + a14 148 | g = a21*r + a22*g + a23*b + a24 149 | ... 150 | ''' 151 | R = torch.sum(full_res_input * coeff[:, 0:3, :, :], dim=1, keepdim=True) + coeff[:, 3:4, :, :] 152 | G = torch.sum(full_res_input * coeff[:, 4:7, :, :], dim=1, keepdim=True) + coeff[:, 7:8, :, :] 153 | B = torch.sum(full_res_input * coeff[:, 8:11, :, :], dim=1, keepdim=True) + coeff[:, 11:12, :, :] 154 | 155 | return torch.cat([R, G, B], dim=1) 156 | 157 | 158 | class ApplyCoeffsGamma(nn.Module): 159 | def __init__(self): 160 | super(ApplyCoeffsGamma, self).__init__() 161 | print('[ WARN ] Use alter methods indtead of affine matrix.') 162 | 163 | def forward(self, x_r, x): 164 | ''' 165 | coeff shape: [bs, 12, h, w] 166 | apply zeroDCE curve. 167 | ''' 168 | 169 | # [ 008 ] single iteration alpha map: 170 | # coeff channel num: 3 171 | # return x + x_r * (torch.pow(x, 2) - x) 172 | 173 | # [ 009 ] 8 iteratoins: 174 | # coeff channel num: 24 175 | r1, r2, r3, r4, r5, r6, r7, r8 = torch.split(x_r, 3, dim=1) 176 | x = x + r1 * (torch.pow(x, 2) - x) 177 | x = x + r2 * (torch.pow(x, 2) - x) 178 | x = x + r3 * (torch.pow(x, 2) - x) 179 | enhance_image_1 = x + r4 * (torch.pow(x, 2) - x) 180 | x = enhance_image_1 + r5 * (torch.pow(enhance_image_1, 2) - enhance_image_1) 181 | x = x + r6 * (torch.pow(x, 2) - x) 182 | x = x + r7 * (torch.pow(x, 2) - x) 183 | enhance_image = x + r8 * (torch.pow(x, 2) - x) 184 | r = torch.cat([r1, r2, r3, r4, r5, r6, r7, r8], 1) 185 | return enhance_image 186 | 187 | # [ 014 ] use illu map: 188 | # coeff channel num: 3 189 | # return x / (torch.where(x_r < x, x, x_r) + 1e-7) 190 | 191 | # [ 015 ] use HSV and only affine V channel: 192 | # coeff channel num: 3 193 | # V = torch.sum(x * x_r, dim=1, keepdim=True) + x_r 194 | # return torch.cat([x[:, 0:2, ...], V], dim=1) 195 | 196 | 197 | class ApplyCoeffsRetinex(nn.Module): 198 | def __init__(self): 199 | super().__init__() 200 | print('[ WARN ] Use alter methods indtead of affine matrix.') 201 | 202 | def forward(self, x_r, x): 203 | ''' 204 | coeff shape: [bs, 12, h, w] 205 | apply division of illumap. 206 | ''' 207 | 208 | # [ 014 ] use illu map: 209 | # coeff channel num: 3 210 | return x / (torch.where(x_r < x, x, x_r) + 1e-7) 211 | 212 | 213 | class GuideNet(nn.Module): 214 | def __init__(self, params=None, out_channel=1): 215 | super(GuideNet, self).__init__() 216 | self.params = params 217 | self.conv1 = ConvBlock(3, 16, kernel_size=1, padding=0, batch_norm=True) 218 | self.conv2 = ConvBlock(16, out_channel, kernel_size=1, padding=0, activation=nn.Sigmoid) # nn.Tanh 219 | 220 | def forward(self, x): 221 | return self.conv2(self.conv1(x)) # .squeeze(1) 222 | 223 | 224 | class LowResNet(nn.Module): 225 | def __init__(self, coeff_dim=12, opt=None): 226 | super(LowResNet, self).__init__() 227 | self.params = opt 228 | self.coeff_dim = coeff_dim 229 | 230 | lb = opt[LUMA_BINS] 231 | cm = opt[CHANNEL_MULTIPLIER] 232 | sb = opt[SPATIAL_BIN] 233 | bn = opt[BATCH_NORM] 234 | nsize = opt[LOW_RESOLUTION] 235 | 236 | self.relu = nn.ReLU() 237 | 238 | # splat features 239 | n_layers_splat = int(np.log2(nsize / sb)) 240 | self.splat_features = nn.ModuleList() 241 | prev_ch = 3 242 | for i in range(n_layers_splat): 243 | use_bn = bn if i > 0 else False 244 | self.splat_features.append(ConvBlock(prev_ch, cm * (2 ** i) * lb, 3, stride=2, batch_norm=use_bn)) 245 | prev_ch = splat_ch = cm * (2 ** i) * lb 246 | 247 | # global features 248 | n_layers_global = int(np.log2(sb / 4)) 249 | # print(n_layers_global) 250 | self.global_features_conv = nn.ModuleList() 251 | self.global_features_fc = nn.ModuleList() 252 | for i in range(n_layers_global): 253 | self.global_features_conv.append(ConvBlock(prev_ch, cm * 8 * lb, 3, stride=2, batch_norm=bn)) 254 | prev_ch = cm * 8 * lb 255 | 256 | n_total = n_layers_splat + n_layers_global 257 | prev_ch = prev_ch * (nsize / 2 ** n_total) ** 2 258 | self.global_features_fc.append(FC(prev_ch, 32 * cm * lb, batch_norm=bn)) 259 | self.global_features_fc.append(FC(32 * cm * lb, 16 * cm * lb, batch_norm=bn)) 260 | self.global_features_fc.append(FC(16 * cm * lb, 8 * cm * lb, activation=None, batch_norm=bn)) 261 | 262 | # local features 263 | self.local_features = nn.ModuleList() 264 | self.local_features.append(ConvBlock(splat_ch, 8 * cm * lb, 3, batch_norm=bn)) 265 | self.local_features.append(ConvBlock(8 * cm * lb, 8 * cm * lb, 3, activation=None, use_bias=False)) 266 | 267 | # predicton 268 | self.conv_out = ConvBlock(8 * cm * lb, lb * coeff_dim, 1, padding=0, activation=None) 269 | 270 | def forward(self, lowres_input): 271 | params = self.params 272 | bs = lowres_input.shape[0] 273 | lb = params[LUMA_BINS] 274 | cm = params[CHANNEL_MULTIPLIER] 275 | sb = params[SPATIAL_BIN] 276 | 277 | x = lowres_input 278 | for layer in self.splat_features: 279 | x = layer(x) 280 | splat_features = x 281 | 282 | for layer in self.global_features_conv: 283 | x = layer(x) 284 | x = x.view(bs, -1) 285 | for layer in self.global_features_fc: 286 | x = layer(x) 287 | global_features = x 288 | 289 | x = splat_features 290 | for layer in self.local_features: 291 | x = layer(x) 292 | local_features = x 293 | 294 | # shape: bs x 64 x 16 x 16 295 | fusion_grid = local_features 296 | 297 | # shape: bs x 64 x 1 x 1 298 | fusion_global = global_features.view(bs, 8 * cm * lb, 1, 1) 299 | fusion = self.relu(fusion_grid + fusion_global) 300 | 301 | x = self.conv_out(fusion) 302 | 303 | # reshape channel dimension -> bilateral grid dimensions: 304 | # [bs, 96, 16, 16] -> [bs, 12, 8, 16, 16] 305 | y = torch.stack(torch.split(x, self.coeff_dim, 1), 2) 306 | return y 307 | 308 | 309 | class LowResHistUNet(HistUNet): 310 | def __init__(self, coeff_dim=12, opt=None): 311 | super(LowResHistUNet, self).__init__( 312 | in_channels=3, 313 | out_channels=coeff_dim * opt[LUMA_BINS], 314 | bilinear=True, 315 | **opt[HIST_UNET] 316 | ) 317 | self.coeff_dim = coeff_dim 318 | print('[[ WARN ]] Using HistUNet in BilateralUpsampleNet as backbone') 319 | 320 | def forward(self, x): 321 | y = super(LowResHistUNet, self).forward(x) 322 | y = torch.stack(torch.split(y, self.coeff_dim, 1), 2) 323 | return y 324 | 325 | 326 | class BilateralUpsampleNet(nn.Module): 327 | def __init__(self, opt): 328 | super(BilateralUpsampleNet, self).__init__() 329 | self.opt = opt 330 | global OPT 331 | OPT = opt 332 | self.guide = GuideNet(params=opt) 333 | self.slice = SliceNode(opt) 334 | self.build_coeffs_network(opt) 335 | 336 | def build_coeffs_network(self, opt): 337 | # Choose backbone: 338 | if opt[BACKBONE] == 'ori': 339 | Backbone = LowResNet 340 | elif opt[BACKBONE] == 'hist-unet': 341 | Backbone = LowResHistUNet 342 | else: 343 | raise NotImplementedError() 344 | 345 | # How to apply coeffs: 346 | # ─────────────────────────────────────────────────────────────────── 347 | if opt[COEFFS_TYPE] == MATRIX: 348 | self.coeffs = Backbone(opt=opt) 349 | self.apply_coeffs = ApplyCoeffs() 350 | 351 | elif opt[COEFFS_TYPE] == GAMMA: 352 | print('[[ WARN ]] HDRPointwiseNN use COEFFS_TYPE: GAMMA.') 353 | 354 | # [ 008 ] change affine matrix -> other methods (alpha map, illu map) 355 | self.coeffs = Backbone(opt=opt, coeff_dim=24) 356 | self.apply_coeffs = ApplyCoeffsGamma() 357 | 358 | elif opt[COEFFS_TYPE] == 'retinex': 359 | print('[[ WARN ]] HDRPointwiseNN use COEFFS_TYPE: retinex.') 360 | self.coeffs = Backbone(opt=opt, coeff_dim=3) 361 | self.apply_coeffs = ApplyCoeffsRetinex() 362 | 363 | else: 364 | raise NotImplementedError(f'[ ERR ] coeff type {opt[COEFFS_TYPE]} unkown.') 365 | # ───────────────────────────────────────────────────────────────────────────── 366 | 367 | def forward(self, lowres, fullres): 368 | bilateral_grid = self.coeffs(lowres) 369 | try: 370 | self.guide_features = self.coeffs.guide_features 371 | except: 372 | ... 373 | guide = self.guide(fullres) 374 | self.guidemap = guide 375 | 376 | slice_coeffs = self.slice(bilateral_grid, guide) 377 | out = self.apply_coeffs(slice_coeffs, fullres) 378 | 379 | # use illu map: 380 | self.slice_coeffs = slice_coeffs 381 | # if self.opt[PREDICT_ILLUMINATION]: 382 | # 383 | # power = self.opt[ILLU_MAP_POWER] 384 | # if power: 385 | # assert type(power + 0.1) == float 386 | # out = out.pow(power) 387 | # 388 | # out = out.clamp(fullres, torch.ones_like(out)) 389 | # # out = torch.where(out < fullres, fullres, out) 390 | # self.illu_map = out 391 | # out = fullres / (out + 1e-7) 392 | # else: 393 | self.illu_map = None 394 | 395 | if self.opt[PREDICT_ILLUMINATION]: 396 | return fullres / (out.clamp(fullres, torch.ones_like(out)) + 1e-7) 397 | else: 398 | return out 399 | -------------------------------------------------------------------------------- /src/model/lcdpnet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import kornia as kn 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torchvision.utils 8 | import utils.util as util 9 | 10 | from globalenv import * 11 | from .arch.nonlocal_block_embedded_gaussian import NONLocalBlock2D 12 | from .basic_loss import L_TV, WeightedL1Loss, HistogramLoss, IntermediateHistogramLoss, LTVloss 13 | # from .hdrunet import tanh_L1Loss 14 | from .single_net_basemodel import SingleNetBaseModel 15 | 16 | 17 | class tanh_L1Loss(nn.Module): 18 | def __init__(self): 19 | super(tanh_L1Loss, self).__init__() 20 | 21 | def forward(self, x, y): 22 | loss = torch.mean(torch.abs(torch.tanh(x) - torch.tanh(y))) 23 | return loss 24 | 25 | 26 | class LitModel(SingleNetBaseModel): 27 | def __init__(self, opt): 28 | super().__init__(opt, DeepWBNet(opt[RUNTIME]), [TRAIN, VALID]) 29 | # self.pixel_loss = torch.nn.MSELoss() 30 | 31 | # [ 008-u1 ] use tanh L1 loss 32 | self.pixel_loss = tanh_L1Loss() 33 | 34 | # [ 019 ] use log L1 loss 35 | # self.pixel_loss = LogL2Loss() 36 | 37 | # [ 028 ] use weighted L1 loss. 38 | self.weighted_loss = WeightedL1Loss() 39 | self.tvloss = L_TV() 40 | self.ltv2 = LTVloss() 41 | self.cos = torch.nn.CosineSimilarity(1, 1e-8) 42 | self.histloss = HistogramLoss() 43 | # self.vggloss = VGGLoss(shift=2) 44 | # self.vggloss.train() 45 | self.inter_histloss = IntermediateHistogramLoss() 46 | 47 | def training_step(self, batch, batch_idx): 48 | input_batch, gt_batch, output_batch = super().training_step_forward(batch, batch_idx) 49 | # print('[*] Now running:', batch[INPUT].shape, batch[GT].shape, output_batch.shape, batch[INPUT_FPATH], batch[GT_FPATH]) 50 | 51 | loss_lambda_map = { 52 | L1_LOSS: lambda: self.pixel_loss(output_batch, gt_batch), 53 | COS_LOSS: lambda: (1 - self.cos(output_batch, gt_batch).mean()) * 0.5, 54 | COS_LOSS + '2': lambda: 1 - F.sigmoid(self.cos(output_batch, gt_batch).mean()), 55 | LTV_LOSS: lambda: self.tvloss(output_batch), 56 | 'tvloss1': lambda: self.tvloss(self.net.res[ILLU_MAP]), 57 | 'tvloss2': lambda: self.tvloss(self.net.res[INVERSE_ILLU_MAP]), 58 | 59 | 'tvloss1_new': lambda: self.ltv2(input_batch, self.net.res[ILLU_MAP], 1), 60 | 'tvloss2_new': lambda: self.ltv2(1 - input_batch, self.net.res[INVERSE_ILLU_MAP], 1), 61 | 'illumap_loss': lambda: F.mse_loss(self.net.res[ILLU_MAP], 1 - self.net.res[INVERSE_ILLU_MAP]), 62 | WEIGHTED_LOSS: lambda: self.weighted_loss(input_batch.detach(), output_batch, gt_batch), 63 | SSIM_LOSS: lambda: kn.losses.ssim_loss(output_batch, gt_batch, window_size=5), 64 | PSNR_LOSS: lambda: kn.losses.psnr_loss(output_batch, gt_batch, max_val=1.0), 65 | HIST_LOSS: lambda: self.histloss(output_batch, gt_batch), 66 | INTER_HIST_LOSS: lambda: self.inter_histloss( 67 | input_batch, gt_batch, self.net.res[BRIGHTEN_INPUT], self.net.res[DARKEN_INPUT]), 68 | VGG_LOSS: lambda: self.vggloss(input_batch, gt_batch), 69 | } 70 | loss = self.calc_and_log_losses(loss_lambda_map) 71 | 72 | # logging images: 73 | self.log_training_iogt_img(batch) 74 | return loss 75 | 76 | def validation_step(self, batch, batch_idx): 77 | super().validation_step(batch, batch_idx) 78 | 79 | def test_step(self, batch, batch_ix): 80 | super().test_step(batch, batch_ix) 81 | 82 | # save intermidiate results 83 | for k, v in self.net.res.items(): 84 | dirpath = Path(self.opt[IMG_DIRPATH]) / k 85 | fname = osp.basename(batch[INPUT_FPATH][0]) 86 | if 'illu' in k: 87 | util.mkdir(dirpath) 88 | torchvision.utils.save_image(v[0].unsqueeze(1), dirpath / fname) 89 | elif k == 'guide_features': 90 | # v.shape: [bs, region_num, h, w] 91 | util.mkdir(dirpath) 92 | max_size = v[-1][-1].shape[-2:] 93 | final = [] 94 | for level_guide in v: 95 | gs = [F.interpolate(g, max_size) for g in level_guide] 96 | final.extend(gs) 97 | # import ipdb 98 | # ipdb.set_trace() 99 | region_num = final[0].shape[1] 100 | final = torch.stack(final).argmax(axis=2).float() / region_num 101 | # ipdb.set_trace() 102 | torchvision.utils.save_image(final, dirpath / fname) 103 | else: 104 | self.save_img_batch(v, dirpath, fname) 105 | 106 | 107 | class DeepWBNet(nn.Module): 108 | def build_illu_net(self): 109 | # if self.opt[BACKBONE] == 'unet': 110 | # if self.opt[USE_ATTN_MAP]: 111 | # return UNet( 112 | # self.opt, 113 | # in_channels=4, 114 | # out_channels=1, 115 | # wavelet=self.opt[USE_WAVELET], 116 | # non_local=self.opt[NON_LOCAL] 117 | # ) 118 | # else: 119 | # return UNet(self.opt, out_channels=self.opt[ILLUMAP_CHANNEL], wavelet=self.opt[USE_WAVELET]) 120 | 121 | from .bilateralupsamplenet import BilateralUpsampleNet 122 | return BilateralUpsampleNet(self.opt[BUNET]) 123 | # 124 | # elif self.opt[BACKBONE] == 'ynet': 125 | # from .arch.ynet import YNet 126 | # return YNet() 127 | # 128 | # elif self.opt[BACKBONE] == 'hdrunet': 129 | # from .hdrunet import HDRUNet 130 | # return HDRUNet() 131 | # 132 | # elif self.opt[BACKBONE] == 'hist-unet': 133 | # from model.arch.unet_based.hist_unet import HistUNet 134 | # return HistUNet(**self.opt[HIST_UNET]) 135 | # 136 | # else: 137 | # raise NotImplementedError(f'[[ ERR ]] Unknown backbone arch: {self.opt[BACKBONE]}') 138 | 139 | def backbone_forward(self, net, x): 140 | if self.opt[BACKBONE] in ['unet', 'hdrunet', 'hist-unet']: 141 | return net(x) 142 | 143 | elif self.opt[BACKBONE] == 'ynet': 144 | return net.forward_2input(x, 1 - x) 145 | 146 | elif self.opt[BACKBONE] == BUNET: 147 | low_x = self.down_sampler(x) 148 | res = net(low_x, x) 149 | try: 150 | self.res.update({'guide_features': net.guide_features}) 151 | except: 152 | ... 153 | # print('[yellow]No guide feature found in BilateralUpsampleNet[/yellow]') 154 | return res 155 | 156 | def __init__(self, opt=None): 157 | super(DeepWBNet, self).__init__() 158 | self.opt = opt 159 | self.down_sampler = lambda x: F.interpolate(x, size=(256, 256), mode='bicubic', align_corners=False) 160 | self.illu_net = self.build_illu_net() 161 | 162 | # [ 021 ] use 2 illu nets (do not share weights). 163 | if not opt[SHARE_WEIGHTS]: 164 | self.illu_net2 = self.build_illu_net() 165 | 166 | # self.guide_net = GuideNN(out_channel=3) 167 | if opt[HOW_TO_FUSE] in ['cnn-weights', 'cnn-direct', 'cnn-softmax3']: 168 | # self.out_net = UNet(in_channels=9, wavelet=opt[USE_WAVELET]) 169 | 170 | # [ 008-u1 ] use a simple network 171 | nf = 32 172 | self.out_net = nn.Sequential( 173 | nn.Conv2d(9, nf, 3, 1, 1), 174 | nn.ReLU(inplace=True), 175 | nn.Conv2d(nf, nf, 3, 1, 1), 176 | nn.ReLU(inplace=True), 177 | NONLocalBlock2D(nf, sub_sample='bilinear', bn_layer=False), 178 | nn.Conv2d(nf, nf, 1), 179 | nn.ReLU(inplace=True), 180 | nn.Conv2d(nf, 3, 1), 181 | NONLocalBlock2D(3, sub_sample='bilinear', bn_layer=False), 182 | ) 183 | 184 | elif opt[HOW_TO_FUSE] in ['cnn-color']: 185 | # self.out_net = UNet(in_channels=3, wavelet=opt[USE_WAVELET]) 186 | ... 187 | 188 | if not self.opt[BACKBONE_OUT_ILLU]: 189 | print('[[ WARN ]] Use output of backbone as brighten & darken directly.') 190 | self.res = {} 191 | 192 | def decomp(self, x1, illu_map): 193 | return x1 / (torch.where(illu_map < x1, x1, illu_map.float()) + 1e-7) 194 | 195 | def one_iter(self, x, attn_map, inverse_attn_map): 196 | # used only when USE_ATTN_MAP 197 | x1 = torch.cat((x, attn_map), 1) 198 | inverse_x1 = torch.cat((1 - x, inverse_attn_map), 1) 199 | 200 | illu_map = self.illu_net(x1, attn_map) 201 | inverse_illu_map = self.illu_net(inverse_x1) 202 | return illu_map, inverse_illu_map 203 | 204 | def forward(self, x): 205 | # ────────────────────────────────────────────────────────── 206 | # [ <008 ] use guideNN 207 | # x1 = self.guide_net(x).clamp(0, 1) 208 | 209 | # [ 008 ] use original input 210 | x1 = x 211 | inverse_x1 = 1 - x1 212 | 213 | if self.opt[USE_ATTN_MAP]: 214 | # [ 015 ] use attn map iteration to get illu map 215 | r, g, b = x[:, 0] + 1, x[:, 1] + 1, x[:, 2] + 1 216 | 217 | # init attn map as illumination channel of original input img: 218 | attn_map = (1. - (0.299 * r + 0.587 * g + 0.114 * b) / 2.).unsqueeze(1) 219 | inverse_attn_map = 1 - attn_map 220 | for _ in range(3): 221 | inverse_attn_map, attn_map = self.one_iter(x, attn_map, inverse_attn_map) 222 | illu_map, inverse_illu_map = inverse_attn_map, attn_map 223 | 224 | elif self.opt[BACKBONE] == 'ynet': 225 | # [ 024 ] one encoder, 2 decoders. 226 | illu_map, inverse_illu_map = self.backbone_forward(self.illu_net, x1) 227 | 228 | else: 229 | illu_map = self.backbone_forward(self.illu_net, x1) 230 | if self.opt[SHARE_WEIGHTS]: 231 | inverse_illu_map = self.backbone_forward(self.illu_net, inverse_x1) 232 | else: 233 | # [ 021 ] use 2 illu nets 234 | inverse_illu_map = self.backbone_forward(self.illu_net2, inverse_x1) 235 | # ────────────────────────────────────────────────────────── 236 | 237 | if self.opt[BACKBONE_OUT_ILLU]: 238 | brighten_x1 = self.decomp(x1, illu_map) 239 | inverse_x2 = self.decomp(inverse_x1, inverse_illu_map) 240 | else: 241 | brighten_x1 = illu_map 242 | inverse_x2 = inverse_illu_map 243 | darken_x1 = 1 - inverse_x2 244 | # ────────────────────────────────────────────────────────── 245 | 246 | self.res.update({ 247 | ILLU_MAP: illu_map, 248 | INVERSE_ILLU_MAP: inverse_illu_map, 249 | BRIGHTEN_INPUT: brighten_x1, 250 | DARKEN_INPUT: darken_x1, 251 | }) 252 | 253 | # fusion: 254 | # ────────────────────────────────────────────────────────── 255 | if self.opt[HOW_TO_FUSE] == 'cnn-weights': 256 | # [ 009 ] only fuse 2 output image 257 | # fused_x = torch.cat([brighten_x1, darken_x1], dim=1) 258 | 259 | fused_x = torch.cat([x, brighten_x1, darken_x1], dim=1) 260 | 261 | # [ 007 ] get weight-map from UNet, then get output from weight-map 262 | weight_map = self.out_net(fused_x) # <- 3 channels, [ N, 3, H, W ] 263 | w1 = weight_map[:, 0, ...].unsqueeze(1) 264 | w2 = weight_map[:, 1, ...].unsqueeze(1) 265 | w3 = weight_map[:, 2, ...].unsqueeze(1) 266 | out = x * w1 + brighten_x1 * w2 + darken_x1 * w3 267 | 268 | # [ 009 ] only fuse 2 output image 269 | # out = brighten_x1 * w1 + darken_x1 * w2 270 | # ──────────────────────────────────────────────────────────── 271 | 272 | elif self.opt[HOW_TO_FUSE] == 'cnn-softmax3': 273 | fused_x = torch.cat([x, brighten_x1, darken_x1], dim=1) 274 | weight_map = F.softmax(self.out_net(fused_x), dim=1) # <- 3 channels, [ N, 3, H, W ] 275 | w1 = weight_map[:, 0, ...].unsqueeze(1) 276 | w2 = weight_map[:, 1, ...].unsqueeze(1) 277 | w3 = weight_map[:, 2, ...].unsqueeze(1) 278 | out = x * w1 + brighten_x1 * w2 + darken_x1 * w3 279 | 280 | # [ 006 ] get output directly from UNet 281 | elif self.opt[HOW_TO_FUSE] == 'cnn-direct': 282 | fused_x = torch.cat([x, brighten_x1, darken_x1], dim=1) 283 | out = self.out_net(fused_x) 284 | 285 | # [ 016 ] average 2 outputs. 286 | elif self.opt[HOW_TO_FUSE] == 'avg': 287 | out = 0.5 * brighten_x1 + 0.5 * darken_x1 288 | 289 | # [ 017 ] global color ajust 290 | elif self.opt[HOW_TO_FUSE] == 'cnn-color': 291 | out = 0.5 * brighten_x1 + 0.5 * darken_x1 292 | 293 | # elif self.opt[HOW_TO_FUSE] == 'cnn-residual': 294 | # out = x + 295 | 296 | else: 297 | raise NotImplementedError(f'Unknown fusion method: {self.opt[HOW_TO_FUSE]}') 298 | return out 299 | -------------------------------------------------------------------------------- /src/model/single_net_basemodel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os.path as osp 4 | 5 | import cv2 6 | import ipdb 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | import torchmetrics 11 | 12 | import utils.util as util 13 | from globalenv import * 14 | from .basemodel import BaseModel 15 | 16 | 17 | # dict_merge = lambda a, b: a.update(b) or a 18 | 19 | 20 | class SingleNetBaseModel(BaseModel): 21 | # for models with only one self.net 22 | def __init__(self, opt, net, running_modes, valid_ssim=False, print_arch=True): 23 | super().__init__(opt, running_modes) 24 | self.net = net 25 | self.net.train() 26 | 27 | # config for SingleNetBaseModel 28 | if print_arch: 29 | print(str(net)) 30 | self.valid_ssim = valid_ssim # weather to compute ssim in validation 31 | self.tonemapper = cv2.createTonemapReinhard(2.2) 32 | 33 | self.psnr_func = torchmetrics.PeakSignalNoiseRatio(data_range=1) 34 | self.ssim_func = torchmetrics.StructuralSimilarityIndexMeasure(data_range=1) 35 | 36 | def configure_optimizers(self): 37 | # self.parameters in LitModel is the same as nn.Module. 38 | # once you add nn.xxxx as a member in __init__, self.parameters will include it. 39 | optimizer = optim.Adam(self.net.parameters(), lr=self.learning_rate) 40 | # optimizer = optim.Adam(self.net.parameters(), lr=self.opt[LR]) 41 | 42 | schedular = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) 43 | return [optimizer], [schedular] 44 | 45 | def forward(self, x): 46 | return self.net(x) 47 | 48 | def training_step_forward(self, batch, batch_idx): 49 | if not self.MODEL_WATCHED and not self.opt[DEBUG] and self.opt.logger == 'wandb': 50 | self.logger.experiment.watch( 51 | self.net, log_freq=self.opt[LOG_EVERY] * 2, log_graph=True) 52 | self.MODEL_WATCHED = True 53 | # self.show_flops_and_param_num([batch[INPUT]]) 54 | 55 | input_batch, gt_batch = batch[INPUT], batch[GT] 56 | output_batch = self(input_batch) 57 | self.iogt = { 58 | INPUT: input_batch, 59 | OUTPUT: output_batch, 60 | GT: gt_batch, 61 | } 62 | return input_batch, gt_batch, output_batch 63 | 64 | def validation_step(self, batch, batch_idx): 65 | input_batch, gt_batch = batch[INPUT], batch[GT] 66 | output_batch = self(input_batch) 67 | 68 | # log psnr 69 | output_ = util.cuda_tensor_to_ndarray(output_batch) 70 | y_ = util.cuda_tensor_to_ndarray(gt_batch) 71 | try: 72 | psnr = util.ImageProcessing.compute_psnr(output_, y_, 1.0) 73 | except: 74 | ipdb.set_trace() 75 | self.log(PSNR, psnr) 76 | 77 | # log SSIM (optional) 78 | if self.valid_ssim: 79 | ssim = util.ImageProcessing.compute_ssim(output_batch, gt_batch) 80 | self.log(SSIM, ssim) 81 | 82 | # log images 83 | if self.global_valid_step % self.opt.log_every == 0: 84 | self.log_images_dict( 85 | VALID, 86 | osp.basename(batch[INPUT_FPATH][0]), 87 | { 88 | INPUT: input_batch, 89 | OUTPUT: output_batch, 90 | GT: gt_batch, 91 | }, 92 | gt_fname=osp.basename(batch[GT_FPATH][0]) 93 | ) 94 | self.global_valid_step += 1 95 | return output_batch 96 | 97 | def log_training_iogt_img(self, batch, extra_img_dict=None): 98 | """ 99 | Only used in training_step 100 | """ 101 | if extra_img_dict: 102 | img_dict = {**self.iogt, **extra_img_dict} 103 | else: 104 | img_dict = self.iogt 105 | 106 | if self.global_step % self.opt.log_every == 0: 107 | self.log_images_dict( 108 | TRAIN, 109 | osp.basename(batch[INPUT_FPATH][0]), 110 | img_dict, 111 | gt_fname=osp.basename(batch[GT_FPATH][0]) 112 | ) 113 | 114 | @staticmethod 115 | def logdomain2hdr(ldr_batch): 116 | return 10 ** ldr_batch - 1 117 | 118 | def on_test_start(self): 119 | self.total_psnr = 0 120 | self.total_ssim = 0 121 | self.global_test_step = 0 122 | 123 | def on_test_end(self): 124 | print( 125 | f'Test step: {self.global_test_step}, Manual PSNR: {self.total_psnr / self.global_test_step}, Manual SSIM: {self.total_ssim / self.global_test_step}') 126 | 127 | def test_step(self, batch, batch_ix): 128 | """ 129 | save test result and calculate PSNR and SSIM for `self.net` (when have GT) 130 | """ 131 | # test without GT image: 132 | self.global_test_step += 1 133 | input_batch = batch[INPUT] 134 | assert input_batch.shape[0] == 1 135 | output_batch = self(input_batch) 136 | save_num = 1 137 | # visualized_batch = torch.cat([input_batch, output_batch]) 138 | # save_num = 2 139 | 140 | # test with GT: 141 | # if GT in batch: 142 | # gt_batch = batch[GT] 143 | # if output_batch.shape != batch[GT].shape: 144 | # print( 145 | # f'[[ WARN ]] output.shape is {output_batch.shape} but GT.shape is {batch[GT].shape}. Resize to get PSNR.') 146 | # gt_batch = F.interpolate(batch[GT], output_batch.shape[2:]) 147 | # 148 | # visualized_batch = torch.cat([visualized_batch, gt_batch]) 149 | # save_num = 3 150 | # 151 | # # calculate metrics: 152 | # # psnr = float(self.psnr_func(output_batch, gt_batch).cpu().numpy()) 153 | # # ssim = float(self.ssim_func(output_batch, gt_batch).cpu().numpy()) 154 | # # ipdb.set_trace() 155 | # 156 | # # output_ = util.cuda_tensor_to_ndarray(output_batch) 157 | # # y_ = util.cuda_tensor_to_ndarray(gt_batch) 158 | # # psnr = util.ImageProcessing.compute_psnr(output_, y_, 1.0) 159 | # # ssim = util.ImageProcessing.compute_ssim(output_, y_) 160 | # # self.log_dict({ 161 | # # 'test-' + PSNR: psnr, 162 | # # 'test-' + SSIM: ssim 163 | # # }, prog_bar=True, on_step=True, on_epoch=True, batch_size=1) 164 | # # self.total_psnr += psnr 165 | # # self.total_ssim += ssim 166 | # # print( 167 | # # f'{batch[INPUT_FPATH][0].split("/")[-1]}: psnr: {psnr:.4f}, ssim: {ssim:.4f}, avgpsnr: {self.total_psnr / self.global_test_step:.4f}, avgssim: {self.total_ssim / self.global_test_step:.4f}') 168 | # 169 | # # if batch[GT_FPATH][0].endswith('.hdr'): 170 | # # # output and GT -> HDR domain -> tonemap back to LDR 171 | # # output_vis = util.cuda_tensor_to_ndarray( 172 | # # self.logdomain2hdr(visualized_batch[1]).permute(1, 2, 0)) 173 | # # gt_vis = util.cuda_tensor_to_ndarray( 174 | # # self.logdomain2hdr(visualized_batch[2]).permute(1, 2, 0)) 175 | # # output_ldr = self.tonemapper.process(output_vis) 176 | # # gt_ldr = self.tonemapper.process(gt_vis) 177 | # # visualized_batch[1] = torch.tensor( 178 | # # output_ldr).permute(2, 0, 1).cuda() 179 | # # visualized_batch[2] = torch.tensor( 180 | # # gt_ldr).permute(2, 0, 1).cuda() 181 | # # 182 | # # hdr_outpath = Path(self.opt[IMG_DIRPATH]) / 'hdr_output' 183 | # # util.mkdir(hdr_outpath) 184 | # # cv2.imwrite( 185 | # # str(hdr_outpath / (osp.basename(batch[INPUT_FPATH][0]) + '.hdr')), output_vis) 186 | 187 | # save images 188 | self.save_img_batch( 189 | output_batch, 190 | self.opt[IMG_DIRPATH], 191 | osp.basename(batch[INPUT_FPATH][0]), 192 | save_num=save_num 193 | ) 194 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | 4 | import hydra 5 | import pytorch_lightning as pl 6 | from omegaconf import open_dict 7 | from pytorch_lightning import Trainer 8 | 9 | from globalenv import * 10 | from utils.util import parse_config 11 | 12 | pl.seed_everything(GLOBAL_SEED) 13 | 14 | 15 | @hydra.main(config_path='config', config_name="config") 16 | def main(opt): 17 | opt = parse_config(opt, TEST) 18 | print('Running config:', opt) 19 | from model.lcdpnet import LitModel as ModelClass 20 | ckpt = opt[CHECKPOINT_PATH] 21 | assert ckpt 22 | model = ModelClass.load_from_checkpoint(ckpt, opt=opt) 23 | # model.opt = opt 24 | with open_dict(opt): 25 | model.opt[IMG_DIRPATH] = model.build_test_res_dir() 26 | opt.mode = 'test' 27 | print(f'Loading model from: {ckpt}') 28 | 29 | from data.img_dataset import DataModule 30 | datamodule = DataModule(opt) 31 | 32 | trainer = Trainer( 33 | gpus=opt[GPU], 34 | strategy=opt[BACKEND], 35 | precision=opt[RUNTIME_PRECISION]) 36 | 37 | beg = time.time() 38 | trainer.test(model, datamodule) 39 | print(f'[ TIMER ] Total time usage: {time.time() - beg}') 40 | print('[ PATH ] The results are in :') 41 | print(model.opt[IMG_DIRPATH]) 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import platform 4 | import time 5 | import warnings 6 | 7 | import hydra 8 | import pytorch_lightning as pl 9 | import torch 10 | from omegaconf import open_dict 11 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 12 | from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger 13 | 14 | import utils.util as util 15 | from globalenv import * 16 | from utils.util import parse_config, init_logging 17 | from data.img_dataset import DataModule 18 | 19 | pl.seed_everything(GLOBAL_SEED) 20 | 21 | 22 | @hydra.main(config_path='config', config_name="config") 23 | def main(config): 24 | try: 25 | print('GPU status info:') 26 | os.system('nvidia-smi') 27 | except: 28 | ... 29 | 30 | # print(config) 31 | opt = parse_config(config, TRAIN) 32 | if opt.name == DEBUG: 33 | opt.debug = True 34 | 35 | if opt.debug: 36 | # ipdb.set_trace() 37 | mylogger = None 38 | if opt.checkpoint_path: 39 | continue_epoch = torch.load( 40 | opt.checkpoint_path, map_location=torch.device('cpu'))['global_step'] 41 | debug_config = { 42 | DATALOADER_N: 0, 43 | NAME: DEBUG, 44 | LOG_EVERY: 1, 45 | VALID_EVERY: 1, 46 | NUM_EPOCH: 2 if not opt.checkpoint_path else continue_epoch + 2 47 | } 48 | opt.update(debug_config) 49 | debug_str = '[red]>>>> [[ WARN ]] You are in debug mode, update configs. <<<<[/red]' 50 | print(f'{debug_str}\n{debug_config}\n{debug_str}') 51 | 52 | else: 53 | # rename the exp 54 | spl = '_' if platform.system() == 'Windows' else ':' 55 | opt.name = f'{opt.runtime.modelname}{spl}{opt.name}@{opt.train_ds.name}' 56 | 57 | # trainer logger. init early to record all console output. 58 | mylogger = TensorBoardLogger( 59 | name=opt.name, 60 | save_dir=ROOT_PATH / 'tb_logs', 61 | ) 62 | 63 | # init logging 64 | print('Running config:', opt) 65 | # opt[LOG_DIRPATH], opt.img_dirpath = init_logging(TRAIN, opt) 66 | with open_dict(opt): 67 | opt.log_dirpath, opt.img_dirpath = init_logging(TRAIN, opt) 68 | 69 | # load data 70 | # DataModuleClass = parse_ds_class(opt[TRAIN_DATA][CLASS]) 71 | datamodule = DataModule(opt) 72 | 73 | # callbacks: 74 | callbacks = [] 75 | if opt[EARLY_STOP]: 76 | print( 77 | f'Apply EarlyStopping when `{opt.checkpoint_monitor}` is {opt.monitor_mode}') 78 | callbacks.append(EarlyStopping( 79 | opt.checkpoint_monitor, mode=opt.monitor_mode)) 80 | 81 | # callbacks: 82 | checkpoint_callback = ModelCheckpoint( 83 | dirpath=opt[LOG_DIRPATH], 84 | save_last=True, 85 | save_top_k=5, 86 | mode=opt.monitor_mode, 87 | monitor=opt.checkpoint_monitor, 88 | save_on_train_epoch_end=True, 89 | every_n_epochs=opt.savemodel_every 90 | ) 91 | callbacks.append(checkpoint_callback) 92 | 93 | if opt[AMP_BACKEND] != 'native': 94 | print( 95 | f'WARN: Running in APEX, mode: {opt[AMP_BACKEND]}-{opt[AMP_LEVEL]}') 96 | else: 97 | opt[AMP_LEVEL] = None 98 | 99 | # init trainer: 100 | trainer = pl.Trainer( 101 | gpus=opt[GPU], 102 | max_epochs=opt[NUM_EPOCH], 103 | logger=mylogger, 104 | callbacks=callbacks, 105 | check_val_every_n_epoch=opt[VALID_EVERY], 106 | num_sanity_val_steps=opt[VAL_DEBUG_STEP_NUMS], 107 | strategy=opt[BACKEND], 108 | precision=opt[RUNTIME_PRECISION], 109 | amp_backend=opt[AMP_BACKEND], 110 | amp_level=opt[AMP_LEVEL], 111 | **opt.flags 112 | ) 113 | print('Trainer initailized.') 114 | 115 | # training loop 116 | from model.lcdpnet import LitModel as ModelClass 117 | if opt.checkpoint_path and not opt.resume_training: 118 | print('Load ckpt and train from step 0...') 119 | model = ModelClass.load_from_checkpoint(opt.checkpoint_path, opt=opt) 120 | trainer.fit(model, datamodule) 121 | else: 122 | model = ModelClass(opt) 123 | print(f'Continue training: {opt.checkpoint_path}') 124 | trainer.fit(model, datamodule, ckpt_path=opt.checkpoint_path) 125 | 126 | 127 | if __name__ == "__main__": 128 | main() 129 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onpix/LCDPNet/4faa0d98e8ff45f53a3569dd005a74353995b335/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | import os 4 | import os.path as osp 5 | import smtplib 6 | from email.mime.multipart import MIMEMultipart 7 | from email.mime.text import MIMEText 8 | 9 | import cv2 10 | import ipdb 11 | import numpy as np 12 | import omegaconf 13 | import torch 14 | import yaml 15 | from PIL import Image 16 | from matplotlib.image import imread 17 | from skimage.metrics import structural_similarity as calc_ssim 18 | from torch.autograd import Variable 19 | 20 | from globalenv import * 21 | 22 | 23 | # from model import parse_model_class 24 | 25 | 26 | def update_global_opt(global_opt, valued_opt): 27 | for k, v in valued_opt.items(): 28 | global_opt[k] = v 29 | 30 | 31 | def mkdir(dirpath): 32 | if not osp.exists(dirpath): 33 | print(f'Creating directory: "{dirpath}"') 34 | try: 35 | os.makedirs(dirpath) 36 | except: 37 | ipdb.set_trace() 38 | return 39 | # print(f'Directory {dirpath} already exists, skip creating.') 40 | 41 | 42 | def cuda_tensor_to_ndarray(cuda_tensor): 43 | return cuda_tensor.clone().detach().cpu().numpy() 44 | 45 | 46 | def tensor2pil(img): 47 | img = img.squeeze() # * 0.5 + 0.5 48 | return Image.fromarray(img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()) 49 | 50 | 51 | # 52 | # def calculate_psnr(img1, img2): 53 | # # img1 and img2 have range [0, 255] 54 | # # shape: [H, W, C] 55 | # img1 = img1.astype(np.float64) 56 | # img2 = img2.astype(np.float64) 57 | # mse = np.mean((img1 - img2) ** 2) 58 | # if mse == 0: 59 | # return float('inf') 60 | # return 20 * np.log10(255.0 / np.sqrt(mse)) 61 | # 62 | # 63 | # def ssim_com(img1, img2): 64 | # C1 = (0.01 * 255) ** 2 65 | # C2 = (0.03 * 255) ** 2 66 | # 67 | # img1 = img1.astype(np.float64) 68 | # img2 = img2.astype(np.float64) 69 | # kernel = cv2.getGaussianKernel(11, 1.5) 70 | # window = np.outer(kernel, kernel.transpose()) 71 | # 72 | # mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 73 | # mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 74 | # mu1_sq = mu1 ** 2 75 | # mu2_sq = mu2 ** 2 76 | # mu1_mu2 = mu1 * mu2 77 | # sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 78 | # sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 79 | # sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 80 | # 81 | # ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 82 | # (sigma1_sq + sigma2_sq + C2)) 83 | # return ssim_map.mean() 84 | # 85 | # 86 | # def calculate_ssim(img1, img2): 87 | # """ 88 | # calculate SSIM 89 | # the same outputs as MATLAB's 90 | # img1, img2: [0, 255] numpy array. shape: [H, W, C] 91 | # """ 92 | # if not img1.shape == img2.shape: 93 | # raise ValueError('Input images must have the same dimensions.') 94 | # if img1.ndim == 2: 95 | # return ssim_com(img1, img2) 96 | # elif img1.ndim == 3: 97 | # if img1.shape[2] == 3: 98 | # ssims = [] 99 | # for i in range(3): 100 | # ssims.append(ssim_com(img1, img2)) 101 | # return np.array(ssims).mean() 102 | # elif img1.shape[2] == 1: 103 | # return ssim_com(np.squeeze(img1), np.squeeze(img2)) 104 | # else: 105 | # raise ValueError('Wrong input image dimensions.') 106 | 107 | 108 | def save_opt(dirpath, opt): 109 | save_opt_fpath = dirpath / OPT_FILENAME 110 | 111 | with save_opt_fpath.open('w', encoding="utf-8") as f: 112 | yaml.dump(omega2dict(opt), f, default_flow_style=False) 113 | 114 | 115 | def init_logging(mode, opt): 116 | # mode: only for training 117 | assert mode == TRAIN 118 | 119 | log_dirpath = ROOT_PATH / TRAIN_LOG_DIRNAME / opt[RUNTIME][MODELNAME] / opt[NAME] 120 | # + datetime.datetime.now().strftime(LOG_TIME_FORMAT) 121 | img_dirpath = log_dirpath / IMAGES 122 | 123 | mkdir(log_dirpath) 124 | mkdir(img_dirpath) 125 | save_opt(log_dirpath, opt) 126 | 127 | # pl_logger = logging.getLogger("lightning") 128 | # pl_logger.propagate = False 129 | 130 | return str(log_dirpath), str(img_dirpath) 131 | 132 | 133 | def saveTensorAsImg(output, path, downsample_factor=False): 134 | # save image of A BATCH or a SINGLE IMAGE. 135 | # input dtype: must be Tensor 136 | output = output.squeeze(0) 137 | 138 | if len(output.shape) == 4: 139 | # a batch 140 | res = [] 141 | for i in range(len(output)): 142 | res.append(saveTensorAsImg(output[i], f'{path}-{i}.png', downsample_factor)) 143 | return res 144 | 145 | # a single image 146 | assert len(output.shape) == 3 147 | outImg = cuda_tensor_to_ndarray( 148 | output.permute(1, 2, 0) 149 | ) * 255.0 150 | outImg = outImg[:, :, [2, 1, 0]].astype(np.uint8) 151 | 152 | if downsample_factor: 153 | assert type(downsample_factor + 0.1) == float 154 | h = outImg.shape[0] 155 | w = outImg.shape[1] 156 | outImg = cv2.resize(outImg, (int(w / downsample_factor), int(h / downsample_factor))).astype(np.uint8) 157 | 158 | cv2.imwrite(path, outImg) 159 | return outImg 160 | 161 | 162 | def parse_config(opt, mode): 163 | def checkField(opt, name, raise_msg): 164 | try: 165 | assert name in opt 166 | except: 167 | raise RuntimeError(raise_msg) 168 | 169 | # check necessary argments for ALL MODELS and ALL MODES: 170 | # for x in GENERAL_NECESSARY_ARGUMENTS: 171 | # checkField(opt, x, ARGUMENTS_MISSING_ERRS[x]) 172 | 173 | # check necessary argments for all models for EACH MODE: 174 | # if mode == TRAIN: 175 | # necessaryFields = TRAIN_NECESSARY_ARGUMENTS 176 | # elif mode in [TEST, VALID]: 177 | # necessaryFields = TEST_NECESSARY_ARGUMENTS 178 | # else: 179 | # raise NotImplementedError('[ ERR ] In function [checkConfig]: unknown mode', mode) 180 | # for x in necessaryFields: 181 | # checkField(opt, x, ARGUMENTS_MISSING_ERRS[x]) 182 | 183 | # make sure the model is implemented: 184 | modelname = opt[RUNTIME][MODELNAME] 185 | # assert parse_model_class(modelname) 186 | 187 | # check fields in runtime config is the same as template. 188 | # use `modelname.default.yaml` as template. 189 | runtime_config_dir = SRC_PATH.absolute() / CONFIG_DIR / RUNTIME 190 | template_yml_path = runtime_config_dir / f'{modelname}.default.yaml' 191 | print(f'Check runtime config: use "{template_yml_path}" as template.') 192 | assert template_yml_path.exists() 193 | # for x in load_yml(str(template_yml_path)): 194 | # checkField(opt[RUNTIME], x, f'[ ERR ] Runtime config missing argument: {x}') 195 | 196 | # if type(opt) == omegaconf.DictConfig: 197 | # return omegaconf.OmegaConf.to_container(opt) 198 | 199 | pl_logger = logging.getLogger("lightning") 200 | pl_logger.propagate = False 201 | return opt 202 | 203 | 204 | def omega2dict(opt): 205 | if type(opt) == omegaconf.DictConfig: 206 | return omegaconf.OmegaConf.to_container(opt) 207 | else: 208 | return opt 209 | 210 | 211 | # (Discarded) 212 | def load_yml(ymlpath): 213 | ''' 214 | input config file path (yml file), return config dict. 215 | ''' 216 | print(f'* Reading config from: {ymlpath}') 217 | 218 | if ymlpath.startswith('http'): 219 | import requests 220 | ymlContent = requests.get(ymlpath).content 221 | else: 222 | ymlContent = open(ymlpath, 'r').read() 223 | 224 | yml = yaml.load(ymlContent, Loader=yaml.FullLoader) 225 | return yml 226 | 227 | 228 | class ImageProcessing(object): 229 | 230 | @staticmethod 231 | def rgb_to_lab(img, is_training=True): 232 | """ PyTorch implementation of RGB to LAB conversion: https://docs.opencv.org/3.3.0/de/d25/imgproc_color_conversions.html 233 | Based roughly on a similar implementation here: https://github.com/affinelayer/pix2pix-tensorflow/blob/master/pix2pix.py 234 | :param img: 235 | :returns: 236 | :rtype: 237 | 238 | """ 239 | img = img.permute(2, 1, 0) 240 | shape = img.shape 241 | img = img.contiguous() 242 | img = img.view(-1, 3) 243 | 244 | img = (img / 12.92) * img.le(0.04045).float() + (((torch.clamp(img, 245 | min=0.0001) + 0.055) / 1.055) ** 2.4) * img.gt( 246 | 0.04045).float() 247 | 248 | rgb_to_xyz = Variable(torch.FloatTensor([ # X Y Z 249 | [0.412453, 0.212671, 250 | 0.019334], # R 251 | [0.357580, 0.715160, 252 | 0.119193], # G 253 | [0.180423, 0.072169, 254 | 0.950227], # B 255 | ]), requires_grad=False).type_as(img) 256 | 257 | img = torch.matmul(img, rgb_to_xyz) 258 | img = torch.mul(img, Variable(torch.FloatTensor( 259 | [1 / 0.950456, 1.0, 1 / 1.088754]), requires_grad=False).type_as(img)) 260 | 261 | epsilon = 6 / 29 262 | 263 | img = ((img / (3.0 * epsilon ** 2) + 4.0 / 29.0) * img.le(epsilon ** 3).float()) + \ 264 | (torch.clamp(img, min=0.0001) ** (1.0 / 3.0) * img.gt(epsilon ** 3).float()) 265 | 266 | fxfyfz_to_lab = Variable(torch.FloatTensor([[0.0, 500.0, 0.0], # fx 267 | [116.0, -500.0, 200.0], # fy 268 | [0.0, 0.0, -200.0], # fz 269 | ]), requires_grad=False).type_as(img) 270 | 271 | img = torch.matmul(img, fxfyfz_to_lab) + Variable( 272 | torch.FloatTensor([-16.0, 0.0, 0.0]), requires_grad=False).type_as(img) 273 | 274 | img = img.view(shape) 275 | img = img.permute(2, 1, 0) 276 | 277 | ''' 278 | L_chan: black and white with input range [0, 100] 279 | a_chan/b_chan: color channels with input range ~[-110, 110], not exact 280 | [0, 100] => [0, 1], ~[-110, 110] => [0, 1] 281 | ''' 282 | img[0, :, :] = img[0, :, :] / 100 283 | img[1, :, :] = (img[1, :, :] / 110 + 1) / 2 284 | img[2, :, :] = (img[2, :, :] / 110 + 1) / 2 285 | 286 | img[(img != img).detach()] = 0 287 | 288 | img = img.contiguous() 289 | 290 | return img 291 | 292 | @staticmethod 293 | def swapimdims_3HW_HW3(img): 294 | """Move the image channels to the first dimension of the numpy 295 | multi-dimensional array 296 | 297 | :param img: numpy nd array representing the image 298 | :returns: numpy nd array with permuted axes 299 | :rtype: numpy nd array 300 | 301 | """ 302 | if img.ndim == 3: 303 | return np.swapaxes(np.swapaxes(img, 1, 2), 0, 2) 304 | elif img.ndim == 4: 305 | return np.swapaxes(np.swapaxes(img, 2, 3), 1, 3) 306 | 307 | @staticmethod 308 | def swapimdims_HW3_3HW(img): 309 | """Move the image channels to the last dimensiion of the numpy 310 | multi-dimensional array 311 | 312 | :param img: numpy nd array representing the image 313 | :returns: numpy nd array with permuted axes 314 | :rtype: numpy nd array 315 | 316 | """ 317 | if img.ndim == 3: 318 | return np.swapaxes(np.swapaxes(img, 0, 2), 1, 2) 319 | elif img.ndim == 4: 320 | return np.swapaxes(np.swapaxes(img, 1, 3), 2, 3) 321 | 322 | @staticmethod 323 | def load_image(img_filepath, normaliser): 324 | """Loads an image from file as a numpy multi-dimensional array 325 | 326 | :param img_filepath: filepath to the image 327 | :returns: image as a multi-dimensional numpy array 328 | :rtype: multi-dimensional numpy array 329 | 330 | """ 331 | img = ImageProcessing.normalise_image( 332 | imread(img_filepath), normaliser) # NB: imread normalises to 0-1 333 | return img 334 | 335 | @staticmethod 336 | def normalise_image(img, normaliser): 337 | """Normalises image data to be a float between 0 and 1 338 | 339 | :param img: Image as a numpy multi-dimensional image array 340 | :returns: Normalised image as a numpy multi-dimensional image array 341 | :rtype: Numpy array 342 | 343 | """ 344 | img = img.astype('float32') / normaliser 345 | return img 346 | 347 | @staticmethod 348 | def compute_mse(original, result): 349 | """Computes the mean squared error between to RGB images represented as multi-dimensional numpy arrays. 350 | 351 | :param original: input RGB image as a numpy array 352 | :param result: target RGB image as a numpy array 353 | :returns: the mean squared error between the input and target images 354 | :rtype: float 355 | 356 | """ 357 | return ((original - result) ** 2).mean() 358 | 359 | @staticmethod 360 | def compute_psnr(image_batchA, image_batchB, max_intensity): 361 | """Computes the average PSNR for a batch of input and output images 362 | could be used during training / validation 363 | 364 | :param image_batchA: numpy nd-array representing the image batch A of shape Bx3xWxH 365 | :param image_batchB: numpy nd-array representing the image batch A of shape Bx3xWxH 366 | :param max_intensity: maximum intensity possible in the image (e.g. 255) 367 | :returns: average PSNR for the batch of images 368 | :rtype: float 369 | 370 | """ 371 | num_images = image_batchA.shape[0] 372 | psnr_val = 0.0 373 | 374 | for i in range(0, num_images): 375 | imageA = image_batchA[i, 0:3, :, :] 376 | imageB = image_batchB[i, 0:3, :, :] 377 | imageB = np.maximum(0, np.minimum(imageB, max_intensity)) 378 | psnr_val += 10 * \ 379 | np.log10(max_intensity ** 2 / 380 | ImageProcessing.compute_mse(imageA, imageB)) 381 | 382 | return psnr_val / num_images 383 | 384 | @staticmethod 385 | def compute_ssim(image_batchA, image_batchB): 386 | """Computes the SSIM for a batch of input and output images 387 | 388 | :param image_batchA: numpy nd-array representing the image batch A of shape Bx3xWxH 389 | :param image_batchB: numpy nd-array representing the image batch A of shape Bx3xWxH 390 | :param max_intensity: maximum intensity possible in the image (e.g. 255) 391 | :returns: average PSNR for the batch of images 392 | :rtype: float 393 | 394 | """ 395 | num_images = image_batchA.shape[0] 396 | ssim_val = 0.0 397 | 398 | for i in range(0, num_images): 399 | imageA = ImageProcessing.swapimdims_3HW_HW3( 400 | image_batchA[i, 0:3, :, :]) 401 | imageB = ImageProcessing.swapimdims_3HW_HW3( 402 | image_batchB[i, 0:3, :, :]) 403 | ssim_val += calc_ssim(imageA, imageB, data_range=imageA.max() - imageA.min(), multichannel=True, 404 | gaussian_weights=True, win_size=11) 405 | 406 | return ssim_val / num_images 407 | --------------------------------------------------------------------------------